Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -10,11 +10,11 @@ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=hf_token)
|
|
10 |
|
11 |
# Story genres with genre-specific example prompts
|
12 |
GENRE_EXAMPLES = {
|
13 |
-
"
|
14 |
-
"I follow the shimmer of fairy dust into a hidden forest"
|
15 |
-
"I meet a talking rabbit who claims to know a secret about the king
|
16 |
-
"A tiny dragon appears at my window, asking for help to find its mother"
|
17 |
-
"I step into a clearing where the trees whisper ancient riddles"
|
18 |
"A friendly witch invites me into her cozy cottage, offering a warm cup of tea"
|
19 |
],
|
20 |
"fantasy": [
|
@@ -73,7 +73,7 @@ GENRE_EXAMPLES = {
|
|
73 |
]
|
74 |
}
|
75 |
|
76 |
-
#
|
77 |
MAX_HISTORY_LENGTH = 20
|
78 |
MEMORY_WINDOW = 5 # Reduced from 10 to limit context
|
79 |
MAX_TOKENS = 1024 # Reduced from 2048 for faster responses
|
@@ -110,81 +110,66 @@ IMPORTANT:
|
|
110 |
Keep the story cohesive by referencing previous events and choices."""
|
111 |
return system_message
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
def respond(
|
114 |
message: str,
|
115 |
chat_history: List[Tuple[str, str]],
|
116 |
genre: Optional[str] = None,
|
117 |
use_full_memory: bool = True
|
118 |
-
) -> List[
|
119 |
"""Generate a response based on the current message and conversation history."""
|
120 |
if not message.strip():
|
121 |
-
return
|
122 |
|
123 |
try:
|
124 |
-
#
|
125 |
-
|
126 |
-
|
127 |
-
# Construct messages list for API call
|
128 |
-
messages = []
|
129 |
-
|
130 |
-
# Always start with system message
|
131 |
-
messages.append({
|
132 |
-
"role": "system",
|
133 |
-
"content": get_enhanced_system_prompt(genre)
|
134 |
-
})
|
135 |
|
136 |
-
# Add
|
137 |
-
if
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
"content": str(user_msg)
|
144 |
-
})
|
145 |
-
if bot_msg: # Only add if bot message exists
|
146 |
-
messages.append({
|
147 |
-
"role": "assistant",
|
148 |
-
"content": str(bot_msg)
|
149 |
-
})
|
150 |
|
151 |
-
# Add current
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
json={
|
161 |
-
"inputs": messages,
|
162 |
-
"parameters": {
|
163 |
-
"max_new_tokens": MAX_TOKENS,
|
164 |
-
"temperature": TEMPERATURE,
|
165 |
-
"top_p": TOP_P
|
166 |
-
}
|
167 |
-
}
|
168 |
)
|
169 |
|
170 |
-
#
|
171 |
-
|
172 |
-
|
173 |
-
elif isinstance(response, list) and response and "generated_text" in response[0]:
|
174 |
-
bot_message = response[0]["generated_text"]
|
175 |
-
else:
|
176 |
-
# Fallback in case the response structure is different
|
177 |
-
bot_message = str(response)
|
178 |
-
|
179 |
-
# Add the new exchange to chat history
|
180 |
-
new_history.append((str(message), str(bot_message)))
|
181 |
-
return new_history
|
182 |
|
183 |
except Exception as e:
|
184 |
-
|
185 |
-
print(f"Error: {str(e)}")
|
186 |
-
# Return existing history plus error message
|
187 |
-
return chat_history + [(str(message), f"Story magic temporarily interrupted. Please try again. (Error: {str(e)})")]
|
188 |
|
189 |
def save_story(chat_history):
|
190 |
"""Convert chat history to markdown for download"""
|
@@ -198,7 +183,6 @@ def save_story(chat_history):
|
|
198 |
|
199 |
return story_text
|
200 |
|
201 |
-
# Define the Gradio interface
|
202 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
203 |
gr.Markdown("# 🔮 Interactive Story Time")
|
204 |
with gr.Row():
|
@@ -209,15 +193,16 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
209 |
with gr.Column(scale=3):
|
210 |
# Chat window + user input
|
211 |
chatbot = gr.Chatbot(
|
212 |
-
height=500,
|
213 |
-
bubble_full_width=True,
|
214 |
show_copy_button=True,
|
215 |
avatar_images=(None, "🧙"),
|
216 |
type="messages",
|
217 |
container=True,
|
218 |
scale=1,
|
219 |
-
min_width=800,
|
220 |
-
value=[]
|
|
|
221 |
)
|
222 |
msg = gr.Textbox(
|
223 |
placeholder="Describe what you want to do next in the story...",
|
@@ -252,42 +237,50 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
252 |
starter_btn4 = gr.Button("Starter 4")
|
253 |
starter_buttons = [starter_btn1, starter_btn2, starter_btn3, starter_btn4]
|
254 |
|
|
|
255 |
def update_starter_buttons(selected_genre):
|
256 |
"""Update starter buttons with examples for the selected genre."""
|
257 |
examples = get_examples_for_genre(selected_genre)
|
258 |
results = []
|
259 |
for i in range(4):
|
260 |
if i < len(examples):
|
|
|
261 |
results.append(examples[i])
|
262 |
else:
|
263 |
-
results.append("")
|
264 |
-
return tuple(results)
|
|
|
|
|
|
|
|
|
265 |
|
266 |
-
#
|
267 |
-
|
|
|
|
|
268 |
|
269 |
-
# Connect each starter button
|
270 |
for starter_button in starter_buttons:
|
271 |
-
# First, update chatbot with starter message
|
272 |
starter_button.click(
|
273 |
-
|
|
|
274 |
inputs=[starter_button],
|
275 |
outputs=[chatbot],
|
276 |
queue=False
|
277 |
).success(
|
278 |
-
# Then process
|
279 |
-
fn=respond
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
],
|
286 |
outputs=chatbot,
|
287 |
queue=True
|
288 |
)
|
289 |
|
290 |
-
#
|
291 |
genre.change(
|
292 |
fn=update_starter_buttons,
|
293 |
inputs=[genre],
|
@@ -316,13 +309,13 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
316 |
queue=False
|
317 |
)
|
318 |
|
319 |
-
#
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
)
|
325 |
|
326 |
# Run the app
|
327 |
if __name__ == "__main__":
|
328 |
-
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
10 |
|
11 |
# Story genres with genre-specific example prompts
|
12 |
GENRE_EXAMPLES = {
|
13 |
+
"fairy_tale": [
|
14 |
+
"I follow the shimmer of fairy dust into a hidden forest"
|
15 |
+
"I meet a talking rabbit who claims to know a secret about the king’s lost crown"
|
16 |
+
"A tiny dragon appears at my window, asking for help to find its mother"
|
17 |
+
"I step into a clearing where the trees whisper ancient riddles"
|
18 |
"A friendly witch invites me into her cozy cottage, offering a warm cup of tea"
|
19 |
],
|
20 |
"fantasy": [
|
|
|
73 |
]
|
74 |
}
|
75 |
|
76 |
+
# 2. Add constants at the top for magic numbers
|
77 |
MAX_HISTORY_LENGTH = 20
|
78 |
MEMORY_WINDOW = 5 # Reduced from 10 to limit context
|
79 |
MAX_TOKENS = 1024 # Reduced from 2048 for faster responses
|
|
|
110 |
Keep the story cohesive by referencing previous events and choices."""
|
111 |
return system_message
|
112 |
|
113 |
+
def create_story_summary(chat_history):
|
114 |
+
"""Create a concise summary of the story so far if the history gets too long"""
|
115 |
+
if len(chat_history) <= 2:
|
116 |
+
return None
|
117 |
+
|
118 |
+
story_text = ""
|
119 |
+
for user_msg, bot_msg in chat_history:
|
120 |
+
story_text += f"User: {user_msg}\nStory: {bot_msg}\n\n"
|
121 |
+
|
122 |
+
summary_instruction = {
|
123 |
+
"role": "system",
|
124 |
+
"content": "The conversation history is getting long. Please create a brief summary of the key plot points and character development so far to help maintain context without exceeding token limits."
|
125 |
+
}
|
126 |
+
return summary_instruction
|
127 |
+
|
128 |
+
def format_history_for_gradio(history_tuples):
|
129 |
+
"""Convert chat history to Gradio's message format."""
|
130 |
+
return [(str(user_msg), str(bot_msg)) for user_msg, bot_msg in history_tuples]
|
131 |
+
|
132 |
+
# 1. Add type hints for better code maintainability
|
133 |
+
# 4. Add input validation
|
134 |
def respond(
|
135 |
message: str,
|
136 |
chat_history: List[Tuple[str, str]],
|
137 |
genre: Optional[str] = None,
|
138 |
use_full_memory: bool = True
|
139 |
+
) -> List[Dict[str, str]]: # Changed return type
|
140 |
"""Generate a response based on the current message and conversation history."""
|
141 |
if not message.strip():
|
142 |
+
return [{"role": "assistant", "content": "Please provide a message"}]
|
143 |
|
144 |
try:
|
145 |
+
# Format messages for API
|
146 |
+
api_messages = [{"role": "system", "content": get_enhanced_system_prompt(genre)}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
+
# Add chat history
|
149 |
+
if chat_history and use_full_memory:
|
150 |
+
for user_msg, bot_msg in chat_history[-MEMORY_WINDOW:]:
|
151 |
+
api_messages.extend([
|
152 |
+
{"role": "user", "content": str(user_msg)},
|
153 |
+
{"role": "assistant", "content": str(bot_msg)}
|
154 |
+
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
+
# Add current message
|
157 |
+
api_messages.append({"role": "user", "content": str(message)})
|
158 |
+
|
159 |
+
# Make API call
|
160 |
+
response = client.chat_completion(
|
161 |
+
messages=api_messages,
|
162 |
+
max_tokens=MAX_TOKENS,
|
163 |
+
temperature=TEMPERATURE,
|
164 |
+
top_p=TOP_P
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
)
|
166 |
|
167 |
+
# Return properly formatted message
|
168 |
+
bot_message = response.choices[0].message.content
|
169 |
+
return [{"role": "assistant", "content": str(bot_message)}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
except Exception as e:
|
172 |
+
return [{"role": "assistant", "content": f"Story magic temporarily interrupted. Please try again. (Error: {str(e)})"}]
|
|
|
|
|
|
|
173 |
|
174 |
def save_story(chat_history):
|
175 |
"""Convert chat history to markdown for download"""
|
|
|
183 |
|
184 |
return story_text
|
185 |
|
|
|
186 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
187 |
gr.Markdown("# 🔮 Interactive Story Time")
|
188 |
with gr.Row():
|
|
|
193 |
with gr.Column(scale=3):
|
194 |
# Chat window + user input
|
195 |
chatbot = gr.Chatbot(
|
196 |
+
height=500, # Increased height
|
197 |
+
bubble_full_width=True, # Allow bubbles to use full width
|
198 |
show_copy_button=True,
|
199 |
avatar_images=(None, "🧙"),
|
200 |
type="messages",
|
201 |
container=True,
|
202 |
scale=1,
|
203 |
+
min_width=800, # Ensure minimum width
|
204 |
+
value=[], # Initialize with empty list
|
205 |
+
render=True
|
206 |
)
|
207 |
msg = gr.Textbox(
|
208 |
placeholder="Describe what you want to do next in the story...",
|
|
|
237 |
starter_btn4 = gr.Button("Starter 4")
|
238 |
starter_buttons = [starter_btn1, starter_btn2, starter_btn3, starter_btn4]
|
239 |
|
240 |
+
# 1) We'll return a list of 4 dicts, each dict updating 'value' & 'visible'
|
241 |
def update_starter_buttons(selected_genre):
|
242 |
"""Update starter buttons with examples for the selected genre."""
|
243 |
examples = get_examples_for_genre(selected_genre)
|
244 |
results = []
|
245 |
for i in range(4):
|
246 |
if i < len(examples):
|
247 |
+
# Return just the string value instead of a dict
|
248 |
results.append(examples[i])
|
249 |
else:
|
250 |
+
results.append("") # Empty string for hidden buttons
|
251 |
+
return tuple(results) # Return tuple of strings
|
252 |
+
|
253 |
+
# 2) Initialize them with "fantasy" so they don't stay "Starter X" on page load
|
254 |
+
# We'll just call the function and store the results in a variable, then apply them in a .load() event
|
255 |
+
initial_button_data = update_starter_buttons("fantasy") # returns 4 dicts
|
256 |
|
257 |
+
# 3) We'll define a "pick_starter" function that sets msg to the chosen text
|
258 |
+
def pick_starter(starter_text, chat_history, selected_genre, memory_flag):
|
259 |
+
# Putting 'starter_text' into the msg
|
260 |
+
return starter_text
|
261 |
|
262 |
+
# 4) Connect each starter button:
|
263 |
for starter_button in starter_buttons:
|
|
|
264 |
starter_button.click(
|
265 |
+
# Format initial message correctly as ChatMessage
|
266 |
+
fn=lambda x: [{"role": "user", "content": str(x)}],
|
267 |
inputs=[starter_button],
|
268 |
outputs=[chatbot],
|
269 |
queue=False
|
270 |
).success(
|
271 |
+
# Then process with properly formatted message
|
272 |
+
fn=lambda x, h, g, m: respond(
|
273 |
+
message=x["content"], # Extract content from ChatMessage
|
274 |
+
chat_history=h if h else [],
|
275 |
+
genre=g,
|
276 |
+
use_full_memory=m
|
277 |
+
),
|
278 |
+
inputs=[starter_button, chatbot, genre, full_memory],
|
279 |
outputs=chatbot,
|
280 |
queue=True
|
281 |
)
|
282 |
|
283 |
+
# 5) Dynamically update the 4 buttons if the user changes the genre
|
284 |
genre.change(
|
285 |
fn=update_starter_buttons,
|
286 |
inputs=[genre],
|
|
|
309 |
queue=False
|
310 |
)
|
311 |
|
312 |
+
# 6) Finally, run a "load" event to apply initial_button_data to the 4 button outputs on page load
|
313 |
+
def load_initial_buttons():
|
314 |
+
# Just return our precomputed tuple of 4 dicts
|
315 |
+
return initial_button_data
|
316 |
+
|
317 |
+
demo.load(fn=load_initial_buttons, outputs=starter_buttons, queue=False)
|
318 |
|
319 |
# Run the app
|
320 |
if __name__ == "__main__":
|
321 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|