Spaces:
Sleeping
Sleeping
File size: 4,971 Bytes
4a1df2e |
|
"""
BackTranslation class
-----------------------------------
"""
import random
from transformers import MarianMTModel, MarianTokenizer
from textattack.shared import AttackedText
from .sentence_transformation import SentenceTransformation
class BackTranslation(SentenceTransformation):
"""A type of sentence level transformation that takes in a text input,
translates it into target language and translates it back to source
language.
letters_to_insert (string): letters allowed for insertion into words
(used by some char-based transformations)
src_lang (string): source language
target_lang (string): target language, for the list of supported language check bottom of this page
src_model: translation model from huggingface that translates from source language to target language
target_model: translation model from huggingface that translates from target language to source language
chained_back_translation: run back translation in a chain for more perturbation (for example, en-es-en-fr-en)
Example::
>>> from textattack.transformations.sentence_transformations import BackTranslation
>>> from textattack.constraints.pre_transformation import RepeatModification, StopwordModification
>>> from textattack.augmentation import Augmenter
>>> transformation = BackTranslation()
>>> constraints = [RepeatModification(), StopwordModification()]
>>> augmenter = Augmenter(transformation = transformation, constraints = constraints)
>>> s = 'What on earth are you doing here.'
>>> augmenter.augment(s)
"""
def __init__(
self,
src_lang="en",
target_lang="es",
src_model="Helsinki-NLP/opus-mt-ROMANCE-en",
target_model="Helsinki-NLP/opus-mt-en-ROMANCE",
chained_back_translation=0,
):
self.src_lang = src_lang
self.target_lang = target_lang
self.target_model = MarianMTModel.from_pretrained(target_model)
self.target_tokenizer = MarianTokenizer.from_pretrained(target_model)
self.src_model = MarianMTModel.from_pretrained(src_model)
self.src_tokenizer = MarianTokenizer.from_pretrained(src_model)
self.chained_back_translation = chained_back_translation
def translate(self, input, model, tokenizer, lang="es"):
# change the text to model's format
src_texts = []
if lang == "en":
src_texts.append(input[0])
else:
if ">>" and "<<" not in lang:
lang = ">>" + lang + "<< "
src_texts.append(lang + input[0])
# tokenize the input
encoded_input = tokenizer.prepare_seq2seq_batch(src_texts, return_tensors="pt")
# translate the input
translated = model.generate(**encoded_input)
translated_input = tokenizer.batch_decode(translated, skip_special_tokens=True)
return translated_input
def _get_transformations(self, current_text, indices_to_modify):
transformed_texts = []
current_text = current_text.text
# to perform chained back translation, a random list of target languages are selected from the provided model
if self.chained_back_translation:
list_of_target_lang = random.sample(
self.target_tokenizer.supported_language_codes,
self.chained_back_translation,
)
for target_lang in list_of_target_lang:
target_language_text = self.translate(
[current_text],
self.target_model,
self.target_tokenizer,
target_lang,
)
src_language_text = self.translate(
target_language_text,
self.src_model,
self.src_tokenizer,
self.src_lang,
)
current_text = src_language_text[0]
return [AttackedText(current_text)]
# translates source to target language and back to source language (single back translation)
target_language_text = self.translate(
[current_text], self.target_model, self.target_tokenizer, self.target_lang
)
src_language_text = self.translate(
target_language_text, self.src_model, self.src_tokenizer, self.src_lang
)
transformed_texts.append(AttackedText(src_language_text[0]))
return transformed_texts
"""
List of supported languages
['fr',
'es',
'it',
'pt',
'pt_br',
'ro',
'ca',
'gl',
'pt_BR<<',
'la<<',
'wa<<',
'fur<<',
'oc<<',
'fr_CA<<',
'sc<<',
'es_ES',
'es_MX',
'es_AR',
'es_PR',
'es_UY',
'es_CL',
'es_CO',
'es_CR',
'es_GT',
'es_HN',
'es_NI',
'es_PA',
'es_PE',
'es_VE',
'es_DO',
'es_EC',
'es_SV',
'an',
'pt_PT',
'frp',
'lad',
'vec',
'fr_FR',
'co',
'it_IT',
'lld',
'lij',
'lmo',
'nap',
'rm',
'scn',
'mwl']
"""
|