panelforge commited on
Commit
7bb50ea
·
verified ·
1 Parent(s): 3e820bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -24
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import gradio as gr
 
 
2
  import numpy as np
3
  import random
4
- import spaces # Restored import for spaces
5
- from diffusers import DiffusionPipeline
6
- import torch
7
  from tags import participant_tags, tribe_tags, skin_tone_tags, body_type_tags, tattoo_tags, piercing_tags, expression_tags, eye_tags, hair_style_tags, position_tags, fetish_tags, location_tags, camera_tags, atmosphere_tags
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -20,7 +20,7 @@ pipe = pipe.to(device)
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
22
 
23
- @spaces.GPU # Restored decorator to enable GPU use
24
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
25
  selected_participant_tags, selected_tribe_tags, selected_skin_tone_tags, selected_body_type_tags,
26
  selected_tattoo_tags, selected_piercing_tags, selected_expression_tags, selected_eye_tags,
@@ -214,6 +214,37 @@ with gr.Blocks(css=css) as demo:
214
  selected_atmosphere_tags = gr.CheckboxGroup(choices=list(atmosphere_tags.keys()), label="Atmosphere Tags")
215
  tag_tab.select(lambda: "Tag Selection", inputs=None, outputs=active_tab)
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  run_button.click(
218
  infer,
219
  inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
@@ -224,23 +255,4 @@ with gr.Blocks(css=css) as demo:
224
  outputs=[result, seed, prompt_info]
225
  )
226
 
227
- link_button_v7 = gr.Button("V7", elem_id="link-v7", size="sm")
228
- link_button_v8 = gr.Button("V8", elem_id="link-v8", size="sm")
229
- link_button_v11 = gr.Button("V11", elem_id="link-v11", size="sm")
230
-
231
- def update_model_version(version):
232
- global model_repo_id
233
- if version == "v7":
234
- model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v7-sdxl"
235
- elif version == "v8":
236
- model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl"
237
- elif version == "v11":
238
- model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v11-sdxl"
239
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
240
- pipe = pipe.to(device)
241
-
242
- link_button_v7.click(update_model_version, inputs=["v7"], outputs=[])
243
- link_button_v8.click(update_model_version, inputs=["v8"], outputs=[])
244
- link_button_v11.click(update_model_version, inputs=["v11"], outputs=[])
245
-
246
- demo.queue().launch()
 
1
  import gradio as gr
2
+ import torch
3
+ from diffusers import DiffusionPipeline
4
  import numpy as np
5
  import random
6
+ import spaces
 
 
7
  from tags import participant_tags, tribe_tags, skin_tone_tags, body_type_tags, tattoo_tags, piercing_tags, expression_tags, eye_tags, hair_style_tags, position_tags, fetish_tags, location_tags, camera_tags, atmosphere_tags
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
22
 
23
+ @spaces.GPU
24
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
25
  selected_participant_tags, selected_tribe_tags, selected_skin_tone_tags, selected_body_type_tags,
26
  selected_tattoo_tags, selected_piercing_tags, selected_expression_tags, selected_eye_tags,
 
214
  selected_atmosphere_tags = gr.CheckboxGroup(choices=list(atmosphere_tags.keys()), label="Atmosphere Tags")
215
  tag_tab.select(lambda: "Tag Selection", inputs=None, outputs=active_tab)
216
 
217
+ # Model version buttons
218
+ link_button_v7 = gr.Button("Use Model V7")
219
+ link_button_v8 = gr.Button("Use Model V8")
220
+ link_button_v11 = gr.Button("Use Model V11")
221
+
222
+ def update_model_version_v7():
223
+ global model_repo_id
224
+ model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v7-sdxl"
225
+ global pipe
226
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
227
+ pipe = pipe.to(device)
228
+
229
+ def update_model_version_v8():
230
+ global model_repo_id
231
+ model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl"
232
+ global pipe
233
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
234
+ pipe = pipe.to(device)
235
+
236
+ def update_model_version_v11():
237
+ global model_repo_id
238
+ model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v11-sdxl"
239
+ global pipe
240
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
241
+ pipe = pipe.to(device)
242
+
243
+ # Link button actions
244
+ link_button_v7.click(update_model_version_v7)
245
+ link_button_v8.click(update_model_version_v8)
246
+ link_button_v11.click(update_model_version_v11)
247
+
248
  run_button.click(
249
  infer,
250
  inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
 
255
  outputs=[result, seed, prompt_info]
256
  )
257
 
258
+ demo.queue()