import gc import logging import os import sys from glob import glob from typing import Union import cv2 import gradio as gr import numpy as np import spaces import torch import trimesh from easydict import EasyDict as edict from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import ( StableDiffusionXLPipeline, ) from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import ( # noqa StableDiffusionXLPipeline as StableDiffusionXLPipelineIP, ) from PIL import Image from tqdm import tqdm from asset3d_gen.data.backproject_v2 import entrypoint as backproject_api from asset3d_gen.models.delight import DelightingModel from asset3d_gen.models.gs_model import GaussianOperator from asset3d_gen.models.segment import ( RembgRemover, SAMPredictor, trellis_preprocess, ) from asset3d_gen.models.super_resolution import ImageRealESRGAN, ImageStableSR from asset3d_gen.scripts.render_gs import entrypoint as render_gs_api from asset3d_gen.scripts.text2image import text2img_gen from asset3d_gen.utils.process_media import ( filter_image_small_connected_components, merge_images_video, render_asset3d, ) from asset3d_gen.utils.tags import VERSION from asset3d_gen.validators.quality_checkers import ( BaseChecker, ImageAestheticChecker, ImageSegChecker, MeshGeoChecker, ) from asset3d_gen.validators.urdf_convertor import URDFGenerator, zip_files current_file_path = os.path.abspath(__file__) current_dir = os.path.dirname(current_file_path) sys.path.append(os.path.join(current_dir, "../..")) from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer from thirdparty.TRELLIS.trellis.representations import ( Gaussian, MeshExtractResult, ) from thirdparty.TRELLIS.trellis.utils import postprocessing_utils from thirdparty.TRELLIS.trellis.utils.render_utils import ( render_frames, yaw_pitch_r_fov_to_extrinsics_intrinsics, ) logging.basicConfig( format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO ) logger = logging.getLogger(__name__) MAX_SEED = 100000 os.environ["GRADIO_ANALYTICS_ENABLED"] = "false" @spaces.GPU def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs): renderer = MeshRenderer() renderer.rendering_options.resolution = options.get("resolution", 512) renderer.rendering_options.near = options.get("near", 1) renderer.rendering_options.far = options.get("far", 100) renderer.rendering_options.ssaa = options.get("ssaa", 4) rets = {} for extr, intr in tqdm(zip(extrinsics, intrinsics), desc="Rendering"): res = renderer.render(sample, extr, intr) if "normal" not in rets: rets["normal"] = [] normal = torch.lerp( torch.zeros_like(res["normal"]), res["normal"], res["mask"] ) normal = np.clip( normal.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255 ).astype(np.uint8) rets["normal"].append(normal) return rets @spaces.GPU def render_video( sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2, fov=40, **kwargs, ): yaws = torch.linspace(0, 2 * 3.1415, num_frames) yaws = yaws.tolist() pitch = [0.5] * num_frames extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics( yaws, pitch, r, fov ) render_fn = ( render_mesh if isinstance(sample, MeshExtractResult) else render_frames ) result = render_fn( sample, extrinsics, intrinsics, {"resolution": resolution, "bg_color": bg_color}, **kwargs, ) return result @spaces.GPU def preprocess_image_fn( image: str | np.ndarray | Image.Image, model: DelightingModel | RembgRemover, buffer: dict = None, ) -> Image.Image: if isinstance(image, str): image = Image.open(image) elif isinstance(image, np.ndarray): image = Image.fromarray(image) if buffer is not None: buffer["raw_image"] = image if isinstance(model, DelightingModel): image = model(image, preprocess=True, target_wh=(512, 512)) elif isinstance(model, RembgRemover): image = model(image) image = trellis_preprocess(image) return image @spaces.GPU def preprocess_sam_image_fn( image: Image.Image, buffer: dict, model: SAMPredictor ) -> Image.Image: if isinstance(image, np.ndarray): image = Image.fromarray(image) buffer["raw_image"] = image sam_image = model.preprocess_image(image) model.predictor.set_image(sam_image) return sam_image def active_btn_by_content(content: gr.Image) -> gr.Button: interactive = True if content is not None else False return gr.Button(interactive=interactive) def active_btn_by_text_content(content: gr.Textbox) -> gr.Button: if content is not None and len(content) > 0: interactive = True else: interactive = False return gr.Button(interactive=interactive) def get_selected_image( choice: str, sample1: str, sample2: str, sample3: str ) -> str: if choice == "sample1": return sample1 elif choice == "sample2": return sample2 elif choice == "sample3": return sample3 else: raise ValueError(f"Invalid choice: {choice}") @spaces.GPU def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict: return { "gaussian": { **gs.init_params, "_xyz": gs._xyz.cpu().numpy(), "_features_dc": gs._features_dc.cpu().numpy(), "_scaling": gs._scaling.cpu().numpy(), "_rotation": gs._rotation.cpu().numpy(), "_opacity": gs._opacity.cpu().numpy(), }, "mesh": { "vertices": mesh.vertices.cpu().numpy(), "faces": mesh.faces.cpu().numpy(), }, } @spaces.GPU def unpack_state(state: dict) -> tuple[Gaussian, edict, str]: gs = Gaussian( aabb=state["gaussian"]["aabb"], sh_degree=state["gaussian"]["sh_degree"], mininum_kernel_size=state["gaussian"]["mininum_kernel_size"], scaling_bias=state["gaussian"]["scaling_bias"], opacity_bias=state["gaussian"]["opacity_bias"], scaling_activation=state["gaussian"]["scaling_activation"], ) gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device="cuda") gs._features_dc = torch.tensor( state["gaussian"]["_features_dc"], device="cuda" ) gs._scaling = torch.tensor(state["gaussian"]["_scaling"], device="cuda") gs._rotation = torch.tensor(state["gaussian"]["_rotation"], device="cuda") gs._opacity = torch.tensor(state["gaussian"]["_opacity"], device="cuda") mesh = edict( vertices=torch.tensor(state["mesh"]["vertices"], device="cuda"), faces=torch.tensor(state["mesh"]["faces"], device="cuda"), ) return gs, mesh def get_seed(randomize_seed: bool, seed: int, max_seed: int = MAX_SEED) -> int: return np.random.randint(0, max_seed) if randomize_seed else seed @spaces.GPU def select_point( image: np.ndarray, sel_pix: list, point_type: str, model: SAMPredictor, evt: gr.SelectData, ): if point_type == "foreground_point": sel_pix.append((evt.index, 1)) # append the foreground_point elif point_type == "background_point": sel_pix.append((evt.index, 0)) # append the background_point else: sel_pix.append((evt.index, 1)) # default foreground_point masks = model.generate_masks(image, sel_pix) seg_image = model.get_segmented_image(image, masks) for point, label in sel_pix: color = (255, 0, 0) if label == 0 else (0, 255, 0) marker_type = 1 if label == 0 else 5 cv2.drawMarker( image, point, color, markerType=marker_type, markerSize=15, thickness=10, ) torch.cuda.empty_cache() return (image, masks), seg_image @spaces.GPU def image_to_3d( image: Image.Image, seed: int, ss_guidance_strength: float, ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int, buffer: dict, pipeline: TrellisImageTo3DPipeline, output_root: str, sam_image: Image.Image = None, is_sam_image: bool = False, req: gr.Request = None, ) -> tuple[dict, str]: if is_sam_image: seg_image = filter_image_small_connected_components(sam_image) seg_image = Image.fromarray(seg_image, mode="RGBA") seg_image = trellis_preprocess(seg_image) # seg_image.save(f"{TMP_DIR}/seg_image_sam.png") else: seg_image = image if isinstance(seg_image, np.ndarray): seg_image = Image.fromarray(seg_image) buffer["seg_image"] = seg_image pipeline.cuda() outputs = pipeline.run( seg_image, seed=seed, formats=["gaussian", "mesh"], preprocess_image=False, sparse_structure_sampler_params={ "steps": ss_sampling_steps, "cfg_strength": ss_guidance_strength, }, slat_sampler_params={ "steps": slat_sampling_steps, "cfg_strength": slat_guidance_strength, }, ) # Set to cpu for memory saving. pipeline.cpu() gs_model = outputs["gaussian"][0] mesh_model = outputs["mesh"][0] color_images = render_video(gs_model)["color"] normal_images = render_video(mesh_model)["normal"] if req is not None: output_root = os.path.join(output_root, str(req.session_hash)) video_path = os.path.join(output_root, "gs_mesh.mp4") merge_images_video(color_images, normal_images, video_path) state = pack_state(gs_model, mesh_model) gc.collect() torch.cuda.empty_cache() return state, video_path @spaces.GPU def extract_3d_representations( state: dict, enable_delight: bool, output_root: str, req: gr.Request ): user_dir = os.path.join(output_root, str(req.session_hash)) gs_model, mesh_model = unpack_state(state) mesh = postprocessing_utils.to_glb( gs_model, mesh_model, simplify=0.9, texture_size=1024, verbose=True, ) filename = "sample" gs_path = os.path.join(user_dir, f"{filename}_gs.ply") gs_model.save_ply(gs_path) # Rotate mesh and GS by 90 degrees around Z-axis. rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]] # Addtional rotation for GS to align mesh. gs_rot = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) @ np.array( rot_matrix ) pose = GaussianOperator.trans_to_quatpose(gs_rot) aligned_gs_path = gs_path.replace(".ply", "_aligned.ply") GaussianOperator.resave_ply( in_ply=gs_path, out_ply=aligned_gs_path, instance_pose=pose, ) mesh.vertices = mesh.vertices @ np.array(rot_matrix) mesh_obj_path = os.path.join(user_dir, f"{filename}.obj") mesh.export(mesh_obj_path) mesh_glb_path = os.path.join(user_dir, f"{filename}.glb") mesh.export(mesh_glb_path) torch.cuda.empty_cache() return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path @spaces.GPU def extract_3d_representations_v2( state: dict, enable_delight: bool, output_root: str, delight_model: DelightingModel, sr_model: Union[ImageRealESRGAN, ImageStableSR], req: gr.Request, ): user_dir = os.path.join(output_root, str(req.session_hash)) gs_model, mesh_model = unpack_state(state) filename = "sample" gs_path = os.path.join(user_dir, f"{filename}_gs.ply") gs_model.save_ply(gs_path) # Rotate mesh and GS by 90 degrees around Z-axis. rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]] gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]] # Addtional rotation for GS to align mesh. gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix) pose = GaussianOperator.trans_to_quatpose(gs_rot) aligned_gs_path = gs_path.replace(".ply", "_aligned.ply") GaussianOperator.resave_ply( in_ply=gs_path, out_ply=aligned_gs_path, instance_pose=pose, ) color_path = os.path.join(user_dir, "color.png") render_gs_api(aligned_gs_path, color_path) mesh = trimesh.Trimesh( vertices=mesh_model.vertices.cpu().numpy(), faces=mesh_model.faces.cpu().numpy(), ) mesh.vertices = mesh.vertices @ np.array(mesh_add_rot) mesh.vertices = mesh.vertices @ np.array(rot_matrix) mesh_obj_path = os.path.join(user_dir, f"{filename}.obj") mesh.export(mesh_obj_path) mesh = backproject_api( delight_model=delight_model, imagesr_model=sr_model, color_path=color_path, mesh_path=mesh_obj_path, output_path=mesh_obj_path, skip_fix_mesh=False, delight=enable_delight, ) mesh_glb_path = os.path.join(user_dir, f"{filename}.glb") mesh.export(mesh_glb_path) torch.cuda.empty_cache() return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path @spaces.GPU def extract_urdf( gs_path: str, mesh_obj_path: str, asset_cat_text: str, height_range_text: str, mass_range_text: str, asset_version_text: str, output_root: str, urdf_convertor: URDFGenerator, buffer: dict, checkers: list[BaseChecker], req: gr.Request = None, ): if req is not None: output_root = os.path.join(output_root, str(req.session_hash)) # Convert to URDF and recover attrs by gpt4o filename = "sample" asset_attrs = { "version": VERSION, "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply", } if asset_version_text: asset_attrs["version"] = asset_version_text if asset_cat_text: asset_attrs["category"] = asset_cat_text.lower() if height_range_text: try: min_height, max_height = map(float, height_range_text.split("-")) asset_attrs["min_height"] = min_height asset_attrs["max_height"] = max_height except ValueError: return "Invalid height input format. Use the format: min-max." if mass_range_text: try: min_mass, max_mass = map(float, mass_range_text.split("-")) asset_attrs["min_mass"] = min_mass asset_attrs["max_mass"] = max_mass except ValueError: return "Invalid mass input format. Use the format: min-max." urdf_path = urdf_convertor( mesh_path=mesh_obj_path, output_root=f"{output_root}/URDF_{filename}", **asset_attrs, ) # Rescale GS and save to URDF/mesh folder. real_height = urdf_convertor.get_attr_from_urdf( urdf_path, attr_name="real_height" ) out_gs = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}_gs.ply" # noqa GaussianOperator.resave_ply( in_ply=gs_path, out_ply=out_gs, real_height=real_height, ) # Quality check and update .urdf file. mesh_out = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}.obj" # noqa trimesh.load(mesh_out).export(mesh_out.replace(".obj", ".glb")) # image_paths = render_asset3d( # mesh_path=mesh_out, # output_root=f"{output_root}/URDF_{filename}", # output_subdir="qa_renders", # num_images=8, # elevation=(30, -30), # distance=5.5, # ) image_dir = f"{output_root}/URDF_{filename}/{urdf_convertor.output_render_dir}/image_color" # noqa image_paths = glob(f"{image_dir}/*.png") images_list = [] for checker in checkers: images = image_paths if isinstance(checker, ImageSegChecker): images = [buffer["raw_image"], buffer["seg_image"]] images_list.append(images) results = BaseChecker.validate(checkers, images_list) urdf_convertor.add_quality_tag(urdf_path, results) # Zip urdf files urdf_zip = zip_files( input_paths=[ f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}", f"{output_root}/URDF_{filename}/{filename}.urdf", ], output_zip=f"{output_root}/urdf_{filename}.zip", ) torch.cuda.empty_cache() estimated_type = urdf_convertor.estimated_attrs["category"] estimated_height = urdf_convertor.estimated_attrs["height"] estimated_mass = urdf_convertor.estimated_attrs["mass"] estimated_mu = urdf_convertor.estimated_attrs["mu"] return ( urdf_zip, estimated_type, estimated_height, estimated_mass, estimated_mu, ) @spaces.GPU def text2image_fn( prompt: str, output_root: str, guidance_scale: float, model_ip: StableDiffusionXLPipelineIP, model_img: StableDiffusionXLPipeline, bg_model: RembgRemover, infer_step: int = 50, ip_image: Image.Image | str = None, ip_adapt_scale: float = 0.3, image_wh: int | tuple[int, int] = [1024, 1024], n_sample: int = 3, postprocess: bool = True, req: gr.Request = None, ): if isinstance(image_wh, int): image_wh = (image_wh, image_wh) if req is not None: output_root = os.path.join(output_root, str(req.session_hash)) os.makedirs(output_root, exist_ok=True) pipeline = model_img if ip_image is None else model_ip if ip_image is not None: pipeline.set_ip_adapter_scale([ip_adapt_scale]) images = text2img_gen( prompt=prompt, n_sample=n_sample, guidance_scale=guidance_scale, pipeline=pipeline, ip_image=ip_image, image_wh=image_wh, infer_step=infer_step, ) if postprocess: for idx in range(len(images)): image = images[idx] images[idx] = preprocess_image_fn(image, bg_model) save_paths = [] for idx, image in enumerate(images): save_path = f"{output_root}/sample_{idx}.png" image.save(save_path) save_paths.append(save_path) logger.info(f"Images saved to {output_root}") gc.collect() torch.cuda.empty_cache() return save_paths + save_paths