AbenzaFran's picture
try catch
6b333e0
raw
history blame
3.89 kB
import os
import re
import streamlit as st
import openai
from dotenv import load_dotenv
from langchain.agents.openai_assistant import OpenAIAssistantRunnable
# Load environment variables
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
extractor_agent = os.getenv("ASSISTANT_ID_SOLUTION_SPECIFIER_A")
# Create the assistant
extractor_llm = OpenAIAssistantRunnable(
assistant_id=extractor_agent,
api_key=api_key,
as_agent=True
)
def remove_citation(text: str) -> str:
pattern = r"【\d+†\w+】"
return re.sub(pattern, "πŸ“š", text)
# Initialize session state
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "thread_id" not in st.session_state:
st.session_state["thread_id"] = None
# A flag to indicate if a request is in progress
if "is_in_request" not in st.session_state:
st.session_state["is_in_request"] = False
st.title("Solution Specifier A")
def predict(user_input: str) -> str:
"""
This function calls our OpenAIAssistantRunnable to get a response.
If st.session_state["thread_id"] is None, we start a new thread.
Otherwise, we continue the existing thread.
If a concurrency error occurs ("Can't add messages to thread..."), we reset
the thread_id and try again once on a fresh thread.
"""
try:
if st.session_state["thread_id"] is None:
# Start a new thread
response = extractor_llm.invoke({"content": user_input})
st.session_state["thread_id"] = response.thread_id
else:
# Continue existing thread
response = extractor_llm.invoke(
{"content": user_input, "thread_id": st.session_state["thread_id"]}
)
output = response.return_values["output"]
return remove_citation(output)
except openai.error.BadRequestError as e:
# If we get the specific concurrency error, reset thread and try once more
if "while a run" in str(e):
st.session_state["thread_id"] = None
# Now create a new thread for the same user input
try:
response = extractor_llm.invoke({"content": user_input})
st.session_state["thread_id"] = response.thread_id
output = response.return_values["output"]
return remove_citation(output)
except Exception as e2:
st.error(f"Error after resetting thread: {e2}")
return ""
else:
# Some other 400 error
st.error(str(e))
return ""
except Exception as e:
st.error(str(e))
return ""
# Display any existing messages
for msg in st.session_state["messages"]:
if msg["role"] == "user":
with st.chat_message("user"):
st.write(msg["content"])
else:
with st.chat_message("assistant"):
st.write(msg["content"])
# Chat input at the bottom of the page
user_input = st.chat_input("Type your message here...")
# Process the user input only if:
# 1) There is some text, and
# 2) We are not already handling a request (is_in_request == False)
if user_input and not st.session_state["is_in_request"]:
# Lock to prevent duplicate requests
st.session_state["is_in_request"] = True
# Add the user message to session state
st.session_state["messages"].append({"role": "user", "content": user_input})
# Display the user's message
with st.chat_message("user"):
st.write(user_input)
# Get assistant response
response_text = predict(user_input)
# Add assistant response to session state
st.session_state["messages"].append({"role": "assistant", "content": response_text})
# Display assistant response
with st.chat_message("assistant"):
st.write(response_text)
# Release the lock
st.session_state["is_in_request"] = False