ChanMeng666 commited on
Commit
d9d201b
·
verified ·
1 Parent(s): 88f8e60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +558 -86
app.py CHANGED
@@ -1,3 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  import time
@@ -13,9 +215,46 @@ load_dotenv()
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  # 初始化故事生成器的系统提示
17
- STORY_SYSTEM_PROMPT = """你是一个专业的故事生成器。你需要根据用户提供的场景或角色描述,生成引人入胜的故事情节。
18
- 请确保故事具有连贯性和创意性。每次回应都应该是故事情节的自然延续。"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  STORY_STYLES = [
21
  "奇幻",
@@ -41,131 +280,342 @@ def create_client() -> InferenceClient:
41
  def generate_story(
42
  scene: str,
43
  style: str,
44
- history: Optional[list[dict]] = None,
 
 
45
  temperature: float = 0.7,
46
  max_tokens: int = 512,
47
  top_p: float = 0.95,
48
  ) -> Generator[str, None, None]:
 
 
 
49
  if history is None:
50
  history = []
51
 
52
- style_prompt = f"请以{style}风格续写以下故事:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  messages = [
55
  {"role": "system", "content": STORY_SYSTEM_PROMPT},
56
- {"role": "user", "content": f"{style_prompt}\n{scene}"}
57
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- for msg in history:
60
- messages.append(msg)
 
 
 
 
 
61
 
62
- response = ""
63
- retries = 0
 
 
 
 
64
 
65
- while retries < MAX_RETRIES:
66
- try:
67
- client = create_client()
68
- for message in client.chat_completion(
69
- messages,
70
- max_tokens=max_tokens,
71
- stream=True,
72
- temperature=temperature,
73
- top_p=top_p,
74
- ):
75
- if hasattr(message.choices[0].delta, 'content'):
76
- token = message.choices[0].delta.content
77
- if token is not None:
78
- response += token
79
- yield response
80
- break
81
- except Exception as e:
82
- retries += 1
83
- logger.error(f"生成故事时发生错误 (尝试 {retries}/{MAX_RETRIES}): {str(e)}")
84
- if retries < MAX_RETRIES:
85
- time.sleep(RETRY_DELAY)
86
- else:
87
- yield f"抱歉,生成故事时遇到了问题:{str(e)}\n请稍后重试。"
88
 
89
  def create_demo():
90
- with gr.Blocks() as demo:
91
- gr.Markdown("# 互动式故事生成器")
92
- gr.Markdown("请输入一个场景或角色描述,AI将为您生成一个有趣的故事。您可以继续输入来推进故事情节的发展。")
93
-
94
- style_select = gr.Dropdown(
95
- choices=STORY_STYLES,
96
- value="奇幻",
97
- label="选择故事风格"
98
- )
99
- scene_input = gr.Textbox(
100
- lines=3,
101
- placeholder="请输入一个场景或角色描述...",
102
- label="场景描述"
103
  )
104
 
105
- temperature = gr.Slider(
106
- minimum=0.1,
107
- maximum=2.0,
108
- value=0.7,
109
- step=0.1,
110
- label="创意度(Temperature)"
111
- )
112
- max_tokens = gr.Slider(
113
- minimum=64,
114
- maximum=1024,
115
- value=512,
116
- step=64,
117
- label="最大生成长度"
118
- )
119
- top_p = gr.Slider(
120
- minimum=0.1,
121
- maximum=1.0,
122
- value=0.95,
123
- step=0.05,
124
- label="采样范围(Top-p)"
125
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- chatbot = gr.Chatbot(
128
- label="故事对话",
129
- type="messages"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  )
131
- status_msg = gr.Markdown("")
132
 
133
- submit_btn = gr.Button("生成故事")
134
- clear_btn = gr.Button("清除对话")
 
 
 
 
135
 
 
136
  def user_input(user_message, history):
 
 
 
 
 
 
137
  if history is None:
138
  history = []
139
- history.append({"role": "user", "content": user_message})
140
  return "", history
141
-
142
- def bot_response(history, style, temperature, max_tokens, top_p):
 
 
 
 
 
 
 
 
 
 
 
 
143
  try:
144
- current_message = {"role": "assistant", "content": ""}
145
- history.append(current_message)
 
146
 
 
 
 
 
 
 
 
 
 
 
147
  for text in generate_story(
148
- history[-2]["content"],
149
  style,
150
- history[:-2],
 
 
151
  temperature,
152
  max_tokens,
153
  top_p
154
  ):
155
- current_message["content"] = text
 
156
  yield history
 
157
  except Exception as e:
158
  logger.error(f"处理响应时发生错误: {str(e)}")
159
- current_message["content"] = f"抱歉,生成故事时遇到了问题。请稍后重试。"
 
160
  yield history
 
161
 
 
 
 
 
 
162
  scene_input.submit(
163
  user_input,
164
  [scene_input, chatbot],
165
  [scene_input, chatbot]
166
  ).then(
167
  bot_response,
168
- [chatbot, style_select, temperature, max_tokens, top_p],
169
  chatbot
170
  )
171
 
@@ -175,13 +625,10 @@ def create_demo():
175
  [scene_input, chatbot]
176
  ).then(
177
  bot_response,
178
- [chatbot, style_select, temperature, max_tokens, top_p],
179
  chatbot
180
  )
181
 
182
- def clear_chat():
183
- return [], ""
184
-
185
  clear_btn.click(
186
  clear_chat,
187
  None,
@@ -190,6 +637,30 @@ def create_demo():
190
 
191
  return demo
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  if __name__ == "__main__":
194
  demo = create_demo()
195
  demo.queue().launch(
@@ -197,3 +668,4 @@ if __name__ == "__main__":
197
  server_port=7860,
198
  share=False
199
  )
 
 
1
+ # import gradio as gr
2
+ # from huggingface_hub import InferenceClient
3
+ # import time
4
+ # from typing import Optional, Generator
5
+ # import logging
6
+ # import os
7
+ # from dotenv import load_dotenv
8
+
9
+ # # 加载环境变量
10
+ # load_dotenv()
11
+
12
+ # # 设置日志
13
+ # logging.basicConfig(level=logging.INFO)
14
+ # logger = logging.getLogger(__name__)
15
+
16
+ # # 初始化故事生成器的系统提示
17
+ # STORY_SYSTEM_PROMPT = """你是一个专业的故事生成器。你需要根据用户提供的场景或角色描述,生成引人入胜的故事情节。
18
+ # 请确保故事具有连贯性和创意性。每次回应都应该是故事情节的自然延续。"""
19
+
20
+ # STORY_STYLES = [
21
+ # "奇幻",
22
+ # "科幻",
23
+ # "悬疑",
24
+ # "冒险",
25
+ # "爱情",
26
+ # "恐怖"
27
+ # ]
28
+
29
+ # MAX_RETRIES = 3
30
+ # RETRY_DELAY = 2
31
+
32
+ # def create_client() -> InferenceClient:
33
+ # hf_token = os.getenv('HF_TOKEN')
34
+ # if not hf_token:
35
+ # raise ValueError("HF_TOKEN 环境变量未设置")
36
+ # return InferenceClient(
37
+ # "HuggingFaceH4/zephyr-7b-beta",
38
+ # token=hf_token
39
+ # )
40
+
41
+ # def generate_story(
42
+ # scene: str,
43
+ # style: str,
44
+ # history: Optional[list[dict]] = None,
45
+ # temperature: float = 0.7,
46
+ # max_tokens: int = 512,
47
+ # top_p: float = 0.95,
48
+ # ) -> Generator[str, None, None]:
49
+ # if history is None:
50
+ # history = []
51
+
52
+ # style_prompt = f"请以{style}风格续写以下故事:"
53
+
54
+ # messages = [
55
+ # {"role": "system", "content": STORY_SYSTEM_PROMPT},
56
+ # {"role": "user", "content": f"{style_prompt}\n{scene}"}
57
+ # ]
58
+
59
+ # for msg in history:
60
+ # messages.append(msg)
61
+
62
+ # response = ""
63
+ # retries = 0
64
+
65
+ # while retries < MAX_RETRIES:
66
+ # try:
67
+ # client = create_client()
68
+ # for message in client.chat_completion(
69
+ # messages,
70
+ # max_tokens=max_tokens,
71
+ # stream=True,
72
+ # temperature=temperature,
73
+ # top_p=top_p,
74
+ # ):
75
+ # if hasattr(message.choices[0].delta, 'content'):
76
+ # token = message.choices[0].delta.content
77
+ # if token is not None:
78
+ # response += token
79
+ # yield response
80
+ # break
81
+ # except Exception as e:
82
+ # retries += 1
83
+ # logger.error(f"生成故事时发生错误 (尝试 {retries}/{MAX_RETRIES}): {str(e)}")
84
+ # if retries < MAX_RETRIES:
85
+ # time.sleep(RETRY_DELAY)
86
+ # else:
87
+ # yield f"抱歉,生成故事时遇到了问题:{str(e)}\n请稍后重试。"
88
+
89
+ # def create_demo():
90
+ # with gr.Blocks() as demo:
91
+ # gr.Markdown("# 互动式故事生成器")
92
+ # gr.Markdown("请输入一个场景或角色描述,AI将为您生成一个有趣的故事。您可以继续输入来推进故事情节的发展。")
93
+
94
+ # style_select = gr.Dropdown(
95
+ # choices=STORY_STYLES,
96
+ # value="奇幻",
97
+ # label="选择故事风格"
98
+ # )
99
+ # scene_input = gr.Textbox(
100
+ # lines=3,
101
+ # placeholder="请输入一个场景或角色描述...",
102
+ # label="场景描述"
103
+ # )
104
+
105
+ # temperature = gr.Slider(
106
+ # minimum=0.1,
107
+ # maximum=2.0,
108
+ # value=0.7,
109
+ # step=0.1,
110
+ # label="创意度(Temperature)"
111
+ # )
112
+ # max_tokens = gr.Slider(
113
+ # minimum=64,
114
+ # maximum=1024,
115
+ # value=512,
116
+ # step=64,
117
+ # label="最大生成长度"
118
+ # )
119
+ # top_p = gr.Slider(
120
+ # minimum=0.1,
121
+ # maximum=1.0,
122
+ # value=0.95,
123
+ # step=0.05,
124
+ # label="采样范围(Top-p)"
125
+ # )
126
+
127
+ # chatbot = gr.Chatbot(
128
+ # label="故事对话",
129
+ # type="messages"
130
+ # )
131
+ # status_msg = gr.Markdown("")
132
+
133
+ # submit_btn = gr.Button("生成故事")
134
+ # clear_btn = gr.Button("清除对话")
135
+
136
+ # def user_input(user_message, history):
137
+ # if history is None:
138
+ # history = []
139
+ # history.append({"role": "user", "content": user_message})
140
+ # return "", history
141
+
142
+ # def bot_response(history, style, temperature, max_tokens, top_p):
143
+ # try:
144
+ # current_message = {"role": "assistant", "content": ""}
145
+ # history.append(current_message)
146
+
147
+ # for text in generate_story(
148
+ # history[-2]["content"],
149
+ # style,
150
+ # history[:-2],
151
+ # temperature,
152
+ # max_tokens,
153
+ # top_p
154
+ # ):
155
+ # current_message["content"] = text
156
+ # yield history
157
+ # except Exception as e:
158
+ # logger.error(f"处理响应时发生错误: {str(e)}")
159
+ # current_message["content"] = f"抱歉,生成故事时遇到了问题。请稍后重试。"
160
+ # yield history
161
+
162
+ # scene_input.submit(
163
+ # user_input,
164
+ # [scene_input, chatbot],
165
+ # [scene_input, chatbot]
166
+ # ).then(
167
+ # bot_response,
168
+ # [chatbot, style_select, temperature, max_tokens, top_p],
169
+ # chatbot
170
+ # )
171
+
172
+ # submit_btn.click(
173
+ # user_input,
174
+ # [scene_input, chatbot],
175
+ # [scene_input, chatbot]
176
+ # ).then(
177
+ # bot_response,
178
+ # [chatbot, style_select, temperature, max_tokens, top_p],
179
+ # chatbot
180
+ # )
181
+
182
+ # def clear_chat():
183
+ # return [], ""
184
+
185
+ # clear_btn.click(
186
+ # clear_chat,
187
+ # None,
188
+ # [chatbot, status_msg],
189
+ # )
190
+
191
+ # return demo
192
+
193
+ # if __name__ == "__main__":
194
+ # demo = create_demo()
195
+ # demo.queue().launch(
196
+ # server_name="0.0.0.0",
197
+ # server_port=7860,
198
+ # share=False
199
+ # )
200
+
201
+
202
+
203
  import gradio as gr
