PeterPinetree commited on
Commit
d021d32
·
verified ·
1 Parent(s): 6898cb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -44
app.py CHANGED
@@ -10,18 +10,18 @@ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=hf_token)
10
 
11
  # Story genres with genre-specific example prompts
12
  GENRE_EXAMPLES = {
 
 
 
 
 
 
 
13
  "fantasy": [
14
  "I enter the ancient forest seeking the wizard's tower",
15
  "I approach the dragon cautiously with my shield raised",
16
  "I examine the mysterious runes carved into the stone altar",
17
  "I try to bargain with the elven council for safe passage"
18
- ],
19
- "fairy tale": [
20
- "I follow the shimmer of fairy dust into a hidden forest",
21
- "I meet a talking rabbit who claims to know a secret about the king’s lost crown",
22
- "A tiny dragon appears at my window, asking for help to find its mother",
23
- "I step into a clearing where the trees whisper ancient riddles",
24
- "A friendly witch invites me into her cozy cottage, offering a warm cup of tea"
25
  ],
26
  "sci-fi": [
27
  "I hack into the space station's mainframe",
@@ -129,8 +129,10 @@ def format_history_for_gradio(history_tuples):
129
  """Convert (user, bot) tuples into Gradio 'messages' format (role/content dicts)."""
130
  messages = []
131
  for user_msg, bot_msg in history_tuples:
132
- messages.append({"role": "user", "content": user_msg})
133
- messages.append({"role": "assistant", "content": bot_msg})
 
 
134
  return messages
135
 
136
  # 1. Add type hints for better code maintainability
@@ -144,45 +146,43 @@ def respond(
144
  """Generate a response based on the current message and conversation history."""
145
  if not message.strip():
146
  return chat_history
147
- if genre and genre not in GENRE_EXAMPLES:
148
- genre = "fantasy" # fallback to default
149
 
150
  system_message = get_enhanced_system_prompt(genre)
151
-
152
- # Convert your existing (user, bot) history into a format for the API request
153
  formatted_history = []
154
- for user_msg, bot_msg in chat_history:
155
- formatted_history.append({"role": "user", "content": user_msg})
156
- formatted_history.append({"role": "assistant", "content": bot_msg})
 
 
 
157
 
158
- api_messages = [{"role": "system", "content": system_message}]
 
 
 
159
 
160
- # Use full memory or partial memory
161
  if use_full_memory and formatted_history:
162
  if len(formatted_history) > MAX_HISTORY_LENGTH:
163
  summary_instruction = create_story_summary(chat_history[:len(chat_history)-5])
164
  if summary_instruction:
165
  api_messages.append(summary_instruction)
166
- for msg in formatted_history[-MEMORY_WINDOW:]:
167
- api_messages.append(msg)
168
  else:
169
- for msg in formatted_history:
170
- api_messages.append(msg)
171
  else:
172
- memory_length = MEMORY_WINDOW
173
- if formatted_history:
174
- for msg in formatted_history[-memory_length*2:]:
175
- api_messages.append(msg)
176
 
177
  # Add current user message
178
- api_messages.append({"role": "user", "content": message})
179
 
180
- # Special handling for story initialization
181
- if not chat_history or message.lower() in ["start", "begin", "begin my adventure"]:
182
- api_messages.append({
183
- "role": "system",
184
- "content": f"Begin a new {genre or 'fantasy'} adventure with an intriguing opening scene. Introduce the protagonist without assuming too much about them."
185
- })
186
 
187
  bot_message = ""
188
  try:
@@ -193,18 +193,17 @@ def respond(
193
  temperature=TEMPERATURE,
194
  top_p=TOP_P,
195
  ):
196
- delta = response_chunk.choices[0].delta.content
197
- if delta:
198
- bot_message += delta
199
- # Yield more frequently
200
- if len(bot_message.strip()) >= MIN_RESPONSE_LENGTH:
201
- new_history = chat_history.copy()
202
- new_history.append((message, bot_message))
203
- yield format_history_for_gradio(new_history)
204
  except Exception as e:
205
  error_message = f"Story magic temporarily interrupted. Please try again. (Error: {str(e)})"
206
- broken_history = chat_history + [(message, error_message)]
207
- yield format_history_for_gradio(broken_history)
208
 
209
  def save_story(chat_history):
210
  """Convert chat history to markdown for download"""
@@ -344,4 +343,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
344
 
345
  # Run the app
346
  if __name__ == "__main__":
347
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
10
 
11
  # Story genres with genre-specific example prompts
12
  GENRE_EXAMPLES = {
13
+ "fairy_tale": [
14
+ "I follow the shimmer of fairy dust into a hidden forest"
15
+ "I meet a talking rabbit who claims to know a secret about the king’s lost crown"
16
+ "A tiny dragon appears at my window, asking for help to find its mother"
17
+ "I step into a clearing where the trees whisper ancient riddles"
18
+ "A friendly witch invites me into her cozy cottage, offering a warm cup of tea"
19
+ ],
20
  "fantasy": [
21
  "I enter the ancient forest seeking the wizard's tower",
22
  "I approach the dragon cautiously with my shield raised",
23
  "I examine the mysterious runes carved into the stone altar",
24
  "I try to bargain with the elven council for safe passage"
 
 
 
 
 
 
 
25
  ],
26
  "sci-fi": [
27
  "I hack into the space station's mainframe",
 
129
  """Convert (user, bot) tuples into Gradio 'messages' format (role/content dicts)."""
130
  messages = []
131
  for user_msg, bot_msg in history_tuples:
132
+ messages.extend([
133
+ {"role": "user", "content": str(user_msg)},
134
+ {"role": "assistant", "content": str(bot_msg)}
135
+ ])
136
  return messages
137
 
138
  # 1. Add type hints for better code maintainability
 
146
  """Generate a response based on the current message and conversation history."""
147
  if not message.strip():
148
  return chat_history
 
 
149
 
150
  system_message = get_enhanced_system_prompt(genre)
151
+
152
+ # Format history properly
153
  formatted_history = []
154
+ if chat_history:
155
+ for user_msg, bot_msg in chat_history:
156
+ formatted_history.extend([
157
+ {"role": "user", "content": str(user_msg)},
158
+ {"role": "assistant", "content": str(bot_msg)}
159
+ ])
160
 
161
+ # Construct API messages
162
+ api_messages = [
163
+ {"role": "system", "content": system_message}
164
+ ]
165
 
166
+ # Add memory management
167
  if use_full_memory and formatted_history:
168
  if len(formatted_history) > MAX_HISTORY_LENGTH:
169
  summary_instruction = create_story_summary(chat_history[:len(chat_history)-5])
170
  if summary_instruction:
171
  api_messages.append(summary_instruction)
172
+ api_messages.extend(formatted_history[-MEMORY_WINDOW*2:])
 
173
  else:
174
+ api_messages.extend(formatted_history)
 
175
  else:
176
+ api_messages.extend(formatted_history[-MEMORY_WINDOW*2:])
 
 
 
177
 
178
  # Add current user message
179
+ api_messages.append({"role": "user", "content": str(message)})
180
 
181
+ # Add choice enforcement
182
+ api_messages.append({
183
+ "role": "system",
184
+ "content": "Remember to end your response with exactly three numbered choices, each starting with 'You'"
185
+ })
 
186
 
187
  bot_message = ""
188
  try:
 
193
  temperature=TEMPERATURE,
194
  top_p=TOP_P,
195
  ):
196
+ if hasattr(response_chunk.choices[0].delta, 'content'):
197
+ delta = response_chunk.choices[0].delta.content
198
+ if delta:
199
+ bot_message += delta
200
+ if len(bot_message.strip()) >= MIN_RESPONSE_LENGTH:
201
+ yield [{"role": "user", "content": message},
202
+ {"role": "assistant", "content": bot_message}]
 
203
  except Exception as e:
204
  error_message = f"Story magic temporarily interrupted. Please try again. (Error: {str(e)})"
205
+ yield [{"role": "user", "content": message},
206
+ {"role": "assistant", "content": error_message}]
207
 
208
  def save_story(chat_history):
209
  """Convert chat history to markdown for download"""
 
343
 
344
  # Run the app
345
  if __name__ == "__main__":
346
+ demo.launch(server_name="0.0.0.0", server_port=7860)