Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -172,16 +172,30 @@ def step4_track(state):
|
|
| 172 |
from pytorch3d.io import load_obj
|
| 173 |
|
| 174 |
from pixel3dmm.tracking.flame.FLAME import FLAME
|
|
|
|
| 175 |
from pixel3dmm.tracking.tracker import Tracker
|
| 176 |
|
| 177 |
flame = FLAME(base_conf) # CPU instantiation
|
| 178 |
flame = flame.to(DEVICE) # CUDA init happens here
|
| 179 |
_model_cache["flame_model"] = flame
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
flame_model = _model_cache["flame_model"]
|
|
|
|
| 182 |
session_id = state.get("session_id")
|
| 183 |
base_conf.video_name = f'{session_id}'
|
| 184 |
-
tracker = Tracker(base_conf, flame_model)
|
| 185 |
tracker.run()
|
| 186 |
|
| 187 |
tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames")
|
|
|
|
| 172 |
from pytorch3d.io import load_obj
|
| 173 |
|
| 174 |
from pixel3dmm.tracking.flame.FLAME import FLAME
|
| 175 |
+
from pixel3dmm.tracking.renderer_nvdiffrast import NVDRenderer
|
| 176 |
from pixel3dmm.tracking.tracker import Tracker
|
| 177 |
|
| 178 |
flame = FLAME(base_conf) # CPU instantiation
|
| 179 |
flame = flame.to(DEVICE) # CUDA init happens here
|
| 180 |
_model_cache["flame_model"] = flame
|
| 181 |
+
|
| 182 |
+
_mesh_file = env_paths.head_template
|
| 183 |
+
|
| 184 |
+
_obj_faces = load_obj(_mesh_file)[1]
|
| 185 |
+
|
| 186 |
+
_model_cache["diff_renderer"] = NVDRenderer(
|
| 187 |
+
image_size=base_conf.size,
|
| 188 |
+
obj_filename=_mesh_file,
|
| 189 |
+
no_sh=False,
|
| 190 |
+
white_bg=True
|
| 191 |
+
).to(DEVICE)
|
| 192 |
+
|
| 193 |
|
| 194 |
flame_model = _model_cache["flame_model"]
|
| 195 |
+
diff_renderer = _model_cache["diff_renderer"]
|
| 196 |
session_id = state.get("session_id")
|
| 197 |
base_conf.video_name = f'{session_id}'
|
| 198 |
+
tracker = Tracker(base_conf, flame_model, diff_renderer)
|
| 199 |
tracker.run()
|
| 200 |
|
| 201 |
tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames")
|