shravvvv commited on
Commit
b99e299
·
1 Parent(s): 0fe9461

Added files

Browse files
__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ from .modeling_sagvit import SAGViTClassifier
hubconf.py CHANGED
@@ -1,6 +1,6 @@
1
  dependencies = ['torch']
2
 
3
- from sag_vit_model import SAGViTClassifier
4
  import torch
5
 
6
  def SAGViT(pretrained=False, **kwargs):
 
1
  dependencies = ['torch']
2
 
3
+ from modeling_sagvit import SAGViTClassifier
4
  import torch
5
 
6
  def SAGViT(pretrained=False, **kwargs):
sag_vit_model.py → modeling_sagvit.py RENAMED
@@ -1,11 +1,16 @@
1
  import torch
2
  from torch import nn
 
 
 
 
3
  from huggingface_hub import PyTorchModelHubMixin
4
 
5
  from torch_geometric.data import Batch
6
  from model_components import EfficientNetV2FeatureExtractor, GATGNN, TransformerEncoder, MLPBlock
7
  from graph_construction import build_graph_from_patches, build_graph_data_from_patches
8
 
 
9
  ###############################################################################
10
  # SAG-ViT Model:
11
  # This class combines:
@@ -15,7 +20,26 @@ from graph_construction import build_graph_from_patches, build_graph_data_from_p
15
  # 4) A final MLP classifier.
16
  ###############################################################################
17
 
18
- class SAGViTClassifier(nn.Module, PyTorchModelHubMixin):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  """
20
  SAG-ViT: Scale-Aware Graph Attention Vision Transformer
21
 
@@ -32,41 +56,41 @@ class SAGViTClassifier(nn.Module, PyTorchModelHubMixin):
32
  Outputs:
33
  - out (Tensor): Classification logits (B, num_classes)
34
  """
35
- def __init__(
36
- self,
37
- patch_size=(4,4),
38
- num_classes=10,
39
- d_model=64,
40
- nhead=4,
41
- num_layers=2,
42
- dim_feedforward=64,
43
- hidden_mlp_features=64,
44
- in_channels=2560, # Derived from patch dimensions and CNN output channels
45
- gcn_hidden=128,
46
- gcn_out=64
47
- ):
48
- super(SAGViTClassifier, self).__init__()
49
 
50
  # CNN feature extractor (frozen pre-trained EfficientNetv2)
51
  self.cnn = EfficientNetV2FeatureExtractor()
52
 
53
  # Graph Attention Network to process patch embeddings
54
- self.gcn = GATGNN(in_channels=in_channels, hidden_channels=gcn_hidden, out_channels=gcn_out)
 
 
 
 
55
 
56
  # Learnable positional embedding for Transformer input
57
- self.positional_embedding = nn.Parameter(torch.randn(1, 1, d_model))
58
  # Extra embedding token (similar to class token) to summarize global info
59
- self.extra_embedding = nn.Parameter(torch.randn(1, d_model))
60
 
61
  # Transformer encoder to capture long-range global dependencies
62
- self.transformer_encoder = TransformerEncoder(d_model, nhead, num_layers, dim_feedforward)
 
 
 
 
 
63
 
64
  # MLP classification head
65
- self.mlp = MLPBlock(d_model, hidden_mlp_features, num_classes)
66
 
67
- self.patch_size = patch_size
68
-
69
- def forward(self, x):
70
  # Step 1: High-fidelity feature extraction from CNN
71
  feature_map = self.cnn(x)
72
 
@@ -103,5 +127,9 @@ class SAGViTClassifier(nn.Module, PyTorchModelHubMixin):
103
  x_pooled = x_trans.mean(dim=1) # (B, D)
104
 
105
  # Classification
106
- out = self.mlp(x_pooled)
107
- return out
 
 
 
 
 
1
  import torch
2
  from torch import nn
