Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -17,20 +17,31 @@ SCENES = ["sunset beach", "rainy city street at night", "floating ash land", "pa
|
|
17 |
CAMERA_ANGLES = ["low-angle shot", "close-up shot", "bird's-eye view", "wide-angle shot", "over-the-shoulder shot", "extreme close-up", "panoramic view", "dynamic tracking shot", "fisheye view", "point-of-view shot"]
|
18 |
QUALITY_PROMPTS = ["cinematic lighting", "sharp shadow", "award-winning", "masterpiece", "vivid colors", "high dynamic range", "immersive", "studio quality", "fine art", "dreamlike", "8K", "HD", "high quality", "best quality", "artistic", "vibrant"]
|
19 |
|
20 |
-
# Hugging Face DTR
|
21 |
DTR_DATASET_PATTERN = "https://huggingface.co/datasets/X779/Danbooruwildcards/resolve/main/*DTR*.txt"
|
22 |
|
23 |
# ========== 工具函数 ==========
|
24 |
-
def load_candidates_from_files(files):
|
25 |
"""
|
26 |
-
|
27 |
"""
|
|
|
|
|
28 |
all_lines = []
|
29 |
if files:
|
30 |
for file in files:
|
31 |
if isinstance(file, str):
|
|
|
32 |
with open(file, "r", encoding="utf-8") as f:
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
return all_lines
|
35 |
|
36 |
def get_random_items(candidates, num_items=1):
|
@@ -39,14 +50,19 @@ def get_random_items(candidates, num_items=1):
|
|
39 |
"""
|
40 |
return random.sample(candidates, min(num_items, len(candidates))) if candidates else []
|
41 |
|
42 |
-
def load_dtr_from_huggingface():
|
43 |
"""
|
44 |
-
从 Hugging Face 数据集中加载所有包含 "DTR"
|
45 |
"""
|
|
|
|
|
46 |
try:
|
47 |
response = requests.get(DTR_DATASET_PATTERN)
|
48 |
response.raise_for_status()
|
49 |
-
|
|
|
|
|
|
|
50 |
except Exception as e:
|
51 |
print(f"Error loading DTR dataset: {e}")
|
52 |
return []
|
@@ -60,11 +76,14 @@ def generate_natural_language_description(tags, api_key=None, base_url=None, mod
|
|
60 |
if not api_key:
|
61 |
return "Error: No API Key provided and none found in environment variables."
|
62 |
|
63 |
-
|
|
|
|
|
|
|
|
|
64 |
|
65 |
try:
|
66 |
client = OpenAI(api_key=api_key, base_url=base_url) if base_url else OpenAI(api_key=api_key)
|
67 |
-
|
68 |
response = client.chat.completions.create(
|
69 |
messages=[
|
70 |
{
|
@@ -72,7 +91,7 @@ def generate_natural_language_description(tags, api_key=None, base_url=None, mod
|
|
72 |
"content": (
|
73 |
"You are a creative assistant that generates detailed and imaginative scene descriptions for AI generation prompts. "
|
74 |
"Focus on the details provided and incorporate them into a cohesive narrative. "
|
75 |
-
"Use at least three sentences but no more than five sentences"
|
76 |
),
|
77 |
},
|
78 |
{
|
@@ -89,114 +108,227 @@ def generate_natural_language_description(tags, api_key=None, base_url=None, mod
|
|
89 |
def generate_prompt(
|
90 |
action_file, style_file, artist_files, character_files, dtr_enabled, api_key, selected_categories,
|
91 |
expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count,
|
92 |
-
artist_count, use_deepseek, deepseek_key
|
93 |
):
|
94 |
"""
|
95 |
生成随机提示词和描述。
|
96 |
"""
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
number_of_characters = ", ".join(selected_categories) if selected_categories else []
|
104 |
|
|
|
105 |
tags = {
|
106 |
-
"number_of_characters": number_of_characters,
|
107 |
"character_name": characters,
|
108 |
-
"artist_prompt": f"(artist:{', '.join(artists)})",
|
109 |
"style": styles,
|
110 |
-
"scene":
|
111 |
-
"camera_angle":
|
112 |
"action": actions,
|
113 |
-
"expression":
|
114 |
-
"items":
|
115 |
-
"other_details":
|
116 |
-
"quality_prompts":
|
117 |
-
"dtr": dtr_candidates
|
118 |
}
|
119 |
|
|
|
|
|
|
|
|
|
|
|
120 |
if use_deepseek:
|
121 |
description = generate_natural_language_description(tags, api_key=deepseek_key, base_url="https://api.deepseek.com", model="deepseek-chat")
|
122 |
else:
|
123 |
description = generate_natural_language_description(tags, api_key=api_key)
|
124 |
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
combined_output = f"{final_tags}\n\n{description}"
|
129 |
return final_tags, description, combined_output
|
130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
# ========== Gradio 界面 ==========
|
132 |
def gradio_interface():
|
133 |
"""
|
134 |
定义 Gradio 应用界面。
|
135 |
"""
|
136 |
with gr.Blocks() as demo:
|
137 |
-
gr.Markdown("## Random Prompt Generator with Adjustable Tag Counts")
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
placeholder="sk-...",
|
142 |
-
type="password"
|
143 |
-
)
|
144 |
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
150 |
|
151 |
-
|
|
|
|
|
|
|
|
|
152 |
|
153 |
-
|
154 |
-
action_file = gr.File(label="Upload Action File (Optional)", file_types=[".txt"])
|
155 |
-
style_file = gr.File(label="Upload Style File (Optional)", file_types=[".txt"])
|
156 |
|
157 |
-
|
158 |
-
artist_files = gr.Files(label="Upload Artist Files (Multiple Allowed)", file_types=[".txt"])
|
159 |
-
character_files = gr.Files(label="Upload Character Files (Multiple Allowed)", file_types=[".txt"])
|
160 |
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
167 |
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
|
|
|
|
|
|
|
|
173 |
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
|
|
198 |
return demo
|
199 |
|
200 |
-
# 启动 Gradio
|
201 |
if __name__ == "__main__":
|
202 |
-
gradio_interface().launch(
|
|
|
17 |
CAMERA_ANGLES = ["low-angle shot", "close-up shot", "bird's-eye view", "wide-angle shot", "over-the-shoulder shot", "extreme close-up", "panoramic view", "dynamic tracking shot", "fisheye view", "point-of-view shot"]
|
18 |
QUALITY_PROMPTS = ["cinematic lighting", "sharp shadow", "award-winning", "masterpiece", "vivid colors", "high dynamic range", "immersive", "studio quality", "fine art", "dreamlike", "8K", "HD", "high quality", "best quality", "artistic", "vibrant"]
|
19 |
|
20 |
+
# Hugging Face DTR 数据集路径(示例,若不可用请忽略)
|
21 |
DTR_DATASET_PATTERN = "https://huggingface.co/datasets/X779/Danbooruwildcards/resolve/main/*DTR*.txt"
|
22 |
|
23 |
# ========== 工具函数 ==========
|
24 |
+
def load_candidates_from_files(files, excluded_tags=None):
|
25 |
"""
|
26 |
+
从多个文件中加载候选项,同时排除用户不想要的标签(精确匹配)。
|
27 |
"""
|
28 |
+
if excluded_tags is None:
|
29 |
+
excluded_tags = set()
|
30 |
all_lines = []
|
31 |
if files:
|
32 |
for file in files:
|
33 |
if isinstance(file, str):
|
34 |
+
# 说明是路径字符串
|
35 |
with open(file, "r", encoding="utf-8") as f:
|
36 |
+
lines = [line.strip() for line in f if line.strip()]
|
37 |
+
filtered = [l for l in lines if l not in excluded_tags]
|
38 |
+
all_lines.extend(filtered)
|
39 |
+
else:
|
40 |
+
# 说明是一个上传的 file-like 对象
|
41 |
+
file_data = file.read().decode("utf-8", errors="ignore")
|
42 |
+
lines = [line.strip() for line in file_data.splitlines() if line.strip()]
|
43 |
+
filtered = [l for l in lines if l not in excluded_tags]
|
44 |
+
all_lines.extend(filtered)
|
45 |
return all_lines
|
46 |
|
47 |
def get_random_items(candidates, num_items=1):
|
|
|
50 |
"""
|
51 |
return random.sample(candidates, min(num_items, len(candidates))) if candidates else []
|
52 |
|
53 |
+
def load_dtr_from_huggingface(excluded_tags=None):
|
54 |
"""
|
55 |
+
从 Hugging Face 数据集中加载所有包含 "DTR" 的文件内容,同时排除不需要的tag。
|
56 |
"""
|
57 |
+
if excluded_tags is None:
|
58 |
+
excluded_tags = set()
|
59 |
try:
|
60 |
response = requests.get(DTR_DATASET_PATTERN)
|
61 |
response.raise_for_status()
|
62 |
+
lines = response.text.splitlines()
|
63 |
+
# 只过滤精确匹配
|
64 |
+
lines = [l for l in lines if l not in excluded_tags]
|
65 |
+
return lines
|
66 |
except Exception as e:
|
67 |
print(f"Error loading DTR dataset: {e}")
|
68 |
return []
|
|
|
76 |
if not api_key:
|
77 |
return "Error: No API Key provided and none found in environment variables."
|
78 |
|
79 |
+
# 将dict转成可读字符串
|
80 |
+
tag_descriptions = "\n".join([
|
81 |
+
f"{key}: {', '.join(value) if isinstance(value, list) else value}"
|
82 |
+
for key, value in tags.items() if value
|
83 |
+
])
|
84 |
|
85 |
try:
|
86 |
client = OpenAI(api_key=api_key, base_url=base_url) if base_url else OpenAI(api_key=api_key)
|
|
|
87 |
response = client.chat.completions.create(
|
88 |
messages=[
|
89 |
{
|
|
|
91 |
"content": (
|
92 |
"You are a creative assistant that generates detailed and imaginative scene descriptions for AI generation prompts. "
|
93 |
"Focus on the details provided and incorporate them into a cohesive narrative. "
|
94 |
+
"Use at least three sentences but no more than five sentences."
|
95 |
),
|
96 |
},
|
97 |
{
|
|
|
108 |
def generate_prompt(
|
109 |
action_file, style_file, artist_files, character_files, dtr_enabled, api_key, selected_categories,
|
110 |
expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count,
|
111 |
+
artist_count, use_deepseek, deepseek_key, user_custom_tags, excluded_tags
|
112 |
):
|
113 |
"""
|
114 |
生成随机提示词和描述。
|
115 |
"""
|
116 |
+
# 处理排除 Tags(逗号分隔 -> 去重 set)
|
117 |
+
excluded_set = set(
|
118 |
+
[tag.strip() for tag in excluded_tags.split(",") if tag.strip()]
|
119 |
+
) if excluded_tags else set()
|
120 |
+
|
121 |
+
# 从文件中加载可选 action、style、artist、character
|
122 |
+
actions = get_random_items(load_candidates_from_files([action_file], excluded_set) if action_file else [], action_count)
|
123 |
+
styles = get_random_items(load_candidates_from_files([style_file], excluded_set) if style_file else [], style_count)
|
124 |
+
artists = get_random_items(load_candidates_from_files(artist_files, excluded_set) if artist_files else [], artist_count)
|
125 |
+
characters = get_random_items(load_candidates_from_files(character_files, excluded_set) if character_files else [], 1)
|
126 |
+
|
127 |
+
# 处理 DTR
|
128 |
+
dtr_candidates = get_random_items(load_dtr_from_huggingface(excluded_set) if dtr_enabled else [], 1)
|
129 |
+
|
130 |
+
# 处理预设列表中的随机筛选
|
131 |
+
filtered_expressions = [e for e in EXPRESSIONS if e not in excluded_set]
|
132 |
+
filtered_items = [i for i in ITEMS if i not in excluded_set]
|
133 |
+
filtered_details = [d for d in OTHER_DETAILS if d not in excluded_set]
|
134 |
+
filtered_scenes = [s for s in SCENES if s not in excluded_set]
|
135 |
+
filtered_angles = [c for c in CAMERA_ANGLES if c not in excluded_set]
|
136 |
+
filtered_quality = [q for q in QUALITY_PROMPTS if q not in excluded_set]
|
137 |
+
|
138 |
+
# 随机抽取
|
139 |
+
random_expression = get_random_items(filtered_expressions, expression_count)
|
140 |
+
random_items = get_random_items(filtered_items, item_count)
|
141 |
+
random_details = get_random_items(filtered_details, detail_count)
|
142 |
+
random_scenes = get_random_items(filtered_scenes, scene_count)
|
143 |
+
random_angles = get_random_items(filtered_angles, angle_count)
|
144 |
+
random_quality = get_random_items(filtered_quality, quality_count)
|
145 |
|
146 |
number_of_characters = ", ".join(selected_categories) if selected_categories else []
|
147 |
|
148 |
+
# 整理为字典
|
149 |
tags = {
|
150 |
+
"number_of_characters": [number_of_characters] if number_of_characters else [],
|
151 |
"character_name": characters,
|
152 |
+
"artist_prompt": [f"(artist:{', '.join(artists)})"] if artists else [],
|
153 |
"style": styles,
|
154 |
+
"scene": random_scenes,
|
155 |
+
"camera_angle": random_angles,
|
156 |
"action": actions,
|
157 |
+
"expression": random_expression,
|
158 |
+
"items": random_items,
|
159 |
+
"other_details": random_details,
|
160 |
+
"quality_prompts": random_quality,
|
161 |
+
"dtr": dtr_candidates,
|
162 |
}
|
163 |
|
164 |
+
# 如果用户有自定义输入
|
165 |
+
if user_custom_tags.strip():
|
166 |
+
tags["custom_tags"] = [t.strip() for t in user_custom_tags.split(",") if t.strip()]
|
167 |
+
|
168 |
+
# 生成自然语言描述
|
169 |
if use_deepseek:
|
170 |
description = generate_natural_language_description(tags, api_key=deepseek_key, base_url="https://api.deepseek.com", model="deepseek-chat")
|
171 |
else:
|
172 |
description = generate_natural_language_description(tags, api_key=api_key)
|
173 |
|
174 |
+
# 整理最终 Tags(flatten 并去重)
|
175 |
+
tags_list = []
|
176 |
+
for v in tags.values():
|
177 |
+
if isinstance(v, list):
|
178 |
+
tags_list.extend(v)
|
179 |
+
else:
|
180 |
+
tags_list.append(v)
|
181 |
+
|
182 |
+
# 去重保持顺序
|
183 |
+
seen = set()
|
184 |
+
final_tags_list = []
|
185 |
+
for t in tags_list:
|
186 |
+
if t not in seen and t:
|
187 |
+
seen.add(t)
|
188 |
+
final_tags_list.append(t)
|
189 |
+
|
190 |
+
final_tags = ", ".join(final_tags_list)
|
191 |
+
|
192 |
+
# 默认 Combined = Tags + Description
|
193 |
combined_output = f"{final_tags}\n\n{description}"
|
194 |
return final_tags, description, combined_output
|
195 |
|
196 |
+
# ========== Favorite 相关函数 ==========
|
197 |
+
def add_to_favorites(combined_output, current_favorites):
|
198 |
+
"""
|
199 |
+
将当前生成的 combined_output 添加到收藏列表中(最多存 3 条)。
|
200 |
+
"""
|
201 |
+
# current_favorites 是一个列表
|
202 |
+
current_favorites.append(combined_output)
|
203 |
+
# 如果超过3条,移除最早的一条
|
204 |
+
if len(current_favorites) > 3:
|
205 |
+
current_favorites.pop(0)
|
206 |
+
# 格式化输出
|
207 |
+
favorites_text = "\n\n".join(
|
208 |
+
[f"[Favorite {i+1}]\n{fav}" for i, fav in enumerate(current_favorites)]
|
209 |
+
)
|
210 |
+
return favorites_text, current_favorites
|
211 |
+
|
212 |
# ========== Gradio 界面 ==========
|
213 |
def gradio_interface():
|
214 |
"""
|
215 |
定义 Gradio 应用界面。
|
216 |
"""
|
217 |
with gr.Blocks() as demo:
|
218 |
+
gr.Markdown("## Random Prompt Generator with Adjustable Tag Counts (Enhanced)")
|
219 |
|
220 |
+
# 用于存储收藏内容的状态(最多缓存3条)
|
221 |
+
favorites_state = gr.State([])
|
|
|
|
|
|
|
222 |
|
223 |
+
with gr.Row():
|
224 |
+
# 左侧:文件上传、参数选择、排除/自定义输入
|
225 |
+
with gr.Column(scale=1):
|
226 |
+
api_key_input = gr.Textbox(
|
227 |
+
label="OpenAI API Key (可选)",
|
228 |
+
placeholder="sk-...",
|
229 |
+
type="password"
|
230 |
+
)
|
231 |
|
232 |
+
deepseek_key_input = gr.Textbox(
|
233 |
+
label="DeepSeek API Key (可选)",
|
234 |
+
placeholder="sk-...",
|
235 |
+
type="password"
|
236 |
+
)
|
237 |
|
238 |
+
use_deepseek = gr.Checkbox(label="Use DeepSeek API")
|
|
|
|
|
239 |
|
240 |
+
dtr_enabled = gr.Checkbox(label="Enable DTR (如不可用请勿勾选)")
|
|
|
|
|
241 |
|
242 |
+
# 上传文件部分
|
243 |
+
with gr.Box():
|
244 |
+
gr.Markdown("**上传文件 (可选):**")
|
245 |
+
action_file = gr.File(label="Action File", file_types=[".txt"])
|
246 |
+
style_file = gr.File(label="Style File", file_types=[".txt"])
|
247 |
+
artist_files = gr.Files(label="Artist Files", file_types=[".txt"])
|
248 |
+
character_files = gr.Files(label="Character Files", file_types=[".txt"])
|
249 |
|
250 |
+
# 选择角色类型
|
251 |
+
selected_categories = gr.CheckboxGroup(
|
252 |
+
["1boy", "1girl", "furry", "mecha", "fantasy monster", "animal", "still life"],
|
253 |
+
label="Choose Character Categories"
|
254 |
+
)
|
255 |
|
256 |
+
# 输入排除和自定义
|
257 |
+
excluded_tags = gr.Textbox(
|
258 |
+
label="排除 Tags (逗号分隔)",
|
259 |
+
placeholder="如:angry, sword"
|
260 |
+
)
|
261 |
+
user_custom_tags = gr.Textbox(
|
262 |
+
label="自定义附加 Tags (逗号分隔)",
|
263 |
+
placeholder="如:glowing eyes, giant wings"
|
264 |
+
)
|
265 |
|
266 |
+
# 各种数量
|
267 |
+
with gr.Box():
|
268 |
+
gr.Markdown("**随机数量设置:**")
|
269 |
+
expression_count = gr.Slider(label="Number of Expressions", minimum=0, maximum=10, step=1, value=1)
|
270 |
+
item_count = gr.Slider(label="Number of Items", minimum=0, maximum=10, step=1, value=1)
|
271 |
+
detail_count = gr.Slider(label="Number of Other Details", minimum=0, maximum=10, step=1, value=1)
|
272 |
+
scene_count = gr.Slider(label="Number of Scenes", minimum=0, maximum=10, step=1, value=1)
|
273 |
+
angle_count = gr.Slider(label="Number of Camera Angles", minimum=0, maximum=10, step=1, value=1)
|
274 |
+
quality_count = gr.Slider(label="Number of Quality Prompts", minimum=0, maximum=10, step=1, value=1)
|
275 |
+
action_count = gr.Slider(label="Number of Actions", minimum=1, maximum=10, step=1, value=1)
|
276 |
+
style_count = gr.Slider(label="Number of Styles", minimum=1, maximum=10, step=1, value=1)
|
277 |
+
artist_count = gr.Slider(label="Number of Artists", minimum=1, maximum=10, step=1, value=1)
|
278 |
|
279 |
+
# 右侧:生成按钮 + 生成结果 + 收藏
|
280 |
+
with gr.Column(scale=2):
|
281 |
+
generate_button = gr.Button("Generate Prompt", variant="primary")
|
282 |
+
|
283 |
+
# 生成的结果(可编辑)
|
284 |
+
tags_output = gr.Textbox(
|
285 |
+
label="Generated Tags",
|
286 |
+
placeholder="等待生成...",
|
287 |
+
lines=4,
|
288 |
+
interactive=True
|
289 |
+
)
|
290 |
+
description_output = gr.Textbox(
|
291 |
+
label="Generated Description",
|
292 |
+
placeholder="等待生成...",
|
293 |
+
lines=4,
|
294 |
+
interactive=True
|
295 |
+
)
|
296 |
+
combined_output = gr.Textbox(
|
297 |
+
label="Combined Output: Tags + Description",
|
298 |
+
placeholder="等待生成...",
|
299 |
+
lines=6
|
300 |
+
)
|
301 |
+
|
302 |
+
with gr.Row():
|
303 |
+
# 收藏操作
|
304 |
+
favorite_button = gr.Button("收藏本次结果")
|
305 |
+
favorites_box = gr.Textbox(
|
306 |
+
label="收藏夹 (最多 3 条)",
|
307 |
+
placeholder="暂无收藏",
|
308 |
+
lines=6
|
309 |
+
)
|
310 |
+
|
311 |
+
# 点击生成按钮
|
312 |
+
generate_button.click(
|
313 |
+
generate_prompt,
|
314 |
+
inputs=[
|
315 |
+
action_file, style_file, artist_files, character_files, dtr_enabled, api_key_input, selected_categories,
|
316 |
+
expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count,
|
317 |
+
artist_count, use_deepseek, deepseek_key_input, user_custom_tags, excluded_tags
|
318 |
+
],
|
319 |
+
outputs=[tags_output, description_output, combined_output],
|
320 |
+
)
|
321 |
+
|
322 |
+
# 收藏按钮点击事件
|
323 |
+
favorite_button.click(
|
324 |
+
fn=add_to_favorites,
|
325 |
+
inputs=[combined_output, favorites_state],
|
326 |
+
outputs=[favorites_box, favorites_state],
|
327 |
+
)
|
328 |
|
329 |
+
# 整个界面启动
|
330 |
return demo
|
331 |
|
332 |
+
# 启动 Gradio 应用(在 Hugging Face Spaces 上只需要保留最后的 launch 即可)
|
333 |
if __name__ == "__main__":
|
334 |
+
gradio_interface().launch()
|