Dmytro Vodianytskyi commited on
Commit
3219277
·
1 Parent(s): 504e122

space updated

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -7,7 +7,7 @@ TOKENIZER = T5Tokenizer.from_pretrained('werent4/mt5TranslatorLT')
7
  MODEL = MT5ForConditionalGeneration.from_pretrained("werent4/mt5TranslatorLT")
8
  MODEL.to(DEVICE)
9
 
10
- def translate(text, device, max_length, num_beams, translation_way = "en-lt"):
11
  translations_ways = {
12
  "en-lt": "<EN2LT>",
13
  "lt-en": "<LT2EN>"
@@ -15,7 +15,7 @@ def translate(text, device, max_length, num_beams, translation_way = "en-lt"):
15
  if translation_way not in translations_ways:
16
  raise ValueError(f"Invalid translation way. Supported ways: {list(translations_ways.keys())}")
17
  input_text = f"{translations_ways[translation_way]} {text}"
18
- encoded_input = TOKENIZER(input_text, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
19
  with torch.no_grad():
20
  output_tokens = MODEL.generate(
21
  **encoded_input,
 
7
  MODEL = MT5ForConditionalGeneration.from_pretrained("werent4/mt5TranslatorLT")
8
  MODEL.to(DEVICE)
9
 
10
+ def translate(text, max_length, num_beams, translation_way = "en-lt"):
11
  translations_ways = {
12
  "en-lt": "<EN2LT>",
13
  "lt-en": "<LT2EN>"
 
15
  if translation_way not in translations_ways:
16
  raise ValueError(f"Invalid translation way. Supported ways: {list(translations_ways.keys())}")
17
  input_text = f"{translations_ways[translation_way]} {text}"
18
+ encoded_input = TOKENIZER(input_text, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE)
19
  with torch.no_grad():
20
  output_tokens = MODEL.generate(
21
  **encoded_input,