First try entity_ratio option
Browse files- FairEval.py +41 -45
FairEval.py
CHANGED
|
@@ -147,21 +147,24 @@ class FairEvaluation(evaluate.Metric):
|
|
| 147 |
true_spans = seq_to_fair(true_spans)
|
| 148 |
pred_spans = seq_to_fair(pred_spans)
|
| 149 |
|
| 150 |
-
# (3) COUNT ERRORS AND CALCULATE SCORES
|
| 151 |
total_errors = compare_spans([], [])
|
|
|
|
| 152 |
for i in range(len(true_spans)):
|
|
|
|
| 153 |
sentence_errors = compare_spans(true_spans[i], pred_spans[i])
|
| 154 |
total_errors = add_dict(total_errors, sentence_errors)
|
| 155 |
|
| 156 |
if weights is None and mode == 'weighted':
|
| 157 |
-
print("The chosen mode is \'weighted\', but no weights are given. Setting weights to:\n")
|
| 158 |
weights = {"TP": {"TP": 1},
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
print(weights)
|
|
|
|
|
|
|
| 165 |
|
| 166 |
config = {"labels": "all", "eval_method": [mode], "weights": weights,}
|
| 167 |
results = calculate_results(total_errors, config)
|
|
@@ -170,34 +173,36 @@ class FairEvaluation(evaluate.Metric):
|
|
| 170 |
# (4) SELECT OUTPUT MODE AND REFORMAT AS SEQEVAL-HUGGINGFACE OUTPUT
|
| 171 |
# initialize empty dictionary and count errors
|
| 172 |
output = {}
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
# assert valid options
|
| 179 |
assert mode in ['traditional', 'fair', 'weighted'], 'mode must be \'traditional\', \'fair\' or \'weighted\''
|
| 180 |
-
assert error_format in ['count', '
|
| 181 |
|
| 182 |
# append entity-level errors and scores
|
| 183 |
if mode == 'traditional':
|
| 184 |
for k, v in results['per_label'][mode].items():
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
elif error_format == 'proportion':
|
| 189 |
-
output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], 'TP': v['TP'],
|
| 190 |
-
'FP': v['FP'] / total_trad_errors, 'FN': v['FN'] / total_trad_errors}
|
| 191 |
elif mode == 'fair' or mode == 'weighted':
|
| 192 |
for k, v in results['per_label'][mode].items():
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], 'TP': v['TP'],
|
| 198 |
-
'FP': v['FP'] / total_fair_errors, 'FN': v['FN'] / total_fair_errors,
|
| 199 |
-
'LE': v['LE'] / total_fair_errors, 'BE': v['BE'] / total_fair_errors,
|
| 200 |
-
'LBE': v['LBE'] / total_fair_errors}
|
| 201 |
|
| 202 |
# append overall scores
|
| 203 |
output['overall_precision'] = results['overall'][mode]['Prec']
|
|
@@ -206,25 +211,16 @@ class FairEvaluation(evaluate.Metric):
|
|
| 206 |
|
| 207 |
# append overall error counts
|
| 208 |
if mode == 'traditional':
|
| 209 |
-
output['TP'] = results['overall'][mode]['TP']
|
| 210 |
-
output['FP'] = results['overall'][mode]['FP']
|
| 211 |
-
output['FN'] = results['overall'][mode]['FN']
|
| 212 |
-
if error_format == 'proportion':
|
| 213 |
-
output['FP'] = output['FP'] / total_trad_errors
|
| 214 |
-
output['FN'] = output['FN'] / total_trad_errors
|
| 215 |
elif mode == 'fair' or 'weighted':
|
| 216 |
-
output['TP'] = results['overall'][mode]['TP']
|
| 217 |
-
output['FP'] = results['overall'][mode]['FP']
|
| 218 |
-
output['FN'] = results['overall'][mode]['FN']
|
| 219 |
-
output['LE'] = results['overall'][mode]['LE']
|
| 220 |
-
output['BE'] = results['overall'][mode]['BE']
|
| 221 |
-
output['LBE'] = results['overall'][mode]['LBE']
|
| 222 |
-
if error_format == 'proportion':
|
| 223 |
-
output['FP'] = output['FP'] / total_fair_errors
|
| 224 |
-
output['FN'] = output['FN'] / total_fair_errors
|
| 225 |
-
output['LE'] = output['LE'] / total_fair_errors
|
| 226 |
-
output['BE'] = output['BE'] / total_fair_errors
|
| 227 |
-
output['LBE'] = output['LBE'] / total_fair_errors
|
| 228 |
|
| 229 |
return output
|
| 230 |
|
|
|
|
| 147 |
true_spans = seq_to_fair(true_spans)
|
| 148 |
pred_spans = seq_to_fair(pred_spans)
|
| 149 |
|
| 150 |
+
# (3) COUNT ERRORS AND CALCULATE SCORES (counting total ground truth entities too)
|
| 151 |
total_errors = compare_spans([], [])
|
| 152 |
+
total_ref_entities = 0
|
| 153 |
for i in range(len(true_spans)):
|
| 154 |
+
total_ref_entities += len(true_spans[i])
|
| 155 |
sentence_errors = compare_spans(true_spans[i], pred_spans[i])
|
| 156 |
total_errors = add_dict(total_errors, sentence_errors)
|
| 157 |
|
| 158 |
if weights is None and mode == 'weighted':
|
|
|
|
| 159 |
weights = {"TP": {"TP": 1},
|
| 160 |
+
"FP": {"FP": 1},
|
| 161 |
+
"FN": {"FN": 1},
|
| 162 |
+
"LE": {"TP": 0, "FP": 0.5, "FN": 0.5},
|
| 163 |
+
"BE": {"TP": 0.5, "FP": 0.25, "FN": 0.25},
|
| 164 |
+
"LBE": {"TP": 0, "FP": 0.5, "FN": 0.5}}
|
| 165 |
+
print("The chosen mode is \'weighted\', but no weights are given. Setting weights to:\n")
|
| 166 |
+
for k in weights:
|
| 167 |
+
print(k, ":", weights[k])
|
| 168 |
|
| 169 |
config = {"labels": "all", "eval_method": [mode], "weights": weights,}
|
| 170 |
results = calculate_results(total_errors, config)
|
|
|
|
| 173 |
# (4) SELECT OUTPUT MODE AND REFORMAT AS SEQEVAL-HUGGINGFACE OUTPUT
|
| 174 |
# initialize empty dictionary and count errors
|
| 175 |
output = {}
|
| 176 |
+
# control the divider for the error_format (count, proportion over total errors or over total entities)
|
| 177 |
+
if error_format == 'count':
|
| 178 |
+
trad_divider = 1,
|
| 179 |
+
fair_divider = 1,
|
| 180 |
+
elif error_format == 'entity_ratio':
|
| 181 |
+
trad_divider = total_ref_entities
|
| 182 |
+
fair_divider = total_ref_entities
|
| 183 |
+
elif error_format == 'error_ratio':
|
| 184 |
+
trad_divider = results['overall']['traditional']['FP'] + results['overall']['traditional']['FN']
|
| 185 |
+
fair_divider = results['overall']['fair']['FP'] + results['overall']['fair']['FN'] + \
|
| 186 |
+
results['overall']['fair']['LE'] + results['overall']['fair']['BE'] + \
|
| 187 |
+
results['overall']['fair']['LBE']
|
| 188 |
+
|
| 189 |
|
| 190 |
# assert valid options
|
| 191 |
assert mode in ['traditional', 'fair', 'weighted'], 'mode must be \'traditional\', \'fair\' or \'weighted\''
|
| 192 |
+
assert error_format in ['count', 'error_ratio', 'entity_ratio'], 'error_format must be \'count\', \'error_ratio\' or \'entity_ratio\''
|
| 193 |
|
| 194 |
# append entity-level errors and scores
|
| 195 |
if mode == 'traditional':
|
| 196 |
for k, v in results['per_label'][mode].items():
|
| 197 |
+
output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'],
|
| 198 |
+
'TP': v['TP'] / trad_divider if error_format == 'entity_ratio' else v['TP'],
|
| 199 |
+
'FP': v['FP'] / trad_divider, 'FN': v['FN'] / trad_divider}
|
|
|
|
|
|
|
|
|
|
| 200 |
elif mode == 'fair' or mode == 'weighted':
|
| 201 |
for k, v in results['per_label'][mode].items():
|
| 202 |
+
output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'],
|
| 203 |
+
'TP': v['TP'] / fair_divider if error_format == 'entity_ratio' else v['TP'],
|
| 204 |
+
'FP': v['FP'] / fair_divider, 'FN': v['FN'] / fair_divider,
|
| 205 |
+
'LE': v['LE'] / fair_divider, 'BE': v['BE'] / fair_divider, 'LBE': v['LBE'] / fair_divider}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
# append overall scores
|
| 208 |
output['overall_precision'] = results['overall'][mode]['Prec']
|
|
|
|
| 211 |
|
| 212 |
# append overall error counts
|
| 213 |
if mode == 'traditional':
|
| 214 |
+
output['TP'] = results['overall'][mode]['TP'] / trad_divider if error_format == 'entity_ratio' else results['overall'][mode]['TP']
|
| 215 |
+
output['FP'] = results['overall'][mode]['FP'] / trad_divider
|
| 216 |
+
output['FN'] = results['overall'][mode]['FN'] / trad_divider
|
|
|
|
|
|
|
|
|
|
| 217 |
elif mode == 'fair' or 'weighted':
|
| 218 |
+
output['TP'] = results['overall'][mode]['TP'] / fair_divider if error_format == 'entity_ratio' else results['overall'][mode]['TP']
|
| 219 |
+
output['FP'] = results['overall'][mode]['FP'] / fair_divider
|
| 220 |
+
output['FN'] = results['overall'][mode]['FN'] / fair_divider
|
| 221 |
+
output['LE'] = results['overall'][mode]['LE'] / fair_divider
|
| 222 |
+
output['BE'] = results['overall'][mode]['BE'] / fair_divider
|
| 223 |
+
output['LBE'] = results['overall'][mode]['LBE'] / fair_divider
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
return output
|
| 226 |
|