wangqixun commited on
Commit
8310116
Β·
verified Β·
1 Parent(s): e9750d9

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +302 -0
  2. requirements.txt +16 -0
app.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('../')
3
+
4
+ import torch
5
+ import random
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ import gradio as gr
10
+ from huggingface_hub import hf_hub_download
11
+ from transformers import AutoModelForImageSegmentation
12
+ from torchvision import transforms
13
+
14
+ from pipeline import InstantCharacterFluxPipeline
15
+
16
+ # global variable
17
+ MAX_SEED = np.iinfo(np.int32).max
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
20
+
21
+ # pre-trained weights
22
+ ip_adapter_path = hf_hub_download(repo_id="InstantX/InstantCharacter", filename="instantcharacter_ip-adapter.bin")
23
+ base_model = 'black-forest-labs/FLUX.1-dev'
24
+ image_encoder_path = 'google/siglip-so400m-patch14-384'
25
+ image_encoder_2_path = 'facebook/dinov2-giant'
26
+ birefnet_path = 'ZhengPeng7/BiRefNet'
27
+ makoto_style_lora_path = hf_hub_download(repo_id="InstantX/FLUX.1-dev-LoRA-Makoto-Shinkai", filename="Makoto_Shinkai_style.safetensors")
28
+ ghibli_style_lora_path = hf_hub_download(repo_id="InstantX/FLUX.1-dev-LoRA-Ghibli", filename="ghibli_style.safetensors")
29
+
30
+ # init InstantCharacter pipeline
31
+ pipe = InstantCharacterFluxPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
32
+ pipe.to(device)
33
+
34
+ # load InstantCharacter
35
+ pipe.init_adapter(
36
+ image_encoder_path=image_encoder_path,
37
+ image_encoder_2_path=image_encoder_2_path,
38
+ subject_ipadapter_cfg=dict(subject_ip_adapter_path=ip_adapter_path, nb_token=1024),
39
+ )
40
+
41
+ # load matting model
42
+ birefnet = AutoModelForImageSegmentation.from_pretrained(birefnet_path, trust_remote_code=True)
43
+ birefnet.to('cuda')
44
+ birefnet.eval()
45
+ birefnet_transform_image = transforms.Compose([
46
+ transforms.Resize((1024, 1024)),
47
+ transforms.ToTensor(),
48
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
49
+ ])
50
+
51
+
52
+ def remove_bkg(subject_image):
53
+
54
+ def infer_matting(img_pil):
55
+ input_images = birefnet_transform_image(img_pil).unsqueeze(0).to('cuda')
56
+
57
+ with torch.no_grad():
58
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
59
+ pred = preds[0].squeeze()
60
+ pred_pil = transforms.ToPILImage()(pred)
61
+ mask = pred_pil.resize(img_pil.size)
62
+ mask = np.array(mask)
63
+ mask = mask[..., None]
64
+ return mask
65
+
66
+ def get_bbox_from_mask(mask, th=128):
67
+ height, width = mask.shape[:2]
68
+ x1, y1, x2, y2 = 0, 0, width - 1, height - 1
69
+
70
+ sample = np.max(mask, axis=0)
71
+ for idx in range(width):
72
+ if sample[idx] >= th:
73
+ x1 = idx
74
+ break
75
+
76
+ sample = np.max(mask[:, ::-1], axis=0)
77
+ for idx in range(width):
78
+ if sample[idx] >= th:
79
+ x2 = width - 1 - idx
80
+ break
81
+
82
+ sample = np.max(mask, axis=1)
83
+ for idx in range(height):
84
+ if sample[idx] >= th:
85
+ y1 = idx
86
+ break
87
+
88
+ sample = np.max(mask[::-1], axis=1)
89
+ for idx in range(height):
90
+ if sample[idx] >= th:
91
+ y2 = height - 1 - idx
92
+ break
93
+
94
+ x1 = np.clip(x1, 0, width-1).round().astype(np.int32)
95
+ y1 = np.clip(y1, 0, height-1).round().astype(np.int32)
96
+ x2 = np.clip(x2, 0, width-1).round().astype(np.int32)
97
+ y2 = np.clip(y2, 0, height-1).round().astype(np.int32)
98
+
99
+ return [x1, y1, x2, y2]
100
+
101
+ def pad_to_square(image, pad_value = 255, random = False):
102
+ '''
103
+ image: np.array [h, w, 3]
104
+ '''
105
+ H,W = image.shape[0], image.shape[1]
106
+ if H == W:
107
+ return image
108
+
109
+ padd = abs(H - W)
110
+ if random:
111
+ padd_1 = int(np.random.randint(0,padd))
112
+ else:
113
+ padd_1 = int(padd / 2)
114
+ padd_2 = padd - padd_1
115
+
116
+ if H > W:
117
+ pad_param = ((0,0),(padd_1,padd_2),(0,0))
118
+ else:
119
+ pad_param = ((padd_1,padd_2),(0,0),(0,0))
120
+
121
+ image = np.pad(image, pad_param, 'constant', constant_values=pad_value)
122
+ return image
123
+
124
+ salient_object_mask = infer_matting(subject_image)[..., 0]
125
+ x1, y1, x2, y2 = get_bbox_from_mask(salient_object_mask)
126
+ subject_image = np.array(subject_image)
127
+ salient_object_mask[salient_object_mask > 128] = 255
128
+ salient_object_mask[salient_object_mask < 128] = 0
129
+ sample_mask = np.concatenate([salient_object_mask[..., None]]*3, axis=2)
130
+ obj_image = sample_mask / 255 * subject_image + (1 - sample_mask / 255) * 255
131
+ crop_obj_image = obj_image[y1:y2, x1:x2]
132
+ crop_pad_obj_image = pad_to_square(crop_obj_image, 255)
133
+ subject_image = Image.fromarray(crop_pad_obj_image.astype(np.uint8))
134
+ return subject_image
135
+
136
+
137
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
138
+ if randomize_seed:
139
+ seed = random.randint(0, MAX_SEED)
140
+ return seed
141
+
142
+ def get_example():
143
+ case = [
144
+ [
145
+ "./assets/girl.jpg",
146
+ "A girl is playing a guitar in street",
147
+ 0.9,
148
+ 'Makoto Shinkai style',
149
+ ],
150
+ [
151
+ "./assets/boy.jpg",
152
+ "A boy is riding a bike in snow",
153
+ 0.9,
154
+ 'Makoto Shinkai style',
155
+ ],
156
+ ]
157
+ return case
158
+
159
+ def run_for_examples(source_image, prompt, scale, style_mode):
160
+
161
+ return create_image(
162
+ input_image=source_image,
163
+ prompt=prompt,
164
+ scale=scale,
165
+ guidance_scale=3.5,
166
+ num_inference_steps=28,
167
+ seed=123456,
168
+ style_mode=style_mode,
169
+ )
170
+
171
+ def create_image(input_image,
172
+ prompt,
173
+ scale,
174
+ guidance_scale,
175
+ num_inference_steps,
176
+ seed,
177
+ style_mode=None):
178
+
179
+ input_image = remove_bkg(input_image)
180
+
181
+ if style_mode is None:
182
+ images = pipe(
183
+ prompt=prompt,
184
+ num_inference_steps=num_inference_steps,
185
+ guidance_scale=guidance_scale,
186
+ width=1024,
187
+ height=1024,
188
+ subject_image=input_image,
189
+ subject_scale=scale,
190
+ generator=torch.manual_seed(seed),
191
+ ).images
192
+ else:
193
+ if style_mode == 'Makoto Shinkai style':
194
+ lora_file_path = makoto_style_lora_path
195
+ trigger = 'Makoto Shinkai style'
196
+ elif style_mode == 'Ghibli style':
197
+ lora_file_path = ghibli_style_lora_path
198
+ trigger = 'ghibli style'
199
+
200
+ images = pipe.with_style_lora(
201
+ lora_file_path=lora_file_path,
202
+ trigger=trigger,
203
+ prompt=prompt,
204
+ num_inference_steps=num_inference_steps,
205
+ guidance_scale=guidance_scale,
206
+ width=1024,
207
+ height=1024,
208
+ subject_image=input_image,
209
+ subject_scale=scale,
210
+ generator=torch.manual_seed(seed),
211
+ ).images
212
+
213
+
214
+ return images
215
+
216
+ # Description
217
+ title = r"""
218
+ <h1 align="center">InstantCharacter : Personalize Any Characters with a Scalable Diffusion Transformer Framework</h1>
219
+ """
220
+
221
+ description = r"""
222
+ <b>Official πŸ€— Gradio demo</b> for <a href='https://instantcharacter.github.io/' target='_blank'><b>InstantCharacter : Personalize Any Characters with a Scalable Diffusion Transformer Framework</b></a>.<br>
223
+ How to use:<br>
224
+ 1. Upload a character image, removing background would be preferred.
225
+ 2. Enter a text prompt to describe what you hope the chracter does.
226
+ 3. Click the <b>Submit</b> button to begin customization.
227
+ 4. Share your custimized photo with your friends and enjoy! 😊
228
+ """
229
+
230
+ article = r"""
231
+ ---
232
+ πŸ“ **Citation**
233
+ <br>
234
+ If our work is helpful for your research or applications, please cite us via:
235
+ ```bibtex
236
+ TBD
237
+ ```
238
+ πŸ“§ **Contact**
239
+ <br>
240
+ If you have any questions, please feel free to open an issue.
241
+ """
242
+
243
+ block = gr.Blocks(css="footer {visibility: hidden}").queue(max_size=10, api_open=False)
244
+ with block:
245
+
246
+ # description
247
+ gr.Markdown(title)
248
+ gr.Markdown(description)
249
+
250
+ with gr.Tabs():
251
+ with gr.Row():
252
+ with gr.Column():
253
+
254
+ with gr.Row():
255
+ with gr.Column():
256
+ image_pil = gr.Image(label="Source Image", type='pil')
257
+
258
+ prompt = gr.Textbox(label="Prompt", value="a character is riding a bike in snow")
259
+
260
+ scale = gr.Slider(minimum=0, maximum=1.5, step=0.01,value=1.0, label="Scale")
261
+ style_mode = gr.Dropdown(label='Style', choices=[None, 'Makoto Shinkai style', 'Ghibli style'], value=None)
262
+
263
+ with gr.Accordion(open=False, label="Advanced Options"):
264
+ guidance_scale = gr.Slider(minimum=1,maximum=7.0, step=0.01,value=3.5, label="guidance scale")
265
+ num_inference_steps = gr.Slider(minimum=5,maximum=50.0, step=1.0,value=28, label="num inference steps")
266
+ seed = gr.Slider(minimum=-1000000, maximum=1000000, value=123456, step=1, label="Seed Value")
267
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
268
+
269
+ generate_button = gr.Button("Generate Image")
270
+
271
+ with gr.Column():
272
+ generated_image = gr.Gallery(label="Generated Image")
273
+
274
+ generate_button.click(
275
+ fn=randomize_seed_fn,
276
+ inputs=[seed, randomize_seed],
277
+ outputs=seed,
278
+ queue=False,
279
+ api_name=False,
280
+ ).then(
281
+ fn=create_image,
282
+ inputs=[image_pil,
283
+ prompt,
284
+ scale,
285
+ guidance_scale,
286
+ num_inference_steps,
287
+ seed,
288
+ style_mode,
289
+ ],
290
+ outputs=[generated_image])
291
+
292
+ gr.Examples(
293
+ examples=get_example(),
294
+ inputs=[image_pil, prompt, scale, style_mode],
295
+ fn=run_for_examples,
296
+ outputs=[generated_image],
297
+ cache_examples=True,
298
+ )
299
+
300
+ gr.Markdown(article)
301
+
302
+ block.launch(server_name="0.0.0.0", server_port=80)
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers>=0.32.2
2
+ torch>=2.0.0
3
+ torchvision>=0.15.1
4
+ transformers>=4.37.1
5
+ accelerate
6
+ safetensors
7
+ einops
8
+ spaces>=0.19.4
9
+ omegaconf
10
+ peft
11
+ huggingface-hub>=0.20.2
12
+ opencv-python
13
+ gradio
14
+ controlnet_aux
15
+ gdown
16
+ peft