Spaces:
Runtime error
Runtime error
Upload hy3dgen/texgen/differentiable_renderer/mesh_render.py with huggingface_hub
Browse files
hy3dgen/texgen/differentiable_renderer/mesh_render.py
ADDED
@@ -0,0 +1,833 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Open Source Model Licensed under the Apache License Version 2.0
|
2 |
+
# and Other Licenses of the Third-Party Components therein:
|
3 |
+
# The below Model in this distribution may have been modified by THL A29 Limited
|
4 |
+
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
5 |
+
|
6 |
+
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
7 |
+
# The below software and/or models in this distribution may have been
|
8 |
+
# modified by THL A29 Limited ("Tencent Modifications").
|
9 |
+
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
10 |
+
|
11 |
+
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
12 |
+
# except for the third-party components listed below.
|
13 |
+
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
14 |
+
# in the repsective licenses of these third-party components.
|
15 |
+
# Users must comply with all terms and conditions of original licenses of these third-party
|
16 |
+
# components and must ensure that the usage of the third party components adheres to
|
17 |
+
# all relevant laws and regulations.
|
18 |
+
|
19 |
+
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
20 |
+
# their software and algorithms, including trained model weights, parameters (including
|
21 |
+
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
22 |
+
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
23 |
+
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
24 |
+
|
25 |
+
import cv2
|
26 |
+
import numpy as np
|
27 |
+
import torch
|
28 |
+
import torch.nn.functional as F
|
29 |
+
import trimesh
|
30 |
+
from PIL import Image
|
31 |
+
|
32 |
+
from .camera_utils import (
|
33 |
+
transform_pos,
|
34 |
+
get_mv_matrix,
|
35 |
+
get_orthographic_projection_matrix,
|
36 |
+
get_perspective_projection_matrix,
|
37 |
+
)
|
38 |
+
from .mesh_processor import meshVerticeInpaint
|
39 |
+
from .mesh_utils import load_mesh, save_mesh
|
40 |
+
|
41 |
+
|
42 |
+
def stride_from_shape(shape):
|
43 |
+
stride = [1]
|
44 |
+
for x in reversed(shape[1:]):
|
45 |
+
stride.append(stride[-1] * x)
|
46 |
+
return list(reversed(stride))
|
47 |
+
|
48 |
+
|
49 |
+
def scatter_add_nd_with_count(input, count, indices, values, weights=None):
|
50 |
+
# input: [..., C], D dimension + C channel
|
51 |
+
# count: [..., 1], D dimension
|
52 |
+
# indices: [N, D], long
|
53 |
+
# values: [N, C]
|
54 |
+
|
55 |
+
D = indices.shape[-1]
|
56 |
+
C = input.shape[-1]
|
57 |
+
size = input.shape[:-1]
|
58 |
+
stride = stride_from_shape(size)
|
59 |
+
|
60 |
+
assert len(size) == D
|
61 |
+
|
62 |
+
input = input.view(-1, C) # [HW, C]
|
63 |
+
count = count.view(-1, 1)
|
64 |
+
|
65 |
+
flatten_indices = (indices * torch.tensor(stride,
|
66 |
+
dtype=torch.long, device=indices.device)).sum(-1) # [N]
|
67 |
+
|
68 |
+
if weights is None:
|
69 |
+
weights = torch.ones_like(values[..., :1])
|
70 |
+
|
71 |
+
input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values)
|
72 |
+
count.scatter_add_(0, flatten_indices.unsqueeze(1), weights)
|
73 |
+
|
74 |
+
return input.view(*size, C), count.view(*size, 1)
|
75 |
+
|
76 |
+
|
77 |
+
def linear_grid_put_2d(H, W, coords, values, return_count=False):
|
78 |
+
# coords: [N, 2], float in [0, 1]
|
79 |
+
# values: [N, C]
|
80 |
+
|
81 |
+
C = values.shape[-1]
|
82 |
+
|
83 |
+
indices = coords * torch.tensor(
|
84 |
+
[H - 1, W - 1], dtype=torch.float32, device=coords.device
|
85 |
+
)
|
86 |
+
indices_00 = indices.floor().long() # [N, 2]
|
87 |
+
indices_00[:, 0].clamp_(0, H - 2)
|
88 |
+
indices_00[:, 1].clamp_(0, W - 2)
|
89 |
+
indices_01 = indices_00 + torch.tensor(
|
90 |
+
[0, 1], dtype=torch.long, device=indices.device
|
91 |
+
)
|
92 |
+
indices_10 = indices_00 + torch.tensor(
|
93 |
+
[1, 0], dtype=torch.long, device=indices.device
|
94 |
+
)
|
95 |
+
indices_11 = indices_00 + torch.tensor(
|
96 |
+
[1, 1], dtype=torch.long, device=indices.device
|
97 |
+
)
|
98 |
+
|
99 |
+
h = indices[..., 0] - indices_00[..., 0].float()
|
100 |
+
w = indices[..., 1] - indices_00[..., 1].float()
|
101 |
+
w_00 = (1 - h) * (1 - w)
|
102 |
+
w_01 = (1 - h) * w
|
103 |
+
w_10 = h * (1 - w)
|
104 |
+
w_11 = h * w
|
105 |
+
|
106 |
+
result = torch.zeros(H, W, C, device=values.device,
|
107 |
+
dtype=values.dtype) # [H, W, C]
|
108 |
+
count = torch.zeros(H, W, 1, device=values.device,
|
109 |
+
dtype=values.dtype) # [H, W, 1]
|
110 |
+
weights = torch.ones_like(values[..., :1]) # [N, 1]
|
111 |
+
|
112 |
+
result, count = scatter_add_nd_with_count(
|
113 |
+
result, count, indices_00, values * w_00.unsqueeze(1), weights * w_00.unsqueeze(1))
|
114 |
+
result, count = scatter_add_nd_with_count(
|
115 |
+
result, count, indices_01, values * w_01.unsqueeze(1), weights * w_01.unsqueeze(1))
|
116 |
+
result, count = scatter_add_nd_with_count(
|
117 |
+
result, count, indices_10, values * w_10.unsqueeze(1), weights * w_10.unsqueeze(1))
|
118 |
+
result, count = scatter_add_nd_with_count(
|
119 |
+
result, count, indices_11, values * w_11.unsqueeze(1), weights * w_11.unsqueeze(1))
|
120 |
+
|
121 |
+
if return_count:
|
122 |
+
return result, count
|
123 |
+
|
124 |
+
mask = (count.squeeze(-1) > 0)
|
125 |
+
result[mask] = result[mask] / count[mask].repeat(1, C)
|
126 |
+
|
127 |
+
return result
|
128 |
+
|
129 |
+
|
130 |
+
class MeshRender():
|
131 |
+
def __init__(
|
132 |
+
self,
|
133 |
+
camera_distance=1.45, camera_type='orth',
|
134 |
+
default_resolution=1024, texture_size=1024,
|
135 |
+
use_antialias=True, max_mip_level=None, filter_mode='linear',
|
136 |
+
bake_mode='linear', raster_mode='cr', device='cuda'):
|
137 |
+
|
138 |
+
self.device = device
|
139 |
+
|
140 |
+
self.set_default_render_resolution(default_resolution)
|
141 |
+
self.set_default_texture_resolution(texture_size)
|
142 |
+
|
143 |
+
self.camera_distance = camera_distance
|
144 |
+
self.use_antialias = use_antialias
|
145 |
+
self.max_mip_level = max_mip_level
|
146 |
+
self.filter_mode = filter_mode
|
147 |
+
|
148 |
+
self.bake_angle_thres = 75
|
149 |
+
self.bake_unreliable_kernel_size = int(
|
150 |
+
(2 / 512) * max(self.default_resolution[0], self.default_resolution[1]))
|
151 |
+
self.bake_mode = bake_mode
|
152 |
+
|
153 |
+
self.raster_mode = raster_mode
|
154 |
+
if self.raster_mode == 'cr':
|
155 |
+
import custom_rasterizer as cr
|
156 |
+
self.raster = cr
|
157 |
+
else:
|
158 |
+
raise f'No raster named {self.raster_mode}'
|
159 |
+
|
160 |
+
if camera_type == 'orth':
|
161 |
+
self.ortho_scale = 1.2
|
162 |
+
self.camera_proj_mat = get_orthographic_projection_matrix(
|
163 |
+
left=-self.ortho_scale * 0.5, right=self.ortho_scale * 0.5,
|
164 |
+
bottom=-self.ortho_scale * 0.5, top=self.ortho_scale * 0.5,
|
165 |
+
near=0.1, far=100
|
166 |
+
)
|
167 |
+
elif camera_type == 'perspective':
|
168 |
+
self.camera_proj_mat = get_perspective_projection_matrix(
|
169 |
+
49.13, self.default_resolution[1] / self.default_resolution[0],
|
170 |
+
0.01, 100.0
|
171 |
+
)
|
172 |
+
else:
|
173 |
+
raise f'No camera type {camera_type}'
|
174 |
+
|
175 |
+
def raster_rasterize(self, pos, tri, resolution, ranges=None, grad_db=True):
|
176 |
+
|
177 |
+
if self.raster_mode == 'cr':
|
178 |
+
rast_out_db = None
|
179 |
+
if pos.dim() == 2:
|
180 |
+
pos = pos.unsqueeze(0)
|
181 |
+
findices, barycentric = self.raster.rasterize(pos, tri, resolution)
|
182 |
+
rast_out = torch.cat((barycentric, findices.unsqueeze(-1)), dim=-1)
|
183 |
+
rast_out = rast_out.unsqueeze(0)
|
184 |
+
else:
|
185 |
+
raise f'No raster named {self.raster_mode}'
|
186 |
+
|
187 |
+
return rast_out, rast_out_db
|
188 |
+
|
189 |
+
def raster_interpolate(self, uv, rast_out, uv_idx, rast_db=None, diff_attrs=None):
|
190 |
+
|
191 |
+
if self.raster_mode == 'cr':
|
192 |
+
textd = None
|
193 |
+
barycentric = rast_out[0, ..., :-1]
|
194 |
+
findices = rast_out[0, ..., -1]
|
195 |
+
if uv.dim() == 2:
|
196 |
+
uv = uv.unsqueeze(0)
|
197 |
+
textc = self.raster.interpolate(uv, findices, barycentric, uv_idx)
|
198 |
+
else:
|
199 |
+
raise f'No raster named {self.raster_mode}'
|
200 |
+
|
201 |
+
return textc, textd
|
202 |
+
|
203 |
+
def raster_texture(self, tex, uv, uv_da=None, mip_level_bias=None, mip=None, filter_mode='auto',
|
204 |
+
boundary_mode='wrap', max_mip_level=None):
|
205 |
+
|
206 |
+
if self.raster_mode == 'cr':
|
207 |
+
raise f'Texture is not implemented in cr'
|
208 |
+
else:
|
209 |
+
raise f'No raster named {self.raster_mode}'
|
210 |
+
|
211 |
+
return color
|
212 |
+
|
213 |
+
def raster_antialias(self, color, rast, pos, tri, topology_hash=None, pos_gradient_boost=1.0):
|
214 |
+
|
215 |
+
if self.raster_mode == 'cr':
|
216 |
+
# Antialias has not been supported yet
|
217 |
+
color = color
|
218 |
+
else:
|
219 |
+
raise f'No raster named {self.raster_mode}'
|
220 |
+
|
221 |
+
return color
|
222 |
+
|
223 |
+
def load_mesh(
|
224 |
+
self,
|
225 |
+
mesh,
|
226 |
+
scale_factor=1.15,
|
227 |
+
auto_center=True,
|
228 |
+
):
|
229 |
+
vtx_pos, pos_idx, vtx_uv, uv_idx, texture_data = load_mesh(mesh)
|
230 |
+
self.mesh_copy = mesh
|
231 |
+
self.set_mesh(vtx_pos, pos_idx,
|
232 |
+
vtx_uv=vtx_uv, uv_idx=uv_idx,
|
233 |
+
scale_factor=scale_factor, auto_center=auto_center
|
234 |
+
)
|
235 |
+
if texture_data is not None:
|
236 |
+
self.set_texture(texture_data)
|
237 |
+
|
238 |
+
def save_mesh(self):
|
239 |
+
texture_data = self.get_texture()
|
240 |
+
texture_data = Image.fromarray((texture_data * 255).astype(np.uint8))
|
241 |
+
return save_mesh(self.mesh_copy, texture_data)
|
242 |
+
|
243 |
+
def set_mesh(
|
244 |
+
self,
|
245 |
+
vtx_pos, pos_idx,
|
246 |
+
vtx_uv=None, uv_idx=None,
|
247 |
+
scale_factor=1.15, auto_center=True
|
248 |
+
):
|
249 |
+
|
250 |
+
self.vtx_pos = torch.from_numpy(vtx_pos).to(self.device).float()
|
251 |
+
self.pos_idx = torch.from_numpy(pos_idx).to(self.device).to(torch.int)
|
252 |
+
if (vtx_uv is not None) and (uv_idx is not None):
|
253 |
+
self.vtx_uv = torch.from_numpy(vtx_uv).to(self.device).float()
|
254 |
+
self.uv_idx = torch.from_numpy(uv_idx).to(self.device).to(torch.int)
|
255 |
+
else:
|
256 |
+
self.vtx_uv = None
|
257 |
+
self.uv_idx = None
|
258 |
+
|
259 |
+
self.vtx_pos[:, [0, 1]] = -self.vtx_pos[:, [0, 1]]
|
260 |
+
self.vtx_pos[:, [1, 2]] = self.vtx_pos[:, [2, 1]]
|
261 |
+
if (vtx_uv is not None) and (uv_idx is not None):
|
262 |
+
self.vtx_uv[:, 1] = 1.0 - self.vtx_uv[:, 1]
|
263 |
+
|
264 |
+
if auto_center:
|
265 |
+
max_bb = (self.vtx_pos - 0).max(0)[0]
|
266 |
+
min_bb = (self.vtx_pos - 0).min(0)[0]
|
267 |
+
center = (max_bb + min_bb) / 2
|
268 |
+
scale = torch.norm(self.vtx_pos - center, dim=1).max() * 2.0
|
269 |
+
self.vtx_pos = (self.vtx_pos - center) * \
|
270 |
+
(scale_factor / float(scale))
|
271 |
+
self.scale_factor = scale_factor
|
272 |
+
|
273 |
+
def set_texture(self, tex):
|
274 |
+
if isinstance(tex, np.ndarray):
|
275 |
+
tex = Image.fromarray((tex * 255).astype(np.uint8))
|
276 |
+
elif isinstance(tex, torch.Tensor):
|
277 |
+
tex = tex.cpu().numpy()
|
278 |
+
tex = Image.fromarray((tex * 255).astype(np.uint8))
|
279 |
+
|
280 |
+
tex = tex.resize(self.texture_size).convert('RGB')
|
281 |
+
tex = np.array(tex) / 255.0
|
282 |
+
self.tex = torch.from_numpy(tex).to(self.device)
|
283 |
+
self.tex = self.tex.float()
|
284 |
+
|
285 |
+
def set_default_render_resolution(self, default_resolution):
|
286 |
+
if isinstance(default_resolution, int):
|
287 |
+
default_resolution = (default_resolution, default_resolution)
|
288 |
+
self.default_resolution = default_resolution
|
289 |
+
|
290 |
+
def set_default_texture_resolution(self, texture_size):
|
291 |
+
if isinstance(texture_size, int):
|
292 |
+
texture_size = (texture_size, texture_size)
|
293 |
+
self.texture_size = texture_size
|
294 |
+
|
295 |
+
def get_mesh(self):
|
296 |
+
vtx_pos = self.vtx_pos.cpu().numpy()
|
297 |
+
pos_idx = self.pos_idx.cpu().numpy()
|
298 |
+
vtx_uv = self.vtx_uv.cpu().numpy()
|
299 |
+
uv_idx = self.uv_idx.cpu().numpy()
|
300 |
+
|
301 |
+
# 坐标变换的逆变换
|
302 |
+
vtx_pos[:, [1, 2]] = vtx_pos[:, [2, 1]]
|
303 |
+
vtx_pos[:, [0, 1]] = -vtx_pos[:, [0, 1]]
|
304 |
+
|
305 |
+
vtx_uv[:, 1] = 1.0 - vtx_uv[:, 1]
|
306 |
+
return vtx_pos, pos_idx, vtx_uv, uv_idx
|
307 |
+
|
308 |
+
def get_texture(self):
|
309 |
+
return self.tex.cpu().numpy()
|
310 |
+
|
311 |
+
def to(self, device):
|
312 |
+
self.device = device
|
313 |
+
|
314 |
+
for attr_name in dir(self):
|
315 |
+
attr_value = getattr(self, attr_name)
|
316 |
+
if isinstance(attr_value, torch.Tensor):
|
317 |
+
setattr(self, attr_name, attr_value.to(self.device))
|
318 |
+
|
319 |
+
def color_rgb_to_srgb(self, image):
|
320 |
+
if isinstance(image, Image.Image):
|
321 |
+
image_rgb = torch.tesnor(
|
322 |
+
np.array(image) /
|
323 |
+
255.0).float().to(
|
324 |
+
self.device)
|
325 |
+
elif isinstance(image, np.ndarray):
|
326 |
+
image_rgb = torch.tensor(image).float()
|
327 |
+
else:
|
328 |
+
image_rgb = image.to(self.device)
|
329 |
+
|
330 |
+
image_srgb = torch.where(
|
331 |
+
image_rgb <= 0.0031308,
|
332 |
+
12.92 * image_rgb,
|
333 |
+
1.055 * torch.pow(image_rgb, 1 / 2.4) - 0.055
|
334 |
+
)
|
335 |
+
|
336 |
+
if isinstance(image, Image.Image):
|
337 |
+
image_srgb = Image.fromarray(
|
338 |
+
(image_srgb.cpu().numpy() *
|
339 |
+
255).astype(
|
340 |
+
np.uint8))
|
341 |
+
elif isinstance(image, np.ndarray):
|
342 |
+
image_srgb = image_srgb.cpu().numpy()
|
343 |
+
else:
|
344 |
+
image_srgb = image_srgb.to(image.device)
|
345 |
+
|
346 |
+
return image_srgb
|
347 |
+
|
348 |
+
def _render(
|
349 |
+
self,
|
350 |
+
glctx,
|
351 |
+
mvp,
|
352 |
+
pos,
|
353 |
+
pos_idx,
|
354 |
+
uv,
|
355 |
+
uv_idx,
|
356 |
+
tex,
|
357 |
+
resolution,
|
358 |
+
max_mip_level,
|
359 |
+
keep_alpha,
|
360 |
+
filter_mode
|
361 |
+
):
|
362 |
+
pos_clip = transform_pos(mvp, pos)
|
363 |
+
if isinstance(resolution, (int, float)):
|
364 |
+
resolution = [resolution, resolution]
|
365 |
+
rast_out, rast_out_db = self.raster_rasterize(
|
366 |
+
glctx, pos_clip, pos_idx, resolution=resolution)
|
367 |
+
|
368 |
+
tex = tex.contiguous()
|
369 |
+
if filter_mode == 'linear-mipmap-linear':
|
370 |
+
texc, texd = self.raster_interpolate(
|
371 |
+
uv[None, ...], rast_out, uv_idx, rast_db=rast_out_db, diff_attrs='all')
|
372 |
+
color = self.raster_texture(
|
373 |
+
tex[None, ...], texc, texd, filter_mode='linear-mipmap-linear', max_mip_level=max_mip_level)
|
374 |
+
else:
|
375 |
+
texc, _ = self.raster_interpolate(uv[None, ...], rast_out, uv_idx)
|
376 |
+
color = self.raster_texture(tex[None, ...], texc, filter_mode=filter_mode)
|
377 |
+
|
378 |
+
visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)
|
379 |
+
color = color * visible_mask # Mask out background.
|
380 |
+
if self.use_antialias:
|
381 |
+
color = self.raster_antialias(color, rast_out, pos_clip, pos_idx)
|
382 |
+
|
383 |
+
if keep_alpha:
|
384 |
+
color = torch.cat([color, visible_mask], dim=-1)
|
385 |
+
return color[0, ...]
|
386 |
+
|
387 |
+
def render(
|
388 |
+
self,
|
389 |
+
elev,
|
390 |
+
azim,
|
391 |
+
camera_distance=None,
|
392 |
+
center=None,
|
393 |
+
resolution=None,
|
394 |
+
tex=None,
|
395 |
+
keep_alpha=True,
|
396 |
+
bgcolor=None,
|
397 |
+
filter_mode=None,
|
398 |
+
return_type='th'
|
399 |
+
):
|
400 |
+
|
401 |
+
proj = self.camera_proj_mat
|
402 |
+
r_mv = get_mv_matrix(
|
403 |
+
elev=elev,
|
404 |
+
azim=azim,
|
405 |
+
camera_distance=self.camera_distance if camera_distance is None else camera_distance,
|
406 |
+
center=center)
|
407 |
+
r_mvp = np.matmul(proj, r_mv).astype(np.float32)
|
408 |
+
if tex is not None:
|
409 |
+
if isinstance(tex, Image.Image):
|
410 |
+
tex = torch.tensor(np.array(tex) / 255.0)
|
411 |
+
elif isinstance(tex, np.ndarray):
|
412 |
+
tex = torch.tensor(tex)
|
413 |
+
if tex.dim() == 2:
|
414 |
+
tex = tex.unsqueeze(-1)
|
415 |
+
tex = tex.float().to(self.device)
|
416 |
+
image = self._render(r_mvp, self.vtx_pos, self.pos_idx, self.vtx_uv, self.uv_idx,
|
417 |
+
self.tex if tex is None else tex,
|
418 |
+
self.default_resolution if resolution is None else resolution,
|
419 |
+
self.max_mip_level, True, filter_mode if filter_mode else self.filter_mode)
|
420 |
+
mask = (image[..., [-1]] == 1).float()
|
421 |
+
if bgcolor is None:
|
422 |
+
bgcolor = [0 for _ in range(image.shape[-1] - 1)]
|
423 |
+
image = image * mask + (1 - mask) * \
|
424 |
+
torch.tensor(bgcolor + [0]).to(self.device)
|
425 |
+
if keep_alpha == False:
|
426 |
+
image = image[..., :-1]
|
427 |
+
if return_type == 'np':
|
428 |
+
image = image.cpu().numpy()
|
429 |
+
elif return_type == 'pl':
|
430 |
+
image = image.squeeze(-1).cpu().numpy() * 255
|
431 |
+
image = Image.fromarray(image.astype(np.uint8))
|
432 |
+
return image
|
433 |
+
|
434 |
+
def render_normal(
|
435 |
+
self,
|
436 |
+
elev,
|
437 |
+
azim,
|
438 |
+
camera_distance=None,
|
439 |
+
center=None,
|
440 |
+
resolution=None,
|
441 |
+
bg_color=[1, 1, 1],
|
442 |
+
use_abs_coor=False,
|
443 |
+
normalize_rgb=True,
|
444 |
+
return_type='th'
|
445 |
+
):
|
446 |
+
|
447 |
+
pos_camera, pos_clip = self.get_pos_from_mvp(elev, azim, camera_distance, center)
|
448 |
+
if resolution is None:
|
449 |
+
resolution = self.default_resolution
|
450 |
+
if isinstance(resolution, (int, float)):
|
451 |
+
resolution = [resolution, resolution]
|
452 |
+
rast_out, rast_out_db = self.raster_rasterize(
|
453 |
+
pos_clip, self.pos_idx, resolution=resolution)
|
454 |
+
|
455 |
+
if use_abs_coor:
|
456 |
+
mesh_triangles = self.vtx_pos[self.pos_idx[:, :3], :]
|
457 |
+
else:
|
458 |
+
pos_camera = pos_camera[:, :3] / pos_camera[:, 3:4]
|
459 |
+
mesh_triangles = pos_camera[self.pos_idx[:, :3], :]
|
460 |
+
face_normals = F.normalize(
|
461 |
+
torch.cross(mesh_triangles[:,
|
462 |
+
1,
|
463 |
+
:] - mesh_triangles[:,
|
464 |
+
0,
|
465 |
+
:],
|
466 |
+
mesh_triangles[:,
|
467 |
+
2,
|
468 |
+
:] - mesh_triangles[:,
|
469 |
+
0,
|
470 |
+
:],
|
471 |
+
dim=-1),
|
472 |
+
dim=-1)
|
473 |
+
|
474 |
+
vertex_normals = trimesh.geometry.mean_vertex_normals(vertex_count=self.vtx_pos.shape[0],
|
475 |
+
faces=self.pos_idx.cpu(),
|
476 |
+
face_normals=face_normals.cpu(), )
|
477 |
+
vertex_normals = torch.from_numpy(
|
478 |
+
vertex_normals).float().to(self.device).contiguous()
|
479 |
+
|
480 |
+
# Interpolate normal values across the rasterized pixels
|
481 |
+
normal, _ = self.raster_interpolate(
|
482 |
+
vertex_normals[None, ...], rast_out, self.pos_idx)
|
483 |
+
|
484 |
+
visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)
|
485 |
+
normal = normal * visible_mask + \
|
486 |
+
torch.tensor(bg_color, dtype=torch.float32, device=self.device) * (1 -
|
487 |
+
visible_mask) # Mask out background.
|
488 |
+
|
489 |
+
if normalize_rgb:
|
490 |
+
normal = (normal + 1) * 0.5
|
491 |
+
if self.use_antialias:
|
492 |
+
normal = self.raster_antialias(normal, rast_out, pos_clip, self.pos_idx)
|
493 |
+
|
494 |
+
image = normal[0, ...]
|
495 |
+
if return_type == 'np':
|
496 |
+
image = image.cpu().numpy()
|
497 |
+
elif return_type == 'pl':
|
498 |
+
image = image.cpu().numpy() * 255
|
499 |
+
image = Image.fromarray(image.astype(np.uint8))
|
500 |
+
|
501 |
+
return image
|
502 |
+
|
503 |
+
def convert_normal_map(self, image):
|
504 |
+
# blue is front, red is left, green is top
|
505 |
+
if isinstance(image, Image.Image):
|
506 |
+
image = np.array(image)
|
507 |
+
mask = (image == [255, 255, 255]).all(axis=-1)
|
508 |
+
|
509 |
+
image = (image / 255.0) * 2.0 - 1.0
|
510 |
+
|
511 |
+
image[..., [1]] = -image[..., [1]]
|
512 |
+
image[..., [1, 2]] = image[..., [2, 1]]
|
513 |
+
image[..., [0]] = -image[..., [0]]
|
514 |
+
|
515 |
+
image = (image + 1.0) * 0.5
|
516 |
+
|
517 |
+
image = (image * 255).astype(np.uint8)
|
518 |
+
image[mask] = [127, 127, 255]
|
519 |
+
|
520 |
+
return Image.fromarray(image)
|
521 |
+
|
522 |
+
def get_pos_from_mvp(self, elev, azim, camera_distance, center):
|
523 |
+
proj = self.camera_proj_mat
|
524 |
+
r_mv = get_mv_matrix(
|
525 |
+
elev=elev,
|
526 |
+
azim=azim,
|
527 |
+
camera_distance=self.camera_distance if camera_distance is None else camera_distance,
|
528 |
+
center=center)
|
529 |
+
|
530 |
+
pos_camera = transform_pos(r_mv, self.vtx_pos, keepdim=True)
|
531 |
+
pos_clip = transform_pos(proj, pos_camera)
|
532 |
+
|
533 |
+
return pos_camera, pos_clip
|
534 |
+
|
535 |
+
def render_depth(
|
536 |
+
self,
|
537 |
+
elev,
|
538 |
+
azim,
|
539 |
+
camera_distance=None,
|
540 |
+
center=None,
|
541 |
+
resolution=None,
|
542 |
+
return_type='th'
|
543 |
+
):
|
544 |
+
pos_camera, pos_clip = self.get_pos_from_mvp(elev, azim, camera_distance, center)
|
545 |
+
|
546 |
+
if resolution is None:
|
547 |
+
resolution = self.default_resolution
|
548 |
+
if isinstance(resolution, (int, float)):
|
549 |
+
resolution = [resolution, resolution]
|
550 |
+
rast_out, rast_out_db = self.raster_rasterize(
|
551 |
+
pos_clip, self.pos_idx, resolution=resolution)
|
552 |
+
|
553 |
+
pos_camera = pos_camera[:, :3] / pos_camera[:, 3:4]
|
554 |
+
tex_depth = pos_camera[:, 2].reshape(1, -1, 1).contiguous()
|
555 |
+
|
556 |
+
# Interpolate depth values across the rasterized pixels
|
557 |
+
depth, _ = self.raster_interpolate(tex_depth, rast_out, self.pos_idx)
|
558 |
+
|
559 |
+
visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)
|
560 |
+
depth_max, depth_min = depth[visible_mask >
|
561 |
+
0].max(), depth[visible_mask > 0].min()
|
562 |
+
depth = (depth - depth_min) / (depth_max - depth_min)
|
563 |
+
|
564 |
+
depth = depth * visible_mask # Mask out background.
|
565 |
+
if self.use_antialias:
|
566 |
+
depth = self.raster_antialias(depth, rast_out, pos_clip, self.pos_idx)
|
567 |
+
|
568 |
+
image = depth[0, ...]
|
569 |
+
if return_type == 'np':
|
570 |
+
image = image.cpu().numpy()
|
571 |
+
elif return_type == 'pl':
|
572 |
+
image = image.squeeze(-1).cpu().numpy() * 255
|
573 |
+
image = Image.fromarray(image.astype(np.uint8))
|
574 |
+
return image
|
575 |
+
|
576 |
+
def render_position(self, elev, azim, camera_distance=None, center=None,
|
577 |
+
resolution=None, bg_color=[1, 1, 1], return_type='th'):
|
578 |
+
pos_camera, pos_clip = self.get_pos_from_mvp(elev, azim, camera_distance, center)
|
579 |
+
if resolution is None:
|
580 |
+
resolution = self.default_resolution
|
581 |
+
if isinstance(resolution, (int, float)):
|
582 |
+
resolution = [resolution, resolution]
|
583 |
+
rast_out, rast_out_db = self.raster_rasterize(
|
584 |
+
pos_clip, self.pos_idx, resolution=resolution)
|
585 |
+
|
586 |
+
tex_position = 0.5 - self.vtx_pos[:, :3] / self.scale_factor
|
587 |
+
tex_position = tex_position.contiguous()
|
588 |
+
|
589 |
+
# Interpolate depth values across the rasterized pixels
|
590 |
+
position, _ = self.raster_interpolate(
|
591 |
+
tex_position[None, ...], rast_out, self.pos_idx)
|
592 |
+
|
593 |
+
visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)
|
594 |
+
|
595 |
+
position = position * visible_mask + \
|
596 |
+
torch.tensor(bg_color, dtype=torch.float32, device=self.device) * (1 -
|
597 |
+
visible_mask) # Mask out background.
|
598 |
+
if self.use_antialias:
|
599 |
+
position = self.raster_antialias(position, rast_out, pos_clip, self.pos_idx)
|
600 |
+
|
601 |
+
image = position[0, ...]
|
602 |
+
|
603 |
+
if return_type == 'np':
|
604 |
+
image = image.cpu().numpy()
|
605 |
+
elif return_type == 'pl':
|
606 |
+
image = image.squeeze(-1).cpu().numpy() * 255
|
607 |
+
image = Image.fromarray(image.astype(np.uint8))
|
608 |
+
return image
|
609 |
+
|
610 |
+
def render_uvpos(self, return_type='th'):
|
611 |
+
image = self.uv_feature_map(self.vtx_pos * 0.5 + 0.5)
|
612 |
+
if return_type == 'np':
|
613 |
+
image = image.cpu().numpy()
|
614 |
+
elif return_type == 'pl':
|
615 |
+
image = image.cpu().numpy() * 255
|
616 |
+
image = Image.fromarray(image.astype(np.uint8))
|
617 |
+
return image
|
618 |
+
|
619 |
+
def uv_feature_map(self, vert_feat, bg=None):
|
620 |
+
vtx_uv = self.vtx_uv * 2 - 1.0
|
621 |
+
vtx_uv = torch.cat(
|
622 |
+
[vtx_uv, torch.zeros_like(self.vtx_uv)], dim=1).unsqueeze(0)
|
623 |
+
vtx_uv[..., -1] = 1
|
624 |
+
uv_idx = self.uv_idx
|
625 |
+
rast_out, rast_out_db = self.raster_rasterize(
|
626 |
+
vtx_uv, uv_idx, resolution=self.texture_size)
|
627 |
+
feat_map, _ = self.raster_interpolate(vert_feat[None, ...], rast_out, uv_idx)
|
628 |
+
feat_map = feat_map[0, ...]
|
629 |
+
if bg is not None:
|
630 |
+
visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0, ...]
|
631 |
+
feat_map[visible_mask == 0] = bg
|
632 |
+
return feat_map
|
633 |
+
|
634 |
+
def render_sketch_from_geometry(self, normal_image, depth_image):
|
635 |
+
normal_image_np = normal_image.cpu().numpy()
|
636 |
+
depth_image_np = depth_image.cpu().numpy()
|
637 |
+
|
638 |
+
normal_image_np = (normal_image_np * 255).astype(np.uint8)
|
639 |
+
depth_image_np = (depth_image_np * 255).astype(np.uint8)
|
640 |
+
normal_image_np = cv2.cvtColor(normal_image_np, cv2.COLOR_RGB2GRAY)
|
641 |
+
|
642 |
+
normal_edges = cv2.Canny(normal_image_np, 80, 150)
|
643 |
+
depth_edges = cv2.Canny(depth_image_np, 30, 80)
|
644 |
+
|
645 |
+
combined_edges = np.maximum(normal_edges, depth_edges)
|
646 |
+
|
647 |
+
sketch_image = torch.from_numpy(combined_edges).to(
|
648 |
+
normal_image.device).float() / 255.0
|
649 |
+
sketch_image = sketch_image.unsqueeze(-1)
|
650 |
+
|
651 |
+
return sketch_image
|
652 |
+
|
653 |
+
def render_sketch_from_depth(self, depth_image):
|
654 |
+
depth_image_np = depth_image.cpu().numpy()
|
655 |
+
depth_image_np = (depth_image_np * 255).astype(np.uint8)
|
656 |
+
depth_edges = cv2.Canny(depth_image_np, 30, 80)
|
657 |
+
combined_edges = depth_edges
|
658 |
+
sketch_image = torch.from_numpy(combined_edges).to(
|
659 |
+
depth_image.device).float() / 255.0
|
660 |
+
sketch_image = sketch_image.unsqueeze(-1)
|
661 |
+
return sketch_image
|
662 |
+
|
663 |
+
def back_project(self, image, elev, azim,
|
664 |
+
camera_distance=None, center=None, method=None):
|
665 |
+
if isinstance(image, Image.Image):
|
666 |
+
image = torch.tensor(np.array(image) / 255.0)
|
667 |
+
elif isinstance(image, np.ndarray):
|
668 |
+
image = torch.tensor(image)
|
669 |
+
if image.dim() == 2:
|
670 |
+
image = image.unsqueeze(-1)
|
671 |
+
image = image.float().to(self.device)
|
672 |
+
resolution = image.shape[:2]
|
673 |
+
channel = image.shape[-1]
|
674 |
+
texture = torch.zeros(self.texture_size + (channel,)).to(self.device)
|
675 |
+
cos_map = torch.zeros(self.texture_size + (1,)).to(self.device)
|
676 |
+
|
677 |
+
proj = self.camera_proj_mat
|
678 |
+
r_mv = get_mv_matrix(
|
679 |
+
elev=elev,
|
680 |
+
azim=azim,
|
681 |
+
camera_distance=self.camera_distance if camera_distance is None else camera_distance,
|
682 |
+
center=center)
|
683 |
+
pos_camera = transform_pos(r_mv, self.vtx_pos, keepdim=True)
|
684 |
+
pos_clip = transform_pos(proj, pos_camera)
|
685 |
+
pos_camera = pos_camera[:, :3] / pos_camera[:, 3:4]
|
686 |
+
v0 = pos_camera[self.pos_idx[:, 0], :]
|
687 |
+
v1 = pos_camera[self.pos_idx[:, 1], :]
|
688 |
+
v2 = pos_camera[self.pos_idx[:, 2], :]
|
689 |
+
face_normals = F.normalize(
|
690 |
+
torch.cross(
|
691 |
+
v1 - v0,
|
692 |
+
v2 - v0,
|
693 |
+
dim=-1),
|
694 |
+
dim=-1)
|
695 |
+
vertex_normals = trimesh.geometry.mean_vertex_normals(vertex_count=self.vtx_pos.shape[0],
|
696 |
+
faces=self.pos_idx.cpu(),
|
697 |
+
face_normals=face_normals.cpu(), )
|
698 |
+
vertex_normals = torch.from_numpy(
|
699 |
+
vertex_normals).float().to(self.device).contiguous()
|
700 |
+
tex_depth = pos_camera[:, 2].reshape(1, -1, 1).contiguous()
|
701 |
+
rast_out, rast_out_db = self.raster_rasterize(
|
702 |
+
pos_clip, self.pos_idx, resolution=resolution)
|
703 |
+
visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0, ...]
|
704 |
+
|
705 |
+
normal, _ = self.raster_interpolate(
|
706 |
+
vertex_normals[None, ...], rast_out, self.pos_idx)
|
707 |
+
normal = normal[0, ...]
|
708 |
+
uv, _ = self.raster_interpolate(self.vtx_uv[None, ...], rast_out, self.uv_idx)
|
709 |
+
depth, _ = self.raster_interpolate(tex_depth, rast_out, self.pos_idx)
|
710 |
+
depth = depth[0, ...]
|
711 |
+
|
712 |
+
depth_max, depth_min = depth[visible_mask >
|
713 |
+
0].max(), depth[visible_mask > 0].min()
|
714 |
+
depth_normalized = (depth - depth_min) / (depth_max - depth_min)
|
715 |
+
depth_image = depth_normalized * visible_mask # Mask out background.
|
716 |
+
|
717 |
+
sketch_image = self.render_sketch_from_depth(depth_image)
|
718 |
+
|
719 |
+
lookat = torch.tensor([[0, 0, -1]], device=self.device)
|
720 |
+
cos_image = torch.nn.functional.cosine_similarity(
|
721 |
+
lookat, normal.view(-1, 3))
|
722 |
+
cos_image = cos_image.view(normal.shape[0], normal.shape[1], 1)
|
723 |
+
|
724 |
+
cos_thres = np.cos(self.bake_angle_thres / 180 * np.pi)
|
725 |
+
cos_image[cos_image < cos_thres] = 0
|
726 |
+
|
727 |
+
# shrink
|
728 |
+
kernel_size = self.bake_unreliable_kernel_size * 2 + 1
|
729 |
+
kernel = torch.ones(
|
730 |
+
(1, 1, kernel_size, kernel_size), dtype=torch.float32).to(
|
731 |
+
sketch_image.device)
|
732 |
+
|
733 |
+
visible_mask = visible_mask.permute(2, 0, 1).unsqueeze(0).float()
|
734 |
+
visible_mask = F.conv2d(
|
735 |
+
1.0 - visible_mask,
|
736 |
+
kernel,
|
737 |
+
padding=kernel_size // 2)
|
738 |
+
visible_mask = 1.0 - (visible_mask > 0).float() # 二值化
|
739 |
+
visible_mask = visible_mask.squeeze(0).permute(1, 2, 0)
|
740 |
+
|
741 |
+
sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
|
742 |
+
sketch_image = F.conv2d(sketch_image, kernel, padding=kernel_size // 2)
|
743 |
+
sketch_image = (sketch_image > 0).float() # 二值化
|
744 |
+
sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
|
745 |
+
visible_mask = visible_mask * (sketch_image < 0.5)
|
746 |
+
|
747 |
+
cos_image[visible_mask == 0] = 0
|
748 |
+
|
749 |
+
method = self.bake_mode if method is None else method
|
750 |
+
|
751 |
+
if method == 'linear':
|
752 |
+
proj_mask = (visible_mask != 0).view(-1)
|
753 |
+
uv = uv.squeeze(0).contiguous().view(-1, 2)[proj_mask]
|
754 |
+
image = image.squeeze(0).contiguous().view(-1, channel)[proj_mask]
|
755 |
+
cos_image = cos_image.contiguous().view(-1, 1)[proj_mask]
|
756 |
+
sketch_image = sketch_image.contiguous().view(-1, 1)[proj_mask]
|
757 |
+
|
758 |
+
texture = linear_grid_put_2d(
|
759 |
+
self.texture_size[1], self.texture_size[0], uv[..., [1, 0]], image)
|
760 |
+
cos_map = linear_grid_put_2d(
|
761 |
+
self.texture_size[1], self.texture_size[0], uv[..., [1, 0]], cos_image)
|
762 |
+
boundary_map = linear_grid_put_2d(
|
763 |
+
self.texture_size[1], self.texture_size[0], uv[..., [1, 0]], sketch_image)
|
764 |
+
else:
|
765 |
+
raise f'No bake mode {method}'
|
766 |
+
|
767 |
+
return texture, cos_map, boundary_map
|
768 |
+
|
769 |
+
def bake_texture(self, colors, elevs, azims,
|
770 |
+
camera_distance=None, center=None, exp=6, weights=None):
|
771 |
+
for i in range(len(colors)):
|
772 |
+
if isinstance(colors[i], Image.Image):
|
773 |
+
colors[i] = torch.tensor(
|
774 |
+
np.array(
|
775 |
+
colors[i]) / 255.0,
|
776 |
+
device=self.device).float()
|
777 |
+
if weights is None:
|
778 |
+
weights = [1.0 for _ in range(colors)]
|
779 |
+
textures = []
|
780 |
+
cos_maps = []
|
781 |
+
for color, elev, azim, weight in zip(colors, elevs, azims, weights):
|
782 |
+
texture, cos_map, _ = self.back_project(
|
783 |
+
color, elev, azim, camera_distance, center)
|
784 |
+
cos_map = weight * (cos_map ** exp)
|
785 |
+
textures.append(texture)
|
786 |
+
cos_maps.append(cos_map)
|
787 |
+
|
788 |
+
texture_merge, trust_map_merge = self.fast_bake_texture(
|
789 |
+
textures, cos_maps)
|
790 |
+
return texture_merge, trust_map_merge
|
791 |
+
|
792 |
+
@torch.no_grad()
|
793 |
+
def fast_bake_texture(self, textures, cos_maps):
|
794 |
+
|
795 |
+
channel = textures[0].shape[-1]
|
796 |
+
texture_merge = torch.zeros(
|
797 |
+
self.texture_size + (channel,)).to(self.device)
|
798 |
+
trust_map_merge = torch.zeros(self.texture_size + (1,)).to(self.device)
|
799 |
+
for texture, cos_map in zip(textures, cos_maps):
|
800 |
+
view_sum = (cos_map > 0).sum()
|
801 |
+
painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()
|
802 |
+
if painted_sum / view_sum > 0.99:
|
803 |
+
continue
|
804 |
+
texture_merge += texture * cos_map
|
805 |
+
trust_map_merge += cos_map
|
806 |
+
texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1E-8)
|
807 |
+
|
808 |
+
return texture_merge, trust_map_merge > 1E-8
|
809 |
+
|
810 |
+
def uv_inpaint(self, texture, mask):
|
811 |
+
|
812 |
+
if isinstance(texture, torch.Tensor):
|
813 |
+
texture_np = texture.cpu().numpy()
|
814 |
+
elif isinstance(texture, np.ndarray):
|
815 |
+
texture_np = texture
|
816 |
+
elif isinstance(texture, Image.Image):
|
817 |
+
texture_np = np.array(texture) / 255.0
|
818 |
+
|
819 |
+
vtx_pos, pos_idx, vtx_uv, uv_idx = self.get_mesh()
|
820 |
+
|
821 |
+
texture_np, mask = meshVerticeInpaint(
|
822 |
+
texture_np, mask, vtx_pos, vtx_uv, pos_idx, uv_idx)
|
823 |
+
|
824 |
+
texture_np = cv2.inpaint(
|
825 |
+
(texture_np *
|
826 |
+
255).astype(
|
827 |
+
np.uint8),
|
828 |
+
255 -
|
829 |
+
mask,
|
830 |
+
3,
|
831 |
+
cv2.INPAINT_NS)
|
832 |
+
|
833 |
+
return texture_np
|