Aleš Sršeň
commited on
Commit
·
b732563
1
Parent(s):
81e888b
feat: Initial commit of app.py with the gradio interface
Browse files- app.py +214 -0
- requirements.txt +11 -0
app.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
# %pip install gradio diffusers
|
3 |
+
|
4 |
+
# %%
|
5 |
+
import gradio as gr
|
6 |
+
from PIL import Image
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import numpy as np
|
9 |
+
import cv2
|
10 |
+
import torch
|
11 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM
|
12 |
+
import io
|
13 |
+
from diffusers import StableDiffusionPipeline
|
14 |
+
|
15 |
+
# %%
|
16 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
+
|
18 |
+
# Load BLIP model and processor
|
19 |
+
model_name = "Salesforce/blip-image-captioning-large"
|
20 |
+
blip_processor = BlipProcessor.from_pretrained(model_name)
|
21 |
+
blip_model = BlipForConditionalGeneration.from_pretrained(model_name).to(device)
|
22 |
+
blip_model.config.vision_config.output_attentions = True
|
23 |
+
|
24 |
+
# Load Stable Diffusion model
|
25 |
+
diffusion_model_name = "CompVis/stable-diffusion-v1-4"
|
26 |
+
diffusion_pipeline = StableDiffusionPipeline.from_pretrained(diffusion_model_name).to(device)
|
27 |
+
|
28 |
+
# Load smol model
|
29 |
+
smol_model_name = "Michaelj1/INSTRUCT_smolLM2-360M-finetuned-wikitext2-raw-v1"
|
30 |
+
tokenizer = AutoTokenizer.from_pretrained(smol_model_name)
|
31 |
+
smol_model = AutoModelForCausalLM.from_pretrained(smol_model_name).to(device)
|
32 |
+
|
33 |
+
# %%
|
34 |
+
def generate_caption(image):
|
35 |
+
inputs = blip_processor(images=image, return_tensors="pt").to(device)
|
36 |
+
caption_ids = blip_model.generate(**inputs, max_new_tokens=50)
|
37 |
+
caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True)
|
38 |
+
return caption, inputs
|
39 |
+
|
40 |
+
def generate_gradcam(image, inputs):
|
41 |
+
with torch.no_grad():
|
42 |
+
vision_outputs = blip_model.vision_model(**inputs)
|
43 |
+
attentions = vision_outputs.attentions
|
44 |
+
last_layer_attentions = attentions[-1]
|
45 |
+
avg_attention = last_layer_attentions.mean(dim=1)
|
46 |
+
cls_attention = avg_attention[:, 0, 1:]
|
47 |
+
num_patches = cls_attention.shape[-1]
|
48 |
+
grid_size = int(np.sqrt(num_patches))
|
49 |
+
attention_map = cls_attention.cpu().numpy().reshape(grid_size, grid_size)
|
50 |
+
attention_map = cv2.resize(attention_map, (image.size[0], image.size[1]))
|
51 |
+
attention_map = attention_map - np.min(attention_map)
|
52 |
+
attention_map = attention_map / np.max(attention_map)
|
53 |
+
img_np = np.array(image)
|
54 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * attention_map), cv2.COLORMAP_JET)
|
55 |
+
heatmap = np.float32(heatmap) / 255
|
56 |
+
cam = heatmap + np.float32(img_np) / 255
|
57 |
+
cam = cam / np.max(cam)
|
58 |
+
cam_image = np.uint8(255 * cam)
|
59 |
+
return cam_image
|
60 |
+
|
61 |
+
|
62 |
+
def generate_image_from_caption(caption):
|
63 |
+
image = diffusion_pipeline(caption).images[0]
|
64 |
+
return image
|
65 |
+
|
66 |
+
|
67 |
+
def explain_word(word):
|
68 |
+
messages = [{"role": "user", "content": f"Explain the word '{word}' in detail."}]
|
69 |
+
input_text = tokenizer.apply_chat_template(messages, tokenize=False)
|
70 |
+
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
71 |
+
outputs = smol_model.generate(
|
72 |
+
inputs,
|
73 |
+
max_new_tokens=150,
|
74 |
+
temperature=0.9,
|
75 |
+
top_p=0.95,
|
76 |
+
do_sample=True
|
77 |
+
)
|
78 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
79 |
+
lines = generated_text.split('\n')
|
80 |
+
assistant_response = []
|
81 |
+
collect = False
|
82 |
+
for line in lines:
|
83 |
+
line = line.strip()
|
84 |
+
if line.lower() == 'assistant':
|
85 |
+
collect = True
|
86 |
+
continue
|
87 |
+
elif line.lower() in ['system', 'user']:
|
88 |
+
collect = False
|
89 |
+
if collect and line:
|
90 |
+
assistant_response.append(line)
|
91 |
+
explanation = '\n'.join(assistant_response).strip()
|
92 |
+
return explanation
|
93 |
+
|
94 |
+
def get_caption_self_attention(caption):
|
95 |
+
text_inputs = blip_processor.tokenizer(
|
96 |
+
caption,
|
97 |
+
return_tensors="pt",
|
98 |
+
add_special_tokens=True
|
99 |
+
).to(device)
|
100 |
+
|
101 |
+
with torch.no_grad():
|
102 |
+
outputs = blip_model.text_decoder(
|
103 |
+
input_ids=text_inputs.input_ids,
|
104 |
+
attention_mask=text_inputs.attention_mask,
|
105 |
+
output_attentions=True,
|
106 |
+
return_dict=True,
|
107 |
+
)
|
108 |
+
decoder_attentions = outputs.attentions
|
109 |
+
return decoder_attentions, text_inputs
|
110 |
+
|
111 |
+
|
112 |
+
def generate_self_attention(decoder_attentions, text_inputs):
|
113 |
+
last_layer_attentions = decoder_attentions[-1]
|
114 |
+
avg_attentions = last_layer_attentions.mean(dim=1)
|
115 |
+
attentions = avg_attentions[0].cpu().numpy()
|
116 |
+
tokens = blip_processor.tokenizer.convert_ids_to_tokens(text_inputs.input_ids[0])
|
117 |
+
cls_token = blip_processor.tokenizer.cls_token or "[CLS]"
|
118 |
+
sep_token = blip_processor.tokenizer.sep_token or "[SEP]"
|
119 |
+
special_token_indices = [idx for idx, token in enumerate(tokens) if token in [cls_token, sep_token]]
|
120 |
+
mask = np.ones(len(tokens), dtype=bool)
|
121 |
+
mask[special_token_indices] = False
|
122 |
+
filtered_tokens = [token for idx, token in enumerate(tokens) if mask[idx]]
|
123 |
+
filtered_attentions = attentions[mask, :][:, mask]
|
124 |
+
return filtered_tokens, filtered_attentions
|
125 |
+
|
126 |
+
def process_image(image):
|
127 |
+
# Ensure input is in the correct format
|
128 |
+
if isinstance(image, np.ndarray):
|
129 |
+
image = Image.fromarray(image)
|
130 |
+
caption, inputs = generate_caption(image)
|
131 |
+
cam_image = generate_gradcam(image, inputs)
|
132 |
+
diffusion_image = generate_image_from_caption(caption)
|
133 |
+
decoder_attentions, text_inputs = get_caption_self_attention(caption)
|
134 |
+
filtered_tokens, filtered_attentions = generate_self_attention(decoder_attentions, text_inputs)
|
135 |
+
|
136 |
+
# Create visualization grid
|
137 |
+
fig, axs = plt.subplots(2, 2, figsize=(18, 18))
|
138 |
+
|
139 |
+
axs[0][0].imshow(image)
|
140 |
+
axs[0][0].axis('off')
|
141 |
+
axs[0][0].set_title('Original Image')
|
142 |
+
|
143 |
+
axs[0][1].imshow(cam_image)
|
144 |
+
axs[0][1].axis('off')
|
145 |
+
axs[0][1].set_title('Grad-CAM Overlay')
|
146 |
+
|
147 |
+
axs[1][0].imshow(diffusion_image)
|
148 |
+
axs[1][0].axis('off')
|
149 |
+
axs[1][0].set_title('Generated Image (Stable Diffusion)')
|
150 |
+
|
151 |
+
ax = axs[1][1]
|
152 |
+
im = ax.imshow(filtered_attentions, cmap='viridis')
|
153 |
+
ax.set_xticks(range(len(filtered_tokens)))
|
154 |
+
ax.set_yticks(range(len(filtered_tokens)))
|
155 |
+
ax.set_xticklabels(filtered_tokens, rotation=90, fontsize=8)
|
156 |
+
ax.set_yticklabels(filtered_tokens, fontsize=8)
|
157 |
+
ax.set_title('Caption Self-Attention')
|
158 |
+
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
159 |
+
|
160 |
+
plt.tight_layout()
|
161 |
+
|
162 |
+
# Save visualization to a buffer for display
|
163 |
+
buffer = io.BytesIO()
|
164 |
+
plt.savefig(buffer, format='png')
|
165 |
+
plt.close(fig)
|
166 |
+
buffer.seek(0)
|
167 |
+
visualization_image = Image.open(buffer)
|
168 |
+
|
169 |
+
# Generate word options for dropdown
|
170 |
+
words = caption.split()
|
171 |
+
return caption, visualization_image, gr.Dropdown(label="Select a Word from Caption", choices=words, interactive=True)
|
172 |
+
|
173 |
+
|
174 |
+
def get_word_explanation(word):
|
175 |
+
explanation = explain_word(word)
|
176 |
+
return f"Explanation for '{word}':\n\n{explanation}"
|
177 |
+
|
178 |
+
# %%
|
179 |
+
# Define Gradio interface
|
180 |
+
with gr.Blocks() as interface:
|
181 |
+
gr.Markdown("# Image Captioning and Visualization with Word Explanation")
|
182 |
+
|
183 |
+
with gr.Row():
|
184 |
+
with gr.Column():
|
185 |
+
image_input = gr.Image(type="pil", label="Upload an Image")
|
186 |
+
process_button = gr.Button("Process Image")
|
187 |
+
with gr.Column():
|
188 |
+
caption_output = gr.Textbox(label="Generated Caption")
|
189 |
+
visualization_output = gr.Image(type="pil", label="Visualization (Original, Grad-CAM, Stable Diffusion)")
|
190 |
+
|
191 |
+
word_dropdown = gr.Dropdown(label="Select a Word from Caption", choices=[], interactive=True)
|
192 |
+
word_explanation = gr.Textbox(label="Word Explanation")
|
193 |
+
|
194 |
+
# Bind functions to components
|
195 |
+
process_button.click(
|
196 |
+
process_image,
|
197 |
+
inputs=image_input,
|
198 |
+
outputs=[caption_output, visualization_output, word_dropdown]
|
199 |
+
)
|
200 |
+
|
201 |
+
word_dropdown.change(
|
202 |
+
get_word_explanation,
|
203 |
+
inputs=word_dropdown,
|
204 |
+
outputs=word_explanation
|
205 |
+
)
|
206 |
+
|
207 |
+
# %%
|
208 |
+
# Launch the Gradio app
|
209 |
+
interface.launch()
|
210 |
+
|
211 |
+
# %%
|
212 |
+
|
213 |
+
|
214 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers==0.31.0
|
2 |
+
numpy==1.26.4
|
3 |
+
opencv-python==4.10.0.84
|
4 |
+
pillow==10.3.0
|
5 |
+
matplotlib==3.7.5
|
6 |
+
transformers==4.45.1
|
7 |
+
torchvision==0.19.0
|
8 |
+
torch==2.4.0
|
9 |
+
torchaudio==2.4.0
|
10 |
+
jupyterlab==4.2.5
|
11 |
+
gradio==5.9.1
|