Patrick Ramos
commited on
Commit
·
b093863
1
Parent(s):
d81f7ec
Update README.md
Browse files
README.md
CHANGED
@@ -3083,7 +3083,7 @@ widget:
|
|
3083 |
|
3084 |
# Model description
|
3085 |
|
3086 |
-
This is a LogisticRegressionCV model trained on averages of patch embeddings from the Imagenette dataset
|
3087 |
|
3088 |
## Intended uses & limitations
|
3089 |
|
@@ -3145,9 +3145,40 @@ Use the code below to get started with the model.
|
|
3145 |
<summary> Click to expand </summary>
|
3146 |
|
3147 |
```python
|
3148 |
-
import
|
3149 |
-
|
3150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3151 |
```
|
3152 |
|
3153 |
</details>
|
@@ -3168,11 +3199,16 @@ You can contact the model card authors through following channels:
|
|
3168 |
|
3169 |
# Citation
|
3170 |
|
3171 |
-
Below you can find information related to citation.
|
3172 |
|
3173 |
**BibTeX:**
|
3174 |
```
|
3175 |
-
|
|
|
|
|
|
|
|
|
|
|
3176 |
```
|
3177 |
|
3178 |
|
|
|
3083 |
|
3084 |
# Model description
|
3085 |
|
3086 |
+
This is a LogisticRegressionCV model trained on averages of patch embeddings from the Imagenette dataset. This forms the GAM of an [Emb-GAM](https://arxiv.org/abs/2209.11799) extended to images. Patch embeddings are meant to be extracted with the [`facebook/dino-vitb16` DINO checkpoint](https://huggingface.co/facebook/dino-vitb16).
|
3087 |
|
3088 |
## Intended uses & limitations
|
3089 |
|
|
|
3145 |
<summary> Click to expand </summary>
|
3146 |
|
3147 |
```python
|
3148 |
+
from PIL import Image
|
3149 |
+
from skops import hub_utils
|
3150 |
+
import torch
|
3151 |
+
from transformers import ViTFeatureExtractor, ViTModel
|
3152 |
+
import pickle
|
3153 |
+
import os
|
3154 |
+
|
3155 |
+
# load DINO
|
3156 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
3157 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained('facebook/dino-vitb16')
|
3158 |
+
model = ViTModel.from_pretrained('facebook/dino-vitb16').eval().to(device)
|
3159 |
+
|
3160 |
+
# load logistic regression
|
3161 |
+
os.mkdir('logistic regression')
|
3162 |
+
hub_utils.download(repo_id='Ramos-Ramos/emb-gam-dino', dst='emb-gam-dino')
|
3163 |
+
|
3164 |
+
with open('emb-gam-dino/model.pkl', 'rb') as file:
|
3165 |
+
logistic_regression = pickle.load(file)
|
3166 |
+
|
3167 |
+
# load image
|
3168 |
+
img = Image.open('examples/english_springer.png')
|
3169 |
+
|
3170 |
+
# preprocess image
|
3171 |
+
inputs = {k: v.to(device) for k, v in feature_extractor(img, return_tensors='pt').items()}
|
3172 |
+
|
3173 |
+
# extract patch embeddings
|
3174 |
+
with torch.no_grad():
|
3175 |
+
patch_embeddings = model(**inputs).last_hidden_state[0, 1:].cpu()
|
3176 |
+
|
3177 |
+
# classify
|
3178 |
+
pred = logistic_regression.predict(patch_embeddings.mean(dim=0).view(1, -1))
|
3179 |
+
|
3180 |
+
# get patch contributions
|
3181 |
+
patch_contributions = logistic_regression.coef_ @ patch_embeddings.T.numpy()
|
3182 |
```
|
3183 |
|
3184 |
</details>
|
|
|
3199 |
|
3200 |
# Citation
|
3201 |
|
3202 |
+
Below you can find information related to citation. Note that this is **not our own paper**.
|
3203 |
|
3204 |
**BibTeX:**
|
3205 |
```
|
3206 |
+
@article{singh2022emb,
|
3207 |
+
title={Emb-GAM: an Interpretable and Efficient Predictor using Pre-trained Language Models},
|
3208 |
+
author={Singh, Chandan and Gao, Jianfeng},
|
3209 |
+
journal={arXiv preprint arXiv:2209.11799},
|
3210 |
+
year={2022}
|
3211 |
+
}
|
3212 |
```
|
3213 |
|
3214 |
|