AttnViz / app.py
Muthukamalan's picture
init gradio
197f827
raw
history blame
2.79 kB
############################
#
# Imports
#
############################
import timm
import torch
from skimage import io
from src.gradcams import GradCam
import numpy as np
import cv2
import gradio as gr
from PIL import Image
############################
#
# model
#
############################
model:torch.nn.Module = timm.create_model("vit_small_patch16_224",pretrained=True) # num_classes=10
model.eval()
############################
#
# utility functions
#
############################
def prepare_input(image:np.array)->torch.Tensor:
image = image.copy() # (H,W,C)
mean = np.array([0.5,.5,.5])
stds = np.array([.5,.5,.5])
image -= mean
image /= stds
image = np.ascontiguousarray(np.transpose(image,(2,0,1))) # transpose the image to match model's input format (C,H,W)
image = image[np.newaxis,...] # (bs, C, H, W)
return torch.tensor(image,requires_grad=True)
def gen_cam(image, mask):
# create a heatmap from the Grad-CAM mask
heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap)/255.
# superimpose the heatmap on the original image
cam = (.5*heatmap) + (.5*image.squeeze(0).permute(1,2,0).detach().cpu().numpy())
# normalize
cam = cam/ np.max(cam)
return np.uint8(255*cam)
def attn_viz(image,number:int=2):
image = np.float32(cv2.resize(image,(224,224) )) / 255
image = prepare_input(image)
target_layer = model.blocks[number]
grad_cam = GradCam(model=model,target=target_layer)
mask = grad_cam(image)
result = gen_cam(image=image,mask=mask)
return Image.fromarray(result)
# Create a Gradio TabbedInterface with two tabs
with gr.Blocks(
title="AttnViz",
) as demo:
with gr.Tab("Image Processing"):
# Create an image input and a number input
image_input = gr.Image(label="Input Image",type='numpy')
number_input = gr.Number(label="Number",minimum=0,maximum=11,show_label=True)
# Create an image output
image_output = gr.Image(label="Output Image")
# Set up the event listener for the image processing function
process_button = gr.Button("Process Image")
process_button.click(attn_viz, inputs=[image_input, number_input], outputs=image_output)
gr.Examples(
examples=[
["samples/mr_bean.png", 1],
["samples/sectional-sofa.png", 8],
],
inputs=[image_input, number_input],
)
with gr.Tab("README"):
# Add a simple text description in the About tab
with open("README.md", "r+") as file: readme_content = file.read()
gr.Markdown(readme_content)
if __name__=='__main__':
demo.launch(show_error=True,share=False,)