File size: 584 Bytes
5e83696 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
from main import *
def get_satclip(ckpt_path, device, return_all=False):
ckpt = torch.load(ckpt_path,map_location=device)
ckpt['hyper_parameters'].pop('eval_downstream')
ckpt['hyper_parameters'].pop('air_temp_data_path')
ckpt['hyper_parameters'].pop('election_data_path')
lightning_model = SatCLIPLightningModule(**ckpt['hyper_parameters']).to(device)
lightning_model.load_state_dict(ckpt['state_dict'])
lightning_model.eval()
geo_model = lightning_model.model
if return_all:
return geo_model
else:
return geo_model.location |