Spaces:
Running
Running
import gradio as gr | |
import random | |
import glob | |
import os | |
import requests | |
from openai import OpenAI | |
from dotenv import load_dotenv | |
# 加载环境变量 | |
load_dotenv() | |
# ========== 默认选项和数据 ========== | |
EXPRESSIONS = ["smiling", "determined", "surprised", "serene", "laughing", "angry", "pensive", "confident"] | |
ITEMS = ["magic wand", "sword", "flower", "book of spells", "ancient scroll", "music instrument", "shield", "dagger", "staff", "potion"] | |
OTHER_DETAILS = ["sparkles", "magical aura", "lens flare", "fireworks in the background", "smoke effects", "light trails", "falling leaves", "glowing embers"] | |
SCENES = ["sunset beach", "rainy city street at night", "fantasy forest with glowing mushrooms", "futuristic skyline at dawn", "abandoned castle", "snowy mountain peak", "desert ruins", "underwater city"] | |
CAMERA_ANGLES = ["low-angle shot", "close-up shot", "bird's-eye view", "wide-angle shot", "over-the-shoulder shot", "extreme close-up", "panoramic view", "dynamic tracking shot"] | |
QUALITY_PROMPTS = ["cinematic lighting", "award-winning", "masterpiece", "vivid colors", "high dynamic range", "immersive", "studio quality", "fine art", "dreamlike", "8K", "HD", "high quality", "best quality"] | |
# Hugging Face DTR 数据集路径 | |
DTR_DATASET_PATTERN = "https://huggingface.co/datasets/X779/Danbooruwildcards/resolve/main/*DTR*.txt" | |
# ========== 工具函数 ========== | |
def load_candidates_from_files(files): | |
""" | |
从多个文件中加载候选项。 | |
""" | |
all_lines = [] | |
if files: | |
for file in files: | |
if isinstance(file, str): | |
with open(file, "r", encoding="utf-8") as f: | |
all_lines.extend([line.strip() for line in f if line.strip()]) | |
return all_lines | |
def get_random_items(candidates, num_items=1): | |
""" | |
从候选项中随机选取指定数量的选项。 | |
""" | |
return random.sample(candidates, min(num_items, len(candidates))) if candidates else [] | |
def load_dtr_from_huggingface(): | |
""" | |
从 Hugging Face 数据集中加载所有包含 "DTR" 的文件内容。 | |
""" | |
try: | |
response = requests.get(DTR_DATASET_PATTERN) | |
response.raise_for_status() | |
return response.text.splitlines() | |
except Exception as e: | |
print(f"Error loading DTR dataset: {e}") | |
return [] | |
def generate_natural_language_description(tags, api_key=None): | |
""" | |
使用 OpenAI GPT 生成自然语言描述。 | |
""" | |
if not api_key: | |
api_key = os.getenv("OPENAI_API_KEY") | |
if not api_key: | |
return "Error: No API Key provided and none found in environment variables." | |
tag_descriptions = "\n".join([f"{key}: {value}" for key, value in tags.items() if value]) | |
try: | |
client = OpenAI(api_key=api_key) | |
response = client.chat.completions.create( | |
messages=[ | |
{ | |
"role": "system", | |
"content": ( | |
"You are a creative assistant that generates vivid and imaginative scene descriptions for painting prompts. " | |
"Focus on the details provided and incorporate them into a cohesive narrative. " | |
"Use at least three sentences." | |
), | |
}, | |
{ | |
"role": "user", | |
"content": f"Here are the tags and details:\n{tag_descriptions}\nPlease generate a vivid, imaginative scene description.", | |
}, | |
], | |
model="gpt-4o", | |
) | |
return response.choices[0].message.content.strip() | |
except Exception as e: | |
return f"GPT generation failed. Error: {e}" | |
def generate_prompt( | |
action_file, style_file, artist_files, character_files, dtr_enabled, api_key, selected_categories, | |
expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count | |
): | |
""" | |
生成随机提示词和描述。 | |
""" | |
actions = get_random_items(load_candidates_from_files([action_file]) if action_file else [], action_count) | |
styles = get_random_items(load_candidates_from_files([style_file]) if style_file else [], style_count) | |
artists = get_random_items(load_candidates_from_files(artist_files) if artist_files else [], 1) | |
characters = get_random_items(load_candidates_from_files(character_files) if character_files else [], 1) | |
dtr_candidates = get_random_items(load_dtr_from_huggingface() if dtr_enabled else [], 1) | |
number_of_characters = ", ".join(selected_categories) if selected_categories else random.choice(["1girl", "1boy"]) | |
tags = { | |
"number_of_characters": number_of_characters, | |
"character_name": characters, | |
"artist_prompt": f"(artist:{artists})", | |
"style": styles, | |
"scene": get_random_items(SCENES, scene_count), | |
"camera_angle": get_random_items(CAMERA_ANGLES, angle_count), | |
"action": actions, | |
"expression": get_random_items(EXPRESSIONS, expression_count), | |
"items": get_random_items(ITEMS, item_count), | |
"other_details": get_random_items(OTHER_DETAILS, detail_count), | |
"quality_prompts": get_random_items(QUALITY_PROMPTS, quality_count), | |
"dtr": dtr_candidates | |
} | |
description = generate_natural_language_description(tags, api_key) | |
tags_list = [item for sublist in tags.values() for item in (sublist if isinstance(sublist, list) else [sublist])] # Flatten | |
unique_tags = list(dict.fromkeys(tags_list)) | |
final_tags = ", ".join(unique_tags) | |
combined_output = f"{final_tags}\n\n{description}" | |
return final_tags, description, combined_output | |
# ========== Gradio 界面 ========== | |
def gradio_interface(): | |
""" | |
定义 Gradio 应用界面。 | |
""" | |
with gr.Blocks() as demo: | |
gr.Markdown("## Random Prompt Generator with Adjustable Tag Counts") | |
api_key_input = gr.Textbox( | |
label="Enter your OpenAI API Key (Optional)", | |
placeholder="sk-...", | |
type="password" | |
) | |
with gr.Row(): | |
action_file = gr.File(label="Upload Action File (Optional)", file_types=[".txt"]) | |
style_file = gr.File(label="Upload Style File (Optional)", file_types=[".txt"]) | |
with gr.Row(): | |
artist_files = gr.Files(label="Upload Artist Files (Multiple Allowed)", file_types=[".txt"]) | |
character_files = gr.Files(label="Upload Character Files (Multiple Allowed)", file_types=[".txt"]) | |
dtr_enabled = gr.Checkbox(label="Enable DTR") | |
selected_categories = gr.CheckboxGroup( | |
["1boy", "1girl", "furry", "mecha", "fantasy monster", "animal", "still life"], | |
label="Choose Character Categories (Optional)" | |
) | |
with gr.Row(): | |
expression_count = gr.Slider(label="Number of Expressions", minimum=1, maximum=5, step=1, value=1) | |
item_count = gr.Slider(label="Number of Items", minimum=1, maximum=5, step=1, value=1) | |
detail_count = gr.Slider(label="Number of Other Details", minimum=1, maximum=5, step=1, value=1) | |
scene_count = gr.Slider(label="Number of Scenes", minimum=1, maximum=5, step=1, value=1) | |
with gr.Row(): | |
angle_count = gr.Slider(label="Number of Camera Angles", minimum=1, maximum=5, step=1, value=1) | |
quality_count = gr.Slider(label="Number of Quality Prompts", minimum=1, maximum=5, step=1, value=1) | |
action_count = gr.Slider(label="Number of Actions", minimum=1, maximum=5, step=1, value=1) | |
style_count = gr.Slider(label="Number of Styles", minimum=1, maximum=5, step=1, value=1) | |
with gr.Row(): | |
tags_output = gr.Textbox(label="Generated Tags") | |
description_output = gr.Textbox(label="Generated Description") | |
combined_output = gr.Textbox(label="Combined Output: Tags + Description") | |
generate_button = gr.Button("Generate Prompt") | |
generate_button.click( | |
generate_prompt, | |
inputs=[ | |
action_file, style_file, artist_files, character_files, dtr_enabled, api_key_input, selected_categories, | |
expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count | |
], | |
outputs=[tags_output, description_output, combined_output], | |
) | |
return demo | |
# 启动 Gradio 应用 | |
if __name__ == "__main__": | |
gradio_interface().launch(share=True) |