liyy201912's picture
Upload folder using huggingface_hub
cc0dd3c
raw
history blame
7.95 kB
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Tuple, Union
import torch
from mmengine.model import BaseModel
from torch import Tensor
from mmpose.datasets.datasets.utils import parse_pose_metainfo
from mmpose.models.utils import check_and_update_config
from mmpose.registry import MODELS
from mmpose.utils.typing import (ConfigType, ForwardResults, OptConfigType,
Optional, OptMultiConfig, OptSampleList,
SampleList)
class BasePoseEstimator(BaseModel, metaclass=ABCMeta):
"""Base class for pose estimators.
Args:
data_preprocessor (dict | ConfigDict, optional): The pre-processing
config of :class:`BaseDataPreprocessor`. Defaults to ``None``
init_cfg (dict | ConfigDict): The model initialization config.
Defaults to ``None``
metainfo (dict): Meta information for dataset, such as keypoints
definition and properties. If set, the metainfo of the input data
batch will be overridden. For more details, please refer to
https://mmpose.readthedocs.io/en/latest/user_guides/
prepare_datasets.html#create-a-custom-dataset-info-
config-file-for-the-dataset. Defaults to ``None``
"""
_version = 2
def __init__(self,
backbone: ConfigType,
neck: OptConfigType = None,
head: OptConfigType = None,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None,
metainfo: Optional[dict] = None):
super().__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
self.metainfo = self._load_metainfo(metainfo)
self.backbone = MODELS.build(backbone)
# the PR #2108 and #2126 modified the interface of neck and head.
# The following function automatically detects outdated
# configurations and updates them accordingly, while also providing
# clear and concise information on the changes made.
neck, head = check_and_update_config(neck, head)
if neck is not None:
self.neck = MODELS.build(neck)
if head is not None:
self.head = MODELS.build(head)
self.train_cfg = train_cfg if train_cfg else {}
self.test_cfg = test_cfg if test_cfg else {}
# Register the hook to automatically convert old version state dicts
self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
@property
def with_neck(self) -> bool:
"""bool: whether the pose estimator has a neck."""
return hasattr(self, 'neck') and self.neck is not None
@property
def with_head(self) -> bool:
"""bool: whether the pose estimator has a head."""
return hasattr(self, 'head') and self.head is not None
@staticmethod
def _load_metainfo(metainfo: dict = None) -> dict:
"""Collect meta information from the dictionary of meta.
Args:
metainfo (dict): Raw data of pose meta information.
Returns:
dict: Parsed meta information.
"""
if metainfo is None:
return None
if not isinstance(metainfo, dict):
raise TypeError(
f'metainfo should be a dict, but got {type(metainfo)}')
metainfo = parse_pose_metainfo(metainfo)
return metainfo
def forward(self,
inputs: torch.Tensor,
data_samples: OptSampleList,
mode: str = 'tensor') -> ForwardResults:
"""The unified entry for a forward process in both training and test.
The method should accept three modes: 'tensor', 'predict' and 'loss':
- 'tensor': Forward the whole network and return tensor or tuple of
tensor without any post-processing, same as a common nn.Module.
- 'predict': Forward and return the predictions, which are fully
processed to a list of :obj:`PoseDataSample`.
- 'loss': Forward and return a dict of losses according to the given
inputs and data samples.
Note that this method doesn't handle neither back propagation nor
optimizer updating, which are done in the :meth:`train_step`.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general
data_samples (list[:obj:`PoseDataSample`], optional): The
annotation of every sample. Defaults to ``None``
mode (str): Set the forward mode and return value type. Defaults
to ``'tensor'``
Returns:
The return type depends on ``mode``.
- If ``mode='tensor'``, return a tensor or a tuple of tensors
- If ``mode='predict'``, return a list of :obj:``PoseDataSample``
that contains the pose predictions
- If ``mode='loss'``, return a dict of tensor(s) which is the loss
function value
"""
if mode == 'loss':
return self.loss(inputs, data_samples)
elif mode == 'predict':
# use customed metainfo to override the default metainfo
if self.metainfo is not None:
for data_sample in data_samples:
data_sample.set_metainfo(self.metainfo)
return self.predict(inputs, data_samples)
elif mode == 'tensor':
return self._forward(inputs)
else:
raise RuntimeError(f'Invalid mode "{mode}". '
'Only supports loss, predict and tensor mode.')
@abstractmethod
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
"""Calculate losses from a batch of inputs and data samples."""
@abstractmethod
def predict(self, inputs: Tensor, data_samples: SampleList) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing."""
def _forward(self,
inputs: Tensor,
data_samples: OptSampleList = None
) -> Union[Tensor, Tuple[Tensor]]:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
Args:
inputs (Tensor): Inputs with shape (N, C, H, W).
Returns:
Union[Tensor | Tuple[Tensor]]: forward output of the network.
"""
x = self.extract_feat(inputs)
if self.with_head:
x = self.head.forward(x)
return x
def extract_feat(self, inputs: Tensor) -> Tuple[Tensor]:
"""Extract features.
Args:
inputs (Tensor): Image tensor with shape (N, C, H ,W).
Returns:
tuple[Tensor]: Multi-level features that may have various
resolutions.
"""
x = self.backbone(inputs)
if self.with_neck:
x = self.neck(x)
return x
def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args,
**kwargs):
"""A hook function to convert old-version state dict of
:class:`TopdownHeatmapSimpleHead` (before MMPose v1.0.0) to a
compatible format of :class:`HeatmapHead`.
The hook will be automatically registered during initialization.
"""
version = local_meta.get('version', None)
if version and version >= self._version:
return
# convert old-version state dict
keys = list(state_dict.keys())
for k in keys:
if 'keypoint_head' in k:
v = state_dict.pop(k)
k = k.replace('keypoint_head', 'head')
state_dict[k] = v