Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -51,9 +51,9 @@ def load_dtr_from_huggingface():
|
|
51 |
print(f"Error loading DTR dataset: {e}")
|
52 |
return []
|
53 |
|
54 |
-
def generate_natural_language_description(tags, api_key=None):
|
55 |
"""
|
56 |
-
使用 OpenAI GPT 生成自然语言描述。
|
57 |
"""
|
58 |
if not api_key:
|
59 |
api_key = os.getenv("OPENAI_API_KEY")
|
@@ -63,7 +63,7 @@ def generate_natural_language_description(tags, api_key=None):
|
|
63 |
tag_descriptions = "\n".join([f"{key}: {value}" for key, value in tags.items() if value])
|
64 |
|
65 |
try:
|
66 |
-
client = OpenAI(api_key=api_key)
|
67 |
|
68 |
response = client.chat.completions.create(
|
69 |
messages=[
|
@@ -80,7 +80,7 @@ def generate_natural_language_description(tags, api_key=None):
|
|
80 |
"content": f"Here are the tags and details:\n{tag_descriptions}\nPlease generate a vivid, imaginative scene description.",
|
81 |
},
|
82 |
],
|
83 |
-
model=
|
84 |
)
|
85 |
return response.choices[0].message.content.strip()
|
86 |
except Exception as e:
|
@@ -88,7 +88,8 @@ def generate_natural_language_description(tags, api_key=None):
|
|
88 |
|
89 |
def generate_prompt(
|
90 |
action_file, style_file, artist_files, character_files, dtr_enabled, api_key, selected_categories,
|
91 |
-
expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count
|
|
|
92 |
):
|
93 |
"""
|
94 |
生成随机提示词和描述。
|
@@ -116,7 +117,10 @@ def generate_prompt(
|
|
116 |
"dtr": dtr_candidates
|
117 |
}
|
118 |
|
119 |
-
|
|
|
|
|
|
|
120 |
|
121 |
tags_list = [item for sublist in tags.values() for item in (sublist if isinstance(sublist, list) else [sublist])] # Flatten
|
122 |
unique_tags = list(dict.fromkeys(tags_list))
|
@@ -138,6 +142,14 @@ def gradio_interface():
|
|
138 |
type="password"
|
139 |
)
|
140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
with gr.Row():
|
142 |
action_file = gr.File(label="Upload Action File (Optional)", file_types=[".txt"])
|
143 |
style_file = gr.File(label="Upload Style File (Optional)", file_types=[".txt"])
|
@@ -176,7 +188,8 @@ def gradio_interface():
|
|
176 |
generate_prompt,
|
177 |
inputs=[
|
178 |
action_file, style_file, artist_files, character_files, dtr_enabled, api_key_input, selected_categories,
|
179 |
-
expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count
|
|
|
180 |
],
|
181 |
outputs=[tags_output, description_output, combined_output],
|
182 |
)
|
|
|
51 |
print(f"Error loading DTR dataset: {e}")
|
52 |
return []
|
53 |
|
54 |
+
def generate_natural_language_description(tags, api_key=None, base_url=None, model="gpt-4o"):
|
55 |
"""
|
56 |
+
使用 OpenAI GPT 或 DeepSeek API 生成自然语言描述。
|
57 |
"""
|
58 |
if not api_key:
|
59 |
api_key = os.getenv("OPENAI_API_KEY")
|
|
|
63 |
tag_descriptions = "\n".join([f"{key}: {value}" for key, value in tags.items() if value])
|
64 |
|
65 |
try:
|
66 |
+
client = OpenAI(api_key=api_key, base_url=base_url) if base_url else OpenAI(api_key=api_key)
|
67 |
|
68 |
response = client.chat.completions.create(
|
69 |
messages=[
|
|
|
80 |
"content": f"Here are the tags and details:\n{tag_descriptions}\nPlease generate a vivid, imaginative scene description.",
|
81 |
},
|
82 |
],
|
83 |
+
model=model,
|
84 |
)
|
85 |
return response.choices[0].message.content.strip()
|
86 |
except Exception as e:
|
|
|
88 |
|
89 |
def generate_prompt(
|
90 |
action_file, style_file, artist_files, character_files, dtr_enabled, api_key, selected_categories,
|
91 |
+
expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count,
|
92 |
+
use_deepseek, deepseek_key
|
93 |
):
|
94 |
"""
|
95 |
生成随机提示词和描述。
|
|
|
117 |
"dtr": dtr_candidates
|
118 |
}
|
119 |
|
120 |
+
if use_deepseek:
|
121 |
+
description = generate_natural_language_description(tags, api_key=deepseek_key, base_url="https://api.deepseek.com", model="deepseek-chat")
|
122 |
+
else:
|
123 |
+
description = generate_natural_language_description(tags, api_key=api_key)
|
124 |
|
125 |
tags_list = [item for sublist in tags.values() for item in (sublist if isinstance(sublist, list) else [sublist])] # Flatten
|
126 |
unique_tags = list(dict.fromkeys(tags_list))
|
|
|
142 |
type="password"
|
143 |
)
|
144 |
|
145 |
+
deepseek_key_input = gr.Textbox(
|
146 |
+
label="Enter your DeepSeek API Key (Optional)",
|
147 |
+
placeholder="sk-...",
|
148 |
+
type="password"
|
149 |
+
)
|
150 |
+
|
151 |
+
use_deepseek = gr.Checkbox(label="Use DeepSeek API")
|
152 |
+
|
153 |
with gr.Row():
|
154 |
action_file = gr.File(label="Upload Action File (Optional)", file_types=[".txt"])
|
155 |
style_file = gr.File(label="Upload Style File (Optional)", file_types=[".txt"])
|
|
|
188 |
generate_prompt,
|
189 |
inputs=[
|
190 |
action_file, style_file, artist_files, character_files, dtr_enabled, api_key_input, selected_categories,
|
191 |
+
expression_count, item_count, detail_count, scene_count, angle_count, quality_count, action_count, style_count,
|
192 |
+
use_deepseek, deepseek_key_input
|
193 |
],
|
194 |
outputs=[tags_output, description_output, combined_output],
|
195 |
)
|