eksemyashkina commited on
Commit
f514e23
·
verified ·
1 Parent(s): 285ba68

Upload 13 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/examples/image3.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/examples/image5.jpg filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import gradio as gr
3
+ import PIL.Image, PIL.ImageOps
4
+ import torch
5
+ import numpy as np
6
+ import torchvision.transforms as T
7
+
8
+ from src.models.yolov3 import YOLOv3
9
+ from src.train import draw_bounding_boxes, decode_predictions_3scales
10
+ from src.dataset import ANCHORS, resize_with_padding
11
+
12
+
13
+ device = torch.device("cpu")
14
+ model_weight = "weights/checkpoint-best.pth"
15
+ label_colors = {"without_mask": (178, 34, 34), "with_mask": (34, 139, 34), "mask_worn_incorrectly": (184, 134, 11)}
16
+
17
+ model = YOLOv3()
18
+ model.load_state_dict(torch.load(model_weight, map_location=device))
19
+ model.eval()
20
+
21
+
22
+ def create_combined_image(img: torch.Tensor, results: List[torch.Tensor], mean: List[float] = [0.485, 0.456, 0.406], std: List[float] = [0.229, 0.224, 0.225]):
23
+ batch_size, _, height, width = img.shape
24
+ combined_height = height
25
+ combined_width = width * batch_size
26
+ combined_image = np.zeros((combined_height, combined_width, 3), dtype=np.uint8)
27
+
28
+ for i in range(batch_size):
29
+ image = img[i].cpu().permute(1, 2, 0).numpy()
30
+ image = (image * std + mean).clip(0, 1)
31
+ image = (image * 255).astype(np.uint8)
32
+ pred_image = PIL.Image.fromarray(image.copy())
33
+ draw_bounding_boxes(pred_image, results[i], show_conf=True)
34
+ combined_image[:height, i * width:(i + 1) * width, :] = np.array(pred_image)
35
+ return PIL.Image.fromarray(combined_image)
36
+
37
+
38
+ transform = T.Compose([
39
+ T.ToTensor(),
40
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
41
+ ])
42
+
43
+
44
+ def detect_mask(image, conf_threshold: float) -> PIL.Image:
45
+ img_resized, _, _, _ = resize_with_padding(image)
46
+ img_tensor = transform(img_resized)
47
+ with torch.no_grad():
48
+ out_l, out_m, out_s = model(img_tensor.unsqueeze(0))
49
+ results = decode_predictions_3scales(out_l, out_m, out_s, ANCHORS["large"], ANCHORS["medium"], ANCHORS["small"], conf_threshold=conf_threshold)
50
+ combined_image = create_combined_image(img_tensor.unsqueeze(0), results)
51
+ return combined_image
52
+
53
+
54
+ def generate_legend_html_compact() -> str:
55
+ legend_html = """
56
+ <div style="display: flex; flex-wrap: wrap; gap: 10px; justify-content: center;">
57
+ """
58
+ for idx, (label, color) in enumerate(label_colors.items()):
59
+ legend_html += f"""
60
+ <div style="display: flex; align-items: center; justify-content: center;
61
+ padding: 5px 10px; border: 1px solid rgb{color};
62
+ background-color: rgb{color}; border-radius: 5px;
63
+ color: white; font-size: 12px; text-align: center;">
64
+ {label}
65
+ </div>
66
+ """
67
+ legend_html += "</div>"
68
+ return legend_html
69
+
70
+
71
+ examples = [
72
+ ["assets/examples/image1.jpg"],
73
+ ["assets/examples/image2.jpg"],
74
+ ["assets/examples/image3.jpg"],
75
+ ["assets/examples/image4.jpg"],
76
+ ["assets/examples/image5.jpg"]
77
+ ]
78
+
79
+
80
+ with gr.Blocks() as demo:
81
+ gr.Markdown("## Mask Detection with YOLOv3")
82
+ with gr.Row():
83
+ with gr.Column():
84
+ pic = gr.Image(label="Upload Human Image", type="pil", height=300, width=300)
85
+ conf_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.01, label="Confidence Threshold")
86
+ with gr.Row():
87
+ with gr.Column(scale=1):
88
+ predict_btn = gr.Button("Predict")
89
+ with gr.Column(scale=1):
90
+ clear_btn = gr.Button("Clear")
91
+
92
+ with gr.Column():
93
+ output = gr.Image(label="Detection", type="pil", height=300, width=300)
94
+ legend = gr.HTML(label="Legend", value=generate_legend_html_compact())
95
+
96
+ predict_btn.click(fn=detect_mask, inputs=[pic, conf_slider], outputs=output, api_name="predict")
97
+ clear_btn.click(lambda: (None, None), outputs=[pic, output])
98
+ gr.Examples(examples=examples, inputs=[pic])
99
+
100
+ demo.launch()
assets/examples/image1.jpg ADDED
assets/examples/image2.jpg ADDED
assets/examples/image3.jpg ADDED

Git LFS Details

  • SHA256: fc1c5e72c5362f0f1ea703728f51df89ee0715c6c4d33ef03cb143724db6fd56
  • Pointer size: 131 Bytes
  • Size of remote file: 554 kB
assets/examples/image4.jpg ADDED
assets/examples/image5.jpg ADDED

