from .transformer_utils import BaseTemperalPointModel from copy import deepcopy import torch import einops import torch.nn as nn import torch.nn.functional as F from torch import nn from einops import rearrange import pointops from pointcept.models.utils import offset2batch, batch2offset class PointBatchNorm(nn.Module): """ Batch Normalization for Point Clouds data in shape of [B*N, C], [B*N, L, C] """ def __init__(self, embed_channels): super().__init__() self.norm = nn.BatchNorm1d(embed_channels) def forward(self, input: torch.Tensor) -> torch.Tensor: if input.dim() == 3: return ( self.norm(input.transpose(1, 2).contiguous()) .transpose(1, 2) .contiguous() ) elif input.dim() == 2: return self.norm(input) else: raise NotImplementedError #https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/point_transformer_v2/point_transformer_v2m2_base.py class GroupedVectorAttention(nn.Module): def __init__( self, embed_channels, groups, attn_drop_rate=0.0, qkv_bias=True, pe_multiplier=False, pe_bias=True, ): super(GroupedVectorAttention, self).__init__() self.embed_channels = embed_channels self.groups = groups assert embed_channels % groups == 0 self.attn_drop_rate = attn_drop_rate self.qkv_bias = qkv_bias self.pe_multiplier = pe_multiplier self.pe_bias = pe_bias self.linear_q = nn.Sequential( nn.Linear(embed_channels, embed_channels, bias=qkv_bias), PointBatchNorm(embed_channels), nn.ReLU(inplace=True), ) self.linear_k = nn.Sequential( nn.Linear(embed_channels, embed_channels, bias=qkv_bias), PointBatchNorm(embed_channels), nn.ReLU(inplace=True), ) self.linear_v = nn.Linear(embed_channels, embed_channels, bias=qkv_bias) if self.pe_multiplier: self.linear_p_multiplier = nn.Sequential( nn.Linear(3, embed_channels), PointBatchNorm(embed_channels), nn.ReLU(inplace=True), nn.Linear(embed_channels, embed_channels), ) if self.pe_bias: self.linear_p_bias = nn.Sequential( nn.Linear(3, embed_channels), PointBatchNorm(embed_channels), nn.ReLU(inplace=True), nn.Linear(embed_channels, embed_channels), ) self.weight_encoding = nn.Sequential( nn.Linear(embed_channels, groups), PointBatchNorm(groups), nn.ReLU(inplace=True), nn.Linear(groups, groups), ) self.softmax = nn.Softmax(dim=1) self.attn_drop = nn.Dropout(attn_drop_rate) def forward(self, feat, coord, reference_index): query, key, value = ( self.linear_q(feat), self.linear_k(feat), self.linear_v(feat), ) key = pointops.grouping(reference_index, key, coord, with_xyz=True) value = pointops.grouping(reference_index, value, coord, with_xyz=False) pos, key = key[:, :, 0:3], key[:, :, 3:] relation_qk = key - query.unsqueeze(1) if self.pe_multiplier: pem = self.linear_p_multiplier(pos) relation_qk = relation_qk * pem if self.pe_bias: peb = self.linear_p_bias(pos) relation_qk = relation_qk + peb value = value + peb weight = self.weight_encoding(relation_qk) weight = self.attn_drop(self.softmax(weight)) mask = torch.sign(reference_index + 1) weight = torch.einsum("n s g, n s -> n s g", weight, mask) value = einops.rearrange(value, "n ns (g i) -> n ns g i", g=self.groups) feat = torch.einsum("n s g i, n s g -> n g i", value, weight) feat = einops.rearrange(feat, "n g i -> n (g i)") return feat class BlockSequence(nn.Module): def __init__( self, depth, embed_channels, groups, neighbours=16, qkv_bias=True, pe_multiplier=False, pe_bias=True, attn_drop_rate=0.0, drop_path_rate=0.0, enable_checkpoint=False, ): super(BlockSequence, self).__init__() if isinstance(drop_path_rate, list): drop_path_rates = drop_path_rate assert len(drop_path_rates) == depth elif isinstance(drop_path_rate, float): drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)] else: drop_path_rates = [0.0 for _ in range(depth)] self.neighbours = neighbours self.blocks = nn.ModuleList() for i in range(depth): block = Block( embed_channels=embed_channels, groups=groups, qkv_bias=qkv_bias, pe_multiplier=pe_multiplier, pe_bias=pe_bias, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rates[i], enable_checkpoint=enable_checkpoint, ) self.blocks.append(block) def forward(self, points): coord, feat, offset = points # reference index query of neighbourhood attention # for windows attention, modify reference index query method reference_index, _ = pointops.knn_query(self.neighbours, coord, offset) for block in self.blocks: points = block(points, reference_index) return points class GVAPatchEmbed(nn.Module): def __init__( self, depth, in_channels, embed_channels, groups, neighbours=16, qkv_bias=True, pe_multiplier=False, pe_bias=True, attn_drop_rate=0.0, drop_path_rate=0.0, enable_checkpoint=False, ): super(GVAPatchEmbed, self).__init__() self.in_channels = in_channels self.embed_channels = embed_channels self.proj = nn.Sequential( nn.Linear(in_channels, embed_channels, bias=False), PointBatchNorm(embed_channels), nn.ReLU(inplace=True), ) self.blocks = BlockSequence( depth=depth, embed_channels=embed_channels, groups=groups, neighbours=neighbours, qkv_bias=qkv_bias, pe_multiplier=pe_multiplier, pe_bias=pe_bias, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, enable_checkpoint=enable_checkpoint, ) def forward(self, points): coord, feat, offset = points feat = self.proj(feat) return self.blocks([coord, feat, offset]) class Block(nn.Module): def __init__( self, embed_channels, groups, qkv_bias=True, pe_multiplier=False, pe_bias=True, attn_drop_rate=0.0, drop_path_rate=0.0, enable_checkpoint=False, ): super(Block, self).__init__() self.attn = GroupedVectorAttention( embed_channels=embed_channels, groups=groups, qkv_bias=qkv_bias, attn_drop_rate=attn_drop_rate, pe_multiplier=pe_multiplier, pe_bias=pe_bias, ) self.fc1 = nn.Linear(embed_channels, embed_channels, bias=False) self.fc3 = nn.Linear(embed_channels, embed_channels, bias=False) self.norm1 = PointBatchNorm(embed_channels) self.norm2 = PointBatchNorm(embed_channels) self.norm3 = PointBatchNorm(embed_channels) self.act = nn.ReLU(inplace=True) self.enable_checkpoint = enable_checkpoint self.drop_path = ( DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() ) def forward(self, points, reference_index): coord, feat, offset = points identity = feat feat = self.act(self.norm1(self.fc1(feat))) feat = ( self.attn(feat, coord, reference_index) if not self.enable_checkpoint else checkpoint(self.attn, feat, coord, reference_index) ) feat = self.act(self.norm2(feat)) feat = self.norm3(self.fc3(feat)) feat = identity + self.drop_path(feat) feat = self.act(feat) return [coord, feat, offset]