Gourisankar Padihary commited on
Commit
5184c29
·
1 Parent(s): e234b58

Multiple data set support

Browse files
app.py CHANGED
@@ -1,71 +1,134 @@
1
  import gradio as gr
2
  import logging
3
- from generator.compute_rmse_auc_roc_metrics import compute_rmse_auc_roc_metrics
 
 
 
 
4
 
5
- def launch_gradio(vector_store, dataset, gen_llm, val_llm):
6
  """
7
  Launch the Gradio app with pre-initialized objects.
8
  """
9
- def answer_question_with_metrics(query):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  try:
11
- logging.info(f"Processing query: {query}")
 
12
 
13
- # Generate metrics using the passed objects
14
- from main import generate_metrics
15
- response, metrics = generate_metrics(gen_llm, val_llm, vector_store, query, 1)
 
16
 
17
  response_text = f"Response: {response}\n\n"
18
- metrics_text = "Metrics:\n"
19
- for key, value in metrics.items():
20
- if key != 'response':
21
- metrics_text += f"{key}: {value}\n"
22
-
23
- return response_text, metrics_text
24
  except Exception as e:
25
  logging.error(f"Error processing query: {e}")
26
- return f"An error occurred: {e}"
27
 
28
- def compute_and_display_metrics():
29
  try:
30
- # Call the function to compute metrics
31
- relevance_rmse, utilization_rmse, adherence_auc = compute_rmse_auc_roc_metrics(
32
- gen_llm, val_llm, dataset, vector_store, 10
33
- )
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # Format the result for display
36
- result = (
37
- f"Relevance RMSE Score: {relevance_rmse}\n"
38
- f"Utilization RMSE Score: {utilization_rmse}\n"
39
- f"Overall Adherence AUC-ROC: {adherence_auc}\n"
40
- )
41
- return result
42
  except Exception as e:
43
- logging.error(f"Error during metrics computation: {e}")
44
- return f"An error occurred: {e}"
45
 
46
  # Define Gradio Blocks layout
47
  with gr.Blocks() as interface:
48
  interface.title = "Real Time RAG Pipeline Q&A"
49
  gr.Markdown("### Real Time RAG Pipeline Q&A") # Heading
50
- gr.Markdown("Ask a question and get a response with metrics calculated from the RAG pipeline.") # Description
51
 
 
 
 
 
 
 
 
 
 
52
  with gr.Row():
53
  query_input = gr.Textbox(label="Ask a question", placeholder="Type your query here")
54
  with gr.Row():
 
55
  clear_query_button = gr.Button("Clear") # Clear button
56
- submit_button = gr.Button("Submit", variant="primary") # Submit button
57
  with gr.Row():
58
  answer_output = gr.Textbox(label="Response", placeholder="Response will appear here")
 
59
  with gr.Row():
 
 
60
  metrics_output = gr.Textbox(label="Metrics", placeholder="Metrics will appear here")
61
- with gr.Row():
62
- compute_rmse_button = gr.Button("Compute RMSE & AU-ROC", variant="primary")
63
- rmse_output = gr.Textbox(label="RMSE & AU-ROC Score", placeholder="RMSE & AU-ROC score will appear here")
64
-
65
-
66
  # Define button actions
