Spaces:
Sleeping
Sleeping
Dmytro Vodianytskyi
commited on
Commit
·
5cfd806
1
Parent(s):
2b29e41
space updated
Browse files
app.py
CHANGED
@@ -3,23 +3,29 @@ import torch
|
|
3 |
from transformers import T5Tokenizer, MT5ForConditionalGeneration
|
4 |
|
5 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
6 |
-
TOKENIZER = T5Tokenizer.from_pretrained('
|
7 |
MODEL = MT5ForConditionalGeneration.from_pretrained("werent4/mt5TranslatorLT")
|
8 |
MODEL.to(DEVICE)
|
9 |
|
10 |
-
def translate(text,
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
with torch.no_grad():
|
14 |
-
output_tokens =
|
15 |
**encoded_input,
|
16 |
-
max_length=
|
17 |
-
num_beams=
|
18 |
no_repeat_ngram_size=2,
|
19 |
early_stopping=True
|
20 |
)
|
|
|
21 |
|
22 |
-
return TOKENIZER.decode(output_tokens[0], skip_special_tokens=True)
|
23 |
|
24 |
|
25 |
with gr.Blocks() as interface:
|
@@ -30,18 +36,26 @@ with gr.Blocks() as interface:
|
|
30 |
with gr.Row():
|
31 |
input_text = gr.Textbox(label="Text input", placeholder="Enter your text here")
|
32 |
with gr.Column():
|
33 |
-
mode = gr.Dropdown(label="Mode", choices=["
|
34 |
translate_button = gr.Button("Translate")
|
35 |
output_text = gr.Textbox(label="Translated text")
|
36 |
with gr.Accordion("How to run the model locally:", open=False):
|
37 |
gr.Code("""import torch
|
38 |
-
from transformers import T5Tokenizer, MT5ForConditionalGeneration
|
39 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
40 |
-
|
|
|
|
|
|
|
41 |
model = MT5ForConditionalGeneration.from_pretrained("werent4/mt5TranslatorLT")
|
42 |
model.to(device)
|
43 |
-
def translate(text, model, tokenizer, device):
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
encoded_input = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
|
46 |
with torch.no_grad():
|
47 |
output_tokens = model.generate(
|
@@ -51,6 +65,7 @@ def translate(text, model, tokenizer, device):
|
|
51 |
no_repeat_ngram_size=2,
|
52 |
early_stopping=True
|
53 |
)
|
|
|
54 |
translated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
|
55 |
return translated_text
|
56 |
text = "I live in Kaunas"
|
|
|
3 |
from transformers import T5Tokenizer, MT5ForConditionalGeneration
|
4 |
|
5 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
6 |
+
TOKENIZER = T5Tokenizer.from_pretrained('werent4/mt5TranslatorLT')
|
7 |
MODEL = MT5ForConditionalGeneration.from_pretrained("werent4/mt5TranslatorLT")
|
8 |
MODEL.to(DEVICE)
|
9 |
|
10 |
+
def translate(text, model, tokenizer, device, translation_way = "en-lt"):
|
11 |
+
translations_ways = {
|
12 |
+
"en-lt": "<EN2LT>",
|
13 |
+
"lt-en": "<LT2EN>"
|
14 |
+
}
|
15 |
+
if translation_way not in translations_ways:
|
16 |
+
raise ValueError(f"Invalid translation way. Supported ways: {list(translations_ways.keys())}")
|
17 |
+
text = f"{translations_ways[translation_way]} {text}"
|
18 |
+
encoded_input = TOKENIZER(input_text, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
|
19 |
with torch.no_grad():
|
20 |
+
output_tokens = model.generate(
|
21 |
**encoded_input,
|
22 |
+
max_length=128,
|
23 |
+
num_beams=5,
|
24 |
no_repeat_ngram_size=2,
|
25 |
early_stopping=True
|
26 |
)
|
27 |
+
return TOKENIZER.decode(output_tokens[0], skip_special_tokens=True)
|
28 |
|
|
|
29 |
|
30 |
|
31 |
with gr.Blocks() as interface:
|
|
|
36 |
with gr.Row():
|
37 |
input_text = gr.Textbox(label="Text input", placeholder="Enter your text here")
|
38 |
with gr.Column():
|
39 |
+
mode = gr.Dropdown(label="Mode", choices=["en-lt", "lt-en"])
|
40 |
translate_button = gr.Button("Translate")
|
41 |
output_text = gr.Textbox(label="Translated text")
|
42 |
with gr.Accordion("How to run the model locally:", open=False):
|
43 |
gr.Code("""import torch
|
|
|
44 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
45 |
+
|
46 |
+
from transformers import T5Tokenizer, MT5ForConditionalGeneration
|
47 |
+
|
48 |
+
tokenizer = T5Tokenizer.from_pretrained('werent4/mt5TranslatorLT')
|
49 |
model = MT5ForConditionalGeneration.from_pretrained("werent4/mt5TranslatorLT")
|
50 |
model.to(device)
|
51 |
+
def translate(text, model, tokenizer, device, translation_way = "en-lt"):
|
52 |
+
translations_ways = {
|
53 |
+
"en-lt": "<EN2LT>",
|
54 |
+
"lt-en": "<LT2EN>"
|
55 |
+
}
|
56 |
+
if translation_way not in translations_ways:
|
57 |
+
raise ValueError(f"Invalid translation way. Supported ways: {list(translations_ways.keys())}")
|
58 |
+
input_text = f"{translations_ways[translation_way]} {text}"
|
59 |
encoded_input = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
|
60 |
with torch.no_grad():
|
61 |
output_tokens = model.generate(
|
|
|
65 |
no_repeat_ngram_size=2,
|
66 |
early_stopping=True
|
67 |
)
|
68 |
+
|
69 |
translated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
|
70 |
return translated_text
|
71 |
text = "I live in Kaunas"
|