venkyyuvy's picture
map_location to cpu
3cb2ba0
import operator
import torch
import gradio as gr
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
import gradio as gr
from data_loader import CIFAR_CLASS_LABELS, TEST_TRANSFORM
import matplotlib
from model import ResNet18
matplotlib.use('agg')
from matplotlib import pyplot as plt
resnet_18_model = ResNet18()
resnet_18_model.load_state_dict(torch.load('resnet18.pth', map_location='cpu'))
resnet_18_model.eval()
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
def inference(input_img, n_top_classes,
apply_gradcam, transparency=0.5,
target_layer_number = -1):
org_img = input_img
input_img = TEST_TRANSFORM(image=input_img)['image']
input_img = input_img.unsqueeze(0)
outputs = resnet_18_model(input_img)
softmax = torch.nn.Softmax(dim=0)
o = softmax(outputs.flatten())
y = {classes[i]: float(o[i]) for i in range(10)}
sorted_pred = sorted(y.items(), key=operator.itemgetter(1), reverse=True)
sorted_pred = sorted_pred[: n_top_classes]
confidences = {klass: prob for klass, prob in sorted_pred}
if apply_gradcam:
target_layers = [resnet_18_model.layer3[target_layer_number]]
cam = GradCAM(model=resnet_18_model, target_layers=target_layers, use_cuda=False)
grayscale_cam = cam(input_tensor=input_img, targets=None)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(
org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
return (gr.update(value= confidences),
gr.update(value=visualization, visible=True))
return (gr.update(value=confidences),
gr.update(visible=False))
def show_misclasif(see_misclassif, n_images):
if see_misclassif:
subset = torch.load('misclassified_images.pt')
images, actuals, preds = torch.tensor(subset[0])[:20], subset[1], subset[2]
figsize=(n_images, 4)
nrows=2
ncols=n_images//2
fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
fig.suptitle('misclassified images', weight='bold', size=10)
axes = axes.ravel()
for img, actual, pred, ax in zip(images, actuals, preds, axes):
ax.imshow(img)
ax.set_title(
f'Prediction={CIFAR_CLASS_LABELS[pred]}\n Actual={CIFAR_CLASS_LABELS[actual]}',
fontsize=8)
ax.set(xticks=[], yticks=[], xticklabels=[], yticklabels=[])
ax.axis('off')
image_path = "plot.png"
fig.savefig(image_path)
plt.close()
return gr.update(value=image_path, visible=True)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_image = gr.Image(shape=(32, 32), label="Input Image")
n_top_classes = gr.Slider(maximum=10, minimum=1, value=3, step=1,
label="Top n classes to show", interactive=True)
require_gradcam = gr.Checkbox(label="Apply GradCAM",
info="Do you want see the GRAD-CAM visualization")
opacity_gradcam = gr.Slider(0, 1, value=0.5,
label="Opacity of GradCAM")
layer_gradcam = gr.Slider(-2, -1, value=-2, step=1,
label="Which Layer?")
submit = gr.Button("Submit")
with gr.Column():
pred_classes = gr.Label()
grad_cam = gr.Image(shape=(32, 32),
label="Output",visible=False)\
.style(width=128, height=128)
with gr.Row():
with gr.Column():
see_misclassif = gr.Checkbox(label="View misclassified images",
info="Do you want see the miscassified images in the test dataset")
n_misclasif = gr.Slider(maximum=20, minimum=2, value=10, step=2,
label="Number of misclassified images to show",
interactive=True, visible=False)
render = gr.Button("Render", visible=False)
misclasif_display = gr.Image(visible=False)
n_top_classes.postprocess(n_top_classes.value)
submit.click(inference,
inputs=[input_image, n_top_classes, require_gradcam,
opacity_gradcam, layer_gradcam],
outputs=[pred_classes, grad_cam]
)
def turn_on_misclasif(see_misclassif):
if see_misclassif:
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
see_misclassif.change(turn_on_misclasif, see_misclassif, [n_misclasif, render, misclasif_display])
render.click(show_misclasif, [see_misclassif, n_misclasif], misclasif_display)
gr.Examples(
examples=[
["examples/truck.jpg", 3, True],
["examples/ship.jpg", 3, True],
["examples/dog.jpg", 3, True],
["examples/cat.jpg", 3, True],
["examples/horse.jpg", 3, True],
["examples/airplane.jpg", 3, True],
["examples/bird.jpg", 3, True],
["examples/automobile.jpg", 3, True],
["examples/deer.jpg", 3, True],
["examples/frog.jpg", 3, True],
],
inputs=[input_image, n_top_classes, require_gradcam],
outputs=[pred_classes, grad_cam],
fn=inference,
)
demo.launch()