3
+ from transformers import AutoConfig, PretrainedConfig, AutoModel, PreTrainedModel
4
+ from transformers.models.auto import AutoConfig, CONFIG_MAPPING, MODEL_MAPPING
5
+ from transformers.utils import logging
6
+ from transformers.modeling_utils import ModelOutput
7
  from huggingface_hub import PyTorchModelHubMixin
8
 
9
  from torch_geometric.data import Batch
10
  from model_components import EfficientNetV2FeatureExtractor, GATGNN, TransformerEncoder, MLPBlock
11
  from graph_construction import build_graph_from_patches, build_graph_data_from_patches
12
 
13
+
14
  ###############################################################################
15
  # SAG-ViT Model:
16
  # This class combines:
 
20
  # 4) A final MLP classifier.
21
  ###############################################################################
22
 
23
+
24
+ # Custom model registration
25
+ class SAGViTConfig(PretrainedConfig):
26
+ model_type = "sagvit"
27
+
28
+ def __init__(self, **kwargs):
29
+ super().__init__(**kwargs)
30
+ self.d_model = kwargs.get("d_model", 64)
31
+ self.dim_feedforward = kwargs.get("dim_feedforward", 64)
32
+ self.gcn_hidden = kwargs.get("gcn_hidden", 128)
33
+ self.gcn_out = kwargs.get("gcn_out", 64)
34
+ self.hidden_mlp_features = kwargs.get("hidden_mlp_features", 64)
35
+ self.in_channels = kwargs.get("in_channels", 2560)
36
+ self.nhead = kwargs.get("nhead", 4)
37
+ self.num_classes = kwargs.get("num_classes", 10)
38
+ self.num_layers = kwargs.get("num_layers", 2)
39
+ self.patch_size = kwargs.get("patch_size", (4, 4))
40
+
41
+
42
+ class SAGViTClassifier(PreTrainedModel):
43
  """
44
  SAG-ViT: Scale-Aware Graph Attention Vision Transformer
45
 
 
56
  Outputs:
57
  - out (Tensor): Classification logits (B, num_classes)
58
  """
59
+
60
+ config_class = SAGViTConfig
61
+ def __init__(self, config):
62
+ super().__init__(config)
63
+
64
+ self.patch_size = config.patch_size
65
+ self.num_classes = config.num_classes
 
 
 
 
 
 
 
66
 
67
  # CNN feature extractor (frozen pre-trained EfficientNetv2)
68
  self.cnn = EfficientNetV2FeatureExtractor()
69
 
70
  # Graph Attention Network to process patch embeddings
71
+ self.gcn = GATGNN(
72
+ in_channels=config.in_channels,
73
+ hidden_channels=config.gcn_hidden,
74
+ out_channels=config.gcn_out,
75
+ )
76
 
77
  # Learnable positional embedding for Transformer input
78
+ self.positional_embedding = nn.Parameter(torch.randn(1, 1, config.d_model))
79
  # Extra embedding token (similar to class token) to summarize global info
80
+ self.extra_embedding = nn.Parameter(torch.randn(1, config.d_model))
81
 
82
  # Transformer encoder to capture long-range global dependencies
83
+ self.transformer_encoder = TransformerEncoder(
84
+ d_model=config.d_model,
85
+ nhead=config.nhead,
86
+ num_layers=config.num_layers,
87
+ dim_feedforward=config.dim_feedforward,
88
+ )
89
 
90
  # MLP classification head
91
+ self.mlp = MLPBlock(config.d_model, config.hidden_mlp_features, config.num_classes)
92
 
93
+ def forward(self, x, **kwargs):
 
 
94
  # Step 1: High-fidelity feature extraction from CNN
95
  feature_map = self.cnn(x)
96
 
 
127
  x_pooled = x_trans.mean(dim=1) # (B, D)
128
 
129
  # Classification
