Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,156 @@
|
|
1 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
license: mit
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
language: en
|
3 |
+
tags:
|
4 |
+
- text-classification
|
5 |
+
- onnx
|
6 |
+
- bge-small-en
|
7 |
+
- emotions
|
8 |
+
- multi-class-classification
|
9 |
+
- multi-label-classification
|
10 |
+
datasets:
|
11 |
+
- go_emotions
|
12 |
+
models:
|
13 |
+
- BAAI/bge-small-en
|
14 |
license: mit
|
15 |
+
inference: false
|
16 |
+
widget:
|
17 |
+
- text: ONNX is so much faster, its very handy!
|
18 |
---
|
19 |
+
|
20 |
+
### Overview
|
21 |
+
|
22 |
+
This is a multi-label, multi-class linear classifer for emotions that works with [BGE-small-en embeddings](https://huggingface.co/BAAI/bge-small-en), having been trained on the [go_emotions](https://huggingface.co/datasets/go_emotions) dataset.
|
23 |
+
|
24 |
+
### Labels
|
25 |
+
|
26 |
+
The 28 labels from the [go_emotions](https://huggingface.co/datasets/go_emotions) dataset are:
|
27 |
+
```
|
28 |
+
['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']
|
29 |
+
```
|
30 |
+
|
31 |
+
### Metrics (exact match of labels per item)
|
32 |
+
|
33 |
+
This is a multi-label, multi-class dataset, so each label is effectively a separate binary classification. Evaluating across all labels per item in the go_emotions test split the metrics are shown below.
|
34 |
+
|
35 |
+
Optimising the threshold per label to optimise the F1 metric, the metrics (evaluated on the go_emotions test split) are:
|
36 |
+
|
37 |
+
- Precision: 0.429
|
38 |
+
- Recall: 0.483
|
39 |
+
- F1: 0.439
|
40 |
+
|
41 |
+
Weighted by the relative support of each label in the dataset, this is:
|
42 |
+
|
43 |
+
- Precision: 0.457
|
44 |
+
- Recall: 0.585
|
45 |
+
- F1: 0.502
|
46 |
+
|
47 |
+
Using a fixed threshold of 0.5 to convert the scores to binary predictions for each label, the metrics (evaluated on the go_emotions test split, and unweighted by support) are:
|
48 |
+
|
49 |
+
- Precision: 0.650
|
50 |
+
- Recall: 0.189
|
51 |
+
- F1: 0.249
|
52 |
+
|
53 |
+
### Metrics (per-label)
|
54 |
+
|
55 |
+
This is a multi-label, multi-class dataset, so each label is effectively a separate binary classification and metrics are better measured per label.
|
56 |
+
|
57 |
+
Optimising the threshold per label to optimise the F1 metric, the metrics (evaluated on the go_emotions test split) are:
|
58 |
+
|
59 |
+
| | f1 | precision | recall | support | threshold |
|
60 |
+
| -------------- | ----- | --------- | ------ | ------- | --------- |
|
61 |
+
| admiration | 0.561 | 0.517 | 0.613 | 504 | 0.25 |
|
62 |
+
| amusement | 0.647 | 0.663 | 0.633 | 264 | 0.20 |
|
63 |
+
| anger | 0.324 | 0.238 | 0.510 | 198 | 0.10 |
|
64 |
+
| annoyance | 0.292 | 0.200 | 0.541 | 320 | 0.10 |
|
65 |
+
| approval | 0.335 | 0.297 | 0.385 | 351 | 0.15 |
|
66 |
+
| caring | 0.306 | 0.221 | 0.496 | 135 | 0.10 |
|
67 |
+
| confusion | 0.360 | 0.400 | 0.327 | 153 | 0.20 |
|
68 |
+
| curiosity | 0.461 | 0.392 | 0.560 | 284 | 0.15 |
|
69 |
+
| desire | 0.411 | 0.476 | 0.361 | 83 | 0.25 |
|
70 |
+
| disappointment | 0.204 | 0.150 | 0.318 | 151 | 0.10 |
|
71 |
+
| disapproval | 0.357 | 0.291 | 0.461 | 267 | 0.15 |
|
72 |
+
| disgust | 0.403 | 0.417 | 0.390 | 123 | 0.20 |
|
73 |
+
| embarrassment | 0.424 | 0.483 | 0.378 | 37 | 0.30 |
|
74 |
+
| excitement | 0.298 | 0.255 | 0.359 | 103 | 0.15 |
|
75 |
+
| fear | 0.609 | 0.590 | 0.628 | 78 | 0.25 |
|
76 |
+
| gratitude | 0.801 | 0.819 | 0.784 | 352 | 0.30 |
|
77 |
+
| grief | 0.500 | 0.500 | 0.500 | 6 | 0.75 |
|
78 |
+
| joy | 0.437 | 0.453 | 0.422 | 161 | 0.20 |
|
79 |
+
| love | 0.641 | 0.693 | 0.597 | 238 | 0.30 |
|
80 |
+
| nervousness | 0.356 | 0.364 | 0.348 | 23 | 0.45 |
|
81 |
+
| optimism | 0.416 | 0.538 | 0.339 | 186 | 0.25 |
|
82 |
+
| pride | 0.500 | 0.750 | 0.375 | 16 | 0.65 |
|
83 |
+
| realization | 0.247 | 0.228 | 0.269 | 145 | 0.10 |
|
84 |
+
| relief | 0.364 | 0.273 | 0.545 | 11 | 0.30 |
|
85 |
+
| remorse | 0.581 | 0.529 | 0.643 | 56 | 0.25 |
|
86 |
+
| sadness | 0.525 | 0.519 | 0.532 | 156 | 0.20 |
|
87 |
+
| surprise | 0.301 | 0.235 | 0.418 | 141 | 0.10 |
|
88 |
+
| neutral | 0.626 | 0.519 | 0.786 | 1787 | 0.30 |
|
89 |
+
|
90 |
+
Using a fixed threshold of 0.5 to convert the scores to binary predictions for each label, the metrics (evaluated on the go_emotions test split) are:
|
91 |
+
|
92 |
+
| | f1 | precision | recall | support | threshold |
|
93 |
+
| ------------- | ----- | --------- | ------ | ------- | --------- |
|
94 |
+
|admiration | 0.443 | 0.722 | 0.319 | 504 | 0.5 |
|
95 |
+
|amusement | 0.364 | 0.805 | 0.235 | 264 | 0.5 |
|
96 |
+
|anger | 0.100 | 0.478 | 0.056 | 198 | 0.5 |
|
97 |
+
|annoyance | 0.012 | 0.667 | 0.006 | 320 | 0.5 |
|
98 |
+
|approval | 0.082 | 0.882 | 0.043 | 351 | 0.5 |
|
99 |
+
|caring | 0.118 | 0.500 | 0.067 | 135 | 0.5 |
|
100 |
+
|confusion | 0.107 | 0.600 | 0.059 | 153 | 0.5 |
|
101 |
+
|curiosity | 0.242 | 0.550 | 0.155 | 284 | 0.5 |
|
102 |
+
|desire | 0.204 | 0.667 | 0.120 | 83 | 0.5 |
|
103 |
+
|disappointment | 0.026 | 1.000 | 0.013 | 151 | 0.5 |
|
104 |
+
|disapproval | 0.084 | 0.600 | 0.045 | 267 | 0.5 |
|
105 |
+
|disgust | 0.243 | 0.720 | 0.146 | 123 | 0.5 |
|
106 |
+
|embarrassment | 0.217 | 0.556 | 0.135 | 37 | 0.5 |
|
107 |
+
|excitement | 0.037 | 0.333 | 0.019 | 103 | 0.5 |
|
108 |
+
|fear | 0.466 | 0.711 | 0.346 | 78 | 0.5 |
|
109 |
+
|gratitude | 0.757 | 0.915 | 0.645 | 352 | 0.5 |
|
110 |
+
|grief | 0.286 | 0.200 | 0.500 | 6 | 0.5 |
|
111 |
+
|joy | 0.197 | 0.818 | 0.112 | 161 | 0.5 |
|
112 |
+
|love | 0.519 | 0.805 | 0.382 | 238 | 0.5 |
|
113 |
+
|nervousness | 0.293 | 0.333 | 0.261 | 23 | 0.5 |
|
114 |
+
|optimism | 0.260 | 0.784 | 0.156 | 186 | 0.5 |
|
115 |
+
|pride | 0.444 | 0.545 | 0.375 | 16 | 0.5 |
|
116 |
+
|realization | 0.014 | 0.500 | 0.007 | 145 | 0.5 |
|
117 |
+
|relief | 0.154 | 0.500 | 0.091 | 11 | 0.5 |
|
118 |
+
|remorse | 0.449 | 0.606 | 0.357 | 56 | 0.5 |
|
119 |
+
|sadness | 0.297 | 0.744 | 0.186 | 156 | 0.5 |
|
120 |
+
|surprise | 0.042 | 1.000 | 0.021 | 141 | 0.5 |
|
121 |
+
|neutral | 0.528 | 0.649 | 0.445 | 1787 | 0.5 |
|
122 |
+
|
123 |
+
### Use with ONNXRuntime
|
124 |
+
|
125 |
+
The input to the model is called `logits`, and there is one output per label. Each output produces a 2d array, with 1 row per input row, and each row having 2 columns - the first being a proba output for the negative case, and the second being a proba output for the positive case.
|
126 |
+
|
127 |
+
```python
|
128 |
+
# Assuming you have embeddings from BAAI/bge-small-en for the input sentences
|
129 |
+
# E.g. produced from sentence-transformers E.g. huggingface.co/BAAI/bge-small-en
|
130 |
+
# or from an ONNX version E.g. huggingface.co/Xenova/bge-small-en
|
131 |
+
|
132 |
+
print(sentences.shape) # E.g. a batch of 1 sentence
|
133 |
+
> (1, 384)
|
134 |
+
|
135 |
+
import onnxruntime as ort
|
136 |
+
|
137 |
+
sess = ort.InferenceSession("path_to_model_dot_onnx", providers=['CPUExecutionProvider'])
|
138 |
+
|
139 |
+
outputs = [o.name for o in sess.get_outputs()] # list of labels, in the order of the outputs
|
140 |
+
preds_onnx = sess.run(_outputs, {'logits': _label_embeddings})
|
141 |
+
# preds_onnx is a list with 28 entries, one per label,
|
142 |
+
# each with a numpy array of shape (1, 2) given the input was a batch of 1
|
143 |
+
|
144 |
+
print(outputs[0])
|
145 |
+
> surprise
|
146 |
+
print(preds_onnx[0])
|
147 |
+
> array([[0.97136074, 0.02863926]], dtype=float32)
|
148 |
+
```
|
149 |
+
|
150 |
+
### Commentary on the dataset
|
151 |
+
|
152 |
+
Some labels (E.g. gratitude) when considered independently perform very strongly, whilst others (E.g. relief) perform very poorly.
|
153 |
+
|
154 |
+
This is a challenging dataset. Labels such as relief do have much fewer examples in the training data (less than 100 out of the 40k+, and only 11 in the test split).
|
155 |
+
|
156 |
+
But there is also some ambiguity and/or labelling errors visible in the training data of go_emotions that is suspected to constrain the performance. Data cleaning on the dataset to reduce some of the mistakes, ambiguity, conflicts and duplication in the labelling would produce a higher performing model.
|