Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| # pyre-unsafe | |
| import json | |
| import logging | |
| from typing import List, Optional | |
| import torch | |
| from torch import nn | |
| from detectron2.utils.file_io import PathManager | |
| from densepose.structures.mesh import create_mesh | |
| class MeshAlignmentEvaluator: | |
| """ | |
| Class for evaluation of 3D mesh alignment based on the learned vertex embeddings | |
| """ | |
| def __init__(self, embedder: nn.Module, mesh_names: Optional[List[str]]): | |
| self.embedder = embedder | |
| # use the provided mesh names if not None and not an empty list | |
| self.mesh_names = mesh_names if mesh_names else embedder.mesh_names | |
| self.logger = logging.getLogger(__name__) | |
| with PathManager.open( | |
| "https://dl.fbaipublicfiles.com/densepose/data/cse/mesh_keyvertices_v0.json", "r" | |
| ) as f: | |
| self.mesh_keyvertices = json.load(f) | |
| def evaluate(self): | |
| ge_per_mesh = {} | |
| gps_per_mesh = {} | |
| for mesh_name_1 in self.mesh_names: | |
| avg_errors = [] | |
| avg_gps = [] | |
| embeddings_1 = self.embedder(mesh_name_1) | |
| keyvertices_1 = self.mesh_keyvertices[mesh_name_1] | |
| keyvertex_names_1 = list(keyvertices_1.keys()) | |
| keyvertex_indices_1 = [keyvertices_1[name] for name in keyvertex_names_1] | |
| for mesh_name_2 in self.mesh_names: | |
| if mesh_name_1 == mesh_name_2: | |
| continue | |
| embeddings_2 = self.embedder(mesh_name_2) | |
| keyvertices_2 = self.mesh_keyvertices[mesh_name_2] | |
| sim_matrix_12 = embeddings_1[keyvertex_indices_1].mm(embeddings_2.T) | |
| vertices_2_matching_keyvertices_1 = sim_matrix_12.argmax(axis=1) | |
| mesh_2 = create_mesh(mesh_name_2, embeddings_2.device) | |
| geodists = mesh_2.geodists[ | |
| vertices_2_matching_keyvertices_1, | |
| [keyvertices_2[name] for name in keyvertex_names_1], | |
| ] | |
| Current_Mean_Distances = 0.255 | |
| gps = (-(geodists**2) / (2 * (Current_Mean_Distances**2))).exp() | |
| avg_errors.append(geodists.mean().item()) | |
| avg_gps.append(gps.mean().item()) | |
| ge_mean = torch.as_tensor(avg_errors).mean().item() | |
| gps_mean = torch.as_tensor(avg_gps).mean().item() | |
| ge_per_mesh[mesh_name_1] = ge_mean | |
| gps_per_mesh[mesh_name_1] = gps_mean | |
| ge_mean_global = torch.as_tensor(list(ge_per_mesh.values())).mean().item() | |
| gps_mean_global = torch.as_tensor(list(gps_per_mesh.values())).mean().item() | |
| per_mesh_metrics = { | |
| "GE": ge_per_mesh, | |
| "GPS": gps_per_mesh, | |
| } | |
| return ge_mean_global, gps_mean_global, per_mesh_metrics | |