Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import uuid | |
| import time | |
| import rembg | |
| import numpy as np | |
| import trimesh | |
| import torch | |
| import fpsample | |
| import matplotlib.pyplot as plt | |
| cmap = plt.get_cmap("hsv") | |
| from torchvision.transforms import v2 | |
| from pytorch_lightning import seed_everything | |
| from PIL import Image | |
| from omegaconf import OmegaConf | |
| from einops import rearrange | |
| from scipy.spatial.transform import Rotation | |
| from safetensors import safe_open | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| from transformers import AutoModelForImageSegmentation | |
| from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler | |
| from freesplatter.hunyuan.hunyuan3d_mvd_std_pipeline import HunYuan3D_MVD_Std_Pipeline | |
| from freesplatter.utils.mesh_optim import optimize_mesh | |
| from freesplatter.utils.camera_util import * | |
| from freesplatter.utils.recon_util import * | |
| from freesplatter.utils.infer_util import * | |
| from freesplatter.webui.camera_viewer.visualizer import CameraVisualizer | |
| def inv_sigmoid(x: torch.Tensor) -> torch.Tensor: | |
| return torch.log(x / (1.0 - x)) | |
| def save_gaussian(latent, gs_vis_path, model, opacity_threshold=None, pad_2dgs_scale=True): | |
| if latent.ndim == 3: | |
| latent = latent[0] | |
| sh_dim = model.sh_dim | |
| scale_dim = 2 if model.use_2dgs else 3 | |
| xyz, features, opacity, scaling, rotation = latent.split([3, sh_dim, 1, scale_dim, 4], dim=-1) | |
| features = features.reshape(features.shape[0], sh_dim//3, 3) | |
| if opacity_threshold is not None: | |
| index = torch.nonzero(opacity.sigmoid() > opacity_threshold)[:, 0] | |
| xyz = xyz[index] | |
| features = features[index] | |
| opacity = opacity[index] | |
| scaling = scaling[index] | |
| rotation = rotation[index] | |
| # transform gaussians from reference view to world view | |
| cam2world = create_camera_to_world(torch.tensor([0, -2, 0]), camera_system='opencv').to(latent) | |
| R, T = cam2world[:3, :3], cam2world[:3, 3].reshape(1, 3) | |
| xyz = xyz @ R.T + T | |
| rotation = rotation.detach().cpu().numpy() | |
| rotation = Rotation.from_quat(rotation[:, [1, 2, 3, 0]]).as_matrix() | |
| rotation = R.detach().cpu().numpy() @ rotation | |
| rotation = Rotation.from_matrix(rotation).as_quat()[:, [3, 0, 1, 2]] | |
| rotation = torch.from_numpy(rotation).to(latent) | |
| # pad 2DGS with an additional z-scale for visualization | |
| if scaling.shape[-1] == 2 and pad_2dgs_scale: | |
| z_scaling = inv_sigmoid(torch.ones_like(scaling[:, :1]) * 0.001) | |
| scaling = torch.cat([scaling, z_scaling], dim=-1) | |
| pc_vis = model.gs_renderer.gaussian_model.set_data( | |
| xyz.float(), features.float(), scaling.float(), rotation.float(), opacity.float()) | |
| pc_vis.save_ply_vis(gs_vis_path) | |
| class FreeSplatterRunner: | |
| def __init__(self, device): | |
| self.device = device | |
| # background remover | |
| self.rembg = AutoModelForImageSegmentation.from_pretrained( | |
| "ZhengPeng7/BiRefNet", | |
| # "briaai/RMBG-2.0", | |
| trust_remote_code=True, | |
| ).to(device) | |
| self.rembg.eval() | |
| # self.rembg = rembg.new_session('birefnet-general') | |
| # diffusion models | |
| pipeline = DiffusionPipeline.from_pretrained( | |
| "sudo-ai/zero123plus-v1.1", | |
| custom_pipeline="sudo-ai/zero123plus-pipeline", | |
| torch_dtype=torch.float16, | |
| ) | |
| pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( | |
| pipeline.scheduler.config, timestep_spacing='trailing' | |
| ) | |
| self.zero123plus_v11 = pipeline.to(device) | |
| pipeline = DiffusionPipeline.from_pretrained( | |
| "sudo-ai/zero123plus-v1.2", | |
| custom_pipeline="sudo-ai/zero123plus-pipeline", | |
| torch_dtype=torch.float16, | |
| ) | |
| pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( | |
| pipeline.scheduler.config, timestep_spacing='trailing' | |
| ) | |
| self.zero123plus_v12 = pipeline.to(device) | |
| download_dir = snapshot_download('tencent/Hunyuan3D-1', repo_type='model') | |
| pipeline = HunYuan3D_MVD_Std_Pipeline.from_pretrained( | |
| os.path.join(download_dir, 'mvd_std'), | |
| torch_dtype=torch.float16, | |
| use_safetensors=True, | |
| ) | |
| self.hunyuan3d_mvd_std = pipeline.to(device) | |
| # freesplatter | |
| config_file = 'configs/freesplatter-object.yaml' | |
| ckpt_path = hf_hub_download('TencentARC/FreeSplatter', repo_type='model', filename='freesplatter-object.safetensors') | |
| model = instantiate_from_config(OmegaConf.load(config_file).model) | |
| state_dict = {} | |
| with safe_open(ckpt_path, framework="pt", device="cpu") as f: | |
| for key in f.keys(): | |
| state_dict[key] = f.get_tensor(key) | |
| model.load_state_dict(state_dict, strict=True) | |
| self.freesplatter = model.eval().to(device) | |
| config_file = 'configs/freesplatter-object-2dgs.yaml' | |
| ckpt_path = hf_hub_download('TencentARC/FreeSplatter', repo_type='model', filename='freesplatter-object-2dgs.safetensors') | |
| model = instantiate_from_config(OmegaConf.load(config_file).model) | |
| state_dict = {} | |
| with safe_open(ckpt_path, framework="pt", device="cpu") as f: | |
| for key in f.keys(): | |
| state_dict[key] = f.get_tensor(key) | |
| model.load_state_dict(state_dict, strict=True) | |
| self.freesplatter_2dgs = model.eval().to(device) | |
| config_file = 'configs/freesplatter-scene.yaml' | |
| ckpt_path = hf_hub_download('TencentARC/FreeSplatter', repo_type='model', filename='freesplatter-scene.safetensors') | |
| model = instantiate_from_config(OmegaConf.load(config_file).model) | |
| state_dict = {} | |
| with safe_open(ckpt_path, framework="pt", device="cpu") as f: | |
| for key in f.keys(): | |
| state_dict[key] = f.get_tensor(key) | |
| model.load_state_dict(state_dict, strict=True) | |
| self.freesplatter_scene = model.eval().to(device) | |
| def run_segmentation( | |
| self, | |
| image, | |
| do_rembg=True, | |
| ): | |
| if do_rembg: | |
| image = remove_background(image, self.rembg) | |
| return image | |
| def run_img_to_3d( | |
| self, | |
| image, | |
| model='Zero123++ v1.2', | |
| diffusion_steps=30, | |
| guidance_scale=4.0, | |
| seed=42, | |
| view_indices=[], | |
| gs_type='2DGS', | |
| mesh_reduction=0.5, | |
| cache_dir=None, | |
| ): | |
| image_rgba = self.run_segmentation(image) | |
| res = [image_rgba] | |
| yield res + [None] * (6 - len(res)) | |
| self.output_dir = os.path.join(cache_dir, f'output_{uuid.uuid4()}') | |
| os.makedirs(self.output_dir, exist_ok=True) | |
| # image-to-multiview | |
| input_image = resize_foreground(image_rgba, 0.9) | |
| seed_everything(seed) | |
| if model == 'Zero123++ v1.1': | |
| output_image = self.zero123plus_v11( | |
| input_image, | |
| num_inference_steps=diffusion_steps, | |
| guidance_scale=guidance_scale, | |
| ).images[0] | |
| elif model == 'Zero123++ v1.2': | |
| output_image = self.zero123plus_v12( | |
| input_image, | |
| num_inference_steps=diffusion_steps, | |
| guidance_scale=guidance_scale, | |
| ).images[0] | |
| elif model == 'Hunyuan3D Std': | |
| output_image = self.hunyuan3d_mvd_std( | |
| input_image, | |
| num_inference_steps=diffusion_steps, | |
| guidance_scale=guidance_scale, | |
| guidance_curve=lambda t:2.0, | |
| ).images[0] | |
| else: | |
| raise ValueError(f'Unknown model: {model}') | |
| # preprocess images | |
| image, alpha = rgba_to_white_background(input_image) | |
| image = v2.functional.resize(image, 512, interpolation=3, antialias=True).clamp(0, 1) | |
| alpha = v2.functional.resize(alpha, 512, interpolation=0, antialias=True).clamp(0, 1) | |
| output_image_rgba = remove_background(output_image, self.rembg) | |
| if 'Zero123++' in model: | |
| images, alphas = rgba_to_white_background(output_image_rgba) | |
| else: | |
| _, alphas = rgba_to_white_background(output_image_rgba) | |
| images = torch.from_numpy(np.asarray(output_image) / 255.0).float() | |
| images = rearrange(images, 'h w c -> c h w') | |
| images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) | |
| alphas = rearrange(alphas, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) | |
| if model == 'Hunyuan3D Std': | |
| images = images[[0, 2, 4, 5, 3, 1]] | |
| alphas = alphas[[0, 2, 4, 5, 3, 1]] | |
| images_vis = v2.functional.to_pil_image(rearrange(images, 'nm c h w -> c h (nm w)')) | |
| res += [images_vis] | |
| yield res + [None] * (6 - len(res)) | |
| images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1) | |
| alphas = v2.functional.resize(alphas, 512, interpolation=0, antialias=True).clamp(0, 1) | |
| images = torch.cat([image.unsqueeze(0), images], dim=0) # 7 x 3 x 512 x 512 | |
| alphas = torch.cat([alpha.unsqueeze(0), alphas], dim=0) # 7 x 1 x 512 x 512 | |
| # run reconstruction | |
| view_indices = [1, 2, 3, 4, 5, 6] if len(view_indices) == 0 else view_indices | |
| images, alphas = images[view_indices], alphas[view_indices] | |
| legends = [f'V{i}' if i != 0 else 'Input' for i in view_indices] | |
| for item in self.run_freesplatter_object( | |
| images, alphas, legends=legends, gs_type=gs_type, mesh_reduction=mesh_reduction): | |
| res += [item] | |
| yield res + [None] * (6 - len(res)) | |
| def run_views_to_3d( | |
| self, | |
| image_files, | |
| do_rembg=False, | |
| gs_type='2DGS', | |
| mesh_reduction=0.5, | |
| cache_dir=None, | |
| ): | |
| self.output_dir = os.path.join(cache_dir, f'output_{uuid.uuid4()}') | |
| os.makedirs(self.output_dir, exist_ok=True) | |
| # preprocesss images | |
| images, alphas = [], [] | |
| for image_file in image_files: | |
| if isinstance(image_file, tuple): | |
| image_file = image_file[0] | |
| image = Image.open(image_file) | |
| w, h = image.size | |
| image_rgba = self.run_segmentation(image) | |
| if image.mode == 'RGBA': | |
| image, alpha = rgba_to_white_background(image_rgba) | |
| image = v2.functional.center_crop(image, min(h, w)) | |
| alpha = v2.functional.center_crop(alpha, min(h, w)) | |
| else: | |
| image_rgba = resize_foreground(image_rgba, 0.9) | |
| image_rgba.save('test.png') | |
| image, alpha = rgba_to_white_background(image_rgba) | |
| image = v2.functional.resize(image, 512, interpolation=3, antialias=True).clamp(0, 1) | |
| alpha = v2.functional.resize(alpha, 512, interpolation=0, antialias=True).clamp(0, 1) | |
| images.append(image) | |
| alphas.append(alpha) | |
| images = torch.stack(images, dim=0) | |
| alphas = torch.stack(alphas, dim=0) | |
| images_vis = v2.functional.to_pil_image(rearrange(images, 'n c h w -> c h (n w)')) | |
| # run reconstruction | |
| legends = [f'V{i}' for i in range(1, 1+len(images))] | |
| gs_vis_path, video_path, mesh_fine_path, fig = self.run_freesplatter_object( | |
| images, alphas, legends=legends, gs_type=gs_type, mesh_reduction=mesh_reduction) | |
| return images_vis, gs_vis_path, video_path, mesh_fine_path, fig | |
| def run_freesplatter_object( | |
| self, | |
| images, | |
| alphas, | |
| legends=None, | |
| gs_type='2DGS', | |
| mesh_reduction=0.5, | |
| ): | |
| device = self.device | |
| freesplatter = self.freesplatter_2dgs if gs_type == '2DGS' else self.freesplatter | |
| images, alphas = images.to(device), alphas.to(device) | |
| t0 = time.time() | |
| with torch.inference_mode(): | |
| gaussians = freesplatter.forward_gaussians(images.unsqueeze(0)) | |
| t1 = time.time() | |
| # estimate camera parameters and visualize | |
| c2ws_pred, focals_pred = freesplatter.estimate_poses(images, gaussians, masks=alphas, use_first_focal=True, pnp_iter=10) | |
| fig = self.visualize_cameras_object(images, c2ws_pred, focals_pred, legends=legends) | |
| t2 = time.time() | |
| yield fig | |
| # save gaussians | |
| gs_vis_path = os.path.join(self.output_dir, 'gs_vis.ply') | |
| save_gaussian(gaussians, gs_vis_path, freesplatter, opacity_threshold=5e-3, pad_2dgs_scale=True) | |
| print(f'Save gaussian at {gs_vis_path}') | |
| yield gs_vis_path | |
| # render video | |
| with torch.inference_mode(): | |
| c2ws_video = get_circular_cameras(N=120, elevation=0, radius=2.0, normalize=True).to(device) | |
| fx = fy = focals_pred.mean() / 512.0 | |
| cx = cy = torch.ones_like(fx) * 0.5 | |
| fxfycxcy_video = torch.tensor([fx, fy, cx, cy]).unsqueeze(0).repeat(c2ws_video.shape[0], 1).to(device) | |
| video_frames = freesplatter.forward_renderer( | |
| gaussians, | |
| c2ws_video.unsqueeze(0), | |
| fxfycxcy_video.unsqueeze(0), | |
| )['image'][0].clamp(0, 1) | |
| video_path = os.path.join(self.output_dir, 'gs.mp4') | |
| save_video(video_frames, video_path, fps=30) | |
| print(f'Save video at {video_path}') | |
| t3 = time.time() | |
| yield video_path | |
| # extract mesh | |
| with torch.inference_mode(): | |
| c2ws_fusion = get_fibonacci_cameras(N=120, radius=2.0) | |
| c2ws_fusion, _ = normalize_cameras(c2ws_fusion, camera_position=torch.tensor([0., -2., 0.]), camera_system='opencv') | |
| c2ws_fusion = c2ws_fusion.to(device) | |
| c2ws_fusion_reference = torch.linalg.inv(c2ws_fusion[0:1]) @ c2ws_fusion | |
| fx = fy = focals_pred.mean() / 512.0 | |
| cx = cy = torch.ones_like(fx) * 0.5 | |
| fov = np.rad2deg(np.arctan(0.5 / fx.item())) * 2 | |
| fxfycxcy_fusion = torch.tensor([fx, fy, cx, cy]).unsqueeze(0).repeat(c2ws_fusion.shape[0], 1).to(device) | |
| fusion_render_results = freesplatter.forward_renderer( | |
| gaussians, | |
| c2ws_fusion_reference.unsqueeze(0), | |
| fxfycxcy_fusion.unsqueeze(0), | |
| ) | |
| images_fusion = fusion_render_results['image'][0].clamp(0, 1).permute(0, 2, 3, 1) | |
| alphas_fusion = fusion_render_results['alpha'][0].permute(0, 2, 3, 1) | |
| depths_fusion = fusion_render_results['depth'][0].permute(0, 2, 3, 1) | |
| fusion_images = (images_fusion.detach().cpu().numpy()*255).clip(0, 255).astype(np.uint8) | |
| fusion_depths = depths_fusion.detach().cpu().numpy() | |
| fusion_alphas = alphas_fusion.detach().cpu().numpy() | |
| fusion_masks = (fusion_alphas > 1e-2).astype(np.uint8) | |
| fusion_depths = fusion_depths * fusion_masks - np.ones_like(fusion_depths) * (1 - fusion_masks) | |
| fusion_c2ws = c2ws_fusion.detach().cpu().numpy() | |
| mesh_path = os.path.join(self.output_dir, 'mesh.obj') | |
| rgbd_to_mesh( | |
| fusion_images, fusion_depths, fusion_c2ws, fov, mesh_path, cam_elev_thr=-90) # use all angles for tsdf fusion | |
| print(f'Save mesh at {mesh_path}') | |
| t4 = time.time() | |
| # optimize texture | |
| cam_pos = c2ws_fusion[:, :3, 3].cpu().numpy() | |
| cam_inds = torch.from_numpy(fpsample.fps_sampling(cam_pos, 16).astype(int)).to(device=device) | |
| alphas_bake = alphas_fusion[cam_inds] | |
| images_bake = (images_fusion[cam_inds] - (1 - alphas_bake)) / alphas_bake.clamp(min=1e-6) | |
| fxfycxcy = fxfycxcy_fusion[cam_inds].clone() | |
| intrinsics = torch.eye(3).unsqueeze(0).repeat(len(cam_inds), 1, 1).to(fxfycxcy) | |
| intrinsics[:, 0, 0] = fxfycxcy[:, 0] | |
| intrinsics[:, 0, 2] = fxfycxcy[:, 2] | |
| intrinsics[:, 1, 1] = fxfycxcy[:, 1] | |
| intrinsics[:, 1, 2] = fxfycxcy[:, 3] | |
| out_mesh = trimesh.load(str(mesh_path), process=False) | |
| out_mesh = optimize_mesh( | |
| out_mesh, | |
| images_bake, | |
| alphas_bake.squeeze(-1), | |
| c2ws_fusion[cam_inds].inverse(), | |
| intrinsics, | |
| simplify=mesh_reduction, | |
| verbose=False | |
| ) | |
| mesh_fine_path = os.path.join(self.output_dir, 'mesh.glb') | |
| out_mesh.export(mesh_fine_path) | |
| print(f"Save optimized mesh at {mesh_fine_path}") | |
| t5 = time.time() | |
| print(f'Generate Gaussians: {t1-t0:.2f} seconds.') | |
| print(f'Estimate poses: {t2-t1:.2f} seconds.') | |
| print(f'Generate video: {t3-t2:.2f} seconds.') | |
| print(f'Generate mesh: {t4-t3:.2f} seconds.') | |
| print(f'Optimize mesh: {t5-t4:.2f} seconds.') | |
| yield mesh_fine_path | |
| def visualize_cameras_object( | |
| self, | |
| images, | |
| c2ws, | |
| focal_length, | |
| legends=None, | |
| ): | |
| images = v2.functional.resize(images, 128, interpolation=3, antialias=True).clamp(0, 1) | |
| images = (images.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype(np.uint8) | |
| cam2world = create_camera_to_world(torch.tensor([0, -2, 0]), camera_system='opencv').to(c2ws) | |
| transform = cam2world @ torch.linalg.inv(c2ws[0:1]) | |
| c2ws = transform @ c2ws | |
| c2ws = c2ws.detach().cpu().numpy() | |
| c2ws[:, :, 1:3] *= -1 # opencv to opengl | |
| focal_length = focal_length.mean().detach().cpu().numpy() | |
| fov = np.rad2deg(np.arctan(256.0 / focal_length)) * 2 | |
| colors = [cmap(i / len(images))[:3] for i in range(len(images))] | |
| legends = [None] * len(images) if legends is None else legends | |
| viz = CameraVisualizer(c2ws, legends, colors, images=images) | |
| fig = viz.update_figure( | |
| 3, | |
| height=320, | |
| line_width=5, | |
| base_radius=1, | |
| zoom_scale=1, | |
| fov_deg=fov, | |
| show_grid=True, | |
| show_ticklabels=True, | |
| show_background=True, | |
| y_up=False, | |
| ) | |
| return fig | |
| # FreeSplatter-S | |
| def run_views_to_scene( | |
| self, | |
| image1, | |
| image2, | |
| cache_dir=None, | |
| ): | |
| self.output_dir = os.path.join(cache_dir, f'output_{uuid.uuid4()}') | |
| os.makedirs(self.output_dir, exist_ok=True) | |
| # preprocesss images | |
| images = [] | |
| for image in [image1, image2]: | |
| w, h = image.size | |
| image = torch.from_numpy(np.asarray(image) / 255.0).float() | |
| image = rearrange(image, 'h w c -> c h w') | |
| image = v2.functional.center_crop(image, min(h, w)) | |
| image = v2.functional.resize(image, 512, interpolation=3, antialias=True).clamp(0, 1) | |
| images.append(image) | |
| images = torch.stack(images, dim=0) | |
| images_vis = v2.functional.to_pil_image(rearrange(images, 'n c h w -> c h (n w)')) | |
| # run reconstruction | |
| legends = [f'V{i}' for i in range(1, 1+len(images))] | |
| gs_vis_path, video_path, fig = self.run_freesplatter_scene(images, legends=legends) | |
| return images_vis, gs_vis_path, video_path, fig | |
| def run_freesplatter_scene( | |
| self, | |
| images, | |
| legends=None, | |
| ): | |
| freesplatter = self.freesplatter_scene | |
| device = self.device | |
| images = images.to(device) | |
| t0 = time.time() | |
| with torch.inference_mode(): | |
| gaussians = freesplatter.forward_gaussians(images.unsqueeze(0)) | |
| t1 = time.time() | |
| # estimate camera parameters | |
| c2ws_pred, focals_pred = freesplatter.estimate_poses(images, gaussians, use_first_focal=True, pnp_iter=10) | |
| # rescale cameras to make the baseline equal to 1.0 | |
| baseline_pred = (c2ws_pred[:, :3, 3] - c2ws_pred[:1, :3, 3]).norm() + 1e-2 | |
| scale_factor = 1.0 / baseline_pred | |
| c2ws_pred = c2ws_pred.clone() | |
| c2ws_pred[:, :3, 3] *= scale_factor | |
| # visualize cameras | |
| fig = self.visualize_cameras_scene(images, c2ws_pred, focals_pred, legends=legends) | |
| t2 = time.time() | |
| # save gaussians | |
| gs_vis_path = os.path.join(self.output_dir, 'gs_vis.ply') | |
| save_gaussian(gaussians, gs_vis_path, freesplatter, opacity_threshold=5e-3) | |
| print(f'Save gaussian at {gs_vis_path}') | |
| # render video | |
| with torch.inference_mode(): | |
| c2ws_video = generate_interpolated_path(c2ws_pred.detach().cpu().numpy()[:, :3, :], n_interp=120) | |
| c2ws_video = torch.cat([ | |
| torch.from_numpy(c2ws_video), | |
| torch.tensor([0, 0, 0, 1]).reshape(1, 1, 4).repeat(c2ws_video.shape[0], 1, 1) | |
| ], dim=1).to(gaussians) | |
| fx = fy = focals_pred.mean() / 512.0 | |
| cx = cy = torch.ones_like(fx) * 0.5 | |
| fxfycxcy_video = torch.tensor([fx, fy, cx, cy]).unsqueeze(0).repeat(c2ws_video.shape[0], 1).to(device) | |
| video_frames = freesplatter.forward_renderer( | |
| gaussians, | |
| c2ws_video.unsqueeze(0), | |
| fxfycxcy_video.unsqueeze(0), | |
| rescale=scale_factor.reshape(1).to(gaussians) | |
| )['image'][0].clamp(0, 1) | |
| video_path = os.path.join(self.output_dir, 'gs.mp4') | |
| save_video(video_frames, video_path, fps=30) | |
| print(f'Save video at {video_path}') | |
| t3 = time.time() | |
| print(f'Generate Gaussians: {t1-t0:.2f} seconds.') | |
| print(f'Estimate poses: {t2-t1:.2f} seconds.') | |
| print(f'Generate video: {t3-t2:.2f} seconds.') | |
| return gs_vis_path, video_path, fig | |
| def visualize_cameras_scene( | |
| self, | |
| images, | |
| c2ws, | |
| focal_length, | |
| legends=None, | |
| ): | |
| images = v2.functional.resize(images, 128, interpolation=3, antialias=True).clamp(0, 1) | |
| images = (images.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype(np.uint8) | |
| c2ws = c2ws.detach().cpu().numpy() | |
| c2ws[:, :, 1:3] *= -1 | |
| focal_length = focal_length.mean().detach().cpu().numpy() | |
| fov = np.rad2deg(np.arctan(256.0 / focal_length)) * 2 | |
| colors = [cmap(i / len(images))[:3] for i in range(len(images))] | |
| legends = [None] * len(images) if legends is None else legends | |
| viz = CameraVisualizer(c2ws, legends, colors, images=images) | |
| fig = viz.update_figure( | |
| 2, | |
| height=320, | |
| line_width=5, | |
| base_radius=1, | |
| zoom_scale=1, | |
| fov_deg=fov, | |
| show_grid=True, | |
| show_ticklabels=True, | |
| show_background=True, | |
| y_up=False, | |
| ) | |
| return fig |