xinjie.wang
update
55ed985
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
import math
import trimesh
import cv2
import xatlas
from typing import Union
def get_perspective_projection_matrix(fovy, aspect_wh, near, far):
fovy_rad = math.radians(fovy)
return np.array(
[
[1.0 / (math.tan(fovy_rad / 2.0) * aspect_wh), 0, 0, 0],
[0, 1.0 / math.tan(fovy_rad / 2.0), 0, 0],
[
0,
0,
-(far + near) / (far - near),
-2.0 * far * near / (far - near),
],
[0, 0, -1, 0],
]
).astype(np.float32)
def load_mesh(mesh):
vtx_pos = mesh.vertices if hasattr(mesh, "vertices") else None
pos_idx = mesh.faces if hasattr(mesh, "faces") else None
vtx_uv = mesh.visual.uv if hasattr(mesh.visual, "uv") else None
uv_idx = mesh.faces if hasattr(mesh, "faces") else None
texture_data = None
return vtx_pos, pos_idx, vtx_uv, uv_idx, texture_data
def save_mesh(mesh, texture_data):
material = trimesh.visual.texture.SimpleMaterial(
image=texture_data, diffuse=(255, 255, 255)
)
texture_visuals = trimesh.visual.TextureVisuals(
uv=mesh.visual.uv, image=texture_data, material=material
)
mesh.visual = texture_visuals
return mesh
def transform_pos(mtx, pos, keepdim=False):
t_mtx = (
torch.from_numpy(mtx).to(pos.device)
if isinstance(mtx, np.ndarray)
else mtx
)
if pos.shape[-1] == 3:
posw = torch.cat(
[pos, torch.ones([pos.shape[0], 1]).to(pos.device)], axis=1
)
else:
posw = pos
if keepdim:
return torch.matmul(posw, t_mtx.t())[...]
else:
return torch.matmul(posw, t_mtx.t())[None, ...]
def get_mv_matrix(elev, azim, camera_distance, center=None):
elev = -elev
elev_rad = math.radians(elev)
azim_rad = math.radians(azim)
camera_position = np.array(
[
camera_distance * math.cos(elev_rad) * math.cos(azim_rad),
camera_distance * math.cos(elev_rad) * math.sin(azim_rad),
camera_distance * math.sin(elev_rad),
]
)
if center is None:
center = np.array([0, 0, 0])
else:
center = np.array(center)
lookat = center - camera_position
lookat = lookat / np.linalg.norm(lookat)
up = np.array([0, 0, 1.0])
right = np.cross(lookat, up)
right = right / np.linalg.norm(right)
up = np.cross(right, lookat)
up = up / np.linalg.norm(up)
c2w = np.concatenate(
[np.stack([right, up, -lookat], axis=-1), camera_position[:, None]],
axis=-1,
)
w2c = np.zeros((4, 4))
w2c[:3, :3] = np.transpose(c2w[:3, :3], (1, 0))
w2c[:3, 3:] = -np.matmul(np.transpose(c2w[:3, :3], (1, 0)), c2w[:3, 3:])
w2c[3, 3] = 1.0
return w2c.astype(np.float32)
def stride_from_shape(shape):
stride = [1]
for x in reversed(shape[1:]):
stride.append(stride[-1] * x)
return list(reversed(stride))
def scatter_add_nd_with_count(input, count, indices, values, weights=None):
# input: [..., C], D dimension + C channel
# count: [..., 1], D dimension
# indices: [N, D], long
# values: [N, C]
D = indices.shape[-1]
C = input.shape[-1]
size = input.shape[:-1]
stride = stride_from_shape(size)
assert len(size) == D
input = input.view(-1, C) # [HW, C]
count = count.view(-1, 1)
flatten_indices = (
indices * torch.tensor(stride, dtype=torch.long, device=indices.device)
).sum(
-1
) # [N]
if weights is None:
weights = torch.ones_like(values[..., :1])
input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values)
count.scatter_add_(0, flatten_indices.unsqueeze(1), weights)
return input.view(*size, C), count.view(*size, 1)
def linear_grid_put_2d(H, W, coords, values, return_count=False):
# coords: [N, 2], float in [0, 1]
# values: [N, C]
C = values.shape[-1]
indices = coords * torch.tensor(
[H - 1, W - 1], dtype=torch.float32, device=coords.device
)
indices_00 = indices.floor().long() # [N, 2]
indices_00[:, 0].clamp_(0, H - 2)
indices_00[:, 1].clamp_(0, W - 2)
indices_01 = indices_00 + torch.tensor(
[0, 1], dtype=torch.long, device=indices.device
)
indices_10 = indices_00 + torch.tensor(
[1, 0], dtype=torch.long, device=indices.device
)
indices_11 = indices_00 + torch.tensor(
[1, 1], dtype=torch.long, device=indices.device
)
h = indices[..., 0] - indices_00[..., 0].float()
w = indices[..., 1] - indices_00[..., 1].float()
w_00 = (1 - h) * (1 - w)
w_01 = (1 - h) * w
w_10 = h * (1 - w)
w_11 = h * w
result = torch.zeros(
H, W, C, device=values.device, dtype=values.dtype
) # [H, W, C]
count = torch.zeros(
H, W, 1, device=values.device, dtype=values.dtype
) # [H, W, 1]
weights = torch.ones_like(values[..., :1]) # [N, 1]
result, count = scatter_add_nd_with_count(
result,
count,
indices_00,
values * w_00.unsqueeze(1),
weights * w_00.unsqueeze(1),
)
result, count = scatter_add_nd_with_count(
result,
count,
indices_01,
values * w_01.unsqueeze(1),
weights * w_01.unsqueeze(1),
)
result, count = scatter_add_nd_with_count(
result,
count,
indices_10,
values * w_10.unsqueeze(1),
weights * w_10.unsqueeze(1),
)
result, count = scatter_add_nd_with_count(
result,
count,
indices_11,
values * w_11.unsqueeze(1),
weights * w_11.unsqueeze(1),
)
if return_count:
return result, count
mask = count.squeeze(-1) > 0
result[mask] = result[mask] / count[mask].repeat(1, C)
return result
def meshVerticeInpaint_smooth(texture, mask, vtx_pos, vtx_uv, pos_idx, uv_idx):
texture_height, texture_width, texture_channel = texture.shape
vtx_num = vtx_pos.shape[0]
vtx_mask = np.zeros(vtx_num, dtype=np.float32)
vtx_color = [
np.zeros(texture_channel, dtype=np.float32) for _ in range(vtx_num)
]
uncolored_vtxs = []
G = [[] for _ in range(vtx_num)]
for i in range(uv_idx.shape[0]):
for k in range(3):
vtx_uv_idx = uv_idx[i, k]
vtx_idx = pos_idx[i, k]
uv_v = int(round(vtx_uv[vtx_uv_idx, 0] * (texture_width - 1)))
uv_u = int(
round((1.0 - vtx_uv[vtx_uv_idx, 1]) * (texture_height - 1))
)
if mask[uv_u, uv_v] > 0:
vtx_mask[vtx_idx] = 1.0
vtx_color[vtx_idx] = texture[uv_u, uv_v]
else:
uncolored_vtxs.append(vtx_idx)
G[pos_idx[i, k]].append(pos_idx[i, (k + 1) % 3])
smooth_count = 2
last_uncolored_vtx_count = 0
while smooth_count > 0:
uncolored_vtx_count = 0
for vtx_idx in uncolored_vtxs:
sum_color = np.zeros(texture_channel, dtype=np.float32)
total_weight = 0.0
vtx_0 = vtx_pos[vtx_idx]
for connected_idx in G[vtx_idx]:
if vtx_mask[connected_idx] > 0:
vtx1 = vtx_pos[connected_idx]
dist = np.sqrt(np.sum((vtx_0 - vtx1) ** 2))
dist_weight = 1.0 / max(dist, 1e-4)
dist_weight *= dist_weight
sum_color += vtx_color[connected_idx] * dist_weight
total_weight += dist_weight
if total_weight > 0:
vtx_color[vtx_idx] = sum_color / total_weight
vtx_mask[vtx_idx] = 1.0
else:
uncolored_vtx_count += 1
if last_uncolored_vtx_count == uncolored_vtx_count:
smooth_count -= 1
else:
smooth_count += 1
last_uncolored_vtx_count = uncolored_vtx_count
new_texture = texture.copy()
new_mask = mask.copy()
for face_idx in range(uv_idx.shape[0]):
for k in range(3):
vtx_uv_idx = uv_idx[face_idx, k]
vtx_idx = pos_idx[face_idx, k]
if vtx_mask[vtx_idx] == 1.0:
uv_v = int(round(vtx_uv[vtx_uv_idx, 0] * (texture_width - 1)))
uv_u = int(
round((1.0 - vtx_uv[vtx_uv_idx, 1]) * (texture_height - 1))
)
new_texture[uv_u, uv_v] = vtx_color[vtx_idx]
new_mask[uv_u, uv_v] = 255
return new_texture, new_mask
def mesh_uv_wrap(mesh):
if isinstance(mesh, trimesh.Scene):
mesh = mesh.dump(concatenate=True)
if len(mesh.faces) > 500000000:
raise ValueError(
"The mesh has more than 500,000,000 faces, which is not supported."
)
vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
mesh.vertices = mesh.vertices[vmapping]
mesh.faces = indices
mesh.visual.uv = uvs
return mesh
class MeshRender:
def __init__(
self,
camera_distance=1.45,
default_resolution=1024,
texture_size=1024,
use_antialias=True,
max_mip_level=None,
filter_mode="linear",
bake_mode="linear",
raster_mode="cr",
device="cuda",
):
self.device = device
self.set_default_render_resolution(default_resolution)
self.set_default_texture_resolution(texture_size)
self.camera_distance = camera_distance
self.use_antialias = use_antialias
self.max_mip_level = max_mip_level
self.filter_mode = filter_mode
self.bake_angle_thres = 75
self.bake_unreliable_kernel_size = int(
(2 / 512)
* max(self.default_resolution[0], self.default_resolution[1])
)
self.bake_mode = bake_mode
self.raster_mode = raster_mode
if self.raster_mode == "cr":
import custom_rasterizer as cr
self.raster = cr
else:
raise f"No raster named {self.raster_mode}"
fov = 30
self.camera_proj_mat = get_perspective_projection_matrix(
fov,
self.default_resolution[1] / self.default_resolution[0],
0.01,
100.0,
)
def raster_rasterize(
self, pos, tri, resolution, ranges=None, grad_db=True
):
if self.raster_mode == "cr":
rast_out_db = None
if pos.dim() == 2:
pos = pos.unsqueeze(0)
findices, barycentric = self.raster.rasterize(pos, tri, resolution)
rast_out = torch.cat((barycentric, findices.unsqueeze(-1)), dim=-1)
rast_out = rast_out.unsqueeze(0)
else:
raise f"No raster named {self.raster_mode}"
return rast_out, rast_out_db
def raster_interpolate(
self, uv, rast_out, uv_idx, rast_db=None, diff_attrs=None
):
if self.raster_mode == "cr":
textd = None
barycentric = rast_out[0, ..., :-1]
findices = rast_out[0, ..., -1]
if uv.dim() == 2:
uv = uv.unsqueeze(0)
textc = self.raster.interpolate(uv, findices, barycentric, uv_idx)
else:
raise f"No raster named {self.raster_mode}"
return textc, textd
def load_mesh(
self,
mesh,
):
vtx_pos, pos_idx, vtx_uv, uv_idx, texture_data = load_mesh(mesh)
self.mesh_copy = mesh
self.set_mesh(
vtx_pos,
pos_idx,
vtx_uv=vtx_uv,
uv_idx=uv_idx,
)
if texture_data is not None:
self.set_texture(texture_data)
def save_mesh(self):
texture_data = self.get_texture()
texture_data = Image.fromarray((texture_data * 255).astype(np.uint8))
return save_mesh(self.mesh_copy, texture_data)
def set_mesh(
self,
vtx_pos,
pos_idx,
vtx_uv=None,
uv_idx=None,
):
self.vtx_pos = torch.from_numpy(vtx_pos).to(self.device).float()
self.pos_idx = torch.from_numpy(pos_idx).to(self.device).to(torch.int)
if (vtx_uv is not None) and (uv_idx is not None):
self.vtx_uv = torch.from_numpy(vtx_uv).to(self.device).float()
self.uv_idx = (
torch.from_numpy(uv_idx).to(self.device).to(torch.int)
)
else:
self.vtx_uv = None
self.uv_idx = None
self.vtx_pos[:, [0, 1]] = -self.vtx_pos[:, [0, 1]]
self.vtx_pos[:, [1, 2]] = self.vtx_pos[:, [2, 1]]
if (vtx_uv is not None) and (uv_idx is not None):
self.vtx_uv[:, 1] = 1.0 - self.vtx_uv[:, 1]
def set_texture(self, tex):
if isinstance(tex, np.ndarray):
tex = Image.fromarray((tex * 255).astype(np.uint8))
elif isinstance(tex, torch.Tensor):
tex = tex.cpu().numpy()
tex = Image.fromarray((tex * 255).astype(np.uint8))
tex = tex.resize(self.texture_size).convert("RGB")
tex = np.array(tex) / 255.0
self.tex = torch.from_numpy(tex).to(self.device)
self.tex = self.tex.float()
def set_default_render_resolution(self, default_resolution):
if isinstance(default_resolution, int):
default_resolution = (default_resolution, default_resolution)
self.default_resolution = default_resolution
def set_default_texture_resolution(self, texture_size):
if isinstance(texture_size, int):
texture_size = (texture_size, texture_size)
self.texture_size = texture_size
def get_mesh(self):
vtx_pos = self.vtx_pos.cpu().numpy()
pos_idx = self.pos_idx.cpu().numpy()
vtx_uv = self.vtx_uv.cpu().numpy()
uv_idx = self.uv_idx.cpu().numpy()
# 坐标变换的逆变换
vtx_pos[:, [1, 2]] = vtx_pos[:, [2, 1]]
vtx_pos[:, [0, 1]] = -vtx_pos[:, [0, 1]]
vtx_uv[:, 1] = 1.0 - vtx_uv[:, 1]
return vtx_pos, pos_idx, vtx_uv, uv_idx
def get_texture(self):
return self.tex.cpu().numpy()
def render_sketch_from_depth(self, depth_image):
depth_image_np = depth_image.cpu().numpy()
depth_image_np = (depth_image_np * 255).astype(np.uint8)
depth_edges = cv2.Canny(depth_image_np, 30, 80)
combined_edges = depth_edges
sketch_image = (
torch.from_numpy(combined_edges).to(depth_image.device).float()
/ 255.0
)
sketch_image = sketch_image.unsqueeze(-1)
return sketch_image
def back_project(
self, image, elev, azim, camera_distance=None, center=None, method=None
):
if isinstance(image, Image.Image):
image = torch.tensor(np.array(image) / 255.0)
elif isinstance(image, np.ndarray):
image = torch.tensor(image)
if image.dim() == 2:
image = image.unsqueeze(-1)
image = image.float().to(self.device)
resolution = image.shape[:2]
channel = image.shape[-1]
texture = torch.zeros(self.texture_size + (channel,)).to(self.device)
cos_map = torch.zeros(self.texture_size + (1,)).to(self.device)
proj = self.camera_proj_mat
r_mv = get_mv_matrix(
elev=elev,
azim=azim,
camera_distance=(
self.camera_distance
if camera_distance is None
else camera_distance
),
center=center,
)
pos_camera = transform_pos(r_mv, self.vtx_pos, keepdim=True)
pos_clip = transform_pos(proj, pos_camera)
pos_camera = pos_camera[:, :3] / pos_camera[:, 3:4]
v0 = pos_camera[self.pos_idx[:, 0], :]
v1 = pos_camera[self.pos_idx[:, 1], :]
v2 = pos_camera[self.pos_idx[:, 2], :]
face_normals = F.normalize(
torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1
)
vertex_normals = trimesh.geometry.mean_vertex_normals(
vertex_count=self.vtx_pos.shape[0],
faces=self.pos_idx.cpu(),
face_normals=face_normals.cpu(),
)
vertex_normals = (
torch.from_numpy(vertex_normals)
.float()
.to(self.device)
.contiguous()
)
tex_depth = pos_camera[:, 2].reshape(1, -1, 1).contiguous()
rast_out, rast_out_db = self.raster_rasterize(
pos_clip, self.pos_idx, resolution=resolution
)
visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0, ...]
normal, _ = self.raster_interpolate(
vertex_normals[None, ...], rast_out, self.pos_idx
)
normal = normal[0, ...]
uv, _ = self.raster_interpolate(
self.vtx_uv[None, ...], rast_out, self.uv_idx
)
depth, _ = self.raster_interpolate(tex_depth, rast_out, self.pos_idx)
depth = depth[0, ...]
depth_max, depth_min = (
depth[visible_mask > 0].max(),
depth[visible_mask > 0].min(),
)
depth_normalized = (depth - depth_min) / (depth_max - depth_min)
depth_image = depth_normalized * visible_mask # Mask out background.
sketch_image = self.render_sketch_from_depth(depth_image)
cv2.imwrite("d_depth.png", depth_image.cpu().numpy() * 255)
cv2.imwrite("d_normal.png", normal.cpu().numpy() * 255)
cv2.imwrite(
"d_image.png", image.cpu().numpy()[..., :3][..., ::-1] * 255
)
cv2.imwrite("d_sketch_image.png", sketch_image.cpu().numpy() * 255)
cv2.imwrite("d_uv1.png", uv.cpu().numpy()[0, ..., 0] * 255)
cv2.imwrite("d_uv2.png", uv.cpu().numpy()[0, ..., 1] * 255)
# p uv[0,...,0].mean(axis=0)
# import pdb; pdb.set_trace()
# depth_image = None
# normal = None
# image = None
sketch_image = self.render_sketch_from_depth(depth_image)
channel = image.shape[-1]
lookat = torch.tensor([[0, 0, -1]], device=self.device)
cos_image = torch.nn.functional.cosine_similarity(
lookat, normal.view(-1, 3)
)
cos_image = cos_image.view(normal.shape[0], normal.shape[1], 1)
cos_thres = np.cos(self.bake_angle_thres / 180 * np.pi)
cos_image[cos_image < cos_thres] = 0
# shrink
kernel_size = self.bake_unreliable_kernel_size * 2 + 1
kernel = torch.ones(
(1, 1, kernel_size, kernel_size), dtype=torch.float32
).to(sketch_image.device)
visible_mask = visible_mask.permute(2, 0, 1).unsqueeze(0).float()
visible_mask = F.conv2d(
1.0 - visible_mask, kernel, padding=kernel_size // 2
)
visible_mask = 1.0 - (visible_mask > 0).float() # 二值化
visible_mask = visible_mask.squeeze(0).permute(1, 2, 0)
sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
sketch_image = F.conv2d(sketch_image, kernel, padding=kernel_size // 2)
sketch_image = (sketch_image > 0).float() # 二值化
sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
visible_mask = visible_mask * (sketch_image < 0.5)
cos_image[visible_mask == 0] = 0
proj_mask = (visible_mask != 0).view(-1)
uv = uv.squeeze(0).contiguous().view(-1, 2)[proj_mask]
image = image.squeeze(0).contiguous().view(-1, channel)[proj_mask]
cos_image = cos_image.contiguous().view(-1, 1)[proj_mask]
sketch_image = sketch_image.contiguous().view(-1, 1)[proj_mask]
import pdb
pdb.set_trace()
texture = linear_grid_put_2d(
self.texture_size[1], self.texture_size[0], uv[..., [1, 0]], image
)
cos_map = linear_grid_put_2d(
self.texture_size[1],
self.texture_size[0],
uv[..., [1, 0]],
cos_image,
)
boundary_map = linear_grid_put_2d(
self.texture_size[1],
self.texture_size[0],
uv[..., [1, 0]],
sketch_image,
)
return texture, cos_map, boundary_map
@torch.no_grad()
def fast_bake_texture(self, textures, cos_maps):
channel = textures[0].shape[-1]
texture_merge = torch.zeros(self.texture_size + (channel,)).to(
self.device
)
trust_map_merge = torch.zeros(self.texture_size + (1,)).to(self.device)
for texture, cos_map in zip(textures, cos_maps):
view_sum = (cos_map > 0).sum()
painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()
if painted_sum / view_sum > 0.99:
continue
texture_merge += texture * cos_map
trust_map_merge += cos_map
texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8)
return texture_merge, trust_map_merge > 1e-8
def uv_inpaint(self, texture, mask):
if isinstance(texture, torch.Tensor):
texture_np = texture.cpu().numpy()
elif isinstance(texture, np.ndarray):
texture_np = texture
elif isinstance(texture, Image.Image):
texture_np = np.array(texture) / 255.0
vtx_pos, pos_idx, vtx_uv, uv_idx = self.get_mesh()
texture_np, mask = meshVerticeInpaint_smooth(
texture_np, mask, vtx_pos, vtx_uv, pos_idx, uv_idx
)
texture_np = cv2.inpaint(
(texture_np * 255).astype(np.uint8), 255 - mask, 3, cv2.INPAINT_NS
)
return texture_np
def get_images_from_file(img_path: str, img_size: int) -> list[np.array]:
input_image = Image.open(img_path)
view_images = np.array(input_image)
view_images = np.concatenate(
[view_images[:img_size, ...], view_images[img_size:, ...]], axis=1
)
images = np.split(view_images, view_images.shape[1] // img_size, axis=1)
return images
def bake_from_multiview(
render, views, camera_elevs, camera_azims, view_weights, method="fast"
):
project_textures, project_weighted_cos_maps = [], []
project_boundary_maps = []
for view, camera_elev, camera_azim, weight in zip(
views, camera_elevs, camera_azims, view_weights
):
project_texture, project_cos_map, project_boundary_map = (
render.back_project(view, camera_elev, camera_azim)
)
project_cos_map = weight * (project_cos_map**4)
project_textures.append(project_texture)
project_weighted_cos_maps.append(project_cos_map)
project_boundary_maps.append(project_boundary_map)
if method == "fast":
texture, ori_trust_map = render.fast_bake_texture(
project_textures, project_weighted_cos_maps
)
else:
raise f"no method {method}"
return texture, ori_trust_map > 1e-8
def post_process(texture: np.ndarray, iter: int = 2) -> np.ndarray:
for _ in range(iter):
texture = cv2.fastNlMeansDenoisingColored(texture, None, 11, 11, 9, 25)
texture = cv2.bilateralFilter(
texture, d=7, sigmaColor=80, sigmaSpace=80
)
return texture
class Image_Super_Net:
def __init__(self, device="cuda"):
from diffusers import StableDiffusionUpscalePipeline
self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler",
torch_dtype=torch.float16,
).to(device)
self.up_pipeline_x4.set_progress_bar_config(disable=True)
def __call__(self, image, prompt=""):
with torch.no_grad():
upscaled_image = self.up_pipeline_x4(
prompt=[prompt],
image=image,
num_inference_steps=10,
).images[0]
return upscaled_image
class Image_GANNet:
def __init__(self, outscale: int):
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
self.outscale = outscale
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
self.upsampler = RealESRGANer(
scale=4,
model_path="/home/users/xinjie.wang/xinjie/Real-ESRGAN/weights/RealESRGAN_x4plus.pth",
model=model,
pre_pad=0,
half=True,
)
def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
if isinstance(image, Image.Image):
image = np.array(image)
output, _ = self.upsampler.enhance(image, outscale=self.outscale)
return Image.fromarray(output)
if __name__ == "__main__":
device = "cuda"
# super_model = Image_Super_Net(device)
super_model = Image_GANNet(outscale=4)
selected_camera_elevs = [20, 20, 20, -10, -10, -10]
selected_camera_azims = [-180, -60, 60, -120, 0, 120]
selected_view_weights = [1, 0.2, 0.2, 0.2, 1, 0.2]
# selected_view_weights = [1, 0.1, 0.5, 0.1, 0.05, 0.05]
multiviews = get_images_from_file(
"scripts/apps/texture_sessions/mfq4e7u4ko/multi_view/color_sample1.png",
512,
)
target_image_size = (2048, 2048)
render = MeshRender(
camera_distance=5,
default_resolution=2048,
texture_size=2048,
)
mesh = trimesh.load("scripts/apps/assets/example_texture/meshes/robot.obj")
from asset3d_gen.data.utils import normalize_vertices_array
mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
mesh = mesh_uv_wrap(mesh)
render.load_mesh(mesh)
# multiviews = [Image.fromarray(img) for img in multiviews]
# multiviews = [Image.fromarray(img).convert("RGB") for img in multiviews]
# for idx, img in enumerate(multiviews):
# img.save(f"robot/raw/res_{idx}.png")
multiviews = [super_model(img) for img in multiviews]
multiviews = [img.convert("RGB") for img in multiviews]
for idx, img in enumerate(multiviews):
img.save(f"robot/super_gan_res_{idx}.png")
texture, mask = bake_from_multiview(
render,
multiviews,
selected_camera_elevs,
selected_camera_azims,
selected_view_weights,
)
texture_np = (texture.cpu().numpy() * 255).astype(np.uint8)[..., :3][
..., ::-1
]
cv2.imwrite("robot/raw_texture.png", texture_np)
print("texture done.")
mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
texture_np = render.uv_inpaint(texture, mask_np)
cv2.imwrite("robot/inpaint_texture.png", texture_np[..., ::-1])
# texture_np = post_process(texture_np, 2)
# cv2.imwrite("robot/inpaint_conv_texture.png", texture_np[..., ::-1])
print("inpaint done.")
texture = torch.tensor(texture_np / 255).float().to(texture.device)
render.set_texture(texture)
textured_mesh = render.save_mesh()
_ = textured_mesh.export("robot/robot.obj")