jbilcke-hf HF staff commited on
Commit
e02679c
·
verified ·
1 Parent(s): f779fbc

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +26 -29
gradio_app.py CHANGED
@@ -53,35 +53,32 @@ def create_rgba_image(rgb_image: Image.Image, mask: np.ndarray = None) -> Image.
53
  return rgba_image
54
 
55
  def create_batch(input_image: Image.Image) -> dict[str, Any]:
56
- """Prepare image batch for model input."""
57
- # Ensure input is RGBA
58
- if input_image.mode != 'RGBA':
59
- input_image = input_image.convert('RGBA')
60
-
61
- # Resize and convert to numpy array
62
- resized_image = input_image.resize((COND_WIDTH, COND_HEIGHT))
63
- img_array = np.array(resized_image).astype(np.float32) / 255.0
64
-
65
- # Split into RGB and alpha
66
- rgb = img_array[..., :3]
67
- alpha = img_array[..., 3:4]
68
-
69
- # Convert to tensors
70
- rgb_tensor = torch.from_numpy(rgb).float()
71
- alpha_tensor = torch.from_numpy(alpha).float()
72
-
73
- # Create background blend
74
- bg_tensor = torch.tensor(BACKGROUND_COLOR)[None, None, :]
75
- rgb_cond = torch.lerp(bg_tensor, rgb_tensor, alpha_tensor)
76
-
77
- batch = {
78
- "rgb_cond": rgb_cond.unsqueeze(0),
79
- "mask_cond": alpha_tensor.unsqueeze(0),
80
- "c2w_cond": c2w_cond.unsqueeze(0),
81
- "intrinsic_cond": intrinsic.unsqueeze(0),
82
- "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
83
- }
84
- return batch
85
 
86
  def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, height: int = 1024) -> tuple[str | None, Image.Image | None]:
87
  """Generate image from prompt and convert to 3D model."""
 
53
  return rgba_image
54
 
55
  def create_batch(input_image: Image.Image) -> dict[str, Any]:
56
+ """Prepare image batch for model input."""
57
+ # Ensure input is RGBA
58
+ if input_image.mode != 'RGBA':
59
+ input_image = input_image.convert('RGBA')
60
+
61
+ # Resize and convert to numpy array
62
+ resized_image = input_image.resize((COND_WIDTH, COND_HEIGHT))
63
+ img_array = np.array(resized_image).astype(np.float32) / 255.0
64
+
65
+ # Split into RGB and alpha
66
+ mask_cond = img_array[..., 3:4] # Alpha channel
67
+ # Blend RGB with background based on alpha
68
+ rgb_cond = np.clip(
69
+ img_array[..., :3] * mask_cond + BACKGROUND_COLOR * (1 - mask_cond),
70
+ 0,
71
+ 1
72
+ )
73
+
74
+ batch = {
75
+ "rgb_cond": torch.from_numpy(rgb_cond).unsqueeze(0),
76
+ "mask_cond": torch.from_numpy(mask_cond).unsqueeze(0),
77
+ "c2w_cond": c2w_cond.unsqueeze(0),
78
+ "intrinsic_cond": intrinsic.unsqueeze(0),
79
+ "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
80
+ }
81
+ return batch
 
 
 
82
 
83
  def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, height: int = 1024) -> tuple[str | None, Image.Image | None]:
84
  """Generate image from prompt and convert to 3D model."""