File size: 4,325 Bytes
a6028c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from einops import rearrange
import math
from einops_exts import check_shape, rearrange_many
from torch import Size, Tensor, nn
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
def map_positional_encoding(v: Tensor, freq_bands: Tensor) -> Tensor:
"""Map v to positional encoding representation phi(v)
Arguments:
v (Tensor): input features (B, IFeatures)
freq_bands (Tensor): frequency bands (N_freqs, )
Returns:
phi(v) (Tensor): fourrier features (B, 3 + (2 * N_freqs) * 3)
"""
pe = [v]
for freq in freq_bands:
fv = freq * v
pe += [torch.sin(fv), torch.cos(fv)]
return torch.cat(pe, dim=-1)
class FeatureMapping(nn.Module):
"""FeatureMapping nn.Module
Maps v to features following transformation phi(v)
Arguments:
i_dim (int): input dimensions
o_dim (int): output dimensions
"""
def __init__(self, i_dim: int, o_dim: int) -> None:
super().__init__()
self.i_dim = i_dim
self.o_dim = o_dim
def forward(self, v: Tensor) -> Tensor:
"""FeratureMapping forward pass
Arguments:
v (Tensor): input features (B, IFeatures)
Returns:
phi(v) (Tensor): mapped features (B, OFeatures)
"""
raise NotImplementedError("Forward pass not implemented yet!")
class PositionalEncoding(FeatureMapping):
"""PositionalEncoding module
Maps v to positional encoding representation phi(v)
Arguments:
i_dim (int): input dimension for v
N_freqs (int): #frequency to sample (default: 10)
"""
def __init__(
self,
i_dim: int,
N_freqs: int = 10,
) -> None:
super().__init__(i_dim, 3 + (2 * N_freqs) * 3)
self.N_freqs = N_freqs
a, b = 1, self.N_freqs - 1
freq_bands = 2 ** torch.linspace(a, b, self.N_freqs)
self.register_buffer("freq_bands", freq_bands)
def forward(self, v: Tensor) -> Tensor:
"""Map v to positional encoding representation phi(v)
Arguments:
v (Tensor): input features (B, IFeatures)
Returns:
phi(v) (Tensor): fourrier features (B, 3 + (2 * N_freqs) * 3)
"""
return map_positional_encoding(v, self.freq_bands)
class BaseTemperalPointModel(nn.Module):
""" A base class providing useful methods for point cloud processing. """
def __init__(
self,
*,
num_classes,
embed_dim,
extra_feature_channels,
dim: int = 768,
num_layers: int = 6
):
super().__init__()
self.extra_feature_channels = extra_feature_channels
self.timestep_embed_dim = 256
self.output_dim = num_classes
self.dim = dim
self.num_layers = num_layers
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, self.timestep_embed_dim ),
nn.SiLU(),
nn.Linear(self.timestep_embed_dim , self.timestep_embed_dim )
)
self.positional_encoding = PositionalEncoding(i_dim=3, N_freqs=10)
positional_encoding_d_out = 3 + (2 * 10) * 3
# Input projection (point coords, point coord encodings, other features, and timestep embeddings)
self.input_projection = nn.Linear(
in_features=(3 + positional_encoding_d_out),
out_features=self.dim
)#b f p c
# Transformer layers
self.layers = self.get_layers()
# Output projection
self.output_projection = nn.Linear(self.dim, self.output_dim)
def get_layers(self):
raise NotImplementedError('This method should be implemented by subclasses')
def forward(self, inputs: torch.Tensor, t: torch.Tensor):
raise NotImplementedError('This method should be implemented by subclasses')
|