sashtech commited on
Commit
38ba5af
·
verified ·
1 Parent(s): 5dca2dc

Update gramformer.py

Browse files
Files changed (1) hide show
  1. gramformer.py +103 -105
gramformer.py CHANGED
@@ -1,111 +1,109 @@
 
 
 
1
  class Gramformer:
2
 
3
- def __init__(self, models=1, use_gpu=False):
4
- from transformers import AutoTokenizer
5
- from transformers import AutoModelForSeq2SeqLM
6
- #from lm_scorer.models.auto import AutoLMScorer as LMScorer
7
- import errant
8
- self.annotator = errant.load('en')
9
-
10
- if use_gpu:
11
- device= "cuda:0"
12
- else:
13
- device = "cpu"
14
- batch_size = 1
15
- #self.scorer = LMScorer.from_pretrained("gpt2", device=device, batch_size=batch_size)
16
- self.device = device
17
- correction_model_tag = "prithivida/grammar_error_correcter_v1"
18
- self.model_loaded = False
19
-
20
- if models == 1:
21
- self.correction_tokenizer = AutoTokenizer.from_pretrained(correction_model_tag, use_auth_token=False)
22
- self.correction_model = AutoModelForSeq2SeqLM.from_pretrained(correction_model_tag, use_auth_token=False)
23
- self.correction_model = self.correction_model.to(device)
24
- self.model_loaded = True
25
- print("[Gramformer] Grammar error correct/highlight model loaded..")
26
- elif models == 2:
27
- # TODO
28
- print("TO BE IMPLEMENTED!!!")
29
-
30
- def correct(self, input_sentence, max_candidates=1):
31
- if self.model_loaded:
32
- correction_prefix = "gec: "
33
- input_sentence = correction_prefix + input_sentence
34
- input_ids = self.correction_tokenizer.encode(input_sentence, return_tensors='pt')
35
- input_ids = input_ids.to(self.device)
36
-
37
- preds = self.correction_model.generate(
38
- input_ids,
39
- do_sample=True,
40
- max_length=128,
41
- # top_k=50,
42
- # top_p=0.95,
43
- num_beams=7,
44
- early_stopping=True,
45
- num_return_sequences=max_candidates)
46
-
47
- corrected = set()
48
- for pred in preds:
49
- corrected.add(self.correction_tokenizer.decode(pred, skip_special_tokens=True).strip())
50
-
51
- #corrected = list(corrected)
52
- #scores = self.scorer.sentence_score(corrected, log=True)
53
- #ranked_corrected = [(c,s) for c, s in zip(corrected, scores)]
54
- #ranked_corrected.sort(key = lambda x:x[1], reverse=True)
55
- return corrected
56
- else:
57
- print("Model is not loaded")
58
- return None
59
-
60
- def highlight(self, orig, cor):
61
- edits = self._get_edits(orig, cor)
62
- orig_tokens = orig.split()
63
-
64
- ignore_indexes = []
65
-
66
- for edit in edits:
67
- edit_type = edit[0]
68
- edit_str_start = edit[1]
69
- edit_spos = edit[2]
70
- edit_epos = edit[3]
71
- edit_str_end = edit[4]
72
-
73
- # if no_of_tokens(edit_str_start) > 1 ==> excluding the first token, mark all other tokens for deletion
74
- for i in range(edit_spos+1, edit_epos):
75
- ignore_indexes.append(i)
76
-
77
- if edit_str_start == "":
78
- if edit_spos - 1 >= 0:
79
- new_edit_str = orig_tokens[edit_spos - 1]
80
- edit_spos -= 1
81
- else:
82
- new_edit_str = orig_tokens[edit_spos + 1]
83
- edit_spos += 1
84
- if edit_type == "PUNCT":
85
- st = "<a type='" + edit_type + "' edit='" + \
86
- edit_str_end + "'>" + new_edit_str + "</a>"
87
- else:
88
- st = "<a type='" + edit_type + "' edit='" + new_edit_str + \
89
- " " + edit_str_end + "'>" + new_edit_str + "</a>"
90
- orig_tokens[edit_spos] = st
91
- elif edit_str_end == "":
92
- st = "<d type='" + edit_type + "' edit=''>" + edit_str_start + "</d>"
93
- orig_tokens[edit_spos] = st
94
- else:
95
- st = "<c type='" + edit_type + "' edit='" + \
96
- edit_str_end + "'>" + edit_str_start + "</c>"
97
- orig_tokens[edit_spos] = st
98
-
99
- for i in sorted(ignore_indexes, reverse=True):
100
- del(orig_tokens[i])
101
-
102
- return(" ".join(orig_tokens))
103
-
104
- def detect(self, input_sentence):
105
  # TO BE IMPLEMENTED
106
  pass
107
 
108
- def _get_edits(self, orig, cor):
109
  orig = self.annotator.parse(orig)
110
  cor = self.annotator.parse(cor)
111
  alignment = self.annotator.align(orig, cor)
