|
import gradio as gr |
|
from huggingface_hub import InferenceClient |
|
import time |
|
from typing import Optional, Generator |
|
import logging |
|
import os |
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
STORY_SYSTEM_PROMPT = """你是一个专业的故事生成器。你需要根据用户提供的场景或角色描述,生成引人入胜的故事情节。 |
|
请确保故事具有连贯性和创意性。每次回应都应该是故事情节的自然延续。""" |
|
|
|
STORY_STYLES = [ |
|
"奇幻", |
|
"科幻", |
|
"悬疑", |
|
"冒险", |
|
"爱情", |
|
"恐怖" |
|
] |
|
|
|
MAX_RETRIES = 3 |
|
RETRY_DELAY = 2 |
|
|
|
def create_client() -> InferenceClient: |
|
hf_token = os.getenv('HF_TOKEN') |
|
if not hf_token: |
|
raise ValueError("HF_TOKEN 环境变量未设置") |
|
return InferenceClient( |
|
"HuggingFaceH4/zephyr-7b-beta", |
|
token=hf_token |
|
) |
|
|
|
def generate_story( |
|
scene: str, |
|
style: str, |
|
history: Optional[list[dict]] = None, |
|
temperature: float = 0.7, |
|
max_tokens: int = 512, |
|
top_p: float = 0.95, |
|
) -> Generator[str, None, None]: |
|
if history is None: |
|
history = [] |
|
|
|
style_prompt = f"请以{style}风格续写以下故事:" |
|
|
|
messages = [ |
|
{"role": "system", "content": STORY_SYSTEM_PROMPT}, |
|
{"role": "user", "content": f"{style_prompt}\n{scene}"} |
|
] |
|
|
|
for msg in history: |
|
messages.append(msg) |
|
|
|
response = "" |
|
retries = 0 |
|
|
|
while retries < MAX_RETRIES: |
|
try: |
|
client = create_client() |
|
for message in client.chat_completion( |
|
messages, |
|
max_tokens=max_tokens, |
|
stream=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
): |
|
if hasattr(message.choices[0].delta, 'content'): |
|
token = message.choices[0].delta.content |
|
if token is not None: |
|
response += token |
|
yield response |
|
break |
|
except Exception as e: |
|
retries += 1 |
|
logger.error(f"生成故事时发生错误 (尝试 {retries}/{MAX_RETRIES}): {str(e)}") |
|
if retries < MAX_RETRIES: |
|
time.sleep(RETRY_DELAY) |
|
else: |
|
yield f"抱歉,生成故事时遇到了问题:{str(e)}\n请稍后重试。" |
|
|
|
def create_demo(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown("# 互动式故事生成器") |
|
gr.Markdown("请输入一个场景或角色描述,AI将为您生成一个有趣的故事。您可以继续输入来推进故事情节的发展。") |
|
|
|
style_select = gr.Dropdown( |
|
choices=STORY_STYLES, |
|
value="奇幻", |
|
label="选择故事风格" |
|
) |
|
scene_input = gr.Textbox( |
|
lines=3, |
|
placeholder="请输入一个场景或角色描述...", |
|
label="场景描述" |
|
) |
|
|
|
temperature = gr.Slider( |
|
minimum=0.1, |
|
maximum=2.0, |
|
value=0.7, |
|
step=0.1, |
|
label="创意度(Temperature)" |
|
) |
|
max_tokens = gr.Slider( |
|
minimum=64, |
|
maximum=1024, |
|
value=512, |
|
step=64, |
|
label="最大生成长度" |
|
) |
|
top_p = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.95, |
|
step=0.05, |
|
label="采样范围(Top-p)" |
|
) |
|
|
|
chatbot = gr.Chatbot( |
|
label="故事对话", |
|
type="messages" |
|
) |
|
status_msg = gr.Markdown("") |
|
|
|
submit_btn = gr.Button("生成故事") |
|
clear_btn = gr.Button("清除对话") |
|
|
|
def user_input(user_message, history): |
|
if history is None: |
|
history = [] |
|
history.append({"role": "user", "content": user_message}) |
|
return "", history |
|
|
|
def bot_response(history, style, temperature, max_tokens, top_p): |
|
try: |
|
current_message = {"role": "assistant", "content": ""} |
|
history.append(current_message) |
|
|
|
for text in generate_story( |
|
history[-2]["content"], |
|
style, |
|
history[:-2], |
|
temperature, |
|
max_tokens, |
|
top_p |
|
): |
|
current_message["content"] = text |
|
yield history |
|
except Exception as e: |
|
logger.error(f"处理响应时发生错误: {str(e)}") |
|
current_message["content"] = f"抱歉,生成故事时遇到了问题。请稍后重试。" |
|
yield history |
|
|
|
scene_input.submit( |
|
user_input, |
|
[scene_input, chatbot], |
|
[scene_input, chatbot] |
|
).then( |
|
bot_response, |
|
[chatbot, style_select, temperature, max_tokens, top_p], |
|
chatbot |
|
) |
|
|
|
submit_btn.click( |
|
user_input, |
|
[scene_input, chatbot], |
|
[scene_input, chatbot] |
|
).then( |
|
bot_response, |
|
[chatbot, style_select, temperature, max_tokens, top_p], |
|
chatbot |
|
) |
|
|
|
def clear_chat(): |
|
return [], "" |
|
|
|
clear_btn.click( |
|
clear_chat, |
|
None, |
|
[chatbot, status_msg], |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
demo = create_demo() |
|
demo.queue().launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False |
|
) |
|
|