jbilcke-hf HF staff commited on
Commit
9e70cab
·
verified ·
1 Parent(s): 751171e

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +26 -11
gradio_app.py CHANGED
@@ -71,26 +71,30 @@ def create_batch(input_image: Image.Image) -> dict[str, Any]:
71
  rgb = img_array
72
  mask = np.ones((*img_array.shape[:2], 1), dtype=np.float32)
73
 
74
- # Convert to tensors
75
- rgb = torch.from_numpy(rgb).float()
76
- mask = torch.from_numpy(mask).float()
77
  print("[debug] rgb tensor shape:", rgb.shape)
78
  print("[debug] mask tensor shape:", mask.shape)
79
 
80
  # Create background blend
81
- bg_tensor = torch.tensor(BACKGROUND_COLOR)[None, None, :]
82
  print("[debug] bg_tensor shape:", bg_tensor.shape)
83
 
84
  # Blend RGB with background using mask
85
- rgb_cond = torch.lerp(bg_tensor, rgb, mask)
86
- print("[debug] rgb_cond shape:", rgb_cond.shape)
 
 
 
 
87
 
88
- # Permute the tensors to match the expected shape [B, C, H, W]
89
- rgb_cond = torch.movedim(rgb_cond, 2, 0).unsqueeze(0) # [1, 3, H, W]
90
- mask = torch.movedim(mask, 2, 0).unsqueeze(0) # [1, 1, H, W]
91
 
92
- print("[debug] rgb_cond after permute shape:", rgb_cond.shape)
93
- print("[debug] mask after permute shape:", mask.shape)
94
 
95
  batch = {
96
  "rgb_cond": rgb_cond,
@@ -109,6 +113,17 @@ def create_batch(input_image: Image.Image) -> dict[str, Any]:
109
  def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
110
  """Process batch through model and generate point cloud."""
111
  print("[debug] Starting forward_model")
 
 
 
 
 
 
 
 
 
 
 
112
  batch_size = batch["rgb_cond"].shape[0]
113
 
114
  # Generate point cloud tokens
 
71
  rgb = img_array
72
  mask = np.ones((*img_array.shape[:2], 1), dtype=np.float32)
73
 
74
+ # Convert to tensors and keep in channel-last format initially
75
+ rgb = torch.from_numpy(rgb).float() # [H, W, 3]
76
+ mask = torch.from_numpy(mask).float() # [H, W, 1]
77
  print("[debug] rgb tensor shape:", rgb.shape)
78
  print("[debug] mask tensor shape:", mask.shape)
79
 
80
  # Create background blend
81
+ bg_tensor = torch.tensor(BACKGROUND_COLOR) # [3]
82
  print("[debug] bg_tensor shape:", bg_tensor.shape)
83
 
84
  # Blend RGB with background using mask
85
+ rgb_cond = torch.lerp(
86
+ bg_tensor.view(1, 1, 3), # [1, 1, 3]
87
+ rgb, # [H, W, 3]
88
+ mask # [H, W, 1]
89
+ )
90
+ print("[debug] rgb_cond shape after blend:", rgb_cond.shape)
91
 
92
+ # Permute the tensors to [B, C, H, W] format at the end
93
+ rgb_cond = rgb_cond.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
94
+ mask = mask.permute(2, 0, 1).unsqueeze(0) # [1, 1, H, W]
95
 
96
+ print("[debug] rgb_cond final shape:", rgb_cond.shape)
97
+ print("[debug] mask final shape:", mask.shape)
98
 
99
  batch = {
100
  "rgb_cond": rgb_cond,
 
113
  def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
114
  """Process batch through model and generate point cloud."""
115
  print("[debug] Starting forward_model")
116
+ print("[debug] Input rgb_cond shape:", batch["rgb_cond"].shape)
117
+
118
+ # Ensure input is in correct format [B, C, H, W]
119
+ if batch["rgb_cond"].shape[1] != 3:
120
+ batch["rgb_cond"] = batch["rgb_cond"].permute(0, 3, 1, 2)
121
+ if batch["mask_cond"].shape[1] != 1:
122
+ batch["mask_cond"] = batch["mask_cond"].permute(0, 3, 1, 2)
123
+
124
+ print("[debug] Processed rgb_cond shape:", batch["rgb_cond"].shape)
125
+ print("[debug] Processed mask_cond shape:", batch["mask_cond"].shape)
126
+
127
  batch_size = batch["rgb_cond"].shape[0]
128
 
129
  # Generate point cloud tokens