First try entity_ratio option
Browse files- FairEval.py +41 -45
    	
        FairEval.py
    CHANGED
    
    | @@ -147,21 +147,24 @@ class FairEvaluation(evaluate.Metric): | |
| 147 | 
             
                    true_spans = seq_to_fair(true_spans)
         | 
| 148 | 
             
                    pred_spans = seq_to_fair(pred_spans)
         | 
| 149 |  | 
| 150 | 
            -
                    # (3) COUNT ERRORS AND CALCULATE SCORES
         | 
| 151 | 
             
                    total_errors = compare_spans([], [])
         | 
|  | |
| 152 | 
             
                    for i in range(len(true_spans)):
         | 
|  | |
| 153 | 
             
                        sentence_errors = compare_spans(true_spans[i], pred_spans[i])
         | 
| 154 | 
             
                        total_errors = add_dict(total_errors, sentence_errors)
         | 
| 155 |  | 
| 156 | 
             
                    if weights is None and mode == 'weighted':
         | 
| 157 | 
            -
                        print("The chosen mode is \'weighted\', but no weights are given. Setting weights to:\n")
         | 
| 158 | 
             
                        weights = {"TP": {"TP": 1},
         | 
| 159 | 
            -
             | 
| 160 | 
            -
             | 
| 161 | 
            -
             | 
| 162 | 
            -
             | 
| 163 | 
            -
             | 
| 164 | 
            -
                        print(weights)
         | 
|  | |
|  | |
| 165 |  | 
| 166 | 
             
                    config = {"labels": "all", "eval_method": [mode], "weights": weights,}
         | 
| 167 | 
             
                    results = calculate_results(total_errors, config)
         | 
| @@ -170,34 +173,36 @@ class FairEvaluation(evaluate.Metric): | |
| 170 | 
             
                    # (4) SELECT OUTPUT MODE AND REFORMAT AS SEQEVAL-HUGGINGFACE OUTPUT
         | 
| 171 | 
             
                    # initialize empty dictionary and count errors
         | 
| 172 | 
             
                    output = {}
         | 
| 173 | 
            -
                     | 
| 174 | 
            -
                     | 
| 175 | 
            -
             | 
| 176 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 177 |  | 
| 178 | 
             
                    # assert valid options
         | 
| 179 | 
             
                    assert mode in ['traditional', 'fair', 'weighted'], 'mode must be \'traditional\', \'fair\' or \'weighted\''
         | 
| 180 | 
            -
                    assert error_format in ['count', ' | 
| 181 |  | 
| 182 | 
             
                    # append entity-level errors and scores
         | 
| 183 | 
             
                    if mode == 'traditional':
         | 
| 184 | 
             
                        for k, v in results['per_label'][mode].items():
         | 
| 185 | 
            -
                             | 
| 186 | 
            -
             | 
| 187 | 
            -
             | 
| 188 | 
            -
                            elif error_format == 'proportion':
         | 
| 189 | 
            -
                                output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], 'TP': v['TP'],
         | 
| 190 | 
            -
                                             'FP': v['FP'] / total_trad_errors, 'FN': v['FN'] / total_trad_errors}
         | 
| 191 | 
             
                    elif mode == 'fair' or mode == 'weighted':
         | 
| 192 | 
             
                        for k, v in results['per_label'][mode].items():
         | 
| 193 | 
            -
                             | 
| 194 | 
            -
             | 
| 195 | 
            -
             | 
| 196 | 
            -
             | 
| 197 | 
            -
                                output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], 'TP': v['TP'],
         | 
| 198 | 
            -
                                             'FP': v['FP'] / total_fair_errors, 'FN': v['FN'] / total_fair_errors,
         | 
| 199 | 
            -
                                             'LE': v['LE'] / total_fair_errors, 'BE': v['BE'] / total_fair_errors,
         | 
| 200 | 
            -
                                             'LBE': v['LBE'] / total_fair_errors}
         | 
| 201 |  | 
| 202 | 
             
                    # append overall scores
         | 
| 203 | 
             
                    output['overall_precision'] = results['overall'][mode]['Prec']
         | 
| @@ -206,25 +211,16 @@ class FairEvaluation(evaluate.Metric): | |
| 206 |  | 
| 207 | 
             
                    # append overall error counts
         | 
| 208 | 
             
                    if mode == 'traditional':
         | 
