lukecq commited on
Commit
29633e7
·
1 Parent(s): 2a87f9d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +199 -0
README.md CHANGED
@@ -1,3 +1,202 @@
1
  ---
 
2
  license: mit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ inference: false
3
  license: mit
4
+ tags:
5
+ - Zero-Shot Classification
6
+ language:
7
+ - multilingual
8
+ - af
9
+ - am
10
+ - ar
11
+ - as
12
+ - az
13
+ - be
14
+ - bg
15
+ - bn
16
+ - br
17
+ - bs
18
+ - ca
19
+ - cs
20
+ - cy
21
+ - da
22
+ - de
23
+ - el
24
+ - en
25
+ - eo
26
+ - es
27
+ - et
28
+ - eu
29
+ - fa
30
+ - fi
31
+ - fr
32
+ - fy
33
+ - ga
34
+ - gd
35
+ - gl
36
+ - gu
37
+ - ha
38
+ - he
39
+ - hi
40
+ - hr
41
+ - hu
42
+ - hy
43
+ - id
44
+ - is
45
+ - it
46
+ - ja
47
+ - jv
48
+ - ka
49
+ - kk
50
+ - km
51
+ - kn
52
+ - ko
53
+ - ku
54
+ - ky
55
+ - la
56
+ - lo
57
+ - lt
58
+ - lv
59
+ - mg
60
+ - mk
61
+ - ml
62
+ - mn
63
+ - mr
64
+ - ms
65
+ - my
66
+ - ne
67
+ - nl
68
+ - 'no'
69
+ - om
70
+ - or
71
+ - pa
72
+ - pl
73
+ - ps
74
+ - pt
75
+ - ro
76
+ - ru
77
+ - sa
78
+ - sd
79
+ - si
80
+ - sk
81
+ - sl
82
+ - so
83
+ - sq
84
+ - sr
85
+ - su
86
+ - sv
87
+ - sw
88
+ - ta
89
+ - te
90
+ - th
91
+ - tl
92
+ - tr
93
+ - ug
94
+ - uk
95
+ - ur
96
+ - uz
97
+ - vi
98
+ - xh
99
+ - yi
100
+ - zh
101
+ pipeline_tag: zero-shot-classification
102
+ metrics:
103
+ - accuracy
104
  ---
