import gradio as gr import torch from PIL import Image import cv2 import albumentations as A from albumentations.pytorch import ToTensorV2 import matplotlib.pyplot as plt import matplotlib matplotlib.use('agg') from model import YOLOv3 from utils import ( cells_to_bboxes, non_max_suppression, plot_image ) ANCHORS = [ [(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)], [(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)], [(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)], ] # Note these have been rescaled to be between [0, 1] fname = 'epoch=36-step=19166.ckpt' checkpoint = torch.load(fname, map_location=torch.device('cpu')) model_state_dict = checkpoint['state_dict'] model = YOLOv3(num_classes=20) model.load_state_dict(model_state_dict) IMAGE_SIZE = 416 S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8] anchors = ( torch.tensor(ANCHORS) * torch.tensor(S).unsqueeze(1)\ .unsqueeze(1).repeat(1, 3, 2) ) test_transforms = A.Compose( [ A.LongestMaxSize(max_size=IMAGE_SIZE), A.PadIfNeeded( min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT ), A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,), ToTensorV2(), ], ) def object_detector(input_image): input_img = test_transforms(image=input_image)['image'] input_img = input_img.unsqueeze(0) thresh = 0.6 iou_thresh = 0.5 with torch.no_grad(): out = model(input_img) bboxes = [] for i in range(3): _, _, S, _, _ = out[i].shape anchor = anchors[i] bboxes += cells_to_bboxes( out[i], anchor, S=S, is_preds=True ) nms_boxes = non_max_suppression( bboxes[0], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint", ) fig = plot_image(input_img.squeeze(0).permute(1,2,0).detach().cpu(), nms_boxes, return_fig=True) plt.gca().set(xticks=[], yticks=[], xticklabels=[], yticklabels=[]) plt.axis('off') image_path = "plot.png" fig.savefig(image_path) plt.close() return gr.update(value=image_path, visible=True) # Define the input and output components for Gradio input_image = gr.Image(label="Input image") output_box = gr.Image(label="Output image")\ .style(width=428, height=428) images_path = "examples/" # Create the Gradio interface gr.Interface(fn=object_detector, inputs=input_image, outputs=output_box, examples=[[images_path + "000015.jpg"], [images_path + "000017.jpg"], [images_path + "000030.jpg"], [images_path + "000069.jpg"], [images_path + "000071.jpg"], [images_path + "000084.jpg"], [images_path + "000086.jpg"], [images_path + "000088.jpg"], [images_path + "000095.jpg"], [images_path + "000100.jpg"], ], ).launch()