jbilcke-hf HF staff commited on
Commit
751171e
·
verified ·
1 Parent(s): 2728300

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +58 -23
gradio_app.py CHANGED
@@ -86,8 +86,8 @@ def create_batch(input_image: Image.Image) -> dict[str, Any]:
86
  print("[debug] rgb_cond shape:", rgb_cond.shape)
87
 
88
  # Permute the tensors to match the expected shape [B, C, H, W]
89
- rgb_cond = rgb_cond.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
90
- mask = mask.permute(2, 0, 1).unsqueeze(0) # [1, 1, H, W]
91
 
92
  print("[debug] rgb_cond after permute shape:", rgb_cond.shape)
93
  print("[debug] mask after permute shape:", mask.shape)
@@ -106,12 +106,53 @@ def create_batch(input_image: Image.Image) -> dict[str, Any]:
106
 
107
  return batch
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, height: int = 1024) -> tuple[str | None, Image.Image | None]:
110
  """Generate image from prompt and convert to 3D model."""
111
  try:
 
 
 
 
112
  # Generate image using FLUX
113
  generator = torch.Generator(device=device).manual_seed(seed)
114
-
115
  print("[debug] generating the image using Flux")
116
  generated_image = flux_pipe(
117
  prompt=prompt,
@@ -138,7 +179,7 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
138
  print("[debug] creating the RGBA image using create_rgba_image(rgb_image, mask)")
139
  rgba_image = create_rgba_image(rgb_image, mask)
140
 
141
- print(f"[debug] auto-cropping the rgba_image using spar3d_utils.foreground_crop(...)")
142
  processed_image = spar3d_utils.foreground_crop(
143
  rgba_image,
144
  crop_ratio=1.3,
@@ -146,33 +187,25 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
146
  no_crop=False
147
  )
148
 
149
- # Forward pass through SPAR3D
150
  print("[debug] preparing the batch by calling create_batch(processed_image)")
151
  batch = create_batch(processed_image)
152
  batch = {k: v.to(device) for k, v in batch.items()}
153
 
 
 
 
 
 
 
 
 
 
 
154
  # Generate mesh
155
  with torch.no_grad():
156
  print("[debug] calling torch.autocast(....) to generate the mesh")
157
  with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
158
- # Add point cloud conditioning to match expected input
159
- if "pc_cond" not in batch:
160
- # Sample tokens from model's diffusion process
161
- cond_tokens = spar3d_model.forward_pdiff_cond(batch)
162
- sample_iter = spar3d_model.sampler.sample_batch_progressive(
163
- 1, # batch size
164
- cond_tokens,
165
- guidance_scale=3.0,
166
- device=device,
167
- )
168
- for x in sample_iter:
169
- samples = x["xstart"]
170
- # Add point cloud to batch
171
- batch["pc_cond"] = samples.permute(0, 2, 1).float()
172
- batch["pc_cond"] = spar3d_utils.normalize_pc_bbox(batch["pc_cond"])
173
- # Subsample to 512 points
174
- batch["pc_cond"] = batch["pc_cond"][:, torch.randperm(batch["pc_cond"].shape[1])[:512]]
175
-
176
  trimesh_mesh, _ = spar3d_model.generate_mesh(
177
  batch,
178
  1024, # texture_resolution
@@ -194,6 +227,8 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
194
 
195
  except Exception as e:
196
  print(f"Error during generation: {str(e)}")
 
 
197
  return None, None
198
 
199
  # Create Gradio interface
 
86
  print("[debug] rgb_cond shape:", rgb_cond.shape)
87
 
88
  # Permute the tensors to match the expected shape [B, C, H, W]
89
+ rgb_cond = torch.movedim(rgb_cond, 2, 0).unsqueeze(0) # [1, 3, H, W]
90
+ mask = torch.movedim(mask, 2, 0).unsqueeze(0) # [1, 1, H, W]
91
 
92
  print("[debug] rgb_cond after permute shape:", rgb_cond.shape)
93
  print("[debug] mask after permute shape:", mask.shape)
 
106
 
107
  return batch
108
 
109
+ def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
110
+ """Process batch through model and generate point cloud."""
111
+ print("[debug] Starting forward_model")
112
+ batch_size = batch["rgb_cond"].shape[0]
113
+
114
+ # Generate point cloud tokens
115
+ print("[debug] Generating point cloud tokens")
116
+ cond_tokens = system.forward_pdiff_cond(batch)
117
+ print("[debug] cond_tokens shape:", cond_tokens.shape)
118
+
119
+ # Sample points
120
+ print("[debug] Sampling points")
121
+ sample_iter = system.sampler.sample_batch_progressive(
122
+ batch_size,
123
+ cond_tokens,
124
+ guidance_scale=guidance_scale,
125
+ device=device
126
+ )
127
+
128
+ # Get final samples
129
+ for x in sample_iter:
130
+ samples = x["xstart"]
131
+
132
+ print("[debug] samples shape before permute:", samples.shape)
133
+ # Convert samples to point cloud format
134
+ pc_cond = samples.permute(0, 2, 1).float()
135
+ print("[debug] pc_cond shape after permute:", pc_cond.shape)
136
+
137
+ # Normalize point cloud
138
+ pc_cond = spar3d_utils.normalize_pc_bbox(pc_cond)
139
+ print("[debug] pc_cond shape after normalize:", pc_cond.shape)
140
+
141
+ # Subsample to 512 points
142
+ pc_cond = pc_cond[:, torch.randperm(pc_cond.shape[1])[:512]]
143
+ print("[debug] pc_cond final shape:", pc_cond.shape)
144
+
145
+ return pc_cond
146
+
147
  def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, height: int = 1024) -> tuple[str | None, Image.Image | None]:
148
  """Generate image from prompt and convert to 3D model."""
149
  try:
150
+ # Set random seeds
151
+ torch.manual_seed(seed)
152
+ np.random.seed(seed)
153
+
154
  # Generate image using FLUX
155
  generator = torch.Generator(device=device).manual_seed(seed)
 
156
  print("[debug] generating the image using Flux")
157
  generated_image = flux_pipe(
158
  prompt=prompt,
 
179
  print("[debug] creating the RGBA image using create_rgba_image(rgb_image, mask)")
180
  rgba_image = create_rgba_image(rgb_image, mask)
181
 
182
+ print("[debug] auto-cropping the rgba_image using spar3d_utils.foreground_crop(...)")
183
  processed_image = spar3d_utils.foreground_crop(
184
  rgba_image,
185
  crop_ratio=1.3,
 
187
  no_crop=False
188
  )
189
 
190
+ # Prepare batch for processing
191
  print("[debug] preparing the batch by calling create_batch(processed_image)")
192
  batch = create_batch(processed_image)
193
  batch = {k: v.to(device) for k, v in batch.items()}
194
 
195
+ # Generate point cloud
196
+ pc_cond = forward_model(
197
+ batch,
198
+ spar3d_model,
199
+ guidance_scale=3.0,
200
+ seed=seed,
201
+ device=device
202
+ )
203
+ batch["pc_cond"] = pc_cond
204
+
205
  # Generate mesh
206
  with torch.no_grad():
207
  print("[debug] calling torch.autocast(....) to generate the mesh")
208
  with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  trimesh_mesh, _ = spar3d_model.generate_mesh(
210
  batch,
211
  1024, # texture_resolution
 
227
 
228
  except Exception as e:
229
  print(f"Error during generation: {str(e)}")
230
+ import traceback
231
+ traceback.print_exc()
232
  return None, None
233
 
234
  # Create Gradio interface