File size: 16,737 Bytes
1d5bb62
 
 
 
 
6d4bcdf
1d5bb62
 
 
 
 
 
 
 
 
 
 
 
6d4bcdf
1d5bb62
6d4bcdf
 
 
 
 
 
 
 
 
 
 
 
 
 
1d5bb62
 
 
 
 
 
 
6d4bcdf
1d5bb62
 
 
 
 
 
 
 
 
 
6d4bcdf
 
 
 
 
 
 
 
 
 
 
1d5bb62
 
 
 
6d4bcdf
1d5bb62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d4bcdf
1d5bb62
 
 
 
6d4bcdf
1d5bb62
6d4bcdf
1d5bb62
 
6d4bcdf
1d5bb62
 
 
 
6d4bcdf
1d5bb62
6d4bcdf
1d5bb62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d4bcdf
1d5bb62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
import math
from functools import cache
from typing import Dict, Union

import numpy as np
import spaces
import torch
import torch.nn.functional as F
from einops import rearrange
from jaxtyping import Float
from PIL import Image
from torch import Tensor
from torchvision.transforms import ToPILImage

from .rasterize import (NVDiffRasterizerContext,
                        rasterize_position_and_normal_maps,
                        render_geo_from_mesh,
                        render_rgb_from_texture_mesh_with_mask)
from utils.file_utils import load_tensor_from_file

# Global variable to store the singleton context
_CTX_INSTANCE = None

@spaces.GPU
def get_rasterizer_context():
    """
    Get the NVDiffRasterizer context using singleton pattern.
    This ensures only one context is created and reused across the application.
    """
    global _CTX_INSTANCE
    if _CTX_INSTANCE is None:
        # Use string 'cuda' instead of torch.device to avoid early CUDA initialization
        _CTX_INSTANCE = NVDiffRasterizerContext('cuda', 'cuda')
    return _CTX_INSTANCE

def setup_lights():
    """
    Set three random point lights in the scene.
    """
    raise NotImplementedError("setup_lights function is not implemented yet.")

@spaces.GPU
def render_views(mesh, texture, mvp_matrix, lights=None, img_size=(512, 512)) -> Image.Image:
    """
    Render the RGB color images of the mesh. The background will be transparent.
    :param mesh: The mesh to be rendered. Class: Mesh.
    :param texture: The texture of the mesh, a tensor of shape (H, W, 3).
    :param mvp_matrix: The Model-View-Projection matrix for rendering, a tensor of shape (n_v, 4, 4).
    :param lights: The lights in the scene.
    :param img_size: The size of the output image, a tuple (height, width).
    :return: A concatenated PIL Image.
    """
    # If texture or mvp_matrix is a file path, load the tensor from file
    if isinstance(texture, str):
        texture = load_tensor_from_file(texture, map_location="cuda")
    if isinstance(mvp_matrix, str):
        mvp_matrix = load_tensor_from_file(mvp_matrix, map_location="cuda")
    mesh = mesh.to("cuda")
    texture = texture.to("cuda")
    mvp_matrix = mvp_matrix.to("cuda")
    
    print("Trying to render views...")
    ctx = get_rasterizer_context()
    if texture.shape[-1] != 3:
        texture = texture.permute(1, 2, 0)
    image_height, image_width = img_size
    rgb_cond, mask = render_rgb_from_texture_mesh_with_mask(
            ctx, mesh, texture, mvp_matrix, image_height, image_width, torch.tensor([0.0, 0.0, 0.0], device=texture.device))
    
    if mvp_matrix.shape[0] == 0:
        return None

    pil_images = []
    for i in range(mvp_matrix.shape[0]):
        rgba_img = torch.cat([rgb_cond[i], mask[i].unsqueeze(-1)], dim=-1)  # [H, W, 3] + [H, W, 1] -> [H, W, 4]
        rgba_img = (rgba_img * 255).to(torch.uint8)  # Convert to uint8
        rgba_img = rgba_img.cpu().numpy()  # Convert to numpy array
        pil_images.append(Image.fromarray(rgba_img, mode='RGBA'))

    if not pil_images:
        return None

    total_width = sum(img.width for img in pil_images)
    max_height = max(img.height for img in pil_images)

    concatenated_image = Image.new('RGBA', (total_width, max_height))

    current_x = 0
    for img in pil_images:
        concatenated_image.paste(img, (current_x, 0))
        current_x += img.width
    
    return concatenated_image

