File size: 4,504 Bytes
38ba5af
 
 
3541ec5
 
38ba5af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3541ec5
 
 
38ba5af
3541ec5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38ba5af
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import spacy.cli
import errant

class Gramformer:

    def __init__(self, models=1, use_gpu=False):
        from transformers import AutoTokenizer
        from transformers import AutoModelForSeq2SeqLM
        
        # Ensure the SpaCy model 'en_core_web_sm' is downloaded
        spacy.cli.download("en_core_web_sm")
        
        # Load the correct SpaCy model for errant
        self.annotator = errant.load('en_core_web_sm')
        
        if use_gpu:
            device = "cuda:0"
        else:
            device = "cpu"
            
        batch_size = 1    
        self.device = device
        correction_model_tag = "prithivida/grammar_error_correcter_v1"
        self.model_loaded = False

        if models == 1:
            self.correction_tokenizer = AutoTokenizer.from_pretrained(correction_model_tag, use_auth_token=False)
            self.correction_model = AutoModelForSeq2SeqLM.from_pretrained(correction_model_tag, use_auth_token=False)
            self.correction_model = self.correction_model.to(device)
            self.model_loaded = True
            print("[Gramformer] Grammar error correct/highlight model loaded..")
        elif models == 2:
            # TODO: Implement this part
            print("TO BE IMPLEMENTED!!!")

    def correct(self, input_sentence, max_candidates=1):
        if self.model_loaded:
            correction_prefix = "gec: "
            input_sentence = correction_prefix + input_sentence
            input_ids = self.correction_tokenizer.encode(input_sentence, return_tensors='pt')
            input_ids = input_ids.to(self.device)

            preds = self.correction_model.generate(
                input_ids,
                do_sample=True, 
                max_length=128, 
                num_beams=7,
                early_stopping=True,
                num_return_sequences=max_candidates
            )

            corrected = set()
            for pred in preds:  
                corrected.add(self.correction_tokenizer.decode(pred, skip_special_tokens=True).strip())

            return corrected
        else:
            print("Model is not loaded")  
            return None

    def highlight(self, orig, cor):
        edits = self._get_edits(orig, cor)
        orig_tokens = orig.split()

        ignore_indexes = []

        for edit in edits:
            edit_type = edit[0]
            edit_str_start = edit[1]
            edit_spos = edit[2]
            edit_epos = edit[3]
            edit_str_end = edit[4]

            # if no_of_tokens(edit_str_start) > 1 ==> excluding the first token, mark all other tokens for deletion
            for i in range(edit_spos + 1, edit_epos):
                ignore_indexes.append(i)

            if edit_str_start == "":
                if edit_spos - 1 >= 0:
                    new_edit_str = orig_tokens[edit_spos - 1]
                    edit_spos -= 1
                else:
                    new_edit_str = orig_tokens[edit_spos + 1]
                    edit_spos += 1
                if edit_type == "PUNCT":
                    st = f"<a type='{edit_type}' edit='{edit_str_end}'>{new_edit_str}</a>"
                else:
                    st = f"<a type='{edit_type}' edit='{new_edit_str} {edit_str_end}'>{new_edit_str}</a>"
                orig_tokens[edit_spos] = st
            elif edit_str_end == "":
                st = f"<d type='{edit_type}' edit=''>{edit_str_start}</d>"
                orig_tokens[edit_spos] = st
            else:
                st = f"<c type='{edit_type}' edit='{edit_str_end}'>{edit_str_start}</c>"
                orig_tokens[edit_spos] = st

        for i in sorted(ignore_indexes, reverse=True):
            del orig_tokens[i]

        return " ".join(orig_tokens)

    def detect(self, input_sentence):
        # TO BE IMPLEMENTED
        pass

    def _get_edits(self, orig, cor):
        orig = self.annotator.parse(orig)
        cor = self.annotator.parse(cor)
        alignment = self.annotator.align(orig, cor)
        edits = self.annotator.merge(alignment)

        if len(edits) == 0:  
            return []

        edit_annotations = []
        for e in edits:
            e = self.annotator.classify(e)
            edit_annotations.append((e.type[2:], e.o_str, e.o_start, e.o_end,  e.c_str, e.c_start, e.c_end))
                
        if len(edit_annotations) > 0:
            return edit_annotations
        else:    
            return []

    def get_edits(self, orig, cor):
        return self._get_edits(orig, cor)