Spaces:
Running
on
Zero
Running
on
Zero
File size: 16,737 Bytes
1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 |
import math
from functools import cache
from typing import Dict, Union
import numpy as np
import spaces
import torch
import torch.nn.functional as F
from einops import rearrange
from jaxtyping import Float
from PIL import Image
from torch import Tensor
from torchvision.transforms import ToPILImage
from .rasterize import (NVDiffRasterizerContext,
rasterize_position_and_normal_maps,
render_geo_from_mesh,
render_rgb_from_texture_mesh_with_mask)
from utils.file_utils import load_tensor_from_file
# Global variable to store the singleton context
_CTX_INSTANCE = None
@spaces.GPU
def get_rasterizer_context():
"""
Get the NVDiffRasterizer context using singleton pattern.
This ensures only one context is created and reused across the application.
"""
global _CTX_INSTANCE
if _CTX_INSTANCE is None:
# Use string 'cuda' instead of torch.device to avoid early CUDA initialization
_CTX_INSTANCE = NVDiffRasterizerContext('cuda', 'cuda')
return _CTX_INSTANCE
def setup_lights():
"""
Set three random point lights in the scene.
"""
raise NotImplementedError("setup_lights function is not implemented yet.")
@spaces.GPU
def render_views(mesh, texture, mvp_matrix, lights=None, img_size=(512, 512)) -> Image.Image:
"""
Render the RGB color images of the mesh. The background will be transparent.
:param mesh: The mesh to be rendered. Class: Mesh.
:param texture: The texture of the mesh, a tensor of shape (H, W, 3).
:param mvp_matrix: The Model-View-Projection matrix for rendering, a tensor of shape (n_v, 4, 4).
:param lights: The lights in the scene.
:param img_size: The size of the output image, a tuple (height, width).
:return: A concatenated PIL Image.
"""
# If texture or mvp_matrix is a file path, load the tensor from file
if isinstance(texture, str):
texture = load_tensor_from_file(texture, map_location="cuda")
if isinstance(mvp_matrix, str):
mvp_matrix = load_tensor_from_file(mvp_matrix, map_location="cuda")
mesh = mesh.to("cuda")
texture = texture.to("cuda")
mvp_matrix = mvp_matrix.to("cuda")
print("Trying to render views...")
ctx = get_rasterizer_context()
if texture.shape[-1] != 3:
texture = texture.permute(1, 2, 0)
image_height, image_width = img_size
rgb_cond, mask = render_rgb_from_texture_mesh_with_mask(
ctx, mesh, texture, mvp_matrix, image_height, image_width, torch.tensor([0.0, 0.0, 0.0], device=texture.device))
if mvp_matrix.shape[0] == 0:
return None
pil_images = []
for i in range(mvp_matrix.shape[0]):
rgba_img = torch.cat([rgb_cond[i], mask[i].unsqueeze(-1)], dim=-1) # [H, W, 3] + [H, W, 1] -> [H, W, 4]
rgba_img = (rgba_img * 255).to(torch.uint8) # Convert to uint8
rgba_img = rgba_img.cpu().numpy() # Convert to numpy array
pil_images.append(Image.fromarray(rgba_img, mode='RGBA'))
if not pil_images:
return None
total_width = sum(img.width for img in pil_images)
max_height = max(img.height for img in pil_images)
concatenated_image = Image.new('RGBA', (total_width, max_height))
current_x = 0
for img in pil_images:
concatenated_image.paste(img, (current_x, 0))
current_x += img.width
return concatenated_image
@spaces.GPU
def render_geo_views_tensor(mesh, mvp_matrix, img_size=(512, 512)) -> tuple[torch.Tensor, torch.Tensor]:
"""
render the geometry information including position and normal from views that mvp matrix implies.
"""
ctx = get_rasterizer_context()
image_height, image_width = img_size
position_images, normal_images, mask_images = render_geo_from_mesh(ctx, mesh, mvp_matrix, image_height, image_width)
return position_images, normal_images, mask_images
@spaces.GPU
def render_geo_map(mesh, map_size=(1024, 1024)) -> tuple[torch.Tensor, torch.Tensor]:
"""
Render the geometry information including position and normal from UV parameterization.
"""
ctx = get_rasterizer_context()
map_height, map_width = map_size
position_images, normal_images, mask = rasterize_position_and_normal_maps(ctx, mesh, map_height, map_width)
# out_imgs = []
# if mask.ndim == 4:
# mask = mask[0]
# for img_map in [position_images, normal_images]:
# if img_map.ndim == 4:
# img_map = img_map[0]
# # normalize to [0, 1]
# img_map = (img_map - img_map.min()) / (img_map.max() - img_map.min() + 1e-6)
# rgba_img = torch.cat([img_map, mask], dim=-1) # [H, W, 3] + [H, W, 1] -> [H, W, 4]
# rgba_img = (rgba_img * 255).to(torch.uint8) # Convert to uint8
# rgba_img = rgba_img.cpu().numpy() # Convert to numpy array
# out_imgs.append(Image.fromarray(rgba_img, mode='RGBA'))
return position_images, normal_images
@cache
def get_pure_texture(uv_size, color=(int("0x55", 16), int("0x55", 16), int("0x55", 16))) -> torch.Tensor:
"""
get a pure texture image with the specified color.
:param uv_size: The size of the UV map (height, width).
:param color: The color of the texture, default is "0x555555" (light gray).
:return: A texture image tensor of shape (height, width, 3).
"""
height, width = uv_size
color = torch.tensor(color, dtype=torch.float32).view(1, 1, 3) / 255.0
texture = color.repeat(height, width, 1)
return texture
def get_c2w(
azimuth_deg,
elevation_deg,
camera_distances,):
assert len(azimuth_deg) == len(elevation_deg) == len(camera_distances)
n_views = len(azimuth_deg)
#camera_distances = torch.full_like(elevation_deg, dis)
elevation = elevation_deg * math.pi / 180
azimuth = azimuth_deg * math.pi / 180
camera_positions = torch.stack(
[
camera_distances * torch.cos(elevation) * torch.cos(azimuth),
camera_distances * torch.cos(elevation) * torch.sin(azimuth),
camera_distances * torch.sin(elevation),
],
dim=-1,
)
center = torch.zeros_like(camera_positions)
up = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None, :].repeat(n_views, 1)
lookat = F.normalize(center - camera_positions, dim=-1)
right = F.normalize(torch.cross(lookat, up, dim=-1), dim=-1)
up = F.normalize(torch.cross(right, lookat, dim=-1), dim=-1)
c2w3x4 = torch.cat(
[torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
dim=-1,
)
c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1)
c2w[:, 3, 3] = 1.0
return c2w
def camera_strategy_test_4_90deg(
mesh: Dict,
num_views: int = 4,
**kwargs) -> Dict:
"""
For sup views: Random elevation and azimuth, fixed distance and close fov.
:param num_views: number of supervision views
:param kwargs: additional arguments
"""
# Default camera intrinsics
default_elevation = 10
default_camera_lens = 50
default_camera_sensor_width = 36
default_fovy = 2 * np.arctan(default_camera_sensor_width / (2 * default_camera_lens))
bbox_size = mesh.v_pos.max(dim=0)[0] - mesh.v_pos.min(dim=0)[0]
distance = default_camera_lens / default_camera_sensor_width * \
math.sqrt(bbox_size[0] ** 2 + bbox_size[1] ** 2 + bbox_size[2] ** 2)
all_azimuth_deg = torch.linspace(0, 360.0, num_views + 1)[:num_views] - 90
all_elevation_deg = torch.full_like(all_azimuth_deg, default_elevation)
# Get the corresponding azimuth and elevation
view_idxs = torch.arange(0, num_views)
azimuth = all_azimuth_deg[view_idxs]
elevation = all_elevation_deg[view_idxs]
camera_distances = torch.full_like(elevation, distance)
c2w = get_c2w(azimuth, elevation, camera_distances)
if c2w.ndim == 2:
w2c: Float[Tensor, "4 4"] = torch.zeros(4, 4).to(c2w)
w2c[:3, :3] = c2w[:3, :3].permute(1, 0)
w2c[:3, 3:] = -c2w[:3, :3].permute(1, 0) @ c2w[:3, 3:]
w2c[3, 3] = 1.0
else:
w2c: Float[Tensor, "B 4 4"] = torch.zeros(c2w.shape[0], 4, 4).to(c2w)
w2c[:, :3, :3] = c2w[:, :3, :3].permute(0, 2, 1)
w2c[:, :3, 3:] = -c2w[:, :3, :3].permute(0, 2, 1) @ c2w[:, :3, 3:]
w2c[:, 3, 3] = 1.0
fovy = torch.full_like(azimuth, default_fovy)
return {
'cond_sup_view_idxs': view_idxs,
'cond_sup_c2w': c2w,
'cond_sup_w2c': w2c,
'cond_sup_fovy': fovy,
# 'cond_sup_azimuth': azimuth,
# 'cond_sup_elevation': elevation,
}
def _get_projection_matrix(
fovy: Union[float, Float[Tensor, "B"]], aspect_wh: float, near: float, far: float
) -> Float[Tensor, "*B 4 4"]:
if isinstance(fovy, float):
proj_mtx = torch.zeros(4, 4, dtype=torch.float32)
proj_mtx[0, 0] = 1.0 / (math.tan(fovy / 2.0) * aspect_wh)
proj_mtx[1, 1] = -1.0 / math.tan(
fovy / 2.0
) # add a negative sign here as the y axis is flipped in nvdiffrast output
proj_mtx[2, 2] = -(far + near) / (far - near)
proj_mtx[2, 3] = -2.0 * far * near / (far - near)
proj_mtx[3, 2] = -1.0
else:
batch_size = fovy.shape[0]
proj_mtx = torch.zeros(batch_size, 4, 4, dtype=torch.float32)
proj_mtx[:, 0, 0] = 1.0 / (torch.tan(fovy / 2.0) * aspect_wh)
proj_mtx[:, 1, 1] = -1.0 / torch.tan(
fovy / 2.0
) # add a negative sign here as the y axis is flipped in nvdiffrast output
proj_mtx[:, 2, 2] = -(far + near) / (far - near)
proj_mtx[:, 2, 3] = -2.0 * far * near / (far - near)
proj_mtx[:, 3, 2] = -1.0
return proj_mtx
def _get_mvp_matrix(
c2w: Float[Tensor, "*B 4 4"], proj_mtx: Float[Tensor, "*B 4 4"]
) -> Float[Tensor, "*B 4 4"]:
# calculate w2c from c2w: R' = Rt, t' = -Rt * t
# mathematically equivalent to (c2w)^-1
if c2w.ndim == 2:
assert proj_mtx.ndim == 2
w2c: Float[Tensor, "4 4"] = torch.zeros(4, 4).to(c2w)
w2c[:3, :3] = c2w[:3, :3].permute(1, 0)
w2c[:3, 3:] = -c2w[:3, :3].permute(1, 0) @ c2w[:3, 3:]
w2c[3, 3] = 1.0
else:
w2c: Float[Tensor, "B 4 4"] = torch.zeros(c2w.shape[0], 4, 4).to(c2w)
w2c[:, :3, :3] = c2w[:, :3, :3].permute(0, 2, 1)
w2c[:, :3, 3:] = -c2w[:, :3, :3].permute(0, 2, 1) @ c2w[:, :3, 3:]
w2c[:, 3, 3] = 1.0
# calculate mvp matrix by proj_mtx @ w2c (mv_mtx)
mvp_mtx = proj_mtx @ w2c
return mvp_mtx
def get_mvp_matrix(mesh, num_views=4, width=512, height=512, strategy="strategy_test_4_90deg"):
"""
Get Model-View-Projection (MVP) matrix for rendering views.
:param mesh: The mesh object to determine camera positioning.
:param num_views: Number of views to generate, default is 4.
:param width: Image width for projection matrix calculation.
:param height: Image height for projection matrix calculation.
:param strategy: Camera positioning strategy, default is "strategy_test_4_90deg".
:return: MVP matrix and world-to-camera transformation matrix.
"""
if strategy == "strategy_test_4_90deg":
camera_info = camera_strategy_test_4_90deg(
mesh=mesh, # Dummy mesh for camera strategy
num_views=num_views,
)
cond_sup_fovy = camera_info["cond_sup_fovy"]
cond_sup_c2w = camera_info["cond_sup_c2w"]
cond_sup_w2c = camera_info["cond_sup_w2c"]
# cond_sup_azimuth = camera_info["cond_sup_azimuth"]
# cond_sup_elevation = camera_info["cond_sup_elevation"]
else:
raise ValueError(f"Unsupported camera strategy: {strategy}")
cond_sup_proj_mtx: Float[Tensor, "B 4 4"] = _get_projection_matrix(
cond_sup_fovy, width / height, 0.1, 1000.0
)
mvp_mtx: Float[Tensor, "B 4 4"] = _get_mvp_matrix(cond_sup_c2w, cond_sup_proj_mtx)
return mvp_mtx, cond_sup_w2c
@torch.cuda.amp.autocast(enabled=False)
def _get_depth_noraml_map_with_mask(xyz_map, normal_map, mask, w2c, device="cuda", background_color=(0, 0, 0)):
"""
Get depth and normal map with mask from position and normal images.
:param xyz_map: Position images in world coordinate, shape [B, Nv, H, W, 3]. It is the return value of `render_geo_views`.
:param normal_map: Normal images in world coordinate, shape [B, Nv, H, W, 3]. It is the return value of `render_geo_views`.
:param mask: Mask for the images, shape [B, Nv, H, W]. It is the return value of `render_geo_views`.
:param w2c: World to camera transformation matrix, shape [B, Nv, 4, 4].
:param device: Device to run the computation on, default is "cuda".
:param background_color: Background color for the depth and normal maps.
:return: depth_map, normal_map, mask
"""
w2c = w2c.to(device)
# Render world coordinate position map and mask
B, Nv, H, W, C = xyz_map.shape # B: batch size, Nv: number of views, H/W: height/width, C: channels
assert Nv == 1
# Rearrange tensors for batch processing
xyz_map = rearrange(xyz_map, "B Nv H W C -> (B Nv) (H W) C")
normal_map = rearrange(normal_map, "B Nv H W C -> (B Nv) (H W) C")
w2c = rearrange(w2c, "B Nv C1 C2 -> (B Nv) C1 C2")
# Create homogeneous coordinates and correctly transform to camera coordinate system
# Points in world coordinate system need to be multiplied by world-to-camera transformation matrix
B_Nv, N, C = xyz_map.shape
ones = torch.ones(B_Nv, N, 1, dtype=xyz_map.dtype, device=xyz_map.device)
homogeneous_xyz = torch.cat([xyz_map, ones], dim=2) # [x,y,z,1]
zeros = torch.zeros(B_Nv, N, 1, dtype=xyz_map.dtype, device=xyz_map.device)
homogeneous_normal = torch.cat([normal_map, zeros], dim=2) # [x,y,z,1]
camera_coords = torch.bmm(homogeneous_xyz, w2c.transpose(1, 2))
camera_normals = torch.bmm(homogeneous_normal, w2c.transpose(1, 2))
depth_map = camera_coords[..., 2:3] # Z-axis is the depth direction in camera coordinate system
depth_map = rearrange(depth_map, "(B Nv) (H W) 1 -> B Nv H W", B=B, Nv=Nv, H=H, W=W)
normal_map = camera_normals[..., :3] # Keep only x, y, z components
normal_map = rearrange(normal_map, "(B Nv) (H W) c -> B Nv H W c", B=B, Nv=Nv, H=H, W=W)
assert depth_map.dtype == torch.float32, f"depth_map must be float32, otherwise there will be artifact in controlnet generated pictures, but got {depth_map.dtype}"
# Calculate min and max values
min_depth = depth_map.amin((1,2,3), keepdim=True)
max_depth = depth_map.amax((1,2,3), keepdim=True)
depth_map = (depth_map - min_depth) / (max_depth - min_depth + 1e-6) # Normalize to [0, 1]
depth_map = depth_map.repeat(1, 3, 1, 1) # Repeat 3 times to get RGB depth map
normal_map = normal_map * 0.5 + 0.5 # Normalize to [0, 1], [B, Nv, H, W, 3]
normal_map = normal_map[:,0].permute(0, 3, 1, 2) # [B, 3, H, W]
rgb_background_batched = torch.tensor(background_color, dtype=torch.float32, device=device).view(1, 3, 1, 1)
depth_map = torch.lerp(rgb_background_batched, depth_map, mask)
normal_map = torch.lerp(rgb_background_batched, normal_map, mask)
return depth_map, normal_map, mask
@spaces.GPU
def get_silhouette_image(position_imgs, normal_imgs, mask_imgs, w2c, selected_view="First View") -> tuple[Image.Image, Image.Image]:
"""
Get the silhouette image based on geometry image.
:param position_imgs: Position images from different views, shape [Nv, H, W, 3].
:param normal_imgs: Normal images from different views, shape [Nv, H, W, 3].
:param mask_imgs: Mask for the images, shape [Nv, H, W]. It is the return value of `render_geo_views`.
:param w2c: World to camera transformation matrix, shape [Nv, 4, 4].
:param selected_view: The view selected for generating the image condition.
:return: silhouettes (including depth and normal, which is in camera coordinate system).
"""
view_id_map = {
"First View": 0,
"Second View": 1,
"Third View": 2,
"Fourth View": 3
}
view_id = view_id_map[selected_view]
position_view = position_imgs[view_id: view_id + 1]
normal_view = normal_imgs[view_id: view_id + 1]
mask_view = mask_imgs[view_id: view_id + 1]
w2c = w2c[view_id: view_id + 1] # Select the corresponding w2c for the view
depth_img, normal_img, mask = _get_depth_noraml_map_with_mask(
position_view.unsqueeze(0), # Add batch dimension
normal_view.unsqueeze(0),
mask_view.unsqueeze(0),
w2c.unsqueeze(0),
)
to_img = ToPILImage()
return to_img(depth_img.squeeze(0)), to_img(normal_img.squeeze(0)), to_img(mask.squeeze(0))
|