Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from .attention import CrossAttention | |
| from .positional_encoding import PositionalEncoding | |
| class StreamTemporalAttention(CrossAttention): | |
| """ | |
| * window_size: The max length of attention window. | |
| * sink_size: The number sink token. | |
| * positional_rule: absolute, relative | |
| Therefore, the seq length of temporal self-attention will be: | |
| sink_length + cache_size | |
| """ | |
| def __init__( | |
| self, | |
| attention_mode=None, | |
| cross_frame_attention_mode=None, | |
| temporal_position_encoding=False, | |
| temporal_position_encoding_max_len=32, | |
| window_size=8, | |
| sink_size=0, | |
| *args, | |
| **kwargs, | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.attention_mode = self._orig_attention_mode = attention_mode | |
| self.is_cross_attention = kwargs["cross_attention_dim"] is not None | |
| self.pos_encoder = PositionalEncoding( | |
| kwargs["query_dim"], | |
| dropout=0.0, | |
| max_len=temporal_position_encoding_max_len, | |
| ) | |
| self.window_size = window_size | |
| self.sink_size = sink_size | |
| self.cache_size = self.window_size - self.sink_size | |
| assert self.cache_size >= 0, ( | |
| "cache_size must be greater or equal to 0. Please check your configuration. " | |
| f"window_size: {window_size}, sink_size: {sink_size}, " | |
| f"cache_size: {self.cache_size}" | |
| ) | |
| self.motion_module_idx = None | |
| def set_index(self, idx): | |
| self.motion_module_idx = idx | |
| def set_cache(self, denoising_steps_num: int): | |
| """ | |
| larger buffer index means cleaner latent | |
| """ | |
| device = next(self.parameters()).device | |
| dtype = next(self.parameters()).dtype | |
| # [t, 2, hw, L, c], 2 means k and v | |
| kv_cache = torch.zeros( | |
| denoising_steps_num, | |
| 2, | |
| self.h * self.w, | |
| self.window_size, | |
| self.kv_channels, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| self.denoising_steps_num = denoising_steps_num | |
| return kv_cache | |
| def prepare_pe_buffer(self): | |
| """In AnimateDiff, Temporal Self-attention use absolute positional encoding: | |
| q = w_q * (x + pe) + bias | |
| k = w_k * (x + pe) + bias | |
| v = w_v * (x + pe) + bias | |
| If we want to conduct relative positional encoding with kv-cache, we should pre-calcute | |
| `w_q/k/v * pe` and then cache `w_q/k/v * x + bias` | |
| """ | |
| pe_list = self.pos_encoder.pe[:, : self.window_size] # [1, window_size, ch] | |
| q_pe = F.linear(pe_list, self.to_q.weight) | |
| k_pe = F.linear(pe_list, self.to_k.weight) | |
| v_pe = F.linear(pe_list, self.to_v.weight) | |
| self.register_buffer("q_pe", q_pe) | |
| self.register_buffer("k_pe", k_pe) | |
| self.register_buffer("v_pe", v_pe) | |
| def prepare_qkv_full_and_cache(self, hidden_states, kv_cache, pe_idx, update_idx): | |
| """ | |
| hidden_states: [(N * bhw), F, c], | |
| kv_cache: [2, N, hw, L, c] | |
| * for warmup case: `N` should be 1 and `F` should be warmup_size (`sink_size`) | |
| * for streaming case: `N` should be `denoising_steps_num` and `F` should be `chunk_size` | |
| """ | |
| q_layer = self.to_q(hidden_states) | |
| k_layer = self.to_k(hidden_states) | |
| v_layer = self.to_v(hidden_states) | |
| q_layer = rearrange(q_layer, "(n bhw) f c -> n bhw f c", n=self.denoising_steps_num) | |
| k_layer = rearrange(k_layer, "(n bhw) f c -> n bhw f c", n=self.denoising_steps_num) | |
| v_layer = rearrange(v_layer, "(n bhw) f c -> n bhw f c", n=self.denoising_steps_num) | |
| # onnx & trt friendly indexing | |
| for idx in range(self.denoising_steps_num): | |
| kv_cache[idx, 0, :, update_idx[idx]] = k_layer[idx, :, 0] | |
| kv_cache[idx, 1, :, update_idx[idx]] = v_layer[idx, :, 0] | |
| k_full = kv_cache[:, 0] | |
| v_full = kv_cache[:, 1] | |
| kv_idx = pe_idx | |
| q_idx = torch.stack([kv_idx[idx, update_idx[idx]] for idx in range(self.denoising_steps_num)]).unsqueeze_( | |
| 1 | |
| ) # [timesteps, 1] | |
| pe_k = torch.cat( | |
| [self.k_pe.index_select(1, kv_idx[idx]) for idx in range(self.denoising_steps_num)], dim=0 | |
| ) # [n, window_size, c] | |
| pe_v = torch.cat( | |
| [self.v_pe.index_select(1, kv_idx[idx]) for idx in range(self.denoising_steps_num)], dim=0 | |
| ) # [n, window_size, c] | |
| pe_q = torch.cat( | |
| [self.q_pe.index_select(1, q_idx[idx]) for idx in range(self.denoising_steps_num)], dim=0 | |
| ) # [n, window_size, c] | |
| q_layer = q_layer + pe_q.unsqueeze(1) | |
| k_full = k_full + pe_k.unsqueeze(1) | |
| v_full = v_full + pe_v.unsqueeze(1) | |
| q_layer = rearrange(q_layer, "n bhw f c -> (n bhw) f c") | |
| k_full = rearrange(k_full, "n bhw f c -> (n bhw) f c") | |
| v_full = rearrange(v_full, "n bhw f c -> (n bhw) f c") | |
| return q_layer, k_full, v_full | |
| def forward( | |
| self, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| video_length=None, | |
| temporal_attention_mask=None, | |
| kv_cache=None, | |
| pe_idx=None, | |
| update_idx=None, | |
| *args, | |
| **kwargs, | |
| ): | |
| """ | |
| temporal_attention_mask: attention mask specific for the temporal self-attention. | |
| """ | |
| d = hidden_states.shape[1] | |
| hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) | |
| if self.group_norm is not None: | |
| hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
| query_layer, key_full, value_full = self.prepare_qkv_full_and_cache( | |
| hidden_states, kv_cache, pe_idx, update_idx | |
| ) | |
| # [(n * hw * b), f, c] -> [(n * hw * b * head), f, c // head] | |
| query_layer = self.reshape_heads_to_batch_dim(query_layer) | |
| key_full = self.reshape_heads_to_batch_dim(key_full) | |
| value_full = self.reshape_heads_to_batch_dim(value_full) | |
| if temporal_attention_mask is not None: | |
| q_size = query_layer.shape[1] | |
| # [n, self.window_size] -> [n, hw, q_size, window_size] | |
| temporal_attention_mask_ = temporal_attention_mask[:, None, None, :].repeat(1, self.h * self.w, q_size, 1) | |
| temporal_attention_mask_ = rearrange(temporal_attention_mask_, "n hw Q KV -> (n hw) Q KV") | |
| temporal_attention_mask_ = temporal_attention_mask_.repeat_interleave(self.heads, dim=0) | |
| else: | |
| temporal_attention_mask_ = None | |
| # attention, what we cannot get enough of | |
| if hasattr(F, "scaled_dot_product_attention"): | |
| hidden_states = self._memory_efficient_attention_pt20( | |
| query_layer, key_full, value_full, attention_mask=temporal_attention_mask_ | |
| ) | |
| elif self._use_memory_efficient_attention_xformers: | |
| hidden_states = self._memory_efficient_attention_xformers( | |
| query_layer, key_full, value_full, attention_mask=temporal_attention_mask_ | |
| ) | |
| # Some versions of xformers return output in fp32, cast it back to the dtype of the input | |
| hidden_states = hidden_states.to(query_layer.dtype) | |
| else: | |
| hidden_states = self._attention(query_layer, key_full, value_full, temporal_attention_mask_) | |
| # linear proj | |
| hidden_states = self.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = self.to_out[1](hidden_states) | |
| hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) | |
| return hidden_states | |