sashtech commited on
Commit
3541ec5
·
verified ·
1 Parent(s): 7fa9ef9

Rename gingerit.py to gramformer.py

Browse files
Files changed (2) hide show
  1. gingerit.py +0 -77
  2. gramformer.py +128 -0
gingerit.py DELETED
@@ -1,77 +0,0 @@
1
-
2
- import requests
3
- import cloudscraper
4
-
5
- URL = "https://services.gingersoftware.com/Ginger/correct/jsonSecured/GingerTheTextFull" # noqa
6
- API_KEY = "6ae0c3a0-afdc-4532-a810-82ded0054236"
7
-
8
- class GingerIt(object):
9
- def __init__(self):
10
- self.url = URL
11
- self.api_key = API_KEY
12
- self.api_version = "2.0"
13
- self.lang = "US"
14
-
15
- def parse(self, text, verify=True):
16
- session = cloudscraper.create_scraper()
17
- try:
18
- request = session.get(
19
- self.url,
20
- params={
21
- "lang": self.lang,
22
- "apiKey": self.api_key,
23
- "clientVersion": self.api_version,
24
- "text": text,
25
- },
26
- verify=verify,
27
- )
28
-
29
- # Print the raw response text for debugging
30
- print("Response Text:", request.text)
31
-
32
- # Try parsing the response as JSON
33
- try:
34
- data = request.json()
35
- except ValueError:
36
- print("Failed to parse JSON response. Returning original text.")
37
- return {"text": text, "result": text, "corrections": []}
38
-
39
- # Process and return the corrected data
40
- return self._process_data(text, data)
41
-
42
- except requests.exceptions.RequestException as e:
43
- print(f"An error occurred during the API request: {e}")
44
- return {"text": text, "result": text, "corrections": []}
45
-
46
- @staticmethod
47
- def _change_char(original_text, from_position, to_position, change_with):
48
- return "{}{}{}".format(
49
- original_text[:from_position], change_with, original_text[to_position + 1 :]
50
- )
51
-
52
- def _process_data(self, text, data):
53
- result = text
54
- corrections = []
55
-
56
- if "Corrections" not in data:
57
- print("No corrections found in the API response.")
58
- return {"text": text, "result": text, "corrections": corrections}
59
-
60
- for suggestion in reversed(data["Corrections"]):
61
- start = suggestion["From"]
62
- end = suggestion["To"]
63
-
64
- if suggestion["Suggestions"]:
65
- suggest = suggestion["Suggestions"][0]
66
- result = self._change_char(result, start, end, suggest["Text"])
67
-
68
- corrections.append(
69
- {
70
- "start": start,
71
- "text": text[start : end + 1],
72
- "correct": suggest.get("Text", None),
73
- "definition": suggest.get("Definition", None),
74
- }
75
- )
76
-
77
- return {"text": text, "result": result, "corrections": corrections}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gramformer.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Gramformer:
2
+
3
+ def __init__(self, models=1, use_gpu=False):
4
+ from transformers import AutoTokenizer
5
+ from transformers import AutoModelForSeq2SeqLM
6
+ #from lm_scorer.models.auto import AutoLMScorer as LMScorer
7
+ import errant
8
+ self.annotator = errant.load('en')
9
+
10
+ if use_gpu:
11
+ device= "cuda:0"
12
+ else:
13
+ device = "cpu"
14
+ batch_size = 1
15
+ #self.scorer = LMScorer.from_pretrained("gpt2", device=device, batch_size=batch_size)
16
+ self.device = device
17
+ correction_model_tag = "prithivida/grammar_error_correcter_v1"
18
+ self.model_loaded = False
19
+
20
+ if models == 1:
21
+ self.correction_tokenizer = AutoTokenizer.from_pretrained(correction_model_tag, use_auth_token=False)
22
+ self.correction_model = AutoModelForSeq2SeqLM.from_pretrained(correction_model_tag, use_auth_token=False)
23
+ self.correction_model = self.correction_model.to(device)
24
+ self.model_loaded = True
25
+ print("[Gramformer] Grammar error correct/highlight model loaded..")
26
+ elif models == 2:
27
+ # TODO
28
+ print("TO BE IMPLEMENTED!!!")
29
+
30
+ def correct(self, input_sentence, max_candidates=1):
31
+ if self.model_loaded:
32
+ correction_prefix = "gec: "
33
+ input_sentence = correction_prefix + input_sentence
34
+ input_ids = self.correction_tokenizer.encode(input_sentence, return_tensors='pt')
35
+ input_ids = input_ids.to(self.device)
36
+
37
+ preds = self.correction_model.generate(
38
+ input_ids,
39
+ do_sample=True,
40
+ max_length=128,
41
+ # top_k=50,
42
+ # top_p=0.95,
43
+ num_beams=7,
44
+ early_stopping=True,
45
+ num_return_sequences=max_candidates)
46
+
47
+ corrected = set()
48
+ for pred in preds:
49
+ corrected.add(self.correction_tokenizer.decode(pred, skip_special_tokens=True).strip())
50
+
51
+ #corrected = list(corrected)
52
+ #scores = self.scorer.sentence_score(corrected, log=True)
53
+ #ranked_corrected = [(c,s) for c, s in zip(corrected, scores)]
54
+ #ranked_corrected.sort(key = lambda x:x[1], reverse=True)
55
+ return corrected
56
+ else:
57
+ print("Model is not loaded")
58
+ return None
59
+
60
+ def highlight(self, orig, cor):
61
+ edits = self._get_edits(orig, cor)
62
+ orig_tokens = orig.split()
63
+
64
+ ignore_indexes = []
65
+
66
+ for edit in edits:
67
+ edit_type = edit[0]
68
+ edit_str_start = edit[1]
69
+ edit_spos = edit[2]
70
+ edit_epos = edit[3]
71
+ edit_str_end = edit[4]
72
+
73
+ # if no_of_tokens(edit_str_start) > 1 ==> excluding the first token, mark all other tokens for deletion
74
+ for i in range(edit_spos+1, edit_epos):
75
+ ignore_indexes.append(i)
76
+
77
+ if edit_str_start == "":
78
+ if edit_spos - 1 >= 0:
79
+ new_edit_str = orig_tokens[edit_spos - 1]
80
+ edit_spos -= 1
81
+ else:
82
+ new_edit_str = orig_tokens[edit_spos + 1]
83
+ edit_spos += 1
84
+ if edit_type == "PUNCT":
85
+ st = "<a type='" + edit_type + "' edit='" + \
86
+ edit_str_end + "'>" + new_edit_str + "</a>"
87
+ else:
88
+ st = "<a type='" + edit_type + "' edit='" + new_edit_str + \
89
+ " " + edit_str_end + "'>" + new_edit_str + "</a>"
90
+ orig_tokens[edit_spos] = st
91
+ elif edit_str_end == "":
92
+ st = "<d type='" + edit_type + "' edit=''>" + edit_str_start + "</d>"
93
+ orig_tokens[edit_spos] = st
94
+ else:
95
+ st = "<c type='" + edit_type + "' edit='" + \
96
+ edit_str_end + "'>" + edit_str_start + "</c>"
97
+ orig_tokens[edit_spos] = st
98
+
99
+ for i in sorted(ignore_indexes, reverse=True):
100
+ del(orig_tokens[i])
101
+
102
+ return(" ".join(orig_tokens))
103
+
104
+ def detect(self, input_sentence):
105
+ # TO BE IMPLEMENTED
106
+ pass
107
+
108
+ def _get_edits(self, orig, cor):
109
+ orig = self.annotator.parse(orig)
110
+ cor = self.annotator.parse(cor)
111
+ alignment = self.annotator.align(orig, cor)
112
+ edits = self.annotator.merge(alignment)
113
+
114
+ if len(edits) == 0:
115
+ return []
116
+
117
+ edit_annotations = []
118
+ for e in edits:
119
+ e = self.annotator.classify(e)
120
+ edit_annotations.append((e.type[2:], e.o_str, e.o_start, e.o_end, e.c_str, e.c_start, e.c_end))
121
+
122
+ if len(edit_annotations) > 0:
123
+ return edit_annotations
124
+ else:
125
+ return []
126
+
127
+ def get_edits(self, orig, cor):
128
+ return self._get_edits(orig, cor)