shravvvv commited on
Commit
41a1cb5
·
1 Parent(s): 4708396

Removed Files

Browse files
.gitignore DELETED
@@ -1,3 +0,0 @@
1
- data/
2
- __pycache__
3
- tests/__pycache__
 
 
 
 
LICENSE DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2024 Shravan Venkatraman
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_loader.py DELETED
@@ -1,47 +0,0 @@
1
- import os
2
- from torch.utils.data import DataLoader, random_split
3
- from torchvision import datasets, transforms
4
-
5
- def get_dataloaders(data_dir="path/to/data/dir", batch_size=512, train_split=0.8, img_size=224, num_workers=4):
6
- """
7
- Returns training and validation dataloaders for an image classification dataset.
8
-
9
- Parameters:
10
- - data_dir (str): Path to the directory containing image data in a folder structure compatible with ImageFolder.
11
- - batch_size (int): Number of samples per batch.
12
- - train_split (float): Fraction of data to use for training. Remaining is for validation.
13
- - img_size (int): Target size to which all images are resized after validation.
14
- - num_workers (int): Number of worker processes for data loading.
15
-
16
- Image Size Validation:
17
- - Minimum allowed image size: 49x49 pixels.
18
- - If an image has either width or height less than 49 pixels, a ValueError is raised.
19
-
20
- Returns:
21
- - train_dataloader (DataLoader): DataLoader for the training split.
22
- - val_dataloader (DataLoader): DataLoader for the validation split.
23
- """
24
-
25
- # Check if the provided image size is valid
26
- if img_size < 49:
27
- raise ValueError(f"Image size must be at least 49x49 pixels, but got {img_size}x{img_size}.")
28
-
29
- transform = transforms.Compose([
30
- transforms.Resize((img_size, img_size)),
31
- transforms.ToTensor(),
32
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
33
- ])
34
-
35
- # Load full dataset
36
- full_dataset = datasets.ImageFolder(root=data_dir, transform=transform)
37
-
38
- # Split into training and validation sets
39
- train_size = int(train_split * len(full_dataset))
40
- val_size = len(full_dataset) - train_size
41
- train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
42
-
43
- # Create dataloaders
44
- train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
45
- val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
46
-
47
- return train_dataloader, val_dataloader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graph_construction.py DELETED
@@ -1,138 +0,0 @@
1
- import torch
2
- import networkx as nx
3
- from torch_geometric.utils import from_networkx
4
-
5
- ####################################################################
6
- # These functions reflect the methods described in Section 3.1 and 3.2
7
- # of the SAG-ViT paper, where high-fidelity feature patches are extracted
8
- # from the CNN feature maps and organized into a graph structure.
9
- ####################################################################
10
-
11
- def extract_patches(feature_map, patch_size=(4, 4)):
12
- """
13
- Extracts non-overlapping patches from a feature map to form nodes in a graph.
14
-
15
- Parameters:
16
- - feature_map (Tensor): The feature map from the CNN of shape (B, C, H', W').
17
- H' and W' are reduced spatial dimensions after CNN feature extraction.
18
- - patch_size (tuple): Spatial size (height, width) of each patch.
19
-
20
- Returns:
21
- - patches (Tensor): Tensor of shape (B, N, C, patch_h, patch_w), where N is the number of patches per image.
22
- """
23
- b, c, h, w = feature_map.size()
24
- patch_h, patch_w = patch_size
25
-
26
- # Unfold extracts sliding patches; here we align so that they are non-overlapping
27
- patches = feature_map.unfold(2, patch_h, patch_h).unfold(3, patch_w, patch_w)
28
-
29
- # Rearrange to have patches as separate units
30
- patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
31
- patches = patches.view(b, -1, c, patch_h, patch_w)
32
- return patches
33
-
34
- def construct_graph_from_patch(patch_index, patch_shape, image_shape):
35
- """
36
- Constructs edges between patch nodes based on spatial adjacency (k-connectivity).
37
- This follows the approach described in Section 3.2 of SAG-ViT, where patches
38
- are arranged in a grid and connected to their spatial neighbors.
39
-
40
- Parameters:
41
- - patch_index (int): Index of the current patch node.
42
- - patch_shape (tuple): (patch_height, patch_width).
43
- - image_shape (tuple): (height, width) of the feature map.
44
-
45
- Returns:
46
- - G (nx.Graph): A graph with a single node and edges to its neighbors (to be composed globally).
47
- """
48
- G = nx.Graph()
49
-
50
- # Compute grid dimensions (how many patches along height and width)
51
- grid_height = image_shape[0] // patch_shape[0]
52
- grid_width = image_shape[1] // patch_shape[1]
53
-
54
- # Current node index in a flattened grid
55
- current_node = patch_index
56
-
57
- G.add_node(current_node)
58
-
59
- # 8-neighborhood connectivity (up, down, left, right, diagonals)
60
- neighbor_offsets = [(-1, 0), (1, 0), (0, -1), (0, 1),
61
- (-1, -1), (-1, 1), (1, -1), (1, 1)]
62
-
63
- # Recover row, col from patch_index
64
- row = current_node // grid_width
65
- col = current_node % grid_width
66
-
67
- for dr, dc in neighbor_offsets:
68
- neighbor_row = row + dr
69
- neighbor_col = col + dc
70
- if 0 <= neighbor_row < grid_height and 0 <= neighbor_col < grid_width:
71
- neighbor_node = neighbor_row * grid_width + neighbor_col
72
- G.add_edge(current_node, neighbor_node)
73
-
74
- return G
75
-
76
- def build_graph_from_patches(feature_map, patch_size=(4,4)):
77
- """
78
- Builds a global graph for each image in the batch, where each node corresponds
79
- to a patch, and edges represent spatial adjacency. This graph captures local
80
- spatial relationships of the patches, as outlined in Sections 3.1 and 3.2 of SAG-ViT.
81
-
82
- Parameters:
83
- - feature_map (Tensor): CNN output (B, C, H', W').
84
- - patch_size (tuple): Size of each patch (patch_h, patch_w).
85
-
86
- Returns:
87
- - G_global_batch (list): A list of NetworkX graphs, one per image in the batch.
88
- - patches (Tensor): The extracted patches (B, N, C, patch_h, patch_w).
89
- """
90
- patches = extract_patches(feature_map, patch_size)
91
- batch_size = patches.size(0)
92
-
93
- grid_height = feature_map.size(2) // patch_size[0]
94
- grid_width = feature_map.size(3) // patch_size[1]
95
- num_patches = grid_height * grid_width
96
-
97
- G_global_batch = []
98
- for batch_idx in range(batch_size):
99
- G_global = nx.Graph()
100
- # Construct a global graph by composing individual patch-based graphs
101
- for patch_idx in range(num_patches):
102
- G_patch = construct_graph_from_patch(
103
- patch_index=patch_idx,
104
- patch_shape=patch_size,
105
- image_shape=(feature_map.size(2), feature_map.size(3))
106
- )
107
- G_global = nx.compose(G_global, G_patch)
108
- G_global_batch.append(G_global)
109
-
110
- return G_global_batch, patches
111
-
112
- def build_graph_data_from_patches(G_global_batch, patches):
113
- """
114
- Converts NetworkX graphs and associated patches into PyTorch Geometric Data objects.
115
- Each node corresponds to a patch vectorized into a feature node embedding.
116
-
117
- Parameters:
118
- - G_global_batch (list): List of global graphs (one per image) in NetworkX form.
119
- - patches (Tensor): (B, N, C, patch_h, patch_w) patch tensor.
120
-
121
- Returns:
122
- - data_list (list): List of PyTorch Geometric Data objects, where data.x are node features,
123
- and data.edge_index is the adjacency from the constructed graph.
124
- """
125
- from_networkx_ = from_networkx # local alias to avoid confusion
126
-
127
- data_list = []
128
- batch_size, num_patches, channels, patch_h, patch_w = patches.size()
129
-
130
- for batch_idx, G_global in enumerate(G_global_batch):
131
- # Flatten each patch into a feature vector
132
- node_features = patches[batch_idx].view(num_patches, -1)
133
-
134
- G_pygeom = from_networkx_(G_global)
135
- G_pygeom.x = node_features
136
- data_list.append(G_pygeom)
137
-
138
- return data_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hubconf.py DELETED
@@ -1,23 +0,0 @@
1
- dependencies = ['torch']
2
-
3
- from sag_vit_model import SAGViTClassifier
4
- import torch
5
-
6
- def SAGViT(pretrained=False, **kwargs):
7
- """
8
- SAG-ViT model endpoint.
9
- Args:
10
- pretrained (bool): If True, loads pretrained weights.
11
- **kwargs: Additional arguments for the model.
12
- Returns:
13
- model (nn.Module): The SAG-ViT model as proposed in the
14
- paper: SAG-ViT: A Scale-Aware, High-Fidelity Patching
15
- Approach with Graph Attention for Vision Transformers.
16
- https://doi.org/10.48550/arXiv.2411.09420
17
- """
18
- model = SAGViTClassifier(**kwargs)
19
- if pretrained:
20
- checkpoint = ''
21
- state_dict = torch.hub.load_state_dict_from_url(checkpoint, progress=True)
22
- model.load_state_dict(state_dict)
23
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
images/SAG-ViT.png DELETED
Binary file (358 kB)
 
