File size: 4,578 Bytes
d661944
 
 
e39c176
 
d661944
e39c176
 
 
d661944
e39c176
 
 
d661944
 
e39c176
 
 
 
 
 
 
d661944
 
e39c176
d661944
 
e39c176
 
d661944
 
 
 
e39c176
 
 
d661944
e39c176
d661944
 
 
e39c176
d661944
e39c176
 
 
d661944
 
 
 
 
 
 
 
 
e39c176
d661944
 
 
 
 
 
e39c176
d661944
 
e39c176
d661944
e39c176
d661944
 
e39c176
d661944
e39c176
d661944
 
e39c176
d661944
 
 
 
 
 
 
 
53dd1eb
d661944
e39c176
 
 
 
d661944
e39c176
d661944
e39c176
d661944
 
e39c176
 
d661944
e39c176
 
 
 
 
d661944
e39c176
d661944
 
e39c176
3f7f963
e39c176
 
 
d661944
3f7f963
 
e39c176
3f7f963
e39c176
 
 
 
3f7f963
e39c176
3f7f963
 
 
d661944
e39c176
3f7f963
 
d661944
e39c176
 
d661944
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# -*- coding: utf-8 -*-
from dataclasses import dataclass
import os
import pickle
from typing import List, Dict, Optional, Type, TypeVar, TypedDict
import re
import math
from collections import Counter
import gradio as gr
import nltk
from nlp4web_codebase.ir.data_loaders.dm import Document
from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
from nlp4web_codebase.ir.models import BaseRetriever
from nltk.corpus import stopwords as nltk_stopwords

# Check nltk stopwords data
try:
    nltk.data.find("corpora/stopwords")
except LookupError:
    nltk.download("stopwords", quiet=True)

# Tokenization and helper functions
LANGUAGE = "english"
stopwords = set(nltk_stopwords.words(LANGUAGE))
word_splitter = re.compile(r"(?u)\b\w\w+\b").findall

def simple_tokenize(text: str) -> List[str]:
    words = word_splitter(text.lower())
    tokenized = [word for word in words if word not in stopwords]
    return tokenized

@dataclass
class PostingList:
    term: str
    docid_postings: List[int]
    tweight_postings: List[float]

T = TypeVar("T", bound="InvertedIndex")

@dataclass
class InvertedIndex:
    posting_lists: List[PostingList]
    vocab: Dict[str, int]
    cid2docid: Dict[str, int]
    collection_ids: List[str]
    doc_texts: Optional[List[str]] = None

    def save(self, output_dir: str) -> None:
        os.makedirs(output_dir, exist_ok=True)
        with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
            pickle.dump(self, f)

    @classmethod
    def from_saved(cls: Type[T], saved_dir: str) -> T:
        with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
            return pickle.load(f)

@dataclass
class BM25Index(InvertedIndex):

    @staticmethod
    def cache_term_weights(
        posting_lists: List[PostingList], total_docs: int, avgdl: float, dfs: List[int], dls: List[int], k1: float, b: float,
    ) -> None:
        N = total_docs
        for tid, posting_list in enumerate(posting_lists):
            idf = BM25Index.calc_idf(df=dfs[tid], N=N)
            for i, docid in enumerate(posting_list.docid_postings):
                tf = posting_list.tweight_postings[i]
                dl = dls[docid]
                posting_list.tweight_postings[i] = BM25Index.calc_regularized_tf(
                    tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
                ) * idf

    @staticmethod
    def calc_regularized_tf(tf: int, dl: float, avgdl: float, k1: float, b: float) -> float:
        return tf / (tf + k1 * (1 - b + b * dl / avgdl))

    @staticmethod
    def calc_idf(df: int, N: int):
        return math.log(1 + (N - df + 0.5) / (df + 0.5))

    @classmethod
    def build_from_documents(
        cls: Type["BM25Index"], documents: List[Document], avgdl: float, total_docs: int, k1: float = 0.9, b: float = 0.4
    ) -> BM25Index:
        # Assume run_counting() is defined to return counting object with relevant data
        counting = run_counting(documents, simple_tokenize)
        BM25Index.cache_term_weights(counting.posting_lists, total_docs, avgdl, counting.dfs, counting.dls, k1, b)
        return cls(counting.posting_lists, counting.vocab, counting.cid2docid, counting.collection_ids, counting.doc_texts)

class BM25Retriever(BaseRetriever):
    def __init__(self, index_dir: str) -> None:
        self.index = BM25Index.from_saved(index_dir)

    def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
        toks = simple_tokenize(query)
        docid2score = Counter()
        for tok in toks:
            if tok in self.index.vocab:
                tid = self.index.vocab[tok]
                posting_list = self.index.posting_lists[tid]
                for docid, weight in zip(posting_list.docid_postings, posting_list.tweight_postings):
                    docid2score[docid] += weight
        return {
            self.index.collection_ids[docid]: score for docid, score in docid2score.most_common(topk)
        }

# Gradio app setup
class Hit(TypedDict):
    cid: str
    score: float
    text: str

def search_sciq(query: str) -> List[Hit]:
    results = bm25_retriever.retrieve(query)
    hits = []
    for cid, score in results.items():
        docid = bm25_retriever.index.cid2docid[cid]
        text = bm25_retriever.index.doc_texts[docid]
        hits.append(Hit(cid=cid, score=score, text=text))
    return hits

bm25_retriever = BM25Retriever(index_dir="output/bm25_index")

demo = gr.Interface(
    fn=search_sciq,
    inputs="textbox",
    outputs="json",
    description="BM25 Search Engine Demo on SciQ Dataset"
)

if __name__ == "__main__":
    demo.launch()