Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
from typing import List | |
import cv2 | |
import numpy as np | |
import torch | |
from mmengine.structures import InstanceData, PixelData | |
from mmengine.utils import is_list_of | |
from .bbox.transforms import get_warp_matrix | |
from .pose_data_sample import PoseDataSample | |
def merge_data_samples(data_samples: List[PoseDataSample]) -> PoseDataSample: | |
"""Merge the given data samples into a single data sample. | |
This function can be used to merge the top-down predictions with | |
bboxes from the same image. The merged data sample will contain all | |
instances from the input data samples, and the identical metainfo with | |
the first input data sample. | |
Args: | |
data_samples (List[:obj:`PoseDataSample`]): The data samples to | |
merge | |
Returns: | |
PoseDataSample: The merged data sample. | |
""" | |
if not is_list_of(data_samples, PoseDataSample): | |
raise ValueError('Invalid input type, should be a list of ' | |
':obj:`PoseDataSample`') | |
if len(data_samples) == 0: | |
warnings.warn('Try to merge an empty list of data samples.') | |
return PoseDataSample() | |
merged = PoseDataSample(metainfo=data_samples[0].metainfo) | |
if 'gt_instances' in data_samples[0]: | |
merged.gt_instances = InstanceData.cat( | |
[d.gt_instances for d in data_samples]) | |
if 'pred_instances' in data_samples[0]: | |
merged.pred_instances = InstanceData.cat( | |
[d.pred_instances for d in data_samples]) | |
if 'pred_fields' in data_samples[0] and 'heatmaps' in data_samples[ | |
0].pred_fields: | |
reverted_heatmaps = [ | |
revert_heatmap(data_sample.pred_fields.heatmaps, | |
data_sample.gt_instances.bbox_centers, | |
data_sample.gt_instances.bbox_scales, | |
data_sample.ori_shape) | |
for data_sample in data_samples | |
] | |
merged_heatmaps = np.max(reverted_heatmaps, axis=0) | |
pred_fields = PixelData() | |
pred_fields.set_data(dict(heatmaps=merged_heatmaps)) | |
merged.pred_fields = pred_fields | |
if 'gt_fields' in data_samples[0] and 'heatmaps' in data_samples[ | |
0].gt_fields: | |
reverted_heatmaps = [ | |
revert_heatmap(data_sample.gt_fields.heatmaps, | |
data_sample.gt_instances.bbox_centers, | |
data_sample.gt_instances.bbox_scales, | |
data_sample.ori_shape) | |
for data_sample in data_samples | |
] | |
merged_heatmaps = np.max(reverted_heatmaps, axis=0) | |
gt_fields = PixelData() | |
gt_fields.set_data(dict(heatmaps=merged_heatmaps)) | |
merged.gt_fields = gt_fields | |
return merged | |
def revert_heatmap(heatmap, bbox_center, bbox_scale, img_shape): | |
"""Revert predicted heatmap on the original image. | |
Args: | |
heatmap (np.ndarray or torch.tensor): predicted heatmap. | |
bbox_center (np.ndarray): bounding box center coordinate. | |
bbox_scale (np.ndarray): bounding box scale. | |
img_shape (tuple or list): size of original image. | |
""" | |
if torch.is_tensor(heatmap): | |
heatmap = heatmap.cpu().detach().numpy() | |
ndim = heatmap.ndim | |
# [K, H, W] -> [H, W, K] | |
if ndim == 3: | |
heatmap = heatmap.transpose(1, 2, 0) | |
hm_h, hm_w = heatmap.shape[:2] | |
img_h, img_w = img_shape | |
warp_mat = get_warp_matrix( | |
bbox_center.reshape((2, )), | |
bbox_scale.reshape((2, )), | |
rot=0, | |
output_size=(hm_w, hm_h), | |
inv=True) | |
heatmap = cv2.warpAffine( | |
heatmap, warp_mat, (img_w, img_h), flags=cv2.INTER_LINEAR) | |
# [H, W, K] -> [K, H, W] | |
if ndim == 3: | |
heatmap = heatmap.transpose(2, 0, 1) | |
return heatmap | |
def split_instances(instances: InstanceData) -> List[InstanceData]: | |
"""Convert instances into a list where each element is a dict that contains | |
information about one instance.""" | |
results = [] | |
# return an empty list if there is no instance detected by the model | |
if instances is None: | |
return results | |
for i in range(len(instances.keypoints)): | |
result = dict( | |
keypoints=instances.keypoints[i].tolist(), | |
keypoint_scores=instances.keypoint_scores[i].tolist(), | |
) | |
if 'bboxes' in instances: | |
result['bbox'] = instances.bboxes[i].tolist(), | |
if 'bbox_scores' in instances: | |
result['bbox_score'] = instances.bbox_scores[i] | |
results.append(result) | |
return results | |