PeterPinetree commited on
Commit
c7d0530
·
verified ·
1 Parent(s): cd10677

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -24
app.py CHANGED
@@ -136,46 +136,39 @@ def respond(
136
  chat_history: List[Tuple[str, str]],
137
  genre: Optional[str] = None,
138
  use_full_memory: bool = True
139
- ) -> Generator[List[Tuple[str, str]], None, None]:
140
  """Generate a response based on the current message and conversation history."""
141
  if not message.strip():
142
  return chat_history
143
 
144
- # Format messages for the API
145
- formatted_messages = [{
146
- "role": "system",
147
- "content": get_enhanced_system_prompt(genre)
148
- }]
149
 
150
- # Add chat history
151
  if chat_history and use_full_memory:
152
  for user_msg, bot_msg in chat_history[-MEMORY_WINDOW:]:
153
- formatted_messages.extend([
154
- {"role": "user", "content": str(user_msg)},
155
- {"role": "assistant", "content": str(bot_msg)}
156
- ])
157
-
158
- # Add current message
159
- formatted_messages.append({
160
- "role": "user",
161
- "content": str(message)
162
- })
163
 
164
  try:
 
165
  response = client.chat_completion(
166
- formatted_messages,
167
  max_tokens=MAX_TOKENS,
168
  temperature=TEMPERATURE,
169
  top_p=TOP_P
170
  )
 
 
171
  bot_message = response.choices[0].message.content
172
- new_history = list(chat_history) + [(message, bot_message)]
173
- return new_history
174
 
175
  except Exception as e:
176
  error_message = f"Story magic temporarily interrupted. Please try again. (Error: {str(e)})"
177
- error_history = list(chat_history) + [(message, error_message)]
178
- return error_history
179
 
180
  def save_story(chat_history):
181
  """Convert chat history to markdown for download"""
@@ -268,7 +261,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
268
  # 4) Connect each starter button:
269
  for starter_button in starter_buttons:
270
  starter_button.click(
271
- fn=lambda x: {"role": "user", "content": str(x)}, # Format as message dict
272
  inputs=[starter_button],
273
  outputs=[chatbot],
274
  queue=False
@@ -318,4 +311,3 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
318
  # Run the app
319
  if __name__ == "__main__":
320
  demo.launch(server_name="0.0.0.0", server_port=7860)
321
-
 
136
  chat_history: List[Tuple[str, str]],
137
  genre: Optional[str] = None,
138
  use_full_memory: bool = True
139
+ ) -> List[Tuple[str, str]]: # Changed return type
140
  """Generate a response based on the current message and conversation history."""
141
  if not message.strip():
142
  return chat_history
143
 
144
+ # Ensure formatted_messages is correctly structured
145
+ formatted_messages = [{"role": "system", "content": get_enhanced_system_prompt(genre)}]
 
 
 
146
 
147
+ # Add chat history correctly
148
  if chat_history and use_full_memory:
149
  for user_msg, bot_msg in chat_history[-MEMORY_WINDOW:]:
150
+ formatted_messages.append({"role": "user", "content": str(user_msg)})
151
+ formatted_messages.append({"role": "assistant", "content": str(bot_msg)})
152
+
153
+ # Append user message
154
+ formatted_messages.append({"role": "user", "content": str(message)})
 
 
 
 
 
155
 
156
  try:
157
+ # Make API call with correct message structure
158
  response = client.chat_completion(
159
+ messages=formatted_messages,
160
  max_tokens=MAX_TOKENS,
161
  temperature=TEMPERATURE,
162
  top_p=TOP_P
163
  )
164
+
165
+ # Extract and format bot response
166
  bot_message = response.choices[0].message.content
167
+ return chat_history + [(str(message), str(bot_message))]
 
168
 
169
  except Exception as e:
170
  error_message = f"Story magic temporarily interrupted. Please try again. (Error: {str(e)})"
171
+ return chat_history + [(str(message), str(error_message))]
 
172
 
173
  def save_story(chat_history):
174
  """Convert chat history to markdown for download"""
 
261
  # 4) Connect each starter button:
262
  for starter_button in starter_buttons:
263
  starter_button.click(
264
+ fn=lambda x: [(str(x), "")], # Return initial history tuple
265
  inputs=[starter_button],
266
  outputs=[chatbot],
267
  queue=False
 
311
  # Run the app
312
  if __name__ == "__main__":
313
  demo.launch(server_name="0.0.0.0", server_port=7860)