File size: 1,242 Bytes
19c4ddf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from dataclasses import dataclass, field
from typing import Dict, Optional

import torch

from .mesh import TriMesh


@dataclass
class TorchMesh:
    """
    A 3D triangle mesh with optional data at the vertices and faces.
    """

    # [N x 3] array of vertex coordinates.
    verts: torch.Tensor

    # [M x 3] array of triangles, pointing to indices in verts.
    faces: torch.Tensor

    # Extra data per vertex and face.
    vertex_channels: Optional[Dict[str, torch.Tensor]] = field(default_factory=dict)
    face_channels: Optional[Dict[str, torch.Tensor]] = field(default_factory=dict)

    def tri_mesh(self) -> TriMesh:
        """
        Create a CPU version of the mesh.
        """
        return TriMesh(
            verts=self.verts.detach().cpu().numpy(),
            faces=self.faces.cpu().numpy(),
            vertex_channels=(
                {k: v.detach().cpu().numpy() for k, v in self.vertex_channels.items()}
                if self.vertex_channels is not None
                else None
            ),
            face_channels=(
                {k: v.detach().cpu().numpy() for k, v in self.face_channels.items()}
                if self.face_channels is not None
                else None
            ),
        )