File size: 4,306 Bytes
1e5535f
 
02a9751
1e5535f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02a9751
 
1e5535f
 
 
 
7cad53a
02a9751
 
 
 
 
 
 
 
 
 
 
 
 
1e5535f
 
02a9751
1e5535f
 
 
 
 
 
 
 
ded6c2a
1e5535f
 
235efa3
 
 
02a9751
 
1e5535f
 
 
 
 
 
235efa3
1e5535f
235efa3
1e5535f
 
02a9751
1e5535f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235efa3
1e5535f
 
 
 
 
 
 
 
 
 
 
 
 
02a9751
 
1e5535f
02a9751
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import pytorch3d
import torch
import imageio
import numpy as np
import os
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.renderer import (
    AmbientLights,
    PerspectiveCameras, 
    RasterizationSettings, 
    look_at_view_transform,
    TexturesVertex,
    MeshRenderer, 
    Materials,
    MeshRasterizer, 
    SoftPhongShader, 
    PointLights
)
import trimesh
from tqdm import tqdm
from pytorch3d.transforms import RotateAxisAngle

from shader import MultiOutputShader

def render_video_from_obj(input_obj_path, output_video_path, num_frames=60, image_size=512, fps=15, device="cuda"):
    if not os.path.exists(input_obj_path):
        raise FileNotFoundError(f"Input OBJ file not found: {input_obj_path}")

    scene_data = trimesh.load(input_obj_path)

    if isinstance(scene_data, trimesh.Scene):
        mesh_data = trimesh.util.concatenate([geom for geom in scene_data.geometry.values()])
    else:
        mesh_data = scene_data

    if not hasattr(mesh_data, 'vertex_normals') or mesh_data.vertex_normals is None:
        mesh_data.compute_vertex_normals()

    vertices = torch.tensor(mesh_data.vertices, dtype=torch.float32, device=device)
    faces = torch.tensor(mesh_data.faces, dtype=torch.int64, device=device)

    if mesh_data.visual.vertex_colors is None:
        vertex_colors = torch.ones_like(vertices)[None]
    else:
        vertex_colors = torch.tensor(mesh_data.visual.vertex_colors[:, :3], dtype=torch.float32)[None]
    textures = TexturesVertex(verts_features=vertex_colors)
    textures.to(device)
    mesh = pytorch3d.structures.Meshes(verts=[vertices], faces=[faces], textures=textures)

    lights = AmbientLights(ambient_color=((2.0,)*3,), device=device)
    # lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]], ambient_color=[[0.5, 0.5, 0.5]], diffuse_color=[[1.0, 1.0, 1.0]])
    raster_settings = RasterizationSettings(
        image_size=image_size,
        blur_radius=0.0,
        faces_per_pixel=1,
    )

    frames = []
    camera_distance = 6.5
    elevs = 0.0
    center = (0.0, 0.0, 0.0)
    materials = Materials(
            device=device,
            diffuse_color=((1.0, 1.0, 1.0),),
            ambient_color=((1.0, 1.0, 1.0),),
            specular_color=((1.0, 1.0, 1.0),),
            shininess=0.0,
    )
        
    rasterizer = MeshRasterizer(raster_settings=raster_settings)
    for i in tqdm(range(num_frames)):
        azims = 360.0 * i / num_frames
        R, T = look_at_view_transform(
            dist=camera_distance,
            elev=elevs,
            azim=azims,
            at=(center,),
            degrees=True
        )

        # 手动设置相机的旋转矩阵
        cameras = PerspectiveCameras(device=device, R=R, T=T, focal_length=5.0)
        cameras.znear = 0.0001
        cameras.zfar = 10000000.0
        shader=MultiOutputShader(
                device=device,
                cameras=cameras,
                lights=lights,
                materials=materials,
                choices=["rgb", "mask", "normal"]
            )

        renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
        render_result = renderer(mesh, cameras=cameras)
        rgb_image = render_result["rgb"] * render_result["mask"] + (1 - render_result["mask"]) * torch.ones_like(render_result["rgb"]) * 255.0
        normal_map = render_result["normal"]

        rgb = rgb_image[0, ..., :3].cpu().numpy()
        normal_map = torch.nn.functional.normalize(normal_map, dim=-1)  # Normal map
        normal_map = (normal_map + 1) / 2
        normal_map = normal_map * render_result["mask"] + (1 - render_result["mask"]) * torch.ones_like(render_result["normal"])
        normal = normal_map[0, ..., :3].cpu().numpy()  # Normal map
        rgb = np.clip(rgb, 0, 255).astype(np.uint8)
        normal = np.clip(normal*255, 0, 255).astype(np.uint8)
        combined_image = np.concatenate((rgb, normal), axis=1)

        frames.append(combined_image)

    imageio.mimsave(output_video_path, frames, fps=fps)

    print(f"Video saved to {output_video_path}")

if __name__ == '__main__':
    input_obj_path = "/hpc2hdd/home/jlin695/code/github/Kiss3DGen/outputs/a_owl_wearing_a_hat/ISOMER/rgb_projected.obj"
    output_video_path = "output.mp4"
    render_video_from_obj(input_obj_path, output_video_path)