jbilcke-hf HF staff commited on
Commit
aec7186
·
verified ·
1 Parent(s): c9fb81c

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +15 -12
gradio_app.py CHANGED
@@ -43,14 +43,14 @@ intrinsic, intrinsic_normed_cond = spar3d_utils.create_intrinsic_from_fov_rad(
43
  COND_FOVY, COND_HEIGHT, COND_WIDTH
44
  )
45
 
46
- def create_rgba_image(rgb_image: Image.Image, alpha: np.ndarray = None) -> Image.Image:
47
- """Create an RGBA image from RGB image and optional alpha channel."""
48
- if alpha is None:
49
- alpha = np.full(rgb_image.size[::-1], 255, dtype=np.uint8)
50
- rgba = Image.new('RGBA', rgb_image.size)
51
- rgba.paste(rgb_image)
52
- rgba.putalpha(Image.fromarray(alpha))
53
- return rgba
54
 
55
  def create_batch(input_image: Image.Image) -> dict[str, Any]:
56
  """Prepare image batch for model input."""
@@ -100,12 +100,15 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
100
  # Process the generated image
101
  rgb_image = generated_image.convert('RGB')
102
 
103
- # Remove background and get mask
104
- mask = bg_remover.process(rgb_image)
105
- mask_uint8 = (mask * 255).astype(np.uint8)
 
 
 
106
 
107
  # Create RGBA image
108
- rgba_image = create_rgba_image(rgb_image, mask_uint8)
109
 
110
  # Auto crop with foreground
111
  processed_image = spar3d_utils.foreground_crop(
 
43
  COND_FOVY, COND_HEIGHT, COND_WIDTH
44
  )
45
 
46
+ def create_rgba_image(rgb_image: Image.Image, mask: np.ndarray = None) -> Image.Image:
47
+ """Create an RGBA image from RGB image and optional mask."""
48
+ rgba_image = rgb_image.convert('RGBA')
49
+ if mask is not None:
50
+ # Convert mask to alpha channel format
51
+ alpha = Image.fromarray((mask * 255).astype(np.uint8))
52
+ rgba_image.putalpha(alpha)
53
+ return rgba_image
54
 
55
  def create_batch(input_image: Image.Image) -> dict[str, Any]:
56
  """Prepare image batch for model input."""
 
100
  # Process the generated image
101
  rgb_image = generated_image.convert('RGB')
102
 
103
+ # Remove background
104
+ no_bg_image = bg_remover.process(rgb_image)
105
+
106
+ # Convert to numpy array to extract mask
107
+ no_bg_array = np.array(no_bg_image)
108
+ mask = (no_bg_array.sum(axis=2) > 0).astype(np.float32)
109
 
110
  # Create RGBA image
111
+ rgba_image = create_rgba_image(rgb_image, mask)
112
 
113
  # Auto crop with foreground
114
  processed_image = spar3d_utils.foreground_crop(