panelforge commited on
Commit
e2f757b
·
verified ·
1 Parent(s): c8b1eef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -38
app.py CHANGED
@@ -7,7 +7,9 @@ import torch
7
  from tags import participant_tags, tribe_tags, skin_tone_tags, body_type_tags, tattoo_tags, piercing_tags, expression_tags, eye_tags, hair_style_tags, position_tags, fetish_tags, location_tags, camera_tags, atmosphere_tags
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl" # Replace with your desired model
 
 
11
 
12
  if torch.cuda.is_available():
13
  torch_dtype = torch.float16
@@ -20,6 +22,14 @@ pipe = pipe.to(device)
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
22
 
 
 
 
 
 
 
 
 
23
  @spaces.GPU # [uncomment to use ZeroGPU]
24
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
25
  selected_participant_tags, selected_tribe_tags, selected_skin_tone_tags, selected_body_type_tags,
@@ -28,10 +38,8 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
28
  selected_camera_tags, selected_atmosphere_tags, active_tab, progress=gr.Progress(track_tqdm=True)):
29
 
30
  if active_tab == "Prompt Input":
31
- # Use the user-provided prompt
32
  final_prompt = f'score_9, score_8_up, score_7_up, source_anime, {prompt}'
33
  else:
34
- # Use tags from the "Tag Selection" tab
35
  selected_tags = (
36
  [participant_tags[tag] for tag in selected_participant_tags] +
37
  [tribe_tags[tag] for tag in selected_tribe_tags] +
@@ -51,7 +59,6 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
51
  tags_text = ', '.join(selected_tags)
52
  final_prompt = f'score_9, score_8_up, score_7_up, source_anime, {tags_text}'
53
 
54
- # Concatenate user-provided negative prompt with additional restrictions
55
  additional_negatives = "worst quality, bad quality, jpeg artifacts, source_cartoon, 3d, (censor), monochrome, blurry, lowres, watermark"
56
  full_negative_prompt = f"{additional_negatives}, {negative_prompt}"
57
 
@@ -71,10 +78,9 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
71
  generator=generator
72
  ).images[0]
73
 
74
- # Return image, seed, and the used prompts
75
  return image, seed, f"Prompt used: {final_prompt}\nNegative prompt used: {full_negative_prompt}"
76
 
77
-
78
  css = """
79
  #col-container {
80
  margin: 0 auto;
@@ -123,45 +129,25 @@ css = """
123
  margin-bottom: 20px;
124
  }
125
 
126
- #external-links {
127
- margin-bottom: 20px;
128
  display: flex;
129
- gap: 10px;
130
  }
131
 
132
- #external-links .gradio-button {
133
  flex: 1;
134
- display: block;
135
- width: 100%;
136
  }
137
  """
138
 
 
139
  with gr.Blocks(css=css) as demo:
140
 
141
  with gr.Row():
142
  with gr.Column(elem_id="left-column"):
143
  gr.Markdown("""# Rainbow Media X""")
144
-
145
- # Add buttons for external links above the prompt using Markdown
146
- with gr.Row(elem_id="external-links"):
147
- gr.Markdown(
148
- """
149
- <a href="https://example.com/space1" target="_blank">
150
- <button class="gradio-button" style="width: 100%;">Go to Space 1</button>
151
- </a>
152
- <a href="https://example.com/space2" target="_blank">
153
- <button class="gradio-button" style="width: 100%;">Go to Space 2</button>
154
- </a>
155
- <a href="https://example.com/space3" target="_blank">
156
- <button class="gradio-button" style="width: 100%;">Go to Space 3</button>
157
- </a>
158
- """
159
- )
160
-
161
- # Display result image at the top
162
  result = gr.Image(label="Result", show_label=False, elem_id="result")
163
-
164
- # Add a textbox to display the prompts used for generation
165
  prompt_info = gr.Textbox(label="Prompts Used", lines=3, interactive=False, elem_id="prompt-info")
166
 
167
  # Advanced Settings and Run Button
@@ -217,15 +203,10 @@ with gr.Blocks(css=css) as demo:
217
  value=35,
218
  )
219
 
220
- # Full-width "Run" button
221
  run_button = gr.Button("Run", elem_id="run-button")
222
 
223
  with gr.Column(elem_id="right-column"):
224
- # Removed the Prompt / Tag Input title here
225
- # State to track active tab
226
  active_tab = gr.State("Prompt Input")
227
-
228
- # Tabbed interface to select either Prompt or Tags
229
  with gr.Tabs() as tabs:
230
  with gr.TabItem("Prompt Input") as prompt_tab:
