TiberiuCristianLeon commited on
Commit
a103fac
·
verified ·
1 Parent(s): aab3949

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import spaces
 
3
  from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
4
 
5
  langs = {"German": "de", "Romanian": "ro", "English": "en", "French": "fr", "Spanish": "es", "Italian": "it"}
@@ -15,18 +16,18 @@ def translate_text(input_text, sselected_language, tselected_language, model_nam
15
  try:
16
  model_name_full = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
17
  tokenizer = AutoTokenizer.from_pretrained(model_name_full)
18
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full)
19
  except EnvironmentError:
20
  try :
21
  model_name_full = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
22
  tokenizer = AutoTokenizer.from_pretrained(model_name_full)
23
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full)
24
  except EnvironmentError as error:
25
  return f"Error finding model: {model_name_full}! Try other available language combination.", error
26
  elif model_name.startswith('facebook/nllb'):
27
  from languagecodes import nllb_language_codes
28
  tokenizer = AutoTokenizer.from_pretrained(model_name, src_lang=nllb_language_codes[sselected_language])
29
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
30
  translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=nllb_language_codes[sselected_language], tgt_lang=nllb_language_codes[tselected_language])
31
  translated_text = translator(input_text, max_length=360)
32
  return translated_text[0]['translation_text'], message_text
@@ -34,6 +35,10 @@ def translate_text(input_text, sselected_language, tselected_language, model_nam
34
  tokenizer = T5Tokenizer.from_pretrained(model_name)
35
  model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto")
36
 
 
 
 
 
37
  if model_name.startswith("Helsinki-NLP"):
38
  prompt = input_text
39
  else:
 
1
  import gradio as gr
2
  import spaces
3
+ import torch
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
 
6
  langs = {"German": "de", "Romanian": "ro", "English": "en", "French": "fr", "Spanish": "es", "Italian": "it"}
 
16
  try:
17
  model_name_full = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
18
  tokenizer = AutoTokenizer.from_pretrained(model_name_full)
19
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full, device_map="auto")
20
  except EnvironmentError:
21
  try :
22
  model_name_full = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
23
  tokenizer = AutoTokenizer.from_pretrained(model_name_full)
24
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full, device_map="auto")
25
  except EnvironmentError as error:
26
  return f"Error finding model: {model_name_full}! Try other available language combination.", error
27
  elif model_name.startswith('facebook/nllb'):
28
  from languagecodes import nllb_language_codes
29
  tokenizer = AutoTokenizer.from_pretrained(model_name, src_lang=nllb_language_codes[sselected_language])
30
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto")
31
  translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=nllb_language_codes[sselected_language], tgt_lang=nllb_language_codes[tselected_language])
32
  translated_text = translator(input_text, max_length=360)
33
  return translated_text[0]['translation_text'], message_text
 
35
  tokenizer = T5Tokenizer.from_pretrained(model_name)
36
  model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto")
37
 
38
+ # Move the model to GPU if available
39
+ if torch.cuda.is_available():
40
+ model = model.to('cuda')
41
+ print("CUDA is available! Using GPU.")
42
  if model_name.startswith("Helsinki-NLP"):
43
  prompt = input_text
44
  else: