dkatz2391 commited on
Commit
6ff7518
·
verified ·
1 Parent(s): 94bae10

retry all in omne

Browse files
Files changed (1) hide show
  1. app.py +80 -426
app.py CHANGED
@@ -3,7 +3,9 @@ import spaces
3
 
4
  import os
5
  import shutil
6
- import json
 
 
7
  import torch
8
  import numpy as np
9
  import imageio
@@ -11,35 +13,18 @@ from easydict import EasyDict as edict
11
  from trellis.pipelines import TrellisTextTo3DPipeline
12
  from trellis.representations import Gaussian, MeshExtractResult
13
  from trellis.utils import render_utils, postprocessing_utils
 
14
  import traceback
15
  import sys
16
- import time
17
- # If psutil is available in the environment, we can use it for memory info
18
- try:
19
- import psutil
20
- PSUTIL_AVAILABLE = True
21
- except ImportError:
22
- PSUTIL_AVAILABLE = False
23
-
24
- # --- Environment Variables ---
25
- os.environ['TOKENIZERS_PARALLELISM'] = 'true'
26
- os.environ['SPCONV_ALGO'] = 'native'
27
- # ---------------------------
28
-
29
- from typing import *
30
-
31
- # Add JSON encoder for NumPy arrays
32
- class NumpyEncoder(json.JSONEncoder):
33
- def default(self, obj):
34
- if isinstance(obj, np.ndarray):
35
- return obj.tolist()
36
- return json.JSONEncoder.default(self, obj)
37
 
38
 
39
  MAX_SEED = np.iinfo(np.int32).max
40
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
41
  os.makedirs(TMP_DIR, exist_ok=True)
42
 
 
 
43
 
44
  def start_session(req: gr.Request):
45
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
@@ -48,8 +33,7 @@ def start_session(req: gr.Request):
48
 
49
  def end_session(req: gr.Request):
50
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
51
- # Use shutil.rmtree with ignore_errors=True for robustness
52
- shutil.rmtree(user_dir, ignore_errors=True)
53
 
54
 
55
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
@@ -78,16 +62,15 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
78
  opacity_bias=state['gaussian']['opacity_bias'],
79
  scaling_activation=state['gaussian']['scaling_activation'],
80
  )
81
- # Ensure tensors are created on the correct device ('cuda')
82
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda', dtype=torch.float32)
83
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda', dtype=torch.float32)
84
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda', dtype=torch.float32)
85
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda', dtype=torch.float32)
86
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda', dtype=torch.float32)
87
 
88
  mesh = edict(
89
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda', dtype=torch.float32),
90
- faces=torch.tensor(state['mesh']['faces'], device='cuda', dtype=torch.int64), # Faces are usually integers
91
  )
92
 
93
  return gs, mesh
@@ -109,9 +92,9 @@ def text_to_3d(
109
  slat_guidance_strength: float,
110
  slat_sampling_steps: int,
111
  req: gr.Request,
112
- ) -> dict: # MODIFIED: Now returns only the state dict
113
  """
114
- Convert a text prompt to a 3D model state object.
115
  Args:
116
  prompt (str): The text prompt.
117
  seed (int): The random seed.
@@ -120,14 +103,11 @@ def text_to_3d(
120
  slat_guidance_strength (float): The guidance strength for structured latent generation.
121
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
122
  Returns:
123
- dict: The JSON-serializable state object containing the generated 3D model info.
 
124
  """
125
- # Ensure user directory exists (redundant if start_session is always called, but safe)
126
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
127
- os.makedirs(user_dir, exist_ok=True)
128
-
129
- print(f"[{req.session_hash}] Running text_to_3d for prompt: {prompt}") # Add logging
130
-
131
  outputs = pipeline.run(
132
  prompt,
133
  seed=seed,
@@ -141,63 +121,14 @@ def text_to_3d(
141
  "cfg_strength": slat_guidance_strength,
142
  },
143
  )
