dbaranchuk commited on
Commit
8f3a280
·
verified ·
1 Parent(s): 3ad0e52

Delete p2p.py

Browse files
Files changed (1) hide show
  1. p2p.py +0 -454
p2p.py DELETED
@@ -1,454 +0,0 @@
1
- import torch.nn.functional as nnf
2
- import torch
3
- import abc
4
- import numpy as np
5
- import seq_aligner
6
-
7
- from typing import Optional, Union, Tuple, List, Callable, Dict
8
-
9
- MAX_NUM_WORDS = 77
10
- LOW_RESOURCE = False
11
- NUM_DDIM_STEPS = 50
12
- device = 'cuda'
13
- tokenizer = None
14
-
15
-
16
- # Different attention controllers
17
- # ----------------------------------------------------------------------
18
- class LocalBlend:
19
-
20
- def get_mask(self, maps, alpha, use_pool, x_t):
21
- k = 1
22
- maps = (maps * alpha).sum(-1).mean(1)
23
- if use_pool:
24
- maps = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))
25
- mask = nnf.interpolate(maps, size=(x_t.shape[2:]))
26
- mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
27
- mask = mask.gt(self.th[1 - int(use_pool)])
28
- mask = mask[:1] + mask
29
- return mask
30
-
31
- def __call__(self, x_t, attention_store):
32
- self.counter += 1
33
- if self.counter > self.start_blend:
34
-
35
- maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3]
36
- maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, MAX_NUM_WORDS) for item in maps]
37
- maps = torch.cat(maps, dim=1)
38
- mask = self.get_mask(maps, self.alpha_layers, True, x_t)
39
- if self.substruct_layers is not None:
40
- maps_sub = ~self.get_mask(maps, self.substruct_layers, False, x_t)
41
- mask = mask * maps_sub
42
- mask = mask.float()
43
- x_t = x_t[:1] + mask * (x_t - x_t[:1])
44
- return x_t
45
-
46
- def __init__(self, prompts: List[str], words: [List[List[str]]], substruct_words=None, start_blend=0.2,
47
- th=(.3, .3)):
48
- alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS)
49
- for i, (prompt, words_) in enumerate(zip(prompts, words)):
50
- if type(words_) is str:
51
- words_ = [words_]
52
- for word in words_:
53
- ind = get_word_inds(prompt, word, tokenizer)
54
- alpha_layers[i, :, :, :, :, ind] = 1
55
-
56
- if substruct_words is not None:
57
- substruct_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS)
58
- for i, (prompt, words_) in enumerate(zip(prompts, substruct_words)):
59
- if type(words_) is str:
60
- words_ = [words_]
61
- for word in words_:
62
- ind = get_word_inds(prompt, word, tokenizer)
63
- substruct_layers[i, :, :, :, :, ind] = 1
64
- self.substruct_layers = substruct_layers.to(device)
65
- else:
66
- self.substruct_layers = None
67
- self.alpha_layers = alpha_layers.to(device)
68
- self.start_blend = int(start_blend * NUM_DDIM_STEPS)
69
- self.counter = 0
70
- self.th = th
71
-
72
-
73
- class EmptyControl:
74
-
75
- def step_callback(self, x_t):
76
- return x_t
77
-
78
- def between_steps(self):
79
- return
80
-
81
- def __call__(self, attn, is_cross: bool, place_in_unet: str):
82
- return attn
83
-
84
-
85
- class AttentionControl(abc.ABC):
86
-
87
- def step_callback(self, x_t):
88
- return x_t
89
-
90
- def between_steps(self):
91
- return
92
-
93
- @property
94
- def num_uncond_att_layers(self):
95
- return self.num_att_layers if LOW_RESOURCE else 0
96
-
97
- @abc.abstractmethod
98
- def forward(self, attn, is_cross: bool, place_in_unet: str):
99
- raise NotImplementedError
100
-
101
- def __call__(self, attn, is_cross: bool, place_in_unet: str):
102
- if self.cur_att_layer >= self.num_uncond_att_layers:
103
- if LOW_RESOURCE:
104
- attn = self.forward(attn, is_cross, place_in_unet)
105
- else:
106
- h = attn.shape[0]
107
- attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
108
- self.cur_att_layer += 1
109
- if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
110
- self.cur_att_layer = 0
111
- self.cur_step += 1
112
- self.between_steps()
113
- return attn
114
-
115
- def reset(self):
116
- self.cur_step = 0
117
- self.cur_att_layer = 0
118
-
119
- def __init__(self):
120
- self.cur_step = 0
121
- self.num_att_layers = -1
122
- self.cur_att_layer = 0
123
-
124
-
125
- class SpatialReplace(EmptyControl):
126
-
127
- def step_callback(self, x_t):
128
- if self.cur_step < self.stop_inject:
129
- b = x_t.shape[0]
130
- x_t = x_t[:1].expand(b, *x_t.shape[1:])
131
- return x_t
132
-
133
- def __init__(self, stop_inject: float):
134
- super(SpatialReplace, self).__init__()
135
- self.stop_inject = int((1 - stop_inject) * NUM_DDIM_STEPS)
136
-
137
-
138
- class AttentionStore(AttentionControl):
139
-
140
- @staticmethod
141
- def get_empty_store():
142
- return {"down_cross": [], "mid_cross": [], "up_cross": [],
143
- "down_self": [], "mid_self": [], "up_self": []}
144
-
145
- def forward(self, attn, is_cross: bool, place_in_unet: str):
146
- key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
147
- if attn.shape[1] <= 32 ** 2: # avoid memory overhead
148
- self.step_store[key].append(attn)
149
- return attn
150
-
151
- def between_steps(self):
152
- if len(self.attention_store) == 0:
153
- self.attention_store = self.step_store
154
- else:
155
- for key in self.attention_store:
156
- for i in range(len(self.attention_store[key])):
157
- self.attention_store[key][i] += self.step_store[key][i]
158
- self.step_store = self.get_empty_store()
159
-
160
- def get_average_attention(self):
161
- average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in
162
- self.attention_store}
163
- return average_attention
164
-
165
- def reset(self):
166
- super(AttentionStore, self).reset()
167
- self.step_store = self.get_empty_store()
168
- self.attention_store = {}
169
-
170
- def __init__(self):
171
- super(AttentionStore, self).__init__()
172
- self.step_store = self.get_empty_store()
173
- self.attention_store = {}
174
-
175
-
176
- class AttentionControlEdit(AttentionStore, abc.ABC):
177
-
178
- def step_callback(self, x_t):
179
- if self.local_blend is not None:
180
- x_t = self.local_blend(x_t, self.attention_store)
181
- return x_t
182
-
183
- def replace_self_attention(self, attn_base, att_replace, place_in_unet):
184
- if att_replace.shape[2] <= 32 ** 2:
185
- attn_base = attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
186
- return attn_base
187
- else:
188
- return att_replace
189
-
190
- @abc.abstractmethod
191
- def replace_cross_attention(self, attn_base, att_replace):
192
- raise NotImplementedError
193
-
194
- def forward(self, attn, is_cross: bool, place_in_unet: str):
195
- super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
196
- if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
197
- h = attn.shape[0] // (self.batch_size)
198
- attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
199
- attn_base, attn_repalce = attn[0], attn[1:]
200
- if is_cross:
201
- alpha_words = self.cross_replace_alpha[self.cur_step]
202
- attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (
203
- 1 - alpha_words) * attn_repalce
204
- attn[1:] = attn_repalce_new
205
- else:
206
- attn[1:] = self.replace_self_attention(attn_base, attn_repalce, place_in_unet)
207
- attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
208
- return attn
209
-
210
- def __init__(self, prompts, num_steps: int,
211
- cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
212
- self_replace_steps: Union[float, Tuple[float, float]],
213
- local_blend: Optional[LocalBlend]):
214
- super(AttentionControlEdit, self).__init__()
215
- self.batch_size = len(prompts)
216
- self.cross_replace_alpha = get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps,
217
- tokenizer).to(device)
218
- if type(self_replace_steps) is float:
219
- self_replace_steps = 0, self_replace_steps
220
- self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
221
- self.local_blend = local_blend
222
-
223
-
224
- class AttentionReplace(AttentionControlEdit):
225
-
226
- def replace_cross_attention(self, attn_base, att_replace):
227
- return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)
228
-
229
- def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
230
- local_blend: Optional[LocalBlend] = None):
231
- super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
232
- self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device)
233
-
234
-
235
- class AttentionRefine(AttentionControlEdit):
236
-
237
- def replace_cross_attention(self, attn_base, att_replace):
238
- attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
239
- attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
240
- # attn_replace = attn_replace / attn_replace.sum(-1, keepdims=True)
241
- return attn_replace
242
-
243
- def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
244
- local_blend: Optional[LocalBlend] = None):
245
- super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
246
- self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer)
247
- self.mapper, alphas = self.mapper.to(device), alphas.to(device)
248
- self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
249
-
250
-
251
- class AttentionReweight(AttentionControlEdit):
252
-
253
- def replace_cross_attention(self, attn_base, att_replace):
254
- if self.prev_controller is not None:
255
- attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace)
256
- attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
257
- # attn_replace = attn_replace / attn_replace.sum(-1, keepdims=True)
258
- return attn_replace
259
-
260
- def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer,
261
- local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None):
262
- super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps,
263
- local_blend)
264
- self.equalizer = equalizer.to(device)
265
- self.prev_controller = controller
266
- self.attn = []
267
- # ----------------------------------------------------------------------
268
-
269
-
270
- # Attention controller during sampling
271
- # ----------------------------------------------------------------------
272
- def make_controller(prompts: List[str], is_replace_controller: bool, cross_replace_steps: Dict[str, float],
273
- self_replace_steps: float, blend_words=None, equilizer_params=None) -> AttentionControlEdit:
274
- if blend_words is None:
275
- lb = None
276
- else:
277
- lb = LocalBlend(prompts, blend_words, start_blend=0.0, th=(0.3, 0.3))
278
- if is_replace_controller:
279
- controller = AttentionReplace(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps,
280
- self_replace_steps=self_replace_steps, local_blend=lb)
281
- else:
282
- controller = AttentionRefine(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps,
283
- self_replace_steps=self_replace_steps, local_blend=lb)
284
- if equilizer_params is not None:
285
- eq = get_equalizer(prompts[1], equilizer_params["words"], equilizer_params["values"])
286
- controller = AttentionReweight(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps,
287
- self_replace_steps=self_replace_steps, equalizer=eq, local_blend=lb,
288
- controller=controller)
289
- return controller
290
-
291
- def register_attention_control(model, controller):
292
- def ca_forward(self, place_in_unet):
293
- to_out = self.to_out
294
- if type(to_out) is torch.nn.modules.container.ModuleList:
295
- to_out = self.to_out[0]
296
- else:
297
- to_out = self.to_out
298
-
299
- def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ):
300
- is_cross = encoder_hidden_states is not None
301
-
302
- residual = hidden_states
303
-
304
- if self.spatial_norm is not None:
305
- hidden_states = self.spatial_norm(hidden_states, temb)
306
-
307
- input_ndim = hidden_states.ndim
308
-
309
- if input_ndim == 4:
310
- batch_size, channel, height, width = hidden_states.shape
311
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
312
-
313
- batch_size, sequence_length, _ = (
314
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
315
- )
316
- attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
317
-
318
- if self.group_norm is not None:
319
- hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
320
-
321
- query = self.to_q(hidden_states)
322
-
323
- if encoder_hidden_states is None:
324
- encoder_hidden_states = hidden_states
325
- elif self.norm_cross:
326
- encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
327
-
328
- key = self.to_k(encoder_hidden_states)
329
- value = self.to_v(encoder_hidden_states)
330
-
331
- query = self.head_to_batch_dim(query)
332
- key = self.head_to_batch_dim(key)
333
- value = self.head_to_batch_dim(value)
334
-
335
- attention_probs = self.get_attention_scores(query, key, attention_mask)
336
- attention_probs = controller(attention_probs, is_cross, place_in_unet)
337
-
338
- hidden_states = torch.bmm(attention_probs, value)
339
- hidden_states = self.batch_to_head_dim(hidden_states)
340
-
341
- # linear proj
342
- hidden_states = to_out(hidden_states)
343
-
344
- if input_ndim == 4:
345
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
346
-
347
- if self.residual_connection:
348
- hidden_states = hidden_states + residual
349
-
350
- hidden_states = hidden_states / self.rescale_output_factor
351
-
352
- return hidden_states
353
-
354
- return forward
355
-
356
- class DummyController:
357
-
358
- def __call__(self, *args):
359
- return args[0]
360
-
361
- def __init__(self):
362
- self.num_att_layers = 0
363
-
364
- if controller is None:
365
- controller = DummyController()
366
-
367
- def register_recr(net_, count, place_in_unet):
368
- if net_.__class__.__name__ == 'Attention':
369
- net_.forward = ca_forward(net_, place_in_unet)
370
- return count + 1
371
- elif hasattr(net_, 'children'):
372
- for net__ in net_.children():
373
- count = register_recr(net__, count, place_in_unet)
374
- return count
375
-
376
- cross_att_count = 0
377
- sub_nets = model.unet.named_children()
378
- for net in sub_nets:
379
- if "down" in net[0]:
380
- cross_att_count += register_recr(net[1], 0, "down")
381
- elif "up" in net[0]:
382
- cross_att_count += register_recr(net[1], 0, "up")
383
- elif "mid" in net[0]:
384
- cross_att_count += register_recr(net[1], 0, "mid")
385
-
386
- controller.num_att_layers = cross_att_count
387
- # ----------------------------------------------------------------------
388
-
389
- # Other
390
- # ----------------------------------------------------------------------
391
- def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float],
392
- Tuple[float, ...]]):
393
- if type(word_select) is int or type(word_select) is str:
394
- word_select = (word_select,)
395
- equalizer = torch.ones(1, 77)
396
-
397
- for word, val in zip(word_select, values):
398
- inds = get_word_inds(text, word, tokenizer)
399
- equalizer[:, inds] = val
400
- return equalizer
401
-
402
- def get_time_words_attention_alpha(prompts, num_steps,
403
- cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
404
- tokenizer, max_num_words=77):
405
- if type(cross_replace_steps) is not dict:
406
- cross_replace_steps = {"default_": cross_replace_steps}
407
- if "default_" not in cross_replace_steps:
408
- cross_replace_steps["default_"] = (0., 1.)
409
- alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
410
- for i in range(len(prompts) - 1):
411
- alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
412
- i)
413
- for key, item in cross_replace_steps.items():
414
- if key != "default_":
415
- inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
416
- for i, ind in enumerate(inds):
417
- if len(ind) > 0:
418
- alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
419
- alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
420
- return alpha_time_words
421
-
422
- def get_word_inds(text: str, word_place: int, tokenizer):
423
- split_text = text.split(" ")
424
- if type(word_place) is str:
425
- word_place = [i for i, word in enumerate(split_text) if word_place == word]
426
- elif type(word_place) is int:
427
- word_place = [word_place]
428
- out = []
429
- if len(word_place) > 0:
430
- words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
431
- cur_len, ptr = 0, 0
432
-
433
- for i in range(len(words_encode)):
434
- cur_len += len(words_encode[i])
435
- if ptr in word_place:
436
- out.append(i + 1)
437
- if cur_len >= len(split_text[ptr]):
438
- ptr += 1
439
- cur_len = 0
440
- return np.array(out)
441
-
442
-
443
- def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int,
444
- word_inds: Optional[torch.Tensor] = None):
445
- if type(bounds) is float:
446
- bounds = 0, bounds
447
- start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
448
- if word_inds is None:
449
- word_inds = torch.arange(alpha.shape[2])
450
- alpha[: start, prompt_ind, word_inds] = 0
451
- alpha[start: end, prompt_ind, word_inds] = 1
452
- alpha[end:, prompt_ind, word_inds] = 0
453
- return alpha
454
- # ----------------------------------------------------------------------