jbilcke-hf HF staff commited on
Commit
4b4ce8a
·
verified ·
1 Parent(s): 9e70cab

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +41 -33
gradio_app.py CHANGED
@@ -71,68 +71,77 @@ def create_batch(input_image: Image.Image) -> dict[str, Any]:
71
  rgb = img_array
72
  mask = np.ones((*img_array.shape[:2], 1), dtype=np.float32)
73
 
74
- # Convert to tensors and keep in channel-last format initially
75
  rgb = torch.from_numpy(rgb).float() # [H, W, 3]
76
  mask = torch.from_numpy(mask).float() # [H, W, 1]
77
  print("[debug] rgb tensor shape:", rgb.shape)
78
  print("[debug] mask tensor shape:", mask.shape)
79
 
80
- # Create background blend
81
- bg_tensor = torch.tensor(BACKGROUND_COLOR) # [3]
82
  print("[debug] bg_tensor shape:", bg_tensor.shape)
83
 
84
- # Blend RGB with background using mask
85
- rgb_cond = torch.lerp(
86
- bg_tensor.view(1, 1, 3), # [1, 1, 3]
87
- rgb, # [H, W, 3]
88
- mask # [H, W, 1]
89
- )
90
  print("[debug] rgb_cond shape after blend:", rgb_cond.shape)
91
 
92
- # Permute the tensors to [B, C, H, W] format at the end
93
- rgb_cond = rgb_cond.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
94
- mask = mask.permute(2, 0, 1).unsqueeze(0) # [1, 1, H, W]
 
95
 
96
  print("[debug] rgb_cond final shape:", rgb_cond.shape)
97
  print("[debug] mask final shape:", mask.shape)
98
 
 
99
  batch = {
100
- "rgb_cond": rgb_cond,
101
- "mask_cond": mask,
102
- "c2w_cond": c2w_cond.unsqueeze(0),
103
- "intrinsic_cond": intrinsic.unsqueeze(0),
104
- "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
105
  }
106
 
107
- # Final shapes check
108
  for k, v in batch.items():
109
  print(f"[debug] {k} final shape:", v.shape)
 
 
 
110
 
111
  return batch
112
 
113
  def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
114
  """Process batch through model and generate point cloud."""
115
- print("[debug] Starting forward_model")
116
  print("[debug] Input rgb_cond shape:", batch["rgb_cond"].shape)
117
-
118
- # Ensure input is in correct format [B, C, H, W]
119
- if batch["rgb_cond"].shape[1] != 3:
120
- batch["rgb_cond"] = batch["rgb_cond"].permute(0, 3, 1, 2)
121
- if batch["mask_cond"].shape[1] != 1:
122
- batch["mask_cond"] = batch["mask_cond"].permute(0, 3, 1, 2)
123
-
124
- print("[debug] Processed rgb_cond shape:", batch["rgb_cond"].shape)
125
- print("[debug] Processed mask_cond shape:", batch["mask_cond"].shape)
126
 
127
  batch_size = batch["rgb_cond"].shape[0]
 
 
 
 
 
 
 
128
 
129
  # Generate point cloud tokens
130
- print("[debug] Generating point cloud tokens")
131
- cond_tokens = system.forward_pdiff_cond(batch)
132
- print("[debug] cond_tokens shape:", cond_tokens.shape)
 
 
 
 
 
 
 
 
 
133
 
134
  # Sample points