144
-
145
- # REMOVED: Video rendering logic moved to render_preview_video
146
-
147
- # Create the state object and ensure it's JSON serializable for API calls
148
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
149
- # Convert to serializable format
150
- serializable_state = json.loads(json.dumps(state, cls=NumpyEncoder))
151
-
152
- print(f"[{req.session_hash}] text_to_3d completed. Returning state.") # Modified log message
153
-
154
- torch.cuda.empty_cache()
155
-
156
- # --- REVERTED DEBUGGING ---
157
- # Remove the temporary simple dictionary return
158
- # print("[DEBUG] Returning simple dict for API test.")
159
- # return {"status": "test_success", "received_prompt": prompt}
160
- # --- END REVERTED DEBUGGING ---
161
-
162
- # Original return line (restored):
163
- return serializable_state # MODIFIED: Return only state
164
-
165
- # --- NEW FUNCTION ---
166
- @spaces.GPU
167
- def render_preview_video(state: dict, req: gr.Request) -> str:
168
- """
169
- Renders a preview video from the provided state object.
170
- Args:
171
- state (dict): The state object containing Gaussian and mesh data.
172
- req (gr.Request): Gradio request object for session hash.
173
- Returns:
174
- str: The path to the rendered video file.
175
- """
176
- if not state:
177
- print(f"[{req.session_hash}] render_preview_video called with empty state. Returning None.")
178
- return None
179
-
180
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
181
- os.makedirs(user_dir, exist_ok=True) # Ensure directory exists
182
-
183
- print(f"[{req.session_hash}] Unpacking state for video rendering.")
184
- # Only unpack gs, as mesh causes type errors with render_utils after unpacking
185
- gs, _ = unpack_state(state) # We still need the mesh for GLB, but not for this video preview
186
-
187
- print(f"[{req.session_hash}] Rendering video (Gaussian only)...")
188
- # Render ONLY the Gaussian splats, as rendering the unpacked mesh fails
189
- video = render_utils.render_video(gs, num_frames=120)['color']
190
- # REMOVED: video_geo = render_utils.render_video(mesh, num_frames=120)['normal']
191
- # REMOVED: video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
192
-
193
- video_path = os.path.join(user_dir, 'preview_sample.mp4')
194
- print(f"[{req.session_hash}] Saving video to {video_path}")
195
- # Save only the Gaussian render
196
  imageio.mimsave(video_path, video, fps=15)
197
-
198
  torch.cuda.empty_cache()
199
- return video_path
200
- # --- END NEW FUNCTION ---
201
 
202
 
203
  @spaces.GPU(duration=90)
@@ -208,67 +139,43 @@ def extract_glb(
208
  req: gr.Request,
209
  ) -> Tuple[str, str]:
210
  """
211
- Extract a GLB file from the 3D model state.
212
  Args:
213
  state (dict): The state of the generated 3D model.
214
  mesh_simplify (float): The mesh simplification factor.
215
  texture_size (int): The texture resolution.
216
  Returns:
217
- str: The path to the extracted GLB file (for Model3D component).
218
- str: The path to the extracted GLB file (for DownloadButton).
219
  """
220
- if not state:
221
- print(f"[{req.session_hash}] extract_glb called with empty state. Returning None.")
222
- return None, None # Return Nones if state is missing
223
-
224
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
225
  os.makedirs(user_dir, exist_ok=True)
226
-
227
- print(f"[{req.session_hash}] Unpacking state for GLB extraction.") # Add logging
228
  gs, mesh = unpack_state(state)
229
-
230
- print(f"[{req.session_hash}] Extracting GLB (simplify={mesh_simplify}, texture={texture_size})...") # Add logging
231
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
232
  glb_path = os.path.join(user_dir, 'sample.glb')
233
- print(f"[{req.session_hash}] Saving GLB to {glb_path}") # Add logging
234
  glb.export(glb_path)
235
-
236
  torch.cuda.empty_cache()
237
- # Return the same path for both Model3D and DownloadButton components
238
  return glb_path, glb_path
239
 
240
 
241
  @spaces.GPU
242
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
243
  """
244
- Extract a Gaussian PLY file from the 3D model state.
245
  Args:
246
  state (dict): The state of the generated 3D model.
247
  Returns:
248
- str: The path to the extracted Gaussian file (for Model3D component).
249
- str: The path to the extracted Gaussian file (for DownloadButton).
250
  """
251
- if not state:
252
- print(f"[{req.session_hash}] extract_gaussian called with empty state. Returning None.")
253
- return None, None # Return Nones if state is missing
254
-
255
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
256
  os.makedirs(user_dir, exist_ok=True)
257
-
258
- print(f"[{req.session_hash}] Unpacking state for Gaussian extraction.") # Add logging
259
  gs, _ = unpack_state(state)
260
-
261
  gaussian_path = os.path.join(user_dir, 'sample.ply')
262
- print(f"[{req.session_hash}] Saving Gaussian PLY to {gaussian_path}") # Add logging
263
  gs.save_ply(gaussian_path)
264
-
265
  torch.cuda.empty_cache()
266
- # Return the same path for both Model3D and DownloadButton components
267
  return gaussian_path, gaussian_path
268
 
269
 
