ChanMeng666's picture
Update app.py
88f8e60 verified
raw
history blame
5.83 kB
import gradio as gr
from huggingface_hub import InferenceClient
import time
from typing import Optional, Generator
import logging
import os
from dotenv import load_dotenv
# 加载环境变量
load_dotenv()
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 初始化故事生成器的系统提示
STORY_SYSTEM_PROMPT = """你是一个专业的故事生成器。你需要根据用户提供的场景或角色描述,生成引人入胜的故事情节。
请确保故事具有连贯性和创意性。每次回应都应该是故事情节的自然延续。"""
STORY_STYLES = [
"奇幻",
"科幻",
"悬疑",
"冒险",
"爱情",
"恐怖"
]
MAX_RETRIES = 3
RETRY_DELAY = 2
def create_client() -> InferenceClient:
hf_token = os.getenv('HF_TOKEN')
if not hf_token:
raise ValueError("HF_TOKEN 环境变量未设置")
return InferenceClient(
"HuggingFaceH4/zephyr-7b-beta",
token=hf_token
)
def generate_story(
scene: str,
style: str,
history: Optional[list[dict]] = None,
temperature: float = 0.7,
max_tokens: int = 512,
top_p: float = 0.95,
) -> Generator[str, None, None]:
if history is None:
history = []
style_prompt = f"请以{style}风格续写以下故事:"
messages = [
{"role": "system", "content": STORY_SYSTEM_PROMPT},
{"role": "user", "content": f"{style_prompt}\n{scene}"}
]
for msg in history:
messages.append(msg)
response = ""
retries = 0
while retries < MAX_RETRIES:
try:
client = create_client()
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
if hasattr(message.choices[0].delta, 'content'):
token = message.choices[0].delta.content
if token is not None:
response += token
yield response
break
except Exception as e:
retries += 1
logger.error(f"生成故事时发生错误 (尝试 {retries}/{MAX_RETRIES}): {str(e)}")
if retries < MAX_RETRIES:
time.sleep(RETRY_DELAY)
else:
yield f"抱歉,生成故事时遇到了问题:{str(e)}\n请稍后重试。"
def create_demo():
with gr.Blocks() as demo:
gr.Markdown("# 互动式故事生成器")
gr.Markdown("请输入一个场景或角色描述,AI将为您生成一个有趣的故事。您可以继续输入来推进故事情节的发展。")
style_select = gr.Dropdown(
choices=STORY_STYLES,
value="奇幻",
label="选择故事风格"
)
scene_input = gr.Textbox(
lines=3,
placeholder="请输入一个场景或角色描述...",
label="场景描述"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="创意度(Temperature)"
)
max_tokens = gr.Slider(
minimum=64,
maximum=1024,
value=512,
step=64,
label="最大生成长度"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="采样范围(Top-p)"
)
chatbot = gr.Chatbot(
label="故事对话",
type="messages"
)
status_msg = gr.Markdown("")
submit_btn = gr.Button("生成故事")
clear_btn = gr.Button("清除对话")
def user_input(user_message, history):
if history is None:
history = []
history.append({"role": "user", "content": user_message})
return "", history
def bot_response(history, style, temperature, max_tokens, top_p):
try:
current_message = {"role": "assistant", "content": ""}
history.append(current_message)
for text in generate_story(
history[-2]["content"],
style,
history[:-2],
temperature,
max_tokens,
top_p
):
current_message["content"] = text
yield history
except Exception as e:
logger.error(f"处理响应时发生错误: {str(e)}")
current_message["content"] = f"抱歉,生成故事时遇到了问题。请稍后重试。"
yield history
scene_input.submit(
user_input,
[scene_input, chatbot],
[scene_input, chatbot]
).then(
bot_response,
[chatbot, style_select, temperature, max_tokens, top_p],
chatbot
)
submit_btn.click(
user_input,
[scene_input, chatbot],
[scene_input, chatbot]
).then(
bot_response,
[chatbot, style_select, temperature, max_tokens, top_p],
chatbot
)
def clear_chat():
return [], ""
clear_btn.click(
clear_chat,
None,
[chatbot, status_msg],
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.queue().launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)