jbilcke-hf HF staff commited on
Commit
f779fbc
·
verified ·
1 Parent(s): aec7186

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +12 -0
gradio_app.py CHANGED
@@ -88,6 +88,8 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
88
  try:
89
  # Generate image using FLUX
90
  generator = torch.Generator(device=device).manual_seed(seed)
 
 
91
  generated_image = flux_pipe(
92
  prompt=prompt,
93
  width=width,
@@ -98,19 +100,24 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
98
  ).images[0]
99
 
100
  # Process the generated image
 
101
  rgb_image = generated_image.convert('RGB')
102
 
103
  # Remove background
 
104
  no_bg_image = bg_remover.process(rgb_image)
105
 
106
  # Convert to numpy array to extract mask
 
107
  no_bg_array = np.array(no_bg_image)
108
  mask = (no_bg_array.sum(axis=2) > 0).astype(np.float32)
109
 
110
  # Create RGBA image
 
111
  rgba_image = create_rgba_image(rgb_image, mask)
112
 
113
  # Auto crop with foreground
 
114
  processed_image = spar3d_utils.foreground_crop(
115
  rgba_image,
116
  crop_ratio=1.3,
@@ -118,12 +125,14 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
118
  no_crop=False
119
  )
120
 
 
121
  # Prepare batch for 3D generation
122
  batch = create_batch(processed_image)
123
  batch = {k: v.to(device) for k, v in batch.items()}
124
 
125
  # Generate mesh
126
  with torch.no_grad():
 
127
  with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
128
  trimesh_mesh, _ = spar3d_model.generate_mesh(
129
  batch,
@@ -135,8 +144,11 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
135
  trimesh_mesh = trimesh_mesh[0]
136
 
137
  # Export to GLB
 
138
  temp_dir = tempfile.mkdtemp()
139
  output_path = os.path.join(temp_dir, 'output.glb')
 
 
140
  trimesh_mesh.export(output_path, file_type="glb", include_normals=True)
141
 
142
  return output_path, generated_image
 
88
  try:
89
  # Generate image using FLUX
90
  generator = torch.Generator(device=device).manual_seed(seed)
91
+
92
+ print("[debug] generating the image using Flux")
93
  generated_image = flux_pipe(
94
  prompt=prompt,
95
  width=width,
 
100
  ).images[0]
101
 
102
  # Process the generated image
103
+ print("[debug] converting the image to rgb")
104
  rgb_image = generated_image.convert('RGB')
105
 
106
  # Remove background
107
+ print("[debug] removing the background by calling bg_remover.process(rgb_image)")
108
  no_bg_image = bg_remover.process(rgb_image)
109
 
110
  # Convert to numpy array to extract mask
111
+ print("[debug] converting to numpy array to extract the mask")
112
  no_bg_array = np.array(no_bg_image)
113
  mask = (no_bg_array.sum(axis=2) > 0).astype(np.float32)
114
 
115
  # Create RGBA image
116
+ print("[debug] creating the RGBA image using create_rgba_image(rgb_image, mask)")
117
  rgba_image = create_rgba_image(rgb_image, mask)
118
 
119
  # Auto crop with foreground
120
+ print(f"[debug] auto-cromming the rgba_image using spar3d_utils.foreground_crop(...). newsize=(COND_WIDTH, COND_HEIGHT) = ({COND_WIDTH}, {COND_HEIGHT})")
121
  processed_image = spar3d_utils.foreground_crop(
122
  rgba_image,
123
  crop_ratio=1.3,
 
125
  no_crop=False
126
  )
127
 
128
+ print("[debug] preparing the batch by calling create_batch(processed_image)")
129
  # Prepare batch for 3D generation
130
  batch = create_batch(processed_image)
131
  batch = {k: v.to(device) for k, v in batch.items()}
132
 
133
  # Generate mesh
134
  with torch.no_grad():
135
+ print("[debug] calling torch.autocast(....) to generate the mesh")
136
  with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
137
  trimesh_mesh, _ = spar3d_model.generate_mesh(
138
  batch,
 
144
  trimesh_mesh = trimesh_mesh[0]
145
 
146
  # Export to GLB
147
+ print("[debug] creating tmp dir for the .glb output")
148
  temp_dir = tempfile.mkdtemp()
149
  output_path = os.path.join(temp_dir, 'output.glb')
150
+
151
+ print("[debug] calling trimesh_mesh.export(...) to export to .glb")
152
  trimesh_mesh.export(output_path, file_type="glb", include_normals=True)
153
 
154
  return output_path, generated_image