jammmmm's picture
Add spar3d demo files
38dbec8
from dataclasses import dataclass
from typing import Optional
import torch
from jaxtyping import Float
from torch import Tensor
from spar3d.models.transformers.transformer_1d import Transformer1D
from spar3d.models.utils import BaseModule
class TransformerPointTokenizer(BaseModule):
@dataclass
class Config(BaseModule.Config):
num_attention_heads: int = 16
attention_head_dim: int = 64
in_channels: Optional[int] = 6
out_channels: Optional[int] = 1024
num_layers: int = 16
norm_num_groups: int = 32
attention_bias: bool = False
activation_fn: str = "geglu"
norm_elementwise_affine: bool = True
cfg: Config
def configure(self) -> None:
transformer_cfg = dict(self.cfg.copy())
# remove the non-transformer configs
transformer_cfg["in_channels"] = (
self.cfg.num_attention_heads * self.cfg.attention_head_dim
)
self.model = Transformer1D(transformer_cfg)
self.linear_in = torch.nn.Linear(
self.cfg.in_channels, transformer_cfg["in_channels"]
)
self.linear_out = torch.nn.Linear(
transformer_cfg["in_channels"], self.cfg.out_channels
)
def forward(
self, points: Float[Tensor, "B N Ci"], **kwargs
) -> Float[Tensor, "B N Cp"]:
assert points.ndim == 3
inputs = self.linear_in(points).permute(0, 2, 1) # B N Ci -> B Ci N
out = self.model(inputs).permute(0, 2, 1) # B Ci N -> B N Ci
out = self.linear_out(out) # B N Ci -> B N Co
return out
def detokenize(self, *args, **kwargs):
raise NotImplementedError