# -*- coding: utf-8 -*- from dataclasses import dataclass import os import pickle from typing import List, Dict, Optional, Type, TypeVar, TypedDict import re import math from collections import Counter import gradio as gr import nltk from nlp4web_codebase.ir.data_loaders.dm import Document from nlp4web_codebase.ir.data_loaders.sciq import load_sciq from nlp4web_codebase.ir.models import BaseRetriever from nltk.corpus import stopwords as nltk_stopwords # Check nltk stopwords data try: nltk.data.find("corpora/stopwords") except LookupError: nltk.download("stopwords", quiet=True) # Tokenization and helper functions LANGUAGE = "english" stopwords = set(nltk_stopwords.words(LANGUAGE)) word_splitter = re.compile(r"(?u)\b\w\w+\b").findall def simple_tokenize(text: str) -> List[str]: words = word_splitter(text.lower()) tokenized = [word for word in words if word not in stopwords] return tokenized @dataclass class PostingList: term: str docid_postings: List[int] tweight_postings: List[float] T = TypeVar("T", bound="InvertedIndex") @dataclass class InvertedIndex: posting_lists: List[PostingList] vocab: Dict[str, int] cid2docid: Dict[str, int] collection_ids: List[str] doc_texts: Optional[List[str]] = None def save(self, output_dir: str) -> None: os.makedirs(output_dir, exist_ok=True) with open(os.path.join(output_dir, "index.pkl"), "wb") as f: pickle.dump(self, f) @classmethod def from_saved(cls: Type[T], saved_dir: str) -> T: with open(os.path.join(saved_dir, "index.pkl"), "rb") as f: return pickle.load(f) @dataclass class BM25Index(InvertedIndex): @staticmethod def cache_term_weights( posting_lists: List[PostingList], total_docs: int, avgdl: float, dfs: List[int], dls: List[int], k1: float, b: float, ) -> None: N = total_docs for tid, posting_list in enumerate(posting_lists): idf = BM25Index.calc_idf(df=dfs[tid], N=N) for i, docid in enumerate(posting_list.docid_postings): tf = posting_list.tweight_postings[i] dl = dls[docid] posting_list.tweight_postings[i] = BM25Index.calc_regularized_tf( tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b ) * idf @staticmethod def calc_regularized_tf(tf: int, dl: float, avgdl: float, k1: float, b: float) -> float: return tf / (tf + k1 * (1 - b + b * dl / avgdl)) @staticmethod def calc_idf(df: int, N: int): return math.log(1 + (N - df + 0.5) / (df + 0.5)) @classmethod def build_from_documents( cls: Type["BM25Index"], documents: List[Document], avgdl: float, total_docs: int, k1: float = 0.9, b: float = 0.4 ) -> BM25Index: # Assume run_counting() is defined to return counting object with relevant data counting = run_counting(documents, simple_tokenize) BM25Index.cache_term_weights(counting.posting_lists, total_docs, avgdl, counting.dfs, counting.dls, k1, b) return cls(counting.posting_lists, counting.vocab, counting.cid2docid, counting.collection_ids, counting.doc_texts) class BM25Retriever(BaseRetriever): def __init__(self, index_dir: str) -> None: self.index = BM25Index.from_saved(index_dir) def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: toks = simple_tokenize(query) docid2score = Counter() for tok in toks: if tok in self.index.vocab: tid = self.index.vocab[tok] posting_list = self.index.posting_lists[tid] for docid, weight in zip(posting_list.docid_postings, posting_list.tweight_postings): docid2score[docid] += weight return { self.index.collection_ids[docid]: score for docid, score in docid2score.most_common(topk) } # Gradio app setup class Hit(TypedDict): cid: str score: float text: str def search_sciq(query: str) -> List[Hit]: results = bm25_retriever.retrieve(query) hits = [] for cid, score in results.items(): docid = bm25_retriever.index.cid2docid[cid] text = bm25_retriever.index.doc_texts[docid] hits.append(Hit(cid=cid, score=score, text=text)) return hits bm25_retriever = BM25Retriever(index_dir="output/bm25_index") demo = gr.Interface( fn=search_sciq, inputs="textbox", outputs="json", description="BM25 Search Engine Demo on SciQ Dataset" ) if __name__ == "__main__": demo.launch()