@@ -124,5 +122,5 @@ class Gramformer:
124
  else:
125
  return []
126
 
127
- def get_edits(self, orig, cor):
128
- return self._get_edits(orig, cor)
 
1
+ import spacy.cli
2
+ import errant
3
+
4
  class Gramformer:
5
 
6
+ def __init__(self, models=1, use_gpu=False):
7
+ from transformers import AutoTokenizer
8
+ from transformers import AutoModelForSeq2SeqLM
9
+
10
+ # Ensure the SpaCy model 'en_core_web_sm' is downloaded
11
+ spacy.cli.download("en_core_web_sm")
12
+
13
+ # Load the correct SpaCy model for errant
14
+ self.annotator = errant.load('en_core_web_sm')
15
+
16
+ if use_gpu:
17
+ device = "cuda:0"
18
+ else:
19
+ device = "cpu"
20
+
21
+ batch_size = 1
22
+ self.device = device
23
+ correction_model_tag = "prithivida/grammar_error_correcter_v1"
24
+ self.model_loaded = False
25
+
26
+ if models == 1:
27
+ self.correction_tokenizer = AutoTokenizer.from_pretrained(correction_model_tag, use_auth_token=False)
28
+ self.correction_model = AutoModelForSeq2SeqLM.from_pretrained(correction_model_tag, use_auth_token=False)
29
+ self.correction_model = self.correction_model.to(device)
30
+ self.model_loaded = True
31
+ print("[Gramformer] Grammar error correct/highlight model loaded..")
32
+ elif models == 2:
33
+ # TODO: Implement this part
34
+ print("TO BE IMPLEMENTED!!!")
35
+
36
+ def correct(self, input_sentence, max_candidates=1):
37
+ if self.model_loaded:
38
+ correction_prefix = "gec: "
39
+ input_sentence = correction_prefix + input_sentence
40
+ input_ids = self.correction_tokenizer.encode(input_sentence, return_tensors='pt')
41
+ input_ids = input_ids.to(self.device)
42
+
43
+ preds = self.correction_model.generate(
44
+ input_ids,
45
+ do_sample=True,
46
+ max_length=128,
47
+ num_beams=7,
48
+ early_stopping=True,
49
+ num_return_sequences=max_candidates
50
+ )
51
+
52
+ corrected = set()
53
+ for pred in preds:
54
+ corrected.add(self.correction_tokenizer.decode(pred, skip_special_tokens=True).strip())
55
+
56
+ return corrected
57
+ else:
58
+ print("Model is not loaded")
59
+ return None
60
+
61
+ def highlight(self, orig, cor):
62
+ edits = self._get_edits(orig, cor)
63
+ orig_tokens = orig.split()
64
+
65
+ ignore_indexes = []
66
+
67
+ for edit in edits:
68
+ edit_type = edit[0]
69
+ edit_str_start = edit[1]
70
+ edit_spos = edit[2]
71
+ edit_epos = edit[3]
72
+ edit_str_end = edit[4]
73
+
74
+ # if no_of_tokens(edit_str_start) > 1 ==> excluding the first token, mark all other tokens for deletion
75
+ for i in range(edit_spos + 1, edit_epos):
76
+ ignore_indexes.append(i)
77
+
78
+ if edit_str_start == "":
79
+ if edit_spos - 1 >= 0:
80
+ new_edit_str = orig_tokens[edit_spos - 1]
81
+ edit_spos -= 1
82
+ else:
83
+ new_edit_str = orig_tokens[edit_spos + 1]
84
+ edit_spos += 1
85
+ if edit_type == "PUNCT":
86
+ st = f"<a type='{edit_type}' edit='{edit_str_end}'>{new_edit_str}</a>"
87
+ else:
88
+ st = f"<a type='{edit_type}' edit='{new_edit_str} {edit_str_end}'>{new_edit_str}</a>"
89
+ orig_tokens[edit_spos] = st
90
+ elif edit_str_end == "":
91
+ st = f"<d type='{edit_type}' edit=''>{edit_str_start}</d>"
92
+ orig_tokens[edit_spos] = st
93
+ else:
94
+ st = f"<c type='{edit_type}' edit='{edit_str_end}'>{edit_str_start}</c>"
95
+ orig_tokens[edit_spos] = st
96
+
97
+ for i in sorted(ignore_indexes, reverse=True):
98
+ del orig_tokens[i]
99
+
100
+ return " ".join(orig_tokens)
101
+
102
+ def detect(self, input_sentence):
 
 
 
 
 
103
  # TO BE IMPLEMENTED
104
  pass
105
 
106
+ def _get_edits(self, orig, cor):
107
  orig = self.annotator.parse(orig)
108
  cor = self.annotator.parse(cor)
109
  alignment = self.annotator.align(orig, cor)
 
122
  else:
123
  return []
124
 
125
+ def get_edits(self, orig, cor):
126
+ return self._get_edits(orig, cor)