Spaces:
Sleeping
Sleeping
""" | |
BERT Score | |
--------------------- | |
BERT Score is introduced in this paper (BERTScore: Evaluating Text Generation with BERT) `arxiv link`_. | |
.. _arxiv link: https://arxiv.org/abs/1904.09675 | |
BERT Score measures token similarity between two text using contextual embedding. | |
To decide which two tokens to compare, it greedily chooses the most similar token from one text and matches it to a token in the second text. | |
""" | |
import bert_score | |
from textattack.constraints import Constraint | |
from textattack.shared import utils | |
class BERTScore(Constraint): | |
"""A constraint on BERT-Score difference. | |
Args: | |
min_bert_score (float), minimum threshold value for BERT-Score | |
model_name (str), name of model to use for scoring | |
num_layers (int), number of hidden layers in the model | |
score_type (str), Pick one of following three choices | |
-(1) ``precision`` : match words from candidate text to reference text | |
-(2) ``recall`` : match words from reference text to candidate text | |
-(3) ``f1``: harmonic mean of precision and recall (recommended) | |
compare_against_original (bool): | |
If ``True``, compare new ``x_adv`` against the original ``x``. | |
Otherwise, compare it against the previous ``x_adv``. | |
""" | |
SCORE_TYPE2IDX = {"precision": 0, "recall": 1, "f1": 2} | |
def __init__( | |
self, | |
min_bert_score, | |
model_name="bert-base-uncased", | |
num_layers=None, | |
score_type="f1", | |
compare_against_original=True, | |
): | |
super().__init__(compare_against_original) | |
if not isinstance(min_bert_score, float): | |
raise TypeError("max_bert_score must be a float") | |
if min_bert_score < 0.0 or min_bert_score > 1.0: | |
raise ValueError("max_bert_score must be a value between 0.0 and 1.0") | |
self.min_bert_score = min_bert_score | |
self.model = model_name | |
self.score_type = score_type | |
# Turn off idf-weighting scheme b/c reference sentence set is small | |
self._bert_scorer = bert_score.BERTScorer( | |
model_type=model_name, idf=False, device=utils.device, num_layers=num_layers | |
) | |
def _check_constraint(self, transformed_text, reference_text): | |
"""Return `True` if BERT Score between `transformed_text` and | |
`reference_text` is lower than minimum BERT Score.""" | |
cand = transformed_text.text | |
ref = reference_text.text | |
result = self._bert_scorer.score([cand], [ref]) | |
score = result[BERTScore.SCORE_TYPE2IDX[self.score_type]].item() | |
if score >= self.min_bert_score: | |
return True | |
else: | |
return False | |
def extra_repr_keys(self): | |
return ["min_bert_score", "model", "score_type"] + super().extra_repr_keys() | |