Gourisankar Padihary commited on
Commit
e234b58
·
1 Parent(s): bcc15bd

Added Gradio UI

Browse files
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import logging
3
+ from generator.compute_rmse_auc_roc_metrics import compute_rmse_auc_roc_metrics
4
+
5
+ def launch_gradio(vector_store, dataset, gen_llm, val_llm):
6
+ """
7
+ Launch the Gradio app with pre-initialized objects.
8
+ """
9
+ def answer_question_with_metrics(query):
10
+ try:
11
+ logging.info(f"Processing query: {query}")
12
+
13
+ # Generate metrics using the passed objects
14
+ from main import generate_metrics
15
+ response, metrics = generate_metrics(gen_llm, val_llm, vector_store, query, 1)
16
+
17
+ response_text = f"Response: {response}\n\n"
18
+ metrics_text = "Metrics:\n"
19
+ for key, value in metrics.items():
20
+ if key != 'response':
21
+ metrics_text += f"{key}: {value}\n"
22
+
23
+ return response_text, metrics_text
24
+ except Exception as e:
25
+ logging.error(f"Error processing query: {e}")
26
+ return f"An error occurred: {e}"
27
+
28
+ def compute_and_display_metrics():
29
+ try:
30
+ # Call the function to compute metrics
31
+ relevance_rmse, utilization_rmse, adherence_auc = compute_rmse_auc_roc_metrics(
32
+ gen_llm, val_llm, dataset, vector_store, 10
33
+ )
34
+
35
+ # Format the result for display
36
+ result = (
37
+ f"Relevance RMSE Score: {relevance_rmse}\n"
38
+ f"Utilization RMSE Score: {utilization_rmse}\n"
39
+ f"Overall Adherence AUC-ROC: {adherence_auc}\n"
40
+ )
41
+ return result
42
+ except Exception as e:
43
+ logging.error(f"Error during metrics computation: {e}")
44
+ return f"An error occurred: {e}"
45
+
46
+ # Define Gradio Blocks layout
47
+ with gr.Blocks() as interface:
48
+ interface.title = "Real Time RAG Pipeline Q&A"
49
+ gr.Markdown("### Real Time RAG Pipeline Q&A") # Heading
50
+ gr.Markdown("Ask a question and get a response with metrics calculated from the RAG pipeline.") # Description
51
+
52
+ with gr.Row():
53
+ query_input = gr.Textbox(label="Ask a question", placeholder="Type your query here")
54
+ with gr.Row():
55
+ clear_query_button = gr.Button("Clear") # Clear button
56
+ submit_button = gr.Button("Submit", variant="primary") # Submit button
57
+ with gr.Row():
58
+ answer_output = gr.Textbox(label="Response", placeholder="Response will appear here")
59
+ with gr.Row():
60
+ metrics_output = gr.Textbox(label="Metrics", placeholder="Metrics will appear here")
61
+ with gr.Row():
62
+ compute_rmse_button = gr.Button("Compute RMSE & AU-ROC", variant="primary")
63
+ rmse_output = gr.Textbox(label="RMSE & AU-ROC Score", placeholder="RMSE & AU-ROC score will appear here")
64
+
65
+
66
+ # Define button actions
67
+ submit_button.click(fn=answer_question_with_metrics, inputs=[query_input], outputs=[answer_output, metrics_output])
68
+ clear_query_button.click(fn=lambda: "", outputs=[query_input]) # Clear query input
69
+ compute_rmse_button.click(fn=compute_and_display_metrics, outputs=[rmse_output])
70
+
71
+ interface.launch()
generator/compute_rmse_auc_roc_metrics.py CHANGED
@@ -25,7 +25,7 @@ def compute_rmse_auc_roc_metrics(gen_llm, val_llm, dataset, vector_store, num_qu
25
  query = document['question']
26
  logging.info(f'Query number: {i + 1}')
27
  # Call the generate_metrics for each query
28
- metrics = generate_metrics(gen_llm, val_llm, vector_store, query)
29
 
30
  # Extract predicted metrics (ensure these are continuous if possible)
31
  predicted_relevance = metrics.get('Context Relevance', 0) if metrics else 0
@@ -69,3 +69,5 @@ def compute_rmse_auc_roc_metrics(gen_llm, val_llm, dataset, vector_store, num_qu
69
  logging.info(f"Relevance RMSE score: {relevance_rmse}")
70
  logging.info(f"Utilization RMSE score: {utilization_rmse}")
71
  logging.info(f"Overall Adherence AUC-ROC: {adherence_auc}")
 
 
 
25
  query = document['question']
26
  logging.info(f'Query number: {i + 1}')
27
  # Call the generate_metrics for each query
28
+ response, metrics = generate_metrics(gen_llm, val_llm, vector_store, query, 15)
29
 
30
  # Extract predicted metrics (ensure these are continuous if possible)
31
  predicted_relevance = metrics.get('Context Relevance', 0) if metrics else 0
 
69
  logging.info(f"Relevance RMSE score: {relevance_rmse}")
70
  logging.info(f"Utilization RMSE score: {utilization_rmse}")
71
  logging.info(f"Overall Adherence AUC-ROC: {adherence_auc}")
72
+
73
+ return relevance_rmse, utilization_rmse, adherence_auc
generator/generate_metrics.py CHANGED
@@ -5,7 +5,7 @@ from retriever.retrieve_documents import retrieve_top_k_documents
5
  from generator.compute_metrics import get_metrics
6
  from generator.extract_attributes import extract_attributes
7
 
8
- def generate_metrics(gen_llm, val_llm, vector_store, query):
9
  logging.info(f'Query: {query}')
10
 
11
  # Step 1: Retrieve relevant documents for given query
@@ -22,7 +22,7 @@ def generate_metrics(gen_llm, val_llm, vector_store, query):
22
  logging.info(f"Response from LLM: {response}")
23
 
24
  # Add a sleep interval to avoid hitting the rate limit
25
- time.sleep(25) # Adjust the sleep time as needed
26
 
27
  # Step 3: Extract attributes and total sentences for each query
28
  logging.info(f"Extracting attributes through validation LLM")
@@ -32,4 +32,4 @@ def generate_metrics(gen_llm, val_llm, vector_store, query):
32
  # Step 4 : Call the get metrics calculate metrics
33
  metrics = get_metrics(attributes, total_sentences)
34
 
35
- return metrics
 
5
  from generator.compute_metrics import get_metrics
6
  from generator.extract_attributes import extract_attributes
7
 
8
+ def generate_metrics(gen_llm, val_llm, vector_store, query, time_to_wait):
9
  logging.info(f'Query: {query}')
10
 
11
  # Step 1: Retrieve relevant documents for given query
 
22
  logging.info(f"Response from LLM: {response}")
23
 
24
  # Add a sleep interval to avoid hitting the rate limit
25
+ time.sleep(time_to_wait) # Adjust the sleep time as needed
26
 
27
  # Step 3: Extract attributes and total sentences for each query
28
  logging.info(f"Extracting attributes through validation LLM")
 
32
  # Step 4 : Call the get metrics calculate metrics
33
  metrics = get_metrics(attributes, total_sentences)
34
 
35
+ return response, metrics
main.py CHANGED
@@ -6,6 +6,7 @@ from retriever.embed_documents import embed_documents
6
  from generator.generate_metrics import generate_metrics
7
  from generator.initialize_llm import initialize_generation_llm
8
  from generator.initialize_llm import initialize_validation_llm
 
9
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -36,15 +37,18 @@ def main():
36
  val_llm = initialize_validation_llm()
37
 
38
  # Sample question
39
- row_num = 2
40
- query = dataset[row_num]['question']
41
 
42
  # Call generate_metrics for above sample question
43
  #generate_metrics(gen_llm, val_llm, vector_store, query)
44
 
45
  #Compute RMSE and AUC-ROC for entire dataset
46
- compute_rmse_auc_roc_metrics(gen_llm, val_llm, dataset, vector_store, 10)
47
 
 
 
 
48
  logging.info("Finished!!!")
49
 
50
  if __name__ == "__main__":
 
6
  from generator.generate_metrics import generate_metrics
7
  from generator.initialize_llm import initialize_generation_llm
8
  from generator.initialize_llm import initialize_validation_llm
9
+ from app import launch_gradio
10
 
11
  # Configure logging
12
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
37
  val_llm = initialize_validation_llm()
38
 
39
  # Sample question
40
+ #row_num = 30
41
+ #query = dataset[row_num]['question']
42
 
43
  # Call generate_metrics for above sample question
44
  #generate_metrics(gen_llm, val_llm, vector_store, query)
45
 
46
  #Compute RMSE and AUC-ROC for entire dataset
47
+ #compute_rmse_auc_roc_metrics(gen_llm, val_llm, dataset, vector_store, 10)
48
 
49
+ # Launch the Gradio app
50
+ launch_gradio(vector_store, dataset, gen_llm, val_llm)
51
+
52
  logging.info("Finished!!!")
53
 
54
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -7,4 +7,5 @@ langchain
7
  llama-index
8
  langchain-community
9
  langchain_groq
10
- langchain-huggingface
 
 
7
  llama-index
8
  langchain-community
9
  langchain_groq
10
+ langchain-huggingface
11
+ gradio