File size: 6,237 Bytes
613d518
 
408db9c
613d518
408db9c
6aaad65
408db9c
93714a0
 
 
 
 
613d518
408db9c
 
00b6a77
408db9c
 
 
 
 
 
 
 
00b6a77
408db9c
 
 
 
 
 
 
a5df98e
 
 
 
408db9c
00b6a77
 
 
 
 
 
 
408db9c
613d518
 
 
 
 
 
 
 
 
 
3e94094
00b6a77
 
 
613d518
 
 
 
 
 
 
 
 
 
 
 
 
 
d62b763
c9cab76
00b6a77
 
88fc2e5
00b6a77
f0da3a3
3e94094
8d6f040
408db9c
 
 
 
88fc2e5
00b6a77
 
408db9c
c9cab76
613d518
 
7291f37
228ed52
613d518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16fe1eb
613d518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164e1f6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import torch
import random
import spaces  ## For ZeroGPU
import gradio as gr
from PIL import Image
from models_transformer_sd3 import SD3Transformer2DModel
from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
import os
from huggingface_hub import login

TOKEN = os.getenv('TOKEN')
login(TOKEN)

model_path = 'stabilityai/stable-diffusion-3.5-large'
ip_adapter_path = './ip-adapter.bin'
##ipadapter_path = hf_hub_download(repo_id="InstantX/SD3.5-Large-IP-Adapter", filename="ip-adapter.bin")
image_encoder_path = "google/siglip-so400m-patch14-384"

transformer = SD3Transformer2DModel.from_pretrained(
    model_path, subfolder="transformer", torch_dtype=torch.bfloat16
)

pipe = StableDiffusion3Pipeline.from_pretrained(
    model_path, transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")

pipe.init_ipadapter(
    ip_adapter_path=ip_adapter_path, 
    image_encoder_path=image_encoder_path, 
    nb_token=64, 
)

def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
    if randomize_seed:
        seed = random.randint(0, 2000)
    return seed

def resize_img(image, max_size=1024):
    width, height = image.size
    scaling_factor = min(max_size / width, max_size / height)
    new_width = int(width * scaling_factor)
    new_height = int(height * scaling_factor)
    return image.resize((new_width, new_height), Image.LANCZOS)

@spaces.GPU() ## For ZeroGPU
def create_image(image_pil,
                 prompt,
                 n_prompt,
                 scale, 
                 control_scale, 
                 guidance_scale,
                 num_inference_steps,
                 seed,
                 target="Load only style blocks",
                 ):
    
    if image_pil is None:
        return None
    
    if target !="Load original IP-Adapter":
        if target=="Load only style blocks":
            scale = {
                "up": {"block_0": [0.0, control_scale, 0.0]},
            }
        elif target=="Load only layout blocks":
            scale = {
                "down": {"block_2": [0.0, control_scale]},
            }
        elif target == "Load style+layout block":
            scale = {
                "down": {"block_2": [0.0, control_scale]},
                "up": {"block_0": [0.0, control_scale, 0.0]},
            }
        #pipe.set_ip_adapter_scale(scale) ## Waiting for SD3 Diffuser integration
    
    if not isinstance(image_pil, Image.Image):  # If it's a file
        image_pil = Image.fromarray(image_pil)

    image_pil = resize_img(image_pil)
    generator = torch.Generator().manual_seed(randomize_seed_fn(seed, True))
    
    image = pipe(
            width=1024,
            height=1024,
            prompt=prompt,
            negative_prompt="lowres, low quality, worst quality",
            generator=generator,  ## For ZeroGPU no device="cpu"
            clip_image=image_pil,
            ipadapter_scale=1,
        ).images[0]
    
    return image
    
    
    

# Description
title = r"""
<h1 align="center">InstantStyle</h1>
"""

description = r"""
How to use:<br>
1. Upload a style image.
2. Set stylization mode, only use style block by default.
2. Enter a text prompt, as done in normal text-to-image models.
3. Click the <b>Submit</b> button to begin customization.
"""

article = r"""
---
```bibtex
@article{wang2024instantstyle,
  title={InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation},
  author={Wang, Haofan and Wang, Qixun and Bai, Xu and Qin, Zekui and Chen, Anthony},
  journal={arXiv preprint arXiv:2404.02733},
  year={2024}
```
"""

block = gr.Blocks().queue(max_size=10, api_open=True)
with block:
    
    # description
    gr.Markdown(title)
    gr.Markdown(description)
    
    with gr.Tabs():
        with gr.Row():
            with gr.Column():
                
                with gr.Row():
                    with gr.Column():
                        image_pil = gr.Image(label="Style Image", type="pil")
                
                target = gr.Radio(["Load only style blocks", "Load only layout blocks","Load style+layout block", "Load original IP-Adapter"], 
                                  value="Load only style blocks",
                                  label="Style mode")
                
                prompt = gr.Textbox(label="Prompt",
                                    value="a cat, masterpiece, best quality, high quality")
                
                scale = gr.Slider(minimum=0,maximum=2.0, step=0.01,value=1.0, label="Scale")
                
                with gr.Accordion(open=False, label="Advanced Options"):
                    
                    control_scale = gr.Slider(minimum=0,maximum=1.0, step=0.01,value=0.5, label="Controlnet conditioning scale")
                    
                    n_prompt = gr.Textbox(label="Neg Prompt", value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry")
                    guidance_scale = gr.Slider(minimum=1,maximum=15.0, step=0.01,value=5.0, label="guidance scale")
                    num_inference_steps = gr.Slider(minimum=5,maximum=50.0, step=1.0,value=20, label="num inference steps")
                    seed = gr.Slider(minimum=-1000000,maximum=1000000,value=1, step=1, label="Seed Value")
                    randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
                    
                generate_button = gr.Button("Generate Image")
                
            with gr.Column():
                generated_image = gr.Image(label="Generated Image", show_label=False)

        generate_button.click(
            fn=randomize_seed_fn,
            inputs=[seed, randomize_seed],
            outputs=seed,
            queue=False,
            api_name=False,
        ).then(
            fn=create_image,
            inputs=[image_pil,
                    prompt,
                    n_prompt,
                    scale, 
                    control_scale, 
                    guidance_scale,
                    num_inference_steps,
                    seed,
                    target], 
            outputs=[generated_image])
    
    gr.Markdown(article)

block.launch(show_error=True, share=True)