File size: 6,930 Bytes
5ba6f5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import gradio as gr
import logging
import threading
import time
from generator.compute_metrics import get_attributes_text
from generator.generate_metrics import generate_metrics, retrieve_and_generate_response
from config import AppConfig, ConfigConstants
from generator.initialize_llm import initialize_generation_llm, initialize_validation_llm
from generator.document_utils import get_logs, initialize_logging
def launch_gradio(config : AppConfig):
"""
Launch the Gradio app with pre-initialized objects.
"""
initialize_logging()
def update_logs_periodically():
while True:
time.sleep(2) # Wait for 2 seconds
yield get_logs()
def answer_question(query, state):
try:
# Generate response using the passed objects
response, source_docs = retrieve_and_generate_response(config.gen_llm, config.vector_store, query)
# Update state with the response and source documents
state["query"] = query
state["response"] = response
state["source_docs"] = source_docs
response_text = f"Response: {response}\n\n"
return response_text, state
except Exception as e:
logging.error(f"Error processing query: {e}")
return f"An error occurred: {e}", state
def compute_metrics(state):
try:
logging.info(f"Computing metrics")
# Retrieve response and source documents from state
response = state.get("response", "")
source_docs = state.get("source_docs", {})
query = state.get("query", "")
# Generate metrics using the passed objects
attributes, metrics = generate_metrics(config.val_llm, response, source_docs, query, 1)
attributes_text = get_attributes_text(attributes)
metrics_text = "Metrics:\n"
for key, value in metrics.items():
if key != 'response':
metrics_text += f"{key}: {value}\n"
return attributes_text, metrics_text
except Exception as e:
logging.error(f"Error computing metrics: {e}")
return f"An error occurred: {e}", ""
def reinitialize_llm(model_type, model_name):
"""Reinitialize the specified LLM (generation or validation) and return updated model info."""
if model_name.strip(): # Only update if input is not empty
if model_type == "generation":
config.gen_llm = initialize_generation_llm(model_name)
elif model_type == "validation":
config.val_llm = initialize_validation_llm(model_name)
return get_updated_model_info()
def get_updated_model_info():
"""Generate and return the updated model information string."""
return (
f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
)
# Wrappers for event listeners
def reinitialize_gen_llm(gen_llm_name):
return reinitialize_llm("generation", gen_llm_name)
def reinitialize_val_llm(val_llm_name):
return reinitialize_llm("validation", val_llm_name)
# Define Gradio Blocks layout
with gr.Blocks() as interface:
interface.title = "Real Time RAG Pipeline Q&A"
gr.Markdown("# Real Time RAG Pipeline Q&A") # Heading
# Textbox for new generation LLM name
with gr.Row():
new_gen_llm_input = gr.Dropdown(
label="Generation Model",
choices=ConfigConstants.GENERATION_MODELS, # Directly use the list
value=ConfigConstants.GENERATION_MODELS[0] if ConfigConstants.GENERATION_MODELS else None, # First value dynamically
interactive=True
)
new_val_llm_input = gr.Dropdown(
label="Validation Model",
choices=ConfigConstants.VALIDATION_MODELS, # Directly use the list
value=ConfigConstants.VALIDATION_MODELS[0] if ConfigConstants.VALIDATION_MODELS else None, # First value dynamically
interactive=True
)
model_info_display = gr.Textbox(
value=get_updated_model_info(), # Use the helper function
label="System Information",
interactive=False # Read-only textbox
)
# State to store response and source documents
state = gr.State(value={"query": "","response": "", "source_docs": {}})
gr.Markdown("Ask a question and get a response with metrics calculated from the RAG pipeline.") # Description
with gr.Row():
query_input = gr.Textbox(label="Ask a question", placeholder="Type your query here")
with gr.Row():
submit_button = gr.Button("Submit", variant="primary", scale = 0) # Submit button
clear_query_button = gr.Button("Clear", scale = 0) # Clear button
with gr.Row():
answer_output = gr.Textbox(label="Response", placeholder="Response will appear here")
with gr.Row():
compute_metrics_button = gr.Button("Compute metrics", variant="primary" , scale = 0)
attr_output = gr.Textbox(label="Attributes", placeholder="Attributes will appear here")
metrics_output = gr.Textbox(label="Metrics", placeholder="Metrics will appear here")
#with gr.Row():
# Attach event listeners to update model info on change
new_gen_llm_input.change(reinitialize_gen_llm, inputs=new_gen_llm_input, outputs=model_info_display)
new_val_llm_input.change(reinitialize_val_llm, inputs=new_val_llm_input, outputs=model_info_display)
# Define button actions
submit_button.click(
fn=answer_question,
inputs=[query_input, state],
outputs=[answer_output, state]
)
clear_query_button.click(fn=lambda: "", outputs=[query_input]) # Clear query input
compute_metrics_button.click(
fn=compute_metrics,
inputs=[state],
outputs=[attr_output, metrics_output]
)
# Section to display logs
with gr.Row():
log_section = gr.Textbox(label="Logs", interactive=False, visible=True, lines=10 , every=2) # Log section
# Update UI when logs_state changes
interface.queue()
interface.load(update_logs_periodically, outputs=log_section)
interface.launch() |