204
  from huggingface_hub import InferenceClient
205
  import time
 
215
  logging.basicConfig(level=logging.INFO)
216
  logger = logging.getLogger(__name__)
217
 
218
+
219
+ STORY_THEMES = [
220
+ "冒险",
221
+ "神秘",
222
+ "浪漫",
223
+ "历史",
224
+ "日常",
225
+ "童话"
226
+ ]
227
+
228
+ CHARACTER_TEMPLATES = {
229
+ "冒险家": "一个勇敢无畏的探险家,热爱冒险与挑战。",
230
+ "侦探": "一个敏锐细心的侦探,善于观察和推理。",
231
+ "艺术家": "一个富有创造力的艺术家,对美有独特的见解。",
232
+ "科学家": "一个求知若渴的科学家,致力于探索未知。",
233
+ "普通人": "一个平凡但内心丰富的普通人。"
234
+ }
235
+
236
+
237
+
238
  # 初始化故事生成器的系统提示
239
+
240
+ STORY_SYSTEM_PROMPT = """你是一个专业的故事生成器。你的任务是根据用户提供的设定和实时输入,生成连贯且引人入胜的故事。
241
+
242
+ 关键要求:
243
+ 1. 故事必须具有连续性,每次回应都要基于之前的所有情节发展
244
+ 2. 认真分析对话历史,保持人物性格、情节走向的一致性
245
+ 3. 当用户补充新的细节或提供新的发展方向时,自然地将其整合到现有故事中
246
+ 4. 注意因果关系,确保每个情节的发生都有合理的铺垫和解释
247
+ 5. 通过环境描写、人物对话等手法,让故事更加生动
248
+ 6. 在故事发展的关键节点,可以给出一些暗示,引导用户参与情节推进
249
+
250
+ 你不应该:
251
+ 1. 重新开始新的故事
252
+ 2. 忽视之前提到的重要情节或细节
253
+ 3. 生成与已建立设定相矛盾的内容
254
+ 4. 突兀地引入未经铺垫的重大转折
255
+
256
+ 请记住:你正在创作一个持续发展的故事,而不是独立的片段。"""
257
+
258
 
