File size: 5,444 Bytes
613d518
 
78b9267
613d518
408db9c
c3ac009
 
408db9c
8d6f040
ac8d8ef
c3ac009
 
ec47174
a5df98e
 
 
 
c3ac009
 
613d518
 
 
 
 
 
 
 
 
 
3e94094
613d518
 
 
 
 
 
 
 
 
 
 
 
 
 
78b9267
c3ac009
 
 
3e94094
 
8d6f040
c3ac009
 
 
 
 
 
 
c9cab76
613d518
 
7291f37
613d518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78b9267
613d518
 
 
 
 
86f92e3
 
 
 
 
 
 
 
 
 
 
613d518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16fe1eb
613d518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78b9267
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import torch
import random
import spaces
import gradio as gr
from PIL import Image
from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype)
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
pipe.to(device)
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
    if randomize_seed:
        seed = random.randint(0, 2000)
    return seed
    
@spaces.GPU()
def create_image(image_pil,
                 prompt,
                 n_prompt,
                 scale, 
                 control_scale, 
                 guidance_scale,
                 num_inference_steps,
                 seed,
                 target="Load only style blocks",
                 ):
    
    if target !="Load original IP-Adapter":
        if target=="Load only style blocks":
            scale = {
                "up": {"block_0": [0.0, control_scale, 0.0]},
            }
        elif target=="Load only layout blocks":
            scale = {
                "down": {"block_2": [0.0, control_scale]},
            }
        elif target == "Load style+layout block":
            scale = {
                "down": {"block_2": [0.0, control_scale]},
                "up": {"block_0": [0.0, control_scale, 0.0]},
            }
        pipe.set_ip_adapter_scale(scale)

    style_image = load_image(image_pil)
    generator = torch.Generator().manual_seed(randomize_seed_fn(seed, True))
    
    
    image = pipe(
        prompt=prompt,
        ip_adapter_image=style_image,
        negative_prompt=n_prompt,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        generator=generator,
    ).images[0]
    
    return image
    
    

# Description
title = r"""
<h1 align="center">InstantStyle</h1>
"""

description = r"""
How to use:<br>
1. Upload a style image.
2. Set stylization mode, only use style block by default.
2. Enter a text prompt, as done in normal text-to-image models.
3. Click the <b>Submit</b> button to begin customization.
"""

article = r"""
---
```bibtex
@article{wang2024instantstyle,
  title={InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation},
  author={Wang, Haofan and Wang, Qixun and Bai, Xu and Qin, Zekui and Chen, Anthony},
  journal={arXiv preprint arXiv:2404.02733},
  year={2024}
}
```
"""

block = gr.Blocks().queue(max_size=10, api_open=True)
with block:
    gr.HTML("""
    <style>
        ::-webkit-scrollbar {
            display: none; 
        }
        #component-0 {
            max-width: 800px;
            margin: 0 auto; 
        }
    </style>
    """)
    # description
    gr.Markdown(title)
    gr.Markdown(description)
    
    with gr.Tabs():
        with gr.Row():
            with gr.Column():
                
                with gr.Row():
                    with gr.Column():
                        image_pil = gr.Image(label="Style Image", type="pil")
                
                target = gr.Radio(["Load only style blocks", "Load only layout blocks","Load style+layout block", "Load original IP-Adapter"], 
                                  value="Load only style blocks",
                                  label="Style mode")
                
                prompt = gr.Textbox(label="Prompt",
                                    value="a cat, masterpiece, best quality, high quality")
                
                scale = gr.Slider(minimum=0,maximum=2.0, step=0.01,value=1.0, label="Scale")
                
                with gr.Accordion(open=False, label="Advanced Options"):
                    
                    control_scale = gr.Slider(minimum=0,maximum=1.0, step=0.01,value=0.5, label="Controlnet conditioning scale")
                    
                    n_prompt = gr.Textbox(label="Neg Prompt", value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry")
                    guidance_scale = gr.Slider(minimum=1,maximum=15.0, step=0.01,value=5.0, label="guidance scale")
                    num_inference_steps = gr.Slider(minimum=5,maximum=50.0, step=1.0,value=20, label="num inference steps")
                    seed = gr.Slider(minimum=-1000000,maximum=1000000,value=1, step=1, label="Seed Value")
                    randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
                    
                generate_button = gr.Button("Generate Image")
                
            with gr.Column():
                generated_image = gr.Image(label="Generated Image", show_label=False)

        generate_button.click(
            fn=randomize_seed_fn,
            inputs=[seed, randomize_seed],
            outputs=seed,
            queue=False,
            api_name=False,
        ).then(
            fn=create_image,
            inputs=[image_pil,
                    prompt,
                    n_prompt,
                    scale, 
                    control_scale, 
                    guidance_scale,
                    num_inference_steps,
                    seed,
                    target], 
            outputs=[generated_image])
    
    gr.Markdown(article)

block.launch(show_error=True)