| 209 | 
            -
                        output['TP'] = results['overall'][mode]['TP']
         | 
| 210 | 
            -
                        output['FP'] = results['overall'][mode]['FP']
         | 
| 211 | 
            -
                        output['FN'] = results['overall'][mode]['FN']
         | 
| 212 | 
            -
                        if error_format == 'proportion':
         | 
| 213 | 
            -
                            output['FP'] = output['FP'] / total_trad_errors
         | 
| 214 | 
            -
                            output['FN'] = output['FN'] / total_trad_errors
         | 
| 215 | 
             
                    elif mode == 'fair' or 'weighted':
         | 
| 216 | 
            -
                        output['TP'] = results['overall'][mode]['TP']
         | 
| 217 | 
            -
                        output['FP'] = results['overall'][mode]['FP']
         | 
| 218 | 
            -
                        output['FN'] = results['overall'][mode]['FN']
         | 
| 219 | 
            -
                        output['LE'] = results['overall'][mode]['LE']
         | 
| 220 | 
            -
                        output['BE'] = results['overall'][mode]['BE']
         | 
| 221 | 
            -
                        output['LBE'] = results['overall'][mode]['LBE']
         | 
| 222 | 
            -
                        if error_format == 'proportion':
         | 
| 223 | 
            -
                            output['FP'] = output['FP'] / total_fair_errors
         | 
| 224 | 
            -
                            output['FN'] = output['FN'] / total_fair_errors
         | 
| 225 | 
            -
                            output['LE'] = output['LE'] / total_fair_errors
         | 
| 226 | 
            -
                            output['BE'] = output['BE'] / total_fair_errors
         | 
| 227 | 
            -
                            output['LBE'] = output['LBE'] / total_fair_errors
         | 
| 228 |  | 
| 229 | 
             
                    return output
         | 
| 230 |  | 
|  | |
| 147 | 
             
                    true_spans = seq_to_fair(true_spans)
         | 
| 148 | 
             
                    pred_spans = seq_to_fair(pred_spans)
         | 
| 149 |  | 
| 150 | 
            +
                    # (3) COUNT ERRORS AND CALCULATE SCORES (counting total ground truth entities too)
         | 
| 151 | 
             
                    total_errors = compare_spans([], [])
         | 
| 152 | 
            +
                    total_ref_entities = 0
         | 
| 153 | 
             
                    for i in range(len(true_spans)):
         | 
| 154 | 
            +
                        total_ref_entities += len(true_spans[i])
         | 
| 155 | 
             
                        sentence_errors = compare_spans(true_spans[i], pred_spans[i])
         | 
| 156 | 
             
                        total_errors = add_dict(total_errors, sentence_errors)
         | 
| 157 |  | 
| 158 | 
             
                    if weights is None and mode == 'weighted':
         | 
|  | |
| 159 | 
             
                        weights = {"TP": {"TP": 1},
         | 
| 160 | 
            +
                                   "FP": {"FP": 1},
         | 
| 161 | 
            +
                                   "FN": {"FN": 1},
         | 
| 162 | 
            +
                                   "LE": {"TP": 0, "FP": 0.5, "FN": 0.5},
         | 
| 163 | 
            +
                                   "BE": {"TP": 0.5, "FP": 0.25, "FN": 0.25},
         | 
| 164 | 
            +
                                   "LBE": {"TP": 0, "FP": 0.5, "FN": 0.5}}
         | 
| 165 | 
            +
                        print("The chosen mode is \'weighted\', but no weights are given. Setting weights to:\n")
         | 
| 166 | 
            +
                        for k in weights:
         | 
| 167 | 
            +
                            print(k, ":", weights[k])
         | 
| 168 |  | 
| 169 | 
             
                    config = {"labels": "all", "eval_method": [mode], "weights": weights,}
         | 
| 170 | 
             
                    results = calculate_results(total_errors, config)
         | 
|  | |
| 173 | 
             
                    # (4) SELECT OUTPUT MODE AND REFORMAT AS SEQEVAL-HUGGINGFACE OUTPUT
         | 
| 174 | 
             
                    # initialize empty dictionary and count errors
         | 
| 175 | 
             
                    output = {}
         | 
| 176 | 
            +
                    # control the divider for the error_format (count, proportion over total errors or over total entities)
         | 
| 177 | 
            +
                    if error_format == 'count':
         | 
| 178 | 
            +
                        trad_divider = 1,
         | 