270
- # --- NEW COMBINED API FUNCTION (with HEAVY logging) ---
271
- @spaces.GPU(duration=120)
272
  def generate_and_extract_glb(
273
  prompt: str,
274
  seed: int,
@@ -278,264 +185,51 @@ def generate_and_extract_glb(
278
  slat_sampling_steps: int,
279
  mesh_simplify: float,
280
  texture_size: int,
281
- req: Optional[gr.Request] = None, # Make req optional for robustness
282
- ) -> Optional[str]:
283
  """
284
- Combines 3D model generation and GLB extraction into a single step
285
- for API usage, avoiding the need to transfer the state object.
286
- Includes extensive logging.
287
- Args:
288
- prompt (str): Text prompt for generation.
289
- seed (int): Random seed.
290
- ss_guidance_strength (float): Sparse structure guidance.
291
- ss_sampling_steps (int): Sparse structure steps.
292
- slat_guidance_strength (float): Structured latent guidance.
293
- slat_sampling_steps (int): Structured latent steps.
294
- mesh_simplify (float): Mesh simplification factor for GLB.
295
- texture_size (int): Texture resolution for GLB.
296
- req (Optional[gr.Request]): Gradio request object.
297
- Returns:
298
- Optional[str]: Path to the generated GLB file or None on failure.
299
  """
300
- # --- Setup & Initial Logging ---
301
- pid = os.getpid()
302
- session_hash = f"API_CALL_{pid}_{int(time.time()*1000)}" # More unique ID for API calls
303
- if req and hasattr(req, 'session_hash') and req.session_hash:
304
- session_hash = req.session_hash # Use session hash if available from UI call
305
-
306
- print(f"\n[{session_hash}] ========= generate_and_extract_glb INVOKED =========")
307
- print(f"[{session_hash}] API: PID: {pid}")
308
- if PSUTIL_AVAILABLE:
309
- process = psutil.Process(pid)
310
- mem_info_start = process.memory_info()
311
- print(f"[{session_hash}] API: Initial Memory: RSS={mem_info_start.rss / (1024**2):.2f} MB, VMS={mem_info_start.vms / (1024**2):.2f} MB")
312
- else:
313
- print(f"[{session_hash}] API: psutil not available, cannot log memory usage.")
314
-
315
- user_dir = os.path.join(TMP_DIR, str(session_hash))
316
- try:
317
- print(f"[{session_hash}] API: Ensuring directory exists: {user_dir}")
318
- os.makedirs(user_dir, exist_ok=True)
319
- print(f"[{session_hash}] API: Directory ensured.")
320
- except Exception as e:
321
- print(f"[{session_hash}] API: FATAL ERROR creating directory {user_dir}: {e}")
322
- traceback.print_exc()
323
- print(f"[{session_hash}] ========= generate_and_extract_glb FAILED (Directory Creation) =========")
324
- return None
325
-
326
- print(f"[{session_hash}] API: Input Params: Prompt='{prompt}', Seed={seed}, Simplify={mesh_simplify}, Texture={texture_size}")
327
- print(f"[{session_hash}] API: Input Params: SS Steps={ss_sampling_steps}, SS Cfg={ss_guidance_strength}, Slat Steps={slat_sampling_steps}, Slat Cfg={slat_guidance_strength}")
328
-
329
- # Check CUDA availability
330
- cuda_available = torch.cuda.is_available()
331
- print(f"[{session_hash}] API: torch.cuda.is_available(): {cuda_available}")
332
- if not cuda_available:
333
- print(f"[{session_hash}] API: FATAL ERROR - CUDA not available!")
334
- print(f"[{session_hash}] ========= generate_and_extract_glb FAILED (CUDA Unavailable) =========")
335
- return None
336
-
337
- gs_output = None
338
- mesh_output = None
339
- glb_path = None
340
-
341
- # --- Step 1: Generate 3D Model ---
342
- print(f"\n[{session_hash}] API: --- Starting Step 1: Generation Pipeline --- ")
343
  try:
344
- if pipeline is None:
345
- print(f"[{session_hash}] API: FATAL ERROR - `pipeline` object is None!")
346
- raise ValueError("Trellis pipeline is not loaded.")
347
-
348
- print(f"[{session_hash}] API: Step 1 - Calling pipeline.run()...")
349
- t_start_gen = time.time()
350
- # --- The actual pipeline call ---
351
- outputs = pipeline.run(
352
- prompt,
353
- seed=seed,
354
- formats=["gaussian", "mesh"],
355
- sparse_structure_sampler_params={
356
- "steps": ss_sampling_steps,
357
- "cfg_strength": ss_guidance_strength,
358
- },
359
- slat_sampler_params={
360
- "steps": slat_sampling_steps,
361
- "cfg_strength": slat_guidance_strength,
362
- },
363
  )
364
- # --- End pipeline call ---
365
- t_end_gen = time.time()
366
- print(f"[{session_hash}] API: Step 1 - pipeline.run() completed in {t_end_gen - t_start_gen:.2f}s.")
367
-
368
- # === Validate pipeline outputs ===
369
- print(f"[{session_hash}] API: Step 1 - Validating pipeline outputs...")
370
- if not outputs:
371
- print(f"[{session_hash}] API: ERROR - Pipeline output dictionary is None or empty.")
372
- raise ValueError("Pipeline returned empty output.")
373
-
374
- if 'gaussian' not in outputs or not outputs['gaussian']:
375
- print(f"[{session_hash}] API: ERROR - Pipeline output missing 'gaussian' key or value is empty.")
376
- raise ValueError("Pipeline output missing Gaussian result.")
377
-
378
- if 'mesh' not in outputs or not outputs['mesh']:
379
- print(f"[{session_hash}] API: ERROR - Pipeline output missing 'mesh' key or value is empty.")
380
- raise ValueError("Pipeline output missing Mesh result.")
381
-
382
- gs_output = outputs['gaussian'][0]
383
- mesh_output = outputs['mesh'][0]
384
-
385
- if gs_output is None:
386
- print(f"[{session_hash}] API: ERROR - Pipeline returned gs_output as None.")
387
- raise ValueError("Pipeline returned None for Gaussian output.")
388
-
389
- if mesh_output is None:
390
- print(f"[{session_hash}] API: ERROR - Pipeline returned mesh_output as None.")
391
- raise ValueError("Pipeline returned None for Mesh output.")
392
-
393
- print(f"[{session_hash}] API: Step 1 - Outputs validated successfully.")
394
- print(f"[{session_hash}] API: Step 1 - gs_output type: {type(gs_output)}")
395
- # Add more details if useful, e.g., number of Gaussians
396
- if hasattr(gs_output, '_xyz'):
397
- print(f"[{session_hash}] API: Step 1 - gs_output num points: {len(gs_output._xyz)}")
398
-
399
- print(f"[{session_hash}] API: Step 1 - mesh_output type: {type(mesh_output)}")
400
- # Add more details if useful, e.g., number of vertices/faces
401
- if hasattr(mesh_output, 'vertices') and hasattr(mesh_output, 'faces'):
402
- print(f"[{session_hash}] API: Step 1 - mesh_output verts: {len(mesh_output.vertices)}, faces: {len(mesh_output.faces)}")
403
- # =================================
404
-
405
- if PSUTIL_AVAILABLE:
406
- mem_info_after_gen = process.memory_info()
407
- print(f"[{session_hash}] API: Memory After Gen: RSS={mem_info_after_gen.rss / (1024**2):.2f} MB, VMS={mem_info_after_gen.vms / (1024**2):.2f} MB")
408
-
409
- except Exception as e_gen:
410
- print(f"\n[{session_hash}] API: ******** ERROR IN STEP 1: Generation Pipeline ********")
411
- print(f"[{session_hash}] API: Error Type: {type(e_gen).__name__}, Message: {e_gen}")
412
- print(f"[{session_hash}] API: Printing traceback...")
413
- traceback.print_exc()
414
- print(f"[{session_hash}] API: ********************************************************")
415
- gs_output = None # Ensure reset on error
416
- mesh_output = None
417
- # Fall through to finally block for cleanup
418
- finally:
419
- # Attempt cleanup regardless of success/failure in try block
420
- print(f"[{session_hash}] API: Step 1 - Entering finally block for potential cleanup.")
421
- try:
422
- print(f"[{session_hash}] API: Step 1 - Attempting CUDA cache clear (finally)...")
423
- torch.cuda.empty_cache()
424
- print(f"[{session_hash}] API: Step 1 - CUDA cache cleared (finally).")
425
- except Exception as cache_e_gen:
426
- print(f"[{session_hash}] API: WARNING - Error clearing CUDA cache in Step 1 finally block: {cache_e_gen}")
427
- print(f"[{session_hash}] API: --- Finished Step 1: Generation Pipeline (gs valid: {gs_output is not None}, mesh valid: {mesh_output is not None}) --- \n")
428
-
429
- # --- Step 2: Extract GLB ---
430
- # Proceed only if Step 1 was successful
431
- if gs_output is not None and mesh_output is not None:
432
- print(f"\n[{session_hash}] API: --- Starting Step 2: GLB Extraction --- ")
433
- try:
434
- print(f"[{session_hash}] API: Step 2 - Inputs: gs type {type(gs_output)}, mesh type {type(mesh_output)}")
435
- print(f"[{session_hash}] API: Step 2 - Params: Simplify={mesh_simplify}, Texture Size={texture_size}")
436
- print(f"[{session_hash}] API: Step 2 - Calling postprocessing_utils.to_glb()...")
437
- t_start_glb = time.time()
438
- # --- The actual GLB conversion call ---
439
- glb = postprocessing_utils.to_glb(gs_output, mesh_output, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
440
- # --- End GLB conversion call ---
441
- t_end_glb = time.time()
442
- print(f"[{session_hash}] API: Step 2 - postprocessing_utils.to_glb() completed in {t_end_glb - t_start_glb:.2f}s.")
443
-
444
- # === Validate GLB output ===
445
- print(f"[{session_hash}] API: Step 2 - Validating GLB object...")
446
- if glb is None:
447
- print(f"[{session_hash}] API: ERROR - postprocessing_utils.to_glb returned None.")
448
- raise ValueError("GLB conversion returned None.")
449
- print(f"[{session_hash}] API: Step 2 - GLB object validated successfully (type: {type(glb)})...")
450
- # ==========================
451
-
452
- # === Save GLB ===
453
- glb_path = os.path.join(user_dir, f'api_generated_{session_hash}_{int(time.time()*1000)}.glb') # More unique name
454
- print(f"[{session_hash}] API: Step 2 - Saving GLB to path: {glb_path}...")
455
- t_start_save = time.time()
456
- # --- The actual GLB export call ---
457
- glb.export(glb_path)
458
- # --- End GLB export call ---
459
- t_end_save = time.time()
460
- print(f"[{session_hash}] API: Step 2 - glb.export() completed in {t_end_save - t_start_save:.2f}s.")
461
- # =================
462
-
463
- # === Verify File Exists ===
464
- print(f"[{session_hash}] API: Step 2 - Verifying saved file exists at {glb_path}...")
465
- if not os.path.exists(glb_path):
466
- print(f"[{session_hash}] API: ERROR - GLB file was not found after export at {glb_path}.")
467
- raise IOError(f"GLB export failed, file not found: {glb_path}")
468
- print(f"[{session_hash}] API: Step 2 - Saved file verified.")
469
- # =========================
470
-
471
- print(f"[{session_hash}] API: Step 2 - GLB extraction and saving completed successfully.")
472
- if PSUTIL_AVAILABLE:
473
- mem_info_after_glb = process.memory_info()
474
- print(f"[{session_hash}] API: Memory After GLB: RSS={mem_info_after_glb.rss / (1024**2):.2f} MB, VMS={mem_info_after_glb.vms / (1024**2):.2f} MB")
475
-
476
- except Exception as e_glb:
477
- print(f"\n[{session_hash}] API: ******** ERROR IN STEP 2: GLB Extraction ********")
478
- print(f"[{session_hash}] API: Error Type: {type(e_glb).__name__}, Message: {e_glb}")
479
- print(f"[{session_hash}] API: Printing traceback...")
480
- traceback.print_exc()
481
- print(f"[{session_hash}] API: *****************************************************")
482
- glb_path = None # Ensure reset on error
483
- # Fall through to finally block for cleanup
484
- finally:
485
- # Attempt cleanup regardless of success/failure in try block
486
- print(f"[{session_hash}] API: Step 2 - Entering finally block for potential cleanup.")
487
- # Explicitly delete large objects if possible (might help memory)
488
- del glb
489
- print(f"[{session_hash}] API: Step 2 - Deleted intermediate 'glb' object.")
490
- try:
491
- print(f"[{session_hash}] API: Step 2 - Attempting CUDA cache clear (finally)...")
492
- torch.cuda.empty_cache()
493
- print(f"[{session_hash}] API: Step 2 - CUDA cache cleared (finally).")
494
- except Exception as cache_e_glb:
495
- print(f"[{session_hash}] API: WARNING - Error clearing CUDA cache in Step 2 finally block: {cache_e_glb}")
496
- print(f"[{session_hash}] API: --- Finished Step 2: GLB Extraction (path valid: {glb_path is not None}) --- \n")
497
- else:
498
- print(f"[{session_hash}] API: Skipping Step 2 (GLB Extraction) because Step 1 failed or produced invalid outputs.")
499
- glb_path = None # Ensure glb_path is None if Step 1 failed
500
-
501
- # --- Final Cleanup and Return ---
502
- print(f"[{session_hash}] API: --- Entering Final Cleanup and Return --- ")
503
- # Final attempt to clear CUDA cache
504
- try:
505
- print(f"[{session_hash}] API: Final CUDA cache clear attempt...")
506
- torch.cuda.empty_cache()
507
- print(f"[{session_hash}] API: Final CUDA cache cleared.")
508
- except Exception as cache_e_final:
509
- print(f"[{session_hash}] API: WARNING - Error clearing final CUDA cache: {cache_e_final}")
510
-
511
- # Explicitly delete pipeline outputs if they exist
512
- del gs_output
513
- del mesh_output
514
- print(f"[{session_hash}] API: Deleted intermediate 'gs_output' and 'mesh_output' objects.")
515
-
516
- # Final decision based on glb_path status
517
- if glb_path and os.path.exists(glb_path):
518
- print(f"[{session_hash}] API: Final Result: SUCCESS. GLB Path: {glb_path}")
519
- print(f"[{session_hash}] ========= generate_and_extract_glb END (Success) =========")
520
  return glb_path
521
- else:
522
- print(f"[{session_hash}] API: Final Result: FAILURE. GLB Path: {glb_path} (Exists: {os.path.exists(glb_path) if glb_path else 'N/A'})")
523
- print(f"[{session_hash}] ========= generate_and_extract_glb END (Failure) =========")
524
- return None
525
- # --- END NEW COMBINED API FUNCTION ---
526
 
 
 
 
527
 
528
- # State object to hold the generated model info between steps
529
- output_buf = gr.State()
530
- # Video component placeholder (will be populated by render_preview_video)
531
- # video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300) # Defined later inside the Blocks
532
 
533
  with gr.Blocks(delete_cache=(600, 600)) as demo:
534
  gr.Markdown("""
535
  ## Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
536
  * Type a text prompt and click "Generate" to create a 3D asset.
537
- * The preview video will appear after generation.
538
- * If you find the generated 3D asset satisfactory, click "Extract GLB" or "Extract Gaussian" to extract the file and download it.
539
  """)
540
 
541
  with gr.Row():
@@ -561,7 +255,6 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
561
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
562
 
563
  with gr.Row():
564
- # Buttons start non-interactive, enabled after generation
565
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
566
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
567
  gr.Markdown("""
@@ -569,102 +262,63 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
569
  """)
570
 
571
  with gr.Column():
572
- # Define UI components here
573
- video_output = gr.Video(label="Generated 3D Asset Preview", autoplay=True, loop=True, height=300)
574
  model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
575
 
576
  with gr.Row():
577
- # Buttons start non-interactive, enabled after extraction
578
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
579
  download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
580
 
581
- # Define the state buffer here, outside the component definitions but inside the Blocks scope
582
  output_buf = gr.State()
583
 
584
- # --- Handlers ---
585
  demo.load(start_session)
586
  demo.unload(end_session)
587
 
588
- # --- MODIFIED UI CHAIN ---
589
- # 1. Get Seed
590
- # 2. Run text_to_3d -> outputs state to output_buf
591
- # 3. Run render_preview_video (using state from output_buf) -> outputs video to video_output
592
- # 4. Enable extraction buttons
593
  generate_btn.click(
594
  get_seed,
595
  inputs=[randomize_seed, seed],
596
  outputs=[seed],
597
- queue=True # Use queue for potentially long-running steps
598
  ).then(
599
  text_to_3d,
600
  inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
601
- outputs=[output_buf], # text_to_3d now ONLY outputs state
602
- api_name="text_to_3d" # Keep API name consistent if needed
603
- ).then(
604
- render_preview_video, # NEW step: Render video from state
605
- inputs=[output_buf],
606
- outputs=[video_output],
607
- api_name="render_preview_video" # Assign API name if you want to call this separately
608
  ).then(
609
- lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]), # Enable extraction buttons
610
  outputs=[extract_glb_btn, extract_gs_btn],
611
  )
