File size: 8,409 Bytes
2b8b77d
 
1eb5467
e1fcf74
2b8b77d
 
 
 
e2371c5
16908f1
3c2c999
003a054
 
2d475e1
 
c6dfa2b
2d475e1
e1fcf74
1314d69
f54a55a
 
2d475e1
 
292c38f
e1fcf74
e2371c5
140db10
 
3d9ac9f
 
 
 
 
 
 
 
 
 
 
 
 
7b28dab
140db10
db1abac
 
b777a65
 
1314d69
140db10
f217e4d
9a397ea
 
140db10
 
 
718ba97
00fc70b
f217e4d
1eb5467
0cbf06a
ca9e441
 
0cbf06a
3d9ac9f
c464ec4
e1ad51f
 
9d731d3
0cbf06a
 
 
 
 
 
f40fb7c
cd2465c
5e49d53
 
 
 
81ffcd6
dc2976a
3d9ac9f
2d475e1
f217e4d
3d9ac9f
 
 
 
 
 
 
 
718ba97
 
1314d69
9d731d3
29017ec
 
9d731d3
 
29017ec
 
 
 
 
 
 
 
 
 
 
 
 
b777a65
 
 
140db10
f618cec
 
 
 
9d731d3
 
b777a65
0cbf06a
29017ec
f217e4d
 
 
3d9ac9f
d5a8945
7b9e6e4
718ba97
 
f217e4d
ccc38b8
 
145506a
140db10
 
 
 
db1abac
 
 
be22d29
db1abac
2d475e1
53ec176
db1abac
ccc38b8
 
f618cec
 
3d9ac9f
5e49d53
f618cec
132c798
 
f618cec
ccc38b8
 
140db10
 
 
ccc38b8
 
 
 
 
140db10
ccc38b8
140db10
9d731d3
 
e4f255d
140db10
 
1eb5467
718ba97
 
74f1b98
164edec
cd2465c
50d6862
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import gradio as gr
import spaces
from clip_slider_pipeline import CLIPSliderFlux
from diffusers import FluxPipeline, AutoencoderTiny
import torch
import numpy as np
import cv2
from PIL import Image
from diffusers.utils import load_image
from diffusers.utils import export_to_gif
import random

# load pipelines
base_model = "black-forest-labs/FLUX.1-schnell"

taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to("cuda")
pipe = FluxPipeline.from_pretrained(base_model,
                                    vae=taef1,
                                    torch_dtype=torch.bfloat16)

pipe.transformer.to(memory_format=torch.channels_last)
# pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
# pipe.enable_model_cpu_offload()
clip_slider = CLIPSliderFlux(pipe, device=torch.device("cuda"))


MAX_SEED = 2**32-1

def convert_to_centered_scale(num):
    if num <= 0:
        raise ValueError("Input must be a positive integer")
    
    if num % 2 == 0:  # even
        start = -(num // 2 - 1)
        end = num // 2
    else:  # odd
        start = -(num // 2)
        end = num // 2
    
    return tuple(range(start, end + 1))

@spaces.GPU(duration=200)
def generate(concept_1, concept_2, scale, prompt, randomize_seed=True, seed=42, recalc_directions=True, iterations=200, steps=4, interm_steps=9, guidance_scale=3.5,
             x_concept_1="", x_concept_2="", 
             avg_diff_x=None, 
             total_images=[],
             progress=gr.Progress(track_tqdm=True)
             ):
    slider_x = [concept_2, concept_1]
    # check if avg diff for directions need to be re-calculated
    print("slider_x", slider_x)
    print("x_concept_1", x_concept_1, "x_concept_2", x_concept_2)
    if randomize_seed:
            seed = random.randint(0, MAX_SEED)
        
    if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]) or recalc_directions:
        avg_diff = clip_slider.find_latent_direction(slider_x[0], slider_x[1], num_iterations=iterations)
        x_concept_1, x_concept_2 = slider_x[0], slider_x[1]

    images = []
    high_scale = scale
    low_scale = -1 * scale
    for i in range(interm_steps):
        cur_scale = low_scale + (high_scale - low_scale) * i / (interm_steps - 1)
        image = clip_slider.generate(prompt, 
                                     width=768,
                                     height=768,
                                     guidance_scale=guidance_scale, 
                                     scale=cur_scale,  seed=seed, num_inference_steps=steps, avg_diff=avg_diff) 
        images.append(image)
    canvas = Image.new('RGB', (256*interm_steps, 256))
    for i, im in enumerate(images):
        canvas.paste(im.resize((256,256)), (256 * i, 0))

    comma_concepts_x = f"{slider_x[1]}, {slider_x[0]}"

    scale_total = convert_to_centered_scale(interm_steps)
    scale_min = scale_total[0]
    scale_max = scale_total[-1]
    scale_middle = scale_total.index(0)
    post_generation_slider_update = gr.update(label=comma_concepts_x, value=0, minimum=scale_min, maximum=scale_max, interactive=True)
    avg_diff_x = avg_diff.cpu()
    
    return x_concept_1,x_concept_2, avg_diff_x, export_to_gif(images, "clip.gif", fps=5), canvas, images, images[scale_middle], post_generation_slider_update, seed

