ArneBinder's picture
update from https://github.com/ArneBinder/pie-document-level/pull/397
ced4316 verified
raw
history blame
13.8 kB
from __future__ import annotations
import logging
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
from pie_modules.document.processing.merge_spans_via_relation import _merge_spans_via_relation
from pie_modules.documents import TextDocumentWithLabeledMultiSpansAndBinaryRelations
from pie_modules.utils.span import have_overlap
from pytorch_ie import AnnotationLayer
from pytorch_ie.core import Document
from pytorch_ie.core.document import Annotation, _enumerate_dependencies
from src.utils import distance
from src.utils.span_utils import get_overlap_len
logger = logging.getLogger(__name__)
D = TypeVar("D", bound=Document)
def _remove_overlapping_entities(
entities: Iterable[Dict[str, Any]], relations: Iterable[Dict[str, Any]]
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
sorted_entities = sorted(entities, key=lambda span: span["start"])
entities_wo_overlap = []
skipped_entities = []
last_end = 0
for entity_dict in sorted_entities:
if entity_dict["start"] < last_end:
skipped_entities.append(entity_dict)
else:
entities_wo_overlap.append(entity_dict)
last_end = entity_dict["end"]
if len(skipped_entities) > 0:
logger.warning(f"skipped overlapping entities: {skipped_entities}")
valid_entity_ids = set(entity_dict["_id"] for entity_dict in entities_wo_overlap)
valid_relations = [
relation_dict
for relation_dict in relations
if relation_dict["head"] in valid_entity_ids and relation_dict["tail"] in valid_entity_ids
]
return entities_wo_overlap, valid_relations
def remove_overlapping_entities(
doc: D,
entity_layer_name: str = "entities",
relation_layer_name: str = "relations",
) -> D:
# TODO: use document.add_all_annotations_from_other()
document_dict = doc.asdict()
entities_wo_overlap, valid_relations = _remove_overlapping_entities(
entities=document_dict[entity_layer_name]["annotations"],
relations=document_dict[relation_layer_name]["annotations"],
)
document_dict[entity_layer_name] = {
"annotations": entities_wo_overlap,
"predictions": [],
}
document_dict[relation_layer_name] = {
"annotations": valid_relations,
"predictions": [],
}
new_doc = type(doc).fromdict(document_dict)
return new_doc
# TODO: remove and use pie_modules.document.processing.SpansViaRelationMerger instead
def merge_spans_via_relation(
document: D,
relation_layer: str,
link_relation_label: str,
use_predicted_spans: bool = False,
process_predictions: bool = True,
create_multi_spans: bool = False,
) -> D:
rel_layer = document[relation_layer]
span_layer = rel_layer.target_layer
new_gold_spans, new_gold_relations = _merge_spans_via_relation(
spans=span_layer,
relations=rel_layer,
link_relation_label=link_relation_label,
create_multi_spans=create_multi_spans,
)
if process_predictions:
new_pred_spans, new_pred_relations = _merge_spans_via_relation(
spans=span_layer.predictions if use_predicted_spans else span_layer,
relations=rel_layer.predictions,
link_relation_label=link_relation_label,
create_multi_spans=create_multi_spans,
)
else:
assert not use_predicted_spans
new_pred_spans = set(span_layer.predictions.clear())
new_pred_relations = set(rel_layer.predictions.clear())
relation_layer_name = relation_layer
span_layer_name = document[relation_layer].target_name
if create_multi_spans:
doc_dict = document.asdict()
for f in document.annotation_fields():
doc_dict.pop(f.name)
result = TextDocumentWithLabeledMultiSpansAndBinaryRelations.fromdict(doc_dict)
result.labeled_multi_spans.extend(new_gold_spans)
result.labeled_multi_spans.predictions.extend(new_pred_spans)
result.binary_relations.extend(new_gold_relations)
result.binary_relations.predictions.extend(new_pred_relations)
else:
result = document.copy(with_annotations=False)
result[span_layer_name].extend(new_gold_spans)
result[span_layer_name].predictions.extend(new_pred_spans)
result[relation_layer_name].extend(new_gold_relations)
result[relation_layer_name].predictions.extend(new_pred_relations)
return result
def remove_partitions_by_labels(
document: D, partition_layer: str, label_blacklist: List[str], span_layer: Optional[str] = None
) -> D:
"""Remove partitions with labels in the blacklist from a document.
Args:
document: The document to process.
partition_layer: The name of the partition layer.
label_blacklist: The list of labels to remove.
span_layer: The name of the span layer to remove spans from if they are not fully
contained in any remaining partition. Any dependent annotations will be removed as well.
Returns:
The processed document.
"""
document = document.copy()
p_layer: AnnotationLayer = document[partition_layer]
new_partitions = []
for partition in p_layer.clear():
if partition.label not in label_blacklist:
new_partitions.append(partition)
p_layer.extend(new_partitions)
if span_layer is not None:
result = document.copy(with_annotations=False)
removed_span_ids = set()
for span in document[span_layer]:
# keep spans fully contained in any partition
if any(
partition.start <= span.start and span.end <= partition.end
for partition in new_partitions
):
result[span_layer].append(span.copy())
else:
removed_span_ids.add(span._id)
result.add_all_annotations_from_other(
document,
removed_annotations={span_layer: removed_span_ids},
strict=False,
verbose=False,
)
document = result
return document
D_text = TypeVar("D_text", bound=Document)
def replace_substrings_in_text(
document: D_text, replacements: Dict[str, str], enforce_same_length: bool = True
) -> D_text:
new_text = document.text
for old_str, new_str in replacements.items():
if enforce_same_length and len(old_str) != len(new_str):
raise ValueError(
f'Replacement strings must have the same length, but got "{old_str}" -> "{new_str}"'
)
new_text = new_text.replace(old_str, new_str)
result_dict = document.asdict()
result_dict["text"] = new_text
result = type(document).fromdict(result_dict)
result.text = new_text
return result
def replace_substrings_in_text_with_spaces(document: D_text, substrings: Iterable[str]) -> D_text:
replacements = {substring: " " * len(substring) for substring in substrings}
return replace_substrings_in_text(document, replacements=replacements)
def relabel_annotations(
document: D,
label_mapping: Dict[str, Dict[str, str]],
) -> D:
"""
Replace annotation labels in a document.
Args:
document: The document to process.
label_mapping: A mapping from layer names to mappings from old labels to new labels.
Returns:
The processed document.
"""
dependency_ordered_fields: List[str] = []
_enumerate_dependencies(
dependency_ordered_fields,
dependency_graph=document._annotation_graph,
nodes=document._annotation_graph["_artificial_root"],
)
result = document.copy(with_annotations=False)
store: Dict[int, Annotation] = {}
# not yet used
invalid_annotation_ids: Set[int] = set()
for field_name in dependency_ordered_fields:
if field_name in document._annotation_fields:
layer = document[field_name]
for is_prediction, anns in [(False, layer), (True, layer.predictions)]:
for ann in anns:
new_ann = ann.copy_with_store(
override_annotation_store=store,
invalid_annotation_ids=invalid_annotation_ids,
)
if field_name in label_mapping:
if ann.label in label_mapping[field_name]:
new_label = label_mapping[field_name][ann.label]
new_ann = new_ann.copy(label=new_label)
else:
raise ValueError(
f"Label {ann.label} not found in label mapping for {field_name}"
)
store[ann._id] = new_ann
target_layer = result[field_name]
if is_prediction:
target_layer.predictions.append(new_ann)
else:
target_layer.append(new_ann)
return result
DWithSpans = TypeVar("DWithSpans", bound=Document)
def align_predicted_span_annotations(
document: DWithSpans, span_layer: str, distance_type: str = "center", verbose: bool = False
) -> DWithSpans:
"""
Aligns predicted span annotations with the closest gold spans in a document.
First, calculates the distance between each predicted span and each gold span. Then,
for each predicted span, the gold span with the smallest distance is selected. If the
predicted span and the gold span have an overlap of at least half of the maximum length
of the two spans, the predicted span is aligned with the gold span.
Args:
document: The document to process.
span_layer: The name of the span layer.
distance_type: The type of distance to calculate. One of: center, inner, outer
verbose: Whether to print debug information.
Returns:
The processed document.
"""
gold_spans = document[span_layer]
if len(gold_spans) == 0:
return document.copy()
pred_spans = document[span_layer].predictions
old2new_pred_span = {}
span_id2gold_span = {}
for pred_span in pred_spans:
gold_spans_with_distance = [
(
gold_span,
distance(
start_end=(pred_span.start, pred_span.end),
other_start_end=(gold_span.start, gold_span.end),
distance_type=distance_type,
),
)
for gold_span in gold_spans
]
closest_gold_span, min_distance = min(gold_spans_with_distance, key=lambda x: x[1])
# if the closest gold span is the same as the predicted span, we don't need to align
if min_distance == 0.0:
continue
if have_overlap(
start_end=(pred_span.start, pred_span.end),
other_start_end=(closest_gold_span.start, closest_gold_span.end),
):
overlap_len = get_overlap_len(
(pred_span.start, pred_span.end), (closest_gold_span.start, closest_gold_span.end)
)
# get the maximum length of the two spans
l_max = max(
pred_span.end - pred_span.start, closest_gold_span.end - closest_gold_span.start
)
# if the overlap is at least half of the maximum length, we consider it a valid match for alignment
valid_match = overlap_len >= (l_max / 2)
else:
valid_match = False
if valid_match:
aligned_pred_span = pred_span.copy(
start=closest_gold_span.start, end=closest_gold_span.end
)
old2new_pred_span[pred_span._id] = aligned_pred_span
span_id2gold_span[pred_span._id] = closest_gold_span
result = document.copy(with_annotations=False)
# multiple predicted spans can be aligned with the same gold span,
# so we need to keep track of the added spans
added_pred_span_ids = dict()
for pred_span in pred_spans:
# just add the predicted span if it was not aligned with a gold span
if pred_span._id not in old2new_pred_span:
# if this was not added before (e.g. as aligned span), add it
if pred_span._id not in added_pred_span_ids:
keep_pred_span = pred_span.copy()
result[span_layer].predictions.append(keep_pred_span)
added_pred_span_ids[pred_span._id] = keep_pred_span
elif verbose:
print(f"Skipping duplicate predicted span. pred_span='{str(pred_span)}'")
else:
aligned_pred_span = old2new_pred_span[pred_span._id]
# if this was not added before (e.g. as aligned or original pred span), add it
if aligned_pred_span._id not in added_pred_span_ids:
result[span_layer].predictions.append(aligned_pred_span)
added_pred_span_ids[aligned_pred_span._id] = aligned_pred_span
elif verbose:
prev_pred_span = added_pred_span_ids[aligned_pred_span._id]
gold_span = span_id2gold_span[pred_span._id]
print(
f"Skipping duplicate aligned predicted span. aligned gold_span='{str(gold_span)}', "
f"prev_pred_span='{str(prev_pred_span)}', current_pred_span='{str(pred_span)}'"
)
# print("bbb")
result[span_layer].extend([span.copy() for span in gold_spans])
# add remaining gold and predicted spans (the result, _aligned_spans, is just for debugging)
_aligned_spans = result.add_all_annotations_from_other(
document, override_annotations={span_layer: old2new_pred_span}
)
return result