SamLowe commited on
Commit
919f303
·
1 Parent(s): f7ca71e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +153 -0
README.md CHANGED
@@ -1,3 +1,156 @@
1
  ---
 
 
 
 
 
 
 
 
 
 
 
 
2
  license: mit
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language: en
3
+ tags:
4
+ - text-classification
5
+ - onnx
6
+ - bge-small-en
7
+ - emotions
8
+ - multi-class-classification
9
+ - multi-label-classification
10
+ datasets:
11
+ - go_emotions
12
+ models:
13
+ - BAAI/bge-small-en
14
  license: mit
15
+ inference: false
16
+ widget:
17
+ - text: ONNX is so much faster, its very handy!
18
  ---
19
+
20
+ ### Overview
21
+
22
+ This is a multi-label, multi-class linear classifer for emotions that works with [BGE-small-en embeddings](https://huggingface.co/BAAI/bge-small-en), having been trained on the [go_emotions](https://huggingface.co/datasets/go_emotions) dataset.
23
+
24
+ ### Labels
25
+
26
+ The 28 labels from the [go_emotions](https://huggingface.co/datasets/go_emotions) dataset are:
27
+ ```
28
+ ['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']
29
+ ```
30
+
31
+ ### Metrics (exact match of labels per item)
32
+
33
+ This is a multi-label, multi-class dataset, so each label is effectively a separate binary classification. Evaluating across all labels per item in the go_emotions test split the metrics are shown below.
34
+
35
+ Optimising the threshold per label to optimise the F1 metric, the metrics (evaluated on the go_emotions test split) are:
36
+
37
+ - Precision: 0.429
38
+ - Recall: 0.483
39
+ - F1: 0.439
40
+
41
+ Weighted by the relative support of each label in the dataset, this is:
42
+
43
+ - Precision: 0.457
44
+ - Recall: 0.585
45
+ - F1: 0.502
46
+
47
+ Using a fixed threshold of 0.5 to convert the scores to binary predictions for each label, the metrics (evaluated on the go_emotions test split, and unweighted by support) are:
48
+
49
+ - Precision: 0.650
50
+ - Recall: 0.189
51
+ - F1: 0.249
52
+
53
+ ### Metrics (per-label)
54
+
55
+ This is a multi-label, multi-class dataset, so each label is effectively a separate binary classification and metrics are better measured per label.
56
+
57
+ Optimising the threshold per label to optimise the F1 metric, the metrics (evaluated on the go_emotions test split) are:
58
+
59
+ | | f1 | precision | recall | support | threshold |
60
+ | -------------- | ----- | --------- | ------ | ------- | --------- |
61
+ | admiration | 0.561 | 0.517 | 0.613 | 504 | 0.25 |
62
+ | amusement | 0.647 | 0.663 | 0.633 | 264 | 0.20 |
63
+ | anger | 0.324 | 0.238 | 0.510 | 198 | 0.10 |
64
+ | annoyance | 0.292 | 0.200 | 0.541 | 320 | 0.10 |
65
+ | approval | 0.335 | 0.297 | 0.385 | 351 | 0.15 |
66
+ | caring | 0.306 | 0.221 | 0.496 | 135 | 0.10 |
67
+ | confusion | 0.360 | 0.400 | 0.327 | 153 | 0.20 |
68
+ | curiosity | 0.461 | 0.392 | 0.560 | 284 | 0.15 |
69
+ | desire | 0.411 | 0.476 | 0.361 | 83 | 0.25 |
70
+ | disappointment | 0.204 | 0.150 | 0.318 | 151 | 0.10 |
71
+ | disapproval | 0.357 | 0.291 | 0.461 | 267 | 0.15 |
72
+ | disgust | 0.403 | 0.417 | 0.390 | 123 | 0.20 |
73
+ | embarrassment | 0.424 | 0.483 | 0.378 | 37 | 0.30 |
74
+ | excitement | 0.298 | 0.255 | 0.359 | 103 | 0.15 |
75
+ | fear | 0.609 | 0.590 | 0.628 | 78 | 0.25 |
76
+ | gratitude | 0.801 | 0.819 | 0.784 | 352 | 0.30 |
77
+ | grief | 0.500 | 0.500 | 0.500 | 6 | 0.75 |
78
+ | joy | 0.437 | 0.453 | 0.422 | 161 | 0.20 |
79
+ | love | 0.641 | 0.693 | 0.597 | 238 | 0.30 |
80
+ | nervousness | 0.356 | 0.364 | 0.348 | 23 | 0.45 |
81
+ | optimism | 0.416 | 0.538 | 0.339 | 186 | 0.25 |
82
+ | pride | 0.500 | 0.750 | 0.375 | 16 | 0.65 |
83
+ | realization | 0.247 | 0.228 | 0.269 | 145 | 0.10 |
84
+ | relief | 0.364 | 0.273 | 0.545 | 11 | 0.30 |
85
+ | remorse | 0.581 | 0.529 | 0.643 | 56 | 0.25 |
86
+ | sadness | 0.525 | 0.519 | 0.532 | 156 | 0.20 |
87
+ | surprise | 0.301 | 0.235 | 0.418 | 141 | 0.10 |
88
+ | neutral | 0.626 | 0.519 | 0.786 | 1787 | 0.30 |
89
+
90
+ Using a fixed threshold of 0.5 to convert the scores to binary predictions for each label, the metrics (evaluated on the go_emotions test split) are:
91
+
92
+ | | f1 | precision | recall | support | threshold |
93
+ | ------------- | ----- | --------- | ------ | ------- | --------- |
94
+ |admiration | 0.443 | 0.722 | 0.319 | 504 | 0.5 |
95
+ |amusement | 0.364 | 0.805 | 0.235 | 264 | 0.5 |
96
+ |anger | 0.100 | 0.478 | 0.056 | 198 | 0.5 |
97
+ |annoyance | 0.012 | 0.667 | 0.006 | 320 | 0.5 |
98
+ |approval | 0.082 | 0.882 | 0.043 | 351 | 0.5 |
99
+ |caring | 0.118 | 0.500 | 0.067 | 135 | 0.5 |
100
+ |confusion | 0.107 | 0.600 | 0.059 | 153 | 0.5 |
101
+ |curiosity | 0.242 | 0.550 | 0.155 | 284 | 0.5 |
102
+ |desire | 0.204 | 0.667 | 0.120 | 83 | 0.5 |
103
+ |disappointment | 0.026 | 1.000 | 0.013 | 151 | 0.5 |
104
+ |disapproval | 0.084 | 0.600 | 0.045 | 267 | 0.5 |
105
+ |disgust | 0.243 | 0.720 | 0.146 | 123 | 0.5 |
106
+ |embarrassment | 0.217 | 0.556 | 0.135 | 37 | 0.5 |
107
+ |excitement | 0.037 | 0.333 | 0.019 | 103 | 0.5 |
108
+ |fear | 0.466 | 0.711 | 0.346 | 78 | 0.5 |
109
+ |gratitude | 0.757 | 0.915 | 0.645 | 352 | 0.5 |
110
+ |grief | 0.286 | 0.200 | 0.500 | 6 | 0.5 |
111
+ |joy | 0.197 | 0.818 | 0.112 | 161 | 0.5 |
112
+ |love | 0.519 | 0.805 | 0.382 | 238 | 0.5 |
113
+ |nervousness | 0.293 | 0.333 | 0.261 | 23 | 0.5 |
114
+ |optimism | 0.260 | 0.784 | 0.156 | 186 | 0.5 |
115
+ |pride | 0.444 | 0.545 | 0.375 | 16 | 0.5 |
116
+ |realization | 0.014 | 0.500 | 0.007 | 145 | 0.5 |
117
+ |relief | 0.154 | 0.500 | 0.091 | 11 | 0.5 |
118
+ |remorse | 0.449 | 0.606 | 0.357 | 56 | 0.5 |
119
+ |sadness | 0.297 | 0.744 | 0.186 | 156 | 0.5 |
120
+ |surprise | 0.042 | 1.000 | 0.021 | 141 | 0.5 |
121
+ |neutral | 0.528 | 0.649 | 0.445 | 1787 | 0.5 |
122
+
123
+ ### Use with ONNXRuntime
124
+
125
+ The input to the model is called `logits`, and there is one output per label. Each output produces a 2d array, with 1 row per input row, and each row having 2 columns - the first being a proba output for the negative case, and the second being a proba output for the positive case.
126
+
127
+ ```python
128
+ # Assuming you have embeddings from BAAI/bge-small-en for the input sentences
129
+ # E.g. produced from sentence-transformers E.g. huggingface.co/BAAI/bge-small-en
130
+ # or from an ONNX version E.g. huggingface.co/Xenova/bge-small-en
131
+
132
+ print(sentences.shape) # E.g. a batch of 1 sentence
133
+ > (1, 384)
134
+
135
+ import onnxruntime as ort
136
+
137
+ sess = ort.InferenceSession("path_to_model_dot_onnx", providers=['CPUExecutionProvider'])
138
+
139
+ outputs = [o.name for o in sess.get_outputs()] # list of labels, in the order of the outputs
140
+ preds_onnx = sess.run(_outputs, {'logits': _label_embeddings})
141
+ # preds_onnx is a list with 28 entries, one per label,
142
+ # each with a numpy array of shape (1, 2) given the input was a batch of 1
143
+
144
+ print(outputs[0])
145
+ > surprise
146
+ print(preds_onnx[0])
147
+ > array([[0.97136074, 0.02863926]], dtype=float32)
148
+ ```
149
+
150
+ ### Commentary on the dataset
151
+
152
+ Some labels (E.g. gratitude) when considered independently perform very strongly, whilst others (E.g. relief) perform very poorly.
153
+
154
+ This is a challenging dataset. Labels such as relief do have much fewer examples in the training data (less than 100 out of the 40k+, and only 11 in the test split).
155
+
156
+ But there is also some ambiguity and/or labelling errors visible in the training data of go_emotions that is suspected to constrain the performance. Data cleaning on the dataset to reduce some of the mistakes, ambiguity, conflicts and duplication in the labelling would produce a higher performing model.