dkatz2391 commited on
Commit
b078538
·
verified ·
1 Parent(s): 8fec6b4

revert back to 1.1

Browse files
Files changed (1) hide show
  1. app.py +227 -247
app.py CHANGED
@@ -1,20 +1,22 @@
1
- # Version: 1.1.0 - API State Fix + DEBUG (Video Disabled - FINAL CORRECTED BASELINE) (2025-05-04)
2
  # Changes:
3
- # - Based *EXACTLY* on user-provided Version 1.1.0 code that fixed state dict & loaded pipeline.
4
- # - TEMPORARY DEBUGGING STEP: Commented out video rendering/saving in `text_to_3d`
5
- # and return None for video_path to isolate the "Session not found" error.
6
- # - No other changes were made. Removed previous erroneous additions.
 
 
 
7
 
8
  import gradio as gr
9
- # NOTE: Ensuring 'spaces' is imported if decorators are used (was missing in user provided snippet but needed)
10
- # If @spaces.GPU decorators are not used, this import is not needed.
11
- # Assuming they ARE used based on previous context:
12
  import spaces
13
 
14
  import os
15
  import shutil
16
  os.environ['TOKENIZERS_PARALLELISM'] = 'true'
17
- os.environ['SPCONV_ALGO'] = 'native' # Direct set as per original user code
 
 
18
 
19
  from typing import *
20
  import torch
@@ -28,128 +30,102 @@ from trellis.utils import render_utils, postprocessing_utils
28
  import traceback
29
  import sys
30
 
 
31
  MAX_SEED = np.iinfo(np.int32).max
32
- # Using path relative to file as in original user provided code
33
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
34
- # Ensure base directory exists
35
- try:
36
- os.makedirs(TMP_DIR, exist_ok=True)
37
- print(f"Using temporary directory: {TMP_DIR}")
38
- except OSError as e:
39
- print(f"Warning: Could not create base temp directory {TMP_DIR}: {e}", file=sys.stderr)
40
- TMP_DIR = '.' # Fallback
41
- print(f"Warning: Falling back to use current directory for temp files: {os.path.abspath(TMP_DIR)}")
42
 
43
  def start_session(req: gr.Request):
44
  """Creates a temporary directory for the user session."""
45
- user_dir = None
46
- try:
47
- session_hash = req.session_hash
48
- if not session_hash:
49
- session_hash = f"no_session_{np.random.randint(10000, 99999)}"
50
- print(f"Warning: No session_hash in request, using temporary ID: {session_hash}")
51
- user_dir = os.path.join(TMP_DIR, str(session_hash))
52
- os.makedirs(user_dir, exist_ok=True)
53
- print(f"Started session, ensured directory exists: {user_dir}")
54
- except Exception as e:
55
- print(f"Error in start_session creating directory '{user_dir}': {e}", file=sys.stderr)
56
 
57
  def end_session(req: gr.Request):
58
  """Removes the temporary directory for the user session."""
59
- user_dir = None
60
- try:
61
- session_hash = req.session_hash
62
- if not session_hash:
63
- print("Warning: No session_hash in end_session request, cannot clean up.")
64
- return
65
- user_dir = os.path.join(TMP_DIR, str(session_hash))
66
- if os.path.exists(user_dir) and os.path.isdir(user_dir):
67
- try:
68
- shutil.rmtree(user_dir)
69
- print(f"Ended session, removed directory: {user_dir}")
70
- except OSError as e:
71
- print(f"Error removing tmp directory {user_dir}: {e.strerror}", file=sys.stderr)
72
- else:
73
- print(f"Ended session, directory not found or not a directory: {user_dir}")
74
- except Exception as e:
75
- print(f"Error in end_session cleaning directory '{user_dir}': {e}", file=sys.stderr)
76
 
77
 
78
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
79
  """Packs Gaussian and Mesh data into a serializable dictionary."""
 
80
  print("[pack_state] Packing state to dictionary...")
81
- try:
82
- packed_data = {
83
- 'gaussian': {
84
- **{k: v for k, v in gs.init_params.items()},
85
- '_xyz': gs._xyz.detach().cpu().numpy(),
86
- '_features_dc': gs._features_dc.detach().cpu().numpy(),
87
- '_scaling': gs._scaling.detach().cpu().numpy(),
88
- '_rotation': gs._rotation.detach().cpu().numpy(),
89
- '_opacity': gs._opacity.detach().cpu().numpy(),
90
- },
91
- 'mesh': {
92
- 'vertices': mesh.vertices.detach().cpu().numpy(),
93
- 'faces': mesh.faces.detach().cpu().numpy(),
94
- },
95
- }
96
- print(f"[pack_state] Dictionary created. Keys: {list(packed_data.keys())}, Gaussian points: {len(packed_data['gaussian']['_xyz'])}, Mesh vertices: {len(packed_data['mesh']['vertices'])}")
97
- return packed_data
98
- except Exception as e:
99
- print(f"Error during pack_state: {e}", file=sys.stderr)
100
- traceback.print_exc()
101
- raise
102
 
103
 
104
  def unpack_state(state_dict: dict) -> Tuple[Gaussian, edict]:
105
  """Unpacks Gaussian and Mesh data from a dictionary."""
106
  print("[unpack_state] Unpacking state from dictionary...")
