Aleš Sršeň commited on
Commit
b732563
·
1 Parent(s): 81e888b

feat: Initial commit of app.py with the gradio interface

Browse files
Files changed (2) hide show
  1. app.py +214 -0
  2. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ # %pip install gradio diffusers
3
+
4
+ # %%
5
+ import gradio as gr
6
+ from PIL import Image
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import cv2
10
+ import torch
11
+ from transformers import BlipProcessor, BlipForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM
12
+ import io
13
+ from diffusers import StableDiffusionPipeline
14
+
15
+ # %%
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+
18
+ # Load BLIP model and processor
19
+ model_name = "Salesforce/blip-image-captioning-large"
20
+ blip_processor = BlipProcessor.from_pretrained(model_name)
21
+ blip_model = BlipForConditionalGeneration.from_pretrained(model_name).to(device)
22
+ blip_model.config.vision_config.output_attentions = True
23
+
24
+ # Load Stable Diffusion model
25
+ diffusion_model_name = "CompVis/stable-diffusion-v1-4"
26
+ diffusion_pipeline = StableDiffusionPipeline.from_pretrained(diffusion_model_name).to(device)
27
+
28
+ # Load smol model
29
+ smol_model_name = "Michaelj1/INSTRUCT_smolLM2-360M-finetuned-wikitext2-raw-v1"
30
+ tokenizer = AutoTokenizer.from_pretrained(smol_model_name)
31
+ smol_model = AutoModelForCausalLM.from_pretrained(smol_model_name).to(device)
32
+
33
+ # %%
34
+ def generate_caption(image):
35
+ inputs = blip_processor(images=image, return_tensors="pt").to(device)
36
+ caption_ids = blip_model.generate(**inputs, max_new_tokens=50)
37
+ caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True)
38
+ return caption, inputs
39
+
40
+ def generate_gradcam(image, inputs):
41
+ with torch.no_grad():
42
+ vision_outputs = blip_model.vision_model(**inputs)
43
+ attentions = vision_outputs.attentions
44
+ last_layer_attentions = attentions[-1]
45
+ avg_attention = last_layer_attentions.mean(dim=1)
46
+ cls_attention = avg_attention[:, 0, 1:]
47
+ num_patches = cls_attention.shape[-1]
48
+ grid_size = int(np.sqrt(num_patches))
49
+ attention_map = cls_attention.cpu().numpy().reshape(grid_size, grid_size)
50
+ attention_map = cv2.resize(attention_map, (image.size[0], image.size[1]))
51
+ attention_map = attention_map - np.min(attention_map)
52
+ attention_map = attention_map / np.max(attention_map)
53
+ img_np = np.array(image)
54
+ heatmap = cv2.applyColorMap(np.uint8(255 * attention_map), cv2.COLORMAP_JET)
55
+ heatmap = np.float32(heatmap) / 255
56
+ cam = heatmap + np.float32(img_np) / 255
57
+ cam = cam / np.max(cam)
58
+ cam_image = np.uint8(255 * cam)
59
+ return cam_image
60
+
61
+
62
+ def generate_image_from_caption(caption):
63
+ image = diffusion_pipeline(caption).images[0]
64
+ return image
65
+
66
+
67
+ def explain_word(word):
68
+ messages = [{"role": "user", "content": f"Explain the word '{word}' in detail."}]
69
+ input_text = tokenizer.apply_chat_template(messages, tokenize=False)
70
+ inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
71
+ outputs = smol_model.generate(
72
+ inputs,
73
+ max_new_tokens=150,
74
+ temperature=0.9,
75
+ top_p=0.95,
76
+ do_sample=True
77
+ )
78
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
79
+ lines = generated_text.split('\n')
80
+ assistant_response = []
81
+ collect = False
82
+ for line in lines:
83
+ line = line.strip()
84
+ if line.lower() == 'assistant':
85
+ collect = True
86
+ continue
87
+ elif line.lower() in ['system', 'user']:
88
+ collect = False
89
+ if collect and line:
90
+ assistant_response.append(line)
91
+ explanation = '\n'.join(assistant_response).strip()
92
+ return explanation
93
+
94
+ def get_caption_self_attention(caption):
95
+ text_inputs = blip_processor.tokenizer(
96
+ caption,
97
+ return_tensors="pt",
98
+ add_special_tokens=True
99
+ ).to(device)
100
+
101
+ with torch.no_grad():
102
+ outputs = blip_model.text_decoder(
103
+ input_ids=text_inputs.input_ids,
104
+ attention_mask=text_inputs.attention_mask,
105
+ output_attentions=True,
106
+ return_dict=True,
107
+ )
108
+ decoder_attentions = outputs.attentions
109
+ return decoder_attentions, text_inputs
110
+
111
+
112
+ def generate_self_attention(decoder_attentions, text_inputs):
113
+ last_layer_attentions = decoder_attentions[-1]
114
+ avg_attentions = last_layer_attentions.mean(dim=1)
115
+ attentions = avg_attentions[0].cpu().numpy()
116
+ tokens = blip_processor.tokenizer.convert_ids_to_tokens(text_inputs.input_ids[0])
117
+ cls_token = blip_processor.tokenizer.cls_token or "[CLS]"
118
+ sep_token = blip_processor.tokenizer.sep_token or "[SEP]"
119
+ special_token_indices = [idx for idx, token in enumerate(tokens) if token in [cls_token, sep_token]]
120
+ mask = np.ones(len(tokens), dtype=bool)
121
+ mask[special_token_indices] = False
122
+ filtered_tokens = [token for idx, token in enumerate(tokens) if mask[idx]]
123
+ filtered_attentions = attentions[mask, :][:, mask]
124
+ return filtered_tokens, filtered_attentions
125
+
126
+ def process_image(image):
127
+ # Ensure input is in the correct format
128
+ if isinstance(image, np.ndarray):
129
+ image = Image.fromarray(image)
130
+ caption, inputs = generate_caption(image)
131
+ cam_image = generate_gradcam(image, inputs)
132
+ diffusion_image = generate_image_from_caption(caption)
133
+ decoder_attentions, text_inputs = get_caption_self_attention(caption)
134
+ filtered_tokens, filtered_attentions = generate_self_attention(decoder_attentions, text_inputs)
135
+
136
+ # Create visualization grid
137
+ fig, axs = plt.subplots(2, 2, figsize=(18, 18))
138
+
139
+ axs[0][0].imshow(image)
140
+ axs[0][0].axis('off')
141
+ axs[0][0].set_title('Original Image')
142
+
143
+ axs[0][1].imshow(cam_image)
144
+ axs[0][1].axis('off')
145
+ axs[0][1].set_title('Grad-CAM Overlay')
146
+
147
+ axs[1][0].imshow(diffusion_image)
148
+ axs[1][0].axis('off')
149
+ axs[1][0].set_title('Generated Image (Stable Diffusion)')
150
+
151
+ ax = axs[1][1]
152
+ im = ax.imshow(filtered_attentions, cmap='viridis')
153
+ ax.set_xticks(range(len(filtered_tokens)))
154
+ ax.set_yticks(range(len(filtered_tokens)))
155
+ ax.set_xticklabels(filtered_tokens, rotation=90, fontsize=8)
156
+ ax.set_yticklabels(filtered_tokens, fontsize=8)
157
+ ax.set_title('Caption Self-Attention')
158
+ plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
159
+
160
+ plt.tight_layout()
161
+
162
+ # Save visualization to a buffer for display
163
+ buffer = io.BytesIO()
164
+ plt.savefig(buffer, format='png')
165
+ plt.close(fig)
166
+ buffer.seek(0)
167
+ visualization_image = Image.open(buffer)
168
+
169
+ # Generate word options for dropdown
170
+ words = caption.split()
171
+ return caption, visualization_image, gr.Dropdown(label="Select a Word from Caption", choices=words, interactive=True)
172
+
173
+
174
+ def get_word_explanation(word):
175
+ explanation = explain_word(word)
176
+ return f"Explanation for '{word}':\n\n{explanation}"
177
+
178
+ # %%
179
+ # Define Gradio interface
180
+ with gr.Blocks() as interface:
181
+ gr.Markdown("# Image Captioning and Visualization with Word Explanation")
182
+
183
+ with gr.Row():
184
+ with gr.Column():
185
+ image_input = gr.Image(type="pil", label="Upload an Image")
186
+ process_button = gr.Button("Process Image")
187
+ with gr.Column():
188
+ caption_output = gr.Textbox(label="Generated Caption")
189
+ visualization_output = gr.Image(type="pil", label="Visualization (Original, Grad-CAM, Stable Diffusion)")
190
+
191
+ word_dropdown = gr.Dropdown(label="Select a Word from Caption", choices=[], interactive=True)
192
+ word_explanation = gr.Textbox(label="Word Explanation")
193
+
194
+ # Bind functions to components
195
+ process_button.click(
196
+ process_image,
197
+ inputs=image_input,
198
+ outputs=[caption_output, visualization_output, word_dropdown]
199
+ )
200
+
201
+ word_dropdown.change(
202
+ get_word_explanation,
203
+ inputs=word_dropdown,
204
+ outputs=word_explanation
205
+ )
206
+
207
+ # %%
208
+ # Launch the Gradio app
209
+ interface.launch()
210
+
211
+ # %%
212
+
213
+
214
+
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.31.0
2
+ numpy==1.26.4
3
+ opencv-python==4.10.0.84
4
+ pillow==10.3.0
5
+ matplotlib==3.7.5
6
+ transformers==4.45.1
7
+ torchvision==0.19.0
8
+ torch==2.4.0
9
+ torchaudio==2.4.0
10
+ jupyterlab==4.2.5
11
+ gradio==5.9.1