Spaces:
Runtime error
Runtime error
import unittest | |
import gradio as gr | |
import torch | |
import cv2 | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
import matplotlib.pyplot as plt | |
import matplotlib | |
matplotlib.use('agg') | |
import torch | |
import cv2 | |
from pytorch_grad_cam import EigenCAM | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
import config | |
from model import YOLOv3 | |
from utils import ( | |
cells_to_bboxes, | |
non_max_suppression, | |
plot_image | |
) | |
def yolov3_reshape_transform(tensor, ): | |
return tensor[0] | |
# fname = 'epoch=38-step=20202.ckpt' | |
# checkpoint = torch.load(fname, map_location=torch.device('cpu')) | |
# model_state_dict = checkpoint['state_dict'] | |
model = YOLOv3(num_classes=20) | |
# model.load_state_dict(model_state_dict) | |
# torch.save(model.state_dict(), 'yolov3.pth') | |
fname = 'yolov3.pth' | |
model.load_state_dict(torch.load(fname)) | |
IMAGE_SIZE = config.IMAGE_SIZE | |
S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8] | |
anchors = ( torch.tensor(config.ANCHORS) | |
* torch.tensor(config.S).unsqueeze(1)\ | |
.unsqueeze(1).repeat(1, 3, 2) | |
) | |
def object_detector(input_image, thresh = 0.8, iou_thresh = 0.5): | |
input_img = config.test_transforms(image=input_image)['image'] | |
input_img = input_img.unsqueeze(0) | |
with torch.no_grad(): | |
out = model(input_img) | |
bboxes = [] | |
for i in range(3): | |
_, _, S, _, _ = out[i].shape | |
anchor = anchors[i] | |
bboxes += cells_to_bboxes( | |
out[i], anchor, S=S, is_preds=True | |
)[0] | |
nms_boxes = non_max_suppression( | |
bboxes, iou_threshold=iou_thresh, | |
threshold=thresh, box_format="midpoint", | |
) | |
fig = plot_image(input_img.squeeze(0).permute(1,2,0).detach().cpu(), | |
nms_boxes, | |
return_fig=True) | |
plt.gca().set(xticks=[], yticks=[], xticklabels=[], yticklabels=[]) | |
plt.axis('off') | |
image_path = "plot.png" | |
fig.savefig(image_path) | |
plt.close() | |
# target_layers = [model.layers[21]] | |
# cam = EigenCAM(model, target_layers, use_cuda=False, | |
# reshape_transform=yolov3_reshape_transform, | |
# ) | |
# grayscale_cam = cam(input_img, target_layers)[0][0, :, :] | |
# cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True) | |
return gr.update(value=image_path, visible=True),\ | |
gr.update(value=image_path, visible=True) | |
# Define the input and output components for Gradio | |
input_image = gr.Image(label="Input image") | |
confidence_level = gr.Slider(0.5, 1, value=0.6, step=0.01, | |
label="confidence level") | |
iou_level = gr.Slider(0.5, 1, value=0.6, step=0.01, | |
label="Interference over union level") | |
output_box = gr.Image(label="Output image", visible=False,)\ | |
.style(width=428, height=428) | |
cam_output = gr.Image(label="cam output", visible=False)\ | |
.style(width=428, height=428) | |
images_path = "examples/" | |
gr_interface = gr.Interface( | |
fn=object_detector, | |
inputs=[input_image, confidence_level, iou_level], | |
outputs=[output_box, cam_output], | |
examples=[[images_path + "000015.jpg"], | |
[images_path + "000017.jpg"], | |
[images_path + "000030.jpg"], | |
[images_path + "000069.jpg"], | |
[images_path + "000071.jpg"], | |
[images_path + "000084.jpg"], | |
[images_path + "000086.jpg"], | |
[images_path + "000088.jpg"], | |
[images_path + "000100.jpg"], | |
], | |
cache_examples=False | |
) | |
gr_interface.launch() | |
# class TestGradioInterfaceInput(unittest.TestCase): | |
# def test_valid_image_input(self): | |
# # Create a valid input image | |
# input_image = images_path + "000015.jpg" | |
# | |
# # Pass the image through the interface | |
# output = gr_interface(input_image) | |
# | |
# # Assert the output matches the expected result | |
# self.assertEqual(output[0].shape, (3, 416, 416)) | |
# | |
# if __name__ == '__main__': | |
# unittest.main() | |
# | |
# Create the Gradio interface | |
# gr.Interface(fn=object_detector, inputs=input_image, | |
# outputs=[output_box, cam_output], | |
# examples=[[images_path + "000015.jpg"], | |
# [images_path + "000017.jpg"], | |
# [images_path + "000030.jpg"], | |
# [images_path + "000069.jpg"], | |
# [images_path + "000071.jpg"], | |
# [images_path + "000084.jpg"], | |
# [images_path + "000086.jpg"], | |
# [images_path + "000088.jpg"], | |
# [images_path + "000095.jpg"], | |
# [images_path + "000100.jpg"], | |
# ], | |
# ).launch() | |
# | |