File size: 11,334 Bytes
630d1c8
 
 
decc71b
630d1c8
e7003c8
 
630d1c8
 
02074a8
0f8e37d
 
 
 
 
630d1c8
0f8e37d
1b50d93
630d1c8
 
 
 
e7003c8
 
 
 
 
 
 
 
 
 
 
 
 
decc71b
1b50d93
 
 
 
 
 
 
e7003c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b50d93
e7003c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bb4602
1b50d93
90d2a01
 
 
630d1c8
 
1b50d93
6d482fb
630d1c8
6d482fb
90d2a01
6d482fb
 
 
 
 
 
 
1b50d93
630d1c8
e7003c8
6d482fb
0f8e37d
 
 
 
 
630d1c8
 
 
1b50d93
6d482fb
03c18e7
d16d2c8
58d31c6
 
7819529
1b50d93
 
 
 
6d482fb
1b50d93
 
f54a073
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b50d93
 
 
 
f54a073
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b50d93
00e6a86
1b50d93
 
f54a073
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b50d93
630d1c8
1b50d93
630d1c8
1b50d93
 
 
6d482fb
630d1c8
1b50d93
 
6d482fb
630d1c8
1b50d93
 
7819529
1b50d93
02074a8
 
66d1fcc
a3ecd5b
66d1fcc
 
58d31c6
02074a8
 
ce6ba71
6d482fb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import gradio as gr
import numpy as np
import random
import spaces
import torch
from diffusers import DiffusionPipeline
import importlib  # to import tag modules dynamically

device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl"  # Replace with your desired model

if torch.cuda.is_available():
    torch_dtype = torch.float16
else:
    torch_dtype = torch.float32

pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pipe.to(device)

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024

# Function to load tags dynamically based on the selected tab
def load_tags(active_tab):
    if active_tab == "Gay":
        tags_module = importlib.import_module('tags_gay')  # dynamically import the tags_gay module
    elif active_tab == "Straight":
        tags_module = importlib.import_module('tags_straight')  # dynamically import the tags_straight module
    elif active_tab == "Lesbian":
        tags_module = importlib.import_module('tags_lesbian')  # dynamically import the tags_lesbian module
    else:
        raise ValueError(f"Unknown tab: {active_tab}")
    
    return tags_module

@spaces.GPU
def infer(
    prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
    selected_participant_tags, selected_tribe_tags, selected_role_tags, selected_skin_tone_tags, selected_body_type_tags,
    selected_tattoo_tags, selected_piercing_tags, selected_expression_tags, selected_eye_tags, selected_hair_style_tags,
    selected_position_tags, selected_fetish_tags, selected_location_tags, selected_camera_tags, selected_atmosphere_tags,
    active_tab, progress=gr.Progress(track_tqdm=True)
):
    # Dynamically load the correct tags module based on active tab
    tags_module = load_tags(active_tab)

    # Now use the tags from the loaded module
    participant_tags = tags_module.participant_tags
    tribe_tags = tags_module.tribe_tags
    role_tags = tags_module.role_tags
    skin_tone_tags = tags_module.skin_tone_tags
    body_type_tags = tags_module.body_type_tags
    tattoo_tags = tags_module.tattoo_tags
    piercing_tags = tags_module.piercing_tags
    expression_tags = tags_module.expression_tags
    eye_tags = tags_module.eye_tags
    hair_style_tags = tags_module.hair_style_tags
    position_tags = tags_module.position_tags
    fetish_tags = tags_module.fetish_tags
    location_tags = tags_module.location_tags
    camera_tags = tags_module.camera_tags
    atmosphere_tags = tags_module.atmosphere_tags

    # Handle the active tab and generate the prompt accordingly
    tag_list = (
        [participant_tags[tag] for tag in selected_participant_tags] +
        [tribe_tags[tag] for tag in selected_tribe_tags] +
        [role_tags[tag] for tag in selected_role_tags] +
        [skin_tone_tags[tag] for tag in selected_skin_tone_tags] +
        [body_type_tags[tag] for tag in selected_body_type_tags] +
        [tattoo_tags[tag] for tag in selected_tattoo_tags] +
        [piercing_tags[tag] for tag in selected_piercing_tags] +
        [expression_tags[tag] for tag in selected_expression_tags] +
        [eye_tags[tag] for tag in selected_eye_tags] +
        [hair_style_tags[tag] for tag in selected_hair_style_tags] +
        [position_tags[tag] for tag in selected_position_tags] +
        [fetish_tags[tag] for tag in selected_fetish_tags] +
        [location_tags[tag] for tag in selected_location_tags] +
        [camera_tags[tag] for tag in selected_camera_tags] +
        [atmosphere_tags[tag] for tag in selected_atmosphere_tags]
    )

    final_prompt = f"score_9, score_8_up, score_7_up, source_anime, {', '.join(tag_list)}"

    # Concatenate additional negative prompts
    additional_negatives = "worst quality, bad quality, jpeg artifacts, source_cartoon, 3d, (censor), monochrome, blurry, lowres, watermark"
    full_negative_prompt = f"{additional_negatives}, {negative_prompt}"

    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator(device=device).manual_seed(seed)

    image = pipe(
        prompt=final_prompt,
        negative_prompt=full_negative_prompt,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        width=width,
        height=height,
        generator=generator
    ).images[0]

    return image, seed, f"Prompt: {final_prompt}\nNegative Prompt: {full_negative_prompt}"