612
 
613
- # Clear video and disable extraction buttons if prompt is cleared or generation restarted
614
- # (Consider adding logic to clear prompt on successful generation if desired)
615
- text_prompt.change( # Example: Clear video if prompt changes
616
- lambda: (None, gr.Button(interactive=False), gr.Button(interactive=False)),
617
- outputs=[video_output, extract_glb_btn, extract_gs_btn]
618
- )
619
- video_output.clear( # This might be redundant if text_prompt.change handles it
620
  lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
621
  outputs=[extract_glb_btn, extract_gs_btn],
622
  )
623
 
624
- # --- Extraction Handlers ---
625
- # GLB Extraction: Takes state from output_buf, outputs model and download path
626
  extract_glb_btn.click(
627
  extract_glb,
628
  inputs=[output_buf, mesh_simplify, texture_size],
629
- outputs=[model_output, download_glb], # Outputs to Model3D and DownloadButton path
630
- api_name="extract_glb"
631
  ).then(
632
- lambda: gr.Button(interactive=True), # Enable download button
633
  outputs=[download_glb],
634
  )
635
 
636
- # Gaussian Extraction: Takes state from output_buf, outputs model and download path
637
  extract_gs_btn.click(
638
  extract_gaussian,
639
  inputs=[output_buf],
640
- outputs=[model_output, download_gs], # Outputs to Model3D and DownloadButton path
641
- api_name="extract_gaussian"
642
  ).then(
643
- lambda: gr.Button(interactive=True), # Enable download button
644
  outputs=[download_gs],
645
  )
