File size: 7,860 Bytes
e234b58
 
5184c29
 
 
 
2889c96
 
e234b58
5485d7c
e234b58
 
 
5184c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e234b58
5184c29
5485d7c
e234b58
5184c29
 
 
 
e234b58
 
5184c29
e234b58
 
5184c29
e234b58
5184c29
e234b58
5184c29
 
 
 
 
 
 
 
5485d7c
5184c29
 
 
 
 
 
 
e234b58
5184c29
e234b58
5184c29
 
e234b58
2889c96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e234b58
 
 
 
 
2889c96
 
 
 
 
 
 
5184c29
 
5485d7c
 
 
2889c96
5184c29
 
 
 
e234b58
 
 
5184c29
e234b58
 
 
5184c29
e234b58
5184c29
 
e234b58
5184c29
 
 
e234b58
5184c29
 
 
 
 
 
 
 
 
 
 
2889c96
 
 
 
 
 
5184c29
2889c96
 
 
 
 
 
5184c29
 
 
 
 
 
 
 
e234b58
5184c29
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import gradio as gr
import logging
import threading
import time
from generator.compute_metrics import get_attributes_text
from generator.generate_metrics import generate_metrics, retrieve_and_generate_response
from config import AppConfig, ConfigConstants
from generator.initialize_llm import initialize_generation_llm, initialize_validation_llm 

def launch_gradio(config : AppConfig):
    """
    Launch the Gradio app with pre-initialized objects.
    """
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # Create a list to store logs
    logs = []

    # Custom log handler to capture logs and add them to the logs list
    class LogHandler(logging.Handler):
        def emit(self, record):
            log_entry = self.format(record)
            logs.append(log_entry)

    # Add custom log handler to the logger
    log_handler = LogHandler()
    log_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
    logger.addHandler(log_handler)

    def log_updater():
        """Background function to add logs."""
        while True:
            time.sleep(2)  # Update logs every 2 seconds
            pass  # Log capture is now handled by the logging system

    def get_logs():
        """Retrieve logs for display."""
        return "\n".join(logs[-50:])  # Only show the last 50 logs for example

    # Start the logging thread
    threading.Thread(target=log_updater, daemon=True).start()

    def answer_question(query, state):
        try:
            # Generate response using the passed objects
            response, source_docs = retrieve_and_generate_response(config.gen_llm, config.vector_store, query)
            
            # Update state with the response and source documents
            state["query"] = query
            state["response"] = response
            state["source_docs"] = source_docs
            
            response_text = f"Response: {response}\n\n"
            return response_text, state
        except Exception as e:
            logging.error(f"Error processing query: {e}")
            return f"An error occurred: {e}", state

    def compute_metrics(state):
        try:
            logging.info(f"Computing metrics")
            
            # Retrieve response and source documents from state
            response = state.get("response", "")
            source_docs = state.get("source_docs", {})
            query = state.get("query", "")

            # Generate metrics using the passed objects
            attributes, metrics = generate_metrics(config.val_llm, response, source_docs, query, 1)
            
            attributes_text = get_attributes_text(attributes)

            metrics_text = "Metrics:\n"
            for key, value in metrics.items():
                if key != 'response':
                    metrics_text += f"{key}: {value}\n"
            
            return attributes_text, metrics_text
        except Exception as e:
            logging.error(f"Error computing metrics: {e}")
            return f"An error occurred: {e}", ""

    def reinitialize_gen_llm(gen_llm_name):
        """Reinitialize the generation LLM and return updated model info."""
        if gen_llm_name.strip():  # Only update if input is not empty
            config.gen_llm = initialize_generation_llm(gen_llm_name)
        
        # Return updated model information
        updated_model_info = (
            f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
            f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
            f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
        )
        return updated_model_info

    def reinitialize_val_llm(val_llm_name):
        """Reinitialize the generation LLM and return updated model info."""
        if val_llm_name.strip():  # Only update if input is not empty
            config.val_llm = initialize_validation_llm(val_llm_name)
        
        # Return updated model information
        updated_model_info = (
            f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
            f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
            f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
        )
        return updated_model_info
    
    # Define Gradio Blocks layout
    with gr.Blocks() as interface:
        interface.title = "Real Time RAG Pipeline Q&A"
        gr.Markdown("### Real Time RAG Pipeline Q&A")  # Heading
        
        # Textbox for new generation LLM name
        with gr.Row():
            new_gen_llm_input = gr.Textbox(label="New Generation LLM Name", placeholder="Enter LLM name to update")
            update_gen_llm_button = gr.Button("Update Generation LLM")
            new_val_llm_input = gr.Textbox(label="New Validation LLM Name", placeholder="Enter LLM name to update")
            update_val_llm_button = gr.Button("Update Validation LLM")

        # Section to display LLM names
        with gr.Row():
            model_info = f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
            model_info += f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
            model_info += f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
            model_info_display = gr.Textbox(value=model_info, label="Model Information", interactive=False)  # Read-only textbox

        # State to store response and source documents
        state = gr.State(value={"query": "","response": "", "source_docs": {}})
        gr.Markdown("Ask a question and get a response with metrics calculated from the RAG pipeline.")  # Description
        with gr.Row():
            query_input = gr.Textbox(label="Ask a question", placeholder="Type your query here")
        with gr.Row():
            submit_button = gr.Button("Submit", variant="primary")  # Submit button
            clear_query_button = gr.Button("Clear")  # Clear button
        with gr.Row():
            answer_output = gr.Textbox(label="Response", placeholder="Response will appear here")
        
        with gr.Row():
            compute_metrics_button = gr.Button("Compute metrics", variant="primary")
            attr_output = gr.Textbox(label="Attributes", placeholder="Attributes will appear here")
            metrics_output = gr.Textbox(label="Metrics", placeholder="Metrics will appear here")
       
        #with gr.Row():

        # Define button actions
        submit_button.click(
            fn=answer_question,
            inputs=[query_input, state],
            outputs=[answer_output, state]
        )
        clear_query_button.click(fn=lambda: "", outputs=[query_input])  # Clear query input
        compute_metrics_button.click(
            fn=compute_metrics,
            inputs=[state],
            outputs=[attr_output, metrics_output]
        )
        
        update_gen_llm_button.click(
            fn=reinitialize_gen_llm,
            inputs=[new_gen_llm_input],
            outputs=[model_info_display]  # Update the displayed model info
        )

        update_val_llm_button.click(
            fn=reinitialize_val_llm,
            inputs=[new_val_llm_input],
            outputs=[model_info_display]  # Update the displayed model info
        )
        
        # Section to display logs
        with gr.Row():
            start_log_button = gr.Button("Start Log Update", elem_id="start_btn")  # Button to start log updates
        with gr.Row():
            log_section = gr.Textbox(label="Logs", interactive=False, visible=True, lines=10)  # Log section

        # Set button click to trigger log updates
        start_log_button.click(fn=get_logs, outputs=log_section)

    interface.launch()