|
from __future__ import annotations |
|
|
|
import logging |
|
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, TypeVar |
|
|
|
from pie_modules.document.processing.merge_spans_via_relation import _merge_spans_via_relation |
|
from pie_modules.documents import TextDocumentWithLabeledMultiSpansAndBinaryRelations |
|
from pie_modules.utils.span import have_overlap |
|
from pytorch_ie import AnnotationLayer |
|
from pytorch_ie.core import Document |
|
from pytorch_ie.core.document import Annotation, _enumerate_dependencies |
|
|
|
from src.utils import distance |
|
from src.utils.span_utils import get_overlap_len |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
D = TypeVar("D", bound=Document) |
|
|
|
|
|
def _remove_overlapping_entities( |
|
entities: Iterable[Dict[str, Any]], relations: Iterable[Dict[str, Any]] |
|
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: |
|
sorted_entities = sorted(entities, key=lambda span: span["start"]) |
|
entities_wo_overlap = [] |
|
skipped_entities = [] |
|
last_end = 0 |
|
for entity_dict in sorted_entities: |
|
if entity_dict["start"] < last_end: |
|
skipped_entities.append(entity_dict) |
|
else: |
|
entities_wo_overlap.append(entity_dict) |
|
last_end = entity_dict["end"] |
|
if len(skipped_entities) > 0: |
|
logger.warning(f"skipped overlapping entities: {skipped_entities}") |
|
valid_entity_ids = set(entity_dict["_id"] for entity_dict in entities_wo_overlap) |
|
valid_relations = [ |
|
relation_dict |
|
for relation_dict in relations |
|
if relation_dict["head"] in valid_entity_ids and relation_dict["tail"] in valid_entity_ids |
|
] |
|
return entities_wo_overlap, valid_relations |
|
|
|
|
|
def remove_overlapping_entities( |
|
doc: D, |
|
entity_layer_name: str = "entities", |
|
relation_layer_name: str = "relations", |
|
) -> D: |
|
|
|
document_dict = doc.asdict() |
|
entities_wo_overlap, valid_relations = _remove_overlapping_entities( |
|
entities=document_dict[entity_layer_name]["annotations"], |
|
relations=document_dict[relation_layer_name]["annotations"], |
|
) |
|
|
|
document_dict[entity_layer_name] = { |
|
"annotations": entities_wo_overlap, |
|
"predictions": [], |
|
} |
|
document_dict[relation_layer_name] = { |
|
"annotations": valid_relations, |
|
"predictions": [], |
|
} |
|
new_doc = type(doc).fromdict(document_dict) |
|
|
|
return new_doc |
|
|
|
|
|
|
|
def merge_spans_via_relation( |
|
document: D, |
|
relation_layer: str, |
|
link_relation_label: str, |
|
use_predicted_spans: bool = False, |
|
process_predictions: bool = True, |
|
create_multi_spans: bool = False, |
|
) -> D: |
|
|
|
rel_layer = document[relation_layer] |
|
span_layer = rel_layer.target_layer |
|
new_gold_spans, new_gold_relations = _merge_spans_via_relation( |
|
spans=span_layer, |
|
relations=rel_layer, |
|
link_relation_label=link_relation_label, |
|
create_multi_spans=create_multi_spans, |
|
) |
|
if process_predictions: |
|
new_pred_spans, new_pred_relations = _merge_spans_via_relation( |
|
spans=span_layer.predictions if use_predicted_spans else span_layer, |
|
relations=rel_layer.predictions, |
|
link_relation_label=link_relation_label, |
|
create_multi_spans=create_multi_spans, |
|
) |
|
else: |
|
assert not use_predicted_spans |
|
new_pred_spans = set(span_layer.predictions.clear()) |
|
new_pred_relations = set(rel_layer.predictions.clear()) |
|
|
|
relation_layer_name = relation_layer |
|
span_layer_name = document[relation_layer].target_name |
|
if create_multi_spans: |
|
doc_dict = document.asdict() |
|
for f in document.annotation_fields(): |
|
doc_dict.pop(f.name) |
|
|
|
result = TextDocumentWithLabeledMultiSpansAndBinaryRelations.fromdict(doc_dict) |
|
result.labeled_multi_spans.extend(new_gold_spans) |
|
result.labeled_multi_spans.predictions.extend(new_pred_spans) |
|
result.binary_relations.extend(new_gold_relations) |
|
result.binary_relations.predictions.extend(new_pred_relations) |
|
else: |
|
result = document.copy(with_annotations=False) |
|
result[span_layer_name].extend(new_gold_spans) |
|
result[span_layer_name].predictions.extend(new_pred_spans) |
|
result[relation_layer_name].extend(new_gold_relations) |
|
result[relation_layer_name].predictions.extend(new_pred_relations) |
|
|
|
return result |
|
|
|
|
|
def remove_partitions_by_labels( |
|
document: D, partition_layer: str, label_blacklist: List[str], span_layer: Optional[str] = None |
|
) -> D: |
|
"""Remove partitions with labels in the blacklist from a document. |
|
|
|
Args: |
|
document: The document to process. |
|
partition_layer: The name of the partition layer. |
|
label_blacklist: The list of labels to remove. |
|
span_layer: The name of the span layer to remove spans from if they are not fully |
|
contained in any remaining partition. Any dependent annotations will be removed as well. |
|
|
|
Returns: |
|
The processed document. |
|
""" |
|
|
|
document = document.copy() |
|
p_layer: AnnotationLayer = document[partition_layer] |
|
new_partitions = [] |
|
for partition in p_layer.clear(): |
|
if partition.label not in label_blacklist: |
|
new_partitions.append(partition) |
|
p_layer.extend(new_partitions) |
|
|
|
if span_layer is not None: |
|
result = document.copy(with_annotations=False) |
|
removed_span_ids = set() |
|
for span in document[span_layer]: |
|
|
|
if any( |
|
partition.start <= span.start and span.end <= partition.end |
|
for partition in new_partitions |
|
): |
|
result[span_layer].append(span.copy()) |
|
else: |
|
removed_span_ids.add(span._id) |
|
|
|
result.add_all_annotations_from_other( |
|
document, |
|
removed_annotations={span_layer: removed_span_ids}, |
|
strict=False, |
|
verbose=False, |
|
) |
|
document = result |
|
|
|
return document |
|
|
|
|
|
D_text = TypeVar("D_text", bound=Document) |
|
|
|
|
|
def replace_substrings_in_text( |
|
document: D_text, replacements: Dict[str, str], enforce_same_length: bool = True |
|
) -> D_text: |
|
new_text = document.text |
|
for old_str, new_str in replacements.items(): |
|
if enforce_same_length and len(old_str) != len(new_str): |
|
raise ValueError( |
|
f'Replacement strings must have the same length, but got "{old_str}" -> "{new_str}"' |
|
) |
|
new_text = new_text.replace(old_str, new_str) |
|
result_dict = document.asdict() |
|
result_dict["text"] = new_text |
|
result = type(document).fromdict(result_dict) |
|
result.text = new_text |
|
return result |
|
|
|
|
|
def replace_substrings_in_text_with_spaces(document: D_text, substrings: Iterable[str]) -> D_text: |
|
replacements = {substring: " " * len(substring) for substring in substrings} |
|
return replace_substrings_in_text(document, replacements=replacements) |
|
|
|
|
|
def relabel_annotations( |
|
document: D, |
|
label_mapping: Dict[str, Dict[str, str]], |
|
) -> D: |
|
""" |
|
Replace annotation labels in a document. |
|
|
|
Args: |
|
document: The document to process. |
|
label_mapping: A mapping from layer names to mappings from old labels to new labels. |
|
|
|
Returns: |
|
The processed document. |
|
|
|
""" |
|
|
|
dependency_ordered_fields: List[str] = [] |
|
_enumerate_dependencies( |
|
dependency_ordered_fields, |
|
dependency_graph=document._annotation_graph, |
|
nodes=document._annotation_graph["_artificial_root"], |
|
) |
|
result = document.copy(with_annotations=False) |
|
store: Dict[int, Annotation] = {} |
|
|
|
invalid_annotation_ids: Set[int] = set() |
|
for field_name in dependency_ordered_fields: |
|
if field_name in document._annotation_fields: |
|
layer = document[field_name] |
|
for is_prediction, anns in [(False, layer), (True, layer.predictions)]: |
|
for ann in anns: |
|
new_ann = ann.copy_with_store( |
|
override_annotation_store=store, |
|
invalid_annotation_ids=invalid_annotation_ids, |
|
) |
|
if field_name in label_mapping: |
|
if ann.label in label_mapping[field_name]: |
|
new_label = label_mapping[field_name][ann.label] |
|
new_ann = new_ann.copy(label=new_label) |
|
else: |
|
raise ValueError( |
|
f"Label {ann.label} not found in label mapping for {field_name}" |
|
) |
|
store[ann._id] = new_ann |
|
target_layer = result[field_name] |
|
if is_prediction: |
|
target_layer.predictions.append(new_ann) |
|
else: |
|
target_layer.append(new_ann) |
|
|
|
return result |
|
|
|
|
|
DWithSpans = TypeVar("DWithSpans", bound=Document) |
|
|
|
|
|
def align_predicted_span_annotations( |
|
document: DWithSpans, span_layer: str, distance_type: str = "center", verbose: bool = False |
|
) -> DWithSpans: |
|
""" |
|
Aligns predicted span annotations with the closest gold spans in a document. |
|
|
|
First, calculates the distance between each predicted span and each gold span. Then, |
|
for each predicted span, the gold span with the smallest distance is selected. If the |
|
predicted span and the gold span have an overlap of at least half of the maximum length |
|
of the two spans, the predicted span is aligned with the gold span. |
|
|
|
Args: |
|
document: The document to process. |
|
span_layer: The name of the span layer. |
|
distance_type: The type of distance to calculate. One of: center, inner, outer |
|
verbose: Whether to print debug information. |
|
|
|
Returns: |
|
The processed document. |
|
""" |
|
gold_spans = document[span_layer] |
|
if len(gold_spans) == 0: |
|
return document.copy() |
|
|
|
pred_spans = document[span_layer].predictions |
|
old2new_pred_span = {} |
|
span_id2gold_span = {} |
|
for pred_span in pred_spans: |
|
|
|
gold_spans_with_distance = [ |
|
( |
|
gold_span, |
|
distance( |
|
start_end=(pred_span.start, pred_span.end), |
|
other_start_end=(gold_span.start, gold_span.end), |
|
distance_type=distance_type, |
|
), |
|
) |
|
for gold_span in gold_spans |
|
] |
|
|
|
closest_gold_span, min_distance = min(gold_spans_with_distance, key=lambda x: x[1]) |
|
|
|
if min_distance == 0.0: |
|
continue |
|
|
|
if have_overlap( |
|
start_end=(pred_span.start, pred_span.end), |
|
other_start_end=(closest_gold_span.start, closest_gold_span.end), |
|
): |
|
overlap_len = get_overlap_len( |
|
(pred_span.start, pred_span.end), (closest_gold_span.start, closest_gold_span.end) |
|
) |
|
|
|
l_max = max( |
|
pred_span.end - pred_span.start, closest_gold_span.end - closest_gold_span.start |
|
) |
|
|
|
valid_match = overlap_len >= (l_max / 2) |
|
else: |
|
valid_match = False |
|
|
|
if valid_match: |
|
aligned_pred_span = pred_span.copy( |
|
start=closest_gold_span.start, end=closest_gold_span.end |
|
) |
|
old2new_pred_span[pred_span._id] = aligned_pred_span |
|
span_id2gold_span[pred_span._id] = closest_gold_span |
|
|
|
result = document.copy(with_annotations=False) |
|
|
|
|
|
|
|
added_pred_span_ids = dict() |
|
for pred_span in pred_spans: |
|
|
|
if pred_span._id not in old2new_pred_span: |
|
|
|
if pred_span._id not in added_pred_span_ids: |
|
keep_pred_span = pred_span.copy() |
|
result[span_layer].predictions.append(keep_pred_span) |
|
added_pred_span_ids[pred_span._id] = keep_pred_span |
|
elif verbose: |
|
print(f"Skipping duplicate predicted span. pred_span='{str(pred_span)}'") |
|
else: |
|
aligned_pred_span = old2new_pred_span[pred_span._id] |
|
|
|
if aligned_pred_span._id not in added_pred_span_ids: |
|
result[span_layer].predictions.append(aligned_pred_span) |
|
added_pred_span_ids[aligned_pred_span._id] = aligned_pred_span |
|
elif verbose: |
|
prev_pred_span = added_pred_span_ids[aligned_pred_span._id] |
|
gold_span = span_id2gold_span[pred_span._id] |
|
print( |
|
f"Skipping duplicate aligned predicted span. aligned gold_span='{str(gold_span)}', " |
|
f"prev_pred_span='{str(prev_pred_span)}', current_pred_span='{str(pred_span)}'" |
|
) |
|
|
|
|
|
result[span_layer].extend([span.copy() for span in gold_spans]) |
|
|
|
|
|
_aligned_spans = result.add_all_annotations_from_other( |
|
document, override_annotations={span_layer: old2new_pred_span} |
|
) |
|
|
|
return result |
|
|