TiberiuCristianLeon commited on
Commit
8b76c73
·
verified ·
1 Parent(s): 681fbf2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -12
app.py CHANGED
@@ -3,7 +3,10 @@ import spaces
3
  import torch
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
 
6
- langs = {"German": "de", "Romanian": "ro", "English": "en", "French": "fr", "Spanish": "es", "Italian": "it"}
 
 
 
7
  options = list(langs.keys())
8
  models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large", "facebook/nllb-200-distilled-600M", "facebook/nllb-200-distilled-1.3B", "facebook/mbart-large-50-many-to-many-mmt"]
9
 
@@ -22,16 +25,16 @@ def translate_text(input_text, sselected_language, tselected_language, model_nam
22
  message_text = f'Translated from {sselected_language} to {tselected_language} with {model_name}'
23
  if model_name == "Helsinki-NLP":
24
  try:
25
- model_name_full = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
26
- tokenizer = AutoTokenizer.from_pretrained(model_name_full)
27
- model = model_to_cuda(AutoModelForSeq2SeqLM.from_pretrained(model_name_full))
28
  except EnvironmentError:
29
- try :
30
- model_name_full = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
31
- tokenizer = AutoTokenizer.from_pretrained(model_name_full)
32
- model = model_to_cuda(AutoModelForSeq2SeqLM.from_pretrained(model_name_full))
33
  except EnvironmentError as error:
34
- return f"Error finding model: {model_name_full}! Try other available language combination.", error
35
 
36
  if model_name.startswith('facebook/nllb'):
37
  from languagecodes import nllb_language_codes
@@ -91,8 +94,8 @@ def create_interface():
91
  model_name = gr.Dropdown(choices=models, label="Select a model", value = models[4], interactive=True)
92
  translate_button = gr.Button("Translate")
93
 
94
- translated_text = gr.Textbox(label="Translated text:", interactive=False, show_copy_button=True)
95
- message_text = gr.Textbox(label="Messages:", placeholder="Display status and error messages", interactive=False)
96
 
97
  translate_button.click(
98
  translate_text,
@@ -102,6 +105,5 @@ def create_interface():
102
 
103
  return interface
104
 
105
- # Launch the Gradio interface
106
  interface = create_interface()
107
  interface.launch()
 
3
  import torch
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
 
6
+ favourite_langs = {"German": "de", "Romanian": "ro", "English": "en", "French": "fr", "-----": "-----"}
7
+ all_langs = {"German": "de", "Romanian": "ro", "English": "en", "French": "fr", "Spanish": "es", "Italian": "it", "Hungarian": "hu"}
8
+ langs = {**favourite_langs, **all_langs}
9
+
10
  options = list(langs.keys())
11
  models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large", "facebook/nllb-200-distilled-600M", "facebook/nllb-200-distilled-1.3B", "facebook/mbart-large-50-many-to-many-mmt"]
12
 
 
25
  message_text = f'Translated from {sselected_language} to {tselected_language} with {model_name}'
26
  if model_name == "Helsinki-NLP":
27
  try:
28
+ model_name = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
29
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
30
+ model = model_to_cuda(AutoModelForSeq2SeqLM.from_pretrained(model_name))
31
  except EnvironmentError:
32
+ try:
33
+ model_name = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
34
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
35
+ model = model_to_cuda(AutoModelForSeq2SeqLM.from_pretrained(model_name))
36
  except EnvironmentError as error:
37
+ return f"Error finding model: {model_name}! Try other available language combination.", error
38
 
39
  if model_name.startswith('facebook/nllb'):
40
  from languagecodes import nllb_language_codes
 
94
  model_name = gr.Dropdown(choices=models, label="Select a model", value = models[4], interactive=True)
95
  translate_button = gr.Button("Translate")
96
 
97
+ translated_text = gr.Textbox(label="Translated text:", placeholder="Display field for translation", interactive=False, show_copy_button=True)
98
+ message_text = gr.Textbox(label="Messages:", placeholder="Display field for status and error messages", interactive=False)
99
 
100
  translate_button.click(
101
  translate_text,
 
105
 
106
  return interface
107
 
 
108
  interface = create_interface()
109
  interface.launch()