Spaces:
Sleeping
Sleeping
File size: 4,971 Bytes
4a1df2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
"""
BackTranslation class
-----------------------------------
"""
import random
from transformers import MarianMTModel, MarianTokenizer
from textattack.shared import AttackedText
from .sentence_transformation import SentenceTransformation
class BackTranslation(SentenceTransformation):
"""A type of sentence level transformation that takes in a text input,
translates it into target language and translates it back to source
language.
letters_to_insert (string): letters allowed for insertion into words
(used by some char-based transformations)
src_lang (string): source language
target_lang (string): target language, for the list of supported language check bottom of this page
src_model: translation model from huggingface that translates from source language to target language
target_model: translation model from huggingface that translates from target language to source language
chained_back_translation: run back translation in a chain for more perturbation (for example, en-es-en-fr-en)
Example::
>>> from textattack.transformations.sentence_transformations import BackTranslation
>>> from textattack.constraints.pre_transformation import RepeatModification, StopwordModification
>>> from textattack.augmentation import Augmenter
>>> transformation = BackTranslation()
>>> constraints = [RepeatModification(), StopwordModification()]
>>> augmenter = Augmenter(transformation = transformation, constraints = constraints)
>>> s = 'What on earth are you doing here.'
>>> augmenter.augment(s)
"""
def __init__(
self,
src_lang="en",
target_lang="es",
src_model="Helsinki-NLP/opus-mt-ROMANCE-en",
target_model="Helsinki-NLP/opus-mt-en-ROMANCE",
chained_back_translation=0,
):
self.src_lang = src_lang
self.target_lang = target_lang
self.target_model = MarianMTModel.from_pretrained(target_model)
self.target_tokenizer = MarianTokenizer.from_pretrained(target_model)
self.src_model = MarianMTModel.from_pretrained(src_model)
self.src_tokenizer = MarianTokenizer.from_pretrained(src_model)
self.chained_back_translation = chained_back_translation
def translate(self, input, model, tokenizer, lang="es"):
# change the text to model's format
src_texts = []
if lang == "en":
src_texts.append(input[0])
else:
if ">>" and "<<" not in lang:
lang = ">>" + lang + "<< "
src_texts.append(lang + input[0])
# tokenize the input
encoded_input = tokenizer.prepare_seq2seq_batch(src_texts, return_tensors="pt")
# translate the input
translated = model.generate(**encoded_input)
translated_input = tokenizer.batch_decode(translated, skip_special_tokens=True)
return translated_input
def _get_transformations(self, current_text, indices_to_modify):
transformed_texts = []
current_text = current_text.text
# to perform chained back translation, a random list of target languages are selected from the provided model
if self.chained_back_translation:
list_of_target_lang = random.sample(
self.target_tokenizer.supported_language_codes,
self.chained_back_translation,
)
for target_lang in list_of_target_lang:
target_language_text = self.translate(
[current_text],
self.target_model,
self.target_tokenizer,
target_lang,
)
src_language_text = self.translate(
target_language_text,
self.src_model,
self.src_tokenizer,
self.src_lang,
)
current_text = src_language_text[0]
return [AttackedText(current_text)]
# translates source to target language and back to source language (single back translation)
target_language_text = self.translate(
[current_text], self.target_model, self.target_tokenizer, self.target_lang
)
src_language_text = self.translate(
target_language_text, self.src_model, self.src_tokenizer, self.src_lang
)
transformed_texts.append(AttackedText(src_language_text[0]))
return transformed_texts
"""
List of supported languages
['fr',
'es',
'it',
'pt',
'pt_br',
'ro',
'ca',
'gl',
'pt_BR<<',
'la<<',
'wa<<',
'fur<<',
'oc<<',
'fr_CA<<',
'sc<<',
'es_ES',
'es_MX',
'es_AR',
'es_PR',
'es_UY',
'es_CL',
'es_CO',
'es_CR',
'es_GT',
'es_HN',
'es_NI',
'es_PA',
'es_PE',
'es_VE',
'es_DO',
'es_EC',
'es_SV',
'an',
'pt_PT',
'frp',
'lad',
'vec',
'fr_FR',
'co',
'it_IT',
'lld',
'lij',
'lmo',
'nap',
'rm',
'scn',
'mwl']
"""
|