Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Sequence, Union | |
import numpy as np | |
import torch | |
from mmcv.transforms import BaseTransform | |
from mmengine.structures import InstanceData, PixelData | |
from mmengine.utils import is_seq_of | |
from mmpose.registry import TRANSFORMS | |
from mmpose.structures import MultilevelPixelData, PoseDataSample | |
def image_to_tensor(img: Union[np.ndarray, | |
Sequence[np.ndarray]]) -> torch.torch.Tensor: | |
"""Translate image or sequence of images to tensor. Multiple image tensors | |
will be stacked. | |
Args: | |
value (np.ndarray | Sequence[np.ndarray]): The original image or | |
image sequence | |
Returns: | |
torch.Tensor: The output tensor. | |
""" | |
if isinstance(img, np.ndarray): | |
if len(img.shape) < 3: | |
img = np.expand_dims(img, -1) | |
img = np.ascontiguousarray(img) | |
tensor = torch.from_numpy(img).permute(2, 0, 1).contiguous() | |
else: | |
assert is_seq_of(img, np.ndarray) | |
tensor = torch.stack([image_to_tensor(_img) for _img in img]) | |
return tensor | |
class PackPoseInputs(BaseTransform): | |
"""Pack the inputs data for pose estimation. | |
The ``img_meta`` item is always populated. The contents of the | |
``img_meta`` dictionary depends on ``meta_keys``. By default it includes: | |
- ``id``: id of the data sample | |
- ``img_id``: id of the image | |
- ``'category_id'``: the id of the instance category | |
- ``img_path``: path to the image file | |
- ``crowd_index`` (optional): measure the crowding level of an image, | |
defined in CrowdPose dataset | |
- ``ori_shape``: original shape of the image as a tuple (h, w, c) | |
- ``img_shape``: shape of the image input to the network as a tuple \ | |
(h, w). Note that images may be zero padded on the \ | |
bottom/right if the batch tensor is larger than this shape. | |
- ``input_size``: the input size to the network | |
- ``flip``: a boolean indicating if image flip transform was used | |
- ``flip_direction``: the flipping direction | |
- ``flip_indices``: the indices of each keypoint's symmetric keypoint | |
- ``raw_ann_info`` (optional): raw annotation of the instance(s) | |
Args: | |
meta_keys (Sequence[str], optional): Meta keys which will be stored in | |
:obj: `PoseDataSample` as meta info. Defaults to ``('id', | |
'img_id', 'img_path', 'category_id', 'crowd_index, 'ori_shape', | |
'img_shape',, 'input_size', 'input_center', 'input_scale', 'flip', | |
'flip_direction', 'flip_indices', 'raw_ann_info')`` | |
""" | |
# items in `instance_mapping_table` will be directly packed into | |
# PoseDataSample.gt_instances without converting to Tensor | |
instance_mapping_table = { | |
'bbox': 'bboxes', | |
'head_size': 'head_size', | |
'bbox_center': 'bbox_centers', | |
'bbox_scale': 'bbox_scales', | |
'bbox_score': 'bbox_scores', | |
'keypoints': 'keypoints', | |
'keypoints_visible': 'keypoints_visible', | |
} | |
# items in `label_mapping_table` will be packed into | |
# PoseDataSample.gt_instance_labels and converted to Tensor. These items | |
# will be used for computing losses | |
label_mapping_table = { | |
'keypoint_labels': 'keypoint_labels', | |
'keypoint_x_labels': 'keypoint_x_labels', | |
'keypoint_y_labels': 'keypoint_y_labels', | |
'keypoint_weights': 'keypoint_weights', | |
'instance_coords': 'instance_coords' | |
} | |
# items in `field_mapping_table` will be packed into | |
# PoseDataSample.gt_fields and converted to Tensor. These items will be | |
# used for computing losses | |
field_mapping_table = { | |
'heatmaps': 'heatmaps', | |
'instance_heatmaps': 'instance_heatmaps', | |
'heatmap_mask': 'heatmap_mask', | |
'heatmap_weights': 'heatmap_weights', | |
'displacements': 'displacements', | |
'displacement_weights': 'displacement_weights', | |
} | |
def __init__(self, | |
meta_keys=('id', 'img_id', 'img_path', 'category_id', | |
'crowd_index', 'ori_shape', 'img_shape', | |
'input_size', 'input_center', 'input_scale', | |
'flip', 'flip_direction', 'flip_indices', | |
'raw_ann_info'), | |
pack_transformed=False): | |
self.meta_keys = meta_keys | |
self.pack_transformed = pack_transformed | |
def transform(self, results: dict) -> dict: | |
"""Method to pack the input data. | |
Args: | |
results (dict): Result dict from the data pipeline. | |
Returns: | |
dict: | |
- 'inputs' (obj:`torch.Tensor`): The forward data of models. | |
- 'data_samples' (obj:`PoseDataSample`): The annotation info of the | |
sample. | |
""" | |
# Pack image(s) | |
if 'img' in results: | |
img = results['img'] | |
img_tensor = image_to_tensor(img) | |
data_sample = PoseDataSample() | |
# pack instance data | |
gt_instances = InstanceData() | |
for key, packed_key in self.instance_mapping_table.items(): | |
if key in results: | |
gt_instances.set_field(results[key], packed_key) | |
# pack `transformed_keypoints` for visualizing data transform | |
# and augmentation results | |
if self.pack_transformed and 'transformed_keypoints' in results: | |
gt_instances.set_field(results['transformed_keypoints'], | |
'transformed_keypoints') | |
data_sample.gt_instances = gt_instances | |
# pack instance labels | |
gt_instance_labels = InstanceData() | |
for key, packed_key in self.label_mapping_table.items(): | |
if key in results: | |
if isinstance(results[key], list): | |
# A list of labels is usually generated by combined | |
# multiple encoders (See ``GenerateTarget`` in | |
# mmpose/datasets/transforms/common_transforms.py) | |
# In this case, labels in list should have the same | |
# shape and will be stacked. | |
_labels = np.stack(results[key]) | |
gt_instance_labels.set_field(_labels, packed_key) | |
else: | |
gt_instance_labels.set_field(results[key], packed_key) | |
data_sample.gt_instance_labels = gt_instance_labels.to_tensor() | |
# pack fields | |
gt_fields = None | |
for key, packed_key in self.field_mapping_table.items(): | |
if key in results: | |
if isinstance(results[key], list): | |
if gt_fields is None: | |
gt_fields = MultilevelPixelData() | |
else: | |
assert isinstance( | |
gt_fields, MultilevelPixelData | |
), 'Got mixed single-level and multi-level pixel data.' | |
else: | |
if gt_fields is None: | |
gt_fields = PixelData() | |
else: | |
assert isinstance( | |
gt_fields, PixelData | |
), 'Got mixed single-level and multi-level pixel data.' | |
gt_fields.set_field(results[key], packed_key) | |
if gt_fields: | |
data_sample.gt_fields = gt_fields.to_tensor() | |
img_meta = {k: results[k] for k in self.meta_keys if k in results} | |
data_sample.set_metainfo(img_meta) | |
packed_results = dict() | |
packed_results['inputs'] = img_tensor | |
packed_results['data_samples'] = data_sample | |
return packed_results | |
def __repr__(self) -> str: | |
"""print the basic information of the transform. | |
Returns: | |
str: Formatted string. | |
""" | |
repr_str = self.__class__.__name__ | |
repr_str += f'(meta_keys={self.meta_keys})' | |
return repr_str | |