def update_pre_generated_images(slider_value, total_images):
    number_images = len(total_images)
    if(number_images > 0):
        scale_tuple = convert_to_centered_scale(number_images)
        return total_images[scale_tuple.index(slider_value)]
    else:
        return None
    
def reset_recalc_directions():
    return True


intro = """
<div style="display: flex;align-items: center;justify-content: center">
    <img src="https://huggingface.co/spaces/LatentNavigation/latentnavigation-flux/resolve/main/Group 4-16.png" width="120" style="display: inline-block">
    <h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block;font-size:1.1em">Latent Navigation</h1>
</div>
<div style="display: flex;align-items: center;justify-content: center">
    <h3 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Exploring CLIP text space with FLUX.1 schnell 🪐</h3>
</div>
<p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block">
    <a href="https://github.com/linoytsaban/semantic-sliders" target="_blank">code</a>
     | 
    <a href="https://huggingface.co/spaces/LatentNavigation/latentnavigation-flux?duplicate=true" target="_blank" style="
        display: inline-block;
    ">
    <img style="margin-top: -1em;margin-bottom: 0em;position: absolute;" src="https://bit.ly/3CWLGkA" alt="Duplicate Space"></a>
</p>
"""
css='''
#strip, #gif{min-height: 50px}
'''
examples = [["winter", "summer", 1.25, "a dog in the park"], ["USA suburb", "Europe", 2, "a house"], ["rotten", "super fresh", 2, "a tomato"]]
image_seq = gr.Image(label="Strip", elem_id="strip")
output_image = gr.Image(label="Gif", elem_id="gif")
post_generation_image = gr.Image(label="Generated Images")
post_generation_slider = gr.Slider(minimum=-2, maximum=2, value=0, step=1, interactive=False)
seed = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, label="Seed", interactive=True, randomize=True)

with gr.Blocks(css=css) as demo:

    gr.HTML(intro)
    
    x_concept_1 = gr.State("")
    x_concept_2 = gr.State("")
    total_images = gr.State([])

    avg_diff_x = gr.State()

    recalc_directions = gr.State(False)
    
    with gr.Row():
        with gr.Column():
            with gr.Row():
                concept_1 = gr.Textbox(label="1st direction to steer", placeholder="winter")
                concept_2 = gr.Textbox(label="2nd direction to steer", placeholder="summer")
            prompt = gr.Textbox(label="Prompt", info="Describe what you to be steered by the directions", placeholder="A dog in the park")
            x = gr.Slider(minimum=0, value=1.5, step=0.1, maximum=4.0, label="Strength", info="maximum strength on each direction (unstable beyond 2.5)")
            submit = gr.Button("Generate directions")
            gr.Examples(
                examples=examples,
                inputs=[concept_1, concept_2, x, prompt],
                fn=generate,
                outputs=[x_concept_1, x_concept_2, avg_diff_x, output_image, image_seq, total_images, post_generation_image, post_generation_slider, seed],
                cache_examples="lazy"
            )
        with gr.Column():
            with gr.Group(elem_id="group"):
                post_generation_image.render()
                post_generation_slider.render()
            with gr.Row():
                with gr.Column(scale=4, min_width=50):
                    image_seq.render()
                    
                with gr.Column(scale=2, min_width=50):
                    output_image.render()
    
    with gr.Accordion(label="advanced options", open=False):
        iterations = gr.Slider(label = "num iterations for clip directions", minimum=0, value=200, maximum=500, step=1)
        steps = gr.Slider(label = "num inference steps", minimum=1, value=3, maximum=8, step=1)
        interm_steps = gr.Slider(label = "num of intermediate images", minimum=3, value=21, maximum=65, step=2)
        guidance_scale = gr.Slider(
                label="Guidance scale",
                minimum=0.1,
                maximum=10.0,
                step=0.1,
                value=3.5,
            )
        randomize_seed = gr.Checkbox(True, label="Randomize seed")
        seed.render()
     
    submit.click(fn=generate,
                     inputs=[concept_1, concept_2, x, prompt, randomize_seed, seed, recalc_directions, iterations, steps, interm_steps, guidance_scale, x_concept_1, x_concept_2, avg_diff_x, total_images],
                     outputs=[x_concept_1, x_concept_2, avg_diff_x, output_image, image_seq, total_images, post_generation_image, post_generation_slider, seed])

    iterations.change(fn=reset_recalc_directions, outputs=[recalc_directions])
    seed.change(fn=reset_recalc_directions, outputs=[recalc_directions])
    post_generation_slider.change(fn=update_pre_generated_images, inputs=[post_generation_slider, total_images], outputs=[post_generation_image], queue=False, show_progress="hidden", concurrency_limit=None)
        
if __name__ == "__main__":
    demo.launch()