Spaces:
Sleeping
Sleeping
File size: 4,214 Bytes
009f8e2 23c0953 afb1063 d10fd10 95f281d d10fd10 95f281d d10fd10 95f281d d10fd10 7066d45 d10fd10 23c0953 afb1063 5aafc6e afb1063 23c0953 fa07e23 afb1063 90fd9d9 21b62df 90fd9d9 21b62df 23c0953 90fd9d9 fa07e23 95f281d fa07e23 90fd9d9 23c0953 afb1063 23c0953 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
# Load the model and tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction")
tokenizer = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction")
def correct_text(text, max_length, min_length, max_new_tokens, min_new_tokens, num_beams, temperature, top_p):
inputs = tokenizer.encode("grammar: " + text, return_tensors="pt")
if max_new_tokens > 0 or min_new_tokens > 0:
if max_new_tokens > 0 and min_new_tokens > 0:
outputs = model.generate(
inputs,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
num_beams=num_beams,
temperature=temperature,
top_p=top_p,
early_stopping=True,
do_sample=True
)
elif max_new_tokens > 0:
outputs = model.generate(inputs, max_new_tokens=max_new_tokens, min_length=min_length, num_beams=num_beams, temperature=temperature, top_p=top_p, early_stopping=True, do_sample=True)
else:
outputs = model.generate(inputs, max_length=max_length, min_new_tokens=min_new_tokens, num_beams=num_beams, temperature=temperature, top_p=top_p, early_stopping=True, do_sample=True)
else:
outputs = model.generate(
inputs,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
temperature=temperature,
top_p=top_p,
early_stopping=True,
do_sample=True
)
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
yield corrected_text
def correct_text2(text, genConfig):
inputs = tokenizer.encode("grammar: " + text, return_tensors="pt")
outputs = model.generate(inputs, **genConfig.to_dict())
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
yield corrected_text
def respond(text, max_length, min_length, max_new_tokens, min_new_tokens, num_beams, temperature, top_p):
config = GenerationConfig(
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
temperature=temperature,
top_p=top_p,
early_stopping=True,
do_sample=True
)
# Add max/min new tokens if they are there
if max_new_tokens > 0:
config.max_new_tokens = max_new_tokens
if min_new_tokens > 0:
config.min_new_tokens = min_new_tokens
corrected = correct_text2(text, config)
yield corrected
def update_prompt(prompt):
return prompt
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("""# Grammar Correction App""")
prompt_box = gr.Textbox(placeholder="Enter your prompt here...")
output_box = gr.Textbox()
# Sample prompts
with gr.Row():
samp1 = gr.Button("we shood buy an car")
samp2 = gr.Button("she is more taller")
samp3 = gr.Button("John and i saw a sheep over their.")
samp1.click(update_prompt, samp1, prompt_box)
samp2.click(update_prompt, samp2, prompt_box)
samp3.click(update_prompt, samp3, prompt_box)
submitBtn = gr.Button("Submit")
with gr.Accordion("Generation Parameters:", open=False):
max_length = gr.Slider(minimum=1, maximum=256, value=80, step=1, label="Max Length")
min_length = gr.Slider(minimum=1, maximum=256, value=0, step=1, label="Min Length")
max_tokens = gr.Slider(minimum=0, maximum=256, value=0, step=1, label="Max New Tokens")
min_tokens = gr.Slider(minimum=0, maximum=256, value=0, step=1, label="Min New Tokens")
num_beams = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Num Beams")
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
submitBtn.click(respond, [prompt_box, max_length, min_length, max_tokens, min_tokens, num_beams, temperature, top_p], output_box)
demo.launch() |