Update README.md
Browse files
README.md
CHANGED
@@ -7,24 +7,121 @@ metrics:
|
|
7 |
---
|
8 |
# VGG-like Kolmogorov-Arnold Convolutional network with Gram polynomials
|
9 |
|
10 |
-
This model is a
|
11 |
|
12 |
## Model description
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
## Intended uses & limitations
|
15 |
|
|
|
|
|
16 |
### How to use
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
```python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
|
|
20 |
```
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
## Training data
|
|
|
23 |
|
24 |
## Training procedure
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
## Evaluation results
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
### BibTeX entry and citation info
|
29 |
|
30 |
If you use this project in your research or wish to refer to the baseline results, please use the following BibTeX entry.
|
|
|
7 |
---
|
8 |
# VGG-like Kolmogorov-Arnold Convolutional network with Gram polynomials
|
9 |
|
10 |
+
This model is a Convolutional version of Kolmogorov-Arnold Network with VGG-11 like architecture, pretrained on Imagenet1k dataset. KANs were originally presented in [1, 2]. Gram version of KAN originally presented in [3]. For more details visit our [torch-conv-kan](https://github.com/IvanDrokin/torch-conv-kan) repository on GitHub.
|
11 |
|
12 |
## Model description
|
13 |
|
14 |
+
The model consists of consecutive 10 Gram ConvKAN Layers with InstanceNorm2d, polynomial degree equal to 5, GlobalAveragePooling and Linear classification head:
|
15 |
+
|
16 |
+
1. KAGN Convolution, 32 filters, 3x3
|
17 |
+
2. Max pooling, 2x2
|
18 |
+
3. KAGN Convolution, 64 filters, 3x3
|
19 |
+
4. Max pooling, 2x2
|
20 |
+
5. KAGN Convolution, 128 filters, 3x3
|
21 |
+
6. KAGN Convolution, 128 filters, 3x3
|
22 |
+
7. Max pooling, 2x2
|
23 |
+
8. KAGN Convolution, 256 filters, 3x3
|
24 |
+
9. KAGN Convolution, 256 filters, 3x3
|
25 |
+
10 Max pooling, 2x2
|
26 |
+
11. KAGN Convolution, 256 filters, 3x3
|
27 |
+
12. KAGN Convolution, 256 filters, 3x3
|
28 |
+
13. Max pooling, 2x2
|
29 |
+
14. KAGN Convolution, 256 filters, 3x3
|
30 |
+
15. KAGN Convolution, 256 filters, 3x3
|
31 |
+
16. Global Average pooling
|
32 |
+
17. Output layer, 1000 nodes.
|
33 |
+
|
34 |
+

|
35 |
+
|
36 |
+
|
37 |
## Intended uses & limitations
|
38 |
|
39 |
+
You can use the raw model for image classification or use it as pretrained model for further finetuning.
|
40 |
+
|
41 |
### How to use
|
42 |
|
43 |
+
First, clone the repository:
|
44 |
+
|
45 |
+
```
|
46 |
+
git clone https://github.com/IvanDrokin/torch-conv-kan.git
|
47 |
+
cd torch-conv-kan
|
48 |
+
pip install -r requirements.txt
|
49 |
+
```
|
50 |
+
Then you can initialize the model and load weights.
|
51 |
+
|
52 |
```python
|
53 |
+
import torch
|
54 |
+
from models import vggkagn
|
55 |
+
|
56 |
+
|
57 |
+
model = vggkagn(3,
|
58 |
+
1000,
|
59 |
+
groups=1,
|
60 |
+
degree=5,
|
61 |
+
dropout=0.15,
|
62 |
+
l1_decay=0,
|
63 |
+
dropout_linear=0.25,
|
64 |
+
width_scale=2,
|
65 |
+
vgg_type='VGG11v2',
|
66 |
+
expected_feature_shape=(1, 1),
|
67 |
+
affine=True
|
68 |
+
)
|
69 |
|
70 |
+
model.from_pretrained('Brivangl/vgg_kagn11_v2')
|
71 |
```
|
72 |
|
73 |
+
Transforms, used for validation on Imagenet1k:
|
74 |
+
|
75 |
+
```python
|
76 |
+
from torchvision.transforms import v2
|
77 |
+
|
78 |
+
|
79 |
+
transforms_val = v2.Compose([
|
80 |
+
v2.ToImage(),
|
81 |
+
v2.Resize(256, antialias=True),
|
82 |
+
v2.CenterCrop(224),
|
83 |
+
v2.ToDtype(torch.float32, scale=True),
|
84 |
+
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
85 |
+
])
|
86 |
+
```
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
## Training data
|
91 |
+
This model trained on Imagenet1k dataset (1281167 images in train set)
|
92 |
|
93 |
## Training procedure
|
94 |
|
95 |
+
Model was trained during 200 full epochs with AdamW optimizer, with following parameters:
|
96 |
+
```python
|
97 |
+
{'learning_rate': 0.0009, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 5e-06,
|
98 |
+
'adam_epsilon': 1e-08, 'lr_warmup_steps': 7500, 'lr_power': 0.3, 'lr_end': 1e-07, 'set_grads_to_none': False}
|
99 |
+
```
|
100 |
+
And this augmnetations:
|
101 |
+
```python
|
102 |
+
transforms_train = v2.Compose([
|
103 |
+
v2.ToImage(),
|
104 |
+
v2.RandomHorizontalFlip(p=0.5),
|
105 |
+
v2.RandomResizedCrop(224, antialias=True),
|
106 |
+
v2.RandomChoice([v2.AutoAugment(AutoAugmentPolicy.CIFAR10),
|
107 |
+
v2.AutoAugment(AutoAugmentPolicy.IMAGENET)
|
108 |
+
]),
|
109 |
+
v2.ToDtype(torch.float32, scale=True),
|
110 |
+
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
111 |
+
])
|
112 |
+
```
|
113 |
+
|
114 |
## Evaluation results
|
115 |
|
116 |
+
On Imagenet1k Validation:
|
117 |
+
|
118 |
+
| Accuracy, top1 | Accuracy, top5 | AUC (ovo) | AUC (ovr) |
|
119 |
+
|:--------------:|:--------------:|:---------:|:---------:|
|
120 |
+
| 59.1 | 82.29 | 99.43 | 99.43 |
|
121 |
+
|
122 |
+
On Imagenet1k Test:
|
123 |
+
Coming soon
|
124 |
+
|
125 |
### BibTeX entry and citation info
|
126 |
|
127 |
If you use this project in your research or wish to refer to the baseline results, please use the following BibTeX entry.
|