Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction") | |
tokenizer = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction") | |
def correct_text(text, max_length, max_new_tokens, min_length, num_beams, temperature, top_p): | |
inputs = tokenizer.encode("grammar: " + text, return_tensors="pt") | |
generate_kwargs = { | |
"inputs": inputs, | |
"max_length": max_length, | |
"min_length": min_length, | |
"num_beams": num_beams, | |
"temperature": temperature, | |
"top_p": top_p, | |
"early_stopping": True | |
} | |
if max_new_tokens > 0: | |
generate_kwargs["max_new_tokens"] = max_new_tokens | |
outputs = model.generate(**generate_kwargs) | |
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return corrected_text | |
def respond(message, history, max_length, min_length, max_new_tokens, num_beams, temperature, top_p): | |
response = correct_text(message, max_length, max_new_tokens, min_length, num_beams, temperature, top_p) | |
yield response | |
css = """ | |
#interface-container { | |
padding: 20px; | |
border-radius: 10px; | |
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1); | |
max-width: 800px; | |
margin: auto; | |
font-family: 'Arial', sans-serif; | |
} | |
#input-container { | |
margin-bottom: 20px; | |
} | |
#output-container { | |
margin-top: 20px; | |
font-family: Arial, sans-serif; | |
font-size: 16px; | |
color: #333; | |
padding: 10px; | |
border-radius: 5px; | |
border: 1px solid #ddd; | |
} | |
.gr-button { | |
background-color: #007bff; | |
color: white; | |
border: none; | |
padding: 10px 20px; | |
border-radius: 5px; | |
cursor: pointer; | |
font-size: 16px; | |
} | |
.gr-button:hover { | |
background-color: #0056b3; | |
} | |
.gr-slider .gr-slider-track { | |
background-color: #007bff; | |
} | |
.gr-slider .gr-slider-thumb { | |
background-color: #0056b3; | |
} | |
.gr-textbox input { | |
border: 1px solid #ddd; | |
border-radius: 5px; | |
padding: 10px; | |
} | |
.gr-textbox textarea { | |
border: 1px solid #ddd; | |
border-radius: 5px; | |
padding: 10px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.HTML("<h1 style='text-align: center; color: #007bff;'>Grammar Correction Tool</h1>") | |
with gr.Row(elem_id="interface-container"): | |
with gr.Column(): | |
user_input = gr.Textbox(lines=2, placeholder="Enter a sentence with grammatical errors...", label="Input Text", elem_id="input-container") | |
max_length = gr.Slider(minimum=1, maximum=256, value=100, step=1, label="Max Length") | |
min_length = gr.Slider(minimum=1, maximum=256, value=0, step=1, label="Min Length") | |
max_new_tokens = gr.Slider(minimum=0, maximum=256, value=0, step=1, label="Max New Tokens (optional)") | |
num_beams = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Num Beams") | |
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature") | |
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") | |
btn = gr.Button("Correct Grammar") | |
with gr.Column(): | |
corrected_output = gr.Textbox(lines=2, placeholder="The corrected sentence will appear here...", label="Corrected Text", elem_id="output-container") | |
btn.click( | |
fn=correct_text, | |
inputs=[user_input, max_length, max_new_tokens, min_length, num_beams, temperature, top_p], | |
outputs=corrected_output | |
) | |
if __name__ == "__main__": | |
demo.launch() |