@spaces.GPU
def render_geo_views_tensor(mesh, mvp_matrix, img_size=(512, 512)) -> tuple[torch.Tensor, torch.Tensor]:
    """
    render the geometry information including position and normal from views that mvp matrix implies.
    """
    ctx = get_rasterizer_context()
    image_height, image_width = img_size
    position_images, normal_images, mask_images = render_geo_from_mesh(ctx, mesh, mvp_matrix, image_height, image_width)
    return position_images, normal_images, mask_images

@spaces.GPU
def render_geo_map(mesh, map_size=(1024, 1024)) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Render the geometry information including position and normal from UV parameterization.
    """
    ctx = get_rasterizer_context()
    map_height, map_width = map_size
    position_images, normal_images, mask = rasterize_position_and_normal_maps(ctx, mesh, map_height, map_width)
    # out_imgs = []
    # if mask.ndim == 4:
    #     mask = mask[0]
    # for img_map in [position_images, normal_images]:
    #     if img_map.ndim == 4:
    #         img_map = img_map[0]
    #     # normalize to [0, 1]
    #     img_map = (img_map - img_map.min()) / (img_map.max() - img_map.min() + 1e-6)
        
    #     rgba_img = torch.cat([img_map, mask], dim=-1)  # [H, W, 3] + [H, W, 1] -> [H, W, 4]
    #     rgba_img = (rgba_img * 255).to(torch.uint8)  # Convert to uint8
    #     rgba_img = rgba_img.cpu().numpy()  # Convert to numpy array
    #     out_imgs.append(Image.fromarray(rgba_img, mode='RGBA'))
    return position_images, normal_images

@cache
def get_pure_texture(uv_size, color=(int("0x55", 16), int("0x55", 16), int("0x55", 16))) -> torch.Tensor:
    """
    get a pure texture image with the specified color.
    :param uv_size: The size of the UV map (height, width).
    :param color: The color of the texture, default is "0x555555" (light gray).
    :return: A texture image tensor of shape (height, width, 3).
    """
    height, width = uv_size
    
    color = torch.tensor(color, dtype=torch.float32).view(1, 1, 3) / 255.0
    texture = color.repeat(height, width, 1)

    return texture

def get_c2w(
        azimuth_deg,
        elevation_deg,
        camera_distances,):
    assert len(azimuth_deg) == len(elevation_deg) == len(camera_distances)
    n_views = len(azimuth_deg)
    #camera_distances = torch.full_like(elevation_deg, dis)
    elevation = elevation_deg * math.pi / 180
    azimuth = azimuth_deg * math.pi / 180
    camera_positions = torch.stack(
        [
            camera_distances * torch.cos(elevation) * torch.cos(azimuth),
            camera_distances * torch.cos(elevation) * torch.sin(azimuth),
            camera_distances * torch.sin(elevation),
        ],
        dim=-1,
    )
    center = torch.zeros_like(camera_positions)
    up = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None, :].repeat(n_views, 1)
    lookat = F.normalize(center - camera_positions, dim=-1)
    right = F.normalize(torch.cross(lookat, up, dim=-1), dim=-1)
    up = F.normalize(torch.cross(right, lookat, dim=-1), dim=-1)
    c2w3x4 = torch.cat(
        [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
        dim=-1,
    )
    c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1)
    c2w[:, 3, 3] = 1.0
    return c2w

def camera_strategy_test_4_90deg(
        mesh: Dict,
        num_views: int = 4,
        **kwargs) -> Dict:
    """
    For sup views: Random elevation and azimuth, fixed distance and close fov.
    :param num_views: number of supervision views
    :param kwargs: additional arguments
    """
    # Default camera intrinsics
    default_elevation = 10
    default_camera_lens = 50
    default_camera_sensor_width = 36
    default_fovy = 2 * np.arctan(default_camera_sensor_width / (2 * default_camera_lens))

    bbox_size = mesh.v_pos.max(dim=0)[0] - mesh.v_pos.min(dim=0)[0]
    distance = default_camera_lens / default_camera_sensor_width * \
               math.sqrt(bbox_size[0] ** 2 + bbox_size[1] ** 2 + bbox_size[2] ** 2)

    all_azimuth_deg = torch.linspace(0, 360.0, num_views + 1)[:num_views] - 90

    all_elevation_deg = torch.full_like(all_azimuth_deg, default_elevation)

    # Get the corresponding azimuth and elevation
    view_idxs = torch.arange(0, num_views)
    azimuth = all_azimuth_deg[view_idxs]
    elevation = all_elevation_deg[view_idxs]
    camera_distances = torch.full_like(elevation, distance)
    c2w = get_c2w(azimuth, elevation, camera_distances)
    
    if c2w.ndim == 2:
        w2c: Float[Tensor, "4 4"] = torch.zeros(4, 4).to(c2w)
        w2c[:3, :3] = c2w[:3, :3].permute(1, 0)
        w2c[:3, 3:] = -c2w[:3, :3].permute(1, 0) @ c2w[:3, 3:]
        w2c[3, 3] = 1.0
    else:
        w2c: Float[Tensor, "B 4 4"] = torch.zeros(c2w.shape[0], 4, 4).to(c2w)
        w2c[:, :3, :3] = c2w[:, :3, :3].permute(0, 2, 1)
        w2c[:, :3, 3:] = -c2w[:, :3, :3].permute(0, 2, 1) @ c2w[:, :3, 3:]
        w2c[:, 3, 3] = 1.0

    fovy = torch.full_like(azimuth, default_fovy)

    return {
        'cond_sup_view_idxs': view_idxs,
        'cond_sup_c2w': c2w,
        'cond_sup_w2c': w2c,
        'cond_sup_fovy': fovy,
        # 'cond_sup_azimuth': azimuth,
        # 'cond_sup_elevation': elevation,
    }

def _get_projection_matrix(
    fovy: Union[float, Float[Tensor, "B"]], aspect_wh: float, near: float, far: float
) -> Float[Tensor, "*B 4 4"]:
    if isinstance(fovy, float):
        proj_mtx = torch.zeros(4, 4, dtype=torch.float32)
        proj_mtx[0, 0] = 1.0 / (math.tan(fovy / 2.0) * aspect_wh)
        proj_mtx[1, 1] = -1.0 / math.tan(
            fovy / 2.0
        )  # add a negative sign here as the y axis is flipped in nvdiffrast output
        proj_mtx[2, 2] = -(far + near) / (far - near)
        proj_mtx[2, 3] = -2.0 * far * near / (far - near)
        proj_mtx[3, 2] = -1.0
    else:
        batch_size = fovy.shape[0]
        proj_mtx = torch.zeros(batch_size, 4, 4, dtype=torch.float32)
        proj_mtx[:, 0, 0] = 1.0 / (torch.tan(fovy / 2.0) * aspect_wh)
        proj_mtx[:, 1, 1] = -1.0 / torch.tan(
            fovy / 2.0
        )  # add a negative sign here as the y axis is flipped in nvdiffrast output
        proj_mtx[:, 2, 2] = -(far + near) / (far - near)
        proj_mtx[:, 2, 3] = -2.0 * far * near / (far - near)
        proj_mtx[:, 3, 2] = -1.0
    return proj_mtx

def _get_mvp_matrix(
    c2w: Float[Tensor, "*B 4 4"], proj_mtx: Float[Tensor, "*B 4 4"]
) -> Float[Tensor, "*B 4 4"]:
    # calculate w2c from c2w: R' = Rt, t' = -Rt * t
    # mathematically equivalent to (c2w)^-1
    if c2w.ndim == 2:
        assert proj_mtx.ndim == 2
        w2c: Float[Tensor, "4 4"] = torch.zeros(4, 4).to(c2w)
        w2c[:3, :3] = c2w[:3, :3].permute(1, 0)
        w2c[:3, 3:] = -c2w[:3, :3].permute(1, 0) @ c2w[:3, 3:]
        w2c[3, 3] = 1.0
    else:
        w2c: Float[Tensor, "B 4 4"] = torch.zeros(c2w.shape[0], 4, 4).to(c2w)
        w2c[:, :3, :3] = c2w[:, :3, :3].permute(0, 2, 1)
        w2c[:, :3, 3:] = -c2w[:, :3, :3].permute(0, 2, 1) @ c2w[:, :3, 3:]
        w2c[:, 3, 3] = 1.0
    # calculate mvp matrix by proj_mtx @ w2c (mv_mtx)
    mvp_mtx = proj_mtx @ w2c
    return mvp_mtx

def get_mvp_matrix(mesh, num_views=4, width=512, height=512, strategy="strategy_test_4_90deg"):
    """
    Get Model-View-Projection (MVP) matrix for rendering views.
    :param mesh: The mesh object to determine camera positioning.
    :param num_views: Number of views to generate, default is 4.
    :param width: Image width for projection matrix calculation.
    :param height: Image height for projection matrix calculation.
    :param strategy: Camera positioning strategy, default is "strategy_test_4_90deg".
    :return: MVP matrix and world-to-camera transformation matrix.
    """
    if strategy == "strategy_test_4_90deg":
        camera_info = camera_strategy_test_4_90deg(
            mesh=mesh,  # Dummy mesh for camera strategy
            num_views=num_views,
        )
        cond_sup_fovy = camera_info["cond_sup_fovy"]
        cond_sup_c2w = camera_info["cond_sup_c2w"]
        cond_sup_w2c = camera_info["cond_sup_w2c"]
        # cond_sup_azimuth = camera_info["cond_sup_azimuth"]
        # cond_sup_elevation = camera_info["cond_sup_elevation"]
    else:
        raise ValueError(f"Unsupported camera strategy: {strategy}")
    cond_sup_proj_mtx: Float[Tensor, "B 4 4"] = _get_projection_matrix(
        cond_sup_fovy, width / height, 0.1, 1000.0
    )
    mvp_mtx: Float[Tensor, "B 4 4"] = _get_mvp_matrix(cond_sup_c2w, cond_sup_proj_mtx)
    return mvp_mtx, cond_sup_w2c

@torch.cuda.amp.autocast(enabled=False)
def _get_depth_noraml_map_with_mask(xyz_map, normal_map, mask, w2c, device="cuda", background_color=(0, 0, 0)):
    """
    Get depth and normal map with mask from position and normal images.
    :param xyz_map: Position images in world coordinate, shape [B, Nv, H, W, 3]. It is the return value of `render_geo_views`.
    :param normal_map: Normal images in world coordinate, shape [B, Nv, H, W, 3]. It is the return value of `render_geo_views`.
    :param mask: Mask for the images, shape [B, Nv, H, W]. It is the return value of `render_geo_views`.
    :param w2c: World to camera transformation matrix, shape [B, Nv, 4, 4].
    :param device: Device to run the computation on, default is "cuda".
    :param background_color: Background color for the depth and normal maps.
    :return: depth_map, normal_map, mask
    """
    w2c = w2c.to(device)

    # Render world coordinate position map and mask
    B, Nv, H, W, C = xyz_map.shape  # B: batch size, Nv: number of views, H/W: height/width, C: channels
    assert Nv == 1
    # Rearrange tensors for batch processing
    xyz_map = rearrange(xyz_map, "B Nv H W C -> (B Nv) (H W) C")
    normal_map = rearrange(normal_map, "B Nv H W C -> (B Nv) (H W) C")
    w2c = rearrange(w2c, "B Nv C1 C2 -> (B Nv) C1 C2")

    # Create homogeneous coordinates and correctly transform to camera coordinate system
    # Points in world coordinate system need to be multiplied by world-to-camera transformation matrix
    B_Nv, N, C = xyz_map.shape
    ones = torch.ones(B_Nv, N, 1, dtype=xyz_map.dtype, device=xyz_map.device)
    homogeneous_xyz = torch.cat([xyz_map, ones], dim=2)  # [x,y,z,1]
    zeros = torch.zeros(B_Nv, N, 1, dtype=xyz_map.dtype, device=xyz_map.device)
    homogeneous_normal = torch.cat([normal_map, zeros], dim=2)  # [x,y,z,1]
    
    camera_coords = torch.bmm(homogeneous_xyz, w2c.transpose(1, 2))
    camera_normals = torch.bmm(homogeneous_normal, w2c.transpose(1, 2))

    depth_map = camera_coords[..., 2:3]  # Z-axis is the depth direction in camera coordinate system
    depth_map = rearrange(depth_map, "(B Nv) (H W) 1 -> B Nv H W", B=B, Nv=Nv, H=H, W=W)
    normal_map = camera_normals[..., :3]  # Keep only x, y, z components
    normal_map = rearrange(normal_map, "(B Nv) (H W) c -> B Nv H W c", B=B, Nv=Nv, H=H, W=W)
    assert depth_map.dtype == torch.float32, f"depth_map must be float32, otherwise there will be artifact in controlnet generated pictures, but got {depth_map.dtype}"

    # Calculate min and max values
    min_depth = depth_map.amin((1,2,3), keepdim=True)
    max_depth = depth_map.amax((1,2,3), keepdim=True)

    depth_map = (depth_map - min_depth) / (max_depth - min_depth + 1e-6)  # Normalize to [0, 1]
    
    depth_map = depth_map.repeat(1, 3, 1, 1)  # Repeat 3 times to get RGB depth map
    normal_map = normal_map * 0.5 + 0.5  # Normalize to [0, 1], [B, Nv, H, W, 3]
    normal_map = normal_map[:,0].permute(0, 3, 1, 2)  # [B, 3, H, W]

    rgb_background_batched = torch.tensor(background_color, dtype=torch.float32, device=device).view(1, 3, 1, 1)
    depth_map = torch.lerp(rgb_background_batched, depth_map, mask)
    normal_map = torch.lerp(rgb_background_batched, normal_map, mask)

    return depth_map, normal_map, mask

@spaces.GPU
def get_silhouette_image(position_imgs, normal_imgs, mask_imgs, w2c, selected_view="First View") -> tuple[Image.Image, Image.Image]:
    """
    Get the silhouette image based on geometry image.
    :param position_imgs: Position images from different views, shape [Nv, H, W, 3].
    :param normal_imgs: Normal images from different views, shape [Nv, H, W, 3].
    :param mask_imgs: Mask for the images, shape [Nv, H, W]. It is the return value of `render_geo_views`.
    :param w2c: World to camera transformation matrix, shape [Nv, 4, 4].
    :param selected_view: The view selected for generating the image condition.
    :return: silhouettes (including depth and normal, which is in camera coordinate system).
    """
    view_id_map = {
        "First View": 0,
        "Second View": 1,
        "Third View": 2,
        "Fourth View": 3
    }
    view_id = view_id_map[selected_view]
    position_view = position_imgs[view_id: view_id + 1]
    normal_view = normal_imgs[view_id: view_id + 1]
    mask_view = mask_imgs[view_id: view_id + 1]
    w2c = w2c[view_id: view_id + 1]  # Select the corresponding w2c for the view

    depth_img, normal_img, mask = _get_depth_noraml_map_with_mask(
        position_view.unsqueeze(0),  # Add batch dimension
        normal_view.unsqueeze(0),
        mask_view.unsqueeze(0),
        w2c.unsqueeze(0),
    )

    to_img = ToPILImage()
    return to_img(depth_img.squeeze(0)), to_img(normal_img.squeeze(0)), to_img(mask.squeeze(0))