import torch # code from https://raw.githubusercontent.com/yufu-wang/aves/main/optimization/loss_arap.py class Arap_Loss(): ''' Pytorch implementaion: As-rigid-as-possible loss class ''' def __init__(self, meshes, device='cpu', vertex_w=None): with torch.no_grad(): # new nadine self.device = device self.bn = len(meshes) # get lapacian cotangent matrix L = self.get_laplacian_cot(meshes) self.wij = L.values().clone() self.wij[self.wij<0] = 0. # get ajacency matrix V = meshes.num_verts_per_mesh().sum() edges_packed = meshes.edges_packed() e0, e1 = edges_packed.unbind(1) idx01 = torch.stack([e0, e1], dim=1) idx10 = torch.stack([e1, e0], dim=1) idx = torch.cat([idx01, idx10], dim=0).t() ones = torch.ones(idx.shape[1], dtype=torch.float32).to(device) A = torch.sparse.FloatTensor(idx, ones, (V, V)) self.deg = torch.sparse.sum(A, dim=1).to_dense().long() self.idx = self.sort_idx(idx) # get edges of default mesh self.eij = self.get_edges(meshes) # get per vertex regularization strength self.vertex_w = vertex_w def __call__(self, new_meshes): new_meshes._compute_packed() optimal_R = self.step_1(new_meshes) arap_loss = self.step_2(optimal_R, new_meshes) return arap_loss def step_1(self, new_meshes): bn = self.bn eij = self.eij.view(bn, -1, 3).cpu() with torch.no_grad(): eij_ = self.get_edges(new_meshes) eij_ = eij_.view(bn, -1, 3).cpu() wij = self.wij.view(bn, -1).cpu() deg_1 = self.deg.view(bn, -1)[0].cpu() # assuming same topology S = torch.zeros([bn, len(deg_1), 3, 3]) for i in range(len(deg_1)): start, end = deg_1[:i].sum(), deg_1[:i+1].sum() P = eij[:, start : end] P_ = eij_[:, start : end] D = wij[:, start : end] D = torch.diag_embed(D) S[:, i] = P.transpose(-2,-1) @ D @ P_ S = S.view(-1, 3, 3) u, _, v = torch.svd(S) R = v @ u.transpose(-2, -1) det = torch.det(R) u[det<0, :, -1] *= -1 R = v @ u.transpose(-2, -1) R = R.to(self.device) return R def step_2(self, R, new_meshes): R = torch.repeat_interleave(R, self.deg, dim=0) Reij = R @ self.eij.unsqueeze(2) Reij = Reij.squeeze() eij_ = self.get_edges(new_meshes) arap_loss = self.wij * (eij_ - Reij).norm(dim=1) if self.vertex_w is not None: vertex_w = torch.repeat_interleave(self.vertex_w, self.deg, dim=0) arap_loss = arap_loss * vertex_w arap_loss = arap_loss.sum() / self.bn return arap_loss def get_edges(self, meshes): verts_packed = meshes.verts_packed() vi = torch.repeat_interleave(verts_packed, self.deg, dim=0) vj = verts_packed[self.idx[1]] eij = vi - vj return eij def sort_idx(self, idx): _, order = (idx[0] + idx[1]*1e-6).sort() return idx[:, order] def get_laplacian_cot(self, meshes): ''' Routine modified from : pytorch3d/loss/mesh_laplacian_smoothing.py ''' verts_packed = meshes.verts_packed() faces_packed = meshes.faces_packed() V, F = verts_packed.shape[0], faces_packed.shape[0] face_verts = verts_packed[faces_packed] v0, v1, v2 = face_verts[:,0], face_verts[:,1], face_verts[:,2] A = (v1-v2).norm(dim=1) B = (v0-v2).norm(dim=1) C = (v0-v1).norm(dim=1) s = 0.5 * (A+B+C) area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=1e-12).sqrt() A2, B2, C2 = A * A, B * B, C * C cota = (B2 + C2 - A2) / area cotb = (A2 + C2 - B2) / area cotc = (A2 + B2 - C2) / area cot = torch.stack([cota, cotb, cotc], dim=1) cot /= 4.0 ii = faces_packed[:, [1,2,0]] jj = faces_packed[:, [2,0,1]] idx = torch.stack([ii, jj], dim=0).view(2, F*3) L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V)) L += L.t() L = L.coalesce() L /= 2.0 # normalized according to arap paper return L