File size: 5,588 Bytes
039647a
 
b99e299
 
 
 
039647a
 
 
 
 
 
b99e299
039647a
 
 
 
 
 
 
 
 
b99e299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
039647a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b99e299
 
 
 
 
 
 
039647a
 
 
 
 
b99e299
 
 
 
 
039647a
 
b99e299
039647a
b99e299
039647a
 
b99e299
 
 
 
 
 
039647a
 
b99e299
039647a
b99e299
039647a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b99e299
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
from torch import nn
from transformers import AutoConfig, PretrainedConfig, AutoModel, PreTrainedModel
from transformers.models.auto import AutoConfig, CONFIG_MAPPING, MODEL_MAPPING
from transformers.utils import logging
from transformers.modeling_utils import ModelOutput
from huggingface_hub import PyTorchModelHubMixin

from torch_geometric.data import Batch
from model_components import EfficientNetV2FeatureExtractor, GATGNN, TransformerEncoder, MLPBlock
from graph_construction import build_graph_from_patches, build_graph_data_from_patches


###############################################################################
# SAG-ViT Model:
# This class combines:
# 1) CNN backbone to produce high-fidelity feature maps (Section 3.1),
# 2) Graph construction and GAT to refine local patch embeddings (Section 3.2 and 3.3),
# 3) A Transformer encoder to capture global relationships (Section 3.3),
# 4) A final MLP classifier.
###############################################################################


# Custom model registration
class SAGViTConfig(PretrainedConfig):
    model_type = "sagvit"

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.d_model = kwargs.get("d_model", 64)
        self.dim_feedforward = kwargs.get("dim_feedforward", 64)
        self.gcn_hidden = kwargs.get("gcn_hidden", 128)
        self.gcn_out = kwargs.get("gcn_out", 64)
        self.hidden_mlp_features = kwargs.get("hidden_mlp_features", 64)
        self.in_channels = kwargs.get("in_channels", 2560)
        self.nhead = kwargs.get("nhead", 4)
        self.num_classes = kwargs.get("num_classes", 10)
        self.num_layers = kwargs.get("num_layers", 2)
        self.patch_size = kwargs.get("patch_size", (4, 4))


class SAGViTClassifier(PreTrainedModel):
    """
    SAG-ViT: Scale-Aware Graph Attention Vision Transformer

    This model integrates the following steps:
    - Extract multi-scale features from images using a CNN backbone (EfficientNetv2 here).
    - Partition the feature map into patches and build a graph where each node is a patch.
    - Use a Graph Attention Network (GAT) to refine patch embeddings based on local spatial relationships.
    - Utilize a Transformer encoder to model long-range dependencies and integrate multi-scale information.
    - Finally, classify the resulting representation into desired classes.

    Inputs:
    - x (Tensor): Input images (B, 3, H, W)

    Outputs:
    - out (Tensor): Classification logits (B, num_classes)
    """

    config_class = SAGViTConfig
    def __init__(self, config):
        super().__init__(config)

        self.patch_size = config.patch_size
        self.num_classes = config.num_classes

        # CNN feature extractor (frozen pre-trained EfficientNetv2)
        self.cnn = EfficientNetV2FeatureExtractor()

        # Graph Attention Network to process patch embeddings
        self.gcn = GATGNN(
            in_channels=config.in_channels,
            hidden_channels=config.gcn_hidden,
            out_channels=config.gcn_out,
        )

        # Learnable positional embedding for Transformer input
        self.positional_embedding = nn.Parameter(torch.randn(1, 1, config.d_model))
        # Extra embedding token (similar to class token) to summarize global info
        self.extra_embedding = nn.Parameter(torch.randn(1, config.d_model))

        # Transformer encoder to capture long-range global dependencies
        self.transformer_encoder = TransformerEncoder(
            d_model=config.d_model, 
            nhead=config.nhead, 
            num_layers=config.num_layers, 
            dim_feedforward=config.dim_feedforward,
        )

        # MLP classification head
        self.mlp = MLPBlock(config.d_model, config.hidden_mlp_features, config.num_classes)

    def forward(self, x, **kwargs):
        # Step 1: High-fidelity feature extraction from CNN
        feature_map = self.cnn(x)

        # Step 2: Build graphs from patches
        G_global_batch, patches = build_graph_from_patches(feature_map, self.patch_size)

        # Step 3: Convert to PyG Data format and batch
        data_list = build_graph_data_from_patches(G_global_batch, patches)
        device = x.device
        batch = Batch.from_data_list(data_list).to(device)

        # Step 4: GAT stage
        x_gcn = self.gcn(batch)

        # Step 5: Reshape GCN output back to (B, N, D)
        # The number of patches per image is determined by patch size and feature map dimensions.
        B = x.size(0)
        D = x_gcn.size(-1)
        # N is automatically inferred
        # Thus x_gcn is (B, D) now. We need a sequence dimension for the Transformer.
        # Let's treat each image-level embedding as one "patch token" plus an extra token:
        patch_embeddings = x_gcn.unsqueeze(1)  # (B, 1, D)

        # Add positional embedding
        patch_embeddings = patch_embeddings + self.positional_embedding  # (B, 1, D)

        # Add an extra learnable embedding (like a CLS token)
        patch_embeddings = torch.cat([patch_embeddings, self.extra_embedding.unsqueeze(0).expand(B, -1, -1)], dim=1)  # (B, 2, D)

        # Step 6: Transformer encoder
        x_trans = self.transformer_encoder(patch_embeddings)

        # Step 7: Global pooling (here we just take the mean)
        x_pooled = x_trans.mean(dim=1)  # (B, D)

        # Classification
        logits = self.mlp(x_pooled)
        return ModelOutput(logits=logits)

# Register custom model and config
CONFIG_MAPPING.register("sagvit", SAGViTConfig)
MODEL_MAPPING.register(SAGViTConfig, SAGViTClassifier)