Spaces:
Running
on
L40S
Running
on
L40S
Update gradio_app.py
Browse files- gradio_app.py +26 -11
gradio_app.py
CHANGED
@@ -71,26 +71,30 @@ def create_batch(input_image: Image.Image) -> dict[str, Any]:
|
|
71 |
rgb = img_array
|
72 |
mask = np.ones((*img_array.shape[:2], 1), dtype=np.float32)
|
73 |
|
74 |
-
# Convert to tensors
|
75 |
-
rgb = torch.from_numpy(rgb).float()
|
76 |
-
mask = torch.from_numpy(mask).float()
|
77 |
print("[debug] rgb tensor shape:", rgb.shape)
|
78 |
print("[debug] mask tensor shape:", mask.shape)
|
79 |
|
80 |
# Create background blend
|
81 |
-
bg_tensor = torch.tensor(BACKGROUND_COLOR)[
|
82 |
print("[debug] bg_tensor shape:", bg_tensor.shape)
|
83 |
|
84 |
# Blend RGB with background using mask
|
85 |
-
rgb_cond = torch.lerp(
|
86 |
-
|
|
|
|
|
|
|
|
|
87 |
|
88 |
-
# Permute the tensors to
|
89 |
-
rgb_cond =
|
90 |
-
mask =
|
91 |
|
92 |
-
print("[debug] rgb_cond
|
93 |
-
print("[debug] mask
|
94 |
|
95 |
batch = {
|
96 |
"rgb_cond": rgb_cond,
|
@@ -109,6 +113,17 @@ def create_batch(input_image: Image.Image) -> dict[str, Any]:
|
|
109 |
def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
|
110 |
"""Process batch through model and generate point cloud."""
|
111 |
print("[debug] Starting forward_model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
batch_size = batch["rgb_cond"].shape[0]
|
113 |
|
114 |
# Generate point cloud tokens
|
|
|
71 |
rgb = img_array
|
72 |
mask = np.ones((*img_array.shape[:2], 1), dtype=np.float32)
|
73 |
|
74 |
+
# Convert to tensors and keep in channel-last format initially
|
75 |
+
rgb = torch.from_numpy(rgb).float() # [H, W, 3]
|
76 |
+
mask = torch.from_numpy(mask).float() # [H, W, 1]
|
77 |
print("[debug] rgb tensor shape:", rgb.shape)
|
78 |
print("[debug] mask tensor shape:", mask.shape)
|
79 |
|
80 |
# Create background blend
|
81 |
+
bg_tensor = torch.tensor(BACKGROUND_COLOR) # [3]
|
82 |
print("[debug] bg_tensor shape:", bg_tensor.shape)
|
83 |
|
84 |
# Blend RGB with background using mask
|
85 |
+
rgb_cond = torch.lerp(
|
86 |
+
bg_tensor.view(1, 1, 3), # [1, 1, 3]
|
87 |
+
rgb, # [H, W, 3]
|
88 |
+
mask # [H, W, 1]
|
89 |
+
)
|
90 |
+
print("[debug] rgb_cond shape after blend:", rgb_cond.shape)
|
91 |
|
92 |
+
# Permute the tensors to [B, C, H, W] format at the end
|
93 |
+
rgb_cond = rgb_cond.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
|
94 |
+
mask = mask.permute(2, 0, 1).unsqueeze(0) # [1, 1, H, W]
|
95 |
|
96 |
+
print("[debug] rgb_cond final shape:", rgb_cond.shape)
|
97 |
+
print("[debug] mask final shape:", mask.shape)
|
98 |
|
99 |
batch = {
|
100 |
"rgb_cond": rgb_cond,
|
|
|
113 |
def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
|
114 |
"""Process batch through model and generate point cloud."""
|
115 |
print("[debug] Starting forward_model")
|
116 |
+
print("[debug] Input rgb_cond shape:", batch["rgb_cond"].shape)
|
117 |
+
|
118 |
+
# Ensure input is in correct format [B, C, H, W]
|
119 |
+
if batch["rgb_cond"].shape[1] != 3:
|
120 |
+
batch["rgb_cond"] = batch["rgb_cond"].permute(0, 3, 1, 2)
|
121 |
+
if batch["mask_cond"].shape[1] != 1:
|
122 |
+
batch["mask_cond"] = batch["mask_cond"].permute(0, 3, 1, 2)
|
123 |
+
|
124 |
+
print("[debug] Processed rgb_cond shape:", batch["rgb_cond"].shape)
|
125 |
+
print("[debug] Processed mask_cond shape:", batch["mask_cond"].shape)
|
126 |
+
|
127 |
batch_size = batch["rgb_cond"].shape[0]
|
128 |
|
129 |
# Generate point cloud tokens
|