File size: 4,541 Bytes
753fd9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch

# code from https://raw.githubusercontent.com/yufu-wang/aves/main/optimization/loss_arap.py


class Arap_Loss():
    '''
    Pytorch implementaion: As-rigid-as-possible loss class
    
    '''

    def __init__(self, meshes, device='cpu', vertex_w=None):

        with torch.no_grad():       # new nadine 

            self.device = device
            self.bn = len(meshes)

            # get lapacian cotangent matrix
            L = self.get_laplacian_cot(meshes)
            self.wij = L.values().clone()
            self.wij[self.wij<0] = 0.

            # get ajacency matrix
            V = meshes.num_verts_per_mesh().sum()
            edges_packed = meshes.edges_packed()
            e0, e1 = edges_packed.unbind(1)
            idx01 = torch.stack([e0, e1], dim=1)
            idx10 = torch.stack([e1, e0], dim=1)
            idx = torch.cat([idx01, idx10], dim=0).t()

            ones = torch.ones(idx.shape[1], dtype=torch.float32).to(device)
            A = torch.sparse.FloatTensor(idx, ones, (V, V))
            self.deg = torch.sparse.sum(A, dim=1).to_dense().long()
            self.idx = self.sort_idx(idx)

            # get edges of default mesh
            self.eij = self.get_edges(meshes)

            # get per vertex regularization strength
            self.vertex_w = vertex_w


    def __call__(self, new_meshes):
        new_meshes._compute_packed()
        
        optimal_R = self.step_1(new_meshes)
        arap_loss = self.step_2(optimal_R, new_meshes)
        return arap_loss


    def step_1(self, new_meshes):
        bn = self.bn
        eij = self.eij.view(bn, -1, 3).cpu()

        with torch.no_grad():
            eij_ = self.get_edges(new_meshes)

            eij_ = eij_.view(bn, -1, 3).cpu()
            wij = self.wij.view(bn, -1).cpu()

            deg_1 = self.deg.view(bn, -1)[0].cpu()  # assuming same topology
            S = torch.zeros([bn, len(deg_1), 3, 3])
            for i in range(len(deg_1)):
                start, end = deg_1[:i].sum(), deg_1[:i+1].sum()
                P  = eij[:, start : end]
                P_ = eij_[:, start : end]
                D = wij[:, start : end]
                D = torch.diag_embed(D)
                S[:, i] = P.transpose(-2,-1) @ D @ P_
                
            S = S.view(-1, 3, 3)

            u, _, v = torch.svd(S)
            R = v @ u.transpose(-2, -1)
            det = torch.det(R)

            u[det<0, :, -1] *= -1
            R = v @ u.transpose(-2, -1)
            R = R.to(self.device)

        return R


    def step_2(self, R, new_meshes):
        R = torch.repeat_interleave(R, self.deg, dim=0)
        Reij = R @ self.eij.unsqueeze(2)
        Reij = Reij.squeeze()
        
        eij_ = self.get_edges(new_meshes)
        arap_loss = self.wij * (eij_ - Reij).norm(dim=1)

        if self.vertex_w is not None:
            vertex_w = torch.repeat_interleave(self.vertex_w, self.deg, dim=0)
            arap_loss = arap_loss * vertex_w

        arap_loss = arap_loss.sum() / self.bn

        return arap_loss


    def get_edges(self, meshes):
        verts_packed = meshes.verts_packed()
        vi = torch.repeat_interleave(verts_packed, self.deg, dim=0)
        vj = verts_packed[self.idx[1]]
        eij = vi - vj
        return eij


    def sort_idx(self, idx):
        _, order = (idx[0] + idx[1]*1e-6).sort()

        return idx[:, order]


    def get_laplacian_cot(self, meshes):
        '''
        Routine modified from :
        pytorch3d/loss/mesh_laplacian_smoothing.py
        '''
        verts_packed = meshes.verts_packed()
        faces_packed = meshes.faces_packed()
        V, F = verts_packed.shape[0], faces_packed.shape[0]

        face_verts = verts_packed[faces_packed]
        v0, v1, v2 = face_verts[:,0], face_verts[:,1], face_verts[:,2]

        A = (v1-v2).norm(dim=1)
        B = (v0-v2).norm(dim=1)
        C = (v0-v1).norm(dim=1)

        s = 0.5 * (A+B+C)
        area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=1e-12).sqrt()

        A2, B2, C2 = A * A, B * B, C * C
        cota = (B2 + C2 - A2) / area
        cotb = (A2 + C2 - B2) / area
        cotc = (A2 + B2 - C2) / area
        cot = torch.stack([cota, cotb, cotc], dim=1)
        cot /= 4.0

        ii = faces_packed[:, [1,2,0]]  
        jj = faces_packed[:, [2,0,1]]
        idx = torch.stack([ii, jj], dim=0).view(2, F*3)
        L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V))
        L += L.t()
        L = L.coalesce()
        L /= 2.0  # normalized according to arap paper

        return L