# CSS for the layout
css = """
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("# Image Generator with Tags and Prompts")

        result = gr.Image(label="Result", show_label=False)
        prompt_info = gr.Textbox(label="Prompts Used", lines=3, interactive=False)
        active_tab = gr.State("Prompt Input")

        with gr.Tabs() as tabs:
            # Prompt Input Tab
            with gr.TabItem("Prompt Input"):
                prompt = gr.Textbox(label="Prompt", placeholder="Enter your custom prompt")
                tabs.select(lambda: "Prompt Input", inputs=None, outputs=active_tab)

            # Straight Tab
            with gr.TabItem("Straight"):
                selected_participant_tags = gr.CheckboxGroup(choices=[], label="Participant Tags")
                selected_tribe_tags = gr.CheckboxGroup(choices=[], label="Tribe Tags")
                selected_role_tags = gr.CheckboxGroup(choices=[], label="Role Tags")
                selected_skin_tone_tags = gr.CheckboxGroup(choices=[], label="Skin Tone Tags")
                selected_body_type_tags = gr.CheckboxGroup(choices=[], label="Body Type Tags")
                selected_tattoo_tags = gr.CheckboxGroup(choices=[], label="Tattoo Tags")
                selected_piercing_tags = gr.CheckboxGroup(choices=[], label="Piercing Tags")
                selected_expression_tags = gr.CheckboxGroup(choices=[], label="Expression Tags")
                selected_eye_tags = gr.CheckboxGroup(choices=[], label="Eye Tags")
                selected_hair_style_tags = gr.CheckboxGroup(choices=[], label="Hair Style Tags")
                selected_position_tags = gr.CheckboxGroup(choices=[], label="Position Tags")
                selected_fetish_tags = gr.CheckboxGroup(choices=[], label="Fetish Tags")
                selected_location_tags = gr.CheckboxGroup(choices=[], label="Location Tags")
                selected_camera_tags = gr.CheckboxGroup(choices=[], label="Camera Tags")
                selected_atmosphere_tags = gr.CheckboxGroup(choices=[], label="Atmosphere Tags")
                tabs.select(lambda: "Straight", inputs=None, outputs=active_tab)

            # Gay Tab
            with gr.TabItem("Gay"):
                selected_participant_tags = gr.CheckboxGroup(choices=[], label="Participant Tags")
                selected_tribe_tags = gr.CheckboxGroup(choices=[], label="Tribe Tags")
                selected_role_tags = gr.CheckboxGroup(choices=[], label="Role Tags")
                selected_skin_tone_tags = gr.CheckboxGroup(choices=[], label="Skin Tone Tags")
                selected_body_type_tags = gr.CheckboxGroup(choices=[], label="Body Type Tags")
                selected_tattoo_tags = gr.CheckboxGroup(choices=[], label="Tattoo Tags")
                selected_piercing_tags = gr.CheckboxGroup(choices=[], label="Piercing Tags")
                selected_expression_tags = gr.CheckboxGroup(choices=[], label="Expression Tags")
                selected_eye_tags = gr.CheckboxGroup(choices=[], label="Eye Tags")
                selected_hair_style_tags = gr.CheckboxGroup(choices=[], label="Hair Style Tags")
                selected_position_tags = gr.CheckboxGroup(choices=[], label="Position Tags")
                selected_fetish_tags = gr.CheckboxGroup(choices=[], label="Fetish Tags")
                selected_location_tags = gr.CheckboxGroup(choices=[], label="Location Tags")
                selected_camera_tags = gr.CheckboxGroup(choices=[], label="Camera Tags")
                selected_atmosphere_tags = gr.CheckboxGroup(choices=[], label="Atmosphere Tags")
                tabs.select(lambda: "Gay", inputs=None, outputs=active_tab)

            # Lesbian Tab
            with gr.TabItem("Lesbian"):
                selected_participant_tags = gr.CheckboxGroup(choices=[], label="Participant Tags")
                selected_tribe_tags = gr.CheckboxGroup(choices=[], label="Tribe Tags")
                selected_role_tags = gr.CheckboxGroup(choices=[], label="Role Tags")
                selected_skin_tone_tags = gr.CheckboxGroup(choices=[], label="Skin Tone Tags")
                selected_body_type_tags = gr.CheckboxGroup(choices=[], label="Body Type Tags")
                selected_tattoo_tags = gr.CheckboxGroup(choices=[], label="Tattoo Tags")
                selected_piercing_tags = gr.CheckboxGroup(choices=[], label="Piercing Tags")
                selected_expression_tags = gr.CheckboxGroup(choices=[], label="Expression Tags")
                selected_eye_tags = gr.CheckboxGroup(choices=[], label="Eye Tags")
                selected_hair_style_tags = gr.CheckboxGroup(choices=[], label="Hair Style Tags")
                selected_position_tags = gr.CheckboxGroup(choices=[], label="Position Tags")
                selected_fetish_tags = gr.CheckboxGroup(choices=[], label="Fetish Tags")
                selected_location_tags = gr.CheckboxGroup(choices=[], label="Location Tags")
                selected_camera_tags = gr.CheckboxGroup(choices=[], label="Camera Tags")
                selected_atmosphere_tags = gr.CheckboxGroup(choices=[], label="Atmosphere Tags")
                tabs.select(lambda: "Lesbian", inputs=None, outputs=active_tab)

        # Advanced Settings
        with gr.Accordion("Advanced Settings", open=False):
            negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt")
            seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
            randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)

            with gr.Row():
                width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
                height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)

            with gr.Row():
                guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=7)
                num_inference_steps = gr.Slider(label="Number of Inference Steps", minimum=1, maximum=50, step=1, value=35)

        run_button = gr.Button("Run")
        run_button.click(
            infer,
            inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
                    selected_participant_tags, selected_tribe_tags, selected_role_tags, selected_skin_tone_tags, selected_body_type_tags,
                    selected_tattoo_tags, selected_piercing_tags, selected_expression_tags, selected_eye_tags,
                    selected_hair_style_tags, selected_position_tags, selected_fetish_tags, selected_location_tags,
                    selected_camera_tags, selected_atmosphere_tags, active_tab],
            outputs=[result, seed, prompt_info]
        )

demo.queue().launch()