nlpWeb / app.py
Sophia Koehler
fix2
53dd1eb
raw
history blame
4.58 kB
# -*- coding: utf-8 -*-
from dataclasses import dataclass
import os
import pickle
from typing import List, Dict, Optional, Type, TypeVar, TypedDict
import re
import math
from collections import Counter
import gradio as gr
import nltk
from nlp4web_codebase.ir.data_loaders.dm import Document
from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
from nlp4web_codebase.ir.models import BaseRetriever
from nltk.corpus import stopwords as nltk_stopwords
# Check nltk stopwords data
try:
nltk.data.find("corpora/stopwords")
except LookupError:
nltk.download("stopwords", quiet=True)
# Tokenization and helper functions
LANGUAGE = "english"
stopwords = set(nltk_stopwords.words(LANGUAGE))
word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
def simple_tokenize(text: str) -> List[str]:
words = word_splitter(text.lower())
tokenized = [word for word in words if word not in stopwords]
return tokenized
@dataclass
class PostingList:
term: str
docid_postings: List[int]
tweight_postings: List[float]
T = TypeVar("T", bound="InvertedIndex")
@dataclass
class InvertedIndex:
posting_lists: List[PostingList]
vocab: Dict[str, int]
cid2docid: Dict[str, int]
collection_ids: List[str]
doc_texts: Optional[List[str]] = None
def save(self, output_dir: str) -> None:
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
pickle.dump(self, f)
@classmethod
def from_saved(cls: Type[T], saved_dir: str) -> T:
with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
return pickle.load(f)
@dataclass
class BM25Index(InvertedIndex):
@staticmethod
def cache_term_weights(
posting_lists: List[PostingList], total_docs: int, avgdl: float, dfs: List[int], dls: List[int], k1: float, b: float,
) -> None:
N = total_docs
for tid, posting_list in enumerate(posting_lists):
idf = BM25Index.calc_idf(df=dfs[tid], N=N)
for i, docid in enumerate(posting_list.docid_postings):
tf = posting_list.tweight_postings[i]
dl = dls[docid]
posting_list.tweight_postings[i] = BM25Index.calc_regularized_tf(
tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
) * idf
@staticmethod
def calc_regularized_tf(tf: int, dl: float, avgdl: float, k1: float, b: float) -> float:
return tf / (tf + k1 * (1 - b + b * dl / avgdl))
@staticmethod
def calc_idf(df: int, N: int):
return math.log(1 + (N - df + 0.5) / (df + 0.5))
@classmethod
def build_from_documents(
cls: Type["BM25Index"], documents: List[Document], avgdl: float, total_docs: int, k1: float = 0.9, b: float = 0.4
) -> BM25Index:
# Assume run_counting() is defined to return counting object with relevant data
counting = run_counting(documents, simple_tokenize)
BM25Index.cache_term_weights(counting.posting_lists, total_docs, avgdl, counting.dfs, counting.dls, k1, b)
return cls(counting.posting_lists, counting.vocab, counting.cid2docid, counting.collection_ids, counting.doc_texts)
class BM25Retriever(BaseRetriever):
def __init__(self, index_dir: str) -> None:
self.index = BM25Index.from_saved(index_dir)
def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
toks = simple_tokenize(query)
docid2score = Counter()
for tok in toks:
if tok in self.index.vocab:
tid = self.index.vocab[tok]
posting_list = self.index.posting_lists[tid]
for docid, weight in zip(posting_list.docid_postings, posting_list.tweight_postings):
docid2score[docid] += weight
return {
self.index.collection_ids[docid]: score for docid, score in docid2score.most_common(topk)
}
# Gradio app setup
class Hit(TypedDict):
cid: str
score: float
text: str
def search_sciq(query: str) -> List[Hit]:
results = bm25_retriever.retrieve(query)
hits = []
for cid, score in results.items():
docid = bm25_retriever.index.cid2docid[cid]
text = bm25_retriever.index.doc_texts[docid]
hits.append(Hit(cid=cid, score=score, text=text))
return hits
bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
demo = gr.Interface(
fn=search_sciq,
inputs="textbox",
outputs="json",
description="BM25 Search Engine Demo on SciQ Dataset"
)
if __name__ == "__main__":
demo.launch()