# Copyright (c) OpenMMLab. All rights reserved. import math from typing import Dict, Optional, Sequence, Tuple, Union import numpy as np import torch import torch.nn as nn from mmcv.cnn import build_conv_layer from mmengine.model import BaseModule, ModuleDict from mmengine.structures import InstanceData, PixelData from torch import Tensor from mmpose.models.utils.tta import flip_heatmaps from mmpose.registry import KEYPOINT_CODECS, MODELS from mmpose.utils.typing import (ConfigType, Features, OptConfigType, OptSampleList, Predictions) from ..base_head import BaseHead def smooth_heatmaps(heatmaps: Tensor, blur_kernel_size: int) -> Tensor: """Smooth the heatmaps by blurring and averaging. Args: heatmaps (Tensor): The heatmaps to smooth. blur_kernel_size (int): The kernel size for blurring the heatmaps. Returns: Tensor: The smoothed heatmaps. """ smoothed_heatmaps = torch.nn.functional.avg_pool2d( heatmaps, blur_kernel_size, 1, (blur_kernel_size - 1) // 2) smoothed_heatmaps = (heatmaps + smoothed_heatmaps) / 2.0 return smoothed_heatmaps class TruncSigmoid(nn.Sigmoid): """A sigmoid activation function that truncates the output to the given range. Args: min (float, optional): The minimum value to clamp the output to. Defaults to 0.0 max (float, optional): The maximum value to clamp the output to. Defaults to 1.0 """ def __init__(self, min: float = 0.0, max: float = 1.0): super(TruncSigmoid, self).__init__() self.min = min self.max = max def forward(self, input: Tensor) -> Tensor: """Computes the truncated sigmoid activation of the input tensor.""" output = torch.sigmoid(input) output = output.clamp(min=self.min, max=self.max) return output class IIAModule(BaseModule): """Instance Information Abstraction module introduced in `CID`. This module extracts the feature representation vectors for each instance. Args: in_channels (int): Number of channels in the input feature tensor out_channels (int): Number of channels of the output heatmaps clamp_delta (float, optional): A small value that prevents the sigmoid activation from becoming saturated. Defaults to 1e-4. init_cfg (Config, optional): Config to control the initialization. See :attr:`default_init_cfg` for default settings """ def __init__( self, in_channels: int, out_channels: int, clamp_delta: float = 1e-4, init_cfg: OptConfigType = None, ): super().__init__(init_cfg=init_cfg) self.keypoint_root_conv = build_conv_layer( dict( type='Conv2d', in_channels=in_channels, out_channels=out_channels, kernel_size=1)) self.sigmoid = TruncSigmoid(min=clamp_delta, max=1 - clamp_delta) def forward(self, feats: Tensor): heatmaps = self.keypoint_root_conv(feats) heatmaps = self.sigmoid(heatmaps) return heatmaps def _sample_feats(self, feats: Tensor, indices: Tensor) -> Tensor: """Extract feature vectors at the specified indices from the input feature map. Args: feats (Tensor): Input feature map. indices (Tensor): Indices of the feature vectors to extract. Returns: Tensor: Extracted feature vectors. """ assert indices.dtype == torch.long if indices.shape[1] == 3: b, w, h = [ind.squeeze(-1) for ind in indices.split(1, -1)] instance_feats = feats[b, :, h, w] elif indices.shape[1] == 2: w, h = [ind.squeeze(-1) for ind in indices.split(1, -1)] instance_feats = feats[:, :, h, w] instance_feats = instance_feats.permute(0, 2, 1) instance_feats = instance_feats.reshape(-1, instance_feats.shape[-1]) else: raise ValueError(f'`indices` should have 2 or 3 channels, ' f'but got f{indices.shape[1]}') return instance_feats def _hierarchical_pool(self, heatmaps: Tensor) -> Tensor: """Conduct max pooling on the input heatmaps with different kernel size according to the input size. Args: heatmaps (Tensor): Input heatmaps. Returns: Tensor: Result of hierarchical pooling. """ map_size = (heatmaps.shape[-1] + heatmaps.shape[-2]) / 2.0 if map_size > 300: maxm = torch.nn.functional.max_pool2d(heatmaps, 7, 1, 3) elif map_size > 200: maxm = torch.nn.functional.max_pool2d(heatmaps, 5, 1, 2) else: maxm = torch.nn.functional.max_pool2d(heatmaps, 3, 1, 1) return maxm def forward_train(self, feats: Tensor, instance_coords: Tensor, instance_imgids: Tensor) -> Tuple[Tensor, Tensor]: """Forward pass during training. Args: feats (Tensor): Input feature tensor. instance_coords (Tensor): Coordinates of the instance roots. instance_imgids (Tensor): Sample indices of each instances in the batch. Returns: Tuple[Tensor, Tensor]: Extracted feature vectors and heatmaps for the instances. """ heatmaps = self.forward(feats) indices = torch.cat((instance_imgids[:, None], instance_coords), dim=1) instance_feats = self._sample_feats(feats, indices) return instance_feats, heatmaps def forward_test( self, feats: Tensor, test_cfg: Dict ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: """Forward pass during testing. Args: feats (Tensor): Input feature tensor. test_cfg (Dict): Testing configuration, including: - blur_kernel_size (int, optional): Kernel size for blurring the heatmaps. Defaults to 3. - max_instances (int, optional): Maximum number of instances to extract. Defaults to 30. - score_threshold (float, optional): Minimum score for extracting an instance. Defaults to 0.01. - flip_test (bool, optional): Whether to compute the average of the heatmaps across the batch dimension. Defaults to False. Returns: A tuple of Tensor including extracted feature vectors, coordinates, and scores of the instances. Any of these can be empty Tensor if no instances are extracted. """ blur_kernel_size = test_cfg.get('blur_kernel_size', 3) max_instances = test_cfg.get('max_instances', 30) score_threshold = test_cfg.get('score_threshold', 0.01) H, W = feats.shape[-2:] # compute heatmaps heatmaps = self.forward(feats).narrow(1, -1, 1) if test_cfg.get('flip_test', False): heatmaps = heatmaps.mean(dim=0, keepdims=True) smoothed_heatmaps = smooth_heatmaps(heatmaps, blur_kernel_size) # decode heatmaps maximums = self._hierarchical_pool(smoothed_heatmaps) maximums = torch.eq(maximums, smoothed_heatmaps).float() maximums = (smoothed_heatmaps * maximums).reshape(-1) scores, pos_ind = maximums.topk(max_instances, dim=0) select_ind = (scores > (score_threshold)).nonzero().squeeze(1) scores, pos_ind = scores[select_ind], pos_ind[select_ind] # sample feature vectors from feature map instance_coords = torch.stack((pos_ind % W, pos_ind // W), dim=1) instance_feats = self._sample_feats(feats, instance_coords) return instance_feats, instance_coords, scores class ChannelAttention(nn.Module): """Channel-wise attention module introduced in `CID`. Args: in_channels (int): The number of channels of the input instance vectors. out_channels (int): The number of channels of the transformed instance vectors. """ def __init__(self, in_channels: int, out_channels: int): super(ChannelAttention, self).__init__() self.atn = nn.Linear(in_channels, out_channels) def forward(self, global_feats: Tensor, instance_feats: Tensor) -> Tensor: """Applies attention to the channel dimension of the input tensor.""" instance_feats = self.atn(instance_feats).unsqueeze(2).unsqueeze(3) return global_feats * instance_feats class SpatialAttention(nn.Module): """Spatial-wise attention module introduced in `CID`. Args: in_channels (int): The number of channels of the input instance vectors. out_channels (int): The number of channels of the transformed instance vectors. """ def __init__(self, in_channels, out_channels): super(SpatialAttention, self).__init__() self.atn = nn.Linear(in_channels, out_channels) self.feat_stride = 4 self.conv = nn.Conv2d(3, 1, 5, 1, 2) def _get_pixel_coords(self, heatmap_size: Tuple, device: str = 'cpu'): """Get pixel coordinates for each element in the heatmap. Args: heatmap_size (tuple): Size of the heatmap in (W, H) format. device (str): Device to put the resulting tensor on. Returns: Tensor of shape (batch_size, num_pixels, 2) containing the pixel coordinates for each element in the heatmap. """ w, h = heatmap_size y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) pixel_coords = torch.stack((x, y), dim=-1).reshape(-1, 2) pixel_coords = pixel_coords.float().to(device) + 0.5 return pixel_coords def forward(self, global_feats: Tensor, instance_feats: Tensor, instance_coords: Tensor) -> Tensor: """Perform spatial attention. Args: global_feats (Tensor): Tensor containing the global features. instance_feats (Tensor): Tensor containing the instance feature vectors. instance_coords (Tensor): Tensor containing the root coordinates of the instances. Returns: Tensor containing the modulated global features. """ B, C, H, W = global_feats.size() instance_feats = self.atn(instance_feats).reshape(B, C, 1, 1) feats = global_feats * instance_feats.expand_as(global_feats) fsum = torch.sum(feats, dim=1, keepdim=True) pixel_coords = self._get_pixel_coords((W, H), feats.device) relative_coords = instance_coords.reshape( -1, 1, 2) - pixel_coords.reshape(1, -1, 2) relative_coords = relative_coords.permute(0, 2, 1) / 32.0 relative_coords = relative_coords.reshape(B, 2, H, W) input_feats = torch.cat((fsum, relative_coords), dim=1) mask = self.conv(input_feats).sigmoid() return global_feats * mask class GFDModule(BaseModule): """Global Feature Decoupling module introduced in `CID`. This module extracts the decoupled heatmaps for each instance. Args: in_channels (int): Number of channels in the input feature map out_channels (int): Number of channels of the output heatmaps for each instance gfd_channels (int): Number of channels in the transformed feature map clamp_delta (float, optional): A small value that prevents the sigmoid activation from becoming saturated. Defaults to 1e-4. init_cfg (Config, optional): Config to control the initialization. See :attr:`default_init_cfg` for default settings """ def __init__( self, in_channels: int, out_channels: int, gfd_channels: int, clamp_delta: float = 1e-4, init_cfg: OptConfigType = None, ): super().__init__(init_cfg=init_cfg) self.conv_down = build_conv_layer( dict( type='Conv2d', in_channels=in_channels, out_channels=gfd_channels, kernel_size=1)) self.channel_attention = ChannelAttention(in_channels, gfd_channels) self.spatial_attention = SpatialAttention(in_channels, gfd_channels) self.fuse_attention = build_conv_layer( dict( type='Conv2d', in_channels=gfd_channels * 2, out_channels=gfd_channels, kernel_size=1)) self.heatmap_conv = build_conv_layer( dict( type='Conv2d', in_channels=gfd_channels, out_channels=out_channels, kernel_size=1)) self.sigmoid = TruncSigmoid(min=clamp_delta, max=1 - clamp_delta) def forward( self, feats: Tensor, instance_feats: Tensor, instance_coords: Tensor, instance_imgids: Tensor, ) -> Tensor: """Extract decoupled heatmaps for each instance. Args: feats (Tensor): Input feature maps. instance_feats (Tensor): Tensor containing the instance feature vectors. instance_coords (Tensor): Tensor containing the root coordinates of the instances. instance_imgids (Tensor): Sample indices of each instances in the batch. Returns: A tensor containing decoupled heatmaps. """ global_feats = self.conv_down(feats) global_feats = global_feats[instance_imgids] cond_instance_feats = torch.cat( (self.channel_attention(global_feats, instance_feats), self.spatial_attention(global_feats, instance_feats, instance_coords)), dim=1) cond_instance_feats = self.fuse_attention(cond_instance_feats) cond_instance_feats = torch.nn.functional.relu(cond_instance_feats) cond_instance_feats = self.heatmap_conv(cond_instance_feats) heatmaps = self.sigmoid(cond_instance_feats) return heatmaps @MODELS.register_module() class CIDHead(BaseHead): """Contextual Instance Decoupling head introduced in `Contextual Instance Decoupling for Robust Multi-Person Pose Estimation (CID)`_ by Wang et al (2022). The head is composed of an Instance Information Abstraction (IIA) module and a Global Feature Decoupling (GFD) module. Args: in_channels (int | Sequence[int]): Number of channels in the input feature map num_keypoints (int): Number of keypoints gfd_channels (int): Number of filters in GFD module max_train_instances (int): Maximum number of instances in a batch during training. Defaults to 200 heatmap_loss (Config): Config of the heatmap loss. Defaults to use :class:`KeypointMSELoss` coupled_heatmap_loss (Config): Config of the loss for coupled heatmaps. Defaults to use :class:`SoftWeightSmoothL1Loss` decoupled_heatmap_loss (Config): Config of the loss for decoupled heatmaps. Defaults to use :class:`SoftWeightSmoothL1Loss` contrastive_loss (Config): Config of the contrastive loss for representation vectors of instances. Defaults to use :class:`InfoNCELoss` decoder (Config, optional): The decoder config that controls decoding keypoint coordinates from the network output. Defaults to ``None`` init_cfg (Config, optional): Config to control the initialization. See :attr:`default_init_cfg` for default settings .. _`CID`: https://openaccess.thecvf.com/content/CVPR2022/html/Wang_ Contextual_Instance_Decoupling_for_Robust_Multi-Person_Pose_Estimation_ CVPR_2022_paper.html """ _version = 2 def __init__(self, in_channels: Union[int, Sequence[int]], gfd_channels: int, num_keypoints: int, prior_prob: float = 0.01, coupled_heatmap_loss: OptConfigType = dict( type='FocalHeatmapLoss'), decoupled_heatmap_loss: OptConfigType = dict( type='FocalHeatmapLoss'), contrastive_loss: OptConfigType = dict(type='InfoNCELoss'), decoder: OptConfigType = None, init_cfg: OptConfigType = None): if init_cfg is None: init_cfg = self.default_init_cfg super().__init__(init_cfg) self.in_channels = in_channels self.num_keypoints = num_keypoints if decoder is not None: self.decoder = KEYPOINT_CODECS.build(decoder) else: self.decoder = None # build sub-modules bias_value = -math.log((1 - prior_prob) / prior_prob) self.iia_module = IIAModule( in_channels, num_keypoints + 1, init_cfg=init_cfg + [ dict( type='Normal', layer=['Conv2d', 'Linear'], std=0.001, override=dict( name='keypoint_root_conv', type='Normal', std=0.001, bias=bias_value)) ]) self.gfd_module = GFDModule( in_channels, num_keypoints, gfd_channels, init_cfg=init_cfg + [ dict( type='Normal', layer=['Conv2d', 'Linear'], std=0.001, override=dict( name='heatmap_conv', type='Normal', std=0.001, bias=bias_value)) ]) # build losses self.loss_module = ModuleDict( dict( heatmap_coupled=MODELS.build(coupled_heatmap_loss), heatmap_decoupled=MODELS.build(decoupled_heatmap_loss), contrastive=MODELS.build(contrastive_loss), )) # Register the hook to automatically convert old version state dicts self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) @property def default_init_cfg(self): init_cfg = [ dict(type='Normal', layer=['Conv2d', 'Linear'], std=0.001), dict(type='Constant', layer='BatchNorm2d', val=1) ] return init_cfg def forward(self, feats: Tuple[Tensor]) -> Tensor: """Forward the network. The input is multi scale feature maps and the output is the heatmap. Args: feats (Tuple[Tensor]): Multi scale feature maps. Returns: Tensor: output heatmap. """ feats = feats[-1] instance_info = self.iia_module.forward_test(feats, {}) instance_feats, instance_coords, instance_scores = instance_info instance_imgids = torch.zeros( instance_coords.size(0), dtype=torch.long, device=feats.device) instance_heatmaps = self.gfd_module(feats, instance_feats, instance_coords, instance_imgids) return instance_heatmaps def predict(self, feats: Features, batch_data_samples: OptSampleList, test_cfg: ConfigType = {}) -> Predictions: """Predict results from features. Args: feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage features (or multiple multi-stage features in TTA) batch_data_samples (List[:obj:`PoseDataSample`]): The batch data samples test_cfg (dict): The runtime config for testing process. Defaults to {} Returns: Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If ``test_cfg['output_heatmap']==True``, return both pose and heatmap prediction; otherwise only return the pose prediction. The pose prediction is a list of ``InstanceData``, each contains the following fields: - keypoints (np.ndarray): predicted keypoint coordinates in shape (num_instances, K, D) where K is the keypoint number and D is the keypoint dimension - keypoint_scores (np.ndarray): predicted keypoint scores in shape (num_instances, K) The heatmap prediction is a list of ``PixelData``, each contains the following fields: - heatmaps (Tensor): The predicted heatmaps in shape (K, h, w) """ metainfo = batch_data_samples[0].metainfo if test_cfg.get('flip_test', False): assert isinstance(feats, list) and len(feats) == 2 feats_flipped = flip_heatmaps(feats[1][-1], shift_heatmap=False) feats = torch.cat((feats[0][-1], feats_flipped)) else: feats = feats[-1] instance_info = self.iia_module.forward_test(feats, test_cfg) instance_feats, instance_coords, instance_scores = instance_info if len(instance_coords) > 0: instance_imgids = torch.zeros( instance_coords.size(0), dtype=torch.long, device=feats.device) if test_cfg.get('flip_test', False): instance_coords = torch.cat((instance_coords, instance_coords)) instance_imgids = torch.cat( (instance_imgids, instance_imgids + 1)) instance_heatmaps = self.gfd_module(feats, instance_feats, instance_coords, instance_imgids) if test_cfg.get('flip_test', False): flip_indices = batch_data_samples[0].metainfo['flip_indices'] instance_heatmaps, instance_heatmaps_flip = torch.chunk( instance_heatmaps, 2, dim=0) instance_heatmaps_flip = \ instance_heatmaps_flip[:, flip_indices, :, :] instance_heatmaps = (instance_heatmaps + instance_heatmaps_flip) / 2.0 instance_heatmaps = smooth_heatmaps( instance_heatmaps, test_cfg.get('blur_kernel_size', 3)) preds = self.decode((instance_heatmaps, instance_scores[:, None])) preds = InstanceData.cat(preds) preds.keypoints[..., 0] += metainfo['input_size'][ 0] / instance_heatmaps.shape[-1] / 2.0 preds.keypoints[..., 1] += metainfo['input_size'][ 1] / instance_heatmaps.shape[-2] / 2.0 preds = [preds] else: preds = [ InstanceData( keypoints=np.empty((0, self.num_keypoints, 2)), keypoint_scores=np.empty((0, self.num_keypoints))) ] instance_heatmaps = torch.empty(0, self.num_keypoints, *feats.shape[-2:]) if test_cfg.get('output_heatmaps', False): pred_fields = [ PixelData( heatmaps=instance_heatmaps.reshape( -1, *instance_heatmaps.shape[-2:])) ] return preds, pred_fields else: return preds def loss(self, feats: Tuple[Tensor], batch_data_samples: OptSampleList, train_cfg: ConfigType = {}) -> dict: """Calculate losses from a batch of inputs and data samples. Args: feats (Tuple[Tensor]): The multi-stage features batch_data_samples (List[:obj:`PoseDataSample`]): The batch data samples train_cfg (dict): The runtime config for training process. Defaults to {} Returns: dict: A dictionary of losses. """ # load targets gt_heatmaps, gt_instance_coords, keypoint_weights = [], [], [] heatmap_mask = [] instance_imgids, gt_instance_heatmaps = [], [] for i, d in enumerate(batch_data_samples): gt_heatmaps.append(d.gt_fields.heatmaps) gt_instance_coords.append(d.gt_instance_labels.instance_coords) keypoint_weights.append(d.gt_instance_labels.keypoint_weights) instance_imgids.append( torch.ones( len(d.gt_instance_labels.instance_coords), dtype=torch.long) * i) instance_heatmaps = d.gt_fields.instance_heatmaps.reshape( -1, self.num_keypoints, *d.gt_fields.instance_heatmaps.shape[1:]) gt_instance_heatmaps.append(instance_heatmaps) if 'heatmap_mask' in d.gt_fields: heatmap_mask.append(d.gt_fields.heatmap_mask) gt_heatmaps = torch.stack(gt_heatmaps) heatmap_mask = torch.stack(heatmap_mask) if heatmap_mask else None gt_instance_coords = torch.cat(gt_instance_coords, dim=0) gt_instance_heatmaps = torch.cat(gt_instance_heatmaps, dim=0) keypoint_weights = torch.cat(keypoint_weights, dim=0) instance_imgids = torch.cat(instance_imgids).to(gt_heatmaps.device) # feed-forward feats = feats[-1] pred_instance_feats, pred_heatmaps = self.iia_module.forward_train( feats, gt_instance_coords, instance_imgids) # conpute contrastive loss contrastive_loss = 0 for i in range(len(batch_data_samples)): pred_instance_feat = pred_instance_feats[instance_imgids == i] contrastive_loss += self.loss_module['contrastive']( pred_instance_feat) contrastive_loss = contrastive_loss / max(1, len(instance_imgids)) # limit the number of instances max_train_instances = train_cfg.get('max_train_instances', -1) if (max_train_instances > 0 and len(instance_imgids) > max_train_instances): selected_indices = torch.randperm( len(instance_imgids), device=gt_heatmaps.device, dtype=torch.long)[:max_train_instances] gt_instance_coords = gt_instance_coords[selected_indices] keypoint_weights = keypoint_weights[selected_indices] gt_instance_heatmaps = gt_instance_heatmaps[selected_indices] instance_imgids = instance_imgids[selected_indices] pred_instance_feats = pred_instance_feats[selected_indices] # calculate the decoupled heatmaps for each instance pred_instance_heatmaps = self.gfd_module(feats, pred_instance_feats, gt_instance_coords, instance_imgids) # calculate losses losses = { 'loss/heatmap_coupled': self.loss_module['heatmap_coupled'](pred_heatmaps, gt_heatmaps, None, heatmap_mask) } if len(instance_imgids) > 0: losses.update({ 'loss/heatmap_decoupled': self.loss_module['heatmap_decoupled'](pred_instance_heatmaps, gt_instance_heatmaps, keypoint_weights), 'loss/contrastive': contrastive_loss }) return losses def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args, **kwargs): """A hook function to convert old-version state dict of :class:`CIDHead` (before MMPose v1.0.0) to a compatible format of :class:`CIDHead`. The hook will be automatically registered during initialization. """ version = local_meta.get('version', None) if version and version >= self._version: return # convert old-version state dict keys = list(state_dict.keys()) for k in keys: if 'keypoint_center_conv' in k: v = state_dict.pop(k) k = k.replace('keypoint_center_conv', 'iia_module.keypoint_root_conv') state_dict[k] = v if 'conv_down' in k: v = state_dict.pop(k) k = k.replace('conv_down', 'gfd_module.conv_down') state_dict[k] = v if 'c_attn' in k: v = state_dict.pop(k) k = k.replace('c_attn', 'gfd_module.channel_attention') state_dict[k] = v if 's_attn' in k: v = state_dict.pop(k) k = k.replace('s_attn', 'gfd_module.spatial_attention') state_dict[k] = v if 'fuse_attn' in k: v = state_dict.pop(k) k = k.replace('fuse_attn', 'gfd_module.fuse_attention') state_dict[k] = v if 'heatmap_conv' in k: v = state_dict.pop(k) k = k.replace('heatmap_conv', 'gfd_module.heatmap_conv') state_dict[k] = v