xinjie.wang commited on
Commit
211069d
·
1 Parent(s): 95a94b9
Files changed (3) hide show
  1. app.py +7 -3
  2. common.py +10 -12
  3. requirements.txt +3 -3
app.py CHANGED
@@ -35,13 +35,17 @@ with gr.Blocks(
35
  with gr.Tab(
36
  label="Image(auto seg)", id=0
37
  ) as single_image_input_tab:
38
- raw_image_cache = gr.State()
 
 
 
 
39
  image_prompt = gr.Image(
40
  label="Input Image",
41
  format="png",
42
  image_mode="RGBA",
43
  type="pil",
44
- height=500,
45
  )
46
  gr.Markdown(
47
  """
@@ -439,4 +443,4 @@ with gr.Blocks(
439
 
440
 
441
  if __name__ == "__main__":
442
- demo.launch()
 
35
  with gr.Tab(
36
  label="Image(auto seg)", id=0
37
  ) as single_image_input_tab:
38
+ raw_image_cache = gr.Image(
39
+ format="png",
40
+ image_mode="RGB",
41
+ type="pil",
42
+ )
43
  image_prompt = gr.Image(
44
  label="Input Image",
45
  format="png",
46
  image_mode="RGBA",
47
  type="pil",
48
+ height=400,
49
  )
50
  gr.Markdown(
51
  """
 
443
 
444
 
445
  if __name__ == "__main__":
446
+ demo.launch(server_name="10.34.8.82", server_port=8084)
common.py CHANGED
@@ -193,8 +193,8 @@ def render_video(
193
 
194
  @spaces.GPU
195
  def preprocess_image_fn(
196
- image: str | np.ndarray | Image.Image,
197
- ) -> Image.Image:
198
  if isinstance(image, str):
199
  image = Image.open(image)
200
  elif isinstance(image, np.ndarray):
@@ -210,8 +210,8 @@ def preprocess_image_fn(
210
 
211
  @spaces.GPU
212
  def preprocess_sam_image_fn(
213
- image: Image.Image, req: gr.Request
214
- ) -> Image.Image:
215
  if isinstance(image, np.ndarray):
216
  image = Image.fromarray(image)
217
 
@@ -359,11 +359,10 @@ def image_to_3d(
359
  if isinstance(seg_image, np.ndarray):
360
  seg_image = Image.fromarray(seg_image)
361
 
362
- print("raw_image_cache", raw_image_cache)
363
- print("seg_image", seg_image)
364
- os.makedirs(f"{TMP_DIR}/{req.session_hash}", exist_ok=True)
365
- seg_image.save(f"{TMP_DIR}/{req.session_hash}/seg_image.png")
366
- raw_image_cache.save(f"{TMP_DIR}/{req.session_hash}/raw_image.png")
367
  PIPELINE.cuda()
368
  outputs = PIPELINE.run(
369
  seg_image,
@@ -386,9 +385,8 @@ def image_to_3d(
386
  mesh_model = outputs["mesh"][0]
387
  color_images = render_video(gs_model)["color"]
388
  normal_images = render_video(mesh_model)["normal"]
389
- output_root = TMP_DIR
390
- if req is not None:
391
- output_root = os.path.join(output_root, str(req.session_hash))
392
  video_path = os.path.join(output_root, "gs_mesh.mp4")
393
  merge_images_video(color_images, normal_images, video_path)
394
  state = pack_state(gs_model, mesh_model)
 
193
 
194
  @spaces.GPU
195
  def preprocess_image_fn(
196
+ image: str | np.ndarray | Image.Image
197
+ ) -> tuple[Image.Image, Image.Image]:
198
  if isinstance(image, str):
199
  image = Image.open(image)
200
  elif isinstance(image, np.ndarray):
 
210
 
211
  @spaces.GPU
212
  def preprocess_sam_image_fn(
213
+ image: Image.Image
214
+ ) -> tuple[Image.Image, Image.Image]:
215
  if isinstance(image, np.ndarray):
216
  image = Image.fromarray(image)
217
 
 
359
  if isinstance(seg_image, np.ndarray):
360
  seg_image = Image.fromarray(seg_image)
361
 
362
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
363
+ os.makedirs(output_root, exist_ok=True)
364
+ seg_image.save(f"{output_root}/seg_image.png")
365
+ raw_image_cache.save(f"{output_root}/raw_image.png")
 
366
  PIPELINE.cuda()
367
  outputs = PIPELINE.run(
368
  seg_image,
 
385
  mesh_model = outputs["mesh"][0]
386
  color_images = render_video(gs_model)["color"]
387
  normal_images = render_video(mesh_model)["normal"]
388
+
389
+
 
390
  video_path = os.path.join(output_root, "gs_mesh.mp4")
391
  merge_images_video(color_images, normal_images, video_path)
392
  state = pack_state(gs_model, mesh_model)
requirements.txt CHANGED
@@ -1,9 +1,9 @@
1
  --extra-index-url https://download.pytorch.org/whl/cu121
2
 
3
 
4
- torch==2.4.0
5
- torchvision==0.19.0
6
- xformers==0.0.27.post2
7
  pytorch-lightning==2.4.0
8
  spconv-cu120==2.3.6
9
  dataclasses_json
 
1
  --extra-index-url https://download.pytorch.org/whl/cu121
2
 
3
 
4
+ torch==2.4.0+cu121
5
+ torchvision==0.19.0+cu121
6
+ xformers==0.0.27.post2+cu121
7
  pytorch-lightning==2.4.0
8
  spconv-cu120==2.3.6
9
  dataclasses_json