Spaces:
Runtime error
Runtime error
File size: 8,040 Bytes
cc0dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Union
import numpy as np
import torch
from mmcv.transforms import BaseTransform
from mmengine.structures import InstanceData, PixelData
from mmengine.utils import is_seq_of
from mmpose.registry import TRANSFORMS
from mmpose.structures import MultilevelPixelData, PoseDataSample
def image_to_tensor(img: Union[np.ndarray,
Sequence[np.ndarray]]) -> torch.torch.Tensor:
"""Translate image or sequence of images to tensor. Multiple image tensors
will be stacked.
Args:
value (np.ndarray | Sequence[np.ndarray]): The original image or
image sequence
Returns:
torch.Tensor: The output tensor.
"""
if isinstance(img, np.ndarray):
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img)
tensor = torch.from_numpy(img).permute(2, 0, 1).contiguous()
else:
assert is_seq_of(img, np.ndarray)
tensor = torch.stack([image_to_tensor(_img) for _img in img])
return tensor
@TRANSFORMS.register_module()
class PackPoseInputs(BaseTransform):
"""Pack the inputs data for pose estimation.
The ``img_meta`` item is always populated. The contents of the
``img_meta`` dictionary depends on ``meta_keys``. By default it includes:
- ``id``: id of the data sample
- ``img_id``: id of the image
- ``'category_id'``: the id of the instance category
- ``img_path``: path to the image file
- ``crowd_index`` (optional): measure the crowding level of an image,
defined in CrowdPose dataset
- ``ori_shape``: original shape of the image as a tuple (h, w, c)
- ``img_shape``: shape of the image input to the network as a tuple \
(h, w). Note that images may be zero padded on the \
bottom/right if the batch tensor is larger than this shape.
- ``input_size``: the input size to the network
- ``flip``: a boolean indicating if image flip transform was used
- ``flip_direction``: the flipping direction
- ``flip_indices``: the indices of each keypoint's symmetric keypoint
- ``raw_ann_info`` (optional): raw annotation of the instance(s)
Args:
meta_keys (Sequence[str], optional): Meta keys which will be stored in
:obj: `PoseDataSample` as meta info. Defaults to ``('id',
'img_id', 'img_path', 'category_id', 'crowd_index, 'ori_shape',
'img_shape',, 'input_size', 'input_center', 'input_scale', 'flip',
'flip_direction', 'flip_indices', 'raw_ann_info')``
"""
# items in `instance_mapping_table` will be directly packed into
# PoseDataSample.gt_instances without converting to Tensor
instance_mapping_table = {
'bbox': 'bboxes',
'head_size': 'head_size',
'bbox_center': 'bbox_centers',
'bbox_scale': 'bbox_scales',
'bbox_score': 'bbox_scores',
'keypoints': 'keypoints',
'keypoints_visible': 'keypoints_visible',
}
# items in `label_mapping_table` will be packed into
# PoseDataSample.gt_instance_labels and converted to Tensor. These items
# will be used for computing losses
label_mapping_table = {
'keypoint_labels': 'keypoint_labels',
'keypoint_x_labels': 'keypoint_x_labels',
'keypoint_y_labels': 'keypoint_y_labels',
'keypoint_weights': 'keypoint_weights',
'instance_coords': 'instance_coords'
}
# items in `field_mapping_table` will be packed into
# PoseDataSample.gt_fields and converted to Tensor. These items will be
# used for computing losses
field_mapping_table = {
'heatmaps': 'heatmaps',
'instance_heatmaps': 'instance_heatmaps',
'heatmap_mask': 'heatmap_mask',
'heatmap_weights': 'heatmap_weights',
'displacements': 'displacements',
'displacement_weights': 'displacement_weights',
}
def __init__(self,
meta_keys=('id', 'img_id', 'img_path', 'category_id',
'crowd_index', 'ori_shape', 'img_shape',
'input_size', 'input_center', 'input_scale',
'flip', 'flip_direction', 'flip_indices',
'raw_ann_info'),
pack_transformed=False):
self.meta_keys = meta_keys
self.pack_transformed = pack_transformed
def transform(self, results: dict) -> dict:
"""Method to pack the input data.
Args:
results (dict): Result dict from the data pipeline.
Returns:
dict:
- 'inputs' (obj:`torch.Tensor`): The forward data of models.
- 'data_samples' (obj:`PoseDataSample`): The annotation info of the
sample.
"""
# Pack image(s)
if 'img' in results:
img = results['img']
img_tensor = image_to_tensor(img)
data_sample = PoseDataSample()
# pack instance data
gt_instances = InstanceData()
for key, packed_key in self.instance_mapping_table.items():
if key in results:
gt_instances.set_field(results[key], packed_key)
# pack `transformed_keypoints` for visualizing data transform
# and augmentation results
if self.pack_transformed and 'transformed_keypoints' in results:
gt_instances.set_field(results['transformed_keypoints'],
'transformed_keypoints')
data_sample.gt_instances = gt_instances
# pack instance labels
gt_instance_labels = InstanceData()
for key, packed_key in self.label_mapping_table.items():
if key in results:
if isinstance(results[key], list):
# A list of labels is usually generated by combined
# multiple encoders (See ``GenerateTarget`` in
# mmpose/datasets/transforms/common_transforms.py)
# In this case, labels in list should have the same
# shape and will be stacked.
_labels = np.stack(results[key])
gt_instance_labels.set_field(_labels, packed_key)
else:
gt_instance_labels.set_field(results[key], packed_key)
data_sample.gt_instance_labels = gt_instance_labels.to_tensor()
# pack fields
gt_fields = None
for key, packed_key in self.field_mapping_table.items():
if key in results:
if isinstance(results[key], list):
if gt_fields is None:
gt_fields = MultilevelPixelData()
else:
assert isinstance(
gt_fields, MultilevelPixelData
), 'Got mixed single-level and multi-level pixel data.'
else:
if gt_fields is None:
gt_fields = PixelData()
else:
assert isinstance(
gt_fields, PixelData
), 'Got mixed single-level and multi-level pixel data.'
gt_fields.set_field(results[key], packed_key)
if gt_fields:
data_sample.gt_fields = gt_fields.to_tensor()
img_meta = {k: results[k] for k in self.meta_keys if k in results}
data_sample.set_metainfo(img_meta)
packed_results = dict()
packed_results['inputs'] = img_tensor
packed_results['data_samples'] = data_sample
return packed_results
def __repr__(self) -> str:
"""print the basic information of the transform.
Returns:
str: Formatted string.
"""
repr_str = self.__class__.__name__
repr_str += f'(meta_keys={self.meta_keys})'
return repr_str
|