Update app.py
Browse filesContinue button coding.
app.py
CHANGED
@@ -37,6 +37,10 @@ if "system_message" not in st.session_state:
|
|
37 |
if "starter_message" not in st.session_state:
|
38 |
st.session_state.starter_message = "Hello, there! How can I help you today?"
|
39 |
|
|
|
|
|
|
|
|
|
40 |
# Sidebar for settings
|
41 |
with st.sidebar:
|
42 |
st.header("System Settings")
|
@@ -72,7 +76,7 @@ with st.sidebar:
|
|
72 |
if "chat_history" not in st.session_state or reset_history:
|
73 |
st.session_state.chat_history = [{"role": "assistant", "content": st.session_state.starter_message}]
|
74 |
|
75 |
-
def get_response(system_message, chat_history, user_text, max_new_tokens=256):
|
76 |
"""
|
77 |
Generates a response from the chatbot model.
|
78 |
|
@@ -81,6 +85,7 @@ def get_response(system_message, chat_history, user_text, max_new_tokens=256):
|
|
81 |
chat_history (list): The list of previous chat messages.
|
82 |
user_text (str): The user's input text.
|
83 |
max_new_tokens (int): The maximum number of new tokens to generate.
|
|
|
84 |
|
85 |
Returns:
|
86 |
tuple: A tuple containing the generated response and the updated chat history.
|
@@ -93,6 +98,12 @@ def get_response(system_message, chat_history, user_text, max_new_tokens=256):
|
|
93 |
prompt += f"\n{role}\n{message['content']}\n"
|
94 |
prompt += f"\n<|user|>\n{user_text}\n<|assistant|>\n"
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
# Generate the response
|
97 |
response_output = generator(
|
98 |
prompt,
|
@@ -108,9 +119,13 @@ def get_response(system_message, chat_history, user_text, max_new_tokens=256):
|
|
108 |
# Extract the assistant's response
|
109 |
assistant_response = generated_text[len(prompt):].strip()
|
110 |
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
114 |
|
115 |
return assistant_response, chat_history
|
116 |
|
@@ -126,6 +141,10 @@ with output_container:
|
|
126 |
continue
|
127 |
with st.chat_message(message['role'], avatar=st.session_state.avatars[message['role']]):
|
128 |
st.markdown(message['content'])
|
|
|
|
|
|
|
|
|
129 |
|
130 |
# User input area (moved to the bottom)
|
131 |
st.session_state.user_text = st.chat_input(placeholder="Enter your text here.")
|
@@ -145,8 +164,27 @@ if st.session_state.user_text:
|
|
145 |
user_text=st.session_state.user_text,
|
146 |
chat_history=st.session_state.chat_history,
|
147 |
max_new_tokens=st.session_state.max_response_length,
|
|
|
148 |
)
|
149 |
st.markdown(response)
|
150 |
|
151 |
# Clear the user input
|
152 |
st.session_state.user_text = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
if "starter_message" not in st.session_state:
|
38 |
st.session_state.starter_message = "Hello, there! How can I help you today?"
|
39 |
|
40 |
+
# Initialize session state for continue action
|
41 |
+
if "need_continue" not in st.session_state:
|
42 |
+
st.session_state.need_continue = False
|
43 |
+
|
44 |
# Sidebar for settings
|
45 |
with st.sidebar:
|
46 |
st.header("System Settings")
|
|
|
76 |
if "chat_history" not in st.session_state or reset_history:
|
77 |
st.session_state.chat_history = [{"role": "assistant", "content": st.session_state.starter_message}]
|
78 |
|
79 |
+
def get_response(system_message, chat_history, user_text, max_new_tokens=256, continue_last=False):
|
80 |
"""
|
81 |
Generates a response from the chatbot model.
|
82 |
|
|
|
85 |
chat_history (list): The list of previous chat messages.
|
86 |
user_text (str): The user's input text.
|
87 |
max_new_tokens (int): The maximum number of new tokens to generate.
|
88 |
+
continue_last (bool): Whether to continue the last assistant response.
|
89 |
|
90 |
Returns:
|
91 |
tuple: A tuple containing the generated response and the updated chat history.
|
|
|
98 |
prompt += f"\n{role}\n{message['content']}\n"
|
99 |
prompt += f"\n<|user|>\n{user_text}\n<|assistant|>\n"
|
100 |
|
101 |
+
if continue_last:
|
102 |
+
# We want to continue the last assistant response
|
103 |
+
prompt += "Assistant:"
|
104 |
+
else:
|
105 |
+
prompt += f"User: {user_text}\nAssistant:"
|
106 |
+
|
107 |
# Generate the response
|
108 |
response_output = generator(
|
109 |
prompt,
|
|
|
119 |
# Extract the assistant's response
|
120 |
assistant_response = generated_text[len(prompt):].strip()
|
121 |
|
122 |
+
if continue_last:
|
123 |
+
# Append the continued text to the last assistant message
|
124 |
+
st.session_state.chat_history[-1]['content'] += assistant_response
|
125 |
+
else:
|
126 |
+
# Update the chat history
|
127 |
+
chat_history.append({'role': 'user', 'content': user_text})
|
128 |
+
chat_history.append({'role': 'assistant', 'content': assistant_response})
|
129 |
|
130 |
return assistant_response, chat_history
|
131 |
|
|
|
141 |
continue
|
142 |
with st.chat_message(message['role'], avatar=st.session_state.avatars[message['role']]):
|
143 |
st.markdown(message['content'])
|
144 |
+
# If this is the last assistant message, add the "Continue" button
|
145 |
+
if idx == len(st.session_state.chat_history) - 1 and message['role'] == 'assistant':
|
146 |
+
if st.button("Continue"):
|
147 |
+
st.session_state.need_continue = True
|
148 |
|
149 |
# User input area (moved to the bottom)
|
150 |
st.session_state.user_text = st.chat_input(placeholder="Enter your text here.")
|
|
|
164 |
user_text=st.session_state.user_text,
|
165 |
chat_history=st.session_state.chat_history,
|
166 |
max_new_tokens=st.session_state.max_response_length,
|
167 |
+
continue_last=False
|
168 |
)
|
169 |
st.markdown(response)
|
170 |
|
171 |
# Clear the user input
|
172 |
st.session_state.user_text = None
|
173 |
+
|
174 |
+
# If "Continue" button was pressed
|
175 |
+
if st.session_state.get('need_continue', False):
|
176 |
+
# Display a spinner while generating the continuation
|
177 |
+
with st.chat_message("assistant", avatar=st.session_state.avatars['assistant']):
|
178 |
+
with st.spinner("Continuing..."):
|
179 |
+
# Generate the continuation of the assistant's last response
|
180 |
+
response, st.session_state.chat_history = get_response(
|
181 |
+
system_message=st.session_state.system_message,
|
182 |
+
user_text=None,
|
183 |
+
chat_history=st.session_state.chat_history,
|
184 |
+
max_new_tokens=st.session_state.max_response_length,
|
185 |
+
continue_last=True
|
186 |
+
)
|
187 |
+
st.markdown(response)
|
188 |
+
|
189 |
+
# Reset the flag
|
190 |
+
st.session_state.need_continue = False
|