Muthukamalan commited on
Commit
197f827
·
1 Parent(s): 8b649e6

init gradio

Browse files
LICENSE ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This is free and unencumbered software released into the public domain.
2
+
3
+ Anyone is free to copy, modify, publish, use, compile, sell, or
4
+ distribute this software, either in source code form or as a compiled
5
+ binary, for any purpose, commercial or non-commercial, and by any
6
+ means.
7
+
8
+ In jurisdictions that recognize copyright laws, the author or authors
9
+ of this software dedicate any and all copyright interest in the
10
+ software to the public domain. We make this dedication for the benefit
11
+ of the public at large and to the detriment of our heirs and
12
+ successors. We intend this dedication to be an overt act of
13
+ relinquishment in perpetuity of all present and future rights to this
14
+ software under copyright law.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19
+ IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
20
+ OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
21
+ ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
22
+ OTHER DEALINGS IN THE SOFTWARE.
23
+
24
+ For more information, please refer to <https://unlicense.org>
Plybooks.ipynb ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import lightning as pl \n",
10
+ "from src.datamodule import CIFAR10DataModule\n",
11
+ "from src.vit import ViTLightning"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 2,
17
+ "metadata": {},
18
+ "outputs": [
19
+ {
20
+ "name": "stderr",
21
+ "output_type": "stream",
22
+ "text": [
23
+ "GPU available: True (cuda), used: True\n",
24
+ "TPU available: False, using: 0 TPU cores\n",
25
+ "HPU available: False, using: 0 HPUs\n"
26
+ ]
27
+ }
28
+ ],
29
+ "source": [
30
+ "trainer = pl.Trainer(max_epochs=15,accelerator='auto',reload_dataloaders_every_n_epochs=2)"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 3,
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "model = ViTLightning()\n",
40
+ "dm = CIFAR10DataModule()\n",
41
+ "dm.setup()"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": 4,
47
+ "metadata": {},
48
+ "outputs": [
49
+ {
50
+ "name": "stderr",
51
+ "output_type": "stream",
52
+ "text": [
53
+ "You are using a CUDA device ('NVIDIA GeForce RTX 4050 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n"
54
+ ]
55
+ },
56
+ {
57
+ "name": "stdout",
58
+ "output_type": "stream",
59
+ "text": [
60
+ "Files already downloaded and verified\n",
61
+ "Files already downloaded and verified\n"
62
+ ]
63
+ },
64
+ {
65
+ "name": "stderr",
66
+ "output_type": "stream",
67
+ "text": [
68
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
69
+ "Loading `train_dataloader` to estimate number of stepping batches.\n",
70
+ "\n",
71
+ " | Name | Type | Params | Mode \n",
72
+ "---------------------------------------------------------\n",
73
+ "0 | vit | ViT | 154 K | train\n",
74
+ "1 | train_acc | MulticlassAccuracy | 0 | train\n",
75
+ "2 | val_acc | MulticlassAccuracy | 0 | train\n",
76
+ "3 | test_acc | MulticlassAccuracy | 0 | train\n",
77
+ "---------------------------------------------------------\n",
78
+ "154 K Trainable params\n",
79
+ "0 Non-trainable params\n",
80
+ "154 K Total params\n",
81
+ "0.616 Total estimated model params size (MB)\n",
82
+ "37 Modules in train mode\n",
83
+ "0 Modules in eval mode\n"
84
+ ]
85
+ },
86
+ {
87
+ "name": "stdout",
88
+ "output_type": "stream",
89
+ "text": [
90
+ "Epoch 14: 100%|██████████| 1407/1407 [00:28<00:00, 49.90it/s, v_num=0, train_loss=0.518, train_acc=0.875, val_loss=0.996, val_acc=0.644]"
91
+ ]
92
+ },
93
+ {
94
+ "name": "stderr",
95
+ "output_type": "stream",
96
+ "text": [
97
+ "`Trainer.fit` stopped: `max_epochs=15` reached.\n"
98
+ ]
99
+ },
100
+ {
101
+ "name": "stdout",
102
+ "output_type": "stream",
103
+ "text": [
104
+ "Epoch 14: 100%|██████████| 1407/1407 [00:28<00:00, 49.87it/s, v_num=0, train_loss=0.518, train_acc=0.875, val_loss=0.996, val_acc=0.644]\n"
105
+ ]
106
+ }
107
+ ],
108
+ "source": [
109
+ "trainer.fit(datamodule=dm,model=model)"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": 6,
115
+ "metadata": {},
116
+ "outputs": [
117
+ {
118
+ "name": "stdout",
119
+ "output_type": "stream",
120
+ "text": [
121
+ "Files already downloaded and verified\n",
122
+ "Files already downloaded and verified\n"
123
+ ]
124
+ },
125
+ {
126
+ "name": "stderr",
127
+ "output_type": "stream",
128
+ "text": [
129
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
130
+ ]
131
+ },
132
+ {
133
+ "name": "stdout",
134
+ "output_type": "stream",
135
+ "text": [
136
+ "Validation DataLoader 0: 100%|██████████| 157/157 [00:00<00:00, 163.27it/s]\n"
137
+ ]
138
+ },
139
+ {
140
+ "data": {
141
+ "text/html": [
142
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
143
+ "┃<span style=\"font-weight: bold\"> Validate metric </span>┃<span style=\"font-weight: bold\"> DataLoader 0 </span>┃\n",
144
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
145
+ "│<span style=\"color: #008080; text-decoration-color: #008080\"> val_acc </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.6284000277519226 </span>│\n",
146
+ "│<span style=\"color: #008080; text-decoration-color: #008080\"> val_loss </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 1.0169780254364014 </span>│\n",
147
+ "└───────────────────────────┴───────────────────────────┘\n",
148
+ "</pre>\n"
149
+ ],
150
+ "text/plain": [
151
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
152
+ "┃\u001b[1m \u001b[0m\u001b[1m Validate metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
153
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
154
+ "│\u001b[36m \u001b[0m\u001b[36m val_acc \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.6284000277519226 \u001b[0m\u001b[35m \u001b[0m│\n",
155
+ "│\u001b[36m \u001b[0m\u001b[36m val_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.0169780254364014 \u001b[0m\u001b[35m \u001b[0m│\n",
156
+ "└───────────────────────────┴───────────────────────────┘\n"
157
+ ]
158
+ },
159
+ "metadata": {},
160
+ "output_type": "display_data"
161
+ },
162
+ {
163
+ "data": {
164
+ "text/plain": [
165
+ "[{'val_loss': 1.0169780254364014, 'val_acc': 0.6284000277519226}]"
166
+ ]
167
+ },
168
+ "execution_count": 6,
169
+ "metadata": {},
170
+ "output_type": "execute_result"
171
+ }
172
+ ],
173
+ "source": [
174
+ "trainer.validate(model,dm)"
175
+ ]
176
+ }
177
+ ],
178
+ "metadata": {
179
+ "kernelspec": {
180
+ "display_name": "venv",
181
+ "language": "python",
182
+ "name": "python3"
183
+ },
184
+ "language_info": {
185
+ "codemirror_mode": {
186
+ "name": "ipython",
187
+ "version": 3
188
+ },
189
+ "file_extension": ".py",
190
+ "mimetype": "text/x-python",
191
+ "name": "python",
192
+ "nbconvert_exporter": "python",
193
+ "pygments_lexer": "ipython3",
194
+ "version": "3.11.9"
195
+ }
196
+ },
197
+ "nbformat": 4,
198
+ "nbformat_minor": 2
199
+ }
README.md CHANGED
@@ -1,14 +1,60 @@
1
- ---
2
- title: AttnViz
3
- emoji: 🐨
4
- colorFrom: pink
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.1.0
8
- app_file: app.py
9
- pinned: false
10
- license: unlicense
11
- short_description: Trained on Cifar10 & explore attention visually
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ViT
2
+ - GitHub source repo⭐:: [VitCiFar](https://github.com/Muthukamalan/VitCiFar)
3
+
4
+ As we all know Transformer architecture, taken up the world by Storm.
5
+
6
+ In this Repo, I practised (from scratch) how we implement this to Vision. Transformers are data hungry don't just compare with CNN (not apples to apple comparison here)
7
+
8
+
9
+ #### Model
10
+ <div align='center'><img src="https://raw.githubusercontent.com/Muthukamalan/VitCiFar/main/assets/vit.png" width=500 height=300></div>
11
+
12
+
13
+ **Patches**
14
+ ```python
15
+ nn.Conv2d(
16
+ in_chans,
17
+ emb_dim,
18
+ kernel_size = patch_size,
19
+ stride = patch_size
20
+ )
21
+ ```
22
+ <div align='center'>
23
+ <img src="https://raw.githubusercontent.com/Muthukamalan/VitCiFar/main/assets/patches.png" width=500 height=300 style="display:inline-block; margin-right: 10px;" alt="patchs">
24
+ <img src="https://raw.githubusercontent.com/Muthukamalan/VitCiFar/main/assets/embedding.png" width=500 height=300 style="display:inline-block;">
25
+ </div>
26
+
27
+
28
+ > [!NOTE] CASUAL MASK
29
+ > Unlike in words, we don't use casual mask here.
30
+
31
+
32
+ <!-- <div align='center'><img src="assets/attention-part.png" width=300 height=500 style="display:inline-block; margin-right: 10px;"></div> -->
33
+ <p align="center">
34
+ <img src="https://raw.githubusercontent.com/Muthukamalan/VitCiFar/main/assets/attention-part.png" alt="Attention Visualization" />
35
+ </p>
36
+
37
+
38
+ At Final Projection layer,
39
+ - pooling (combine) and projected what peredicted layer
40
+ - Add One Token before train transformer-block after then pick that token pass it to projection layer (like `BERT` did) << ViT chooses
41
+
42
+ ```python
43
+
44
+ # Transformer Encoder
45
+ xformer_out = self.enc(out) # [batch, 65, 384]
46
+ if self.is_cls_token:
47
+ token_out = xformer_out[:,0] # [batch, 384]
48
+ else:
49
+ token_out = xformer_out.mean(1)
50
+
51
+ # MLP Head
52
+ projection_out = self.mlp_head(token_out) # [batch, 10]
53
+
54
+ ```
55
+
56
+
57
+ #### Context Grad-CAM
58
+ [Xplain AI](https://github.com/jacobgil/pytorch-grad-cam)
59
+
60
+ - register_forward_hook:: hook will be executed during the forward pass of the model
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ############################
2
+ #
3
+ # Imports
4
+ #
5
+ ############################
6
+ import timm
7
+ import torch
8
+ from skimage import io
9
+ from src.gradcams import GradCam
10
+ import numpy as np
11
+ import cv2
12
+ import gradio as gr
13
+ from PIL import Image
14
+
15
+
16
+
17
+ ############################
18
+ #
19
+ # model
20
+ #
21
+ ############################
22
+ model:torch.nn.Module = timm.create_model("vit_small_patch16_224",pretrained=True) # num_classes=10
23
+ model.eval()
24
+
25
+ ############################
26
+ #
27
+ # utility functions
28
+ #
29
+ ############################
30
+ def prepare_input(image:np.array)->torch.Tensor:
31
+ image = image.copy() # (H,W,C)
32
+ mean = np.array([0.5,.5,.5])
33
+ stds = np.array([.5,.5,.5])
34
+ image -= mean
35
+ image /= stds
36
+
37
+ image = np.ascontiguousarray(np.transpose(image,(2,0,1))) # transpose the image to match model's input format (C,H,W)
38
+ image = image[np.newaxis,...] # (bs, C, H, W)
39
+ return torch.tensor(image,requires_grad=True)
40
+
41
+
42
+ def gen_cam(image, mask):
43
+ # create a heatmap from the Grad-CAM mask
44
+ heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
45
+ heatmap = np.float32(heatmap)/255.
46
+ # superimpose the heatmap on the original image
47
+ cam = (.5*heatmap) + (.5*image.squeeze(0).permute(1,2,0).detach().cpu().numpy())
48
+ # normalize
49
+ cam = cam/ np.max(cam)
50
+ return np.uint8(255*cam)
51
+
52
+
53
+
54
+ def attn_viz(image,number:int=2):
55
+ image = np.float32(cv2.resize(image,(224,224) )) / 255
56
+ image = prepare_input(image)
57
+
58
+ target_layer = model.blocks[number]
59
+ grad_cam = GradCam(model=model,target=target_layer)
60
+ mask = grad_cam(image)
61
+ result = gen_cam(image=image,mask=mask)
62
+ return Image.fromarray(result)
63
+
64
+
65
+ # Create a Gradio TabbedInterface with two tabs
66
+ with gr.Blocks(
67
+ title="AttnViz",
68
+ ) as demo:
69
+ with gr.Tab("Image Processing"):
70
+ # Create an image input and a number input
71
+ image_input = gr.Image(label="Input Image",type='numpy')
72
+ number_input = gr.Number(label="Number",minimum=0,maximum=11,show_label=True)
73
+ # Create an image output
74
+ image_output = gr.Image(label="Output Image")
75
+ # Set up the event listener for the image processing function
76
+ process_button = gr.Button("Process Image")
77
+ process_button.click(attn_viz, inputs=[image_input, number_input], outputs=image_output)
78
+
79
+ gr.Examples(
80
+ examples=[
81
+ ["samples/mr_bean.png", 1],
82
+ ["samples/sectional-sofa.png", 8],
83
+ ],
84
+ inputs=[image_input, number_input],
85
+ )
86
+
87
+
88
+ with gr.Tab("README"):
89
+ # Add a simple text description in the About tab
90
+ with open("README.md", "r+") as file: readme_content = file.read()
91
+ gr.Markdown(readme_content)
92
+
93
+
94
+
95
+ if __name__=='__main__':
96
+ demo.launch(show_error=True,share=False,)
assets/attention-part.png ADDED
assets/embedding.png ADDED
assets/patches.png ADDED
assets/vit.png ADDED
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Pilow==10.1.0
2
+ opencv-python==4.8.1.78
3
+ opencv-python-headless==4.8.1.78
4
+ torch==2.5.0
5
+ pytorch-gradcam==0.2.1
6
+ torchvision==0.20.0
7
+ timm==1.0.9
8
+ gradio==4.44.1
9
+ gradio_client==1.3.0
10
+ lightning==2.4.0
11
+ lightning-utilities==0.11.6
12
+ pytorch-lightning==2.4.0
13
+ numpy==1.26.1
samples/mr_bean.png ADDED
samples/sectional-sofa.png ADDED
src/__pycache__/gradcams.cpython-311.pyc ADDED
Binary file (4.17 kB). View file
 
src/datamodule.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader, random_split
3
+ from torchvision import transforms, datasets
4
+ import lightning as pl
5
+
6
+ class CIFAR10DataModule(pl.LightningDataModule):
7
+ def __init__(self, data_dir: str = r'/home/muthu/GitHub/DATA 📁/CIFAR', batch_size: int = 32, num_workers: int = 4):
8
+ super().__init__()
9
+ self.data_dir = data_dir
10
+ self.batch_size = batch_size
11
+ self.num_workers = num_workers
12
+
13
+ # Define data transforms for train, validation and test
14
+ self.transform_train = transforms.Compose([
15
+ transforms.RandomHorizontalFlip(),
16
+ transforms.RandomCrop(32, padding=4),
17
+ transforms.Resize((32,32)),
18
+ transforms.ToTensor(),
19
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
20
+ ])
21
+
22
+ self.transform_test = transforms.Compose([
23
+ transforms.ToTensor(),
24
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
25
+ ])
26
+
27
+ def prepare_data(self):
28
+ # Download CIFAR-10 dataset
29
+ datasets.CIFAR10(root=self.data_dir, train=True, download=True)
30
+ datasets.CIFAR10(root=self.data_dir, train=False, download=True)
31
+
32
+ def setup(self, stage=None):
33
+ # Split dataset for training, validation and test
34
+ if stage == 'fit' or stage is None:
35
+ full_train_dataset = datasets.CIFAR10(root=self.data_dir, train=True, transform=self.transform_train)
36
+ self.train_dataset, self.val_dataset = random_split(full_train_dataset, [45000, 5000])
37
+
38
+ if stage == 'test' or stage is None:
39
+ self.test_dataset = datasets.CIFAR10(root=self.data_dir, train=False, transform=self.transform_test)
40
+
41
+ def train_dataloader(self):
42
+ return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
43
+
44
+ def val_dataloader(self):
45
+ return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
46
+
47
+ def test_dataloader(self):
48
+ return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
49
+
50
+ def predict_dataloader(self):
51
+ return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
src/gradcams.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2 # OpenCV for image processing
2
+ import numpy as np # NumPy for numerical operations
3
+
4
+ class GradCam:
5
+ def __init__(self, model, target):
6
+ self.model = model.eval() # Set the model to evaluation mode
7
+ self.feature = None # To store the features from the target layer
8
+ self.gradient = None # To store the gradients from the target layer
9
+ self.handlers = [] # List to keep track of hooks
10
+ self.target = target # Target layer for Grad-CAM
11
+ self._get_hook() # Register hooks to the target layer
12
+
13
+ # Hook to get features from the forward pass
14
+ def _get_features_hook(self, module, input, output):
15
+ self.feature = self.reshape_transform(output) # Store and reshape the output features
16
+
17
+ # Hook to get gradients from the backward pass
18
+ def _get_grads_hook(self, module, input_grad, output_grad):
19
+ self.gradient = self.reshape_transform(output_grad) # Store and reshape the output gradients
20
+
21
+ def _store_grad(grad):
22
+ self.gradient = self.reshape_transform(grad) # Store gradients for later use
23
+
24
+ output_grad.register_hook(_store_grad) # Register hook to store gradients
25
+
26
+ # Register forward hooks to the target layer
27
+ def _get_hook(self):
28
+ self.target.register_forward_hook(self._get_features_hook)
29
+ self.target.register_forward_hook(self._get_grads_hook)
30
+
31
+ # Function to reshape the tensor for visualization
32
+ def reshape_transform(self, tensor, height=14, width=14):
33
+ result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2))
34
+ result = result.transpose(2, 3).transpose(1, 2) # Rearrange dimensions to (C, H, W)
35
+ return result
36
+
37
+ # Function to compute the Grad-CAM heatmap
38
+ def __call__(self, inputs):
39
+ self.model.zero_grad() # Zero the gradients
40
+ output = self.model(inputs) # Forward pass
41
+
42
+ # Get the index of the highest score in the output
43
+ index = np.argmax(output.cpu().data.numpy())
44
+ target = output[0][index] # Get the target score
45
+ target.backward() # Backward pass to compute gradients
46
+
47
+ # Get the gradients and features
48
+ gradient = self.gradient[0].cpu().data.numpy()
49
+ weight = np.mean(gradient, axis=(1, 2)) # Average the gradients
50
+ feature = self.feature[0].cpu().data.numpy()
51
+
52
+ # Compute the weighted sum of the features
53
+ cam = feature * weight[:, np.newaxis, np.newaxis]
54
+ cam = np.sum(cam, axis=0) # Sum over the channels
55
+ cam = np.maximum(cam, 0) # Apply ReLU to remove negative values
56
+
57
+ # Normalize the heatmap
58
+ cam -= np.min(cam)
59
+ cam /= np.max(cam)
60
+ cam = cv2.resize(cam, (224, 224)) # Resize to match the input image size
61
+ return cam # Return the Grad-CAM heatmap
src/old.py.old ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from lightning.pytorch.utilities.types import EVAL_DATALOADERS
3
+ import torch
4
+ from typing import Dict,Optional,Tuple,Union
5
+ from dataclasses import dataclass
6
+
7
+ import lightning as pl
8
+ from torchmetrics import Accuracy
9
+ # @dataclass
10
+ # class ViTCfg:
11
+ # image_size: int
12
+ # patch_size: int
13
+ # num_channels: int
14
+ # model_dim: int
15
+ # num_attn_heads:int
16
+ # attn_dropout: int
17
+ # d_ff: int
18
+ # number_encoders:int
19
+ # classification_heads:int
20
+
21
+
22
+ class PatchEmbedding(torch.nn.Module):
23
+ def __init__(self, cfg:Dict) -> None:
24
+ super().__init__()
25
+ for k,v in cfg.items(): setattr(self,k,v)
26
+ assert self.image_size % self.patch_size==0,"patch size is not divide image_size properly"
27
+ self.num_patchs = (self.image_size // self.patch_size)**2
28
+ self.img2flattn:torch.nn.Conv2d = torch.nn.Conv2d (
29
+ in_channels = self.num_channels,
30
+ out_channels=self.model_dim,
31
+ kernel_size = self.patch_size,
32
+ stride = self.patch_size,
33
+ bias=False
34
+ )
35
+ def forward(self,x:torch.Tensor)->torch.Tensor:
36
+ # (bs, 3, 32, 32 ) >> (bs, model_dim, img_size//patch_size, img_size//patch_size ) >> ( 1. model_dim, img_size**2 ) >> ( 1, img_size**2, model_dim )
37
+ return self.img2flattn(x).flatten(2).transpose(1,2)
38
+
39
+
40
+ class Embedding(torch.nn.Module):
41
+ def __init__(self,cfg:Dict ) -> None:
42
+ super().__init__()
43
+ for k,v in cfg.items(): setattr(self,k,v)
44
+ self.patch_embedding:PatchEmbedding = PatchEmbedding(cfg=cfg)
45
+
46
+ # single [CLS] token
47
+ self.cls_token:torch.nn.Parameter = torch.nn.Parameter( torch.randn(1,1, self.model_dim ) )
48
+
49
+ self.position_embd:torch.nn.Parameter = torch.nn.Parameter(
50
+ torch.randn( 1, int( (self.image_size // self.patch_size)**2 + 1), self.model_dim )
51
+ )
52
+ def forward(self,x:torch.Tensor)->torch.Tensor:
53
+ x = self.patch_embedding(x)
54
+ cls_token = self.cls_token.expand( x.shape[0], -1, -1 )
55
+ x = torch.cat( (cls_token,x) , dim=1)
56
+ x = x + self.position_embd
57
+ return x
58
+
59
+
60
+ class AttentionBlock(torch.nn.Module):
61
+ def __init__(self,cfg:Dict ) -> None:
62
+ super().__init__()
63
+ for k,v in cfg.items(): self.__setattr__(k,v)
64
+
65
+ assert self.model_dim % self.num_attn_heads ==0, "model dim is not divisible by n heads"
66
+
67
+ self.attn_layer:torch.nn.Linear = torch.nn.Linear(self.model_dim, 3*self.model_dim, bias=False)
68
+ self.out :torch.nn.Linear = torch.nn.Linear(self.model_dim,self.model_dim,bias=False)
69
+
70
+ self.attn_dropout:torch.nn.Dropout = torch.nn.Dropout()
71
+ self.resid_dropout:torch.nn.Dropout= torch.nn.Dropout()
72
+
73
+ # casual mask to ensure that attention is only applied to the left in the input seq
74
+ # self.register_buffer('bias',tensor= torch.tril(torch.ones(self.block_size,self.block_size)).view(1, 1, self.block_size, self.block_size) )
75
+ '''
76
+ block_size=10
77
+ [[[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
78
+ [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
79
+ [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
80
+ [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
81
+ [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
82
+ [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
83
+ [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
84
+ [1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
85
+ [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
86
+ [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]]
87
+
88
+ # Batch-1, Seq-1, Mask-(10,10)
89
+ '''
90
+
91
+ def forward(self,x:torch.Tensor, attention_outputs:bool)->Tuple[torch.Tensor, Union[torch.Tensor,None]]:
92
+ '''
93
+ input (bs,seq_len,embedding_dim) >> output (bs,seq_len,embedding_dim)
94
+
95
+ x :: (bs,seq_len,embedding_dim)
96
+ attn :: (bs, seq_len, 3*embedding_dim)
97
+ .split:: (bs, seq_len, 3*embedding_dim).split(embedding_dim,dim=2)
98
+ # Each chunk (bs,seq_len,embedding) is a view of the original tensor, split across embeddin_dim so, 3 will get
99
+
100
+ k,q,v >> (bs,seql_len, n_heads, embedding_dim//n_heads) >> (bs,head, seql_len, embedding_dim//n_heads)
101
+ # Each Heads are responsible for different context of seq_len
102
+ '''
103
+ B,T,C = x.size() #(bs, seq_len ,embedding_dim)
104
+
105
+ # calc q,k,v
106
+ q:torch.Tensor;
107
+ k:torch.Tensor;
108
+ v:torch.Tensor;
109
+ q,k,v = self.attn_layer(x).split(split_size=self.model_dim,dim=2)
110
+ q = q.view(B,T,self.num_attn_heads, C//self.num_attn_heads).transpose(1,2)
111
+ k = k.view(B,T,self.num_attn_heads, C//self.num_attn_heads).transpose(1,2)
112
+ v = v.view(B,T,self.num_attn_heads, C//self.num_attn_heads).transpose(1,2)
113
+
114
+
115
+ attn = (q @ k.transpose(-2,-1)) * (1/math.sqrt(k.size(-1)))
116
+ # attn = attn.masked_fill(self.bias[:,:,:T,:T]==0,float('-inf'))
117
+ attn = torch.nn.functional.softmax(attn,dim=-1)
118
+ attn = self.attn_dropout(attn)
119
+
120
+ y:torch.Tensor = attn @ v # (bs, n_heads, T,T) @ (bs, n_heads, T, embding_dm/n_heads ) >> (bs,n_heads, seq_len, embedding_dim/n_heads )
121
+ y:torch.Tensor = y.transpose(1,2).contiguous().view(B,T,C)
122
+
123
+ return self.resid_dropout(self.out(y)), attn if attention_outputs else None
124
+
125
+
126
+
127
+ class MLP(torch.nn.Module):
128
+ def __init__(self,cfg:Dict ) -> None:
129
+ super().__init__()
130
+ for k,v in cfg.items(): self.__setattr__(k,v)
131
+ super().__init__()
132
+ self.dense_1 = torch.nn.Linear(self.model_dim, self.d_ff)
133
+ self.activation = torch.nn.ReLU()
134
+ self.layernorm = torch.nn.LayerNorm(self.d_ff)
135
+ self.dense_2 = torch.nn.Linear(self.d_ff, self.model_dim)
136
+ self.dropout = torch.nn.Dropout(0.2)
137
+ def forward(self,x:torch.Tensor)->torch.Tensor:
138
+ return self.dropout( self.dense_2( self.layernorm(self.activation( self.dense_1(x) )) ) )
139
+
140
+
141
+ class EncoderBlock(torch.nn.Module):
142
+ def __init__(self,cfg:Dict ) -> None:
143
+ super().__init__()
144
+ for k,v in cfg.items(): self.__setattr__(k,v)
145
+ self.attn_block = AttentionBlock(cfg)
146
+ self.layernorm_1 = torch.nn.LayerNorm(self.model_dim)
147
+ self.mlp = MLP(cfg)
148
+ self.layernorm_2 = torch.nn.LayerNorm(self.model_dim)
149
+ def forward(self,x:torch.Tensor, attention_outputs:bool)->Tuple[torch.Tensor, Union[torch.Tensor,None]]:
150
+ # self-attention
151
+ attention_op, attn = self.attn_block(self.layernorm_1(x), attention_outputs=attention_outputs )
152
+ x = x + attention_op
153
+ # FC
154
+ mlp_output = self.mlp( self.layernorm_2(x) )
155
+ x = x + mlp_output
156
+ return x, attn if attention_outputs==True else None # Return the transformer block's output and the attention probabilities (optional)
157
+
158
+ class Encoder(torch.nn.Module):
159
+ """
160
+ The transformer encoder module.
161
+ """
162
+ def __init__(self,cfg:Dict ) -> None:
163
+ super().__init__()
164
+ for k,v in cfg.items(): self.__setattr__(k,v)
165
+ # Create a list of transformer blocks
166
+ self.blocks = torch.nn.ModuleList([])
167
+ for _ in range(self.number_encoders):
168
+ block = EncoderBlock(cfg)
169
+ self.blocks.append(block)
170
+
171
+ def forward(self,x:torch.Tensor,attention_outputs:bool):
172
+ # Calculate the transformer block's output for each block
173
+ all_attn = []
174
+ for block in self.blocks:
175
+ x,attn = block(x,attention_outputs=attention_outputs)
176
+ all_attn.append(attn)
177
+ # Return the encoder's output and the attention probabilities (optional)
178
+ return x,all_attn if attention_outputs==True else None
179
+
180
+
181
+ class ViTClassifier(torch.nn.Module):
182
+ def __init__(self, cfg:Dict ) -> None:
183
+ super().__init__()
184
+ for k,v in cfg.items(): self.__setattr__(k,v)
185
+ self.embed:Embedding = Embedding(cfg)
186
+ self.encoders:Encoder = Encoder(cfg=cfg)
187
+ self.classifier:torch.nn.Linear = torch.nn.Linear(self.model_dim ,self.classification_heads,bias=False)
188
+
189
+ def forward(self,x:torch.Tensor,attention_outputs=False):
190
+ x = self.embed(x)
191
+ x,attn = self.encoders(x,attention_outputs=attention_outputs)
192
+ return self.classifier(x[:,0]), attn if attention_outputs else None
193
+
194
+
src/vit.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class PatchEmbedding(nn.Module): # Done
6
+ """
7
+ img_size: 1d size of each image (32 for CIFAR-10)
8
+ patch_size: 1d size of each patch (img_size/num_patch_1d, 4 in this experiment)
9
+ in_chans: input channel (3 for RGB images)
10
+ emb_dim: flattened length for each token (or patch)
11
+ """
12
+ def __init__(self, img_size:int, patch_size:int, in_chans:int=3, emb_dim:int=48):
13
+ super(PatchEmbedding, self).__init__()
14
+ self.img_size = img_size
15
+ self.patch_size = patch_size
16
+
17
+ self.proj = nn.Conv2d(
18
+ in_chans,
19
+ emb_dim,
20
+ kernel_size = patch_size,
21
+ stride = patch_size
22
+ )
23
+
24
+ def forward(self, x):
25
+ with torch.no_grad():
26
+ # x: [batch, in_chans, img_size, img_size]
27
+ x = self.proj(x) # [batch, embed_dim, # of patches in a row, # of patches in a col], [batch, 48, 8, 8] in this experiment
28
+ x = x.flatten(2) # [batch, embed_dim, total # of patches], [batch, 48, 64] in this experiment
29
+ x = x.transpose(1, 2) # [batch, total # of patches, emb_dim] => Transformer encoder requires this dimensions [batch, number of words, word_emb_dim]
30
+ return x
31
+
32
+
33
+ class TransformerEncoder(nn.Module): # Done
34
+ def __init__(self, input_dim:int, mlp_hidden_dim:int, num_head:int=8, dropout:float=0.):
35
+ # input_dim and head for Multi-Head Attention
36
+ super(TransformerEncoder, self).__init__()
37
+ self.norm1 = nn.LayerNorm(input_dim) # LayerNorm is BatchNorm for NLP
38
+ self.msa = MultiHeadSelfAttention(input_dim, n_heads=num_head)
39
+ self.norm2 = nn.LayerNorm(input_dim)
40
+ # Position-wise Feed-Forward Networks with GELU activation functions
41
+ self.mlp = nn.Sequential(
42
+ nn.Linear(input_dim, mlp_hidden_dim),
43
+ nn.GELU(),
44
+ nn.Linear(mlp_hidden_dim, input_dim),
45
+ nn.GELU(),
46
+ )
47
+
48
+ def forward(self, x):
49
+ out = self.msa(self.norm1(x)) + x # add residual connection
50
+ out = self.mlp(self.norm2(out)) + out # add another residual connection
51
+ return out
52
+
53
+
54
+ class MultiHeadSelfAttention(nn.Module):
55
+ """
56
+ dim: dimension of input and out per token features (emb dim for tokens)
57
+ n_heads: number of heads
58
+ qkv_bias: whether to have bias in qkv linear layers
59
+ attn_p: dropout probability for attention
60
+ proj_p: droupout probability last linear layer
61
+ scale: scaling factor for attention (1/sqrt(dk))
62
+ qkv: initial linear layer for the query, key, and value
63
+ proj: last linear layer
64
+ attn_drop, proj_drop: dropout layers for attn and proj
65
+ """
66
+ def __init__(self, dim:int, n_heads:int=8, qkv_bias:bool=True, attn_p:float=0.01, proj_p:float=0.01):
67
+ super(MultiHeadSelfAttention, self).__init__()
68
+ self.n_heads = n_heads
69
+ self.dim = dim # embedding dimension for input
70
+ self.head_dim = dim // n_heads # d_q, d_k, d_v in the paper (int div needed to preserve input dim = output dim)
71
+ self.scale = self.head_dim ** -0.5 # 1/sqrt(d_k)
72
+
73
+ self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias) # lower linear layers in Figure 2 of the paper
74
+ self.attn_drop = nn.Dropout(attn_p)
75
+ self.proj = nn.Linear(dim, dim) # upper linear layers in Figure 2 of the paper
76
+ self.proj_drop = nn.Dropout(proj_p)
77
+
78
+ def forward(self, x):
79
+ """
80
+ Input and Output shape: [batch_size, n_patches + 1, dim]
81
+ """
82
+ batch_size, n_tokens, x_dim = x.shape # n_tokens = n_patches + 1 (1 is cls_token), x_dim is input dim
83
+
84
+ # Sanity Check
85
+ if x_dim != self.dim: # make sure input dim is same as concatnated dim (output dim)
86
+ raise ValueError
87
+ if self.dim != self.head_dim*self.n_heads: # make sure dim is divisible by n_heads
88
+ raise ValueError(f"Input & Output dim should be divisible by Number of Heads")
89
+
90
+ # Linear Layers for Query, Key, Value
91
+ qkv = self.qkv(x) # (batch_size, n_patches+1, 3*dim)
92
+ qkv = qkv.reshape(batch_size, n_tokens, 3, self.n_heads, self.head_dim) # (batch_size, n_patches+1, 3, n_heads, head_dim)
93
+ qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch_size, n_heads, n_patches+1, head_dim)
94
+ q, k, v = qkv[0], qkv[1], qkv[2] # (batch_size, n_heads, n_patches+1, head_dim)
95
+
96
+ # Scaled Dot-Product Attention
97
+ k_t = k.transpose(-2, -1) # K Transpose: (batch_size, n_heads, head_dim, n_patches+1)
98
+ dot_product = (q @ k_t)*self.scale # Query, Key Dot Product with Scale Factor: (batch_size, n_heads, n_patches+1, n_patches+1)
99
+ attn = dot_product.softmax(dim=-1) # Softmax: (batch_size, n_heads, n_patches+1, n_patches+1)
100
+ attn = self.attn_drop(attn) # Attention Dropout: (batch_size, n_heads, n_patches+1, n_patches+1)
101
+ weighted_avg = attn @ v # (batch_size, n_heads, n_patches+1, head_dim)
102
+ weighted_avg = weighted_avg.transpose(1, 2) # (batch_size, n_patches+1, n_heads, head_dim)
103
+
104
+ # Concat and Last Linear Layer
105
+ weighted_avg = weighted_avg.flatten(2) # Concat: (batch_size, n_patches+1, dim)
106
+ x = self.proj(weighted_avg) # Last Linear Layer: (batch_size, n_patches+1, dim)
107
+ x = self.proj_drop(x) # Last Linear Layer Dropout: (batch_size, n_patches+1, dim)
108
+
109
+ return x
110
+
111
+ class ViT(nn.Module): # Done
112
+ def __init__(
113
+ self,
114
+ in_c:int=3,
115
+ num_classes:int=10,
116
+ img_size:int=32,
117
+ num_patch_1d:int=16,
118
+ dropout:float=0.1,
119
+ num_enc_layers:int=2,
120
+ hidden_dim:int=128,
121
+ mlp_hidden_dim:int=128//2,
122
+ num_head:int=4,
123
+ is_cls_token:bool=True
124
+ ):
125
+ super(ViT, self).__init__()
126
+ """
127
+ is_cls_token: are we using class token?
128
+ num_patch_1d: number of patches in one row (or col), 3 in Figure 1 of the paper, 8 in this experiment
129
+ patch_size: # 1d size (size of row or col) of each patch, 16 for ImageNet in the paper, 4 in this experiment
130
+ flattened_patch_dim: Flattened vec length for each patch (4 x 4 x 3, each side is 4 and 3 color scheme), 48 in this experiment
131
+ num_tokens: number of total patches + 1 (class token), 10 in Figure 1 of the paper, 65 in this experiment
132
+ """
133
+ self.is_cls_token = is_cls_token
134
+ self.num_patch_1d = num_patch_1d
135
+ self.patch_size = img_size//self.num_patch_1d
136
+ num_tokens = (self.num_patch_1d**2)+1 if self.is_cls_token else (self.num_patch_1d**2)
137
+
138
+ # Divide each image into patches
139
+ self.images_to_patches = PatchEmbedding(
140
+ img_size=img_size,
141
+ patch_size=img_size//num_patch_1d,
142
+ emb_dim=num_patch_1d*num_patch_1d
143
+ )
144
+
145
+ # Linear Projection of Flattened Patches
146
+ self.lpfp = nn.Linear(num_patch_1d*num_patch_1d, hidden_dim) # 48 x 384 (384 is the latent vector size D in the paper)
147
+
148
+ # Patch + Position Embedding (Learnable)
149
+ self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim)) if is_cls_token else None # learnable classification token with dim [1, 1, 384]. 1 in 2nd dim because there is only one class per each image not each patch
150
+ self.pos_emb = nn.Parameter(torch.randn(1, num_tokens, hidden_dim)) # learnable positional embedding with dim [1, 65, 384]
151
+
152
+ # Transformer Encoder
153
+ enc_list = [TransformerEncoder(hidden_dim, mlp_hidden_dim=mlp_hidden_dim, dropout=dropout, num_head=num_head) for _ in range(num_enc_layers)] # num_enc_layers is L in Transformer Encoder at Figure 1
154
+ self.enc = nn.Sequential(*enc_list) # * should be adeed if given regular python list to nn.Sequential
155
+
156
+ # MLP Head (Standard Classifier)
157
+ self.mlp_head = nn.Sequential(
158
+ nn.LayerNorm(hidden_dim),
159
+ nn.Linear(hidden_dim, num_classes)
160
+ )
161
+
162
+ def forward(self, x): # x: [batch, 3, 32, 32]
163
+ # Images into Patches (including flattening)
164
+ out = self.images_to_patches(x) # [batch, 64, 48]
165
+ # Linear Projection on Flattened Patches
166
+ out = self.lpfp(out) # [batch, 64, 384]
167
+
168
+ # Add Class Token and Positional Embedding
169
+ if self.is_cls_token:
170
+ out = torch.cat([self.cls_token.repeat(out.size(0),1,1), out], dim=1) # [batch, 65, 384], added as extra learnable embedding
171
+ out = out + self.pos_emb # [batch, 65, 384]
172
+
173
+ # Transformer Encoder
174
+ out = self.enc(out) # [batch, 65, 384]
175
+ if self.is_cls_token:
176
+ out = out[:,0] # [batch, 384]
177
+ else:
178
+ out = out.mean(1)
179
+
180
+ # MLP Head
181
+ out = self.mlp_head(out) # [batch, 10]
182
+ return out
183
+
184
+
185
+
186
+ import lightning as pl
187
+ from torchmetrics import Accuracy
188
+
189
+ class ViTLightning(pl.LightningModule):
190
+ def __init__(self, learning_rate: float = 1e-3):
191
+ super(ViTLightning, self).__init__()
192
+ self.vit = ViT(
193
+ in_c=3,
194
+ num_classes=10,
195
+ img_size=32,
196
+ num_patch_1d=16,
197
+ dropout=0.1,
198
+ num_enc_layers=2,
199
+ hidden_dim=96,
200
+ mlp_hidden_dim=64,
201
+ num_head=8,
202
+ is_cls_token=True
203
+ )
204
+ self.train_acc = Accuracy('multiclass',num_classes=10)
205
+ self.val_acc = Accuracy('multiclass',num_classes=10)
206
+ self.test_acc = Accuracy('multiclass',num_classes=10)
207
+ self.learning_rate = learning_rate
208
+
209
+ def forward(self, x):
210
+ return self.vit(x)
211
+
212
+ def training_step(self, batch, batch_idx):
213
+ x, y = batch
214
+ preds = self.forward(x)
215
+ loss = nn.CrossEntropyLoss()(preds, y)
216
+ acc = self.train_acc(preds, y)
217
+ self.log('train_loss', loss, prog_bar=True, logger=True)
218
+ self.log('train_acc', acc, prog_bar=True, logger=True)
219
+
220
+ return loss
221
+
222
+ def validation_step(self, batch, batch_idx):
223
+ x, y = batch
224
+ preds = self.forward(x)
225
+ loss = nn.CrossEntropyLoss()(preds, y)
226
+ acc = self.val_acc(preds, y)
227
+ self.log('val_loss', loss, prog_bar=True, logger=True)
228
+ self.log('val_acc', acc, prog_bar=True, logger=True)
229
+ return loss
230
+
231
+ def test_step(self, batch, batch_idx):
232
+ x, y = batch
233
+ preds = self.forward(x)
234
+ loss = nn.CrossEntropyLoss()(preds, y)
235
+ acc = self.test_acc(preds, y)
236
+ self.log('test_loss', loss, prog_bar=True, logger=True)
237
+ self.log('test_acc', acc, prog_bar=True, logger=True)
238
+ return loss
239
+
240
+ def configure_optimizers(self):
241
+ optimizer = torch.optim.Adam( self.vit.parameters(), )
242
+ num_epochs = self.trainer.max_epochs,
243
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
244
+ optimizer=optimizer,
245
+ total_steps=self.trainer.estimated_stepping_batches,
246
+ epochs=num_epochs,
247
+ pct_start= .3,
248
+ div_factor= 100,
249
+ max_lr= 1e-3,
250
+ three_phase= False,
251
+ final_div_factor= 100,
252
+ anneal_strategy='linear'
253
+ )
254
+ return {
255
+ 'optimizer':optimizer,
256
+ 'lr_scheduler':{
257
+ 'scheduler':scheduler,
258
+ 'monitor': "val_loss",
259
+ "interval":"step",
260
+ "frequency":1
261
+ }
262
+ }