Spaces:
Running
on
L40S
Running
on
L40S
Update gradio_app.py
Browse files- gradio_app.py +44 -25
gradio_app.py
CHANGED
@@ -43,31 +43,47 @@ intrinsic, intrinsic_normed_cond = spar3d_utils.create_intrinsic_from_fov_rad(
|
|
43 |
COND_FOVY, COND_HEIGHT, COND_WIDTH
|
44 |
)
|
45 |
|
46 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
"""Prepare image batch for model input."""
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
batch = {
|
62 |
"rgb_cond": rgb_cond.unsqueeze(0),
|
63 |
-
"mask_cond":
|
64 |
"c2w_cond": c2w_cond.unsqueeze(0),
|
65 |
"intrinsic_cond": intrinsic.unsqueeze(0),
|
66 |
"intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
|
67 |
}
|
68 |
return batch
|
69 |
|
70 |
-
def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, height: int = 1024) -> tuple[str, Image.Image]:
|
71 |
"""Generate image from prompt and convert to 3D model."""
|
72 |
try:
|
73 |
# Generate image using FLUX
|
@@ -81,23 +97,26 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
|
|
81 |
guidance_scale=0.0
|
82 |
).images[0]
|
83 |
|
84 |
-
#
|
85 |
-
|
|
|
|
|
|
|
|
|
86 |
|
87 |
-
#
|
88 |
-
rgba_image =
|
89 |
-
rgba_image.putalpha(255) # Add alpha channel
|
90 |
|
91 |
-
# Auto crop
|
92 |
-
|
93 |
rgba_image,
|
94 |
crop_ratio=1.3,
|
95 |
newsize=(COND_WIDTH, COND_HEIGHT),
|
96 |
no_crop=False
|
97 |
)
|
98 |
|
99 |
-
# Prepare batch
|
100 |
-
batch = create_batch(
|
101 |
batch = {k: v.to(device) for k, v in batch.items()}
|
102 |
|
103 |
# Generate mesh
|
@@ -120,7 +139,7 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
|
|
120 |
return output_path, generated_image
|
121 |
|
122 |
except Exception as e:
|
123 |
-
print(f"Error: {str(e)}")
|
124 |
return None, None
|
125 |
|
126 |
# Create Gradio interface
|
|
|
43 |
COND_FOVY, COND_HEIGHT, COND_WIDTH
|
44 |
)
|
45 |
|
46 |
+
def create_rgba_image(rgb_image: Image.Image, alpha: np.ndarray = None) -> Image.Image:
|
47 |
+
"""Create an RGBA image from RGB image and optional alpha channel."""
|
48 |
+
if alpha is None:
|
49 |
+
alpha = np.full(rgb_image.size[::-1], 255, dtype=np.uint8)
|
50 |
+
rgba = Image.new('RGBA', rgb_image.size)
|
51 |
+
rgba.paste(rgb_image)
|
52 |
+
rgba.putalpha(Image.fromarray(alpha))
|
53 |
+
return rgba
|
54 |
+
|
55 |
+
def create_batch(input_image: Image.Image) -> dict[str, Any]:
|
56 |
"""Prepare image batch for model input."""
|
57 |
+
# Ensure input is RGBA
|
58 |
+
if input_image.mode != 'RGBA':
|
59 |
+
input_image = input_image.convert('RGBA')
|
60 |
+
|
61 |
+
# Resize and convert to numpy array
|
62 |
+
resized_image = input_image.resize((COND_WIDTH, COND_HEIGHT))
|
63 |
+
img_array = np.array(resized_image).astype(np.float32) / 255.0
|
64 |
+
|
65 |
+
# Split into RGB and alpha
|
66 |
+
rgb = img_array[..., :3]
|
67 |
+
alpha = img_array[..., 3:4]
|
68 |
+
|
69 |
+
# Convert to tensors
|
70 |
+
rgb_tensor = torch.from_numpy(rgb).float()
|
71 |
+
alpha_tensor = torch.from_numpy(alpha).float()
|
72 |
+
|
73 |
+
# Create background blend
|
74 |
+
bg_tensor = torch.tensor(BACKGROUND_COLOR)[None, None, :]
|
75 |
+
rgb_cond = torch.lerp(bg_tensor, rgb_tensor, alpha_tensor)
|
76 |
|
77 |
batch = {
|
78 |
"rgb_cond": rgb_cond.unsqueeze(0),
|
79 |
+
"mask_cond": alpha_tensor.unsqueeze(0),
|
80 |
"c2w_cond": c2w_cond.unsqueeze(0),
|
81 |
"intrinsic_cond": intrinsic.unsqueeze(0),
|
82 |
"intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
|
83 |
}
|
84 |
return batch
|
85 |
|
86 |
+
def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, height: int = 1024) -> tuple[str | None, Image.Image | None]:
|
87 |
"""Generate image from prompt and convert to 3D model."""
|
88 |
try:
|
89 |
# Generate image using FLUX
|
|
|
97 |
guidance_scale=0.0
|
98 |
).images[0]
|
99 |
|
100 |
+
# Process the generated image
|
101 |
+
rgb_image = generated_image.convert('RGB')
|
102 |
+
|
103 |
+
# Remove background and get mask
|
104 |
+
mask = bg_remover.process_image(rgb_image)
|
105 |
+
mask_uint8 = (mask * 255).astype(np.uint8)
|
106 |
|
107 |
+
# Create RGBA image
|
108 |
+
rgba_image = create_rgba_image(rgb_image, mask_uint8)
|
|
|
109 |
|
110 |
+
# Auto crop with foreground
|
111 |
+
processed_image = spar3d_utils.foreground_crop(
|
112 |
rgba_image,
|
113 |
crop_ratio=1.3,
|
114 |
newsize=(COND_WIDTH, COND_HEIGHT),
|
115 |
no_crop=False
|
116 |
)
|
117 |
|
118 |
+
# Prepare batch for 3D generation
|
119 |
+
batch = create_batch(processed_image)
|
120 |
batch = {k: v.to(device) for k, v in batch.items()}
|
121 |
|
122 |
# Generate mesh
|
|
|
139 |
return output_path, generated_image
|
140 |
|
141 |
except Exception as e:
|
142 |
+
print(f"Error during generation: {str(e)}")
|
143 |
return None, None
|
144 |
|
145 |
# Create Gradio interface
|