Spaces:
Sleeping
Sleeping
Commit
·
197f827
1
Parent(s):
8b649e6
init gradio
Browse files- LICENSE +24 -0
- Plybooks.ipynb +199 -0
- README.md +60 -14
- app.py +96 -0
- assets/attention-part.png +0 -0
- assets/embedding.png +0 -0
- assets/patches.png +0 -0
- assets/vit.png +0 -0
- requirements.txt +13 -0
- samples/mr_bean.png +0 -0
- samples/sectional-sofa.png +0 -0
- src/__pycache__/gradcams.cpython-311.pyc +0 -0
- src/datamodule.py +51 -0
- src/gradcams.py +61 -0
- src/old.py.old +194 -0
- src/vit.py +262 -0
LICENSE
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
This is free and unencumbered software released into the public domain.
|
2 |
+
|
3 |
+
Anyone is free to copy, modify, publish, use, compile, sell, or
|
4 |
+
distribute this software, either in source code form or as a compiled
|
5 |
+
binary, for any purpose, commercial or non-commercial, and by any
|
6 |
+
means.
|
7 |
+
|
8 |
+
In jurisdictions that recognize copyright laws, the author or authors
|
9 |
+
of this software dedicate any and all copyright interest in the
|
10 |
+
software to the public domain. We make this dedication for the benefit
|
11 |
+
of the public at large and to the detriment of our heirs and
|
12 |
+
successors. We intend this dedication to be an overt act of
|
13 |
+
relinquishment in perpetuity of all present and future rights to this
|
14 |
+
software under copyright law.
|
15 |
+
|
16 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
17 |
+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
18 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
19 |
+
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
|
20 |
+
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
21 |
+
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
22 |
+
OTHER DEALINGS IN THE SOFTWARE.
|
23 |
+
|
24 |
+
For more information, please refer to <https://unlicense.org>
|
Plybooks.ipynb
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import lightning as pl \n",
|
10 |
+
"from src.datamodule import CIFAR10DataModule\n",
|
11 |
+
"from src.vit import ViTLightning"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": 2,
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [
|
19 |
+
{
|
20 |
+
"name": "stderr",
|
21 |
+
"output_type": "stream",
|
22 |
+
"text": [
|
23 |
+
"GPU available: True (cuda), used: True\n",
|
24 |
+
"TPU available: False, using: 0 TPU cores\n",
|
25 |
+
"HPU available: False, using: 0 HPUs\n"
|
26 |
+
]
|
27 |
+
}
|
28 |
+
],
|
29 |
+
"source": [
|
30 |
+
"trainer = pl.Trainer(max_epochs=15,accelerator='auto',reload_dataloaders_every_n_epochs=2)"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"execution_count": 3,
|
36 |
+
"metadata": {},
|
37 |
+
"outputs": [],
|
38 |
+
"source": [
|
39 |
+
"model = ViTLightning()\n",
|
40 |
+
"dm = CIFAR10DataModule()\n",
|
41 |
+
"dm.setup()"
|
42 |
+
]
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"cell_type": "code",
|
46 |
+
"execution_count": 4,
|
47 |
+
"metadata": {},
|
48 |
+
"outputs": [
|
49 |
+
{
|
50 |
+
"name": "stderr",
|
51 |
+
"output_type": "stream",
|
52 |
+
"text": [
|
53 |
+
"You are using a CUDA device ('NVIDIA GeForce RTX 4050 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n"
|
54 |
+
]
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"name": "stdout",
|
58 |
+
"output_type": "stream",
|
59 |
+
"text": [
|
60 |
+
"Files already downloaded and verified\n",
|
61 |
+
"Files already downloaded and verified\n"
|
62 |
+
]
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"name": "stderr",
|
66 |
+
"output_type": "stream",
|
67 |
+
"text": [
|
68 |
+
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
|
69 |
+
"Loading `train_dataloader` to estimate number of stepping batches.\n",
|
70 |
+
"\n",
|
71 |
+
" | Name | Type | Params | Mode \n",
|
72 |
+
"---------------------------------------------------------\n",
|
73 |
+
"0 | vit | ViT | 154 K | train\n",
|
74 |
+
"1 | train_acc | MulticlassAccuracy | 0 | train\n",
|
75 |
+
"2 | val_acc | MulticlassAccuracy | 0 | train\n",
|
76 |
+
"3 | test_acc | MulticlassAccuracy | 0 | train\n",
|
77 |
+
"---------------------------------------------------------\n",
|
78 |
+
"154 K Trainable params\n",
|
79 |
+
"0 Non-trainable params\n",
|
80 |
+
"154 K Total params\n",
|
81 |
+
"0.616 Total estimated model params size (MB)\n",
|
82 |
+
"37 Modules in train mode\n",
|
83 |
+
"0 Modules in eval mode\n"
|
84 |
+
]
|
85 |
+
},
|
86 |
+
{
|
87 |
+
"name": "stdout",
|
88 |
+
"output_type": "stream",
|
89 |
+
"text": [
|
90 |
+
"Epoch 14: 100%|██████████| 1407/1407 [00:28<00:00, 49.90it/s, v_num=0, train_loss=0.518, train_acc=0.875, val_loss=0.996, val_acc=0.644]"
|
91 |
+
]
|
92 |
+
},
|
93 |
+
{
|
94 |
+
"name": "stderr",
|
95 |
+
"output_type": "stream",
|
96 |
+
"text": [
|
97 |
+
"`Trainer.fit` stopped: `max_epochs=15` reached.\n"
|
98 |
+
]
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"name": "stdout",
|
102 |
+
"output_type": "stream",
|
103 |
+
"text": [
|
104 |
+
"Epoch 14: 100%|██████████| 1407/1407 [00:28<00:00, 49.87it/s, v_num=0, train_loss=0.518, train_acc=0.875, val_loss=0.996, val_acc=0.644]\n"
|
105 |
+
]
|
106 |
+
}
|
107 |
+
],
|
108 |
+
"source": [
|
109 |
+
"trainer.fit(datamodule=dm,model=model)"
|
110 |
+
]
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"cell_type": "code",
|
114 |
+
"execution_count": 6,
|
115 |
+
"metadata": {},
|
116 |
+
"outputs": [
|
117 |
+
{
|
118 |
+
"name": "stdout",
|
119 |
+
"output_type": "stream",
|
120 |
+
"text": [
|
121 |
+
"Files already downloaded and verified\n",
|
122 |
+
"Files already downloaded and verified\n"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"name": "stderr",
|
127 |
+
"output_type": "stream",
|
128 |
+
"text": [
|
129 |
+
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
|
130 |
+
]
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"name": "stdout",
|
134 |
+
"output_type": "stream",
|
135 |
+
"text": [
|
136 |
+
"Validation DataLoader 0: 100%|██████████| 157/157 [00:00<00:00, 163.27it/s]\n"
|
137 |
+
]
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"data": {
|
141 |
+
"text/html": [
|
142 |
+
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
|
143 |
+
"┃<span style=\"font-weight: bold\"> Validate metric </span>┃<span style=\"font-weight: bold\"> DataLoader 0 </span>┃\n",
|
144 |
+
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
|
145 |
+
"│<span style=\"color: #008080; text-decoration-color: #008080\"> val_acc </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.6284000277519226 </span>│\n",
|
146 |
+
"│<span style=\"color: #008080; text-decoration-color: #008080\"> val_loss </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 1.0169780254364014 </span>│\n",
|
147 |
+
"└───────────────────────────┴───────────────────────────┘\n",
|
148 |
+
"</pre>\n"
|
149 |
+
],
|
150 |
+
"text/plain": [
|
151 |
+
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
|
152 |
+
"┃\u001b[1m \u001b[0m\u001b[1m Validate metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
|
153 |
+
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
|
154 |
+
"│\u001b[36m \u001b[0m\u001b[36m val_acc \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.6284000277519226 \u001b[0m\u001b[35m \u001b[0m│\n",
|
155 |
+
"│\u001b[36m \u001b[0m\u001b[36m val_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.0169780254364014 \u001b[0m\u001b[35m \u001b[0m│\n",
|
156 |
+
"└───────────────────────────┴───────────────────────────┘\n"
|
157 |
+
]
|
158 |
+
},
|
159 |
+
"metadata": {},
|
160 |
+
"output_type": "display_data"
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"data": {
|
164 |
+
"text/plain": [
|
165 |
+
"[{'val_loss': 1.0169780254364014, 'val_acc': 0.6284000277519226}]"
|
166 |
+
]
|
167 |
+
},
|
168 |
+
"execution_count": 6,
|
169 |
+
"metadata": {},
|
170 |
+
"output_type": "execute_result"
|
171 |
+
}
|
172 |
+
],
|
173 |
+
"source": [
|
174 |
+
"trainer.validate(model,dm)"
|
175 |
+
]
|
176 |
+
}
|
177 |
+
],
|
178 |
+
"metadata": {
|
179 |
+
"kernelspec": {
|
180 |
+
"display_name": "venv",
|
181 |
+
"language": "python",
|
182 |
+
"name": "python3"
|
183 |
+
},
|
184 |
+
"language_info": {
|
185 |
+
"codemirror_mode": {
|
186 |
+
"name": "ipython",
|
187 |
+
"version": 3
|
188 |
+
},
|
189 |
+
"file_extension": ".py",
|
190 |
+
"mimetype": "text/x-python",
|
191 |
+
"name": "python",
|
192 |
+
"nbconvert_exporter": "python",
|
193 |
+
"pygments_lexer": "ipython3",
|
194 |
+
"version": "3.11.9"
|
195 |
+
}
|
196 |
+
},
|
197 |
+
"nbformat": 4,
|
198 |
+
"nbformat_minor": 2
|
199 |
+
}
|
README.md
CHANGED
@@ -1,14 +1,60 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ViT
|
2 |
+
- GitHub source repo⭐:: [VitCiFar](https://github.com/Muthukamalan/VitCiFar)
|
3 |
+
|
4 |
+
As we all know Transformer architecture, taken up the world by Storm.
|
5 |
+
|
6 |
+
In this Repo, I practised (from scratch) how we implement this to Vision. Transformers are data hungry don't just compare with CNN (not apples to apple comparison here)
|
7 |
+
|
8 |
+
|
9 |
+
#### Model
|
10 |
+
<div align='center'><img src="https://raw.githubusercontent.com/Muthukamalan/VitCiFar/main/assets/vit.png" width=500 height=300></div>
|
11 |
+
|
12 |
+
|
13 |
+
**Patches**
|
14 |
+
```python
|
15 |
+
nn.Conv2d(
|
16 |
+
in_chans,
|
17 |
+
emb_dim,
|
18 |
+
kernel_size = patch_size,
|
19 |
+
stride = patch_size
|
20 |
+
)
|
21 |
+
```
|
22 |
+
<div align='center'>
|
23 |
+
<img src="https://raw.githubusercontent.com/Muthukamalan/VitCiFar/main/assets/patches.png" width=500 height=300 style="display:inline-block; margin-right: 10px;" alt="patchs">
|
24 |
+
<img src="https://raw.githubusercontent.com/Muthukamalan/VitCiFar/main/assets/embedding.png" width=500 height=300 style="display:inline-block;">
|
25 |
+
</div>
|
26 |
+
|
27 |
+
|
28 |
+
> [!NOTE] CASUAL MASK
|
29 |
+
> Unlike in words, we don't use casual mask here.
|
30 |
+
|
31 |
+
|
32 |
+
<!-- <div align='center'><img src="assets/attention-part.png" width=300 height=500 style="display:inline-block; margin-right: 10px;"></div> -->
|
33 |
+
<p align="center">
|
34 |
+
<img src="https://raw.githubusercontent.com/Muthukamalan/VitCiFar/main/assets/attention-part.png" alt="Attention Visualization" />
|
35 |
+
</p>
|
36 |
+
|
37 |
+
|
38 |
+
At Final Projection layer,
|
39 |
+
- pooling (combine) and projected what peredicted layer
|
40 |
+
- Add One Token before train transformer-block after then pick that token pass it to projection layer (like `BERT` did) << ViT chooses
|
41 |
+
|
42 |
+
```python
|
43 |
+
|
44 |
+
# Transformer Encoder
|
45 |
+
xformer_out = self.enc(out) # [batch, 65, 384]
|
46 |
+
if self.is_cls_token:
|
47 |
+
token_out = xformer_out[:,0] # [batch, 384]
|
48 |
+
else:
|
49 |
+
token_out = xformer_out.mean(1)
|
50 |
+
|
51 |
+
# MLP Head
|
52 |
+
projection_out = self.mlp_head(token_out) # [batch, 10]
|
53 |
+
|
54 |
+
```
|
55 |
+
|
56 |
+
|
57 |
+
#### Context Grad-CAM
|
58 |
+
[Xplain AI](https://github.com/jacobgil/pytorch-grad-cam)
|
59 |
+
|
60 |
+
- register_forward_hook:: hook will be executed during the forward pass of the model
|
app.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
############################
|
2 |
+
#
|
3 |
+
# Imports
|
4 |
+
#
|
5 |
+
############################
|
6 |
+
import timm
|
7 |
+
import torch
|
8 |
+
from skimage import io
|
9 |
+
from src.gradcams import GradCam
|
10 |
+
import numpy as np
|
11 |
+
import cv2
|
12 |
+
import gradio as gr
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
############################
|
18 |
+
#
|
19 |
+
# model
|
20 |
+
#
|
21 |
+
############################
|
22 |
+
model:torch.nn.Module = timm.create_model("vit_small_patch16_224",pretrained=True) # num_classes=10
|
23 |
+
model.eval()
|
24 |
+
|
25 |
+
############################
|
26 |
+
#
|
27 |
+
# utility functions
|
28 |
+
#
|
29 |
+
############################
|
30 |
+
def prepare_input(image:np.array)->torch.Tensor:
|
31 |
+
image = image.copy() # (H,W,C)
|
32 |
+
mean = np.array([0.5,.5,.5])
|
33 |
+
stds = np.array([.5,.5,.5])
|
34 |
+
image -= mean
|
35 |
+
image /= stds
|
36 |
+
|
37 |
+
image = np.ascontiguousarray(np.transpose(image,(2,0,1))) # transpose the image to match model's input format (C,H,W)
|
38 |
+
image = image[np.newaxis,...] # (bs, C, H, W)
|
39 |
+
return torch.tensor(image,requires_grad=True)
|
40 |
+
|
41 |
+
|
42 |
+
def gen_cam(image, mask):
|
43 |
+
# create a heatmap from the Grad-CAM mask
|
44 |
+
heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
|
45 |
+
heatmap = np.float32(heatmap)/255.
|
46 |
+
# superimpose the heatmap on the original image
|
47 |
+
cam = (.5*heatmap) + (.5*image.squeeze(0).permute(1,2,0).detach().cpu().numpy())
|
48 |
+
# normalize
|
49 |
+
cam = cam/ np.max(cam)
|
50 |
+
return np.uint8(255*cam)
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
def attn_viz(image,number:int=2):
|
55 |
+
image = np.float32(cv2.resize(image,(224,224) )) / 255
|
56 |
+
image = prepare_input(image)
|
57 |
+
|
58 |
+
target_layer = model.blocks[number]
|
59 |
+
grad_cam = GradCam(model=model,target=target_layer)
|
60 |
+
mask = grad_cam(image)
|
61 |
+
result = gen_cam(image=image,mask=mask)
|
62 |
+
return Image.fromarray(result)
|
63 |
+
|
64 |
+
|
65 |
+
# Create a Gradio TabbedInterface with two tabs
|
66 |
+
with gr.Blocks(
|
67 |
+
title="AttnViz",
|
68 |
+
) as demo:
|
69 |
+
with gr.Tab("Image Processing"):
|
70 |
+
# Create an image input and a number input
|
71 |
+
image_input = gr.Image(label="Input Image",type='numpy')
|
72 |
+
number_input = gr.Number(label="Number",minimum=0,maximum=11,show_label=True)
|
73 |
+
# Create an image output
|
74 |
+
image_output = gr.Image(label="Output Image")
|
75 |
+
# Set up the event listener for the image processing function
|
76 |
+
process_button = gr.Button("Process Image")
|
77 |
+
process_button.click(attn_viz, inputs=[image_input, number_input], outputs=image_output)
|
78 |
+
|
79 |
+
gr.Examples(
|
80 |
+
examples=[
|
81 |
+
["samples/mr_bean.png", 1],
|
82 |
+
["samples/sectional-sofa.png", 8],
|
83 |
+
],
|
84 |
+
inputs=[image_input, number_input],
|
85 |
+
)
|
86 |
+
|
87 |
+
|
88 |
+
with gr.Tab("README"):
|
89 |
+
# Add a simple text description in the About tab
|
90 |
+
with open("README.md", "r+") as file: readme_content = file.read()
|
91 |
+
gr.Markdown(readme_content)
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
if __name__=='__main__':
|
96 |
+
demo.launch(show_error=True,share=False,)
|
assets/attention-part.png
ADDED
![]() |
assets/embedding.png
ADDED
![]() |
assets/patches.png
ADDED
![]() |
assets/vit.png
ADDED
![]() |
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Pilow==10.1.0
|
2 |
+
opencv-python==4.8.1.78
|
3 |
+
opencv-python-headless==4.8.1.78
|
4 |
+
torch==2.5.0
|
5 |
+
pytorch-gradcam==0.2.1
|
6 |
+
torchvision==0.20.0
|
7 |
+
timm==1.0.9
|
8 |
+
gradio==4.44.1
|
9 |
+
gradio_client==1.3.0
|
10 |
+
lightning==2.4.0
|
11 |
+
lightning-utilities==0.11.6
|
12 |
+
pytorch-lightning==2.4.0
|
13 |
+
numpy==1.26.1
|
samples/mr_bean.png
ADDED
![]() |
samples/sectional-sofa.png
ADDED
![]() |
src/__pycache__/gradcams.cpython-311.pyc
ADDED
Binary file (4.17 kB). View file
|
|
src/datamodule.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader, random_split
|
3 |
+
from torchvision import transforms, datasets
|
4 |
+
import lightning as pl
|
5 |
+
|
6 |
+
class CIFAR10DataModule(pl.LightningDataModule):
|
7 |
+
def __init__(self, data_dir: str = r'/home/muthu/GitHub/DATA 📁/CIFAR', batch_size: int = 32, num_workers: int = 4):
|
8 |
+
super().__init__()
|
9 |
+
self.data_dir = data_dir
|
10 |
+
self.batch_size = batch_size
|
11 |
+
self.num_workers = num_workers
|
12 |
+
|
13 |
+
# Define data transforms for train, validation and test
|
14 |
+
self.transform_train = transforms.Compose([
|
15 |
+
transforms.RandomHorizontalFlip(),
|
16 |
+
transforms.RandomCrop(32, padding=4),
|
17 |
+
transforms.Resize((32,32)),
|
18 |
+
transforms.ToTensor(),
|
19 |
+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
|
20 |
+
])
|
21 |
+
|
22 |
+
self.transform_test = transforms.Compose([
|
23 |
+
transforms.ToTensor(),
|
24 |
+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
|
25 |
+
])
|
26 |
+
|
27 |
+
def prepare_data(self):
|
28 |
+
# Download CIFAR-10 dataset
|
29 |
+
datasets.CIFAR10(root=self.data_dir, train=True, download=True)
|
30 |
+
datasets.CIFAR10(root=self.data_dir, train=False, download=True)
|
31 |
+
|
32 |
+
def setup(self, stage=None):
|
33 |
+
# Split dataset for training, validation and test
|
34 |
+
if stage == 'fit' or stage is None:
|
35 |
+
full_train_dataset = datasets.CIFAR10(root=self.data_dir, train=True, transform=self.transform_train)
|
36 |
+
self.train_dataset, self.val_dataset = random_split(full_train_dataset, [45000, 5000])
|
37 |
+
|
38 |
+
if stage == 'test' or stage is None:
|
39 |
+
self.test_dataset = datasets.CIFAR10(root=self.data_dir, train=False, transform=self.transform_test)
|
40 |
+
|
41 |
+
def train_dataloader(self):
|
42 |
+
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
|
43 |
+
|
44 |
+
def val_dataloader(self):
|
45 |
+
return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
|
46 |
+
|
47 |
+
def test_dataloader(self):
|
48 |
+
return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
|
49 |
+
|
50 |
+
def predict_dataloader(self):
|
51 |
+
return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
|
src/gradcams.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2 # OpenCV for image processing
|
2 |
+
import numpy as np # NumPy for numerical operations
|
3 |
+
|
4 |
+
class GradCam:
|
5 |
+
def __init__(self, model, target):
|
6 |
+
self.model = model.eval() # Set the model to evaluation mode
|
7 |
+
self.feature = None # To store the features from the target layer
|
8 |
+
self.gradient = None # To store the gradients from the target layer
|
9 |
+
self.handlers = [] # List to keep track of hooks
|
10 |
+
self.target = target # Target layer for Grad-CAM
|
11 |
+
self._get_hook() # Register hooks to the target layer
|
12 |
+
|
13 |
+
# Hook to get features from the forward pass
|
14 |
+
def _get_features_hook(self, module, input, output):
|
15 |
+
self.feature = self.reshape_transform(output) # Store and reshape the output features
|
16 |
+
|
17 |
+
# Hook to get gradients from the backward pass
|
18 |
+
def _get_grads_hook(self, module, input_grad, output_grad):
|
19 |
+
self.gradient = self.reshape_transform(output_grad) # Store and reshape the output gradients
|
20 |
+
|
21 |
+
def _store_grad(grad):
|
22 |
+
self.gradient = self.reshape_transform(grad) # Store gradients for later use
|
23 |
+
|
24 |
+
output_grad.register_hook(_store_grad) # Register hook to store gradients
|
25 |
+
|
26 |
+
# Register forward hooks to the target layer
|
27 |
+
def _get_hook(self):
|
28 |
+
self.target.register_forward_hook(self._get_features_hook)
|
29 |
+
self.target.register_forward_hook(self._get_grads_hook)
|
30 |
+
|
31 |
+
# Function to reshape the tensor for visualization
|
32 |
+
def reshape_transform(self, tensor, height=14, width=14):
|
33 |
+
result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2))
|
34 |
+
result = result.transpose(2, 3).transpose(1, 2) # Rearrange dimensions to (C, H, W)
|
35 |
+
return result
|
36 |
+
|
37 |
+
# Function to compute the Grad-CAM heatmap
|
38 |
+
def __call__(self, inputs):
|
39 |
+
self.model.zero_grad() # Zero the gradients
|
40 |
+
output = self.model(inputs) # Forward pass
|
41 |
+
|
42 |
+
# Get the index of the highest score in the output
|
43 |
+
index = np.argmax(output.cpu().data.numpy())
|
44 |
+
target = output[0][index] # Get the target score
|
45 |
+
target.backward() # Backward pass to compute gradients
|
46 |
+
|
47 |
+
# Get the gradients and features
|
48 |
+
gradient = self.gradient[0].cpu().data.numpy()
|
49 |
+
weight = np.mean(gradient, axis=(1, 2)) # Average the gradients
|
50 |
+
feature = self.feature[0].cpu().data.numpy()
|
51 |
+
|
52 |
+
# Compute the weighted sum of the features
|
53 |
+
cam = feature * weight[:, np.newaxis, np.newaxis]
|
54 |
+
cam = np.sum(cam, axis=0) # Sum over the channels
|
55 |
+
cam = np.maximum(cam, 0) # Apply ReLU to remove negative values
|
56 |
+
|
57 |
+
# Normalize the heatmap
|
58 |
+
cam -= np.min(cam)
|
59 |
+
cam /= np.max(cam)
|
60 |
+
cam = cv2.resize(cam, (224, 224)) # Resize to match the input image size
|
61 |
+
return cam # Return the Grad-CAM heatmap
|
src/old.py.old
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from lightning.pytorch.utilities.types import EVAL_DATALOADERS
|
3 |
+
import torch
|
4 |
+
from typing import Dict,Optional,Tuple,Union
|
5 |
+
from dataclasses import dataclass
|
6 |
+
|
7 |
+
import lightning as pl
|
8 |
+
from torchmetrics import Accuracy
|
9 |
+
# @dataclass
|
10 |
+
# class ViTCfg:
|
11 |
+
# image_size: int
|
12 |
+
# patch_size: int
|
13 |
+
# num_channels: int
|
14 |
+
# model_dim: int
|
15 |
+
# num_attn_heads:int
|
16 |
+
# attn_dropout: int
|
17 |
+
# d_ff: int
|
18 |
+
# number_encoders:int
|
19 |
+
# classification_heads:int
|
20 |
+
|
21 |
+
|
22 |
+
class PatchEmbedding(torch.nn.Module):
|
23 |
+
def __init__(self, cfg:Dict) -> None:
|
24 |
+
super().__init__()
|
25 |
+
for k,v in cfg.items(): setattr(self,k,v)
|
26 |
+
assert self.image_size % self.patch_size==0,"patch size is not divide image_size properly"
|
27 |
+
self.num_patchs = (self.image_size // self.patch_size)**2
|
28 |
+
self.img2flattn:torch.nn.Conv2d = torch.nn.Conv2d (
|
29 |
+
in_channels = self.num_channels,
|
30 |
+
out_channels=self.model_dim,
|
31 |
+
kernel_size = self.patch_size,
|
32 |
+
stride = self.patch_size,
|
33 |
+
bias=False
|
34 |
+
)
|
35 |
+
def forward(self,x:torch.Tensor)->torch.Tensor:
|
36 |
+
# (bs, 3, 32, 32 ) >> (bs, model_dim, img_size//patch_size, img_size//patch_size ) >> ( 1. model_dim, img_size**2 ) >> ( 1, img_size**2, model_dim )
|
37 |
+
return self.img2flattn(x).flatten(2).transpose(1,2)
|
38 |
+
|
39 |
+
|
40 |
+
class Embedding(torch.nn.Module):
|
41 |
+
def __init__(self,cfg:Dict ) -> None:
|
42 |
+
super().__init__()
|
43 |
+
for k,v in cfg.items(): setattr(self,k,v)
|
44 |
+
self.patch_embedding:PatchEmbedding = PatchEmbedding(cfg=cfg)
|
45 |
+
|
46 |
+
# single [CLS] token
|
47 |
+
self.cls_token:torch.nn.Parameter = torch.nn.Parameter( torch.randn(1,1, self.model_dim ) )
|
48 |
+
|
49 |
+
self.position_embd:torch.nn.Parameter = torch.nn.Parameter(
|
50 |
+
torch.randn( 1, int( (self.image_size // self.patch_size)**2 + 1), self.model_dim )
|
51 |
+
)
|
52 |
+
def forward(self,x:torch.Tensor)->torch.Tensor:
|
53 |
+
x = self.patch_embedding(x)
|
54 |
+
cls_token = self.cls_token.expand( x.shape[0], -1, -1 )
|
55 |
+
x = torch.cat( (cls_token,x) , dim=1)
|
56 |
+
x = x + self.position_embd
|
57 |
+
return x
|
58 |
+
|
59 |
+
|
60 |
+
class AttentionBlock(torch.nn.Module):
|
61 |
+
def __init__(self,cfg:Dict ) -> None:
|
62 |
+
super().__init__()
|
63 |
+
for k,v in cfg.items(): self.__setattr__(k,v)
|
64 |
+
|
65 |
+
assert self.model_dim % self.num_attn_heads ==0, "model dim is not divisible by n heads"
|
66 |
+
|
67 |
+
self.attn_layer:torch.nn.Linear = torch.nn.Linear(self.model_dim, 3*self.model_dim, bias=False)
|
68 |
+
self.out :torch.nn.Linear = torch.nn.Linear(self.model_dim,self.model_dim,bias=False)
|
69 |
+
|
70 |
+
self.attn_dropout:torch.nn.Dropout = torch.nn.Dropout()
|
71 |
+
self.resid_dropout:torch.nn.Dropout= torch.nn.Dropout()
|
72 |
+
|
73 |
+
# casual mask to ensure that attention is only applied to the left in the input seq
|
74 |
+
# self.register_buffer('bias',tensor= torch.tril(torch.ones(self.block_size,self.block_size)).view(1, 1, self.block_size, self.block_size) )
|
75 |
+
'''
|
76 |
+
block_size=10
|
77 |
+
[[[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
|
78 |
+
[1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
|
79 |
+
[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
|
80 |
+
[1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
|
81 |
+
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
|
82 |
+
[1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
|
83 |
+
[1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
|
84 |
+
[1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
|
85 |
+
[1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
|
86 |
+
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]]
|
87 |
+
|
88 |
+
# Batch-1, Seq-1, Mask-(10,10)
|
89 |
+
'''
|
90 |
+
|
91 |
+
def forward(self,x:torch.Tensor, attention_outputs:bool)->Tuple[torch.Tensor, Union[torch.Tensor,None]]:
|
92 |
+
'''
|
93 |
+
input (bs,seq_len,embedding_dim) >> output (bs,seq_len,embedding_dim)
|
94 |
+
|
95 |
+
x :: (bs,seq_len,embedding_dim)
|
96 |
+
attn :: (bs, seq_len, 3*embedding_dim)
|
97 |
+
.split:: (bs, seq_len, 3*embedding_dim).split(embedding_dim,dim=2)
|
98 |
+
# Each chunk (bs,seq_len,embedding) is a view of the original tensor, split across embeddin_dim so, 3 will get
|
99 |
+
|
100 |
+
k,q,v >> (bs,seql_len, n_heads, embedding_dim//n_heads) >> (bs,head, seql_len, embedding_dim//n_heads)
|
101 |
+
# Each Heads are responsible for different context of seq_len
|
102 |
+
'''
|
103 |
+
B,T,C = x.size() #(bs, seq_len ,embedding_dim)
|
104 |
+
|
105 |
+
# calc q,k,v
|
106 |
+
q:torch.Tensor;
|
107 |
+
k:torch.Tensor;
|
108 |
+
v:torch.Tensor;
|
109 |
+
q,k,v = self.attn_layer(x).split(split_size=self.model_dim,dim=2)
|
110 |
+
q = q.view(B,T,self.num_attn_heads, C//self.num_attn_heads).transpose(1,2)
|
111 |
+
k = k.view(B,T,self.num_attn_heads, C//self.num_attn_heads).transpose(1,2)
|
112 |
+
v = v.view(B,T,self.num_attn_heads, C//self.num_attn_heads).transpose(1,2)
|
113 |
+
|
114 |
+
|
115 |
+
attn = (q @ k.transpose(-2,-1)) * (1/math.sqrt(k.size(-1)))
|
116 |
+
# attn = attn.masked_fill(self.bias[:,:,:T,:T]==0,float('-inf'))
|
117 |
+
attn = torch.nn.functional.softmax(attn,dim=-1)
|
118 |
+
attn = self.attn_dropout(attn)
|
119 |
+
|
120 |
+
y:torch.Tensor = attn @ v # (bs, n_heads, T,T) @ (bs, n_heads, T, embding_dm/n_heads ) >> (bs,n_heads, seq_len, embedding_dim/n_heads )
|
121 |
+
y:torch.Tensor = y.transpose(1,2).contiguous().view(B,T,C)
|
122 |
+
|
123 |
+
return self.resid_dropout(self.out(y)), attn if attention_outputs else None
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
class MLP(torch.nn.Module):
|
128 |
+
def __init__(self,cfg:Dict ) -> None:
|
129 |
+
super().__init__()
|
130 |
+
for k,v in cfg.items(): self.__setattr__(k,v)
|
131 |
+
super().__init__()
|
132 |
+
self.dense_1 = torch.nn.Linear(self.model_dim, self.d_ff)
|
133 |
+
self.activation = torch.nn.ReLU()
|
134 |
+
self.layernorm = torch.nn.LayerNorm(self.d_ff)
|
135 |
+
self.dense_2 = torch.nn.Linear(self.d_ff, self.model_dim)
|
136 |
+
self.dropout = torch.nn.Dropout(0.2)
|
137 |
+
def forward(self,x:torch.Tensor)->torch.Tensor:
|
138 |
+
return self.dropout( self.dense_2( self.layernorm(self.activation( self.dense_1(x) )) ) )
|
139 |
+
|
140 |
+
|
141 |
+
class EncoderBlock(torch.nn.Module):
|
142 |
+
def __init__(self,cfg:Dict ) -> None:
|
143 |
+
super().__init__()
|
144 |
+
for k,v in cfg.items(): self.__setattr__(k,v)
|
145 |
+
self.attn_block = AttentionBlock(cfg)
|
146 |
+
self.layernorm_1 = torch.nn.LayerNorm(self.model_dim)
|
147 |
+
self.mlp = MLP(cfg)
|
148 |
+
self.layernorm_2 = torch.nn.LayerNorm(self.model_dim)
|
149 |
+
def forward(self,x:torch.Tensor, attention_outputs:bool)->Tuple[torch.Tensor, Union[torch.Tensor,None]]:
|
150 |
+
# self-attention
|
151 |
+
attention_op, attn = self.attn_block(self.layernorm_1(x), attention_outputs=attention_outputs )
|
152 |
+
x = x + attention_op
|
153 |
+
# FC
|
154 |
+
mlp_output = self.mlp( self.layernorm_2(x) )
|
155 |
+
x = x + mlp_output
|
156 |
+
return x, attn if attention_outputs==True else None # Return the transformer block's output and the attention probabilities (optional)
|
157 |
+
|
158 |
+
class Encoder(torch.nn.Module):
|
159 |
+
"""
|
160 |
+
The transformer encoder module.
|
161 |
+
"""
|
162 |
+
def __init__(self,cfg:Dict ) -> None:
|
163 |
+
super().__init__()
|
164 |
+
for k,v in cfg.items(): self.__setattr__(k,v)
|
165 |
+
# Create a list of transformer blocks
|
166 |
+
self.blocks = torch.nn.ModuleList([])
|
167 |
+
for _ in range(self.number_encoders):
|
168 |
+
block = EncoderBlock(cfg)
|
169 |
+
self.blocks.append(block)
|
170 |
+
|
171 |
+
def forward(self,x:torch.Tensor,attention_outputs:bool):
|
172 |
+
# Calculate the transformer block's output for each block
|
173 |
+
all_attn = []
|
174 |
+
for block in self.blocks:
|
175 |
+
x,attn = block(x,attention_outputs=attention_outputs)
|
176 |
+
all_attn.append(attn)
|
177 |
+
# Return the encoder's output and the attention probabilities (optional)
|
178 |
+
return x,all_attn if attention_outputs==True else None
|
179 |
+
|
180 |
+
|
181 |
+
class ViTClassifier(torch.nn.Module):
|
182 |
+
def __init__(self, cfg:Dict ) -> None:
|
183 |
+
super().__init__()
|
184 |
+
for k,v in cfg.items(): self.__setattr__(k,v)
|
185 |
+
self.embed:Embedding = Embedding(cfg)
|
186 |
+
self.encoders:Encoder = Encoder(cfg=cfg)
|
187 |
+
self.classifier:torch.nn.Linear = torch.nn.Linear(self.model_dim ,self.classification_heads,bias=False)
|
188 |
+
|
189 |
+
def forward(self,x:torch.Tensor,attention_outputs=False):
|
190 |
+
x = self.embed(x)
|
191 |
+
x,attn = self.encoders(x,attention_outputs=attention_outputs)
|
192 |
+
return self.classifier(x[:,0]), attn if attention_outputs else None
|
193 |
+
|
194 |
+
|
src/vit.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class PatchEmbedding(nn.Module): # Done
|
6 |
+
"""
|
7 |
+
img_size: 1d size of each image (32 for CIFAR-10)
|
8 |
+
patch_size: 1d size of each patch (img_size/num_patch_1d, 4 in this experiment)
|
9 |
+
in_chans: input channel (3 for RGB images)
|
10 |
+
emb_dim: flattened length for each token (or patch)
|
11 |
+
"""
|
12 |
+
def __init__(self, img_size:int, patch_size:int, in_chans:int=3, emb_dim:int=48):
|
13 |
+
super(PatchEmbedding, self).__init__()
|
14 |
+
self.img_size = img_size
|
15 |
+
self.patch_size = patch_size
|
16 |
+
|
17 |
+
self.proj = nn.Conv2d(
|
18 |
+
in_chans,
|
19 |
+
emb_dim,
|
20 |
+
kernel_size = patch_size,
|
21 |
+
stride = patch_size
|
22 |
+
)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
with torch.no_grad():
|
26 |
+
# x: [batch, in_chans, img_size, img_size]
|
27 |
+
x = self.proj(x) # [batch, embed_dim, # of patches in a row, # of patches in a col], [batch, 48, 8, 8] in this experiment
|
28 |
+
x = x.flatten(2) # [batch, embed_dim, total # of patches], [batch, 48, 64] in this experiment
|
29 |
+
x = x.transpose(1, 2) # [batch, total # of patches, emb_dim] => Transformer encoder requires this dimensions [batch, number of words, word_emb_dim]
|
30 |
+
return x
|
31 |
+
|
32 |
+
|
33 |
+
class TransformerEncoder(nn.Module): # Done
|
34 |
+
def __init__(self, input_dim:int, mlp_hidden_dim:int, num_head:int=8, dropout:float=0.):
|
35 |
+
# input_dim and head for Multi-Head Attention
|
36 |
+
super(TransformerEncoder, self).__init__()
|
37 |
+
self.norm1 = nn.LayerNorm(input_dim) # LayerNorm is BatchNorm for NLP
|
38 |
+
self.msa = MultiHeadSelfAttention(input_dim, n_heads=num_head)
|
39 |
+
self.norm2 = nn.LayerNorm(input_dim)
|
40 |
+
# Position-wise Feed-Forward Networks with GELU activation functions
|
41 |
+
self.mlp = nn.Sequential(
|
42 |
+
nn.Linear(input_dim, mlp_hidden_dim),
|
43 |
+
nn.GELU(),
|
44 |
+
nn.Linear(mlp_hidden_dim, input_dim),
|
45 |
+
nn.GELU(),
|
46 |
+
)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
out = self.msa(self.norm1(x)) + x # add residual connection
|
50 |
+
out = self.mlp(self.norm2(out)) + out # add another residual connection
|
51 |
+
return out
|
52 |
+
|
53 |
+
|
54 |
+
class MultiHeadSelfAttention(nn.Module):
|
55 |
+
"""
|
56 |
+
dim: dimension of input and out per token features (emb dim for tokens)
|
57 |
+
n_heads: number of heads
|
58 |
+
qkv_bias: whether to have bias in qkv linear layers
|
59 |
+
attn_p: dropout probability for attention
|
60 |
+
proj_p: droupout probability last linear layer
|
61 |
+
scale: scaling factor for attention (1/sqrt(dk))
|
62 |
+
qkv: initial linear layer for the query, key, and value
|
63 |
+
proj: last linear layer
|
64 |
+
attn_drop, proj_drop: dropout layers for attn and proj
|
65 |
+
"""
|
66 |
+
def __init__(self, dim:int, n_heads:int=8, qkv_bias:bool=True, attn_p:float=0.01, proj_p:float=0.01):
|
67 |
+
super(MultiHeadSelfAttention, self).__init__()
|
68 |
+
self.n_heads = n_heads
|
69 |
+
self.dim = dim # embedding dimension for input
|
70 |
+
self.head_dim = dim // n_heads # d_q, d_k, d_v in the paper (int div needed to preserve input dim = output dim)
|
71 |
+
self.scale = self.head_dim ** -0.5 # 1/sqrt(d_k)
|
72 |
+
|
73 |
+
self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias) # lower linear layers in Figure 2 of the paper
|
74 |
+
self.attn_drop = nn.Dropout(attn_p)
|
75 |
+
self.proj = nn.Linear(dim, dim) # upper linear layers in Figure 2 of the paper
|
76 |
+
self.proj_drop = nn.Dropout(proj_p)
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
"""
|
80 |
+
Input and Output shape: [batch_size, n_patches + 1, dim]
|
81 |
+
"""
|
82 |
+
batch_size, n_tokens, x_dim = x.shape # n_tokens = n_patches + 1 (1 is cls_token), x_dim is input dim
|
83 |
+
|
84 |
+
# Sanity Check
|
85 |
+
if x_dim != self.dim: # make sure input dim is same as concatnated dim (output dim)
|
86 |
+
raise ValueError
|
87 |
+
if self.dim != self.head_dim*self.n_heads: # make sure dim is divisible by n_heads
|
88 |
+
raise ValueError(f"Input & Output dim should be divisible by Number of Heads")
|
89 |
+
|
90 |
+
# Linear Layers for Query, Key, Value
|
91 |
+
qkv = self.qkv(x) # (batch_size, n_patches+1, 3*dim)
|
92 |
+
qkv = qkv.reshape(batch_size, n_tokens, 3, self.n_heads, self.head_dim) # (batch_size, n_patches+1, 3, n_heads, head_dim)
|
93 |
+
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch_size, n_heads, n_patches+1, head_dim)
|
94 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # (batch_size, n_heads, n_patches+1, head_dim)
|
95 |
+
|
96 |
+
# Scaled Dot-Product Attention
|
97 |
+
k_t = k.transpose(-2, -1) # K Transpose: (batch_size, n_heads, head_dim, n_patches+1)
|
98 |
+
dot_product = (q @ k_t)*self.scale # Query, Key Dot Product with Scale Factor: (batch_size, n_heads, n_patches+1, n_patches+1)
|
99 |
+
attn = dot_product.softmax(dim=-1) # Softmax: (batch_size, n_heads, n_patches+1, n_patches+1)
|
100 |
+
attn = self.attn_drop(attn) # Attention Dropout: (batch_size, n_heads, n_patches+1, n_patches+1)
|
101 |
+
weighted_avg = attn @ v # (batch_size, n_heads, n_patches+1, head_dim)
|
102 |
+
weighted_avg = weighted_avg.transpose(1, 2) # (batch_size, n_patches+1, n_heads, head_dim)
|
103 |
+
|
104 |
+
# Concat and Last Linear Layer
|
105 |
+
weighted_avg = weighted_avg.flatten(2) # Concat: (batch_size, n_patches+1, dim)
|
106 |
+
x = self.proj(weighted_avg) # Last Linear Layer: (batch_size, n_patches+1, dim)
|
107 |
+
x = self.proj_drop(x) # Last Linear Layer Dropout: (batch_size, n_patches+1, dim)
|
108 |
+
|
109 |
+
return x
|
110 |
+
|
111 |
+
class ViT(nn.Module): # Done
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
in_c:int=3,
|
115 |
+
num_classes:int=10,
|
116 |
+
img_size:int=32,
|
117 |
+
num_patch_1d:int=16,
|
118 |
+
dropout:float=0.1,
|
119 |
+
num_enc_layers:int=2,
|
120 |
+
hidden_dim:int=128,
|
121 |
+
mlp_hidden_dim:int=128//2,
|
122 |
+
num_head:int=4,
|
123 |
+
is_cls_token:bool=True
|
124 |
+
):
|
125 |
+
super(ViT, self).__init__()
|
126 |
+
"""
|
127 |
+
is_cls_token: are we using class token?
|
128 |
+
num_patch_1d: number of patches in one row (or col), 3 in Figure 1 of the paper, 8 in this experiment
|
129 |
+
patch_size: # 1d size (size of row or col) of each patch, 16 for ImageNet in the paper, 4 in this experiment
|
130 |
+
flattened_patch_dim: Flattened vec length for each patch (4 x 4 x 3, each side is 4 and 3 color scheme), 48 in this experiment
|
131 |
+
num_tokens: number of total patches + 1 (class token), 10 in Figure 1 of the paper, 65 in this experiment
|
132 |
+
"""
|
133 |
+
self.is_cls_token = is_cls_token
|
134 |
+
self.num_patch_1d = num_patch_1d
|
135 |
+
self.patch_size = img_size//self.num_patch_1d
|
136 |
+
num_tokens = (self.num_patch_1d**2)+1 if self.is_cls_token else (self.num_patch_1d**2)
|
137 |
+
|
138 |
+
# Divide each image into patches
|
139 |
+
self.images_to_patches = PatchEmbedding(
|
140 |
+
img_size=img_size,
|
141 |
+
patch_size=img_size//num_patch_1d,
|
142 |
+
emb_dim=num_patch_1d*num_patch_1d
|
143 |
+
)
|
144 |
+
|
145 |
+
# Linear Projection of Flattened Patches
|
146 |
+
self.lpfp = nn.Linear(num_patch_1d*num_patch_1d, hidden_dim) # 48 x 384 (384 is the latent vector size D in the paper)
|
147 |
+
|
148 |
+
# Patch + Position Embedding (Learnable)
|
149 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim)) if is_cls_token else None # learnable classification token with dim [1, 1, 384]. 1 in 2nd dim because there is only one class per each image not each patch
|
150 |
+
self.pos_emb = nn.Parameter(torch.randn(1, num_tokens, hidden_dim)) # learnable positional embedding with dim [1, 65, 384]
|
151 |
+
|
152 |
+
# Transformer Encoder
|
153 |
+
enc_list = [TransformerEncoder(hidden_dim, mlp_hidden_dim=mlp_hidden_dim, dropout=dropout, num_head=num_head) for _ in range(num_enc_layers)] # num_enc_layers is L in Transformer Encoder at Figure 1
|
154 |
+
self.enc = nn.Sequential(*enc_list) # * should be adeed if given regular python list to nn.Sequential
|
155 |
+
|
156 |
+
# MLP Head (Standard Classifier)
|
157 |
+
self.mlp_head = nn.Sequential(
|
158 |
+
nn.LayerNorm(hidden_dim),
|
159 |
+
nn.Linear(hidden_dim, num_classes)
|
160 |
+
)
|
161 |
+
|
162 |
+
def forward(self, x): # x: [batch, 3, 32, 32]
|
163 |
+
# Images into Patches (including flattening)
|
164 |
+
out = self.images_to_patches(x) # [batch, 64, 48]
|
165 |
+
# Linear Projection on Flattened Patches
|
166 |
+
out = self.lpfp(out) # [batch, 64, 384]
|
167 |
+
|
168 |
+
# Add Class Token and Positional Embedding
|
169 |
+
if self.is_cls_token:
|
170 |
+
out = torch.cat([self.cls_token.repeat(out.size(0),1,1), out], dim=1) # [batch, 65, 384], added as extra learnable embedding
|
171 |
+
out = out + self.pos_emb # [batch, 65, 384]
|
172 |
+
|
173 |
+
# Transformer Encoder
|
174 |
+
out = self.enc(out) # [batch, 65, 384]
|
175 |
+
if self.is_cls_token:
|
176 |
+
out = out[:,0] # [batch, 384]
|
177 |
+
else:
|
178 |
+
out = out.mean(1)
|
179 |
+
|
180 |
+
# MLP Head
|
181 |
+
out = self.mlp_head(out) # [batch, 10]
|
182 |
+
return out
|
183 |
+
|
184 |
+
|
185 |
+
|
186 |
+
import lightning as pl
|
187 |
+
from torchmetrics import Accuracy
|
188 |
+
|
189 |
+
class ViTLightning(pl.LightningModule):
|
190 |
+
def __init__(self, learning_rate: float = 1e-3):
|
191 |
+
super(ViTLightning, self).__init__()
|
192 |
+
self.vit = ViT(
|
193 |
+
in_c=3,
|
194 |
+
num_classes=10,
|
195 |
+
img_size=32,
|
196 |
+
num_patch_1d=16,
|
197 |
+
dropout=0.1,
|
198 |
+
num_enc_layers=2,
|
199 |
+
hidden_dim=96,
|
200 |
+
mlp_hidden_dim=64,
|
201 |
+
num_head=8,
|
202 |
+
is_cls_token=True
|
203 |
+
)
|
204 |
+
self.train_acc = Accuracy('multiclass',num_classes=10)
|
205 |
+
self.val_acc = Accuracy('multiclass',num_classes=10)
|
206 |
+
self.test_acc = Accuracy('multiclass',num_classes=10)
|
207 |
+
self.learning_rate = learning_rate
|
208 |
+
|
209 |
+
def forward(self, x):
|
210 |
+
return self.vit(x)
|
211 |
+
|
212 |
+
def training_step(self, batch, batch_idx):
|
213 |
+
x, y = batch
|
214 |
+
preds = self.forward(x)
|
215 |
+
loss = nn.CrossEntropyLoss()(preds, y)
|
216 |
+
acc = self.train_acc(preds, y)
|
217 |
+
self.log('train_loss', loss, prog_bar=True, logger=True)
|
218 |
+
self.log('train_acc', acc, prog_bar=True, logger=True)
|
219 |
+
|
220 |
+
return loss
|
221 |
+
|
222 |
+
def validation_step(self, batch, batch_idx):
|
223 |
+
x, y = batch
|
224 |
+
preds = self.forward(x)
|
225 |
+
loss = nn.CrossEntropyLoss()(preds, y)
|
226 |
+
acc = self.val_acc(preds, y)
|
227 |
+
self.log('val_loss', loss, prog_bar=True, logger=True)
|
228 |
+
self.log('val_acc', acc, prog_bar=True, logger=True)
|
229 |
+
return loss
|
230 |
+
|
231 |
+
def test_step(self, batch, batch_idx):
|
232 |
+
x, y = batch
|
233 |
+
preds = self.forward(x)
|
234 |
+
loss = nn.CrossEntropyLoss()(preds, y)
|
235 |
+
acc = self.test_acc(preds, y)
|
236 |
+
self.log('test_loss', loss, prog_bar=True, logger=True)
|
237 |
+
self.log('test_acc', acc, prog_bar=True, logger=True)
|
238 |
+
return loss
|
239 |
+
|
240 |
+
def configure_optimizers(self):
|
241 |
+
optimizer = torch.optim.Adam( self.vit.parameters(), )
|
242 |
+
num_epochs = self.trainer.max_epochs,
|
243 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
244 |
+
optimizer=optimizer,
|
245 |
+
total_steps=self.trainer.estimated_stepping_batches,
|
246 |
+
epochs=num_epochs,
|
247 |
+
pct_start= .3,
|
248 |
+
div_factor= 100,
|
249 |
+
max_lr= 1e-3,
|
250 |
+
three_phase= False,
|
251 |
+
final_div_factor= 100,
|
252 |
+
anneal_strategy='linear'
|
253 |
+
)
|
254 |
+
return {
|
255 |
+
'optimizer':optimizer,
|
256 |
+
'lr_scheduler':{
|
257 |
+
'scheduler':scheduler,
|
258 |
+
'monitor': "val_loss",
|
259 |
+
"interval":"step",
|
260 |
+
"frequency":1
|
261 |
+
}
|
262 |
+
}
|