Hatman commited on
Commit
65fe6f1
·
verified ·
1 Parent(s): 8e3f1b8

Upload 3 files

Browse files
Files changed (2) hide show
  1. app.py +159 -265
  2. requirements.txt +1 -11
app.py CHANGED
@@ -1,266 +1,160 @@
1
- import sys
2
- sys.path.append('./')
3
-
4
-
5
- import os
6
- import cv2
7
- import torch
8
- import random
9
- import numpy as np
10
- from PIL import Image
11
- from diffusers import KandinskyV22PriorPipeline, KandinskyV22ControlnetPipeline
12
-
13
- import spaces
14
- import gradio as gr
15
- from huggingface_hub import hf_hub_download
16
-
17
- from ip_adapter import IPAdapterXL
18
-
19
- import os
20
- os.system("git lfs install")
21
- os.system("git clone https://huggingface.co/h94/IP-Adapter")
22
- os.system("mv IP-Adapter/sdxl_models sdxl_models")
23
-
24
- # global variable
25
- MAX_SEED = np.iinfo(np.int32).max
26
- device = "cuda" if torch.cuda.is_available() else "cpu"
27
- dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
28
-
29
- # initialization
30
- base_model_path = "kandinsky-community/kandinsky-2-2-prior"
31
- image_encoder_path = "sdxl_models/image_encoder"
32
- ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin"
33
-
34
- controlnet_path = "kandinsky-community/kandinsky-2-2-controlnet-depth"
35
- controlnet = KandinskyV22ControlnetPipeline.from_pretrained(controlnet_path, use_safetensors=False, torch_dtype=torch.float16).to(device)
36
-
37
- # load SDXL pipeline
38
- pipe = KandinskyV22PriorPipeline.from_pretrained(
39
- base_model_path,
40
- controlnet=controlnet,
41
- torch_dtype=torch.float16,
42
- add_watermarker=False,
43
- )
44
-
45
- # load ip-adapter
46
- # target_blocks=["block"] for original IP-Adapter
47
- # target_blocks=["up_blocks.0.attentions.1"] for style blocks only
48
- # target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks
49
- ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=["up_blocks.0.attentions.1"])
50
-
51
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
52
- if randomize_seed:
53
- seed = random.randint(0, MAX_SEED)
54
- return seed
55
-
56
- def resize_img(
57
- input_image,
58
- max_side=1280,
59
- min_side=1024,
60
- size=None,
61
- pad_to_max_side=False,
62
- mode=Image.BILINEAR,
63
- base_pixel_number=64,
64
- ):
65
- w, h = input_image.size
66
- if size is not None:
67
- w_resize_new, h_resize_new = size
68
- else:
69
- ratio = min_side / min(h, w)
70
- w, h = round(ratio * w), round(ratio * h)
71
- ratio = max_side / max(h, w)
72
- input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
73
- w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
74
- h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
75
- input_image = input_image.resize([w_resize_new, h_resize_new], mode)
76
-
77
- if pad_to_max_side:
78
- res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
79
- offset_x = (max_side - w_resize_new) // 2
80
- offset_y = (max_side - h_resize_new) // 2
81
- res[
82
- offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new
83
- ] = np.array(input_image)
84
- input_image = Image.fromarray(res)
85
- return input_image
86
-
87
- @spaces.GPU(enable_queue=True)
88
- def create_image(image_pil,
89
- input_image,
90
- prompt,
91
- n_prompt,
92
- scale,
93
- control_scale,
94
- guidance_scale,
95
- num_samples,
96
- num_inference_steps,
97
- seed,
98
- target="Load only style blocks",
99
- neg_content_prompt=None,
100
- neg_content_scale=0):
101
-
102
- if isinstance(image_pil, np.ndarray):
103
- image_pil = Image.fromarray(image_pil)
104
-
105
- if target =="Load original IP-Adapter":
106
- # target_blocks=["blocks"] for original IP-Adapter
107
- ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=["blocks"])
108
- elif target=="Load only style blocks":
109
- # target_blocks=["up_blocks.0.attentions.1"] for style blocks only
110
- ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=["up_blocks.0.attentions.1"])
111
- elif target=="Load only layout blocks":
112
- # target_blocks=["up_blocks.0.attentions.1"] for style blocks only
113
- ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=["down_blocks.2.attentions.1"])
114
- elif target == "Load style+layout block":
115
- # target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks
116
- ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"])
117
-
118
- if input_image is not None:
119
- input_image = resize_img(input_image, max_side=1024)
120
- cv_input_image = pil_to_cv2(input_image)
121
- detected_map = cv2.Canny(cv_input_image, 50, 200)
122
- canny_map = Image.fromarray(cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB))
123
- else:
124
- canny_map = Image.new('RGB', (1024, 1024), color=(255, 255, 255))
125
- control_scale = 0
126
-
127
- if float(control_scale) == 0:
128
- canny_map = canny_map.resize((1024,1024))
129
-
130
- if len(neg_content_prompt) > 0 and neg_content_scale != 0:
131
- images = ip_model.generate(pil_image=image_pil,
132
- prompt=prompt,
133
- negative_prompt=n_prompt,
134
- scale=scale,
135
- guidance_scale=guidance_scale,
136
- num_samples=num_samples,
137
- num_inference_steps=num_inference_steps,
138
- seed=seed,
139
- image=canny_map,
140
- controlnet_conditioning_scale=float(control_scale),
141
- neg_content_prompt=neg_content_prompt,
142
- neg_content_scale=neg_content_scale
143
- )
144
- else:
145
- images = ip_model.generate(pil_image=image_pil,
146
- prompt=prompt,
147
- negative_prompt=n_prompt,
148
- scale=scale,
149
- guidance_scale=guidance_scale,
150
- num_samples=num_samples,
151
- num_inference_steps=num_inference_steps,
152
- seed=seed,
153
- image=canny_map,
154
- controlnet_conditioning_scale=float(control_scale),
155
- )
156
- return images
157
-
158
- def pil_to_cv2(image_pil):
159
- image_np = np.array(image_pil)
160
- image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
161
- return image_cv2
162
-
163
- # Description
164
- title = r"""
165
- <h1 align="center">InstantStyle</h1>
166
- """
167
-
168
- description = r"""
169
- How to use:<br>
170
- 1. Upload a style image.
171
- 2. Set stylization mode, only use style block by default.
172
- 2. Enter a text prompt, as done in normal text-to-image models.
173
- 3. Click the <b>Submit</b> button to begin customization.
174
- 4. Share your stylized photo with your friends and enjoy! 😊
175
-
176
-
177
- Advanced usage:<br>
178
- 1. Click advanced options.
179
- 2. Upload another source image for image-based stylization using ControlNet.
180
- 3. Enter negative content prompt to avoid content leakage.
181
- """
182
-
183
- article = r"""
184
- ---
185
- ```bibtex
186
- @article{wang2024instantstyle,
187
- title={InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation},
188
- author={Wang, Haofan and Wang, Qixun and Bai, Xu and Qin, Zekui and Chen, Anthony},
189
- journal={arXiv preprint arXiv:2404.02733},
190
- year={2024}
191
- }
192
- ```
193
- """
194
-
195
- block = gr.Blocks().queue(max_size=10, api_open=True)
196
- with block:
197
-
198
- # description
199
- gr.Markdown(title)
200
- gr.Markdown(description)
201
-
202
- with gr.Tabs():
203
- with gr.Row():
204
- with gr.Column():
205
-
206
- with gr.Row():
207
- with gr.Column():
208
- image_pil = gr.Image(label="Style Image", type="numpy")
209
-
210
- target = gr.Radio(["Load only style blocks", "Load style+layout block", "Load original IP-Adapter"],
211
- value="Load only style blocks",
212
- label="Style mode")
213
-
214
- prompt = gr.Textbox(label="Prompt",
215
- value="a cat, masterpiece, best quality, high quality")
216
-
217
- scale = gr.Slider(minimum=0,maximum=2.0, step=0.01,value=1.0, label="Scale")
218
-
219
- with gr.Accordion(open=False, label="Advanced Options"):
220
-
221
- with gr.Column():
222
- src_image_pil = gr.Image(label="Source Image (optional)", type='pil')
223
- control_scale = gr.Slider(minimum=0,maximum=1.0, step=0.01,value=0.5, label="Controlnet conditioning scale")
224
-
225
- n_prompt = gr.Textbox(label="Neg Prompt", value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry")
226
-
227
- neg_content_prompt = gr.Textbox(label="Neg Content Prompt", value="")
228
- neg_content_scale = gr.Slider(minimum=0, maximum=1.0, step=0.01,value=0.5, label="Neg Content Scale")
229
-
230
- guidance_scale = gr.Slider(minimum=1,maximum=15.0, step=0.01,value=5.0, label="guidance scale")
231
- num_samples= gr.Slider(minimum=1,maximum=4.0, step=1.0,value=1.0, label="num samples")
232
- num_inference_steps = gr.Slider(minimum=5,maximum=50.0, step=1.0,value=20, label="num inference steps")
233
- seed = gr.Slider(minimum=-1000000,maximum=1000000,value=1, step=1, label="Seed Value")
234
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
235
-
236
- generate_button = gr.Button("Generate Image")
237
-
238
- with gr.Column():
239
- generated_image = gr.Gallery(label="Generated Image")
240
-
241
- generate_button.click(
242
- fn=randomize_seed_fn,
243
- inputs=[seed, randomize_seed],
244
- outputs=seed,
245
- queue=False,
246
- api_name=False,
247
- ).then(
248
- fn=create_image,
249
- inputs=[image_pil,
250
- src_image_pil,
251
- prompt,
252
- n_prompt,
253
- scale,
254
- control_scale,
255
- guidance_scale,
256
- num_samples,
257
- num_inference_steps,
258
- seed,
259
- target,
260
- neg_content_prompt,
261
- neg_content_scale],
262
- outputs=[generated_image])
263
-
264
- gr.Markdown(article)
265
-
266
  block.launch(show_error=True)
 
1
+ import sys
2
+ sys.path.append('./')
3
+
4
+ import torch
5
+ import random
6
+ import spaces
7
+ import gradio as gr
8
+
9
+ from diffusers import AutoPipelineForText2Image
10
+ from diffusers.utils import load_image
11
+
12
+ # global variable
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
15
+
16
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
17
+ if randomize_seed:
18
+ seed = random.randint(0, 2000)
19
+ return seed
20
+
21
+ pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype).to(device)
22
+ pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
23
+
24
+ @spaces.GPU(enable_queue=True)
25
+ def create_image(image_pil,
26
+ prompt,
27
+ n_prompt,
28
+ scale,
29
+ control_scale,
30
+ guidance_scale,
31
+ num_inference_steps,
32
+ seed,
33
+ target="Load only style blocks",
34
+ ):
35
+
36
+
37
+ if target !="Load original IP-Adapter":
38
+ if target=="Load only style blocks":
39
+ scale = {
40
+ "up": {"block_0": [0.0, control_scale, 0.0]},
41
+ }
42
+ elif target=="Load only layout blocks":
43
+ scale = {
44
+ "down": {"block_2": [0.0, control_scale]},
45
+ }
46
+ elif target == "Load style+layout block":
47
+ scale = {
48
+ "down": {"block_2": [0.0, control_scale]},
49
+ "up": {"block_0": [0.0, control_scale, 0.0]},
50
+ }
51
+ pipeline.set_ip_adapter_scale(scale)
52
+
53
+
54
+ style_image = load_image(image_pil)
55
+
56
+ generator = torch.Generator(device="cpu").manual_seed(randomize_seed_fn(seed, False))
57
+ image = pipeline(
58
+ prompt=prompt,
59
+ ip_adapter_image=style_image,
60
+ negative_prompt=n_prompt,
61
+ guidance_scale=guidance_scale,
62
+ num_inference_steps=num_inference_steps,
63
+ generator=generator,
64
+ )
65
+ return image
66
+
67
+
68
+ # Description
69
+ title = r"""
70
+ <h1 align="center">InstantStyle</h1>
71
+ """
72
+
73
+ description = r"""
74
+ How to use:<br>
75
+ 1. Upload a style image.
76
+ 2. Set stylization mode, only use style block by default.
77
+ 2. Enter a text prompt, as done in normal text-to-image models.
78
+ 3. Click the <b>Submit</b> button to begin customization.
79
+ 4. Share your stylized photo with your friends and enjoy! 😊
80
+
81
+
82
+ Advanced usage:<br>
83
+ 1. Click advanced options.
84
+ 2. Upload another source image for image-based stylization using ControlNet.
85
+ 3. Enter negative content prompt to avoid content leakage.
86
+ """
87
+
88
+ article = r"""
89
+ ---
90
+ ```bibtex
91
+ @article{wang2024instantstyle,
92
+ title={InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation},
93
+ author={Wang, Haofan and Wang, Qixun and Bai, Xu and Qin, Zekui and Chen, Anthony},
94
+ journal={arXiv preprint arXiv:2404.02733},
95
+ year={2024}
96
+ }
97
+ ```
98
+ """
99
+
100
+ block = gr.Blocks().queue(max_size=10, api_open=True)
101
+ with block:
102
+
103
+ # description
104
+ gr.Markdown(title)
105
+ gr.Markdown(description)
106
+
107
+ with gr.Tabs():
108
+ with gr.Row():
109
+ with gr.Column():
110
+
111
+ with gr.Row():
112
+ with gr.Column():
113
+ image_pil = gr.Image(label="Style Image", type="numpy")
114
+
115
+ target = gr.Radio(["Load only style blocks", "Load only layout blocks","Load style+layout block", "Load original IP-Adapter"],
116
+ value="Load only style blocks",
117
+ label="Style mode")
118
+
119
+ prompt = gr.Textbox(label="Prompt",
120
+ value="a cat, masterpiece, best quality, high quality")
121
+
122
+ scale = gr.Slider(minimum=0,maximum=2.0, step=0.01,value=1.0, label="Scale")
123
+
124
+ with gr.Accordion(open=False, label="Advanced Options"):
125
+
126
+ control_scale = gr.Slider(minimum=0,maximum=1.0, step=0.01,value=0.5, label="Controlnet conditioning scale")
127
+
128
+ n_prompt = gr.Textbox(label="Neg Prompt", value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry")
129
+ guidance_scale = gr.Slider(minimum=1,maximum=15.0, step=0.01,value=5.0, label="guidance scale")
130
+ num_inference_steps = gr.Slider(minimum=5,maximum=50.0, step=1.0,value=20, label="num inference steps")
131
+ seed = gr.Slider(minimum=-1000000,maximum=1000000,value=1, step=1, label="Seed Value")
132
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
133
+
134
+ generate_button = gr.Button("Generate Image")
135
+
136
+ with gr.Column():
137
+ generated_image = gr.Gallery(label="Generated Image")
138
+
139
+ generate_button.click(
140
+ fn=randomize_seed_fn,
141
+ inputs=[seed, randomize_seed],
142
+ outputs=seed,
143
+ queue=False,
144
+ api_name=False,
145
+ ).then(
146
+ fn=create_image,
147
+ inputs=[image_pil,
148
+ prompt,
149
+ n_prompt,
150
+ scale,
151
+ control_scale,
152
+ guidance_scale,
153
+ num_inference_steps,
154
+ seed,
155
+ target],
156
+ outputs=[generated_image])
157
+
158
+ gr.Markdown(article)
159
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  block.launch(show_error=True)
requirements.txt CHANGED
@@ -1,16 +1,6 @@
1
  diffusers>=0.25.1
2
  torch>=2.0.0
3
- torchvision>=0.15.1
4
  transformers>=4.37.1
5
- accelerate
6
- safetensors
7
- einops
8
  spaces>=0.19.4
9
- omegaconf
10
- peft
11
  huggingface-hub>=0.20.2
12
- opencv-python
13
- gradio==4.38.0
14
- controlnet_aux
15
- gdown
16
- peft
 
1
  diffusers>=0.25.1
2
  torch>=2.0.0
 
3
  transformers>=4.37.1
 
 
 
4
  spaces>=0.19.4
 
 
5
  huggingface-hub>=0.20.2
6
+ gradio