Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,202 @@
|
|
1 |
---
|
|
|
2 |
license: mit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
inference: false
|
3 |
license: mit
|
4 |
+
tags:
|
5 |
+
- Zero-Shot Classification
|
6 |
+
language:
|
7 |
+
- multilingual
|
8 |
+
- af
|
9 |
+
- am
|
10 |
+
- ar
|
11 |
+
- as
|
12 |
+
- az
|
13 |
+
- be
|
14 |
+
- bg
|
15 |
+
- bn
|
16 |
+
- br
|
17 |
+
- bs
|
18 |
+
- ca
|
19 |
+
- cs
|
20 |
+
- cy
|
21 |
+
- da
|
22 |
+
- de
|
23 |
+
- el
|
24 |
+
- en
|
25 |
+
- eo
|
26 |
+
- es
|
27 |
+
- et
|
28 |
+
- eu
|
29 |
+
- fa
|
30 |
+
- fi
|
31 |
+
- fr
|
32 |
+
- fy
|
33 |
+
- ga
|
34 |
+
- gd
|
35 |
+
- gl
|
36 |
+
- gu
|
37 |
+
- ha
|
38 |
+
- he
|
39 |
+
- hi
|
40 |
+
- hr
|
41 |
+
- hu
|
42 |
+
- hy
|
43 |
+
- id
|
44 |
+
- is
|
45 |
+
- it
|
46 |
+
- ja
|
47 |
+
- jv
|
48 |
+
- ka
|
49 |
+
- kk
|
50 |
+
- km
|
51 |
+
- kn
|
52 |
+
- ko
|
53 |
+
- ku
|
54 |
+
- ky
|
55 |
+
- la
|
56 |
+
- lo
|
57 |
+
- lt
|
58 |
+
- lv
|
59 |
+
- mg
|
60 |
+
- mk
|
61 |
+
- ml
|
62 |
+
- mn
|
63 |
+
- mr
|
64 |
+
- ms
|
65 |
+
- my
|
66 |
+
- ne
|
67 |
+
- nl
|
68 |
+
- 'no'
|
69 |
+
- om
|
70 |
+
- or
|
71 |
+
- pa
|
72 |
+
- pl
|
73 |
+
- ps
|
74 |
+
- pt
|
75 |
+
- ro
|
76 |
+
- ru
|
77 |
+
- sa
|
78 |
+
- sd
|
79 |
+
- si
|
80 |
+
- sk
|
81 |
+
- sl
|
82 |
+
- so
|
83 |
+
- sq
|
84 |
+
- sr
|
85 |
+
- su
|
86 |
+
- sv
|
87 |
+
- sw
|
88 |
+
- ta
|
89 |
+
- te
|
90 |
+
- th
|
91 |
+
- tl
|
92 |
+
- tr
|
93 |
+
- ug
|
94 |
+
- uk
|
95 |
+
- ur
|
96 |
+
- uz
|
97 |
+
- vi
|
98 |
+
- xh
|
99 |
+
- yi
|
100 |
+
- zh
|
101 |
+
pipeline_tag: zero-shot-classification
|
102 |
+
metrics:
|
103 |
+
- accuracy
|
104 |
---
|
105 |
+
# Zero-shot text classification (base-sized model) trained with self-supervised tuning
|
106 |
+
|
107 |
+
Zero-shot text classification model trained with self-supervised tuning (SSTuning).
|
108 |
+
It was introduced in the paper [Zero-Shot Text Classification via Self-Supervised Tuning](https://arxiv.org/abs/2305.11442) by
|
109 |
+
Chaoqun Liu, Wenxuan Zhang, Guizhen Chen, Xiaobao Wu, Anh Tuan Luu, Chip Hong Chang, Lidong Bing
|
110 |
+
and first released in [this repository](https://github.com/DAMO-NLP-SG/SSTuning).
|
111 |
+
|
112 |
+
The model backbone is RoBERTa-base.
|
113 |
+
|
114 |
+
## Model description
|
115 |
+
|
116 |
+
|
117 |
+
The model is tuned with unlabeled data using a first sentence prediction (FSP) learning objective.
|
118 |
+
The FSP task is designed by considering both the nature of the unlabeled corpus and the input/output format of classification tasks.
|
119 |
+
|
120 |
+
The training and validation sets are constructed from the unlabeled corpus using FSP.
|
121 |
+
|
122 |
+
During tuning, BERT-like pre-trained masked language
|
123 |
+
models such as RoBERTa and ALBERT are employed as the backbone, and an output layer for classification is added.
|
124 |
+
The learning objective for FSP is to predict the index of the correct label.
|
125 |
+
A cross-entropy loss is used for tuning the model.
|
126 |
+
|
127 |
+
## Model variations
|
128 |
+
There are three versions of models released. The details are:
|
129 |
+
|
130 |
+
| Model | Backbone | #params | accuracy | Speed | #Training data
|
131 |
+
|------------|-----------|----------|-------|-------|----|
|
132 |
+
| [zero-shot-classify-SSTuning-base](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-base) | [roberta-base](https://huggingface.co/roberta-base) | 125M | Low | High | 20.48M |
|
133 |
+
| [zero-shot-classify-SSTuning-large](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-large) | [roberta-large](https://huggingface.co/roberta-large) | 355M | Medium | Medium | 5.12M |
|
134 |
+
| [zero-shot-classify-SSTuning-ALBERT](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-ALBERT) | [albert-xxlarge-v2](https://huggingface.co/albert-xxlarge-v2) | 235M | High | Low| 5.12M |
|
135 |
+
| [zero-shot-classify-SSTuning-XLM-R](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-XLM-R) | [xlm-roberta-base](https://huggingface.co/xlm-roberta-base) | 278M | - | - | 20.48M |
|
136 |
+
|
137 |
+
Please note that zero-shot-classify-SSTuning-XLM-R is trained with 20.48M English samples only. However, it can also be used in other languages as long as XLM-R
|
138 |
+
|
139 |
+
## Intended uses & limitations
|
140 |
+
The model can be used for zero-shot text classification such as sentiment analysis and topic classification. No further finetuning is needed.
|
141 |
+
|
142 |
+
The number of labels should be 2 ~ 20.
|
143 |
+
|
144 |
+
### How to use
|
145 |
+
You can try the model with the Colab [Notebook](https://colab.research.google.com/drive/17bqc8cXFF-wDmZ0o8j7sbrQB9Cq7Gowr?usp=sharing).
|
146 |
+
|
147 |
+
```python
|
148 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
149 |
+
import torch, string, random
|
150 |
+
|
151 |
+
tokenizer = AutoTokenizer.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuning-base")
|
152 |
+
model = AutoModelForSequenceClassification.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuning-base")
|
153 |
+
|
154 |
+
text = "I love this place! The food is always so fresh and delicious."
|
155 |
+
list_label = ["negative", "positive"]
|
156 |
+
|
157 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
158 |
+
list_ABC = [x for x in string.ascii_uppercase]
|
159 |
+
|
160 |
+
def check_text(model, text, list_label, shuffle=False):
|
161 |
+
list_label = [x+'.' if x[-1] != '.' else x for x in list_label]
|
162 |
+
list_label_new = list_label + [tokenizer.pad_token]* (20 - len(list_label))
|
163 |
+
if shuffle:
|
164 |
+
random.shuffle(list_label_new)
|
165 |
+
s_option = ' '.join(['('+list_ABC[i]+') '+list_label_new[i] for i in range(len(list_label_new))])
|
166 |
+
text = f'{s_option} {tokenizer.sep_token} {text}'
|
167 |
+
|
168 |
+
model.to(device).eval()
|
169 |
+
encoding = tokenizer([text],truncation=True, max_length=512,return_tensors='pt')
|
170 |
+
item = {key: val.to(device) for key, val in encoding.items()}
|
171 |
+
logits = model(**item).logits
|
172 |
+
|
173 |
+
logits = logits if shuffle else logits[:,0:len(list_label)]
|
174 |
+
probs = torch.nn.functional.softmax(logits, dim = -1).tolist()
|
175 |
+
predictions = torch.argmax(logits, dim=-1).item()
|
176 |
+
probabilities = [round(x,5) for x in probs[0]]
|
177 |
+
|
178 |
+
print(f'prediction: {predictions} => ({list_ABC[predictions]}) {list_label_new[predictions]}')
|
179 |
+
print(f'probability: {round(probabilities[predictions]*100,2)}%')
|
180 |
+
|
181 |
+
check_text(model, text, list_label)
|
182 |
+
# prediction: 1 => (B) positive.
|
183 |
+
# probability: 99.92%
|
184 |
+
```
|
185 |
+
|
186 |
+
|
187 |
+
### BibTeX entry and citation info
|
188 |
+
```bibtxt
|
189 |
+
@inproceedings{acl23/SSTuning,
|
190 |
+
author = {Chaoqun Liu and
|
191 |
+
Wenxuan Zhang and
|
192 |
+
Guizhen Chen and
|
193 |
+
Xiaobao Wu and
|
194 |
+
Anh Tuan Luu and
|
195 |
+
Chip Hong Chang and
|
196 |
+
Lidong Bing},
|
197 |
+
title = {Zero-Shot Text Classification via Self-Supervised Tuning},
|
198 |
+
booktitle = {Findings of the Association for Computational Linguistics: ACL 2023},
|
199 |
+
year = {2023},
|
200 |
+
url = {https://arxiv.org/abs/2305.11442},
|
201 |
+
}
|
202 |
+
```
|