jbilcke-hf HF staff commited on
Commit
03dc078
·
verified ·
1 Parent(s): c882a68

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +19 -19
gradio_app.py CHANGED
@@ -55,38 +55,38 @@ def create_rgba_image(rgb_image: Image.Image, mask: np.ndarray = None) -> Image.
55
  print("[debug] alpha size:", alpha.size)
56
  rgba_image.putalpha(alpha)
57
  return rgba_image
58
-
59
  def create_batch(input_image: Image.Image) -> dict[str, Any]:
60
  """Prepare image batch for model input."""
61
- # Ensure input is RGBA
62
- if input_image.mode != 'RGBA':
63
- input_image = input_image.convert('RGBA')
64
-
65
- # Resize and convert to numpy array
66
- resized_image = input_image.resize((COND_WIDTH, COND_HEIGHT))
67
- img_array = np.array(resized_image).astype(np.float32) / 255.0
68
-
69
  print("[debug] img_array shape:", img_array.shape)
70
 
71
- # Split into RGB and alpha
72
  rgb = torch.from_numpy(img_array[..., :3]).float()
73
- alpha = torch.from_numpy(img_array[..., 3:4]).float()
74
-
75
  print("[debug] rgb tensor shape:", rgb.shape)
76
- print("[debug] alpha tensor shape:", alpha.shape)
77
 
78
- # Create background blend using torch.lerp()
79
  bg_tensor = torch.tensor(BACKGROUND_COLOR)[None, None, :]
80
  print("[debug] bg_tensor shape:", bg_tensor.shape)
81
 
82
- rgb_cond = torch.lerp(bg_tensor, rgb, alpha)
 
83
  print("[debug] rgb_cond shape:", rgb_cond.shape)
84
 
 
 
 
 
 
 
85
  batch = {
86
- "rgb_cond": rgb_cond.unsqueeze(0),
87
- "mask_cond": alpha.unsqueeze(0),
88
  "c2w_cond": c2w_cond.unsqueeze(0),
89
- "intrinsic_cond": intrinsic.unsqueeze(0),
90
  "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
91
  }
92
 
@@ -130,7 +130,7 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
130
  rgba_image = create_rgba_image(rgb_image, mask)
131
 
132
  # Auto crop with foreground
133
- print(f"[debug] auto-cromming the rgba_image using spar3d_utils.foreground_crop(...). newsize=(COND_WIDTH, COND_HEIGHT) = ({COND_WIDTH}, {COND_HEIGHT})")
134
  processed_image = spar3d_utils.foreground_crop(
135
  rgba_image,
136
  crop_ratio=1.3,
 
55
  print("[debug] alpha size:", alpha.size)
56
  rgba_image.putalpha(alpha)
57
  return rgba_image
58
+
59
  def create_batch(input_image: Image.Image) -> dict[str, Any]:
60
  """Prepare image batch for model input."""
61
+ # Convert input image to numpy array and normalize
62
+ img_array = np.array(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32) / 255.0
 
 
 
 
 
 
63
  print("[debug] img_array shape:", img_array.shape)
64
 
65
+ # Extract RGB and alpha channels
66
  rgb = torch.from_numpy(img_array[..., :3]).float()
67
+ mask = torch.from_numpy(img_array[..., 3:4]).float()
 
68
  print("[debug] rgb tensor shape:", rgb.shape)
69
+ print("[debug] mask tensor shape:", mask.shape)
70
 
71
+ # Create background blend
72
  bg_tensor = torch.tensor(BACKGROUND_COLOR)[None, None, :]
73
  print("[debug] bg_tensor shape:", bg_tensor.shape)
74
 
75
+ # Blend RGB with background using mask
76
+ rgb_cond = torch.lerp(bg_tensor, rgb, mask)
77
  print("[debug] rgb_cond shape:", rgb_cond.shape)
78
 
79
+ # Note: We need to permute the tensors to match the expected shape
80
+ rgb_cond = rgb_cond.permute(2, 0, 1) # Change from [H, W, C] to [C, H, W]
81
+ mask = mask.permute(2, 0, 1) # Change from [H, W, 1] to [1, H, W]
82
+ print("[debug] rgb_cond after permute shape:", rgb_cond.shape)
83
+ print("[debug] mask after permute shape:", mask.shape)
84
+
85
  batch = {
86
+ "rgb_cond": rgb_cond.unsqueeze(0), # Add batch dimension
87
+ "mask_cond": mask.unsqueeze(0),
88
  "c2w_cond": c2w_cond.unsqueeze(0),
89
+ "intrinsic_cond": intrinsic.unsqueeze(0),
90
  "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
91
  }
92
 
 
130
  rgba_image = create_rgba_image(rgb_image, mask)
131
 
132
  # Auto crop with foreground
133
+ print(f"[debug] auto-cropping the rgba_image using spar3d_utils.foreground_crop(...). newsize=(COND_WIDTH, COND_HEIGHT) = ({COND_WIDTH}, {COND_HEIGHT})")
134
  processed_image = spar3d_utils.foreground_crop(
135
  rgba_image,
136
  crop_ratio=1.3,