Spaces:
Running
on
Zero
Running
on
Zero
from .transformer_utils import BaseTemperalPointModel | |
from copy import deepcopy | |
import torch | |
import einops | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import nn | |
from einops import rearrange | |
import pointops | |
from pointcept.models.utils import offset2batch, batch2offset | |
class PointBatchNorm(nn.Module): | |
""" | |
Batch Normalization for Point Clouds data in shape of [B*N, C], [B*N, L, C] | |
""" | |
def __init__(self, embed_channels): | |
super().__init__() | |
self.norm = nn.BatchNorm1d(embed_channels) | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
if input.dim() == 3: | |
return ( | |
self.norm(input.transpose(1, 2).contiguous()) | |
.transpose(1, 2) | |
.contiguous() | |
) | |
elif input.dim() == 2: | |
return self.norm(input) | |
else: | |
raise NotImplementedError | |
#https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/point_transformer_v2/point_transformer_v2m2_base.py | |
class GroupedVectorAttention(nn.Module): | |
def __init__( | |
self, | |
embed_channels, | |
groups, | |
attn_drop_rate=0.0, | |
qkv_bias=True, | |
pe_multiplier=False, | |
pe_bias=True, | |
): | |
super(GroupedVectorAttention, self).__init__() | |
self.embed_channels = embed_channels | |
self.groups = groups | |
assert embed_channels % groups == 0 | |
self.attn_drop_rate = attn_drop_rate | |
self.qkv_bias = qkv_bias | |
self.pe_multiplier = pe_multiplier | |
self.pe_bias = pe_bias | |
self.linear_q = nn.Sequential( | |
nn.Linear(embed_channels, embed_channels, bias=qkv_bias), | |
PointBatchNorm(embed_channels), | |
nn.ReLU(inplace=True), | |
) | |
self.linear_k = nn.Sequential( | |
nn.Linear(embed_channels, embed_channels, bias=qkv_bias), | |
PointBatchNorm(embed_channels), | |
nn.ReLU(inplace=True), | |
) | |
self.linear_v = nn.Linear(embed_channels, embed_channels, bias=qkv_bias) | |
if self.pe_multiplier: | |
self.linear_p_multiplier = nn.Sequential( | |
nn.Linear(3, embed_channels), | |
PointBatchNorm(embed_channels), | |
nn.ReLU(inplace=True), | |
nn.Linear(embed_channels, embed_channels), | |
) | |
if self.pe_bias: | |
self.linear_p_bias = nn.Sequential( | |
nn.Linear(3, embed_channels), | |
PointBatchNorm(embed_channels), | |
nn.ReLU(inplace=True), | |
nn.Linear(embed_channels, embed_channels), | |
) | |
self.weight_encoding = nn.Sequential( | |
nn.Linear(embed_channels, groups), | |
PointBatchNorm(groups), | |
nn.ReLU(inplace=True), | |
nn.Linear(groups, groups), | |
) | |
self.softmax = nn.Softmax(dim=1) | |
self.attn_drop = nn.Dropout(attn_drop_rate) | |
def forward(self, feat, coord, reference_index): | |
query, key, value = ( | |
self.linear_q(feat), | |
self.linear_k(feat), | |
self.linear_v(feat), | |
) | |
key = pointops.grouping(reference_index, key, coord, with_xyz=True) | |
value = pointops.grouping(reference_index, value, coord, with_xyz=False) | |
pos, key = key[:, :, 0:3], key[:, :, 3:] | |
relation_qk = key - query.unsqueeze(1) | |
if self.pe_multiplier: | |
pem = self.linear_p_multiplier(pos) | |
relation_qk = relation_qk * pem | |
if self.pe_bias: | |
peb = self.linear_p_bias(pos) | |
relation_qk = relation_qk + peb | |
value = value + peb | |
weight = self.weight_encoding(relation_qk) | |
weight = self.attn_drop(self.softmax(weight)) | |
mask = torch.sign(reference_index + 1) | |
weight = torch.einsum("n s g, n s -> n s g", weight, mask) | |
value = einops.rearrange(value, "n ns (g i) -> n ns g i", g=self.groups) | |
feat = torch.einsum("n s g i, n s g -> n g i", value, weight) | |
feat = einops.rearrange(feat, "n g i -> n (g i)") | |
return feat | |
class BlockSequence(nn.Module): | |
def __init__( | |
self, | |
depth, | |
embed_channels, | |
groups, | |
neighbours=16, | |
qkv_bias=True, | |
pe_multiplier=False, | |
pe_bias=True, | |
attn_drop_rate=0.0, | |
drop_path_rate=0.0, | |
enable_checkpoint=False, | |
): | |
super(BlockSequence, self).__init__() | |
if isinstance(drop_path_rate, list): | |
drop_path_rates = drop_path_rate | |
assert len(drop_path_rates) == depth | |
elif isinstance(drop_path_rate, float): | |
drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)] | |
else: | |
drop_path_rates = [0.0 for _ in range(depth)] | |
self.neighbours = neighbours | |
self.blocks = nn.ModuleList() | |
for i in range(depth): | |
block = Block( | |
embed_channels=embed_channels, | |
groups=groups, | |
qkv_bias=qkv_bias, | |
pe_multiplier=pe_multiplier, | |
pe_bias=pe_bias, | |
attn_drop_rate=attn_drop_rate, | |
drop_path_rate=drop_path_rates[i], | |
enable_checkpoint=enable_checkpoint, | |
) | |
self.blocks.append(block) | |
def forward(self, points): | |
coord, feat, offset = points | |
# reference index query of neighbourhood attention | |
# for windows attention, modify reference index query method | |
reference_index, _ = pointops.knn_query(self.neighbours, coord, offset) | |
for block in self.blocks: | |
points = block(points, reference_index) | |
return points | |
class GVAPatchEmbed(nn.Module): | |
def __init__( | |
self, | |
depth, | |
in_channels, | |
embed_channels, | |
groups, | |
neighbours=16, | |
qkv_bias=True, | |
pe_multiplier=False, | |
pe_bias=True, | |
attn_drop_rate=0.0, | |
drop_path_rate=0.0, | |
enable_checkpoint=False, | |
): | |
super(GVAPatchEmbed, self).__init__() | |
self.in_channels = in_channels | |
self.embed_channels = embed_channels | |
self.proj = nn.Sequential( | |
nn.Linear(in_channels, embed_channels, bias=False), | |
PointBatchNorm(embed_channels), | |
nn.ReLU(inplace=True), | |
) | |
self.blocks = BlockSequence( | |
depth=depth, | |
embed_channels=embed_channels, | |
groups=groups, | |
neighbours=neighbours, | |
qkv_bias=qkv_bias, | |
pe_multiplier=pe_multiplier, | |
pe_bias=pe_bias, | |
attn_drop_rate=attn_drop_rate, | |
drop_path_rate=drop_path_rate, | |
enable_checkpoint=enable_checkpoint, | |
) | |
def forward(self, points): | |
coord, feat, offset = points | |
feat = self.proj(feat) | |
return self.blocks([coord, feat, offset]) | |
class Block(nn.Module): | |
def __init__( | |
self, | |
embed_channels, | |
groups, | |
qkv_bias=True, | |
pe_multiplier=False, | |
pe_bias=True, | |
attn_drop_rate=0.0, | |
drop_path_rate=0.0, | |
enable_checkpoint=False, | |
): | |
super(Block, self).__init__() | |
self.attn = GroupedVectorAttention( | |
embed_channels=embed_channels, | |
groups=groups, | |
qkv_bias=qkv_bias, | |
attn_drop_rate=attn_drop_rate, | |
pe_multiplier=pe_multiplier, | |
pe_bias=pe_bias, | |
) | |
self.fc1 = nn.Linear(embed_channels, embed_channels, bias=False) | |
self.fc3 = nn.Linear(embed_channels, embed_channels, bias=False) | |
self.norm1 = PointBatchNorm(embed_channels) | |
self.norm2 = PointBatchNorm(embed_channels) | |
self.norm3 = PointBatchNorm(embed_channels) | |
self.act = nn.ReLU(inplace=True) | |
self.enable_checkpoint = enable_checkpoint | |
self.drop_path = ( | |
DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() | |
) | |
def forward(self, points, reference_index): | |
coord, feat, offset = points | |
identity = feat | |
feat = self.act(self.norm1(self.fc1(feat))) | |
feat = ( | |
self.attn(feat, coord, reference_index) | |
if not self.enable_checkpoint | |
else checkpoint(self.attn, feat, coord, reference_index) | |
) | |
feat = self.act(self.norm2(feat)) | |
feat = self.norm3(self.fc3(feat)) | |
feat = identity + self.drop_path(feat) | |
feat = self.act(feat) | |
return [coord, feat, offset] |