Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
|
|
2 |
import random
|
3 |
import glob
|
4 |
import os
|
|
|
5 |
from openai import OpenAI
|
6 |
from dotenv import load_dotenv
|
7 |
|
@@ -29,28 +30,23 @@ def load_candidates_from_files(files):
|
|
29 |
all_lines.extend([line.strip() for line in f if line.strip()])
|
30 |
return all_lines
|
31 |
|
32 |
-
def
|
33 |
"""
|
34 |
-
|
35 |
"""
|
36 |
-
return random.
|
37 |
|
38 |
-
def
|
39 |
"""
|
40 |
-
|
41 |
-
:param directory: 目标目录,默认为当前目录
|
42 |
-
:param pattern: 匹配文件名的模式,默认是包含 "DTR" 的文件
|
43 |
-
:return: 文件内容的列表
|
44 |
"""
|
45 |
-
dtr_candidates = []
|
46 |
try:
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
dtr_candidates.extend([line.strip() for line in f if line.strip()])
|
51 |
except Exception as e:
|
52 |
-
print(f"Error loading DTR
|
53 |
-
|
54 |
|
55 |
def generate_natural_language_description(tags, api_key=None):
|
56 |
"""
|
@@ -87,7 +83,10 @@ def generate_natural_language_description(tags, api_key=None):
|
|
87 |
except Exception as e:
|
88 |
return f"GPT generation failed. Error: {e}"
|
89 |
|
90 |
-
def generate_prompt(
|
|
|
|
|
|
|
91 |
"""
|
92 |
生成随机提示词和描述。
|
93 |
"""
|
@@ -95,28 +94,28 @@ def generate_prompt(action_file, style_file, artist_files, character_files, dtr_
|
|
95 |
styles = load_candidates_from_files([style_file]) if style_file else []
|
96 |
artists = load_candidates_from_files(artist_files) if artist_files else []
|
97 |
characters = load_candidates_from_files(character_files) if character_files else []
|
98 |
-
dtr_candidates =
|
99 |
|
100 |
number_of_characters = ", ".join(selected_categories) if selected_categories else random.choice(["1girl", "1boy"])
|
101 |
|
102 |
tags = {
|
103 |
"number_of_characters": number_of_characters,
|
104 |
-
"character_name":
|
105 |
-
"artist_prompt": f"(artist:{
|
106 |
-
"style":
|
107 |
-
"scene":
|
108 |
-
"camera_angle":
|
109 |
-
"action":
|
110 |
-
"expression":
|
111 |
-
"items":
|
112 |
-
"other_details":
|
113 |
-
"quality_prompts":
|
114 |
-
"dtr":
|
115 |
}
|
116 |
|
117 |
description = generate_natural_language_description(tags, api_key)
|
118 |
|
119 |
-
tags_list = [
|
120 |
unique_tags = list(dict.fromkeys(tags_list))
|
121 |
final_tags = ", ".join(unique_tags)
|
122 |
combined_output = f"{final_tags}\n\n{description}"
|
@@ -128,7 +127,7 @@ def gradio_interface():
|
|
128 |
定义 Gradio 应用界面。
|
129 |
"""
|
130 |
with gr.Blocks() as demo:
|
131 |
-
gr.Markdown("## Random Prompt Generator with
|
132 |
|
133 |
api_key_input = gr.Textbox(
|
134 |
label="Enter your OpenAI API Key (Optional)",
|
@@ -145,12 +144,29 @@ def gradio_interface():
|
|
145 |
character_files = gr.Files(label="Upload Character Files (Multiple Allowed)", file_types=[".txt"])
|
146 |
|
147 |
dtr_enabled = gr.Checkbox(label="Enable DTR Directory")
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
selected_categories = gr.CheckboxGroup(
|
150 |
["1boy", "1girl", "furry", "mecha", "fantasy monster", "animal", "still life"],
|
151 |
label="Choose Character Categories (Optional)"
|
152 |
)
|
153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
with gr.Row():
|
155 |
tags_output = gr.Textbox(label="Generated Tags")
|
156 |
description_output = gr.Textbox(label="Generated Description")
|
@@ -161,7 +177,8 @@ def gradio_interface():
|
|
161 |
generate_button.click(
|
162 |
generate_prompt,
|
163 |
inputs=[
|
164 |
-
action_file, style_file, artist_files, character_files, dtr_enabled, api_key_input, selected_categories
|
|
|
165 |
],
|
166 |
outputs=[tags_output, description_output, combined_output],
|
167 |
)
|
|
|
2 |
import random
|
3 |
import glob
|
4 |
import os
|
5 |
+
import requests
|
6 |
from openai import OpenAI
|
7 |
from dotenv import load_dotenv
|
8 |
|
|
|
30 |
all_lines.extend([line.strip() for line in f if line.strip()])
|
31 |
return all_lines
|
32 |
|
33 |
+
def get_random_items(candidates, num_items=1):
|
34 |
"""
|
35 |
+
从候选项中随机选取指定数量的选项。
|
36 |
"""
|
37 |
+
return random.sample(candidates, min(num_items, len(candidates))) if candidates else []
|
38 |
|
39 |
+
def load_dtr_from_huggingface(dataset_url):
|
40 |
"""
|
41 |
+
从 Hugging Face 数据集中加载 DTR 内容。
|
|
|
|
|
|
|
42 |
"""
|
|
|
43 |
try:
|
44 |
+
response = requests.get(dataset_url)
|
45 |
+
response.raise_for_status()
|
46 |
+
return response.text.splitlines()
|
|
|
47 |
except Exception as e:
|
48 |
+
print(f"Error loading DTR dataset: {e}")
|
49 |
+
return []
|
50 |
|
51 |
def generate_natural_language_description(tags, api_key=None):
|
52 |
"""
|
|
|
83 |
except Exception as e:
|
84 |
return f"GPT generation failed. Error: {e}"
|
85 |
|
86 |
+
def generate_prompt(
|
87 |
+
action_file, style_file, artist_files, character_files, dtr_enabled, dtr_url, api_key, selected_categories,
|
88 |
+
expression_count, item_count, detail_count, scene_count, angle_count, quality_count
|
89 |
+
):
|
90 |
"""
|
91 |
生成随机提示词和描述。
|
92 |
"""
|
|
|
94 |
styles = load_candidates_from_files([style_file]) if style_file else []
|
95 |
artists = load_candidates_from_files(artist_files) if artist_files else []
|
96 |
characters = load_candidates_from_files(character_files) if character_files else []
|
97 |
+
dtr_candidates = load_dtr_from_huggingface(dtr_url) if dtr_enabled else []
|
98 |
|
99 |
number_of_characters = ", ".join(selected_categories) if selected_categories else random.choice(["1girl", "1boy"])
|
100 |
|
101 |
tags = {
|
102 |
"number_of_characters": number_of_characters,
|
103 |
+
"character_name": get_random_items(characters, 1),
|
104 |
+
"artist_prompt": f"(artist:{get_random_items(artists, 1)})",
|
105 |
+
"style": get_random_items(styles, 1),
|
106 |
+
"scene": get_random_items(SCENES, scene_count),
|
107 |
+
"camera_angle": get_random_items(CAMERA_ANGLES, angle_count),
|
108 |
+
"action": get_random_items(actions, 1),
|
109 |
+
"expression": get_random_items(EXPRESSIONS, expression_count),
|
110 |
+
"items": get_random_items(ITEMS, item_count),
|
111 |
+
"other_details": get_random_items(OTHER_DETAILS, detail_count),
|
112 |
+
"quality_prompts": get_random_items(QUALITY_PROMPTS, quality_count),
|
113 |
+
"dtr": get_random_items(dtr_candidates, 1)
|
114 |
}
|
115 |
|
116 |
description = generate_natural_language_description(tags, api_key)
|
117 |
|
118 |
+
tags_list = [item for sublist in tags.values() for item in (sublist if isinstance(sublist, list) else [sublist])] # Flatten
|
119 |
unique_tags = list(dict.fromkeys(tags_list))
|
120 |
final_tags = ", ".join(unique_tags)
|
121 |
combined_output = f"{final_tags}\n\n{description}"
|
|
|
127 |
定义 Gradio 应用界面。
|
128 |
"""
|
129 |
with gr.Blocks() as demo:
|
130 |
+
gr.Markdown("## Random Prompt Generator with Adjustable Tag Counts")
|
131 |
|
132 |
api_key_input = gr.Textbox(
|
133 |
label="Enter your OpenAI API Key (Optional)",
|
|
|
144 |
character_files = gr.Files(label="Upload Character Files (Multiple Allowed)", file_types=[".txt"])
|
145 |
|
146 |
dtr_enabled = gr.Checkbox(label="Enable DTR Directory")
|
147 |
+
dtr_url = gr.Textbox(
|
148 |
+
label="Hugging Face DTR Dataset URL",
|
149 |
+
placeholder="https://huggingface.co/datasets/X779/Danbooruwildcards/tree/main",
|
150 |
+
value="https://huggingface.co/datasets/X779/Danbooruwildcards/tree/main"
|
151 |
+
)
|
152 |
|
153 |
selected_categories = gr.CheckboxGroup(
|
154 |
["1boy", "1girl", "furry", "mecha", "fantasy monster", "animal", "still life"],
|
155 |
label="Choose Character Categories (Optional)"
|
156 |
)
|
157 |
|
158 |
+
with gr.Row():
|
159 |
+
expression_count = gr.Slider(label="Number of Expressions", minimum=1, maximum=5, step=1, value=1)
|
160 |
+
item_count = gr.Slider(label="Number of Items", minimum=1, maximum=5, step=1, value=1)
|
161 |
+
|
162 |
+
with gr.Row():
|
163 |
+
detail_count = gr.Slider(label="Number of Other Details", minimum=1, maximum=5, step=1, value=1)
|
164 |
+
scene_count = gr.Slider(label="Number of Scenes", minimum=1, maximum=5, step=1, value=1)
|
165 |
+
|
166 |
+
with gr.Row():
|
167 |
+
angle_count = gr.Slider(label="Number of Camera Angles", minimum=1, maximum=5, step=1, value=1)
|
168 |
+
quality_count = gr.Slider(label="Number of Quality Prompts", minimum=1, maximum=5, step=1, value=1)
|
169 |
+
|
170 |
with gr.Row():
|
171 |
tags_output = gr.Textbox(label="Generated Tags")
|
172 |
description_output = gr.Textbox(label="Generated Description")
|
|
|
177 |
generate_button.click(
|
178 |
generate_prompt,
|
179 |
inputs=[
|
180 |
+
action_file, style_file, artist_files, character_files, dtr_enabled, dtr_url, api_key_input, selected_categories,
|
181 |
+
expression_count, item_count, detail_count, scene_count, angle_count, quality_count
|
182 |
],
|
183 |
outputs=[tags_output, description_output, combined_output],
|
184 |
)
|