eksemyashkina commited on
Commit
bf9a8c8
·
verified ·
1 Parent(s): f514e23

Update src/train.py

Browse files
Files changed (1) hide show
  1. src/train.py +426 -426
src/train.py CHANGED
@@ -1,427 +1,427 @@
1
- import math
2
- from pathlib import Path
3
- from typing import List, Tuple, Dict
4
- from tqdm import tqdm
5
- import argparse
6
- from accelerate import Accelerator
7
- from accelerate.utils import set_seed
8
- import wandb
9
- import torch
10
- from torch import nn
11
- from torch.utils.data import DataLoader
12
- import torchvision.ops as ops
13
- import PIL
14
- import numpy as np
15
-
16
- from dataset import MaskDataset, collate_fn, ANCHORS
17
- from utils import EMA
18
- from models.yolov3 import YOLOv3
19
- from loss import YoloLoss
20
-
21
-
22
- class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler):
23
- def __init__(self, optimizer: torch.optim.Optimizer, warmup_steps: int, total_steps: int, eta_min: int = 0, last_epoch: int = -1) -> None:
24
- self.warmup_steps = warmup_steps
25
- self.total_steps = total_steps
26
- self.eta_min = eta_min
27
- super().__init__(optimizer, last_epoch)
28
-
29
- def get_lr(self) -> List[float]:
30
- if self.last_epoch < self.warmup_steps:
31
- return [base_lr * (self.last_epoch / max(1, self.warmup_steps)) for base_lr in self.base_lrs]
32
- else:
33
- current_step = self.last_epoch - self.warmup_steps
34
- cosine_steps = max(1, self.total_steps - self.warmup_steps)
35
- return [self.eta_min + (base_lr - self.eta_min) * 0.5 * (1 + math.cos(math.pi * current_step / cosine_steps)) for base_lr in self.base_lrs]
36
-
37
-
38
- def draw_bounding_boxes(image: PIL.Image.Image, boxes: torch.Tensor, colors: Dict[int, int] = {0: (178, 34, 34), 1: (34, 139, 34), 2: (184, 134, 11)}, labels = {0: "without_mask", 1: "with_mask", 2: "weared_incorrect"}, show_conf = False) -> None:
39
- draw = PIL.ImageDraw.Draw(image)
40
- for box in boxes:
41
- xmin, ymin, xmax, ymax, class_id = int(box[0]), int(box[1]), int(box[2]), int(box[3]), int(box[-1])
42
- conf_text = ""
43
- if show_conf and box.shape[0] == 6:
44
- conf = float(box[4])
45
- conf_text = f" {conf:.2f}"
46
- color = colors.get(class_id, (255, 255, 255))
47
- label = labels.get(class_id, "Unknown") + conf_text
48
- draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=2)
49
- text_bbox = draw.textbbox((xmin, ymin), label)
50
- text_width = text_bbox[2] - text_bbox[0]
51
- text_height = text_bbox[3] - text_bbox[1]
52
- draw.rectangle([xmin, ymin - text_height - 2, xmin + text_width + 2, ymin], fill=color)
53
- draw.text((xmin + 1, ymin - text_height - 1), label, fill="white")
54
-
55
-
56
- def create_combined_image(img: torch.Tensor, gt_batch: List[torch.Tensor], results: List[torch.Tensor], mean: List[float] = [0.485, 0.456, 0.406], std: List[float] = [0.229, 0.224, 0.225]):
57
- batch_size, _, height, width = img.shape
58
- combined_height = height * 2
59
- combined_width = width * batch_size
60
- combined_image = np.zeros((combined_height, combined_width, 3), dtype=np.uint8)
61
-
62
- for i in range(batch_size):
63
- image = img[i].cpu().permute(1, 2, 0).numpy()
64
- image = (image * std + mean).clip(0, 1)
65
- image = (image * 255).astype(np.uint8)
66
- gt_image = PIL.Image.fromarray(image.copy())
67
- pred_image = PIL.Image.fromarray(image.copy())
68
- draw_bounding_boxes(gt_image, gt_batch[i])
69
- draw_bounding_boxes(pred_image, results[i], show_conf=True)
70
- combined_image[:height, i * width:(i + 1) * width, :] = np.array(gt_image)
71
- combined_image[height:, i * width:(i + 1) * width, :] = np.array(pred_image)
72
- return PIL.Image.fromarray(combined_image)
73
-
74
-
75
- def decode_yolo_output_single(prediction: torch.Tensor, anchors: List[Tuple[int]], image_size: Tuple[int] = (416, 416), conf_threshold: float = 0.5, iou_threshold: float = 0.3, apply_nms: bool = True, num_classes: int = 3) -> List[torch.Tensor]:
76
- device = prediction.device
77
- B, _, H, W = prediction.shape
78
- A = len(anchors)
79
- prediction = prediction.view(B, A, 5 + num_classes, H, W)
80
- prediction = prediction.permute(0, 1, 3, 4, 2).contiguous()
81
- tx = prediction[..., 0]
82
- ty = prediction[..., 1]
83
- tw = prediction[..., 2]
84
- th = prediction[..., 3]
85
- obj = prediction[..., 4]
86
- class_scores = prediction[..., 5:]
87
- tx = tx.sigmoid()
88
- ty = ty.sigmoid()
89
- obj = obj.sigmoid()
90
- class_scores = class_scores.softmax(dim=-1)
91
- img_w, img_h = image_size
92
- cell_w = img_w / W
93
- cell_h = img_h / H
94
- grid_x = torch.arange(W, device=device).view(1, 1, W).expand(1, H, W)
95
- grid_y = torch.arange(H, device=device).view(1, H, 1).expand(1, H, W)
96
- anchors_tensor = torch.tensor(anchors, dtype=torch.float32, device=device)
97
- anchor_w = anchors_tensor[:, 0].view(1, A, 1, 1)
98
- anchor_h = anchors_tensor[:, 1].view(1, A, 1, 1)
99
- x_center = (grid_x + tx) * cell_w
100
- y_center = (grid_y + ty) * cell_h
101
- w = torch.exp(tw) * anchor_w
102
- h = torch.exp(th) * anchor_h
103
- xmin = x_center - w / 2
104
- ymin = y_center - h / 2
105
- xmax = x_center + w / 2
106
- ymax = y_center + h / 2
107
- max_class_probs, class_ids = class_scores.max(dim=-1)
108
- confidence = obj * max_class_probs
109
- outputs = []
110
- for b_i in range(B):
111
- box_xmin = xmin[b_i].view(-1)
112
- box_ymin = ymin[b_i].view(-1)
113
- box_xmax = xmax[b_i].view(-1)
114
- box_ymax = ymax[b_i].view(-1)
115
- conf = confidence[b_i].view(-1)
116
- cls_id = class_ids[b_i].view(-1).float()
117
- mask = (conf > conf_threshold)
118
- box_xmin = box_xmin[mask]
119
- box_ymin = box_ymin[mask]
120
- box_xmax = box_xmax[mask]
121
- box_ymax = box_ymax[mask]
122
- conf = conf[mask]
123
- cls_id = cls_id[mask]
124
- if mask.sum() == 0:
125
- outputs.append(torch.empty((0, 6), device=device))
126
- continue
127
- boxes = torch.stack([box_xmin, box_ymin, box_xmax, box_ymax], dim=-1)
128
- if apply_nms:
129
- keep = ops.nms(boxes, conf, iou_threshold)
130
- boxes = boxes[keep]
131
- conf = conf[keep]
132
- cls_id = cls_id[keep]
133
- out = torch.cat([boxes, conf.unsqueeze(-1), cls_id.unsqueeze(-1)], dim=-1)
134
- outputs.append(out)
135
- return outputs
136
-
137
-
138
- def decode_predictions_3scales(out_l: torch.Tensor, out_m: torch.Tensor, out_s: torch.Tensor, anchors_l: List[Tuple[int]], anchors_m: List[Tuple[int, int]], anchors_s: List[Tuple[int, int]], image_size: Tuple[int, int] = (416, 416), conf_threshold: float = 0.5, iou_threshold: float = 0.45, num_classes: int = 3) -> List[torch.Tensor]:
139
- b_l = decode_yolo_output_single(out_l, anchors_l, image_size, conf_threshold, iou_threshold, apply_nms=False, num_classes=num_classes)
140
- b_m = decode_yolo_output_single(out_m, anchors_m, image_size, conf_threshold, iou_threshold, apply_nms=False, num_classes=num_classes)
141
- b_s = decode_yolo_output_single(out_s, anchors_s, image_size, conf_threshold, iou_threshold, apply_nms=False, num_classes=num_classes)
142
- results = []
143
- B = len(b_l)
144
- for i in range(B):
145
- boxes_all = torch.cat([b_l[i], b_m[i], b_s[i]], dim=0)
146
- if boxes_all.numel() == 0:
147
- results.append(boxes_all)
148
- continue
149
- xyxy = boxes_all[:, :4]
150
- scores = boxes_all[:, 4]
151
- keep = ops.nms(xyxy, scores, iou_threshold)
152
- final = boxes_all[keep]
153
- results.append(final)
154
- return results
155
-
156
-
157
- def decode_target_single(target: torch.Tensor, anchors: List[Tuple[int]], image_size: Tuple[int] = (416, 416), obj_threshold: float = 0.5) -> List[torch.Tensor]:
158
- args = parse_args()
159
- target = target.to(args.device)
160
- B, S, _, A, _ = target.shape
161
- img_w, img_h = image_size
162
- cell_w = img_w / S
163
- cell_h = img_h / S
164
- anchors_tensor = torch.tensor(anchors, dtype=torch.float)
165
- tx = target[..., 0]
166
- ty = target[..., 1]
167
- tw = target[..., 2]
168
- th = target[..., 3]
169
- tobj = target[..., 4]
170
- tcls = target[..., 5:]
171
- results = []
172
- for b_i in range(B):
173
- bx_list = []
174
- tx_b = tx[b_i]
175
- ty_b = ty[b_i]
176
- tw_b = tw[b_i]
177
- th_b = th[b_i]
178
- tobj_b = tobj[b_i]
179
- tcls_b = tcls[b_i]
180
- for i in range(S):
181
- for j in range(S):
182
- for a_i in range(A):
183
- if tobj_b[i,j,a_i] < obj_threshold:
184
- continue
185
- cls_one_hot = tcls_b[i, j, a_i]
186
- cls_id = cls_one_hot.argmax().item()
187
- x_center = (j + tx_b[i, j, a_i].item()) * cell_w
188
- y_center = (i + ty_b[i, j, a_i].item()) * cell_h
189
- anchor_w = anchors_tensor[a_i, 0]
190
- anchor_h = anchors_tensor[a_i, 1]
191
- box_w = torch.exp(tw_b[i, j, a_i]) * anchor_w
192
- box_h = torch.exp(th_b[i, j, a_i]) * anchor_h
193
- xmin = x_center - box_w / 2
194
- ymin = y_center - box_h / 2
195
- xmax = x_center + box_w / 2
196
- ymax = y_center + box_h / 2
197
- bx_list.append([xmin.item(), ymin.item(), xmax.item(), ymax.item(), cls_id])
198
- if len(bx_list) == 0:
199
- results.append(torch.empty((0, 5), dtype=torch.float32, device=args.device))
200
- else:
201
- results.append(torch.tensor(bx_list, dtype=torch.float32, device=args.device))
202
- return results
203
-
204
-
205
- def decode_target_3scales(t_l: torch.Tensor, t_m: torch.Tensor, t_s: torch.Tensor, anchors_l: List[Tuple[int]], anchors_m: List[Tuple[int]], anchors_s: List[Tuple[int]], image_size: Tuple[int] = (416, 416), obj_threshold: float = 0.5) -> List[torch.Tensor]:
206
- dec_l = decode_target_single(t_l, anchors_l, image_size, obj_threshold)
207
- dec_m = decode_target_single(t_m, anchors_m, image_size, obj_threshold)
208
- dec_s = decode_target_single(t_s, anchors_s, image_size, obj_threshold)
209
- results = []
210
- B = len(dec_l)
211
- for i in range(B):
212
- boxes_l = dec_l[i]
213
- boxes_m = dec_m[i]
214
- boxes_s = dec_s[i]
215
- if boxes_l.numel() == 0 and boxes_m.numel() == 0 and boxes_s.numel() == 0:
216
- results.append(torch.empty((0, 5), dtype=torch.float32, device=boxes_l.device))
217
- else:
218
- all_ = torch.cat([boxes_l, boxes_m, boxes_s], dim=0)
219
- results.append(all_)
220
- return results
221
-
222
-
223
- def iou_xyxy(box1: List[int | float], box2: List[int | float]) -> float:
224
- x1 = max(box1[0], box2[0])
225
- y1 = max(box1[1], box2[1])
226
- x2 = min(box1[2], box2[2])
227
- y2 = min(box1[3], box2[3])
228
- w = max(0., x2 - x1)
229
- h = max(0., y2 - y1)
230
- inter = w * h
231
- area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
232
- area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
233
- union = area1 + area2 - inter
234
- return inter / union if union > 0 else 0.0
235
-
236
-
237
- def compute_ap_per_class(boxes_pred: List[List[float]], boxes_gt: List[List[float]], iou_threshold: float = 0.45) -> float:
238
- boxes_pred = sorted(boxes_pred, key=lambda x: x[4], reverse=True)
239
- n_gt = len(boxes_gt)
240
- if n_gt == 0 and len(boxes_pred) == 0:
241
- return 1.0
242
- if n_gt == 0:
243
- return 0.0
244
- matched = [False] * n_gt
245
- tps = []
246
- fps = []
247
- for i, pred in enumerate(boxes_pred):
248
- best_iou = 0.0
249
- best_j = -1
250
- for j, gt in enumerate(boxes_gt):
251
- if matched[j]:
252
- continue
253
- iou = iou_xyxy(pred, gt)
254
- if iou > best_iou:
255
- best_iou = iou
256
- best_j = j
257
- if best_iou > iou_threshold and best_j >= 0:
258
- tps.append(1)
259
- fps.append(0)
260
- matched[best_j] = True
261
- else:
262
- tps.append(0)
263
- fps.append(1)
264
- tps_cum = []
265
- fps_cum = []
266
- s_tp = 0
267
- s_fp = 0
268
- for i in range(len(tps)):
269
- s_tp += tps[i]
270
- s_fp += fps[i]
271
- tps_cum.append(s_tp)
272
- fps_cum.append(s_fp)
273
- precisions = []
274
- recalls = []
275
- for i in range(len(tps)):
276
- prec = tps_cum[i] / (tps_cum[i] + fps_cum[i]) if (tps_cum[i] + fps_cum[i]) > 0 else 0
277
- rec = tps_cum[i] / n_gt
278
- precisions.append(prec)
279
- recalls.append(rec)
280
- recalls = [0.0] + recalls + [1.0]
281
- precisions = [1.0] + precisions + [0.0]
282
- for i in range(len(precisions) - 2, -1, -1):
283
- precisions[i] = max(precisions[i], precisions[i+1])
284
- ap = 0.0
285
- for i in range(len(precisions) - 1):
286
- ap += (recalls[i+1] - recalls[i]) * precisions[i+1]
287
- return ap
288
-
289
-
290
- def compute_map(all_pred: List[float], all_gt: List[float], num_classes: int = 3, iou_threshold: float = 0.45) -> float:
291
- APs = []
292
- for c in range(num_classes):
293
- ap_c = compute_ap_per_class(all_pred[c], all_gt[c], iou_threshold)
294
- APs.append(ap_c)
295
- mAP = sum(APs) / len(APs) if len(APs) > 0 else 0.0
296
- return mAP
297
-
298
-
299
- def parse_args():
300
- parser = argparse.ArgumentParser(description="Train a model on the face mask detection dataset")
301
- parser.add_argument("--root", type=str, default="data/masks", help="Path to the data")
302
- parser.add_argument("--batch-size", type=int, default=16, help="Batch size for training and testing")
303
- parser.add_argument("--logs-dir", type=str, default="yolo-logs", help="Path to save logs")
304
- parser.add_argument("--pin-memory", type=bool, default=True, help="Pin Memory for DataLoader")
305
- parser.add_argument("--num-workers", type=int, default=0, help="Number of workers for DataLoader")
306
- parser.add_argument("--num-epochs", type=int, default=100, help="Number of training epochs")
307
- parser.add_argument("--optimizer", type=str, default="AdamW", help="Optimizer type")
308
- parser.add_argument("--learning-rate", type=float, default=5e-4, help="Learning rate for the optimizer")
309
- parser.add_argument("--save-frequency", type=int, default=4, help="Frequency of saving model weights")
310
- parser.add_argument("--max-norm", type=float, default=10.0, help="Maximum gradient norm for clipping")
311
- parser.add_argument("--project-name", type=str, default="YOLOv3, mask detection", help="Wandb project name")
312
- parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run the training on")
313
- parser.add_argument("--weights-path", type=str, default="weights/darknet53.pth", help="Path to the weights")
314
- parser.add_argument("--seed", type=int, default=42, help="Value of the seed")
315
- parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["fp16", "bf16", "fp8", "no"], help="Value of the mixed precision")
316
- parser.add_argument("--gradient-accumulation-steps", type=int, default=2, help="Value of the gradient accumulation steps")
317
- parser.add_argument("--log-steps", type=int, default=13, help="Number of steps between logging training images and metrics")
318
- parser.add_argument("--num-warmup-steps", type=int, default=400, help="Number of steps")
319
- return parser.parse_args()
320
-
321
-
322
- def main() -> None:
323
- args = parse_args()
324
- set_seed(args.seed)
325
- accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision)
326
- with accelerator.main_process_first():
327
- logs_dir = Path(args.logs_dir)
328
- logs_dir.mkdir(exist_ok=True)
329
- wandb.init(project=args.project_name, dir=logs_dir)
330
- train_dataset = MaskDataset(root=args.root, train=True)
331
- test_dataset = MaskDataset(root=args.root, train=False)
332
- train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=args.pin_memory, num_workers=args.num_workers, collate_fn=collate_fn)
333
- test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=args.pin_memory, num_workers=args.num_workers, collate_fn=collate_fn)
334
- model = YOLOv3().to(accelerator.device)
335
- optimizer_class = getattr(torch.optim, args.optimizer)
336
- if args.weights_path:
337
- weights = torch.load(args.weights_path, map_location="cpu", weights_only=True)
338
- model.backbone.load_state_dict(weights)
339
- optimizer = optimizer_class(model.parameters(), lr=args.learning_rate)
340
- criterion = YoloLoss(class_counts=train_dataset.class_counts)
341
- scheduler = WarmupCosineAnnealingLR(optimizer, warmup_steps=args.num_warmup_steps//args.gradient_accumulation_steps, total_steps=args.num_epochs*len(train_loader)//args.gradient_accumulation_steps, eta_min=1e-7)
342
-
343
- model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
344
- best_map = 0.0
345
- train_loss_ema = EMA()
346
- for epoch in range(1, args.num_epochs + 1):
347
- model.train()
348
- pbar = tqdm(train_loader, desc = f"Train epoch {epoch} / {args.num_epochs}")
349
- for images, (t_l, t_m, t_s) in pbar:
350
- images = images.to(accelerator.device)
351
- t_l = t_l.to(accelerator.device)
352
- t_m = t_m.to(accelerator.device)
353
- t_s = t_s.to(accelerator.device)
354
- with accelerator.accumulate(model):
355
- with accelerator.autocast():
356
- out_l, out_m, out_s = model(images)
357
- loss = criterion((out_l, out_m, out_s), (t_l, t_m, t_s))
358
- accelerator.backward(loss)
359
- grad_norm = None
360
- if accelerator.sync_gradients:
361
- grad_norm = accelerator.clip_grad_norm_(model.parameters(), args.max_norm).item()
362
- optimizer.step()
363
- optimizer.zero_grad()
364
- scheduler.step()
365
- lr = scheduler.get_last_lr()[0]
366
- pbar.set_postfix({"loss": train_loss_ema(loss.item())})
367
- log_data = {
368
- "train/epoch": epoch,
369
- "train/loss": loss.item(),
370
- "train/lr": lr
371
- }
372
- if grad_norm is not None:
373
- log_data["train/grad_norm"] = grad_norm
374
- if accelerator.is_main_process:
375
- wandb.log(log_data)
376
- accelerator.wait_for_everyone()
377
- model.eval()
378
- all_pred = [[] for _ in range(model.num_classes)]
379
- all_gt = [[] for _ in range(model.num_classes)]
380
- with torch.inference_mode():
381
- test_loss = 0.0
382
- pbar = tqdm(test_loader, desc=f"Test epoch {epoch} / {args.num_epochs}")
383
- for index, (images, (t_l, t_m, t_s)) in enumerate(pbar):
384
- images = images.to(accelerator.device)
385
- t_l = t_l.to(accelerator.device)
386
- t_m = t_m.to(accelerator.device)
387
- t_s = t_s.to(accelerator.device)
388
- out_l, out_m, out_s = model(images)
389
- loss = criterion((out_l, out_m, out_s), (t_l, t_m, t_s))
390
- test_loss += loss.item()
391
- results = decode_predictions_3scales(out_l, out_m, out_s, ANCHORS["large"], ANCHORS["medium"], ANCHORS["small"])
392
- gt_batch = decode_target_3scales(t_l, t_m, t_s, ANCHORS["large"], ANCHORS["medium"], ANCHORS["small"])
393
- if (index + 1) % args.log_steps == 0 and accelerator.is_main_process:
394
- images_to_log = []
395
- combined_image = create_combined_image(images, gt_batch, results)
396
- images_to_log.append(wandb.Image(combined_image, caption=f"Combined Image (Test, Epoch {epoch})"))
397
- wandb.log({"test_samples": images_to_log})
398
- for b_i in range(len(images)):
399
- dets_b = results[b_i].detach().cpu().numpy()
400
- gts_b = gt_batch[b_i].detach().cpu().numpy()
401
- for db in dets_b:
402
- c = int(db[5])
403
- all_pred[c].append([db[0], db[1], db[2], db[3], db[4]])
404
- for gb in gts_b:
405
- c = int(gb[4])
406
- all_gt[c].append([gb[0], gb[1], gb[2], gb[3]])
407
- test_loss /= len(test_loader)
408
- test_map = compute_map(all_pred, all_gt)
409
- accelerator.print(f"loss: {test_loss:.3f}, map: {test_map:.3f}")
410
- if accelerator.is_main_process:
411
- wandb.log({
412
- "epoch": epoch,
413
- "test/loss": test_loss,
414
- "test/mAP": test_map
415
- })
416
- if test_map > best_map:
417
- best_map = test_map
418
- accelerator.save(model.state_dict(), logs_dir / "checkpoint-best.pth")
419
- elif epoch % args.save_frequency == 0:
420
- accelerator.save(model.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth")
421
- accelerator.wait_for_everyone()
422
- accelerator.wait_for_everyone()
423
- wandb.finish()
424
-
425
-
426
- if __name__ == "__main__":
427
  main()
 
1
+ import math
2
+ from pathlib import Path
3
+ from typing import List, Tuple, Dict
4
+ from tqdm import tqdm
5
+ import argparse
6
+ from accelerate import Accelerator
7
+ from accelerate.utils import set_seed
8
+ import wandb
9
+ import torch
10
+ from torch import nn
11
+ from torch.utils.data import DataLoader
12
+ import torchvision.ops as ops
13
+ import PIL
14
+ import numpy as np
15
+
16
+ from src.dataset import MaskDataset, collate_fn, ANCHORS
17
+ from src.utils import EMA
18
+ from src.models.yolov3 import YOLOv3
19
+ from src.loss import YoloLoss
20
+
21
+
22
+ class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler):
23
+ def __init__(self, optimizer: torch.optim.Optimizer, warmup_steps: int, total_steps: int, eta_min: int = 0, last_epoch: int = -1) -> None:
24
+ self.warmup_steps = warmup_steps
25
+ self.total_steps = total_steps
26
+ self.eta_min = eta_min
27
+ super().__init__(optimizer, last_epoch)
28
+
29
+ def get_lr(self) -> List[float]:
30
+ if self.last_epoch < self.warmup_steps:
31
+ return [base_lr * (self.last_epoch / max(1, self.warmup_steps)) for base_lr in self.base_lrs]
32
+ else:
33
+ current_step = self.last_epoch - self.warmup_steps
34
+ cosine_steps = max(1, self.total_steps - self.warmup_steps)
35
+ return [self.eta_min + (base_lr - self.eta_min) * 0.5 * (1 + math.cos(math.pi * current_step / cosine_steps)) for base_lr in self.base_lrs]
36
+
37
+
38
+ def draw_bounding_boxes(image: PIL.Image.Image, boxes: torch.Tensor, colors: Dict[int, int] = {0: (178, 34, 34), 1: (34, 139, 34), 2: (184, 134, 11)}, labels = {0: "without_mask", 1: "with_mask", 2: "weared_incorrect"}, show_conf = False) -> None:
39
+ draw = PIL.ImageDraw.Draw(image)
40
+ for box in boxes:
41
+ xmin, ymin, xmax, ymax, class_id = int(box[0]), int(box[1]), int(box[2]), int(box[3]), int(box[-1])
42
+ conf_text = ""
43
+ if show_conf and box.shape[0] == 6:
44
+ conf = float(box[4])
45
+ conf_text = f" {conf:.2f}"
46
+ color = colors.get(class_id, (255, 255, 255))
47
+ label = labels.get(class_id, "Unknown") + conf_text
48
+ draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=2)
49
+ text_bbox = draw.textbbox((xmin, ymin), label)
50
+ text_width = text_bbox[2] - text_bbox[0]
51
+ text_height = text_bbox[3] - text_bbox[1]
52
+ draw.rectangle([xmin, ymin - text_height - 2, xmin + text_width + 2, ymin], fill=color)
53
+ draw.text((xmin + 1, ymin - text_height - 1), label, fill="white")
54
+
55
+
56
+ def create_combined_image(img: torch.Tensor, gt_batch: List[torch.Tensor], results: List[torch.Tensor], mean: List[float] = [0.485, 0.456, 0.406], std: List[float] = [0.229, 0.224, 0.225]):
57
+ batch_size, _, height, width = img.shape
58
+ combined_height = height * 2
59
+ combined_width = width * batch_size
60
+ combined_image = np.zeros((combined_height, combined_width, 3), dtype=np.uint8)
61
+
62
+ for i in range(batch_size):
63
+ image = img[i].cpu().permute(1, 2, 0).numpy()
64
+ image = (image * std + mean).clip(0, 1)
65
+ image = (image * 255).astype(np.uint8)
66
+ gt_image = PIL.Image.fromarray(image.copy())
67
+ pred_image = PIL.Image.fromarray(image.copy())
68
+ draw_bounding_boxes(gt_image, gt_batch[i])
69
+ draw_bounding_boxes(pred_image, results[i], show_conf=True)
70
+ combined_image[:height, i * width:(i + 1) * width, :] = np.array(gt_image)
71
+ combined_image[height:, i * width:(i + 1) * width, :] = np.array(pred_image)
72
+ return PIL.Image.fromarray(combined_image)
73
+
74
+
75
+ def decode_yolo_output_single(prediction: torch.Tensor, anchors: List[Tuple[int]], image_size: Tuple[int] = (416, 416), conf_threshold: float = 0.5, iou_threshold: float = 0.3, apply_nms: bool = True, num_classes: int = 3) -> List[torch.Tensor]:
76
+ device = prediction.device
77
+ B, _, H, W = prediction.shape
78
+ A = len(anchors)
79
+ prediction = prediction.view(B, A, 5 + num_classes, H, W)
80
+ prediction = prediction.permute(0, 1, 3, 4, 2).contiguous()
81
+ tx = prediction[..., 0]
82
+ ty = prediction[..., 1]
83
+ tw = prediction[..., 2]
84
+ th = prediction[..., 3]
85
+ obj = prediction[..., 4]
86
+ class_scores = prediction[..., 5:]
87
+ tx = tx.sigmoid()
88
+ ty = ty.sigmoid()
89
+ obj = obj.sigmoid()
90
+ class_scores = class_scores.softmax(dim=-1)
91
+ img_w, img_h = image_size
92
+ cell_w = img_w / W
93
+ cell_h = img_h / H
94
+ grid_x = torch.arange(W, device=device).view(1, 1, W).expand(1, H, W)
95
+ grid_y = torch.arange(H, device=device).view(1, H, 1).expand(1, H, W)
96
+ anchors_tensor = torch.tensor(anchors, dtype=torch.float32, device=device)
97
+ anchor_w = anchors_tensor[:, 0].view(1, A, 1, 1)
98
+ anchor_h = anchors_tensor[:, 1].view(1, A, 1, 1)
99
+ x_center = (grid_x + tx) * cell_w
100
+ y_center = (grid_y + ty) * cell_h
101
+ w = torch.exp(tw) * anchor_w
102
+ h = torch.exp(th) * anchor_h
103
+ xmin = x_center - w / 2
104
+ ymin = y_center - h / 2
105
+ xmax = x_center + w / 2
106
+ ymax = y_center + h / 2
107
+ max_class_probs, class_ids = class_scores.max(dim=-1)
108
+ confidence = obj * max_class_probs
109
+ outputs = []
110
+ for b_i in range(B):
111
+ box_xmin = xmin[b_i].view(-1)
112
+ box_ymin = ymin[b_i].view(-1)
113
+ box_xmax = xmax[b_i].view(-1)
114
+ box_ymax = ymax[b_i].view(-1)
115
+ conf = confidence[b_i].view(-1)
116
+ cls_id = class_ids[b_i].view(-1).float()
117
+ mask = (conf > conf_threshold)
118
+ box_xmin = box_xmin[mask]
119
+ box_ymin = box_ymin[mask]
120
+ box_xmax = box_xmax[mask]
121
+ box_ymax = box_ymax[mask]
122
+ conf = conf[mask]
123
+ cls_id = cls_id[mask]
124
+ if mask.sum() == 0:
125
+ outputs.append(torch.empty((0, 6), device=device))
126
+ continue
127
+ boxes = torch.stack([box_xmin, box_ymin, box_xmax, box_ymax], dim=-1)
128
+ if apply_nms:
129
+ keep = ops.nms(boxes, conf, iou_threshold)
130
+ boxes = boxes[keep]
131
+ conf = conf[keep]
132
+ cls_id = cls_id[keep]
133
+ out = torch.cat([boxes, conf.unsqueeze(-1), cls_id.unsqueeze(-1)], dim=-1)
134
+ outputs.append(out)
135
+ return outputs
136
+
137
+
138
+ def decode_predictions_3scales(out_l: torch.Tensor, out_m: torch.Tensor, out_s: torch.Tensor, anchors_l: List[Tuple[int]], anchors_m: List[Tuple[int, int]], anchors_s: List[Tuple[int, int]], image_size: Tuple[int, int] = (416, 416), conf_threshold: float = 0.5, iou_threshold: float = 0.45, num_classes: int = 3) -> List[torch.Tensor]:
139
+ b_l = decode_yolo_output_single(out_l, anchors_l, image_size, conf_threshold, iou_threshold, apply_nms=False, num_classes=num_classes)
140
+ b_m = decode_yolo_output_single(out_m, anchors_m, image_size, conf_threshold, iou_threshold, apply_nms=False, num_classes=num_classes)
141
+ b_s = decode_yolo_output_single(out_s, anchors_s, image_size, conf_threshold, iou_threshold, apply_nms=False, num_classes=num_classes)
142
+ results = []
143
+ B = len(b_l)
144
+ for i in range(B):
145
+ boxes_all = torch.cat([b_l[i], b_m[i], b_s[i]], dim=0)
146
+ if boxes_all.numel() == 0:
147
+ results.append(boxes_all)
148
+ continue
149
+ xyxy = boxes_all[:, :4]
150
+ scores = boxes_all[:, 4]
151
+ keep = ops.nms(xyxy, scores, iou_threshold)
152
+ final = boxes_all[keep]
153
+ results.append(final)
154
+ return results
155
+
156
+
157
+ def decode_target_single(target: torch.Tensor, anchors: List[Tuple[int]], image_size: Tuple[int] = (416, 416), obj_threshold: float = 0.5) -> List[torch.Tensor]:
158
+ args = parse_args()
159
+ target = target.to(args.device)
160
+ B, S, _, A, _ = target.shape
161
+ img_w, img_h = image_size
162
+ cell_w = img_w / S
163
+ cell_h = img_h / S
164
+ anchors_tensor = torch.tensor(anchors, dtype=torch.float)
165
+ tx = target[..., 0]
166
+ ty = target[..., 1]
167
+ tw = target[..., 2]
168
+ th = target[..., 3]
169
+ tobj = target[..., 4]
170
+ tcls = target[..., 5:]
171
+ results = []
172
+ for b_i in range(B):
173
+ bx_list = []
174
+ tx_b = tx[b_i]
175
+ ty_b = ty[b_i]
176
+ tw_b = tw[b_i]
177
+ th_b = th[b_i]
178
+ tobj_b = tobj[b_i]
179
+ tcls_b = tcls[b_i]
180
+ for i in range(S):
181
+ for j in range(S):
182
+ for a_i in range(A):
183
+ if tobj_b[i,j,a_i] < obj_threshold:
184
+ continue
185
+ cls_one_hot = tcls_b[i, j, a_i]
186
+ cls_id = cls_one_hot.argmax().item()
187
+ x_center = (j + tx_b[i, j, a_i].item()) * cell_w
188
+ y_center = (i + ty_b[i, j, a_i].item()) * cell_h
189
+ anchor_w = anchors_tensor[a_i, 0]
190
+ anchor_h = anchors_tensor[a_i, 1]
191
+ box_w = torch.exp(tw_b[i, j, a_i]) * anchor_w
192
+ box_h = torch.exp(th_b[i, j, a_i]) * anchor_h
193
+ xmin = x_center - box_w / 2
194
+ ymin = y_center - box_h / 2
195
+ xmax = x_center + box_w / 2
196
+ ymax = y_center + box_h / 2
197
+ bx_list.append([xmin.item(), ymin.item(), xmax.item(), ymax.item(), cls_id])
198
+ if len(bx_list) == 0:
199
+ results.append(torch.empty((0, 5), dtype=torch.float32, device=args.device))
200
+ else:
201
+ results.append(torch.tensor(bx_list, dtype=torch.float32, device=args.device))
202
+ return results
203
+
204
+
205
+ def decode_target_3scales(t_l: torch.Tensor, t_m: torch.Tensor, t_s: torch.Tensor, anchors_l: List[Tuple[int]], anchors_m: List[Tuple[int]], anchors_s: List[Tuple[int]], image_size: Tuple[int] = (416, 416), obj_threshold: float = 0.5) -> List[torch.Tensor]:
206
+ dec_l = decode_target_single(t_l, anchors_l, image_size, obj_threshold)
207
+ dec_m = decode_target_single(t_m, anchors_m, image_size, obj_threshold)
208
+ dec_s = decode_target_single(t_s, anchors_s, image_size, obj_threshold)
209
+ results = []
210
+ B = len(dec_l)
211
+ for i in range(B):
212
+ boxes_l = dec_l[i]
213
+ boxes_m = dec_m[i]
214
+ boxes_s = dec_s[i]
215
+ if boxes_l.numel() == 0 and boxes_m.numel() == 0 and boxes_s.numel() == 0:
216
+ results.append(torch.empty((0, 5), dtype=torch.float32, device=boxes_l.device))
217
+ else:
218
+ all_ = torch.cat([boxes_l, boxes_m, boxes_s], dim=0)
219
+ results.append(all_)
220
+ return results
221
+
222
+
223
+ def iou_xyxy(box1: List[int | float], box2: List[int | float]) -> float:
224
+ x1 = max(box1[0], box2[0])
225
+ y1 = max(box1[1], box2[1])
226
+ x2 = min(box1[2], box2[2])
227
+ y2 = min(box1[3], box2[3])
228
+ w = max(0., x2 - x1)
229
+ h = max(0., y2 - y1)
230
+ inter = w * h
231
+ area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
232
+ area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
233
+ union = area1 + area2 - inter
234
+ return inter / union if union > 0 else 0.0
235
+
236
+
237
+ def compute_ap_per_class(boxes_pred: List[List[float]], boxes_gt: List[List[float]], iou_threshold: float = 0.45) -> float:
238
+ boxes_pred = sorted(boxes_pred, key=lambda x: x[4], reverse=True)
239
+ n_gt = len(boxes_gt)
240
+ if n_gt == 0 and len(boxes_pred) == 0:
241
+ return 1.0
242
+ if n_gt == 0:
243
+ return 0.0
244
+ matched = [False] * n_gt
245
+ tps = []
246
+ fps = []
247
+ for i, pred in enumerate(boxes_pred):
248
+ best_iou = 0.0
249
+ best_j = -1
250
+ for j, gt in enumerate(boxes_gt):
251
+ if matched[j]:
252
+ continue
253
+ iou = iou_xyxy(pred, gt)
254
+ if iou > best_iou:
255
+ best_iou = iou
256
+ best_j = j
257
+ if best_iou > iou_threshold and best_j >= 0:
258
+ tps.append(1)
259
+ fps.append(0)
260
+ matched[best_j] = True
261
+ else:
262
+ tps.append(0)
263
+ fps.append(1)
264
+ tps_cum = []
265
+ fps_cum = []
266
+ s_tp = 0
267
+ s_fp = 0
268
+ for i in range(len(tps)):
269
+ s_tp += tps[i]
270
+ s_fp += fps[i]
271
+ tps_cum.append(s_tp)
272
+ fps_cum.append(s_fp)
273
+ precisions = []
274
+ recalls = []
275
+ for i in range(len(tps)):
276
+ prec = tps_cum[i] / (tps_cum[i] + fps_cum[i]) if (tps_cum[i] + fps_cum[i]) > 0 else 0
277
+ rec = tps_cum[i] / n_gt
278
+ precisions.append(prec)
279
+ recalls.append(rec)
280
+ recalls = [0.0] + recalls + [1.0]
281
+ precisions = [1.0] + precisions + [0.0]
282
+ for i in range(len(precisions) - 2, -1, -1):
283
+ precisions[i] = max(precisions[i], precisions[i+1])
284
+ ap = 0.0
285
+ for i in range(len(precisions) - 1):
286
+ ap += (recalls[i+1] - recalls[i]) * precisions[i+1]
287
+ return ap
288
+
289
+
290
+ def compute_map(all_pred: List[float], all_gt: List[float], num_classes: int = 3, iou_threshold: float = 0.45) -> float:
291
+ APs = []
292
+ for c in range(num_classes):
293
+ ap_c = compute_ap_per_class(all_pred[c], all_gt[c], iou_threshold)
294
+ APs.append(ap_c)
295
+ mAP = sum(APs) / len(APs) if len(APs) > 0 else 0.0
296
+ return mAP
297
+
298
+
299
+ def parse_args():
300
+ parser = argparse.ArgumentParser(description="Train a model on the face mask detection dataset")
301
+ parser.add_argument("--root", type=str, default="data/masks", help="Path to the data")
302
+ parser.add_argument("--batch-size", type=int, default=16, help="Batch size for training and testing")
303
+ parser.add_argument("--logs-dir", type=str, default="yolo-logs", help="Path to save logs")
304
+ parser.add_argument("--pin-memory", type=bool, default=True, help="Pin Memory for DataLoader")
305
+ parser.add_argument("--num-workers", type=int, default=0, help="Number of workers for DataLoader")
306
+ parser.add_argument("--num-epochs", type=int, default=100, help="Number of training epochs")
307
+ parser.add_argument("--optimizer", type=str, default="AdamW", help="Optimizer type")
308
+ parser.add_argument("--learning-rate", type=float, default=5e-4, help="Learning rate for the optimizer")
309
+ parser.add_argument("--save-frequency", type=int, default=4, help="Frequency of saving model weights")
310
+ parser.add_argument("--max-norm", type=float, default=10.0, help="Maximum gradient norm for clipping")
311
+ parser.add_argument("--project-name", type=str, default="YOLOv3, mask detection", help="Wandb project name")
312
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run the training on")
313
+ parser.add_argument("--weights-path", type=str, default="weights/darknet53.pth", help="Path to the weights")
314
+ parser.add_argument("--seed", type=int, default=42, help="Value of the seed")
315
+ parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["fp16", "bf16", "fp8", "no"], help="Value of the mixed precision")
316
+ parser.add_argument("--gradient-accumulation-steps", type=int, default=2, help="Value of the gradient accumulation steps")
317
+ parser.add_argument("--log-steps", type=int, default=13, help="Number of steps between logging training images and metrics")
318
+ parser.add_argument("--num-warmup-steps", type=int, default=400, help="Number of steps")
319
+ return parser.parse_args()
320
+
321
+
322
+ def main() -> None:
323
+ args = parse_args()
324
+ set_seed(args.seed)
325
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision)
326
+ with accelerator.main_process_first():
327
+ logs_dir = Path(args.logs_dir)
328
+ logs_dir.mkdir(exist_ok=True)
329
+ wandb.init(project=args.project_name, dir=logs_dir)
330
+ train_dataset = MaskDataset(root=args.root, train=True)
331
+ test_dataset = MaskDataset(root=args.root, train=False)
332
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=args.pin_memory, num_workers=args.num_workers, collate_fn=collate_fn)
333
+ test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=args.pin_memory, num_workers=args.num_workers, collate_fn=collate_fn)
334
+ model = YOLOv3().to(accelerator.device)
335
+ optimizer_class = getattr(torch.optim, args.optimizer)
336
+ if args.weights_path:
337
+ weights = torch.load(args.weights_path, map_location="cpu", weights_only=True)
338
+ model.backbone.load_state_dict(weights)
339
+ optimizer = optimizer_class(model.parameters(), lr=args.learning_rate)
340
+ criterion = YoloLoss(class_counts=train_dataset.class_counts)
341
+ scheduler = WarmupCosineAnnealingLR(optimizer, warmup_steps=args.num_warmup_steps//args.gradient_accumulation_steps, total_steps=args.num_epochs*len(train_loader)//args.gradient_accumulation_steps, eta_min=1e-7)
342
+
343
+ model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
344
+ best_map = 0.0
345
+ train_loss_ema = EMA()
346
+ for epoch in range(1, args.num_epochs + 1):
347
+ model.train()
348
+ pbar = tqdm(train_loader, desc = f"Train epoch {epoch} / {args.num_epochs}")
349
+ for images, (t_l, t_m, t_s) in pbar:
350
+ images = images.to(accelerator.device)
351
+ t_l = t_l.to(accelerator.device)
352
+ t_m = t_m.to(accelerator.device)
353
+ t_s = t_s.to(accelerator.device)
354
+ with accelerator.accumulate(model):
355
+ with accelerator.autocast():
356
+ out_l, out_m, out_s = model(images)
357
+ loss = criterion((out_l, out_m, out_s), (t_l, t_m, t_s))
358
+ accelerator.backward(loss)
359
+ grad_norm = None
360
+ if accelerator.sync_gradients:
361
+ grad_norm = accelerator.clip_grad_norm_(model.parameters(), args.max_norm).item()
362
+ optimizer.step()
363
+ optimizer.zero_grad()
364
+ scheduler.step()
365
+ lr = scheduler.get_last_lr()[0]
366
+ pbar.set_postfix({"loss": train_loss_ema(loss.item())})
367
+ log_data = {
368
+ "train/epoch": epoch,
369
+ "train/loss": loss.item(),
370
+ "train/lr": lr
371
+ }
372
+ if grad_norm is not None:
373
+ log_data["train/grad_norm"] = grad_norm
374
+ if accelerator.is_main_process:
375
+ wandb.log(log_data)
376
+ accelerator.wait_for_everyone()
377
+ model.eval()
378
+ all_pred = [[] for _ in range(model.num_classes)]
379
+ all_gt = [[] for _ in range(model.num_classes)]
380
+ with torch.inference_mode():
381
+ test_loss = 0.0
382
+ pbar = tqdm(test_loader, desc=f"Test epoch {epoch} / {args.num_epochs}")
383
+ for index, (images, (t_l, t_m, t_s)) in enumerate(pbar):
384
+ images = images.to(accelerator.device)
385
+ t_l = t_l.to(accelerator.device)
386
+ t_m = t_m.to(accelerator.device)
387
+ t_s = t_s.to(accelerator.device)
388
+ out_l, out_m, out_s = model(images)
389
+ loss = criterion((out_l, out_m, out_s), (t_l, t_m, t_s))
390
+ test_loss += loss.item()
391
+ results = decode_predictions_3scales(out_l, out_m, out_s, ANCHORS["large"], ANCHORS["medium"], ANCHORS["small"])
392
+ gt_batch = decode_target_3scales(t_l, t_m, t_s, ANCHORS["large"], ANCHORS["medium"], ANCHORS["small"])
393
+ if (index + 1) % args.log_steps == 0 and accelerator.is_main_process:
394
+ images_to_log = []
395
+ combined_image = create_combined_image(images, gt_batch, results)
396
+ images_to_log.append(wandb.Image(combined_image, caption=f"Combined Image (Test, Epoch {epoch})"))
397
+ wandb.log({"test_samples": images_to_log})
398
+ for b_i in range(len(images)):
399
+ dets_b = results[b_i].detach().cpu().numpy()
400
+ gts_b = gt_batch[b_i].detach().cpu().numpy()
401
+ for db in dets_b:
402
+ c = int(db[5])
403
+ all_pred[c].append([db[0], db[1], db[2], db[3], db[4]])
404
+ for gb in gts_b:
405
+ c = int(gb[4])
406
+ all_gt[c].append([gb[0], gb[1], gb[2], gb[3]])
407
+ test_loss /= len(test_loader)
408
+ test_map = compute_map(all_pred, all_gt)
409
+ accelerator.print(f"loss: {test_loss:.3f}, map: {test_map:.3f}")
410
+ if accelerator.is_main_process:
411
+ wandb.log({
412
+ "epoch": epoch,
413
+ "test/loss": test_loss,
414
+ "test/mAP": test_map
415
+ })
416
+ if test_map > best_map:
417
+ best_map = test_map
418
+ accelerator.save(model.state_dict(), logs_dir / "checkpoint-best.pth")
419
+ elif epoch % args.save_frequency == 0:
420
+ accelerator.save(model.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth")
421
+ accelerator.wait_for_everyone()
422
+ accelerator.wait_for_everyone()
423
+ wandb.finish()
424
+
425
+
426
+ if __name__ == "__main__":
427
  main()