gourisankar85 commited on
Commit
e73ca1c
·
verified ·
1 Parent(s): d0b15a7

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +123 -37
  2. config.py +36 -18
  3. main.py +33 -63
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import gradio as gr
2
  import logging
3
- import threading
4
  import time
5
  from generator.compute_metrics import get_attributes_text
6
  from generator.generate_metrics import generate_metrics, retrieve_and_generate_response
7
  from config import AppConfig, ConfigConstants
8
  from generator.initialize_llm import initialize_generation_llm, initialize_validation_llm
9
- from generator.document_utils import get_logs, initialize_logging
 
10
 
11
  def launch_gradio(config : AppConfig):
12
  """
@@ -14,6 +14,9 @@ def launch_gradio(config : AppConfig):
14
  """
15
  initialize_logging()
16
 
 
 
 
17
  def update_logs_periodically():
18
  while True:
19
  time.sleep(2) # Wait for 2 seconds
@@ -21,6 +24,10 @@ def launch_gradio(config : AppConfig):
21
 
22
  def answer_question(query, state):
23
  try:
 
 
 
 
24
  # Generate response using the passed objects
25
  response, source_docs = retrieve_and_generate_response(config.gen_llm, config.vector_store, query)
26
 
@@ -29,7 +36,7 @@ def launch_gradio(config : AppConfig):
29
  state["response"] = response
30
  state["source_docs"] = source_docs
31
 
32
- response_text = f"Response: {response}\n\n"
33
  return response_text, state
34
  except Exception as e:
35
  logging.error(f"Error processing query: {e}")
@@ -49,7 +56,7 @@ def launch_gradio(config : AppConfig):
49
 
50
  attributes_text = get_attributes_text(attributes)
51
 
52
- metrics_text = "Metrics:\n"
53
  for key, value in metrics.items():
54
  if key != 'response':
55
  metrics_text += f"{key}: {value}\n"
@@ -70,11 +77,14 @@ def launch_gradio(config : AppConfig):
70
  return get_updated_model_info()
71
 
72
  def get_updated_model_info():
 
73
  """Generate and return the updated model information string."""
74
  return (
75
  f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
76
  f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
 
77
  f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
 
78
  )
79
 
80
  # Wrappers for event listeners
@@ -84,50 +94,125 @@ def launch_gradio(config : AppConfig):
84
  def reinitialize_val_llm(val_llm_name):
85
  return reinitialize_llm("validation", val_llm_name)
86
 
 
 
 
 
87
  # Define Gradio Blocks layout
88
  with gr.Blocks() as interface:
89
  interface.title = "Real Time RAG Pipeline Q&A"
90
- gr.Markdown("# Real Time RAG Pipeline Q&A") # Heading
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- # Textbox for new generation LLM name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  with gr.Row():
94
- new_gen_llm_input = gr.Dropdown(
95
- label="Generation Model",
96
- choices=ConfigConstants.GENERATION_MODELS, # Directly use the list
97
- value=ConfigConstants.GENERATION_MODELS[0] if ConfigConstants.GENERATION_MODELS else None, # First value dynamically
98
- interactive=True
99
- )
100
-
101
- new_val_llm_input = gr.Dropdown(
102
- label="Validation Model",
103
- choices=ConfigConstants.VALIDATION_MODELS, # Directly use the list
104
- value=ConfigConstants.VALIDATION_MODELS[0] if ConfigConstants.VALIDATION_MODELS else None, # First value dynamically
105
- interactive=True
106
- )
107
-
108
- model_info_display = gr.Textbox(
109
- value=get_updated_model_info(), # Use the helper function
110
- label="System Information",
111
- interactive=False # Read-only textbox
112
- )
 
 
 
 
 
 
 
 
 
 
113
 
114
- # State to store response and source documents
115
- state = gr.State(value={"query": "","response": "", "source_docs": {}})
116
- gr.Markdown("Ask a question and get a response with metrics calculated from the RAG pipeline.") # Description
117
  with gr.Row():
118
- query_input = gr.Textbox(label="Ask a question", placeholder="Type your query here")
119
- with gr.Row():
120
- submit_button = gr.Button("Submit", variant="primary", scale = 0) # Submit button
121
- clear_query_button = gr.Button("Clear", scale = 0) # Clear button
122
- with gr.Row():
123
- answer_output = gr.Textbox(label="Response", placeholder="Response will appear here")
124
 
