File size: 2,794 Bytes
197f827
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
############################
#
#   Imports
#
############################
import timm 
import torch
from skimage import io 
from src.gradcams import GradCam
import numpy as np 
import cv2 
import gradio  as gr 
from PIL import Image



############################
#
#   model
#
############################
model:torch.nn.Module = timm.create_model("vit_small_patch16_224",pretrained=True) # num_classes=10
model.eval()

############################
#
#   utility functions
#
############################
def prepare_input(image:np.array)->torch.Tensor:
    image = image.copy()   # (H,W,C)
    mean  = np.array([0.5,.5,.5])
    stds = np.array([.5,.5,.5])
    image -= mean 
    image /= stds 

    image = np.ascontiguousarray(np.transpose(image,(2,0,1)))  # transpose the image to match model's input format (C,H,W)
    image = image[np.newaxis,...]  # (bs, C, H, W)
    return torch.tensor(image,requires_grad=True)


def gen_cam(image, mask):
    # create a heatmap from the Grad-CAM mask 
    heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap)/255.
    # superimpose the heatmap on the original image
    cam =  (.5*heatmap) + (.5*image.squeeze(0).permute(1,2,0).detach().cpu().numpy())
    # normalize
    cam = cam/ np.max(cam)
    return np.uint8(255*cam)



def attn_viz(image,number:int=2):
    image = np.float32(cv2.resize(image,(224,224) )) / 255
    image = prepare_input(image)

    target_layer = model.blocks[number]
    grad_cam = GradCam(model=model,target=target_layer)
    mask = grad_cam(image)
    result = gen_cam(image=image,mask=mask)
    return Image.fromarray(result)


# Create a Gradio TabbedInterface with two tabs
with gr.Blocks(
            title="AttnViz",
    ) as demo:
    with gr.Tab("Image Processing"):
        # Create an image input and a number input
        image_input = gr.Image(label="Input Image",type='numpy')
        number_input = gr.Number(label="Number",minimum=0,maximum=11,show_label=True)
        # Create an image output
        image_output = gr.Image(label="Output Image")
        # Set up the event listener for the image processing function
        process_button = gr.Button("Process Image")
        process_button.click(attn_viz, inputs=[image_input, number_input], outputs=image_output)
        
        gr.Examples(
            examples=[
                ["samples/mr_bean.png", 1],
                ["samples/sectional-sofa.png", 8],
            ],
            inputs=[image_input, number_input],
        )


    with gr.Tab("README"):
        # Add a simple text description in the About tab
        with open("README.md", "r+") as file: readme_content = file.read()
        gr.Markdown(readme_content)



if __name__=='__main__':
    demo.launch(show_error=True,share=False,)