PSNbst commited on
Commit
a6ac76e
·
verified ·
1 Parent(s): a36cdfe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -80
app.py CHANGED
@@ -17,20 +17,31 @@ SCENES = ["sunset beach", "rainy city street at night", "floating ash land", "pa
17
  CAMERA_ANGLES = ["low-angle shot", "close-up shot", "bird's-eye view", "wide-angle shot", "over-the-shoulder shot", "extreme close-up", "panoramic view", "dynamic tracking shot", "fisheye view", "point-of-view shot"]
18
  QUALITY_PROMPTS = ["cinematic lighting", "sharp shadow", "award-winning", "masterpiece", "vivid colors", "high dynamic range", "immersive", "studio quality", "fine art", "dreamlike", "8K", "HD", "high quality", "best quality", "artistic", "vibrant"]
19
 
20
- # Hugging Face DTR 数据集路径
21
  DTR_DATASET_PATTERN = "https://huggingface.co/datasets/X779/Danbooruwildcards/resolve/main/*DTR*.txt"
22
 
23
  # ========== 工具函数 ==========
24
- def load_candidates_from_files(files):
25
  """
26
- 从多个文件中加载候选项。
27
  """
 
 
28
  all_lines = []
29
  if files:
30
  for file in files:
31
  if isinstance(file, str):
 
32
  with open(file, "r", encoding="utf-8") as f:
33
- all_lines.extend([line.strip() for line in f if line.strip()])
 
 
 
 
 
 
 
 
34
  return all_lines
35
 
36
  def get_random_items(candidates, num_items=1):
@@ -39,14 +50,19 @@ def get_random_items(candidates, num_items=1):
39
  """
40
  return random.sample(candidates, min(num_items, len(candidates))) if candidates else []
41
 
42
- def load_dtr_from_huggingface():
43
  """
44
- 从 Hugging Face 数据集中加载所有包含 "DTR" 的文件内容。
45
  """
 
 
46
  try:
47
  response = requests.get(DTR_DATASET_PATTERN)
48
  response.raise_for_status()
49
- return response.text.splitlines()
 
 
 
50
  except Exception as e:
51
  print(f"Error loading DTR dataset: {e}")
52
  return []
@@ -60,11 +76,14 @@ def generate_natural_language_description(tags, api_key=None, base_url=None, mod
60
  if not api_key:
61
  return "Error: No API Key provided and none found in environment variables."
62
 
63
- tag_descriptions = "\n".join([f"{key}: {value}" for key, value in tags.items() if value])
 
 
 
 
64
 
65
  try:
66
  client = OpenAI(api_key=api_key, base_url=base_url) if base_url else OpenAI(api_key=api_key)
67
-
68
  response = client.chat.completions.create(
69
  messages=[
70
  {
@@ -72,7 +91,7 @@ def generate_natural_language_description(tags, api_key=None, base_url=None, mod
72
  "content": (
73
  "You are a creative assistant that generates detailed and imaginative scene descriptions for AI generation prompts. "
74
  "Focus on the details provided and incorporate them into a cohesive narrative. "
75
- "Use at least three sentences but no more than five sentences"
76
  ),
77
  },
78
  {
@@ -89,114 +108,227 @@ def generate_natural_language_description(tags, api_key=None, base_url=None, mod
89
  def generate_prompt(
90
  action_file, style_file, artist_files, character_files, dtr_enabled, api_key, selected_categories,
91
  expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count,
92
- artist_count, use_deepseek, deepseek_key
93
  ):
94
  """
95
  生成随机提示词和描述。
96
  """
97
- actions = get_random_items(load_candidates_from_files([action_file]) if action_file else [], action_count)
98
- styles = get_random_items(load_candidates_from_files([style_file]) if style_file else [], style_count)
99
- artists = get_random_items(load_candidates_from_files(artist_files) if artist_files else [], artist_count)
100
- characters = get_random_items(load_candidates_from_files(character_files) if character_files else [], 1)
101
- dtr_candidates = get_random_items(load_dtr_from_huggingface() if dtr_enabled else [], 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  number_of_characters = ", ".join(selected_categories) if selected_categories else []
104
 
 
105
  tags = {
106
- "number_of_characters": number_of_characters,
107
  "character_name": characters,
108
- "artist_prompt": f"(artist:{', '.join(artists)})",
109
  "style": styles,
110
- "scene": get_random_items(SCENES, scene_count),
111
- "camera_angle": get_random_items(CAMERA_ANGLES, angle_count),
112
  "action": actions,
113
- "expression": get_random_items(EXPRESSIONS, expression_count),
114
- "items": get_random_items(ITEMS, item_count),
115
- "other_details": get_random_items(OTHER_DETAILS, detail_count),
116
- "quality_prompts": get_random_items(QUALITY_PROMPTS, quality_count),
117
- "dtr": dtr_candidates
118
  }
119
 
 
 
 
 
 
120
  if use_deepseek:
121
  description = generate_natural_language_description(tags, api_key=deepseek_key, base_url="https://api.deepseek.com", model="deepseek-chat")
122
  else:
123
  description = generate_natural_language_description(tags, api_key=api_key)
124
 
125
- tags_list = [item for sublist in tags.values() for item in (sublist if isinstance(sublist, list) else [sublist])] # Flatten
126
- unique_tags = list(dict.fromkeys(tags_list))
127
- final_tags = ", ".join(unique_tags)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  combined_output = f"{final_tags}\n\n{description}"
129
  return final_tags, description, combined_output
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  # ========== Gradio 界面 ==========
132
  def gradio_interface():
133
  """
134
  定义 Gradio 应用界面。
135
  """
136
  with gr.Blocks() as demo:
137
- gr.Markdown("## Random Prompt Generator with Adjustable Tag Counts")
138
 
139
- api_key_input = gr.Textbox(
140
- label="Enter your OpenAI API Key (Optional)",
141
- placeholder="sk-...",
142
- type="password"
143
- )
144
 
145
- deepseek_key_input = gr.Textbox(
146
- label="Enter your DeepSeek API Key (Optional)",
147
- placeholder="sk-...",
148
- type="password"
149
- )
 
 
 
150
 
151
- use_deepseek = gr.Checkbox(label="Use DeepSeek API - 用DeepSeek别忘了勾选这个")
 
 
 
 
152
 
153
- with gr.Row():
154
- action_file = gr.File(label="Upload Action File (Optional)", file_types=[".txt"])
155
- style_file = gr.File(label="Upload Style File (Optional)", file_types=[".txt"])
156
 
157
- with gr.Row():
158
- artist_files = gr.Files(label="Upload Artist Files (Multiple Allowed)", file_types=[".txt"])
159
- character_files = gr.Files(label="Upload Character Files (Multiple Allowed)", file_types=[".txt"])
160
 
161
- dtr_enabled = gr.Checkbox(label="Enable DTR - 当前不可用2025-01-12")
 
 
 
 
 
 
162
 
163
- selected_categories = gr.CheckboxGroup(
164
- ["1boy", "1girl", "furry", "mecha", "fantasy monster", "animal", "still life"],
165
- label="Choose Character Categories (Optional)"
166
- )
 
167
 
168
- with gr.Row():
169
- expression_count = gr.Slider(label="Number of Expressions", minimum=0, maximum=10, step=1, value=1)
170
- item_count = gr.Slider(label="Number of Items", minimum=0, maximum=10, step=1, value=1)
171
- detail_count = gr.Slider(label="Number of Other Details", minimum=0, maximum=10, step=1, value=1)
172
- scene_count = gr.Slider(label="Number of Scenes", minimum=0, maximum=10, step=1, value=1)
 
 
 
 
173
 
174
- with gr.Row():
175
- angle_count = gr.Slider(label="Number of Camera Angles", minimum=0, maximum=10, step=1, value=1)
176
- quality_count = gr.Slider(label="Number of Quality Prompts", minimum=0, maximum=10, step=1, value=1)
177
- action_count = gr.Slider(label="Number of Actions", minimum=1, maximum=10, step=1, value=1)
178
- style_count = gr.Slider(label="Number of Styles", minimum=1, maximum=10, step=1, value=1)
179
- artist_count = gr.Slider(label="Number of Artists", minimum=1, maximum=10, step=1, value=1)
 
 
 
 
 
 
180
 
181
- with gr.Row():
182
- tags_output = gr.Textbox(label="Generated Tags")
183
- description_output = gr.Textbox(label="Generated Description")
184
- combined_output = gr.Textbox(label="Combined Output: Tags + Description")
185
-
186
- generate_button = gr.Button("Generate Prompt")
187
-
188
- generate_button.click(
189
- generate_prompt,
190
- inputs=[
191
- action_file, style_file, artist_files, character_files, dtr_enabled, api_key_input, selected_categories,
192
- expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count,
193
- artist_count, use_deepseek, deepseek_key_input
194
- ],
195
- outputs=[tags_output, description_output, combined_output],
196
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
 
198
  return demo
199
 
200
- # 启动 Gradio 应用
201
  if __name__ == "__main__":
202
- gradio_interface().launch(share=True)
 
17
  CAMERA_ANGLES = ["low-angle shot", "close-up shot", "bird's-eye view", "wide-angle shot", "over-the-shoulder shot", "extreme close-up", "panoramic view", "dynamic tracking shot", "fisheye view", "point-of-view shot"]
18
  QUALITY_PROMPTS = ["cinematic lighting", "sharp shadow", "award-winning", "masterpiece", "vivid colors", "high dynamic range", "immersive", "studio quality", "fine art", "dreamlike", "8K", "HD", "high quality", "best quality", "artistic", "vibrant"]
19
 
20
+ # Hugging Face DTR 数据集路径(示例,若不可用请忽略)
21
  DTR_DATASET_PATTERN = "https://huggingface.co/datasets/X779/Danbooruwildcards/resolve/main/*DTR*.txt"
22
 
23
  # ========== 工具函数 ==========
24
+ def load_candidates_from_files(files, excluded_tags=None):
25
  """
26
+ 从多个文件中加载候选项,同时排除用户不想要的标签(精确匹配)。
27
  """
28
+ if excluded_tags is None:
29
+ excluded_tags = set()
30
  all_lines = []
31
  if files:
32
  for file in files:
33
  if isinstance(file, str):
34
+ # 说明是路径字符串
35
  with open(file, "r", encoding="utf-8") as f:
36
+ lines = [line.strip() for line in f if line.strip()]
37
+ filtered = [l for l in lines if l not in excluded_tags]
38
+ all_lines.extend(filtered)
39
+ else:
40
+ # 说明是一个上传的 file-like 对象
41
+ file_data = file.read().decode("utf-8", errors="ignore")
42
+ lines = [line.strip() for line in file_data.splitlines() if line.strip()]
43
+ filtered = [l for l in lines if l not in excluded_tags]
44
+ all_lines.extend(filtered)
45
  return all_lines
46
 
47
  def get_random_items(candidates, num_items=1):
 
50
  """
51
  return random.sample(candidates, min(num_items, len(candidates))) if candidates else []
52
 
53
+ def load_dtr_from_huggingface(excluded_tags=None):
54
  """
55
+ 从 Hugging Face 数据集中加载所有包含 "DTR" 的文件内容,同时排除不需要的tag。
56
  """
57
+ if excluded_tags is None:
58
+ excluded_tags = set()
59
  try:
60
  response = requests.get(DTR_DATASET_PATTERN)
61
  response.raise_for_status()
62
+ lines = response.text.splitlines()
63
+ # 只过滤精确匹配
64
+ lines = [l for l in lines if l not in excluded_tags]
65
+ return lines
66
  except Exception as e:
67
  print(f"Error loading DTR dataset: {e}")
68
  return []
 
76
  if not api_key:
77
  return "Error: No API Key provided and none found in environment variables."
78
 
79
+ # 将dict转成可读字符串
80
+ tag_descriptions = "\n".join([
81
+ f"{key}: {', '.join(value) if isinstance(value, list) else value}"
82
+ for key, value in tags.items() if value
83
+ ])
84
 
85
  try:
86
  client = OpenAI(api_key=api_key, base_url=base_url) if base_url else OpenAI(api_key=api_key)
 
87
  response = client.chat.completions.create(
88
  messages=[
89
  {
 
91
  "content": (
92
  "You are a creative assistant that generates detailed and imaginative scene descriptions for AI generation prompts. "
93
  "Focus on the details provided and incorporate them into a cohesive narrative. "
94
+ "Use at least three sentences but no more than five sentences."
95
  ),
96
  },
97
  {
 
108
  def generate_prompt(
109
  action_file, style_file, artist_files, character_files, dtr_enabled, api_key, selected_categories,
110
  expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count,
111
+ artist_count, use_deepseek, deepseek_key, user_custom_tags, excluded_tags
112
  ):
113
  """
114
  生成随机提示词和描述。
115
  """
116
+ # 处理排除 Tags(逗号分隔 -> 去重 set)
117
+ excluded_set = set(
118
+ [tag.strip() for tag in excluded_tags.split(",") if tag.strip()]
119
+ ) if excluded_tags else set()
120
+
121
+ # 从文件中加载可选 action、style、artist、character
122
+ actions = get_random_items(load_candidates_from_files([action_file], excluded_set) if action_file else [], action_count)
123
+ styles = get_random_items(load_candidates_from_files([style_file], excluded_set) if style_file else [], style_count)
124
+ artists = get_random_items(load_candidates_from_files(artist_files, excluded_set) if artist_files else [], artist_count)
125
+ characters = get_random_items(load_candidates_from_files(character_files, excluded_set) if character_files else [], 1)
126
+
127
+ # 处理 DTR
128
+ dtr_candidates = get_random_items(load_dtr_from_huggingface(excluded_set) if dtr_enabled else [], 1)
129
+
130
+ # 处理预设列表中的随机筛选
131
+ filtered_expressions = [e for e in EXPRESSIONS if e not in excluded_set]
132
+ filtered_items = [i for i in ITEMS if i not in excluded_set]
133
+ filtered_details = [d for d in OTHER_DETAILS if d not in excluded_set]
134
+ filtered_scenes = [s for s in SCENES if s not in excluded_set]
135
+ filtered_angles = [c for c in CAMERA_ANGLES if c not in excluded_set]
136
+ filtered_quality = [q for q in QUALITY_PROMPTS if q not in excluded_set]
137
+
138
+ # 随机抽取
139
+ random_expression = get_random_items(filtered_expressions, expression_count)
140
+ random_items = get_random_items(filtered_items, item_count)
141
+ random_details = get_random_items(filtered_details, detail_count)
142
+ random_scenes = get_random_items(filtered_scenes, scene_count)
143
+ random_angles = get_random_items(filtered_angles, angle_count)
144
+ random_quality = get_random_items(filtered_quality, quality_count)
145
 
146
  number_of_characters = ", ".join(selected_categories) if selected_categories else []
147
 
148
+ # 整理为字典
149
  tags = {
150
+ "number_of_characters": [number_of_characters] if number_of_characters else [],
151
  "character_name": characters,
152
+ "artist_prompt": [f"(artist:{', '.join(artists)})"] if artists else [],
153
  "style": styles,
154
+ "scene": random_scenes,
155
+ "camera_angle": random_angles,
156
  "action": actions,
157
+ "expression": random_expression,
158
+ "items": random_items,
159
+ "other_details": random_details,
160
+ "quality_prompts": random_quality,
161
+ "dtr": dtr_candidates,
162
  }
163
 
164
+ # 如果用户有自定义输入
165
+ if user_custom_tags.strip():
166
+ tags["custom_tags"] = [t.strip() for t in user_custom_tags.split(",") if t.strip()]
167
+
168
+ # 生成自然语言描述
169
  if use_deepseek:
170
  description = generate_natural_language_description(tags, api_key=deepseek_key, base_url="https://api.deepseek.com", model="deepseek-chat")
171
  else:
172
  description = generate_natural_language_description(tags, api_key=api_key)
173
 
174
+ # 整理最终 Tags(flatten 并去重)
175
+ tags_list = []
176
+ for v in tags.values():
177
+ if isinstance(v, list):
178
+ tags_list.extend(v)
179
+ else:
180
+ tags_list.append(v)
181
+
182
+ # 去重保持顺序
183
+ seen = set()
184
+ final_tags_list = []
185
+ for t in tags_list:
186
+ if t not in seen and t:
187
+ seen.add(t)
188
+ final_tags_list.append(t)
189
+
190
+ final_tags = ", ".join(final_tags_list)
191
+
192
+ # 默认 Combined = Tags + Description
193
  combined_output = f"{final_tags}\n\n{description}"
194
  return final_tags, description, combined_output
195
 
196
+ # ========== Favorite 相关函数 ==========
197
+ def add_to_favorites(combined_output, current_favorites):
198
+ """
199
+ 将当前生成的 combined_output 添加到收藏列表中(最多存 3 条)。
200
+ """
201
+ # current_favorites 是一个列表
202
+ current_favorites.append(combined_output)
203
+ # 如果超过3条,移除最早的一条
204
+ if len(current_favorites) > 3:
205
+ current_favorites.pop(0)
206
+ # 格式化输出
207
+ favorites_text = "\n\n".join(
208
+ [f"[Favorite {i+1}]\n{fav}" for i, fav in enumerate(current_favorites)]
209
+ )
210
+ return favorites_text, current_favorites
211
+
212
  # ========== Gradio 界面 ==========
213
  def gradio_interface():
214
  """
215
  定义 Gradio 应用界面。
216
  """
217
  with gr.Blocks() as demo:
218
+ gr.Markdown("## Random Prompt Generator with Adjustable Tag Counts (Enhanced)")
219
 
220
+ # 用于存储收藏内容的状态(最多缓存3条)
221
+ favorites_state = gr.State([])
 
 
 
222
 
223
+ with gr.Row():
224
+ # 左侧:文件上传、参数选择、排除/自定义输入
225
+ with gr.Column(scale=1):
226
+ api_key_input = gr.Textbox(
227
+ label="OpenAI API Key (可选)",
228
+ placeholder="sk-...",
229
+ type="password"
230
+ )
231
 
232
+ deepseek_key_input = gr.Textbox(
233
+ label="DeepSeek API Key (可选)",
234
+ placeholder="sk-...",
235
+ type="password"
236
+ )
237
 
238
+ use_deepseek = gr.Checkbox(label="Use DeepSeek API")
 
 
239
 
240
+ dtr_enabled = gr.Checkbox(label="Enable DTR (如不可用请勿勾选)")
 
 
241
 
242
+ # 上传文件部分
243
+ with gr.Box():
244
+ gr.Markdown("**上传文件 (可选):**")
245
+ action_file = gr.File(label="Action File", file_types=[".txt"])
246
+ style_file = gr.File(label="Style File", file_types=[".txt"])
247
+ artist_files = gr.Files(label="Artist Files", file_types=[".txt"])
248
+ character_files = gr.Files(label="Character Files", file_types=[".txt"])
249
 
250
+ # 选择角色类型
251
+ selected_categories = gr.CheckboxGroup(
252
+ ["1boy", "1girl", "furry", "mecha", "fantasy monster", "animal", "still life"],
253
+ label="Choose Character Categories"
254
+ )
255
 
256
+ # 输入排除和自定义
257
+ excluded_tags = gr.Textbox(
258
+ label="排除 Tags (逗号分隔)",
259
+ placeholder="如:angry, sword"
260
+ )
261
+ user_custom_tags = gr.Textbox(
262
+ label="自定义附加 Tags (逗号分隔)",
263
+ placeholder="如:glowing eyes, giant wings"
264
+ )
265
 
266
+ # 各种数量
267
+ with gr.Box():
268
+ gr.Markdown("**随机数量设置:**")
269
+ expression_count = gr.Slider(label="Number of Expressions", minimum=0, maximum=10, step=1, value=1)
270
+ item_count = gr.Slider(label="Number of Items", minimum=0, maximum=10, step=1, value=1)
271
+ detail_count = gr.Slider(label="Number of Other Details", minimum=0, maximum=10, step=1, value=1)
272
+ scene_count = gr.Slider(label="Number of Scenes", minimum=0, maximum=10, step=1, value=1)
273
+ angle_count = gr.Slider(label="Number of Camera Angles", minimum=0, maximum=10, step=1, value=1)
274
+ quality_count = gr.Slider(label="Number of Quality Prompts", minimum=0, maximum=10, step=1, value=1)
275
+ action_count = gr.Slider(label="Number of Actions", minimum=1, maximum=10, step=1, value=1)
276
+ style_count = gr.Slider(label="Number of Styles", minimum=1, maximum=10, step=1, value=1)
277
+ artist_count = gr.Slider(label="Number of Artists", minimum=1, maximum=10, step=1, value=1)
278
 
279
+ # 右侧:生成按钮 + 生成结果 + 收藏
280
+ with gr.Column(scale=2):
281
+ generate_button = gr.Button("Generate Prompt", variant="primary")
282
+
283
+ # 生成的结果(可编辑)
284
+ tags_output = gr.Textbox(
285
+ label="Generated Tags",
286
+ placeholder="等待生成...",
287
+ lines=4,
288
+ interactive=True
289
+ )
290
+ description_output = gr.Textbox(
291
+ label="Generated Description",
292
+ placeholder="等待生成...",
293
+ lines=4,
294
+ interactive=True
295
+ )
296
+ combined_output = gr.Textbox(
297
+ label="Combined Output: Tags + Description",
298
+ placeholder="等待生成...",
299
+ lines=6
300
+ )
301
+
302
+ with gr.Row():
303
+ # 收藏操作
304
+ favorite_button = gr.Button("收藏本次结果")
305
+ favorites_box = gr.Textbox(
306
+ label="收藏夹 (最多 3 条)",
307
+ placeholder="暂无收藏",
308
+ lines=6
309
+ )
310
+
311
+ # 点击生成按钮
312
+ generate_button.click(
313
+ generate_prompt,
314
+ inputs=[
315
+ action_file, style_file, artist_files, character_files, dtr_enabled, api_key_input, selected_categories,
316
+ expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count,
317
+ artist_count, use_deepseek, deepseek_key_input, user_custom_tags, excluded_tags
318
+ ],
319
+ outputs=[tags_output, description_output, combined_output],
320
+ )
321
+
322
+ # 收藏按钮点击事件
323
+ favorite_button.click(
324
+ fn=add_to_favorites,
325
+ inputs=[combined_output, favorites_state],
326
+ outputs=[favorites_box, favorites_state],
327
+ )
328
 
329
+ # 整个界面启动
330
  return demo
331
 
332
+ # 启动 Gradio 应用(在 Hugging Face Spaces 上只需要保留最后的 launch 即可)
333
  if __name__ == "__main__":
334
+ gradio_interface().launch()