Transformers
Safetensors
brivangl commited on
Commit
1fdbf1a
·
verified ·
1 Parent(s): 5d40ee8

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +98 -1
README.md CHANGED
@@ -7,24 +7,121 @@ metrics:
7
  ---
8
  # VGG-like Kolmogorov-Arnold Convolutional network with Gram polynomials
9
 
10
- This model is a convolutional version of Kolmogorov-Arnold Network, originally presented in [1, 2], with architecture inspired by VGG11 model. Gram version of KAN originally presented in [3].
11
 
12
  ## Model description
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  ## Intended uses & limitations
15
 
 
 
16
  ### How to use
17
 
 
 
 
 
 
 
 
 
 
18
  ```python
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
 
20
  ```
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  ## Training data
 
23
 
24
  ## Training procedure
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  ## Evaluation results
27
 
 
 
 
 
 
 
 
 
 
28
  ### BibTeX entry and citation info
29
 
30
  If you use this project in your research or wish to refer to the baseline results, please use the following BibTeX entry.
 
7
  ---
8
  # VGG-like Kolmogorov-Arnold Convolutional network with Gram polynomials
9
 
10
+ This model is a Convolutional version of Kolmogorov-Arnold Network with VGG-11 like architecture, pretrained on Imagenet1k dataset. KANs were originally presented in [1, 2]. Gram version of KAN originally presented in [3]. For more details visit our [torch-conv-kan](https://github.com/IvanDrokin/torch-conv-kan) repository on GitHub.
11
 
12
  ## Model description
13
 
14
+ The model consists of consecutive 10 Gram ConvKAN Layers with InstanceNorm2d, polynomial degree equal to 5, GlobalAveragePooling and Linear classification head:
15
+
16
+ 1. KAGN Convolution, 32 filters, 3x3
17
+ 2. Max pooling, 2x2
18
+ 3. KAGN Convolution, 64 filters, 3x3
19
+ 4. Max pooling, 2x2
20
+ 5. KAGN Convolution, 128 filters, 3x3
21
+ 6. KAGN Convolution, 128 filters, 3x3
22
+ 7. Max pooling, 2x2
23
+ 8. KAGN Convolution, 256 filters, 3x3
24
+ 9. KAGN Convolution, 256 filters, 3x3
25
+ 10 Max pooling, 2x2
26
+ 11. KAGN Convolution, 256 filters, 3x3
27
+ 12. KAGN Convolution, 256 filters, 3x3
28
+ 13. Max pooling, 2x2
29
+ 14. KAGN Convolution, 256 filters, 3x3
30
+ 15. KAGN Convolution, 256 filters, 3x3
31
+ 16. Global Average pooling
32
+ 17. Output layer, 1000 nodes.
33
+
34
+ ![model image](https://github.com/IvanDrokin/torch-conv-kan/blob/main/assets/vgg_kagn_11_v2.png)
35
+
36
+
37
  ## Intended uses & limitations
38
 
39
+ You can use the raw model for image classification or use it as pretrained model for further finetuning.
40
+
41
  ### How to use
42
 
43
+ First, clone the repository:
44
+
45
+ ```
46
+ git clone https://github.com/IvanDrokin/torch-conv-kan.git
47
+ cd torch-conv-kan
48
+ pip install -r requirements.txt
49
+ ```
50
+ Then you can initialize the model and load weights.
51
+
52
  ```python
53
+ import torch
54
+ from models import vggkagn
55
+
56
+
57
+ model = vggkagn(3,
58
+ 1000,
59
+ groups=1,
60
+ degree=5,
61
+ dropout=0.15,
62
+ l1_decay=0,
63
+ dropout_linear=0.25,
64
+ width_scale=2,
65
+ vgg_type='VGG11v2',
66
+ expected_feature_shape=(1, 1),
67
+ affine=True
68
+ )
69
 
70
+ model.from_pretrained('Brivangl/vgg_kagn11_v2')
71
  ```
72
 
73
+ Transforms, used for validation on Imagenet1k:
74
+
75
+ ```python
76
+ from torchvision.transforms import v2
77
+
78
+
79
+ transforms_val = v2.Compose([
80
+ v2.ToImage(),
81
+ v2.Resize(256, antialias=True),
82
+ v2.CenterCrop(224),
83
+ v2.ToDtype(torch.float32, scale=True),
84
+ v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
85
+ ])
86
+ ```
87
+
88
+
89
+
90
  ## Training data
91
+ This model trained on Imagenet1k dataset (1281167 images in train set)
92
 
93
  ## Training procedure
94
 
95
+ Model was trained during 200 full epochs with AdamW optimizer, with following parameters:
96
+ ```python
97
+ {'learning_rate': 0.0009, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 5e-06,
98
+ 'adam_epsilon': 1e-08, 'lr_warmup_steps': 7500, 'lr_power': 0.3, 'lr_end': 1e-07, 'set_grads_to_none': False}
99
+ ```
100
+ And this augmnetations:
101
+ ```python
102
+ transforms_train = v2.Compose([
103
+ v2.ToImage(),
104
+ v2.RandomHorizontalFlip(p=0.5),
105
+ v2.RandomResizedCrop(224, antialias=True),
106
+ v2.RandomChoice([v2.AutoAugment(AutoAugmentPolicy.CIFAR10),
107
+ v2.AutoAugment(AutoAugmentPolicy.IMAGENET)
108
+ ]),
109
+ v2.ToDtype(torch.float32, scale=True),
110
+ v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
111
+ ])
112
+ ```
113
+
114
  ## Evaluation results
115
 
116
+ On Imagenet1k Validation:
117
+
118
+ | Accuracy, top1 | Accuracy, top5 | AUC (ovo) | AUC (ovr) |
119
+ |:--------------:|:--------------:|:---------:|:---------:|
120
+ | 59.1 | 82.29 | 99.43 | 99.43 |
121
+
122
+ On Imagenet1k Test:
123
+ Coming soon
124
+
125
  ### BibTeX entry and citation info
126
 
127
  If you use this project in your research or wish to refer to the baseline results, please use the following BibTeX entry.