File size: 10,817 Bytes
73d49e1
b8736af
73d49e1
 
b02dba2
713dd57
b02dba2
713dd57
 
 
b02dba2
f6d4f89
713dd57
b02dba2
 
 
 
 
 
713dd57
b02dba2
f6d4f89
b02dba2
 
 
f6d4f89
b02dba2
f6d4f89
 
 
b02dba2
 
713dd57
b02dba2
713dd57
b02dba2
 
 
 
 
 
713dd57
b02dba2
f6d4f89
 
 
73d49e1
b02dba2
 
 
 
 
73d49e1
b02dba2
f6d4f89
b02dba2
b8736af
 
b02dba2
f6d4f89
 
b02dba2
713dd57
 
73d49e1
f6d4f89
 
 
713dd57
73d49e1
b02dba2
f6d4f89
b02dba2
 
 
 
b8736af
f6d4f89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b02dba2
 
 
 
 
 
 
 
f6d4f89
 
b02dba2
f6d4f89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b02dba2
 
 
 
 
 
 
 
 
f6d4f89
 
 
b02dba2
b8736af
f6d4f89
 
b02dba2
f6d4f89
 
b8736af
f6d4f89
 
 
 
 
 
 
713dd57
b02dba2
f6d4f89
b02dba2
f6d4f89
b02dba2
713dd57
f6d4f89
 
713dd57
f6d4f89
 
b02dba2
f6d4f89
 
 
 
 
 
 
b02dba2
 
f6d4f89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b02dba2
 
 
f6d4f89
 
 
 
 
 
 
 
 
 
 
b02dba2
 
 
 
713dd57
b02dba2
 
 
 
 
f6d4f89
b02dba2
 
713dd57
f6d4f89
 
b02dba2
f6d4f89
b02dba2
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
import os
import re
import streamlit as st
from dotenv import load_dotenv

import io
import time
import json
import queue
import logging
from PIL import Image
from typing import Optional

# ------------------------
# LangSmith imports
# ------------------------
import openai
from langsmith.wrappers import wrap_openai
from langsmith import traceable

# ------------------------
# Configure logging
# ------------------------
def init_logging():
    logging.basicConfig(
        format="[%(asctime)s] %(levelname)s: %(message)s",
        level=logging.INFO,
        handlers=[
            logging.StreamHandler()
        ]
    )
    return logging.getLogger()

logger = init_logging()

# ------------------------
# Load environment variables
# ------------------------
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
assistant_id = os.getenv("ASSISTANT_ID_SOLUTION_SPECIFIER_A")  # The assistant we want to call

if not api_key or not assistant_id:
    logger.error("Environment variables OPENAI_API_KEY and ASSISTANT_ID_SOLUTION_SPECIFIER_A must be set.")
    st.error("Missing environment configuration. Please set the required environment variables.")
    st.stop()

# ------------------------
# Wrap the OpenAI client for LangSmith traceability
# ------------------------
openai_client = openai.Client(api_key=api_key)
client = wrap_openai(openai_client)

# ------------------------
# Streamlit session state initialization
# ------------------------
if "messages" not in st.session_state:
    st.session_state["messages"] = []

if "thread_id" not in st.session_state:
    st.session_state["thread_id"] = None

if "tool_requests" not in st.session_state:
    st.session_state["tool_requests"] = queue.Queue()

if "current_run" not in st.session_state:
    st.session_state["current_run"] = None

tool_requests = st.session_state["tool_requests"]

# ------------------------
# Utility to remove citations like: 
# ------------------------
def remove_citation(text: str) -> str:
    pattern = r"【\d+†\w+】"
    return re.sub(pattern, "πŸ“š", text)

# ------------------------
# Function to handle tool requests (function calls)
# ------------------------
def handle_tool_request(event):
    """
    Processes function call requests from the assistant.
    """
    logger.info(f"Handling tool request: {event}")
    st.toast("Processing a function call...", icon=":hammer_and_wrench:")
    tool_outputs = []
    data = event.data

    for tool_call in data.required_action.submit_tool_outputs.tool_calls:
        function_name = tool_call.function.name
        arguments = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}

        logger.info(f"Executing function '{function_name}' with arguments {arguments}")

        try:
            # Map function names to actual implementations
            if function_name == "hello_world":
                output = hello_world(**arguments)
            elif function_name == "another_function":
                output = another_function(**arguments)
            else:
                raise ValueError(f"Unrecognized function name: {function_name}")

            tool_outputs.append({"tool_call_id": tool_call.id, "output": output})
            logger.info(f"Function '{function_name}' executed successfully.")

        except Exception as e:
            logger.error(f"Error executing function '{function_name}': {e}")
            error_response = {"status": "error", "message": str(e)}
            tool_outputs.append({"tool_call_id": tool_call.id, "output": json.dumps(error_response)})

    st.toast("Function call completed.", icon=":white_check_mark:")
    return tool_outputs, data.thread_id, data.id

# ------------------------
# Example function implementations
# ------------------------
def hello_world(name: str = "World") -> str:
    """
    Example function that returns a greeting.
    """
    time.sleep(2)  # Simulate a delay for a long-running task
    return f"Hello, {name}! This message is from a function call."

def another_function(param1: str, param2: int) -> str:
    """
    Another example function.
    """
    time.sleep(1)
    return f"Received param1: {param1} and param2: {param2}."

