Spaces:
Running
Running
File size: 1,355 Bytes
4ea2b30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from typing import Optional
import bm25s
import weave
from Stemmer import Stemmer
import wandb
LANGUAGE_DICT = {
"english": "en",
"french": "fr",
"german": "de",
}
class BM25sRetriever(weave.Model):
language: str
use_stemmer: bool
_retriever: Optional[bm25s.BM25]
def __init__(
self,
language: str = "english",
use_stemmer: bool = True,
retriever: Optional[bm25s.BM25] = None,
):
super().__init__(language=language, use_stemmer=use_stemmer)
self._retriever = retriever or bm25s.BM25()
def index(self, corpus_dataset_name: str, index_name: Optional[str] = None):
corpus_dataset = weave.ref(corpus_dataset_name).get().rows
corpus = [row["text"] for row in corpus_dataset]
corpus_tokens = bm25s.tokenize(
corpus,
stopwords=LANGUAGE_DICT[self.language],
stemmer=Stemmer(self.language) if self.use_stemmer else None,
)
self._retriever.index(corpus_tokens)
self._retriever.save(index_name, corpus=[dict(row) for row in corpus_dataset])
if index_name:
self._retriever.save(index_name)
if wandb.run:
artifact = wandb.Artifact(name=index_name, type="bm25s-index")
artifact.add_dir(index_name)
artifact.save()
|