|
--- |
|
license: mit |
|
base_model: |
|
- timm/deit_small_patch16_224.fb_in1k |
|
- timm/deit_tiny_patch16_224.fb_in1k |
|
- timm/cait_xxs24_224.fb_dist_in1k |
|
metrics: |
|
- accuracy |
|
tags: |
|
- Interpretability |
|
- ViT |
|
- Classification |
|
- XAI |
|
--- |
|
|
|
# ProtoViT: Interpretable Vision Transformer with Adaptive Prototype Learning |
|
|
|
This repository contains pretrained ProtoViT models for interpretable image classification, as described in our paper "Interpretable Image Classification with Adaptive Prototype-based Vision Transformers". |
|
|
|
## Model Description |
|
|
|
[ProtoViT](https://github.com/Henrymachiyu/ProtoViT) combines Vision Transformers with prototype-based learning to create models that are both highly accurate and interpretable. Rather than functioning as a black box, ProtoViT learns interpretable prototypes that explain its classification decisions through visual similarities. |
|
|
|
### Supported Architectures |
|
|
|
We provide three variants of ProtoViT: |
|
|
|
- **ProtoViT-T**: Built on DeiT-Tiny backbone |
|
- **ProtoViT-S**: Built on DeiT-Small backbone |
|
- **ProtoViT-CaiT**: Built on CaiT-XXS24 backbone |
|
|
|
## Performance |
|
|
|
All models were trained and evaluated on the CUB-200-2011 fine-grained bird species classification dataset. |
|
|
|
| Model Version | Backbone | Resolution | Top-1 Accuracy | Checkpoint | |
|
|--------------|----------|------------|----------------|------------| |
|
| ProtoViT-T | DeiT-Tiny | 224ร224 | 83.36% | [Download](https://huggingface.co/chiyum609/ProtoViT/blob/main/DeiT_Tiny_finetuned0.8336.pth) | |
|
| ProtoViT-S | DeiT-Small | 224ร224 | 85.30% | [Download](https://huggingface.co/chiyum609/ProtoViT/blob/main/DeiT_Small_finetuned0.8530.pth) | |
|
| ProtoViT-CaiT | CaiT_xxs24 | 224ร224 | 86.02% | [Download](https://huggingface.co/chiyum609/ProtoViT/blob/main/CaiT_xxs24_224_finetuned0.8602.pth) | |
|
|
|
## Features |
|
|
|
- ๐ **Interpretable Decisions**: The model performs classification with self-explainatory reasoning based on the inputโs similarity to learned prototypes, the key features for each classes. |
|
- ๐ฏ **High Accuracy**: Achieves competitive performance on fine-grained classification tasks |
|
- ๐ **Multiple Architectures**: Supports various Vision Transformer backbones |
|
- ๐ **Analysis Tools**: Comes with tools for both local and global prototype analysis |
|
|
|
## Requirements |
|
|
|
- Python 3.8+ |
|
- PyTorch 1.8+ |
|
- timm==0.4.12 |
|
- torchvision |
|
- numpy |
|
- pillow |
|
|
|
## Limitations and Bias |
|
- Data Bias: These models are trained on CUB-200-2011, which may not generalize well to images outside this dataset. |
|
- Resolution Constraints: The models are trained at a resolution of 224ร224; higher or lower resolutions may impact performance. |
|
- Location Misalignment: Same as the CNN based models, these models are not perfectly immune to location misalignment under adversarial attack. |
|
|
|
## Citation |
|
|
|
If you use this model in your research, please cite: |
|
|
|
```bibtex |
|
@article{ma2024interpretable, |
|
title={Interpretable Image Classification with Adaptive Prototype-based Vision Transformers}, |
|
author={Ma, Chiyu and Donnelly, Jon and Liu, Wenjun and Vosoughi, Soroush and Rudin, Cynthia and Chen, Chaofan}, |
|
journal={arXiv preprint arXiv:2410.20722}, |
|
year={2024} |
|
} |
|
``` |
|
|
|
## Acknowledgements |
|
|
|
This implementation builds upon the following excellent repositories: |
|
- [DeiT](https://github.com/facebookresearch/deit) |
|
- [CaiT](https://github.com/facebookresearch/deit) |
|
- [ProtoPNet](https://github.com/cfchen-duke/ProtoPNet) |
|
|
|
## License |
|
|
|
This project is released under [MIT] license. |
|
|
|
## Contact |
|
|
|
For any questions or feedback, please: |
|
1. Open an issue in the GitHub repository |
|
2. Contact [[email protected]] |