Spaces:
Sleeping
Sleeping
Upload 13 files
Browse files- .gitattributes +2 -0
- app.py +100 -0
- assets/examples/image1.jpg +0 -0
- assets/examples/image2.jpg +0 -0
- assets/examples/image3.jpg +3 -0
- assets/examples/image4.jpg +0 -0
- assets/examples/image5.jpg +3 -0
- requirements.txt +11 -0
- src/dataset.py +173 -0
- src/loss.py +177 -0
- src/models/yolov3.py +114 -0
- src/train.py +427 -0
- src/utils.py +11 -0
- weights/checkpoint-best.pth +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/examples/image3.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/examples/image5.jpg filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import gradio as gr
|
3 |
+
import PIL.Image, PIL.ImageOps
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import torchvision.transforms as T
|
7 |
+
|
8 |
+
from src.models.yolov3 import YOLOv3
|
9 |
+
from src.train import draw_bounding_boxes, decode_predictions_3scales
|
10 |
+
from src.dataset import ANCHORS, resize_with_padding
|
11 |
+
|
12 |
+
|
13 |
+
device = torch.device("cpu")
|
14 |
+
model_weight = "weights/checkpoint-best.pth"
|
15 |
+
label_colors = {"without_mask": (178, 34, 34), "with_mask": (34, 139, 34), "mask_worn_incorrectly": (184, 134, 11)}
|
16 |
+
|
17 |
+
model = YOLOv3()
|
18 |
+
model.load_state_dict(torch.load(model_weight, map_location=device))
|
19 |
+
model.eval()
|
20 |
+
|
21 |
+
|
22 |
+
def create_combined_image(img: torch.Tensor, results: List[torch.Tensor], mean: List[float] = [0.485, 0.456, 0.406], std: List[float] = [0.229, 0.224, 0.225]):
|
23 |
+
batch_size, _, height, width = img.shape
|
24 |
+
combined_height = height
|
25 |
+
combined_width = width * batch_size
|
26 |
+
combined_image = np.zeros((combined_height, combined_width, 3), dtype=np.uint8)
|
27 |
+
|
28 |
+
for i in range(batch_size):
|
29 |
+
image = img[i].cpu().permute(1, 2, 0).numpy()
|
30 |
+
image = (image * std + mean).clip(0, 1)
|
31 |
+
image = (image * 255).astype(np.uint8)
|
32 |
+
pred_image = PIL.Image.fromarray(image.copy())
|
33 |
+
draw_bounding_boxes(pred_image, results[i], show_conf=True)
|
34 |
+
combined_image[:height, i * width:(i + 1) * width, :] = np.array(pred_image)
|
35 |
+
return PIL.Image.fromarray(combined_image)
|
36 |
+
|
37 |
+
|
38 |
+
transform = T.Compose([
|
39 |
+
T.ToTensor(),
|
40 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
41 |
+
])
|
42 |
+
|
43 |
+
|
44 |
+
def detect_mask(image, conf_threshold: float) -> PIL.Image:
|
45 |
+
img_resized, _, _, _ = resize_with_padding(image)
|
46 |
+
img_tensor = transform(img_resized)
|
47 |
+
with torch.no_grad():
|
48 |
+
out_l, out_m, out_s = model(img_tensor.unsqueeze(0))
|
49 |
+
results = decode_predictions_3scales(out_l, out_m, out_s, ANCHORS["large"], ANCHORS["medium"], ANCHORS["small"], conf_threshold=conf_threshold)
|
50 |
+
combined_image = create_combined_image(img_tensor.unsqueeze(0), results)
|
51 |
+
return combined_image
|
52 |
+
|
53 |
+
|
54 |
+
def generate_legend_html_compact() -> str:
|
55 |
+
legend_html = """
|
56 |
+
<div style="display: flex; flex-wrap: wrap; gap: 10px; justify-content: center;">
|
57 |
+
"""
|
58 |
+
for idx, (label, color) in enumerate(label_colors.items()):
|
59 |
+
legend_html += f"""
|
60 |
+
<div style="display: flex; align-items: center; justify-content: center;
|
61 |
+
padding: 5px 10px; border: 1px solid rgb{color};
|
62 |
+
background-color: rgb{color}; border-radius: 5px;
|
63 |
+
color: white; font-size: 12px; text-align: center;">
|
64 |
+
{label}
|
65 |
+
</div>
|
66 |
+
"""
|
67 |
+
legend_html += "</div>"
|
68 |
+
return legend_html
|
69 |
+
|
70 |
+
|
71 |
+
examples = [
|
72 |
+
["assets/examples/image1.jpg"],
|
73 |
+
["assets/examples/image2.jpg"],
|
74 |
+
["assets/examples/image3.jpg"],
|
75 |
+
["assets/examples/image4.jpg"],
|
76 |
+
["assets/examples/image5.jpg"]
|
77 |
+
]
|
78 |
+
|
79 |
+
|
80 |
+
with gr.Blocks() as demo:
|
81 |
+
gr.Markdown("## Mask Detection with YOLOv3")
|
82 |
+
with gr.Row():
|
83 |
+
with gr.Column():
|
84 |
+
pic = gr.Image(label="Upload Human Image", type="pil", height=300, width=300)
|
85 |
+
conf_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.01, label="Confidence Threshold")
|
86 |
+
with gr.Row():
|
87 |
+
with gr.Column(scale=1):
|
88 |
+
predict_btn = gr.Button("Predict")
|
89 |
+
with gr.Column(scale=1):
|
90 |
+
clear_btn = gr.Button("Clear")
|
91 |
+
|
92 |
+
with gr.Column():
|
93 |
+
output = gr.Image(label="Detection", type="pil", height=300, width=300)
|
94 |
+
legend = gr.HTML(label="Legend", value=generate_legend_html_compact())
|
95 |
+
|
96 |
+
predict_btn.click(fn=detect_mask, inputs=[pic, conf_slider], outputs=output, api_name="predict")
|
97 |
+
clear_btn.click(lambda: (None, None), outputs=[pic, output])
|
98 |
+
gr.Examples(examples=examples, inputs=[pic])
|
99 |
+
|
100 |
+
demo.launch()
|
assets/examples/image1.jpg
ADDED
![]() |
assets/examples/image2.jpg
ADDED
![]() |
assets/examples/image3.jpg
ADDED
![]() |
Git LFS Details
|
assets/examples/image4.jpg
ADDED
![]() |
assets/examples/image5.jpg
ADDED
![]() |
Git LFS Details
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.6.0
|
2 |
+
tqdm==4.67.1
|
3 |
+
Pillow==10.4.0
|
4 |
+
bs4==0.0.2
|
5 |
+
scikit-learn==1.6.0
|
6 |
+
torchvision==0.21.0
|
7 |
+
wandb==0.19.1
|
8 |
+
lxml==5.3.0
|
9 |
+
accelerate==1.1.0
|
10 |
+
kaggle==1.6.17
|
11 |
+
gradio==5.14.0
|
src/dataset.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Dict
|
2 |
+
from pathlib import Path
|
3 |
+
import PIL.Image
|
4 |
+
import numpy as np
|
5 |
+
import torchvision.transforms as T
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from bs4 import BeautifulSoup
|
9 |
+
from bs4.element import Tag
|
10 |
+
|
11 |
+
|
12 |
+
ANCHORS = {
|
13 |
+
"small": [(26, 28), (17, 19), (10, 11)],
|
14 |
+
"medium": [(78, 88), (55, 59), (37, 42)],
|
15 |
+
"large": [(128, 152), (182, 205), (103, 124)]
|
16 |
+
}
|
17 |
+
GRID_SIZES = [13, 26, 52]
|
18 |
+
IMAGE_SIZE = (416, 416)
|
19 |
+
NUM_CLASSES = 3
|
20 |
+
|
21 |
+
|
22 |
+
def generate_box(obj: Tag) -> List[int]:
|
23 |
+
xmin = int(obj.find("xmin").text) - 1
|
24 |
+
ymin = int(obj.find("ymin").text) - 1
|
25 |
+
xmax = int(obj.find("xmax").text) - 1
|
26 |
+
ymax = int(obj.find("ymax").text) - 1
|
27 |
+
if obj.find("name").text == "without_mask":
|
28 |
+
class_id = 0
|
29 |
+
elif obj.find("name").text == "with_mask":
|
30 |
+
class_id = 1
|
31 |
+
else:
|
32 |
+
class_id = 2
|
33 |
+
return [xmin, ymin, xmax, ymax, class_id]
|
34 |
+
|
35 |
+
|
36 |
+
def resize_boxes(box: List[int], scale: float, pad_x: int, pad_y: int) -> Tuple[int]:
|
37 |
+
xmin, ymin, xmax, ymax, class_id = box
|
38 |
+
xmin = int(xmin * scale + pad_x)
|
39 |
+
ymin = int(ymin * scale + pad_y)
|
40 |
+
xmax = int(xmax * scale + pad_x)
|
41 |
+
ymax = int(ymax * scale + pad_y)
|
42 |
+
return (xmin, ymin, xmax, ymax, class_id)
|
43 |
+
|
44 |
+
|
45 |
+
def resize_with_padding(image: PIL.Image.Image, target_size: Tuple[int] = IMAGE_SIZE, fill: Tuple[int] = (255, 255, 255)) -> Tuple[PIL.Image.Image, float, int]:
|
46 |
+
target_w, target_h = target_size
|
47 |
+
orig_w, orig_h = image.size
|
48 |
+
scale = min(target_w / orig_w, target_h / orig_h)
|
49 |
+
new_w = int(orig_w * scale)
|
50 |
+
new_h = int(orig_h * scale)
|
51 |
+
image_resized = image.resize((new_w, new_h), resample=PIL.Image.LANCZOS)
|
52 |
+
new_image = PIL.Image.new("RGB", (target_w, target_h), color=fill)
|
53 |
+
pad_x = (target_w - new_w) // 2
|
54 |
+
pad_y = (target_h - new_h) // 2
|
55 |
+
new_image.paste(image_resized, (pad_x, pad_y))
|
56 |
+
return new_image, scale, pad_x, pad_y
|
57 |
+
|
58 |
+
|
59 |
+
def build_targets_3scale(bboxes: List[Tuple[int]], image_size: Tuple[int] = IMAGE_SIZE, anchors: Dict[str, List[Tuple[int]]] = ANCHORS, grid_sizes: List[int] = GRID_SIZES, num_classes: int = NUM_CLASSES) -> Tuple[torch.Tensor]:
|
60 |
+
img_w, img_h = image_size
|
61 |
+
t_large = torch.zeros((grid_sizes[0], grid_sizes[0], 3, 5 + num_classes), dtype=torch.float32)
|
62 |
+
t_medium = torch.zeros((grid_sizes[1], grid_sizes[1], 3, 5 + num_classes), dtype=torch.float32)
|
63 |
+
t_small = torch.zeros((grid_sizes[2], grid_sizes[2], 3, 5 + num_classes), dtype=torch.float32)
|
64 |
+
all_anchors = anchors["large"] + anchors["medium"] + anchors["small"]
|
65 |
+
for (xmin, ymin, xmax, ymax, cls_id) in bboxes:
|
66 |
+
box_w = xmax - xmin
|
67 |
+
box_h = ymax - ymin
|
68 |
+
x_center = (xmax + xmin) / 2
|
69 |
+
y_center = (ymax + ymin) / 2
|
70 |
+
if box_w <= 0 or box_h <= 0:
|
71 |
+
continue
|
72 |
+
best_iou = 0
|
73 |
+
best_idx = 0
|
74 |
+
for i, (aw, ah) in enumerate(all_anchors):
|
75 |
+
inter = min(box_w, aw) * min(box_h, ah)
|
76 |
+
union = box_w * box_h + aw * ah - inter
|
77 |
+
iou = inter / union if union > 0 else 0
|
78 |
+
if iou > best_iou:
|
79 |
+
best_iou = iou
|
80 |
+
best_idx = i
|
81 |
+
if best_idx <= 2:
|
82 |
+
s = grid_sizes[0]
|
83 |
+
t = t_large
|
84 |
+
local_anchor_id = best_idx
|
85 |
+
anchor_w, anchor_h = anchors["large"][local_anchor_id]
|
86 |
+
elif best_idx <= 5:
|
87 |
+
s = grid_sizes[1]
|
88 |
+
t = t_medium
|
89 |
+
local_anchor_id = best_idx - 3
|
90 |
+
anchor_w, anchor_h = anchors["medium"][local_anchor_id]
|
91 |
+
else:
|
92 |
+
s = grid_sizes[2]
|
93 |
+
t = t_small
|
94 |
+
local_anchor_id = best_idx - 6
|
95 |
+
anchor_w, anchor_h = anchors["small"][local_anchor_id]
|
96 |
+
cell_w = img_w / s
|
97 |
+
cell_h = img_h / s
|
98 |
+
gx = int(x_center // cell_w)
|
99 |
+
gy = int(y_center // cell_h)
|
100 |
+
tx = (x_center / cell_w) - gx
|
101 |
+
ty = (y_center / cell_h) - gy
|
102 |
+
tw = np.log((box_w / (anchor_w + 1e-16)) + 1e-16)
|
103 |
+
th = np.log((box_h / (anchor_h + 1e-16)) + 1e-16)
|
104 |
+
t[gy, gx, local_anchor_id, 0] = tx
|
105 |
+
t[gy, gx, local_anchor_id, 1] = ty
|
106 |
+
t[gy, gx, local_anchor_id, 2] = tw
|
107 |
+
t[gy, gx, local_anchor_id, 3] = th
|
108 |
+
t[gy, gx, local_anchor_id, 4] = 1.0
|
109 |
+
t[gy, gx, local_anchor_id, 5 + cls_id] = 1.0
|
110 |
+
return t_large, t_medium, t_small
|
111 |
+
|
112 |
+
|
113 |
+
class MaskDataset(Dataset):
|
114 |
+
def __init__(self, root: str, train: bool = True, test_size: float = 0.25) -> None:
|
115 |
+
super().__init__()
|
116 |
+
self.class_counts = [0, 0, 0]
|
117 |
+
self.root = root
|
118 |
+
self.train = train
|
119 |
+
all_imgs = sorted(list((Path(root) / "images").glob("*.png")))
|
120 |
+
all_anns = sorted(list((Path(root) / "annotations").glob("*.xml")))
|
121 |
+
n_test = int(len(all_imgs) * test_size)
|
122 |
+
if train:
|
123 |
+
self.images = all_imgs[n_test:]
|
124 |
+
self.annots = all_anns[n_test:]
|
125 |
+
else:
|
126 |
+
self.images = all_imgs[:n_test]
|
127 |
+
self.annots = all_anns[:n_test]
|
128 |
+
self.transform = T.Compose([
|
129 |
+
T.ToTensor(),
|
130 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
131 |
+
])
|
132 |
+
for ann in self.annots:
|
133 |
+
with open(ann, "r") as f:
|
134 |
+
data = f.read()
|
135 |
+
soup = BeautifulSoup(data, "lxml")
|
136 |
+
for obj in soup.find_all("object"):
|
137 |
+
cls = obj.find("name").text
|
138 |
+
self.class_counts[0 if cls == "without_mask" else 1 if cls == "with_mask" else 2] += 1
|
139 |
+
|
140 |
+
def __len__(self) -> int:
|
141 |
+
return len(self.images)
|
142 |
+
|
143 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
144 |
+
img_path = self.images[idx]
|
145 |
+
ann_path = self.annots[idx]
|
146 |
+
img = PIL.Image.open(img_path).convert("RGB")
|
147 |
+
img_resized, scale, pad_x, pad_y = resize_with_padding(img)
|
148 |
+
with open(ann_path, "r") as f:
|
149 |
+
data = f.read()
|
150 |
+
soup = BeautifulSoup(data, "lxml")
|
151 |
+
objs = soup.find_all("object")
|
152 |
+
resized_boxes = []
|
153 |
+
for obj in objs:
|
154 |
+
b = generate_box(obj)
|
155 |
+
b2 = resize_boxes(b, scale, pad_x, pad_y)
|
156 |
+
resized_boxes.append(b2)
|
157 |
+
t_large, t_medium, t_small = build_targets_3scale(resized_boxes)
|
158 |
+
img_tensor = self.transform(img_resized)
|
159 |
+
return img_tensor, (t_large, t_medium, t_small)
|
160 |
+
|
161 |
+
|
162 |
+
def collate_fn(batch: List[Tuple[torch.Tensor, Tuple[torch.Tensor]]]) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
163 |
+
imgs, t_l, t_m, t_s = [], [], [], []
|
164 |
+
for (img, (tl, tm, ts)) in batch:
|
165 |
+
imgs.append(img)
|
166 |
+
t_l.append(tl)
|
167 |
+
t_m.append(tm)
|
168 |
+
t_s.append(ts)
|
169 |
+
imgs = torch.stack(imgs, dim=0)
|
170 |
+
t_l = torch.stack(t_l, dim=0)
|
171 |
+
t_m = torch.stack(t_m, dim=0)
|
172 |
+
t_s = torch.stack(t_s, dim=0)
|
173 |
+
return imgs, (t_l, t_m, t_s)
|
src/loss.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, List
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
def box_iou_xyxy(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
|
8 |
+
N = boxes1.size(0)
|
9 |
+
M = boxes2.size(0)
|
10 |
+
x1_1, y1_1, x2_1, y2_1 = boxes1[:, 0], boxes1[:, 1], boxes1[:, 2], boxes1[:, 3]
|
11 |
+
x1_2, y1_2, x2_2, y2_2 = boxes2[:, 0], boxes2[:, 1], boxes2[:, 2], boxes2[:, 3]
|
12 |
+
x1_1 = x1_1.unsqueeze(1).expand(N, M)
|
13 |
+
y1_1 = y1_1.unsqueeze(1).expand(N, M)
|
14 |
+
x2_1 = x2_1.unsqueeze(1).expand(N, M)
|
15 |
+
y2_1 = y2_1.unsqueeze(1).expand(N, M)
|
16 |
+
x1_2 = x1_2.unsqueeze(0).expand(N, M)
|
17 |
+
y1_2 = y1_2.unsqueeze(0).expand(N, M)
|
18 |
+
x2_2 = x2_2.unsqueeze(0).expand(N, M)
|
19 |
+
y2_2 = y2_2.unsqueeze(0).expand(N, M)
|
20 |
+
interX1 = torch.max(x1_1, x1_2)
|
21 |
+
interY1 = torch.max(y1_1, y1_2)
|
22 |
+
interX2 = torch.min(x2_1, x2_2)
|
23 |
+
interY2 = torch.min(y2_1, y2_2)
|
24 |
+
interW = (interX2 - interX1).clamp(min=0)
|
25 |
+
interH = (interY2 - interY1).clamp(min=0)
|
26 |
+
interArea = interW * interH
|
27 |
+
area1 = (x2_1 - x1_1).clamp(min=0) * (y2_1 - y1_1).clamp(min=0)
|
28 |
+
area2 = (x2_2 - x1_2).clamp(min=0) * (y2_2 - y1_2).clamp(min=0)
|
29 |
+
union = area1 + area2 - interArea + 1e-16
|
30 |
+
iou = interArea / union
|
31 |
+
return iou
|
32 |
+
|
33 |
+
|
34 |
+
def box_giou_xyxy(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
|
35 |
+
xA = torch.max(boxes1[:, 0], boxes2[:, 0])
|
36 |
+
yA = torch.max(boxes1[:, 1], boxes2[:, 1])
|
37 |
+
xB = torch.min(boxes1[:, 2], boxes2[:, 2])
|
38 |
+
yB = torch.min(boxes1[:, 3], boxes2[:, 3])
|
39 |
+
interW = (xB - xA).clamp(min=0)
|
40 |
+
interH = (yB - yA).clamp(min=0)
|
41 |
+
interArea = interW * interH
|
42 |
+
area1 = (boxes1[:, 2] - boxes1[:, 0]).clamp(min=0) * (boxes1[:, 3] - boxes1[:, 1]).clamp(min=0)
|
43 |
+
area2 = (boxes2[:, 2] - boxes2[:, 0]).clamp(min=0) * (boxes2[:, 3] - boxes2[:, 1]).clamp(min=0)
|
44 |
+
union = area1 + area2 - interArea + 1e-16
|
45 |
+
iou = interArea / union
|
46 |
+
xC1 = torch.min(boxes1[:, 0], boxes2[:, 0])
|
47 |
+
yC1 = torch.min(boxes1[:, 1], boxes2[:, 1])
|
48 |
+
xC2 = torch.max(boxes1[:, 2], boxes2[:, 2])
|
49 |
+
yC2 = torch.max(boxes1[:, 3], boxes2[:, 3])
|
50 |
+
encloseW = (xC2 - xC1).clamp(min=0)
|
51 |
+
encloseH = (yC2 - yC1).clamp(min=0)
|
52 |
+
encloseArea = encloseW * encloseH + 1e-16
|
53 |
+
giou = iou - (encloseArea - union) / encloseArea
|
54 |
+
return giou
|
55 |
+
|
56 |
+
|
57 |
+
class YoloLoss(nn.Module):
|
58 |
+
def __init__(self, class_counts: List[int], anchors_l: List[int] = [(128, 152), (182, 205), (103, 124)], anchors_m: List[int] = [(78, 88), (55, 59), (37, 42)], anchors_s: List[int] = [(26, 28), (17, 19), (10, 11)], image_size: Tuple[int] = (416, 416), num_classes: int = 3, ignore_thresh: float = 0.7, lambda_noobj: float = 5.0):
|
59 |
+
super().__init__()
|
60 |
+
self.anchors_l = anchors_l
|
61 |
+
self.anchors_m = anchors_m
|
62 |
+
self.anchors_s = anchors_s
|
63 |
+
self.image_size = image_size
|
64 |
+
self.num_classes = num_classes
|
65 |
+
self.ignore_thresh = ignore_thresh
|
66 |
+
self.lambda_noobj = lambda_noobj
|
67 |
+
total = sum(class_counts)
|
68 |
+
w_list = [total / (c + 1e-5) * (2.0 if c_id == 0 else (3.0 if c_id == 2 else 1.0)) for c_id, c in enumerate(class_counts)]
|
69 |
+
self.class_weight = torch.tensor(w_list, dtype=torch.float32)
|
70 |
+
self.bce_obj = nn.BCEWithLogitsLoss(reduction="none")
|
71 |
+
self.bce_cls = nn.BCEWithLogitsLoss(weight=self.class_weight, reduction="none")
|
72 |
+
|
73 |
+
def forward(self, outputs: Tuple[torch.Tensor], targets: Tuple[torch.Tensor]) -> torch.Tensor:
|
74 |
+
out_l, out_m, out_s = outputs
|
75 |
+
t_l, t_m, t_s = targets
|
76 |
+
loss_l = self._loss_single_scale(out_l, t_l, self.anchors_l, scale_wh=(13, 13))
|
77 |
+
loss_m = self._loss_single_scale(out_m, t_m, self.anchors_m, scale_wh=(26, 26))
|
78 |
+
loss_s = self._loss_single_scale(out_s, t_s, self.anchors_s, scale_wh=(52, 52))
|
79 |
+
return loss_l + loss_m + loss_s
|
80 |
+
|
81 |
+
def _loss_single_scale(self, pred: torch.Tensor, target: torch.Tensor, anchors: List[Tuple[int]], scale_wh: Tuple[int]) -> torch.Tensor:
|
82 |
+
device = pred.device
|
83 |
+
B, _, H, W = pred.shape
|
84 |
+
A = len(anchors)
|
85 |
+
pred = pred.view(B, A, (5 + self.num_classes), H, W)
|
86 |
+
pred = pred.permute(0, 3, 4, 1, 2).contiguous()
|
87 |
+
pred_tx = pred[..., 0]
|
88 |
+
pred_ty = pred[..., 1]
|
89 |
+
pred_tw = pred[..., 2]
|
90 |
+
pred_th = pred[..., 3]
|
91 |
+
pred_obj = pred[..., 4]
|
92 |
+
pred_cls = pred[..., 5:]
|
93 |
+
tgt_tx = target[..., 0]
|
94 |
+
tgt_ty = target[..., 1]
|
95 |
+
tgt_tw = target[..., 2]
|
96 |
+
tgt_th = target[..., 3]
|
97 |
+
tgt_obj = target[..., 4]
|
98 |
+
tgt_cls = target[..., 5:]
|
99 |
+
obj_mask = (tgt_obj == 1)
|
100 |
+
noobj_mask = (tgt_obj == 0)
|
101 |
+
img_w, img_h = self.image_size
|
102 |
+
stride_x = img_w / W
|
103 |
+
stride_y = img_h / H
|
104 |
+
grid_x = torch.arange(W, device=device).view(1, 1, W, 1).expand(1, H, W, 1)
|
105 |
+
grid_y = torch.arange(H, device=device).view(1, H, 1, 1).expand(1, H, W, 1)
|
106 |
+
anchors_t = torch.tensor(anchors, dtype=torch.float, device=device)
|
107 |
+
anchor_w = anchors_t[:, 0].view(1, 1, 1, A)
|
108 |
+
anchor_h = anchors_t[:, 1].view(1, 1, 1, A)
|
109 |
+
pred_box_xc = (grid_x + torch.sigmoid(pred_tx)) * stride_x
|
110 |
+
pred_box_yc = (grid_y + torch.sigmoid(pred_ty)) * stride_y
|
111 |
+
pred_box_w = torch.exp(pred_tw) * anchor_w
|
112 |
+
pred_box_h = torch.exp(pred_th) * anchor_h
|
113 |
+
pred_x1 = pred_box_xc - pred_box_w / 2
|
114 |
+
pred_y1 = pred_box_yc - pred_box_h / 2
|
115 |
+
pred_x2 = pred_box_xc + pred_box_w / 2
|
116 |
+
pred_y2 = pred_box_yc + pred_box_h / 2
|
117 |
+
gt_box_xc = (grid_x + tgt_tx) * stride_x
|
118 |
+
gt_box_yc = (grid_y + tgt_ty) * stride_y
|
119 |
+
gt_box_w = torch.exp(tgt_tw) * anchor_w
|
120 |
+
gt_box_h = torch.exp(tgt_th) * anchor_h
|
121 |
+
gt_x1 = gt_box_xc - gt_box_w / 2
|
122 |
+
gt_y1 = gt_box_yc - gt_box_h /2
|
123 |
+
gt_x2 = gt_box_xc + gt_box_w / 2
|
124 |
+
gt_y2 = gt_box_yc + gt_box_h / 2
|
125 |
+
with torch.no_grad():
|
126 |
+
ignore_mask_buf = torch.zeros_like(tgt_obj, dtype=torch.bool)
|
127 |
+
noobj_flat = noobj_mask.view(-1)
|
128 |
+
obj_flat = obj_mask.view(-1)
|
129 |
+
px1f = pred_x1.view(-1)
|
130 |
+
py1f = pred_y1.view(-1)
|
131 |
+
px2f = pred_x2.view(-1)
|
132 |
+
py2f = pred_y2.view(-1)
|
133 |
+
gx1f = gt_x1.view(-1)[obj_flat]
|
134 |
+
gy1f = gt_y1.view(-1)[obj_flat]
|
135 |
+
gx2f = gt_x2.view(-1)[obj_flat]
|
136 |
+
gy2f = gt_y2.view(-1)[obj_flat]
|
137 |
+
if noobj_flat.sum() > 0 and obj_flat.sum() > 0:
|
138 |
+
noobj_idx = noobj_flat.nonzero(as_tuple=True)[0]
|
139 |
+
noobj_boxes_xyxy = torch.stack([px1f[noobj_idx], py1f[noobj_idx], px2f[noobj_idx], py2f[noobj_idx]], dim=-1)
|
140 |
+
obj_boxes_xyxy = torch.stack([gx1f, gy1f, gx2f, gy2f], dim=-1)
|
141 |
+
ious = box_iou_xyxy(noobj_boxes_xyxy, obj_boxes_xyxy)
|
142 |
+
best_iou, _ = ious.max(dim=1)
|
143 |
+
ignore_flags = (best_iou > self.ignore_thresh)
|
144 |
+
all_idx = noobj_idx[ignore_flags]
|
145 |
+
ignore_mask_buf.view(-1)[all_idx] = True
|
146 |
+
ignore_mask = ignore_mask_buf
|
147 |
+
obj_loss = self.bce_obj(pred_obj[obj_mask], torch.ones_like(pred_obj[obj_mask]))
|
148 |
+
obj_loss = obj_loss.mean() if obj_loss.numel() > 0 else torch.tensor(0., device=device)
|
149 |
+
noobj_mask_final = (noobj_mask & (~ignore_mask))
|
150 |
+
noobj_loss = self.bce_obj(pred_obj[noobj_mask_final], torch.zeros_like(pred_obj[noobj_mask_final]))
|
151 |
+
noobj_loss = noobj_loss.mean() if noobj_loss.numel() > 0 else torch.tensor(0., device=device)
|
152 |
+
objectness_loss = obj_loss + self.lambda_noobj * noobj_loss
|
153 |
+
class_loss = torch.tensor(0., device=device, requires_grad=True)
|
154 |
+
if obj_mask.sum() > 0:
|
155 |
+
self.bce_cls.weight = self.class_weight.to(device)
|
156 |
+
cls_pred = pred_cls[obj_mask].to(device)
|
157 |
+
cls_gt = tgt_cls[obj_mask].to(device)
|
158 |
+
c_loss = self.bce_cls(cls_pred, cls_gt)
|
159 |
+
class_loss = c_loss.mean()
|
160 |
+
giou_loss = torch.tensor(0., device=device, requires_grad=True)
|
161 |
+
if obj_mask.sum() > 0:
|
162 |
+
px1_ = pred_x1[obj_mask]
|
163 |
+
py1_ = pred_y1[obj_mask]
|
164 |
+
px2_ = pred_x2[obj_mask]
|
165 |
+
py2_ = pred_y2[obj_mask]
|
166 |
+
p_xyxy = torch.stack([px1_,py1_,px2_,py2_], dim=-1)
|
167 |
+
gx1_ = gt_x1[obj_mask]
|
168 |
+
gy1_ = gt_y1[obj_mask]
|
169 |
+
gx2_ = gt_x2[obj_mask]
|
170 |
+
gy2_ = gt_y2[obj_mask]
|
171 |
+
g_xyxy = torch.stack([gx1_,gy1_,gx2_,gy2_], dim=-1)
|
172 |
+
giou = box_giou_xyxy(p_xyxy, g_xyxy)
|
173 |
+
giou_loss = (1. - giou).mean()
|
174 |
+
total_loss = objectness_loss + class_loss + giou_loss
|
175 |
+
if total_loss is None:
|
176 |
+
pass
|
177 |
+
return total_loss
|
src/models/yolov3.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
def conv_batch(in_ch: int, out_ch: int, kernel_size: int = 3, padding: int = 1, stride: int = 1) -> nn.Sequential:
|
8 |
+
return nn.Sequential(
|
9 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
|
10 |
+
nn.BatchNorm2d(out_ch),
|
11 |
+
nn.LeakyReLU()
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
class DarkResidualBlock(nn.Module):
|
16 |
+
def __init__(self, in_channels: int) -> None:
|
17 |
+
super().__init__()
|
18 |
+
reduced_channels = in_channels // 2
|
19 |
+
self.layer1 = conv_batch(in_channels, reduced_channels, kernel_size=1, padding=0)
|
20 |
+
self.layer2 = conv_batch(reduced_channels, in_channels)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
return x + self.layer2(self.layer1(x))
|
24 |
+
|
25 |
+
|
26 |
+
class Darknet53(nn.Module):
|
27 |
+
def __init__(self, block: nn.Module = DarkResidualBlock) -> None:
|
28 |
+
super().__init__()
|
29 |
+
self.conv1 = conv_batch(3, 32)
|
30 |
+
self.conv2 = conv_batch(32, 64, stride=2)
|
31 |
+
self.residual_block1 = self.make_layer(block, in_channels=64, num_blocks=1)
|
32 |
+
self.conv3 = conv_batch(64, 128, stride=2)
|
33 |
+
self.residual_block2 = self.make_layer(block, in_channels=128, num_blocks=2)
|
34 |
+
self.conv4 = conv_batch(128, 256, stride=2)
|
35 |
+
self.residual_block3 = self.make_layer(block, in_channels=256, num_blocks=8)
|
36 |
+
self.conv5 = conv_batch(256, 512, stride=2)
|
37 |
+
self.residual_block4 = self.make_layer(block, in_channels=512, num_blocks=8)
|
38 |
+
self.conv6 = conv_batch(512, 1024, stride=2)
|
39 |
+
self.residual_block5 = self.make_layer(block, in_channels=1024, num_blocks=4)
|
40 |
+
|
41 |
+
def make_layer(self, block: nn.Module, in_channels: int, num_blocks: int) -> nn.Sequential:
|
42 |
+
layers = []
|
43 |
+
for _ in range(num_blocks):
|
44 |
+
layers.append(block(in_channels))
|
45 |
+
return nn.Sequential(*layers)
|
46 |
+
|
47 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
|
48 |
+
x = self.conv1(x)
|
49 |
+
x = self.conv2(x)
|
50 |
+
x = self.residual_block1(x)
|
51 |
+
x = self.conv3(x)
|
52 |
+
x = self.residual_block2(x)
|
53 |
+
x = self.conv4(x)
|
54 |
+
x = self.residual_block3(x)
|
55 |
+
c4 = x
|
56 |
+
x = self.conv5(x)
|
57 |
+
x = self.residual_block4(x)
|
58 |
+
c5 = x
|
59 |
+
x = self.conv6(x)
|
60 |
+
x = self.residual_block5(x)
|
61 |
+
c6 = x
|
62 |
+
return c4, c5, c6
|
63 |
+
|
64 |
+
|
65 |
+
def conv_leaky(in_ch: int, out_ch: int, k: int = 1, s: int = 1, p: int = 0):
|
66 |
+
return nn.Sequential(
|
67 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=False),
|
68 |
+
nn.BatchNorm2d(out_ch),
|
69 |
+
nn.LeakyReLU(0.1, inplace=True)
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
class DetectionHead(nn.Module):
|
74 |
+
def __init__(self, in_ch: int, mid_ch: int, num_anchors: int = 3, num_classes: int = 3) -> None:
|
75 |
+
super().__init__()
|
76 |
+
self.block = nn.Sequential(
|
77 |
+
conv_leaky(in_ch, mid_ch, k=1, s=1, p=0),
|
78 |
+
conv_leaky(mid_ch, mid_ch * 2, k=3, s=1, p=1),
|
79 |
+
conv_leaky(mid_ch * 2, mid_ch, k=1, s=1, p=0),
|
80 |
+
conv_leaky(mid_ch, mid_ch * 2, k=3, s=1, p=1),
|
81 |
+
conv_leaky(mid_ch * 2, mid_ch, k=1, s=1, p=0)
|
82 |
+
)
|
83 |
+
self.out_conv = nn.Conv2d(mid_ch, num_anchors * (5 + num_classes), kernel_size=1, stride=1, padding=0)
|
84 |
+
|
85 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
86 |
+
x = self.block(x)
|
87 |
+
out = self.out_conv(x)
|
88 |
+
return out
|
89 |
+
|
90 |
+
|
91 |
+
class YOLOv3(nn.Module):
|
92 |
+
def __init__(self, num_classes: int = 3) -> None:
|
93 |
+
super().__init__()
|
94 |
+
self.backbone = Darknet53()
|
95 |
+
self.num_classes = num_classes
|
96 |
+
self.num_anchors = 3
|
97 |
+
self.head_large = DetectionHead(in_ch=1024, mid_ch=512, num_anchors=3, num_classes=num_classes)
|
98 |
+
self.head_medium = DetectionHead(in_ch=1024, mid_ch=256, num_anchors=3, num_classes=num_classes)
|
99 |
+
self.head_small = DetectionHead(in_ch=512, mid_ch=128, num_anchors=3, num_classes=num_classes)
|
100 |
+
self.conv_upsample_l2 = conv_leaky(1024, 512, k=1, s=1, p=0)
|
101 |
+
self.conv_upsample_l3 = conv_leaky(1024, 256, k=1, s=1, p=0)
|
102 |
+
|
103 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
|
104 |
+
c4, c5, c6 = self.backbone(x)
|
105 |
+
out_l = self.head_large(c6)
|
106 |
+
x_l2 = self.conv_upsample_l2(c6)
|
107 |
+
x_l2_up = F.interpolate(x_l2, scale_factor=2, mode="nearest")
|
108 |
+
x_merge_l2 = torch.cat([x_l2_up, c5], dim=1)
|
109 |
+
out_m = self.head_medium(x_merge_l2)
|
110 |
+
x_l3 = self.conv_upsample_l3(x_merge_l2)
|
111 |
+
x_l3_up = F.interpolate(x_l3, scale_factor=2, mode="nearest")
|
112 |
+
x_merge_l3 = torch.cat([x_l3_up, c4], dim=1)
|
113 |
+
out_s = self.head_small(x_merge_l3)
|
114 |
+
return out_l, out_m, out_s
|
src/train.py
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import List, Tuple, Dict
|
4 |
+
from tqdm import tqdm
|
5 |
+
import argparse
|
6 |
+
from accelerate import Accelerator
|
7 |
+
from accelerate.utils import set_seed
|
8 |
+
import wandb
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
import torchvision.ops as ops
|
13 |
+
import PIL
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
from dataset import MaskDataset, collate_fn, ANCHORS
|
17 |
+
from utils import EMA
|
18 |
+
from models.yolov3 import YOLOv3
|
19 |
+
from loss import YoloLoss
|
20 |
+
|
21 |
+
|
22 |
+
class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler):
|
23 |
+
def __init__(self, optimizer: torch.optim.Optimizer, warmup_steps: int, total_steps: int, eta_min: int = 0, last_epoch: int = -1) -> None:
|
24 |
+
self.warmup_steps = warmup_steps
|
25 |
+
self.total_steps = total_steps
|
26 |
+
self.eta_min = eta_min
|
27 |
+
super().__init__(optimizer, last_epoch)
|
28 |
+
|
29 |
+
def get_lr(self) -> List[float]:
|
30 |
+
if self.last_epoch < self.warmup_steps:
|
31 |
+
return [base_lr * (self.last_epoch / max(1, self.warmup_steps)) for base_lr in self.base_lrs]
|
32 |
+
else:
|
33 |
+
current_step = self.last_epoch - self.warmup_steps
|
34 |
+
cosine_steps = max(1, self.total_steps - self.warmup_steps)
|
35 |
+
return [self.eta_min + (base_lr - self.eta_min) * 0.5 * (1 + math.cos(math.pi * current_step / cosine_steps)) for base_lr in self.base_lrs]
|
36 |
+
|
37 |
+
|
38 |
+
def draw_bounding_boxes(image: PIL.Image.Image, boxes: torch.Tensor, colors: Dict[int, int] = {0: (178, 34, 34), 1: (34, 139, 34), 2: (184, 134, 11)}, labels = {0: "without_mask", 1: "with_mask", 2: "weared_incorrect"}, show_conf = False) -> None:
|
39 |
+
draw = PIL.ImageDraw.Draw(image)
|
40 |
+
for box in boxes:
|
41 |
+
xmin, ymin, xmax, ymax, class_id = int(box[0]), int(box[1]), int(box[2]), int(box[3]), int(box[-1])
|
42 |
+
conf_text = ""
|
43 |
+
if show_conf and box.shape[0] == 6:
|
44 |
+
conf = float(box[4])
|
45 |
+
conf_text = f" {conf:.2f}"
|
46 |
+
color = colors.get(class_id, (255, 255, 255))
|
47 |
+
label = labels.get(class_id, "Unknown") + conf_text
|
48 |
+
draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=2)
|
49 |
+
text_bbox = draw.textbbox((xmin, ymin), label)
|
50 |
+
text_width = text_bbox[2] - text_bbox[0]
|
51 |
+
text_height = text_bbox[3] - text_bbox[1]
|
52 |
+
draw.rectangle([xmin, ymin - text_height - 2, xmin + text_width + 2, ymin], fill=color)
|
53 |
+
draw.text((xmin + 1, ymin - text_height - 1), label, fill="white")
|
54 |
+
|
55 |
+
|
56 |
+
def create_combined_image(img: torch.Tensor, gt_batch: List[torch.Tensor], results: List[torch.Tensor], mean: List[float] = [0.485, 0.456, 0.406], std: List[float] = [0.229, 0.224, 0.225]):
|
57 |
+
batch_size, _, height, width = img.shape
|
58 |
+
combined_height = height * 2
|
59 |
+
combined_width = width * batch_size
|
60 |
+
combined_image = np.zeros((combined_height, combined_width, 3), dtype=np.uint8)
|
61 |
+
|
62 |
+
for i in range(batch_size):
|
63 |
+
image = img[i].cpu().permute(1, 2, 0).numpy()
|
64 |
+
image = (image * std + mean).clip(0, 1)
|
65 |
+
image = (image * 255).astype(np.uint8)
|
66 |
+
gt_image = PIL.Image.fromarray(image.copy())
|
67 |
+
pred_image = PIL.Image.fromarray(image.copy())
|
68 |
+
draw_bounding_boxes(gt_image, gt_batch[i])
|
69 |
+
draw_bounding_boxes(pred_image, results[i], show_conf=True)
|
70 |
+
combined_image[:height, i * width:(i + 1) * width, :] = np.array(gt_image)
|
71 |
+
combined_image[height:, i * width:(i + 1) * width, :] = np.array(pred_image)
|
72 |
+
return PIL.Image.fromarray(combined_image)
|
73 |
+
|
74 |
+
|
75 |
+
def decode_yolo_output_single(prediction: torch.Tensor, anchors: List[Tuple[int]], image_size: Tuple[int] = (416, 416), conf_threshold: float = 0.5, iou_threshold: float = 0.3, apply_nms: bool = True, num_classes: int = 3) -> List[torch.Tensor]:
|
76 |
+
device = prediction.device
|
77 |
+
B, _, H, W = prediction.shape
|
78 |
+
A = len(anchors)
|
79 |
+
prediction = prediction.view(B, A, 5 + num_classes, H, W)
|
80 |
+
prediction = prediction.permute(0, 1, 3, 4, 2).contiguous()
|
81 |
+
tx = prediction[..., 0]
|
82 |
+
ty = prediction[..., 1]
|
83 |
+
tw = prediction[..., 2]
|
84 |
+
th = prediction[..., 3]
|
85 |
+
obj = prediction[..., 4]
|
86 |
+
class_scores = prediction[..., 5:]
|
87 |
+
tx = tx.sigmoid()
|
88 |
+
ty = ty.sigmoid()
|
89 |
+
obj = obj.sigmoid()
|
90 |
+
class_scores = class_scores.softmax(dim=-1)
|
91 |
+
img_w, img_h = image_size
|
92 |
+
cell_w = img_w / W
|
93 |
+
cell_h = img_h / H
|
94 |
+
grid_x = torch.arange(W, device=device).view(1, 1, W).expand(1, H, W)
|
95 |
+
grid_y = torch.arange(H, device=device).view(1, H, 1).expand(1, H, W)
|
96 |
+
anchors_tensor = torch.tensor(anchors, dtype=torch.float32, device=device)
|
97 |
+
anchor_w = anchors_tensor[:, 0].view(1, A, 1, 1)
|
98 |
+
anchor_h = anchors_tensor[:, 1].view(1, A, 1, 1)
|
99 |
+
x_center = (grid_x + tx) * cell_w
|
100 |
+
y_center = (grid_y + ty) * cell_h
|
101 |
+
w = torch.exp(tw) * anchor_w
|
102 |
+
h = torch.exp(th) * anchor_h
|
103 |
+
xmin = x_center - w / 2
|
104 |
+
ymin = y_center - h / 2
|
105 |
+
xmax = x_center + w / 2
|
106 |
+
ymax = y_center + h / 2
|
107 |
+
max_class_probs, class_ids = class_scores.max(dim=-1)
|
108 |
+
confidence = obj * max_class_probs
|
109 |
+
outputs = []
|
110 |
+
for b_i in range(B):
|
111 |
+
box_xmin = xmin[b_i].view(-1)
|
112 |
+
box_ymin = ymin[b_i].view(-1)
|
113 |
+
box_xmax = xmax[b_i].view(-1)
|
114 |
+
box_ymax = ymax[b_i].view(-1)
|
115 |
+
conf = confidence[b_i].view(-1)
|
116 |
+
cls_id = class_ids[b_i].view(-1).float()
|
117 |
+
mask = (conf > conf_threshold)
|
118 |
+
box_xmin = box_xmin[mask]
|
119 |
+
box_ymin = box_ymin[mask]
|
120 |
+
box_xmax = box_xmax[mask]
|
121 |
+
box_ymax = box_ymax[mask]
|
122 |
+
conf = conf[mask]
|
123 |
+
cls_id = cls_id[mask]
|
124 |
+
if mask.sum() == 0:
|
125 |
+
outputs.append(torch.empty((0, 6), device=device))
|
126 |
+
continue
|
127 |
+
boxes = torch.stack([box_xmin, box_ymin, box_xmax, box_ymax], dim=-1)
|
128 |
+
if apply_nms:
|
129 |
+
keep = ops.nms(boxes, conf, iou_threshold)
|
130 |
+
boxes = boxes[keep]
|
131 |
+
conf = conf[keep]
|
132 |
+
cls_id = cls_id[keep]
|
133 |
+
out = torch.cat([boxes, conf.unsqueeze(-1), cls_id.unsqueeze(-1)], dim=-1)
|
134 |
+
outputs.append(out)
|
135 |
+
return outputs
|
136 |
+
|
137 |
+
|
138 |
+
def decode_predictions_3scales(out_l: torch.Tensor, out_m: torch.Tensor, out_s: torch.Tensor, anchors_l: List[Tuple[int]], anchors_m: List[Tuple[int, int]], anchors_s: List[Tuple[int, int]], image_size: Tuple[int, int] = (416, 416), conf_threshold: float = 0.5, iou_threshold: float = 0.45, num_classes: int = 3) -> List[torch.Tensor]:
|
139 |
+
b_l = decode_yolo_output_single(out_l, anchors_l, image_size, conf_threshold, iou_threshold, apply_nms=False, num_classes=num_classes)
|
140 |
+
b_m = decode_yolo_output_single(out_m, anchors_m, image_size, conf_threshold, iou_threshold, apply_nms=False, num_classes=num_classes)
|
141 |
+
b_s = decode_yolo_output_single(out_s, anchors_s, image_size, conf_threshold, iou_threshold, apply_nms=False, num_classes=num_classes)
|
142 |
+
results = []
|
143 |
+
B = len(b_l)
|
144 |
+
for i in range(B):
|
145 |
+
boxes_all = torch.cat([b_l[i], b_m[i], b_s[i]], dim=0)
|
146 |
+
if boxes_all.numel() == 0:
|
147 |
+
results.append(boxes_all)
|
148 |
+
continue
|
149 |
+
xyxy = boxes_all[:, :4]
|
150 |
+
scores = boxes_all[:, 4]
|
151 |
+
keep = ops.nms(xyxy, scores, iou_threshold)
|
152 |
+
final = boxes_all[keep]
|
153 |
+
results.append(final)
|
154 |
+
return results
|
155 |
+
|
156 |
+
|
157 |
+
def decode_target_single(target: torch.Tensor, anchors: List[Tuple[int]], image_size: Tuple[int] = (416, 416), obj_threshold: float = 0.5) -> List[torch.Tensor]:
|
158 |
+
args = parse_args()
|
159 |
+
target = target.to(args.device)
|
160 |
+
B, S, _, A, _ = target.shape
|
161 |
+
img_w, img_h = image_size
|
162 |
+
cell_w = img_w / S
|
163 |
+
cell_h = img_h / S
|
164 |
+
anchors_tensor = torch.tensor(anchors, dtype=torch.float)
|
165 |
+
tx = target[..., 0]
|
166 |
+
ty = target[..., 1]
|
167 |
+
tw = target[..., 2]
|
168 |
+
th = target[..., 3]
|
169 |
+
tobj = target[..., 4]
|
170 |
+
tcls = target[..., 5:]
|
171 |
+
results = []
|
172 |
+
for b_i in range(B):
|
173 |
+
bx_list = []
|
174 |
+
tx_b = tx[b_i]
|
175 |
+
ty_b = ty[b_i]
|
176 |
+
tw_b = tw[b_i]
|
177 |
+
th_b = th[b_i]
|
178 |
+
tobj_b = tobj[b_i]
|
179 |
+
tcls_b = tcls[b_i]
|
180 |
+
for i in range(S):
|
181 |
+
for j in range(S):
|
182 |
+
for a_i in range(A):
|
183 |
+
if tobj_b[i,j,a_i] < obj_threshold:
|
184 |
+
continue
|
185 |
+
cls_one_hot = tcls_b[i, j, a_i]
|
186 |
+
cls_id = cls_one_hot.argmax().item()
|
187 |
+
x_center = (j + tx_b[i, j, a_i].item()) * cell_w
|
188 |
+
y_center = (i + ty_b[i, j, a_i].item()) * cell_h
|
189 |
+
anchor_w = anchors_tensor[a_i, 0]
|
190 |
+
anchor_h = anchors_tensor[a_i, 1]
|
191 |
+
box_w = torch.exp(tw_b[i, j, a_i]) * anchor_w
|
192 |
+
box_h = torch.exp(th_b[i, j, a_i]) * anchor_h
|
193 |
+
xmin = x_center - box_w / 2
|
194 |
+
ymin = y_center - box_h / 2
|
195 |
+
xmax = x_center + box_w / 2
|
196 |
+
ymax = y_center + box_h / 2
|
197 |
+
bx_list.append([xmin.item(), ymin.item(), xmax.item(), ymax.item(), cls_id])
|
198 |
+
if len(bx_list) == 0:
|
199 |
+
results.append(torch.empty((0, 5), dtype=torch.float32, device=args.device))
|
200 |
+
else:
|
201 |
+
results.append(torch.tensor(bx_list, dtype=torch.float32, device=args.device))
|
202 |
+
return results
|
203 |
+
|
204 |
+
|
205 |
+
def decode_target_3scales(t_l: torch.Tensor, t_m: torch.Tensor, t_s: torch.Tensor, anchors_l: List[Tuple[int]], anchors_m: List[Tuple[int]], anchors_s: List[Tuple[int]], image_size: Tuple[int] = (416, 416), obj_threshold: float = 0.5) -> List[torch.Tensor]:
|
206 |
+
dec_l = decode_target_single(t_l, anchors_l, image_size, obj_threshold)
|
207 |
+
dec_m = decode_target_single(t_m, anchors_m, image_size, obj_threshold)
|
208 |
+
dec_s = decode_target_single(t_s, anchors_s, image_size, obj_threshold)
|
209 |
+
results = []
|
210 |
+
B = len(dec_l)
|
211 |
+
for i in range(B):
|
212 |
+
boxes_l = dec_l[i]
|
213 |
+
boxes_m = dec_m[i]
|
214 |
+
boxes_s = dec_s[i]
|
215 |
+
if boxes_l.numel() == 0 and boxes_m.numel() == 0 and boxes_s.numel() == 0:
|
216 |
+
results.append(torch.empty((0, 5), dtype=torch.float32, device=boxes_l.device))
|
217 |
+
else:
|
218 |
+
all_ = torch.cat([boxes_l, boxes_m, boxes_s], dim=0)
|
219 |
+
results.append(all_)
|
220 |
+
return results
|
221 |
+
|
222 |
+
|
223 |
+
def iou_xyxy(box1: List[int | float], box2: List[int | float]) -> float:
|
224 |
+
x1 = max(box1[0], box2[0])
|
225 |
+
y1 = max(box1[1], box2[1])
|
226 |
+
x2 = min(box1[2], box2[2])
|
227 |
+
y2 = min(box1[3], box2[3])
|
228 |
+
w = max(0., x2 - x1)
|
229 |
+
h = max(0., y2 - y1)
|
230 |
+
inter = w * h
|
231 |
+
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
232 |
+
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
233 |
+
union = area1 + area2 - inter
|
234 |
+
return inter / union if union > 0 else 0.0
|
235 |
+
|
236 |
+
|
237 |
+
def compute_ap_per_class(boxes_pred: List[List[float]], boxes_gt: List[List[float]], iou_threshold: float = 0.45) -> float:
|
238 |
+
boxes_pred = sorted(boxes_pred, key=lambda x: x[4], reverse=True)
|
239 |
+
n_gt = len(boxes_gt)
|
240 |
+
if n_gt == 0 and len(boxes_pred) == 0:
|
241 |
+
return 1.0
|
242 |
+
if n_gt == 0:
|
243 |
+
return 0.0
|
244 |
+
matched = [False] * n_gt
|
245 |
+
tps = []
|
246 |
+
fps = []
|
247 |
+
for i, pred in enumerate(boxes_pred):
|
248 |
+
best_iou = 0.0
|
249 |
+
best_j = -1
|
250 |
+
for j, gt in enumerate(boxes_gt):
|
251 |
+
if matched[j]:
|
252 |
+
continue
|
253 |
+
iou = iou_xyxy(pred, gt)
|
254 |
+
if iou > best_iou:
|
255 |
+
best_iou = iou
|
256 |
+
best_j = j
|
257 |
+
if best_iou > iou_threshold and best_j >= 0:
|
258 |
+
tps.append(1)
|
259 |
+
fps.append(0)
|
260 |
+
matched[best_j] = True
|
261 |
+
else:
|
262 |
+
tps.append(0)
|
263 |
+
fps.append(1)
|
264 |
+
tps_cum = []
|
265 |
+
fps_cum = []
|
266 |
+
s_tp = 0
|
267 |
+
s_fp = 0
|
268 |
+
for i in range(len(tps)):
|
269 |
+
s_tp += tps[i]
|
270 |
+
s_fp += fps[i]
|
271 |
+
tps_cum.append(s_tp)
|
272 |
+
fps_cum.append(s_fp)
|
273 |
+
precisions = []
|
274 |
+
recalls = []
|
275 |
+
for i in range(len(tps)):
|
276 |
+
prec = tps_cum[i] / (tps_cum[i] + fps_cum[i]) if (tps_cum[i] + fps_cum[i]) > 0 else 0
|
277 |
+
rec = tps_cum[i] / n_gt
|
278 |
+
precisions.append(prec)
|
279 |
+
recalls.append(rec)
|
280 |
+
recalls = [0.0] + recalls + [1.0]
|
281 |
+
precisions = [1.0] + precisions + [0.0]
|
282 |
+
for i in range(len(precisions) - 2, -1, -1):
|
283 |
+
precisions[i] = max(precisions[i], precisions[i+1])
|
284 |
+
ap = 0.0
|
285 |
+
for i in range(len(precisions) - 1):
|
286 |
+
ap += (recalls[i+1] - recalls[i]) * precisions[i+1]
|
287 |
+
return ap
|
288 |
+
|
289 |
+
|
290 |
+
def compute_map(all_pred: List[float], all_gt: List[float], num_classes: int = 3, iou_threshold: float = 0.45) -> float:
|
291 |
+
APs = []
|
292 |
+
for c in range(num_classes):
|
293 |
+
ap_c = compute_ap_per_class(all_pred[c], all_gt[c], iou_threshold)
|
294 |
+
APs.append(ap_c)
|
295 |
+
mAP = sum(APs) / len(APs) if len(APs) > 0 else 0.0
|
296 |
+
return mAP
|
297 |
+
|
298 |
+
|
299 |
+
def parse_args():
|
300 |
+
parser = argparse.ArgumentParser(description="Train a model on the face mask detection dataset")
|
301 |
+
parser.add_argument("--root", type=str, default="data/masks", help="Path to the data")
|
302 |
+
parser.add_argument("--batch-size", type=int, default=16, help="Batch size for training and testing")
|
303 |
+
parser.add_argument("--logs-dir", type=str, default="yolo-logs", help="Path to save logs")
|
304 |
+
parser.add_argument("--pin-memory", type=bool, default=True, help="Pin Memory for DataLoader")
|
305 |
+
parser.add_argument("--num-workers", type=int, default=0, help="Number of workers for DataLoader")
|
306 |
+
parser.add_argument("--num-epochs", type=int, default=100, help="Number of training epochs")
|
307 |
+
parser.add_argument("--optimizer", type=str, default="AdamW", help="Optimizer type")
|
308 |
+
parser.add_argument("--learning-rate", type=float, default=5e-4, help="Learning rate for the optimizer")
|
309 |
+
parser.add_argument("--save-frequency", type=int, default=4, help="Frequency of saving model weights")
|
310 |
+
parser.add_argument("--max-norm", type=float, default=10.0, help="Maximum gradient norm for clipping")
|
311 |
+
parser.add_argument("--project-name", type=str, default="YOLOv3, mask detection", help="Wandb project name")
|
312 |
+
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run the training on")
|
313 |
+
parser.add_argument("--weights-path", type=str, default="weights/darknet53.pth", help="Path to the weights")
|
314 |
+
parser.add_argument("--seed", type=int, default=42, help="Value of the seed")
|
315 |
+
parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["fp16", "bf16", "fp8", "no"], help="Value of the mixed precision")
|
316 |
+
parser.add_argument("--gradient-accumulation-steps", type=int, default=2, help="Value of the gradient accumulation steps")
|
317 |
+
parser.add_argument("--log-steps", type=int, default=13, help="Number of steps between logging training images and metrics")
|
318 |
+
parser.add_argument("--num-warmup-steps", type=int, default=400, help="Number of steps")
|
319 |
+
return parser.parse_args()
|
320 |
+
|
321 |
+
|
322 |
+
def main() -> None:
|
323 |
+
args = parse_args()
|
324 |
+
set_seed(args.seed)
|
325 |
+
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision)
|
326 |
+
with accelerator.main_process_first():
|
327 |
+
logs_dir = Path(args.logs_dir)
|
328 |
+
logs_dir.mkdir(exist_ok=True)
|
329 |
+
wandb.init(project=args.project_name, dir=logs_dir)
|
330 |
+
train_dataset = MaskDataset(root=args.root, train=True)
|
331 |
+
test_dataset = MaskDataset(root=args.root, train=False)
|
332 |
+
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=args.pin_memory, num_workers=args.num_workers, collate_fn=collate_fn)
|
333 |
+
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=args.pin_memory, num_workers=args.num_workers, collate_fn=collate_fn)
|
334 |
+
model = YOLOv3().to(accelerator.device)
|
335 |
+
optimizer_class = getattr(torch.optim, args.optimizer)
|
336 |
+
if args.weights_path:
|
337 |
+
weights = torch.load(args.weights_path, map_location="cpu", weights_only=True)
|
338 |
+
model.backbone.load_state_dict(weights)
|
339 |
+
optimizer = optimizer_class(model.parameters(), lr=args.learning_rate)
|
340 |
+
criterion = YoloLoss(class_counts=train_dataset.class_counts)
|
341 |
+
scheduler = WarmupCosineAnnealingLR(optimizer, warmup_steps=args.num_warmup_steps//args.gradient_accumulation_steps, total_steps=args.num_epochs*len(train_loader)//args.gradient_accumulation_steps, eta_min=1e-7)
|
342 |
+
|
343 |
+
model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
|
344 |
+
best_map = 0.0
|
345 |
+
train_loss_ema = EMA()
|
346 |
+
for epoch in range(1, args.num_epochs + 1):
|
347 |
+
model.train()
|
348 |
+
pbar = tqdm(train_loader, desc = f"Train epoch {epoch} / {args.num_epochs}")
|
349 |
+
for images, (t_l, t_m, t_s) in pbar:
|
350 |
+
images = images.to(accelerator.device)
|
351 |
+
t_l = t_l.to(accelerator.device)
|
352 |
+
t_m = t_m.to(accelerator.device)
|
353 |
+
t_s = t_s.to(accelerator.device)
|
354 |
+
with accelerator.accumulate(model):
|
355 |
+
with accelerator.autocast():
|
356 |
+
out_l, out_m, out_s = model(images)
|
357 |
+
loss = criterion((out_l, out_m, out_s), (t_l, t_m, t_s))
|
358 |
+
accelerator.backward(loss)
|
359 |
+
grad_norm = None
|
360 |
+
if accelerator.sync_gradients:
|
361 |
+
grad_norm = accelerator.clip_grad_norm_(model.parameters(), args.max_norm).item()
|
362 |
+
optimizer.step()
|
363 |
+
optimizer.zero_grad()
|
364 |
+
scheduler.step()
|
365 |
+
lr = scheduler.get_last_lr()[0]
|
366 |
+
pbar.set_postfix({"loss": train_loss_ema(loss.item())})
|
367 |
+
log_data = {
|
368 |
+
"train/epoch": epoch,
|
369 |
+
"train/loss": loss.item(),
|
370 |
+
"train/lr": lr
|
371 |
+
}
|
372 |
+
if grad_norm is not None:
|
373 |
+
log_data["train/grad_norm"] = grad_norm
|
374 |
+
if accelerator.is_main_process:
|
375 |
+
wandb.log(log_data)
|
376 |
+
accelerator.wait_for_everyone()
|
377 |
+
model.eval()
|
378 |
+
all_pred = [[] for _ in range(model.num_classes)]
|
379 |
+
all_gt = [[] for _ in range(model.num_classes)]
|
380 |
+
with torch.inference_mode():
|
381 |
+
test_loss = 0.0
|
382 |
+
pbar = tqdm(test_loader, desc=f"Test epoch {epoch} / {args.num_epochs}")
|
383 |
+
for index, (images, (t_l, t_m, t_s)) in enumerate(pbar):
|
384 |
+
images = images.to(accelerator.device)
|
385 |
+
t_l = t_l.to(accelerator.device)
|
386 |
+
t_m = t_m.to(accelerator.device)
|
387 |
+
t_s = t_s.to(accelerator.device)
|
388 |
+
out_l, out_m, out_s = model(images)
|
389 |
+
loss = criterion((out_l, out_m, out_s), (t_l, t_m, t_s))
|
390 |
+
test_loss += loss.item()
|
391 |
+
results = decode_predictions_3scales(out_l, out_m, out_s, ANCHORS["large"], ANCHORS["medium"], ANCHORS["small"])
|
392 |
+
gt_batch = decode_target_3scales(t_l, t_m, t_s, ANCHORS["large"], ANCHORS["medium"], ANCHORS["small"])
|
393 |
+
if (index + 1) % args.log_steps == 0 and accelerator.is_main_process:
|
394 |
+
images_to_log = []
|
395 |
+
combined_image = create_combined_image(images, gt_batch, results)
|
396 |
+
images_to_log.append(wandb.Image(combined_image, caption=f"Combined Image (Test, Epoch {epoch})"))
|
397 |
+
wandb.log({"test_samples": images_to_log})
|
398 |
+
for b_i in range(len(images)):
|
399 |
+
dets_b = results[b_i].detach().cpu().numpy()
|
400 |
+
gts_b = gt_batch[b_i].detach().cpu().numpy()
|
401 |
+
for db in dets_b:
|
402 |
+
c = int(db[5])
|
403 |
+
all_pred[c].append([db[0], db[1], db[2], db[3], db[4]])
|
404 |
+
for gb in gts_b:
|
405 |
+
c = int(gb[4])
|
406 |
+
all_gt[c].append([gb[0], gb[1], gb[2], gb[3]])
|
407 |
+
test_loss /= len(test_loader)
|
408 |
+
test_map = compute_map(all_pred, all_gt)
|
409 |
+
accelerator.print(f"loss: {test_loss:.3f}, map: {test_map:.3f}")
|
410 |
+
if accelerator.is_main_process:
|
411 |
+
wandb.log({
|
412 |
+
"epoch": epoch,
|
413 |
+
"test/loss": test_loss,
|
414 |
+
"test/mAP": test_map
|
415 |
+
})
|
416 |
+
if test_map > best_map:
|
417 |
+
best_map = test_map
|
418 |
+
accelerator.save(model.state_dict(), logs_dir / "checkpoint-best.pth")
|
419 |
+
elif epoch % args.save_frequency == 0:
|
420 |
+
accelerator.save(model.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth")
|
421 |
+
accelerator.wait_for_everyone()
|
422 |
+
accelerator.wait_for_everyone()
|
423 |
+
wandb.finish()
|
424 |
+
|
425 |
+
|
426 |
+
if __name__ == "__main__":
|
427 |
+
main()
|
src/utils.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class EMA:
|
2 |
+
def __init__(self, alpha: float = 0.9) -> None:
|
3 |
+
self.value = None
|
4 |
+
self.alpha = alpha
|
5 |
+
|
6 |
+
def __call__(self, value: float) -> float:
|
7 |
+
if self.value is None:
|
8 |
+
self.value = value
|
9 |
+
else:
|
10 |
+
self.value = self.alpha * self.value + (1 - self.alpha) * value
|
11 |
+
return self.value
|
weights/checkpoint-best.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d496cd707cec1135b6d6cfece5c35b92572063914d81ae2bbbc8ded5c7366e10
|
3 |
+
size 224442922
|