venkyyuvy's picture
ckpt to pth file
60930a3
raw
history blame
3.03 kB
import gradio as gr
import torch
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('agg')
import config
from model import YOLOv3
from utils import (
cells_to_bboxes,
non_max_suppression,
plot_image
)
# fname = 'epoch=36-step=19166.ckpt'
fname = 'yolov3.pth'
# checkpoint = torch.load(fname, map_location=torch.device('cpu'))
# model_state_dict = checkpoint['state_dict']
# torch.save(model.state_dict(), 'yolov3.pth')
model = YOLOv3(num_classes=20)
model.load_state_dict(torch.load(fname))
IMAGE_SIZE = 416
S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8]
anchors = ( torch.tensor(config.ANCHORS)
* torch.tensor(config.S).unsqueeze(1)\
.unsqueeze(1).repeat(1, 3, 2)
)
test_transforms = A.Compose(
[
A.LongestMaxSize(max_size=IMAGE_SIZE),
A.PadIfNeeded(
min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
),
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
ToTensorV2(),
],
)
def object_detector(input_image):
input_img = test_transforms(image=input_image)['image']
input_img = input_img.unsqueeze(0)
thresh = 0.6
iou_thresh = 0.5
with torch.no_grad():
out = model(input_img)
bboxes = []
for i in range(3):
_, _, S, _, _ = out[i].shape
anchor = anchors[i]
bboxes += cells_to_bboxes(
out[i], anchor, S=S, is_preds=True
)
nms_boxes = non_max_suppression(
bboxes[0], iou_threshold=iou_thresh,
threshold=thresh, box_format="midpoint",
)
fig = plot_image(input_img.squeeze(0).permute(1,2,0).detach().cpu(),
nms_boxes,
return_fig=True)
plt.gca().set(xticks=[], yticks=[], xticklabels=[], yticklabels=[])
plt.axis('off')
image_path = "plot.png"
fig.savefig(image_path)
plt.close()
return gr.update(value=image_path, visible=True)
# Define the input and output components for Gradio
input_image = gr.Image(label="Input image")
output_box = gr.Image(label="Output image")\
.style(width=428, height=428)
images_path = "examples/"
# Create the Gradio interface
gr.Interface(fn=object_detector, inputs=input_image, outputs=output_box,
examples=[[images_path + "000015.jpg"],
[images_path + "000017.jpg"],
[images_path + "000030.jpg"],
[images_path + "000069.jpg"],
[images_path + "000071.jpg"],
[images_path + "000084.jpg"],
[images_path + "000086.jpg"],
[images_path + "000088.jpg"],
[images_path + "000095.jpg"],
[images_path + "000100.jpg"],
],
).launch()