xinjie.wang commited on
Commit
58bf798
·
1 Parent(s): 8e3d188
app.py CHANGED
@@ -120,7 +120,7 @@ with gr.Blocks(
120
  )
121
  project_delight = gr.Checkbox(
122
  label="Backproject delighting",
123
- value=True,
124
  )
125
  gr.Markdown("Geo Structure Generation")
126
  with gr.Row():
 
120
  )
121
  project_delight = gr.Checkbox(
122
  label="Backproject delighting",
123
+ value=False,
124
  )
125
  gr.Markdown("Geo Structure Generation")
126
  with gr.Row():
asset3d_gen/data/backproject_v2.py CHANGED
@@ -220,6 +220,7 @@ class TextureBacker:
220
  texture_wh: tuple[int, int] = (2048, 2048),
221
  bake_angle_thresh: int = 75,
222
  mask_thresh: float = 0.5,
 
223
  ):
224
 
225
  self.camera_params = camera_params
@@ -229,6 +230,7 @@ class TextureBacker:
229
  self.render_wh = render_wh
230
  self.texture_wh = texture_wh
231
  self.mask_thresh = mask_thresh
 
232
 
233
  self.bake_angle_thresh = bake_angle_thresh
234
  self.bake_unreliable_kernel_size = int(
@@ -468,7 +470,9 @@ class TextureBacker:
468
  texture_np, mask_np = self.compute_texture(colors, mesh)
469
 
470
  texture_np = self.uv_inpaint(mesh, texture_np, mask_np)
471
- texture_np = post_process_texture(texture_np)
 
 
472
  vertices, faces, uv_map = self.get_mesh_np_attrs(
473
  mesh, self.scale, self.center
474
  )
@@ -551,7 +555,11 @@ def parse_args():
551
  parser.add_argument(
552
  "--delight", action="store_true", help="Use delighting model."
553
  )
554
- args = parser.parse_args()
 
 
 
 
555
 
556
  return args
557
 
@@ -619,6 +627,7 @@ def entrypoint(
619
  view_weights=view_weights,
620
  render_wh=camera_params.resolution_hw,
621
  texture_wh=args.texture_wh,
 
622
  )
623
 
624
  textured_mesh = texture_backer(multiviews, mesh, args.output_path)
 
220
  texture_wh: tuple[int, int] = (2048, 2048),
221
  bake_angle_thresh: int = 75,
222
  mask_thresh: float = 0.5,
223
+ smooth_texture: bool = True,
224
  ):
225
 
226
  self.camera_params = camera_params
 
230
  self.render_wh = render_wh
231
  self.texture_wh = texture_wh
232
  self.mask_thresh = mask_thresh
233
+ self.smooth_texture = smooth_texture
234
 
235
  self.bake_angle_thresh = bake_angle_thresh
236
  self.bake_unreliable_kernel_size = int(
 
470
  texture_np, mask_np = self.compute_texture(colors, mesh)
471
 
472
  texture_np = self.uv_inpaint(mesh, texture_np, mask_np)
473
+ if self.smooth_texture:
474
+ texture_np = post_process_texture(texture_np)
475
+
476
  vertices, faces, uv_map = self.get_mesh_np_attrs(
477
  mesh, self.scale, self.center
478
  )
 
555
  parser.add_argument(
556
  "--delight", action="store_true", help="Use delighting model."
557
  )
558
+ parser.add_argument(
559
+ "--smooth_texture", type=bool, default=True, help="Smooth the texture."
560
+ )
561
+
562
+ args, unknown = parser.parse_known_args()
563
 
564
  return args
565
 
 
627
  view_weights=view_weights,
628
  render_wh=camera_params.resolution_hw,
629
  texture_wh=args.texture_wh,
630
+ smooth_texture=args.smooth_texture,
631
  )
632
 
633
  textured_mesh = texture_backer(multiviews, mesh, args.output_path)
