File size: 5,551 Bytes
613d518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
851fa96
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import sys
sys.path.append('./')

import torch
import random
import spaces
import gradio as gr

from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image

# global variable
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32

def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
    if randomize_seed:
        seed = random.randint(0, 2000)
    return seed

pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype).to(device)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")

@spaces.GPU(enable_queue=True)
def create_image(image_pil,
                 prompt,
                 n_prompt,
                 scale, 
                 control_scale, 
                 guidance_scale,
                 num_inference_steps,
                 seed,
                 target="Load only style blocks",
                 ):

    
    if target !="Load original IP-Adapter":
        if target=="Load only style blocks":
            scale = {
                "up": {"block_0": [0.0, control_scale, 0.0]},
            }
        elif target=="Load only layout blocks":
            scale = {
                "down": {"block_2": [0.0, control_scale]},
            }
        elif target == "Load style+layout block":
            scale = {
                "down": {"block_2": [0.0, control_scale]},
                "up": {"block_0": [0.0, control_scale, 0.0]},
            }
        pipeline.set_ip_adapter_scale(scale)
    
    print(image_pil)
    
    style_image = load_image(image_pil)

    generator = torch.Generator(device=device).manual_seed(randomize_seed_fn(seed, False))
    image = pipeline(
        prompt=prompt,
        ip_adapter_image=style_image,
        negative_prompt=n_prompt,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        generator=generator,
    )
    return image
    

# Description
title = r"""
<h1 align="center">InstantStyle</h1>
"""

description = r"""
How to use:<br>
1. Upload a style image.
2. Set stylization mode, only use style block by default.
2. Enter a text prompt, as done in normal text-to-image models.
3. Click the <b>Submit</b> button to begin customization.
4. Share your stylized photo with your friends and enjoy! 😊


Advanced usage:<br>
1. Click advanced options.
2. Upload another source image for image-based stylization using ControlNet.
3. Enter negative content prompt to avoid content leakage.
"""

article = r"""
---
```bibtex
@article{wang2024instantstyle,
  title={InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation},
  author={Wang, Haofan and Wang, Qixun and Bai, Xu and Qin, Zekui and Chen, Anthony},
  journal={arXiv preprint arXiv:2404.02733},
  year={2024}
}
```
"""

block = gr.Blocks().queue(max_size=10, api_open=True)
with block:
    
    # description
    gr.Markdown(title)
    gr.Markdown(description)
    
    with gr.Tabs():
        with gr.Row():
            with gr.Column():
                
                with gr.Row():
                    with gr.Column():
                        image_pil = gr.Image(label="Style Image", type="pil")
                
                target = gr.Radio(["Load only style blocks", "Load only layout blocks","Load style+layout block", "Load original IP-Adapter"], 
                                  value="Load only style blocks",
                                  label="Style mode")
                
                prompt = gr.Textbox(label="Prompt",
                                    value="a cat, masterpiece, best quality, high quality")
                
                scale = gr.Slider(minimum=0,maximum=2.0, step=0.01,value=1.0, label="Scale")
                
                with gr.Accordion(open=False, label="Advanced Options"):
                    
                    control_scale = gr.Slider(minimum=0,maximum=1.0, step=0.01,value=0.5, label="Controlnet conditioning scale")
                    
                    n_prompt = gr.Textbox(label="Neg Prompt", value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry")
                    guidance_scale = gr.Slider(minimum=1,maximum=15.0, step=0.01,value=5.0, label="guidance scale")
                    num_inference_steps = gr.Slider(minimum=5,maximum=50.0, step=1.0,value=20, label="num inference steps")
                    seed = gr.Slider(minimum=-1000000,maximum=1000000,value=1, step=1, label="Seed Value")
                    randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
                    
                generate_button = gr.Button("Generate Image")
                
            with gr.Column():
                generated_image = gr.Gallery(label="Generated Image")

        generate_button.click(
            fn=randomize_seed_fn,
            inputs=[seed, randomize_seed],
            outputs=seed,
            queue=False,
            api_name=False,
        ).then(
            fn=create_image,
            inputs=[image_pil,
                    prompt,
                    n_prompt,
                    scale, 
                    control_scale, 
                    guidance_scale,
                    num_inference_steps,
                    seed,
                    target], 
            outputs=[generated_image])
    
    gr.Markdown(article)

block.launch(show_error=True)