107
- try:
108
- if not isinstance(state_dict, dict) or 'gaussian' not in state_dict or 'mesh' not in state_dict:
109
- raise ValueError("Invalid state_dict structure passed to unpack_state.")
110
-
111
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
112
- print(f"[unpack_state] Using device: {device}")
113
-
114
- gauss_data = state_dict['gaussian']
115
- mesh_data = state_dict['mesh']
116
-
117
- gs = Gaussian(
118
- aabb=gauss_data.get('aabb'),
119
- sh_degree=gauss_data.get('sh_degree'),
120
- mininum_kernel_size=gauss_data.get('mininum_kernel_size'),
121
- scaling_bias=gauss_data.get('scaling_bias'),
122
- opacity_bias=gauss_data.get('opacity_bias'),
123
- scaling_activation=gauss_data.get('scaling_activation'),
124
- )
125
- gs._xyz = torch.tensor(gauss_data['_xyz'], device=device, dtype=torch.float32)
126
- gs._features_dc = torch.tensor(gauss_data['_features_dc'], device=device, dtype=torch.float32)
127
- gs._scaling = torch.tensor(gauss_data['_scaling'], device=device, dtype=torch.float32)
128
- gs._rotation = torch.tensor(gauss_data['_rotation'], device=device, dtype=torch.float32)
129
- gs._opacity = torch.tensor(gauss_data['_opacity'], device=device, dtype=torch.float32)
130
- print(f"[unpack_state] Gaussian unpacked. Points: {gs.get_xyz.shape[0]}")
131
-
132
- mesh = edict(
133
- vertices=torch.tensor(mesh_data['vertices'], device=device, dtype=torch.float32),
134
- faces=torch.tensor(mesh_data['faces'], device=device, dtype=torch.int64),
135
- )
136
- print(f"[unpack_state] Mesh unpacked. Vertices: {mesh.vertices.shape[0]}, Faces: {mesh.faces.shape[0]}")
 
 
 
137
 
138
- return gs, mesh
139
- except Exception as e:
140
- print(f"Error during unpack_state: {e}", file=sys.stderr)
141
- traceback.print_exc()
142
- raise
143
 
144
 
145
  def get_seed(randomize_seed: bool, seed: int) -> int:
146
  """Gets a seed value, randomizing if requested."""
147
  new_seed = np.random.randint(0, MAX_SEED) if randomize_seed else seed
148
  print(f"[get_seed] Randomize: {randomize_seed}, Input Seed: {seed}, Output Seed: {new_seed}")
149
- return int(new_seed)
150
 
151
 
152
- # Decorator requires 'import spaces' at the top
153
  @spaces.GPU
