PSNbst commited on
Commit
3b3e0c8
·
verified ·
1 Parent(s): 3a4bd62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -31
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import random
3
  import glob
4
  import os
 
5
  from openai import OpenAI
6
  from dotenv import load_dotenv
7
 
@@ -29,28 +30,23 @@ def load_candidates_from_files(files):
29
  all_lines.extend([line.strip() for line in f if line.strip()])
30
  return all_lines
31
 
32
- def get_random_item(candidates):
33
  """
34
- 随机选取候选项。
35
  """
36
- return random.choice(candidates) if candidates else ""
37
 
38
- def load_dtr_from_directory(directory=".", pattern="*DTR*"):
39
  """
40
- 从指定目录中加载所有包含特定模式的文件内容。
41
- :param directory: 目标目录,默认为当前目录
42
- :param pattern: 匹配文件名的模式,默认是包含 "DTR" 的文件
43
- :return: 文件内容的列表
44
  """
45
- dtr_candidates = []
46
  try:
47
- files = glob.glob(os.path.join(directory, pattern))
48
- for file in files:
49
- with open(file, "r", encoding="utf-8") as f:
50
- dtr_candidates.extend([line.strip() for line in f if line.strip()])
51
  except Exception as e:
52
- print(f"Error loading DTR files: {e}")
53
- return dtr_candidates
54
 
55
  def generate_natural_language_description(tags, api_key=None):
56
  """
@@ -87,7 +83,10 @@ def generate_natural_language_description(tags, api_key=None):
87
  except Exception as e:
88
  return f"GPT generation failed. Error: {e}"
89
 
90
- def generate_prompt(action_file, style_file, artist_files, character_files, dtr_directory, api_key, selected_categories):
 
 
 
91
  """
92
  生成随机提示词和描述。
93
  """
@@ -95,28 +94,28 @@ def generate_prompt(action_file, style_file, artist_files, character_files, dtr_
95
  styles = load_candidates_from_files([style_file]) if style_file else []
96
  artists = load_candidates_from_files(artist_files) if artist_files else []
97
  characters = load_candidates_from_files(character_files) if character_files else []
98
- dtr_candidates = load_dtr_from_directory(dtr_directory) if dtr_directory else []
99
 
100
  number_of_characters = ", ".join(selected_categories) if selected_categories else random.choice(["1girl", "1boy"])
101
 
102
  tags = {
103
  "number_of_characters": number_of_characters,
104
- "character_name": get_random_item(characters),
105
- "artist_prompt": f"(artist:{get_random_item(artists)})",
106
- "style": get_random_item(styles),
107
- "scene": get_random_item(SCENES),
108
- "camera_angle": get_random_item(CAMERA_ANGLES),
109
- "action": get_random_item(actions),
110
- "expression": get_random_item(EXPRESSIONS),
111
- "items": get_random_item(ITEMS),
112
- "other_details": get_random_item(OTHER_DETAILS),
113
- "quality_prompts": get_random_item(QUALITY_PROMPTS),
114
- "dtr": get_random_item(dtr_candidates)
115
  }
116
 
117
  description = generate_natural_language_description(tags, api_key)
118
 
119
- tags_list = [value for value in tags.values() if value]
120
  unique_tags = list(dict.fromkeys(tags_list))
121
  final_tags = ", ".join(unique_tags)
122
  combined_output = f"{final_tags}\n\n{description}"
@@ -128,7 +127,7 @@ def gradio_interface():
128
  定义 Gradio 应用界面。
129
  """
130
  with gr.Blocks() as demo:
131
- gr.Markdown("## Random Prompt Generator with User-Provided GPT API Key")
132
 
