Adapters
khulnasoft commited on
Commit
e2e1958
·
verified ·
1 Parent(s): 8e4b938

Create roundtrip_mutator.py

Browse files
prompt_injection/mutators/roundtrip_mutator.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from prompt_injection.mutators.base import PromptMutator
4
+ from transformers import MarianMTModel, MarianTokenizer
5
+
6
+
7
+ class RoundTripPromptMutator(PromptMutator):
8
+ def __init__(self,model_name_translate='Helsinki-NLP/opus-mt-en-zh',model_name_inv_translate='Helsinki-NLP/opus-mt-zh-en',label=None):
9
+ self.model_name_translate=model_name_translate
10
+ self.model_name_inv_translate=model_name_inv_translate
11
+
12
+ # Load the pre-trained model and tokenizer
13
+ self.model_translate = MarianMTModel.from_pretrained(model_name_translate)
14
+ self.tokenizer_translate = MarianTokenizer.from_pretrained(model_name_translate)
15
+
16
+ # Load the pre-trained model and tokenizer
17
+ self.model_inv_translate = MarianMTModel.from_pretrained(model_name_inv_translate)
18
+ self.tokenizer_inv_translate = MarianTokenizer.from_pretrained(model_name_inv_translate)
19
+ if label is None:
20
+ self.label= f'RoundTripPromptMutator-{self.model_name_translate}--{self.model_name_translate}'
21
+ else:
22
+ self.label= f'RoundTripPromptMutator-{label}'
23
+
24
+
25
+ def to_lang(self,text):
26
+ inputs = self.tokenizer_translate.encode(text, return_tensors='pt', padding=True, truncation=True)
27
+ translated_tokens = self.model_translate.generate(inputs, max_length=40, num_beams=4, early_stopping=True)
28
+ translated_text = self.tokenizer_translate.decode(translated_tokens[0], skip_special_tokens=True)
29
+ return translated_text
30
+
31
+ def from_lang(self,text):
32
+ inputs = self.tokenizer_inv_translate.encode(text, return_tensors='pt', padding=True, truncation=True)
33
+ translated_tokens = self.model_inv_translate.generate(inputs, max_length=40, num_beams=4, early_stopping=True)
34
+ translated_text = self.tokenizer_inv_translate.decode(translated_tokens[0], skip_special_tokens=True)
35
+ return translated_text
36
+ def mutate(self,sample:str)->str:
37
+ return self.from_lang(self.to_lang(sample))
38
+
39
+ def get_name(self):
40
+ return self.label