File size: 7,246 Bytes
d4e7f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
from typing import Dict, List
import torch
import colorsys
import random
import numpy as np
from skimage.draw import line_aa, circle_perimeter_aa
import cv2
from .util import select_data


def _gen_random_colors(N, bright=True):
    brightness = 1.0 if bright else 0.7
    hsv = [(i / N, 1, brightness) for i in range(N)]
    colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
    random.shuffle(colors)
    return colors


_static_label_colors = [
    np.array((1.0, 1.0, 1.0), np.float32),
    np.array((255, 250, 79), np.float32) / 255.0,  # face
    np.array([255, 125, 138], np.float32) / 255.0,  # lb
    np.array([213, 32, 29], np.float32) / 255.0,  # rb
    np.array([0, 144, 187], np.float32) / 255.0,  # le
    np.array([0, 196, 253], np.float32) / 255.0,  # re
    np.array([255, 129, 54], np.float32) / 255.0,  # nose
    np.array([88, 233, 135], np.float32) / 255.0,  # ulip
    np.array([0, 117, 27], np.float32) / 255.0,  # llip
    np.array([255, 76, 249], np.float32) / 255.0,  # imouth
    np.array((1.0, 0.0, 0.0), np.float32),  # hair
    np.array((255, 250, 100), np.float32) / 255.0,  # lr
    np.array((255, 250, 100), np.float32) / 255.0,  # rr
    np.array((250, 245, 50), np.float32) / 255.0,  # neck
    np.array((0.0, 1.0, 0.5), np.float32),  # cloth
    np.array((1.0, 0.0, 0.5), np.float32),
] + _gen_random_colors(256)

_names_in_static_label_colors = [
    'background', 'face', 'lb', 'rb', 'le', 're', 'nose',
    'ulip', 'llip', 'imouth', 'hair', 'lr', 'rr', 'neck',
    'cloth', 'eyeg', 'hat', 'earr'
]


def _blend_labels(image, labels, label_names_dict=None,
                  default_alpha=0.6, color_offset=None):
    assert labels.ndim == 2
    bg_mask = labels == 0
    if label_names_dict is None:
        colors = _static_label_colors
    else:
        colors = [np.array((1.0, 1.0, 1.0), np.float32)]
        for i in range(1, labels.max() + 1):
            if isinstance(label_names_dict, dict) and i not in label_names_dict:
                bg_mask = np.logical_or(bg_mask, labels == i)
                colors.append(np.zeros((3)))
                continue
            label_name = label_names_dict[i]
            if label_name in _names_in_static_label_colors:
                color = _static_label_colors[
                    _names_in_static_label_colors.index(
                        label_name)]
            else:
                color = np.array((1.0, 1.0, 1.0), np.float32)
            colors.append(color)

    if color_offset is not None:
        ncolors = []
        for c in colors:
            nc = np.array(c)
            if (nc != np.zeros(3)).any():
                nc += color_offset
            ncolors.append(nc)
        colors = ncolors

    if image is None:
        image = orig_image = np.zeros(
            [labels.shape[0], labels.shape[1], 3], np.float32)
        alpha = 1.0
    else:
        orig_image = image / np.max(image)
        image = orig_image * (1.0 - default_alpha)
        alpha = default_alpha
    for i in range(1, np.max(labels) + 1):
        image += alpha * \
            np.tile(
                np.expand_dims(
                    (labels == i).astype(np.float32), -1),
                [1, 1, 3]) * colors[(i) % len(colors)]
    image[np.where(image > 1.0)] = 1.0
    image[np.where(image < 0)] = 0.0
    image[np.where(bg_mask)] = orig_image[np.where(bg_mask)]
    return image


def _draw_hwc(image: torch.Tensor, data: Dict[str, torch.Tensor]):
    device = image.device
    image = np.array(image.cpu().numpy(), copy=True)
    dtype = image.dtype
    h, w, _ = image.shape

    draw_score_error = False
    for tag, batch_content in data.items():
        if tag == 'rects':
            for cid, content in enumerate(batch_content):
                x1, y1, x2, y2 = [int(v) for v in content]
                y1, y2 = [max(min(v, h-1), 0) for v in [y1, y2]]
                x1, x2 = [max(min(v, w-1), 0) for v in [x1, x2]]
                for xx1, yy1, xx2, yy2 in [
                    [x1, y1, x2, y1],
                    [x1, y2, x2, y2],
                    [x1, y1, x1, y2],
                    [x2, y1, x2, y2]
                ]:
                    rr, cc, val = line_aa(yy1, xx1, yy2, xx2)
                    val = val[:, None][:, [0, 0, 0]]
                    image[rr, cc] = image[rr, cc] * (1.0-val) + val * 255

                if 'scores' in data:
                    try:
                        import cv2
                        score = data['scores'][cid].item()
                        score_str = f'{score:0.3f}'
                        image_c = np.array(image).copy()
                        cv2.putText(image_c, score_str, org=(int(x1), int(y2)),
                                    fontFace=cv2.FONT_HERSHEY_TRIPLEX,
                                    fontScale=0.6, color=(255, 255, 255), thickness=1)
                        image[:, :, :] = image_c
                    except Exception as e:
                        if not draw_score_error:
                            print(f'Failed to draw scores on image.')
                            print(e)
                        draw_score_error = True

        if tag == 'points':
            for content in batch_content:
                # content: npoints x 2
                for x, y in content:
                    x = max(min(int(x), w-1), 0)
                    y = max(min(int(y), h-1), 0)
                    rr, cc, val = circle_perimeter_aa(y, x, 1)
                    valid = np.all([rr >= 0, rr < h, cc >= 0, cc < w], axis=0)
                    rr = rr[valid]
                    cc = cc[valid]
                    val = val[valid]
                    val = val[:, None][:, [0, 0, 0]]
                    image[rr, cc] = image[rr, cc] * (1.0-val) + val * 255

        if tag == 'seg':
            label_names = batch_content['label_names']
            for seg_logits in batch_content['logits']:
                # content: nclasses x h x w
                seg_probs = seg_logits.softmax(dim=0)
                seg_labels = seg_probs.argmax(dim=0).cpu().numpy()
                image = (_blend_labels(image.astype(np.float32) /
                         255, seg_labels,
                         label_names_dict=label_names) * 255).astype(dtype)

    return torch.from_numpy(image).to(device=device)


def draw_bchw(images: torch.Tensor, data: Dict[str, torch.Tensor]) -> torch.Tensor:
    images2 = []
    for image_id, image_chw in enumerate(images):
        selected_data = select_data(image_id == data['image_ids'], data)
        images2.append(
            _draw_hwc(image_chw.permute(1, 2, 0), selected_data).permute(2, 0, 1))
    return torch.stack(images2, dim=0)

def draw_landmarks(img, bbox=None, landmark=None, color=(0, 255, 0)):
    """
    Input:
    - img: gray or RGB
    - bbox: type of BBox
    - landmark: reproject landmark of (5L, 2L)
    Output:
    - img marked with landmark and bbox
    """
    img = cv2.UMat(img).get()
    if bbox is not None:
        x1, y1, x2, y2 = np.array(bbox)[:4].astype(np.int32)
        cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2)
    if landmark is not None:
        for x, y in np.array(landmark).astype(np.int32):
            cv2.circle(img, (int(x), int(y)), 2, color, -1)
    return img