File size: 2,461 Bytes
d657b96
 
 
0124bea
d657b96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0124bea
 
d657b96
 
 
 
 
 
 
0124bea
 
d657b96
 
 
 
 
 
0124bea
 
 
 
d657b96
 
0124bea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d657b96
0124bea
 
 
d657b96
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import yaml

from swin2_mose.model import Swin2MoSE


def to_shape(t1, t2):
    t1 = t1[None].repeat(t2.shape[0], 1)
    t1 = t1.view((t2.shape[:2] + (1, 1)))
    return t1


def norm(tensor, mean, std):
    # get stats
    mean = torch.tensor(mean).to(tensor.device)
    std = torch.tensor(std).to(tensor.device)
    # denorm
    return (tensor - to_shape(mean, tensor)) / to_shape(std, tensor)


def denorm(tensor, mean, std):
    # get stats
    mean = torch.tensor(mean).to(tensor.device)
    std = torch.tensor(std).to(tensor.device)
    # denorm
    return (tensor * to_shape(std, tensor)) + to_shape(mean, tensor)


def load_config(path):
    # load config
    with open(path, 'r') as f:
        cfg = yaml.safe_load(f)
    return cfg


def load_swin2_mose(model_weights, cfg):
    # load checkpoint
    checkpoint = torch.load(model_weights)

    # build model
    sr_model = Swin2MoSE(**cfg['super_res']['model'])
    sr_model.load_state_dict(
        checkpoint['model_state_dict'])

    sr_model.cfg = cfg

    return sr_model


def run_swin2_mose(model, lr, hr, device='cuda'):

    cfg = model.cfg

    # norm fun
    hr_stats = cfg['dataset']['stats']['tensor_05m_b2b3b4b8']
    lr_stats = cfg['dataset']['stats']['tensor_10m_b2b3b4b8']

    # select 10m lr bands: B02, B03, B04, B08 and hr bands
    lr_orig = torch.from_numpy(lr)[None].float()[:, [3, 2, 1, 7]].to(device)
    hr_orig = torch.from_numpy(hr)[None].float().to(device)

    # normalize data
    lr = norm(lr_orig, mean=lr_stats['mean'], std=lr_stats['std'])
    hr = norm(hr_orig, mean=hr_stats['mean'], std=hr_stats['std'])

    # predict a image
    with torch.no_grad():
        sr = model(lr)
        if not torch.is_tensor(sr):
            sr, _ = sr

    # denorm sr
    sr = denorm(sr, mean=hr_stats['mean'], std=hr_stats['std'])    

    # Prepare output
    sr = sr.round().cpu().numpy().astype('uint16').squeeze()[0:3]
    lr  = lr_orig[0].cpu().numpy().astype('uint16').squeeze()[0:3]
    hr = hr_orig[0].cpu().numpy().astype('uint16').squeeze()[0:3]

    # Use nn interpolation to go back to x2 without distortion
    # during metrics calculation
    if sr.shape[1] != hr.shape[1]:
        sr = torch.nn.functional.interpolate(
            torch.from_numpy(sr)[None].float(),
            size=hr.shape[1:],
            mode='nearest'
        ).squeeze().numpy().astype('uint16')

    
    return {
        'lr': lr,
        'sr': sr,
        'hr': hr
    }