# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import List import torch from torch import Tensor from torch import nn as nn from .furthest_point_sample import (furthest_point_sample, furthest_point_sample_with_dist) def calc_square_dist(point_feat_a: Tensor, point_feat_b: Tensor, norm: bool = True) -> Tensor: """Calculating square distance between a and b. Args: point_feat_a (torch.Tensor): (B, N, C) Feature vector of each point. point_feat_b (torch.Tensor): (B, M, C) Feature vector of each point. norm (bool, optional): Whether to normalize the distance. Default: True. Returns: torch.Tensor: (B, N, M) Square distance between each point pair. """ num_channel = point_feat_a.shape[-1] dist = torch.cdist(point_feat_a, point_feat_b) if norm: dist = dist / num_channel else: dist = torch.square(dist) return dist def get_sampler_cls(sampler_type: str) -> nn.Module: """Get the type and mode of points sampler. Args: sampler_type (str): The type of points sampler. The valid value are "D-FPS", "F-FPS", or "FS". Returns: class: Points sampler type. """ sampler_mappings = { 'D-FPS': DFPSSampler, 'F-FPS': FFPSSampler, 'FS': FSSampler, } try: return sampler_mappings[sampler_type] except KeyError: raise KeyError( f'Supported `sampler_type` are {sampler_mappings.keys()}, but got \ {sampler_type}') class PointsSampler(nn.Module): """Points sampling. Args: num_point (list[int]): Number of sample points. fps_mod_list (list[str], optional): Type of FPS method, valid mod ['F-FPS', 'D-FPS', 'FS'], Default: ['D-FPS']. F-FPS: using feature distances for FPS. D-FPS: using Euclidean distances of points for FPS. FS: using F-FPS and D-FPS simultaneously. fps_sample_range_list (list[int], optional): Range of points to apply FPS. Default: [-1]. """ def __init__(self, num_point: List[int], fps_mod_list: List[str] = ['D-FPS'], fps_sample_range_list: List[int] = [-1]) -> None: super().__init__() # FPS would be applied to different fps_mod in the list, # so the length of the num_point should be equal to # fps_mod_list and fps_sample_range_list. assert len(num_point) == len(fps_mod_list) == len( fps_sample_range_list) self.num_point = num_point self.fps_sample_range_list = fps_sample_range_list self.samplers = nn.ModuleList() for fps_mod in fps_mod_list: self.samplers.append(get_sampler_cls(fps_mod)()) self.fp16_enabled = False def forward(self, points_xyz: Tensor, features: Tensor) -> Tensor: """ Args: points_xyz (torch.Tensor): (B, N, 3) xyz coordinates of the points. features (torch.Tensor): (B, C, N) features of the points. Returns: torch.Tensor: (B, npoint, sample_num) Indices of sampled points. """ if points_xyz.dtype == torch.half: points_xyz = points_xyz.to(torch.float32) if features is not None and features.dtype == torch.half: features = features.to(torch.float32) indices = [] last_fps_end_index = 0 for fps_sample_range, sampler, npoint in zip( self.fps_sample_range_list, self.samplers, self.num_point): assert fps_sample_range < points_xyz.shape[1] if fps_sample_range == -1: sample_points_xyz = points_xyz[:, last_fps_end_index:] if features is not None: sample_features = features[:, :, last_fps_end_index:] else: sample_features = None else: sample_points_xyz = points_xyz[:, last_fps_end_index: fps_sample_range] if features is not None: sample_features = features[:, :, last_fps_end_index: fps_sample_range] else: sample_features = None fps_idx = sampler(sample_points_xyz.contiguous(), sample_features, npoint) indices.append(fps_idx + last_fps_end_index) last_fps_end_index = fps_sample_range indices = torch.cat(indices, dim=1) return indices class DFPSSampler(nn.Module): """Using Euclidean distances of points for FPS.""" def __init__(self) -> None: super().__init__() def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor: """Sampling points with D-FPS.""" fps_idx = furthest_point_sample(points.contiguous(), npoint) return fps_idx class FFPSSampler(nn.Module): """Using feature distances for FPS.""" def __init__(self) -> None: super().__init__() def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor: """Sampling points with F-FPS.""" assert features is not None, \ 'feature input to FFPS_Sampler should not be None' features_for_fps = torch.cat([points, features.transpose(1, 2)], dim=2) features_dist = calc_square_dist( features_for_fps, features_for_fps, norm=False) fps_idx = furthest_point_sample_with_dist(features_dist, npoint) return fps_idx class FSSampler(nn.Module): """Using F-FPS and D-FPS simultaneously.""" def __init__(self) -> None: super().__init__() def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor: """Sampling points with FS_Sampling.""" assert features is not None, \ 'feature input to FS_Sampler should not be None' ffps_sampler = FFPSSampler() dfps_sampler = DFPSSampler() fps_idx_ffps = ffps_sampler(points, features, npoint) fps_idx_dfps = dfps_sampler(points, features, npoint) fps_idx = torch.cat([fps_idx_ffps, fps_idx_dfps], dim=1) return fps_idx