hatmanstack commited on
Commit
39b6987
·
1 Parent(s): 65fe6f1

AutoIP Adapter

Browse files
app.py CHANGED
@@ -3,7 +3,7 @@ sys.path.append('./')
3
 
4
  import torch
5
  import random
6
- import spaces
7
  import gradio as gr
8
 
9
  from diffusers import AutoPipelineForText2Image
@@ -21,7 +21,7 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
21
  pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype).to(device)
22
  pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
23
 
24
- @spaces.GPU(enable_queue=True)
25
  def create_image(image_pil,
26
  prompt,
27
  n_prompt,
@@ -50,7 +50,8 @@ def create_image(image_pil,
50
  }
51
  pipeline.set_ip_adapter_scale(scale)
52
 
53
-
 
54
  style_image = load_image(image_pil)
55
 
56
  generator = torch.Generator(device="cpu").manual_seed(randomize_seed_fn(seed, False))
@@ -63,7 +64,7 @@ def create_image(image_pil,
63
  generator=generator,
64
  )
65
  return image
66
-
67
 
68
  # Description
69
  title = r"""
@@ -110,7 +111,7 @@ with block:
110
 
111
  with gr.Row():
112
  with gr.Column():
113
- image_pil = gr.Image(label="Style Image", type="numpy")
114
 
115
  target = gr.Radio(["Load only style blocks", "Load only layout blocks","Load style+layout block", "Load original IP-Adapter"],
116
  value="Load only style blocks",
 
3
 
4
  import torch
5
  import random
6
+
7
  import gradio as gr
8
 
9
  from diffusers import AutoPipelineForText2Image
 
21
  pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype).to(device)
22
  pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
23
 
24
+ #### Don't Forget Spaces GPU
25
  def create_image(image_pil,
26
  prompt,
27
  n_prompt,
 
50
  }
51
  pipeline.set_ip_adapter_scale(scale)
52
 
53
+ print(image_pil)
54
+
55
  style_image = load_image(image_pil)
56
 
57
  generator = torch.Generator(device="cpu").manual_seed(randomize_seed_fn(seed, False))
 
64
  generator=generator,
65
  )
66
  return image
67
+
68
 
69
  # Description
70
  title = r"""
 
111
 
112
  with gr.Row():
113
  with gr.Column():
114
+ image_pil = gr.Image(label="Style Image", type="pil")
115
 
116
  target = gr.Radio(["Load only style blocks", "Load only layout blocks","Load style+layout block", "Load original IP-Adapter"],
117
  value="Load only style blocks",
assets/0.jpg DELETED
Binary file (610 kB)
 
assets/1.jpg DELETED
Binary file (283 kB)
 
assets/2.jpg DELETED
Binary file (88.3 kB)
 
assets/3.jpg DELETED
Binary file (404 kB)
 
assets/yann-lecun.jpg DELETED
Binary file (30.6 kB)
 
ip_adapter/__init__.py DELETED
@@ -1,9 +0,0 @@
1
- from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull
2
-
3
- __all__ = [
4
- "IPAdapter",
5
- "IPAdapterPlus",
6
- "IPAdapterPlusXL",
7
- "IPAdapterXL",
8
- "IPAdapterFull",
9
- ]
 
 
 
 
 
 
 
 
 
 
ip_adapter/attention_processor.py DELETED
@@ -1,562 +0,0 @@
1
- # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
-
6
-
7
- class AttnProcessor(nn.Module):
8
- r"""
9
- Default processor for performing attention-related computations.
10
- """
11
-
12
- def __init__(
13
- self,
14
- hidden_size=None,
15
- cross_attention_dim=None,
16
- ):
17
- super().__init__()
18
-
19
- def __call__(
20
- self,
21
- attn,
22
- hidden_states,
23
- encoder_hidden_states=None,
24
- attention_mask=None,
25
- temb=None,
26
- ):
27
- residual = hidden_states
28
-
29
- if attn.spatial_norm is not None:
30
- hidden_states = attn.spatial_norm(hidden_states, temb)
31
-
32
- input_ndim = hidden_states.ndim
33
-
34
- if input_ndim == 4:
35
- batch_size, channel, height, width = hidden_states.shape
36
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
37
-
38
- batch_size, sequence_length, _ = (
39
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
40
- )
41
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
42
-
43
- if attn.group_norm is not None:
44
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
45
-
46
- query = attn.to_q(hidden_states)
47
-
48
- if encoder_hidden_states is None:
49
- encoder_hidden_states = hidden_states
50
- elif attn.norm_cross:
51
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
52
-
53
- key = attn.to_k(encoder_hidden_states)
54
- value = attn.to_v(encoder_hidden_states)
55
-
56
- query = attn.head_to_batch_dim(query)
57
- key = attn.head_to_batch_dim(key)
58
- value = attn.head_to_batch_dim(value)
59
-
60
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
61
- hidden_states = torch.bmm(attention_probs, value)
62
- hidden_states = attn.batch_to_head_dim(hidden_states)
63
-
64
- # linear proj
65
- hidden_states = attn.to_out[0](hidden_states)
66
- # dropout
67
- hidden_states = attn.to_out[1](hidden_states)
68
-
69
- if input_ndim == 4:
70
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
71
-
72
- if attn.residual_connection:
73
- hidden_states = hidden_states + residual
74
-
75
- hidden_states = hidden_states / attn.rescale_output_factor
76
-
77
- return hidden_states
78
-
79
-
80
- class IPAttnProcessor(nn.Module):
81
- r"""
82
- Attention processor for IP-Adapater.
83
- Args:
84
- hidden_size (`int`):
85
- The hidden size of the attention layer.
86
- cross_attention_dim (`int`):
87
- The number of channels in the `encoder_hidden_states`.
88
- scale (`float`, defaults to 1.0):
89
- the weight scale of image prompt.
90
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
91
- The context length of the image features.
92
- """
93
-
94
- def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False):
95
- super().__init__()
96
-
97
- self.hidden_size = hidden_size
98
- self.cross_attention_dim = cross_attention_dim
99
- self.scale = scale
100
- self.num_tokens = num_tokens
101
- self.skip = skip
102
-
103
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
104
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
105
-
106
- def __call__(
107
- self,
108
- attn,
109
- hidden_states,
110
- encoder_hidden_states=None,
111
- attention_mask=None,
112
- temb=None,
113
- ):
114
- residual = hidden_states
115
-
116
- if attn.spatial_norm is not None:
117
- hidden_states = attn.spatial_norm(hidden_states, temb)
118
-
119
- input_ndim = hidden_states.ndim
120
-
121
- if input_ndim == 4:
122
- batch_size, channel, height, width = hidden_states.shape
123
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
124
-
125
- batch_size, sequence_length, _ = (
126
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
127
- )
128
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
129
-
130
- if attn.group_norm is not None:
131
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
132
-
133
- query = attn.to_q(hidden_states)
134
-
135
- if encoder_hidden_states is None:
136
- encoder_hidden_states = hidden_states
137
- else:
138
- # get encoder_hidden_states, ip_hidden_states
139
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
140
- encoder_hidden_states, ip_hidden_states = (
141
- encoder_hidden_states[:, :end_pos, :],
142
- encoder_hidden_states[:, end_pos:, :],
143
- )
144
- if attn.norm_cross:
145
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
146
-
147
- key = attn.to_k(encoder_hidden_states)
148
- value = attn.to_v(encoder_hidden_states)
149
-
150
- query = attn.head_to_batch_dim(query)
151
- key = attn.head_to_batch_dim(key)
152
- value = attn.head_to_batch_dim(value)
153
-
154
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
155
- hidden_states = torch.bmm(attention_probs, value)
156
- hidden_states = attn.batch_to_head_dim(hidden_states)
157
-
158
- if not self.skip:
159
- # for ip-adapter
160
- ip_key = self.to_k_ip(ip_hidden_states)
161
- ip_value = self.to_v_ip(ip_hidden_states)
162
-
163
- ip_key = attn.head_to_batch_dim(ip_key)
164
- ip_value = attn.head_to_batch_dim(ip_value)
165
-
166
- ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
167
- self.attn_map = ip_attention_probs
168
- ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
169
- ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
170
-
171
- hidden_states = hidden_states + self.scale * ip_hidden_states
172
-
173
- # linear proj
174
- hidden_states = attn.to_out[0](hidden_states)
175
- # dropout
176
- hidden_states = attn.to_out[1](hidden_states)
177
-
178
- if input_ndim == 4:
179
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
180
-
181
- if attn.residual_connection:
182
- hidden_states = hidden_states + residual
183
-
184
- hidden_states = hidden_states / attn.rescale_output_factor
185
-
186
- return hidden_states
187
-
188
-
189
- class AttnProcessor2_0(torch.nn.Module):
190
- r"""
191
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
192
- """
193
-
194
- def __init__(
195
- self,
196
- hidden_size=None,
197
- cross_attention_dim=None,
198
- ):
199
- super().__init__()
200
- if not hasattr(F, "scaled_dot_product_attention"):
201
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
202
-
203
- def __call__(
204
- self,
205
- attn,
206
- hidden_states,
207
- encoder_hidden_states=None,
208
- attention_mask=None,
209
- temb=None,
210
- ):
211
- residual = hidden_states
212
-
213
- if attn.spatial_norm is not None:
214
- hidden_states = attn.spatial_norm(hidden_states, temb)
215
-
216
- input_ndim = hidden_states.ndim
217
-
218
- if input_ndim == 4:
219
- batch_size, channel, height, width = hidden_states.shape
220
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
221
-
222
- batch_size, sequence_length, _ = (
223
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
224
- )
225
-
226
- if attention_mask is not None:
227
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
228
- # scaled_dot_product_attention expects attention_mask shape to be
229
- # (batch, heads, source_length, target_length)
230
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
231
-
232
- if attn.group_norm is not None:
233
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
234
-
235
- query = attn.to_q(hidden_states)
236
-
237
- if encoder_hidden_states is None:
238
- encoder_hidden_states = hidden_states
239
- elif attn.norm_cross:
240
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
241
-
242
- key = attn.to_k(encoder_hidden_states)
243
- value = attn.to_v(encoder_hidden_states)
244
-
245
- inner_dim = key.shape[-1]
246
- head_dim = inner_dim // attn.heads
247
-
248
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
249
-
250
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
251
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
252
-
253
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
254
- # TODO: add support for attn.scale when we move to Torch 2.1
255
- hidden_states = F.scaled_dot_product_attention(
256
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
257
- )
258
-
259
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
260
- hidden_states = hidden_states.to(query.dtype)
261
-
262
- # linear proj
263
- hidden_states = attn.to_out[0](hidden_states)
264
- # dropout
265
- hidden_states = attn.to_out[1](hidden_states)
266
-
267
- if input_ndim == 4:
268
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
269
-
270
- if attn.residual_connection:
271
- hidden_states = hidden_states + residual
272
-
273
- hidden_states = hidden_states / attn.rescale_output_factor
274
-
275
- return hidden_states
276
-
277
-
278
- class IPAttnProcessor2_0(torch.nn.Module):
279
- r"""
280
- Attention processor for IP-Adapater for PyTorch 2.0.
281
- Args:
282
- hidden_size (`int`):
283
- The hidden size of the attention layer.
284
- cross_attention_dim (`int`):
285
- The number of channels in the `encoder_hidden_states`.
286
- scale (`float`, defaults to 1.0):
287
- the weight scale of image prompt.
288
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
289
- The context length of the image features.
290
- """
291
-
292
- def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False):
293
- super().__init__()
294
-
295
- if not hasattr(F, "scaled_dot_product_attention"):
296
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
297
-
298
- self.hidden_size = hidden_size
299
- self.cross_attention_dim = cross_attention_dim
300
- self.scale = scale
301
- self.num_tokens = num_tokens
302
- self.skip = skip
303
-
304
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
305
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
306
-
307
- def __call__(
308
- self,
309
- attn,
310
- hidden_states,
311
- encoder_hidden_states=None,
312
- attention_mask=None,
313
- temb=None,
314
- ):
315
- residual = hidden_states
316
-
317
- if attn.spatial_norm is not None:
318
- hidden_states = attn.spatial_norm(hidden_states, temb)
319
-
320
- input_ndim = hidden_states.ndim
321
-
322
- if input_ndim == 4:
323
- batch_size, channel, height, width = hidden_states.shape
324
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
325
-
326
- batch_size, sequence_length, _ = (
327
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
328
- )
329
-
330
- if attention_mask is not None:
331
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
332
- # scaled_dot_product_attention expects attention_mask shape to be
333
- # (batch, heads, source_length, target_length)
334
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
335
-
336
- if attn.group_norm is not None:
337
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
338
-
339
- query = attn.to_q(hidden_states)
340
-
341
- if encoder_hidden_states is None:
342
- encoder_hidden_states = hidden_states
343
- else:
344
- # get encoder_hidden_states, ip_hidden_states
345
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
346
- encoder_hidden_states, ip_hidden_states = (
347
- encoder_hidden_states[:, :end_pos, :],
348
- encoder_hidden_states[:, end_pos:, :],
349
- )
350
- if attn.norm_cross:
351
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
352
-
353
- key = attn.to_k(encoder_hidden_states)
354
- value = attn.to_v(encoder_hidden_states)
355
-
356
- inner_dim = key.shape[-1]
357
- head_dim = inner_dim // attn.heads
358
-
359
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
360
-
361
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
362
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
363
-
364
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
365
- # TODO: add support for attn.scale when we move to Torch 2.1
366
- hidden_states = F.scaled_dot_product_attention(
367
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
368
- )
369
-
370
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
371
- hidden_states = hidden_states.to(query.dtype)
372
-
373
- if not self.skip:
374
- # for ip-adapter
375
- ip_key = self.to_k_ip(ip_hidden_states)
376
- ip_value = self.to_v_ip(ip_hidden_states)
377
-
378
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
379
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
380
-
381
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
382
- # TODO: add support for attn.scale when we move to Torch 2.1
383
- ip_hidden_states = F.scaled_dot_product_attention(
384
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
385
- )
386
- with torch.no_grad():
387
- self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
388
- #print(self.attn_map.shape)
389
-
390
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
391
- ip_hidden_states = ip_hidden_states.to(query.dtype)
392
-
393
- hidden_states = hidden_states + self.scale * ip_hidden_states
394
-
395
- # linear proj
396
- hidden_states = attn.to_out[0](hidden_states)
397
- # dropout
398
- hidden_states = attn.to_out[1](hidden_states)
399
-
400
- if input_ndim == 4:
401
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
402
-
403
- if attn.residual_connection:
404
- hidden_states = hidden_states + residual
405
-
406
- hidden_states = hidden_states / attn.rescale_output_factor
407
-
408
- return hidden_states
409
-
410
-
411
- ## for controlnet
412
- class CNAttnProcessor:
413
- r"""
414
- Default processor for performing attention-related computations.
415
- """
416
-
417
- def __init__(self, num_tokens=4):
418
- self.num_tokens = num_tokens
419
-
420
- def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
421
- residual = hidden_states
422
-
423
- if attn.spatial_norm is not None:
424
- hidden_states = attn.spatial_norm(hidden_states, temb)
425
-
426
- input_ndim = hidden_states.ndim
427
-
428
- if input_ndim == 4:
429
- batch_size, channel, height, width = hidden_states.shape
430
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
431
-
432
- batch_size, sequence_length, _ = (
433
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
434
- )
435
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
436
-
437
- if attn.group_norm is not None:
438
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
439
-
440
- query = attn.to_q(hidden_states)
441
-
442
- if encoder_hidden_states is None:
443
- encoder_hidden_states = hidden_states
444
- else:
445
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
446
- encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
447
- if attn.norm_cross:
448
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
449
-
450
- key = attn.to_k(encoder_hidden_states)
451
- value = attn.to_v(encoder_hidden_states)
452
-
453
- query = attn.head_to_batch_dim(query)
454
- key = attn.head_to_batch_dim(key)
455
- value = attn.head_to_batch_dim(value)
456
-
457
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
458
- hidden_states = torch.bmm(attention_probs, value)
459
- hidden_states = attn.batch_to_head_dim(hidden_states)
460
-
461
- # linear proj
462
- hidden_states = attn.to_out[0](hidden_states)
463
- # dropout
464
- hidden_states = attn.to_out[1](hidden_states)
465
-
466
- if input_ndim == 4:
467
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
468
-
469
- if attn.residual_connection:
470
- hidden_states = hidden_states + residual
471
-
472
- hidden_states = hidden_states / attn.rescale_output_factor
473
-
474
- return hidden_states
475
-
476
-
477
- class CNAttnProcessor2_0:
478
- r"""
479
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
480
- """
481
-
482
- def __init__(self, num_tokens=4):
483
- if not hasattr(F, "scaled_dot_product_attention"):
484
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
485
- self.num_tokens = num_tokens
486
-
487
- def __call__(
488
- self,
489
- attn,
490
- hidden_states,
491
- encoder_hidden_states=None,
492
- attention_mask=None,
493
- temb=None,
494
- ):
495
- residual = hidden_states
496
-
497
- if attn.spatial_norm is not None:
498
- hidden_states = attn.spatial_norm(hidden_states, temb)
499
-
500
- input_ndim = hidden_states.ndim
501
-
502
- if input_ndim == 4:
503
- batch_size, channel, height, width = hidden_states.shape
504
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
505
-
506
- batch_size, sequence_length, _ = (
507
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
508
- )
509
-
510
- if attention_mask is not None:
511
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
512
- # scaled_dot_product_attention expects attention_mask shape to be
513
- # (batch, heads, source_length, target_length)
514
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
515
-
516
- if attn.group_norm is not None:
517
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
518
-
519
- query = attn.to_q(hidden_states)
520
-
521
- if encoder_hidden_states is None:
522
- encoder_hidden_states = hidden_states
523
- else:
524
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
525
- encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
526
- if attn.norm_cross:
527
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
528
-
529
- key = attn.to_k(encoder_hidden_states)
530
- value = attn.to_v(encoder_hidden_states)
531
-
532
- inner_dim = key.shape[-1]
533
- head_dim = inner_dim // attn.heads
534
-
535
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
536
-
537
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
538
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
539
-
540
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
541
- # TODO: add support for attn.scale when we move to Torch 2.1
542
- hidden_states = F.scaled_dot_product_attention(
543
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
544
- )
545
-
546
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
547
- hidden_states = hidden_states.to(query.dtype)
548
-
549
- # linear proj
550
- hidden_states = attn.to_out[0](hidden_states)
551
- # dropout
552
- hidden_states = attn.to_out[1](hidden_states)
553
-
554
- if input_ndim == 4:
555
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
556
-
557
- if attn.residual_connection:
558
- hidden_states = hidden_states + residual
559
-
560
- hidden_states = hidden_states / attn.rescale_output_factor
561
-
562
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ip_adapter/ip_adapter.py DELETED
@@ -1,461 +0,0 @@
1
- import os
2
- from typing import List
3
-
4
- import torch
5
- from diffusers import StableDiffusionPipeline
6
- from diffusers.pipelines.controlnet import MultiControlNetModel
7
- from PIL import Image
8
- from safetensors import safe_open
9
- from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
10
-
11
- from .utils import is_torch2_available, get_generator
12
-
13
- if is_torch2_available():
14
- from .attention_processor import (
15
- AttnProcessor2_0 as AttnProcessor,
16
- )
17
- from .attention_processor import (
18
- CNAttnProcessor2_0 as CNAttnProcessor,
19
- )
20
- from .attention_processor import (
21
- IPAttnProcessor2_0 as IPAttnProcessor,
22
- )
23
- else:
24
- from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
25
- from .resampler import Resampler
26
-
27
-
28
- class ImageProjModel(torch.nn.Module):
29
- """Projection Model"""
30
-
31
- def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
32
- super().__init__()
33
-
34
- self.generator = None
35
- self.cross_attention_dim = cross_attention_dim
36
- self.clip_extra_context_tokens = clip_extra_context_tokens
37
- self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
38
- self.norm = torch.nn.LayerNorm(cross_attention_dim)
39
-
40
- def forward(self, image_embeds):
41
- embeds = image_embeds
42
- clip_extra_context_tokens = self.proj(embeds).reshape(
43
- -1, self.clip_extra_context_tokens, self.cross_attention_dim
44
- )
45
- clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
46
- return clip_extra_context_tokens
47
-
48
-
49
- class MLPProjModel(torch.nn.Module):
50
- """SD model with image prompt"""
51
- def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
52
- super().__init__()
53
-
54
- self.proj = torch.nn.Sequential(
55
- torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
56
- torch.nn.GELU(),
57
- torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
58
- torch.nn.LayerNorm(cross_attention_dim)
59
- )
60
-
61
- def forward(self, image_embeds):
62
- clip_extra_context_tokens = self.proj(image_embeds)
63
- return clip_extra_context_tokens
64
-
65
-
66
- class IPAdapter:
67
- def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, target_blocks=["block"]):
68
- self.device = device
69
- self.image_encoder_path = image_encoder_path
70
- self.ip_ckpt = ip_ckpt
71
- self.num_tokens = num_tokens
72
- self.target_blocks = target_blocks
73
-
74
- self.pipe = sd_pipe.to(self.device)
75
- self.set_ip_adapter()
76
-
77
- # load image encoder
78
- self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
79
- self.device, dtype=torch.float16
80
- )
81
- self.clip_image_processor = CLIPImageProcessor()
82
- # image proj model
83
- self.image_proj_model = self.init_proj()
84
-
85
- self.load_ip_adapter()
86
-
87
-
88
- def init_proj(self):
89
- image_proj_model = ImageProjModel(
90
- cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
91
- clip_embeddings_dim=self.image_encoder.config.projection_dim,
92
- clip_extra_context_tokens=self.num_tokens,
93
- ).to(self.device, dtype=torch.float16)
94
- return image_proj_model
95
-
96
- def set_ip_adapter(self):
97
- unet = self.pipe.unet
98
- attn_procs = {}
99
- for name in unet.attn_processors.keys():
100
- cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
101
- if name.startswith("mid_block"):
102
- hidden_size = unet.config.block_out_channels[-1]
103
- elif name.startswith("up_blocks"):
104
- block_id = int(name[len("up_blocks.")])
105
- hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
106
- elif name.startswith("down_blocks"):
107
- block_id = int(name[len("down_blocks.")])
108
- hidden_size = unet.config.block_out_channels[block_id]
109
- if cross_attention_dim is None:
110
- attn_procs[name] = AttnProcessor()
111
- else:
112
- selected = False
113
- for block_name in self.target_blocks:
114
- if block_name in name:
115
- selected = True
116
- break
117
- if selected:
118
- attn_procs[name] = IPAttnProcessor(
119
- hidden_size=hidden_size,
120
- cross_attention_dim=cross_attention_dim,
121
- scale=1.0,
122
- num_tokens=self.num_tokens,
123
- ).to(self.device, dtype=torch.float16)
124
- else:
125
- attn_procs[name] = IPAttnProcessor(
126
- hidden_size=hidden_size,
127
- cross_attention_dim=cross_attention_dim,
128
- scale=1.0,
129
- num_tokens=self.num_tokens,
130
- skip=True
131
- ).to(self.device, dtype=torch.float16)
132
- unet.set_attn_processor(attn_procs)
133
- if hasattr(self.pipe, "controlnet"):
134
- if isinstance(self.pipe.controlnet, MultiControlNetModel):
135
- for controlnet in self.pipe.controlnet.nets:
136
- controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
137
- else:
138
- self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
139
-
140
- def load_ip_adapter(self):
141
- if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
142
- state_dict = {"image_proj": {}, "ip_adapter": {}}
143
- with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
144
- for key in f.keys():
145
- if key.startswith("image_proj."):
146
- state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
147
- elif key.startswith("ip_adapter."):
148
- state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
149
- else:
150
- state_dict = torch.load(self.ip_ckpt, map_location="cpu")
151
- self.image_proj_model.load_state_dict(state_dict["image_proj"])
152
- ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
153
- ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
154
-
155
- @torch.inference_mode()
156
- def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None):
157
- if pil_image is not None:
158
- if isinstance(pil_image, Image.Image):
159
- pil_image = [pil_image]
160
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
161
- clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
162
- else:
163
- clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
164
-
165
- if content_prompt_embeds is not None:
166
- clip_image_embeds = clip_image_embeds - content_prompt_embeds
167
-
168
- image_prompt_embeds = self.image_proj_model(clip_image_embeds)
169
- uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
170
- return image_prompt_embeds, uncond_image_prompt_embeds
171
-
172
- def set_scale(self, scale):
173
- for attn_processor in self.pipe.unet.attn_processors.values():
174
- if isinstance(attn_processor, IPAttnProcessor):
175
- attn_processor.scale = scale
176
-
177
- def generate(
178
- self,
179
- pil_image=None,
180
- clip_image_embeds=None,
181
- prompt=None,
182
- negative_prompt=None,
183
- scale=1.0,
184
- num_samples=4,
185
- seed=None,
186
- guidance_scale=7.5,
187
- num_inference_steps=30,
188
- neg_content_emb=None,
189
- **kwargs,
190
- ):
191
- self.set_scale(scale)
192
-
193
- if pil_image is not None:
194
- num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
195
- else:
196
- num_prompts = clip_image_embeds.size(0)
197
-
198
- if prompt is None:
199
- prompt = "best quality, high quality"
200
- if negative_prompt is None:
201
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
202
-
203
- if not isinstance(prompt, List):
204
- prompt = [prompt] * num_prompts
205
- if not isinstance(negative_prompt, List):
206
- negative_prompt = [negative_prompt] * num_prompts
207
-
208
- image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
209
- pil_image=pil_image, clip_image_embeds=clip_image_embeds, content_prompt_embeds=neg_content_emb
210
- )
211
- bs_embed, seq_len, _ = image_prompt_embeds.shape
212
- image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
213
- image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
214
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
215
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
216
-
217
- with torch.inference_mode():
218
- prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
219
- prompt,
220
- device=self.device,
221
- num_images_per_prompt=num_samples,
222
- do_classifier_free_guidance=True,
223
- negative_prompt=negative_prompt,
224
- )
225
- prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
226
- negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
227
-
228
- generator = get_generator(seed, self.device)
229
-
230
- images = self.pipe(
231
- prompt_embeds=prompt_embeds,
232
- negative_prompt_embeds=negative_prompt_embeds,
233
- guidance_scale=guidance_scale,
234
- num_inference_steps=num_inference_steps,
235
- generator=generator,
236
- **kwargs,
237
- ).images
238
-
239
- return images
240
-
241
-
242
- class IPAdapterXL(IPAdapter):
243
- """SDXL"""
244
-
245
- def generate(
246
- self,
247
- pil_image,
248
- prompt=None,
249
- negative_prompt=None,
250
- scale=1.0,
251
- num_samples=4,
252
- seed=None,
253
- num_inference_steps=30,
254
- neg_content_emb=None,
255
- neg_content_prompt=None,
256
- neg_content_scale=1.0,
257
- **kwargs,
258
- ):
259
- self.set_scale(scale)
260
-
261
- num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
262
-
263
- if prompt is None:
264
- prompt = "best quality, high quality"
265
- if negative_prompt is None:
266
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
267
-
268
- if not isinstance(prompt, List):
269
- prompt = [prompt] * num_prompts
270
- if not isinstance(negative_prompt, List):
271
- negative_prompt = [negative_prompt] * num_prompts
272
-
273
- if neg_content_emb is None:
274
- if neg_content_prompt is not None:
275
- with torch.inference_mode():
276
- (
277
- prompt_embeds_, # torch.Size([1, 77, 2048])
278
- negative_prompt_embeds_,
279
- pooled_prompt_embeds_, # torch.Size([1, 1280])
280
- negative_pooled_prompt_embeds_,
281
- ) = self.pipe.encode_prompt(
282
- neg_content_prompt,
283
- num_images_per_prompt=num_samples,
284
- do_classifier_free_guidance=True,
285
- negative_prompt=negative_prompt,
286
- )
287
- pooled_prompt_embeds_ *= neg_content_scale
288
- else:
289
- pooled_prompt_embeds_ = neg_content_emb
290
- else:
291
- pooled_prompt_embeds_ = None
292
-
293
- image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image, content_prompt_embeds=pooled_prompt_embeds_)
294
- bs_embed, seq_len, _ = image_prompt_embeds.shape
295
- image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
296
- image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
297
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
298
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
299
-
300
- with torch.inference_mode():
301
- (
302
- prompt_embeds,
303
- negative_prompt_embeds,
304
- pooled_prompt_embeds,
305
- negative_pooled_prompt_embeds,
306
- ) = self.pipe.encode_prompt(
307
- prompt,
308
- num_images_per_prompt=num_samples,
309
- do_classifier_free_guidance=True,
310
- negative_prompt=negative_prompt,
311
- )
312
- prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
313
- negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
314
-
315
- self.generator = get_generator(seed, self.device)
316
-
317
- images = self.pipe(
318
- prompt_embeds=prompt_embeds,
319
- negative_prompt_embeds=negative_prompt_embeds,
320
- pooled_prompt_embeds=pooled_prompt_embeds,
321
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
322
- num_inference_steps=num_inference_steps,
323
- generator=self.generator,
324
- **kwargs,
325
- ).images
326
-
327
- return images
328
-
329
-
330
- class IPAdapterPlus(IPAdapter):
331
- """IP-Adapter with fine-grained features"""
332
-
333
- def init_proj(self):
334
- image_proj_model = Resampler(
335
- dim=self.pipe.unet.config.cross_attention_dim,
336
- depth=4,
337
- dim_head=64,
338
- heads=12,
339
- num_queries=self.num_tokens,
340
- embedding_dim=self.image_encoder.config.hidden_size,
341
- output_dim=self.pipe.unet.config.cross_attention_dim,
342
- ff_mult=4,
343
- ).to(self.device, dtype=torch.float16)
344
- return image_proj_model
345
-
346
- @torch.inference_mode()
347
- def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
348
- if isinstance(pil_image, Image.Image):
349
- pil_image = [pil_image]
350
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
351
- clip_image = clip_image.to(self.device, dtype=torch.float16)
352
- clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
353
- image_prompt_embeds = self.image_proj_model(clip_image_embeds)
354
- uncond_clip_image_embeds = self.image_encoder(
355
- torch.zeros_like(clip_image), output_hidden_states=True
356
- ).hidden_states[-2]
357
- uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
358
- return image_prompt_embeds, uncond_image_prompt_embeds
359
-
360
-
361
- class IPAdapterFull(IPAdapterPlus):
362
- """IP-Adapter with full features"""
363
-
364
- def init_proj(self):
365
- image_proj_model = MLPProjModel(
366
- cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
367
- clip_embeddings_dim=self.image_encoder.config.hidden_size,
368
- ).to(self.device, dtype=torch.float16)
369
- return image_proj_model
370
-
371
-
372
- class IPAdapterPlusXL(IPAdapter):
373
- """SDXL"""
374
-
375
- def init_proj(self):
376
- image_proj_model = Resampler(
377
- dim=1280,
378
- depth=4,
379
- dim_head=64,
380
- heads=20,
381
- num_queries=self.num_tokens,
382
- embedding_dim=self.image_encoder.config.hidden_size,
383
- output_dim=self.pipe.unet.config.cross_attention_dim,
384
- ff_mult=4,
385
- ).to(self.device, dtype=torch.float16)
386
- return image_proj_model
387
-
388
- @torch.inference_mode()
389
- def get_image_embeds(self, pil_image):
390
- if isinstance(pil_image, Image.Image):
391
- pil_image = [pil_image]
392
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
393
- clip_image = clip_image.to(self.device, dtype=torch.float16)
394
- clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
395
- image_prompt_embeds = self.image_proj_model(clip_image_embeds)
396
- uncond_clip_image_embeds = self.image_encoder(
397
- torch.zeros_like(clip_image), output_hidden_states=True
398
- ).hidden_states[-2]
399
- uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
400
- return image_prompt_embeds, uncond_image_prompt_embeds
401
-
402
- def generate(
403
- self,
404
- pil_image,
405
- prompt=None,
406
- negative_prompt=None,
407
- scale=1.0,
408
- num_samples=4,
409
- seed=None,
410
- num_inference_steps=30,
411
- **kwargs,
412
- ):
413
- self.set_scale(scale)
414
-
415
- num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
416
-
417
- if prompt is None:
418
- prompt = "best quality, high quality"
419
- if negative_prompt is None:
420
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
421
-
422
- if not isinstance(prompt, List):
423
- prompt = [prompt] * num_prompts
424
- if not isinstance(negative_prompt, List):
425
- negative_prompt = [negative_prompt] * num_prompts
426
-
427
- image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
428
- bs_embed, seq_len, _ = image_prompt_embeds.shape
429
- image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
430
- image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
431
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
432
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
433
-
434
- with torch.inference_mode():
435
- (
436
- prompt_embeds,
437
- negative_prompt_embeds,
438
- pooled_prompt_embeds,
439
- negative_pooled_prompt_embeds,
440
- ) = self.pipe.encode_prompt(
441
- prompt,
442
- num_images_per_prompt=num_samples,
443
- do_classifier_free_guidance=True,
444
- negative_prompt=negative_prompt,
445
- )
446
- prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
447
- negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
448
-
449
- generator = get_generator(seed, self.device)
450
-
451
- images = self.pipe(
452
- prompt_embeds=prompt_embeds,
453
- negative_prompt_embeds=negative_prompt_embeds,
454
- pooled_prompt_embeds=pooled_prompt_embeds,
455
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
456
- num_inference_steps=num_inference_steps,
457
- generator=generator,
458
- **kwargs,
459
- ).images
460
-
461
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ip_adapter/resampler.py DELETED
@@ -1,158 +0,0 @@
1
- # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
- # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3
-
4
- import math
5
-
6
- import torch
7
- import torch.nn as nn
8
- from einops import rearrange
9
- from einops.layers.torch import Rearrange
10
-
11
-
12
- # FFN
13
- def FeedForward(dim, mult=4):
14
- inner_dim = int(dim * mult)
15
- return nn.Sequential(
16
- nn.LayerNorm(dim),
17
- nn.Linear(dim, inner_dim, bias=False),
18
- nn.GELU(),
19
- nn.Linear(inner_dim, dim, bias=False),
20
- )
21
-
22
-
23
- def reshape_tensor(x, heads):
24
- bs, length, width = x.shape
25
- # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
26
- x = x.view(bs, length, heads, -1)
27
- # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
28
- x = x.transpose(1, 2)
29
- # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
30
- x = x.reshape(bs, heads, length, -1)
31
- return x
32
-
33
-
34
- class PerceiverAttention(nn.Module):
35
- def __init__(self, *, dim, dim_head=64, heads=8):
36
- super().__init__()
37
- self.scale = dim_head**-0.5
38
- self.dim_head = dim_head
39
- self.heads = heads
40
- inner_dim = dim_head * heads
41
-
42
- self.norm1 = nn.LayerNorm(dim)
43
- self.norm2 = nn.LayerNorm(dim)
44
-
45
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
46
- self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
47
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
48
-
49
- def forward(self, x, latents):
50
- """
51
- Args:
52
- x (torch.Tensor): image features
53
- shape (b, n1, D)
54
- latent (torch.Tensor): latent features
55
- shape (b, n2, D)
56
- """
57
- x = self.norm1(x)
58
- latents = self.norm2(latents)
59
-
60
- b, l, _ = latents.shape
61
-
62
- q = self.to_q(latents)
63
- kv_input = torch.cat((x, latents), dim=-2)
64
- k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65
-
66
- q = reshape_tensor(q, self.heads)
67
- k = reshape_tensor(k, self.heads)
68
- v = reshape_tensor(v, self.heads)
69
-
70
- # attention
71
- scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72
- weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
73
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
74
- out = weight @ v
75
-
76
- out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
77
-
78
- return self.to_out(out)
79
-
80
-
81
- class Resampler(nn.Module):
82
- def __init__(
83
- self,
84
- dim=1024,
85
- depth=8,
86
- dim_head=64,
87
- heads=16,
88
- num_queries=8,
89
- embedding_dim=768,
90
- output_dim=1024,
91
- ff_mult=4,
92
- max_seq_len: int = 257, # CLIP tokens + CLS token
93
- apply_pos_emb: bool = False,
94
- num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
95
- ):
96
- super().__init__()
97
- self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
98
-
99
- self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
100
-
101
- self.proj_in = nn.Linear(embedding_dim, dim)
102
-
103
- self.proj_out = nn.Linear(dim, output_dim)
104
- self.norm_out = nn.LayerNorm(output_dim)
105
-
106
- self.to_latents_from_mean_pooled_seq = (
107
- nn.Sequential(
108
- nn.LayerNorm(dim),
109
- nn.Linear(dim, dim * num_latents_mean_pooled),
110
- Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
111
- )
112
- if num_latents_mean_pooled > 0
113
- else None
114
- )
115
-
116
- self.layers = nn.ModuleList([])
117
- for _ in range(depth):
118
- self.layers.append(
119
- nn.ModuleList(
120
- [
121
- PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
122
- FeedForward(dim=dim, mult=ff_mult),
123
- ]
124
- )
125
- )
126
-
127
- def forward(self, x):
128
- if self.pos_emb is not None:
129
- n, device = x.shape[1], x.device
130
- pos_emb = self.pos_emb(torch.arange(n, device=device))
131
- x = x + pos_emb
132
-
133
- latents = self.latents.repeat(x.size(0), 1, 1)
134
-
135
- x = self.proj_in(x)
136
-
137
- if self.to_latents_from_mean_pooled_seq:
138
- meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
139
- meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
140
- latents = torch.cat((meanpooled_latents, latents), dim=-2)
141
-
142
- for attn, ff in self.layers:
143
- latents = attn(x, latents) + latents
144
- latents = ff(latents) + latents
145
-
146
- latents = self.proj_out(latents)
147
- return self.norm_out(latents)
148
-
149
-
150
- def masked_mean(t, *, dim, mask=None):
151
- if mask is None:
152
- return t.mean(dim=dim)
153
-
154
- denom = mask.sum(dim=dim, keepdim=True)
155
- mask = rearrange(mask, "b n -> b n 1")
156
- masked_t = t.masked_fill(~mask, 0.0)
157
-
158
- return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ip_adapter/utils.py DELETED
@@ -1,93 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import numpy as np
4
- from PIL import Image
5
-
6
- attn_maps = {}
7
- def hook_fn(name):
8
- def forward_hook(module, input, output):
9
- if hasattr(module.processor, "attn_map"):
10
- attn_maps[name] = module.processor.attn_map
11
- del module.processor.attn_map
12
-
13
- return forward_hook
14
-
15
- def register_cross_attention_hook(unet):
16
- for name, module in unet.named_modules():
17
- if name.split('.')[-1].startswith('attn2'):
18
- module.register_forward_hook(hook_fn(name))
19
-
20
- return unet
21
-
22
- def upscale(attn_map, target_size):
23
- attn_map = torch.mean(attn_map, dim=0)
24
- attn_map = attn_map.permute(1,0)
25
- temp_size = None
26
-
27
- for i in range(0,5):
28
- scale = 2 ** i
29
- if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
30
- temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
31
- break
32
-
33
- assert temp_size is not None, "temp_size cannot is None"
34
-
35
- attn_map = attn_map.view(attn_map.shape[0], *temp_size)
36
-
37
- attn_map = F.interpolate(
38
- attn_map.unsqueeze(0).to(dtype=torch.float32),
39
- size=target_size,
40
- mode='bilinear',
41
- align_corners=False
42
- )[0]
43
-
44
- attn_map = torch.softmax(attn_map, dim=0)
45
- return attn_map
46
- def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
47
-
48
- idx = 0 if instance_or_negative else 1
49
- net_attn_maps = []
50
-
51
- for name, attn_map in attn_maps.items():
52
- attn_map = attn_map.cpu() if detach else attn_map
53
- attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
54
- attn_map = upscale(attn_map, image_size)
55
- net_attn_maps.append(attn_map)
56
-
57
- net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
58
-
59
- return net_attn_maps
60
-
61
- def attnmaps2images(net_attn_maps):
62
-
63
- #total_attn_scores = 0
64
- images = []
65
-
66
- for attn_map in net_attn_maps:
67
- attn_map = attn_map.cpu().numpy()
68
- #total_attn_scores += attn_map.mean().item()
69
-
70
- normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
71
- normalized_attn_map = normalized_attn_map.astype(np.uint8)
72
- #print("norm: ", normalized_attn_map.shape)
73
- image = Image.fromarray(normalized_attn_map)
74
-
75
- #image = fix_save_attn_map(attn_map)
76
- images.append(image)
77
-
78
- #print(total_attn_scores)
79
- return images
80
- def is_torch2_available():
81
- return hasattr(F, "scaled_dot_product_attention")
82
-
83
- def get_generator(seed, device):
84
-
85
- if seed is not None:
86
- if isinstance(seed, list):
87
- generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
88
- else:
89
- generator = torch.Generator(device).manual_seed(seed)
90
- else:
91
- generator = None
92
-
93
- return generator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -3,4 +3,5 @@ torch>=2.0.0
3
  transformers>=4.37.1
4
  spaces>=0.19.4
5
  huggingface-hub>=0.20.2
6
- gradio
 
 
3
  transformers>=4.37.1
4
  spaces>=0.19.4
5
  huggingface-hub>=0.20.2
6
+ gradio
7
+ accelerate