schuler commited on
Commit
0539125
·
verified ·
1 Parent(s): 233c632

Update app.py

Browse files

Continue button coding.

Files changed (1) hide show
  1. app.py +42 -4
app.py CHANGED
@@ -37,6 +37,10 @@ if "system_message" not in st.session_state:
37
  if "starter_message" not in st.session_state:
38
  st.session_state.starter_message = "Hello, there! How can I help you today?"
39
 
 
 
 
 
40
  # Sidebar for settings
41
  with st.sidebar:
42
  st.header("System Settings")
@@ -72,7 +76,7 @@ with st.sidebar:
72
  if "chat_history" not in st.session_state or reset_history:
73
  st.session_state.chat_history = [{"role": "assistant", "content": st.session_state.starter_message}]
74
 
75
- def get_response(system_message, chat_history, user_text, max_new_tokens=256):
76
  """
77
  Generates a response from the chatbot model.
78
 
@@ -81,6 +85,7 @@ def get_response(system_message, chat_history, user_text, max_new_tokens=256):
81
  chat_history (list): The list of previous chat messages.
82
  user_text (str): The user's input text.
83
  max_new_tokens (int): The maximum number of new tokens to generate.
 
84
 
85
  Returns:
86
  tuple: A tuple containing the generated response and the updated chat history.
@@ -93,6 +98,12 @@ def get_response(system_message, chat_history, user_text, max_new_tokens=256):
93
  prompt += f"\n{role}\n{message['content']}\n"
94
  prompt += f"\n<|user|>\n{user_text}\n<|assistant|>\n"
95
 
 
 
 
 
 
 
96
  # Generate the response
97
  response_output = generator(
98
  prompt,
@@ -108,9 +119,13 @@ def get_response(system_message, chat_history, user_text, max_new_tokens=256):
108
  # Extract the assistant's response
109
  assistant_response = generated_text[len(prompt):].strip()
110
 
111
- # Update the chat history
112
- chat_history.append({'role': 'user', 'content': user_text})
113
- chat_history.append({'role': 'assistant', 'content': assistant_response})
 
 
 
 
114
 
115
  return assistant_response, chat_history
116
 
@@ -126,6 +141,10 @@ with output_container:
126
  continue
127
  with st.chat_message(message['role'], avatar=st.session_state.avatars[message['role']]):
128
  st.markdown(message['content'])
 
 
 
 
129
 
130
  # User input area (moved to the bottom)
131
  st.session_state.user_text = st.chat_input(placeholder="Enter your text here.")
@@ -145,8 +164,27 @@ if st.session_state.user_text:
145
  user_text=st.session_state.user_text,
146
  chat_history=st.session_state.chat_history,
147
  max_new_tokens=st.session_state.max_response_length,
 
148
  )
149
  st.markdown(response)
150
 
151
  # Clear the user input
152
  st.session_state.user_text = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  if "starter_message" not in st.session_state:
38
  st.session_state.starter_message = "Hello, there! How can I help you today?"
39
 
40
+ # Initialize session state for continue action
41
+ if "need_continue" not in st.session_state:
42
+ st.session_state.need_continue = False
43
+
44
  # Sidebar for settings
45
  with st.sidebar:
46
  st.header("System Settings")
 
76
  if "chat_history" not in st.session_state or reset_history:
77
  st.session_state.chat_history = [{"role": "assistant", "content": st.session_state.starter_message}]
78
 
79
+ def get_response(system_message, chat_history, user_text, max_new_tokens=256, continue_last=False):
80
  """
81
  Generates a response from the chatbot model.
82
 
 
85
  chat_history (list): The list of previous chat messages.
86
  user_text (str): The user's input text.
87
  max_new_tokens (int): The maximum number of new tokens to generate.
88
+ continue_last (bool): Whether to continue the last assistant response.
89
 
90
  Returns:
91
  tuple: A tuple containing the generated response and the updated chat history.
 
98
  prompt += f"\n{role}\n{message['content']}\n"
99
  prompt += f"\n<|user|>\n{user_text}\n<|assistant|>\n"
100
 
101
+ if continue_last:
102
+ # We want to continue the last assistant response
103
+ prompt += "Assistant:"
104
+ else:
105
+ prompt += f"User: {user_text}\nAssistant:"
106
+
107
  # Generate the response
108
  response_output = generator(
109
  prompt,
 
119
  # Extract the assistant's response
120
  assistant_response = generated_text[len(prompt):].strip()
121
 
122
+ if continue_last:
123
+ # Append the continued text to the last assistant message
124
+ st.session_state.chat_history[-1]['content'] += assistant_response
125
+ else:
126
+ # Update the chat history
127
+ chat_history.append({'role': 'user', 'content': user_text})
128
+ chat_history.append({'role': 'assistant', 'content': assistant_response})
129
 
130
  return assistant_response, chat_history
131
 
 
141
  continue
142
  with st.chat_message(message['role'], avatar=st.session_state.avatars[message['role']]):
143
  st.markdown(message['content'])
144
+ # If this is the last assistant message, add the "Continue" button
145
+ if idx == len(st.session_state.chat_history) - 1 and message['role'] == 'assistant':
146
+ if st.button("Continue"):
147
+ st.session_state.need_continue = True
148
 
149
  # User input area (moved to the bottom)
150
  st.session_state.user_text = st.chat_input(placeholder="Enter your text here.")
 
164
  user_text=st.session_state.user_text,
165
  chat_history=st.session_state.chat_history,
166
  max_new_tokens=st.session_state.max_response_length,
167
+ continue_last=False
168
  )
169
  st.markdown(response)
170
 
171
  # Clear the user input
172
  st.session_state.user_text = None
173
+
174
+ # If "Continue" button was pressed
175
+ if st.session_state.get('need_continue', False):
176
+ # Display a spinner while generating the continuation
177
+ with st.chat_message("assistant", avatar=st.session_state.avatars['assistant']):
178
+ with st.spinner("Continuing..."):
179
+ # Generate the continuation of the assistant's last response
180
+ response, st.session_state.chat_history = get_response(
181
+ system_message=st.session_state.system_message,
182
+ user_text=None,
183
+ chat_history=st.session_state.chat_history,
184
+ max_new_tokens=st.session_state.max_response_length,
185
+ continue_last=True
186
+ )
187
+ st.markdown(response)
188
+
189
+ # Reset the flag
190
+ st.session_state.need_continue = False