KDTalker / model /temporaltrans /transformer_utils.py
fffiloni's picture
Migrated from GitHub
a6028c9 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from einops import rearrange
import math
from einops_exts import check_shape, rearrange_many
from torch import Size, Tensor, nn
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
def map_positional_encoding(v: Tensor, freq_bands: Tensor) -> Tensor:
"""Map v to positional encoding representation phi(v)
Arguments:
v (Tensor): input features (B, IFeatures)
freq_bands (Tensor): frequency bands (N_freqs, )
Returns:
phi(v) (Tensor): fourrier features (B, 3 + (2 * N_freqs) * 3)
"""
pe = [v]
for freq in freq_bands:
fv = freq * v
pe += [torch.sin(fv), torch.cos(fv)]
return torch.cat(pe, dim=-1)
class FeatureMapping(nn.Module):
"""FeatureMapping nn.Module
Maps v to features following transformation phi(v)
Arguments:
i_dim (int): input dimensions
o_dim (int): output dimensions
"""
def __init__(self, i_dim: int, o_dim: int) -> None:
super().__init__()
self.i_dim = i_dim
self.o_dim = o_dim
def forward(self, v: Tensor) -> Tensor:
"""FeratureMapping forward pass
Arguments:
v (Tensor): input features (B, IFeatures)
Returns:
phi(v) (Tensor): mapped features (B, OFeatures)
"""
raise NotImplementedError("Forward pass not implemented yet!")
class PositionalEncoding(FeatureMapping):
"""PositionalEncoding module
Maps v to positional encoding representation phi(v)
Arguments:
i_dim (int): input dimension for v
N_freqs (int): #frequency to sample (default: 10)
"""
def __init__(
self,
i_dim: int,
N_freqs: int = 10,
) -> None:
super().__init__(i_dim, 3 + (2 * N_freqs) * 3)
self.N_freqs = N_freqs
a, b = 1, self.N_freqs - 1
freq_bands = 2 ** torch.linspace(a, b, self.N_freqs)
self.register_buffer("freq_bands", freq_bands)
def forward(self, v: Tensor) -> Tensor:
"""Map v to positional encoding representation phi(v)
Arguments:
v (Tensor): input features (B, IFeatures)
Returns:
phi(v) (Tensor): fourrier features (B, 3 + (2 * N_freqs) * 3)
"""
return map_positional_encoding(v, self.freq_bands)
class BaseTemperalPointModel(nn.Module):
""" A base class providing useful methods for point cloud processing. """
def __init__(
self,
*,
num_classes,
embed_dim,
extra_feature_channels,
dim: int = 768,
num_layers: int = 6
):
super().__init__()
self.extra_feature_channels = extra_feature_channels
self.timestep_embed_dim = 256
self.output_dim = num_classes
self.dim = dim
self.num_layers = num_layers
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, self.timestep_embed_dim ),
nn.SiLU(),
nn.Linear(self.timestep_embed_dim , self.timestep_embed_dim )
)
self.positional_encoding = PositionalEncoding(i_dim=3, N_freqs=10)
positional_encoding_d_out = 3 + (2 * 10) * 3
# Input projection (point coords, point coord encodings, other features, and timestep embeddings)
self.input_projection = nn.Linear(
in_features=(3 + positional_encoding_d_out),
out_features=self.dim
)#b f p c
# Transformer layers
self.layers = self.get_layers()
# Output projection
self.output_projection = nn.Linear(self.dim, self.output_dim)
def get_layers(self):
raise NotImplementedError('This method should be implemented by subclasses')
def forward(self, inputs: torch.Tensor, t: torch.Tensor):
raise NotImplementedError('This method should be implemented by subclasses')