diff --git "a/lrm/lrm.py" "b/lrm/lrm.py"
new file mode 100644--- /dev/null
+++ "b/lrm/lrm.py"
@@ -0,0 +1,3373 @@
+import collections
+import itertools
+import math
+from typing import Dict, List, Optional, Set, Tuple, Union
+
+import numpy as np
+import nvdiffrast.torch as dr
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import xatlas
+from diffusers import ConfigMixin, ModelMixin
+from transformers import PreTrainedModel, ViTConfig, ViTImageProcessor
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import (BaseModelOutput,
+                                           BaseModelOutputWithPooling)
+from transformers.pytorch_utils import (find_pruneable_heads_and_indices,
+                                        prune_linear_layer)
+
+
+def generate_planes():
+    """
+    Defines planes by the three vectors that form the "axes" of the
+    plane. Should work with arbitrary number of planes and planes of
+    arbitrary orientation.
+
+    Bugfix reference: https://github.com/NVlabs/eg3d/issues/67
+    """
+    return torch.tensor([[[1, 0, 0],
+                            [0, 1, 0],
+                            [0, 0, 1]],
+                            [[1, 0, 0],
+                            [0, 0, 1],
+                            [0, 1, 0]],
+                            [[0, 0, 1],
+                            [0, 1, 0],
+                            [1, 0, 0]]], dtype=torch.float32)
+
+def project_onto_planes(planes, coordinates):
+    """
+    Does a projection of a 3D point onto a batch of 2D planes,
+    returning 2D plane coordinates.
+
+    Takes plane axes of shape n_planes, 3, 3
+    # Takes coordinates of shape N, M, 3
+    # returns projections of shape N*n_planes, M, 2
+    """
+    N, M, C = coordinates.shape
+    n_planes, _, _ = planes.shape
+    coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
+    inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
+    projections = torch.bmm(coordinates, inv_planes)
+    return projections[..., :2]
+
+def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
+    assert padding_mode == 'zeros'
+    N, n_planes, C, H, W = plane_features.shape
+    _, M, _ = coordinates.shape
+    plane_features = plane_features.view(N*n_planes, C, H, W)
+    dtype = plane_features.dtype
+
+    coordinates = (2/box_warp) * coordinates # add specific box bounds
+
+    projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
+    output_features = torch.nn.functional.grid_sample(
+        plane_features, 
+        projected_coordinates.to(dtype), 
+        mode=mode, 
+        padding_mode=padding_mode, 
+        align_corners=False,
+    ).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
+    return output_features
+
+
+class OSGDecoder(nn.Module):
+    """
+    Triplane decoder that gives RGB and sigma values from sampled features.
+    Using ReLU here instead of Softplus in the original implementation.
+    
+    Reference:
+    EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112
+    """
+    def __init__(self, n_features: int,
+                 hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU):
+        super().__init__()
+
+        self.net_sdf = nn.Sequential(
+            nn.Linear(3 * n_features, hidden_dim),
+            activation(),
+            *itertools.chain(*[[
+                nn.Linear(hidden_dim, hidden_dim),
+                activation(),
+            ] for _ in range(num_layers - 2)]),
+            nn.Linear(hidden_dim, 1),
+        )
+        self.net_rgb = nn.Sequential(
+            nn.Linear(3 * n_features, hidden_dim),
+            activation(),
+            *itertools.chain(*[[
+                nn.Linear(hidden_dim, hidden_dim),
+                activation(),
+            ] for _ in range(num_layers - 2)]),
+            nn.Linear(hidden_dim, 3),
+        )
+        self.net_deformation = nn.Sequential(
+            nn.Linear(3 * n_features, hidden_dim),
+            activation(),
+            *itertools.chain(*[[
+                nn.Linear(hidden_dim, hidden_dim),
+                activation(),
+            ] for _ in range(num_layers - 2)]),
+            nn.Linear(hidden_dim, 3),
+        )
+        self.net_weight = nn.Sequential(
+            nn.Linear(8 * 3 * n_features, hidden_dim),
+            activation(),
+            *itertools.chain(*[[
+                nn.Linear(hidden_dim, hidden_dim),
+                activation(),
+            ] for _ in range(num_layers - 2)]),
+            nn.Linear(hidden_dim, 21),
+        )
+
+        # init all bias to zero
+        for m in self.modules():
+            if isinstance(m, nn.Linear):
+                nn.init.zeros_(m.bias)
+
+    def get_geometry_prediction(self, sampled_features, flexicubes_indices):
+        _N, n_planes, _M, _C = sampled_features.shape
+        sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
+
+        sdf = self.net_sdf(sampled_features)
+        deformation = self.net_deformation(sampled_features)
+
+        grid_features = torch.index_select(input=sampled_features, index=flexicubes_indices.reshape(-1), dim=1)
+        grid_features = grid_features.reshape(
+            sampled_features.shape[0], flexicubes_indices.shape[0], flexicubes_indices.shape[1] * sampled_features.shape[-1])
+        weight = self.net_weight(grid_features) * 0.1
+
+        return sdf, deformation, weight
+    
+    def get_texture_prediction(self, sampled_features):
+        _N, n_planes, _M, _C = sampled_features.shape
+        sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
+
+        rgb = self.net_rgb(sampled_features)
+        rgb = torch.sigmoid(rgb)*(1 + 2*0.001) - 0.001  # Uses sigmoid clamping from MipNeRF
+
+        return rgb
+
+
+class TriplaneSynthesizer(nn.Module):
+    """
+    Synthesizer that renders a triplane volume with planes and a camera.
+    
+    Reference:
+    EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19
+    """
+
+    DEFAULT_RENDERING_KWARGS = {
+        'ray_start': 'auto',
+        'ray_end': 'auto',
+        'box_warp': 2.,
+        'white_back': True,
+        'disparity_space_sampling': False,
+        'clamp_mode': 'softplus',
+        'sampler_bbox_min': -1.,
+        'sampler_bbox_max': 1.,
+    }
+
+    def __init__(self, triplane_dim: int, samples_per_ray: int):
+        super().__init__()
+
+        # attributes
+        self.triplane_dim = triplane_dim
+        self.rendering_kwargs = {
+            **self.DEFAULT_RENDERING_KWARGS,
+            'depth_resolution': samples_per_ray // 2,
+            'depth_resolution_importance': samples_per_ray // 2,
+        }
+
+        # modules
+        self.plane_axes = generate_planes()
+        self.decoder = OSGDecoder(n_features=triplane_dim)
+
+    def get_geometry_prediction(self, planes, sample_coordinates, flexicubes_indices):
+        plane_axes = self.plane_axes.to(planes.device)
+        sampled_features = sample_from_planes(
+            plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp'])
+
+        sdf, deformation, weight = self.decoder.get_geometry_prediction(sampled_features, flexicubes_indices)
+        return sdf, deformation, weight
+    
+    def get_texture_prediction(self, planes, sample_coordinates):
+        plane_axes = self.plane_axes.to(planes.device)
+        sampled_features = sample_from_planes(
+            plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp'])
+
+        rgb = self.decoder.get_texture_prediction(sampled_features)
+        return rgb
+
+
+
+dmc_table = [
+[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]],
+[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]],
+[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]]
+]
+num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2,
+2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2,
+1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1,
+1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2,
+2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2,
+3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1,
+2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1,
+1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2,
+1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1,
+1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
+check_table = [
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 1, 0, 0, 194],
+[1, -1, 0, 0, 193],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 1, 0, 164],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, -1, 0, 161],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, 1, 152],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, 1, 145],
+[1, 0, 0, 1, 144],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, -1, 137],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 1, 0, 133],
+[1, 0, 1, 0, 132],
+[1, 1, 0, 0, 131],
+[1, 1, 0, 0, 130],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, 1, 100],
+[0, 0, 0, 0, 0],
+[1, 0, 0, 1, 98],
+[0, 0, 0, 0, 0],
+[1, 0, 0, 1, 96],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 1, 0, 88],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, -1, 0, 82],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 1, 0, 74],
+[0, 0, 0, 0, 0],
+[1, 0, 1, 0, 72],
+[0, 0, 0, 0, 0],
+[1, 0, 0, -1, 70],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, -1, 0, 0, 67],
+[0, 0, 0, 0, 0],
+[1, -1, 0, 0, 65],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 1, 0, 0, 56],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, -1, 0, 0, 52],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 1, 0, 0, 44],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 1, 0, 0, 40],
+[0, 0, 0, 0, 0],
+[1, 0, 0, -1, 38],
+[1, 0, -1, 0, 37],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, -1, 0, 33],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, -1, 0, 0, 28],
+[0, 0, 0, 0, 0],
+[1, 0, -1, 0, 26],
+[1, 0, 0, -1, 25],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, -1, 0, 0, 20],
+[0, 0, 0, 0, 0],
+[1, 0, -1, 0, 18],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, -1, 9],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, -1, 6],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0]
+]
+tet_table = [
+[-1, -1, -1, -1, -1, -1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[4, 4, 4, 4, 4, 4],
+[0, 0, 0, 0, 0, 0],
+[4, 0, 0, 4, 4, -1],
+[1, 1, 1, 1, 1, 1],
+[4, 4, 4, 4, 4, 4],
+[0, 4, 0, 4, 4, -1],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[5, 5, 5, 5, 5, 5],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, -1, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, -1, 2, 4, 4, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, 4, 4, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 4, 2, 4, 4, 2],
+[0, 4, 0, 4, 4, 0],
+[2, 0, 2, 0, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 5, 2, 5, 5, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, 0, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[1, 1, 1, 1, 1, 1],
+[0, 1, 1, -1, 0, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[4, 1, 1, 4, 4, 1],
+[0, 1, 1, 0, 0, 1],
+[4, 0, 0, 4, 4, 0],
+[2, 2, 2, 2, 2, 2],
+[-1, 1, 1, 4, 4, 1],
+[0, 1, 1, 4, 4, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[5, 1, 1, 5, 5, 1],
+[0, 1, 1, 0, 0, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[8, 8, 8, 8, 8, 8],
+[1, 1, 1, 4, 4, 1],
+[0, 0, 0, 0, 0, 0],
+[4, 0, 0, 4, 4, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 4, 4, 1],
+[0, 4, 0, 4, 4, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 5, 5, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[5, 5, 5, 5, 5, 5],
+[6, 6, 6, 6, 6, 6],
+[6, -1, 0, 6, 0, 6],
+[6, 0, 0, 6, 0, 6],
+[6, 1, 1, 6, 1, 6],
+[4, 4, 4, 4, 4, 4],
+[0, 0, 0, 0, 0, 0],
+[4, 0, 0, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[6, 4, -1, 6, 4, 6],
+[6, 4, 0, 6, 4, 6],
+[6, 0, 0, 6, 0, 6],
+[6, 1, 1, 6, 1, 6],
+[5, 5, 5, 5, 5, 5],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, 2, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 4, 2, 2, 4, 2],
+[0, 4, 0, 4, 4, 0],
+[2, 0, 2, 2, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[6, 1, 1, 6, -1, 6],
+[6, 1, 1, 6, 0, 6],
+[6, 0, 0, 6, 0, 6],
+[6, 2, 2, 6, 2, 6],
+[4, 1, 1, 4, 4, 1],
+[0, 1, 1, 0, 0, 1],
+[4, 0, 0, 4, 4, 4],
+[2, 2, 2, 2, 2, 2],
+[6, 1, 1, 6, 4, 6],
+[6, 1, 1, 6, 4, 6],
+[6, 0, 0, 6, 0, 6],
+[6, 2, 2, 6, 2, 6],
+[5, 1, 1, 5, 5, 1],
+[0, 1, 1, 0, 0, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[6, 6, 6, 6, 6, 6],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 4, 1],
+[0, 4, 0, 4, 4, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 5, 0, 5, 0, 5],
+[5, 5, 5, 5, 5, 5],
+[5, 5, 5, 5, 5, 5],
+[0, 5, 0, 5, 0, 5],
+[-1, 5, 0, 5, 0, 5],
+[1, 5, 1, 5, 1, 5],
+[4, 5, -1, 5, 4, 5],
+[0, 5, 0, 5, 0, 5],
+[4, 5, 0, 5, 4, 5],
+[1, 5, 1, 5, 1, 5],
+[4, 4, 4, 4, 4, 4],
+[0, 4, 0, 4, 4, 4],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[6, 6, 6, 6, 6, 6],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[2, 5, 2, 5, -1, 5],
+[0, 5, 0, 5, 0, 5],
+[2, 5, 2, 5, 0, 5],
+[1, 5, 1, 5, 1, 5],
+[2, 5, 2, 5, 4, 5],
+[0, 5, 0, 5, 0, 5],
+[2, 5, 2, 5, 4, 5],
+[1, 5, 1, 5, 1, 5],
+[2, 4, 2, 4, 4, 2],
+[0, 4, 0, 4, 4, 4],
+[2, 0, 2, 0, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 6, 2, 6, 6, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, 0, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[1, 1, 1, 1, 1, 1],
+[0, 1, 1, 1, 0, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[4, 1, 1, 1, 4, 1],
+[0, 1, 1, 1, 0, 1],
+[4, 0, 0, 4, 4, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[5, 5, 5, 5, 5, 5],
+[1, 1, 1, 1, 4, 1],
+[0, 0, 0, 0, 0, 0],
+[4, 0, 0, 4, 4, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[6, 0, 0, 6, 0, 6],
+[0, 0, 0, 0, 0, 0],
+[6, 6, 6, 6, 6, 6],
+[5, 5, 5, 5, 5, 5],
+[5, 5, 0, 5, 0, 5],
+[5, 5, 0, 5, 0, 5],
+[5, 5, 1, 5, 1, 5],
+[4, 4, 4, 4, 4, 4],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 0, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[4, 4, 4, 4, 4, 4],
+[4, 4, 0, 4, 4, 4],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[8, 8, 8, 8, 8, 8],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 1, 1, 4, 4, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[1, 1, 1, 1, 1, 1],
+[1, 1, 1, 1, 0, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[2, 4, 2, 4, 4, 2],
+[1, 1, 1, 1, 1, 1],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[5, 5, 5, 5, 5, 5],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[12, 12, 12, 12, 12, 12]
+]
+
+
+class FlexiCubes:
+    def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99):
+        self.device = device
+        self.dmc_table = torch.tensor(
+            dmc_table, dtype=torch.long, device=device, requires_grad=False
+        )
+        self.num_vd_table = torch.tensor(
+            num_vd_table, dtype=torch.long, device=device, requires_grad=False
+        )
+        self.check_table = torch.tensor(
+            check_table, dtype=torch.long, device=device, requires_grad=False
+        )
+
+        self.tet_table = torch.tensor(
+            tet_table, dtype=torch.long, device=device, requires_grad=False
+        )
+        self.quad_split_1 = torch.tensor(
+            [0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False
+        )
+        self.quad_split_2 = torch.tensor(
+            [0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False
+        )
+        self.quad_split_train = torch.tensor(
+            [0, 1, 1, 2, 2, 3, 3, 0],
+            dtype=torch.long,
+            device=device,
+            requires_grad=False,
+        )
+
+        self.cube_corners = torch.tensor(
+            [
+                [0, 0, 0],
+                [1, 0, 0],
+                [0, 1, 0],
+                [1, 1, 0],
+                [0, 0, 1],
+                [1, 0, 1],
+                [0, 1, 1],
+                [1, 1, 1],
+            ],
+            dtype=torch.float,
+            device=device,
+        )
+        self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False))
+        self.cube_edges = torch.tensor(
+            [0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, 2, 0, 3, 1, 7, 5, 6, 4],
+            dtype=torch.long,
+            device=device,
+            requires_grad=False,
+        )
+
+        self.edge_dir_table = torch.tensor(
+            [0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], dtype=torch.long, device=device
+        )
+        self.dir_faces_table = torch.tensor(
+            [
+                [[5, 4], [3, 2], [4, 5], [2, 3]],
+                [[5, 4], [1, 0], [4, 5], [0, 1]],
+                [[3, 2], [1, 0], [2, 3], [0, 1]],
+            ],
+            dtype=torch.long,
+            device=device,
+        )
+        self.adj_pairs = torch.tensor(
+            [0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device
+        )
+        self.qef_reg_scale = qef_reg_scale
+        self.weight_scale = weight_scale
+
+    def construct_voxel_grid(self, res):
+        """
+        Generates a voxel grid based on the specified resolution.
+
+        Args:
+            res (int or list[int]): The resolution of the voxel grid. If an integer
+                is provided, it is used for all three dimensions. If a list or tuple
+                of 3 integers is provided, they define the resolution for the x,
+                y, and z dimensions respectively.
+
+        Returns:
+            (torch.Tensor, torch.Tensor): Returns the vertices and the indices of the
+                cube corners (index into vertices) of the constructed voxel grid.
+                The vertices are centered at the origin, with the length of each
+                dimension in the grid being one.
+        """
+        base_cube_f = torch.arange(8).to(self.device)
+        if isinstance(res, int):
+            res = (res, res, res)
+        voxel_grid_template = torch.ones(res, device=self.device)
+
+        res = torch.tensor([res], dtype=torch.float, device=self.device)
+        coords = torch.nonzero(voxel_grid_template).float() / res  # N, 3
+        verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(
+            -1, 3
+        )
+        cubes = (
+            base_cube_f.unsqueeze(0)
+            + torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8
+        ).reshape(-1)
+
+        verts_rounded = torch.round(verts * 10**5) / (10**5)
+        verts_unique, inverse_indices = torch.unique(
+            verts_rounded, dim=0, return_inverse=True
+        )
+        cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8)
+
+        return verts_unique - 0.5, cubes
+
+    def __call__(
+        self,
+        x_nx3,
+        s_n,
+        cube_fx8,
+        res,
+        beta_fx12=None,
+        alpha_fx8=None,
+        gamma_f=None,
+        training=False,
+        output_tetmesh=False,
+        grad_func=None,
+    ):
+        r"""
+        Main function for mesh extraction from scalar field using FlexiCubes. This function converts
+        discrete signed distance fields, encoded on voxel grids and additional per-cube parameters,
+        to triangle or tetrahedral meshes using a differentiable operation as described in
+        `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances
+        mesh quality and geometric fidelity by adjusting the surface representation based on gradient
+        optimization. The output surface is differentiable with respect to the input vertex positions,
+        scalar field values, and weight parameters.
+
+        If you intend to extract a surface mesh from a fixed Signed Distance Field without the
+        optimization of parameters, it is suggested to provide the "grad_func" which should
+        return the surface gradient at any given 3D position. When grad_func is provided, the process
+        to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as
+        described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy.
+        Please note, this approach is non-differentiable.
+
+        For more details and example usage in optimization, refer to the
+        `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper.
+
+        Args:
+            x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed.
+            s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values
+                denote that the corresponding vertex resides inside the isosurface. This affects
+                the directions of the extracted triangle faces and volume to be tetrahedralized.
+            cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid.
+            res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it
+                is used for all three dimensions. If a list or tuple of 3 integers is provided, they
+                specify the resolution for the x, y, and z dimensions respectively.
+            beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual
+                vertices positioning. Defaults to uniform value for all edges.
+            alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual
+                vertices positioning. Defaults to uniform value for all vertices.
+            gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of
+                quadrilaterals into triangles. Defaults to uniform value for all cubes.
+            training (bool, optional): If set to True, applies differentiable quad splitting for
+                training. Defaults to False.
+            output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise,
+                outputs a triangular mesh. Defaults to False.
+            grad_func (callable, optional): A function to compute the surface gradient at specified
+                3D positions (input: Nx3 positions). The function should return gradients as an Nx3
+                tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None.
+
+        Returns:
+            (torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing:
+                - Vertices for the extracted triangular/tetrahedral mesh.
+                - Faces for the extracted triangular/tetrahedral mesh.
+                - Regularizer L_dev, computed per dual vertex.
+
+        .. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization:
+            https://research.nvidia.com/labs/toronto-ai/flexicubes/
+        .. _Manifold Dual Contouring:
+            https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf
+        """
+
+        surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8)
+        if surf_cubes.sum() == 0:
+            return (
+                torch.zeros((0, 3), device=self.device),
+                (
+                    torch.zeros((0, 4), dtype=torch.long, device=self.device)
+                    if output_tetmesh
+                    else torch.zeros((0, 3), dtype=torch.long, device=self.device)
+                ),
+                torch.zeros((0), device=self.device),
+            )
+        beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(
+            beta_fx12, alpha_fx8, gamma_f, surf_cubes
+        )
+
+        case_ids = self._get_case_id(occ_fx8, surf_cubes, res)
+
+        surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(
+            s_n, cube_fx8, surf_cubes
+        )
+
+        vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd(
+            x_nx3,
+            cube_fx8[surf_cubes],
+            surf_edges,
+            s_n,
+            case_ids,
+            beta_fx12,
+            alpha_fx8,
+            gamma_f,
+            idx_map,
+            grad_func,
+        )
+        vertices, faces, s_edges, edge_indices = self._triangulate(
+            s_n,
+            surf_edges,
+            vd,
+            vd_gamma,
+            edge_counts,
+            idx_map,
+            vd_idx_map,
+            surf_edges_mask,
+            training,
+            grad_func,
+        )
+        if not output_tetmesh:
+            return vertices, faces, L_dev
+        else:
+            vertices, tets = self._tetrahedralize(
+                x_nx3,
+                s_n,
+                cube_fx8,
+                vertices,
+                faces,
+                surf_edges,
+                s_edges,
+                vd_idx_map,
+                case_ids,
+                edge_indices,
+                surf_cubes,
+                training,
+            )
+            return vertices, tets, L_dev
+
+    def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
+        """
+        Regularizer L_dev as in Equation 8
+        """
+        dist = torch.norm(
+            ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1
+        )
+        mean_l2 = torch.zeros_like(vd[:, 0])
+        mean_l2 = (mean_l2).index_add_(
+            0, edge_group_to_vd, dist
+        ) / vd_num_edges.squeeze(1).float()
+        mad = (
+            dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)
+        ).abs()
+        return mad
+
+    def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes):
+        """
+        Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones.
+        """
+        n_cubes = surf_cubes.shape[0]
+
+        if beta_fx12 is not None:
+            beta_fx12 = torch.tanh(beta_fx12) * self.weight_scale + 1
+        else:
+            beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device)
+
+        if alpha_fx8 is not None:
+            alpha_fx8 = torch.tanh(alpha_fx8) * self.weight_scale + 1
+        else:
+            alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device)
+
+        if gamma_f is not None:
+            gamma_f = (
+                torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale) / 2
+            )
+        else:
+            gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device)
+
+        return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes]
+
+    @torch.no_grad()
+    def _get_case_id(self, occ_fx8, surf_cubes, res):
+        """
+        Obtains the ID of topology cases based on cell corner occupancy. This function resolves the
+        ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the
+        supplementary material. It should be noted that this function assumes a regular grid.
+        """
+        case_ids = (
+            occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)
+        ).sum(-1)
+
+        problem_config = self.check_table.to(self.device)[case_ids]
+        to_check = problem_config[..., 0] == 1
+        problem_config = problem_config[to_check]
+        if not isinstance(res, (list, tuple)):
+            res = [res, res, res]
+
+        # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
+        # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
+        # This allows efficient checking on adjacent cubes.
+        problem_config_full = torch.zeros(
+            list(res) + [5], device=self.device, dtype=torch.long
+        )
+        vol_idx = torch.nonzero(problem_config_full[..., 0] == 0)  # N, 3
+        vol_idx_problem = vol_idx[surf_cubes][to_check]
+        problem_config_full[
+            vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]
+        ] = problem_config
+        vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]
+
+        within_range = (
+            (vol_idx_problem_adj[..., 0] >= 0)
+            & (vol_idx_problem_adj[..., 0] < res[0])
+            & (vol_idx_problem_adj[..., 1] >= 0)
+            & (vol_idx_problem_adj[..., 1] < res[1])
+            & (vol_idx_problem_adj[..., 2] >= 0)
+            & (vol_idx_problem_adj[..., 2] < res[2])
+        )
+
+        vol_idx_problem = vol_idx_problem[within_range]
+        vol_idx_problem_adj = vol_idx_problem_adj[within_range]
+        problem_config = problem_config[within_range]
+        problem_config_adj = problem_config_full[
+            vol_idx_problem_adj[..., 0],
+            vol_idx_problem_adj[..., 1],
+            vol_idx_problem_adj[..., 2],
+        ]
+        # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
+        to_invert = problem_config_adj[..., 0] == 1
+        idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][
+            within_range
+        ][to_invert]
+        case_ids.index_put_((idx,), problem_config[to_invert][..., -1])
+        return case_ids
+
+    @torch.no_grad()
+    def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes):
+        """
+        Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge
+        can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge
+        and marks the cube edges with this index.
+        """
+        occ_n = s_n < 0
+        all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2)
+        unique_edges, _idx_map, counts = torch.unique(
+            all_edges, dim=0, return_inverse=True, return_counts=True
+        )
+
+        unique_edges = unique_edges.long()
+        mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
+
+        surf_edges_mask = mask_edges[_idx_map]
+        counts = counts[_idx_map]
+
+        mapping = (
+            torch.ones(
+                (unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device
+            )
+            * -1
+        )
+        mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device)
+        # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index
+        # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.
+        idx_map = mapping[_idx_map]
+        surf_edges = unique_edges[mask_edges]
+        return surf_edges, idx_map, counts, surf_edges_mask
+
+    @torch.no_grad()
+    def _identify_surf_cubes(self, s_n, cube_fx8):
+        """
+        Identifies grid cubes that intersect with the underlying surface by checking if the signs at
+        all corners are not identical.
+        """
+        occ_n = s_n < 0
+        occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
+        _occ_sum = torch.sum(occ_fx8, -1)
+        surf_cubes = (_occ_sum > 0) & (_occ_sum < 8)
+        return surf_cubes, occ_fx8
+
+    def _linear_interp(self, edges_weight, edges_x):
+        """
+        Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.
+        """
+        edge_dim = edges_weight.dim() - 2
+        assert edges_weight.shape[edge_dim] == 2
+        edges_weight = torch.cat(
+            [
+                torch.index_select(
+                    input=edges_weight,
+                    index=torch.tensor(1, device=self.device),
+                    dim=edge_dim,
+                ),
+                -torch.index_select(
+                    input=edges_weight,
+                    index=torch.tensor(0, device=self.device),
+                    dim=edge_dim,
+                ),
+            ],
+            edge_dim,
+        )
+        denominator = edges_weight.sum(edge_dim)
+        ue = (edges_x * edges_weight).sum(edge_dim) / denominator
+        return ue
+
+    def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None):
+        p_bxnx3 = p_bxnx3.reshape(-1, 7, 3)
+        norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3)
+        c_bx3 = c_bx3.reshape(-1, 3)
+        A = norm_bxnx3
+        B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True)
+
+        A_reg = (
+            (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale)
+            .unsqueeze(0)
+            .repeat(p_bxnx3.shape[0], 1, 1)
+        )
+        B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1)
+        A = torch.cat([A, A_reg], 1)
+        B = torch.cat([B, B_reg], 1)
+        dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1)
+        return dual_verts
+
+    def _compute_vd(
+        self,
+        x_nx3,
+        surf_cubes_fx8,
+        surf_edges,
+        s_n,
+        case_ids,
+        beta_fx12,
+        alpha_fx8,
+        gamma_f,
+        idx_map,
+        grad_func,
+    ):
+        """
+        Computes the location of dual vertices as described in Section 4.2
+        """
+        alpha_nx12x2 = torch.index_select(
+            input=alpha_fx8, index=self.cube_edges, dim=1
+        ).reshape(-1, 12, 2)
+        surf_edges_x = torch.index_select(
+            input=x_nx3, index=surf_edges.reshape(-1), dim=0
+        ).reshape(-1, 2, 3)
+        surf_edges_s = torch.index_select(
+            input=s_n, index=surf_edges.reshape(-1), dim=0
+        ).reshape(-1, 2, 1)
+        zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x)
+
+        idx_map = idx_map.reshape(-1, 12)
+        num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0)
+        edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = (
+            [],
+            [],
+            [],
+            [],
+            [],
+        )
+
+        total_num_vd = 0
+        vd_idx_map = torch.zeros(
+            (case_ids.shape[0], 12),
+            dtype=torch.long,
+            device=self.device,
+            requires_grad=False,
+        )
+        if grad_func is not None:
+            normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1)
+            vd = []
+        for num in torch.unique(num_vd):
+            cur_cubes = (
+                num_vd == num
+            )  # consider cubes with the same numbers of vd emitted (for batching)
+            curr_num_vd = cur_cubes.sum() * num
+            curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(
+                -1, num * 7
+            )
+            curr_edge_group_to_vd = (
+                torch.arange(curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7)
+                + total_num_vd
+            )
+            total_num_vd += curr_num_vd
+            curr_edge_group_to_cube = (
+                torch.arange(idx_map.shape[0], device=self.device)[cur_cubes]
+                .unsqueeze(-1)
+                .repeat(1, num * 7)
+                .reshape_as(curr_edge_group)
+            )
+
+            curr_mask = curr_edge_group != -1
+            edge_group.append(torch.masked_select(curr_edge_group, curr_mask))
+            edge_group_to_vd.append(
+                torch.masked_select(
+                    curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask
+                )
+            )
+            edge_group_to_cube.append(
+                torch.masked_select(curr_edge_group_to_cube, curr_mask)
+            )
+            vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True))
+            vd_gamma.append(
+                torch.masked_select(gamma_f, cur_cubes)
+                .unsqueeze(-1)
+                .repeat(1, num)
+                .reshape(-1)
+            )
+
+            if grad_func is not None:
+                with torch.no_grad():
+                    cube_e_verts_idx = idx_map[cur_cubes]
+                    curr_edge_group[~curr_mask] = 0
+
+                    verts_group_idx = torch.gather(
+                        input=cube_e_verts_idx, dim=1, index=curr_edge_group
+                    )
+                    verts_group_idx[verts_group_idx == -1] = 0
+                    verts_group_pos = torch.index_select(
+                        input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0
+                    ).reshape(-1, num.item(), 7, 3)
+                    v0 = (
+                        x_nx3[surf_cubes_fx8[cur_cubes][:, 0]]
+                        .reshape(-1, 1, 1, 3)
+                        .repeat(1, num.item(), 1, 1)
+                    )
+                    curr_mask = curr_mask.reshape(-1, num.item(), 7, 1)
+                    verts_centroid = (verts_group_pos * curr_mask).sum(2) / (
+                        curr_mask.sum(2)
+                    )
+
+                    normals_bx7x3 = torch.index_select(
+                        input=normals, index=verts_group_idx.reshape(-1), dim=0
+                    ).reshape(-1, num.item(), 7, 3)
+                    curr_mask = curr_mask.squeeze(2)
+                    vd.append(
+                        self._solve_vd_QEF(
+                            (verts_group_pos - v0) * curr_mask,
+                            normals_bx7x3 * curr_mask,
+                            verts_centroid - v0.squeeze(2),
+                        )
+                        + v0.reshape(-1, 3)
+                    )
+        edge_group = torch.cat(edge_group)
+        edge_group_to_vd = torch.cat(edge_group_to_vd)
+        edge_group_to_cube = torch.cat(edge_group_to_cube)
+        vd_num_edges = torch.cat(vd_num_edges)
+        vd_gamma = torch.cat(vd_gamma)
+
+        if grad_func is not None:
+            vd = torch.cat(vd)
+            L_dev = torch.zeros([1], device=self.device)
+        else:
+            vd = torch.zeros((total_num_vd, 3), device=self.device)
+            beta_sum = torch.zeros((total_num_vd, 1), device=self.device)
+
+            idx_group = torch.gather(
+                input=idx_map.reshape(-1),
+                dim=0,
+                index=edge_group_to_cube * 12 + edge_group,
+            )
+
+            x_group = torch.index_select(
+                input=surf_edges_x, index=idx_group.reshape(-1), dim=0
+            ).reshape(-1, 2, 3)
+            s_group = torch.index_select(
+                input=surf_edges_s, index=idx_group.reshape(-1), dim=0
+            ).reshape(-1, 2, 1)
+
+            zero_crossing_group = torch.index_select(
+                input=zero_crossing, index=idx_group.reshape(-1), dim=0
+            ).reshape(-1, 3)
+
+            alpha_group = torch.index_select(
+                input=alpha_nx12x2.reshape(-1, 2),
+                dim=0,
+                index=edge_group_to_cube * 12 + edge_group,
+            ).reshape(-1, 2, 1)
+            ue_group = self._linear_interp(s_group * alpha_group, x_group)
+
+            beta_group = torch.gather(
+                input=beta_fx12.reshape(-1),
+                dim=0,
+                index=edge_group_to_cube * 12 + edge_group,
+            ).reshape(-1, 1)
+            beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)
+            vd = (
+                vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group)
+                / beta_sum
+            )
+            L_dev = self._compute_reg_loss(
+                vd, zero_crossing_group, edge_group_to_vd, vd_num_edges
+            )
+
+        v_idx = torch.arange(vd.shape[0], device=self.device)  # + total_num_vd
+
+        vd_idx_map = (vd_idx_map.reshape(-1)).scatter(
+            dim=0,
+            index=edge_group_to_cube * 12 + edge_group,
+            src=v_idx[edge_group_to_vd],
+        )
+
+        return vd, L_dev, vd_gamma, vd_idx_map
+
+    def _triangulate(
+        self,
+        s_n,
+        surf_edges,
+        vd,
+        vd_gamma,
+        edge_counts,
+        idx_map,
+        vd_idx_map,
+        surf_edges_mask,
+        training,
+        grad_func,
+    ):
+        """
+        Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into
+        triangles based on the gamma parameter, as described in Section 4.3.
+        """
+        with torch.no_grad():
+            group_mask = (
+                edge_counts == 4
+            ) & surf_edges_mask  # surface edges shared by 4 cubes.
+            group = idx_map.reshape(-1)[group_mask]
+            vd_idx = vd_idx_map[group_mask]
+            edge_indices, indices = torch.sort(group, stable=True)
+            quad_vd_idx = vd_idx[indices].reshape(-1, 4)
+
+            # Ensure all face directions point towards the positive SDF to maintain consistent winding.
+            s_edges = s_n[
+                surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)
+            ].reshape(-1, 2)
+            flip_mask = s_edges[:, 0] > 0
+            quad_vd_idx = torch.cat(
+                (
+                    quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],
+                    quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]],
+                )
+            )
+        if grad_func is not None:
+            # when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients.
+            with torch.no_grad():
+                vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1)
+                quad_gamma = torch.index_select(
+                    input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0
+                ).reshape(-1, 4, 3)
+                gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True)
+                gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True)
+        else:
+            quad_gamma = torch.index_select(
+                input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0
+            ).reshape(-1, 4)
+            gamma_02 = torch.index_select(
+                input=quad_gamma, index=torch.tensor(0, device=self.device), dim=1
+            ) * torch.index_select(
+                input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1
+            )
+            gamma_13 = torch.index_select(
+                input=quad_gamma, index=torch.tensor(1, device=self.device), dim=1
+            ) * torch.index_select(
+                input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1
+            )
+        if not training:
+            mask = (gamma_02 > gamma_13).squeeze(1)
+            faces = torch.zeros(
+                (quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device
+            )
+            faces[mask] = quad_vd_idx[mask][:, self.quad_split_1]
+            faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2]
+            faces = faces.reshape(-1, 3)
+        else:
+            vd_quad = torch.index_select(
+                input=vd, index=quad_vd_idx.reshape(-1), dim=0
+            ).reshape(-1, 4, 3)
+            vd_02 = (
+                torch.index_select(
+                    input=vd_quad, index=torch.tensor(0, device=self.device), dim=1
+                )
+                + torch.index_select(
+                    input=vd_quad, index=torch.tensor(2, device=self.device), dim=1
+                )
+            ) / 2
+            vd_13 = (
+                torch.index_select(
+                    input=vd_quad, index=torch.tensor(1, device=self.device), dim=1
+                )
+                + torch.index_select(
+                    input=vd_quad, index=torch.tensor(3, device=self.device), dim=1
+                )
+            ) / 2
+            weight_sum = (gamma_02 + gamma_13) + 1e-8
+            vd_center = (
+                (vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1))
+                / weight_sum.unsqueeze(-1)
+            ).squeeze(1)
+            vd_center_idx = (
+                torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]
+            )
+            vd = torch.cat([vd, vd_center])
+            faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2)
+            faces = torch.cat(
+                [faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1
+            ).reshape(-1, 3)
+        return vd, faces, s_edges, edge_indices
+
+    def _tetrahedralize(
+        self,
+        x_nx3,
+        s_n,
+        cube_fx8,
+        vertices,
+        faces,
+        surf_edges,
+        s_edges,
+        vd_idx_map,
+        case_ids,
+        edge_indices,
+        surf_cubes,
+        training,
+    ):
+        """
+        Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5.
+        """
+        occ_n = s_n < 0
+        occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
+        occ_sum = torch.sum(occ_fx8, -1)
+
+        inside_verts = x_nx3[occ_n]
+        mapping_inside_verts = (
+            torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1
+        )
+        mapping_inside_verts[occ_n] = (
+            torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0]
+        )
+        """ 
+        For each grid edge connecting two grid vertices with different
+        signs, we first form a four-sided pyramid by connecting one
+        of the grid vertices with four mesh vertices that correspond
+        to the grid edge and then subdivide the pyramid into two tetrahedra
+        """
+        inside_verts_idx = mapping_inside_verts[
+            surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[s_edges < 0]
+        ]
+        if not training:
+            inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1)
+        else:
+            inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1)
+
+        tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1)
+        """ 
+        For each grid edge connecting two grid vertices with the
+        same sign, the tetrahedron is formed by the two grid vertices
+        and two vertices in consecutive adjacent cells
+        """
+        inside_cubes = occ_sum == 8
+        inside_cubes_center = (
+            x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1)
+        )
+        inside_cubes_center_idx = (
+            torch.arange(inside_cubes_center.shape[0], device=inside_cubes.device)
+            + vertices.shape[0]
+            + inside_verts.shape[0]
+        )
+
+        surface_n_inside_cubes = surf_cubes | inside_cubes
+        edge_center_vertex_idx = (
+            torch.ones(
+                ((surface_n_inside_cubes).sum(), 13),
+                dtype=torch.long,
+                device=x_nx3.device,
+            )
+            * -1
+        )
+        surf_cubes = surf_cubes[surface_n_inside_cubes]
+        inside_cubes = inside_cubes[surface_n_inside_cubes]
+        edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12)
+        edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx
+
+        all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2)
+        unique_edges, _idx_map, counts = torch.unique(
+            all_edges, dim=0, return_inverse=True, return_counts=True
+        )
+        unique_edges = unique_edges.long()
+        mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2
+        mask = mask_edges[_idx_map]
+        counts = counts[_idx_map]
+        mapping = (
+            torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device)
+            * -1
+        )
+        mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device)
+        idx_map = mapping[_idx_map]
+
+        group_mask = (counts == 4) & mask
+        group = idx_map.reshape(-1)[group_mask]
+        edge_indices, indices = torch.sort(group)
+        cube_idx = (
+            torch.arange(
+                (_idx_map.shape[0] // 12), dtype=torch.long, device=self.device
+            )
+            .unsqueeze(1)
+            .expand(-1, 12)
+            .reshape(-1)[group_mask]
+        )
+        edge_idx = (
+            torch.arange((12), dtype=torch.long, device=self.device)
+            .unsqueeze(0)
+            .expand(_idx_map.shape[0] // 12, -1)
+            .reshape(-1)[group_mask]
+        )
+        # Identify the face shared by the adjacent cells.
+        cube_idx_4 = cube_idx[indices].reshape(-1, 4)
+        edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0]
+        shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1)
+        cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1)
+        # Identify an edge of the face with different signs and
+        # select the mesh vertex corresponding to the identified edge.
+        case_ids_expand = (
+            torch.ones(
+                (surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device
+            )
+            * 255
+        )
+        case_ids_expand[surf_cubes] = case_ids
+        cases = case_ids_expand[cube_idx_4x2]
+        quad_edge = edge_center_vertex_idx[
+            cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]
+        ].reshape(-1, 2)
+        mask = (quad_edge == -1).sum(-1) == 0
+        inside_edge = mapping_inside_verts[
+            unique_edges[mask_edges][edge_indices].reshape(-1)
+        ].reshape(-1, 2)
+        tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask]
+
+        tets = torch.cat([tets_surface, tets_inside])
+        vertices = torch.cat([vertices, inside_verts, inside_cubes_center])
+        return vertices, tets
+
+
+def get_center_boundary_index(grid_res, device):
+    v = torch.zeros(
+        (grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device
+    )
+    v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True
+    center_indices = torch.nonzero(v.reshape(-1))
+
+    v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False
+    v[:2, ...] = True
+    v[-2:, ...] = True
+    v[:, :2, ...] = True
+    v[:, -2:, ...] = True
+    v[:, :, :2] = True
+    v[:, :, -2:] = True
+    boundary_indices = torch.nonzero(v.reshape(-1))
+    return center_indices, boundary_indices
+
+
+class Geometry:
+    def __init__(self):
+        pass
+
+    def forward(self):
+        pass
+
+
+class FlexiCubesGeometry(Geometry):
+    def __init__(
+        self,
+        grid_res=64,
+        scale=2.0,
+        device="cuda",
+        renderer=None,
+        render_type="neural_render",
+        args=None,
+    ):
+        super(FlexiCubesGeometry, self).__init__()
+        self.grid_res = grid_res
+        self.device = device
+        self.args = args
+        self.fc = FlexiCubes(device, weight_scale=0.5)
+        self.verts, self.indices = self.fc.construct_voxel_grid(grid_res)
+        if isinstance(scale, list):
+            self.verts[:, 0] = self.verts[:, 0] * scale[0]
+            self.verts[:, 1] = self.verts[:, 1] * scale[1]
+            self.verts[:, 2] = self.verts[:, 2] * scale[1]
+        else:
+            self.verts = self.verts * scale
+
+        all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2)
+        self.all_edges = torch.unique(all_edges, dim=0)
+
+        # Parameters used for fix boundary sdf
+        self.center_indices, self.boundary_indices = get_center_boundary_index(
+            self.grid_res, device
+        )
+        self.renderer = renderer
+        self.render_type = render_type
+
+    def getAABB(self):
+        return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
+
+    def get_mesh(
+        self,
+        v_deformed_nx3,
+        sdf_n,
+        weight_n=None,
+        with_uv=False,
+        indices=None,
+        is_training=False,
+    ):
+        if indices is None:
+            indices = self.indices
+
+        verts, faces, v_reg_loss = self.fc(
+            v_deformed_nx3,
+            sdf_n,
+            indices,
+            self.grid_res,
+            beta_fx12=weight_n[:, :12],
+            alpha_fx8=weight_n[:, 12:20],
+            gamma_f=weight_n[:, 20],
+            training=is_training,
+        )
+        return verts, faces, v_reg_loss
+
+    def render_mesh(
+        self,
+        mesh_v_nx3,
+        mesh_f_fx3,
+        camera_mv_bx4x4,
+        resolution=256,
+        hierarchical_mask=False,
+    ):
+        return_value = dict()
+        if self.render_type == "neural_render":
+            tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal = (
+                self.renderer.render_mesh(
+                    mesh_v_nx3.unsqueeze(dim=0),
+                    mesh_f_fx3.int(),
+                    camera_mv_bx4x4,
+                    mesh_v_nx3.unsqueeze(dim=0),
+                    resolution=resolution,
+                    device=self.device,
+                    hierarchical_mask=hierarchical_mask,
+                )
+            )
+
+            return_value["tex_pos"] = tex_pos
+            return_value["mask"] = mask
+            return_value["hard_mask"] = hard_mask
+            return_value["rast"] = rast
+            return_value["v_pos_clip"] = v_pos_clip
+            return_value["mask_pyramid"] = mask_pyramid
+            return_value["depth"] = depth
+            return_value["normal"] = normal
+        else:
+            raise NotImplementedError
+
+        return return_value
+
+    def render(
+        self,
+        v_deformed_bxnx3=None,
+        sdf_bxn=None,
+        camera_mv_bxnviewx4x4=None,
+        resolution=256,
+    ):
+        # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1
+        v_list = []
+        f_list = []
+        n_batch = v_deformed_bxnx3.shape[0]
+        all_render_output = []
+        for i_batch in range(n_batch):
+            verts_nx3, faces_fx3 = self.get_mesh(
+                v_deformed_bxnx3[i_batch], sdf_bxn[i_batch]
+            )
+            v_list.append(verts_nx3)
+            f_list.append(faces_fx3)
+            render_output = self.render_mesh(
+                verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution
+            )
+            all_render_output.append(render_output)
+
+        # Concatenate all render output
+        return_keys = all_render_output[0].keys()
+        return_value = dict()
+        for k in return_keys:
+            value = [v[k] for v in all_render_output]
+            return_value[k] = value
+            # We can do concatenation outside of the render
+        return return_value
+
+
+def interpolate(attr, rast, attr_idx, rast_db=None):
+    return dr.interpolate(
+        attr.contiguous(),
+        rast,
+        attr_idx,
+        rast_db=rast_db,
+        diff_attrs=None if rast_db is None else "all",
+    )
+
+
+def xfm_points(points, matrix, use_python=True):
+    """Transform points.
+    Args:
+        points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
+        matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
+        use_python: Use PyTorch's torch.matmul (for validation)
+    Returns:
+        Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
+    """
+    out = torch.matmul(
+        torch.nn.functional.pad(points, pad=(0, 1), mode="constant", value=1.0),
+        torch.transpose(matrix, 1, 2),
+    )
+    if torch.is_anomaly_enabled():
+        assert torch.all(
+            torch.isfinite(out)
+        ), "Output of xfm_points contains inf or NaN"
+    return out
+
+
+def dot(x, y):
+    return torch.sum(x * y, -1, keepdim=True)
+
+
+def compute_vertex_normal(v_pos, t_pos_idx):
+    i0 = t_pos_idx[:, 0]
+    i1 = t_pos_idx[:, 1]
+    i2 = t_pos_idx[:, 2]
+
+    v0 = v_pos[i0, :]
+    v1 = v_pos[i1, :]
+    v2 = v_pos[i2, :]
+
+    face_normals = torch.cross(v1 - v0, v2 - v0)
+
+    # Splat face normals to vertices
+    v_nrm = torch.zeros_like(v_pos)
+    v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
+    v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
+    v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
+
+    # Normalize, replace zero (degenerated) normals with some default value
+    v_nrm = torch.where(
+        dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
+    )
+    v_nrm = F.normalize(v_nrm, dim=1)
+    assert torch.all(torch.isfinite(v_nrm))
+
+    return v_nrm
+
+
+class Renderer:
+    def __init__(self):
+        pass
+
+    def forward(self):
+        pass
+
+
+class NeuralRender(Renderer):
+    def __init__(self, device="cuda", camera_model=None):
+        super(NeuralRender, self).__init__()
+        self.device = device
+        self.ctx = dr.RasterizeCudaContext(device=device)
+        self.projection_mtx = None
+        self.camera = camera_model
+
+    def render_mesh(
+        self,
+        mesh_v_pos_bxnx3,
+        mesh_t_pos_idx_fx3,
+        camera_mv_bx4x4,
+        mesh_v_feat_bxnxd,
+        resolution=256,
+        spp=1,
+        device="cuda",
+        hierarchical_mask=False,
+    ):
+        assert not hierarchical_mask
+
+        mtx_in = (
+            torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device)
+            if not torch.is_tensor(camera_mv_bx4x4)
+            else camera_mv_bx4x4
+        )
+        v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in)  # Rotate it to camera coordinates
+        v_pos_clip = self.camera.project(v_pos)  # Projection in the camera
+
+        v_nrm = compute_vertex_normal(
+            mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()
+        )  # vertex normals in world coordinates
+
+        # Render the image,
+        # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render
+        num_layers = 1
+        mask_pyramid = None
+        assert mesh_t_pos_idx_fx3.shape[0] > 0  # Make sure we have shapes
+        mesh_v_feat_bxnxd = torch.cat(
+            [mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1
+        )  # Concatenate the pos
+
+        with dr.DepthPeeler(
+            self.ctx,
+            v_pos_clip,
+            mesh_t_pos_idx_fx3,
+            [resolution * spp, resolution * spp],
+        ) as peeler:
+            for _ in range(num_layers):
+                rast, db = peeler.rasterize_next_layer()
+                gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3)
+
+        hard_mask = torch.clamp(rast[..., -1:], 0, 1)
+        antialias_mask = dr.antialias(
+            hard_mask.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3
+        )
+
+        depth = gb_feat[..., -2:-1]
+        ori_mesh_feature = gb_feat[..., :-4]
+
+        normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3)
+        normal = dr.antialias(
+            normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3
+        )
+        normal = F.normalize(normal, dim=-1)
+        normal = torch.lerp(
+            torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()
+        )  # black background
+
+        return (
+            ori_mesh_feature,
+            antialias_mask,
+            hard_mask,
+            rast,
+            v_pos_clip,
+            mask_pyramid,
+            depth,
+            normal,
+        )
+
+
+def projection(x=0.1, n=1.0, f=50.0, near_plane=None):
+    if near_plane is None:
+        near_plane = n
+    return np.array(
+        [
+            [n / x, 0, 0, 0],
+            [0, n / -x, 0, 0],
+            [
+                0,
+                0,
+                -(f + near_plane) / (f - near_plane),
+                -(2 * f * near_plane) / (f - near_plane),
+            ],
+            [0, 0, -1, 0],
+        ]
+    ).astype(np.float32)
+
+
+class Camera(nn.Module):
+    def __init__(self):
+        super(Camera, self).__init__()
+        pass
+
+
+class PerspectiveCamera(Camera):
+    def __init__(self, fovy=49.0, device="cuda"):
+        super(PerspectiveCamera, self).__init__()
+        self.device = device
+        focal = np.tan(fovy / 180.0 * np.pi * 0.5)
+        self.proj_mtx = (
+            torch.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1))
+            .to(self.device)
+            .unsqueeze(dim=0)
+        )
+
+    def project(self, points_bxnx4):
+        out = torch.matmul(points_bxnx4, torch.transpose(self.proj_mtx, 1, 2))
+        return out
+
+
+class ViTEmbeddings(nn.Module):
+    def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
+        super().__init__()
+
+        self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+        self.mask_token = (
+            nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+            if use_mask_token
+            else None
+        )
+        self.patch_embeddings = ViTPatchEmbeddings(config)
+        num_patches = self.patch_embeddings.num_patches
+        self.position_embeddings = nn.Parameter(
+            torch.randn(1, num_patches + 1, config.hidden_size)
+        )
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+        self.config = config
+
+    def interpolate_pos_encoding(
+        self, embeddings: torch.Tensor, height: int, width: int
+    ) -> torch.Tensor:
+        """
+        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+        resolution images.
+
+        Source:
+        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
+        """
+
+        num_patches = embeddings.shape[1] - 1
+        num_positions = self.position_embeddings.shape[1] - 1
+        if num_patches == num_positions and height == width:
+            return self.position_embeddings
+        class_pos_embed = self.position_embeddings[:, 0]
+        patch_pos_embed = self.position_embeddings[:, 1:]
+        dim = embeddings.shape[-1]
+        h0 = height // self.config.patch_size
+        w0 = width // self.config.patch_size
+        # we add a small number to avoid floating point error in the interpolation
+        # see discussion at https://github.com/facebookresearch/dino/issues/8
+        h0, w0 = h0 + 0.1, w0 + 0.1
+        patch_pos_embed = patch_pos_embed.reshape(
+            1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
+        )
+        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+        patch_pos_embed = nn.functional.interpolate(
+            patch_pos_embed,
+            scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
+            mode="bicubic",
+            align_corners=False,
+        )
+        assert (
+            int(h0) == patch_pos_embed.shape[-2]
+            and int(w0) == patch_pos_embed.shape[-1]
+        )
+        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+    def forward(
+        self,
+        pixel_values: torch.Tensor,
+        bool_masked_pos: Optional[torch.BoolTensor] = None,
+        interpolate_pos_encoding: bool = False,
+    ) -> torch.Tensor:
+        batch_size, num_channels, height, width = pixel_values.shape
+        embeddings = self.patch_embeddings(
+            pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
+        )
+
+        if bool_masked_pos is not None:
+            seq_length = embeddings.shape[1]
+            mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
+            # replace the masked visual tokens by mask_tokens
+            mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+            embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+        # add the [CLS] token to the embedded patch tokens
+        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+        embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+        # add positional encoding to each token
+        if interpolate_pos_encoding:
+            embeddings = embeddings + self.interpolate_pos_encoding(
+                embeddings, height, width
+            )
+        else:
+            embeddings = embeddings + self.position_embeddings
+
+        embeddings = self.dropout(embeddings)
+
+        return embeddings
+
+
+class ViTPatchEmbeddings(nn.Module):
+    """
+    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+    Transformer.
+    """
+
+    def __init__(self, config):
+        super().__init__()
+        image_size, patch_size = config.image_size, config.patch_size
+        num_channels, hidden_size = config.num_channels, config.hidden_size
+
+        image_size = (
+            image_size
+            if isinstance(image_size, collections.abc.Iterable)
+            else (image_size, image_size)
+        )
+        patch_size = (
+            patch_size
+            if isinstance(patch_size, collections.abc.Iterable)
+            else (patch_size, patch_size)
+        )
+        num_patches = (image_size[1] // patch_size[1]) * (
+            image_size[0] // patch_size[0]
+        )
+        self.image_size = image_size
+        self.patch_size = patch_size
+        self.num_channels = num_channels
+        self.num_patches = num_patches
+
+        self.projection = nn.Conv2d(
+            num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
+        )
+
+    def forward(
+        self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False
+    ) -> torch.Tensor:
+        batch_size, num_channels, height, width = pixel_values.shape
+        if num_channels != self.num_channels:
+            raise ValueError(
+                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+                f" Expected {self.num_channels} but got {num_channels}."
+            )
+        if not interpolate_pos_encoding:
+            if height != self.image_size[0] or width != self.image_size[1]:
+                raise ValueError(
+                    f"Input image size ({height}*{width}) doesn't match model"
+                    f" ({self.image_size[0]}*{self.image_size[1]})."
+                )
+        embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
+        return embeddings
+
+
+class ViTSelfAttention(nn.Module):
+    def __init__(self, config: ViTConfig) -> None:
+        super().__init__()
+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
+            config, "embedding_size"
+        ):
+            raise ValueError(
+                f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
+                f"heads {config.num_attention_heads}."
+            )
+
+        self.num_attention_heads = config.num_attention_heads
+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+        self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+        self.query = nn.Linear(
+            config.hidden_size, self.all_head_size, bias=config.qkv_bias
+        )
+        self.key = nn.Linear(
+            config.hidden_size, self.all_head_size, bias=config.qkv_bias
+        )
+        self.value = nn.Linear(
+            config.hidden_size, self.all_head_size, bias=config.qkv_bias
+        )
+
+        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+        new_x_shape = x.size()[:-1] + (
+            self.num_attention_heads,
+            self.attention_head_size,
+        )
+        x = x.view(new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def forward(
+        self,
+        hidden_states,
+        head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+        mixed_query_layer = self.query(hidden_states)
+
+        key_layer = self.transpose_for_scores(self.key(hidden_states))
+        value_layer = self.transpose_for_scores(self.value(hidden_states))
+        query_layer = self.transpose_for_scores(mixed_query_layer)
+
+        # Take the dot product between "query" and "key" to get the raw attention scores.
+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+        # Normalize the attention scores to probabilities.
+        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+        # This is actually dropping out entire tokens to attend to, which might
+        # seem a bit unusual, but is taken from the original Transformer paper.
+        attention_probs = self.dropout(attention_probs)
+
+        # Mask heads if we want to
+        if head_mask is not None:
+            attention_probs = attention_probs * head_mask
+
+        context_layer = torch.matmul(attention_probs, value_layer)
+
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+
+        outputs = (
+            (context_layer, attention_probs) if output_attentions else (context_layer,)
+        )
+
+        return outputs
+
+
+class ViTSelfOutput(nn.Module):
+    """
+    The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
+    layernorm applied before each block.
+    """
+
+    def __init__(self, config: ViTConfig) -> None:
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(
+        self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
+    ) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+
+        return hidden_states
+
+
+class ViTAttention(nn.Module):
+    def __init__(self, config: ViTConfig) -> None:
+        super().__init__()
+        self.attention = ViTSelfAttention(config)
+        self.output = ViTSelfOutput(config)
+        self.pruned_heads = set()
+
+    def prune_heads(self, heads: Set[int]) -> None:
+        if len(heads) == 0:
+            return
+        heads, index = find_pruneable_heads_and_indices(
+            heads,
+            self.attention.num_attention_heads,
+            self.attention.attention_head_size,
+            self.pruned_heads,
+        )
+
+        # Prune linear layers
+        self.attention.query = prune_linear_layer(self.attention.query, index)
+        self.attention.key = prune_linear_layer(self.attention.key, index)
+        self.attention.value = prune_linear_layer(self.attention.value, index)
+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+        # Update hyper params and store pruned heads
+        self.attention.num_attention_heads = self.attention.num_attention_heads - len(
+            heads
+        )
+        self.attention.all_head_size = (
+            self.attention.attention_head_size * self.attention.num_attention_heads
+        )
+        self.pruned_heads = self.pruned_heads.union(heads)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+        self_outputs = self.attention(hidden_states, head_mask, output_attentions)
+
+        attention_output = self.output(self_outputs[0], hidden_states)
+
+        outputs = (attention_output,) + self_outputs[
+            1:
+        ]  # add attentions if we output them
+        return outputs
+
+
+class ViTIntermediate(nn.Module):
+    def __init__(self, config: ViTConfig) -> None:
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+        if isinstance(config.hidden_act, str):
+            self.intermediate_act_fn = ACT2FN[config.hidden_act]
+        else:
+            self.intermediate_act_fn = config.hidden_act
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.intermediate_act_fn(hidden_states)
+
+        return hidden_states
+
+
+class ViTOutput(nn.Module):
+    def __init__(self, config: ViTConfig) -> None:
+        super().__init__()
+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(
+        self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
+    ) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+
+        hidden_states = hidden_states + input_tensor
+
+        return hidden_states
+
+
+def modulate(x, shift, scale):
+    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+class ViTLayer(nn.Module):
+    """This corresponds to the Block class in the timm implementation."""
+
+    def __init__(self, config: ViTConfig) -> None:
+        super().__init__()
+        self.chunk_size_feed_forward = config.chunk_size_feed_forward
+        self.seq_len_dim = 1
+        self.attention = ViTAttention(config)
+        self.intermediate = ViTIntermediate(config)
+        self.output = ViTOutput(config)
+        self.layernorm_before = nn.LayerNorm(
+            config.hidden_size, eps=config.layer_norm_eps
+        )
+        self.layernorm_after = nn.LayerNorm(
+            config.hidden_size, eps=config.layer_norm_eps
+        )
+
+        self.adaLN_modulation = nn.Sequential(
+            nn.SiLU(), nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=True)
+        )
+        nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
+        nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        adaln_input: torch.Tensor = None,
+        head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+        shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(
+            adaln_input
+        ).chunk(4, dim=1)
+
+        self_attention_outputs = self.attention(
+            modulate(
+                self.layernorm_before(hidden_states), shift_msa, scale_msa
+            ),  # in ViT, layernorm is applied before self-attention
+            head_mask,
+            output_attentions=output_attentions,
+        )
+        attention_output = self_attention_outputs[0]
+        outputs = self_attention_outputs[
+            1:
+        ]  # add self attentions if we output attention weights
+
+        # first residual connection
+        hidden_states = attention_output + hidden_states
+
+        # in ViT, layernorm is also applied after self-attention
+        layer_output = modulate(
+            self.layernorm_after(hidden_states), shift_mlp, scale_mlp
+        )
+        layer_output = self.intermediate(layer_output)
+
+        # second residual connection is done here
+        layer_output = self.output(layer_output, hidden_states)
+
+        outputs = (layer_output,) + outputs
+
+        return outputs
+
+
+class ViTEncoder(nn.Module):
+    def __init__(self, config: ViTConfig) -> None:
+        super().__init__()
+        self.config = config
+        self.layer = nn.ModuleList(
+            [ViTLayer(config) for _ in range(config.num_hidden_layers)]
+        )
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        adaln_input: torch.Tensor = None,
+        head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+        output_hidden_states: bool = False,
+        return_dict: bool = True,
+    ) -> Union[tuple, BaseModelOutput]:
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attentions = () if output_attentions else None
+
+        for i, layer_module in enumerate(self.layer):
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            layer_head_mask = head_mask[i] if head_mask is not None else None
+
+            if self.gradient_checkpointing and self.training:
+                layer_outputs = self._gradient_checkpointing_func(
+                    layer_module.__call__,
+                    hidden_states,
+                    adaln_input,
+                    layer_head_mask,
+                    output_attentions,
+                )
+            else:
+                layer_outputs = layer_module(
+                    hidden_states, adaln_input, layer_head_mask, output_attentions
+                )
+
+            hidden_states = layer_outputs[0]
+
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(
+                v
+                for v in [hidden_states, all_hidden_states, all_self_attentions]
+                if v is not None
+            )
+        return BaseModelOutput(
+            last_hidden_state=hidden_states,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+        )
+
+
+class ViTPreTrainedModel(PreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = ViTConfig
+    base_model_prefix = "vit"
+    main_input_name = "pixel_values"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["ViTEmbeddings", "ViTLayer"]
+
+    def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+        """Initialize the weights"""
+        if isinstance(module, (nn.Linear, nn.Conv2d)):
+            # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+            # `trunc_normal_cpu` not implemented in `half` issues
+            module.weight.data = nn.init.trunc_normal_(
+                module.weight.data.to(torch.float32),
+                mean=0.0,
+                std=self.config.initializer_range,
+            ).to(module.weight.dtype)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+        elif isinstance(module, ViTEmbeddings):
+            module.position_embeddings.data = nn.init.trunc_normal_(
+                module.position_embeddings.data.to(torch.float32),
+                mean=0.0,
+                std=self.config.initializer_range,
+            ).to(module.position_embeddings.dtype)
+
+            module.cls_token.data = nn.init.trunc_normal_(
+                module.cls_token.data.to(torch.float32),
+                mean=0.0,
+                std=self.config.initializer_range,
+            ).to(module.cls_token.dtype)
+
+
+class ViTModel(ViTPreTrainedModel):
+    def __init__(
+        self,
+        config: ViTConfig,
+        add_pooling_layer: bool = True,
+        use_mask_token: bool = False,
+    ):
+        super().__init__(config)
+        self.config = config
+
+        self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
+        self.encoder = ViTEncoder(config)
+
+        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.pooler = ViTPooler(config) if add_pooling_layer else None
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self) -> ViTPatchEmbeddings:
+        return self.embeddings.patch_embeddings
+
+    def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
+        """
+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+        class PreTrainedModel
+        """
+        for layer, heads in heads_to_prune.items():
+            self.encoder.layer[layer].attention.prune_heads(heads)
+
+    def forward(
+        self,
+        pixel_values: Optional[torch.Tensor] = None,
+        adaln_input: Optional[torch.Tensor] = None,
+        bool_masked_pos: Optional[torch.BoolTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        interpolate_pos_encoding: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPooling]:
+        r"""
+        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
+            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+        """
+        output_attentions = (
+            output_attentions
+            if output_attentions is not None
+            else self.config.output_attentions
+        )
+        output_hidden_states = (
+            output_hidden_states
+            if output_hidden_states is not None
+            else self.config.output_hidden_states
+        )
+        return_dict = (
+            return_dict if return_dict is not None else self.config.use_return_dict
+        )
+
+        if pixel_values is None:
+            raise ValueError("You have to specify pixel_values")
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+        # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
+        expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
+        if pixel_values.dtype != expected_dtype:
+            pixel_values = pixel_values.to(expected_dtype)
+
+        embedding_output = self.embeddings(
+            pixel_values,
+            bool_masked_pos=bool_masked_pos,
+            interpolate_pos_encoding=interpolate_pos_encoding,
+        )
+
+        encoder_outputs = self.encoder(
+            embedding_output,
+            adaln_input=adaln_input,
+            head_mask=head_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        sequence_output = encoder_outputs[0]
+        sequence_output = self.layernorm(sequence_output)
+        pooled_output = (
+            self.pooler(sequence_output) if self.pooler is not None else None
+        )
+
+        if not return_dict:
+            head_outputs = (
+                (sequence_output, pooled_output)
+                if pooled_output is not None
+                else (sequence_output,)
+            )
+            return head_outputs + encoder_outputs[1:]
+
+        return BaseModelOutputWithPooling(
+            last_hidden_state=sequence_output,
+            pooler_output=pooled_output,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+        )
+
+
+class ViTPooler(nn.Module):
+    def __init__(self, config: ViTConfig):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.activation = nn.Tanh()
+
+    def forward(self, hidden_states):
+        # We "pool" the model by simply taking the hidden state corresponding
+        # to the first token.
+        first_token_tensor = hidden_states[:, 0]
+        pooled_output = self.dense(first_token_tensor)
+        pooled_output = self.activation(pooled_output)
+        return pooled_output
+
+
+class DinoWrapper(nn.Module):
+    def __init__(self, model_name: str, freeze: bool = True):
+        super().__init__()
+        self.model, self.processor = self._build_dino(model_name)
+        self.camera_embedder = nn.Sequential(
+            nn.Linear(16, self.model.config.hidden_size, bias=True),
+            nn.SiLU(),
+            nn.Linear(
+                self.model.config.hidden_size, self.model.config.hidden_size, bias=True
+            ),
+        )
+        if freeze:
+            self._freeze()
+
+    def forward(self, image, camera):
+        if image.ndim == 5:
+            image = image.view(-1, *image.shape[2:])
+        dtype = image.dtype
+        inputs = (
+            self.processor(
+                images=image.float(),
+                return_tensors="pt",
+                do_rescale=False,
+                do_resize=False,
+            )
+            .to(self.model.device)
+            .to(dtype)
+        )
+        # embed camera
+        camera_embeddings = self.camera_embedder(camera)
+        camera_embeddings = camera_embeddings.view(-1, camera_embeddings.shape[-1])
+        embeddings = camera_embeddings
+        # This resampling of positional embedding uses bicubic interpolation
+        outputs = self.model(
+            **inputs, adaln_input=embeddings, interpolate_pos_encoding=True
+        )
+        last_hidden_states = outputs.last_hidden_state
+        return last_hidden_states
+
+    def _freeze(self):
+        self.model.eval()
+        for name, param in self.model.named_parameters():
+            param.requires_grad = False
+
+    @staticmethod
+    def _build_dino(
+        model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5
+    ):
+        import requests
+
+        try:
+            model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
+            processor = ViTImageProcessor.from_pretrained(model_name)
+            return model, processor
+        except requests.exceptions.ProxyError as err:
+            if proxy_error_retries > 0:
+                print(
+                    f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds..."
+                )
+                import time
+
+                time.sleep(proxy_error_cooldown)
+                return DinoWrapper._build_dino(
+                    model_name, proxy_error_retries - 1, proxy_error_cooldown
+                )
+            else:
+                raise err
+
+
+class BasicTransformerBlock(nn.Module):
+    def __init__(
+        self,
+        inner_dim: int,
+        cond_dim: int,
+        num_heads: int,
+        eps: float,
+        attn_drop: float = 0.0,
+        attn_bias: bool = False,
+        mlp_ratio: float = 4.0,
+        mlp_drop: float = 0.0,
+    ):
+        super().__init__()
+
+        self.norm1 = nn.LayerNorm(inner_dim)
+        self.cross_attn = nn.MultiheadAttention(
+            embed_dim=inner_dim,
+            num_heads=num_heads,
+            kdim=cond_dim,
+            vdim=cond_dim,
+            dropout=attn_drop,
+            bias=attn_bias,
+            batch_first=True,
+        )
+        self.norm2 = nn.LayerNorm(inner_dim)
+        self.self_attn = nn.MultiheadAttention(
+            embed_dim=inner_dim,
+            num_heads=num_heads,
+            dropout=attn_drop,
+            bias=attn_bias,
+            batch_first=True,
+        )
+        self.norm3 = nn.LayerNorm(inner_dim)
+        self.mlp = nn.Sequential(
+            nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
+            nn.GELU(),
+            nn.Dropout(mlp_drop),
+            nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
+            nn.Dropout(mlp_drop),
+        )
+
+    def forward(self, x, cond):
+        x = x + self.cross_attn(self.norm1(x), cond, cond)[0]
+        before_sa = self.norm2(x)
+        x = x + self.self_attn(before_sa, before_sa, before_sa)[0]
+        x = x + self.mlp(self.norm3(x))
+        return x
+
+
+class TriplaneTransformer(nn.Module):
+    def __init__(
+        self,
+        inner_dim: int,
+        image_feat_dim: int,
+        triplane_low_res: int,
+        triplane_high_res: int,
+        triplane_dim: int,
+        num_layers: int,
+        num_heads: int,
+        eps: float = 1e-6,
+    ):
+        super().__init__()
+
+        self.triplane_low_res = triplane_low_res
+        self.triplane_high_res = triplane_high_res
+        self.triplane_dim = triplane_dim
+
+        self.pos_embed = nn.Parameter(
+            torch.randn(1, 3 * triplane_low_res**2, inner_dim)
+            * (1.0 / inner_dim) ** 0.5
+        )
+        self.layers = nn.ModuleList(
+            [
+                BasicTransformerBlock(
+                    inner_dim=inner_dim,
+                    cond_dim=image_feat_dim,
+                    num_heads=num_heads,
+                    eps=eps,
+                )
+                for _ in range(num_layers)
+            ]
+        )
+        self.norm = nn.LayerNorm(inner_dim, eps=eps)
+        self.deconv = nn.ConvTranspose2d(
+            inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0
+        )
+
+    def forward(self, image_feats):
+
+        N = image_feats.shape[0]
+        H = W = self.triplane_low_res
+
+        x = self.pos_embed.repeat(N, 1, 1)
+        for layer in self.layers:
+            x = layer(x, image_feats)
+        x = self.norm(x)
+
+        x = x.view(N, 3, H, W, -1)
+        x = torch.einsum("nihwd->indhw", x)
+        x = x.contiguous().view(3 * N, -1, H, W)
+        x = self.deconv(x)
+        x = x.view(3, N, *x.shape[-3:])
+        x = torch.einsum("indhw->nidhw", x)
+        x = x.contiguous()
+
+        return x
+
+
+def interpolate_atlas(attr, rast, attr_idx, rast_db=None):
+    return dr.interpolate(
+        attr.contiguous(),
+        rast,
+        attr_idx,
+        rast_db=rast_db,
+        diff_attrs=None if rast_db is None else "all",
+    )
+
+
+def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution):
+    _, indices, uvs = xatlas.parametrize(
+        mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy()
+    )
+
+    indices_int64 = indices.astype(np.uint64, casting="same_kind").view(np.int64)
+
+    uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device)
+    mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device)
+    uv_clip = uvs[None, ...] * 2.0 - 1.0
+
+    uv_clip4 = torch.cat(
+        (
+            uv_clip,
+            torch.zeros_like(uv_clip[..., 0:1]),
+            torch.ones_like(uv_clip[..., 0:1]),
+        ),
+        dim=-1,
+    )
+
+    rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution))
+
+    gb_pos, _ = interpolate_atlas(mesh_v[None, ...], rast, mesh_pos_idx.int())
+    mask = rast[..., 3:4] > 0
+    return uvs, mesh_tex_idx, gb_pos, mask
+
+
+
+
+
+class LRM(ModelMixin, ConfigMixin):
+    def __init__(
+        self,
+        encoder_freeze: bool = False,
+        encoder_model_name: str = "facebook/dino-vitb16",
+        encoder_feat_dim: int = 768,
+        transformer_dim: int = 1024,
+        transformer_layers: int = 16,
+        transformer_heads: int = 16,
+        triplane_low_res: int = 32,
+        triplane_high_res: int = 64,
+        triplane_dim: int = 80,
+        rendering_samples_per_ray: int = 128,
+        grid_res: int = 128,
+        grid_scale: float = 2.1,
+    ):
+        super().__init__()
+
+        self.grid_res = grid_res
+        self.grid_scale = grid_scale
+        self.deformation_multiplier = 4.0
+
+        self.encoder = DinoWrapper(
+            model_name=encoder_model_name,
+            freeze=encoder_freeze,
+        )
+
+        self.transformer = TriplaneTransformer(
+            inner_dim=transformer_dim,
+            num_layers=transformer_layers,
+            num_heads=transformer_heads,
+            image_feat_dim=encoder_feat_dim,
+            triplane_low_res=triplane_low_res,
+            triplane_high_res=triplane_high_res,
+            triplane_dim=triplane_dim,
+        )
+
+        self.synthesizer = TriplaneSynthesizer(
+            triplane_dim=triplane_dim,
+            samples_per_ray=rendering_samples_per_ray,
+        )
+
+    def init_flexicubes_geometry(self, device, fovy=50.0):
+        camera = PerspectiveCamera(fovy=fovy, device=device)
+        renderer = NeuralRender(device, camera_model=camera)
+        self.geometry = FlexiCubesGeometry(
+            grid_res=self.grid_res,
+            scale=self.grid_scale,
+            renderer=renderer,
+            render_type="neural_render",
+            device=device,
+        )
+
+    def forward_planes(self, images, cameras):
+        B = images.shape[0]
+
+        image_feats = self.encoder(images, cameras)
+        image_feats = image_feats.view(B, -1, image_feats.shape[-1])
+
+        planes = self.transformer(image_feats)
+
+        return planes
+
+    def get_sdf_deformation_prediction(self, planes):
+        init_position = self.geometry.verts.unsqueeze(0).expand(planes.shape[0], -1, -1)
+
+        sdf, deformation, weight = torch.utils.checkpoint.checkpoint(
+            self.synthesizer.get_geometry_prediction,
+            planes,
+            init_position,
+            self.geometry.indices,
+            use_reentrant=False,
+        )
+
+        deformation = (
+            1.0
+            / (self.grid_res * self.deformation_multiplier)
+            * torch.tanh(deformation)
+        )
+        sdf_reg_loss = torch.zeros(sdf.shape[0], device=sdf.device, dtype=torch.float32)
+
+        sdf_bxnxnxn = sdf.reshape(
+            (sdf.shape[0], self.grid_res + 1, self.grid_res + 1, self.grid_res + 1)
+        )
+        sdf_less_boundary = sdf_bxnxnxn[:, 1:-1, 1:-1, 1:-1].reshape(sdf.shape[0], -1)
+        pos_shape = torch.sum((sdf_less_boundary > 0).int(), dim=-1)
+        neg_shape = torch.sum((sdf_less_boundary < 0).int(), dim=-1)
+        zero_surface = torch.bitwise_or(pos_shape == 0, neg_shape == 0)
+        if torch.sum(zero_surface).item() > 0:
+            update_sdf = torch.zeros_like(sdf[0:1])
+            max_sdf = sdf.max()
+            min_sdf = sdf.min()
+            update_sdf[:, self.geometry.center_indices] += 1.0 - min_sdf
+            update_sdf[:, self.geometry.boundary_indices] += -1 - max_sdf
+            new_sdf = torch.zeros_like(sdf)
+            for i_batch in range(zero_surface.shape[0]):
+                if zero_surface[i_batch]:
+                    new_sdf[i_batch : i_batch + 1] += update_sdf
+            update_mask = (new_sdf == 0).float()
+            sdf_reg_loss = torch.abs(sdf).mean(dim=-1).mean(dim=-1)
+            sdf_reg_loss = sdf_reg_loss * zero_surface.float()
+            sdf = sdf * update_mask + new_sdf * (1 - update_mask)
+
+        final_sdf = []
+        final_def = []
+        for i_batch in range(zero_surface.shape[0]):
+            if zero_surface[i_batch]:
+                final_sdf.append(sdf[i_batch : i_batch + 1].detach())
+                final_def.append(deformation[i_batch : i_batch + 1].detach())
+            else:
+                final_sdf.append(sdf[i_batch : i_batch + 1])
+                final_def.append(deformation[i_batch : i_batch + 1])
+        sdf = torch.cat(final_sdf, dim=0)
+        deformation = torch.cat(final_def, dim=0)
+        return sdf, deformation, sdf_reg_loss, weight
+
+    def get_geometry_prediction(self, planes=None):
+        sdf, deformation, sdf_reg_loss, weight = self.get_sdf_deformation_prediction(
+            planes
+        )
+        v_deformed = (
+            self.geometry.verts.unsqueeze(dim=0).expand(sdf.shape[0], -1, -1)
+            + deformation
+        )
+        tets = self.geometry.indices
+        n_batch = planes.shape[0]
+        v_list = []
+        f_list = []
+        flexicubes_surface_reg_list = []
+
+        for i_batch in range(n_batch):
+            verts, faces, flexicubes_surface_reg = self.geometry.get_mesh(
+                v_deformed[i_batch],
+                sdf[i_batch].squeeze(dim=-1),
+                with_uv=False,
+                indices=tets,
+                weight_n=weight[i_batch].squeeze(dim=-1),
+                is_training=self.training,
+            )
+            flexicubes_surface_reg_list.append(flexicubes_surface_reg)
+            v_list.append(verts)
+            f_list.append(faces)
+
+        flexicubes_surface_reg = torch.cat(flexicubes_surface_reg_list).mean()
+        flexicubes_weight_reg = (weight**2).mean()
+
+        return (
+            v_list,
+            f_list,
+            sdf,
+            deformation,
+            v_deformed,
+            (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg),
+        )
+
+    def get_texture_prediction(self, planes, tex_pos, hard_mask=None):
+        tex_pos = torch.cat(tex_pos, dim=0)
+        if hard_mask is not None:
+            tex_pos = tex_pos * hard_mask.float()
+        batch_size = tex_pos.shape[0]
+        tex_pos = tex_pos.reshape(batch_size, -1, 3)
+        if hard_mask is not None:
+            n_point_list = torch.sum(
+                hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1
+            )
+            sample_tex_pose_list = []
+            max_point = n_point_list.max()
+            expanded_hard_mask = (
+                hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5
+            )
+            for i in range(tex_pos.shape[0]):
+                tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3)
+                if tex_pos_one_shape.shape[1] < max_point:
+                    tex_pos_one_shape = torch.cat(
+                        [
+                            tex_pos_one_shape,
+                            torch.zeros(
+                                1,
+                                max_point - tex_pos_one_shape.shape[1],
+                                3,
+                                device=tex_pos_one_shape.device,
+                                dtype=torch.float32,
+                            ),
+                        ],
+                        dim=1,
+                    )
+                sample_tex_pose_list.append(tex_pos_one_shape)
+            tex_pos = torch.cat(sample_tex_pose_list, dim=0)
+
+        tex_feat = torch.utils.checkpoint.checkpoint(
+            self.synthesizer.get_texture_prediction,
+            planes,
+            tex_pos,
+            use_reentrant=False,
+        )
+
+        if hard_mask is not None:
+            final_tex_feat = torch.zeros(
+                planes.shape[0],
+                hard_mask.shape[1] * hard_mask.shape[2],
+                tex_feat.shape[-1],
+                device=tex_feat.device,
+            )
+            expanded_hard_mask = (
+                hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(
+                    -1, -1, final_tex_feat.shape[-1]
+                )
+                > 0.5
+            )
+            for i in range(planes.shape[0]):
+                final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][
+                    : n_point_list[i]
+                ].reshape(-1)
+            tex_feat = final_tex_feat
+
+        return tex_feat.reshape(
+            planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1]
+        )
+
+    def render_mesh(self, mesh_v, mesh_f, cam_mv, render_size=256):
+        return_value_list = []
+        for i_mesh in range(len(mesh_v)):
+            return_value = self.geometry.render_mesh(
+                mesh_v[i_mesh],
+                mesh_f[i_mesh].int(),
+                cam_mv[i_mesh],
+                resolution=render_size,
+                hierarchical_mask=False,
+            )
+            return_value_list.append(return_value)
+
+        return_keys = return_value_list[0].keys()
+        return_value = dict()
+        for k in return_keys:
+            value = [v[k] for v in return_value_list]
+            return_value[k] = value
+
+        mask = torch.cat(return_value["mask"], dim=0)
+        hard_mask = torch.cat(return_value["hard_mask"], dim=0)
+        tex_pos = return_value["tex_pos"]
+        depth = torch.cat(return_value["depth"], dim=0)
+        normal = torch.cat(return_value["normal"], dim=0)
+        return mask, hard_mask, tex_pos, depth, normal
+
+    def forward_geometry(self, planes, render_cameras, render_size=256):
+        B, NV = render_cameras.shape[:2]
+
+        mesh_v, mesh_f, sdf, _, _, sdf_reg_loss = self.get_geometry_prediction(planes)
+
+        cam_mv = render_cameras
+        run_n_view = cam_mv.shape[1]
+        antilias_mask, hard_mask, tex_pos, depth, normal = self.render_mesh(
+            mesh_v, mesh_f, cam_mv, render_size=render_size
+        )
+
+        tex_hard_mask = hard_mask
+        tex_pos = [
+            torch.cat([pos[i_view : i_view + 1] for i_view in range(run_n_view)], dim=2)
+            for pos in tex_pos
+        ]
+        tex_hard_mask = torch.cat(
+            [
+                torch.cat(
+                    [
+                        tex_hard_mask[
+                            i * run_n_view + i_view : i * run_n_view + i_view + 1
+                        ]
+                        for i_view in range(run_n_view)
+                    ],
+                    dim=2,
+                )
+                for i in range(planes.shape[0])
+            ],
+            dim=0,
+        )
+
+        tex_feat = self.get_texture_prediction(planes, tex_pos, tex_hard_mask)
+        background_feature = torch.ones_like(tex_feat)
+
+        img_feat = tex_feat * tex_hard_mask + background_feature * (1 - tex_hard_mask)
+
+        img_feat = torch.cat(
+            [
+                torch.cat(
+                    [
+                        img_feat[
+                            i : i + 1,
+                            :,
+                            render_size * i_view : render_size * (i_view + 1),
+                        ]
+                        for i_view in range(run_n_view)
+                    ],
+                    dim=0,
+                )
+                for i in range(len(tex_pos))
+            ],
+            dim=0,
+        )
+
+        img = img_feat.clamp(0, 1).permute(0, 3, 1, 2).unflatten(0, (B, NV))
+        antilias_mask = antilias_mask.permute(0, 3, 1, 2).unflatten(0, (B, NV))
+        depth = -depth.permute(0, 3, 1, 2).unflatten(0, (B, NV))
+        normal = normal.permute(0, 3, 1, 2).unflatten(0, (B, NV))
+
+        out = {
+            "img": img,
+            "mask": antilias_mask,
+            "depth": depth,
+            "normal": normal,
+            "sdf": sdf,
+            "mesh_v": mesh_v,
+            "mesh_f": mesh_f,
+            "sdf_reg_loss": sdf_reg_loss,
+        }
+        return out
+
+    def forward(self, images, cameras, render_cameras, render_size: int):
+        planes = self.forward_planes(images, cameras)
+        out = self.forward_geometry(planes, render_cameras, render_size=render_size)
+
+        return {"planes": planes, **out}
+
+    def extract_mesh(
+        self,
+        planes: torch.Tensor,
+        use_texture_map: bool = False,
+        texture_resolution: int = 1024,
+        **kwargs,
+    ):
+        """
+        Extract a 3D mesh from FlexiCubes. Only support batch_size 1.
+        :param planes: triplane features
+        :param use_texture_map: use texture map or vertex color
+        :param texture_resolution: the resolution of texure map
+        """
+        assert planes.shape[0] == 1
+
+        # predict geometry first
+        mesh_v, mesh_f, _, _, _, _ = self.get_geometry_prediction(planes)
+        vertices, faces = mesh_v[0], mesh_f[0]
+
+        if not use_texture_map:
+            # query vertex colors
+            vertices_tensor = vertices.unsqueeze(0)
+            vertices_colors = (
+                self.synthesizer.get_texture_prediction(planes, vertices_tensor)
+                .clamp(0, 1)
+                .squeeze(0)
+                .cpu()
+                .numpy()
+            )
+            vertices_colors = (vertices_colors * 255).astype(np.uint8)
+
+            return vertices.cpu().numpy(), faces.cpu().numpy(), vertices_colors
+
+        uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap(
+            self.geometry.renderer.ctx, vertices, faces, resolution=texture_resolution
+        )
+        tex_hard_mask = tex_hard_mask.float()
+
+        tex_feat = self.get_texture_prediction(planes, [gb_pos], tex_hard_mask)
+        background_feature = torch.zeros_like(tex_feat)
+        img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask)
+        texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0)
+
+        return vertices, faces, uvs, mesh_tex_idx, texture_map