|
import gradio as gr |
|
import logging |
|
import threading |
|
import time |
|
from generator.compute_metrics import get_attributes_text |
|
from generator.generate_metrics import generate_metrics, retrieve_and_generate_response |
|
from io import StringIO |
|
|
|
def launch_gradio(vector_store, gen_llm, val_llm): |
|
""" |
|
Launch the Gradio app with pre-initialized objects. |
|
""" |
|
logger = logging.getLogger() |
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
logs = [] |
|
|
|
|
|
class LogHandler(logging.Handler): |
|
def emit(self, record): |
|
log_entry = self.format(record) |
|
logs.append(log_entry) |
|
|
|
|
|
log_handler = LogHandler() |
|
log_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s')) |
|
logger.addHandler(log_handler) |
|
|
|
def log_updater(): |
|
"""Background function to add logs.""" |
|
while True: |
|
time.sleep(2) |
|
pass |
|
|
|
def get_logs(): |
|
"""Retrieve logs for display.""" |
|
return "\n".join(logs[-50:]) |
|
|
|
|
|
threading.Thread(target=log_updater, daemon=True).start() |
|
|
|
def answer_question(query, state): |
|
try: |
|
|
|
response, source_docs = retrieve_and_generate_response(gen_llm, vector_store, query) |
|
|
|
|
|
state["query"] = query |
|
state["response"] = response |
|
state["source_docs"] = source_docs |
|
|
|
response_text = f"Response: {response}\n\n" |
|
return response_text, state |
|
except Exception as e: |
|
logging.error(f"Error processing query: {e}") |
|
return f"An error occurred: {e}", state |
|
|
|
def compute_metrics(state): |
|
try: |
|
logging.info(f"Computing metrics") |
|
|
|
|
|
response = state.get("response", "") |
|
source_docs = state.get("source_docs", {}) |
|
query = state.get("query", "") |
|
|
|
|
|
attributes, metrics = generate_metrics(val_llm, response, source_docs, query, 1) |
|
|
|
attributes_text = get_attributes_text(attributes) |
|
|
|
metrics_text = "Metrics:\n" |
|
for key, value in metrics.items(): |
|
if key != 'response': |
|
metrics_text += f"{key}: {value}\n" |
|
|
|
return attributes_text, metrics_text |
|
except Exception as e: |
|
logging.error(f"Error computing metrics: {e}") |
|
return f"An error occurred: {e}", "" |
|
|
|
|
|
with gr.Blocks() as interface: |
|
interface.title = "Real Time RAG Pipeline Q&A" |
|
gr.Markdown("### Real Time RAG Pipeline Q&A") |
|
|
|
|
|
with gr.Row(): |
|
model_info = f"Generation LLM: {gen_llm.name if hasattr(gen_llm, 'name') else 'Unknown'}\n" |
|
model_info += f"Validation LLM: {val_llm.name if hasattr(val_llm, 'name') else 'Unknown'}\n" |
|
gr.Textbox(value=model_info, label="Model Information", interactive=False) |
|
|
|
|
|
state = gr.State(value={"query": "","response": "", "source_docs": {}}) |
|
gr.Markdown("Ask a question and get a response with metrics calculated from the RAG pipeline.") |
|
with gr.Row(): |
|
query_input = gr.Textbox(label="Ask a question", placeholder="Type your query here") |
|
with gr.Row(): |
|
submit_button = gr.Button("Submit", variant="primary") |
|
clear_query_button = gr.Button("Clear") |
|
with gr.Row(): |
|
answer_output = gr.Textbox(label="Response", placeholder="Response will appear here") |
|
|
|
with gr.Row(): |
|
compute_metrics_button = gr.Button("Compute metrics", variant="primary") |
|
attr_output = gr.Textbox(label="Attributes", placeholder="Attributes will appear here") |
|
metrics_output = gr.Textbox(label="Metrics", placeholder="Metrics will appear here") |
|
|
|
|
|
|
|
|
|
submit_button.click( |
|
fn=answer_question, |
|
inputs=[query_input, state], |
|
outputs=[answer_output, state] |
|
) |
|
clear_query_button.click(fn=lambda: "", outputs=[query_input]) |
|
compute_metrics_button.click( |
|
fn=compute_metrics, |
|
inputs=[state], |
|
outputs=[attr_output, metrics_output] |
|
) |
|
|
|
|
|
with gr.Row(): |
|
start_log_button = gr.Button("Start Log Update", elem_id="start_btn") |
|
with gr.Row(): |
|
log_section = gr.Textbox(label="Logs", interactive=False, visible=True, lines=10) |
|
|
|
|
|
start_log_button.click(fn=get_logs, outputs=log_section) |
|
|
|
interface.launch() |