125
  with gr.Row():
126
  compute_metrics_button = gr.Button("Compute metrics", variant="primary" , scale = 0)
127
  attr_output = gr.Textbox(label="Attributes", placeholder="Attributes will appear here")
128
  metrics_output = gr.Textbox(label="Metrics", placeholder="Metrics will appear here")
129
 
130
- #with gr.Row():
 
 
 
 
131
  # Attach event listeners to update model info on change
132
  new_gen_llm_input.change(reinitialize_gen_llm, inputs=new_gen_llm_input, outputs=model_info_display)
133
  new_val_llm_input.change(reinitialize_val_llm, inputs=new_val_llm_input, outputs=model_info_display)
@@ -146,8 +231,9 @@ def launch_gradio(config : AppConfig):
146
  )
147
 
148
  # Section to display logs
149
- with gr.Row():
150
- log_section = gr.Textbox(label="Logs", interactive=False, visible=True, lines=10 , every=2) # Log section
 
151
 
152
  # Update UI when logs_state changes
153
  interface.queue()
 
1
  import gradio as gr
2
  import logging
 
3
  import time
4
  from generator.compute_metrics import get_attributes_text
5
  from generator.generate_metrics import generate_metrics, retrieve_and_generate_response
6
  from config import AppConfig, ConfigConstants
7
  from generator.initialize_llm import initialize_generation_llm, initialize_validation_llm
8
+ from generator.document_utils import get_logs, initialize_logging
9
+ from retriever.load_selected_datasets import load_selected_datasets
10
 
11
  def launch_gradio(config : AppConfig):
12
  """
 
14
  """
15
  initialize_logging()
16
 
17
+ # **🔹 Always get the latest loaded datasets**
18
+ config.detect_loaded_datasets()
19
+
20
  def update_logs_periodically():
21
  while True:
22
  time.sleep(2) # Wait for 2 seconds
 
24
 
25
  def answer_question(query, state):
26
  try:
27
+ # Ensure vector store is updated before use
28
+ if config.vector_store is None:
29
+ return "Please load a dataset first.", state
30
+
31
  # Generate response using the passed objects
32
  response, source_docs = retrieve_and_generate_response(config.gen_llm, config.vector_store, query)
33
 
 
36
  state["response"] = response
37
  state["source_docs"] = source_docs
38
 
39
+ response_text = f"Response from Model : {response}\n\n"
40
  return response_text, state
41
  except Exception as e:
42
  logging.error(f"Error processing query: {e}")
 
56
 
57
  attributes_text = get_attributes_text(attributes)
58
 
59
+ metrics_text = ""
60
  for key, value in metrics.items():
61
  if key != 'response':
62
  metrics_text += f"{key}: {value}\n"
 
77
  return get_updated_model_info()
78
 
79
  def get_updated_model_info():
80
+ loaded_datasets_str = ", ".join(config.loaded_datasets) if config.loaded_datasets else "None"
81
  """Generate and return the updated model information string."""
82
  return (
83
  f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
84
  f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
85
+ f"Re-ranking LLM: {ConfigConstants.RE_RANKER_MODEL_NAME}\n"
86
  f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
87
+ f"Loaded Datasets: {loaded_datasets_str}\n"
88
  )
89
 
90
  # Wrappers for event listeners
 
94
  def reinitialize_val_llm(val_llm_name):
95
  return reinitialize_llm("validation", val_llm_name)
96
 
97
+ # Function to update query input when a question is selected from the dropdown
98
+ def update_query_input(selected_question):
99
+ return selected_question
100
+
101
  # Define Gradio Blocks layout
102
  with gr.Blocks() as interface:
103
  interface.title = "Real Time RAG Pipeline Q&A"
