AbenzaFran's picture
no check
959e7c7
raw
history blame
8.23 kB
import os
import re
import streamlit as st
from dotenv import load_dotenv
from langsmith import traceable
from langsmith.wrappers import wrap_openai
import openai
import asyncio
import threading
from PIL import Image
import io
import json
import queue
import logging
import time
# Load environment variables
load_dotenv()
openai.api_key = os.getenv("openai.api_key")
LANGSMITH_API_KEY = os.getenv("LANGSMITH_API_KEY")
ASSISTANT_ID = os.getenv("ASSISTANT_ID_SOLUTION_SPECIFIER_A")
# if not all([openai.api_key, LANGSMITH_API_KEY, ASSISTANT_ID]):
# raise ValueError("Please set openai.api_key, LANGSMITH_API_KEY, and ASSISTANT_ID in your .env file.")
# Initialize logging
logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s", level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize Langsmith's traceable OpenAI client
wrapped_openai = wrap_openai(openai.Client(api_key=openai.api_key, api_base="https://api.openai.com"))
# Initialize Langsmith client (ensure you have configured Langsmith correctly)
# Assuming Langsmith uses environment variables or configuration files for setup
# If not, initialize it here accordingly
# Define a traceable function to handle Assistant interactions
@traceable
def create_run(thread_id: str, assistant_id: str) -> openai.beta.RunsStream:
"""
Creates a streaming run with the Assistant.
"""
return wrapped_openai.beta.threads.runs.stream(
thread_id=thread_id,
assistant_id=assistant_id,
model="gpt-4o", # Replace with your desired model
stream=True
)
# Function to remove citations as per your original code
def remove_citation(text: str) -> str:
pattern = r"【\d+†\w+】"
return re.sub(pattern, "πŸ“š", text)
# Initialize session state for messages, thread_id, and tool_requests
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "thread_id" not in st.session_state:
st.session_state["thread_id"] = None
if "tool_requests" not in st.session_state:
st.session_state["tool_requests"] = queue.Queue()
tool_requests = st.session_state["tool_requests"]
# Initialize Streamlit page
st.set_page_config(page_title="Solution Specifier A", layout="centered")
st.title("Solution Specifier A")
# Display existing messages
for msg in st.session_state["messages"]:
if msg["role"] == "user":
with st.chat_message("user"):
st.write(msg["content"])
else:
with st.chat_message("assistant"):
if isinstance(msg["content"], Image.Image):
st.image(msg["content"])
else:
st.write(msg["content"])
# Chat input widget
user_input = st.chat_input("Type your message here...")
# Function to handle tool requests (function calls)
def handle_requires_action(tool_request):
st.toast("Running a function", icon=":material/function:")
tool_outputs = []
data = tool_request.data
for tool in data.required_action.submit_tool_outputs.tool_calls:
if tool.function.arguments:
function_arguments = json.loads(tool.function.arguments)
else:
function_arguments = {}
match tool.function.name:
case "hello_world":
logger.info("Calling hello_world function")
answer = hello_world(**function_arguments)
tool_outputs.append({"tool_call_id": tool.id, "output": answer})
case _:
logger.error(f"Unrecognized function name: {tool.function.name}. Tool: {tool}")
ret_val = {
"status": "error",
"message": f"Function name is not recognized. Ensure the correct function name and try again."
}
tool_outputs.append({"tool_call_id": tool.id, "output": json.dumps(ret_val)})
st.toast("Function completed", icon=":material/function:")
return tool_outputs, data.thread_id, data.id
# Example function that could be called by the Assistant
def hello_world(name: str) -> str:
time.sleep(2) # Simulate a long-running task
return f"Hello {name}!"
# Function to add assistant messages to session state
def add_message_to_state_session(message):
if len(message) > 0:
st.session_state["messages"].append({"role": "assistant", "content": message})
# Function to process streamed data
def data_streamer(stream):
"""
Stream data from the assistant. Text messages are yielded. Images and tool requests are put in the queue.
"""
logger.info("Starting data streamer")
st.toast("Thinking...", icon=":material/emoji_objects:")
content_produced = False
try:
for response in stream:
event = response.event
if event == "thread.message.delta":
content = response.data.delta.content[0]
if content.type == "text":
value = content.text.value
content_produced = True
yield value
elif content.type == "image_file":
logger.info("Image file received")
image_content = io.BytesIO(wrapped_openai.files.content(content.image_file.file_id).read())
img = Image.open(image_content)
content_produced = True
yield img
elif event == "thread.run.requires_action":
logger.info("Run requires action")
tool_requests.put(response)
if not content_produced:
yield "[LLM requires a function call]"
break
elif event == "thread.run.failed":
logger.error("Run failed")
yield "[Run failed]"
break
finally:
st.toast("Completed", icon=":material/emoji_objects:")
logger.info("Finished data streamer")
# Function to display the streamed response
def display_stream(stream):
with st.chat_message("assistant"):
for content in data_streamer(stream):
if isinstance(content, Image.Image):
st.image(content)
add_message_to_state_session(content)
else:
st.write(content)
add_message_to_state_session(content)
# Main function to handle user input and assistant response
def main():
if user_input:
# Add user message to session state
st.session_state["messages"].append({"role": "user", "content": user_input})
# Display the user's message
with st.chat_message("user"):
st.write(user_input)
# Create a new thread if it doesn't exist
if st.session_state["thread_id"] is None:
logger.info("Creating new thread")
thread = wrapped_openai.beta.threads.create()
st.session_state["thread_id"] = thread.id
else:
thread = wrapped_openai.beta.threads.retrieve(st.session_state["thread_id"])
# Add user message to the thread
wrapped_openai.beta.threads.messages.create(
thread_id=thread.id,
role="user",
content=user_input
)
# Create a new run with streaming
logger.info("Creating a new run with streaming")
stream = create_run(thread.id, ASSISTANT_ID)
# Start a separate thread to handle streaming to avoid blocking Streamlit
stream_thread = threading.Thread(target=display_stream, args=(stream,))
stream_thread.start()
# Handle tool requests if any
while not tool_requests.empty():
logger.info("Handling tool requests")
tool_request = tool_requests.get()
tool_outputs, thread_id, run_id = handle_requires_action(tool_request)
wrapped_openai.beta.threads.runs.submit_tool_outputs_stream(
thread_id=thread_id,
run_id=run_id,
tool_outputs=tool_outputs
)
# After handling, create a new stream to continue the conversation
new_stream = create_run(thread_id, ASSISTANT_ID)
display_stream(new_stream)
# Run the main function
main()