Spaces:
Sleeping
Sleeping
import os | |
import re | |
import streamlit as st | |
from dotenv import load_dotenv | |
import io | |
import time | |
import json | |
import queue | |
import logging | |
from PIL import Image | |
from typing import Optional | |
# ------------------------ | |
# LangSmith imports | |
# ------------------------ | |
import openai | |
from langsmith.wrappers import wrap_openai | |
from langsmith import traceable | |
# ------------------------ | |
# Configure logging | |
# ------------------------ | |
def init_logging(): | |
logging.basicConfig( | |
format="[%(asctime)s] %(levelname)s: %(message)s", | |
level=logging.INFO, | |
handlers=[ | |
logging.StreamHandler() | |
] | |
) | |
return logging.getLogger() | |
logger = init_logging() | |
# ------------------------ | |
# Load environment variables | |
# ------------------------ | |
load_dotenv() | |
api_key = os.getenv("OPENAI_API_KEY") | |
assistant_id = os.getenv("ASSISTANT_ID_SOLUTION_SPECIFIER_A") # The assistant we want to call | |
if not api_key or not assistant_id: | |
logger.error("Environment variables OPENAI_API_KEY and ASSISTANT_ID_SOLUTION_SPECIFIER_A must be set.") | |
st.error("Missing environment configuration. Please set the required environment variables.") | |
st.stop() | |
# ------------------------ | |
# Wrap the OpenAI client for LangSmith traceability | |
# ------------------------ | |
openai_client = openai.Client(api_key=api_key) | |
client = wrap_openai(openai_client) | |
# ------------------------ | |
# Streamlit session state initialization | |
# ------------------------ | |
if "messages" not in st.session_state: | |
st.session_state["messages"] = [] | |
if "thread_id" not in st.session_state: | |
st.session_state["thread_id"] = None | |
if "tool_requests" not in st.session_state: | |
st.session_state["tool_requests"] = queue.Queue() | |
if "current_run" not in st.session_state: | |
st.session_state["current_run"] = None | |
tool_requests = st.session_state["tool_requests"] | |
# ------------------------ | |
# Utility to remove citations like: | |
# ------------------------ | |
def remove_citation(text: str) -> str: | |
pattern = r"γ\d+β \w+γ" | |
return re.sub(pattern, "π", text) | |
# ------------------------ | |
# Function to handle tool requests (function calls) | |
# ------------------------ | |
def handle_tool_request(event): | |
""" | |
Processes function call requests from the assistant. | |
""" | |
logger.info(f"Handling tool request: {event}") | |
st.toast("Processing a function call...", icon=":hammer_and_wrench:") | |
tool_outputs = [] | |
data = event.data | |
for tool_call in data.required_action.submit_tool_outputs.tool_calls: | |
function_name = tool_call.function.name | |
arguments = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {} | |
logger.info(f"Executing function '{function_name}' with arguments {arguments}") | |
try: | |
# Map function names to actual implementations | |
if function_name == "hello_world": | |
output = hello_world(**arguments) | |
elif function_name == "another_function": | |
output = another_function(**arguments) | |
else: | |
raise ValueError(f"Unrecognized function name: {function_name}") | |
tool_outputs.append({"tool_call_id": tool_call.id, "output": output}) | |
logger.info(f"Function '{function_name}' executed successfully.") | |
except Exception as e: | |
logger.error(f"Error executing function '{function_name}': {e}") | |
error_response = {"status": "error", "message": str(e)} | |
tool_outputs.append({"tool_call_id": tool_call.id, "output": json.dumps(error_response)}) | |
st.toast("Function call completed.", icon=":white_check_mark:") | |
return tool_outputs, data.thread_id, data.id | |
# ------------------------ | |
# Example function implementations | |
# ------------------------ | |
def hello_world(name: str = "World") -> str: | |
""" | |
Example function that returns a greeting. | |
""" | |
time.sleep(2) # Simulate a delay for a long-running task | |
return f"Hello, {name}! This message is from a function call." | |
def another_function(param1: str, param2: int) -> str: | |
""" | |
Another example function. | |
""" | |
time.sleep(1) | |
return f"Received param1: {param1} and param2: {param2}." | |
# ------------------------ | |
# Streamlit UI Components | |
# ------------------------ | |
def display_message(role: str, content: str): | |
""" | |
Displays a message in the Streamlit chat interface. | |
""" | |
with st.chat_message(role): | |
if role == "assistant" and isinstance(content, Image.Image): | |
st.image(content) | |
else: | |
st.write(content) | |
# ------------------------ | |
# Helper: data streamer for text & images | |
# ------------------------ | |
def data_streamer(): | |
""" | |
Streams data from the assistant run. Yields text or images | |
and enqueues tool requests (function calls) to tool_requests. | |
""" | |
logger.info("Starting data streamer.") | |
st.toast("Thinking...", icon=":hourglass_flowing_sand:") | |
content_produced = False | |
accumulated_content = "" | |
try: | |
for event in st.session_state["current_run"]: | |
match event.event: | |
case "thread.message.delta": | |
content = event.data.delta.content[0] | |
match content.type: | |
case "text": | |
text_value = content.text.value | |
accumulated_content += text_value | |
content_produced = True | |
yield remove_citation(text_value) | |
case "image_file": | |
file_id = content.image_file.file_id | |
logger.info(f"Received image file ID: {file_id}") | |
image_content = io.BytesIO(client.files.content(file_id).read()) | |
image = Image.open(image_content) | |
yield image | |
case "thread.run.requires_action": | |
logger.info(f"Run requires action: {event}") | |
tool_requests.put(event) | |
if not content_produced: | |
yield "[LLM is requesting a function call...]" | |
return | |
case "thread.run.failed": | |
logger.error(f"Run failed: {event}") | |
st.error("The assistant encountered an error and couldn't complete the request.") | |
return | |
except Exception as e: | |
logger.exception(f"Exception in data_streamer: {e}") | |
st.error(f"An unexpected error occurred: {e}") | |
finally: | |
st.toast("Completed", icon=":checkered_flag:") | |
# ------------------------ | |
# Helper: display the streaming content | |
# ------------------------ | |
def display_stream(run_stream, create_context=True): | |
""" | |
Grabs tokens from data_streamer() and displays them in real-time. | |
If `create_context=True`, messages are displayed as an assistant block. | |
""" | |
logger.info("Displaying stream.") | |
st.session_state["current_run"] = run_stream | |
if create_context: | |
with st.chat_message("assistant"): | |
for content in data_streamer(): | |
display_message("assistant", content) | |
else: | |
for content in data_streamer(): | |
display_message("assistant", content) | |
# After streaming, accumulate the final content | |
# This assumes that the entire content has been yielded | |
# You might want to enhance this to handle partial content or interruptions | |
# Here, we simply capture accumulated content if it's text | |
# For images, it's already displayed | |
if accumulated_text := remove_citation(accumulated_content.strip()): | |
st.session_state["messages"].append({"role": "assistant", "content": accumulated_text}) | |
# ------------------------ | |
# Main chat logic with traceability | |
# ------------------------ | |
# Enable LangSmith traceability | |
def generate_assistant_reply(user_input: str): | |
""" | |
Handles user input by creating or continuing a thread, | |
sending the message to the assistant, and streaming the response. | |
""" | |
logger.info(f"User input received: {user_input}") | |
# Create or retrieve thread | |
if not st.session_state["thread_id"]: | |
logger.info("Creating a new thread.") | |
thread = client.beta.threads.create() | |
st.session_state["thread_id"] = thread.id | |
else: | |
thread = client.beta.threads.retrieve(thread_id=st.session_state["thread_id"]) | |
logger.info(f"Using existing thread ID: {thread.id}") | |
# Add user message to the thread | |
try: | |
client.beta.threads.messages.create( | |
thread_id=thread.id, | |
role="user", | |
content=user_input | |
) | |
logger.info("User message added to thread.") | |
except Exception as e: | |
logger.exception(f"Failed to add user message to thread: {e}") | |
st.error("Failed to send your message. Please try again.") | |
return | |
# Create and stream assistant response | |
try: | |
with client.beta.threads.runs.stream( | |
thread_id=thread.id, | |
assistant_id=assistant_id, | |
) as run_stream: | |
st.session_state["current_run"] = run_stream | |
display_stream(run_stream) | |
except Exception as e: | |
logger.exception(f"Failed to stream assistant response: {e}") | |
st.error("Failed to receive a response from the assistant. Please try again.") | |
# Handle any function calls requested by the assistant | |
while not tool_requests.empty(): | |
event = tool_requests.get() | |
tool_outputs, t_id, run_id = handle_tool_request(event) | |
try: | |
with client.beta.threads.runs.submit_tool_outputs_stream( | |
thread_id=t_id, | |
run_id=run_id, | |
tool_outputs=tool_outputs | |
) as tool_stream: | |
display_stream(tool_stream, create_context=False) | |
except Exception as e: | |
logger.exception(f"Failed to submit tool outputs: {e}") | |
st.error("Failed to process a function call from the assistant.") | |
# ------------------------ | |
# Streamlit UI | |
# ------------------------ | |
def main(): | |
st.set_page_config(page_title="Solution Specifier A", layout="centered") | |
st.title("Solution Specifier A") | |
# Display existing conversation | |
for msg in st.session_state["messages"]: | |
display_message(msg["role"], msg["content"]) | |
user_input = st.chat_input("Type your message here...") | |
if user_input: | |
# Display user's message | |
display_message("user", user_input) | |
# Add user message to session state | |
st.session_state["messages"].append({"role": "user", "content": user_input}) | |
# Generate assistant reply | |
generate_assistant_reply(user_input) | |
if __name__ == "__main__": | |
main() |