SettW commited on
Commit
967e6ad
·
verified ·
1 Parent(s): ed52dfa

Create cross_frame_attention.py

Browse files
Files changed (1) hide show
  1. cross_frame_attention.py +121 -0
cross_frame_attention.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/Picsart-AI-Research/Text2Video-Zero
2
+ import torch
3
+ from einops import rearrange
4
+
5
+ class CrossFrameAttnProcessor:
6
+ def __init__(self, unet_chunk_size=2):
7
+ self.unet_chunk_size = unet_chunk_size
8
+
9
+ def __call__(
10
+ self,
11
+ attn,
12
+ hidden_states,
13
+ encoder_hidden_states=None,
14
+ attention_mask=None, **kwargs):
15
+ batch_size, sequence_length, _ = hidden_states.shape
16
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
17
+ query = attn.to_q(hidden_states)
18
+
19
+ is_cross_attention = encoder_hidden_states is not None
20
+ if encoder_hidden_states is None:
21
+ encoder_hidden_states = hidden_states
22
+ elif attn.norm_cross:
23
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
24
+ key = attn.to_k(encoder_hidden_states)
25
+ value = attn.to_v(encoder_hidden_states)
26
+ # Sparse Attention
27
+ if not is_cross_attention:
28
+ video_length = key.size()[0] // self.unet_chunk_size
29
+ # print("Video length is", video_length)
30
+ # former_frame_index = torch.arange(video_length) - 1
31
+ # former_frame_index[0] = 0
32
+ former_frame_index = [0] * video_length
33
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
34
+ key = key[:, former_frame_index]
35
+ key = rearrange(key, "b f d c -> (b f) d c")
36
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
37
+ value = value[:, former_frame_index]
38
+ value = rearrange(value, "b f d c -> (b f) d c")
39
+
40
+ query = attn.head_to_batch_dim(query)
41
+ key = attn.head_to_batch_dim(key)
42
+ value = attn.head_to_batch_dim(value)
43
+
44
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
45
+ hidden_states = torch.bmm(attention_probs, value)
46
+ hidden_states = attn.batch_to_head_dim(hidden_states)
47
+
48
+ # linear proj
49
+ hidden_states = attn.to_out[0](hidden_states)
50
+ # dropout
51
+ hidden_states = attn.to_out[1](hidden_states)
52
+
53
+ return hidden_states
54
+
55
+
56
+
57
+ class AttnProcessorX:
58
+ r"""
59
+ Default processor for performing attention-related computations.
60
+ """
61
+
62
+ def __call__(
63
+ self,
64
+ attn,
65
+ hidden_states,
66
+ encoder_hidden_states=None,
67
+ attention_mask=None,
68
+ temb=None,
69
+ scale=1.0,
70
+ ):
71
+ residual = hidden_states
72
+
73
+ if attn.spatial_norm is not None:
74
+ hidden_states = attn.spatial_norm(hidden_states, temb)
75
+
76
+ input_ndim = hidden_states.ndim
77
+
78
+ if input_ndim == 4:
79
+ batch_size, channel, height, width = hidden_states.shape
80
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
81
+
82
+ batch_size, sequence_length, _ = (
83
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
84
+ )
85
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
86
+
87
+ if attn.group_norm is not None:
88
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
89
+
90
+ query = attn.to_q(hidden_states, scale=scale)
91
+
92
+ if encoder_hidden_states is None:
93
+ encoder_hidden_states = hidden_states
94
+ elif attn.norm_cross:
95
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
96
+
97
+ key = attn.to_k(encoder_hidden_states, scale=scale)
98
+ value = attn.to_v(encoder_hidden_states, scale=scale)
99
+
100
+ query = attn.head_to_batch_dim(query)
101
+ key = attn.head_to_batch_dim(key)
102
+ value = attn.head_to_batch_dim(value)
103
+
104
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
105
+ hidden_states = torch.bmm(attention_probs, value)
106
+ hidden_states = attn.batch_to_head_dim(hidden_states)
107
+
108
+ # linear proj
109
+ hidden_states = attn.to_out[0](hidden_states, scale=scale)
110
+ # dropout
111
+ hidden_states = attn.to_out[1](hidden_states)
112
+
113
+ if input_ndim == 4:
114
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
115
+
116
+ if attn.residual_connection:
117
+ hidden_states = hidden_states + residual
118
+
119
+ hidden_states = hidden_states / attn.rescale_output_factor
120
+
121
+ return hidden_states