xinjie.wang commited on
Commit
4d0533e
·
1 Parent(s): 73797d3
Files changed (2) hide show
  1. app.py +7 -27
  2. common.py +21 -8
app.py CHANGED
@@ -35,22 +35,6 @@ TMP_DIR = os.path.join(
35
  )
36
  os.makedirs(TMP_DIR, exist_ok=True)
37
 
38
- RBG_REMOVER = RembgRemover()
39
- SAM_PREDICTOR = SAMPredictor(model_type="vit_h")
40
- # DELIGHT = DelightingModel()
41
- # IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
42
- # PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
43
- # "JeffreyXiang/TRELLIS-image-large"
44
- # )
45
- # # PIPELINE.cuda()
46
-
47
- IMAGE_BUFFER = {}
48
- # SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
49
- # GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
50
- # AESTHETIC_CHECKER = ImageAestheticChecker()
51
- # CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
52
- # URDF_CONVERTOR = URDFGenerator(GPT_CLIENT, render_view_num=4)
53
-
54
 
55
  def start_session(req: gr.Request) -> None:
56
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
@@ -233,11 +217,7 @@ with gr.Blocks(
233
  )
234
  ],
235
  inputs=[image_prompt],
236
- fn=partial(
237
- preprocess_image_fn,
238
- model=RBG_REMOVER,
239
- buffer=IMAGE_BUFFER,
240
- ),
241
  outputs=[image_prompt],
242
  run_on_click=True,
243
  examples_per_page=32,
@@ -305,7 +285,7 @@ with gr.Blocks(
305
  )
306
 
307
  image_prompt.upload(
308
- partial(preprocess_image_fn, model=gr.State(value=RBG_REMOVER), buffer=gr.State(value=IMAGE_BUFFER)),
309
  inputs=[image_prompt],
310
  outputs=[image_prompt],
311
  )
@@ -431,7 +411,7 @@ with gr.Blocks(
431
  slat_guidance_strength,
432
  slat_sampling_steps,
433
  gr.State(lambda: IMAGE_BUFFER),
434
- # gr.State(lambda: PIPELINE),
435
  gr.State(lambda: TMP_DIR),
436
  image_seg_sam,
437
  is_samimage,
@@ -448,8 +428,8 @@ with gr.Blocks(
448
  output_buf,
449
  project_delight,
450
  gr.State(lambda: TMP_DIR),
451
- # gr.State(lambda: DELIGHT),
452
- # gr.State(lambda: IMAGESR_MODEL),
453
  ],
454
  outputs=[
455
  model_output_mesh,
@@ -472,9 +452,9 @@ with gr.Blocks(
472
  mass_range_text,
473
  asset_version_text,
474
  gr.State(lambda: TMP_DIR),
475
- # gr.State(lambda: URDF_CONVERTOR),
476
  gr.State(lambda: IMAGE_BUFFER),
477
- # gr.State(lambda: CHECKERS),
478
  ],
479
  outputs=[
480
  download_urdf,
 
35
  )
36
  os.makedirs(TMP_DIR, exist_ok=True)
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def start_session(req: gr.Request) -> None:
40
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
217
  )
218
  ],
219
  inputs=[image_prompt],
220
+ fn=preprocess_image_fn,
 
 
 
 
221
  outputs=[image_prompt],
222
  run_on_click=True,
223
  examples_per_page=32,
 
285
  )
286
 
287
  image_prompt.upload(
288
+ preprocess_image_fn,
289
  inputs=[image_prompt],
290
  outputs=[image_prompt],
291
  )
 
411
  slat_guidance_strength,
412
  slat_sampling_steps,
413
  gr.State(lambda: IMAGE_BUFFER),
414
+ gr.State(lambda: PIPELINE),
415
  gr.State(lambda: TMP_DIR),
416
  image_seg_sam,
417
  is_samimage,
 
428
  output_buf,
429
  project_delight,
430
  gr.State(lambda: TMP_DIR),
431
+ gr.State(lambda: DELIGHT),
432
+ gr.State(lambda: IMAGESR_MODEL),
433
  ],
434
  outputs=[
435
  model_output_mesh,
 
452
  mass_range_text,
453
  asset_version_text,
454
  gr.State(lambda: TMP_DIR),
455
+ gr.State(lambda: URDF_CONVERTOR),
456
  gr.State(lambda: IMAGE_BUFFER),
457
+ gr.State(lambda: CHECKERS),
458
  ],
459
  outputs=[
460
  download_urdf,
common.py CHANGED
@@ -44,6 +44,7 @@ from asset3d_gen.validators.quality_checkers import (
44
  MeshGeoChecker,
45
  )
46
  from asset3d_gen.validators.urdf_convertor import URDFGenerator, zip_files
 
47
 
48
  current_file_path = os.path.abspath(__file__)
49
  current_dir = os.path.dirname(current_file_path)
@@ -70,6 +71,23 @@ MAX_SEED = 100000
70
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
71
 
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  @spaces.GPU
74
  def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
75
  renderer = MeshRenderer()
@@ -126,21 +144,16 @@ def render_video(
126
  @spaces.GPU
127
  def preprocess_image_fn(
128
  image: str | np.ndarray | Image.Image,
129
- model: DelightingModel | RembgRemover,
130
- buffer: dict = None,
131
  ) -> Image.Image:
132
  if isinstance(image, str):
133
  image = Image.open(image)
134
  elif isinstance(image, np.ndarray):
135
  image = Image.fromarray(image)
136
 
137
- if buffer is not None:
138
- buffer["raw_image"] = image
139
 
140
- if isinstance(model, DelightingModel):
141
- image = model(image, preprocess=True, target_wh=(512, 512))
142
- elif isinstance(model, RembgRemover):
143
- image = model(image)
144
  image = trellis_preprocess(image)
145
 
146
  return image
 
44
  MeshGeoChecker,
45
  )
46
  from asset3d_gen.validators.urdf_convertor import URDFGenerator, zip_files
47
+ from asset3d_gen.utils.gpt_clients import GPT_CLIENT
48
 
49
  current_file_path = os.path.abspath(__file__)
50
  current_dir = os.path.dirname(current_file_path)
 
71
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
72
 
73
 
74
+ RBG_REMOVER = RembgRemover()
75
+ SAM_PREDICTOR = SAMPredictor(model_type="vit_h")
76
+ DELIGHT = DelightingModel()
77
+ IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
78
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
79
+ "JeffreyXiang/TRELLIS-image-large"
80
+ )
81
+ # PIPELINE.cuda()
82
+
83
+ IMAGE_BUFFER = {}
84
+ SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
85
+ GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
86
+ AESTHETIC_CHECKER = ImageAestheticChecker()
87
+ CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
88
+ URDF_CONVERTOR = URDFGenerator(GPT_CLIENT, render_view_num=4)
89
+
90
+
91
  @spaces.GPU
92
  def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
93
  renderer = MeshRenderer()
 
144
  @spaces.GPU
145
  def preprocess_image_fn(
146
  image: str | np.ndarray | Image.Image,
 
 
147
  ) -> Image.Image:
148
  if isinstance(image, str):
149
  image = Image.open(image)
150
  elif isinstance(image, np.ndarray):
151
  image = Image.fromarray(image)
152
 
153
+ if IMAGE_BUFFER is not None:
154
+ IMAGE_BUFFER["raw_image"] = image
155
 
156
+ image = RBG_REMOVER(image)
 
 
 
157
  image = trellis_preprocess(image)
158
 
159
  return image