File size: 4,606 Bytes
cc0dd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import List

import cv2
import numpy as np
import torch
from mmengine.structures import InstanceData, PixelData
from mmengine.utils import is_list_of

from .bbox.transforms import get_warp_matrix
from .pose_data_sample import PoseDataSample


def merge_data_samples(data_samples: List[PoseDataSample]) -> PoseDataSample:
    """Merge the given data samples into a single data sample.

    This function can be used to merge the top-down predictions with
    bboxes from the same image. The merged data sample will contain all
    instances from the input data samples, and the identical metainfo with
    the first input data sample.

    Args:
        data_samples (List[:obj:`PoseDataSample`]): The data samples to
            merge

    Returns:
        PoseDataSample: The merged data sample.
    """

    if not is_list_of(data_samples, PoseDataSample):
        raise ValueError('Invalid input type, should be a list of '
                         ':obj:`PoseDataSample`')

    if len(data_samples) == 0:
        warnings.warn('Try to merge an empty list of data samples.')
        return PoseDataSample()

    merged = PoseDataSample(metainfo=data_samples[0].metainfo)

    if 'gt_instances' in data_samples[0]:
        merged.gt_instances = InstanceData.cat(
            [d.gt_instances for d in data_samples])

    if 'pred_instances' in data_samples[0]:
        merged.pred_instances = InstanceData.cat(
            [d.pred_instances for d in data_samples])

    if 'pred_fields' in data_samples[0] and 'heatmaps' in data_samples[
            0].pred_fields:
        reverted_heatmaps = [
            revert_heatmap(data_sample.pred_fields.heatmaps,
                           data_sample.gt_instances.bbox_centers,
                           data_sample.gt_instances.bbox_scales,
                           data_sample.ori_shape)
            for data_sample in data_samples
        ]

        merged_heatmaps = np.max(reverted_heatmaps, axis=0)
        pred_fields = PixelData()
        pred_fields.set_data(dict(heatmaps=merged_heatmaps))
        merged.pred_fields = pred_fields

    if 'gt_fields' in data_samples[0] and 'heatmaps' in data_samples[
            0].gt_fields:
        reverted_heatmaps = [
            revert_heatmap(data_sample.gt_fields.heatmaps,
                           data_sample.gt_instances.bbox_centers,
                           data_sample.gt_instances.bbox_scales,
                           data_sample.ori_shape)
            for data_sample in data_samples
        ]

        merged_heatmaps = np.max(reverted_heatmaps, axis=0)
        gt_fields = PixelData()
        gt_fields.set_data(dict(heatmaps=merged_heatmaps))
        merged.gt_fields = gt_fields

    return merged


def revert_heatmap(heatmap, bbox_center, bbox_scale, img_shape):
    """Revert predicted heatmap on the original image.

    Args:
        heatmap (np.ndarray or torch.tensor): predicted heatmap.
        bbox_center (np.ndarray): bounding box center coordinate.
        bbox_scale (np.ndarray): bounding box scale.
        img_shape (tuple or list): size of original image.
    """
    if torch.is_tensor(heatmap):
        heatmap = heatmap.cpu().detach().numpy()

    ndim = heatmap.ndim
    # [K, H, W] -> [H, W, K]
    if ndim == 3:
        heatmap = heatmap.transpose(1, 2, 0)

    hm_h, hm_w = heatmap.shape[:2]
    img_h, img_w = img_shape
    warp_mat = get_warp_matrix(
        bbox_center.reshape((2, )),
        bbox_scale.reshape((2, )),
        rot=0,
        output_size=(hm_w, hm_h),
        inv=True)

    heatmap = cv2.warpAffine(
        heatmap, warp_mat, (img_w, img_h), flags=cv2.INTER_LINEAR)

    # [H, W, K] -> [K, H, W]
    if ndim == 3:
        heatmap = heatmap.transpose(2, 0, 1)

    return heatmap


def split_instances(instances: InstanceData) -> List[InstanceData]:
    """Convert instances into a list where each element is a dict that contains
    information about one instance."""
    results = []

    # return an empty list if there is no instance detected by the model
    if instances is None:
        return results

    for i in range(len(instances.keypoints)):
        result = dict(
            keypoints=instances.keypoints[i].tolist(),
            keypoint_scores=instances.keypoint_scores[i].tolist(),
        )
        if 'bboxes' in instances:
            result['bbox'] = instances.bboxes[i].tolist(),
            if 'bbox_scores' in instances:
                result['bbox_score'] = instances.bbox_scores[i]
        results.append(result)

    return results