| 179 | 
            +
                        fair_divider = 1,
         | 
| 180 | 
            +
                    elif error_format == 'entity_ratio':
         | 
| 181 | 
            +
                        trad_divider = total_ref_entities
         | 
| 182 | 
            +
                        fair_divider = total_ref_entities
         | 
| 183 | 
            +
                    elif error_format == 'error_ratio':
         | 
| 184 | 
            +
                        trad_divider = results['overall']['traditional']['FP'] + results['overall']['traditional']['FN']
         | 
| 185 | 
            +
                        fair_divider = results['overall']['fair']['FP'] + results['overall']['fair']['FN'] + \
         | 
| 186 | 
            +
                                            results['overall']['fair']['LE'] + results['overall']['fair']['BE'] + \
         | 
| 187 | 
            +
                                            results['overall']['fair']['LBE']
         | 
| 188 | 
            +
             | 
| 189 |  | 
| 190 | 
             
                    # assert valid options
         | 
| 191 | 
             
                    assert mode in ['traditional', 'fair', 'weighted'], 'mode must be \'traditional\', \'fair\' or \'weighted\''
         | 
| 192 | 
            +
                    assert error_format in ['count', 'error_ratio', 'entity_ratio'], 'error_format must be \'count\', \'error_ratio\' or \'entity_ratio\''
         | 
| 193 |  | 
| 194 | 
             
                    # append entity-level errors and scores
         | 
| 195 | 
             
                    if mode == 'traditional':
         | 
| 196 | 
             
                        for k, v in results['per_label'][mode].items():
         | 
| 197 | 
            +
                            output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'],
         | 
| 198 | 
            +
                                         'TP': v['TP'] / trad_divider if error_format == 'entity_ratio' else v['TP'],
         | 
| 199 | 
            +
                                         'FP': v['FP'] / trad_divider, 'FN': v['FN'] / trad_divider}
         | 
|  | |
|  | |
|  | |
| 200 | 
             
                    elif mode == 'fair' or mode == 'weighted':
         | 
| 201 | 
             
                        for k, v in results['per_label'][mode].items():
         | 
| 202 | 
            +
                            output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'],
         | 
| 203 | 
            +
                                         'TP': v['TP'] / fair_divider if error_format == 'entity_ratio' else v['TP'],
         | 
| 204 | 
            +
                                         'FP': v['FP'] / fair_divider, 'FN': v['FN'] / fair_divider,
         | 
| 205 | 
            +
                                         'LE': v['LE'] / fair_divider, 'BE': v['BE'] / fair_divider, 'LBE': v['LBE'] / fair_divider}
         | 
|  | |
|  | |
|  | |
|  | |
| 206 |  | 
| 207 | 
             
                    # append overall scores
         | 
| 208 | 
             
                    output['overall_precision'] = results['overall'][mode]['Prec']
         | 
|  | |
| 211 |  | 
| 212 | 
             
                    # append overall error counts
         | 
| 213 | 
             
                    if mode == 'traditional':
         | 
| 214 | 
            +
                        output['TP'] = results['overall'][mode]['TP'] / trad_divider if error_format == 'entity_ratio' else results['overall'][mode]['TP']
         | 
| 215 | 
            +
                        output['FP'] = results['overall'][mode]['FP'] / trad_divider
         | 
| 216 | 
            +
                        output['FN'] = results['overall'][mode]['FN'] / trad_divider
         | 
|  | |
|  | |
|  | |
| 217 | 
             
                    elif mode == 'fair' or 'weighted':
         | 
| 218 | 
            +
                        output['TP'] = results['overall'][mode]['TP'] / fair_divider if error_format == 'entity_ratio' else results['overall'][mode]['TP']
         | 
| 219 | 
            +
                        output['FP'] = results['overall'][mode]['FP'] / fair_divider
         | 
| 220 | 
            +
                        output['FN'] = results['overall'][mode]['FN'] / fair_divider
         | 
| 221 | 
            +
                        output['LE'] = results['overall'][mode]['LE'] / fair_divider
         | 
| 222 | 
            +
                        output['BE'] = results['overall'][mode]['BE'] / fair_divider
         | 
| 223 | 
            +
                        output['LBE'] = results['overall'][mode]['LBE'] / fair_divider
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 224 |  | 
| 225 | 
             
                    return output
         | 
| 226 |  | 

