File size: 3,885 Bytes
73d49e1
b8736af
73d49e1
57c2a16
73d49e1
24b6129
b02dba2
 
 
 
24b6129
713dd57
24b6129
 
 
 
 
 
73d49e1
24b6129
 
 
73d49e1
6b333e0
b8736af
 
f6d4f89
 
6b333e0
 
 
b02dba2
24b6129
73d49e1
24b6129
f6d4f89
24b6129
6b333e0
 
 
 
 
b02dba2
57c2a16
 
 
 
 
 
6b333e0
57c2a16
 
 
6b333e0
57c2a16
 
 
 
6b333e0
 
57c2a16
6b333e0
57c2a16
 
 
 
 
 
 
 
 
6b333e0
 
57c2a16
 
 
 
24b6129
57c2a16
24b6129
 
 
 
 
 
 
 
6b333e0
57c2a16
24b6129
6b333e0
 
 
 
 
 
 
57c2a16
 
 
 
 
 
 
6b333e0
57c2a16
 
6b333e0
57c2a16
 
6b333e0
57c2a16
6b333e0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import re
import streamlit as st
import openai
from dotenv import load_dotenv
from langchain.agents.openai_assistant import OpenAIAssistantRunnable

# Load environment variables
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
extractor_agent = os.getenv("ASSISTANT_ID_SOLUTION_SPECIFIER_A")

# Create the assistant
extractor_llm = OpenAIAssistantRunnable(
    assistant_id=extractor_agent,
    api_key=api_key,
    as_agent=True
)

def remove_citation(text: str) -> str:
    pattern = r"【\d+†\w+】"
    return re.sub(pattern, "πŸ“š", text)

# Initialize session state
if "messages" not in st.session_state:
    st.session_state["messages"] = []
if "thread_id" not in st.session_state:
    st.session_state["thread_id"] = None
# A flag to indicate if a request is in progress
if "is_in_request" not in st.session_state:
    st.session_state["is_in_request"] = False

st.title("Solution Specifier A")

def predict(user_input: str) -> str:
    """
    This function calls our OpenAIAssistantRunnable to get a response.
    If st.session_state["thread_id"] is None, we start a new thread.
    Otherwise, we continue the existing thread.

    If a concurrency error occurs ("Can't add messages to thread..."), we reset
    the thread_id and try again once on a fresh thread.
    """
    try:
        if st.session_state["thread_id"] is None:
            # Start a new thread
            response = extractor_llm.invoke({"content": user_input})
            st.session_state["thread_id"] = response.thread_id
        else:
            # Continue existing thread
            response = extractor_llm.invoke(
                {"content": user_input, "thread_id": st.session_state["thread_id"]}
            )

        output = response.return_values["output"]
        return remove_citation(output)

    except openai.error.BadRequestError as e:
        # If we get the specific concurrency error, reset thread and try once more
        if "while a run" in str(e):
            st.session_state["thread_id"] = None
            # Now create a new thread for the same user input
            try:
                response = extractor_llm.invoke({"content": user_input})
                st.session_state["thread_id"] = response.thread_id
                output = response.return_values["output"]
                return remove_citation(output)
            except Exception as e2:
                st.error(f"Error after resetting thread: {e2}")
                return ""
        else:
            # Some other 400 error
            st.error(str(e))
            return ""
    except Exception as e:
        st.error(str(e))
        return ""

# Display any existing messages
for msg in st.session_state["messages"]:
    if msg["role"] == "user":
        with st.chat_message("user"):
            st.write(msg["content"])
    else:
        with st.chat_message("assistant"):
            st.write(msg["content"])

# Chat input at the bottom of the page
user_input = st.chat_input("Type your message here...")

# Process the user input only if:
# 1) There is some text, and
# 2) We are not already handling a request (is_in_request == False)
if user_input and not st.session_state["is_in_request"]:
    # Lock to prevent duplicate requests
    st.session_state["is_in_request"] = True

    # Add the user message to session state
    st.session_state["messages"].append({"role": "user", "content": user_input})

    # Display the user's message
    with st.chat_message("user"):
        st.write(user_input)

    # Get assistant response
    response_text = predict(user_input)

    # Add assistant response to session state
    st.session_state["messages"].append({"role": "assistant", "content": response_text})

    # Display assistant response
    with st.chat_message("assistant"):
        st.write(response_text)

    # Release the lock
    st.session_state["is_in_request"] = False