File size: 5,445 Bytes
3133b5e
 
 
 
 
 
 
 
 
 
 
ced4316
 
 
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ced4316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ced4316
 
 
 
3133b5e
 
ced4316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ced4316
 
3133b5e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import pyrootutils

root = pyrootutils.setup_root(
    search_from=__file__,
    indicator=[".project-root"],
    pythonpath=True,
    dotenv=True,
)

import argparse
import logging
import os

import pandas as pd

from src.demo.retriever_utils import (
    retrieve_all_relevant_spans,
    retrieve_all_relevant_spans_for_all_documents,
    retrieve_relevant_spans,
)
from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations

logger = logging.getLogger(__name__)


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-c",
        "--config_path",
        type=str,
        default="configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        required=True,
        help="Path to a zip or directory containing a retriever dump.",
    )
    parser.add_argument("-k", "--top_k", type=int, default=10)
    parser.add_argument("-t", "--threshold", type=float, default=0.95)
    parser.add_argument(
        "-o",
        "--output_path",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--query_doc_id",
        type=str,
        default=None,
        help="If provided, retrieve all spans for only this query document.",
    )
    parser.add_argument(
        "--query_span_id",
        type=str,
        default=None,
        help="If provided, retrieve all spans for only this query span.",
    )
    parser.add_argument(
        "--doc_id_whitelist",
        type=str,
        nargs="+",
        default=None,
        help="If provided, only consider documents with these IDs.",
    )
    parser.add_argument(
        "--doc_id_blacklist",
        type=str,
        nargs="+",
        default=None,
        help="If provided, ignore documents with these IDs.",
    )
    parser.add_argument(
        "--query_target_doc_id_pairs",
        type=str,
        nargs="+",
        default=None,
        help="One or more pairs of query and target document IDs "
        '(each separated by ":") to retrieve spans for. If provided, '
        "--query_doc_id and --query_span_id are ignored.",
    )
    args = parser.parse_args()

    logging.basicConfig(
        format="%(asctime)s %(levelname)-8s %(message)s",
        level=logging.INFO,
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    if not args.output_path.endswith(".json"):
        raise ValueError("only support json output")

    logger.info(f"instantiating retriever from {args.config_path}...")
    retriever = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config_file(
        args.config_path
    )
    logger.info(f"loading data from {args.data_path}...")
    retriever.load_from_disc(args.data_path)

    search_kwargs = {"k": args.top_k, "score_threshold": args.threshold}
    if args.doc_id_whitelist is not None:
        search_kwargs["doc_id_whitelist"] = args.doc_id_whitelist
    if args.doc_id_blacklist is not None:
        search_kwargs["doc_id_blacklist"] = args.doc_id_blacklist
    logger.info(f"use search_kwargs: {search_kwargs}")

    if args.query_target_doc_id_pairs is not None:
        all_spans_for_all_documents = None
        for doc_id_pair in args.query_target_doc_id_pairs:
            query_doc_id, target_doc_id = doc_id_pair.split(":")
            current_result = retrieve_all_relevant_spans(
                retriever=retriever,
                query_doc_id=query_doc_id,
                doc_id_whitelist=[target_doc_id],
                **search_kwargs,
            )
            if current_result is None:
                logger.warning(
                    f"no relevant spans found for query_doc_id={query_doc_id} and "
                    f"target_doc_id={target_doc_id}"
                )
                continue
            logger.info(
                f"retrieved {len(current_result)} spans for query_doc_id={query_doc_id} "
                f"and target_doc_id={target_doc_id}"
            )
            current_result["query_doc_id"] = query_doc_id
            if all_spans_for_all_documents is None:
                all_spans_for_all_documents = current_result
            else:
                all_spans_for_all_documents = pd.concat(
                    [all_spans_for_all_documents, current_result], ignore_index=True
                )

    elif args.query_span_id is not None:
        logger.warning(f"retrieving results for single span: {args.query_span_id}")
        all_spans_for_all_documents = retrieve_relevant_spans(
            retriever=retriever, query_span_id=args.query_span_id, **search_kwargs
        )
    elif args.query_doc_id is not None:
        logger.warning(f"retrieving results for single document: {args.query_doc_id}")
        all_spans_for_all_documents = retrieve_all_relevant_spans(
            retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs
        )
    else:
        all_spans_for_all_documents = retrieve_all_relevant_spans_for_all_documents(
            retriever=retriever, **search_kwargs
        )

    if all_spans_for_all_documents is None:
        logger.warning("no relevant spans found in any document")
        exit(0)

    logger.info(f"dumping results ({len(all_spans_for_all_documents)}) to {args.output_path}...")
    os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
    all_spans_for_all_documents.to_json(args.output_path)

    logger.info("done")