xinjie.wang commited on
Commit
ffe3ce4
·
1 Parent(s): 07a8a18
Files changed (3) hide show
  1. app.py +7 -41
  2. asset3d_gen/models/text_model.py +1 -1
  3. common.py +290 -57
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import os
2
- import shutil
 
3
  from functools import partial
4
 
5
  import gradio as gr
6
  from common import (
7
  MAX_SEED,
8
  VERSION,
9
- TrellisImageTo3DPipeline,
10
  active_btn_by_content,
11
  extract_3d_representations_v2,
12
  extract_urdf,
@@ -15,36 +15,13 @@ from common import (
15
  preprocess_image_fn,
16
  preprocess_sam_image_fn,
17
  select_point,
 
 
18
  )
19
  from gradio.themes import Default
20
  from gradio.themes.utils.colors import slate
21
- from gradio_litmodel3d import LitModel3D
22
- from asset3d_gen.models.delight_model import DelightingModel
23
- from asset3d_gen.models.segment_model import RembgRemover, SAMPredictor
24
- from asset3d_gen.models.sr_model import ImageRealESRGAN
25
- from asset3d_gen.utils.gpt_clients import GPT_CLIENT
26
- from asset3d_gen.validators.quality_checkers import (
27
- ImageAestheticChecker,
28
- ImageSegChecker,
29
- MeshGeoChecker,
30
- )
31
- from asset3d_gen.validators.urdf_convertor import URDFGenerator
32
-
33
- TMP_DIR = os.path.join(
34
- os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
35
- )
36
- os.makedirs(TMP_DIR, exist_ok=True)
37
-
38
-
39
- def start_session(req: gr.Request) -> None:
40
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
41
- os.makedirs(user_dir, exist_ok=True)
42
 
43
 
44
- def end_session(req: gr.Request) -> None:
45
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
46
- if os.path.exists(user_dir):
47
- shutil.rmtree(user_dir)
48
 
49
 
50
  with gr.Blocks(
@@ -220,7 +197,7 @@ with gr.Blocks(
220
  fn=preprocess_image_fn,
221
  outputs=[image_prompt],
222
  run_on_click=True,
223
- examples_per_page=32,
224
  )
225
 
226
  with gr.Row(visible=False) as single_sam_image_example:
@@ -236,7 +213,7 @@ with gr.Blocks(
236
  fn=preprocess_sam_image_fn,
237
  outputs=[image_prompt_sam],
238
  run_on_click=True,
239
- examples_per_page=32,
240
  )
241
  with gr.Column(scale=1):
242
  video_output = gr.Video(
@@ -246,7 +223,7 @@ with gr.Blocks(
246
  height=300,
247
  )
248
  model_output_gs = gr.Model3D(
249
- label="Gaussian Representation", height=300, interactive=False # , clear_color=[0.9, 0.9, 0.9, 1.0],
250
  )
251
  aligned_gs = gr.Textbox(visible=False)
252
  with gr.Row():
@@ -381,7 +358,6 @@ with gr.Blocks(
381
  image_prompt_sam,
382
  selected_points,
383
  fg_bg_radio,
384
- # gr.State(lambda: SAM_PREDICTOR),
385
  ],
386
  [image_mask_sam, image_seg_sam],
387
  )
@@ -404,9 +380,6 @@ with gr.Blocks(
404
  ss_sampling_steps,
405
  slat_guidance_strength,
406
  slat_sampling_steps,
407
- # gr.State(lambda: IMAGE_BUFFER),
408
- # gr.State(lambda: PIPELINE),
409
- gr.State(lambda: TMP_DIR),
410
  image_seg_sam,
411
  is_samimage,
412
  ],
@@ -421,9 +394,6 @@ with gr.Blocks(
421
  inputs=[
422
  output_buf,
423
  project_delight,
424
- gr.State(lambda: TMP_DIR),
425
- # gr.State(lambda: DELIGHT),
426
- # gr.State(lambda: IMAGESR_MODEL),
427
  ],
428
  outputs=[
429
  model_output_mesh,
@@ -445,10 +415,6 @@ with gr.Blocks(
445
  height_range_text,
446
  mass_range_text,
447
  asset_version_text,
448
- gr.State(lambda: TMP_DIR),
449
- # gr.State(lambda: URDF_CONVERTOR),
450
- # gr.State(lambda: IMAGE_BUFFER),
451
- # gr.State(lambda: CHECKERS),
452
  ],
453
  outputs=[
454
  download_urdf,
 
1
  import os
2
+ os.environ["GRADIO_APP"] = "imageto3d"
3
+
4
  from functools import partial
5
 
6
  import gradio as gr
7
  from common import (
8
  MAX_SEED,
9
  VERSION,
 
10
  active_btn_by_content,
11
  extract_3d_representations_v2,
12
  extract_urdf,
 
15
  preprocess_image_fn,
16
  preprocess_sam_image_fn,
17
  select_point,
18
+ start_session,
19
+ end_session,
20
  )
21
  from gradio.themes import Default
22
  from gradio.themes.utils.colors import slate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
 
 
 
 
25
 
26
 
27
  with gr.Blocks(
 
197
  fn=preprocess_image_fn,
198
  outputs=[image_prompt],
199
  run_on_click=True,
200
+ examples_per_page=10,
201
  )
202
 
203
  with gr.Row(visible=False) as single_sam_image_example:
 
213
  fn=preprocess_sam_image_fn,
214
  outputs=[image_prompt_sam],
215
  run_on_click=True,
216
+ examples_per_page=10,
217
  )
218
  with gr.Column(scale=1):
219
  video_output = gr.Video(
 
223
  height=300,
224
  )
225
  model_output_gs = gr.Model3D(
226
+ label="Gaussian Representation", height=300, interactive=False
227
  )
228
  aligned_gs = gr.Textbox(visible=False)
229
  with gr.Row():
 
358
  image_prompt_sam,
359
  selected_points,
360
  fg_bg_radio,
 
361
  ],
362
  [image_mask_sam, image_seg_sam],
363
  )
 
380
  ss_sampling_steps,
381
  slat_guidance_strength,
382
  slat_sampling_steps,
 
 
 
383
  image_seg_sam,
384
  is_samimage,
385
  ],
 
394
  inputs=[
395
  output_buf,
396
  project_delight,
 
 
 
397
  ],
398
  outputs=[
399
  model_output_mesh,
 
415
  height_range_text,
416
  mass_range_text,
417
  asset_version_text,
 
 
 
 
418
  ],
419
  outputs=[
420
  download_urdf,
asset3d_gen/models/text_model.py CHANGED
@@ -75,7 +75,7 @@ def build_text2img_ip_pipeline(
75
  pipe.set_ip_adapter_scale([ref_scale])
76
 
77
  pipe = pipe.to(device)
78
- pipe.enable_model_cpu_offload()
79
  # pipe.enable_xformers_memory_efficient_attention()
80
  # pipe.enable_vae_slicing()
81
 
 
75
  pipe.set_ip_adapter_scale([ref_scale])
76
 
77
  pipe = pipe.to(device)
78
+ # pipe.enable_model_cpu_offload()
79
  # pipe.enable_xformers_memory_efficient_attention()
80
  # pipe.enable_vae_slicing()
81
 
common.py CHANGED
@@ -4,8 +4,9 @@ import os
4
  import sys
5
  from glob import glob
6
  from typing import Union
7
-
8
  import cv2
 
9
  import gradio as gr
10
  import numpy as np
11
  import spaces
@@ -45,6 +46,11 @@ from asset3d_gen.validators.quality_checkers import (
45
  )
46
  from asset3d_gen.validators.urdf_convertor import URDFGenerator, zip_files
47
  from asset3d_gen.utils.gpt_clients import GPT_CLIENT
 
 
 
 
 
48
 
49
  current_file_path = os.path.abspath(__file__)
50
  current_dir = os.path.dirname(current_file_path)
@@ -67,25 +73,68 @@ logging.basicConfig(
67
  logger = logging.getLogger(__name__)
68
 
69
 
70
- MAX_SEED = 100000
71
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
72
-
73
-
74
- RBG_REMOVER = RembgRemover()
75
- SAM_PREDICTOR = SAMPredictor(model_type="vit_h")
76
  DELIGHT = DelightingModel()
77
  IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
78
- PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
79
- "JeffreyXiang/TRELLIS-image-large"
80
- )
81
- # PIPELINE.cuda()
82
 
83
- IMAGE_BUFFER = {}
84
- SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
85
- GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
86
- AESTHETIC_CHECKER = ImageAestheticChecker()
87
- CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
88
- URDF_CONVERTOR = URDFGenerator(GPT_CLIENT, render_view_num=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
 
91
  @spaces.GPU
@@ -150,8 +199,7 @@ def preprocess_image_fn(
150
  elif isinstance(image, np.ndarray):
151
  image = Image.fromarray(image)
152
 
153
- if IMAGE_BUFFER is not None:
154
- IMAGE_BUFFER["raw_image"] = image
155
 
156
  image = RBG_REMOVER(image)
157
  image = trellis_preprocess(image)
@@ -160,15 +208,13 @@ def preprocess_image_fn(
160
 
161
 
162
  @spaces.GPU
163
- def preprocess_sam_image_fn(
164
- image: Image.Image, buffer: dict, model: SAMPredictor
165
- ) -> Image.Image:
166
  if isinstance(image, np.ndarray):
167
  image = Image.fromarray(image)
168
 
169
- buffer["raw_image"] = image
170
- sam_image = model.preprocess_image(image)
171
- model.predictor.set_image(sam_image)
172
 
173
  return sam_image
174
 
@@ -254,7 +300,6 @@ def select_point(
254
  image: np.ndarray,
255
  sel_pix: list,
256
  point_type: str,
257
- model: SAMPredictor,
258
  evt: gr.SelectData,
259
  ):
260
  if point_type == "foreground_point":
@@ -264,8 +309,8 @@ def select_point(
264
  else:
265
  sel_pix.append((evt.index, 1)) # default foreground_point
266
 
267
- masks = model.generate_masks(image, sel_pix)
268
- seg_image = model.get_segmented_image(image, masks)
269
 
270
  for point, label in sel_pix:
271
  color = (255, 0, 0) if label == 0 else (0, 255, 0)
@@ -292,9 +337,6 @@ def image_to_3d(
292
  ss_sampling_steps: int,
293
  slat_guidance_strength: float,
294
  slat_sampling_steps: int,
295
- buffer: dict,
296
- pipeline: TrellisImageTo3DPipeline,
297
- output_root: str,
298
  sam_image: Image.Image = None,
299
  is_sam_image: bool = False,
300
  req: gr.Request = None,
@@ -309,10 +351,10 @@ def image_to_3d(
309
 
310
  if isinstance(seg_image, np.ndarray):
311
  seg_image = Image.fromarray(seg_image)
312
- buffer["seg_image"] = seg_image
313
 
314
- pipeline.cuda()
315
- outputs = pipeline.run(
316
  seg_image,
317
  seed=seed,
318
  formats=["gaussian", "mesh"],
@@ -327,12 +369,13 @@ def image_to_3d(
327
  },
328
  )
329
  # Set to cpu for memory saving.
330
- pipeline.cpu()
331
 
332
  gs_model = outputs["gaussian"][0]
333
  mesh_model = outputs["mesh"][0]
334
  color_images = render_video(gs_model)["color"]
335
  normal_images = render_video(mesh_model)["normal"]
 
336
  if req is not None:
337
  output_root = os.path.join(output_root, str(req.session_hash))
338
  video_path = os.path.join(output_root, "gs_mesh.mp4")
@@ -347,9 +390,10 @@ def image_to_3d(
347
 
348
  @spaces.GPU
349
  def extract_3d_representations(
350
- state: dict, enable_delight: bool, output_root: str, req: gr.Request
351
  ):
352
- user_dir = os.path.join(output_root, str(req.session_hash))
 
353
  gs_model, mesh_model = unpack_state(state)
354
 
355
  mesh = postprocessing_utils.to_glb(
@@ -360,7 +404,7 @@ def extract_3d_representations(
360
  verbose=True,
361
  )
362
  filename = "sample"
363
- gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
364
  gs_model.save_ply(gs_path)
365
 
366
  # Rotate mesh and GS by 90 degrees around Z-axis.
@@ -378,9 +422,9 @@ def extract_3d_representations(
378
  )
379
 
380
  mesh.vertices = mesh.vertices @ np.array(rot_matrix)
381
- mesh_obj_path = os.path.join(user_dir, f"{filename}.obj")
382
  mesh.export(mesh_obj_path)
383
- mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
384
  mesh.export(mesh_glb_path)
385
 
386
  torch.cuda.empty_cache()
@@ -392,11 +436,9 @@ def extract_3d_representations(
392
  def extract_3d_representations_v2(
393
  state: dict,
394
  enable_delight: bool,
395
- output_root: str,
396
- delight_model: DelightingModel,
397
- sr_model: Union[ImageRealESRGAN, ImageStableSR],
398
  req: gr.Request,
399
  ):
 
400
  user_dir = os.path.join(output_root, str(req.session_hash))
401
  gs_model, mesh_model = unpack_state(state)
402
 
@@ -432,8 +474,8 @@ def extract_3d_representations_v2(
432
  mesh.export(mesh_obj_path)
433
 
434
  mesh = backproject_api(
435
- delight_model=delight_model,
436
- imagesr_model=sr_model,
437
  color_path=color_path,
438
  mesh_path=mesh_obj_path,
439
  output_path=mesh_obj_path,
@@ -457,16 +499,14 @@ def extract_urdf(
457
  height_range_text: str,
458
  mass_range_text: str,
459
  asset_version_text: str,
460
- output_root: str,
461
- urdf_convertor: URDFGenerator,
462
- buffer: dict,
463
- checkers: list[BaseChecker],
464
  req: gr.Request = None,
465
  ):
 
466
  if req is not None:
467
  output_root = os.path.join(output_root, str(req.session_hash))
468
  # Convert to URDF and recover attrs by gpt4o
469
  filename = "sample"
 
470
  asset_attrs = {
471
  "version": VERSION,
472
  "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
@@ -522,13 +562,13 @@ def extract_urdf(
522
  image_dir = f"{output_root}/URDF_{filename}/{urdf_convertor.output_render_dir}/image_color" # noqa
523
  image_paths = glob(f"{image_dir}/*.png")
524
  images_list = []
525
- for checker in checkers:
526
  images = image_paths
527
  if isinstance(checker, ImageSegChecker):
528
- images = [buffer["raw_image"], buffer["seg_image"]]
529
  images_list.append(images)
530
 
531
- results = BaseChecker.validate(checkers, images_list)
532
  urdf_convertor.add_quality_tag(urdf_path, results)
533
 
534
  # Zip urdf files
@@ -559,11 +599,7 @@ def extract_urdf(
559
  @spaces.GPU
560
  def text2image_fn(
561
  prompt: str,
562
- output_root: str,
563
  guidance_scale: float,
564
- model_ip: StableDiffusionXLPipelineIP,
565
- model_img: StableDiffusionXLPipeline,
566
- bg_model: RembgRemover,
567
  infer_step: int = 50,
568
  ip_image: Image.Image | str = None,
569
  ip_adapt_scale: float = 0.3,
@@ -574,11 +610,12 @@ def text2image_fn(
574
  ):
575
  if isinstance(image_wh, int):
576
  image_wh = (image_wh, image_wh)
 
577
  if req is not None:
578
  output_root = os.path.join(output_root, str(req.session_hash))
579
  os.makedirs(output_root, exist_ok=True)
580
 
581
- pipeline = model_img if ip_image is None else model_ip
582
  if ip_image is not None:
583
  pipeline.set_ip_adapter_scale([ip_adapt_scale])
584
 
@@ -594,7 +631,7 @@ def text2image_fn(
594
  if postprocess:
595
  for idx in range(len(images)):
596
  image = images[idx]
597
- images[idx] = preprocess_image_fn(image, bg_model)
598
 
599
  save_paths = []
600
  for idx, image in enumerate(images):
@@ -608,3 +645,199 @@ def text2image_fn(
608
  torch.cuda.empty_cache()
609
 
610
  return save_paths + save_paths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import sys
5
  from glob import glob
6
  from typing import Union
7
+ import shutil
8
  import cv2
9
+ import subprocess
10
  import gradio as gr
11
  import numpy as np
12
  import spaces
 
46
  )
47
  from asset3d_gen.validators.urdf_convertor import URDFGenerator, zip_files
48
  from asset3d_gen.utils.gpt_clients import GPT_CLIENT
49
+ from asset3d_gen.scripts.render_mv import build_texture_gen_pipe, infer_pipe
50
+ from asset3d_gen.scripts.text2image import (
51
+ build_text2img_ip_pipeline,
52
+ build_text2img_pipeline,
53
+ )
54
 
55
  current_file_path = os.path.abspath(__file__)
56
  current_dir = os.path.dirname(current_file_path)
 
73
  logger = logging.getLogger(__name__)
74
 
75
 
 
76
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
77
+ MAX_SEED = 100000
78
+ IMAGE_BUFFER = {}
 
 
79
  DELIGHT = DelightingModel()
80
  IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
 
 
 
 
81
 
82
+ if os.getenv("GRADIO_APP") == "imageto3d":
83
+ RBG_REMOVER = RembgRemover()
84
+ SAM_PREDICTOR = SAMPredictor(model_type="vit_h")
85
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
86
+ "JeffreyXiang/TRELLIS-image-large"
87
+ )
88
+ # PIPELINE.cuda()
89
+ SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
90
+ GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
91
+ AESTHETIC_CHECKER = ImageAestheticChecker()
92
+ CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
93
+ TMP_DIR = os.path.join(
94
+ os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
95
+ )
96
+ elif os.getenv("GRADIO_APP") == "textto3d":
97
+ RBG_REMOVER = RembgRemover()
98
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
99
+ "JeffreyXiang/TRELLIS-image-large"
100
+ )
101
+ # PIPELINE.cuda()
102
+ PIPELINE_IMG_IP = build_text2img_ip_pipeline("weights/Kolors", ref_scale=0.3)
103
+ PIPELINE_IMG = build_text2img_pipeline("weights/Kolors")
104
+ SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
105
+ GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
106
+ AESTHETIC_CHECKER = ImageAestheticChecker()
107
+ CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
108
+ TMP_DIR = os.path.join(
109
+ os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
110
+ )
111
+ elif os.getenv("GRADIO_APP") == "texture_edit":
112
+ PIPELINE_IP = build_texture_gen_pipe(
113
+ base_ckpt_dir="./weights",
114
+ ip_adapt_scale=0.7,
115
+ device="cuda",
116
+ )
117
+ PIPELINE = build_texture_gen_pipe(
118
+ base_ckpt_dir="./weights",
119
+ ip_adapt_scale=0,
120
+ device="cuda",
121
+ )
122
+ TMP_DIR = os.path.join(
123
+ os.path.dirname(os.path.abspath(__file__)), "sessions/texture_edit"
124
+ )
125
+
126
+ os.makedirs(TMP_DIR, exist_ok=True)
127
+
128
+
129
+ def start_session(req: gr.Request) -> None:
130
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
131
+ os.makedirs(user_dir, exist_ok=True)
132
+
133
+
134
+ def end_session(req: gr.Request) -> None:
135
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
136
+ if os.path.exists(user_dir):
137
+ shutil.rmtree(user_dir)
138
 
139
 
140
  @spaces.GPU
 
199
  elif isinstance(image, np.ndarray):
200
  image = Image.fromarray(image)
201
 
202
+ IMAGE_BUFFER["raw_image"] = image
 
203
 
204
  image = RBG_REMOVER(image)
205
  image = trellis_preprocess(image)
 
208
 
209
 
210
  @spaces.GPU
211
+ def preprocess_sam_image_fn(image: Image.Image) -> Image.Image:
 
 
212
  if isinstance(image, np.ndarray):
213
  image = Image.fromarray(image)
214
 
215
+ IMAGE_BUFFER["raw_image"] = image
216
+ sam_image = SAM_PREDICTOR.preprocess_image(image)
217
+ SAM_PREDICTOR.predictor.set_image(sam_image)
218
 
219
  return sam_image
220
 
 
300
  image: np.ndarray,
301
  sel_pix: list,
302
  point_type: str,
 
303
  evt: gr.SelectData,
304
  ):
305
  if point_type == "foreground_point":
 
309
  else:
310
  sel_pix.append((evt.index, 1)) # default foreground_point
311
 
312
+ masks = SAM_PREDICTOR.generate_masks(image, sel_pix)
313
+ seg_image = SAM_PREDICTOR.get_segmented_image(image, masks)
314
 
315
  for point, label in sel_pix:
316
  color = (255, 0, 0) if label == 0 else (0, 255, 0)
 
337
  ss_sampling_steps: int,
338
  slat_guidance_strength: float,
339
  slat_sampling_steps: int,
 
 
 
340
  sam_image: Image.Image = None,
341
  is_sam_image: bool = False,
342
  req: gr.Request = None,
 
351
 
352
  if isinstance(seg_image, np.ndarray):
353
  seg_image = Image.fromarray(seg_image)
354
+ IMAGE_BUFFER["seg_image"] = seg_image
355
 
356
+ PIPELINE.cuda()
357
+ outputs = PIPELINE.run(
358
  seg_image,
359
  seed=seed,
360
  formats=["gaussian", "mesh"],
 
369
  },
370
  )
371
  # Set to cpu for memory saving.
372
+ PIPELINE.cpu()
373
 
374
  gs_model = outputs["gaussian"][0]
375
  mesh_model = outputs["mesh"][0]
376
  color_images = render_video(gs_model)["color"]
377
  normal_images = render_video(mesh_model)["normal"]
378
+ output_root = TMP_DIR
379
  if req is not None:
380
  output_root = os.path.join(output_root, str(req.session_hash))
381
  video_path = os.path.join(output_root, "gs_mesh.mp4")
 
390
 
391
  @spaces.GPU
392
  def extract_3d_representations(
393
+ state: dict, enable_delight: bool, req: gr.Request
394
  ):
395
+ output_root = TMP_DIR
396
+ output_root = os.path.join(output_root, str(req.session_hash))
397
  gs_model, mesh_model = unpack_state(state)
398
 
399
  mesh = postprocessing_utils.to_glb(
 
404
  verbose=True,
405
  )
406
  filename = "sample"
407
+ gs_path = os.path.join(output_root, f"{filename}_gs.ply")
408
  gs_model.save_ply(gs_path)
409
 
410
  # Rotate mesh and GS by 90 degrees around Z-axis.
 
422
  )
423
 
424
  mesh.vertices = mesh.vertices @ np.array(rot_matrix)
425
+ mesh_obj_path = os.path.join(output_root, f"{filename}.obj")
426
  mesh.export(mesh_obj_path)
427
+ mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
428
  mesh.export(mesh_glb_path)
429
 
430
  torch.cuda.empty_cache()
 
436
  def extract_3d_representations_v2(
437
  state: dict,
438
  enable_delight: bool,
 
 
 
439
  req: gr.Request,
440
  ):
441
+ output_root = TMP_DIR
442
  user_dir = os.path.join(output_root, str(req.session_hash))
443
  gs_model, mesh_model = unpack_state(state)
444
 
 
474
  mesh.export(mesh_obj_path)
475
 
476
  mesh = backproject_api(
477
+ delight_model=DELIGHT,
478
+ imagesr_model=IMAGESR_MODEL,
479
  color_path=color_path,
480
  mesh_path=mesh_obj_path,
481
  output_path=mesh_obj_path,
 
499
  height_range_text: str,
500
  mass_range_text: str,
501
  asset_version_text: str,
 
 
 
 
502
  req: gr.Request = None,
503
  ):
504
+ output_root = TMP_DIR
505
  if req is not None:
506
  output_root = os.path.join(output_root, str(req.session_hash))
507
  # Convert to URDF and recover attrs by gpt4o
508
  filename = "sample"
509
+ urdf_convertor = URDFGenerator(GPT_CLIENT, render_view_num=4)
510
  asset_attrs = {
511
  "version": VERSION,
512
  "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
 
562
  image_dir = f"{output_root}/URDF_{filename}/{urdf_convertor.output_render_dir}/image_color" # noqa
563
  image_paths = glob(f"{image_dir}/*.png")
564
  images_list = []
565
+ for checker in CHECKERS:
566
  images = image_paths
567
  if isinstance(checker, ImageSegChecker):
568
+ images = [IMAGE_BUFFER["raw_image"], IMAGE_BUFFER["seg_image"]]
569
  images_list.append(images)
570
 
571
+ results = BaseChecker.validate(CHECKERS, images_list)
572
  urdf_convertor.add_quality_tag(urdf_path, results)
573
 
574
  # Zip urdf files
 
599
  @spaces.GPU
600
  def text2image_fn(
601
  prompt: str,
 
602
  guidance_scale: float,
 
 
 
603
  infer_step: int = 50,
604
  ip_image: Image.Image | str = None,
605
  ip_adapt_scale: float = 0.3,
 
610
  ):
611
  if isinstance(image_wh, int):
612
  image_wh = (image_wh, image_wh)
613
+ output_root = TMP_DIR
614
  if req is not None:
615
  output_root = os.path.join(output_root, str(req.session_hash))
616
  os.makedirs(output_root, exist_ok=True)
617
 
618
+ pipeline = PIPELINE_IMG if ip_image is None else PIPELINE_IMG_IP
619
  if ip_image is not None:
620
  pipeline.set_ip_adapter_scale([ip_adapt_scale])
621
 
 
631
  if postprocess:
632
  for idx in range(len(images)):
633
  image = images[idx]
634
+ images[idx] = preprocess_image_fn(image, RBG_REMOVER)
635
 
636
  save_paths = []
637
  for idx, image in enumerate(images):
 
645
  torch.cuda.empty_cache()
646
 
647
  return save_paths + save_paths
648
+
649
+
650
+ @spaces.GPU
651
+ def generate_condition(mesh_path: str, req: gr.Request, uuid: str = "sample"):
652
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
653
+ command = [
654
+ "drender-cli",
655
+ "--mesh_path",
656
+ mesh_path,
657
+ "--output_root",
658
+ f"{output_root}/condition",
659
+ "--uuid",
660
+ f"{uuid}",
661
+ ]
662
+
663
+ _ = subprocess.run(
664
+ command, capture_output=True, text=True, encoding="utf-8"
665
+ )
666
+
667
+ gc.collect()
668
+ torch.cuda.empty_cache()
669
+
670
+ return None, None, None
671
+
672
+
673
+ @spaces.GPU
674
+ def generate_texture_mvimages(
675
+ prompt: str,
676
+ controlnet_cond_scale: float = 0.55,
677
+ guidance_scale: float = 9,
678
+ strength: float = 0.9,
679
+ num_inference_steps: int = 50,
680
+ seed: int = 0,
681
+ ip_adapt_scale: float = 0,
682
+ ip_img_path: str = None,
683
+ uid: str = "sample",
684
+ sub_idxs: tuple[tuple[int]] = ((0, 1, 2), (3, 4, 5)),
685
+ req: gr.Request = None,
686
+ ) -> list[str]:
687
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
688
+ use_ip_adapter = True if ip_img_path and ip_adapt_scale > 0 else False
689
+ PIPELINE_IP.set_ip_adapter_scale([ip_adapt_scale])
690
+ img_save_paths = infer_pipe(
691
+ index_file=f"{output_root}/condition/index.json",
692
+ controlnet_cond_scale=controlnet_cond_scale,
693
+ guidance_scale=guidance_scale,
694
+ strength=strength,
695
+ num_inference_steps=num_inference_steps,
696
+ ip_adapt_scale=ip_adapt_scale,
697
+ ip_img_path=ip_img_path,
698
+ uid=uid,
699
+ prompt=prompt,
700
+ save_dir=f"{output_root}/multi_view",
701
+ sub_idxs=sub_idxs,
702
+ pipeline=PIPELINE_IP if use_ip_adapter else PIPELINE,
703
+ seed=seed,
704
+ )
705
+
706
+ gc.collect()
707
+ torch.cuda.empty_cache()
708
+
709
+ return img_save_paths + img_save_paths
710
+
711
+
712
+ @spaces.GPU
713
+ def backproject_texture(
714
+ mesh_path: str,
715
+ input_image: str,
716
+ texture_size: int,
717
+ uuid: str = "sample",
718
+ req: gr.Request = None,
719
+ ) -> str:
720
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
721
+ output_dir = os.path.join(output_root, "texture_mesh")
722
+ os.makedirs(output_dir, exist_ok=True)
723
+ command = [
724
+ "backproject-cli",
725
+ "--mesh_path",
726
+ mesh_path,
727
+ "--input_image",
728
+ input_image,
729
+ "--output_root",
730
+ output_dir,
731
+ "--uuid",
732
+ f"{uuid}",
733
+ "--texture_size",
734
+ str(texture_size),
735
+ "--skip_fix_mesh",
736
+ ]
737
+
738
+ _ = subprocess.run(
739
+ command, capture_output=True, text=True, encoding="utf-8"
740
+ )
741
+ output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
742
+ output_glb_mesh = os.path.join(output_dir, f"{uuid}.glb")
743
+ _ = trimesh.load(output_obj_mesh).export(output_glb_mesh)
744
+
745
+ zip_file = zip_files(
746
+ input_paths=[
747
+ output_glb_mesh,
748
+ output_obj_mesh,
749
+ os.path.join(output_dir, "material.mtl"),
750
+ os.path.join(output_dir, "material_0.png"),
751
+ ],
752
+ output_zip=os.path.join(output_dir, f"{uuid}.zip"),
753
+ )
754
+
755
+ gc.collect()
756
+ torch.cuda.empty_cache()
757
+
758
+ return output_glb_mesh, output_obj_mesh, zip_file
759
+
760
+
761
+ @spaces.GPU
762
+ def backproject_texture_v2(
763
+ mesh_path: str,
764
+ input_image: str,
765
+ texture_size: int,
766
+ enable_delight: bool = True,
767
+ fix_mesh: bool = False,
768
+ uuid: str = "sample",
769
+ req: gr.Request = None,
770
+ ) -> str:
771
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
772
+ output_dir = os.path.join(output_root, "texture_mesh")
773
+ os.makedirs(output_dir, exist_ok=True)
774
+
775
+ textured_mesh = backproject_api(
776
+ delight_model=DELIGHT,
777
+ imagesr_model=IMAGESR_MODEL,
778
+ color_path=input_image,
779
+ mesh_path=mesh_path,
780
+ output_path=f"{output_dir}/{uuid}.obj",
781
+ skip_fix_mesh=not fix_mesh,
782
+ delight=enable_delight,
783
+ texture_wh=[texture_size, texture_size],
784
+ )
785
+
786
+ output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
787
+ output_glb_mesh = os.path.join(output_dir, f"{uuid}.glb")
788
+ _ = textured_mesh.export(output_glb_mesh)
789
+
790
+ zip_file = zip_files(
791
+ input_paths=[
792
+ output_glb_mesh,
793
+ output_obj_mesh,
794
+ os.path.join(output_dir, "material.mtl"),
795
+ os.path.join(output_dir, "material_0.png"),
796
+ ],
797
+ output_zip=os.path.join(output_dir, f"{uuid}.zip"),
798
+ )
799
+
800
+ gc.collect()
801
+ torch.cuda.empty_cache()
802
+
803
+ return output_glb_mesh, output_obj_mesh, zip_file
804
+
805
+
806
+ @spaces.GPU
807
+ def render_result_video(
808
+ mesh_path: str, video_size: int, req: gr.Request, uuid: str = ""
809
+ ) -> str:
810
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
811
+ output_dir = os.path.join(output_root, "texture_mesh")
812
+ command = [
813
+ "drender-cli",
814
+ "--mesh_path",
815
+ mesh_path,
816
+ "--output_root",
817
+ output_dir,
818
+ "--num_images",
819
+ "90",
820
+ "--elevation",
821
+ "20",
822
+ "--with_mtl",
823
+ "--pbr_light_factor",
824
+ "1.",
825
+ "--uuid",
826
+ f"{uuid}",
827
+ "--gen_color_mp4",
828
+ "--gen_glonormal_mp4",
829
+ "--distance",
830
+ "5.5",
831
+ "--resolution_hw",
832
+ f"{video_size}",
833
+ f"{video_size}",
834
+ ]
835
+
836
+ _ = subprocess.run(
837
+ command, capture_output=True, text=True, encoding="utf-8"
838
+ )
839
+
840
+ gc.collect()
841
+ torch.cuda.empty_cache()
842
+
843
+ return f"{output_dir}/color.mp4"