JiantaoLin
commited on
Commit
Β·
6a8a55e
1
Parent(s):
7800520
new
Browse files
models/lrm/models/lrm_mesh.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14 |
|
15 |
import numpy as np
|
16 |
import torch
|
|
|
17 |
import torch.nn as nn
|
18 |
import nvdiffrast.torch as dr
|
19 |
from einops import rearrange, repeat
|
@@ -78,6 +79,7 @@ class PRM(nn.Module):
|
|
78 |
samples_per_ray=rendering_samples_per_ray,
|
79 |
)
|
80 |
|
|
|
81 |
def init_flexicubes_geometry(self, device, fovy=50.0):
|
82 |
camera = PerspectiveCamera(fovy=fovy, device=device)
|
83 |
renderer = NeuralRender(device, camera_model=camera)
|
@@ -103,6 +105,7 @@ class PRM(nn.Module):
|
|
103 |
|
104 |
return planes
|
105 |
|
|
|
106 |
def get_sdf_deformation_prediction(self, planes):
|
107 |
'''
|
108 |
Predict SDF and deformation for tetrahedron vertices
|
@@ -195,6 +198,7 @@ class PRM(nn.Module):
|
|
195 |
|
196 |
return v_list, f_list, imesh_list, sdf, deformation, v_deformed, (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg)
|
197 |
|
|
|
198 |
def get_texture_prediction(self, planes, tex_pos, hard_mask=None, gb_normal=None, training=True):
|
199 |
'''
|
200 |
Predict Texture given triplanes
|
@@ -364,7 +368,8 @@ class PRM(nn.Module):
|
|
364 |
'planes': planes,
|
365 |
**out
|
366 |
}
|
367 |
-
|
|
|
368 |
def extract_mesh(
|
369 |
self,
|
370 |
planes: torch.Tensor,
|
|
|
14 |
|
15 |
import numpy as np
|
16 |
import torch
|
17 |
+
import spaces
|
18 |
import torch.nn as nn
|
19 |
import nvdiffrast.torch as dr
|
20 |
from einops import rearrange, repeat
|
|
|
79 |
samples_per_ray=rendering_samples_per_ray,
|
80 |
)
|
81 |
|
82 |
+
@spaces.GPU
|
83 |
def init_flexicubes_geometry(self, device, fovy=50.0):
|
84 |
camera = PerspectiveCamera(fovy=fovy, device=device)
|
85 |
renderer = NeuralRender(device, camera_model=camera)
|
|
|
105 |
|
106 |
return planes
|
107 |
|
108 |
+
@spaces.GPU
|
109 |
def get_sdf_deformation_prediction(self, planes):
|
110 |
'''
|
111 |
Predict SDF and deformation for tetrahedron vertices
|
|
|
198 |
|
199 |
return v_list, f_list, imesh_list, sdf, deformation, v_deformed, (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg)
|
200 |
|
201 |
+
@spaces.GPU
|
202 |
def get_texture_prediction(self, planes, tex_pos, hard_mask=None, gb_normal=None, training=True):
|
203 |
'''
|
204 |
Predict Texture given triplanes
|
|
|
368 |
'planes': planes,
|
369 |
**out
|
370 |
}
|
371 |
+
|
372 |
+
@spaces.GPU
|
373 |
def extract_mesh(
|
374 |
self,
|
375 |
planes: torch.Tensor,
|