jbilcke-hf HF staff commited on
Commit
daf9fe6
·
verified ·
1 Parent(s): e973397

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +44 -25
gradio_app.py CHANGED
@@ -43,31 +43,47 @@ intrinsic, intrinsic_normed_cond = spar3d_utils.create_intrinsic_from_fov_rad(
43
  COND_FOVY, COND_HEIGHT, COND_WIDTH
44
  )
45
 
46
- def create_batch(input_image: Image) -> dict[str, Any]:
 
 
 
 
 
 
 
 
 
47
  """Prepare image batch for model input."""
48
- img_cond = (
49
- torch.from_numpy(
50
- np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32)
51
- / 255.0
52
- )
53
- .float()
54
- .clip(0, 1)
55
- )
56
- mask_cond = img_cond[:, :, -1:]
57
- rgb_cond = torch.lerp(
58
- torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond
59
- )
 
 
 
 
 
 
 
60
 
61
  batch = {
62
  "rgb_cond": rgb_cond.unsqueeze(0),
63
- "mask_cond": mask_cond.unsqueeze(0),
64
  "c2w_cond": c2w_cond.unsqueeze(0),
65
  "intrinsic_cond": intrinsic.unsqueeze(0),
66
  "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
67
  }
68
  return batch
69
 
70
- def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, height: int = 1024) -> tuple[str, Image.Image]:
71
  """Generate image from prompt and convert to 3D model."""
72
  try:
73
  # Generate image using FLUX
@@ -81,23 +97,26 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
81
  guidance_scale=0.0
82
  ).images[0]
83
 
84
- # Convert PIL image to RGBA
85
- input_image = generated_image.convert("RGBA")
 
 
 
 
86
 
87
- # Remove background
88
- rgba_image = bg_remover.process(input_image.convert("RGB"))
89
- rgba_image.putalpha(255) # Add alpha channel
90
 
91
- # Auto crop
92
- input_image = spar3d_utils.foreground_crop(
93
  rgba_image,
94
  crop_ratio=1.3,
95
  newsize=(COND_WIDTH, COND_HEIGHT),
96
  no_crop=False
97
  )
98
 
99
- # Prepare batch
100
- batch = create_batch(input_image)
101
  batch = {k: v.to(device) for k, v in batch.items()}
102
 
103
  # Generate mesh
@@ -120,7 +139,7 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
120
  return output_path, generated_image
121
 
122
  except Exception as e:
123
- print(f"Error: {str(e)}")
124
  return None, None
125
 
126
  # Create Gradio interface
 
43
  COND_FOVY, COND_HEIGHT, COND_WIDTH
44
  )
45
 
46
+ def create_rgba_image(rgb_image: Image.Image, alpha: np.ndarray = None) -> Image.Image:
47
+ """Create an RGBA image from RGB image and optional alpha channel."""
48
+ if alpha is None:
49
+ alpha = np.full(rgb_image.size[::-1], 255, dtype=np.uint8)
50
+ rgba = Image.new('RGBA', rgb_image.size)
51
+ rgba.paste(rgb_image)
52
+ rgba.putalpha(Image.fromarray(alpha))
53
+ return rgba
54
+
55
+ def create_batch(input_image: Image.Image) -> dict[str, Any]:
56
  """Prepare image batch for model input."""
57
+ # Ensure input is RGBA
58
+ if input_image.mode != 'RGBA':
59
+ input_image = input_image.convert('RGBA')
60
+
61
+ # Resize and convert to numpy array
62
+ resized_image = input_image.resize((COND_WIDTH, COND_HEIGHT))
63
+ img_array = np.array(resized_image).astype(np.float32) / 255.0
64
+
65
+ # Split into RGB and alpha
66
+ rgb = img_array[..., :3]
67
+ alpha = img_array[..., 3:4]
68
+
69
+ # Convert to tensors
70
+ rgb_tensor = torch.from_numpy(rgb).float()
71
+ alpha_tensor = torch.from_numpy(alpha).float()
72
+
73
+ # Create background blend
74
+ bg_tensor = torch.tensor(BACKGROUND_COLOR)[None, None, :]
75
+ rgb_cond = torch.lerp(bg_tensor, rgb_tensor, alpha_tensor)
76
 
77
  batch = {
78
  "rgb_cond": rgb_cond.unsqueeze(0),
79
+ "mask_cond": alpha_tensor.unsqueeze(0),
80
  "c2w_cond": c2w_cond.unsqueeze(0),
81
  "intrinsic_cond": intrinsic.unsqueeze(0),
82
  "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
83
  }
84
  return batch
85
 
86
+ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, height: int = 1024) -> tuple[str | None, Image.Image | None]:
87
  """Generate image from prompt and convert to 3D model."""
88
  try:
89
  # Generate image using FLUX
 
97
  guidance_scale=0.0
98
  ).images[0]
99
 
100
+ # Process the generated image
101
+ rgb_image = generated_image.convert('RGB')
102
+
103
+ # Remove background and get mask
104
+ mask = bg_remover.process_image(rgb_image)
105
+ mask_uint8 = (mask * 255).astype(np.uint8)
106
 
107
+ # Create RGBA image
108
+ rgba_image = create_rgba_image(rgb_image, mask_uint8)
 
109
 
110
+ # Auto crop with foreground
111
+ processed_image = spar3d_utils.foreground_crop(
112
  rgba_image,
113
  crop_ratio=1.3,
114
  newsize=(COND_WIDTH, COND_HEIGHT),
115
  no_crop=False
116
  )
117
 
118
+ # Prepare batch for 3D generation
119
+ batch = create_batch(processed_image)
120
  batch = {k: v.to(device) for k, v in batch.items()}
121
 
122
  # Generate mesh
 
139
  return output_path, generated_image
140
 
141
  except Exception as e:
142
+ print(f"Error during generation: {str(e)}")
143
  return None, None
144
 
145
  # Create Gradio interface