vibs08 commited on
Commit
53b232b
·
verified ·
1 Parent(s): 38c07d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -22
app.py CHANGED
@@ -23,8 +23,6 @@ import io
23
  from io import BytesIO
24
  from botocore.exceptions import NoCredentialsError, PartialCredentialsError
25
  import datetime
26
- from transformers.utils import move_cache
27
- move_cache()
28
 
29
  app = FastAPI()
30
 
@@ -41,7 +39,7 @@ if torch.cuda.is_available():
41
  else:
42
  device = "cpu"
43
 
44
- torch.cuda.synchronize()
45
 
46
  model = TSR.from_pretrained(
47
  "stabilityai/TripoSR",
@@ -151,16 +149,16 @@ def check_input_image(input_image):
151
 
152
  def preprocess(input_image, do_remove_background, foreground_ratio):
153
  def fill_background(image):
154
- torch.cuda.synchronize()
155
- torch.cuda.empty_cache()
156
  image = np.array(image).astype(np.float32) / 255.0
157
  image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
158
  image = Image.fromarray((image * 255.0).astype(np.uint8))
159
  return image
160
 
161
  if do_remove_background:
162
- torch.cuda.synchronize()
163
- torch.cuda.empty_cache()
164
  image = input_image.convert("RGB")
165
  image = remove_background(image, rembg_session)
166
  image = resize_foreground(image, foreground_ratio)
@@ -173,39 +171,39 @@ def preprocess(input_image, do_remove_background, foreground_ratio):
173
  image = input_image
174
  if image.mode == "RGBA":
175
  image = fill_background(image)
176
- torch.cuda.synchronize() # Wait for all CUDA operations to complete
177
- torch.cuda.empty_cache()
178
  return image
179
 
180
  @spaces.GPU
181
  def generate(image, mc_resolution, formats=["obj", "glb"]):
182
- torch.cuda.synchronize()
183
  scene_codes = model(image, device=device)
184
- torch.cuda.synchronize()
185
  mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
186
- torch.cuda.synchronize()
187
  mesh = to_gradio_3d_orientation(mesh)
188
- torch.cuda.synchronize()
189
 
190
  mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f".glb", delete=False)
191
- torch.cuda.synchronize()
192
  mesh.export(mesh_path_glb.name)
193
- torch.cuda.synchronize()
194
 
195
  mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False)
196
- torch.cuda.synchronize()
197
  mesh.apply_scale([-1, 1, 1]) # Otherwise the visualized .obj will be flipped
198
  mesh.export(mesh_path_obj.name)
199
- torch.cuda.synchronize() # Ensure all CUDA operations are complete before clearing cache
200
- torch.cuda.empty_cache()
201
  return mesh_path_obj.name, mesh_path_glb.name
202
 
203
  def upload_file_to_s3(file_path, bucket_name, object_name=None):
204
  s3_client.upload_file(file_path, bucket_name, object_name)
205
 
206
  # print(f"File {file_path} uploaded successfully to {bucket_name}/{object_name}.")
207
- torch.cuda.synchronize() # Wait for all CUDA operations to complete
208
- torch.cuda.empty_cache()
209
  return True
210
 
211
 
@@ -238,8 +236,8 @@ async def process_image(
238
  object_name_2 = f'object_{timestamp}_2.glb'
239
 
240
  if upload_file_to_s3(mesh_name_obj, 'framebucket3d',object_name) and upload_file_to_s3(mesh_name_glb, 'framebucket3d',object_name_2):
241
- torch.cuda.synchronize() # Wait for all CUDA operations to complete
242
- torch.cuda.empty_cache()
243
  return {
244
  "obj_path": f"https://framebucket3d.s3.amazonaws.com/{object_name}",
245
  "glb_path": f"https://framebucket3d.s3.amazonaws.com/{object_name_2}"
 
23
  from io import BytesIO
24
  from botocore.exceptions import NoCredentialsError, PartialCredentialsError
25
  import datetime
 
 
26
 
27
  app = FastAPI()
28
 
 
39
  else:
40
  device = "cpu"
41
 
42
+ # torch.cuda.synchronize()
43
 
44
  model = TSR.from_pretrained(
45
  "stabilityai/TripoSR",
 
149
 
150
  def preprocess(input_image, do_remove_background, foreground_ratio):
151
  def fill_background(image):
152
+ # torch.cuda.synchronize()
153
+ # torch.cuda.empty_cache()
154
  image = np.array(image).astype(np.float32) / 255.0
155
  image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
156
  image = Image.fromarray((image * 255.0).astype(np.uint8))
157
  return image
158
 
159
  if do_remove_background:
160
+ # torch.cuda.synchronize()
161
+ # torch.cuda.empty_cache()
162
  image = input_image.convert("RGB")
163
  image = remove_background(image, rembg_session)
164
  image = resize_foreground(image, foreground_ratio)
 
171
  image = input_image
172
  if image.mode == "RGBA":
173
  image = fill_background(image)
174
+ # torch.cuda.synchronize() # Wait for all CUDA operations to complete
175
+ # torch.cuda.empty_cache()
176
  return image
177
 
178
  @spaces.GPU
179
  def generate(image, mc_resolution, formats=["obj", "glb"]):
180
+ # torch.cuda.synchronize()
181
  scene_codes = model(image, device=device)
182
+ # torch.cuda.synchronize()
183
  mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
184
+ # torch.cuda.synchronize()
185
  mesh = to_gradio_3d_orientation(mesh)
186
+ # torch.cuda.synchronize()
187
 
188
  mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f".glb", delete=False)
189
+ # torch.cuda.synchronize()
190
  mesh.export(mesh_path_glb.name)
191
+ # torch.cuda.synchronize()
192
 
193
  mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False)
194
+ # torch.cuda.synchronize()
195
  mesh.apply_scale([-1, 1, 1]) # Otherwise the visualized .obj will be flipped
196
  mesh.export(mesh_path_obj.name)
197
+ # torch.cuda.synchronize() # Ensure all CUDA operations are complete before clearing cache
198
+ # torch.cuda.empty_cache()
199
  return mesh_path_obj.name, mesh_path_glb.name
200
 
201
  def upload_file_to_s3(file_path, bucket_name, object_name=None):
202
  s3_client.upload_file(file_path, bucket_name, object_name)
203
 
204
  # print(f"File {file_path} uploaded successfully to {bucket_name}/{object_name}.")
205
+ # torch.cuda.synchronize() # Wait for all CUDA operations to complete
206
+ # torch.cuda.empty_cache()
207
  return True
208
 
209
 
 
236
  object_name_2 = f'object_{timestamp}_2.glb'
237
 
238
  if upload_file_to_s3(mesh_name_obj, 'framebucket3d',object_name) and upload_file_to_s3(mesh_name_glb, 'framebucket3d',object_name_2):
239
+ # torch.cuda.synchronize() # Wait for all CUDA operations to complete
240
+ # torch.cuda.empty_cache()
241
  return {
242
  "obj_path": f"https://framebucket3d.s3.amazonaws.com/{object_name}",
243
  "glb_path": f"https://framebucket3d.s3.amazonaws.com/{object_name_2}"