Spaces:
Sleeping
Sleeping
import operator | |
import torch | |
import gradio as gr | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
import gradio as gr | |
from data_loader import CIFAR_CLASS_LABELS, TEST_TRANSFORM | |
import matplotlib | |
from model import ResNet18 | |
matplotlib.use('agg') | |
from matplotlib import pyplot as plt | |
resnet_18_model = ResNet18() | |
resnet_18_model.load_state_dict(torch.load('resnet18.pth', map_location='cpu')) | |
resnet_18_model.eval() | |
classes = ('plane', 'car', 'bird', 'cat', 'deer', | |
'dog', 'frog', 'horse', 'ship', 'truck') | |
def inference(input_img, n_top_classes, | |
apply_gradcam, transparency=0.5, | |
target_layer_number = -1): | |
org_img = input_img | |
input_img = TEST_TRANSFORM(image=input_img)['image'] | |
input_img = input_img.unsqueeze(0) | |
outputs = resnet_18_model(input_img) | |
softmax = torch.nn.Softmax(dim=0) | |
o = softmax(outputs.flatten()) | |
y = {classes[i]: float(o[i]) for i in range(10)} | |
sorted_pred = sorted(y.items(), key=operator.itemgetter(1), reverse=True) | |
sorted_pred = sorted_pred[: n_top_classes] | |
confidences = {klass: prob for klass, prob in sorted_pred} | |
if apply_gradcam: | |
target_layers = [resnet_18_model.layer3[target_layer_number]] | |
cam = GradCAM(model=resnet_18_model, target_layers=target_layers, use_cuda=False) | |
grayscale_cam = cam(input_tensor=input_img, targets=None) | |
grayscale_cam = grayscale_cam[0, :] | |
visualization = show_cam_on_image( | |
org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency) | |
return (gr.update(value= confidences), | |
gr.update(value=visualization, visible=True)) | |
return (gr.update(value=confidences), | |
gr.update(visible=False)) | |
def show_misclasif(see_misclassif, n_images): | |
if see_misclassif: | |
subset = torch.load('misclassified_images.pt') | |
images, actuals, preds = torch.tensor(subset[0])[:20], subset[1], subset[2] | |
figsize=(n_images, 4) | |
nrows=2 | |
ncols=n_images//2 | |
fig, axes = plt.subplots(nrows, ncols, figsize=figsize) | |
fig.suptitle('misclassified images', weight='bold', size=10) | |
axes = axes.ravel() | |
for img, actual, pred, ax in zip(images, actuals, preds, axes): | |
ax.imshow(img) | |
ax.set_title( | |
f'Prediction={CIFAR_CLASS_LABELS[pred]}\n Actual={CIFAR_CLASS_LABELS[actual]}', | |
fontsize=8) | |
ax.set(xticks=[], yticks=[], xticklabels=[], yticklabels=[]) | |
ax.axis('off') | |
image_path = "plot.png" | |
fig.savefig(image_path) | |
plt.close() | |
return gr.update(value=image_path, visible=True) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(shape=(32, 32), label="Input Image") | |
n_top_classes = gr.Slider(maximum=10, minimum=1, value=3, step=1, | |
label="Top n classes to show", interactive=True) | |
require_gradcam = gr.Checkbox(label="Apply GradCAM", | |
info="Do you want see the GRAD-CAM visualization") | |
opacity_gradcam = gr.Slider(0, 1, value=0.5, | |
label="Opacity of GradCAM") | |
layer_gradcam = gr.Slider(-2, -1, value=-2, step=1, | |
label="Which Layer?") | |
submit = gr.Button("Submit") | |
with gr.Column(): | |
pred_classes = gr.Label() | |
grad_cam = gr.Image(shape=(32, 32), | |
label="Output",visible=False)\ | |
.style(width=128, height=128) | |
with gr.Row(): | |
with gr.Column(): | |
see_misclassif = gr.Checkbox(label="View misclassified images", | |
info="Do you want see the miscassified images in the test dataset") | |
n_misclasif = gr.Slider(maximum=20, minimum=2, value=10, step=2, | |
label="Number of misclassified images to show", | |
interactive=True, visible=False) | |
render = gr.Button("Render", visible=False) | |
misclasif_display = gr.Image(visible=False) | |
n_top_classes.postprocess(n_top_classes.value) | |
submit.click(inference, | |
inputs=[input_image, n_top_classes, require_gradcam, | |
opacity_gradcam, layer_gradcam], | |
outputs=[pred_classes, grad_cam] | |
) | |
def turn_on_misclasif(see_misclassif): | |
if see_misclassif: | |
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False) | |
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) | |
see_misclassif.change(turn_on_misclasif, see_misclassif, [n_misclasif, render, misclasif_display]) | |
render.click(show_misclasif, [see_misclassif, n_misclasif], misclasif_display) | |
gr.Examples( | |
examples=[ | |
["examples/truck.jpg", 3, True], | |
["examples/ship.jpg", 3, True], | |
["examples/dog.jpg", 3, True], | |
["examples/cat.jpg", 3, True], | |
["examples/horse.jpg", 3, True], | |
["examples/airplane.jpg", 3, True], | |
["examples/bird.jpg", 3, True], | |
["examples/automobile.jpg", 3, True], | |
["examples/deer.jpg", 3, True], | |
["examples/frog.jpg", 3, True], | |
], | |
inputs=[input_image, n_top_classes, require_gradcam], | |
outputs=[pred_classes, grad_cam], | |
fn=inference, | |
) | |
demo.launch() | |