File size: 2,746 Bytes
73d49e1
b8736af
73d49e1
 
7c11bd6
 
73d49e1
b8736af
73d49e1
b8736af
7c11bd6
73d49e1
7c11bd6
b8736af
 
 
73d49e1
b8736af
 
 
 
 
73d49e1
b8736af
73d49e1
7c11bd6
 
 
73d49e1
7c11bd6
 
 
73d49e1
7c11bd6
 
 
 
 
 
 
1de6e3a
7c11bd6
 
 
b8736af
7c11bd6
 
 
 
 
 
 
 
 
 
b8736af
 
 
 
1de6e3a
b8736af
 
 
 
1de6e3a
b8736af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import re
import streamlit as st
from dotenv import load_dotenv
import openai
from langsmith import traceable

# Load environment variables
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
openai.api_key = api_key

# Helper function to remove citations
def remove_citation(text: str) -> str:
    pattern = r"【\d+†\w+】"
    return re.sub(pattern, "πŸ“š", text)

# Initialize session state for messages and thread_id
if "messages" not in st.session_state:
    st.session_state["messages"] = []
if "thread_id" not in st.session_state:
    st.session_state["thread_id"] = None

st.title("Solution Specifier A")

# Traceable function for predict logic
@traceable
def get_response(user_input: str, thread_id: str = None):
    """
    This function calls OpenAI API to get a response.
    If thread_id is provided, it continues the conversation.
    Otherwise, it starts a new conversation.
    """
    messages = [{"role": "user", "content": user_input}]
    if thread_id:
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo",
            messages=messages,
            user=thread_id
        )
    else:
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo",
            messages=messages
        )
    return response["choices"][0]["message"]["content"], response["id"]

# Streamlit app logic
def predict(user_input: str) -> str:
    if st.session_state["thread_id"] is None:
        response_text, thread_id = get_response(user_input)
        st.session_state["thread_id"] = thread_id
    else:
        response_text, _ = get_response(user_input, thread_id=st.session_state["thread_id"])
    return remove_citation(response_text)

# Display any existing messages (from a previous run or refresh)
for msg in st.session_state["messages"]:
    if msg["role"] == "user":
        with st.chat_message("user"):
            st.write(msg["content"])
    else:
        with st.chat_message("assistant"):
            st.write(msg["content"])

# Create the chat input widget at the bottom of the page
user_input = st.chat_input("Type your message here...")

# When the user hits ENTER on st.chat_input
if user_input:
    # Add the user message to session state
    st.session_state["messages"].append({"role": "user", "content": user_input})
    
    # Display the user's message
    with st.chat_message("user"):
        st.write(user_input)
        
    # Get the assistant's response
    response_text = predict(user_input)
    
    # Add the assistant response to session state
    st.session_state["messages"].append({"role": "assistant", "content": response_text})
    
    # Display the assistant's reply
    with st.chat_message("assistant"):
        st.write(response_text)