File size: 5,781 Bytes
db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...modules import sparse as sp
from ...utils.random_utils import hammersley_sequence
from .base import SparseTransformerBase
from ...representations import Gaussian
class SLatGaussianDecoder(SparseTransformerBase):
def __init__(
self,
resolution: int,
model_channels: int,
latent_channels: int,
num_blocks: int,
num_heads: Optional[int] = None,
num_head_channels: Optional[int] = 64,
mlp_ratio: float = 4,
attn_mode: Literal[
"full", "shift_window", "shift_sequence", "shift_order", "swin"
] = "swin",
window_size: int = 8,
pe_mode: Literal["ape", "rope"] = "ape",
use_fp16: bool = False,
use_checkpoint: bool = False,
qk_rms_norm: bool = False,
representation_config: dict = None,
):
super().__init__(
in_channels=latent_channels,
model_channels=model_channels,
num_blocks=num_blocks,
num_heads=num_heads,
num_head_channels=num_head_channels,
mlp_ratio=mlp_ratio,
attn_mode=attn_mode,
window_size=window_size,
pe_mode=pe_mode,
use_fp16=use_fp16,
use_checkpoint=use_checkpoint,
qk_rms_norm=qk_rms_norm,
)
self.resolution = resolution
self.rep_config = representation_config
self._calc_layout()
self.out_layer = sp.SparseLinear(model_channels, self.out_channels)
self._build_perturbation()
self.initialize_weights()
if use_fp16:
self.convert_to_fp16()
def initialize_weights(self) -> None:
super().initialize_weights()
# Zero-out output layers:
nn.init.constant_(self.out_layer.weight, 0)
nn.init.constant_(self.out_layer.bias, 0)
def _build_perturbation(self) -> None:
perturbation = [
hammersley_sequence(3, i, self.rep_config["num_gaussians"])
for i in range(self.rep_config["num_gaussians"])
]
perturbation = torch.tensor(perturbation).float() * 2 - 1
perturbation = perturbation / self.rep_config["voxel_size"]
perturbation = torch.atanh(perturbation).to(self.device)
self.register_buffer("offset_perturbation", perturbation)
def _calc_layout(self) -> None:
self.layout = {
"_xyz": {
"shape": (self.rep_config["num_gaussians"], 3),
"size": self.rep_config["num_gaussians"] * 3,
},
"_features_dc": {
"shape": (self.rep_config["num_gaussians"], 1, 3),
"size": self.rep_config["num_gaussians"] * 3,
},
"_scaling": {
"shape": (self.rep_config["num_gaussians"], 3),
"size": self.rep_config["num_gaussians"] * 3,
},
"_rotation": {
"shape": (self.rep_config["num_gaussians"], 4),
"size": self.rep_config["num_gaussians"] * 4,
},
"_opacity": {
"shape": (self.rep_config["num_gaussians"], 1),
"size": self.rep_config["num_gaussians"],
},
}
start = 0
for k, v in self.layout.items():
v["range"] = (start, start + v["size"])
start += v["size"]
self.out_channels = start
def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]:
"""
Convert a batch of network outputs to 3D representations.
Args:
x: The [N x * x C] sparse tensor output by the network.
Returns:
list of representations
"""
ret = []
for i in range(x.shape[0]):
representation = Gaussian(
sh_degree=0,
aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
mininum_kernel_size=self.rep_config["3d_filter_kernel_size"],
scaling_bias=self.rep_config["scaling_bias"],
opacity_bias=self.rep_config["opacity_bias"],
scaling_activation=self.rep_config["scaling_activation"],
)
xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
for k, v in self.layout.items():
if k == "_xyz":
offset = x.feats[x.layout[i]][
:, v["range"][0] : v["range"][1]
].reshape(-1, *v["shape"])
offset = offset * self.rep_config["lr"][k]
if self.rep_config["perturb_offset"]:
offset = offset + self.offset_perturbation
offset = (
torch.tanh(offset)
/ self.resolution
* 0.5
* self.rep_config["voxel_size"]
)
_xyz = xyz.unsqueeze(1) + offset
setattr(representation, k, _xyz.flatten(0, 1))
else:
feats = (
x.feats[x.layout[i]][:, v["range"][0] : v["range"][1]]
.reshape(-1, *v["shape"])
.flatten(0, 1)
)
feats = feats * self.rep_config["lr"][k]
setattr(representation, k, feats)
ret.append(representation)
return ret
def forward(self, x: sp.SparseTensor) -> List[Gaussian]:
h = super().forward(x)
h = h.type(x.dtype)
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
h = self.out_layer(h)
return self.to_representation(h)
|