primerz commited on
Commit
a0e6fd7
·
verified ·
1 Parent(s): 0a1370c

Update ip_adapter/attention_processor.py

Browse files
Files changed (1) hide show
  1. ip_adapter/attention_processor.py +57 -354
ip_adapter/attention_processor.py CHANGED
@@ -1,4 +1,3 @@
1
- # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
@@ -7,33 +6,31 @@ try:
7
  import xformers
8
  import xformers.ops
9
  xformers_available = True
10
- except Exception as e:
11
  xformers_available = False
12
 
13
- class RegionControler(object):
 
 
14
  def __init__(self) -> None:
15
  self.prompt_image_conditioning = []
 
16
  region_control = RegionControler()
17
 
18
- class AttnProcessor(nn.Module):
19
- r"""
20
- Default processor for performing attention-related computations.
21
- """
22
- def __init__(
23
- self,
24
- hidden_size=None,
25
- cross_attention_dim=None,
26
- ):
 
27
  super().__init__()
28
 
29
- def forward(
30
- self,
31
- attn,
32
- hidden_states,
33
- encoder_hidden_states=None,
34
- attention_mask=None,
35
- temb=None,
36
- ):
37
  residual = hidden_states
38
 
39
  if attn.spatial_norm is not None:
@@ -44,286 +41,60 @@ class AttnProcessor(nn.Module):
44
  if input_ndim == 4:
45
  batch_size, channel, height, width = hidden_states.shape
46
  hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
47
-
48
- batch_size, sequence_length, _ = (
49
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
50
- )
51
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
52
-
53
- if attn.group_norm is not None:
54
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
55
-
56
- query = attn.to_q(hidden_states)
57
 
58
  if encoder_hidden_states is None:
59
  encoder_hidden_states = hidden_states
60
  elif attn.norm_cross:
61
  encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
62
 
63
- key = attn.to_k(encoder_hidden_states)
64
- value = attn.to_v(encoder_hidden_states)
65
-
66
- query = attn.head_to_batch_dim(query)
67
- key = attn.head_to_batch_dim(key)
68
- value = attn.head_to_batch_dim(value)
69
-
70
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
71
- hidden_states = torch.bmm(attention_probs, value)
72
- hidden_states = attn.batch_to_head_dim(hidden_states)
73
-
74
- # linear proj
75
- hidden_states = attn.to_out[0](hidden_states)
76
- # dropout
77
- hidden_states = attn.to_out[1](hidden_states)
78
-
79
- if input_ndim == 4:
80
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
81
-
82
- if attn.residual_connection:
83
- hidden_states = hidden_states + residual
84
-
85
- hidden_states = hidden_states / attn.rescale_output_factor
86
-
87
- return hidden_states
88
-
89
-
90
- class IPAttnProcessor(nn.Module):
91
- r"""
92
- Attention processor for IP-Adapater.
93
- Args:
94
- hidden_size (`int`):
95
- The hidden size of the attention layer.
96
- cross_attention_dim (`int`):
97
- The number of channels in the `encoder_hidden_states`.
98
- scale (`float`, defaults to 1.0):
99
- the weight scale of image prompt.
100
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
101
- The context length of the image features.
102
- """
103
-
104
- def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
105
- super().__init__()
106
-
107
- self.hidden_size = hidden_size
108
- self.cross_attention_dim = cross_attention_dim
109
- self.scale = scale
110
- self.num_tokens = num_tokens
111
-
112
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
113
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
114
-
115
- def forward(
116
- self,
117
- attn,
118
- hidden_states,
119
- encoder_hidden_states=None,
120
- attention_mask=None,
121
- temb=None,
122
- ):
123
- residual = hidden_states
124
-
125
- if attn.spatial_norm is not None:
126
- hidden_states = attn.spatial_norm(hidden_states, temb)
127
-
128
- input_ndim = hidden_states.ndim
129
-
130
- if input_ndim == 4:
131
- batch_size, channel, height, width = hidden_states.shape
132
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
133
-
134
- batch_size, sequence_length, _ = (
135
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
136
- )
137
  attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
138
 
139
  if attn.group_norm is not None:
140
  hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
141
 
142
- query = attn.to_q(hidden_states)
143
-
144
- if encoder_hidden_states is None:
145
- encoder_hidden_states = hidden_states
146
- else:
147
- # get encoder_hidden_states, ip_hidden_states
148
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
149
- encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
150
- if attn.norm_cross:
151
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
152
-
153
- key = attn.to_k(encoder_hidden_states)
154
- value = attn.to_v(encoder_hidden_states)
155
-
156
- query = attn.head_to_batch_dim(query)
157
- key = attn.head_to_batch_dim(key)
158
- value = attn.head_to_batch_dim(value)
159
 
 
 