105
+ # Zero-shot text classification (base-sized model) trained with self-supervised tuning
106
+
107
+ Zero-shot text classification model trained with self-supervised tuning (SSTuning).
108
+ It was introduced in the paper [Zero-Shot Text Classification via Self-Supervised Tuning](https://arxiv.org/abs/2305.11442) by
109
+ Chaoqun Liu, Wenxuan Zhang, Guizhen Chen, Xiaobao Wu, Anh Tuan Luu, Chip Hong Chang, Lidong Bing
110
+ and first released in [this repository](https://github.com/DAMO-NLP-SG/SSTuning).
111
+
112
+ The model backbone is RoBERTa-base.
113
+
114
+ ## Model description
115
+
116
+
117
+ The model is tuned with unlabeled data using a first sentence prediction (FSP) learning objective.
118
+ The FSP task is designed by considering both the nature of the unlabeled corpus and the input/output format of classification tasks.
119
+
120
+ The training and validation sets are constructed from the unlabeled corpus using FSP.
121
+
122
+ During tuning, BERT-like pre-trained masked language
123
+ models such as RoBERTa and ALBERT are employed as the backbone, and an output layer for classification is added.
124
+ The learning objective for FSP is to predict the index of the correct label.
125
+ A cross-entropy loss is used for tuning the model.
126
+
127
+ ## Model variations
128
+ There are three versions of models released. The details are:
129
+
130
+ | Model | Backbone | #params | accuracy | Speed | #Training data
131
+ |------------|-----------|----------|-------|-------|----|
132
+ | [zero-shot-classify-SSTuning-base](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-base) | [roberta-base](https://huggingface.co/roberta-base) | 125M | Low | High | 20.48M |
133
+ | [zero-shot-classify-SSTuning-large](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-large) | [roberta-large](https://huggingface.co/roberta-large) | 355M | Medium | Medium | 5.12M |
134
+ | [zero-shot-classify-SSTuning-ALBERT](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-ALBERT) | [albert-xxlarge-v2](https://huggingface.co/albert-xxlarge-v2) | 235M | High | Low| 5.12M |
135
+ | [zero-shot-classify-SSTuning-XLM-R](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-XLM-R) | [xlm-roberta-base](https://huggingface.co/xlm-roberta-base) | 278M | - | - | 20.48M |
136
+
137
+ Please note that zero-shot-classify-SSTuning-XLM-R is trained with 20.48M English samples only. However, it can also be used in other languages as long as XLM-R
138
+
139
+ ## Intended uses & limitations
140
+ The model can be used for zero-shot text classification such as sentiment analysis and topic classification. No further finetuning is needed.
141
+
142
+ The number of labels should be 2 ~ 20.
143
+
144
+ ### How to use
145
+ You can try the model with the Colab [Notebook](https://colab.research.google.com/drive/17bqc8cXFF-wDmZ0o8j7sbrQB9Cq7Gowr?usp=sharing).
146
+
147
+ ```python
148
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
149
+ import torch, string, random
150
+
151
+ tokenizer = AutoTokenizer.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuning-base")
152
+ model = AutoModelForSequenceClassification.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuning-base")
153
+
154
+ text = "I love this place! The food is always so fresh and delicious."
155
+ list_label = ["negative", "positive"]
156
+
157
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
158
+ list_ABC = [x for x in string.ascii_uppercase]
159
+
160
+ def check_text(model, text, list_label, shuffle=False):
161
+ list_label = [x+'.' if x[-1] != '.' else x for x in list_label]
162
+ list_label_new = list_label + [tokenizer.pad_token]* (20 - len(list_label))
163
+ if shuffle:
164
+ random.shuffle(list_label_new)
165
+ s_option = ' '.join(['('+list_ABC[i]+') '+list_label_new[i] for i in range(len(list_label_new))])
166
+ text = f'{s_option} {tokenizer.sep_token} {text}'
167
+
168
+ model.to(device).eval()
169
+ encoding = tokenizer([text],truncation=True, max_length=512,return_tensors='pt')
170
+ item = {key: val.to(device) for key, val in encoding.items()}
171
+ logits = model(**item).logits
172
+
173
+ logits = logits if shuffle else logits[:,0:len(list_label)]
174
+ probs = torch.nn.functional.softmax(logits, dim = -1).tolist()
175
+ predictions = torch.argmax(logits, dim=-1).item()
176
+ probabilities = [round(x,5) for x in probs[0]]
177
+
178
+ print(f'prediction: {predictions} => ({list_ABC[predictions]}) {list_label_new[predictions]}')
179
+ print(f'probability: {round(probabilities[predictions]*100,2)}%')
180
+
181
+ check_text(model, text, list_label)
182
+ # prediction: 1 => (B) positive.
183
+ # probability: 99.92%
184
+ ```
185
+
186
+
187
+ ### BibTeX entry and citation info
188
+ ```bibtxt
189
+ @inproceedings{acl23/SSTuning,
190
+ author = {Chaoqun Liu and
191
+ Wenxuan Zhang and
192
+ Guizhen Chen and
193
+ Xiaobao Wu and
194
+ Anh Tuan Luu and
195
+ Chip Hong Chang and
196
+ Lidong Bing},
197
+ title = {Zero-Shot Text Classification via Self-Supervised Tuning},
198
+ booktitle = {Findings of the Association for Computational Linguistics: ACL 2023},
199
+ year = {2023},
200
+ url = {https://arxiv.org/abs/2305.11442},
201
+ }
202
+ ```