|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from einops import rearrange |
|
import math |
|
from einops_exts import check_shape, rearrange_many |
|
from torch import Size, Tensor, nn |
|
|
|
class SinusoidalPosEmb(nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
self.dim = dim |
|
|
|
def forward(self, x): |
|
device = x.device |
|
half_dim = self.dim // 2 |
|
emb = math.log(10000) / (half_dim - 1) |
|
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) |
|
emb = x[:, None] * emb[None, :] |
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) |
|
return emb |
|
|
|
|
|
def map_positional_encoding(v: Tensor, freq_bands: Tensor) -> Tensor: |
|
"""Map v to positional encoding representation phi(v) |
|
|
|
Arguments: |
|
v (Tensor): input features (B, IFeatures) |
|
freq_bands (Tensor): frequency bands (N_freqs, ) |
|
|
|
Returns: |
|
phi(v) (Tensor): fourrier features (B, 3 + (2 * N_freqs) * 3) |
|
""" |
|
pe = [v] |
|
for freq in freq_bands: |
|
fv = freq * v |
|
pe += [torch.sin(fv), torch.cos(fv)] |
|
return torch.cat(pe, dim=-1) |
|
|
|
class FeatureMapping(nn.Module): |
|
"""FeatureMapping nn.Module |
|
|
|
Maps v to features following transformation phi(v) |
|
|
|
Arguments: |
|
i_dim (int): input dimensions |
|
o_dim (int): output dimensions |
|
""" |
|
|
|
def __init__(self, i_dim: int, o_dim: int) -> None: |
|
super().__init__() |
|
self.i_dim = i_dim |
|
self.o_dim = o_dim |
|
|
|
def forward(self, v: Tensor) -> Tensor: |
|
"""FeratureMapping forward pass |
|
|
|
Arguments: |
|
v (Tensor): input features (B, IFeatures) |
|
|
|
Returns: |
|
phi(v) (Tensor): mapped features (B, OFeatures) |
|
""" |
|
raise NotImplementedError("Forward pass not implemented yet!") |
|
|
|
class PositionalEncoding(FeatureMapping): |
|
"""PositionalEncoding module |
|
|
|
Maps v to positional encoding representation phi(v) |
|
|
|
Arguments: |
|
i_dim (int): input dimension for v |
|
N_freqs (int): #frequency to sample (default: 10) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
i_dim: int, |
|
N_freqs: int = 10, |
|
) -> None: |
|
super().__init__(i_dim, 3 + (2 * N_freqs) * 3) |
|
self.N_freqs = N_freqs |
|
|
|
a, b = 1, self.N_freqs - 1 |
|
freq_bands = 2 ** torch.linspace(a, b, self.N_freqs) |
|
self.register_buffer("freq_bands", freq_bands) |
|
|
|
def forward(self, v: Tensor) -> Tensor: |
|
"""Map v to positional encoding representation phi(v) |
|
|
|
Arguments: |
|
v (Tensor): input features (B, IFeatures) |
|
|
|
Returns: |
|
phi(v) (Tensor): fourrier features (B, 3 + (2 * N_freqs) * 3) |
|
""" |
|
return map_positional_encoding(v, self.freq_bands) |
|
|
|
class BaseTemperalPointModel(nn.Module): |
|
""" A base class providing useful methods for point cloud processing. """ |
|
|
|
def __init__( |
|
self, |
|
*, |
|
num_classes, |
|
embed_dim, |
|
extra_feature_channels, |
|
dim: int = 768, |
|
num_layers: int = 6 |
|
): |
|
super().__init__() |
|
|
|
self.extra_feature_channels = extra_feature_channels |
|
self.timestep_embed_dim = 256 |
|
self.output_dim = num_classes |
|
self.dim = dim |
|
self.num_layers = num_layers |
|
|
|
|
|
self.time_mlp = nn.Sequential( |
|
SinusoidalPosEmb(dim), |
|
nn.Linear(dim, self.timestep_embed_dim ), |
|
nn.SiLU(), |
|
nn.Linear(self.timestep_embed_dim , self.timestep_embed_dim ) |
|
) |
|
|
|
self.positional_encoding = PositionalEncoding(i_dim=3, N_freqs=10) |
|
positional_encoding_d_out = 3 + (2 * 10) * 3 |
|
|
|
|
|
|
|
self.input_projection = nn.Linear( |
|
in_features=(3 + positional_encoding_d_out), |
|
out_features=self.dim |
|
) |
|
|
|
|
|
self.layers = self.get_layers() |
|
|
|
|
|
self.output_projection = nn.Linear(self.dim, self.output_dim) |
|
def get_layers(self): |
|
raise NotImplementedError('This method should be implemented by subclasses') |
|
|
|
def forward(self, inputs: torch.Tensor, t: torch.Tensor): |
|
raise NotImplementedError('This method should be implemented by subclasses') |
|
|