JiantaoLin commited on
Commit
6a8a55e
Β·
1 Parent(s): 7800520
Files changed (1) hide show
  1. models/lrm/models/lrm_mesh.py +6 -1
models/lrm/models/lrm_mesh.py CHANGED
@@ -14,6 +14,7 @@
14
 
15
  import numpy as np
16
  import torch
 
17
  import torch.nn as nn
18
  import nvdiffrast.torch as dr
19
  from einops import rearrange, repeat
@@ -78,6 +79,7 @@ class PRM(nn.Module):
78
  samples_per_ray=rendering_samples_per_ray,
79
  )
80
 
 
81
  def init_flexicubes_geometry(self, device, fovy=50.0):
82
  camera = PerspectiveCamera(fovy=fovy, device=device)
83
  renderer = NeuralRender(device, camera_model=camera)
@@ -103,6 +105,7 @@ class PRM(nn.Module):
103
 
104
  return planes
105
 
 
106
  def get_sdf_deformation_prediction(self, planes):
107
  '''
108
  Predict SDF and deformation for tetrahedron vertices
@@ -195,6 +198,7 @@ class PRM(nn.Module):
195
 
196
  return v_list, f_list, imesh_list, sdf, deformation, v_deformed, (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg)
197
 
 
198
  def get_texture_prediction(self, planes, tex_pos, hard_mask=None, gb_normal=None, training=True):
199
  '''
200
  Predict Texture given triplanes
@@ -364,7 +368,8 @@ class PRM(nn.Module):
364
  'planes': planes,
365
  **out
366
  }
367
-
 
368
  def extract_mesh(
369
  self,
370
  planes: torch.Tensor,
 
14
 
15
  import numpy as np
16
  import torch
17
+ import spaces
18
  import torch.nn as nn
19
  import nvdiffrast.torch as dr
20
  from einops import rearrange, repeat
 
79
  samples_per_ray=rendering_samples_per_ray,
80
  )
81
 
82
+ @spaces.GPU
83
  def init_flexicubes_geometry(self, device, fovy=50.0):
84
  camera = PerspectiveCamera(fovy=fovy, device=device)
85
  renderer = NeuralRender(device, camera_model=camera)
 
105
 
106
  return planes
107
 
108
+ @spaces.GPU
109
  def get_sdf_deformation_prediction(self, planes):
110
  '''
111
  Predict SDF and deformation for tetrahedron vertices
 
198
 
199
  return v_list, f_list, imesh_list, sdf, deformation, v_deformed, (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg)
200
 
201
+ @spaces.GPU
202
  def get_texture_prediction(self, planes, tex_pos, hard_mask=None, gb_normal=None, training=True):
203
  '''
204
  Predict Texture given triplanes
 
368
  'planes': planes,
369
  **out
370
  }
371
+
372
+ @spaces.GPU
373
  def extract_mesh(
374
  self,
375
  planes: torch.Tensor,