TiberiuCristianLeon commited on
Commit
cae0132
·
verified ·
1 Parent(s): 0d00ebd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -11
app.py CHANGED
@@ -2,12 +2,13 @@ import gradio as gr
2
  import spaces
3
  import torch
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
5
 
6
- favourite_langs = {"German": "de", "Romanian": "ro", "English": "en", "French": "fr", "-----": "-----"}
7
- all_langs = {"German": "de", "Romanian": "ro", "English": "en", "French": "fr", "Spanish": "es", "Italian": "it", "Hungarian": "hu"}
8
- langs = {**favourite_langs, **all_langs}
9
 
10
- options = list(langs.keys())
 
11
  models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large", "facebook/nllb-200-distilled-600M", "facebook/nllb-200-distilled-1.3B", "facebook/mbart-large-50-many-to-many-mmt"]
12
 
13
  def model_to_cuda(model):
@@ -37,24 +38,22 @@ def translate_text(input_text, sselected_language, tselected_language, model_nam
37
  return f"Error finding model: {model_name}! Try other available language combination.", error
38
 
39
  if model_name.startswith('facebook/nllb'):
40
- from languagecodes import nllb_language_codes
41
- tokenizer = AutoTokenizer.from_pretrained(model_name, src_lang=nllb_language_codes[sselected_language])
42
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto")
43
- translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=nllb_language_codes[sselected_language], tgt_lang=nllb_language_codes[tselected_language])
44
  translated_text = translator(input_text, max_length=512)
45
  return translated_text[0]['translation_text'], message_text
46
 
47
- if model_name.startswith('facebook/mbart-large'):
48
- from languagecodes import mbart_large_languages
49
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
50
  model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
51
  tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
52
  # translate source to target
53
- tokenizer.src_lang = mbart_large_languages[sselected_language]
54
  encoded = tokenizer(input_text, return_tensors="pt")
55
  generated_tokens = model.generate(
56
  **encoded,
57
- forced_bos_token_id=tokenizer.lang_code_to_id[mbart_large_languages[tselected_language]]
58
  )
59
  return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0], message_text
60
 
 
2
  import spaces
3
  import torch
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
+ import languagecodes
6
 
7
+ favourite_langs = {"German": "de", "Romanian": "ro", "English": "en", "-----": "-----"}
8
+ langs = languagecodes.iso_languages
 
9
 
10
+ # options = list(langs.keys())
11
+ options = [(k, v) for k,v in favourite_langs.items()].extend([(k, v) for k,v in langs.items()])
12
  models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large", "facebook/nllb-200-distilled-600M", "facebook/nllb-200-distilled-1.3B", "facebook/mbart-large-50-many-to-many-mmt"]
13
 
14
  def model_to_cuda(model):
 
38
  return f"Error finding model: {model_name}! Try other available language combination.", error
39
 
40
  if model_name.startswith('facebook/nllb'):
41
+ tokenizer = AutoTokenizer.from_pretrained(model_name, src_lang=languagecodes.nllb_language_codes[sselected_language])
 
42
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto")
43
+ translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=languagecodes.nllb_language_codes[sselected_language], tgt_lang=languagecodes.nllb_language_codes[tselected_language])
44
  translated_text = translator(input_text, max_length=512)
45
  return translated_text[0]['translation_text'], message_text
46
 
47
+ if model_name.startswith('facebook/mbart-large
 
48
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
49
  model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
50
  tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
51
  # translate source to target
52
+ tokenizer.src_lang = languagecodes.mbart_large_languages[sselected_language]
53
  encoded = tokenizer(input_text, return_tensors="pt")
54
  generated_tokens = model.generate(
55
  **encoded,
56
+ forced_bos_token_id=tokenizer.lang_code_to_id[languagecodes.mbart_large_languages[tselected_language]]
57
  )
58
  return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0], message_text
59