asset3d_gen/data/utils.py CHANGED
@@ -933,21 +933,11 @@ def get_images_from_grid(
933
  return images
934
 
935
 
936
- # def post_process_texture(texture: np.ndarray, iter: int = 2) -> np.ndarray:
937
- # for _ in range(iter):
938
- # texture = cv2.fastNlMeansDenoisingColored(texture, None, 13, 13, 9, 27)
939
- # texture = cv2.bilateralFilter(
940
- # texture, d=9, sigmaColor=80, sigmaSpace=80
941
- # )
942
-
943
- # return texture
944
-
945
-
946
  def post_process_texture(texture: np.ndarray, iter: int = 1) -> np.ndarray:
947
  for _ in range(iter):
948
- texture = cv2.fastNlMeansDenoisingColored(texture, None, 5, 5, 7, 19)
949
  texture = cv2.bilateralFilter(
950
- texture, d=7, sigmaColor=50, sigmaSpace=50
951
  )
952
 
953
  return texture
 
933
  return images
934
 
935
 
 
 
 
 
 
 
 
 
 
 
936
  def post_process_texture(texture: np.ndarray, iter: int = 1) -> np.ndarray:
937
  for _ in range(iter):
938
+ texture = cv2.fastNlMeansDenoisingColored(texture, None, 2, 2, 7, 15)
939
  texture = cv2.bilateralFilter(
940
+ texture, d=5, sigmaColor=20, sigmaSpace=20
941
  )
942
 
943
  return texture
asset3d_gen/scripts/imageto3d.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import sys
5
+ from glob import glob
6
+
7
+ import numpy as np
8
+ import trimesh
9
+ from PIL import Image
10
+ from asset3d_gen.data.backproject_v2 import entrypoint as backproject_api
11
+ from asset3d_gen.models.delight_model import DelightingModel
12
+ from asset3d_gen.models.gs_model import GaussianOperator
13
+ from asset3d_gen.models.segment_model import (
14
+ BMGG14Remover,
15
+ RembgRemover,
16
+ SAMPredictor,
17
+ trellis_preprocess,
18
+ )
19
+ from asset3d_gen.models.sr_model import ImageRealESRGAN
20
+ from asset3d_gen.scripts.render_gs import entrypoint as render_gs_api
21
+ from asset3d_gen.utils.gpt_clients import GPT_CLIENT
22
+ from asset3d_gen.utils.process_media import (
23
+ merge_images_video,
24
+ render_asset3d,
25
+ render_mesh,
26
+ render_video,
27
+ )
28
+ from asset3d_gen.utils.tags import VERSION
29
+ from asset3d_gen.validators.quality_checkers import (
30
+ BaseChecker,
31
+ ImageAestheticChecker,
32
+ ImageSegChecker,
33
+ MeshGeoChecker,
34
+ )
35
+ from asset3d_gen.validators.urdf_convertor import URDFGenerator
36
+
37
+ current_file_path = os.path.abspath(__file__)
38
+ current_dir = os.path.dirname(current_file_path)
39
+ sys.path.append(os.path.join(current_dir, "../.."))
40
+ from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
41
+ from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer
42
+ from thirdparty.TRELLIS.trellis.representations import (
43
+ Gaussian,
44
+ MeshExtractResult,
45
+ )
46
+ from thirdparty.TRELLIS.trellis.representations.gaussian.general_utils import (
47
+ build_scaling_rotation,
48
+ inverse_sigmoid,
49
+ strip_symmetric,
50
+ )
51
+ from thirdparty.TRELLIS.trellis.utils import postprocessing_utils
52
+ from thirdparty.TRELLIS.trellis.utils.render_utils import (
53
+ render_frames,
54
+ yaw_pitch_r_fov_to_extrinsics_intrinsics,
55
+ )
56
+
57
+ logging.basicConfig(
58
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
59
+ )
60
+ logger = logging.getLogger(__name__)
61
+
62
+
63
+ os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
64
+ "~/.cache/torch_extensions"
65
+ )
66
+ os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
67
+ os.environ["SPCONV_ALGO"] = "native"
68
+
69
+
70
+ DELIGHT = DelightingModel()
71
+ IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
72
+
73
+ RBG_REMOVER = RembgRemover()
74
+ RBG14_REMOVER = BMGG14Remover()
75
+ SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
76
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
77
+ "JeffreyXiang/TRELLIS-image-large"
78
+ )
79
+ PIPELINE.cuda()
80
+ SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
81
+ GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
82
+ AESTHETIC_CHECKER = ImageAestheticChecker()
83
+ CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
84
+ TMP_DIR = os.path.join(
85
+ os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
86
+ )
87
+
88
+
89
+ def parse_args():
90
+ parser = argparse.ArgumentParser(description="Image to 3D pipeline args.")
91
+ parser.add_argument(
92
+ "--image_path", type=str, nargs="+", help="Path to the input images."
93
+ )
94
+ parser.add_argument(
95
+ "--image_root", type=str, help="Path to the input images folder."
96
+ )
97
+ parser.add_argument(
98
+ "--output_root",
99
+ type=str,
100
+ required=True,
101
+ help="Root directory for saving outputs.",
102
+ )
103
+ parser.add_argument(
104
+ "--no_mesh", action="store_true", help="Do not output mesh files."
105
+ )
106
+ parser.add_argument(
107
+ "--height_range",
108
+ type=str,
109
+ default=None,
110
+ help="The hight in meter to restore the mesh real size.",
111
+ )
112
+ parser.add_argument(
113
+ "--mass_range",
114
+ type=str,
115
+ default=None,
116
+ help="The mass in kg to restore the mesh real weight.",
117
+ )
118
+ parser.add_argument("--asset_type", type=str, default=None)
119
+ parser.add_argument("--skip_exists", action="store_true")
120
+ parser.add_argument("--strict_seg", action="store_true")
121
+ parser.add_argument("--version", type=str, default=VERSION)
122
+ args = parser.parse_args()
123
+
124
+ assert (
125
+ args.image_path or args.image_root
126
+ ), "Please provide either --image_path or --image_root."
127
+ if not args.image_path:
128
+ args.image_path = glob(os.path.join(args.image_root, "*.png"))
129
+ args.image_path += glob(os.path.join(args.image_root, "*.jpg"))
130
+ args.image_path += glob(os.path.join(args.image_root, "*.jpeg"))
131
+
132
+ return args
133
+
134
+
135
+ def get_segmented_image(
136
+ image,
137
+ sam_remover,
138
+ rbg_remover,
139
+ seg_checker,
140
+ image_path,
141
+ seg_path,
142
+ mode="loose",
143
+ ) -> Image.Image:
144
+ def _is_valid_seg(img: Image.Image) -> bool:
145
+ return img.mode == "RGBA" and seg_checker([image_path, seg_path])[0]
146
+
147
+ seg_image = sam_remover(image, save_path=seg_path)
148
+ if not _is_valid_seg(seg_image):
149
+ logger.warning(
150
+ f"Failed to segment {image_path} by SAM, retry with `rembg`."
151
+ ) # noqa
152
+ seg_image = rbg_remover(image, save_path=seg_path)
153
+
154
+ if not _is_valid_seg(seg_image):
155
+ if mode == "strict":
156
+ raise RuntimeError(
157
+ f"Failed to segment {image_path} by SAM and rembg, abort."
158
+ )
159
+ logger.warning(
160
+ f"Failed to segment {image_path} by rembg, use raw image."
161
+ ) # noqa
162
+ seg_image = image.convert("RGBA")
163
+ seg_image.save(seg_path)
164
+
165
+ return seg_image
166
+
167
+
168
+ if __name__ == "__main__":
169
+ args = parse_args()
170
+
171
+ for image_path in args.image_path:
172
+ try:
173
+ filename = os.path.basename(image_path).split(".")[0]
174
+ output_root = args.output_root
175
+ if args.image_root is not None:
176
+ output_root = os.path.join(output_root, filename)
177
+ os.makedirs(output_root, exist_ok=True)
178
+
179
+ mesh_out = f"{output_root}/{filename}.obj"
180
+ if args.skip_exists and os.path.exists(mesh_out):
181
+ logger.info(
182
+ f"Skip {image_path}, already processed in {mesh_out}"
183
+ )
184
+ continue
185
+
186
+ image = Image.open(image_path)
187
+ image.save(f"{output_root}/{filename}_raw.png")
188
+
189
+ # Segmentation: Get segmented image using SAM or Rembg.
190
+ seg_path = f"{output_root}/{filename}_cond.png"
191
+ if image.mode != "RGBA":
192
+ seg_image = RBG_REMOVER(image, save_path=seg_path)
193
+ seg_image = trellis_preprocess(seg_image)
194
+ else:
195
+ seg_image = image
196
+ seg_image.save(seg_path)
197
+
198
+ # Run the pipeline
199
+ try:
200
+ outputs = PIPELINE.run(
201
+ seg_image,
202
+ preprocess_image=False,
203
+ # Optional parameters
204
+ # seed=1,
205
+ # sparse_structure_sampler_params={
206
+ # "steps": 12,
207
+ # "cfg_strength": 7.5,
208
+ # },
209
+ # slat_sampler_params={
210
+ # "steps": 12,
211
+ # "cfg_strength": 3,
212
+ # },
213
+ )
214
+ except Exception as e:
215
+ logger.error(
216
+ f"[Pipeline Failed] process {image_path}: {e}, skip."
217
+ )
218
+ continue
219
+
220
+ # Render and save color and mesh videos
221
+ gs_model = outputs["gaussian"][0]
222
+ mesh_model = outputs["mesh"][0]
223
+ color_images = render_video(gs_model)["color"]
224
+ normal_images = render_video(mesh_model)["normal"]
225
+ video_path = os.path.join(output_root, "gs_mesh.mp4")
226
+ merge_images_video(color_images, normal_images, video_path)
227
+
228
+ if not args.no_mesh:
229
+ # Save the raw Gaussian model
230
+ gs_path = mesh_out.replace(".obj", "_gs.ply")
231
+ gs_model.save_ply(gs_path)
232
+
233
+ # Rotate mesh and GS by 90 degrees around Z-axis.
234
+ rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
235
+ gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
236
+ mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
237
+
238
+ # Addtional rotation for GS to align mesh.
239
+ gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
240
+ pose = GaussianOperator.trans_to_quatpose(gs_rot)
241
+ aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
242
+ GaussianOperator.resave_ply(
243
+ in_ply=gs_path,
244
+ out_ply=aligned_gs_path,
245
+ instance_pose=pose,
246
+ device="cpu",
247
+ )
248
+ color_path = os.path.join(output_root, "color.png")
249
+ render_gs_api(aligned_gs_path, color_path)
250
+
251
+ mesh = trimesh.Trimesh(
252
+ vertices=mesh_model.vertices.cpu().numpy(),
253
+ faces=mesh_model.faces.cpu().numpy(),
254
+ )
255
+ mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
256
+ mesh.vertices = mesh.vertices @ np.array(rot_matrix)
257
+
258
+ mesh_obj_path = os.path.join(output_root, f"{filename}.obj")
259
+ mesh.export(mesh_obj_path)
260
+
261
+ mesh = backproject_api(
262
+ delight_model=DELIGHT,
263
+ imagesr_model=IMAGESR_MODEL,
264
+ color_path=color_path,
265
+ mesh_path=mesh_obj_path,
266
+ output_path=mesh_obj_path,
267
+ skip_fix_mesh=False,
268
+ delight=True,
269
+ texture_wh=[2048, 2048],
270
+ )
271
+
272
+ mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
273
+ mesh.export(mesh_glb_path)
274
+
275
+ urdf_convertor = URDFGenerator(GPT_CLIENT, render_view_num=4)
276
+ asset_attrs = {
277
+ "version": VERSION,
278
+ "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
279
+ }
280
+ if args.height_range:
281
+ min_height, max_height = map(
282
+ float, args.height_range.split("-")
283
+ )
284
+ asset_attrs["min_height"] = min_height
285
+ asset_attrs["max_height"] = max_height
286
+ if args.mass_range:
287
+ min_mass, max_mass = map(float, args.mass_range.split("-"))
288
+ asset_attrs["min_mass"] = min_mass
289
+ asset_attrs["max_mass"] = max_mass
290
+ if args.asset_type:
291
+ asset_attrs["category"] = args.asset_type
292
+ if args.version:
293
+ asset_attrs["version"] = args.version
294
+
295
+ urdf_path = urdf_convertor(
296
+ mesh_path=mesh_obj_path,
297
+ output_root=f"{output_root}/URDF_{filename}",
298
+ **asset_attrs,
299
+ )
300
+
301
+ # Rescale GS and save to URDF/mesh folder.
302
+ real_height = urdf_convertor.get_attr_from_urdf(
303
+ urdf_path, attr_name="real_height"
304
+ )
305
+ out_gs = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}_gs.ply" # noqa
306
+ GaussianOperator.resave_ply(
307
+ in_ply=aligned_gs_path,
308
+ out_ply=out_gs,
309
+ real_height=real_height,
310
+ device="cpu",
311
+ )
312
+
313
+ # Quality check and update .urdf file.
314
+ mesh_out = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}.obj" # noqa
315
+ trimesh.load(mesh_out).export(mesh_out.replace(".obj", ".glb"))
316
+ # image_paths = render_asset3d(
317
+ # mesh_path=mesh_out,
318
+ # output_root=f"{output_root}/URDF_{filename}",
319
+ # output_subdir="qa_renders",
320
+ # num_images=8,
321
+ # elevation=(30, -30),
322
+ # distance=5.5,
323
+ # )
324
+
325
+ image_dir = f"{output_root}/URDF_{filename}/{urdf_convertor.output_render_dir}/image_color" # noqa
326
+ image_paths = glob(f"{image_dir}/*.png")
327
+ images_list = []
328
+ for checker in CHECKERS:
329
+ images = image_paths
330
+ if isinstance(checker, ImageSegChecker):
331
+ images = [
332
+ f"{output_root}/{filename}_raw.png",
333
+ f"{output_root}/{filename}_cond.png",
334
+ ]
335
+ images_list.append(images)
336
+
337
+ results = BaseChecker.validate(CHECKERS, images_list)
338
+ urdf_convertor.add_quality_tag(urdf_path, results)
339
+
340
+ except Exception as e:
341
+ logger.error(f"Failed to process {image_path}: {e}, skip.")
342
+ continue
343
+
344
+ logger.info(f"Processing complete. Outputs saved to {args.output_root}")
asset3d_gen/scripts/render_gs.py CHANGED
@@ -75,7 +75,7 @@ def parse_args():
75
  help="Output image size for single view in color grid (default: 512)",
