|
from typing import Dict, List |
|
import torch |
|
import colorsys |
|
import random |
|
import numpy as np |
|
from skimage.draw import line_aa, circle_perimeter_aa |
|
import cv2 |
|
from .util import select_data |
|
|
|
|
|
def _gen_random_colors(N, bright=True): |
|
brightness = 1.0 if bright else 0.7 |
|
hsv = [(i / N, 1, brightness) for i in range(N)] |
|
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) |
|
random.shuffle(colors) |
|
return colors |
|
|
|
|
|
_static_label_colors = [ |
|
np.array((1.0, 1.0, 1.0), np.float32), |
|
np.array((255, 250, 79), np.float32) / 255.0, |
|
np.array([255, 125, 138], np.float32) / 255.0, |
|
np.array([213, 32, 29], np.float32) / 255.0, |
|
np.array([0, 144, 187], np.float32) / 255.0, |
|
np.array([0, 196, 253], np.float32) / 255.0, |
|
np.array([255, 129, 54], np.float32) / 255.0, |
|
np.array([88, 233, 135], np.float32) / 255.0, |
|
np.array([0, 117, 27], np.float32) / 255.0, |
|
np.array([255, 76, 249], np.float32) / 255.0, |
|
np.array((1.0, 0.0, 0.0), np.float32), |
|
np.array((255, 250, 100), np.float32) / 255.0, |
|
np.array((255, 250, 100), np.float32) / 255.0, |
|
np.array((250, 245, 50), np.float32) / 255.0, |
|
np.array((0.0, 1.0, 0.5), np.float32), |
|
np.array((1.0, 0.0, 0.5), np.float32), |
|
] + _gen_random_colors(256) |
|
|
|
_names_in_static_label_colors = [ |
|
'background', 'face', 'lb', 'rb', 'le', 're', 'nose', |
|
'ulip', 'llip', 'imouth', 'hair', 'lr', 'rr', 'neck', |
|
'cloth', 'eyeg', 'hat', 'earr' |
|
] |
|
|
|
|
|
def _blend_labels(image, labels, label_names_dict=None, |
|
default_alpha=0.6, color_offset=None): |
|
assert labels.ndim == 2 |
|
bg_mask = labels == 0 |
|
if label_names_dict is None: |
|
colors = _static_label_colors |
|
else: |
|
colors = [np.array((1.0, 1.0, 1.0), np.float32)] |
|
for i in range(1, labels.max() + 1): |
|
if isinstance(label_names_dict, dict) and i not in label_names_dict: |
|
bg_mask = np.logical_or(bg_mask, labels == i) |
|
colors.append(np.zeros((3))) |
|
continue |
|
label_name = label_names_dict[i] |
|
if label_name in _names_in_static_label_colors: |
|
color = _static_label_colors[ |
|
_names_in_static_label_colors.index( |
|
label_name)] |
|
else: |
|
color = np.array((1.0, 1.0, 1.0), np.float32) |
|
colors.append(color) |
|
|
|
if color_offset is not None: |
|
ncolors = [] |
|
for c in colors: |
|
nc = np.array(c) |
|
if (nc != np.zeros(3)).any(): |
|
nc += color_offset |
|
ncolors.append(nc) |
|
colors = ncolors |
|
|
|
if image is None: |
|
image = orig_image = np.zeros( |
|
[labels.shape[0], labels.shape[1], 3], np.float32) |
|
alpha = 1.0 |
|
else: |
|
orig_image = image / np.max(image) |
|
image = orig_image * (1.0 - default_alpha) |
|
alpha = default_alpha |
|
for i in range(1, np.max(labels) + 1): |
|
image += alpha * \ |
|
np.tile( |
|
np.expand_dims( |
|
(labels == i).astype(np.float32), -1), |
|
[1, 1, 3]) * colors[(i) % len(colors)] |
|
image[np.where(image > 1.0)] = 1.0 |
|
image[np.where(image < 0)] = 0.0 |
|
image[np.where(bg_mask)] = orig_image[np.where(bg_mask)] |
|
return image |
|
|
|
|
|
def _draw_hwc(image: torch.Tensor, data: Dict[str, torch.Tensor]): |
|
device = image.device |
|
image = np.array(image.cpu().numpy(), copy=True) |
|
dtype = image.dtype |
|
h, w, _ = image.shape |
|
|
|
draw_score_error = False |
|
for tag, batch_content in data.items(): |
|
if tag == 'rects': |
|
for cid, content in enumerate(batch_content): |
|
x1, y1, x2, y2 = [int(v) for v in content] |
|
y1, y2 = [max(min(v, h-1), 0) for v in [y1, y2]] |
|
x1, x2 = [max(min(v, w-1), 0) for v in [x1, x2]] |
|
for xx1, yy1, xx2, yy2 in [ |
|
[x1, y1, x2, y1], |
|
[x1, y2, x2, y2], |
|
[x1, y1, x1, y2], |
|
[x2, y1, x2, y2] |
|
]: |
|
rr, cc, val = line_aa(yy1, xx1, yy2, xx2) |
|
val = val[:, None][:, [0, 0, 0]] |
|
image[rr, cc] = image[rr, cc] * (1.0-val) + val * 255 |
|
|
|
if 'scores' in data: |
|
try: |
|
import cv2 |
|
score = data['scores'][cid].item() |
|
score_str = f'{score:0.3f}' |
|
image_c = np.array(image).copy() |
|
cv2.putText(image_c, score_str, org=(int(x1), int(y2)), |
|
fontFace=cv2.FONT_HERSHEY_TRIPLEX, |
|
fontScale=0.6, color=(255, 255, 255), thickness=1) |
|
image[:, :, :] = image_c |
|
except Exception as e: |
|
if not draw_score_error: |
|
print(f'Failed to draw scores on image.') |
|
print(e) |
|
draw_score_error = True |
|
|
|
if tag == 'points': |
|
for content in batch_content: |
|
|
|
for x, y in content: |
|
x = max(min(int(x), w-1), 0) |
|
y = max(min(int(y), h-1), 0) |
|
rr, cc, val = circle_perimeter_aa(y, x, 1) |
|
valid = np.all([rr >= 0, rr < h, cc >= 0, cc < w], axis=0) |
|
rr = rr[valid] |
|
cc = cc[valid] |
|
val = val[valid] |
|
val = val[:, None][:, [0, 0, 0]] |
|
image[rr, cc] = image[rr, cc] * (1.0-val) + val * 255 |
|
|
|
if tag == 'seg': |
|
label_names = batch_content['label_names'] |
|
for seg_logits in batch_content['logits']: |
|
|
|
seg_probs = seg_logits.softmax(dim=0) |
|
seg_labels = seg_probs.argmax(dim=0).cpu().numpy() |
|
image = (_blend_labels(image.astype(np.float32) / |
|
255, seg_labels, |
|
label_names_dict=label_names) * 255).astype(dtype) |
|
|
|
return torch.from_numpy(image).to(device=device) |
|
|
|
|
|
def draw_bchw(images: torch.Tensor, data: Dict[str, torch.Tensor]) -> torch.Tensor: |
|
images2 = [] |
|
for image_id, image_chw in enumerate(images): |
|
selected_data = select_data(image_id == data['image_ids'], data) |
|
images2.append( |
|
_draw_hwc(image_chw.permute(1, 2, 0), selected_data).permute(2, 0, 1)) |
|
return torch.stack(images2, dim=0) |
|
|
|
def draw_landmarks(img, bbox=None, landmark=None, color=(0, 255, 0)): |
|
""" |
|
Input: |
|
- img: gray or RGB |
|
- bbox: type of BBox |
|
- landmark: reproject landmark of (5L, 2L) |
|
Output: |
|
- img marked with landmark and bbox |
|
""" |
|
img = cv2.UMat(img).get() |
|
if bbox is not None: |
|
x1, y1, x2, y2 = np.array(bbox)[:4].astype(np.int32) |
|
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2) |
|
if landmark is not None: |
|
for x, y in np.array(landmark).astype(np.int32): |
|
cv2.circle(img, (int(x), int(y)), 2, color, -1) |
|
return img |