Update README.md
Browse files
README.md
CHANGED
@@ -102,7 +102,7 @@ formats = {
|
|
102 |
"entity_swapping": """<|im_start|>system\nEntity Swapping<|im_end|>\n<|im_start|>user\nentities:{entities}\ntext:\n{text}<|im_end|>\n<|im_start|>assistant\n"""
|
103 |
}
|
104 |
|
105 |
-
def model_inference(text, mode="anonymization", max_new_tokens=
|
106 |
if mode not in formats and mode != "anonymization":
|
107 |
raise ValueError("Invalid mode. Choose from 'sensitivity', 'complexity', 'entity_detection', 'anonymization'.")
|
108 |
|
@@ -154,7 +154,6 @@ def model_inference(text, mode="anonymization", max_new_tokens=50, config=None,
|
|
154 |
# Step 2: Select entities based on config
|
155 |
selected_entities = select_entities_based_on_json(detected_entities, config)
|
156 |
entities_str = "\n".join([f"{entity} : {label}" for entity, label in selected_entities])
|
157 |
-
|
158 |
# Step 3: Entity swapping for anonymization
|
159 |
swapping_prompt = formats["entity_swapping"].format(entities=entities_str, text=text)
|
160 |
swapping_inputs = tokenizer(swapping_prompt, return_tensors="pt").to(device)
|
@@ -168,24 +167,25 @@ def model_inference(text, mode="anonymization", max_new_tokens=50, config=None,
|
|
168 |
anonymized_text = tokenizer.decode(swapping_output[0], skip_special_tokens=True)
|
169 |
anonymized_text = anonymized_text.split("assistant\n", 1)[-1].strip() # Extract only the assistant's response
|
170 |
|
171 |
-
|
|
|
|
|
|
|
172 |
|
173 |
# Entity Restoration Mode using entity_swapping
|
174 |
elif mode == "entity_swapping" and entity_mapping:
|
175 |
-
#
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
|
184 |
-
# Combine all replacement mappings for the prompt
|
185 |
-
reversed_entities_str = "\n".join(reversed_entities)
|
186 |
|
187 |
# Create the swapping prompt with the aggregated reversed mappings
|
188 |
-
swapping_prompt = formats["entity_swapping"].format(entities=
|
189 |
swapping_inputs = tokenizer(swapping_prompt, return_tensors="pt").to(device)
|
190 |
swapping_output = model.generate(
|
191 |
**swapping_inputs,
|
@@ -206,7 +206,7 @@ def model_inference(text, mode="anonymization", max_new_tokens=50, config=None,
|
|
206 |
model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
207 |
generation_output = model.generate(
|
208 |
**model_inputs,
|
209 |
-
max_new_tokens=
|
210 |
use_cache=True,
|
211 |
eos_token_id=151645
|
212 |
)
|
@@ -224,7 +224,7 @@ def postprocess_entity_recognition(detection_output: str) -> dict:
|
|
224 |
entity_pattern = re.compile(
|
225 |
r'(?P<entity>[\w\s]+)--(?P<type>[\w]+)--(?P<random>[\w\s]+)--(?P<generalizations>.+)'
|
226 |
)
|
227 |
-
generalization_pattern = re.compile(r'(\
|
228 |
|
229 |
lines = detection_output.strip().split("\n")
|
230 |
for line in lines:
|
@@ -236,8 +236,21 @@ def postprocess_entity_recognition(detection_output: str) -> dict:
|
|
236 |
|
237 |
generalizations = []
|
238 |
for gen_match in generalization_pattern.findall(match.group("generalizations")):
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
output_json[entity_name] = {
|
243 |
"TYPE": entity_type,
|
@@ -304,11 +317,16 @@ To protect sensitive information, the model detects specific entities in the tex
|
|
304 |
|
305 |
```python
|
306 |
# Anonymize the text
|
307 |
-
anonymized_text
|
308 |
print(f"Anonymized Text: {anonymized_text}\n")
|
|
|
309 |
|
|
|
310 |
# Restore the original text
|
311 |
-
|
|
|
|
|
|
|
312 |
print(f"Restored Text: {restored_text}")
|
313 |
```
|
314 |
|
|
|
102 |
"entity_swapping": """<|im_start|>system\nEntity Swapping<|im_end|>\n<|im_start|>user\nentities:{entities}\ntext:\n{text}<|im_end|>\n<|im_start|>assistant\n"""
|
103 |
}
|
104 |
|
105 |
+
def model_inference(text, mode="anonymization", max_new_tokens=2028, config=None, entity_mapping=None, return_entities=False, reverse_mapping=False):
|
106 |
if mode not in formats and mode != "anonymization":
|
107 |
raise ValueError("Invalid mode. Choose from 'sensitivity', 'complexity', 'entity_detection', 'anonymization'.")
|
108 |
|
|
|
154 |
# Step 2: Select entities based on config
|
155 |
selected_entities = select_entities_based_on_json(detected_entities, config)
|
156 |
entities_str = "\n".join([f"{entity} : {label}" for entity, label in selected_entities])
|
|
|
157 |
# Step 3: Entity swapping for anonymization
|
158 |
swapping_prompt = formats["entity_swapping"].format(entities=entities_str, text=text)
|
159 |
swapping_inputs = tokenizer(swapping_prompt, return_tensors="pt").to(device)
|
|
|
167 |
anonymized_text = tokenizer.decode(swapping_output[0], skip_special_tokens=True)
|
168 |
anonymized_text = anonymized_text.split("assistant\n", 1)[-1].strip() # Extract only the assistant's response
|
169 |
|
170 |
+
if return_entities:
|
171 |
+
return anonymized_text, entities_str
|
172 |
+
|
173 |
+
return anonymized_text
|
174 |
|
175 |
# Entity Restoration Mode using entity_swapping
|
176 |
elif mode == "entity_swapping" and entity_mapping:
|
177 |
+
# Reverse the entity mapping
|
178 |
+
if reverse_mapping:
|
179 |
+
reversed_mapping = []
|
180 |
+
for line in entity_mapping.splitlines():
|
181 |
+
if ':' in line: # Ensure the line contains a colon
|
182 |
+
left, right = map(str.strip, line.split(":", 1)) # Split and strip spaces
|
183 |
+
reversed_mapping.append(f"{right} : {left}") # Reverse and format
|
184 |
+
entity_mapping = "\n".join(reversed_mapping)
|
185 |
|
|
|
|
|
186 |
|
187 |
# Create the swapping prompt with the aggregated reversed mappings
|
188 |
+
swapping_prompt = formats["entity_swapping"].format(entities=entity_mapping, text=text)
|
189 |
swapping_inputs = tokenizer(swapping_prompt, return_tensors="pt").to(device)
|
190 |
swapping_output = model.generate(
|
191 |
**swapping_inputs,
|
|
|
206 |
model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
207 |
generation_output = model.generate(
|
208 |
**model_inputs,
|
209 |
+
max_new_tokens=5,
|
210 |
use_cache=True,
|
211 |
eos_token_id=151645
|
212 |
)
|
|
|
224 |
entity_pattern = re.compile(
|
225 |
r'(?P<entity>[\w\s]+)--(?P<type>[\w]+)--(?P<random>[\w\s]+)--(?P<generalizations>.+)'
|
226 |
)
|
227 |
+
generalization_pattern = re.compile(r'([\w\s]+)::([\w\s]+)')
|
228 |
|
229 |
lines = detection_output.strip().split("\n")
|
230 |
for line in lines:
|
|
|
236 |
|
237 |
generalizations = []
|
238 |
for gen_match in generalization_pattern.findall(match.group("generalizations")):
|
239 |
+
first, second = gen_match
|
240 |
+
|
241 |
+
# Check if the first part is a digit (score) and swap if needed
|
242 |
+
if first.isdigit() and not second.isdigit():
|
243 |
+
score = first
|
244 |
+
label = second
|
245 |
+
generalizations.append([label.strip(), score.strip()])
|
246 |
+
|
247 |
+
elif not first.isdigit() and second.isdigit():
|
248 |
+
label = first
|
249 |
+
score = second
|
250 |
+
generalizations.append([label.strip(), score.strip()])
|
251 |
+
|
252 |
+
|
253 |
+
|
254 |
|
255 |
output_json[entity_name] = {
|
256 |
"TYPE": entity_type,
|
|
|
317 |
|
318 |
```python
|
319 |
# Anonymize the text
|
320 |
+
anonymized_text = model_inference(text, mode="anonymization")
|
321 |
print(f"Anonymized Text: {anonymized_text}\n")
|
322 |
+
```
|
323 |
|
324 |
+
```python
|
325 |
# Restore the original text
|
326 |
+
anonymized_text, entity_mapping = model_inference(text, mode="anonymization", return_entities=True)
|
327 |
+
print(f"Entity Mapping:\n{entity_mapping}\n")
|
328 |
+
print(f"Anonymized Text: {anonymized_text}\n")
|
329 |
+
restored_text = model_inference(anonymized_text, mode="entity_swapping", entity_mapping=entity_mapping, reverse_mapping=True)
|
330 |
print(f"Restored Text: {restored_text}")
|
331 |
```
|
332 |
|