Git LFS Details

  • SHA256: e568c5c129b480bb65705f658e3dbfc89cb3f7c074491122d235c4d2a485c751
  • Pointer size: 131 Bytes
  • Size of remote file: 229 kB
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.6.0
2
+ tqdm==4.67.1
3
+ Pillow==10.4.0
4
+ bs4==0.0.2
5
+ scikit-learn==1.6.0
6
+ torchvision==0.21.0
7
+ wandb==0.19.1
8
+ lxml==5.3.0
9
+ accelerate==1.1.0
10
+ kaggle==1.6.17
11
+ gradio==5.14.0
src/dataset.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Dict
2
+ from pathlib import Path
3
+ import PIL.Image
4
+ import numpy as np
5
+ import torchvision.transforms as T
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ from bs4 import BeautifulSoup
9
+ from bs4.element import Tag
10
+
11
+
12
+ ANCHORS = {
13
+ "small": [(26, 28), (17, 19), (10, 11)],
14
+ "medium": [(78, 88), (55, 59), (37, 42)],
15
+ "large": [(128, 152), (182, 205), (103, 124)]
16
+ }
17
+ GRID_SIZES = [13, 26, 52]
18
+ IMAGE_SIZE = (416, 416)
19
+ NUM_CLASSES = 3
20
+
21
+
22
+ def generate_box(obj: Tag) -> List[int]:
23
+ xmin = int(obj.find("xmin").text) - 1
24
+ ymin = int(obj.find("ymin").text) - 1
25
+ xmax = int(obj.find("xmax").text) - 1
26
+ ymax = int(obj.find("ymax").text) - 1
27
+ if obj.find("name").text == "without_mask":
28
+ class_id = 0
29
+ elif obj.find("name").text == "with_mask":
30
+ class_id = 1
31
+ else:
32
+ class_id = 2
33
+ return [xmin, ymin, xmax, ymax, class_id]
34
+
35
+
36
+ def resize_boxes(box: List[int], scale: float, pad_x: int, pad_y: int) -> Tuple[int]:
37
+ xmin, ymin, xmax, ymax, class_id = box
38
+ xmin = int(xmin * scale + pad_x)
39
+ ymin = int(ymin * scale + pad_y)
40
+ xmax = int(xmax * scale + pad_x)
41
+ ymax = int(ymax * scale + pad_y)
42
+ return (xmin, ymin, xmax, ymax, class_id)
43
+
44
+
45
+ def resize_with_padding(image: PIL.Image.Image, target_size: Tuple[int] = IMAGE_SIZE, fill: Tuple[int] = (255, 255, 255)) -> Tuple[PIL.Image.Image, float, int]:
46
+ target_w, target_h = target_size
47
+ orig_w, orig_h = image.size
48
+ scale = min(target_w / orig_w, target_h / orig_h)
49
+ new_w = int(orig_w * scale)
50
+ new_h = int(orig_h * scale)
51
+ image_resized = image.resize((new_w, new_h), resample=PIL.Image.LANCZOS)
52
+ new_image = PIL.Image.new("RGB", (target_w, target_h), color=fill)
53
+ pad_x = (target_w - new_w) // 2
54
+ pad_y = (target_h - new_h) // 2
55
+ new_image.paste(image_resized, (pad_x, pad_y))
56
+ return new_image, scale, pad_x, pad_y
57
+
58
+
59
+ def build_targets_3scale(bboxes: List[Tuple[int]], image_size: Tuple[int] = IMAGE_SIZE, anchors: Dict[str, List[Tuple[int]]] = ANCHORS, grid_sizes: List[int] = GRID_SIZES, num_classes: int = NUM_CLASSES) -> Tuple[torch.Tensor]:
60
+ img_w, img_h = image_size
61
+ t_large = torch.zeros((grid_sizes[0], grid_sizes[0], 3, 5 + num_classes), dtype=torch.float32)
62
+ t_medium = torch.zeros((grid_sizes[1], grid_sizes[1], 3, 5 + num_classes), dtype=torch.float32)
63
+ t_small = torch.zeros((grid_sizes[2], grid_sizes[2], 3, 5 + num_classes), dtype=torch.float32)
64
+ all_anchors = anchors["large"] + anchors["medium"] + anchors["small"]
65
+ for (xmin, ymin, xmax, ymax, cls_id) in bboxes:
66
+ box_w = xmax - xmin
67
+ box_h = ymax - ymin
68
+ x_center = (xmax + xmin) / 2
69
+ y_center = (ymax + ymin) / 2
70
+ if box_w <= 0 or box_h <= 0:
71
+ continue
72
+ best_iou = 0
73
+ best_idx = 0
74
+ for i, (aw, ah) in enumerate(all_anchors):
75
+ inter = min(box_w, aw) * min(box_h, ah)
76
+ union = box_w * box_h + aw * ah - inter
77
+ iou = inter / union if union > 0 else 0
78
+ if iou > best_iou:
79
+ best_iou = iou
80
+ best_idx = i
81
+ if best_idx <= 2:
82
+ s = grid_sizes[0]
83
+ t = t_large
84
+ local_anchor_id = best_idx
85
+ anchor_w, anchor_h = anchors["large"][local_anchor_id]
86
+ elif best_idx <= 5:
87
+ s = grid_sizes[1]
88
+ t = t_medium
89
+ local_anchor_id = best_idx - 3
90
+ anchor_w, anchor_h = anchors["medium"][local_anchor_id]
91
+ else:
92
+ s = grid_sizes[2]
93
+ t = t_small
94
+ local_anchor_id = best_idx - 6
95
+ anchor_w, anchor_h = anchors["small"][local_anchor_id]
96
+ cell_w = img_w / s
97
+ cell_h = img_h / s
98
+ gx = int(x_center // cell_w)
99
+ gy = int(y_center // cell_h)
100
+ tx = (x_center / cell_w) - gx
101
+ ty = (y_center / cell_h) - gy
102
+ tw = np.log((box_w / (anchor_w + 1e-16)) + 1e-16)
103
+ th = np.log((box_h / (anchor_h + 1e-16)) + 1e-16)
104
+ t[gy, gx, local_anchor_id, 0] = tx
105
+ t[gy, gx, local_anchor_id, 1] = ty
106
+ t[gy, gx, local_anchor_id, 2] = tw
107
+ t[gy, gx, local_anchor_id, 3] = th
108
+ t[gy, gx, local_anchor_id, 4] = 1.0
109
+ t[gy, gx, local_anchor_id, 5 + cls_id] = 1.0
110
+ return t_large, t_medium, t_small
111
+
112
+
113
+ class MaskDataset(Dataset):
114
+ def __init__(self, root: str, train: bool = True, test_size: float = 0.25) -> None:
115
+ super().__init__()
116
+ self.class_counts = [0, 0, 0]
117
+ self.root = root
118
+ self.train = train
119
+ all_imgs = sorted(list((Path(root) / "images").glob("*.png")))
120
+ all_anns = sorted(list((Path(root) / "annotations").glob("*.xml")))
121
+ n_test = int(len(all_imgs) * test_size)
122
+ if train:
123
+ self.images = all_imgs[n_test:]
124
+ self.annots = all_anns[n_test:]
125
+ else:
126
+ self.images = all_imgs[:n_test]
127
+ self.annots = all_anns[:n_test]
128
+ self.transform = T.Compose([
129
+ T.ToTensor(),
130
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
131
+ ])
132
+ for ann in self.annots:
133
+ with open(ann, "r") as f:
134
+ data = f.read()
135
+ soup = BeautifulSoup(data, "lxml")
136
+ for obj in soup.find_all("object"):
137
+ cls = obj.find("name").text
138
+ self.class_counts[0 if cls == "without_mask" else 1 if cls == "with_mask" else 2] += 1
139
+
140
+ def __len__(self) -> int:
141
+ return len(self.images)
142
+
143
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
144
+ img_path = self.images[idx]
145
+ ann_path = self.annots[idx]
146
+ img = PIL.Image.open(img_path).convert("RGB")
147
+ img_resized, scale, pad_x, pad_y = resize_with_padding(img)
148
+ with open(ann_path, "r") as f:
149
+ data = f.read()
150
+ soup = BeautifulSoup(data, "lxml")
151
+ objs = soup.find_all("object")
152
+ resized_boxes = []
153
+ for obj in objs:
154
+ b = generate_box(obj)
155
+ b2 = resize_boxes(b, scale, pad_x, pad_y)
156
+ resized_boxes.append(b2)
157
+ t_large, t_medium, t_small = build_targets_3scale(resized_boxes)
158
+ img_tensor = self.transform(img_resized)
159
+ return img_tensor, (t_large, t_medium, t_small)
160
+
161
+
162
+ def collate_fn(batch: List[Tuple[torch.Tensor, Tuple[torch.Tensor]]]) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
163
+ imgs, t_l, t_m, t_s = [], [], [], []
164
+ for (img, (tl, tm, ts)) in batch:
165
+ imgs.append(img)
166
+ t_l.append(tl)
167
+ t_m.append(tm)
168
+ t_s.append(ts)
169
+ imgs = torch.stack(imgs, dim=0)
170
+ t_l = torch.stack(t_l, dim=0)
171
+ t_m = torch.stack(t_m, dim=0)
172
+ t_s = torch.stack(t_s, dim=0)
173
+ return imgs, (t_l, t_m, t_s)
src/loss.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def box_iou_xyxy(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
8
+ N = boxes1.size(0)
9
+ M = boxes2.size(0)
10
+ x1_1, y1_1, x2_1, y2_1 = boxes1[:, 0], boxes1[:, 1], boxes1[:, 2], boxes1[:, 3]
11
+ x1_2, y1_2, x2_2, y2_2 = boxes2[:, 0], boxes2[:, 1], boxes2[:, 2], boxes2[:, 3]
12
+ x1_1 = x1_1.unsqueeze(1).expand(N, M)
13
+ y1_1 = y1_1.unsqueeze(1).expand(N, M)
14
+ x2_1 = x2_1.unsqueeze(1).expand(N, M)
15
+ y2_1 = y2_1.unsqueeze(1).expand(N, M)
16
+ x1_2 = x1_2.unsqueeze(0).expand(N, M)
17
+ y1_2 = y1_2.unsqueeze(0).expand(N, M)
18
+ x2_2 = x2_2.unsqueeze(0).expand(N, M)
19
+ y2_2 = y2_2.unsqueeze(0).expand(N, M)
20
+ interX1 = torch.max(x1_1, x1_2)
21
+ interY1 = torch.max(y1_1, y1_2)
22
+ interX2 = torch.min(x2_1, x2_2)
23
+ interY2 = torch.min(y2_1, y2_2)
24
+ interW = (interX2 - interX1).clamp(min=0)
25
+ interH = (interY2 - interY1).clamp(min=0)
26
+ interArea = interW * interH
27
+ area1 = (x2_1 - x1_1).clamp(min=0) * (y2_1 - y1_1).clamp(min=0)
28
+ area2 = (x2_2 - x1_2).clamp(min=0) * (y2_2 - y1_2).clamp(min=0)
29
+ union = area1 + area2 - interArea + 1e-16
30
+ iou = interArea / union
31
+ return iou
32
+
33
+
34
+ def box_giou_xyxy(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
35
+ xA = torch.max(boxes1[:, 0], boxes2[:, 0])
36
+ yA = torch.max(boxes1[:, 1], boxes2[:, 1])
37
+ xB = torch.min(boxes1[:, 2], boxes2[:, 2])
38
+ yB = torch.min(boxes1[:, 3], boxes2[:, 3])
39
+ interW = (xB - xA).clamp(min=0)
40
+ interH = (yB - yA).clamp(min=0)
41
+ interArea = interW * interH
42
+ area1 = (boxes1[:, 2] - boxes1[:, 0]).clamp(min=0) * (boxes1[:, 3] - boxes1[:, 1]).clamp(min=0)
43
+ area2 = (boxes2[:, 2] - boxes2[:, 0]).clamp(min=0) * (boxes2[:, 3] - boxes2[:, 1]).clamp(min=0)
44
+ union = area1 + area2 - interArea + 1e-16
45
+ iou = interArea / union
46
+ xC1 = torch.min(boxes1[:, 0], boxes2[:, 0])
47
+ yC1 = torch.min(boxes1[:, 1], boxes2[:, 1])
48
+ xC2 = torch.max(boxes1[:, 2], boxes2[:, 2])
49
+ yC2 = torch.max(boxes1[:, 3], boxes2[:, 3])
50
+ encloseW = (xC2 - xC1).clamp(min=0)
51
+ encloseH = (yC2 - yC1).clamp(min=0)
52
+ encloseArea = encloseW * encloseH + 1e-16
53
+ giou = iou - (encloseArea - union) / encloseArea
54
+ return giou
55
+
56
+
57
+ class YoloLoss(nn.Module):
58
+ def __init__(self, class_counts: List[int], anchors_l: List[int] = [(128, 152), (182, 205), (103, 124)], anchors_m: List[int] = [(78, 88), (55, 59), (37, 42)], anchors_s: List[int] = [(26, 28), (17, 19), (10, 11)], image_size: Tuple[int] = (416, 416), num_classes: int = 3, ignore_thresh: float = 0.7, lambda_noobj: float = 5.0):
59
+ super().__init__()
60
+ self.anchors_l = anchors_l
61
+ self.anchors_m = anchors_m
62
+ self.anchors_s = anchors_s
63
+ self.image_size = image_size
64
+ self.num_classes = num_classes
65
+ self.ignore_thresh = ignore_thresh
66
+ self.lambda_noobj = lambda_noobj
67
+ total = sum(class_counts)
68
+ w_list = [total / (c + 1e-5) * (2.0 if c_id == 0 else (3.0 if c_id == 2 else 1.0)) for c_id, c in enumerate(class_counts)]
69
+ self.class_weight = torch.tensor(w_list, dtype=torch.float32)
70
+ self.bce_obj = nn.BCEWithLogitsLoss(reduction="none")
71
+ self.bce_cls = nn.BCEWithLogitsLoss(weight=self.class_weight, reduction="none")
72
+
73
+ def forward(self, outputs: Tuple[torch.Tensor], targets: Tuple[torch.Tensor]) -> torch.Tensor:
74
+ out_l, out_m, out_s = outputs
75
+ t_l, t_m, t_s = targets
76
+ loss_l = self._loss_single_scale(out_l, t_l, self.anchors_l, scale_wh=(13, 13))
77
+ loss_m = self._loss_single_scale(out_m, t_m, self.anchors_m, scale_wh=(26, 26))
78
+ loss_s = self._loss_single_scale(out_s, t_s, self.anchors_s, scale_wh=(52, 52))
79
+ return loss_l + loss_m + loss_s
80
+
81
+ def _loss_single_scale(self, pred: torch.Tensor, target: torch.Tensor, anchors: List[Tuple[int]], scale_wh: Tuple[int]) -> torch.Tensor:
82
+ device = pred.device
83
+ B, _, H, W = pred.shape
84
+ A = len(anchors)
85
+ pred = pred.view(B, A, (5 + self.num_classes), H, W)
86
+ pred = pred.permute(0, 3, 4, 1, 2).contiguous()
87
+ pred_tx = pred[..., 0]
88
+ pred_ty = pred[..., 1]
89
+ pred_tw = pred[..., 2]
90
+ pred_th = pred[..., 3]
91
+ pred_obj = pred[..., 4]
92
+ pred_cls = pred[..., 5:]
93
+ tgt_tx = target[..., 0]
94
+ tgt_ty = target[..., 1]
95
+ tgt_tw = target[..., 2]
96
+ tgt_th = target[..., 3]
97
+ tgt_obj = target[..., 4]
98
+ tgt_cls = target[..., 5:]
99
+ obj_mask = (tgt_obj == 1)
100
+ noobj_mask = (tgt_obj == 0)
101
+ img_w, img_h = self.image_size
102
+ stride_x = img_w / W
103
+ stride_y = img_h / H
104
+ grid_x = torch.arange(W, device=device).view(1, 1, W, 1).expand(1, H, W, 1)
105
+ grid_y = torch.arange(H, device=device).view(1, H, 1, 1).expand(1, H, W, 1)
106
+ anchors_t = torch.tensor(anchors, dtype=torch.float, device=device)
107
+ anchor_w = anchors_t[:, 0].view(1, 1, 1, A)
108
+ anchor_h = anchors_t[:, 1].view(1, 1, 1, A)
109
+ pred_box_xc = (grid_x + torch.sigmoid(pred_tx)) * stride_x
110
+ pred_box_yc = (grid_y + torch.sigmoid(pred_ty)) * stride_y
111
+ pred_box_w = torch.exp(pred_tw) * anchor_w
112
+ pred_box_h = torch.exp(pred_th) * anchor_h
113
+ pred_x1 = pred_box_xc - pred_box_w / 2
114
+ pred_y1 = pred_box_yc - pred_box_h / 2
115
+ pred_x2 = pred_box_xc + pred_box_w / 2
116
+ pred_y2 = pred_box_yc + pred_box_h / 2
117
+ gt_box_xc = (grid_x + tgt_tx) * stride_x
118
+ gt_box_yc = (grid_y + tgt_ty) * stride_y
119
+ gt_box_w = torch.exp(tgt_tw) * anchor_w
120
+ gt_box_h = torch.exp(tgt_th) * anchor_h
121
+ gt_x1 = gt_box_xc - gt_box_w / 2
122
+ gt_y1 = gt_box_yc - gt_box_h /2
123
+ gt_x2 = gt_box_xc + gt_box_w / 2
124
+ gt_y2 = gt_box_yc + gt_box_h / 2
125
+ with torch.no_grad():
126
+ ignore_mask_buf = torch.zeros_like(tgt_obj, dtype=torch.bool)
127
+ noobj_flat = noobj_mask.view(-1)
128
+ obj_flat = obj_mask.view(-1)
129
+ px1f = pred_x1.view(-1)
130
+ py1f = pred_y1.view(-1)
131
+ px2f = pred_x2.view(-1)
132
+ py2f = pred_y2.view(-1)
133
+ gx1f = gt_x1.view(-1)[obj_flat]
134
+ gy1f = gt_y1.view(-1)[obj_flat]
135
+ gx2f = gt_x2.view(-1)[obj_flat]
136
+ gy2f = gt_y2.view(-1)[obj_flat]
137
+ if noobj_flat.sum() > 0 and obj_flat.sum() > 0:
138
+ noobj_idx = noobj_flat.nonzero(as_tuple=True)[0]
139
+ noobj_boxes_xyxy = torch.stack([px1f[noobj_idx], py1f[noobj_idx], px2f[noobj_idx], py2f[noobj_idx]], dim=-1)
140
+ obj_boxes_xyxy = torch.stack([gx1f, gy1f, gx2f, gy2f], dim=-1)
141
+ ious = box_iou_xyxy(noobj_boxes_xyxy, obj_boxes_xyxy)
142
+ best_iou, _ = ious.max(dim=1)
143
+ ignore_flags = (best_iou > self.ignore_thresh)
144
+ all_idx = noobj_idx[ignore_flags]
145
+ ignore_mask_buf.view(-1)[all_idx] = True
146
+ ignore_mask = ignore_mask_buf
147
+ obj_loss = self.bce_obj(pred_obj[obj_mask], torch.ones_like(pred_obj[obj_mask]))
148
+ obj_loss = obj_loss.mean() if obj_loss.numel() > 0 else torch.tensor(0., device=device)
149
+ noobj_mask_final = (noobj_mask & (~ignore_mask))
150
+ noobj_loss = self.bce_obj(pred_obj[noobj_mask_final], torch.zeros_like(pred_obj[noobj_mask_final]))
151
+ noobj_loss = noobj_loss.mean() if noobj_loss.numel() > 0 else torch.tensor(0., device=device)
152
+ objectness_loss = obj_loss + self.lambda_noobj * noobj_loss
153
+ class_loss = torch.tensor(0., device=device, requires_grad=True)
154
+ if obj_mask.sum() > 0:
155
+ self.bce_cls.weight = self.class_weight.to(device)
156
+ cls_pred = pred_cls[obj_mask].to(device)
157
+ cls_gt = tgt_cls[obj_mask].to(device)
158
+ c_loss = self.bce_cls(cls_pred, cls_gt)
159
+ class_loss = c_loss.mean()
160
+ giou_loss = torch.tensor(0., device=device, requires_grad=True)
161
+ if obj_mask.sum() > 0:
162
+ px1_ = pred_x1[obj_mask]
163
+ py1_ = pred_y1[obj_mask]
164
+ px2_ = pred_x2[obj_mask]
165
+ py2_ = pred_y2[obj_mask]
166
+ p_xyxy = torch.stack([px1_,py1_,px2_,py2_], dim=-1)
167
+ gx1_ = gt_x1[obj_mask]
168
+ gy1_ = gt_y1[obj_mask]
169
+ gx2_ = gt_x2[obj_mask]
170
+ gy2_ = gt_y2[obj_mask]
171
+ g_xyxy = torch.stack([gx1_,gy1_,gx2_,gy2_], dim=-1)
172
+ giou = box_giou_xyxy(p_xyxy, g_xyxy)
173
+ giou_loss = (1. - giou).mean()
174
+ total_loss = objectness_loss + class_loss + giou_loss
175
+ if total_loss is None:
176
+ pass
177
+ return total_loss
src/models/yolov3.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def conv_batch(in_ch: int, out_ch: int, kernel_size: int = 3, padding: int = 1, stride: int = 1) -> nn.Sequential:
8
+ return nn.Sequential(
9
+ nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
10
+ nn.BatchNorm2d(out_ch),
11
+ nn.LeakyReLU()
12
+ )
13
+
14
+
15
+ class DarkResidualBlock(nn.Module):
16
+ def __init__(self, in_channels: int) -> None:
17
+ super().__init__()
18
+ reduced_channels = in_channels // 2
19
+ self.layer1 = conv_batch(in_channels, reduced_channels, kernel_size=1, padding=0)
20
+ self.layer2 = conv_batch(reduced_channels, in_channels)
21
+
22
+ def forward(self, x):
23
+ return x + self.layer2(self.layer1(x))
24
+
25
+
26
+ class Darknet53(nn.Module):
27
+ def __init__(self, block: nn.Module = DarkResidualBlock) -> None:
28
+ super().__init__()
29
+ self.conv1 = conv_batch(3, 32)
30
+ self.conv2 = conv_batch(32, 64, stride=2)
31
+ self.residual_block1 = self.make_layer(block, in_channels=64, num_blocks=1)
32
+ self.conv3 = conv_batch(64, 128, stride=2)
33
+ self.residual_block2 = self.make_layer(block, in_channels=128, num_blocks=2)
34
+ self.conv4 = conv_batch(128, 256, stride=2)
35
+ self.residual_block3 = self.make_layer(block, in_channels=256, num_blocks=8)
36
+ self.conv5 = conv_batch(256, 512, stride=2)
37
+ self.residual_block4 = self.make_layer(block, in_channels=512, num_blocks=8)
38
+ self.conv6 = conv_batch(512, 1024, stride=2)
39
+ self.residual_block5 = self.make_layer(block, in_channels=1024, num_blocks=4)
40
+
41
+ def make_layer(self, block: nn.Module, in_channels: int, num_blocks: int) -> nn.Sequential:
42
+ layers = []
43
+ for _ in range(num_blocks):
44
+ layers.append(block(in_channels))
45
+ return nn.Sequential(*layers)
46
+
47
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
48
+ x = self.conv1(x)
49
+ x = self.conv2(x)
50
+ x = self.residual_block1(x)
51
+ x = self.conv3(x)
52
+ x = self.residual_block2(x)
53
+ x = self.conv4(x)
54
+ x = self.residual_block3(x)
55
+ c4 = x
56
+ x = self.conv5(x)
57
+ x = self.residual_block4(x)
58
+ c5 = x
59
+ x = self.conv6(x)
60
+ x = self.residual_block5(x)
61
+ c6 = x
62
+ return c4, c5, c6
63
+
64
+
65
+ def conv_leaky(in_ch: int, out_ch: int, k: int = 1, s: int = 1, p: int = 0):
66
+ return nn.Sequential(
67
+ nn.Conv2d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=False),
68
+ nn.BatchNorm2d(out_ch),
69
+ nn.LeakyReLU(0.1, inplace=True)
70
+ )
71
+
72
+
73
+ class DetectionHead(nn.Module):
74
+ def __init__(self, in_ch: int, mid_ch: int, num_anchors: int = 3, num_classes: int = 3) -> None:
75
+ super().__init__()
76
+ self.block = nn.Sequential(
77
+ conv_leaky(in_ch, mid_ch, k=1, s=1, p=0),
78
+ conv_leaky(mid_ch, mid_ch * 2, k=3, s=1, p=1),
79
+ conv_leaky(mid_ch * 2, mid_ch, k=1, s=1, p=0),
80
+ conv_leaky(mid_ch, mid_ch * 2, k=3, s=1, p=1),
81
+ conv_leaky(mid_ch * 2, mid_ch, k=1, s=1, p=0)
82
+ )
83
+ self.out_conv = nn.Conv2d(mid_ch, num_anchors * (5 + num_classes), kernel_size=1, stride=1, padding=0)
84
+
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ x = self.block(x)
87
+ out = self.out_conv(x)
88
+ return out
89
+
90
+
91
+ class YOLOv3(nn.Module):
92
+ def __init__(self, num_classes: int = 3) -> None:
93
+ super().__init__()
94
+ self.backbone = Darknet53()
95
+ self.num_classes = num_classes
96
+ self.num_anchors = 3
97
+ self.head_large = DetectionHead(in_ch=1024, mid_ch=512, num_anchors=3, num_classes=num_classes)
98
+ self.head_medium = DetectionHead(in_ch=1024, mid_ch=256, num_anchors=3, num_classes=num_classes)
99
+ self.head_small = DetectionHead(in_ch=512, mid_ch=128, num_anchors=3, num_classes=num_classes)
100
+ self.conv_upsample_l2 = conv_leaky(1024, 512, k=1, s=1, p=0)
101
+ self.conv_upsample_l3 = conv_leaky(1024, 256, k=1, s=1, p=0)
102
+
103
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
104
+ c4, c5, c6 = self.backbone(x)
105
+ out_l = self.head_large(c6)
106
+ x_l2 = self.conv_upsample_l2(c6)
107
+ x_l2_up = F.interpolate(x_l2, scale_factor=2, mode="nearest")
108
+ x_merge_l2 = torch.cat([x_l2_up, c5], dim=1)
109
+ out_m = self.head_medium(x_merge_l2)
110
+ x_l3 = self.conv_upsample_l3(x_merge_l2)
111
+ x_l3_up = F.interpolate(x_l3, scale_factor=2, mode="nearest")
112
+ x_merge_l3 = torch.cat([x_l3_up, c4], dim=1)
113
+ out_s = self.head_small(x_merge_l3)
114
+ return out_l, out_m, out_s
src/train.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from pathlib import Path
3
+ from typing import List, Tuple, Dict
4
+ from tqdm import tqdm
5
+ import argparse
6
+ from accelerate import Accelerator
7
+ from accelerate.utils import set_seed
8
+ import wandb
9
+ import torch
10
+ from torch import nn
11
+ from torch.utils.data import DataLoader
12
+ import torchvision.ops as ops
13
+ import PIL
14
+ import numpy as np
15
+
16
+ from dataset import MaskDataset, collate_fn, ANCHORS
17
+ from utils import EMA
18
+ from models.yolov3 import YOLOv3
19
+ from loss import YoloLoss
20
+
21
+
22
+ class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler):
23
+ def __init__(self, optimizer: torch.optim.Optimizer, warmup_steps: int, total_steps: int, eta_min: int = 0, last_epoch: int = -1) -> None:
24
+ self.warmup_steps = warmup_steps
25
+ self.total_steps = total_steps
26
+ self.eta_min = eta_min
27
+ super().__init__(optimizer, last_epoch)
28
+
29
+ def get_lr(self) -> List[float]:
30
+ if self.last_epoch < self.warmup_steps:
31
+ return [base_lr * (self.last_epoch / max(1, self.warmup_steps)) for base_lr in self.base_lrs]
32
+ else:
33
+ current_step = self.last_epoch - self.warmup_steps
34
+ cosine_steps = max(1, self.total_steps - self.warmup_steps)
35
+ return [self.eta_min + (base_lr - self.eta_min) * 0.5 * (1 + math.cos(math.pi * current_step / cosine_steps)) for base_lr in self.base_lrs]
36
+
37
+
38
+ def draw_bounding_boxes(image: PIL.Image.Image, boxes: torch.Tensor, colors: Dict[int, int] = {0: (178, 34, 34), 1: (34, 139, 34), 2: (184, 134, 11)}, labels = {0: "without_mask", 1: "with_mask", 2: "weared_incorrect"}, show_conf = False) -> None:
39
+ draw = PIL.ImageDraw.Draw(image)
40
+ for box in boxes:
41
+ xmin, ymin, xmax, ymax, class_id = int(box[0]), int(box[1]), int(box[2]), int(box[3]), int(box[-1])
42
+ conf_text = ""
43
+ if show_conf and box.shape[0] == 6:
44
+ conf = float(box[4])
45
+ conf_text = f" {conf:.2f}"
46
+ color = colors.get(class_id, (255, 255, 255))
47
+ label = labels.get(class_id, "Unknown") + conf_text
48
+ draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=2)
49
+ text_bbox = draw.textbbox((xmin, ymin), label)
50
+ text_width = text_bbox[2] - text_bbox[0]
51
+ text_height = text_bbox[3] - text_bbox[1]
52
+ draw.rectangle([xmin, ymin - text_height - 2, xmin + text_width + 2, ymin], fill=color)
53
+ draw.text((xmin + 1, ymin - text_height - 1), label, fill="white")
54
+
55
+
56
+ def create_combined_image(img: torch.Tensor, gt_batch: List[torch.Tensor], results: List[torch.Tensor], mean: List[float] = [0.485, 0.456, 0.406], std: List[float] = [0.229, 0.224, 0.225]):
57
+ batch_size, _, height, width = img.shape
58
+ combined_height = height * 2
59
+ combined_width = width * batch_size
60
+ combined_image = np.zeros((combined_height, combined_width, 3), dtype=np.uint8)
61
+
62
+ for i in range(batch_size):
63
+ image = img[i].cpu().permute(1, 2, 0).numpy()
64
+ image = (image * std + mean).clip(0, 1)
65
+ image = (image * 255).astype(np.uint8)
66
+ gt_image = PIL.Image.fromarray(image.copy())
67
+ pred_image = PIL.Image.fromarray(image.copy())
68
+ draw_bounding_boxes(gt_image, gt_batch[i])
69
+ draw_bounding_boxes(pred_image, results[i], show_conf=True)
70
+ combined_image[:height, i * width:(i + 1) * width, :] = np.array(gt_image)
71
+ combined_image[height:, i * width:(i + 1) * width, :] = np.array(pred_image)
72
+ return PIL.Image.fromarray(combined_image)
73
+
74
+
75
+ def decode_yolo_output_single(prediction: torch.Tensor, anchors: List[Tuple[int]], image_size: Tuple[int] = (416, 416), conf_threshold: float = 0.5, iou_threshold: float = 0.3, apply_nms: bool = True, num_classes: int = 3) -> List[torch.Tensor]:
76
+ device = prediction.device
77
+ B, _, H, W = prediction.shape
78
+ A = len(anchors)
79
+ prediction = prediction.view(B, A, 5 + num_classes, H, W)
80
+ prediction = prediction.permute(0, 1, 3, 4, 2).contiguous()
81
+ tx = prediction[..., 0]
82
+ ty = prediction[..., 1]
83
+ tw = prediction[..., 2]
84
+ th = prediction[..., 3]
85
+ obj = prediction[..., 4]
86
+ class_scores = prediction[..., 5:]
87
+ tx = tx.sigmoid()
88
+ ty = ty.sigmoid()
89
+ obj = obj.sigmoid()
90
+ class_scores = class_scores.softmax(dim=-1)
91
+ img_w, img_h = image_size
92
+ cell_w = img_w / W
93
+ cell_h = img_h / H
94
+ grid_x = torch.arange(W, device=device).view(1, 1, W).expand(1, H, W)
95
+ grid_y = torch.arange(H, device=device).view(1, H, 1).expand(1, H, W)
96
+ anchors_tensor = torch.tensor(anchors, dtype=torch.float32, device=device)
97
+ anchor_w = anchors_tensor[:, 0].view(1, A, 1, 1)
98
+ anchor_h = anchors_tensor[:, 1].view(1, A, 1, 1)
99
+ x_center = (grid_x + tx) * cell_w
100
+ y_center = (grid_y + ty) * cell_h
101
+ w = torch.exp(tw) * anchor_w
102
+ h = torch.exp(th) * anchor_h
103
+ xmin = x_center - w / 2
104
+ ymin = y_center - h / 2
105
+ xmax = x_center + w / 2
106
+ ymax = y_center + h / 2
107
+ max_class_probs, class_ids = class_scores.max(dim=-1)
108
+ confidence = obj * max_class_probs
109
+ outputs = []
110
+ for b_i in range(B):
111
+ box_xmin = xmin[b_i].view(-1)
112
+ box_ymin = ymin[b_i].view(-1)
113
+ box_xmax = xmax[b_i].view(-1)
114
+ box_ymax = ymax[b_i].view(-1)
115
+ conf = confidence[b_i].view(-1)
116
+ cls_id = class_ids[b_i].view(-1).float()
117
+ mask = (conf > conf_threshold)
118
+ box_xmin = box_xmin[mask]
119
+ box_ymin = box_ymin[mask]
120
+ box_xmax = box_xmax[mask]
121
+ box_ymax = box_ymax[mask]
122
+ conf = conf[mask]
123
+ cls_id = cls_id[mask]
124
+ if mask.sum() == 0:
125
+ outputs.append(torch.empty((0, 6), device=device))
126
+ continue
127
+ boxes = torch.stack([box_xmin, box_ymin, box_xmax, box_ymax], dim=-1)
128
+ if apply_nms:
129
+ keep = ops.nms(boxes, conf, iou_threshold)
130
+ boxes = boxes[keep]
131
+ conf = conf[keep]
132
+ cls_id = cls_id[keep]
133
+ out = torch.cat([boxes, conf.unsqueeze(-1), cls_id.unsqueeze(-1)], dim=-1)
134
+ outputs.append(out)
135
+ return outputs
136
+
137
+
138
+ def decode_predictions_3scales(out_l: torch.Tensor, out_m: torch.Tensor, out_s: torch.Tensor, anchors_l: List[Tuple[int]], anchors_m: List[Tuple[int, int]], anchors_s: List[Tuple[int, int]], image_size: Tuple[int, int] = (416, 416), conf_threshold: float = 0.5, iou_threshold: float = 0.45, num_classes: int = 3) -> List[torch.Tensor]:
139
+ b_l = decode_yolo_output_single(out_l, anchors_l, image_size, conf_threshold, iou_threshold, apply_nms=False, num_classes=num_classes)
140
+ b_m = decode_yolo_output_single(out_m, anchors_m, image_size, conf_threshold, iou_threshold, apply_nms=False, num_classes=num_classes)
141
+ b_s = decode_yolo_output_single(out_s, anchors_s, image_size, conf_threshold, iou_threshold, apply_nms=False, num_classes=num_classes)
142
+ results = []
143
+ B = len(b_l)
144
+ for i in range(B):
145
+ boxes_all = torch.cat([b_l[i], b_m[i], b_s[i]], dim=0)
146
+ if boxes_all.numel() == 0:
147
+ results.append(boxes_all)
148
+ continue
149
+ xyxy = boxes_all[:, :4]
150
+ scores = boxes_all[:, 4]
151
+ keep = ops.nms(xyxy, scores, iou_threshold)
152
+ final = boxes_all[keep]
153
+ results.append(final)
154
+ return results
155
+
156
+
157
+ def decode_target_single(target: torch.Tensor, anchors: List[Tuple[int]], image_size: Tuple[int] = (416, 416), obj_threshold: float = 0.5) -> List[torch.Tensor]:
158
+ args = parse_args()
159
+ target = target.to(args.device)
160
+ B, S, _, A, _ = target.shape
161
+ img_w, img_h = image_size
162
+ cell_w = img_w / S
163
+ cell_h = img_h / S
164
+ anchors_tensor = torch.tensor(anchors, dtype=torch.float)
165
+ tx = target[..., 0]
166
+ ty = target[..., 1]
167
+ tw = target[..., 2]
168
+ th = target[..., 3]
169
+ tobj = target[..., 4]
170
+ tcls = target[..., 5:]
171
+ results = []
172
+ for b_i in range(B):
173
+ bx_list = []
174
+ tx_b = tx[b_i]
175
+ ty_b = ty[b_i]
176
+ tw_b = tw[b_i]
177
+ th_b = th[b_i]
178
+ tobj_b = tobj[b_i]
179
+ tcls_b = tcls[b_i]
180
+ for i in range(S):
181
+ for j in range(S):
182
+ for a_i in range(A):
183
+ if tobj_b[i,j,a_i] < obj_threshold:
184
+ continue
185
+ cls_one_hot = tcls_b[i, j, a_i]
186
+ cls_id = cls_one_hot.argmax().item()
187
+ x_center = (j + tx_b[i, j, a_i].item()) * cell_w
188
+ y_center = (i + ty_b[i, j, a_i].item()) * cell_h
189
+ anchor_w = anchors_tensor[a_i, 0]
190
+ anchor_h = anchors_tensor[a_i, 1]
191
+ box_w = torch.exp(tw_b[i, j, a_i]) * anchor_w
192
+ box_h = torch.exp(th_b[i, j, a_i]) * anchor_h
193
+ xmin = x_center - box_w / 2
194
+ ymin = y_center - box_h / 2
195
+ xmax = x_center + box_w / 2
196
+ ymax = y_center + box_h / 2
197
+ bx_list.append([xmin.item(), ymin.item(), xmax.item(), ymax.item(), cls_id])
198
+ if len(bx_list) == 0:
199
+ results.append(torch.empty((0, 5), dtype=torch.float32, device=args.device))
200
+ else:
201
+ results.append(torch.tensor(bx_list, dtype=torch.float32, device=args.device))
202
+ return results
203
+
204
+
205
+ def decode_target_3scales(t_l: torch.Tensor, t_m: torch.Tensor, t_s: torch.Tensor, anchors_l: List[Tuple[int]], anchors_m: List[Tuple[int]], anchors_s: List[Tuple[int]], image_size: Tuple[int] = (416, 416), obj_threshold: float = 0.5) -> List[torch.Tensor]:
206
+ dec_l = decode_target_single(t_l, anchors_l, image_size, obj_threshold)
207
+ dec_m = decode_target_single(t_m, anchors_m, image_size, obj_threshold)
208
+ dec_s = decode_target_single(t_s, anchors_s, image_size, obj_threshold)
209
+ results = []
210
+ B = len(dec_l)
211
+ for i in range(B):
212
+ boxes_l = dec_l[i]
213
+ boxes_m = dec_m[i]
214
+ boxes_s = dec_s[i]
215
+ if boxes_l.numel() == 0 and boxes_m.numel() == 0 and boxes_s.numel() == 0:
216
+ results.append(torch.empty((0, 5), dtype=torch.float32, device=boxes_l.device))
217
+ else:
218
+ all_ = torch.cat([boxes_l, boxes_m, boxes_s], dim=0)
219
+ results.append(all_)
220
+ return results
221
+
222
+
223
+ def iou_xyxy(box1: List[int | float], box2: List[int | float]) -> float:
224
+ x1 = max(box1[0], box2[0])
225
+ y1 = max(box1[1], box2[1])
226
+ x2 = min(box1[2], box2[2])
227
+ y2 = min(box1[3], box2[3])
228
+ w = max(0., x2 - x1)
229
+ h = max(0., y2 - y1)
230
+ inter = w * h
231
+ area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
232
+ area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
233
+ union = area1 + area2 - inter
234
+ return inter / union if union > 0 else 0.0
235
+
236
+
237
+ def compute_ap_per_class(boxes_pred: List[List[float]], boxes_gt: List[List[float]], iou_threshold: float = 0.45) -> float:
238
+ boxes_pred = sorted(boxes_pred, key=lambda x: x[4], reverse=True)
239
+ n_gt = len(boxes_gt)
240
+ if n_gt == 0 and len(boxes_pred) == 0:
241
+ return 1.0
242
+ if n_gt == 0:
243
+ return 0.0
244
+ matched = [False] * n_gt
245
+ tps = []
246
+ fps = []
247
+ for i, pred in enumerate(boxes_pred):
248
+ best_iou = 0.0
249
+ best_j = -1
250
+ for j, gt in enumerate(boxes_gt):
251
+ if matched[j]:
252
+ continue
253
+ iou = iou_xyxy(pred, gt)
254
+ if iou > best_iou:
255
+ best_iou = iou
256
+ best_j = j
257
+ if best_iou > iou_threshold and best_j >= 0:
258
+ tps.append(1)
259
+ fps.append(0)
260
+ matched[best_j] = True
261
+ else:
262
+ tps.append(0)
263
+ fps.append(1)
264
+ tps_cum = []
265
+ fps_cum = []
266
+ s_tp = 0
267
+ s_fp = 0
268
+ for i in range(len(tps)):
269
+ s_tp += tps[i]
270
+ s_fp += fps[i]
271
+ tps_cum.append(s_tp)
272
+ fps_cum.append(s_fp)
273
+ precisions = []
274
+ recalls = []
275
+ for i in range(len(tps)):
276
+ prec = tps_cum[i] / (tps_cum[i] + fps_cum[i]) if (tps_cum[i] + fps_cum[i]) > 0 else 0
277
+ rec = tps_cum[i] / n_gt
278
+ precisions.append(prec)
279
+ recalls.append(rec)
280
+ recalls = [0.0] + recalls + [1.0]
281
+ precisions = [1.0] + precisions + [0.0]
282
+ for i in range(len(precisions) - 2, -1, -1):
283
+ precisions[i] = max(precisions[i], precisions[i+1])
284
+ ap = 0.0
285
+ for i in range(len(precisions) - 1):
286
+ ap += (recalls[i+1] - recalls[i]) * precisions[i+1]
287
+ return ap
288
+
289
+
290
+ def compute_map(all_pred: List[float], all_gt: List[float], num_classes: int = 3, iou_threshold: float = 0.45) -> float:
291
+ APs = []
292
+ for c in range(num_classes):
293
+ ap_c = compute_ap_per_class(all_pred[c], all_gt[c], iou_threshold)
294
+ APs.append(ap_c)
295
+ mAP = sum(APs) / len(APs) if len(APs) > 0 else 0.0
296
+ return mAP
297
+
298
+
299
+ def parse_args():
300
+ parser = argparse.ArgumentParser(description="Train a model on the face mask detection dataset")
301
+ parser.add_argument("--root", type=str, default="data/masks", help="Path to the data")
302
+ parser.add_argument("--batch-size", type=int, default=16, help="Batch size for training and testing")
303
+ parser.add_argument("--logs-dir", type=str, default="yolo-logs", help="Path to save logs")
304
+ parser.add_argument("--pin-memory", type=bool, default=True, help="Pin Memory for DataLoader")
305
+ parser.add_argument("--num-workers", type=int, default=0, help="Number of workers for DataLoader")
306
+ parser.add_argument("--num-epochs", type=int, default=100, help="Number of training epochs")
307
+ parser.add_argument("--optimizer", type=str, default="AdamW", help="Optimizer type")
308
+ parser.add_argument("--learning-rate", type=float, default=5e-4, help="Learning rate for the optimizer")
309
+ parser.add_argument("--save-frequency", type=int, default=4, help="Frequency of saving model weights")
310
+ parser.add_argument("--max-norm", type=float, default=10.0, help="Maximum gradient norm for clipping")
311
+ parser.add_argument("--project-name", type=str, default="YOLOv3, mask detection", help="Wandb project name")
312
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run the training on")
313
+ parser.add_argument("--weights-path", type=str, default="weights/darknet53.pth", help="Path to the weights")
314
+ parser.add_argument("--seed", type=int, default=42, help="Value of the seed")
315
+ parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["fp16", "bf16", "fp8", "no"], help="Value of the mixed precision")
316
+ parser.add_argument("--gradient-accumulation-steps", type=int, default=2, help="Value of the gradient accumulation steps")
317
+ parser.add_argument("--log-steps", type=int, default=13, help="Number of steps between logging training images and metrics")
318
+ parser.add_argument("--num-warmup-steps", type=int, default=400, help="Number of steps")
319
+ return parser.parse_args()
320
+
321
+
322
+ def main() -> None:
323
+ args = parse_args()
324
+ set_seed(args.seed)
325
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision)
326
+ with accelerator.main_process_first():
327
+ logs_dir = Path(args.logs_dir)
328
+ logs_dir.mkdir(exist_ok=True)
329
+ wandb.init(project=args.project_name, dir=logs_dir)
330
+ train_dataset = MaskDataset(root=args.root, train=True)
331
+ test_dataset = MaskDataset(root=args.root, train=False)
332
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=args.pin_memory, num_workers=args.num_workers, collate_fn=collate_fn)
333
+ test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=args.pin_memory, num_workers=args.num_workers, collate_fn=collate_fn)
334
+ model = YOLOv3().to(accelerator.device)
335
+ optimizer_class = getattr(torch.optim, args.optimizer)
336
+ if args.weights_path:
337
+ weights = torch.load(args.weights_path, map_location="cpu", weights_only=True)
338
+ model.backbone.load_state_dict(weights)
339
+ optimizer = optimizer_class(model.parameters(), lr=args.learning_rate)
340
+ criterion = YoloLoss(class_counts=train_dataset.class_counts)
341
+ scheduler = WarmupCosineAnnealingLR(optimizer, warmup_steps=args.num_warmup_steps//args.gradient_accumulation_steps, total_steps=args.num_epochs*len(train_loader)//args.gradient_accumulation_steps, eta_min=1e-7)
342
+
343
+ model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
344
+ best_map = 0.0
345
+ train_loss_ema = EMA()
346
+ for epoch in range(1, args.num_epochs + 1):
347
+ model.train()
348
+ pbar = tqdm(train_loader, desc = f"Train epoch {epoch} / {args.num_epochs}")
349
+ for images, (t_l, t_m, t_s) in pbar:
350
+ images = images.to(accelerator.device)
351
+ t_l = t_l.to(accelerator.device)
352
+ t_m = t_m.to(accelerator.device)
353
+ t_s = t_s.to(accelerator.device)
354
+ with accelerator.accumulate(model):
355
+ with accelerator.autocast():
356
+ out_l, out_m, out_s = model(images)
357
+ loss = criterion((out_l, out_m, out_s), (t_l, t_m, t_s))
358
+ accelerator.backward(loss)
359
+ grad_norm = None
360
+ if accelerator.sync_gradients:
361
+ grad_norm = accelerator.clip_grad_norm_(model.parameters(), args.max_norm).item()
362
+ optimizer.step()
363
+ optimizer.zero_grad()
364
+ scheduler.step()
365
+ lr = scheduler.get_last_lr()[0]
366
+ pbar.set_postfix({"loss": train_loss_ema(loss.item())})
367
+ log_data = {
368
+ "train/epoch": epoch,
369
+ "train/loss": loss.item(),
370
+ "train/lr": lr
371
+ }
372
+ if grad_norm is not None:
373
+ log_data["train/grad_norm"] = grad_norm
374
+ if accelerator.is_main_process:
375
+ wandb.log(log_data)
376
+ accelerator.wait_for_everyone()
377
+ model.eval()
378
+ all_pred = [[] for _ in range(model.num_classes)]
379
+ all_gt = [[] for _ in range(model.num_classes)]
380
+ with torch.inference_mode():
381
+ test_loss = 0.0
382
+ pbar = tqdm(test_loader, desc=f"Test epoch {epoch} / {args.num_epochs}")
383
+ for index, (images, (t_l, t_m, t_s)) in enumerate(pbar):
384
+ images = images.to(accelerator.device)
385
+ t_l = t_l.to(accelerator.device)
386
+ t_m = t_m.to(accelerator.device)
387
+ t_s = t_s.to(accelerator.device)
388
+ out_l, out_m, out_s = model(images)
389
+ loss = criterion((out_l, out_m, out_s), (t_l, t_m, t_s))
390
+ test_loss += loss.item()
391
+ results = decode_predictions_3scales(out_l, out_m, out_s, ANCHORS["large"], ANCHORS["medium"], ANCHORS["small"])
392
+ gt_batch = decode_target_3scales(t_l, t_m, t_s, ANCHORS["large"], ANCHORS["medium"], ANCHORS["small"])
393
+ if (index + 1) % args.log_steps == 0 and accelerator.is_main_process:
394
+ images_to_log = []
395
+ combined_image = create_combined_image(images, gt_batch, results)
396
+ images_to_log.append(wandb.Image(combined_image, caption=f"Combined Image (Test, Epoch {epoch})"))
397
+ wandb.log({"test_samples": images_to_log})
398
+ for b_i in range(len(images)):
399
+ dets_b = results[b_i].detach().cpu().numpy()
400
+ gts_b = gt_batch[b_i].detach().cpu().numpy()
401
+ for db in dets_b:
402
+ c = int(db[5])
403
+ all_pred[c].append([db[0], db[1], db[2], db[3], db[4]])
404
+ for gb in gts_b:
405
+ c = int(gb[4])
406
+ all_gt[c].append([gb[0], gb[1], gb[2], gb[3]])
407
+ test_loss /= len(test_loader)
408
+ test_map = compute_map(all_pred, all_gt)
409
+ accelerator.print(f"loss: {test_loss:.3f}, map: {test_map:.3f}")
410
+ if accelerator.is_main_process:
411
+ wandb.log({
412
+ "epoch": epoch,
413
+ "test/loss": test_loss,
414
+ "test/mAP": test_map
415
+ })
416
+ if test_map > best_map:
417
+ best_map = test_map
418
+ accelerator.save(model.state_dict(), logs_dir / "checkpoint-best.pth")
419
+ elif epoch % args.save_frequency == 0:
420
+ accelerator.save(model.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth")
421
+ accelerator.wait_for_everyone()
422
+ accelerator.wait_for_everyone()
423
+ wandb.finish()
424
+
425
+
426
+ if __name__ == "__main__":
427
+ main()
src/utils.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class EMA:
2
+ def __init__(self, alpha: float = 0.9) -> None:
3
+ self.value = None
4
+ self.alpha = alpha
5
+
6
+ def __call__(self, value: float) -> float:
7
+ if self.value is None:
8
+ self.value = value
9
+ else:
10
+ self.value = self.alpha * self.value + (1 - self.alpha) * value
11
+ return self.value
weights/checkpoint-best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d496cd707cec1135b6d6cfece5c35b92572063914d81ae2bbbc8ded5c7366e10
3
+ size 224442922