Spaces:
Runtime error
Runtime error
Upload 6 files
Browse files- README.md +66 -5
- app.py +10 -0
- quad_match_score.py +727 -0
- requirements.txt +3 -0
- tests.py +9 -0
README.md
CHANGED
|
@@ -1,12 +1,73 @@
|
|
| 1 |
---
|
| 2 |
title: Quad Match Score
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 3.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: Quad Match Score
|
| 3 |
+
datasets:
|
| 4 |
+
- SemEval2016 Task5
|
| 5 |
+
tags:
|
| 6 |
+
- evaluate
|
| 7 |
+
- metric
|
| 8 |
+
description: "TODO: add a description here"
|
| 9 |
sdk: gradio
|
| 10 |
+
sdk_version: 3.19.1
|
| 11 |
app_file: app.py
|
| 12 |
pinned: false
|
| 13 |
---
|
| 14 |
|
| 15 |
+
# Metric Card for My Metric
|
| 16 |
+
|
| 17 |
+
***Module Card Instructions:*** *评估生成模型的情感四元组抽取结果.*
|
| 18 |
+
|
| 19 |
+
## Metric Description
|
| 20 |
+
*评估生成模型的情感四元组抽取结果.*
|
| 21 |
+
|
| 22 |
+
## How to Use
|
| 23 |
+
```python
|
| 24 |
+
import evaluate
|
| 25 |
+
|
| 26 |
+
module = evaluate.load("yuyijiong/my_metric")
|
| 27 |
+
|
| 28 |
+
predictions=["food | good | food#taste | pos"]
|
| 29 |
+
references=["food | good | food#taste | pos & service | bad | service#general | neg"]
|
| 30 |
+
|
| 31 |
+
module.compute(predictions=predictions, references=references)
|
| 32 |
+
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
### Inputs
|
| 36 |
+
*List all input arguments in the format below*
|
| 37 |
+
- **predictions** *(List[str]): 模型生成的四元组,列表中每个字符串代表一个样本的生成结果.*
|
| 38 |
+
- **references** *(Union[List[str],List[List[str]]):
|
| 39 |
+
人工标注的四元组,列表中每个字符串代表一个样本的标签.如果列表元素为list,代表多个reference,评估时取最高分*
|
| 40 |
+
- **weights** *(Tuple[float, float, float, float]):默认为(1,1,1,1),分别代表四个方面的评估指标的权重*
|
| 41 |
+
- **tuple_len** *(str): indicate the format of the quad, see the following mapping
|
| 42 |
+
指示四元组的格式,默认为'0123'。对应关系如下所示*
|
| 43 |
+
```
|
| 44 |
+
{'0123': "四元组(对象 | 观点 | 方面 | 极性)",
|
| 45 |
+
'01':'二元组(对象 | 观点)',
|
| 46 |
+
'012':'三元组(对象 | 观点 | 方面)',
|
| 47 |
+
'013':'三元组(对象 | 观点 | 极性)',
|
| 48 |
+
'023':'三元组(对象 | 方面 | 极性)',
|
| 49 |
+
'23':'二元组(方面 | 极性)',
|
| 50 |
+
'03':'二元组(对象 | 极性)',
|
| 51 |
+
'13':'二元组(观点 | 极性)',
|
| 52 |
+
'3':'单元素(极性)'}
|
| 53 |
+
```
|
| 54 |
+
- **sep_token1**: the token to seperate quads 分割不同四元组的token
|
| 55 |
+
- **sep_token2**: the token to seperate units of one quad 四元组中不同元素之间的分隔token
|
| 56 |
+
|
| 57 |
+
### Output Values
|
| 58 |
+
|
| 59 |
+
*最优匹配 f1值、最优匹配样本平均得分、完全匹配 f1值 组成的dict,f1值均在[0,1]之间*
|
| 60 |
+
|
| 61 |
+
*例如: {'ave match score of weight (1, 1, 1, 1)': 0.375,
|
| 62 |
+
'f1 score of exact match': 0.0,
|
| 63 |
+
'f1 score of optimal match of weight (1, 1, 1, 1)': 0.5}*
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
## Limitations and Bias
|
| 67 |
+
*对比传统评估指标,得分偏高*
|
| 68 |
+
|
| 69 |
+
## Citation
|
| 70 |
+
*论文即将发表*
|
| 71 |
+
|
| 72 |
+
## Further References
|
| 73 |
+
*Add any useful further references.*
|
app.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import evaluate
|
| 2 |
+
from evaluate.utils import launch_gradio_widget
|
| 3 |
+
|
| 4 |
+
module = evaluate.load("yuyijiong/quad_match_score")
|
| 5 |
+
launch_gradio_widget(module)
|
| 6 |
+
|
| 7 |
+
# predictions=["a | b | c | pos"]
|
| 8 |
+
# references=["a | b | c | pos & e | f | g | neg"]
|
| 9 |
+
#
|
| 10 |
+
# module.compute(predictions=predictions, references=references)
|
quad_match_score.py
ADDED
|
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""TODO: Add a description here."""
|
| 15 |
+
|
| 16 |
+
import copy
|
| 17 |
+
import re
|
| 18 |
+
from typing import List, Dict, Union,Callable
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
import datasets
|
| 23 |
+
import evaluate
|
| 24 |
+
from rouge_chinese import Rouge
|
| 25 |
+
from scipy.optimize import linear_sum_assignment
|
| 26 |
+
|
| 27 |
+
# TODO: Add BibTeX citation
|
| 28 |
+
_CITATION = """\
|
| 29 |
+
@InProceedings{huggingface:module,
|
| 30 |
+
title = {A great new module},
|
| 31 |
+
authors={huggingface, Inc.},
|
| 32 |
+
year={2020}
|
| 33 |
+
}
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
# TODO: Add description of the module here
|
| 37 |
+
_DESCRIPTION = """\
|
| 38 |
+
evaluate sentiment quadruples.
|
| 39 |
+
评估生成模型的情感四元组
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# TODO: Add description of the arguments of the module here
|
| 44 |
+
_KWARGS_DESCRIPTION = """
|
| 45 |
+
Calculates how good are predictions given some references, using certain scores
|
| 46 |
+
Args:
|
| 47 |
+
predictions: list of predictions to score. Each predictions
|
| 48 |
+
should be a string with tokens separated by spaces.
|
| 49 |
+
references: list of reference for each prediction. Each
|
| 50 |
+
reference should be a string with tokens separated by spaces.
|
| 51 |
+
Returns:
|
| 52 |
+
score: sentiment quadruple match score
|
| 53 |
+
|
| 54 |
+
Examples:
|
| 55 |
+
Examples should be written in doctest format, and should illustrate how
|
| 56 |
+
to use the function.
|
| 57 |
+
|
| 58 |
+
>>> my_new_module = evaluate.load("my_new_module")
|
| 59 |
+
>>> results = my_new_module.compute(references=[0, 1], predictions=[0, 1])
|
| 60 |
+
>>> print(results)
|
| 61 |
+
{'accuracy': 1.0}
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def compute_quadruple_f1(y_pred: List[str], y_true: Union[List[str], List[List[str]]],
|
| 66 |
+
return_rp=False, **kwargs) -> Dict[str, float]:
|
| 67 |
+
assert len(y_pred) == len(y_true)
|
| 68 |
+
correct, pred_num, true_num = 0, 0, 0
|
| 69 |
+
|
| 70 |
+
for pred, true in zip(y_pred, y_true):
|
| 71 |
+
|
| 72 |
+
pred = CommentUnitsSim.from_str(pred, **kwargs)
|
| 73 |
+
# 如果true是list,说明有多个正确答案
|
| 74 |
+
if isinstance(true, str):
|
| 75 |
+
true = CommentUnitsSim.from_str(true, **kwargs)
|
| 76 |
+
else:
|
| 77 |
+
true = [CommentUnitsSim.from_str(t,**kwargs) for t in true]
|
| 78 |
+
|
| 79 |
+
# 如果true是list,说明有多个正确答案,取最高分
|
| 80 |
+
if isinstance(true, list):
|
| 81 |
+
correct_list = [pred.compare_same(t) for t in true]
|
| 82 |
+
correct += max(correct_list) # 获取得分最高的值
|
| 83 |
+
correct_index = correct_list.index(max(correct_list)) # 获取得分最高的索引
|
| 84 |
+
pred_num += pred.num
|
| 85 |
+
true_num += true[correct_index].num
|
| 86 |
+
else:
|
| 87 |
+
correct += pred.compare_same(true)
|
| 88 |
+
pred_num += pred.num
|
| 89 |
+
true_num += true.num
|
| 90 |
+
|
| 91 |
+
# 以下结果保留4位小数
|
| 92 |
+
precision = round(correct / pred_num, 4) + 1e-8
|
| 93 |
+
recall = round(correct / true_num, 4) + 1e-8
|
| 94 |
+
f1 = round(2 * precision * recall / (precision + recall), 4)
|
| 95 |
+
|
| 96 |
+
if return_rp:
|
| 97 |
+
return {"precision": precision, "recall": recall, "f1": f1}
|
| 98 |
+
else:
|
| 99 |
+
return f1
|
| 100 |
+
|
| 101 |
+
# 计算rougel的f1值
|
| 102 |
+
def get_rougel_f1(text_pred_list: List[str], text_true_list: List[str]) -> float:
|
| 103 |
+
assert len(text_pred_list) == len(text_true_list), "文本数量不一致"
|
| 104 |
+
#如果text_pred_list[0]为空字符串或空格,则返回0
|
| 105 |
+
if not text_pred_list[0].strip():
|
| 106 |
+
return 0
|
| 107 |
+
|
| 108 |
+
rouge = Rouge()
|
| 109 |
+
# 判断text_true[0]是否有中文,有中文则要用空格分割
|
| 110 |
+
if re.search(u"[\u4e00-\u9fa5]+", text_pred_list[0]):
|
| 111 |
+
text_pred_list = [' '.join(list(text_pred)) for text_pred in text_pred_list]
|
| 112 |
+
text_true_list = [' '.join(list(text_true)) for text_true in text_true_list]
|
| 113 |
+
|
| 114 |
+
rouge_l_f1 = rouge.get_scores(text_pred_list, text_true_list, avg=True)['rouge-l']['f']
|
| 115 |
+
|
| 116 |
+
return rouge_l_f1
|
| 117 |
+
|
| 118 |
+
# 记录四元组的函数
|
| 119 |
+
class CommentUnitsSim:
|
| 120 |
+
def __init__(self, data: List[Dict[str, str]],data_source:any=None,abnormal=False,language=None):
|
| 121 |
+
self.data_source=data_source
|
| 122 |
+
self.abnormal=abnormal
|
| 123 |
+
data=copy.deepcopy(data)
|
| 124 |
+
# 如果字典有target,则改名为target_text
|
| 125 |
+
for quad_dict in data:
|
| 126 |
+
if 'target' in quad_dict:
|
| 127 |
+
quad_dict['target_text'] = quad_dict['target']
|
| 128 |
+
del quad_dict['target']
|
| 129 |
+
if 'opinion' in quad_dict:
|
| 130 |
+
quad_dict['opinion_text'] = quad_dict['opinion']
|
| 131 |
+
del quad_dict['opinion']
|
| 132 |
+
|
| 133 |
+
self.data = data
|
| 134 |
+
self.polarity_en2zh = {'positive': '积极', 'negative': '消极', 'neutral': '中性','pos':'积极','neg':'消极','neu':'中性','积极':'积极','消极':'���极','中性':'中性'}
|
| 135 |
+
self.polarity_zh2en={'积极':'pos','消极':'neg','中性':'neu','pos':'pos','neg':'neg','neu':'neu','positive':'pos','negative':'neg','neutral':'neu'}
|
| 136 |
+
|
| 137 |
+
self.language=language if language is not None else 'zh' if self.check_zh() else 'en'
|
| 138 |
+
self.none_sign='null'
|
| 139 |
+
|
| 140 |
+
@property
|
| 141 |
+
def num(self):
|
| 142 |
+
return len(self.data)
|
| 143 |
+
|
| 144 |
+
#检查四元组中是否有中文
|
| 145 |
+
def check_zh(self):
|
| 146 |
+
for quad_dict in self.data:
|
| 147 |
+
if re.search('[\u4e00-\u9fa5]',quad_dict['target_text']) or re.search('[\u4e00-\u9fa5]',quad_dict['opinion_text']):
|
| 148 |
+
return True
|
| 149 |
+
return False
|
| 150 |
+
|
| 151 |
+
# 检测极性是否正确
|
| 152 |
+
def check_polarity(self):
|
| 153 |
+
#若有某个四元组的极性不是positive、negative、neutral,则返回False
|
| 154 |
+
for quad_dict in self.data:
|
| 155 |
+
if quad_dict['polarity'] not in ['positive', 'negative', 'neutral','pos','neg','neu','积极','消极','中性']:
|
| 156 |
+
self.abnormal=True
|
| 157 |
+
return False
|
| 158 |
+
|
| 159 |
+
#将极性由英文转为中文
|
| 160 |
+
def convert_polarity_en2zh(self):
|
| 161 |
+
for quad_dict in self.data:
|
| 162 |
+
quad_dict['polarity']=self.polarity_en2zh[quad_dict['polarity']]
|
| 163 |
+
return self
|
| 164 |
+
|
| 165 |
+
#将极性由中文转为英文
|
| 166 |
+
def convert_polarity_zh2en(self):
|
| 167 |
+
for quad_dict in self.data:
|
| 168 |
+
quad_dict['polarity']=self.polarity_zh2en[quad_dict['polarity']]
|
| 169 |
+
return self
|
| 170 |
+
|
| 171 |
+
#检查是否有重复的四元组,若有则删除重复的
|
| 172 |
+
def del_duplicate(self):
|
| 173 |
+
new_data=[]
|
| 174 |
+
for quad_dict in self.data:
|
| 175 |
+
if quad_dict not in new_data:
|
| 176 |
+
new_data.append(quad_dict)
|
| 177 |
+
self.data=new_data
|
| 178 |
+
return self
|
| 179 |
+
|
| 180 |
+
#检查是否有target和opinion都为null的四元组,若有则返回True
|
| 181 |
+
def check_target_opinion_null(self):
|
| 182 |
+
for quad_dict in self.data:
|
| 183 |
+
if quad_dict['target_text']=='null' and quad_dict['opinion_text']=='null':
|
| 184 |
+
return True
|
| 185 |
+
return False
|
| 186 |
+
|
| 187 |
+
#检查是否有target或opinion为null的四元组,若有则返回True
|
| 188 |
+
def check_any_null(self):
|
| 189 |
+
for quad_dict in self.data:
|
| 190 |
+
if quad_dict['target_text']=='null' or quad_dict['opinion_text']=='null':
|
| 191 |
+
return True
|
| 192 |
+
return False
|
| 193 |
+
|
| 194 |
+
@classmethod
|
| 195 |
+
def from_str(cls, quadruple_str: str, tuple_len:Union[int,list,str]=4, format_code=0, sep_token1=' & ', sep_token2=' | '):
|
| 196 |
+
data = []
|
| 197 |
+
abnormal=False
|
| 198 |
+
#确保分隔符后面一定是空格
|
| 199 |
+
for i in range(len(quadruple_str)-1):
|
| 200 |
+
if (quadruple_str[i] == sep_token1.strip() or quadruple_str[i] == sep_token2.strip()) and quadruple_str[i + 1] != ' ':
|
| 201 |
+
quadruple_str = quadruple_str[:i + 1] + ' ' + quadruple_str[i + 1:]
|
| 202 |
+
|
| 203 |
+
# 选择几元组,即创建列表索引,从四元组中抽出n元
|
| 204 |
+
if isinstance(tuple_len, int):
|
| 205 |
+
tuple_index = list(range(tuple_len))
|
| 206 |
+
elif isinstance(tuple_len, list):
|
| 207 |
+
tuple_index = tuple_len
|
| 208 |
+
elif isinstance(tuple_len, str):
|
| 209 |
+
# 例如将‘012’转换为[0,1,2]
|
| 210 |
+
tuple_index = [int(i) for i in tuple_len]
|
| 211 |
+
else:
|
| 212 |
+
raise Exception('tuple_len参数错误')
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
for quadruple in quadruple_str.split(sep_token1):
|
| 216 |
+
if format_code == 0:
|
| 217 |
+
# quadruple可能是target|opinion|aspect|polarity,也可能是target|opinion|aspect,也可能是target|opinion,若没有则为“None”
|
| 218 |
+
quadruple_split=[unit.strip() for unit in quadruple.split(sep_token2)]
|
| 219 |
+
if len(quadruple_split)>len(tuple_index):
|
| 220 |
+
print('quadruple格式错误,过多元素', quadruple_str)
|
| 221 |
+
abnormal=True
|
| 222 |
+
quadruple_split=quadruple_split[0:len(tuple_index)] #过长则截断
|
| 223 |
+
elif len(quadruple_split)<len(tuple_index):
|
| 224 |
+
print('quadruple格式错误,过少元素', quadruple_str)
|
| 225 |
+
abnormal=True
|
| 226 |
+
quadruple_split=["None"]*(len(tuple_index)-len(quadruple_split))+quadruple_split #过短则补'None'
|
| 227 |
+
|
| 228 |
+
quadruple_keys=[["target_text","opinion_text","aspect","polarity"][i] for i in tuple_index]
|
| 229 |
+
quadruple_dict=dict(zip(quadruple_keys,quadruple_split))
|
| 230 |
+
|
| 231 |
+
q = {"target_text": 'None', "opinion_text": 'None', "aspect": 'None', "polarity": 'None'}
|
| 232 |
+
q.update(quadruple_dict)
|
| 233 |
+
#检查极性是否合法
|
| 234 |
+
if q['polarity'] not in ['pos','neg','neu','None','积极','消极','中性']:
|
| 235 |
+
print('quadruple格式错误,极性格式不对', quadruple_str)
|
| 236 |
+
|
| 237 |
+
else:
|
| 238 |
+
raise Exception('answer_format参数错误')
|
| 239 |
+
|
| 240 |
+
data.append(q)
|
| 241 |
+
|
| 242 |
+
return CommentUnitsSim(data,quadruple_str,abnormal)
|
| 243 |
+
|
| 244 |
+
@classmethod
|
| 245 |
+
def from_list(cls, quadruple_list: List[List[str]],**kwargs):
|
| 246 |
+
data = []
|
| 247 |
+
for quadruple in quadruple_list:
|
| 248 |
+
# #format_code='013'代表list只有四元组的第0、1、3个元素,需要扩充为4元组,空缺位置补上None
|
| 249 |
+
# if format_code=='013':
|
| 250 |
+
# quadruple.insert(2,None)
|
| 251 |
+
|
| 252 |
+
data.append(
|
| 253 |
+
{"target_text": quadruple[0], "opinion_text": quadruple[1], "aspect": quadruple[2],
|
| 254 |
+
"polarity": quadruple[3]})
|
| 255 |
+
|
| 256 |
+
return CommentUnitsSim(data,quadruple_list,**kwargs)
|
| 257 |
+
|
| 258 |
+
@classmethod
|
| 259 |
+
def from_list_dict(cls, quadruple_list: List[dict],**kwargs):
|
| 260 |
+
for quad_dict in quadruple_list:
|
| 261 |
+
if 'target' in quad_dict:
|
| 262 |
+
quad_dict['target_text'] = quad_dict['target']
|
| 263 |
+
del quad_dict['target']
|
| 264 |
+
if 'opinion' in quad_dict:
|
| 265 |
+
quad_dict['opinion_text'] = quad_dict['opinion']
|
| 266 |
+
del quad_dict['opinion']
|
| 267 |
+
|
| 268 |
+
data = []
|
| 269 |
+
for quadruple in quadruple_list:
|
| 270 |
+
#如果quadruple缺少某个key,则补上None
|
| 271 |
+
q={"target_text":'None',"opinion_text":'None',"aspect":'None',"polarity":'None'}
|
| 272 |
+
q.update(quadruple)
|
| 273 |
+
data.append(q)
|
| 274 |
+
|
| 275 |
+
return CommentUnitsSim(data,quadruple_list,**kwargs)
|
| 276 |
+
|
| 277 |
+
#转化为list,即只保留字典的value
|
| 278 |
+
def to_list(self):
|
| 279 |
+
data = []
|
| 280 |
+
for quad_dict in self.data:
|
| 281 |
+
data.append([quad_dict['target_text'],quad_dict['opinion_text'],quad_dict['aspect'],quad_dict['polarity']])
|
| 282 |
+
return data
|
| 283 |
+
|
| 284 |
+
# 将data转换为n元组字符串
|
| 285 |
+
def get_quadruple_str(self, format_code=0, tuple_len:Union[int,list,str]=4,sep_token1=' & ',sep_token2=' | '):
|
| 286 |
+
new_text_list = []
|
| 287 |
+
# 选择几元组,即创建列表索引,从四元组中抽出n元
|
| 288 |
+
if isinstance(tuple_len, int):
|
| 289 |
+
tuple_index = list(range(tuple_len))
|
| 290 |
+
elif isinstance(tuple_len, list):
|
| 291 |
+
tuple_index = tuple_len
|
| 292 |
+
elif isinstance(tuple_len, str):
|
| 293 |
+
# 例如将‘012’转换为[0,1,2]
|
| 294 |
+
tuple_index = [int(i) for i in tuple_len]
|
| 295 |
+
else:
|
| 296 |
+
raise Exception('tuple_len参数错误')
|
| 297 |
+
|
| 298 |
+
try:
|
| 299 |
+
#若语言为中文,则使用中文极性
|
| 300 |
+
if self.language=='zh':
|
| 301 |
+
self.convert_polarity_en2zh()
|
| 302 |
+
else:
|
| 303 |
+
self.convert_polarity_zh2en()
|
| 304 |
+
except:
|
| 305 |
+
print('语言参数错误',self.data)
|
| 306 |
+
print(self.language)
|
| 307 |
+
raise Exception('语言参数错误')
|
| 308 |
+
|
| 309 |
+
#若tuple_index==[3],则返回综合情感极性
|
| 310 |
+
if tuple_index==[3]:
|
| 311 |
+
return self.merge_polarity()
|
| 312 |
+
|
| 313 |
+
for quad_dict in self.data:
|
| 314 |
+
# 提取target_text,如果空列表则为'',如果列表长度大于1则为','.join(list)
|
| 315 |
+
target_text = quad_dict['target_text']
|
| 316 |
+
# 提取opinion_text,如果空列表则为'',如果列表长度大于1则为','.join(list)
|
| 317 |
+
opinion_text = quad_dict['opinion_text']
|
| 318 |
+
# 提取aspect
|
| 319 |
+
aspect = quad_dict['aspect']
|
| 320 |
+
# 提取polarity
|
| 321 |
+
polarity = quad_dict['polarity']
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
# 拼接,‘|’分割
|
| 325 |
+
if format_code == 0:
|
| 326 |
+
# 根据tuple_len拼接
|
| 327 |
+
new_text = sep_token2.join([[target_text, opinion_text, aspect, polarity][i] for i in tuple_index])
|
| 328 |
+
else:
|
| 329 |
+
raise Exception('answer_format参数错误')
|
| 330 |
+
|
| 331 |
+
new_text_list.append(new_text)
|
| 332 |
+
|
| 333 |
+
#如果tuple_index为[2,3],则需要去除new_text_list中重复的元素,不要改变顺序。因为可能有重复的方面
|
| 334 |
+
if tuple_index==[2,3]:
|
| 335 |
+
res = []
|
| 336 |
+
for t in new_text_list:
|
| 337 |
+
if t not in res:
|
| 338 |
+
res.append(t)
|
| 339 |
+
new_text_list=res
|
| 340 |
+
|
| 341 |
+
#如果tuple_index为[3],则只保留new_text_list的第一个元素。因为只有一个情感极性
|
| 342 |
+
elif tuple_index==[3]:
|
| 343 |
+
new_text_list=new_text_list[:1]
|
| 344 |
+
|
| 345 |
+
if format_code == 0:
|
| 346 |
+
# 根据tuple_len拼接
|
| 347 |
+
return sep_token1.join(new_text_list)
|
| 348 |
+
|
| 349 |
+
# 与另一个CommentUnits对象对比,检测有几个相同的四元组
|
| 350 |
+
def compare_same(self, other)->int:
|
| 351 |
+
count = 0
|
| 352 |
+
for quad_dict in self.data:
|
| 353 |
+
if quad_dict in other.data:
|
| 354 |
+
count += 1
|
| 355 |
+
return count
|
| 356 |
+
|
| 357 |
+
# 检查自身数据的四元组中target是否有重复
|
| 358 |
+
def check_target_repeat(self):
|
| 359 |
+
target_list = []
|
| 360 |
+
for quad_dict in self.data:
|
| 361 |
+
target_list.append(quad_dict['target_text'])
|
| 362 |
+
return len(target_list) != len(set(target_list))
|
| 363 |
+
|
| 364 |
+
# 检查自身数据的四元组中opinion是否有重复
|
| 365 |
+
def check_opinion_repeat(self):
|
| 366 |
+
opinion_list = []
|
| 367 |
+
for quad_dict in self.data:
|
| 368 |
+
opinion_list.append(quad_dict['opinion_text'])
|
| 369 |
+
return len(opinion_list) != len(set(opinion_list))
|
| 370 |
+
|
| 371 |
+
# 检查自身数据的四元组中aspect是否有重复
|
| 372 |
+
def check_aspect_repeat(self):
|
| 373 |
+
aspect_list = []
|
| 374 |
+
for quad_dict in self.data:
|
| 375 |
+
aspect_list.append(quad_dict['aspect'])
|
| 376 |
+
return len(aspect_list) != len(set(aspect_list))
|
| 377 |
+
|
| 378 |
+
# 输出所有aspect的列表
|
| 379 |
+
def get_aspect_list(self):
|
| 380 |
+
aspect_list = []
|
| 381 |
+
for quad_dict in self.data:
|
| 382 |
+
aspect_list.append(quad_dict['aspect'])
|
| 383 |
+
return aspect_list
|
| 384 |
+
|
| 385 |
+
# 输出所有target的列表
|
| 386 |
+
def get_target_list(self):
|
| 387 |
+
target_list = []
|
| 388 |
+
for quad_dict in self.data:
|
| 389 |
+
target_list.append(quad_dict['target_text'])
|
| 390 |
+
return target_list
|
| 391 |
+
|
| 392 |
+
# 输出所有opinion的列表
|
| 393 |
+
def get_opinion_list(self):
|
| 394 |
+
opinion_list = []
|
| 395 |
+
for quad_dict in self.data:
|
| 396 |
+
opinion_list.append(quad_dict['opinion_text'])
|
| 397 |
+
return opinion_list
|
| 398 |
+
|
| 399 |
+
# 输出所有polarity的列表
|
| 400 |
+
def get_polarity_list(self):
|
| 401 |
+
polarity_list = []
|
| 402 |
+
for quad_dict in self.data:
|
| 403 |
+
polarity_list.append(quad_dict['polarity'])
|
| 404 |
+
return polarity_list
|
| 405 |
+
|
| 406 |
+
#对所有polarity进行综合
|
| 407 |
+
def merge_polarity(self):
|
| 408 |
+
polarity_list = self.get_polarity_list()
|
| 409 |
+
#判断是英文还是中文
|
| 410 |
+
if self.language == 'en':
|
| 411 |
+
if 'pos' in polarity_list and 'neg' in polarity_list:
|
| 412 |
+
return 'neu'
|
| 413 |
+
elif 'pos' in polarity_list:
|
| 414 |
+
return 'pos'
|
| 415 |
+
elif 'neg' in polarity_list:
|
| 416 |
+
return 'neg'
|
| 417 |
+
else:
|
| 418 |
+
return 'neu'
|
| 419 |
+
else:
|
| 420 |
+
if '积极' in polarity_list and '消极' in polarity_list:
|
| 421 |
+
return '中性'
|
| 422 |
+
elif '积极' in polarity_list:
|
| 423 |
+
return '积极'
|
| 424 |
+
elif '消极' in polarity_list:
|
| 425 |
+
return '消极'
|
| 426 |
+
else:
|
| 427 |
+
return '中性'
|
| 428 |
+
|
| 429 |
+
#检测是否有不合法opinion
|
| 430 |
+
def check_opinion_in_comment(self, comment_text):
|
| 431 |
+
for quad_dict in self.data:
|
| 432 |
+
if quad_dict['opinion_text'] !='*' and (not quad_dict['opinion_text'] in comment_text):
|
| 433 |
+
return False
|
| 434 |
+
return True
|
| 435 |
+
|
| 436 |
+
#检测是否有不合法target
|
| 437 |
+
def check_target_in_comment(self,comment_text):
|
| 438 |
+
for quad_dict in self.data:
|
| 439 |
+
if quad_dict['target_text'] !='*' and (not quad_dict['target_text'] in comment_text):
|
| 440 |
+
return False
|
| 441 |
+
return True
|
| 442 |
+
|
| 443 |
+
#计算两个四元组的相似度
|
| 444 |
+
@staticmethod
|
| 445 |
+
def get_similarity(units1, units2: 'CommentUnitsSim'):
|
| 446 |
+
pass
|
| 447 |
+
|
| 448 |
+
#对自身数据进行操作
|
| 449 |
+
def apply(self,func:Callable,field:str):
|
| 450 |
+
for quad_dict in self.data:
|
| 451 |
+
quad_dict[field] = func(quad_dict[field])
|
| 452 |
+
return self
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
#四元组匹配函数
|
| 456 |
+
class CommentUnitsMatch:
|
| 457 |
+
def __init__(self,target_weight=0.5,opinion_weight=0.5,aspect_weight=0.5,polarity_weight=0.5):
|
| 458 |
+
#归一化权重
|
| 459 |
+
weight_sum = target_weight+opinion_weight+aspect_weight+polarity_weight
|
| 460 |
+
self.target_weight = target_weight/weight_sum
|
| 461 |
+
self.opinion_weight = opinion_weight/weight_sum
|
| 462 |
+
self.aspect_weight = aspect_weight/weight_sum
|
| 463 |
+
self.polarity_weight = polarity_weight/weight_sum
|
| 464 |
+
|
| 465 |
+
#特定feature置零
|
| 466 |
+
def set_zero(self,feature:str='polarity'):
|
| 467 |
+
if feature == 'polarity':
|
| 468 |
+
self.polarity_weight = 0
|
| 469 |
+
elif feature == 'aspect':
|
| 470 |
+
self.aspect_weight = 0
|
| 471 |
+
elif 'opinion' in feature:
|
| 472 |
+
self.opinion_weight = 0
|
| 473 |
+
elif 'target' in feature:
|
| 474 |
+
self.target_weight = 0
|
| 475 |
+
else:
|
| 476 |
+
raise Exception('feature参数错误')
|
| 477 |
+
|
| 478 |
+
def re_normalize(self):
|
| 479 |
+
weight_sum = self.target_weight+self.opinion_weight+self.aspect_weight+self.polarity_weight
|
| 480 |
+
self.target_weight = self.target_weight/weight_sum
|
| 481 |
+
self.opinion_weight = self.opinion_weight/weight_sum
|
| 482 |
+
self.aspect_weight = self.aspect_weight/weight_sum
|
| 483 |
+
self.polarity_weight = self.polarity_weight/weight_sum
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
#计算cost矩阵
|
| 487 |
+
def get_cost_matrix(self,units1: 'CommentUnitsSim', units2: 'CommentUnitsSim',feature:str='polarity'):
|
| 488 |
+
pass
|
| 489 |
+
#检查此feature是否存在,不存在则返回全0矩阵
|
| 490 |
+
if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None\
|
| 491 |
+
or units1.data[0].get(feature)=='None' or units2.data[0].get(feature)=='None':
|
| 492 |
+
cost_matrix = np.zeros((len(units1.data),len(units2.data)))
|
| 493 |
+
#对应feature的weight也为0
|
| 494 |
+
self.set_zero(feature)
|
| 495 |
+
|
| 496 |
+
# 并再次归一化
|
| 497 |
+
self.re_normalize()
|
| 498 |
+
|
| 499 |
+
return cost_matrix
|
| 500 |
+
|
| 501 |
+
#检查两个四元组的极性是否相同,生成cost矩阵,用于匈牙利算法。不相同则cost为1,相同则cost为0
|
| 502 |
+
cost_matrix = []
|
| 503 |
+
for quad_dict1 in units1.data:
|
| 504 |
+
cost_list = []
|
| 505 |
+
for quad_dict2 in units2.data:
|
| 506 |
+
if quad_dict1[feature] == quad_dict2[feature]:
|
| 507 |
+
cost_list.append(0)
|
| 508 |
+
else:
|
| 509 |
+
cost_list.append(1)
|
| 510 |
+
cost_matrix.append(cost_list)
|
| 511 |
+
|
| 512 |
+
#cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data))
|
| 513 |
+
cost_matrix = np.array(cost_matrix)
|
| 514 |
+
return cost_matrix
|
| 515 |
+
|
| 516 |
+
#计算cost矩阵,使用rouge指标
|
| 517 |
+
def get_cost_matrix_rouge(self,units1: 'CommentUnitsSim', units2: 'CommentUnitsSim',feature:str='target_text'):
|
| 518 |
+
#检查此feature是否存在,不存在则返回全0矩阵
|
| 519 |
+
if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None\
|
| 520 |
+
or units1.data[0].get(feature)=='None' or units2.data[0].get(feature)=='None':
|
| 521 |
+
cost_matrix = np.zeros((len(units1.data),len(units2.data)))
|
| 522 |
+
#对应feature的weight也为0
|
| 523 |
+
self.set_zero(feature)
|
| 524 |
+
# 并再次归一化
|
| 525 |
+
self.re_normalize()
|
| 526 |
+
return cost_matrix
|
| 527 |
+
|
| 528 |
+
#检查两个四元组的极性是否相同,生成cost矩阵,用于匈牙利算法。相同则cost为0,不相同则cost为1-rougel
|
| 529 |
+
cost_matrix = []
|
| 530 |
+
for quad_dict1 in units1.data:
|
| 531 |
+
cost_list = []
|
| 532 |
+
for quad_dict2 in units2.data:
|
| 533 |
+
if quad_dict1[feature] == quad_dict2[feature]:
|
| 534 |
+
cost_list.append(0)
|
| 535 |
+
else:
|
| 536 |
+
cost_list.append(1-get_rougel_f1([quad_dict1[feature]],[quad_dict2[feature]]))
|
| 537 |
+
cost_matrix.append(cost_list)
|
| 538 |
+
|
| 539 |
+
#cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data))
|
| 540 |
+
cost_matrix = np.array(cost_matrix)
|
| 541 |
+
return cost_matrix
|
| 542 |
+
|
| 543 |
+
def match_units(self,units1: 'CommentUnitsSim', units2: 'CommentUnitsSim',one_match=True)->tuple:
|
| 544 |
+
#计算极性的cost矩阵,矩阵元素在0-1之间
|
| 545 |
+
cost_matrix_polarity = self.get_cost_matrix(units1, units2,feature='polarity')
|
| 546 |
+
#计算aspect的cost矩阵
|
| 547 |
+
cost_matrix_aspect = self.get_cost_matrix(units1, units2,feature='aspect')
|
| 548 |
+
#计算target的cost矩阵
|
| 549 |
+
cost_matrix_target = self.get_cost_matrix_rouge(units1, units2,feature='target_text')
|
| 550 |
+
#计算opinion的cost矩阵
|
| 551 |
+
cost_matrix_opinion = self.get_cost_matrix_rouge(units1, units2,feature='opinion_text')
|
| 552 |
+
|
| 553 |
+
#计算总的cost矩阵,矩阵元素在0-1之间。矩阵的行数为units1即pred的数量,列数为units2即true的数量
|
| 554 |
+
cost_matrix = self.target_weight*cost_matrix_target + self.opinion_weight*cost_matrix_opinion + \
|
| 555 |
+
self.aspect_weight*cost_matrix_aspect + self.polarity_weight*cost_matrix_polarity
|
| 556 |
+
score_matrix = 1-cost_matrix
|
| 557 |
+
#使用匈牙利算法进行匹配
|
| 558 |
+
if one_match:
|
| 559 |
+
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
| 560 |
+
else:
|
| 561 |
+
#允许一对多的匹配
|
| 562 |
+
row_ind = np.argmin(cost_matrix, axis=0)
|
| 563 |
+
col_ind = np.arange(len(units2.data))
|
| 564 |
+
|
| 565 |
+
max_units_num=max(units1.num,units2.num)
|
| 566 |
+
|
| 567 |
+
#计算这种匹配的cost
|
| 568 |
+
cost = 0
|
| 569 |
+
for i in range(len(row_ind)):
|
| 570 |
+
cost += cost_matrix[row_ind[i]][col_ind[i]]
|
| 571 |
+
|
| 572 |
+
#计算这种匹配下的TP\FP\FN
|
| 573 |
+
TP = 0
|
| 574 |
+
for i in range(len(row_ind)):
|
| 575 |
+
TP += score_matrix[row_ind[i]][col_ind[i]]
|
| 576 |
+
|
| 577 |
+
#len(row_ind)为pred的数量,TP为匹配上的数量
|
| 578 |
+
FP = units1.num-TP
|
| 579 |
+
FN = units2.num-TP
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
#匹配不上的四元组,cost为1
|
| 583 |
+
cost += (max_units_num-len(row_ind))
|
| 584 |
+
|
| 585 |
+
cost_per_quadruple=cost/max_units_num
|
| 586 |
+
if cost_per_quadruple>1 or cost_per_quadruple <0:
|
| 587 |
+
|
| 588 |
+
print('cost错误',cost_per_quadruple,'pred:',units1.data,'true:',units2.data)
|
| 589 |
+
print(self.target_weight,self.opinion_weight,self.aspect_weight,self.polarity_weight)
|
| 590 |
+
|
| 591 |
+
#返回的cost在0-1之间
|
| 592 |
+
return cost_per_quadruple,TP,FP,FN
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
| 596 |
+
class QuadMatch(evaluate.Metric):
|
| 597 |
+
"""TODO: Short description of my evaluation module."""
|
| 598 |
+
|
| 599 |
+
def _info(self):
|
| 600 |
+
# TODO: Specifies the evaluate.EvaluationModuleInfo object
|
| 601 |
+
return evaluate.MetricInfo(
|
| 602 |
+
# This is the description that will appear on the modules page.
|
| 603 |
+
module_type="metric",
|
| 604 |
+
description=_DESCRIPTION,
|
| 605 |
+
citation=_CITATION,
|
| 606 |
+
inputs_description=_KWARGS_DESCRIPTION,
|
| 607 |
+
# This defines the format of each prediction and reference
|
| 608 |
+
features=[
|
| 609 |
+
datasets.Features(
|
| 610 |
+
{
|
| 611 |
+
"predictions": datasets.Value("string", id="sequence"),
|
| 612 |
+
"references": datasets.Sequence(datasets.Value("string", id="sequence")),
|
| 613 |
+
}
|
| 614 |
+
),
|
| 615 |
+
datasets.Features(
|
| 616 |
+
{
|
| 617 |
+
"predictions": datasets.Value("string", id="sequence"),
|
| 618 |
+
"references": datasets.Value("string", id="sequence"),
|
| 619 |
+
}
|
| 620 |
+
),
|
| 621 |
+
],
|
| 622 |
+
# Homepage of the module for documentation
|
| 623 |
+
homepage="http://module.homepage",
|
| 624 |
+
# Additional links to the codebase or references
|
| 625 |
+
codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
|
| 626 |
+
reference_urls=["http://path.to.reference.url/new_module"]
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
def _download_and_prepare(self, dl_manager):
|
| 630 |
+
"""Optional: download external resources useful to compute the scores"""
|
| 631 |
+
# TODO: Download external resources if needed
|
| 632 |
+
pass
|
| 633 |
+
|
| 634 |
+
def _compute(self,
|
| 635 |
+
predictions:List[str],
|
| 636 |
+
references: Union[List[str],List[List[str]]],
|
| 637 |
+
quad_weights:tuple=(1,1,1,1),
|
| 638 |
+
**kwargs) -> dict:
|
| 639 |
+
'''
|
| 640 |
+
|
| 641 |
+
:param predictions: list of predictions of sentiment quads
|
| 642 |
+
:param references: list of references of sentiment quads
|
| 643 |
+
:param quad_weights: weight of target,opinion,aspect,polarity for cost compute
|
| 644 |
+
|
| 645 |
+
:param kwargs:
|
| 646 |
+
:param tuple_len: indicate the format of the quad, see the following mapping
|
| 647 |
+
:param sep_token1: the token to seperate quads
|
| 648 |
+
:param sep_token2: the token to seperate units of one quad
|
| 649 |
+
|
| 650 |
+
:return:average matching score
|
| 651 |
+
|
| 652 |
+
#mapping
|
| 653 |
+
id2prompt={'0123':"quadruples (target | opinion | aspect | polarity)",
|
| 654 |
+
'':"quadruples (target | opinion | aspect | polarity)",
|
| 655 |
+
'01':'pairs (target | opinion)',
|
| 656 |
+
'012':'triples (target | opinion | aspect)',
|
| 657 |
+
'013':'triples (target | opinion | polarity)',
|
| 658 |
+
'023':'triples (target | aspect | polarity)',
|
| 659 |
+
'23':'pairs (aspect | polarity)',
|
| 660 |
+
'03':'pairs (target | polarity)',
|
| 661 |
+
'13':'pairs (opinion | polarity)',
|
| 662 |
+
'3':'single (polarity)'}
|
| 663 |
+
|
| 664 |
+
#中文版映射
|
| 665 |
+
id2prompt_zh={'0123': "四元组(对象 | 观点 | 方面 | 极性)",
|
| 666 |
+
'':"四元组(对象 | 观点 | 方面 | 极性)",
|
| 667 |
+
'01':'二元组(对象 | 观点)',
|
| 668 |
+
'012':'三元组(对象 | 观点 | 方面)',
|
| 669 |
+
'013':'三元组(对象 | 观点 | 极性)',
|
| 670 |
+
'023':'三元组(对象 | 方面 | 极性)',
|
| 671 |
+
'23':'二元组(方面 | 极性)',
|
| 672 |
+
'03':'二元组(对象 | 极性)',
|
| 673 |
+
'13':'二元组(观点 | 极性)',
|
| 674 |
+
'3':'单元素(极性)'}
|
| 675 |
+
'''
|
| 676 |
+
|
| 677 |
+
assert len(predictions) == len(references)
|
| 678 |
+
if isinstance(predictions,str):
|
| 679 |
+
predictions=[predictions]
|
| 680 |
+
references=[references]
|
| 681 |
+
|
| 682 |
+
cost=0
|
| 683 |
+
TP,FP,FN=0,0,0
|
| 684 |
+
matcher = CommentUnitsMatch(*quad_weights)
|
| 685 |
+
|
| 686 |
+
for pred, true in zip(predictions, references):
|
| 687 |
+
|
| 688 |
+
pred = CommentUnitsSim.from_str(pred,**kwargs)
|
| 689 |
+
# 如果true是list,说明有多个正确答案
|
| 690 |
+
if isinstance(true, str):
|
| 691 |
+
true = CommentUnitsSim.from_str(true, **kwargs)
|
| 692 |
+
elif isinstance(true, list):
|
| 693 |
+
true=[CommentUnitsSim.from_str(t, **kwargs) for t in true]
|
| 694 |
+
else:
|
| 695 |
+
print("true的类型不对",true)
|
| 696 |
+
continue
|
| 697 |
+
|
| 698 |
+
#如果true是list,说明有多个正确答案,取最高分
|
| 699 |
+
if isinstance(true, list):
|
| 700 |
+
cost_list=[matcher.match_units(pred,t,one_match=True) for t in true]
|
| 701 |
+
# 获取得分最高的值的索引,按元组中第一个元素大小排序
|
| 702 |
+
cost_,TP_,FP_,FN_ = cost_list[np.argmax([c[0] for c in cost_list])]
|
| 703 |
+
cost += cost_
|
| 704 |
+
TP+=TP_
|
| 705 |
+
FP+=FP_
|
| 706 |
+
FN+=FN_
|
| 707 |
+
|
| 708 |
+
else:
|
| 709 |
+
cost_,TP_,FP_,FN_ = matcher.match_units(pred,true,one_match=True)
|
| 710 |
+
cost += cost_
|
| 711 |
+
TP+=TP_
|
| 712 |
+
FP+=FP_
|
| 713 |
+
FN+=FN_
|
| 714 |
+
|
| 715 |
+
#平均cost
|
| 716 |
+
cost=cost/len(predictions)
|
| 717 |
+
#由TP\FP\FN计算最优匹配F1
|
| 718 |
+
precision_match=TP/(TP+FP)
|
| 719 |
+
recall_match=TP/(TP+FN)
|
| 720 |
+
f1_match=2*precision_match*recall_match/(precision_match+recall_match)
|
| 721 |
+
|
| 722 |
+
f1=compute_quadruple_f1(y_pred=predictions,y_true=references, **kwargs)
|
| 723 |
+
|
| 724 |
+
#取1-cost为得分
|
| 725 |
+
return {'ave match score of weight '+str(quad_weights):1-cost,
|
| 726 |
+
'f1 score of optimal match of weight '+str(quad_weights): f1_match,
|
| 727 |
+
'f1 score of exact match':f1}
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
git+https://github.com/huggingface/evaluate@main
|
| 2 |
+
rouge_chinese
|
| 3 |
+
scipy
|
tests.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
test_cases = [
|
| 2 |
+
{
|
| 3 |
+
"predictions": "a | b | c | pos",
|
| 4 |
+
"references": "a | b | c | pos & e | f | g | neg",
|
| 5 |
+
"result": {'ave match score of weight (1, 1, 1, 1)': 0.375,
|
| 6 |
+
'f1 score of exact match': 0.0,
|
| 7 |
+
'f1 score of optimal match of weight (1, 1, 1, 1)': 0.5}
|
| 8 |
+
}
|
| 9 |
+
]
|