ProtoViT / README.md
chiyum609's picture
Update README.md
4073fe5 verified
---
license: mit
base_model:
- timm/deit_small_patch16_224.fb_in1k
- timm/deit_tiny_patch16_224.fb_in1k
- timm/cait_xxs24_224.fb_dist_in1k
metrics:
- accuracy
tags:
- Interpretability
- ViT
- Classification
- XAI
---
# ProtoViT: Interpretable Vision Transformer with Adaptive Prototype Learning
This repository contains pretrained ProtoViT models for interpretable image classification, as described in our paper "Interpretable Image Classification with Adaptive Prototype-based Vision Transformers".
## Model Description
[ProtoViT](https://github.com/Henrymachiyu/ProtoViT) combines Vision Transformers with prototype-based learning to create models that are both highly accurate and interpretable. Rather than functioning as a black box, ProtoViT learns interpretable prototypes that explain its classification decisions through visual similarities.
### Supported Architectures
We provide three variants of ProtoViT:
- **ProtoViT-T**: Built on DeiT-Tiny backbone
- **ProtoViT-S**: Built on DeiT-Small backbone
- **ProtoViT-CaiT**: Built on CaiT-XXS24 backbone
## Performance
All models were trained and evaluated on the CUB-200-2011 fine-grained bird species classification dataset.
| Model Version | Backbone | Resolution | Top-1 Accuracy | Checkpoint |
|--------------|----------|------------|----------------|------------|
| ProtoViT-T | DeiT-Tiny | 224ร—224 | 83.36% | [Download](https://huggingface.co/chiyum609/ProtoViT/blob/main/DeiT_Tiny_finetuned0.8336.pth) |
| ProtoViT-S | DeiT-Small | 224ร—224 | 85.30% | [Download](https://huggingface.co/chiyum609/ProtoViT/blob/main/DeiT_Small_finetuned0.8530.pth) |
| ProtoViT-CaiT | CaiT_xxs24 | 224ร—224 | 86.02% | [Download](https://huggingface.co/chiyum609/ProtoViT/blob/main/CaiT_xxs24_224_finetuned0.8602.pth) |
## Features
- ๐Ÿ” **Interpretable Decisions**: The model performs classification with self-explainatory reasoning based on the inputโ€™s similarity to learned prototypes, the key features for each classes.
- ๐ŸŽฏ **High Accuracy**: Achieves competitive performance on fine-grained classification tasks
- ๐Ÿš€ **Multiple Architectures**: Supports various Vision Transformer backbones
- ๐Ÿ“Š **Analysis Tools**: Comes with tools for both local and global prototype analysis
## Requirements
- Python 3.8+
- PyTorch 1.8+
- timm==0.4.12
- torchvision
- numpy
- pillow
## Limitations and Bias
- Data Bias: These models are trained on CUB-200-2011, which may not generalize well to images outside this dataset.
- Resolution Constraints: The models are trained at a resolution of 224ร—224; higher or lower resolutions may impact performance.
- Location Misalignment: Same as the CNN based models, these models are not perfectly immune to location misalignment under adversarial attack.
## Citation
If you use this model in your research, please cite:
```bibtex
@article{ma2024interpretable,
title={Interpretable Image Classification with Adaptive Prototype-based Vision Transformers},
author={Ma, Chiyu and Donnelly, Jon and Liu, Wenjun and Vosoughi, Soroush and Rudin, Cynthia and Chen, Chaofan},
journal={arXiv preprint arXiv:2410.20722},
year={2024}
}
```
## Acknowledgements
This implementation builds upon the following excellent repositories:
- [DeiT](https://github.com/facebookresearch/deit)
- [CaiT](https://github.com/facebookresearch/deit)
- [ProtoPNet](https://github.com/cfchen-duke/ProtoPNet)
## License
This project is released under [MIT] license.
## Contact
For any questions or feedback, please:
1. Open an issue in the GitHub repository
2. Contact [[email protected]]