# ------------------------
# Streamlit UI Components
# ------------------------
def display_message(role: str, content: str):
    """
    Displays a message in the Streamlit chat interface.
    """
    with st.chat_message(role):
        if role == "assistant" and isinstance(content, Image.Image):
            st.image(content)
        else:
            st.write(content)

# ------------------------
# Helper: data streamer for text & images
# ------------------------
def data_streamer():
    """
    Streams data from the assistant run. Yields text or images
    and enqueues tool requests (function calls) to tool_requests.
    """
    logger.info("Starting data streamer.")
    st.toast("Thinking...", icon=":hourglass_flowing_sand:")
    content_produced = False
    accumulated_content = ""

    try:
        for event in st.session_state["current_run"]:
            match event.event:
                case "thread.message.delta":
                    content = event.data.delta.content[0]
                    match content.type:
                        case "text":
                            text_value = content.text.value
                            accumulated_content += text_value
                            content_produced = True
                            yield remove_citation(text_value)

                        case "image_file":
                            file_id = content.image_file.file_id
                            logger.info(f"Received image file ID: {file_id}")
                            image_content = io.BytesIO(client.files.content(file_id).read())
                            image = Image.open(image_content)
                            yield image

                case "thread.run.requires_action":
                    logger.info(f"Run requires action: {event}")
                    tool_requests.put(event)
                    if not content_produced:
                        yield "[LLM is requesting a function call...]"
                    return

                case "thread.run.failed":
                    logger.error(f"Run failed: {event}")
                    st.error("The assistant encountered an error and couldn't complete the request.")
                    return

    except Exception as e:
        logger.exception(f"Exception in data_streamer: {e}")
        st.error(f"An unexpected error occurred: {e}")

    finally:
        st.toast("Completed", icon=":checkered_flag:")

# ------------------------
# Helper: display the streaming content
# ------------------------
def display_stream(run_stream, create_context=True):
    """
    Grabs tokens from data_streamer() and displays them in real-time.
    If `create_context=True`, messages are displayed as an assistant block.
    """
    logger.info("Displaying stream.")
    st.session_state["current_run"] = run_stream

    if create_context:
        with st.chat_message("assistant"):
            for content in data_streamer():
                display_message("assistant", content)
    else:
        for content in data_streamer():
            display_message("assistant", content)

    # After streaming, accumulate the final content
    # This assumes that the entire content has been yielded
    # You might want to enhance this to handle partial content or interruptions
    # Here, we simply capture accumulated content if it's text
    # For images, it's already displayed
    if accumulated_text := remove_citation(accumulated_content.strip()):
        st.session_state["messages"].append({"role": "assistant", "content": accumulated_text})

# ------------------------
# Main chat logic with traceability
# ------------------------
@traceable  # Enable LangSmith traceability
def generate_assistant_reply(user_input: str):
    """
    Handles user input by creating or continuing a thread,
    sending the message to the assistant, and streaming the response.
    """
    logger.info(f"User input received: {user_input}")

    # Create or retrieve thread
    if not st.session_state["thread_id"]:
        logger.info("Creating a new thread.")
        thread = client.beta.threads.create()
        st.session_state["thread_id"] = thread.id
    else:
        thread = client.beta.threads.retrieve(thread_id=st.session_state["thread_id"])
        logger.info(f"Using existing thread ID: {thread.id}")

    # Add user message to the thread
    try:
        client.beta.threads.messages.create(
            thread_id=thread.id,
            role="user",
            content=user_input
        )
        logger.info("User message added to thread.")
    except Exception as e:
        logger.exception(f"Failed to add user message to thread: {e}")
        st.error("Failed to send your message. Please try again.")
        return

    # Create and stream assistant response
    try:
        with client.beta.threads.runs.stream(
            thread_id=thread.id,
            assistant_id=assistant_id,
        ) as run_stream:
            st.session_state["current_run"] = run_stream
            display_stream(run_stream)
    except Exception as e:
        logger.exception(f"Failed to stream assistant response: {e}")
        st.error("Failed to receive a response from the assistant. Please try again.")

    # Handle any function calls requested by the assistant
    while not tool_requests.empty():
        event = tool_requests.get()
        tool_outputs, t_id, run_id = handle_tool_request(event)

        try:
            with client.beta.threads.runs.submit_tool_outputs_stream(
                thread_id=t_id,
                run_id=run_id,
                tool_outputs=tool_outputs
            ) as tool_stream:
                display_stream(tool_stream, create_context=False)
        except Exception as e:
            logger.exception(f"Failed to submit tool outputs: {e}")
            st.error("Failed to process a function call from the assistant.")

# ------------------------
# Streamlit UI
# ------------------------
def main():
    st.set_page_config(page_title="Solution Specifier A", layout="centered")
    st.title("Solution Specifier A")

    # Display existing conversation
    for msg in st.session_state["messages"]:
        display_message(msg["role"], msg["content"])

    user_input = st.chat_input("Type your message here...")
    if user_input:
        # Display user's message
        display_message("user", user_input)

        # Add user message to session state
        st.session_state["messages"].append({"role": "user", "content": user_input})

        # Generate assistant reply
        generate_assistant_reply(user_input)

if __name__ == "__main__":
    main()