259
  STORY_STYLES = [
260
  "奇幻",
 
280
  def generate_story(
281
  scene: str,
282
  style: str,
283
+ theme: str,
284
+ character_desc: str,
285
+ history: list = None,
286
  temperature: float = 0.7,
287
  max_tokens: int = 512,
288
  top_p: float = 0.95,
289
  ) -> Generator[str, None, None]:
290
+ """
291
+ 生成连续性的故事情节
292
+ """
293
  if history is None:
294
  history = []
295
 
296
+ # 构建上下文摘要
297
+ context_summary = ""
298
+ story_content = []
299
+
300
+ # 提取之前的故事内容
301
+ for msg in history:
302
+ if msg["role"] == "assistant":
303
+ story_content.append(msg["content"])
304
+
305
+ if story_content:
306
+ context_summary = "\n".join([
307
+ "已经发生的故事情节:",
308
+ "---",
309
+ "\n".join(story_content),
310
+ "---"
311
+ ])
312
+
313
+ # 根据是否有历史记录使用不同的提示模板
314
+ if not history:
315
+ # 首次生成,使用完整设定
316
+ prompt = f"""
317
+ 请基于以下设定开始讲述一个故事:
318
+
319
+ 风格:{style}
320
+ 主题:{theme}
321
+ 角色:{character_desc}
322
+ 初始场景:{scene}
323
+
324
+ 请从这个场景开始,展开故事的开端。注意为后续发展留下铺垫。
325
+ """
326
+ else:
327
+ # 后续生成,侧重情节延续
328
+ prompt = f"""
329
+ {context_summary}
330
+
331
+ 故事设定提醒:
332
+ - 风格:{style}
333
+ - 主题:{theme}
334
+ - 主要角色:{character_desc}
335
+
336
+ 用户新的输入:{scene}
337
+
338
+ 请基于以上已发生的情节和用户新的输入���自然地继续发展故事。注意:
339
+ 1. 新的发展必须与之前的情节保持连贯
340
+ 2. 合理化用户提供的新元素
341
+ 3. 注意人物性格的一致性
342
+ 4. 为后续发展留下可能性
343
+
344
+ 继续讲述:
345
+ """
346
 
347
  messages = [
348
  {"role": "system", "content": STORY_SYSTEM_PROMPT},
349
+ {"role": "user", "content": prompt}
350
  ]
351
+
352
+ try:
353
+ client = create_client()
354
+ response = ""
355
+
356
+ for message in client.chat_completion(
357
+ messages,
358
+ max_tokens=max_tokens,
359
+ stream=True,
360
+ temperature=temperature,
361
+ top_p=top_p,
362
+ ):
363
+ if hasattr(message.choices[0].delta, 'content'):
364
+ token = message.choices[0].delta.content
365
+ if token is not None:
366
+ response += token
367
+ yield response
368
+ except Exception as e:
369
+ logger.error(f"生成故事时发生错误: {str(e)}")
370
+ yield f"抱歉,生成故事时遇到了问题:{str(e)}\n请稍后重试。"
371
+
372
+
373
+
374
+ def summarize_story_context(history: list) -> str:
375
+ """
376
+ 总结当前的故事上下文,用于辅助生成
377
+ """
378
+ if not history:
379
+ return ""
380
 
381
+ summary_parts = []
382
+ key_elements = {
383
+ "characters": set(), # 出场人物
384
+ "locations": set(), # 场景地点
385
+ "events": [], # 关键事件
386
+ "objects": set() # 重要物品
387
+ }
388
 
389
+ for msg in history:
390
+ content = msg.get("content", "")
391
+ # TODO: 这里可以添加更复杂的NLP处理来提取关键信息
392
+ # 当前使用简单的文本累加
393
+ if content:
394
+ summary_parts.append(content)
395
 
396
+ return "\n".join(summary_parts)
397
+
398
+
399
+
400
+ # 创建故事生成器界面
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
 
402
  def create_demo():
403
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
404
+ gr.Markdown(
405
+ """
406
+ # 🎭 互动式故事生成器
407
+ 让AI为您创造独特的故事体验。您可以选择故事风格、主题,添加角色设定,
408
+ 然后描述一个场景开始您的故事。与AI互动来继续发展故事情节!
409
+ """
 
 
 
 
 
 
410
  )
411
 
412
+ with gr.Tabs():
413
+ # 故事创作标签页
414
+ with gr.Tab("✍️ 故事创作"):
415
+ with gr.Row(equal_height=True):
416
+ # 左侧控制面板
417
+ with gr.Column(scale=1):
418
+ with gr.Group():
419
+ style_select = gr.Dropdown(
420
+ choices=STORY_STYLES,
421
+ value="奇幻",
422
+ label="选择故事风格",
423
+ info="选择一个整体风格来定义故事的基调"
424
+ )
425
+
426
+ theme_select = gr.Dropdown(
427
+ choices=STORY_THEMES,
428
+ value="冒险",
429
+ label="选择故事主题",
430
+ info="选择故事要重点表现的主题元素"
431
+ )
432
+
433
+ with gr.Group():
434
+ gr.Markdown("### 👤 角色设定")
435
+ character_select = gr.Dropdown(
436
+ choices=list(CHARACTER_TEMPLATES.keys()),
437
+ value="冒险家",
438
+ label="选择角色模板",
439
+ info="选择一个预设的角色类型,或自定义描述"
440
+ )
441
+
442
+ character_desc = gr.Textbox(
443
+ lines=3,
444
+ value=CHARACTER_TEMPLATES["冒险家"],
445
+ label="角色描述",
446
+ info="描述角色的性格、背景、特点等"
447
+ )
448
+
449
+ with gr.Group():
450
+ scene_input = gr.Textbox(
451
+ lines=3,
452
+ placeholder="在这里描述故事发生的场景、环境、时间等...",
453
+ label="场景描述",
454
+ info="详细的场景描述会让故事更加生动"
455
+ )
456
+
457
+ with gr.Row():
458
+ submit_btn = gr.Button("✨ 开始故事", variant="primary", scale=2)
459
+ clear_btn = gr.Button("🗑️ 清除对话", scale=1)
460
+ save_btn = gr.Button("💾 保存故事", scale=1)
461
+
462
+ # 右侧对话区域
463
+ with gr.Column(scale=2):
464
+ chatbot = gr.Chatbot(
465
+ label="故事对话",
466
+ height=600,
467
+ show_label=True
468
+ )
469
+
470
+ status_msg = gr.Markdown("")
471
+
472
+ # 设置标签页
473
+ with gr.Tab("⚙️ 高级设置"):
474
+ with gr.Group():
475
+ with gr.Row():
476
+ with gr.Column():
477
+ temperature = gr.Slider(
478
+ minimum=0.1,
479
+ maximum=2.0,
480
+ value=0.7,
481
+ step=0.1,
482
+ label="创意度(Temperature)",
483
+ info="较高的值会让故事更有创意但可能不够连贯"
484
+ )
485
+
486
+ max_tokens = gr.Slider(
487
+ minimum=64,
488
+ maximum=1024,
489
+ value=512,
490
+ step=64,
491
+ label="最大生成长度",
492
+ info="控制每次生成的文本长度"
493
+ )
494
+
495
+ top_p = gr.Slider(
496
+ minimum=0.1,
497
+ maximum=1.0,
498
+ value=0.95,
499
+ step=0.05,
500
+ label="采样范围(Top-p)",
501
+ info="控制词语选择的多样性"
502
+ )
503
 
504
+ # 帮助信息
505
+ with gr.Accordion("📖 使用帮助", open=False):
506
+ gr.Markdown(
507
+ """
508
+ ## 如何使用故事生成器
509
+ 1. 选择故事风格和主题来确定故事的整体基调
510
+ 2. 选择预设角色模板或自定义角色描述
511
+ 3. 描述故事发生的场景和环境
512
+ 4. 点击"开始故事"生成开篇
513
+ 5. 继续输入内容与AI交互,推进故事发展
514
+
515
+ ## 小提示
516
+ - 详细的场景和角色描述会让生成的故事更加丰富
517
+ - 可以使用"保存故事"功能保存精彩的故事情节
518
+ - 在设置中调整参数可以影响故事的创意程度和连贯性
519
+ - 遇到不满意的情节可以使用"清除对话"重新开始
520
+
521
+ ## 参数说明
522
+ - 创意度: 控制故事的创意程度,值越高创意性越强
523
+ - 采样范围: 控制用词的丰富程��,值越高用词越多样
524
+ - 最大长度: 控制每次生成的文本长度
525
+ """
526
+ )
527
+
528
+ # 更新角色描述
529
+ def update_character_desc(template):
530
+ return CHARACTER_TEMPLATES[template]
531
+
532
+ character_select.change(
533
+ update_character_desc,
534
+ character_select,
535
+ character_desc
536
  )
 
537
 
538
+ # 保存故事对话
539
+ save_btn.click(
540
+ save_story,
541
+ chatbot,
542
+ status_msg,
543
+ )
544
 
545
+ # 用户输入处理
546
  def user_input(user_message, history):
547
+ """
548
+ 处理用户输入
549
+ Args:
550
+ user_message: 用户输入的消息
551
+ history: 聊天历史记录 [(user_msg, bot_msg), ...]
552
+ """
553
  if history is None:
554
  history = []
555
+ history.append([user_message, None]) # 添加用户消息,bot消息暂时为None
556
  return "", history
557
+
558
+ # AI响应处理
559
+ def bot_response(history, style, theme, character_desc, temperature, max_tokens, top_p):
560
+ """
561
+ 生成AI响应
562
+ Args:
563
+ history: 聊天历史记录 [(user_msg, bot_msg), ...]
564
+ style: 故事风格
565
+ theme: 故事主题
566
+ character_desc: 角色描述
567
+ temperature: 生成参数
568
+ max_tokens: 生成参数
569
+ top_p: 生成参数
570
+ """
571
  try:
572
+
573
+ # 获取用户的最后一条消息
574
+ user_message = history[-1][0]
575
 
576
+ # 转换历史记录格式以传递给generate_story
577
+ message_history = []
578
+ for user_msg, bot_msg in history[:-1]: # 不包括最后一条
579
+ if user_msg:
580
+ message_history.append({"role": "user", "content": user_msg})
581
+ if bot_msg:
582
+ message_history.append({"role": "assistant", "content": bot_msg})
583
+
584
+ # 开始生成故事
585
+ current_response = ""
586
  for text in generate_story(
587
+ user_message,
588
  style,
589
+ theme,
590
+ character_desc,
591
+ message_history,
592
  temperature,
593
  max_tokens,
594
  top_p
595
  ):
596
+ current_response = text
597
+ history[-1][1] = current_response # 更新最后一条消息的bot回复
598
  yield history
599
+
600
  except Exception as e:
601
  logger.error(f"处理响应时发生错误: {str(e)}")
602
+ error_msg = f"抱歉,生成故事时遇到了问题。请稍后重试。"
603
+ history[-1][1] = error_msg
604
  yield history
605
+
606
 
607
+ # 清除对话
608
+ def clear_chat():
609
+ return [], ""
610
+
611
+ # 绑定事件
612
  scene_input.submit(
613
  user_input,
614
  [scene_input, chatbot],
615
  [scene_input, chatbot]
616
  ).then(
617
  bot_response,
618
+ [chatbot, style_select, theme_select, character_desc, temperature, max_tokens, top_p],
619
  chatbot
620
  )
621
 
 
625
  [scene_input, chatbot]
626
  ).then(
627
  bot_response,
628
+ [chatbot, style_select, theme_select, character_desc, temperature, max_tokens, top_p],
629
  chatbot
630
  )
631
 
 
 
 
632
  clear_btn.click(
633
  clear_chat,
634
  None,
 
637
 
638
  return demo
639
 
640
+
641
+ def save_story(chatbot):
642
+ """保存故事对话记录"""
643
+ if not chatbot:
644
+ return "故事为空,无法保存"
645
+
646
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
647
+ filename = f"stories/story_{timestamp}.txt"
648
+
649
+ os.makedirs("stories", exist_ok=True)
650
+
651
+ try:
652
+ with open(filename, "w", encoding="utf-8") as f:
653
+ for user_msg, bot_msg in chatbot:
654
+ if user_msg:
655
+ f.write(f"用户: {user_msg}\n")
656
+ if bot_msg:
657
+ f.write(f"AI: {bot_msg}\n\n")
658
+ return f"故事已保存至 {filename}"
659
+ except Exception as e:
660
+ return f"保存失败: {str(e)}"
661
+
662
+
663
+
664
  if __name__ == "__main__":
665
  demo = create_demo()
666
  demo.queue().launch(
 
668
  server_port=7860,
669
  share=False
670
  )
671
+