milyiyo commited on
Commit
81705b8
·
1 Parent(s): 7f6306d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
4
+
5
+ tokenizer = AutoTokenizer.from_pretrained(
6
+ "prithivida/parrot_paraphraser_on_T5", use_auth_token=os.environ["AUTH_TOKEN"])
7
+ model = AutoModelForSeq2SeqLM.from_pretrained(
8
+ "prithivida/parrot_paraphraser_on_T5", use_auth_token=os.environ["AUTH_TOKEN"])
9
+
10
+ pln_es_to_en = pipeline('translation_es_to_en',
11
+ model=AutoModelForSeq2SeqLM.from_pretrained(
12
+ 'Helsinki-NLP/opus-mt-es-en'),
13
+ tokenizer=AutoTokenizer.from_pretrained(
14
+ 'Helsinki-NLP/opus-mt-es-en')
15
+ )
16
+
17
+ pln_en_to_es = pipeline('translation_en_to_es',
18
+ model=AutoModelForSeq2SeqLM.from_pretrained(
19
+ 'Helsinki-NLP/opus-mt-en-es'),
20
+ tokenizer=AutoTokenizer.from_pretrained(
21
+ 'Helsinki-NLP/opus-mt-en-es')
22
+ )
23
+
24
+
25
+ def paraphrase(sentence: str, lang: str, count: str):
26
+ p_count = int(count)
27
+ if p_count <= 0 or len(sentence.strip()) == 0:
28
+ return {'result': []}
29
+ sentence_input = sentence
30
+ if lang == 'ES':
31
+ sentence_input = pln_es_to_en(sentence_input)[0]['translation_text']
32
+ text = f"paraphrase: {sentence_input} </s>"
33
+ encoding = tokenizer.encode_plus(text, padding=True, return_tensors="pt")
34
+ input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"]
35
+ outputs = model.generate(
36
+ input_ids=input_ids, attention_mask=attention_masks,
37
+ max_length=512, # 256,
38
+ do_sample=True,
39
+ top_k=120,
40
+ top_p=0.95,
41
+ early_stopping=True,
42
+ num_return_sequences=p_count
43
+ )
44
+ res = []
45
+ for output in outputs:
46
+ line = tokenizer.decode(
47
+ output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
48
+ res.append(line)
49
+ if lang == 'EN':
50
+ return {'result': res}
51
+ else:
52
+ res_es = [pln_en_to_es(x)[0]['translation_text']
53
+ for x in res]
54
+ return {'result': res_es}
55
+
56
+
57
+ iface = gr.Interface(fn=paraphrase,
58
+ inputs=[
59
+ gr.inputs.Textbox(
60
+ lines=2, placeholder=None, label='Sentence'),
61
+ gr.inputs.Dropdown(
62
+ ['ES', 'EN'], type="value", label='Language'),
63
+ gr.inputs.Number(
64
+ default=3, label='Paraphrases count'),
65
+ ],
66
+ outputs=[gr.outputs.JSON(label=None)])
67
+ iface.launch()