vibs08 commited on
Commit
fbf7ae4
·
verified ·
1 Parent(s): 0c4832c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -30
app.py CHANGED
@@ -147,63 +147,52 @@ def check_input_image(input_image):
147
 
148
  def preprocess(input_image, do_remove_background, foreground_ratio):
149
  def fill_background(image):
150
- # torch.cuda.synchronize()
151
- # torch.cuda.empty_cache()
152
  image = np.array(image).astype(np.float32) / 255.0
153
  image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
154
  image = Image.fromarray((image * 255.0).astype(np.uint8))
155
  return image
156
 
157
  if do_remove_background:
158
- # torch.cuda.synchronize()
159
- # torch.cuda.empty_cache()
160
  image = input_image.convert("RGB")
161
  image = remove_background(image, rembg_session)
162
  image = resize_foreground(image, foreground_ratio)
163
  image = fill_background(image)
164
 
165
- # torch.cuda.synchronize()
166
- # torch.cuda.empty_cache()
167
- print("Background Removed")
168
  else:
169
  image = input_image
170
  if image.mode == "RGBA":
171
  image = fill_background(image)
172
- # torch.cuda.synchronize() # Wait for all CUDA operations to complete
173
- # torch.cuda.empty_cache()
174
  return image
175
 
176
- # @spaces.GPU
 
177
  def generate(image, mc_resolution, formats=["obj", "glb"]):
178
- # torch.cuda.synchronize()
179
  scene_codes = model(image, device=device)
180
- # torch.cuda.synchronize()
181
  mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
182
- # torch.cuda.synchronize()
183
  mesh = to_gradio_3d_orientation(mesh)
184
- # torch.cuda.synchronize()
185
-
186
  mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f".glb", delete=False)
187
- # torch.cuda.synchronize()
188
  mesh.export(mesh_path_glb.name)
189
- # torch.cuda.synchronize()
190
-
191
  mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False)
192
- # torch.cuda.synchronize()
193
- mesh.apply_scale([-1, 1, 1]) # Otherwise the visualized .obj will be flipped
194
  mesh.export(mesh_path_obj.name)
195
- # torch.cuda.synchronize() # Ensure all CUDA operations are complete before clearing cache
196
- # torch.cuda.empty_cache()
197
  return mesh_path_obj.name, mesh_path_glb.name
198
 
199
- def upload_file_to_s3(file_path, bucket_name, object_name=None):
200
- s3_client.upload_file(file_path, bucket_name, object_name)
201
-
202
- # print(f"File {file_path} uploaded successfully to {bucket_name}/{object_name}.")
203
- # torch.cuda.synchronize() # Wait for all CUDA operations to complete
204
- # torch.cuda.empty_cache()
205
- return True
206
-
207
 
208
 
209
  @app.post("/process_image/")
 
147
 
148
  def preprocess(input_image, do_remove_background, foreground_ratio):
149
  def fill_background(image):
150
+ torch.cuda.synchronize() # Ensure previous CUDA operations are complete
 
151
  image = np.array(image).astype(np.float32) / 255.0
152
  image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
153
  image = Image.fromarray((image * 255.0).astype(np.uint8))
154
  return image
155
 
156
  if do_remove_background:
157
+ torch.cuda.synchronize()
 
158
  image = input_image.convert("RGB")
159
  image = remove_background(image, rembg_session)
160
  image = resize_foreground(image, foreground_ratio)
161
  image = fill_background(image)
162
 
163
+ torch.cuda.synchronize()
 
 
164
  else:
165
  image = input_image
166
  if image.mode == "RGBA":
167
  image = fill_background(image)
168
+ torch.cuda.synchronize() # Wait for all CUDA operations to complete
169
+ torch.cuda.empty_cache()
170
  return image
171
 
172
+
173
+
174
  def generate(image, mc_resolution, formats=["obj", "glb"]):
175
+ torch.cuda.synchronize()
176
  scene_codes = model(image, device=device)
177
+ torch.cuda.synchronize()
178
  mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
179
+ torch.cuda.synchronize()
180
  mesh = to_gradio_3d_orientation(mesh)
181
+ torch.cuda.synchronize()
182
+
183
  mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f".glb", delete=False)
184
+ torch.cuda.synchronize()
185
  mesh.export(mesh_path_glb.name)
186
+ torch.cuda.synchronize()
187
+
188
  mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False)
189
+ torch.cuda.synchronize()
190
+ mesh.apply_scale([-1, 1, 1])
191
  mesh.export(mesh_path_obj.name)
192
+ torch.cuda.synchronize()
193
+ torch.cuda.empty_cache()
194
  return mesh_path_obj.name, mesh_path_glb.name
195
 
 
 
 
 
 
 
 
 
196
 
197
 
198
  @app.post("/process_image/")