mask-detection / src /dataset.py
eksemyashkina's picture
Upload 13 files
f514e23 verified
raw
history blame
6.91 kB
from typing import List, Tuple, Dict
from pathlib import Path
import PIL.Image
import numpy as np
import torchvision.transforms as T
import torch
from torch.utils.data import Dataset
from bs4 import BeautifulSoup
from bs4.element import Tag
ANCHORS = {
"small": [(26, 28), (17, 19), (10, 11)],
"medium": [(78, 88), (55, 59), (37, 42)],
"large": [(128, 152), (182, 205), (103, 124)]
}
GRID_SIZES = [13, 26, 52]
IMAGE_SIZE = (416, 416)
NUM_CLASSES = 3
def generate_box(obj: Tag) -> List[int]:
xmin = int(obj.find("xmin").text) - 1
ymin = int(obj.find("ymin").text) - 1
xmax = int(obj.find("xmax").text) - 1
ymax = int(obj.find("ymax").text) - 1
if obj.find("name").text == "without_mask":
class_id = 0
elif obj.find("name").text == "with_mask":
class_id = 1
else:
class_id = 2
return [xmin, ymin, xmax, ymax, class_id]
def resize_boxes(box: List[int], scale: float, pad_x: int, pad_y: int) -> Tuple[int]:
xmin, ymin, xmax, ymax, class_id = box
xmin = int(xmin * scale + pad_x)
ymin = int(ymin * scale + pad_y)
xmax = int(xmax * scale + pad_x)
ymax = int(ymax * scale + pad_y)
return (xmin, ymin, xmax, ymax, class_id)
def resize_with_padding(image: PIL.Image.Image, target_size: Tuple[int] = IMAGE_SIZE, fill: Tuple[int] = (255, 255, 255)) -> Tuple[PIL.Image.Image, float, int]:
target_w, target_h = target_size
orig_w, orig_h = image.size
scale = min(target_w / orig_w, target_h / orig_h)
new_w = int(orig_w * scale)
new_h = int(orig_h * scale)
image_resized = image.resize((new_w, new_h), resample=PIL.Image.LANCZOS)
new_image = PIL.Image.new("RGB", (target_w, target_h), color=fill)
pad_x = (target_w - new_w) // 2
pad_y = (target_h - new_h) // 2
new_image.paste(image_resized, (pad_x, pad_y))
return new_image, scale, pad_x, pad_y
def build_targets_3scale(bboxes: List[Tuple[int]], image_size: Tuple[int] = IMAGE_SIZE, anchors: Dict[str, List[Tuple[int]]] = ANCHORS, grid_sizes: List[int] = GRID_SIZES, num_classes: int = NUM_CLASSES) -> Tuple[torch.Tensor]:
img_w, img_h = image_size
t_large = torch.zeros((grid_sizes[0], grid_sizes[0], 3, 5 + num_classes), dtype=torch.float32)
t_medium = torch.zeros((grid_sizes[1], grid_sizes[1], 3, 5 + num_classes), dtype=torch.float32)
t_small = torch.zeros((grid_sizes[2], grid_sizes[2], 3, 5 + num_classes), dtype=torch.float32)
all_anchors = anchors["large"] + anchors["medium"] + anchors["small"]
for (xmin, ymin, xmax, ymax, cls_id) in bboxes:
box_w = xmax - xmin
box_h = ymax - ymin
x_center = (xmax + xmin) / 2
y_center = (ymax + ymin) / 2
if box_w <= 0 or box_h <= 0:
continue
best_iou = 0
best_idx = 0
for i, (aw, ah) in enumerate(all_anchors):
inter = min(box_w, aw) * min(box_h, ah)
union = box_w * box_h + aw * ah - inter
iou = inter / union if union > 0 else 0
if iou > best_iou:
best_iou = iou
best_idx = i
if best_idx <= 2:
s = grid_sizes[0]
t = t_large
local_anchor_id = best_idx
anchor_w, anchor_h = anchors["large"][local_anchor_id]
elif best_idx <= 5:
s = grid_sizes[1]
t = t_medium
local_anchor_id = best_idx - 3
anchor_w, anchor_h = anchors["medium"][local_anchor_id]
else:
s = grid_sizes[2]
t = t_small
local_anchor_id = best_idx - 6
anchor_w, anchor_h = anchors["small"][local_anchor_id]
cell_w = img_w / s
cell_h = img_h / s
gx = int(x_center // cell_w)
gy = int(y_center // cell_h)
tx = (x_center / cell_w) - gx
ty = (y_center / cell_h) - gy
tw = np.log((box_w / (anchor_w + 1e-16)) + 1e-16)
th = np.log((box_h / (anchor_h + 1e-16)) + 1e-16)
t[gy, gx, local_anchor_id, 0] = tx
t[gy, gx, local_anchor_id, 1] = ty
t[gy, gx, local_anchor_id, 2] = tw
t[gy, gx, local_anchor_id, 3] = th
t[gy, gx, local_anchor_id, 4] = 1.0
t[gy, gx, local_anchor_id, 5 + cls_id] = 1.0
return t_large, t_medium, t_small
class MaskDataset(Dataset):
def __init__(self, root: str, train: bool = True, test_size: float = 0.25) -> None:
super().__init__()
self.class_counts = [0, 0, 0]
self.root = root
self.train = train
all_imgs = sorted(list((Path(root) / "images").glob("*.png")))
all_anns = sorted(list((Path(root) / "annotations").glob("*.xml")))
n_test = int(len(all_imgs) * test_size)
if train:
self.images = all_imgs[n_test:]
self.annots = all_anns[n_test:]
else:
self.images = all_imgs[:n_test]
self.annots = all_anns[:n_test]
self.transform = T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
for ann in self.annots:
with open(ann, "r") as f:
data = f.read()
soup = BeautifulSoup(data, "lxml")
for obj in soup.find_all("object"):
cls = obj.find("name").text
self.class_counts[0 if cls == "without_mask" else 1 if cls == "with_mask" else 2] += 1
def __len__(self) -> int:
return len(self.images)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
img_path = self.images[idx]
ann_path = self.annots[idx]
img = PIL.Image.open(img_path).convert("RGB")
img_resized, scale, pad_x, pad_y = resize_with_padding(img)
with open(ann_path, "r") as f:
data = f.read()
soup = BeautifulSoup(data, "lxml")
objs = soup.find_all("object")
resized_boxes = []
for obj in objs:
b = generate_box(obj)
b2 = resize_boxes(b, scale, pad_x, pad_y)
resized_boxes.append(b2)
t_large, t_medium, t_small = build_targets_3scale(resized_boxes)
img_tensor = self.transform(img_resized)
return img_tensor, (t_large, t_medium, t_small)
def collate_fn(batch: List[Tuple[torch.Tensor, Tuple[torch.Tensor]]]) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
imgs, t_l, t_m, t_s = [], [], [], []
for (img, (tl, tm, ts)) in batch:
imgs.append(img)
t_l.append(tl)
t_m.append(tm)
t_s.append(ts)
imgs = torch.stack(imgs, dim=0)
t_l = torch.stack(t_l, dim=0)
t_m = torch.stack(t_m, dim=0)
t_s = torch.stack(t_s, dim=0)
return imgs, (t_l, t_m, t_s)