Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
2 |
import numpy as np
|
3 |
import random
|
4 |
-
import spaces
|
5 |
-
from diffusers import DiffusionPipeline
|
6 |
-
import torch
|
7 |
from tags import participant_tags, tribe_tags, skin_tone_tags, body_type_tags, tattoo_tags, piercing_tags, expression_tags, eye_tags, hair_style_tags, position_tags, fetish_tags, location_tags, camera_tags, atmosphere_tags
|
8 |
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -20,7 +20,7 @@ pipe = pipe.to(device)
|
|
20 |
MAX_SEED = np.iinfo(np.int32).max
|
21 |
MAX_IMAGE_SIZE = 1024
|
22 |
|
23 |
-
@spaces.GPU
|
24 |
def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
|
25 |
selected_participant_tags, selected_tribe_tags, selected_skin_tone_tags, selected_body_type_tags,
|
26 |
selected_tattoo_tags, selected_piercing_tags, selected_expression_tags, selected_eye_tags,
|
@@ -214,6 +214,37 @@ with gr.Blocks(css=css) as demo:
|
|
214 |
selected_atmosphere_tags = gr.CheckboxGroup(choices=list(atmosphere_tags.keys()), label="Atmosphere Tags")
|
215 |
tag_tab.select(lambda: "Tag Selection", inputs=None, outputs=active_tab)
|
216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
run_button.click(
|
218 |
infer,
|
219 |
inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
|
@@ -224,23 +255,4 @@ with gr.Blocks(css=css) as demo:
|
|
224 |
outputs=[result, seed, prompt_info]
|
225 |
)
|
226 |
|
227 |
-
|
228 |
-
link_button_v8 = gr.Button("V8", elem_id="link-v8", size="sm")
|
229 |
-
link_button_v11 = gr.Button("V11", elem_id="link-v11", size="sm")
|
230 |
-
|
231 |
-
def update_model_version(version):
|
232 |
-
global model_repo_id
|
233 |
-
if version == "v7":
|
234 |
-
model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v7-sdxl"
|
235 |
-
elif version == "v8":
|
236 |
-
model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl"
|
237 |
-
elif version == "v11":
|
238 |
-
model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v11-sdxl"
|
239 |
-
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
|
240 |
-
pipe = pipe.to(device)
|
241 |
-
|
242 |
-
link_button_v7.click(update_model_version, inputs=["v7"], outputs=[])
|
243 |
-
link_button_v8.click(update_model_version, inputs=["v8"], outputs=[])
|
244 |
-
link_button_v11.click(update_model_version, inputs=["v11"], outputs=[])
|
245 |
-
|
246 |
-
demo.queue().launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from diffusers import DiffusionPipeline
|
4 |
import numpy as np
|
5 |
import random
|
6 |
+
import spaces
|
|
|
|
|
7 |
from tags import participant_tags, tribe_tags, skin_tone_tags, body_type_tags, tattoo_tags, piercing_tags, expression_tags, eye_tags, hair_style_tags, position_tags, fetish_tags, location_tags, camera_tags, atmosphere_tags
|
8 |
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
20 |
MAX_SEED = np.iinfo(np.int32).max
|
21 |
MAX_IMAGE_SIZE = 1024
|
22 |
|
23 |
+
@spaces.GPU
|
24 |
def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
|
25 |
selected_participant_tags, selected_tribe_tags, selected_skin_tone_tags, selected_body_type_tags,
|
26 |
selected_tattoo_tags, selected_piercing_tags, selected_expression_tags, selected_eye_tags,
|
|
|
214 |
selected_atmosphere_tags = gr.CheckboxGroup(choices=list(atmosphere_tags.keys()), label="Atmosphere Tags")
|
215 |
tag_tab.select(lambda: "Tag Selection", inputs=None, outputs=active_tab)
|
216 |
|
217 |
+
# Model version buttons
|
218 |
+
link_button_v7 = gr.Button("Use Model V7")
|
219 |
+
link_button_v8 = gr.Button("Use Model V8")
|
220 |
+
link_button_v11 = gr.Button("Use Model V11")
|
221 |
+
|
222 |
+
def update_model_version_v7():
|
223 |
+
global model_repo_id
|
224 |
+
model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v7-sdxl"
|
225 |
+
global pipe
|
226 |
+
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
|
227 |
+
pipe = pipe.to(device)
|
228 |
+
|
229 |
+
def update_model_version_v8():
|
230 |
+
global model_repo_id
|
231 |
+
model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl"
|
232 |
+
global pipe
|
233 |
+
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
|
234 |
+
pipe = pipe.to(device)
|
235 |
+
|
236 |
+
def update_model_version_v11():
|
237 |
+
global model_repo_id
|
238 |
+
model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v11-sdxl"
|
239 |
+
global pipe
|
240 |
+
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
|
241 |
+
pipe = pipe.to(device)
|
242 |
+
|
243 |
+
# Link button actions
|
244 |
+
link_button_v7.click(update_model_version_v7)
|
245 |
+
link_button_v8.click(update_model_version_v8)
|
246 |
+
link_button_v11.click(update_model_version_v11)
|
247 |
+
|
248 |
run_button.click(
|
249 |
infer,
|
250 |
inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
|
|
|
255 |
outputs=[result, seed, prompt_info]
|
256 |
)
|
257 |
|
258 |
+
demo.queue()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|