Spaces:
Runtime error
Runtime error
File size: 4,115 Bytes
aeb12b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import json
from typing import List, Literal, Protocol, Tuple, TypedDict, Union
from pyserini.analysis import get_lucene_analyzer
from pyserini.index import IndexReader
from pyserini.search import DenseSearchResult, JLuceneSearcherResult
from pyserini.search.faiss.__main__ import init_query_encoder
from pyserini.search.faiss import FaissSearcher
from pyserini.search.hybrid import HybridSearcher
from pyserini.search.lucene import LuceneSearcher
EncoderClass = Literal["dkrr", "dpr", "tct_colbert", "ance", "sentence", "contriever", "auto"]
class AnalyzerArgs(TypedDict):
language: str
stemming: bool
stemmer: str
stopwords: bool
huggingFaceTokenizer: str
class SearchResult(TypedDict):
docid: str
text: str
score: float
language: str
class Searcher(Protocol):
def search(self, query: str, **kwargs) -> List[Union[DenseSearchResult, JLuceneSearcherResult]]:
...
def init_searcher_and_reader(
sparse_index_path: str = None,
bm25_k1: float = None,
bm25_b: float = None,
analyzer_args: AnalyzerArgs = None,
dense_index_path: str = None,
encoder_name_or_path: str = None,
encoder_class: EncoderClass = None,
tokenizer_name: str = None,
device: str = None,
prefix: str = None
) -> Tuple[Union[FaissSearcher, HybridSearcher, LuceneSearcher], IndexReader]:
"""
Initialize and return an approapriate searcher
Parameters
----------
sparse_index_path: str
Path to sparse index
dense_index_path: str
Path to dense index
encoder_name_or_path: str
Path to query encoder checkpoint or encoder name
encoder_class: str
Query encoder class to use. If None, infer from `encoder`
tokenizer_name: str
Tokenizer name or path
device: str
Device to load Query encoder on.
prefix: str
Query prefix if exists
Returns
-------
Searcher: FaissSearcher | HybridSearcher | LuceneSearcher
A sparse, dense or hybrid searcher
"""
reader = None
if sparse_index_path:
ssearcher = LuceneSearcher(sparse_index_path)
if analyzer_args:
analyzer = get_lucene_analyzer(**analyzer_args)
ssearcher.set_analyzer(analyzer)
if bm25_k1 and bm25_b:
ssearcher.set_bm25(bm25_k1, bm25_b)
if dense_index_path:
encoder = init_query_encoder(
encoder=encoder_name_or_path,
encoder_class=encoder_class,
tokenizer_name=tokenizer_name,
topics_name=None,
encoded_queries=None,
device=device,
prefix=prefix
)
reader = IndexReader(sparse_index_path)
dsearcher = FaissSearcher(dense_index_path, encoder)
if sparse_index_path:
hsearcher = HybridSearcher(dense_searcher=dsearcher, sparse_searcher=ssearcher)
return hsearcher, reader
else:
return dsearcher, reader
return ssearcher, reader
def _search(searcher: Searcher, reader: IndexReader, query: str, num_results: int = 10) -> List[SearchResult]:
"""
Parameters:
-----------
searcher: FaissSearcher | HybridSearcher | LuceneSearcher
A sparse, dense or hybrid searcher
query: str
Query for which to retrieve results
num_results: int
Maximum number of results to retrieve
Returns:
--------
Dict:
"""
def _get_dict(r: Union[DenseSearchResult, JLuceneSearcherResult]):
if isinstance(r, JLuceneSearcherResult):
return json.loads(r.raw)
elif isinstance(r, DenseSearchResult):
# Get document from sparse_index using index reader
return json.loads(reader.doc(r.docid).raw())
search_results = searcher.search(query, k=num_results)
all_results = [
SearchResult(
docid=result["id"],
text=result["contents"],
score=search_results[idx].score
) for idx, result in enumerate(map(lambda r: _get_dict(r), search_results))
]
return all_results
|