Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ import random
|
|
4 |
import torch
|
5 |
import spaces
|
6 |
from diffusers import DiffusionPipeline
|
7 |
-
import importlib
|
8 |
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl" # Replace with your desired model
|
@@ -22,18 +22,19 @@ MAX_IMAGE_SIZE = 1024
|
|
22 |
|
23 |
# Function to load tags dynamically based on the selected tab
|
24 |
def load_tags(active_tab):
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
35 |
|
36 |
-
@spaces.GPU
|
37 |
@spaces.GPU
|
38 |
def infer(
|
39 |
prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
|
@@ -45,7 +46,7 @@ def infer(
|
|
45 |
# Dynamically load the correct tags module based on active tab
|
46 |
tags_module = load_tags(active_tab)
|
47 |
|
48 |
-
#
|
49 |
participant_tags = tags_module.participant_tags
|
50 |
tribe_tags = tags_module.tribe_tags
|
51 |
role_tags = tags_module.role_tags
|
@@ -62,74 +63,114 @@ def infer(
|
|
62 |
camera_tags = tags_module.camera_tags
|
63 |
atmosphere_tags = tags_module.atmosphere_tags
|
64 |
|
65 |
-
#
|
66 |
-
print(f"Loaded tags for {active_tab}:")
|
67 |
-
print(f"participant_tags: {list(participant_tags.keys())}")
|
68 |
-
print(f"tribe_tags: {list(tribe_tags.keys())}")
|
69 |
-
print(f"role_tags: {list(role_tags.keys())}")
|
70 |
-
print(f"skin_tone_tags: {list(skin_tone_tags.keys())}")
|
71 |
-
|
72 |
-
# Handle the active tab and generate the prompt accordingly
|
73 |
tag_list = []
|
74 |
|
75 |
-
#
|
76 |
-
for
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
#
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
final_prompt = f"score_9, score_8_up, score_7_up, source_anime, {', '.join(tag_list)}"
|
104 |
|
105 |
-
#
|
106 |
additional_negatives = "worst quality, bad quality, jpeg artifacts, source_cartoon, 3d, (censor), monochrome, blurry, lowres, watermark"
|
107 |
full_negative_prompt = f"{additional_negatives}, {negative_prompt}"
|
108 |
|
|
|
109 |
if randomize_seed:
|
110 |
seed = random.randint(0, MAX_SEED)
|
111 |
generator = torch.Generator(device=device).manual_seed(seed)
|
112 |
|
113 |
-
#
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
print(f"Seed used for generation: {seed}")
|
128 |
-
|
129 |
-
return image, seed, f"Prompt: {final_prompt}\nNegative Prompt: {full_negative_prompt}"
|
130 |
|
|
|
131 |
|
132 |
-
#
|
133 |
css = """
|
134 |
#col-container {
|
135 |
margin: 0 auto;
|
@@ -146,12 +187,11 @@ with gr.Blocks(css=css) as demo:
|
|
146 |
active_tab = gr.State("Prompt Input")
|
147 |
|
148 |
with gr.Tabs() as tabs:
|
149 |
-
#
|
150 |
with gr.TabItem("Prompt Input"):
|
151 |
prompt = gr.Textbox(label="Prompt", placeholder="Enter your custom prompt")
|
152 |
tabs.select(lambda: "Prompt Input", inputs=None, outputs=active_tab)
|
153 |
|
154 |
-
# Straight Tab
|
155 |
with gr.TabItem("Straight"):
|
156 |
tags_module = load_tags("Straight")
|
157 |
selected_participant_tags = gr.CheckboxGroup(choices=list(tags_module.participant_tags.keys()), label="Participant Tags")
|
@@ -171,47 +211,7 @@ with gr.Blocks(css=css) as demo:
|
|
171 |
selected_atmosphere_tags = gr.CheckboxGroup(choices=list(tags_module.atmosphere_tags.keys()), label="Atmosphere Tags")
|
172 |
tabs.select(lambda: "Straight", inputs=None, outputs=active_tab)
|
173 |
|
174 |
-
|
175 |
-
with gr.TabItem("Gay"):
|
176 |
-
tags_module = load_tags("Gay")
|
177 |
-
selected_participant_tags = gr.CheckboxGroup(choices=list(tags_module.participant_tags.keys()), label="Participant Tags")
|
178 |
-
selected_tribe_tags = gr.CheckboxGroup(choices=list(tags_module.tribe_tags.keys()), label="Tribe Tags")
|
179 |
-
selected_role_tags = gr.CheckboxGroup(choices=list(tags_module.role_tags.keys()), label="Role Tags")
|
180 |
-
selected_skin_tone_tags = gr.CheckboxGroup(choices=list(tags_module.skin_tone_tags.keys()), label="Skin Tone Tags")
|
181 |
-
selected_body_type_tags = gr.CheckboxGroup(choices=list(tags_module.body_type_tags.keys()), label="Body Type Tags")
|
182 |
-
selected_tattoo_tags = gr.CheckboxGroup(choices=list(tags_module.tattoo_tags.keys()), label="Tattoo Tags")
|
183 |
-
selected_piercing_tags = gr.CheckboxGroup(choices=list(tags_module.piercing_tags.keys()), label="Piercing Tags")
|
184 |
-
selected_expression_tags = gr.CheckboxGroup(choices=list(tags_module.expression_tags.keys()), label="Expression Tags")
|
185 |
-
selected_eye_tags = gr.CheckboxGroup(choices=list(tags_module.eye_tags.keys()), label="Eye Tags")
|
186 |
-
selected_hair_style_tags = gr.CheckboxGroup(choices=list(tags_module.hair_style_tags.keys()), label="Hair Style Tags")
|
187 |
-
selected_position_tags = gr.CheckboxGroup(choices=list(tags_module.position_tags.keys()), label="Position Tags")
|
188 |
-
selected_fetish_tags = gr.CheckboxGroup(choices=list(tags_module.fetish_tags.keys()), label="Fetish Tags")
|
189 |
-
selected_location_tags = gr.CheckboxGroup(choices=list(tags_module.location_tags.keys()), label="Location Tags")
|
190 |
-
selected_camera_tags = gr.CheckboxGroup(choices=list(tags_module.camera_tags.keys()), label="Camera Tags")
|
191 |
-
selected_atmosphere_tags = gr.CheckboxGroup(choices=list(tags_module.atmosphere_tags.keys()), label="Atmosphere Tags")
|
192 |
-
tabs.select(lambda: "Gay", inputs=None, outputs=active_tab)
|
193 |
-
|
194 |
-
# Lesbian Tab
|
195 |
-
with gr.TabItem("Lesbian"):
|
196 |
-
tags_module = load_tags("Lesbian")
|
197 |
-
selected_participant_tags = gr.CheckboxGroup(choices=list(tags_module.participant_tags.keys()), label="Participant Tags")
|
198 |
-
selected_tribe_tags = gr.CheckboxGroup(choices=list(tags_module.tribe_tags.keys()), label="Tribe Tags")
|
199 |
-
selected_role_tags = gr.CheckboxGroup(choices=list(tags_module.role_tags.keys()), label="Role Tags")
|
200 |
-
selected_skin_tone_tags = gr.CheckboxGroup(choices=list(tags_module.skin_tone_tags.keys()), label="Skin Tone Tags")
|
201 |
-
selected_body_type_tags = gr.CheckboxGroup(choices=list(tags_module.body_type_tags.keys()), label="Body Type Tags")
|
202 |
-
selected_tattoo_tags = gr.CheckboxGroup(choices=list(tags_module.tattoo_tags.keys()), label="Tattoo Tags")
|
203 |
-
selected_piercing_tags = gr.CheckboxGroup(choices=list(tags_module.piercing_tags.keys()), label="Piercing Tags")
|
204 |
-
selected_expression_tags = gr.CheckboxGroup(choices=list(tags_module.expression_tags.keys()), label="Expression Tags")
|
205 |
-
selected_eye_tags = gr.CheckboxGroup(choices=list(tags_module.eye_tags.keys()), label="Eye Tags")
|
206 |
-
selected_hair_style_tags = gr.CheckboxGroup(choices=list(tags_module.hair_style_tags.keys()), label="Hair Style Tags")
|
207 |
-
selected_position_tags = gr.CheckboxGroup(choices=list(tags_module.position_tags.keys()), label="Position Tags")
|
208 |
-
selected_fetish_tags = gr.CheckboxGroup(choices=list(tags_module.fetish_tags.keys()), label="Fetish Tags")
|
209 |
-
selected_location_tags = gr.CheckboxGroup(choices=list(tags_module.location_tags.keys()), label="Location Tags")
|
210 |
-
selected_camera_tags = gr.CheckboxGroup(choices=list(tags_module.camera_tags.keys()), label="Camera Tags")
|
211 |
-
selected_atmosphere_tags = gr.CheckboxGroup(choices=list(tags_module.atmosphere_tags.keys()), label="Atmosphere Tags")
|
212 |
-
tabs.select(lambda: "Lesbian", inputs=None, outputs=active_tab)
|
213 |
-
|
214 |
-
# Advanced Settings
|
215 |
with gr.Accordion("Advanced Settings", open=False):
|
216 |
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt")
|
217 |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
|
|
4 |
import torch
|
5 |
import spaces
|
6 |
from diffusers import DiffusionPipeline
|
7 |
+
import importlib
|
8 |
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl" # Replace with your desired model
|
|
|
22 |
|
23 |
# Function to load tags dynamically based on the selected tab
|
24 |
def load_tags(active_tab):
|
25 |
+
try:
|
26 |
+
if active_tab == "Gay":
|
27 |
+
return importlib.import_module('tags_gay') # dynamically import the tags_gay module
|
28 |
+
elif active_tab == "Straight":
|
29 |
+
return importlib.import_module('tags_straight') # dynamically import the tags_straight module
|
30 |
+
elif active_tab == "Lesbian":
|
31 |
+
return importlib.import_module('tags_lesbian') # dynamically import the tags_lesbian module
|
32 |
+
else:
|
33 |
+
raise ValueError(f"Unknown tab: {active_tab}")
|
34 |
+
except Exception as e:
|
35 |
+
print(f"Error loading tags for {active_tab}: {str(e)}")
|
36 |
+
raise
|
37 |
|
|
|
38 |
@spaces.GPU
|
39 |
def infer(
|
40 |
prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
|
|
|
46 |
# Dynamically load the correct tags module based on active tab
|
47 |
tags_module = load_tags(active_tab)
|
48 |
|
49 |
+
# Get the tag dictionaries from the loaded module
|
50 |
participant_tags = tags_module.participant_tags
|
51 |
tribe_tags = tags_module.tribe_tags
|
52 |
role_tags = tags_module.role_tags
|
|
|
63 |
camera_tags = tags_module.camera_tags
|
64 |
atmosphere_tags = tags_module.atmosphere_tags
|
65 |
|
66 |
+
# Build the tag list using selected tags from each group
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
tag_list = []
|
68 |
|
69 |
+
# Add selected participant tags
|
70 |
+
for tag in selected_participant_tags:
|
71 |
+
if tag in participant_tags:
|
72 |
+
tag_list.append(participant_tags[tag])
|
73 |
+
|
74 |
+
# Add selected tribe tags
|
75 |
+
for tag in selected_tribe_tags:
|
76 |
+
if tag in tribe_tags:
|
77 |
+
tag_list.append(tribe_tags[tag])
|
78 |
+
|
79 |
+
# Add selected role tags
|
80 |
+
for tag in selected_role_tags:
|
81 |
+
if tag in role_tags:
|
82 |
+
tag_list.append(role_tags[tag])
|
83 |
+
|
84 |
+
# Add selected skin tone tags
|
85 |
+
for tag in selected_skin_tone_tags:
|
86 |
+
if tag in skin_tone_tags:
|
87 |
+
tag_list.append(skin_tone_tags[tag])
|
88 |
+
|
89 |
+
# Add selected body type tags
|
90 |
+
for tag in selected_body_type_tags:
|
91 |
+
if tag in body_type_tags:
|
92 |
+
tag_list.append(body_type_tags[tag])
|
93 |
+
|
94 |
+
# Add selected tattoo tags
|
95 |
+
for tag in selected_tattoo_tags:
|
96 |
+
if tag in tattoo_tags:
|
97 |
+
tag_list.append(tattoo_tags[tag])
|
98 |
+
|
99 |
+
# Add selected piercing tags
|
100 |
+
for tag in selected_piercing_tags:
|
101 |
+
if tag in piercing_tags:
|
102 |
+
tag_list.append(piercing_tags[tag])
|
103 |
+
|
104 |
+
# Add selected expression tags
|
105 |
+
for tag in selected_expression_tags:
|
106 |
+
if tag in expression_tags:
|
107 |
+
tag_list.append(expression_tags[tag])
|
108 |
+
|
109 |
+
# Add selected eye tags
|
110 |
+
for tag in selected_eye_tags:
|
111 |
+
if tag in eye_tags:
|
112 |
+
tag_list.append(eye_tags[tag])
|
113 |
+
|
114 |
+
# Add selected hair style tags
|
115 |
+
for tag in selected_hair_style_tags:
|
116 |
+
if tag in hair_style_tags:
|
117 |
+
tag_list.append(hair_style_tags[tag])
|
118 |
+
|
119 |
+
# Add selected position tags
|
120 |
+
for tag in selected_position_tags:
|
121 |
+
if tag in position_tags:
|
122 |
+
tag_list.append(position_tags[tag])
|
123 |
+
|
124 |
+
# Add selected fetish tags
|
125 |
+
for tag in selected_fetish_tags:
|
126 |
+
if tag in fetish_tags:
|
127 |
+
tag_list.append(fetish_tags[tag])
|
128 |
+
|
129 |
+
# Add selected location tags
|
130 |
+
for tag in selected_location_tags:
|
131 |
+
if tag in location_tags:
|
132 |
+
tag_list.append(location_tags[tag])
|
133 |
+
|
134 |
+
# Add selected camera tags
|
135 |
+
for tag in selected_camera_tags:
|
136 |
+
if tag in camera_tags:
|
137 |
+
tag_list.append(camera_tags[tag])
|
138 |
+
|
139 |
+
# Add selected atmosphere tags
|
140 |
+
for tag in selected_atmosphere_tags:
|
141 |
+
if tag in atmosphere_tags:
|
142 |
+
tag_list.append(atmosphere_tags[tag])
|
143 |
+
|
144 |
+
# Construct final prompt
|
145 |
final_prompt = f"score_9, score_8_up, score_7_up, source_anime, {', '.join(tag_list)}"
|
146 |
|
147 |
+
# Negative prompt
|
148 |
additional_negatives = "worst quality, bad quality, jpeg artifacts, source_cartoon, 3d, (censor), monochrome, blurry, lowres, watermark"
|
149 |
full_negative_prompt = f"{additional_negatives}, {negative_prompt}"
|
150 |
|
151 |
+
# Handle random seed if needed
|
152 |
if randomize_seed:
|
153 |
seed = random.randint(0, MAX_SEED)
|
154 |
generator = torch.Generator(device=device).manual_seed(seed)
|
155 |
|
156 |
+
# Generate the image
|
157 |
+
try:
|
158 |
+
image = pipe(
|
159 |
+
prompt=final_prompt,
|
160 |
+
negative_prompt=full_negative_prompt,
|
161 |
+
guidance_scale=guidance_scale,
|
162 |
+
num_inference_steps=num_inference_steps,
|
163 |
+
width=width,
|
164 |
+
height=height,
|
165 |
+
generator=generator
|
166 |
+
).images[0]
|
167 |
+
except Exception as e:
|
168 |
+
print(f"Error generating image: {str(e)}")
|
169 |
+
raise
|
|
|
|
|
|
|
170 |
|
171 |
+
return image, seed, f"Prompt: {final_prompt}\nNegative Prompt: {full_negative_prompt}"
|
172 |
|
173 |
+
# Gradio UI setup
|
174 |
css = """
|
175 |
#col-container {
|
176 |
margin: 0 auto;
|
|
|
187 |
active_tab = gr.State("Prompt Input")
|
188 |
|
189 |
with gr.Tabs() as tabs:
|
190 |
+
# Tab setup for different categories
|
191 |
with gr.TabItem("Prompt Input"):
|
192 |
prompt = gr.Textbox(label="Prompt", placeholder="Enter your custom prompt")
|
193 |
tabs.select(lambda: "Prompt Input", inputs=None, outputs=active_tab)
|
194 |
|
|
|
195 |
with gr.TabItem("Straight"):
|
196 |
tags_module = load_tags("Straight")
|
197 |
selected_participant_tags = gr.CheckboxGroup(choices=list(tags_module.participant_tags.keys()), label="Participant Tags")
|
|
|
211 |
selected_atmosphere_tags = gr.CheckboxGroup(choices=list(tags_module.atmosphere_tags.keys()), label="Atmosphere Tags")
|
212 |
tabs.select(lambda: "Straight", inputs=None, outputs=active_tab)
|
213 |
|
214 |
+
# Advanced settings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
with gr.Accordion("Advanced Settings", open=False):
|
216 |
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt")
|
217 |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|