etemkocaaslan's picture
initial-dev
7621385 verified
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction")
tokenizer = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction")
def correct_text(text, max_length, max_new_tokens, min_length, num_beams, temperature, top_p):
inputs = tokenizer.encode("grammar: " + text, return_tensors="pt")
generate_kwargs = {
"inputs": inputs,
"max_length": max_length,
"min_length": min_length,
"num_beams": num_beams,
"temperature": temperature,
"top_p": top_p,
"early_stopping": True
}
if max_new_tokens > 0:
generate_kwargs["max_new_tokens"] = max_new_tokens
outputs = model.generate(**generate_kwargs)
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return corrected_text
def respond(message, history, max_length, min_length, max_new_tokens, num_beams, temperature, top_p):
response = correct_text(message, max_length, max_new_tokens, min_length, num_beams, temperature, top_p)
yield response
css = """
#interface-container {
padding: 20px;
border-radius: 10px;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
max-width: 800px;
margin: auto;
font-family: 'Arial', sans-serif;
}
#input-container {
margin-bottom: 20px;
}
#output-container {
margin-top: 20px;
font-family: Arial, sans-serif;
font-size: 16px;
color: #333;
padding: 10px;
border-radius: 5px;
border: 1px solid #ddd;
}
.gr-button {
background-color: #007bff;
color: white;
border: none;
padding: 10px 20px;
border-radius: 5px;
cursor: pointer;
font-size: 16px;
}
.gr-button:hover {
background-color: #0056b3;
}
.gr-slider .gr-slider-track {
background-color: #007bff;
}
.gr-slider .gr-slider-thumb {
background-color: #0056b3;
}
.gr-textbox input {
border: 1px solid #ddd;
border-radius: 5px;
padding: 10px;
}
.gr-textbox textarea {
border: 1px solid #ddd;
border-radius: 5px;
padding: 10px;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML("<h1 style='text-align: center; color: #007bff;'>Grammar Correction Tool</h1>")
with gr.Row(elem_id="interface-container"):
with gr.Column():
user_input = gr.Textbox(lines=2, placeholder="Enter a sentence with grammatical errors...", label="Input Text", elem_id="input-container")
max_length = gr.Slider(minimum=1, maximum=256, value=100, step=1, label="Max Length")
min_length = gr.Slider(minimum=1, maximum=256, value=0, step=1, label="Min Length")
max_new_tokens = gr.Slider(minimum=0, maximum=256, value=0, step=1, label="Max New Tokens (optional)")
num_beams = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Num Beams")
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
btn = gr.Button("Correct Grammar")
with gr.Column():
corrected_output = gr.Textbox(lines=2, placeholder="The corrected sentence will appear here...", label="Corrected Text", elem_id="output-container")
btn.click(
fn=correct_text,
inputs=[user_input, max_length, max_new_tokens, min_length, num_beams, temperature, top_p],
outputs=corrected_output
)
if __name__ == "__main__":
demo.launch()