PSNbst commited on
Commit
46b912b
·
verified ·
1 Parent(s): 63c2525

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -23
app.py CHANGED
@@ -10,12 +10,40 @@ from dotenv import load_dotenv
10
  load_dotenv()
11
 
12
  # ========== 默认选项和数据 ==========
13
- EXPRESSIONS = ["smiling", "determined", "surprised", "serene", "smug", "thinking", "looking back", "laughing", "angry", "pensive", "confident", "grinning", "thoughtful", "sad tears", "bewildered"]
14
- ITEMS = ["magic wand", "sword", "flower", "book of spells", "earrings", "loincloth", "slippers", "ancient scroll", "music instrument", "shield", "dagger", "headband", "leg ties", "staff", "potion", "crystal ball", "anklet", "ribbon", "lantern", "amulet", "ring"]
15
- OTHER_DETAILS = ["sparkles", "magical aura", "lens flare", "fireworks in the background", "smoke effects", "light trails", "falling leaves", "glowing embers", "floating particles", "rays of light", "shimmering mist", "ethereal glow"]
16
- SCENES = ["sunset beach", "rainy city street at night", "floating ash land", "particles magic world", "high blue sky", "top of the building", "fantasy forest with glowing mushrooms", "futuristic skyline at dawn", "abandoned castle", "snowy mountain peak", "desert ruins", "underwater city", "enchanted meadow", "haunted mansion", "steampunk marketplace", "glacial cavern"]
17
- CAMERA_ANGLES = ["low-angle shot", "close-up shot", "bird's-eye view", "wide-angle shot", "over-the-shoulder shot", "extreme close-up", "panoramic view", "dynamic tracking shot", "fisheye view", "point-of-view shot"]
18
- QUALITY_PROMPTS = ["cinematic lighting", "sharp shadow", "award-winning", "masterpiece", "vivid colors", "high dynamic range", "immersive", "studio quality", "fine art", "dreamlike", "8K", "HD", "high quality", "best quality", "artistic", "vibrant"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # Hugging Face DTR 数据集路径(示例,若不可用请忽略)
21
  DTR_DATASET_PATTERN = "https://huggingface.co/datasets/X779/Danbooruwildcards/resolve/main/*DTR*.txt"
@@ -68,7 +96,7 @@ def load_dtr_from_huggingface(excluded_tags=None):
68
  print(f"Error loading DTR dataset: {e}")
69
  return []
70
 
71
- def generate_natural_language_description(tags, api_key=None, base_url=None, model="gpt-4o"):
72
  """
73
  使用 OpenAI GPT 或 DeepSeek API 生成自然语言描述。
74
  """
@@ -77,7 +105,7 @@ def generate_natural_language_description(tags, api_key=None, base_url=None, mod
77
  if not api_key:
78
  return "Error: No API Key provided and none found in environment variables."
79
 
80
- # 将dict转成可读字符串
81
  tag_descriptions = "\n".join([
82
  f"{key}: {', '.join(value) if isinstance(value, list) else value}"
83
  for key, value in tags.items() if value
@@ -106,6 +134,8 @@ def generate_natural_language_description(tags, api_key=None, base_url=None, mod
106
  except Exception as e:
107
  return f"GPT generation failed. Error: {e}"
108
 
 
 
109
  def generate_prompt(
110
  action_file, style_file, artist_files, character_files, dtr_enabled, api_key, selected_categories,
111
  expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count,
@@ -194,12 +224,98 @@ def generate_prompt(
194
  combined_output = f"{final_tags}\n\n{description}"
195
  return final_tags, description, combined_output
196
 
197
- # ========== Favorite 相关函数 ==========
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  def add_to_favorites(combined_output, current_favorites):
199
  """
200
  将当前生成的 combined_output 添加到收藏列表中(最多存 3 条)。
201
  """
202
- # current_favorites 是一个列表
203
  current_favorites.append(combined_output)
204
  # 如果超过3条,移除最早的一条
205
  if len(current_favorites) > 3:
@@ -210,6 +326,7 @@ def add_to_favorites(combined_output, current_favorites):
210
  )
211
  return favorites_text, current_favorites
212
 
 
213
  # ========== Gradio 界面 ==========
214
  def gradio_interface():
215
  """
@@ -237,10 +354,8 @@ def gradio_interface():
237
  )
238
 
239
  use_deepseek = gr.Checkbox(label="Use DeepSeek API")
240
-
241
  dtr_enabled = gr.Checkbox(label="Enable DTR (如不可用请勿勾选)")
242
 
243
- # 上传文件部分
244
  with gr.Group():
245
  gr.Markdown("**上传文件 (可选):**")
246
  action_file = gr.File(label="Action File", file_types=[".txt"])
@@ -248,13 +363,11 @@ def gradio_interface():
248
  artist_files = gr.Files(label="Artist Files", file_types=[".txt"])
249
  character_files = gr.Files(label="Character Files", file_types=[".txt"])
250
 
251
- # 选择角色类型
252
  selected_categories = gr.CheckboxGroup(
253
  ["1boy", "1girl", "furry", "mecha", "fantasy monster", "animal", "still life"],
254
  label="Choose Character Categories"
255
  )
256
 
257
- # 输入排除和自定义
258
  excluded_tags = gr.Textbox(
259
  label="排除 Tags (逗号分隔)",
260
  placeholder="如:angry, sword"
@@ -264,7 +377,6 @@ def gradio_interface():
264
  placeholder="如:glowing eyes, giant wings"
265
  )
266
 
267
- # 各种数量
268
  with gr.Group():
269
  gr.Markdown("**随机数量设置:**")
270
  expression_count = gr.Slider(label="Number of Expressions", minimum=0, maximum=10, step=1, value=1)
@@ -277,11 +389,10 @@ def gradio_interface():
277
  style_count = gr.Slider(label="Number of Styles", minimum=1, maximum=10, step=1, value=1)
278
  artist_count = gr.Slider(label="Number of Artists", minimum=1, maximum=10, step=1, value=1)
279
 
280
- # 右侧:生成按钮 + 生成结果 + 收藏
281
  with gr.Column(scale=2):
282
  generate_button = gr.Button("Generate Prompt", variant="primary")
283
 
284
- # 生成的结果(可编辑)
285
  tags_output = gr.Textbox(
286
  label="Generated Tags",
287
  placeholder="等待生成...",
@@ -300,8 +411,25 @@ def gradio_interface():
300
  lines=6
301
  )
302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  with gr.Row():
304
- # 收藏操作
305
  favorite_button = gr.Button("收藏本次结果")
306
  favorites_box = gr.Textbox(
307
  label="收藏夹 (最多 3 条)",
@@ -309,17 +437,45 @@ def gradio_interface():
309
  lines=6
310
  )
311
 
312
- # 点击生成按钮
313
  generate_button.click(
314
  generate_prompt,
315
  inputs=[
316
- action_file, style_file, artist_files, character_files, dtr_enabled, api_key_input, selected_categories,
317
- expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count,
318
- artist_count, use_deepseek, deepseek_key_input, user_custom_tags, excluded_tags
 
 
 
319
  ],
320
  outputs=[tags_output, description_output, combined_output],
321
  )
322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  # 收藏按钮点击事件
324
  favorite_button.click(
325
  fn=add_to_favorites,
@@ -329,6 +485,7 @@ def gradio_interface():
329
 
330
  return demo
331
 
332
- # 启动 Gradio 应用(在 Hugging Face Spaces 上只需要保留最后的 launch 即可)
 
333
  if __name__ == "__main__":
334
  gradio_interface().launch()
 
10
  load_dotenv()
11
 
12
  # ========== 默认选项和数据 ==========
13
+ EXPRESSIONS = [
14
+ "smiling", "determined", "surprised", "serene", "smug", "thinking",
15
+ "looking back", "laughing", "angry", "pensive", "confident",
16
+ "grinning", "thoughtful", "sad tears", "bewildered"
17
+ ]
18
+ ITEMS = [
19
+ "magic wand", "sword", "flower", "book of spells", "earrings", "loincloth",
20
+ "slippers", "ancient scroll", "music instrument", "shield", "dagger",
21
+ "headband", "leg ties", "staff", "potion", "crystal ball", "anklet",
22
+ "ribbon", "lantern", "amulet", "ring"
23
+ ]
24
+ OTHER_DETAILS = [
25
+ "sparkles", "magical aura", "lens flare", "fireworks in the background",
26
+ "smoke effects", "light trails", "falling leaves", "glowing embers",
27
+ "floating particles", "rays of light", "shimmering mist", "ethereal glow"
28
+ ]
29
+ SCENES = [
30
+ "sunset beach", "rainy city street at night", "floating ash land",
31
+ "particles magic world", "high blue sky", "top of the building",
32
+ "fantasy forest with glowing mushrooms", "futuristic skyline at dawn",
33
+ "abandoned castle", "snowy mountain peak", "desert ruins", "underwater city",
34
+ "enchanted meadow", "haunted mansion", "steampunk marketplace", "glacial cavern"
35
+ ]
36
+ CAMERA_ANGLES = [
37
+ "low-angle shot", "close-up shot", "bird's-eye view", "wide-angle shot",
38
+ "over-the-shoulder shot", "extreme close-up", "panoramic view",
39
+ "dynamic tracking shot", "fisheye view", "point-of-view shot"
40
+ ]
41
+ QUALITY_PROMPTS = [
42
+ "cinematic lighting", "sharp shadow", "award-winning", "masterpiece",
43
+ "vivid colors", "high dynamic range", "immersive", "studio quality",
44
+ "fine art", "dreamlike", "8K", "HD", "high quality", "best quality",
45
+ "artistic", "vibrant"
46
+ ]
47
 
48
  # Hugging Face DTR 数据集路径(示例,若不可用请忽略)
49
  DTR_DATASET_PATTERN = "https://huggingface.co/datasets/X779/Danbooruwildcards/resolve/main/*DTR*.txt"
 
96
  print(f"Error loading DTR dataset: {e}")
97
  return []
98
 
99
+ def generate_natural_language_description(tags, api_key=None, base_url=None, model="gpt-4"):
100
  """
101
  使用 OpenAI GPT 或 DeepSeek API 生成自然语言描述。
102
  """
 
105
  if not api_key:
106
  return "Error: No API Key provided and none found in environment variables."
107
 
108
+ # 将 dict 转成可读字符串
109
  tag_descriptions = "\n".join([
110
  f"{key}: {', '.join(value) if isinstance(value, list) else value}"
111
  for key, value in tags.items() if value
 
134
  except Exception as e:
135
  return f"GPT generation failed. Error: {e}"
136
 
137
+
138
+ # ========== 核心函数:随机生成 prompt ==========
139
  def generate_prompt(
140
  action_file, style_file, artist_files, character_files, dtr_enabled, api_key, selected_categories,
141
  expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count,
 
224
  combined_output = f"{final_tags}\n\n{description}"
225
  return final_tags, description, combined_output
226
 
227
+
228
+ # ========== 部分更新:只根据用户修改后的 tags_text 生成新的描述和合并输出 ==========
229
+ def update_description(tags_text, api_key, use_deepseek, deepseek_key):
230
+ """
231
+ 只根据用户提供的 tags_text 生成描述和合并输出。
232
+ 不再重新随机抽取,以免破坏用户手动修改过的 Tags。
233
+ """
234
+ if not api_key and not deepseek_key:
235
+ # 没有提供任意可用 API Key
236
+ return "(No API Key provided)", f"{tags_text}\n\n(No API Key provided)"
237
+
238
+ # 构造给 GPT 的 prompt
239
+ user_prompt = (
240
+ "You are a creative assistant that generates detailed, imaginative scene descriptions for AI generation.\n"
241
+ "Below is the user's current tags (prompt elements). "
242
+ "Generate a new descriptive text (3-5 sentences) that incorporates these tags.\n\n"
243
+ f"User Tags: {tags_text}\n"
244
+ "Please generate a vivid, imaginative scene description."
245
+ )
246
+
247
+ try:
248
+ if use_deepseek:
249
+ # 调用 DeepSeek
250
+ client = OpenAI(api_key=deepseek_key, base_url="https://api.deepseek.com")
251
+ model = "deepseek-chat"
252
+ else:
253
+ # 调用 OpenAI
254
+ client = OpenAI(api_key=api_key)
255
+ model = "gpt-4" # 或其他可用模型,比如 "gpt-3.5-turbo"
256
+ response = client.chat.completions.create(
257
+ messages=[
258
+ {
259
+ "role": "system",
260
+ "content": "You are a creative assistant that generates imaginative scene descriptions..."
261
+ },
262
+ {
263
+ "role": "user",
264
+ "content": user_prompt,
265
+ },
266
+ ],
267
+ model=model,
268
+ )
269
+ new_description = response.choices[0].message.content.strip()
270
+ except Exception as e:
271
+ new_description = f"(GPT generation failed: {e})"
272
+
273
+ new_combined_output = f"{tags_text}\n\n{new_description}"
274
+ return new_description, new_combined_output
275
+
276
+
277
+ # ========== 翻译功能:将 combined_output 翻译成用户选定语言 ==========
278
+ def translate_combined_output(combined_text, target_language, api_key, use_deepseek, deepseek_key):
279
+ """
280
+ 使用 GPT 或 DeepSeek API,将 combined_text 翻译成 target_language。
281
+ """
282
+ if not api_key and not deepseek_key:
283
+ return "(No API Key provided)"
284
+
285
+ # 简单用 GPT 做翻译,也可改成其他翻译 API
286
+ translation_prompt = (
287
+ f"You are a professional translator. Please translate the following text into {target_language}.\n\n"
288
+ f"{combined_text}"
289
+ )
290
+
291
+ try:
292
+ if use_deepseek:
293
+ # 调用 DeepSeek
294
+ client = OpenAI(api_key=deepseek_key, base_url="https://api.deepseek.com")
295
+ model = "deepseek-chat"
296
+ else:
297
+ # 调用 OpenAI
298
+ client = OpenAI(api_key=api_key)
299
+ model = "gpt-3.5-turbo" # 或者别的模型
300
+ response = client.chat.completions.create(
301
+ messages=[
302
+ {"role": "system", "content": "You are a professional translator."},
303
+ {"role": "user", "content": translation_prompt},
304
+ ],
305
+ model=model,
306
+ )
307
+ translated_text = response.choices[0].message.content.strip()
308
+ except Exception as e:
309
+ translated_text = f"(Translation failed: {e})"
310
+
311
+ return translated_text
312
+
313
+
314
+ # ========== 收藏功能:最多存 3 条 ==========
315
  def add_to_favorites(combined_output, current_favorites):
316
  """
317
  将当前生成的 combined_output 添加到收藏列表中(最多存 3 条)。
318
  """
 
319
  current_favorites.append(combined_output)
320
  # 如果超过3条,移除最早的一条
321
  if len(current_favorites) > 3:
 
326
  )
327
  return favorites_text, current_favorites
328
 
329
+
330
  # ========== Gradio 界面 ==========
331
  def gradio_interface():
332
  """
 
354
  )
355
 
356
  use_deepseek = gr.Checkbox(label="Use DeepSeek API")
 
357
  dtr_enabled = gr.Checkbox(label="Enable DTR (如不可用请勿勾选)")
358
 
 
359
  with gr.Group():
360
  gr.Markdown("**上传文件 (可选):**")
361
  action_file = gr.File(label="Action File", file_types=[".txt"])
 
363
  artist_files = gr.Files(label="Artist Files", file_types=[".txt"])
364
  character_files = gr.Files(label="Character Files", file_types=[".txt"])
365
 
 
366
  selected_categories = gr.CheckboxGroup(
367
  ["1boy", "1girl", "furry", "mecha", "fantasy monster", "animal", "still life"],
368
  label="Choose Character Categories"
369
  )
370
 
 
371
  excluded_tags = gr.Textbox(
372
  label="排除 Tags (逗号分隔)",
373
  placeholder="如:angry, sword"
 
377
  placeholder="如:glowing eyes, giant wings"
378
  )
379
 
 
380
  with gr.Group():
381
  gr.Markdown("**随机数量设置:**")
382
  expression_count = gr.Slider(label="Number of Expressions", minimum=0, maximum=10, step=1, value=1)
 
389
  style_count = gr.Slider(label="Number of Styles", minimum=1, maximum=10, step=1, value=1)
390
  artist_count = gr.Slider(label="Number of Artists", minimum=1, maximum=10, step=1, value=1)
391
 
392
+ # 右侧:生成按钮 + 生成结果 + 收藏 + 翻译
393
  with gr.Column(scale=2):
394
  generate_button = gr.Button("Generate Prompt", variant="primary")
395
 
 
396
  tags_output = gr.Textbox(
397
  label="Generated Tags",
398
  placeholder="等待生成...",
 
411
  lines=6
412
  )
413
 
414
+ # 新增一个按钮,只更新 description 和 combined
415
+ update_desc_button = gr.Button("Update Description Only")
416
+
417
+ # 翻译相关
418
+ with gr.Row():
419
+ target_language = gr.Dropdown(
420
+ choices=["English", "Chinese", "Japanese"],
421
+ value="English",
422
+ label="Target Language"
423
+ )
424
+ translate_button = gr.Button("Translate to selected language")
425
+ translated_output = gr.Textbox(
426
+ label="Translated Output",
427
+ placeholder="等待翻译...",
428
+ lines=6
429
+ )
430
+
431
+ # 收藏
432
  with gr.Row():
 
433
  favorite_button = gr.Button("收藏本次结果")
434
  favorites_box = gr.Textbox(
435
  label="收藏夹 (最多 3 条)",
 
437
  lines=6
438
  )
439
 
440
+ # 点击“Generate Prompt”按钮
441
  generate_button.click(
442
  generate_prompt,
443
  inputs=[
444
+ action_file, style_file, artist_files, character_files,
445
+ dtr_enabled, api_key_input, selected_categories,
446
+ expression_count, item_count, detail_count, scene_count,
447
+ angle_count, quality_count, action_count, style_count,
448
+ artist_count, use_deepseek, deepseek_key_input,
449
+ user_custom_tags, excluded_tags
450
  ],
451
  outputs=[tags_output, description_output, combined_output],
452
  )
453
 
454
+ # 点击“Update Description Only”按钮
455
+ update_desc_button.click(
456
+ update_description,
457
+ inputs=[
458
+ tags_output, # 用户在文本框里编辑后的 Tags
459
+ api_key_input,
460
+ use_deepseek,
461
+ deepseek_key_input,
462
+ ],
463
+ outputs=[description_output, combined_output],
464
+ )
465
+
466
+ # 点击“Translate to selected language”按钮
467
+ translate_button.click(
468
+ fn=translate_combined_output,
469
+ inputs=[
470
+ combined_output, # 要翻译的源文本
471
+ target_language,
472
+ api_key_input,
473
+ use_deepseek,
474
+ deepseek_key_input
475
+ ],
476
+ outputs=[translated_output],
477
+ )
478
+
479
  # 收藏按钮点击事件
480
  favorite_button.click(
481
  fn=add_to_favorites,
 
485
 
486
  return demo
487
 
488
+
489
+ # 启动 Gradio 应用
490
  if __name__ == "__main__":
491
  gradio_interface().launch()