ChanMeng666 commited on
Commit
d74a6c7
·
verified ·
1 Parent(s): 6941b44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -48
app.py CHANGED
@@ -1,64 +1,198 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
 
27
 
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  minimum=0.1,
54
  maximum=1.0,
55
  value=0.95,
56
  step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ import time
4
+ from typing import Optional, Generator
5
+ import logging
6
+ import os
7
+ from dotenv import load_dotenv
8
 
9
+ # 加载环境变量
10
+ load_dotenv()
 
 
11
 
12
+ # 设置日志
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
 
16
+ # 初始化故事生成器的系统提示
17
+ STORY_SYSTEM_PROMPT = """你是一个专业的故事生成器。你需要根据用户提供的场景或角色描述,生成引人入胜的故事情节。
18
+ 请确保故事具有连贯性和创意性。每次回应都应该是故事情节的自然延续。"""
 
 
 
 
 
 
19
 
20
+ STORY_STYLES = [
21
+ "奇幻",
22
+ "科幻",
23
+ "悬疑",
24
+ "冒险",
25
+ "爱情",
26
+ "恐怖"
27
+ ]
28
 
29
+ MAX_RETRIES = 3
30
+ RETRY_DELAY = 2
31
 
32
+ def create_client() -> InferenceClient:
33
+ hf_token = os.getenv('HF_TOKEN')
34
+ if not hf_token:
35
+ raise ValueError("HF_TOKEN 环境变量未设置")
36
+ return InferenceClient(
37
+ "HuggingFaceH4/zephyr-7b-beta",
38
+ token=hf_token
39
+ )
 
 
 
 
 
40
 
41
+ def generate_story(
42
+ scene: str,
43
+ style: str,
44
+ history: Optional[list[dict]] = None,
45
+ temperature: float = 0.7,
46
+ max_tokens: int = 512,
47
+ top_p: float = 0.95,
48
+ ) -> Generator[str, None, None]:
49
+ if history is None:
50
+ history = []
51
+
52
+ style_prompt = f"请以{style}风格续写以下故事:"
53
+
54
+ messages = [
55
+ {"role": "system", "content": STORY_SYSTEM_PROMPT},
56
+ {"role": "user", "content": f"{style_prompt}\n{scene}"}
57
+ ]
58
+
59
+ for msg in history:
60
+ messages.append(msg)
61
+
62
+ response = ""
63
+ retries = 0
64
+
65
+ while retries < MAX_RETRIES:
66
+ try:
67
+ client = create_client()
68
+ for message in client.chat_completion(
69
+ messages,
70
+ max_tokens=max_tokens,
71
+ stream=True,
72
+ temperature=temperature,
73
+ top_p=top_p,
74
+ ):
75
+ if hasattr(message.choices[0].delta, 'content'):
76
+ token = message.choices[0].delta.content
77
+ if token is not None:
78
+ response += token
79
+ yield response
80
+ break
81
+ except Exception as e:
82
+ retries += 1
83
+ logger.error(f"生成故事时发生错误 (尝试 {retries}/{MAX_RETRIES}): {str(e)}")
84
+ if retries < MAX_RETRIES:
85
+ time.sleep(RETRY_DELAY)
86
+ else:
87
+ yield f"抱歉,生成故事时遇到了问题:{str(e)}\n请稍后重试。"
88
 
89
+ def create_demo():
90
+ with gr.Blocks() as demo:
91
+ gr.Markdown("# 互动式故事生成器")
92
+ gr.Markdown("请输入一个场景或角色描述,AI将为您生成一个有趣的故事。您可以继续输入来推进故事情节的发展。")
93
+
94
+ style_select = gr.Dropdown(
95
+ choices=STORY_STYLES,
96
+ value="奇幻",
97
+ label="选择故事风格"
98
+ )
99
+ scene_input = gr.Textbox(
100
+ lines=3,
101
+ placeholder="请输入一个场景或角色描述...",
102
+ label="场景描述"
103
+ )
104
+
105
+ temperature = gr.Slider(
106
+ minimum=0.1,
107
+ maximum=2.0,
108
+ value=0.7,
109
+ step=0.1,
110
+ label="创意度(Temperature)"
111
+ )
112
+ max_tokens = gr.Slider(
113
+ minimum=64,
114
+ maximum=1024,
115
+ value=512,
116
+ step=64,
117
+ label="最大生成长度"
118
+ )
119
+ top_p = gr.Slider(
120
  minimum=0.1,
121
  maximum=1.0,
122
  value=0.95,
123
  step=0.05,
124
+ label="采样范围(Top-p)"
125
+ )
126
+
127
+ chatbot = gr.Chatbot(
128
+ label="故事对话",
129
+ type="messages"
130
+ )
131
+ status_msg = gr.Markdown("")
132
+
133
+ submit_btn = gr.Button("生成故事")
134
+ clear_btn = gr.Button("清除对话")
135
+
136
+ def user_input(user_message, history):
137
+ if history is None:
138
+ history = []
139
+ history.append({"role": "user", "content": user_message})
140
+ return "", history
141
+
142
+ def bot_response(history, style, temperature, max_tokens, top_p):
143
+ try:
144
+ current_message = {"role": "assistant", "content": ""}
145
+ history.append(current_message)
146
+
147
+ for text in generate_story(
148
+ history[-2]["content"],
149
+ style,
150
+ history[:-2],
151
+ temperature,
152
+ max_tokens,
153
+ top_p
154
+ ):
155
+ current_message["content"] = text
156
+ yield history
157
+ except Exception as e:
158
+ logger.error(f"处理响应时发生错误: {str(e)}")
159
+ current_message["content"] = f"抱歉,生成故事时遇到了问题。请稍后重试。"
160
+ yield history
161
+
162
+ scene_input.submit(
163
+ user_input,
164
+ [scene_input, chatbot],
165
+ [scene_input, chatbot]
166
+ ).then(
167
+ bot_response,
168
+ [chatbot, style_select, temperature, max_tokens, top_p],
169
+ chatbot
170
+ )
171
+
172
+ submit_btn.click(
173
+ user_input,
174
+ [scene_input, chatbot],
175
+ [scene_input, chatbot]
176
+ ).then(
177
+ bot_response,
178
+ [chatbot, style_select, temperature, max_tokens, top_p],
179
+ chatbot
180
+ )
181
+
182
+ def clear_chat():
183
+ return [], ""
184
+
185
+ clear_btn.click(
186
+ clear_chat,
187
+ None,
188
+ [chatbot, status_msg],
189
+ )
190
+
191
+ return demo
192
 
193
  if __name__ == "__main__":
194
+ demo = create_demo()
195
+ demo.queue().launch(
196
+ server_port=7861,
197
+ share=False
198
+ )