ArneBinder's picture
update from https://github.com/ArneBinder/pie-document-level/pull/397
ced4316 verified
raw
history blame
8.74 kB
from __future__ import annotations
import logging
from functools import partial
from typing import (
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Type,
TypeVar,
Union,
overload,
)
from pie_modules.utils import resolve_type
from pytorch_ie import AutoPipeline, WithDocumentTypeMixin
from pytorch_ie.core import Document
logger = logging.getLogger(__name__)
D = TypeVar("D", bound=Document)
def clear_annotation_layers(doc: D, layer_names: List[str], predictions: bool = False) -> None:
for layer_name in layer_names:
if predictions:
doc[layer_name].predictions.clear()
else:
doc[layer_name].clear()
def move_annotations_from_predictions(doc: D, layer_names: List[str]) -> None:
for layer_name in layer_names:
annotations = list(doc[layer_name].predictions)
# remove any previous annotations
doc[layer_name].clear()
# each annotation can be attached to just one annotation container, so we need to clear the predictions
doc[layer_name].predictions.clear()
doc[layer_name].extend(annotations)
def move_annotations_to_predictions(doc: D, layer_names: List[str]) -> None:
for layer_name in layer_names:
annotations = list(doc[layer_name])
# each annotation can be attached to just one annotation container, so we need to clear the layer
doc[layer_name].clear()
# remove any previous annotations
doc[layer_name].predictions.clear()
doc[layer_name].predictions.extend(annotations)
def add_annotations_from_other_documents(
docs: Iterable[D],
other_docs: Sequence[Document],
layer_names: List[str],
from_predictions: bool = False,
to_predictions: bool = False,
clear_before: bool = True,
) -> None:
for i, doc in enumerate(docs):
other_doc = other_docs[i]
# copy to not modify the input
other_doc = type(other_doc).fromdict(other_doc.asdict())
for layer_name in layer_names:
if clear_before:
doc[layer_name].clear()
other_layer = other_doc[layer_name]
if from_predictions:
other_layer = other_layer.predictions
other_annotations = list(other_layer)
other_layer.clear()
if to_predictions:
doc[layer_name].predictions.extend(other_annotations)
else:
doc[layer_name].extend(other_annotations)
def process_pipeline_steps(
documents: Sequence[Document],
processors: Dict[str, Callable[[Sequence[Document]], Optional[Sequence[Document]]]],
verbose: bool = False,
) -> Sequence[Document]:
# call the processors in the order they are provided
for step_name, processor in processors.items():
if verbose:
logger.info(f"process {step_name} ...")
processed_documents = processor(documents)
if processed_documents is not None:
documents = processed_documents
return documents
def process_documents(
documents: List[Document], processor: Callable[..., Optional[Document]], **kwargs
) -> List[Document]:
result = []
for doc in documents:
processed_doc = processor(doc, **kwargs)
if processed_doc is not None:
result.append(processed_doc)
else:
result.append(doc)
return result
class DummyTaskmodule(WithDocumentTypeMixin):
def __init__(self, document_type: Optional[Union[Type[Document], str]]):
if isinstance(document_type, str):
self._document_type = resolve_type(document_type, expected_super_type=Document)
else:
self._document_type = document_type
@property
def document_type(self) -> Optional[Type[Document]]:
return self._document_type
class NerRePipeline:
def __init__(
self,
ner_model_path: str,
re_model_path: str,
entity_layer: str,
relation_layer: str,
device: Optional[int] = None,
batch_size: Optional[int] = None,
show_progress_bar: Optional[bool] = None,
document_type: Optional[Union[Type[Document], str]] = None,
verbose: bool = True,
**processor_kwargs,
):
self.taskmodule = DummyTaskmodule(document_type)
self.ner_model_path = ner_model_path
self.re_model_path = re_model_path
self.processor_kwargs = processor_kwargs or {}
self.entity_layer = entity_layer
self.relation_layer = relation_layer
self.verbose = verbose
# set some values for the inference processors, if provided
for inference_pipeline in ["ner_pipeline", "re_pipeline"]:
if inference_pipeline not in self.processor_kwargs:
self.processor_kwargs[inference_pipeline] = {}
if "device" not in self.processor_kwargs[inference_pipeline] and device is not None:
self.processor_kwargs[inference_pipeline]["device"] = device
if (
"batch_size" not in self.processor_kwargs[inference_pipeline]
and batch_size is not None
):
self.processor_kwargs[inference_pipeline]["batch_size"] = batch_size
if (
"show_progress_bar" not in self.processor_kwargs[inference_pipeline]
and show_progress_bar is not None
):
self.processor_kwargs[inference_pipeline]["show_progress_bar"] = show_progress_bar
self.ner_pipeline = AutoPipeline.from_pretrained(
self.ner_model_path, **self.processor_kwargs.get("ner_pipeline", {})
)
self.re_pipeline = AutoPipeline.from_pretrained(
self.re_model_path, **self.processor_kwargs.get("re_pipeline", {})
)
@overload
def __call__(
self, documents: Sequence[Document], inplace: bool = False
) -> Sequence[Document]: ...
@overload
def __call__(self, documents: Document, inplace: bool = False) -> Document: ...
def __call__(
self, documents: Union[Sequence[Document], Document], inplace: bool = False
) -> Union[Sequence[Document], Document]:
is_single_doc = False
if isinstance(documents, Document):
documents = [documents]
is_single_doc = True
input_docs: Sequence[Document]
# we need to keep the original documents to add the gold data back
original_docs: Sequence[Document]
if inplace:
input_docs = documents
original_docs = [doc.copy() for doc in documents]
else:
input_docs = [doc.copy() for doc in documents]
original_docs = documents
docs_with_predictions = process_pipeline_steps(
documents=input_docs,
processors={
"clear_annotations": partial(
process_documents,
processor=clear_annotation_layers,
layer_names=[self.entity_layer, self.relation_layer],
**self.processor_kwargs.get("clear_annotations", {}),
),
"ner_pipeline": self.ner_pipeline,
"use_predicted_entities": partial(
process_documents,
processor=move_annotations_from_predictions,
layer_names=[self.entity_layer],
**self.processor_kwargs.get("use_predicted_entities", {}),
),
"re_pipeline": self.re_pipeline,
# otherwise we can not move the entities back to predictions
"clear_candidate_relations": partial(
process_documents,
processor=clear_annotation_layers,
layer_names=[self.relation_layer],
**self.processor_kwargs.get("clear_candidate_relations", {}),
),
"move_entities_to_predictions": partial(
process_documents,
processor=move_annotations_to_predictions,
layer_names=[self.entity_layer],
**self.processor_kwargs.get("move_entities_to_predictions", {}),
),
"re_add_gold_data": partial(
add_annotations_from_other_documents,
other_docs=original_docs,
layer_names=[self.entity_layer, self.relation_layer],
**self.processor_kwargs.get("re_add_gold_data", {}),
),
},
verbose=self.verbose,
)
if is_single_doc:
return docs_with_predictions[0]
return docs_with_predictions