Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -10,18 +10,18 @@ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=hf_token)
|
|
10 |
|
11 |
# Story genres with genre-specific example prompts
|
12 |
GENRE_EXAMPLES = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
"fantasy": [
|
14 |
"I enter the ancient forest seeking the wizard's tower",
|
15 |
"I approach the dragon cautiously with my shield raised",
|
16 |
"I examine the mysterious runes carved into the stone altar",
|
17 |
"I try to bargain with the elven council for safe passage"
|
18 |
-
],
|
19 |
-
"fairy tale": [
|
20 |
-
"I follow the shimmer of fairy dust into a hidden forest",
|
21 |
-
"I meet a talking rabbit who claims to know a secret about the king’s lost crown",
|
22 |
-
"A tiny dragon appears at my window, asking for help to find its mother",
|
23 |
-
"I step into a clearing where the trees whisper ancient riddles",
|
24 |
-
"A friendly witch invites me into her cozy cottage, offering a warm cup of tea"
|
25 |
],
|
26 |
"sci-fi": [
|
27 |
"I hack into the space station's mainframe",
|
@@ -129,8 +129,10 @@ def format_history_for_gradio(history_tuples):
|
|
129 |
"""Convert (user, bot) tuples into Gradio 'messages' format (role/content dicts)."""
|
130 |
messages = []
|
131 |
for user_msg, bot_msg in history_tuples:
|
132 |
-
messages.
|
133 |
-
|
|
|
|
|
134 |
return messages
|
135 |
|
136 |
# 1. Add type hints for better code maintainability
|
@@ -144,45 +146,43 @@ def respond(
|
|
144 |
"""Generate a response based on the current message and conversation history."""
|
145 |
if not message.strip():
|
146 |
return chat_history
|
147 |
-
if genre and genre not in GENRE_EXAMPLES:
|
148 |
-
genre = "fantasy" # fallback to default
|
149 |
|
150 |
system_message = get_enhanced_system_prompt(genre)
|
151 |
-
|
152 |
-
#
|
153 |
formatted_history = []
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
157 |
|
158 |
-
|
|
|
|
|
|
|
159 |
|
160 |
-
#
|
161 |
if use_full_memory and formatted_history:
|
162 |
if len(formatted_history) > MAX_HISTORY_LENGTH:
|
163 |
summary_instruction = create_story_summary(chat_history[:len(chat_history)-5])
|
164 |
if summary_instruction:
|
165 |
api_messages.append(summary_instruction)
|
166 |
-
|
167 |
-
api_messages.append(msg)
|
168 |
else:
|
169 |
-
|
170 |
-
api_messages.append(msg)
|
171 |
else:
|
172 |
-
|
173 |
-
if formatted_history:
|
174 |
-
for msg in formatted_history[-memory_length*2:]:
|
175 |
-
api_messages.append(msg)
|
176 |
|
177 |
# Add current user message
|
178 |
-
api_messages.append({"role": "user", "content": message})
|
179 |
|
180 |
-
#
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
})
|
186 |
|
187 |
bot_message = ""
|
188 |
try:
|
@@ -193,18 +193,17 @@ def respond(
|
|
193 |
temperature=TEMPERATURE,
|
194 |
top_p=TOP_P,
|
195 |
):
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
yield format_history_for_gradio(new_history)
|
204 |
except Exception as e:
|
205 |
error_message = f"Story magic temporarily interrupted. Please try again. (Error: {str(e)})"
|
206 |
-
|
207 |
-
|
208 |
|
209 |
def save_story(chat_history):
|
210 |
"""Convert chat history to markdown for download"""
|
@@ -344,4 +343,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
344 |
|
345 |
# Run the app
|
346 |
if __name__ == "__main__":
|
347 |
-
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
10 |
|
11 |
# Story genres with genre-specific example prompts
|
12 |
GENRE_EXAMPLES = {
|
13 |
+
"fairy_tale": [
|
14 |
+
"I follow the shimmer of fairy dust into a hidden forest"
|
15 |
+
"I meet a talking rabbit who claims to know a secret about the king’s lost crown"
|
16 |
+
"A tiny dragon appears at my window, asking for help to find its mother"
|
17 |
+
"I step into a clearing where the trees whisper ancient riddles"
|
18 |
+
"A friendly witch invites me into her cozy cottage, offering a warm cup of tea"
|
19 |
+
],
|
20 |
"fantasy": [
|
21 |
"I enter the ancient forest seeking the wizard's tower",
|
22 |
"I approach the dragon cautiously with my shield raised",
|
23 |
"I examine the mysterious runes carved into the stone altar",
|
24 |
"I try to bargain with the elven council for safe passage"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
],
|
26 |
"sci-fi": [
|
27 |
"I hack into the space station's mainframe",
|
|
|
129 |
"""Convert (user, bot) tuples into Gradio 'messages' format (role/content dicts)."""
|
130 |
messages = []
|
131 |
for user_msg, bot_msg in history_tuples:
|
132 |
+
messages.extend([
|
133 |
+
{"role": "user", "content": str(user_msg)},
|
134 |
+
{"role": "assistant", "content": str(bot_msg)}
|
135 |
+
])
|
136 |
return messages
|
137 |
|
138 |
# 1. Add type hints for better code maintainability
|
|
|
146 |
"""Generate a response based on the current message and conversation history."""
|
147 |
if not message.strip():
|
148 |
return chat_history
|
|
|
|
|
149 |
|
150 |
system_message = get_enhanced_system_prompt(genre)
|
151 |
+
|
152 |
+
# Format history properly
|
153 |
formatted_history = []
|
154 |
+
if chat_history:
|
155 |
+
for user_msg, bot_msg in chat_history:
|
156 |
+
formatted_history.extend([
|
157 |
+
{"role": "user", "content": str(user_msg)},
|
158 |
+
{"role": "assistant", "content": str(bot_msg)}
|
159 |
+
])
|
160 |
|
161 |
+
# Construct API messages
|
162 |
+
api_messages = [
|
163 |
+
{"role": "system", "content": system_message}
|
164 |
+
]
|
165 |
|
166 |
+
# Add memory management
|
167 |
if use_full_memory and formatted_history:
|
168 |
if len(formatted_history) > MAX_HISTORY_LENGTH:
|
169 |
summary_instruction = create_story_summary(chat_history[:len(chat_history)-5])
|
170 |
if summary_instruction:
|
171 |
api_messages.append(summary_instruction)
|
172 |
+
api_messages.extend(formatted_history[-MEMORY_WINDOW*2:])
|
|
|
173 |
else:
|
174 |
+
api_messages.extend(formatted_history)
|
|
|
175 |
else:
|
176 |
+
api_messages.extend(formatted_history[-MEMORY_WINDOW*2:])
|
|
|
|
|
|
|
177 |
|
178 |
# Add current user message
|
179 |
+
api_messages.append({"role": "user", "content": str(message)})
|
180 |
|
181 |
+
# Add choice enforcement
|
182 |
+
api_messages.append({
|
183 |
+
"role": "system",
|
184 |
+
"content": "Remember to end your response with exactly three numbered choices, each starting with 'You'"
|
185 |
+
})
|
|
|
186 |
|
187 |
bot_message = ""
|
188 |
try:
|
|
|
193 |
temperature=TEMPERATURE,
|
194 |
top_p=TOP_P,
|
195 |
):
|
196 |
+
if hasattr(response_chunk.choices[0].delta, 'content'):
|
197 |
+
delta = response_chunk.choices[0].delta.content
|
198 |
+
if delta:
|
199 |
+
bot_message += delta
|
200 |
+
if len(bot_message.strip()) >= MIN_RESPONSE_LENGTH:
|
201 |
+
yield [{"role": "user", "content": message},
|
202 |
+
{"role": "assistant", "content": bot_message}]
|
|
|
203 |
except Exception as e:
|
204 |
error_message = f"Story magic temporarily interrupted. Please try again. (Error: {str(e)})"
|
205 |
+
yield [{"role": "user", "content": message},
|
206 |
+
{"role": "assistant", "content": error_message}]
|
207 |
|
208 |
def save_story(chat_history):
|
209 |
"""Convert chat history to markdown for download"""
|
|
|
343 |
|
344 |
# Run the app
|
345 |
if __name__ == "__main__":
|
346 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|