medrag / medrag_multi_modal /retrieval /bm25s_retrieval.py
geekyrakshit's picture
add: BM25sRetriever
4ea2b30
raw
history blame
1.36 kB
from typing import Optional
import bm25s
import weave
from Stemmer import Stemmer
import wandb
LANGUAGE_DICT = {
"english": "en",
"french": "fr",
"german": "de",
}
class BM25sRetriever(weave.Model):
language: str
use_stemmer: bool
_retriever: Optional[bm25s.BM25]
def __init__(
self,
language: str = "english",
use_stemmer: bool = True,
retriever: Optional[bm25s.BM25] = None,
):
super().__init__(language=language, use_stemmer=use_stemmer)
self._retriever = retriever or bm25s.BM25()
def index(self, corpus_dataset_name: str, index_name: Optional[str] = None):
corpus_dataset = weave.ref(corpus_dataset_name).get().rows
corpus = [row["text"] for row in corpus_dataset]
corpus_tokens = bm25s.tokenize(
corpus,
stopwords=LANGUAGE_DICT[self.language],
stemmer=Stemmer(self.language) if self.use_stemmer else None,
)
self._retriever.index(corpus_tokens)
self._retriever.save(index_name, corpus=[dict(row) for row in corpus_dataset])
if index_name:
self._retriever.save(index_name)
if wandb.run:
artifact = wandb.Artifact(name=index_name, type="bm25s-index")
artifact.add_dir(index_name)
artifact.save()