xinjie.wang
update
55ed985
import argparse
import logging
import math
import os
import cv2
import numpy as np
import torch
from tqdm import tqdm
from asset3d_gen.data.utils import (
CameraSetting,
init_kal_camera,
normalize_vertices_array,
)
from asset3d_gen.models.gs_model import GaussianOperator
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Render GS color images")
parser.add_argument(
"--input_gs", type=str, help="Input render GS.ply path."
)
parser.add_argument(
"--output_path",
type=str,
help="Output grid image path for rendered GS color images.",
)
parser.add_argument(
"--num_images", type=int, default=6, help="Number of images to render."
)
parser.add_argument(
"--elevation",
type=float,
nargs="+",
default=[20.0, -10.0],
help="Elevation angles for the camera (default: [20.0, -10.0])",
)
parser.add_argument(
"--distance",
type=float,
default=5,
help="Camera distance (default: 5)",
)
parser.add_argument(
"--resolution_hw",
type=int,
nargs=2,
default=(512, 512),
help="Resolution of the output images (default: (512, 512))",
)
parser.add_argument(
"--fov",
type=float,
default=30,
help="Field of view in degrees (default: 30)",
)
parser.add_argument(
"--device",
type=str,
choices=["cpu", "cuda"],
default="cuda",
help="Device to run on (default: `cuda`)",
)
parser.add_argument(
"--image_size",
type=int,
default=512,
help="Output image size for single view in color grid (default: 512)",
)
args = parser.parse_args()
return args
def load_gs_model(
input_gs: str, pre_quat: list[float] = [0.0, 0.7071, 0.0, -0.7071]
) -> GaussianOperator:
gs_model = GaussianOperator.load_from_ply(input_gs)
# Normalize vertices to [-1, 1], center to (0, 0, 0).
_, scale, center = normalize_vertices_array(gs_model._means)
scale, center = float(scale), center.tolist()
transpose = [*[-v for v in center], *pre_quat]
instance_pose = torch.tensor(transpose).to(gs_model.device)
gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
gs_model.rescale(scale)
return gs_model
def entrypoint(input_gs: str = None, output_path: str = None) -> None:
args = parse_args()
if isinstance(input_gs, str):
args.input_gs = input_gs
if isinstance(output_path, str):
args.output_path = output_path
# Setup camera parameters
camera_params = CameraSetting(
num_images=args.num_images,
elevation=args.elevation,
distance=args.distance,
resolution_hw=args.resolution_hw,
fov=math.radians(args.fov),
device=args.device,
)
camera = init_kal_camera(camera_params)
matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
matrix_mv[:, :3, 3] = -matrix_mv[:, :3, 3]
w2cs = matrix_mv.to(camera_params.device)
c2ws = [torch.linalg.inv(matrix) for matrix in w2cs]
Ks = torch.tensor(camera_params.Ks).to(camera_params.device)
# Load GS model and normalize.
gs_model = load_gs_model(args.input_gs, pre_quat=[0.0, 0.0, 1.0, 0.0])
# Render GS color images.
images = []
for idx in tqdm(range(len(c2ws)), desc="Rendering GS"):
result = gs_model.render(
c2ws[idx],
Ks=Ks,
image_width=camera_params.resolution_hw[1],
image_height=camera_params.resolution_hw[0],
)
color = cv2.resize(
result.rgba,
(args.image_size, args.image_size),
interpolation=cv2.INTER_AREA,
)
images.append(color)
# Cat color images into grid image and save.
select_idxs = [[0, 2, 1], [5, 4, 3]] # fix order for 6 views
grid_image = []
for row_idxs in select_idxs:
row_image = []
for row_idx in row_idxs:
row_image.append(images[row_idx])
row_image = np.concatenate(row_image, axis=1)
grid_image.append(row_image)
grid_image = np.concatenate(grid_image, axis=0)
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
cv2.imwrite(args.output_path, grid_image)
logger.info(f"Saved grid image to {args.output_path}")
if __name__ == "__main__":
entrypoint()