Spaces:
Sleeping
Sleeping
import os | |
import re | |
import streamlit as st | |
from dotenv import load_dotenv | |
from langsmith import traceable | |
from langsmith.wrappers import wrap_openai | |
import openai | |
import asyncio | |
import threading | |
from PIL import Image | |
import io | |
import json | |
import queue | |
import logging | |
import time | |
# Load environment variables | |
load_dotenv() | |
openai.api_key = os.getenv("openai.api_key") | |
LANGSMITH_API_KEY = os.getenv("LANGSMITH_API_KEY") | |
ASSISTANT_ID = os.getenv("ASSISTANT_ID_SOLUTION_SPECIFIER_A") | |
# if not all([openai.api_key, LANGSMITH_API_KEY, ASSISTANT_ID]): | |
# raise ValueError("Please set openai.api_key, LANGSMITH_API_KEY, and ASSISTANT_ID in your .env file.") | |
# Initialize logging | |
logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s", level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize Langsmith's traceable OpenAI client | |
wrapped_openai = wrap_openai(openai.Client(api_key=openai.api_key, api_base="https://api.openai.com")) | |
# Initialize Langsmith client (ensure you have configured Langsmith correctly) | |
# Assuming Langsmith uses environment variables or configuration files for setup | |
# If not, initialize it here accordingly | |
# Define a traceable function to handle Assistant interactions | |
def create_run(thread_id: str, assistant_id: str) -> openai.beta.RunsStream: | |
""" | |
Creates a streaming run with the Assistant. | |
""" | |
return wrapped_openai.beta.threads.runs.stream( | |
thread_id=thread_id, | |
assistant_id=assistant_id, | |
model="gpt-4o", # Replace with your desired model | |
stream=True | |
) | |
# Function to remove citations as per your original code | |
def remove_citation(text: str) -> str: | |
pattern = r"γ\d+β \w+γ" | |
return re.sub(pattern, "π", text) | |
# Initialize session state for messages, thread_id, and tool_requests | |
if "messages" not in st.session_state: | |
st.session_state["messages"] = [] | |
if "thread_id" not in st.session_state: | |
st.session_state["thread_id"] = None | |
if "tool_requests" not in st.session_state: | |
st.session_state["tool_requests"] = queue.Queue() | |
tool_requests = st.session_state["tool_requests"] | |
# Initialize Streamlit page | |
st.set_page_config(page_title="Solution Specifier A", layout="centered") | |
st.title("Solution Specifier A") | |
# Display existing messages | |
for msg in st.session_state["messages"]: | |
if msg["role"] == "user": | |
with st.chat_message("user"): | |
st.write(msg["content"]) | |
else: | |
with st.chat_message("assistant"): | |
if isinstance(msg["content"], Image.Image): | |
st.image(msg["content"]) | |
else: | |
st.write(msg["content"]) | |
# Chat input widget | |
user_input = st.chat_input("Type your message here...") | |
# Function to handle tool requests (function calls) | |
def handle_requires_action(tool_request): | |
st.toast("Running a function", icon=":material/function:") | |
tool_outputs = [] | |
data = tool_request.data | |
for tool in data.required_action.submit_tool_outputs.tool_calls: | |
if tool.function.arguments: | |
function_arguments = json.loads(tool.function.arguments) | |
else: | |
function_arguments = {} | |
match tool.function.name: | |
case "hello_world": | |
logger.info("Calling hello_world function") | |
answer = hello_world(**function_arguments) | |
tool_outputs.append({"tool_call_id": tool.id, "output": answer}) | |
case _: | |
logger.error(f"Unrecognized function name: {tool.function.name}. Tool: {tool}") | |
ret_val = { | |
"status": "error", | |
"message": f"Function name is not recognized. Ensure the correct function name and try again." | |
} | |
tool_outputs.append({"tool_call_id": tool.id, "output": json.dumps(ret_val)}) | |
st.toast("Function completed", icon=":material/function:") | |
return tool_outputs, data.thread_id, data.id | |
# Example function that could be called by the Assistant | |
def hello_world(name: str) -> str: | |
time.sleep(2) # Simulate a long-running task | |
return f"Hello {name}!" | |
# Function to add assistant messages to session state | |
def add_message_to_state_session(message): | |
if len(message) > 0: | |
st.session_state["messages"].append({"role": "assistant", "content": message}) | |
# Function to process streamed data | |
def data_streamer(stream): | |
""" | |
Stream data from the assistant. Text messages are yielded. Images and tool requests are put in the queue. | |
""" | |
logger.info("Starting data streamer") | |
st.toast("Thinking...", icon=":material/emoji_objects:") | |
content_produced = False | |
try: | |
for response in stream: | |
event = response.event | |
if event == "thread.message.delta": | |
content = response.data.delta.content[0] | |
if content.type == "text": | |
value = content.text.value | |
content_produced = True | |
yield value | |
elif content.type == "image_file": | |
logger.info("Image file received") | |
image_content = io.BytesIO(wrapped_openai.files.content(content.image_file.file_id).read()) | |
img = Image.open(image_content) | |
content_produced = True | |
yield img | |
elif event == "thread.run.requires_action": | |
logger.info("Run requires action") | |
tool_requests.put(response) | |
if not content_produced: | |
yield "[LLM requires a function call]" | |
break | |
elif event == "thread.run.failed": | |
logger.error("Run failed") | |
yield "[Run failed]" | |
break | |
finally: | |
st.toast("Completed", icon=":material/emoji_objects:") | |
logger.info("Finished data streamer") | |
# Function to display the streamed response | |
def display_stream(stream): | |
with st.chat_message("assistant"): | |
for content in data_streamer(stream): | |
if isinstance(content, Image.Image): | |
st.image(content) | |
add_message_to_state_session(content) | |
else: | |
st.write(content) | |
add_message_to_state_session(content) | |
# Main function to handle user input and assistant response | |
def main(): | |
if user_input: | |
# Add user message to session state | |
st.session_state["messages"].append({"role": "user", "content": user_input}) | |
# Display the user's message | |
with st.chat_message("user"): | |
st.write(user_input) | |
# Create a new thread if it doesn't exist | |
if st.session_state["thread_id"] is None: | |
logger.info("Creating new thread") | |
thread = wrapped_openai.beta.threads.create() | |
st.session_state["thread_id"] = thread.id | |
else: | |
thread = wrapped_openai.beta.threads.retrieve(st.session_state["thread_id"]) | |
# Add user message to the thread | |
wrapped_openai.beta.threads.messages.create( | |
thread_id=thread.id, | |
role="user", | |
content=user_input | |
) | |
# Create a new run with streaming | |
logger.info("Creating a new run with streaming") | |
stream = create_run(thread.id, ASSISTANT_ID) | |
# Start a separate thread to handle streaming to avoid blocking Streamlit | |
stream_thread = threading.Thread(target=display_stream, args=(stream,)) | |
stream_thread.start() | |
# Handle tool requests if any | |
while not tool_requests.empty(): | |
logger.info("Handling tool requests") | |
tool_request = tool_requests.get() | |
tool_outputs, thread_id, run_id = handle_requires_action(tool_request) | |
wrapped_openai.beta.threads.runs.submit_tool_outputs_stream( | |
thread_id=thread_id, | |
run_id=run_id, | |
tool_outputs=tool_outputs | |
) | |
# After handling, create a new stream to continue the conversation | |
new_stream = create_run(thread_id, ASSISTANT_ID) | |
display_stream(new_stream) | |
# Run the main function | |
main() |