Kiss3DGen / video_render.py
JiantaoLin
new
ded6c2a
raw
history blame
5.14 kB
import pytorch3d
import torch
import imageio
import numpy as np
import os
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.renderer import (
AmbientLights,
PerspectiveCameras,
RasterizationSettings,
look_at_view_transform,
TexturesVertex,
MeshRenderer,
Materials,
MeshRasterizer,
SoftPhongShader,
PointLights
)
import trimesh
from tqdm import tqdm
from pytorch3d.transforms import RotateAxisAngle
from shader import MultiOutputShader
def render_video_from_obj(input_obj_path, output_video_path, num_frames=60, image_size=512, fps=30, device="cuda"):
if not os.path.exists(input_obj_path):
raise FileNotFoundError(f"Input OBJ file not found: {input_obj_path}")
# 加载3D模型
scene_data = trimesh.load(input_obj_path)
# 提取或合并网格
if isinstance(scene_data, trimesh.Scene):
mesh_data = trimesh.util.concatenate([geom for geom in scene_data.geometry.values()])
else:
mesh_data = scene_data
# 确保顶点法线存在
if not hasattr(mesh_data, 'vertex_normals') or mesh_data.vertex_normals is None:
mesh_data.compute_vertex_normals()
# 获取顶点坐标、法线和面
vertices = torch.tensor(mesh_data.vertices, dtype=torch.float32, device=device)
faces = torch.tensor(mesh_data.faces, dtype=torch.int64, device=device)
vertex_normals = torch.tensor(mesh_data.vertex_normals, dtype=torch.float32)
# 获取顶点颜色
if mesh_data.visual.vertex_colors is None:
# 如果没有顶点颜色,可以给定一个默认值(例如,白色)
vertex_colors = torch.ones_like(vertices)[None]
else:
vertex_colors = torch.tensor(mesh_data.visual.vertex_colors[:, :3], dtype=torch.float32)[None]
# 创建纹理并分配顶点颜色
textures = TexturesVertex(verts_features=vertex_colors)
textures.to(device)
# 创建Mesh对象
mesh = pytorch3d.structures.Meshes(verts=[vertices], faces=[faces], textures=textures)
# 设置渲染器
lights = AmbientLights(ambient_color=((2.0,)*3,), device=device)
# lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]], ambient_color=[[0.5, 0.5, 0.5]], diffuse_color=[[1.0, 1.0, 1.0]])
raster_settings = RasterizationSettings(
image_size=image_size, # 渲染图像的尺寸
blur_radius=0.0, # 默认无模糊
faces_per_pixel=1, # 每像素渲染一个面
# background_color=(1.0, 1.0, 1.0)
)
# 设置旋转和渲染参数
frames = []
camera_distance = 6.5
elevs = 0.0
center = (0.0, 0.0, 0.0)
# 渲染每一帧
materials = Materials(
device=device,
diffuse_color=((0.0, 0.0, 0.0),),
ambient_color=((1.0, 1.0, 1.0),),
specular_color=((0.0, 0.0, 0.0),),
shininess=0.0,
)
rasterizer = MeshRasterizer(raster_settings=raster_settings)
for i in tqdm(range(num_frames)):
azims = 360.0 * i / num_frames
R, T = look_at_view_transform(
dist=camera_distance,
elev=elevs,
azim=azims,
at=(center,),
degrees=True
)
# 手动设置相机的旋转矩阵
cameras = PerspectiveCameras(device=device, R=R, T=T, focal_length=5.0)
cameras.znear = 0.0001
cameras.zfar = 10000000.0
shader=MultiOutputShader(
device=device,
cameras=cameras,
lights=lights,
materials=materials,
choices=["rgb", "mask", "normal"]
)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
# 渲染RGB图像和Normal图像
render_result = renderer(mesh, cameras=cameras)
rgb_image = render_result["rgb"] * render_result["mask"] + (1 - render_result["mask"]) * torch.ones_like(render_result["rgb"]) * 255.0
normal_map = render_result["normal"]
# 提取RGB和Normal map
rgb = rgb_image[0, ..., :3].cpu().numpy() # RGB图像
normal_map = torch.nn.functional.normalize(normal_map, dim=-1) # Normal map
normal_map = (normal_map + 1) / 2
normal_map = normal_map * render_result["mask"] + (1 - render_result["mask"]) * torch.ones_like(render_result["normal"])
normal = normal_map[0, ..., :3].cpu().numpy() # Normal map
rgb = np.clip(rgb, 0, 255).astype(np.uint8)
normal = np.clip(normal*255, 0, 255).astype(np.uint8)
# 将RGB和Normal map合并为一张图,左边RGB,右边Normal map
combined_image = np.concatenate((rgb, normal), axis=1)
# 将合并后的图像加入到帧列表
frames.append(combined_image)
# 使用imageio保存视频
imageio.mimsave(output_video_path, frames, fps=fps)
print(f"Video saved to {output_video_path}")
if __name__ == '__main__':
# 示例调用
input_obj_path = "/hpc2hdd/home/jlin695/code/github/Kiss3DGen/outputs/a_owl_wearing_a_hat/ISOMER/rgb_projected.obj"
output_video_path = "output.mp4"
render_video_from_obj(input_obj_path, output_video_path)