Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 4,628 Bytes
			
			| 938e515 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import random
from typing import Tuple
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.config import CfgNode
from densepose.structures.mesh import create_mesh
from .utils import sample_random_indices
class ShapeToShapeCycleLoss(nn.Module):
    """
    Cycle Loss for Shapes.
    Inspired by:
    "Mapping in a Cycle: Sinkhorn Regularized Unsupervised Learning for Point Cloud Shapes".
    """
    def __init__(self, cfg: CfgNode):
        super().__init__()
        self.shape_names = list(cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.keys())
        self.all_shape_pairs = [
            (x, y) for i, x in enumerate(self.shape_names) for y in self.shape_names[i + 1 :]
        ]
        random.shuffle(self.all_shape_pairs)
        self.cur_pos = 0
        self.norm_p = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.NORM_P
        self.temperature = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.TEMPERATURE
        self.max_num_vertices = (
            cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.MAX_NUM_VERTICES
        )
    def _sample_random_pair(self) -> Tuple[str, str]:
        """
        Produce a random pair of different mesh names
        Return:
            tuple(str, str): a pair of different mesh names
        """
        if self.cur_pos >= len(self.all_shape_pairs):
            random.shuffle(self.all_shape_pairs)
            self.cur_pos = 0
        shape_pair = self.all_shape_pairs[self.cur_pos]
        self.cur_pos += 1
        return shape_pair
    def forward(self, embedder: nn.Module):
        """
        Do a forward pass with a random pair (src, dst) pair of shapes
        Args:
            embedder (nn.Module): module that computes vertex embeddings for different meshes
        """
        src_mesh_name, dst_mesh_name = self._sample_random_pair()
        return self._forward_one_pair(embedder, src_mesh_name, dst_mesh_name)
    def fake_value(self, embedder: nn.Module):
        losses = []
        for mesh_name in embedder.mesh_names:
            losses.append(embedder(mesh_name).sum() * 0)
        return torch.mean(torch.stack(losses))
    def _get_embeddings_and_geodists_for_mesh(
        self, embedder: nn.Module, mesh_name: str
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Produces embeddings and geodesic distance tensors for a given mesh. May subsample
        the mesh, if it contains too many vertices (controlled by
        SHAPE_CYCLE_LOSS_MAX_NUM_VERTICES parameter).
        Args:
            embedder (nn.Module): module that computes embeddings for mesh vertices
            mesh_name (str): mesh name
        Return:
            embeddings (torch.Tensor of size [N, D]): embeddings for selected mesh
                vertices (N = number of selected vertices, D = embedding space dim)
            geodists (torch.Tensor of size [N, N]): geodesic distances for the selected
                mesh vertices (N = number of selected vertices)
        """
        embeddings = embedder(mesh_name)
        indices = sample_random_indices(
            embeddings.shape[0], self.max_num_vertices, embeddings.device
        )
        mesh = create_mesh(mesh_name, embeddings.device)
        geodists = mesh.geodists
        if indices is not None:
            embeddings = embeddings[indices]
            geodists = geodists[torch.meshgrid(indices, indices)]
        return embeddings, geodists
    def _forward_one_pair(
        self, embedder: nn.Module, mesh_name_1: str, mesh_name_2: str
    ) -> torch.Tensor:
        """
        Do a forward pass with a selected pair of meshes
        Args:
            embedder (nn.Module): module that computes vertex embeddings for different meshes
            mesh_name_1 (str): first mesh name
            mesh_name_2 (str): second mesh name
        Return:
            Tensor containing the loss value
        """
        embeddings_1, geodists_1 = self._get_embeddings_and_geodists_for_mesh(embedder, mesh_name_1)
        embeddings_2, geodists_2 = self._get_embeddings_and_geodists_for_mesh(embedder, mesh_name_2)
        sim_matrix_12 = embeddings_1.mm(embeddings_2.T)
        c_12 = F.softmax(sim_matrix_12 / self.temperature, dim=1)
        c_21 = F.softmax(sim_matrix_12.T / self.temperature, dim=1)
        c_11 = c_12.mm(c_21)
        c_22 = c_21.mm(c_12)
        loss_cycle_11 = torch.norm(geodists_1 * c_11, p=self.norm_p)
        loss_cycle_22 = torch.norm(geodists_2 * c_22, p=self.norm_p)
        return loss_cycle_11 + loss_cycle_22
 | 
 
			