model_components.py DELETED
@@ -1,119 +0,0 @@
1
- import torch
2
- from torch import nn
3
- import torch.nn.functional as F
4
- from torch_geometric.nn import GATConv, global_mean_pool
5
-
6
- from torchvision import models
7
-
8
- ###############################################################
9
- # These modules correspond to core building blocks of SAG-ViT:
10
- # 1. A CNN feature extractor for high-fidelity multi-scale feature maps.
11
- # 2. A Graph Attention Network (GAT) to refine patch embeddings.
12
- # 3. A Transformer Encoder to capture global long-range dependencies.
13
- # 4. An MLP classifier head.
14
- ###############################################################
15
-
16
- class EfficientNetV2FeatureExtractor(nn.Module):
17
- """
18
- Extracts multi-scale, spatially-rich, and semantically-meaningful feature maps
19
- from images using a pre-trained EfficientNetV2-S model. This corresponds
20
- to Section 3.1, where a CNN backbone (EfficientNetV2-S) is used to produce rich
21
- feature maps that preserve semantic information at multiple scales.
22
- """
23
- def __init__(self, pretrained=False):
24
- super(EfficientNetV2FeatureExtractor, self).__init__()
25
-
26
- # Load EfficientNetV2-S with pretrained weights
27
- efficientnet = models.efficientnet_v2_s(
28
- weights="IMAGENET1K_V1" if pretrained else None
29
- )
30
-
31
- # Extract layers up to the last block before downsampling below 16x16
32
- self.extractor = nn.Sequential(*list(efficientnet.features.children())[:-2])
33
-
34
- # Freezing the extractor parameters (if desired)
35
- for param in self.extractor.parameters():
36
- param.requires_grad = False
37
-
38
-
39
- def forward(self, x):
40
- """
41
- Forward pass through the CNN backbone.
42
-
43
- Input:
44
- - x (Tensor): Input images of shape (B, 3, H, W)
45
-
46
- Output:
47
- - features (Tensor): Extracted feature map of shape (B, C, H', W'),
48
- where H' and W' are reduced spatial dimensions.
49
- """
50
- features = self.extractor(x)
51
- return features
52
-
53
- class GATGNN(nn.Module):
54
- """
55
- A Graph Attention Network (GAT) that processes patch-graph embeddings.
56
- This module corresponds to the Graph Attention stage (Section 3.3),
57
- refining local relationships between patches in a learned manner.
58
- """
59
- def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
60
- super(GATGNN, self).__init__()
61
- # GAT layers:
62
- # First layer maps raw patch embeddings to a higher-level representation.
63
- self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
64
- # Second layer produces final node embeddings with a single head.
65
- self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1)
66
- self.pool = global_mean_pool
67
-
68
- def forward(self, data):
69
- """
70
- Input:
71
- - data (PyG Data): Contains x (node features), edge_index (graph edges), and batch indexing.
72
-
73
- Output:
74
- - x (Tensor): Aggregated graph-level embedding after mean pooling.
75
- """
76
- x, edge_index, batch = data.x, data.edge_index, data.batch
77
- x = F.elu(self.conv1(x, edge_index))
78
- x = self.conv2(x, edge_index)
79
- x = self.pool(x, batch)
80
- return x
81
-
82
- class TransformerEncoder(nn.Module):
83
- """
84
- A Transformer encoder to capture long-range dependencies among patch embeddings.
85
- Integrates global dependencies after GAT processing, as per Section 3.3.
86
- """
87
- def __init__(self, d_model, nhead, num_layers, dim_feedforward):
88
- super(TransformerEncoder, self).__init__()
89
- encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward)
90
- self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
91
-
92
- def forward(self, x):
93
- """
94
- Input:
95
- - x (Tensor): Sequence of patch embeddings with shape (B, N, D).
96
-
97
- Output:
98
- - (Tensor): Transformed embeddings with global relationships integrated (B, N, D).
99
- """
100
- # The Transformer expects (N, B, D), so transpose first
101
- x = x.transpose(0, 1) # (N, B, D)
102
- x = self.transformer_encoder(x)
103
- x = x.transpose(0, 1) # (B, N, D)
104
- return x
105
-
106
- class MLPBlock(nn.Module):
107
- """
108
- An MLP classification head to map final global embeddings to classification logits.
109
- """
110
- def __init__(self, in_features, hidden_features, out_features):
111
- super(MLPBlock, self).__init__()
112
- self.mlp = nn.Sequential(
113
- nn.Linear(in_features, hidden_features),
114
- nn.ReLU(),
115
- nn.Linear(hidden_features, out_features)
116
- )
117
-
118
- def forward(self, x):
119
- return self.mlp(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1,12 +0,0 @@
1
- numpy==1.26.4
2
- pandas==2.2.3
3
- matplotlib==3.7.5
4
- seaborn==0.12.2
5
- tqdm==4.66.4
6
- psutil==5.9.3
7
- pynvml==11.4.1
8
- scikit-learn==1.2.2
9
- torch==2.4.0
10
- torch-geometric==2.6.1
11
- torchvision==0.19.0
12
- networkx==3.3
 
 
 
 
 
 
 
 
 
 
 
 
 
sag_vit_model.py DELETED
@@ -1,106 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
- from torch_geometric.data import Batch
5
- from model_components import EfficientNetV2FeatureExtractor, GATGNN, TransformerEncoder, MLPBlock
6
- from graph_construction import build_graph_from_patches, build_graph_data_from_patches
7
-
8
- ###############################################################################
9
- # SAG-ViT Model:
10
- # This class combines:
11
- # 1) CNN backbone to produce high-fidelity feature maps (Section 3.1),
12
- # 2) Graph construction and GAT to refine local patch embeddings (Section 3.2 and 3.3),
13
- # 3) A Transformer encoder to capture global relationships (Section 3.3),
14
- # 4) A final MLP classifier.
15
- ###############################################################################
16
-
17
- class SAGViTClassifier(nn.Module):
18
- """
19
- SAG-ViT: Scale-Aware Graph Attention Vision Transformer
20
-
21
- This model integrates the following steps:
22
- - Extract multi-scale features from images using a CNN backbone (InceptionV3 here).
23
- - Partition the feature map into patches and build a graph where each node is a patch.
24
- - Use a Graph Attention Network (GAT) to refine patch embeddings based on local spatial relationships.
25
- - Utilize a Transformer encoder to model long-range dependencies and integrate multi-scale information.
26
- - Finally, classify the resulting representation into desired classes.
27
-
28
- Inputs:
29
- - x (Tensor): Input images (B, 3, H, W)
30
-
31
- Outputs:
32
- - out (Tensor): Classification logits (B, num_classes)
33
- """
34
- def __init__(
35
- self,
36
- patch_size=(4,4),
37
- num_classes=10,
38
- d_model=64,
39
- nhead=4,
40
- num_layers=2,
41
- dim_feedforward=64,
42
- hidden_mlp_features=64,
43
- in_channels=2560, # Derived from patch dimensions and CNN output channels
44
- gcn_hidden=128,
45
- gcn_out=64
46
- ):
47
- super(SAGViTClassifier, self).__init__()
48
-
49
- # CNN feature extractor (frozen pre-trained InceptionV3)
50
- self.cnn = EfficientNetV2FeatureExtractor()
51
-
52
- # Graph Attention Network to process patch embeddings
53
- self.gcn = GATGNN(in_channels=in_channels, hidden_channels=gcn_hidden, out_channels=gcn_out)
54
-
55
- # Learnable positional embedding for Transformer input
56
- self.positional_embedding = nn.Parameter(torch.randn(1, 1, d_model))
57
- # Extra embedding token (similar to class token) to summarize global info
58
- self.extra_embedding = nn.Parameter(torch.randn(1, d_model))
59
-
60
- # Transformer encoder to capture long-range global dependencies
61
- self.transformer_encoder = TransformerEncoder(d_model, nhead, num_layers, dim_feedforward)
62
-
63
- # MLP classification head
64
- self.mlp = MLPBlock(d_model, hidden_mlp_features, num_classes)
65
-
66
- self.patch_size = patch_size
67
-
68
- def forward(self, x):
69
- # Step 1: High-fidelity feature extraction from CNN
70
- feature_map = self.cnn(x)
71
-
72
- # Step 2: Build graphs from patches
73
- G_global_batch, patches = build_graph_from_patches(feature_map, self.patch_size)
74
-
75
- # Step 3: Convert to PyG Data format and batch
76
- data_list = build_graph_data_from_patches(G_global_batch, patches)
77
- device = x.device
78
- batch = Batch.from_data_list(data_list).to(device)
79
-
80
- # Step 4: GAT stage
81
- x_gcn = self.gcn(batch)
82
-
83
- # Step 5: Reshape GCN output back to (B, N, D)
84
- # The number of patches per image is determined by patch size and feature map dimensions.
85
- B = x.size(0)
86
- D = x_gcn.size(-1)
87
- # N is automatically inferred
88
- # Thus x_gcn is (B, D) now. We need a sequence dimension for the Transformer.
89
- # Let's treat each image-level embedding as one "patch token" plus an extra token:
90
- patch_embeddings = x_gcn.unsqueeze(1) # (B, 1, D)
91
-
92
- # Add positional embedding
93
- patch_embeddings = patch_embeddings + self.positional_embedding # (B, 1, D)
94
-
95
- # Add an extra learnable embedding (like a CLS token)
96
- patch_embeddings = torch.cat([patch_embeddings, self.extra_embedding.unsqueeze(0).expand(B, -1, -1)], dim=1) # (B, 2, D)
97
-
98
- # Step 6: Transformer encoder
99
- x_trans = self.transformer_encoder(patch_embeddings)
100
-
101
- # Step 7: Global pooling (here we just take the mean)
102
- x_pooled = x_trans.mean(dim=1) # (B, D)
103
-
104
- # Classification
105
- out = self.mlp(x_pooled)
106
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_graph_construction.py DELETED
@@ -1,39 +0,0 @@
1
- import unittest
2
- import torch
3
- import networkx as nx
4
- from graph_construction import extract_patches, build_graph_from_patches, build_graph_data_from_patches
5
-
6
- class TestGraphConstruction(unittest.TestCase):
7
- def test_extract_patches_shape(self):
8
- # Create a dummy feature map: B=2, C=16, H=32, W=32
9
- feature_map = torch.randn(2, 16, 32, 32)
10
- patches = extract_patches(feature_map, patch_size=(4,4))
11
- # Check dimensions: after extraction,
12
- # number_of_patches = (H/4)*(W/4) = 8*8=64 per image, total 2*64=128
13
- self.assertEqual(patches.shape, (2, 64, 16, 4, 4))
14
-
15
- def test_build_graph_from_patches_graph_structure(self):
16
- feature_map = torch.randn(1, 16, 32, 32)
17
- G_batch, patches = build_graph_from_patches(feature_map, patch_size=(4,4))
18
- # 1 image => G_batch[0] is the graph
19
- G = G_batch[0]
20
- # We have 64 patches
21
- self.assertEqual(len(G.nodes), 64)
22
- # Check if edges exist (8-neighborhood).
23
- # Interior nodes should have edges to neighbors.
24
- # Just check a random node in the middle
25
- node_index = 9 # assuming row=1, col=1 in an 8x8 grid
26
- self.assertTrue(len(list(G.neighbors(node_index))) > 0)
27
-
28
- def test_build_graph_data_from_patches_conversion(self):
29
- feature_map = torch.randn(2, 16, 32, 32)
30
- G_batch, patches = build_graph_from_patches(feature_map, patch_size=(4,4))
31
- data_list = build_graph_data_from_patches(G_batch, patches)
32
- self.assertEqual(len(data_list), 2)
33
- # Check node feature shape
34
- self.assertEqual(data_list[0].x.shape[1], 16*4*4) # C * patch_h * patch_w = 16*4*4=256
35
- # Check edges are present
36
- self.assertTrue(data_list[0].edge_index.shape[1] > 0)
37
-
38
- if __name__ == '__main__':
39
- unittest.main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_model_components.py DELETED
@@ -1,53 +0,0 @@
1
- import unittest
2
- import torch
3
- from model_components import EfficientNetV2FeatureExtractor, GATGNN, TransformerEncoder, MLPBlock
4
- from torch_geometric.data import Data
5
-
6
- class TestModelComponents(unittest.TestCase):
7
- def test_efficientnetv2_extractor_output_shape(self):
8
- model = EfficientNetV2FeatureExtractor()
9
- model.eval()
10
- x = torch.randn(2, 3, 224, 224)
11
- with torch.no_grad():
12
- features = model(x)
13
- # Check output shape - depends on inception intermediate layer
14
- # Example: shape could be (2, 768, 8, 8) depending on the chosen layer
15
- self.assertEqual(features.size(0), 2)
16
- self.assertTrue(features.size(1) > 0)
17
- self.assertTrue(features.size(2) > 0)
18
- self.assertTrue(features.size(3) > 0)
19
-
20
- def test_gatgnn_forward(self):
21
- # Graph with 4 nodes, each node feature dim=256
22
- x = torch.randn(4, 256)
23
- edge_index = torch.tensor([[0,1,1,2],[1,0,2,3]], dtype=torch.long)
24
- batch = torch.tensor([0,0,0,0])
25
- data = Data(x=x, edge_index=edge_index, batch=batch)
26
-
27
- gnn = GATGNN(in_channels=256, hidden_channels=64, out_channels=32)
28
- output = gnn(data)
29
- # After pooling: should be (batch_size, out_channels) = (1,32)
30
- self.assertEqual(output.shape, (1, 32))
31
-
32
- def test_transformer_encoder(self):
33
- # (B, N, D) = (2, 10, 64)
34
- x = torch.randn(2, 10, 64)
35
- encoder = TransformerEncoder(d_model=64, nhead=4, num_layers=2, dim_feedforward=64)
36
- out = encoder(x)
37
- # same shape as input
38
- self.assertEqual(out.shape, (2, 10, 64))
39
-
40
- def test_mlp_block(self):
41
- mlp = MLPBlock(in_features=64, hidden_features=128, out_features=10)
42
- x = torch.randn(2, 64)
43
- out = mlp(x)
44
- self.assertEqual(out.shape, (2,10))
45
-
46
- def test_efficientnetv2_freeze(self):
47
- # Ensure params are frozen
48
- model = EfficientNetV2FeatureExtractor()
49
- for param in model.parameters():
50
- self.assertFalse(param.requires_grad)
51
-
52
- if __name__ == '__main__':
53
- unittest.main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_sag_vit_model.py DELETED
@@ -1,39 +0,0 @@
1
- import unittest
2
- import torch
3
- from sag_vit_model import SAGViTClassifier
4
-
5
- class TestSAGViTModel(unittest.TestCase):
6
- def test_forward_pass(self):
7
- model = SAGViTClassifier(
8
- patch_size=(4,4),
9
- num_classes=10, # smaller num classes for test
10
- d_model=64,
11
- nhead=4,
12
- num_layers=2,
13
- dim_feedforward=64,
14
- hidden_mlp_features=64,
15
- in_channels=2560, # from patch dimension example
16
- gcn_hidden=128,
17
- gcn_out=64
18
- )
19
- model.eval()
20
- x = torch.randn(2, 3, 224, 224)
21
- with torch.no_grad():
22
- out = model(x)
23
- # Check output shape: (B, num_classes) = (2,10)
24
- self.assertEqual(out.shape, (2,10))
25
-
26
- def test_empty_input(self):
27
- model = SAGViTClassifier()
28
- # Passing an empty tensor should fail gracefully
29
- with self.assertRaises(Exception):
30
- model(torch.empty(0,3,224,224))
31
-
32
- def test_invalid_input_dimensions(self):
33
- model = SAGViTClassifier()
34
- # Incorrect dimension (e.g., missing channel)
35
- with self.assertRaises(RuntimeError):
36
- model(torch.randn(2, 224, 224)) # no channel dimension
37
-
38
- if __name__ == '__main__':
39
- unittest.main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_train.py DELETED
@@ -1,54 +0,0 @@
1
- import unittest
2
- from unittest.mock import MagicMock, patch
3
- import torch
4
- import torch.nn as nn
5
- from train import train_model
6
- from sag_vit_model import SAGViTClassifier
7
-
8
- class TestTrain(unittest.TestCase):
9
- @patch("train.optim.Adam")
10
- def test_train_model_loop(self, mock_adam):
11
- # Mock the optimizer
12
- mock_optimizer = MagicMock()
13
- mock_adam.return_value = mock_optimizer
14
-
15
- # Mock dataloaders with a small dummy dataset
16
- # Just one batch with a couple of samples
17
- train_dataloader = [ (torch.randn(2,3,224,224), torch.tensor([0,1])) ]
18
- val_dataloader = [ (torch.randn(2,3,224,224), torch.tensor([0,1])) ]
19
-
20
- model = SAGViTClassifier(num_classes=2)
21
-
22
- criterion = nn.CrossEntropyLoss()
23
- device = torch.device("cpu")
24
-
25
- # Test a single epoch training
26
- history = train_model(model, "TestModel", train_dataloader, val_dataloader,
27
- num_epochs=1, criterion=criterion, optimizer=mock_optimizer, device=device, patience=2, verbose=False)
28
-
29
- # Check if history is properly recorded
30
- self.assertIn("train_loss", history)
31
- self.assertIn("val_loss", history)
32
- self.assertGreaterEqual(len(history["train_loss"]), 1)
33
- self.assertGreaterEqual(len(history["val_loss"]), 1)
34
-
35
- def test_early_stopping(self):
36
- # Mocking dataloaders where validation loss doesn't improve
37
- model = SAGViTClassifier(num_classes=2)
38
- criterion = nn.CrossEntropyLoss()
39
- optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
40
- device = torch.device("cpu")
41
-
42
- # create a scenario where val loss won't improve
43
- # first epoch normal, second epoch slightly worse
44
- train_dataloader = [ (torch.randn(2,3,224,224), torch.tensor([0,1])) ]
45
- val_dataloader = [ (torch.randn(2,3,224,224), torch.tensor([0,1])) ]
46
-
47
- history = train_model(model, "TestModelEarlyStop", train_dataloader, val_dataloader,
48
- num_epochs=5, criterion=criterion, optimizer=optimizer, device=device, patience=1, verbose=False)
49
-
50
- # Should have triggered early stopping before all 5 epochs
51
- self.assertLessEqual(len(history["train_loss"]), 5)
52
-
53
- if __name__ == '__main__':
54
- unittest.main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train.py DELETED
@@ -1,189 +0,0 @@
1
- import os
2
- import torch
3
- from torch import nn, optim
4
- from tqdm import tqdm
5
- import numpy as np
6
- from sklearn.metrics import (precision_score, recall_score, f1_score,
7
- roc_auc_score, cohen_kappa_score, matthews_corrcoef,
8
- confusion_matrix)
9
-
10
- from sag_vit_model import SAGViTClassifier
11
- from data_loader import get_dataloaders
12
-
13
- #####################################################################
14
- # This file provides the training loop and metric computation. It uses
15
- # the SAG-ViT model defined in sag_vit_model.py, and the data from data_loader.py.
16
- # The training loop is adapted to implement early stopping and track various metrics.
17
- #####################################################################
18
-
19
- def train_model(model, model_name, train_loader, val_loader, num_epochs, criterion, optimizer, device, patience=8, verbose=True):
20
- """
21
- Trains the SAG-ViT model and evaluates it on the validation set.
22
- Implements early stopping based on validation loss.
23
-
24
- Parameters:
25
- - model (nn.Module): The SAG-ViT model.
26
- - model_name (str): A name to identify the model (used for saving checkpoints).
27
- - train_loader, val_loader: DataLoaders for training and validation.
28
- - num_epochs (int): Maximum number of epochs.
29
- - criterion (nn.Module): Loss function.
30
- - optimizer (torch.optim.Optimizer): Optimization algorithm.
31
- - device (torch.device): Device to run the computations on (CPU/GPU).
32
- - patience (int): Early stopping patience.
33
-
34
- Returns:
35
- - history (dict): Dictionary containing training and validation metrics per epoch.
36
- """
37
-
38
- history = {
39
- 'train_loss': [], 'train_acc': [], 'train_prec': [], 'train_rec': [], 'train_f1': [],
40
- 'train_auc': [], 'train_mcc': [], 'train_cohen_kappa': [], 'train_confusion_matrix': [],
41
- 'val_loss': [], 'val_acc': [], 'val_prec': [], 'val_rec': [], 'val_f1': [],
42
- 'val_auc': [], 'val_mcc': [], 'val_cohen_kappa': [], 'val_confusion_matrix': []
43
- }
44
-
45
- best_val_loss = float('inf')
46
- patience_counter = 0
47
- best_model_state = None
48
-
49
- for epoch in range(num_epochs):
50
- print(f'Epoch {epoch+1}/{num_epochs}')
51
- model.train()
52
-
53
- train_loss_total, correct, total = 0, 0, 0
54
- all_preds, all_labels, all_probs = [], [], []
55
-
56
- # Training loop
57
- for batch_idx, (X, y) in enumerate(tqdm(train_loader)):
58
- inputs, labels = X.to(device), y.to(device)
59
- optimizer.zero_grad()
60
-
61
- outputs = model(inputs)
62
- loss = criterion(outputs, labels)
63
- loss.backward()
64
- optimizer.step()
65
-
66
- train_loss_total += loss.item()
67
-
68
- probs = torch.softmax(outputs, dim=1)
69
- _, preds = torch.max(outputs, 1)
70
- correct += (preds == labels).sum().item()
71
- total += labels.size(0)
72
-
73
- all_preds.extend(preds.cpu().numpy())
74
- all_labels.extend(labels.cpu().numpy())
75
- all_probs.extend(probs.detach().cpu().numpy())
76
-
77
- # Compute training metrics
78
- train_acc = correct / total
79
- train_prec = precision_score(all_labels, all_preds, average='macro', zero_division=0)
80
- train_rec = recall_score(all_labels, all_preds, average='macro')
81
- train_f1 = f1_score(all_labels, all_preds, average='macro')
82
- train_cohen_kappa = cohen_kappa_score(all_labels, all_preds)
83
- train_mcc = matthews_corrcoef(all_labels, all_preds)
84
- train_confusion = confusion_matrix(all_labels, all_preds)
85
-
86
- history['train_loss'].append(train_loss_total / len(train_loader))
87
- history['train_acc'].append(train_acc)
88
- history['train_prec'].append(train_prec)
89
- history['train_rec'].append(train_rec)
90
- history['train_f1'].append(train_f1)
91
- history['train_cohen_kappa'].append(train_cohen_kappa)
92
- history['train_mcc'].append(train_mcc)
93
- history['train_confusion_matrix'].append(train_confusion)
94
-
95
- # Validation
96
- model.eval()
97
- val_loss_total, correct, total = 0, 0, 0
98
- all_preds, all_labels, all_probs = [], [], []
99
-
100
- with torch.no_grad():
101
- for batch_idx, (X, y) in enumerate(tqdm(val_loader)):
102
- inputs, labels = X.to(device), y.to(device)
103
- outputs = model(inputs)
104
- loss = criterion(outputs, labels)
105
-
106
- val_loss_total += loss.item()
107
- probs = torch.softmax(outputs, dim=1)
108
- _, preds = torch.max(outputs, 1)
109
- correct += (preds == labels).sum().item()
110
- total += labels.size(0)
111
-
112
- all_preds.extend(preds.cpu().numpy())
113
- all_labels.extend(labels.cpu().numpy())
114
- all_probs.extend(probs.detach().cpu().numpy())
115
-
116
- # Compute validation metrics
117
- val_acc = correct / total
118
- val_prec = precision_score(all_labels, all_preds, average='macro', zero_division=0)
119
- val_rec = recall_score(all_labels, all_preds, average='macro')
120
- val_f1 = f1_score(all_labels, all_preds, average='macro')
121
- val_cohen_kappa = cohen_kappa_score(all_labels, all_preds)
122
- val_mcc = matthews_corrcoef(all_labels, all_preds)
123
- val_confusion = confusion_matrix(all_labels, all_preds)
124
-
125
- history['val_loss'].append(val_loss_total / len(val_loader))
126
- history['val_acc'].append(val_acc)
127
- history['val_prec'].append(val_prec)
128
- history['val_rec'].append(val_rec)
129
- history['val_f1'].append(val_f1)
130
- history['val_cohen_kappa'].append(val_cohen_kappa)
131
- history['val_mcc'].append(val_mcc)
132
- history['val_confusion_matrix'].append(val_confusion)
133
-
134
- # Print epoch summary
135
- if verbose:
136
- print(f"Train Loss: {history['train_loss'][-1]:.4f}, Train Acc: {history['train_acc'][-1]:.4f}, "
137
- f"Val Loss: {history['val_loss'][-1]:.4f}, Val Acc: {history['val_acc'][-1]:.4f}")
138
-
139
- # Early stopping
140
- current_val_loss = history['val_loss'][-1]
141
- if current_val_loss < best_val_loss:
142
- best_val_loss = current_val_loss
143
- best_model_state = model.state_dict()
144
- patience_counter = 0
145
- else:
146
- patience_counter += 1
147
- print(f"Patience counter: {patience_counter}/{patience}")
148
- if patience_counter >= patience:
149
- print("Early stopping triggered.")
150
- model.load_state_dict(best_model_state)
151
- torch.save(model.state_dict(), f'{model_name}-best.pth')
152
- return history
153
-
154
- model.load_state_dict(best_model_state)
155
- torch.save(model.state_dict(), f'{model_name}-{num_epochs}_epochs.pth')
156
-
157
- return history
158
-
159
-
160
- if __name__ == "__main__":
161
- # Example usage:
162
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
163
- print(f"Training on device: {device}")
164
- data_dir = "data/PlantVillage" # "path/to/data/dir"
165
- num_classes = len(os.listdir(data_dir))
166
- train_loader, val_loader = get_dataloaders(data_dir=data_dir, img_size=224, batch_size=32) # Minimum image size should be atleast (49, 49)
167
-
168
- model = SAGViTClassifier(num_classes=num_classes).to(device)
169
-
170
- criterion = nn.CrossEntropyLoss()
171
- optimizer = optim.Adam(model.parameters(), lr=0.0001)
172
- num_epochs = 100
173
-
174
- history = train_model(
175
- model,
176
- 'SAGViT',
177
- train_loader,
178
- val_loader,
179
- num_epochs,
180
- criterion,
181
- optimizer,
182
- device
183
- )
184
-
185
- # You may save history to a CSV or analyze it further as needed.
186
- # Example:
187
- # import pandas as pd
188
- # history_df = pd.DataFrame(history)
189
- # history_df.to_csv("training_history.csv", index=False)