PSNbst's picture
Update app.py
00e31d3 verified
raw
history blame
8.48 kB
import gradio as gr
import random
import glob
import os
import requests
from openai import OpenAI
from dotenv import load_dotenv
# 加载环境变量
load_dotenv()
# ========== 默认选项和数据 ==========
EXPRESSIONS = ["smiling", "determined", "surprised", "serene", "laughing", "angry", "pensive", "confident"]
ITEMS = ["magic wand", "sword", "flower", "book of spells", "ancient scroll", "music instrument", "shield", "dagger", "staff", "potion"]
OTHER_DETAILS = ["sparkles", "magical aura", "lens flare", "fireworks in the background", "smoke effects", "light trails", "falling leaves", "glowing embers"]
SCENES = ["sunset beach", "rainy city street at night", "fantasy forest with glowing mushrooms", "futuristic skyline at dawn", "abandoned castle", "snowy mountain peak", "desert ruins", "underwater city"]
CAMERA_ANGLES = ["low-angle shot", "close-up shot", "bird's-eye view", "wide-angle shot", "over-the-shoulder shot", "extreme close-up", "panoramic view", "dynamic tracking shot"]
QUALITY_PROMPTS = ["cinematic lighting", "award-winning", "masterpiece", "vivid colors", "high dynamic range", "immersive", "studio quality", "fine art", "dreamlike", "8K", "HD", "high quality", "best quality"]
# Hugging Face DTR 数据集路径
DTR_DATASET_PATTERN = "https://huggingface.co/datasets/X779/Danbooruwildcards/resolve/main/*DTR*.txt"
# ========== 工具函数 ==========
def load_candidates_from_files(files):
"""
从多个文件中加载候选项。
"""
all_lines = []
if files:
for file in files:
if isinstance(file, str):
with open(file, "r", encoding="utf-8") as f:
all_lines.extend([line.strip() for line in f if line.strip()])
return all_lines
def get_random_items(candidates, num_items=1):
"""
从候选项中随机选取指定数量的选项。
"""
return random.sample(candidates, min(num_items, len(candidates))) if candidates else []
def load_dtr_from_huggingface():
"""
从 Hugging Face 数据集中加载所有包含 "DTR" 的文件内容。
"""
try:
response = requests.get(DTR_DATASET_PATTERN)
response.raise_for_status()
return response.text.splitlines()
except Exception as e:
print(f"Error loading DTR dataset: {e}")
return []
def generate_natural_language_description(tags, api_key=None):
"""
使用 OpenAI GPT 生成自然语言描述。
"""
if not api_key:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
return "Error: No API Key provided and none found in environment variables."
tag_descriptions = "\n".join([f"{key}: {value}" for key, value in tags.items() if value])
try:
client = OpenAI(api_key=api_key)
response = client.chat.completions.create(
messages=[
{
"role": "system",
"content": (
"You are a creative assistant that generates vivid and imaginative scene descriptions for painting prompts. "
"Focus on the details provided and incorporate them into a cohesive narrative. "
"Use at least three sentences."
),
},
{
"role": "user",
"content": f"Here are the tags and details:\n{tag_descriptions}\nPlease generate a vivid, imaginative scene description.",
},
],
model="gpt-4o",
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"GPT generation failed. Error: {e}"
def generate_prompt(
action_file, style_file, artist_files, character_files, dtr_enabled, api_key, selected_categories,
expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count
):
"""
生成随机提示词和描述。
"""
actions = get_random_items(load_candidates_from_files([action_file]) if action_file else [], action_count)
styles = get_random_items(load_candidates_from_files([style_file]) if style_file else [], style_count)
artists = get_random_items(load_candidates_from_files(artist_files) if artist_files else [], 1)
characters = get_random_items(load_candidates_from_files(character_files) if character_files else [], 1)
dtr_candidates = get_random_items(load_dtr_from_huggingface() if dtr_enabled else [], 1)
number_of_characters = ", ".join(selected_categories) if selected_categories else random.choice(["1girl", "1boy"])
tags = {
"number_of_characters": number_of_characters,
"character_name": characters,
"artist_prompt": f"(artist:{artists})",
"style": styles,
"scene": get_random_items(SCENES, scene_count),
"camera_angle": get_random_items(CAMERA_ANGLES, angle_count),
"action": actions,
"expression": get_random_items(EXPRESSIONS, expression_count),
"items": get_random_items(ITEMS, item_count),
"other_details": get_random_items(OTHER_DETAILS, detail_count),
"quality_prompts": get_random_items(QUALITY_PROMPTS, quality_count),
"dtr": dtr_candidates
}
description = generate_natural_language_description(tags, api_key)
tags_list = [item for sublist in tags.values() for item in (sublist if isinstance(sublist, list) else [sublist])] # Flatten
unique_tags = list(dict.fromkeys(tags_list))
final_tags = ", ".join(unique_tags)
combined_output = f"{final_tags}\n\n{description}"
return final_tags, description, combined_output
# ========== Gradio 界面 ==========
def gradio_interface():
"""
定义 Gradio 应用界面。
"""
with gr.Blocks() as demo:
gr.Markdown("## Random Prompt Generator with Adjustable Tag Counts")
api_key_input = gr.Textbox(
label="Enter your OpenAI API Key (Optional)",
placeholder="sk-...",
type="password"
)
with gr.Row():
action_file = gr.File(label="Upload Action File (Optional)", file_types=[".txt"])
style_file = gr.File(label="Upload Style File (Optional)", file_types=[".txt"])
with gr.Row():
artist_files = gr.Files(label="Upload Artist Files (Multiple Allowed)", file_types=[".txt"])
character_files = gr.Files(label="Upload Character Files (Multiple Allowed)", file_types=[".txt"])
dtr_enabled = gr.Checkbox(label="Enable DTR")
selected_categories = gr.CheckboxGroup(
["1boy", "1girl", "furry", "mecha", "fantasy monster", "animal", "still life"],
label="Choose Character Categories (Optional)"
)
with gr.Row():
expression_count = gr.Slider(label="Number of Expressions", minimum=1, maximum=5, step=1, value=1)
item_count = gr.Slider(label="Number of Items", minimum=1, maximum=5, step=1, value=1)
detail_count = gr.Slider(label="Number of Other Details", minimum=1, maximum=5, step=1, value=1)
scene_count = gr.Slider(label="Number of Scenes", minimum=1, maximum=5, step=1, value=1)
with gr.Row():
angle_count = gr.Slider(label="Number of Camera Angles", minimum=1, maximum=5, step=1, value=1)
quality_count = gr.Slider(label="Number of Quality Prompts", minimum=1, maximum=5, step=1, value=1)
action_count = gr.Slider(label="Number of Actions", minimum=1, maximum=5, step=1, value=1)
style_count = gr.Slider(label="Number of Styles", minimum=1, maximum=5, step=1, value=1)
with gr.Row():
tags_output = gr.Textbox(label="Generated Tags")
description_output = gr.Textbox(label="Generated Description")
combined_output = gr.Textbox(label="Combined Output: Tags + Description")
generate_button = gr.Button("Generate Prompt")
generate_button.click(
generate_prompt,
inputs=[
action_file, style_file, artist_files, character_files, dtr_enabled, api_key_input, selected_categories,
expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count
],
outputs=[tags_output, description_output, combined_output],
)
return demo
# 启动 Gradio 应用
if __name__ == "__main__":
gradio_interface().launch(share=True)