File size: 4,052 Bytes
8d34f50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import torch.nn as nn
import torch.nn.functional as F
from raft import RAFT
from nnutils import make_conv_2d, make_upscale_2d, make_downscale_2d, ResBlock2d, Identity


class ImportanceWeights(torch.nn.Module):
    def __init__(self, opt):
        super().__init__()

        if opt.small:
            in_dim = 128
        else:
            in_dim = 256
        fn_0 = 16
        self.input_fn = fn_0 + 3 * 2
        fn_1 = 16
        self.conv1 = torch.nn.Conv2d(in_channels=in_dim, out_channels=fn_0, kernel_size=3, stride=1, padding=1)

        if opt.use_batch_norm:
            custom_batch_norm = torch.nn.BatchNorm2d
        else:
            custom_batch_norm = Identity

        self.model = nn.Sequential(
            make_conv_2d(self.input_fn, fn_1, n_blocks=1, normalization=custom_batch_norm),
            ResBlock2d(fn_1, normalization=custom_batch_norm),
            ResBlock2d(fn_1, normalization=custom_batch_norm),
            ResBlock2d(fn_1, normalization=custom_batch_norm),
            nn.Conv2d(fn_1, 1, kernel_size=3, padding=1)
            # torch.nn.Sigmoid()
        )

    def forward(self, x, features):
        # Reduce number of channels and upscale to highest resolution
        features = self.conv1(features)
        x = torch.cat([features, x], 1)
        assert x.shape[1] == self.input_fn
        x = self.model(x)
        print(x)
        print(x.max(), x.min(), x.mean())

        return torch.nn.Sigmoid()(x)

class NeuralNRT(nn.Module):
    def __init__(self, opt, path=None, device="cuda:0"):
        super(NeuralNRT, self).__init__()
        self.opt = opt
        self.CorresPred = RAFT(opt)
        self.ImportanceW = ImportanceWeights(opt)
        if path is not None:
            data = torch.load(path,map_location='cpu')
            if 'state_dict' in data.keys():
                self.CorresPred.load_state_dict(data['state_dict'])
                print("load done")
            else:
                self.CorresPred.load_state_dict({k.replace('module.', ''):v for k,v in data.items()})
                print("load done")
    def forward(self, src_im,tar_im, src_im_raw, tar_im_raw, Crop_param):
        N=src_im.shape[0]
        src_im = src_im*255.0
        tar_im = tar_im*255.0
        flow_fw_crop, feature_fw_crop = self.CorresPred(src_im, tar_im, iters=self.opt.iters)

        xx = torch.arange(0, self.opt.width).view(1,-1).repeat(self.opt.height,1)
        yy = torch.arange(0, self.opt.height).view(-1,1).repeat(1,self.opt.width)
        xx = xx.view(1,1,self.opt.height,self.opt.width).repeat(N,1,1,1)
        yy = yy.view(1,1,self.opt.height,self.opt.width).repeat(N,1,1,1)
        grid = torch.cat((xx,yy),1).float()
        grid = grid.to(src_im.device)

        grid_crop = grid[:, :, :self.opt.crop_height, :self.opt.crop_width]

        flow_fw = torch.zeros((N, 2, self.opt.height, self.opt.width), device=src_im.device)

        leftup1 = torch.cat((Crop_param[:, 0:1, 0], Crop_param[:, 2:3, 0]), 1)[:, :, None, None]
        leftup2 = torch.cat((Crop_param[:, 4:5, 0], Crop_param[:, 6:7, 0]), 1)[:, :, None, None]

        scale1 = torch.cat(((Crop_param[:, 1:2, 0]-Crop_param[:, 0:1, 0]).float() / self.opt.crop_width, (Crop_param[:, 3:4, 0]-Crop_param[:, 2:3, 0]).float() / self.opt.crop_height), 1)[:, :, None, None]
        scale2 = torch.cat(((Crop_param[:, 5:6, 0]-Crop_param[:, 4:5, 0]).float() / self.opt.crop_width, (Crop_param[:, 7:8, 0]-Crop_param[:, 6:7, 0]).float() / self.opt.crop_height), 1)[:, :, None, None]
 
        flow_fw_crop = (scale2 - scale1) * grid_crop + scale2 * flow_fw_crop
        for i in range(N):
            flow_fw_cropi = F.interpolate(flow_fw_crop[i:(i+1)], ((Crop_param[i, 3, 0]-Crop_param[i, 2, 0]).item(), (Crop_param[i, 1, 0]-Crop_param[i, 0, 0]).item()), mode='bilinear', align_corners=True)
            flow_fw_cropi  =flow_fw_cropi + (leftup2 - leftup1)[i:(i+1), :, :, :]
            flow_fw[i, :, Crop_param[i, 2, 0]:Crop_param[i, 3, 0], Crop_param[i, 0, 0]:Crop_param[i, 1, 0]] = flow_fw_cropi[0]
        return flow_fw