etemkocaaslan commited on
Commit
7621385
·
verified ·
1 Parent(s): 80e1b70

initial-dev

Browse files
Files changed (1) hide show
  1. app.py +106 -0
app.py CHANGED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+
4
+ model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction")
5
+ tokenizer = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction")
6
+
7
+ def correct_text(text, max_length, max_new_tokens, min_length, num_beams, temperature, top_p):
8
+ inputs = tokenizer.encode("grammar: " + text, return_tensors="pt")
9
+ generate_kwargs = {
10
+ "inputs": inputs,
11
+ "max_length": max_length,
12
+ "min_length": min_length,
13
+ "num_beams": num_beams,
14
+ "temperature": temperature,
15
+ "top_p": top_p,
16
+ "early_stopping": True
17
+ }
18
+
19
+ if max_new_tokens > 0:
20
+ generate_kwargs["max_new_tokens"] = max_new_tokens
21
+
22
+ outputs = model.generate(**generate_kwargs)
23
+ corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
24
+ return corrected_text
25
+
26
+ def respond(message, history, max_length, min_length, max_new_tokens, num_beams, temperature, top_p):
27
+ response = correct_text(message, max_length, max_new_tokens, min_length, num_beams, temperature, top_p)
28
+ yield response
29
+
30
+ css = """
31
+ #interface-container {
32
+ padding: 20px;
33
+ border-radius: 10px;
34
+ box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
35
+ max-width: 800px;
36
+ margin: auto;
37
+ font-family: 'Arial', sans-serif;
38
+ }
39
+ #input-container {
40
+ margin-bottom: 20px;
41
+ }
42
+ #output-container {
43
+ margin-top: 20px;
44
+ font-family: Arial, sans-serif;
45
+ font-size: 16px;
46
+ color: #333;
47
+ padding: 10px;
48
+ border-radius: 5px;
49
+ border: 1px solid #ddd;
50
+ }
51
+ .gr-button {
52
+ background-color: #007bff;
53
+ color: white;
54
+ border: none;
55
+ padding: 10px 20px;
56
+ border-radius: 5px;
57
+ cursor: pointer;
58
+ font-size: 16px;
59
+ }
60
+ .gr-button:hover {
61
+ background-color: #0056b3;
62
+ }
63
+ .gr-slider .gr-slider-track {
64
+ background-color: #007bff;
65
+ }
66
+ .gr-slider .gr-slider-thumb {
67
+ background-color: #0056b3;
68
+ }
69
+ .gr-textbox input {
70
+ border: 1px solid #ddd;
71
+ border-radius: 5px;
72
+ padding: 10px;
73
+ }
74
+ .gr-textbox textarea {
75
+ border: 1px solid #ddd;
76
+ border-radius: 5px;
77
+ padding: 10px;
78
+ }
79
+ """
80
+
81
+ with gr.Blocks(css=css) as demo:
82
+ gr.HTML("<h1 style='text-align: center; color: #007bff;'>Grammar Correction Tool</h1>")
83
+
84
+ with gr.Row(elem_id="interface-container"):
85
+ with gr.Column():
86
+ user_input = gr.Textbox(lines=2, placeholder="Enter a sentence with grammatical errors...", label="Input Text", elem_id="input-container")
87
+ max_length = gr.Slider(minimum=1, maximum=256, value=100, step=1, label="Max Length")
88
+ min_length = gr.Slider(minimum=1, maximum=256, value=0, step=1, label="Min Length")
89
+ max_new_tokens = gr.Slider(minimum=0, maximum=256, value=0, step=1, label="Max New Tokens (optional)")
90
+ num_beams = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Num Beams")
91
+ temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
92
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
93
+
94
+ btn = gr.Button("Correct Grammar")
95
+
96
+ with gr.Column():
97
+ corrected_output = gr.Textbox(lines=2, placeholder="The corrected sentence will appear here...", label="Corrected Text", elem_id="output-container")
98
+
99
+ btn.click(
100
+ fn=correct_text,
101
+ inputs=[user_input, max_length, max_new_tokens, min_length, num_beams, temperature, top_p],
102
+ outputs=corrected_output
103
+ )
104
+
105
+ if __name__ == "__main__":
106
+ demo.launch()