import pyrootutils root = pyrootutils.setup_root( search_from=__file__, indicator=[".project-root"], pythonpath=True, dotenv=True, ) import argparse import logging import os import pandas as pd from src.demo.retriever_utils import ( retrieve_all_relevant_spans, retrieve_all_relevant_spans_for_all_documents, retrieve_relevant_spans, ) from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations logger = logging.getLogger(__name__) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "-c", "--config_path", type=str, default="configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml", ) parser.add_argument( "--data_path", type=str, required=True, help="Path to a zip or directory containing a retriever dump.", ) parser.add_argument("-k", "--top_k", type=int, default=10) parser.add_argument("-t", "--threshold", type=float, default=0.95) parser.add_argument( "-o", "--output_path", type=str, required=True, ) parser.add_argument( "--query_doc_id", type=str, default=None, help="If provided, retrieve all spans for only this query document.", ) parser.add_argument( "--query_span_id", type=str, default=None, help="If provided, retrieve all spans for only this query span.", ) parser.add_argument( "--doc_id_whitelist", type=str, nargs="+", default=None, help="If provided, only consider documents with these IDs.", ) parser.add_argument( "--doc_id_blacklist", type=str, nargs="+", default=None, help="If provided, ignore documents with these IDs.", ) parser.add_argument( "--query_target_doc_id_pairs", type=str, nargs="+", default=None, help="One or more pairs of query and target document IDs " '(each separated by ":") to retrieve spans for. If provided, ' "--query_doc_id and --query_span_id are ignored.", ) args = parser.parse_args() logging.basicConfig( format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S", ) if not args.output_path.endswith(".json"): raise ValueError("only support json output") logger.info(f"instantiating retriever from {args.config_path}...") retriever = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config_file( args.config_path ) logger.info(f"loading data from {args.data_path}...") retriever.load_from_disc(args.data_path) search_kwargs = {"k": args.top_k, "score_threshold": args.threshold} if args.doc_id_whitelist is not None: search_kwargs["doc_id_whitelist"] = args.doc_id_whitelist if args.doc_id_blacklist is not None: search_kwargs["doc_id_blacklist"] = args.doc_id_blacklist logger.info(f"use search_kwargs: {search_kwargs}") if args.query_target_doc_id_pairs is not None: all_spans_for_all_documents = None for doc_id_pair in args.query_target_doc_id_pairs: query_doc_id, target_doc_id = doc_id_pair.split(":") current_result = retrieve_all_relevant_spans( retriever=retriever, query_doc_id=query_doc_id, doc_id_whitelist=[target_doc_id], **search_kwargs, ) if current_result is None: logger.warning( f"no relevant spans found for query_doc_id={query_doc_id} and " f"target_doc_id={target_doc_id}" ) continue logger.info( f"retrieved {len(current_result)} spans for query_doc_id={query_doc_id} " f"and target_doc_id={target_doc_id}" ) current_result["query_doc_id"] = query_doc_id if all_spans_for_all_documents is None: all_spans_for_all_documents = current_result else: all_spans_for_all_documents = pd.concat( [all_spans_for_all_documents, current_result], ignore_index=True ) elif args.query_span_id is not None: logger.warning(f"retrieving results for single span: {args.query_span_id}") all_spans_for_all_documents = retrieve_relevant_spans( retriever=retriever, query_span_id=args.query_span_id, **search_kwargs ) elif args.query_doc_id is not None: logger.warning(f"retrieving results for single document: {args.query_doc_id}") all_spans_for_all_documents = retrieve_all_relevant_spans( retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs ) else: all_spans_for_all_documents = retrieve_all_relevant_spans_for_all_documents( retriever=retriever, **search_kwargs ) if all_spans_for_all_documents is None: logger.warning("no relevant spans found in any document") exit(0) logger.info(f"dumping results ({len(all_spans_for_all_documents)}) to {args.output_path}...") os.makedirs(os.path.dirname(args.output_path), exist_ok=True) all_spans_for_all_documents.to_json(args.output_path) logger.info("done")