Added files
Browse files- __init__.py +1 -0
- hubconf.py +1 -1
- sag_vit_model.py → modeling_sagvit.py +53 -25
- push_model_to_hfhub.py +9 -0
- register_model.py +1 -1
- sagvit_config.py +0 -28
- tests/test_sag_vit_model.py +1 -1
- tests/test_train.py +1 -1
- train.py +1 -1
__init__.py
CHANGED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .modeling_sagvit import SAGViTClassifier
|
hubconf.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
dependencies = ['torch']
|
2 |
|
3 |
-
from
|
4 |
import torch
|
5 |
|
6 |
def SAGViT(pretrained=False, **kwargs):
|
|
|
1 |
dependencies = ['torch']
|
2 |
|
3 |
+
from modeling_sagvit import SAGViTClassifier
|
4 |
import torch
|
5 |
|
6 |
def SAGViT(pretrained=False, **kwargs):
|
sag_vit_model.py → modeling_sagvit.py
RENAMED
@@ -1,11 +1,16 @@
|
|
1 |
import torch
|
2 |
from torch import nn
|
|
|
|
|
|
|
|
|
3 |
from huggingface_hub import PyTorchModelHubMixin
|
4 |
|
5 |
from torch_geometric.data import Batch
|
6 |
from model_components import EfficientNetV2FeatureExtractor, GATGNN, TransformerEncoder, MLPBlock
|
7 |
from graph_construction import build_graph_from_patches, build_graph_data_from_patches
|
8 |
|
|
|
9 |
###############################################################################
|
10 |
# SAG-ViT Model:
|
11 |
# This class combines:
|
@@ -15,7 +20,26 @@ from graph_construction import build_graph_from_patches, build_graph_data_from_p
|
|
15 |
# 4) A final MLP classifier.
|
16 |
###############################################################################
|
17 |
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
"""
|
20 |
SAG-ViT: Scale-Aware Graph Attention Vision Transformer
|
21 |
|
@@ -32,41 +56,41 @@ class SAGViTClassifier(nn.Module, PyTorchModelHubMixin):
|
|
32 |
Outputs:
|
33 |
- out (Tensor): Classification logits (B, num_classes)
|
34 |
"""
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
dim_feedforward=64,
|
43 |
-
hidden_mlp_features=64,
|
44 |
-
in_channels=2560, # Derived from patch dimensions and CNN output channels
|
45 |
-
gcn_hidden=128,
|
46 |
-
gcn_out=64
|
47 |
-
):
|
48 |
-
super(SAGViTClassifier, self).__init__()
|
49 |
|
50 |
# CNN feature extractor (frozen pre-trained EfficientNetv2)
|
51 |
self.cnn = EfficientNetV2FeatureExtractor()
|
52 |
|
53 |
# Graph Attention Network to process patch embeddings
|
54 |
-
self.gcn = GATGNN(
|
|
|
|
|
|
|
|
|
55 |
|
56 |
# Learnable positional embedding for Transformer input
|
57 |
-
self.positional_embedding = nn.Parameter(torch.randn(1, 1, d_model))
|
58 |
# Extra embedding token (similar to class token) to summarize global info
|
59 |
-
self.extra_embedding = nn.Parameter(torch.randn(1, d_model))
|
60 |
|
61 |
# Transformer encoder to capture long-range global dependencies
|
62 |
-
self.transformer_encoder = TransformerEncoder(
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
# MLP classification head
|
65 |
-
self.mlp = MLPBlock(d_model, hidden_mlp_features, num_classes)
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
def forward(self, x):
|
70 |
# Step 1: High-fidelity feature extraction from CNN
|
71 |
feature_map = self.cnn(x)
|
72 |
|
@@ -103,5 +127,9 @@ class SAGViTClassifier(nn.Module, PyTorchModelHubMixin):
|
|
103 |
x_pooled = x_trans.mean(dim=1) # (B, D)
|
104 |
|
105 |
# Classification
|
106 |
-
|
107 |
-
return
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
+
from transformers import AutoConfig, PretrainedConfig, AutoModel, PreTrainedModel
|
4 |
+
from transformers.models.auto import AutoConfig, CONFIG_MAPPING, MODEL_MAPPING
|
5 |
+
from transformers.utils import logging
|
6 |
+
from transformers.modeling_utils import ModelOutput
|
7 |
from huggingface_hub import PyTorchModelHubMixin
|
8 |
|
9 |
from torch_geometric.data import Batch
|
10 |
from model_components import EfficientNetV2FeatureExtractor, GATGNN, TransformerEncoder, MLPBlock
|
11 |
from graph_construction import build_graph_from_patches, build_graph_data_from_patches
|
12 |
|
13 |
+
|
14 |
###############################################################################
|
15 |
# SAG-ViT Model:
|
16 |
# This class combines:
|
|
|
20 |
# 4) A final MLP classifier.
|
21 |
###############################################################################
|
22 |
|
23 |
+
|
24 |
+
# Custom model registration
|
25 |
+
class SAGViTConfig(PretrainedConfig):
|
26 |
+
model_type = "sagvit"
|
27 |
+
|
28 |
+
def __init__(self, **kwargs):
|
29 |
+
super().__init__(**kwargs)
|
30 |
+
self.d_model = kwargs.get("d_model", 64)
|
31 |
+
self.dim_feedforward = kwargs.get("dim_feedforward", 64)
|
32 |
+
self.gcn_hidden = kwargs.get("gcn_hidden", 128)
|
33 |
+
self.gcn_out = kwargs.get("gcn_out", 64)
|
34 |
+
self.hidden_mlp_features = kwargs.get("hidden_mlp_features", 64)
|
35 |
+
self.in_channels = kwargs.get("in_channels", 2560)
|
36 |
+
self.nhead = kwargs.get("nhead", 4)
|
37 |
+
self.num_classes = kwargs.get("num_classes", 10)
|
38 |
+
self.num_layers = kwargs.get("num_layers", 2)
|
39 |
+
self.patch_size = kwargs.get("patch_size", (4, 4))
|
40 |
+
|
41 |
+
|
42 |
+
class SAGViTClassifier(PreTrainedModel):
|
43 |
"""
|
44 |
SAG-ViT: Scale-Aware Graph Attention Vision Transformer
|
45 |
|
|
|
56 |
Outputs:
|
57 |
- out (Tensor): Classification logits (B, num_classes)
|
58 |
"""
|
59 |
+
|
60 |
+
config_class = SAGViTConfig
|
61 |
+
def __init__(self, config):
|
62 |
+
super().__init__(config)
|
63 |
+
|
64 |
+
self.patch_size = config.patch_size
|
65 |
+
self.num_classes = config.num_classes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
# CNN feature extractor (frozen pre-trained EfficientNetv2)
|
68 |
self.cnn = EfficientNetV2FeatureExtractor()
|
69 |
|
70 |
# Graph Attention Network to process patch embeddings
|
71 |
+
self.gcn = GATGNN(
|
72 |
+
in_channels=config.in_channels,
|
73 |
+
hidden_channels=config.gcn_hidden,
|
74 |
+
out_channels=config.gcn_out,
|
75 |
+
)
|
76 |
|
77 |
# Learnable positional embedding for Transformer input
|
78 |
+
self.positional_embedding = nn.Parameter(torch.randn(1, 1, config.d_model))
|
79 |
# Extra embedding token (similar to class token) to summarize global info
|
80 |
+
self.extra_embedding = nn.Parameter(torch.randn(1, config.d_model))
|
81 |
|
82 |
# Transformer encoder to capture long-range global dependencies
|
83 |
+
self.transformer_encoder = TransformerEncoder(
|
84 |
+
d_model=config.d_model,
|
85 |
+
nhead=config.nhead,
|
86 |
+
num_layers=config.num_layers,
|
87 |
+
dim_feedforward=config.dim_feedforward,
|
88 |
+
)
|
89 |
|
90 |
# MLP classification head
|
91 |
+
self.mlp = MLPBlock(config.d_model, config.hidden_mlp_features, config.num_classes)
|
92 |
|
93 |
+
def forward(self, x, **kwargs):
|
|
|
|
|
94 |
# Step 1: High-fidelity feature extraction from CNN
|
95 |
feature_map = self.cnn(x)
|
96 |
|
|
|
127 |
x_pooled = x_trans.mean(dim=1) # (B, D)
|
128 |
|
129 |
# Classification
|
130 |
+
logits = self.mlp(x_pooled)
|
131 |
+
return ModelOutput(logits=logits)
|
132 |
+
|
133 |
+
# Register custom model and config
|
134 |
+
CONFIG_MAPPING.register("sagvit", SAGViTConfig)
|
135 |
+
MODEL_MAPPING.register(SAGViTConfig, SAGViTClassifier)
|
push_model_to_hfhub.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoConfig, AutoModel
|
2 |
+
from modeling_sagvit import SAGViTClassifier
|
3 |
+
|
4 |
+
# Initialize config and model
|
5 |
+
config = AutoConfig.from_pretrained("shravvvv/SAG-ViT")
|
6 |
+
model = AutoModel.from_pretrained("shravvvv/SAG-ViT", config=config)
|
7 |
+
|
8 |
+
# Push model to the Hub
|
9 |
+
model.push_to_hub("shravvvv/SAG-ViT")
|
register_model.py
CHANGED
@@ -3,7 +3,7 @@ from transformers.models.auto.configuration_auto import CONFIG_MAPPING
|
|
3 |
from transformers.models.auto.modeling_auto import MODEL_MAPPING
|
4 |
|
5 |
from sagvit_config import SAGViTConfig
|
6 |
-
from
|
7 |
|
8 |
# Register the configuration
|
9 |
CONFIG_MAPPING.register("sagvit", SAGViTConfig)
|
|
|
3 |
from transformers.models.auto.modeling_auto import MODEL_MAPPING
|
4 |
|
5 |
from sagvit_config import SAGViTConfig
|
6 |
+
from modeling_sagvit import SAGViTClassifier
|
7 |
|
8 |
# Register the configuration
|
9 |
CONFIG_MAPPING.register("sagvit", SAGViTConfig)
|
sagvit_config.py
DELETED
@@ -1,28 +0,0 @@
|
|
1 |
-
from transformers import PretrainedConfig
|
2 |
-
|
3 |
-
class SAGViTConfig(PretrainedConfig):
|
4 |
-
model_type = "sagvit"
|
5 |
-
|
6 |
-
def __init__(self,
|
7 |
-
d_model=64,
|
8 |
-
dim_feedforward=64,
|
9 |
-
gcn_hidden=128,
|
10 |
-
gcn_out=64,
|
11 |
-
hidden_mlp_features=64,
|
12 |
-
in_channels=2560,
|
13 |
-
nhead=4,
|
14 |
-
num_classes=10,
|
15 |
-
num_layers=2,
|
16 |
-
patch_size=(4, 4),
|
17 |
-
**kwargs):
|
18 |
-
super().__init__(**kwargs)
|
19 |
-
self.d_model = d_model
|
20 |
-
self.dim_feedforward = dim_feedforward
|
21 |
-
self.gcn_hidden = gcn_hidden
|
22 |
-
self.gcn_out = gcn_out
|
23 |
-
self.hidden_mlp_features = hidden_mlp_features
|
24 |
-
self.in_channels = in_channels
|
25 |
-
self.nhead = nhead
|
26 |
-
self.num_classes = num_classes
|
27 |
-
self.num_layers = num_layers
|
28 |
-
self.patch_size = patch_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_sag_vit_model.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import unittest
|
2 |
import torch
|
3 |
-
from
|
4 |
|
5 |
class TestSAGViTModel(unittest.TestCase):
|
6 |
def test_forward_pass(self):
|
|
|
1 |
import unittest
|
2 |
import torch
|
3 |
+
from modeling_sagvit import SAGViTClassifier
|
4 |
|
5 |
class TestSAGViTModel(unittest.TestCase):
|
6 |
def test_forward_pass(self):
|
tests/test_train.py
CHANGED
@@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
from train import train_model
|
6 |
-
from
|
7 |
|
8 |
class TestTrain(unittest.TestCase):
|
9 |
@patch("train.optim.Adam")
|
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
from train import train_model
|
6 |
+
from modeling_sagvit import SAGViTClassifier
|
7 |
|
8 |
class TestTrain(unittest.TestCase):
|
9 |
@patch("train.optim.Adam")
|
train.py
CHANGED
@@ -8,7 +8,7 @@ from sklearn.metrics import (precision_score, recall_score, f1_score,
|
|
8 |
roc_auc_score, cohen_kappa_score, matthews_corrcoef,
|
9 |
confusion_matrix)
|
10 |
|
11 |
-
from
|
12 |
from data_loader import get_dataloaders
|
13 |
|
14 |
#####################################################################
|
|
|
8 |
roc_auc_score, cohen_kappa_score, matthews_corrcoef,
|
9 |
confusion_matrix)
|
10 |
|
11 |
+
from modeling_sagvit import SAGViTClassifier
|
12 |
from data_loader import get_dataloaders
|
13 |
|
14 |
#####################################################################
|