File size: 4,800 Bytes
77f3515
 
 
9571b87
 
 
77f3515
9571b87
77f3515
 
9571b87
 
 
 
 
 
 
7163838
 
c139d3f
9571b87
 
 
 
 
 
 
77f3515
9571b87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77f3515
9571b87
77f3515
9571b87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7163838
 
9571b87
 
 
 
 
 
 
 
7163838
 
9571b87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from typing import List, Tuple, Dict
from collections import OrderedDict

import gradio as gr
import torch
import torch.nn.functional as F
import timm
from timm.data import create_transform
from timm.models import create_model
from timm.utils import AttentionExtract
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt


def get_attention_models() -> List[str]:
    """Get a list of timm models that have attention blocks."""
    all_models = timm.list_pretrained()
    # FIXME Focusing on ViT models for initial impl
    attention_models = [model for model in all_models if any([model.lower().startswith(p) for p in ('vit', 'deit', 'beit', 'eva')])]
    return attention_models

def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtractor]:
    """Load a model from timm and prepare it for attention extraction."""
    timm.layers.set_fused_attn(False)
    model = create_model(model_name, pretrained=True)
    model.eval()
    extractor = AttentionExtract(model, method='fx')  # can use 'hooks', can also allow specifying matching names for attention nodes or modules...
    return model, extractor

def process_image(image: Image.Image, model: torch.nn.Module, extractor: AttentionExtractor) -> Dict[str, torch.Tensor]:
    """Process the input image and get the attention maps."""
    # Get the correct transform for the model
    config = model.pretrained_cfg
    transform = create_transform(
        input_size=config['input_size'],
        crop_pct=config['crop_pct'],
        mean=config['mean'],
        std=config['std'],
        interpolation=config['interpolation'],
        is_training=False
    )
    
    # Preprocess the image
    tensor = transform(image).unsqueeze(0)
       
    # Extract attention maps
    attention_maps = extractor(tensor)
    
    return attention_maps

def apply_mask(image: np.ndarray, mask: np.ndarray, color: Tuple[float, float, float], alpha: float = 0.5) -> np.ndarray:
    # Ensure mask and image have the same shape
    mask = mask[:, :, np.newaxis]
    mask = np.repeat(mask, 3, axis=2)
    
    # Convert color to numpy array
    color = np.array(color)
    
    # Apply mask
    masked_image = image * (1 - alpha * mask) + alpha * mask * color[np.newaxis, np.newaxis, :] * 255
    return masked_image.astype(np.uint8)

def visualize_attention(image: Image.Image, model_name: str) -> List[Image.Image]:
    """Visualize attention maps for the given image and model."""
    model, extractor = load_model(model_name)
    attention_maps = process_image(image, model, extractor)
    num_prefix_tokens = getattr(model, 'num_prefix_tokens', 0)
    
    # Convert PIL Image to numpy array
    image_np = np.array(image)

    # Create visualizations
    visualizations = []
    for layer_name, attn_map in attention_maps.items():
        print(f"Attention map shape for {layer_name}: {attn_map.shape}")
        
        # Remove the CLS token attention and average over heads        
        attn_map = attn_map[0, :, 0, num_prefix_tokens:].mean(0)  # Shape: (seq_len-1,)
        
        # Reshape the attention map to 2D
        num_patches = int(np.sqrt(attn_map.shape[0]))
        attn_map = attn_map.reshape(num_patches, num_patches)
        
        # Interpolate to match image size
        attn_map = torch.tensor(attn_map).unsqueeze(0).unsqueeze(0)
        attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False)
        attn_map = attn_map.squeeze().cpu().numpy()

        # Normalize attention map
        attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())

        # Create visualization
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))

        # Original image
        ax1.imshow(image_np)
        ax1.set_title("Original Image")
        ax1.axis('off')

        # Attention map overlay
        masked_image = apply_mask(image_np, attn_map, color=(1, 0, 0))  # Red mask
        ax2.imshow(masked_image)
        ax2.set_title(f'Attention Map for {layer_name}')
        ax2.axis('off')

        plt.tight_layout()

        # Convert plot to image
        fig.canvas.draw()
        vis_image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
        visualizations.append(vis_image)
        plt.close(fig)

    return visualizations

# Create Gradio interface
iface = gr.Interface(
    fn=visualize_attention,
    inputs=[
        gr.Image(type="pil", label="Input Image"),
        gr.Dropdown(choices=get_attention_models(), label="Select Model")
    ],
    outputs=gr.Gallery(label="Attention Maps"),
    title="Attention Map Visualizer for timm Models",
    description="Upload an image and select a timm model to visualize its attention maps."
)

iface.launch(debug=True)