TiberiuCristianLeon commited on
Commit
4d3b257
·
verified ·
1 Parent(s): cfebeae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -11
app.py CHANGED
@@ -70,12 +70,12 @@ st.session_state["model_name"] = model_name
70
  if model_name == 'Helsinki-NLP':
71
  try:
72
  model_name = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
73
- tokenizer = AutoTokenizer.from_pretrained(model_name)
74
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
75
  except EnvironmentError:
76
  model_name = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
77
- tokenizer = AutoTokenizer.from_pretrained(model_name)
78
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
79
  if model_name.startswith('t5'):
80
  tokenizer = T5Tokenizer.from_pretrained(model_name)
81
  model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
@@ -87,13 +87,18 @@ translated_textarea = st.text("")
87
  # Handle the submit button click
88
  if submit_button:
89
  if model_name.startswith('Helsinki-NLP'):
90
- prompt = input_text
91
- print(prompt)
92
- input_ids = tokenizer.encode(prompt, return_tensors='pt')
93
- # Perform translation
94
- output_ids = model.generate(input_ids)
95
- # Decode the translated text
96
- translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
 
 
 
 
97
  elif model_name.startswith('Google'):
98
  url = os.environ['GCLIENT'] + f'sl={sl}&tl={tl}&q={input_text}'
99
  response = httpx.get(url)
 
70
  if model_name == 'Helsinki-NLP':
71
  try:
72
  model_name = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
73
+ # tokenizer = AutoTokenizer.from_pretrained(model_name)
74
+ # model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
75
  except EnvironmentError:
76
  model_name = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
77
+ # tokenizer = AutoTokenizer.from_pretrained(model_name)
78
+ # model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
79
  if model_name.startswith('t5'):
80
  tokenizer = T5Tokenizer.from_pretrained(model_name)
81
  model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
 
87
  # Handle the submit button click
88
  if submit_button:
89
  if model_name.startswith('Helsinki-NLP'):
90
+ # prompt = input_text
91
+ # print(prompt)
92
+ # input_ids = tokenizer.encode(prompt, return_tensors='pt')
93
+ # # Perform translation
94
+ # output_ids = model.generate(input_ids)
95
+ # # Decode the translated text
96
+ # translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
97
+ # Use a pipeline as a high-level helper
98
+ pipe = pipeline("translation", model=model_name)
99
+ translation = pipe(input_text)
100
+ translated_text = translation[0]['translation_text']
101
+
102
  elif model_name.startswith('Google'):
103
  url = os.environ['GCLIENT'] + f'sl={sl}&tl={tl}&q={input_text}'
104
  response = httpx.get(url)