jbilcke-hf HF staff commited on
Commit
e973397
·
verified ·
1 Parent(s): 592f4b9

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +17 -15
gradio_app.py CHANGED
@@ -67,11 +67,11 @@ def create_batch(input_image: Image) -> dict[str, Any]:
67
  }
68
  return batch
69
 
70
- def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, height: int = 1024) -> str:
71
  """Generate image from prompt and convert to 3D model."""
72
  try:
73
  # Generate image using FLUX
74
- generator = torch.Generator().manual_seed(seed)
75
  generated_image = flux_pipe(
76
  prompt=prompt,
77
  width=width,
@@ -84,12 +84,13 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
84
  # Convert PIL image to RGBA
85
  input_image = generated_image.convert("RGBA")
86
 
87
- # Remove background if needed
88
- input_image = bg_remover.process(input_image.convert("RGB"))
 
89
 
90
  # Auto crop
91
  input_image = spar3d_utils.foreground_crop(
92
- input_image,
93
  crop_ratio=1.3,
94
  newsize=(COND_WIDTH, COND_HEIGHT),
95
  no_crop=False
@@ -101,7 +102,7 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
101
 
102
  # Generate mesh
103
  with torch.no_grad():
104
- with torch.autocast(device_type=device, dtype=torch.bfloat16):
105
  trimesh_mesh, _ = spar3d_model.generate_mesh(
106
  batch,
107
  1024, # texture_resolution
@@ -112,16 +113,17 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
112
  trimesh_mesh = trimesh_mesh[0]
113
 
114
  # Export to GLB
115
- temp_file = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
116
- trimesh_mesh.export(temp_file.name, file_type="glb", include_normals=True)
 
117
 
118
- return temp_file.name, generated_image
119
 
120
  except Exception as e:
121
- return str(e), None
122
-
123
-
124
 
 
125
  demo = gr.Interface(
126
  fn=generate_and_process_3d,
127
  inputs=[
@@ -153,8 +155,8 @@ demo = gr.Interface(
153
  ],
154
  outputs=[
155
  gr.File(
156
- label="Download GLB",
157
- file_types=[".glb"],
158
  ),
159
  gr.Image(
160
  label="Generated Image",
@@ -166,4 +168,4 @@ demo = gr.Interface(
166
  )
167
 
168
  if __name__ == "__main__":
169
- demo.launch()
 
67
  }
68
  return batch
69
 
70
+ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, height: int = 1024) -> tuple[str, Image.Image]:
71
  """Generate image from prompt and convert to 3D model."""
72
  try:
73
  # Generate image using FLUX
74
+ generator = torch.Generator(device=device).manual_seed(seed)
75
  generated_image = flux_pipe(
76
  prompt=prompt,
77
  width=width,
 
84
  # Convert PIL image to RGBA
85
  input_image = generated_image.convert("RGBA")
86
 
87
+ # Remove background
88
+ rgba_image = bg_remover.process(input_image.convert("RGB"))
89
+ rgba_image.putalpha(255) # Add alpha channel
90
 
91
  # Auto crop
92
  input_image = spar3d_utils.foreground_crop(
93
+ rgba_image,
94
  crop_ratio=1.3,
95
  newsize=(COND_WIDTH, COND_HEIGHT),
96
  no_crop=False
 
102
 
103
  # Generate mesh
104
  with torch.no_grad():
105
+ with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
106
  trimesh_mesh, _ = spar3d_model.generate_mesh(
107
  batch,
108
  1024, # texture_resolution
 
113
  trimesh_mesh = trimesh_mesh[0]
114
 
115
  # Export to GLB
116
+ temp_dir = tempfile.mkdtemp()
117
+ output_path = os.path.join(temp_dir, 'output.glb')
118
+ trimesh_mesh.export(output_path, file_type="glb", include_normals=True)
119
 
120
+ return output_path, generated_image
121
 
122
  except Exception as e:
123
+ print(f"Error: {str(e)}")
124
+ return None, None
 
125
 
126
+ # Create Gradio interface
127
  demo = gr.Interface(
128
  fn=generate_and_process_3d,
129
  inputs=[
 
155
  ],
156
  outputs=[
157
  gr.File(
158
+ label="Download 3D Model",
159
+ file_types=[".glb"]
160
  ),
161
  gr.Image(
162
  label="Generated Image",
 
168
  )
169
 
170
  if __name__ == "__main__":
171
+ demo.queue().launch()