Image / app.py
Staticaliza's picture
Update app.py
1eea9d2 verified
raw
history blame
6.24 kB
# Imports
import gradio as gr
import random
import spaces
import torch
import uuid
import os
from transformers import pipeline
from diffusers import StableDiffusionXLPipeline, ControlNetModel
from diffusers.models import AutoencoderKL
from PIL import Image
# Pre-Initialize
DEVICE = "auto"
if DEVICE == "auto":
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[SYSTEM] | Using {DEVICE} type compute device.")
# Variables
MAX_SEED = 9007199254740991
DEFAULT_INPUT = ""
DEFAULT_NEGATIVE_INPUT = "(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, (exposed, explicit, porn, nude, nudity, naked, nsfw:1.25)"
DEFAULT_MODEL = "Default"
DEFAULT_HEIGHT = 1024
DEFAULT_WIDTH = 1024
css = '''
.gradio-container{max-width: 560px !important}
h1{text-align:center}
footer {
visibility: hidden
}
'''
repo_nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
controlnet = ControlNetModel.from_pretrained("MakiPan/controlnet-encoded-hands-130k", torch_dtype=torch.float16)
repo_default = StableDiffusionXLPipeline.from_pretrained("fluently/Fluently-XL-Final", vae=vae, controlnet=controlnet, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False)
repo_default.load_lora_weights("ehristoforu/dalle-3-xl-v2", adapter_name="base")
repo_default.set_adapters("base")
# repo_default.set_adapters(["base"], adapter_weights=[0.7])
repo_pixel = StableDiffusionXLPipeline.from_pretrained("sd-community/sdxl-flash", vae=vae, controlnet=controlnet, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False)
repo_pixel.load_lora_weights("artificialguybr/PixelArtRedmond", adapter_name="base")
repo_pixel.load_lora_weights("nerijs/pixel-art-xl", adapter_name="base2")
repo_pixel.set_adapters(["base", "base2"], adapter_weights=[1, 1])
repo_customs = {
"Default": repo_default,
"Realistic": StableDiffusionXLPipeline.from_pretrained("SG161222/RealVisXL_V4.0", vae=vae, controlnet=controlnet, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False),
"Anime": StableDiffusionXLPipeline.from_pretrained("cagliostrolab/animagine-xl-3.1", vae=vae, controlnet=controlnet, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False),
"Pixel": repo_pixel,
}
# Functions
def save_image(img, seed):
name = f"{seed}-{uuid.uuid4()}.png"
img.save(name)
return name
def get_seed(seed):
seed = seed.strip()
if seed.isdigit():
return int(seed)
else:
return random.randint(0, MAX_SEED)
@spaces.GPU(duration=60)
def generate(input=DEFAULT_INPUT, filter_input="", negative_input=DEFAULT_NEGATIVE_INPUT, model=DEFAULT_MODEL, height=DEFAULT_HEIGHT, width=DEFAULT_WIDTH, steps=1, guidance=0, number=1, seed=None):
repo = repo_customs[model or "Default"]
filter_input = filter_input or ""
negative_input = negative_input or DEFAULT_NEGATIVE_INPUT
steps_set = steps
guidance_set = guidance
seed = get_seed(seed)
print(input, filter_input, negative_input, model, height, width, steps, guidance, number, seed)
if model == "Realistic":
steps_set = 35
guidance_set = 6
elif model == "Anime":
steps_set = 35
guidance_set = 6
elif model == "Pixel":
steps_set = 15
guidance_set = 1.5
else:
steps_set = 25
guidance_set = 6
if not steps or steps < 0:
steps = steps_set
if not guidance or guidance < 0:
guidance = guidance_set
print(steps, guidance)
repo.to(DEVICE)
parameters = {
"prompt": input,
"negative_prompt": filter_input + negative_input,
"height": height,
"width": width,
"num_inference_steps": steps,
"guidance_scale": guidance,
"num_images_per_prompt": number,
"controlnet_conditioning_scale": 1,
"cross_attention_kwargs": {"scale": 1},
"generator": torch.Generator().manual_seed(seed),
"use_resolution_binning": True,
"output_type":"pil",
}
images = repo(**parameters).images
image_paths = [save_image(img, seed) for img in images]
print(image_paths)
nsfw_prediction = repo_nsfw_classifier(Image.open(image_paths[0]))
print(nsfw_prediction)
return image_paths, {item['label']: round(item['score'], 3) for item in nsfw_prediction}
def cloud():
print("[CLOUD] | Space maintained.")
# Initialize
with gr.Blocks(css=css) as main:
with gr.Column():
gr.Markdown("🪄 Generate high quality images on all styles between 10 to 20 seconds.")
with gr.Column():
input = gr.Textbox(lines=1, value=DEFAULT_INPUT, label="Input")
filter_input = gr.Textbox(lines=1, value="", label="Input Filter")
negative_input = gr.Textbox(lines=1, value=DEFAULT_NEGATIVE_INPUT, label="Input Negative")
model = gr.Dropdown(label="Models", choices=repo_customs.keys(), value="Default")
height = gr.Slider(minimum=8, maximum=2160, step=1, value=DEFAULT_HEIGHT, label="Height")
width = gr.Slider(minimum=8, maximum=2160, step=1, value=DEFAULT_WIDTH, label="Width")
steps = gr.Slider(minimum=-1, maximum=100, step=1, value=-1, label="Steps")
guidance = gr.Slider(minimum=-1, maximum=100, step=0.001, value=-1, label = "Guidance")
number = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="Number")
seed = gr.Textbox(lines=1, value="", label="Seed (Blank for random)")
submit = gr.Button("▶")
maintain = gr.Button("☁️")
with gr.Column():
images = gr.Gallery(columns=1, label="Image")
nsfw_classifier = gr.Label()
submit.click(generate, inputs=[input, filter_input, negative_input, model, height, width, steps, guidance, number, seed], outputs=[images, nsfw_classifier], queue=False)
maintain.click(cloud, inputs=[], outputs=[], queue=False)
main.launch(show_api=True)