venkyyuvy commited on
Commit
c87ccf3
·
1 Parent(s): 0e79fb9

model w mosaic aug

Browse files
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ import cv2
5
+ import albumentations as A
6
+ from albumentations.pytorch import ToTensorV2
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib
9
+ matplotlib.use('agg')
10
+
11
+ from model import YOLOv3
12
+ from utils import (
13
+ cells_to_bboxes,
14
+ non_max_suppression,
15
+ plot_image
16
+ )
17
+
18
+
19
+ ANCHORS = [
20
+ [(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)],
21
+ [(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)],
22
+ [(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)],
23
+ ] # Note these have been rescaled to be between [0, 1]
24
+
25
+ fname = 'epoch=36-step=19166.ckpt'
26
+ checkpoint = torch.load(fname, map_location=torch.device('cpu'))
27
+ model_state_dict = checkpoint['state_dict']
28
+ model = YOLOv3(num_classes=20)
29
+ model.load_state_dict(model_state_dict)
30
+
31
+ IMAGE_SIZE = 416
32
+ S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8]
33
+ anchors = ( torch.tensor(ANCHORS)
34
+ * torch.tensor(S).unsqueeze(1)\
35
+ .unsqueeze(1).repeat(1, 3, 2)
36
+ )
37
+
38
+ test_transforms = A.Compose(
39
+ [
40
+ A.LongestMaxSize(max_size=IMAGE_SIZE),
41
+ A.PadIfNeeded(
42
+ min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
43
+ ),
44
+ A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
45
+ ToTensorV2(),
46
+ ],
47
+ )
48
+ def object_detector(input_image):
49
+ input_img = test_transforms(image=input_image)['image']
50
+ input_img = input_img.unsqueeze(0)
51
+
52
+ thresh = 0.6
53
+ iou_thresh = 0.5
54
+ with torch.no_grad():
55
+ out = model(input_img)
56
+ bboxes = []
57
+ for i in range(3):
58
+ _, _, S, _, _ = out[i].shape
59
+ anchor = anchors[i]
60
+ bboxes += cells_to_bboxes(
61
+ out[i], anchor, S=S, is_preds=True
62
+ )
63
+ nms_boxes = non_max_suppression(
64
+ bboxes[0], iou_threshold=iou_thresh,
65
+ threshold=thresh, box_format="midpoint",
66
+ )
67
+ fig = plot_image(input_img.squeeze(0).permute(1,2,0).detach().cpu(),
68
+ nms_boxes,
69
+ return_fig=True)
70
+ plt.gca().set(xticks=[], yticks=[], xticklabels=[], yticklabels=[])
71
+ plt.axis('off')
72
+ image_path = "plot.png"
73
+ fig.savefig(image_path)
74
+ plt.close()
75
+ return gr.update(value=image_path, visible=True)
76
+
77
+ # Define the input and output components for Gradio
78
+ input_image = gr.Image(label="Input image")
79
+ output_box = gr.Image(label="Output image")\
80
+ .style(width=428, height=428)
81
+ images_path = "examples/"
82
+
83
+
84
+ # Create the Gradio interface
85
+ gr.Interface(fn=object_detector, inputs=input_image, outputs=output_box,
86
+ examples=[[images_path + "000015.jpg"],
87
+ [images_path + "000017.jpg"],
88
+ [images_path + "000030.jpg"],
89
+ [images_path + "000069.jpg"],
90
+ [images_path + "000071.jpg"],
91
+ [images_path + "000084.jpg"],
92
+ [images_path + "000086.jpg"],
93
+ [images_path + "000088.jpg"],
94
+ [images_path + "000095.jpg"],
95
+ [images_path + "000100.jpg"],
96
+ ],
97
+ ).launch()
98
+
epoch=36-step=19166.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd0d67eac457f890a4c9557393e9b4c705dcbac707d30f9cbe658a1b7d00a678
3
+ size 740104313
examples/000015.jpg ADDED
examples/000017.jpg ADDED
examples/000030.jpg ADDED
examples/000069.jpg ADDED
examples/000071.jpg ADDED
examples/000084.jpg ADDED
examples/000086.jpg ADDED
examples/000088.jpg ADDED
examples/000095.jpg ADDED
examples/000100.jpg ADDED
examples/000186.jpg ADDED
model.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of YOLOv3 architecture
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch import optim
8
+
9
+ """
10
+ Information about architecture config:
11
+ Tuple is structured by (filters, kernel_size, stride)
12
+ Every conv is a same convolution.
13
+ List is structured by "B" indicating a residual block followed by the number of repeats
14
+ "S" is for scale prediction block and computing the yolo loss
15
+ "U" is for upsampling the feature map and concatenating with a previous layer
16
+ """
17
+ config = [
18
+ (32, 3, 1),
19
+ (64, 3, 2),
20
+ ["B", 1],
21
+ (128, 3, 2),
22
+ ["B", 2],
23
+ (256, 3, 2),
24
+ ["B", 8],
25
+ (512, 3, 2),
26
+ ["B", 8],
27
+ (1024, 3, 2),
28
+ ["B", 4], # To this point is Darknet-53
29
+ (512, 1, 1),
30
+ (1024, 3, 1),
31
+ "S",
32
+ (256, 1, 1),
33
+ "U",
34
+ (256, 1, 1),
35
+ (512, 3, 1),
36
+ "S",
37
+ (128, 1, 1),
38
+ "U",
39
+ (128, 1, 1),
40
+ (256, 3, 1),
41
+ "S",
42
+ ]
43
+
44
+
45
+ class CNNBlock(nn.Module):
46
+ def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):
47
+ super().__init__()
48
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs)
49
+ self.bn = nn.BatchNorm2d(out_channels)
50
+ self.leaky = nn.LeakyReLU(0.1)
51
+ self.use_bn_act = bn_act
52
+
53
+ def forward(self, x):
54
+ if self.use_bn_act:
55
+ return self.leaky(self.bn(self.conv(x)))
56
+ else:
57
+ return self.conv(x)
58
+
59
+
60
+ class ResidualBlock(nn.Module):
61
+ def __init__(self, channels, use_residual=True, num_repeats=1):
62
+ super().__init__()
63
+ self.layers = nn.ModuleList()
64
+ for repeat in range(num_repeats):
65
+ self.layers += [
66
+ nn.Sequential(
67
+ CNNBlock(channels, channels // 2, kernel_size=1),
68
+ CNNBlock(channels // 2, channels, kernel_size=3, padding=1),
69
+ )
70
+ ]
71
+
72
+ self.use_residual = use_residual
73
+ self.num_repeats = num_repeats
74
+
75
+ def forward(self, x):
76
+ for layer in self.layers:
77
+ if self.use_residual:
78
+ x = x + layer(x)
79
+ else:
80
+ x = layer(x)
81
+
82
+ return x
83
+
84
+
85
+ class ScalePrediction(nn.Module):
86
+ def __init__(self, in_channels, num_classes):
87
+ super().__init__()
88
+ self.pred = nn.Sequential(
89
+ CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1),
90
+ CNNBlock(
91
+ 2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1
92
+ ),
93
+ )
94
+ self.num_classes = num_classes
95
+
96
+ def forward(self, x):
97
+ return (
98
+ self.pred(x)
99
+ .reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3])
100
+ .permute(0, 1, 3, 4, 2)
101
+ )
102
+
103
+
104
+ class YOLOv3(nn.Module):
105
+ def __init__(self, in_channels=3, num_classes=80):
106
+ super().__init__()
107
+ self.num_classes = num_classes
108
+ self.in_channels = in_channels
109
+ self.layers = self._create_conv_layers()
110
+
111
+ def forward(self, x):
112
+ outputs = [] # for each scale
113
+ route_connections = []
114
+ for layer in self.layers:
115
+ if isinstance(layer, ScalePrediction):
116
+ outputs.append(layer(x))
117
+ continue
118
+
119
+ x = layer(x)
120
+
121
+ if isinstance(layer, ResidualBlock) and layer.num_repeats == 8:
122
+ route_connections.append(x)
123
+
124
+ elif isinstance(layer, nn.Upsample):
125
+ x = torch.cat([x, route_connections[-1]], dim=1)
126
+ route_connections.pop()
127
+
128
+ return outputs
129
+
130
+ def _create_conv_layers(self):
131
+ layers = nn.ModuleList()
132
+ in_channels = self.in_channels
133
+
134
+ for module in config:
135
+ if isinstance(module, tuple):
136
+ out_channels, kernel_size, stride = module
137
+ layers.append(
138
+ CNNBlock(
139
+ in_channels,
140
+ out_channels,
141
+ kernel_size=kernel_size,
142
+ stride=stride,
143
+ padding=1 if kernel_size == 3 else 0,
144
+ )
145
+ )
146
+ in_channels = out_channels
147
+
148
+ elif isinstance(module, list):
149
+ num_repeats = module[1]
150
+ layers.append(ResidualBlock(in_channels, num_repeats=num_repeats,))
151
+
152
+ elif isinstance(module, str):
153
+ if module == "S":
154
+ layers += [
155
+ ResidualBlock(in_channels, use_residual=False, num_repeats=1),
156
+ CNNBlock(in_channels, in_channels // 2, kernel_size=1),
157
+ ScalePrediction(in_channels // 2, num_classes=self.num_classes),
158
+ ]
159
+ in_channels = in_channels // 2
160
+
161
+ elif module == "U":
162
+ layers.append(nn.Upsample(scale_factor=2),)
163
+ in_channels = in_channels * 3
164
+
165
+ return layers
166
+
167
+ def configure_optimizers(self):
168
+ optimizer = optim.Adam(
169
+ self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
170
+ )
171
+ scheduler = OneCycleLR(
172
+ optimizer,
173
+ max_lr=self.best_lr,
174
+ steps_per_epoch=len(self.trainer.datamodule.train_dataloader()),
175
+ epochs=config.NUM_EPOCHS,
176
+ pct_start=5 / config.NUM_EPOCHS,
177
+ div_factor=100,
178
+ three_phase=False,
179
+ final_div_factor=100,
180
+ anneal_strategy="linear",
181
+ )
182
+ return [optimizer], [
183
+ {"scheduler": scheduler, "interval": "step", "frequency": 1}
184
+ ]
185
+
186
+
187
+ if __name__ == "__main__":
188
+ num_classes = 20
189
+ IMAGE_SIZE = 416
190
+ model = YOLOv3(num_classes=num_classes)
191
+ x = torch.randn((2, 3, IMAGE_SIZE, IMAGE_SIZE))
192
+ out = model(x)
193
+ assert model(x)[0].shape == (2, 3, IMAGE_SIZE//32, IMAGE_SIZE//32, num_classes + 5)
194
+ assert model(x)[1].shape == (2, 3, IMAGE_SIZE//16, IMAGE_SIZE//16, num_classes + 5)
195
+ assert model(x)[2].shape == (2, 3, IMAGE_SIZE//8, IMAGE_SIZE//8, num_classes + 5)
196
+ print("Success!")
197
+
utils.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import config
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.patches as patches
4
+ import numpy as np
5
+ import os
6
+ import random
7
+ import torch
8
+
9
+ from collections import Counter
10
+ from torch.utils.data import DataLoader
11
+ from tqdm import tqdm
12
+
13
+
14
+ def iou_width_height(boxes1, boxes2):
15
+ """
16
+ Parameters:
17
+ boxes1 (tensor): width and height of the first bounding boxes
18
+ boxes2 (tensor): width and height of the second bounding boxes
19
+ Returns:
20
+ tensor: Intersection over union of the corresponding boxes
21
+ """
22
+ intersection = torch.min(boxes1[..., 0], boxes2[..., 0]) * torch.min(
23
+ boxes1[..., 1], boxes2[..., 1]
24
+ )
25
+ union = (
26
+ boxes1[..., 0] * boxes1[..., 1] + boxes2[..., 0] * boxes2[..., 1] - intersection
27
+ )
28
+ return intersection / union
29
+
30
+
31
+ def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
32
+ """
33
+ Video explanation of this function:
34
+ https://youtu.be/XXYG5ZWtjj0
35
+
36
+ This function calculates intersection over union (iou) given pred boxes
37
+ and target boxes.
38
+
39
+ Parameters:
40
+ boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4)
41
+ boxes_labels (tensor): Correct labels of Bounding Boxes (BATCH_SIZE, 4)
42
+ box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2)
43
+
44
+ Returns:
45
+ tensor: Intersection over union for all examples
46
+ """
47
+
48
+ if box_format == "midpoint":
49
+ box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
50
+ box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
51
+ box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
52
+ box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
53
+ box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
54
+ box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
55
+ box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
56
+ box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2
57
+
58
+ if box_format == "corners":
59
+ box1_x1 = boxes_preds[..., 0:1]
60
+ box1_y1 = boxes_preds[..., 1:2]
61
+ box1_x2 = boxes_preds[..., 2:3]
62
+ box1_y2 = boxes_preds[..., 3:4]
63
+ box2_x1 = boxes_labels[..., 0:1]
64
+ box2_y1 = boxes_labels[..., 1:2]
65
+ box2_x2 = boxes_labels[..., 2:3]
66
+ box2_y2 = boxes_labels[..., 3:4]
67
+
68
+ x1 = torch.max(box1_x1, box2_x1)
69
+ y1 = torch.max(box1_y1, box2_y1)
70
+ x2 = torch.min(box1_x2, box2_x2)
71
+ y2 = torch.min(box1_y2, box2_y2)
72
+
73
+ intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
74
+ box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
75
+ box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
76
+
77
+ return intersection / (box1_area + box2_area - intersection + 1e-6)
78
+
79
+
80
+ def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
81
+ """
82
+ Video explanation of this function:
83
+ https://youtu.be/YDkjWEN8jNA
84
+
85
+ Does Non Max Suppression given bboxes
86
+
87
+ Parameters:
88
+ bboxes (list): list of lists containing all bboxes with each bboxes
89
+ specified as [class_pred, prob_score, x1, y1, x2, y2]
90
+ iou_threshold (float): threshold where predicted bboxes is correct
91
+ threshold (float): threshold to remove predicted bboxes (independent of IoU)
92
+ box_format (str): "midpoint" or "corners" used to specify bboxes
93
+
94
+ Returns:
95
+ list: bboxes after performing NMS given a specific IoU threshold
96
+ """
97
+
98
+ assert type(bboxes) == list
99
+
100
+ bboxes = [box for box in bboxes if box[1] > threshold]
101
+ bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
102
+ bboxes_after_nms = []
103
+
104
+ while bboxes:
105
+ chosen_box = bboxes.pop(0)
106
+
107
+ bboxes = [
108
+ box
109
+ for box in bboxes
110
+ if box[0] != chosen_box[0]
111
+ or intersection_over_union(
112
+ torch.tensor(chosen_box[2:]),
113
+ torch.tensor(box[2:]),
114
+ box_format=box_format,
115
+ )
116
+ < iou_threshold
117
+ ]
118
+
119
+ bboxes_after_nms.append(chosen_box)
120
+
121
+ return bboxes_after_nms
122
+
123
+
124
+ def mean_average_precision(
125
+ pred_boxes, true_boxes, iou_threshold=0.5, box_format="midpoint", num_classes=20
126
+ ):
127
+ """
128
+ Video explanation of this function:
129
+ https://youtu.be/FppOzcDvaDI
130
+
131
+ This function calculates mean average precision (mAP)
132
+
133
+ Parameters:
134
+ pred_boxes (list): list of lists containing all bboxes with each bboxes
135
+ specified as [train_idx, class_prediction, prob_score, x1, y1, x2, y2]
136
+ true_boxes (list): Similar as pred_boxes except all the correct ones
137
+ iou_threshold (float): threshold where predicted bboxes is correct
138
+ box_format (str): "midpoint" or "corners" used to specify bboxes
139
+ num_classes (int): number of classes
140
+
141
+ Returns:
142
+ float: mAP value across all classes given a specific IoU threshold
143
+ """
144
+
145
+ # list storing all AP for respective classes
146
+ average_precisions = []
147
+
148
+ # used for numerical stability later on
149
+ epsilon = 1e-6
150
+
151
+ for c in tqdm(range(num_classes)):
152
+ detections = []
153
+ ground_truths = []
154
+
155
+ # Go through all predictions and targets,
156
+ # and only add the ones that belong to the
157
+ # current class c
158
+ for detection in pred_boxes:
159
+ if detection[1] == c:
160
+ detections.append(detection)
161
+
162
+ for true_box in true_boxes:
163
+ if true_box[1] == c:
164
+ ground_truths.append(true_box)
165
+
166
+ # find the amount of bboxes for each training example
167
+ # Counter here finds how many ground truth bboxes we get
168
+ # for each training example, so let's say img 0 has 3,
169
+ # img 1 has 5 then we will obtain a dictionary with:
170
+ # amount_bboxes = {0:3, 1:5}
171
+ amount_bboxes = Counter([gt[0] for gt in ground_truths])
172
+
173
+ # We then go through each key, val in this dictionary
174
+ # and convert to the following (w.r.t same example):
175
+ # ammount_bboxes = {0:torch.tensor[0,0,0], 1:torch.tensor[0,0,0,0,0]}
176
+ for key, val in amount_bboxes.items():
177
+ amount_bboxes[key] = torch.zeros(val)
178
+
179
+ # sort by box probabilities which is index 2
180
+ detections.sort(key=lambda x: x[2], reverse=True)
181
+ TP = torch.zeros((len(detections)))
182
+ FP = torch.zeros((len(detections)))
183
+ total_true_bboxes = len(ground_truths)
184
+
185
+ # If none exists for this class then we can safely skip
186
+ if total_true_bboxes == 0:
187
+ continue
188
+
189
+ for detection_idx, detection in enumerate(detections):
190
+ # Only take out the ground_truths that have the same
191
+ # training idx as detection
192
+ ground_truth_img = [
193
+ bbox for bbox in ground_truths if bbox[0] == detection[0]
194
+ ]
195
+
196
+ num_gts = len(ground_truth_img)
197
+ best_iou = 0
198
+
199
+ for idx, gt in enumerate(ground_truth_img):
200
+ iou = intersection_over_union(
201
+ torch.tensor(detection[3:]),
202
+ torch.tensor(gt[3:]),
203
+ box_format=box_format,
204
+ )
205
+
206
+ if iou > best_iou:
207
+ best_iou = iou
208
+ best_gt_idx = idx
209
+
210
+ if best_iou > iou_threshold:
211
+ # only detect ground truth detection once
212
+ if amount_bboxes[detection[0]][best_gt_idx] == 0:
213
+ # true positive and add this bounding box to seen
214
+ TP[detection_idx] = 1
215
+ amount_bboxes[detection[0]][best_gt_idx] = 1
216
+ else:
217
+ FP[detection_idx] = 1
218
+
219
+ # if IOU is lower then the detection is a false positive
220
+ else:
221
+ FP[detection_idx] = 1
222
+
223
+ TP_cumsum = torch.cumsum(TP, dim=0)
224
+ FP_cumsum = torch.cumsum(FP, dim=0)
225
+ recalls = TP_cumsum / (total_true_bboxes + epsilon)
226
+ precisions = TP_cumsum / (TP_cumsum + FP_cumsum + epsilon)
227
+ precisions = torch.cat((torch.tensor([1]), precisions))
228
+ recalls = torch.cat((torch.tensor([0]), recalls))
229
+ # torch.trapz for numerical integration
230
+ average_precisions.append(torch.trapz(precisions, recalls))
231
+
232
+ return sum(average_precisions) / len(average_precisions)
233
+
234
+
235
+ def plot_image(image, boxes, return_fig=False):
236
+ """Plots predicted bounding boxes on the image"""
237
+ cmap = plt.get_cmap("tab20b")
238
+ class_labels = config.COCO_LABELS if config.DATASET=='COCO' else config.PASCAL_CLASSES
239
+ colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
240
+ im = np.array(image)
241
+ height, width, _ = im.shape
242
+
243
+ # Create figure and axes
244
+ fig, ax = plt.subplots(1)
245
+ # Display the image
246
+ ax.imshow(im)
247
+
248
+ # box[0] is x midpoint, box[2] is width
249
+ # box[1] is y midpoint, box[3] is height
250
+
251
+ # Create a Rectangle patch
252
+ for box in boxes:
253
+ assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
254
+ class_pred = box[0]
255
+ box = box[2:]
256
+ upper_left_x = box[0] - box[2] / 2
257
+ upper_left_y = box[1] - box[3] / 2
258
+ rect = patches.Rectangle(
259
+ (upper_left_x * width, upper_left_y * height),
260
+ box[2] * width,
261
+ box[3] * height,
262
+ linewidth=2,
263
+ edgecolor=colors[int(class_pred)],
264
+ facecolor="none",
265
+ )
266
+ # Add the patch to the Axes
267
+ ax.add_patch(rect)
268
+ plt.text(
269
+ upper_left_x * width,
270
+ upper_left_y * height,
271
+ s=class_labels[int(class_pred)],
272
+ color="white",
273
+ verticalalignment="top",
274
+ bbox={"color": colors[int(class_pred)], "pad": 0},
275
+ )
276
+
277
+ if return_fig:
278
+ return fig
279
+ plt.show()
280
+
281
+
282
+ def get_evaluation_bboxes(
283
+ loader,
284
+ model,
285
+ iou_threshold,
286
+ anchors,
287
+ threshold,
288
+ box_format="midpoint",
289
+ ):
290
+ # make sure model is in eval before get bboxes
291
+ model.eval()
292
+ train_idx = 0
293
+ all_pred_boxes = []
294
+ all_true_boxes = []
295
+ for batch_idx, (x, labels) in enumerate(tqdm(loader)):
296
+ x = x.to(config.DEVICE)
297
+ with torch.no_grad():
298
+ predictions = model(x)
299
+
300
+ batch_size = x.shape[0]
301
+ bboxes = [[] for _ in range(batch_size)]
302
+ for i in range(3):
303
+ S = predictions[i].shape[2]
304
+ anchor = torch.tensor([*anchors[i]]) * S
305
+ boxes_scale_i = cells_to_bboxes(
306
+ predictions[i], anchor, S=S, is_preds=True
307
+ )
308
+ for idx, (box) in enumerate(boxes_scale_i):
309
+ bboxes[idx] += box
310
+
311
+ # we just want one bbox for each label, not one for each scale
312
+ true_bboxes = cells_to_bboxes(
313
+ labels[2], anchor, S=S, is_preds=False
314
+ )
315
+
316
+ for idx in range(batch_size):
317
+ nms_boxes = non_max_suppression(
318
+ bboxes[idx],
319
+ iou_threshold=iou_threshold,
320
+ threshold=threshold,
321
+ box_format=box_format,
322
+ )
323
+
324
+ for nms_box in nms_boxes:
325
+ all_pred_boxes.append([train_idx] + nms_box)
326
+
327
+ for box in true_bboxes[idx]:
328
+ if box[1] > threshold:
329
+ all_true_boxes.append([train_idx] + box)
330
+
331
+ train_idx += 1
332
+
333
+ model.train()
334
+ return all_pred_boxes, all_true_boxes
335
+
336
+
337
+ def cells_to_bboxes(predictions, anchors, S, is_preds=True):
338
+ """
339
+ Scales the predictions coming from the model to
340
+ be relative to the entire image such that they for example later
341
+ can be plotted or.
342
+ INPUT:
343
+ predictions: tensor of size (N, 3, S, S, num_classes+5)
344
+ anchors: the anchors used for the predictions
345
+ S: the number of cells the image is divided in on the width (and height)
346
+ is_preds: whether the input is predictions or the true bounding boxes
347
+ OUTPUT:
348
+ converted_bboxes: the converted boxes of sizes (N, num_anchors, S, S, 1+5) with class index,
349
+ object score, bounding box coordinates
350
+ """
351
+ BATCH_SIZE = predictions.shape[0]
352
+ num_anchors = len(anchors)
353
+ box_predictions = predictions[..., 1:5]
354
+ if is_preds:
355
+ anchors = anchors.reshape(1, len(anchors), 1, 1, 2)
356
+ box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2])
357
+ box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * anchors
358
+ scores = torch.sigmoid(predictions[..., 0:1])
359
+ best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1)
360
+ else:
361
+ scores = predictions[..., 0:1]
362
+ best_class = predictions[..., 5:6]
363
+
364
+ cell_indices = (
365
+ torch.arange(S)
366
+ .repeat(predictions.shape[0], 3, S, 1)
367
+ .unsqueeze(-1)
368
+ ).to(config.DEVICE)
369
+ x = 1 / S * (box_predictions[..., 0:1] + cell_indices)
370
+ y = 1 / S * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4))
371
+ w_h = 1 / S * box_predictions[..., 2:4]
372
+ converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1)\
373
+ .reshape(BATCH_SIZE, num_anchors * S * S, 6)
374
+ return converted_bboxes.tolist()
375
+
376
+ def check_class_accuracy(model, batch, threshold, tag='train'):
377
+ model.eval()
378
+ tot_class_preds, correct_class = 0, 0
379
+ tot_noobj, correct_noobj = 0, 0
380
+ tot_obj, correct_obj = 0, 0
381
+ x, y = batch
382
+
383
+ x = x.to(config.DEVICE)
384
+
385
+ with torch.no_grad():
386
+ out = model(x)
387
+
388
+ for i in range(3):
389
+ y[i] = y[i].to(config.DEVICE)
390
+ obj = y[i][..., 0] == 1 # in paper this is Iobj_i
391
+ noobj = y[i][..., 0] == 0 # in paper this is Iobj_i
392
+
393
+ correct_class += torch.sum(
394
+ torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj]
395
+ )
396
+ tot_class_preds += torch.sum(obj)
397
+
398
+ obj_preds = torch.sigmoid(out[i][..., 0]) > threshold
399
+ correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj])
400
+ tot_obj += torch.sum(obj)
401
+ correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj])
402
+ tot_noobj += torch.sum(noobj)
403
+
404
+ ans = {
405
+ f"{tag}_class_accuracy": (correct_class/(tot_class_preds+1e-16))*100,
406
+ f"{tag}_no_obj_accuracy": (correct_noobj/(tot_noobj+1e-16))*100,
407
+ f"{tag}_obj_accuracy": (correct_obj/(tot_obj+1e-16))*100
408
+ }
409
+ model.train()
410
+ return ans
411
+
412
+
413
+
414
+ def get_mean_std(loader):
415
+ # var[X] = E[X**2] - E[X]**2
416
+ channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0
417
+
418
+ for data, _ in tqdm(loader):
419
+ channels_sum += torch.mean(data, dim=[0, 2, 3])
420
+ channels_sqrd_sum += torch.mean(data ** 2, dim=[0, 2, 3])
421
+ num_batches += 1
422
+
423
+ mean = channels_sum / num_batches
424
+ std = (channels_sqrd_sum / num_batches - mean ** 2) ** 0.5
425
+
426
+ return mean, std
427
+
428
+
429
+ def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
430
+ print("=> Saving checkpoint")
431
+ checkpoint = {
432
+ "state_dict": model.state_dict(),
433
+ "optimizer": optimizer.state_dict(),
434
+ }
435
+ torch.save(checkpoint, filename)
436
+
437
+
438
+ def load_checkpoint(checkpoint_file, model, optimizer, lr, device):
439
+ print("=> Loading checkpoint")
440
+ checkpoint = torch.load(checkpoint_file, map_location=device)
441
+ model.load_state_dict(checkpoint["state_dict"])
442
+ optimizer.load_state_dict(checkpoint["optimizer"])
443
+
444
+ # If we don't do this then it will just have learning rate of old checkpoint
445
+ # and it will lead to many hours of debugging \:
446
+ for param_group in optimizer.param_groups:
447
+ param_group["lr"] = lr
448
+
449
+
450
+ def get_loaders(train_csv_path, test_csv_path):
451
+ from dataset import YOLODataset
452
+
453
+ IMAGE_SIZE = config.IMAGE_SIZE
454
+ train_dataset = YOLODataset(
455
+ train_csv_path,
456
+ transform=config.train_transforms,
457
+ S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
458
+ img_dir=config.IMG_DIR,
459
+ label_dir=config.LABEL_DIR,
460
+ anchors=config.ANCHORS,
461
+ )
462
+ test_dataset = YOLODataset(
463
+ test_csv_path,
464
+ transform=config.test_transforms,
465
+ S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
466
+ img_dir=config.IMG_DIR,
467
+ label_dir=config.LABEL_DIR,
468
+ anchors=config.ANCHORS,
469
+ )
470
+ train_loader = DataLoader(
471
+ dataset=train_dataset,
472
+ batch_size=config.BATCH_SIZE,
473
+ num_workers=config.NUM_WORKERS,
474
+ pin_memory=config.PIN_MEMORY,
475
+ shuffle=True,
476
+ drop_last=False,
477
+ )
478
+ test_loader = DataLoader(
479
+ dataset=test_dataset,
480
+ batch_size=config.BATCH_SIZE,
481
+ num_workers=config.NUM_WORKERS,
482
+ pin_memory=config.PIN_MEMORY,
483
+ shuffle=False,
484
+ drop_last=False,
485
+ )
486
+
487
+ train_eval_dataset = YOLODataset(
488
+ train_csv_path,
489
+ transform=config.test_transforms,
490
+ S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
491
+ img_dir=config.IMG_DIR,
492
+ label_dir=config.LABEL_DIR,
493
+ anchors=config.ANCHORS,
494
+ )
495
+ train_eval_loader = DataLoader(
496
+ dataset=train_eval_dataset,
497
+ batch_size=config.BATCH_SIZE,
498
+ num_workers=config.NUM_WORKERS,
499
+ pin_memory=config.PIN_MEMORY,
500
+ shuffle=False,
501
+ drop_last=False,
502
+ )
503
+
504
+ return train_loader, test_loader, train_eval_loader
505
+
506
+ def plot_couple_examples(model, batch, thresh, iou_thresh, anchors):
507
+ model.eval()
508
+ x, _ = batch
509
+ x = x.to(config.DEVICE)
510
+ batch_size = x.shape[0]
511
+ with torch.no_grad():
512
+ out = model(x)
513
+ bboxes = [[] for _ in range(x.shape[0])]
514
+ for i in range(3):
515
+ batch_size, _, S, _, _ = out[i].shape
516
+ anchor = anchors[i]
517
+ boxes_scale_i = cells_to_bboxes(
518
+ out[i], anchor, S=S, is_preds=True
519
+ )
520
+ for idx, (box) in enumerate(boxes_scale_i):
521
+ bboxes[idx] += box
522
+
523
+ model.train()
524
+
525
+ for i in range(batch_size//4):
526
+ nms_boxes = non_max_suppression(
527
+ bboxes[i], iou_threshold=iou_thresh,
528
+ threshold=thresh, box_format="midpoint",
529
+ )
530
+ plot_image(x[i].permute(1,2,0).detach().cpu(), nms_boxes)
531
+
532
+
533
+
534
+ def seed_everything(seed=42):
535
+ os.environ['PYTHONHASHSEED'] = str(seed)
536
+ random.seed(seed)
537
+ np.random.seed(seed)
538
+ torch.manual_seed(seed)
539
+ torch.cuda.manual_seed(seed)
540
+ torch.cuda.manual_seed_all(seed)
541
+ torch.backends.cudnn.deterministic = True
542
+ torch.backends.cudnn.benchmark = False
543
+
544
+
545
+ def clip_coords(boxes, img_shape):
546
+ # Clip bounding xyxy bounding boxes to image shape (height, width)
547
+ boxes[:, 0].clamp_(0, img_shape[1]) # x1
548
+ boxes[:, 1].clamp_(0, img_shape[0]) # y1
549
+ boxes[:, 2].clamp_(0, img_shape[1]) # x2
550
+ boxes[:, 3].clamp_(0, img_shape[0]) # y2
551
+
552
+ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
553
+ # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
554
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
555
+ y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
556
+ y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
557
+ y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
558
+ y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
559
+ return y
560
+
561
+
562
+ def xyn2xy(x, w=640, h=640, padw=0, padh=0):
563
+ # Convert normalized segments into pixel segments, shape (n,2)
564
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
565
+ y[..., 0] = w * x[..., 0] + padw # top left x
566
+ y[..., 1] = h * x[..., 1] + padh # top left y
567
+ return y
568
+
569
+ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
570
+ # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
571
+ if clip:
572
+ clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
573
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
574
+ y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
575
+ y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
576
+ y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
577
+ y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
578
+ return y
579
+
580
+ def clip_boxes(boxes, shape):
581
+ # Clip boxes (xyxy) to image shape (height, width)
582
+ if isinstance(boxes, torch.Tensor): # faster individually
583
+ boxes[..., 0].clamp_(0, shape[1]) # x1
584
+ boxes[..., 1].clamp_(0, shape[0]) # y1
585
+ boxes[..., 2].clamp_(0, shape[1]) # x2
586
+ boxes[..., 3].clamp_(0, shape[0]) # y2
587
+ else: # np.array (faster grouped)
588
+ boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
589
+ boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
590
+