File size: 4,869 Bytes
aad51cc
c87ccf3
 
 
 
 
 
 
 
aad51cc
 
 
 
 
c87ccf3
60930a3
c87ccf3
 
 
 
 
 
 
aad51cc
 
 
 
60930a3
 
c87ccf3
aad51cc
 
 
60930a3
c87ccf3
aad51cc
c87ccf3
60930a3
 
c87ccf3
 
 
aad51cc
 
c87ccf3
 
aad51cc
c87ccf3
 
 
 
 
 
 
 
aad51cc
c87ccf3
aad51cc
c87ccf3
 
 
 
 
 
 
 
 
 
aad51cc
 
 
 
 
 
 
 
 
c87ccf3
 
 
aad51cc
 
 
 
 
 
 
c87ccf3
 
 
aad51cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c87ccf3
aad51cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c87ccf3
aad51cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import unittest
import gradio as gr
import torch
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('agg')
import torch    
import cv2
from pytorch_grad_cam import EigenCAM
from pytorch_grad_cam.utils.image import show_cam_on_image


import config
from model import YOLOv3
from utils import (
    cells_to_bboxes, 
    non_max_suppression,
    plot_image
)

def yolov3_reshape_transform(tensor, ):
    return tensor[0]

# fname = 'epoch=38-step=20202.ckpt'
# checkpoint = torch.load(fname, map_location=torch.device('cpu'))
# model_state_dict = checkpoint['state_dict']
model = YOLOv3(num_classes=20)
# model.load_state_dict(model_state_dict)
# torch.save(model.state_dict(), 'yolov3.pth')
fname = 'yolov3.pth'
model.load_state_dict(torch.load(fname))

IMAGE_SIZE = config.IMAGE_SIZE
S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8]
anchors = ( torch.tensor(config.ANCHORS)
    * torch.tensor(config.S).unsqueeze(1)\
        .unsqueeze(1).repeat(1, 3, 2)
    )

def object_detector(input_image, thresh = 0.8, iou_thresh = 0.5):
    input_img = config.test_transforms(image=input_image)['image']
    input_img = input_img.unsqueeze(0)


    with torch.no_grad():
        out = model(input_img)
        bboxes = []
        for i in range(3):
            _, _, S, _, _ = out[i].shape
            anchor = anchors[i]
            bboxes += cells_to_bboxes(
                out[i], anchor, S=S, is_preds=True
            )[0]
        nms_boxes = non_max_suppression(
            bboxes, iou_threshold=iou_thresh, 
            threshold=thresh, box_format="midpoint",
        )
        fig = plot_image(input_img.squeeze(0).permute(1,2,0).detach().cpu(),
                         nms_boxes,
                         return_fig=True)
        plt.gca().set(xticks=[], yticks=[], xticklabels=[], yticklabels=[])
        plt.axis('off')
        image_path = "plot.png"
        fig.savefig(image_path)
        plt.close()

        # target_layers = [model.layers[21]]
        # cam = EigenCAM(model, target_layers, use_cuda=False,
        #                reshape_transform=yolov3_reshape_transform,
        #                )
        # grayscale_cam = cam(input_img, target_layers)[0][0, :, :]
        # cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True)
        return gr.update(value=image_path, visible=True),\
            gr.update(value=image_path, visible=True)

# Define the input and output components for Gradio
input_image = gr.Image(label="Input image")
confidence_level = gr.Slider(0.5, 1, value=0.6, step=0.01, 
                    label="confidence level")
iou_level = gr.Slider(0.5, 1, value=0.6, step=0.01, 
                    label="Interference over union level")
output_box = gr.Image(label="Output image", visible=False,)\
                   .style(width=428, height=428)
cam_output = gr.Image(label="cam output", visible=False)\
                   .style(width=428, height=428)
images_path = "examples/"

gr_interface = gr.Interface(
    fn=object_detector,
    inputs=[input_image, confidence_level, iou_level],
    outputs=[output_box, cam_output],
    examples=[[images_path + "000015.jpg"], 
               [images_path + "000017.jpg"],
               [images_path + "000030.jpg"],
               [images_path + "000069.jpg"],
               [images_path + "000071.jpg"],
               [images_path + "000084.jpg"],
               [images_path + "000086.jpg"],
               [images_path + "000088.jpg"],
               [images_path + "000100.jpg"],
               ],
    cache_examples=False
                 )

gr_interface.launch()

# class TestGradioInterfaceInput(unittest.TestCase):
#     def test_valid_image_input(self):
#         # Create a valid input image
#         input_image = images_path + "000015.jpg"
#
#         # Pass the image through the interface
#         output = gr_interface(input_image)
#
#         # Assert the output matches the expected result
#         self.assertEqual(output[0].shape, (3, 416, 416))
#
# if __name__ == '__main__':
#     unittest.main()
#

# Create the Gradio interface
# gr.Interface(fn=object_detector, inputs=input_image, 
#              outputs=[output_box, cam_output],
#              examples=[[images_path + "000015.jpg"], 
#                        [images_path + "000017.jpg"],
#                        [images_path + "000030.jpg"],
#                        [images_path + "000069.jpg"],
#                        [images_path + "000071.jpg"],
#                        [images_path + "000084.jpg"],
#                        [images_path + "000086.jpg"],
#                        [images_path + "000088.jpg"],
#                        [images_path + "000095.jpg"],
#                        [images_path + "000100.jpg"],
#                        ],
#              ).launch()
#