shravvvv commited on
Commit
32db49c
·
1 Parent(s): aa85c94

Updated code

Browse files
model_dir/config.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "_name_or_path": "shravvvv/SAG-ViT",
3
- "architectures": [
4
- "SAGViTClassifier"
5
- ],
6
- "d_model": 64,
7
- "dim_feedforward": 64,
8
- "gcn_hidden": 128,
9
- "gcn_out": 64,
10
- "hidden_mlp_features": 64,
11
- "in_channels": 2560,
12
- "model_type": "sagvit",
13
- "nhead": 4,
14
- "num_classes": 10,
15
- "num_layers": 2,
16
- "patch_size": [
17
- 4,
18
- 4
19
- ],
20
- "torch_dtype": "float32",
21
- "transformers_version": "4.47.0",
22
- "use_safetensors": true
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model_dir/pytorch_model.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3472578af5e4fd2f9644e94f8502895fc007abea1e6880364c0480b588c474a0
3
- size 32491922
 
 
 
 
push_model_to_hfhub.py CHANGED
@@ -1,10 +1,13 @@
1
  from transformers import AutoConfig, AutoModel
2
  from modeling_sagvit import SAGViTClassifier, SAGViTConfig
3
 
 
 
 
 
4
  # Load config and model
5
- config = AutoConfig.from_pretrained("shravvvv/SAG-ViT")
6
  model = SAGViTClassifier(config)
7
 
8
- # Save model locally before pushing
9
- model.save_pretrained("./model_dir", safe_serialization=False) # Save in PyTorch format
10
  model.push_to_hub("shravvvv/SAG-ViT")
 
1
  from transformers import AutoConfig, AutoModel
2
  from modeling_sagvit import SAGViTClassifier, SAGViTConfig
3
 
4
+
5
+ AutoConfig.register("sagvit", SAGViTConfig)
6
+ AutoModel.register(SAGViTConfig, SAGViTClassifier)
7
+
8
  # Load config and model
9
+ config = SAGViTConfig()
10
  model = SAGViTClassifier(config)
11
 
12
+ # Push model and code
 
13
  model.push_to_hub("shravvvv/SAG-ViT")
register_model.py CHANGED
@@ -1,14 +1,8 @@
1
  from transformers import AutoConfig, AutoModel
2
- from transformers.models.auto.configuration_auto import CONFIG_MAPPING
3
- from transformers.models.auto.modeling_auto import MODEL_MAPPING
4
 
5
- from sagvit_config import SAGViTConfig
6
- from modeling_sagvit import SAGViTClassifier
7
-
8
- # Register the configuration
9
- CONFIG_MAPPING.register("sagvit", SAGViTConfig)
10
-
11
- # Register the model
12
- MODEL_MAPPING.register(SAGViTConfig, SAGViTClassifier)
13
-
14
- print("Registered model successfully...")
 
1
  from transformers import AutoConfig, AutoModel
2
+ from modeling_sagvit import SAGViTConfig, SAGViTClassifier
 
3
 
4
+ # Register Custom Model and Config
5
+ print("Registering model")
6
+ AutoConfig.register("sagvit", SAGViTConfig)
7
+ AutoModel.register(SAGViTConfig, SAGViTClassifier)
8
+ print("Registration complete")