# Copyright (c) OpenMMLab. All rights reserved. from abc import ABCMeta, abstractmethod from typing import Tuple, Union import torch from mmengine.model import BaseModel from torch import Tensor from mmpose.datasets.datasets.utils import parse_pose_metainfo from mmpose.models.utils import check_and_update_config from mmpose.registry import MODELS from mmpose.utils.typing import (ConfigType, ForwardResults, OptConfigType, Optional, OptMultiConfig, OptSampleList, SampleList) class BasePoseEstimator(BaseModel, metaclass=ABCMeta): """Base class for pose estimators. Args: data_preprocessor (dict | ConfigDict, optional): The pre-processing config of :class:`BaseDataPreprocessor`. Defaults to ``None`` init_cfg (dict | ConfigDict): The model initialization config. Defaults to ``None`` metainfo (dict): Meta information for dataset, such as keypoints definition and properties. If set, the metainfo of the input data batch will be overridden. For more details, please refer to https://mmpose.readthedocs.io/en/latest/user_guides/ prepare_datasets.html#create-a-custom-dataset-info- config-file-for-the-dataset. Defaults to ``None`` """ _version = 2 def __init__(self, backbone: ConfigType, neck: OptConfigType = None, head: OptConfigType = None, train_cfg: OptConfigType = None, test_cfg: OptConfigType = None, data_preprocessor: OptConfigType = None, init_cfg: OptMultiConfig = None, metainfo: Optional[dict] = None): super().__init__( data_preprocessor=data_preprocessor, init_cfg=init_cfg) self.metainfo = self._load_metainfo(metainfo) self.backbone = MODELS.build(backbone) # the PR #2108 and #2126 modified the interface of neck and head. # The following function automatically detects outdated # configurations and updates them accordingly, while also providing # clear and concise information on the changes made. neck, head = check_and_update_config(neck, head) if neck is not None: self.neck = MODELS.build(neck) if head is not None: self.head = MODELS.build(head) self.train_cfg = train_cfg if train_cfg else {} self.test_cfg = test_cfg if test_cfg else {} # Register the hook to automatically convert old version state dicts self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) @property def with_neck(self) -> bool: """bool: whether the pose estimator has a neck.""" return hasattr(self, 'neck') and self.neck is not None @property def with_head(self) -> bool: """bool: whether the pose estimator has a head.""" return hasattr(self, 'head') and self.head is not None @staticmethod def _load_metainfo(metainfo: dict = None) -> dict: """Collect meta information from the dictionary of meta. Args: metainfo (dict): Raw data of pose meta information. Returns: dict: Parsed meta information. """ if metainfo is None: return None if not isinstance(metainfo, dict): raise TypeError( f'metainfo should be a dict, but got {type(metainfo)}') metainfo = parse_pose_metainfo(metainfo) return metainfo def forward(self, inputs: torch.Tensor, data_samples: OptSampleList, mode: str = 'tensor') -> ForwardResults: """The unified entry for a forward process in both training and test. The method should accept three modes: 'tensor', 'predict' and 'loss': - 'tensor': Forward the whole network and return tensor or tuple of tensor without any post-processing, same as a common nn.Module. - 'predict': Forward and return the predictions, which are fully processed to a list of :obj:`PoseDataSample`. - 'loss': Forward and return a dict of losses according to the given inputs and data samples. Note that this method doesn't handle neither back propagation nor optimizer updating, which are done in the :meth:`train_step`. Args: inputs (torch.Tensor): The input tensor with shape (N, C, ...) in general data_samples (list[:obj:`PoseDataSample`], optional): The annotation of every sample. Defaults to ``None`` mode (str): Set the forward mode and return value type. Defaults to ``'tensor'`` Returns: The return type depends on ``mode``. - If ``mode='tensor'``, return a tensor or a tuple of tensors - If ``mode='predict'``, return a list of :obj:``PoseDataSample`` that contains the pose predictions - If ``mode='loss'``, return a dict of tensor(s) which is the loss function value """ if mode == 'loss': return self.loss(inputs, data_samples) elif mode == 'predict': # use customed metainfo to override the default metainfo if self.metainfo is not None: for data_sample in data_samples: data_sample.set_metainfo(self.metainfo) return self.predict(inputs, data_samples) elif mode == 'tensor': return self._forward(inputs) else: raise RuntimeError(f'Invalid mode "{mode}". ' 'Only supports loss, predict and tensor mode.') @abstractmethod def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: """Calculate losses from a batch of inputs and data samples.""" @abstractmethod def predict(self, inputs: Tensor, data_samples: SampleList) -> SampleList: """Predict results from a batch of inputs and data samples with post- processing.""" def _forward(self, inputs: Tensor, data_samples: OptSampleList = None ) -> Union[Tensor, Tuple[Tensor]]: """Network forward process. Usually includes backbone, neck and head forward without any post-processing. Args: inputs (Tensor): Inputs with shape (N, C, H, W). Returns: Union[Tensor | Tuple[Tensor]]: forward output of the network. """ x = self.extract_feat(inputs) if self.with_head: x = self.head.forward(x) return x def extract_feat(self, inputs: Tensor) -> Tuple[Tensor]: """Extract features. Args: inputs (Tensor): Image tensor with shape (N, C, H ,W). Returns: tuple[Tensor]: Multi-level features that may have various resolutions. """ x = self.backbone(inputs) if self.with_neck: x = self.neck(x) return x def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args, **kwargs): """A hook function to convert old-version state dict of :class:`TopdownHeatmapSimpleHead` (before MMPose v1.0.0) to a compatible format of :class:`HeatmapHead`. The hook will be automatically registered during initialization. """ version = local_meta.get('version', None) if version and version >= self._version: return # convert old-version state dict keys = list(state_dict.keys()) for k in keys: if 'keypoint_head' in k: v = state_dict.pop(k) k = k.replace('keypoint_head', 'head') state_dict[k] = v