# Imports import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Load the tokenizer and model tokenizer = AutoTokenizer.from_pretrained("prithivida/grammar_error_correcter_v1") model = AutoModelForSeq2SeqLM.from_pretrained("prithivida/grammar_error_correcter_v1") # Use GPU if available device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) # Grammar correction function def correct_grammar(text): # Tokenize input text with an increased max_length for handling larger input inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True, max_length=1024).to(device) # Generate corrected text with increased max_length and num_beams outputs = model.generate(**inputs, max_length=1024, num_beams=5, early_stopping=True) # Decode the output and return the corrected text corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return corrected_text # Gradio interface function def correct_grammar_interface(text): corrected_text = correct_grammar(text) return corrected_text # Gradio app interface with gr.Blocks() as grammar_app: gr.Markdown("

Grammar Correction App (up to 300 words)

") with gr.Row(): input_box = gr.Textbox(label="Input Text", placeholder="Enter text (up to 300 words)", lines=10) output_box = gr.Textbox(label="Corrected Text", placeholder="Corrected text will appear here", lines=10) submit_button = gr.Button("Correct Grammar") # Bind the button click to the grammar correction function submit_button.click(fn=correct_grammar_interface, inputs=input_box, outputs=output_box) # Launch the app if __name__ == "__main__": grammar_app.launch()