Spaces:
Running
Running
import gradio as gr | |
import src.srf_bot as sb | |
import prompts.system_prompts as sp | |
from langchain_core.messages import HumanMessage | |
# Initialize the chatbot | |
chatbot = sb.SRFChatbot() | |
# Dictionary to store passages with identifiers | |
retrieved_passages = {} | |
# Define the respond function | |
def respond(query, history): | |
formatted_query = [HumanMessage(content=query)] | |
# Invoke the chatbot | |
result = chatbot.graph.invoke({"messages": formatted_query}, chatbot.config) | |
# Extract the assistant's response | |
response = result["messages"][-1].content | |
# Retrieve passages from your vector database based on the query | |
# For the example, we'll use dummy passages | |
passages = [ | |
"This is the full text of Passage 1.", | |
"This is the full text of Passage 2.", | |
"This is the full text of Passage 3." | |
] | |
# Store passages with identifiers | |
passage_ids = [] | |
for idx, passage in enumerate(passages): | |
identifier = f"Passage {idx+1}" | |
retrieved_passages[identifier] = passage | |
passage_ids.append(identifier) | |
# Reference passages in the response | |
linked_response = f"{response}\n\nReferences:" | |
for pid in passage_ids: | |
linked_response += f" [{pid}]" | |
# Append to history | |
history.append((query, linked_response)) | |
return history, "" | |
# Function to get passage content based on selection | |
def get_passage_content(passage_id): | |
return retrieved_passages.get(passage_id, "Passage not found.") | |
# Function to update the system prompt | |
def update_system_prompt(selected_prompt): | |
# Update the chatbot's system prompt | |
chatbot.reset_system_prompt(selected_prompt) | |
# Update the displayed system prompt text | |
return sp.system_prompt_templates[selected_prompt] | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# SRF Chatbot") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
# Chatbot interface | |
chatbot_output = gr.Chatbot() | |
user_input = gr.Textbox(placeholder="Type your question here...", label="Your Question") | |
submit_button = gr.Button("Submit") | |
with gr.Column(scale=1): | |
# Dropdown to select system prompts | |
system_prompt_dropdown = gr.Dropdown( | |
choices=list(sp.system_prompt_templates.keys()), | |
label="Select Chatbot Instructions", | |
value=list(sp.system_prompt_templates.keys())[0] | |
) | |
# Display the selected system prompt | |
system_prompt_display = gr.Textbox( | |
value=sp.system_prompt_templates[list(sp.system_prompt_templates.keys())[0]], | |
label="Current Chatbot Instructions", | |
lines=5, | |
interactive=False | |
) | |
# Update system prompt display when a new prompt is selected | |
system_prompt_dropdown.change( | |
fn=update_system_prompt, | |
inputs=[system_prompt_dropdown], | |
outputs=[system_prompt_display] | |
) | |
# Passage selection and display | |
gr.Markdown("### References") | |
passage_selector = gr.Dropdown(label="Select a passage to view", choices=[]) | |
passage_display = gr.Markdown() | |
# Update the chatbot when the submit button is clicked | |
submit_button.click( | |
fn=respond, | |
inputs=[user_input, chatbot_output], | |
outputs=[chatbot_output, user_input] | |
) | |
# Update the passage selector options when the chatbot output changes | |
def update_passage_selector(chat_history): | |
# Get the latest passages | |
choices = list(retrieved_passages.keys()) | |
return gr.update(choices=choices) | |
chatbot_output.change( | |
fn=update_passage_selector, | |
inputs=[chatbot_output], | |
outputs=[passage_selector] | |
) | |
# Display the selected passage | |
passage_selector.change( | |
fn=get_passage_content, | |
inputs=[passage_selector], | |
outputs=[passage_display] | |
) | |
demo.launch() | |