File size: 4,214 Bytes
009f8e2
23c0953
afb1063
d10fd10
 
 
 
95f281d
d10fd10
 
95f281d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d10fd10
 
 
 
 
 
 
 
95f281d
 
d10fd10
 
 
7066d45
d10fd10
23c0953
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afb1063
5aafc6e
 
afb1063
 
 
23c0953
 
fa07e23
afb1063
90fd9d9
 
21b62df
 
 
90fd9d9
21b62df
 
 
23c0953
 
90fd9d9
fa07e23
 
 
95f281d
 
fa07e23
 
 
 
90fd9d9
23c0953
afb1063
23c0953
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig

# Load the model and tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction")
tokenizer = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction")

def correct_text(text, max_length, min_length, max_new_tokens, min_new_tokens, num_beams, temperature, top_p):
    inputs = tokenizer.encode("grammar: " + text, return_tensors="pt")
    
    if max_new_tokens > 0 or min_new_tokens > 0:
        if max_new_tokens > 0 and min_new_tokens > 0:
            outputs = model.generate(
                inputs,
                max_new_tokens=max_new_tokens,
                min_new_tokens=min_new_tokens,
                num_beams=num_beams,
                temperature=temperature,
                top_p=top_p,
                early_stopping=True,
                do_sample=True
            )
        elif max_new_tokens > 0:
            outputs = model.generate(inputs, max_new_tokens=max_new_tokens, min_length=min_length, num_beams=num_beams, temperature=temperature, top_p=top_p, early_stopping=True, do_sample=True)
        else:
            outputs = model.generate(inputs, max_length=max_length, min_new_tokens=min_new_tokens, num_beams=num_beams, temperature=temperature, top_p=top_p, early_stopping=True, do_sample=True)
    else:
        outputs = model.generate(
            inputs,
            max_length=max_length,
            min_length=min_length,
            num_beams=num_beams,
            temperature=temperature,
            top_p=top_p,
            early_stopping=True,
            do_sample=True
        )
    
    corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    yield corrected_text

def correct_text2(text, genConfig):
    inputs = tokenizer.encode("grammar: " + text, return_tensors="pt")
    outputs = model.generate(inputs, **genConfig.to_dict())

    corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    yield corrected_text

def respond(text, max_length, min_length, max_new_tokens, min_new_tokens, num_beams, temperature, top_p):
    config = GenerationConfig(
        max_length=max_length,
        min_length=min_length,
        num_beams=num_beams,
        temperature=temperature,
        top_p=top_p,
        early_stopping=True,
        do_sample=True
    )

    # Add max/min new tokens if they are there
    if max_new_tokens > 0: 
        config.max_new_tokens = max_new_tokens
    if min_new_tokens > 0: 
        config.min_new_tokens = min_new_tokens
    
    corrected = correct_text2(text, config)
    yield corrected

    

def update_prompt(prompt):
    return prompt

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("""# Grammar Correction App""")
    prompt_box = gr.Textbox(placeholder="Enter your prompt here...")
    output_box = gr.Textbox()

    # Sample prompts
    with gr.Row():
        samp1 = gr.Button("we shood buy an car")
        samp2 = gr.Button("she is more taller")
        samp3 = gr.Button("John and i saw a sheep over their.")
        
        samp1.click(update_prompt, samp1, prompt_box)
        samp2.click(update_prompt, samp2, prompt_box)
        samp3.click(update_prompt, samp3, prompt_box)

    submitBtn = gr.Button("Submit")
    
    with gr.Accordion("Generation Parameters:", open=False):
        max_length = gr.Slider(minimum=1, maximum=256, value=80, step=1, label="Max Length")
        min_length = gr.Slider(minimum=1, maximum=256, value=0,  step=1, label="Min Length")
        max_tokens = gr.Slider(minimum=0, maximum=256, value=0, step=1, label="Max New Tokens")
        min_tokens = gr.Slider(minimum=0, maximum=256, value=0, step=1, label="Min New Tokens")
        num_beams = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Num Beams")
        temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
        top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")

    
    submitBtn.click(respond, [prompt_box, max_length, min_length, max_tokens, min_tokens, num_beams, temperature, top_p], output_box)

demo.launch()