Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from PIL import Image | |
import cv2 | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
import matplotlib.pyplot as plt | |
import matplotlib | |
matplotlib.use('agg') | |
from model import YOLOv3 | |
from utils import ( | |
cells_to_bboxes, | |
non_max_suppression, | |
plot_image | |
) | |
ANCHORS = [ | |
[(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)], | |
[(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)], | |
[(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)], | |
] # Note these have been rescaled to be between [0, 1] | |
fname = 'epoch=36-step=19166.ckpt' | |
checkpoint = torch.load(fname, map_location=torch.device('cpu')) | |
model_state_dict = checkpoint['state_dict'] | |
model = YOLOv3(num_classes=20) | |
model.load_state_dict(model_state_dict) | |
IMAGE_SIZE = 416 | |
S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8] | |
anchors = ( torch.tensor(ANCHORS) | |
* torch.tensor(S).unsqueeze(1)\ | |
.unsqueeze(1).repeat(1, 3, 2) | |
) | |
test_transforms = A.Compose( | |
[ | |
A.LongestMaxSize(max_size=IMAGE_SIZE), | |
A.PadIfNeeded( | |
min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT | |
), | |
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,), | |
ToTensorV2(), | |
], | |
) | |
def object_detector(input_image): | |
input_img = test_transforms(image=input_image)['image'] | |
input_img = input_img.unsqueeze(0) | |
thresh = 0.6 | |
iou_thresh = 0.5 | |
with torch.no_grad(): | |
out = model(input_img) | |
bboxes = [] | |
for i in range(3): | |
_, _, S, _, _ = out[i].shape | |
anchor = anchors[i] | |
bboxes += cells_to_bboxes( | |
out[i], anchor, S=S, is_preds=True | |
) | |
nms_boxes = non_max_suppression( | |
bboxes[0], iou_threshold=iou_thresh, | |
threshold=thresh, box_format="midpoint", | |
) | |
fig = plot_image(input_img.squeeze(0).permute(1,2,0).detach().cpu(), | |
nms_boxes, | |
return_fig=True) | |
plt.gca().set(xticks=[], yticks=[], xticklabels=[], yticklabels=[]) | |
plt.axis('off') | |
image_path = "plot.png" | |
fig.savefig(image_path) | |
plt.close() | |
return gr.update(value=image_path, visible=True) | |
# Define the input and output components for Gradio | |
input_image = gr.Image(label="Input image") | |
output_box = gr.Image(label="Output image")\ | |
.style(width=428, height=428) | |
images_path = "examples/" | |
# Create the Gradio interface | |
gr.Interface(fn=object_detector, inputs=input_image, outputs=output_box, | |
examples=[[images_path + "000015.jpg"], | |
[images_path + "000017.jpg"], | |
[images_path + "000030.jpg"], | |
[images_path + "000069.jpg"], | |
[images_path + "000071.jpg"], | |
[images_path + "000084.jpg"], | |
[images_path + "000086.jpg"], | |
[images_path + "000088.jpg"], | |
[images_path + "000095.jpg"], | |
[images_path + "000100.jpg"], | |
], | |
).launch() | |