Gourisankar Padihary
Multiple data set support
5184c29
raw
history blame
5.39 kB
import gradio as gr
import logging
import threading
import time
from generator.compute_metrics import get_attributes_text
from generator.generate_metrics import generate_metrics, retrieve_and_generate_response
from io import StringIO
def launch_gradio(vector_store, gen_llm, val_llm):
"""
Launch the Gradio app with pre-initialized objects.
"""
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# Create a list to store logs
logs = []
# Custom log handler to capture logs and add them to the logs list
class LogHandler(logging.Handler):
def emit(self, record):
log_entry = self.format(record)
logs.append(log_entry)
# Add custom log handler to the logger
log_handler = LogHandler()
log_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
logger.addHandler(log_handler)
def log_updater():
"""Background function to add logs."""
while True:
time.sleep(2) # Update logs every 2 seconds
pass # Log capture is now handled by the logging system
def get_logs():
"""Retrieve logs for display."""
return "\n".join(logs[-50:]) # Only show the last 50 logs for example
# Start the logging thread
threading.Thread(target=log_updater, daemon=True).start()
def answer_question(query, state):
try:
# Generate response using the passed objects
response, source_docs = retrieve_and_generate_response(gen_llm, vector_store, query)
# Update state with the response and source documents
state["query"] = query
state["response"] = response
state["source_docs"] = source_docs
response_text = f"Response: {response}\n\n"
return response_text, state
except Exception as e:
logging.error(f"Error processing query: {e}")
return f"An error occurred: {e}", state
def compute_metrics(state):
try:
logging.info(f"Computing metrics")
# Retrieve response and source documents from state
response = state.get("response", "")
source_docs = state.get("source_docs", {})
query = state.get("query", "")
# Generate metrics using the passed objects
attributes, metrics = generate_metrics(val_llm, response, source_docs, query, 1)
attributes_text = get_attributes_text(attributes)
metrics_text = "Metrics:\n"
for key, value in metrics.items():
if key != 'response':
metrics_text += f"{key}: {value}\n"
return attributes_text, metrics_text
except Exception as e:
logging.error(f"Error computing metrics: {e}")
return f"An error occurred: {e}", ""
# Define Gradio Blocks layout
with gr.Blocks() as interface:
interface.title = "Real Time RAG Pipeline Q&A"
gr.Markdown("### Real Time RAG Pipeline Q&A") # Heading
# Section to display LLM names
with gr.Row():
model_info = f"Generation LLM: {gen_llm.name if hasattr(gen_llm, 'name') else 'Unknown'}\n"
model_info += f"Validation LLM: {val_llm.name if hasattr(val_llm, 'name') else 'Unknown'}\n"
gr.Textbox(value=model_info, label="Model Information", interactive=False) # Read-only textbox
# State to store response and source documents
state = gr.State(value={"query": "","response": "", "source_docs": {}})
gr.Markdown("Ask a question and get a response with metrics calculated from the RAG pipeline.") # Description
with gr.Row():
query_input = gr.Textbox(label="Ask a question", placeholder="Type your query here")
with gr.Row():
submit_button = gr.Button("Submit", variant="primary") # Submit button
clear_query_button = gr.Button("Clear") # Clear button
with gr.Row():
answer_output = gr.Textbox(label="Response", placeholder="Response will appear here")
with gr.Row():
compute_metrics_button = gr.Button("Compute metrics", variant="primary")
attr_output = gr.Textbox(label="Attributes", placeholder="Attributes will appear here")
metrics_output = gr.Textbox(label="Metrics", placeholder="Metrics will appear here")
#with gr.Row():
# Define button actions
submit_button.click(
fn=answer_question,
inputs=[query_input, state],
outputs=[answer_output, state]
)
clear_query_button.click(fn=lambda: "", outputs=[query_input]) # Clear query input
compute_metrics_button.click(
fn=compute_metrics,
inputs=[state],
outputs=[attr_output, metrics_output]
)
# Section to display logs
with gr.Row():
start_log_button = gr.Button("Start Log Update", elem_id="start_btn") # Button to start log updates
with gr.Row():
log_section = gr.Textbox(label="Logs", interactive=False, visible=True, lines=10) # Log section
# Set button click to trigger log updates
start_log_button.click(fn=get_logs, outputs=log_section)
interface.launch()