|
import torch |
|
import torch.nn.functional as F |
|
import cv2 |
|
from PIL import Image |
|
import numpy as np |
|
|
|
|
|
class Colors: |
|
def __init__(self): |
|
|
|
hexs = ( |
|
"00FF00", |
|
"FF3838", |
|
"FF701F", |
|
"FFB21D", |
|
"CFD231", |
|
"48F90A", |
|
"92CC17", |
|
"3DDB86", |
|
"1A9334", |
|
"00D4BB", |
|
"2C99A8", |
|
"00C2FF", |
|
"344593", |
|
"6473FF", |
|
"0018EC", |
|
"8438FF", |
|
"520085", |
|
"CB38FF", |
|
"FF95C8", |
|
"FF37C7", |
|
) |
|
self.palette = [self.hex2rgb(f"#{c}") for c in hexs] |
|
self.n = len(self.palette) |
|
|
|
def __call__(self, i, bgr=False): |
|
c = self.palette[int(i) % self.n] |
|
return (c[2], c[1], c[0]) if bgr else c |
|
|
|
@staticmethod |
|
def hex2rgb(h): |
|
return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4)) |
|
|
|
|
|
colors = Colors() |
|
|
|
|
|
def is_ascii(s=""): |
|
|
|
s = str(s) |
|
return len(s.encode().decode("ascii", "ignore")) == len(s) |
|
|
|
|
|
def clip_boxes(boxes, shape): |
|
|
|
if isinstance(boxes, torch.Tensor): |
|
boxes[:, 0].clamp_(0, shape[1]) |
|
boxes[:, 1].clamp_(0, shape[0]) |
|
boxes[:, 2].clamp_(0, shape[1]) |
|
boxes[:, 3].clamp_(0, shape[0]) |
|
else: |
|
boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) |
|
boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) |
|
|
|
|
|
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None): |
|
|
|
if ratio_pad is None: |
|
gain = min( |
|
img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1] |
|
) |
|
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, ( |
|
img1_shape[0] - img0_shape[0] * gain |
|
) / 2 |
|
else: |
|
gain = ratio_pad[0][0] |
|
pad = ratio_pad[1] |
|
|
|
boxes[:, [0, 2]] -= pad[0] |
|
boxes[:, [1, 3]] -= pad[1] |
|
boxes[:, :4] /= gain |
|
clip_boxes(boxes, img0_shape) |
|
return boxes |
|
|
|
|
|
def crop_mask(masks, boxes): |
|
""" |
|
"Crop" predicted masks by zeroing out everything not in the predicted bbox. |
|
Vectorized by Chong (thanks Chong). |
|
Args: |
|
- masks should be a size [h, w, n] tensor of masks |
|
- boxes should be a size [n, 4] tensor of bbox coords in relative point form |
|
""" |
|
|
|
n, h, w = masks.shape |
|
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) |
|
r = torch.arange(w, device=masks.device, dtype=x1.dtype)[ |
|
None, None, : |
|
] |
|
c = torch.arange(h, device=masks.device, dtype=x1.dtype)[ |
|
None, :, None |
|
] |
|
|
|
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2)) |
|
|
|
|
|
def process_mask(protos, masks_in, bboxes, shape, upsample=False): |
|
""" |
|
Crop before upsample. |
|
proto_out: [mask_dim, mask_h, mask_w] |
|
out_masks: [n, mask_dim], n is number of masks after nms |
|
bboxes: [n, 4], n is number of masks after nms |
|
shape:input_image_size, (h, w) |
|
return: h, w, n |
|
""" |
|
|
|
c, mh, mw = protos.shape |
|
ih, iw = shape |
|
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) |
|
|
|
downsampled_bboxes = bboxes.clone() |
|
downsampled_bboxes[:, 0] *= mw / iw |
|
downsampled_bboxes[:, 2] *= mw / iw |
|
downsampled_bboxes[:, 3] *= mh / ih |
|
downsampled_bboxes[:, 1] *= mh / ih |
|
|
|
masks = crop_mask(masks, downsampled_bboxes) |
|
if upsample: |
|
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[ |
|
0 |
|
] |
|
return masks.gt_(0.5) |
|
|
|
|
|
def scale_image(im1_shape, masks, im0_shape, ratio_pad=None): |
|
""" |
|
img1_shape: model input shape, [h, w] |
|
img0_shape: origin pic shape, [h, w, 3] |
|
masks: [h, w, num] |
|
""" |
|
|
|
if ratio_pad is None: |
|
gain = min( |
|
im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1] |
|
) |
|
pad = (im1_shape[1] - im0_shape[1] * gain) / 2, ( |
|
im1_shape[0] - im0_shape[0] * gain |
|
) / 2 |
|
else: |
|
pad = ratio_pad[1] |
|
top, left = int(pad[1]), int(pad[0]) |
|
bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0]) |
|
|
|
if len(masks.shape) < 2: |
|
raise ValueError( |
|
f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}' |
|
) |
|
masks = masks[top:bottom, left:right] |
|
|
|
|
|
|
|
masks = cv2.resize(masks, (im0_shape[1], im0_shape[0])) |
|
|
|
if len(masks.shape) == 2: |
|
masks = masks[:, :, None] |
|
return masks |
|
|
|
|
|
class Annotator: |
|
|
|
def __init__( |
|
self, |
|
im, |
|
line_width=None, |
|
font_size=None, |
|
font="Arial.ttf", |
|
pil=False, |
|
example="abc", |
|
): |
|
assert ( |
|
im.data.contiguous |
|
), "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images." |
|
non_ascii = not is_ascii( |
|
example |
|
) |
|
self.pil = pil or non_ascii |
|
if self.pil: |
|
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im) |
|
self.draw = ImageDraw.Draw(self.im) |
|
self.font = check_pil_font( |
|
font="Arial.Unicode.ttf" if non_ascii else font, |
|
size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12), |
|
) |
|
else: |
|
self.im = im |
|
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) |
|
|
|
def box_label( |
|
self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255) |
|
): |
|
|
|
if self.pil or not is_ascii(label): |
|
self.draw.rectangle(box, width=self.lw, outline=color) |
|
if label: |
|
w, h = self.font.getsize(label) |
|
outside = box[1] - h >= 0 |
|
self.draw.rectangle( |
|
( |
|
box[0], |
|
box[1] - h if outside else box[1], |
|
box[0] + w + 1, |
|
box[1] + 1 if outside else box[1] + h + 1, |
|
), |
|
fill=color, |
|
) |
|
|
|
self.draw.text( |
|
(box[0], box[1] - h if outside else box[1]), |
|
label, |
|
fill=txt_color, |
|
font=self.font, |
|
) |
|
else: |
|
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) |
|
cv2.rectangle( |
|
self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA |
|
) |
|
if label: |
|
tf = max(self.lw - 1, 1) |
|
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[ |
|
0 |
|
] |
|
outside = p1[1] - h >= 3 |
|
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 |
|
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) |
|
cv2.putText( |
|
self.im, |
|
label, |
|
(p1[0], p1[1] - 2 if outside else p1[1] + h + 2), |
|
0, |
|
self.lw / 3, |
|
txt_color, |
|
thickness=tf, |
|
lineType=cv2.LINE_AA, |
|
) |
|
|
|
def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False): |
|
"""Plot masks at once. |
|
Args: |
|
masks (tensor): predicted masks on cuda, shape: [n, h, w] |
|
colors (List[List[Int]]): colors for predicted masks, [[r, g, b] * n] |
|
im_gpu (tensor): img is in cuda, shape: [3, h, w], range: [0, 1] |
|
alpha (float): mask transparency: 0.0 fully transparent, 1.0 opaque |
|
""" |
|
im_gpu = torch.from_numpy(im_gpu) |
|
|
|
if self.pil: |
|
|
|
self.im = np.asarray(self.im).copy() |
|
if len(masks) == 0: |
|
self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255 |
|
colors = torch.tensor(colors, device=im_gpu.device, dtype=torch.float32) / 255.0 |
|
colors = colors[:, None, None] |
|
masks = masks.unsqueeze(3) |
|
masks_color = masks * (colors * alpha) |
|
|
|
inv_alph_masks = (1 - masks * alpha).cumprod(0) |
|
mcs = (masks_color * inv_alph_masks).sum( |
|
0 |
|
) * 2 |
|
|
|
im_gpu = im_gpu.flip(dims=[0]) |
|
im_gpu = im_gpu.permute(1, 2, 0).contiguous() |
|
im_gpu = im_gpu * inv_alph_masks[-1] + mcs |
|
im_mask = (im_gpu * 255).byte().cpu().numpy() |
|
self.im[:] = ( |
|
im_mask |
|
if retina_masks |
|
else scale_image(im_gpu.shape, im_mask, self.im.shape) |
|
) |
|
if self.pil: |
|
|
|
self.fromarray(self.im) |
|
|
|
def rectangle(self, xy, fill=None, outline=None, width=1): |
|
|
|
self.draw.rectangle(xy, fill, outline, width) |
|
|
|
def text(self, xy, text, txt_color=(255, 255, 255), anchor="top"): |
|
|
|
if anchor == "bottom": |
|
w, h = self.font.getsize(text) |
|
xy[1] += 1 - h |
|
self.draw.text(xy, text, fill=txt_color, font=self.font) |
|
|
|
def fromarray(self, im): |
|
|
|
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im) |
|
self.draw = ImageDraw.Draw(self.im) |
|
|
|
def result(self): |
|
|
|
return np.asarray(self.im) |
|
|