67
- submit_button.click(fn=answer_question_with_metrics, inputs=[query_input], outputs=[answer_output, metrics_output])
68
- clear_query_button.click(fn=lambda: "", outputs=[query_input]) # Clear query input
69
- compute_rmse_button.click(fn=compute_and_display_metrics, outputs=[rmse_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- interface.launch()
 
1
  import gradio as gr
2
  import logging
3
+ import threading
4
+ import time
5
+ from generator.compute_metrics import get_attributes_text
6
+ from generator.generate_metrics import generate_metrics, retrieve_and_generate_response
7
+ from io import StringIO
8
 
9
+ def launch_gradio(vector_store, gen_llm, val_llm):
10
  """
11
  Launch the Gradio app with pre-initialized objects.
12
  """
13
+ logger = logging.getLogger()
14
+ logger.setLevel(logging.INFO)
15
+
16
+ # Create a list to store logs
17
+ logs = []
18
+
19
+ # Custom log handler to capture logs and add them to the logs list
20
+ class LogHandler(logging.Handler):
21
+ def emit(self, record):
22
+ log_entry = self.format(record)
23
+ logs.append(log_entry)
24
+
25
+ # Add custom log handler to the logger
26
+ log_handler = LogHandler()
27
+ log_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
28
+ logger.addHandler(log_handler)
29
+
30
+ def log_updater():
31
+ """Background function to add logs."""
32
+ while True:
33
+ time.sleep(2) # Update logs every 2 seconds
34
+ pass # Log capture is now handled by the logging system
35
+
36
+ def get_logs():
37
+ """Retrieve logs for display."""
38
+ return "\n".join(logs[-50:]) # Only show the last 50 logs for example
39
+
40
+ # Start the logging thread
41
+ threading.Thread(target=log_updater, daemon=True).start()
42
+
43
+ def answer_question(query, state):
44
  try:
45
+ # Generate response using the passed objects
46
+ response, source_docs = retrieve_and_generate_response(gen_llm, vector_store, query)
47
 
48
+ # Update state with the response and source documents
49
+ state["query"] = query
50
+ state["response"] = response
51
+ state["source_docs"] = source_docs
52
 
53
  response_text = f"Response: {response}\n\n"
54
+ return response_text, state
 
 
 
 
 
55
  except Exception as e:
56
  logging.error(f"Error processing query: {e}")
57
+ return f"An error occurred: {e}", state
58
 
59
+ def compute_metrics(state):
60
  try:
61
+ logging.info(f"Computing metrics")
62
+
63
+ # Retrieve response and source documents from state
64
+ response = state.get("response", "")
65
+ source_docs = state.get("source_docs", {})
66
+ query = state.get("query", "")
67
+
68
+ # Generate metrics using the passed objects
69
+ attributes, metrics = generate_metrics(val_llm, response, source_docs, query, 1)
70
+
71
+ attributes_text = get_attributes_text(attributes)
72
+
73
+ metrics_text = "Metrics:\n"
74
+ for key, value in metrics.items():
75
+ if key != 'response':
76
+ metrics_text += f"{key}: {value}\n"
77
 
78
+ return attributes_text, metrics_text
 
 
 
 
 
 
79
  except Exception as e:
80
+ logging.error(f"Error computing metrics: {e}")
81
+ return f"An error occurred: {e}", ""
82
 
83
  # Define Gradio Blocks layout
84
  with gr.Blocks() as interface:
85
  interface.title = "Real Time RAG Pipeline Q&A"
86
  gr.Markdown("### Real Time RAG Pipeline Q&A") # Heading
 
87
 
88
+ # Section to display LLM names
89
+ with gr.Row():
90
+ model_info = f"Generation LLM: {gen_llm.name if hasattr(gen_llm, 'name') else 'Unknown'}\n"
91
+ model_info += f"Validation LLM: {val_llm.name if hasattr(val_llm, 'name') else 'Unknown'}\n"
92
+ gr.Textbox(value=model_info, label="Model Information", interactive=False) # Read-only textbox
93
+
94
+ # State to store response and source documents
95
+ state = gr.State(value={"query": "","response": "", "source_docs": {}})
96
+ gr.Markdown("Ask a question and get a response with metrics calculated from the RAG pipeline.") # Description
97
  with gr.Row():
98
  query_input = gr.Textbox(label="Ask a question", placeholder="Type your query here")
99
  with gr.Row():
100
+ submit_button = gr.Button("Submit", variant="primary") # Submit button
101
  clear_query_button = gr.Button("Clear") # Clear button
 
102
  with gr.Row():
103
  answer_output = gr.Textbox(label="Response", placeholder="Response will appear here")
104
+
105
  with gr.Row():
106
+ compute_metrics_button = gr.Button("Compute metrics", variant="primary")
107
+ attr_output = gr.Textbox(label="Attributes", placeholder="Attributes will appear here")
108
  metrics_output = gr.Textbox(label="Metrics", placeholder="Metrics will appear here")
109
+
110
+ #with gr.Row():
111
+
 
 
112
  # Define button actions
113
+ submit_button.click(
114
+ fn=answer_question,
115
+ inputs=[query_input, state],
116
+ outputs=[answer_output, state]
117
+ )
118
+ clear_query_button.click(fn=lambda: "", outputs=[query_input]) # Clear query input
119
+ compute_metrics_button.click(
120
+ fn=compute_metrics,
121
+ inputs=[state],
122
+ outputs=[attr_output, metrics_output]
123
+ )
124
+
125
+ # Section to display logs
126
+ with gr.Row():
127
+ start_log_button = gr.Button("Start Log Update", elem_id="start_btn") # Button to start log updates
128
+ with gr.Row():
129
+ log_section = gr.Textbox(label="Logs", interactive=False, visible=True, lines=10) # Log section
130
+
131
+ # Set button click to trigger log updates
132
+ start_log_button.click(fn=get_logs, outputs=log_section)
133
 
134
+ interface.launch()
generator/compute_metrics.py CHANGED
@@ -32,18 +32,53 @@ def compute_metrics(attributes, total_sentences):
32
 
33
  def get_metrics(attributes, total_sentences):
34
  if attributes.content:
35
- #print(attributes)
36
- result_content = attributes.content # Access the content attribute
37
- # Extract the JSON part from the result_content
38
- json_start = result_content.find("{")
39
- json_end = result_content.rfind("}") + 1
40
- json_str = result_content[json_start:json_end]
41
-
42
  try:
 
 
 
 
 
43
  result_json = json.loads(json_str)
44
  # Compute metrics using the extracted attributes
45
  metrics = compute_metrics(result_json, total_sentences)
46
  logging.info(metrics)
 
47
  return metrics
48
  except json.JSONDecodeError as e:
49
- logging.error(f"JSONDecodeError: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def get_metrics(attributes, total_sentences):
34
  if attributes.content:
 
 
 
 
 
 
 
35
  try:
36
+ result_content = attributes.content # Access the content attribute
37
+ # Extract the JSON part from the result_content
38
+ json_start = result_content.find("{")
39
+ json_end = result_content.rfind("}") + 1
40
+ json_str = result_content[json_start:json_end]
41
  result_json = json.loads(json_str)
42
  # Compute metrics using the extracted attributes
43
  metrics = compute_metrics(result_json, total_sentences)
44
  logging.info(metrics)
45
+
46
  return metrics
47
  except json.JSONDecodeError as e:
48
+ logging.error(f"JSONDecodeError: {e}")
49
+
50
+ def get_attributes_text(attributes):
51
+ try:
52
+ result_content = attributes.content # Access the content attribute
53
+ # Extract the JSON part from the result_content
54
+ json_start = result_content.find("{")
55
+ json_end = result_content.rfind("}") + 1
56
+ json_str = result_content[json_start:json_end]
57
+ result_json = json.loads(json_str)
58
+
59
+ # Extract the required fields from json
60
+ relevance_explanation = result_json.get("relevance_explanation", "N/A")
61
+ all_relevant_sentence_keys = result_json.get("all_relevant_sentence_keys", [])
62
+ overall_supported_explanation = result_json.get("overall_supported_explanation", "N/A")
63
+ overall_supported = result_json.get("overall_supported", "N/A")
64
+ sentence_support_information = result_json.get("sentence_support_information", [])
65
+ all_utilized_sentence_keys = result_json.get("all_utilized_sentence_keys", [])
66
+
67
+ # Format the metrics for display
68
+ attributes_text = "Attributes:\n"
69
+ attributes_text = f"### Relevance Explanation:\n{relevance_explanation}\n\n"
70
+ attributes_text += f"### All Relevant Sentence Keys:\n{', '.join(all_relevant_sentence_keys)}\n\n"
71
+ attributes_text += f"### Overall Supported Explanation:\n{overall_supported_explanation}\n\n"
72
+ attributes_text += f"### Overall Supported:\n{overall_supported}\n\n"
73
+ attributes_text += "### Sentence Support Information:\n"
74
+ for info in sentence_support_information:
75
+ attributes_text += f"- Response Sentence Key: {info.get('response_sentence_key', 'N/A')}\n"
76
+ attributes_text += f" Explanation: {info.get('explanation', 'N/A')}\n"
77
+ attributes_text += f" Supporting Sentence Keys: {', '.join(info.get('supporting_sentence_keys', []))}\n"
78
+ attributes_text += f" Fully Supported: {info.get('fully_supported', 'N/A')}\n"
79
+ attributes_text += f"\n### All Utilized Sentence Keys:\n{', '.join(all_utilized_sentence_keys)}"
80
+
81
+ return attributes_text
82
+ except Exception as e:
83
+ logging.error(f"Error extracting attributes: {e}")
84
+ return f"An error occurred while extracting attributes: {e}"
generator/compute_rmse_auc_roc_metrics.py CHANGED
@@ -1,6 +1,6 @@
1
 
2
  from sklearn.metrics import roc_auc_score, root_mean_squared_error
3
- from generator.generate_metrics import generate_metrics
4
  import logging
5
 
6
  def compute_rmse_auc_roc_metrics(gen_llm, val_llm, dataset, vector_store, num_question):
@@ -25,7 +25,8 @@ def compute_rmse_auc_roc_metrics(gen_llm, val_llm, dataset, vector_store, num_qu
25
  query = document['question']
26
  logging.info(f'Query number: {i + 1}')
27
  # Call the generate_metrics for each query
28
- response, metrics = generate_metrics(gen_llm, val_llm, vector_store, query, 15)
 
29
 
30
  # Extract predicted metrics (ensure these are continuous if possible)
31
  predicted_relevance = metrics.get('Context Relevance', 0) if metrics else 0
 
1
 
2
  from sklearn.metrics import roc_auc_score, root_mean_squared_error
3
+ from generator.generate_metrics import generate_metrics, retrieve_and_generate_response
4
  import logging
5
 
6
  def compute_rmse_auc_roc_metrics(gen_llm, val_llm, dataset, vector_store, num_question):
 
25
  query = document['question']
26
  logging.info(f'Query number: {i + 1}')
27
  # Call the generate_metrics for each query
28
+ response, source_docs = retrieve_and_generate_response(gen_llm, vector_store, query)
29
+ attributes, metrics = generate_metrics(val_llm, response, source_docs, query, 25)
30
 
31
  # Extract predicted metrics (ensure these are continuous if possible)
32
  predicted_relevance = metrics.get('Context Relevance', 0) if metrics else 0
generator/generate_metrics.py CHANGED
@@ -5,7 +5,7 @@ from retriever.retrieve_documents import retrieve_top_k_documents
5
  from generator.compute_metrics import get_metrics
6
  from generator.extract_attributes import extract_attributes
7
 
8
- def generate_metrics(gen_llm, val_llm, vector_store, query, time_to_wait):
9
  logging.info(f'Query: {query}')
10
 
11
  # Step 1: Retrieve relevant documents for given query
@@ -21,6 +21,10 @@ def generate_metrics(gen_llm, val_llm, vector_store, query, time_to_wait):
21
 
22
  logging.info(f"Response from LLM: {response}")
23
 
 
 
 
 
24
  # Add a sleep interval to avoid hitting the rate limit
25
  time.sleep(time_to_wait) # Adjust the sleep time as needed
26
 
@@ -28,8 +32,8 @@ def generate_metrics(gen_llm, val_llm, vector_store, query, time_to_wait):
28
  logging.info(f"Extracting attributes through validation LLM")
29
  attributes, total_sentences = extract_attributes(val_llm, query, source_docs, response)
30
  logging.info(f"Extracted attributes successfully")
31
-
32
  # Step 4 : Call the get metrics calculate metrics
33
  metrics = get_metrics(attributes, total_sentences)
34
 
35
- return response, metrics
 
5
  from generator.compute_metrics import get_metrics
6
  from generator.extract_attributes import extract_attributes
7
 
8
+ def retrieve_and_generate_response(gen_llm, vector_store, query):
9
  logging.info(f'Query: {query}')
10
 
11
  # Step 1: Retrieve relevant documents for given query
 
21
 
22
  logging.info(f"Response from LLM: {response}")
23
 
24
+ return response, source_docs
25
+
26
+ def generate_metrics(val_llm, response, source_docs, query, time_to_wait):
27
+
28
  # Add a sleep interval to avoid hitting the rate limit
29
  time.sleep(time_to_wait) # Adjust the sleep time as needed
30
 
 
32
  logging.info(f"Extracting attributes through validation LLM")
33
  attributes, total_sentences = extract_attributes(val_llm, query, source_docs, response)
34
  logging.info(f"Extracted attributes successfully")
35
+
36
  # Step 4 : Call the get metrics calculate metrics
37
  metrics = get_metrics(attributes, total_sentences)
38
 
39
+ return attributes, metrics
generator/initialize_llm.py CHANGED
@@ -4,8 +4,9 @@ from langchain_groq import ChatGroq
4
 
5
  def initialize_generation_llm():
6
  os.environ["GROQ_API_KEY"] = "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"
7
- model_name = "llama3-8b-8192"
8
  llm = ChatGroq(model=model_name, temperature=0.7)
 
9
  logging.info(f'Generation LLM {model_name} initialized')
10
  return llm
11
 
@@ -13,5 +14,6 @@ def initialize_validation_llm():
13
  os.environ["GROQ_API_KEY"] = "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"
14
  model_name = "llama3-70b-8192"
15
  llm = ChatGroq(model=model_name, temperature=0.7)
 
16
  logging.info(f'Validation LLM {model_name} initialized')
17
  return llm
 
4
 
5
  def initialize_generation_llm():
6
  os.environ["GROQ_API_KEY"] = "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"
7
+ model_name = "mixtral-8x7b-32768"
8
  llm = ChatGroq(model=model_name, temperature=0.7)
9
+ llm.name = model_name
10
  logging.info(f'Generation LLM {model_name} initialized')
11
  return llm
12
 
 
14
  os.environ["GROQ_API_KEY"] = "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"
15
  model_name = "llama3-70b-8192"
16
  llm = ChatGroq(model=model_name, temperature=0.7)
17
+ llm.name = model_name
18
  logging.info(f'Validation LLM {model_name} initialized')
19
  return llm
main.py CHANGED
@@ -3,7 +3,6 @@ from data.load_dataset import load_data
3
  from generator.compute_rmse_auc_roc_metrics import compute_rmse_auc_roc_metrics
4
  from retriever.chunk_documents import chunk_documents
5
  from retriever.embed_documents import embed_documents
6
- from generator.generate_metrics import generate_metrics
7
  from generator.initialize_llm import initialize_generation_llm
8
  from generator.initialize_llm import initialize_validation_llm
9
  from app import launch_gradio
@@ -13,21 +12,43 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
13
 
14
  def main():
15
  logging.info("Starting the RAG pipeline")
16
- data_set_name = 'covidqa'
 
 
 
 
 
 
17
 
18
- # Load the dataset
19
- dataset = load_data(data_set_name)
20
- logging.info("Dataset loaded")
21
 
22
- # Chunk the dataset
23
- chunk_size = 1000 # default value
24
- if data_set_name == 'cuad':
25
- chunk_size = 3000
26
- documents = chunk_documents(dataset, chunk_size)
27
- logging.info("Documents chunked")
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # Embed the documents
30
- vector_store = embed_documents(documents)
31
  logging.info("Documents embedded")
32
 
33
  # Initialize the Generation LLM
@@ -36,18 +57,12 @@ def main():
36
  # Initialize the Validation LLM
37
  val_llm = initialize_validation_llm()
38
 
39
- # Sample question
40
- #row_num = 30
41
- #query = dataset[row_num]['question']
42
-
43
- # Call generate_metrics for above sample question
44
- #generate_metrics(gen_llm, val_llm, vector_store, query)
45
-
46
  #Compute RMSE and AUC-ROC for entire dataset
47
- #compute_rmse_auc_roc_metrics(gen_llm, val_llm, dataset, vector_store, 10)
 
48
 
49
  # Launch the Gradio app
50
- launch_gradio(vector_store, dataset, gen_llm, val_llm)
51
 
52
  logging.info("Finished!!!")
53
 
 
3
  from generator.compute_rmse_auc_roc_metrics import compute_rmse_auc_roc_metrics
4
  from retriever.chunk_documents import chunk_documents
5
  from retriever.embed_documents import embed_documents
 
6
  from generator.initialize_llm import initialize_generation_llm
7
  from generator.initialize_llm import initialize_validation_llm
8
  from app import launch_gradio
 
12
 
13
  def main():
14
  logging.info("Starting the RAG pipeline")
15
+
16
+
17
+ # Load single dataset
18
+ #dataset = load_data(data_set_name)
19
+ #logging.info("Dataset loaded")
20
+ # List of datasets to load
21
+ data_set_names = ['covidqa', 'techqa', 'cuad']
22
 
23
+ default_chunk_size = 1000
24
+ chunk_overlap = 200
 
25
 
26
+ # Dictionary to store chunked documents
27
+ all_chunked_documents = []
28
+ # Load multiple datasets
29
+ datasets = {}
30
+ for data_set_name in data_set_names:
31
+ logging.info(f"Loading dataset: {data_set_name}")
32
+ datasets[data_set_name] = load_data(data_set_name)
33
 
34
+ # Set chunk size based on dataset name
35
+ chunk_size = default_chunk_size
36
+ if data_set_name == 'cuad':
37
+ chunk_size = 4000 # Custom chunk size for 'cuad'
38
+
39
+ # Chunk documents
40
+ chunked_documents = chunk_documents(datasets[data_set_name], chunk_size=chunk_size, chunk_overlap=chunk_overlap)
41
+ all_chunked_documents.extend(chunked_documents) # Combine all chunks
42
+
43
+ # Access individual datasets
44
+ #for name, dataset in datasets.items():
45
+ #logging.info(f"Loaded {name} with {dataset.num_rows} rows")
46
+
47
+ # Logging final count
48
+ logging.info(f"Total chunked documents: {len(all_chunked_documents)}")
49
+
50
  # Embed the documents
51
+ vector_store = embed_documents(all_chunked_documents)
52
  logging.info("Documents embedded")
53
 
54
  # Initialize the Generation LLM
 
57
  # Initialize the Validation LLM
58
  val_llm = initialize_validation_llm()
59
 
 
 
 
 
 
 
 
60
  #Compute RMSE and AUC-ROC for entire dataset
61
+ data_set_name = 'covidqa'
62
+ #compute_rmse_auc_roc_metrics(gen_llm, val_llm, datasets[data_set_name], vector_store, 10)
63
 
64
  # Launch the Gradio app
65
+ launch_gradio(vector_store, gen_llm, val_llm)
66
 
67
  logging.info("Finished!!!")
68
 
retriever/retrieve_documents.py CHANGED
@@ -1,2 +1,78 @@
 
 
 
1
  def retrieve_top_k_documents(vector_store, query, top_k=5):
2
- return vector_store.similarity_search(query, k=top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from transformers import pipeline
3
+
4
  def retrieve_top_k_documents(vector_store, query, top_k=5):
5
+ documents = vector_store.similarity_search(query, k=top_k)
6
+ documents = rerank_documents(query, documents)
7
+ return documents
8
+
9
+ # Reranking: Cross-Encoder for refining top-k results
10
+ def rerank_documents(query, documents, reranker_model_name="cross-encoder/ms-marco-electra-base"):
11
+ """
12
+ Re-rank documents using a cross-encoder model.
13
+
14
+ Parameters:
15
+ query (str): The user's query.
16
+ documents (list): List of LangChain Document objects.
17
+ reranker_model_name (str): Hugging Face model name for re-ranking.
18
+
19
+ Returns:
20
+ list: Re-ranked list of Document objects with updated scores.
21
+ """
22
+ # Initialize the cross-encoder model
23
+ reranker = pipeline("text-classification", model=reranker_model_name, return_all_scores=False)
24
+
25
+ # Pair the query with each document's text
26
+ rerank_inputs = [{"text": query, "text_pair": doc.page_content} for doc in documents]
27
+
28
+ # Get relevance scores for each query-document pair
29
+ scores = reranker(rerank_inputs)
30
+
31
+ # Attach the new scores to the documents
32
+ for doc, score in zip(documents, scores):
33
+ doc.metadata["rerank_score"] = score["score"] # Add score to document metadata
34
+
35
+ # Sort documents by the rerank_score in descending order
36
+ documents = sorted(documents, key=lambda x: x.metadata.get("rerank_score", 0), reverse=True)
37
+ return documents
38
+
39
+
40
+ # Query Handling: Retrieve top-k candidates using FAISS with IVF index not used only for learning
41
+ def retrieve_top_k_documents_manual(vector_store, query, top_k=5):
42
+ """
43
+ Retrieve top-k documents using FAISS index and optionally rerank them.
44
+
45
+ Parameters:
46
+ vector_store (FAISS): The vector store containing the FAISS index and docstore.
47
+ query (str): The user's query string.
48
+ top_k (int): The number of top results to retrieve.
49
+ reranker_model_name (str): The Hugging Face model name for cross-encoder reranking.
50
+
51
+ Returns:
52
+ list: Top-k retrieved and reranked documents.
53
+ """
54
+ # Encode the query into a dense vector
55
+ embedding_model = vector_store.embedding_function
56
+ query_vector = embedding_model.embed_query(query) # Encode the query
57
+ query_vector = np.array([query_vector]).astype('float32')
58
+
59
+ # Search the FAISS index for top_k results
60
+ distances, indices = vector_store.index.search(query_vector, top_k)
61
+
62
+ # Retrieve documents from the docstore
63
+ documents = []
64
+ for idx in indices.flatten():
65
+ if idx == -1: # FAISS can return -1 for invalid indices
66
+ continue
67
+ doc_id = vector_store.index_to_docstore_id[idx]
68
+
69
+ # Access the internal dictionary of InMemoryDocstore
70
+ internal_docstore = getattr(vector_store.docstore, "_dict", None)
71
+ if internal_docstore and doc_id in internal_docstore: # Check if doc_id exists
72
+ document = internal_docstore[doc_id]
73
+ documents.append(document)
74
+
75
+ # Rerank the documents
76
+ documents = rerank_documents(query, documents)
77
+
78
+ return documents