104
+ gr.Markdown("""
105
+ # Real Time RAG Pipeline Q&A
106
+ The **Retrieval-Augmented Generation (RAG) Pipeline** combines retrieval-based and generative AI models to provide accurate and context-aware answers to your questions.
107
+ It retrieves relevant documents from a dataset (e.g., COVIDQA, TechQA, FinQA) and uses a generative model to synthesize a response.
108
+ Metrics are computed to evaluate the quality of the response and the retrieval process.
109
+ """)
110
+ # Model Configuration
111
+ with gr.Accordion("System Information", open=False):
112
+ with gr.Accordion("DataSet", open=False):
113
+ with gr.Row():
114
+ dataset_selector = gr.CheckboxGroup(ConfigConstants.DATA_SET_NAMES, label="Select Datasets to Load")
115
+ load_button = gr.Button("Load", scale= 0)
116
+
117
+ with gr.Row():
118
+ # Column for Generation Model Dropdown
119
+ with gr.Column(scale=1):
120
+ new_gen_llm_input = gr.Dropdown(
121
+ label="Generation Model",
122
+ choices=ConfigConstants.GENERATION_MODELS,
123
+ value=ConfigConstants.GENERATION_MODELS[0] if ConfigConstants.GENERATION_MODELS else None,
124
+ interactive=True,
125
+ info="Select the generative model for response generation."
126
+ )
127
+
128
+ # Column for Validation Model Dropdown
129
+ with gr.Column(scale=1):
130
+ new_val_llm_input = gr.Dropdown(
131
+ label="Validation Model",
132
+ choices=ConfigConstants.VALIDATION_MODELS,
133
+ value=ConfigConstants.VALIDATION_MODELS[0] if ConfigConstants.VALIDATION_MODELS else None,
134
+ interactive=True,
135
+ info="Select the model for validating the response quality."
136
+ )
137
+
138
+ # Column for Model Information
139
+ with gr.Column(scale=2):
140
+ model_info_display = gr.Textbox(
141
+ value=get_updated_model_info(), # Use the helper function
142
+ label="Model Configuration",
143
+ interactive=False, # Read-only textbox
144
+ lines=5
145
+ )
146
 
147
+ # Query Section
148
+ gr.Markdown("Ask a question and get a response with metrics calculated from the RAG pipeline.")
149
+ all_questions = [
150
+ "When was the first case of COVID-19 identified?",
151
+ "What are the ages of the patients in this study?",
152
+ "Is one party required to deposit its source code into escrow with a third party, which can be released to the counterparty upon the occurrence of certain events (bankruptcy, insolvency, etc.)?",
153
+ "Explain the concept of blockchain.",
154
+ "What is the capital of France?",
155
+ "Do Surface Porosity and Pore Size Influence Mechanical Properties and Cellular Response to PEEK??",
156
+ "How does a vaccine work?",
157
+ "What is the difference between RNA and DNA?",
158
+ "What are the risk factors for heart disease?",
159
+ "What is the role of insulin in the body?",
160
+ # Add more questions as needed
161
+ ]
162
+
163
+ # Subset of questions to display as examples
164
+ example_questions = [
165
+ "When was the first case of COVID-19 identified?",
166
+ "What are the ages of the patients in this study?",
167
+ "What is the Hepatitis C virus?",
168
+ "Explain the concept of blockchain.",
169
+ "What is the capital of France?"
170
+ ]
171
  with gr.Row():
172
+ with gr.Column():
173
+ with gr.Row():
174
+ query_input = gr.Textbox(
175
+ label="Ask a question ",
176
+ placeholder="Type your query here or select from examples/dropdown",
177
+ lines=2
178
+ )
179
+ with gr.Row():
180
+ submit_button = gr.Button("Submit", variant="primary", scale=0)
181
+ clear_query_button = gr.Button("Clear", scale=0)
182
+ with gr.Column():
183
+ gr.Examples(
184
+ examples=example_questions, # Make sure the variable name matches
185
+ inputs=query_input,
186
+ label="Try these examples:"
187
+ )
188
+ question_dropdown = gr.Dropdown(
189
+ label="",
190
+ choices=all_questions,
191
+ interactive=True,
192
+ info="Choose a question from the dropdown to populate the query box."
193
+ )
194
+
195
+ # Attach event listener to dropdown
196
+ question_dropdown.change(
197
+ fn=update_query_input,
198
+ inputs=question_dropdown,
199
+ outputs=query_input
200
+ )
201
 
202
+ # Response and Metrics
 
 
203
  with gr.Row():
204
+ answer_output = gr.Textbox(label="Response", placeholder="Response will appear here", lines=2)
 
 
 
 
 
205
 
206
  with gr.Row():
207
  compute_metrics_button = gr.Button("Compute metrics", variant="primary" , scale = 0)
