Wootang01's picture
Update app.py
a184a4d
raw
history blame
1.44 kB
import streamlit as st
st.title("Grammar Corrector")
st.write("Paste or type text, submit and the machine will attempt to correct your text's grammar.")
default_text = "In conclusion,if anyone has some problem the customers must be returned."
sent = st.text_area("Text", default_text, height=40)
num_correct_options = st.number_input('Number of Correction Options', min_value=1, max_value=3, value=1, step=1)
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = T5Tokenizer.from_pretrained('deep-learning-analytics/GrammarCorrector')
model = T5ForConditionalGeneration.from_pretrained('deep-learning-analytics/GrammarCorrector').to(torch_device)
def correct_grammar(input_text, num_return_sequences=num_correct_options):
batch = tokenizer([input_text], truncation=True, padding = 'max_length', max_length = 64, return_tensors = 'pt').to(torch_device)
results = model.generate(**batch, max_length = 64, num_beams = 4, num_return_sequences = num_correct_options, temperature = 1.5)
return results
results = correct_grammar(sent, num_correct_options)
generated_options = []
for generated_option_idx, generated_option in enumerate(results):
text = tokenizer.decode(generated_option, clean_up_tokenization_spaces = True, skip_special_tokens = True)
generated_options.append(text)
st.write(generated_options)