File size: 3,340 Bytes
66b8c66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

# https://www.mixedbread.ai/blog/mxbai-embed-large-v1
# https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1

import os
import time
import pandas as pd
import numpy as np
from typing import Dict

import torch
from transformers import AutoModel, AutoTokenizer
from sentence_transformers.util import cos_sim
from accelerate import Accelerator  # Import from accelerate
from scipy.stats import zscore

# Set up environment variables for Hugging Face caching
os.environ["HF_HUB_CACHE"] = "/eos/jeodpp/home/users/consose/cache/huggingface/hub"
os.environ["HUGGINGFACE_HUB_CACHE"] = "/eos/jeodpp/home/users/consose/cache/huggingface/hub"
os.environ["HF_HOME"] = "/eos/jeodpp/home/users/consose/cache/huggingface/hub"

# Initialize the Accelerator
accelerator = Accelerator()

# Use the device managed by Accelerator
device = accelerator.device
print("Using accelerator device =", device)


from sentence_transformers import CrossEncoder
model_sf_mxbai = CrossEncoder("mixedbread-ai/mxbai-rerank-large-v1" ,device=device)





def RAG_retrieval_Base(queryText ,passages, min_threshold=0.0, max_num_passages=None):

    # # Example query
    # query = "What is the capital of France?"
    #
    # # Example passages
    # ppppassages = [
    #     "This is the first passage.",
    #     "The capital of France is Paris.",
    #     "This is the third passage.",
    #     "Paris is a beautiful city.",
    #     "The Eiffel Tower is in Paris."
    # ]
    #
    # # Rank the passages with respect to the query
    # ranked_passages = model_sf_mxbai.rank(query, ppppassages)

    try:

        df_filtered = pd.DataFrame()

        if max_num_passages:
            result_rerank = model_sf_mxbai.rank(queryText, passages, return_documents=False, top_k=max_num_passages)
        else:
            nback =int(0.1 *len(passages)) # 10% of the number of passages
            if nback<=0:
                nback=1
            result_rerank = model_sf_mxbai.rank(queryText, passages, return_documents=False, top_k=nback)

        if result_rerank:
            df = pd.DataFrame(result_rerank)  # corpus_id, score

            if min_threshold >0:
                df_filtered = df[df['score'] >= min_threshold]
            else:
                df_filtered =df.copy()

        selected_passages = [passages[i] for i in df_filtered['corpus_id']]

        # Add the selected passages as a new column "Passage" to the DataFrame
        df_filtered['Passage'] = selected_passages

        df_filtered = df_filtered.drop_duplicates(subset='Passage', keep='first')

        # df_filtered = df_filtered.sort_values(by='score', ascending=False)

        # Return the filtered DataFrame
        return df_filtered

    except Exception as e:
        # Log the exception message or handle it as needed
        print(f"An error occurred: {e}")
        return pd.DataFrame()  # Return an empty DataFrame in case of error





if __name__ == '__main__':

    queryText = 'A man is eating a piece of bread'

    # Define the passages list
    passages = [
        "A man is eating food.",
        "A man is eating pasta.",
        "The girl is carrying a baby.",
        "A man is riding a horse.",
    ]

    df_retrieved = RAG_retrieval_Base(queryText, passages, min_threshold=0, max_num_passages=3)


    print(df_retrieved)


    print("end of computations")