--- tags: - image_classification - computer_vision license: mit datasets: - p2pfl/CIFAR10 language: - en pipeline_tag: image-classification metrics: - f1 --- # SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers ### Model Description Implementation of the ***SAG-ViT*** model as proposed in the [SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers](https://arxiv.org/abs/2411.09420) paper. It is a novel transformer framework designed to enhance Vision Transformers (ViT) with scale-awareness and refined patch-level feature embeddings. It extracts multiscale features using EfficientNetV2 organizes patches into a graph based on spatial relationships, and refines them with a Graph Attention Network (GAT). A Transformer encoder then integrates these embeddings globally, capturing long-range dependencies for comprehensive image understanding. ### Model Architecture  _Image source: [SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers](https://arxiv.org/abs/2411.09420)_ ### Usage SAG-ViT expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape `(N, 3, H, W)`, where `N` is the number of images, `H` and `W` are expected to be at least `49` pixels. The images have to be loaded in to a range of `[0, 1]` and then normalized using `mean = [0.485, 0.456, 0.406]` and `std = [0.229, 0.224, 0.225]`. To train or run inference on our model, refer to the following steps: Clone our repository and load the model pretrained on CIFAR-10 dataset. ```bash git clone https://huggingface.co/shravvvv/SAG-ViT cd SAG-ViT ``` Install required dependencies. ```bash pip install -r requirements.txt ``` Use `from_pretrained` to load the model from Hugging Face Hub and run inference on a sample input image. ```python from transformers import AutoModel, AutoConfig from PIL import Image from torchvision import transforms import torch # Step 1: Load the model and configuration directly from Hugging Face Hub repo_name = "shravvvv/SAG-ViT" config = AutoConfig.from_pretrained(repo_name) # Load config from hub model = AutoModel.from_pretrained(repo_name, config=config) # Load model from hub # Step 2: Define the transformation for the input image transform = transforms.Compose([ transforms.Resize((224, 224)), # Resize to match the expected input size transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Example normalization ]) # Step 3: Load and preprocess the input image input_image_path = "path/to/your/image.jpg" img = Image.open(input_image_path).convert("RGB") img = transform(img).unsqueeze(0) # Add batch dimension # Step 4: Ensure the model is in evaluation mode model.eval() # Step 5: Run inference with torch.no_grad(): outputs = model(img) logits = outputs.logits # Accessing logits from ModelOutput # Step 6: Post-process the predictions predicted_class_index = torch.argmax(logits, dim=1) # Get the predicted class index # CIFAR-10 label mapping class_names = [ 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' ] # Get the predicted class name from the class index predicted_class_name = class_names[predicted_class_index.item()] print(f"Predicted class: {predicted_class_name}") ``` ### Running Tests If you clone our [repository](https://github.com/shravan-18/SAG-ViT), the *'tests'* folder will contain unit tests for each of our model's modules. Make sure you have a proper Python environment with the required dependencies installed. Then run: ```bash python -m unittest discover -s tests ``` or, if you are using `pytest`, you can run: ```bash pytest tests ``` **Results** We evaluated SAG-ViT on diverse datasets: - **CIFAR-10** (natural images) - **GTSRB** (traffic sign recognition) - **NCT-CRC-HE-100K** (histopathological images) - **NWPU-RESISC45** (remote sensing imagery) - **PlantVillage** (agricultural imagery) SAG-ViT achieves state-of-the-art results across all benchmarks, as shown in the table below (F1 scores):