File size: 4,115 Bytes
aeb12b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import json
from typing import List, Literal, Protocol, Tuple, TypedDict, Union

from pyserini.analysis import get_lucene_analyzer
from pyserini.index import IndexReader
from pyserini.search import DenseSearchResult, JLuceneSearcherResult
from pyserini.search.faiss.__main__ import init_query_encoder
from pyserini.search.faiss import FaissSearcher
from pyserini.search.hybrid import HybridSearcher
from pyserini.search.lucene import LuceneSearcher

EncoderClass = Literal["dkrr", "dpr", "tct_colbert", "ance", "sentence", "contriever", "auto"]


class AnalyzerArgs(TypedDict):
    language: str
    stemming: bool
    stemmer: str
    stopwords: bool
    huggingFaceTokenizer: str


class SearchResult(TypedDict):
    docid: str
    text: str
    score: float
    language: str


class Searcher(Protocol):
    def search(self, query: str, **kwargs) -> List[Union[DenseSearchResult, JLuceneSearcherResult]]:
        ...


def init_searcher_and_reader(
    sparse_index_path: str = None,
    bm25_k1: float = None,
    bm25_b: float = None,
    analyzer_args: AnalyzerArgs = None,
    dense_index_path: str = None,
    encoder_name_or_path: str = None,
    encoder_class: EncoderClass = None, 
    tokenizer_name: str = None,
    device: str = None,
    prefix: str = None
) -> Tuple[Union[FaissSearcher, HybridSearcher, LuceneSearcher], IndexReader]:
    """
    Initialize and return an approapriate searcher
    
    Parameters
    ----------
    sparse_index_path: str
        Path to sparse index
    dense_index_path: str
        Path to dense index
    encoder_name_or_path: str
        Path to query encoder checkpoint or encoder name
    encoder_class: str
        Query encoder class to use. If None, infer from `encoder`
    tokenizer_name: str
        Tokenizer name or path
    device: str
        Device to load Query encoder on. 
    prefix: str
        Query prefix if exists
    
    Returns
    -------
    Searcher: FaissSearcher | HybridSearcher | LuceneSearcher
        A sparse, dense or hybrid searcher
    """
    reader = None
    if sparse_index_path:
        ssearcher = LuceneSearcher(sparse_index_path)
        if analyzer_args:
            analyzer = get_lucene_analyzer(**analyzer_args)
            ssearcher.set_analyzer(analyzer)
            if bm25_k1 and bm25_b:
                ssearcher.set_bm25(bm25_k1, bm25_b)

    if dense_index_path:
        encoder = init_query_encoder(
            encoder=encoder_name_or_path,
            encoder_class=encoder_class,
            tokenizer_name=tokenizer_name,
            topics_name=None,
            encoded_queries=None,
            device=device,
            prefix=prefix
        )

        reader = IndexReader(sparse_index_path)
        dsearcher = FaissSearcher(dense_index_path, encoder)

        if sparse_index_path:
            hsearcher = HybridSearcher(dense_searcher=dsearcher, sparse_searcher=ssearcher)
            return hsearcher, reader
        else:
            return dsearcher, reader
    
    return ssearcher, reader


def _search(searcher: Searcher, reader: IndexReader, query: str, num_results: int = 10) -> List[SearchResult]:
    """
    Parameters:
    -----------
    searcher: FaissSearcher | HybridSearcher | LuceneSearcher
        A sparse, dense or hybrid searcher
    query: str
        Query for which to retrieve results
    num_results: int
        Maximum number of results to retrieve
    
    Returns:
    --------
    Dict:
    """
    def _get_dict(r: Union[DenseSearchResult, JLuceneSearcherResult]):
        if isinstance(r, JLuceneSearcherResult):
            return json.loads(r.raw)
        elif isinstance(r, DenseSearchResult):
            # Get document from sparse_index using index reader
            return json.loads(reader.doc(r.docid).raw())
    
    search_results = searcher.search(query, k=num_results)
    all_results = [
        SearchResult(
            docid=result["id"],
            text=result["contents"],
            score=search_results[idx].score   
        ) for idx, result in enumerate(map(lambda r: _get_dict(r), search_results))
    ]

    return all_results