File size: 2,742 Bytes
c24da45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.nn as nn
import torch.nn.functional as F


class TetTexNet(nn.Module):
    def __init__(self, plane_reso=64, padding=0.1, fea_concat=True):
        super().__init__()
        # self.c_dim = c_dim
        self.plane_reso = plane_reso
        self.padding = padding
        self.fea_concat = fea_concat

    def forward(self, rolled_out_feature, query):
        # rolled_out_feature: rolled-out triplane feature
        # query: queried xyz coordinates (should be scaled consistently to ptr cloud)

        plane_reso = self.plane_reso

        triplane_feature = dict()
        triplane_feature['xy'] = rolled_out_feature[:, :, :, 0: plane_reso]
        triplane_feature['yz'] = rolled_out_feature[:, :, :, plane_reso: 2 * plane_reso]
        triplane_feature['zx'] = rolled_out_feature[:, :, :, 2 * plane_reso:]

        query_feature_xy = self.sample_plane_feature(query, triplane_feature['xy'], 'xy')
        query_feature_yz = self.sample_plane_feature(query, triplane_feature['yz'], 'yz')
        query_feature_zx = self.sample_plane_feature(query, triplane_feature['zx'], 'zx')

        if self.fea_concat:
            query_feature = torch.cat((query_feature_xy, query_feature_yz, query_feature_zx), dim=1)
        else:
            query_feature = query_feature_xy + query_feature_yz + query_feature_zx

        output = query_feature.permute(0, 2, 1)

        return output

    # uses values from plane_feature and pixel locations from vgrid to interpolate feature
    def sample_plane_feature(self, query, plane_feature, plane):
        # CYF note:
        # for pretraining, query are uniformly sampled positions w.i. [-scale, scale]
        # for training, query are essentially tetrahedra grid vertices, which are
        # also within [-scale, scale] in the current version!
        # xy range [-scale, scale]
        if plane == 'xy':
            xy = query[:, :, [0, 1]]
        elif plane == 'yz':
            xy = query[:, :, [1, 2]]
        elif plane == 'zx':
            xy = query[:, :, [2, 0]]
        else:
            raise ValueError("Error! Invalid plane type!")

        xy = xy[:, :, None].float()
        # not seem necessary to rescale the grid, because from
        # https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html,
        # it specifies sampling locations normalized by plane_feature's spatial dimension,
        # which is within [-scale, scale] as specified by encoder's calling of coordinate2index()
        vgrid = 1.0 * xy
        sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1)

        return sampled_feat