File size: 3,011 Bytes
73d49e1
b8736af
73d49e1
 
24b6129
b02dba2
 
 
 
24b6129
713dd57
24b6129
 
 
 
 
 
73d49e1
24b6129
 
 
73d49e1
7e727b6
b8736af
 
f6d4f89
 
7e727b6
 
b02dba2
24b6129
73d49e1
24b6129
f6d4f89
24b6129
 
 
b02dba2
24b6129
 
 
f6d4f89
24b6129
 
f6d4f89
24b6129
 
 
 
 
 
 
 
 
 
 
 
 
7e727b6
 
 
 
 
24b6129
 
7e727b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24b6129
7e727b6
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import re
import streamlit as st
from dotenv import load_dotenv
from langchain.agents.openai_assistant import OpenAIAssistantRunnable

# Load environment variables
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
extractor_agent = os.getenv("ASSISTANT_ID_SOLUTION_SPECIFIER_A")

# Create the assistant
extractor_llm = OpenAIAssistantRunnable(
    assistant_id=extractor_agent,
    api_key=api_key,
    as_agent=True
)

def remove_citation(text: str) -> str:
    pattern = r"【\d+†\w+】"
    return re.sub(pattern, "πŸ“š", text)

# Initialize session state for messages, thread_id and processing state
if "messages" not in st.session_state:
    st.session_state["messages"] = []
if "thread_id" not in st.session_state:
    st.session_state["thread_id"] = None
if "is_processing" not in st.session_state:
    st.session_state["is_processing"] = False

st.title("Solution Specifier A")

def predict(user_input: str) -> str:
    """
    This function calls our OpenAIAssistantRunnable to get a response.
    If we don't have a thread_id yet, we create a new thread on the first call.
    Otherwise, we continue the existing thread.
    """
    if st.session_state["thread_id"] is None:
        response = extractor_llm.invoke({"content": user_input})
        st.session_state["thread_id"] = response.thread_id
    else:
        response = extractor_llm.invoke(
            {"content": user_input, "thread_id": st.session_state["thread_id"]}
        )
    output = response.return_values["output"]
    return remove_citation(output)

# Display any existing messages (from a previous run or refresh)
for msg in st.session_state["messages"]:
    if msg["role"] == "user":
        with st.chat_message("user"):
            st.write(msg["content"])
    else:
        with st.chat_message("assistant"):
            st.write(msg["content"])

# Create the chat input widget at the bottom of the page
# Disable it while processing a message
user_input = st.chat_input(
    "Type your message here...",
    disabled=st.session_state["is_processing"]
)

# When the user hits ENTER on st.chat_input
if user_input and not st.session_state["is_processing"]:
    try:
        # Set processing state to True
        st.session_state["is_processing"] = True
        
        # Add the user message to session state
        st.session_state["messages"].append({"role": "user", "content": user_input})
        
        # Display the user's message
        with st.chat_message("user"):
            st.write(user_input)
            
        # Get the assistant's response
        response_text = predict(user_input)
        
        # Add the assistant response to session state
        st.session_state["messages"].append({"role": "assistant", "content": response_text})
        
        # Display the assistant's reply
        with st.chat_message("assistant"):
            st.write(response_text)
            
    finally:
        # Reset processing state when done
        st.session_state["is_processing"] = False