Anupam251272 commited on
Commit
c3c912b
·
verified ·
1 Parent(s): 7811719

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import logging
3
+ import os
4
+ import re
5
+
6
+
7
+ import spaces
8
+
9
+ import torch
10
+ from cleantext import clean
11
+ import gradio as gr
12
+ from tqdm.auto import tqdm
13
+ from transformers import pipeline
14
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logging.info(f"torch version:\t{torch.__version__}")
18
+
19
+ # Model names
20
+ checker_model_name = "textattack/roberta-base-CoLA"
21
+ corrector_model_name = "pszemraj/flan-t5-large-grammar-synthesis"
22
+
23
+
24
+ checker = pipeline(
25
+ "text-classification",
26
+ checker_model_name,
27
+ device_map="cuda",
28
+ )
29
+
30
+ corrector = pipeline(
31
+ "text2text-generation",
32
+ corrector_model_name,
33
+ device_map="cuda",
34
+ )
35
+
36
+ def split_text(text: str) -> list:
37
+ # Split the text into sentences using regex
38
+ sentences = re.split(r"(?<=[^A-Z].[.?]) +(?=[A-Z])", text)
39
+
40
+ # Initialize lists for batching
41
+ sentence_batches = []
42
+ temp_batch = []
43
+
44
+ # Create batches of 2-3 sentences
45
+ for sentence in sentences:
46
+ temp_batch.append(sentence)
47
+ if len(temp_batch) >= 2 and len(temp_batch) <= 3 or sentence == sentences[-1]:
48
+ sentence_batches.append(temp_batch)
49
+ temp_batch = []
50
+
51
+ return sentence_batches
52
+
53
+
54
+ @spaces.GPU(duration=60)
55
+ def correct_text(text: str, separator: str = " ") -> str:
56
+
57
+ # Split the text into sentence batches
58
+ sentence_batches = split_text(text)
59
+
60
+ # Initialize a list to store the corrected text
61
+ corrected_text = []
62
+
63
+ # Process each batch
64
+ for batch in tqdm(
65
+ sentence_batches, total=len(sentence_batches), desc="correcting text.."
66
+ ):
67
+ raw_text = " ".join(batch)
68
+
69
+ # Check grammar quality
70
+ results = checker(raw_text)
71
+
72
+ # Correct text if needed
73
+ if results[0]["label"] != "LABEL_1" or (
74
+ results[0]["label"] == "LABEL_1" and results[0]["score"] < 0.9
75
+ ):
76
+ corrected_batch = corrector(raw_text)
77
+ corrected_text.append(corrected_batch[0]["generated_text"])
78
+ else:
79
+ corrected_text.append(raw_text)
80
+
81
+ # Join the corrected text
82
+ return separator.join(corrected_text)
83
+
84
+
85
+ def update(text: str):
86
+ # Clean and truncate input text
87
+ text = clean(text[:4000], lower=False)
88
+ return correct_text(text)
89
+
90
+
91
+ # Create the Gradio interface
92
+ with gr.Blocks() as demo:
93
+ gr.Markdown("# <center>Robust Grammar Correction with FLAN-T5</center>")
94
+ gr.Markdown(
95
+ "**Instructions:** Enter the text you want to correct in the textbox below (_text will be truncated to 4000 characters_). Click 'Process' to run."
96
+ )
97
+ gr.Markdown(
98
+ """Models:
99
+ - `textattack/roberta-base-CoLA` for grammar quality detection
100
+ - `pszemraj/flan-t5-large-grammar-synthesis` for grammar correction
101
+ """
102
+ )
103
+ with gr.Row():
104
+ inp = gr.Textbox(
105
+ label="input",
106
+ placeholder="Enter text to check & correct",
107
+ value="I wen to the store yesturday to bye some food. I needd milk, bread, and a few otter things. The store was really crowed and I had a hard time finding everyting I needed. I finaly made it to the check out line and payed for my stuff.",
108
+ )
109
+ out = gr.Textbox(label="output", interactive=False)
110
+ btn = gr.Button("Process")
111
+ btn.click(fn=update, inputs=inp, outputs=out)
112
+ gr.Markdown("---")
113
+ gr.Markdown(
114
+ "- See the [model card](https://huggingface.co/pszemraj/flan-t5-large-grammar-synthesis) for more info"
115
+ )
116
+
117
+ # Launch the demo
118
+ demo.launch(debug=True)