ford442 commited on
Commit
29c63b7
·
verified ·
1 Parent(s): 2c472d5

Delete ip_adapter

Browse files
ip_adapter/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .ip_adapter import IPAdapter, IPAdapterXL, IPAdapterPlus
 
 
ip_adapter/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (234 Bytes)
 
ip_adapter/__pycache__/attention_processor.cpython-310.pyc DELETED
Binary file (9.71 kB)
 
ip_adapter/__pycache__/ip_adapter.cpython-310.pyc DELETED
Binary file (8.17 kB)
 
ip_adapter/__pycache__/resampler.cpython-310.pyc DELETED
Binary file (3.17 kB)
 
ip_adapter/__pycache__/utils.cpython-310.pyc DELETED
Binary file (362 Bytes)
 
ip_adapter/attention_processor.py DELETED
@@ -1,553 +0,0 @@
1
- # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
-
6
-
7
- class AttnProcessor(nn.Module):
8
- r"""
9
- Default processor for performing attention-related computations.
10
- """
11
- def __init__(
12
- self,
13
- hidden_size=None,
14
- cross_attention_dim=None,
15
- ):
16
- super().__init__()
17
-
18
- def __call__(
19
- self,
20
- attn,
21
- hidden_states,
22
- encoder_hidden_states=None,
23
- attention_mask=None,
24
- temb=None,
25
- ):
26
- residual = hidden_states
27
-
28
- if attn.spatial_norm is not None:
29
- hidden_states = attn.spatial_norm(hidden_states, temb)
30
-
31
- input_ndim = hidden_states.ndim
32
-
33
- if input_ndim == 4:
34
- batch_size, channel, height, width = hidden_states.shape
35
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
36
-
37
- batch_size, sequence_length, _ = (
38
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
39
- )
40
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
41
-
42
- if attn.group_norm is not None:
43
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
44
-
45
- query = attn.to_q(hidden_states)
46
-
47
- if encoder_hidden_states is None:
48
- encoder_hidden_states = hidden_states
49
- elif attn.norm_cross:
50
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
51
-
52
- key = attn.to_k(encoder_hidden_states)
53
- value = attn.to_v(encoder_hidden_states)
54
-
55
- query = attn.head_to_batch_dim(query)
56
- key = attn.head_to_batch_dim(key)
57
- value = attn.head_to_batch_dim(value)
58
-
59
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
60
- hidden_states = torch.bmm(attention_probs, value)
61
- hidden_states = attn.batch_to_head_dim(hidden_states)
62
-
63
- # linear proj
64
- hidden_states = attn.to_out[0](hidden_states)
65
- # dropout
66
- hidden_states = attn.to_out[1](hidden_states)
67
-
68
- if input_ndim == 4:
69
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
70
-
71
- if attn.residual_connection:
72
- hidden_states = hidden_states + residual
73
-
74
- hidden_states = hidden_states / attn.rescale_output_factor
75
-
76
- return hidden_states
77
-
78
-
79
- class IPAttnProcessor(nn.Module):
80
- r"""
81
- Attention processor for IP-Adapater.
82
- Args:
83
- hidden_size (`int`):
84
- The hidden size of the attention layer.
85
- cross_attention_dim (`int`):
86
- The number of channels in the `encoder_hidden_states`.
87
- scale (`float`, defaults to 1.0):
88
- the weight scale of image prompt.
89
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
90
- The context length of the image features.
91
- """
92
-
93
- def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
94
- super().__init__()
95
-
96
- self.hidden_size = hidden_size
97
- self.cross_attention_dim = cross_attention_dim
98
- self.scale = scale
99
- self.num_tokens = num_tokens
100
-
101
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
102
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
103
-
104
- def __call__(
105
- self,
106
- attn,
107
- hidden_states,
108
- encoder_hidden_states=None,
109
- attention_mask=None,
110
- temb=None,
111
- ):
112
- residual = hidden_states
113
-
114
- if attn.spatial_norm is not None:
115
- hidden_states = attn.spatial_norm(hidden_states, temb)
116
-
117
- input_ndim = hidden_states.ndim
118
-
119
- if input_ndim == 4:
120
- batch_size, channel, height, width = hidden_states.shape
121
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
122
-
123
- batch_size, sequence_length, _ = (
124
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
125
- )
126
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
127
-
128
- if attn.group_norm is not None:
129
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
130
-
131
- query = attn.to_q(hidden_states)
132
-
133
- if encoder_hidden_states is None:
134
- encoder_hidden_states = hidden_states
135
- else:
136
- # get encoder_hidden_states, ip_hidden_states
137
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
138
- encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
139
- if attn.norm_cross:
140
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
141
-
142
- key = attn.to_k(encoder_hidden_states)
143
- value = attn.to_v(encoder_hidden_states)
144
-
145
- query = attn.head_to_batch_dim(query)
146
- key = attn.head_to_batch_dim(key)
147
- value = attn.head_to_batch_dim(value)
148
-
149
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
150
- hidden_states = torch.bmm(attention_probs, value)
151
- hidden_states = attn.batch_to_head_dim(hidden_states)
152
-
153
- # for ip-adapter
154
- ip_key = self.to_k_ip(ip_hidden_states)
155
- ip_value = self.to_v_ip(ip_hidden_states)
156
-
157
- ip_key = attn.head_to_batch_dim(ip_key)
158
- ip_value = attn.head_to_batch_dim(ip_value)
159
-
160
- ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
161
- ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
162
- ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
163
-
164
- hidden_states = hidden_states + self.scale * ip_hidden_states
165
-
166
- # linear proj
167
- hidden_states = attn.to_out[0](hidden_states)
168
- # dropout
169
- hidden_states = attn.to_out[1](hidden_states)
170
-
171
- if input_ndim == 4:
172
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
173
-
174
- if attn.residual_connection:
175
- hidden_states = hidden_states + residual
176
-
177
- hidden_states = hidden_states / attn.rescale_output_factor
178
-
179
- return hidden_states
180
-
181
-
182
- class AttnProcessor2_0(torch.nn.Module):
183
- r"""
184
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
185
- """
186
- def __init__(
187
- self,
188
- hidden_size=None,
189
- cross_attention_dim=None,
190
- ):
191
- super().__init__()
192
- if not hasattr(F, "scaled_dot_product_attention"):
193
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
194
-
195
- def __call__(
196
- self,
197
- attn,
198
- hidden_states,
199
- encoder_hidden_states=None,
200
- attention_mask=None,
201
- temb=None,
202
- ):
203
- residual = hidden_states
204
-
205
- if attn.spatial_norm is not None:
206
- hidden_states = attn.spatial_norm(hidden_states, temb)
207
-
208
- input_ndim = hidden_states.ndim
209
-
210
- if input_ndim == 4:
211
- batch_size, channel, height, width = hidden_states.shape
212
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
213
-
214
- batch_size, sequence_length, _ = (
215
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
216
- )
217
-
218
- if attention_mask is not None:
219
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
220
- # scaled_dot_product_attention expects attention_mask shape to be
221
- # (batch, heads, source_length, target_length)
222
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
223
-
224
- if attn.group_norm is not None:
225
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
226
-
227
- query = attn.to_q(hidden_states)
228
-
229
- if encoder_hidden_states is None:
230
- encoder_hidden_states = hidden_states
231
- elif attn.norm_cross:
232
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
233
-
234
- key = attn.to_k(encoder_hidden_states)
235
- value = attn.to_v(encoder_hidden_states)
236
-
237
- inner_dim = key.shape[-1]
238
- head_dim = inner_dim // attn.heads
239
-
240
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
241
-
242
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
243
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
244
-
245
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
246
- # TODO: add support for attn.scale when we move to Torch 2.1
247
- hidden_states = F.scaled_dot_product_attention(
248
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
249
- )
250
-
251
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
252
- hidden_states = hidden_states.to(query.dtype)
253
-
254
- # linear proj
255
- hidden_states = attn.to_out[0](hidden_states)
256
- # dropout
257
- hidden_states = attn.to_out[1](hidden_states)
258
-
259
- if input_ndim == 4:
260
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
261
-
262
- if attn.residual_connection:
263
- hidden_states = hidden_states + residual
264
-
265
- hidden_states = hidden_states / attn.rescale_output_factor
266
-
267
- return hidden_states
268
-
269
-
270
- class IPAttnProcessor2_0(torch.nn.Module):
271
- r"""
272
- Attention processor for IP-Adapater for PyTorch 2.0.
273
- Args:
274
- hidden_size (`int`):
275
- The hidden size of the attention layer.
276
- cross_attention_dim (`int`):
277
- The number of channels in the `encoder_hidden_states`.
278
- scale (`float`, defaults to 1.0):
279
- the weight scale of image prompt.
280
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
281
- The context length of the image features.
282
- """
283
-
284
- def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
285
- super().__init__()
286
-
287
- if not hasattr(F, "scaled_dot_product_attention"):
288
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
289
-
290
- self.hidden_size = hidden_size
291
- self.cross_attention_dim = cross_attention_dim
292
- self.scale = scale
293
- self.num_tokens = num_tokens
294
-
295
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
296
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
297
-
298
- def __call__(
299
- self,
300
- attn,
301
- hidden_states,
302
- encoder_hidden_states=None,
303
- attention_mask=None,
304
- temb=None,
305
- ):
306
- residual = hidden_states
307
-
308
- if attn.spatial_norm is not None:
309
- hidden_states = attn.spatial_norm(hidden_states, temb)
310
-
311
- input_ndim = hidden_states.ndim
312
-
313
- if input_ndim == 4:
314
- batch_size, channel, height, width = hidden_states.shape
315
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
316
-
317
- batch_size, sequence_length, _ = (
318
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
319
- )
320
-
321
- if attention_mask is not None:
322
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
323
- # scaled_dot_product_attention expects attention_mask shape to be
324
- # (batch, heads, source_length, target_length)
325
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
326
-
327
- if attn.group_norm is not None:
328
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
329
-
330
- query = attn.to_q(hidden_states)
331
-
332
- if encoder_hidden_states is None:
333
- encoder_hidden_states = hidden_states
334
- else:
335
- # get encoder_hidden_states, ip_hidden_states
336
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
337
- encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
338
- if attn.norm_cross:
339
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
340
-
341
- key = attn.to_k(encoder_hidden_states)
342
- value = attn.to_v(encoder_hidden_states)
343
-
344
- inner_dim = key.shape[-1]
345
- head_dim = inner_dim // attn.heads
346
-
347
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
348
-
349
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
350
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
351
-
352
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
353
- # TODO: add support for attn.scale when we move to Torch 2.1
354
- hidden_states = F.scaled_dot_product_attention(
355
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
356
- )
357
-
358
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
359
- hidden_states = hidden_states.to(query.dtype)
360
-
361
- # for ip-adapter
362
- ip_key = self.to_k_ip(ip_hidden_states)
363
- ip_value = self.to_v_ip(ip_hidden_states)
364
-
365
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
366
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
367
-
368
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
369
- # TODO: add support for attn.scale when we move to Torch 2.1
370
- ip_hidden_states = F.scaled_dot_product_attention(
371
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
372
- )
373
-
374
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
375
- ip_hidden_states = ip_hidden_states.to(query.dtype)
376
-
377
- hidden_states = hidden_states + self.scale * ip_hidden_states
378
-
379
- # linear proj
380
- hidden_states = attn.to_out[0](hidden_states)
381
- # dropout
382
- hidden_states = attn.to_out[1](hidden_states)
383
-
384
- if input_ndim == 4:
385
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
386
-
387
- if attn.residual_connection:
388
- hidden_states = hidden_states + residual
389
-
390
- hidden_states = hidden_states / attn.rescale_output_factor
391
-
392
- return hidden_states
393
-
394
-
395
- ## for controlnet
396
- class CNAttnProcessor:
397
- r"""
398
- Default processor for performing attention-related computations.
399
- """
400
-
401
- def __init__(self, num_tokens=4):
402
- self.num_tokens = num_tokens
403
-
404
- def __call__(
405
- self,
406
- attn,
407
- hidden_states,
408
- encoder_hidden_states=None,
409
- attention_mask=None,
410
- temb=None
411
- ):
412
- residual = hidden_states
413
-
414
- if attn.spatial_norm is not None:
415
- hidden_states = attn.spatial_norm(hidden_states, temb)
416
-
417
- input_ndim = hidden_states.ndim
418
-
419
- if input_ndim == 4:
420
- batch_size, channel, height, width = hidden_states.shape
421
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
422
-
423
- batch_size, sequence_length, _ = (
424
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
425
- )
426
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
427
-
428
- if attn.group_norm is not None:
429
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
430
-
431
- query = attn.to_q(hidden_states)
432
-
433
- if encoder_hidden_states is None:
434
- encoder_hidden_states = hidden_states
435
- else:
436
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
437
- encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
438
- if attn.norm_cross:
439
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
440
-
441
- key = attn.to_k(encoder_hidden_states)
442
- value = attn.to_v(encoder_hidden_states)
443
-
444
- query = attn.head_to_batch_dim(query)
445
- key = attn.head_to_batch_dim(key)
446
- value = attn.head_to_batch_dim(value)
447
-
448
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
449
- hidden_states = torch.bmm(attention_probs, value)
450
- hidden_states = attn.batch_to_head_dim(hidden_states)
451
-
452
- # linear proj
453
- hidden_states = attn.to_out[0](hidden_states)
454
- # dropout
455
- hidden_states = attn.to_out[1](hidden_states)
456
-
457
- if input_ndim == 4:
458
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
459
-
460
- if attn.residual_connection:
461
- hidden_states = hidden_states + residual
462
-
463
- hidden_states = hidden_states / attn.rescale_output_factor
464
-
465
- return hidden_states
466
-
467
-
468
- class CNAttnProcessor2_0:
469
- r"""
470
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
471
- """
472
-
473
- def __init__(self, num_tokens=4):
474
- if not hasattr(F, "scaled_dot_product_attention"):
475
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
476
- self.num_tokens = num_tokens
477
-
478
- def __call__(
479
- self,
480
- attn,
481
- hidden_states,
482
- encoder_hidden_states=None,
483
- attention_mask=None,
484
- temb=None,
485
- ):
486
- residual = hidden_states
487
-
488
- if attn.spatial_norm is not None:
489
- hidden_states = attn.spatial_norm(hidden_states, temb)
490
-
491
- input_ndim = hidden_states.ndim
492
-
493
- if input_ndim == 4:
494
- batch_size, channel, height, width = hidden_states.shape
495
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
496
-
497
- batch_size, sequence_length, _ = (
498
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
499
- )
500
-
501
- if attention_mask is not None:
502
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
503
- # scaled_dot_product_attention expects attention_mask shape to be
504
- # (batch, heads, source_length, target_length)
505
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
506
-
507
- if attn.group_norm is not None:
508
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
509
-
510
- query = attn.to_q(hidden_states)
511
-
512
- if encoder_hidden_states is None:
513
- encoder_hidden_states = hidden_states
514
- else:
515
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
516
- encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
517
- if attn.norm_cross:
518
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
519
-
520
- key = attn.to_k(encoder_hidden_states)
521
- value = attn.to_v(encoder_hidden_states)
522
-
523
- inner_dim = key.shape[-1]
524
- head_dim = inner_dim // attn.heads
525
-
526
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
527
-
528
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
529
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
530
-
531
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
532
- # TODO: add support for attn.scale when we move to Torch 2.1
533
- hidden_states = F.scaled_dot_product_attention(
534
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
535
- )
536
-
537
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
538
- hidden_states = hidden_states.to(query.dtype)
539
-
540
- # linear proj
541
- hidden_states = attn.to_out[0](hidden_states)
542
- # dropout
543
- hidden_states = attn.to_out[1](hidden_states)
544
-
545
- if input_ndim == 4:
546
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
547
-
548
- if attn.residual_connection:
549
- hidden_states = hidden_states + residual
550
-
551
- hidden_states = hidden_states / attn.rescale_output_factor
552
-
553
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ip_adapter/ip_adapter.py DELETED
@@ -1,273 +0,0 @@
1
- import os
2
- from typing import List
3
-
4
- import torch
5
- from diffusers import StableDiffusionPipeline
6
- from diffusers.pipelines.controlnet import MultiControlNetModel
7
- from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
8
- from PIL import Image
9
-
10
- from .utils import is_torch2_available
11
- if is_torch2_available():
12
- from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor, CNAttnProcessor2_0 as CNAttnProcessor
13
- else:
14
- from .attention_processor import IPAttnProcessor, AttnProcessor, CNAttnProcessor
15
- from .resampler import Resampler
16
-
17
-
18
- class ImageProjModel(torch.nn.Module):
19
- """Projection Model"""
20
- def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
21
- super().__init__()
22
-
23
- self.cross_attention_dim = cross_attention_dim
24
- self.clip_extra_context_tokens = clip_extra_context_tokens
25
- self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
26
- self.norm = torch.nn.LayerNorm(cross_attention_dim)
27
-
28
- def forward(self, image_embeds):
29
- embeds = image_embeds
30
- clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
31
- clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
32
- return clip_extra_context_tokens
33
-
34
-
35
- class IPAdapter:
36
-
37
- def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
38
-
39
- self.device = device
40
- self.image_encoder_path = image_encoder_path
41
- self.ip_ckpt = ip_ckpt
42
- self.num_tokens = num_tokens
43
-
44
- self.pipe = sd_pipe.to(self.device)
45
- self.set_ip_adapter()
46
-
47
- # load image encoder
48
- self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(self.device, dtype=torch.bfloat16)
49
- self.clip_image_processor = CLIPImageProcessor()
50
- # image proj model
51
- self.image_proj_model = self.init_proj()
52
- self.load_ip_adapter()
53
-
54
- def init_proj(self):
55
- image_proj_model = ImageProjModel(
56
- cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
57
- clip_embeddings_dim=self.image_encoder.config.projection_dim,
58
- clip_extra_context_tokens=self.num_tokens,
59
- ).to(self.device, dtype=torch.bfloat16)
60
- return image_proj_model
61
-
62
- def set_ip_adapter(self):
63
- unet = self.pipe.unet
64
- attn_procs = {}
65
- for name in unet.attn_processors.keys():
66
- cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
67
- if name.startswith("mid_block"):
68
- hidden_size = unet.config.block_out_channels[-1]
69
- elif name.startswith("up_blocks"):
70
- block_id = int(name[len("up_blocks.")])
71
- hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
72
- elif name.startswith("down_blocks"):
73
- block_id = int(name[len("down_blocks.")])
74
- hidden_size = unet.config.block_out_channels[block_id]
75
- if cross_attention_dim is None:
76
- attn_procs[name] = AttnProcessor()
77
- else:
78
- attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
79
- scale=1.0,num_tokens= self.num_tokens).to(self.device, dtype=torch.bfloat16)
80
- unet.set_attn_processor(attn_procs)
81
- if hasattr(self.pipe, "controlnet"):
82
- if isinstance(self.pipe.controlnet, MultiControlNetModel):
83
- for controlnet in self.pipe.controlnet.nets:
84
- controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
85
- else:
86
- self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
87
-
88
- def update_state_dict(self, state_dict):
89
- image_proj_dict = {}
90
- ip_adapter_dict = {}
91
-
92
- for k in state_dict.keys():
93
- if k.startswith("image_proj_model"):
94
- image_proj_dict[k.replace("image_proj_model.", "")] = state_dict[k]
95
- if k.startswith("adapter_modules"):
96
- ip_adapter_dict[k.replace("adapter_modules.", "")] = state_dict[k]
97
-
98
- dict = {'image_proj': image_proj_dict,
99
- 'ip_adapter' : ip_adapter_dict
100
- }
101
- return dict
102
-
103
- def load_ip_adapter(self):
104
- state_dict = torch.load(self.ip_ckpt, map_location="cpu")
105
- if "image_proj_model.proj.weight" in state_dict.keys():
106
- state_dict = self.update_state_dict(state_dict)
107
- self.image_proj_model.load_state_dict(state_dict["image_proj"])
108
- ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
109
- ip_layers.load_state_dict(state_dict["ip_adapter"])
110
-
111
- @torch.inference_mode()
112
- def get_image_embeds(self, pil_image):
113
- if isinstance(pil_image, Image.Image):
114
- pil_image = [pil_image]
115
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
116
- clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.bfloat16)).image_embeds
117
- image_prompt_embeds = self.image_proj_model(clip_image_embeds)
118
- uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
119
- return image_prompt_embeds, uncond_image_prompt_embeds
120
-
121
- def set_scale(self, scale):
122
- for attn_processor in self.pipe.unet.attn_processors.values():
123
- if isinstance(attn_processor, IPAttnProcessor):
124
- attn_processor.scale = scale
125
-
126
- def generate(
127
- self,
128
- pil_image,
129
- prompt=None,
130
- negative_prompt=None,
131
- scale=1.0,
132
- num_samples=4,
133
- seed=-1,
134
- guidance_scale=7.5,
135
- num_inference_steps=30,
136
- **kwargs,
137
- ):
138
- self.set_scale(scale)
139
-
140
- if isinstance(pil_image, List):
141
- num_prompts = len(pil_image)
142
- else:
143
- num_prompts = 1
144
-
145
- # if isinstance(pil_image, Image.Image):
146
- # num_prompts = 1
147
- # else:
148
- # num_prompts = len(pil_image)
149
- # print("num promp", num_prompts)
150
-
151
- if prompt is None:
152
- prompt = "best quality, high quality"
153
- if negative_prompt is None:
154
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
155
-
156
- if not isinstance(prompt, List):
157
- prompt = [prompt] * num_prompts
158
- if not isinstance(negative_prompt, List):
159
- negative_prompt = [negative_prompt] * num_prompts
160
-
161
- image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
162
- bs_embed, seq_len, _ = image_prompt_embeds.shape
163
- image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
164
- image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
165
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
166
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
167
-
168
- with torch.inference_mode():
169
- prompt_embeds = self.pipe._encode_prompt(
170
- prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
171
- negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
172
-
173
- prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
174
- negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
175
-
176
- generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
177
- images = self.pipe(
178
- prompt_embeds=prompt_embeds,
179
- negative_prompt_embeds=negative_prompt_embeds,
180
- guidance_scale=guidance_scale,
181
- num_inference_steps=num_inference_steps,
182
- generator=generator,
183
- **kwargs,
184
- ).images
185
-
186
- return images
187
-
188
-
189
- class IPAdapterXL(IPAdapter):
190
- """SDXL"""
191
-
192
- def generate(
193
- self,
194
- pil_image,
195
- prompt=None,
196
- negative_prompt=None,
197
- scale=1.0,
198
- num_samples=4,
199
- seed=-1,
200
- num_inference_steps=30,
201
- **kwargs,
202
- ):
203
- self.set_scale(scale)
204
-
205
- if isinstance(pil_image, Image.Image):
206
- num_prompts = 1
207
- else:
208
- num_prompts = len(pil_image)
209
-
210
- if prompt is None:
211
- prompt = "best quality, high quality"
212
- if negative_prompt is None:
213
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
214
-
215
- if not isinstance(prompt, List):
216
- prompt = [prompt] * num_prompts
217
- if not isinstance(negative_prompt, List):
218
- negative_prompt = [negative_prompt] * num_prompts
219
-
220
- image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
221
- bs_embed, seq_len, _ = image_prompt_embeds.shape
222
- image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
223
- image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
224
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
225
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
226
-
227
- with torch.inference_mode():
228
- prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipe.encode_prompt(
229
- prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
230
- prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
231
- negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
232
-
233
- generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
234
- images = self.pipe(
235
- prompt_embeds=prompt_embeds,
236
- negative_prompt_embeds=negative_prompt_embeds,
237
- pooled_prompt_embeds=pooled_prompt_embeds,
238
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
239
- num_inference_steps=num_inference_steps,
240
- generator=generator,
241
- **kwargs,
242
- ).images
243
-
244
- return images
245
-
246
-
247
- class IPAdapterPlus(IPAdapter):
248
- """IP-Adapter with fine-grained features"""
249
-
250
- def init_proj(self):
251
- image_proj_model = Resampler(
252
- dim=self.pipe.unet.config.cross_attention_dim,
253
- depth=4,
254
- dim_head=64,
255
- heads=12,
256
- num_queries=self.num_tokens,
257
- embedding_dim=self.image_encoder.config.hidden_size,
258
- output_dim=self.pipe.unet.config.cross_attention_dim,
259
- ff_mult=4
260
- ).to(self.device, dtype=torch.bfloat16)
261
- return image_proj_model
262
-
263
- @torch.inference_mode()
264
- def get_image_embeds(self, pil_image):
265
- if isinstance(pil_image, Image.Image):
266
- pil_image = [pil_image]
267
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
268
- clip_image = clip_image.to(self.device, dtype=torch.bfloat16)
269
- clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
270
- image_prompt_embeds = self.image_proj_model(clip_image_embeds)
271
- uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2]
272
- uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
273
- return image_prompt_embeds, uncond_image_prompt_embeds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ip_adapter/resampler.py DELETED
@@ -1,121 +0,0 @@
1
- # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
- import math
3
-
4
- import torch
5
- import torch.nn as nn
6
-
7
-
8
- # FFN
9
- def FeedForward(dim, mult=4):
10
- inner_dim = int(dim * mult)
11
- return nn.Sequential(
12
- nn.LayerNorm(dim),
13
- nn.Linear(dim, inner_dim, bias=False),
14
- nn.GELU(),
15
- nn.Linear(inner_dim, dim, bias=False),
16
- )
17
-
18
-
19
- def reshape_tensor(x, heads):
20
- bs, length, width = x.shape
21
- #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
22
- x = x.view(bs, length, heads, -1)
23
- # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
24
- x = x.transpose(1, 2)
25
- # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
26
- x = x.reshape(bs, heads, length, -1)
27
- return x
28
-
29
-
30
- class PerceiverAttention(nn.Module):
31
- def __init__(self, *, dim, dim_head=64, heads=8):
32
- super().__init__()
33
- self.scale = dim_head**-0.5
34
- self.dim_head = dim_head
35
- self.heads = heads
36
- inner_dim = dim_head * heads
37
-
38
- self.norm1 = nn.LayerNorm(dim)
39
- self.norm2 = nn.LayerNorm(dim)
40
-
41
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
42
- self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
43
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
44
-
45
-
46
- def forward(self, x, latents):
47
- """
48
- Args:
49
- x (torch.Tensor): image features
50
- shape (b, n1, D)
51
- latent (torch.Tensor): latent features
52
- shape (b, n2, D)
53
- """
54
- x = self.norm1(x)
55
- latents = self.norm2(latents)
56
-
57
- b, l, _ = latents.shape
58
-
59
- q = self.to_q(latents)
60
- kv_input = torch.cat((x, latents), dim=-2)
61
- k, v = self.to_kv(kv_input).chunk(2, dim=-1)
62
-
63
- q = reshape_tensor(q, self.heads)
64
- k = reshape_tensor(k, self.heads)
65
- v = reshape_tensor(v, self.heads)
66
-
67
- # attention
68
- scale = 1 / math.sqrt(math.sqrt(self.dim_head))
69
- weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
70
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
71
- out = weight @ v
72
-
73
- out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
74
-
75
- return self.to_out(out)
76
-
77
-
78
- class Resampler(nn.Module):
79
- def __init__(
80
- self,
81
- dim=1024,
82
- depth=8,
83
- dim_head=64,
84
- heads=16,
85
- num_queries=8,
86
- embedding_dim=768,
87
- output_dim=1024,
88
- ff_mult=4,
89
- ):
90
- super().__init__()
91
-
92
- self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
93
-
94
- self.proj_in = nn.Linear(embedding_dim, dim)
95
-
96
- self.proj_out = nn.Linear(dim, output_dim)
97
- self.norm_out = nn.LayerNorm(output_dim)
98
-
99
- self.layers = nn.ModuleList([])
100
- for _ in range(depth):
101
- self.layers.append(
102
- nn.ModuleList(
103
- [
104
- PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
105
- FeedForward(dim=dim, mult=ff_mult),
106
- ]
107
- )
108
- )
109
-
110
- def forward(self, x):
111
-
112
- latents = self.latents.repeat(x.size(0), 1, 1)
113
-
114
- x = self.proj_in(x)
115
-
116
- for attn, ff in self.layers:
117
- latents = attn(x, latents) + latents
118
- latents = ff(latents) + latents
119
-
120
- latents = self.proj_out(latents)
121
- return self.norm_out(latents)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ip_adapter/utils.py DELETED
@@ -1,5 +0,0 @@
1
- import torch.nn.functional as F
2
-
3
-
4
- def is_torch2_available():
5
- return hasattr(F, "scaled_dot_product_attention")