File size: 4,325 Bytes
a6028c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from einops import rearrange
import math
from einops_exts import check_shape, rearrange_many
from torch import Size, Tensor, nn

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


def map_positional_encoding(v: Tensor, freq_bands: Tensor) -> Tensor:
    """Map v to positional encoding representation phi(v)

    Arguments:
        v (Tensor): input features (B, IFeatures)
        freq_bands (Tensor): frequency bands (N_freqs, )

    Returns:
        phi(v) (Tensor): fourrier features (B, 3 + (2 * N_freqs) * 3)
    """
    pe = [v]
    for freq in freq_bands:
        fv = freq * v
        pe += [torch.sin(fv), torch.cos(fv)]
    return torch.cat(pe, dim=-1)

class FeatureMapping(nn.Module):
    """FeatureMapping nn.Module

    Maps v to features following transformation phi(v)

    Arguments:
        i_dim (int): input dimensions
        o_dim (int): output dimensions
    """

    def __init__(self, i_dim: int, o_dim: int) -> None:
        super().__init__()
        self.i_dim = i_dim
        self.o_dim = o_dim

    def forward(self, v: Tensor) -> Tensor:
        """FeratureMapping forward pass

        Arguments:
            v (Tensor): input features (B, IFeatures)

        Returns:
            phi(v) (Tensor): mapped features (B, OFeatures)
        """
        raise NotImplementedError("Forward pass not implemented yet!")

class PositionalEncoding(FeatureMapping):
    """PositionalEncoding module

    Maps v to positional encoding representation phi(v)

    Arguments:
        i_dim (int): input dimension for v
        N_freqs (int): #frequency to sample (default: 10)
    """

    def __init__(
        self,
        i_dim: int,
        N_freqs: int = 10,
    ) -> None:
        super().__init__(i_dim, 3 + (2 * N_freqs) * 3)
        self.N_freqs = N_freqs

        a, b = 1, self.N_freqs - 1
        freq_bands = 2 ** torch.linspace(a, b, self.N_freqs)
        self.register_buffer("freq_bands", freq_bands)

    def forward(self, v: Tensor) -> Tensor:
        """Map v to positional encoding representation phi(v)

        Arguments:
            v (Tensor): input features (B, IFeatures)

        Returns:
            phi(v) (Tensor): fourrier features (B, 3 + (2 * N_freqs) * 3)
        """
        return map_positional_encoding(v, self.freq_bands)

class BaseTemperalPointModel(nn.Module):
    """ A base class providing useful methods for point cloud processing. """

    def __init__(
        self,
        *,
        num_classes,
        embed_dim,
        extra_feature_channels,
        dim: int = 768,
        num_layers: int = 6
    ):
        super().__init__()

        self.extra_feature_channels = extra_feature_channels
        self.timestep_embed_dim = 256
        self.output_dim = num_classes
        self.dim = dim
        self.num_layers = num_layers


        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(dim),
            nn.Linear(dim, self.timestep_embed_dim ),
            nn.SiLU(),
            nn.Linear(self.timestep_embed_dim , self.timestep_embed_dim )
        )

        self.positional_encoding = PositionalEncoding(i_dim=3, N_freqs=10)
        positional_encoding_d_out = 3 + (2 * 10) * 3

        # Input projection (point coords, point coord encodings, other features, and timestep embeddings)

        self.input_projection = nn.Linear(
            in_features=(3 + positional_encoding_d_out),
            out_features=self.dim
        )#b f p c

        # Transformer layers
        self.layers = self.get_layers()

        # Output projection
        self.output_projection = nn.Linear(self.dim, self.output_dim)
    def get_layers(self):
        raise NotImplementedError('This method should be implemented by subclasses')

    def forward(self, inputs: torch.Tensor, t: torch.Tensor):
        raise NotImplementedError('This method should be implemented by subclasses')