hebrew-dentsit / reranker.py
borodache's picture
Upload 6 files
fb0495b verified
raw
history blame
900 Bytes
from sklearn.metrics.pairwise import cosine_similarity
from text_embedder_encoder import TextEmbedder
class Reranker:
def __init__(self):
self.text_embedder = TextEmbedder()
def rerank(self, query, retrieved_docs, top_n=5):
# Encode query and documents
query_embedding = self.text_embedder.encode(query)
doc_embeddings = self.text_embedder.encode_many(retrieved_docs)
similarity_scores = cosine_similarity([query_embedding], doc_embeddings)[0]
similarity_scores_with_idxes = list(zip(similarity_scores, range(len(similarity_scores))))
similarity_scores_with_idxes.sort(reverse=True)
similarity_scores_with_idxes_final = similarity_scores_with_idxes[:top_n]
reranked_docs = [retrieved_docs[idx] for score, idx in similarity_scores_with_idxes_final if score >= 0.7]
return reranked_docs