|
|
|
from dataclasses import dataclass |
|
import os |
|
import pickle |
|
from typing import List, Dict, Optional, Type, TypeVar, TypedDict |
|
import re |
|
import math |
|
from collections import Counter |
|
import gradio as gr |
|
import nltk |
|
from nlp4web_codebase.ir.data_loaders.dm import Document |
|
from nlp4web_codebase.ir.data_loaders.sciq import load_sciq |
|
from nlp4web_codebase.ir.models import BaseRetriever |
|
from nltk.corpus import stopwords as nltk_stopwords |
|
|
|
|
|
try: |
|
nltk.data.find("corpora/stopwords") |
|
except LookupError: |
|
nltk.download("stopwords", quiet=True) |
|
|
|
|
|
LANGUAGE = "english" |
|
stopwords = set(nltk_stopwords.words(LANGUAGE)) |
|
word_splitter = re.compile(r"(?u)\b\w\w+\b").findall |
|
|
|
def simple_tokenize(text: str) -> List[str]: |
|
words = word_splitter(text.lower()) |
|
tokenized = [word for word in words if word not in stopwords] |
|
return tokenized |
|
|
|
@dataclass |
|
class PostingList: |
|
term: str |
|
docid_postings: List[int] |
|
tweight_postings: List[float] |
|
|
|
T = TypeVar("T", bound="InvertedIndex") |
|
|
|
@dataclass |
|
class InvertedIndex: |
|
posting_lists: List[PostingList] |
|
vocab: Dict[str, int] |
|
cid2docid: Dict[str, int] |
|
collection_ids: List[str] |
|
doc_texts: Optional[List[str]] = None |
|
|
|
def save(self, output_dir: str) -> None: |
|
os.makedirs(output_dir, exist_ok=True) |
|
with open(os.path.join(output_dir, "index.pkl"), "wb") as f: |
|
pickle.dump(self, f) |
|
|
|
@classmethod |
|
def from_saved(cls: Type[T], saved_dir: str) -> T: |
|
with open(os.path.join(saved_dir, "index.pkl"), "rb") as f: |
|
return pickle.load(f) |
|
|
|
@dataclass |
|
class BM25Index(InvertedIndex): |
|
|
|
@staticmethod |
|
def cache_term_weights( |
|
posting_lists: List[PostingList], total_docs: int, avgdl: float, dfs: List[int], dls: List[int], k1: float, b: float, |
|
) -> None: |
|
N = total_docs |
|
for tid, posting_list in enumerate(posting_lists): |
|
idf = BM25Index.calc_idf(df=dfs[tid], N=N) |
|
for i, docid in enumerate(posting_list.docid_postings): |
|
tf = posting_list.tweight_postings[i] |
|
dl = dls[docid] |
|
posting_list.tweight_postings[i] = BM25Index.calc_regularized_tf( |
|
tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b |
|
) * idf |
|
|
|
@staticmethod |
|
def calc_regularized_tf(tf: int, dl: float, avgdl: float, k1: float, b: float) -> float: |
|
return tf / (tf + k1 * (1 - b + b * dl / avgdl)) |
|
|
|
@staticmethod |
|
def calc_idf(df: int, N: int): |
|
return math.log(1 + (N - df + 0.5) / (df + 0.5)) |
|
|
|
@classmethod |
|
def build_from_documents( |
|
cls: Type["BM25Index"], documents: List[Document], avgdl: float, total_docs: int, k1: float = 0.9, b: float = 0.4 |
|
) -> BM25Index: |
|
|
|
counting = run_counting(documents, simple_tokenize) |
|
BM25Index.cache_term_weights(counting.posting_lists, total_docs, avgdl, counting.dfs, counting.dls, k1, b) |
|
return cls(counting.posting_lists, counting.vocab, counting.cid2docid, counting.collection_ids, counting.doc_texts) |
|
|
|
class BM25Retriever(BaseRetriever): |
|
def __init__(self, index_dir: str) -> None: |
|
self.index = BM25Index.from_saved(index_dir) |
|
|
|
def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: |
|
toks = simple_tokenize(query) |
|
docid2score = Counter() |
|
for tok in toks: |
|
if tok in self.index.vocab: |
|
tid = self.index.vocab[tok] |
|
posting_list = self.index.posting_lists[tid] |
|
for docid, weight in zip(posting_list.docid_postings, posting_list.tweight_postings): |
|
docid2score[docid] += weight |
|
return { |
|
self.index.collection_ids[docid]: score for docid, score in docid2score.most_common(topk) |
|
} |
|
|
|
|
|
class Hit(TypedDict): |
|
cid: str |
|
score: float |
|
text: str |
|
|
|
def search_sciq(query: str) -> List[Hit]: |
|
results = bm25_retriever.retrieve(query) |
|
hits = [] |
|
for cid, score in results.items(): |
|
docid = bm25_retriever.index.cid2docid[cid] |
|
text = bm25_retriever.index.doc_texts[docid] |
|
hits.append(Hit(cid=cid, score=score, text=text)) |
|
return hits |
|
|
|
bm25_retriever = BM25Retriever(index_dir="output/bm25_index") |
|
|
|
demo = gr.Interface( |
|
fn=search_sciq, |
|
inputs="textbox", |
|
outputs="json", |
|
description="BM25 Search Engine Demo on SciQ Dataset" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|
|
|