Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from abc import ABCMeta, abstractmethod | |
| from typing import Tuple, Union | |
| import torch | |
| from mmengine.model import BaseModel | |
| from torch import Tensor | |
| from mmpose.datasets.datasets.utils import parse_pose_metainfo | |
| from mmpose.models.utils import check_and_update_config | |
| from mmpose.registry import MODELS | |
| from mmpose.utils.typing import (ConfigType, ForwardResults, OptConfigType, | |
| Optional, OptMultiConfig, OptSampleList, | |
| SampleList) | |
| class BasePoseEstimator(BaseModel, metaclass=ABCMeta): | |
| """Base class for pose estimators. | |
| Args: | |
| data_preprocessor (dict | ConfigDict, optional): The pre-processing | |
| config of :class:`BaseDataPreprocessor`. Defaults to ``None`` | |
| init_cfg (dict | ConfigDict): The model initialization config. | |
| Defaults to ``None`` | |
| metainfo (dict): Meta information for dataset, such as keypoints | |
| definition and properties. If set, the metainfo of the input data | |
| batch will be overridden. For more details, please refer to | |
| https://mmpose.readthedocs.io/en/latest/user_guides/ | |
| prepare_datasets.html#create-a-custom-dataset-info- | |
| config-file-for-the-dataset. Defaults to ``None`` | |
| """ | |
| _version = 2 | |
| def __init__(self, | |
| backbone: ConfigType, | |
| neck: OptConfigType = None, | |
| head: OptConfigType = None, | |
| train_cfg: OptConfigType = None, | |
| test_cfg: OptConfigType = None, | |
| data_preprocessor: OptConfigType = None, | |
| init_cfg: OptMultiConfig = None, | |
| metainfo: Optional[dict] = None): | |
| super().__init__( | |
| data_preprocessor=data_preprocessor, init_cfg=init_cfg) | |
| self.metainfo = self._load_metainfo(metainfo) | |
| self.backbone = MODELS.build(backbone) | |
| # the PR #2108 and #2126 modified the interface of neck and head. | |
| # The following function automatically detects outdated | |
| # configurations and updates them accordingly, while also providing | |
| # clear and concise information on the changes made. | |
| neck, head = check_and_update_config(neck, head) | |
| if neck is not None: | |
| self.neck = MODELS.build(neck) | |
| if head is not None: | |
| self.head = MODELS.build(head) | |
| self.train_cfg = train_cfg if train_cfg else {} | |
| self.test_cfg = test_cfg if test_cfg else {} | |
| # Register the hook to automatically convert old version state dicts | |
| self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) | |
| def with_neck(self) -> bool: | |
| """bool: whether the pose estimator has a neck.""" | |
| return hasattr(self, 'neck') and self.neck is not None | |
| def with_head(self) -> bool: | |
| """bool: whether the pose estimator has a head.""" | |
| return hasattr(self, 'head') and self.head is not None | |
| def _load_metainfo(metainfo: dict = None) -> dict: | |
| """Collect meta information from the dictionary of meta. | |
| Args: | |
| metainfo (dict): Raw data of pose meta information. | |
| Returns: | |
| dict: Parsed meta information. | |
| """ | |
| if metainfo is None: | |
| return None | |
| if not isinstance(metainfo, dict): | |
| raise TypeError( | |
| f'metainfo should be a dict, but got {type(metainfo)}') | |
| metainfo = parse_pose_metainfo(metainfo) | |
| return metainfo | |
| def forward(self, | |
| inputs: torch.Tensor, | |
| data_samples: OptSampleList, | |
| mode: str = 'tensor') -> ForwardResults: | |
| """The unified entry for a forward process in both training and test. | |
| The method should accept three modes: 'tensor', 'predict' and 'loss': | |
| - 'tensor': Forward the whole network and return tensor or tuple of | |
| tensor without any post-processing, same as a common nn.Module. | |
| - 'predict': Forward and return the predictions, which are fully | |
| processed to a list of :obj:`PoseDataSample`. | |
| - 'loss': Forward and return a dict of losses according to the given | |
| inputs and data samples. | |
| Note that this method doesn't handle neither back propagation nor | |
| optimizer updating, which are done in the :meth:`train_step`. | |
| Args: | |
| inputs (torch.Tensor): The input tensor with shape | |
| (N, C, ...) in general | |
| data_samples (list[:obj:`PoseDataSample`], optional): The | |
| annotation of every sample. Defaults to ``None`` | |
| mode (str): Set the forward mode and return value type. Defaults | |
| to ``'tensor'`` | |
| Returns: | |
| The return type depends on ``mode``. | |
| - If ``mode='tensor'``, return a tensor or a tuple of tensors | |
| - If ``mode='predict'``, return a list of :obj:``PoseDataSample`` | |
| that contains the pose predictions | |
| - If ``mode='loss'``, return a dict of tensor(s) which is the loss | |
| function value | |
| """ | |
| if mode == 'loss': | |
| return self.loss(inputs, data_samples) | |
| elif mode == 'predict': | |
| # use customed metainfo to override the default metainfo | |
| if self.metainfo is not None: | |
| for data_sample in data_samples: | |
| data_sample.set_metainfo(self.metainfo) | |
| return self.predict(inputs, data_samples) | |
| elif mode == 'tensor': | |
| return self._forward(inputs) | |
| else: | |
| raise RuntimeError(f'Invalid mode "{mode}". ' | |
| 'Only supports loss, predict and tensor mode.') | |
| def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: | |
| """Calculate losses from a batch of inputs and data samples.""" | |
| def predict(self, inputs: Tensor, data_samples: SampleList) -> SampleList: | |
| """Predict results from a batch of inputs and data samples with post- | |
| processing.""" | |
| def _forward(self, | |
| inputs: Tensor, | |
| data_samples: OptSampleList = None | |
| ) -> Union[Tensor, Tuple[Tensor]]: | |
| """Network forward process. Usually includes backbone, neck and head | |
| forward without any post-processing. | |
| Args: | |
| inputs (Tensor): Inputs with shape (N, C, H, W). | |
| Returns: | |
| Union[Tensor | Tuple[Tensor]]: forward output of the network. | |
| """ | |
| x = self.extract_feat(inputs) | |
| if self.with_head: | |
| x = self.head.forward(x) | |
| return x | |
| def extract_feat(self, inputs: Tensor) -> Tuple[Tensor]: | |
| """Extract features. | |
| Args: | |
| inputs (Tensor): Image tensor with shape (N, C, H ,W). | |
| Returns: | |
| tuple[Tensor]: Multi-level features that may have various | |
| resolutions. | |
| """ | |
| x = self.backbone(inputs) | |
| if self.with_neck: | |
| x = self.neck(x) | |
| return x | |
| def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args, | |
| **kwargs): | |
| """A hook function to convert old-version state dict of | |
| :class:`TopdownHeatmapSimpleHead` (before MMPose v1.0.0) to a | |
| compatible format of :class:`HeatmapHead`. | |
| The hook will be automatically registered during initialization. | |
| """ | |
| version = local_meta.get('version', None) | |
| if version and version >= self._version: | |
| return | |
| # convert old-version state dict | |
| keys = list(state_dict.keys()) | |
| for k in keys: | |
| if 'keypoint_head' in k: | |
| v = state_dict.pop(k) | |
| k = k.replace('keypoint_head', 'head') | |
| state_dict[k] = v | |