import json
from typing import List, Literal, Protocol, Tuple, TypedDict, Union

from pyserini.analysis import get_lucene_analyzer
from pyserini.index import IndexReader
from pyserini.search import DenseSearchResult, JLuceneSearcherResult
from pyserini.search.faiss.__main__ import init_query_encoder
from pyserini.search.faiss import FaissSearcher
from pyserini.search.hybrid import HybridSearcher
from pyserini.search.lucene import LuceneSearcher

EncoderClass = Literal["dkrr", "dpr", "tct_colbert", "ance", "sentence", "contriever", "auto"]


class AnalyzerArgs(TypedDict):
    language: str
    stemming: bool
    stemmer: str
    stopwords: bool
    huggingFaceTokenizer: str


class SearchResult(TypedDict):
    docid: str
    text: str
    score: float
    language: str


class Searcher(Protocol):
    def search(self, query: str, **kwargs) -> List[Union[DenseSearchResult, JLuceneSearcherResult]]:
        ...


def init_searcher_and_reader(
    sparse_index_path: str = None,
    bm25_k1: float = None,
    bm25_b: float = None,
    analyzer_args: AnalyzerArgs = None,
    dense_index_path: str = None,
    encoder_name_or_path: str = None,
    encoder_class: EncoderClass = None, 
    tokenizer_name: str = None,
    device: str = None,
    prefix: str = None
) -> Tuple[Union[FaissSearcher, HybridSearcher, LuceneSearcher], IndexReader]:
    """
    Initialize and return an approapriate searcher
    
    Parameters
    ----------
    sparse_index_path: str
        Path to sparse index
    dense_index_path: str
        Path to dense index
    encoder_name_or_path: str
        Path to query encoder checkpoint or encoder name
    encoder_class: str
        Query encoder class to use. If None, infer from `encoder`
    tokenizer_name: str
        Tokenizer name or path
    device: str
        Device to load Query encoder on. 
    prefix: str
        Query prefix if exists
    
    Returns
    -------
    Searcher: FaissSearcher | HybridSearcher | LuceneSearcher
        A sparse, dense or hybrid searcher
    """
    reader = None
    if sparse_index_path:
        ssearcher = LuceneSearcher(sparse_index_path)
        if analyzer_args:
            analyzer = get_lucene_analyzer(**analyzer_args)
            ssearcher.set_analyzer(analyzer)
            if bm25_k1 and bm25_b:
                ssearcher.set_bm25(bm25_k1, bm25_b)

    if dense_index_path:
        encoder = init_query_encoder(
            encoder=encoder_name_or_path,
            encoder_class=encoder_class,
            tokenizer_name=tokenizer_name,
            topics_name=None,
            encoded_queries=None,
            device=device,
            prefix=prefix
        )

        reader = IndexReader(sparse_index_path)
        dsearcher = FaissSearcher(dense_index_path, encoder)

        if sparse_index_path:
            hsearcher = HybridSearcher(dense_searcher=dsearcher, sparse_searcher=ssearcher)
            return hsearcher, reader
        else:
            return dsearcher, reader
    
    return ssearcher, reader


def _search(searcher: Searcher, reader: IndexReader, query: str, num_results: int = 10) -> List[SearchResult]:
    """
    Parameters:
    -----------
    searcher: FaissSearcher | HybridSearcher | LuceneSearcher
        A sparse, dense or hybrid searcher
    query: str
        Query for which to retrieve results
    num_results: int
        Maximum number of results to retrieve
    
    Returns:
    --------
    Dict:
    """
    def _get_dict(r: Union[DenseSearchResult, JLuceneSearcherResult]):
        if isinstance(r, JLuceneSearcherResult):
            return json.loads(r.raw)
        elif isinstance(r, DenseSearchResult):
            # Get document from sparse_index using index reader
            return json.loads(reader.doc(r.docid).raw())
    
    search_results = searcher.search(query, k=num_results)
    all_results = [
        SearchResult(
            docid=result["id"],
            text=result["contents"],
            score=search_results[idx].score   
        ) for idx, result in enumerate(map(lambda r: _get_dict(r), search_results))
    ]

    return all_results