Spaces:
Runtime error
Runtime error
import gradio as gr | |
from huggingface_hub import InferenceClient | |
import time | |
from typing import Optional, Generator | |
import logging | |
import os | |
from dotenv import load_dotenv | |
# 加载环境变量 | |
load_dotenv() | |
# 设置日志 | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# 初始化故事生成器的系统提示 | |
STORY_SYSTEM_PROMPT = """你是一个专业的故事生成器。你需要根据用户提供的场景或角色描述,生成引人入胜的故事情节。 | |
请确保故事具有连贯性和创意性。每次回应都应该是故事情节的自然延续。""" | |
STORY_STYLES = [ | |
"奇幻", | |
"科幻", | |
"悬疑", | |
"冒险", | |
"爱情", | |
"恐怖" | |
] | |
MAX_RETRIES = 3 | |
RETRY_DELAY = 2 | |
def create_client() -> InferenceClient: | |
hf_token = os.getenv('HF_TOKEN') | |
if not hf_token: | |
raise ValueError("HF_TOKEN 环境变量未设置") | |
return InferenceClient( | |
"HuggingFaceH4/zephyr-7b-beta", | |
token=hf_token | |
) | |
def generate_story( | |
scene: str, | |
style: str, | |
history: Optional[list[dict]] = None, | |
temperature: float = 0.7, | |
max_tokens: int = 512, | |
top_p: float = 0.95, | |
) -> Generator[str, None, None]: | |
if history is None: | |
history = [] | |
style_prompt = f"请以{style}风格续写以下故事:" | |
messages = [ | |
{"role": "system", "content": STORY_SYSTEM_PROMPT}, | |
{"role": "user", "content": f"{style_prompt}\n{scene}"} | |
] | |
for msg in history: | |
messages.append(msg) | |
response = "" | |
retries = 0 | |
while retries < MAX_RETRIES: | |
try: | |
client = create_client() | |
for message in client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
if hasattr(message.choices[0].delta, 'content'): | |
token = message.choices[0].delta.content | |
if token is not None: | |
response += token | |
yield response | |
break | |
except Exception as e: | |
retries += 1 | |
logger.error(f"生成故事时发生错误 (尝试 {retries}/{MAX_RETRIES}): {str(e)}") | |
if retries < MAX_RETRIES: | |
time.sleep(RETRY_DELAY) | |
else: | |
yield f"抱歉,生成故事时遇到了问题:{str(e)}\n请稍后重试。" | |
def create_demo(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# 互动式故事生成器") | |
gr.Markdown("请输入一个场景或角色描述,AI将为您生成一个有趣的故事。您可以继续输入来推进故事情节的发展。") | |
style_select = gr.Dropdown( | |
choices=STORY_STYLES, | |
value="奇幻", | |
label="选择故事风格" | |
) | |
scene_input = gr.Textbox( | |
lines=3, | |
placeholder="请输入一个场景或角色描述...", | |
label="场景描述" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.7, | |
step=0.1, | |
label="创意度(Temperature)" | |
) | |
max_tokens = gr.Slider( | |
minimum=64, | |
maximum=1024, | |
value=512, | |
step=64, | |
label="最大生成长度" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="采样范围(Top-p)" | |
) | |
chatbot = gr.Chatbot( | |
label="故事对话", | |
type="messages" | |
) | |
status_msg = gr.Markdown("") | |
submit_btn = gr.Button("生成故事") | |
clear_btn = gr.Button("清除对话") | |
def user_input(user_message, history): | |
if history is None: | |
history = [] | |
history.append({"role": "user", "content": user_message}) | |
return "", history | |
def bot_response(history, style, temperature, max_tokens, top_p): | |
try: | |
current_message = {"role": "assistant", "content": ""} | |
history.append(current_message) | |
for text in generate_story( | |
history[-2]["content"], | |
style, | |
history[:-2], | |
temperature, | |
max_tokens, | |
top_p | |
): | |
current_message["content"] = text | |
yield history | |
except Exception as e: | |
logger.error(f"处理响应时发生错误: {str(e)}") | |
current_message["content"] = f"抱歉,生成故事时遇到了问题。请稍后重试。" | |
yield history | |
scene_input.submit( | |
user_input, | |
[scene_input, chatbot], | |
[scene_input, chatbot] | |
).then( | |
bot_response, | |
[chatbot, style_select, temperature, max_tokens, top_p], | |
chatbot | |
) | |
submit_btn.click( | |
user_input, | |
[scene_input, chatbot], | |
[scene_input, chatbot] | |
).then( | |
bot_response, | |
[chatbot, style_select, temperature, max_tokens, top_p], | |
chatbot | |
) | |
def clear_chat(): | |
return [], "" | |
clear_btn.click( | |
clear_chat, | |
None, | |
[chatbot, status_msg], | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.queue().launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False | |
) | |