File size: 794 Bytes
7fdb8e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import os

from huggingface_hub import InferenceClient

from rag_demo.rag.base.query import Query
from rag_demo.rag.base.template_factory import RAGStep
from rag_demo.preprocessing.embed import EmbeddedChunk


class Reranker(RAGStep):
    def generate(
        self, query: Query, chunks: list[EmbeddedChunk], keep_top_k: int
    ) -> list[EmbeddedChunk]:
        api = InferenceClient(
            model="intfloat/multilingual-e5-large-instruct",
            token=os.getenv("HF_API_TOKEN"),
        )
        similarity = api.sentence_similarity(
            query.content, [chunk.content for chunk in chunks]
        )
        for chunk, sim in zip(chunks, similarity):
            chunk.similarity = sim

        return sorted(chunks, key=lambda x: x.similarity, reverse=True)[:keep_top_k]