diff --git a/requirements.txt b/requirements.txt index 9d7d7956f0008e648a082160989527beca6a1e90..2ed355b01fe97496ca6ea7add34b04e9d274a3b5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,3 @@ -# this requires python>=3.10 -gradio~=5.4.0 -prettytable==3.10.0 -beautifulsoup4==4.12.3 -# numpy 2.0.0 breaks the code -numpy==1.25.2 -scipy==1.13.0 -arxiv==2.1.3 -pyrootutils>=1.0.0,<1.1.0 - -########## from root requirements ########## # --------- pytorch-ie --------- # pytorch-ie>=0.29.6,<0.32.0 pie-datasets>=0.10.5,<0.11.0 @@ -16,20 +5,51 @@ pie-modules>=0.14.0,<0.15.0 # --------- models -------- # adapters>=0.1.2,<0.2.0 -# ADU retrieval (and demo, in future): +pytorch-crf~=0.7.2 +# --------- retriever -------- # langchain>=0.3.0,<0.4.0 langchain-core>=0.3.0,<0.4.0 langchain-community>=0.3.0,<0.4.0 # we use QDrant as vectorstore backend langchain-qdrant>=0.1.0,<0.2.0 qdrant-client>=1.12.0,<2.0.0 -# 0.26 seems to be broken when used with adapters, see https://github.com/adapter-hub/adapters/issues/748 -huggingface_hub<0.26.0 # 0.26 seems to be broken -# to to handle segmented entities (if HANDLE_PARTS_OF_SAME=True) -networkx>=3.0.0,<4.0.0 -# --------- config --------- # +# --------- demo -------- # +gradio~=5.4.0 +arxiv~=2.1.3 + +# --------- hydra --------- # hydra-core>=1.3.0 +hydra-colorlog>=1.2.0 +hydra-optuna-sweeper>=1.2.0 + +# --------- loggers --------- # +wandb +# neptune-client +# mlflow +# comet-ml +# tensorboard +# aim -# --------- dev --------- # +# --------- linters --------- # pre-commit # hooks for applying linters on commit +black # code formatting +isort # import sorting +flake8 # code analysis +nbstripout # remove output from jupyter notebooks + +# --------- others --------- # +pyrootutils # standardizing the project root setup +python-dotenv # loading env variables from .env file +rich # beautiful text formatting in terminal +pytest # tests +pytest-cov # test coverageataset +sh # for running bash commands in some tests +pudb # debugger +tabulate # show statistics as markdown +plotext # show statistics as plots +prettytable # rendering annotated docs as table (demo) +beautifulsoup4 # rendering annotated docs with displacy + highlighted relations (demo) +# 0.26 seems to be broken when used with adapters, see https://github.com/adapter-hub/adapters/issues/748 +huggingface_hub<0.26.0 # interaction with HF hub +networkx~=3.2.1 # to handle segmented entities (e.g if HANDLE_PARTS_OF_SAME=True in demo) diff --git a/src/datamodules/__init__.py b/src/datamodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1d8636dd7e5d9fd23e304ed3f2926c6020fe7f8f --- /dev/null +++ b/src/datamodules/__init__.py @@ -0,0 +1 @@ +from .datamodule import PieDataModule diff --git a/src/datamodules/components/__init__.py b/src/datamodules/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/datamodules/components/sampler.py b/src/datamodules/components/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..c1f69ce514db0b02cac579f8784b9fdb578f65e3 --- /dev/null +++ b/src/datamodules/components/sampler.py @@ -0,0 +1,67 @@ +"""This is a slightly modified version of https://github.com/ufoym/imbalanced-dataset-sampler.""" + +from typing import Callable, List, Optional + +import pandas as pd +import torch +import torch.utils.data + + +class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler): + """Samples elements randomly from a given list of indices for imbalanced dataset. + + Arguments: + indices: a list of indices + num_samples: number of samples to draw + callback_get_label: a callback-like function which takes one argument - the dataset + """ + + def __init__( + self, + dataset, + labels: Optional[List] = None, + indices: Optional[List] = None, + num_samples: Optional[int] = None, + callback_get_label: Optional[Callable] = None, + ): + # if indices is not provided, all elements in the dataset will be considered + self.indices = list(range(len(dataset))) if indices is None else indices + + # define custom callback + self.callback_get_label = callback_get_label + + # if num_samples is not provided, draw `len(indices)` samples in each iteration + self.num_samples = len(self.indices) if num_samples is None else num_samples + + # distribution of classes in the dataset + df = pd.DataFrame() + df["label"] = self._get_labels(dataset) if labels is None else labels + df.index = self.indices + df = df.sort_index() + + label_to_count = df["label"].value_counts() + + weights = 1.0 / label_to_count[df["label"]] + + self.weights = torch.DoubleTensor(weights.to_list()) + + def _get_labels(self, dataset): + if self.callback_get_label: + return self.callback_get_label(dataset) + elif isinstance(dataset, torch.utils.data.TensorDataset): + return dataset.tensors[1] + elif isinstance(dataset, torch.utils.data.Subset): + return dataset.dataset.imgs[:][1] + elif isinstance(dataset, torch.utils.data.Dataset): + return dataset.get_labels() + else: + raise NotImplementedError + + def __iter__(self): + return ( + self.indices[i] + for i in torch.multinomial(self.weights, self.num_samples, replacement=True) + ) + + def __len__(self): + return self.num_samples diff --git a/src/datamodules/datamodule.py b/src/datamodules/datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..ecaf0131c4ac9e7ecc11cdfa434eef82708e0705 --- /dev/null +++ b/src/datamodules/datamodule.py @@ -0,0 +1,154 @@ +from typing import Any, Dict, Generic, Optional, Sequence, TypeVar, Union + +from pytorch_ie.core import Document +from pytorch_ie.core.taskmodule import ( + IterableTaskEncodingDataset, + TaskEncoding, + TaskEncodingDataset, + TaskModule, +) +from pytorch_lightning import LightningDataModule +from torch.utils.data import DataLoader, Sampler +from typing_extensions import TypeAlias + +from .components.sampler import ImbalancedDatasetSampler + +DocumentType = TypeVar("DocumentType", bound=Document) +InputEncoding = TypeVar("InputEncoding") +TargetEncoding = TypeVar("TargetEncoding") +DatasetType: TypeAlias = Union[ + TaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]], + IterableTaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]], +] + + +class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, TargetEncoding]): + """A simple LightningDataModule for PIE document datasets. + + A DataModule implements 5 key methods: + - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode) + - setup (things to do on every accelerator in distributed mode) + - train_dataloader (the training dataloader) + - val_dataloader (the validation dataloader(s)) + - test_dataloader (the test dataloader(s)) + + This allows you to share a full dataset without explaining how to download, + split, transform and process the data. + + Read the docs: + https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html + """ + + def __init__( + self, + taskmodule: TaskModule[DocumentType, InputEncoding, TargetEncoding, Any, Any, Any], + dataset: Dict[str, Sequence[DocumentType]], + data_config_path: Optional[str] = None, + train_split: Optional[str] = "train", + val_split: Optional[str] = "validation", + test_split: Optional[str] = "test", + show_progress_for_encode: bool = False, + train_sampler: Optional[str] = None, + **dataloader_kwargs, + ): + super().__init__() + + self.taskmodule = taskmodule + self.config_path = data_config_path + self.dataset = dataset + self.train_split = train_split + self.val_split = val_split + self.test_split = test_split + self.show_progress_for_encode = show_progress_for_encode + self.train_sampler_name = train_sampler + self.dataloader_kwargs = dataloader_kwargs + + self._data: Dict[str, DatasetType] = {} + + @property + def num_train(self) -> int: + if self.train_split is None: + raise ValueError("no train_split assigned") + data_train = self._data.get(self.train_split, None) + if data_train is None: + raise ValueError("can not get train size if setup() was not yet called") + if isinstance(data_train, IterableTaskEncodingDataset): + raise TypeError("IterableTaskEncodingDataset has no length") + return len(data_train) + + def setup(self, stage: str): + if stage == "fit": + split_names = [self.train_split, self.val_split] + elif stage == "validate": + split_names = [self.val_split] + elif stage == "test": + split_names = [self.test_split] + else: + raise NotImplementedError(f"not implemented for stage={stage} ") + + for split in split_names: + if split is None or split not in self.dataset: + continue + task_encoding_dataset = self.taskmodule.encode( + self.dataset[split], + encode_target=True, + as_dataset=True, + show_progress=self.show_progress_for_encode, + ) + if not isinstance( + task_encoding_dataset, + (TaskEncodingDataset, IterableTaskEncodingDataset), + ): + raise TypeError( + f"taskmodule.encode did not return a (Iterable)TaskEncodingDataset, but: {type(task_encoding_dataset)}" + ) + self._data[split] = task_encoding_dataset + + def data_split(self, split: Optional[str] = None) -> DatasetType: + if split is None or split not in self._data: + raise ValueError(f"data for split={split} not available") + return self._data[split] + + def get_train_sampler( + self, + sampler_name: str, + dataset: DatasetType, + ) -> Sampler: + if sampler_name == "imbalanced_dataset": + # for now, this work only with targets that have a single entry + return ImbalancedDatasetSampler( + dataset, callback_get_label=lambda ds: [x.targets[0] for x in ds] + ) + else: + raise ValueError(f"unknown sampler name: {sampler_name}") + + def train_dataloader(self): + ds = self.data_split(self.train_split) + if self.train_sampler_name is not None: + sampler = self.get_train_sampler(sampler_name=self.train_sampler_name, dataset=ds) + else: + sampler = None + return DataLoader( + dataset=ds, + sampler=sampler, + collate_fn=self.taskmodule.collate, + # don't shuffle streamed datasets or if we use a sampler + shuffle=not (isinstance(ds, IterableTaskEncodingDataset) or sampler is not None), + **self.dataloader_kwargs, + ) + + def val_dataloader(self): + return DataLoader( + dataset=self.data_split(self.val_split), + collate_fn=self.taskmodule.collate, + shuffle=False, + **self.dataloader_kwargs, + ) + + def test_dataloader(self): + return DataLoader( + dataset=self.data_split(self.test_split), + collate_fn=self.taskmodule.collate, + shuffle=False, + **self.dataloader_kwargs, + ) diff --git a/src/dataset/__init__.py b/src/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/dataset/processing.py b/src/dataset/processing.py new file mode 100644 index 0000000000000000000000000000000000000000..74f8a70b890fc47be85bfc2624105301f7e3fa44 --- /dev/null +++ b/src/dataset/processing.py @@ -0,0 +1,29 @@ +from typing import Callable, Type, Union + +from pie_datasets import Dataset, DatasetDict +from pytorch_ie import Document +from pytorch_ie.utils.hydra import resolve_optional_document_type, resolve_target + + +# TODO: simply use use DatasetDict.map() with set_batch_size_to_split_size=True and +# batched=True instead when https://github.com/ArneBinder/pie-datasets/pull/155 is merged +def apply_func_to_splits( + dataset: DatasetDict, + function: Union[str, Callable], + result_document_type: Type[Document], + **kwargs +): + resolved_func = resolve_target(function) + resolved_document_type = resolve_optional_document_type(document_type=result_document_type) + result_dict = dict() + split: Dataset + for split_name, split in dataset.items(): + converted_dataset = split.map( + function=resolved_func, + batched=True, + batch_size=len(split), + result_document_type=resolved_document_type, + **kwargs + ) + result_dict[split_name] = converted_dataset + return DatasetDict(result_dict) diff --git a/src/demo/__init__.py b/src/demo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/demo/annotation_utils.py b/src/demo/annotation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..638c224d7d9ecadb6973e11e2e0effb637f115bf --- /dev/null +++ b/src/demo/annotation_utils.py @@ -0,0 +1,137 @@ +import logging +from typing import Optional, Sequence, Union + +import gradio as gr +from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger + +# this is required to dynamically load the PIE models +from pie_modules.models import * # noqa: F403 +from pie_modules.taskmodules import * # noqa: F403 +from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE +from pytorch_ie import Pipeline +from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.auto import AutoPipeline +from pytorch_ie.documents import ( + TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, + TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, +) + +# this is required to dynamically load the PIE models +from pytorch_ie.models import * # noqa: F403 +from pytorch_ie.taskmodules import * # noqa: F403 + +logger = logging.getLogger(__name__) + + +def annotate_document( + document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, + argumentation_model: Pipeline, + handle_parts_of_same: bool = False, +) -> Union[ + TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, + TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, +]: + """Annotate a document with the provided pipeline. + + Args: + document: The document to annotate. + argumentation_model: The pipeline to use for annotation. + handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span. + """ + + # execute prediction pipeline + argumentation_model(document) + + if handle_parts_of_same: + merger = SpansViaRelationMerger( + relation_layer="binary_relations", + link_relation_label="parts_of_same", + create_multi_spans=True, + result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, + result_field_mapping={ + "labeled_spans": "labeled_multi_spans", + "binary_relations": "binary_relations", + "labeled_partitions": "labeled_partitions", + }, + ) + document = merger(document) + + return document + + +def create_document( + text: str, doc_id: str, split_regex: Optional[str] = None +) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions: + """Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided + text. + + Parameters: + text: The text to process. + doc_id: The ID of the document. + split_regex: A regular expression pattern to use for splitting the text into partitions. + + Returns: + The processed document. + """ + + document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( + id=doc_id, text=text, metadata={} + ) + if split_regex is not None: + partitioner = RegexPartitioner( + pattern=split_regex, partition_layer_name="labeled_partitions" + ) + document = partitioner(document) + else: + # add single partition from the whole text (the model only considers text in partitions) + document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text")) + return document + + +def load_argumentation_model( + model_name: str, + revision: Optional[str] = None, + device: str = "cpu", +) -> Pipeline: + try: + # the Pipeline class expects an integer for the device + if device == "cuda": + pipeline_device = 0 + elif device.startswith("cuda:"): + pipeline_device = int(device.split(":")[1]) + elif device == "cpu": + pipeline_device = -1 + else: + raise gr.Error(f"Invalid device: {device}") + + model = AutoPipeline.from_pretrained( + model_name, + device=pipeline_device, + num_workers=0, + taskmodule_kwargs=dict(revision=revision), + model_kwargs=dict(revision=revision), + ) + gr.Info( + f"Loaded argumentation model: model_name={model_name}, revision={revision}, device={device}" + ) + except Exception as e: + raise gr.Error(f"Failed to load argumentation model: {e}") + + return model + + +def set_relation_types( + argumentation_model: Pipeline, + default: Optional[Sequence[str]] = None, +) -> gr.Dropdown: + if isinstance(argumentation_model.taskmodule, PointerNetworkTaskModuleForEnd2EndRE): + relation_types = argumentation_model.taskmodule.labels_per_layer["binary_relations"] + else: + raise gr.Error("Unsupported taskmodule for relation types") + + return gr.Dropdown( + choices=relation_types, + label="Argumentative Relation Types", + value=default, + multiselect=True, + ) diff --git a/src/demo/backend_utils.py b/src/demo/backend_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7619dfcc01e08961ffa48494a9162795b44299b1 --- /dev/null +++ b/src/demo/backend_utils.py @@ -0,0 +1,221 @@ +import json +import logging +import os +import tempfile +from typing import Iterable, List, Optional, Sequence + +import gradio as gr +import pandas as pd +from pie_datasets import Dataset, IterableDataset, load_dataset +from pytorch_ie import Pipeline +from pytorch_ie.documents import ( + TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, +) + +from src.demo.annotation_utils import annotate_document, create_document +from src.demo.data_utils import load_text_from_arxiv +from src.demo.rendering_utils import ( + RENDER_WITH_DISPLACY, + RENDER_WITH_PRETTY_TABLE, + render_displacy, + render_pretty_table, +) +from src.demo.retriever_utils import get_text_spans_and_relations_from_document +from src.langchain_modules import ( + DocumentAwareSpanRetriever, + DocumentAwareSpanRetrieverWithRelations, +) + +logger = logging.getLogger(__name__) + + +def add_annotated_pie_documents( + retriever: DocumentAwareSpanRetriever, + pie_documents: Sequence[TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions], + use_predicted_annotations: bool, + verbose: bool = False, +) -> None: + if verbose: + gr.Info(f"Create span embeddings for {len(pie_documents)} documents...") + num_docs_before = len(retriever.docstore) + retriever.add_pie_documents(pie_documents, use_predicted_annotations=use_predicted_annotations) + # number of documents that were overwritten + num_overwritten_docs = num_docs_before + len(pie_documents) - len(retriever.docstore) + # warn if documents were overwritten + if num_overwritten_docs > 0: + gr.Warning(f"{num_overwritten_docs} documents were overwritten.") + + +def process_texts( + texts: Iterable[str], + doc_ids: Iterable[str], + argumentation_model: Pipeline, + retriever: DocumentAwareSpanRetriever, + split_regex_escaped: Optional[str], + handle_parts_of_same: bool = False, + verbose: bool = False, +) -> None: + # check that doc_ids are unique + if len(set(doc_ids)) != len(list(doc_ids)): + raise gr.Error("Document IDs must be unique.") + pie_documents = [ + create_document(text=text, doc_id=doc_id, split_regex=split_regex_escaped) + for text, doc_id in zip(texts, doc_ids) + ] + if verbose: + gr.Info(f"Annotate {len(pie_documents)} documents...") + pie_documents = [ + annotate_document( + document=pie_document, + argumentation_model=argumentation_model, + handle_parts_of_same=handle_parts_of_same, + ) + for pie_document in pie_documents + ] + add_annotated_pie_documents( + retriever=retriever, + pie_documents=pie_documents, + use_predicted_annotations=True, + verbose=verbose, + ) + + +def add_annotated_pie_documents_from_dataset( + retriever: DocumentAwareSpanRetriever, verbose: bool = False, **load_dataset_kwargs +) -> None: + try: + gr.Info( + "Loading PIE dataset with parameters:\n" + json.dumps(load_dataset_kwargs, indent=2) + ) + dataset = load_dataset(**load_dataset_kwargs) + if not isinstance(dataset, (Dataset, IterableDataset)): + raise gr.Error("Loaded dataset is not of type PIE (Iterable)Dataset.") + dataset_converted = dataset.to_document_type( + TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions + ) + add_annotated_pie_documents( + retriever=retriever, + pie_documents=dataset_converted, + use_predicted_annotations=False, + verbose=verbose, + ) + except Exception as e: + raise gr.Error(f"Failed to load dataset: {e}") + + +def wrapped_process_text( + doc_id: str, text: str, retriever: DocumentAwareSpanRetriever, **kwargs +) -> str: + try: + process_texts(doc_ids=[doc_id], texts=[text], retriever=retriever, **kwargs) + except Exception as e: + raise gr.Error(f"Failed to process text: {e}") + # Return as dict and document to avoid serialization issues + return doc_id + + +def process_uploaded_files( + file_names: List[str], + retriever: DocumentAwareSpanRetriever, + layer_captions: dict[str, str], + **kwargs, +) -> pd.DataFrame: + try: + doc_ids = [] + texts = [] + for file_name in file_names: + if file_name.lower().endswith(".txt"): + # read the file content + with open(file_name, "r", encoding="utf-8") as f: + text = f.read() + base_file_name = os.path.basename(file_name) + doc_ids.append(base_file_name) + texts.append(text) + else: + raise gr.Error(f"Unsupported file format: {file_name}") + process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs) + except Exception as e: + raise gr.Error(f"Failed to process uploaded files: {e}") + + return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) + + +def wrapped_add_annotated_pie_documents_from_dataset( + retriever: DocumentAwareSpanRetriever, verbose: bool, layer_captions: dict[str, str], **kwargs +) -> pd.DataFrame: + try: + add_annotated_pie_documents_from_dataset(retriever=retriever, verbose=verbose, **kwargs) + except Exception as e: + raise gr.Error(f"Failed to add annotated PIE documents from dataset: {e}") + return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) + + +def download_processed_documents( + retriever: DocumentAwareSpanRetriever, + file_name: str = "retriever_store", +) -> Optional[str]: + if len(retriever.docstore) == 0: + gr.Warning("No documents to download.") + return None + + # zip the directory + file_path = os.path.join(tempfile.gettempdir(), file_name) + + gr.Info(f"Zipping the retriever store to '{file_name}' ...") + result_file_path = retriever.save_to_archive(base_name=file_path, format="zip") + + return result_file_path + + +def upload_processed_documents( + file_name: str, + retriever: DocumentAwareSpanRetriever, + layer_captions: dict[str, str], +) -> pd.DataFrame: + # load the documents from the zip file or directory + retriever.load_from_disc(file_name) + # return the overview of the document store + return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) + + +def process_text_from_arxiv( + arxiv_id: str, retriever: DocumentAwareSpanRetriever, abstract_only: bool = False, **kwargs +) -> str: + try: + text, doc_id = load_text_from_arxiv(arxiv_id=arxiv_id, abstract_only=abstract_only) + except Exception as e: + raise gr.Error(f"Failed to load text from arXiv: {e}") + return wrapped_process_text(doc_id=doc_id, text=text, retriever=retriever, **kwargs) + + +def render_annotated_document( + retriever: DocumentAwareSpanRetrieverWithRelations, + document_id: str, + render_with: str, + render_kwargs_json: str, +) -> str: + text, spans, span_id2idx, relations = get_text_spans_and_relations_from_document( + retriever=retriever, document_id=document_id + ) + + render_kwargs = json.loads(render_kwargs_json) + if render_with == RENDER_WITH_PRETTY_TABLE: + html = render_pretty_table( + text=text, + spans=spans, + span_id2idx=span_id2idx, + binary_relations=relations, + **render_kwargs, + ) + elif render_with == RENDER_WITH_DISPLACY: + html = render_displacy( + text=text, + spans=spans, + span_id2idx=span_id2idx, + binary_relations=relations, + **render_kwargs, + ) + else: + raise ValueError(f"Unknown render_with value: {render_with}") + + return html diff --git a/src/demo/data_utils.py b/src/demo/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..980c05677ff36bbaa6e9556a0d5d8e9ab75d9d19 --- /dev/null +++ b/src/demo/data_utils.py @@ -0,0 +1,63 @@ +import logging +import re +from typing import Tuple + +import arxiv +import gradio as gr +import requests +from bs4 import BeautifulSoup + +logger = logging.getLogger(__name__) + + +def clean_spaces(text: str) -> str: + # replace all multiple spaces with a single space + text = re.sub(" +", " ", text) + # reduce more than two newlines to two newlines + text = re.sub("\n\n+", "\n\n", text) + # remove leading and trailing whitespaces + text = text.strip() + return text + + +def get_cleaned_arxiv_paper_text(html_content: str) -> str: + # parse the HTML content with BeautifulSoup + soup = BeautifulSoup(html_content, "html.parser") + # get alerts (this is one div with classes "package-alerts" and "ltx_document") + alerts = soup.find("div", class_="package-alerts ltx_document") + # get the "article" html element + article = soup.find("article") + article_text = article.get_text() + # cleanup the text + article_text_clean = clean_spaces(article_text) + return article_text_clean + + +def load_text_from_arxiv(arxiv_id: str, abstract_only: bool = False) -> Tuple[str, str]: + + search_by_id = arxiv.Search(id_list=[arxiv_id]) + try: + result = list(arxiv.Client().results(search_by_id)) + except arxiv.HTTPError as e: + raise gr.Error(f"Failed to fetch arXiv data: {e}") + if len(result) == 0: + raise gr.Error(f"Could not find any paper with arXiv ID '{arxiv_id}'") + first_result = result[0] + if abstract_only: + abstract_clean = first_result.summary.replace("\n", " ") + return abstract_clean, first_result.entry_id + if "/abs/" not in first_result.entry_id: + raise gr.Error( + f"Could not create the HTML URL for arXiv ID '{arxiv_id}' because its entry ID has " + f"an unexpected format: {first_result.entry_id}" + ) + html_url = first_result.entry_id.replace("/abs/", "/html/") + request_result = requests.get(html_url) + if request_result.status_code != 200: + raise gr.Error( + f"Could not fetch the HTML content for arXiv ID '{arxiv_id}', status code: " + f"{request_result.status_code}" + ) + html_content = request_result.text + text_clean = get_cleaned_arxiv_paper_text(html_content) + return text_clean, html_url diff --git a/src/demo/frontend_utils.py b/src/demo/frontend_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..01b17cb247225c6284957d04bf8911aba90f8286 --- /dev/null +++ b/src/demo/frontend_utils.py @@ -0,0 +1,56 @@ +from typing import Any, Union + +import gradio as gr +import pandas as pd + + +# see https://github.com/gradio-app/gradio/issues/9288#issuecomment-2356163329 +def get_fix_df_height_css(css_class: str, max_height: int) -> str: + # return ".qa-pairs .table-wrap {min-height: 170px; max-height: 170px;}" + return "." + css_class + " .table-wrap {max-height: " + str(max_height) + "px;}" + + +def escape_regex(regex: str) -> str: + # "double escape" the backslashes + result = regex.encode("unicode_escape").decode("utf-8") + return result + + +def unescape_regex(regex: str) -> str: + # reverse of escape_regex + result = regex.encode("utf-8").decode("unicode_escape") + return result + + +def open_accordion(): + return gr.Accordion(open=True) + + +def close_accordion(): + return gr.Accordion(open=False) + + +def change_tab(id: Union[int, str]): + return gr.Tabs(selected=id) + + +def get_cell_for_fixed_column_from_df( + evt: gr.SelectData, + df: pd.DataFrame, + column: str, +) -> Any: + """Get the value of the fixed column for the selected row in the DataFrame. + This is required can *not* with a lambda function because that will not get + the evt parameter. + + Args: + evt: The event object. + df: The DataFrame. + column: The name of the column. + + Returns: + The value of the fixed column for the selected row. + """ + row_idx, col_idx = evt.index + doc_id = df.iloc[row_idx][column] + return doc_id diff --git a/src/demo/rendering_utils.py b/src/demo/rendering_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fd46410b62e254e9bf89bf99cab6de584c5e1954 --- /dev/null +++ b/src/demo/rendering_utils.py @@ -0,0 +1,296 @@ +import json +import logging +from collections import defaultdict +from typing import Any, Dict, List, Optional, Sequence, Union + +from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan + +from .rendering_utils_displacy import EntityRenderer + +logger = logging.getLogger(__name__) + +RENDER_WITH_DISPLACY = "displacy" +RENDER_WITH_PRETTY_TABLE = "pretty_table" +AVAILABLE_RENDER_MODES = [RENDER_WITH_DISPLACY, RENDER_WITH_PRETTY_TABLE] + +# adjusted from rendering_utils_displacy.TPL_ENT +TPL_ENT_WITH_ID = """ + + {text} + {label} + +""" + +HIGHLIGHT_SPANS_JS = """ +() => { + function maybeSetColor(entity, colorAttributeKey, colorDictKey) { + var color = entity.getAttribute('data-color-' + colorAttributeKey); + // if color is a json string, parse it and use the value at colorDictKey + try { + const colors = JSON.parse(color); + color = colors[colorDictKey]; + } catch (e) {} + if (color) { + entity.style.backgroundColor = color; + entity.style.color = '#000'; + } + } + + function highlightRelationArguments(entityId) { + const entities = document.querySelectorAll('.entity'); + // reset all entities + entities.forEach(entity => { + const color = entity.getAttribute('data-color-original'); + entity.style.backgroundColor = color; + entity.style.color = ''; + }); + + if (entityId !== null) { + var visitedEntities = new Set(); + // highlight selected entity + // get all elements with attribute data-entity-id==entityId + const selectedEntityParts = document.querySelectorAll(`[data-entity-id="${entityId}"]`); + selectedEntityParts.forEach(selectedEntityPart => { + const label = selectedEntityPart.getAttribute('data-label'); + maybeSetColor(selectedEntityPart, 'selected', label); + visitedEntities.add(selectedEntityPart); + }); // <-- Corrected closing parenthesis here + // if there is at least one part, get the first one and ... + if (selectedEntityParts.length > 0) { + const selectedEntity = selectedEntityParts[0]; + + // ... highlight tails and ... + const relationTailsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-tails')); + relationTailsAndLabels.forEach(relationTail => { + const tailEntityId = relationTail['entity-id']; + const tailEntityParts = document.querySelectorAll(`[data-entity-id="${tailEntityId}"]`); + tailEntityParts.forEach(tailEntity => { + const label = relationTail['label']; + maybeSetColor(tailEntity, 'tail', label); + visitedEntities.add(tailEntity); + }); // <-- Corrected closing parenthesis here + }); // <-- Corrected closing parenthesis here + // .. highlight heads + const relationHeadsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-heads')); + relationHeadsAndLabels.forEach(relationHead => { + const headEntityId = relationHead['entity-id']; + const headEntityParts = document.querySelectorAll(`[data-entity-id="${headEntityId}"]`); + headEntityParts.forEach(headEntity => { + const label = relationHead['label']; + maybeSetColor(headEntity, 'head', label); + visitedEntities.add(headEntity); + }); // <-- Corrected closing parenthesis here + }); // <-- Corrected closing parenthesis here + } + + // highlight other entities + entities.forEach(entity => { + if (!visitedEntities.has(entity)) { + const label = entity.getAttribute('data-label'); + maybeSetColor(entity, 'other', label); + } + }); + } + } + function setHoverAduId(entityId) { + // get the textarea element that holds the reference adu id + let hoverAduIdDiv = document.querySelector('#hover_adu_id textarea'); + // set the value of the input field + hoverAduIdDiv.value = entityId; + // trigger an input event to update the state + var event = new Event('input'); + hoverAduIdDiv.dispatchEvent(event); + } + function setReferenceAduIdFromHover() { + // get the hover adu id + const hoverAduIdDiv = document.querySelector('#hover_adu_id textarea'); + // get the value of the input field + const entityId = hoverAduIdDiv.value; + // get the textarea element that holds the reference adu id + let referenceAduIdDiv = document.querySelector('#selected_adu_id textarea'); + // set the value of the input field + referenceAduIdDiv.value = entityId; + // trigger an input event to update the state + var event = new Event('input'); + referenceAduIdDiv.dispatchEvent(event); + } + + const entities = document.querySelectorAll('.entity'); + entities.forEach(entity => { + // make the cursor a pointer + entity.style.cursor = 'pointer'; + const alreadyHasListener = entity.getAttribute('data-has-listener'); + if (alreadyHasListener) { + return; + } + entity.addEventListener('mouseover', () => { + const entityId = entity.getAttribute('data-entity-id'); + highlightRelationArguments(entityId); + setHoverAduId(entityId); + }); + entity.addEventListener('mouseout', () => { + highlightRelationArguments(null); + }); + entity.setAttribute('data-has-listener', 'true'); + }); + const entityContainer = document.querySelector('.entities'); + if (entityContainer) { + entityContainer.addEventListener('click', () => { + setReferenceAduIdFromHover(); + }); + // make the cursor a pointer + // entityContainer.style.cursor = 'pointer'; + } +} +""" + + +def render_pretty_table( + text: str, + spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]], + span_id2idx: Dict[str, int], + binary_relations: Sequence[BinaryRelation], + **render_kwargs, +): + from prettytable import PrettyTable + + t = PrettyTable() + t.field_names = ["head", "tail", "relation"] + t.align = "l" + for relation in list(binary_relations) + list(binary_relations): + t.add_row([str(relation.head), str(relation.tail), relation.label]) + + html = t.get_html_string(format=True) + html = "
" + html + "
" + + return html + + +def render_displacy( + text: str, + spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]], + span_id2idx: Dict[str, int], + binary_relations: Sequence[BinaryRelation], + inject_relations=True, + colors_hover=None, + entity_options={}, + **render_kwargs, +): + + ents: List[Dict[str, Any]] = [] + for entity_id, idx in span_id2idx.items(): + labeled_span = spans[idx] + # pass the ID as a parameter to the entity. The id is required to fetch the entity annotations + # on hover and to inject the relation data. + if isinstance(labeled_span, LabeledSpan): + ents.append( + { + "start": labeled_span.start, + "end": labeled_span.end, + "label": labeled_span.label, + "params": {"entity_id": entity_id, "slice_idx": 0}, + } + ) + elif isinstance(labeled_span, LabeledMultiSpan): + for i, (start, end) in enumerate(labeled_span.slices): + ents.append( + { + "start": start, + "end": end, + "label": labeled_span.label, + "params": {"entity_id": entity_id, "slice_idx": i}, + } + ) + else: + raise ValueError(f"Unsupported labeled span type: {type(labeled_span)}") + + ents_sorted = sorted(ents, key=lambda x: (x["start"], x["end"])) + spacy_doc = { + "text": text, + # the ents MUST be sorted by start and end + "ents": ents_sorted, + "title": None, + } + + # copy to avoid modifying the original options + entity_options = entity_options.copy() + # use the custom template with the entity ID + entity_options["template"] = TPL_ENT_WITH_ID + renderer = EntityRenderer(options=entity_options) + html = renderer.render([spacy_doc], page=True, minify=True).strip() + + html = "
" + html + "
" + if inject_relations: + html = inject_relation_data( + html, + spans=spans, + span_id2idx=span_id2idx, + binary_relations=binary_relations, + additional_colors=colors_hover, + ) + return html + + +def inject_relation_data( + html: str, + spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]], + span_id2idx: Dict[str, int], + binary_relations: Sequence[BinaryRelation], + additional_colors: Optional[Dict[str, Union[str, dict]]] = None, +) -> str: + from bs4 import BeautifulSoup + + # Parse the HTML using BeautifulSoup + soup = BeautifulSoup(html, "html.parser") + + entity2tails = defaultdict(list) + entity2heads = defaultdict(list) + for relation in binary_relations: + entity2heads[relation.tail].append((relation.head, relation.label)) + entity2tails[relation.head].append((relation.tail, relation.label)) + + annotation2id = {spans[span_idx]: span_id for span_id, span_idx in span_id2idx.items()} + # Add unique IDs to each entity + entities = soup.find_all(class_="entity") + for entity in entities: + original_color = entity["style"].split("background:")[1].split(";")[0].strip() + entity["data-color-original"] = original_color + if additional_colors is not None: + for key, color in additional_colors.items(): + entity[f"data-color-{key}"] = ( + json.dumps(color) if isinstance(color, dict) else color + ) + + entity_annotation = spans[span_id2idx[entity["data-entity-id"]]] + + # sanity check. + if isinstance(entity_annotation, LabeledSpan): + annotation_text = entity_annotation.resolve()[1] + elif isinstance(entity_annotation, LabeledMultiSpan): + slice_idx = int(entity["data-slice-idx"]) + annotation_text = entity_annotation.resolve()[1][slice_idx] + else: + raise ValueError(f"Unsupported entity type: {type(entity_annotation)}") + annotation_text_without_newline = annotation_text.replace("\n", "") + # Just check the start, because the text has the label attached to the end + if not entity.text.startswith(annotation_text_without_newline): + logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}") + + entity["data-label"] = entity_annotation.label + entity["data-relation-tails"] = json.dumps( + [ + {"entity-id": annotation2id[tail], "label": label} + for tail, label in entity2tails.get(entity_annotation, []) + if tail in annotation2id + ] + ) + entity["data-relation-heads"] = json.dumps( + [ + {"entity-id": annotation2id[head], "label": label} + for head, label in entity2heads.get(entity_annotation, []) + if head in annotation2id + ] + ) + + # Return the modified HTML as a string + return str(soup) diff --git a/src/demo/rendering_utils_displacy.py b/src/demo/rendering_utils_displacy.py new file mode 100644 index 0000000000000000000000000000000000000000..2c69cb26749ec18df50c2157003e1d36b24f72b2 --- /dev/null +++ b/src/demo/rendering_utils_displacy.py @@ -0,0 +1,217 @@ +# This code is mainly taken from +# https://github.com/explosion/spaCy/blob/master/spacy/displacy/templates.py, and from +# https://github.com/explosion/spaCy/blob/master/spacy/displacy/render.py. + +# Setting explicit height and max-width: none on the SVG is required for +# Jupyter to render it properly in a cell + +TPL_DEP_SVG = """ +{content} +""" + + +TPL_DEP_WORDS = """ + + {text} + {tag} + +""" + + +TPL_DEP_WORDS_LEMMA = """ + + {text} + {lemma} + {tag} + +""" + + +TPL_DEP_ARCS = """ + + + + {label} + + + +""" + + +TPL_FIGURE = """ +
{content}
+""" + +TPL_TITLE = """ +

