File size: 3,761 Bytes
24a6868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52abc8d
24a6868
 
 
 
 
 
 
 
 
 
 
 
 
 
6095378
 
 
 
 
 
 
 
7de4c3e
c7f68b4
 
 
 
0563fad
24a6868
 
 
 
 
 
 
 
e9947e2
24a6868
0563fad
 
e9947e2
24a6868
0563fad
24a6868
d62c968
5a7e564
24a6868
 
 
c6a4957
e9947e2
c7f68b4
0563fad
24a6868
0563fad
24a6868
 
3eeb179
c7f68b4
ffaf784
 
24a6868
3389734
 
 
7de4c3e
6095378
 
 
 
 
 
 
 
e9947e2
6095378
 
3389734
24a6868
6095378
e7066e9
24a6868
e9947e2
3389734
24a6868
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Imports
import gradio as gr
import random
import spaces
import torch
import numpy
import uuid
import json
import os
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
from PIL import Image

# Pre-Initialize
DEVICE = "auto"
if DEVICE == "auto":
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[SYSTEM] | Using {DEVICE} type compute device.")

# Variables
MAX_SEED = 9007199254740991
DEFAULT_INPUT = ""
DEFAULT_NEGATIVE_INPUT = "(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW"
DEFAULT_HEIGHT = 1024
DEFAULT_WIDTH = 1024

REPO = "sd-community/sdxl-flash"
REPO_WEIGHT = "ehristoforu/dalle-3-xl-v2"
WEIGHT = "dalle-3-xl-lora-v2.safetensors"
ADAPTER = "dalle"

model = StableDiffusionXLPipeline.from_pretrained(REPO, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False)
model.scheduler = EulerAncestralDiscreteScheduler.from_config(model.scheduler.config)
model.load_lora_weights(REPO_WEIGHT, weight_name=WEIGHT, adapter_name=ADAPTER)
model.set_adapters(ADAPTER, adapter_weights=[0.7])
model.to(DEVICE)

css = '''
.gradio-container{max-width: 560px !important}
h1{text-align:center}
footer {
    visibility: hidden
}
'''

# Functions
def save_image(img, seed):
    name = f"{seed}-{uuid.uuid4()}.png"
    img.save(name)
    return name
    
def get_seed(seed):
    seed = seed.strip()
    if seed.isdigit():
        return int(seed)
    else:
        return random.randint(0, MAX_SEED)

@spaces.GPU(duration=30)
def generate(input=DEFAULT_INPUT, negative_input=DEFAULT_NEGATIVE_INPUT, height=DEFAULT_HEIGHT, width=DEFAULT_WIDTH, steps=1, guidance=0, number=1, seed=None):
    
    seed = get_seed(seed)

    print(input, negative_input, height, width, steps, guidance, number, seed)

    model.to(DEVICE)
    parameters  = {
        "prompt": input,
        "negative_prompt": negative_input,
        "height": height,
        "width": width,
        "num_inference_steps": steps,
        "guidance_scale": guidance,
        "num_images_per_prompt": number,
        "cross_attention_kwargs": {"scale": 0.01},
        "generator": torch.Generator().manual_seed(seed),
        "use_resolution_binning": True,
        "output_type":"pil",
    }
    
    images = model(**parameters).images
    image_paths = [save_image(img, seed) for img in images]
    print(image_paths)
    return image_paths

def cloud():
    print("[CLOUD] | Space maintained.")

# Initialize
with gr.Blocks(css=css) as main:
    with gr.Column():
        input = gr.Textbox(lines=1, value=DEFAULT_INPUT, label="Input")
        negative_input = gr.Textbox(lines=1, value=DEFAULT_NEGATIVE_INPUT, label="Input Negative")
        height = gr.Slider(minimum=1, maximum=2160, step=1, value=DEFAULT_HEIGHT, label="Height")
        width = gr.Slider(minimum=1, maximum=2160, step=1, value=DEFAULT_WIDTH, label="Width")
        steps = gr.Slider(minimum=0, maximum=100, step=1, value=8, label="Steps")
        guidance = gr.Slider(minimum=0, maximum=100, step=0.001, value=3, label = "Guidance")
        number = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="Number")
        seed = gr.Textbox(lines=1, value="", label="Seed (Blank for random)")
        submit = gr.Button("▶")
        maintain = gr.Button("☁️")

    with gr.Column():
        images = gr.Gallery(columns=1, label="Image")
            
    submit.click(generate, inputs=[input, negative_input, height, width, steps, guidance, number, seed], outputs=[images], queue=False)
    maintain.click(cloud, inputs=[], outputs=[], queue=False)

main.launch(show_api=True)