File size: 3,761 Bytes
24a6868 52abc8d 24a6868 6095378 7de4c3e c7f68b4 0563fad 24a6868 e9947e2 24a6868 0563fad e9947e2 24a6868 0563fad 24a6868 d62c968 5a7e564 24a6868 c6a4957 e9947e2 c7f68b4 0563fad 24a6868 0563fad 24a6868 3eeb179 c7f68b4 ffaf784 24a6868 3389734 7de4c3e 6095378 e9947e2 6095378 3389734 24a6868 6095378 e7066e9 24a6868 e9947e2 3389734 24a6868 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
# Imports
import gradio as gr
import random
import spaces
import torch
import numpy
import uuid
import json
import os
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
from PIL import Image
# Pre-Initialize
DEVICE = "auto"
if DEVICE == "auto":
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[SYSTEM] | Using {DEVICE} type compute device.")
# Variables
MAX_SEED = 9007199254740991
DEFAULT_INPUT = ""
DEFAULT_NEGATIVE_INPUT = "(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW"
DEFAULT_HEIGHT = 1024
DEFAULT_WIDTH = 1024
REPO = "sd-community/sdxl-flash"
REPO_WEIGHT = "ehristoforu/dalle-3-xl-v2"
WEIGHT = "dalle-3-xl-lora-v2.safetensors"
ADAPTER = "dalle"
model = StableDiffusionXLPipeline.from_pretrained(REPO, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False)
model.scheduler = EulerAncestralDiscreteScheduler.from_config(model.scheduler.config)
model.load_lora_weights(REPO_WEIGHT, weight_name=WEIGHT, adapter_name=ADAPTER)
model.set_adapters(ADAPTER, adapter_weights=[0.7])
model.to(DEVICE)
css = '''
.gradio-container{max-width: 560px !important}
h1{text-align:center}
footer {
visibility: hidden
}
'''
# Functions
def save_image(img, seed):
name = f"{seed}-{uuid.uuid4()}.png"
img.save(name)
return name
def get_seed(seed):
seed = seed.strip()
if seed.isdigit():
return int(seed)
else:
return random.randint(0, MAX_SEED)
@spaces.GPU(duration=30)
def generate(input=DEFAULT_INPUT, negative_input=DEFAULT_NEGATIVE_INPUT, height=DEFAULT_HEIGHT, width=DEFAULT_WIDTH, steps=1, guidance=0, number=1, seed=None):
seed = get_seed(seed)
print(input, negative_input, height, width, steps, guidance, number, seed)
model.to(DEVICE)
parameters = {
"prompt": input,
"negative_prompt": negative_input,
"height": height,
"width": width,
"num_inference_steps": steps,
"guidance_scale": guidance,
"num_images_per_prompt": number,
"cross_attention_kwargs": {"scale": 0.01},
"generator": torch.Generator().manual_seed(seed),
"use_resolution_binning": True,
"output_type":"pil",
}
images = model(**parameters).images
image_paths = [save_image(img, seed) for img in images]
print(image_paths)
return image_paths
def cloud():
print("[CLOUD] | Space maintained.")
# Initialize
with gr.Blocks(css=css) as main:
with gr.Column():
input = gr.Textbox(lines=1, value=DEFAULT_INPUT, label="Input")
negative_input = gr.Textbox(lines=1, value=DEFAULT_NEGATIVE_INPUT, label="Input Negative")
height = gr.Slider(minimum=1, maximum=2160, step=1, value=DEFAULT_HEIGHT, label="Height")
width = gr.Slider(minimum=1, maximum=2160, step=1, value=DEFAULT_WIDTH, label="Width")
steps = gr.Slider(minimum=0, maximum=100, step=1, value=8, label="Steps")
guidance = gr.Slider(minimum=0, maximum=100, step=0.001, value=3, label = "Guidance")
number = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="Number")
seed = gr.Textbox(lines=1, value="", label="Seed (Blank for random)")
submit = gr.Button("▶")
maintain = gr.Button("☁️")
with gr.Column():
images = gr.Gallery(columns=1, label="Image")
submit.click(generate, inputs=[input, negative_input, height, width, steps, guidance, number, seed], outputs=[images], queue=False)
maintain.click(cloud, inputs=[], outputs=[], queue=False)
main.launch(show_api=True) |