Spaces:
Running
Running
# https://www.mixedbread.ai/blog/mxbai-embed-large-v1 | |
# https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1 | |
import os | |
import time | |
import pandas as pd | |
import numpy as np | |
from typing import Dict | |
import torch | |
from transformers import AutoModel, AutoTokenizer | |
from sentence_transformers.util import cos_sim | |
from accelerate import Accelerator # Import from accelerate | |
from scipy.stats import zscore | |
# Set up environment variables for Hugging Face caching | |
os.environ["HF_HUB_CACHE"] = "/eos/jeodpp/home/users/consose/cache/huggingface/hub" | |
os.environ["HUGGINGFACE_HUB_CACHE"] = "/eos/jeodpp/home/users/consose/cache/huggingface/hub" | |
os.environ["HF_HOME"] = "/eos/jeodpp/home/users/consose/cache/huggingface/hub" | |
# Initialize the Accelerator | |
accelerator = Accelerator() | |
# Use the device managed by Accelerator | |
device = accelerator.device | |
print("Using accelerator device =", device) | |
from sentence_transformers import CrossEncoder | |
model_sf_mxbai = CrossEncoder("mixedbread-ai/mxbai-rerank-large-v1" ,device=device) | |
def RAG_retrieval_Base(queryText ,passages, min_threshold=0.0, max_num_passages=None): | |
# # Example query | |
# query = "What is the capital of France?" | |
# | |
# # Example passages | |
# ppppassages = [ | |
# "This is the first passage.", | |
# "The capital of France is Paris.", | |
# "This is the third passage.", | |
# "Paris is a beautiful city.", | |
# "The Eiffel Tower is in Paris." | |
# ] | |
# | |
# # Rank the passages with respect to the query | |
# ranked_passages = model_sf_mxbai.rank(query, ppppassages) | |
try: | |
df_filtered = pd.DataFrame() | |
if max_num_passages: | |
result_rerank = model_sf_mxbai.rank(queryText, passages, return_documents=False, top_k=max_num_passages) | |
else: | |
nback =int(0.1 *len(passages)) # 10% of the number of passages | |
if nback<=0: | |
nback=1 | |
result_rerank = model_sf_mxbai.rank(queryText, passages, return_documents=False, top_k=nback) | |
if result_rerank: | |
df = pd.DataFrame(result_rerank) # corpus_id, score | |
if min_threshold >0: | |
df_filtered = df[df['score'] >= min_threshold] | |
else: | |
df_filtered =df.copy() | |
selected_passages = [passages[i] for i in df_filtered['corpus_id']] | |
# Add the selected passages as a new column "Passage" to the DataFrame | |
df_filtered['Passage'] = selected_passages | |
df_filtered = df_filtered.drop_duplicates(subset='Passage', keep='first') | |
# df_filtered = df_filtered.sort_values(by='score', ascending=False) | |
# Return the filtered DataFrame | |
return df_filtered | |
except Exception as e: | |
# Log the exception message or handle it as needed | |
print(f"An error occurred: {e}") | |
return pd.DataFrame() # Return an empty DataFrame in case of error | |
if __name__ == '__main__': | |
queryText = 'A man is eating a piece of bread' | |
# Define the passages list | |
passages = [ | |
"A man is eating food.", | |
"A man is eating pasta.", | |
"The girl is carrying a baby.", | |
"A man is riding a horse.", | |
] | |
df_retrieved = RAG_retrieval_Base(queryText, passages, min_threshold=0, max_num_passages=3) | |
print(df_retrieved) | |
print("end of computations") | |