Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import logging | |
import math | |
import os | |
import cv2 | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
from asset3d_gen.data.utils import ( | |
CameraSetting, | |
init_kal_camera, | |
normalize_vertices_array, | |
) | |
from asset3d_gen.models.gs_model import GaussianOperator | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO | |
) | |
logger = logging.getLogger(__name__) | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Render GS color images") | |
parser.add_argument( | |
"--input_gs", type=str, help="Input render GS.ply path." | |
) | |
parser.add_argument( | |
"--output_path", | |
type=str, | |
help="Output grid image path for rendered GS color images.", | |
) | |
parser.add_argument( | |
"--num_images", type=int, default=6, help="Number of images to render." | |
) | |
parser.add_argument( | |
"--elevation", | |
type=float, | |
nargs="+", | |
default=[20.0, -10.0], | |
help="Elevation angles for the camera (default: [20.0, -10.0])", | |
) | |
parser.add_argument( | |
"--distance", | |
type=float, | |
default=5, | |
help="Camera distance (default: 5)", | |
) | |
parser.add_argument( | |
"--resolution_hw", | |
type=int, | |
nargs=2, | |
default=(512, 512), | |
help="Resolution of the output images (default: (512, 512))", | |
) | |
parser.add_argument( | |
"--fov", | |
type=float, | |
default=30, | |
help="Field of view in degrees (default: 30)", | |
) | |
parser.add_argument( | |
"--device", | |
type=str, | |
choices=["cpu", "cuda"], | |
default="cuda", | |
help="Device to run on (default: `cuda`)", | |
) | |
parser.add_argument( | |
"--image_size", | |
type=int, | |
default=512, | |
help="Output image size for single view in color grid (default: 512)", | |
) | |
args = parser.parse_args() | |
return args | |
def load_gs_model( | |
input_gs: str, pre_quat: list[float] = [0.0, 0.7071, 0.0, -0.7071] | |
) -> GaussianOperator: | |
gs_model = GaussianOperator.load_from_ply(input_gs) | |
# Normalize vertices to [-1, 1], center to (0, 0, 0). | |
_, scale, center = normalize_vertices_array(gs_model._means) | |
scale, center = float(scale), center.tolist() | |
transpose = [*[-v for v in center], *pre_quat] | |
instance_pose = torch.tensor(transpose).to(gs_model.device) | |
gs_model = gs_model.get_gaussians(instance_pose=instance_pose) | |
gs_model.rescale(scale) | |
return gs_model | |
def entrypoint(input_gs: str = None, output_path: str = None) -> None: | |
args = parse_args() | |
if isinstance(input_gs, str): | |
args.input_gs = input_gs | |
if isinstance(output_path, str): | |
args.output_path = output_path | |
# Setup camera parameters | |
camera_params = CameraSetting( | |
num_images=args.num_images, | |
elevation=args.elevation, | |
distance=args.distance, | |
resolution_hw=args.resolution_hw, | |
fov=math.radians(args.fov), | |
device=args.device, | |
) | |
camera = init_kal_camera(camera_params) | |
matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam | |
matrix_mv[:, :3, 3] = -matrix_mv[:, :3, 3] | |
w2cs = matrix_mv.to(camera_params.device) | |
c2ws = [torch.linalg.inv(matrix) for matrix in w2cs] | |
Ks = torch.tensor(camera_params.Ks).to(camera_params.device) | |
# Load GS model and normalize. | |
gs_model = load_gs_model(args.input_gs, pre_quat=[0.0, 0.0, 1.0, 0.0]) | |
# Render GS color images. | |
images = [] | |
for idx in tqdm(range(len(c2ws)), desc="Rendering GS"): | |
result = gs_model.render( | |
c2ws[idx], | |
Ks=Ks, | |
image_width=camera_params.resolution_hw[1], | |
image_height=camera_params.resolution_hw[0], | |
) | |
color = cv2.resize( | |
result.rgba, | |
(args.image_size, args.image_size), | |
interpolation=cv2.INTER_AREA, | |
) | |
images.append(color) | |
# Cat color images into grid image and save. | |
select_idxs = [[0, 2, 1], [5, 4, 3]] # fix order for 6 views | |
grid_image = [] | |
for row_idxs in select_idxs: | |
row_image = [] | |
for row_idx in row_idxs: | |
row_image.append(images[row_idx]) | |
row_image = np.concatenate(row_image, axis=1) | |
grid_image.append(row_image) | |
grid_image = np.concatenate(grid_image, axis=0) | |
os.makedirs(os.path.dirname(args.output_path), exist_ok=True) | |
cv2.imwrite(args.output_path, grid_image) | |
logger.info(f"Saved grid image to {args.output_path}") | |
if __name__ == "__main__": | |
entrypoint() | |