from PIL import Image import torch import torch.nn.functional as F import numpy as np import math import trimesh import cv2 import xatlas from typing import Union def get_perspective_projection_matrix(fovy, aspect_wh, near, far): fovy_rad = math.radians(fovy) return np.array( [ [1.0 / (math.tan(fovy_rad / 2.0) * aspect_wh), 0, 0, 0], [0, 1.0 / math.tan(fovy_rad / 2.0), 0, 0], [ 0, 0, -(far + near) / (far - near), -2.0 * far * near / (far - near), ], [0, 0, -1, 0], ] ).astype(np.float32) def load_mesh(mesh): vtx_pos = mesh.vertices if hasattr(mesh, "vertices") else None pos_idx = mesh.faces if hasattr(mesh, "faces") else None vtx_uv = mesh.visual.uv if hasattr(mesh.visual, "uv") else None uv_idx = mesh.faces if hasattr(mesh, "faces") else None texture_data = None return vtx_pos, pos_idx, vtx_uv, uv_idx, texture_data def save_mesh(mesh, texture_data): material = trimesh.visual.texture.SimpleMaterial( image=texture_data, diffuse=(255, 255, 255) ) texture_visuals = trimesh.visual.TextureVisuals( uv=mesh.visual.uv, image=texture_data, material=material ) mesh.visual = texture_visuals return mesh def transform_pos(mtx, pos, keepdim=False): t_mtx = ( torch.from_numpy(mtx).to(pos.device) if isinstance(mtx, np.ndarray) else mtx ) if pos.shape[-1] == 3: posw = torch.cat( [pos, torch.ones([pos.shape[0], 1]).to(pos.device)], axis=1 ) else: posw = pos if keepdim: return torch.matmul(posw, t_mtx.t())[...] else: return torch.matmul(posw, t_mtx.t())[None, ...] def get_mv_matrix(elev, azim, camera_distance, center=None): elev = -elev elev_rad = math.radians(elev) azim_rad = math.radians(azim) camera_position = np.array( [ camera_distance * math.cos(elev_rad) * math.cos(azim_rad), camera_distance * math.cos(elev_rad) * math.sin(azim_rad), camera_distance * math.sin(elev_rad), ] ) if center is None: center = np.array([0, 0, 0]) else: center = np.array(center) lookat = center - camera_position lookat = lookat / np.linalg.norm(lookat) up = np.array([0, 0, 1.0]) right = np.cross(lookat, up) right = right / np.linalg.norm(right) up = np.cross(right, lookat) up = up / np.linalg.norm(up) c2w = np.concatenate( [np.stack([right, up, -lookat], axis=-1), camera_position[:, None]], axis=-1, ) w2c = np.zeros((4, 4)) w2c[:3, :3] = np.transpose(c2w[:3, :3], (1, 0)) w2c[:3, 3:] = -np.matmul(np.transpose(c2w[:3, :3], (1, 0)), c2w[:3, 3:]) w2c[3, 3] = 1.0 return w2c.astype(np.float32) def stride_from_shape(shape): stride = [1] for x in reversed(shape[1:]): stride.append(stride[-1] * x) return list(reversed(stride)) def scatter_add_nd_with_count(input, count, indices, values, weights=None): # input: [..., C], D dimension + C channel # count: [..., 1], D dimension # indices: [N, D], long # values: [N, C] D = indices.shape[-1] C = input.shape[-1] size = input.shape[:-1] stride = stride_from_shape(size) assert len(size) == D input = input.view(-1, C) # [HW, C] count = count.view(-1, 1) flatten_indices = ( indices * torch.tensor(stride, dtype=torch.long, device=indices.device) ).sum( -1 ) # [N] if weights is None: weights = torch.ones_like(values[..., :1]) input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values) count.scatter_add_(0, flatten_indices.unsqueeze(1), weights) return input.view(*size, C), count.view(*size, 1) def linear_grid_put_2d(H, W, coords, values, return_count=False): # coords: [N, 2], float in [0, 1] # values: [N, C] C = values.shape[-1] indices = coords * torch.tensor( [H - 1, W - 1], dtype=torch.float32, device=coords.device ) indices_00 = indices.floor().long() # [N, 2] indices_00[:, 0].clamp_(0, H - 2) indices_00[:, 1].clamp_(0, W - 2) indices_01 = indices_00 + torch.tensor( [0, 1], dtype=torch.long, device=indices.device ) indices_10 = indices_00 + torch.tensor( [1, 0], dtype=torch.long, device=indices.device ) indices_11 = indices_00 + torch.tensor( [1, 1], dtype=torch.long, device=indices.device ) h = indices[..., 0] - indices_00[..., 0].float() w = indices[..., 1] - indices_00[..., 1].float() w_00 = (1 - h) * (1 - w) w_01 = (1 - h) * w w_10 = h * (1 - w) w_11 = h * w result = torch.zeros( H, W, C, device=values.device, dtype=values.dtype ) # [H, W, C] count = torch.zeros( H, W, 1, device=values.device, dtype=values.dtype ) # [H, W, 1] weights = torch.ones_like(values[..., :1]) # [N, 1] result, count = scatter_add_nd_with_count( result, count, indices_00, values * w_00.unsqueeze(1), weights * w_00.unsqueeze(1), ) result, count = scatter_add_nd_with_count( result, count, indices_01, values * w_01.unsqueeze(1), weights * w_01.unsqueeze(1), ) result, count = scatter_add_nd_with_count( result, count, indices_10, values * w_10.unsqueeze(1), weights * w_10.unsqueeze(1), ) result, count = scatter_add_nd_with_count( result, count, indices_11, values * w_11.unsqueeze(1), weights * w_11.unsqueeze(1), ) if return_count: return result, count mask = count.squeeze(-1) > 0 result[mask] = result[mask] / count[mask].repeat(1, C) return result def meshVerticeInpaint_smooth(texture, mask, vtx_pos, vtx_uv, pos_idx, uv_idx): texture_height, texture_width, texture_channel = texture.shape vtx_num = vtx_pos.shape[0] vtx_mask = np.zeros(vtx_num, dtype=np.float32) vtx_color = [ np.zeros(texture_channel, dtype=np.float32) for _ in range(vtx_num) ] uncolored_vtxs = [] G = [[] for _ in range(vtx_num)] for i in range(uv_idx.shape[0]): for k in range(3): vtx_uv_idx = uv_idx[i, k] vtx_idx = pos_idx[i, k] uv_v = int(round(vtx_uv[vtx_uv_idx, 0] * (texture_width - 1))) uv_u = int( round((1.0 - vtx_uv[vtx_uv_idx, 1]) * (texture_height - 1)) ) if mask[uv_u, uv_v] > 0: vtx_mask[vtx_idx] = 1.0 vtx_color[vtx_idx] = texture[uv_u, uv_v] else: uncolored_vtxs.append(vtx_idx) G[pos_idx[i, k]].append(pos_idx[i, (k + 1) % 3]) smooth_count = 2 last_uncolored_vtx_count = 0 while smooth_count > 0: uncolored_vtx_count = 0 for vtx_idx in uncolored_vtxs: sum_color = np.zeros(texture_channel, dtype=np.float32) total_weight = 0.0 vtx_0 = vtx_pos[vtx_idx] for connected_idx in G[vtx_idx]: if vtx_mask[connected_idx] > 0: vtx1 = vtx_pos[connected_idx] dist = np.sqrt(np.sum((vtx_0 - vtx1) ** 2)) dist_weight = 1.0 / max(dist, 1e-4) dist_weight *= dist_weight sum_color += vtx_color[connected_idx] * dist_weight total_weight += dist_weight if total_weight > 0: vtx_color[vtx_idx] = sum_color / total_weight vtx_mask[vtx_idx] = 1.0 else: uncolored_vtx_count += 1 if last_uncolored_vtx_count == uncolored_vtx_count: smooth_count -= 1 else: smooth_count += 1 last_uncolored_vtx_count = uncolored_vtx_count new_texture = texture.copy() new_mask = mask.copy() for face_idx in range(uv_idx.shape[0]): for k in range(3): vtx_uv_idx = uv_idx[face_idx, k] vtx_idx = pos_idx[face_idx, k] if vtx_mask[vtx_idx] == 1.0: uv_v = int(round(vtx_uv[vtx_uv_idx, 0] * (texture_width - 1))) uv_u = int( round((1.0 - vtx_uv[vtx_uv_idx, 1]) * (texture_height - 1)) ) new_texture[uv_u, uv_v] = vtx_color[vtx_idx] new_mask[uv_u, uv_v] = 255 return new_texture, new_mask def mesh_uv_wrap(mesh): if isinstance(mesh, trimesh.Scene): mesh = mesh.dump(concatenate=True) if len(mesh.faces) > 500000000: raise ValueError( "The mesh has more than 500,000,000 faces, which is not supported." ) vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces) mesh.vertices = mesh.vertices[vmapping] mesh.faces = indices mesh.visual.uv = uvs return mesh class MeshRender: def __init__( self, camera_distance=1.45, default_resolution=1024, texture_size=1024, use_antialias=True, max_mip_level=None, filter_mode="linear", bake_mode="linear", raster_mode="cr", device="cuda", ): self.device = device self.set_default_render_resolution(default_resolution) self.set_default_texture_resolution(texture_size) self.camera_distance = camera_distance self.use_antialias = use_antialias self.max_mip_level = max_mip_level self.filter_mode = filter_mode self.bake_angle_thres = 75 self.bake_unreliable_kernel_size = int( (2 / 512) * max(self.default_resolution[0], self.default_resolution[1]) ) self.bake_mode = bake_mode self.raster_mode = raster_mode if self.raster_mode == "cr": import custom_rasterizer as cr self.raster = cr else: raise f"No raster named {self.raster_mode}" fov = 30 self.camera_proj_mat = get_perspective_projection_matrix( fov, self.default_resolution[1] / self.default_resolution[0], 0.01, 100.0, ) def raster_rasterize( self, pos, tri, resolution, ranges=None, grad_db=True ): if self.raster_mode == "cr": rast_out_db = None if pos.dim() == 2: pos = pos.unsqueeze(0) findices, barycentric = self.raster.rasterize(pos, tri, resolution) rast_out = torch.cat((barycentric, findices.unsqueeze(-1)), dim=-1) rast_out = rast_out.unsqueeze(0) else: raise f"No raster named {self.raster_mode}" return rast_out, rast_out_db def raster_interpolate( self, uv, rast_out, uv_idx, rast_db=None, diff_attrs=None ): if self.raster_mode == "cr": textd = None barycentric = rast_out[0, ..., :-1] findices = rast_out[0, ..., -1] if uv.dim() == 2: uv = uv.unsqueeze(0) textc = self.raster.interpolate(uv, findices, barycentric, uv_idx) else: raise f"No raster named {self.raster_mode}" return textc, textd def load_mesh( self, mesh, ): vtx_pos, pos_idx, vtx_uv, uv_idx, texture_data = load_mesh(mesh) self.mesh_copy = mesh self.set_mesh( vtx_pos, pos_idx, vtx_uv=vtx_uv, uv_idx=uv_idx, ) if texture_data is not None: self.set_texture(texture_data) def save_mesh(self): texture_data = self.get_texture() texture_data = Image.fromarray((texture_data * 255).astype(np.uint8)) return save_mesh(self.mesh_copy, texture_data) def set_mesh( self, vtx_pos, pos_idx, vtx_uv=None, uv_idx=None, ): self.vtx_pos = torch.from_numpy(vtx_pos).to(self.device).float() self.pos_idx = torch.from_numpy(pos_idx).to(self.device).to(torch.int) if (vtx_uv is not None) and (uv_idx is not None): self.vtx_uv = torch.from_numpy(vtx_uv).to(self.device).float() self.uv_idx = ( torch.from_numpy(uv_idx).to(self.device).to(torch.int) ) else: self.vtx_uv = None self.uv_idx = None self.vtx_pos[:, [0, 1]] = -self.vtx_pos[:, [0, 1]] self.vtx_pos[:, [1, 2]] = self.vtx_pos[:, [2, 1]] if (vtx_uv is not None) and (uv_idx is not None): self.vtx_uv[:, 1] = 1.0 - self.vtx_uv[:, 1] def set_texture(self, tex): if isinstance(tex, np.ndarray): tex = Image.fromarray((tex * 255).astype(np.uint8)) elif isinstance(tex, torch.Tensor): tex = tex.cpu().numpy() tex = Image.fromarray((tex * 255).astype(np.uint8)) tex = tex.resize(self.texture_size).convert("RGB") tex = np.array(tex) / 255.0 self.tex = torch.from_numpy(tex).to(self.device) self.tex = self.tex.float() def set_default_render_resolution(self, default_resolution): if isinstance(default_resolution, int): default_resolution = (default_resolution, default_resolution) self.default_resolution = default_resolution def set_default_texture_resolution(self, texture_size): if isinstance(texture_size, int): texture_size = (texture_size, texture_size) self.texture_size = texture_size def get_mesh(self): vtx_pos = self.vtx_pos.cpu().numpy() pos_idx = self.pos_idx.cpu().numpy() vtx_uv = self.vtx_uv.cpu().numpy() uv_idx = self.uv_idx.cpu().numpy() # 坐标变换的逆变换 vtx_pos[:, [1, 2]] = vtx_pos[:, [2, 1]] vtx_pos[:, [0, 1]] = -vtx_pos[:, [0, 1]] vtx_uv[:, 1] = 1.0 - vtx_uv[:, 1] return vtx_pos, pos_idx, vtx_uv, uv_idx def get_texture(self): return self.tex.cpu().numpy() def render_sketch_from_depth(self, depth_image): depth_image_np = depth_image.cpu().numpy() depth_image_np = (depth_image_np * 255).astype(np.uint8) depth_edges = cv2.Canny(depth_image_np, 30, 80) combined_edges = depth_edges sketch_image = ( torch.from_numpy(combined_edges).to(depth_image.device).float() / 255.0 ) sketch_image = sketch_image.unsqueeze(-1) return sketch_image def back_project( self, image, elev, azim, camera_distance=None, center=None, method=None ): if isinstance(image, Image.Image): image = torch.tensor(np.array(image) / 255.0) elif isinstance(image, np.ndarray): image = torch.tensor(image) if image.dim() == 2: image = image.unsqueeze(-1) image = image.float().to(self.device) resolution = image.shape[:2] channel = image.shape[-1] texture = torch.zeros(self.texture_size + (channel,)).to(self.device) cos_map = torch.zeros(self.texture_size + (1,)).to(self.device) proj = self.camera_proj_mat r_mv = get_mv_matrix( elev=elev, azim=azim, camera_distance=( self.camera_distance if camera_distance is None else camera_distance ), center=center, ) pos_camera = transform_pos(r_mv, self.vtx_pos, keepdim=True) pos_clip = transform_pos(proj, pos_camera) pos_camera = pos_camera[:, :3] / pos_camera[:, 3:4] v0 = pos_camera[self.pos_idx[:, 0], :] v1 = pos_camera[self.pos_idx[:, 1], :] v2 = pos_camera[self.pos_idx[:, 2], :] face_normals = F.normalize( torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1 ) vertex_normals = trimesh.geometry.mean_vertex_normals( vertex_count=self.vtx_pos.shape[0], faces=self.pos_idx.cpu(), face_normals=face_normals.cpu(), ) vertex_normals = ( torch.from_numpy(vertex_normals) .float() .to(self.device) .contiguous() ) tex_depth = pos_camera[:, 2].reshape(1, -1, 1).contiguous() rast_out, rast_out_db = self.raster_rasterize( pos_clip, self.pos_idx, resolution=resolution ) visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0, ...] normal, _ = self.raster_interpolate( vertex_normals[None, ...], rast_out, self.pos_idx ) normal = normal[0, ...] uv, _ = self.raster_interpolate( self.vtx_uv[None, ...], rast_out, self.uv_idx ) depth, _ = self.raster_interpolate(tex_depth, rast_out, self.pos_idx) depth = depth[0, ...] depth_max, depth_min = ( depth[visible_mask > 0].max(), depth[visible_mask > 0].min(), ) depth_normalized = (depth - depth_min) / (depth_max - depth_min) depth_image = depth_normalized * visible_mask # Mask out background. sketch_image = self.render_sketch_from_depth(depth_image) cv2.imwrite("d_depth.png", depth_image.cpu().numpy() * 255) cv2.imwrite("d_normal.png", normal.cpu().numpy() * 255) cv2.imwrite( "d_image.png", image.cpu().numpy()[..., :3][..., ::-1] * 255 ) cv2.imwrite("d_sketch_image.png", sketch_image.cpu().numpy() * 255) cv2.imwrite("d_uv1.png", uv.cpu().numpy()[0, ..., 0] * 255) cv2.imwrite("d_uv2.png", uv.cpu().numpy()[0, ..., 1] * 255) # p uv[0,...,0].mean(axis=0) # import pdb; pdb.set_trace() # depth_image = None # normal = None # image = None sketch_image = self.render_sketch_from_depth(depth_image) channel = image.shape[-1] lookat = torch.tensor([[0, 0, -1]], device=self.device) cos_image = torch.nn.functional.cosine_similarity( lookat, normal.view(-1, 3) ) cos_image = cos_image.view(normal.shape[0], normal.shape[1], 1) cos_thres = np.cos(self.bake_angle_thres / 180 * np.pi) cos_image[cos_image < cos_thres] = 0 # shrink kernel_size = self.bake_unreliable_kernel_size * 2 + 1 kernel = torch.ones( (1, 1, kernel_size, kernel_size), dtype=torch.float32 ).to(sketch_image.device) visible_mask = visible_mask.permute(2, 0, 1).unsqueeze(0).float() visible_mask = F.conv2d( 1.0 - visible_mask, kernel, padding=kernel_size // 2 ) visible_mask = 1.0 - (visible_mask > 0).float() # 二值化 visible_mask = visible_mask.squeeze(0).permute(1, 2, 0) sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0) sketch_image = F.conv2d(sketch_image, kernel, padding=kernel_size // 2) sketch_image = (sketch_image > 0).float() # 二值化 sketch_image = sketch_image.squeeze(0).permute(1, 2, 0) visible_mask = visible_mask * (sketch_image < 0.5) cos_image[visible_mask == 0] = 0 proj_mask = (visible_mask != 0).view(-1) uv = uv.squeeze(0).contiguous().view(-1, 2)[proj_mask] image = image.squeeze(0).contiguous().view(-1, channel)[proj_mask] cos_image = cos_image.contiguous().view(-1, 1)[proj_mask] sketch_image = sketch_image.contiguous().view(-1, 1)[proj_mask] import pdb pdb.set_trace() texture = linear_grid_put_2d( self.texture_size[1], self.texture_size[0], uv[..., [1, 0]], image ) cos_map = linear_grid_put_2d( self.texture_size[1], self.texture_size[0], uv[..., [1, 0]], cos_image, ) boundary_map = linear_grid_put_2d( self.texture_size[1], self.texture_size[0], uv[..., [1, 0]], sketch_image, ) return texture, cos_map, boundary_map @torch.no_grad() def fast_bake_texture(self, textures, cos_maps): channel = textures[0].shape[-1] texture_merge = torch.zeros(self.texture_size + (channel,)).to( self.device ) trust_map_merge = torch.zeros(self.texture_size + (1,)).to(self.device) for texture, cos_map in zip(textures, cos_maps): view_sum = (cos_map > 0).sum() painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum() if painted_sum / view_sum > 0.99: continue texture_merge += texture * cos_map trust_map_merge += cos_map texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8) return texture_merge, trust_map_merge > 1e-8 def uv_inpaint(self, texture, mask): if isinstance(texture, torch.Tensor): texture_np = texture.cpu().numpy() elif isinstance(texture, np.ndarray): texture_np = texture elif isinstance(texture, Image.Image): texture_np = np.array(texture) / 255.0 vtx_pos, pos_idx, vtx_uv, uv_idx = self.get_mesh() texture_np, mask = meshVerticeInpaint_smooth( texture_np, mask, vtx_pos, vtx_uv, pos_idx, uv_idx ) texture_np = cv2.inpaint( (texture_np * 255).astype(np.uint8), 255 - mask, 3, cv2.INPAINT_NS ) return texture_np def get_images_from_file(img_path: str, img_size: int) -> list[np.array]: input_image = Image.open(img_path) view_images = np.array(input_image) view_images = np.concatenate( [view_images[:img_size, ...], view_images[img_size:, ...]], axis=1 ) images = np.split(view_images, view_images.shape[1] // img_size, axis=1) return images def bake_from_multiview( render, views, camera_elevs, camera_azims, view_weights, method="fast" ): project_textures, project_weighted_cos_maps = [], [] project_boundary_maps = [] for view, camera_elev, camera_azim, weight in zip( views, camera_elevs, camera_azims, view_weights ): project_texture, project_cos_map, project_boundary_map = ( render.back_project(view, camera_elev, camera_azim) ) project_cos_map = weight * (project_cos_map**4) project_textures.append(project_texture) project_weighted_cos_maps.append(project_cos_map) project_boundary_maps.append(project_boundary_map) if method == "fast": texture, ori_trust_map = render.fast_bake_texture( project_textures, project_weighted_cos_maps ) else: raise f"no method {method}" return texture, ori_trust_map > 1e-8 def post_process(texture: np.ndarray, iter: int = 2) -> np.ndarray: for _ in range(iter): texture = cv2.fastNlMeansDenoisingColored(texture, None, 11, 11, 9, 25) texture = cv2.bilateralFilter( texture, d=7, sigmaColor=80, sigmaSpace=80 ) return texture class Image_Super_Net: def __init__(self, device="cuda"): from diffusers import StableDiffusionUpscalePipeline self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained( "stabilityai/stable-diffusion-x4-upscaler", torch_dtype=torch.float16, ).to(device) self.up_pipeline_x4.set_progress_bar_config(disable=True) def __call__(self, image, prompt=""): with torch.no_grad(): upscaled_image = self.up_pipeline_x4( prompt=[prompt], image=image, num_inference_steps=10, ).images[0] return upscaled_image class Image_GANNet: def __init__(self, outscale: int): from realesrgan import RealESRGANer from basicsr.archs.rrdbnet_arch import RRDBNet self.outscale = outscale model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4, ) self.upsampler = RealESRGANer( scale=4, model_path="/home/users/xinjie.wang/xinjie/Real-ESRGAN/weights/RealESRGAN_x4plus.pth", model=model, pre_pad=0, half=True, ) def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image: if isinstance(image, Image.Image): image = np.array(image) output, _ = self.upsampler.enhance(image, outscale=self.outscale) return Image.fromarray(output) if __name__ == "__main__": device = "cuda" # super_model = Image_Super_Net(device) super_model = Image_GANNet(outscale=4) selected_camera_elevs = [20, 20, 20, -10, -10, -10] selected_camera_azims = [-180, -60, 60, -120, 0, 120] selected_view_weights = [1, 0.2, 0.2, 0.2, 1, 0.2] # selected_view_weights = [1, 0.1, 0.5, 0.1, 0.05, 0.05] multiviews = get_images_from_file( "scripts/apps/texture_sessions/mfq4e7u4ko/multi_view/color_sample1.png", 512, ) target_image_size = (2048, 2048) render = MeshRender( camera_distance=5, default_resolution=2048, texture_size=2048, ) mesh = trimesh.load("scripts/apps/assets/example_texture/meshes/robot.obj") from asset3d_gen.data.utils import normalize_vertices_array mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices) mesh = mesh_uv_wrap(mesh) render.load_mesh(mesh) # multiviews = [Image.fromarray(img) for img in multiviews] # multiviews = [Image.fromarray(img).convert("RGB") for img in multiviews] # for idx, img in enumerate(multiviews): # img.save(f"robot/raw/res_{idx}.png") multiviews = [super_model(img) for img in multiviews] multiviews = [img.convert("RGB") for img in multiviews] for idx, img in enumerate(multiviews): img.save(f"robot/super_gan_res_{idx}.png") texture, mask = bake_from_multiview( render, multiviews, selected_camera_elevs, selected_camera_azims, selected_view_weights, ) texture_np = (texture.cpu().numpy() * 255).astype(np.uint8)[..., :3][ ..., ::-1 ] cv2.imwrite("robot/raw_texture.png", texture_np) print("texture done.") mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8) texture_np = render.uv_inpaint(texture, mask_np) cv2.imwrite("robot/inpaint_texture.png", texture_np[..., ::-1]) # texture_np = post_process(texture_np, 2) # cv2.imwrite("robot/inpaint_conv_texture.png", texture_np[..., ::-1]) print("inpaint done.") texture = torch.tensor(texture_np / 255).float().to(texture.device) render.set_texture(texture) textured_mesh = render.save_mesh() _ = textured_mesh.export("robot/robot.obj")