PeterPinetree commited on
Commit
a036936
·
verified ·
1 Parent(s): 568f8af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -64
app.py CHANGED
@@ -10,11 +10,11 @@ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=hf_token)
10
 
11
  # Story genres with genre-specific example prompts
12
  GENRE_EXAMPLES = {
13
- "fairy_tale": [
14
- "I follow the shimmer of fairy dust into a hidden forest"
15
- "I meet a talking rabbit who claims to know a secret about the kings lost crown"
16
- "A tiny dragon appears at my window, asking for help to find its mother"
17
- "I step into a clearing where the trees whisper ancient riddles"
18
  "A friendly witch invites me into her cozy cottage, offering a warm cup of tea"
19
  ],
20
  "fantasy": [
@@ -73,7 +73,7 @@ GENRE_EXAMPLES = {
73
  ]
74
  }
75
 
76
- # 2. Add constants at the top for magic numbers
77
  MAX_HISTORY_LENGTH = 20
78
  MEMORY_WINDOW = 5 # Reduced from 10 to limit context
79
  MAX_TOKENS = 1024 # Reduced from 2048 for faster responses
@@ -110,33 +110,12 @@ IMPORTANT:
110
  Keep the story cohesive by referencing previous events and choices."""
111
  return system_message
112
 
113
- def create_story_summary(chat_history):
114
- """Create a concise summary of the story so far if the history gets too long"""
115
- if len(chat_history) <= 2:
116
- return None
117
-
118
- story_text = ""
119
- for user_msg, bot_msg in chat_history:
120
- story_text += f"User: {user_msg}\nStory: {bot_msg}\n\n"
121
-
122
- summary_instruction = {
123
- "role": "system",
124
- "content": "The conversation history is getting long. Please create a brief summary of the key plot points and character development so far to help maintain context without exceeding token limits."
125
- }
126
- return summary_instruction
127
-
128
- def format_history_for_gradio(history_tuples):
129
- """Convert chat history to Gradio's message format."""
130
- return [(str(user_msg), str(bot_msg)) for user_msg, bot_msg in history_tuples]
131
-
132
- # 1. Add type hints for better code maintainability
133
- # 4. Add input validation
134
  def respond(
135
  message: str,
136
  chat_history: List[Tuple[str, str]],
137
  genre: Optional[str] = None,
138
  use_full_memory: bool = True
139
- ) -> List[Tuple[str, str]]: # Changed return type to match Gradio's chatbot format
140
  """Generate a response based on the current message and conversation history."""
141
  if not message.strip():
142
  return chat_history
@@ -148,17 +127,18 @@ def respond(
148
  # Add chat history in correct format
149
  if chat_history and use_full_memory:
150
  for user_msg, bot_msg in chat_history[-MEMORY_WINDOW:]:
151
- api_messages.extend([
152
- {"role": "user", "content": str(user_msg)},
153
- {"role": "assistant", "content": str(bot_msg)}
154
- ])
155
 
156
  # Add current message
157
  api_messages.append({"role": "user", "content": str(message)})
158
 
159
  # Make API call with properly formatted messages
160
  response = client.chat_completion(
161
- messages=api_messages, # Use formatted messages
 
162
  max_tokens=MAX_TOKENS,
163
  temperature=TEMPERATURE,
164
  top_p=TOP_P
@@ -186,6 +166,7 @@ def save_story(chat_history):
186
 
187
  return story_text
188
 
 
189
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
190
  gr.Markdown("# 🔮 Interactive Story Time")
191
  with gr.Row():
@@ -196,16 +177,15 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
196
  with gr.Column(scale=3):
197
  # Chat window + user input
198
  chatbot = gr.Chatbot(
199
- height=500, # Increased height
200
- bubble_full_width=True, # Allow bubbles to use full width
201
  show_copy_button=True,
202
  avatar_images=(None, "🧙"),
203
  type="messages",
204
  container=True,
205
  scale=1,
206
- min_width=800, # Ensure minimum width
207
- value=[], # Initialize with empty list
208
- render=True
209
  )
210
  msg = gr.Textbox(
211
  placeholder="Describe what you want to do next in the story...",
@@ -240,32 +220,24 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
240
  starter_btn4 = gr.Button("Starter 4")
241
  starter_buttons = [starter_btn1, starter_btn2, starter_btn3, starter_btn4]
242
 
243
- # 1) We'll return a list of 4 dicts, each dict updating 'value' & 'visible'
244
  def update_starter_buttons(selected_genre):
245
  """Update starter buttons with examples for the selected genre."""
246
  examples = get_examples_for_genre(selected_genre)
247
  results = []
248
  for i in range(4):
249
  if i < len(examples):
250
- # Return just the string value instead of a dict
251
  results.append(examples[i])
252
  else:
253
- results.append("") # Empty string for hidden buttons
254
- return tuple(results) # Return tuple of strings
255
 
256
- # 2) Initialize them with "fantasy" so they don't stay "Starter X" on page load
257
- # We'll just call the function and store the results in a variable, then apply them in a .load() event
258
- initial_button_data = update_starter_buttons("fantasy") # returns 4 dicts
259
 
260
- # 3) We'll define a "pick_starter" function that sets msg to the chosen text
261
- def pick_starter(starter_text, chat_history, selected_genre, memory_flag):
262
- # Putting 'starter_text' into the msg
263
- return starter_text
264
-
265
- # 4) Connect each starter button:
266
  for starter_button in starter_buttons:
 
267
  starter_button.click(
268
- # Initial click creates empty chat history with starter message
269
  fn=lambda x: [(str(x), "")],
270
  inputs=[starter_button],
271
  outputs=[chatbot],
@@ -274,16 +246,16 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
274
  # Then process the message with respond function
275
  fn=respond,
276
  inputs=[
277
- starter_button, # The starter text
278
- chatbot, # Current chat history
279
- genre, # Selected genre
280
- full_memory # Memory flag
281
  ],
282
  outputs=chatbot,
283
  queue=True
284
  )
285
 
286
- # 5) Dynamically update the 4 buttons if the user changes the genre
287
  genre.change(
288
  fn=update_starter_buttons,
289
  inputs=[genre],
@@ -312,13 +284,13 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
312
  queue=False
313
  )
314
 
315
- # 6) Finally, run a "load" event to apply initial_button_data to the 4 button outputs on page load
316
- def load_initial_buttons():
317
- # Just return our precomputed tuple of 4 dicts
318
- return initial_button_data
319
-
320
- demo.load(fn=load_initial_buttons, outputs=starter_buttons, queue=False)
321
 
322
  # Run the app
323
  if __name__ == "__main__":
324
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
10
 
11
  # Story genres with genre-specific example prompts
12
  GENRE_EXAMPLES = {
13
+ "fairy tale": [
14
+ "I follow the shimmer of fairy dust into a hidden forest",
15
+ "I meet a talking rabbit who claims to know a secret about the king's lost crown",
16
+ "A tiny dragon appears at my window, asking for help to find its mother",
17
+ "I step into a clearing where the trees whisper ancient riddles",
18
  "A friendly witch invites me into her cozy cottage, offering a warm cup of tea"
19
  ],
20
  "fantasy": [
 
73
  ]
74
  }
75
 
76
+ # Constants
77
  MAX_HISTORY_LENGTH = 20
78
  MEMORY_WINDOW = 5 # Reduced from 10 to limit context
79
  MAX_TOKENS = 1024 # Reduced from 2048 for faster responses
 
110
  Keep the story cohesive by referencing previous events and choices."""
111
  return system_message
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  def respond(
114
  message: str,
115
  chat_history: List[Tuple[str, str]],
116
  genre: Optional[str] = None,
117
  use_full_memory: bool = True
118
+ ) -> List[Tuple[str, str]]:
119
  """Generate a response based on the current message and conversation history."""
120
  if not message.strip():
121
  return chat_history
 
127
  # Add chat history in correct format
128
  if chat_history and use_full_memory:
129
  for user_msg, bot_msg in chat_history[-MEMORY_WINDOW:]:
130
+ # Only add completed exchanges (where both user and bot messages exist)
131
+ if user_msg and bot_msg:
132
+ api_messages.append({"role": "user", "content": str(user_msg)})
133
+ api_messages.append({"role": "assistant", "content": str(bot_msg)})
134
 
135
  # Add current message
136
  api_messages.append({"role": "user", "content": str(message)})
137
 
138
  # Make API call with properly formatted messages
139
  response = client.chat_completion(
140
+ model="HuggingFaceH4/zephyr-7b-beta", # Explicitly specify model
141
+ messages=api_messages,
142
  max_tokens=MAX_TOKENS,
143
  temperature=TEMPERATURE,
144
  top_p=TOP_P
 
166
 
167
  return story_text
168
 
169
+ # Define the Gradio interface
170
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
171
  gr.Markdown("# 🔮 Interactive Story Time")
172
  with gr.Row():
 
177
  with gr.Column(scale=3):
178
  # Chat window + user input
179
  chatbot = gr.Chatbot(
180
+ height=500,
181
+ bubble_full_width=True,
182
  show_copy_button=True,
183
  avatar_images=(None, "🧙"),
184
  type="messages",
185
  container=True,
186
  scale=1,
187
+ min_width=800,
188
+ value=[]
 
189
  )
190
  msg = gr.Textbox(
191
  placeholder="Describe what you want to do next in the story...",
 
220
  starter_btn4 = gr.Button("Starter 4")
221
  starter_buttons = [starter_btn1, starter_btn2, starter_btn3, starter_btn4]
222
 
 
223
  def update_starter_buttons(selected_genre):
224
  """Update starter buttons with examples for the selected genre."""
225
  examples = get_examples_for_genre(selected_genre)
226
  results = []
227
  for i in range(4):
228
  if i < len(examples):
 
229
  results.append(examples[i])
230
  else:
231
+ results.append("")
232
+ return tuple(results)
233
 
234
+ # Initialize with default genre
235
+ initial_button_data = update_starter_buttons("fantasy")
 
236
 
237
+ # Connect each starter button
 
 
 
 
 
238
  for starter_button in starter_buttons:
239
+ # First, update chatbot with starter message
240
  starter_button.click(
 
241
  fn=lambda x: [(str(x), "")],
242
  inputs=[starter_button],
243
  outputs=[chatbot],
 
246
  # Then process the message with respond function
247
  fn=respond,
248
  inputs=[
249
+ starter_button,
250
+ chatbot,
251
+ genre,
252
+ full_memory
253
  ],
254
  outputs=chatbot,
255
  queue=True
256
  )
257
 
258
+ # Update buttons when genre changes
259
  genre.change(
260
  fn=update_starter_buttons,
261
  inputs=[genre],
 
284
  queue=False
285
  )
286
 
287
+ # Load initial button data
288
+ demo.load(
289
+ fn=lambda: initial_button_data,
290
+ outputs=starter_buttons,
291
+ queue=False
292
+ )
293
 
294
  # Run the app
295
  if __name__ == "__main__":
296
+ demo.launch(server_name="0.0.0.0", server_port=7860)