import operator import torch import gradio as gr from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image import gradio as gr from data_loader import CIFAR_CLASS_LABELS, TEST_TRANSFORM import matplotlib from model import ResNet18 matplotlib.use('agg') from matplotlib import pyplot as plt resnet_18_model = ResNet18() resnet_18_model.load_state_dict(torch.load('resnet18.pth', map_location='cpu')) resnet_18_model.eval() classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') def inference(input_img, n_top_classes, apply_gradcam, transparency=0.5, target_layer_number = -1): org_img = input_img input_img = TEST_TRANSFORM(image=input_img)['image'] input_img = input_img.unsqueeze(0) outputs = resnet_18_model(input_img) softmax = torch.nn.Softmax(dim=0) o = softmax(outputs.flatten()) y = {classes[i]: float(o[i]) for i in range(10)} sorted_pred = sorted(y.items(), key=operator.itemgetter(1), reverse=True) sorted_pred = sorted_pred[: n_top_classes] confidences = {klass: prob for klass, prob in sorted_pred} if apply_gradcam: target_layers = [resnet_18_model.layer3[target_layer_number]] cam = GradCAM(model=resnet_18_model, target_layers=target_layers, use_cuda=False) grayscale_cam = cam(input_tensor=input_img, targets=None) grayscale_cam = grayscale_cam[0, :] visualization = show_cam_on_image( org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency) return (gr.update(value= confidences), gr.update(value=visualization, visible=True)) return (gr.update(value=confidences), gr.update(visible=False)) def show_misclasif(see_misclassif, n_images): if see_misclassif: subset = torch.load('misclassified_images.pt') images, actuals, preds = torch.tensor(subset[0])[:20], subset[1], subset[2] figsize=(n_images, 4) nrows=2 ncols=n_images//2 fig, axes = plt.subplots(nrows, ncols, figsize=figsize) fig.suptitle('misclassified images', weight='bold', size=10) axes = axes.ravel() for img, actual, pred, ax in zip(images, actuals, preds, axes): ax.imshow(img) ax.set_title( f'Prediction={CIFAR_CLASS_LABELS[pred]}\n Actual={CIFAR_CLASS_LABELS[actual]}', fontsize=8) ax.set(xticks=[], yticks=[], xticklabels=[], yticklabels=[]) ax.axis('off') image_path = "plot.png" fig.savefig(image_path) plt.close() return gr.update(value=image_path, visible=True) with gr.Blocks() as demo: with gr.Row(): with gr.Column(): input_image = gr.Image(shape=(32, 32), label="Input Image") n_top_classes = gr.Slider(maximum=10, minimum=1, value=3, step=1, label="Top n classes to show", interactive=True) require_gradcam = gr.Checkbox(label="Apply GradCAM", info="Do you want see the GRAD-CAM visualization") opacity_gradcam = gr.Slider(0, 1, value=0.5, label="Opacity of GradCAM") layer_gradcam = gr.Slider(-2, -1, value=-2, step=1, label="Which Layer?") submit = gr.Button("Submit") with gr.Column(): pred_classes = gr.Label() grad_cam = gr.Image(shape=(32, 32), label="Output",visible=False)\ .style(width=128, height=128) with gr.Row(): with gr.Column(): see_misclassif = gr.Checkbox(label="View misclassified images", info="Do you want see the miscassified images in the test dataset") n_misclasif = gr.Slider(maximum=20, minimum=2, value=10, step=2, label="Number of misclassified images to show", interactive=True, visible=False) render = gr.Button("Render", visible=False) misclasif_display = gr.Image(visible=False) n_top_classes.postprocess(n_top_classes.value) submit.click(inference, inputs=[input_image, n_top_classes, require_gradcam, opacity_gradcam, layer_gradcam], outputs=[pred_classes, grad_cam] ) def turn_on_misclasif(see_misclassif): if see_misclassif: return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False) return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) see_misclassif.change(turn_on_misclasif, see_misclassif, [n_misclasif, render, misclasif_display]) render.click(show_misclasif, [see_misclassif, n_misclasif], misclasif_display) gr.Examples( examples=[ ["examples/truck.jpg", 3, True], ["examples/ship.jpg", 3, True], ["examples/dog.jpg", 3, True], ["examples/cat.jpg", 3, True], ["examples/horse.jpg", 3, True], ["examples/airplane.jpg", 3, True], ["examples/bird.jpg", 3, True], ["examples/automobile.jpg", 3, True], ["examples/deer.jpg", 3, True], ["examples/frog.jpg", 3, True], ], inputs=[input_image, n_top_classes, require_gradcam], outputs=[pred_classes, grad_cam], fn=inference, ) demo.launch()