Spaces:
Sleeping
Sleeping
import os | |
import re | |
import io | |
import time | |
import json | |
import queue | |
import logging | |
from typing import Any, Generator, Optional, List, Dict, Tuple | |
from dataclasses import dataclass | |
import streamlit as st | |
from dotenv import load_dotenv | |
from PIL import Image | |
import openai | |
from langsmith.wrappers import wrap_openai | |
from langsmith import traceable | |
# ------------------------ | |
# Configuration and Types | |
# ------------------------ | |
class AppConfig: | |
"""Application configuration settings.""" | |
page_title: str = "Solution Specifier A" | |
page_icon: str = "π€" | |
layout: str = "centered" | |
class Message: | |
"""Chat message structure.""" | |
role: str | |
content: str | |
class StreamingError(Exception): | |
"""Custom exception for streaming-related errors.""" | |
pass | |
# ------------------------ | |
# Logging Configuration | |
# ------------------------ | |
def setup_logging() -> logging.Logger: | |
"""Configure and return the application logger.""" | |
logging.basicConfig( | |
format="[%(asctime)s] %(levelname)+8s: %(message)s", | |
level=logging.INFO, | |
) | |
return logging.getLogger(__name__) | |
logger = setup_logging() | |
# ------------------------ | |
# Environment Setup | |
# ------------------------ | |
class EnvironmentManager: | |
"""Manages environment variables and configuration.""" | |
def load_environment() -> Tuple[str, str]: | |
"""Load and validate environment variables.""" | |
load_dotenv(override=True) | |
api_key = os.getenv("OPENAI_API_KEY") | |
assistant_id = os.getenv("ASSISTANT_ID_SOLUTION_SPECIFIER_A") | |
if not api_key or not assistant_id: | |
raise RuntimeError( | |
"Missing required environment variables. Please set " | |
"OPENAI_API_KEY and ASSISTANT_ID_SOLUTION_SPECIFIER_A" | |
) | |
return api_key, assistant_id | |
# ------------------------ | |
# State Management | |
# ------------------------ | |
class StateManager: | |
"""Manages Streamlit session state.""" | |
def initialize_state() -> None: | |
"""Initialize session state variables.""" | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "thread" not in st.session_state: | |
st.session_state.thread = None | |
if "tool_requests" not in st.session_state: | |
st.session_state.tool_requests = queue.Queue() | |
if "run_stream" not in st.session_state: | |
st.session_state.run_stream = None | |
def add_message(role: str, content: str) -> None: | |
"""Add a message to the conversation history.""" | |
st.session_state.messages.append(Message(role=role, content=content)) | |
# ------------------------ | |
# Text Processing | |
# ------------------------ | |
class TextProcessor: | |
"""Handles text processing and formatting.""" | |
def remove_citations(text: str) -> str: | |
"""Remove citation markers from text.""" | |
pattern = r"γ\d+β \w+γ" | |
return re.sub(pattern, "π", text) | |
# ------------------------ | |
# Streaming Handler | |
# ------------------------ | |
class StreamHandler: | |
"""Handles streaming of assistant responses.""" | |
def __init__(self, client: Any): | |
self.client = client | |
self.text_processor = TextProcessor() | |
self.complete_response = [] | |
def stream_data(self) -> Generator[Any, None, None]: | |
"""Stream data from the assistant run.""" | |
st.toast("Thinking...", icon="π€") | |
content_produced = False | |
self.complete_response = [] # Reset for new stream | |
try: | |
for event in st.session_state.run_stream: | |
match event.event: | |
case "thread.message.delta": | |
yield from self._handle_message_delta(event, content_produced) | |
case "thread.run.requires_action": | |
yield from self._handle_action_request(event, content_produced) | |
case "thread.run.failed": | |
logger.error(f"Run failed: {event}") | |
raise StreamingError(f"Assistant run failed: {event}") | |
st.toast("Completed", icon="β ") | |
# Return the complete response for storage | |
return "".join(self.complete_response) | |
except Exception as e: | |
logger.error(f"Streaming error: {e}") | |
st.error(f"An error occurred while streaming: {str(e)}") | |
raise | |
def _handle_message_delta(self, event: Any, content_produced: bool) -> Generator[Any, None, None]: | |
"""Handle message delta events.""" | |
content = event.data.delta.content[0] | |
match content.type: | |
case "text": | |
processed_text = self.text_processor.remove_citations(content.text.value) | |
self.complete_response.append(processed_text) # Store the chunk | |
yield processed_text | |
case "image_file": | |
image_content = io.BytesIO(self.client.files.content(content.image_file.file_id).read()) | |
yield Image.open(image_content) | |
def _handle_action_request(self, event: Any, content_produced: bool) -> Generator[str, None, None]: | |
"""Handle action request events.""" | |
logger.info(f"[Tool Request] {event}") | |
st.session_state.tool_requests.put(event) | |
if not content_produced: | |
yield "[Processing function call...]" | |
# ------------------------ | |
# Tool Request Handler | |
# ------------------------ | |
class ToolRequestHandler: | |
"""Handles tool requests from the assistant.""" | |
def handle_request(event: Any) -> Tuple[List[Dict[str, str]], str, str]: | |
"""Process tool requests and return outputs.""" | |
st.toast("Processing function call...", icon="βοΈ") | |
tool_outputs = [] | |
data = event.data | |
for tool_call in data.required_action.submit_tool_outputs.tool_calls: | |
output = ToolRequestHandler._process_tool_call(tool_call) | |
tool_outputs.append(output) | |
return tool_outputs, data.thread_id, data.id | |
def _process_tool_call(tool_call: Any) -> Dict[str, str]: | |
"""Process individual tool calls.""" | |
function_args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {} | |
match tool_call.function.name: | |
case "hello_world": | |
name = function_args.get("name", "anonymous") | |
output_val = f"Hello, {name}! This was from a local function." | |
case _: | |
output_val = json.dumps({"status": "error", "message": "Unknown function request."}) | |
return {"tool_call_id": tool_call.id, "output": output_val} | |
# ------------------------ | |
# Assistant Manager | |
# ------------------------ | |
class AssistantManager: | |
"""Manages interactions with the OpenAI Assistant.""" | |
def __init__(self, client: Any, assistant_id: str): | |
self.client = client | |
self.assistant_id = assistant_id | |
self.stream_handler = StreamHandler(client) | |
self.tool_handler = ToolRequestHandler() | |
def generate_reply(self, user_input: str) -> str: | |
"""Generate and stream assistant's reply.""" | |
# Ensure thread exists | |
if not st.session_state.thread: | |
st.session_state.thread = self.client.beta.threads.create() | |
# Add user message | |
self.client.beta.threads.messages.create( | |
thread_id=st.session_state.thread.id, | |
role="user", | |
content=user_input | |
) | |
complete_response = "" | |
# Stream initial response | |
with self.client.beta.threads.runs.stream( | |
thread_id=st.session_state.thread.id, | |
assistant_id=self.assistant_id, | |
) as run_stream: | |
complete_response = self._display_stream(run_stream) | |
# Handle any tool requests | |
self._process_tool_requests() | |
return complete_response | |
def _display_stream(self, run_stream: Any, create_context: bool = True) -> str: | |
"""Display streaming content.""" | |
st.session_state.run_stream = run_stream | |
if create_context: | |
with st.chat_message("assistant"): | |
return st.write_stream(self.stream_handler.stream_data) | |
else: | |
return st.write_stream(self.stream_handler.stream_data) | |
def _process_tool_requests(self) -> None: | |
"""Process any pending tool requests.""" | |
while not st.session_state.tool_requests.empty(): | |
event = st.session_state.tool_requests.get() | |
tool_outputs, thread_id, run_id = self.tool_handler.handle_request(event) | |
with self.client.beta.threads.runs.submit_tool_outputs_stream( | |
thread_id=thread_id, | |
run_id=run_id, | |
tool_outputs=tool_outputs | |
) as next_stream: | |
self._display_stream(next_stream, create_context=False) | |
# ------------------------ | |
# Main Application | |
# ------------------------ | |
class ChatApplication: | |
"""Main chat application class.""" | |
def __init__(self): | |
self.config = AppConfig() | |
api_key, assistant_id = EnvironmentManager.load_environment() | |
# Initialize OpenAI client | |
openai_client = openai.Client(api_key=api_key) | |
self.client = wrap_openai(openai_client) | |
# Initialize components | |
self.state_manager = StateManager() | |
self.assistant_manager = AssistantManager(self.client, assistant_id) | |
def setup_page(self) -> None: | |
"""Configure the Streamlit page.""" | |
st.set_page_config( | |
page_title=self.config.page_title, | |
page_icon=self.config.page_icon, | |
layout=self.config.layout | |
) | |
st.title(self.config.page_title) | |
def display_chat_history(self) -> None: | |
"""Display the chat history.""" | |
for msg in st.session_state.messages: | |
with st.chat_message(msg.role): | |
st.write(msg.content) | |
def run(self) -> None: | |
"""Run the chat application.""" | |
self.setup_page() | |
self.state_manager.initialize_state() | |
self.display_chat_history() | |
user_input = st.chat_input("Type your message here...") | |
if user_input: | |
# Display and store user message | |
with st.chat_message("user"): | |
st.write(user_input) | |
self.state_manager.add_message("user", user_input) | |
# Generate and display assistant reply | |
try: | |
complete_response = self.assistant_manager.generate_reply(user_input) | |
self.state_manager.add_message( | |
"assistant", | |
complete_response | |
) | |
except Exception as e: | |
st.error(f"Error generating response: {str(e)}") | |
logger.exception("Error in assistant reply generation") | |
def main(): | |
"""Application entry point.""" | |
try: | |
app = ChatApplication() | |
app.run() | |
except Exception as e: | |
st.error(f"Application error: {str(e)}") | |
logger.exception("Fatal application error") | |
if __name__ == "__main__": | |
main() |