|
import gradio as gr |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction") |
|
tokenizer = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction") |
|
|
|
def correct_text(text, max_length, max_new_tokens, min_length, num_beams, temperature, top_p): |
|
inputs = tokenizer.encode("grammar: " + text, return_tensors="pt") |
|
generate_kwargs = { |
|
"inputs": inputs, |
|
"max_length": max_length, |
|
"min_length": min_length, |
|
"num_beams": num_beams, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"early_stopping": True |
|
} |
|
|
|
if max_new_tokens > 0: |
|
generate_kwargs["max_new_tokens"] = max_new_tokens |
|
|
|
outputs = model.generate(**generate_kwargs) |
|
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return corrected_text |
|
|
|
def respond(message, history, max_length, min_length, max_new_tokens, num_beams, temperature, top_p): |
|
response = correct_text(message, max_length, max_new_tokens, min_length, num_beams, temperature, top_p) |
|
yield response |
|
|
|
css = """ |
|
#interface-container { |
|
padding: 20px; |
|
border-radius: 10px; |
|
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1); |
|
max-width: 800px; |
|
margin: auto; |
|
font-family: 'Arial', sans-serif; |
|
} |
|
#input-container { |
|
margin-bottom: 20px; |
|
} |
|
#output-container { |
|
margin-top: 20px; |
|
font-family: Arial, sans-serif; |
|
font-size: 16px; |
|
color: #333; |
|
padding: 10px; |
|
border-radius: 5px; |
|
border: 1px solid #ddd; |
|
} |
|
.gr-button { |
|
background-color: #007bff; |
|
color: white; |
|
border: none; |
|
padding: 10px 20px; |
|
border-radius: 5px; |
|
cursor: pointer; |
|
font-size: 16px; |
|
} |
|
.gr-button:hover { |
|
background-color: #0056b3; |
|
} |
|
.gr-slider .gr-slider-track { |
|
background-color: #007bff; |
|
} |
|
.gr-slider .gr-slider-thumb { |
|
background-color: #0056b3; |
|
} |
|
.gr-textbox input { |
|
border: 1px solid #ddd; |
|
border-radius: 5px; |
|
padding: 10px; |
|
} |
|
.gr-textbox textarea { |
|
border: 1px solid #ddd; |
|
border-radius: 5px; |
|
padding: 10px; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.HTML("<h1 style='text-align: center; color: #007bff;'>Grammar Correction Tool</h1>") |
|
|
|
with gr.Row(elem_id="interface-container"): |
|
with gr.Column(): |
|
user_input = gr.Textbox(lines=2, placeholder="Enter a sentence with grammatical errors...", label="Input Text", elem_id="input-container") |
|
max_length = gr.Slider(minimum=1, maximum=256, value=100, step=1, label="Max Length") |
|
min_length = gr.Slider(minimum=1, maximum=256, value=0, step=1, label="Min Length") |
|
max_new_tokens = gr.Slider(minimum=0, maximum=256, value=0, step=1, label="Max New Tokens (optional)") |
|
num_beams = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Num Beams") |
|
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature") |
|
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") |
|
|
|
btn = gr.Button("Correct Grammar") |
|
|
|
with gr.Column(): |
|
corrected_output = gr.Textbox(lines=2, placeholder="The corrected sentence will appear here...", label="Corrected Text", elem_id="output-container") |
|
|
|
btn.click( |
|
fn=correct_text, |
|
inputs=[user_input, max_length, max_new_tokens, min_length, num_beams, temperature, top_p], |
|
outputs=corrected_output |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |