import gradio as gr from transformers import AutoModelForSeq2SeqLM, AutoTokenizer model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction") tokenizer = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction") def correct_text(text, max_length, max_new_tokens, min_length, num_beams, temperature, top_p): inputs = tokenizer.encode("grammar: " + text, return_tensors="pt") generate_kwargs = { "inputs": inputs, "max_length": max_length, "min_length": min_length, "num_beams": num_beams, "temperature": temperature, "top_p": top_p, "early_stopping": True } if max_new_tokens > 0: generate_kwargs["max_new_tokens"] = max_new_tokens outputs = model.generate(**generate_kwargs) corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return corrected_text def respond(message, history, max_length, min_length, max_new_tokens, num_beams, temperature, top_p): response = correct_text(message, max_length, max_new_tokens, min_length, num_beams, temperature, top_p) yield response css = """ #interface-container { padding: 20px; border-radius: 10px; box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1); max-width: 800px; margin: auto; font-family: 'Arial', sans-serif; } #input-container { margin-bottom: 20px; } #output-container { margin-top: 20px; font-family: Arial, sans-serif; font-size: 16px; color: #333; padding: 10px; border-radius: 5px; border: 1px solid #ddd; } .gr-button { background-color: #007bff; color: white; border: none; padding: 10px 20px; border-radius: 5px; cursor: pointer; font-size: 16px; } .gr-button:hover { background-color: #0056b3; } .gr-slider .gr-slider-track { background-color: #007bff; } .gr-slider .gr-slider-thumb { background-color: #0056b3; } .gr-textbox input { border: 1px solid #ddd; border-radius: 5px; padding: 10px; } .gr-textbox textarea { border: 1px solid #ddd; border-radius: 5px; padding: 10px; } """ with gr.Blocks(css=css) as demo: gr.HTML("

Grammar Correction Tool

") with gr.Row(elem_id="interface-container"): with gr.Column(): user_input = gr.Textbox(lines=2, placeholder="Enter a sentence with grammatical errors...", label="Input Text", elem_id="input-container") max_length = gr.Slider(minimum=1, maximum=256, value=100, step=1, label="Max Length") min_length = gr.Slider(minimum=1, maximum=256, value=0, step=1, label="Min Length") max_new_tokens = gr.Slider(minimum=0, maximum=256, value=0, step=1, label="Max New Tokens (optional)") num_beams = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Num Beams") temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature") top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") btn = gr.Button("Correct Grammar") with gr.Column(): corrected_output = gr.Textbox(lines=2, placeholder="The corrected sentence will appear here...", label="Corrected Text", elem_id="output-container") btn.click( fn=correct_text, inputs=[user_input, max_length, max_new_tokens, min_length, num_beams, temperature, top_p], outputs=corrected_output ) if __name__ == "__main__": demo.launch()