135
- print("[debug] Sampling points")
136
  sample_iter = system.sampler.sample_batch_progressive(
137
  batch_size,
138
  cond_tokens,
@@ -145,7 +154,6 @@ def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
145
  samples = x["xstart"]
146
 
147
  print("[debug] samples shape before permute:", samples.shape)
148
- # Convert samples to point cloud format
149
  pc_cond = samples.permute(0, 2, 1).float()
150
  print("[debug] pc_cond shape after permute:", pc_cond.shape)
151
 
 
71
  rgb = img_array
72
  mask = np.ones((*img_array.shape[:2], 1), dtype=np.float32)
73
 
74
+ # Convert to tensors while keeping channel-last format
75
  rgb = torch.from_numpy(rgb).float() # [H, W, 3]
76
  mask = torch.from_numpy(mask).float() # [H, W, 1]
77
  print("[debug] rgb tensor shape:", rgb.shape)
78
  print("[debug] mask tensor shape:", mask.shape)
79
 
80
+ # Create background blend (match channel-last format)
81
+ bg_tensor = torch.tensor(BACKGROUND_COLOR).view(1, 1, 3) # [1, 1, 3]
82
  print("[debug] bg_tensor shape:", bg_tensor.shape)
83
 
84
+ # Blend RGB with background using mask (all in channel-last format)
85
+ rgb_cond = torch.lerp(bg_tensor, rgb, mask) # [H, W, 3]
 
 
 
 
86
  print("[debug] rgb_cond shape after blend:", rgb_cond.shape)
87
 
88
+ # Move channels to correct dimension and add batch dimension
89
+ # Important: For SPAR3D image tokenizer, we need [B, H, W, C] format
90
+ rgb_cond = rgb_cond.unsqueeze(0) # [1, H, W, 3]
91
+ mask = mask.unsqueeze(0) # [1, H, W, 1]
92
 
93
  print("[debug] rgb_cond final shape:", rgb_cond.shape)
94
  print("[debug] mask final shape:", mask.shape)
95
 
96
+ # Create the batch dictionary
97
  batch = {
98
+ "rgb_cond": rgb_cond, # [1, H, W, 3]
99
+ "mask_cond": mask, # [1, H, W, 1]
100
+ "c2w_cond": c2w_cond.unsqueeze(0), # [1, 4, 4]
101
+ "intrinsic_cond": intrinsic.unsqueeze(0), # [1, 3, 3]
102
+ "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0), # [1, 3, 3]
103
  }
104
 
105
+ print("\nFinal batch shapes:")
106
  for k, v in batch.items():
107
  print(f"[debug] {k} final shape:", v.shape)
108
+ print("\nrgb_cond max:", batch["rgb_cond"].max())
109
+ print("rgb_cond min:", batch["rgb_cond"].min())
110
+ print("mask_cond unique values:", torch.unique(batch["mask_cond"]))
111
 
112
  return batch
113
 
114
  def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
115
  """Process batch through model and generate point cloud."""
116
+ print("\n[debug] Starting forward_model")
117
  print("[debug] Input rgb_cond shape:", batch["rgb_cond"].shape)
118
+ print("[debug] Input mask_cond shape:", batch["mask_cond"].shape)
 
 
 
 
 
 
 
 
119
 
120
  batch_size = batch["rgb_cond"].shape[0]
121
+ assert batch_size == 1, f"Expected batch size 1, got {batch_size}"
122
+
123
+ # Print value ranges for debugging
124
+ print("\nValue ranges:")
125
+ print("rgb_cond max:", batch["rgb_cond"].max())
126
+ print("rgb_cond min:", batch["rgb_cond"].min())
127
+ print("mask_cond unique values:", torch.unique(batch["mask_cond"]))
128
 
129
  # Generate point cloud tokens
130
+ print("\n[debug] Generating point cloud tokens")
131
+ try:
132
+ cond_tokens = system.forward_pdiff_cond(batch)
133
+ print("[debug] cond_tokens shape:", cond_tokens.shape)
134
+ except Exception as e:
135
+ print("\n[ERROR] Failed in forward_pdiff_cond:")
136
+ print(e)
137
+ print("\nInput tensor properties:")
138
+ print("rgb_cond dtype:", batch["rgb_cond"].dtype)
139
+ print("rgb_cond device:", batch["rgb_cond"].device)
140
+ print("rgb_cond requires_grad:", batch["rgb_cond"].requires_grad)
141
+ raise
142
 
143
  # Sample points
144
+ print("\n[debug] Sampling points")
145
  sample_iter = system.sampler.sample_batch_progressive(
146
  batch_size,
147
  cond_tokens,
 
154
  samples = x["xstart"]
155
 
156
  print("[debug] samples shape before permute:", samples.shape)
 
157
  pc_cond = samples.permute(0, 2, 1).float()
158
  print("[debug] pc_cond shape after permute:", pc_cond.shape)
159