davidvgilmore commited on
Commit
143b538
·
verified ·
1 Parent(s): 4ae01fa

Upload hy3dgen/texgen/differentiable_renderer/mesh_render.py with huggingface_hub

Browse files
hy3dgen/texgen/differentiable_renderer/mesh_render.py ADDED
@@ -0,0 +1,833 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Open Source Model Licensed under the Apache License Version 2.0
2
+ # and Other Licenses of the Third-Party Components therein:
3
+ # The below Model in this distribution may have been modified by THL A29 Limited
4
+ # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
5
+
6
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
7
+ # The below software and/or models in this distribution may have been
8
+ # modified by THL A29 Limited ("Tencent Modifications").
9
+ # All Tencent Modifications are Copyright (C) THL A29 Limited.
10
+
11
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
12
+ # except for the third-party components listed below.
13
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
14
+ # in the repsective licenses of these third-party components.
15
+ # Users must comply with all terms and conditions of original licenses of these third-party
16
+ # components and must ensure that the usage of the third party components adheres to
17
+ # all relevant laws and regulations.
18
+
19
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
20
+ # their software and algorithms, including trained model weights, parameters (including
21
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
22
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
23
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
+
25
+ import cv2
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn.functional as F
29
+ import trimesh
30
+ from PIL import Image
31
+
32
+ from .camera_utils import (
33
+ transform_pos,
34
+ get_mv_matrix,
35
+ get_orthographic_projection_matrix,
36
+ get_perspective_projection_matrix,
37
+ )
38
+ from .mesh_processor import meshVerticeInpaint
39
+ from .mesh_utils import load_mesh, save_mesh
40
+
41
+
42
+ def stride_from_shape(shape):
43
+ stride = [1]
44
+ for x in reversed(shape[1:]):
45
+ stride.append(stride[-1] * x)
46
+ return list(reversed(stride))
47
+
48
+
49
+ def scatter_add_nd_with_count(input, count, indices, values, weights=None):
50
+ # input: [..., C], D dimension + C channel
51
+ # count: [..., 1], D dimension
52
+ # indices: [N, D], long
53
+ # values: [N, C]
54
+
55
+ D = indices.shape[-1]
56
+ C = input.shape[-1]
57
+ size = input.shape[:-1]
58
+ stride = stride_from_shape(size)
59
+
60
+ assert len(size) == D
61
+
62
+ input = input.view(-1, C) # [HW, C]
63
+ count = count.view(-1, 1)
64
+
65
+ flatten_indices = (indices * torch.tensor(stride,
66
+ dtype=torch.long, device=indices.device)).sum(-1) # [N]
67
+
68
+ if weights is None:
69
+ weights = torch.ones_like(values[..., :1])
70
+
71
+ input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values)
72
+ count.scatter_add_(0, flatten_indices.unsqueeze(1), weights)
73
+
74
+ return input.view(*size, C), count.view(*size, 1)
75
+
76
+
77
+ def linear_grid_put_2d(H, W, coords, values, return_count=False):
78
+ # coords: [N, 2], float in [0, 1]
79
+ # values: [N, C]
80
+
81
+ C = values.shape[-1]
82
+
83
+ indices = coords * torch.tensor(
84
+ [H - 1, W - 1], dtype=torch.float32, device=coords.device
85
+ )
86
+ indices_00 = indices.floor().long() # [N, 2]
87
+ indices_00[:, 0].clamp_(0, H - 2)
88
+ indices_00[:, 1].clamp_(0, W - 2)
89
+ indices_01 = indices_00 + torch.tensor(
90
+ [0, 1], dtype=torch.long, device=indices.device
91
+ )
92
+ indices_10 = indices_00 + torch.tensor(
93
+ [1, 0], dtype=torch.long, device=indices.device
94
+ )
95
+ indices_11 = indices_00 + torch.tensor(
96
+ [1, 1], dtype=torch.long, device=indices.device
97
+ )
98
+
99
+ h = indices[..., 0] - indices_00[..., 0].float()
100
+ w = indices[..., 1] - indices_00[..., 1].float()
101
+ w_00 = (1 - h) * (1 - w)
102
+ w_01 = (1 - h) * w
103
+ w_10 = h * (1 - w)
104
+ w_11 = h * w
105
+
106
+ result = torch.zeros(H, W, C, device=values.device,
107
+ dtype=values.dtype) # [H, W, C]
108
+ count = torch.zeros(H, W, 1, device=values.device,
109
+ dtype=values.dtype) # [H, W, 1]
110
+ weights = torch.ones_like(values[..., :1]) # [N, 1]
111
+
112
+ result, count = scatter_add_nd_with_count(
113
+ result, count, indices_00, values * w_00.unsqueeze(1), weights * w_00.unsqueeze(1))
114
+ result, count = scatter_add_nd_with_count(
115
+ result, count, indices_01, values * w_01.unsqueeze(1), weights * w_01.unsqueeze(1))
116
+ result, count = scatter_add_nd_with_count(
117
+ result, count, indices_10, values * w_10.unsqueeze(1), weights * w_10.unsqueeze(1))
118
+ result, count = scatter_add_nd_with_count(
119
+ result, count, indices_11, values * w_11.unsqueeze(1), weights * w_11.unsqueeze(1))
120
+
121
+ if return_count:
122
+ return result, count
123
+
124
+ mask = (count.squeeze(-1) > 0)
125
+ result[mask] = result[mask] / count[mask].repeat(1, C)
126
+
127
+ return result
128
+
129
+
130
+ class MeshRender():
131
+ def __init__(
132
+ self,
133
+ camera_distance=1.45, camera_type='orth',
134
+ default_resolution=1024, texture_size=1024,
135
+ use_antialias=True, max_mip_level=None, filter_mode='linear',
136
+ bake_mode='linear', raster_mode='cr', device='cuda'):
137
+
138
+ self.device = device
139
+
140
+ self.set_default_render_resolution(default_resolution)
141
+ self.set_default_texture_resolution(texture_size)
142
+
143
+ self.camera_distance = camera_distance
144
+ self.use_antialias = use_antialias
145
+ self.max_mip_level = max_mip_level
146
+ self.filter_mode = filter_mode
147
+
148
+ self.bake_angle_thres = 75
149
+ self.bake_unreliable_kernel_size = int(
150
+ (2 / 512) * max(self.default_resolution[0], self.default_resolution[1]))
151
+ self.bake_mode = bake_mode
152
+
153
+ self.raster_mode = raster_mode
154
+ if self.raster_mode == 'cr':
155
+ import custom_rasterizer as cr
156
+ self.raster = cr
157
+ else:
158
+ raise f'No raster named {self.raster_mode}'
159
+
160
+ if camera_type == 'orth':
161
+ self.ortho_scale = 1.2
162
+ self.camera_proj_mat = get_orthographic_projection_matrix(
163
+ left=-self.ortho_scale * 0.5, right=self.ortho_scale * 0.5,
164
+ bottom=-self.ortho_scale * 0.5, top=self.ortho_scale * 0.5,
165
+ near=0.1, far=100
166
+ )
167
+ elif camera_type == 'perspective':
168
+ self.camera_proj_mat = get_perspective_projection_matrix(
169
+ 49.13, self.default_resolution[1] / self.default_resolution[0],
170
+ 0.01, 100.0
171
+ )
172
+ else:
173
+ raise f'No camera type {camera_type}'
174
+
175
+ def raster_rasterize(self, pos, tri, resolution, ranges=None, grad_db=True):
176
+
177
+ if self.raster_mode == 'cr':
178
+ rast_out_db = None
179
+ if pos.dim() == 2:
180
+ pos = pos.unsqueeze(0)
181
+ findices, barycentric = self.raster.rasterize(pos, tri, resolution)
182
+ rast_out = torch.cat((barycentric, findices.unsqueeze(-1)), dim=-1)
183
+ rast_out = rast_out.unsqueeze(0)
184
+ else:
185
+ raise f'No raster named {self.raster_mode}'
186
+
187
+ return rast_out, rast_out_db
188
+
189
+ def raster_interpolate(self, uv, rast_out, uv_idx, rast_db=None, diff_attrs=None):
190
+
191
+ if self.raster_mode == 'cr':
192
+ textd = None
193
+ barycentric = rast_out[0, ..., :-1]
194
+ findices = rast_out[0, ..., -1]
195
+ if uv.dim() == 2:
196
+ uv = uv.unsqueeze(0)
197
+ textc = self.raster.interpolate(uv, findices, barycentric, uv_idx)
198
+ else:
199
+ raise f'No raster named {self.raster_mode}'
200
+
201
+ return textc, textd
202
+
203
+ def raster_texture(self, tex, uv, uv_da=None, mip_level_bias=None, mip=None, filter_mode='auto',
204
+ boundary_mode='wrap', max_mip_level=None):
205
+
206
+ if self.raster_mode == 'cr':
207
+ raise f'Texture is not implemented in cr'
208
+ else:
209
+ raise f'No raster named {self.raster_mode}'
210
+
211
+ return color
212
+
213
+ def raster_antialias(self, color, rast, pos, tri, topology_hash=None, pos_gradient_boost=1.0):
214
+
215
+ if self.raster_mode == 'cr':
216
+ # Antialias has not been supported yet
217
+ color = color
218
+ else:
219
+ raise f'No raster named {self.raster_mode}'
220
+
221
+ return color
222
+
223
+ def load_mesh(
224
+ self,
225
+ mesh,
226
+ scale_factor=1.15,
227
+ auto_center=True,
228
+ ):
229
+ vtx_pos, pos_idx, vtx_uv, uv_idx, texture_data = load_mesh(mesh)
230
+ self.mesh_copy = mesh
231
+ self.set_mesh(vtx_pos, pos_idx,
232
+ vtx_uv=vtx_uv, uv_idx=uv_idx,
233
+ scale_factor=scale_factor, auto_center=auto_center
234
+ )
235
+ if texture_data is not None:
236
+ self.set_texture(texture_data)
237
+
238
+ def save_mesh(self):
239
+ texture_data = self.get_texture()
240
+ texture_data = Image.fromarray((texture_data * 255).astype(np.uint8))
241
+ return save_mesh(self.mesh_copy, texture_data)
242
+
243
+ def set_mesh(
244
+ self,
245
+ vtx_pos, pos_idx,
246
+ vtx_uv=None, uv_idx=None,
247
+ scale_factor=1.15, auto_center=True
248
+ ):
249
+
250
+ self.vtx_pos = torch.from_numpy(vtx_pos).to(self.device).float()
251
+ self.pos_idx = torch.from_numpy(pos_idx).to(self.device).to(torch.int)
252
+ if (vtx_uv is not None) and (uv_idx is not None):
253
+ self.vtx_uv = torch.from_numpy(vtx_uv).to(self.device).float()
254
+ self.uv_idx = torch.from_numpy(uv_idx).to(self.device).to(torch.int)
255
+ else:
256
+ self.vtx_uv = None
257
+ self.uv_idx = None
258
+
259
+ self.vtx_pos[:, [0, 1]] = -self.vtx_pos[:, [0, 1]]
260
+ self.vtx_pos[:, [1, 2]] = self.vtx_pos[:, [2, 1]]
261
+ if (vtx_uv is not None) and (uv_idx is not None):
262
+ self.vtx_uv[:, 1] = 1.0 - self.vtx_uv[:, 1]
263
+
264
+ if auto_center:
265
+ max_bb = (self.vtx_pos - 0).max(0)[0]
266
+ min_bb = (self.vtx_pos - 0).min(0)[0]
267
+ center = (max_bb + min_bb) / 2
268
+ scale = torch.norm(self.vtx_pos - center, dim=1).max() * 2.0
269
+ self.vtx_pos = (self.vtx_pos - center) * \
270
+ (scale_factor / float(scale))
271
+ self.scale_factor = scale_factor
272
+
273
+ def set_texture(self, tex):
274
+ if isinstance(tex, np.ndarray):
275
+ tex = Image.fromarray((tex * 255).astype(np.uint8))
276
+ elif isinstance(tex, torch.Tensor):
277
+ tex = tex.cpu().numpy()
278
+ tex = Image.fromarray((tex * 255).astype(np.uint8))
279
+
280
+ tex = tex.resize(self.texture_size).convert('RGB')
281
+ tex = np.array(tex) / 255.0
282
+ self.tex = torch.from_numpy(tex).to(self.device)
283
+ self.tex = self.tex.float()
284
+
285
+ def set_default_render_resolution(self, default_resolution):
286
+ if isinstance(default_resolution, int):
287
+ default_resolution = (default_resolution, default_resolution)
288
+ self.default_resolution = default_resolution
289
+
290
+ def set_default_texture_resolution(self, texture_size):
291
+ if isinstance(texture_size, int):
292
+ texture_size = (texture_size, texture_size)
293
+ self.texture_size = texture_size
294
+
295
+ def get_mesh(self):
296
+ vtx_pos = self.vtx_pos.cpu().numpy()
297
+ pos_idx = self.pos_idx.cpu().numpy()
298
+ vtx_uv = self.vtx_uv.cpu().numpy()
299
+ uv_idx = self.uv_idx.cpu().numpy()
300
+
301
+ # 坐标变换的逆变换
302
+ vtx_pos[:, [1, 2]] = vtx_pos[:, [2, 1]]
303
+ vtx_pos[:, [0, 1]] = -vtx_pos[:, [0, 1]]
304
+
305
+ vtx_uv[:, 1] = 1.0 - vtx_uv[:, 1]
306
+ return vtx_pos, pos_idx, vtx_uv, uv_idx
307
+
308
+ def get_texture(self):
309
+ return self.tex.cpu().numpy()
310
+
311
+ def to(self, device):
312
+ self.device = device
313
+
314
+ for attr_name in dir(self):
315
+ attr_value = getattr(self, attr_name)
316
+ if isinstance(attr_value, torch.Tensor):
317
+ setattr(self, attr_name, attr_value.to(self.device))
318
+
319
+ def color_rgb_to_srgb(self, image):
320
+ if isinstance(image, Image.Image):
321
+ image_rgb = torch.tesnor(
322
+ np.array(image) /
323
+ 255.0).float().to(
324
+ self.device)
325
+ elif isinstance(image, np.ndarray):
326
+ image_rgb = torch.tensor(image).float()
327
+ else:
328
+ image_rgb = image.to(self.device)
329
+
330
+ image_srgb = torch.where(
331
+ image_rgb <= 0.0031308,
332
+ 12.92 * image_rgb,
333
+ 1.055 * torch.pow(image_rgb, 1 / 2.4) - 0.055
334
+ )
335
+
336
+ if isinstance(image, Image.Image):
337
+ image_srgb = Image.fromarray(
338
+ (image_srgb.cpu().numpy() *
339
+ 255).astype(
340
+ np.uint8))
341
+ elif isinstance(image, np.ndarray):
342
+ image_srgb = image_srgb.cpu().numpy()
343
+ else:
344
+ image_srgb = image_srgb.to(image.device)
345
+
346
+ return image_srgb
347
+
348
+ def _render(
349
+ self,
350
+ glctx,
351
+ mvp,
352
+ pos,
353
+ pos_idx,
354
+ uv,
355
+ uv_idx,
356
+ tex,
357
+ resolution,
358
+ max_mip_level,
359
+ keep_alpha,
360
+ filter_mode
361
+ ):
362
+ pos_clip = transform_pos(mvp, pos)
363
+ if isinstance(resolution, (int, float)):
364
+ resolution = [resolution, resolution]
365
+ rast_out, rast_out_db = self.raster_rasterize(
366
+ glctx, pos_clip, pos_idx, resolution=resolution)
367
+
368
+ tex = tex.contiguous()
369
+ if filter_mode == 'linear-mipmap-linear':
370
+ texc, texd = self.raster_interpolate(
371
+ uv[None, ...], rast_out, uv_idx, rast_db=rast_out_db, diff_attrs='all')
372
+ color = self.raster_texture(
373
+ tex[None, ...], texc, texd, filter_mode='linear-mipmap-linear', max_mip_level=max_mip_level)
374
+ else:
375
+ texc, _ = self.raster_interpolate(uv[None, ...], rast_out, uv_idx)
376
+ color = self.raster_texture(tex[None, ...], texc, filter_mode=filter_mode)
377
+
378
+ visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)
379
+ color = color * visible_mask # Mask out background.
380
+ if self.use_antialias:
381
+ color = self.raster_antialias(color, rast_out, pos_clip, pos_idx)
382
+
383
+ if keep_alpha:
384
+ color = torch.cat([color, visible_mask], dim=-1)
385
+ return color[0, ...]
386
+
387
+ def render(
388
+ self,
389
+ elev,
390
+ azim,
391
+ camera_distance=None,
392
+ center=None,
393
+ resolution=None,
394
+ tex=None,
395
+ keep_alpha=True,
396
+ bgcolor=None,
397
+ filter_mode=None,
398
+ return_type='th'
399
+ ):
400
+
401
+ proj = self.camera_proj_mat
402
+ r_mv = get_mv_matrix(
403
+ elev=elev,
404
+ azim=azim,
405
+ camera_distance=self.camera_distance if camera_distance is None else camera_distance,
406
+ center=center)
407
+ r_mvp = np.matmul(proj, r_mv).astype(np.float32)
408
+ if tex is not None:
409
+ if isinstance(tex, Image.Image):
410
+ tex = torch.tensor(np.array(tex) / 255.0)
411
+ elif isinstance(tex, np.ndarray):
412
+ tex = torch.tensor(tex)
413
+ if tex.dim() == 2:
414
+ tex = tex.unsqueeze(-1)
415
+ tex = tex.float().to(self.device)
416
+ image = self._render(r_mvp, self.vtx_pos, self.pos_idx, self.vtx_uv, self.uv_idx,
417
+ self.tex if tex is None else tex,
418
+ self.default_resolution if resolution is None else resolution,
419
+ self.max_mip_level, True, filter_mode if filter_mode else self.filter_mode)
420
+ mask = (image[..., [-1]] == 1).float()
421
+ if bgcolor is None:
422
+ bgcolor = [0 for _ in range(image.shape[-1] - 1)]
423
+ image = image * mask + (1 - mask) * \
424
+ torch.tensor(bgcolor + [0]).to(self.device)
425
+ if keep_alpha == False:
426
+ image = image[..., :-1]
427
+ if return_type == 'np':
428
+ image = image.cpu().numpy()
429
+ elif return_type == 'pl':
430
+ image = image.squeeze(-1).cpu().numpy() * 255
431
+ image = Image.fromarray(image.astype(np.uint8))
432
+ return image
433
+
434
+ def render_normal(
435
+ self,
436
+ elev,
437
+ azim,
438
+ camera_distance=None,
439
+ center=None,
440
+ resolution=None,
441
+ bg_color=[1, 1, 1],
442
+ use_abs_coor=False,
443
+ normalize_rgb=True,
444
+ return_type='th'
445
+ ):
446
+
447
+ pos_camera, pos_clip = self.get_pos_from_mvp(elev, azim, camera_distance, center)
448
+ if resolution is None:
449
+ resolution = self.default_resolution
450
+ if isinstance(resolution, (int, float)):
451
+ resolution = [resolution, resolution]
452
+ rast_out, rast_out_db = self.raster_rasterize(
453
+ pos_clip, self.pos_idx, resolution=resolution)
454
+
455
+ if use_abs_coor:
456
+ mesh_triangles = self.vtx_pos[self.pos_idx[:, :3], :]
457
+ else:
458
+ pos_camera = pos_camera[:, :3] / pos_camera[:, 3:4]
459
+ mesh_triangles = pos_camera[self.pos_idx[:, :3], :]
460
+ face_normals = F.normalize(
461
+ torch.cross(mesh_triangles[:,
462
+ 1,
463
+ :] - mesh_triangles[:,
464
+ 0,
465
+ :],
466
+ mesh_triangles[:,
467
+ 2,
468
+ :] - mesh_triangles[:,
469
+ 0,
470
+ :],
471
+ dim=-1),
472
+ dim=-1)
473
+
474
+ vertex_normals = trimesh.geometry.mean_vertex_normals(vertex_count=self.vtx_pos.shape[0],
475
+ faces=self.pos_idx.cpu(),
476
+ face_normals=face_normals.cpu(), )
477
+ vertex_normals = torch.from_numpy(
478
+ vertex_normals).float().to(self.device).contiguous()
479
+
480
+ # Interpolate normal values across the rasterized pixels
481
+ normal, _ = self.raster_interpolate(
482
+ vertex_normals[None, ...], rast_out, self.pos_idx)
483
+
484
+ visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)
485
+ normal = normal * visible_mask + \
486
+ torch.tensor(bg_color, dtype=torch.float32, device=self.device) * (1 -
487
+ visible_mask) # Mask out background.
488
+
489
+ if normalize_rgb:
490
+ normal = (normal + 1) * 0.5
491
+ if self.use_antialias:
492
+ normal = self.raster_antialias(normal, rast_out, pos_clip, self.pos_idx)
493
+
494
+ image = normal[0, ...]
495
+ if return_type == 'np':
496
+ image = image.cpu().numpy()
497
+ elif return_type == 'pl':
498
+ image = image.cpu().numpy() * 255
499
+ image = Image.fromarray(image.astype(np.uint8))
500
+
501
+ return image
502
+
503
+ def convert_normal_map(self, image):
504
+ # blue is front, red is left, green is top
505
+ if isinstance(image, Image.Image):
506
+ image = np.array(image)
507
+ mask = (image == [255, 255, 255]).all(axis=-1)
508
+
509
+ image = (image / 255.0) * 2.0 - 1.0
510
+
511
+ image[..., [1]] = -image[..., [1]]
512
+ image[..., [1, 2]] = image[..., [2, 1]]
513
+ image[..., [0]] = -image[..., [0]]
514
+
515
+ image = (image + 1.0) * 0.5
516
+
517
+ image = (image * 255).astype(np.uint8)
518
+ image[mask] = [127, 127, 255]
519
+
520
+ return Image.fromarray(image)
521
+
522
+ def get_pos_from_mvp(self, elev, azim, camera_distance, center):
523
+ proj = self.camera_proj_mat
524
+ r_mv = get_mv_matrix(
525
+ elev=elev,
526
+ azim=azim,
527
+ camera_distance=self.camera_distance if camera_distance is None else camera_distance,
528
+ center=center)
529
+
530
+ pos_camera = transform_pos(r_mv, self.vtx_pos, keepdim=True)
531
+ pos_clip = transform_pos(proj, pos_camera)
532
+
533
+ return pos_camera, pos_clip
534
+
535
+ def render_depth(
536
+ self,
537
+ elev,
538
+ azim,
539
+ camera_distance=None,
540
+ center=None,
541
+ resolution=None,
542
+ return_type='th'
543
+ ):
544
+ pos_camera, pos_clip = self.get_pos_from_mvp(elev, azim, camera_distance, center)
545
+
546
+ if resolution is None:
547
+ resolution = self.default_resolution
548
+ if isinstance(resolution, (int, float)):
549
+ resolution = [resolution, resolution]
550
+ rast_out, rast_out_db = self.raster_rasterize(
551
+ pos_clip, self.pos_idx, resolution=resolution)
552
+
553
+ pos_camera = pos_camera[:, :3] / pos_camera[:, 3:4]
554
+ tex_depth = pos_camera[:, 2].reshape(1, -1, 1).contiguous()
555
+
556
+ # Interpolate depth values across the rasterized pixels
557
+ depth, _ = self.raster_interpolate(tex_depth, rast_out, self.pos_idx)
558
+
559
+ visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)
560
+ depth_max, depth_min = depth[visible_mask >
561
+ 0].max(), depth[visible_mask > 0].min()
562
+ depth = (depth - depth_min) / (depth_max - depth_min)
563
+
564
+ depth = depth * visible_mask # Mask out background.
565
+ if self.use_antialias:
566
+ depth = self.raster_antialias(depth, rast_out, pos_clip, self.pos_idx)
567
+
568
+ image = depth[0, ...]
569
+ if return_type == 'np':
570
+ image = image.cpu().numpy()
571
+ elif return_type == 'pl':
572
+ image = image.squeeze(-1).cpu().numpy() * 255
573
+ image = Image.fromarray(image.astype(np.uint8))
574
+ return image
575
+
576
+ def render_position(self, elev, azim, camera_distance=None, center=None,
577
+ resolution=None, bg_color=[1, 1, 1], return_type='th'):
578
+ pos_camera, pos_clip = self.get_pos_from_mvp(elev, azim, camera_distance, center)
579
+ if resolution is None:
580
+ resolution = self.default_resolution
581
+ if isinstance(resolution, (int, float)):
582
+ resolution = [resolution, resolution]
583
+ rast_out, rast_out_db = self.raster_rasterize(
584
+ pos_clip, self.pos_idx, resolution=resolution)
585
+
586
+ tex_position = 0.5 - self.vtx_pos[:, :3] / self.scale_factor
587
+ tex_position = tex_position.contiguous()
588
+
589
+ # Interpolate depth values across the rasterized pixels
590
+ position, _ = self.raster_interpolate(
591
+ tex_position[None, ...], rast_out, self.pos_idx)
592
+
593
+ visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)
594
+
595
+ position = position * visible_mask + \
596
+ torch.tensor(bg_color, dtype=torch.float32, device=self.device) * (1 -
597
+ visible_mask) # Mask out background.
598
+ if self.use_antialias:
599
+ position = self.raster_antialias(position, rast_out, pos_clip, self.pos_idx)
600
+
601
+ image = position[0, ...]
602
+
603
+ if return_type == 'np':
604
+ image = image.cpu().numpy()
605
+ elif return_type == 'pl':
606
+ image = image.squeeze(-1).cpu().numpy() * 255
607
+ image = Image.fromarray(image.astype(np.uint8))
608
+ return image
609
+
610
+ def render_uvpos(self, return_type='th'):
611
+ image = self.uv_feature_map(self.vtx_pos * 0.5 + 0.5)
612
+ if return_type == 'np':
613
+ image = image.cpu().numpy()
614
+ elif return_type == 'pl':
615
+ image = image.cpu().numpy() * 255
616
+ image = Image.fromarray(image.astype(np.uint8))
617
+ return image
618
+
619
+ def uv_feature_map(self, vert_feat, bg=None):
620
+ vtx_uv = self.vtx_uv * 2 - 1.0
621
+ vtx_uv = torch.cat(
622
+ [vtx_uv, torch.zeros_like(self.vtx_uv)], dim=1).unsqueeze(0)
623
+ vtx_uv[..., -1] = 1
624
+ uv_idx = self.uv_idx
625
+ rast_out, rast_out_db = self.raster_rasterize(
626
+ vtx_uv, uv_idx, resolution=self.texture_size)
627
+ feat_map, _ = self.raster_interpolate(vert_feat[None, ...], rast_out, uv_idx)
628
+ feat_map = feat_map[0, ...]
629
+ if bg is not None:
630
+ visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0, ...]
631
+ feat_map[visible_mask == 0] = bg
632
+ return feat_map
633
+
634
+ def render_sketch_from_geometry(self, normal_image, depth_image):
635
+ normal_image_np = normal_image.cpu().numpy()
636
+ depth_image_np = depth_image.cpu().numpy()
637
+
638
+ normal_image_np = (normal_image_np * 255).astype(np.uint8)
639
+ depth_image_np = (depth_image_np * 255).astype(np.uint8)
640
+ normal_image_np = cv2.cvtColor(normal_image_np, cv2.COLOR_RGB2GRAY)
641
+
642
+ normal_edges = cv2.Canny(normal_image_np, 80, 150)
643
+ depth_edges = cv2.Canny(depth_image_np, 30, 80)
644
+
645
+ combined_edges = np.maximum(normal_edges, depth_edges)
646
+
647
+ sketch_image = torch.from_numpy(combined_edges).to(
648
+ normal_image.device).float() / 255.0
649
+ sketch_image = sketch_image.unsqueeze(-1)
650
+
651
+ return sketch_image
652
+
653
+ def render_sketch_from_depth(self, depth_image):
654
+ depth_image_np = depth_image.cpu().numpy()
655
+ depth_image_np = (depth_image_np * 255).astype(np.uint8)
656
+ depth_edges = cv2.Canny(depth_image_np, 30, 80)
657
+ combined_edges = depth_edges
658
+ sketch_image = torch.from_numpy(combined_edges).to(
659
+ depth_image.device).float() / 255.0
660
+ sketch_image = sketch_image.unsqueeze(-1)
661
+ return sketch_image
662
+
663
+ def back_project(self, image, elev, azim,
664
+ camera_distance=None, center=None, method=None):
665
+ if isinstance(image, Image.Image):
666
+ image = torch.tensor(np.array(image) / 255.0)
667
+ elif isinstance(image, np.ndarray):
668
+ image = torch.tensor(image)
669
+ if image.dim() == 2:
670
+ image = image.unsqueeze(-1)
671
+ image = image.float().to(self.device)
672
+ resolution = image.shape[:2]
673
+ channel = image.shape[-1]
674
+ texture = torch.zeros(self.texture_size + (channel,)).to(self.device)
675
+ cos_map = torch.zeros(self.texture_size + (1,)).to(self.device)
676
+
677
+ proj = self.camera_proj_mat
678
+ r_mv = get_mv_matrix(
679
+ elev=elev,
680
+ azim=azim,
681
+ camera_distance=self.camera_distance if camera_distance is None else camera_distance,
682
+ center=center)
683
+ pos_camera = transform_pos(r_mv, self.vtx_pos, keepdim=True)
684
+ pos_clip = transform_pos(proj, pos_camera)
685
+ pos_camera = pos_camera[:, :3] / pos_camera[:, 3:4]
686
+ v0 = pos_camera[self.pos_idx[:, 0], :]
687
+ v1 = pos_camera[self.pos_idx[:, 1], :]
688
+ v2 = pos_camera[self.pos_idx[:, 2], :]
689
+ face_normals = F.normalize(
690
+ torch.cross(
691
+ v1 - v0,
692
+ v2 - v0,
693
+ dim=-1),
694
+ dim=-1)
695
+ vertex_normals = trimesh.geometry.mean_vertex_normals(vertex_count=self.vtx_pos.shape[0],
696
+ faces=self.pos_idx.cpu(),
697
+ face_normals=face_normals.cpu(), )
698
+ vertex_normals = torch.from_numpy(
699
+ vertex_normals).float().to(self.device).contiguous()
700
+ tex_depth = pos_camera[:, 2].reshape(1, -1, 1).contiguous()
701
+ rast_out, rast_out_db = self.raster_rasterize(
702
+ pos_clip, self.pos_idx, resolution=resolution)
703
+ visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0, ...]
704
+
705
+ normal, _ = self.raster_interpolate(
706
+ vertex_normals[None, ...], rast_out, self.pos_idx)
707
+ normal = normal[0, ...]
708
+ uv, _ = self.raster_interpolate(self.vtx_uv[None, ...], rast_out, self.uv_idx)
709
+ depth, _ = self.raster_interpolate(tex_depth, rast_out, self.pos_idx)
710
+ depth = depth[0, ...]
711
+
712
+ depth_max, depth_min = depth[visible_mask >
713
+ 0].max(), depth[visible_mask > 0].min()
714
+ depth_normalized = (depth - depth_min) / (depth_max - depth_min)
715
+ depth_image = depth_normalized * visible_mask # Mask out background.
716
+
717
+ sketch_image = self.render_sketch_from_depth(depth_image)
718
+
719
+ lookat = torch.tensor([[0, 0, -1]], device=self.device)
720
+ cos_image = torch.nn.functional.cosine_similarity(
721
+ lookat, normal.view(-1, 3))
722
+ cos_image = cos_image.view(normal.shape[0], normal.shape[1], 1)
723
+
724
+ cos_thres = np.cos(self.bake_angle_thres / 180 * np.pi)
725
+ cos_image[cos_image < cos_thres] = 0
726
+
727
+ # shrink
728
+ kernel_size = self.bake_unreliable_kernel_size * 2 + 1
729
+ kernel = torch.ones(
730
+ (1, 1, kernel_size, kernel_size), dtype=torch.float32).to(
731
+ sketch_image.device)
732
+
733
+ visible_mask = visible_mask.permute(2, 0, 1).unsqueeze(0).float()
734
+ visible_mask = F.conv2d(
735
+ 1.0 - visible_mask,
736
+ kernel,
737
+ padding=kernel_size // 2)
738
+ visible_mask = 1.0 - (visible_mask > 0).float() # 二值化
739
+ visible_mask = visible_mask.squeeze(0).permute(1, 2, 0)
740
+
741
+ sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
742
+ sketch_image = F.conv2d(sketch_image, kernel, padding=kernel_size // 2)
743
+ sketch_image = (sketch_image > 0).float() # 二值化
744
+ sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
745
+ visible_mask = visible_mask * (sketch_image < 0.5)
746
+
747
+ cos_image[visible_mask == 0] = 0
748
+
749
+ method = self.bake_mode if method is None else method
750
+
751
+ if method == 'linear':
752
+ proj_mask = (visible_mask != 0).view(-1)
753
+ uv = uv.squeeze(0).contiguous().view(-1, 2)[proj_mask]
754
+ image = image.squeeze(0).contiguous().view(-1, channel)[proj_mask]
755
+ cos_image = cos_image.contiguous().view(-1, 1)[proj_mask]
756
+ sketch_image = sketch_image.contiguous().view(-1, 1)[proj_mask]
757
+
758
+ texture = linear_grid_put_2d(
759
+ self.texture_size[1], self.texture_size[0], uv[..., [1, 0]], image)
760
+ cos_map = linear_grid_put_2d(
761
+ self.texture_size[1], self.texture_size[0], uv[..., [1, 0]], cos_image)
762
+ boundary_map = linear_grid_put_2d(
763
+ self.texture_size[1], self.texture_size[0], uv[..., [1, 0]], sketch_image)
764
+ else:
765
+ raise f'No bake mode {method}'
766
+
767
+ return texture, cos_map, boundary_map
768
+
769
+ def bake_texture(self, colors, elevs, azims,
770
+ camera_distance=None, center=None, exp=6, weights=None):
771
+ for i in range(len(colors)):
772
+ if isinstance(colors[i], Image.Image):
773
+ colors[i] = torch.tensor(
774
+ np.array(
775
+ colors[i]) / 255.0,
776
+ device=self.device).float()
777
+ if weights is None:
778
+ weights = [1.0 for _ in range(colors)]
779
+ textures = []
780
+ cos_maps = []
781
+ for color, elev, azim, weight in zip(colors, elevs, azims, weights):
782
+ texture, cos_map, _ = self.back_project(
783
+ color, elev, azim, camera_distance, center)
784
+ cos_map = weight * (cos_map ** exp)
785
+ textures.append(texture)
786
+ cos_maps.append(cos_map)
787
+
788
+ texture_merge, trust_map_merge = self.fast_bake_texture(
789
+ textures, cos_maps)
790
+ return texture_merge, trust_map_merge
791
+
792
+ @torch.no_grad()
793
+ def fast_bake_texture(self, textures, cos_maps):
794
+
795
+ channel = textures[0].shape[-1]
796
+ texture_merge = torch.zeros(
797
+ self.texture_size + (channel,)).to(self.device)
798
+ trust_map_merge = torch.zeros(self.texture_size + (1,)).to(self.device)
799
+ for texture, cos_map in zip(textures, cos_maps):
800
+ view_sum = (cos_map > 0).sum()
801
+ painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()
802
+ if painted_sum / view_sum > 0.99:
803
+ continue
804
+ texture_merge += texture * cos_map
805
+ trust_map_merge += cos_map
806
+ texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1E-8)
807
+
808
+ return texture_merge, trust_map_merge > 1E-8
809
+
810
+ def uv_inpaint(self, texture, mask):
811
+
812
+ if isinstance(texture, torch.Tensor):
813
+ texture_np = texture.cpu().numpy()
814
+ elif isinstance(texture, np.ndarray):
815
+ texture_np = texture
816
+ elif isinstance(texture, Image.Image):
817
+ texture_np = np.array(texture) / 255.0
818
+
819
+ vtx_pos, pos_idx, vtx_uv, uv_idx = self.get_mesh()
820
+
821
+ texture_np, mask = meshVerticeInpaint(
822
+ texture_np, mask, vtx_pos, vtx_uv, pos_idx, uv_idx)
823
+
824
+ texture_np = cv2.inpaint(
825
+ (texture_np *
826
+ 255).astype(
827
+ np.uint8),
828
+ 255 -
829
+ mask,
830
+ 3,
831
+ cv2.INPAINT_NS)
832
+
833
+ return texture_np