TiberiuCristianLeon commited on
Commit
978158a
·
verified ·
1 Parent(s): 3a7e27a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -5,6 +5,7 @@ import os
5
  import httpx
6
 
7
  logging.set_verbosity_error()
 
8
 
9
  def download_argos_model(from_code, to_code):
10
  import argostranslate.package
@@ -53,7 +54,8 @@ if model_name == 'Helsinki-NLP':
53
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
54
  if model_name.startswith('t5'):
55
  tokenizer = T5Tokenizer.from_pretrained(model_name)
56
- model = T5ForConditionalGeneration.from_pretrained(model_name)
 
57
 
58
  st.write("Selected language combination:", sselected_language, " - ", tselected_language, "Selected model:", model_name)
59
  submit_button = st.button("Translate")
@@ -77,7 +79,7 @@ if submit_button:
77
  elif model_name.startswith('t5'):
78
  prompt = f'translate {sselected_language} to {tselected_language}: {input_text}'
79
  print(prompt)
80
- input_ids = tokenizer.encode(prompt, return_tensors='pt')
81
  # Perform translation
82
  output_ids = model.generate(input_ids)
83
  # Decode the translated text
@@ -104,8 +106,6 @@ if submit_button:
104
  translated_text = f"No Argos model for {sselected_language} to {tselected_language}. Try other model or languages combination!"
105
  except Exception as error:
106
  translated_text = error
107
- # download_argos_model(sl, tl)
108
- # translated_text = argostranslate.translate.translate(input_text, sl, tl)
109
  # Display the translated text
110
  print(translated_text)
111
  st.write(f"Translated text from {sselected_language} to {tselected_language} using {model_name}:")
 
5
  import httpx
6
 
7
  logging.set_verbosity_error()
8
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
 
10
  def download_argos_model(from_code, to_code):
11
  import argostranslate.package
 
54
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
55
  if model_name.startswith('t5'):
56
  tokenizer = T5Tokenizer.from_pretrained(model_name)
57
+ model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
58
+
59
 
60
  st.write("Selected language combination:", sselected_language, " - ", tselected_language, "Selected model:", model_name)
61
  submit_button = st.button("Translate")
 
79
  elif model_name.startswith('t5'):
80
  prompt = f'translate {sselected_language} to {tselected_language}: {input_text}'
81
  print(prompt)
82
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
83
  # Perform translation
84
  output_ids = model.generate(input_ids)
85
  # Decode the translated text
 
106
  translated_text = f"No Argos model for {sselected_language} to {tselected_language}. Try other model or languages combination!"
107
  except Exception as error:
108
  translated_text = error
 
 
109
  # Display the translated text
110
  print(translated_text)
111
  st.write(f"Translated text from {sselected_language} to {tselected_language} using {model_name}:")