646
 
647
- # Clear model and disable download buttons if video/state is cleared
648
  model_output.clear(
649
- lambda: (gr.Button(interactive=False), gr.Button(interactive=False)),
650
- outputs=[download_glb, download_gs], # Disable both download buttons
651
  )
652
 
653
- # --- Launch the Gradio app ---
 
654
  if __name__ == "__main__":
655
- print("Initializing pipeline...")
656
- pipeline = None # Initialize pipeline variable
657
- try:
658
- # Load the pipeline
659
- pipeline = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge")
660
- # Move pipeline to CUDA device
661
- pipeline.cuda()
662
- print("Pipeline loaded and moved to CUDA successfully.")
663
- except Exception as e:
664
- print(f"FATAL ERROR initializing pipeline: {e}")
665
- traceback.print_exc()
666
- # Optionally exit if pipeline loading fails
667
- sys.exit(1)
668
-
669
- print("Launching Gradio demo with queue enabled...")
670
- demo.queue().launch()
 
3
 
4
  import os
5
  import shutil
6
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
7
+ os.environ['SPCONV_ALGO'] = 'native'
8
+ from typing import *
9
  import torch
10
  import numpy as np
11
  import imageio
 
13
  from trellis.pipelines import TrellisTextTo3DPipeline
14
  from trellis.representations import Gaussian, MeshExtractResult