133
  api_key_input = gr.Textbox(
134
  label="Enter your OpenAI API Key (Optional)",
@@ -145,12 +144,29 @@ def gradio_interface():
145
  character_files = gr.Files(label="Upload Character Files (Multiple Allowed)", file_types=[".txt"])
146
 
147
  dtr_enabled = gr.Checkbox(label="Enable DTR Directory")
 
 
 
 
 
148
 
149
  selected_categories = gr.CheckboxGroup(
150
  ["1boy", "1girl", "furry", "mecha", "fantasy monster", "animal", "still life"],
151
  label="Choose Character Categories (Optional)"
152
  )
153
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  with gr.Row():
155
  tags_output = gr.Textbox(label="Generated Tags")
156
  description_output = gr.Textbox(label="Generated Description")
@@ -161,7 +177,8 @@ def gradio_interface():
161
  generate_button.click(
162
  generate_prompt,
163
  inputs=[
164
- action_file, style_file, artist_files, character_files, dtr_enabled, api_key_input, selected_categories
 
165
  ],
166
  outputs=[tags_output, description_output, combined_output],
167
  )
 
2
  import random
3
  import glob
4
  import os
5
+ import requests
6
  from openai import OpenAI
7
  from dotenv import load_dotenv
8
 
 
30
  all_lines.extend([line.strip() for line in f if line.strip()])
31
  return all_lines
32
 
33
+ def get_random_items(candidates, num_items=1):
34
  """
35
+ 从候选项中随机选取指定数量的选项。
36
  """
37
+ return random.sample(candidates, min(num_items, len(candidates))) if candidates else []
38
 
39
+ def load_dtr_from_huggingface(dataset_url):
40
  """
41
+ 从 Hugging Face 数据集中加载 DTR 内容。
 
 
 
42
  """
 
43
  try:
44
+ response = requests.get(dataset_url)
45
+ response.raise_for_status()
46
+ return response.text.splitlines()
 
47
  except Exception as e:
48
+ print(f"Error loading DTR dataset: {e}")
49
+ return []
50
 
51
  def generate_natural_language_description(tags, api_key=None):
52
  """
 
83
  except Exception as e:
84
  return f"GPT generation failed. Error: {e}"
85
 
86
+ def generate_prompt(
87
+ action_file, style_file, artist_files, character_files, dtr_enabled, dtr_url, api_key, selected_categories,
88
+ expression_count, item_count, detail_count, scene_count, angle_count, quality_count
89
+ ):
90
  """
91
  生成随机提示词和描述。
92
  """
 
94
  styles = load_candidates_from_files([style_file]) if style_file else []
95
  artists = load_candidates_from_files(artist_files) if artist_files else []
96
  characters = load_candidates_from_files(character_files) if character_files else []
97
+ dtr_candidates = load_dtr_from_huggingface(dtr_url) if dtr_enabled else []
98
 
99
  number_of_characters = ", ".join(selected_categories) if selected_categories else random.choice(["1girl", "1boy"])
100
 
101
  tags = {
102
  "number_of_characters": number_of_characters,
103
+ "character_name": get_random_items(characters, 1),
104
+ "artist_prompt": f"(artist:{get_random_items(artists, 1)})",
105
+ "style": get_random_items(styles, 1),
106
+ "scene": get_random_items(SCENES, scene_count),
107
+ "camera_angle": get_random_items(CAMERA_ANGLES, angle_count),
108
+ "action": get_random_items(actions, 1),
109
+ "expression": get_random_items(EXPRESSIONS, expression_count),
110
+ "items": get_random_items(ITEMS, item_count),
111
+ "other_details": get_random_items(OTHER_DETAILS, detail_count),
112
+ "quality_prompts": get_random_items(QUALITY_PROMPTS, quality_count),
113
+ "dtr": get_random_items(dtr_candidates, 1)
114
  }
115
 
116
  description = generate_natural_language_description(tags, api_key)
117
 
118
+ tags_list = [item for sublist in tags.values() for item in (sublist if isinstance(sublist, list) else [sublist])] # Flatten
119
  unique_tags = list(dict.fromkeys(tags_list))
120
  final_tags = ", ".join(unique_tags)
121
  combined_output = f"{final_tags}\n\n{description}"
 
127
  定义 Gradio 应用界面。
128
  """
129
  with gr.Blocks() as demo:
130
+ gr.Markdown("## Random Prompt Generator with Adjustable Tag Counts")
131
 
132
  api_key_input = gr.Textbox(
133
  label="Enter your OpenAI API Key (Optional)",
 
144
  character_files = gr.Files(label="Upload Character Files (Multiple Allowed)", file_types=[".txt"])
145
 
146
  dtr_enabled = gr.Checkbox(label="Enable DTR Directory")
147
+ dtr_url = gr.Textbox(
148
+ label="Hugging Face DTR Dataset URL",
149
+ placeholder="https://huggingface.co/datasets/X779/Danbooruwildcards/tree/main",
150
+ value="https://huggingface.co/datasets/X779/Danbooruwildcards/tree/main"
151
+ )
152
 
153
  selected_categories = gr.CheckboxGroup(
154
  ["1boy", "1girl", "furry", "mecha", "fantasy monster", "animal", "still life"],
155
  label="Choose Character Categories (Optional)"
156
  )
157
 
158
+ with gr.Row():
159
+ expression_count = gr.Slider(label="Number of Expressions", minimum=1, maximum=5, step=1, value=1)
160
+ item_count = gr.Slider(label="Number of Items", minimum=1, maximum=5, step=1, value=1)
161
+
162
+ with gr.Row():
163
+ detail_count = gr.Slider(label="Number of Other Details", minimum=1, maximum=5, step=1, value=1)
164
+ scene_count = gr.Slider(label="Number of Scenes", minimum=1, maximum=5, step=1, value=1)
165
+
166
+ with gr.Row():
167
+ angle_count = gr.Slider(label="Number of Camera Angles", minimum=1, maximum=5, step=1, value=1)
168
+ quality_count = gr.Slider(label="Number of Quality Prompts", minimum=1, maximum=5, step=1, value=1)
169
+
170
  with gr.Row():
171
  tags_output = gr.Textbox(label="Generated Tags")
172
  description_output = gr.Textbox(label="Generated Description")
 
177
  generate_button.click(
178
  generate_prompt,
179
  inputs=[
180
+ action_file, style_file, artist_files, character_files, dtr_enabled, dtr_url, api_key_input, selected_categories,
181
+ expression_count, item_count, detail_count, scene_count, angle_count, quality_count
182
  ],
183
  outputs=[tags_output, description_output, combined_output],
184
  )