231
  prompt = gr.Textbox(
@@ -239,7 +220,6 @@ with gr.Blocks(css=css) as demo:
239
  prompt_tab.select(lambda: "Prompt Input", inputs=None, outputs=active_tab)
240
 
241
  with gr.TabItem("Tag Selection") as tag_tab:
242
- # Tag selection checkboxes for each tag group
243
  selected_participant_tags = gr.CheckboxGroup(choices=list(participant_tags.keys()), label="Participant Tags")
244
  selected_tribe_tags = gr.CheckboxGroup(choices=list(tribe_tags.keys()), label="Tribe Tags")
245
  selected_skin_tone_tags = gr.CheckboxGroup(choices=list(skin_tone_tags.keys()), label="Skin Tone Tags")
@@ -256,6 +236,16 @@ with gr.Blocks(css=css) as demo:
256
  selected_atmosphere_tags = gr.CheckboxGroup(choices=list(atmosphere_tags.keys()), label="Atmosphere Tags")
257
  tag_tab.select(lambda: "Tag Selection", inputs=None, outputs=active_tab)
258
 
 
 
 
 
 
 
 
 
 
 
259
  run_button.click(
260
  infer,
261
  inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
 
7
  from tags import participant_tags, tribe_tags, skin_tone_tags, body_type_tags, tattoo_tags, piercing_tags, expression_tags, eye_tags, hair_style_tags, position_tags, fetish_tags, location_tags, camera_tags, atmosphere_tags
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # Default model version
12
+ model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl" # Default model V8
13
 
14
  if torch.cuda.is_available():
15
  torch_dtype = torch.float16
 
22
  MAX_SEED = np.iinfo(np.int32).max
23
  MAX_IMAGE_SIZE = 1024
24
 
25
+ def update_model_version(version):
26
+ """Update the model version dynamically based on the selected version."""
27
+ global model_repo_id, pipe
28
+ model_repo_id = f"John6666/wai-ani-nsfw-ponyxl-{version}-sdxl"
29
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
30
+ pipe = pipe.to(device)
31
+ print(f"Model switched to {model_repo_id}")
32
+
33
  @spaces.GPU # [uncomment to use ZeroGPU]
34
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
35
  selected_participant_tags, selected_tribe_tags, selected_skin_tone_tags, selected_body_type_tags,
 
38
  selected_camera_tags, selected_atmosphere_tags, active_tab, progress=gr.Progress(track_tqdm=True)):
39
 
40
  if active_tab == "Prompt Input":
 
41
  final_prompt = f'score_9, score_8_up, score_7_up, source_anime, {prompt}'
42
  else:
 
43
  selected_tags = (
44
  [participant_tags[tag] for tag in selected_participant_tags] +
45
  [tribe_tags[tag] for tag in selected_tribe_tags] +
 
59
  tags_text = ', '.join(selected_tags)
60
  final_prompt = f'score_9, score_8_up, score_7_up, source_anime, {tags_text}'
61
 
 
62
  additional_negatives = "worst quality, bad quality, jpeg artifacts, source_cartoon, 3d, (censor), monochrome, blurry, lowres, watermark"
63
  full_negative_prompt = f"{additional_negatives}, {negative_prompt}"
64
 
 
78
  generator=generator
79
  ).images[0]
80
 
 
81
  return image, seed, f"Prompt used: {final_prompt}\nNegative prompt used: {full_negative_prompt}"
82
 
83
+ # CSS for button styling and horizontal layout
84
  css = """
85
  #col-container {
86
  margin: 0 auto;
 
129
  margin-bottom: 20px;
130
  }
131
 
132
+ .button-group {
 
133
  display: flex;
134
+ justify-content: space-between;
135
  }
136
 
137
+ .button-group .gradio-button {
138
  flex: 1;
139
+ margin: 0 10px;
140
+ text-align: center;
141
  }
142
  """
143
 
144
+ # Gradio interface setup
145
  with gr.Blocks(css=css) as demo:
146
 
147
  with gr.Row():
148
  with gr.Column(elem_id="left-column"):
149
  gr.Markdown("""# Rainbow Media X""")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  result = gr.Image(label="Result", show_label=False, elem_id="result")
 
 
151
  prompt_info = gr.Textbox(label="Prompts Used", lines=3, interactive=False, elem_id="prompt-info")
152
 
153
  # Advanced Settings and Run Button
 
203
  value=35,
204
  )
205
 
 
206
  run_button = gr.Button("Run", elem_id="run-button")
207
 
208
  with gr.Column(elem_id="right-column"):
 
 
209
  active_tab = gr.State("Prompt Input")
 
 
210
  with gr.Tabs() as tabs:
211
  with gr.TabItem("Prompt Input") as prompt_tab:
212
  prompt = gr.Textbox(
 
220
  prompt_tab.select(lambda: "Prompt Input", inputs=None, outputs=active_tab)
221
 
222
  with gr.TabItem("Tag Selection") as tag_tab:
 
223
  selected_participant_tags = gr.CheckboxGroup(choices=list(participant_tags.keys()), label="Participant Tags")
224
  selected_tribe_tags = gr.CheckboxGroup(choices=list(tribe_tags.keys()), label="Tribe Tags")
225
  selected_skin_tone_tags = gr.CheckboxGroup(choices=list(skin_tone_tags.keys()), label="Skin Tone Tags")
 
236
  selected_atmosphere_tags = gr.CheckboxGroup(choices=list(atmosphere_tags.keys()), label="Atmosphere Tags")
237
  tag_tab.select(lambda: "Tag Selection", inputs=None, outputs=active_tab)
238
 
239
+ # Add buttons for selecting model versions
240
+ with gr.Row(elem_id="button-group"):
241
+ link_button_v7 = gr.Button("V7 Model", elem_id="link-v7", variant="primary")
242
+ link_button_v8 = gr.Button("V8 Model", elem_id="link-v8", variant="primary")
243
+ link_button_v11 = gr.Button("V11 Model", elem_id="link-v11", variant="primary")
244
+
245
+ link_button_v7.click(update_model_version, inputs=["V7"], outputs=[])
246
+ link_button_v8.click(update_model_version, inputs=["V8"], outputs=[])
247
+ link_button_v11.click(update_model_version, inputs=["V11"], outputs=[])
248
+
249
  run_button.click(
250
  infer,
251
  inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,