TiberiuCristianLeon commited on
Commit
d2894fa
·
verified ·
1 Parent(s): 08aee1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -2
app.py CHANGED
@@ -24,14 +24,28 @@ def translate_text(input_text, sselected_language, tselected_language, model_nam
24
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full)
25
  except EnvironmentError as error:
26
  return f"Error finding model: {model_name_full}! Try other available language combination.", error
27
- elif model_name.startswith('facebook/nllb'):
28
  from languagecodes import nllb_language_codes
29
  tokenizer = AutoTokenizer.from_pretrained(model_name, src_lang=nllb_language_codes[sselected_language])
30
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto")
31
  translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=nllb_language_codes[sselected_language], tgt_lang=nllb_language_codes[tselected_language])
32
  translated_text = translator(input_text, max_length=512)
33
  return translated_text[0]['translation_text'], message_text
34
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  tokenizer = T5Tokenizer.from_pretrained(model_name)
36
  model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto")
37
 
 
24
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full)
25
  except EnvironmentError as error:
26
  return f"Error finding model: {model_name_full}! Try other available language combination.", error
27
+ if model_name.startswith('facebook/nllb'):
28
  from languagecodes import nllb_language_codes
29
  tokenizer = AutoTokenizer.from_pretrained(model_name, src_lang=nllb_language_codes[sselected_language])
30
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto")
31
  translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=nllb_language_codes[sselected_language], tgt_lang=nllb_language_codes[tselected_language])
32
  translated_text = translator(input_text, max_length=512)
33
  return translated_text[0]['translation_text'], message_text
34
+ if model_name.startswith('facebook/mbart-large'):
35
+ from languagecodes import mbart_large_languages
36
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
37
+ model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
38
+ tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
39
+ # translate source to target
40
+ tokenizer.src_lang = mbart_large_languages[sselected_language]
41
+ encoded = tokenizer(article_hi, return_tensors="pt")
42
+ generated_tokens = model.generate(
43
+ **encoded,
44
+ forced_bos_token_id=tokenizer.lang_code_to_id[mbart_large_languages[tselected_language]]
45
+ )
46
+ return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True), message_text
47
+
48
+ if model_name.startswith('t5'):
49
  tokenizer = T5Tokenizer.from_pretrained(model_name)
50
  model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto")
51