File size: 5,445 Bytes
3133b5e ced4316 3133b5e ced4316 3133b5e ced4316 3133b5e ced4316 3133b5e ced4316 3133b5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import pyrootutils
root = pyrootutils.setup_root(
search_from=__file__,
indicator=[".project-root"],
pythonpath=True,
dotenv=True,
)
import argparse
import logging
import os
import pandas as pd
from src.demo.retriever_utils import (
retrieve_all_relevant_spans,
retrieve_all_relevant_spans_for_all_documents,
retrieve_relevant_spans,
)
from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations
logger = logging.getLogger(__name__)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--config_path",
type=str,
default="configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml",
)
parser.add_argument(
"--data_path",
type=str,
required=True,
help="Path to a zip or directory containing a retriever dump.",
)
parser.add_argument("-k", "--top_k", type=int, default=10)
parser.add_argument("-t", "--threshold", type=float, default=0.95)
parser.add_argument(
"-o",
"--output_path",
type=str,
required=True,
)
parser.add_argument(
"--query_doc_id",
type=str,
default=None,
help="If provided, retrieve all spans for only this query document.",
)
parser.add_argument(
"--query_span_id",
type=str,
default=None,
help="If provided, retrieve all spans for only this query span.",
)
parser.add_argument(
"--doc_id_whitelist",
type=str,
nargs="+",
default=None,
help="If provided, only consider documents with these IDs.",
)
parser.add_argument(
"--doc_id_blacklist",
type=str,
nargs="+",
default=None,
help="If provided, ignore documents with these IDs.",
)
parser.add_argument(
"--query_target_doc_id_pairs",
type=str,
nargs="+",
default=None,
help="One or more pairs of query and target document IDs "
'(each separated by ":") to retrieve spans for. If provided, '
"--query_doc_id and --query_span_id are ignored.",
)
args = parser.parse_args()
logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
if not args.output_path.endswith(".json"):
raise ValueError("only support json output")
logger.info(f"instantiating retriever from {args.config_path}...")
retriever = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config_file(
args.config_path
)
logger.info(f"loading data from {args.data_path}...")
retriever.load_from_disc(args.data_path)
search_kwargs = {"k": args.top_k, "score_threshold": args.threshold}
if args.doc_id_whitelist is not None:
search_kwargs["doc_id_whitelist"] = args.doc_id_whitelist
if args.doc_id_blacklist is not None:
search_kwargs["doc_id_blacklist"] = args.doc_id_blacklist
logger.info(f"use search_kwargs: {search_kwargs}")
if args.query_target_doc_id_pairs is not None:
all_spans_for_all_documents = None
for doc_id_pair in args.query_target_doc_id_pairs:
query_doc_id, target_doc_id = doc_id_pair.split(":")
current_result = retrieve_all_relevant_spans(
retriever=retriever,
query_doc_id=query_doc_id,
doc_id_whitelist=[target_doc_id],
**search_kwargs,
)
if current_result is None:
logger.warning(
f"no relevant spans found for query_doc_id={query_doc_id} and "
f"target_doc_id={target_doc_id}"
)
continue
logger.info(
f"retrieved {len(current_result)} spans for query_doc_id={query_doc_id} "
f"and target_doc_id={target_doc_id}"
)
current_result["query_doc_id"] = query_doc_id
if all_spans_for_all_documents is None:
all_spans_for_all_documents = current_result
else:
all_spans_for_all_documents = pd.concat(
[all_spans_for_all_documents, current_result], ignore_index=True
)
elif args.query_span_id is not None:
logger.warning(f"retrieving results for single span: {args.query_span_id}")
all_spans_for_all_documents = retrieve_relevant_spans(
retriever=retriever, query_span_id=args.query_span_id, **search_kwargs
)
elif args.query_doc_id is not None:
logger.warning(f"retrieving results for single document: {args.query_doc_id}")
all_spans_for_all_documents = retrieve_all_relevant_spans(
retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs
)
else:
all_spans_for_all_documents = retrieve_all_relevant_spans_for_all_documents(
retriever=retriever, **search_kwargs
)
if all_spans_for_all_documents is None:
logger.warning("no relevant spans found in any document")
exit(0)
logger.info(f"dumping results ({len(all_spans_for_all_documents)}) to {args.output_path}...")
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
all_spans_for_all_documents.to_json(args.output_path)
logger.info("done")
|