Spaces:
Sleeping
Sleeping
File size: 1,925 Bytes
4a1df2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
"""
Goal Function for Attempts to minimize the BLEU score
-------------------------------------------------------
"""
import functools
import nltk
import textattack
from .text_to_text_goal_function import TextToTextGoalFunction
class MinimizeBleu(TextToTextGoalFunction):
"""Attempts to minimize the BLEU score between the current output
translation and the reference translation.
BLEU score was defined in (BLEU: a Method for Automatic Evaluation of Machine Translation).
`ArxivURL`_
.. _ArxivURL: https://www.aclweb.org/anthology/P02-1040.pdf
This goal function is defined in (It’s Morphin’ Time! Combating Linguistic Discrimination with Inflectional Perturbations).
`ArxivURL2`_
.. _ArxivURL2: https://www.aclweb.org/anthology/2020.acl-main.263
"""
EPS = 1e-10
def __init__(self, *args, target_bleu=0.0, **kwargs):
self.target_bleu = target_bleu
super().__init__(*args, **kwargs)
def clear_cache(self):
if self.use_cache:
self._call_model_cache.clear()
get_bleu.cache_clear()
def _is_goal_complete(self, model_output, _):
bleu_score = 1.0 - self._get_score(model_output, _)
return bleu_score <= (self.target_bleu + MinimizeBleu.EPS)
def _get_score(self, model_output, _):
model_output_at = textattack.shared.AttackedText(model_output)
ground_truth_at = textattack.shared.AttackedText(self.ground_truth_output)
bleu_score = get_bleu(model_output_at, ground_truth_at)
return 1.0 - bleu_score
def extra_repr_keys(self):
if self.maximizable:
return ["maximizable"]
else:
return ["maximizable", "target_bleu"]
@functools.lru_cache(maxsize=2**12)
def get_bleu(a, b):
ref = a.words
hyp = b.words
bleu_score = nltk.translate.bleu_score.sentence_bleu([ref], hyp)
return bleu_score
|