venkyyuvy's picture
cache examples false
aad51cc
import unittest
import gradio as gr
import torch
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('agg')
import torch
import cv2
from pytorch_grad_cam import EigenCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
import config
from model import YOLOv3
from utils import (
cells_to_bboxes,
non_max_suppression,
plot_image
)
def yolov3_reshape_transform(tensor, ):
return tensor[0]
# fname = 'epoch=38-step=20202.ckpt'
# checkpoint = torch.load(fname, map_location=torch.device('cpu'))
# model_state_dict = checkpoint['state_dict']
model = YOLOv3(num_classes=20)
# model.load_state_dict(model_state_dict)
# torch.save(model.state_dict(), 'yolov3.pth')
fname = 'yolov3.pth'
model.load_state_dict(torch.load(fname))
IMAGE_SIZE = config.IMAGE_SIZE
S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8]
anchors = ( torch.tensor(config.ANCHORS)
* torch.tensor(config.S).unsqueeze(1)\
.unsqueeze(1).repeat(1, 3, 2)
)
def object_detector(input_image, thresh = 0.8, iou_thresh = 0.5):
input_img = config.test_transforms(image=input_image)['image']
input_img = input_img.unsqueeze(0)
with torch.no_grad():
out = model(input_img)
bboxes = []
for i in range(3):
_, _, S, _, _ = out[i].shape
anchor = anchors[i]
bboxes += cells_to_bboxes(
out[i], anchor, S=S, is_preds=True
)[0]
nms_boxes = non_max_suppression(
bboxes, iou_threshold=iou_thresh,
threshold=thresh, box_format="midpoint",
)
fig = plot_image(input_img.squeeze(0).permute(1,2,0).detach().cpu(),
nms_boxes,
return_fig=True)
plt.gca().set(xticks=[], yticks=[], xticklabels=[], yticklabels=[])
plt.axis('off')
image_path = "plot.png"
fig.savefig(image_path)
plt.close()
# target_layers = [model.layers[21]]
# cam = EigenCAM(model, target_layers, use_cuda=False,
# reshape_transform=yolov3_reshape_transform,
# )
# grayscale_cam = cam(input_img, target_layers)[0][0, :, :]
# cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True)
return gr.update(value=image_path, visible=True),\
gr.update(value=image_path, visible=True)
# Define the input and output components for Gradio
input_image = gr.Image(label="Input image")
confidence_level = gr.Slider(0.5, 1, value=0.6, step=0.01,
label="confidence level")
iou_level = gr.Slider(0.5, 1, value=0.6, step=0.01,
label="Interference over union level")
output_box = gr.Image(label="Output image", visible=False,)\
.style(width=428, height=428)
cam_output = gr.Image(label="cam output", visible=False)\
.style(width=428, height=428)
images_path = "examples/"
gr_interface = gr.Interface(
fn=object_detector,
inputs=[input_image, confidence_level, iou_level],
outputs=[output_box, cam_output],
examples=[[images_path + "000015.jpg"],
[images_path + "000017.jpg"],
[images_path + "000030.jpg"],
[images_path + "000069.jpg"],
[images_path + "000071.jpg"],
[images_path + "000084.jpg"],
[images_path + "000086.jpg"],
[images_path + "000088.jpg"],
[images_path + "000100.jpg"],
],
cache_examples=False
)
gr_interface.launch()
# class TestGradioInterfaceInput(unittest.TestCase):
# def test_valid_image_input(self):
# # Create a valid input image
# input_image = images_path + "000015.jpg"
#
# # Pass the image through the interface
# output = gr_interface(input_image)
#
# # Assert the output matches the expected result
# self.assertEqual(output[0].shape, (3, 416, 416))
#
# if __name__ == '__main__':
# unittest.main()
#
# Create the Gradio interface
# gr.Interface(fn=object_detector, inputs=input_image,
# outputs=[output_box, cam_output],
# examples=[[images_path + "000015.jpg"],
# [images_path + "000017.jpg"],
# [images_path + "000030.jpg"],
# [images_path + "000069.jpg"],
# [images_path + "000071.jpg"],
# [images_path + "000084.jpg"],
# [images_path + "000086.jpg"],
# [images_path + "000088.jpg"],
# [images_path + "000095.jpg"],
# [images_path + "000100.jpg"],
# ],
# ).launch()
#