File size: 3,378 Bytes
0fe1449
30bf7bc
0fe1449
 
 
 
 
 
 
 
 
b1b860e
 
 
 
0fe1449
 
 
 
 
 
 
 
 
 
b12e444
48df6f9
b12e444
0fe1449
 
b12e444
0fe1449
 
 
b12e444
0fe1449
 
 
acb6168
0fe1449
 
 
 
 
 
 
 
b1b860e
855f3b2
0fe1449
 
c4f32a7
0fe1449
 
 
477a209
0fe1449
 
44f95c7
0fe1449
b12e444
0fe1449
 
48df6f9
0fe1449
 
 
 
 
 
b12e444
0fe1449
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
from diffusers import StableDiffusionXLPipeline
import numpy as np
import math
import spaces 
import torch 
import sys
import random

from gradio_imageslider import ImageSlider

theme = gr.themes.Base(
    font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
)

pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    custom_pipeline="multimodalart/sdxl_perturbed_attention_guidance",
    torch_dtype=torch.float16
)

device="cuda"
pipe = pipe.to(device)

@spaces.GPU
def run(prompt, negative_prompt="", guidance_scale=7.0, pag_scale=3.0, pag_layers=["mid"], randomize_seed=True, seed=42, progress=gr.Progress(track_tqdm=True)):
    prompt = prompt.strip()
    negative_prompt = negative_prompt.strip()
    if(randomize_seed):
        seed = random.randint(0, sys.maxsize)
    if(prompt == "" and negative_prompt == ""):
        guidance_scale = 0.0
        
    generator = torch.Generator(device="cuda").manual_seed(seed)
    image_pag = pipe(prompt, guidance_scale=guidance_scale, pag_scale=pag_scale, pag_applied_layers=pag_layers, generator=generator, num_inference_steps=25).images[0]    
    
    generator = torch.Generator(device="cuda").manual_seed(seed)
    image_normal = pipe(prompt, guidance_scale=guidance_scale, generator=generator, num_inference_steps=25).images[0]
    return (image_pag, image_normal), seed

css = '''
.gradio-container{
max-width: 768px !important;
margin: 0 auto;
}
'''

with gr.Blocks(css=css, theme=theme) as demo:
    gr.Markdown('''# Perturbed-Attention Guidance SDXL
    SDXL 🧨 [diffusers implementation](https://huggingface.co/multimodalart/sdxl_perturbed_attention_guidance) of [Perturbed-Attenton Guidance](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/)
    ''')
    with gr.Group():
      with gr.Row():
        prompt = gr.Textbox(show_label=False, scale=4, placeholder="Your prompt", info="Leave blank to test unconditional generation")
        button = gr.Button("Generate", min_width=120)
      output = ImageSlider(label="Left: PAG, Right: No PAG", interactive=False)
      with gr.Accordion("Advanced Settings", open=False):
        guidance_scale = gr.Number(label="Guidance Scale", value=7.0)
        negative_prompt = gr.Textbox(label="Negative prompt", info="Is only applied for the CFG part, leave blank for unconditional generation")
        pag_scale = gr.Number(label="Pag Scale", value=3.0)
        pag_layers = gr.Dropdown(label="Model layers to apply Pag to", info="mid is the one used on the paper, up and down blocks seem unstable", choices=["up", "mid", "down"], multiselect=True, value="mid")
        randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
        seed = gr.Slider(minimum=1, maximum=18446744073709551615, step=1, randomize=True)
    gr.Examples(fn=run, examples=[" ", "an insect robot preparing a delicious meal, anime style", "a photo of a group of friends at an amusement park"], inputs=prompt, outputs=[output, seed], cache_examples=True)
    gr.on(
        triggers=[
            button.click,
            prompt.submit
        ],
        fn=run,
        inputs=[prompt, negative_prompt, guidance_scale, pag_scale, pag_layers, randomize_seed, seed],
        outputs=[output, seed],
    )
if __name__ == "__main__":
    demo.launch(share=True)