File size: 3,026 Bytes
c87ccf3
 
 
 
 
 
 
 
 
60930a3
c87ccf3
 
 
 
 
 
 
60930a3
 
 
 
 
c87ccf3
60930a3
c87ccf3
 
 
60930a3
 
c87ccf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
import torch
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('agg')

import config
from model import YOLOv3
from utils import (
    cells_to_bboxes, 
    non_max_suppression,
    plot_image
)

# fname = 'epoch=36-step=19166.ckpt'
fname = 'yolov3.pth'
# checkpoint = torch.load(fname, map_location=torch.device('cpu'))
# model_state_dict = checkpoint['state_dict']
# torch.save(model.state_dict(), 'yolov3.pth')
model = YOLOv3(num_classes=20)
model.load_state_dict(torch.load(fname))

IMAGE_SIZE = 416
S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8]
anchors = ( torch.tensor(config.ANCHORS)
    * torch.tensor(config.S).unsqueeze(1)\
        .unsqueeze(1).repeat(1, 3, 2)
    )

test_transforms = A.Compose(
    [
        A.LongestMaxSize(max_size=IMAGE_SIZE),
        A.PadIfNeeded(
            min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
        ),
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
        ToTensorV2(),
    ],
)
def object_detector(input_image):
    input_img = test_transforms(image=input_image)['image']
    input_img = input_img.unsqueeze(0)

    thresh = 0.6
    iou_thresh = 0.5
    with torch.no_grad():
        out = model(input_img)
        bboxes = []
        for i in range(3):
            _, _, S, _, _ = out[i].shape
            anchor = anchors[i]
            bboxes += cells_to_bboxes(
                out[i], anchor, S=S, is_preds=True
            )
        nms_boxes = non_max_suppression(
            bboxes[0], iou_threshold=iou_thresh, 
            threshold=thresh, box_format="midpoint",
        )
        fig = plot_image(input_img.squeeze(0).permute(1,2,0).detach().cpu(),
                         nms_boxes,
                         return_fig=True)
        plt.gca().set(xticks=[], yticks=[], xticklabels=[], yticklabels=[])
        plt.axis('off')
        image_path = "plot.png"
        fig.savefig(image_path)
        plt.close()
        return gr.update(value=image_path, visible=True)

# Define the input and output components for Gradio
input_image = gr.Image(label="Input image")
output_box = gr.Image(label="Output image")\
                   .style(width=428, height=428)
images_path = "examples/"


# Create the Gradio interface
gr.Interface(fn=object_detector, inputs=input_image, outputs=output_box,
             examples=[[images_path + "000015.jpg"], 
                       [images_path + "000017.jpg"],
                       [images_path + "000030.jpg"],
                       [images_path + "000069.jpg"],
                       [images_path + "000071.jpg"],
                       [images_path + "000084.jpg"],
                       [images_path + "000086.jpg"],
                       [images_path + "000088.jpg"],
                       [images_path + "000095.jpg"],
                       [images_path + "000100.jpg"],
                       ],
             ).launch()