PeterPinetree commited on
Commit
3552c2f
·
verified ·
1 Parent(s): 164a4d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -94
app.py CHANGED
@@ -10,11 +10,11 @@ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=hf_token)
10
 
11
  # Story genres with genre-specific example prompts
12
  GENRE_EXAMPLES = {
13
- "fairy tale": [
14
- "I follow the shimmer of fairy dust into a hidden forest",
15
- "I meet a talking rabbit who claims to know a secret about the king's lost crown",
16
- "A tiny dragon appears at my window, asking for help to find its mother",
17
- "I step into a clearing where the trees whisper ancient riddles",
18
  "A friendly witch invites me into her cozy cottage, offering a warm cup of tea"
19
  ],
20
  "fantasy": [
@@ -73,7 +73,7 @@ GENRE_EXAMPLES = {
73
  ]
74
  }
75
 
76
- # Constants
77
  MAX_HISTORY_LENGTH = 20
78
  MEMORY_WINDOW = 5 # Reduced from 10 to limit context
79
  MAX_TOKENS = 1024 # Reduced from 2048 for faster responses
@@ -110,81 +110,66 @@ IMPORTANT:
110
  Keep the story cohesive by referencing previous events and choices."""
111
  return system_message
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  def respond(
114
  message: str,
115
  chat_history: List[Tuple[str, str]],
116
  genre: Optional[str] = None,
117
  use_full_memory: bool = True
118
- ) -> List[Tuple[str, str]]:
119
  """Generate a response based on the current message and conversation history."""
120
  if not message.strip():
121
- return chat_history
122
 
123
  try:
124
- # Create a copy of chat history to avoid modifying the original
125
- new_history = list(chat_history)
126
-
127
- # Construct messages list for API call
128
- messages = []
129
-
130
- # Always start with system message
131
- messages.append({
132
- "role": "system",
133
- "content": get_enhanced_system_prompt(genre)
134
- })
135
 
136
- # Add relevant chat history if enabled
137
- if use_full_memory and new_history:
138
- history_to_include = new_history[-MEMORY_WINDOW:] if len(new_history) > MEMORY_WINDOW else new_history
139
- for user_msg, bot_msg in history_to_include:
140
- if user_msg: # Only add if user message exists
141
- messages.append({
142
- "role": "user",
143
- "content": str(user_msg)
144
- })
145
- if bot_msg: # Only add if bot message exists
146
- messages.append({
147
- "role": "assistant",
148
- "content": str(bot_msg)
149
- })
150
 
151
- # Add current user message
152
- messages.append({
153
- "role": "user",
154
- "content": str(message)
155
- })
156
-
157
- # Make API call with strict dictionary format
158
- response = client.post(
159
- model="HuggingFaceH4/zephyr-7b-beta",
160
- json={
161
- "inputs": messages,
162
- "parameters": {
163
- "max_new_tokens": MAX_TOKENS,
164
- "temperature": TEMPERATURE,
165
- "top_p": TOP_P
166
- }
167
- }
168
  )
169
 
170
- # Extract response text
171
- if isinstance(response, dict) and "generated_text" in response:
172
- bot_message = response["generated_text"]
173
- elif isinstance(response, list) and response and "generated_text" in response[0]:
174
- bot_message = response[0]["generated_text"]
175
- else:
176
- # Fallback in case the response structure is different
177
- bot_message = str(response)
178
-
179
- # Add the new exchange to chat history
180
- new_history.append((str(message), str(bot_message)))
181
- return new_history
182
 
183
  except Exception as e:
184
- # Print the full error message for debugging
185
- print(f"Error: {str(e)}")
186
- # Return existing history plus error message
187
- return chat_history + [(str(message), f"Story magic temporarily interrupted. Please try again. (Error: {str(e)})")]
188
 
189
  def save_story(chat_history):
190
  """Convert chat history to markdown for download"""
@@ -198,7 +183,6 @@ def save_story(chat_history):
198
 
199
  return story_text
200
 
201
- # Define the Gradio interface
202
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
203
  gr.Markdown("# 🔮 Interactive Story Time")
204
  with gr.Row():
@@ -209,15 +193,16 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
209
  with gr.Column(scale=3):
210
  # Chat window + user input
211
  chatbot = gr.Chatbot(
212
- height=500,
213
- bubble_full_width=True,
214
  show_copy_button=True,
215
  avatar_images=(None, "🧙"),
216
  type="messages",
217
  container=True,
218
  scale=1,
219
- min_width=800,
220
- value=[]
 
221
  )
222
  msg = gr.Textbox(
223
  placeholder="Describe what you want to do next in the story...",
@@ -252,42 +237,50 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
252
  starter_btn4 = gr.Button("Starter 4")
253
  starter_buttons = [starter_btn1, starter_btn2, starter_btn3, starter_btn4]
254
 
 
255
  def update_starter_buttons(selected_genre):
256
  """Update starter buttons with examples for the selected genre."""
257
  examples = get_examples_for_genre(selected_genre)
258
  results = []
259
  for i in range(4):
260
  if i < len(examples):
 
261
  results.append(examples[i])
262
  else:
263
- results.append("")
264
- return tuple(results)
 
 
 
 
265
 
266
- # Initialize with default genre
267
- initial_button_data = update_starter_buttons("fantasy")
 
 
268
 
269
- # Connect each starter button
270
  for starter_button in starter_buttons:
271
- # First, update chatbot with starter message
272
  starter_button.click(
273
- fn=lambda x: [(str(x), "")],
 
274
  inputs=[starter_button],
275
  outputs=[chatbot],
276
  queue=False
277
  ).success(
278
- # Then process the message with respond function
279
- fn=respond,
280
- inputs=[
281
- starter_button,
282
- chatbot,
283
- genre,
284
- full_memory
285
- ],
286
  outputs=chatbot,
287
  queue=True
288
  )
289
 
290
- # Update buttons when genre changes
291
  genre.change(
292
  fn=update_starter_buttons,
293
  inputs=[genre],
@@ -316,13 +309,13 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
316
  queue=False
317
  )
318
 
319
- # Load initial button data
320
- demo.load(
321
- fn=lambda: initial_button_data,
322
- outputs=starter_buttons,
323
- queue=False
324
- )
325
 
326
  # Run the app
327
  if __name__ == "__main__":
328
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
10
 
11
  # Story genres with genre-specific example prompts
12
  GENRE_EXAMPLES = {
13
+ "fairy_tale": [
14
+ "I follow the shimmer of fairy dust into a hidden forest"
15
+ "I meet a talking rabbit who claims to know a secret about the kings lost crown"
16
+ "A tiny dragon appears at my window, asking for help to find its mother"
17
+ "I step into a clearing where the trees whisper ancient riddles"
18
  "A friendly witch invites me into her cozy cottage, offering a warm cup of tea"
19
  ],
20
  "fantasy": [
 
73
  ]
74
  }
75
 
76
+ # 2. Add constants at the top for magic numbers
77
  MAX_HISTORY_LENGTH = 20
78
  MEMORY_WINDOW = 5 # Reduced from 10 to limit context
79
  MAX_TOKENS = 1024 # Reduced from 2048 for faster responses
 
110
  Keep the story cohesive by referencing previous events and choices."""
111
  return system_message
112
 
113
+ def create_story_summary(chat_history):
114
+ """Create a concise summary of the story so far if the history gets too long"""
115
+ if len(chat_history) <= 2:
116
+ return None
117
+
118
+ story_text = ""
119
+ for user_msg, bot_msg in chat_history:
120
+ story_text += f"User: {user_msg}\nStory: {bot_msg}\n\n"
121
+
122
+ summary_instruction = {
123
+ "role": "system",
124
+ "content": "The conversation history is getting long. Please create a brief summary of the key plot points and character development so far to help maintain context without exceeding token limits."
125
+ }
126
+ return summary_instruction
127
+
128
+ def format_history_for_gradio(history_tuples):
129
+ """Convert chat history to Gradio's message format."""
130
+ return [(str(user_msg), str(bot_msg)) for user_msg, bot_msg in history_tuples]
131
+
132
+ # 1. Add type hints for better code maintainability
133
+ # 4. Add input validation
134
  def respond(
135
  message: str,
136
  chat_history: List[Tuple[str, str]],
137
  genre: Optional[str] = None,
138
  use_full_memory: bool = True
139
+ ) -> List[Dict[str, str]]: # Changed return type
140
  """Generate a response based on the current message and conversation history."""
141
  if not message.strip():
142
+ return [{"role": "assistant", "content": "Please provide a message"}]
143
 
144
  try:
145
+ # Format messages for API
146
+ api_messages = [{"role": "system", "content": get_enhanced_system_prompt(genre)}]
 
 
 
 
 
 
 
 
 
147
 
148
+ # Add chat history
149
+ if chat_history and use_full_memory:
150
+ for user_msg, bot_msg in chat_history[-MEMORY_WINDOW:]:
151
+ api_messages.extend([
152
+ {"role": "user", "content": str(user_msg)},
153
+ {"role": "assistant", "content": str(bot_msg)}
154
+ ])
 
 
 
 
 
 
 
155
 
156
+ # Add current message
157
+ api_messages.append({"role": "user", "content": str(message)})
158
+
159
+ # Make API call
160
+ response = client.chat_completion(
161
+ messages=api_messages,
162
+ max_tokens=MAX_TOKENS,
163
+ temperature=TEMPERATURE,
164
+ top_p=TOP_P
 
 
 
 
 
 
 
 
165
  )
166
 
167
+ # Return properly formatted message
168
+ bot_message = response.choices[0].message.content
169
+ return [{"role": "assistant", "content": str(bot_message)}]
 
 
 
 
 
 
 
 
 
170
 
171
  except Exception as e:
172
+ return [{"role": "assistant", "content": f"Story magic temporarily interrupted. Please try again. (Error: {str(e)})"}]
 
 
 
173
 
174
  def save_story(chat_history):
175
  """Convert chat history to markdown for download"""
 
183
 
184
  return story_text
185
 
 
186
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
187
  gr.Markdown("# 🔮 Interactive Story Time")
188
  with gr.Row():
 
193
  with gr.Column(scale=3):
194
  # Chat window + user input
195
  chatbot = gr.Chatbot(
196
+ height=500, # Increased height
197
+ bubble_full_width=True, # Allow bubbles to use full width
198
  show_copy_button=True,
199
  avatar_images=(None, "🧙"),
200
  type="messages",
201
  container=True,
202
  scale=1,
203
+ min_width=800, # Ensure minimum width
204
+ value=[], # Initialize with empty list
205
+ render=True
206
  )
207
  msg = gr.Textbox(
208
  placeholder="Describe what you want to do next in the story...",
 
237
  starter_btn4 = gr.Button("Starter 4")
238
  starter_buttons = [starter_btn1, starter_btn2, starter_btn3, starter_btn4]
239
 
240
+ # 1) We'll return a list of 4 dicts, each dict updating 'value' & 'visible'
241
  def update_starter_buttons(selected_genre):
242
  """Update starter buttons with examples for the selected genre."""
243
  examples = get_examples_for_genre(selected_genre)
244
  results = []
245
  for i in range(4):
246
  if i < len(examples):
247
+ # Return just the string value instead of a dict
248
  results.append(examples[i])
249
  else:
250
+ results.append("") # Empty string for hidden buttons
251
+ return tuple(results) # Return tuple of strings
252
+
253
+ # 2) Initialize them with "fantasy" so they don't stay "Starter X" on page load
254
+ # We'll just call the function and store the results in a variable, then apply them in a .load() event
255
+ initial_button_data = update_starter_buttons("fantasy") # returns 4 dicts
256
 
257
+ # 3) We'll define a "pick_starter" function that sets msg to the chosen text
258
+ def pick_starter(starter_text, chat_history, selected_genre, memory_flag):
259
+ # Putting 'starter_text' into the msg
260
+ return starter_text
261
 
262
+ # 4) Connect each starter button:
263
  for starter_button in starter_buttons:
 
264
  starter_button.click(
265
+ # Format initial message correctly as ChatMessage
266
+ fn=lambda x: [{"role": "user", "content": str(x)}],
267
  inputs=[starter_button],
268
  outputs=[chatbot],
269
  queue=False
270
  ).success(
271
+ # Then process with properly formatted message
272
+ fn=lambda x, h, g, m: respond(
273
+ message=x["content"], # Extract content from ChatMessage
274
+ chat_history=h if h else [],
275
+ genre=g,
276
+ use_full_memory=m
277
+ ),
278
+ inputs=[starter_button, chatbot, genre, full_memory],
279
  outputs=chatbot,
280
  queue=True
281
  )
282
 
283
+ # 5) Dynamically update the 4 buttons if the user changes the genre
284
  genre.change(
285
  fn=update_starter_buttons,
286
  inputs=[genre],
 
309
  queue=False
310
  )
311
 
312
+ # 6) Finally, run a "load" event to apply initial_button_data to the 4 button outputs on page load
313
+ def load_initial_buttons():
314
+ # Just return our precomputed tuple of 4 dicts
315
+ return initial_button_data
316
+
317
+ demo.load(fn=load_initial_buttons, outputs=starter_buttons, queue=False)
318
 
319
  # Run the app
320
  if __name__ == "__main__":
321
+ demo.launch(server_name="0.0.0.0", server_port=7860)