Luisgust commited on
Commit
7e1d1d2
·
verified ·
1 Parent(s): 109cddf

Create vtoonify/model/raft/core/corr.py

Browse files
Files changed (1) hide show
  1. vtoonify/model/raft/core/corr.py +91 -0
vtoonify/model/raft/core/corr.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from model.raft.core.utils.utils import bilinear_sampler, coords_grid
4
+
5
+ try:
6
+ import alt_cuda_corr
7
+ except:
8
+ # alt_cuda_corr is not compiled
9
+ pass
10
+
11
+
12
+ class CorrBlock:
13
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14
+ self.num_levels = num_levels
15
+ self.radius = radius
16
+ self.corr_pyramid = []
17
+
18
+ # all pairs correlation
19
+ corr = CorrBlock.corr(fmap1, fmap2)
20
+
21
+ batch, h1, w1, dim, h2, w2 = corr.shape
22
+ corr = corr.reshape(batch*h1*w1, dim, h2, w2)
23
+
24
+ self.corr_pyramid.append(corr)
25
+ for i in range(self.num_levels-1):
26
+ corr = F.avg_pool2d(corr, 2, stride=2)
27
+ self.corr_pyramid.append(corr)
28
+
29
+ def __call__(self, coords):
30
+ r = self.radius
31
+ coords = coords.permute(0, 2, 3, 1)
32
+ batch, h1, w1, _ = coords.shape
33
+
34
+ out_pyramid = []
35
+ for i in range(self.num_levels):
36
+ corr = self.corr_pyramid[i]
37
+ dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
38
+ dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
39
+ delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
40
+
41
+ centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
42
+ delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
43
+ coords_lvl = centroid_lvl + delta_lvl
44
+
45
+ corr = bilinear_sampler(corr, coords_lvl)
46
+ corr = corr.view(batch, h1, w1, -1)
47
+ out_pyramid.append(corr)
48
+
49
+ out = torch.cat(out_pyramid, dim=-1)
50
+ return out.permute(0, 3, 1, 2).contiguous().float()
51
+
52
+ @staticmethod
53
+ def corr(fmap1, fmap2):
54
+ batch, dim, ht, wd = fmap1.shape
55
+ fmap1 = fmap1.view(batch, dim, ht*wd)
56
+ fmap2 = fmap2.view(batch, dim, ht*wd)
57
+
58
+ corr = torch.matmul(fmap1.transpose(1,2), fmap2)
59
+ corr = corr.view(batch, ht, wd, 1, ht, wd)
60
+ return corr / torch.sqrt(torch.tensor(dim).float())
61
+
62
+
63
+ class AlternateCorrBlock:
64
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
65
+ self.num_levels = num_levels
66
+ self.radius = radius
67
+
68
+ self.pyramid = [(fmap1, fmap2)]
69
+ for i in range(self.num_levels):
70
+ fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
71
+ fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
72
+ self.pyramid.append((fmap1, fmap2))
73
+
74
+ def __call__(self, coords):
75
+ coords = coords.permute(0, 2, 3, 1)
76
+ B, H, W, _ = coords.shape
77
+ dim = self.pyramid[0][0].shape[1]
78
+
79
+ corr_list = []
80
+ for i in range(self.num_levels):
81
+ r = self.radius
82
+ fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
83
+ fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
84
+
85
+ coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
86
+ corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
87
+ corr_list.append(corr.squeeze(1))
88
+
89
+ corr = torch.stack(corr_list, dim=1)
90
+ corr = corr.reshape(B, -1, H, W)
91
+ return corr / torch.sqrt(torch.tensor(dim).float())