208
  attr_output = gr.Textbox(label="Attributes", placeholder="Attributes will appear here")
209
  metrics_output = gr.Textbox(label="Metrics", placeholder="Metrics will appear here")
210
 
211
+ # State to store response and source documents
212
+ state = gr.State(value={"query": "","response": "", "source_docs": {}})
213
+
214
+ # Pass config to update vector store
215
+ load_button.click(lambda datasets: (load_selected_datasets(datasets, config), get_updated_model_info()), inputs=dataset_selector)
216
  # Attach event listeners to update model info on change
217
  new_gen_llm_input.change(reinitialize_gen_llm, inputs=new_gen_llm_input, outputs=model_info_display)
218
  new_val_llm_input.change(reinitialize_val_llm, inputs=new_val_llm_input, outputs=model_info_display)
 
231
  )
232
 
233
  # Section to display logs
234
+ with gr.Accordion("View Live Logs", open=False):
235
+ with gr.Row():
236
+ log_section = gr.Textbox(label="Logs", interactive=False, visible=True, lines=10 , every=2) # Log section
237
 
238
  # Update UI when logs_state changes
239
  interface.queue()
config.py CHANGED
@@ -1,18 +1,36 @@
1
-
2
- class ConfigConstants:
3
- # Constants related to datasets and models
4
- DATA_SET_NAMES = ['covidqa', 'cuad']#, 'techqa' 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 'tatqa']
5
- EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L3-v2"
6
- RE_RANKER_MODEL_NAME = 'cross-encoder/ms-marco-electra-base'
7
- GENERATION_MODEL_NAME = 'mixtral-8x7b-32768'
8
- VALIDATION_MODEL_NAME = 'llama3-70b-8192'
9
- GENERATION_MODELS = ["llama3-8b-8192", "qwen-2.5-32b", "mixtral-8x7b-32768", "gemma2-9b-it" ]
10
- VALIDATION_MODELS = ["llama3-70b-8192", "deepseek-r1-distill-llama-70b" ]
11
- DEFAULT_CHUNK_SIZE = 1000
12
- CHUNK_OVERLAP = 200
13
-
14
- class AppConfig:
15
- def __init__(self, vector_store, gen_llm, val_llm):
16
- self.vector_store = vector_store
17
- self.gen_llm = gen_llm
18
- self.val_llm = val_llm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+
4
+ class ConfigConstants:
5
+ # Constants related to datasets and models
6
+ DATA_SET_PATH= '/persistent/local_datasets'
7
+ DATA_SET_NAMES = ['covidqa', 'cuad', 'techqa','delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 'tatqa']
8
+ EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L3-v2"
9
+ RE_RANKER_MODEL_NAME = 'cross-encoder/ms-marco-electra-base'
10
+ GENERATION_MODEL_NAME = 'mixtral-8x7b-32768'
11
+ VALIDATION_MODEL_NAME = 'llama3-70b-8192'
12
+ GENERATION_MODELS = ["llama3-8b-8192", "qwen-2.5-32b", "mixtral-8x7b-32768", "gemma2-9b-it" ]
13
+ VALIDATION_MODELS = ["llama3-70b-8192", "deepseek-r1-distill-llama-70b" ]
14
+ DEFAULT_CHUNK_SIZE = 1000
15
+ CHUNK_OVERLAP = 200
16
+
17
+ class AppConfig:
18
+ def __init__(self, vector_store, gen_llm, val_llm):
19
+ self.vector_store = vector_store
20
+ self.gen_llm = gen_llm
21
+ self.val_llm = val_llm
22
+ self.loaded_datasets = self.detect_loaded_datasets() # Auto-detect loaded datasets
23
+
24
+ @staticmethod
25
+ def detect_loaded_datasets():
26
+ print('Calling detect_loaded_datasets')
27
+ """Check which datasets are already stored locally."""
28
+ local_path = ConfigConstants.DATA_SET_PATH
29
+ if not os.path.exists(local_path):
30
+ return set()
31
+
32
+ dataset_files = os.listdir(local_path)
33
+ loaded_datasets = {
34
+ file.replace("_test.pkl", "") for file in dataset_files if file.endswith("_test.pkl")
35
+ }
36
+ return loaded_datasets
main.py CHANGED
@@ -1,64 +1,34 @@
1
- import logging
2
- from config import AppConfig, ConfigConstants
3
- from data.load_dataset import load_data
4
- from generator.compute_rmse_auc_roc_metrics import compute_rmse_auc_roc_metrics
5
- from retriever.chunk_documents import chunk_documents
6
- from retriever.embed_documents import embed_documents
7
- from generator.initialize_llm import initialize_generation_llm
8
- from generator.initialize_llm import initialize_validation_llm
9
- from app import launch_gradio
10
-
11
- # Configure logging
12
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
13
-
14
- def main():
15
- logging.info("Starting the RAG pipeline")
16
-
17
- # Dictionary to store chunked documents
18
- all_chunked_documents = []
19
- datasets = {}
20
-
21
- # Load multiple datasets
22
- for data_set_name in ConfigConstants.DATA_SET_NAMES:
23
- logging.info(f"Loading dataset: {data_set_name}")
24
- datasets[data_set_name] = load_data(data_set_name)
25
-
26
- # Set chunk size based on dataset name
27
- chunk_size = ConfigConstants.DEFAULT_CHUNK_SIZE
28
- if data_set_name == 'cuad':
29
- chunk_size = 4000 # Custom chunk size for 'cuad'
30
-
31
- # Chunk documents
32
- chunked_documents = chunk_documents(datasets[data_set_name], chunk_size=chunk_size, chunk_overlap=ConfigConstants.CHUNK_OVERLAP)
33
- all_chunked_documents.extend(chunked_documents) # Combine all chunks
34
-
35
- # Access individual datasets
36
- #for name, dataset in datasets.items():
37
- #logging.info(f"Loaded {name} with {dataset.num_rows} rows")
38
-
39
- # Logging final count
40
- logging.info(f"Total chunked documents: {len(all_chunked_documents)}")
41
-
42
- # Embed the documents
43
- vector_store = embed_documents(all_chunked_documents)
44
- logging.info("Documents embedded")
45
-
46
- # Initialize the Generation LLM
47
- gen_llm = initialize_generation_llm(ConfigConstants.GENERATION_MODEL_NAME)
48
-
49
- # Initialize the Validation LLM
50
- val_llm = initialize_validation_llm(ConfigConstants.VALIDATION_MODEL_NAME)
51
-
52
- #Compute RMSE and AUC-ROC for entire dataset
53
- #Enable below code for calculation
54
- #data_set_name = 'covidqa'
55
- #compute_rmse_auc_roc_metrics(gen_llm, val_llm, datasets[data_set_name], vector_store, 10)
56
-
57
- # Launch the Gradio app
58
- config = AppConfig(vector_store= vector_store, gen_llm = gen_llm, val_llm = val_llm)
59
- launch_gradio(config)
60
-
61
- logging.info("Finished!!!")
62
-
63
- if __name__ == "__main__":
64
  main()
 
1
+ import logging
2
+ from config import AppConfig, ConfigConstants
3
+ from generator.compute_rmse_auc_roc_metrics import compute_rmse_auc_roc_metrics
4
+ from retriever.load_selected_datasets import load_selected_datasets
5
+ from generator.initialize_llm import initialize_generation_llm
6
+ from generator.initialize_llm import initialize_validation_llm
7
+ from app import launch_gradio
8
+
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
11
+
12
+ def main():
13
+ logging.info("Starting the RAG pipeline")
14
+
15
+ # Initialize the Generation LLM
16
+ gen_llm = initialize_generation_llm(ConfigConstants.GENERATION_MODEL_NAME)
17
+
18
+ # Initialize the Validation LLM
19
+ val_llm = initialize_validation_llm(ConfigConstants.VALIDATION_MODEL_NAME)
20
+
21
+ #Compute RMSE and AUC-ROC for entire dataset
22
+ #Enable below code for calculation
23
+ #data_set_name = 'covidqa'
24
+ #compute_rmse_auc_roc_metrics(gen_llm, val_llm, datasets[data_set_name], vector_store, 10)
25
+
26
+ # Launch the Gradio app
27
+ config = AppConfig(vector_store = None, gen_llm = gen_llm, val_llm = val_llm)
28
+ load_selected_datasets(['covidqa'], config)
29
+ launch_gradio(config)
30
+
31
+ logging.info("Finished!!!")
32
+
33
+ if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  main()