venkyyuvy's picture
model w mosaic aug
c87ccf3
raw
history blame
3.16 kB
import gradio as gr
import torch
from PIL import Image
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('agg')
from model import YOLOv3
from utils import (
cells_to_bboxes,
non_max_suppression,
plot_image
)
ANCHORS = [
[(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)],
[(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)],
[(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)],
] # Note these have been rescaled to be between [0, 1]
fname = 'epoch=36-step=19166.ckpt'
checkpoint = torch.load(fname, map_location=torch.device('cpu'))
model_state_dict = checkpoint['state_dict']
model = YOLOv3(num_classes=20)
model.load_state_dict(model_state_dict)
IMAGE_SIZE = 416
S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8]
anchors = ( torch.tensor(ANCHORS)
* torch.tensor(S).unsqueeze(1)\
.unsqueeze(1).repeat(1, 3, 2)
)
test_transforms = A.Compose(
[
A.LongestMaxSize(max_size=IMAGE_SIZE),
A.PadIfNeeded(
min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
),
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
ToTensorV2(),
],
)
def object_detector(input_image):
input_img = test_transforms(image=input_image)['image']
input_img = input_img.unsqueeze(0)
thresh = 0.6
iou_thresh = 0.5
with torch.no_grad():
out = model(input_img)
bboxes = []
for i in range(3):
_, _, S, _, _ = out[i].shape
anchor = anchors[i]
bboxes += cells_to_bboxes(
out[i], anchor, S=S, is_preds=True
)
nms_boxes = non_max_suppression(
bboxes[0], iou_threshold=iou_thresh,
threshold=thresh, box_format="midpoint",
)
fig = plot_image(input_img.squeeze(0).permute(1,2,0).detach().cpu(),
nms_boxes,
return_fig=True)
plt.gca().set(xticks=[], yticks=[], xticklabels=[], yticklabels=[])
plt.axis('off')
image_path = "plot.png"
fig.savefig(image_path)
plt.close()
return gr.update(value=image_path, visible=True)
# Define the input and output components for Gradio
input_image = gr.Image(label="Input image")
output_box = gr.Image(label="Output image")\
.style(width=428, height=428)
images_path = "examples/"
# Create the Gradio interface
gr.Interface(fn=object_detector, inputs=input_image, outputs=output_box,
examples=[[images_path + "000015.jpg"],
[images_path + "000017.jpg"],
[images_path + "000030.jpg"],
[images_path + "000069.jpg"],
[images_path + "000071.jpg"],
[images_path + "000084.jpg"],
[images_path + "000086.jpg"],
[images_path + "000088.jpg"],
[images_path + "000095.jpg"],
[images_path + "000100.jpg"],
],
).launch()