Translation / app.py
puppala13's picture
Update app.py
6608f42 verified
raw
history blame
2.42 kB
import streamlit as st
import PyPDF2
import PyPDF2 as PDF
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
def main():
st.title("Translation App")
# Load model and tokenizer
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-one-to-many-mmt")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-one-to-many-mmt", src_lang="en_XX")
# Input option: Text area or file upload
input_option = st.radio("Select Input Option", ("Text", "PDF"))
if input_option == "Text":
input_text = st.text_area("Enter text to translate", "")
translate_button = st.button("Translate")
if translate_button:
translated_text = translate_text(input_text, model, tokenizer)
st.write("Translated Text:")
st.write(translated_text)
elif input_option == "PDF":
pdf_file = st.file_uploader("Upload PDF file", type=['pdf'])
if pdf_file is not None:
pdf_text = extract_text_from_pdf(pdf_file)
st.write("Extracted Text from PDF:")
st.write(pdf_text)
translate_button = st.button("Translate")
if translate_button:
translated_text = translate_text(pdf_text, model, tokenizer)
st.write("Translated Text:")
st.write(translated_text)
def extract_text_from_pdf(pdf_file):
pdf_reader = PyPDF2.PdfFileReader(pdf_file)
text = ""
for page_num in range(pdf_reader.numPages):
page = pdf_reader.getPage(page_num)
text += page.extractText()
return text
def translate_text(input_text, model, tokenizer):
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
translate_to = st.selectbox("Select language to translate", ("Hindi", "Tamil", "Telugu"))
target_lang = ""
if translate_to == "Hindi":
target_lang = "hi_IN"
elif translate_to == "Tamil":
target_lang = "ta_IN"
elif translate_to == "Telugu":
target_lang = "te_IN"
generated_tokens = model.generate(
input_ids=input_ids,
forced_bos_token_id=tokenizer.lang_code_to_id[target_lang]
)
translated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
return translated_text
if __name__ == '__main__':
main()