File size: 6,014 Bytes
613d518
 
78b9267
613d518
408db9c
6aaad65
408db9c
7b8a4a9
93714a0
 
 
 
 
613d518
408db9c
 
 
 
8d6f040
ac8d8ef
7b8a4a9
408db9c
 
 
 
 
 
7b8a4a9
408db9c
 
 
 
 
 
 
ec47174
7b8a4a9
a5df98e
 
 
 
408db9c
 
613d518
 
 
 
 
 
 
 
 
 
3e94094
00b6a77
 
 
613d518
 
 
 
 
 
 
 
 
 
 
 
 
 
78b9267
7b8a4a9
 
3e94094
 
8d6f040
408db9c
 
 
 
7b8a4a9
 
 
 
 
408db9c
c9cab76
7b8a4a9
 
 
c9cab76
613d518
 
7291f37
613d518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78b9267
613d518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16fe1eb
613d518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78b9267
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import torch
import random
import spaces
import gradio as gr
from PIL import Image
from models_transformer_sd3 import SD3Transformer2DModel
from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
import gc
import os
from huggingface_hub import login

TOKEN = os.getenv('TOKEN')
login(TOKEN)

model_path = 'stabilityai/stable-diffusion-3.5-large'
ip_adapter_path = './ip-adapter.bin'
image_encoder_path = "google/siglip-so400m-patch14-384"

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32

transformer = SD3Transformer2DModel.from_pretrained(
    model_path, subfolder="transformer", torch_dtype=torch.bfloat16
)

pipe = StableDiffusion3Pipeline.from_pretrained(
    model_path, transformer=transformer, torch_dtype=torch.bfloat16
) ## For ZeroGPU no .to("cuda")

pipe.init_ipadapter(
    ip_adapter_path=ip_adapter_path, 
    image_encoder_path=image_encoder_path, 
    nb_token=64, 
)

pipe.to(device)

def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
    if randomize_seed:
        seed = random.randint(0, 2000)
    return seed

@spaces.GPU() ## For ZeroGPU
def create_image(image_pil,
                 prompt,
                 n_prompt,
                 scale, 
                 control_scale, 
                 guidance_scale,
                 num_inference_steps,
                 seed,
                 target="Load only style blocks",
                 ):
    
    if image_pil is None:
        return None
    
    if target !="Load original IP-Adapter":
        if target=="Load only style blocks":
            scale = {
                "up": {"block_0": [0.0, control_scale, 0.0]},
            }
        elif target=="Load only layout blocks":
            scale = {
                "down": {"block_2": [0.0, control_scale]},
            }
        elif target == "Load style+layout block":
            scale = {
                "down": {"block_2": [0.0, control_scale]},
                "up": {"block_0": [0.0, control_scale, 0.0]},
            }
        pipe.set_ip_adapter_scale(scale)
    
    style_image = Image.open(image_pil).convert('RGB')
    
    
    image = pipe(
            width=1024,
            height=1024,
            prompt=prompt,
            negative_prompt="lowres, low quality, worst quality",
            num_inference_steps=24, 
            guidance_scale=guidance_scale,
            generator=torch.Generator("cuda").manual_seed(randomize_seed_fn(seed, True)),  ## For ZeroGPU no device="cpu"
            clip_image=style_image,
            ipadapter_scale=scale,
        ).images[0]
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()
    
    return image
    
    

# Description
title = r"""
<h1 align="center">InstantStyle</h1>
"""

description = r"""
How to use:<br>
1. Upload a style image.
2. Set stylization mode, only use style block by default.
2. Enter a text prompt, as done in normal text-to-image models.
3. Click the <b>Submit</b> button to begin customization.
"""

article = r"""
---
```bibtex
@article{wang2024instantstyle,
  title={InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation},
  author={Wang, Haofan and Wang, Qixun and Bai, Xu and Qin, Zekui and Chen, Anthony},
  journal={arXiv preprint arXiv:2404.02733},
  year={2024}
}
```
"""

block = gr.Blocks().queue(max_size=10, api_open=True)
with block:
    
    # description
    gr.Markdown(title)
    gr.Markdown(description)
    
    with gr.Tabs():
        with gr.Row():
            with gr.Column():
                
                with gr.Row():
                    with gr.Column():
                        image_pil = gr.Image(label="Style Image", type="pil")
                
                target = gr.Radio(["Load only style blocks", "Load only layout blocks","Load style+layout block", "Load original IP-Adapter"], 
                                  value="Load only style blocks",
                                  label="Style mode")
                
                prompt = gr.Textbox(label="Prompt",
                                    value="a cat, masterpiece, best quality, high quality")
                
                scale = gr.Slider(minimum=0,maximum=2.0, step=0.01,value=1.0, label="Scale")
                
                with gr.Accordion(open=False, label="Advanced Options"):
                    
                    control_scale = gr.Slider(minimum=0,maximum=1.0, step=0.01,value=0.5, label="Controlnet conditioning scale")
                    
                    n_prompt = gr.Textbox(label="Neg Prompt", value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry")
                    guidance_scale = gr.Slider(minimum=1,maximum=15.0, step=0.01,value=5.0, label="guidance scale")
                    num_inference_steps = gr.Slider(minimum=5,maximum=50.0, step=1.0,value=20, label="num inference steps")
                    seed = gr.Slider(minimum=-1000000,maximum=1000000,value=1, step=1, label="Seed Value")
                    randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
                    
                generate_button = gr.Button("Generate Image")
                
            with gr.Column():
                generated_image = gr.Image(label="Generated Image", show_label=False)

        generate_button.click(
            fn=randomize_seed_fn,
            inputs=[seed, randomize_seed],
            outputs=seed,
            queue=False,
            api_name=False,
        ).then(
            fn=create_image,
            inputs=[image_pil,
                    prompt,
                    n_prompt,
                    scale, 
                    control_scale, 
                    guidance_scale,
                    num_inference_steps,
                    seed,
                    target], 
            outputs=[generated_image])
    
    gr.Markdown(article)

block.launch(show_error=True)