File size: 3,911 Bytes
5c4b5eb
6b566c6
5c4b5eb
 
 
 
 
 
 
 
aeccac3
 
 
 
 
 
 
 
 
5c4b5eb
 
 
 
cedad44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aeccac3
 
 
 
5c4b5eb
 
e07cbb0
 
 
 
8bfad22
 
 
 
 
5c4b5eb
dda75dc
cedad44
 
 
5c4b5eb
 
 
cedad44
5c4b5eb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import time
import spaces
import gradio as gr
import torch
import diffusers
from utils import patch_attention_proc
import math
import numpy as np
from PIL import Image

# Globals
css = """
h1 {
  text-align: center;
  display: block;
}
"""

# Pipeline
pipe = diffusers.StableDiffusionPipeline.from_pretrained("Lykon/DreamShaper").to("cuda", torch.float16)
pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.safety_checker = None

@spaces.GPU
def generate(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method):

    downsample_factor = 2
    ratio = 0.38
    merge_method = "downsample" if method == "todo" else "similarity"
    merge_tokens = "keys/values" if method == "todo" else "all"

    if height_width == 1024:
        downsample_factor = 2
        ratio = 0.75
        downsample_factor_level_2 = 1
        ratio_level_2 = 0.0
    elif height_width == 1536:
        downsample_factor = 3
        ratio = 0.89
        downsample_factor_level_2 = 1
        ratio_level_2 = 0.0
    elif height_width == 2048:
        downsample_factor = 4
        ratio = 0.9375
        downsample_factor_level_2 = 2
        ratio_level_2 = 0.75

    token_merge_args = {"ratio": ratio,
                "merge_tokens": merge_tokens,
                "merge_method": merge_method,
                "downsample_method": "nearest",
                "downsample_factor": downsample_factor,
                "timestep_threshold_switch": 0.0,
                "timestep_threshold_stop": 0.0,
                "downsample_factor_level_2": downsample_factor_level_2,
                "ratio_level_2": ratio_level_2
                }

    l_r = torch.rand(1).item()
    torch.manual_seed(seed)
    start_time_base = time.time()
    base_img = pipe(prompt,
                    num_inference_steps=steps, height=height_width, width=height_width,
                    negative_prompt=negative_prompt,
                    guidance_scale=guidance_scale).images[0]
    end_time_base = time.time()

    patch_attention_proc(pipe.unet, token_merge_args=token_merge_args)

    torch.manual_seed(seed)
    start_time_merge = time.time()
    merged_img = pipe(prompt,
                        num_inference_steps=steps, height=height_width, width=height_width,
                        negative_prompt=negative_prompt,
                        guidance_scale=guidance_scale).images[0]
    end_time_merge = time.time()

    result = f"Baseline image: {end_time_base-start_time_base:.2f} sec  |  {'ToDo' if method == 'todo' else 'ToMe'} image: {end_time_merge-start_time_merge:.2f} sec"

    return base_img, merged_img, result



with gr.Blocks(css=css) as demo:
    gr.Markdown("# ToDo: Token Downsampling for Efficient Generation of High-Resolution Images")
    prompt = gr.Textbox(interactive=True, label="prompt")
    negative_prompt = gr.Textbox(interactive=True, label="negative_prompt")
    
    with gr.Row():
        method = gr.Dropdown(["todo", "tome"], value="todo", label="method", info="Choose Your Desired Method (Default: todo)")
        height_width = gr.Dropdown([1024, 1536, 2048], value=1024, label="height/width", info="Choose Your Desired Height/Width (Default: 1024)")

    with gr.Row():
        guidance_scale = gr.Number(label="guidance_scale", value=7.5, precision=1)
        steps = gr.Number(label="steps", value=20, precision=0)
        seed = gr.Number(label="seed", value=1, precision=0)

    result = gr.Textbox(label="Result")
    with gr.Row():
        base_image = gr.Image(label=f"baseline_image", type="pil", interactive=False)
        output_image = gr.Image(label=f"output_image", type="pil", interactive=False)

    gen = gr.Button("generate")
    gen.click(generate, inputs=[prompt, seed, steps, height_width, negative_prompt,
                                guidance_scale, method], outputs=[base_image, output_image, result])

demo.launch(share=True)