diff --git a/requirements.txt b/requirements.txt
index 9d7d7956f0008e648a082160989527beca6a1e90..2ed355b01fe97496ca6ea7add34b04e9d274a3b5 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,14 +1,3 @@
-# this requires python>=3.10
-gradio~=5.4.0
-prettytable==3.10.0
-beautifulsoup4==4.12.3
-# numpy 2.0.0 breaks the code
-numpy==1.25.2
-scipy==1.13.0
-arxiv==2.1.3
-pyrootutils>=1.0.0,<1.1.0
-
-########## from root requirements ##########
# --------- pytorch-ie --------- #
pytorch-ie>=0.29.6,<0.32.0
pie-datasets>=0.10.5,<0.11.0
@@ -16,20 +5,51 @@ pie-modules>=0.14.0,<0.15.0
# --------- models -------- #
adapters>=0.1.2,<0.2.0
-# ADU retrieval (and demo, in future):
+pytorch-crf~=0.7.2
+# --------- retriever -------- #
langchain>=0.3.0,<0.4.0
langchain-core>=0.3.0,<0.4.0
langchain-community>=0.3.0,<0.4.0
# we use QDrant as vectorstore backend
langchain-qdrant>=0.1.0,<0.2.0
qdrant-client>=1.12.0,<2.0.0
-# 0.26 seems to be broken when used with adapters, see https://github.com/adapter-hub/adapters/issues/748
-huggingface_hub<0.26.0 # 0.26 seems to be broken
-# to to handle segmented entities (if HANDLE_PARTS_OF_SAME=True)
-networkx>=3.0.0,<4.0.0
-# --------- config --------- #
+# --------- demo -------- #
+gradio~=5.4.0
+arxiv~=2.1.3
+
+# --------- hydra --------- #
hydra-core>=1.3.0
+hydra-colorlog>=1.2.0
+hydra-optuna-sweeper>=1.2.0
+
+# --------- loggers --------- #
+wandb
+# neptune-client
+# mlflow
+# comet-ml
+# tensorboard
+# aim
-# --------- dev --------- #
+# --------- linters --------- #
pre-commit # hooks for applying linters on commit
+black # code formatting
+isort # import sorting
+flake8 # code analysis
+nbstripout # remove output from jupyter notebooks
+
+# --------- others --------- #
+pyrootutils # standardizing the project root setup
+python-dotenv # loading env variables from .env file
+rich # beautiful text formatting in terminal
+pytest # tests
+pytest-cov # test coverageataset
+sh # for running bash commands in some tests
+pudb # debugger
+tabulate # show statistics as markdown
+plotext # show statistics as plots
+prettytable # rendering annotated docs as table (demo)
+beautifulsoup4 # rendering annotated docs with displacy + highlighted relations (demo)
+# 0.26 seems to be broken when used with adapters, see https://github.com/adapter-hub/adapters/issues/748
+huggingface_hub<0.26.0 # interaction with HF hub
+networkx~=3.2.1 # to handle segmented entities (e.g if HANDLE_PARTS_OF_SAME=True in demo)
diff --git a/src/datamodules/__init__.py b/src/datamodules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d8636dd7e5d9fd23e304ed3f2926c6020fe7f8f
--- /dev/null
+++ b/src/datamodules/__init__.py
@@ -0,0 +1 @@
+from .datamodule import PieDataModule
diff --git a/src/datamodules/components/__init__.py b/src/datamodules/components/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/datamodules/components/sampler.py b/src/datamodules/components/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1f69ce514db0b02cac579f8784b9fdb578f65e3
--- /dev/null
+++ b/src/datamodules/components/sampler.py
@@ -0,0 +1,67 @@
+"""This is a slightly modified version of https://github.com/ufoym/imbalanced-dataset-sampler."""
+
+from typing import Callable, List, Optional
+
+import pandas as pd
+import torch
+import torch.utils.data
+
+
+class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
+ """Samples elements randomly from a given list of indices for imbalanced dataset.
+
+ Arguments:
+ indices: a list of indices
+ num_samples: number of samples to draw
+ callback_get_label: a callback-like function which takes one argument - the dataset
+ """
+
+ def __init__(
+ self,
+ dataset,
+ labels: Optional[List] = None,
+ indices: Optional[List] = None,
+ num_samples: Optional[int] = None,
+ callback_get_label: Optional[Callable] = None,
+ ):
+ # if indices is not provided, all elements in the dataset will be considered
+ self.indices = list(range(len(dataset))) if indices is None else indices
+
+ # define custom callback
+ self.callback_get_label = callback_get_label
+
+ # if num_samples is not provided, draw `len(indices)` samples in each iteration
+ self.num_samples = len(self.indices) if num_samples is None else num_samples
+
+ # distribution of classes in the dataset
+ df = pd.DataFrame()
+ df["label"] = self._get_labels(dataset) if labels is None else labels
+ df.index = self.indices
+ df = df.sort_index()
+
+ label_to_count = df["label"].value_counts()
+
+ weights = 1.0 / label_to_count[df["label"]]
+
+ self.weights = torch.DoubleTensor(weights.to_list())
+
+ def _get_labels(self, dataset):
+ if self.callback_get_label:
+ return self.callback_get_label(dataset)
+ elif isinstance(dataset, torch.utils.data.TensorDataset):
+ return dataset.tensors[1]
+ elif isinstance(dataset, torch.utils.data.Subset):
+ return dataset.dataset.imgs[:][1]
+ elif isinstance(dataset, torch.utils.data.Dataset):
+ return dataset.get_labels()
+ else:
+ raise NotImplementedError
+
+ def __iter__(self):
+ return (
+ self.indices[i]
+ for i in torch.multinomial(self.weights, self.num_samples, replacement=True)
+ )
+
+ def __len__(self):
+ return self.num_samples
diff --git a/src/datamodules/datamodule.py b/src/datamodules/datamodule.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecaf0131c4ac9e7ecc11cdfa434eef82708e0705
--- /dev/null
+++ b/src/datamodules/datamodule.py
@@ -0,0 +1,154 @@
+from typing import Any, Dict, Generic, Optional, Sequence, TypeVar, Union
+
+from pytorch_ie.core import Document
+from pytorch_ie.core.taskmodule import (
+ IterableTaskEncodingDataset,
+ TaskEncoding,
+ TaskEncodingDataset,
+ TaskModule,
+)
+from pytorch_lightning import LightningDataModule
+from torch.utils.data import DataLoader, Sampler
+from typing_extensions import TypeAlias
+
+from .components.sampler import ImbalancedDatasetSampler
+
+DocumentType = TypeVar("DocumentType", bound=Document)
+InputEncoding = TypeVar("InputEncoding")
+TargetEncoding = TypeVar("TargetEncoding")
+DatasetType: TypeAlias = Union[
+ TaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
+ IterableTaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
+]
+
+
+class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, TargetEncoding]):
+ """A simple LightningDataModule for PIE document datasets.
+
+ A DataModule implements 5 key methods:
+ - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode)
+ - setup (things to do on every accelerator in distributed mode)
+ - train_dataloader (the training dataloader)
+ - val_dataloader (the validation dataloader(s))
+ - test_dataloader (the test dataloader(s))
+
+ This allows you to share a full dataset without explaining how to download,
+ split, transform and process the data.
+
+ Read the docs:
+ https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
+ """
+
+ def __init__(
+ self,
+ taskmodule: TaskModule[DocumentType, InputEncoding, TargetEncoding, Any, Any, Any],
+ dataset: Dict[str, Sequence[DocumentType]],
+ data_config_path: Optional[str] = None,
+ train_split: Optional[str] = "train",
+ val_split: Optional[str] = "validation",
+ test_split: Optional[str] = "test",
+ show_progress_for_encode: bool = False,
+ train_sampler: Optional[str] = None,
+ **dataloader_kwargs,
+ ):
+ super().__init__()
+
+ self.taskmodule = taskmodule
+ self.config_path = data_config_path
+ self.dataset = dataset
+ self.train_split = train_split
+ self.val_split = val_split
+ self.test_split = test_split
+ self.show_progress_for_encode = show_progress_for_encode
+ self.train_sampler_name = train_sampler
+ self.dataloader_kwargs = dataloader_kwargs
+
+ self._data: Dict[str, DatasetType] = {}
+
+ @property
+ def num_train(self) -> int:
+ if self.train_split is None:
+ raise ValueError("no train_split assigned")
+ data_train = self._data.get(self.train_split, None)
+ if data_train is None:
+ raise ValueError("can not get train size if setup() was not yet called")
+ if isinstance(data_train, IterableTaskEncodingDataset):
+ raise TypeError("IterableTaskEncodingDataset has no length")
+ return len(data_train)
+
+ def setup(self, stage: str):
+ if stage == "fit":
+ split_names = [self.train_split, self.val_split]
+ elif stage == "validate":
+ split_names = [self.val_split]
+ elif stage == "test":
+ split_names = [self.test_split]
+ else:
+ raise NotImplementedError(f"not implemented for stage={stage} ")
+
+ for split in split_names:
+ if split is None or split not in self.dataset:
+ continue
+ task_encoding_dataset = self.taskmodule.encode(
+ self.dataset[split],
+ encode_target=True,
+ as_dataset=True,
+ show_progress=self.show_progress_for_encode,
+ )
+ if not isinstance(
+ task_encoding_dataset,
+ (TaskEncodingDataset, IterableTaskEncodingDataset),
+ ):
+ raise TypeError(
+ f"taskmodule.encode did not return a (Iterable)TaskEncodingDataset, but: {type(task_encoding_dataset)}"
+ )
+ self._data[split] = task_encoding_dataset
+
+ def data_split(self, split: Optional[str] = None) -> DatasetType:
+ if split is None or split not in self._data:
+ raise ValueError(f"data for split={split} not available")
+ return self._data[split]
+
+ def get_train_sampler(
+ self,
+ sampler_name: str,
+ dataset: DatasetType,
+ ) -> Sampler:
+ if sampler_name == "imbalanced_dataset":
+ # for now, this work only with targets that have a single entry
+ return ImbalancedDatasetSampler(
+ dataset, callback_get_label=lambda ds: [x.targets[0] for x in ds]
+ )
+ else:
+ raise ValueError(f"unknown sampler name: {sampler_name}")
+
+ def train_dataloader(self):
+ ds = self.data_split(self.train_split)
+ if self.train_sampler_name is not None:
+ sampler = self.get_train_sampler(sampler_name=self.train_sampler_name, dataset=ds)
+ else:
+ sampler = None
+ return DataLoader(
+ dataset=ds,
+ sampler=sampler,
+ collate_fn=self.taskmodule.collate,
+ # don't shuffle streamed datasets or if we use a sampler
+ shuffle=not (isinstance(ds, IterableTaskEncodingDataset) or sampler is not None),
+ **self.dataloader_kwargs,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ dataset=self.data_split(self.val_split),
+ collate_fn=self.taskmodule.collate,
+ shuffle=False,
+ **self.dataloader_kwargs,
+ )
+
+ def test_dataloader(self):
+ return DataLoader(
+ dataset=self.data_split(self.test_split),
+ collate_fn=self.taskmodule.collate,
+ shuffle=False,
+ **self.dataloader_kwargs,
+ )
diff --git a/src/dataset/__init__.py b/src/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/dataset/processing.py b/src/dataset/processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..74f8a70b890fc47be85bfc2624105301f7e3fa44
--- /dev/null
+++ b/src/dataset/processing.py
@@ -0,0 +1,29 @@
+from typing import Callable, Type, Union
+
+from pie_datasets import Dataset, DatasetDict
+from pytorch_ie import Document
+from pytorch_ie.utils.hydra import resolve_optional_document_type, resolve_target
+
+
+# TODO: simply use use DatasetDict.map() with set_batch_size_to_split_size=True and
+# batched=True instead when https://github.com/ArneBinder/pie-datasets/pull/155 is merged
+def apply_func_to_splits(
+ dataset: DatasetDict,
+ function: Union[str, Callable],
+ result_document_type: Type[Document],
+ **kwargs
+):
+ resolved_func = resolve_target(function)
+ resolved_document_type = resolve_optional_document_type(document_type=result_document_type)
+ result_dict = dict()
+ split: Dataset
+ for split_name, split in dataset.items():
+ converted_dataset = split.map(
+ function=resolved_func,
+ batched=True,
+ batch_size=len(split),
+ result_document_type=resolved_document_type,
+ **kwargs
+ )
+ result_dict[split_name] = converted_dataset
+ return DatasetDict(result_dict)
diff --git a/src/demo/__init__.py b/src/demo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/demo/annotation_utils.py b/src/demo/annotation_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..638c224d7d9ecadb6973e11e2e0effb637f115bf
--- /dev/null
+++ b/src/demo/annotation_utils.py
@@ -0,0 +1,137 @@
+import logging
+from typing import Optional, Sequence, Union
+
+import gradio as gr
+from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger
+
+# this is required to dynamically load the PIE models
+from pie_modules.models import * # noqa: F403
+from pie_modules.taskmodules import * # noqa: F403
+from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
+from pytorch_ie import Pipeline
+from pytorch_ie.annotations import LabeledSpan
+from pytorch_ie.auto import AutoPipeline
+from pytorch_ie.documents import (
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
+ TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
+)
+
+# this is required to dynamically load the PIE models
+from pytorch_ie.models import * # noqa: F403
+from pytorch_ie.taskmodules import * # noqa: F403
+
+logger = logging.getLogger(__name__)
+
+
+def annotate_document(
+ document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
+ argumentation_model: Pipeline,
+ handle_parts_of_same: bool = False,
+) -> Union[
+ TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
+]:
+ """Annotate a document with the provided pipeline.
+
+ Args:
+ document: The document to annotate.
+ argumentation_model: The pipeline to use for annotation.
+ handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span.
+ """
+
+ # execute prediction pipeline
+ argumentation_model(document)
+
+ if handle_parts_of_same:
+ merger = SpansViaRelationMerger(
+ relation_layer="binary_relations",
+ link_relation_label="parts_of_same",
+ create_multi_spans=True,
+ result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
+ result_field_mapping={
+ "labeled_spans": "labeled_multi_spans",
+ "binary_relations": "binary_relations",
+ "labeled_partitions": "labeled_partitions",
+ },
+ )
+ document = merger(document)
+
+ return document
+
+
+def create_document(
+ text: str, doc_id: str, split_regex: Optional[str] = None
+) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
+ """Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
+ text.
+
+ Parameters:
+ text: The text to process.
+ doc_id: The ID of the document.
+ split_regex: A regular expression pattern to use for splitting the text into partitions.
+
+ Returns:
+ The processed document.
+ """
+
+ document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
+ id=doc_id, text=text, metadata={}
+ )
+ if split_regex is not None:
+ partitioner = RegexPartitioner(
+ pattern=split_regex, partition_layer_name="labeled_partitions"
+ )
+ document = partitioner(document)
+ else:
+ # add single partition from the whole text (the model only considers text in partitions)
+ document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
+ return document
+
+
+def load_argumentation_model(
+ model_name: str,
+ revision: Optional[str] = None,
+ device: str = "cpu",
+) -> Pipeline:
+ try:
+ # the Pipeline class expects an integer for the device
+ if device == "cuda":
+ pipeline_device = 0
+ elif device.startswith("cuda:"):
+ pipeline_device = int(device.split(":")[1])
+ elif device == "cpu":
+ pipeline_device = -1
+ else:
+ raise gr.Error(f"Invalid device: {device}")
+
+ model = AutoPipeline.from_pretrained(
+ model_name,
+ device=pipeline_device,
+ num_workers=0,
+ taskmodule_kwargs=dict(revision=revision),
+ model_kwargs=dict(revision=revision),
+ )
+ gr.Info(
+ f"Loaded argumentation model: model_name={model_name}, revision={revision}, device={device}"
+ )
+ except Exception as e:
+ raise gr.Error(f"Failed to load argumentation model: {e}")
+
+ return model
+
+
+def set_relation_types(
+ argumentation_model: Pipeline,
+ default: Optional[Sequence[str]] = None,
+) -> gr.Dropdown:
+ if isinstance(argumentation_model.taskmodule, PointerNetworkTaskModuleForEnd2EndRE):
+ relation_types = argumentation_model.taskmodule.labels_per_layer["binary_relations"]
+ else:
+ raise gr.Error("Unsupported taskmodule for relation types")
+
+ return gr.Dropdown(
+ choices=relation_types,
+ label="Argumentative Relation Types",
+ value=default,
+ multiselect=True,
+ )
diff --git a/src/demo/backend_utils.py b/src/demo/backend_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7619dfcc01e08961ffa48494a9162795b44299b1
--- /dev/null
+++ b/src/demo/backend_utils.py
@@ -0,0 +1,221 @@
+import json
+import logging
+import os
+import tempfile
+from typing import Iterable, List, Optional, Sequence
+
+import gradio as gr
+import pandas as pd
+from pie_datasets import Dataset, IterableDataset, load_dataset
+from pytorch_ie import Pipeline
+from pytorch_ie.documents import (
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
+)
+
+from src.demo.annotation_utils import annotate_document, create_document
+from src.demo.data_utils import load_text_from_arxiv
+from src.demo.rendering_utils import (
+ RENDER_WITH_DISPLACY,
+ RENDER_WITH_PRETTY_TABLE,
+ render_displacy,
+ render_pretty_table,
+)
+from src.demo.retriever_utils import get_text_spans_and_relations_from_document
+from src.langchain_modules import (
+ DocumentAwareSpanRetriever,
+ DocumentAwareSpanRetrieverWithRelations,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def add_annotated_pie_documents(
+ retriever: DocumentAwareSpanRetriever,
+ pie_documents: Sequence[TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions],
+ use_predicted_annotations: bool,
+ verbose: bool = False,
+) -> None:
+ if verbose:
+ gr.Info(f"Create span embeddings for {len(pie_documents)} documents...")
+ num_docs_before = len(retriever.docstore)
+ retriever.add_pie_documents(pie_documents, use_predicted_annotations=use_predicted_annotations)
+ # number of documents that were overwritten
+ num_overwritten_docs = num_docs_before + len(pie_documents) - len(retriever.docstore)
+ # warn if documents were overwritten
+ if num_overwritten_docs > 0:
+ gr.Warning(f"{num_overwritten_docs} documents were overwritten.")
+
+
+def process_texts(
+ texts: Iterable[str],
+ doc_ids: Iterable[str],
+ argumentation_model: Pipeline,
+ retriever: DocumentAwareSpanRetriever,
+ split_regex_escaped: Optional[str],
+ handle_parts_of_same: bool = False,
+ verbose: bool = False,
+) -> None:
+ # check that doc_ids are unique
+ if len(set(doc_ids)) != len(list(doc_ids)):
+ raise gr.Error("Document IDs must be unique.")
+ pie_documents = [
+ create_document(text=text, doc_id=doc_id, split_regex=split_regex_escaped)
+ for text, doc_id in zip(texts, doc_ids)
+ ]
+ if verbose:
+ gr.Info(f"Annotate {len(pie_documents)} documents...")
+ pie_documents = [
+ annotate_document(
+ document=pie_document,
+ argumentation_model=argumentation_model,
+ handle_parts_of_same=handle_parts_of_same,
+ )
+ for pie_document in pie_documents
+ ]
+ add_annotated_pie_documents(
+ retriever=retriever,
+ pie_documents=pie_documents,
+ use_predicted_annotations=True,
+ verbose=verbose,
+ )
+
+
+def add_annotated_pie_documents_from_dataset(
+ retriever: DocumentAwareSpanRetriever, verbose: bool = False, **load_dataset_kwargs
+) -> None:
+ try:
+ gr.Info(
+ "Loading PIE dataset with parameters:\n" + json.dumps(load_dataset_kwargs, indent=2)
+ )
+ dataset = load_dataset(**load_dataset_kwargs)
+ if not isinstance(dataset, (Dataset, IterableDataset)):
+ raise gr.Error("Loaded dataset is not of type PIE (Iterable)Dataset.")
+ dataset_converted = dataset.to_document_type(
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
+ )
+ add_annotated_pie_documents(
+ retriever=retriever,
+ pie_documents=dataset_converted,
+ use_predicted_annotations=False,
+ verbose=verbose,
+ )
+ except Exception as e:
+ raise gr.Error(f"Failed to load dataset: {e}")
+
+
+def wrapped_process_text(
+ doc_id: str, text: str, retriever: DocumentAwareSpanRetriever, **kwargs
+) -> str:
+ try:
+ process_texts(doc_ids=[doc_id], texts=[text], retriever=retriever, **kwargs)
+ except Exception as e:
+ raise gr.Error(f"Failed to process text: {e}")
+ # Return as dict and document to avoid serialization issues
+ return doc_id
+
+
+def process_uploaded_files(
+ file_names: List[str],
+ retriever: DocumentAwareSpanRetriever,
+ layer_captions: dict[str, str],
+ **kwargs,
+) -> pd.DataFrame:
+ try:
+ doc_ids = []
+ texts = []
+ for file_name in file_names:
+ if file_name.lower().endswith(".txt"):
+ # read the file content
+ with open(file_name, "r", encoding="utf-8") as f:
+ text = f.read()
+ base_file_name = os.path.basename(file_name)
+ doc_ids.append(base_file_name)
+ texts.append(text)
+ else:
+ raise gr.Error(f"Unsupported file format: {file_name}")
+ process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs)
+ except Exception as e:
+ raise gr.Error(f"Failed to process uploaded files: {e}")
+
+ return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True)
+
+
+def wrapped_add_annotated_pie_documents_from_dataset(
+ retriever: DocumentAwareSpanRetriever, verbose: bool, layer_captions: dict[str, str], **kwargs
+) -> pd.DataFrame:
+ try:
+ add_annotated_pie_documents_from_dataset(retriever=retriever, verbose=verbose, **kwargs)
+ except Exception as e:
+ raise gr.Error(f"Failed to add annotated PIE documents from dataset: {e}")
+ return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True)
+
+
+def download_processed_documents(
+ retriever: DocumentAwareSpanRetriever,
+ file_name: str = "retriever_store",
+) -> Optional[str]:
+ if len(retriever.docstore) == 0:
+ gr.Warning("No documents to download.")
+ return None
+
+ # zip the directory
+ file_path = os.path.join(tempfile.gettempdir(), file_name)
+
+ gr.Info(f"Zipping the retriever store to '{file_name}' ...")
+ result_file_path = retriever.save_to_archive(base_name=file_path, format="zip")
+
+ return result_file_path
+
+
+def upload_processed_documents(
+ file_name: str,
+ retriever: DocumentAwareSpanRetriever,
+ layer_captions: dict[str, str],
+) -> pd.DataFrame:
+ # load the documents from the zip file or directory
+ retriever.load_from_disc(file_name)
+ # return the overview of the document store
+ return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True)
+
+
+def process_text_from_arxiv(
+ arxiv_id: str, retriever: DocumentAwareSpanRetriever, abstract_only: bool = False, **kwargs
+) -> str:
+ try:
+ text, doc_id = load_text_from_arxiv(arxiv_id=arxiv_id, abstract_only=abstract_only)
+ except Exception as e:
+ raise gr.Error(f"Failed to load text from arXiv: {e}")
+ return wrapped_process_text(doc_id=doc_id, text=text, retriever=retriever, **kwargs)
+
+
+def render_annotated_document(
+ retriever: DocumentAwareSpanRetrieverWithRelations,
+ document_id: str,
+ render_with: str,
+ render_kwargs_json: str,
+) -> str:
+ text, spans, span_id2idx, relations = get_text_spans_and_relations_from_document(
+ retriever=retriever, document_id=document_id
+ )
+
+ render_kwargs = json.loads(render_kwargs_json)
+ if render_with == RENDER_WITH_PRETTY_TABLE:
+ html = render_pretty_table(
+ text=text,
+ spans=spans,
+ span_id2idx=span_id2idx,
+ binary_relations=relations,
+ **render_kwargs,
+ )
+ elif render_with == RENDER_WITH_DISPLACY:
+ html = render_displacy(
+ text=text,
+ spans=spans,
+ span_id2idx=span_id2idx,
+ binary_relations=relations,
+ **render_kwargs,
+ )
+ else:
+ raise ValueError(f"Unknown render_with value: {render_with}")
+
+ return html
diff --git a/src/demo/data_utils.py b/src/demo/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..980c05677ff36bbaa6e9556a0d5d8e9ab75d9d19
--- /dev/null
+++ b/src/demo/data_utils.py
@@ -0,0 +1,63 @@
+import logging
+import re
+from typing import Tuple
+
+import arxiv
+import gradio as gr
+import requests
+from bs4 import BeautifulSoup
+
+logger = logging.getLogger(__name__)
+
+
+def clean_spaces(text: str) -> str:
+ # replace all multiple spaces with a single space
+ text = re.sub(" +", " ", text)
+ # reduce more than two newlines to two newlines
+ text = re.sub("\n\n+", "\n\n", text)
+ # remove leading and trailing whitespaces
+ text = text.strip()
+ return text
+
+
+def get_cleaned_arxiv_paper_text(html_content: str) -> str:
+ # parse the HTML content with BeautifulSoup
+ soup = BeautifulSoup(html_content, "html.parser")
+ # get alerts (this is one div with classes "package-alerts" and "ltx_document")
+ alerts = soup.find("div", class_="package-alerts ltx_document")
+ # get the "article" html element
+ article = soup.find("article")
+ article_text = article.get_text()
+ # cleanup the text
+ article_text_clean = clean_spaces(article_text)
+ return article_text_clean
+
+
+def load_text_from_arxiv(arxiv_id: str, abstract_only: bool = False) -> Tuple[str, str]:
+
+ search_by_id = arxiv.Search(id_list=[arxiv_id])
+ try:
+ result = list(arxiv.Client().results(search_by_id))
+ except arxiv.HTTPError as e:
+ raise gr.Error(f"Failed to fetch arXiv data: {e}")
+ if len(result) == 0:
+ raise gr.Error(f"Could not find any paper with arXiv ID '{arxiv_id}'")
+ first_result = result[0]
+ if abstract_only:
+ abstract_clean = first_result.summary.replace("\n", " ")
+ return abstract_clean, first_result.entry_id
+ if "/abs/" not in first_result.entry_id:
+ raise gr.Error(
+ f"Could not create the HTML URL for arXiv ID '{arxiv_id}' because its entry ID has "
+ f"an unexpected format: {first_result.entry_id}"
+ )
+ html_url = first_result.entry_id.replace("/abs/", "/html/")
+ request_result = requests.get(html_url)
+ if request_result.status_code != 200:
+ raise gr.Error(
+ f"Could not fetch the HTML content for arXiv ID '{arxiv_id}', status code: "
+ f"{request_result.status_code}"
+ )
+ html_content = request_result.text
+ text_clean = get_cleaned_arxiv_paper_text(html_content)
+ return text_clean, html_url
diff --git a/src/demo/frontend_utils.py b/src/demo/frontend_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..01b17cb247225c6284957d04bf8911aba90f8286
--- /dev/null
+++ b/src/demo/frontend_utils.py
@@ -0,0 +1,56 @@
+from typing import Any, Union
+
+import gradio as gr
+import pandas as pd
+
+
+# see https://github.com/gradio-app/gradio/issues/9288#issuecomment-2356163329
+def get_fix_df_height_css(css_class: str, max_height: int) -> str:
+ # return ".qa-pairs .table-wrap {min-height: 170px; max-height: 170px;}"
+ return "." + css_class + " .table-wrap {max-height: " + str(max_height) + "px;}"
+
+
+def escape_regex(regex: str) -> str:
+ # "double escape" the backslashes
+ result = regex.encode("unicode_escape").decode("utf-8")
+ return result
+
+
+def unescape_regex(regex: str) -> str:
+ # reverse of escape_regex
+ result = regex.encode("utf-8").decode("unicode_escape")
+ return result
+
+
+def open_accordion():
+ return gr.Accordion(open=True)
+
+
+def close_accordion():
+ return gr.Accordion(open=False)
+
+
+def change_tab(id: Union[int, str]):
+ return gr.Tabs(selected=id)
+
+
+def get_cell_for_fixed_column_from_df(
+ evt: gr.SelectData,
+ df: pd.DataFrame,
+ column: str,
+) -> Any:
+ """Get the value of the fixed column for the selected row in the DataFrame.
+ This is required can *not* with a lambda function because that will not get
+ the evt parameter.
+
+ Args:
+ evt: The event object.
+ df: The DataFrame.
+ column: The name of the column.
+
+ Returns:
+ The value of the fixed column for the selected row.
+ """
+ row_idx, col_idx = evt.index
+ doc_id = df.iloc[row_idx][column]
+ return doc_id
diff --git a/src/demo/rendering_utils.py b/src/demo/rendering_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd46410b62e254e9bf89bf99cab6de584c5e1954
--- /dev/null
+++ b/src/demo/rendering_utils.py
@@ -0,0 +1,296 @@
+import json
+import logging
+from collections import defaultdict
+from typing import Any, Dict, List, Optional, Sequence, Union
+
+from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
+
+from .rendering_utils_displacy import EntityRenderer
+
+logger = logging.getLogger(__name__)
+
+RENDER_WITH_DISPLACY = "displacy"
+RENDER_WITH_PRETTY_TABLE = "pretty_table"
+AVAILABLE_RENDER_MODES = [RENDER_WITH_DISPLACY, RENDER_WITH_PRETTY_TABLE]
+
+# adjusted from rendering_utils_displacy.TPL_ENT
+TPL_ENT_WITH_ID = """
+
+ {text}
+ {label}
+
+"""
+
+HIGHLIGHT_SPANS_JS = """
+() => {
+ function maybeSetColor(entity, colorAttributeKey, colorDictKey) {
+ var color = entity.getAttribute('data-color-' + colorAttributeKey);
+ // if color is a json string, parse it and use the value at colorDictKey
+ try {
+ const colors = JSON.parse(color);
+ color = colors[colorDictKey];
+ } catch (e) {}
+ if (color) {
+ entity.style.backgroundColor = color;
+ entity.style.color = '#000';
+ }
+ }
+
+ function highlightRelationArguments(entityId) {
+ const entities = document.querySelectorAll('.entity');
+ // reset all entities
+ entities.forEach(entity => {
+ const color = entity.getAttribute('data-color-original');
+ entity.style.backgroundColor = color;
+ entity.style.color = '';
+ });
+
+ if (entityId !== null) {
+ var visitedEntities = new Set();
+ // highlight selected entity
+ // get all elements with attribute data-entity-id==entityId
+ const selectedEntityParts = document.querySelectorAll(`[data-entity-id="${entityId}"]`);
+ selectedEntityParts.forEach(selectedEntityPart => {
+ const label = selectedEntityPart.getAttribute('data-label');
+ maybeSetColor(selectedEntityPart, 'selected', label);
+ visitedEntities.add(selectedEntityPart);
+ }); // <-- Corrected closing parenthesis here
+ // if there is at least one part, get the first one and ...
+ if (selectedEntityParts.length > 0) {
+ const selectedEntity = selectedEntityParts[0];
+
+ // ... highlight tails and ...
+ const relationTailsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-tails'));
+ relationTailsAndLabels.forEach(relationTail => {
+ const tailEntityId = relationTail['entity-id'];
+ const tailEntityParts = document.querySelectorAll(`[data-entity-id="${tailEntityId}"]`);
+ tailEntityParts.forEach(tailEntity => {
+ const label = relationTail['label'];
+ maybeSetColor(tailEntity, 'tail', label);
+ visitedEntities.add(tailEntity);
+ }); // <-- Corrected closing parenthesis here
+ }); // <-- Corrected closing parenthesis here
+ // .. highlight heads
+ const relationHeadsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-heads'));
+ relationHeadsAndLabels.forEach(relationHead => {
+ const headEntityId = relationHead['entity-id'];
+ const headEntityParts = document.querySelectorAll(`[data-entity-id="${headEntityId}"]`);
+ headEntityParts.forEach(headEntity => {
+ const label = relationHead['label'];
+ maybeSetColor(headEntity, 'head', label);
+ visitedEntities.add(headEntity);
+ }); // <-- Corrected closing parenthesis here
+ }); // <-- Corrected closing parenthesis here
+ }
+
+ // highlight other entities
+ entities.forEach(entity => {
+ if (!visitedEntities.has(entity)) {
+ const label = entity.getAttribute('data-label');
+ maybeSetColor(entity, 'other', label);
+ }
+ });
+ }
+ }
+ function setHoverAduId(entityId) {
+ // get the textarea element that holds the reference adu id
+ let hoverAduIdDiv = document.querySelector('#hover_adu_id textarea');
+ // set the value of the input field
+ hoverAduIdDiv.value = entityId;
+ // trigger an input event to update the state
+ var event = new Event('input');
+ hoverAduIdDiv.dispatchEvent(event);
+ }
+ function setReferenceAduIdFromHover() {
+ // get the hover adu id
+ const hoverAduIdDiv = document.querySelector('#hover_adu_id textarea');
+ // get the value of the input field
+ const entityId = hoverAduIdDiv.value;
+ // get the textarea element that holds the reference adu id
+ let referenceAduIdDiv = document.querySelector('#selected_adu_id textarea');
+ // set the value of the input field
+ referenceAduIdDiv.value = entityId;
+ // trigger an input event to update the state
+ var event = new Event('input');
+ referenceAduIdDiv.dispatchEvent(event);
+ }
+
+ const entities = document.querySelectorAll('.entity');
+ entities.forEach(entity => {
+ // make the cursor a pointer
+ entity.style.cursor = 'pointer';
+ const alreadyHasListener = entity.getAttribute('data-has-listener');
+ if (alreadyHasListener) {
+ return;
+ }
+ entity.addEventListener('mouseover', () => {
+ const entityId = entity.getAttribute('data-entity-id');
+ highlightRelationArguments(entityId);
+ setHoverAduId(entityId);
+ });
+ entity.addEventListener('mouseout', () => {
+ highlightRelationArguments(null);
+ });
+ entity.setAttribute('data-has-listener', 'true');
+ });
+ const entityContainer = document.querySelector('.entities');
+ if (entityContainer) {
+ entityContainer.addEventListener('click', () => {
+ setReferenceAduIdFromHover();
+ });
+ // make the cursor a pointer
+ // entityContainer.style.cursor = 'pointer';
+ }
+}
+"""
+
+
+def render_pretty_table(
+ text: str,
+ spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
+ span_id2idx: Dict[str, int],
+ binary_relations: Sequence[BinaryRelation],
+ **render_kwargs,
+):
+ from prettytable import PrettyTable
+
+ t = PrettyTable()
+ t.field_names = ["head", "tail", "relation"]
+ t.align = "l"
+ for relation in list(binary_relations) + list(binary_relations):
+ t.add_row([str(relation.head), str(relation.tail), relation.label])
+
+ html = t.get_html_string(format=True)
+ html = "
" + html + "
"
+
+ return html
+
+
+def render_displacy(
+ text: str,
+ spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
+ span_id2idx: Dict[str, int],
+ binary_relations: Sequence[BinaryRelation],
+ inject_relations=True,
+ colors_hover=None,
+ entity_options={},
+ **render_kwargs,
+):
+
+ ents: List[Dict[str, Any]] = []
+ for entity_id, idx in span_id2idx.items():
+ labeled_span = spans[idx]
+ # pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
+ # on hover and to inject the relation data.
+ if isinstance(labeled_span, LabeledSpan):
+ ents.append(
+ {
+ "start": labeled_span.start,
+ "end": labeled_span.end,
+ "label": labeled_span.label,
+ "params": {"entity_id": entity_id, "slice_idx": 0},
+ }
+ )
+ elif isinstance(labeled_span, LabeledMultiSpan):
+ for i, (start, end) in enumerate(labeled_span.slices):
+ ents.append(
+ {
+ "start": start,
+ "end": end,
+ "label": labeled_span.label,
+ "params": {"entity_id": entity_id, "slice_idx": i},
+ }
+ )
+ else:
+ raise ValueError(f"Unsupported labeled span type: {type(labeled_span)}")
+
+ ents_sorted = sorted(ents, key=lambda x: (x["start"], x["end"]))
+ spacy_doc = {
+ "text": text,
+ # the ents MUST be sorted by start and end
+ "ents": ents_sorted,
+ "title": None,
+ }
+
+ # copy to avoid modifying the original options
+ entity_options = entity_options.copy()
+ # use the custom template with the entity ID
+ entity_options["template"] = TPL_ENT_WITH_ID
+ renderer = EntityRenderer(options=entity_options)
+ html = renderer.render([spacy_doc], page=True, minify=True).strip()
+
+ html = "" + html + "
"
+ if inject_relations:
+ html = inject_relation_data(
+ html,
+ spans=spans,
+ span_id2idx=span_id2idx,
+ binary_relations=binary_relations,
+ additional_colors=colors_hover,
+ )
+ return html
+
+
+def inject_relation_data(
+ html: str,
+ spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
+ span_id2idx: Dict[str, int],
+ binary_relations: Sequence[BinaryRelation],
+ additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
+) -> str:
+ from bs4 import BeautifulSoup
+
+ # Parse the HTML using BeautifulSoup
+ soup = BeautifulSoup(html, "html.parser")
+
+ entity2tails = defaultdict(list)
+ entity2heads = defaultdict(list)
+ for relation in binary_relations:
+ entity2heads[relation.tail].append((relation.head, relation.label))
+ entity2tails[relation.head].append((relation.tail, relation.label))
+
+ annotation2id = {spans[span_idx]: span_id for span_id, span_idx in span_id2idx.items()}
+ # Add unique IDs to each entity
+ entities = soup.find_all(class_="entity")
+ for entity in entities:
+ original_color = entity["style"].split("background:")[1].split(";")[0].strip()
+ entity["data-color-original"] = original_color
+ if additional_colors is not None:
+ for key, color in additional_colors.items():
+ entity[f"data-color-{key}"] = (
+ json.dumps(color) if isinstance(color, dict) else color
+ )
+
+ entity_annotation = spans[span_id2idx[entity["data-entity-id"]]]
+
+ # sanity check.
+ if isinstance(entity_annotation, LabeledSpan):
+ annotation_text = entity_annotation.resolve()[1]
+ elif isinstance(entity_annotation, LabeledMultiSpan):
+ slice_idx = int(entity["data-slice-idx"])
+ annotation_text = entity_annotation.resolve()[1][slice_idx]
+ else:
+ raise ValueError(f"Unsupported entity type: {type(entity_annotation)}")
+ annotation_text_without_newline = annotation_text.replace("\n", "")
+ # Just check the start, because the text has the label attached to the end
+ if not entity.text.startswith(annotation_text_without_newline):
+ logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}")
+
+ entity["data-label"] = entity_annotation.label
+ entity["data-relation-tails"] = json.dumps(
+ [
+ {"entity-id": annotation2id[tail], "label": label}
+ for tail, label in entity2tails.get(entity_annotation, [])
+ if tail in annotation2id
+ ]
+ )
+ entity["data-relation-heads"] = json.dumps(
+ [
+ {"entity-id": annotation2id[head], "label": label}
+ for head, label in entity2heads.get(entity_annotation, [])
+ if head in annotation2id
+ ]
+ )
+
+ # Return the modified HTML as a string
+ return str(soup)
diff --git a/src/demo/rendering_utils_displacy.py b/src/demo/rendering_utils_displacy.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c69cb26749ec18df50c2157003e1d36b24f72b2
--- /dev/null
+++ b/src/demo/rendering_utils_displacy.py
@@ -0,0 +1,217 @@
+# This code is mainly taken from
+# https://github.com/explosion/spaCy/blob/master/spacy/displacy/templates.py, and from
+# https://github.com/explosion/spaCy/blob/master/spacy/displacy/render.py.
+
+# Setting explicit height and max-width: none on the SVG is required for
+# Jupyter to render it properly in a cell
+
+TPL_DEP_SVG = """
+
+"""
+
+
+TPL_DEP_WORDS = """
+
+ {text}
+ {tag}
+
+"""
+
+
+TPL_DEP_WORDS_LEMMA = """
+
+ {text}
+ {lemma}
+ {tag}
+
+"""
+
+
+TPL_DEP_ARCS = """
+
+
+
+ {label}
+
+
+
+"""
+
+
+TPL_FIGURE = """
+{content}
+"""
+
+TPL_TITLE = """
+{title}
+"""
+
+
+TPL_ENTS = """
+{content}
+"""
+
+
+TPL_ENT = """
+
+ {text}
+ {label}
+
+"""
+
+TPL_ENT_RTL = """
+
+ {text}
+ {label}
+
+"""
+
+
+TPL_PAGE = """
+
+
+
+ displaCy
+
+
+ {content}
+
+"""
+
+
+DEFAULT_LANG = "en"
+DEFAULT_DIR = "ltr"
+
+
+def minify_html(html):
+ """Perform a template-specific, rudimentary HTML minification for displaCy.
+ Disclaimer: NOT a general-purpose solution, only removes indentation and
+ newlines.
+
+ html (unicode): Markup to minify.
+ RETURNS (unicode): "Minified" HTML.
+ """
+ return html.strip().replace(" ", "").replace("\n", "")
+
+
+def escape_html(text):
+ """Replace <, >, &, " with their HTML encoded representation. Intended to prevent HTML errors
+ in rendered displaCy markup.
+
+ text (unicode): The original text. RETURNS (unicode): Equivalent text to be safely used within
+ HTML.
+ """
+ text = text.replace("&", "&")
+ text = text.replace("<", "<")
+ text = text.replace(">", ">")
+ text = text.replace('"', """)
+ return text
+
+
+class EntityRenderer(object):
+ """Render named entities as HTML."""
+
+ style = "ent"
+
+ def __init__(self, options={}):
+ """Initialise dependency renderer.
+
+ options (dict): Visualiser-specific options (colors, ents)
+ """
+ colors = {
+ "ORG": "#7aecec",
+ "PRODUCT": "#bfeeb7",
+ "GPE": "#feca74",
+ "LOC": "#ff9561",
+ "PERSON": "#aa9cfc",
+ "NORP": "#c887fb",
+ "FACILITY": "#9cc9cc",
+ "EVENT": "#ffeb80",
+ "LAW": "#ff8197",
+ "LANGUAGE": "#ff8197",
+ "WORK_OF_ART": "#f0d0ff",
+ "DATE": "#bfe1d9",
+ "TIME": "#bfe1d9",
+ "MONEY": "#e4e7d2",
+ "QUANTITY": "#e4e7d2",
+ "ORDINAL": "#e4e7d2",
+ "CARDINAL": "#e4e7d2",
+ "PERCENT": "#e4e7d2",
+ }
+ # user_colors = registry.displacy_colors.get_all()
+ # for user_color in user_colors.values():
+ # colors.update(user_color)
+ colors.update(options.get("colors", {}))
+ self.default_color = "#ddd"
+ self.colors = colors
+ self.ents = options.get("ents", None)
+ self.direction = DEFAULT_DIR
+ self.lang = DEFAULT_LANG
+
+ template = options.get("template")
+ if template:
+ self.ent_template = template
+ else:
+ if self.direction == "rtl":
+ self.ent_template = TPL_ENT_RTL
+ else:
+ self.ent_template = TPL_ENT
+
+ def render(self, parsed, page=False, minify=False):
+ """Render complete markup.
+
+ parsed (list): Dependency parses to render. page (bool): Render parses wrapped as full HTML
+ page. minify (bool): Minify HTML markup. RETURNS (unicode): Rendered HTML markup.
+ """
+ rendered = []
+ for i, p in enumerate(parsed):
+ if i == 0:
+ settings = p.get("settings", {})
+ self.direction = settings.get("direction", DEFAULT_DIR)
+ self.lang = settings.get("lang", DEFAULT_LANG)
+ rendered.append(self.render_ents(p["text"], p["ents"], p.get("title")))
+ if page:
+ docs = "".join([TPL_FIGURE.format(content=doc) for doc in rendered])
+ markup = TPL_PAGE.format(content=docs, lang=self.lang, dir=self.direction)
+ else:
+ markup = "".join(rendered)
+ if minify:
+ return minify_html(markup)
+ return markup
+
+ def render_ents(self, text, spans, title):
+ """Render entities in text.
+
+ text (unicode): Original text. spans (list): Individual entity spans and their start, end
+ and label. title (unicode or None): Document title set in Doc.user_data['title'].
+ """
+ markup = ""
+ offset = 0
+ for span in spans:
+ label = span["label"]
+ start = span["start"]
+ end = span["end"]
+ additional_params = span.get("params", {})
+ entity = escape_html(text[start:end])
+ fragments = text[offset:start].split("\n")
+ for i, fragment in enumerate(fragments):
+ markup += escape_html(fragment)
+ if len(fragments) > 1 and i != len(fragments) - 1:
+ markup += "
"
+ if self.ents is None or label.upper() in self.ents:
+ color = self.colors.get(label.upper(), self.default_color)
+ ent_settings = {"label": label, "text": entity, "bg": color}
+ ent_settings.update(additional_params)
+ markup += self.ent_template.format(**ent_settings)
+ else:
+ markup += entity
+ offset = end
+ fragments = text[offset:].split("\n")
+ for i, fragment in enumerate(fragments):
+ markup += escape_html(fragment)
+ if len(fragments) > 1 and i != len(fragments) - 1:
+ markup += "
"
+ markup = TPL_ENTS.format(content=markup, dir=self.direction)
+ if title:
+ markup = TPL_TITLE.format(title=title) + markup
+ return markup
diff --git a/src/demo/retrieve_and_dump_all_relevant.py b/src/demo/retrieve_and_dump_all_relevant.py
new file mode 100644
index 0000000000000000000000000000000000000000..7105e122a1b7e34066b3cc127ece5891d62ed4a0
--- /dev/null
+++ b/src/demo/retrieve_and_dump_all_relevant.py
@@ -0,0 +1,101 @@
+import pyrootutils
+
+root = pyrootutils.setup_root(
+ search_from=__file__,
+ indicator=[".project-root"],
+ pythonpath=True,
+ dotenv=True,
+)
+
+import argparse
+import logging
+
+from src.demo.retriever_utils import (
+ retrieve_all_relevant_spans,
+ retrieve_all_relevant_spans_for_all_documents,
+ retrieve_relevant_spans,
+)
+from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations
+
+logger = logging.getLogger(__name__)
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-c",
+ "--config_path",
+ type=str,
+ default="configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml",
+ )
+ parser.add_argument(
+ "--data_path",
+ type=str,
+ required=True,
+ help="Path to a zip or directory containing a retriever dump.",
+ )
+ parser.add_argument("-k", "--top_k", type=int, default=10)
+ parser.add_argument("-t", "--threshold", type=float, default=0.95)
+ parser.add_argument(
+ "-o",
+ "--output_path",
+ type=str,
+ required=True,
+ )
+ parser.add_argument(
+ "--query_doc_id",
+ type=str,
+ default=None,
+ help="If provided, retrieve all spans for only this query document.",
+ )
+ parser.add_argument(
+ "--query_span_id",
+ type=str,
+ default=None,
+ help="If provided, retrieve all spans for only this query span.",
+ )
+ args = parser.parse_args()
+
+ logging.basicConfig(
+ format="%(asctime)s %(levelname)-8s %(message)s",
+ level=logging.INFO,
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+
+ if not args.output_path.endswith(".json"):
+ raise ValueError("only support json output")
+
+ logger.info(f"instantiating retriever from {args.config_path}...")
+ retriever = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config_file(
+ args.config_path
+ )
+ logger.info(f"loading data from {args.data_path}...")
+ retriever.load_from_disc(args.data_path)
+
+ search_kwargs = {"k": args.top_k, "score_threshold": args.threshold}
+ logger.info(f"use search_kwargs: {search_kwargs}")
+
+ if args.query_span_id is not None:
+ logger.warning(f"retrieving results for single span: {args.query_span_id}")
+ all_spans_for_all_documents = retrieve_relevant_spans(
+ retriever=retriever, query_span_id=args.query_span_id, **search_kwargs
+ )
+ elif args.query_doc_id is not None:
+ logger.warning(f"retrieving results for single document: {args.query_doc_id}")
+ all_spans_for_all_documents = retrieve_all_relevant_spans(
+ retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs
+ )
+ else:
+ all_spans_for_all_documents = retrieve_all_relevant_spans_for_all_documents(
+ retriever=retriever, **search_kwargs
+ )
+
+ if all_spans_for_all_documents is None:
+ logger.warning("no relevant spans found in any document")
+ exit(0)
+
+ logger.info(f"dumping results to {args.output_path}...")
+ all_spans_for_all_documents.to_json(args.output_path)
+
+ logger.info("done")
diff --git a/src/demo/retriever_utils.py b/src/demo/retriever_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dc47eea8944c2302491eb421d729ed9059d1a9f
--- /dev/null
+++ b/src/demo/retriever_utils.py
@@ -0,0 +1,313 @@
+import logging
+from typing import Dict, Optional, Sequence, Tuple, Union
+
+import gradio as gr
+import pandas as pd
+from pytorch_ie import Annotation
+from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
+from typing_extensions import Protocol
+
+from src.langchain_modules import DocumentAwareSpanRetriever
+from src.langchain_modules.span_retriever import (
+ DocumentAwareSpanRetrieverWithRelations,
+ _parse_config,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def get_document_as_dict(retriever: DocumentAwareSpanRetriever, doc_id: str) -> Dict:
+ document = retriever.get_document(doc_id=doc_id)
+ return retriever.docstore.as_dict(document)
+
+
+def load_retriever(
+ retriever_config_str: str,
+ config_format: str,
+ device: str = "cpu",
+ previous_retriever: Optional[DocumentAwareSpanRetrieverWithRelations] = None,
+) -> DocumentAwareSpanRetrieverWithRelations:
+ try:
+ retriever_config = _parse_config(retriever_config_str, format=config_format)
+ # set device for the embeddings pipeline
+ retriever_config["vectorstore"]["embedding"]["pipeline_kwargs"]["device"] = device
+ result = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config(retriever_config)
+ # if a previous retriever is provided, load all documents and vectors from the previous retriever
+ if previous_retriever is not None:
+ # documents
+ all_doc_ids = list(previous_retriever.docstore.yield_keys())
+ gr.Info(f"Storing {len(all_doc_ids)} documents from previous retriever...")
+ all_docs = previous_retriever.docstore.mget(all_doc_ids)
+ result.docstore.mset([(doc.id, doc) for doc in all_docs])
+ # spans (with vectors)
+ all_span_ids = list(previous_retriever.vectorstore.yield_keys())
+ all_spans = previous_retriever.vectorstore.mget(all_span_ids)
+ result.vectorstore.mset([(span.id, span) for span in all_spans])
+
+ gr.Info("Retriever loaded successfully.")
+ return result
+ except Exception as e:
+ raise gr.Error(f"Failed to load retriever: {e}")
+
+
+def retrieve_similar_spans(
+ retriever: DocumentAwareSpanRetriever,
+ query_span_id: str,
+ **kwargs,
+) -> pd.DataFrame:
+ if not query_span_id.strip():
+ raise gr.Error("No query span selected.")
+ try:
+ retrieval_result = retriever.invoke(input=query_span_id, **kwargs)
+ records = []
+ for similar_span_doc in retrieval_result:
+ pie_doc, metadata = retriever.docstore.unwrap_with_metadata(similar_span_doc)
+ span_ann = metadata["attached_span"]
+ records.append(
+ {
+ "doc_id": pie_doc.id,
+ "span_id": similar_span_doc.id,
+ "score": metadata["relevance_score"],
+ "label": span_ann.label,
+ "text": str(span_ann),
+ }
+ )
+ return (
+ pd.DataFrame(records, columns=["doc_id", "score", "label", "text", "span_id"])
+ .sort_values(by="score", ascending=False)
+ .round(3)
+ )
+ except Exception as e:
+ raise gr.Error(f"Failed to retrieve similar ADUs: {e}")
+
+
+def retrieve_relevant_spans(
+ retriever: DocumentAwareSpanRetriever,
+ query_span_id: str,
+ relation_label_mapping: Optional[dict[str, str]] = None,
+ **kwargs,
+) -> pd.DataFrame:
+ if not query_span_id.strip():
+ raise gr.Error("No query span selected.")
+ try:
+ relation_label_mapping = relation_label_mapping or {}
+ retrieval_result = retriever.invoke(input=query_span_id, return_related=True, **kwargs)
+ records = []
+ for relevant_span_doc in retrieval_result:
+ pie_doc, metadata = retriever.docstore.unwrap_with_metadata(relevant_span_doc)
+ span_ann = metadata["attached_span"]
+ tail_span_ann = metadata["attached_tail_span"]
+ mapped_relation_label = relation_label_mapping.get(
+ metadata["relation_label"], metadata["relation_label"]
+ )
+ records.append(
+ {
+ "doc_id": pie_doc.id,
+ "type": mapped_relation_label,
+ "rel_score": metadata["relation_score"],
+ "text": str(tail_span_ann),
+ "span_id": relevant_span_doc.id,
+ "label": tail_span_ann.label,
+ "ref_score": metadata["relevance_score"],
+ "ref_label": span_ann.label,
+ "ref_text": str(span_ann),
+ "ref_span_id": metadata["head_id"],
+ }
+ )
+ return (
+ pd.DataFrame(
+ records,
+ columns=[
+ "type",
+ # omitted for now, we get no valid relation scores for the generative model
+ # "rel_score",
+ "ref_score",
+ "label",
+ "text",
+ "ref_label",
+ "ref_text",
+ "doc_id",
+ "span_id",
+ "ref_span_id",
+ ],
+ )
+ .sort_values(by=["ref_score"], ascending=False)
+ .round(3)
+ )
+ except Exception as e:
+ raise gr.Error(f"Failed to retrieve relevant ADUs: {e}")
+
+
+class RetrieverCallable(Protocol):
+ def __call__(
+ self,
+ retriever: DocumentAwareSpanRetriever,
+ query_span_id: str,
+ **kwargs,
+ ) -> Optional[pd.DataFrame]:
+ pass
+
+
+def _retrieve_for_all_spans(
+ retriever: DocumentAwareSpanRetriever,
+ query_doc_id: str,
+ retrieve_func: RetrieverCallable,
+ query_span_id_column: str = "query_span_id",
+ **kwargs,
+) -> Optional[pd.DataFrame]:
+ if not query_doc_id.strip():
+ raise gr.Error("No query document selected.")
+ try:
+ span_id2idx = retriever.get_span_id2idx_from_doc(query_doc_id)
+ gr.Info(f"Retrieving results for {len(span_id2idx)} ADUs in document {query_doc_id}...")
+ span_results = {
+ query_span_id: retrieve_func(
+ retriever=retriever,
+ query_span_id=query_span_id,
+ **kwargs,
+ )
+ for query_span_id in span_id2idx.keys()
+ }
+ span_results_not_empty = {
+ query_span_id: df
+ for query_span_id, df in span_results.items()
+ if df is not None and not df.empty
+ }
+
+ # add column with query_span_id
+ for query_span_id, query_span_result in span_results_not_empty.items():
+ query_span_result[query_span_id_column] = query_span_id
+
+ if len(span_results_not_empty) == 0:
+ gr.Info(f"No results found for any ADU in document {query_doc_id}.")
+ return None
+ else:
+ result = pd.concat(span_results_not_empty.values(), ignore_index=True)
+ gr.Info(f"Retrieved {len(result)} ADUs for document {query_doc_id}.")
+ return result
+ except Exception as e:
+ raise gr.Error(
+ f'Failed to retrieve results for all ADUs in document "{query_doc_id}": {e}'
+ )
+
+
+def retrieve_all_similar_spans(
+ retriever: DocumentAwareSpanRetriever,
+ query_doc_id: str,
+ **kwargs,
+) -> Optional[pd.DataFrame]:
+ return _retrieve_for_all_spans(
+ retriever=retriever,
+ query_doc_id=query_doc_id,
+ retrieve_func=retrieve_similar_spans,
+ **kwargs,
+ )
+
+
+def retrieve_all_relevant_spans(
+ retriever: DocumentAwareSpanRetriever,
+ query_doc_id: str,
+ **kwargs,
+) -> Optional[pd.DataFrame]:
+ return _retrieve_for_all_spans(
+ retriever=retriever,
+ query_doc_id=query_doc_id,
+ retrieve_func=retrieve_relevant_spans,
+ **kwargs,
+ )
+
+
+class RetrieverForAllSpansCallable(Protocol):
+ def __call__(
+ self,
+ retriever: DocumentAwareSpanRetriever,
+ query_doc_id: str,
+ **kwargs,
+ ) -> Optional[pd.DataFrame]:
+ pass
+
+
+def _retrieve_for_all_documents(
+ retriever: DocumentAwareSpanRetriever,
+ retrieve_func: RetrieverForAllSpansCallable,
+ query_doc_id_column: str = "query_doc_id",
+ **kwargs,
+) -> Optional[pd.DataFrame]:
+ try:
+ all_doc_ids = list(retriever.docstore.yield_keys())
+ gr.Info(f"Retrieving results for {len(all_doc_ids)} documents...")
+ doc_results = {
+ doc_id: retrieve_func(retriever=retriever, query_doc_id=doc_id, **kwargs)
+ for doc_id in all_doc_ids
+ }
+ doc_results_not_empty = {
+ doc_id: df for doc_id, df in doc_results.items() if df is not None and not df.empty
+ }
+ # add column with query_doc_id
+ for doc_id, doc_result in doc_results_not_empty.items():
+ doc_result[query_doc_id_column] = doc_id
+
+ if len(doc_results_not_empty) == 0:
+ gr.Info("No results found for any document.")
+ return None
+ else:
+ result = pd.concat(doc_results_not_empty, ignore_index=True)
+ gr.Info(f"Retrieved {len(result)} ADUs for all documents.")
+ return result
+ except Exception as e:
+ raise gr.Error(f"Failed to retrieve results for all documents: {e}")
+
+
+def retrieve_all_similar_spans_for_all_documents(
+ retriever: DocumentAwareSpanRetriever,
+ **kwargs,
+) -> Optional[pd.DataFrame]:
+ return _retrieve_for_all_documents(
+ retriever=retriever,
+ retrieve_func=retrieve_all_similar_spans,
+ **kwargs,
+ )
+
+
+def retrieve_all_relevant_spans_for_all_documents(
+ retriever: DocumentAwareSpanRetriever,
+ **kwargs,
+) -> Optional[pd.DataFrame]:
+ return _retrieve_for_all_documents(
+ retriever=retriever,
+ retrieve_func=retrieve_all_relevant_spans,
+ **kwargs,
+ )
+
+
+def get_text_spans_and_relations_from_document(
+ retriever: DocumentAwareSpanRetrieverWithRelations, document_id: str
+) -> Tuple[
+ str,
+ Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
+ Dict[str, int],
+ Sequence[BinaryRelation],
+]:
+ document = retriever.get_document(doc_id=document_id)
+ pie_document = retriever.docstore.unwrap(document)
+ use_predicted_annotations = retriever.use_predicted_annotations(document)
+ spans = retriever.get_base_layer(
+ pie_document=pie_document, use_predicted_annotations=use_predicted_annotations
+ )
+ relations = retriever.get_relation_layer(
+ pie_document=pie_document, use_predicted_annotations=use_predicted_annotations
+ )
+ span_id2idx = retriever.get_span_id2idx_from_doc(document)
+ return pie_document.text, spans, span_id2idx, relations
+
+
+def get_span_annotation(
+ retriever: DocumentAwareSpanRetriever,
+ span_id: str,
+) -> Annotation:
+ if span_id.strip() == "":
+ raise gr.Error("No span selected.")
+ try:
+ return retriever.get_span_by_id(span_id=span_id)
+ except Exception as e:
+ raise gr.Error(f"Failed to retrieve span annotation: {e}")
diff --git a/src/document/__init__.py b/src/document/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/document/processing.py b/src/document/processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..651418aca6ff09f621e02281ced1fd5dc1a0f703
--- /dev/null
+++ b/src/document/processing.py
@@ -0,0 +1,223 @@
+from __future__ import annotations
+
+import logging
+from typing import Any, Dict, Iterable, List, Sequence, Set, Tuple, TypeVar, Union
+
+import networkx as nx
+from pie_modules.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
+from pie_modules.documents import TextDocumentWithLabeledMultiSpansAndBinaryRelations
+from pytorch_ie import AnnotationLayer
+from pytorch_ie.core import Document
+
+logger = logging.getLogger(__name__)
+
+
+D = TypeVar("D", bound=Document)
+
+
+def _remove_overlapping_entities(
+ entities: Iterable[Dict[str, Any]], relations: Iterable[Dict[str, Any]]
+) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
+ sorted_entities = sorted(entities, key=lambda span: span["start"])
+ entities_wo_overlap = []
+ skipped_entities = []
+ last_end = 0
+ for entity_dict in sorted_entities:
+ if entity_dict["start"] < last_end:
+ skipped_entities.append(entity_dict)
+ else:
+ entities_wo_overlap.append(entity_dict)
+ last_end = entity_dict["end"]
+ if len(skipped_entities) > 0:
+ logger.warning(f"skipped overlapping entities: {skipped_entities}")
+ valid_entity_ids = set(entity_dict["_id"] for entity_dict in entities_wo_overlap)
+ valid_relations = [
+ relation_dict
+ for relation_dict in relations
+ if relation_dict["head"] in valid_entity_ids and relation_dict["tail"] in valid_entity_ids
+ ]
+ return entities_wo_overlap, valid_relations
+
+
+def remove_overlapping_entities(
+ doc: D,
+ entity_layer_name: str = "entities",
+ relation_layer_name: str = "relations",
+) -> D:
+ # TODO: use document.add_all_annotations_from_other()
+ document_dict = doc.asdict()
+ entities_wo_overlap, valid_relations = _remove_overlapping_entities(
+ entities=document_dict[entity_layer_name]["annotations"],
+ relations=document_dict[relation_layer_name]["annotations"],
+ )
+
+ document_dict[entity_layer_name] = {
+ "annotations": entities_wo_overlap,
+ "predictions": [],
+ }
+ document_dict[relation_layer_name] = {
+ "annotations": valid_relations,
+ "predictions": [],
+ }
+ new_doc = type(doc).fromdict(document_dict)
+
+ return new_doc
+
+
+def _merge_spans_via_relation(
+ spans: Sequence[LabeledSpan],
+ relations: Sequence[BinaryRelation],
+ link_relation_label: str,
+ create_multi_spans: bool = True,
+) -> Tuple[Union[Set[LabeledSpan], Set[LabeledMultiSpan]], Set[BinaryRelation]]:
+ # convert list of relations to a graph to easily calculate connected components to merge
+ g = nx.Graph()
+ link_relations = []
+ other_relations = []
+ for rel in relations:
+ if rel.label == link_relation_label:
+ link_relations.append(rel)
+ # never merge spans that have not the same label
+ if (
+ not (isinstance(rel.head, LabeledSpan) or isinstance(rel.tail, LabeledSpan))
+ or rel.head.label == rel.tail.label
+ ):
+ g.add_edge(rel.head, rel.tail)
+ else:
+ logger.debug(
+ f"spans to merge do not have the same label, do not merge them: {rel.head}, {rel.tail}"
+ )
+ else:
+ other_relations.append(rel)
+
+ span_mapping = {}
+ connected_components: Set[LabeledSpan]
+ for connected_components in nx.connected_components(g):
+ # all spans in a connected component have the same label
+ label = list(span.label for span in connected_components)[0]
+ connected_components_sorted = sorted(connected_components, key=lambda span: span.start)
+ if create_multi_spans:
+ new_span = LabeledMultiSpan(
+ slices=tuple((span.start, span.end) for span in connected_components_sorted),
+ label=label,
+ )
+ else:
+ new_span = LabeledSpan(
+ start=min(span.start for span in connected_components_sorted),
+ end=max(span.end for span in connected_components_sorted),
+ label=label,
+ )
+ for span in connected_components_sorted:
+ span_mapping[span] = new_span
+ for span in spans:
+ if span not in span_mapping:
+ if create_multi_spans:
+ span_mapping[span] = LabeledMultiSpan(
+ slices=((span.start, span.end),), label=span.label, score=span.score
+ )
+ else:
+ span_mapping[span] = LabeledSpan(
+ start=span.start, end=span.end, label=span.label, score=span.score
+ )
+
+ new_spans = set(span_mapping.values())
+ new_relations = set(
+ BinaryRelation(
+ head=span_mapping[rel.head],
+ tail=span_mapping[rel.tail],
+ label=rel.label,
+ score=rel.score,
+ )
+ for rel in other_relations
+ )
+
+ return new_spans, new_relations
+
+
+def merge_spans_via_relation(
+ document: D,
+ relation_layer: str,
+ link_relation_label: str,
+ use_predicted_spans: bool = False,
+ process_predictions: bool = True,
+ create_multi_spans: bool = False,
+) -> D:
+
+ rel_layer = document[relation_layer]
+ span_layer = rel_layer.target_layer
+ new_gold_spans, new_gold_relations = _merge_spans_via_relation(
+ spans=span_layer,
+ relations=rel_layer,
+ link_relation_label=link_relation_label,
+ create_multi_spans=create_multi_spans,
+ )
+ if process_predictions:
+ new_pred_spans, new_pred_relations = _merge_spans_via_relation(
+ spans=span_layer.predictions if use_predicted_spans else span_layer,
+ relations=rel_layer.predictions,
+ link_relation_label=link_relation_label,
+ create_multi_spans=create_multi_spans,
+ )
+ else:
+ assert not use_predicted_spans
+ new_pred_spans = set(span_layer.predictions.clear())
+ new_pred_relations = set(rel_layer.predictions.clear())
+
+ relation_layer_name = relation_layer
+ span_layer_name = document[relation_layer].target_name
+ if create_multi_spans:
+ doc_dict = document.asdict()
+ for f in document.annotation_fields():
+ doc_dict.pop(f.name)
+
+ result = TextDocumentWithLabeledMultiSpansAndBinaryRelations.fromdict(doc_dict)
+ result.labeled_multi_spans.extend(new_gold_spans)
+ result.labeled_multi_spans.predictions.extend(new_pred_spans)
+ result.binary_relations.extend(new_gold_relations)
+ result.binary_relations.predictions.extend(new_pred_relations)
+ else:
+ result = document.copy(with_annotations=False)
+ result[span_layer_name].extend(new_gold_spans)
+ result[span_layer_name].predictions.extend(new_pred_spans)
+ result[relation_layer_name].extend(new_gold_relations)
+ result[relation_layer_name].predictions.extend(new_pred_relations)
+
+ return result
+
+
+def remove_partitions_by_labels(
+ document: D, partition_layer: str, label_blacklist: List[str]
+) -> D:
+ document = document.copy()
+ layer: AnnotationLayer = document[partition_layer]
+ new_partitions = []
+ for partition in layer.clear():
+ if partition.label not in label_blacklist:
+ new_partitions.append(partition)
+ layer.extend(new_partitions)
+ return document
+
+
+D_text = TypeVar("D_text", bound=Document)
+
+
+def replace_substrings_in_text(
+ document: D_text, replacements: Dict[str, str], enforce_same_length: bool = True
+) -> D_text:
+ new_text = document.text
+ for old_str, new_str in replacements.items():
+ if enforce_same_length and len(old_str) != len(new_str):
+ raise ValueError(
+ f'Replacement strings must have the same length, but got "{old_str}" -> "{new_str}"'
+ )
+ new_text = new_text.replace(old_str, new_str)
+ result_dict = document.asdict()
+ result_dict["text"] = new_text
+ result = type(document).fromdict(result_dict)
+ result.text = new_text
+ return result
+
+
+def replace_substrings_in_text_with_spaces(document: D_text, substrings: Iterable[str]) -> D_text:
+ replacements = {substring: " " * len(substring) for substring in substrings}
+ return replace_substrings_in_text(document, replacements=replacements)
diff --git a/src/evaluate.py b/src/evaluate.py
new file mode 100644
index 0000000000000000000000000000000000000000..e65b43288a9a05c8ca962b0d26e9584a4589f178
--- /dev/null
+++ b/src/evaluate.py
@@ -0,0 +1,137 @@
+import pyrootutils
+
+root = pyrootutils.setup_root(
+ search_from=__file__,
+ indicator=[".project-root"],
+ pythonpath=True,
+ dotenv=True,
+)
+
+# ------------------------------------------------------------------------------------ #
+# `pyrootutils.setup_root(...)` is an optional line at the top of each entry file
+# that helps to make the environment more robust and convenient
+#
+# the main advantages are:
+# - allows you to keep all entry files in "src/" without installing project as a package
+# - makes paths and scripts always work no matter where is your current work dir
+# - automatically loads environment variables from ".env" file if exists
+#
+# how it works:
+# - the line above recursively searches for either ".git" or "pyproject.toml" in present
+# and parent dirs, to determine the project root dir
+# - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from
+# any place without installing project as a package
+# - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml"
+# to make all paths always relative to the project root
+# - loads environment variables from ".env" file in root dir (if `dotenv=True`)
+#
+# you can remove `pyrootutils.setup_root(...)` if you:
+# 1. either install project as a package or move each entry file to the project root dir
+# 2. simply remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml"
+# 3. always run entry files from the project root dir
+#
+# https://github.com/ashleve/pyrootutils
+# ------------------------------------------------------------------------------------ #
+
+from typing import Tuple
+
+import hydra
+import pytorch_lightning as pl
+from omegaconf import DictConfig
+from pie_datasets import DatasetDict
+from pie_modules.models import * # noqa: F403
+from pie_modules.taskmodules import * # noqa: F403
+from pytorch_ie.core import PyTorchIEModel, TaskModule
+from pytorch_ie.models import * # noqa: F403
+from pytorch_ie.taskmodules import * # noqa: F403
+from pytorch_lightning import Trainer
+
+from src import utils
+from src.datamodules import PieDataModule
+from src.models import * # noqa: F403
+from src.taskmodules import * # noqa: F403
+
+log = utils.get_pylogger(__name__)
+
+
+@utils.task_wrapper
+def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
+ """Evaluates given checkpoint on a datamodule testset.
+
+ This method is wrapped in optional @task_wrapper decorator which applies extra utilities
+ before and after the call.
+
+ Args:
+ cfg (DictConfig): Configuration composed by Hydra.
+
+ Returns:
+ Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
+ """
+
+ # Set seed for random number generators in pytorch, numpy and python.random
+ if cfg.get("seed"):
+ pl.seed_everything(cfg.seed, workers=True)
+
+ # Init pytorch-ie dataset
+ log.info(f"Instantiating dataset <{cfg.dataset._target_}>")
+ dataset: DatasetDict = hydra.utils.instantiate(cfg.dataset, _convert_="partial")
+
+ # Init pytorch-ie taskmodule
+ log.info(f"Instantiating taskmodule <{cfg.taskmodule._target_}>")
+ taskmodule: TaskModule = hydra.utils.instantiate(cfg.taskmodule, _convert_="partial")
+
+ # auto-convert the dataset if the metric specifies a document type
+ dataset = taskmodule.convert_dataset(dataset)
+
+ # Init pytorch-ie datamodule
+ log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
+ datamodule: PieDataModule = hydra.utils.instantiate(
+ cfg.datamodule, dataset=dataset, taskmodule=taskmodule, _convert_="partial"
+ )
+
+ # Init pytorch-ie model
+ log.info(f"Instantiating model <{cfg.model._target_}>")
+ model: PyTorchIEModel = hydra.utils.instantiate(cfg.model, _convert_="partial")
+
+ # Init lightning loggers
+ logger = utils.instantiate_dict_entries(cfg, "logger")
+
+ # Init lightning trainer
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
+ trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger, _convert_="partial")
+
+ object_dict = {
+ "cfg": cfg,
+ "taskmodule": taskmodule,
+ "dataset": dataset,
+ "model": model,
+ "logger": logger,
+ "trainer": trainer,
+ }
+
+ if logger:
+ log.info("Logging hyperparameters!")
+ utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg)
+
+ log.info("Starting testing!")
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
+
+ # for predictions use trainer.predict(...)
+ # predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)
+
+ metric_dict = trainer.callback_metrics
+
+ return metric_dict, object_dict
+
+
+@hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="evaluate.yaml")
+def main(cfg: DictConfig) -> None:
+ metric_dict, _ = evaluate(cfg)
+
+ return metric_dict
+
+
+if __name__ == "__main__":
+ utils.replace_sys_args_with_values_from_files()
+ utils.prepare_omegaconf()
+ main()
diff --git a/src/evaluate_documents.py b/src/evaluate_documents.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2bc5e4c6d39085a13f8274f2a2ddab86f55cc9f
--- /dev/null
+++ b/src/evaluate_documents.py
@@ -0,0 +1,116 @@
+import pyrootutils
+
+root = pyrootutils.setup_root(
+ search_from=__file__,
+ indicator=[".project-root"],
+ pythonpath=True,
+ dotenv=True,
+)
+
+# ------------------------------------------------------------------------------------ #
+# `pyrootutils.setup_root(...)` is an optional line at the top of each entry file
+# that helps to make the environment more robust and convenient
+#
+# the main advantages are:
+# - allows you to keep all entry files in "src/" without installing project as a package
+# - makes paths and scripts always work no matter where is your current work dir
+# - automatically loads environment variables from ".env" file if exists
+#
+# how it works:
+# - the line above recursively searches for either ".git" or "pyproject.toml" in present
+# and parent dirs, to determine the project root dir
+# - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from
+# any place without installing project as a package
+# - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml"
+# to make all paths always relative to the project root
+# - loads environment variables from ".env" file in root dir (if `dotenv=True`)
+#
+# you can remove `pyrootutils.setup_root(...)` if you:
+# 1. either install project as a package or move each entry file to the project root dir
+# 2. simply remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml"
+# 3. always run entry files from the project root dir
+#
+# https://github.com/ashleve/pyrootutils
+# ------------------------------------------------------------------------------------ #
+
+from typing import Any, Tuple
+
+import hydra
+import pytorch_lightning as pl
+from omegaconf import DictConfig
+from pie_datasets import DatasetDict
+from pytorch_ie.core import DocumentMetric
+from pytorch_ie.metrics import * # noqa: F403
+
+from src import utils
+from src.metrics import * # noqa: F403
+
+log = utils.get_pylogger(__name__)
+
+
+@utils.task_wrapper
+def evaluate_documents(cfg: DictConfig) -> Tuple[dict, dict]:
+ """Evaluates serialized PIE documents.
+
+ This method is wrapped in optional @task_wrapper decorator which applies extra utilities
+ before and after the call.
+ Args:
+ cfg (DictConfig): Configuration composed by Hydra.
+ Returns:
+ Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
+ """
+
+ # Set seed for random number generators in pytorch, numpy and python.random
+ if cfg.get("seed"):
+ pl.seed_everything(cfg.seed, workers=True)
+
+ # Init pytorch-ie dataset
+ log.info(f"Instantiating dataset <{cfg.dataset._target_}>")
+ dataset: DatasetDict = hydra.utils.instantiate(cfg.dataset, _convert_="partial")
+
+ # Init pytorch-ie taskmodule
+ log.info(f"Instantiating metric <{cfg.metric._target_}>")
+ metric: DocumentMetric = hydra.utils.instantiate(cfg.metric, _convert_="partial")
+
+ # auto-convert the dataset if the metric specifies a document type
+ dataset = metric.convert_dataset(dataset)
+
+ # Init lightning loggers
+ loggers = utils.instantiate_dict_entries(cfg, "logger")
+
+ object_dict = {
+ "cfg": cfg,
+ "dataset": dataset,
+ "metric": metric,
+ "logger": loggers,
+ }
+
+ if loggers:
+ log.info("Logging hyperparameters!")
+ # send hparams to all loggers
+ for logger in loggers:
+ logger.log_hyperparams(cfg)
+
+ splits = cfg.get("splits", None)
+ if splits is None:
+ documents = dataset
+ else:
+ documents = type(dataset)({k: v for k, v in dataset.items() if k in splits})
+
+ metric_dict = metric(documents)
+
+ return metric_dict, object_dict
+
+
+@hydra.main(
+ version_base="1.2", config_path=str(root / "configs"), config_name="evaluate_documents.yaml"
+)
+def main(cfg: DictConfig) -> Any:
+ metric_dict, _ = evaluate_documents(cfg)
+ return metric_dict
+
+
+if __name__ == "__main__":
+ utils.replace_sys_args_with_values_from_files()
+ utils.prepare_omegaconf()
+ main()
diff --git a/src/hydra_callbacks/__init__.py b/src/hydra_callbacks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..293137642e82b75fa7dda2748f0d5835cef8f7b4
--- /dev/null
+++ b/src/hydra_callbacks/__init__.py
@@ -0,0 +1 @@
+from .save_job_return_value import SaveJobReturnValueCallback
diff --git a/src/hydra_callbacks/save_job_return_value.py b/src/hydra_callbacks/save_job_return_value.py
new file mode 100644
index 0000000000000000000000000000000000000000..0225c979efdc71114e0d8d030f7e6d1b88895c4e
--- /dev/null
+++ b/src/hydra_callbacks/save_job_return_value.py
@@ -0,0 +1,261 @@
+import json
+import logging
+import os
+import pickle
+from pathlib import Path
+from typing import Any, Dict, Generator, List, Tuple, Union
+
+import numpy as np
+import pandas as pd
+import torch
+from hydra.core.utils import JobReturn
+from hydra.experimental.callback import Callback
+from omegaconf import DictConfig
+
+
+def to_py_obj(obj):
+ """Convert a PyTorch tensor, Numpy array or python list to a python list.
+
+ Modified version of transformers.utils.generic.to_py_obj.
+ """
+ if isinstance(obj, dict):
+ return {k: to_py_obj(v) for k, v in obj.items()}
+ elif isinstance(obj, (list, tuple)):
+ return type(obj)(to_py_obj(o) for o in obj)
+ elif isinstance(obj, torch.Tensor):
+ return obj.detach().cpu().tolist()
+ elif isinstance(obj, (np.ndarray, np.number)): # tolist also works on 0d np arrays
+ return obj.tolist()
+ else:
+ return obj
+
+
+def list_of_dicts_to_dict_of_lists_recursive(list_of_dicts):
+ """Convert a list of dicts to a dict of lists recursively.
+
+ Example:
+ # works with nested dicts
+ >>> list_of_dicts_to_dict_of_lists_recursive([{"a": 1, "b": {"c": 2}}, {"a": 3, "b": {"c": 4}}])
+ {'b': {'c': [2, 4]}, 'a': [1, 3]}
+ # works with incomplete dicts
+ >>> list_of_dicts_to_dict_of_lists_recursive([{"a": 1, "b": 2}, {"a": 3}])
+ {'b': [2, None], 'a': [1, 3]}
+
+ Args:
+ list_of_dicts (List[dict]): A list of dicts.
+
+ Returns:
+ dict: A dict of lists.
+ """
+ if isinstance(list_of_dicts, list):
+ if len(list_of_dicts) == 0:
+ return {}
+ elif isinstance(list_of_dicts[0], dict):
+ keys = set()
+ for d in list_of_dicts:
+ if not isinstance(d, dict):
+ raise ValueError("Not all elements of the list are dicts.")
+ keys.update(d.keys())
+ return {
+ k: list_of_dicts_to_dict_of_lists_recursive(
+ [d.get(k, None) for d in list_of_dicts]
+ )
+ for k in keys
+ }
+ else:
+ return list_of_dicts
+ else:
+ return list_of_dicts
+
+
+def _flatten_dict_gen(d, parent_key: Tuple[str, ...] = ()) -> Generator:
+ for k, v in d.items():
+ new_key = parent_key + (k,)
+ if isinstance(v, dict):
+ yield from dict(_flatten_dict_gen(v, new_key)).items()
+ else:
+ yield new_key, v
+
+
+def flatten_dict(d: Dict[str, Any]) -> Dict[Tuple[str, ...], Any]:
+ return dict(_flatten_dict_gen(d))
+
+
+def unflatten_dict(d: Dict[Tuple[str, ...], Any]) -> Union[Dict[str, Any], Any]:
+ """Unflattens a dictionary with nested keys.
+
+ Example:
+ >>> d = {("a", "b", "c"): 1, ("a", "b", "d"): 2, ("a", "e"): 3}
+ >>> unflatten_dict(d)
+ {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
+ """
+ result: Dict[str, Any] = {}
+ for k, v in d.items():
+ if len(k) == 0:
+ if len(result) > 1:
+ raise ValueError("Cannot unflatten dictionary with multiple root keys.")
+ return v
+ current = result
+ for key in k[:-1]:
+ current = current.setdefault(key, {})
+ current[k[-1]] = v
+ return result
+
+
+def overrides_to_identifiers(overrides_per_result: List[List[str]], sep: str = "-") -> List[str]:
+ """Converts a list of lists of overrides to a list of identifiers. But takes only the overrides
+ into account, that are not identical for all results.
+
+ Example:
+ >>> overrides_per_result = [
+ ... ["a=1", "b=2", "c=3"],
+ ... ["a=1", "b=2", "c=4"],
+ ... ["a=1", "b=3", "c=3"],
+ ]
+ >>> overrides_to_identifiers(overrides_per_result)
+ ['b=2-c=3', 'b=2-c=4', 'b=3-c=3']
+
+ Args:
+ overrides_per_result (List[List[str]]): A list of lists of overrides.
+ sep (str, optional): The separator to use between the overrides. Defaults to "-".
+
+ Returns:
+ List[str]: A list of identifiers.
+ """
+ # get the overrides that are not identical for all results
+ overrides_per_result_transposed = np.array(overrides_per_result).T.tolist()
+ indices = [
+ i for i, entries in enumerate(overrides_per_result_transposed) if len(set(entries)) > 1
+ ]
+ # convert the overrides to identifiers
+ identifiers = [
+ sep.join([overrides[idx] for idx in indices]) for overrides in overrides_per_result
+ ]
+ return identifiers
+
+
+class SaveJobReturnValueCallback(Callback):
+ """Save the job return-value in ${output_dir}/{job_return_value_filename}.
+
+ This also works for multi-runs (e.g. sweeps for hyperparameter search). In this case, the result will be saved
+ additionally in a common file in the multi-run log directory. If integrate_multirun_result=True, the
+ job return-values are also aggregated (e.g. mean, min, max) and saved in another file.
+
+ params:
+ -------
+ filenames: str or List[str] (default: "job_return_value.json")
+ The filename(s) of the file(s) to save the job return-value to. If it ends with ".json",
+ the return-value will be saved as a json file. If it ends with ".pkl", the return-value will be
+ saved as a pickle file, if it ends with ".md", the return-value will be saved as a markdown file.
+ integrate_multirun_result: bool (default: True)
+ If True, the job return-values of all jobs from a multi-run will be rearranged into a dict of lists (maybe
+ nested), where the keys are the keys of the job return-values and the values are lists of the corresponding
+ values of all jobs. This is useful if you want to access specific values of all jobs in a multi-run all at once.
+ Also, aggregated values (e.g. mean, min, max) are created for all numeric values and saved in another file.
+ """
+
+ def __init__(
+ self,
+ filenames: Union[str, List[str]] = "job_return_value.json",
+ integrate_multirun_result: bool = False,
+ ) -> None:
+ self.log = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
+ self.filenames = [filenames] if isinstance(filenames, str) else filenames
+ self.integrate_multirun_result = integrate_multirun_result
+ self.job_returns: List[JobReturn] = []
+
+ def on_job_end(self, config: DictConfig, job_return: JobReturn, **kwargs: Any) -> None:
+ self.job_returns.append(job_return)
+ output_dir = Path(config.hydra.runtime.output_dir) # / Path(config.hydra.output_subdir)
+ for filename in self.filenames:
+ self._save(obj=job_return.return_value, filename=filename, output_dir=output_dir)
+
+ def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
+ if self.integrate_multirun_result:
+ # rearrange the job return-values of all jobs from a multi-run into a dict of lists (maybe nested),
+ obj = list_of_dicts_to_dict_of_lists_recursive(
+ [jr.return_value for jr in self.job_returns]
+ )
+ # also create an aggregated result
+ # convert to python object to allow selecting numeric columns
+ obj_py = to_py_obj(obj)
+ obj_flat = flatten_dict(obj_py)
+ # create dataframe from flattened dict
+ df_flat = pd.DataFrame(obj_flat)
+ # select only the numeric values
+ df_numbers_only = df_flat.select_dtypes(["number"])
+ cols_removed = set(df_flat.columns) - set(df_numbers_only.columns)
+ if len(cols_removed) > 0:
+ self.log.warning(
+ f"Removed the following columns from the aggregated result because they are not numeric: "
+ f"{cols_removed}"
+ )
+ if len(df_numbers_only.columns) == 0:
+ obj_aggregated = None
+ else:
+ # aggregate the numeric values
+ df_described = df_numbers_only.describe()
+ # add the aggregation keys (e.g. mean, min, ...) as most inner keys and convert back to dict
+ obj_flat_aggregated = df_described.T.stack().to_dict()
+ # unflatten because _save() works better with nested dicts
+ obj_aggregated = unflatten_dict(obj_flat_aggregated)
+ else:
+ # create a dict of the job return-values of all jobs from a multi-run
+ # (_save() works better with nested dicts)
+ ids = overrides_to_identifiers([jr.overrides for jr in self.job_returns])
+ obj = {identifier: jr.return_value for identifier, jr in zip(ids, self.job_returns)}
+ obj_aggregated = None
+ output_dir = Path(config.hydra.sweep.dir)
+ for filename in self.filenames:
+ self._save(
+ obj=obj,
+ filename=filename,
+ output_dir=output_dir,
+ multi_run_result=self.integrate_multirun_result,
+ )
+ # if available, also save the aggregated result
+ if obj_aggregated is not None:
+ file_base_name, ext = os.path.splitext(filename)
+ filename_aggregated = f"{file_base_name}.aggregated{ext}"
+ self._save(obj=obj_aggregated, filename=filename_aggregated, output_dir=output_dir)
+
+ def _save(
+ self, obj: Any, filename: str, output_dir: Path, multi_run_result: bool = False
+ ) -> None:
+ self.log.info(f"Saving job_return in {output_dir / filename}")
+ output_dir.mkdir(parents=True, exist_ok=True)
+ assert output_dir is not None
+ if filename.endswith(".pkl"):
+ with open(str(output_dir / filename), "wb") as file:
+ pickle.dump(obj, file, protocol=4)
+ elif filename.endswith(".json"):
+ # Convert PyTorch tensors and numpy arrays to native python types
+ obj_py = to_py_obj(obj)
+ with open(str(output_dir / filename), "w") as file:
+ json.dump(obj_py, file, indent=2)
+ elif filename.endswith(".md"):
+ # Convert PyTorch tensors and numpy arrays to native python types
+ obj_py = to_py_obj(obj)
+ obj_py_flat = flatten_dict(obj_py)
+
+ if multi_run_result:
+ # In the case of multi-run, we expect to have multiple values for each key.
+ # We therefore just convert the dict to a pandas DataFrame.
+ result = pd.DataFrame(obj_py_flat)
+ else:
+ # In the case of a single job, we expect to have only one value for each key.
+ # We therefore convert the dict to a pandas Series and ...
+ series = pd.Series(obj_py_flat)
+ if len(series.index.levels) > 1:
+ # ... if the Series has multiple index levels, we create a DataFrame by unstacking the last level.
+ result = series.unstack(-1)
+ else:
+ # ... otherwise we just unpack the one-entry index values and save the resulting Series.
+ series.index = series.index.get_level_values(0)
+ result = series
+
+ with open(str(output_dir / filename), "w") as file:
+ file.write(result.to_markdown())
+
+ else:
+ raise ValueError("Unknown file extension")
diff --git a/src/langchain_modules/span_retriever.py b/src/langchain_modules/span_retriever.py
index 8c4192843dec908efa63a15525a6553c7cb20ca8..09f2674909a8ab53d26ae320fc314bd4cc065bf2 100644
--- a/src/langchain_modules/span_retriever.py
+++ b/src/langchain_modules/span_retriever.py
@@ -4,7 +4,7 @@ import uuid
from collections import defaultdict
from copy import copy
from enum import Enum
-from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Type, Union
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
@@ -674,14 +674,14 @@ class DocumentAwareSpanRetriever(BaseRetriever, SerializableStore):
def add_pie_documents(
self,
- documents: List[TextBasedDocument],
+ documents: Iterable[TextBasedDocument],
use_predicted_annotations: bool,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Add pie documents to the retriever.
Args:
- documents: List of pie documents to add
+ documents: Iterable of pie documents to add
use_predicted_annotations: Whether to use the predicted annotations or the gold annotations
metadata: Optional metadata to add to each document
"""
diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7cd2e1657a85b5c421f21b37520a9646ec8da2b
--- /dev/null
+++ b/src/metrics/__init__.py
@@ -0,0 +1,2 @@
+from .coref_sklearn import CorefMetricsSKLearn
+from .coref_torchmetrics import CorefMetricsTorchmetrics
diff --git a/src/metrics/annotation_processor.py b/src/metrics/annotation_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..9aa5e0bce5cc54b8014c54386cb619f11d690c11
--- /dev/null
+++ b/src/metrics/annotation_processor.py
@@ -0,0 +1,23 @@
+from typing import Tuple, List
+
+from pytorch_ie.annotations import BinaryRelation, Span, LabeledMultiSpan
+from pytorch_ie.core import Annotation
+
+
+def decode_span_without_label(ann: Annotation) -> Tuple[Tuple[int, int], ...]:
+ if isinstance(ann, Span):
+ return (ann.start, ann.end),
+ elif isinstance(ann, LabeledMultiSpan):
+ return ann.slices
+ else:
+ raise ValueError("Annotation must be a Span or LabeledMultiSpan")
+
+
+def to_binary_relation_without_argument_labels(ann: Annotation) -> Tuple[Tuple[Tuple[int, int], ...], Tuple[Tuple[int, int], ...], str]:
+ if not isinstance(ann, BinaryRelation):
+ raise ValueError("Annotation must be a BinaryRelation")
+ return (
+ decode_span_without_label(ann.head),
+ decode_span_without_label(ann.tail),
+ ann.label,
+ )
diff --git a/src/metrics/coref_sklearn.py b/src/metrics/coref_sklearn.py
new file mode 100644
index 0000000000000000000000000000000000000000..2db953a2df2c97c7fe2af8ef9947673089908956
--- /dev/null
+++ b/src/metrics/coref_sklearn.py
@@ -0,0 +1,162 @@
+import logging
+import math
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from pandas import MultiIndex
+from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
+from pytorch_ie import DocumentMetric
+from pytorch_ie.core.metric import T
+from pytorch_ie.utils.hydra import resolve_target
+from torchmetrics import Metric, MetricCollection
+
+from src.hydra_callbacks.save_job_return_value import to_py_obj
+
+logger = logging.getLogger(__name__)
+
+
+def get_num_total(targets: List[int], preds: List[float]):
+ return len(targets)
+
+
+def get_num_positives(targets: List[int], preds: List[float], positive_idx: int = 1):
+ return len([v for v in targets if v == positive_idx])
+
+
+def discretize(
+ values: List[float], threshold: Union[float, List[float], dict]
+) -> Union[List[float], Dict[Any, List[float]]]:
+ if isinstance(threshold, float):
+ result = (np.array(values) >= threshold).astype(int).tolist()
+ return result
+ if isinstance(threshold, list):
+ return {t: discretize(values=values, threshold=t) for t in threshold} # type: ignore
+ if isinstance(threshold, dict):
+ thresholds = (
+ np.arange(threshold["start"], threshold["end"], threshold["step"]).round(4).tolist()
+ )
+ return discretize(values, threshold=thresholds)
+ raise TypeError(f"threshold has unknown type: {threshold}")
+
+
+class CorefMetricsSKLearn(DocumentMetric):
+ DOCUMENT_TYPE = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
+
+ def __init__(
+ self,
+ metrics: Dict[str, str],
+ thresholds: Optional[Dict[str, float]] = None,
+ default_target_idx: int = 0,
+ default_prediction_score: float = 0.0,
+ show_as_markdown: bool = False,
+ markdown_precision: int = 4,
+ plot: bool = False,
+ ):
+ self.metrics = {name: resolve_target(metric) for name, metric in metrics.items()}
+ self.thresholds = thresholds or {}
+ thresholds_not_in_metrics = {
+ name: t for name, t in self.thresholds.items() if name not in self.metrics
+ }
+ if len(thresholds_not_in_metrics) > 0:
+ logger.warning(
+ f"there are discretizing thresholds that do not have a metric: {thresholds_not_in_metrics}"
+ )
+ self.default_target_idx = default_target_idx
+ self.default_prediction_score = default_prediction_score
+ self.show_as_markdown = show_as_markdown
+ self.markdown_precision = markdown_precision
+ self.plot = plot
+
+ super().__init__()
+
+ def reset(self) -> None:
+ self._preds: List[float] = []
+ self._targets: List[int] = []
+
+ def _update(self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations) -> None:
+ target_args2idx = {
+ (rel.head, rel.tail): int(rel.score) for rel in document.binary_coref_relations
+ }
+ prediction_args2score = {
+ (rel.head, rel.tail): rel.score for rel in document.binary_coref_relations.predictions
+ }
+ all_args = set(target_args2idx) | set(prediction_args2score)
+ all_targets: List[int] = []
+ all_predictions: List[float] = []
+ for args in all_args:
+ target_idx = target_args2idx.get(args, self.default_target_idx)
+ prediction_score = prediction_args2score.get(args, self.default_prediction_score)
+ all_targets.append(target_idx)
+ all_predictions.append(prediction_score)
+ # prediction_scores = torch.tensor(all_predictions)
+ # target_indices = torch.tensor(all_targets)
+ # self.metrics.update(preds=prediction_scores, target=target_indices)
+ self._preds.extend(all_predictions)
+ self._targets.extend(all_targets)
+
+ def do_plot(self):
+ raise NotImplementedError()
+
+ from matplotlib import pyplot as plt
+
+ # Get the number of metrics
+ num_metrics = len(self.metrics)
+
+ # Calculate rows and columns for subplots (aim for a square-like layout)
+ ncols = math.ceil(math.sqrt(num_metrics))
+ nrows = math.ceil(num_metrics / ncols)
+
+ # Create the subplots
+ fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10))
+
+ # Flatten the ax_list if necessary (in case of multiple rows/columns)
+ ax_list = ax_list.flatten().tolist() # Ensure it's a list, and flatten it if necessary
+
+ # Ensure that we pass exactly the number of axes required by metrics
+ ax_list = ax_list[:num_metrics]
+
+ # Plot the metrics using the list of axes
+ self.metrics.plot(ax=ax_list, together=False)
+
+ # Adjust layout to avoid overlapping plots
+ plt.tight_layout()
+ plt.show()
+
+ def _compute(self) -> T:
+
+ if self.plot:
+ self.do_plot()
+
+ result = {}
+ for name, metric in self.metrics.items():
+
+ if name in self.thresholds:
+ preds = discretize(values=self._preds, threshold=self.thresholds[name])
+ else:
+ preds = self._preds
+ if isinstance(preds, dict):
+ metric_results = {
+ t: metric(self._targets, t_preds) for t, t_preds in preds.items()
+ }
+ # just get the max
+ max_t, max_v = max(metric_results.items(), key=lambda k_v: k_v[1])
+ result[f"{name}-{max_t}"] = max_v
+ else:
+ result[name] = metric(self._targets, preds)
+
+ result = to_py_obj(result)
+ if self.show_as_markdown:
+ import pandas as pd
+
+ series = pd.Series(result)
+ if isinstance(series.index, MultiIndex):
+ if len(series.index.levels) > 1:
+ # in fact, this is not a series anymore
+ series = series.unstack(-1)
+ else:
+ series.index = series.index.get_level_values(0)
+ logger.info(
+ f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}"
+ )
+ return result
diff --git a/src/metrics/coref_torchmetrics.py b/src/metrics/coref_torchmetrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..b19830ac5040bb8a17271c8ff158a3a150c7e8ab
--- /dev/null
+++ b/src/metrics/coref_torchmetrics.py
@@ -0,0 +1,107 @@
+import logging
+import math
+from typing import Dict
+
+import torch
+from pandas import MultiIndex
+from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
+from pytorch_ie import DocumentMetric
+from pytorch_ie.core.metric import T
+from torchmetrics import Metric, MetricCollection
+
+from src.hydra_callbacks.save_job_return_value import to_py_obj
+
+logger = logging.getLogger(__name__)
+
+
+class CorefMetricsTorchmetrics(DocumentMetric):
+ DOCUMENT_TYPE = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
+
+ def __init__(
+ self,
+ metrics: Dict[str, Metric],
+ default_target_idx: int = 0,
+ default_prediction_score: float = 0.0,
+ show_as_markdown: bool = False,
+ markdown_precision: int = 4,
+ plot: bool = False,
+ ):
+ self.metrics = MetricCollection(metrics)
+ self.default_target_idx = default_target_idx
+ self.default_prediction_score = default_prediction_score
+ self.show_as_markdown = show_as_markdown
+ self.markdown_precision = markdown_precision
+ self.plot = plot
+
+ super().__init__()
+
+ def reset(self) -> None:
+ self.metrics.reset()
+
+ def _update(self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations) -> None:
+ target_args2idx = {
+ (rel.head, rel.tail): int(rel.score) for rel in document.binary_coref_relations
+ }
+ prediction_args2score = {
+ (rel.head, rel.tail): rel.score for rel in document.binary_coref_relations.predictions
+ }
+ all_args = set(target_args2idx) | set(prediction_args2score)
+ all_targets = []
+ all_predictions = []
+ for args in all_args:
+ target_idx = target_args2idx.get(args, self.default_target_idx)
+ prediction_score = prediction_args2score.get(args, self.default_prediction_score)
+ all_targets.append(target_idx)
+ all_predictions.append(prediction_score)
+ prediction_scores = torch.tensor(all_predictions)
+ target_indices = torch.tensor(all_targets)
+ self.metrics.update(preds=prediction_scores, target=target_indices)
+
+ def do_plot(self):
+ from matplotlib import pyplot as plt
+
+ # Get the number of metrics
+ num_metrics = len(self.metrics)
+
+ # Calculate rows and columns for subplots (aim for a square-like layout)
+ ncols = math.ceil(math.sqrt(num_metrics))
+ nrows = math.ceil(num_metrics / ncols)
+
+ # Create the subplots
+ fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10))
+
+ # Flatten the ax_list if necessary (in case of multiple rows/columns)
+ ax_list = ax_list.flatten().tolist() # Ensure it's a list, and flatten it if necessary
+
+ # Ensure that we pass exactly the number of axes required by metrics
+ ax_list = ax_list[:num_metrics]
+
+ # Plot the metrics using the list of axes
+ self.metrics.plot(ax=ax_list, together=False)
+
+ # Adjust layout to avoid overlapping plots
+ plt.tight_layout()
+ plt.show()
+
+ def _compute(self) -> T:
+
+ if self.plot:
+ self.do_plot()
+
+ result = self.metrics.compute()
+
+ result = to_py_obj(result)
+ if self.show_as_markdown:
+ import pandas as pd
+
+ series = pd.Series(result)
+ if isinstance(series.index, MultiIndex):
+ if len(series.index.levels) > 1:
+ # in fact, this is not a series anymore
+ series = series.unstack(-1)
+ else:
+ series.index = series.index.get_level_values(0)
+ logger.info(
+ f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}"
+ )
+ return result
diff --git a/src/models/__init__.py b/src/models/__init__.py
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..f99145a08dfdbb616d7be60678f6f1753ccf1728 100644
--- a/src/models/__init__.py
+++ b/src/models/__init__.py
@@ -0,0 +1,5 @@
+from .sequence_classification_with_pooler import (
+ SequencePairSimilarityModelWithMaxCosineSim,
+ SequencePairSimilarityModelWithPooler2,
+ SequencePairSimilarityModelWithPoolerAndAdapter,
+)
diff --git a/src/models/components/__init__.py b/src/models/components/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/models/components/pooler.py b/src/models/components/pooler.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc5d9a18ebb6acacea4532b5d0d757ad6b7c1017
--- /dev/null
+++ b/src/models/components/pooler.py
@@ -0,0 +1,79 @@
+import torch
+from torch import Tensor, cat, nn
+
+
+class SpanMeanPooler(nn.Module):
+ """Pooler that takes the mean hidden state over spans. If the start or end index is negative, a
+ learned embedding is used. The indices are expected to have the shape [batch_size,
+ num_indices].
+
+ The resulting embeddings are concatenated, so the output shape is [batch_size, num_indices * input_dim].
+ Note this a slightly modified version of the pie_modules.models.components.pooler.SpanMaxPooler,
+ i.e. we changed the aggregation method from torch.amax to torch.mean.
+
+ Args:
+ input_dim: The input dimension of the hidden state.
+ num_indices: The number of indices to pool.
+
+ Returns:
+ The pooled hidden states with shape [batch_size, num_indices * input_dim].
+ """
+
+ def __init__(self, input_dim: int, num_indices: int = 2, **kwargs):
+ super().__init__(**kwargs)
+ self.input_dim = input_dim
+ self.num_indices = num_indices
+ self.missing_embeddings = nn.Parameter(torch.empty(num_indices, self.input_dim))
+ nn.init.normal_(self.missing_embeddings)
+
+ def forward(
+ self, hidden_state: Tensor, start_indices: Tensor, end_indices: Tensor, **kwargs
+ ) -> Tensor:
+ batch_size, seq_len, hidden_size = hidden_state.shape
+ if start_indices.shape[1] != self.num_indices:
+ raise ValueError(
+ f"number of start indices [{start_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]"
+ )
+
+ if end_indices.shape[1] != self.num_indices:
+ raise ValueError(
+ f"number of end indices [{end_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]"
+ )
+
+ # check that start_indices are before end_indices
+ mask_both_positive = (start_indices >= 0) & (end_indices >= 0)
+ mask_start_before_end = start_indices < end_indices
+ mask_valid = mask_start_before_end | ~mask_both_positive
+ if not torch.all(mask_valid):
+ raise ValueError(
+ f"values in start_indices have to be smaller than respective values in "
+ f"end_indices, but start_indices=\n{start_indices}\n and end_indices=\n{end_indices}"
+ )
+
+ # times num_indices due to concat
+ result = torch.zeros(
+ batch_size, hidden_size * self.num_indices, device=hidden_state.device
+ )
+ for batch_idx in range(batch_size):
+ current_start_indices = start_indices[batch_idx]
+ current_end_indices = end_indices[batch_idx]
+ current_embeddings = [
+ (
+ torch.mean(
+ hidden_state[
+ batch_idx, current_start_indices[i] : current_end_indices[i], :
+ ],
+ dim=0,
+ )
+ if current_start_indices[i] >= 0 and current_end_indices[i] >= 0
+ else self.missing_embeddings[i]
+ )
+ for i in range(self.num_indices)
+ ]
+ result[batch_idx] = cat(current_embeddings, 0)
+
+ return result
+
+ @property
+ def output_dim(self) -> int:
+ return self.input_dim * self.num_indices
diff --git a/src/models/sequence_classification_with_pooler.py b/src/models/sequence_classification_with_pooler.py
new file mode 100644
index 0000000000000000000000000000000000000000..52ae4e99f3beb93dfe710057d527149448234a8c
--- /dev/null
+++ b/src/models/sequence_classification_with_pooler.py
@@ -0,0 +1,166 @@
+import abc
+import logging
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from adapters import AutoAdapterModel
+from pie_modules.models import SequencePairSimilarityModelWithPooler
+from pie_modules.models.components.pooler import MENTION_POOLING
+from pie_modules.models.sequence_classification_with_pooler import (
+ InputType,
+ OutputType,
+ SequenceClassificationModelWithPooler,
+ SequenceClassificationModelWithPoolerBase,
+ TargetType,
+ separate_arguments_by_prefix,
+)
+from pytorch_ie import PyTorchIEModel
+from torch import FloatTensor, Tensor
+from transformers import AutoConfig, PreTrainedModel
+from transformers.modeling_outputs import SequenceClassifierOutput
+
+from src.models.components.pooler import SpanMeanPooler
+
+logger = logging.getLogger(__name__)
+
+
+class SequenceClassificationModelWithPoolerBase2(
+ SequenceClassificationModelWithPoolerBase, abc.ABC
+):
+ def setup_pooler(self, input_dim: int) -> Tuple[Callable, int]:
+ aggregate = self.pooler_config.get("aggregate", "max")
+ if self.pooler_config["type"] == MENTION_POOLING and aggregate != "max":
+ if aggregate == "mean":
+ pooler_config = dict(self.pooler_config)
+ pooler_config.pop("type")
+ pooler_config.pop("aggregate")
+ pooler = SpanMeanPooler(input_dim=input_dim, **pooler_config)
+ return pooler, pooler.output_dim
+ else:
+ raise ValueError(f"Unknown aggregation method: {aggregate}")
+ else:
+ return super().setup_pooler(input_dim)
+
+
+class SequenceClassificationModelWithPoolerAndAdapterBase(
+ SequenceClassificationModelWithPoolerBase2, abc.ABC
+):
+ def __init__(self, adapter_name_or_path: Optional[str] = None, **kwargs):
+ self.adapter_name_or_path = adapter_name_or_path
+ super().__init__(**kwargs)
+
+ def setup_base_model(self) -> PreTrainedModel:
+ if self.adapter_name_or_path is None:
+ return super().setup_base_model()
+ else:
+ config = AutoConfig.from_pretrained(self.model_name_or_path)
+ if self.is_from_pretrained:
+ model = AutoAdapterModel.from_config(config=config)
+ else:
+ model = AutoAdapterModel.from_pretrained(self.model_name_or_path, config=config)
+ # load the adapter in any case (it looks like it is not saved in the state or loaded
+ # from a serialized state)
+ logger.info(f"load adapter: {self.adapter_name_or_path}")
+ model.load_adapter(self.adapter_name_or_path, source="hf", set_active=True)
+ return model
+
+
+@PyTorchIEModel.register()
+class SequencePairSimilarityModelWithPooler2(
+ SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerBase2
+):
+ pass
+
+
+@PyTorchIEModel.register()
+class SequencePairSimilarityModelWithPoolerAndAdapter(
+ SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase
+):
+ pass
+
+
+@PyTorchIEModel.register()
+class SequenceClassificationModelWithPoolerAndAdapter(
+ SequenceClassificationModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase
+):
+ pass
+
+
+def get_max_cosine_sim(embeddings: Tensor, embeddings_pair: Tensor) -> Tensor:
+ # Normalize the embeddings
+ embeddings_normalized = F.normalize(embeddings, p=2, dim=1) # Shape: (n, k)
+ embeddings_normalized_pair = F.normalize(embeddings_pair, p=2, dim=1) # Shape: (m, k)
+
+ # Compute the cosine similarity matrix
+ cosine_sim = torch.mm(embeddings_normalized, embeddings_normalized_pair.T) # Shape: (n, m)
+
+ # Get the overall maximum cosine similarity value
+ max_cosine_sim = torch.max(cosine_sim) # This will return a scalar
+ return max_cosine_sim
+
+
+def get_span_embeddings(
+ embeddings: FloatTensor, start_indices: Tensor, end_indices: Tensor
+) -> List[FloatTensor]:
+ result = []
+ for embeds, starts, ends in zip(embeddings, start_indices, end_indices):
+ span_embeds = embeds[starts[0] : ends[0]]
+ result.append(span_embeds)
+ return result
+
+
+@PyTorchIEModel.register()
+class SequencePairSimilarityModelWithMaxCosineSim(SequencePairSimilarityModelWithPooler):
+ def get_pooled_output(self, model_inputs, pooler_inputs) -> List[FloatTensor]:
+ output = self.model(**model_inputs)
+ hidden_state = output.last_hidden_state
+ # pooled_output = self.pooler(hidden_state, **pooler_inputs)
+ # pooled_output = self.dropout(pooled_output)
+ span_embeds = get_span_embeddings(hidden_state, **pooler_inputs)
+ return span_embeds
+
+ def forward(
+ self,
+ inputs: InputType,
+ targets: Optional[TargetType] = None,
+ return_hidden_states: bool = False,
+ ) -> OutputType:
+ sanitized_inputs = separate_arguments_by_prefix(
+ # Note that the order of the prefixes is important because one is a prefix of the other,
+ # so we need to start with the longer!
+ arguments=inputs,
+ prefixes=["pooler_pair_", "pooler_"],
+ )
+
+ span_embeddings = self.get_pooled_output(
+ model_inputs=sanitized_inputs["remaining"]["encoding"],
+ pooler_inputs=sanitized_inputs["pooler_"],
+ )
+ span_embeddings_pair = self.get_pooled_output(
+ model_inputs=sanitized_inputs["remaining"]["encoding_pair"],
+ pooler_inputs=sanitized_inputs["pooler_pair_"],
+ )
+
+ logits_list = [
+ get_max_cosine_sim(span_embeds, span_embeds_pair)
+ for span_embeds, span_embeds_pair in zip(span_embeddings, span_embeddings_pair)
+ ]
+ logits = torch.stack(logits_list)
+
+ result = {"logits": logits}
+ if targets is not None:
+ labels = targets["scores"]
+ loss = self.loss_fct(logits, labels)
+ result["loss"] = loss
+ if return_hidden_states:
+ raise NotImplementedError("return_hidden_states is not yet implemented")
+
+ return SequenceClassifierOutput(**result)
+
+
+@PyTorchIEModel.register()
+class SequencePairSimilarityModelWithMaxCosineSimAndAdapter(
+ SequencePairSimilarityModelWithMaxCosineSim, SequencePairSimilarityModelWithPoolerAndAdapter
+):
+ pass
diff --git a/src/models/utils/__init__.py b/src/models/utils/__init__.py
index 8c9b0b43fc643aa8e1ed8e963cd653082366b6fa..4accfd9be9cfebfd8e08ff86f4e19f101bca9f27 100644
--- a/src/models/utils/__init__.py
+++ b/src/models/utils/__init__.py
@@ -1 +1,5 @@
-from .loading import load_model_from_pie_model, load_model_with_adapter, load_tokenizer_from_pie_taskmodule
\ No newline at end of file
+from .loading import (
+ load_model_from_pie_model,
+ load_model_with_adapter,
+ load_tokenizer_from_pie_taskmodule,
+)
diff --git a/src/models/utils/loading.py b/src/models/utils/loading.py
index 585152823313bd43a06aeeeb928374b0c31c313c..268d92d90ed029e4962c80486954a92b5b1e322e 100644
--- a/src/models/utils/loading.py
+++ b/src/models/utils/loading.py
@@ -23,10 +23,10 @@ def load_tokenizer_from_pie_taskmodule(taskmodule_kwargs: Dict[str, Any]) -> Pre
def load_model_with_adapter(
- model_kwargs: Dict[str, Any], adapter_kwargs: Dict[str, Any]
-) -> "ModelAdaptersMixin":
- from adapters import AutoAdapterModel, ModelAdaptersMixin
+ model_kwargs: Dict[str, Any], adapter_kwargs: Dict[str, Any]
+) -> PreTrainedModel:
+ from adapters import AutoAdapterModel
model = AutoAdapterModel.from_pretrained(**model_kwargs)
model.load_adapter(set_active=True, **adapter_kwargs)
- return model
\ No newline at end of file
+ return model
diff --git a/src/pipeline/__init__.py b/src/pipeline/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..019c9b0ead0b6c39c94441748e81c6f1774e294a
--- /dev/null
+++ b/src/pipeline/__init__.py
@@ -0,0 +1,2 @@
+from .ner_re_pipeline import NerRePipeline
+from .span_retrieval_based_re_pipeline import SpanRetrievalBasedRelationExtractionPipeline
diff --git a/src/pipeline/ner_re_pipeline.py b/src/pipeline/ner_re_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d40627e8bc6b5bf813f252112c226fc15817af5
--- /dev/null
+++ b/src/pipeline/ner_re_pipeline.py
@@ -0,0 +1,208 @@
+from __future__ import annotations
+
+import logging
+from functools import partial
+from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, TypeVar, Union
+
+from pie_modules.utils import resolve_type
+from pytorch_ie import AutoPipeline, WithDocumentTypeMixin
+from pytorch_ie.core import Document
+
+logger = logging.getLogger(__name__)
+
+
+D = TypeVar("D", bound=Document)
+
+
+def clear_annotation_layers(doc: D, layer_names: List[str], predictions: bool = False) -> None:
+ for layer_name in layer_names:
+ if predictions:
+ doc[layer_name].predictions.clear()
+ else:
+ doc[layer_name].clear()
+
+
+def move_annotations_from_predictions(doc: D, layer_names: List[str]) -> None:
+ for layer_name in layer_names:
+ annotations = list(doc[layer_name].predictions)
+ # remove any previous annotations
+ doc[layer_name].clear()
+ # each annotation can be attached to just one annotation container, so we need to clear the predictions
+ doc[layer_name].predictions.clear()
+ doc[layer_name].extend(annotations)
+
+
+def move_annotations_to_predictions(doc: D, layer_names: List[str]) -> None:
+ for layer_name in layer_names:
+ annotations = list(doc[layer_name])
+ # each annotation can be attached to just one annotation container, so we need to clear the layer
+ doc[layer_name].clear()
+ # remove any previous annotations
+ doc[layer_name].predictions.clear()
+ doc[layer_name].predictions.extend(annotations)
+
+
+def add_annotations_from_other_documents(
+ docs: Iterable[D],
+ other_docs: Sequence[Document],
+ layer_names: List[str],
+ from_predictions: bool = False,
+ to_predictions: bool = False,
+ clear_before: bool = True,
+) -> None:
+ for i, doc in enumerate(docs):
+ other_doc = other_docs[i]
+ # copy to not modify the input
+ other_doc = type(other_doc).fromdict(other_doc.asdict())
+
+ for layer_name in layer_names:
+ if clear_before:
+ doc[layer_name].clear()
+ other_layer = other_doc[layer_name]
+ if from_predictions:
+ other_layer = other_layer.predictions
+ other_annotations = list(other_layer)
+ other_layer.clear()
+ if to_predictions:
+ doc[layer_name].predictions.extend(other_annotations)
+ else:
+ doc[layer_name].extend(other_annotations)
+
+
+def process_pipeline_steps(
+ documents: Sequence[Document],
+ processors: Dict[str, Callable[[Sequence[Document]], Optional[Sequence[Document]]]],
+) -> Sequence[Document]:
+
+ # call the processors in the order they are provided
+ for step_name, processor in processors.items():
+ logger.info(f"process {step_name} ...")
+ processed_documents = processor(documents)
+ if processed_documents is not None:
+ documents = processed_documents
+
+ return documents
+
+
+def process_documents(
+ documents: List[Document], processor: Callable[..., Optional[Document]], **kwargs
+) -> List[Document]:
+ result = []
+ for doc in documents:
+ processed_doc = processor(doc, **kwargs)
+ if processed_doc is not None:
+ result.append(processed_doc)
+ else:
+ result.append(doc)
+ return result
+
+
+class DummyTaskmodule(WithDocumentTypeMixin):
+ def __init__(self, document_type: Optional[Union[Type[Document], str]]):
+ if isinstance(document_type, str):
+ self._document_type = resolve_type(document_type, expected_super_type=Document)
+ else:
+ self._document_type = document_type
+
+ @property
+ def document_type(self) -> Optional[Type[Document]]:
+ return self._document_type
+
+
+class NerRePipeline:
+ def __init__(
+ self,
+ ner_model_path: str,
+ re_model_path: str,
+ entity_layer: str,
+ relation_layer: str,
+ device: Optional[int] = None,
+ batch_size: Optional[int] = None,
+ show_progress_bar: Optional[bool] = None,
+ document_type: Optional[Union[Type[Document], str]] = None,
+ **processor_kwargs,
+ ):
+ self.taskmodule = DummyTaskmodule(document_type)
+ self.ner_model_path = ner_model_path
+ self.re_model_path = re_model_path
+ self.processor_kwargs = processor_kwargs or {}
+ self.entity_layer = entity_layer
+ self.relation_layer = relation_layer
+ # set some values for the inference processors, if provided
+ for inference_pipeline in ["ner_pipeline", "re_pipeline"]:
+ if inference_pipeline not in self.processor_kwargs:
+ self.processor_kwargs[inference_pipeline] = {}
+ if "device" not in self.processor_kwargs[inference_pipeline] and device is not None:
+ self.processor_kwargs[inference_pipeline]["device"] = device
+ if (
+ "batch_size" not in self.processor_kwargs[inference_pipeline]
+ and batch_size is not None
+ ):
+ self.processor_kwargs[inference_pipeline]["batch_size"] = batch_size
+ if (
+ "show_progress_bar" not in self.processor_kwargs[inference_pipeline]
+ and show_progress_bar is not None
+ ):
+ self.processor_kwargs[inference_pipeline]["show_progress_bar"] = show_progress_bar
+
+ def __call__(self, documents: Sequence[Document], inplace: bool = False) -> Sequence[Document]:
+
+ input_docs: Sequence[Document]
+ # we need to keep the original documents to add the gold data back
+ original_docs: Sequence[Document]
+ if inplace:
+ input_docs = documents
+ original_docs = [doc.copy() for doc in documents]
+ else:
+ input_docs = [doc.copy() for doc in documents]
+ original_docs = documents
+
+ docs_with_predictions = process_pipeline_steps(
+ documents=input_docs,
+ processors={
+ "clear_annotations": partial(
+ process_documents,
+ processor=clear_annotation_layers,
+ layer_names=[self.entity_layer, self.relation_layer],
+ **self.processor_kwargs.get("clear_annotations", {}),
+ ),
+ "ner_pipeline": AutoPipeline.from_pretrained(
+ self.ner_model_path, **self.processor_kwargs.get("ner_pipeline", {})
+ ),
+ "use_predicted_entities": partial(
+ process_documents,
+ processor=move_annotations_from_predictions,
+ layer_names=[self.entity_layer],
+ **self.processor_kwargs.get("use_predicted_entities", {}),
+ ),
+ # "create_candidate_relations": partial(
+ # process_documents,
+ # processor=CandidateRelationAdder(
+ # **self.processor_kwargs.get("create_candidate_relations", {})
+ # ),
+ # ),
+ "re_pipeline": AutoPipeline.from_pretrained(
+ self.re_model_path, **self.processor_kwargs.get("re_pipeline", {})
+ ),
+ # otherwise we can not move the entities back to predictions
+ "clear_candidate_relations": partial(
+ process_documents,
+ processor=clear_annotation_layers,
+ layer_names=[self.relation_layer],
+ **self.processor_kwargs.get("clear_candidate_relations", {}),
+ ),
+ "move_entities_to_predictions": partial(
+ process_documents,
+ processor=move_annotations_to_predictions,
+ layer_names=[self.entity_layer],
+ **self.processor_kwargs.get("move_entities_to_predictions", {}),
+ ),
+ "re_add_gold_data": partial(
+ add_annotations_from_other_documents,
+ other_docs=original_docs,
+ layer_names=[self.entity_layer, self.relation_layer],
+ **self.processor_kwargs.get("re_add_gold_data", {}),
+ ),
+ },
+ )
+ return docs_with_predictions
diff --git a/src/pipeline/span_retrieval_based_re_pipeline.py b/src/pipeline/span_retrieval_based_re_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..77f86fdd8bd69cae190b57992ae683b12accfb72
--- /dev/null
+++ b/src/pipeline/span_retrieval_based_re_pipeline.py
@@ -0,0 +1,130 @@
+import logging
+from typing import Optional, Sequence, Type
+
+from langchain_core.documents import Document as LCDocument
+from pie_datasets import Dataset, IterableDataset
+from pytorch_ie import Document, WithDocumentTypeMixin
+from pytorch_ie.annotations import BinaryRelation, LabeledSpan
+from pytorch_ie.documents import TextBasedDocument
+
+from src.langchain_modules import DocumentAwareSpanRetriever
+
+logger = logging.getLogger(__name__)
+
+
+class DummyTaskmodule(WithDocumentTypeMixin):
+ def __init__(self, document_type: Type[Document]):
+ self._document_type = document_type
+
+ @property
+ def document_type(self) -> Optional[Type[Document]]:
+ return self._document_type
+
+
+class SpanRetrievalBasedRelationExtractionPipeline:
+ """Pipeline for adding binary relations between spans based on span retrieval within the same document.
+
+ This pipeline retrieves spans for all existing spans as query and adds binary relations between the
+ query spans and the retrieved spans.
+
+ Args:
+ retriever: The span retriever to use for retrieving spans.
+ relation_label: The label to use for the binary relations.
+ relation_layer_name: The name of the annotation layer to add the binary relations to.
+ load_store_path: If provided, the retriever store(s) will be loaded from this path before processing.
+ save_store_path: If provided, the retriever store(s) will be saved to this path after processing.
+ fast_dev_run: Whether to run the pipeline in fast dev mode, i.e. only processing the first 2 documents.
+ """
+
+ def __init__(
+ self,
+ retriever: DocumentAwareSpanRetriever,
+ relation_label: str,
+ relation_layer_name: str = "binary_relations",
+ use_predicted_annotations: bool = False,
+ load_store_path: Optional[str] = None,
+ save_store_path: Optional[str] = None,
+ fast_dev_run: bool = False,
+ ):
+ self.retriever = retriever
+ if not self.retriever.retrieve_from_same_document:
+ raise NotImplementedError("Retriever must retrieve from the same document")
+ self.relation_label = relation_label
+ self.relation_layer_name = relation_layer_name
+ self.use_predicted_annotations = use_predicted_annotations
+ self.load_store_path = load_store_path
+ self.save_store_path = save_store_path
+ if self.load_store_path is not None:
+ self.retriever.load_from_directory(path=self.load_store_path)
+
+ self.fast_dev_run = fast_dev_run
+
+ # to make auto-conversion work: we request documents of type pipeline.taskmodule.document_type
+ # from the dataset
+ @property
+ def taskmodule(self) -> DummyTaskmodule:
+ return DummyTaskmodule(self.retriever.pie_document_type)
+
+ def _construct_similarity_relations(
+ self,
+ query_results: list[LCDocument],
+ query_span: LabeledSpan,
+ ) -> list[BinaryRelation]:
+ return [
+ BinaryRelation(
+ head=query_span,
+ tail=lc_doc.metadata["attached_span"],
+ label=self.relation_label,
+ score=float(lc_doc.metadata["relevance_score"]),
+ )
+ for lc_doc in query_results
+ ]
+
+ def _process_single_document(
+ self,
+ document: Document,
+ ) -> TextBasedDocument:
+ if not isinstance(document, TextBasedDocument):
+ raise ValueError("Document must be a TextBasedDocument")
+
+ self.retriever.add_pie_documents(
+ [document], use_predicted_annotations=self.use_predicted_annotations
+ )
+
+ all_new_rels = []
+ spans = self.retriever.get_base_layer(
+ document, use_predicted_annotations=self.use_predicted_annotations
+ )
+ span_id2idx = self.retriever.get_span_id2idx_from_doc(document.id)
+ for span_id, span_idx in span_id2idx.items():
+ query_span = spans[span_idx]
+ query_result = self.retriever.invoke(input=span_id)
+ query_rels = self._construct_similarity_relations(query_result, query_span=query_span)
+ all_new_rels.extend(query_rels)
+
+ if self.relation_layer_name not in document:
+ raise ValueError(f"Document does not have a layer named {self.relation_layer_name}")
+ document[self.relation_layer_name].predictions.extend(all_new_rels)
+
+ if self.retriever.retrieve_from_same_document and self.save_store_path is None:
+ self.retriever.delete_documents([document.id])
+
+ return document
+
+ def __call__(self, documents: Sequence[Document], inplace: bool = False) -> Sequence[Document]:
+ if inplace:
+ raise NotImplementedError("Inplace processing is not supported yet")
+
+ if self.fast_dev_run:
+ logger.warning("Fast dev run enabled, only processing the first 2 documents")
+ documents = documents[:2]
+
+ if not isinstance(documents, (Dataset, IterableDataset)):
+ documents = Dataset.from_documents(documents)
+
+ mapped_documents = documents.map(self._process_single_document)
+
+ if self.save_store_path is not None:
+ self.retriever.save_to_directory(path=self.save_store_path)
+
+ return mapped_documents
diff --git a/src/predict.py b/src/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4c0f8b8d77566078e9fdbcb83c0fa0f2dc0e1bb
--- /dev/null
+++ b/src/predict.py
@@ -0,0 +1,183 @@
+import pyrootutils
+
+root = pyrootutils.setup_root(
+ search_from=__file__,
+ indicator=[".project-root"],
+ pythonpath=True,
+ dotenv=True,
+)
+
+# ------------------------------------------------------------------------------------ #
+# `pyrootutils.setup_root(...)` is an optional line at the top of each entry file
+# that helps to make the environment more robust and convenient
+#
+# the main advantages are:
+# - allows you to keep all entry files in "src/" without installing project as a package
+# - makes paths and scripts always work no matter where is your current work dir
+# - automatically loads environment variables from ".env" file if exists
+#
+# how it works:
+# - the line above recursively searches for either ".git" or "pyproject.toml" in present
+# and parent dirs, to determine the project root dir
+# - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from
+# any place without installing project as a package
+# - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml"
+# to make all paths always relative to the project root
+# - loads environment variables from ".env" file in root dir (if `dotenv=True`)
+#
+# you can remove `pyrootutils.setup_root(...)` if you:
+# 1. either install project as a package or move each entry file to the project root dir
+# 2. simply remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml"
+# 3. always run entry files from the project root dir
+#
+# https://github.com/ashleve/pyrootutils
+# ------------------------------------------------------------------------------------ #
+
+import os
+import timeit
+from collections.abc import Iterable, Sequence
+from typing import Any, Dict, Optional, Tuple, Union
+
+import hydra
+import pytorch_lightning as pl
+from omegaconf import DictConfig, OmegaConf
+from pie_datasets import Dataset, DatasetDict
+from pie_modules.models import * # noqa: F403
+from pie_modules.taskmodules import * # noqa: F403
+from pytorch_ie import Document, Pipeline
+from pytorch_ie.models import * # noqa: F403
+from pytorch_ie.taskmodules import * # noqa: F403
+
+from src import utils
+from src.models import * # noqa: F403
+from src.serializer.interface import DocumentSerializer
+from src.taskmodules import * # noqa: F403
+
+log = utils.get_pylogger(__name__)
+
+
+def document_batch_iter(
+ dataset: Union[Sequence[Document], Iterable[Document]], batch_size: int
+) -> Iterable[Sequence[Document]]:
+ if isinstance(dataset, Sequence):
+ for i in range(0, len(dataset), batch_size):
+ yield dataset[i : i + batch_size]
+ elif isinstance(dataset, Iterable):
+ docs = []
+ for doc in dataset:
+ docs.append(doc)
+ if len(docs) == batch_size:
+ yield docs
+ docs = []
+ if docs:
+ yield docs
+ else:
+ raise ValueError(f"Unsupported dataset type: {type(dataset)}")
+
+
+@utils.task_wrapper
+def predict(cfg: DictConfig) -> Tuple[dict, dict]:
+ """Contains minimal example of the prediction pipeline. Uses a pretrained model to annotate
+ documents from a dataset and serializes them.
+
+ Args:
+ cfg (DictConfig): Configuration composed by Hydra.
+
+ Returns:
+ None
+ """
+
+ # Set seed for random number generators in pytorch, numpy and python.random
+ if cfg.get("seed"):
+ pl.seed_everything(cfg.seed, workers=True)
+
+ # Init pytorch-ie dataset
+ log.info(f"Instantiating dataset <{cfg.dataset._target_}>")
+ dataset: DatasetDict = hydra.utils.instantiate(cfg.dataset, _convert_="partial")
+
+ # Init pytorch-ie pipeline
+ # The pipeline, and therefore the inference step, is optional to allow for easy testing
+ # of the dataset creation and processing.
+ pipeline: Optional[Pipeline] = None
+ if cfg.get("pipeline") and cfg.pipeline.get("_target_"):
+ log.info(f"Instantiating pipeline <{cfg.pipeline._target_}> from {cfg.model_name_or_path}")
+ pipeline = hydra.utils.instantiate(cfg.pipeline, _convert_="partial")
+
+ # Per default, the model is loaded with .from_pretrained() which already loads the weights.
+ # However, ckpt_path can be used to load different weights from any checkpoint.
+ if cfg.ckpt_path is not None:
+ pipeline.model = pipeline.model.load_from_checkpoint(checkpoint_path=cfg.ckpt_path).to(
+ pipeline.device
+ )
+
+ # auto-convert the dataset if the metric specifies a document type
+ dataset = pipeline.taskmodule.convert_dataset(dataset)
+
+ # Init the serializer
+ serializer: Optional[DocumentSerializer] = None
+ if cfg.get("serializer") and cfg.serializer.get("_target_"):
+ log.info(f"Instantiating serializer <{cfg.serializer._target_}>")
+ serializer = hydra.utils.instantiate(cfg.serializer, _convert_="partial")
+
+ # select the dataset split for prediction
+ dataset_predict = dataset[cfg.dataset_split]
+
+ object_dict = {
+ "cfg": cfg,
+ "dataset": dataset,
+ "pipeline": pipeline,
+ "serializer": serializer,
+ }
+ result: Dict[str, Any] = {}
+ if pipeline is not None:
+ log.info("Starting inference!")
+ prediction_time = 0.0
+ else:
+ log.warning("No prediction pipeline is defined, skip inference!")
+ prediction_time = None
+ document_batch_size = cfg.get("document_batch_size", None)
+ for docs_batch in (
+ document_batch_iter(dataset_predict, document_batch_size)
+ if document_batch_size
+ else [dataset_predict]
+ ):
+ if pipeline is not None:
+ t_start = timeit.default_timer()
+ docs_batch = pipeline(docs_batch, inplace=False)
+ prediction_time += timeit.default_timer() - t_start # type: ignore
+
+ # serialize the documents
+ if serializer is not None:
+ # the serializer should not return the serialized documents, but write them to disk
+ # and instead return some metadata such as the path to the serialized documents
+ serializer_result = serializer(docs_batch)
+ if "serializer" in result and result["serializer"] != serializer_result:
+ log.warning(
+ f"serializer result changed from {result['serializer']} to {serializer_result}"
+ " during prediction. Only the last result is returned."
+ )
+ result["serializer"] = serializer_result
+
+ if prediction_time is not None:
+ result["prediction_time"] = prediction_time
+
+ # serialize config with resolved paths
+ if cfg.get("config_out_path"):
+ config_out_dir = os.path.dirname(cfg.config_out_path)
+ os.makedirs(config_out_dir, exist_ok=True)
+ OmegaConf.save(config=cfg, f=cfg.config_out_path)
+ result["config"] = cfg.config_out_path
+
+ return result, object_dict
+
+
+@hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="predict.yaml")
+def main(cfg: DictConfig) -> None:
+ result_dict, _ = predict(cfg)
+ return result_dict
+
+
+if __name__ == "__main__":
+ utils.replace_sys_args_with_values_from_files()
+ utils.prepare_omegaconf()
+ main()
diff --git a/src/serializer/__init__.py b/src/serializer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9df4f1769cfa62649f6905d6fee22a14aab8a2e6
--- /dev/null
+++ b/src/serializer/__init__.py
@@ -0,0 +1 @@
+from .json import JsonSerializer, JsonSerializer2
diff --git a/src/serializer/interface.py b/src/serializer/interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ba07401bad2d1983ca1b07e6b0fa4fd027ded9d
--- /dev/null
+++ b/src/serializer/interface.py
@@ -0,0 +1,16 @@
+from abc import ABC, abstractmethod
+from typing import Any, Sequence
+
+from pytorch_ie.core import Document
+
+
+class DocumentSerializer(ABC):
+ """This defines the interface for a document serializer.
+
+ The serializer should not return the serialized documents, but write them to disk and instead
+ return some metadata such as the path to the serialized documents.
+ """
+
+ @abstractmethod
+ def __call__(self, documents: Sequence[Document]) -> Any:
+ pass
diff --git a/src/serializer/json.py b/src/serializer/json.py
new file mode 100644
index 0000000000000000000000000000000000000000..c293ae3691cf402e93dc7c3ba2f0fdf0d25ad0b7
--- /dev/null
+++ b/src/serializer/json.py
@@ -0,0 +1,179 @@
+import json
+import os
+from typing import Dict, List, Optional, Sequence, Type, TypeVar
+
+from pie_datasets import Dataset, DatasetDict, IterableDataset
+from pie_datasets.core.dataset_dict import METADATA_FILE_NAME
+from pytorch_ie.core import Document
+from pytorch_ie.utils.hydra import resolve_optional_document_type, serialize_document_type
+
+from src.serializer.interface import DocumentSerializer
+from src.utils import get_pylogger
+
+log = get_pylogger(__name__)
+
+D = TypeVar("D", bound=Document)
+
+
+def as_json_lines(file_name: str) -> bool:
+ if file_name.lower().endswith(".jsonl"):
+ return True
+ elif file_name.lower().endswith(".json"):
+ return False
+ else:
+ raise Exception(f"unknown file extension: {file_name}")
+
+
+class JsonSerializer(DocumentSerializer):
+ def __init__(self, **kwargs):
+ self.default_kwargs = kwargs
+
+ @classmethod
+ def write(
+ cls,
+ documents: Sequence[Document],
+ path: str,
+ file_name: str = "documents.jsonl",
+ metadata_file_name: str = METADATA_FILE_NAME,
+ split: Optional[str] = None,
+ **kwargs,
+ ) -> Dict[str, str]:
+ realpath = os.path.realpath(path)
+ log.info(f'serialize documents to "{realpath}" ...')
+ os.makedirs(realpath, exist_ok=True)
+
+ # dump metadata including the document_type
+ if len(documents) == 0:
+ raise Exception("cannot serialize empty list of documents")
+ document_type = type(documents[0])
+ metadata = {"document_type": serialize_document_type(document_type)}
+ full_metadata_file_name = os.path.join(realpath, metadata_file_name)
+ if os.path.exists(full_metadata_file_name):
+ # load previous metadata
+ with open(full_metadata_file_name) as f:
+ previous_metadata = json.load(f)
+ if previous_metadata != metadata:
+ raise ValueError(
+ f"metadata file {full_metadata_file_name} already exists, "
+ "but the content does not match the current metadata"
+ "\nprevious metadata: {previous_metadata}"
+ "\ncurrent metadata: {metadata}"
+ )
+ else:
+ with open(full_metadata_file_name, "w") as f:
+ json.dump(metadata, f, indent=2)
+
+ if split is not None:
+ realpath = os.path.join(realpath, split)
+ os.makedirs(realpath, exist_ok=True)
+ full_file_name = os.path.join(realpath, file_name)
+ if as_json_lines(file_name):
+ # if the file already exists, append to it
+ mode = "a" if os.path.exists(full_file_name) else "w"
+ with open(full_file_name, mode) as f:
+ for doc in documents:
+ f.write(json.dumps(doc.asdict(), **kwargs) + "\n")
+ else:
+ docs_list = [doc.asdict() for doc in documents]
+ if os.path.exists(full_file_name):
+ # load previous documents
+ with open(full_file_name) as f:
+ previous_doc_list = json.load(f)
+ docs_list = previous_doc_list + docs_list
+ with open(full_file_name, "w") as f:
+ json.dump(docs_list, fp=f, **kwargs)
+ return {"path": realpath, "file_name": file_name, "metadata_file_name": metadata_file_name}
+
+ @classmethod
+ def read(
+ cls,
+ path: str,
+ document_type: Optional[Type[D]] = None,
+ file_name: str = "documents.jsonl",
+ metadata_file_name: str = METADATA_FILE_NAME,
+ split: Optional[str] = None,
+ ) -> List[D]:
+ realpath = os.path.realpath(path)
+ log.info(f'load documents from "{realpath}" ...')
+
+ # try to load metadata including the document_type
+ full_metadata_file_name = os.path.join(realpath, metadata_file_name)
+ if os.path.exists(full_metadata_file_name):
+ with open(full_metadata_file_name) as f:
+ metadata = json.load(f)
+ document_type = resolve_optional_document_type(metadata.get("document_type"))
+
+ if document_type is None:
+ raise Exception("document_type is required to load serialized documents")
+
+ if split is not None:
+ realpath = os.path.join(realpath, split)
+ full_file_name = os.path.join(realpath, file_name)
+ documents = []
+ if as_json_lines(str(file_name)):
+ with open(full_file_name) as f:
+ for line in f:
+ json_dict = json.loads(line)
+ documents.append(document_type.fromdict(json_dict))
+ else:
+ with open(full_file_name) as f:
+ json_list = json.load(f)
+ for json_dict in json_list:
+ documents.append(document_type.fromdict(json_dict))
+ return documents
+
+ def read_with_defaults(self, **kwargs) -> List[D]:
+ all_kwargs = {**self.default_kwargs, **kwargs}
+ return self.read(**all_kwargs)
+
+ def write_with_defaults(self, **kwargs) -> Dict[str, str]:
+ all_kwargs = {**self.default_kwargs, **kwargs}
+ return self.write(**all_kwargs)
+
+ def __call__(self, documents: Sequence[Document], **kwargs) -> Dict[str, str]:
+ return self.write_with_defaults(documents=documents, **kwargs)
+
+
+class JsonSerializer2(DocumentSerializer):
+ def __init__(self, **kwargs):
+ self.default_kwargs = kwargs
+
+ @classmethod
+ def write(
+ cls,
+ documents: Sequence[Document],
+ path: str,
+ split: str = "train",
+ ) -> Dict[str, str]:
+ if not isinstance(documents, (Dataset, IterableDataset)):
+ documents = Dataset.from_documents(documents)
+ dataset_dict = DatasetDict({split: documents})
+ dataset_dict.to_json(path=path)
+ return {"path": path, "split": split}
+
+ @classmethod
+ def read(
+ cls,
+ path: str,
+ document_type: Optional[Type[D]] = None,
+ split: Optional[str] = None,
+ ) -> Dataset[Document]:
+ dataset_dict = DatasetDict.from_json(
+ data_dir=path, document_type=document_type, split=split
+ )
+ if split is not None:
+ return dataset_dict[split]
+ if len(dataset_dict) == 1:
+ return dataset_dict[list(dataset_dict.keys())[0]]
+ raise ValueError(f"multiple splits found in dataset_dict: {list(dataset_dict.keys())}")
+
+ def read_with_defaults(self, **kwargs) -> Sequence[D]:
+ all_kwargs = {**self.default_kwargs, **kwargs}
+ return self.read(**all_kwargs)
+
+ def write_with_defaults(self, **kwargs) -> Dict[str, str]:
+ all_kwargs = {**self.default_kwargs, **kwargs}
+ return self.write(**all_kwargs)
+
+ def __call__(self, documents: Sequence[Document], **kwargs) -> Dict[str, str]:
+ return self.write_with_defaults(documents=documents, **kwargs)
diff --git a/src/start_demo.py b/src/start_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..942c82094d857fd211e3774c44341c44bbe7d582
--- /dev/null
+++ b/src/start_demo.py
@@ -0,0 +1,578 @@
+import hydra
+import pyrootutils
+from omegaconf import DictConfig, OmegaConf, SCMode
+
+root = pyrootutils.setup_root(
+ search_from=__file__,
+ indicator=[".project-root"],
+ pythonpath=True,
+ dotenv=True,
+)
+
+import json
+import logging
+
+import gradio as gr
+import torch
+import yaml
+
+from src.demo.annotation_utils import load_argumentation_model
+from src.demo.backend_utils import (
+ download_processed_documents,
+ process_text_from_arxiv,
+ process_uploaded_files,
+ render_annotated_document,
+ upload_processed_documents,
+ wrapped_add_annotated_pie_documents_from_dataset,
+ wrapped_process_text,
+)
+from src.demo.frontend_utils import (
+ change_tab,
+ escape_regex,
+ get_cell_for_fixed_column_from_df,
+ get_fix_df_height_css,
+ open_accordion,
+ unescape_regex,
+)
+from src.demo.rendering_utils import AVAILABLE_RENDER_MODES, HIGHLIGHT_SPANS_JS
+from src.demo.retriever_utils import (
+ get_document_as_dict,
+ get_span_annotation,
+ load_retriever,
+ retrieve_all_relevant_spans,
+ retrieve_all_similar_spans,
+ retrieve_relevant_spans,
+ retrieve_similar_spans,
+)
+
+
+def load_yaml_config(path: str) -> str:
+ with open(path, "r") as file:
+ yaml_string = file.read()
+ config = yaml.safe_load(yaml_string)
+ return yaml.dump(config)
+
+
+def resolve_config(cfg) -> dict:
+ return OmegaConf.to_container(cfg, resolve=True, structured_config_mode=SCMode.DICT)
+
+
+@hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="demo.yaml")
+def main(cfg: DictConfig) -> None:
+
+ # configure logging
+ logging.basicConfig()
+
+ # resolve everything in the config to prevent any issues with to json serialization etc.
+ cfg = resolve_config(cfg)
+
+ example_text = cfg["example_text"]
+
+ default_device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ default_retriever_config_str = load_yaml_config(cfg["default_retriever_config_path"])
+
+ default_model_name = cfg["default_model_name"]
+ default_model_revision = cfg["default_model_revision"]
+ handle_parts_of_same = cfg["handle_parts_of_same"]
+
+ default_arxiv_id = cfg["default_arxiv_id"]
+ default_load_pie_dataset_kwargs_str = json.dumps(
+ cfg["default_load_pie_dataset_kwargs"], indent=2
+ )
+
+ default_render_mode = cfg["default_render_mode"]
+ if default_render_mode not in AVAILABLE_RENDER_MODES:
+ raise ValueError(
+ f"Invalid default render mode '{default_render_mode}'. "
+ f"Choose one of {AVAILABLE_RENDER_MODES}."
+ )
+ default_render_kwargs = cfg["default_render_kwargs"]
+
+ # captions for better readability
+ default_split_regex = cfg["default_split_regex"]
+ # map from render mode to the corresponding caption
+ render_mode2caption = {
+ render_mode: cfg["render_mode_captions"].get(render_mode, render_mode)
+ for render_mode in AVAILABLE_RENDER_MODES
+ }
+ render_caption2mode = {v: k for k, v in render_mode2caption.items()}
+ default_min_similarity = cfg["default_min_similarity"]
+ layer_caption_mapping = cfg["layer_caption_mapping"]
+ relation_name_mapping = cfg["relation_name_mapping"]
+
+ gr.Info("Loading models ...")
+ argumentation_model = load_argumentation_model(
+ model_name=default_model_name,
+ revision=default_model_revision,
+ device=default_device,
+ )
+ retriever = load_retriever(
+ default_retriever_config_str, device=default_device, config_format="yaml"
+ )
+
+ with gr.Blocks(css=get_fix_df_height_css(css_class="df-docstore", max_height=300)) as demo:
+ # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
+ # models_state = gr.State((argumentation_model, embedding_model))
+ argumentation_model_state = gr.State((argumentation_model,))
+ retriever_state = gr.State((retriever,))
+
+ with gr.Row():
+ with gr.Tabs() as left_tabs:
+ with gr.Tab("User Input", id="user_input") as user_input_tab:
+ doc_id = gr.Textbox(
+ label="Document ID",
+ value="user_input",
+ )
+ doc_text = gr.Textbox(
+ label="Text",
+ lines=20,
+ value=example_text,
+ )
+
+ with gr.Accordion("Model Configuration", open=False):
+ with gr.Accordion("argumentation structure", open=True):
+ model_name = gr.Textbox(
+ label="Model Name",
+ value=default_model_name,
+ )
+ model_revision = gr.Textbox(
+ label="Model Revision",
+ value=default_model_revision,
+ )
+ load_arg_model_btn = gr.Button("Load Argumentation Model")
+
+ with gr.Accordion("retriever", open=True):
+ retriever_config = gr.Code(
+ language="yaml",
+ label="Retriever Configuration",
+ value=default_retriever_config_str,
+ lines=len(default_retriever_config_str.split("\n")),
+ )
+ load_retriever_btn = gr.Button("Load Retriever")
+
+ device = gr.Textbox(
+ label="Device (e.g. 'cuda' or 'cpu')",
+ value=default_device,
+ )
+ load_arg_model_btn.click(
+ fn=lambda _model_name, _model_revision, _device: (
+ load_argumentation_model(
+ model_name=_model_name,
+ revision=_model_revision,
+ device=_device,
+ ),
+ ),
+ inputs=[model_name, model_revision, device],
+ outputs=argumentation_model_state,
+ )
+ load_retriever_btn.click(
+ fn=lambda _retriever_config, _device, _previous_retriever: (
+ load_retriever(
+ retriever_config_str=_retriever_config,
+ device=_device,
+ previous_retriever=_previous_retriever[0],
+ config_format="yaml",
+ ),
+ ),
+ inputs=[retriever_config, device, retriever_state],
+ outputs=retriever_state,
+ )
+
+ split_regex_escaped = gr.Textbox(
+ label="Regex to partition the text",
+ placeholder="Regular expression pattern to split the text into partitions",
+ value=escape_regex(default_split_regex),
+ )
+
+ predict_btn = gr.Button("Analyse")
+
+ with gr.Tab("Analysed Document", id="analysed_document") as analysed_document_tab:
+ selected_document_id = gr.Textbox(
+ label="Document ID", max_lines=1, interactive=False
+ )
+ rendered_output = gr.HTML(label="Rendered Output")
+
+ with gr.Accordion("Render Options", open=False):
+ render_as = gr.Dropdown(
+ label="Render with",
+ choices=list(render_mode2caption.values()),
+ value=render_mode2caption[default_render_mode],
+ )
+ render_kwargs = gr.Code(
+ language="json",
+ label="Render Arguments",
+ lines=len(json.dumps(default_render_kwargs, indent=2).split("\n")),
+ value=json.dumps(default_render_kwargs, indent=2),
+ )
+ render_btn = gr.Button("Re-render")
+
+ with gr.Accordion("See plain result ...", open=False):
+ get_document_json_btn = gr.Button("Fetch annotated document as JSON")
+ document_json = gr.JSON(label="Model Output")
+
+ with gr.Tabs() as right_tabs:
+ with gr.Tab("Retrieval", id="retrieval") as retrieval_tab:
+ with gr.Accordion(
+ "Indexed Documents", open=False
+ ) as processed_documents_accordion:
+ processed_documents_df = gr.DataFrame(
+ headers=["id", "num_adus", "num_relations"],
+ interactive=False,
+ elem_classes="df-docstore",
+ )
+ gr.Markdown("Data Snapshot:")
+ with gr.Row():
+ download_processed_documents_btn = gr.DownloadButton("Download")
+ upload_processed_documents_btn = gr.UploadButton(
+ "Upload", file_types=["file"]
+ )
+
+ # currently not used
+ # relation_types = set_relation_types(
+ # argumentation_model_state.value[0], default=["supports_reversed", "contradicts_reversed"]
+ # )
+
+ # Dummy textbox to hold the hover adu id. On click on the rendered output,
+ # its content will be copied to selected_adu_id which will trigger the retrieval.
+ hover_adu_id = gr.Textbox(
+ label="ID (hover)",
+ elem_id="hover_adu_id",
+ interactive=False,
+ visible=False,
+ )
+ selected_adu_id = gr.Textbox(
+ label="ID (selected)",
+ elem_id="selected_adu_id",
+ interactive=False,
+ visible=False,
+ )
+ selected_adu_text = gr.Textbox(label="Selected ADU", interactive=False)
+
+ with gr.Accordion("Relevant ADUs from other documents", open=True):
+ relevant_adus_df = gr.DataFrame(
+ headers=[
+ "relation",
+ "adu",
+ "reference_adu",
+ "doc_id",
+ "sim_score",
+ "rel_score",
+ ],
+ interactive=False,
+ )
+
+ with gr.Accordion("Retrieval Configuration", open=False):
+ min_similarity = gr.Slider(
+ label="Minimum Similarity",
+ minimum=0.0,
+ maximum=1.0,
+ step=0.01,
+ value=default_min_similarity,
+ )
+ top_k = gr.Slider(
+ label="Top K",
+ minimum=2,
+ maximum=50,
+ step=1,
+ value=10,
+ )
+ retrieve_similar_adus_btn = gr.Button(
+ "Retrieve *similar* ADUs for *selected* ADU"
+ )
+ similar_adus_df = gr.DataFrame(
+ headers=["doc_id", "adu_id", "score", "text"], interactive=False
+ )
+ retrieve_all_similar_adus_btn = gr.Button(
+ "Retrieve *similar* ADUs for *all* ADUs in the document"
+ )
+ all_similar_adus_df = gr.DataFrame(
+ headers=["doc_id", "query_adu_id", "adu_id", "score", "text"],
+ interactive=False,
+ )
+ retrieve_all_relevant_adus_btn = gr.Button(
+ "Retrieve *relevant* ADUs for *all* ADUs in the document"
+ )
+ all_relevant_adus_df = gr.DataFrame(
+ headers=["doc_id", "adu_id", "score", "text"], interactive=False
+ )
+
+ with gr.Tab("Import Documents", id="import_documents") as import_documents_tab:
+ upload_btn = gr.UploadButton(
+ "Batch Analyse Texts",
+ file_types=["text"],
+ file_count="multiple",
+ )
+
+ with gr.Accordion("Import text from arXiv", open=False):
+ arxiv_id = gr.Textbox(
+ label="arXiv paper ID",
+ placeholder=f"e.g. {default_arxiv_id}",
+ max_lines=1,
+ )
+ load_arxiv_only_abstract = gr.Checkbox(label="abstract only", value=False)
+ load_arxiv_btn = gr.Button(
+ "Load & Analyse from arXiv", variant="secondary"
+ )
+
+ with gr.Accordion(
+ "Import argument structure annotated PIE dataset", open=False
+ ):
+ load_pie_dataset_kwargs_str = gr.Code(
+ language="json",
+ label="Parameters for Loading the PIE Dataset",
+ value=default_load_pie_dataset_kwargs_str,
+ lines=len(default_load_pie_dataset_kwargs_str.split("\n")),
+ )
+ load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset")
+
+ render_event_kwargs = dict(
+ fn=lambda _retriever, _document_id, _render_as, _render_kwargs: render_annotated_document(
+ retriever=_retriever[0],
+ document_id=_document_id,
+ render_with=render_caption2mode[_render_as],
+ render_kwargs_json=_render_kwargs,
+ ),
+ inputs=[retriever_state, selected_document_id, render_as, render_kwargs],
+ outputs=rendered_output,
+ )
+
+ show_overview_kwargs = dict(
+ fn=lambda _retriever: _retriever[0].docstore.overview(
+ layer_captions=layer_caption_mapping, use_predictions=True
+ ),
+ inputs=[retriever_state],
+ outputs=[processed_documents_df],
+ )
+ predict_btn.click(
+ fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs]
+ ).then(
+ fn=lambda _doc_text, _doc_id, _argumentation_model, _retriever, _split_regex_escaped: wrapped_process_text(
+ text=_doc_text,
+ doc_id=_doc_id,
+ argumentation_model=_argumentation_model[0],
+ retriever=_retriever[0],
+ split_regex_escaped=(
+ unescape_regex(_split_regex_escaped) if _split_regex_escaped else None
+ ),
+ handle_parts_of_same=handle_parts_of_same,
+ ),
+ inputs=[
+ doc_text,
+ doc_id,
+ argumentation_model_state,
+ retriever_state,
+ split_regex_escaped,
+ ],
+ outputs=[selected_document_id],
+ api_name="predict",
+ ).success(
+ **show_overview_kwargs
+ ).success(
+ **render_event_kwargs
+ )
+ render_btn.click(**render_event_kwargs, api_name="render")
+
+ load_arxiv_btn.click(
+ fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs]
+ ).then(
+ fn=lambda _arxiv_id, _load_arxiv_only_abstract, _argumentation_model, _retriever, _split_regex_escaped: process_text_from_arxiv(
+ arxiv_id=_arxiv_id.strip() or default_arxiv_id,
+ abstract_only=_load_arxiv_only_abstract,
+ argumentation_model=_argumentation_model[0],
+ retriever=_retriever[0],
+ split_regex_escaped=(
+ unescape_regex(_split_regex_escaped) if _split_regex_escaped else None
+ ),
+ handle_parts_of_same=handle_parts_of_same,
+ ),
+ inputs=[
+ arxiv_id,
+ load_arxiv_only_abstract,
+ argumentation_model_state,
+ retriever_state,
+ split_regex_escaped,
+ ],
+ outputs=[selected_document_id],
+ api_name="predict",
+ ).success(
+ **show_overview_kwargs
+ )
+
+ load_pie_dataset_btn.click(
+ fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs]
+ ).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then(
+ fn=lambda _retriever, _load_pie_dataset_kwargs_str: wrapped_add_annotated_pie_documents_from_dataset(
+ retriever=_retriever[0],
+ verbose=True,
+ layer_captions=layer_caption_mapping,
+ **json.loads(_load_pie_dataset_kwargs_str),
+ ),
+ inputs=[retriever_state, load_pie_dataset_kwargs_str],
+ outputs=[processed_documents_df],
+ )
+
+ selected_document_id.change(
+ fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs]
+ ).then(**render_event_kwargs)
+
+ get_document_json_btn.click(
+ fn=lambda _retriever, _document_id: get_document_as_dict(
+ retriever=_retriever[0], doc_id=_document_id
+ ),
+ inputs=[retriever_state, selected_document_id],
+ outputs=[document_json],
+ )
+
+ upload_btn.upload(
+ fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs]
+ ).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then(
+ fn=lambda _file_names, _argumentation_model, _retriever, _split_regex_escaped: process_uploaded_files(
+ file_names=_file_names,
+ argumentation_model=_argumentation_model[0],
+ retriever=_retriever[0],
+ split_regex_escaped=unescape_regex(_split_regex_escaped),
+ handle_parts_of_same=handle_parts_of_same,
+ layer_captions=layer_caption_mapping,
+ ),
+ inputs=[
+ upload_btn,
+ argumentation_model_state,
+ retriever_state,
+ split_regex_escaped,
+ ],
+ outputs=[processed_documents_df],
+ )
+ processed_documents_df.select(
+ fn=get_cell_for_fixed_column_from_df,
+ inputs=[processed_documents_df, gr.State("doc_id")],
+ outputs=[selected_document_id],
+ )
+
+ download_processed_documents_btn.click(
+ fn=lambda _retriever: download_processed_documents(
+ _retriever[0], file_name="processed_documents"
+ ),
+ inputs=[retriever_state],
+ outputs=[download_processed_documents_btn],
+ )
+ upload_processed_documents_btn.upload(
+ fn=lambda file_name, _retriever: upload_processed_documents(
+ file_name, retriever=_retriever[0], layer_captions=layer_caption_mapping
+ ),
+ inputs=[upload_processed_documents_btn, retriever_state],
+ outputs=[processed_documents_df],
+ )
+
+ retrieve_relevant_adus_event_kwargs = dict(
+ fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k: retrieve_relevant_spans(
+ retriever=_retriever[0],
+ query_span_id=_selected_adu_id,
+ k=_top_k,
+ score_threshold=_min_similarity,
+ relation_label_mapping=relation_name_mapping,
+ # columns=relevant_adus.headers
+ ),
+ inputs=[
+ retriever_state,
+ selected_adu_id,
+ min_similarity,
+ top_k,
+ ],
+ outputs=[relevant_adus_df],
+ )
+ relevant_adus_df.select(
+ fn=get_cell_for_fixed_column_from_df,
+ inputs=[relevant_adus_df, gr.State("doc_id")],
+ outputs=[selected_document_id],
+ )
+
+ selected_adu_id.change(
+ fn=lambda _retriever, _selected_adu_id: get_span_annotation(
+ retriever=_retriever[0], span_id=_selected_adu_id
+ ),
+ inputs=[retriever_state, selected_adu_id],
+ outputs=[selected_adu_text],
+ ).success(**retrieve_relevant_adus_event_kwargs)
+
+ retrieve_similar_adus_btn.click(
+ fn=lambda _retriever, _selected_adu_id, _min_similarity, _tok_k: retrieve_similar_spans(
+ retriever=_retriever[0],
+ query_span_id=_selected_adu_id,
+ k=_tok_k,
+ score_threshold=_min_similarity,
+ ),
+ inputs=[
+ retriever_state,
+ selected_adu_id,
+ min_similarity,
+ top_k,
+ ],
+ outputs=[similar_adus_df],
+ )
+ similar_adus_df.select(
+ fn=get_cell_for_fixed_column_from_df,
+ inputs=[similar_adus_df, gr.State("doc_id")],
+ outputs=[selected_document_id],
+ )
+
+ retrieve_all_similar_adus_btn.click(
+ fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_similar_spans(
+ retriever=_retriever[0],
+ query_doc_id=_document_id,
+ k=_tok_k,
+ score_threshold=_min_similarity,
+ query_span_id_column="query_span_id",
+ ),
+ inputs=[
+ retriever_state,
+ selected_document_id,
+ min_similarity,
+ top_k,
+ ],
+ outputs=[all_similar_adus_df],
+ )
+
+ retrieve_all_relevant_adus_btn.click(
+ fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_relevant_spans(
+ retriever=_retriever[0],
+ query_doc_id=_document_id,
+ k=_tok_k,
+ score_threshold=_min_similarity,
+ query_span_id_column="query_span_id",
+ ),
+ inputs=[
+ retriever_state,
+ selected_document_id,
+ min_similarity,
+ top_k,
+ ],
+ outputs=[all_relevant_adus_df],
+ )
+
+ # select query span id from the "retrieve all" result data frames
+ all_similar_adus_df.select(
+ fn=get_cell_for_fixed_column_from_df,
+ inputs=[all_similar_adus_df, gr.State("query_span_id")],
+ outputs=[selected_adu_id],
+ )
+ all_relevant_adus_df.select(
+ fn=get_cell_for_fixed_column_from_df,
+ inputs=[all_relevant_adus_df, gr.State("query_span_id")],
+ outputs=[selected_adu_id],
+ )
+
+ # argumentation_model_state.change(
+ # fn=lambda _argumentation_model: set_relation_types(_argumentation_model[0]),
+ # inputs=[argumentation_model_state],
+ # outputs=[relation_types],
+ # )
+
+ rendered_output.change(fn=None, js=HIGHLIGHT_SPANS_JS, inputs=[], outputs=[])
+
+ demo.launch()
+
+
+if __name__ == "__main__":
+
+ main()
diff --git a/src/taskmodules/__init__.py b/src/taskmodules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2045b56af4b0f8cdb9962cce2559bf1b18cdd966
--- /dev/null
+++ b/src/taskmodules/__init__.py
@@ -0,0 +1,8 @@
+from .cross_text_binary_coref import CrossTextBinaryCorefTaskModuleWithOptionalContext
+from .cross_text_binary_coref_nli import CrossTextBinaryCorefTaskModuleByNli
+from .re_text_classification_with_indices import (
+ CrossTextBinaryCorefByRETextClassificationTaskModule,
+ RETextClassificationWithIndicesTaskModuleAndWithSharpBracketMarkers,
+)
+
+CrossTextBinaryCorefTaskModule2 = CrossTextBinaryCorefByRETextClassificationTaskModule
diff --git a/src/taskmodules/components/__init__.py b/src/taskmodules/components/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/taskmodules/cross_text_binary_coref.py b/src/taskmodules/cross_text_binary_coref.py
new file mode 100644
index 0000000000000000000000000000000000000000..04b32bb63ac92e3c140ad30f0e9319f7960832f8
--- /dev/null
+++ b/src/taskmodules/cross_text_binary_coref.py
@@ -0,0 +1,116 @@
+import logging
+from typing import Optional, Sequence, TypeVar, Union
+
+from pie_modules.taskmodules import CrossTextBinaryCorefTaskModule
+from pie_modules.taskmodules.cross_text_binary_coref import (
+ DocumentType,
+ SpanDoesNotFitIntoAvailableWindow,
+ TaskEncodingType,
+)
+from pie_modules.utils.tokenization import SpanNotAlignedWithTokenException
+from pytorch_ie.annotations import Span
+from pytorch_ie.core import TaskEncoding, TaskModule
+
+logger = logging.getLogger(__name__)
+
+
+S = TypeVar("S", bound=Span)
+
+
+def shift_span(span: S, offset: int) -> S:
+ return span.copy(start=span.start + offset, end=span.end + offset)
+
+
+@TaskModule.register()
+class CrossTextBinaryCorefTaskModuleWithOptionalContext(CrossTextBinaryCorefTaskModule):
+ """Same as CrossTextBinaryCorefTaskModule, but:
+ - optionally without context.
+ """
+
+ def __init__(
+ self,
+ without_context: bool = False,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.without_context = without_context
+
+ def encode_input(
+ self,
+ document: DocumentType,
+ is_training: bool = False,
+ ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]:
+ if self.without_context:
+ return self.encode_input_without_context(document)
+ else:
+ return super().encode_input(document)
+
+ def encode_input_without_context(
+ self, document: DocumentType
+ ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]:
+ self.collect_all_relations(kind="available", relations=document.binary_coref_relations)
+ tokenizer_kwargs = dict(
+ padding=False,
+ truncation=False,
+ add_special_tokens=False,
+ )
+
+ task_encodings = []
+ for coref_rel in document.binary_coref_relations:
+
+ # TODO: This can miss instances if both texts are the same. We could check that
+ # coref_rel.head is in document.labeled_spans (same for the tail), but would this
+ # slow down the encoding?
+ if not (
+ coref_rel.head.target == document.text
+ or coref_rel.tail.target == document.text_pair
+ ):
+ raise ValueError(
+ f"It is expected that coref relations go from (head) spans over 'text' "
+ f"to (tail) spans over 'text_pair', but this is not the case for this "
+ f"relation (i.e. it points into the other direction): {coref_rel.resolve()}"
+ )
+ encoding = self.tokenizer(text=str(coref_rel.head), **tokenizer_kwargs)
+ encoding_pair = self.tokenizer(text=str(coref_rel.tail), **tokenizer_kwargs)
+
+ try:
+ current_encoding, token_span = self.truncate_encoding_around_span(
+ encoding=encoding, char_span=shift_span(coref_rel.head, -coref_rel.head.start)
+ )
+ current_encoding_pair, token_span_pair = self.truncate_encoding_around_span(
+ encoding=encoding_pair,
+ char_span=shift_span(coref_rel.tail, -coref_rel.tail.start),
+ )
+ except SpanNotAlignedWithTokenException as e:
+ logger.warning(
+ f"Could not get token offsets for argument ({e.span}) of coref relation: "
+ f"{coref_rel.resolve()}. Skip it."
+ )
+ self.collect_relation(kind="skipped_args_not_aligned", relation=coref_rel)
+ continue
+ except SpanDoesNotFitIntoAvailableWindow as e:
+ logger.warning(
+ f"Argument span [{e.span}] does not fit into available token window "
+ f"({self.available_window}). Skip it."
+ )
+ self.collect_relation(
+ kind="skipped_span_does_not_fit_into_window", relation=coref_rel
+ )
+ continue
+
+ task_encodings.append(
+ TaskEncoding(
+ document=document,
+ inputs={
+ "encoding": current_encoding,
+ "encoding_pair": current_encoding_pair,
+ "pooler_start_indices": token_span.start,
+ "pooler_end_indices": token_span.end,
+ "pooler_pair_start_indices": token_span_pair.start,
+ "pooler_pair_end_indices": token_span_pair.end,
+ },
+ metadata={"candidate_annotation": coref_rel},
+ )
+ )
+ self.collect_relation("used", coref_rel)
+ return task_encodings
diff --git a/src/taskmodules/cross_text_binary_coref_nli.py b/src/taskmodules/cross_text_binary_coref_nli.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad662031e6d90055fa0f3694267da2b1d721b187
--- /dev/null
+++ b/src/taskmodules/cross_text_binary_coref_nli.py
@@ -0,0 +1,166 @@
+import logging
+from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, TypedDict, Union
+
+import torch
+from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
+from pie_modules.taskmodules.common.mixins import RelationStatisticsMixin
+from pytorch_ie import Annotation
+from pytorch_ie.core import TaskEncoding, TaskModule
+from transformers import AutoTokenizer
+from typing_extensions import TypeAlias
+
+logger = logging.getLogger(__name__)
+
+InputEncodingType: TypeAlias = Dict[str, Any]
+TargetEncodingType: TypeAlias = Sequence[float]
+DocumentType: TypeAlias = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
+
+TaskEncodingType: TypeAlias = TaskEncoding[
+ DocumentType,
+ InputEncodingType,
+ TargetEncodingType,
+]
+
+
+class TaskOutputType(TypedDict, total=False):
+ label_pair: Tuple[str, str]
+ entailment_probability_pair: Tuple[float, float]
+
+
+ModelInputType: TypeAlias = Dict[str, torch.Tensor]
+ModelTargetType: TypeAlias = Dict[str, torch.Tensor]
+ModelOutputType: TypeAlias = Dict[str, torch.Tensor]
+
+TaskModuleType: TypeAlias = TaskModule[
+ # _InputEncoding, _TargetEncoding, _TaskBatchEncoding, _ModelBatchOutput, _TaskOutput
+ DocumentType,
+ InputEncodingType,
+ TargetEncodingType,
+ Tuple[ModelInputType, Optional[ModelTargetType]],
+ ModelTargetType,
+ TaskOutputType,
+]
+
+
+@TaskModule.register()
+class CrossTextBinaryCorefTaskModuleByNli(RelationStatisticsMixin, TaskModuleType):
+ """This taskmodule processes documents of type
+ TextPairDocumentWithLabeledSpansAndBinaryCorefRelations in preparation for a sequence
+ classification model trained for NLI. The assumption is that if the entailment class is
+ predicted for both directions, a coreference relation exists between the two spans.
+
+ It simply tokenizes and encodes the head and tail texts of the coreference relations as text
+ pairs, i.e. no context of head and tail is considered. During decoding, coreference relations
+ are created if the entailment class (see parameter entailment_label) is predicted for both
+ directions and the average probability is used as the score.
+ """
+
+ DOCUMENT_TYPE = DocumentType
+
+ def __init__(
+ self,
+ tokenizer_name_or_path: str,
+ labels: List[str],
+ entailment_label: str,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.save_hyperparameters()
+
+ self.labels = labels
+ self.entailment_label = entailment_label
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
+
+ def _post_prepare(self):
+ self.id_to_label = dict(enumerate(self.labels))
+ self.label_to_id = {v: k for k, v in self.id_to_label.items()}
+ self.entailment_idx = self.label_to_id[self.entailment_label]
+
+ def encode(self, documents: Union[DocumentType, Iterable[DocumentType]], **kwargs):
+ self.reset_statistics()
+ result = super().encode(documents=documents, **kwargs)
+ self.show_statistics()
+ return result
+
+ def encode_input(
+ self,
+ document: DocumentType,
+ is_training: bool = False,
+ ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]:
+ self.collect_all_relations(kind="available", relations=document.binary_coref_relations)
+ result = []
+ for coref_rel in document.binary_coref_relations:
+ head_text = str(coref_rel.head)
+ tail_text = str(coref_rel.tail)
+ task_encoding = TaskEncoding(
+ document=document,
+ inputs={"text": [head_text, tail_text], "text_pair": [tail_text, head_text]},
+ metadata={"candidate_annotation": coref_rel},
+ )
+ result.append(task_encoding)
+ self.collect_relation("used", coref_rel)
+ return result
+
+ def encode_target(
+ self,
+ task_encoding: TaskEncodingType,
+ ) -> Optional[TargetEncodingType]:
+ raise NotImplementedError()
+
+ def collate(
+ self,
+ task_encodings: Sequence[
+ TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType]
+ ],
+ ) -> Tuple[ModelInputType, Optional[ModelTargetType]]:
+ all_texts = []
+ all_texts_pair = []
+ for task_encoding in task_encodings:
+ all_texts.extend(task_encoding.inputs["text"])
+ all_texts_pair.extend(task_encoding.inputs["text_pair"])
+ inputs = self.tokenizer(
+ text=all_texts,
+ text_pair=all_texts_pair,
+ truncation=True,
+ padding=True,
+ return_tensors="pt",
+ )
+ if not task_encodings[0].has_targets:
+ return inputs, None
+ raise NotImplementedError()
+
+ def unbatch_output(self, model_output: ModelTargetType) -> Sequence[TaskOutputType]:
+ probs_tensor = model_output["probabilities"]
+ labels_tensor = model_output["labels"]
+
+ bs, num_classes = probs_tensor.size()
+ # Reshape the probs tensor to (bs/2, 2, num_classes)
+ probs_paired = probs_tensor.view(bs // 2, 2, num_classes).detach().cpu().tolist()
+
+ # Reshape the labels tensor to (bs/2, 2)
+ labels_paired = labels_tensor.view(bs // 2, 2).detach().cpu().tolist()
+
+ result = []
+ for (label_id, label_id_pair), (probs_list, probs_list_pair) in zip(
+ labels_paired, probs_paired
+ ):
+ task_output: TaskOutputType = {
+ "label_pair": (self.id_to_label[label_id], self.id_to_label[label_id_pair]),
+ "entailment_probability_pair": (
+ probs_list[self.entailment_idx],
+ probs_list_pair[self.entailment_idx],
+ ),
+ }
+ result.append(task_output)
+ return result
+
+ def create_annotations_from_output(
+ self,
+ task_encoding: TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType],
+ task_output: TaskOutputType,
+ ) -> Iterator[Tuple[str, Annotation]]:
+ if all(label == self.entailment_label for label in task_output["label_pair"]):
+ probs = task_output["entailment_probability_pair"]
+ score = (probs[0] + probs[1]) / 2
+ new_coref_rel = task_encoding.metadata["candidate_annotation"].copy(score=score)
+ yield "binary_coref_relations", new_coref_rel
diff --git a/src/taskmodules/re_text_classification_with_indices.py b/src/taskmodules/re_text_classification_with_indices.py
new file mode 100644
index 0000000000000000000000000000000000000000..5993cb21709a02c253851d9ddd2c09be4b03dcd4
--- /dev/null
+++ b/src/taskmodules/re_text_classification_with_indices.py
@@ -0,0 +1,176 @@
+import copy
+from itertools import chain
+from typing import Dict, Optional, Sequence, Type
+
+import torch
+from pie_modules.annotations import BinaryCorefRelation
+from pie_modules.document.processing.text_pair import shift_span
+from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
+from pie_modules.taskmodules import RETextClassificationWithIndicesTaskModule
+from pie_modules.taskmodules.common import TaskModuleWithDocumentConverter
+from pie_modules.taskmodules.re_text_classification_with_indices import MarkerFactory
+from pie_modules.taskmodules.re_text_classification_with_indices import (
+ ModelTargetType as REModelTargetType,
+)
+from pie_modules.taskmodules.re_text_classification_with_indices import (
+ TaskOutputType as RETaskOutputType,
+)
+from pytorch_ie import Document, TaskModule
+from pytorch_ie.annotations import LabeledSpan
+from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations
+
+
+class SharpBracketMarkerFactory(MarkerFactory):
+ def _get_marker(self, role: str, is_start: bool, label: Optional[str] = None) -> str:
+ result = "<"
+ if not is_start:
+ result += "/"
+ result += self._get_role_marker(role)
+ if label is not None:
+ result += f":{label}"
+ result += ">"
+ return result
+
+ def get_append_marker(self, role: str, label: Optional[str] = None) -> str:
+ role_marker = self._get_role_marker(role)
+ if label is None:
+ return f"<{role_marker}>"
+ else:
+ return f"<{role_marker}={label}>"
+
+
+@TaskModule.register()
+class RETextClassificationWithIndicesTaskModuleAndWithSharpBracketMarkers(
+ RETextClassificationWithIndicesTaskModule
+):
+ def __init__(self, use_sharp_marker: bool = False, **kwargs):
+ super().__init__(**kwargs)
+ self.use_sharp_marker = use_sharp_marker
+
+ def get_marker_factory(self) -> MarkerFactory:
+ if self.use_sharp_marker:
+ return SharpBracketMarkerFactory(role_to_marker=self.argument_role_to_marker)
+ else:
+ return MarkerFactory(role_to_marker=self.argument_role_to_marker)
+
+
+def construct_text_document_from_text_pair_coref_document(
+ document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations,
+ glue_text: str,
+ no_relation_label: str,
+ relation_label_mapping: Optional[Dict[str, str]] = None,
+ add_span_mapping_to_metadata: bool = False,
+) -> TextDocumentWithLabeledSpansAndBinaryRelations:
+ if document.text == document.text_pair:
+ new_doc = TextDocumentWithLabeledSpansAndBinaryRelations(
+ id=document.id, metadata=copy.deepcopy(document.metadata), text=document.text
+ )
+ old2new_spans: Dict[LabeledSpan, LabeledSpan] = {}
+ new2new_spans: Dict[LabeledSpan, LabeledSpan] = {}
+ for old_span in chain(document.labeled_spans, document.labeled_spans_pair):
+ new_span = old_span.copy()
+ # when detaching / copying the span, it may be the same as a previous span from the other
+ new_span = new2new_spans.get(new_span, new_span)
+ new2new_spans[new_span] = new_span
+ old2new_spans[old_span] = new_span
+ else:
+ new_doc = TextDocumentWithLabeledSpansAndBinaryRelations(
+ text=document.text + glue_text + document.text_pair,
+ id=document.id,
+ metadata=copy.deepcopy(document.metadata),
+ )
+ old2new_spans = {}
+ old2new_spans.update({span: span.copy() for span in document.labeled_spans})
+ offset = len(document.text) + len(glue_text)
+ old2new_spans.update(
+ {span: shift_span(span.copy(), offset) for span in document.labeled_spans_pair}
+ )
+
+ # sort to make order deterministic
+ new_doc.labeled_spans.extend(
+ sorted(old2new_spans.values(), key=lambda s: (s.start, s.end, s.label))
+ )
+ for old_rel in document.binary_coref_relations:
+ label = old_rel.label if old_rel.score > 0.0 else no_relation_label
+ if relation_label_mapping is not None:
+ label = relation_label_mapping.get(label, label)
+ new_rel = old_rel.copy(
+ head=old2new_spans[old_rel.head],
+ tail=old2new_spans[old_rel.tail],
+ label=label,
+ score=1.0,
+ )
+ new_doc.binary_relations.append(new_rel)
+
+ if add_span_mapping_to_metadata:
+ new_doc.metadata["span_mapping"] = old2new_spans
+ return new_doc
+
+
+@TaskModule.register()
+class CrossTextBinaryCorefByRETextClassificationTaskModule(
+ TaskModuleWithDocumentConverter,
+ RETextClassificationWithIndicesTaskModuleAndWithSharpBracketMarkers,
+):
+ def __init__(
+ self,
+ coref_relation_label: str,
+ relation_annotation: str = "binary_relations",
+ probability_threshold: float = 0.0,
+ **kwargs,
+ ):
+ if relation_annotation != "binary_relations":
+ raise ValueError(
+ f"{type(self).__name__} requires relation_annotation='binary_relations', "
+ f"but it is: {relation_annotation}"
+ )
+ super().__init__(relation_annotation=relation_annotation, **kwargs)
+ self.coref_relation_label = coref_relation_label
+ self.probability_threshold = probability_threshold
+
+ @property
+ def document_type(self) -> Optional[Type[Document]]:
+ return TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
+
+ def _get_glue_text(self) -> str:
+ result = self.tokenizer.decode(self._get_glue_token_ids())
+ return result
+
+ def _convert_document(
+ self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
+ ) -> TextDocumentWithLabeledSpansAndBinaryRelations:
+ return construct_text_document_from_text_pair_coref_document(
+ document,
+ glue_text=self._get_glue_text(),
+ relation_label_mapping={"coref": self.coref_relation_label},
+ no_relation_label=self.none_label,
+ add_span_mapping_to_metadata=True,
+ )
+
+ def _integrate_predictions_from_converted_document(
+ self,
+ document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations,
+ converted_document: TextDocumentWithLabeledSpansAndBinaryRelations,
+ ) -> None:
+ original2converted_span = converted_document.metadata["span_mapping"]
+ new2original_span = {
+ converted_s: orig_s for orig_s, converted_s in original2converted_span.items()
+ }
+
+ for rel in converted_document.binary_relations.predictions:
+ original_head = new2original_span[rel.head]
+ original_tail = new2original_span[rel.tail]
+ if rel.label != self.coref_relation_label:
+ raise ValueError(f"unexpected label: {rel.label}")
+ if rel.score >= self.probability_threshold:
+ original_predicted_rel = BinaryCorefRelation(
+ head=original_head, tail=original_tail, label="coref", score=rel.score
+ )
+ document.binary_coref_relations.predictions.append(original_predicted_rel)
+
+ def unbatch_output(self, model_output: REModelTargetType) -> Sequence[RETaskOutputType]:
+ coref_relation_idx = self.label_to_id[self.coref_relation_label]
+ # we are just concerned with the coref class, so we overwrite the labels field
+ model_output = copy.copy(model_output)
+ model_output["labels"] = torch.ones_like(model_output["labels"]) * coref_relation_idx
+ return super().unbatch_output(model_output=model_output)
diff --git a/src/train.py b/src/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..31c55f55927d1ea25f5f1d70dd3c810191ef2b9c
--- /dev/null
+++ b/src/train.py
@@ -0,0 +1,294 @@
+import pyrootutils
+
+root = pyrootutils.setup_root(
+ search_from=__file__,
+ indicator=[".project-root"],
+ pythonpath=True,
+ dotenv=True,
+)
+
+# ------------------------------------------------------------------------------------ #
+# `pyrootutils.setup_root(...)` is an optional line at the top of each entry file
+# that helps to make the environment more robust and convenient
+#
+# the main advantages are:
+# - allows you to keep all entry files in "src/" without installing project as a package
+# - makes paths and scripts always work no matter where is your current work dir
+# - automatically loads environment variables from ".env" file if exists
+#
+# how it works:
+# - the line above recursively searches for either ".git" or "pyproject.toml" in present
+# and parent dirs, to determine the project root dir
+# - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from
+# any place without installing project as a package
+# - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml"
+# to make all paths always relative to the project root
+# - loads environment variables from ".env" file in root dir (if `dotenv=True`)
+#
+# you can remove `pyrootutils.setup_root(...)` if you:
+# 1. either install project as a package or move each entry file to the project root dir
+# 2. simply remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml"
+# 3. always run entry files from the project root dir
+#
+# https://github.com/ashleve/pyrootutils
+# ------------------------------------------------------------------------------------ #
+
+import os.path
+from typing import Any, Dict, List, Optional, Tuple
+
+import hydra
+import pytorch_lightning as pl
+from omegaconf import DictConfig
+from pie_datasets import DatasetDict
+from pie_modules.models import * # noqa: F403
+from pie_modules.models import SimpleGenerativeModel
+from pie_modules.models.interface import RequiresTaskmoduleConfig
+from pie_modules.taskmodules import * # noqa: F403
+from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
+from pytorch_ie.core import PyTorchIEModel, TaskModule
+from pytorch_ie.models import * # noqa: F403
+from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
+from pytorch_ie.taskmodules import * # noqa: F403
+from pytorch_ie.taskmodules.interface import ChangesTokenizerVocabSize
+from pytorch_lightning import Callback, Trainer
+from pytorch_lightning.loggers import Logger
+
+from src import utils
+from src.datamodules import PieDataModule
+from src.models import * # noqa: F403
+from src.taskmodules import * # noqa: F403
+
+log = utils.get_pylogger(__name__)
+
+
+def get_metric_value(metric_dict: dict, metric_name: str) -> Optional[float]:
+ """Safely retrieves value of the metric logged in LightningModule."""
+
+ if not metric_name:
+ log.info("Metric name is None! Skipping metric value retrieval...")
+ return None
+
+ if metric_name not in metric_dict:
+ raise Exception(
+ f"Metric value not found! \n"
+ "Make sure metric name logged in LightningModule is correct!\n"
+ "Make sure `optimized_metric` name in `hparams_search` config is correct!"
+ )
+
+ metric_value = metric_dict[metric_name].item()
+ log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
+
+ return metric_value
+
+
+@utils.task_wrapper
+def train(cfg: DictConfig) -> Tuple[dict, dict]:
+ """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
+ training.
+
+ This method is wrapped in optional @task_wrapper decorator which applies extra utilities
+ before and after the call.
+
+ Args:
+ cfg (DictConfig): Configuration composed by Hydra.
+
+ Returns:
+ Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
+ """
+
+ # set seed for random number generators in pytorch, numpy and python.random
+ if cfg.get("seed"):
+ pl.seed_everything(cfg.seed, workers=True)
+
+ # Init pytorch-ie taskmodule
+ log.info(f"Instantiating taskmodule <{cfg.taskmodule._target_}>")
+ taskmodule: TaskModule = hydra.utils.instantiate(cfg.taskmodule, _convert_="partial")
+
+ # Init pytorch-ie dataset
+ log.info(f"Instantiating dataset <{cfg.dataset._target_}>")
+ dataset: DatasetDict = hydra.utils.instantiate(
+ cfg.dataset,
+ _convert_="partial",
+ )
+
+ # auto-convert the dataset if the taskmodule specifies a document type
+ dataset = taskmodule.convert_dataset(dataset)
+
+ # Init pytorch-ie datamodule
+ log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
+ datamodule: PieDataModule = hydra.utils.instantiate(
+ cfg.datamodule, dataset=dataset, taskmodule=taskmodule, _convert_="partial"
+ )
+ # Use the train dataset split to prepare the taskmodule
+ taskmodule.prepare(dataset[datamodule.train_split])
+
+ # Init the pytorch-ie model
+ log.info(f"Instantiating model <{cfg.model._target_}>")
+ # get additional model arguments
+ additional_model_kwargs: Dict[str, Any] = {}
+ model_cls = hydra.utils.get_class(cfg.model["_target_"])
+ # NOTE: MODIFY THE additional_model_kwargs IF YOUR MODEL REQUIRES ANY MORE PARAMETERS FROM THE TASKMODULE!
+ # SEE EXAMPLES BELOW.
+ if issubclass(model_cls, RequiresNumClasses):
+ additional_model_kwargs["num_classes"] = len(taskmodule.label_to_id)
+ if issubclass(model_cls, RequiresModelNameOrPath):
+ if "model_name_or_path" not in cfg.model:
+ raise Exception(
+ f"Please specify model_name_or_path in the model config for {model_cls.__name__}."
+ )
+ if isinstance(taskmodule, ChangesTokenizerVocabSize):
+ additional_model_kwargs["tokenizer_vocab_size"] = len(taskmodule.tokenizer)
+
+ pooler_config = cfg["model"].get("pooler")
+ if pooler_config is not None:
+ if isinstance(pooler_config, str):
+ pooler_config = {"type": pooler_config}
+ pooler_config = dict(pooler_config)
+ if pooler_config["type"] in ["start_tokens", "mention_pooling"]:
+ # NOTE: This is very hacky, we should create a new interface class, e.g. RequiresPoolerNumIndices
+ if hasattr(taskmodule, "argument_role2idx"):
+ pooler_config["num_indices"] = len(taskmodule.argument_role2idx)
+ else:
+ pooler_config["num_indices"] = 1
+ elif pooler_config["type"] == "cls_token":
+ pass
+ else:
+ raise Exception(
+ f"unknown pooler type: {pooler_config['type']}. Please adjust the train.py script for that type."
+ )
+ additional_model_kwargs["pooler"] = pooler_config
+
+ if issubclass(model_cls, RequiresTaskmoduleConfig):
+ additional_model_kwargs["taskmodule_config"] = taskmodule.config
+
+ if model_cls == SimpleGenerativeModel:
+ # There may be already some base_model_config entries in the model config. Also need to convert the
+ # base_model_config to a dict, because it is a OmegaConf object which does not accept additional entries.
+ base_model_config = (
+ dict(cfg.model.base_model_config) if "base_model_config" in cfg.model else {}
+ )
+ if isinstance(taskmodule, PointerNetworkTaskModuleForEnd2EndRE):
+ base_model_config.update(
+ dict(
+ bos_token_id=taskmodule.bos_id,
+ eos_token_id=taskmodule.eos_id,
+ pad_token_id=taskmodule.eos_id,
+ target_token_ids=taskmodule.target_token_ids,
+ embedding_weight_mapping=taskmodule.label_embedding_weight_mapping,
+ )
+ )
+ additional_model_kwargs["base_model_config"] = base_model_config
+
+ # initialize the model
+ model: PyTorchIEModel = hydra.utils.instantiate(
+ cfg.model, _convert_="partial", **additional_model_kwargs
+ )
+
+ log.info("Instantiating callbacks...")
+ callbacks: List[Callback] = utils.instantiate_dict_entries(cfg, key="callbacks")
+
+ log.info("Instantiating loggers...")
+ logger: List[Logger] = utils.instantiate_dict_entries(cfg, key="logger")
+
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
+ trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
+
+ object_dict = {
+ "cfg": cfg,
+ "dataset": dataset,
+ "taskmodule": taskmodule,
+ "model": model,
+ "callbacks": callbacks,
+ "logger": logger,
+ "trainer": trainer,
+ }
+
+ if logger:
+ log.info("Logging hyperparameters!")
+ utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg)
+
+ if cfg.model_save_dir is not None:
+ log.info(f"Save taskmodule to {cfg.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
+ taskmodule.save_pretrained(save_directory=cfg.model_save_dir, push_to_hub=cfg.push_to_hub)
+ else:
+ log.warning("the taskmodule is not saved because no save_dir is specified")
+
+ if cfg.get("train"):
+ log.info("Starting training!")
+ trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
+
+ train_metrics = trainer.callback_metrics
+
+ best_ckpt_path = trainer.checkpoint_callback.best_model_path
+ if best_ckpt_path != "":
+ log.info(f"Best ckpt path: {best_ckpt_path}")
+ best_checkpoint_file = os.path.basename(best_ckpt_path)
+ utils.log_hyperparameters(
+ logger=logger,
+ best_checkpoint=best_checkpoint_file,
+ checkpoint_dir=trainer.checkpoint_callback.dirpath,
+ )
+
+ if not cfg.trainer.get("fast_dev_run"):
+ if cfg.model_save_dir is not None:
+ if best_ckpt_path == "":
+ log.warning("Best ckpt not found! Using current weights for saving...")
+ else:
+ model = type(model).load_from_checkpoint(best_ckpt_path)
+
+ log.info(f"Save model to {cfg.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
+ model.save_pretrained(save_directory=cfg.model_save_dir, push_to_hub=cfg.push_to_hub)
+ else:
+ log.warning("the model is not saved because no save_dir is specified")
+
+ if cfg.get("validate"):
+ log.info("Starting validation!")
+ if best_ckpt_path == "":
+ log.warning("Best ckpt not found! Using current weights for validation...")
+ trainer.validate(model=model, datamodule=datamodule, ckpt_path=best_ckpt_path or None)
+ elif cfg.get("train"):
+ log.warning(
+ "Validation after training is skipped! That means, the finally reported validation scores are "
+ "the values from the *last* checkpoint, not from the *best* checkpoint (which is saved)!"
+ )
+
+ if cfg.get("test"):
+ log.info("Starting testing!")
+ if best_ckpt_path == "":
+ log.warning("Best ckpt not found! Using current weights for testing...")
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=best_ckpt_path or None)
+
+ test_metrics = trainer.callback_metrics
+
+ # merge train and test metrics
+ metric_dict = {**train_metrics, **test_metrics}
+
+ # add model_save_dir to the result so that it gets dumped to job_return_value.json
+ # if we use hydra_callbacks.SaveJobReturnValueCallback
+ if cfg.get("model_save_dir") is not None:
+ metric_dict["model_save_dir"] = cfg.model_save_dir
+
+ return metric_dict, object_dict
+
+
+@hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="train.yaml")
+def main(cfg: DictConfig) -> Optional[float]:
+ # train the model
+ metric_dict, _ = train(cfg)
+
+ # safely retrieve metric value for hydra-based hyperparameter optimization
+ if cfg.get("optimized_metric") is not None:
+ metric_value = get_metric_value(
+ metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
+ )
+
+ # return optimized metric
+ return metric_value
+ else:
+ return metric_dict
+
+
+if __name__ == "__main__":
+ utils.replace_sys_args_with_values_from_files()
+ utils.prepare_omegaconf()
+ main()
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0e526c3452c5530f02449d50aa20350b04cc08d
--- /dev/null
+++ b/src/utils/__init__.py
@@ -0,0 +1,6 @@
+from .config_utils import execute_pipeline, instantiate_dict_entries, prepare_omegaconf
+from .data_utils import download_and_unzip, filter_dataframe_and_get_column
+from .logging_utils import close_loggers, get_pylogger, log_hyperparameters
+from .rich_utils import enforce_tags, print_config_tree
+from .span_utils import distance
+from .task_utils import extras, replace_sys_args_with_values_from_files, save_file, task_wrapper
diff --git a/src/utils/config_utils.py b/src/utils/config_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..290faf6b07e55d71b1cee26831747ddb50cd5c6d
--- /dev/null
+++ b/src/utils/config_utils.py
@@ -0,0 +1,71 @@
+from copy import copy
+from typing import Any, List, Optional
+
+from hydra.utils import instantiate
+from omegaconf import DictConfig, OmegaConf
+
+from src.utils.logging_utils import get_pylogger
+
+logger = get_pylogger(__name__)
+
+
+def execute_pipeline(
+ input: Any,
+ setup: Optional[Any] = None,
+ **processors,
+) -> Any:
+ if setup is not None and callable(setup):
+ setup()
+ result = input
+ for processor_name, processor_config in processors.items():
+ if not isinstance(processor_config, dict) or "_processor_" not in processor_config:
+ continue
+ logger.info(f"call processor: {processor_name}")
+ config = copy(processor_config)
+ if not config.pop("_enabled_", True):
+ logger.warning(f"skip processor because it is disabled: {processor_name}")
+ continue
+ # rename key "_processor_" to "_target_"
+ if "_target_" in config:
+ raise ValueError(
+ f"processor {processor_name} has a key '_target_', which is not allowed"
+ )
+ config["_target_"] = config.pop("_processor_")
+ # IMPORTANT: We pass result as the first argument after the config in contrast to adding it to the config.
+ # By doing so, we prevent that it gets converted into a OmegaConf object which would be converted back to
+ # a simple dict breaking all the DatasetDict methods
+ tmp_result = instantiate(config, result, _convert_="partial")
+ if tmp_result is not None:
+ result = tmp_result
+ else:
+ logger.warning(f'processor "{processor_name}" did not return a result')
+ return result
+
+
+def instantiate_dict_entries(
+ config: DictConfig, key: str, entry_description: Optional[str] = None
+) -> List:
+ entries: List = []
+ key_config = config.get(key)
+
+ if not key_config:
+ logger.warning(f"{key} config is empty.")
+ return entries
+
+ if not isinstance(key_config, DictConfig):
+ raise TypeError("Logger config must be a DictConfig!")
+
+ for _, entry_conf in key_config.items():
+ if isinstance(entry_conf, DictConfig) and "_target_" in entry_conf:
+ logger.info(f"Instantiating {entry_description or key} <{entry_conf._target_}>")
+ entries.append(instantiate(entry_conf, _convert_="partial"))
+
+ return entries
+
+
+def prepare_omegaconf():
+ # register replace resolver (used to replace "/" with "-" in names to use them as e.g. wandb project names)
+ if not OmegaConf.has_resolver("replace"):
+ OmegaConf.register_new_resolver("replace", lambda s, x, y: s.replace(x, y))
+ else:
+ logger.warning("OmegaConf resolver 'replace' is already registered")
diff --git a/src/utils/data_utils.py b/src/utils/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a89c82ffdd6b9f154c5c0429296616247361198
--- /dev/null
+++ b/src/utils/data_utils.py
@@ -0,0 +1,33 @@
+import os
+from typing import List
+from urllib.request import urlretrieve
+from zipfile import ZipFile
+
+import pandas as pd
+
+from src.utils.logging_utils import get_pylogger
+
+log = get_pylogger(__name__)
+
+
+def filter_dataframe_and_get_column(
+ dataframe: pd.DataFrame, filter_column: str, filter_value: str, select_column: str
+) -> List[str]:
+ return dataframe[dataframe[filter_column] == filter_value][select_column].tolist()
+
+
+def download_and_unzip(
+ url: str, target_path: str, force_download: bool = False, remove_tmp_file: bool = False
+):
+ log.warning(f"download zip file from {url} to {target_path} ...")
+ if not (url.startswith("http://") or url.startswith("https://")):
+ raise ValueError(f"url needs to point to a http(s) address, but it is: {url}")
+ tmp_file = os.path.join(target_path, os.path.basename(url))
+ if os.path.exists(tmp_file) and not force_download:
+ log.warning(f"tmp file {tmp_file} already exists, skip downloading {url}")
+ else:
+ urlretrieve(url, tmp_file) # nosec
+ with ZipFile(tmp_file, "r") as zfile:
+ zfile.extractall(target_path)
+ if remove_tmp_file:
+ os.remove(tmp_file)
diff --git a/src/utils/logging_utils.py b/src/utils/logging_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dfa9d208d2c399a423f83eacfcd6735fd9c6b74
--- /dev/null
+++ b/src/utils/logging_utils.py
@@ -0,0 +1,94 @@
+import logging
+from importlib.util import find_spec
+from typing import List, Optional, Union
+
+from omegaconf import DictConfig, OmegaConf
+from pie_modules.models.interface import RequiresTaskmoduleConfig
+from pytorch_ie import PyTorchIEModel, TaskModule
+from pytorch_lightning.loggers import Logger
+from pytorch_lightning.utilities import rank_zero_only
+
+
+def get_pylogger(name=__name__) -> logging.Logger:
+ """Initializes multi-GPU-friendly python command line logger."""
+
+ logger = logging.getLogger(name)
+
+ # this ensures all logging levels get marked with the rank zero decorator
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
+ logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
+ for level in logging_levels:
+ setattr(logger, level, rank_zero_only(getattr(logger, level)))
+
+ return logger
+
+
+log = get_pylogger(__name__)
+
+
+@rank_zero_only
+def log_hyperparameters(
+ logger: Optional[List[Logger]] = None,
+ config: Optional[Union[dict, DictConfig]] = None,
+ model: Optional[PyTorchIEModel] = None,
+ taskmodule: Optional[TaskModule] = None,
+ key_prefix: str = "_",
+ **kwargs,
+) -> None:
+ """Controls which config parts are saved by lightning loggers.
+
+ Additional saves:
+ - Number of model parameters
+ """
+
+ hparams = {}
+
+ if not logger:
+ log.warning("Logger not found! Skipping hyperparameter logging...")
+ return
+
+ # this is just for backwards compatibility: usually, the taskmodule_config should be passed to
+ # the model and, thus, be logged there automatically
+ if model is not None and not isinstance(model, RequiresTaskmoduleConfig):
+ if taskmodule is None:
+ raise ValueError(
+ "If model is not an instance of RequiresTaskmoduleConfig, taskmodule must be passed!"
+ )
+ # here we use the taskmodule/model config how it is after preparation/initialization
+ hparams["taskmodule_config"] = taskmodule.config
+
+ if model is not None:
+ # save number of model parameters
+ hparams[f"{key_prefix}num_params/total"] = sum(p.numel() for p in model.parameters())
+ hparams[f"{key_prefix}num_params/trainable"] = sum(
+ p.numel() for p in model.parameters() if p.requires_grad
+ )
+ hparams[f"{key_prefix}num_params/non_trainable"] = sum(
+ p.numel() for p in model.parameters() if not p.requires_grad
+ )
+
+ if config is not None:
+ hparams[f"{key_prefix}config"] = (
+ OmegaConf.to_container(config, resolve=True) if OmegaConf.is_config(config) else config
+ )
+
+ # add additional hparams
+ for k, v in kwargs.items():
+ hparams[f"{key_prefix}{k}"] = v
+
+ # send hparams to all loggers
+ for current_logger in logger:
+ current_logger.log_hyperparams(hparams)
+
+
+def close_loggers() -> None:
+ """Makes sure all loggers closed properly (prevents logging failure during multirun)."""
+
+ log.info("Closing loggers...")
+
+ if find_spec("wandb"): # if wandb is installed
+ import wandb
+
+ if wandb.run:
+ log.info("Closing wandb!")
+ wandb.finish()
diff --git a/src/utils/rich_utils.py b/src/utils/rich_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..706df1676e4e3def6b6522fe86883d5894178af0
--- /dev/null
+++ b/src/utils/rich_utils.py
@@ -0,0 +1,110 @@
+from pathlib import Path
+from typing import Sequence
+
+import rich
+import rich.syntax
+import rich.tree
+from hydra.core.hydra_config import HydraConfig
+from omegaconf import DictConfig, OmegaConf, open_dict
+from pytorch_lightning.utilities import rank_zero_only
+from rich.prompt import Prompt
+
+from src.utils.logging_utils import get_pylogger
+
+log = get_pylogger(__name__)
+
+
+@rank_zero_only
+def print_config_tree(
+ cfg: DictConfig,
+ print_order: Sequence[str] = (
+ "datamodule",
+ "taskmodule",
+ "model",
+ "callbacks",
+ "logger",
+ "trainer",
+ "paths",
+ "extras",
+ ),
+ resolve: bool = False,
+ save_to_file: bool = False,
+) -> None:
+ """Prints content of DictConfig using Rich library and its tree structure.
+
+ Args:
+ cfg (DictConfig): Configuration composed by Hydra.
+ print_order (Sequence[str], optional): Determines in what order config components are printed.
+ resolve (bool, optional): Whether to resolve reference fields of DictConfig.
+ save_to_file (bool, optional): Whether to export config to the hydra output folder.
+ """
+
+ style = "dim"
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
+
+ queue = []
+
+ # add fields from `print_order` to queue
+ for field in print_order:
+ (
+ queue.append(field)
+ if field in cfg
+ else log.warning(
+ f"Field '{field}' not found in config. Skipping '{field}' config printing..."
+ )
+ )
+
+ # add all the other fields to queue (not specified in `print_order`)
+ for field in cfg:
+ if field not in queue:
+ queue.append(field)
+
+ # generate config tree from queue
+ for field in queue:
+ branch = tree.add(field, style=style, guide_style=style)
+
+ config_group = cfg[field]
+ if isinstance(config_group, DictConfig):
+ branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
+ else:
+ branch_content = str(config_group)
+
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
+
+ # print config tree
+ rich.print(tree)
+
+ # save config tree to file
+ if save_to_file:
+ with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
+ rich.print(tree, file=file)
+
+
+@rank_zero_only
+def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
+ """Prompts user to input tags from command line if no tags are provided in config."""
+
+ if not cfg.get("tags"):
+ if "id" in HydraConfig().cfg.hydra.job:
+ raise ValueError("Specify tags before launching a multirun!")
+
+ log.warning("No tags provided in config. Prompting user to input tags...")
+ tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
+ tags = [t.strip() for t in tags.split(",") if t != ""]
+
+ with open_dict(cfg):
+ cfg.tags = tags
+
+ log.info(f"Tags: {cfg.tags}")
+
+ if save_to_file:
+ with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
+ rich.print(cfg.tags, file=file)
+
+
+if __name__ == "__main__":
+ from hydra import compose, initialize
+
+ with initialize(version_base="1.2", config_path="../../configs"):
+ cfg = compose(config_name="train.yaml", return_hydra_config=False, overrides=[])
+ print_config_tree(cfg, resolve=False, save_to_file=False)
diff --git a/src/utils/span_utils.py b/src/utils/span_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1880c75503640b3dd017bbab342f292e7945648d
--- /dev/null
+++ b/src/utils/span_utils.py
@@ -0,0 +1,60 @@
+from typing import Tuple
+
+
+def get_overlap_len(indices_1: Tuple[int, int], indices_2: Tuple[int, int]) -> int:
+ if indices_1[0] > indices_2[0]:
+ tmp = indices_1
+ indices_1 = indices_2
+ indices_2 = tmp
+ if indices_1[1] <= indices_2[0]:
+ return 0
+ return min(indices_1[1] - indices_2[0], indices_2[1] - indices_2[0])
+
+
+def have_overlap(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> bool:
+ other_start_overlaps = start_end[0] <= other_start_end[0] < start_end[1]
+ other_end_overlaps = start_end[0] < other_start_end[1] <= start_end[1]
+ start_overlaps_other = other_start_end[0] <= start_end[0] < other_start_end[1]
+ end_overlaps_other = other_start_end[0] < start_end[1] <= other_start_end[1]
+ return other_start_overlaps or other_end_overlaps or start_overlaps_other or end_overlaps_other
+
+
+def is_contained_in(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> bool:
+ return other_start_end[0] <= start_end[0] and start_end[1] <= other_start_end[1]
+
+
+def distance_center(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> float:
+ center = (start_end[0] + start_end[1]) / 2
+ center_other = (other_start_end[0] + other_start_end[1]) / 2
+ return abs(center - center_other)
+
+
+def distance_outer(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> float:
+ _max = max(start_end[0], start_end[1], other_start_end[0], other_start_end[1])
+ _min = min(start_end[0], start_end[1], other_start_end[0], other_start_end[1])
+ return float(_max - _min)
+
+
+def distance_inner(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> float:
+ dist_start_other_end = abs(start_end[0] - other_start_end[1])
+ dist_end_other_start = abs(start_end[1] - other_start_end[0])
+ dist = float(min(dist_start_other_end, dist_end_other_start))
+ if not have_overlap(start_end, other_start_end):
+ return dist
+ else:
+ return -dist
+
+
+def distance(
+ start_end: Tuple[int, int], other_start_end: Tuple[int, int], distance_type: str
+) -> float:
+ if distance_type == "center":
+ return distance_center(start_end, other_start_end)
+ elif distance_type == "inner":
+ return distance_inner(start_end, other_start_end)
+ elif distance_type == "outer":
+ return distance_outer(start_end, other_start_end)
+ else:
+ raise ValueError(
+ f"unknown distance_type={distance_type}. use one of: center, inner, outer"
+ )
diff --git a/src/utils/task_utils.py b/src/utils/task_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb35d3dadd93e8fa12aa98a16bc89f18a421b335
--- /dev/null
+++ b/src/utils/task_utils.py
@@ -0,0 +1,178 @@
+import json
+import os
+import sys
+import time
+import warnings
+from pathlib import Path
+from typing import Callable, Dict, Optional
+
+from omegaconf import DictConfig
+from pytorch_lightning.utilities import rank_zero_only
+
+from src.utils.logging_utils import close_loggers, get_pylogger
+from src.utils.rich_utils import enforce_tags, print_config_tree
+
+log = get_pylogger(__name__)
+
+
+def task_wrapper(task_func: Callable) -> Callable:
+ """Optional decorator that wraps the task function in extra utilities.
+
+ Makes multirun more resistant to failure.
+
+ Utilities:
+ - Calling the `utils.extras()` before the task is started
+ - Calling the `utils.close_loggers()` after the task is finished
+ - Logging the exception if occurs
+ - Logging the task total execution time
+ - Logging the output dir
+ """
+
+ def wrap(cfg: DictConfig):
+
+ # apply extra utilities
+ extras(cfg)
+
+ # execute the task
+ start_time = time.time()
+ try:
+ task_result = task_func(cfg=cfg)
+ except Exception as ex:
+ log.exception("") # save exception to `.log` file
+ raise ex
+ finally:
+ path = Path(cfg.paths.output_dir, "exec_time.log")
+ content = f"'{cfg.pipeline_type}' execution time: {time.time() - start_time} (s)"
+ save_file(path, content) # save task execution time (even if exception occurs)
+ close_loggers() # close loggers (even if exception occurs so multirun won't fail)
+
+ log.info(f"Output dir: {cfg.paths.output_dir}")
+
+ return task_result
+
+ return wrap
+
+
+def extras(cfg: DictConfig) -> None:
+ """Applies optional utilities before the task is started.
+
+ Utilities:
+ - Ignoring python warnings
+ - Setting tags from command line
+ - Rich config printing
+ """
+
+ # return if no `extras` config
+ if not cfg.get("extras"):
+ log.warning("Extras config not found! ")
+ return
+
+ # disable python warnings
+ if cfg.extras.get("ignore_warnings"):
+ log.info("Disabling python warnings! ")
+ warnings.filterwarnings("ignore")
+
+ # prompt user to input tags from command line if none are provided in the config
+ if cfg.extras.get("enforce_tags"):
+ log.info("Enforcing tags! ")
+ enforce_tags(cfg, save_to_file=True)
+
+ # pretty print config tree using Rich library
+ if cfg.extras.get("print_config"):
+ log.info("Printing config tree with Rich! ")
+ print_config_tree(cfg, resolve=True, save_to_file=True)
+
+
+@rank_zero_only
+def save_file(path: str, content: str) -> None:
+ """Save file in rank zero mode (only on one process in multi-GPU setup)."""
+ with open(path, "w+") as file:
+ file.write(content)
+
+
+def load_value_from_file(path: str, split_path_key: str = ":", split_key_parts: str = "/") -> Dict:
+ """Load a value from a file. The path can point to elements within the file (see split_path_key
+ parameter) and that can be nested (see split_key_parts parameter). For now, only .json files
+ are supported.
+
+ Args:
+ path: path to the file (and data within the file)
+ split_path_key: split the path on this value to get the path to the file and the key within the file
+ split_key_parts: the value to split the key on to get the nested keys
+ """
+
+ parts_path = path.split(split_path_key, maxsplit=1)
+ file_extension = os.path.splitext(parts_path[0])[1]
+ if file_extension == ".json":
+ with open(parts_path[0], "r") as f:
+ data = json.load(f)
+ else:
+ raise ValueError(f"Expected .json file, got {file_extension}")
+
+ if len(parts_path) == 1:
+ return data
+
+ keys = parts_path[1].split(split_key_parts)
+ for key in keys:
+ data = data[key]
+ return data
+
+
+def replace_sys_args_with_values_from_files(
+ load_prefix: str = "LOAD_ARG:",
+ load_multi_prefix: str = "LOAD_MULTI_ARG:",
+ **load_value_from_file_kwargs,
+) -> None:
+ """Replaces arguments in sys.argv with values loaded from files.
+
+ Examples:
+ # config.json contains {"a": 1, "b": 2}
+ python train.py LOAD_ARG:job_return_value.json
+ # this will pass "{a:1,b:2}" as the first argument to train.py
+
+ # config.json contains [1, 2, 3]
+ python train.py LOAD_MULTI_ARG:job_return_value.json
+ # this will pass "1,2,3" as the first argument to train.py
+
+ # config.json contains {"model": {"ouput_dir": ["path1", "path2"], f1: [0.7, 0.6]}}
+ python train.py load_model=LOAD_ARG:job_return_value.json:model/output_dir
+ # this will pass "load_model=path1,path2" to train.py
+
+ Args:
+ load_prefix: the prefix to use for loading a single value from a file
+ load_multi_prefix: the prefix to use for loading a list of values from a file
+ **load_value_from_file_kwargs: additional kwargs to pass to load_value_from_file
+ """
+
+ updated_args = []
+ for arg in sys.argv[1:]:
+ is_multirun_arg = False
+ if load_prefix in arg:
+ parts = arg.split(load_prefix, maxsplit=1)
+ elif load_multi_prefix in arg:
+ parts = arg.split(load_multi_prefix, maxsplit=1)
+ is_multirun_arg = True
+ else:
+ updated_args.append(arg)
+ continue
+ if len(parts) == 2:
+ log.warning(f'Replacing argument value for "{parts[0]}" with content from {parts[1]}')
+ json_value = load_value_from_file(parts[1], **load_value_from_file_kwargs)
+ json_value_str = json.dumps(json_value)
+ # replace quotes and spaces
+ json_value_str = json_value_str.replace('"', "").replace(" ", "")
+ # remove outer brackets
+ if is_multirun_arg:
+ if not isinstance(json_value, list):
+ raise ValueError(
+ f"Expected list for multirun argument, got {type(json_value)}. If you just want "
+ f"to set a single value, use {load_prefix} instead of {load_multi_prefix}."
+ )
+ json_value_str = json_value_str[1:-1]
+ # add outer quotes
+ modified_arg = f"{parts[0]}{json_value_str}"
+ updated_args.append(modified_arg)
+ else:
+ updated_args.append(arg)
+ # Set sys.argv to the updated arguments
+ sys.argv = [sys.argv[0]] + updated_args
diff --git a/src/vendor/__init__.py b/src/vendor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..203cb9c9fccf814b835460da1181764ea6ae8fc4
--- /dev/null
+++ b/src/vendor/__init__.py
@@ -0,0 +1 @@
+# use this folder for storing third party code that cannot be installed using pip/conda