Spaces:
Runtime error
Runtime error
Delete p2p.py
Browse files
p2p.py
DELETED
@@ -1,454 +0,0 @@
|
|
1 |
-
import torch.nn.functional as nnf
|
2 |
-
import torch
|
3 |
-
import abc
|
4 |
-
import numpy as np
|
5 |
-
import seq_aligner
|
6 |
-
|
7 |
-
from typing import Optional, Union, Tuple, List, Callable, Dict
|
8 |
-
|
9 |
-
MAX_NUM_WORDS = 77
|
10 |
-
LOW_RESOURCE = False
|
11 |
-
NUM_DDIM_STEPS = 50
|
12 |
-
device = 'cuda'
|
13 |
-
tokenizer = None
|
14 |
-
|
15 |
-
|
16 |
-
# Different attention controllers
|
17 |
-
# ----------------------------------------------------------------------
|
18 |
-
class LocalBlend:
|
19 |
-
|
20 |
-
def get_mask(self, maps, alpha, use_pool, x_t):
|
21 |
-
k = 1
|
22 |
-
maps = (maps * alpha).sum(-1).mean(1)
|
23 |
-
if use_pool:
|
24 |
-
maps = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))
|
25 |
-
mask = nnf.interpolate(maps, size=(x_t.shape[2:]))
|
26 |
-
mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
|
27 |
-
mask = mask.gt(self.th[1 - int(use_pool)])
|
28 |
-
mask = mask[:1] + mask
|
29 |
-
return mask
|
30 |
-
|
31 |
-
def __call__(self, x_t, attention_store):
|
32 |
-
self.counter += 1
|
33 |
-
if self.counter > self.start_blend:
|
34 |
-
|
35 |
-
maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3]
|
36 |
-
maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, MAX_NUM_WORDS) for item in maps]
|
37 |
-
maps = torch.cat(maps, dim=1)
|
38 |
-
mask = self.get_mask(maps, self.alpha_layers, True, x_t)
|
39 |
-
if self.substruct_layers is not None:
|
40 |
-
maps_sub = ~self.get_mask(maps, self.substruct_layers, False, x_t)
|
41 |
-
mask = mask * maps_sub
|
42 |
-
mask = mask.float()
|
43 |
-
x_t = x_t[:1] + mask * (x_t - x_t[:1])
|
44 |
-
return x_t
|
45 |
-
|
46 |
-
def __init__(self, prompts: List[str], words: [List[List[str]]], substruct_words=None, start_blend=0.2,
|
47 |
-
th=(.3, .3)):
|
48 |
-
alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS)
|
49 |
-
for i, (prompt, words_) in enumerate(zip(prompts, words)):
|
50 |
-
if type(words_) is str:
|
51 |
-
words_ = [words_]
|
52 |
-
for word in words_:
|
53 |
-
ind = get_word_inds(prompt, word, tokenizer)
|
54 |
-
alpha_layers[i, :, :, :, :, ind] = 1
|
55 |
-
|
56 |
-
if substruct_words is not None:
|
57 |
-
substruct_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS)
|
58 |
-
for i, (prompt, words_) in enumerate(zip(prompts, substruct_words)):
|
59 |
-
if type(words_) is str:
|
60 |
-
words_ = [words_]
|
61 |
-
for word in words_:
|
62 |
-
ind = get_word_inds(prompt, word, tokenizer)
|
63 |
-
substruct_layers[i, :, :, :, :, ind] = 1
|
64 |
-
self.substruct_layers = substruct_layers.to(device)
|
65 |
-
else:
|
66 |
-
self.substruct_layers = None
|
67 |
-
self.alpha_layers = alpha_layers.to(device)
|
68 |
-
self.start_blend = int(start_blend * NUM_DDIM_STEPS)
|
69 |
-
self.counter = 0
|
70 |
-
self.th = th
|
71 |
-
|
72 |
-
|
73 |
-
class EmptyControl:
|
74 |
-
|
75 |
-
def step_callback(self, x_t):
|
76 |
-
return x_t
|
77 |
-
|
78 |
-
def between_steps(self):
|
79 |
-
return
|
80 |
-
|
81 |
-
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
82 |
-
return attn
|
83 |
-
|
84 |
-
|
85 |
-
class AttentionControl(abc.ABC):
|
86 |
-
|
87 |
-
def step_callback(self, x_t):
|
88 |
-
return x_t
|
89 |
-
|
90 |
-
def between_steps(self):
|
91 |
-
return
|
92 |
-
|
93 |
-
@property
|
94 |
-
def num_uncond_att_layers(self):
|
95 |
-
return self.num_att_layers if LOW_RESOURCE else 0
|
96 |
-
|
97 |
-
@abc.abstractmethod
|
98 |
-
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
99 |
-
raise NotImplementedError
|
100 |
-
|
101 |
-
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
102 |
-
if self.cur_att_layer >= self.num_uncond_att_layers:
|
103 |
-
if LOW_RESOURCE:
|
104 |
-
attn = self.forward(attn, is_cross, place_in_unet)
|
105 |
-
else:
|
106 |
-
h = attn.shape[0]
|
107 |
-
attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
|
108 |
-
self.cur_att_layer += 1
|
109 |
-
if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
|
110 |
-
self.cur_att_layer = 0
|
111 |
-
self.cur_step += 1
|
112 |
-
self.between_steps()
|
113 |
-
return attn
|
114 |
-
|
115 |
-
def reset(self):
|
116 |
-
self.cur_step = 0
|
117 |
-
self.cur_att_layer = 0
|
118 |
-
|
119 |
-
def __init__(self):
|
120 |
-
self.cur_step = 0
|
121 |
-
self.num_att_layers = -1
|
122 |
-
self.cur_att_layer = 0
|
123 |
-
|
124 |
-
|
125 |
-
class SpatialReplace(EmptyControl):
|
126 |
-
|
127 |
-
def step_callback(self, x_t):
|
128 |
-
if self.cur_step < self.stop_inject:
|
129 |
-
b = x_t.shape[0]
|
130 |
-
x_t = x_t[:1].expand(b, *x_t.shape[1:])
|
131 |
-
return x_t
|
132 |
-
|
133 |
-
def __init__(self, stop_inject: float):
|
134 |
-
super(SpatialReplace, self).__init__()
|
135 |
-
self.stop_inject = int((1 - stop_inject) * NUM_DDIM_STEPS)
|
136 |
-
|
137 |
-
|
138 |
-
class AttentionStore(AttentionControl):
|
139 |
-
|
140 |
-
@staticmethod
|
141 |
-
def get_empty_store():
|
142 |
-
return {"down_cross": [], "mid_cross": [], "up_cross": [],
|
143 |
-
"down_self": [], "mid_self": [], "up_self": []}
|
144 |
-
|
145 |
-
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
146 |
-
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
147 |
-
if attn.shape[1] <= 32 ** 2: # avoid memory overhead
|
148 |
-
self.step_store[key].append(attn)
|
149 |
-
return attn
|
150 |
-
|
151 |
-
def between_steps(self):
|
152 |
-
if len(self.attention_store) == 0:
|
153 |
-
self.attention_store = self.step_store
|
154 |
-
else:
|
155 |
-
for key in self.attention_store:
|
156 |
-
for i in range(len(self.attention_store[key])):
|
157 |
-
self.attention_store[key][i] += self.step_store[key][i]
|
158 |
-
self.step_store = self.get_empty_store()
|
159 |
-
|
160 |
-
def get_average_attention(self):
|
161 |
-
average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in
|
162 |
-
self.attention_store}
|
163 |
-
return average_attention
|
164 |
-
|
165 |
-
def reset(self):
|
166 |
-
super(AttentionStore, self).reset()
|
167 |
-
self.step_store = self.get_empty_store()
|
168 |
-
self.attention_store = {}
|
169 |
-
|
170 |
-
def __init__(self):
|
171 |
-
super(AttentionStore, self).__init__()
|
172 |
-
self.step_store = self.get_empty_store()
|
173 |
-
self.attention_store = {}
|
174 |
-
|
175 |
-
|
176 |
-
class AttentionControlEdit(AttentionStore, abc.ABC):
|
177 |
-
|
178 |
-
def step_callback(self, x_t):
|
179 |
-
if self.local_blend is not None:
|
180 |
-
x_t = self.local_blend(x_t, self.attention_store)
|
181 |
-
return x_t
|
182 |
-
|
183 |
-
def replace_self_attention(self, attn_base, att_replace, place_in_unet):
|
184 |
-
if att_replace.shape[2] <= 32 ** 2:
|
185 |
-
attn_base = attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
|
186 |
-
return attn_base
|
187 |
-
else:
|
188 |
-
return att_replace
|
189 |
-
|
190 |
-
@abc.abstractmethod
|
191 |
-
def replace_cross_attention(self, attn_base, att_replace):
|
192 |
-
raise NotImplementedError
|
193 |
-
|
194 |
-
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
195 |
-
super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
|
196 |
-
if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
|
197 |
-
h = attn.shape[0] // (self.batch_size)
|
198 |
-
attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
|
199 |
-
attn_base, attn_repalce = attn[0], attn[1:]
|
200 |
-
if is_cross:
|
201 |
-
alpha_words = self.cross_replace_alpha[self.cur_step]
|
202 |
-
attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (
|
203 |
-
1 - alpha_words) * attn_repalce
|
204 |
-
attn[1:] = attn_repalce_new
|
205 |
-
else:
|
206 |
-
attn[1:] = self.replace_self_attention(attn_base, attn_repalce, place_in_unet)
|
207 |
-
attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
|
208 |
-
return attn
|
209 |
-
|
210 |
-
def __init__(self, prompts, num_steps: int,
|
211 |
-
cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
|
212 |
-
self_replace_steps: Union[float, Tuple[float, float]],
|
213 |
-
local_blend: Optional[LocalBlend]):
|
214 |
-
super(AttentionControlEdit, self).__init__()
|
215 |
-
self.batch_size = len(prompts)
|
216 |
-
self.cross_replace_alpha = get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps,
|
217 |
-
tokenizer).to(device)
|
218 |
-
if type(self_replace_steps) is float:
|
219 |
-
self_replace_steps = 0, self_replace_steps
|
220 |
-
self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
|
221 |
-
self.local_blend = local_blend
|
222 |
-
|
223 |
-
|
224 |
-
class AttentionReplace(AttentionControlEdit):
|
225 |
-
|
226 |
-
def replace_cross_attention(self, attn_base, att_replace):
|
227 |
-
return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)
|
228 |
-
|
229 |
-
def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
|
230 |
-
local_blend: Optional[LocalBlend] = None):
|
231 |
-
super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
|
232 |
-
self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device)
|
233 |
-
|
234 |
-
|
235 |
-
class AttentionRefine(AttentionControlEdit):
|
236 |
-
|
237 |
-
def replace_cross_attention(self, attn_base, att_replace):
|
238 |
-
attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
|
239 |
-
attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
|
240 |
-
# attn_replace = attn_replace / attn_replace.sum(-1, keepdims=True)
|
241 |
-
return attn_replace
|
242 |
-
|
243 |
-
def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
|
244 |
-
local_blend: Optional[LocalBlend] = None):
|
245 |
-
super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
|
246 |
-
self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer)
|
247 |
-
self.mapper, alphas = self.mapper.to(device), alphas.to(device)
|
248 |
-
self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
|
249 |
-
|
250 |
-
|
251 |
-
class AttentionReweight(AttentionControlEdit):
|
252 |
-
|
253 |
-
def replace_cross_attention(self, attn_base, att_replace):
|
254 |
-
if self.prev_controller is not None:
|
255 |
-
attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace)
|
256 |
-
attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
|
257 |
-
# attn_replace = attn_replace / attn_replace.sum(-1, keepdims=True)
|
258 |
-
return attn_replace
|
259 |
-
|
260 |
-
def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer,
|
261 |
-
local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None):
|
262 |
-
super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps,
|
263 |
-
local_blend)
|
264 |
-
self.equalizer = equalizer.to(device)
|
265 |
-
self.prev_controller = controller
|
266 |
-
self.attn = []
|
267 |
-
# ----------------------------------------------------------------------
|
268 |
-
|
269 |
-
|
270 |
-
# Attention controller during sampling
|
271 |
-
# ----------------------------------------------------------------------
|
272 |
-
def make_controller(prompts: List[str], is_replace_controller: bool, cross_replace_steps: Dict[str, float],
|
273 |
-
self_replace_steps: float, blend_words=None, equilizer_params=None) -> AttentionControlEdit:
|
274 |
-
if blend_words is None:
|
275 |
-
lb = None
|
276 |
-
else:
|
277 |
-
lb = LocalBlend(prompts, blend_words, start_blend=0.0, th=(0.3, 0.3))
|
278 |
-
if is_replace_controller:
|
279 |
-
controller = AttentionReplace(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps,
|
280 |
-
self_replace_steps=self_replace_steps, local_blend=lb)
|
281 |
-
else:
|
282 |
-
controller = AttentionRefine(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps,
|
283 |
-
self_replace_steps=self_replace_steps, local_blend=lb)
|
284 |
-
if equilizer_params is not None:
|
285 |
-
eq = get_equalizer(prompts[1], equilizer_params["words"], equilizer_params["values"])
|
286 |
-
controller = AttentionReweight(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps,
|
287 |
-
self_replace_steps=self_replace_steps, equalizer=eq, local_blend=lb,
|
288 |
-
controller=controller)
|
289 |
-
return controller
|
290 |
-
|
291 |
-
def register_attention_control(model, controller):
|
292 |
-
def ca_forward(self, place_in_unet):
|
293 |
-
to_out = self.to_out
|
294 |
-
if type(to_out) is torch.nn.modules.container.ModuleList:
|
295 |
-
to_out = self.to_out[0]
|
296 |
-
else:
|
297 |
-
to_out = self.to_out
|
298 |
-
|
299 |
-
def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ):
|
300 |
-
is_cross = encoder_hidden_states is not None
|
301 |
-
|
302 |
-
residual = hidden_states
|
303 |
-
|
304 |
-
if self.spatial_norm is not None:
|
305 |
-
hidden_states = self.spatial_norm(hidden_states, temb)
|
306 |
-
|
307 |
-
input_ndim = hidden_states.ndim
|
308 |
-
|
309 |
-
if input_ndim == 4:
|
310 |
-
batch_size, channel, height, width = hidden_states.shape
|
311 |
-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
312 |
-
|
313 |
-
batch_size, sequence_length, _ = (
|
314 |
-
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
315 |
-
)
|
316 |
-
attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
317 |
-
|
318 |
-
if self.group_norm is not None:
|
319 |
-
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
320 |
-
|
321 |
-
query = self.to_q(hidden_states)
|
322 |
-
|
323 |
-
if encoder_hidden_states is None:
|
324 |
-
encoder_hidden_states = hidden_states
|
325 |
-
elif self.norm_cross:
|
326 |
-
encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
|
327 |
-
|
328 |
-
key = self.to_k(encoder_hidden_states)
|
329 |
-
value = self.to_v(encoder_hidden_states)
|
330 |
-
|
331 |
-
query = self.head_to_batch_dim(query)
|
332 |
-
key = self.head_to_batch_dim(key)
|
333 |
-
value = self.head_to_batch_dim(value)
|
334 |
-
|
335 |
-
attention_probs = self.get_attention_scores(query, key, attention_mask)
|
336 |
-
attention_probs = controller(attention_probs, is_cross, place_in_unet)
|
337 |
-
|
338 |
-
hidden_states = torch.bmm(attention_probs, value)
|
339 |
-
hidden_states = self.batch_to_head_dim(hidden_states)
|
340 |
-
|
341 |
-
# linear proj
|
342 |
-
hidden_states = to_out(hidden_states)
|
343 |
-
|
344 |
-
if input_ndim == 4:
|
345 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
346 |
-
|
347 |
-
if self.residual_connection:
|
348 |
-
hidden_states = hidden_states + residual
|
349 |
-
|
350 |
-
hidden_states = hidden_states / self.rescale_output_factor
|
351 |
-
|
352 |
-
return hidden_states
|
353 |
-
|
354 |
-
return forward
|
355 |
-
|
356 |
-
class DummyController:
|
357 |
-
|
358 |
-
def __call__(self, *args):
|
359 |
-
return args[0]
|
360 |
-
|
361 |
-
def __init__(self):
|
362 |
-
self.num_att_layers = 0
|
363 |
-
|
364 |
-
if controller is None:
|
365 |
-
controller = DummyController()
|
366 |
-
|
367 |
-
def register_recr(net_, count, place_in_unet):
|
368 |
-
if net_.__class__.__name__ == 'Attention':
|
369 |
-
net_.forward = ca_forward(net_, place_in_unet)
|
370 |
-
return count + 1
|
371 |
-
elif hasattr(net_, 'children'):
|
372 |
-
for net__ in net_.children():
|
373 |
-
count = register_recr(net__, count, place_in_unet)
|
374 |
-
return count
|
375 |
-
|
376 |
-
cross_att_count = 0
|
377 |
-
sub_nets = model.unet.named_children()
|
378 |
-
for net in sub_nets:
|
379 |
-
if "down" in net[0]:
|
380 |
-
cross_att_count += register_recr(net[1], 0, "down")
|
381 |
-
elif "up" in net[0]:
|
382 |
-
cross_att_count += register_recr(net[1], 0, "up")
|
383 |
-
elif "mid" in net[0]:
|
384 |
-
cross_att_count += register_recr(net[1], 0, "mid")
|
385 |
-
|
386 |
-
controller.num_att_layers = cross_att_count
|
387 |
-
# ----------------------------------------------------------------------
|
388 |
-
|
389 |
-
# Other
|
390 |
-
# ----------------------------------------------------------------------
|
391 |
-
def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float],
|
392 |
-
Tuple[float, ...]]):
|
393 |
-
if type(word_select) is int or type(word_select) is str:
|
394 |
-
word_select = (word_select,)
|
395 |
-
equalizer = torch.ones(1, 77)
|
396 |
-
|
397 |
-
for word, val in zip(word_select, values):
|
398 |
-
inds = get_word_inds(text, word, tokenizer)
|
399 |
-
equalizer[:, inds] = val
|
400 |
-
return equalizer
|
401 |
-
|
402 |
-
def get_time_words_attention_alpha(prompts, num_steps,
|
403 |
-
cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
|
404 |
-
tokenizer, max_num_words=77):
|
405 |
-
if type(cross_replace_steps) is not dict:
|
406 |
-
cross_replace_steps = {"default_": cross_replace_steps}
|
407 |
-
if "default_" not in cross_replace_steps:
|
408 |
-
cross_replace_steps["default_"] = (0., 1.)
|
409 |
-
alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
|
410 |
-
for i in range(len(prompts) - 1):
|
411 |
-
alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
|
412 |
-
i)
|
413 |
-
for key, item in cross_replace_steps.items():
|
414 |
-
if key != "default_":
|
415 |
-
inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
|
416 |
-
for i, ind in enumerate(inds):
|
417 |
-
if len(ind) > 0:
|
418 |
-
alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
|
419 |
-
alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
|
420 |
-
return alpha_time_words
|
421 |
-
|
422 |
-
def get_word_inds(text: str, word_place: int, tokenizer):
|
423 |
-
split_text = text.split(" ")
|
424 |
-
if type(word_place) is str:
|
425 |
-
word_place = [i for i, word in enumerate(split_text) if word_place == word]
|
426 |
-
elif type(word_place) is int:
|
427 |
-
word_place = [word_place]
|
428 |
-
out = []
|
429 |
-
if len(word_place) > 0:
|
430 |
-
words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
|
431 |
-
cur_len, ptr = 0, 0
|
432 |
-
|
433 |
-
for i in range(len(words_encode)):
|
434 |
-
cur_len += len(words_encode[i])
|
435 |
-
if ptr in word_place:
|
436 |
-
out.append(i + 1)
|
437 |
-
if cur_len >= len(split_text[ptr]):
|
438 |
-
ptr += 1
|
439 |
-
cur_len = 0
|
440 |
-
return np.array(out)
|
441 |
-
|
442 |
-
|
443 |
-
def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int,
|
444 |
-
word_inds: Optional[torch.Tensor] = None):
|
445 |
-
if type(bounds) is float:
|
446 |
-
bounds = 0, bounds
|
447 |
-
start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
|
448 |
-
if word_inds is None:
|
449 |
-
word_inds = torch.arange(alpha.shape[2])
|
450 |
-
alpha[: start, prompt_ind, word_inds] = 0
|
451 |
-
alpha[start: end, prompt_ind, word_inds] = 1
|
452 |
-
alpha[end:, prompt_ind, word_inds] = 0
|
453 |
-
return alpha
|
454 |
-
# ----------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|