vibs08 commited on
Commit
fe7389a
·
verified ·
1 Parent(s): 50b44f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -12
app.py CHANGED
@@ -36,19 +36,10 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "1"
36
 
37
  HEADER = """FRAME AI"""
38
 
39
- # if torch.cuda.is_available():
40
- # device = "cuda:0"
41
- # else:
42
- # device = "cpu"
43
-
44
- # torch.cuda.set_device(1)
45
-
46
- # CUDA_LAUNCH_BLOCKING=1
47
-
48
-
49
  if torch.cuda.is_available():
50
- torch.cuda.set_device(1)
51
-
 
52
 
53
 
54
  model = TSR.from_pretrained(
@@ -168,6 +159,7 @@ def check_input_image(input_image):
168
  raise gr.Error("No image uploaded!")
169
 
170
  def preprocess(input_image, do_remove_background, foreground_ratio):
 
171
  def fill_background(image):
172
  image = np.array(image).astype(np.float32) / 255.0
173
  image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
@@ -185,10 +177,12 @@ def preprocess(input_image, do_remove_background, foreground_ratio):
185
  image = input_image
186
  if image.mode == "RGBA":
187
  image = fill_background(image)
 
188
  return image
189
 
190
  # @spaces.GPU
191
  def generate(image, mc_resolution, formats=["obj", "glb"]):
 
192
  torch.cuda.empty_cache()
193
  scene_codes = model(image, device=device)
194
  mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
@@ -200,6 +194,8 @@ def generate(image, mc_resolution, formats=["obj", "glb"]):
200
  mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False)
201
  mesh.apply_scale([-1, 1, 1]) # Otherwise the visualized .obj will be flipped
202
  mesh.export(mesh_path_obj.name)
 
 
203
 
204
  return mesh_path_obj.name, mesh_path_glb.name
205
 
 
36
 
37
  HEADER = """FRAME AI"""
38
 
 
 
 
 
 
 
 
 
 
 
39
  if torch.cuda.is_available():
40
+ device = "cuda:0"
41
+ else:
42
+ device = "cpu"
43
 
44
 
45
  model = TSR.from_pretrained(
 
159
  raise gr.Error("No image uploaded!")
160
 
161
  def preprocess(input_image, do_remove_background, foreground_ratio):
162
+ torch.cuda.synchronize()
163
  def fill_background(image):
164
  image = np.array(image).astype(np.float32) / 255.0
165
  image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
 
177
  image = input_image
178
  if image.mode == "RGBA":
179
  image = fill_background(image)
180
+ torch.cuda.synchronize()
181
  return image
182
 
183
  # @spaces.GPU
184
  def generate(image, mc_resolution, formats=["obj", "glb"]):
185
+ torch.cuda.synchronize()
186
  torch.cuda.empty_cache()
187
  scene_codes = model(image, device=device)
188
  mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
 
194
  mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False)
195
  mesh.apply_scale([-1, 1, 1]) # Otherwise the visualized .obj will be flipped
196
  mesh.export(mesh_path_obj.name)
197
+
198
+ torch.cuda.synchronize()
199
 
200
  return mesh_path_obj.name, mesh_path_glb.name
201