154
  def text_to_3d(
155
  prompt: str,
@@ -159,84 +135,79 @@ def text_to_3d(
159
  slat_guidance_strength: float,
160
  slat_sampling_steps: int,
161
  req: gr.Request,
162
- ) -> Tuple[dict, Optional[str]]: # Return Optional[str] for video path
163
  """
164
  Generates a 3D model (Gaussian and Mesh) from text and returns a
165
- serializable state dictionary and potentially a video preview path.
166
- >>> TEMPORARILY DISABLED VIDEO RENDERING FOR DEBUGGING <<<
167
  """
168
- print(f"[text_to_3d - DEBUG MODE] Received prompt: '{prompt}', Seed: {seed}")
169
- user_dir = None
170
- state_dict = None
 
 
 
171
  try:
172
- session_hash = req.session_hash
173
- if not session_hash:
174
- session_hash = f"no_session_{np.random.randint(10000, 99999)}"
175
- print(f"Warning: No session_hash in text_to_3d request, using temporary ID: {session_hash}")
176
- user_dir = os.path.join(TMP_DIR, str(session_hash))
177
- os.makedirs(user_dir, exist_ok=True)
178
- print(f"[text_to_3d - DEBUG MODE] User directory: {user_dir}")
179
-
180
- # --- Generation Pipeline ---
181
- print("[text_to_3d - DEBUG MODE] Running Trellis pipeline...")
182
  outputs = pipeline.run(
183
- prompt=prompt,
184
  seed=seed,
185
- formats=["gaussian", "mesh"],
186
  sparse_structure_sampler_params={
187
- "steps": int(ss_sampling_steps),
188
  "cfg_strength": float(ss_guidance_strength),
189
  },
190
  slat_sampler_params={
191
- "steps": int(slat_sampling_steps),
192
  "cfg_strength": float(slat_guidance_strength),
193
  },
194
  )
195
- print("[text_to_3d - DEBUG MODE] Pipeline run completed.")
 
 
 
 
 
 
196
 
197
- # --- Create Serializable State Dictionary ---
 
 
198
  state_dict = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
 
 
 
 
199
 
 
 
 
 
 
 
 
 
 
 
200
  except Exception as e:
201
- print(f"❌ [text_to_3d - DEBUG MODE] Error during generation or packing: {e}", file=sys.stderr)
202
  traceback.print_exc()
203
- raise gr.Error(f"Core generation failed: {e}")
204
-
205
- # --- Render Video Preview (TEMPORARILY DISABLED FOR DEBUGGING) ---
206
- video_path = None # Explicitly set path to None for this debug version
207
- print("[text_to_3d - DEBUG MODE] Skipping video rendering.")
208
- # --- Start Original Video Code Block (Commented Out) ---
209
- # try:
210
- # print("[text_to_3d] Rendering video preview...")
211
- # video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
212
- # video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
213
- # video = [np.concatenate([v.astype(np.uint8), vg.astype(np.uint8)], axis=1) for v, vg in zip(video, video_geo)]
214
- # video_path_tmp = os.path.join(user_dir, 'sample.mp4')
215
- # imageio.mimsave(video_path_tmp, video, fps=15, quality=8)
216
- # print(f"[text_to_3d] Video saved to: {video_path_tmp}")
217
- # video_path = video_path_tmp
218
- # except Exception as e:
219
- # print(f"❌ [text_to_3d] Video rendering/saving error: {e}", file=sys.stderr)
220
- # traceback.print_exc()
221
- # video_path = None # Indicate video failure
222
- # --- End Original Video Code Block ---
223
 
224
  # --- Cleanup and Return ---
 
225
  if torch.cuda.is_available():
226
  torch.cuda.empty_cache()
227
- print("[text_to_3d - DEBUG MODE] Cleared CUDA cache.")
228
 
229
- print("[text_to_3d - DEBUG MODE] Returning state dictionary and None video path.")
230
- if state_dict is None:
231
- print("Error: state_dict is None before return, generation likely failed.", file=sys.stderr)
232
- raise gr.Error("State dictionary creation failed.")
233
  return state_dict, video_path
234
 
235
 
236
- # Decorator requires 'import spaces' at the top
237
- @spaces.GPU(duration=120)
238
  def extract_glb(
239
- state_dict: dict,
240
  mesh_simplify: float,
241
  texture_size: int,
242
  req: gr.Request,
@@ -245,37 +216,32 @@ def extract_glb(
245
  Extracts a GLB file from the provided 3D model state dictionary.
246
  """
247
  print(f"[extract_glb] Received request. Simplify: {mesh_simplify}, Texture Size: {texture_size}")
248
- user_dir = None
249
- glb_path = None
250
- try:
251
- session_hash = req.session_hash
252
- if not session_hash:
253
- session_hash = f"no_session_{np.random.randint(10000, 99999)}"
254
- print(f"Warning: No session_hash in extract_glb request, using temporary ID: {session_hash}")
255
 
256
- if not isinstance(state_dict, dict):
257
- print("❌ [extract_glb] Error: Invalid state_dict received (not a dictionary).")
258
- raise gr.Error("Invalid state data received. Please generate the model first.")
259
 
260
- user_dir = os.path.join(TMP_DIR, str(session_hash))
261
- os.makedirs(user_dir, exist_ok=True)
262
- print(f"[extract_glb] User directory: {user_dir}")
263
-
264
- # --- Unpack state from the dictionary ---
265
  gs, mesh = unpack_state(state_dict)
 
 
 
 
266
 
267
- # --- Postprocessing and Export ---
 
268
  print("[extract_glb] Converting to GLB...")
269
- simplify_factor = float(mesh_simplify)
270
- tex_size = int(texture_size)
271
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=simplify_factor, texture_size=tex_size, verbose=True)
272
  glb_path = os.path.join(user_dir, 'sample.glb')
273
  print(f"[extract_glb] Exporting GLB to: {glb_path}")
274
  glb.export(glb_path)
275
  print("[extract_glb] GLB exported successfully.")
276
-
277
  except Exception as e:
278
- print(f"❌ [extract_glb] Error during GLB extraction: {e}", file=sys.stderr)
279
  traceback.print_exc()
280
  raise gr.Error(f"Failed to extract GLB: {e}")
281
 
@@ -284,50 +250,44 @@ def extract_glb(
284
  torch.cuda.empty_cache()
285
  print("[extract_glb] Cleared CUDA cache.")
286
 
 
287
  print("[extract_glb] Returning GLB path.")
288
- if glb_path is None:
289
- print("Error: glb_path is None before return, extraction likely failed.", file=sys.stderr)
290
- raise gr.Error("GLB path generation failed.")
291
  return glb_path, glb_path
292
 
293
 
294
- # Decorator requires 'import spaces' at the top
295
  @spaces.GPU
296
  def extract_gaussian(
297
- state_dict: dict,
298
  req: gr.Request
299
  ) -> Tuple[str, str]:
300
  """
301
  Extracts a PLY (Gaussian) file from the provided 3D model state dictionary.
302
  """
303
  print("[extract_gaussian] Received request.")
304
- user_dir = None
305
- gaussian_path = None
306
- try:
307
- session_hash = req.session_hash
308
- if not session_hash:
309
- session_hash = f"no_session_{np.random.randint(10000, 99999)}"
310
- print(f"Warning: No session_hash in extract_gaussian request, using temporary ID: {session_hash}")
311
-
312
- if not isinstance(state_dict, dict):
313
- print("❌ [extract_gaussian] Error: Invalid state_dict received (not a dictionary).")
314
- raise gr.Error("Invalid state data received. Please generate the model first.")
315
 
316
- user_dir = os.path.join(TMP_DIR, str(session_hash))
317
- os.makedirs(user_dir, exist_ok=True)
318
- print(f"[extract_gaussian] User directory: {user_dir}")
319
 
320
- # --- Unpack state from the dictionary ---
321
- gs, _ = unpack_state(state_dict)
 
 
 
 
 
322
 
323
- # --- Export PLY ---
 
324
  gaussian_path = os.path.join(user_dir, 'sample.ply')
325
  print(f"[extract_gaussian] Saving PLY to: {gaussian_path}")
326
  gs.save_ply(gaussian_path)
327
  print("[extract_gaussian] PLY saved successfully.")
328
-
329
  except Exception as e:
330
- print(f"❌ [extract_gaussian] Error during Gaussian extraction: {e}", file=sys.stderr)
331
  traceback.print_exc()
332
  raise gr.Error(f"Failed to extract Gaussian PLY: {e}")
333
 
@@ -336,10 +296,8 @@ def extract_gaussian(
336
  torch.cuda.empty_cache()
337
  print("[extract_gaussian] Cleared CUDA cache.")
338
 
 
339
  print("[extract_gaussian] Returning PLY path.")
340
- if gaussian_path is None:
341
- print("Error: gaussian_path is None before return, extraction likely failed.", file=sys.stderr)
342
- raise gr.Error("Gaussian PLY path generation failed.")
343
  return gaussian_path, gaussian_path
344
 
345
 
@@ -351,14 +309,17 @@ with gr.Blocks(delete_cache=(600, 600), title="TRELLIS Text-to-3D") as demo:
351
  * Type a text prompt and click "Generate" to create a 3D asset preview.
352
  * Adjust extraction settings if desired.
353
  * Click "Extract GLB" or "Extract Gaussian" to get the downloadable 3D file.
354
- *(Note: Video preview is temporarily disabled for debugging)*
355
  """)
356
 
 
 
 
357
  output_buf = gr.State()
358
 
359
  with gr.Row():
360
- with gr.Column(scale=1):
361
  text_prompt = gr.Textbox(label="Text Prompt", lines=5, placeholder="e.g., a cute red dragon")
 
362
  with gr.Accordion(label="Generation Settings", open=False):
363
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
364
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
@@ -370,83 +331,99 @@ with gr.Blocks(delete_cache=(600, 600), title="TRELLIS Text-to-3D") as demo:
370
  with gr.Row():
371
  slat_guidance_strength = gr.Slider(0.0, 15.0, label="Guidance Strength", value=7.5, step=0.1)
372
  slat_sampling_steps = gr.Slider(10, 50, label="Sampling Steps", value=25, step=1)
 
373
  generate_btn = gr.Button("Generate 3D Preview", variant="primary")
374
- with gr.Accordion(label="GLB Extraction Settings", open=True):
 
 
375
  mesh_simplify = gr.Slider(0.9, 0.99, label="Simplify Factor", value=0.95, step=0.01, info="Higher value = less simplification (more polys)")
376
  texture_size = gr.Slider(512, 2048, label="Texture Size (pixels)", value=1024, step=512, info="Size of the generated texture map")
 
377
  with gr.Row():
378
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
379
  extract_gs_btn = gr.Button("Extract Gaussian (PLY)", interactive=False)
380
  gr.Markdown("""
381
  *NOTE: Gaussian file (.ply) can be very large (~50MB+) and may take time to process/download.*
382
  """)
383
- with gr.Column(scale=1):
384
- # Video component remains for layout but won't show anything in this debug version
385
- video_output = gr.Video(label="Generated 3D Preview (DISABLED FOR DEBUG)", autoplay=False, loop=False, value=None, height=350)
386
- model_output = gr.Model3D(label="Extracted Model Preview", height=350, clear_color=[0.95, 0.95, 0.95, 1.0])
 
387
  with gr.Row():
 
388
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
389
  download_gs = gr.DownloadButton(label="Download Gaussian (PLY)", interactive=False)
390
 
391
  # --- Event Handlers ---
392
  print("Defining Gradio event handlers...")
393
- # Use demo.load as in original user-provided code
394
- demo.load(start_session, inputs=None, outputs=None)
395
- # Use demo.unload as in original user-provided code (no extra args)
396
- demo.unload(end_session) # Corrected: removed inputs/outputs
397
 
 
 
 
 
 
 
398
  generate_event = generate_btn.click(
399
  get_seed,
400
  inputs=[randomize_seed, seed],
401
  outputs=[seed],
402
- api_name="get_seed"
403
  ).then(
404
  text_to_3d,
405
  inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
406
- outputs=[output_buf, video_output], # state_dict -> output_buf, None -> video_output
407
  api_name="text_to_3d"
408
  ).then(
409
- lambda: (
410
- gr.Button(interactive=True), gr.Button(interactive=True),
411
- gr.DownloadButton(interactive=False), gr.DownloadButton(interactive=False)
 
 
412
  ),
413
- inputs=None,
414
- outputs=[extract_glb_btn, extract_gs_btn, download_glb, download_gs],
415
  )
416
 
 
 
 
 
 
417
  extract_glb_event = extract_glb_btn.click(
418
  extract_glb,
419
- inputs=[output_buf, mesh_simplify, texture_size],
420
- outputs=[model_output, download_glb],
421
  api_name="extract_glb"
422
  ).then(
423
- lambda: gr.DownloadButton(interactive=True),
424
- inputs=None,
425
  outputs=[download_glb],
426
  )
427
 
 
 
428
  extract_gs_event = extract_gs_btn.click(
429
  extract_gaussian,
430
- inputs=[output_buf],
431
- outputs=[model_output, download_gs],
432
  api_name="extract_gaussian"
433
  ).then(
434
- lambda: gr.DownloadButton(interactive=True),
435
- inputs=None,
436
  outputs=[download_gs],
437
  )
438
 
 
 
439
  model_output.clear(
440
  lambda: (gr.DownloadButton(interactive=False), gr.DownloadButton(interactive=False)),
441
- inputs=None,
442
  outputs=[download_glb, download_gs]
443
  )
444
- video_output.clear(
445
  lambda: (
446
- gr.Button(interactive=False), gr.Button(interactive=False),
447
- gr.DownloadButton(interactive=False), gr.DownloadButton(interactive=False)
 
 
448
  ),
449
- inputs=None,
450
  outputs=[extract_glb_btn, extract_gs_btn, download_glb, download_gs],
451
  )
452
 
@@ -456,30 +433,33 @@ with gr.Blocks(delete_cache=(600, 600), title="TRELLIS Text-to-3D") as demo:
456
  # --- Launch the Gradio app ---
457
  if __name__ == "__main__":
458
  print("Loading Trellis pipeline...")
459
- pipeline = None
460
- pipeline_loaded = False
461
  try:
462
- # --- Load pipeline WITHOUT torch_dtype (As per original working version) ---
463
  pipeline = TrellisTextTo3DPipeline.from_pretrained(
464
- "JeffreyXiang/TRELLIS-text-xlarge"
 
 
465
  )
 
466
  if torch.cuda.is_available():
467
  pipeline = pipeline.to("cuda")
468
  print("✅ Trellis pipeline loaded successfully to GPU.")
469
  else:
470
- print("⚠️ WARNING: CUDA not available, running on CPU.")
471
  print("✅ Trellis pipeline loaded successfully to CPU.")
472
- pipeline_loaded = True
473
  except Exception as e:
474
  print(f"❌ Failed to load Trellis pipeline: {e}", file=sys.stderr)
475
  traceback.print_exc()
 
476
  print("❌ Exiting due to pipeline load failure.")
477
  sys.exit(1)
478
 
479
- if pipeline_loaded:
480
- print("Launching Gradio demo...")
481
- # Use queue and debug=True as in original user-provided code
482
- demo.queue().launch(debug=True)
483
- print("Gradio demo launched.")
484
- else:
485
- print("Gradio demo not launched.")
 
 
 
1
+ # Version: 1.1.0 - API State Fix (2025-05-04)
2
  # Changes:
3
+ # - Modified `text_to_3d` to explicitly return the serializable `state_dict` from `pack_state`
4
+ # as the first return value. This ensures the dictionary is available via the API.
5
+ # - Modified `extract_glb` and `extract_gaussian` to accept `state_dict: dict` as their first argument
6
+ # instead of relying on the implicit `gr.State` object type when called via API.
7
+ # - Kept Gradio UI bindings (`outputs=[output_buf, ...]`, `inputs=[output_buf, ...]`)
8
+ # so the UI continues to function by passing the dictionary through output_buf.
9
+ # - Added minor safety checks and logging.
10
 
11
  import gradio as gr
 
 
 
12
  import spaces
13
 
14
  import os
15
  import shutil
16
  os.environ['TOKENIZERS_PARALLELISM'] = 'true'
17
+ # Fix potential SpConv issue if needed, try 'hash' or 'native'
18
+ # os.environ.setdefault('SPCONV_ALGO', 'native') # Use setdefault to avoid overwriting if already set
19
+ os.environ['SPCONV_ALGO'] = 'native' # Direct set as per original
20
 
21
  from typing import *
22
  import torch
 
30
  import traceback
31
  import sys
32
 
33
+
34
  MAX_SEED = np.iinfo(np.int32).max
 
35
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
36
+ os.makedirs(TMP_DIR, exist_ok=True)
37
+
 
 
 
 
 
 
38
 
39
  def start_session(req: gr.Request):
40
  """Creates a temporary directory for the user session."""
41
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
42
+ os.makedirs(user_dir, exist_ok=True)
43
+ print(f"Started session, created directory: {user_dir}")
44
+
 
 
 
 
 
 
 
45
 
46
  def end_session(req: gr.Request):
47
  """Removes the temporary directory for the user session."""
48
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
49
+ if os.path.exists(user_dir):
50
+ try:
51
+ shutil.rmtree(user_dir)
52
+ print(f"Ended session, removed directory: {user_dir}")
53
+ except OSError as e:
54
+ print(f"Error removing tmp directory {user_dir}: {e.strerror}", file=sys.stderr)
55
+ else:
56
+ print(f"Ended session, directory already removed: {user_dir}")
 
 
 
 
 
 
 
 
57
 
58
 
59
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
60
  """Packs Gaussian and Mesh data into a serializable dictionary."""
61
+ # Ensure tensors are on CPU and converted to numpy before returning the dict
62
  print("[pack_state] Packing state to dictionary...")
63
+ packed_data = {
64
+ 'gaussian': {
65
+ # Spread init_params first to ensure correct types
66
+ **{k: v for k, v in gs.init_params.items()}, # Ensure init_params are included
67
+ '_xyz': gs._xyz.detach().cpu().numpy(),
68
+ '_features_dc': gs._features_dc.detach().cpu().numpy(),
69
+ '_scaling': gs._scaling.detach().cpu().numpy(),
70
+ '_rotation': gs._rotation.detach().cpu().numpy(),
71
+ '_opacity': gs._opacity.detach().cpu().numpy(),
72
+ },
73
+ 'mesh': {
74
+ 'vertices': mesh.vertices.detach().cpu().numpy(),
75
+ 'faces': mesh.faces.detach().cpu().numpy(),
76
+ },
77
+ }
78
+ print(f"[pack_state] Dictionary created. Keys: {list(packed_data.keys())}, Gaussian points: {len(packed_data['gaussian']['_xyz'])}, Mesh vertices: {len(packed_data['mesh']['vertices'])}")
79
+ return packed_data
 
 
 
 
80
 
81
 
82
  def unpack_state(state_dict: dict) -> Tuple[Gaussian, edict]:
83
  """Unpacks Gaussian and Mesh data from a dictionary."""
84
  print("[unpack_state] Unpacking state from dictionary...")
85
+ if not isinstance(state_dict, dict) or 'gaussian' not in state_dict or 'mesh' not in state_dict:
86
+ raise ValueError("Invalid state_dict structure passed to unpack_state.")
87
+
88
+ # Ensure the device is correctly set when unpacking
89
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
90
+ print(f"[unpack_state] Using device: {device}")
91
+
92
+ gauss_data = state_dict['gaussian']
93
+ mesh_data = state_dict['mesh']
94
+
95
+ # Recreate Gaussian object using parameters stored during packing
96
+ gs = Gaussian(
97
+ aabb=gauss_data.get('aabb'), # Use .get for safety
98
+ sh_degree=gauss_data.get('sh_degree'),
99
+ mininum_kernel_size=gauss_data.get('mininum_kernel_size'),
100
+ scaling_bias=gauss_data.get('scaling_bias'),
101
+ opacity_bias=gauss_data.get('opacity_bias'),
102
+ scaling_activation=gauss_data.get('scaling_activation'),
103
+ )
104
+ # Load tensors, ensuring they are created on the correct device
105
+ gs._xyz = torch.tensor(gauss_data['_xyz'], device=device, dtype=torch.float32)
106
+ gs._features_dc = torch.tensor(gauss_data['_features_dc'], device=device, dtype=torch.float32)
107
+ gs._scaling = torch.tensor(gauss_data['_scaling'], device=device, dtype=torch.float32)
108
+ gs._rotation = torch.tensor(gauss_data['_rotation'], device=device, dtype=torch.float32)
109
+ gs._opacity = torch.tensor(gauss_data['_opacity'], device=device, dtype=torch.float32)
110
+ print(f"[unpack_state] Gaussian unpacked. Points: {gs.get_xyz.shape[0]}")
111
+
112
+ # Recreate mesh object using edict for compatibility if needed elsewhere
113
+ mesh = edict(
114
+ vertices=torch.tensor(mesh_data['vertices'], device=device, dtype=torch.float32),
115
+ faces=torch.tensor(mesh_data['faces'], device=device, dtype=torch.int64), # Faces are typically long/int64
116
+ )
117
+ print(f"[unpack_state] Mesh unpacked. Vertices: {mesh.vertices.shape[0]}, Faces: {mesh.faces.shape[0]}")
118
 
119
+ return gs, mesh
 
 
 
 
120
 
121
 
122
  def get_seed(randomize_seed: bool, seed: int) -> int:
123
  """Gets a seed value, randomizing if requested."""
124
  new_seed = np.random.randint(0, MAX_SEED) if randomize_seed else seed
125
  print(f"[get_seed] Randomize: {randomize_seed}, Input Seed: {seed}, Output Seed: {new_seed}")
126
+ return int(new_seed) # Ensure it's a standard int
127
 
128
 
 
129
  @spaces.GPU
130
  def text_to_3d(
131
  prompt: str,
 
135
  slat_guidance_strength: float,
136
  slat_sampling_steps: int,
137
  req: gr.Request,
138
+ ) -> Tuple[dict, str]: # Return type changed for clarity
139
  """
140
  Generates a 3D model (Gaussian and Mesh) from text and returns a
141
+ serializable state dictionary and a video preview path.
 
142
  """
143
+ print(f"[text_to_3d] Received prompt: '{prompt}', Seed: {seed}")
144
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
145
+ os.makedirs(user_dir, exist_ok=True)
146
+ print(f"[text_to_3d] User directory: {user_dir}")
147
+
148
+ # --- Generation Pipeline ---
149
  try:
150
+ print("[text_to_3d] Running Trellis pipeline...")
 
 
 
 
 
 
 
 
 
151
  outputs = pipeline.run(
152
+ prompt,
153
  seed=seed,
154
+ formats=["gaussian", "mesh"], # Ensure both are generated
155
  sparse_structure_sampler_params={
156
+ "steps": int(ss_sampling_steps), # Ensure steps are int
157
  "cfg_strength": float(ss_guidance_strength),
158
  },
159
  slat_sampler_params={
160
+ "steps": int(slat_sampling_steps), # Ensure steps are int
161
  "cfg_strength": float(slat_guidance_strength),
162
  },
163
  )
164
+ print("[text_to_3d] Pipeline run completed.")
165
+ except Exception as e:
166
+ print(f"❌ [text_to_3d] Pipeline error: {e}", file=sys.stderr)
167
+ traceback.print_exc()
168
+ # Return an empty dict and maybe an error indicator path or None?
169
+ # For now, re-raise to signal failure clearly upstream.
170
+ raise gr.Error(f"Trellis pipeline failed: {e}")
171
 
172
+ # --- Create Serializable State Dictionary --- VITAL CHANGE for API
173
+ # This dictionary holds the necessary data for later extraction.
174
+ try:
175
  state_dict = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
176
+ except Exception as e:
177
+ print(f"❌ [text_to_3d] pack_state error: {e}", file=sys.stderr)
178
+ traceback.print_exc()
179
+ raise gr.Error(f"Failed to pack state: {e}")
180
 
181
+ # --- Render Video Preview ---
182
+ try:
183
+ print("[text_to_3d] Rendering video preview...")
184
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
185
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
186
+ # Ensure video frames are uint8
187
+ video = [np.concatenate([v.astype(np.uint8), vg.astype(np.uint8)], axis=1) for v, vg in zip(video, video_geo)]
188
+ video_path = os.path.join(user_dir, 'sample.mp4')
189
+ imageio.mimsave(video_path, video, fps=15, quality=8) # Added quality setting
190
+ print(f"[text_to_3d] Video saved to: {video_path}")
191
  except Exception as e:
192
+ print(f"❌ [text_to_3d] Video rendering/saving error: {e}", file=sys.stderr)
193
  traceback.print_exc()
194
+ # Still return state_dict, but maybe signal video error? Return None for path.
195
+ video_path = None # Indicate video failure
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  # --- Cleanup and Return ---
198
+ # Clear CUDA cache if GPU was used
199
  if torch.cuda.is_available():
200
  torch.cuda.empty_cache()
201
+ print("[text_to_3d] Cleared CUDA cache.")
202
 
203
+ # --- Return Serializable Dictionary and Video Path --- VITAL CHANGE for API
204
+ print("[text_to_3d] Returning state dictionary and video path.")
 
 
205
  return state_dict, video_path
206
 
207
 
208
+ @spaces.GPU(duration=120) # Increased duration slightly
 
209
  def extract_glb(
210
+ state_dict: dict, # <-- VITAL CHANGE: Accept the dictionary directly
211
  mesh_simplify: float,
212
  texture_size: int,
213
  req: gr.Request,
 
216
  Extracts a GLB file from the provided 3D model state dictionary.
217
  """
218
  print(f"[extract_glb] Received request. Simplify: {mesh_simplify}, Texture Size: {texture_size}")
219
+ if not isinstance(state_dict, dict):
220
+ print("❌ [extract_glb] Error: Invalid state_dict received (not a dictionary).")
221
+ raise gr.Error("Invalid state data received. Please generate the model first.")
 
 
 
 
222
 
223
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
224
+ os.makedirs(user_dir, exist_ok=True)
225
+ print(f"[extract_glb] User directory: {user_dir}")
226
 
227
+ # --- Unpack state from the dictionary --- VITAL CHANGE for API
228
+ try:
 
 
 
229
  gs, mesh = unpack_state(state_dict)
230
+ except Exception as e:
231
+ print(f"❌ [extract_glb] unpack_state error: {e}", file=sys.stderr)
232
+ traceback.print_exc()
233
+ raise gr.Error(f"Failed to unpack state: {e}")
234
 
235
+ # --- Postprocessing and Export ---
236
+ try:
237
  print("[extract_glb] Converting to GLB...")
238
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=float(mesh_simplify), texture_size=int(texture_size), verbose=True) # Verbose for debugging
 
 
239
  glb_path = os.path.join(user_dir, 'sample.glb')
240
  print(f"[extract_glb] Exporting GLB to: {glb_path}")
241
  glb.export(glb_path)
242
  print("[extract_glb] GLB exported successfully.")
 
243
  except Exception as e:
244
+ print(f"❌ [extract_glb] GLB conversion/export error: {e}", file=sys.stderr)
245
  traceback.print_exc()
246
  raise gr.Error(f"Failed to extract GLB: {e}")
247
 
 
250
  torch.cuda.empty_cache()
251
  print("[extract_glb] Cleared CUDA cache.")
252
 
253
+ # Return path twice for both Model3D and DownloadButton components
254
  print("[extract_glb] Returning GLB path.")
 
 
 
255
  return glb_path, glb_path
256
 
257
 
 
258
  @spaces.GPU
259
  def extract_gaussian(
260
+ state_dict: dict, # <-- VITAL CHANGE: Accept the dictionary directly
261
  req: gr.Request
262
  ) -> Tuple[str, str]:
263
  """
264
  Extracts a PLY (Gaussian) file from the provided 3D model state dictionary.
265
  """
266
  print("[extract_gaussian] Received request.")
267
+ if not isinstance(state_dict, dict):
268
+ print("❌ [extract_gaussian] Error: Invalid state_dict received (not a dictionary).")
269
+ raise gr.Error("Invalid state data received. Please generate the model first.")
 
 
 
 
 
 
 
 
270
 
271
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
272
+ os.makedirs(user_dir, exist_ok=True)
273
+ print(f"[extract_gaussian] User directory: {user_dir}")
274
 
275
+ # --- Unpack state from the dictionary --- VITAL CHANGE for API
276
+ try:
277
+ gs, _ = unpack_state(state_dict) # Only need Gaussian part
278
+ except Exception as e:
279
+ print(f"❌ [extract_gaussian] unpack_state error: {e}", file=sys.stderr)
280
+ traceback.print_exc()
281
+ raise gr.Error(f"Failed to unpack state: {e}")
282
 
283
+ # --- Export PLY ---
284
+ try:
285
  gaussian_path = os.path.join(user_dir, 'sample.ply')
286
  print(f"[extract_gaussian] Saving PLY to: {gaussian_path}")
287
  gs.save_ply(gaussian_path)
288
  print("[extract_gaussian] PLY saved successfully.")
 
289
  except Exception as e:
290
+ print(f"❌ [extract_gaussian] PLY saving error: {e}", file=sys.stderr)
291
  traceback.print_exc()
292
  raise gr.Error(f"Failed to extract Gaussian PLY: {e}")
293
 
 
296
  torch.cuda.empty_cache()
297
  print("[extract_gaussian] Cleared CUDA cache.")
298
 
299
+ # Return path twice for both Model3D and DownloadButton components
300
  print("[extract_gaussian] Returning PLY path.")
 
 
 
301
  return gaussian_path, gaussian_path
302
 
303
 
 
309
  * Type a text prompt and click "Generate" to create a 3D asset preview.
310
  * Adjust extraction settings if desired.
311
  * Click "Extract GLB" or "Extract Gaussian" to get the downloadable 3D file.
 
312
  """)
313
 
314
+ # --- State Buffer ---
315
+ # This hidden component will hold the dictionary returned by text_to_3d,
316
+ # acting as the state link between generation and extraction for the UI/API.
317
  output_buf = gr.State()
318
 
319
  with gr.Row():
320
+ with gr.Column(scale=1): # Input column
321
  text_prompt = gr.Textbox(label="Text Prompt", lines=5, placeholder="e.g., a cute red dragon")
322
+
323
  with gr.Accordion(label="Generation Settings", open=False):
324
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
325
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
331
  with gr.Row():
332
  slat_guidance_strength = gr.Slider(0.0, 15.0, label="Guidance Strength", value=7.5, step=0.1)
333
  slat_sampling_steps = gr.Slider(10, 50, label="Sampling Steps", value=25, step=1)
334
+
335
  generate_btn = gr.Button("Generate 3D Preview", variant="primary")
336
+
337
+ with gr.Accordion(label="GLB Extraction Settings", open=True): # Open by default
338
+ # Tooltips added for clarity
339
  mesh_simplify = gr.Slider(0.9, 0.99, label="Simplify Factor", value=0.95, step=0.01, info="Higher value = less simplification (more polys)")
340
  texture_size = gr.Slider(512, 2048, label="Texture Size (pixels)", value=1024, step=512, info="Size of the generated texture map")
341
+
342
  with gr.Row():
343
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
344
  extract_gs_btn = gr.Button("Extract Gaussian (PLY)", interactive=False)
345
  gr.Markdown("""
346
  *NOTE: Gaussian file (.ply) can be very large (~50MB+) and may take time to process/download.*
347
  """)
348
+
349
+ with gr.Column(scale=1): # Output column
350
+ video_output = gr.Video(label="Generated 3D Preview (Geometry | Texture)", autoplay=True, loop=True, height=350) # Slightly larger height
351
+ model_output = gr.Model3D(label="Extracted Model Preview", height=350, clear_color=[0.95, 0.95, 0.95, 1.0]) # Light background
352
+
353
  with gr.Row():
354
+ # Link download button visibility/interactivity to model_output potentially
355
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
356
  download_gs = gr.DownloadButton(label="Download Gaussian (PLY)", interactive=False)
357
 
358
  # --- Event Handlers ---
359
  print("Defining Gradio event handlers...")
 
 
 
 
360
 
361
+ # Handle session start/end
362
+ demo.load(start_session, inputs=None, outputs=None) # Pass None for clarity
363
+ demo.unload(end_session, inputs=None, outputs=None)
364
+
365
+ # --- Generate Button Click Flow ---
366
+ # 1. Get Seed -> 2. Run text_to_3d -> 3. Enable extraction buttons
367
  generate_event = generate_btn.click(
368
  get_seed,
369
  inputs=[randomize_seed, seed],
370
  outputs=[seed],
371
+ api_name="get_seed" # Optional API name
372
  ).then(
373
  text_to_3d,
374
  inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
375
+ outputs=[output_buf, video_output], # output_buf receives state_dict
376
  api_name="text_to_3d"
377
  ).then(
378
+ lambda: ( # Return tuple for multiple outputs
379
+ gr.Button(interactive=True),
380
+ gr.Button(interactive=True),
381
+ gr.DownloadButton(interactive=False), # Ensure download buttons are disabled initially
382
+ gr.DownloadButton(interactive=False)
383
  ),
384
+ outputs=[extract_glb_btn, extract_gs_btn, download_glb, download_gs], # Update interactivity
 
385
  )
386
 
387
+ # --- Clear video/model outputs if prompt changes (optional, prevents confusion)
388
+ # text_prompt.change(lambda: (None, None, gr.Button(interactive=False), gr.Button(interactive=False)), outputs=[video_output, model_output, extract_glb_btn, extract_gs_btn])
389
+
390
+ # --- Extract GLB Button Click Flow ---
391
+ # 1. Run extract_glb -> 2. Update Model3D and Download Button
392
  extract_glb_event = extract_glb_btn.click(
393
  extract_glb,
394
+ inputs=[output_buf, mesh_simplify, texture_size], # Pass the state_dict via output_buf
395
+ outputs=[model_output, download_glb], # Returns path to both
396
  api_name="extract_glb"
397
  ).then(
398
+ lambda: gr.DownloadButton(interactive=True), # Enable download button
 
399
  outputs=[download_glb],
400
  )
401
 
402
+ # --- Extract Gaussian Button Click Flow ---
403
+ # 1. Run extract_gaussian -> 2. Update Model3D and Download Button
404
  extract_gs_event = extract_gs_btn.click(
405
  extract_gaussian,
406
+ inputs=[output_buf], # Pass the state_dict via output_buf
407
+ outputs=[model_output, download_gs], # Returns path to both
408
  api_name="extract_gaussian"
409
  ).then(
410
+ lambda: gr.DownloadButton(interactive=True), # Enable download button
 
411
  outputs=[download_gs],
412
  )
413
 
414
+ # --- Clear Download Button Interactivity when model preview is cleared ---
415
+ # This might be redundant if generate disables them, but adds safety
416
  model_output.clear(
417
  lambda: (gr.DownloadButton(interactive=False), gr.DownloadButton(interactive=False)),
 
418
  outputs=[download_glb, download_gs]
419
  )
420
+ video_output.clear( # Also disable extraction if video is cleared (e.g., new generation starts)
421
  lambda: (
422
+ gr.Button(interactive=False),
423
+ gr.Button(interactive=False),
424
+ gr.DownloadButton(interactive=False),
425
+ gr.DownloadButton(interactive=False)
426
  ),
 
427
  outputs=[extract_glb_btn, extract_gs_btn, download_glb, download_gs],
428
  )
429
 
 
433
  # --- Launch the Gradio app ---
434
  if __name__ == "__main__":
435
  print("Loading Trellis pipeline...")
 
 
436
  try:
437
+ # Ensure model/variant matches requirements, use revision if needed
438
  pipeline = TrellisTextTo3DPipeline.from_pretrained(
439
+ "JeffreyXiang/TRELLIS-text-xlarge",
440
+ # revision="main", # Specify if needed
441
+ torch_dtype=torch.float16 # Use float16 if GPU supports it for less memory
442
  )
443
+ # Move to GPU if available
444
  if torch.cuda.is_available():
445
  pipeline = pipeline.to("cuda")
446
  print("✅ Trellis pipeline loaded successfully to GPU.")
447
  else:
448
+ print("⚠️ WARNING: CUDA not available, running on CPU (will be very slow).")
449
  print("✅ Trellis pipeline loaded successfully to CPU.")
 
450
  except Exception as e:
451
  print(f"❌ Failed to load Trellis pipeline: {e}", file=sys.stderr)
452
  traceback.print_exc()
453
+ # Exit if pipeline is critical for the app to run
454
  print("❌ Exiting due to pipeline load failure.")
455
  sys.exit(1)
456
 
457
+ print("Launching Gradio demo...")
458
+ # Set share=True if you need a public link (e.g., for testing from outside local network)
459
+ # Set server_name="0.0.0.0" to allow access from local network IP
460
+ demo.queue().launch( # Use queue for potentially long-running tasks
461
+ # server_name="0.0.0.0",
462
+ # share=False,
463
+ debug=True # Enable debug mode for more logs
464
+ )
465
+ print("Gradio demo launched.")