160
  if xformers_available:
161
- hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
162
  else:
163
  attention_probs = attn.get_attention_scores(query, key, attention_mask)
164
- hidden_states = torch.bmm(attention_probs, value)
165
- hidden_states = attn.batch_to_head_dim(hidden_states)
166
-
167
- # for ip-adapter
168
- ip_key = self.to_k_ip(ip_hidden_states)
169
- ip_value = self.to_v_ip(ip_hidden_states)
170
-
171
- ip_key = attn.head_to_batch_dim(ip_key)
172
- ip_value = attn.head_to_batch_dim(ip_value)
173
-
174
- if xformers_available:
175
- ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
176
- else:
177
- ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
178
- ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
179
- ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
180
-
181
- # region control
182
- if len(region_control.prompt_image_conditioning) == 1:
183
- region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
184
- if region_mask is not None:
185
- h, w = region_mask.shape[:2]
186
- ratio = (h * w / query.shape[1]) ** 0.5
187
- mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
188
- else:
189
- mask = torch.ones_like(ip_hidden_states)
190
- ip_hidden_states = ip_hidden_states * mask
191
-
192
- hidden_states = hidden_states + self.scale * ip_hidden_states
193
 
194
- # linear proj
195
- hidden_states = attn.to_out[0](hidden_states)
196
- # dropout
197
- hidden_states = attn.to_out[1](hidden_states)
198
-
199
- if input_ndim == 4:
200
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
201
-
202
- if attn.residual_connection:
203
- hidden_states = hidden_states + residual
204
-
205
- hidden_states = hidden_states / attn.rescale_output_factor
206
-
207
- return hidden_states
208
 
209
-
210
- def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
211
- # TODO attention_mask
212
- query = query.contiguous()
213
- key = key.contiguous()
214
- value = value.contiguous()
215
- hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
216
- # hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
217
- return hidden_states
218
-
219
-
220
- class AttnProcessor2_0(torch.nn.Module):
221
- r"""
222
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
223
- """
224
- def __init__(
225
- self,
226
- hidden_size=None,
227
- cross_attention_dim=None,
228
- ):
229
- super().__init__()
230
- if not hasattr(F, "scaled_dot_product_attention"):
231
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
232
-
233
- def forward(
234
- self,
235
- attn,
236
- hidden_states,
237
- encoder_hidden_states=None,
238
- attention_mask=None,
239
- temb=None,
240
- ):
241
- residual = hidden_states
242
-
243
- if attn.spatial_norm is not None:
244
- hidden_states = attn.spatial_norm(hidden_states, temb)
245
-
246
- input_ndim = hidden_states.ndim
247
-
248
- if input_ndim == 4:
249
- batch_size, channel, height, width = hidden_states.shape
250
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
251
-
252
- batch_size, sequence_length, _ = (
253
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
254
- )
255
-
256
- if attention_mask is not None:
257
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
258
- # scaled_dot_product_attention expects attention_mask shape to be
259
- # (batch, heads, source_length, target_length)
260
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
261
-
262
- if attn.group_norm is not None:
263
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
264
 
265
  query = attn.to_q(hidden_states)
266
-
267
- if encoder_hidden_states is None:
268
- encoder_hidden_states = hidden_states
269
- elif attn.norm_cross:
270
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
271
-
272
  key = attn.to_k(encoder_hidden_states)
273
  value = attn.to_v(encoder_hidden_states)
274
 
275
- inner_dim = key.shape[-1]
276
- head_dim = inner_dim // attn.heads
277
-
278
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
279
-
280
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
281
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
282
-
283
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
284
- # TODO: add support for attn.scale when we move to Torch 2.1
285
- hidden_states = F.scaled_dot_product_attention(
286
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
287
- )
288
-
289
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
290
- hidden_states = hidden_states.to(query.dtype)
291
 
292
- # linear proj
293
  hidden_states = attn.to_out[0](hidden_states)
294
- # dropout
295
  hidden_states = attn.to_out[1](hidden_states)
296
 
297
  if input_ndim == 4:
298
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
299
 
300
  if attn.residual_connection:
301
  hidden_states = hidden_states + residual
302
 
303
- hidden_states = hidden_states / attn.rescale_output_factor
304
-
305
- return hidden_states
306
 
307
- class IPAttnProcessor2_0(torch.nn.Module):
308
- r"""
309
- Attention processor for IP-Adapater for PyTorch 2.0.
310
- Args:
311
- hidden_size (`int`):
312
- The hidden size of the attention layer.
313
- cross_attention_dim (`int`):
314
- The number of channels in the `encoder_hidden_states`.
315
- scale (`float`, defaults to 1.0):
316
- the weight scale of image prompt.
317
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
318
- The context length of the image features.
319
- """
320
 
 
 
