beki commited on
Commit
d91e05d
·
1 Parent(s): 292fc7e

Delete transformers_recognizer.py

Browse files
Files changed (1) hide show
  1. transformers_recognizer.py +0 -252
transformers_recognizer.py DELETED
@@ -1,252 +0,0 @@
1
- import logging
2
- from typing import Optional, List, Tuple, Set
3
-
4
- from presidio_analyzer import (
5
- RecognizerResult,
6
- EntityRecognizer,
7
- AnalysisExplanation,
8
- )
9
- from presidio_analyzer.nlp_engine import NlpArtifacts
10
-
11
- logger = logging.getLogger("presidio-analyzer")
12
-
13
- try:
14
- from transformers import (
15
- AutoTokenizer,
16
- AutoModelForTokenClassification,
17
- pipeline,
18
- models,
19
- )
20
- from transformers.models.bert.modeling_bert import BertForTokenClassification
21
- except ImportError:
22
- logger.error("transformers is not installed")
23
-
24
-
25
-
26
- class TransformersRecognizer(EntityRecognizer):
27
- """
28
- Wrapper for a transformers model, if needed to be used within Presidio Analyzer.
29
-
30
- :example:
31
- >from presidio_analyzer import AnalyzerEngine, RecognizerRegistry
32
-
33
- >transformers_recognizer = TransformersRecognizer()
34
-
35
- >registry = RecognizerRegistry()
36
- >registry.add_recognizer(transformers_recognizer)
37
-
38
- >analyzer = AnalyzerEngine(registry=registry)
39
-
40
- >results = analyzer.analyze(
41
- > "My name is Christopher and I live in Irbid.",
42
- > language="en",
43
- > return_decision_process=True,
44
- >)
45
- >for result in results:
46
- > print(result)
47
- > print(result.analysis_explanation)
48
-
49
-
50
- """
51
-
52
- ENTITIES = [
53
- "LOCATION",
54
- "PERSON",
55
- "ORGANIZATION",
56
- "AGE",
57
- "ID",
58
- "PHONE_NUMBER",
59
- "EMAIL",
60
- "DATE",
61
-
62
- ]
63
-
64
- DEFAULT_EXPLANATION = "Identified as {} by transformers's Named Entity Recognition"
65
-
66
- CHECK_LABEL_GROUPS = [
67
- ({"LOCATION"}, {"LOC", "HOSP"}),
68
- ({"PERSON"}, {"PER", "PERSON", "STAFF","PATIENT"}),
69
- ({"ORGANIZATION"}, {"ORGANIZATION", "ORG", "PATORG"}),
70
- ({"AGE"}, {"AGE"}),
71
- ({"ID"}, {"ID"}),
72
- ({"EMAIL"}, {"EMAIL"}),
73
- ({"DATE"}, {"DATE"}),
74
- ({"PHONE_NUMBER"}, {"PHONE"}),
75
-
76
- ]
77
-
78
- PRESIDIO_EQUIVALENCES = {
79
- "PER": "PERSON",
80
- "LOC": "LOCATION",
81
- "ORG": "ORGANIZATION",
82
- "AGE": "AGE",
83
- "ID": "ID",
84
- "EMAIL": "EMAIL",
85
- "PATIENT": "PERSON",
86
- "STAFF": "PERSON",
87
- "HOSP": "LOCATION",
88
- "PATORG": "ORGANIZATION",
89
- "DATE": "DATE_TIME",
90
- "PHONE": "PHONE_NUMBER",
91
- }
92
-
93
- DEFAULT_MODEL_PATH = "obi/deid_roberta_i2b2"
94
-
95
- def __init__(
96
- self,
97
- supported_entities: Optional[List[str]] = None,
98
- check_label_groups: Optional[Tuple[Set, Set]] = None,
99
- model: Optional[BertForTokenClassification] = None,
100
- model_path: Optional[str] = None,
101
- ):
102
- if not model and not model_path:
103
- model_path = self.DEFAULT_MODEL_PATH
104
- logger.warning(
105
- f"Both 'model' and 'model_path' arguments are None. Using default model_path={model_path}"
106
- )
107
-
108
- if model and model_path:
109
- logger.warning(
110
- f"Both 'model' and 'model_path' arguments were provided. Ignoring the model_path"
111
- )
112
-
113
- self.check_label_groups = (
114
- check_label_groups if check_label_groups else self.CHECK_LABEL_GROUPS
115
- )
116
-
117
- supported_entities = supported_entities if supported_entities else self.ENTITIES
118
- self.model = (
119
- model
120
- if model
121
- else pipeline(
122
- "ner",
123
- model=AutoModelForTokenClassification.from_pretrained(model_path),
124
- tokenizer=AutoTokenizer.from_pretrained(model_path),
125
- aggregation_strategy="simple",
126
- )
127
- )
128
-
129
- super().__init__(
130
- supported_entities=supported_entities, name="transformers Analytics",
131
- )
132
-
133
- def load(self) -> None:
134
- """Load the model, not used. Model is loaded during initialization."""
135
- pass
136
-
137
- def get_supported_entities(self) -> List[str]:
138
- """
139
- Return supported entities by this model.
140
-
141
- :return: List of the supported entities.
142
- """
143
- return self.supported_entities
144
-
145
- # Class to use transformers with Presidio as an external recognizer.
146
- def analyze(
147
- self, text: str, entities: List[str], nlp_artifacts: NlpArtifacts = None
148
- ) -> List[RecognizerResult]:
149
- """
150
- Analyze text using Text Analytics.
151
-
152
- :param text: The text for analysis.
153
- :param entities: Not working properly for this recognizer.
154
- :param nlp_artifacts: Not used by this recognizer.
155
- :return: The list of Presidio RecognizerResult constructed from the recognized
156
- transformers detections.
157
- """
158
-
159
- results = []
160
- ner_results = self.model(text)
161
-
162
- # If there are no specific list of entities, we will look for all of it.
163
- if not entities:
164
- entities = self.supported_entities
165
-
166
- for entity in entities:
167
- if entity not in self.supported_entities:
168
- continue
169
-
170
- for res in ner_results:
171
- if not self.__check_label(
172
- entity, res["entity_group"], self.check_label_groups
173
- ):
174
- continue
175
- textual_explanation = self.DEFAULT_EXPLANATION.format(
176
- res["entity_group"]
177
- )
178
- explanation = self.build_transformers_explanation(
179
- round(res["score"], 2), textual_explanation
180
- )
181
- transformers_result = self._convert_to_recognizer_result(
182
- res, explanation
183
- )
184
-
185
- results.append(transformers_result)
186
-
187
- return results
188
-
189
- def _convert_to_recognizer_result(self, res, explanation) -> RecognizerResult:
190
-
191
- entity_type = self.PRESIDIO_EQUIVALENCES.get(
192
- res["entity_group"], res["entity_group"]
193
- )
194
- transformers_score = round(res["score"], 2)
195
-
196
- transformers_results = RecognizerResult(
197
- entity_type=entity_type,
198
- start=res["start"],
199
- end=res["end"],
200
- score=transformers_score,
201
- analysis_explanation=explanation,
202
- )
203
-
204
- return transformers_results
205
-
206
- def build_transformers_explanation(
207
- self, original_score: float, explanation: str
208
- ) -> AnalysisExplanation:
209
- """
210
- Create explanation for why this result was detected.
211
-
212
- :param original_score: Score given by this recognizer
213
- :param explanation: Explanation string
214
- :return:
215
- """
216
- explanation = AnalysisExplanation(
217
- recognizer=self.__class__.__name__,
218
- original_score=original_score,
219
- textual_explanation=explanation,
220
- )
221
- return explanation
222
-
223
- @staticmethod
224
- def __check_label(
225
- entity: str, label: str, check_label_groups: Tuple[Set, Set]
226
- ) -> bool:
227
- return any(
228
- [entity in egrp and label in lgrp for egrp, lgrp in check_label_groups]
229
- )
230
-
231
-
232
- if __name__ == "__main__":
233
-
234
- from presidio_analyzer import AnalyzerEngine, RecognizerRegistry
235
-
236
- transformers_recognizer = (
237
- TransformersRecognizer()
238
- ) # This would download a large (~500Mb) model on the first run
239
-
240
- registry = RecognizerRegistry()
241
- registry.add_recognizer(transformers_recognizer)
242
-
243
- analyzer = AnalyzerEngine(registry=registry)
244
-
245
- results = analyzer.analyze(
246
- "My name is Christopher and I live in Irbid.",
247
- language="en",
248
- return_decision_process=True,
249
- )
250
- for result in results:
251
- print(result)
252
- print(result.analysis_explanation)