File size: 5,588 Bytes
039647a b99e299 039647a b99e299 039647a b99e299 039647a b99e299 039647a b99e299 039647a b99e299 039647a b99e299 039647a b99e299 039647a b99e299 039647a b99e299 039647a b99e299 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import torch
from torch import nn
from transformers import AutoConfig, PretrainedConfig, AutoModel, PreTrainedModel
from transformers.models.auto import AutoConfig, CONFIG_MAPPING, MODEL_MAPPING
from transformers.utils import logging
from transformers.modeling_utils import ModelOutput
from huggingface_hub import PyTorchModelHubMixin
from torch_geometric.data import Batch
from model_components import EfficientNetV2FeatureExtractor, GATGNN, TransformerEncoder, MLPBlock
from graph_construction import build_graph_from_patches, build_graph_data_from_patches
###############################################################################
# SAG-ViT Model:
# This class combines:
# 1) CNN backbone to produce high-fidelity feature maps (Section 3.1),
# 2) Graph construction and GAT to refine local patch embeddings (Section 3.2 and 3.3),
# 3) A Transformer encoder to capture global relationships (Section 3.3),
# 4) A final MLP classifier.
###############################################################################
# Custom model registration
class SAGViTConfig(PretrainedConfig):
model_type = "sagvit"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.d_model = kwargs.get("d_model", 64)
self.dim_feedforward = kwargs.get("dim_feedforward", 64)
self.gcn_hidden = kwargs.get("gcn_hidden", 128)
self.gcn_out = kwargs.get("gcn_out", 64)
self.hidden_mlp_features = kwargs.get("hidden_mlp_features", 64)
self.in_channels = kwargs.get("in_channels", 2560)
self.nhead = kwargs.get("nhead", 4)
self.num_classes = kwargs.get("num_classes", 10)
self.num_layers = kwargs.get("num_layers", 2)
self.patch_size = kwargs.get("patch_size", (4, 4))
class SAGViTClassifier(PreTrainedModel):
"""
SAG-ViT: Scale-Aware Graph Attention Vision Transformer
This model integrates the following steps:
- Extract multi-scale features from images using a CNN backbone (EfficientNetv2 here).
- Partition the feature map into patches and build a graph where each node is a patch.
- Use a Graph Attention Network (GAT) to refine patch embeddings based on local spatial relationships.
- Utilize a Transformer encoder to model long-range dependencies and integrate multi-scale information.
- Finally, classify the resulting representation into desired classes.
Inputs:
- x (Tensor): Input images (B, 3, H, W)
Outputs:
- out (Tensor): Classification logits (B, num_classes)
"""
config_class = SAGViTConfig
def __init__(self, config):
super().__init__(config)
self.patch_size = config.patch_size
self.num_classes = config.num_classes
# CNN feature extractor (frozen pre-trained EfficientNetv2)
self.cnn = EfficientNetV2FeatureExtractor()
# Graph Attention Network to process patch embeddings
self.gcn = GATGNN(
in_channels=config.in_channels,
hidden_channels=config.gcn_hidden,
out_channels=config.gcn_out,
)
# Learnable positional embedding for Transformer input
self.positional_embedding = nn.Parameter(torch.randn(1, 1, config.d_model))
# Extra embedding token (similar to class token) to summarize global info
self.extra_embedding = nn.Parameter(torch.randn(1, config.d_model))
# Transformer encoder to capture long-range global dependencies
self.transformer_encoder = TransformerEncoder(
d_model=config.d_model,
nhead=config.nhead,
num_layers=config.num_layers,
dim_feedforward=config.dim_feedforward,
)
# MLP classification head
self.mlp = MLPBlock(config.d_model, config.hidden_mlp_features, config.num_classes)
def forward(self, x, **kwargs):
# Step 1: High-fidelity feature extraction from CNN
feature_map = self.cnn(x)
# Step 2: Build graphs from patches
G_global_batch, patches = build_graph_from_patches(feature_map, self.patch_size)
# Step 3: Convert to PyG Data format and batch
data_list = build_graph_data_from_patches(G_global_batch, patches)
device = x.device
batch = Batch.from_data_list(data_list).to(device)
# Step 4: GAT stage
x_gcn = self.gcn(batch)
# Step 5: Reshape GCN output back to (B, N, D)
# The number of patches per image is determined by patch size and feature map dimensions.
B = x.size(0)
D = x_gcn.size(-1)
# N is automatically inferred
# Thus x_gcn is (B, D) now. We need a sequence dimension for the Transformer.
# Let's treat each image-level embedding as one "patch token" plus an extra token:
patch_embeddings = x_gcn.unsqueeze(1) # (B, 1, D)
# Add positional embedding
patch_embeddings = patch_embeddings + self.positional_embedding # (B, 1, D)
# Add an extra learnable embedding (like a CLS token)
patch_embeddings = torch.cat([patch_embeddings, self.extra_embedding.unsqueeze(0).expand(B, -1, -1)], dim=1) # (B, 2, D)
# Step 6: Transformer encoder
x_trans = self.transformer_encoder(patch_embeddings)
# Step 7: Global pooling (here we just take the mean)
x_pooled = x_trans.mean(dim=1) # (B, D)
# Classification
logits = self.mlp(x_pooled)
return ModelOutput(logits=logits)
# Register custom model and config
CONFIG_MAPPING.register("sagvit", SAGViTConfig)
MODEL_MAPPING.register(SAGViTConfig, SAGViTClassifier)
|