File size: 2,461 Bytes
d657b96 0124bea d657b96 0124bea d657b96 0124bea d657b96 0124bea d657b96 0124bea d657b96 0124bea d657b96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import torch
import yaml
from swin2_mose.model import Swin2MoSE
def to_shape(t1, t2):
t1 = t1[None].repeat(t2.shape[0], 1)
t1 = t1.view((t2.shape[:2] + (1, 1)))
return t1
def norm(tensor, mean, std):
# get stats
mean = torch.tensor(mean).to(tensor.device)
std = torch.tensor(std).to(tensor.device)
# denorm
return (tensor - to_shape(mean, tensor)) / to_shape(std, tensor)
def denorm(tensor, mean, std):
# get stats
mean = torch.tensor(mean).to(tensor.device)
std = torch.tensor(std).to(tensor.device)
# denorm
return (tensor * to_shape(std, tensor)) + to_shape(mean, tensor)
def load_config(path):
# load config
with open(path, 'r') as f:
cfg = yaml.safe_load(f)
return cfg
def load_swin2_mose(model_weights, cfg):
# load checkpoint
checkpoint = torch.load(model_weights)
# build model
sr_model = Swin2MoSE(**cfg['super_res']['model'])
sr_model.load_state_dict(
checkpoint['model_state_dict'])
sr_model.cfg = cfg
return sr_model
def run_swin2_mose(model, lr, hr, device='cuda'):
cfg = model.cfg
# norm fun
hr_stats = cfg['dataset']['stats']['tensor_05m_b2b3b4b8']
lr_stats = cfg['dataset']['stats']['tensor_10m_b2b3b4b8']
# select 10m lr bands: B02, B03, B04, B08 and hr bands
lr_orig = torch.from_numpy(lr)[None].float()[:, [3, 2, 1, 7]].to(device)
hr_orig = torch.from_numpy(hr)[None].float().to(device)
# normalize data
lr = norm(lr_orig, mean=lr_stats['mean'], std=lr_stats['std'])
hr = norm(hr_orig, mean=hr_stats['mean'], std=hr_stats['std'])
# predict a image
with torch.no_grad():
sr = model(lr)
if not torch.is_tensor(sr):
sr, _ = sr
# denorm sr
sr = denorm(sr, mean=hr_stats['mean'], std=hr_stats['std'])
# Prepare output
sr = sr.round().cpu().numpy().astype('uint16').squeeze()[0:3]
lr = lr_orig[0].cpu().numpy().astype('uint16').squeeze()[0:3]
hr = hr_orig[0].cpu().numpy().astype('uint16').squeeze()[0:3]
# Use nn interpolation to go back to x2 without distortion
# during metrics calculation
if sr.shape[1] != hr.shape[1]:
sr = torch.nn.functional.interpolate(
torch.from_numpy(sr)[None].float(),
size=hr.shape[1:],
mode='nearest'
).squeeze().numpy().astype('uint16')
return {
'lr': lr,
'sr': sr,
'hr': hr
}
|