rainbow_media_x / app.py
panelforge's picture
Update app.py
7bb50ea verified
raw
history blame
10.4 kB
import gradio as gr
import torch
from diffusers import DiffusionPipeline
import numpy as np
import random
import spaces
from tags import participant_tags, tribe_tags, skin_tone_tags, body_type_tags, tattoo_tags, piercing_tags, expression_tags, eye_tags, hair_style_tags, position_tags, fetish_tags, location_tags, camera_tags, atmosphere_tags
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl" # Default model version
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pipe = pipe.to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
@spaces.GPU
def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
selected_participant_tags, selected_tribe_tags, selected_skin_tone_tags, selected_body_type_tags,
selected_tattoo_tags, selected_piercing_tags, selected_expression_tags, selected_eye_tags,
selected_hair_style_tags, selected_position_tags, selected_fetish_tags, selected_location_tags,
selected_camera_tags, selected_atmosphere_tags, active_tab, progress=gr.Progress(track_tqdm=True)):
if active_tab == "Prompt Input":
final_prompt = f'score_9, score_8_up, score_7_up, source_anime, {prompt}'
else:
selected_tags = (
[participant_tags[tag] for tag in selected_participant_tags] +
[tribe_tags[tag] for tag in selected_tribe_tags] +
[skin_tone_tags[tag] for tag in selected_skin_tone_tags] +
[body_type_tags[tag] for tag in selected_body_type_tags] +
[tattoo_tags[tag] for tag in selected_tattoo_tags] +
[piercing_tags[tag] for tag in selected_piercing_tags] +
[expression_tags[tag] for tag in selected_expression_tags] +
[eye_tags[tag] for tag in selected_eye_tags] +
[hair_style_tags[tag] for tag in selected_hair_style_tags] +
[position_tags[tag] for tag in selected_position_tags] +
[fetish_tags[tag] for tag in selected_fetish_tags] +
[location_tags[tag] for tag in selected_location_tags] +
[camera_tags[tag] for tag in selected_camera_tags] +
[atmosphere_tags[tag] for tag in selected_atmosphere_tags]
)
tags_text = ', '.join(selected_tags)
final_prompt = f'score_9, score_8_up, score_7_up, source_anime, {tags_text}'
additional_negatives = "worst quality, bad quality, jpeg artifacts, source_cartoon, 3d, (censor), monochrome, blurry, lowres, watermark"
full_negative_prompt = f"{additional_negatives}, {negative_prompt}"
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=final_prompt,
negative_prompt=full_negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator
).images[0]
return image, seed, f"Prompt used: {final_prompt}\nNegative prompt used: {full_negative_prompt}"
css = """
#col-container {
margin: 0 auto;
max-width: 1280px;
}
#left-column {
width: 50%;
display: inline-block;
padding-right: 20px;
padding-left: 20px;
vertical-align: top;
}
#right-column {
width: 50%;
display: inline-block;
vertical-align: top;
padding-left: 20px;
margin-top: 53px;
}
#left-column > * {
margin-bottom: 20px;
}
#run-button {
width: 100%;
margin-top: 10px;
display: block;
}
#prompt-info {
margin-bottom: 20px;
}
#result {
margin-bottom: 20px;
}
.gradio-tabs > .tab-item {
margin-bottom: 20px;
}
#prompt {
margin-bottom: 20px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Row():
with gr.Column(elem_id="left-column"):
gr.Markdown("""# Rainbow Media X""")
result = gr.Image(label="Result", show_label=False, elem_id="result")
prompt_info = gr.Textbox(label="Prompts Used", lines=3, interactive=False, elem_id="prompt-info")
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Textbox(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
visible=True,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=7,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=35,
)
run_button = gr.Button("Run", elem_id="run-button")
with gr.Column(elem_id="right-column"):
active_tab = gr.State("Prompt Input")
with gr.Tabs() as tabs:
with gr.TabItem("Prompt Input") as prompt_tab:
prompt = gr.Textbox(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
elem_id="prompt"
)
prompt_tab.select(lambda: "Prompt Input", inputs=None, outputs=active_tab)
with gr.TabItem("Tag Selection") as tag_tab:
selected_participant_tags = gr.CheckboxGroup(choices=list(participant_tags.keys()), label="Participant Tags")
selected_tribe_tags = gr.CheckboxGroup(choices=list(tribe_tags.keys()), label="Tribe Tags")
selected_skin_tone_tags = gr.CheckboxGroup(choices=list(skin_tone_tags.keys()), label="Skin Tone Tags")
selected_body_type_tags = gr.CheckboxGroup(choices=list(body_type_tags.keys()), label="Body Type Tags")
selected_tattoo_tags = gr.CheckboxGroup(choices=list(tattoo_tags.keys()), label="Tattoo Tags")
selected_piercing_tags = gr.CheckboxGroup(choices=list(piercing_tags.keys()), label="Piercing Tags")
selected_expression_tags = gr.CheckboxGroup(choices=list(expression_tags.keys()), label="Expression Tags")
selected_eye_tags = gr.CheckboxGroup(choices=list(eye_tags.keys()), label="Eye Tags")
selected_hair_style_tags = gr.CheckboxGroup(choices=list(hair_style_tags.keys()), label="Hair Style Tags")
selected_position_tags = gr.CheckboxGroup(choices=list(position_tags.keys()), label="Position Tags")
selected_fetish_tags = gr.CheckboxGroup(choices=list(fetish_tags.keys()), label="Fetish Tags")
selected_location_tags = gr.CheckboxGroup(choices=list(location_tags.keys()), label="Location Tags")
selected_camera_tags = gr.CheckboxGroup(choices=list(camera_tags.keys()), label="Camera Tags")
selected_atmosphere_tags = gr.CheckboxGroup(choices=list(atmosphere_tags.keys()), label="Atmosphere Tags")
tag_tab.select(lambda: "Tag Selection", inputs=None, outputs=active_tab)
# Model version buttons
link_button_v7 = gr.Button("Use Model V7")
link_button_v8 = gr.Button("Use Model V8")
link_button_v11 = gr.Button("Use Model V11")
def update_model_version_v7():
global model_repo_id
model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v7-sdxl"
global pipe
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pipe = pipe.to(device)
def update_model_version_v8():
global model_repo_id
model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl"
global pipe
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pipe = pipe.to(device)
def update_model_version_v11():
global model_repo_id
model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v11-sdxl"
global pipe
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pipe = pipe.to(device)
# Link button actions
link_button_v7.click(update_model_version_v7)
link_button_v8.click(update_model_version_v8)
link_button_v11.click(update_model_version_v11)
run_button.click(
infer,
inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
selected_participant_tags, selected_tribe_tags, selected_skin_tone_tags, selected_body_type_tags,
selected_tattoo_tags, selected_piercing_tags, selected_expression_tags, selected_eye_tags,
selected_hair_style_tags, selected_position_tags, selected_fetish_tags, selected_location_tags,
selected_camera_tags, selected_atmosphere_tags, active_tab],
outputs=[result, seed, prompt_info]
)
demo.queue()