File size: 10,495 Bytes
8d34f50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import sys, os
import numpy as np
import torch


def make_conv(n_in, n_out, n_blocks, kernel=3, normalization=torch.nn.BatchNorm3d, activation=torch.nn.ReLU):
    blocks = []
    for i in range(n_blocks):                                                                                                                                                                                                                                                                                             
        in1 = n_in if i == 0 else n_out
        blocks.append(torch.nn.Sequential(
            torch.nn.Conv3d(in1, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
            normalization(n_out),
            activation(inplace=True)
        ))
    return torch.nn.Sequential(*blocks)


def make_conv_2d(n_in, n_out, n_blocks, kernel=3, normalization=torch.nn.BatchNorm2d, activation=torch.nn.ReLU):
    blocks = []
    for i in range(n_blocks):                                                                                                                                                                                                                                                                                             
        in1 = n_in if i == 0 else n_out
        blocks.append(torch.nn.Sequential(
            torch.nn.Conv2d(in1, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
            normalization(n_out),
            activation(inplace=True)
        ))
    return torch.nn.Sequential(*blocks)


def make_downscale(n_in, n_out, kernel=4, normalization=torch.nn.BatchNorm3d, activation=torch.nn.ReLU):                                                                                                                                                                                                                              
    block = torch.nn.Sequential(
        torch.nn.Conv3d(n_in, n_out, kernel_size=kernel, stride=2, padding=(kernel-2)//2),
        normalization(n_out),
        activation(inplace=True)
    )
    return block


def make_downscale_2d(n_in, n_out, kernel=4, normalization=torch.nn.BatchNorm2d, activation=torch.nn.ReLU):                                                                                                                                                                                                                              
    block = torch.nn.Sequential(
        torch.nn.Conv2d(n_in, n_out, kernel_size=kernel, stride=2, padding=(kernel-2)//2),
        normalization(n_out),
        activation(inplace=True)
    )
    return block
    

def make_upscale(n_in, n_out, normalization=torch.nn.BatchNorm3d, activation=torch.nn.ReLU):
    block = torch.nn.Sequential(
        torch.nn.ConvTranspose3d(n_in, n_out, kernel_size=6, stride=2, padding=2),
        normalization(n_out),
        activation(inplace=True)
    )
    return block


def make_upscale_2d(n_in, n_out, kernel=4, normalization=torch.nn.BatchNorm2d, activation=torch.nn.ReLU):
    block = torch.nn.Sequential(
        torch.nn.ConvTranspose2d(n_in, n_out, kernel_size=kernel, stride=2, padding=(kernel-2)//2),
        normalization(n_out),
        activation(inplace=True)
    )
    return block


class ResBlock(torch.nn.Module):
    def __init__(self, n_out, kernel=3, normalization=torch.nn.BatchNorm3d, activation=torch.nn.ReLU):
        super().__init__()
        self.block0 = torch.nn.Sequential(
            torch.nn.Conv3d(n_out, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
            normalization(n_out),
            activation(inplace=True)
        )
        
        self.block1 = torch.nn.Sequential(
            torch.nn.Conv3d(n_out, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
            normalization(n_out),
        )

        self.block2 = torch.nn.ReLU()

    def forward(self, x0):
        x = self.block0(x0)

        x = self.block1(x)
        
        x = self.block2(x + x0)
        return x


class ResBlock2d(torch.nn.Module):
    def __init__(self, n_out, kernel=3, normalization=torch.nn.BatchNorm2d, activation=torch.nn.ReLU):
        super().__init__()
        self.block0 = torch.nn.Sequential(
            torch.nn.Conv2d(n_out, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
            normalization(n_out),
            activation(inplace=True)
        )
        
        self.block1 = torch.nn.Sequential(
            torch.nn.Conv2d(n_out, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
            normalization(n_out),
        )

        self.block2 = torch.nn.ReLU()

    def forward(self, x0):
        x = self.block0(x0)

        x = self.block1(x)
        
        x = self.block2(x + x0)
        return x


class Identity(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
    
    def forward(self, x):
        return x


def downscale_gt_flow(flow_gt, flow_mask, image_height, image_width):
    flow_gt_copy = flow_gt.clone()
    flow_mask_copy = flow_mask.clone()

    flow_gt_copy = flow_gt_copy / 20.0
    flow_mask_copy = flow_mask_copy.float()

    assert image_height % 64 == 0 and image_width % 64 == 0

    flow_gt2 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//4, image_width//4), mode='nearest')
    flow_mask2 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//4, image_width//4), mode='nearest').bool()
    
    flow_gt3 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//8, image_width//8), mode='nearest')
    flow_mask3 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//8, image_width//8), mode='nearest').bool()
    
    flow_gt4 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//16, image_width//16), mode='nearest')
    flow_mask4 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//16, image_width//16), mode='nearest').bool()
    
    flow_gt5 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//32, image_width//32), mode='nearest')
    flow_mask5 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//32, image_width//32), mode='nearest').bool()

    flow_gt6 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//64, image_width//64), mode='nearest')
    flow_mask6 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//64, image_width//64), mode='nearest').bool()

    return [flow_gt2, flow_gt3, flow_gt4, flow_gt5, flow_gt6], [flow_mask2, flow_mask3, flow_mask4, flow_mask5, flow_mask6]


def compute_baseline_mask_gt(
    xy_coords_warped, 
    target_matches, valid_target_matches,
    source_points, valid_source_points,
    scene_flow_gt, scene_flow_mask, target_boundary_mask,
    max_pos_flowed_source_to_target_dist, min_neg_flowed_source_to_target_dist
):
    # Scene flow mask
    scene_flow_mask_0 = scene_flow_mask[:, 0].type(torch.bool)

    # Boundary correspondences mask
    # We use the nearest neighbor interpolation, since the boundary computations
    # already marks any of 4 pixels as boundary.
    target_nonboundary_mask = (~target_boundary_mask).type(torch.float32)
    target_matches_nonboundary_mask = torch.nn.functional.grid_sample(target_nonboundary_mask, xy_coords_warped, padding_mode='zeros', mode='nearest', align_corners=False)
    target_matches_nonboundary_mask = target_matches_nonboundary_mask[:, 0, :, :] >= 0.999

    # Compute groundtruth mask (oracle)
    flowed_source_points = source_points + scene_flow_gt
    dist = torch.norm(flowed_source_points - target_matches, p=2, dim=1)

    # Combine all masks
    # We mark a correspondence as positive if;
    # - it is close enough to groundtruth flow 
    # AND
    # - there exists groundtruth flow 
    # AND
    # - the target match is valid
    # AND
    # - the source point is valid
    # AND
    # - the target match is not on the boundary
    mask_pos_gt = (dist <= max_pos_flowed_source_to_target_dist) & scene_flow_mask_0 & valid_target_matches & valid_source_points & target_matches_nonboundary_mask

    # We mark a correspondence as negative if;
    # - there exists groundtruth flow AND it is far away enough from the groundtruth flow AND source/target points are valid
    # OR
    # - the target match is on the boundary AND there exists groundtruth flow AND source/target points are valid
    mask_neg_gt = ((dist > min_neg_flowed_source_to_target_dist) & scene_flow_mask_0 & valid_source_points & valid_target_matches) \
            | (~target_matches_nonboundary_mask & scene_flow_mask_0 & valid_source_points & valid_target_matches)

    # What remains is left undecided (masked out at loss).
    # For groundtruth mask we set it to zero.
    valid_mask_pixels = mask_pos_gt | mask_neg_gt
    mask_gt = mask_pos_gt

    mask_gt = mask_gt.type(torch.float32)
    
    return mask_gt, valid_mask_pixels


def compute_deformed_points_gt(
    source_points, scene_flow_gt, 
    valid_solve, valid_correspondences, 
    deformed_points_idxs, deformed_points_subsampled
):
    batch_size = source_points.shape[0]
    max_warped_points = deformed_points_idxs.shape[1]

    deformed_points_gt = torch.zeros((batch_size, max_warped_points, 3), dtype=source_points.dtype, device=source_points.device) 
    deformed_points_mask = torch.zeros((batch_size, max_warped_points, 3), dtype=source_points.dtype, device=source_points.device) 

    for i in range(batch_size):
        if valid_solve[i]:
            valid_correspondences_idxs = torch.where(valid_correspondences[i])

            # Compute deformed point groundtruth.
            deformed_points_i_gt = source_points[i] + scene_flow_gt[i]
            deformed_points_i_gt = deformed_points_i_gt.permute(1, 2, 0)
            deformed_points_i_gt = deformed_points_i_gt[valid_correspondences_idxs[0], valid_correspondences_idxs[1], :].view(-1, 3, 1)

            # Filter out points randomly, if too many are still left.
            if deformed_points_subsampled[i]:
                sampled_idxs_i = deformed_points_idxs[i]
                deformed_points_i_gt  = deformed_points_i_gt[sampled_idxs_i]

            num_points = deformed_points_i_gt.shape[0]

            # Store the results.
            deformed_points_gt[i, :num_points, :] = deformed_points_i_gt.view(1, num_points, 3)
            deformed_points_mask[i, :num_points, :] = 1

    return deformed_points_gt, deformed_points_mask