Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
from huggingface_hub import InferenceClient
|
3 |
import time
|
@@ -13,9 +215,46 @@ load_dotenv()
|
|
13 |
logging.basicConfig(level=logging.INFO)
|
14 |
logger = logging.getLogger(__name__)
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
# 初始化故事生成器的系统提示
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
STORY_STYLES = [
|
21 |
"奇幻",
|
@@ -41,131 +280,342 @@ def create_client() -> InferenceClient:
|
|
41 |
def generate_story(
|
42 |
scene: str,
|
43 |
style: str,
|
44 |
-
|
|
|
|
|
45 |
temperature: float = 0.7,
|
46 |
max_tokens: int = 512,
|
47 |
top_p: float = 0.95,
|
48 |
) -> Generator[str, None, None]:
|
|
|
|
|
|
|
49 |
if history is None:
|
50 |
history = []
|
51 |
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
messages = [
|
55 |
{"role": "system", "content": STORY_SYSTEM_PROMPT},
|
56 |
-
{"role": "user", "content":
|
57 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
max_tokens=max_tokens,
|
71 |
-
stream=True,
|
72 |
-
temperature=temperature,
|
73 |
-
top_p=top_p,
|
74 |
-
):
|
75 |
-
if hasattr(message.choices[0].delta, 'content'):
|
76 |
-
token = message.choices[0].delta.content
|
77 |
-
if token is not None:
|
78 |
-
response += token
|
79 |
-
yield response
|
80 |
-
break
|
81 |
-
except Exception as e:
|
82 |
-
retries += 1
|
83 |
-
logger.error(f"生成故事时发生错误 (尝试 {retries}/{MAX_RETRIES}): {str(e)}")
|
84 |
-
if retries < MAX_RETRIES:
|
85 |
-
time.sleep(RETRY_DELAY)
|
86 |
-
else:
|
87 |
-
yield f"抱歉,生成故事时遇到了问题:{str(e)}\n请稍后重试。"
|
88 |
|
89 |
def create_demo():
|
90 |
-
with gr.Blocks() as demo:
|
91 |
-
gr.Markdown(
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
label="选择故事风格"
|
98 |
-
)
|
99 |
-
scene_input = gr.Textbox(
|
100 |
-
lines=3,
|
101 |
-
placeholder="请输入一个场景或角色描述...",
|
102 |
-
label="场景描述"
|
103 |
)
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
)
|
131 |
-
status_msg = gr.Markdown("")
|
132 |
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
135 |
|
|
|
136 |
def user_input(user_message, history):
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
if history is None:
|
138 |
history = []
|
139 |
-
history.append(
|
140 |
return "", history
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
try:
|
144 |
-
|
145 |
-
|
|
|
146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
for text in generate_story(
|
148 |
-
|
149 |
style,
|
150 |
-
|
|
|
|
|
151 |
temperature,
|
152 |
max_tokens,
|
153 |
top_p
|
154 |
):
|
155 |
-
|
|
|
156 |
yield history
|
|
|
157 |
except Exception as e:
|
158 |
logger.error(f"处理响应时发生错误: {str(e)}")
|
159 |
-
|
|
|
160 |
yield history
|
|
|
161 |
|
|
|
|
|
|
|
|
|
|
|
162 |
scene_input.submit(
|
163 |
user_input,
|
164 |
[scene_input, chatbot],
|
165 |
[scene_input, chatbot]
|
166 |
).then(
|
167 |
bot_response,
|
168 |
-
[chatbot, style_select, temperature, max_tokens, top_p],
|
169 |
chatbot
|
170 |
)
|
171 |
|
@@ -175,13 +625,10 @@ def create_demo():
|
|
175 |
[scene_input, chatbot]
|
176 |
).then(
|
177 |
bot_response,
|
178 |
-
[chatbot, style_select, temperature, max_tokens, top_p],
|
179 |
chatbot
|
180 |
)
|
181 |
|
182 |
-
def clear_chat():
|
183 |
-
return [], ""
|
184 |
-
|
185 |
clear_btn.click(
|
186 |
clear_chat,
|
187 |
None,
|
@@ -190,6 +637,30 @@ def create_demo():
|
|
190 |
|
191 |
return demo
|
192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
if __name__ == "__main__":
|
194 |
demo = create_demo()
|
195 |
demo.queue().launch(
|
@@ -197,3 +668,4 @@ if __name__ == "__main__":
|
|
197 |
server_port=7860,
|
198 |
share=False
|
199 |
)
|
|
|
|
1 |
+
# import gradio as gr
|
2 |
+
# from huggingface_hub import InferenceClient
|
3 |
+
# import time
|
4 |
+
# from typing import Optional, Generator
|
5 |
+
# import logging
|
6 |
+
# import os
|
7 |
+
# from dotenv import load_dotenv
|
8 |
+
|
9 |
+
# # 加载环境变量
|
10 |
+
# load_dotenv()
|
11 |
+
|
12 |
+
# # 设置日志
|
13 |
+
# logging.basicConfig(level=logging.INFO)
|
14 |
+
# logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
# # 初始化故事生成器的系统提示
|
17 |
+
# STORY_SYSTEM_PROMPT = """你是一个专业的故事生成器。你需要根据用户提供的场景或角色描述,生成引人入胜的故事情节。
|
18 |
+
# 请确保故事具有连贯性和创意性。每次回应都应该是故事情节的自然延续。"""
|
19 |
+
|
20 |
+
# STORY_STYLES = [
|
21 |
+
# "奇幻",
|
22 |
+
# "科幻",
|
23 |
+
# "悬疑",
|
24 |
+
# "冒险",
|
25 |
+
# "爱情",
|
26 |
+
# "恐怖"
|
27 |
+
# ]
|
28 |
+
|
29 |
+
# MAX_RETRIES = 3
|
30 |
+
# RETRY_DELAY = 2
|
31 |
+
|
32 |
+
# def create_client() -> InferenceClient:
|
33 |
+
# hf_token = os.getenv('HF_TOKEN')
|
34 |
+
# if not hf_token:
|
35 |
+
# raise ValueError("HF_TOKEN 环境变量未设置")
|
36 |
+
# return InferenceClient(
|
37 |
+
# "HuggingFaceH4/zephyr-7b-beta",
|
38 |
+
# token=hf_token
|
39 |
+
# )
|
40 |
+
|
41 |
+
# def generate_story(
|
42 |
+
# scene: str,
|
43 |
+
# style: str,
|
44 |
+
# history: Optional[list[dict]] = None,
|
45 |
+
# temperature: float = 0.7,
|
46 |
+
# max_tokens: int = 512,
|
47 |
+
# top_p: float = 0.95,
|
48 |
+
# ) -> Generator[str, None, None]:
|
49 |
+
# if history is None:
|
50 |
+
# history = []
|
51 |
+
|
52 |
+
# style_prompt = f"请以{style}风格续写以下故事:"
|
53 |
+
|
54 |
+
# messages = [
|
55 |
+
# {"role": "system", "content": STORY_SYSTEM_PROMPT},
|
56 |
+
# {"role": "user", "content": f"{style_prompt}\n{scene}"}
|
57 |
+
# ]
|
58 |
+
|
59 |
+
# for msg in history:
|
60 |
+
# messages.append(msg)
|
61 |
+
|
62 |
+
# response = ""
|
63 |
+
# retries = 0
|
64 |
+
|
65 |
+
# while retries < MAX_RETRIES:
|
66 |
+
# try:
|
67 |
+
# client = create_client()
|
68 |
+
# for message in client.chat_completion(
|
69 |
+
# messages,
|
70 |
+
# max_tokens=max_tokens,
|
71 |
+
# stream=True,
|
72 |
+
# temperature=temperature,
|
73 |
+
# top_p=top_p,
|
74 |
+
# ):
|
75 |
+
# if hasattr(message.choices[0].delta, 'content'):
|
76 |
+
# token = message.choices[0].delta.content
|
77 |
+
# if token is not None:
|
78 |
+
# response += token
|
79 |
+
# yield response
|
80 |
+
# break
|
81 |
+
# except Exception as e:
|
82 |
+
# retries += 1
|
83 |
+
# logger.error(f"生成故事时发生错误 (尝试 {retries}/{MAX_RETRIES}): {str(e)}")
|
84 |
+
# if retries < MAX_RETRIES:
|
85 |
+
# time.sleep(RETRY_DELAY)
|
86 |
+
# else:
|
87 |
+
# yield f"抱歉,生成故事时遇到了问题:{str(e)}\n请稍后重试。"
|
88 |
+
|
89 |
+
# def create_demo():
|
90 |
+
# with gr.Blocks() as demo:
|
91 |
+
# gr.Markdown("# 互动式故事生成器")
|
92 |
+
# gr.Markdown("请输入一个场景或角色描述,AI将为您生成一个有趣的故事。您可以继续输入来推进故事情节的发展。")
|
93 |
+
|
94 |
+
# style_select = gr.Dropdown(
|
95 |
+
# choices=STORY_STYLES,
|
96 |
+
# value="奇幻",
|
97 |
+
# label="选择故事风格"
|
98 |
+
# )
|
99 |
+
# scene_input = gr.Textbox(
|
100 |
+
# lines=3,
|
101 |
+
# placeholder="请输入一个场景或角色描述...",
|
102 |
+
# label="场景描述"
|
103 |
+
# )
|
104 |
+
|
105 |
+
# temperature = gr.Slider(
|
106 |
+
# minimum=0.1,
|
107 |
+
# maximum=2.0,
|
108 |
+
# value=0.7,
|
109 |
+
# step=0.1,
|
110 |
+
# label="创意度(Temperature)"
|
111 |
+
# )
|
112 |
+
# max_tokens = gr.Slider(
|
113 |
+
# minimum=64,
|
114 |
+
# maximum=1024,
|
115 |
+
# value=512,
|
116 |
+
# step=64,
|
117 |
+
# label="最大生成长度"
|
118 |
+
# )
|
119 |
+
# top_p = gr.Slider(
|
120 |
+
# minimum=0.1,
|
121 |
+
# maximum=1.0,
|
122 |
+
# value=0.95,
|
123 |
+
# step=0.05,
|
124 |
+
# label="采样范围(Top-p)"
|
125 |
+
# )
|
126 |
+
|
127 |
+
# chatbot = gr.Chatbot(
|
128 |
+
# label="故事对话",
|
129 |
+
# type="messages"
|
130 |
+
# )
|
131 |
+
# status_msg = gr.Markdown("")
|
132 |
+
|
133 |
+
# submit_btn = gr.Button("生成故事")
|
134 |
+
# clear_btn = gr.Button("清除对话")
|
135 |
+
|
136 |
+
# def user_input(user_message, history):
|
137 |
+
# if history is None:
|
138 |
+
# history = []
|
139 |
+
# history.append({"role": "user", "content": user_message})
|
140 |
+
# return "", history
|
141 |
+
|
142 |
+
# def bot_response(history, style, temperature, max_tokens, top_p):
|
143 |
+
# try:
|
144 |
+
# current_message = {"role": "assistant", "content": ""}
|
145 |
+
# history.append(current_message)
|
146 |
+
|
147 |
+
# for text in generate_story(
|
148 |
+
# history[-2]["content"],
|
149 |
+
# style,
|
150 |
+
# history[:-2],
|
151 |
+
# temperature,
|
152 |
+
# max_tokens,
|
153 |
+
# top_p
|
154 |
+
# ):
|
155 |
+
# current_message["content"] = text
|
156 |
+
# yield history
|
157 |
+
# except Exception as e:
|
158 |
+
# logger.error(f"处理响应时发生错误: {str(e)}")
|
159 |
+
# current_message["content"] = f"抱歉,生成故事时遇到了问题。请稍后重试。"
|
160 |
+
# yield history
|
161 |
+
|
162 |
+
# scene_input.submit(
|
163 |
+
# user_input,
|
164 |
+
# [scene_input, chatbot],
|
165 |
+
# [scene_input, chatbot]
|
166 |
+
# ).then(
|
167 |
+
# bot_response,
|
168 |
+
# [chatbot, style_select, temperature, max_tokens, top_p],
|
169 |
+
# chatbot
|
170 |
+
# )
|
171 |
+
|
172 |
+
# submit_btn.click(
|
173 |
+
# user_input,
|
174 |
+
# [scene_input, chatbot],
|
175 |
+
# [scene_input, chatbot]
|
176 |
+
# ).then(
|
177 |
+
# bot_response,
|
178 |
+
# [chatbot, style_select, temperature, max_tokens, top_p],
|
179 |
+
# chatbot
|
180 |
+
# )
|
181 |
+
|
182 |
+
# def clear_chat():
|
183 |
+
# return [], ""
|
184 |
+
|
185 |
+
# clear_btn.click(
|
186 |
+
# clear_chat,
|
187 |
+
# None,
|
188 |
+
# [chatbot, status_msg],
|
189 |
+
# )
|
190 |
+
|
191 |
+
# return demo
|
192 |
+
|
193 |
+
# if __name__ == "__main__":
|
194 |
+
# demo = create_demo()
|
195 |
+
# demo.queue().launch(
|
196 |
+
# server_name="0.0.0.0",
|
197 |
+
# server_port=7860,
|
198 |
+
# share=False
|
199 |
+
# )
|
200 |
+
|
201 |
+
|
202 |
+
|
203 |
import gradio as gr
|
204 |
from huggingface_hub import InferenceClient
|
205 |
import time
|
|
|
215 |
logging.basicConfig(level=logging.INFO)
|
216 |
logger = logging.getLogger(__name__)
|
217 |
|
218 |
+
|
219 |
+
STORY_THEMES = [
|
220 |
+
"冒险",
|
221 |
+
"神秘",
|
222 |
+
"浪漫",
|
223 |
+
"历史",
|
224 |
+
"日常",
|
225 |
+
"童话"
|
226 |
+
]
|
227 |
+
|
228 |
+
CHARACTER_TEMPLATES = {
|
229 |
+
"冒险家": "一个勇敢无畏的探险家,热爱冒险与挑战。",
|
230 |
+
"侦探": "一个敏锐细心的侦探,善于观察和推理。",
|
231 |
+
"艺术家": "一个富有创造力的艺术家,对美有独特的见解。",
|
232 |
+
"科学家": "一个求知若渴的科学家,致力于探索未知。",
|
233 |
+
"普通人": "一个平凡但内心丰富的普通人。"
|
234 |
+
}
|
235 |
+
|
236 |
+
|
237 |
+
|
238 |
# 初始化故事生成器的系统提示
|
239 |
+
|
240 |
+
STORY_SYSTEM_PROMPT = """你是一个专业的故事生成器。你的任务是根据用户提供的设定和实时输入,生成连贯且引人入胜的故事。
|
241 |
+
|
242 |
+
关键要求:
|
243 |
+
1. 故事必须具有连续性,每次回应都要基于之前的所有情节发展
|
244 |
+
2. 认真分析对话历史,保持人物性格、情节走向的一致性
|
245 |
+
3. 当用户补充新的细节或提供新的发展方向时,自然地将其整合到现有故事中
|
246 |
+
4. 注意因果关系,确保每个情节的发生都有合理的铺垫和解释
|
247 |
+
5. 通过环境描写、人物对话等手法,让故事更加生动
|
248 |
+
6. 在故事发展的关键节点,可以给出一些暗示,引导用户参与情节推进
|
249 |
+
|
250 |
+
你不应该:
|
251 |
+
1. 重新开始新的故事
|
252 |
+
2. 忽视之前提到的重要情节或细节
|
253 |
+
3. 生成与已建立设定相矛盾的内容
|
254 |
+
4. 突兀地引入未经铺垫的重大转折
|
255 |
+
|
256 |
+
请记住:你正在创作一个持续发展的故事,而不是独立的片段。"""
|
257 |
+
|
258 |
|
259 |
STORY_STYLES = [
|
260 |
"奇幻",
|
|
|
280 |
def generate_story(
|
281 |
scene: str,
|
282 |
style: str,
|
283 |
+
theme: str,
|
284 |
+
character_desc: str,
|
285 |
+
history: list = None,
|
286 |
temperature: float = 0.7,
|
287 |
max_tokens: int = 512,
|
288 |
top_p: float = 0.95,
|
289 |
) -> Generator[str, None, None]:
|
290 |
+
"""
|
291 |
+
生成连续性的故事情节
|
292 |
+
"""
|
293 |
if history is None:
|
294 |
history = []
|
295 |
|
296 |
+
# 构建上下文摘要
|
297 |
+
context_summary = ""
|
298 |
+
story_content = []
|
299 |
+
|
300 |
+
# 提取之前的故事内容
|
301 |
+
for msg in history:
|
302 |
+
if msg["role"] == "assistant":
|
303 |
+
story_content.append(msg["content"])
|
304 |
+
|
305 |
+
if story_content:
|
306 |
+
context_summary = "\n".join([
|
307 |
+
"已经发生的故事情节:",
|
308 |
+
"---",
|
309 |
+
"\n".join(story_content),
|
310 |
+
"---"
|
311 |
+
])
|
312 |
+
|
313 |
+
# 根据是否有历史记录使用不同的提示模板
|
314 |
+
if not history:
|
315 |
+
# 首次生成,使用完整设定
|
316 |
+
prompt = f"""
|
317 |
+
请基于以下设定开始讲述一个故事:
|
318 |
+
|
319 |
+
风格:{style}
|
320 |
+
主题:{theme}
|
321 |
+
角色:{character_desc}
|
322 |
+
初始场景:{scene}
|
323 |
+
|
324 |
+
请从这个场景开始,展开故事的开端。注意为后续发展留下铺垫。
|
325 |
+
"""
|
326 |
+
else:
|
327 |
+
# 后续生成,侧重情节延续
|
328 |
+
prompt = f"""
|
329 |
+
{context_summary}
|
330 |
+
|
331 |
+
故事设定提醒:
|
332 |
+
- 风格:{style}
|
333 |
+
- 主题:{theme}
|
334 |
+
- 主要角色:{character_desc}
|
335 |
+
|
336 |
+
用户新的输入:{scene}
|
337 |
+
|
338 |
+
请基于以上已发生的情节和用户新的输入���自然地继续发展故事。注意:
|
339 |
+
1. 新的发展必须与之前的情节保持连贯
|
340 |
+
2. 合理化用户提供的新元素
|
341 |
+
3. 注意人物性格的一致性
|
342 |
+
4. 为后续发展留下可能性
|
343 |
+
|
344 |
+
继续讲述:
|
345 |
+
"""
|
346 |
|
347 |
messages = [
|
348 |
{"role": "system", "content": STORY_SYSTEM_PROMPT},
|
349 |
+
{"role": "user", "content": prompt}
|
350 |
]
|
351 |
+
|
352 |
+
try:
|
353 |
+
client = create_client()
|
354 |
+
response = ""
|
355 |
+
|
356 |
+
for message in client.chat_completion(
|
357 |
+
messages,
|
358 |
+
max_tokens=max_tokens,
|
359 |
+
stream=True,
|
360 |
+
temperature=temperature,
|
361 |
+
top_p=top_p,
|
362 |
+
):
|
363 |
+
if hasattr(message.choices[0].delta, 'content'):
|
364 |
+
token = message.choices[0].delta.content
|
365 |
+
if token is not None:
|
366 |
+
response += token
|
367 |
+
yield response
|
368 |
+
except Exception as e:
|
369 |
+
logger.error(f"生成故事时发生错误: {str(e)}")
|
370 |
+
yield f"抱歉,生成故事时遇到了问题:{str(e)}\n请稍后重试。"
|
371 |
+
|
372 |
+
|
373 |
+
|
374 |
+
def summarize_story_context(history: list) -> str:
|
375 |
+
"""
|
376 |
+
总结当前的故事上下文,用于辅助生成
|
377 |
+
"""
|
378 |
+
if not history:
|
379 |
+
return ""
|
380 |
|
381 |
+
summary_parts = []
|
382 |
+
key_elements = {
|
383 |
+
"characters": set(), # 出场人物
|
384 |
+
"locations": set(), # 场景地点
|
385 |
+
"events": [], # 关键事件
|
386 |
+
"objects": set() # 重要物品
|
387 |
+
}
|
388 |
|
389 |
+
for msg in history:
|
390 |
+
content = msg.get("content", "")
|
391 |
+
# TODO: 这里可以添加更复杂的NLP处理来提取关键信息
|
392 |
+
# 当前使用简单的文本累加
|
393 |
+
if content:
|
394 |
+
summary_parts.append(content)
|
395 |
|
396 |
+
return "\n".join(summary_parts)
|
397 |
+
|
398 |
+
|
399 |
+
|
400 |
+
# 创建故事生成器界面
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
401 |
|
402 |
def create_demo():
|
403 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
404 |
+
gr.Markdown(
|
405 |
+
"""
|
406 |
+
# 🎭 互动式故事生成器
|
407 |
+
让AI为您创造独特的故事体验。您可以选择故事风格、主题,添加角色设定,
|
408 |
+
然后描述一个场景开始您的故事。与AI互动来继续发展故事情节!
|
409 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
)
|
411 |
|
412 |
+
with gr.Tabs():
|
413 |
+
# 故事创作标签页
|
414 |
+
with gr.Tab("✍️ 故事创作"):
|
415 |
+
with gr.Row(equal_height=True):
|
416 |
+
# 左侧控制面板
|
417 |
+
with gr.Column(scale=1):
|
418 |
+
with gr.Group():
|
419 |
+
style_select = gr.Dropdown(
|
420 |
+
choices=STORY_STYLES,
|
421 |
+
value="奇幻",
|
422 |
+
label="选择故事风格",
|
423 |
+
info="选择一个整体风格来定义故事的基调"
|
424 |
+
)
|
425 |
+
|
426 |
+
theme_select = gr.Dropdown(
|
427 |
+
choices=STORY_THEMES,
|
428 |
+
value="冒险",
|
429 |
+
label="选择故事主题",
|
430 |
+
info="选择故事要重点表现的主题元素"
|
431 |
+
)
|
432 |
+
|
433 |
+
with gr.Group():
|
434 |
+
gr.Markdown("### 👤 角色设定")
|
435 |
+
character_select = gr.Dropdown(
|
436 |
+
choices=list(CHARACTER_TEMPLATES.keys()),
|
437 |
+
value="冒险家",
|
438 |
+
label="选择角色模板",
|
439 |
+
info="选择一个预设的角色类型,或自定义描述"
|
440 |
+
)
|
441 |
+
|
442 |
+
character_desc = gr.Textbox(
|
443 |
+
lines=3,
|
444 |
+
value=CHARACTER_TEMPLATES["冒险家"],
|
445 |
+
label="角色描述",
|
446 |
+
info="描述角色的性格、背景、特点等"
|
447 |
+
)
|
448 |
+
|
449 |
+
with gr.Group():
|
450 |
+
scene_input = gr.Textbox(
|
451 |
+
lines=3,
|
452 |
+
placeholder="在这里描述故事发生的场景、环境、时间等...",
|
453 |
+
label="场景描述",
|
454 |
+
info="详细的场景描述会让故事更加生动"
|
455 |
+
)
|
456 |
+
|
457 |
+
with gr.Row():
|
458 |
+
submit_btn = gr.Button("✨ 开始故事", variant="primary", scale=2)
|
459 |
+
clear_btn = gr.Button("🗑️ 清除对话", scale=1)
|
460 |
+
save_btn = gr.Button("💾 保存故事", scale=1)
|
461 |
+
|
462 |
+
# 右侧对话区域
|
463 |
+
with gr.Column(scale=2):
|
464 |
+
chatbot = gr.Chatbot(
|
465 |
+
label="故事对话",
|
466 |
+
height=600,
|
467 |
+
show_label=True
|
468 |
+
)
|
469 |
+
|
470 |
+
status_msg = gr.Markdown("")
|
471 |
+
|
472 |
+
# 设置标签页
|
473 |
+
with gr.Tab("⚙️ 高级设置"):
|
474 |
+
with gr.Group():
|
475 |
+
with gr.Row():
|
476 |
+
with gr.Column():
|
477 |
+
temperature = gr.Slider(
|
478 |
+
minimum=0.1,
|
479 |
+
maximum=2.0,
|
480 |
+
value=0.7,
|
481 |
+
step=0.1,
|
482 |
+
label="创意度(Temperature)",
|
483 |
+
info="较高的值会让故事更有创意但可能不够连贯"
|
484 |
+
)
|
485 |
+
|
486 |
+
max_tokens = gr.Slider(
|
487 |
+
minimum=64,
|
488 |
+
maximum=1024,
|
489 |
+
value=512,
|
490 |
+
step=64,
|
491 |
+
label="最大生成长度",
|
492 |
+
info="控制每次生成的文本长度"
|
493 |
+
)
|
494 |
+
|
495 |
+
top_p = gr.Slider(
|
496 |
+
minimum=0.1,
|
497 |
+
maximum=1.0,
|
498 |
+
value=0.95,
|
499 |
+
step=0.05,
|
500 |
+
label="采样范围(Top-p)",
|
501 |
+
info="控制词语选择的多样性"
|
502 |
+
)
|
503 |
|
504 |
+
# 帮助信息
|
505 |
+
with gr.Accordion("📖 使用帮助", open=False):
|
506 |
+
gr.Markdown(
|
507 |
+
"""
|
508 |
+
## 如何使用故事生成器
|
509 |
+
1. 选择故事风格和主题来确定故事的整体基调
|
510 |
+
2. 选择预设角色模板或自定义角色描述
|
511 |
+
3. 描述故事发生的场景和环境
|
512 |
+
4. 点击"开始故事"生成开篇
|
513 |
+
5. 继续输入内容与AI交互,推进故事发展
|
514 |
+
|
515 |
+
## 小提示
|
516 |
+
- 详细的场景和角色描述会让生成的故事更加丰富
|
517 |
+
- 可以使用"保存故事"功能保存精彩的故事情节
|
518 |
+
- 在设置中调整参数可以影响故事的创意程度和连贯性
|
519 |
+
- 遇到不满意的情节可以使用"清除对话"重新开始
|
520 |
+
|
521 |
+
## 参数说明
|
522 |
+
- 创意度: 控制故事的创意程度,值越高创意性越强
|
523 |
+
- 采样范围: 控制用词的丰富程��,值越高用词越多样
|
524 |
+
- 最大长度: 控制每次生成的文本长度
|
525 |
+
"""
|
526 |
+
)
|
527 |
+
|
528 |
+
# 更新角色描述
|
529 |
+
def update_character_desc(template):
|
530 |
+
return CHARACTER_TEMPLATES[template]
|
531 |
+
|
532 |
+
character_select.change(
|
533 |
+
update_character_desc,
|
534 |
+
character_select,
|
535 |
+
character_desc
|
536 |
)
|
|
|
537 |
|
538 |
+
# 保存故事对话
|
539 |
+
save_btn.click(
|
540 |
+
save_story,
|
541 |
+
chatbot,
|
542 |
+
status_msg,
|
543 |
+
)
|
544 |
|
545 |
+
# 用户输入处理
|
546 |
def user_input(user_message, history):
|
547 |
+
"""
|
548 |
+
处理用户输入
|
549 |
+
Args:
|
550 |
+
user_message: 用户输入的消息
|
551 |
+
history: 聊天历史记录 [(user_msg, bot_msg), ...]
|
552 |
+
"""
|
553 |
if history is None:
|
554 |
history = []
|
555 |
+
history.append([user_message, None]) # 添加用户消息,bot消息暂时为None
|
556 |
return "", history
|
557 |
+
|
558 |
+
# AI响应处理
|
559 |
+
def bot_response(history, style, theme, character_desc, temperature, max_tokens, top_p):
|
560 |
+
"""
|
561 |
+
生成AI响应
|
562 |
+
Args:
|
563 |
+
history: 聊天历史记录 [(user_msg, bot_msg), ...]
|
564 |
+
style: 故事风格
|
565 |
+
theme: 故事主题
|
566 |
+
character_desc: 角色描述
|
567 |
+
temperature: 生成参数
|
568 |
+
max_tokens: 生成参数
|
569 |
+
top_p: 生成参数
|
570 |
+
"""
|
571 |
try:
|
572 |
+
|
573 |
+
# 获取用户的最后一条消息
|
574 |
+
user_message = history[-1][0]
|
575 |
|
576 |
+
# 转换历史记录格式以传递给generate_story
|
577 |
+
message_history = []
|
578 |
+
for user_msg, bot_msg in history[:-1]: # 不包括最后一条
|
579 |
+
if user_msg:
|
580 |
+
message_history.append({"role": "user", "content": user_msg})
|
581 |
+
if bot_msg:
|
582 |
+
message_history.append({"role": "assistant", "content": bot_msg})
|
583 |
+
|
584 |
+
# 开始生成故事
|
585 |
+
current_response = ""
|
586 |
for text in generate_story(
|
587 |
+
user_message,
|
588 |
style,
|
589 |
+
theme,
|
590 |
+
character_desc,
|
591 |
+
message_history,
|
592 |
temperature,
|
593 |
max_tokens,
|
594 |
top_p
|
595 |
):
|
596 |
+
current_response = text
|
597 |
+
history[-1][1] = current_response # 更新最后一条消息的bot回复
|
598 |
yield history
|
599 |
+
|
600 |
except Exception as e:
|
601 |
logger.error(f"处理响应时发生错误: {str(e)}")
|
602 |
+
error_msg = f"抱歉,生成故事时遇到了问题。请稍后重试。"
|
603 |
+
history[-1][1] = error_msg
|
604 |
yield history
|
605 |
+
|
606 |
|
607 |
+
# 清除对话
|
608 |
+
def clear_chat():
|
609 |
+
return [], ""
|
610 |
+
|
611 |
+
# 绑定事件
|
612 |
scene_input.submit(
|
613 |
user_input,
|
614 |
[scene_input, chatbot],
|
615 |
[scene_input, chatbot]
|
616 |
).then(
|
617 |
bot_response,
|
618 |
+
[chatbot, style_select, theme_select, character_desc, temperature, max_tokens, top_p],
|
619 |
chatbot
|
620 |
)
|
621 |
|
|
|
625 |
[scene_input, chatbot]
|
626 |
).then(
|
627 |
bot_response,
|
628 |
+
[chatbot, style_select, theme_select, character_desc, temperature, max_tokens, top_p],
|
629 |
chatbot
|
630 |
)
|
631 |
|
|
|
|
|
|
|
632 |
clear_btn.click(
|
633 |
clear_chat,
|
634 |
None,
|
|
|
637 |
|
638 |
return demo
|
639 |
|
640 |
+
|
641 |
+
def save_story(chatbot):
|
642 |
+
"""保存故事对话记录"""
|
643 |
+
if not chatbot:
|
644 |
+
return "故事为空,无法保存"
|
645 |
+
|
646 |
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
647 |
+
filename = f"stories/story_{timestamp}.txt"
|
648 |
+
|
649 |
+
os.makedirs("stories", exist_ok=True)
|
650 |
+
|
651 |
+
try:
|
652 |
+
with open(filename, "w", encoding="utf-8") as f:
|
653 |
+
for user_msg, bot_msg in chatbot:
|
654 |
+
if user_msg:
|
655 |
+
f.write(f"用户: {user_msg}\n")
|
656 |
+
if bot_msg:
|
657 |
+
f.write(f"AI: {bot_msg}\n\n")
|
658 |
+
return f"故事已保存至 {filename}"
|
659 |
+
except Exception as e:
|
660 |
+
return f"保存失败: {str(e)}"
|
661 |
+
|
662 |
+
|
663 |
+
|
664 |
if __name__ == "__main__":
|
665 |
demo = create_demo()
|
666 |
demo.queue().launch(
|
|
|
668 |
server_port=7860,
|
669 |
share=False
|
670 |
)
|
671 |
+
|