puppala13 commited on
Commit
3e35f55
·
verified ·
1 Parent(s): dfb8d4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -1,13 +1,13 @@
1
  import streamlit as st
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
3
 
4
  def main():
5
  st.title("Translation App")
6
 
7
  # Load model and tokenizer
8
- model_name = "facebook/mbart-large-50-one-to-many-mmt"
9
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
- tokenizer = AutoTokenizer.from_pretrained(model_name)
11
 
12
  # Input text area
13
  input_text = st.text_area("Enter text to translate", "")
@@ -22,9 +22,9 @@ def main():
22
 
23
  def translate_text(input_text, model, tokenizer):
24
  # Tokenize input text
25
- model_inputs = tokenizer(input_text, return_tensors="pt").input_ids
26
  generated_tokens = model.generate(
27
- **model_inputs,
28
  forced_bos_token_id=tokenizer.lang_code_to_id["hi_IN"]
29
  )
30
 
 
1
  import streamlit as st
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
4
 
5
  def main():
6
  st.title("Translation App")
7
 
8
  # Load model and tokenizer
9
+ model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-one-to-many-mmt")
10
+ tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-one-to-many-mmt", src_lang="en_XX")
 
11
 
12
  # Input text area
13
  input_text = st.text_area("Enter text to translate", "")
 
22
 
23
  def translate_text(input_text, model, tokenizer):
24
  # Tokenize input text
25
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids
26
  generated_tokens = model.generate(
27
+ **input_ids,
28
  forced_bos_token_id=tokenizer.lang_code_to_id["hi_IN"]
29
  )
30