Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import cv2 | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| matplotlib.use('agg') | |
| import config | |
| from model import YOLOv3 | |
| from utils import ( | |
| cells_to_bboxes, | |
| non_max_suppression, | |
| plot_image | |
| ) | |
| # fname = 'epoch=36-step=19166.ckpt' | |
| fname = 'yolov3.pth' | |
| # checkpoint = torch.load(fname, map_location=torch.device('cpu')) | |
| # model_state_dict = checkpoint['state_dict'] | |
| # torch.save(model.state_dict(), 'yolov3.pth') | |
| model = YOLOv3(num_classes=20) | |
| model.load_state_dict(torch.load(fname)) | |
| IMAGE_SIZE = 416 | |
| S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8] | |
| anchors = ( torch.tensor(config.ANCHORS) | |
| * torch.tensor(config.S).unsqueeze(1)\ | |
| .unsqueeze(1).repeat(1, 3, 2) | |
| ) | |
| test_transforms = A.Compose( | |
| [ | |
| A.LongestMaxSize(max_size=IMAGE_SIZE), | |
| A.PadIfNeeded( | |
| min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT | |
| ), | |
| A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,), | |
| ToTensorV2(), | |
| ], | |
| ) | |
| def object_detector(input_image): | |
| input_img = test_transforms(image=input_image)['image'] | |
| input_img = input_img.unsqueeze(0) | |
| thresh = 0.6 | |
| iou_thresh = 0.5 | |
| with torch.no_grad(): | |
| out = model(input_img) | |
| bboxes = [] | |
| for i in range(3): | |
| _, _, S, _, _ = out[i].shape | |
| anchor = anchors[i] | |
| bboxes += cells_to_bboxes( | |
| out[i], anchor, S=S, is_preds=True | |
| ) | |
| nms_boxes = non_max_suppression( | |
| bboxes[0], iou_threshold=iou_thresh, | |
| threshold=thresh, box_format="midpoint", | |
| ) | |
| fig = plot_image(input_img.squeeze(0).permute(1,2,0).detach().cpu(), | |
| nms_boxes, | |
| return_fig=True) | |
| plt.gca().set(xticks=[], yticks=[], xticklabels=[], yticklabels=[]) | |
| plt.axis('off') | |
| image_path = "plot.png" | |
| fig.savefig(image_path) | |
| plt.close() | |
| return gr.update(value=image_path, visible=True) | |
| # Define the input and output components for Gradio | |
| input_image = gr.Image(label="Input image") | |
| output_box = gr.Image(label="Output image")\ | |
| .style(width=428, height=428) | |
| images_path = "examples/" | |
| # Create the Gradio interface | |
| gr.Interface(fn=object_detector, inputs=input_image, outputs=output_box, | |
| examples=[[images_path + "000015.jpg"], | |
| [images_path + "000017.jpg"], | |
| [images_path + "000030.jpg"], | |
| [images_path + "000069.jpg"], | |
| [images_path + "000071.jpg"], | |
| [images_path + "000084.jpg"], | |
| [images_path + "000086.jpg"], | |
| [images_path + "000088.jpg"], | |
| [images_path + "000095.jpg"], | |
| [images_path + "000100.jpg"], | |
| ], | |
| ).launch() | |