Spaces:
Sleeping
Sleeping
import os | |
import re | |
import streamlit as st | |
import openai | |
from dotenv import load_dotenv | |
from langchain.agents.openai_assistant import OpenAIAssistantRunnable | |
# Load environment variables | |
load_dotenv() | |
api_key = os.getenv("OPENAI_API_KEY") | |
extractor_agent = os.getenv("ASSISTANT_ID_SOLUTION_SPECIFIER_A") | |
# Create the assistant | |
extractor_llm = OpenAIAssistantRunnable( | |
assistant_id=extractor_agent, | |
api_key=api_key, | |
as_agent=True | |
) | |
def remove_citation(text: str) -> str: | |
pattern = r"γ\d+β \w+γ" | |
return re.sub(pattern, "π", text) | |
# Initialize session state | |
if "messages" not in st.session_state: | |
st.session_state["messages"] = [] | |
if "thread_id" not in st.session_state: | |
st.session_state["thread_id"] = None | |
# A flag to indicate if a request is in progress | |
if "is_in_request" not in st.session_state: | |
st.session_state["is_in_request"] = False | |
st.title("Solution Specifier A") | |
def predict(user_input: str) -> str: | |
""" | |
This function calls our OpenAIAssistantRunnable to get a response. | |
If st.session_state["thread_id"] is None, we start a new thread. | |
Otherwise, we continue the existing thread. | |
If a concurrency error occurs ("Can't add messages to thread..."), we reset | |
the thread_id and try again once on a fresh thread. | |
""" | |
try: | |
if st.session_state["thread_id"] is None: | |
# Start a new thread | |
response = extractor_llm.invoke({"content": user_input}) | |
st.session_state["thread_id"] = response.thread_id | |
else: | |
# Continue existing thread | |
response = extractor_llm.invoke( | |
{"content": user_input, "thread_id": st.session_state["thread_id"]} | |
) | |
output = response.return_values["output"] | |
return remove_citation(output) | |
except openai.error.BadRequestError as e: | |
# If we get the specific concurrency error, reset thread and try once more | |
if "while a run" in str(e): | |
st.session_state["thread_id"] = None | |
# Now create a new thread for the same user input | |
try: | |
response = extractor_llm.invoke({"content": user_input}) | |
st.session_state["thread_id"] = response.thread_id | |
output = response.return_values["output"] | |
return remove_citation(output) | |
except Exception as e2: | |
st.error(f"Error after resetting thread: {e2}") | |
return "" | |
else: | |
# Some other 400 error | |
st.error(str(e)) | |
return "" | |
except Exception as e: | |
st.error(str(e)) | |
return "" | |
# Display any existing messages | |
for msg in st.session_state["messages"]: | |
if msg["role"] == "user": | |
with st.chat_message("user"): | |
st.write(msg["content"]) | |
else: | |
with st.chat_message("assistant"): | |
st.write(msg["content"]) | |
# Chat input at the bottom of the page | |
user_input = st.chat_input("Type your message here...") | |
# Process the user input only if: | |
# 1) There is some text, and | |
# 2) We are not already handling a request (is_in_request == False) | |
if user_input and not st.session_state["is_in_request"]: | |
# Lock to prevent duplicate requests | |
st.session_state["is_in_request"] = True | |
# Add the user message to session state | |
st.session_state["messages"].append({"role": "user", "content": user_input}) | |
# Display the user's message | |
with st.chat_message("user"): | |
st.write(user_input) | |
# Get assistant response | |
response_text = predict(user_input) | |
# Add assistant response to session state | |
st.session_state["messages"].append({"role": "assistant", "content": response_text}) | |
# Display assistant response | |
with st.chat_message("assistant"): | |
st.write(response_text) | |
# Release the lock | |
st.session_state["is_in_request"] = False |