PSNbst commited on
Commit
38f52fd
·
verified ·
1 Parent(s): 00e31d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -7
app.py CHANGED
@@ -51,9 +51,9 @@ def load_dtr_from_huggingface():
51
  print(f"Error loading DTR dataset: {e}")
52
  return []
53
 
54
- def generate_natural_language_description(tags, api_key=None):
55
  """
56
- 使用 OpenAI GPT 生成自然语言描述。
57
  """
58
  if not api_key:
59
  api_key = os.getenv("OPENAI_API_KEY")
@@ -63,7 +63,7 @@ def generate_natural_language_description(tags, api_key=None):
63
  tag_descriptions = "\n".join([f"{key}: {value}" for key, value in tags.items() if value])
64
 
65
  try:
66
- client = OpenAI(api_key=api_key)
67
 
68
  response = client.chat.completions.create(
69
  messages=[
@@ -80,7 +80,7 @@ def generate_natural_language_description(tags, api_key=None):
80
  "content": f"Here are the tags and details:\n{tag_descriptions}\nPlease generate a vivid, imaginative scene description.",
81
  },
82
  ],
83
- model="gpt-4o",
84
  )
85
  return response.choices[0].message.content.strip()
86
  except Exception as e:
@@ -88,7 +88,8 @@ def generate_natural_language_description(tags, api_key=None):
88
 
89
  def generate_prompt(
90
  action_file, style_file, artist_files, character_files, dtr_enabled, api_key, selected_categories,
91
- expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count
 
92
  ):
93
  """
94
  生成随机提示词和描述。
@@ -116,7 +117,10 @@ def generate_prompt(
116
  "dtr": dtr_candidates
117
  }
118
 
119
- description = generate_natural_language_description(tags, api_key)
 
 
 
120
 
121
  tags_list = [item for sublist in tags.values() for item in (sublist if isinstance(sublist, list) else [sublist])] # Flatten
122
  unique_tags = list(dict.fromkeys(tags_list))
@@ -138,6 +142,14 @@ def gradio_interface():
138
  type="password"
139
  )
140
 
 
 
 
 
 
 
 
 
141
  with gr.Row():
142
  action_file = gr.File(label="Upload Action File (Optional)", file_types=[".txt"])
143
  style_file = gr.File(label="Upload Style File (Optional)", file_types=[".txt"])
@@ -176,7 +188,8 @@ def gradio_interface():
176
  generate_prompt,
177
  inputs=[
178
  action_file, style_file, artist_files, character_files, dtr_enabled, api_key_input, selected_categories,
179
- expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count
 
180
  ],
181
  outputs=[tags_output, description_output, combined_output],
182
  )
 
51
  print(f"Error loading DTR dataset: {e}")
52
  return []
53
 
54
+ def generate_natural_language_description(tags, api_key=None, base_url=None, model="gpt-4o"):
55
  """
56
+ 使用 OpenAI GPT 或 DeepSeek API 生成自然语言描述。
57
  """
58
  if not api_key:
59
  api_key = os.getenv("OPENAI_API_KEY")
 
63
  tag_descriptions = "\n".join([f"{key}: {value}" for key, value in tags.items() if value])
64
 
65
  try:
66
+ client = OpenAI(api_key=api_key, base_url=base_url) if base_url else OpenAI(api_key=api_key)
67
 
68
  response = client.chat.completions.create(
69
  messages=[
 
80
  "content": f"Here are the tags and details:\n{tag_descriptions}\nPlease generate a vivid, imaginative scene description.",
81
  },
82
  ],
83
+ model=model,
84
  )
85
  return response.choices[0].message.content.strip()
86
  except Exception as e:
 
88
 
89
  def generate_prompt(
90
  action_file, style_file, artist_files, character_files, dtr_enabled, api_key, selected_categories,
91
+ expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count,
92
+ use_deepseek, deepseek_key
93
  ):
94
  """
95
  生成随机提示词和描述。
 
117
  "dtr": dtr_candidates
118
  }
119
 
120
+ if use_deepseek:
121
+ description = generate_natural_language_description(tags, api_key=deepseek_key, base_url="https://api.deepseek.com", model="deepseek-chat")
122
+ else:
123
+ description = generate_natural_language_description(tags, api_key=api_key)
124
 
125
  tags_list = [item for sublist in tags.values() for item in (sublist if isinstance(sublist, list) else [sublist])] # Flatten
126
  unique_tags = list(dict.fromkeys(tags_list))
 
142
  type="password"
143
  )
144
 
145
+ deepseek_key_input = gr.Textbox(
146
+ label="Enter your DeepSeek API Key (Optional)",
147
+ placeholder="sk-...",
148
+ type="password"
149
+ )
150
+
151
+ use_deepseek = gr.Checkbox(label="Use DeepSeek API")
152
+
153
  with gr.Row():
154
  action_file = gr.File(label="Upload Action File (Optional)", file_types=[".txt"])
155
  style_file = gr.File(label="Upload Style File (Optional)", file_types=[".txt"])
 
188
  generate_prompt,
189
  inputs=[
190
  action_file, style_file, artist_files, character_files, dtr_enabled, api_key_input, selected_categories,
191
+ expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count,
192
+ use_deepseek, deepseek_key_input
193
  ],
194
  outputs=[tags_output, description_output, combined_output],
195
  )