AbenzaFran commited on
Commit
88f878e
·
2 Parent(s): a2f3e2b 6853d31

Merge branch 'streaming-1st-shot' into main

Browse files
Files changed (1) hide show
  1. app.py +315 -107
app.py CHANGED
@@ -1,116 +1,324 @@
1
  import os
2
  import re
 
 
 
 
 
 
 
 
3
  import streamlit as st
4
- import openai
5
  from dotenv import load_dotenv
6
- from langchain.agents.openai_assistant import OpenAIAssistantRunnable
7
-
8
- # Load environment variables
9
- load_dotenv()
10
- api_key = os.getenv("OPENAI_API_KEY")
11
- extractor_agent = os.getenv("ASSISTANT_ID_SOLUTION_SPECIFIER_A")
12
-
13
- # Create the assistant
14
- extractor_llm = OpenAIAssistantRunnable(
15
- assistant_id=extractor_agent,
16
- api_key=api_key,
17
- as_agent=True
18
- )
19
-
20
- def remove_citation(text: str) -> str:
21
- pattern = r"【\d+†\w+】"
22
- return re.sub(pattern, "📚", text)
23
-
24
- # Initialize session state
25
- if "messages" not in st.session_state:
26
- st.session_state["messages"] = []
27
- if "thread_id" not in st.session_state:
28
- st.session_state["thread_id"] = None
29
- # A flag to indicate if a request is in progress
30
- if "is_in_request" not in st.session_state:
31
- st.session_state["is_in_request"] = False
32
-
33
- st.title("Solution Specifier A")
34
-
35
- def predict(user_input: str) -> str:
36
- """
37
- This function calls our OpenAIAssistantRunnable to get a response.
38
- If st.session_state["thread_id"] is None, we start a new thread.
39
- Otherwise, we continue the existing thread.
40
-
41
- If a concurrency error occurs ("Can't add messages to thread..."), we reset
42
- the thread_id and try again once on a fresh thread.
43
- """
44
- try:
45
- if st.session_state["thread_id"] is None:
46
- # Start a new thread
47
- response = extractor_llm.invoke({"content": user_input})
48
- st.session_state["thread_id"] = response.thread_id
49
- else:
50
- # Continue existing thread
51
- response = extractor_llm.invoke(
52
- {"content": user_input, "thread_id": st.session_state["thread_id"]}
 
 
 
 
 
 
 
 
53
  )
 
 
54
 