321
  def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
322
  super().__init__()
323
-
324
- if not hasattr(F, "scaled_dot_product_attention"):
325
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
326
-
327
  self.hidden_size = hidden_size
328
  self.cross_attention_dim = cross_attention_dim
329
  self.scale = scale
@@ -331,117 +102,49 @@ class IPAttnProcessor2_0(torch.nn.Module):
331
 
332
  self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
333
  self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
 
334
 
335
- def forward(
336
- self,
337
- attn,
338
- hidden_states,
339
- encoder_hidden_states=None,
340
- attention_mask=None,
341
- temb=None,
342
- ):
343
- residual = hidden_states
344
-
345
- if attn.spatial_norm is not None:
346
- hidden_states = attn.spatial_norm(hidden_states, temb)
347
-
348
- input_ndim = hidden_states.ndim
349
-
350
- if input_ndim == 4:
351
- batch_size, channel, height, width = hidden_states.shape
352
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
353
-
354
- batch_size, sequence_length, _ = (
355
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
356
- )
357
 
358
- if attention_mask is not None:
359
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
360
- # scaled_dot_product_attention expects attention_mask shape to be
361
- # (batch, heads, source_length, target_length)
362
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
363
-
364
- if attn.group_norm is not None:
365
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
366
 
367
  query = attn.to_q(hidden_states)
368
-
369
- if encoder_hidden_states is None:
370
- encoder_hidden_states = hidden_states
371
- else:
372
- # get encoder_hidden_states, ip_hidden_states
373
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
374
- encoder_hidden_states, ip_hidden_states = (
375
- encoder_hidden_states[:, :end_pos, :],
376
- encoder_hidden_states[:, end_pos:, :],
377
- )
378
- if attn.norm_cross:
379
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
380
-
381
  key = attn.to_k(encoder_hidden_states)
382
  value = attn.to_v(encoder_hidden_states)
383
 
384
- inner_dim = key.shape[-1]
385
- head_dim = inner_dim // attn.heads
386
-
387
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
388
-
389
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
390
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
391
-
392
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
393
- # TODO: add support for attn.scale when we move to Torch 2.1
394
- hidden_states = F.scaled_dot_product_attention(
395
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
396
- )
397
-
398
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
399
- hidden_states = hidden_states.to(query.dtype)
400
-
401
- # for ip-adapter
402
- ip_key = self.to_k_ip(ip_hidden_states)
403
- ip_value = self.to_v_ip(ip_hidden_states)
404
-
405
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
406
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
407
 
408
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
409
- # TODO: add support for attn.scale when we move to Torch 2.1
410
- ip_hidden_states = F.scaled_dot_product_attention(
411
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
412
- )
413
- with torch.no_grad():
414
- self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
415
- #print(self.attn_map.shape)
416
 
417
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
418
- ip_hidden_states = ip_hidden_states.to(query.dtype)
419
 
420
- # region control
421
  if len(region_control.prompt_image_conditioning) == 1:
422
- region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
423
  if region_mask is not None:
424
- query = query.reshape([-1, query.shape[-2], query.shape[-1]])
425
- h, w = region_mask.shape[:2]
426
- ratio = (h * w / query.shape[1]) ** 0.5
427
- mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
428
  else:
429
  mask = torch.ones_like(ip_hidden_states)
430
- ip_hidden_states = ip_hidden_states * mask
431
 
432
  hidden_states = hidden_states + self.scale * ip_hidden_states
433
 
434
- # linear proj
435
  hidden_states = attn.to_out[0](hidden_states)
436
- # dropout
437
  hidden_states = attn.to_out[1](hidden_states)
438
 
439
  if input_ndim == 4:
440
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
441
 
442
  if attn.residual_connection:
443
  hidden_states = hidden_states + residual
444
 
445
- hidden_states = hidden_states / attn.rescale_output_factor
446
-
447
- return hidden_states
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
6
  import xformers
7
  import xformers.ops
8
  xformers_available = True
9
+ except Exception:
10
  xformers_available = False
11
 
12
+
13
+ # Region Controller (unchanged)
14
+ class RegionControler:
15
  def __init__(self) -> None:
16
  self.prompt_image_conditioning = []
17
+
18
  region_control = RegionControler()
19
 
20
+
21
+ # Helper function for weight initialization
22
+ def init_weights(m):
23
+ if isinstance(m, nn.Linear):
24
+ nn.init.xavier_uniform_(m.weight)
25
+
26
+
27
+ # Base Attention Processor
28
+ class BaseAttnProcessor(nn.Module):
29
+ def __init__(self):
30
  super().__init__()
31
 
