Spaces:
Running
Running
from typing import List, Tuple, Dict | |
from pathlib import Path | |
import PIL.Image | |
import numpy as np | |
import torchvision.transforms as T | |
import torch | |
from torch.utils.data import Dataset | |
from bs4 import BeautifulSoup | |
from bs4.element import Tag | |
ANCHORS = { | |
"small": [(26, 28), (17, 19), (10, 11)], | |
"medium": [(78, 88), (55, 59), (37, 42)], | |
"large": [(128, 152), (182, 205), (103, 124)] | |
} | |
GRID_SIZES = [13, 26, 52] | |
IMAGE_SIZE = (416, 416) | |
NUM_CLASSES = 3 | |
def generate_box(obj: Tag) -> List[int]: | |
xmin = int(obj.find("xmin").text) - 1 | |
ymin = int(obj.find("ymin").text) - 1 | |
xmax = int(obj.find("xmax").text) - 1 | |
ymax = int(obj.find("ymax").text) - 1 | |
if obj.find("name").text == "without_mask": | |
class_id = 0 | |
elif obj.find("name").text == "with_mask": | |
class_id = 1 | |
else: | |
class_id = 2 | |
return [xmin, ymin, xmax, ymax, class_id] | |
def resize_boxes(box: List[int], scale: float, pad_x: int, pad_y: int) -> Tuple[int]: | |
xmin, ymin, xmax, ymax, class_id = box | |
xmin = int(xmin * scale + pad_x) | |
ymin = int(ymin * scale + pad_y) | |
xmax = int(xmax * scale + pad_x) | |
ymax = int(ymax * scale + pad_y) | |
return (xmin, ymin, xmax, ymax, class_id) | |
def resize_with_padding(image: PIL.Image.Image, target_size: Tuple[int] = IMAGE_SIZE, fill: Tuple[int] = (255, 255, 255)) -> Tuple[PIL.Image.Image, float, int]: | |
target_w, target_h = target_size | |
orig_w, orig_h = image.size | |
scale = min(target_w / orig_w, target_h / orig_h) | |
new_w = int(orig_w * scale) | |
new_h = int(orig_h * scale) | |
image_resized = image.resize((new_w, new_h), resample=PIL.Image.LANCZOS) | |
new_image = PIL.Image.new("RGB", (target_w, target_h), color=fill) | |
pad_x = (target_w - new_w) // 2 | |
pad_y = (target_h - new_h) // 2 | |
new_image.paste(image_resized, (pad_x, pad_y)) | |
return new_image, scale, pad_x, pad_y | |
def build_targets_3scale(bboxes: List[Tuple[int]], image_size: Tuple[int] = IMAGE_SIZE, anchors: Dict[str, List[Tuple[int]]] = ANCHORS, grid_sizes: List[int] = GRID_SIZES, num_classes: int = NUM_CLASSES) -> Tuple[torch.Tensor]: | |
img_w, img_h = image_size | |
t_large = torch.zeros((grid_sizes[0], grid_sizes[0], 3, 5 + num_classes), dtype=torch.float32) | |
t_medium = torch.zeros((grid_sizes[1], grid_sizes[1], 3, 5 + num_classes), dtype=torch.float32) | |
t_small = torch.zeros((grid_sizes[2], grid_sizes[2], 3, 5 + num_classes), dtype=torch.float32) | |
all_anchors = anchors["large"] + anchors["medium"] + anchors["small"] | |
for (xmin, ymin, xmax, ymax, cls_id) in bboxes: | |
box_w = xmax - xmin | |
box_h = ymax - ymin | |
x_center = (xmax + xmin) / 2 | |
y_center = (ymax + ymin) / 2 | |
if box_w <= 0 or box_h <= 0: | |
continue | |
best_iou = 0 | |
best_idx = 0 | |
for i, (aw, ah) in enumerate(all_anchors): | |
inter = min(box_w, aw) * min(box_h, ah) | |
union = box_w * box_h + aw * ah - inter | |
iou = inter / union if union > 0 else 0 | |
if iou > best_iou: | |
best_iou = iou | |
best_idx = i | |
if best_idx <= 2: | |
s = grid_sizes[0] | |
t = t_large | |
local_anchor_id = best_idx | |
anchor_w, anchor_h = anchors["large"][local_anchor_id] | |
elif best_idx <= 5: | |
s = grid_sizes[1] | |
t = t_medium | |
local_anchor_id = best_idx - 3 | |
anchor_w, anchor_h = anchors["medium"][local_anchor_id] | |
else: | |
s = grid_sizes[2] | |
t = t_small | |
local_anchor_id = best_idx - 6 | |
anchor_w, anchor_h = anchors["small"][local_anchor_id] | |
cell_w = img_w / s | |
cell_h = img_h / s | |
gx = int(x_center // cell_w) | |
gy = int(y_center // cell_h) | |
tx = (x_center / cell_w) - gx | |
ty = (y_center / cell_h) - gy | |
tw = np.log((box_w / (anchor_w + 1e-16)) + 1e-16) | |
th = np.log((box_h / (anchor_h + 1e-16)) + 1e-16) | |
t[gy, gx, local_anchor_id, 0] = tx | |
t[gy, gx, local_anchor_id, 1] = ty | |
t[gy, gx, local_anchor_id, 2] = tw | |
t[gy, gx, local_anchor_id, 3] = th | |
t[gy, gx, local_anchor_id, 4] = 1.0 | |
t[gy, gx, local_anchor_id, 5 + cls_id] = 1.0 | |
return t_large, t_medium, t_small | |
class MaskDataset(Dataset): | |
def __init__(self, root: str, train: bool = True, test_size: float = 0.25) -> None: | |
super().__init__() | |
self.class_counts = [0, 0, 0] | |
self.root = root | |
self.train = train | |
all_imgs = sorted(list((Path(root) / "images").glob("*.png"))) | |
all_anns = sorted(list((Path(root) / "annotations").glob("*.xml"))) | |
n_test = int(len(all_imgs) * test_size) | |
if train: | |
self.images = all_imgs[n_test:] | |
self.annots = all_anns[n_test:] | |
else: | |
self.images = all_imgs[:n_test] | |
self.annots = all_anns[:n_test] | |
self.transform = T.Compose([ | |
T.ToTensor(), | |
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
for ann in self.annots: | |
with open(ann, "r") as f: | |
data = f.read() | |
soup = BeautifulSoup(data, "lxml") | |
for obj in soup.find_all("object"): | |
cls = obj.find("name").text | |
self.class_counts[0 if cls == "without_mask" else 1 if cls == "with_mask" else 2] += 1 | |
def __len__(self) -> int: | |
return len(self.images) | |
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]: | |
img_path = self.images[idx] | |
ann_path = self.annots[idx] | |
img = PIL.Image.open(img_path).convert("RGB") | |
img_resized, scale, pad_x, pad_y = resize_with_padding(img) | |
with open(ann_path, "r") as f: | |
data = f.read() | |
soup = BeautifulSoup(data, "lxml") | |
objs = soup.find_all("object") | |
resized_boxes = [] | |
for obj in objs: | |
b = generate_box(obj) | |
b2 = resize_boxes(b, scale, pad_x, pad_y) | |
resized_boxes.append(b2) | |
t_large, t_medium, t_small = build_targets_3scale(resized_boxes) | |
img_tensor = self.transform(img_resized) | |
return img_tensor, (t_large, t_medium, t_small) | |
def collate_fn(batch: List[Tuple[torch.Tensor, Tuple[torch.Tensor]]]) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]: | |
imgs, t_l, t_m, t_s = [], [], [], [] | |
for (img, (tl, tm, ts)) in batch: | |
imgs.append(img) | |
t_l.append(tl) | |
t_m.append(tm) | |
t_s.append(ts) | |
imgs = torch.stack(imgs, dim=0) | |
t_l = torch.stack(t_l, dim=0) | |
t_m = torch.stack(t_m, dim=0) | |
t_s = torch.stack(t_s, dim=0) | |
return imgs, (t_l, t_m, t_s) | |