Image / app.py
Staticaliza's picture
Update app.py
00acfc9 verified
raw
history blame
4.18 kB
# Imports
import gradio as gr
import subprocess
import random
import spaces
import torch
import numpy
import uuid
import json
import os
from diffusers import StableDiffusionXLPipeline, ControlNetModel
from diffusers.models import AutoencoderKL
from PIL import Image
# Pre-Initialize
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
DEVICE = "auto"
if DEVICE == "auto":
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[SYSTEM] | Using {DEVICE} type compute device.")
# Variables
MAX_SEED = 9007199254740991
DEFAULT_INPUT = ""
DEFAULT_NEGATIVE_INPUT = "deformed, distorted, disfigured, disconnected, disgusting, mutation, mutated, blur, blurry, scribble, abstract, ugly, amputation, limb, limbs, leg, legs, foot, feet, toe, toes, arm, arms, hand, hands, finger, fingers, head, heads, exposed, porn, nude, nudity, naked, nsfw, NSFW"
DEFAULT_HEIGHT = 1024
DEFAULT_WIDTH = 1024
REPO = "sd-community/sdxl-flash"
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
controlnet = ControlNetModel.from_pretrained("MakiPan/controlnet-encoded-hands-130k", torch_dtype=torch.float16)
model = StableDiffusionXLPipeline.from_pretrained(REPO, vae=vae, controlnet=controlnet, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False)
model.load_lora_weights("ehristoforu/dalle-3-xl-v2", adapter_name="base")
model.set_adapters(["base"], adapter_weights=[0.7])
model.to(DEVICE)
css = '''
.gradio-container{max-width: 560px !important}
h1{text-align:center}
footer {
visibility: hidden
}
'''
# Functions
def save_image(img, seed):
name = f"{seed}-{uuid.uuid4()}.png"
img.save(name)
return name
def get_seed(seed):
seed = seed.strip()
if seed.isdigit():
return int(seed)
else:
return random.randint(0, MAX_SEED)
@spaces.GPU(duration=30)
def generate(input=DEFAULT_INPUT, negative_input=DEFAULT_NEGATIVE_INPUT, height=DEFAULT_HEIGHT, width=DEFAULT_WIDTH, steps=1, guidance=0, number=1, seed=None):
seed = get_seed(seed)
print(input, negative_input, height, width, steps, guidance, number, seed)
model.to(DEVICE)
parameters = {
"prompt": input,
"negative_prompt": negative_input,
"height": height,
"width": width,
"num_inference_steps": steps,
"guidance_scale": guidance,
"num_images_per_prompt": number,
"controlnet_conditioning_scale": 1,
"cross_attention_kwargs": {"scale": 1},
"generator": torch.Generator().manual_seed(seed),
"use_resolution_binning": True,
"output_type":"pil",
}
images = model(**parameters).images
image_paths = [save_image(img, seed) for img in images]
print(image_paths)
return image_paths
def cloud():
print("[CLOUD] | Space maintained.")
# Initialize
with gr.Blocks(css=css) as main:
with gr.Column():
gr.Markdown("🪄 Generate high quality images on all styles.")
with gr.Column():
input = gr.Textbox(lines=1, value=DEFAULT_INPUT, label="Input")
negative_input = gr.Textbox(lines=1, value=DEFAULT_NEGATIVE_INPUT, label="Input Negative")
height = gr.Slider(minimum=1, maximum=2160, step=1, value=DEFAULT_HEIGHT, label="Height")
width = gr.Slider(minimum=1, maximum=2160, step=1, value=DEFAULT_WIDTH, label="Width")
steps = gr.Slider(minimum=0, maximum=100, step=1, value=15, label="Steps")
guidance = gr.Slider(minimum=0, maximum=100, step=0.001, value=3, label = "Guidance")
number = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="Number")
seed = gr.Textbox(lines=1, value="", label="Seed (Blank for random)")
submit = gr.Button("▶")
maintain = gr.Button("☁️")
with gr.Column():
images = gr.Gallery(columns=1, label="Image")
submit.click(generate, inputs=[input, negative_input, height, width, steps, guidance, number, seed], outputs=[images], queue=False)
maintain.click(cloud, inputs=[], outputs=[], queue=False)
main.launch(show_api=True)