Spaces:
Running
on
Zero
Running
on
Zero
Upload 20 files
Browse files- difpoint/model/__init__.py +6 -0
- difpoint/model/__pycache__/__init__.cpython-310.pyc +0 -0
- difpoint/model/__pycache__/__init__.cpython-38.pyc +0 -0
- difpoint/model/__pycache__/model.cpython-310.pyc +0 -0
- difpoint/model/__pycache__/model.cpython-38.pyc +0 -0
- difpoint/model/__pycache__/model_utils.cpython-310.pyc +0 -0
- difpoint/model/__pycache__/model_utils.cpython-38.pyc +0 -0
- difpoint/model/__pycache__/point_model.cpython-310.pyc +0 -0
- difpoint/model/__pycache__/point_model.cpython-38.pyc +0 -0
- difpoint/model/model.py +409 -0
- difpoint/model/model_utils.py +35 -0
- difpoint/model/point_model.py +38 -0
- difpoint/model/temporaltrans/__pycache__/temptrans.cpython-310.pyc +0 -0
- difpoint/model/temporaltrans/__pycache__/temptrans.cpython-38.pyc +0 -0
- difpoint/model/temporaltrans/__pycache__/transformer_utils.cpython-310.pyc +0 -0
- difpoint/model/temporaltrans/__pycache__/transformer_utils.cpython-38.pyc +0 -0
- difpoint/model/temporaltrans/pointnet_util.py +311 -0
- difpoint/model/temporaltrans/pointtransformerv2.py +250 -0
- difpoint/model/temporaltrans/temptrans.py +347 -0
- difpoint/model/temporaltrans/transformer_utils.py +146 -0
difpoint/model/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .model import ConditionalPointCloudDiffusionModel
|
2 |
+
|
3 |
+
def get_model():
|
4 |
+
model = ConditionalPointCloudDiffusionModel()
|
5 |
+
return model
|
6 |
+
|
difpoint/model/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (325 Bytes). View file
|
|
difpoint/model/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (328 Bytes). View file
|
|
difpoint/model/__pycache__/model.cpython-310.pyc
ADDED
Binary file (5.33 kB). View file
|
|
difpoint/model/__pycache__/model.cpython-38.pyc
ADDED
Binary file (5.25 kB). View file
|
|
difpoint/model/__pycache__/model_utils.cpython-310.pyc
ADDED
Binary file (1.69 kB). View file
|
|
difpoint/model/__pycache__/model_utils.cpython-38.pyc
ADDED
Binary file (1.7 kB). View file
|
|
difpoint/model/__pycache__/point_model.cpython-310.pyc
ADDED
Binary file (1.78 kB). View file
|
|
difpoint/model/__pycache__/point_model.cpython-38.pyc
ADDED
Binary file (1.73 kB). View file
|
|
difpoint/model/model.py
ADDED
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from typing import Optional
|
3 |
+
from einops import rearrange
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
7 |
+
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
8 |
+
from diffusers.schedulers.scheduling_pndm import PNDMScheduler
|
9 |
+
|
10 |
+
from torch import Tensor
|
11 |
+
from tqdm import tqdm
|
12 |
+
from diffusers import ModelMixin
|
13 |
+
from .model_utils import get_custom_betas
|
14 |
+
from .point_model import PointModel
|
15 |
+
|
16 |
+
import copy
|
17 |
+
class ConditionalPointCloudDiffusionModel(ModelMixin):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
beta_start: float = 1e-5,
|
21 |
+
beta_end: float = 8e-3,
|
22 |
+
beta_schedule: str = 'linear',
|
23 |
+
point_cloud_model: str = 'simple',
|
24 |
+
point_cloud_model_embed_dim: int = 64,
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
self.in_channels = 70 # 3 for 3D point positions
|
28 |
+
self.out_channels = 70
|
29 |
+
|
30 |
+
# Checks
|
31 |
+
# Create diffusion model schedulers which define the sampling timesteps
|
32 |
+
scheduler_kwargs = {}
|
33 |
+
if beta_schedule == 'custom':
|
34 |
+
scheduler_kwargs.update(dict(trained_betas=get_custom_betas(beta_start=beta_start, beta_end=beta_end)))
|
35 |
+
else:
|
36 |
+
scheduler_kwargs.update(dict(beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule))
|
37 |
+
self.schedulers_map = {
|
38 |
+
'ddpm': DDPMScheduler(**scheduler_kwargs, clip_sample=False),
|
39 |
+
'ddim': DDIMScheduler(**scheduler_kwargs, clip_sample=False),
|
40 |
+
'pndm': PNDMScheduler(**scheduler_kwargs),
|
41 |
+
}
|
42 |
+
self.scheduler = self.schedulers_map['ddim'] # this can be changed for inference
|
43 |
+
|
44 |
+
# Create point cloud model for processing point cloud at each diffusion step
|
45 |
+
self.point_model = PointModel(
|
46 |
+
model_type=point_cloud_model,
|
47 |
+
embed_dim=point_cloud_model_embed_dim,
|
48 |
+
in_channels=self.in_channels,
|
49 |
+
out_channels=self.out_channels,
|
50 |
+
)
|
51 |
+
|
52 |
+
def forward_train(
|
53 |
+
self,
|
54 |
+
pc: Optional[Tensor],
|
55 |
+
ref_kps: Optional[Tensor],
|
56 |
+
ori_kps: Optional[Tensor],
|
57 |
+
aud_feat: Optional[Tensor],
|
58 |
+
mode: str = 'train',
|
59 |
+
return_intermediate_steps: bool = False
|
60 |
+
):
|
61 |
+
|
62 |
+
# Normalize colors and convert to tensor
|
63 |
+
x_0 = pc
|
64 |
+
B, Nf, Np, D = x_0.shape# batch, nums of frames, nums of points, 3
|
65 |
+
|
66 |
+
|
67 |
+
x_0=x_0[:,:,:,0]# batch, nums of frames, 70
|
68 |
+
|
69 |
+
# Sample random noise
|
70 |
+
noise = torch.randn_like(x_0)
|
71 |
+
|
72 |
+
# Sample random timesteps for each point_cloud
|
73 |
+
timestep = torch.randint(0, self.scheduler.num_train_timesteps, (B,),
|
74 |
+
device=self.device, dtype=torch.long)
|
75 |
+
|
76 |
+
# Add noise to points
|
77 |
+
x_t = self.scheduler.add_noise(x_0, noise, timestep)
|
78 |
+
|
79 |
+
# Conditioning
|
80 |
+
ref_kps = ref_kps[:, :, 0]
|
81 |
+
|
82 |
+
x_t_input = torch.cat([ori_kps.unsqueeze(1), ref_kps.unsqueeze(1), x_t], dim=1)
|
83 |
+
|
84 |
+
# x_t_input = torch.cat([ref_kps.unsqueeze(1), x_t], dim=1)
|
85 |
+
|
86 |
+
# ori_kps_repeat = torch.repeat_interleave(ori_kps.unsqueeze(1), repeats=Nf+1, dim=1)
|
87 |
+
|
88 |
+
# x_t_input = torch.cat([x_t_input, ori_kps_repeat], dim=-1) #B, 32+1, 51+45
|
89 |
+
|
90 |
+
|
91 |
+
aud_feat = torch.cat([torch.zeros(B, 2, 512).cuda(), aud_feat], 1)
|
92 |
+
|
93 |
+
# Augmentation for audio feature
|
94 |
+
if mode in 'train':
|
95 |
+
if torch.rand(1) > 0.3:
|
96 |
+
mean = torch.mean(aud_feat)
|
97 |
+
std = torch.std(aud_feat)
|
98 |
+
sample = torch.normal(mean=torch.full(aud_feat.shape, mean), std=torch.full(aud_feat.shape, std)).cuda()
|
99 |
+
aud_feat = sample + aud_feat
|
100 |
+
else:
|
101 |
+
pass
|
102 |
+
else:
|
103 |
+
pass
|
104 |
+
|
105 |
+
# Forward
|
106 |
+
noise_pred = self.point_model(x_t_input, timestep, context=aud_feat) #torch.cat([mel_feat,style_embed],-1))
|
107 |
+
noise_pred = noise_pred[:, 2:]
|
108 |
+
#
|
109 |
+
# Check
|
110 |
+
if not noise_pred.shape == noise.shape:
|
111 |
+
raise ValueError(f'{noise_pred.shape=} and {noise.shape=}')
|
112 |
+
|
113 |
+
# Loss
|
114 |
+
loss = F.mse_loss(noise_pred, noise)
|
115 |
+
|
116 |
+
loss_pose = F.mse_loss(noise_pred[:, :, :6], noise[:, :, :6])
|
117 |
+
loss_exp = F.mse_loss(noise_pred[:, :, 6:], noise[:, :, 6:])
|
118 |
+
|
119 |
+
|
120 |
+
# Whether to return intermediate steps
|
121 |
+
if return_intermediate_steps:
|
122 |
+
return loss, (x_0, x_t, noise, noise_pred)
|
123 |
+
|
124 |
+
return loss, loss_exp, loss_pose
|
125 |
+
|
126 |
+
# def forward_train(
|
127 |
+
# self,
|
128 |
+
# pc: Optional[Tensor],
|
129 |
+
# ref_kps: Optional[Tensor],
|
130 |
+
# ori_kps: Optional[Tensor],
|
131 |
+
# aud_feat: Optional[Tensor],
|
132 |
+
# mode: str = 'train',
|
133 |
+
# return_intermediate_steps: bool = False
|
134 |
+
# ):
|
135 |
+
#
|
136 |
+
# # Normalize colors and convert to tensor
|
137 |
+
# x_0 = pc
|
138 |
+
# B, Nf, Np, D = x_0.shape# batch, nums of frames, nums of points, 3
|
139 |
+
#
|
140 |
+
# # ori_kps = torch.repeat_interleave(ori_kps.unsqueeze(1), Nf, dim=1) # B, Nf, 45
|
141 |
+
# #
|
142 |
+
# # ref_kps = ref_kps[:, :, 0]
|
143 |
+
# # ref_kps = torch.repeat_interleave(ref_kps.unsqueeze(1), Nf, dim=1) # B, Nf, 91
|
144 |
+
#
|
145 |
+
# x_0 = x_0[:,:,:,0]
|
146 |
+
#
|
147 |
+
# # Sample random noise
|
148 |
+
# noise = torch.randn_like(x_0)
|
149 |
+
#
|
150 |
+
# # Sample random timesteps for each point_cloud
|
151 |
+
# timestep = torch.randint(0, self.scheduler.num_train_timesteps, (B,),
|
152 |
+
# device=self.device, dtype=torch.long)
|
153 |
+
#
|
154 |
+
# # Add noise to points
|
155 |
+
# x_t = self.scheduler.add_noise(x_0, noise, timestep)
|
156 |
+
#
|
157 |
+
# # Conditioning
|
158 |
+
# ref_kps = ref_kps[:,:,0]
|
159 |
+
#
|
160 |
+
# # x_t_input = torch.cat([ref_kps.unsqueeze(1), x_t], dim=1)
|
161 |
+
#
|
162 |
+
# # x_0 = torch.cat([x_0, ref_kps, ori_kps], dim=2) # B, Nf, 91+91+45
|
163 |
+
#
|
164 |
+
# x_t_input = torch.cat([ref_kps.unsqueeze(1), x_t], dim=1)
|
165 |
+
# # x_t_input = torch.cat([ori_kps.unsqueeze(1), ref_kps.unsqueeze(1), x_t], dim=1)
|
166 |
+
#
|
167 |
+
# aud_feat = torch.cat([torch.zeros(B, 1, 512).cuda(), aud_feat], 1)
|
168 |
+
#
|
169 |
+
# # Augmentation for audio feature
|
170 |
+
# if mode in 'train':
|
171 |
+
# if torch.rand(1) > 0.3:
|
172 |
+
# mean = torch.mean(aud_feat)
|
173 |
+
# std = torch.std(aud_feat)
|
174 |
+
# sample = torch.normal(mean=torch.full(aud_feat.shape, mean), std=torch.full(aud_feat.shape, std)).cuda()
|
175 |
+
# aud_feat = sample + aud_feat
|
176 |
+
# else:
|
177 |
+
# pass
|
178 |
+
# else:
|
179 |
+
# pass
|
180 |
+
#
|
181 |
+
# # Forward
|
182 |
+
# noise_pred = self.point_model(x_t_input, timestep, context=aud_feat)
|
183 |
+
# noise_pred = noise_pred[:, 1:]
|
184 |
+
#
|
185 |
+
# # Check
|
186 |
+
# # if not noise_pred.shape == noise.shape:
|
187 |
+
# # raise ValueError(f'{noise_pred.shape=} and {noise.shape=}')
|
188 |
+
#
|
189 |
+
# # Loss
|
190 |
+
# loss = F.mse_loss(noise_pred, noise)
|
191 |
+
#
|
192 |
+
# # loss_kp = F.mse_loss(noise_pred[:, :, :45], noise[:, :, :45])
|
193 |
+
#
|
194 |
+
# # Whether to return intermediate steps
|
195 |
+
# if return_intermediate_steps:
|
196 |
+
# return loss, (x_0, x_t, noise, noise_pred)
|
197 |
+
#
|
198 |
+
# return loss
|
199 |
+
|
200 |
+
# @torch.no_grad()
|
201 |
+
# def forward_sample(
|
202 |
+
# self,
|
203 |
+
# num_points: int,
|
204 |
+
# ref_kps: Optional[Tensor],
|
205 |
+
# ori_kps: Optional[Tensor],
|
206 |
+
# aud_feat: Optional[Tensor],
|
207 |
+
# # Optional overrides
|
208 |
+
# scheduler: Optional[str] = 'ddpm',
|
209 |
+
# # Inference parameters
|
210 |
+
# num_inference_steps: Optional[int] = 1000,
|
211 |
+
# eta: Optional[float] = 0.0, # for DDIM
|
212 |
+
# # Whether to return all the intermediate steps in generation
|
213 |
+
# return_sample_every_n_steps: int = -1,
|
214 |
+
# # Whether to disable tqdm
|
215 |
+
# disable_tqdm: bool = False,
|
216 |
+
# ):
|
217 |
+
#
|
218 |
+
# # Get scheduler from mapping, or use self.scheduler if None
|
219 |
+
# scheduler = self.scheduler if scheduler is None else self.schedulers_map[scheduler]
|
220 |
+
#
|
221 |
+
# # Get the size of the noise
|
222 |
+
# Np = num_points
|
223 |
+
# Nf = aud_feat.size(1)
|
224 |
+
# B = 1
|
225 |
+
# D = 3
|
226 |
+
# device = self.device
|
227 |
+
#
|
228 |
+
# # Sample noise
|
229 |
+
# x_t = torch.randn(B, Nf, Np, D, device=device)
|
230 |
+
#
|
231 |
+
# x_t = x_t[:, :, :, 0]
|
232 |
+
#
|
233 |
+
# # ori_kps = torch.repeat_interleave(ori_kps.unsqueeze(1), Nf, dim=1) # B, Nf, 45
|
234 |
+
#
|
235 |
+
# ref_kps = ref_kps[:, :, 0]
|
236 |
+
# # ref_kps = torch.repeat_interleave(ref_kps.unsqueeze(1), Nf, dim=1) # B, Nf, 91
|
237 |
+
#
|
238 |
+
# # Set timesteps
|
239 |
+
# accepts_offset = "offset" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
240 |
+
# extra_set_kwargs = {"offset": 1} if accepts_offset else {}
|
241 |
+
# scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
242 |
+
#
|
243 |
+
# # Prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
244 |
+
# # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
245 |
+
# # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
246 |
+
# # and should be between [0, 1]
|
247 |
+
# accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
|
248 |
+
# extra_step_kwargs = {"eta": eta} if accepts_eta else {}
|
249 |
+
#
|
250 |
+
# # Loop over timesteps
|
251 |
+
# all_outputs = []
|
252 |
+
# return_all_outputs = (return_sample_every_n_steps > 0)
|
253 |
+
# progress_bar = tqdm(scheduler.timesteps.to(device), desc=f'Sampling ({x_t.shape})', disable=disable_tqdm)
|
254 |
+
#
|
255 |
+
# # ori_kps = torch.repeat_interleave(ori_kps[:, 6:].unsqueeze(1), Nf + 1, dim=1)
|
256 |
+
# aud_feat = torch.cat([torch.zeros(B, 1, 512).cuda(), aud_feat], 1)
|
257 |
+
# # aud_feat = torch.cat([ori_kps, aud_feat], -1)
|
258 |
+
#
|
259 |
+
# # aud_feat = torch.cat([torch.zeros(B, 2, 512).cuda(), aud_feat], 1)
|
260 |
+
#
|
261 |
+
# for i, t in enumerate(progress_bar):
|
262 |
+
#
|
263 |
+
# # Conditioning
|
264 |
+
# x_t_input = torch.cat([ref_kps.unsqueeze(1).detach(), x_t], dim=1)
|
265 |
+
# # x_t_input = torch.cat([ori_kps.unsqueeze(1).detach(), ref_kps.unsqueeze(1).detach(), x_t], dim=1)
|
266 |
+
# # x_t_input = torch.cat([x_t, ref_kps, ori_kps], dim=2) # B, Nf, 91+91+45
|
267 |
+
#
|
268 |
+
# # Forward
|
269 |
+
# # noise_pred = self.point_model(x_t_input, t.reshape(1).expand(B), context=aud_feat)[:, 1:]
|
270 |
+
# noise_pred = self.point_model(x_t_input, t.reshape(1).expand(B), context=aud_feat)[:, 1:]
|
271 |
+
#
|
272 |
+
# # noise_pred = noise_pred[:, :, :51]
|
273 |
+
#
|
274 |
+
# # Step
|
275 |
+
# # x_t = x_t[:, :, :51]
|
276 |
+
# x_t = scheduler.step(noise_pred, t, x_t, **extra_step_kwargs).prev_sample
|
277 |
+
#
|
278 |
+
# # Append to output list if desired
|
279 |
+
# if (return_all_outputs and (i % return_sample_every_n_steps == 0 or i == len(scheduler.timesteps) - 1)):
|
280 |
+
# all_outputs.append(x_t)
|
281 |
+
#
|
282 |
+
# # Convert output back into a point cloud, undoing normalization and scaling
|
283 |
+
# output = x_t
|
284 |
+
# output = torch.stack([output, output, output], -1)
|
285 |
+
# if return_all_outputs:
|
286 |
+
# all_outputs = torch.stack(all_outputs, dim=1) # (B, sample_steps, N, D)
|
287 |
+
# return (output, all_outputs) if return_all_outputs else output
|
288 |
+
|
289 |
+
|
290 |
+
@torch.no_grad()
|
291 |
+
def forward_sample(
|
292 |
+
self,
|
293 |
+
num_points: int,
|
294 |
+
ref_kps: Optional[Tensor],
|
295 |
+
ori_kps: Optional[Tensor],
|
296 |
+
aud_feat: Optional[Tensor],
|
297 |
+
# Optional overrides
|
298 |
+
scheduler: Optional[str] = 'ddpm',
|
299 |
+
# Inference parameters
|
300 |
+
num_inference_steps: Optional[int] = 1000,
|
301 |
+
eta: Optional[float] = 0.0, # for DDIM
|
302 |
+
# Whether to return all the intermediate steps in generation
|
303 |
+
return_sample_every_n_steps: int = -1,
|
304 |
+
# Whether to disable tqdm
|
305 |
+
disable_tqdm: bool = False,
|
306 |
+
):
|
307 |
+
|
308 |
+
# Get scheduler from mapping, or use self.scheduler if None
|
309 |
+
scheduler = self.scheduler if scheduler is None else self.schedulers_map[scheduler]
|
310 |
+
|
311 |
+
# Get the size of the noise
|
312 |
+
Np = num_points
|
313 |
+
Nf = aud_feat.size(1)
|
314 |
+
B = 1
|
315 |
+
D = 3
|
316 |
+
device = self.device
|
317 |
+
|
318 |
+
# Sample noise
|
319 |
+
x_t = torch.randn(B, Nf, Np, D, device=device)
|
320 |
+
|
321 |
+
x_t = x_t[:, :, :, 0]
|
322 |
+
|
323 |
+
ref_kps = ref_kps[:,:,0]
|
324 |
+
|
325 |
+
# Set timesteps
|
326 |
+
accepts_offset = "offset" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
327 |
+
extra_set_kwargs = {"offset": 1} if accepts_offset else {}
|
328 |
+
scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
329 |
+
|
330 |
+
# Prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
331 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
332 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
333 |
+
# and should be between [0, 1]
|
334 |
+
accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
|
335 |
+
extra_step_kwargs = {"eta": eta} if accepts_eta else {}
|
336 |
+
|
337 |
+
# Loop over timesteps
|
338 |
+
all_outputs = []
|
339 |
+
return_all_outputs = (return_sample_every_n_steps > 0)
|
340 |
+
progress_bar = tqdm(scheduler.timesteps.to(device), desc=f'Sampling ({x_t.shape})', disable=disable_tqdm)
|
341 |
+
|
342 |
+
# ori_kps = torch.repeat_interleave(ori_kps[:, 6:].unsqueeze(1), Nf + 1, dim=1)
|
343 |
+
# aud_feat = torch.cat([torch.zeros(B, 1, 512).cuda(), aud_feat], 1)
|
344 |
+
# aud_feat = torch.cat([ori_kps, aud_feat], -1)
|
345 |
+
|
346 |
+
aud_feat = torch.cat([torch.zeros(B, 2, 512).cuda(), aud_feat], 1)
|
347 |
+
|
348 |
+
for i, t in enumerate(progress_bar):
|
349 |
+
|
350 |
+
# Conditioning
|
351 |
+
# x_t_input = torch.cat([ref_kps.unsqueeze(1), x_t], dim=1)
|
352 |
+
#
|
353 |
+
# ori_kps_repeat = torch.repeat_interleave(ori_kps.unsqueeze(1), repeats=Nf + 1, dim=1)
|
354 |
+
#
|
355 |
+
# x_t_input = torch.cat([x_t_input.detach(), ori_kps_repeat.detach()], dim=-1) # B, 32+1, 51+45
|
356 |
+
|
357 |
+
|
358 |
+
x_t_input = torch.cat([ori_kps.unsqueeze(1).detach(),ref_kps.unsqueeze(1).detach(), x_t], dim=1)
|
359 |
+
# x_t_input = torch.cat([ref_kps.unsqueeze(1).detach(), x_t], dim=1)
|
360 |
+
|
361 |
+
|
362 |
+
# Forward
|
363 |
+
# noise_pred = self.point_model(x_t_input, t.reshape(1).expand(B), context=aud_feat)[:, 1:]
|
364 |
+
noise_pred = self.point_model(x_t_input, t.reshape(1).expand(B), context=aud_feat)[:, 2:]
|
365 |
+
|
366 |
+
|
367 |
+
# Step
|
368 |
+
x_t = scheduler.step(noise_pred, t, x_t, **extra_step_kwargs).prev_sample
|
369 |
+
|
370 |
+
# Append to output list if desired
|
371 |
+
if (return_all_outputs and (i % return_sample_every_n_steps == 0 or i == len(scheduler.timesteps) - 1)):
|
372 |
+
all_outputs.append(x_t)
|
373 |
+
|
374 |
+
# Convert output back into a point cloud, undoing normalization and scaling
|
375 |
+
output = x_t
|
376 |
+
output = torch.stack([output,output,output],-1)
|
377 |
+
if return_all_outputs:
|
378 |
+
all_outputs = torch.stack(all_outputs, dim=1) # (B, sample_steps, N, D)
|
379 |
+
return (output, all_outputs) if return_all_outputs else output
|
380 |
+
|
381 |
+
def forward(self, batch: dict, mode: str = 'train', **kwargs):
|
382 |
+
"""A wrapper around the forward method for training and inference"""
|
383 |
+
|
384 |
+
if mode == 'train':
|
385 |
+
return self.forward_train(
|
386 |
+
pc=batch['sequence_keypoints'],
|
387 |
+
ref_kps=batch['ref_keypoint'],
|
388 |
+
ori_kps=batch['ori_keypoint'],
|
389 |
+
aud_feat=batch['aud_feat'],
|
390 |
+
mode='train',
|
391 |
+
**kwargs)
|
392 |
+
elif mode == 'val':
|
393 |
+
return self.forward_train(
|
394 |
+
pc=batch['sequence_keypoints'],
|
395 |
+
ref_kps=batch['ref_keypoint'],
|
396 |
+
ori_kps=batch['ori_keypoint'],
|
397 |
+
aud_feat=batch['aud_feat'],
|
398 |
+
mode='val',
|
399 |
+
**kwargs)
|
400 |
+
elif mode == 'sample':
|
401 |
+
num_points = 68
|
402 |
+
return self.forward_sample(
|
403 |
+
num_points=num_points,
|
404 |
+
ref_kps=batch['ref_keypoint'],
|
405 |
+
ori_kps=batch['ori_keypoint'],
|
406 |
+
aud_feat=batch['aud_feat'],
|
407 |
+
**kwargs)
|
408 |
+
else:
|
409 |
+
raise NotImplementedError()
|
difpoint/model/model_utils.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
def set_requires_grad(module: nn.Module, requires_grad: bool):
|
9 |
+
for p in module.parameters():
|
10 |
+
p.requires_grad_(requires_grad)
|
11 |
+
|
12 |
+
|
13 |
+
def compute_distance_transform(mask: torch.Tensor):
|
14 |
+
image_size = mask.shape[-1]
|
15 |
+
distance_transform = torch.stack([
|
16 |
+
torch.from_numpy(cv2.distanceTransform(
|
17 |
+
(1 - m), distanceType=cv2.DIST_L2, maskSize=cv2.DIST_MASK_3
|
18 |
+
) / (image_size / 2))
|
19 |
+
for m in mask.squeeze(1).detach().cpu().numpy().astype(np.uint8)
|
20 |
+
]).unsqueeze(1).clip(0, 1).to(mask.device)
|
21 |
+
return distance_transform
|
22 |
+
|
23 |
+
|
24 |
+
def default(x, d):
|
25 |
+
return d if x is None else x
|
26 |
+
|
27 |
+
def get_custom_betas(beta_start: float, beta_end: float, warmup_frac: float = 0.3, num_train_timesteps: int = 1000):
|
28 |
+
"""Custom beta schedule"""
|
29 |
+
betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
|
30 |
+
warmup_frac = 0.3
|
31 |
+
warmup_time = int(num_train_timesteps * warmup_frac)
|
32 |
+
warmup_steps = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
|
33 |
+
warmup_time = min(warmup_time, num_train_timesteps)
|
34 |
+
betas[:warmup_time] = warmup_steps[:warmup_time]
|
35 |
+
return betas
|
difpoint/model/point_model.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
3 |
+
from diffusers import ModelMixin
|
4 |
+
from torch import Tensor
|
5 |
+
|
6 |
+
from .temporaltrans.temptrans import SimpleTemperalPointModel, SimpleTransModel
|
7 |
+
|
8 |
+
class PointModel(ModelMixin, ConfigMixin):
|
9 |
+
@register_to_config
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
model_type: str = 'pvcnn',
|
13 |
+
in_channels: int = 3,
|
14 |
+
out_channels: int = 3,
|
15 |
+
embed_dim: int = 64,
|
16 |
+
dropout: float = 0.1,
|
17 |
+
width_multiplier: int = 1,
|
18 |
+
voxel_resolution_multiplier: int = 1,
|
19 |
+
):
|
20 |
+
super().__init__()
|
21 |
+
self.model_type = model_type
|
22 |
+
if self.model_type == 'simple':
|
23 |
+
self.autocast_context = torch.autocast('cuda', dtype=torch.float32)
|
24 |
+
self.model = SimpleTransModel(
|
25 |
+
embed_dim=embed_dim,
|
26 |
+
num_classes=out_channels,
|
27 |
+
extra_feature_channels=(in_channels - 3),
|
28 |
+
)
|
29 |
+
self.model.output_projection.bias.data.normal_(0, 1e-6)
|
30 |
+
self.model.output_projection.weight.data.normal_(0, 1e-6)
|
31 |
+
else:
|
32 |
+
raise NotImplementedError()
|
33 |
+
|
34 |
+
def forward(self, inputs: Tensor, t: Tensor, context=None) -> Tensor:
|
35 |
+
""" Receives input of shape (B, N, in_channels) and returns output
|
36 |
+
of shape (B, N, out_channels) """
|
37 |
+
with self.autocast_context:
|
38 |
+
return self.model(inputs, t, context)
|
difpoint/model/temporaltrans/__pycache__/temptrans.cpython-310.pyc
ADDED
Binary file (11 kB). View file
|
|
difpoint/model/temporaltrans/__pycache__/temptrans.cpython-38.pyc
ADDED
Binary file (11.1 kB). View file
|
|
difpoint/model/temporaltrans/__pycache__/transformer_utils.cpython-310.pyc
ADDED
Binary file (5.09 kB). View file
|
|
difpoint/model/temporaltrans/__pycache__/transformer_utils.cpython-38.pyc
ADDED
Binary file (5.09 kB). View file
|
|
difpoint/model/temporaltrans/pointnet_util.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from time import time
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
# reference https://github.com/yanx27/Pointnet_Pointnet2_pytorch, modified by Yang You
|
9 |
+
|
10 |
+
|
11 |
+
def timeit(tag, t):
|
12 |
+
print("{}: {}s".format(tag, time() - t))
|
13 |
+
return time()
|
14 |
+
|
15 |
+
def pc_normalize(pc):
|
16 |
+
centroid = np.mean(pc, axis=0)
|
17 |
+
pc = pc - centroid
|
18 |
+
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
|
19 |
+
pc = pc / m
|
20 |
+
return pc
|
21 |
+
|
22 |
+
def square_distance(src, dst):
|
23 |
+
"""
|
24 |
+
Calculate Euclid distance between each two points.
|
25 |
+
src^T * dst = xn * xm + yn * ym + zn * zm;
|
26 |
+
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
|
27 |
+
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
|
28 |
+
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
|
29 |
+
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
|
30 |
+
Input:
|
31 |
+
src: source points, [B, N, C]
|
32 |
+
dst: target points, [B, M, C]
|
33 |
+
Output:
|
34 |
+
dist: per-point square distance, [B, N, M]
|
35 |
+
"""
|
36 |
+
return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1)
|
37 |
+
|
38 |
+
|
39 |
+
def index_points(points, idx):
|
40 |
+
"""
|
41 |
+
Input:
|
42 |
+
points: input points data, [B, N, C]
|
43 |
+
idx: sample index data, [B, S, [K]]
|
44 |
+
Return:
|
45 |
+
new_points:, indexed points data, [B, S, [K], C]
|
46 |
+
"""
|
47 |
+
raw_size = idx.size()
|
48 |
+
idx = idx.reshape(raw_size[0], -1)
|
49 |
+
res = torch.gather(points, 1, idx[..., None].expand(-1, -1, points.size(-1)))
|
50 |
+
return res.reshape(*raw_size, -1)
|
51 |
+
|
52 |
+
|
53 |
+
def farthest_point_sample(xyz, npoint):
|
54 |
+
"""
|
55 |
+
Input:
|
56 |
+
xyz: pointcloud data, [B, N, 3]
|
57 |
+
npoint: number of samples
|
58 |
+
Return:
|
59 |
+
centroids: sampled pointcloud index, [B, npoint]
|
60 |
+
"""
|
61 |
+
device = xyz.device
|
62 |
+
B, N, C = xyz.shape
|
63 |
+
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
|
64 |
+
distance = torch.ones(B, N).to(device) * 1e10
|
65 |
+
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
|
66 |
+
batch_indices = torch.arange(B, dtype=torch.long).to(device)
|
67 |
+
for i in range(npoint):
|
68 |
+
centroids[:, i] = farthest
|
69 |
+
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
|
70 |
+
dist = torch.sum((xyz - centroid) ** 2, -1)
|
71 |
+
distance = torch.min(distance, dist)
|
72 |
+
farthest = torch.max(distance, -1)[1]
|
73 |
+
return centroids
|
74 |
+
|
75 |
+
|
76 |
+
def query_ball_point(radius, nsample, xyz, new_xyz):
|
77 |
+
"""
|
78 |
+
Input:
|
79 |
+
radius: local region radius
|
80 |
+
nsample: max sample number in local region
|
81 |
+
xyz: all points, [B, N, 3]
|
82 |
+
new_xyz: query points, [B, S, 3]
|
83 |
+
Return:
|
84 |
+
group_idx: grouped points index, [B, S, nsample]
|
85 |
+
"""
|
86 |
+
device = xyz.device
|
87 |
+
B, N, C = xyz.shape
|
88 |
+
_, S, _ = new_xyz.shape
|
89 |
+
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
|
90 |
+
sqrdists = square_distance(new_xyz, xyz)
|
91 |
+
group_idx[sqrdists > radius ** 2] = N
|
92 |
+
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
|
93 |
+
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
|
94 |
+
mask = group_idx == N
|
95 |
+
group_idx[mask] = group_first[mask]
|
96 |
+
return group_idx
|
97 |
+
|
98 |
+
|
99 |
+
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False, knn=False):
|
100 |
+
"""
|
101 |
+
Input:
|
102 |
+
npoint:
|
103 |
+
radius:
|
104 |
+
nsample:
|
105 |
+
xyz: input points position data, [B, N, 3]
|
106 |
+
points: input points data, [B, N, D]
|
107 |
+
Return:
|
108 |
+
new_xyz: sampled points position data, [B, npoint, nsample, 3]
|
109 |
+
new_points: sampled points data, [B, npoint, nsample, 3+D]
|
110 |
+
"""
|
111 |
+
B, N, C = xyz.shape
|
112 |
+
S = npoint
|
113 |
+
fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint]
|
114 |
+
torch.cuda.empty_cache()
|
115 |
+
new_xyz = index_points(xyz, fps_idx)
|
116 |
+
torch.cuda.empty_cache()
|
117 |
+
if knn:
|
118 |
+
dists = square_distance(new_xyz, xyz) # B x npoint x N
|
119 |
+
idx = dists.argsort()[:, :, :nsample] # B x npoint x K
|
120 |
+
else:
|
121 |
+
idx = query_ball_point(radius, nsample, xyz, new_xyz)
|
122 |
+
torch.cuda.empty_cache()
|
123 |
+
grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
|
124 |
+
torch.cuda.empty_cache()
|
125 |
+
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
|
126 |
+
torch.cuda.empty_cache()
|
127 |
+
|
128 |
+
if points is not None:
|
129 |
+
grouped_points = index_points(points, idx)
|
130 |
+
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
|
131 |
+
else:
|
132 |
+
new_points = grouped_xyz_norm
|
133 |
+
if returnfps:
|
134 |
+
return new_xyz, new_points, grouped_xyz, fps_idx
|
135 |
+
else:
|
136 |
+
return new_xyz, new_points
|
137 |
+
|
138 |
+
|
139 |
+
def sample_and_group_all(xyz, points):
|
140 |
+
"""
|
141 |
+
Input:
|
142 |
+
xyz: input points position data, [B, N, 3]
|
143 |
+
points: input points data, [B, N, D]
|
144 |
+
Return:
|
145 |
+
new_xyz: sampled points position data, [B, 1, 3]
|
146 |
+
new_points: sampled points data, [B, 1, N, 3+D]
|
147 |
+
"""
|
148 |
+
device = xyz.device
|
149 |
+
B, N, C = xyz.shape
|
150 |
+
new_xyz = torch.zeros(B, 1, C).to(device)
|
151 |
+
grouped_xyz = xyz.view(B, 1, N, C)
|
152 |
+
if points is not None:
|
153 |
+
new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
|
154 |
+
else:
|
155 |
+
new_points = grouped_xyz
|
156 |
+
return new_xyz, new_points
|
157 |
+
|
158 |
+
|
159 |
+
class PointNetSetAbstraction(nn.Module):
|
160 |
+
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all, knn=False):
|
161 |
+
super(PointNetSetAbstraction, self).__init__()
|
162 |
+
self.npoint = npoint
|
163 |
+
self.radius = radius
|
164 |
+
self.nsample = nsample
|
165 |
+
self.knn = knn
|
166 |
+
self.mlp_convs = nn.ModuleList()
|
167 |
+
self.mlp_bns = nn.ModuleList()
|
168 |
+
last_channel = in_channel
|
169 |
+
for out_channel in mlp:
|
170 |
+
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
|
171 |
+
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
|
172 |
+
last_channel = out_channel
|
173 |
+
self.group_all = group_all
|
174 |
+
|
175 |
+
def forward(self, xyz, points):
|
176 |
+
"""
|
177 |
+
Input:
|
178 |
+
xyz: input points position data, [B, N, C]
|
179 |
+
points: input points data, [B, N, C]
|
180 |
+
Return:
|
181 |
+
new_xyz: sampled points position data, [B, S, C]
|
182 |
+
new_points_concat: sample points feature data, [B, S, D']
|
183 |
+
"""
|
184 |
+
if self.group_all:
|
185 |
+
new_xyz, new_points = sample_and_group_all(xyz, points)
|
186 |
+
else:
|
187 |
+
new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points, knn=self.knn)
|
188 |
+
# new_xyz: sampled points position data, [B, npoint, C]
|
189 |
+
# new_points: sampled points data, [B, npoint, nsample, C+D]
|
190 |
+
new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
|
191 |
+
for i, conv in enumerate(self.mlp_convs):
|
192 |
+
bn = self.mlp_bns[i]
|
193 |
+
new_points = F.relu(bn(conv(new_points)))
|
194 |
+
|
195 |
+
new_points = torch.max(new_points, 2)[0].transpose(1, 2)
|
196 |
+
return new_xyz, new_points
|
197 |
+
|
198 |
+
|
199 |
+
class PointNetSetAbstractionMsg(nn.Module):
|
200 |
+
def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list, knn=False):
|
201 |
+
super(PointNetSetAbstractionMsg, self).__init__()
|
202 |
+
self.npoint = npoint
|
203 |
+
self.radius_list = radius_list
|
204 |
+
self.nsample_list = nsample_list
|
205 |
+
self.knn = knn
|
206 |
+
self.conv_blocks = nn.ModuleList()
|
207 |
+
self.bn_blocks = nn.ModuleList()
|
208 |
+
for i in range(len(mlp_list)):
|
209 |
+
convs = nn.ModuleList()
|
210 |
+
bns = nn.ModuleList()
|
211 |
+
last_channel = in_channel + 3
|
212 |
+
for out_channel in mlp_list[i]:
|
213 |
+
convs.append(nn.Conv2d(last_channel, out_channel, 1))
|
214 |
+
bns.append(nn.BatchNorm2d(out_channel))
|
215 |
+
last_channel = out_channel
|
216 |
+
self.conv_blocks.append(convs)
|
217 |
+
self.bn_blocks.append(bns)
|
218 |
+
|
219 |
+
def forward(self, xyz, points, seed_idx=None):
|
220 |
+
"""
|
221 |
+
Input:
|
222 |
+
xyz: input points position data, [B, C, N]
|
223 |
+
points: input points data, [B, D, N]
|
224 |
+
Return:
|
225 |
+
new_xyz: sampled points position data, [B, C, S]
|
226 |
+
new_points_concat: sample points feature data, [B, D', S]
|
227 |
+
"""
|
228 |
+
|
229 |
+
B, N, C = xyz.shape
|
230 |
+
S = self.npoint
|
231 |
+
new_xyz = index_points(xyz, farthest_point_sample(xyz, S) if seed_idx is None else seed_idx)
|
232 |
+
new_points_list = []
|
233 |
+
for i, radius in enumerate(self.radius_list):
|
234 |
+
K = self.nsample_list[i]
|
235 |
+
if self.knn:
|
236 |
+
dists = square_distance(new_xyz, xyz) # B x npoint x N
|
237 |
+
group_idx = dists.argsort()[:, :, :K] # B x npoint x K
|
238 |
+
else:
|
239 |
+
group_idx = query_ball_point(radius, K, xyz, new_xyz)
|
240 |
+
grouped_xyz = index_points(xyz, group_idx)
|
241 |
+
grouped_xyz -= new_xyz.view(B, S, 1, C)
|
242 |
+
if points is not None:
|
243 |
+
grouped_points = index_points(points, group_idx)
|
244 |
+
grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
|
245 |
+
else:
|
246 |
+
grouped_points = grouped_xyz
|
247 |
+
|
248 |
+
grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
|
249 |
+
for j in range(len(self.conv_blocks[i])):
|
250 |
+
conv = self.conv_blocks[i][j]
|
251 |
+
bn = self.bn_blocks[i][j]
|
252 |
+
grouped_points = F.relu(bn(conv(grouped_points)))
|
253 |
+
new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
|
254 |
+
new_points_list.append(new_points)
|
255 |
+
|
256 |
+
new_points_concat = torch.cat(new_points_list, dim=1).transpose(1, 2)
|
257 |
+
return new_xyz, new_points_concat
|
258 |
+
|
259 |
+
|
260 |
+
# NoteL this function swaps N and C
|
261 |
+
class PointNetFeaturePropagation(nn.Module):
|
262 |
+
def __init__(self, in_channel, mlp):
|
263 |
+
super(PointNetFeaturePropagation, self).__init__()
|
264 |
+
self.mlp_convs = nn.ModuleList()
|
265 |
+
self.mlp_bns = nn.ModuleList()
|
266 |
+
last_channel = in_channel
|
267 |
+
for out_channel in mlp:
|
268 |
+
self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
|
269 |
+
self.mlp_bns.append(nn.BatchNorm1d(out_channel))
|
270 |
+
last_channel = out_channel
|
271 |
+
|
272 |
+
def forward(self, xyz1, xyz2, points1, points2):
|
273 |
+
"""
|
274 |
+
Input:
|
275 |
+
xyz1: input points position data, [B, C, N]
|
276 |
+
xyz2: sampled input points position data, [B, C, S]
|
277 |
+
points1: input points data, [B, D, N]
|
278 |
+
points2: input points data, [B, D, S]
|
279 |
+
Return:
|
280 |
+
new_points: upsampled points data, [B, D', N]
|
281 |
+
"""
|
282 |
+
xyz1 = xyz1.permute(0, 2, 1)
|
283 |
+
xyz2 = xyz2.permute(0, 2, 1)
|
284 |
+
|
285 |
+
points2 = points2.permute(0, 2, 1)
|
286 |
+
B, N, C = xyz1.shape
|
287 |
+
_, S, _ = xyz2.shape
|
288 |
+
|
289 |
+
if S == 1:
|
290 |
+
interpolated_points = points2.repeat(1, N, 1)
|
291 |
+
else:
|
292 |
+
dists = square_distance(xyz1, xyz2)
|
293 |
+
dists, idx = dists.sort(dim=-1)
|
294 |
+
dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
|
295 |
+
|
296 |
+
dist_recip = 1.0 / (dists + 1e-8)
|
297 |
+
norm = torch.sum(dist_recip, dim=2, keepdim=True)
|
298 |
+
weight = dist_recip / norm
|
299 |
+
interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
|
300 |
+
|
301 |
+
if points1 is not None:
|
302 |
+
points1 = points1.permute(0, 2, 1)
|
303 |
+
new_points = torch.cat([points1, interpolated_points], dim=-1)
|
304 |
+
else:
|
305 |
+
new_points = interpolated_points
|
306 |
+
|
307 |
+
new_points = new_points.permute(0, 2, 1)
|
308 |
+
for i, conv in enumerate(self.mlp_convs):
|
309 |
+
bn = self.mlp_bns[i]
|
310 |
+
new_points = F.relu(bn(conv(new_points)))
|
311 |
+
return new_points
|
difpoint/model/temporaltrans/pointtransformerv2.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .transformer_utils import BaseTemperalPointModel
|
2 |
+
from copy import deepcopy
|
3 |
+
import torch
|
4 |
+
import einops
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
from einops import rearrange
|
9 |
+
import pointops
|
10 |
+
from pointcept.models.utils import offset2batch, batch2offset
|
11 |
+
class PointBatchNorm(nn.Module):
|
12 |
+
"""
|
13 |
+
Batch Normalization for Point Clouds data in shape of [B*N, C], [B*N, L, C]
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, embed_channels):
|
17 |
+
super().__init__()
|
18 |
+
self.norm = nn.BatchNorm1d(embed_channels)
|
19 |
+
|
20 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
21 |
+
if input.dim() == 3:
|
22 |
+
return (
|
23 |
+
self.norm(input.transpose(1, 2).contiguous())
|
24 |
+
.transpose(1, 2)
|
25 |
+
.contiguous()
|
26 |
+
)
|
27 |
+
elif input.dim() == 2:
|
28 |
+
return self.norm(input)
|
29 |
+
else:
|
30 |
+
raise NotImplementedError
|
31 |
+
#https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/point_transformer_v2/point_transformer_v2m2_base.py
|
32 |
+
class GroupedVectorAttention(nn.Module):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
embed_channels,
|
36 |
+
groups,
|
37 |
+
attn_drop_rate=0.0,
|
38 |
+
qkv_bias=True,
|
39 |
+
pe_multiplier=False,
|
40 |
+
pe_bias=True,
|
41 |
+
):
|
42 |
+
super(GroupedVectorAttention, self).__init__()
|
43 |
+
self.embed_channels = embed_channels
|
44 |
+
self.groups = groups
|
45 |
+
assert embed_channels % groups == 0
|
46 |
+
self.attn_drop_rate = attn_drop_rate
|
47 |
+
self.qkv_bias = qkv_bias
|
48 |
+
self.pe_multiplier = pe_multiplier
|
49 |
+
self.pe_bias = pe_bias
|
50 |
+
|
51 |
+
self.linear_q = nn.Sequential(
|
52 |
+
nn.Linear(embed_channels, embed_channels, bias=qkv_bias),
|
53 |
+
PointBatchNorm(embed_channels),
|
54 |
+
nn.ReLU(inplace=True),
|
55 |
+
)
|
56 |
+
self.linear_k = nn.Sequential(
|
57 |
+
nn.Linear(embed_channels, embed_channels, bias=qkv_bias),
|
58 |
+
PointBatchNorm(embed_channels),
|
59 |
+
nn.ReLU(inplace=True),
|
60 |
+
)
|
61 |
+
|
62 |
+
self.linear_v = nn.Linear(embed_channels, embed_channels, bias=qkv_bias)
|
63 |
+
|
64 |
+
if self.pe_multiplier:
|
65 |
+
self.linear_p_multiplier = nn.Sequential(
|
66 |
+
nn.Linear(3, embed_channels),
|
67 |
+
PointBatchNorm(embed_channels),
|
68 |
+
nn.ReLU(inplace=True),
|
69 |
+
nn.Linear(embed_channels, embed_channels),
|
70 |
+
)
|
71 |
+
if self.pe_bias:
|
72 |
+
self.linear_p_bias = nn.Sequential(
|
73 |
+
nn.Linear(3, embed_channels),
|
74 |
+
PointBatchNorm(embed_channels),
|
75 |
+
nn.ReLU(inplace=True),
|
76 |
+
nn.Linear(embed_channels, embed_channels),
|
77 |
+
)
|
78 |
+
self.weight_encoding = nn.Sequential(
|
79 |
+
nn.Linear(embed_channels, groups),
|
80 |
+
PointBatchNorm(groups),
|
81 |
+
nn.ReLU(inplace=True),
|
82 |
+
nn.Linear(groups, groups),
|
83 |
+
)
|
84 |
+
self.softmax = nn.Softmax(dim=1)
|
85 |
+
self.attn_drop = nn.Dropout(attn_drop_rate)
|
86 |
+
|
87 |
+
def forward(self, feat, coord, reference_index):
|
88 |
+
query, key, value = (
|
89 |
+
self.linear_q(feat),
|
90 |
+
self.linear_k(feat),
|
91 |
+
self.linear_v(feat),
|
92 |
+
)
|
93 |
+
key = pointops.grouping(reference_index, key, coord, with_xyz=True)
|
94 |
+
value = pointops.grouping(reference_index, value, coord, with_xyz=False)
|
95 |
+
pos, key = key[:, :, 0:3], key[:, :, 3:]
|
96 |
+
relation_qk = key - query.unsqueeze(1)
|
97 |
+
if self.pe_multiplier:
|
98 |
+
pem = self.linear_p_multiplier(pos)
|
99 |
+
relation_qk = relation_qk * pem
|
100 |
+
if self.pe_bias:
|
101 |
+
peb = self.linear_p_bias(pos)
|
102 |
+
relation_qk = relation_qk + peb
|
103 |
+
value = value + peb
|
104 |
+
|
105 |
+
weight = self.weight_encoding(relation_qk)
|
106 |
+
weight = self.attn_drop(self.softmax(weight))
|
107 |
+
|
108 |
+
mask = torch.sign(reference_index + 1)
|
109 |
+
weight = torch.einsum("n s g, n s -> n s g", weight, mask)
|
110 |
+
value = einops.rearrange(value, "n ns (g i) -> n ns g i", g=self.groups)
|
111 |
+
feat = torch.einsum("n s g i, n s g -> n g i", value, weight)
|
112 |
+
feat = einops.rearrange(feat, "n g i -> n (g i)")
|
113 |
+
return feat
|
114 |
+
|
115 |
+
class BlockSequence(nn.Module):
|
116 |
+
def __init__(
|
117 |
+
self,
|
118 |
+
depth,
|
119 |
+
embed_channels,
|
120 |
+
groups,
|
121 |
+
neighbours=16,
|
122 |
+
qkv_bias=True,
|
123 |
+
pe_multiplier=False,
|
124 |
+
pe_bias=True,
|
125 |
+
attn_drop_rate=0.0,
|
126 |
+
drop_path_rate=0.0,
|
127 |
+
enable_checkpoint=False,
|
128 |
+
):
|
129 |
+
super(BlockSequence, self).__init__()
|
130 |
+
|
131 |
+
if isinstance(drop_path_rate, list):
|
132 |
+
drop_path_rates = drop_path_rate
|
133 |
+
assert len(drop_path_rates) == depth
|
134 |
+
elif isinstance(drop_path_rate, float):
|
135 |
+
drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
|
136 |
+
else:
|
137 |
+
drop_path_rates = [0.0 for _ in range(depth)]
|
138 |
+
|
139 |
+
self.neighbours = neighbours
|
140 |
+
self.blocks = nn.ModuleList()
|
141 |
+
for i in range(depth):
|
142 |
+
block = Block(
|
143 |
+
embed_channels=embed_channels,
|
144 |
+
groups=groups,
|
145 |
+
qkv_bias=qkv_bias,
|
146 |
+
pe_multiplier=pe_multiplier,
|
147 |
+
pe_bias=pe_bias,
|
148 |
+
attn_drop_rate=attn_drop_rate,
|
149 |
+
drop_path_rate=drop_path_rates[i],
|
150 |
+
enable_checkpoint=enable_checkpoint,
|
151 |
+
)
|
152 |
+
self.blocks.append(block)
|
153 |
+
|
154 |
+
def forward(self, points):
|
155 |
+
coord, feat, offset = points
|
156 |
+
# reference index query of neighbourhood attention
|
157 |
+
# for windows attention, modify reference index query method
|
158 |
+
reference_index, _ = pointops.knn_query(self.neighbours, coord, offset)
|
159 |
+
for block in self.blocks:
|
160 |
+
points = block(points, reference_index)
|
161 |
+
return points
|
162 |
+
|
163 |
+
class GVAPatchEmbed(nn.Module):
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
depth,
|
167 |
+
in_channels,
|
168 |
+
embed_channels,
|
169 |
+
groups,
|
170 |
+
neighbours=16,
|
171 |
+
qkv_bias=True,
|
172 |
+
pe_multiplier=False,
|
173 |
+
pe_bias=True,
|
174 |
+
attn_drop_rate=0.0,
|
175 |
+
drop_path_rate=0.0,
|
176 |
+
enable_checkpoint=False,
|
177 |
+
):
|
178 |
+
super(GVAPatchEmbed, self).__init__()
|
179 |
+
self.in_channels = in_channels
|
180 |
+
self.embed_channels = embed_channels
|
181 |
+
self.proj = nn.Sequential(
|
182 |
+
nn.Linear(in_channels, embed_channels, bias=False),
|
183 |
+
PointBatchNorm(embed_channels),
|
184 |
+
nn.ReLU(inplace=True),
|
185 |
+
)
|
186 |
+
self.blocks = BlockSequence(
|
187 |
+
depth=depth,
|
188 |
+
embed_channels=embed_channels,
|
189 |
+
groups=groups,
|
190 |
+
neighbours=neighbours,
|
191 |
+
qkv_bias=qkv_bias,
|
192 |
+
pe_multiplier=pe_multiplier,
|
193 |
+
pe_bias=pe_bias,
|
194 |
+
attn_drop_rate=attn_drop_rate,
|
195 |
+
drop_path_rate=drop_path_rate,
|
196 |
+
enable_checkpoint=enable_checkpoint,
|
197 |
+
)
|
198 |
+
|
199 |
+
def forward(self, points):
|
200 |
+
coord, feat, offset = points
|
201 |
+
feat = self.proj(feat)
|
202 |
+
return self.blocks([coord, feat, offset])
|
203 |
+
|
204 |
+
|
205 |
+
class Block(nn.Module):
|
206 |
+
def __init__(
|
207 |
+
self,
|
208 |
+
embed_channels,
|
209 |
+
groups,
|
210 |
+
qkv_bias=True,
|
211 |
+
pe_multiplier=False,
|
212 |
+
pe_bias=True,
|
213 |
+
attn_drop_rate=0.0,
|
214 |
+
drop_path_rate=0.0,
|
215 |
+
enable_checkpoint=False,
|
216 |
+
):
|
217 |
+
super(Block, self).__init__()
|
218 |
+
self.attn = GroupedVectorAttention(
|
219 |
+
embed_channels=embed_channels,
|
220 |
+
groups=groups,
|
221 |
+
qkv_bias=qkv_bias,
|
222 |
+
attn_drop_rate=attn_drop_rate,
|
223 |
+
pe_multiplier=pe_multiplier,
|
224 |
+
pe_bias=pe_bias,
|
225 |
+
)
|
226 |
+
self.fc1 = nn.Linear(embed_channels, embed_channels, bias=False)
|
227 |
+
self.fc3 = nn.Linear(embed_channels, embed_channels, bias=False)
|
228 |
+
self.norm1 = PointBatchNorm(embed_channels)
|
229 |
+
self.norm2 = PointBatchNorm(embed_channels)
|
230 |
+
self.norm3 = PointBatchNorm(embed_channels)
|
231 |
+
self.act = nn.ReLU(inplace=True)
|
232 |
+
self.enable_checkpoint = enable_checkpoint
|
233 |
+
self.drop_path = (
|
234 |
+
DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
235 |
+
)
|
236 |
+
|
237 |
+
def forward(self, points, reference_index):
|
238 |
+
coord, feat, offset = points
|
239 |
+
identity = feat
|
240 |
+
feat = self.act(self.norm1(self.fc1(feat)))
|
241 |
+
feat = (
|
242 |
+
self.attn(feat, coord, reference_index)
|
243 |
+
if not self.enable_checkpoint
|
244 |
+
else checkpoint(self.attn, feat, coord, reference_index)
|
245 |
+
)
|
246 |
+
feat = self.act(self.norm2(feat))
|
247 |
+
feat = self.norm3(self.fc3(feat))
|
248 |
+
feat = identity + self.drop_path(feat)
|
249 |
+
feat = self.act(feat)
|
250 |
+
return [coord, feat, offset]
|
difpoint/model/temporaltrans/temptrans.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn
|
5 |
+
from einops import rearrange
|
6 |
+
from .transformer_utils import BaseTemperalPointModel
|
7 |
+
import math
|
8 |
+
from einops_exts import check_shape, rearrange_many
|
9 |
+
from functools import partial
|
10 |
+
|
11 |
+
class SinusoidalPosEmb(nn.Module):
|
12 |
+
def __init__(self, dim):
|
13 |
+
super().__init__()
|
14 |
+
self.dim = dim
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
device = x.device
|
18 |
+
half_dim = self.dim // 2
|
19 |
+
emb = math.log(10000) / (half_dim - 1)
|
20 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
21 |
+
emb = x[:, None] * emb[None, :]
|
22 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
23 |
+
return emb
|
24 |
+
|
25 |
+
class RelativePositionBias(nn.Module):
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
heads = 8,
|
29 |
+
num_buckets = 32,
|
30 |
+
max_distance = 128
|
31 |
+
):
|
32 |
+
super().__init__()
|
33 |
+
self.num_buckets = num_buckets
|
34 |
+
self.max_distance = max_distance
|
35 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
def _relative_position_bucket(relative_position, num_buckets = 32, max_distance = 128):
|
39 |
+
ret = 0
|
40 |
+
n = -relative_position
|
41 |
+
|
42 |
+
num_buckets //= 2
|
43 |
+
ret += (n < 0).long() * num_buckets
|
44 |
+
n = torch.abs(n)
|
45 |
+
|
46 |
+
max_exact = num_buckets // 2
|
47 |
+
is_small = n < max_exact
|
48 |
+
|
49 |
+
val_if_large = max_exact + (
|
50 |
+
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
|
51 |
+
).long()
|
52 |
+
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
53 |
+
|
54 |
+
ret += torch.where(is_small, n, val_if_large)
|
55 |
+
return ret
|
56 |
+
|
57 |
+
def forward(self, n, device):
|
58 |
+
q_pos = torch.arange(n, dtype = torch.long, device = device)
|
59 |
+
k_pos = torch.arange(n, dtype = torch.long, device = device)
|
60 |
+
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
|
61 |
+
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
|
62 |
+
values = self.relative_attention_bias(rp_bucket)
|
63 |
+
return rearrange(values, 'i j h -> h i j')
|
64 |
+
def exists(x):
|
65 |
+
return x is not None
|
66 |
+
|
67 |
+
class Residual(nn.Module):
|
68 |
+
def __init__(self, fn):
|
69 |
+
super().__init__()
|
70 |
+
self.fn = fn
|
71 |
+
|
72 |
+
def forward(self, x, *args, **kwargs):
|
73 |
+
return self.fn(x, *args, **kwargs) + x
|
74 |
+
class LayerNorm(nn.Module):
|
75 |
+
def __init__(self, dim, eps = 1e-5):
|
76 |
+
super().__init__()
|
77 |
+
self.eps = eps
|
78 |
+
self.gamma = nn.Parameter(torch.ones(1, 1, dim))
|
79 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
|
83 |
+
mean = torch.mean(x, dim = -1, keepdim = True)
|
84 |
+
return (x - mean) / (var + self.eps).sqrt() * self.gamma + self.beta
|
85 |
+
|
86 |
+
class PreNorm(nn.Module):
|
87 |
+
def __init__(self, dim, fn):
|
88 |
+
super().__init__()
|
89 |
+
self.fn = fn
|
90 |
+
self.norm = LayerNorm(dim)
|
91 |
+
|
92 |
+
def forward(self, x, **kwargs):
|
93 |
+
x = self.norm(x)
|
94 |
+
return self.fn(x, **kwargs)
|
95 |
+
|
96 |
+
|
97 |
+
class EinopsToAndFrom(nn.Module):
|
98 |
+
def __init__(self, from_einops, to_einops, fn):
|
99 |
+
super().__init__()
|
100 |
+
self.from_einops = from_einops
|
101 |
+
self.to_einops = to_einops
|
102 |
+
self.fn = fn
|
103 |
+
|
104 |
+
def forward(self, x, **kwargs):
|
105 |
+
shape = x.shape
|
106 |
+
reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape)))
|
107 |
+
x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')
|
108 |
+
x = self.fn(x, **kwargs)
|
109 |
+
x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)
|
110 |
+
return x
|
111 |
+
|
112 |
+
class Attention(nn.Module):
|
113 |
+
def __init__(
|
114 |
+
self, dim, heads=4, attn_head_dim=None, casual_attn=False,rotary_emb = None):
|
115 |
+
super().__init__()
|
116 |
+
self.num_heads = heads
|
117 |
+
head_dim = dim // heads
|
118 |
+
self.casual_attn = casual_attn
|
119 |
+
|
120 |
+
if attn_head_dim is not None:
|
121 |
+
head_dim = attn_head_dim
|
122 |
+
|
123 |
+
all_head_dim = head_dim * self.num_heads
|
124 |
+
self.scale = head_dim ** -0.5
|
125 |
+
self.to_qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
126 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
127 |
+
self.rotary_emb = rotary_emb
|
128 |
+
|
129 |
+
def forward(self, x, pos_bias = None):
|
130 |
+
N, device = x.shape[-2], x.device
|
131 |
+
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
132 |
+
|
133 |
+
q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.num_heads)
|
134 |
+
|
135 |
+
q = q * self.scale
|
136 |
+
|
137 |
+
if exists(self.rotary_emb):
|
138 |
+
q = self.rotary_emb.rotate_queries_or_keys(q)
|
139 |
+
k = self.rotary_emb.rotate_queries_or_keys(k)
|
140 |
+
|
141 |
+
sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k)
|
142 |
+
|
143 |
+
if exists(pos_bias):
|
144 |
+
sim = sim + pos_bias
|
145 |
+
|
146 |
+
if self.casual_attn:
|
147 |
+
mask = torch.tril(torch.ones(sim.size(-1), sim.size(-2))).to(device)
|
148 |
+
sim = sim.masked_fill(mask[..., :, :] == 0, float('-inf'))
|
149 |
+
|
150 |
+
attn = sim.softmax(dim = -1)
|
151 |
+
x = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v)
|
152 |
+
x = rearrange(x, '... h n d -> ... n (h d)')
|
153 |
+
x = self.proj(x)
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
class Block(nn.Module):
|
158 |
+
def __init__(self, dim, dim_out):
|
159 |
+
super().__init__()
|
160 |
+
self.proj = nn.Linear(dim, dim_out)
|
161 |
+
self.norm = LayerNorm(dim)
|
162 |
+
self.act = nn.SiLU()
|
163 |
+
|
164 |
+
def forward(self, x, scale_shift=None):
|
165 |
+
x = self.proj(x)
|
166 |
+
|
167 |
+
if exists(scale_shift):
|
168 |
+
x = self.norm(x)
|
169 |
+
scale, shift = scale_shift
|
170 |
+
x = x * (scale + 1) + shift
|
171 |
+
return self.act(x)
|
172 |
+
|
173 |
+
|
174 |
+
class ResnetBlock(nn.Module):
|
175 |
+
def __init__(self, dim, dim_out, cond_dim=None):
|
176 |
+
super().__init__()
|
177 |
+
self.mlp = nn.Sequential(
|
178 |
+
nn.SiLU(),
|
179 |
+
nn.Linear(cond_dim, dim_out * 2)
|
180 |
+
) if exists(cond_dim) else None
|
181 |
+
|
182 |
+
self.block1 = Block(dim, dim_out)
|
183 |
+
self.block2 = Block(dim_out, dim_out)
|
184 |
+
|
185 |
+
def forward(self, x, cond_emb=None):
|
186 |
+
scale_shift = None
|
187 |
+
if exists(self.mlp):
|
188 |
+
assert exists(cond_emb), 'time emb must be passed in'
|
189 |
+
cond_emb = self.mlp(cond_emb)
|
190 |
+
#cond_emb = rearrange(cond_emb, 'b f c -> b f 1 c')
|
191 |
+
scale_shift = cond_emb.chunk(2, dim=-1)
|
192 |
+
|
193 |
+
h = self.block1(x, scale_shift=scale_shift)
|
194 |
+
h = self.block2(h)
|
195 |
+
return h + x
|
196 |
+
|
197 |
+
from rotary_embedding_torch import RotaryEmbedding
|
198 |
+
|
199 |
+
|
200 |
+
class SimpleTransModel(BaseTemperalPointModel):
|
201 |
+
"""
|
202 |
+
A simple model that processes a point cloud by applying a series of MLPs to each point
|
203 |
+
individually, along with some pooled global features.
|
204 |
+
"""
|
205 |
+
|
206 |
+
def get_layers(self):
|
207 |
+
|
208 |
+
|
209 |
+
# self.input_projection = nn.Linear(
|
210 |
+
# in_features=51,
|
211 |
+
# out_features=self.dim
|
212 |
+
# )
|
213 |
+
|
214 |
+
self.input_projection = nn.Linear(
|
215 |
+
in_features=70,
|
216 |
+
out_features=self.dim
|
217 |
+
)
|
218 |
+
|
219 |
+
cond_dim = 512 + self.timestep_embed_dim
|
220 |
+
|
221 |
+
num_head = self.dim//64
|
222 |
+
rotary_emb = RotaryEmbedding(min(32, num_head))
|
223 |
+
|
224 |
+
self.time_rel_pos_bias = RelativePositionBias(heads=num_head, max_distance=128) # realistically will not be able to generate that many frames of video... yet
|
225 |
+
|
226 |
+
temporal_casual_attn = lambda dim: Attention(dim, heads=num_head, casual_attn=False,rotary_emb=rotary_emb)
|
227 |
+
|
228 |
+
cond_block= partial(ResnetBlock,cond_dim=cond_dim)
|
229 |
+
|
230 |
+
layers = nn.ModuleList([])
|
231 |
+
|
232 |
+
for _ in range(self.num_layers):
|
233 |
+
layers.append(nn.ModuleList([
|
234 |
+
cond_block(self.dim,self.dim),
|
235 |
+
cond_block(self.dim,self.dim),
|
236 |
+
Residual(PreNorm(self.dim,temporal_casual_attn(self.dim)))
|
237 |
+
]))
|
238 |
+
|
239 |
+
return layers
|
240 |
+
|
241 |
+
def forward(self, inputs: torch.Tensor, timesteps: torch.Tensor, context=None):
|
242 |
+
"""
|
243 |
+
Apply the model to an input batch.
|
244 |
+
:param x: an [N x C x ...] Tensor of inputs.
|
245 |
+
:param timesteps: a 1-D batch of timesteps.
|
246 |
+
:param context: conditioning plugged in via crossattn
|
247 |
+
"""
|
248 |
+
# Prepare inputs
|
249 |
+
|
250 |
+
batch, num_frames, channels = inputs.size()
|
251 |
+
|
252 |
+
device = inputs.device
|
253 |
+
#assert channels==3
|
254 |
+
|
255 |
+
# Positional encoding of point coords
|
256 |
+
# inputs=rearrange(inputs,'b f p c->(b f) p c')
|
257 |
+
# pos_emb=self.positional_encoding(inputs)
|
258 |
+
x = self.input_projection(inputs)
|
259 |
+
#x = rearrange(x,'(b f) p c-> b f p c',b=batch)
|
260 |
+
|
261 |
+
t_emb = self.time_mlp(timesteps) if exists(self.time_mlp) else None
|
262 |
+
t_emb = t_emb[:,None,:].expand(-1, num_frames, -1) # b f c
|
263 |
+
if context is not None:
|
264 |
+
t_emb = torch.cat([t_emb, context],-1)
|
265 |
+
|
266 |
+
time_rel_pos_bias = self.time_rel_pos_bias(num_frames, device=device)
|
267 |
+
|
268 |
+
for block1, block2, temporal_casual_attn in self.layers:
|
269 |
+
x = block1(x, t_emb)
|
270 |
+
x = block2(x, t_emb)
|
271 |
+
x = temporal_casual_attn(x, pos_bias=time_rel_pos_bias)
|
272 |
+
|
273 |
+
# Project
|
274 |
+
x = self.output_projection(x)
|
275 |
+
return x
|
276 |
+
|
277 |
+
|
278 |
+
|
279 |
+
class SimpleTemperalPointModel(BaseTemperalPointModel):
|
280 |
+
"""
|
281 |
+
A simple model that processes a point cloud by applying a series of MLPs to each point
|
282 |
+
individually, along with some pooled global features.
|
283 |
+
"""
|
284 |
+
|
285 |
+
def get_layers(self):
|
286 |
+
audio_dim = 512
|
287 |
+
|
288 |
+
cond_dim = audio_dim + self.timestep_embed_dim
|
289 |
+
|
290 |
+
num_head = 4
|
291 |
+
rotary_emb = RotaryEmbedding(min(32, num_head))
|
292 |
+
self.time_rel_pos_bias = RelativePositionBias(heads=num_head, max_distance=128) # realistically will not be able to generate that many frames of video... yet
|
293 |
+
|
294 |
+
temporal_casual_attn = lambda dim: EinopsToAndFrom('b f p c', 'b p f c', Attention(dim, heads=num_head, casual_attn=False, rotary_emb = rotary_emb))
|
295 |
+
|
296 |
+
spatial_kp_attn= lambda dim: EinopsToAndFrom('b f p c', 'b f p c', Attention(dim, heads=num_head))
|
297 |
+
|
298 |
+
cond_block= partial(ResnetBlock,cond_dim=cond_dim)
|
299 |
+
|
300 |
+
layers = nn.ModuleList([])
|
301 |
+
|
302 |
+
for _ in range(self.num_layers):
|
303 |
+
layers.append(nn.ModuleList([
|
304 |
+
cond_block(self.dim,self.dim),
|
305 |
+
cond_block(self.dim,self.dim),
|
306 |
+
Residual(PreNorm(self.dim,spatial_kp_attn(self.dim))),
|
307 |
+
Residual(PreNorm(self.dim,temporal_casual_attn(self.dim)))
|
308 |
+
]))
|
309 |
+
|
310 |
+
return layers
|
311 |
+
|
312 |
+
def forward(self, inputs: torch.Tensor, timesteps: torch.Tensor, context=None):
|
313 |
+
"""
|
314 |
+
Apply the model to an input batch.
|
315 |
+
:param x: an [N x C x ...] Tensor of inputs.
|
316 |
+
:param timesteps: a 1-D batch of timesteps.
|
317 |
+
:param context: conditioning plugged in via crossattn
|
318 |
+
"""
|
319 |
+
# Prepare inputs
|
320 |
+
|
321 |
+
batch, num_frames, num_points, channels = inputs.size()
|
322 |
+
device = inputs.device
|
323 |
+
#assert channels==3
|
324 |
+
|
325 |
+
# Positional encoding of point coords
|
326 |
+
inputs=rearrange(inputs,'b f p c->(b f) p c')
|
327 |
+
pos_emb=self.positional_encoding(inputs)
|
328 |
+
x = self.input_projection(torch.cat([inputs, pos_emb], -1))
|
329 |
+
x = rearrange(x,'(b f) p c-> b f p c',b=batch)
|
330 |
+
|
331 |
+
t_emb = self.time_mlp(timesteps) if exists(self.time_mlp) else None
|
332 |
+
t_emb = t_emb[:,None,:].expand(-1, num_frames, -1) # b f c
|
333 |
+
if context is not None:
|
334 |
+
t_emb = torch.cat([t_emb,context],-1)
|
335 |
+
|
336 |
+
time_rel_pos_bias = self.time_rel_pos_bias(num_frames, device=device)
|
337 |
+
|
338 |
+
for block1, block2, spatial_kp_attn, temporal_casual_attn in self.layers:
|
339 |
+
x = block1(x, t_emb)
|
340 |
+
x = block2(x, t_emb)
|
341 |
+
x = spatial_kp_attn(x)
|
342 |
+
x = temporal_casual_attn(x, pos_bias=time_rel_pos_bias)
|
343 |
+
|
344 |
+
# Project
|
345 |
+
x = self.output_projection(x)
|
346 |
+
return x
|
347 |
+
|
difpoint/model/temporaltrans/transformer_utils.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn
|
5 |
+
from einops import rearrange
|
6 |
+
import math
|
7 |
+
from einops_exts import check_shape, rearrange_many
|
8 |
+
from torch import Size, Tensor, nn
|
9 |
+
class SinusoidalPosEmb(nn.Module):
|
10 |
+
def __init__(self, dim):
|
11 |
+
super().__init__()
|
12 |
+
self.dim = dim
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
device = x.device
|
16 |
+
half_dim = self.dim // 2
|
17 |
+
emb = math.log(10000) / (half_dim - 1)
|
18 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
19 |
+
emb = x[:, None] * emb[None, :]
|
20 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
21 |
+
return emb
|
22 |
+
|
23 |
+
|
24 |
+
def map_positional_encoding(v: Tensor, freq_bands: Tensor) -> Tensor:
|
25 |
+
"""Map v to positional encoding representation phi(v)
|
26 |
+
|
27 |
+
Arguments:
|
28 |
+
v (Tensor): input features (B, IFeatures)
|
29 |
+
freq_bands (Tensor): frequency bands (N_freqs, )
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
phi(v) (Tensor): fourrier features (B, 3 + (2 * N_freqs) * 3)
|
33 |
+
"""
|
34 |
+
pe = [v]
|
35 |
+
for freq in freq_bands:
|
36 |
+
fv = freq * v
|
37 |
+
pe += [torch.sin(fv), torch.cos(fv)]
|
38 |
+
return torch.cat(pe, dim=-1)
|
39 |
+
|
40 |
+
class FeatureMapping(nn.Module):
|
41 |
+
"""FeatureMapping nn.Module
|
42 |
+
|
43 |
+
Maps v to features following transformation phi(v)
|
44 |
+
|
45 |
+
Arguments:
|
46 |
+
i_dim (int): input dimensions
|
47 |
+
o_dim (int): output dimensions
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(self, i_dim: int, o_dim: int) -> None:
|
51 |
+
super().__init__()
|
52 |
+
self.i_dim = i_dim
|
53 |
+
self.o_dim = o_dim
|
54 |
+
|
55 |
+
def forward(self, v: Tensor) -> Tensor:
|
56 |
+
"""FeratureMapping forward pass
|
57 |
+
|
58 |
+
Arguments:
|
59 |
+
v (Tensor): input features (B, IFeatures)
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
phi(v) (Tensor): mapped features (B, OFeatures)
|
63 |
+
"""
|
64 |
+
raise NotImplementedError("Forward pass not implemented yet!")
|
65 |
+
|
66 |
+
class PositionalEncoding(FeatureMapping):
|
67 |
+
"""PositionalEncoding module
|
68 |
+
|
69 |
+
Maps v to positional encoding representation phi(v)
|
70 |
+
|
71 |
+
Arguments:
|
72 |
+
i_dim (int): input dimension for v
|
73 |
+
N_freqs (int): #frequency to sample (default: 10)
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
i_dim: int,
|
79 |
+
N_freqs: int = 10,
|
80 |
+
) -> None:
|
81 |
+
super().__init__(i_dim, 3 + (2 * N_freqs) * 3)
|
82 |
+
self.N_freqs = N_freqs
|
83 |
+
|
84 |
+
a, b = 1, self.N_freqs - 1
|
85 |
+
freq_bands = 2 ** torch.linspace(a, b, self.N_freqs)
|
86 |
+
self.register_buffer("freq_bands", freq_bands)
|
87 |
+
|
88 |
+
def forward(self, v: Tensor) -> Tensor:
|
89 |
+
"""Map v to positional encoding representation phi(v)
|
90 |
+
|
91 |
+
Arguments:
|
92 |
+
v (Tensor): input features (B, IFeatures)
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
phi(v) (Tensor): fourrier features (B, 3 + (2 * N_freqs) * 3)
|
96 |
+
"""
|
97 |
+
return map_positional_encoding(v, self.freq_bands)
|
98 |
+
|
99 |
+
class BaseTemperalPointModel(nn.Module):
|
100 |
+
""" A base class providing useful methods for point cloud processing. """
|
101 |
+
|
102 |
+
def __init__(
|
103 |
+
self,
|
104 |
+
*,
|
105 |
+
num_classes,
|
106 |
+
embed_dim,
|
107 |
+
extra_feature_channels,
|
108 |
+
dim: int = 768,
|
109 |
+
num_layers: int = 6
|
110 |
+
):
|
111 |
+
super().__init__()
|
112 |
+
|
113 |
+
self.extra_feature_channels = extra_feature_channels
|
114 |
+
self.timestep_embed_dim = 256
|
115 |
+
self.output_dim = num_classes
|
116 |
+
self.dim = dim
|
117 |
+
self.num_layers = num_layers
|
118 |
+
|
119 |
+
|
120 |
+
self.time_mlp = nn.Sequential(
|
121 |
+
SinusoidalPosEmb(dim),
|
122 |
+
nn.Linear(dim, self.timestep_embed_dim ),
|
123 |
+
nn.SiLU(),
|
124 |
+
nn.Linear(self.timestep_embed_dim , self.timestep_embed_dim )
|
125 |
+
)
|
126 |
+
|
127 |
+
self.positional_encoding = PositionalEncoding(i_dim=3, N_freqs=10)
|
128 |
+
positional_encoding_d_out = 3 + (2 * 10) * 3
|
129 |
+
|
130 |
+
# Input projection (point coords, point coord encodings, other features, and timestep embeddings)
|
131 |
+
|
132 |
+
self.input_projection = nn.Linear(
|
133 |
+
in_features=(3 + positional_encoding_d_out),
|
134 |
+
out_features=self.dim
|
135 |
+
)#b f p c
|
136 |
+
|
137 |
+
# Transformer layers
|
138 |
+
self.layers = self.get_layers()
|
139 |
+
|
140 |
+
# Output projection
|
141 |
+
self.output_projection = nn.Linear(self.dim, self.output_dim)
|
142 |
+
def get_layers(self):
|
143 |
+
raise NotImplementedError('This method should be implemented by subclasses')
|
144 |
+
|
145 |
+
def forward(self, inputs: torch.Tensor, t: torch.Tensor):
|
146 |
+
raise NotImplementedError('This method should be implemented by subclasses')
|