Spaces:
Running
on
L40S
Running
on
L40S
Update gradio_app.py
Browse files- gradio_app.py +12 -0
gradio_app.py
CHANGED
@@ -88,6 +88,8 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
|
|
88 |
try:
|
89 |
# Generate image using FLUX
|
90 |
generator = torch.Generator(device=device).manual_seed(seed)
|
|
|
|
|
91 |
generated_image = flux_pipe(
|
92 |
prompt=prompt,
|
93 |
width=width,
|
@@ -98,19 +100,24 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
|
|
98 |
).images[0]
|
99 |
|
100 |
# Process the generated image
|
|
|
101 |
rgb_image = generated_image.convert('RGB')
|
102 |
|
103 |
# Remove background
|
|
|
104 |
no_bg_image = bg_remover.process(rgb_image)
|
105 |
|
106 |
# Convert to numpy array to extract mask
|
|
|
107 |
no_bg_array = np.array(no_bg_image)
|
108 |
mask = (no_bg_array.sum(axis=2) > 0).astype(np.float32)
|
109 |
|
110 |
# Create RGBA image
|
|
|
111 |
rgba_image = create_rgba_image(rgb_image, mask)
|
112 |
|
113 |
# Auto crop with foreground
|
|
|
114 |
processed_image = spar3d_utils.foreground_crop(
|
115 |
rgba_image,
|
116 |
crop_ratio=1.3,
|
@@ -118,12 +125,14 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
|
|
118 |
no_crop=False
|
119 |
)
|
120 |
|
|
|
121 |
# Prepare batch for 3D generation
|
122 |
batch = create_batch(processed_image)
|
123 |
batch = {k: v.to(device) for k, v in batch.items()}
|
124 |
|
125 |
# Generate mesh
|
126 |
with torch.no_grad():
|
|
|
127 |
with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
|
128 |
trimesh_mesh, _ = spar3d_model.generate_mesh(
|
129 |
batch,
|
@@ -135,8 +144,11 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
|
|
135 |
trimesh_mesh = trimesh_mesh[0]
|
136 |
|
137 |
# Export to GLB
|
|
|
138 |
temp_dir = tempfile.mkdtemp()
|
139 |
output_path = os.path.join(temp_dir, 'output.glb')
|
|
|
|
|
140 |
trimesh_mesh.export(output_path, file_type="glb", include_normals=True)
|
141 |
|
142 |
return output_path, generated_image
|
|
|
88 |
try:
|
89 |
# Generate image using FLUX
|
90 |
generator = torch.Generator(device=device).manual_seed(seed)
|
91 |
+
|
92 |
+
print("[debug] generating the image using Flux")
|
93 |
generated_image = flux_pipe(
|
94 |
prompt=prompt,
|
95 |
width=width,
|
|
|
100 |
).images[0]
|
101 |
|
102 |
# Process the generated image
|
103 |
+
print("[debug] converting the image to rgb")
|
104 |
rgb_image = generated_image.convert('RGB')
|
105 |
|
106 |
# Remove background
|
107 |
+
print("[debug] removing the background by calling bg_remover.process(rgb_image)")
|
108 |
no_bg_image = bg_remover.process(rgb_image)
|
109 |
|
110 |
# Convert to numpy array to extract mask
|
111 |
+
print("[debug] converting to numpy array to extract the mask")
|
112 |
no_bg_array = np.array(no_bg_image)
|
113 |
mask = (no_bg_array.sum(axis=2) > 0).astype(np.float32)
|
114 |
|
115 |
# Create RGBA image
|
116 |
+
print("[debug] creating the RGBA image using create_rgba_image(rgb_image, mask)")
|
117 |
rgba_image = create_rgba_image(rgb_image, mask)
|
118 |
|
119 |
# Auto crop with foreground
|
120 |
+
print(f"[debug] auto-cromming the rgba_image using spar3d_utils.foreground_crop(...). newsize=(COND_WIDTH, COND_HEIGHT) = ({COND_WIDTH}, {COND_HEIGHT})")
|
121 |
processed_image = spar3d_utils.foreground_crop(
|
122 |
rgba_image,
|
123 |
crop_ratio=1.3,
|
|
|
125 |
no_crop=False
|
126 |
)
|
127 |
|
128 |
+
print("[debug] preparing the batch by calling create_batch(processed_image)")
|
129 |
# Prepare batch for 3D generation
|
130 |
batch = create_batch(processed_image)
|
131 |
batch = {k: v.to(device) for k, v in batch.items()}
|
132 |
|
133 |
# Generate mesh
|
134 |
with torch.no_grad():
|
135 |
+
print("[debug] calling torch.autocast(....) to generate the mesh")
|
136 |
with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
|
137 |
trimesh_mesh, _ = spar3d_model.generate_mesh(
|
138 |
batch,
|
|
|
144 |
trimesh_mesh = trimesh_mesh[0]
|
145 |
|
146 |
# Export to GLB
|
147 |
+
print("[debug] creating tmp dir for the .glb output")
|
148 |
temp_dir = tempfile.mkdtemp()
|
149 |
output_path = os.path.join(temp_dir, 'output.glb')
|
150 |
+
|
151 |
+
print("[debug] calling trimesh_mesh.export(...) to export to .glb")
|
152 |
trimesh_mesh.export(output_path, file_type="glb", include_normals=True)
|
153 |
|
154 |
return output_path, generated_image
|