Spaces:
Running
on
L40S
Running
on
L40S
from dataclasses import dataclass | |
from typing import Optional | |
import torch | |
from jaxtyping import Float | |
from torch import Tensor | |
from spar3d.models.transformers.transformer_1d import Transformer1D | |
from spar3d.models.utils import BaseModule | |
class TransformerPointTokenizer(BaseModule): | |
class Config(BaseModule.Config): | |
num_attention_heads: int = 16 | |
attention_head_dim: int = 64 | |
in_channels: Optional[int] = 6 | |
out_channels: Optional[int] = 1024 | |
num_layers: int = 16 | |
norm_num_groups: int = 32 | |
attention_bias: bool = False | |
activation_fn: str = "geglu" | |
norm_elementwise_affine: bool = True | |
cfg: Config | |
def configure(self) -> None: | |
transformer_cfg = dict(self.cfg.copy()) | |
# remove the non-transformer configs | |
transformer_cfg["in_channels"] = ( | |
self.cfg.num_attention_heads * self.cfg.attention_head_dim | |
) | |
self.model = Transformer1D(transformer_cfg) | |
self.linear_in = torch.nn.Linear( | |
self.cfg.in_channels, transformer_cfg["in_channels"] | |
) | |
self.linear_out = torch.nn.Linear( | |
transformer_cfg["in_channels"], self.cfg.out_channels | |
) | |
def forward( | |
self, points: Float[Tensor, "B N Ci"], **kwargs | |
) -> Float[Tensor, "B N Cp"]: | |
assert points.ndim == 3 | |
inputs = self.linear_in(points).permute(0, 2, 1) # B N Ci -> B Ci N | |
out = self.model(inputs).permute(0, 2, 1) # B Ci N -> B N Ci | |
out = self.linear_out(out) # B N Ci -> B N Co | |
return out | |
def detokenize(self, *args, **kwargs): | |
raise NotImplementedError | |