15
  from trellis.utils import render_utils, postprocessing_utils
16
+
17
  import traceback
18
  import sys
19
+ import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
  MAX_SEED = np.iinfo(np.int32).max
23
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
24
  os.makedirs(TMP_DIR, exist_ok=True)
25
 
26
+ logging.basicConfig(level=logging.INFO)
27
+
28
 
29
  def start_session(req: gr.Request):
30
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
33
 
34
  def end_session(req: gr.Request):
35
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
36
+ shutil.rmtree(user_dir)
 
37
 
38
 
39
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
 
62
  opacity_bias=state['gaussian']['opacity_bias'],
63
  scaling_activation=state['gaussian']['scaling_activation'],
64
  )
65
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
66
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
67
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
68
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
69
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
 
70
 
71
  mesh = edict(
72
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
73
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
74
  )
75
 
76
  return gs, mesh
 
92
  slat_guidance_strength: float,
93
  slat_sampling_steps: int,
94
  req: gr.Request,
95
+ ) -> Tuple[dict, str]:
96
  """
97
+ Convert an text prompt to a 3D model.
98
  Args:
99
  prompt (str): The text prompt.
100
  seed (int): The random seed.
 
103
  slat_guidance_strength (float): The guidance strength for structured latent generation.
104
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
105
  Returns:
106
+ dict: The information of the generated 3D model.
107
+ str: The path to the video of the 3D model.
108
  """
 
