import math from functools import cache from typing import Dict, Union import numpy as np import spaces import torch import torch.nn.functional as F from einops import rearrange from jaxtyping import Float from PIL import Image from torch import Tensor from torchvision.transforms import ToPILImage from .rasterize import (NVDiffRasterizerContext, rasterize_position_and_normal_maps, render_geo_from_mesh, render_rgb_from_texture_mesh_with_mask) from utils.file_utils import load_tensor_from_file # Global variable to store the singleton context _CTX_INSTANCE = None @spaces.GPU def get_rasterizer_context(): """ Get the NVDiffRasterizer context using singleton pattern. This ensures only one context is created and reused across the application. """ global _CTX_INSTANCE if _CTX_INSTANCE is None: # Use string 'cuda' instead of torch.device to avoid early CUDA initialization _CTX_INSTANCE = NVDiffRasterizerContext('cuda', 'cuda') return _CTX_INSTANCE def setup_lights(): """ Set three random point lights in the scene. """ raise NotImplementedError("setup_lights function is not implemented yet.") @spaces.GPU def render_views(mesh, texture, mvp_matrix, lights=None, img_size=(512, 512)) -> Image.Image: """ Render the RGB color images of the mesh. The background will be transparent. :param mesh: The mesh to be rendered. Class: Mesh. :param texture: The texture of the mesh, a tensor of shape (H, W, 3). :param mvp_matrix: The Model-View-Projection matrix for rendering, a tensor of shape (n_v, 4, 4). :param lights: The lights in the scene. :param img_size: The size of the output image, a tuple (height, width). :return: A concatenated PIL Image. """ # If texture or mvp_matrix is a file path, load the tensor from file if isinstance(texture, str): texture = load_tensor_from_file(texture, map_location="cuda") if isinstance(mvp_matrix, str): mvp_matrix = load_tensor_from_file(mvp_matrix, map_location="cuda") mesh = mesh.to("cuda") texture = texture.to("cuda") mvp_matrix = mvp_matrix.to("cuda") print("Trying to render views...") ctx = get_rasterizer_context() if texture.shape[-1] != 3: texture = texture.permute(1, 2, 0) image_height, image_width = img_size rgb_cond, mask = render_rgb_from_texture_mesh_with_mask( ctx, mesh, texture, mvp_matrix, image_height, image_width, torch.tensor([0.0, 0.0, 0.0], device=texture.device)) if mvp_matrix.shape[0] == 0: return None pil_images = [] for i in range(mvp_matrix.shape[0]): rgba_img = torch.cat([rgb_cond[i], mask[i].unsqueeze(-1)], dim=-1) # [H, W, 3] + [H, W, 1] -> [H, W, 4] rgba_img = (rgba_img * 255).to(torch.uint8) # Convert to uint8 rgba_img = rgba_img.cpu().numpy() # Convert to numpy array pil_images.append(Image.fromarray(rgba_img, mode='RGBA')) if not pil_images: return None total_width = sum(img.width for img in pil_images) max_height = max(img.height for img in pil_images) concatenated_image = Image.new('RGBA', (total_width, max_height)) current_x = 0 for img in pil_images: concatenated_image.paste(img, (current_x, 0)) current_x += img.width return concatenated_image @spaces.GPU def render_geo_views_tensor(mesh, mvp_matrix, img_size=(512, 512)) -> tuple[torch.Tensor, torch.Tensor]: """ render the geometry information including position and normal from views that mvp matrix implies. """ ctx = get_rasterizer_context() image_height, image_width = img_size position_images, normal_images, mask_images = render_geo_from_mesh(ctx, mesh, mvp_matrix, image_height, image_width) return position_images, normal_images, mask_images @spaces.GPU def render_geo_map(mesh, map_size=(1024, 1024)) -> tuple[torch.Tensor, torch.Tensor]: """ Render the geometry information including position and normal from UV parameterization. """ ctx = get_rasterizer_context() map_height, map_width = map_size position_images, normal_images, mask = rasterize_position_and_normal_maps(ctx, mesh, map_height, map_width) # out_imgs = [] # if mask.ndim == 4: # mask = mask[0] # for img_map in [position_images, normal_images]: # if img_map.ndim == 4: # img_map = img_map[0] # # normalize to [0, 1] # img_map = (img_map - img_map.min()) / (img_map.max() - img_map.min() + 1e-6) # rgba_img = torch.cat([img_map, mask], dim=-1) # [H, W, 3] + [H, W, 1] -> [H, W, 4] # rgba_img = (rgba_img * 255).to(torch.uint8) # Convert to uint8 # rgba_img = rgba_img.cpu().numpy() # Convert to numpy array # out_imgs.append(Image.fromarray(rgba_img, mode='RGBA')) return position_images, normal_images @cache def get_pure_texture(uv_size, color=(int("0x55", 16), int("0x55", 16), int("0x55", 16))) -> torch.Tensor: """ get a pure texture image with the specified color. :param uv_size: The size of the UV map (height, width). :param color: The color of the texture, default is "0x555555" (light gray). :return: A texture image tensor of shape (height, width, 3). """ height, width = uv_size color = torch.tensor(color, dtype=torch.float32).view(1, 1, 3) / 255.0 texture = color.repeat(height, width, 1) return texture def get_c2w( azimuth_deg, elevation_deg, camera_distances,): assert len(azimuth_deg) == len(elevation_deg) == len(camera_distances) n_views = len(azimuth_deg) #camera_distances = torch.full_like(elevation_deg, dis) elevation = elevation_deg * math.pi / 180 azimuth = azimuth_deg * math.pi / 180 camera_positions = torch.stack( [ camera_distances * torch.cos(elevation) * torch.cos(azimuth), camera_distances * torch.cos(elevation) * torch.sin(azimuth), camera_distances * torch.sin(elevation), ], dim=-1, ) center = torch.zeros_like(camera_positions) up = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None, :].repeat(n_views, 1) lookat = F.normalize(center - camera_positions, dim=-1) right = F.normalize(torch.cross(lookat, up, dim=-1), dim=-1) up = F.normalize(torch.cross(right, lookat, dim=-1), dim=-1) c2w3x4 = torch.cat( [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], dim=-1, ) c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1) c2w[:, 3, 3] = 1.0 return c2w def camera_strategy_test_4_90deg( mesh: Dict, num_views: int = 4, **kwargs) -> Dict: """ For sup views: Random elevation and azimuth, fixed distance and close fov. :param num_views: number of supervision views :param kwargs: additional arguments """ # Default camera intrinsics default_elevation = 10 default_camera_lens = 50 default_camera_sensor_width = 36 default_fovy = 2 * np.arctan(default_camera_sensor_width / (2 * default_camera_lens)) bbox_size = mesh.v_pos.max(dim=0)[0] - mesh.v_pos.min(dim=0)[0] distance = default_camera_lens / default_camera_sensor_width * \ math.sqrt(bbox_size[0] ** 2 + bbox_size[1] ** 2 + bbox_size[2] ** 2) all_azimuth_deg = torch.linspace(0, 360.0, num_views + 1)[:num_views] - 90 all_elevation_deg = torch.full_like(all_azimuth_deg, default_elevation) # Get the corresponding azimuth and elevation view_idxs = torch.arange(0, num_views) azimuth = all_azimuth_deg[view_idxs] elevation = all_elevation_deg[view_idxs] camera_distances = torch.full_like(elevation, distance) c2w = get_c2w(azimuth, elevation, camera_distances) if c2w.ndim == 2: w2c: Float[Tensor, "4 4"] = torch.zeros(4, 4).to(c2w) w2c[:3, :3] = c2w[:3, :3].permute(1, 0) w2c[:3, 3:] = -c2w[:3, :3].permute(1, 0) @ c2w[:3, 3:] w2c[3, 3] = 1.0 else: w2c: Float[Tensor, "B 4 4"] = torch.zeros(c2w.shape[0], 4, 4).to(c2w) w2c[:, :3, :3] = c2w[:, :3, :3].permute(0, 2, 1) w2c[:, :3, 3:] = -c2w[:, :3, :3].permute(0, 2, 1) @ c2w[:, :3, 3:] w2c[:, 3, 3] = 1.0 fovy = torch.full_like(azimuth, default_fovy) return { 'cond_sup_view_idxs': view_idxs, 'cond_sup_c2w': c2w, 'cond_sup_w2c': w2c, 'cond_sup_fovy': fovy, # 'cond_sup_azimuth': azimuth, # 'cond_sup_elevation': elevation, } def _get_projection_matrix( fovy: Union[float, Float[Tensor, "B"]], aspect_wh: float, near: float, far: float ) -> Float[Tensor, "*B 4 4"]: if isinstance(fovy, float): proj_mtx = torch.zeros(4, 4, dtype=torch.float32) proj_mtx[0, 0] = 1.0 / (math.tan(fovy / 2.0) * aspect_wh) proj_mtx[1, 1] = -1.0 / math.tan( fovy / 2.0 ) # add a negative sign here as the y axis is flipped in nvdiffrast output proj_mtx[2, 2] = -(far + near) / (far - near) proj_mtx[2, 3] = -2.0 * far * near / (far - near) proj_mtx[3, 2] = -1.0 else: batch_size = fovy.shape[0] proj_mtx = torch.zeros(batch_size, 4, 4, dtype=torch.float32) proj_mtx[:, 0, 0] = 1.0 / (torch.tan(fovy / 2.0) * aspect_wh) proj_mtx[:, 1, 1] = -1.0 / torch.tan( fovy / 2.0 ) # add a negative sign here as the y axis is flipped in nvdiffrast output proj_mtx[:, 2, 2] = -(far + near) / (far - near) proj_mtx[:, 2, 3] = -2.0 * far * near / (far - near) proj_mtx[:, 3, 2] = -1.0 return proj_mtx def _get_mvp_matrix( c2w: Float[Tensor, "*B 4 4"], proj_mtx: Float[Tensor, "*B 4 4"] ) -> Float[Tensor, "*B 4 4"]: # calculate w2c from c2w: R' = Rt, t' = -Rt * t # mathematically equivalent to (c2w)^-1 if c2w.ndim == 2: assert proj_mtx.ndim == 2 w2c: Float[Tensor, "4 4"] = torch.zeros(4, 4).to(c2w) w2c[:3, :3] = c2w[:3, :3].permute(1, 0) w2c[:3, 3:] = -c2w[:3, :3].permute(1, 0) @ c2w[:3, 3:] w2c[3, 3] = 1.0 else: w2c: Float[Tensor, "B 4 4"] = torch.zeros(c2w.shape[0], 4, 4).to(c2w) w2c[:, :3, :3] = c2w[:, :3, :3].permute(0, 2, 1) w2c[:, :3, 3:] = -c2w[:, :3, :3].permute(0, 2, 1) @ c2w[:, :3, 3:] w2c[:, 3, 3] = 1.0 # calculate mvp matrix by proj_mtx @ w2c (mv_mtx) mvp_mtx = proj_mtx @ w2c return mvp_mtx def get_mvp_matrix(mesh, num_views=4, width=512, height=512, strategy="strategy_test_4_90deg"): """ Get Model-View-Projection (MVP) matrix for rendering views. :param mesh: The mesh object to determine camera positioning. :param num_views: Number of views to generate, default is 4. :param width: Image width for projection matrix calculation. :param height: Image height for projection matrix calculation. :param strategy: Camera positioning strategy, default is "strategy_test_4_90deg". :return: MVP matrix and world-to-camera transformation matrix. """ if strategy == "strategy_test_4_90deg": camera_info = camera_strategy_test_4_90deg( mesh=mesh, # Dummy mesh for camera strategy num_views=num_views, ) cond_sup_fovy = camera_info["cond_sup_fovy"] cond_sup_c2w = camera_info["cond_sup_c2w"] cond_sup_w2c = camera_info["cond_sup_w2c"] # cond_sup_azimuth = camera_info["cond_sup_azimuth"] # cond_sup_elevation = camera_info["cond_sup_elevation"] else: raise ValueError(f"Unsupported camera strategy: {strategy}") cond_sup_proj_mtx: Float[Tensor, "B 4 4"] = _get_projection_matrix( cond_sup_fovy, width / height, 0.1, 1000.0 ) mvp_mtx: Float[Tensor, "B 4 4"] = _get_mvp_matrix(cond_sup_c2w, cond_sup_proj_mtx) return mvp_mtx, cond_sup_w2c @torch.cuda.amp.autocast(enabled=False) def _get_depth_noraml_map_with_mask(xyz_map, normal_map, mask, w2c, device="cuda", background_color=(0, 0, 0)): """ Get depth and normal map with mask from position and normal images. :param xyz_map: Position images in world coordinate, shape [B, Nv, H, W, 3]. It is the return value of `render_geo_views`. :param normal_map: Normal images in world coordinate, shape [B, Nv, H, W, 3]. It is the return value of `render_geo_views`. :param mask: Mask for the images, shape [B, Nv, H, W]. It is the return value of `render_geo_views`. :param w2c: World to camera transformation matrix, shape [B, Nv, 4, 4]. :param device: Device to run the computation on, default is "cuda". :param background_color: Background color for the depth and normal maps. :return: depth_map, normal_map, mask """ w2c = w2c.to(device) # Render world coordinate position map and mask B, Nv, H, W, C = xyz_map.shape # B: batch size, Nv: number of views, H/W: height/width, C: channels assert Nv == 1 # Rearrange tensors for batch processing xyz_map = rearrange(xyz_map, "B Nv H W C -> (B Nv) (H W) C") normal_map = rearrange(normal_map, "B Nv H W C -> (B Nv) (H W) C") w2c = rearrange(w2c, "B Nv C1 C2 -> (B Nv) C1 C2") # Create homogeneous coordinates and correctly transform to camera coordinate system # Points in world coordinate system need to be multiplied by world-to-camera transformation matrix B_Nv, N, C = xyz_map.shape ones = torch.ones(B_Nv, N, 1, dtype=xyz_map.dtype, device=xyz_map.device) homogeneous_xyz = torch.cat([xyz_map, ones], dim=2) # [x,y,z,1] zeros = torch.zeros(B_Nv, N, 1, dtype=xyz_map.dtype, device=xyz_map.device) homogeneous_normal = torch.cat([normal_map, zeros], dim=2) # [x,y,z,1] camera_coords = torch.bmm(homogeneous_xyz, w2c.transpose(1, 2)) camera_normals = torch.bmm(homogeneous_normal, w2c.transpose(1, 2)) depth_map = camera_coords[..., 2:3] # Z-axis is the depth direction in camera coordinate system depth_map = rearrange(depth_map, "(B Nv) (H W) 1 -> B Nv H W", B=B, Nv=Nv, H=H, W=W) normal_map = camera_normals[..., :3] # Keep only x, y, z components normal_map = rearrange(normal_map, "(B Nv) (H W) c -> B Nv H W c", B=B, Nv=Nv, H=H, W=W) assert depth_map.dtype == torch.float32, f"depth_map must be float32, otherwise there will be artifact in controlnet generated pictures, but got {depth_map.dtype}" # Calculate min and max values min_depth = depth_map.amin((1,2,3), keepdim=True) max_depth = depth_map.amax((1,2,3), keepdim=True) depth_map = (depth_map - min_depth) / (max_depth - min_depth + 1e-6) # Normalize to [0, 1] depth_map = depth_map.repeat(1, 3, 1, 1) # Repeat 3 times to get RGB depth map normal_map = normal_map * 0.5 + 0.5 # Normalize to [0, 1], [B, Nv, H, W, 3] normal_map = normal_map[:,0].permute(0, 3, 1, 2) # [B, 3, H, W] rgb_background_batched = torch.tensor(background_color, dtype=torch.float32, device=device).view(1, 3, 1, 1) depth_map = torch.lerp(rgb_background_batched, depth_map, mask) normal_map = torch.lerp(rgb_background_batched, normal_map, mask) return depth_map, normal_map, mask @spaces.GPU def get_silhouette_image(position_imgs, normal_imgs, mask_imgs, w2c, selected_view="First View") -> tuple[Image.Image, Image.Image]: """ Get the silhouette image based on geometry image. :param position_imgs: Position images from different views, shape [Nv, H, W, 3]. :param normal_imgs: Normal images from different views, shape [Nv, H, W, 3]. :param mask_imgs: Mask for the images, shape [Nv, H, W]. It is the return value of `render_geo_views`. :param w2c: World to camera transformation matrix, shape [Nv, 4, 4]. :param selected_view: The view selected for generating the image condition. :return: silhouettes (including depth and normal, which is in camera coordinate system). """ view_id_map = { "First View": 0, "Second View": 1, "Third View": 2, "Fourth View": 3 } view_id = view_id_map[selected_view] position_view = position_imgs[view_id: view_id + 1] normal_view = normal_imgs[view_id: view_id + 1] mask_view = mask_imgs[view_id: view_id + 1] w2c = w2c[view_id: view_id + 1] # Select the corresponding w2c for the view depth_img, normal_img, mask = _get_depth_noraml_map_with_mask( position_view.unsqueeze(0), # Add batch dimension normal_view.unsqueeze(0), mask_view.unsqueeze(0), w2c.unsqueeze(0), ) to_img = ToPILImage() return to_img(depth_img.squeeze(0)), to_img(normal_img.squeeze(0)), to_img(mask.squeeze(0))