File size: 4,206 Bytes
26f1810
5d40ee8
 
 
 
 
26f1810
5d40ee8
26f1810
1fdbf1a
5d40ee8
 
 
1fdbf1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86b7d34
1fdbf1a
 
5d40ee8
 
1fdbf1a
 
5d40ee8
 
1fdbf1a
 
 
 
 
 
 
 
 
5d40ee8
1fdbf1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d40ee8
95dbe78
5d40ee8
 
1fdbf1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d40ee8
1fdbf1a
5d40ee8
 
 
1fdbf1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d40ee8
 
1fdbf1a
 
 
 
 
 
 
 
 
5d40ee8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
---
license: mit
datasets:
- imagenet1k
metrics:
- accuracy
---
# VGG-like Kolmogorov-Arnold Convolutional network with Gram polynomials

This model is a Convolutional version of Kolmogorov-Arnold Network with VGG-11 like architecture, pretrained on Imagenet1k dataset. KANs were originally presented in [1, 2]. Gram version of KAN originally presented in [3]. For more details visit our [torch-conv-kan](https://github.com/IvanDrokin/torch-conv-kan) repository on GitHub.

## Model description

The model consists of consecutive 10 Gram ConvKAN Layers with InstanceNorm2d, polynomial degree equal to 5, GlobalAveragePooling and Linear classification head:

1. KAGN Convolution, 32 filters, 3x3
2. Max pooling, 2x2
3. KAGN Convolution, 64 filters, 3x3
4. Max pooling, 2x2
5. KAGN Convolution, 128 filters, 3x3
6. KAGN Convolution, 128 filters, 3x3
7. Max pooling, 2x2
8. KAGN Convolution, 256 filters, 3x3
9. KAGN Convolution, 256 filters, 3x3
10 Max pooling, 2x2
11. KAGN Convolution, 256 filters, 3x3
12. KAGN Convolution, 256 filters, 3x3
13. Max pooling, 2x2
14. KAGN Convolution, 256 filters, 3x3
15. KAGN Convolution, 256 filters, 3x3
16. Global Average pooling
17. Output layer, 1000 nodes.

![model image](https://github.com/IvanDrokin/torch-conv-kan/blob/main/assets/vgg_kagn_11_v2.png?raw=true)


## Intended uses & limitations

You can use the raw model for image classification or use it as pretrained model for further finetuning.

### How to use

First, clone the repository:

```
git clone https://github.com/IvanDrokin/torch-conv-kan.git
cd torch-conv-kan
pip install -r requirements.txt
```
Then you can initialize the model and load weights.

```python
import torch
from models import vggkagn


model = vggkagn(3,
                1000,
                groups=1,
                degree=5,
                dropout=0.15,
                l1_decay=0,
                dropout_linear=0.25,
                width_scale=2,
                vgg_type='VGG11v2',
                expected_feature_shape=(1, 1),
                affine=True
                )

model.from_pretrained('brivangl/vgg_kagn11_v2')
```

Transforms, used for validation on Imagenet1k:

```python
from torchvision.transforms import v2


transforms_val = v2.Compose([
        v2.ToImage(),
        v2.Resize(256, antialias=True),
        v2.CenterCrop(224),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
```



## Training data
This model trained on Imagenet1k dataset (1281167 images in train set)

## Training procedure

Model was trained during 200 full epochs with AdamW optimizer, with following parameters:
```python
{'learning_rate': 0.0009, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 5e-06,
'adam_epsilon': 1e-08, 'lr_warmup_steps': 7500, 'lr_power': 0.3, 'lr_end': 1e-07, 'set_grads_to_none': False}
```
And this augmnetations:
```python
transforms_train = v2.Compose([
    v2.ToImage(),
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomResizedCrop(224, antialias=True),
    v2.RandomChoice([v2.AutoAugment(AutoAugmentPolicy.CIFAR10),
                     v2.AutoAugment(AutoAugmentPolicy.IMAGENET)
                     ]),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
```

## Evaluation results

On Imagenet1k Validation:

| Accuracy, top1 | Accuracy, top5 | AUC (ovo) | AUC (ovr) |
|:--------------:|:--------------:|:---------:|:---------:|
|      59.1      |      82.29     |   99.43   |   99.43   |

On Imagenet1k Test:
Coming soon

### BibTeX entry and citation info

If you use this project in your research or wish to refer to the baseline results, please use the following BibTeX entry.

```bibtex
@misc{torch-conv-kan,
  author = {Ivan Drokin},
  title = {Torch Conv KAN},
  year = {2024},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/IvanDrokin/torch-conv-kan}}
}
```

## References

- [1] Ziming Liu et al., "KAN: Kolmogorov-Arnold Networks", 2024, arXiv. https://arxiv.org/abs/2404.19756
- [2] https://github.com/KindXiaoming/pykan
- [3] https://github.com/Khochawongwat/GRAMKAN