Wootang01's picture
Update app.py
d8fbcce
raw
history blame
1.47 kB
import streamlit as st
st.title("Grammar Corrector")
st.write("Paste or type text, submit and the machine will attempt to correct your text's grammar.")
default_text = " Lastly, another reason we will need clear procedures so that customers do not exploit our refund policy."
sent = st.text_area("Text", default_text, height=40)
num_correct_options = st.number_input('Number of Correction Options', min_value=1, max_value=3, value=1, step=1)
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = T5Tokenizer.from_pretrained('deep-learning-analytics/GrammarCorrector')
model = T5ForConditionalGeneration.from_pretrained('deep-learning-analytics/GrammarCorrector').to(torch_device)
def correct_grammar(input_text, num_return_sequences=num_correct_options):
batch = tokenizer([input_text], truncation=True, padding = 'max_length', max_length = 64, return_tensors = 'pt').to(torch_device)
results = model.generate(**batch, max_length = 64, num_beams = 4, num_return_sequences = num_correct_options, temperature = 1.5)
return results
results = correct_grammar(sent, num_correct_options)
generated_options = []
for generated_option_idx, generated_option in enumerate(results):
text = tokenizer.decode(generated_option, clean_up_tokenization_spaces = True, skip_special_tokens = True)
generated_options.append(text)
st.write(generated_options)