Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import nn | |
from einops import rearrange | |
import math | |
from einops_exts import check_shape, rearrange_many | |
from torch import Size, Tensor, nn | |
class SinusoidalPosEmb(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, x): | |
device = x.device | |
half_dim = self.dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
emb = x[:, None] * emb[None, :] | |
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
return emb | |
def map_positional_encoding(v: Tensor, freq_bands: Tensor) -> Tensor: | |
"""Map v to positional encoding representation phi(v) | |
Arguments: | |
v (Tensor): input features (B, IFeatures) | |
freq_bands (Tensor): frequency bands (N_freqs, ) | |
Returns: | |
phi(v) (Tensor): fourrier features (B, 3 + (2 * N_freqs) * 3) | |
""" | |
pe = [v] | |
for freq in freq_bands: | |
fv = freq * v | |
pe += [torch.sin(fv), torch.cos(fv)] | |
return torch.cat(pe, dim=-1) | |
class FeatureMapping(nn.Module): | |
"""FeatureMapping nn.Module | |
Maps v to features following transformation phi(v) | |
Arguments: | |
i_dim (int): input dimensions | |
o_dim (int): output dimensions | |
""" | |
def __init__(self, i_dim: int, o_dim: int) -> None: | |
super().__init__() | |
self.i_dim = i_dim | |
self.o_dim = o_dim | |
def forward(self, v: Tensor) -> Tensor: | |
"""FeratureMapping forward pass | |
Arguments: | |
v (Tensor): input features (B, IFeatures) | |
Returns: | |
phi(v) (Tensor): mapped features (B, OFeatures) | |
""" | |
raise NotImplementedError("Forward pass not implemented yet!") | |
class PositionalEncoding(FeatureMapping): | |
"""PositionalEncoding module | |
Maps v to positional encoding representation phi(v) | |
Arguments: | |
i_dim (int): input dimension for v | |
N_freqs (int): #frequency to sample (default: 10) | |
""" | |
def __init__( | |
self, | |
i_dim: int, | |
N_freqs: int = 10, | |
) -> None: | |
super().__init__(i_dim, 3 + (2 * N_freqs) * 3) | |
self.N_freqs = N_freqs | |
a, b = 1, self.N_freqs - 1 | |
freq_bands = 2 ** torch.linspace(a, b, self.N_freqs) | |
self.register_buffer("freq_bands", freq_bands) | |
def forward(self, v: Tensor) -> Tensor: | |
"""Map v to positional encoding representation phi(v) | |
Arguments: | |
v (Tensor): input features (B, IFeatures) | |
Returns: | |
phi(v) (Tensor): fourrier features (B, 3 + (2 * N_freqs) * 3) | |
""" | |
return map_positional_encoding(v, self.freq_bands) | |
class BaseTemperalPointModel(nn.Module): | |
""" A base class providing useful methods for point cloud processing. """ | |
def __init__( | |
self, | |
*, | |
num_classes, | |
embed_dim, | |
extra_feature_channels, | |
dim: int = 768, | |
num_layers: int = 6 | |
): | |
super().__init__() | |
self.extra_feature_channels = extra_feature_channels | |
self.timestep_embed_dim = 256 | |
self.output_dim = num_classes | |
self.dim = dim | |
self.num_layers = num_layers | |
self.time_mlp = nn.Sequential( | |
SinusoidalPosEmb(dim), | |
nn.Linear(dim, self.timestep_embed_dim ), | |
nn.SiLU(), | |
nn.Linear(self.timestep_embed_dim , self.timestep_embed_dim ) | |
) | |
self.positional_encoding = PositionalEncoding(i_dim=3, N_freqs=10) | |
positional_encoding_d_out = 3 + (2 * 10) * 3 | |
# Input projection (point coords, point coord encodings, other features, and timestep embeddings) | |
self.input_projection = nn.Linear( | |
in_features=(3 + positional_encoding_d_out), | |
out_features=self.dim | |
)#b f p c | |
# Transformer layers | |
self.layers = self.get_layers() | |
# Output projection | |
self.output_projection = nn.Linear(self.dim, self.output_dim) | |
def get_layers(self): | |
raise NotImplementedError('This method should be implemented by subclasses') | |
def forward(self, inputs: torch.Tensor, t: torch.Tensor): | |
raise NotImplementedError('This method should be implemented by subclasses') | |