|  |  | 
					
						
						|  | from typing import Tuple | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from torch import nn as nn | 
					
						
						|  | from torch.autograd import Function | 
					
						
						|  |  | 
					
						
						|  | from ..utils import ext_loader | 
					
						
						|  | from .ball_query import ball_query | 
					
						
						|  | from .knn import knn | 
					
						
						|  |  | 
					
						
						|  | ext_module = ext_loader.load_ext( | 
					
						
						|  | '_ext', ['group_points_forward', 'group_points_backward']) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class QueryAndGroup(nn.Module): | 
					
						
						|  | """Groups points with a ball query of radius. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | max_radius (float): The maximum radius of the balls. | 
					
						
						|  | If None is given, we will use kNN sampling instead of ball query. | 
					
						
						|  | sample_num (int): Maximum number of features to gather in the ball. | 
					
						
						|  | min_radius (float, optional): The minimum radius of the balls. | 
					
						
						|  | Default: 0. | 
					
						
						|  | use_xyz (bool, optional): Whether to use xyz. | 
					
						
						|  | Default: True. | 
					
						
						|  | return_grouped_xyz (bool, optional): Whether to return grouped xyz. | 
					
						
						|  | Default: False. | 
					
						
						|  | normalize_xyz (bool, optional): Whether to normalize xyz. | 
					
						
						|  | Default: False. | 
					
						
						|  | uniform_sample (bool, optional): Whether to sample uniformly. | 
					
						
						|  | Default: False | 
					
						
						|  | return_unique_cnt (bool, optional): Whether to return the count of | 
					
						
						|  | unique samples. Default: False. | 
					
						
						|  | return_grouped_idx (bool, optional): Whether to return grouped idx. | 
					
						
						|  | Default: False. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, | 
					
						
						|  | max_radius, | 
					
						
						|  | sample_num, | 
					
						
						|  | min_radius=0, | 
					
						
						|  | use_xyz=True, | 
					
						
						|  | return_grouped_xyz=False, | 
					
						
						|  | normalize_xyz=False, | 
					
						
						|  | uniform_sample=False, | 
					
						
						|  | return_unique_cnt=False, | 
					
						
						|  | return_grouped_idx=False): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.max_radius = max_radius | 
					
						
						|  | self.min_radius = min_radius | 
					
						
						|  | self.sample_num = sample_num | 
					
						
						|  | self.use_xyz = use_xyz | 
					
						
						|  | self.return_grouped_xyz = return_grouped_xyz | 
					
						
						|  | self.normalize_xyz = normalize_xyz | 
					
						
						|  | self.uniform_sample = uniform_sample | 
					
						
						|  | self.return_unique_cnt = return_unique_cnt | 
					
						
						|  | self.return_grouped_idx = return_grouped_idx | 
					
						
						|  | if self.return_unique_cnt: | 
					
						
						|  | assert self.uniform_sample, \ | 
					
						
						|  | 'uniform_sample should be True when ' \ | 
					
						
						|  | 'returning the count of unique samples' | 
					
						
						|  | if self.max_radius is None: | 
					
						
						|  | assert not self.normalize_xyz, \ | 
					
						
						|  | 'can not normalize grouped xyz when max_radius is None' | 
					
						
						|  |  | 
					
						
						|  | def forward(self, points_xyz, center_xyz, features=None): | 
					
						
						|  | """ | 
					
						
						|  | Args: | 
					
						
						|  | points_xyz (Tensor): (B, N, 3) xyz coordinates of the features. | 
					
						
						|  | center_xyz (Tensor): (B, npoint, 3) coordinates of the centriods. | 
					
						
						|  | features (Tensor): (B, C, N) Descriptors of the features. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Tensor: (B, 3 + C, npoint, sample_num) Grouped feature. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.max_radius is None: | 
					
						
						|  | idx = knn(self.sample_num, points_xyz, center_xyz, False) | 
					
						
						|  | idx = idx.transpose(1, 2).contiguous() | 
					
						
						|  | else: | 
					
						
						|  | idx = ball_query(self.min_radius, self.max_radius, self.sample_num, | 
					
						
						|  | points_xyz, center_xyz) | 
					
						
						|  |  | 
					
						
						|  | if self.uniform_sample: | 
					
						
						|  | unique_cnt = torch.zeros((idx.shape[0], idx.shape[1])) | 
					
						
						|  | for i_batch in range(idx.shape[0]): | 
					
						
						|  | for i_region in range(idx.shape[1]): | 
					
						
						|  | unique_ind = torch.unique(idx[i_batch, i_region, :]) | 
					
						
						|  | num_unique = unique_ind.shape[0] | 
					
						
						|  | unique_cnt[i_batch, i_region] = num_unique | 
					
						
						|  | sample_ind = torch.randint( | 
					
						
						|  | 0, | 
					
						
						|  | num_unique, (self.sample_num - num_unique, ), | 
					
						
						|  | dtype=torch.long) | 
					
						
						|  | all_ind = torch.cat((unique_ind, unique_ind[sample_ind])) | 
					
						
						|  | idx[i_batch, i_region, :] = all_ind | 
					
						
						|  |  | 
					
						
						|  | xyz_trans = points_xyz.transpose(1, 2).contiguous() | 
					
						
						|  |  | 
					
						
						|  | grouped_xyz = grouping_operation(xyz_trans, idx) | 
					
						
						|  | grouped_xyz_diff = grouped_xyz - \ | 
					
						
						|  | center_xyz.transpose(1, 2).unsqueeze(-1) | 
					
						
						|  | if self.normalize_xyz: | 
					
						
						|  | grouped_xyz_diff /= self.max_radius | 
					
						
						|  |  | 
					
						
						|  | if features is not None: | 
					
						
						|  | grouped_features = grouping_operation(features, idx) | 
					
						
						|  | if self.use_xyz: | 
					
						
						|  |  | 
					
						
						|  | new_features = torch.cat([grouped_xyz_diff, grouped_features], | 
					
						
						|  | dim=1) | 
					
						
						|  | else: | 
					
						
						|  | new_features = grouped_features | 
					
						
						|  | else: | 
					
						
						|  | assert (self.use_xyz | 
					
						
						|  | ), 'Cannot have not features and not use xyz as a feature!' | 
					
						
						|  | new_features = grouped_xyz_diff | 
					
						
						|  |  | 
					
						
						|  | ret = [new_features] | 
					
						
						|  | if self.return_grouped_xyz: | 
					
						
						|  | ret.append(grouped_xyz) | 
					
						
						|  | if self.return_unique_cnt: | 
					
						
						|  | ret.append(unique_cnt) | 
					
						
						|  | if self.return_grouped_idx: | 
					
						
						|  | ret.append(idx) | 
					
						
						|  | if len(ret) == 1: | 
					
						
						|  | return ret[0] | 
					
						
						|  | else: | 
					
						
						|  | return tuple(ret) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class GroupAll(nn.Module): | 
					
						
						|  | """Group xyz with feature. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | use_xyz (bool): Whether to use xyz. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, use_xyz: bool = True): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.use_xyz = use_xyz | 
					
						
						|  |  | 
					
						
						|  | def forward(self, | 
					
						
						|  | xyz: torch.Tensor, | 
					
						
						|  | new_xyz: torch.Tensor, | 
					
						
						|  | features: torch.Tensor = None): | 
					
						
						|  | """ | 
					
						
						|  | Args: | 
					
						
						|  | xyz (Tensor): (B, N, 3) xyz coordinates of the features. | 
					
						
						|  | new_xyz (Tensor): new xyz coordinates of the features. | 
					
						
						|  | features (Tensor): (B, C, N) features to group. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Tensor: (B, C + 3, 1, N) Grouped feature. | 
					
						
						|  | """ | 
					
						
						|  | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) | 
					
						
						|  | if features is not None: | 
					
						
						|  | grouped_features = features.unsqueeze(2) | 
					
						
						|  | if self.use_xyz: | 
					
						
						|  |  | 
					
						
						|  | new_features = torch.cat([grouped_xyz, grouped_features], | 
					
						
						|  | dim=1) | 
					
						
						|  | else: | 
					
						
						|  | new_features = grouped_features | 
					
						
						|  | else: | 
					
						
						|  | new_features = grouped_xyz | 
					
						
						|  |  | 
					
						
						|  | return new_features | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class GroupingOperation(Function): | 
					
						
						|  | """Group feature with given index.""" | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def forward(ctx, features: torch.Tensor, | 
					
						
						|  | indices: torch.Tensor) -> torch.Tensor: | 
					
						
						|  | """ | 
					
						
						|  | Args: | 
					
						
						|  | features (Tensor): (B, C, N) tensor of features to group. | 
					
						
						|  | indices (Tensor): (B, npoint, nsample) the indices of | 
					
						
						|  | features to group with. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Tensor: (B, C, npoint, nsample) Grouped features. | 
					
						
						|  | """ | 
					
						
						|  | features = features.contiguous() | 
					
						
						|  | indices = indices.contiguous() | 
					
						
						|  |  | 
					
						
						|  | B, nfeatures, nsample = indices.size() | 
					
						
						|  | _, C, N = features.size() | 
					
						
						|  | output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) | 
					
						
						|  |  | 
					
						
						|  | ext_module.group_points_forward(B, C, N, nfeatures, nsample, features, | 
					
						
						|  | indices, output) | 
					
						
						|  |  | 
					
						
						|  | ctx.for_backwards = (indices, N) | 
					
						
						|  | return output | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def backward(ctx, | 
					
						
						|  | grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | 
					
						
						|  | """ | 
					
						
						|  | Args: | 
					
						
						|  | grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients | 
					
						
						|  | of the output from forward. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Tensor: (B, C, N) gradient of the features. | 
					
						
						|  | """ | 
					
						
						|  | idx, N = ctx.for_backwards | 
					
						
						|  |  | 
					
						
						|  | B, C, npoint, nsample = grad_out.size() | 
					
						
						|  | grad_features = torch.cuda.FloatTensor(B, C, N).zero_() | 
					
						
						|  |  | 
					
						
						|  | grad_out_data = grad_out.data.contiguous() | 
					
						
						|  | ext_module.group_points_backward(B, C, N, npoint, nsample, | 
					
						
						|  | grad_out_data, idx, | 
					
						
						|  | grad_features.data) | 
					
						
						|  | return grad_features, None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | grouping_operation = GroupingOperation.apply | 
					
						
						|  |  |