import unittest import gradio as gr import torch import cv2 import albumentations as A from albumentations.pytorch import ToTensorV2 import matplotlib.pyplot as plt import matplotlib matplotlib.use('agg') import torch import cv2 from pytorch_grad_cam import EigenCAM from pytorch_grad_cam.utils.image import show_cam_on_image import config from model import YOLOv3 from utils import ( cells_to_bboxes, non_max_suppression, plot_image ) def yolov3_reshape_transform(tensor, ): return tensor[0] # fname = 'epoch=38-step=20202.ckpt' # checkpoint = torch.load(fname, map_location=torch.device('cpu')) # model_state_dict = checkpoint['state_dict'] model = YOLOv3(num_classes=20) # model.load_state_dict(model_state_dict) # torch.save(model.state_dict(), 'yolov3.pth') fname = 'yolov3.pth' model.load_state_dict(torch.load(fname)) IMAGE_SIZE = config.IMAGE_SIZE S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8] anchors = ( torch.tensor(config.ANCHORS) * torch.tensor(config.S).unsqueeze(1)\ .unsqueeze(1).repeat(1, 3, 2) ) def object_detector(input_image, thresh = 0.8, iou_thresh = 0.5): input_img = config.test_transforms(image=input_image)['image'] input_img = input_img.unsqueeze(0) with torch.no_grad(): out = model(input_img) bboxes = [] for i in range(3): _, _, S, _, _ = out[i].shape anchor = anchors[i] bboxes += cells_to_bboxes( out[i], anchor, S=S, is_preds=True )[0] nms_boxes = non_max_suppression( bboxes, iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint", ) fig = plot_image(input_img.squeeze(0).permute(1,2,0).detach().cpu(), nms_boxes, return_fig=True) plt.gca().set(xticks=[], yticks=[], xticklabels=[], yticklabels=[]) plt.axis('off') image_path = "plot.png" fig.savefig(image_path) plt.close() # target_layers = [model.layers[21]] # cam = EigenCAM(model, target_layers, use_cuda=False, # reshape_transform=yolov3_reshape_transform, # ) # grayscale_cam = cam(input_img, target_layers)[0][0, :, :] # cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True) return gr.update(value=image_path, visible=True),\ gr.update(value=image_path, visible=True) # Define the input and output components for Gradio input_image = gr.Image(label="Input image") confidence_level = gr.Slider(0.5, 1, value=0.6, step=0.01, label="confidence level") iou_level = gr.Slider(0.5, 1, value=0.6, step=0.01, label="Interference over union level") output_box = gr.Image(label="Output image", visible=False,)\ .style(width=428, height=428) cam_output = gr.Image(label="cam output", visible=False)\ .style(width=428, height=428) images_path = "examples/" gr_interface = gr.Interface( fn=object_detector, inputs=[input_image, confidence_level, iou_level], outputs=[output_box, cam_output], examples=[[images_path + "000015.jpg"], [images_path + "000017.jpg"], [images_path + "000030.jpg"], [images_path + "000069.jpg"], [images_path + "000071.jpg"], [images_path + "000084.jpg"], [images_path + "000086.jpg"], [images_path + "000088.jpg"], [images_path + "000100.jpg"], ], cache_examples=False ) gr_interface.launch() # class TestGradioInterfaceInput(unittest.TestCase): # def test_valid_image_input(self): # # Create a valid input image # input_image = images_path + "000015.jpg" # # # Pass the image through the interface # output = gr_interface(input_image) # # # Assert the output matches the expected result # self.assertEqual(output[0].shape, (3, 416, 416)) # # if __name__ == '__main__': # unittest.main() # # Create the Gradio interface # gr.Interface(fn=object_detector, inputs=input_image, # outputs=[output_box, cam_output], # examples=[[images_path + "000015.jpg"], # [images_path + "000017.jpg"], # [images_path + "000030.jpg"], # [images_path + "000069.jpg"], # [images_path + "000071.jpg"], # [images_path + "000084.jpg"], # [images_path + "000086.jpg"], # [images_path + "000088.jpg"], # [images_path + "000095.jpg"], # [images_path + "000100.jpg"], # ], # ).launch() #