Spaces:
Sleeping
Sleeping
import os | |
import re | |
import streamlit as st | |
from dotenv import load_dotenv | |
import io | |
import time | |
import json | |
import queue | |
import logging | |
from PIL import Image | |
# ------------------------ | |
# LangSmith imports | |
# ------------------------ | |
import openai | |
from langsmith.wrappers import wrap_openai | |
from langsmith import traceable | |
# ------------------------ | |
# Configure logging (optional but recommended) | |
# ------------------------ | |
def init_logging(): | |
logging.basicConfig( | |
format="[%(asctime)s] %(levelname)+8s: %(message)s", | |
level=logging.INFO, | |
) | |
return logging.getLogger() | |
logger = init_logging() | |
# ------------------------ | |
# Load environment variables | |
# ------------------------ | |
load_dotenv() | |
api_key = os.getenv("OPENAI_API_KEY") | |
assistant_id = os.getenv("ASSISTANT_ID_SOLUTION_SPECIFIER_A") # The assistant we want to call | |
if not api_key or not assistant_id: | |
raise RuntimeError("Please set OPENAI_API_KEY and ASSISTANT_ID_SOLUTION_SPECIFIER_A in your environment") | |
# ------------------------ | |
# Wrap the OpenAI client for LangSmith traceability | |
# ------------------------ | |
openai_client = openai.Client(api_key=api_key) | |
client = wrap_openai(openai_client) | |
# ------------------------ | |
# Streamlit session state | |
# ------------------------ | |
if "messages" not in st.session_state: | |
st.session_state["messages"] = [] | |
if "thread" not in st.session_state: | |
st.session_state["thread"] = None | |
if "tool_requests" not in st.session_state: | |
st.session_state["tool_requests"] = queue.Queue() | |
tool_requests = st.session_state["tool_requests"] | |
# ------------------------ | |
# Utility to remove citations like: 【12†somefile】 | |
# You can adapt to your own "annotations" handling if needed | |
# ------------------------ | |
def remove_citation(text: str) -> str: | |
pattern = r"【\d+†\w+】" | |
return re.sub(pattern, "📚", text) | |
# ------------------------ | |
# Helper: data streamer for text & images | |
# Adapted from the Medium article approach | |
# to handle text deltas, images, or function calls | |
# ------------------------ | |
def data_streamer(): | |
""" | |
Streams data from the assistant run. Yields text or images | |
and enqueues tool requests (function calls) to tool_requests. | |
""" | |
st.toast("Thinking...", icon=":material/emoji_objects:") | |
content_produced = False | |
for event in st.session_state["run_stream"]: | |
match event.event: | |
case "thread.message.delta": | |
# A chunk of text or an image | |
content = event.data.delta.content[0] | |
match content.type: | |
case "text": | |
text_value = content.text.value | |
content_produced = True | |
# Optionally remove citations, etc. | |
yield remove_citation(text_value) | |
case "image_file": | |
# If the assistant returns an image | |
file_id = content.image_file.file_id | |
content_produced = True | |
image_content = io.BytesIO(client.files.content(file_id).read()) | |
yield Image.open(image_content) | |
case "thread.run.requires_action": | |
# The assistant is requesting a function call | |
logger.info(f"[Tool Request] {event}") | |
tool_requests.put(event) | |
if not content_produced: | |
# We can yield a placeholder if the model hasn't said anything yet | |
yield "[LLM is requesting a function call]" | |
return | |
case "thread.run.failed": | |
# The run failed for some reason | |
logger.error(f"Run failed: {event}") | |
return | |
# If we successfully streamed everything | |
st.toast("Completed", icon=":material/emoji_objects:") | |
# ------------------------ | |
# Helper: display the streaming content | |
# This wraps data_streamer in st.write_stream | |
# so you can see partial tokens in real-time | |
# ------------------------ | |
def display_stream(run_stream, create_context=True): | |
""" | |
Grabs tokens from data_streamer() and displays them in real-time. | |
If `create_context=True`, messages are displayed as an assistant block. | |
""" | |
st.session_state["run_stream"] = run_stream | |
if create_context: | |
with st.chat_message("assistant"): | |
streamed_result = st.write_stream(data_streamer) | |
else: | |
streamed_result = st.write_stream(data_streamer) | |
# Return whatever the final token stream is | |
return streamed_result | |
# ------------------------ | |
# Example of handling a function call (requires_action) | |
# If your Assistant uses function calling (e.g. code interpreter), | |
# you'd parse arguments, run the function, and return output here. | |
# ------------------------ | |
def handle_tool_request(event): | |
""" | |
Demonstrates how you might handle a function call. | |
In practice, you'd parse the arguments from the event | |
and run your custom logic. Then return outputs as JSON. | |
""" | |
st.toast("Running a function (this is user-defined code)", icon=":material/function:") | |
tool_outputs = [] | |
data = event.data | |
for tool_call in data.required_action.submit_tool_outputs.tool_calls: | |
if tool_call.function.arguments: | |
function_args = json.loads(tool_call.function.arguments) | |
else: | |
function_args = {} | |
match tool_call.function.name: | |
case "hello_world": | |
# Example: implement a user-defined function | |
name = function_args.get("name", "anonymous") | |
time.sleep(2) # Simulate a long-running function | |
output_val = f"Hello, {name}! This was from a local function." | |
tool_outputs.append({"tool_call_id": tool_call.id, "output": output_val}) | |
case _: | |
# If unknown function name | |
msg = {"status": "error", "message": "Unknown function request."} | |
tool_outputs.append({"tool_call_id": tool_call.id, "output": json.dumps(msg)}) | |
return tool_outputs, data.thread_id, data.id | |
# ------------------------ | |
# Main chat logic | |
# ------------------------ | |
# Make this function traceable via LangSmith | |
def generate_assistant_reply(user_input: str): | |
""" | |
1. If no thread exists, create a new one. | |
2. Insert user message into the thread. | |
3. Use the Assistants API to create a run + stream the response. | |
4. If the assistant requests a function call, handle it and stream again. | |
""" | |
# Create or retrieve thread | |
if not st.session_state["thread"]: | |
st.session_state["thread"] = client.beta.threads.create() | |
thread = st.session_state["thread"] | |
# Add user message to the thread | |
client.beta.threads.messages.create( | |
thread_id=thread.id, | |
role="user", | |
content=user_input | |
) | |
# Start streaming assistant response | |
with client.beta.threads.runs.stream( | |
thread_id=thread.id, | |
assistant_id=assistant_id, | |
) as run_stream: | |
display_stream(run_stream) | |
# If the assistant requested any tool calls, handle them now | |
while not tool_requests.empty(): | |
event = tool_requests.get() | |
tool_outputs, t_id, run_id = handle_tool_request(event) | |
# Submit tool outputs | |
with client.beta.threads.runs.submit_tool_outputs_stream( | |
thread_id=t_id, run_id=run_id, tool_outputs=tool_outputs | |
) as next_stream: | |
display_stream(next_stream, create_context=False) | |
# ------------------------ | |
# Streamlit UI | |
# ------------------------ | |
def main(): | |
st.set_page_config(page_title="Solution Specifier A", layout="centered") | |
st.title("Solution Specifier A") | |
# Display existing conversation | |
for msg in st.session_state["messages"]: | |
with st.chat_message(msg["role"]): | |
st.write(msg["content"]) | |
user_input = st.chat_input("Type your message here...") | |
if user_input: | |
# Show user's message | |
with st.chat_message("user"): | |
st.write(user_input) | |
# Keep in session state | |
st.session_state["messages"].append({"role": "user", "content": user_input}) | |
# Generate assistant reply | |
generate_assistant_reply(user_input) | |
# In a real app, you might keep track of the final text | |
# from the streamed tokens. For simplicity, we store | |
# the entire streamed result as one block in session state: | |
st.session_state["messages"].append( | |
{"role": "assistant", "content": "[assistant reply streamed above]"} | |
) | |
if __name__ == "__main__": | |
main() |