{title}

+""" + + +TPL_ENTS = """ +
{content}
+""" + + +TPL_ENT = """ + + {text} + {label} + +""" + +TPL_ENT_RTL = """ + + {text} + {label} + +""" + + +TPL_PAGE = """ + + + + displaCy + + + {content} + +""" + + +DEFAULT_LANG = "en" +DEFAULT_DIR = "ltr" + + +def minify_html(html): + """Perform a template-specific, rudimentary HTML minification for displaCy. + Disclaimer: NOT a general-purpose solution, only removes indentation and + newlines. + + html (unicode): Markup to minify. + RETURNS (unicode): "Minified" HTML. + """ + return html.strip().replace(" ", "").replace("\n", "") + + +def escape_html(text): + """Replace <, >, &, " with their HTML encoded representation. Intended to prevent HTML errors + in rendered displaCy markup. + + text (unicode): The original text. RETURNS (unicode): Equivalent text to be safely used within + HTML. + """ + text = text.replace("&", "&") + text = text.replace("<", "<") + text = text.replace(">", ">") + text = text.replace('"', """) + return text + + +class EntityRenderer(object): + """Render named entities as HTML.""" + + style = "ent" + + def __init__(self, options={}): + """Initialise dependency renderer. + + options (dict): Visualiser-specific options (colors, ents) + """ + colors = { + "ORG": "#7aecec", + "PRODUCT": "#bfeeb7", + "GPE": "#feca74", + "LOC": "#ff9561", + "PERSON": "#aa9cfc", + "NORP": "#c887fb", + "FACILITY": "#9cc9cc", + "EVENT": "#ffeb80", + "LAW": "#ff8197", + "LANGUAGE": "#ff8197", + "WORK_OF_ART": "#f0d0ff", + "DATE": "#bfe1d9", + "TIME": "#bfe1d9", + "MONEY": "#e4e7d2", + "QUANTITY": "#e4e7d2", + "ORDINAL": "#e4e7d2", + "CARDINAL": "#e4e7d2", + "PERCENT": "#e4e7d2", + } + # user_colors = registry.displacy_colors.get_all() + # for user_color in user_colors.values(): + # colors.update(user_color) + colors.update(options.get("colors", {})) + self.default_color = "#ddd" + self.colors = colors + self.ents = options.get("ents", None) + self.direction = DEFAULT_DIR + self.lang = DEFAULT_LANG + + template = options.get("template") + if template: + self.ent_template = template + else: + if self.direction == "rtl": + self.ent_template = TPL_ENT_RTL + else: + self.ent_template = TPL_ENT + + def render(self, parsed, page=False, minify=False): + """Render complete markup. + + parsed (list): Dependency parses to render. page (bool): Render parses wrapped as full HTML + page. minify (bool): Minify HTML markup. RETURNS (unicode): Rendered HTML markup. + """ + rendered = [] + for i, p in enumerate(parsed): + if i == 0: + settings = p.get("settings", {}) + self.direction = settings.get("direction", DEFAULT_DIR) + self.lang = settings.get("lang", DEFAULT_LANG) + rendered.append(self.render_ents(p["text"], p["ents"], p.get("title"))) + if page: + docs = "".join([TPL_FIGURE.format(content=doc) for doc in rendered]) + markup = TPL_PAGE.format(content=docs, lang=self.lang, dir=self.direction) + else: + markup = "".join(rendered) + if minify: + return minify_html(markup) + return markup + + def render_ents(self, text, spans, title): + """Render entities in text. + + text (unicode): Original text. spans (list): Individual entity spans and their start, end + and label. title (unicode or None): Document title set in Doc.user_data['title']. + """ + markup = "" + offset = 0 + for span in spans: + label = span["label"] + start = span["start"] + end = span["end"] + additional_params = span.get("params", {}) + entity = escape_html(text[start:end]) + fragments = text[offset:start].split("\n") + for i, fragment in enumerate(fragments): + markup += escape_html(fragment) + if len(fragments) > 1 and i != len(fragments) - 1: + markup += "
" + if self.ents is None or label.upper() in self.ents: + color = self.colors.get(label.upper(), self.default_color) + ent_settings = {"label": label, "text": entity, "bg": color} + ent_settings.update(additional_params) + markup += self.ent_template.format(**ent_settings) + else: + markup += entity + offset = end + fragments = text[offset:].split("\n") + for i, fragment in enumerate(fragments): + markup += escape_html(fragment) + if len(fragments) > 1 and i != len(fragments) - 1: + markup += "
" + markup = TPL_ENTS.format(content=markup, dir=self.direction) + if title: + markup = TPL_TITLE.format(title=title) + markup + return markup diff --git a/src/demo/retrieve_and_dump_all_relevant.py b/src/demo/retrieve_and_dump_all_relevant.py new file mode 100644 index 0000000000000000000000000000000000000000..7105e122a1b7e34066b3cc127ece5891d62ed4a0 --- /dev/null +++ b/src/demo/retrieve_and_dump_all_relevant.py @@ -0,0 +1,101 @@ +import pyrootutils + +root = pyrootutils.setup_root( + search_from=__file__, + indicator=[".project-root"], + pythonpath=True, + dotenv=True, +) + +import argparse +import logging + +from src.demo.retriever_utils import ( + retrieve_all_relevant_spans, + retrieve_all_relevant_spans_for_all_documents, + retrieve_relevant_spans, +) +from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations + +logger = logging.getLogger(__name__) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "-c", + "--config_path", + type=str, + default="configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml", + ) + parser.add_argument( + "--data_path", + type=str, + required=True, + help="Path to a zip or directory containing a retriever dump.", + ) + parser.add_argument("-k", "--top_k", type=int, default=10) + parser.add_argument("-t", "--threshold", type=float, default=0.95) + parser.add_argument( + "-o", + "--output_path", + type=str, + required=True, + ) + parser.add_argument( + "--query_doc_id", + type=str, + default=None, + help="If provided, retrieve all spans for only this query document.", + ) + parser.add_argument( + "--query_span_id", + type=str, + default=None, + help="If provided, retrieve all spans for only this query span.", + ) + args = parser.parse_args() + + logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", + ) + + if not args.output_path.endswith(".json"): + raise ValueError("only support json output") + + logger.info(f"instantiating retriever from {args.config_path}...") + retriever = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config_file( + args.config_path + ) + logger.info(f"loading data from {args.data_path}...") + retriever.load_from_disc(args.data_path) + + search_kwargs = {"k": args.top_k, "score_threshold": args.threshold} + logger.info(f"use search_kwargs: {search_kwargs}") + + if args.query_span_id is not None: + logger.warning(f"retrieving results for single span: {args.query_span_id}") + all_spans_for_all_documents = retrieve_relevant_spans( + retriever=retriever, query_span_id=args.query_span_id, **search_kwargs + ) + elif args.query_doc_id is not None: + logger.warning(f"retrieving results for single document: {args.query_doc_id}") + all_spans_for_all_documents = retrieve_all_relevant_spans( + retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs + ) + else: + all_spans_for_all_documents = retrieve_all_relevant_spans_for_all_documents( + retriever=retriever, **search_kwargs + ) + + if all_spans_for_all_documents is None: + logger.warning("no relevant spans found in any document") + exit(0) + + logger.info(f"dumping results to {args.output_path}...") + all_spans_for_all_documents.to_json(args.output_path) + + logger.info("done") diff --git a/src/demo/retriever_utils.py b/src/demo/retriever_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5dc47eea8944c2302491eb421d729ed9059d1a9f --- /dev/null +++ b/src/demo/retriever_utils.py @@ -0,0 +1,313 @@ +import logging +from typing import Dict, Optional, Sequence, Tuple, Union + +import gradio as gr +import pandas as pd +from pytorch_ie import Annotation +from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan +from typing_extensions import Protocol + +from src.langchain_modules import DocumentAwareSpanRetriever +from src.langchain_modules.span_retriever import ( + DocumentAwareSpanRetrieverWithRelations, + _parse_config, +) + +logger = logging.getLogger(__name__) + + +def get_document_as_dict(retriever: DocumentAwareSpanRetriever, doc_id: str) -> Dict: + document = retriever.get_document(doc_id=doc_id) + return retriever.docstore.as_dict(document) + + +def load_retriever( + retriever_config_str: str, + config_format: str, + device: str = "cpu", + previous_retriever: Optional[DocumentAwareSpanRetrieverWithRelations] = None, +) -> DocumentAwareSpanRetrieverWithRelations: + try: + retriever_config = _parse_config(retriever_config_str, format=config_format) + # set device for the embeddings pipeline + retriever_config["vectorstore"]["embedding"]["pipeline_kwargs"]["device"] = device + result = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config(retriever_config) + # if a previous retriever is provided, load all documents and vectors from the previous retriever + if previous_retriever is not None: + # documents + all_doc_ids = list(previous_retriever.docstore.yield_keys()) + gr.Info(f"Storing {len(all_doc_ids)} documents from previous retriever...") + all_docs = previous_retriever.docstore.mget(all_doc_ids) + result.docstore.mset([(doc.id, doc) for doc in all_docs]) + # spans (with vectors) + all_span_ids = list(previous_retriever.vectorstore.yield_keys()) + all_spans = previous_retriever.vectorstore.mget(all_span_ids) + result.vectorstore.mset([(span.id, span) for span in all_spans]) + + gr.Info("Retriever loaded successfully.") + return result + except Exception as e: + raise gr.Error(f"Failed to load retriever: {e}") + + +def retrieve_similar_spans( + retriever: DocumentAwareSpanRetriever, + query_span_id: str, + **kwargs, +) -> pd.DataFrame: + if not query_span_id.strip(): + raise gr.Error("No query span selected.") + try: + retrieval_result = retriever.invoke(input=query_span_id, **kwargs) + records = [] + for similar_span_doc in retrieval_result: + pie_doc, metadata = retriever.docstore.unwrap_with_metadata(similar_span_doc) + span_ann = metadata["attached_span"] + records.append( + { + "doc_id": pie_doc.id, + "span_id": similar_span_doc.id, + "score": metadata["relevance_score"], + "label": span_ann.label, + "text": str(span_ann), + } + ) + return ( + pd.DataFrame(records, columns=["doc_id", "score", "label", "text", "span_id"]) + .sort_values(by="score", ascending=False) + .round(3) + ) + except Exception as e: + raise gr.Error(f"Failed to retrieve similar ADUs: {e}") + + +def retrieve_relevant_spans( + retriever: DocumentAwareSpanRetriever, + query_span_id: str, + relation_label_mapping: Optional[dict[str, str]] = None, + **kwargs, +) -> pd.DataFrame: + if not query_span_id.strip(): + raise gr.Error("No query span selected.") + try: + relation_label_mapping = relation_label_mapping or {} + retrieval_result = retriever.invoke(input=query_span_id, return_related=True, **kwargs) + records = [] + for relevant_span_doc in retrieval_result: + pie_doc, metadata = retriever.docstore.unwrap_with_metadata(relevant_span_doc) + span_ann = metadata["attached_span"] + tail_span_ann = metadata["attached_tail_span"] + mapped_relation_label = relation_label_mapping.get( + metadata["relation_label"], metadata["relation_label"] + ) + records.append( + { + "doc_id": pie_doc.id, + "type": mapped_relation_label, + "rel_score": metadata["relation_score"], + "text": str(tail_span_ann), + "span_id": relevant_span_doc.id, + "label": tail_span_ann.label, + "ref_score": metadata["relevance_score"], + "ref_label": span_ann.label, + "ref_text": str(span_ann), + "ref_span_id": metadata["head_id"], + } + ) + return ( + pd.DataFrame( + records, + columns=[ + "type", + # omitted for now, we get no valid relation scores for the generative model + # "rel_score", + "ref_score", + "label", + "text", + "ref_label", + "ref_text", + "doc_id", + "span_id", + "ref_span_id", + ], + ) + .sort_values(by=["ref_score"], ascending=False) + .round(3) + ) + except Exception as e: + raise gr.Error(f"Failed to retrieve relevant ADUs: {e}") + + +class RetrieverCallable(Protocol): + def __call__( + self, + retriever: DocumentAwareSpanRetriever, + query_span_id: str, + **kwargs, + ) -> Optional[pd.DataFrame]: + pass + + +def _retrieve_for_all_spans( + retriever: DocumentAwareSpanRetriever, + query_doc_id: str, + retrieve_func: RetrieverCallable, + query_span_id_column: str = "query_span_id", + **kwargs, +) -> Optional[pd.DataFrame]: + if not query_doc_id.strip(): + raise gr.Error("No query document selected.") + try: + span_id2idx = retriever.get_span_id2idx_from_doc(query_doc_id) + gr.Info(f"Retrieving results for {len(span_id2idx)} ADUs in document {query_doc_id}...") + span_results = { + query_span_id: retrieve_func( + retriever=retriever, + query_span_id=query_span_id, + **kwargs, + ) + for query_span_id in span_id2idx.keys() + } + span_results_not_empty = { + query_span_id: df + for query_span_id, df in span_results.items() + if df is not None and not df.empty + } + + # add column with query_span_id + for query_span_id, query_span_result in span_results_not_empty.items(): + query_span_result[query_span_id_column] = query_span_id + + if len(span_results_not_empty) == 0: + gr.Info(f"No results found for any ADU in document {query_doc_id}.") + return None + else: + result = pd.concat(span_results_not_empty.values(), ignore_index=True) + gr.Info(f"Retrieved {len(result)} ADUs for document {query_doc_id}.") + return result + except Exception as e: + raise gr.Error( + f'Failed to retrieve results for all ADUs in document "{query_doc_id}": {e}' + ) + + +def retrieve_all_similar_spans( + retriever: DocumentAwareSpanRetriever, + query_doc_id: str, + **kwargs, +) -> Optional[pd.DataFrame]: + return _retrieve_for_all_spans( + retriever=retriever, + query_doc_id=query_doc_id, + retrieve_func=retrieve_similar_spans, + **kwargs, + ) + + +def retrieve_all_relevant_spans( + retriever: DocumentAwareSpanRetriever, + query_doc_id: str, + **kwargs, +) -> Optional[pd.DataFrame]: + return _retrieve_for_all_spans( + retriever=retriever, + query_doc_id=query_doc_id, + retrieve_func=retrieve_relevant_spans, + **kwargs, + ) + + +class RetrieverForAllSpansCallable(Protocol): + def __call__( + self, + retriever: DocumentAwareSpanRetriever, + query_doc_id: str, + **kwargs, + ) -> Optional[pd.DataFrame]: + pass + + +def _retrieve_for_all_documents( + retriever: DocumentAwareSpanRetriever, + retrieve_func: RetrieverForAllSpansCallable, + query_doc_id_column: str = "query_doc_id", + **kwargs, +) -> Optional[pd.DataFrame]: + try: + all_doc_ids = list(retriever.docstore.yield_keys()) + gr.Info(f"Retrieving results for {len(all_doc_ids)} documents...") + doc_results = { + doc_id: retrieve_func(retriever=retriever, query_doc_id=doc_id, **kwargs) + for doc_id in all_doc_ids + } + doc_results_not_empty = { + doc_id: df for doc_id, df in doc_results.items() if df is not None and not df.empty + } + # add column with query_doc_id + for doc_id, doc_result in doc_results_not_empty.items(): + doc_result[query_doc_id_column] = doc_id + + if len(doc_results_not_empty) == 0: + gr.Info("No results found for any document.") + return None + else: + result = pd.concat(doc_results_not_empty, ignore_index=True) + gr.Info(f"Retrieved {len(result)} ADUs for all documents.") + return result + except Exception as e: + raise gr.Error(f"Failed to retrieve results for all documents: {e}") + + +def retrieve_all_similar_spans_for_all_documents( + retriever: DocumentAwareSpanRetriever, + **kwargs, +) -> Optional[pd.DataFrame]: + return _retrieve_for_all_documents( + retriever=retriever, + retrieve_func=retrieve_all_similar_spans, + **kwargs, + ) + + +def retrieve_all_relevant_spans_for_all_documents( + retriever: DocumentAwareSpanRetriever, + **kwargs, +) -> Optional[pd.DataFrame]: + return _retrieve_for_all_documents( + retriever=retriever, + retrieve_func=retrieve_all_relevant_spans, + **kwargs, + ) + + +def get_text_spans_and_relations_from_document( + retriever: DocumentAwareSpanRetrieverWithRelations, document_id: str +) -> Tuple[ + str, + Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]], + Dict[str, int], + Sequence[BinaryRelation], +]: + document = retriever.get_document(doc_id=document_id) + pie_document = retriever.docstore.unwrap(document) + use_predicted_annotations = retriever.use_predicted_annotations(document) + spans = retriever.get_base_layer( + pie_document=pie_document, use_predicted_annotations=use_predicted_annotations + ) + relations = retriever.get_relation_layer( + pie_document=pie_document, use_predicted_annotations=use_predicted_annotations + ) + span_id2idx = retriever.get_span_id2idx_from_doc(document) + return pie_document.text, spans, span_id2idx, relations + + +def get_span_annotation( + retriever: DocumentAwareSpanRetriever, + span_id: str, +) -> Annotation: + if span_id.strip() == "": + raise gr.Error("No span selected.") + try: + return retriever.get_span_by_id(span_id=span_id) + except Exception as e: + raise gr.Error(f"Failed to retrieve span annotation: {e}") diff --git a/src/document/__init__.py b/src/document/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/document/processing.py b/src/document/processing.py new file mode 100644 index 0000000000000000000000000000000000000000..651418aca6ff09f621e02281ced1fd5dc1a0f703 --- /dev/null +++ b/src/document/processing.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +import logging +from typing import Any, Dict, Iterable, List, Sequence, Set, Tuple, TypeVar, Union + +import networkx as nx +from pie_modules.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan +from pie_modules.documents import TextDocumentWithLabeledMultiSpansAndBinaryRelations +from pytorch_ie import AnnotationLayer +from pytorch_ie.core import Document + +logger = logging.getLogger(__name__) + + +D = TypeVar("D", bound=Document) + + +def _remove_overlapping_entities( + entities: Iterable[Dict[str, Any]], relations: Iterable[Dict[str, Any]] +) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + sorted_entities = sorted(entities, key=lambda span: span["start"]) + entities_wo_overlap = [] + skipped_entities = [] + last_end = 0 + for entity_dict in sorted_entities: + if entity_dict["start"] < last_end: + skipped_entities.append(entity_dict) + else: + entities_wo_overlap.append(entity_dict) + last_end = entity_dict["end"] + if len(skipped_entities) > 0: + logger.warning(f"skipped overlapping entities: {skipped_entities}") + valid_entity_ids = set(entity_dict["_id"] for entity_dict in entities_wo_overlap) + valid_relations = [ + relation_dict + for relation_dict in relations + if relation_dict["head"] in valid_entity_ids and relation_dict["tail"] in valid_entity_ids + ] + return entities_wo_overlap, valid_relations + + +def remove_overlapping_entities( + doc: D, + entity_layer_name: str = "entities", + relation_layer_name: str = "relations", +) -> D: + # TODO: use document.add_all_annotations_from_other() + document_dict = doc.asdict() + entities_wo_overlap, valid_relations = _remove_overlapping_entities( + entities=document_dict[entity_layer_name]["annotations"], + relations=document_dict[relation_layer_name]["annotations"], + ) + + document_dict[entity_layer_name] = { + "annotations": entities_wo_overlap, + "predictions": [], + } + document_dict[relation_layer_name] = { + "annotations": valid_relations, + "predictions": [], + } + new_doc = type(doc).fromdict(document_dict) + + return new_doc + + +def _merge_spans_via_relation( + spans: Sequence[LabeledSpan], + relations: Sequence[BinaryRelation], + link_relation_label: str, + create_multi_spans: bool = True, +) -> Tuple[Union[Set[LabeledSpan], Set[LabeledMultiSpan]], Set[BinaryRelation]]: + # convert list of relations to a graph to easily calculate connected components to merge + g = nx.Graph() + link_relations = [] + other_relations = [] + for rel in relations: + if rel.label == link_relation_label: + link_relations.append(rel) + # never merge spans that have not the same label + if ( + not (isinstance(rel.head, LabeledSpan) or isinstance(rel.tail, LabeledSpan)) + or rel.head.label == rel.tail.label + ): + g.add_edge(rel.head, rel.tail) + else: + logger.debug( + f"spans to merge do not have the same label, do not merge them: {rel.head}, {rel.tail}" + ) + else: + other_relations.append(rel) + + span_mapping = {} + connected_components: Set[LabeledSpan] + for connected_components in nx.connected_components(g): + # all spans in a connected component have the same label + label = list(span.label for span in connected_components)[0] + connected_components_sorted = sorted(connected_components, key=lambda span: span.start) + if create_multi_spans: + new_span = LabeledMultiSpan( + slices=tuple((span.start, span.end) for span in connected_components_sorted), + label=label, + ) + else: + new_span = LabeledSpan( + start=min(span.start for span in connected_components_sorted), + end=max(span.end for span in connected_components_sorted), + label=label, + ) + for span in connected_components_sorted: + span_mapping[span] = new_span + for span in spans: + if span not in span_mapping: + if create_multi_spans: + span_mapping[span] = LabeledMultiSpan( + slices=((span.start, span.end),), label=span.label, score=span.score + ) + else: + span_mapping[span] = LabeledSpan( + start=span.start, end=span.end, label=span.label, score=span.score + ) + + new_spans = set(span_mapping.values()) + new_relations = set( + BinaryRelation( + head=span_mapping[rel.head], + tail=span_mapping[rel.tail], + label=rel.label, + score=rel.score, + ) + for rel in other_relations + ) + + return new_spans, new_relations + + +def merge_spans_via_relation( + document: D, + relation_layer: str, + link_relation_label: str, + use_predicted_spans: bool = False, + process_predictions: bool = True, + create_multi_spans: bool = False, +) -> D: + + rel_layer = document[relation_layer] + span_layer = rel_layer.target_layer + new_gold_spans, new_gold_relations = _merge_spans_via_relation( + spans=span_layer, + relations=rel_layer, + link_relation_label=link_relation_label, + create_multi_spans=create_multi_spans, + ) + if process_predictions: + new_pred_spans, new_pred_relations = _merge_spans_via_relation( + spans=span_layer.predictions if use_predicted_spans else span_layer, + relations=rel_layer.predictions, + link_relation_label=link_relation_label, + create_multi_spans=create_multi_spans, + ) + else: + assert not use_predicted_spans + new_pred_spans = set(span_layer.predictions.clear()) + new_pred_relations = set(rel_layer.predictions.clear()) + + relation_layer_name = relation_layer + span_layer_name = document[relation_layer].target_name + if create_multi_spans: + doc_dict = document.asdict() + for f in document.annotation_fields(): + doc_dict.pop(f.name) + + result = TextDocumentWithLabeledMultiSpansAndBinaryRelations.fromdict(doc_dict) + result.labeled_multi_spans.extend(new_gold_spans) + result.labeled_multi_spans.predictions.extend(new_pred_spans) + result.binary_relations.extend(new_gold_relations) + result.binary_relations.predictions.extend(new_pred_relations) + else: + result = document.copy(with_annotations=False) + result[span_layer_name].extend(new_gold_spans) + result[span_layer_name].predictions.extend(new_pred_spans) + result[relation_layer_name].extend(new_gold_relations) + result[relation_layer_name].predictions.extend(new_pred_relations) + + return result + + +def remove_partitions_by_labels( + document: D, partition_layer: str, label_blacklist: List[str] +) -> D: + document = document.copy() + layer: AnnotationLayer = document[partition_layer] + new_partitions = [] + for partition in layer.clear(): + if partition.label not in label_blacklist: + new_partitions.append(partition) + layer.extend(new_partitions) + return document + + +D_text = TypeVar("D_text", bound=Document) + + +def replace_substrings_in_text( + document: D_text, replacements: Dict[str, str], enforce_same_length: bool = True +) -> D_text: + new_text = document.text + for old_str, new_str in replacements.items(): + if enforce_same_length and len(old_str) != len(new_str): + raise ValueError( + f'Replacement strings must have the same length, but got "{old_str}" -> "{new_str}"' + ) + new_text = new_text.replace(old_str, new_str) + result_dict = document.asdict() + result_dict["text"] = new_text + result = type(document).fromdict(result_dict) + result.text = new_text + return result + + +def replace_substrings_in_text_with_spaces(document: D_text, substrings: Iterable[str]) -> D_text: + replacements = {substring: " " * len(substring) for substring in substrings} + return replace_substrings_in_text(document, replacements=replacements) diff --git a/src/evaluate.py b/src/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..e65b43288a9a05c8ca962b0d26e9584a4589f178 --- /dev/null +++ b/src/evaluate.py @@ -0,0 +1,137 @@ +import pyrootutils + +root = pyrootutils.setup_root( + search_from=__file__, + indicator=[".project-root"], + pythonpath=True, + dotenv=True, +) + +# ------------------------------------------------------------------------------------ # +# `pyrootutils.setup_root(...)` is an optional line at the top of each entry file +# that helps to make the environment more robust and convenient +# +# the main advantages are: +# - allows you to keep all entry files in "src/" without installing project as a package +# - makes paths and scripts always work no matter where is your current work dir +# - automatically loads environment variables from ".env" file if exists +# +# how it works: +# - the line above recursively searches for either ".git" or "pyproject.toml" in present +# and parent dirs, to determine the project root dir +# - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from +# any place without installing project as a package +# - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml" +# to make all paths always relative to the project root +# - loads environment variables from ".env" file in root dir (if `dotenv=True`) +# +# you can remove `pyrootutils.setup_root(...)` if you: +# 1. either install project as a package or move each entry file to the project root dir +# 2. simply remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml" +# 3. always run entry files from the project root dir +# +# https://github.com/ashleve/pyrootutils +# ------------------------------------------------------------------------------------ # + +from typing import Tuple + +import hydra +import pytorch_lightning as pl +from omegaconf import DictConfig +from pie_datasets import DatasetDict +from pie_modules.models import * # noqa: F403 +from pie_modules.taskmodules import * # noqa: F403 +from pytorch_ie.core import PyTorchIEModel, TaskModule +from pytorch_ie.models import * # noqa: F403 +from pytorch_ie.taskmodules import * # noqa: F403 +from pytorch_lightning import Trainer + +from src import utils +from src.datamodules import PieDataModule +from src.models import * # noqa: F403 +from src.taskmodules import * # noqa: F403 + +log = utils.get_pylogger(__name__) + + +@utils.task_wrapper +def evaluate(cfg: DictConfig) -> Tuple[dict, dict]: + """Evaluates given checkpoint on a datamodule testset. + + This method is wrapped in optional @task_wrapper decorator which applies extra utilities + before and after the call. + + Args: + cfg (DictConfig): Configuration composed by Hydra. + + Returns: + Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. + """ + + # Set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + pl.seed_everything(cfg.seed, workers=True) + + # Init pytorch-ie dataset + log.info(f"Instantiating dataset <{cfg.dataset._target_}>") + dataset: DatasetDict = hydra.utils.instantiate(cfg.dataset, _convert_="partial") + + # Init pytorch-ie taskmodule + log.info(f"Instantiating taskmodule <{cfg.taskmodule._target_}>") + taskmodule: TaskModule = hydra.utils.instantiate(cfg.taskmodule, _convert_="partial") + + # auto-convert the dataset if the metric specifies a document type + dataset = taskmodule.convert_dataset(dataset) + + # Init pytorch-ie datamodule + log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>") + datamodule: PieDataModule = hydra.utils.instantiate( + cfg.datamodule, dataset=dataset, taskmodule=taskmodule, _convert_="partial" + ) + + # Init pytorch-ie model + log.info(f"Instantiating model <{cfg.model._target_}>") + model: PyTorchIEModel = hydra.utils.instantiate(cfg.model, _convert_="partial") + + # Init lightning loggers + logger = utils.instantiate_dict_entries(cfg, "logger") + + # Init lightning trainer + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger, _convert_="partial") + + object_dict = { + "cfg": cfg, + "taskmodule": taskmodule, + "dataset": dataset, + "model": model, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg) + + log.info("Starting testing!") + trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path) + + # for predictions use trainer.predict(...) + # predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path) + + metric_dict = trainer.callback_metrics + + return metric_dict, object_dict + + +@hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="evaluate.yaml") +def main(cfg: DictConfig) -> None: + metric_dict, _ = evaluate(cfg) + + return metric_dict + + +if __name__ == "__main__": + utils.replace_sys_args_with_values_from_files() + utils.prepare_omegaconf() + main() diff --git a/src/evaluate_documents.py b/src/evaluate_documents.py new file mode 100644 index 0000000000000000000000000000000000000000..e2bc5e4c6d39085a13f8274f2a2ddab86f55cc9f --- /dev/null +++ b/src/evaluate_documents.py @@ -0,0 +1,116 @@ +import pyrootutils + +root = pyrootutils.setup_root( + search_from=__file__, + indicator=[".project-root"], + pythonpath=True, + dotenv=True, +) + +# ------------------------------------------------------------------------------------ # +# `pyrootutils.setup_root(...)` is an optional line at the top of each entry file +# that helps to make the environment more robust and convenient +# +# the main advantages are: +# - allows you to keep all entry files in "src/" without installing project as a package +# - makes paths and scripts always work no matter where is your current work dir +# - automatically loads environment variables from ".env" file if exists +# +# how it works: +# - the line above recursively searches for either ".git" or "pyproject.toml" in present +# and parent dirs, to determine the project root dir +# - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from +# any place without installing project as a package +# - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml" +# to make all paths always relative to the project root +# - loads environment variables from ".env" file in root dir (if `dotenv=True`) +# +# you can remove `pyrootutils.setup_root(...)` if you: +# 1. either install project as a package or move each entry file to the project root dir +# 2. simply remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml" +# 3. always run entry files from the project root dir +# +# https://github.com/ashleve/pyrootutils +# ------------------------------------------------------------------------------------ # + +from typing import Any, Tuple + +import hydra +import pytorch_lightning as pl +from omegaconf import DictConfig +from pie_datasets import DatasetDict +from pytorch_ie.core import DocumentMetric +from pytorch_ie.metrics import * # noqa: F403 + +from src import utils +from src.metrics import * # noqa: F403 + +log = utils.get_pylogger(__name__) + + +@utils.task_wrapper +def evaluate_documents(cfg: DictConfig) -> Tuple[dict, dict]: + """Evaluates serialized PIE documents. + + This method is wrapped in optional @task_wrapper decorator which applies extra utilities + before and after the call. + Args: + cfg (DictConfig): Configuration composed by Hydra. + Returns: + Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. + """ + + # Set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + pl.seed_everything(cfg.seed, workers=True) + + # Init pytorch-ie dataset + log.info(f"Instantiating dataset <{cfg.dataset._target_}>") + dataset: DatasetDict = hydra.utils.instantiate(cfg.dataset, _convert_="partial") + + # Init pytorch-ie taskmodule + log.info(f"Instantiating metric <{cfg.metric._target_}>") + metric: DocumentMetric = hydra.utils.instantiate(cfg.metric, _convert_="partial") + + # auto-convert the dataset if the metric specifies a document type + dataset = metric.convert_dataset(dataset) + + # Init lightning loggers + loggers = utils.instantiate_dict_entries(cfg, "logger") + + object_dict = { + "cfg": cfg, + "dataset": dataset, + "metric": metric, + "logger": loggers, + } + + if loggers: + log.info("Logging hyperparameters!") + # send hparams to all loggers + for logger in loggers: + logger.log_hyperparams(cfg) + + splits = cfg.get("splits", None) + if splits is None: + documents = dataset + else: + documents = type(dataset)({k: v for k, v in dataset.items() if k in splits}) + + metric_dict = metric(documents) + + return metric_dict, object_dict + + +@hydra.main( + version_base="1.2", config_path=str(root / "configs"), config_name="evaluate_documents.yaml" +) +def main(cfg: DictConfig) -> Any: + metric_dict, _ = evaluate_documents(cfg) + return metric_dict + + +if __name__ == "__main__": + utils.replace_sys_args_with_values_from_files() + utils.prepare_omegaconf() + main() diff --git a/src/hydra_callbacks/__init__.py b/src/hydra_callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..293137642e82b75fa7dda2748f0d5835cef8f7b4 --- /dev/null +++ b/src/hydra_callbacks/__init__.py @@ -0,0 +1 @@ +from .save_job_return_value import SaveJobReturnValueCallback diff --git a/src/hydra_callbacks/save_job_return_value.py b/src/hydra_callbacks/save_job_return_value.py new file mode 100644 index 0000000000000000000000000000000000000000..0225c979efdc71114e0d8d030f7e6d1b88895c4e --- /dev/null +++ b/src/hydra_callbacks/save_job_return_value.py @@ -0,0 +1,261 @@ +import json +import logging +import os +import pickle +from pathlib import Path +from typing import Any, Dict, Generator, List, Tuple, Union + +import numpy as np +import pandas as pd +import torch +from hydra.core.utils import JobReturn +from hydra.experimental.callback import Callback +from omegaconf import DictConfig + + +def to_py_obj(obj): + """Convert a PyTorch tensor, Numpy array or python list to a python list. + + Modified version of transformers.utils.generic.to_py_obj. + """ + if isinstance(obj, dict): + return {k: to_py_obj(v) for k, v in obj.items()} + elif isinstance(obj, (list, tuple)): + return type(obj)(to_py_obj(o) for o in obj) + elif isinstance(obj, torch.Tensor): + return obj.detach().cpu().tolist() + elif isinstance(obj, (np.ndarray, np.number)): # tolist also works on 0d np arrays + return obj.tolist() + else: + return obj + + +def list_of_dicts_to_dict_of_lists_recursive(list_of_dicts): + """Convert a list of dicts to a dict of lists recursively. + + Example: + # works with nested dicts + >>> list_of_dicts_to_dict_of_lists_recursive([{"a": 1, "b": {"c": 2}}, {"a": 3, "b": {"c": 4}}]) + {'b': {'c': [2, 4]}, 'a': [1, 3]} + # works with incomplete dicts + >>> list_of_dicts_to_dict_of_lists_recursive([{"a": 1, "b": 2}, {"a": 3}]) + {'b': [2, None], 'a': [1, 3]} + + Args: + list_of_dicts (List[dict]): A list of dicts. + + Returns: + dict: A dict of lists. + """ + if isinstance(list_of_dicts, list): + if len(list_of_dicts) == 0: + return {} + elif isinstance(list_of_dicts[0], dict): + keys = set() + for d in list_of_dicts: + if not isinstance(d, dict): + raise ValueError("Not all elements of the list are dicts.") + keys.update(d.keys()) + return { + k: list_of_dicts_to_dict_of_lists_recursive( + [d.get(k, None) for d in list_of_dicts] + ) + for k in keys + } + else: + return list_of_dicts + else: + return list_of_dicts + + +def _flatten_dict_gen(d, parent_key: Tuple[str, ...] = ()) -> Generator: + for k, v in d.items(): + new_key = parent_key + (k,) + if isinstance(v, dict): + yield from dict(_flatten_dict_gen(v, new_key)).items() + else: + yield new_key, v + + +def flatten_dict(d: Dict[str, Any]) -> Dict[Tuple[str, ...], Any]: + return dict(_flatten_dict_gen(d)) + + +def unflatten_dict(d: Dict[Tuple[str, ...], Any]) -> Union[Dict[str, Any], Any]: + """Unflattens a dictionary with nested keys. + + Example: + >>> d = {("a", "b", "c"): 1, ("a", "b", "d"): 2, ("a", "e"): 3} + >>> unflatten_dict(d) + {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}} + """ + result: Dict[str, Any] = {} + for k, v in d.items(): + if len(k) == 0: + if len(result) > 1: + raise ValueError("Cannot unflatten dictionary with multiple root keys.") + return v + current = result + for key in k[:-1]: + current = current.setdefault(key, {}) + current[k[-1]] = v + return result + + +def overrides_to_identifiers(overrides_per_result: List[List[str]], sep: str = "-") -> List[str]: + """Converts a list of lists of overrides to a list of identifiers. But takes only the overrides + into account, that are not identical for all results. + + Example: + >>> overrides_per_result = [ + ... ["a=1", "b=2", "c=3"], + ... ["a=1", "b=2", "c=4"], + ... ["a=1", "b=3", "c=3"], + ] + >>> overrides_to_identifiers(overrides_per_result) + ['b=2-c=3', 'b=2-c=4', 'b=3-c=3'] + + Args: + overrides_per_result (List[List[str]]): A list of lists of overrides. + sep (str, optional): The separator to use between the overrides. Defaults to "-". + + Returns: + List[str]: A list of identifiers. + """ + # get the overrides that are not identical for all results + overrides_per_result_transposed = np.array(overrides_per_result).T.tolist() + indices = [ + i for i, entries in enumerate(overrides_per_result_transposed) if len(set(entries)) > 1 + ] + # convert the overrides to identifiers + identifiers = [ + sep.join([overrides[idx] for idx in indices]) for overrides in overrides_per_result + ] + return identifiers + + +class SaveJobReturnValueCallback(Callback): + """Save the job return-value in ${output_dir}/{job_return_value_filename}. + + This also works for multi-runs (e.g. sweeps for hyperparameter search). In this case, the result will be saved + additionally in a common file in the multi-run log directory. If integrate_multirun_result=True, the + job return-values are also aggregated (e.g. mean, min, max) and saved in another file. + + params: + ------- + filenames: str or List[str] (default: "job_return_value.json") + The filename(s) of the file(s) to save the job return-value to. If it ends with ".json", + the return-value will be saved as a json file. If it ends with ".pkl", the return-value will be + saved as a pickle file, if it ends with ".md", the return-value will be saved as a markdown file. + integrate_multirun_result: bool (default: True) + If True, the job return-values of all jobs from a multi-run will be rearranged into a dict of lists (maybe + nested), where the keys are the keys of the job return-values and the values are lists of the corresponding + values of all jobs. This is useful if you want to access specific values of all jobs in a multi-run all at once. + Also, aggregated values (e.g. mean, min, max) are created for all numeric values and saved in another file. + """ + + def __init__( + self, + filenames: Union[str, List[str]] = "job_return_value.json", + integrate_multirun_result: bool = False, + ) -> None: + self.log = logging.getLogger(f"{__name__}.{self.__class__.__name__}") + self.filenames = [filenames] if isinstance(filenames, str) else filenames + self.integrate_multirun_result = integrate_multirun_result + self.job_returns: List[JobReturn] = [] + + def on_job_end(self, config: DictConfig, job_return: JobReturn, **kwargs: Any) -> None: + self.job_returns.append(job_return) + output_dir = Path(config.hydra.runtime.output_dir) # / Path(config.hydra.output_subdir) + for filename in self.filenames: + self._save(obj=job_return.return_value, filename=filename, output_dir=output_dir) + + def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None: + if self.integrate_multirun_result: + # rearrange the job return-values of all jobs from a multi-run into a dict of lists (maybe nested), + obj = list_of_dicts_to_dict_of_lists_recursive( + [jr.return_value for jr in self.job_returns] + ) + # also create an aggregated result + # convert to python object to allow selecting numeric columns + obj_py = to_py_obj(obj) + obj_flat = flatten_dict(obj_py) + # create dataframe from flattened dict + df_flat = pd.DataFrame(obj_flat) + # select only the numeric values + df_numbers_only = df_flat.select_dtypes(["number"]) + cols_removed = set(df_flat.columns) - set(df_numbers_only.columns) + if len(cols_removed) > 0: + self.log.warning( + f"Removed the following columns from the aggregated result because they are not numeric: " + f"{cols_removed}" + ) + if len(df_numbers_only.columns) == 0: + obj_aggregated = None + else: + # aggregate the numeric values + df_described = df_numbers_only.describe() + # add the aggregation keys (e.g. mean, min, ...) as most inner keys and convert back to dict + obj_flat_aggregated = df_described.T.stack().to_dict() + # unflatten because _save() works better with nested dicts + obj_aggregated = unflatten_dict(obj_flat_aggregated) + else: + # create a dict of the job return-values of all jobs from a multi-run + # (_save() works better with nested dicts) + ids = overrides_to_identifiers([jr.overrides for jr in self.job_returns]) + obj = {identifier: jr.return_value for identifier, jr in zip(ids, self.job_returns)} + obj_aggregated = None + output_dir = Path(config.hydra.sweep.dir) + for filename in self.filenames: + self._save( + obj=obj, + filename=filename, + output_dir=output_dir, + multi_run_result=self.integrate_multirun_result, + ) + # if available, also save the aggregated result + if obj_aggregated is not None: + file_base_name, ext = os.path.splitext(filename) + filename_aggregated = f"{file_base_name}.aggregated{ext}" + self._save(obj=obj_aggregated, filename=filename_aggregated, output_dir=output_dir) + + def _save( + self, obj: Any, filename: str, output_dir: Path, multi_run_result: bool = False + ) -> None: + self.log.info(f"Saving job_return in {output_dir / filename}") + output_dir.mkdir(parents=True, exist_ok=True) + assert output_dir is not None + if filename.endswith(".pkl"): + with open(str(output_dir / filename), "wb") as file: + pickle.dump(obj, file, protocol=4) + elif filename.endswith(".json"): + # Convert PyTorch tensors and numpy arrays to native python types + obj_py = to_py_obj(obj) + with open(str(output_dir / filename), "w") as file: + json.dump(obj_py, file, indent=2) + elif filename.endswith(".md"): + # Convert PyTorch tensors and numpy arrays to native python types + obj_py = to_py_obj(obj) + obj_py_flat = flatten_dict(obj_py) + + if multi_run_result: + # In the case of multi-run, we expect to have multiple values for each key. + # We therefore just convert the dict to a pandas DataFrame. + result = pd.DataFrame(obj_py_flat) + else: + # In the case of a single job, we expect to have only one value for each key. + # We therefore convert the dict to a pandas Series and ... + series = pd.Series(obj_py_flat) + if len(series.index.levels) > 1: + # ... if the Series has multiple index levels, we create a DataFrame by unstacking the last level. + result = series.unstack(-1) + else: + # ... otherwise we just unpack the one-entry index values and save the resulting Series. + series.index = series.index.get_level_values(0) + result = series + + with open(str(output_dir / filename), "w") as file: + file.write(result.to_markdown()) + + else: + raise ValueError("Unknown file extension") diff --git a/src/langchain_modules/span_retriever.py b/src/langchain_modules/span_retriever.py index 8c4192843dec908efa63a15525a6553c7cb20ca8..09f2674909a8ab53d26ae320fc314bd4cc065bf2 100644 --- a/src/langchain_modules/span_retriever.py +++ b/src/langchain_modules/span_retriever.py @@ -4,7 +4,7 @@ import uuid from collections import defaultdict from copy import copy from enum import Enum -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union +from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Type, Union from langchain_core.callbacks import ( AsyncCallbackManagerForRetrieverRun, @@ -674,14 +674,14 @@ class DocumentAwareSpanRetriever(BaseRetriever, SerializableStore): def add_pie_documents( self, - documents: List[TextBasedDocument], + documents: Iterable[TextBasedDocument], use_predicted_annotations: bool, metadata: Optional[Dict[str, Any]] = None, ) -> None: """Add pie documents to the retriever. Args: - documents: List of pie documents to add + documents: Iterable of pie documents to add use_predicted_annotations: Whether to use the predicted annotations or the gold annotations metadata: Optional metadata to add to each document """ diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e7cd2e1657a85b5c421f21b37520a9646ec8da2b --- /dev/null +++ b/src/metrics/__init__.py @@ -0,0 +1,2 @@ +from .coref_sklearn import CorefMetricsSKLearn +from .coref_torchmetrics import CorefMetricsTorchmetrics diff --git a/src/metrics/annotation_processor.py b/src/metrics/annotation_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..9aa5e0bce5cc54b8014c54386cb619f11d690c11 --- /dev/null +++ b/src/metrics/annotation_processor.py @@ -0,0 +1,23 @@ +from typing import Tuple, List + +from pytorch_ie.annotations import BinaryRelation, Span, LabeledMultiSpan +from pytorch_ie.core import Annotation + + +def decode_span_without_label(ann: Annotation) -> Tuple[Tuple[int, int], ...]: + if isinstance(ann, Span): + return (ann.start, ann.end), + elif isinstance(ann, LabeledMultiSpan): + return ann.slices + else: + raise ValueError("Annotation must be a Span or LabeledMultiSpan") + + +def to_binary_relation_without_argument_labels(ann: Annotation) -> Tuple[Tuple[Tuple[int, int], ...], Tuple[Tuple[int, int], ...], str]: + if not isinstance(ann, BinaryRelation): + raise ValueError("Annotation must be a BinaryRelation") + return ( + decode_span_without_label(ann.head), + decode_span_without_label(ann.tail), + ann.label, + ) diff --git a/src/metrics/coref_sklearn.py b/src/metrics/coref_sklearn.py new file mode 100644 index 0000000000000000000000000000000000000000..2db953a2df2c97c7fe2af8ef9947673089908956 --- /dev/null +++ b/src/metrics/coref_sklearn.py @@ -0,0 +1,162 @@ +import logging +import math +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from pandas import MultiIndex +from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations +from pytorch_ie import DocumentMetric +from pytorch_ie.core.metric import T +from pytorch_ie.utils.hydra import resolve_target +from torchmetrics import Metric, MetricCollection + +from src.hydra_callbacks.save_job_return_value import to_py_obj + +logger = logging.getLogger(__name__) + + +def get_num_total(targets: List[int], preds: List[float]): + return len(targets) + + +def get_num_positives(targets: List[int], preds: List[float], positive_idx: int = 1): + return len([v for v in targets if v == positive_idx]) + + +def discretize( + values: List[float], threshold: Union[float, List[float], dict] +) -> Union[List[float], Dict[Any, List[float]]]: + if isinstance(threshold, float): + result = (np.array(values) >= threshold).astype(int).tolist() + return result + if isinstance(threshold, list): + return {t: discretize(values=values, threshold=t) for t in threshold} # type: ignore + if isinstance(threshold, dict): + thresholds = ( + np.arange(threshold["start"], threshold["end"], threshold["step"]).round(4).tolist() + ) + return discretize(values, threshold=thresholds) + raise TypeError(f"threshold has unknown type: {threshold}") + + +class CorefMetricsSKLearn(DocumentMetric): + DOCUMENT_TYPE = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations + + def __init__( + self, + metrics: Dict[str, str], + thresholds: Optional[Dict[str, float]] = None, + default_target_idx: int = 0, + default_prediction_score: float = 0.0, + show_as_markdown: bool = False, + markdown_precision: int = 4, + plot: bool = False, + ): + self.metrics = {name: resolve_target(metric) for name, metric in metrics.items()} + self.thresholds = thresholds or {} + thresholds_not_in_metrics = { + name: t for name, t in self.thresholds.items() if name not in self.metrics + } + if len(thresholds_not_in_metrics) > 0: + logger.warning( + f"there are discretizing thresholds that do not have a metric: {thresholds_not_in_metrics}" + ) + self.default_target_idx = default_target_idx + self.default_prediction_score = default_prediction_score + self.show_as_markdown = show_as_markdown + self.markdown_precision = markdown_precision + self.plot = plot + + super().__init__() + + def reset(self) -> None: + self._preds: List[float] = [] + self._targets: List[int] = [] + + def _update(self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations) -> None: + target_args2idx = { + (rel.head, rel.tail): int(rel.score) for rel in document.binary_coref_relations + } + prediction_args2score = { + (rel.head, rel.tail): rel.score for rel in document.binary_coref_relations.predictions + } + all_args = set(target_args2idx) | set(prediction_args2score) + all_targets: List[int] = [] + all_predictions: List[float] = [] + for args in all_args: + target_idx = target_args2idx.get(args, self.default_target_idx) + prediction_score = prediction_args2score.get(args, self.default_prediction_score) + all_targets.append(target_idx) + all_predictions.append(prediction_score) + # prediction_scores = torch.tensor(all_predictions) + # target_indices = torch.tensor(all_targets) + # self.metrics.update(preds=prediction_scores, target=target_indices) + self._preds.extend(all_predictions) + self._targets.extend(all_targets) + + def do_plot(self): + raise NotImplementedError() + + from matplotlib import pyplot as plt + + # Get the number of metrics + num_metrics = len(self.metrics) + + # Calculate rows and columns for subplots (aim for a square-like layout) + ncols = math.ceil(math.sqrt(num_metrics)) + nrows = math.ceil(num_metrics / ncols) + + # Create the subplots + fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10)) + + # Flatten the ax_list if necessary (in case of multiple rows/columns) + ax_list = ax_list.flatten().tolist() # Ensure it's a list, and flatten it if necessary + + # Ensure that we pass exactly the number of axes required by metrics + ax_list = ax_list[:num_metrics] + + # Plot the metrics using the list of axes + self.metrics.plot(ax=ax_list, together=False) + + # Adjust layout to avoid overlapping plots + plt.tight_layout() + plt.show() + + def _compute(self) -> T: + + if self.plot: + self.do_plot() + + result = {} + for name, metric in self.metrics.items(): + + if name in self.thresholds: + preds = discretize(values=self._preds, threshold=self.thresholds[name]) + else: + preds = self._preds + if isinstance(preds, dict): + metric_results = { + t: metric(self._targets, t_preds) for t, t_preds in preds.items() + } + # just get the max + max_t, max_v = max(metric_results.items(), key=lambda k_v: k_v[1]) + result[f"{name}-{max_t}"] = max_v + else: + result[name] = metric(self._targets, preds) + + result = to_py_obj(result) + if self.show_as_markdown: + import pandas as pd + + series = pd.Series(result) + if isinstance(series.index, MultiIndex): + if len(series.index.levels) > 1: + # in fact, this is not a series anymore + series = series.unstack(-1) + else: + series.index = series.index.get_level_values(0) + logger.info( + f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}" + ) + return result diff --git a/src/metrics/coref_torchmetrics.py b/src/metrics/coref_torchmetrics.py new file mode 100644 index 0000000000000000000000000000000000000000..b19830ac5040bb8a17271c8ff158a3a150c7e8ab --- /dev/null +++ b/src/metrics/coref_torchmetrics.py @@ -0,0 +1,107 @@ +import logging +import math +from typing import Dict + +import torch +from pandas import MultiIndex +from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations +from pytorch_ie import DocumentMetric +from pytorch_ie.core.metric import T +from torchmetrics import Metric, MetricCollection + +from src.hydra_callbacks.save_job_return_value import to_py_obj + +logger = logging.getLogger(__name__) + + +class CorefMetricsTorchmetrics(DocumentMetric): + DOCUMENT_TYPE = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations + + def __init__( + self, + metrics: Dict[str, Metric], + default_target_idx: int = 0, + default_prediction_score: float = 0.0, + show_as_markdown: bool = False, + markdown_precision: int = 4, + plot: bool = False, + ): + self.metrics = MetricCollection(metrics) + self.default_target_idx = default_target_idx + self.default_prediction_score = default_prediction_score + self.show_as_markdown = show_as_markdown + self.markdown_precision = markdown_precision + self.plot = plot + + super().__init__() + + def reset(self) -> None: + self.metrics.reset() + + def _update(self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations) -> None: + target_args2idx = { + (rel.head, rel.tail): int(rel.score) for rel in document.binary_coref_relations + } + prediction_args2score = { + (rel.head, rel.tail): rel.score for rel in document.binary_coref_relations.predictions + } + all_args = set(target_args2idx) | set(prediction_args2score) + all_targets = [] + all_predictions = [] + for args in all_args: + target_idx = target_args2idx.get(args, self.default_target_idx) + prediction_score = prediction_args2score.get(args, self.default_prediction_score) + all_targets.append(target_idx) + all_predictions.append(prediction_score) + prediction_scores = torch.tensor(all_predictions) + target_indices = torch.tensor(all_targets) + self.metrics.update(preds=prediction_scores, target=target_indices) + + def do_plot(self): + from matplotlib import pyplot as plt + + # Get the number of metrics + num_metrics = len(self.metrics) + + # Calculate rows and columns for subplots (aim for a square-like layout) + ncols = math.ceil(math.sqrt(num_metrics)) + nrows = math.ceil(num_metrics / ncols) + + # Create the subplots + fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10)) + + # Flatten the ax_list if necessary (in case of multiple rows/columns) + ax_list = ax_list.flatten().tolist() # Ensure it's a list, and flatten it if necessary + + # Ensure that we pass exactly the number of axes required by metrics + ax_list = ax_list[:num_metrics] + + # Plot the metrics using the list of axes + self.metrics.plot(ax=ax_list, together=False) + + # Adjust layout to avoid overlapping plots + plt.tight_layout() + plt.show() + + def _compute(self) -> T: + + if self.plot: + self.do_plot() + + result = self.metrics.compute() + + result = to_py_obj(result) + if self.show_as_markdown: + import pandas as pd + + series = pd.Series(result) + if isinstance(series.index, MultiIndex): + if len(series.index.levels) > 1: + # in fact, this is not a series anymore + series = series.unstack(-1) + else: + series.index = series.index.get_level_values(0) + logger.info( + f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}" + ) + return result diff --git a/src/models/__init__.py b/src/models/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..f99145a08dfdbb616d7be60678f6f1753ccf1728 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -0,0 +1,5 @@ +from .sequence_classification_with_pooler import ( + SequencePairSimilarityModelWithMaxCosineSim, + SequencePairSimilarityModelWithPooler2, + SequencePairSimilarityModelWithPoolerAndAdapter, +) diff --git a/src/models/components/__init__.py b/src/models/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/models/components/pooler.py b/src/models/components/pooler.py new file mode 100644 index 0000000000000000000000000000000000000000..bc5d9a18ebb6acacea4532b5d0d757ad6b7c1017 --- /dev/null +++ b/src/models/components/pooler.py @@ -0,0 +1,79 @@ +import torch +from torch import Tensor, cat, nn + + +class SpanMeanPooler(nn.Module): + """Pooler that takes the mean hidden state over spans. If the start or end index is negative, a + learned embedding is used. The indices are expected to have the shape [batch_size, + num_indices]. + + The resulting embeddings are concatenated, so the output shape is [batch_size, num_indices * input_dim]. + Note this a slightly modified version of the pie_modules.models.components.pooler.SpanMaxPooler, + i.e. we changed the aggregation method from torch.amax to torch.mean. + + Args: + input_dim: The input dimension of the hidden state. + num_indices: The number of indices to pool. + + Returns: + The pooled hidden states with shape [batch_size, num_indices * input_dim]. + """ + + def __init__(self, input_dim: int, num_indices: int = 2, **kwargs): + super().__init__(**kwargs) + self.input_dim = input_dim + self.num_indices = num_indices + self.missing_embeddings = nn.Parameter(torch.empty(num_indices, self.input_dim)) + nn.init.normal_(self.missing_embeddings) + + def forward( + self, hidden_state: Tensor, start_indices: Tensor, end_indices: Tensor, **kwargs + ) -> Tensor: + batch_size, seq_len, hidden_size = hidden_state.shape + if start_indices.shape[1] != self.num_indices: + raise ValueError( + f"number of start indices [{start_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]" + ) + + if end_indices.shape[1] != self.num_indices: + raise ValueError( + f"number of end indices [{end_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]" + ) + + # check that start_indices are before end_indices + mask_both_positive = (start_indices >= 0) & (end_indices >= 0) + mask_start_before_end = start_indices < end_indices + mask_valid = mask_start_before_end | ~mask_both_positive + if not torch.all(mask_valid): + raise ValueError( + f"values in start_indices have to be smaller than respective values in " + f"end_indices, but start_indices=\n{start_indices}\n and end_indices=\n{end_indices}" + ) + + # times num_indices due to concat + result = torch.zeros( + batch_size, hidden_size * self.num_indices, device=hidden_state.device + ) + for batch_idx in range(batch_size): + current_start_indices = start_indices[batch_idx] + current_end_indices = end_indices[batch_idx] + current_embeddings = [ + ( + torch.mean( + hidden_state[ + batch_idx, current_start_indices[i] : current_end_indices[i], : + ], + dim=0, + ) + if current_start_indices[i] >= 0 and current_end_indices[i] >= 0 + else self.missing_embeddings[i] + ) + for i in range(self.num_indices) + ] + result[batch_idx] = cat(current_embeddings, 0) + + return result + + @property + def output_dim(self) -> int: + return self.input_dim * self.num_indices diff --git a/src/models/sequence_classification_with_pooler.py b/src/models/sequence_classification_with_pooler.py new file mode 100644 index 0000000000000000000000000000000000000000..52ae4e99f3beb93dfe710057d527149448234a8c --- /dev/null +++ b/src/models/sequence_classification_with_pooler.py @@ -0,0 +1,166 @@ +import abc +import logging +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from adapters import AutoAdapterModel +from pie_modules.models import SequencePairSimilarityModelWithPooler +from pie_modules.models.components.pooler import MENTION_POOLING +from pie_modules.models.sequence_classification_with_pooler import ( + InputType, + OutputType, + SequenceClassificationModelWithPooler, + SequenceClassificationModelWithPoolerBase, + TargetType, + separate_arguments_by_prefix, +) +from pytorch_ie import PyTorchIEModel +from torch import FloatTensor, Tensor +from transformers import AutoConfig, PreTrainedModel +from transformers.modeling_outputs import SequenceClassifierOutput + +from src.models.components.pooler import SpanMeanPooler + +logger = logging.getLogger(__name__) + + +class SequenceClassificationModelWithPoolerBase2( + SequenceClassificationModelWithPoolerBase, abc.ABC +): + def setup_pooler(self, input_dim: int) -> Tuple[Callable, int]: + aggregate = self.pooler_config.get("aggregate", "max") + if self.pooler_config["type"] == MENTION_POOLING and aggregate != "max": + if aggregate == "mean": + pooler_config = dict(self.pooler_config) + pooler_config.pop("type") + pooler_config.pop("aggregate") + pooler = SpanMeanPooler(input_dim=input_dim, **pooler_config) + return pooler, pooler.output_dim + else: + raise ValueError(f"Unknown aggregation method: {aggregate}") + else: + return super().setup_pooler(input_dim) + + +class SequenceClassificationModelWithPoolerAndAdapterBase( + SequenceClassificationModelWithPoolerBase2, abc.ABC +): + def __init__(self, adapter_name_or_path: Optional[str] = None, **kwargs): + self.adapter_name_or_path = adapter_name_or_path + super().__init__(**kwargs) + + def setup_base_model(self) -> PreTrainedModel: + if self.adapter_name_or_path is None: + return super().setup_base_model() + else: + config = AutoConfig.from_pretrained(self.model_name_or_path) + if self.is_from_pretrained: + model = AutoAdapterModel.from_config(config=config) + else: + model = AutoAdapterModel.from_pretrained(self.model_name_or_path, config=config) + # load the adapter in any case (it looks like it is not saved in the state or loaded + # from a serialized state) + logger.info(f"load adapter: {self.adapter_name_or_path}") + model.load_adapter(self.adapter_name_or_path, source="hf", set_active=True) + return model + + +@PyTorchIEModel.register() +class SequencePairSimilarityModelWithPooler2( + SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerBase2 +): + pass + + +@PyTorchIEModel.register() +class SequencePairSimilarityModelWithPoolerAndAdapter( + SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase +): + pass + + +@PyTorchIEModel.register() +class SequenceClassificationModelWithPoolerAndAdapter( + SequenceClassificationModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase +): + pass + + +def get_max_cosine_sim(embeddings: Tensor, embeddings_pair: Tensor) -> Tensor: + # Normalize the embeddings + embeddings_normalized = F.normalize(embeddings, p=2, dim=1) # Shape: (n, k) + embeddings_normalized_pair = F.normalize(embeddings_pair, p=2, dim=1) # Shape: (m, k) + + # Compute the cosine similarity matrix + cosine_sim = torch.mm(embeddings_normalized, embeddings_normalized_pair.T) # Shape: (n, m) + + # Get the overall maximum cosine similarity value + max_cosine_sim = torch.max(cosine_sim) # This will return a scalar + return max_cosine_sim + + +def get_span_embeddings( + embeddings: FloatTensor, start_indices: Tensor, end_indices: Tensor +) -> List[FloatTensor]: + result = [] + for embeds, starts, ends in zip(embeddings, start_indices, end_indices): + span_embeds = embeds[starts[0] : ends[0]] + result.append(span_embeds) + return result + + +@PyTorchIEModel.register() +class SequencePairSimilarityModelWithMaxCosineSim(SequencePairSimilarityModelWithPooler): + def get_pooled_output(self, model_inputs, pooler_inputs) -> List[FloatTensor]: + output = self.model(**model_inputs) + hidden_state = output.last_hidden_state + # pooled_output = self.pooler(hidden_state, **pooler_inputs) + # pooled_output = self.dropout(pooled_output) + span_embeds = get_span_embeddings(hidden_state, **pooler_inputs) + return span_embeds + + def forward( + self, + inputs: InputType, + targets: Optional[TargetType] = None, + return_hidden_states: bool = False, + ) -> OutputType: + sanitized_inputs = separate_arguments_by_prefix( + # Note that the order of the prefixes is important because one is a prefix of the other, + # so we need to start with the longer! + arguments=inputs, + prefixes=["pooler_pair_", "pooler_"], + ) + + span_embeddings = self.get_pooled_output( + model_inputs=sanitized_inputs["remaining"]["encoding"], + pooler_inputs=sanitized_inputs["pooler_"], + ) + span_embeddings_pair = self.get_pooled_output( + model_inputs=sanitized_inputs["remaining"]["encoding_pair"], + pooler_inputs=sanitized_inputs["pooler_pair_"], + ) + + logits_list = [ + get_max_cosine_sim(span_embeds, span_embeds_pair) + for span_embeds, span_embeds_pair in zip(span_embeddings, span_embeddings_pair) + ] + logits = torch.stack(logits_list) + + result = {"logits": logits} + if targets is not None: + labels = targets["scores"] + loss = self.loss_fct(logits, labels) + result["loss"] = loss + if return_hidden_states: + raise NotImplementedError("return_hidden_states is not yet implemented") + + return SequenceClassifierOutput(**result) + + +@PyTorchIEModel.register() +class SequencePairSimilarityModelWithMaxCosineSimAndAdapter( + SequencePairSimilarityModelWithMaxCosineSim, SequencePairSimilarityModelWithPoolerAndAdapter +): + pass diff --git a/src/models/utils/__init__.py b/src/models/utils/__init__.py index 8c9b0b43fc643aa8e1ed8e963cd653082366b6fa..4accfd9be9cfebfd8e08ff86f4e19f101bca9f27 100644 --- a/src/models/utils/__init__.py +++ b/src/models/utils/__init__.py @@ -1 +1,5 @@ -from .loading import load_model_from_pie_model, load_model_with_adapter, load_tokenizer_from_pie_taskmodule \ No newline at end of file +from .loading import ( + load_model_from_pie_model, + load_model_with_adapter, + load_tokenizer_from_pie_taskmodule, +) diff --git a/src/models/utils/loading.py b/src/models/utils/loading.py index 585152823313bd43a06aeeeb928374b0c31c313c..268d92d90ed029e4962c80486954a92b5b1e322e 100644 --- a/src/models/utils/loading.py +++ b/src/models/utils/loading.py @@ -23,10 +23,10 @@ def load_tokenizer_from_pie_taskmodule(taskmodule_kwargs: Dict[str, Any]) -> Pre def load_model_with_adapter( - model_kwargs: Dict[str, Any], adapter_kwargs: Dict[str, Any] -) -> "ModelAdaptersMixin": - from adapters import AutoAdapterModel, ModelAdaptersMixin + model_kwargs: Dict[str, Any], adapter_kwargs: Dict[str, Any] +) -> PreTrainedModel: + from adapters import AutoAdapterModel model = AutoAdapterModel.from_pretrained(**model_kwargs) model.load_adapter(set_active=True, **adapter_kwargs) - return model \ No newline at end of file + return model diff --git a/src/pipeline/__init__.py b/src/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..019c9b0ead0b6c39c94441748e81c6f1774e294a --- /dev/null +++ b/src/pipeline/__init__.py @@ -0,0 +1,2 @@ +from .ner_re_pipeline import NerRePipeline +from .span_retrieval_based_re_pipeline import SpanRetrievalBasedRelationExtractionPipeline diff --git a/src/pipeline/ner_re_pipeline.py b/src/pipeline/ner_re_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..5d40627e8bc6b5bf813f252112c226fc15817af5 --- /dev/null +++ b/src/pipeline/ner_re_pipeline.py @@ -0,0 +1,208 @@ +from __future__ import annotations + +import logging +from functools import partial +from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, TypeVar, Union + +from pie_modules.utils import resolve_type +from pytorch_ie import AutoPipeline, WithDocumentTypeMixin +from pytorch_ie.core import Document + +logger = logging.getLogger(__name__) + + +D = TypeVar("D", bound=Document) + + +def clear_annotation_layers(doc: D, layer_names: List[str], predictions: bool = False) -> None: + for layer_name in layer_names: + if predictions: + doc[layer_name].predictions.clear() + else: + doc[layer_name].clear() + + +def move_annotations_from_predictions(doc: D, layer_names: List[str]) -> None: + for layer_name in layer_names: + annotations = list(doc[layer_name].predictions) + # remove any previous annotations + doc[layer_name].clear() + # each annotation can be attached to just one annotation container, so we need to clear the predictions + doc[layer_name].predictions.clear() + doc[layer_name].extend(annotations) + + +def move_annotations_to_predictions(doc: D, layer_names: List[str]) -> None: + for layer_name in layer_names: + annotations = list(doc[layer_name]) + # each annotation can be attached to just one annotation container, so we need to clear the layer + doc[layer_name].clear() + # remove any previous annotations + doc[layer_name].predictions.clear() + doc[layer_name].predictions.extend(annotations) + + +def add_annotations_from_other_documents( + docs: Iterable[D], + other_docs: Sequence[Document], + layer_names: List[str], + from_predictions: bool = False, + to_predictions: bool = False, + clear_before: bool = True, +) -> None: + for i, doc in enumerate(docs): + other_doc = other_docs[i] + # copy to not modify the input + other_doc = type(other_doc).fromdict(other_doc.asdict()) + + for layer_name in layer_names: + if clear_before: + doc[layer_name].clear() + other_layer = other_doc[layer_name] + if from_predictions: + other_layer = other_layer.predictions + other_annotations = list(other_layer) + other_layer.clear() + if to_predictions: + doc[layer_name].predictions.extend(other_annotations) + else: + doc[layer_name].extend(other_annotations) + + +def process_pipeline_steps( + documents: Sequence[Document], + processors: Dict[str, Callable[[Sequence[Document]], Optional[Sequence[Document]]]], +) -> Sequence[Document]: + + # call the processors in the order they are provided + for step_name, processor in processors.items(): + logger.info(f"process {step_name} ...") + processed_documents = processor(documents) + if processed_documents is not None: + documents = processed_documents + + return documents + + +def process_documents( + documents: List[Document], processor: Callable[..., Optional[Document]], **kwargs +) -> List[Document]: + result = [] + for doc in documents: + processed_doc = processor(doc, **kwargs) + if processed_doc is not None: + result.append(processed_doc) + else: + result.append(doc) + return result + + +class DummyTaskmodule(WithDocumentTypeMixin): + def __init__(self, document_type: Optional[Union[Type[Document], str]]): + if isinstance(document_type, str): + self._document_type = resolve_type(document_type, expected_super_type=Document) + else: + self._document_type = document_type + + @property + def document_type(self) -> Optional[Type[Document]]: + return self._document_type + + +class NerRePipeline: + def __init__( + self, + ner_model_path: str, + re_model_path: str, + entity_layer: str, + relation_layer: str, + device: Optional[int] = None, + batch_size: Optional[int] = None, + show_progress_bar: Optional[bool] = None, + document_type: Optional[Union[Type[Document], str]] = None, + **processor_kwargs, + ): + self.taskmodule = DummyTaskmodule(document_type) + self.ner_model_path = ner_model_path + self.re_model_path = re_model_path + self.processor_kwargs = processor_kwargs or {} + self.entity_layer = entity_layer + self.relation_layer = relation_layer + # set some values for the inference processors, if provided + for inference_pipeline in ["ner_pipeline", "re_pipeline"]: + if inference_pipeline not in self.processor_kwargs: + self.processor_kwargs[inference_pipeline] = {} + if "device" not in self.processor_kwargs[inference_pipeline] and device is not None: + self.processor_kwargs[inference_pipeline]["device"] = device + if ( + "batch_size" not in self.processor_kwargs[inference_pipeline] + and batch_size is not None + ): + self.processor_kwargs[inference_pipeline]["batch_size"] = batch_size + if ( + "show_progress_bar" not in self.processor_kwargs[inference_pipeline] + and show_progress_bar is not None + ): + self.processor_kwargs[inference_pipeline]["show_progress_bar"] = show_progress_bar + + def __call__(self, documents: Sequence[Document], inplace: bool = False) -> Sequence[Document]: + + input_docs: Sequence[Document] + # we need to keep the original documents to add the gold data back + original_docs: Sequence[Document] + if inplace: + input_docs = documents + original_docs = [doc.copy() for doc in documents] + else: + input_docs = [doc.copy() for doc in documents] + original_docs = documents + + docs_with_predictions = process_pipeline_steps( + documents=input_docs, + processors={ + "clear_annotations": partial( + process_documents, + processor=clear_annotation_layers, + layer_names=[self.entity_layer, self.relation_layer], + **self.processor_kwargs.get("clear_annotations", {}), + ), + "ner_pipeline": AutoPipeline.from_pretrained( + self.ner_model_path, **self.processor_kwargs.get("ner_pipeline", {}) + ), + "use_predicted_entities": partial( + process_documents, + processor=move_annotations_from_predictions, + layer_names=[self.entity_layer], + **self.processor_kwargs.get("use_predicted_entities", {}), + ), + # "create_candidate_relations": partial( + # process_documents, + # processor=CandidateRelationAdder( + # **self.processor_kwargs.get("create_candidate_relations", {}) + # ), + # ), + "re_pipeline": AutoPipeline.from_pretrained( + self.re_model_path, **self.processor_kwargs.get("re_pipeline", {}) + ), + # otherwise we can not move the entities back to predictions + "clear_candidate_relations": partial( + process_documents, + processor=clear_annotation_layers, + layer_names=[self.relation_layer], + **self.processor_kwargs.get("clear_candidate_relations", {}), + ), + "move_entities_to_predictions": partial( + process_documents, + processor=move_annotations_to_predictions, + layer_names=[self.entity_layer], + **self.processor_kwargs.get("move_entities_to_predictions", {}), + ), + "re_add_gold_data": partial( + add_annotations_from_other_documents, + other_docs=original_docs, + layer_names=[self.entity_layer, self.relation_layer], + **self.processor_kwargs.get("re_add_gold_data", {}), + ), + }, + ) + return docs_with_predictions diff --git a/src/pipeline/span_retrieval_based_re_pipeline.py b/src/pipeline/span_retrieval_based_re_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..77f86fdd8bd69cae190b57992ae683b12accfb72 --- /dev/null +++ b/src/pipeline/span_retrieval_based_re_pipeline.py @@ -0,0 +1,130 @@ +import logging +from typing import Optional, Sequence, Type + +from langchain_core.documents import Document as LCDocument +from pie_datasets import Dataset, IterableDataset +from pytorch_ie import Document, WithDocumentTypeMixin +from pytorch_ie.annotations import BinaryRelation, LabeledSpan +from pytorch_ie.documents import TextBasedDocument + +from src.langchain_modules import DocumentAwareSpanRetriever + +logger = logging.getLogger(__name__) + + +class DummyTaskmodule(WithDocumentTypeMixin): + def __init__(self, document_type: Type[Document]): + self._document_type = document_type + + @property + def document_type(self) -> Optional[Type[Document]]: + return self._document_type + + +class SpanRetrievalBasedRelationExtractionPipeline: + """Pipeline for adding binary relations between spans based on span retrieval within the same document. + + This pipeline retrieves spans for all existing spans as query and adds binary relations between the + query spans and the retrieved spans. + + Args: + retriever: The span retriever to use for retrieving spans. + relation_label: The label to use for the binary relations. + relation_layer_name: The name of the annotation layer to add the binary relations to. + load_store_path: If provided, the retriever store(s) will be loaded from this path before processing. + save_store_path: If provided, the retriever store(s) will be saved to this path after processing. + fast_dev_run: Whether to run the pipeline in fast dev mode, i.e. only processing the first 2 documents. + """ + + def __init__( + self, + retriever: DocumentAwareSpanRetriever, + relation_label: str, + relation_layer_name: str = "binary_relations", + use_predicted_annotations: bool = False, + load_store_path: Optional[str] = None, + save_store_path: Optional[str] = None, + fast_dev_run: bool = False, + ): + self.retriever = retriever + if not self.retriever.retrieve_from_same_document: + raise NotImplementedError("Retriever must retrieve from the same document") + self.relation_label = relation_label + self.relation_layer_name = relation_layer_name + self.use_predicted_annotations = use_predicted_annotations + self.load_store_path = load_store_path + self.save_store_path = save_store_path + if self.load_store_path is not None: + self.retriever.load_from_directory(path=self.load_store_path) + + self.fast_dev_run = fast_dev_run + + # to make auto-conversion work: we request documents of type pipeline.taskmodule.document_type + # from the dataset + @property + def taskmodule(self) -> DummyTaskmodule: + return DummyTaskmodule(self.retriever.pie_document_type) + + def _construct_similarity_relations( + self, + query_results: list[LCDocument], + query_span: LabeledSpan, + ) -> list[BinaryRelation]: + return [ + BinaryRelation( + head=query_span, + tail=lc_doc.metadata["attached_span"], + label=self.relation_label, + score=float(lc_doc.metadata["relevance_score"]), + ) + for lc_doc in query_results + ] + + def _process_single_document( + self, + document: Document, + ) -> TextBasedDocument: + if not isinstance(document, TextBasedDocument): + raise ValueError("Document must be a TextBasedDocument") + + self.retriever.add_pie_documents( + [document], use_predicted_annotations=self.use_predicted_annotations + ) + + all_new_rels = [] + spans = self.retriever.get_base_layer( + document, use_predicted_annotations=self.use_predicted_annotations + ) + span_id2idx = self.retriever.get_span_id2idx_from_doc(document.id) + for span_id, span_idx in span_id2idx.items(): + query_span = spans[span_idx] + query_result = self.retriever.invoke(input=span_id) + query_rels = self._construct_similarity_relations(query_result, query_span=query_span) + all_new_rels.extend(query_rels) + + if self.relation_layer_name not in document: + raise ValueError(f"Document does not have a layer named {self.relation_layer_name}") + document[self.relation_layer_name].predictions.extend(all_new_rels) + + if self.retriever.retrieve_from_same_document and self.save_store_path is None: + self.retriever.delete_documents([document.id]) + + return document + + def __call__(self, documents: Sequence[Document], inplace: bool = False) -> Sequence[Document]: + if inplace: + raise NotImplementedError("Inplace processing is not supported yet") + + if self.fast_dev_run: + logger.warning("Fast dev run enabled, only processing the first 2 documents") + documents = documents[:2] + + if not isinstance(documents, (Dataset, IterableDataset)): + documents = Dataset.from_documents(documents) + + mapped_documents = documents.map(self._process_single_document) + + if self.save_store_path is not None: + self.retriever.save_to_directory(path=self.save_store_path) + + return mapped_documents diff --git a/src/predict.py b/src/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..e4c0f8b8d77566078e9fdbcb83c0fa0f2dc0e1bb --- /dev/null +++ b/src/predict.py @@ -0,0 +1,183 @@ +import pyrootutils + +root = pyrootutils.setup_root( + search_from=__file__, + indicator=[".project-root"], + pythonpath=True, + dotenv=True, +) + +# ------------------------------------------------------------------------------------ # +# `pyrootutils.setup_root(...)` is an optional line at the top of each entry file +# that helps to make the environment more robust and convenient +# +# the main advantages are: +# - allows you to keep all entry files in "src/" without installing project as a package +# - makes paths and scripts always work no matter where is your current work dir +# - automatically loads environment variables from ".env" file if exists +# +# how it works: +# - the line above recursively searches for either ".git" or "pyproject.toml" in present +# and parent dirs, to determine the project root dir +# - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from +# any place without installing project as a package +# - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml" +# to make all paths always relative to the project root +# - loads environment variables from ".env" file in root dir (if `dotenv=True`) +# +# you can remove `pyrootutils.setup_root(...)` if you: +# 1. either install project as a package or move each entry file to the project root dir +# 2. simply remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml" +# 3. always run entry files from the project root dir +# +# https://github.com/ashleve/pyrootutils +# ------------------------------------------------------------------------------------ # + +import os +import timeit +from collections.abc import Iterable, Sequence +from typing import Any, Dict, Optional, Tuple, Union + +import hydra +import pytorch_lightning as pl +from omegaconf import DictConfig, OmegaConf +from pie_datasets import Dataset, DatasetDict +from pie_modules.models import * # noqa: F403 +from pie_modules.taskmodules import * # noqa: F403 +from pytorch_ie import Document, Pipeline +from pytorch_ie.models import * # noqa: F403 +from pytorch_ie.taskmodules import * # noqa: F403 + +from src import utils +from src.models import * # noqa: F403 +from src.serializer.interface import DocumentSerializer +from src.taskmodules import * # noqa: F403 + +log = utils.get_pylogger(__name__) + + +def document_batch_iter( + dataset: Union[Sequence[Document], Iterable[Document]], batch_size: int +) -> Iterable[Sequence[Document]]: + if isinstance(dataset, Sequence): + for i in range(0, len(dataset), batch_size): + yield dataset[i : i + batch_size] + elif isinstance(dataset, Iterable): + docs = [] + for doc in dataset: + docs.append(doc) + if len(docs) == batch_size: + yield docs + docs = [] + if docs: + yield docs + else: + raise ValueError(f"Unsupported dataset type: {type(dataset)}") + + +@utils.task_wrapper +def predict(cfg: DictConfig) -> Tuple[dict, dict]: + """Contains minimal example of the prediction pipeline. Uses a pretrained model to annotate + documents from a dataset and serializes them. + + Args: + cfg (DictConfig): Configuration composed by Hydra. + + Returns: + None + """ + + # Set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + pl.seed_everything(cfg.seed, workers=True) + + # Init pytorch-ie dataset + log.info(f"Instantiating dataset <{cfg.dataset._target_}>") + dataset: DatasetDict = hydra.utils.instantiate(cfg.dataset, _convert_="partial") + + # Init pytorch-ie pipeline + # The pipeline, and therefore the inference step, is optional to allow for easy testing + # of the dataset creation and processing. + pipeline: Optional[Pipeline] = None + if cfg.get("pipeline") and cfg.pipeline.get("_target_"): + log.info(f"Instantiating pipeline <{cfg.pipeline._target_}> from {cfg.model_name_or_path}") + pipeline = hydra.utils.instantiate(cfg.pipeline, _convert_="partial") + + # Per default, the model is loaded with .from_pretrained() which already loads the weights. + # However, ckpt_path can be used to load different weights from any checkpoint. + if cfg.ckpt_path is not None: + pipeline.model = pipeline.model.load_from_checkpoint(checkpoint_path=cfg.ckpt_path).to( + pipeline.device + ) + + # auto-convert the dataset if the metric specifies a document type + dataset = pipeline.taskmodule.convert_dataset(dataset) + + # Init the serializer + serializer: Optional[DocumentSerializer] = None + if cfg.get("serializer") and cfg.serializer.get("_target_"): + log.info(f"Instantiating serializer <{cfg.serializer._target_}>") + serializer = hydra.utils.instantiate(cfg.serializer, _convert_="partial") + + # select the dataset split for prediction + dataset_predict = dataset[cfg.dataset_split] + + object_dict = { + "cfg": cfg, + "dataset": dataset, + "pipeline": pipeline, + "serializer": serializer, + } + result: Dict[str, Any] = {} + if pipeline is not None: + log.info("Starting inference!") + prediction_time = 0.0 + else: + log.warning("No prediction pipeline is defined, skip inference!") + prediction_time = None + document_batch_size = cfg.get("document_batch_size", None) + for docs_batch in ( + document_batch_iter(dataset_predict, document_batch_size) + if document_batch_size + else [dataset_predict] + ): + if pipeline is not None: + t_start = timeit.default_timer() + docs_batch = pipeline(docs_batch, inplace=False) + prediction_time += timeit.default_timer() - t_start # type: ignore + + # serialize the documents + if serializer is not None: + # the serializer should not return the serialized documents, but write them to disk + # and instead return some metadata such as the path to the serialized documents + serializer_result = serializer(docs_batch) + if "serializer" in result and result["serializer"] != serializer_result: + log.warning( + f"serializer result changed from {result['serializer']} to {serializer_result}" + " during prediction. Only the last result is returned." + ) + result["serializer"] = serializer_result + + if prediction_time is not None: + result["prediction_time"] = prediction_time + + # serialize config with resolved paths + if cfg.get("config_out_path"): + config_out_dir = os.path.dirname(cfg.config_out_path) + os.makedirs(config_out_dir, exist_ok=True) + OmegaConf.save(config=cfg, f=cfg.config_out_path) + result["config"] = cfg.config_out_path + + return result, object_dict + + +@hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="predict.yaml") +def main(cfg: DictConfig) -> None: + result_dict, _ = predict(cfg) + return result_dict + + +if __name__ == "__main__": + utils.replace_sys_args_with_values_from_files() + utils.prepare_omegaconf() + main() diff --git a/src/serializer/__init__.py b/src/serializer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9df4f1769cfa62649f6905d6fee22a14aab8a2e6 --- /dev/null +++ b/src/serializer/__init__.py @@ -0,0 +1 @@ +from .json import JsonSerializer, JsonSerializer2 diff --git a/src/serializer/interface.py b/src/serializer/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..0ba07401bad2d1983ca1b07e6b0fa4fd027ded9d --- /dev/null +++ b/src/serializer/interface.py @@ -0,0 +1,16 @@ +from abc import ABC, abstractmethod +from typing import Any, Sequence + +from pytorch_ie.core import Document + + +class DocumentSerializer(ABC): + """This defines the interface for a document serializer. + + The serializer should not return the serialized documents, but write them to disk and instead + return some metadata such as the path to the serialized documents. + """ + + @abstractmethod + def __call__(self, documents: Sequence[Document]) -> Any: + pass diff --git a/src/serializer/json.py b/src/serializer/json.py new file mode 100644 index 0000000000000000000000000000000000000000..c293ae3691cf402e93dc7c3ba2f0fdf0d25ad0b7 --- /dev/null +++ b/src/serializer/json.py @@ -0,0 +1,179 @@ +import json +import os +from typing import Dict, List, Optional, Sequence, Type, TypeVar + +from pie_datasets import Dataset, DatasetDict, IterableDataset +from pie_datasets.core.dataset_dict import METADATA_FILE_NAME +from pytorch_ie.core import Document +from pytorch_ie.utils.hydra import resolve_optional_document_type, serialize_document_type + +from src.serializer.interface import DocumentSerializer +from src.utils import get_pylogger + +log = get_pylogger(__name__) + +D = TypeVar("D", bound=Document) + + +def as_json_lines(file_name: str) -> bool: + if file_name.lower().endswith(".jsonl"): + return True + elif file_name.lower().endswith(".json"): + return False + else: + raise Exception(f"unknown file extension: {file_name}") + + +class JsonSerializer(DocumentSerializer): + def __init__(self, **kwargs): + self.default_kwargs = kwargs + + @classmethod + def write( + cls, + documents: Sequence[Document], + path: str, + file_name: str = "documents.jsonl", + metadata_file_name: str = METADATA_FILE_NAME, + split: Optional[str] = None, + **kwargs, + ) -> Dict[str, str]: + realpath = os.path.realpath(path) + log.info(f'serialize documents to "{realpath}" ...') + os.makedirs(realpath, exist_ok=True) + + # dump metadata including the document_type + if len(documents) == 0: + raise Exception("cannot serialize empty list of documents") + document_type = type(documents[0]) + metadata = {"document_type": serialize_document_type(document_type)} + full_metadata_file_name = os.path.join(realpath, metadata_file_name) + if os.path.exists(full_metadata_file_name): + # load previous metadata + with open(full_metadata_file_name) as f: + previous_metadata = json.load(f) + if previous_metadata != metadata: + raise ValueError( + f"metadata file {full_metadata_file_name} already exists, " + "but the content does not match the current metadata" + "\nprevious metadata: {previous_metadata}" + "\ncurrent metadata: {metadata}" + ) + else: + with open(full_metadata_file_name, "w") as f: + json.dump(metadata, f, indent=2) + + if split is not None: + realpath = os.path.join(realpath, split) + os.makedirs(realpath, exist_ok=True) + full_file_name = os.path.join(realpath, file_name) + if as_json_lines(file_name): + # if the file already exists, append to it + mode = "a" if os.path.exists(full_file_name) else "w" + with open(full_file_name, mode) as f: + for doc in documents: + f.write(json.dumps(doc.asdict(), **kwargs) + "\n") + else: + docs_list = [doc.asdict() for doc in documents] + if os.path.exists(full_file_name): + # load previous documents + with open(full_file_name) as f: + previous_doc_list = json.load(f) + docs_list = previous_doc_list + docs_list + with open(full_file_name, "w") as f: + json.dump(docs_list, fp=f, **kwargs) + return {"path": realpath, "file_name": file_name, "metadata_file_name": metadata_file_name} + + @classmethod + def read( + cls, + path: str, + document_type: Optional[Type[D]] = None, + file_name: str = "documents.jsonl", + metadata_file_name: str = METADATA_FILE_NAME, + split: Optional[str] = None, + ) -> List[D]: + realpath = os.path.realpath(path) + log.info(f'load documents from "{realpath}" ...') + + # try to load metadata including the document_type + full_metadata_file_name = os.path.join(realpath, metadata_file_name) + if os.path.exists(full_metadata_file_name): + with open(full_metadata_file_name) as f: + metadata = json.load(f) + document_type = resolve_optional_document_type(metadata.get("document_type")) + + if document_type is None: + raise Exception("document_type is required to load serialized documents") + + if split is not None: + realpath = os.path.join(realpath, split) + full_file_name = os.path.join(realpath, file_name) + documents = [] + if as_json_lines(str(file_name)): + with open(full_file_name) as f: + for line in f: + json_dict = json.loads(line) + documents.append(document_type.fromdict(json_dict)) + else: + with open(full_file_name) as f: + json_list = json.load(f) + for json_dict in json_list: + documents.append(document_type.fromdict(json_dict)) + return documents + + def read_with_defaults(self, **kwargs) -> List[D]: + all_kwargs = {**self.default_kwargs, **kwargs} + return self.read(**all_kwargs) + + def write_with_defaults(self, **kwargs) -> Dict[str, str]: + all_kwargs = {**self.default_kwargs, **kwargs} + return self.write(**all_kwargs) + + def __call__(self, documents: Sequence[Document], **kwargs) -> Dict[str, str]: + return self.write_with_defaults(documents=documents, **kwargs) + + +class JsonSerializer2(DocumentSerializer): + def __init__(self, **kwargs): + self.default_kwargs = kwargs + + @classmethod + def write( + cls, + documents: Sequence[Document], + path: str, + split: str = "train", + ) -> Dict[str, str]: + if not isinstance(documents, (Dataset, IterableDataset)): + documents = Dataset.from_documents(documents) + dataset_dict = DatasetDict({split: documents}) + dataset_dict.to_json(path=path) + return {"path": path, "split": split} + + @classmethod + def read( + cls, + path: str, + document_type: Optional[Type[D]] = None, + split: Optional[str] = None, + ) -> Dataset[Document]: + dataset_dict = DatasetDict.from_json( + data_dir=path, document_type=document_type, split=split + ) + if split is not None: + return dataset_dict[split] + if len(dataset_dict) == 1: + return dataset_dict[list(dataset_dict.keys())[0]] + raise ValueError(f"multiple splits found in dataset_dict: {list(dataset_dict.keys())}") + + def read_with_defaults(self, **kwargs) -> Sequence[D]: + all_kwargs = {**self.default_kwargs, **kwargs} + return self.read(**all_kwargs) + + def write_with_defaults(self, **kwargs) -> Dict[str, str]: + all_kwargs = {**self.default_kwargs, **kwargs} + return self.write(**all_kwargs) + + def __call__(self, documents: Sequence[Document], **kwargs) -> Dict[str, str]: + return self.write_with_defaults(documents=documents, **kwargs) diff --git a/src/start_demo.py b/src/start_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..942c82094d857fd211e3774c44341c44bbe7d582 --- /dev/null +++ b/src/start_demo.py @@ -0,0 +1,578 @@ +import hydra +import pyrootutils +from omegaconf import DictConfig, OmegaConf, SCMode + +root = pyrootutils.setup_root( + search_from=__file__, + indicator=[".project-root"], + pythonpath=True, + dotenv=True, +) + +import json +import logging + +import gradio as gr +import torch +import yaml + +from src.demo.annotation_utils import load_argumentation_model +from src.demo.backend_utils import ( + download_processed_documents, + process_text_from_arxiv, + process_uploaded_files, + render_annotated_document, + upload_processed_documents, + wrapped_add_annotated_pie_documents_from_dataset, + wrapped_process_text, +) +from src.demo.frontend_utils import ( + change_tab, + escape_regex, + get_cell_for_fixed_column_from_df, + get_fix_df_height_css, + open_accordion, + unescape_regex, +) +from src.demo.rendering_utils import AVAILABLE_RENDER_MODES, HIGHLIGHT_SPANS_JS +from src.demo.retriever_utils import ( + get_document_as_dict, + get_span_annotation, + load_retriever, + retrieve_all_relevant_spans, + retrieve_all_similar_spans, + retrieve_relevant_spans, + retrieve_similar_spans, +) + + +def load_yaml_config(path: str) -> str: + with open(path, "r") as file: + yaml_string = file.read() + config = yaml.safe_load(yaml_string) + return yaml.dump(config) + + +def resolve_config(cfg) -> dict: + return OmegaConf.to_container(cfg, resolve=True, structured_config_mode=SCMode.DICT) + + +@hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="demo.yaml") +def main(cfg: DictConfig) -> None: + + # configure logging + logging.basicConfig() + + # resolve everything in the config to prevent any issues with to json serialization etc. + cfg = resolve_config(cfg) + + example_text = cfg["example_text"] + + default_device = "cuda" if torch.cuda.is_available() else "cpu" + + default_retriever_config_str = load_yaml_config(cfg["default_retriever_config_path"]) + + default_model_name = cfg["default_model_name"] + default_model_revision = cfg["default_model_revision"] + handle_parts_of_same = cfg["handle_parts_of_same"] + + default_arxiv_id = cfg["default_arxiv_id"] + default_load_pie_dataset_kwargs_str = json.dumps( + cfg["default_load_pie_dataset_kwargs"], indent=2 + ) + + default_render_mode = cfg["default_render_mode"] + if default_render_mode not in AVAILABLE_RENDER_MODES: + raise ValueError( + f"Invalid default render mode '{default_render_mode}'. " + f"Choose one of {AVAILABLE_RENDER_MODES}." + ) + default_render_kwargs = cfg["default_render_kwargs"] + + # captions for better readability + default_split_regex = cfg["default_split_regex"] + # map from render mode to the corresponding caption + render_mode2caption = { + render_mode: cfg["render_mode_captions"].get(render_mode, render_mode) + for render_mode in AVAILABLE_RENDER_MODES + } + render_caption2mode = {v: k for k, v in render_mode2caption.items()} + default_min_similarity = cfg["default_min_similarity"] + layer_caption_mapping = cfg["layer_caption_mapping"] + relation_name_mapping = cfg["relation_name_mapping"] + + gr.Info("Loading models ...") + argumentation_model = load_argumentation_model( + model_name=default_model_name, + revision=default_model_revision, + device=default_device, + ) + retriever = load_retriever( + default_retriever_config_str, device=default_device, config_format="yaml" + ) + + with gr.Blocks(css=get_fix_df_height_css(css_class="df-docstore", max_height=300)) as demo: + # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called + # models_state = gr.State((argumentation_model, embedding_model)) + argumentation_model_state = gr.State((argumentation_model,)) + retriever_state = gr.State((retriever,)) + + with gr.Row(): + with gr.Tabs() as left_tabs: + with gr.Tab("User Input", id="user_input") as user_input_tab: + doc_id = gr.Textbox( + label="Document ID", + value="user_input", + ) + doc_text = gr.Textbox( + label="Text", + lines=20, + value=example_text, + ) + + with gr.Accordion("Model Configuration", open=False): + with gr.Accordion("argumentation structure", open=True): + model_name = gr.Textbox( + label="Model Name", + value=default_model_name, + ) + model_revision = gr.Textbox( + label="Model Revision", + value=default_model_revision, + ) + load_arg_model_btn = gr.Button("Load Argumentation Model") + + with gr.Accordion("retriever", open=True): + retriever_config = gr.Code( + language="yaml", + label="Retriever Configuration", + value=default_retriever_config_str, + lines=len(default_retriever_config_str.split("\n")), + ) + load_retriever_btn = gr.Button("Load Retriever") + + device = gr.Textbox( + label="Device (e.g. 'cuda' or 'cpu')", + value=default_device, + ) + load_arg_model_btn.click( + fn=lambda _model_name, _model_revision, _device: ( + load_argumentation_model( + model_name=_model_name, + revision=_model_revision, + device=_device, + ), + ), + inputs=[model_name, model_revision, device], + outputs=argumentation_model_state, + ) + load_retriever_btn.click( + fn=lambda _retriever_config, _device, _previous_retriever: ( + load_retriever( + retriever_config_str=_retriever_config, + device=_device, + previous_retriever=_previous_retriever[0], + config_format="yaml", + ), + ), + inputs=[retriever_config, device, retriever_state], + outputs=retriever_state, + ) + + split_regex_escaped = gr.Textbox( + label="Regex to partition the text", + placeholder="Regular expression pattern to split the text into partitions", + value=escape_regex(default_split_regex), + ) + + predict_btn = gr.Button("Analyse") + + with gr.Tab("Analysed Document", id="analysed_document") as analysed_document_tab: + selected_document_id = gr.Textbox( + label="Document ID", max_lines=1, interactive=False + ) + rendered_output = gr.HTML(label="Rendered Output") + + with gr.Accordion("Render Options", open=False): + render_as = gr.Dropdown( + label="Render with", + choices=list(render_mode2caption.values()), + value=render_mode2caption[default_render_mode], + ) + render_kwargs = gr.Code( + language="json", + label="Render Arguments", + lines=len(json.dumps(default_render_kwargs, indent=2).split("\n")), + value=json.dumps(default_render_kwargs, indent=2), + ) + render_btn = gr.Button("Re-render") + + with gr.Accordion("See plain result ...", open=False): + get_document_json_btn = gr.Button("Fetch annotated document as JSON") + document_json = gr.JSON(label="Model Output") + + with gr.Tabs() as right_tabs: + with gr.Tab("Retrieval", id="retrieval") as retrieval_tab: + with gr.Accordion( + "Indexed Documents", open=False + ) as processed_documents_accordion: + processed_documents_df = gr.DataFrame( + headers=["id", "num_adus", "num_relations"], + interactive=False, + elem_classes="df-docstore", + ) + gr.Markdown("Data Snapshot:") + with gr.Row(): + download_processed_documents_btn = gr.DownloadButton("Download") + upload_processed_documents_btn = gr.UploadButton( + "Upload", file_types=["file"] + ) + + # currently not used + # relation_types = set_relation_types( + # argumentation_model_state.value[0], default=["supports_reversed", "contradicts_reversed"] + # ) + + # Dummy textbox to hold the hover adu id. On click on the rendered output, + # its content will be copied to selected_adu_id which will trigger the retrieval. + hover_adu_id = gr.Textbox( + label="ID (hover)", + elem_id="hover_adu_id", + interactive=False, + visible=False, + ) + selected_adu_id = gr.Textbox( + label="ID (selected)", + elem_id="selected_adu_id", + interactive=False, + visible=False, + ) + selected_adu_text = gr.Textbox(label="Selected ADU", interactive=False) + + with gr.Accordion("Relevant ADUs from other documents", open=True): + relevant_adus_df = gr.DataFrame( + headers=[ + "relation", + "adu", + "reference_adu", + "doc_id", + "sim_score", + "rel_score", + ], + interactive=False, + ) + + with gr.Accordion("Retrieval Configuration", open=False): + min_similarity = gr.Slider( + label="Minimum Similarity", + minimum=0.0, + maximum=1.0, + step=0.01, + value=default_min_similarity, + ) + top_k = gr.Slider( + label="Top K", + minimum=2, + maximum=50, + step=1, + value=10, + ) + retrieve_similar_adus_btn = gr.Button( + "Retrieve *similar* ADUs for *selected* ADU" + ) + similar_adus_df = gr.DataFrame( + headers=["doc_id", "adu_id", "score", "text"], interactive=False + ) + retrieve_all_similar_adus_btn = gr.Button( + "Retrieve *similar* ADUs for *all* ADUs in the document" + ) + all_similar_adus_df = gr.DataFrame( + headers=["doc_id", "query_adu_id", "adu_id", "score", "text"], + interactive=False, + ) + retrieve_all_relevant_adus_btn = gr.Button( + "Retrieve *relevant* ADUs for *all* ADUs in the document" + ) + all_relevant_adus_df = gr.DataFrame( + headers=["doc_id", "adu_id", "score", "text"], interactive=False + ) + + with gr.Tab("Import Documents", id="import_documents") as import_documents_tab: + upload_btn = gr.UploadButton( + "Batch Analyse Texts", + file_types=["text"], + file_count="multiple", + ) + + with gr.Accordion("Import text from arXiv", open=False): + arxiv_id = gr.Textbox( + label="arXiv paper ID", + placeholder=f"e.g. {default_arxiv_id}", + max_lines=1, + ) + load_arxiv_only_abstract = gr.Checkbox(label="abstract only", value=False) + load_arxiv_btn = gr.Button( + "Load & Analyse from arXiv", variant="secondary" + ) + + with gr.Accordion( + "Import argument structure annotated PIE dataset", open=False + ): + load_pie_dataset_kwargs_str = gr.Code( + language="json", + label="Parameters for Loading the PIE Dataset", + value=default_load_pie_dataset_kwargs_str, + lines=len(default_load_pie_dataset_kwargs_str.split("\n")), + ) + load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset") + + render_event_kwargs = dict( + fn=lambda _retriever, _document_id, _render_as, _render_kwargs: render_annotated_document( + retriever=_retriever[0], + document_id=_document_id, + render_with=render_caption2mode[_render_as], + render_kwargs_json=_render_kwargs, + ), + inputs=[retriever_state, selected_document_id, render_as, render_kwargs], + outputs=rendered_output, + ) + + show_overview_kwargs = dict( + fn=lambda _retriever: _retriever[0].docstore.overview( + layer_captions=layer_caption_mapping, use_predictions=True + ), + inputs=[retriever_state], + outputs=[processed_documents_df], + ) + predict_btn.click( + fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs] + ).then( + fn=lambda _doc_text, _doc_id, _argumentation_model, _retriever, _split_regex_escaped: wrapped_process_text( + text=_doc_text, + doc_id=_doc_id, + argumentation_model=_argumentation_model[0], + retriever=_retriever[0], + split_regex_escaped=( + unescape_regex(_split_regex_escaped) if _split_regex_escaped else None + ), + handle_parts_of_same=handle_parts_of_same, + ), + inputs=[ + doc_text, + doc_id, + argumentation_model_state, + retriever_state, + split_regex_escaped, + ], + outputs=[selected_document_id], + api_name="predict", + ).success( + **show_overview_kwargs + ).success( + **render_event_kwargs + ) + render_btn.click(**render_event_kwargs, api_name="render") + + load_arxiv_btn.click( + fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs] + ).then( + fn=lambda _arxiv_id, _load_arxiv_only_abstract, _argumentation_model, _retriever, _split_regex_escaped: process_text_from_arxiv( + arxiv_id=_arxiv_id.strip() or default_arxiv_id, + abstract_only=_load_arxiv_only_abstract, + argumentation_model=_argumentation_model[0], + retriever=_retriever[0], + split_regex_escaped=( + unescape_regex(_split_regex_escaped) if _split_regex_escaped else None + ), + handle_parts_of_same=handle_parts_of_same, + ), + inputs=[ + arxiv_id, + load_arxiv_only_abstract, + argumentation_model_state, + retriever_state, + split_regex_escaped, + ], + outputs=[selected_document_id], + api_name="predict", + ).success( + **show_overview_kwargs + ) + + load_pie_dataset_btn.click( + fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs] + ).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then( + fn=lambda _retriever, _load_pie_dataset_kwargs_str: wrapped_add_annotated_pie_documents_from_dataset( + retriever=_retriever[0], + verbose=True, + layer_captions=layer_caption_mapping, + **json.loads(_load_pie_dataset_kwargs_str), + ), + inputs=[retriever_state, load_pie_dataset_kwargs_str], + outputs=[processed_documents_df], + ) + + selected_document_id.change( + fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs] + ).then(**render_event_kwargs) + + get_document_json_btn.click( + fn=lambda _retriever, _document_id: get_document_as_dict( + retriever=_retriever[0], doc_id=_document_id + ), + inputs=[retriever_state, selected_document_id], + outputs=[document_json], + ) + + upload_btn.upload( + fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs] + ).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then( + fn=lambda _file_names, _argumentation_model, _retriever, _split_regex_escaped: process_uploaded_files( + file_names=_file_names, + argumentation_model=_argumentation_model[0], + retriever=_retriever[0], + split_regex_escaped=unescape_regex(_split_regex_escaped), + handle_parts_of_same=handle_parts_of_same, + layer_captions=layer_caption_mapping, + ), + inputs=[ + upload_btn, + argumentation_model_state, + retriever_state, + split_regex_escaped, + ], + outputs=[processed_documents_df], + ) + processed_documents_df.select( + fn=get_cell_for_fixed_column_from_df, + inputs=[processed_documents_df, gr.State("doc_id")], + outputs=[selected_document_id], + ) + + download_processed_documents_btn.click( + fn=lambda _retriever: download_processed_documents( + _retriever[0], file_name="processed_documents" + ), + inputs=[retriever_state], + outputs=[download_processed_documents_btn], + ) + upload_processed_documents_btn.upload( + fn=lambda file_name, _retriever: upload_processed_documents( + file_name, retriever=_retriever[0], layer_captions=layer_caption_mapping + ), + inputs=[upload_processed_documents_btn, retriever_state], + outputs=[processed_documents_df], + ) + + retrieve_relevant_adus_event_kwargs = dict( + fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k: retrieve_relevant_spans( + retriever=_retriever[0], + query_span_id=_selected_adu_id, + k=_top_k, + score_threshold=_min_similarity, + relation_label_mapping=relation_name_mapping, + # columns=relevant_adus.headers + ), + inputs=[ + retriever_state, + selected_adu_id, + min_similarity, + top_k, + ], + outputs=[relevant_adus_df], + ) + relevant_adus_df.select( + fn=get_cell_for_fixed_column_from_df, + inputs=[relevant_adus_df, gr.State("doc_id")], + outputs=[selected_document_id], + ) + + selected_adu_id.change( + fn=lambda _retriever, _selected_adu_id: get_span_annotation( + retriever=_retriever[0], span_id=_selected_adu_id + ), + inputs=[retriever_state, selected_adu_id], + outputs=[selected_adu_text], + ).success(**retrieve_relevant_adus_event_kwargs) + + retrieve_similar_adus_btn.click( + fn=lambda _retriever, _selected_adu_id, _min_similarity, _tok_k: retrieve_similar_spans( + retriever=_retriever[0], + query_span_id=_selected_adu_id, + k=_tok_k, + score_threshold=_min_similarity, + ), + inputs=[ + retriever_state, + selected_adu_id, + min_similarity, + top_k, + ], + outputs=[similar_adus_df], + ) + similar_adus_df.select( + fn=get_cell_for_fixed_column_from_df, + inputs=[similar_adus_df, gr.State("doc_id")], + outputs=[selected_document_id], + ) + + retrieve_all_similar_adus_btn.click( + fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_similar_spans( + retriever=_retriever[0], + query_doc_id=_document_id, + k=_tok_k, + score_threshold=_min_similarity, + query_span_id_column="query_span_id", + ), + inputs=[ + retriever_state, + selected_document_id, + min_similarity, + top_k, + ], + outputs=[all_similar_adus_df], + ) + + retrieve_all_relevant_adus_btn.click( + fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_relevant_spans( + retriever=_retriever[0], + query_doc_id=_document_id, + k=_tok_k, + score_threshold=_min_similarity, + query_span_id_column="query_span_id", + ), + inputs=[ + retriever_state, + selected_document_id, + min_similarity, + top_k, + ], + outputs=[all_relevant_adus_df], + ) + + # select query span id from the "retrieve all" result data frames + all_similar_adus_df.select( + fn=get_cell_for_fixed_column_from_df, + inputs=[all_similar_adus_df, gr.State("query_span_id")], + outputs=[selected_adu_id], + ) + all_relevant_adus_df.select( + fn=get_cell_for_fixed_column_from_df, + inputs=[all_relevant_adus_df, gr.State("query_span_id")], + outputs=[selected_adu_id], + ) + + # argumentation_model_state.change( + # fn=lambda _argumentation_model: set_relation_types(_argumentation_model[0]), + # inputs=[argumentation_model_state], + # outputs=[relation_types], + # ) + + rendered_output.change(fn=None, js=HIGHLIGHT_SPANS_JS, inputs=[], outputs=[]) + + demo.launch() + + +if __name__ == "__main__": + + main() diff --git a/src/taskmodules/__init__.py b/src/taskmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2045b56af4b0f8cdb9962cce2559bf1b18cdd966 --- /dev/null +++ b/src/taskmodules/__init__.py @@ -0,0 +1,8 @@ +from .cross_text_binary_coref import CrossTextBinaryCorefTaskModuleWithOptionalContext +from .cross_text_binary_coref_nli import CrossTextBinaryCorefTaskModuleByNli +from .re_text_classification_with_indices import ( + CrossTextBinaryCorefByRETextClassificationTaskModule, + RETextClassificationWithIndicesTaskModuleAndWithSharpBracketMarkers, +) + +CrossTextBinaryCorefTaskModule2 = CrossTextBinaryCorefByRETextClassificationTaskModule diff --git a/src/taskmodules/components/__init__.py b/src/taskmodules/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/taskmodules/cross_text_binary_coref.py b/src/taskmodules/cross_text_binary_coref.py new file mode 100644 index 0000000000000000000000000000000000000000..04b32bb63ac92e3c140ad30f0e9319f7960832f8 --- /dev/null +++ b/src/taskmodules/cross_text_binary_coref.py @@ -0,0 +1,116 @@ +import logging +from typing import Optional, Sequence, TypeVar, Union + +from pie_modules.taskmodules import CrossTextBinaryCorefTaskModule +from pie_modules.taskmodules.cross_text_binary_coref import ( + DocumentType, + SpanDoesNotFitIntoAvailableWindow, + TaskEncodingType, +) +from pie_modules.utils.tokenization import SpanNotAlignedWithTokenException +from pytorch_ie.annotations import Span +from pytorch_ie.core import TaskEncoding, TaskModule + +logger = logging.getLogger(__name__) + + +S = TypeVar("S", bound=Span) + + +def shift_span(span: S, offset: int) -> S: + return span.copy(start=span.start + offset, end=span.end + offset) + + +@TaskModule.register() +class CrossTextBinaryCorefTaskModuleWithOptionalContext(CrossTextBinaryCorefTaskModule): + """Same as CrossTextBinaryCorefTaskModule, but: + - optionally without context. + """ + + def __init__( + self, + without_context: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.without_context = without_context + + def encode_input( + self, + document: DocumentType, + is_training: bool = False, + ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]: + if self.without_context: + return self.encode_input_without_context(document) + else: + return super().encode_input(document) + + def encode_input_without_context( + self, document: DocumentType + ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]: + self.collect_all_relations(kind="available", relations=document.binary_coref_relations) + tokenizer_kwargs = dict( + padding=False, + truncation=False, + add_special_tokens=False, + ) + + task_encodings = [] + for coref_rel in document.binary_coref_relations: + + # TODO: This can miss instances if both texts are the same. We could check that + # coref_rel.head is in document.labeled_spans (same for the tail), but would this + # slow down the encoding? + if not ( + coref_rel.head.target == document.text + or coref_rel.tail.target == document.text_pair + ): + raise ValueError( + f"It is expected that coref relations go from (head) spans over 'text' " + f"to (tail) spans over 'text_pair', but this is not the case for this " + f"relation (i.e. it points into the other direction): {coref_rel.resolve()}" + ) + encoding = self.tokenizer(text=str(coref_rel.head), **tokenizer_kwargs) + encoding_pair = self.tokenizer(text=str(coref_rel.tail), **tokenizer_kwargs) + + try: + current_encoding, token_span = self.truncate_encoding_around_span( + encoding=encoding, char_span=shift_span(coref_rel.head, -coref_rel.head.start) + ) + current_encoding_pair, token_span_pair = self.truncate_encoding_around_span( + encoding=encoding_pair, + char_span=shift_span(coref_rel.tail, -coref_rel.tail.start), + ) + except SpanNotAlignedWithTokenException as e: + logger.warning( + f"Could not get token offsets for argument ({e.span}) of coref relation: " + f"{coref_rel.resolve()}. Skip it." + ) + self.collect_relation(kind="skipped_args_not_aligned", relation=coref_rel) + continue + except SpanDoesNotFitIntoAvailableWindow as e: + logger.warning( + f"Argument span [{e.span}] does not fit into available token window " + f"({self.available_window}). Skip it." + ) + self.collect_relation( + kind="skipped_span_does_not_fit_into_window", relation=coref_rel + ) + continue + + task_encodings.append( + TaskEncoding( + document=document, + inputs={ + "encoding": current_encoding, + "encoding_pair": current_encoding_pair, + "pooler_start_indices": token_span.start, + "pooler_end_indices": token_span.end, + "pooler_pair_start_indices": token_span_pair.start, + "pooler_pair_end_indices": token_span_pair.end, + }, + metadata={"candidate_annotation": coref_rel}, + ) + ) + self.collect_relation("used", coref_rel) + return task_encodings diff --git a/src/taskmodules/cross_text_binary_coref_nli.py b/src/taskmodules/cross_text_binary_coref_nli.py new file mode 100644 index 0000000000000000000000000000000000000000..ad662031e6d90055fa0f3694267da2b1d721b187 --- /dev/null +++ b/src/taskmodules/cross_text_binary_coref_nli.py @@ -0,0 +1,166 @@ +import logging +from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, TypedDict, Union + +import torch +from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations +from pie_modules.taskmodules.common.mixins import RelationStatisticsMixin +from pytorch_ie import Annotation +from pytorch_ie.core import TaskEncoding, TaskModule +from transformers import AutoTokenizer +from typing_extensions import TypeAlias + +logger = logging.getLogger(__name__) + +InputEncodingType: TypeAlias = Dict[str, Any] +TargetEncodingType: TypeAlias = Sequence[float] +DocumentType: TypeAlias = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations + +TaskEncodingType: TypeAlias = TaskEncoding[ + DocumentType, + InputEncodingType, + TargetEncodingType, +] + + +class TaskOutputType(TypedDict, total=False): + label_pair: Tuple[str, str] + entailment_probability_pair: Tuple[float, float] + + +ModelInputType: TypeAlias = Dict[str, torch.Tensor] +ModelTargetType: TypeAlias = Dict[str, torch.Tensor] +ModelOutputType: TypeAlias = Dict[str, torch.Tensor] + +TaskModuleType: TypeAlias = TaskModule[ + # _InputEncoding, _TargetEncoding, _TaskBatchEncoding, _ModelBatchOutput, _TaskOutput + DocumentType, + InputEncodingType, + TargetEncodingType, + Tuple[ModelInputType, Optional[ModelTargetType]], + ModelTargetType, + TaskOutputType, +] + + +@TaskModule.register() +class CrossTextBinaryCorefTaskModuleByNli(RelationStatisticsMixin, TaskModuleType): + """This taskmodule processes documents of type + TextPairDocumentWithLabeledSpansAndBinaryCorefRelations in preparation for a sequence + classification model trained for NLI. The assumption is that if the entailment class is + predicted for both directions, a coreference relation exists between the two spans. + + It simply tokenizes and encodes the head and tail texts of the coreference relations as text + pairs, i.e. no context of head and tail is considered. During decoding, coreference relations + are created if the entailment class (see parameter entailment_label) is predicted for both + directions and the average probability is used as the score. + """ + + DOCUMENT_TYPE = DocumentType + + def __init__( + self, + tokenizer_name_or_path: str, + labels: List[str], + entailment_label: str, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.save_hyperparameters() + + self.labels = labels + self.entailment_label = entailment_label + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + + def _post_prepare(self): + self.id_to_label = dict(enumerate(self.labels)) + self.label_to_id = {v: k for k, v in self.id_to_label.items()} + self.entailment_idx = self.label_to_id[self.entailment_label] + + def encode(self, documents: Union[DocumentType, Iterable[DocumentType]], **kwargs): + self.reset_statistics() + result = super().encode(documents=documents, **kwargs) + self.show_statistics() + return result + + def encode_input( + self, + document: DocumentType, + is_training: bool = False, + ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]: + self.collect_all_relations(kind="available", relations=document.binary_coref_relations) + result = [] + for coref_rel in document.binary_coref_relations: + head_text = str(coref_rel.head) + tail_text = str(coref_rel.tail) + task_encoding = TaskEncoding( + document=document, + inputs={"text": [head_text, tail_text], "text_pair": [tail_text, head_text]}, + metadata={"candidate_annotation": coref_rel}, + ) + result.append(task_encoding) + self.collect_relation("used", coref_rel) + return result + + def encode_target( + self, + task_encoding: TaskEncodingType, + ) -> Optional[TargetEncodingType]: + raise NotImplementedError() + + def collate( + self, + task_encodings: Sequence[ + TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType] + ], + ) -> Tuple[ModelInputType, Optional[ModelTargetType]]: + all_texts = [] + all_texts_pair = [] + for task_encoding in task_encodings: + all_texts.extend(task_encoding.inputs["text"]) + all_texts_pair.extend(task_encoding.inputs["text_pair"]) + inputs = self.tokenizer( + text=all_texts, + text_pair=all_texts_pair, + truncation=True, + padding=True, + return_tensors="pt", + ) + if not task_encodings[0].has_targets: + return inputs, None + raise NotImplementedError() + + def unbatch_output(self, model_output: ModelTargetType) -> Sequence[TaskOutputType]: + probs_tensor = model_output["probabilities"] + labels_tensor = model_output["labels"] + + bs, num_classes = probs_tensor.size() + # Reshape the probs tensor to (bs/2, 2, num_classes) + probs_paired = probs_tensor.view(bs // 2, 2, num_classes).detach().cpu().tolist() + + # Reshape the labels tensor to (bs/2, 2) + labels_paired = labels_tensor.view(bs // 2, 2).detach().cpu().tolist() + + result = [] + for (label_id, label_id_pair), (probs_list, probs_list_pair) in zip( + labels_paired, probs_paired + ): + task_output: TaskOutputType = { + "label_pair": (self.id_to_label[label_id], self.id_to_label[label_id_pair]), + "entailment_probability_pair": ( + probs_list[self.entailment_idx], + probs_list_pair[self.entailment_idx], + ), + } + result.append(task_output) + return result + + def create_annotations_from_output( + self, + task_encoding: TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType], + task_output: TaskOutputType, + ) -> Iterator[Tuple[str, Annotation]]: + if all(label == self.entailment_label for label in task_output["label_pair"]): + probs = task_output["entailment_probability_pair"] + score = (probs[0] + probs[1]) / 2 + new_coref_rel = task_encoding.metadata["candidate_annotation"].copy(score=score) + yield "binary_coref_relations", new_coref_rel diff --git a/src/taskmodules/re_text_classification_with_indices.py b/src/taskmodules/re_text_classification_with_indices.py new file mode 100644 index 0000000000000000000000000000000000000000..5993cb21709a02c253851d9ddd2c09be4b03dcd4 --- /dev/null +++ b/src/taskmodules/re_text_classification_with_indices.py @@ -0,0 +1,176 @@ +import copy +from itertools import chain +from typing import Dict, Optional, Sequence, Type + +import torch +from pie_modules.annotations import BinaryCorefRelation +from pie_modules.document.processing.text_pair import shift_span +from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations +from pie_modules.taskmodules import RETextClassificationWithIndicesTaskModule +from pie_modules.taskmodules.common import TaskModuleWithDocumentConverter +from pie_modules.taskmodules.re_text_classification_with_indices import MarkerFactory +from pie_modules.taskmodules.re_text_classification_with_indices import ( + ModelTargetType as REModelTargetType, +) +from pie_modules.taskmodules.re_text_classification_with_indices import ( + TaskOutputType as RETaskOutputType, +) +from pytorch_ie import Document, TaskModule +from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations + + +class SharpBracketMarkerFactory(MarkerFactory): + def _get_marker(self, role: str, is_start: bool, label: Optional[str] = None) -> str: + result = "<" + if not is_start: + result += "/" + result += self._get_role_marker(role) + if label is not None: + result += f":{label}" + result += ">" + return result + + def get_append_marker(self, role: str, label: Optional[str] = None) -> str: + role_marker = self._get_role_marker(role) + if label is None: + return f"<{role_marker}>" + else: + return f"<{role_marker}={label}>" + + +@TaskModule.register() +class RETextClassificationWithIndicesTaskModuleAndWithSharpBracketMarkers( + RETextClassificationWithIndicesTaskModule +): + def __init__(self, use_sharp_marker: bool = False, **kwargs): + super().__init__(**kwargs) + self.use_sharp_marker = use_sharp_marker + + def get_marker_factory(self) -> MarkerFactory: + if self.use_sharp_marker: + return SharpBracketMarkerFactory(role_to_marker=self.argument_role_to_marker) + else: + return MarkerFactory(role_to_marker=self.argument_role_to_marker) + + +def construct_text_document_from_text_pair_coref_document( + document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, + glue_text: str, + no_relation_label: str, + relation_label_mapping: Optional[Dict[str, str]] = None, + add_span_mapping_to_metadata: bool = False, +) -> TextDocumentWithLabeledSpansAndBinaryRelations: + if document.text == document.text_pair: + new_doc = TextDocumentWithLabeledSpansAndBinaryRelations( + id=document.id, metadata=copy.deepcopy(document.metadata), text=document.text + ) + old2new_spans: Dict[LabeledSpan, LabeledSpan] = {} + new2new_spans: Dict[LabeledSpan, LabeledSpan] = {} + for old_span in chain(document.labeled_spans, document.labeled_spans_pair): + new_span = old_span.copy() + # when detaching / copying the span, it may be the same as a previous span from the other + new_span = new2new_spans.get(new_span, new_span) + new2new_spans[new_span] = new_span + old2new_spans[old_span] = new_span + else: + new_doc = TextDocumentWithLabeledSpansAndBinaryRelations( + text=document.text + glue_text + document.text_pair, + id=document.id, + metadata=copy.deepcopy(document.metadata), + ) + old2new_spans = {} + old2new_spans.update({span: span.copy() for span in document.labeled_spans}) + offset = len(document.text) + len(glue_text) + old2new_spans.update( + {span: shift_span(span.copy(), offset) for span in document.labeled_spans_pair} + ) + + # sort to make order deterministic + new_doc.labeled_spans.extend( + sorted(old2new_spans.values(), key=lambda s: (s.start, s.end, s.label)) + ) + for old_rel in document.binary_coref_relations: + label = old_rel.label if old_rel.score > 0.0 else no_relation_label + if relation_label_mapping is not None: + label = relation_label_mapping.get(label, label) + new_rel = old_rel.copy( + head=old2new_spans[old_rel.head], + tail=old2new_spans[old_rel.tail], + label=label, + score=1.0, + ) + new_doc.binary_relations.append(new_rel) + + if add_span_mapping_to_metadata: + new_doc.metadata["span_mapping"] = old2new_spans + return new_doc + + +@TaskModule.register() +class CrossTextBinaryCorefByRETextClassificationTaskModule( + TaskModuleWithDocumentConverter, + RETextClassificationWithIndicesTaskModuleAndWithSharpBracketMarkers, +): + def __init__( + self, + coref_relation_label: str, + relation_annotation: str = "binary_relations", + probability_threshold: float = 0.0, + **kwargs, + ): + if relation_annotation != "binary_relations": + raise ValueError( + f"{type(self).__name__} requires relation_annotation='binary_relations', " + f"but it is: {relation_annotation}" + ) + super().__init__(relation_annotation=relation_annotation, **kwargs) + self.coref_relation_label = coref_relation_label + self.probability_threshold = probability_threshold + + @property + def document_type(self) -> Optional[Type[Document]]: + return TextPairDocumentWithLabeledSpansAndBinaryCorefRelations + + def _get_glue_text(self) -> str: + result = self.tokenizer.decode(self._get_glue_token_ids()) + return result + + def _convert_document( + self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations + ) -> TextDocumentWithLabeledSpansAndBinaryRelations: + return construct_text_document_from_text_pair_coref_document( + document, + glue_text=self._get_glue_text(), + relation_label_mapping={"coref": self.coref_relation_label}, + no_relation_label=self.none_label, + add_span_mapping_to_metadata=True, + ) + + def _integrate_predictions_from_converted_document( + self, + document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, + converted_document: TextDocumentWithLabeledSpansAndBinaryRelations, + ) -> None: + original2converted_span = converted_document.metadata["span_mapping"] + new2original_span = { + converted_s: orig_s for orig_s, converted_s in original2converted_span.items() + } + + for rel in converted_document.binary_relations.predictions: + original_head = new2original_span[rel.head] + original_tail = new2original_span[rel.tail] + if rel.label != self.coref_relation_label: + raise ValueError(f"unexpected label: {rel.label}") + if rel.score >= self.probability_threshold: + original_predicted_rel = BinaryCorefRelation( + head=original_head, tail=original_tail, label="coref", score=rel.score + ) + document.binary_coref_relations.predictions.append(original_predicted_rel) + + def unbatch_output(self, model_output: REModelTargetType) -> Sequence[RETaskOutputType]: + coref_relation_idx = self.label_to_id[self.coref_relation_label] + # we are just concerned with the coref class, so we overwrite the labels field + model_output = copy.copy(model_output) + model_output["labels"] = torch.ones_like(model_output["labels"]) * coref_relation_idx + return super().unbatch_output(model_output=model_output) diff --git a/src/train.py b/src/train.py new file mode 100644 index 0000000000000000000000000000000000000000..31c55f55927d1ea25f5f1d70dd3c810191ef2b9c --- /dev/null +++ b/src/train.py @@ -0,0 +1,294 @@ +import pyrootutils + +root = pyrootutils.setup_root( + search_from=__file__, + indicator=[".project-root"], + pythonpath=True, + dotenv=True, +) + +# ------------------------------------------------------------------------------------ # +# `pyrootutils.setup_root(...)` is an optional line at the top of each entry file +# that helps to make the environment more robust and convenient +# +# the main advantages are: +# - allows you to keep all entry files in "src/" without installing project as a package +# - makes paths and scripts always work no matter where is your current work dir +# - automatically loads environment variables from ".env" file if exists +# +# how it works: +# - the line above recursively searches for either ".git" or "pyproject.toml" in present +# and parent dirs, to determine the project root dir +# - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from +# any place without installing project as a package +# - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml" +# to make all paths always relative to the project root +# - loads environment variables from ".env" file in root dir (if `dotenv=True`) +# +# you can remove `pyrootutils.setup_root(...)` if you: +# 1. either install project as a package or move each entry file to the project root dir +# 2. simply remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml" +# 3. always run entry files from the project root dir +# +# https://github.com/ashleve/pyrootutils +# ------------------------------------------------------------------------------------ # + +import os.path +from typing import Any, Dict, List, Optional, Tuple + +import hydra +import pytorch_lightning as pl +from omegaconf import DictConfig +from pie_datasets import DatasetDict +from pie_modules.models import * # noqa: F403 +from pie_modules.models import SimpleGenerativeModel +from pie_modules.models.interface import RequiresTaskmoduleConfig +from pie_modules.taskmodules import * # noqa: F403 +from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE +from pytorch_ie.core import PyTorchIEModel, TaskModule +from pytorch_ie.models import * # noqa: F403 +from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses +from pytorch_ie.taskmodules import * # noqa: F403 +from pytorch_ie.taskmodules.interface import ChangesTokenizerVocabSize +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.loggers import Logger + +from src import utils +from src.datamodules import PieDataModule +from src.models import * # noqa: F403 +from src.taskmodules import * # noqa: F403 + +log = utils.get_pylogger(__name__) + + +def get_metric_value(metric_dict: dict, metric_name: str) -> Optional[float]: + """Safely retrieves value of the metric logged in LightningModule.""" + + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise Exception( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value + + +@utils.task_wrapper +def train(cfg: DictConfig) -> Tuple[dict, dict]: + """Trains the model. Can additionally evaluate on a testset, using best weights obtained during + training. + + This method is wrapped in optional @task_wrapper decorator which applies extra utilities + before and after the call. + + Args: + cfg (DictConfig): Configuration composed by Hydra. + + Returns: + Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. + """ + + # set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + pl.seed_everything(cfg.seed, workers=True) + + # Init pytorch-ie taskmodule + log.info(f"Instantiating taskmodule <{cfg.taskmodule._target_}>") + taskmodule: TaskModule = hydra.utils.instantiate(cfg.taskmodule, _convert_="partial") + + # Init pytorch-ie dataset + log.info(f"Instantiating dataset <{cfg.dataset._target_}>") + dataset: DatasetDict = hydra.utils.instantiate( + cfg.dataset, + _convert_="partial", + ) + + # auto-convert the dataset if the taskmodule specifies a document type + dataset = taskmodule.convert_dataset(dataset) + + # Init pytorch-ie datamodule + log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>") + datamodule: PieDataModule = hydra.utils.instantiate( + cfg.datamodule, dataset=dataset, taskmodule=taskmodule, _convert_="partial" + ) + # Use the train dataset split to prepare the taskmodule + taskmodule.prepare(dataset[datamodule.train_split]) + + # Init the pytorch-ie model + log.info(f"Instantiating model <{cfg.model._target_}>") + # get additional model arguments + additional_model_kwargs: Dict[str, Any] = {} + model_cls = hydra.utils.get_class(cfg.model["_target_"]) + # NOTE: MODIFY THE additional_model_kwargs IF YOUR MODEL REQUIRES ANY MORE PARAMETERS FROM THE TASKMODULE! + # SEE EXAMPLES BELOW. + if issubclass(model_cls, RequiresNumClasses): + additional_model_kwargs["num_classes"] = len(taskmodule.label_to_id) + if issubclass(model_cls, RequiresModelNameOrPath): + if "model_name_or_path" not in cfg.model: + raise Exception( + f"Please specify model_name_or_path in the model config for {model_cls.__name__}." + ) + if isinstance(taskmodule, ChangesTokenizerVocabSize): + additional_model_kwargs["tokenizer_vocab_size"] = len(taskmodule.tokenizer) + + pooler_config = cfg["model"].get("pooler") + if pooler_config is not None: + if isinstance(pooler_config, str): + pooler_config = {"type": pooler_config} + pooler_config = dict(pooler_config) + if pooler_config["type"] in ["start_tokens", "mention_pooling"]: + # NOTE: This is very hacky, we should create a new interface class, e.g. RequiresPoolerNumIndices + if hasattr(taskmodule, "argument_role2idx"): + pooler_config["num_indices"] = len(taskmodule.argument_role2idx) + else: + pooler_config["num_indices"] = 1 + elif pooler_config["type"] == "cls_token": + pass + else: + raise Exception( + f"unknown pooler type: {pooler_config['type']}. Please adjust the train.py script for that type." + ) + additional_model_kwargs["pooler"] = pooler_config + + if issubclass(model_cls, RequiresTaskmoduleConfig): + additional_model_kwargs["taskmodule_config"] = taskmodule.config + + if model_cls == SimpleGenerativeModel: + # There may be already some base_model_config entries in the model config. Also need to convert the + # base_model_config to a dict, because it is a OmegaConf object which does not accept additional entries. + base_model_config = ( + dict(cfg.model.base_model_config) if "base_model_config" in cfg.model else {} + ) + if isinstance(taskmodule, PointerNetworkTaskModuleForEnd2EndRE): + base_model_config.update( + dict( + bos_token_id=taskmodule.bos_id, + eos_token_id=taskmodule.eos_id, + pad_token_id=taskmodule.eos_id, + target_token_ids=taskmodule.target_token_ids, + embedding_weight_mapping=taskmodule.label_embedding_weight_mapping, + ) + ) + additional_model_kwargs["base_model_config"] = base_model_config + + # initialize the model + model: PyTorchIEModel = hydra.utils.instantiate( + cfg.model, _convert_="partial", **additional_model_kwargs + ) + + log.info("Instantiating callbacks...") + callbacks: List[Callback] = utils.instantiate_dict_entries(cfg, key="callbacks") + + log.info("Instantiating loggers...") + logger: List[Logger] = utils.instantiate_dict_entries(cfg, key="logger") + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) + + object_dict = { + "cfg": cfg, + "dataset": dataset, + "taskmodule": taskmodule, + "model": model, + "callbacks": callbacks, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg) + + if cfg.model_save_dir is not None: + log.info(f"Save taskmodule to {cfg.model_save_dir} [push_to_hub={cfg.push_to_hub}]") + taskmodule.save_pretrained(save_directory=cfg.model_save_dir, push_to_hub=cfg.push_to_hub) + else: + log.warning("the taskmodule is not saved because no save_dir is specified") + + if cfg.get("train"): + log.info("Starting training!") + trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) + + train_metrics = trainer.callback_metrics + + best_ckpt_path = trainer.checkpoint_callback.best_model_path + if best_ckpt_path != "": + log.info(f"Best ckpt path: {best_ckpt_path}") + best_checkpoint_file = os.path.basename(best_ckpt_path) + utils.log_hyperparameters( + logger=logger, + best_checkpoint=best_checkpoint_file, + checkpoint_dir=trainer.checkpoint_callback.dirpath, + ) + + if not cfg.trainer.get("fast_dev_run"): + if cfg.model_save_dir is not None: + if best_ckpt_path == "": + log.warning("Best ckpt not found! Using current weights for saving...") + else: + model = type(model).load_from_checkpoint(best_ckpt_path) + + log.info(f"Save model to {cfg.model_save_dir} [push_to_hub={cfg.push_to_hub}]") + model.save_pretrained(save_directory=cfg.model_save_dir, push_to_hub=cfg.push_to_hub) + else: + log.warning("the model is not saved because no save_dir is specified") + + if cfg.get("validate"): + log.info("Starting validation!") + if best_ckpt_path == "": + log.warning("Best ckpt not found! Using current weights for validation...") + trainer.validate(model=model, datamodule=datamodule, ckpt_path=best_ckpt_path or None) + elif cfg.get("train"): + log.warning( + "Validation after training is skipped! That means, the finally reported validation scores are " + "the values from the *last* checkpoint, not from the *best* checkpoint (which is saved)!" + ) + + if cfg.get("test"): + log.info("Starting testing!") + if best_ckpt_path == "": + log.warning("Best ckpt not found! Using current weights for testing...") + trainer.test(model=model, datamodule=datamodule, ckpt_path=best_ckpt_path or None) + + test_metrics = trainer.callback_metrics + + # merge train and test metrics + metric_dict = {**train_metrics, **test_metrics} + + # add model_save_dir to the result so that it gets dumped to job_return_value.json + # if we use hydra_callbacks.SaveJobReturnValueCallback + if cfg.get("model_save_dir") is not None: + metric_dict["model_save_dir"] = cfg.model_save_dir + + return metric_dict, object_dict + + +@hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="train.yaml") +def main(cfg: DictConfig) -> Optional[float]: + # train the model + metric_dict, _ = train(cfg) + + # safely retrieve metric value for hydra-based hyperparameter optimization + if cfg.get("optimized_metric") is not None: + metric_value = get_metric_value( + metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") + ) + + # return optimized metric + return metric_value + else: + return metric_dict + + +if __name__ == "__main__": + utils.replace_sys_args_with_values_from_files() + utils.prepare_omegaconf() + main() diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b0e526c3452c5530f02449d50aa20350b04cc08d --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1,6 @@ +from .config_utils import execute_pipeline, instantiate_dict_entries, prepare_omegaconf +from .data_utils import download_and_unzip, filter_dataframe_and_get_column +from .logging_utils import close_loggers, get_pylogger, log_hyperparameters +from .rich_utils import enforce_tags, print_config_tree +from .span_utils import distance +from .task_utils import extras, replace_sys_args_with_values_from_files, save_file, task_wrapper diff --git a/src/utils/config_utils.py b/src/utils/config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..290faf6b07e55d71b1cee26831747ddb50cd5c6d --- /dev/null +++ b/src/utils/config_utils.py @@ -0,0 +1,71 @@ +from copy import copy +from typing import Any, List, Optional + +from hydra.utils import instantiate +from omegaconf import DictConfig, OmegaConf + +from src.utils.logging_utils import get_pylogger + +logger = get_pylogger(__name__) + + +def execute_pipeline( + input: Any, + setup: Optional[Any] = None, + **processors, +) -> Any: + if setup is not None and callable(setup): + setup() + result = input + for processor_name, processor_config in processors.items(): + if not isinstance(processor_config, dict) or "_processor_" not in processor_config: + continue + logger.info(f"call processor: {processor_name}") + config = copy(processor_config) + if not config.pop("_enabled_", True): + logger.warning(f"skip processor because it is disabled: {processor_name}") + continue + # rename key "_processor_" to "_target_" + if "_target_" in config: + raise ValueError( + f"processor {processor_name} has a key '_target_', which is not allowed" + ) + config["_target_"] = config.pop("_processor_") + # IMPORTANT: We pass result as the first argument after the config in contrast to adding it to the config. + # By doing so, we prevent that it gets converted into a OmegaConf object which would be converted back to + # a simple dict breaking all the DatasetDict methods + tmp_result = instantiate(config, result, _convert_="partial") + if tmp_result is not None: + result = tmp_result + else: + logger.warning(f'processor "{processor_name}" did not return a result') + return result + + +def instantiate_dict_entries( + config: DictConfig, key: str, entry_description: Optional[str] = None +) -> List: + entries: List = [] + key_config = config.get(key) + + if not key_config: + logger.warning(f"{key} config is empty.") + return entries + + if not isinstance(key_config, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, entry_conf in key_config.items(): + if isinstance(entry_conf, DictConfig) and "_target_" in entry_conf: + logger.info(f"Instantiating {entry_description or key} <{entry_conf._target_}>") + entries.append(instantiate(entry_conf, _convert_="partial")) + + return entries + + +def prepare_omegaconf(): + # register replace resolver (used to replace "/" with "-" in names to use them as e.g. wandb project names) + if not OmegaConf.has_resolver("replace"): + OmegaConf.register_new_resolver("replace", lambda s, x, y: s.replace(x, y)) + else: + logger.warning("OmegaConf resolver 'replace' is already registered") diff --git a/src/utils/data_utils.py b/src/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1a89c82ffdd6b9f154c5c0429296616247361198 --- /dev/null +++ b/src/utils/data_utils.py @@ -0,0 +1,33 @@ +import os +from typing import List +from urllib.request import urlretrieve +from zipfile import ZipFile + +import pandas as pd + +from src.utils.logging_utils import get_pylogger + +log = get_pylogger(__name__) + + +def filter_dataframe_and_get_column( + dataframe: pd.DataFrame, filter_column: str, filter_value: str, select_column: str +) -> List[str]: + return dataframe[dataframe[filter_column] == filter_value][select_column].tolist() + + +def download_and_unzip( + url: str, target_path: str, force_download: bool = False, remove_tmp_file: bool = False +): + log.warning(f"download zip file from {url} to {target_path} ...") + if not (url.startswith("http://") or url.startswith("https://")): + raise ValueError(f"url needs to point to a http(s) address, but it is: {url}") + tmp_file = os.path.join(target_path, os.path.basename(url)) + if os.path.exists(tmp_file) and not force_download: + log.warning(f"tmp file {tmp_file} already exists, skip downloading {url}") + else: + urlretrieve(url, tmp_file) # nosec + with ZipFile(tmp_file, "r") as zfile: + zfile.extractall(target_path) + if remove_tmp_file: + os.remove(tmp_file) diff --git a/src/utils/logging_utils.py b/src/utils/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2dfa9d208d2c399a423f83eacfcd6735fd9c6b74 --- /dev/null +++ b/src/utils/logging_utils.py @@ -0,0 +1,94 @@ +import logging +from importlib.util import find_spec +from typing import List, Optional, Union + +from omegaconf import DictConfig, OmegaConf +from pie_modules.models.interface import RequiresTaskmoduleConfig +from pytorch_ie import PyTorchIEModel, TaskModule +from pytorch_lightning.loggers import Logger +from pytorch_lightning.utilities import rank_zero_only + + +def get_pylogger(name=__name__) -> logging.Logger: + """Initializes multi-GPU-friendly python command line logger.""" + + logger = logging.getLogger(name) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") + for level in logging_levels: + setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger + + +log = get_pylogger(__name__) + + +@rank_zero_only +def log_hyperparameters( + logger: Optional[List[Logger]] = None, + config: Optional[Union[dict, DictConfig]] = None, + model: Optional[PyTorchIEModel] = None, + taskmodule: Optional[TaskModule] = None, + key_prefix: str = "_", + **kwargs, +) -> None: + """Controls which config parts are saved by lightning loggers. + + Additional saves: + - Number of model parameters + """ + + hparams = {} + + if not logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + # this is just for backwards compatibility: usually, the taskmodule_config should be passed to + # the model and, thus, be logged there automatically + if model is not None and not isinstance(model, RequiresTaskmoduleConfig): + if taskmodule is None: + raise ValueError( + "If model is not an instance of RequiresTaskmoduleConfig, taskmodule must be passed!" + ) + # here we use the taskmodule/model config how it is after preparation/initialization + hparams["taskmodule_config"] = taskmodule.config + + if model is not None: + # save number of model parameters + hparams[f"{key_prefix}num_params/total"] = sum(p.numel() for p in model.parameters()) + hparams[f"{key_prefix}num_params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams[f"{key_prefix}num_params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + if config is not None: + hparams[f"{key_prefix}config"] = ( + OmegaConf.to_container(config, resolve=True) if OmegaConf.is_config(config) else config + ) + + # add additional hparams + for k, v in kwargs.items(): + hparams[f"{key_prefix}{k}"] = v + + # send hparams to all loggers + for current_logger in logger: + current_logger.log_hyperparams(hparams) + + +def close_loggers() -> None: + """Makes sure all loggers closed properly (prevents logging failure during multirun).""" + + log.info("Closing loggers...") + + if find_spec("wandb"): # if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() diff --git a/src/utils/rich_utils.py b/src/utils/rich_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..706df1676e4e3def6b6522fe86883d5894178af0 --- /dev/null +++ b/src/utils/rich_utils.py @@ -0,0 +1,110 @@ +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig, OmegaConf, open_dict +from pytorch_lightning.utilities import rank_zero_only +from rich.prompt import Prompt + +from src.utils.logging_utils import get_pylogger + +log = get_pylogger(__name__) + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "datamodule", + "taskmodule", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints content of DictConfig using Rich library and its tree structure. + + Args: + cfg (DictConfig): Configuration composed by Hydra. + print_order (Sequence[str], optional): Determines in what order config components are printed. + resolve (bool, optional): Whether to resolve reference fields of DictConfig. + save_to_file (bool, optional): Whether to export config to the hydra output folder. + """ + + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + ( + queue.append(field) + if field in cfg + else log.warning( + f"Field '{field}' not found in config. Skipping '{field}' config printing..." + ) + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config.""" + + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) + + +if __name__ == "__main__": + from hydra import compose, initialize + + with initialize(version_base="1.2", config_path="../../configs"): + cfg = compose(config_name="train.yaml", return_hydra_config=False, overrides=[]) + print_config_tree(cfg, resolve=False, save_to_file=False) diff --git a/src/utils/span_utils.py b/src/utils/span_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1880c75503640b3dd017bbab342f292e7945648d --- /dev/null +++ b/src/utils/span_utils.py @@ -0,0 +1,60 @@ +from typing import Tuple + + +def get_overlap_len(indices_1: Tuple[int, int], indices_2: Tuple[int, int]) -> int: + if indices_1[0] > indices_2[0]: + tmp = indices_1 + indices_1 = indices_2 + indices_2 = tmp + if indices_1[1] <= indices_2[0]: + return 0 + return min(indices_1[1] - indices_2[0], indices_2[1] - indices_2[0]) + + +def have_overlap(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> bool: + other_start_overlaps = start_end[0] <= other_start_end[0] < start_end[1] + other_end_overlaps = start_end[0] < other_start_end[1] <= start_end[1] + start_overlaps_other = other_start_end[0] <= start_end[0] < other_start_end[1] + end_overlaps_other = other_start_end[0] < start_end[1] <= other_start_end[1] + return other_start_overlaps or other_end_overlaps or start_overlaps_other or end_overlaps_other + + +def is_contained_in(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> bool: + return other_start_end[0] <= start_end[0] and start_end[1] <= other_start_end[1] + + +def distance_center(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> float: + center = (start_end[0] + start_end[1]) / 2 + center_other = (other_start_end[0] + other_start_end[1]) / 2 + return abs(center - center_other) + + +def distance_outer(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> float: + _max = max(start_end[0], start_end[1], other_start_end[0], other_start_end[1]) + _min = min(start_end[0], start_end[1], other_start_end[0], other_start_end[1]) + return float(_max - _min) + + +def distance_inner(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> float: + dist_start_other_end = abs(start_end[0] - other_start_end[1]) + dist_end_other_start = abs(start_end[1] - other_start_end[0]) + dist = float(min(dist_start_other_end, dist_end_other_start)) + if not have_overlap(start_end, other_start_end): + return dist + else: + return -dist + + +def distance( + start_end: Tuple[int, int], other_start_end: Tuple[int, int], distance_type: str +) -> float: + if distance_type == "center": + return distance_center(start_end, other_start_end) + elif distance_type == "inner": + return distance_inner(start_end, other_start_end) + elif distance_type == "outer": + return distance_outer(start_end, other_start_end) + else: + raise ValueError( + f"unknown distance_type={distance_type}. use one of: center, inner, outer" + ) diff --git a/src/utils/task_utils.py b/src/utils/task_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fb35d3dadd93e8fa12aa98a16bc89f18a421b335 --- /dev/null +++ b/src/utils/task_utils.py @@ -0,0 +1,178 @@ +import json +import os +import sys +import time +import warnings +from pathlib import Path +from typing import Callable, Dict, Optional + +from omegaconf import DictConfig +from pytorch_lightning.utilities import rank_zero_only + +from src.utils.logging_utils import close_loggers, get_pylogger +from src.utils.rich_utils import enforce_tags, print_config_tree + +log = get_pylogger(__name__) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that wraps the task function in extra utilities. + + Makes multirun more resistant to failure. + + Utilities: + - Calling the `utils.extras()` before the task is started + - Calling the `utils.close_loggers()` after the task is finished + - Logging the exception if occurs + - Logging the task total execution time + - Logging the output dir + """ + + def wrap(cfg: DictConfig): + + # apply extra utilities + extras(cfg) + + # execute the task + start_time = time.time() + try: + task_result = task_func(cfg=cfg) + except Exception as ex: + log.exception("") # save exception to `.log` file + raise ex + finally: + path = Path(cfg.paths.output_dir, "exec_time.log") + content = f"'{cfg.pipeline_type}' execution time: {time.time() - start_time} (s)" + save_file(path, content) # save task execution time (even if exception occurs) + close_loggers() # close loggers (even if exception occurs so multirun won't fail) + + log.info(f"Output dir: {cfg.paths.output_dir}") + + return task_result + + return wrap + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + """ + + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + print_config_tree(cfg, resolve=True, save_to_file=True) + + +@rank_zero_only +def save_file(path: str, content: str) -> None: + """Save file in rank zero mode (only on one process in multi-GPU setup).""" + with open(path, "w+") as file: + file.write(content) + + +def load_value_from_file(path: str, split_path_key: str = ":", split_key_parts: str = "/") -> Dict: + """Load a value from a file. The path can point to elements within the file (see split_path_key + parameter) and that can be nested (see split_key_parts parameter). For now, only .json files + are supported. + + Args: + path: path to the file (and data within the file) + split_path_key: split the path on this value to get the path to the file and the key within the file + split_key_parts: the value to split the key on to get the nested keys + """ + + parts_path = path.split(split_path_key, maxsplit=1) + file_extension = os.path.splitext(parts_path[0])[1] + if file_extension == ".json": + with open(parts_path[0], "r") as f: + data = json.load(f) + else: + raise ValueError(f"Expected .json file, got {file_extension}") + + if len(parts_path) == 1: + return data + + keys = parts_path[1].split(split_key_parts) + for key in keys: + data = data[key] + return data + + +def replace_sys_args_with_values_from_files( + load_prefix: str = "LOAD_ARG:", + load_multi_prefix: str = "LOAD_MULTI_ARG:", + **load_value_from_file_kwargs, +) -> None: + """Replaces arguments in sys.argv with values loaded from files. + + Examples: + # config.json contains {"a": 1, "b": 2} + python train.py LOAD_ARG:job_return_value.json + # this will pass "{a:1,b:2}" as the first argument to train.py + + # config.json contains [1, 2, 3] + python train.py LOAD_MULTI_ARG:job_return_value.json + # this will pass "1,2,3" as the first argument to train.py + + # config.json contains {"model": {"ouput_dir": ["path1", "path2"], f1: [0.7, 0.6]}} + python train.py load_model=LOAD_ARG:job_return_value.json:model/output_dir + # this will pass "load_model=path1,path2" to train.py + + Args: + load_prefix: the prefix to use for loading a single value from a file + load_multi_prefix: the prefix to use for loading a list of values from a file + **load_value_from_file_kwargs: additional kwargs to pass to load_value_from_file + """ + + updated_args = [] + for arg in sys.argv[1:]: + is_multirun_arg = False + if load_prefix in arg: + parts = arg.split(load_prefix, maxsplit=1) + elif load_multi_prefix in arg: + parts = arg.split(load_multi_prefix, maxsplit=1) + is_multirun_arg = True + else: + updated_args.append(arg) + continue + if len(parts) == 2: + log.warning(f'Replacing argument value for "{parts[0]}" with content from {parts[1]}') + json_value = load_value_from_file(parts[1], **load_value_from_file_kwargs) + json_value_str = json.dumps(json_value) + # replace quotes and spaces + json_value_str = json_value_str.replace('"', "").replace(" ", "") + # remove outer brackets + if is_multirun_arg: + if not isinstance(json_value, list): + raise ValueError( + f"Expected list for multirun argument, got {type(json_value)}. If you just want " + f"to set a single value, use {load_prefix} instead of {load_multi_prefix}." + ) + json_value_str = json_value_str[1:-1] + # add outer quotes + modified_arg = f"{parts[0]}{json_value_str}" + updated_args.append(modified_arg) + else: + updated_args.append(arg) + # Set sys.argv to the updated arguments + sys.argv = [sys.argv[0]] + updated_args diff --git a/src/vendor/__init__.py b/src/vendor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..203cb9c9fccf814b835460da1181764ea6ae8fc4 --- /dev/null +++ b/src/vendor/__init__.py @@ -0,0 +1 @@ +# use this folder for storing third party code that cannot be installed using pip/conda