marcel1997 commited on
Commit
5ae56e7
·
verified ·
1 Parent(s): 69404f8

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +38 -20
README.md CHANGED
@@ -102,7 +102,7 @@ formats = {
102
  "entity_swapping": """<|im_start|>system\nEntity Swapping<|im_end|>\n<|im_start|>user\nentities:{entities}\ntext:\n{text}<|im_end|>\n<|im_start|>assistant\n"""
103
  }
104
 
105
- def model_inference(text, mode="anonymization", max_new_tokens=50, config=None, entity_mapping=None):
106
  if mode not in formats and mode != "anonymization":
107
  raise ValueError("Invalid mode. Choose from 'sensitivity', 'complexity', 'entity_detection', 'anonymization'.")
108
 
@@ -154,7 +154,6 @@ def model_inference(text, mode="anonymization", max_new_tokens=50, config=None,
154
  # Step 2: Select entities based on config
155
  selected_entities = select_entities_based_on_json(detected_entities, config)
156
  entities_str = "\n".join([f"{entity} : {label}" for entity, label in selected_entities])
157
-
158
  # Step 3: Entity swapping for anonymization
159
  swapping_prompt = formats["entity_swapping"].format(entities=entities_str, text=text)
160
  swapping_inputs = tokenizer(swapping_prompt, return_tensors="pt").to(device)
@@ -168,24 +167,25 @@ def model_inference(text, mode="anonymization", max_new_tokens=50, config=None,
168
  anonymized_text = tokenizer.decode(swapping_output[0], skip_special_tokens=True)
169
  anonymized_text = anonymized_text.split("assistant\n", 1)[-1].strip() # Extract only the assistant's response
170
 
171
- return anonymized_text, detected_entities
 
 
 
172
 
173
  # Entity Restoration Mode using entity_swapping
174
  elif mode == "entity_swapping" and entity_mapping:
175
- # Aggregate RANDOM and GENERAL replacements for restoration
176
- reversed_entities = []
177
- for original, details in entity_mapping.items():
178
- # Include RANDOM replacement
179
- reversed_entities.append(f"{details['RANDOM']} : {original}")
180
- # Include GENERAL replacements
181
- for general_label, _ in details["GENERAL"]:
182
- reversed_entities.append(f"{general_label} : {original}")
183
 
184
- # Combine all replacement mappings for the prompt
185
- reversed_entities_str = "\n".join(reversed_entities)
186
 
187
  # Create the swapping prompt with the aggregated reversed mappings
188
- swapping_prompt = formats["entity_swapping"].format(entities=reversed_entities_str, text=text)
189
  swapping_inputs = tokenizer(swapping_prompt, return_tensors="pt").to(device)
190
  swapping_output = model.generate(
191
  **swapping_inputs,
@@ -206,7 +206,7 @@ def model_inference(text, mode="anonymization", max_new_tokens=50, config=None,
206
  model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
207
  generation_output = model.generate(
208
  **model_inputs,
209
- max_new_tokens=max_new_tokens,
210
  use_cache=True,
211
  eos_token_id=151645
212
  )
@@ -224,7 +224,7 @@ def postprocess_entity_recognition(detection_output: str) -> dict:
224
  entity_pattern = re.compile(
225
  r'(?P<entity>[\w\s]+)--(?P<type>[\w]+)--(?P<random>[\w\s]+)--(?P<generalizations>.+)'
226
  )
227
- generalization_pattern = re.compile(r'(\d+)::([\w\s]+)')
228
 
229
  lines = detection_output.strip().split("\n")
230
  for line in lines:
@@ -236,8 +236,21 @@ def postprocess_entity_recognition(detection_output: str) -> dict:
236
 
237
  generalizations = []
238
  for gen_match in generalization_pattern.findall(match.group("generalizations")):
239
- score, label = gen_match
240
- generalizations.append([label.strip(), score.strip()])
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
  output_json[entity_name] = {
243
  "TYPE": entity_type,
@@ -304,11 +317,16 @@ To protect sensitive information, the model detects specific entities in the tex
304
 
305
  ```python
306
  # Anonymize the text
307
- anonymized_text, entity_mapping = model_inference(text, mode="anonymization")
308
  print(f"Anonymized Text: {anonymized_text}\n")
 
309
 
 
310
  # Restore the original text
311
- restored_text = model_inference(anonymized_text, mode="entity_swapping", entity_mapping=entity_mapping)
 
 
 
312
  print(f"Restored Text: {restored_text}")
313
  ```
314
 
 
102
  "entity_swapping": """<|im_start|>system\nEntity Swapping<|im_end|>\n<|im_start|>user\nentities:{entities}\ntext:\n{text}<|im_end|>\n<|im_start|>assistant\n"""
103
  }
104
 
105
+ def model_inference(text, mode="anonymization", max_new_tokens=2028, config=None, entity_mapping=None, return_entities=False, reverse_mapping=False):
106
  if mode not in formats and mode != "anonymization":
107
  raise ValueError("Invalid mode. Choose from 'sensitivity', 'complexity', 'entity_detection', 'anonymization'.")
108
 
 
154
  # Step 2: Select entities based on config
155
  selected_entities = select_entities_based_on_json(detected_entities, config)
156
  entities_str = "\n".join([f"{entity} : {label}" for entity, label in selected_entities])
 
157
  # Step 3: Entity swapping for anonymization
158
  swapping_prompt = formats["entity_swapping"].format(entities=entities_str, text=text)
159
  swapping_inputs = tokenizer(swapping_prompt, return_tensors="pt").to(device)
 
167
  anonymized_text = tokenizer.decode(swapping_output[0], skip_special_tokens=True)
168
  anonymized_text = anonymized_text.split("assistant\n", 1)[-1].strip() # Extract only the assistant's response
169
 
170
+ if return_entities:
171
+ return anonymized_text, entities_str
172
+
173
+ return anonymized_text
174
 
175
  # Entity Restoration Mode using entity_swapping
176
  elif mode == "entity_swapping" and entity_mapping:
177
+ # Reverse the entity mapping
178
+ if reverse_mapping:
179
+ reversed_mapping = []
180
+ for line in entity_mapping.splitlines():
181
+ if ':' in line: # Ensure the line contains a colon
182
+ left, right = map(str.strip, line.split(":", 1)) # Split and strip spaces
183
+ reversed_mapping.append(f"{right} : {left}") # Reverse and format
184
+ entity_mapping = "\n".join(reversed_mapping)
185
 
 
 
186
 
187
  # Create the swapping prompt with the aggregated reversed mappings
188
+ swapping_prompt = formats["entity_swapping"].format(entities=entity_mapping, text=text)
189
  swapping_inputs = tokenizer(swapping_prompt, return_tensors="pt").to(device)
190
  swapping_output = model.generate(
191
  **swapping_inputs,
 
206
  model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
207
  generation_output = model.generate(
208
  **model_inputs,
209
+ max_new_tokens=5,
210
  use_cache=True,
211
  eos_token_id=151645
212
  )
 
224
  entity_pattern = re.compile(
225
  r'(?P<entity>[\w\s]+)--(?P<type>[\w]+)--(?P<random>[\w\s]+)--(?P<generalizations>.+)'
226
  )
227
+ generalization_pattern = re.compile(r'([\w\s]+)::([\w\s]+)')
228
 
229
  lines = detection_output.strip().split("\n")
230
  for line in lines:
 
236
 
237
  generalizations = []
238
  for gen_match in generalization_pattern.findall(match.group("generalizations")):
239
+ first, second = gen_match
240
+
241
+ # Check if the first part is a digit (score) and swap if needed
242
+ if first.isdigit() and not second.isdigit():
243
+ score = first
244
+ label = second
245
+ generalizations.append([label.strip(), score.strip()])
246
+
247
+ elif not first.isdigit() and second.isdigit():
248
+ label = first
249
+ score = second
250
+ generalizations.append([label.strip(), score.strip()])
251
+
252
+
253
+
254
 
255
  output_json[entity_name] = {
256
  "TYPE": entity_type,
 
317
 
318
  ```python
319
  # Anonymize the text
320
+ anonymized_text = model_inference(text, mode="anonymization")
321
  print(f"Anonymized Text: {anonymized_text}\n")
322
+ ```
323
 
324
+ ```python
325
  # Restore the original text
326
+ anonymized_text, entity_mapping = model_inference(text, mode="anonymization", return_entities=True)
327
+ print(f"Entity Mapping:\n{entity_mapping}\n")
328
+ print(f"Anonymized Text: {anonymized_text}\n")
329
+ restored_text = model_inference(anonymized_text, mode="entity_swapping", entity_mapping=entity_mapping, reverse_mapping=True)
330
  print(f"Restored Text: {restored_text}")
331
  ```
332