KDTalker / difpoint /model /temporaltrans /pointtransformerv2.py
ChaolongYang's picture
Upload 242 files
475d332 verified
from .transformer_utils import BaseTemperalPointModel
from copy import deepcopy
import torch
import einops
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from einops import rearrange
import pointops
from pointcept.models.utils import offset2batch, batch2offset
class PointBatchNorm(nn.Module):
"""
Batch Normalization for Point Clouds data in shape of [B*N, C], [B*N, L, C]
"""
def __init__(self, embed_channels):
super().__init__()
self.norm = nn.BatchNorm1d(embed_channels)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if input.dim() == 3:
return (
self.norm(input.transpose(1, 2).contiguous())
.transpose(1, 2)
.contiguous()
)
elif input.dim() == 2:
return self.norm(input)
else:
raise NotImplementedError
#https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/point_transformer_v2/point_transformer_v2m2_base.py
class GroupedVectorAttention(nn.Module):
def __init__(
self,
embed_channels,
groups,
attn_drop_rate=0.0,
qkv_bias=True,
pe_multiplier=False,
pe_bias=True,
):
super(GroupedVectorAttention, self).__init__()
self.embed_channels = embed_channels
self.groups = groups
assert embed_channels % groups == 0
self.attn_drop_rate = attn_drop_rate
self.qkv_bias = qkv_bias
self.pe_multiplier = pe_multiplier
self.pe_bias = pe_bias
self.linear_q = nn.Sequential(
nn.Linear(embed_channels, embed_channels, bias=qkv_bias),
PointBatchNorm(embed_channels),
nn.ReLU(inplace=True),
)
self.linear_k = nn.Sequential(
nn.Linear(embed_channels, embed_channels, bias=qkv_bias),
PointBatchNorm(embed_channels),
nn.ReLU(inplace=True),
)
self.linear_v = nn.Linear(embed_channels, embed_channels, bias=qkv_bias)
if self.pe_multiplier:
self.linear_p_multiplier = nn.Sequential(
nn.Linear(3, embed_channels),
PointBatchNorm(embed_channels),
nn.ReLU(inplace=True),
nn.Linear(embed_channels, embed_channels),
)
if self.pe_bias:
self.linear_p_bias = nn.Sequential(
nn.Linear(3, embed_channels),
PointBatchNorm(embed_channels),
nn.ReLU(inplace=True),
nn.Linear(embed_channels, embed_channels),
)
self.weight_encoding = nn.Sequential(
nn.Linear(embed_channels, groups),
PointBatchNorm(groups),
nn.ReLU(inplace=True),
nn.Linear(groups, groups),
)
self.softmax = nn.Softmax(dim=1)
self.attn_drop = nn.Dropout(attn_drop_rate)
def forward(self, feat, coord, reference_index):
query, key, value = (
self.linear_q(feat),
self.linear_k(feat),
self.linear_v(feat),
)
key = pointops.grouping(reference_index, key, coord, with_xyz=True)
value = pointops.grouping(reference_index, value, coord, with_xyz=False)
pos, key = key[:, :, 0:3], key[:, :, 3:]
relation_qk = key - query.unsqueeze(1)
if self.pe_multiplier:
pem = self.linear_p_multiplier(pos)
relation_qk = relation_qk * pem
if self.pe_bias:
peb = self.linear_p_bias(pos)
relation_qk = relation_qk + peb
value = value + peb
weight = self.weight_encoding(relation_qk)
weight = self.attn_drop(self.softmax(weight))
mask = torch.sign(reference_index + 1)
weight = torch.einsum("n s g, n s -> n s g", weight, mask)
value = einops.rearrange(value, "n ns (g i) -> n ns g i", g=self.groups)
feat = torch.einsum("n s g i, n s g -> n g i", value, weight)
feat = einops.rearrange(feat, "n g i -> n (g i)")
return feat
class BlockSequence(nn.Module):
def __init__(
self,
depth,
embed_channels,
groups,
neighbours=16,
qkv_bias=True,
pe_multiplier=False,
pe_bias=True,
attn_drop_rate=0.0,
drop_path_rate=0.0,
enable_checkpoint=False,
):
super(BlockSequence, self).__init__()
if isinstance(drop_path_rate, list):
drop_path_rates = drop_path_rate
assert len(drop_path_rates) == depth
elif isinstance(drop_path_rate, float):
drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
else:
drop_path_rates = [0.0 for _ in range(depth)]
self.neighbours = neighbours
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
embed_channels=embed_channels,
groups=groups,
qkv_bias=qkv_bias,
pe_multiplier=pe_multiplier,
pe_bias=pe_bias,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rates[i],
enable_checkpoint=enable_checkpoint,
)
self.blocks.append(block)
def forward(self, points):
coord, feat, offset = points
# reference index query of neighbourhood attention
# for windows attention, modify reference index query method
reference_index, _ = pointops.knn_query(self.neighbours, coord, offset)
for block in self.blocks:
points = block(points, reference_index)
return points
class GVAPatchEmbed(nn.Module):
def __init__(
self,
depth,
in_channels,
embed_channels,
groups,
neighbours=16,
qkv_bias=True,
pe_multiplier=False,
pe_bias=True,
attn_drop_rate=0.0,
drop_path_rate=0.0,
enable_checkpoint=False,
):
super(GVAPatchEmbed, self).__init__()
self.in_channels = in_channels
self.embed_channels = embed_channels
self.proj = nn.Sequential(
nn.Linear(in_channels, embed_channels, bias=False),
PointBatchNorm(embed_channels),
nn.ReLU(inplace=True),
)
self.blocks = BlockSequence(
depth=depth,
embed_channels=embed_channels,
groups=groups,
neighbours=neighbours,
qkv_bias=qkv_bias,
pe_multiplier=pe_multiplier,
pe_bias=pe_bias,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
enable_checkpoint=enable_checkpoint,
)
def forward(self, points):
coord, feat, offset = points
feat = self.proj(feat)
return self.blocks([coord, feat, offset])
class Block(nn.Module):
def __init__(
self,
embed_channels,
groups,
qkv_bias=True,
pe_multiplier=False,
pe_bias=True,
attn_drop_rate=0.0,
drop_path_rate=0.0,
enable_checkpoint=False,
):
super(Block, self).__init__()
self.attn = GroupedVectorAttention(
embed_channels=embed_channels,
groups=groups,
qkv_bias=qkv_bias,
attn_drop_rate=attn_drop_rate,
pe_multiplier=pe_multiplier,
pe_bias=pe_bias,
)
self.fc1 = nn.Linear(embed_channels, embed_channels, bias=False)
self.fc3 = nn.Linear(embed_channels, embed_channels, bias=False)
self.norm1 = PointBatchNorm(embed_channels)
self.norm2 = PointBatchNorm(embed_channels)
self.norm3 = PointBatchNorm(embed_channels)
self.act = nn.ReLU(inplace=True)
self.enable_checkpoint = enable_checkpoint
self.drop_path = (
DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
)
def forward(self, points, reference_index):
coord, feat, offset = points
identity = feat
feat = self.act(self.norm1(self.fc1(feat)))
feat = (
self.attn(feat, coord, reference_index)
if not self.enable_checkpoint
else checkpoint(self.attn, feat, coord, reference_index)
)
feat = self.act(self.norm2(feat))
feat = self.norm3(self.fc3(feat))
feat = identity + self.drop_path(feat)
feat = self.act(feat)
return [coord, feat, offset]