TiberiuCristianLeon commited on
Commit
310e819
·
verified ·
1 Parent(s): 44aa6cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -7,6 +7,14 @@ langs = {"German": "de", "Romanian": "ro", "English": "en", "French": "fr", "Spa
7
  options = list(langs.keys())
8
  models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large", "facebook/nllb-200-distilled-600M", "facebook/nllb-200-distilled-1.3B", "facebook/mbart-large-50-many-to-many-mmt"]
9
 
 
 
 
 
 
 
 
 
10
  @spaces.GPU
11
  def translate_text(input_text, sselected_language, tselected_language, model_name):
12
  sl = langs[sselected_language]
@@ -16,14 +24,15 @@ def translate_text(input_text, sselected_language, tselected_language, model_nam
16
  try:
17
  model_name_full = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
18
  tokenizer = AutoTokenizer.from_pretrained(model_name_full)
19
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full)
20
  except EnvironmentError:
21
  try :
22
  model_name_full = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
23
  tokenizer = AutoTokenizer.from_pretrained(model_name_full)
24
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full)
25
  except EnvironmentError as error:
26
  return f"Error finding model: {model_name_full}! Try other available language combination.", error
 
27
  if model_name.startswith('facebook/nllb'):
28
  from languagecodes import nllb_language_codes
29
  tokenizer = AutoTokenizer.from_pretrained(model_name, src_lang=nllb_language_codes[sselected_language])
@@ -31,6 +40,7 @@ def translate_text(input_text, sselected_language, tselected_language, model_nam
31
  translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=nllb_language_codes[sselected_language], tgt_lang=nllb_language_codes[tselected_language])
32
  translated_text = translator(input_text, max_length=512)
33
  return translated_text[0]['translation_text'], message_text
 
34
  if model_name.startswith('facebook/mbart-large'):
35
  from languagecodes import mbart_large_languages
36
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
@@ -43,18 +53,12 @@ def translate_text(input_text, sselected_language, tselected_language, model_nam
43
  **encoded,
44
  forced_bos_token_id=tokenizer.lang_code_to_id[mbart_large_languages[tselected_language]]
45
  )
46
- return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True), message_text
47
 
48
  if model_name.startswith('t5'):
49
  tokenizer = T5Tokenizer.from_pretrained(model_name)
50
  model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto")
51
 
52
- # Move the model to GPU if available
53
- if torch.cuda.is_available():
54
- model = model.to('cuda')
55
- print("CUDA is available! Using GPU.")
56
- else:
57
- print("CUDA not available! Using CPU.")
58
  if model_name.startswith("Helsinki-NLP"):
59
  prompt = input_text
60
  else:
 
7
  options = list(langs.keys())
8
  models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large", "facebook/nllb-200-distilled-600M", "facebook/nllb-200-distilled-1.3B", "facebook/mbart-large-50-many-to-many-mmt"]
9
 
10
+ def model_to_cuda(model):
11
+ # Move the model to GPU if available
12
+ if torch.cuda.is_available():
13
+ model = model.to('cuda')
14
+ print("CUDA is available! Using GPU.")
15
+ else:
16
+ print("CUDA not available! Using CPU.")
17
+ return model
18
  @spaces.GPU
19
  def translate_text(input_text, sselected_language, tselected_language, model_name):
20
  sl = langs[sselected_language]
 
24
  try:
25
  model_name_full = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
26
  tokenizer = AutoTokenizer.from_pretrained(model_name_full)
27
+ model = model_to_cuda(AutoModelForSeq2SeqLM.from_pretrained(model_name_full))
28
  except EnvironmentError:
29
  try :
30
  model_name_full = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
31
  tokenizer = AutoTokenizer.from_pretrained(model_name_full)
32
+ model = model_to_cuda(AutoModelForSeq2SeqLM.from_pretrained(model_name_full))
33
  except EnvironmentError as error:
34
  return f"Error finding model: {model_name_full}! Try other available language combination.", error
35
+
36
  if model_name.startswith('facebook/nllb'):
37
  from languagecodes import nllb_language_codes
38
  tokenizer = AutoTokenizer.from_pretrained(model_name, src_lang=nllb_language_codes[sselected_language])
 
40
  translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=nllb_language_codes[sselected_language], tgt_lang=nllb_language_codes[tselected_language])
41
  translated_text = translator(input_text, max_length=512)
42
  return translated_text[0]['translation_text'], message_text
43
+
44
  if model_name.startswith('facebook/mbart-large'):
45
  from languagecodes import mbart_large_languages
46
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
 
53
  **encoded,
54
  forced_bos_token_id=tokenizer.lang_code_to_id[mbart_large_languages[tselected_language]]
55
  )
56
+ return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0], message_text
57
 
58
  if model_name.startswith('t5'):
59
  tokenizer = T5Tokenizer.from_pretrained(model_name)
60
  model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto")
61
 
 
 
 
 
 
 
62
  if model_name.startswith("Helsinki-NLP"):
63
  prompt = input_text
64
  else: