panelforge commited on
Commit
2de8388
·
verified ·
1 Parent(s): 690fec0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -111
app.py CHANGED
@@ -4,7 +4,7 @@ import random
4
  import torch
5
  import spaces
6
  from diffusers import DiffusionPipeline
7
- import importlib # to import tag modules dynamically
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl" # Replace with your desired model
@@ -22,18 +22,19 @@ MAX_IMAGE_SIZE = 1024
22
 
23
  # Function to load tags dynamically based on the selected tab
24
  def load_tags(active_tab):
25
- if active_tab == "Gay":
26
- tags_module = importlib.import_module('tags_gay') # dynamically import the tags_gay module
27
- elif active_tab == "Straight":
28
- tags_module = importlib.import_module('tags_straight') # dynamically import the tags_straight module
29
- elif active_tab == "Lesbian":
30
- tags_module = importlib.import_module('tags_lesbian') # dynamically import the tags_lesbian module
31
- else:
32
- raise ValueError(f"Unknown tab: {active_tab}")
33
-
34
- return tags_module
 
 
35
 
36
- @spaces.GPU
37
  @spaces.GPU
38
  def infer(
39
  prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
@@ -45,7 +46,7 @@ def infer(
45
  # Dynamically load the correct tags module based on active tab
46
  tags_module = load_tags(active_tab)
47
 
48
- # Now use the tags from the loaded module
49
  participant_tags = tags_module.participant_tags
50
  tribe_tags = tags_module.tribe_tags
51
  role_tags = tags_module.role_tags
@@ -62,74 +63,114 @@ def infer(
62
  camera_tags = tags_module.camera_tags
63
  atmosphere_tags = tags_module.atmosphere_tags
64
 
65
- # Debug: Check if the tag modules are loaded correctly
66
- print(f"Loaded tags for {active_tab}:")
67
- print(f"participant_tags: {list(participant_tags.keys())}")
68
- print(f"tribe_tags: {list(tribe_tags.keys())}")
69
- print(f"role_tags: {list(role_tags.keys())}")
70
- print(f"skin_tone_tags: {list(skin_tone_tags.keys())}")
71
-
72
- # Handle the active tab and generate the prompt accordingly
73
  tag_list = []
74
 
75
- # Dynamically build tag list based on selected checkboxes
76
- for tag_list_group, selected_tags in [
77
- (participant_tags, selected_participant_tags),
78
- (tribe_tags, selected_tribe_tags),
79
- (role_tags, selected_role_tags),
80
- (skin_tone_tags, selected_skin_tone_tags),
81
- (body_type_tags, selected_body_type_tags),
82
- (tattoo_tags, selected_tattoo_tags),
83
- (piercing_tags, selected_piercing_tags),
84
- (expression_tags, selected_expression_tags),
85
- (eye_tags, selected_eye_tags),
86
- (hair_style_tags, selected_hair_style_tags),
87
- (position_tags, selected_position_tags),
88
- (fetish_tags, selected_fetish_tags),
89
- (location_tags, selected_location_tags),
90
- (camera_tags, selected_camera_tags),
91
- (atmosphere_tags, selected_atmosphere_tags)
92
- ]:
93
- # Add the selected tags for this category
94
- for tag in selected_tags:
95
- if tag in tag_list_group:
96
- tag_list.append(tag_list_group[tag])
97
- else:
98
- print(f"Tag '{tag}' not found in {tag_list_group}")
99
-
100
- # Debug: Print the final prompt that is being generated
101
- print(f"Final Prompt: {final_prompt}")
102
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  final_prompt = f"score_9, score_8_up, score_7_up, source_anime, {', '.join(tag_list)}"
104
 
105
- # Concatenate additional negative prompts
106
  additional_negatives = "worst quality, bad quality, jpeg artifacts, source_cartoon, 3d, (censor), monochrome, blurry, lowres, watermark"
107
  full_negative_prompt = f"{additional_negatives}, {negative_prompt}"
108
 
 
109
  if randomize_seed:
110
  seed = random.randint(0, MAX_SEED)
111
  generator = torch.Generator(device=device).manual_seed(seed)
112
 
113
- # Debug: Ensure the final negative prompt is correct
114
- print(f"Negative Prompt: {full_negative_prompt}")
115
-
116
- image = pipe(
117
- prompt=final_prompt,
118
- negative_prompt=full_negative_prompt,
119
- guidance_scale=guidance_scale,
120
- num_inference_steps=num_inference_steps,
121
- width=width,
122
- height=height,
123
- generator=generator
124
- ).images[0]
125
-
126
- # Debug: Confirm the seed and image generation
127
- print(f"Seed used for generation: {seed}")
128
-
129
- return image, seed, f"Prompt: {final_prompt}\nNegative Prompt: {full_negative_prompt}"
130
 
 
131
 
132
- # CSS for the layout
133
  css = """
134
  #col-container {
135
  margin: 0 auto;
@@ -146,12 +187,11 @@ with gr.Blocks(css=css) as demo:
146
  active_tab = gr.State("Prompt Input")
147
 
148
  with gr.Tabs() as tabs:
149
- # Prompt Input Tab
150
  with gr.TabItem("Prompt Input"):
151
  prompt = gr.Textbox(label="Prompt", placeholder="Enter your custom prompt")
152
  tabs.select(lambda: "Prompt Input", inputs=None, outputs=active_tab)
153
 
154
- # Straight Tab
155
  with gr.TabItem("Straight"):
156
  tags_module = load_tags("Straight")
157
  selected_participant_tags = gr.CheckboxGroup(choices=list(tags_module.participant_tags.keys()), label="Participant Tags")
@@ -171,47 +211,7 @@ with gr.Blocks(css=css) as demo:
171
  selected_atmosphere_tags = gr.CheckboxGroup(choices=list(tags_module.atmosphere_tags.keys()), label="Atmosphere Tags")
172
  tabs.select(lambda: "Straight", inputs=None, outputs=active_tab)
173
 
174
- # Gay Tab
175
- with gr.TabItem("Gay"):
176
- tags_module = load_tags("Gay")
177
- selected_participant_tags = gr.CheckboxGroup(choices=list(tags_module.participant_tags.keys()), label="Participant Tags")
178
- selected_tribe_tags = gr.CheckboxGroup(choices=list(tags_module.tribe_tags.keys()), label="Tribe Tags")
179
- selected_role_tags = gr.CheckboxGroup(choices=list(tags_module.role_tags.keys()), label="Role Tags")
180
- selected_skin_tone_tags = gr.CheckboxGroup(choices=list(tags_module.skin_tone_tags.keys()), label="Skin Tone Tags")
181
- selected_body_type_tags = gr.CheckboxGroup(choices=list(tags_module.body_type_tags.keys()), label="Body Type Tags")
182
- selected_tattoo_tags = gr.CheckboxGroup(choices=list(tags_module.tattoo_tags.keys()), label="Tattoo Tags")
183
- selected_piercing_tags = gr.CheckboxGroup(choices=list(tags_module.piercing_tags.keys()), label="Piercing Tags")
184
- selected_expression_tags = gr.CheckboxGroup(choices=list(tags_module.expression_tags.keys()), label="Expression Tags")
185
- selected_eye_tags = gr.CheckboxGroup(choices=list(tags_module.eye_tags.keys()), label="Eye Tags")
186
- selected_hair_style_tags = gr.CheckboxGroup(choices=list(tags_module.hair_style_tags.keys()), label="Hair Style Tags")
187
- selected_position_tags = gr.CheckboxGroup(choices=list(tags_module.position_tags.keys()), label="Position Tags")
188
- selected_fetish_tags = gr.CheckboxGroup(choices=list(tags_module.fetish_tags.keys()), label="Fetish Tags")
189
- selected_location_tags = gr.CheckboxGroup(choices=list(tags_module.location_tags.keys()), label="Location Tags")
190
- selected_camera_tags = gr.CheckboxGroup(choices=list(tags_module.camera_tags.keys()), label="Camera Tags")
191
- selected_atmosphere_tags = gr.CheckboxGroup(choices=list(tags_module.atmosphere_tags.keys()), label="Atmosphere Tags")
192
- tabs.select(lambda: "Gay", inputs=None, outputs=active_tab)
193
-
194
- # Lesbian Tab
195
- with gr.TabItem("Lesbian"):
196
- tags_module = load_tags("Lesbian")
197
- selected_participant_tags = gr.CheckboxGroup(choices=list(tags_module.participant_tags.keys()), label="Participant Tags")
198
- selected_tribe_tags = gr.CheckboxGroup(choices=list(tags_module.tribe_tags.keys()), label="Tribe Tags")
199
- selected_role_tags = gr.CheckboxGroup(choices=list(tags_module.role_tags.keys()), label="Role Tags")
200
- selected_skin_tone_tags = gr.CheckboxGroup(choices=list(tags_module.skin_tone_tags.keys()), label="Skin Tone Tags")
201
- selected_body_type_tags = gr.CheckboxGroup(choices=list(tags_module.body_type_tags.keys()), label="Body Type Tags")
202
- selected_tattoo_tags = gr.CheckboxGroup(choices=list(tags_module.tattoo_tags.keys()), label="Tattoo Tags")
203
- selected_piercing_tags = gr.CheckboxGroup(choices=list(tags_module.piercing_tags.keys()), label="Piercing Tags")
204
- selected_expression_tags = gr.CheckboxGroup(choices=list(tags_module.expression_tags.keys()), label="Expression Tags")
205
- selected_eye_tags = gr.CheckboxGroup(choices=list(tags_module.eye_tags.keys()), label="Eye Tags")
206
- selected_hair_style_tags = gr.CheckboxGroup(choices=list(tags_module.hair_style_tags.keys()), label="Hair Style Tags")
207
- selected_position_tags = gr.CheckboxGroup(choices=list(tags_module.position_tags.keys()), label="Position Tags")
208
- selected_fetish_tags = gr.CheckboxGroup(choices=list(tags_module.fetish_tags.keys()), label="Fetish Tags")
209
- selected_location_tags = gr.CheckboxGroup(choices=list(tags_module.location_tags.keys()), label="Location Tags")
210
- selected_camera_tags = gr.CheckboxGroup(choices=list(tags_module.camera_tags.keys()), label="Camera Tags")
211
- selected_atmosphere_tags = gr.CheckboxGroup(choices=list(tags_module.atmosphere_tags.keys()), label="Atmosphere Tags")
212
- tabs.select(lambda: "Lesbian", inputs=None, outputs=active_tab)
213
-
214
- # Advanced Settings
215
  with gr.Accordion("Advanced Settings", open=False):
216
  negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt")
217
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
 
4
  import torch
5
  import spaces
6
  from diffusers import DiffusionPipeline
7
+ import importlib
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl" # Replace with your desired model
 
22
 
23
  # Function to load tags dynamically based on the selected tab
24
  def load_tags(active_tab):
25
+ try:
26
+ if active_tab == "Gay":
27
+ return importlib.import_module('tags_gay') # dynamically import the tags_gay module
28
+ elif active_tab == "Straight":
29
+ return importlib.import_module('tags_straight') # dynamically import the tags_straight module
30
+ elif active_tab == "Lesbian":
31
+ return importlib.import_module('tags_lesbian') # dynamically import the tags_lesbian module
32
+ else:
33
+ raise ValueError(f"Unknown tab: {active_tab}")
34
+ except Exception as e:
35
+ print(f"Error loading tags for {active_tab}: {str(e)}")
36
+ raise
37
 
 
38
  @spaces.GPU
39
  def infer(
40
  prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
 
46
  # Dynamically load the correct tags module based on active tab
47
  tags_module = load_tags(active_tab)
48
 
49
+ # Get the tag dictionaries from the loaded module
50
  participant_tags = tags_module.participant_tags
51
  tribe_tags = tags_module.tribe_tags
52
  role_tags = tags_module.role_tags
 
63
  camera_tags = tags_module.camera_tags
64
  atmosphere_tags = tags_module.atmosphere_tags
65
 
66
+ # Build the tag list using selected tags from each group
 
 
 
 
 
 
 
67
  tag_list = []
68
 
69
+ # Add selected participant tags
70
+ for tag in selected_participant_tags:
71
+ if tag in participant_tags:
72
+ tag_list.append(participant_tags[tag])
73
+
74
+ # Add selected tribe tags
75
+ for tag in selected_tribe_tags:
76
+ if tag in tribe_tags:
77
+ tag_list.append(tribe_tags[tag])
78
+
79
+ # Add selected role tags
80
+ for tag in selected_role_tags:
81
+ if tag in role_tags:
82
+ tag_list.append(role_tags[tag])
83
+
84
+ # Add selected skin tone tags
85
+ for tag in selected_skin_tone_tags:
86
+ if tag in skin_tone_tags:
87
+ tag_list.append(skin_tone_tags[tag])
88
+
89
+ # Add selected body type tags
90
+ for tag in selected_body_type_tags:
91
+ if tag in body_type_tags:
92
+ tag_list.append(body_type_tags[tag])
93
+
94
+ # Add selected tattoo tags
95
+ for tag in selected_tattoo_tags:
96
+ if tag in tattoo_tags:
97
+ tag_list.append(tattoo_tags[tag])
98
+
99
+ # Add selected piercing tags
100
+ for tag in selected_piercing_tags:
101
+ if tag in piercing_tags:
102
+ tag_list.append(piercing_tags[tag])
103
+
104
+ # Add selected expression tags
105
+ for tag in selected_expression_tags:
106
+ if tag in expression_tags:
107
+ tag_list.append(expression_tags[tag])
108
+
109
+ # Add selected eye tags
110
+ for tag in selected_eye_tags:
111
+ if tag in eye_tags:
112
+ tag_list.append(eye_tags[tag])
113
+
114
+ # Add selected hair style tags
115
+ for tag in selected_hair_style_tags:
116
+ if tag in hair_style_tags:
117
+ tag_list.append(hair_style_tags[tag])
118
+
119
+ # Add selected position tags
120
+ for tag in selected_position_tags:
121
+ if tag in position_tags:
122
+ tag_list.append(position_tags[tag])
123
+
124
+ # Add selected fetish tags
125
+ for tag in selected_fetish_tags:
126
+ if tag in fetish_tags:
127
+ tag_list.append(fetish_tags[tag])
128
+
129
+ # Add selected location tags
130
+ for tag in selected_location_tags:
131
+ if tag in location_tags:
132
+ tag_list.append(location_tags[tag])
133
+
134
+ # Add selected camera tags
135
+ for tag in selected_camera_tags:
136
+ if tag in camera_tags:
137
+ tag_list.append(camera_tags[tag])
138
+
139
+ # Add selected atmosphere tags
140
+ for tag in selected_atmosphere_tags:
141
+ if tag in atmosphere_tags:
142
+ tag_list.append(atmosphere_tags[tag])
143
+
144
+ # Construct final prompt
145
  final_prompt = f"score_9, score_8_up, score_7_up, source_anime, {', '.join(tag_list)}"
146
 
147
+ # Negative prompt
148
  additional_negatives = "worst quality, bad quality, jpeg artifacts, source_cartoon, 3d, (censor), monochrome, blurry, lowres, watermark"
149
  full_negative_prompt = f"{additional_negatives}, {negative_prompt}"
150
 
151
+ # Handle random seed if needed
152
  if randomize_seed:
153
  seed = random.randint(0, MAX_SEED)
154
  generator = torch.Generator(device=device).manual_seed(seed)
155
 
156
+ # Generate the image
157
+ try:
158
+ image = pipe(
159
+ prompt=final_prompt,
160
+ negative_prompt=full_negative_prompt,
161
+ guidance_scale=guidance_scale,
162
+ num_inference_steps=num_inference_steps,
163
+ width=width,
164
+ height=height,
165
+ generator=generator
166
+ ).images[0]
167
+ except Exception as e:
168
+ print(f"Error generating image: {str(e)}")
169
+ raise
 
 
 
170
 
171
+ return image, seed, f"Prompt: {final_prompt}\nNegative Prompt: {full_negative_prompt}"
172
 
173
+ # Gradio UI setup
174
  css = """
175
  #col-container {
176
  margin: 0 auto;
 
187
  active_tab = gr.State("Prompt Input")
188
 
189
  with gr.Tabs() as tabs:
190
+ # Tab setup for different categories
191
  with gr.TabItem("Prompt Input"):
192
  prompt = gr.Textbox(label="Prompt", placeholder="Enter your custom prompt")
193
  tabs.select(lambda: "Prompt Input", inputs=None, outputs=active_tab)
194
 
 
195
  with gr.TabItem("Straight"):
196
  tags_module = load_tags("Straight")
197
  selected_participant_tags = gr.CheckboxGroup(choices=list(tags_module.participant_tags.keys()), label="Participant Tags")
 
211
  selected_atmosphere_tags = gr.CheckboxGroup(choices=list(tags_module.atmosphere_tags.keys()), label="Atmosphere Tags")
212
  tabs.select(lambda: "Straight", inputs=None, outputs=active_tab)
213
 
214
+ # Advanced settings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  with gr.Accordion("Advanced Settings", open=False):
216
  negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt")
217
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)