Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import math | |
from typing import Union | |
import custom_rasterizer as cr | |
import cv2 | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import trimesh | |
import xatlas | |
from PIL import Image | |
from asset3d_gen.data.utils import ( | |
get_images_from_file, | |
normalize_vertices_array, | |
post_process_texture, | |
save_mesh_with_mtl, | |
) | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO | |
) | |
logger = logging.getLogger(__name__) | |
__all__ = ["TextureBacker", "Image_Super_Net", "Image_GANNet"] | |
import math | |
import numpy as np | |
def get_perspective_projection( | |
fov: float, aspect_wh: float, near: float = 0.01, far: float = 100 | |
) -> np.ndarray: | |
"""Compute the perspective projection matrix for 3D rendering.""" | |
fov_rad = math.radians(fov) | |
tan_half_fov = math.tan(fov_rad / 2.0) | |
return np.array( | |
[ | |
[1.0 / (tan_half_fov * aspect_wh), 0.0, 0.0, 0.0], | |
[0.0, 1.0 / tan_half_fov, 0.0, 0.0], | |
[ | |
0.0, | |
0.0, | |
-(far + near) / (far - near), | |
-(2.0 * far * near) / (far - near), | |
], | |
[0.0, 0.0, -1.0, 0.0], | |
], | |
dtype=np.float32, | |
) | |
def transform_vertices( | |
mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False | |
) -> torch.Tensor: | |
"""Transform 3D vertices using a projection matrix.""" | |
t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype) | |
if pos.size(-1) == 3: | |
pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1) | |
result = pos @ t_mtx.T | |
return result if keepdim else result.unsqueeze(0) | |
def compute_w2c_matrix( | |
elev_deg: float, azim_deg: float, cam_dist: float | |
) -> np.ndarray: | |
"""Compute w2c 4x4 transformation matrix from spherical coordinates.""" | |
elev_rad = math.radians(-elev_deg) | |
azim_rad = math.radians(azim_deg) | |
sin_elev = math.sin(elev_rad) | |
cos_elev = math.cos(elev_rad) | |
sin_azim = math.sin(azim_rad) | |
cos_azim = math.cos(azim_rad) | |
cam_pos = np.array( | |
[ | |
cam_dist * cos_elev * cos_azim, | |
cam_dist * cos_elev * sin_azim, | |
cam_dist * sin_elev, | |
] | |
) | |
look_dir = -cam_pos / np.linalg.norm(cam_pos) | |
right_dir = np.cross(look_dir, [0, 0, 1]) | |
right_dir /= np.linalg.norm(right_dir) | |
up_dir = np.cross(right_dir, look_dir) | |
c2w = np.eye(4) | |
c2w[:3, 0] = right_dir | |
c2w[:3, 1] = up_dir | |
c2w[:3, 2] = -look_dir | |
c2w[:3, 3] = cam_pos | |
try: | |
w2c = np.linalg.inv(c2w) | |
except np.linalg.LinAlgError as e: | |
raise ArithmeticError("Failed to invert camera-to-world matrix") from e | |
return w2c.astype(np.float32) | |
def _bilinear_interpolation_scattering( | |
image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor | |
) -> torch.Tensor: | |
"""Bilinear interpolation scattering for grid-based value accumulation.""" | |
device = values.device | |
dtype = values.dtype | |
C = values.shape[-1] | |
indices = coords * torch.tensor( | |
[image_h - 1, image_w - 1], dtype=dtype, device=device | |
) | |
i, j = indices.unbind(-1) | |
i0, j0 = ( | |
indices.floor() | |
.long() | |
.clamp(0, image_h - 2) | |
.clamp(0, image_w - 2) | |
.unbind(-1) | |
) | |
i1, j1 = i0 + 1, j0 + 1 | |
w_i = i - i0.float() | |
w_j = j - j0.float() | |
weights = torch.stack( | |
[(1 - w_i) * (1 - w_j), (1 - w_i) * w_j, w_i * (1 - w_j), w_i * w_j], | |
dim=1, | |
) | |
indices_comb = torch.stack( | |
[ | |
torch.stack([i0, j0], dim=1), | |
torch.stack([i0, j1], dim=1), | |
torch.stack([i1, j0], dim=1), | |
torch.stack([i1, j1], dim=1), | |
], | |
dim=1, | |
) | |
grid = torch.zeros(image_h, image_w, C, device=device, dtype=dtype) | |
cnt = torch.zeros(image_h, image_w, 1, device=device, dtype=dtype) | |
for k in range(4): | |
idx = indices_comb[:, k] | |
w = weights[:, k].unsqueeze(-1) | |
stride = torch.tensor([image_w, 1], device=device, dtype=torch.long) | |
flat_idx = (idx * stride).sum(-1) | |
grid.view(-1, C).scatter_add_( | |
0, flat_idx.unsqueeze(-1).expand(-1, C), values * w | |
) | |
cnt.view(-1, 1).scatter_add_(0, flat_idx.unsqueeze(-1), w) | |
mask = cnt.squeeze(-1) > 0 | |
grid[mask] = grid[mask] / cnt[mask].repeat(1, C) | |
return grid | |
def _texture_inpaint_smooth( | |
texture: np.ndarray, | |
mask: np.ndarray, | |
vertices: np.ndarray, | |
faces: np.ndarray, | |
uv_map: np.ndarray, | |
) -> tuple[np.ndarray, np.ndarray]: | |
"""Perform texture inpainting using vertex-based color propagation.""" | |
image_h, image_w, C = texture.shape | |
N = vertices.shape[0] | |
# Initialize vertex data structures | |
vtx_mask = np.zeros(N, dtype=np.float32) | |
vtx_colors = np.zeros((N, C), dtype=np.float32) | |
unprocessed = [] | |
adjacency = [[] for _ in range(N)] | |
# Build adjacency graph and initial color assignment | |
for face_idx in range(faces.shape[0]): | |
for k in range(3): | |
uv_idx_k = faces[face_idx, k] | |
v_idx = faces[face_idx, k] | |
# Convert UV to pixel coordinates with boundary clamping | |
u = np.clip( | |
int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1 | |
) | |
v = np.clip( | |
int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))), | |
0, | |
image_h - 1, | |
) | |
if mask[v, u]: | |
vtx_mask[v_idx] = 1.0 | |
vtx_colors[v_idx] = texture[v, u] | |
elif v_idx not in unprocessed: | |
unprocessed.append(v_idx) | |
# Build undirected adjacency graph | |
neighbor = faces[face_idx, (k + 1) % 3] | |
if neighbor not in adjacency[v_idx]: | |
adjacency[v_idx].append(neighbor) | |
if v_idx not in adjacency[neighbor]: | |
adjacency[neighbor].append(v_idx) | |
# Color propagation with dynamic stopping | |
remaining_iters, prev_count = 2, 0 | |
while remaining_iters > 0: | |
current_unprocessed = [] | |
for v_idx in unprocessed: | |
valid_neighbors = [n for n in adjacency[v_idx] if vtx_mask[n] > 0] | |
if not valid_neighbors: | |
current_unprocessed.append(v_idx) | |
continue | |
# Calculate inverse square distance weights | |
neighbors_pos = vertices[valid_neighbors] | |
dist_sq = np.sum((vertices[v_idx] - neighbors_pos) ** 2, axis=1) | |
weights = 1 / np.maximum(dist_sq, 1e-8) | |
vtx_colors[v_idx] = np.average( | |
vtx_colors[valid_neighbors], weights=weights, axis=0 | |
) | |
vtx_mask[v_idx] = 1.0 | |
# Update iteration control | |
if len(current_unprocessed) == prev_count: | |
remaining_iters -= 1 | |
else: | |
remaining_iters = min(remaining_iters + 1, 2) | |
prev_count = len(current_unprocessed) | |
unprocessed = current_unprocessed | |
# Generate output texture | |
inpainted_texture, updated_mask = texture.copy(), mask.copy() | |
for face_idx in range(faces.shape[0]): | |
for k in range(3): | |
v_idx = faces[face_idx, k] | |
if not vtx_mask[v_idx]: | |
continue | |
# UV coordinate conversion | |
uv_idx_k = faces[face_idx, k] | |
u = np.clip( | |
int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1 | |
) | |
v = np.clip( | |
int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))), | |
0, | |
image_h - 1, | |
) | |
inpainted_texture[v, u] = vtx_colors[v_idx] | |
updated_mask[v, u] = 255 | |
return inpainted_texture, updated_mask | |
class TextureBacker: | |
"""Texture baking pipeline for multi-view projection and fusion.""" | |
def __init__( | |
self, | |
camera_elevs: list[float], | |
camera_azims: list[float], | |
camera_distance: int, | |
camera_fov: float, | |
view_weights: list[float] = None, | |
render_wh: tuple[int, int] = (2048, 2048), | |
texture_wh: tuple[int, int] = (2048, 2048), | |
use_antialias: bool = True, | |
bake_angle_thresh: int = 75, | |
device="cuda", | |
): | |
self.camera_elevs = camera_elevs | |
self.camera_azims = camera_azims | |
self.view_weights = ( | |
view_weights | |
if view_weights is not None | |
else [1] * len(camera_elevs) | |
) | |
self.device = device | |
self.render_wh = render_wh | |
self.texture_wh = texture_wh | |
self.camera_distance = camera_distance | |
self.use_antialias = use_antialias | |
self.bake_angle_thresh = bake_angle_thresh | |
self.bake_unreliable_kernel_size = int( | |
(2 / 512) * max(self.render_wh[0], self.render_wh[1]) | |
) | |
self.camera_proj_mat = get_perspective_projection( | |
camera_fov, | |
self.render_wh[1] / self.render_wh[0], | |
) | |
self.cnt = 0 | |
def rasterize_mesh( | |
self, | |
vertex: torch.Tensor, | |
face: torch.Tensor, | |
resolution: tuple[int, int], | |
) -> torch.Tensor: | |
vertex = vertex[None] if vertex.ndim == 2 else vertex | |
indices, weights = cr.rasterize(vertex, face, resolution) | |
return torch.cat( | |
[weights, indices.unsqueeze(-1).to(weights.dtype)], dim=-1 | |
).unsqueeze(0) | |
def raster_interpolate( | |
self, uv: torch.Tensor, rast_out: torch.Tensor, faces: torch.Tensor | |
) -> torch.Tensor: | |
barycentric = rast_out[0, ..., :-1] | |
findices = rast_out[0, ..., -1] | |
if uv.dim() == 2: | |
uv = uv.unsqueeze(0) | |
return cr.interpolate(uv, findices, barycentric, faces)[0] | |
def load_mesh(self, mesh_path: str) -> None: | |
mesh = trimesh.load(mesh_path) | |
if isinstance(mesh, trimesh.Scene): | |
mesh = mesh.dump(concatenate=True) | |
mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices) | |
self.scale, self.center = scale, center | |
vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces) | |
mesh.vertices = mesh.vertices[vmapping] | |
mesh.faces = indices | |
mesh.visual.uv = uvs | |
self.vertices = torch.from_numpy(mesh.vertices).to(self.device).float() | |
self.faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int) | |
self.uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float() | |
# Transformation of coordinate system | |
self.vertices[:, [0, 1]] = -self.vertices[:, [0, 1]] | |
self.vertices[:, [1, 2]] = self.vertices[:, [2, 1]] | |
self.uv_map[:, 1] = 1 - self.uv_map[:, 1] | |
def get_mesh_attrs( | |
self, | |
scale: float = None, | |
center: np.ndarray = None, | |
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: | |
vertices = self.vertices.cpu().numpy() | |
faces = self.faces.cpu().numpy() | |
uv_map = self.uv_map.cpu().numpy() | |
if scale is not None: | |
vertices = vertices / scale | |
if center is not None: | |
vertices = vertices + center | |
return vertices, faces, uv_map | |
def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor: | |
depth_image_np = depth_image.cpu().numpy() | |
depth_image_np = (depth_image_np * 255).astype(np.uint8) | |
depth_edges = cv2.Canny(depth_image_np, 30, 80) | |
sketch_image = ( | |
torch.from_numpy(depth_edges).to(depth_image.device).float() / 255 | |
) | |
sketch_image = sketch_image.unsqueeze(-1) | |
return sketch_image | |
def back_project( | |
self, image: Image.Image, elev: float, azim: float | |
) -> tuple[torch.Tensor, torch.Tensor]: | |
if isinstance(image, Image.Image): | |
image = np.array(image) | |
image = torch.as_tensor(image, device=self.device, dtype=torch.float32) | |
if image.ndim == 2: | |
image = image.unsqueeze(-1) | |
image = image / 255.0 | |
view_mat = compute_w2c_matrix(elev, azim, self.camera_distance) | |
pos_cam = transform_vertices(view_mat, self.vertices, keepdim=True) | |
pos_clip = transform_vertices(self.camera_proj_mat, pos_cam) | |
pos_cam = pos_cam[:, :3] / pos_cam[:, 3:] | |
v0, v1, v2 = (pos_cam[self.faces[:, i]] for i in range(3)) | |
face_norm = F.normalize(torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1) | |
vertex_norm = ( | |
torch.from_numpy( | |
trimesh.geometry.mean_vertex_normals( | |
len(pos_cam), self.faces.cpu(), face_norm.cpu() | |
) | |
) | |
.to(self.device) | |
.contiguous() | |
) | |
rast_out = self.rasterize_mesh(pos_clip, self.faces, image.shape[:2]) | |
vis_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0] | |
interp_data = { | |
"normal": self.raster_interpolate( | |
vertex_norm[None], rast_out, self.faces | |
), | |
"uv": self.raster_interpolate( | |
self.uv_map[None], rast_out, self.faces | |
), | |
"depth": self.raster_interpolate( | |
pos_cam[:, 2].reshape(1, -1, 1), rast_out, self.faces | |
), | |
} | |
valid_depth = interp_data["depth"][vis_mask > 0] | |
depth_norm = (interp_data["depth"] - valid_depth.min()) / ( | |
valid_depth.max() - valid_depth.min() | |
) | |
depth_norm[vis_mask <= 0] = 0 | |
sketch_image = self._render_depth_edges(depth_norm * vis_mask) | |
# cv2.imwrite("vis_mask.png", (vis_mask.cpu().numpy() * 255).astype(np.uint8)) | |
# cv2.imwrite("normal.png", (interp_data['normal'].cpu().numpy() * 255).astype(np.uint8)) | |
# cv2.imwrite("depth.png", (depth_norm.cpu().numpy() * 255).astype(np.uint8)) | |
# cv2.imwrite("uv.png", (interp_data['uv'][..., 0].cpu().numpy() * 255).astype(np.uint8)) | |
# import pdb; pdb.set_trace() | |
cos = F.cosine_similarity( | |
torch.tensor([[0, 0, -1]], device=self.device), | |
interp_data["normal"].view(-1, 3), | |
).view_as(interp_data["normal"][..., :1]) | |
cos[cos < np.cos(np.radians(self.bake_angle_thresh))] = 0 | |
k = self.bake_unreliable_kernel_size * 2 + 1 | |
kernel = torch.ones((1, 1, k, k), device=self.device) | |
vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float() | |
vis_mask = F.conv2d( | |
1.0 - vis_mask, | |
kernel, | |
padding=k // 2, | |
) | |
vis_mask = 1.0 - (vis_mask > 0).float() | |
vis_mask = vis_mask.squeeze(0).permute(1, 2, 0) | |
sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0) | |
sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2) | |
sketch_image = (sketch_image > 0).float() | |
sketch_image = sketch_image.squeeze(0).permute(1, 2, 0) | |
vis_mask = vis_mask * (sketch_image < 0.5) | |
cos[vis_mask == 0] = 0 | |
valid_pixels = (vis_mask != 0).view(-1) | |
return ( | |
self._scatter_texture(interp_data["uv"], image, valid_pixels), | |
self._scatter_texture(interp_data["uv"], cos, valid_pixels), | |
) | |
def back_project2( | |
self, image, vis_mask, depth, normal, uv | |
) -> tuple[torch.Tensor, torch.Tensor]: | |
if isinstance(image, Image.Image): | |
image = np.array(image) | |
image = torch.as_tensor(image, device=self.device, dtype=torch.float32) | |
if image.ndim == 2: | |
image = image.unsqueeze(-1) | |
image = image / 255.0 | |
depth_inv = (1.0 - depth) * vis_mask | |
sketch_image = self._render_depth_edges(depth_inv) | |
cv2.imwrite( | |
f"v3_depth_inv{self.cnt}.png", | |
(depth_inv.cpu().numpy() * 255).astype(np.uint8), | |
) | |
cos = F.cosine_similarity( | |
torch.tensor([[0, 0, 1]], device=self.device), | |
normal.view(-1, 3), | |
).view_as(normal[..., :1]) | |
cos[cos < np.cos(np.radians(self.bake_angle_thresh))] = 0 | |
# import pdb; pdb.set_trace() | |
# cv2.imwrite(f"v3_cos{self.cnt}.png", (cos.cpu().numpy() * 255).astype(np.uint8)) | |
# cv2.imwrite(f"v3_sketch{self.cnt}.png", (sketch_image.cpu().numpy() * 255).astype(np.uint8)) | |
# cos2 = cv2.imread(f"v2_cos{self.cnt+1}.png", cv2.IMREAD_GRAYSCALE) | |
# cos2 = torch.from_numpy(cos2[..., None]).to(self.device).float() / 255 | |
# cos = cos2 | |
self.cnt += 1 | |
k = self.bake_unreliable_kernel_size * 2 + 1 | |
kernel = torch.ones((1, 1, k, k), device=self.device) | |
vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float() | |
vis_mask = F.conv2d( | |
1.0 - vis_mask, | |
kernel, | |
padding=k // 2, | |
) | |
vis_mask = 1.0 - (vis_mask > 0).float() | |
vis_mask = vis_mask.squeeze(0).permute(1, 2, 0) | |
sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0) | |
sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2) | |
sketch_image = (sketch_image > 0).float() | |
sketch_image = sketch_image.squeeze(0).permute(1, 2, 0) | |
vis_mask = vis_mask * (sketch_image < 0.5) | |
# import pdb; pdb.set_trace() | |
cv2.imwrite( | |
f"v3_db_sketch{self.cnt}.png", | |
(sketch_image.cpu().numpy() * 255).astype(np.uint8), | |
) | |
cos[vis_mask == 0] = 0 | |
# import pdb; pdb.set_trace() | |
# vis_mask = cv2.imread(f"v2_db_mask{self.cnt}.png", cv2.IMREAD_GRAYSCALE) | |
# vis_mask = torch.from_numpy(vis_mask[..., None]).to(self.device).float() / 255 | |
# cos2 = cv2.imread(f"v2_db_cos{self.cnt}.png", cv2.IMREAD_GRAYSCALE) | |
# cos2 = torch.from_numpy(cos2[..., None]).to(self.device).float() / 255 | |
# cos = cos2 | |
valid_pixels = (vis_mask != 0).view(-1) | |
# import pdb; pdb.set_trace() | |
cv2.imwrite( | |
f"v3_db_uv{self.cnt}.png", | |
(uv[..., 0].cpu().numpy() * 255).astype(np.uint8), | |
) | |
cv2.imwrite( | |
f"v3_db_uv2{self.cnt}.png", | |
(uv[..., 1].cpu().numpy() * 255).astype(np.uint8), | |
) | |
cv2.imwrite( | |
f"v3_db_color{self.cnt}.png", | |
(image.cpu().numpy() * 255).astype(np.uint8), | |
) | |
cv2.imwrite( | |
f"v3_db_cos{self.cnt}.png", | |
(cos.cpu().numpy() * 255).astype(np.uint8), | |
) | |
cv2.imwrite( | |
f"v3_db_mask{self.cnt}.png", | |
(vis_mask.cpu().numpy() * 255).astype(np.uint8), | |
) | |
return ( | |
self._scatter_texture(uv, image, valid_pixels), | |
self._scatter_texture(uv, cos, valid_pixels), | |
) | |
def _scatter_texture(self, uv, data, mask): | |
def __filter_data(data, mask): | |
return data.view(-1, data.shape[-1])[mask] | |
return _bilinear_interpolation_scattering( | |
self.texture_wh[1], | |
self.texture_wh[0], | |
__filter_data(uv, mask)[..., [1, 0]], | |
__filter_data(data, mask), | |
) | |
def fast_bake_texture( | |
self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor] | |
) -> tuple[torch.Tensor, torch.Tensor]: | |
channel = textures[0].shape[-1] | |
texture_merge = torch.zeros(self.texture_wh + (channel,)).to( | |
self.device | |
) | |
trust_map_merge = torch.zeros(self.texture_wh + (1,)).to(self.device) | |
for texture, cos_map in zip(textures, confidence_maps): | |
view_sum = (cos_map > 0).sum() | |
painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum() | |
if painted_sum / view_sum > 0.99: | |
continue | |
texture_merge += texture * cos_map | |
trust_map_merge += cos_map | |
texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8) | |
return texture_merge, trust_map_merge > 1e-8 | |
def uv_inpaint( | |
self, texture: torch.Tensor, mask: torch.Tensor | |
) -> np.ndarray: | |
texture_np = texture.cpu().numpy() | |
mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8) | |
vertices, faces, uv_map = self.get_mesh_attrs() | |
# import pdb; pdb.set_trace() | |
texture_np, mask_np = _texture_inpaint_smooth( | |
texture_np, mask_np, vertices, faces, uv_map | |
) | |
texture_np = texture_np.clip(0, 1) | |
texture_np = cv2.inpaint( | |
(texture_np * 255).astype(np.uint8), | |
255 - mask_np, | |
3, | |
cv2.INPAINT_NS, | |
) | |
return texture_np | |
def __call__( | |
self, colors: list[Image.Image], input_mesh: str, output_path: str | |
) -> trimesh.Trimesh: | |
self.load_mesh(input_mesh) | |
textures, weighted_cos_maps = [], [] | |
for color, cam_elev, cam_azim, weight in zip( | |
colors, self.camera_elevs, self.camera_azims, self.view_weights | |
): | |
texture, cos_map = self.back_project(color, cam_elev, cam_azim) | |
textures.append(texture) | |
weighted_cos_maps.append(weight * (cos_map**4)) | |
texture, mask = self.fast_bake_texture(textures, weighted_cos_maps) | |
texture_np = self.uv_inpaint(texture, mask) | |
texture_np = post_process_texture(texture_np) | |
vertices, faces, uv_map = self.get_mesh_attrs(self.scale, self.center) | |
# import pdb; pdb.set_trace() | |
textured_mesh = save_mesh_with_mtl( | |
vertices, faces, uv_map, texture_np, output_path | |
) | |
return textured_mesh | |
def forward( | |
self, | |
colors: list[Image.Image], | |
masks, | |
depths, | |
normals, | |
uvs, | |
) -> trimesh.Trimesh: | |
textures, weighted_cos_maps = [], [] | |
for color, mask, depth, normal, uv, weight in zip( | |
colors, masks, depths, normals, uvs, self.view_weights | |
): | |
texture, cos_map = self.back_project2( | |
color, mask, depth, normal, uv | |
) | |
cv2.imwrite( | |
f"v3_texture{self.cnt}.png", | |
(texture.cpu().numpy() * 255).astype(np.uint8), | |
) | |
cv2.imwrite( | |
f"v3_texture_cos{self.cnt}.png", | |
(cos_map.cpu().numpy() * 255).astype(np.uint8), | |
) | |
textures.append(texture) | |
weighted_cos_maps.append(weight * (cos_map**4)) | |
texture, mask = self.fast_bake_texture(textures, weighted_cos_maps) | |
texture_np = self.uv_inpaint(texture, mask) | |
texture_np = post_process_texture(texture_np) | |
vertices, faces, uv_map = self.get_mesh_attrs(self.scale, self.center) | |
# import pdb; pdb.set_trace() | |
cv2.imwrite("v3_texture_np.png", texture_np) | |
textured_mesh = save_mesh_with_mtl( | |
vertices, faces, uv_map, texture_np, output_path | |
) | |
return textured_mesh | |
class Image_Super_Net: | |
def __init__(self, device="cuda"): | |
from diffusers import StableDiffusionUpscalePipeline | |
self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained( | |
"stabilityai/stable-diffusion-x4-upscaler", | |
torch_dtype=torch.float16, | |
).to(device) | |
self.up_pipeline_x4.set_progress_bar_config(disable=True) | |
def __call__(self, image, prompt=""): | |
with torch.no_grad(): | |
upscaled_image = self.up_pipeline_x4( | |
prompt=[prompt], | |
image=image, | |
num_inference_steps=10, | |
).images[0] | |
return upscaled_image | |
class Image_GANNet: | |
def __init__(self, outscale: int): | |
from basicsr.archs.rrdbnet_arch import RRDBNet | |
from realesrgan import RealESRGANer | |
self.outscale = outscale | |
model = RRDBNet( | |
num_in_ch=3, | |
num_out_ch=3, | |
num_feat=64, | |
num_block=23, | |
num_grow_ch=32, | |
scale=4, | |
) | |
self.upsampler = RealESRGANer( | |
scale=4, | |
model_path="/horizon-bucket/robot_lab/users/xinjie.wang/weights/super_resolution/RealESRGAN_x4plus.pth", # noqa | |
model=model, | |
pre_pad=0, | |
half=True, | |
) | |
def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image: | |
if isinstance(image, Image.Image): | |
image = np.array(image) | |
output, _ = self.upsampler.enhance(image, outscale=self.outscale) | |
return Image.fromarray(output) | |
if __name__ == "__main__": | |
device = "cuda" | |
color_path = "outputs/texture_mesh_gen/multi_view/color_sample0.png" | |
mesh_path = "outputs/texture_mesh_gen/texture_mesh/kettle_color.glb" | |
output_path = "robot_test_v6/robot.obj" | |
target_image_size = (2048, 2048) | |
super_model = Image_GANNet(outscale=4) | |
multiviews = get_images_from_file(color_path, img_size=512) | |
multiviews = [super_model(img) for img in multiviews] | |
multiviews = [img.convert("RGB") for img in multiviews] | |
from asset3d_gen.data.utils import ( | |
CameraSetting, | |
init_kal_camera, | |
DiffrastRender, | |
) | |
import nvdiffrast.torch as dr | |
camera_params = CameraSetting( | |
num_images=6, | |
elevation=[20.0, -10.0], | |
distance=5, | |
resolution_hw=(2048, 2048), | |
fov=math.radians(30), | |
device="cuda", | |
) | |
camera = init_kal_camera(camera_params) | |
mv = camera.view_matrix() # (n 4 4) world2cam | |
p = camera.intrinsics.projection_matrix() | |
# NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa | |
p[:, 1, 1] = -p[:, 1, 1] | |
renderer = DiffrastRender( | |
p_matrix=p, | |
mv_matrix=mv, | |
resolution_hw=camera_params.resolution_hw, | |
context=dr.RasterizeCudaContext(), | |
mask_thresh=0.5, | |
grad_db=False, | |
device=camera_params.device, | |
antialias_mask=True, | |
) | |
mesh = trimesh.load(mesh_path) | |
if isinstance(mesh, trimesh.Scene): | |
mesh = mesh.dump(concatenate=True) | |
mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices) | |
vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces) | |
uvs[:, 1] = 1 - uvs[:, 1] | |
mesh.vertices = mesh.vertices[vmapping] | |
mesh.faces = indices | |
mesh.visual.uv = uvs | |
vertices = torch.from_numpy(mesh.vertices).to(camera_params.device).float() | |
faces = ( | |
torch.from_numpy(mesh.faces).to(camera_params.device).to(torch.int64) | |
) | |
uvs = torch.from_numpy(mesh.visual.uv).to(camera_params.device).float() | |
rendered_view_normals = [] | |
rast, vertices_clip = renderer.compute_dr_raster(vertices, faces) | |
for idx in range(len(mv)): | |
pos_cam = transform_vertices(mv[idx], vertices, keepdim=True) | |
pos_cam = pos_cam[:, :3] / pos_cam[:, 3:] | |
v0, v1, v2 = (pos_cam[faces[:, i]] for i in range(3)) | |
face_norm = F.normalize(torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1) | |
vertex_norm = ( | |
torch.from_numpy( | |
trimesh.geometry.mean_vertex_normals( | |
len(pos_cam), faces.cpu(), face_norm.cpu() | |
) | |
) | |
.to(camera_params.device) | |
.contiguous() | |
) | |
im_base_normals, _ = dr.interpolate( | |
vertex_norm[None, ...].float(), | |
rast[idx : idx + 1], | |
faces.to(torch.int32), | |
) | |
rendered_view_normals.append(im_base_normals) | |
rendered_view_normals = torch.cat(rendered_view_normals, dim=0) | |
rendered_depth, masks = renderer.render_depth(vertices, faces) | |
norm_depths = [] | |
for idx in range(len(rendered_depth)): | |
norm_depth = renderer.normalize_map_by_mask( | |
rendered_depth[idx : idx + 1], masks[idx : idx + 1] | |
) | |
norm_depths.append(norm_depth) | |
norm_depths = torch.cat(norm_depths, dim=0) | |
render_uvs, _ = renderer.render_uv(vertices, faces, uvs) | |
for index in range(6): | |
cv2.imwrite( | |
f"v3_mask{index}.png", | |
(masks[index] * 255).cpu().numpy().astype(np.uint8), | |
) | |
cv2.imwrite( | |
f"v3_normalv2{index}.png", | |
(rendered_view_normals[index] * 255) | |
.cpu() | |
.numpy() | |
.astype(np.uint8)[..., ::-1], | |
) | |
cv2.imwrite( | |
f"v3_depth{index}.png", | |
(norm_depths[index] * 255).cpu().numpy().astype(np.uint8), | |
) | |
cv2.imwrite( | |
f"v3_uv{index}.png", | |
(render_uvs[index, ..., 0] * 255).cpu().numpy().astype(np.uint8), | |
) | |
multiviews[index].save(f"v3_color{index}.png") | |
texture_backer = TextureBacker( | |
camera_elevs=[20, 20, 20, -10, -10, -10], | |
camera_azims=[-180, -60, 60, -120, 0, 120], | |
view_weights=[1, 0.2, 0.2, 0.2, 1, 0.2], | |
camera_distance=5, | |
camera_fov=30, | |
render_wh=(2048, 2048), | |
texture_wh=(2048, 2048), | |
) | |
texture_backer.vertices = vertices | |
texture_backer.faces = faces | |
uvs[:, 1] = 1.0 - uvs[:, 1] | |
texture_backer.uv_map = uvs | |
texture_backer.center = center | |
texture_backer.scale = scale | |
textured_mesh = texture_backer.forward( | |
multiviews, masks, norm_depths, rendered_view_normals, render_uvs | |
) | |
# multiviews = [super_model(img) for img in multiviews] | |
# multiviews = [img.convert("RGB") for img in multiviews] | |
# textured_mesh = texture_backer(multiviews, mesh_path, output_path) | |