32
+ def _process_input(self, attn, hidden_states, encoder_hidden_states, attention_mask, temb):
33
+ """Handles preprocessing for both AttnProcessor and IPAttnProcessor"""
 
 
 
 
 
 
34
  residual = hidden_states
35
 
36
  if attn.spatial_norm is not None:
 
41
  if input_ndim == 4:
42
  batch_size, channel, height, width = hidden_states.shape
43
  hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
44
+ else:
45
+ batch_size, sequence_length, _ = hidden_states.shape
 
 
 
 
 
 
 
 
46
 
47
  if encoder_hidden_states is None:
48
  encoder_hidden_states = hidden_states
49
  elif attn.norm_cross:
50
  encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
53
 
54
  if attn.group_norm is not None:
55
  hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
56
 
57
+ return hidden_states, encoder_hidden_states, residual, batch_size, input_ndim, height, width
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ def _apply_attention(self, attn, query, key, value, attention_mask):
60
+ """Handles the actual attention operation using either xformers or standard PyTorch"""
61
  if xformers_available:
62
+ return xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
63
  else:
64
  attention_probs = attn.get_attention_scores(query, key, attention_mask)
65
+ return torch.bmm(attention_probs, value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ # Optimized AttnProcessor
69
+ class AttnProcessor(BaseAttnProcessor):
70
+ def forward(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
71
+ hidden_states, encoder_hidden_states, residual, batch_size, input_ndim, height, width = \
72
+ self._process_input(attn, hidden_states, encoder_hidden_states, attention_mask, temb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  query = attn.to_q(hidden_states)
 
 
 
 
 
 
75
  key = attn.to_k(encoder_hidden_states)
76
  value = attn.to_v(encoder_hidden_states)
77
 
78
+ query, key, value = map(attn.head_to_batch_dim, (query, key, value))
79
+ hidden_states = self._apply_attention(attn, query, key, value, attention_mask)
80
+ hidden_states = attn.batch_to_head_dim(hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
 
82
  hidden_states = attn.to_out[0](hidden_states)
 
83
  hidden_states = attn.to_out[1](hidden_states)
84
 
85
  if input_ndim == 4:
86
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, -1, height, width)
87
 
88
  if attn.residual_connection:
89
  hidden_states = hidden_states + residual
90
 
91
+ return hidden_states / attn.rescale_output_factor
 
 
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ # Optimized IPAttnProcessor
95
+ class IPAttnProcessor(BaseAttnProcessor):
96
  def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
97
  super().__init__()
 
 
 
 
98
  self.hidden_size = hidden_size
99
  self.cross_attention_dim = cross_attention_dim
100
  self.scale = scale
 
102
 
103
  self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
104
  self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
105
+ self.apply(init_weights)
106
 
107
+ def forward(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
108
+ hidden_states, encoder_hidden_states, residual, batch_size, input_ndim, height, width = \
109
+ self._process_input(attn, hidden_states, encoder_hidden_states, attention_mask, temb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
112
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
 
 
 
 
 
 
113
 
114
  query = attn.to_q(hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  key = attn.to_k(encoder_hidden_states)
116
  value = attn.to_v(encoder_hidden_states)
117
 
118
+ query, key, value = map(attn.head_to_batch_dim, (query, key, value))
119
+ hidden_states = self._apply_attention(attn, query, key, value, attention_mask)
120
+ hidden_states = attn.batch_to_head_dim(hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
+ # Image Prompt Attention
123
+ ip_key = attn.head_to_batch_dim(self.to_k_ip(ip_hidden_states))
124
+ ip_value = attn.head_to_batch_dim(self.to_v_ip(ip_hidden_states))
 
 
 
 
 
125
 
126
+ ip_hidden_states = self._apply_attention(attn, query, ip_key, ip_value, None)
127
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
128
 
129
+ # Region Control
130
  if len(region_control.prompt_image_conditioning) == 1:
131
+ region_mask = region_control.prompt_image_conditioning[0].get("region_mask", None)
132
  if region_mask is not None:
133
+ mask = F.interpolate(region_mask[None, None], scale_factor=(ip_hidden_states.shape[1] / region_mask.shape[0]), mode="nearest").reshape([1, -1, 1])
 
 
 
134
  else:
135
  mask = torch.ones_like(ip_hidden_states)
136
+ ip_hidden_states *= mask
137
 
138
  hidden_states = hidden_states + self.scale * ip_hidden_states
139
 
140
+ # Linear projection and dropout
141
  hidden_states = attn.to_out[0](hidden_states)
 
142
  hidden_states = attn.to_out[1](hidden_states)
143
 
144
  if input_ndim == 4:
145
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, -1, height, width)
146
 
147
  if attn.residual_connection:
148
  hidden_states = hidden_states + residual
149
 
150
+ return hidden_states / attn.rescale_output_factor