File size: 6,930 Bytes
5ba6f5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import gradio as gr
import logging
import threading
import time
from generator.compute_metrics import get_attributes_text
from generator.generate_metrics import generate_metrics, retrieve_and_generate_response
from config import AppConfig, ConfigConstants
from generator.initialize_llm import initialize_generation_llm, initialize_validation_llm
from generator.document_utils import get_logs, initialize_logging 

def launch_gradio(config : AppConfig):
    """

    Launch the Gradio app with pre-initialized objects.

    """
    initialize_logging()

    def update_logs_periodically():
        while True:
            time.sleep(2)  # Wait for 2 seconds
            yield get_logs() 

    def answer_question(query, state):
        try:
            # Generate response using the passed objects
            response, source_docs = retrieve_and_generate_response(config.gen_llm, config.vector_store, query)
            
            # Update state with the response and source documents
            state["query"] = query
            state["response"] = response
            state["source_docs"] = source_docs
            
            response_text = f"Response: {response}\n\n"
            return response_text, state
        except Exception as e:
            logging.error(f"Error processing query: {e}")
            return f"An error occurred: {e}", state

    def compute_metrics(state):
        try:
            logging.info(f"Computing metrics")
            
            # Retrieve response and source documents from state
            response = state.get("response", "")
            source_docs = state.get("source_docs", {})
            query = state.get("query", "")

            # Generate metrics using the passed objects
            attributes, metrics = generate_metrics(config.val_llm, response, source_docs, query, 1)
            
            attributes_text = get_attributes_text(attributes)

            metrics_text = "Metrics:\n"
            for key, value in metrics.items():
                if key != 'response':
                    metrics_text += f"{key}: {value}\n"
            
            return attributes_text, metrics_text
        except Exception as e:
            logging.error(f"Error computing metrics: {e}")
            return f"An error occurred: {e}", ""

    def reinitialize_llm(model_type, model_name):
        """Reinitialize the specified LLM (generation or validation) and return updated model info."""
        if model_name.strip():  # Only update if input is not empty
            if model_type == "generation":
                config.gen_llm = initialize_generation_llm(model_name)
            elif model_type == "validation":
                config.val_llm = initialize_validation_llm(model_name)

        return get_updated_model_info()

    def get_updated_model_info():
        """Generate and return the updated model information string."""
        return (
            f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
            f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
            f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
        )

    # Wrappers for event listeners
    def reinitialize_gen_llm(gen_llm_name):
        return reinitialize_llm("generation", gen_llm_name)

    def reinitialize_val_llm(val_llm_name):
        return reinitialize_llm("validation", val_llm_name)
    
    # Define Gradio Blocks layout
    with gr.Blocks() as interface:
        interface.title = "Real Time RAG Pipeline Q&A"
        gr.Markdown("# Real Time RAG Pipeline Q&A")  # Heading
        
        # Textbox for new generation LLM name
        with gr.Row():
            new_gen_llm_input = gr.Dropdown(
                label="Generation Model",
                choices=ConfigConstants.GENERATION_MODELS,  # Directly use the list
                value=ConfigConstants.GENERATION_MODELS[0] if ConfigConstants.GENERATION_MODELS else None,  # First value dynamically
                interactive=True
            )

            new_val_llm_input = gr.Dropdown(
                label="Validation Model",
                choices=ConfigConstants.VALIDATION_MODELS,  # Directly use the list
                value=ConfigConstants.VALIDATION_MODELS[0] if ConfigConstants.VALIDATION_MODELS else None,  # First value dynamically
                interactive=True
            )

            model_info_display = gr.Textbox(
                value=get_updated_model_info(),  # Use the helper function
                label="System Information",
                interactive=False  # Read-only textbox
            )

        # State to store response and source documents
        state = gr.State(value={"query": "","response": "", "source_docs": {}})
        gr.Markdown("Ask a question and get a response with metrics calculated from the RAG pipeline.")  # Description
        with gr.Row():
            query_input = gr.Textbox(label="Ask a question", placeholder="Type your query here")
        with gr.Row():
            submit_button = gr.Button("Submit", variant="primary", scale = 0)  # Submit button
            clear_query_button = gr.Button("Clear", scale = 0)  # Clear button
        with gr.Row():
            answer_output = gr.Textbox(label="Response", placeholder="Response will appear here")
        
        with gr.Row():
            compute_metrics_button = gr.Button("Compute metrics", variant="primary" , scale = 0)
            attr_output = gr.Textbox(label="Attributes", placeholder="Attributes will appear here")
            metrics_output = gr.Textbox(label="Metrics", placeholder="Metrics will appear here")
       
        #with gr.Row():
        # Attach event listeners to update model info on change
        new_gen_llm_input.change(reinitialize_gen_llm, inputs=new_gen_llm_input, outputs=model_info_display)
        new_val_llm_input.change(reinitialize_val_llm, inputs=new_val_llm_input, outputs=model_info_display)

        # Define button actions
        submit_button.click(
            fn=answer_question,
            inputs=[query_input, state],
            outputs=[answer_output, state]
        )
        clear_query_button.click(fn=lambda: "", outputs=[query_input])  # Clear query input
        compute_metrics_button.click(
            fn=compute_metrics,
            inputs=[state],
            outputs=[attr_output, metrics_output]
        )
        
        # Section to display logs
        with gr.Row():
            log_section = gr.Textbox(label="Logs", interactive=False, visible=True, lines=10 , every=2)  # Log section

        # Update UI when logs_state changes
        interface.queue() 
        interface.load(update_logs_periodically, outputs=log_section)

    interface.launch()