PSNbst's picture
Update app.py
88aafac verified
raw
history blame
6.83 kB
import gradio as gr
import random
import glob
import os
from openai import OpenAI
from dotenv import load_dotenv
# 加载环境变量
load_dotenv()
# ========== 默认选项和数据 ==========
EXPRESSIONS = ["smiling", "determined", "surprised", "serene"]
ITEMS = ["magic wand", "sword", "flower", "book of spells", "ancient scroll", "music instrument"]
OTHER_DETAILS = ["sparkles", "magical aura", "lens flare", "fireworks in the background"]
SCENES = ["sunset beach", "rainy city street at night", "fantasy forest with glowing mushrooms", "futuristic skyline at dawn"]
CAMERA_ANGLES = ["low-angle shot", "close-up shot", "bird's-eye view", "wide-angle shot"]
QUALITY_PROMPTS = ["8k", "ultra-realistic", "high detail", "cinematic lighting", "award-winning", "masterpiece"]
# ========== 工具函数 ==========
def load_candidates_from_files(files):
"""
从多个文件中加载候选项。
"""
all_lines = []
if files:
for file in files:
if isinstance(file, str):
with open(file, "r", encoding="utf-8") as f:
all_lines.extend([line.strip() for line in f if line.strip()])
return all_lines
def get_random_item(candidates):
"""
随机选取候选项。
"""
return random.choice(candidates) if candidates else ""
def load_dtr_from_directory(directory=".", pattern="*DTR*"):
"""
从指定目录中加载所有包含特定模式的文件内容。
:param directory: 目标目录,默认为当前目录
:param pattern: 匹配文件名的模式,默认是包含 "DTR" 的文件
:return: 文件内容的列表
"""
dtr_candidates = []
try:
files = glob.glob(os.path.join(directory, pattern))
for file in files:
with open(file, "r", encoding="utf-8") as f:
dtr_candidates.extend([line.strip() for line in f if line.strip()])
except Exception as e:
print(f"Error loading DTR files: {e}")
return dtr_candidates
def generate_natural_language_description(tags, api_key=None):
"""
使用 OpenAI GPT 生成自然语言描述。
"""
if not api_key:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
return "Error: No API Key provided and none found in environment variables."
tag_descriptions = "\n".join([f"{key}: {value}" for key, value in tags.items() if value])
try:
client = OpenAI(api_key=api_key)
response = client.chat.completions.create(
messages=[
{
"role": "system",
"content": (
"You are a creative assistant that generates vivid and imaginative scene descriptions for painting prompts. "
"Focus on the details provided and incorporate them into a cohesive narrative. "
"Use at least three sentences."
),
},
{
"role": "user",
"content": f"Here are the tags and details:\n{tag_descriptions}\nPlease generate a vivid, imaginative scene description.",
},
],
model="gpt-4o",
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"GPT generation failed. Error: {e}"
def generate_prompt(action_file, style_file, artist_files, character_files, dtr_directory, api_key, selected_categories):
"""
生成随机提示词和描述。
"""
actions = load_candidates_from_files([action_file]) if action_file else []
styles = load_candidates_from_files([style_file]) if style_file else []
artists = load_candidates_from_files(artist_files) if artist_files else []
characters = load_candidates_from_files(character_files) if character_files else []
dtr_candidates = load_dtr_from_directory(dtr_directory) if dtr_directory else []
number_of_characters = ", ".join(selected_categories) if selected_categories else random.choice(["1girl", "1boy"])
tags = {
"number_of_characters": number_of_characters,
"character_name": get_random_item(characters),
"artist_prompt": get_random_item(artists),
"style": get_random_item(styles),
"scene": get_random_item(SCENES),
"camera_angle": get_random_item(CAMERA_ANGLES),
"action": get_random_item(actions),
"expression": get_random_item(EXPRESSIONS),
"items": get_random_item(ITEMS),
"other_details": get_random_item(OTHER_DETAILS),
"quality_prompts": get_random_item(QUALITY_PROMPTS),
"dtr": get_random_item(dtr_candidates)
}
description = generate_natural_language_description(tags, api_key)
tags_list = [value for value in tags.values() if value]
unique_tags = list(dict.fromkeys(tags_list))
final_tags = ", ".join(unique_tags)
combined_output = f"{final_tags}\n\n{description}"
return final_tags, description, combined_output
# ========== Gradio 界面 ==========
def gradio_interface():
"""
定义 Gradio 应用界面。
"""
with gr.Blocks() as demo:
gr.Markdown("## Random Prompt Generator with User-Provided GPT API Key")
api_key_input = gr.Textbox(
label="Enter your OpenAI API Key (Optional)",
placeholder="sk-...",
type="password"
)
with gr.Row():
action_file = gr.File(label="Upload Action File (Optional)", file_types=[".txt"])
style_file = gr.File(label="Upload Style File (Optional)", file_types=[".txt"])
with gr.Row():
artist_files = gr.Files(label="Upload Artist Files (Multiple Allowed)", file_types=[".txt"])
character_files = gr.Files(label="Upload Character Files (Multiple Allowed)", file_types=[".txt"])
dtr_enabled = gr.Checkbox(label="Enable DTR Directory")
selected_categories = gr.CheckboxGroup(
["1boy", "1girl", "furry", "mecha", "fantasy monster", "animal", "still life"],
label="Choose Character Categories (Optional)"
)
with gr.Row():
tags_output = gr.Textbox(label="Generated Tags")
description_output = gr.Textbox(label="Generated Description")
combined_output = gr.Textbox(label="Combined Output: Tags + Description")
generate_button = gr.Button("Generate Prompt")
generate_button.click(
generate_prompt,
inputs=[
action_file, style_file, artist_files, character_files, dtr_enabled, api_key_input, selected_categories
],
outputs=[tags_output, description_output, combined_output],
)
return demo
# 启动 Gradio 应用
if __name__ == "__main__":
gradio_interface().launch(share=True)