Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import random | |
| import glob | |
| import os | |
| import requests | |
| from openai import OpenAI | |
| from dotenv import load_dotenv | |
| # 加载环境变量 | |
| load_dotenv() | |
| # ========== 默认选项和数据 ========== | |
| EXPRESSIONS = [ | |
| "smiling", "determined", "surprised", "serene", "smug", "thinking", | |
| "looking back", "laughing", "angry", "pensive", "confident", | |
| "grinning", "thoughtful", "sad tears", "bewildered" | |
| ] | |
| ITEMS = [ | |
| "magic wand", "sword", "flower", "book of spells", "earrings", "loincloth", | |
| "slippers", "ancient scroll", "music instrument", "shield", "dagger", | |
| "headband", "leg ties", "staff", "potion", "crystal ball", "anklet", | |
| "ribbon", "lantern", "amulet", "ring" | |
| ] | |
| OTHER_DETAILS = [ | |
| "sparkles", "magical aura", "lens flare", "fireworks in the background", | |
| "smoke effects", "light trails", "falling leaves", "glowing embers", | |
| "floating particles", "rays of light", "shimmering mist", "ethereal glow" | |
| ] | |
| SCENES = [ | |
| "sunset beach", "rainy city street at night", "floating ash land", | |
| "particles magic world", "high blue sky", "top of the building", | |
| "fantasy forest with glowing mushrooms", "futuristic skyline at dawn", | |
| "abandoned castle", "snowy mountain peak", "desert ruins", "underwater city", | |
| "enchanted meadow", "haunted mansion", "steampunk marketplace", "glacial cavern" | |
| ] | |
| CAMERA_ANGLES = [ | |
| "low-angle shot", "close-up shot", "bird's-eye view", "wide-angle shot", | |
| "over-the-shoulder shot", "extreme close-up", "panoramic view", | |
| "dynamic tracking shot", "fisheye view", "point-of-view shot" | |
| ] | |
| QUALITY_PROMPTS = [ | |
| "cinematic lighting", "sharp shadow", "award-winning", "masterpiece", | |
| "vivid colors", "high dynamic range", "immersive", "studio quality", | |
| "fine art", "dreamlike", "8K", "HD", "high quality", "best quality", | |
| "artistic", "vibrant" | |
| ] | |
| # Hugging Face DTR 数据集路径(示例,若不可用请忽略) | |
| DTR_DATASET_PATTERN = "https://huggingface.co/datasets/X779/Danbooruwildcards/resolve/main/*DTR*.txt" | |
| # ========== 工具函数 ========== | |
| def load_candidates_from_files(files, excluded_tags=None): | |
| """ | |
| 从多个文件中加载候选项,同时排除用户不想要的标签(精确匹配)。 | |
| """ | |
| if excluded_tags is None: | |
| excluded_tags = set() | |
| all_lines = [] | |
| if files: | |
| for file in files: | |
| if isinstance(file, str): | |
| # 说明是路径字符串 | |
| with open(file, "r", encoding="utf-8") as f: | |
| lines = [line.strip() for line in f if line.strip()] | |
| filtered = [l for l in lines if l not in excluded_tags] | |
| all_lines.extend(filtered) | |
| else: | |
| # 说明是一个上传的 file-like 对象 | |
| file_data = file.read().decode("utf-8", errors="ignore") | |
| lines = [line.strip() for line in file_data.splitlines() if line.strip()] | |
| filtered = [l for l in lines if l not in excluded_tags] | |
| all_lines.extend(filtered) | |
| return all_lines | |
| def get_random_items(candidates, num_items=1): | |
| """ | |
| 从候选项中随机选取指定数量的选项。 | |
| """ | |
| return random.sample(candidates, min(num_items, len(candidates))) if candidates else [] | |
| def load_dtr_from_huggingface(excluded_tags=None): | |
| """ | |
| 从 Hugging Face 数据集中加载所有包含 "DTR" 的文件内容,同时排除不需要的tag。 | |
| """ | |
| if excluded_tags is None: | |
| excluded_tags = set() | |
| try: | |
| response = requests.get(DTR_DATASET_PATTERN) | |
| response.raise_for_status() | |
| lines = response.text.splitlines() | |
| # 只过滤精确匹配 | |
| lines = [l for l in lines if l not in excluded_tags] | |
| return lines | |
| except Exception as e: | |
| print(f"Error loading DTR dataset: {e}") | |
| return [] | |
| def generate_natural_language_description(tags, api_key=None, base_url=None, model="gpt-4"): | |
| """ | |
| 使用 OpenAI GPT 或 DeepSeek API 生成自然语言描述。 | |
| """ | |
| if not api_key: | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| return "Error: No API Key provided and none found in environment variables." | |
| # 将 dict 转成可读字符串 | |
| tag_descriptions = "\n".join([ | |
| f"{key}: {', '.join(value) if isinstance(value, list) else value}" | |
| for key, value in tags.items() if value | |
| ]) | |
| try: | |
| client = OpenAI(api_key=api_key, base_url=base_url) if base_url else OpenAI(api_key=api_key) | |
| response = client.chat.completions.create( | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are a creative assistant that generates detailed and imaginative scene descriptions for AI generation prompts. " | |
| "Focus on the details provided and incorporate them into a cohesive narrative. " | |
| "Use at least three sentences but no more than five sentences." | |
| ), | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"Here are the tags and details:\n{tag_descriptions}\nPlease generate a vivid, imaginative scene description.", | |
| }, | |
| ], | |
| model=model, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| return f"GPT generation failed. Error: {e}" | |
| # ========== 核心函数:随机生成 prompt ========== | |
| def generate_prompt( | |
| action_file, style_file, artist_files, character_files, dtr_enabled, api_key, selected_categories, | |
| expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count, | |
| artist_count, use_deepseek, deepseek_key, user_custom_tags, excluded_tags | |
| ): | |
| """ | |
| 生成随机提示词和描述。 | |
| """ | |
| # 处理排除 Tags(逗号分隔 -> 去重 set) | |
| excluded_set = set( | |
| [tag.strip() for tag in excluded_tags.split(",") if tag.strip()] | |
| ) if excluded_tags else set() | |
| # 从文件中加载可选 action、style、artist、character | |
| actions = get_random_items(load_candidates_from_files([action_file], excluded_set) if action_file else [], action_count) | |
| styles = get_random_items(load_candidates_from_files([style_file], excluded_set) if style_file else [], style_count) | |
| artists = get_random_items(load_candidates_from_files(artist_files, excluded_set) if artist_files else [], artist_count) | |
| characters = get_random_items(load_candidates_from_files(character_files, excluded_set) if character_files else [], 1) | |
| # 处理 DTR | |
| dtr_candidates = get_random_items(load_dtr_from_huggingface(excluded_set) if dtr_enabled else [], 1) | |
| # 处理预设列表中的随机筛选 | |
| filtered_expressions = [e for e in EXPRESSIONS if e not in excluded_set] | |
| filtered_items = [i for i in ITEMS if i not in excluded_set] | |
| filtered_details = [d for d in OTHER_DETAILS if d not in excluded_set] | |
| filtered_scenes = [s for s in SCENES if s not in excluded_set] | |
| filtered_angles = [c for c in CAMERA_ANGLES if c not in excluded_set] | |
| filtered_quality = [q for q in QUALITY_PROMPTS if q not in excluded_set] | |
| # 随机抽取 | |
| random_expression = get_random_items(filtered_expressions, expression_count) | |
| random_items = get_random_items(filtered_items, item_count) | |
| random_details = get_random_items(filtered_details, detail_count) | |
| random_scenes = get_random_items(filtered_scenes, scene_count) | |
| random_angles = get_random_items(filtered_angles, angle_count) | |
| random_quality = get_random_items(filtered_quality, quality_count) | |
| number_of_characters = ", ".join(selected_categories) if selected_categories else [] | |
| # 整理为字典 | |
| tags = { | |
| "number_of_characters": [number_of_characters] if number_of_characters else [], | |
| "character_name": characters, | |
| "artist_prompt": [f"(artist:{', '.join(artists)})"] if artists else [], | |
| "style": styles, | |
| "scene": random_scenes, | |
| "camera_angle": random_angles, | |
| "action": actions, | |
| "expression": random_expression, | |
| "items": random_items, | |
| "other_details": random_details, | |
| "quality_prompts": random_quality, | |
| "dtr": dtr_candidates, | |
| } | |
| # 如果用户有自定义输入 | |
| if user_custom_tags.strip(): | |
| tags["custom_tags"] = [t.strip() for t in user_custom_tags.split(",") if t.strip()] | |
| # 生成自然语言描述 | |
| if use_deepseek: | |
| description = generate_natural_language_description(tags, api_key=deepseek_key, base_url="https://api.deepseek.com", model="deepseek-chat") | |
| else: | |
| description = generate_natural_language_description(tags, api_key=api_key) | |
| # 整理最终 Tags(flatten 并去重) | |
| tags_list = [] | |
| for v in tags.values(): | |
| if isinstance(v, list): | |
| tags_list.extend(v) | |
| else: | |
| tags_list.append(v) | |
| # 去重保持顺序 | |
| seen = set() | |
| final_tags_list = [] | |
| for t in tags_list: | |
| if t not in seen and t: | |
| seen.add(t) | |
| final_tags_list.append(t) | |
| final_tags = ", ".join(final_tags_list) | |
| # 默认 Combined = Tags + Description | |
| combined_output = f"{final_tags}\n\n{description}" | |
| return final_tags, description, combined_output | |
| # ========== 部分更新:只根据用户修改后的 tags_text 生成新的描述和合并输出 ========== | |
| def update_description(tags_text, api_key, use_deepseek, deepseek_key): | |
| """ | |
| 只根据用户提供的 tags_text 生成描述和合并输出。 | |
| 不再重新随机抽取,以免破坏用户手动修改过的 Tags。 | |
| """ | |
| if not api_key and not deepseek_key: | |
| # 没有提供任意可用 API Key | |
| return "(No API Key provided)", f"{tags_text}\n\n(No API Key provided)" | |
| # 构造给 GPT 的 prompt | |
| user_prompt = ( | |
| "You are a creative assistant that generates detailed, imaginative scene descriptions for AI generation.\n" | |
| "Below is the user's current tags (prompt elements). " | |
| "Generate a new descriptive text (3-5 sentences) that incorporates these tags.\n\n" | |
| f"User Tags: {tags_text}\n" | |
| "Please generate a vivid, imaginative scene description." | |
| ) | |
| try: | |
| if use_deepseek: | |
| # 调用 DeepSeek | |
| client = OpenAI(api_key=deepseek_key, base_url="https://api.deepseek.com") | |
| model = "deepseek-chat" | |
| else: | |
| # 调用 OpenAI | |
| client = OpenAI(api_key=api_key) | |
| model = "gpt-4" # 或其他可用模型,比如 "gpt-3.5-turbo" | |
| response = client.chat.completions.create( | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "You are a creative assistant that generates imaginative scene descriptions..." | |
| }, | |
| { | |
| "role": "user", | |
| "content": user_prompt, | |
| }, | |
| ], | |
| model=model, | |
| ) | |
| new_description = response.choices[0].message.content.strip() | |
| except Exception as e: | |
| new_description = f"(GPT generation failed: {e})" | |
| new_combined_output = f"{tags_text}\n\n{new_description}" | |
| return new_description, new_combined_output | |
| # ========== 翻译功能:将 combined_output 翻译成用户选定语言 ========== | |
| def translate_combined_output(combined_text, target_language, api_key, use_deepseek, deepseek_key): | |
| """ | |
| 使用 GPT 或 DeepSeek API,将 combined_text 翻译成 target_language。 | |
| """ | |
| if not api_key and not deepseek_key: | |
| return "(No API Key provided)" | |
| # 简单用 GPT 做翻译,也可改成其他翻译 API | |
| translation_prompt = ( | |
| f"You are a professional translator. Please translate the following text into {target_language}.\n\n" | |
| f"{combined_text}" | |
| ) | |
| try: | |
| if use_deepseek: | |
| # 调用 DeepSeek | |
| client = OpenAI(api_key=deepseek_key, base_url="https://api.deepseek.com") | |
| model = "deepseek-chat" | |
| else: | |
| # 调用 OpenAI | |
| client = OpenAI(api_key=api_key) | |
| model = "gpt-3.5-turbo" # 或者别的模型 | |
| response = client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": "You are a professional translator."}, | |
| {"role": "user", "content": translation_prompt}, | |
| ], | |
| model=model, | |
| ) | |
| translated_text = response.choices[0].message.content.strip() | |
| except Exception as e: | |
| translated_text = f"(Translation failed: {e})" | |
| return translated_text | |
| # ========== 收藏功能:最多存 3 条 ========== | |
| def add_to_favorites(combined_output, current_favorites): | |
| """ | |
| 将当前生成的 combined_output 添加到收藏列表中(最多存 3 条)。 | |
| """ | |
| current_favorites.append(combined_output) | |
| # 如果超过3条,移除最早的一条 | |
| if len(current_favorites) > 3: | |
| current_favorites.pop(0) | |
| # 格式化输出 | |
| favorites_text = "\n\n".join( | |
| [f"[Favorite {i+1}]\n{fav}" for i, fav in enumerate(current_favorites)] | |
| ) | |
| return favorites_text, current_favorites | |
| # ========== Gradio 界面 ========== | |
| def gradio_interface(): | |
| """ | |
| 定义 Gradio 应用界面。 | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 【RPGAT】Random Prompt Generator with Adjustable Tags (V2)") | |
| # 用于存储收藏内容的状态(最多缓存3条) | |
| favorites_state = gr.State([]) | |
| with gr.Row(): | |
| # 左侧:文件上传、参数选择、排除/自定义输入 | |
| with gr.Column(scale=1): | |
| api_key_input = gr.Textbox( | |
| label="OpenAI API Key (可选)", | |
| placeholder="sk-...", | |
| type="password" | |
| ) | |
| deepseek_key_input = gr.Textbox( | |
| label="DeepSeek API Key (可选)", | |
| placeholder="sk-...", | |
| type="password" | |
| ) | |
| use_deepseek = gr.Checkbox(label="Use DeepSeek API") | |
| dtr_enabled = gr.Checkbox(label="Enable DTR (如不可用请勿勾选)") | |
| with gr.Group(): | |
| gr.Markdown("**上传文件 (可选):**") | |
| action_file = gr.File(label="Action File", file_types=[".txt"]) | |
| style_file = gr.File(label="Style File", file_types=[".txt"]) | |
| artist_files = gr.Files(label="Artist Files", file_types=[".txt"]) | |
| character_files = gr.Files(label="Character Files", file_types=[".txt"]) | |
| selected_categories = gr.CheckboxGroup( | |
| ["1boy", "1girl", "furry", "mecha", "fantasy monster", "animal", "still life"], | |
| label="Choose Character Categories" | |
| ) | |
| excluded_tags = gr.Textbox( | |
| label="排除excluded Tags (逗号分隔)", | |
| placeholder="As如:angry, sword" | |
| ) | |
| user_custom_tags = gr.Textbox( | |
| label="自定义附加custom addition Tags (逗号分隔)", | |
| placeholder="As如:glowing eyes, giant wings" | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("**随机数量设置:**") | |
| expression_count = gr.Slider(label="Number of Expressions", minimum=0, maximum=10, step=1, value=1) | |
| item_count = gr.Slider(label="Number of Items", minimum=0, maximum=10, step=1, value=1) | |
| detail_count = gr.Slider(label="Number of Other Details", minimum=0, maximum=10, step=1, value=1) | |
| scene_count = gr.Slider(label="Number of Scenes", minimum=0, maximum=10, step=1, value=1) | |
| angle_count = gr.Slider(label="Number of Camera Angles", minimum=0, maximum=10, step=1, value=1) | |
| quality_count = gr.Slider(label="Number of Quality Prompts", minimum=0, maximum=10, step=1, value=1) | |
| action_count = gr.Slider(label="Number of Actions", minimum=1, maximum=10, step=1, value=1) | |
| style_count = gr.Slider(label="Number of Styles", minimum=1, maximum=10, step=1, value=1) | |
| artist_count = gr.Slider(label="Number of Artists", minimum=1, maximum=10, step=1, value=1) | |
| # 右侧:生成按钮 + 生成结果 + 收藏 + 翻译 | |
| with gr.Column(scale=2): | |
| generate_button = gr.Button("Generate Prompt", variant="primary") | |
| tags_output = gr.Textbox( | |
| label="Generated Tags", | |
| placeholder="等待生成...", | |
| lines=4, | |
| interactive=True | |
| ) | |
| description_output = gr.Textbox( | |
| label="Generated Description", | |
| placeholder="等待生成...", | |
| lines=4, | |
| interactive=True | |
| ) | |
| combined_output = gr.Textbox( | |
| label="Combined Output: Tags + Description", | |
| placeholder="等待生成...", | |
| lines=6 | |
| ) | |
| # 新增一个按钮,只更新 description 和 combined | |
| update_desc_button = gr.Button("Update Description Only") | |
| # 翻译相关 | |
| with gr.Row(): | |
| target_language = gr.Dropdown( | |
| choices=["English", "Chinese", "Arabic (language)", "Japanese", "Persian (language)", "Italian (language)", | |
| "Dutch (language)","Russian (language)","German (language)"], | |
| value="English", | |
| label="Target Language-目标语言" | |
| ) | |
| translate_button = gr.Button("Translate to selected language") | |
| translated_output = gr.Textbox( | |
| label="Translated Output", | |
| placeholder="Waiting for the result of Translation--等待翻译...", | |
| lines=6 | |
| ) | |
| # 收藏 | |
| with gr.Row(): | |
| favorite_button = gr.Button("Favorites this result-收藏本次结果") | |
| favorites_box = gr.Textbox( | |
| label="Favorites Folder (MAX-Save 3 tags-- 最多 3 条)", | |
| placeholder="暂无收藏", | |
| lines=6 | |
| ) | |
| # 点击“Generate Prompt”按钮 | |
| generate_button.click( | |
| generate_prompt, | |
| inputs=[ | |
| action_file, style_file, artist_files, character_files, | |
| dtr_enabled, api_key_input, selected_categories, | |
| expression_count, item_count, detail_count, scene_count, | |
| angle_count, quality_count, action_count, style_count, | |
| artist_count, use_deepseek, deepseek_key_input, | |
| user_custom_tags, excluded_tags | |
| ], | |
| outputs=[tags_output, description_output, combined_output], | |
| ) | |
| # 点击“Update Description Only”按钮 | |
| update_desc_button.click( | |
| update_description, | |
| inputs=[ | |
| tags_output, # 用户在文本框里编辑后的 Tags | |
| api_key_input, | |
| use_deepseek, | |
| deepseek_key_input, | |
| ], | |
| outputs=[description_output, combined_output], | |
| ) | |
| # 点击“Translate to selected language”按钮 | |
| translate_button.click( | |
| fn=translate_combined_output, | |
| inputs=[ | |
| combined_output, # 要翻译的源文本 | |
| target_language, | |
| api_key_input, | |
| use_deepseek, | |
| deepseek_key_input | |
| ], | |
| outputs=[translated_output], | |
| ) | |
| # 收藏按钮点击事件 | |
| favorite_button.click( | |
| fn=add_to_favorites, | |
| inputs=[combined_output, favorites_state], | |
| outputs=[favorites_box, favorites_state], | |
| ) | |
| return demo | |
| # 启动 Gradio 应用 | |
| if __name__ == "__main__": | |
| gradio_interface().launch() |