76
  )
77
 
78
- args = parser.parse_args()
79
 
80
  return args
81
 
 
75
  help="Output image size for single view in color grid (default: 512)",
76
  )
77
 
78
+ args, unknown = parser.parse_known_args()
79
 
80
  return args
81
 
asset3d_gen/utils/process_media.py CHANGED
@@ -3,6 +3,7 @@ import logging
3
  import math
4
  import os
5
  import subprocess
 
6
  from glob import glob
7
  from io import BytesIO
8
  from typing import Union
@@ -12,7 +13,23 @@ import imageio
12
  import numpy as np
13
  import PIL.Image as Image
14
  import spaces
 
15
  from moviepy.editor import VideoFileClip, clips_array
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
@@ -24,6 +41,8 @@ __all__ = [
24
  "filter_small_connected_components",
25
  "filter_image_small_connected_components",
26
  "combine_images_to_base64",
 
 
27
  ]
28
 
29
 
@@ -176,6 +195,59 @@ def combine_images_to_base64(
176
  return base64.b64encode(buffer.getvalue()).decode("utf-8")
177
 
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  if __name__ == "__main__":
180
  # Example usage:
181
  merge_video_video(
 
3
  import math
4
  import os
5
  import subprocess
6
+ import sys
7
  from glob import glob
8
  from io import BytesIO
9
  from typing import Union
 
13
  import numpy as np
14
  import PIL.Image as Image
15
  import spaces
16
+ import torch
17
  from moviepy.editor import VideoFileClip, clips_array
18
+ from tqdm import tqdm
19
+
20
+ current_file_path = os.path.abspath(__file__)
21
+ current_dir = os.path.dirname(current_file_path)
22
+ sys.path.append(os.path.join(current_dir, "../.."))
23
+ from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
24
+ from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer
25
+ from thirdparty.TRELLIS.trellis.representations import (
26
+ Gaussian,
27
+ MeshExtractResult,
28
+ )
29
+ from thirdparty.TRELLIS.trellis.utils.render_utils import (
30
+ render_frames,
31
+ yaw_pitch_r_fov_to_extrinsics_intrinsics,
32
+ )
33
 
34
  logging.basicConfig(level=logging.INFO)
35
  logger = logging.getLogger(__name__)
 
41
  "filter_small_connected_components",
42
  "filter_image_small_connected_components",
43
  "combine_images_to_base64",
44
+ "render_mesh",
45
+ "render_video",
46
  ]
47
 
48
 
 
195
  return base64.b64encode(buffer.getvalue()).decode("utf-8")
196
 
197
 
198
+ @spaces.GPU
199
+ def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
200
+ renderer = MeshRenderer()
201
+ renderer.rendering_options.resolution = options.get("resolution", 512)
202
+ renderer.rendering_options.near = options.get("near", 1)
203
+ renderer.rendering_options.far = options.get("far", 100)
204
+ renderer.rendering_options.ssaa = options.get("ssaa", 4)
205
+ rets = {}
206
+ for extr, intr in tqdm(zip(extrinsics, intrinsics), desc="Rendering"):
207
+ res = renderer.render(sample, extr, intr)
208
+ if "normal" not in rets:
209
+ rets["normal"] = []
210
+ normal = torch.lerp(
211
+ torch.zeros_like(res["normal"]), res["normal"], res["mask"]
212
+ )
213
+ normal = np.clip(
214
+ normal.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255
215
+ ).astype(np.uint8)
216
+ rets["normal"].append(normal)
217
+
218
+ return rets
219
+
220
+
221
+ @spaces.GPU
222
+ def render_video(
223
+ sample,
224
+ resolution=512,
225
+ bg_color=(0, 0, 0),
226
+ num_frames=300,
227
+ r=2,
228
+ fov=40,
229
+ **kwargs,
230
+ ):
231
+ yaws = torch.linspace(0, 2 * 3.1415, num_frames)
232
+ yaws = yaws.tolist()
233
+ pitch = [0.5] * num_frames
234
+ extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(
235
+ yaws, pitch, r, fov
236
+ )
237
+ render_fn = (
238
+ render_mesh if isinstance(sample, MeshExtractResult) else render_frames
239
+ )
240
+ result = render_fn(
241
+ sample,
242
+ extrinsics,
243
+ intrinsics,
244
+ {"resolution": resolution, "bg_color": bg_color},
245
+ **kwargs,
246
+ )
247
+
248
+ return result
249
+
250
+
251
  if __name__ == "__main__":
252
  # Example usage:
253
  merge_video_video(
common.py CHANGED
@@ -26,7 +26,7 @@ from asset3d_gen.models.segment_model import (
26
  SAMPredictor,
27
  trellis_preprocess,
28
  )
29
- from asset3d_gen.models.sr_model import ImageRealESRGAN
30
  from asset3d_gen.scripts.render_gs import entrypoint as render_gs_api
31
  from asset3d_gen.scripts.render_mv import build_texture_gen_pipe, infer_pipe
32
  from asset3d_gen.scripts.text2image import (
@@ -39,6 +39,8 @@ from asset3d_gen.utils.process_media import (
39
  filter_image_small_connected_components,
40
  merge_images_video,
41
  render_asset3d,
 
 
42
  )
43
  from asset3d_gen.utils.tags import VERSION
44
  from asset3d_gen.validators.quality_checkers import (
@@ -84,6 +86,7 @@ os.environ["SPCONV_ALGO"] = "native"
84
  MAX_SEED = 100000
85
  DELIGHT = DelightingModel()
86
  IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
 
87
 
88
 
89
  def patched_setup_functions(self):
@@ -234,59 +237,6 @@ def end_session(req: gr.Request) -> None:
234
  shutil.rmtree(user_dir)
235
 
236
 
237
- @spaces.GPU
238
- def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
239
- renderer = MeshRenderer()
240
- renderer.rendering_options.resolution = options.get("resolution", 512)
241
- renderer.rendering_options.near = options.get("near", 1)
242
- renderer.rendering_options.far = options.get("far", 100)
243
- renderer.rendering_options.ssaa = options.get("ssaa", 4)
244
- rets = {}
245
- for extr, intr in tqdm(zip(extrinsics, intrinsics), desc="Rendering"):
246
- res = renderer.render(sample, extr, intr)
247
- if "normal" not in rets:
248
- rets["normal"] = []
249
- normal = torch.lerp(
250
- torch.zeros_like(res["normal"]), res["normal"], res["mask"]
251
- )
252
- normal = np.clip(
253
- normal.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255
254
- ).astype(np.uint8)
255
- rets["normal"].append(normal)
256
-
257
- return rets
258
-
259
-
260
- @spaces.GPU
261
- def render_video(
262
- sample,
263
- resolution=512,
264
- bg_color=(0, 0, 0),
265
- num_frames=300,
266
- r=2,
267
- fov=40,
268
- **kwargs,
269
- ):
270
- yaws = torch.linspace(0, 2 * 3.1415, num_frames)
271
- yaws = yaws.tolist()
272
- pitch = [0.5] * num_frames
273
- extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(
274
- yaws, pitch, r, fov
275
- )
276
- render_fn = (
277
- render_mesh if isinstance(sample, MeshExtractResult) else render_frames
278
- )
279
- result = render_fn(
280
- sample,
281
- extrinsics,
282
- intrinsics,
283
- {"resolution": resolution, "bg_color": bg_color},
284
- **kwargs,
285
- )
286
-
287
- return result
288
-
289
-
290
  @spaces.GPU
291
  def preprocess_image_fn(
292
  image: str | np.ndarray | Image.Image, rmbg_tag: str = "rembg"
@@ -495,11 +445,11 @@ def image_to_3d(
495
 
496
  @spaces.GPU
497
  def extract_3d_representations(
498
- state: dict, enable_delight: bool, req: gr.Request
499
  ):
500
  output_root = TMP_DIR
501
  output_root = os.path.join(output_root, str(req.session_hash))
502
- gs_model, mesh_model = unpack_state(state)
503
 
504
  mesh = postprocessing_utils.to_glb(
505
  gs_model,
 
26
  SAMPredictor,
27
  trellis_preprocess,
28
  )
29
+ from asset3d_gen.models.sr_model import ImageRealESRGAN, ImageStableSR
30
  from asset3d_gen.scripts.render_gs import entrypoint as render_gs_api
31
  from asset3d_gen.scripts.render_mv import build_texture_gen_pipe, infer_pipe
32
  from asset3d_gen.scripts.text2image import (
 
39
  filter_image_small_connected_components,
40
  merge_images_video,
41
  render_asset3d,
42
+ render_mesh,
43
+ render_video,
44
  )
45
  from asset3d_gen.utils.tags import VERSION
46
  from asset3d_gen.validators.quality_checkers import (
 
86
  MAX_SEED = 100000
87
  DELIGHT = DelightingModel()
88
  IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
89
+ # IMAGESR_MODEL = ImageStableSR()
90
 
91
 
92
  def patched_setup_functions(self):
 
237
  shutil.rmtree(user_dir)
238
 
239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  @spaces.GPU
241
  def preprocess_image_fn(
242
  image: str | np.ndarray | Image.Image, rmbg_tag: str = "rembg"
 
445
 
446
  @spaces.GPU
447
  def extract_3d_representations(
448
+ state: dict, enable_delight: bool, texture_size: int, req: gr.Request
449
  ):
450
  output_root = TMP_DIR
451
  output_root = os.path.join(output_root, str(req.session_hash))
452
+ gs_model, mesh_model = unpack_state(state, device="cuda")
453
 
454
  mesh = postprocessing_utils.to_glb(
455
  gs_model,