|
import torch |
|
import pytorch3d |
|
|
|
|
|
from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj, IO |
|
|
|
from pytorch3d.structures import Meshes |
|
from pytorch3d.renderer import ( |
|
look_at_view_transform, |
|
FoVPerspectiveCameras, |
|
FoVOrthographicCameras, |
|
AmbientLights, |
|
PointLights, |
|
DirectionalLights, |
|
Materials, |
|
RasterizationSettings, |
|
MeshRenderer, |
|
MeshRasterizer, |
|
TexturesUV, |
|
) |
|
|
|
from .geometry import HardGeometryShader |
|
from .shader import HardNChannelFlatShader |
|
from .voronoi import voronoi_solve |
|
import torch.nn.functional as F |
|
import open3d as o3d |
|
import pdb |
|
import kaolin as kal |
|
import numpy as np |
|
|
|
|
|
import torch |
|
from pytorch3d.renderer.cameras import FoVOrthographicCameras |
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union |
|
from pytorch3d.common.datatypes import Device |
|
import math |
|
import torch.nn.functional as F |
|
from trimesh import Trimesh |
|
from pytorch3d.structures import Meshes |
|
import os |
|
|
|
LIST_TYPE = Union[list, np.ndarray, torch.Tensor] |
|
|
|
_R = torch.eye(3)[None] |
|
_T = torch.zeros(1, 3) |
|
_BatchFloatType = Union[float, Sequence[float], torch.Tensor] |
|
|
|
|
|
class CustomOrthographicCameras(FoVOrthographicCameras): |
|
def compute_projection_matrix( |
|
self, znear, zfar, max_x, min_x, max_y, min_y, scale_xyz |
|
) -> torch.Tensor: |
|
""" |
|
自定义正交投影矩阵计算,继承并修改深度通道参数 |
|
参数维度说明: |
|
- znear/zfar: (N,) |
|
- max_x/min_x: (N,) |
|
- max_y/min_y: (N,) |
|
- scale_xyz: (N, 3) |
|
""" |
|
K = torch.zeros((self._N, 4, 4), dtype=torch.float32, device=self.device) |
|
|
|
ones = torch.ones((self._N), dtype=torch.float32, device=self.device) |
|
|
|
|
|
|
|
z_sign = +1.0 |
|
|
|
K[:, 0, 0] = (2.0 / (max_x - min_x)) * scale_xyz[:, 0] |
|
K[:, 1, 1] = (2.0 / (max_y - min_y)) * scale_xyz[:, 1] |
|
K[:, 0, 3] = -(max_x + min_x) / (max_x - min_x) |
|
K[:, 1, 3] = -(max_y + min_y) / (max_y - min_y) |
|
K[:, 3, 3] = ones |
|
|
|
|
|
|
|
K[:, 2, 2] = -2 * (1.0 / (zfar - znear)) * scale_xyz[:, 2] |
|
K[:, 2, 3] = -(znear + zfar) / (zfar - znear) |
|
|
|
return K |
|
|
|
def __init__( |
|
self, |
|
znear: _BatchFloatType = 1.0, |
|
zfar: _BatchFloatType = 100.0, |
|
max_y: _BatchFloatType = 1.0, |
|
min_y: _BatchFloatType = -1.0, |
|
max_x: _BatchFloatType = 1.0, |
|
min_x: _BatchFloatType = -1.0, |
|
scale_xyz=((1.0, 1.0, 1.0),), |
|
R: torch.Tensor = _R, |
|
T: torch.Tensor = _T, |
|
K: Optional[torch.Tensor] = None, |
|
device: Device = "cpu", |
|
): |
|
|
|
super().__init__( |
|
znear=znear, |
|
zfar=zfar, |
|
max_y=max_y, |
|
min_y=min_y, |
|
max_x=max_x, |
|
min_x=min_x, |
|
scale_xyz=scale_xyz, |
|
R=R, |
|
T=T, |
|
K=K, |
|
device=device, |
|
) |
|
|
|
|
|
def erode_torch_batch(binary_img_batch, kernel_size): |
|
pad = (kernel_size - 1) // 2 |
|
bin_img = F.pad( |
|
binary_img_batch.unsqueeze(1), pad=[pad, pad, pad, pad], mode="reflect" |
|
) |
|
out = -F.max_pool2d(-bin_img, kernel_size=kernel_size, stride=1, padding=0) |
|
out = out.squeeze(1) |
|
return out |
|
|
|
|
|
def dilate_torch_batch(binary_img_batch, kernel_size): |
|
pad = (kernel_size - 1) // 2 |
|
bin_img = F.pad(binary_img_batch, pad=[pad, pad, pad, pad], mode="reflect") |
|
out = F.max_pool2d(bin_img, kernel_size=kernel_size, stride=1, padding=0) |
|
out = out.squeeze() |
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UVProjection: |
|
def __init__( |
|
self, |
|
texture_size=96, |
|
render_size=64, |
|
sampling_mode="nearest", |
|
channels=3, |
|
device=None, |
|
): |
|
self.channels = channels |
|
self.device = device or torch.device("cpu") |
|
self.lights = AmbientLights( |
|
ambient_color=((1.0,) * channels,), device=self.device |
|
) |
|
self.target_size = (texture_size, texture_size) |
|
self.render_size = render_size |
|
self.sampling_mode = sampling_mode |
|
|
|
|
|
def load_mesh(self, mesh, scale_factor=2.0, auto_center=True, autouv=False): |
|
if isinstance(mesh, Trimesh): |
|
vertices = torch.tensor(mesh.vertices, dtype=torch.float32).to(self.device) |
|
faces = torch.tensor(mesh.faces, dtype=torch.int64).to(self.device) |
|
mesh = Meshes(verts=[vertices], faces=[faces]) |
|
verts = mesh.verts_packed() |
|
mesh = mesh.update_padded(verts[None, :, :]) |
|
elif isinstance(mesh, str) and os.path.isfile(mesh): |
|
mesh = load_objs_as_meshes([mesh_path], device=self.device) |
|
if auto_center: |
|
verts = mesh.verts_packed() |
|
max_bb = (verts - 0).max(0)[0] |
|
min_bb = (verts - 0).min(0)[0] |
|
scale = (max_bb - min_bb).max() / 2 |
|
center = (max_bb + min_bb) / 2 |
|
mesh.offset_verts_(-center) |
|
mesh.scale_verts_((scale_factor / float(scale))) |
|
else: |
|
mesh.scale_verts_((scale_factor)) |
|
|
|
if autouv or (mesh.textures is None): |
|
mesh = self.uv_unwrap(mesh) |
|
self.mesh = mesh |
|
|
|
def load_glb_mesh( |
|
self, mesh_path, trimesh, scale_factor=1.0, auto_center=True, autouv=False |
|
): |
|
from pytorch3d.io.experimental_gltf_io import MeshGlbFormat |
|
|
|
io = IO() |
|
io.register_meshes_format(MeshGlbFormat()) |
|
with open(mesh_path, "rb") as f: |
|
mesh = io.load_mesh(f, include_textures=True, device=self.device) |
|
if auto_center: |
|
verts = mesh.verts_packed() |
|
|
|
max_bb = (verts - 0).max(0)[0] |
|
min_bb = (verts - 0).min(0)[0] |
|
scale = (max_bb - min_bb).max() / 2 |
|
center = (max_bb + min_bb) / 2 |
|
mesh.offset_verts_(-center) |
|
mesh.scale_verts_((scale_factor / float(scale))) |
|
verts = mesh.verts_packed() |
|
|
|
|
|
|
|
mesh = mesh.update_padded(verts[None, :, :]) |
|
else: |
|
mesh.scale_verts_((scale_factor)) |
|
if autouv or (mesh.textures is None): |
|
mesh = self.uv_unwrap(mesh) |
|
self.mesh = mesh |
|
|
|
|
|
def save_mesh(self, mesh_path, texture): |
|
save_obj( |
|
mesh_path, |
|
self.mesh.verts_list()[0], |
|
self.mesh.faces_list()[0], |
|
verts_uvs=self.mesh.textures.verts_uvs_list()[0], |
|
faces_uvs=self.mesh.textures.faces_uvs_list()[0], |
|
texture_map=texture, |
|
) |
|
|
|
|
|
def uv_unwrap(self, mesh): |
|
verts_list = mesh.verts_list()[0] |
|
faces_list = mesh.faces_list()[0] |
|
|
|
import xatlas |
|
import numpy as np |
|
|
|
v_np = verts_list.cpu().numpy() |
|
f_np = faces_list.int().cpu().numpy() |
|
atlas = xatlas.Atlas() |
|
atlas.add_mesh(v_np, f_np) |
|
chart_options = xatlas.ChartOptions() |
|
chart_options.max_iterations = 4 |
|
atlas.generate(chart_options=chart_options) |
|
vmapping, ft_np, vt_np = atlas[0] |
|
|
|
vt = ( |
|
torch.from_numpy(vt_np.astype(np.float32)) |
|
.type(verts_list.dtype) |
|
.to(mesh.device) |
|
) |
|
ft = ( |
|
torch.from_numpy(ft_np.astype(np.int64)) |
|
.type(faces_list.dtype) |
|
.to(mesh.device) |
|
) |
|
|
|
new_map = torch.zeros(self.target_size + (self.channels,), device=mesh.device) |
|
new_tex = TexturesUV([new_map], [ft], [vt], sampling_mode=self.sampling_mode) |
|
|
|
mesh.textures = new_tex |
|
return mesh |
|
|
|
""" |
|
A functions that disconnect faces in the mesh according to |
|
its UV seams. The number of vertices are made equal to the |
|
number of unique vertices its UV layout, while the faces list |
|
is intact. |
|
""" |
|
|
|
def disconnect_faces(self): |
|
mesh = self.mesh |
|
verts_list = mesh.verts_list() |
|
faces_list = mesh.faces_list() |
|
verts_uvs_list = mesh.textures.verts_uvs_list() |
|
faces_uvs_list = mesh.textures.faces_uvs_list() |
|
packed_list = [v[f] for v, f in zip(verts_list, faces_list)] |
|
verts_disconnect_list = [ |
|
torch.zeros( |
|
(verts_uvs_list[i].shape[0], 3), |
|
dtype=verts_list[0].dtype, |
|
device=verts_list[0].device, |
|
) |
|
for i in range(len(verts_list)) |
|
] |
|
for i in range(len(verts_list)): |
|
verts_disconnect_list[i][faces_uvs_list] = packed_list[i] |
|
assert not mesh.has_verts_normals(), "Not implemented for vertex normals" |
|
self.mesh_d = Meshes(verts_disconnect_list, faces_uvs_list, mesh.textures) |
|
return self.mesh_d |
|
|
|
""" |
|
A function that construct a temp mesh for back-projection. |
|
Take a disconnected mesh and a rasterizer, the function calculates |
|
the projected faces as the UV, as use its original UV with pseudo |
|
z value as world space geometry. |
|
""" |
|
|
|
def construct_uv_mesh(self): |
|
mesh = self.mesh_d |
|
verts_list = mesh.verts_list() |
|
verts_uvs_list = mesh.textures.verts_uvs_list() |
|
|
|
new_verts_list = [] |
|
for i, (verts, verts_uv) in enumerate(zip(verts_list, verts_uvs_list)): |
|
verts = verts.clone() |
|
verts_uv = verts_uv.clone() |
|
verts[..., 0:2] = verts_uv[..., :] |
|
verts = (verts - 0.5) * 2 |
|
verts[..., 2] *= 1 |
|
new_verts_list.append(verts) |
|
textures_uv = mesh.textures.clone() |
|
self.mesh_uv = Meshes(new_verts_list, mesh.faces_list(), textures_uv) |
|
return self.mesh_uv |
|
|
|
|
|
def set_texture_map(self, texture): |
|
new_map = texture.permute(1, 2, 0) |
|
new_map = new_map.to(self.device) |
|
new_tex = TexturesUV( |
|
[new_map], |
|
self.mesh.textures.faces_uvs_padded(), |
|
self.mesh.textures.verts_uvs_padded(), |
|
sampling_mode=self.sampling_mode, |
|
) |
|
self.mesh.textures = new_tex |
|
|
|
|
|
|
|
def set_noise_texture(self, channels=None): |
|
if not channels: |
|
channels = self.channels |
|
noise_texture = torch.normal( |
|
0, 1, (channels,) + self.target_size, device=self.device |
|
) |
|
self.set_texture_map(noise_texture) |
|
return noise_texture |
|
|
|
|
|
def set_cameras(self, camera_poses, centers=None, camera_distance=2.7, scale=None): |
|
elev = torch.FloatTensor([pose[0] for pose in camera_poses]) |
|
azim = torch.FloatTensor([pose[1] for pose in camera_poses]) |
|
print("camera_distance:{}".format(camera_distance)) |
|
R, T = look_at_view_transform( |
|
dist=camera_distance, elev=elev, azim=azim, at=centers or ((0, 0, 0),) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("R size:{}, T size:{}".format(R.size(), T.size())) |
|
|
|
self.cameras = FoVOrthographicCameras( |
|
device=self.device, R=R, T=T, scale_xyz=scale or ((1, 1, 1),) |
|
) |
|
|
|
|
|
|
|
def set_cameras_and_render_settings( |
|
self, |
|
camera_poses, |
|
centers=None, |
|
camera_distance=2.7, |
|
render_size=None, |
|
scale=None, |
|
): |
|
self.set_cameras(camera_poses, centers, camera_distance, scale=scale) |
|
if render_size is None: |
|
render_size = self.render_size |
|
if not hasattr(self, "renderer"): |
|
self.setup_renderer(size=render_size) |
|
if not hasattr(self, "mesh_d"): |
|
self.disconnect_faces() |
|
if not hasattr(self, "mesh_uv"): |
|
self.construct_uv_mesh() |
|
self.calculate_tex_gradient() |
|
self.calculate_visible_triangle_mask() |
|
_, _, _, cos_maps, _, _ = self.render_geometry() |
|
self.calculate_cos_angle_weights(cos_maps) |
|
|
|
|
|
|
|
|
|
def setup_renderer( |
|
self, |
|
size=64, |
|
blur=0.0, |
|
face_per_pix=1, |
|
perspective_correct=False, |
|
channels=None, |
|
): |
|
if not channels: |
|
channels = self.channels |
|
|
|
self.raster_settings = RasterizationSettings( |
|
image_size=size, |
|
blur_radius=blur, |
|
faces_per_pixel=face_per_pix, |
|
perspective_correct=perspective_correct, |
|
cull_backfaces=True, |
|
max_faces_per_bin=30000, |
|
) |
|
|
|
self.renderer = MeshRenderer( |
|
rasterizer=MeshRasterizer( |
|
cameras=self.cameras, |
|
raster_settings=self.raster_settings, |
|
), |
|
shader=HardNChannelFlatShader( |
|
device=self.device, |
|
cameras=self.cameras, |
|
lights=self.lights, |
|
channels=channels, |
|
|
|
), |
|
) |
|
|
|
|
|
|
|
@torch.enable_grad() |
|
def calculate_cos_angle_weights(self, cos_angles, fill=True, channels=None): |
|
if not channels: |
|
channels = self.channels |
|
cos_maps = [] |
|
tmp_mesh = self.mesh.clone() |
|
for i in range(len(self.cameras)): |
|
|
|
zero_map = torch.zeros( |
|
self.target_size + (channels,), device=self.device, requires_grad=True |
|
) |
|
optimizer = torch.optim.SGD([zero_map], lr=1, momentum=0) |
|
optimizer.zero_grad() |
|
zero_tex = TexturesUV( |
|
[zero_map], |
|
self.mesh.textures.faces_uvs_padded(), |
|
self.mesh.textures.verts_uvs_padded(), |
|
sampling_mode=self.sampling_mode, |
|
) |
|
tmp_mesh.textures = zero_tex |
|
|
|
images_predicted = self.renderer( |
|
tmp_mesh, cameras=self.cameras[i], lights=self.lights |
|
) |
|
|
|
loss = torch.sum((cos_angles[i, :, :, 0:1] ** 1 - images_predicted) ** 2) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
if fill: |
|
zero_map = zero_map.detach() / (self.gradient_maps[i] + 1e-8) |
|
zero_map = voronoi_solve( |
|
zero_map, self.gradient_maps[i][..., 0], self.device |
|
) |
|
else: |
|
zero_map = zero_map.detach() / (self.gradient_maps[i] + 1e-8) |
|
cos_maps.append(zero_map) |
|
self.cos_maps = cos_maps |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def render_geometry(self, image_size=None): |
|
if image_size: |
|
size = self.renderer.rasterizer.raster_settings.image_size |
|
self.renderer.rasterizer.raster_settings.image_size = image_size |
|
shader = self.renderer.shader |
|
self.renderer.shader = HardGeometryShader( |
|
device=self.device, cameras=self.cameras[0], lights=self.lights |
|
) |
|
tmp_mesh = self.mesh.clone() |
|
|
|
verts, normals, depths, cos_angles, texels, fragments = self.renderer( |
|
tmp_mesh.extend(len(self.cameras)), cameras=self.cameras, lights=self.lights |
|
) |
|
self.renderer.shader = shader |
|
|
|
if image_size: |
|
self.renderer.rasterizer.raster_settings.image_size = size |
|
|
|
return verts, normals, depths, cos_angles, texels, fragments |
|
|
|
|
|
@torch.no_grad() |
|
def decode_view_normal(self, normals): |
|
w2v_mat = self.cameras.get_full_projection_transform() |
|
normals_view = torch.clone(normals)[:, :, :, 0:3] |
|
normals_view = normals_view.reshape(normals_view.shape[0], -1, 3) |
|
normals_view = w2v_mat.transform_normals(normals_view) |
|
normals_view = normals_view.reshape(normals.shape[0:3] + (3,)) |
|
normals_view[:, :, :, 2] *= -1 |
|
normals = (normals_view[..., 0:3] + 1) * normals[ |
|
..., 3: |
|
] / 2 + torch.FloatTensor(((((0.5, 0.5, 1))))).to(self.device) * ( |
|
1 - normals[..., 3:] |
|
) |
|
|
|
normals = normals.clamp(0, 1) |
|
return normals |
|
|
|
|
|
@torch.no_grad() |
|
def decode_normalized_depth(self, depths, batched_norm=False): |
|
view_z, mask = depths.unbind(-1) |
|
view_z = view_z * mask + 100 * (1 - mask) |
|
inv_z = 1 / view_z |
|
inv_z_min = inv_z * mask + 100 * (1 - mask) |
|
if not batched_norm: |
|
max_ = torch.max(inv_z, 1, keepdim=True) |
|
max_ = torch.max(max_[0], 2, keepdim=True)[0] |
|
|
|
min_ = torch.min(inv_z_min, 1, keepdim=True) |
|
min_ = torch.min(min_[0], 2, keepdim=True)[0] |
|
else: |
|
max_ = torch.max(inv_z) |
|
min_ = torch.min(inv_z_min) |
|
inv_z = (inv_z - min_) / (max_ - min_) |
|
inv_z = inv_z.clamp(0, 1) |
|
inv_z = inv_z[..., None].repeat(1, 1, 1, 3) |
|
|
|
return inv_z |
|
|
|
|
|
|
|
@torch.enable_grad() |
|
def calculate_tex_gradient(self, channels=None): |
|
if not channels: |
|
channels = self.channels |
|
tmp_mesh = self.mesh.clone() |
|
gradient_maps = [] |
|
for i in range(len(self.cameras)): |
|
zero_map = torch.zeros( |
|
self.target_size + (channels,), device=self.device, requires_grad=True |
|
) |
|
optimizer = torch.optim.SGD([zero_map], lr=1, momentum=0) |
|
optimizer.zero_grad() |
|
zero_tex = TexturesUV( |
|
[zero_map], |
|
self.mesh.textures.faces_uvs_padded(), |
|
self.mesh.textures.verts_uvs_padded(), |
|
sampling_mode=self.sampling_mode, |
|
) |
|
tmp_mesh.textures = zero_tex |
|
images_predicted = self.renderer( |
|
tmp_mesh, cameras=self.cameras[i], lights=self.lights |
|
) |
|
loss = torch.sum((1 - images_predicted) ** 2) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
gradient_maps.append(zero_map.detach()) |
|
|
|
self.gradient_maps = gradient_maps |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def get_c2w( |
|
self, |
|
elevation_deg: LIST_TYPE, |
|
distance: LIST_TYPE, |
|
azimuth_deg: Optional[LIST_TYPE], |
|
num_views: Optional[int] = 1, |
|
device: Optional[str] = None, |
|
) -> torch.FloatTensor: |
|
if azimuth_deg is None: |
|
assert ( |
|
num_views is not None |
|
), "num_views must be provided if azimuth_deg is None." |
|
azimuth_deg = torch.linspace( |
|
0, 360, num_views + 1, dtype=torch.float32, device=device |
|
)[:-1] |
|
else: |
|
num_views = len(azimuth_deg) |
|
|
|
def list_to_pt( |
|
x: LIST_TYPE, |
|
dtype: Optional[torch.dtype] = None, |
|
device: Optional[str] = None, |
|
) -> torch.Tensor: |
|
if isinstance(x, list) or isinstance(x, np.ndarray): |
|
return torch.tensor(x, dtype=dtype, device=device) |
|
return x.to(dtype=dtype) |
|
|
|
azimuth_deg = list_to_pt(azimuth_deg, dtype=torch.float32, device=device) |
|
elevation_deg = list_to_pt(elevation_deg, dtype=torch.float32, device=device) |
|
camera_distances = list_to_pt(distance, dtype=torch.float32, device=device) |
|
elevation = elevation_deg * math.pi / 180 |
|
azimuth = azimuth_deg * math.pi / 180 |
|
camera_positions = torch.stack( |
|
[ |
|
camera_distances * torch.cos(elevation) * torch.cos(azimuth), |
|
camera_distances * torch.cos(elevation) * torch.sin(azimuth), |
|
camera_distances * torch.sin(elevation), |
|
], |
|
dim=-1, |
|
) |
|
center = torch.zeros_like(camera_positions) |
|
up = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)[ |
|
None, : |
|
].repeat(num_views, 1) |
|
lookat = F.normalize(center - camera_positions, dim=-1) |
|
right = F.normalize(torch.cross(lookat, up, dim=-1), dim=-1) |
|
up = F.normalize(torch.cross(right, lookat, dim=-1), dim=-1) |
|
c2w3x4 = torch.cat( |
|
[torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], |
|
dim=-1, |
|
) |
|
c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1) |
|
c2w[:, 3, 3] = 1.0 |
|
return c2w |
|
|
|
@torch.no_grad() |
|
def calculate_visible_triangle_mask(self, channels=None, image_size=(512, 512)): |
|
if not channels: |
|
channels = self.channels |
|
|
|
pix2face_list = [] |
|
for i in range(len(self.cameras)): |
|
self.renderer.rasterizer.raster_settings.image_size = image_size |
|
pix2face = self.renderer.rasterizer( |
|
self.mesh_d, cameras=self.cameras[i] |
|
).pix_to_face |
|
self.renderer.rasterizer.raster_settings.image_size = self.render_size |
|
pix2face_list.append(pix2face) |
|
|
|
if not hasattr(self, "mesh_uv"): |
|
self.construct_uv_mesh() |
|
|
|
raster_settings = RasterizationSettings( |
|
image_size=self.target_size, |
|
blur_radius=0, |
|
faces_per_pixel=1, |
|
perspective_correct=False, |
|
cull_backfaces=False, |
|
max_faces_per_bin=30000, |
|
) |
|
|
|
R, T = look_at_view_transform(dist=2, elev=0, azim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cameras = FoVOrthographicCameras(device=self.device, R=R, T=T) |
|
|
|
|
|
|
|
|
|
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) |
|
uv_pix2face = rasterizer(self.mesh_uv).pix_to_face |
|
|
|
visible_triangles = [] |
|
for i in range(len(pix2face_list)): |
|
valid_faceid = torch.unique(pix2face_list[i]) |
|
valid_faceid = valid_faceid[1:] if valid_faceid[0] == -1 else valid_faceid |
|
mask = torch.isin(uv_pix2face[0], valid_faceid, assume_unique=False) |
|
|
|
triangle_mask = torch.ones(self.target_size + (1,), device=self.device) |
|
triangle_mask[~mask] = 0 |
|
|
|
triangle_mask[:, 1:][triangle_mask[:, :-1] > 0] = 1 |
|
triangle_mask[:, :-1][triangle_mask[:, 1:] > 0] = 1 |
|
triangle_mask[1:, :][triangle_mask[:-1, :] > 0] = 1 |
|
triangle_mask[:-1, :][triangle_mask[1:, :] > 0] = 1 |
|
visible_triangles.append(triangle_mask) |
|
|
|
self.visible_triangles = visible_triangles |
|
|
|
|
|
def render_textured_views(self): |
|
meshes = self.mesh.extend(len(self.cameras)) |
|
images_predicted = self.renderer( |
|
meshes, cameras=self.cameras, lights=self.lights |
|
) |
|
|
|
return [image.permute(2, 0, 1) for image in images_predicted] |
|
|
|
@torch.no_grad() |
|
def get_point_validation_by_o3d( |
|
self, points, eye_position, hidden_point_removal_radius=200 |
|
): |
|
point_visibility = torch.zeros((points.shape[0]), device=points.device).bool() |
|
|
|
pcd = o3d.geometry.PointCloud( |
|
points=o3d.utility.Vector3dVector(points.cpu().numpy()) |
|
) |
|
camera_pose = ( |
|
eye_position.get_camera_center().squeeze().cpu().numpy().astype(np.float64) |
|
) |
|
|
|
diameter = np.linalg.norm( |
|
np.asarray(pcd.get_max_bound()) - np.asarray(pcd.get_min_bound()) |
|
) |
|
radius = diameter * 200 |
|
_, pt_map = pcd.hidden_point_removal(camera_pose, radius) |
|
|
|
visible_point_ids = np.array(pt_map) |
|
|
|
point_visibility[visible_point_ids] = True |
|
return point_visibility |
|
|
|
@torch.no_grad() |
|
def hidden_judge(self, camera, texture_dim): |
|
mesh = self.mesh |
|
|
|
verts = mesh.verts_packed() |
|
faces = mesh.faces_packed() |
|
verts_uv = mesh.textures.verts_uvs_padded()[0] |
|
faces_uv = mesh.textures.faces_uvs_padded()[0] |
|
uv_face_attr = torch.index_select( |
|
verts_uv, 0, faces_uv.view(-1) |
|
) |
|
uv_face_attr = uv_face_attr.view( |
|
faces.shape[0], faces_uv.shape[1], 2 |
|
).unsqueeze(0) |
|
x, y, z = verts[:, 0], verts[:, 1], verts[:, 2] |
|
mesh_out_of_range = False |
|
if ( |
|
x.min() < -1 |
|
or x.max() > 1 |
|
or y.min() < -1 |
|
or y.max() > 1 |
|
or z.min() < -1 |
|
or z.max() > 1 |
|
): |
|
mesh_out_of_range = True |
|
face_vertices_world = kal.ops.mesh.index_vertices_by_faces( |
|
verts.unsqueeze(0), faces |
|
) |
|
face_vertices_z = torch.zeros_like( |
|
face_vertices_world[:, :, :, -1], device=verts.device |
|
) |
|
uv_position, face_idx = kal.render.mesh.rasterize( |
|
texture_dim, |
|
texture_dim, |
|
face_vertices_z, |
|
uv_face_attr * 2 - 1, |
|
face_features=face_vertices_world, |
|
) |
|
uv_position = torch.clamp(uv_position, -1, 1) |
|
uv_position[face_idx == -1] = 0 |
|
|
|
points = uv_position.reshape(-1, 3) |
|
mask = points[:, 0] != 0 |
|
valid_points = points[mask] |
|
|
|
|
|
|
|
points_visibility = self.get_point_validation_by_o3d( |
|
valid_points, camera |
|
).float() |
|
visibility_map = torch.zeros((texture_dim * texture_dim,)).to(self.device) |
|
visibility_map[mask] = points_visibility |
|
visibility_map = visibility_map.reshape((texture_dim, texture_dim)) |
|
return visibility_map |
|
|
|
@torch.enable_grad() |
|
def bake_texture( |
|
self, |
|
views=None, |
|
main_views=[], |
|
cos_weighted=True, |
|
channels=None, |
|
exp=None, |
|
noisy=False, |
|
generator=None, |
|
smooth_colorize=False, |
|
): |
|
if not exp: |
|
exp = 1 |
|
if not channels: |
|
channels = self.channels |
|
views = [view.permute(1, 2, 0) for view in views] |
|
|
|
tmp_mesh = self.mesh |
|
bake_maps = [ |
|
torch.zeros( |
|
self.target_size + (views[0].shape[2],), |
|
device=self.device, |
|
requires_grad=True, |
|
) |
|
for view in views |
|
] |
|
optimizer = torch.optim.SGD(bake_maps, lr=1, momentum=0) |
|
optimizer.zero_grad() |
|
loss = 0 |
|
for i in range(len(self.cameras)): |
|
bake_tex = TexturesUV( |
|
[bake_maps[i]], |
|
tmp_mesh.textures.faces_uvs_padded(), |
|
tmp_mesh.textures.verts_uvs_padded(), |
|
sampling_mode=self.sampling_mode, |
|
) |
|
tmp_mesh.textures = bake_tex |
|
images_predicted = self.renderer( |
|
tmp_mesh, |
|
cameras=self.cameras[i], |
|
lights=self.lights, |
|
device=self.device, |
|
) |
|
predicted_rgb = images_predicted[..., :-1] |
|
loss += (((predicted_rgb[...] - views[i])) ** 2).sum() |
|
loss.backward(retain_graph=False) |
|
optimizer.step() |
|
|
|
total_weights = 0 |
|
baked = 0 |
|
for i in range(len(bake_maps)): |
|
normalized_baked_map = bake_maps[i].detach() / ( |
|
self.gradient_maps[i] + 1e-8 |
|
) |
|
bake_map = voronoi_solve( |
|
normalized_baked_map, self.gradient_maps[i][..., 0], self.device |
|
) |
|
|
|
|
|
weight = self.visible_triangles[i] * (self.cos_maps[i]) ** exp |
|
if smooth_colorize: |
|
visibility_map = self.hidden_judge( |
|
self.cameras[i], self.target_size[0] |
|
).unsqueeze(-1) |
|
weight *= visibility_map |
|
if noisy: |
|
noise = ( |
|
torch.rand(weight.shape[:-1] + (1,), generator=generator) |
|
.type(weight.dtype) |
|
.to(weight.device) |
|
) |
|
weight *= noise |
|
total_weights += weight |
|
|
|
baked += bake_map * weight |
|
baked /= total_weights + 1e-8 |
|
|
|
whole_visible_mask = None |
|
if not smooth_colorize: |
|
baked = voronoi_solve(baked, total_weights[..., 0], self.device) |
|
tmp_mesh.textures = TexturesUV( |
|
[baked], |
|
tmp_mesh.textures.faces_uvs_padded(), |
|
tmp_mesh.textures.verts_uvs_padded(), |
|
sampling_mode=self.sampling_mode, |
|
) |
|
else: |
|
baked = voronoi_solve(baked, total_weights[..., 0], self.device) |
|
whole_visible_mask = self.visible_triangles[0].to(torch.int32) |
|
for tensor in self.visible_triangles[1:]: |
|
whole_visible_mask = torch.bitwise_or( |
|
whole_visible_mask, tensor.to(torch.int32) |
|
) |
|
|
|
baked *= whole_visible_mask |
|
tmp_mesh.textures = TexturesUV( |
|
[baked], |
|
tmp_mesh.textures.faces_uvs_padded(), |
|
tmp_mesh.textures.verts_uvs_padded(), |
|
sampling_mode=self.sampling_mode, |
|
) |
|
|
|
extended_mesh = tmp_mesh.extend(len(self.cameras)) |
|
images_predicted = self.renderer( |
|
extended_mesh, cameras=self.cameras, lights=self.lights |
|
) |
|
learned_views = [image.permute(2, 0, 1) for image in images_predicted] |
|
|
|
return learned_views, baked.permute(2, 0, 1), total_weights.permute(2, 0, 1) |
|
|
|
|
|
def to(self, device): |
|
for mesh_name in ["mesh", "mesh_d", "mesh_uv"]: |
|
if hasattr(self, mesh_name): |
|
mesh = getattr(self, mesh_name) |
|
setattr(self, mesh_name, mesh.to(device)) |
|
for list_name in ["visible_triangles", "visibility_maps", "cos_maps"]: |
|
if hasattr(self, list_name): |
|
map_list = getattr(self, list_name) |
|
for i in range(len(map_list)): |
|
map_list[i] = map_list[i].to(device) |
|
|