130
+ logits = self.mlp(x_pooled)
131
+ return ModelOutput(logits=logits)
132
+
133
+ # Register custom model and config
134
+ CONFIG_MAPPING.register("sagvit", SAGViTConfig)
135
+ MODEL_MAPPING.register(SAGViTConfig, SAGViTClassifier)
push_model_to_hfhub.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModel
2
+ from modeling_sagvit import SAGViTClassifier
3
+
4
+ # Initialize config and model
5
+ config = AutoConfig.from_pretrained("shravvvv/SAG-ViT")
6
+ model = AutoModel.from_pretrained("shravvvv/SAG-ViT", config=config)
7
+
8
+ # Push model to the Hub
9
+ model.push_to_hub("shravvvv/SAG-ViT")
register_model.py CHANGED
@@ -3,7 +3,7 @@ from transformers.models.auto.configuration_auto import CONFIG_MAPPING
3
  from transformers.models.auto.modeling_auto import MODEL_MAPPING
4
 
5
  from sagvit_config import SAGViTConfig
6
- from sag_vit_model import SAGViTClassifier
7
 
8
  # Register the configuration
9
  CONFIG_MAPPING.register("sagvit", SAGViTConfig)
 
3
  from transformers.models.auto.modeling_auto import MODEL_MAPPING
4
 
5
  from sagvit_config import SAGViTConfig
6
+ from modeling_sagvit import SAGViTClassifier
7
 
8
  # Register the configuration
9
  CONFIG_MAPPING.register("sagvit", SAGViTConfig)
sagvit_config.py DELETED
@@ -1,28 +0,0 @@
1
- from transformers import PretrainedConfig
2
-
3
- class SAGViTConfig(PretrainedConfig):
4
- model_type = "sagvit"
5
-
6
- def __init__(self,
7
- d_model=64,
8
- dim_feedforward=64,
9
- gcn_hidden=128,
10
- gcn_out=64,
11
- hidden_mlp_features=64,
12
- in_channels=2560,
13
- nhead=4,
14
- num_classes=10,
15
- num_layers=2,
16
- patch_size=(4, 4),
17
- **kwargs):
18
- super().__init__(**kwargs)
19
- self.d_model = d_model
20
- self.dim_feedforward = dim_feedforward
21
- self.gcn_hidden = gcn_hidden
22
- self.gcn_out = gcn_out
23
- self.hidden_mlp_features = hidden_mlp_features
24
- self.in_channels = in_channels
25
- self.nhead = nhead
26
- self.num_classes = num_classes
27
- self.num_layers = num_layers
28
- self.patch_size = patch_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_sag_vit_model.py CHANGED
@@ -1,6 +1,6 @@
1
  import unittest
2
  import torch
3
- from sag_vit_model import SAGViTClassifier
4
 
5
  class TestSAGViTModel(unittest.TestCase):
6
  def test_forward_pass(self):
 
1
  import unittest
2
  import torch
3
+ from modeling_sagvit import SAGViTClassifier
4
 
5
  class TestSAGViTModel(unittest.TestCase):
6
  def test_forward_pass(self):
tests/test_train.py CHANGED
@@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
3
  import torch
4
  import torch.nn as nn
5
  from train import train_model
6
- from sag_vit_model import SAGViTClassifier
7
 
8
  class TestTrain(unittest.TestCase):
9
  @patch("train.optim.Adam")
 
3
  import torch
4
  import torch.nn as nn
5
  from train import train_model
6
+ from modeling_sagvit import SAGViTClassifier
7
 
8
  class TestTrain(unittest.TestCase):
9
  @patch("train.optim.Adam")
train.py CHANGED
@@ -8,7 +8,7 @@ from sklearn.metrics import (precision_score, recall_score, f1_score,
8
  roc_auc_score, cohen_kappa_score, matthews_corrcoef,
9
  confusion_matrix)
10
 
11
- from sag_vit_model import SAGViTClassifier
12
  from data_loader import get_dataloaders
13
 
14
  #####################################################################
 
8
  roc_auc_score, cohen_kappa_score, matthews_corrcoef,
9
  confusion_matrix)
10
 
11
+ from modeling_sagvit import SAGViTClassifier
12
  from data_loader import get_dataloaders
13
 
14
  #####################################################################