Spaces:
Runtime error
Runtime error
| import logging | |
| import random | |
| import time | |
| from copy import deepcopy | |
| from pathlib import Path | |
| from typing import List, Optional, Set, Union | |
| import lightning as pl | |
| import torch | |
| from lightning.pytorch.trainer.states import RunningStage | |
| from omegaconf import DictConfig | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from relik.common.log import get_console_logger, get_logger | |
| from relik.retriever.callbacks.base import PredictionCallback | |
| from relik.retriever.common.model_inputs import ModelInputs | |
| from relik.retriever.data.base.datasets import BaseDataset | |
| from relik.retriever.data.datasets import GoldenRetrieverDataset | |
| from relik.retriever.data.utils import HardNegativesManager | |
| from relik.retriever.indexers.base import BaseDocumentIndex | |
| from relik.retriever.pytorch_modules.model import GoldenRetriever | |
| console_logger = get_console_logger() | |
| logger = get_logger(__name__, level=logging.INFO) | |
| class GoldenRetrieverPredictionCallback(PredictionCallback): | |
| def __init__( | |
| self, | |
| k: Optional[int] = None, | |
| batch_size: int = 32, | |
| num_workers: int = 8, | |
| document_index: Optional[BaseDocumentIndex] = None, | |
| precision: Union[str, int] = 32, | |
| force_reindex: bool = True, | |
| retriever_dir: Optional[Path] = None, | |
| stages: Optional[Set[Union[str, RunningStage]]] = None, | |
| other_callbacks: Optional[List[DictConfig]] = None, | |
| dataset: Optional[Union[DictConfig, BaseDataset]] = None, | |
| dataloader: Optional[DataLoader] = None, | |
| *args, | |
| **kwargs, | |
| ): | |
| super().__init__(batch_size, stages, other_callbacks, dataset, dataloader) | |
| self.k = k | |
| self.num_workers = num_workers | |
| self.document_index = document_index | |
| self.precision = precision | |
| self.force_reindex = force_reindex | |
| self.retriever_dir = retriever_dir | |
| def __call__( | |
| self, | |
| trainer: pl.Trainer, | |
| pl_module: pl.LightningModule, | |
| datasets: Optional[ | |
| Union[DictConfig, BaseDataset, List[DictConfig], List[BaseDataset]] | |
| ] = None, | |
| dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, | |
| *args, | |
| **kwargs, | |
| ) -> dict: | |
| stage = trainer.state.stage | |
| logger.info(f"Computing predictions for stage {stage.value}") | |
| if stage not in self.stages: | |
| raise ValueError( | |
| f"Stage `{stage}` not supported, only {self.stages} are supported" | |
| ) | |
| self.datasets, self.dataloaders = self._get_datasets_and_dataloaders( | |
| datasets, | |
| dataloaders, | |
| trainer, | |
| dataloader_kwargs=dict( | |
| batch_size=self.batch_size, | |
| num_workers=self.num_workers, | |
| pin_memory=True, | |
| shuffle=False, | |
| ), | |
| ) | |
| # set the model to eval mode | |
| pl_module.eval() | |
| # get the retriever | |
| retriever: GoldenRetriever = pl_module.model | |
| # here we will store the samples with predictions for each dataloader | |
| dataloader_predictions = {} | |
| # compute the passage embeddings index for each dataloader | |
| for dataloader_idx, dataloader in enumerate(self.dataloaders): | |
| current_dataset: GoldenRetrieverDataset = self.datasets[dataloader_idx] | |
| logger.info( | |
| f"Computing passage embeddings for dataset {current_dataset.name}" | |
| ) | |
| # passages = self._get_passages_dataloader(current_dataset, trainer) | |
| tokenizer = current_dataset.tokenizer | |
| def collate_fn(x): | |
| return ModelInputs( | |
| tokenizer( | |
| x, | |
| truncation=True, | |
| padding=True, | |
| max_length=current_dataset.max_passage_length, | |
| return_tensors="pt", | |
| ) | |
| ) | |
| # check if we need to reindex the passages and | |
| # also if we need to load the retriever from disk | |
| if (self.retriever_dir is not None and trainer.current_epoch == 0) or ( | |
| self.retriever_dir is not None and stage == RunningStage.TESTING | |
| ): | |
| force_reindex = False | |
| else: | |
| force_reindex = self.force_reindex | |
| if ( | |
| not force_reindex | |
| and self.retriever_dir is not None | |
| and stage == RunningStage.TESTING | |
| ): | |
| retriever = retriever.from_pretrained(self.retriever_dir) | |
| # set the retriever to eval mode if we are loading it from disk | |
| # you never know :) | |
| retriever.eval() | |
| retriever.index( | |
| batch_size=self.batch_size, | |
| num_workers=self.num_workers, | |
| max_length=current_dataset.max_passage_length, | |
| collate_fn=collate_fn, | |
| precision=self.precision, | |
| compute_on_cpu=False, | |
| force_reindex=force_reindex, | |
| ) | |
| # pl_module_original_device = pl_module.device | |
| # if ( | |
| # and pl_module.device.type == "cuda" | |
| # ): | |
| # pl_module.to("cpu") | |
| # now compute the question embeddings and compute the top-k accuracy | |
| predictions = [] | |
| start = time.time() | |
| for batch in tqdm( | |
| dataloader, | |
| desc=f"Computing predictions for dataset {current_dataset.name}", | |
| ): | |
| batch = batch.to(pl_module.device) | |
| # get the top-k indices | |
| retriever_output = retriever.retrieve( | |
| **batch.questions, k=self.k, precision=self.precision | |
| ) | |
| # compute recall at k | |
| for batch_idx, retrieved_samples in enumerate(retriever_output): | |
| # get the positive passages | |
| gold_passages = batch["positives"][batch_idx] | |
| # get the index of the gold passages in the retrieved passages | |
| gold_passage_indices = [ | |
| retriever.get_index_from_passage(passage) | |
| for passage in gold_passages | |
| ] | |
| retrieved_indices = [r.index for r in retrieved_samples] | |
| retrieved_passages = [r.label for r in retrieved_samples] | |
| retrieved_scores = [r.score for r in retrieved_samples] | |
| # correct predictions are the passages that are in the top-k and are gold | |
| correct_indices = set(gold_passage_indices) & set(retrieved_indices) | |
| # wrong predictions are the passages that are in the top-k and are not gold | |
| wrong_indices = set(retrieved_indices) - set(gold_passage_indices) | |
| # add the predictions to the list | |
| prediction_output = dict( | |
| sample_idx=batch.sample_idx[batch_idx], | |
| gold=gold_passages, | |
| predictions=retrieved_passages, | |
| scores=retrieved_scores, | |
| correct=[ | |
| retriever.get_passage_from_index(i) for i in correct_indices | |
| ], | |
| wrong=[ | |
| retriever.get_passage_from_index(i) for i in wrong_indices | |
| ], | |
| ) | |
| predictions.append(prediction_output) | |
| end = time.time() | |
| logger.info(f"Time to retrieve: {str(end - start)}") | |
| dataloader_predictions[dataloader_idx] = predictions | |
| # if pl_module_original_device != pl_module.device: | |
| # pl_module.to(pl_module_original_device) | |
| # return the predictions | |
| return dataloader_predictions | |
| # @staticmethod | |
| # def _get_passages_dataloader( | |
| # indexer: Optional[BaseIndexer] = None, | |
| # dataset: Optional[GoldenRetrieverDataset] = None, | |
| # trainer: Optional[pl.Trainer] = None, | |
| # ): | |
| # if indexer is None: | |
| # logger.info( | |
| # f"Indexer is None, creating indexer from passages not found in dataset {dataset.name}, computing them from the dataloaders" | |
| # ) | |
| # # get the passages from the all the dataloader passage ids | |
| # passages = set() # set to avoid duplicates | |
| # for batch in trainer.train_dataloader: | |
| # passages.update( | |
| # [ | |
| # " ".join(map(str, [c for c in passage_ids.tolist() if c != 0])) | |
| # for passage_ids in batch["passages"]["input_ids"] | |
| # ] | |
| # ) | |
| # for d in trainer.val_dataloaders: | |
| # for batch in d: | |
| # passages.update( | |
| # [ | |
| # " ".join( | |
| # map(str, [c for c in passage_ids.tolist() if c != 0]) | |
| # ) | |
| # for passage_ids in batch["passages"]["input_ids"] | |
| # ] | |
| # ) | |
| # for d in trainer.test_dataloaders: | |
| # for batch in d: | |
| # passages.update( | |
| # [ | |
| # " ".join( | |
| # map(str, [c for c in passage_ids.tolist() if c != 0]) | |
| # ) | |
| # for passage_ids in batch["passages"]["input_ids"] | |
| # ] | |
| # ) | |
| # passages = list(passages) | |
| # else: | |
| # passages = dataset.passages | |
| # return passages | |
| class NegativeAugmentationCallback(GoldenRetrieverPredictionCallback): | |
| """ | |
| Callback that computes the predictions of a retriever model on a dataset and computes the | |
| negative examples for the training set. | |
| Args: | |
| k (:obj:`int`, `optional`, defaults to 100): | |
| The number of top-k retrieved passages to | |
| consider for the evaluation. | |
| batch_size (:obj:`int`, `optional`, defaults to 32): | |
| The batch size to use for the evaluation. | |
| num_workers (:obj:`int`, `optional`, defaults to 0): | |
| The number of workers to use for the evaluation. | |
| force_reindex (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
| Whether to force the reindexing of the dataset. | |
| retriever_dir (:obj:`Path`, `optional`): | |
| The path to the retriever directory. If not specified, the retriever will be | |
| initialized from scratch. | |
| stages (:obj:`Set[str]`, `optional`): | |
| The stages to run the callback on. If not specified, the callback will be run on | |
| train, validation and test. | |
| other_callbacks (:obj:`List[DictConfig]`, `optional`): | |
| A list of other callbacks to run on the same stages. | |
| dataset (:obj:`Union[DictConfig, BaseDataset]`, `optional`): | |
| The dataset to use for the evaluation. If not specified, the dataset will be | |
| initialized from scratch. | |
| metrics_to_monitor (:obj:`List[str]`, `optional`): | |
| The metrics to monitor for the evaluation. | |
| threshold (:obj:`float`, `optional`, defaults to 0.8): | |
| The threshold to consider. If the recall score of the retriever is above the | |
| threshold, the negative examples will be added to the training set. | |
| max_negatives (:obj:`int`, `optional`, defaults to 5): | |
| The maximum number of negative examples to add to the training set. | |
| add_with_probability (:obj:`float`, `optional`, defaults to 1.0): | |
| The probability with which to add the negative examples to the training set. | |
| refresh_every_n_epochs (:obj:`int`, `optional`, defaults to 1): | |
| The number of epochs after which to refresh the index. | |
| """ | |
| def __init__( | |
| self, | |
| k: int = 100, | |
| batch_size: int = 32, | |
| num_workers: int = 0, | |
| force_reindex: bool = False, | |
| retriever_dir: Optional[Path] = None, | |
| stages: Set[Union[str, RunningStage]] = None, | |
| other_callbacks: Optional[List[DictConfig]] = None, | |
| dataset: Optional[Union[DictConfig, BaseDataset]] = None, | |
| metrics_to_monitor: List[str] = None, | |
| threshold: float = 0.8, | |
| max_negatives: int = 5, | |
| add_with_probability: float = 1.0, | |
| refresh_every_n_epochs: int = 1, | |
| *args, | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| k=k, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| force_reindex=force_reindex, | |
| retriever_dir=retriever_dir, | |
| stages=stages, | |
| other_callbacks=other_callbacks, | |
| dataset=dataset, | |
| *args, | |
| **kwargs, | |
| ) | |
| if metrics_to_monitor is None: | |
| metrics_to_monitor = ["val_loss"] | |
| self.metrics_to_monitor = metrics_to_monitor | |
| self.threshold = threshold | |
| self.max_negatives = max_negatives | |
| self.add_with_probability = add_with_probability | |
| self.refresh_every_n_epochs = refresh_every_n_epochs | |
| def __call__( | |
| self, | |
| trainer: pl.Trainer, | |
| pl_module: pl.LightningModule, | |
| *args, | |
| **kwargs, | |
| ) -> dict: | |
| """ | |
| Computes the predictions of a retriever model on a dataset and computes the negative | |
| examples for the training set. | |
| Args: | |
| trainer (:obj:`pl.Trainer`): | |
| The trainer object. | |
| pl_module (:obj:`pl.LightningModule`): | |
| The lightning module. | |
| Returns: | |
| A dictionary containing the negative examples. | |
| """ | |
| stage = trainer.state.stage | |
| if stage not in self.stages: | |
| return {} | |
| if self.metrics_to_monitor not in trainer.logged_metrics: | |
| raise ValueError( | |
| f"Metric `{self.metrics_to_monitor}` not found in trainer.logged_metrics" | |
| f"Available metrics: {trainer.logged_metrics.keys()}" | |
| ) | |
| if trainer.logged_metrics[self.metrics_to_monitor] < self.threshold: | |
| return {} | |
| if trainer.current_epoch % self.refresh_every_n_epochs != 0: | |
| return {} | |
| # if all( | |
| # [ | |
| # trainer.logged_metrics.get(metric) is None | |
| # for metric in self.metrics_to_monitor | |
| # ] | |
| # ): | |
| # raise ValueError( | |
| # f"No metric from {self.metrics_to_monitor} not found in trainer.logged_metrics" | |
| # f"Available metrics: {trainer.logged_metrics.keys()}" | |
| # ) | |
| # if all( | |
| # [ | |
| # trainer.logged_metrics.get(metric) < self.threshold | |
| # for metric in self.metrics_to_monitor | |
| # if trainer.logged_metrics.get(metric) is not None | |
| # ] | |
| # ): | |
| # return {} | |
| if trainer.current_epoch % self.refresh_every_n_epochs != 0: | |
| return {} | |
| logger.info( | |
| f"At least one metric from {self.metrics_to_monitor} is above threshold " | |
| f"{self.threshold}. Computing hard negatives." | |
| ) | |
| # make a copy of the dataset to avoid modifying the original one | |
| trainer.datamodule.train_dataset.hn_manager = None | |
| dataset_copy = deepcopy(trainer.datamodule.train_dataset) | |
| predictions = super().__call__( | |
| trainer, | |
| pl_module, | |
| datasets=dataset_copy, | |
| dataloaders=DataLoader( | |
| dataset_copy.to_torch_dataset(), | |
| shuffle=False, | |
| batch_size=None, | |
| num_workers=self.num_workers, | |
| pin_memory=True, | |
| collate_fn=lambda x: x, | |
| ), | |
| *args, | |
| **kwargs, | |
| ) | |
| logger.info(f"Computing hard negatives for epoch {trainer.current_epoch}") | |
| # predictions is a dict with the dataloader index as key and the predictions as value | |
| # since we only have one dataloader, we can get the predictions directly | |
| predictions = list(predictions.values())[0] | |
| # store the predictions in a dictionary for faster access based on the sample index | |
| hard_negatives_list = {} | |
| for prediction in tqdm(predictions, desc="Collecting hard negatives"): | |
| if random.random() < 1 - self.add_with_probability: | |
| continue | |
| top_k_passages = prediction["predictions"] | |
| gold_passages = prediction["gold"] | |
| # get the ids of the max_negatives wrong passages with the highest similarity | |
| wrong_passages = [ | |
| passage_id | |
| for passage_id in top_k_passages | |
| if passage_id not in gold_passages | |
| ][: self.max_negatives] | |
| hard_negatives_list[prediction["sample_idx"]] = wrong_passages | |
| trainer.datamodule.train_dataset.hn_manager = HardNegativesManager( | |
| tokenizer=trainer.datamodule.train_dataset.tokenizer, | |
| max_length=trainer.datamodule.train_dataset.max_passage_length, | |
| data=hard_negatives_list, | |
| ) | |
| # normalize predictions as in the original GoldenRetrieverPredictionCallback | |
| predictions = {0: predictions} | |
| return predictions | |