ImgRoboAssetGen / common.py
xinjie.wang
update
9b53de6
raw
history blame
18.2 kB
import gc
import logging
import os
import sys
from glob import glob
from typing import Union
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
import trimesh
from easydict import EasyDict as edict
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import (
StableDiffusionXLPipeline,
)
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import ( # noqa
StableDiffusionXLPipeline as StableDiffusionXLPipelineIP,
)
from PIL import Image
from tqdm import tqdm
from asset3d_gen.data.backproject_v2 import entrypoint as backproject_api
from asset3d_gen.models.delight import DelightingModel
from asset3d_gen.models.gs_model import GaussianOperator
from asset3d_gen.models.segment import (
RembgRemover,
SAMPredictor,
trellis_preprocess,
)
from asset3d_gen.models.super_resolution import ImageRealESRGAN, ImageStableSR
from asset3d_gen.scripts.render_gs import entrypoint as render_gs_api
from asset3d_gen.scripts.text2image import text2img_gen
from asset3d_gen.utils.process_media import (
filter_image_small_connected_components,
merge_images_video,
render_asset3d,
)
from asset3d_gen.utils.tags import VERSION
from asset3d_gen.validators.quality_checkers import (
BaseChecker,
ImageAestheticChecker,
ImageSegChecker,
MeshGeoChecker,
)
from asset3d_gen.validators.urdf_convertor import URDFGenerator, zip_files
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
sys.path.append(os.path.join(current_dir, "../.."))
from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer
from thirdparty.TRELLIS.trellis.representations import (
Gaussian,
MeshExtractResult,
)
from thirdparty.TRELLIS.trellis.utils import postprocessing_utils
from thirdparty.TRELLIS.trellis.utils.render_utils import (
render_frames,
yaw_pitch_r_fov_to_extrinsics_intrinsics,
)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
MAX_SEED = 100000
os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
@spaces.GPU
def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
renderer = MeshRenderer()
renderer.rendering_options.resolution = options.get("resolution", 512)
renderer.rendering_options.near = options.get("near", 1)
renderer.rendering_options.far = options.get("far", 100)
renderer.rendering_options.ssaa = options.get("ssaa", 4)
rets = {}
for extr, intr in tqdm(zip(extrinsics, intrinsics), desc="Rendering"):
res = renderer.render(sample, extr, intr)
if "normal" not in rets:
rets["normal"] = []
normal = torch.lerp(
torch.zeros_like(res["normal"]), res["normal"], res["mask"]
)
normal = np.clip(
normal.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255
).astype(np.uint8)
rets["normal"].append(normal)
return rets
@spaces.GPU
def render_video(
sample,
resolution=512,
bg_color=(0, 0, 0),
num_frames=300,
r=2,
fov=40,
**kwargs,
):
yaws = torch.linspace(0, 2 * 3.1415, num_frames)
yaws = yaws.tolist()
pitch = [0.5] * num_frames
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(
yaws, pitch, r, fov
)
render_fn = (
render_mesh if isinstance(sample, MeshExtractResult) else render_frames
)
result = render_fn(
sample,
extrinsics,
intrinsics,
{"resolution": resolution, "bg_color": bg_color},
**kwargs,
)
return result
@spaces.GPU
def preprocess_image_fn(
image: str | np.ndarray | Image.Image,
model: DelightingModel | RembgRemover,
buffer: dict = None,
) -> Image.Image:
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
if buffer is not None:
buffer["raw_image"] = image
if isinstance(model, DelightingModel):
image = model(image, preprocess=True, target_wh=(512, 512))
elif isinstance(model, RembgRemover):
image = model(image)
image = trellis_preprocess(image)
return image
@spaces.GPU
def preprocess_sam_image_fn(
image: Image.Image, buffer: dict, model: SAMPredictor
) -> Image.Image:
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
buffer["raw_image"] = image
sam_image = model.preprocess_image(image)
model.predictor.set_image(sam_image)
return sam_image
def active_btn_by_content(content: gr.Image) -> gr.Button:
interactive = True if content is not None else False
return gr.Button(interactive=interactive)
def active_btn_by_text_content(content: gr.Textbox) -> gr.Button:
if content is not None and len(content) > 0:
interactive = True
else:
interactive = False
return gr.Button(interactive=interactive)
def get_selected_image(
choice: str, sample1: str, sample2: str, sample3: str
) -> str:
if choice == "sample1":
return sample1
elif choice == "sample2":
return sample2
elif choice == "sample3":
return sample3
else:
raise ValueError(f"Invalid choice: {choice}")
@spaces.GPU
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
return {
"gaussian": {
**gs.init_params,
"_xyz": gs._xyz.cpu().numpy(),
"_features_dc": gs._features_dc.cpu().numpy(),
"_scaling": gs._scaling.cpu().numpy(),
"_rotation": gs._rotation.cpu().numpy(),
"_opacity": gs._opacity.cpu().numpy(),
},
"mesh": {
"vertices": mesh.vertices.cpu().numpy(),
"faces": mesh.faces.cpu().numpy(),
},
}
@spaces.GPU
def unpack_state(state: dict) -> tuple[Gaussian, edict, str]:
gs = Gaussian(
aabb=state["gaussian"]["aabb"],
sh_degree=state["gaussian"]["sh_degree"],
mininum_kernel_size=state["gaussian"]["mininum_kernel_size"],
scaling_bias=state["gaussian"]["scaling_bias"],
opacity_bias=state["gaussian"]["opacity_bias"],
scaling_activation=state["gaussian"]["scaling_activation"],
)
gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device="cuda")
gs._features_dc = torch.tensor(
state["gaussian"]["_features_dc"], device="cuda"
)
gs._scaling = torch.tensor(state["gaussian"]["_scaling"], device="cuda")
gs._rotation = torch.tensor(state["gaussian"]["_rotation"], device="cuda")
gs._opacity = torch.tensor(state["gaussian"]["_opacity"], device="cuda")
mesh = edict(
vertices=torch.tensor(state["mesh"]["vertices"], device="cuda"),
faces=torch.tensor(state["mesh"]["faces"], device="cuda"),
)
return gs, mesh
def get_seed(randomize_seed: bool, seed: int, max_seed: int = MAX_SEED) -> int:
return np.random.randint(0, max_seed) if randomize_seed else seed
@spaces.GPU
def select_point(
image: np.ndarray,
sel_pix: list,
point_type: str,
model: SAMPredictor,
evt: gr.SelectData,
):
if point_type == "foreground_point":
sel_pix.append((evt.index, 1)) # append the foreground_point
elif point_type == "background_point":
sel_pix.append((evt.index, 0)) # append the background_point
else:
sel_pix.append((evt.index, 1)) # default foreground_point
masks = model.generate_masks(image, sel_pix)
seg_image = model.get_segmented_image(image, masks)
for point, label in sel_pix:
color = (255, 0, 0) if label == 0 else (0, 255, 0)
marker_type = 1 if label == 0 else 5
cv2.drawMarker(
image,
point,
color,
markerType=marker_type,
markerSize=15,
thickness=10,
)
torch.cuda.empty_cache()
return (image, masks), seg_image
@spaces.GPU
def image_to_3d(
image: Image.Image,
seed: int,
ss_guidance_strength: float,
ss_sampling_steps: int,
slat_guidance_strength: float,
slat_sampling_steps: int,
buffer: dict,
pipeline: TrellisImageTo3DPipeline,
output_root: str,
sam_image: Image.Image = None,
is_sam_image: bool = False,
req: gr.Request = None,
) -> tuple[dict, str]:
if is_sam_image:
seg_image = filter_image_small_connected_components(sam_image)
seg_image = Image.fromarray(seg_image, mode="RGBA")
seg_image = trellis_preprocess(seg_image)
# seg_image.save(f"{TMP_DIR}/seg_image_sam.png")
else:
seg_image = image
if isinstance(seg_image, np.ndarray):
seg_image = Image.fromarray(seg_image)
buffer["seg_image"] = seg_image
pipeline.cuda()
outputs = pipeline.run(
seg_image,
seed=seed,
formats=["gaussian", "mesh"],
preprocess_image=False,
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"cfg_strength": ss_guidance_strength,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_guidance_strength,
},
)
# Set to cpu for memory saving.
pipeline.cpu()
gs_model = outputs["gaussian"][0]
mesh_model = outputs["mesh"][0]
color_images = render_video(gs_model)["color"]
normal_images = render_video(mesh_model)["normal"]
if req is not None:
output_root = os.path.join(output_root, str(req.session_hash))
video_path = os.path.join(output_root, "gs_mesh.mp4")
merge_images_video(color_images, normal_images, video_path)
state = pack_state(gs_model, mesh_model)
gc.collect()
torch.cuda.empty_cache()
return state, video_path
@spaces.GPU
def extract_3d_representations(
state: dict, enable_delight: bool, output_root: str, req: gr.Request
):
user_dir = os.path.join(output_root, str(req.session_hash))
gs_model, mesh_model = unpack_state(state)
mesh = postprocessing_utils.to_glb(
gs_model,
mesh_model,
simplify=0.9,
texture_size=1024,
verbose=True,
)
filename = "sample"
gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
gs_model.save_ply(gs_path)
# Rotate mesh and GS by 90 degrees around Z-axis.
rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
# Addtional rotation for GS to align mesh.
gs_rot = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) @ np.array(
rot_matrix
)
pose = GaussianOperator.trans_to_quatpose(gs_rot)
aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
GaussianOperator.resave_ply(
in_ply=gs_path,
out_ply=aligned_gs_path,
instance_pose=pose,
)
mesh.vertices = mesh.vertices @ np.array(rot_matrix)
mesh_obj_path = os.path.join(user_dir, f"{filename}.obj")
mesh.export(mesh_obj_path)
mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
mesh.export(mesh_glb_path)
torch.cuda.empty_cache()
return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
@spaces.GPU
def extract_3d_representations_v2(
state: dict,
enable_delight: bool,
output_root: str,
delight_model: DelightingModel,
sr_model: Union[ImageRealESRGAN, ImageStableSR],
req: gr.Request,
):
user_dir = os.path.join(output_root, str(req.session_hash))
gs_model, mesh_model = unpack_state(state)
filename = "sample"
gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
gs_model.save_ply(gs_path)
# Rotate mesh and GS by 90 degrees around Z-axis.
rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
# Addtional rotation for GS to align mesh.
gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
pose = GaussianOperator.trans_to_quatpose(gs_rot)
aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
GaussianOperator.resave_ply(
in_ply=gs_path,
out_ply=aligned_gs_path,
instance_pose=pose,
)
color_path = os.path.join(user_dir, "color.png")
render_gs_api(aligned_gs_path, color_path)
mesh = trimesh.Trimesh(
vertices=mesh_model.vertices.cpu().numpy(),
faces=mesh_model.faces.cpu().numpy(),
)
mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
mesh.vertices = mesh.vertices @ np.array(rot_matrix)
mesh_obj_path = os.path.join(user_dir, f"{filename}.obj")
mesh.export(mesh_obj_path)
mesh = backproject_api(
delight_model=delight_model,
imagesr_model=sr_model,
color_path=color_path,
mesh_path=mesh_obj_path,
output_path=mesh_obj_path,
skip_fix_mesh=False,
delight=enable_delight,
)
mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
mesh.export(mesh_glb_path)
torch.cuda.empty_cache()
return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
@spaces.GPU
def extract_urdf(
gs_path: str,
mesh_obj_path: str,
asset_cat_text: str,
height_range_text: str,
mass_range_text: str,
asset_version_text: str,
output_root: str,
urdf_convertor: URDFGenerator,
buffer: dict,
checkers: list[BaseChecker],
req: gr.Request = None,
):
if req is not None:
output_root = os.path.join(output_root, str(req.session_hash))
# Convert to URDF and recover attrs by gpt4o
filename = "sample"
asset_attrs = {
"version": VERSION,
"gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
}
if asset_version_text:
asset_attrs["version"] = asset_version_text
if asset_cat_text:
asset_attrs["category"] = asset_cat_text.lower()
if height_range_text:
try:
min_height, max_height = map(float, height_range_text.split("-"))
asset_attrs["min_height"] = min_height
asset_attrs["max_height"] = max_height
except ValueError:
return "Invalid height input format. Use the format: min-max."
if mass_range_text:
try:
min_mass, max_mass = map(float, mass_range_text.split("-"))
asset_attrs["min_mass"] = min_mass
asset_attrs["max_mass"] = max_mass
except ValueError:
return "Invalid mass input format. Use the format: min-max."
urdf_path = urdf_convertor(
mesh_path=mesh_obj_path,
output_root=f"{output_root}/URDF_{filename}",
**asset_attrs,
)
# Rescale GS and save to URDF/mesh folder.
real_height = urdf_convertor.get_attr_from_urdf(
urdf_path, attr_name="real_height"
)
out_gs = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}_gs.ply" # noqa
GaussianOperator.resave_ply(
in_ply=gs_path,
out_ply=out_gs,
real_height=real_height,
)
# Quality check and update .urdf file.
mesh_out = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}.obj" # noqa
trimesh.load(mesh_out).export(mesh_out.replace(".obj", ".glb"))
# image_paths = render_asset3d(
# mesh_path=mesh_out,
# output_root=f"{output_root}/URDF_{filename}",
# output_subdir="qa_renders",
# num_images=8,
# elevation=(30, -30),
# distance=5.5,
# )
image_dir = f"{output_root}/URDF_{filename}/{urdf_convertor.output_render_dir}/image_color" # noqa
image_paths = glob(f"{image_dir}/*.png")
images_list = []
for checker in checkers:
images = image_paths
if isinstance(checker, ImageSegChecker):
images = [buffer["raw_image"], buffer["seg_image"]]
images_list.append(images)
results = BaseChecker.validate(checkers, images_list)
urdf_convertor.add_quality_tag(urdf_path, results)
# Zip urdf files
urdf_zip = zip_files(
input_paths=[
f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}",
f"{output_root}/URDF_{filename}/{filename}.urdf",
],
output_zip=f"{output_root}/urdf_{filename}.zip",
)
torch.cuda.empty_cache()
estimated_type = urdf_convertor.estimated_attrs["category"]
estimated_height = urdf_convertor.estimated_attrs["height"]
estimated_mass = urdf_convertor.estimated_attrs["mass"]
estimated_mu = urdf_convertor.estimated_attrs["mu"]
return (
urdf_zip,
estimated_type,
estimated_height,
estimated_mass,
estimated_mu,
)
@spaces.GPU
def text2image_fn(
prompt: str,
output_root: str,
guidance_scale: float,
model_ip: StableDiffusionXLPipelineIP,
model_img: StableDiffusionXLPipeline,
bg_model: RembgRemover,
infer_step: int = 50,
ip_image: Image.Image | str = None,
ip_adapt_scale: float = 0.3,
image_wh: int | tuple[int, int] = [1024, 1024],
n_sample: int = 3,
postprocess: bool = True,
req: gr.Request = None,
):
if isinstance(image_wh, int):
image_wh = (image_wh, image_wh)
if req is not None:
output_root = os.path.join(output_root, str(req.session_hash))
os.makedirs(output_root, exist_ok=True)
pipeline = model_img if ip_image is None else model_ip
if ip_image is not None:
pipeline.set_ip_adapter_scale([ip_adapt_scale])
images = text2img_gen(
prompt=prompt,
n_sample=n_sample,
guidance_scale=guidance_scale,
pipeline=pipeline,
ip_image=ip_image,
image_wh=image_wh,
infer_step=infer_step,
)
if postprocess:
for idx in range(len(images)):
image = images[idx]
images[idx] = preprocess_image_fn(image, bg_model)
save_paths = []
for idx, image in enumerate(images):
save_path = f"{output_root}/sample_{idx}.png"
image.save(save_path)
save_paths.append(save_path)
logger.info(f"Images saved to {output_root}")
gc.collect()
torch.cuda.empty_cache()
return save_paths + save_paths