import os import re import streamlit as st import openai from dotenv import load_dotenv from langchain.agents.openai_assistant import OpenAIAssistantRunnable # Load environment variables load_dotenv() api_key = os.getenv("OPENAI_API_KEY") extractor_agent = os.getenv("ASSISTANT_ID_SOLUTION_SPECIFIER_A") # Create the assistant extractor_llm = OpenAIAssistantRunnable( assistant_id=extractor_agent, api_key=api_key, as_agent=True ) def remove_citation(text: str) -> str: pattern = r"【\d+†\w+】" return re.sub(pattern, "📚", text) # Initialize session state if "messages" not in st.session_state: st.session_state["messages"] = [] if "thread_id" not in st.session_state: st.session_state["thread_id"] = None # A flag to indicate if a request is in progress if "is_in_request" not in st.session_state: st.session_state["is_in_request"] = False st.title("Solution Specifier A") def predict(user_input: str) -> str: """ This function calls our OpenAIAssistantRunnable to get a response. If st.session_state["thread_id"] is None, we start a new thread. Otherwise, we continue the existing thread. If a concurrency error occurs ("Can't add messages to thread..."), we reset the thread_id and try again once on a fresh thread. """ try: if st.session_state["thread_id"] is None: # Start a new thread response = extractor_llm.invoke({"content": user_input}) st.session_state["thread_id"] = response.thread_id else: # Continue existing thread response = extractor_llm.invoke( {"content": user_input, "thread_id": st.session_state["thread_id"]} ) output = response.return_values["output"] return remove_citation(output) except openai.error.BadRequestError as e: # If we get the specific concurrency error, reset thread and try once more if "while a run" in str(e): st.session_state["thread_id"] = None # Now create a new thread for the same user input try: response = extractor_llm.invoke({"content": user_input}) st.session_state["thread_id"] = response.thread_id output = response.return_values["output"] return remove_citation(output) except Exception as e2: st.error(f"Error after resetting thread: {e2}") return "" else: # Some other 400 error st.error(str(e)) return "" except Exception as e: st.error(str(e)) return "" # Display any existing messages for msg in st.session_state["messages"]: if msg["role"] == "user": with st.chat_message("user"): st.write(msg["content"]) else: with st.chat_message("assistant"): st.write(msg["content"]) # Chat input at the bottom of the page user_input = st.chat_input("Type your message here...") # Process the user input only if: # 1) There is some text, and # 2) We are not already handling a request (is_in_request == False) if user_input and not st.session_state["is_in_request"]: # Lock to prevent duplicate requests st.session_state["is_in_request"] = True # Add the user message to session state st.session_state["messages"].append({"role": "user", "content": user_input}) # Display the user's message with st.chat_message("user"): st.write(user_input) # Get assistant response response_text = predict(user_input) # Add assistant response to session state st.session_state["messages"].append({"role": "assistant", "content": response_text}) # Display assistant response with st.chat_message("assistant"): st.write(response_text) # Release the lock st.session_state["is_in_request"] = False