109
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
110
+ os.makedirs(user_dir, exist_ok=True)
 
 
 
111
  outputs = pipeline.run(
112
  prompt,
113
  seed=seed,
 
121
  "cfg_strength": slat_guidance_strength,
122
  },
123
  )
124
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
125
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
126
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
127
+ video_path = os.path.join(user_dir, 'sample.mp4')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  imageio.mimsave(video_path, video, fps=15)
129
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
130
  torch.cuda.empty_cache()
131
+ return state, video_path
 
132
 
133
 
134
  @spaces.GPU(duration=90)
 
139
  req: gr.Request,
140
  ) -> Tuple[str, str]:
141
  """
142
+ Extract a GLB file from the 3D model.
143
  Args:
144
  state (dict): The state of the generated 3D model.
145
  mesh_simplify (float): The mesh simplification factor.
146
  texture_size (int): The texture resolution.
147
  Returns:
148
+ str: The path to the extracted GLB file.
 
149
  """
 
 
 
 
150
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
151
  os.makedirs(user_dir, exist_ok=True)
 
 
152
  gs, mesh = unpack_state(state)
 
 
153
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
154
  glb_path = os.path.join(user_dir, 'sample.glb')
 
155
  glb.export(glb_path)
 
156
  torch.cuda.empty_cache()
 
157
  return glb_path, glb_path
158
 
159
 
160
  @spaces.GPU
161
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
162
  """
163
+ Extract a Gaussian file from the 3D model.
164
  Args:
165
  state (dict): The state of the generated 3D model.
166
  Returns:
167
+ str: The path to the extracted Gaussian file.
 
168
  """
 
 
 
 
169
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
170
  os.makedirs(user_dir, exist_ok=True)
 
 
171
  gs, _ = unpack_state(state)
 
172
  gaussian_path = os.path.join(user_dir, 'sample.ply')
 
