Spaces:
Running
Running
import gradio as gr | |
import random | |
import glob | |
import os | |
import requests | |
from openai import OpenAI | |
from dotenv import load_dotenv | |
# 加载环境变量 | |
load_dotenv() | |
# ========== 默认选项和数据 ========== | |
EXPRESSIONS = [ | |
"smiling", "determined", "surprised", "serene", "smug", "thinking", | |
"looking back", "laughing", "angry", "pensive", "confident", | |
"grinning", "thoughtful", "sad tears", "bewildered" | |
] | |
ITEMS = [ | |
"magic wand", "sword", "flower", "book of spells", "earrings", "loincloth", | |
"slippers", "ancient scroll", "music instrument", "shield", "dagger", | |
"headband", "leg ties", "staff", "potion", "crystal ball", "anklet", | |
"ribbon", "lantern", "amulet", "ring" | |
] | |
OTHER_DETAILS = [ | |
"sparkles", "magical aura", "lens flare", "fireworks in the background", | |
"smoke effects", "light trails", "falling leaves", "glowing embers", | |
"floating particles", "rays of light", "shimmering mist", "ethereal glow" | |
] | |
SCENES = [ | |
"sunset beach", "rainy city street at night", "floating ash land", | |
"particles magic world", "high blue sky", "top of the building", | |
"fantasy forest with glowing mushrooms", "futuristic skyline at dawn", | |
"abandoned castle", "snowy mountain peak", "desert ruins", "underwater city", | |
"enchanted meadow", "haunted mansion", "steampunk marketplace", "glacial cavern" | |
] | |
CAMERA_ANGLES = [ | |
"low-angle shot", "close-up shot", "bird's-eye view", "wide-angle shot", | |
"over-the-shoulder shot", "extreme close-up", "panoramic view", | |
"dynamic tracking shot", "fisheye view", "point-of-view shot" | |
] | |
QUALITY_PROMPTS = [ | |
"cinematic lighting", "sharp shadow", "award-winning", "masterpiece", | |
"vivid colors", "high dynamic range", "immersive", "studio quality", | |
"fine art", "dreamlike", "8K", "HD", "high quality", "best quality", | |
"artistic", "vibrant" | |
] | |
# Hugging Face DTR 数据集路径(示例,若不可用请忽略) | |
DTR_DATASET_PATTERN = "https://huggingface.co/datasets/X779/Danbooruwildcards/resolve/main/*DTR*.txt" | |
# ========== 工具函数 ========== | |
def load_candidates_from_files(files, excluded_tags=None): | |
""" | |
从多个文件中加载候选项,同时排除用户不想要的标签(精确匹配)。 | |
""" | |
if excluded_tags is None: | |
excluded_tags = set() | |
all_lines = [] | |
if files: | |
for file in files: | |
if isinstance(file, str): | |
# 说明是路径字符串 | |
with open(file, "r", encoding="utf-8") as f: | |
lines = [line.strip() for line in f if line.strip()] | |
filtered = [l for l in lines if l not in excluded_tags] | |
all_lines.extend(filtered) | |
else: | |
# 说明是一个上传的 file-like 对象 | |
file_data = file.read().decode("utf-8", errors="ignore") | |
lines = [line.strip() for line in file_data.splitlines() if line.strip()] | |
filtered = [l for l in lines if l not in excluded_tags] | |
all_lines.extend(filtered) | |
return all_lines | |
def get_random_items(candidates, num_items=1): | |
""" | |
从候选项中随机选取指定数量的选项。 | |
""" | |
return random.sample(candidates, min(num_items, len(candidates))) if candidates else [] | |
def load_dtr_from_huggingface(excluded_tags=None): | |
""" | |
从 Hugging Face 数据集中加载所有包含 "DTR" 的文件内容,同时排除不需要的tag。 | |
""" | |
if excluded_tags is None: | |
excluded_tags = set() | |
try: | |
response = requests.get(DTR_DATASET_PATTERN) | |
response.raise_for_status() | |
lines = response.text.splitlines() | |
# 只过滤精确匹配 | |
lines = [l for l in lines if l not in excluded_tags] | |
return lines | |
except Exception as e: | |
print(f"Error loading DTR dataset: {e}") | |
return [] | |
def generate_natural_language_description(tags, api_key=None, base_url=None, model="gpt-4"): | |
""" | |
使用 OpenAI GPT 或 DeepSeek API 生成自然语言描述。 | |
""" | |
if not api_key: | |
api_key = os.getenv("OPENAI_API_KEY") | |
if not api_key: | |
return "Error: No API Key provided and none found in environment variables." | |
# 将 dict 转成可读字符串 | |
tag_descriptions = "\n".join([ | |
f"{key}: {', '.join(value) if isinstance(value, list) else value}" | |
for key, value in tags.items() if value | |
]) | |
try: | |
client = OpenAI(api_key=api_key, base_url=base_url) if base_url else OpenAI(api_key=api_key) | |
response = client.chat.completions.create( | |
messages=[ | |
{ | |
"role": "system", | |
"content": ( | |
"You are a creative assistant that generates detailed and imaginative scene descriptions for AI generation prompts. " | |
"Focus on the details provided and incorporate them into a cohesive narrative. " | |
"Use at least three sentences but no more than five sentences." | |
), | |
}, | |
{ | |
"role": "user", | |
"content": f"Here are the tags and details:\n{tag_descriptions}\nPlease generate a vivid, imaginative scene description.", | |
}, | |
], | |
model=model, | |
) | |
return response.choices[0].message.content.strip() | |
except Exception as e: | |
return f"GPT generation failed. Error: {e}" | |
# ========== 核心函数:随机生成 prompt ========== | |
def generate_prompt( | |
action_file, style_file, artist_files, character_files, dtr_enabled, api_key, selected_categories, | |
expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count, | |
artist_count, use_deepseek, deepseek_key, user_custom_tags, excluded_tags | |
): | |
""" | |
生成随机提示词和描述。 | |
""" | |
# 处理排除 Tags(逗号分隔 -> 去重 set) | |
excluded_set = set( | |
[tag.strip() for tag in excluded_tags.split(",") if tag.strip()] | |
) if excluded_tags else set() | |
# 从文件中加载可选 action、style、artist、character | |
actions = get_random_items(load_candidates_from_files([action_file], excluded_set) if action_file else [], action_count) | |
styles = get_random_items(load_candidates_from_files([style_file], excluded_set) if style_file else [], style_count) | |
artists = get_random_items(load_candidates_from_files(artist_files, excluded_set) if artist_files else [], artist_count) | |
characters = get_random_items(load_candidates_from_files(character_files, excluded_set) if character_files else [], 1) | |
# 处理 DTR | |
dtr_candidates = get_random_items(load_dtr_from_huggingface(excluded_set) if dtr_enabled else [], 1) | |
# 处理预设列表中的随机筛选 | |
filtered_expressions = [e for e in EXPRESSIONS if e not in excluded_set] | |
filtered_items = [i for i in ITEMS if i not in excluded_set] | |
filtered_details = [d for d in OTHER_DETAILS if d not in excluded_set] | |
filtered_scenes = [s for s in SCENES if s not in excluded_set] | |
filtered_angles = [c for c in CAMERA_ANGLES if c not in excluded_set] | |
filtered_quality = [q for q in QUALITY_PROMPTS if q not in excluded_set] | |
# 随机抽取 | |
random_expression = get_random_items(filtered_expressions, expression_count) | |
random_items = get_random_items(filtered_items, item_count) | |
random_details = get_random_items(filtered_details, detail_count) | |
random_scenes = get_random_items(filtered_scenes, scene_count) | |
random_angles = get_random_items(filtered_angles, angle_count) | |
random_quality = get_random_items(filtered_quality, quality_count) | |
number_of_characters = ", ".join(selected_categories) if selected_categories else [] | |
# 整理为字典 | |
tags = { | |
"number_of_characters": [number_of_characters] if number_of_characters else [], | |
"character_name": characters, | |
"artist_prompt": [f"(artist:{', '.join(artists)})"] if artists else [], | |
"style": styles, | |
"scene": random_scenes, | |
"camera_angle": random_angles, | |
"action": actions, | |
"expression": random_expression, | |
"items": random_items, | |
"other_details": random_details, | |
"quality_prompts": random_quality, | |
"dtr": dtr_candidates, | |
} | |
# 如果用户有自定义输入 | |
if user_custom_tags.strip(): | |
tags["custom_tags"] = [t.strip() for t in user_custom_tags.split(",") if t.strip()] | |
# 生成自然语言描述 | |
if use_deepseek: | |
description = generate_natural_language_description(tags, api_key=deepseek_key, base_url="https://api.deepseek.com", model="deepseek-chat") | |
else: | |
description = generate_natural_language_description(tags, api_key=api_key) | |
# 整理最终 Tags(flatten 并去重) | |
tags_list = [] | |
for v in tags.values(): | |
if isinstance(v, list): | |
tags_list.extend(v) | |
else: | |
tags_list.append(v) | |
# 去重保持顺序 | |
seen = set() | |
final_tags_list = [] | |
for t in tags_list: | |
if t not in seen and t: | |
seen.add(t) | |
final_tags_list.append(t) | |
final_tags = ", ".join(final_tags_list) | |
# 默认 Combined = Tags + Description | |
combined_output = f"{final_tags}\n\n{description}" | |
return final_tags, description, combined_output | |
# ========== 部分更新:只根据用户修改后的 tags_text 生成新的描述和合并输出 ========== | |
def update_description(tags_text, api_key, use_deepseek, deepseek_key): | |
""" | |
只根据用户提供的 tags_text 生成描述和合并输出。 | |
不再重新随机抽取,以免破坏用户手动修改过的 Tags。 | |
""" | |
if not api_key and not deepseek_key: | |
# 没有提供任意可用 API Key | |
return "(No API Key provided)", f"{tags_text}\n\n(No API Key provided)" | |
# 构造给 GPT 的 prompt | |
user_prompt = ( | |
"You are a creative assistant that generates detailed, imaginative scene descriptions for AI generation.\n" | |
"Below is the user's current tags (prompt elements). " | |
"Generate a new descriptive text (3-5 sentences) that incorporates these tags.\n\n" | |
f"User Tags: {tags_text}\n" | |
"Please generate a vivid, imaginative scene description." | |
) | |
try: | |
if use_deepseek: | |
# 调用 DeepSeek | |
client = OpenAI(api_key=deepseek_key, base_url="https://api.deepseek.com") | |
model = "deepseek-chat" | |
else: | |
# 调用 OpenAI | |
client = OpenAI(api_key=api_key) | |
model = "gpt-4" # 或其他可用模型,比如 "gpt-3.5-turbo" | |
response = client.chat.completions.create( | |
messages=[ | |
{ | |
"role": "system", | |
"content": "You are a creative assistant that generates imaginative scene descriptions..." | |
}, | |
{ | |
"role": "user", | |
"content": user_prompt, | |
}, | |
], | |
model=model, | |
) | |
new_description = response.choices[0].message.content.strip() | |
except Exception as e: | |
new_description = f"(GPT generation failed: {e})" | |
new_combined_output = f"{tags_text}\n\n{new_description}" | |
return new_description, new_combined_output | |
# ========== 翻译功能:将 combined_output 翻译成用户选定语言 ========== | |
def translate_combined_output(combined_text, target_language, api_key, use_deepseek, deepseek_key): | |
""" | |
使用 GPT 或 DeepSeek API,将 combined_text 翻译成 target_language。 | |
""" | |
if not api_key and not deepseek_key: | |
return "(No API Key provided)" | |
# 简单用 GPT 做翻译,也可改成其他翻译 API | |
translation_prompt = ( | |
f"You are a professional translator. Please translate the following text into {target_language}.\n\n" | |
f"{combined_text}" | |
) | |
try: | |
if use_deepseek: | |
# 调用 DeepSeek | |
client = OpenAI(api_key=deepseek_key, base_url="https://api.deepseek.com") | |
model = "deepseek-chat" | |
else: | |
# 调用 OpenAI | |
client = OpenAI(api_key=api_key) | |
model = "gpt-3.5-turbo" # 或者别的模型 | |
response = client.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": "You are a professional translator."}, | |
{"role": "user", "content": translation_prompt}, | |
], | |
model=model, | |
) | |
translated_text = response.choices[0].message.content.strip() | |
except Exception as e: | |
translated_text = f"(Translation failed: {e})" | |
return translated_text | |
# ========== 收藏功能:最多存 3 条 ========== | |
def add_to_favorites(combined_output, current_favorites): | |
""" | |
将当前生成的 combined_output 添加到收藏列表中(最多存 3 条)。 | |
""" | |
current_favorites.append(combined_output) | |
# 如果超过3条,移除最早的一条 | |
if len(current_favorites) > 3: | |
current_favorites.pop(0) | |
# 格式化输出 | |
favorites_text = "\n\n".join( | |
[f"[Favorite {i+1}]\n{fav}" for i, fav in enumerate(current_favorites)] | |
) | |
return favorites_text, current_favorites | |
# ========== Gradio 界面 ========== | |
def gradio_interface(): | |
""" | |
定义 Gradio 应用界面。 | |
""" | |
with gr.Blocks() as demo: | |
gr.Markdown("## 【RPGAT】Random Prompt Generator with Adjustable Tags (V2)") | |
# 用于存储收藏内容的状态(最多缓存3条) | |
favorites_state = gr.State([]) | |
with gr.Row(): | |
# 左侧:文件上传、参数选择、排除/自定义输入 | |
with gr.Column(scale=1): | |
api_key_input = gr.Textbox( | |
label="OpenAI API Key (可选)", | |
placeholder="sk-...", | |
type="password" | |
) | |
deepseek_key_input = gr.Textbox( | |
label="DeepSeek API Key (可选)", | |
placeholder="sk-...", | |
type="password" | |
) | |
use_deepseek = gr.Checkbox(label="Use DeepSeek API") | |
dtr_enabled = gr.Checkbox(label="Enable DTR (如不可用请勿勾选)") | |
with gr.Group(): | |
gr.Markdown("**上传文件 (可选):**") | |
action_file = gr.File(label="Action File", file_types=[".txt"]) | |
style_file = gr.File(label="Style File", file_types=[".txt"]) | |
artist_files = gr.Files(label="Artist Files", file_types=[".txt"]) | |
character_files = gr.Files(label="Character Files", file_types=[".txt"]) | |
selected_categories = gr.CheckboxGroup( | |
["1boy", "1girl", "furry", "mecha", "fantasy monster", "animal", "still life"], | |
label="Choose Character Categories" | |
) | |
excluded_tags = gr.Textbox( | |
label="排除excluded Tags (逗号分隔)", | |
placeholder="As如:angry, sword" | |
) | |
user_custom_tags = gr.Textbox( | |
label="自定义附加custom addition Tags (逗号分隔)", | |
placeholder="As如:glowing eyes, giant wings" | |
) | |
with gr.Group(): | |
gr.Markdown("**随机数量设置:**") | |
expression_count = gr.Slider(label="Number of Expressions", minimum=0, maximum=10, step=1, value=1) | |
item_count = gr.Slider(label="Number of Items", minimum=0, maximum=10, step=1, value=1) | |
detail_count = gr.Slider(label="Number of Other Details", minimum=0, maximum=10, step=1, value=1) | |
scene_count = gr.Slider(label="Number of Scenes", minimum=0, maximum=10, step=1, value=1) | |
angle_count = gr.Slider(label="Number of Camera Angles", minimum=0, maximum=10, step=1, value=1) | |
quality_count = gr.Slider(label="Number of Quality Prompts", minimum=0, maximum=10, step=1, value=1) | |
action_count = gr.Slider(label="Number of Actions", minimum=1, maximum=10, step=1, value=1) | |
style_count = gr.Slider(label="Number of Styles", minimum=1, maximum=10, step=1, value=1) | |
artist_count = gr.Slider(label="Number of Artists", minimum=1, maximum=10, step=1, value=1) | |
# 右侧:生成按钮 + 生成结果 + 收藏 + 翻译 | |
with gr.Column(scale=2): | |
generate_button = gr.Button("Generate Prompt", variant="primary") | |
tags_output = gr.Textbox( | |
label="Generated Tags", | |
placeholder="等待生成...", | |
lines=4, | |
interactive=True | |
) | |
description_output = gr.Textbox( | |
label="Generated Description", | |
placeholder="等待生成...", | |
lines=4, | |
interactive=True | |
) | |
combined_output = gr.Textbox( | |
label="Combined Output: Tags + Description", | |
placeholder="等待生成...", | |
lines=6 | |
) | |
# 新增一个按钮,只更新 description 和 combined | |
update_desc_button = gr.Button("Update Description Only") | |
# 翻译相关 | |
with gr.Row(): | |
target_language = gr.Dropdown( | |
choices=["English", "Chinese", "Arabic (language)", "Japanese", "Persian (language)", "Italian (language)", | |
"Dutch (language)","Russian (language)","German (language)"], | |
value="English", | |
label="Target Language-目标语言" | |
) | |
translate_button = gr.Button("Translate to selected language") | |
translated_output = gr.Textbox( | |
label="Translated Output", | |
placeholder="Waiting for the result of Translation--等待翻译...", | |
lines=6 | |
) | |
# 收藏 | |
with gr.Row(): | |
favorite_button = gr.Button("Favorites this result-收藏本次结果") | |
favorites_box = gr.Textbox( | |
label="Favorites Folder (MAX-Save 3 tags-- 最多 3 条)", | |
placeholder="暂无收藏", | |
lines=6 | |
) | |
# 点击“Generate Prompt”按钮 | |
generate_button.click( | |
generate_prompt, | |
inputs=[ | |
action_file, style_file, artist_files, character_files, | |
dtr_enabled, api_key_input, selected_categories, | |
expression_count, item_count, detail_count, scene_count, | |
angle_count, quality_count, action_count, style_count, | |
artist_count, use_deepseek, deepseek_key_input, | |
user_custom_tags, excluded_tags | |
], | |
outputs=[tags_output, description_output, combined_output], | |
) | |
# 点击“Update Description Only”按钮 | |
update_desc_button.click( | |
update_description, | |
inputs=[ | |
tags_output, # 用户在文本框里编辑后的 Tags | |
api_key_input, | |
use_deepseek, | |
deepseek_key_input, | |
], | |
outputs=[description_output, combined_output], | |
) | |
# 点击“Translate to selected language”按钮 | |
translate_button.click( | |
fn=translate_combined_output, | |
inputs=[ | |
combined_output, # 要翻译的源文本 | |
target_language, | |
api_key_input, | |
use_deepseek, | |
deepseek_key_input | |
], | |
outputs=[translated_output], | |
) | |
# 收藏按钮点击事件 | |
favorite_button.click( | |
fn=add_to_favorites, | |
inputs=[combined_output, favorites_state], | |
outputs=[favorites_box, favorites_state], | |
) | |
return demo | |
# 启动 Gradio 应用 | |
if __name__ == "__main__": | |
gradio_interface().launch() |