File size: 2,870 Bytes
009f8e2
23c0953
afb1063
d10fd10
 
 
 
 
f4b9a92
23c0953
 
 
 
6213036
23c0953
9b8cca7
23c0953
 
 
 
 
9b8cca7
23c0953
 
 
 
 
 
 
 
 
 
 
f4b9a92
23c0953
 
5aafc6e
 
afb1063
 
 
23c0953
 
fa07e23
afb1063
90fd9d9
 
21b62df
 
 
90fd9d9
21b62df
 
 
23c0953
90fd9d9
fa07e23
 
 
95f281d
 
fa07e23
 
9b8cca7
fa07e23
9b8cca7
fa07e23
90fd9d9
f4b9a92
afb1063
23c0953
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig

# Load the model and tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction")
tokenizer = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction")


def correct_text(text, genConfig):
    inputs = tokenizer.encode("grammar: " + text, return_tensors="pt")
    outputs = model.generate(inputs, **genConfig.to_dict())

    corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return corrected_text

def respond(text, max_length, min_length, max_new_tokens, min_new_tokens, num_beams, temperature: float, top_k, top_p: float):
    config = GenerationConfig(
        max_length=max_length,
        min_length=min_length,
        num_beams=num_beams,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        early_stopping=True,
        do_sample=True
    )

    # Add max/min new tokens if they are there
    if max_new_tokens > 0: 
        config.max_new_tokens = max_new_tokens
    if min_new_tokens > 0: 
        config.min_new_tokens = min_new_tokens
    
    corrected = correct_text(text, config)
    yield corrected

def update_prompt(prompt):
    return prompt

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("""# Grammar Correction App""")
    prompt_box = gr.Textbox(placeholder="Enter your prompt here...")
    output_box = gr.Textbox()

    # Sample prompts
    with gr.Row():
        samp1 = gr.Button("we shood buy an car")
        samp2 = gr.Button("she is more taller")
        samp3 = gr.Button("John and i saw a sheep over their.")
        
        samp1.click(update_prompt, samp1, prompt_box)
        samp2.click(update_prompt, samp2, prompt_box)
        samp3.click(update_prompt, samp3, prompt_box)
    submitBtn = gr.Button("Submit")
    
    with gr.Accordion("Generation Parameters:", open=False):
        max_length = gr.Slider(minimum=1, maximum=256, value=80, step=1, label="Max Length")
        min_length = gr.Slider(minimum=1, maximum=256, value=0,  step=1, label="Min Length")
        max_tokens = gr.Slider(minimum=0, maximum=256, value=0, step=1, label="Max New Tokens")
        min_tokens = gr.Slider(minimum=0, maximum=256, value=0, step=1, label="Min New Tokens")
        num_beams = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Num Beams")
        temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
        top_k = gr.Slider(minimum=0, maximum=200, value=50, step=1, label="Top-k")
        top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
        

    
    submitBtn.click(respond, [prompt_box, max_length, min_length, max_tokens, min_tokens, num_beams, temperature, top_k, top_p], output_box)

demo.launch()