Spaces:
Runtime error
Runtime error
Updated the documentation and added more test cases.
Browse files
README.md
CHANGED
|
@@ -25,49 +25,65 @@ summary with the reference overlap summary. It evaluates the semantic overlap su
|
|
| 25 |
computes precision, recall and F1 scores.
|
| 26 |
|
| 27 |
## How to Use
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
| 31 |
|
| 32 |
```python
|
| 33 |
from evaluate import load
|
|
|
|
| 34 |
predictions = [
|
| 35 |
["I go to School.", "You are stupid."],
|
| 36 |
["I love adventure sports."],
|
| 37 |
]
|
| 38 |
references = [
|
| 39 |
["I go to School.", "You are stupid."],
|
| 40 |
-
["I love
|
| 41 |
]
|
| 42 |
metric = load("semf1")
|
| 43 |
results = metric.compute(predictions=predictions, references=references)
|
|
|
|
|
|
|
| 44 |
```
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
`
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
[
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
-
[//]: # (### Inputs)
|
| 57 |
|
| 58 |
[//]: # (*List all input arguments in the format below*)
|
| 59 |
|
| 60 |
[//]: # (- **input_field** *(type): Definition of input, with explanation if necessary. State any default value(s).*)
|
| 61 |
|
| 62 |
### Output Values
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
-
`precision`: The [precision](https://huggingface.co/metrics/precision) for each sentence from the `predictions` + `references` lists, which ranges from 0.0 to 1.0.
|
| 65 |
-
|
| 66 |
-
`recall`: The [recall](https://huggingface.co/metrics/recall) for each sentence from the `predictions` + `references` lists, which ranges from 0.0 to 1.0.
|
| 67 |
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
|
|
|
|
| 71 |
|
| 72 |
[//]: # (*Give examples, preferrably with links to leaderboards or publications, to papers that have reported this metric, along with the values they have reported.*)
|
| 73 |
|
|
|
|
| 25 |
computes precision, recall and F1 scores.
|
| 26 |
|
| 27 |
## How to Use
|
| 28 |
+
|
| 29 |
+
Sem-F1 takes 2 mandatory arguments:
|
| 30 |
+
- `predictions` - List of predictions. Format varies based on `tokenize_sentences` and `multi_references` flags.
|
| 31 |
+
- `references`: List of references. Format varies based on `tokenize_sentences` and `multi_references` flags.
|
| 32 |
|
| 33 |
```python
|
| 34 |
from evaluate import load
|
| 35 |
+
|
| 36 |
predictions = [
|
| 37 |
["I go to School.", "You are stupid."],
|
| 38 |
["I love adventure sports."],
|
| 39 |
]
|
| 40 |
references = [
|
| 41 |
["I go to School.", "You are stupid."],
|
| 42 |
+
["I love outdoor sports."],
|
| 43 |
]
|
| 44 |
metric = load("semf1")
|
| 45 |
results = metric.compute(predictions=predictions, references=references)
|
| 46 |
+
for score in results:
|
| 47 |
+
print(f"Precision: {score.precision}, Recall: {score.recall}, F1: {score.f1}")
|
| 48 |
```
|
| 49 |
|
| 50 |
+
Sem-F1 also accepts multiple optional arguments:
|
| 51 |
+
- `model_type (str)`: Model to use for encoding sentences. Options: ['pv1', 'stsb', 'use']
|
| 52 |
+
- `pv1` - [paraphrase-distilroberta-base-v1](https://huggingface.co/sentence-transformers/paraphrase-distilroberta-base-v1)
|
| 53 |
+
- `stsb` - [stsb-roberta-large](https://huggingface.co/sentence-transformers/stsb-roberta-large)
|
| 54 |
+
- `use` - [Universal Sentence Encoder](https://huggingface.co/sentence-transformers/use-cmlm-multilingual) (Default)
|
| 55 |
+
- `tokenize_sentences (bool)`: Flag to indicate whether to tokenize the sentences in the input documents. Default: True.
|
| 56 |
+
- `multi_references (bool)`: Flag to indicate whether multiple references are provided. Default: False.
|
| 57 |
+
- `gpu (Union[bool, str, int, List[Union[str, int]]])`: Whether to use GPU, CPU or multiple-processes for computation.
|
| 58 |
+
- `batch_size (int)`: Batch size for encoding. Default: 32.
|
| 59 |
+
- `verbose (bool)`: Flag to indicate verbose output. Default: False.
|
| 60 |
+
|
| 61 |
+
Refer to the inputs descriptions for more detailed usage as follows
|
| 62 |
+
```python
|
| 63 |
+
import evaluate
|
| 64 |
+
metric = evaluate.load("semf1")
|
| 65 |
+
metric.inputs_description
|
| 66 |
+
```
|
| 67 |
|
|
|
|
| 68 |
|
| 69 |
[//]: # (*List all input arguments in the format below*)
|
| 70 |
|
| 71 |
[//]: # (- **input_field** *(type): Definition of input, with explanation if necessary. State any default value(s).*)
|
| 72 |
|
| 73 |
### Output Values
|
| 74 |
+
List of `Scores` dataclass corresponding to each sample -
|
| 75 |
+
- `precision: float`: Precision score, which ranges from 0.0 to 1.0.
|
| 76 |
+
- `recall: List[float]`: Recall score corresponding to each reference
|
| 77 |
+
- `f1: float`: F1 score (between precision and average recall).
|
| 78 |
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
+
## Future Extensions
|
| 81 |
+
Currently, we have only implemented the 3 encoders* that we experimented with in our
|
| 82 |
+
[paper](https://aclanthology.org/2022.emnlp-main.49/). However, it can easily with extended for more models by simply
|
| 83 |
+
extending the `Encoder` base class. (Refer to `encoder_models.py` file).
|
| 84 |
|
| 85 |
+
`*` *In out paper, we used the Tensorflow [version](https://www.tensorflow.org/hub/tutorials/semantic_similarity_with_tf_hub_universal_encoder)
|
| 86 |
+
of the USE model, however, in our current implementation, we used [PyTorch version](https://huggingface.co/sentence-transformers/use-cmlm-multilingual).*
|
| 87 |
|
| 88 |
[//]: # (*Give examples, preferrably with links to leaderboards or publications, to papers that have reported this metric, along with the values they have reported.*)
|
| 89 |
|
semf1.py
CHANGED
|
@@ -14,7 +14,6 @@
|
|
| 14 |
# TODO: Add test cases, Remove tokenize_sentences flag since it can be determined from the input itself.
|
| 15 |
"""Sem-F1 metric"""
|
| 16 |
|
| 17 |
-
from functools import partial
|
| 18 |
from typing import List, Optional, Tuple
|
| 19 |
|
| 20 |
import datasets
|
|
@@ -56,69 +55,93 @@ sentence level and computes precision, recall and F1 scores.
|
|
| 56 |
"""
|
| 57 |
|
| 58 |
_KWARGS_DESCRIPTION = """
|
| 59 |
-
Sem-F1 compares the system
|
|
|
|
| 60 |
|
| 61 |
Args:
|
| 62 |
-
predictions
|
| 63 |
-
references
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
Options:
|
| 73 |
False - CPU (Default)
|
| 74 |
-
True - GPU
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
Returns:
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
| 85 |
references: List[List[str]] - List of references where each reference is a list of sentences.
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
Examples:
|
| 108 |
|
| 109 |
>>> import evaluate
|
| 110 |
>>> predictions = [
|
| 111 |
-
["I go to School.
|
| 112 |
["I love adventure sports."],
|
| 113 |
]
|
| 114 |
>>> references = [
|
| 115 |
-
["I go to School.
|
| 116 |
-
["I love
|
| 117 |
]
|
| 118 |
>>> metric = evaluate.load("semf1")
|
| 119 |
>>> results = metric.compute(predictions=predictions, references=references)
|
| 120 |
-
>>>
|
| 121 |
-
|
| 122 |
"""
|
| 123 |
|
| 124 |
|
|
@@ -194,7 +217,12 @@ def _validate_input_format(
|
|
| 194 |
- `PREDICTION_TYPE` and `REFERENCE_TYPE` are defined at the top of the file
|
| 195 |
"""
|
| 196 |
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
if tokenize_sentences and multi_references:
|
| 199 |
condition = is_list_of_strings_at_depth(predictions, 1) and is_list_of_strings_at_depth(references, 2)
|
| 200 |
elif not tokenize_sentences and multi_references:
|
|
@@ -225,7 +253,7 @@ class SemF1(evaluate.Metric):
|
|
| 225 |
inputs_description=_KWARGS_DESCRIPTION,
|
| 226 |
# This defines the format of each prediction and reference
|
| 227 |
features=[
|
| 228 |
-
# Multi References: False, Tokenize_Sentences = False
|
| 229 |
datasets.Features(
|
| 230 |
{
|
| 231 |
# predictions: List[List[str]] - List of predictions where prediction is a list of sentences
|
|
@@ -234,7 +262,7 @@ class SemF1(evaluate.Metric):
|
|
| 234 |
"references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
|
| 235 |
}
|
| 236 |
),
|
| 237 |
-
# Multi References: False, Tokenize_Sentences = True
|
| 238 |
datasets.Features(
|
| 239 |
{
|
| 240 |
# predictions: List[str] - List of predictions
|
|
@@ -243,7 +271,7 @@ class SemF1(evaluate.Metric):
|
|
| 243 |
"references": datasets.Value("string", id="sequence"),
|
| 244 |
}
|
| 245 |
),
|
| 246 |
-
# Multi References: True, Tokenize_Sentences = False
|
| 247 |
datasets.Features(
|
| 248 |
{
|
| 249 |
# predictions: List[List[str]] - List of predictions where prediction is a list of sentences
|
|
@@ -255,7 +283,7 @@ class SemF1(evaluate.Metric):
|
|
| 255 |
datasets.Sequence(datasets.Value("string", id="sequence"), id="ref"), id="references"),
|
| 256 |
}
|
| 257 |
),
|
| 258 |
-
# Multi References: True, Tokenize_Sentences = True
|
| 259 |
datasets.Features(
|
| 260 |
{
|
| 261 |
# predictions: List[str] - List of predictions
|
|
@@ -319,6 +347,12 @@ class SemF1(evaluate.Metric):
|
|
| 319 |
:return: List of Scores dataclass with precision, recall, and F1 scores.
|
| 320 |
"""
|
| 321 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
# Validate inputs corresponding to flags
|
| 323 |
_validate_input_format(tokenize_sentences, multi_references, predictions, references)
|
| 324 |
|
|
@@ -363,10 +397,11 @@ class SemF1(evaluate.Metric):
|
|
| 363 |
# Precision: Concatenate all the sentences in all the references
|
| 364 |
concat_refs = np.concatenate(refs, axis=0)
|
| 365 |
precision, _ = _compute_cosine_similarity(preds, concat_refs)
|
|
|
|
| 366 |
|
| 367 |
# Recall: Compute individually for each reference
|
| 368 |
recall_scores = [_compute_cosine_similarity(r_embeds, preds) for r_embeds in refs]
|
| 369 |
-
recall_scores = [r_scores for (r_scores, _) in recall_scores]
|
| 370 |
|
| 371 |
results.append(Scores(precision, recall_scores))
|
| 372 |
|
|
|
|
| 14 |
# TODO: Add test cases, Remove tokenize_sentences flag since it can be determined from the input itself.
|
| 15 |
"""Sem-F1 metric"""
|
| 16 |
|
|
|
|
| 17 |
from typing import List, Optional, Tuple
|
| 18 |
|
| 19 |
import datasets
|
|
|
|
| 55 |
"""
|
| 56 |
|
| 57 |
_KWARGS_DESCRIPTION = """
|
| 58 |
+
Sem-F1 compares the system-generated summaries (predictions) with ground truth reference summaries (references)
|
| 59 |
+
using precision, recall, and F1 score based on sentence embeddings.
|
| 60 |
|
| 61 |
Args:
|
| 62 |
+
predictions (list): List of predictions. Format varies based on `tokenize_sentences` and `multi_references` flags.
|
| 63 |
+
references (list): List of references. Format varies based on `tokenize_sentences` and `multi_references` flags.
|
| 64 |
+
model_type (str): Model to use for encoding sentences. Options: ['pv1', 'stsb', 'use']
|
| 65 |
+
pv1 - paraphrase-distilroberta-base-v1 (Default)
|
| 66 |
+
stsb - stsb-roberta-large
|
| 67 |
+
use - Universal Sentence Encoder
|
| 68 |
+
tokenize_sentences (bool): Flag to indicate whether to tokenize the sentences in the input documents. Default: True.
|
| 69 |
+
multi_references (bool): Flag to indicate whether multiple references are provided. Default is False.
|
| 70 |
+
gpu (Union[bool, str, int, List[Union[str, int]]]): Whether to use GPU or CPU for computation.
|
| 71 |
+
bool -
|
|
|
|
| 72 |
False - CPU (Default)
|
| 73 |
+
True - GPU (device 0) if gpu is available else CPU
|
| 74 |
+
int -
|
| 75 |
+
n - GPU, device index n
|
| 76 |
+
str -
|
| 77 |
+
'cuda', 'gpu', 'cpu'
|
| 78 |
+
List[Union[str, int]] - Multiple GPUs/cpus i.e. use multiple processes when computing embeddings
|
| 79 |
+
batch_size (int): Batch size for encoding. Default is 32.
|
| 80 |
+
verbose (bool): Flag to indicate verbose output. Default is False.
|
| 81 |
+
|
| 82 |
Returns:
|
| 83 |
+
List of Scores dataclass with attributes as follows -
|
| 84 |
+
precision: float - precision score
|
| 85 |
+
recall: List[float] - List of recall scores corresponding to single/multiple references
|
| 86 |
+
f1: float - F1 score (between precision and average recall)
|
| 87 |
+
|
| 88 |
+
Examples of input formats:
|
| 89 |
+
|
| 90 |
+
Case 1: multi_references = False, tokenize_sentences = False
|
| 91 |
+
predictions: List[List[str]] - List of predictions where each prediction is a list of sentences.
|
| 92 |
references: List[List[str]] - List of references where each reference is a list of sentences.
|
| 93 |
+
Example:
|
| 94 |
+
predictions = [["This is a prediction sentence 1.", "This is a prediction sentence 2."]]
|
| 95 |
+
references = [["This is a reference sentence 1.", "This is a reference sentence 2."]]
|
| 96 |
+
|
| 97 |
+
Case 2: multi_references = False, tokenize_sentences = True
|
| 98 |
+
predictions: List[str] - List of predictions where each prediction is a document.
|
| 99 |
+
references: List[str] - List of references where each reference is a document.
|
| 100 |
+
Example:
|
| 101 |
+
predictions = ["This is a prediction sentence 1. This is a prediction sentence 2."]
|
| 102 |
+
references = ["This is a reference sentence 1. This is a reference sentence 2."]
|
| 103 |
+
|
| 104 |
+
Case 3: multi_references = True, tokenize_sentences = False
|
| 105 |
+
predictions: List[List[str]] - List of predictions where each prediction is a list of sentences.
|
| 106 |
+
references: List[List[List[str]]] - List of references where each example has multi-references (List[r1, r2, ...])
|
| 107 |
+
and each ri is a List of sentences.
|
| 108 |
+
Example:
|
| 109 |
+
predictions = [["Prediction sentence 1.", "Prediction sentence 2."]]
|
| 110 |
+
references = [
|
| 111 |
+
[
|
| 112 |
+
["Reference sentence 1.", "Reference sentence 2."], # Reference 1
|
| 113 |
+
["Alternative reference 1.", "Alternative reference 2."], # Reference 2
|
| 114 |
+
]
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
Case 4: multi_references = True, tokenize_sentences = True
|
| 118 |
+
predictions: List[str] - List of predictions where each prediction is a document.
|
| 119 |
+
references: List[List[str]] - List of references where each example has multi-references (List[r1, r2, ...]) where
|
| 120 |
+
each r1 is a document.
|
| 121 |
+
Example:
|
| 122 |
+
predictions = ["Prediction sentence 1. Prediction sentence 2."]
|
| 123 |
+
references = [
|
| 124 |
+
[
|
| 125 |
+
"Reference sentence 1. Reference sentence 2.", # Reference 1
|
| 126 |
+
"Alternative reference 1. Alternative reference 2.", # Reference 2
|
| 127 |
+
]
|
| 128 |
+
]
|
| 129 |
+
|
| 130 |
Examples:
|
| 131 |
|
| 132 |
>>> import evaluate
|
| 133 |
>>> predictions = [
|
| 134 |
+
["I go to School. You are stupid."],
|
| 135 |
["I love adventure sports."],
|
| 136 |
]
|
| 137 |
>>> references = [
|
| 138 |
+
["I go to School. You are stupid."],
|
| 139 |
+
["I love outdoor sports."],
|
| 140 |
]
|
| 141 |
>>> metric = evaluate.load("semf1")
|
| 142 |
>>> results = metric.compute(predictions=predictions, references=references)
|
| 143 |
+
>>> for score in results:
|
| 144 |
+
>>> print(f"Precision: {score.precision}, Recall: {score.recall}, F1: {score.f1}")
|
| 145 |
"""
|
| 146 |
|
| 147 |
|
|
|
|
| 217 |
- `PREDICTION_TYPE` and `REFERENCE_TYPE` are defined at the top of the file
|
| 218 |
"""
|
| 219 |
|
| 220 |
+
if len(predictions) != len(references):
|
| 221 |
+
raise ValueError("Predictions and references must have the same length.")
|
| 222 |
+
|
| 223 |
+
def is_list_of_strings_at_depth(lst_obj, depth: int):
|
| 224 |
+
return is_nested_list_of_type(lst_obj, element_type=str, depth=depth)
|
| 225 |
+
|
| 226 |
if tokenize_sentences and multi_references:
|
| 227 |
condition = is_list_of_strings_at_depth(predictions, 1) and is_list_of_strings_at_depth(references, 2)
|
| 228 |
elif not tokenize_sentences and multi_references:
|
|
|
|
| 253 |
inputs_description=_KWARGS_DESCRIPTION,
|
| 254 |
# This defines the format of each prediction and reference
|
| 255 |
features=[
|
| 256 |
+
# F0: Multi References: False, Tokenize_Sentences = False
|
| 257 |
datasets.Features(
|
| 258 |
{
|
| 259 |
# predictions: List[List[str]] - List of predictions where prediction is a list of sentences
|
|
|
|
| 262 |
"references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
|
| 263 |
}
|
| 264 |
),
|
| 265 |
+
# F1: Multi References: False, Tokenize_Sentences = True
|
| 266 |
datasets.Features(
|
| 267 |
{
|
| 268 |
# predictions: List[str] - List of predictions
|
|
|
|
| 271 |
"references": datasets.Value("string", id="sequence"),
|
| 272 |
}
|
| 273 |
),
|
| 274 |
+
# F2: Multi References: True, Tokenize_Sentences = False
|
| 275 |
datasets.Features(
|
| 276 |
{
|
| 277 |
# predictions: List[List[str]] - List of predictions where prediction is a list of sentences
|
|
|
|
| 283 |
datasets.Sequence(datasets.Value("string", id="sequence"), id="ref"), id="references"),
|
| 284 |
}
|
| 285 |
),
|
| 286 |
+
# F3: Multi References: True, Tokenize_Sentences = True
|
| 287 |
datasets.Features(
|
| 288 |
{
|
| 289 |
# predictions: List[str] - List of predictions
|
|
|
|
| 347 |
:return: List of Scores dataclass with precision, recall, and F1 scores.
|
| 348 |
"""
|
| 349 |
|
| 350 |
+
# Note: I have to specifically handle this case because the library considers the feature corresponding to
|
| 351 |
+
# this case (F2) as the feature for the other case (F0) i.e. it can't make any distinction between
|
| 352 |
+
# List[str] and List[List[str]]
|
| 353 |
+
if not tokenize_sentences and multi_references:
|
| 354 |
+
references = [[eval(ref) for ref in mul_ref_ex] for mul_ref_ex in references]
|
| 355 |
+
|
| 356 |
# Validate inputs corresponding to flags
|
| 357 |
_validate_input_format(tokenize_sentences, multi_references, predictions, references)
|
| 358 |
|
|
|
|
| 397 |
# Precision: Concatenate all the sentences in all the references
|
| 398 |
concat_refs = np.concatenate(refs, axis=0)
|
| 399 |
precision, _ = _compute_cosine_similarity(preds, concat_refs)
|
| 400 |
+
precision = np.clip(precision, a_min=0.0, a_max=1.0).item()
|
| 401 |
|
| 402 |
# Recall: Compute individually for each reference
|
| 403 |
recall_scores = [_compute_cosine_similarity(r_embeds, preds) for r_embeds in refs]
|
| 404 |
+
recall_scores = [np.clip(r_scores, 0.0, 1.0).item() for (r_scores, _) in recall_scores]
|
| 405 |
|
| 406 |
results.append(Scores(precision, recall_scores))
|
| 407 |
|
tests.py
CHANGED
|
@@ -3,9 +3,12 @@ import unittest
|
|
| 3 |
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
|
|
|
| 6 |
from sentence_transformers import SentenceTransformer
|
|
|
|
| 7 |
|
| 8 |
from encoder_models import SBertEncoder, get_encoder
|
|
|
|
| 9 |
from utils import get_gpu, slice_embeddings, is_nested_list_of_type, flatten_list, compute_f1, Scores
|
| 10 |
|
| 11 |
|
|
@@ -178,5 +181,321 @@ class TestGetEncoder(unittest.TestCase):
|
|
| 178 |
# self.assertEqual(encoder.verbose, verbose)
|
| 179 |
|
| 180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
if __name__ == '__main__':
|
| 182 |
-
unittest.main()
|
|
|
|
|
|
| 3 |
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
| 6 |
+
from numpy.testing import assert_almost_equal
|
| 7 |
from sentence_transformers import SentenceTransformer
|
| 8 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 9 |
|
| 10 |
from encoder_models import SBertEncoder, get_encoder
|
| 11 |
+
from semf1 import SemF1, _compute_cosine_similarity, _validate_input_format
|
| 12 |
from utils import get_gpu, slice_embeddings, is_nested_list_of_type, flatten_list, compute_f1, Scores
|
| 13 |
|
| 14 |
|
|
|
|
| 181 |
# self.assertEqual(encoder.verbose, verbose)
|
| 182 |
|
| 183 |
|
| 184 |
+
class TestSemF1(unittest.TestCase):
|
| 185 |
+
def setUp(self):
|
| 186 |
+
self.semf1_metric = SemF1() # semf1_metric
|
| 187 |
+
|
| 188 |
+
# Example cases, #Samples = 1
|
| 189 |
+
self.untokenized_single_reference_predictions = [
|
| 190 |
+
"This is a prediction sentence 1. This is a prediction sentence 2."]
|
| 191 |
+
self.untokenized_single_reference_references = [
|
| 192 |
+
"This is a reference sentence 1. This is a reference sentence 2."]
|
| 193 |
+
|
| 194 |
+
self.tokenized_single_reference_predictions = [
|
| 195 |
+
["This is a prediction sentence 1.", "This is a prediction sentence 2."],
|
| 196 |
+
]
|
| 197 |
+
self.tokenized_single_reference_references = [
|
| 198 |
+
["This is a reference sentence 1.", "This is a reference sentence 2."],
|
| 199 |
+
]
|
| 200 |
+
|
| 201 |
+
self.untokenized_multi_reference_predictions = [
|
| 202 |
+
"Prediction sentence 1. Prediction sentence 2."
|
| 203 |
+
]
|
| 204 |
+
self.untokenized_multi_reference_references = [
|
| 205 |
+
["Reference sentence 1. Reference sentence 2.", "Alternative reference 1. Alternative reference 2."],
|
| 206 |
+
]
|
| 207 |
+
|
| 208 |
+
self.tokenized_multi_reference_predictions = [
|
| 209 |
+
["Prediction sentence 1.", "Prediction sentence 2."],
|
| 210 |
+
]
|
| 211 |
+
self.tokenized_multi_reference_references = [
|
| 212 |
+
[
|
| 213 |
+
["Reference sentence 1.", "Reference sentence 2."],
|
| 214 |
+
["Alternative reference 1.", "Alternative reference 2."]
|
| 215 |
+
],
|
| 216 |
+
]
|
| 217 |
+
|
| 218 |
+
def test_untokenized_single_reference(self):
|
| 219 |
+
scores = self.semf1_metric.compute(
|
| 220 |
+
predictions=self.untokenized_single_reference_predictions,
|
| 221 |
+
references=self.untokenized_single_reference_references,
|
| 222 |
+
tokenize_sentences=True,
|
| 223 |
+
multi_references=False,
|
| 224 |
+
gpu=False,
|
| 225 |
+
batch_size=32,
|
| 226 |
+
verbose=False
|
| 227 |
+
)
|
| 228 |
+
self.assertIsInstance(scores, list)
|
| 229 |
+
self.assertEqual(len(scores), len(self.untokenized_single_reference_predictions))
|
| 230 |
+
|
| 231 |
+
def test_tokenized_single_reference(self):
|
| 232 |
+
scores = self.semf1_metric.compute(
|
| 233 |
+
predictions=self.tokenized_single_reference_predictions,
|
| 234 |
+
references=self.tokenized_single_reference_references,
|
| 235 |
+
tokenize_sentences=False,
|
| 236 |
+
multi_references=False,
|
| 237 |
+
gpu=False,
|
| 238 |
+
batch_size=32,
|
| 239 |
+
verbose=False
|
| 240 |
+
)
|
| 241 |
+
self.assertIsInstance(scores, list)
|
| 242 |
+
self.assertEqual(len(scores), len(self.tokenized_single_reference_predictions))
|
| 243 |
+
|
| 244 |
+
for score in scores:
|
| 245 |
+
self.assertIsInstance(score, Scores)
|
| 246 |
+
self.assertTrue(0.0 <= score.precision <= 1.0)
|
| 247 |
+
self.assertTrue(all(0.0 <= recall <= 1.0 for recall in score.recall))
|
| 248 |
+
|
| 249 |
+
def test_untokenized_multi_reference(self):
|
| 250 |
+
scores = self.semf1_metric.compute(
|
| 251 |
+
predictions=self.untokenized_multi_reference_predictions,
|
| 252 |
+
references=self.untokenized_multi_reference_references,
|
| 253 |
+
tokenize_sentences=True,
|
| 254 |
+
multi_references=True,
|
| 255 |
+
gpu=False,
|
| 256 |
+
batch_size=32,
|
| 257 |
+
verbose=False
|
| 258 |
+
)
|
| 259 |
+
self.assertIsInstance(scores, list)
|
| 260 |
+
self.assertEqual(len(scores), len(self.untokenized_multi_reference_predictions))
|
| 261 |
+
|
| 262 |
+
def test_tokenized_multi_reference(self):
|
| 263 |
+
scores = self.semf1_metric.compute(
|
| 264 |
+
predictions=self.tokenized_multi_reference_predictions,
|
| 265 |
+
references=self.tokenized_multi_reference_references,
|
| 266 |
+
tokenize_sentences=False,
|
| 267 |
+
multi_references=True,
|
| 268 |
+
gpu=False,
|
| 269 |
+
batch_size=32,
|
| 270 |
+
verbose=False
|
| 271 |
+
)
|
| 272 |
+
self.assertIsInstance(scores, list)
|
| 273 |
+
self.assertEqual(len(scores), len(self.tokenized_multi_reference_predictions))
|
| 274 |
+
|
| 275 |
+
for score in scores:
|
| 276 |
+
self.assertIsInstance(score, Scores)
|
| 277 |
+
self.assertTrue(0.0 <= score.precision <= 1.0)
|
| 278 |
+
self.assertTrue(all(0.0 <= recall <= 1.0 for recall in score.recall))
|
| 279 |
+
|
| 280 |
+
def test_same_predictions_and_references(self):
|
| 281 |
+
scores = self.semf1_metric.compute(
|
| 282 |
+
predictions=self.tokenized_single_reference_predictions,
|
| 283 |
+
references=self.tokenized_single_reference_predictions,
|
| 284 |
+
tokenize_sentences=False,
|
| 285 |
+
multi_references=False,
|
| 286 |
+
gpu=False,
|
| 287 |
+
batch_size=32,
|
| 288 |
+
verbose=False
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
self.assertIsInstance(scores, list)
|
| 292 |
+
self.assertEqual(len(scores), len(self.tokenized_single_reference_predictions))
|
| 293 |
+
|
| 294 |
+
for score in scores:
|
| 295 |
+
self.assertIsInstance(score, Scores)
|
| 296 |
+
self.assertAlmostEqual(score.precision, 1.0, places=6)
|
| 297 |
+
assert_almost_equal(score.recall, 1, decimal=5, err_msg="Not all values are almost equal to 1")
|
| 298 |
+
|
| 299 |
+
def test_exact_output_scores(self):
|
| 300 |
+
predictions = [
|
| 301 |
+
["I go to School.", "You are stupid."],
|
| 302 |
+
["I love adventure sports."],
|
| 303 |
+
]
|
| 304 |
+
references = [
|
| 305 |
+
["I go to playground.", "You are genius.", "You need to be admired."],
|
| 306 |
+
["I love adventure sports."],
|
| 307 |
+
]
|
| 308 |
+
scores = self.semf1_metric.compute(
|
| 309 |
+
predictions=predictions,
|
| 310 |
+
references=references,
|
| 311 |
+
tokenize_sentences=False,
|
| 312 |
+
multi_references=False,
|
| 313 |
+
gpu=False,
|
| 314 |
+
batch_size=32,
|
| 315 |
+
verbose=False,
|
| 316 |
+
model_type="use",
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
self.assertIsInstance(scores, list)
|
| 320 |
+
self.assertEqual(len(scores), len(predictions))
|
| 321 |
+
|
| 322 |
+
score = scores[0]
|
| 323 |
+
self.assertIsInstance(score, Scores)
|
| 324 |
+
self.assertAlmostEqual(score.precision, 0.73, places=2)
|
| 325 |
+
self.assertAlmostEqual(score.recall[0], 0.63, places=2)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class TestCosineSimilarity(unittest.TestCase):
|
| 329 |
+
|
| 330 |
+
def setUp(self):
|
| 331 |
+
# Sample embeddings for testing
|
| 332 |
+
self.pred_embeds = np.array([
|
| 333 |
+
[1, 0, 0],
|
| 334 |
+
[0, 1, 0],
|
| 335 |
+
[0, 0, 1]
|
| 336 |
+
])
|
| 337 |
+
self.ref_embeds = np.array([
|
| 338 |
+
[1, 0, 0],
|
| 339 |
+
[0, 1, 0],
|
| 340 |
+
[0, 0, 1]
|
| 341 |
+
])
|
| 342 |
+
|
| 343 |
+
self.pred_embeds_random = np.random.rand(3, 3)
|
| 344 |
+
self.ref_embeds_random = np.random.rand(3, 3)
|
| 345 |
+
|
| 346 |
+
def test_cosine_similarity_perfect_match(self):
|
| 347 |
+
precision, recall = _compute_cosine_similarity(self.pred_embeds, self.ref_embeds)
|
| 348 |
+
|
| 349 |
+
# Expected values are 1.0 for both precision and recall since embeddings are identical
|
| 350 |
+
self.assertAlmostEqual(precision, 1.0, places=5)
|
| 351 |
+
self.assertAlmostEqual(recall, 1.0, places=5)
|
| 352 |
+
|
| 353 |
+
def _test_cosine_similarity_base(self, pred_embeds, ref_embeds):
|
| 354 |
+
precision, recall = _compute_cosine_similarity(pred_embeds, ref_embeds)
|
| 355 |
+
|
| 356 |
+
# Calculate expected precision and recall using sklearn's cosine similarity function
|
| 357 |
+
cosine_scores = cosine_similarity(pred_embeds, ref_embeds)
|
| 358 |
+
expected_precision = np.mean(np.max(cosine_scores, axis=-1)).item()
|
| 359 |
+
expected_recall = np.mean(np.max(cosine_scores, axis=0)).item()
|
| 360 |
+
|
| 361 |
+
self.assertAlmostEqual(precision, expected_precision, places=5)
|
| 362 |
+
self.assertAlmostEqual(recall, expected_recall, places=5)
|
| 363 |
+
|
| 364 |
+
def test_cosine_similarity_random(self):
|
| 365 |
+
self._test_cosine_similarity_base(self.pred_embeds_random, self.ref_embeds_random)
|
| 366 |
+
|
| 367 |
+
def test_cosine_similarity_different_shapes(self):
|
| 368 |
+
pred_embeds_diff = np.random.rand(5, 3)
|
| 369 |
+
ref_embeds_diff = np.random.rand(3, 3)
|
| 370 |
+
self._test_cosine_similarity_base(pred_embeds_diff, ref_embeds_diff)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
class TestValidateInputFormat(unittest.TestCase):
|
| 374 |
+
def setUp(self):
|
| 375 |
+
# Sample predictions and references for different scenarios where number of samples = 1
|
| 376 |
+
# Note: Naming Convention: # When tokenize_sentences = True (i.e. input is untokenized) and vice-versa
|
| 377 |
+
|
| 378 |
+
# When tokenize_sentences = True (untokenized input) and multi_references = False
|
| 379 |
+
self.untokenized_single_reference_predictions = [
|
| 380 |
+
"This is a prediction sentence 1. This is a prediction sentence 2."
|
| 381 |
+
]
|
| 382 |
+
self.untokenized_single_reference_references = [
|
| 383 |
+
"This is a reference sentence 1. This is a reference sentence 2."
|
| 384 |
+
]
|
| 385 |
+
|
| 386 |
+
# When tokenize_sentences = False (tokenized input) and multi_references = False
|
| 387 |
+
self.tokenized_single_reference_predictions = [
|
| 388 |
+
["This is a prediction sentence 1.", "This is a prediction sentence 2."]
|
| 389 |
+
]
|
| 390 |
+
self.tokenized_single_reference_references = [
|
| 391 |
+
["This is a reference sentence 1.", "This is a reference sentence 2."]
|
| 392 |
+
]
|
| 393 |
+
|
| 394 |
+
# When tokenize_sentences = True (untokenized input) and multi_references = True
|
| 395 |
+
self.untokenized_multi_reference_predictions = [
|
| 396 |
+
"This is a prediction sentence 1. This is a prediction sentence 2."
|
| 397 |
+
]
|
| 398 |
+
self.untokenized_multi_reference_references = [
|
| 399 |
+
[
|
| 400 |
+
"This is a reference sentence 1. This is a reference sentence 2.",
|
| 401 |
+
"Another reference sentence."
|
| 402 |
+
]
|
| 403 |
+
]
|
| 404 |
+
|
| 405 |
+
# When tokenize_sentences = False (tokenized input) and multi_references = True
|
| 406 |
+
self.tokenized_multi_reference_predictions = [
|
| 407 |
+
["This is a prediction sentence 1.", "This is a prediction sentence 2."]
|
| 408 |
+
]
|
| 409 |
+
self.tokenized_multi_reference_references = [
|
| 410 |
+
[
|
| 411 |
+
["This is a reference sentence 1.", "This is a reference sentence 2."],
|
| 412 |
+
["Another reference sentence."]
|
| 413 |
+
]
|
| 414 |
+
]
|
| 415 |
+
|
| 416 |
+
def test_tokenized_sentences_true_multi_references_true(self):
|
| 417 |
+
# Invalid format should raise an error
|
| 418 |
+
with self.assertRaises(ValueError):
|
| 419 |
+
_validate_input_format(
|
| 420 |
+
True,
|
| 421 |
+
True,
|
| 422 |
+
self.tokenized_single_reference_predictions,
|
| 423 |
+
self.tokenized_single_reference_references,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# Valid format should pass without error
|
| 427 |
+
_validate_input_format(
|
| 428 |
+
True,
|
| 429 |
+
True,
|
| 430 |
+
self.untokenized_multi_reference_predictions,
|
| 431 |
+
self.untokenized_multi_reference_references,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
def test_tokenized_sentences_false_multi_references_true(self):
|
| 435 |
+
# Invalid format should raise an error
|
| 436 |
+
with self.assertRaises(ValueError):
|
| 437 |
+
_validate_input_format(
|
| 438 |
+
False,
|
| 439 |
+
True,
|
| 440 |
+
self.untokenized_single_reference_predictions,
|
| 441 |
+
self.untokenized_multi_reference_references,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# Valid format should pass without error
|
| 445 |
+
_validate_input_format(
|
| 446 |
+
False,
|
| 447 |
+
True,
|
| 448 |
+
self.tokenized_multi_reference_predictions,
|
| 449 |
+
self.tokenized_multi_reference_references,
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
def test_tokenized_sentences_true_multi_references_false(self):
|
| 453 |
+
# Invalid format should raise an error
|
| 454 |
+
with self.assertRaises(ValueError):
|
| 455 |
+
_validate_input_format(
|
| 456 |
+
True,
|
| 457 |
+
False,
|
| 458 |
+
self.tokenized_single_reference_predictions,
|
| 459 |
+
self.tokenized_single_reference_references,
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
# Valid format should pass without error
|
| 463 |
+
_validate_input_format(
|
| 464 |
+
True,
|
| 465 |
+
False,
|
| 466 |
+
self.untokenized_single_reference_predictions,
|
| 467 |
+
self.untokenized_single_reference_references,
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
def test_tokenized_sentences_false_multi_references_false(self):
|
| 471 |
+
# Invalid format should raise an error
|
| 472 |
+
with self.assertRaises(ValueError):
|
| 473 |
+
_validate_input_format(
|
| 474 |
+
False,
|
| 475 |
+
False,
|
| 476 |
+
self.untokenized_single_reference_predictions,
|
| 477 |
+
self.untokenized_single_reference_references,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
# Valid format should pass without error
|
| 481 |
+
_validate_input_format(
|
| 482 |
+
False,
|
| 483 |
+
False,
|
| 484 |
+
self.tokenized_single_reference_predictions,
|
| 485 |
+
self.tokenized_single_reference_references,
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
def test_mismatched_lengths(self):
|
| 489 |
+
# Length mismatch should raise an error
|
| 490 |
+
with self.assertRaises(ValueError):
|
| 491 |
+
_validate_input_format(
|
| 492 |
+
True,
|
| 493 |
+
True,
|
| 494 |
+
self.untokenized_single_reference_predictions,
|
| 495 |
+
[self.untokenized_single_reference_predictions[0], self.untokenized_single_reference_predictions[0]],
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
|
| 499 |
if __name__ == '__main__':
|
| 500 |
+
unittest.main(verbosity=2)
|
| 501 |
+
# unittest.main()
|