Spaces:
Runtime error
Runtime error
File size: 4,606 Bytes
cc0dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import List
import cv2
import numpy as np
import torch
from mmengine.structures import InstanceData, PixelData
from mmengine.utils import is_list_of
from .bbox.transforms import get_warp_matrix
from .pose_data_sample import PoseDataSample
def merge_data_samples(data_samples: List[PoseDataSample]) -> PoseDataSample:
"""Merge the given data samples into a single data sample.
This function can be used to merge the top-down predictions with
bboxes from the same image. The merged data sample will contain all
instances from the input data samples, and the identical metainfo with
the first input data sample.
Args:
data_samples (List[:obj:`PoseDataSample`]): The data samples to
merge
Returns:
PoseDataSample: The merged data sample.
"""
if not is_list_of(data_samples, PoseDataSample):
raise ValueError('Invalid input type, should be a list of '
':obj:`PoseDataSample`')
if len(data_samples) == 0:
warnings.warn('Try to merge an empty list of data samples.')
return PoseDataSample()
merged = PoseDataSample(metainfo=data_samples[0].metainfo)
if 'gt_instances' in data_samples[0]:
merged.gt_instances = InstanceData.cat(
[d.gt_instances for d in data_samples])
if 'pred_instances' in data_samples[0]:
merged.pred_instances = InstanceData.cat(
[d.pred_instances for d in data_samples])
if 'pred_fields' in data_samples[0] and 'heatmaps' in data_samples[
0].pred_fields:
reverted_heatmaps = [
revert_heatmap(data_sample.pred_fields.heatmaps,
data_sample.gt_instances.bbox_centers,
data_sample.gt_instances.bbox_scales,
data_sample.ori_shape)
for data_sample in data_samples
]
merged_heatmaps = np.max(reverted_heatmaps, axis=0)
pred_fields = PixelData()
pred_fields.set_data(dict(heatmaps=merged_heatmaps))
merged.pred_fields = pred_fields
if 'gt_fields' in data_samples[0] and 'heatmaps' in data_samples[
0].gt_fields:
reverted_heatmaps = [
revert_heatmap(data_sample.gt_fields.heatmaps,
data_sample.gt_instances.bbox_centers,
data_sample.gt_instances.bbox_scales,
data_sample.ori_shape)
for data_sample in data_samples
]
merged_heatmaps = np.max(reverted_heatmaps, axis=0)
gt_fields = PixelData()
gt_fields.set_data(dict(heatmaps=merged_heatmaps))
merged.gt_fields = gt_fields
return merged
def revert_heatmap(heatmap, bbox_center, bbox_scale, img_shape):
"""Revert predicted heatmap on the original image.
Args:
heatmap (np.ndarray or torch.tensor): predicted heatmap.
bbox_center (np.ndarray): bounding box center coordinate.
bbox_scale (np.ndarray): bounding box scale.
img_shape (tuple or list): size of original image.
"""
if torch.is_tensor(heatmap):
heatmap = heatmap.cpu().detach().numpy()
ndim = heatmap.ndim
# [K, H, W] -> [H, W, K]
if ndim == 3:
heatmap = heatmap.transpose(1, 2, 0)
hm_h, hm_w = heatmap.shape[:2]
img_h, img_w = img_shape
warp_mat = get_warp_matrix(
bbox_center.reshape((2, )),
bbox_scale.reshape((2, )),
rot=0,
output_size=(hm_w, hm_h),
inv=True)
heatmap = cv2.warpAffine(
heatmap, warp_mat, (img_w, img_h), flags=cv2.INTER_LINEAR)
# [H, W, K] -> [K, H, W]
if ndim == 3:
heatmap = heatmap.transpose(2, 0, 1)
return heatmap
def split_instances(instances: InstanceData) -> List[InstanceData]:
"""Convert instances into a list where each element is a dict that contains
information about one instance."""
results = []
# return an empty list if there is no instance detected by the model
if instances is None:
return results
for i in range(len(instances.keypoints)):
result = dict(
keypoints=instances.keypoints[i].tolist(),
keypoint_scores=instances.keypoint_scores[i].tolist(),
)
if 'bboxes' in instances:
result['bbox'] = instances.bboxes[i].tolist(),
if 'bbox_scores' in instances:
result['bbox_score'] = instances.bbox_scores[i]
results.append(result)
return results
|