|
import os |
|
import re |
|
import streamlit as st |
|
from dotenv import load_dotenv |
|
|
|
import io |
|
import time |
|
import json |
|
import queue |
|
import logging |
|
from PIL import Image |
|
from typing import Optional |
|
|
|
|
|
|
|
|
|
import openai |
|
from langsmith.wrappers import wrap_openai |
|
from langsmith import traceable |
|
|
|
|
|
|
|
|
|
def init_logging(): |
|
logging.basicConfig( |
|
format="[%(asctime)s] %(levelname)s: %(message)s", |
|
level=logging.INFO, |
|
handlers=[ |
|
logging.StreamHandler() |
|
] |
|
) |
|
return logging.getLogger() |
|
|
|
logger = init_logging() |
|
|
|
|
|
|
|
|
|
load_dotenv() |
|
api_key = os.getenv("OPENAI_API_KEY") |
|
assistant_id = os.getenv("ASSISTANT_ID_SOLUTION_SPECIFIER_A") |
|
|
|
if not api_key or not assistant_id: |
|
logger.error("Environment variables OPENAI_API_KEY and ASSISTANT_ID_SOLUTION_SPECIFIER_A must be set.") |
|
st.error("Missing environment configuration. Please set the required environment variables.") |
|
st.stop() |
|
|
|
|
|
|
|
|
|
openai_client = openai.Client(api_key=api_key) |
|
client = wrap_openai(openai_client) |
|
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state["messages"] = [] |
|
|
|
if "thread_id" not in st.session_state: |
|
st.session_state["thread_id"] = None |
|
|
|
if "tool_requests" not in st.session_state: |
|
st.session_state["tool_requests"] = queue.Queue() |
|
|
|
if "current_run" not in st.session_state: |
|
st.session_state["current_run"] = None |
|
|
|
tool_requests = st.session_state["tool_requests"] |
|
|
|
|
|
|
|
|
|
def remove_citation(text: str) -> str: |
|
pattern = r"γ\d+β \w+γ" |
|
return re.sub(pattern, "π", text) |
|
|
|
|
|
|
|
|
|
def handle_tool_request(event): |
|
""" |
|
Processes function call requests from the assistant. |
|
""" |
|
logger.info(f"Handling tool request: {event}") |
|
st.toast("Processing a function call...", icon=":hammer_and_wrench:") |
|
tool_outputs = [] |
|
data = event.data |
|
|
|
for tool_call in data.required_action.submit_tool_outputs.tool_calls: |
|
function_name = tool_call.function.name |
|
arguments = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {} |
|
|
|
logger.info(f"Executing function '{function_name}' with arguments {arguments}") |
|
|
|
try: |
|
|
|
if function_name == "hello_world": |
|
output = hello_world(**arguments) |
|
elif function_name == "another_function": |
|
output = another_function(**arguments) |
|
else: |
|
raise ValueError(f"Unrecognized function name: {function_name}") |
|
|
|
tool_outputs.append({"tool_call_id": tool_call.id, "output": output}) |
|
logger.info(f"Function '{function_name}' executed successfully.") |
|
|
|
except Exception as e: |
|
logger.error(f"Error executing function '{function_name}': {e}") |
|
error_response = {"status": "error", "message": str(e)} |
|
tool_outputs.append({"tool_call_id": tool_call.id, "output": json.dumps(error_response)}) |
|
|
|
st.toast("Function call completed.", icon=":white_check_mark:") |
|
return tool_outputs, data.thread_id, data.id |
|
|
|
|
|
|
|
|
|
def hello_world(name: str = "World") -> str: |
|
""" |
|
Example function that returns a greeting. |
|
""" |
|
time.sleep(2) |
|
return f"Hello, {name}! This message is from a function call." |
|
|
|
def another_function(param1: str, param2: int) -> str: |
|
""" |
|
Another example function. |
|
""" |
|
time.sleep(1) |
|
return f"Received param1: {param1} and param2: {param2}." |
|
|
|
|
|
|
|
|
|
def display_message(role: str, content: str): |
|
""" |
|
Displays a message in the Streamlit chat interface. |
|
""" |
|
with st.chat_message(role): |
|
if role == "assistant" and isinstance(content, Image.Image): |
|
st.image(content) |
|
else: |
|
st.write(content) |
|
|
|
|
|
|
|
|
|
def data_streamer(): |
|
""" |
|
Streams data from the assistant run. Yields text or images |
|
and enqueues tool requests (function calls) to tool_requests. |
|
""" |
|
logger.info("Starting data streamer.") |
|
st.toast("Thinking...", icon=":hourglass_flowing_sand:") |
|
content_produced = False |
|
accumulated_content = "" |
|
|
|
try: |
|
for event in st.session_state["current_run"]: |
|
match event.event: |
|
case "thread.message.delta": |
|
content = event.data.delta.content[0] |
|
match content.type: |
|
case "text": |
|
text_value = content.text.value |
|
accumulated_content += text_value |
|
content_produced = True |
|
yield remove_citation(text_value) |
|
|
|
case "image_file": |
|
file_id = content.image_file.file_id |
|
logger.info(f"Received image file ID: {file_id}") |
|
image_content = io.BytesIO(client.files.content(file_id).read()) |
|
image = Image.open(image_content) |
|
yield image |
|
|
|
case "thread.run.requires_action": |
|
logger.info(f"Run requires action: {event}") |
|
tool_requests.put(event) |
|
if not content_produced: |
|
yield "[LLM is requesting a function call...]" |
|
return |
|
|
|
case "thread.run.failed": |
|
logger.error(f"Run failed: {event}") |
|
st.error("The assistant encountered an error and couldn't complete the request.") |
|
return |
|
|
|
except Exception as e: |
|
logger.exception(f"Exception in data_streamer: {e}") |
|
st.error(f"An unexpected error occurred: {e}") |
|
|
|
finally: |
|
st.toast("Completed", icon=":checkered_flag:") |
|
|
|
|
|
|
|
|
|
def display_stream(run_stream, create_context=True): |
|
""" |
|
Grabs tokens from data_streamer() and displays them in real-time. |
|
If `create_context=True`, messages are displayed as an assistant block. |
|
""" |
|
logger.info("Displaying stream.") |
|
st.session_state["current_run"] = run_stream |
|
|
|
if create_context: |
|
with st.chat_message("assistant"): |
|
for content in data_streamer(): |
|
display_message("assistant", content) |
|
else: |
|
for content in data_streamer(): |
|
display_message("assistant", content) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if accumulated_text := remove_citation(accumulated_content.strip()): |
|
st.session_state["messages"].append({"role": "assistant", "content": accumulated_text}) |
|
|
|
|
|
|
|
|
|
@traceable |
|
def generate_assistant_reply(user_input: str): |
|
""" |
|
Handles user input by creating or continuing a thread, |
|
sending the message to the assistant, and streaming the response. |
|
""" |
|
logger.info(f"User input received: {user_input}") |
|
|
|
|
|
if not st.session_state["thread_id"]: |
|
logger.info("Creating a new thread.") |
|
thread = client.beta.threads.create() |
|
st.session_state["thread_id"] = thread.id |
|
else: |
|
thread = client.beta.threads.retrieve(thread_id=st.session_state["thread_id"]) |
|
logger.info(f"Using existing thread ID: {thread.id}") |
|
|
|
|
|
try: |
|
client.beta.threads.messages.create( |
|
thread_id=thread.id, |
|
role="user", |
|
content=user_input |
|
) |
|
logger.info("User message added to thread.") |
|
except Exception as e: |
|
logger.exception(f"Failed to add user message to thread: {e}") |
|
st.error("Failed to send your message. Please try again.") |
|
return |
|
|
|
|
|
try: |
|
with client.beta.threads.runs.stream( |
|
thread_id=thread.id, |
|
assistant_id=assistant_id, |
|
) as run_stream: |
|
st.session_state["current_run"] = run_stream |
|
display_stream(run_stream) |
|
except Exception as e: |
|
logger.exception(f"Failed to stream assistant response: {e}") |
|
st.error("Failed to receive a response from the assistant. Please try again.") |
|
|
|
|
|
while not tool_requests.empty(): |
|
event = tool_requests.get() |
|
tool_outputs, t_id, run_id = handle_tool_request(event) |
|
|
|
try: |
|
with client.beta.threads.runs.submit_tool_outputs_stream( |
|
thread_id=t_id, |
|
run_id=run_id, |
|
tool_outputs=tool_outputs |
|
) as tool_stream: |
|
display_stream(tool_stream, create_context=False) |
|
except Exception as e: |
|
logger.exception(f"Failed to submit tool outputs: {e}") |
|
st.error("Failed to process a function call from the assistant.") |
|
|
|
|
|
|
|
|
|
def main(): |
|
st.set_page_config(page_title="Solution Specifier A", layout="centered") |
|
st.title("Solution Specifier A") |
|
|
|
|
|
for msg in st.session_state["messages"]: |
|
display_message(msg["role"], msg["content"]) |
|
|
|
user_input = st.chat_input("Type your message here...") |
|
if user_input: |
|
|
|
display_message("user", user_input) |
|
|
|
|
|
st.session_state["messages"].append({"role": "user", "content": user_input}) |
|
|
|
|
|
generate_assistant_reply(user_input) |
|
|
|
if __name__ == "__main__": |
|
main() |