Updated code
Browse files- model_dir/config.json +0 -23
- model_dir/pytorch_model.bin +0 -3
- push_model_to_hfhub.py +6 -3
- register_model.py +6 -12
model_dir/config.json
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"_name_or_path": "shravvvv/SAG-ViT",
|
3 |
-
"architectures": [
|
4 |
-
"SAGViTClassifier"
|
5 |
-
],
|
6 |
-
"d_model": 64,
|
7 |
-
"dim_feedforward": 64,
|
8 |
-
"gcn_hidden": 128,
|
9 |
-
"gcn_out": 64,
|
10 |
-
"hidden_mlp_features": 64,
|
11 |
-
"in_channels": 2560,
|
12 |
-
"model_type": "sagvit",
|
13 |
-
"nhead": 4,
|
14 |
-
"num_classes": 10,
|
15 |
-
"num_layers": 2,
|
16 |
-
"patch_size": [
|
17 |
-
4,
|
18 |
-
4
|
19 |
-
],
|
20 |
-
"torch_dtype": "float32",
|
21 |
-
"transformers_version": "4.47.0",
|
22 |
-
"use_safetensors": true
|
23 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_dir/pytorch_model.bin
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:3472578af5e4fd2f9644e94f8502895fc007abea1e6880364c0480b588c474a0
|
3 |
-
size 32491922
|
|
|
|
|
|
|
|
push_model_to_hfhub.py
CHANGED
@@ -1,10 +1,13 @@
|
|
1 |
from transformers import AutoConfig, AutoModel
|
2 |
from modeling_sagvit import SAGViTClassifier, SAGViTConfig
|
3 |
|
|
|
|
|
|
|
|
|
4 |
# Load config and model
|
5 |
-
config =
|
6 |
model = SAGViTClassifier(config)
|
7 |
|
8 |
-
#
|
9 |
-
model.save_pretrained("./model_dir", safe_serialization=False) # Save in PyTorch format
|
10 |
model.push_to_hub("shravvvv/SAG-ViT")
|
|
|
1 |
from transformers import AutoConfig, AutoModel
|
2 |
from modeling_sagvit import SAGViTClassifier, SAGViTConfig
|
3 |
|
4 |
+
|
5 |
+
AutoConfig.register("sagvit", SAGViTConfig)
|
6 |
+
AutoModel.register(SAGViTConfig, SAGViTClassifier)
|
7 |
+
|
8 |
# Load config and model
|
9 |
+
config = SAGViTConfig()
|
10 |
model = SAGViTClassifier(config)
|
11 |
|
12 |
+
# Push model and code
|
|
|
13 |
model.push_to_hub("shravvvv/SAG-ViT")
|
register_model.py
CHANGED
@@ -1,14 +1,8 @@
|
|
1 |
from transformers import AutoConfig, AutoModel
|
2 |
-
from
|
3 |
-
from transformers.models.auto.modeling_auto import MODEL_MAPPING
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
# Register the model
|
12 |
-
MODEL_MAPPING.register(SAGViTConfig, SAGViTClassifier)
|
13 |
-
|
14 |
-
print("Registered model successfully...")
|
|
|
1 |
from transformers import AutoConfig, AutoModel
|
2 |
+
from modeling_sagvit import SAGViTConfig, SAGViTClassifier
|
|
|
3 |
|
4 |
+
# Register Custom Model and Config
|
5 |
+
print("Registering model")
|
6 |
+
AutoConfig.register("sagvit", SAGViTConfig)
|
7 |
+
AutoModel.register(SAGViTConfig, SAGViTClassifier)
|
8 |
+
print("Registration complete")
|
|
|
|
|
|
|
|
|
|