File size: 8,493 Bytes
17cd746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
from plyfile import PlyData, PlyElement
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import copy
from lam.models.rendering.utils.typing import *
from lam.models.rendering.utils.utils import trunc_exp, MLP
from einops import rearrange, repeat


inverse_sigmoid = lambda x: np.log(x / (1 - x))


class GaussianModel:
    def __init__(self, xyz=None, opacity=None, rotation=None, scaling=None, shs=None, offset=None, ply_path=None, sh2rgb=False, albedo=None, lights=None) -> None:
        self.xyz: Tensor = xyz
        self.opacity: Tensor = opacity
        self.rotation: Tensor = rotation
        self.scaling: Tensor = scaling
        self.shs: Tensor = shs
        self.albedo: Tensor = albedo
        self.offset: Tensor = offset
        self.lights: Tensor = lights 
        if ply_path is not None:
            self.load_ply(ply_path, sh2rgb=sh2rgb)

    def update_lights(self, lights):
        self.lights = lights
    
    def update_albedo(self, albedo):
        self.albedo = albedo

    def update_shs(self, shs):
        self.shs = shs
        
    def to_cuda(self):
        self.xyz = self.xyz.cuda()
        self.opacity = self.opacity.cuda()
        self.rotation = self.rotation.cuda()
        self.scaling = self.scaling.cuda()
        self.shs = self.shs.cuda()
        self.offset = self.offset.cuda()
        self.albedo = self.albedo.cuda()

    def construct_list_of_attributes(self):
        l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
        if len(self.shs.shape) == 2:
            features_dc = self.shs[:, :3].unsqueeze(1)
            features_rest = self.shs[:, 3:].unsqueeze(1)
        else:
            features_dc = self.shs[:, :1]
            features_rest = self.shs[:, 1:]
        for i in range(features_dc.shape[1]*features_dc.shape[2]):
            l.append('f_dc_{}'.format(i))
        for i in range(features_rest.shape[1]*features_rest.shape[2]):
            l.append('f_rest_{}'.format(i))
        l.append('opacity')
        for i in range(self.scaling.shape[1]):
            l.append('scale_{}'.format(i))
        for i in range(self.rotation.shape[1]):
            l.append('rot_{}'.format(i))
        return l

    def save_ply(self, path, rgb2sh=False, offset2xyz=False, albedo2rgb=False):
        if offset2xyz:
            xyz = self.offset.detach().cpu().float().numpy()
        else:
            xyz = self.xyz.detach().cpu().float().numpy()
        if albedo2rgb:
            self.shs = self.albedo
        normals = np.zeros_like(xyz)
        if len(self.shs.shape) == 2:
            features_dc = self.shs[:, :3].unsqueeze(1).float()
            features_rest = self.shs[:, 3:].unsqueeze(1).float()
        else:
            features_dc = self.shs[:, :1].float()
            features_rest = self.shs[:, 1:].float()
        f_dc = features_dc.detach().flatten(start_dim=1).contiguous().cpu().numpy()
        f_rest = features_rest.detach().flatten(start_dim=1).contiguous().cpu().numpy()
        if rgb2sh:
            from lam.models.rendering.utils.sh_utils import RGB2SH
            f_dc = RGB2SH(f_dc)
        opacities = inverse_sigmoid(torch.clamp(self.opacity, 1e-3, 1 - 1e-3).detach().cpu().float().numpy())
        scale = np.log(self.scaling.detach().cpu().float().numpy())
        rotation = self.rotation.detach().cpu().float().numpy()

        dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]

        elements = np.empty(xyz.shape[0], dtype=dtype_full)
        attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
        elements[:] = list(map(tuple, attributes))
        el = PlyElement.describe(elements, 'vertex')
        PlyData([el]).write(path)

    def save_ply_nodeact(self, path, rgb2sh=False, albedo2rgb=False):
        if albedo2rgb:
            self.shs = self.albedo
        xyz = self.xyz.detach().cpu().float().numpy()
        normals = np.zeros_like(xyz)
        if len(self.shs.shape) == 2:
            features_dc = self.shs[:, :3].unsqueeze(1).float()
            features_rest = self.shs[:, 3:].unsqueeze(1).float()
        else:
            features_dc = self.shs[:, :1].float()
            features_rest = self.shs[:, 1:].float()
        f_dc = features_dc.detach().flatten(start_dim=1).contiguous().cpu().numpy()
        f_rest = features_rest.detach().flatten(start_dim=1).contiguous().cpu().numpy()
        if rgb2sh:
            from lam.models.rendering.utils.sh_utils import RGB2SH
            f_dc = RGB2SH(f_dc)
        opacities = self.opacity.detach().cpu().float().numpy()
        scale = self.scaling.detach().cpu().float().numpy()
        rotation = self.rotation.detach().cpu().float().numpy()

        dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]

        elements = np.empty(xyz.shape[0], dtype=dtype_full)
        attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
        elements[:] = list(map(tuple, attributes))
        el = PlyElement.describe(elements, 'vertex')
        PlyData([el]).write(path)

    def load_ply(self, path, sh2rgb=False):
        plydata = PlyData.read(path)

        xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
                        np.asarray(plydata.elements[0]["y"]),
                        np.asarray(plydata.elements[0]["z"])),  axis=1)
        opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]

        features_dc = np.zeros((xyz.shape[0], 3, 1))
        features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
        features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
        features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])

        self.sh_degree = 0
        extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
        extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
        features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
        for idx, attr_name in enumerate(extra_f_names):
            features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
        # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
        features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.sh_degree + 1) ** 2 - 1))

        scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
        scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
        scales = np.zeros((xyz.shape[0], len(scale_names)))
        for idx, attr_name in enumerate(scale_names):
            scales[:, idx] = np.asarray(plydata.elements[0][attr_name])

        rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot_")]
        rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
        rots = np.zeros((xyz.shape[0], len(rot_names)))
        for idx, attr_name in enumerate(rot_names):
            rots[:, idx] = np.asarray(plydata.elements[0][attr_name])

        self.xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cpu").requires_grad_(False))
        self.features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cpu").transpose(1, 2).contiguous().requires_grad_(False))
        if sh2rgb:
            from lam.models.rendering.utils.sh_utils import SH2RGB
            self.features_dc = SH2RGB(self.features_dc)
        self.features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cpu").transpose(1, 2).contiguous().requires_grad_(False))
        self.shs = torch.cat([self.features_dc, self.features_rest], dim=1)
        self.opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cpu").requires_grad_(False))
        self.scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cpu").requires_grad_(False))
        self.rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cpu").requires_grad_(False))
        self.offset = nn.Parameter(torch.zeros_like(self.xyz).requires_grad_(False))
        self.albedo = nn.Parameter(torch.zeros_like(self.shs).requires_grad_(False))
        self.lights = nn.Parameter(torch.zeros_like(self.shs).requires_grad_(False))
        if sh2rgb:
            self.opacity = nn.functional.sigmoid(self.opacity)
            self.scaling = trunc_exp(self.scaling)

        self.active_sh_degree = self.sh_degree