AbenzaFran's picture
refactor
b02dba2
raw
history blame
8.66 kB
import os
import re
import streamlit as st
from dotenv import load_dotenv
import io
import time
import json
import queue
import logging
from PIL import Image
# ------------------------
# LangSmith imports
# ------------------------
import openai
from langsmith.wrappers import wrap_openai
from langsmith import traceable
# ------------------------
# Configure logging (optional but recommended)
# ------------------------
def init_logging():
logging.basicConfig(
format="[%(asctime)s] %(levelname)+8s: %(message)s",
level=logging.INFO,
)
return logging.getLogger()
logger = init_logging()
# ------------------------
# Load environment variables
# ------------------------
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
assistant_id = os.getenv("ASSISTANT_ID_SOLUTION_SPECIFIER_A") # The assistant we want to call
if not api_key or not assistant_id:
raise RuntimeError("Please set OPENAI_API_KEY and ASSISTANT_ID_SOLUTION_SPECIFIER_A in your environment")
# ------------------------
# Wrap the OpenAI client for LangSmith traceability
# ------------------------
openai_client = openai.Client(api_key=api_key)
client = wrap_openai(openai_client)
# ------------------------
# Streamlit session state
# ------------------------
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "thread" not in st.session_state:
st.session_state["thread"] = None
if "tool_requests" not in st.session_state:
st.session_state["tool_requests"] = queue.Queue()
tool_requests = st.session_state["tool_requests"]
# ------------------------
# Utility to remove citations like: 【12†somefile】
# You can adapt to your own "annotations" handling if needed
# ------------------------
def remove_citation(text: str) -> str:
pattern = r"【\d+†\w+】"
return re.sub(pattern, "📚", text)
# ------------------------
# Helper: data streamer for text & images
# Adapted from the Medium article approach
# to handle text deltas, images, or function calls
# ------------------------
def data_streamer():
"""
Streams data from the assistant run. Yields text or images
and enqueues tool requests (function calls) to tool_requests.
"""
st.toast("Thinking...", icon=":material/emoji_objects:")
content_produced = False
for event in st.session_state["run_stream"]:
match event.event:
case "thread.message.delta":
# A chunk of text or an image
content = event.data.delta.content[0]
match content.type:
case "text":
text_value = content.text.value
content_produced = True
# Optionally remove citations, etc.
yield remove_citation(text_value)
case "image_file":
# If the assistant returns an image
file_id = content.image_file.file_id
content_produced = True
image_content = io.BytesIO(client.files.content(file_id).read())
yield Image.open(image_content)
case "thread.run.requires_action":
# The assistant is requesting a function call
logger.info(f"[Tool Request] {event}")
tool_requests.put(event)
if not content_produced:
# We can yield a placeholder if the model hasn't said anything yet
yield "[LLM is requesting a function call]"
return
case "thread.run.failed":
# The run failed for some reason
logger.error(f"Run failed: {event}")
return
# If we successfully streamed everything
st.toast("Completed", icon=":material/emoji_objects:")
# ------------------------
# Helper: display the streaming content
# This wraps data_streamer in st.write_stream
# so you can see partial tokens in real-time
# ------------------------
def display_stream(run_stream, create_context=True):
"""
Grabs tokens from data_streamer() and displays them in real-time.
If `create_context=True`, messages are displayed as an assistant block.
"""
st.session_state["run_stream"] = run_stream
if create_context:
with st.chat_message("assistant"):
streamed_result = st.write_stream(data_streamer)
else:
streamed_result = st.write_stream(data_streamer)
# Return whatever the final token stream is
return streamed_result
# ------------------------
# Example of handling a function call (requires_action)
# If your Assistant uses function calling (e.g. code interpreter),
# you'd parse arguments, run the function, and return output here.
# ------------------------
def handle_tool_request(event):
"""
Demonstrates how you might handle a function call.
In practice, you'd parse the arguments from the event
and run your custom logic. Then return outputs as JSON.
"""
st.toast("Running a function (this is user-defined code)", icon=":material/function:")
tool_outputs = []
data = event.data
for tool_call in data.required_action.submit_tool_outputs.tool_calls:
if tool_call.function.arguments:
function_args = json.loads(tool_call.function.arguments)
else:
function_args = {}
match tool_call.function.name:
case "hello_world":
# Example: implement a user-defined function
name = function_args.get("name", "anonymous")
time.sleep(2) # Simulate a long-running function
output_val = f"Hello, {name}! This was from a local function."
tool_outputs.append({"tool_call_id": tool_call.id, "output": output_val})
case _:
# If unknown function name
msg = {"status": "error", "message": "Unknown function request."}
tool_outputs.append({"tool_call_id": tool_call.id, "output": json.dumps(msg)})
return tool_outputs, data.thread_id, data.id
# ------------------------
# Main chat logic
# ------------------------
@traceable # Make this function traceable via LangSmith
def generate_assistant_reply(user_input: str):
"""
1. If no thread exists, create a new one.
2. Insert user message into the thread.
3. Use the Assistants API to create a run + stream the response.
4. If the assistant requests a function call, handle it and stream again.
"""
# Create or retrieve thread
if not st.session_state["thread"]:
st.session_state["thread"] = client.beta.threads.create()
thread = st.session_state["thread"]
# Add user message to the thread
client.beta.threads.messages.create(
thread_id=thread.id,
role="user",
content=user_input
)
# Start streaming assistant response
with client.beta.threads.runs.stream(
thread_id=thread.id,
assistant_id=assistant_id,
) as run_stream:
display_stream(run_stream)
# If the assistant requested any tool calls, handle them now
while not tool_requests.empty():
event = tool_requests.get()
tool_outputs, t_id, run_id = handle_tool_request(event)
# Submit tool outputs
with client.beta.threads.runs.submit_tool_outputs_stream(
thread_id=t_id, run_id=run_id, tool_outputs=tool_outputs
) as next_stream:
display_stream(next_stream, create_context=False)
# ------------------------
# Streamlit UI
# ------------------------
def main():
st.set_page_config(page_title="Solution Specifier A", layout="centered")
st.title("Solution Specifier A")
# Display existing conversation
for msg in st.session_state["messages"]:
with st.chat_message(msg["role"]):
st.write(msg["content"])
user_input = st.chat_input("Type your message here...")
if user_input:
# Show user's message
with st.chat_message("user"):
st.write(user_input)
# Keep in session state
st.session_state["messages"].append({"role": "user", "content": user_input})
# Generate assistant reply
generate_assistant_reply(user_input)
# In a real app, you might keep track of the final text
# from the streamed tokens. For simplicity, we store
# the entire streamed result as one block in session state:
st.session_state["messages"].append(
{"role": "assistant", "content": "[assistant reply streamed above]"}
)
if __name__ == "__main__":
main()