55
- output = response.return_values["output"]
56
- return remove_citation(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- except openai.error.BadRequestError as e:
59
- # If we get the specific concurrency error, reset thread and try once more
60
- if "while a run" in str(e):
61
- st.session_state["thread_id"] = None
62
- # Now create a new thread for the same user input
63
- try:
64
- response = extractor_llm.invoke({"content": user_input})
65
- st.session_state["thread_id"] = response.thread_id
66
- output = response.return_values["output"]
67
- return remove_citation(output)
68
- except Exception as e2:
69
- st.error(f"Error after resetting thread: {e2}")
70
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  else:
72
- # Some other 400 error
73
- st.error(str(e))
74
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  except Exception as e:
76
- st.error(str(e))
77
- return ""
78
-
79
- # Display any existing messages
80
- for msg in st.session_state["messages"]:
81
- if msg["role"] == "user":
82
- with st.chat_message("user"):
83
- st.write(msg["content"])
84
- else:
85
- with st.chat_message("assistant"):
86
- st.write(msg["content"])
87
-
88
- # Chat input at the bottom of the page
89
- user_input = st.chat_input("Type your message here...")
90
-
91
- # Process the user input only if:
92
- # 1) There is some text, and
93
- # 2) We are not already handling a request (is_in_request == False)
94
- if user_input and not st.session_state["is_in_request"]:
95
- # Lock to prevent duplicate requests
96
- st.session_state["is_in_request"] = True
97
-
98
- # Add the user message to session state
99
- st.session_state["messages"].append({"role": "user", "content": user_input})
100
-
101
- # Display the user's message
102
- with st.chat_message("user"):
103
- st.write(user_input)
104
-
105
- # Get assistant response
106
- response_text = predict(user_input)
107
-
108
- # Add assistant response to session state
109
- st.session_state["messages"].append({"role": "assistant", "content": response_text})
110
-
111
- # Display assistant response
112
- with st.chat_message("assistant"):
113
- st.write(response_text)
114
-
115
- # Release the lock
116
- st.session_state["is_in_request"] = False
 
1
  import os
2
  import re
3
+ import io
4
+ import time
5
+ import json
6
+ import queue
7
+ import logging
8
+ from typing import Any, Generator, Optional, List, Dict, Tuple
9
+ from dataclasses import dataclass
10
+
11
  import streamlit as st
 
12
  from dotenv import load_dotenv
13
+ from PIL import Image
14
+ import openai
15
+ from langsmith.wrappers import wrap_openai
16
+ from langsmith import traceable
17
+
18
+ # ------------------------
19
+ # Configuration and Types
20
+ # ------------------------
21
+ @dataclass
22
+ class AppConfig:
23
+ """Application configuration settings."""
24
+ page_title: str = "Solution Specifier A"
25
+ page_icon: str = "🤖"
26
+ layout: str = "centered"
27
+
28
+ @dataclass
29
+ class Message:
30
+ """Chat message structure."""
31
+ role: str
32
+ content: str
33
+
34
+ class StreamingError(Exception):
35
+ """Custom exception for streaming-related errors."""
36
+ pass
37
+
38
+ # ------------------------
39
+ # Logging Configuration
40
+ # ------------------------
41
+ def setup_logging() -> logging.Logger:
42
+ """Configure and return the application logger."""
43
+ logging.basicConfig(
44
+ format="[%(asctime)s] %(levelname)+8s: %(message)s",
45
+ level=logging.INFO,
46
+ )
47
+ return logging.getLogger(__name__)
48
+
49
+ logger = setup_logging()
50
+
51
+ # ------------------------
52
+ # Environment Setup
53
+ # ------------------------
54
+ class EnvironmentManager:
55
+ """Manages environment variables and configuration."""
56
+
57
+ @staticmethod
58
+ def load_environment() -> Tuple[str, str]:
59
+ """Load and validate environment variables."""
60
+ load_dotenv(override=True)
61
+ api_key = os.getenv("OPENAI_API_KEY")
62
+ assistant_id = os.getenv("ASSISTANT_ID_SOLUTION_SPECIFIER_A")
63
+
64
+ if not api_key or not assistant_id:
65
+ raise RuntimeError(
66
+ "Missing required environment variables. Please set "
67
+ "OPENAI_API_KEY and ASSISTANT_ID_SOLUTION_SPECIFIER_A"
68
  )
69
+
70
+ return api_key, assistant_id
71
 
72
+ # ------------------------
73
+ # State Management
74
+ # ------------------------
75
+ class StateManager:
76
+ """Manages Streamlit session state."""
77
+
78
+ @staticmethod
79
+ def initialize_state() -> None:
80
+ """Initialize session state variables."""
81
+ if "messages" not in st.session_state:
82
+ st.session_state.messages = []
83
+ if "thread" not in st.session_state:
84
+ st.session_state.thread = None
85
+ if "tool_requests" not in st.session_state:
86
+ st.session_state.tool_requests = queue.Queue()
87
+ if "run_stream" not in st.session_state:
88
+ st.session_state.run_stream = None
89
 
90
+ @staticmethod
91
+ def add_message(role: str, content: str) -> None:
92
+ """Add a message to the conversation history."""
93
+ st.session_state.messages.append(Message(role=role, content=content))
94
+
95
+ # ------------------------
96
+ # Text Processing
97
+ # ------------------------
98
+ class TextProcessor:
99
+ """Handles text processing and formatting."""
100
+
101
+ @staticmethod
102
+ def remove_citations(text: str) -> str:
103
+ """Remove citation markers from text."""
104
+ pattern = r"【\d+†\w+】"
105
+ return re.sub(pattern, "📚", text)
106
+
107
+ # ------------------------
108
+ # Streaming Handler
109
+ # ------------------------
110
+ class StreamHandler:
111
+ """Handles streaming of assistant responses."""
112
+
113
+ def __init__(self, client: Any):
114
+ self.client = client
115
+ self.text_processor = TextProcessor()
116
+ self.complete_response = []
117
+
118
+ def stream_data(self) -> Generator[Any, None, None]:
119
+ """Stream data from the assistant run."""
120
+ st.toast("Thinking...", icon="🤔")
121
+ content_produced = False
122
+ self.complete_response = [] # Reset for new stream
123
+
124
+ try:
125
+ for event in st.session_state.run_stream:
126
+ match event.event:
127
+ case "thread.message.delta":
128
+ yield from self._handle_message_delta(event, content_produced)
129
+ case "thread.run.requires_action":
130
+ yield from self._handle_action_request(event, content_produced)
131
+ case "thread.run.failed":
132
+ logger.error(f"Run failed: {event}")
133
+ raise StreamingError(f"Assistant run failed: {event}")
134
+
135
+ st.toast("Completed", icon="✅")
136
+ # Return the complete response for storage
137
+ return "".join(self.complete_response)
138
+ except Exception as e:
139
+ logger.error(f"Streaming error: {e}")
140
+ st.error(f"An error occurred while streaming: {str(e)}")
141
+ raise
142
+
143
+ def _handle_message_delta(self, event: Any, content_produced: bool) -> Generator[Any, None, None]:
144
+ """Handle message delta events."""
145
+ content = event.data.delta.content[0]
146
+ match content.type:
147
+ case "text":
148
+ processed_text = self.text_processor.remove_citations(content.text.value)
149
+ self.complete_response.append(processed_text) # Store the chunk
150
+ yield processed_text
151
+ case "image_file":
152
+ image_content = io.BytesIO(self.client.files.content(content.image_file.file_id).read())
153
+ yield Image.open(image_content)
154
+
155
+ def _handle_action_request(self, event: Any, content_produced: bool) -> Generator[str, None, None]:
156
+ """Handle action request events."""
157
+ logger.info(f"[Tool Request] {event}")
158
+ st.session_state.tool_requests.put(event)
159
+ if not content_produced:
160
+ yield "[Processing function call...]"
161
+
162
+ # ------------------------
163
+ # Tool Request Handler
164
+ # ------------------------
165
+ class ToolRequestHandler:
166
+ """Handles tool requests from the assistant."""
167
+
168
+ @staticmethod
169
+ def handle_request(event: Any) -> Tuple[List[Dict[str, str]], str, str]:
170
+ """Process tool requests and return outputs."""
171
+ st.toast("Processing function call...", icon="⚙️")
172
+ tool_outputs = []
173
+ data = event.data
174
+
175
+ for tool_call in data.required_action.submit_tool_outputs.tool_calls:
176
+ output = ToolRequestHandler._process_tool_call(tool_call)
177
+ tool_outputs.append(output)
178
+
179
+ return tool_outputs, data.thread_id, data.id
180
+
181
+ @staticmethod
182
+ def _process_tool_call(tool_call: Any) -> Dict[str, str]:
183
+ """Process individual tool calls."""
184
+ function_args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
185
+
186
+ match tool_call.function.name:
187
+ case "hello_world":
188
+ name = function_args.get("name", "anonymous")
189
+ output_val = f"Hello, {name}! This was from a local function."
190
+ case _:
191
+ output_val = json.dumps({"status": "error", "message": "Unknown function request."})
192
+
193
+ return {"tool_call_id": tool_call.id, "output": output_val}
194
+
195
+ # ------------------------
196
+ # Assistant Manager
197
+ # ------------------------
198
+ class AssistantManager:
199
+ """Manages interactions with the OpenAI Assistant."""
200
+
201
+ def __init__(self, client: Any, assistant_id: str):
202
+ self.client = client
203
+ self.assistant_id = assistant_id
204
+ self.stream_handler = StreamHandler(client)
205
+ self.tool_handler = ToolRequestHandler()
206
+
207
+ @traceable
208
+ def generate_reply(self, user_input: str) -> str:
209
+ """Generate and stream assistant's reply."""
210
+ # Ensure thread exists
211
+ if not st.session_state.thread:
212
+ st.session_state.thread = self.client.beta.threads.create()
213
+
214
+ # Add user message
215
+ self.client.beta.threads.messages.create(
216
+ thread_id=st.session_state.thread.id,
217
+ role="user",
218
+ content=user_input
219
+ )
220
+
221
+ complete_response = ""
222
+
223
+ # Stream initial response
224
+ with self.client.beta.threads.runs.stream(
225
+ thread_id=st.session_state.thread.id,
226
+ assistant_id=self.assistant_id,
227
+ ) as run_stream:
228
+ complete_response = self._display_stream(run_stream)
229
+
230
+ # Handle any tool requests
231
+ self._process_tool_requests()
232
+
233
+ return complete_response
234
+
235
+ def _display_stream(self, run_stream: Any, create_context: bool = True) -> str:
236
+ """Display streaming content."""
237
+ st.session_state.run_stream = run_stream
238
+ if create_context:
239
+ with st.chat_message("assistant"):
240
+ return st.write_stream(self.stream_handler.stream_data)
241
  else:
242
+ return st.write_stream(self.stream_handler.stream_data)
243
+
244
+ def _process_tool_requests(self) -> None:
245
+ """Process any pending tool requests."""
246
+ while not st.session_state.tool_requests.empty():
247
+ event = st.session_state.tool_requests.get()
248
+ tool_outputs, thread_id, run_id = self.tool_handler.handle_request(event)
249
+
250
+ with self.client.beta.threads.runs.submit_tool_outputs_stream(
251
+ thread_id=thread_id,
252
+ run_id=run_id,
253
+ tool_outputs=tool_outputs
254
+ ) as next_stream:
255
+ self._display_stream(next_stream, create_context=False)
256
+
257
+ # ------------------------
258
+ # Main Application
259
+ # ------------------------
260
+ class ChatApplication:
261
+ """Main chat application class."""
262
+
263
+ def __init__(self):
264
+ self.config = AppConfig()
265
+ api_key, assistant_id = EnvironmentManager.load_environment()
266
+
267
+ # Initialize OpenAI client
268
+ openai_client = openai.Client(api_key=api_key)
269
+ self.client = wrap_openai(openai_client)
270
+
271
+ # Initialize components
272
+ self.state_manager = StateManager()
273
+ self.assistant_manager = AssistantManager(self.client, assistant_id)
274
+
275
+ def setup_page(self) -> None:
276
+ """Configure the Streamlit page."""
277
+ st.set_page_config(
278
+ page_title=self.config.page_title,
279
+ page_icon=self.config.page_icon,
280
+ layout=self.config.layout
281
+ )
282
+ st.title(self.config.page_title)
283
+
284
+ def display_chat_history(self) -> None:
285
+ """Display the chat history."""
286
+ for msg in st.session_state.messages:
287
+ with st.chat_message(msg.role):
288
+ st.write(msg.content)
289
+
290
+ def run(self) -> None:
291
+ """Run the chat application."""
292
+ self.setup_page()
293
+ self.state_manager.initialize_state()
294
+ self.display_chat_history()
295
+
296
+ user_input = st.chat_input("Type your message here...")
297
+ if user_input:
298
+ # Display and store user message
299
+ with st.chat_message("user"):
300
+ st.write(user_input)
301
+ self.state_manager.add_message("user", user_input)
302
+
303
+ # Generate and display assistant reply
304
+ try:
305
+ complete_response = self.assistant_manager.generate_reply(user_input)
306
+ self.state_manager.add_message(
307
+ "assistant",
308
+ complete_response
309
+ )
310
+ except Exception as e:
311
+ st.error(f"Error generating response: {str(e)}")
312
+ logger.exception("Error in assistant reply generation")
313
+
314
+ def main():
315
+ """Application entry point."""
316
+ try:
317
+ app = ChatApplication()
318
+ app.run()
319
  except Exception as e:
320
+ st.error(f"Application error: {str(e)}")
321
+ logger.exception("Fatal application error")
322
+
323
+ if __name__ == "__main__":
324
+ main()