File size: 6,280 Bytes
613d518
 
 
 
 
 
 
 
 
1a187f2
937046a
fe71bb1
613d518
ae08710
ac8d8ef
1a187f2
 
 
 
 
 
 
fe71bb1
f37caca
 
 
 
 
 
 
810bc7a
f37caca
 
 
 
fe71bb1
 
 
 
 
937046a
613d518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f37caca
70a52dd
937046a
613d518
 
 
 
 
 
 
0cb81f5
613d518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
851fa96
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import sys
sys.path.append('./')

import torch
import random
import spaces
import gradio as gr

from diffusers import AutoPipelineForText2Image
from transformers import CLIPVisionModelWithProjection
from diffusers.utils import load_image
from torchvision import transforms

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    "h94/IP-Adapter",
    subfolder="models/image_encoder",
    torch_dtype=dtype
)
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",image_encoder=image_encoder, torch_dtype=dtype).to(device)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus_sdxl_vit-h.safetensors")

def prepare_image(image_path_or_url):
    # Load the image
    image = load_image(image_path_or_url)
    
    # Convert to tensor and move to correct device and dtype
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((1024, 1024), interpolation=transforms.InterpolationMode.BICUBIC)
    ])
    image_tensor = transform(image).unsqueeze(0)  # Add batch dimension
    return image_tensor.to(device=device, dtype=dtype)

def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
    if randomize_seed:
        seed = random.randint(0, 2000)
    return seed

@spaces.GPU(enable_queue=True)
def create_image(image_pil,
                 prompt,
                 n_prompt,
                 scale, 
                 control_scale, 
                 guidance_scale,
                 num_inference_steps,
                 seed,
                 target="Load only style blocks",
                 ):
    if target !="Load original IP-Adapter":
        if target=="Load only style blocks":
            scale = {
                "up": {"block_0": [0.0, control_scale, 0.0]},
            }
        elif target=="Load only layout blocks":
            scale = {
                "down": {"block_2": [0.0, control_scale]},
            }
        elif target == "Load style+layout block":
            scale = {
                "down": {"block_2": [0.0, control_scale]},
                "up": {"block_0": [0.0, control_scale, 0.0]},
            }
        pipeline.set_ip_adapter_scale(scale)

    style_image = prepare_image(image_pil)
    generator = torch.Generator(device=device).manual_seed(randomize_seed_fn(seed, False))
    
    image = pipeline(
        prompt=prompt,
        ip_adapter_image=style_image,
        negative_prompt=n_prompt,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        generator=generator,
        device=device
    )
    return image
    

# Description
title = r"""
<h1 align="center">InstantStyle</h1>
"""

description = r"""
How to use:<br>
1. Upload a style image.
2. Set stylization mode, only use style block by default.
2. Enter a text prompt, as done in normal text-to-image models.
3. Click the <b>Submit</b> button to begin customization.
4. Share your stylized photo with your friends and enjoy! 😊


Advanced usage:<br>
1. Click advanced options.
2. Upload another source image for image-based stylization using ControlNet.
3. Enter negative content prompt to avoid content leakage.
"""

article = r"""
---
```bibtex
@article{wang2024instantstyle,
  title={InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation},
  author={Wang, Haofan and Wang, Qixun and Bai, Xu and Qin, Zekui and Chen, Anthony},
  journal={arXiv preprint arXiv:2404.02733},
  year={2024}
}
```
"""

block = gr.Blocks().queue(max_size=10, api_open=True)
with block:
    
    # description
    gr.Markdown(title)
    gr.Markdown(description)
    
    with gr.Tabs():
        with gr.Row():
            with gr.Column():
                
                with gr.Row():
                    with gr.Column():
                        image_pil = gr.Image(label="Style Image", type="pil")
                
                target = gr.Radio(["Load only style blocks", "Load only layout blocks","Load style+layout block", "Load original IP-Adapter"], 
                                  value="Load only style blocks",
                                  label="Style mode")
                
                prompt = gr.Textbox(label="Prompt",
                                    value="a cat, masterpiece, best quality, high quality")
                
                scale = gr.Slider(minimum=0,maximum=2.0, step=0.01,value=1.0, label="Scale")
                
                with gr.Accordion(open=False, label="Advanced Options"):
                    
                    control_scale = gr.Slider(minimum=0,maximum=1.0, step=0.01,value=0.5, label="Controlnet conditioning scale")
                    
                    n_prompt = gr.Textbox(label="Neg Prompt", value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry")
                    guidance_scale = gr.Slider(minimum=1,maximum=15.0, step=0.01,value=5.0, label="guidance scale")
                    num_inference_steps = gr.Slider(minimum=5,maximum=50.0, step=1.0,value=20, label="num inference steps")
                    seed = gr.Slider(minimum=-1000000,maximum=1000000,value=1, step=1, label="Seed Value")
                    randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
                    
                generate_button = gr.Button("Generate Image")
                
            with gr.Column():
                generated_image = gr.Gallery(label="Generated Image")

        generate_button.click(
            fn=randomize_seed_fn,
            inputs=[seed, randomize_seed],
            outputs=seed,
            queue=False,
            api_name=False,
        ).then(
            fn=create_image,
            inputs=[image_pil,
                    prompt,
                    n_prompt,
                    scale, 
                    control_scale, 
                    guidance_scale,
                    num_inference_steps,
                    seed,
                    target], 
            outputs=[generated_image])
    
    gr.Markdown(article)

block.launch(show_error=True)