173
  gs.save_ply(gaussian_path)
 
174
  torch.cuda.empty_cache()
 
175
  return gaussian_path, gaussian_path
176
 
177
 
178
+ @spaces.GPU
 
179
  def generate_and_extract_glb(
180
  prompt: str,
181
  seed: int,
 
185
  slat_sampling_steps: int,
186
  mesh_simplify: float,
187
  texture_size: int,
188
+ req: gr.Request,
189
+ ) -> str:
190
  """
191
+ Runs the full text_to_3d and extract_glb pipeline internally.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  """
193
+ request_hash = str(req.session_hash)[:8]
194
+ logging.info(f"[{request_hash}] ENTER generate_and_extract_glb")
195
+ logging.info(f"[{request_hash}] Received parameters: prompt='{prompt}', seed={seed}, simplify={mesh_simplify}, tex_size={texture_size}, ...")
196
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  try:
198
+ logging.info(f"[{request_hash}] Calling internal text_to_3d...")
199
+ state, _ = text_to_3d(
200
+ prompt, seed, ss_guidance_strength, ss_sampling_steps,
201
+ slat_guidance_strength, slat_sampling_steps, req
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  )
203
+ if state is None:
204
+ logging.error(f"[{request_hash}] Internal text_to_3d returned None state!")
205
+ raise ValueError("Internal text_to_3d failed to return state")
206
+ logging.info(f"[{request_hash}] Internal text_to_3d completed. State type: {type(state)}")
207
+
208
+ logging.info(f"[{request_hash}] Calling internal extract_glb...")
209
+ glb_path, _ = extract_glb(
210
+ state, mesh_simplify, texture_size, req
211
+ )
212
+ if glb_path is None:
213
+ logging.error(f"[{request_hash}] Internal extract_glb returned None path!")
214
+ raise ValueError("Internal extract_glb failed to return GLB path")
215
+ logging.info(f"[{request_hash}] Internal extract_glb completed. GLB path: {glb_path}")
216
+
217
+ logging.info(f"[{request_hash}] EXIT generate_and_extract_glb - Returning path: {glb_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  return glb_path
 
 
 
 
 
219
 
220
+ except Exception as e:
221
+ logging.error(f"[{request_hash}] ERROR in generate_and_extract_glb: {e}", exc_info=True)
222
+ raise gr.Error(f"Pipeline failed: {e}")
223
 
224
+
225
+ output_buf = gr.State()
226
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
 
227
 
228
  with gr.Blocks(delete_cache=(600, 600)) as demo:
229
  gr.Markdown("""
230
  ## Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
231
  * Type a text prompt and click "Generate" to create a 3D asset.
232
+ * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
 
233
  """)
234
 
235
  with gr.Row():
 
255
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
256
 
257
  with gr.Row():
 
258
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
259
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
260
  gr.Markdown("""
 
262
  """)
263
 
264
  with gr.Column():
265
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
 
266
  model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
267
 
268
  with gr.Row():
 
269
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
270
  download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
271
 
 
272
  output_buf = gr.State()
273
 
274
+ # Handlers
275
  demo.load(start_session)
276
  demo.unload(end_session)
277
 
 
 
 
 
 
278
  generate_btn.click(
279
  get_seed,
280
  inputs=[randomize_seed, seed],
281
  outputs=[seed],
 
282
  ).then(
283
  text_to_3d,
284
  inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
285
+ outputs=[output_buf, video_output],
 
 
 
 
 
 
286
  ).then(
287
+ lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
288
  outputs=[extract_glb_btn, extract_gs_btn],
289
  )
290
 
291
+ video_output.clear(
 
 
 
 
 
 
292
  lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
293
  outputs=[extract_glb_btn, extract_gs_btn],
294
  )
295
 
 
 
296
  extract_glb_btn.click(
297
  extract_glb,
298
  inputs=[output_buf, mesh_simplify, texture_size],
299
+ outputs=[model_output, download_glb],
 
300
  ).then(
301
+ lambda: gr.Button(interactive=True),
302
  outputs=[download_glb],
303
  )
304
 
 
305
  extract_gs_btn.click(
306
  extract_gaussian,
307
  inputs=[output_buf],
308
+ outputs=[model_output, download_gs],
 
309
  ).then(
310
+ lambda: gr.Button(interactive=True),
311
  outputs=[download_gs],
312
  )
313
 
 
314
  model_output.clear(
315
+ lambda: gr.Button(interactive=False),
316
+ outputs=[download_glb],
317
  )
318
 
319
+
320
+ # Launch the Gradio app
321
  if __name__ == "__main__":
322
+ pipeline = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge")
323
+ pipeline.cuda()
324
+ demo.launch()