jbilcke-hf HF staff commited on
Commit
a6d4ba6
·
verified ·
1 Parent(s): a1227fd

Upload 6 files

Browse files
enhance_a_video/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .enhance import enhance_score
2
+ from .globals import (
3
+ enable_enhance,
4
+ get_enhance_weight,
5
+ get_num_frames,
6
+ is_enhance_enabled,
7
+ set_enhance_weight,
8
+ set_num_frames,
9
+ )
10
+ from .models.cogvideox import inject_enhance_for_cogvideox
11
+ from .models.hunyuanvideo import inject_enhance_for_hunyuanvideo
12
+
13
+ __all__ = [
14
+ "inject_enhance_for_cogvideox",
15
+ "inject_enhance_for_hunyuanvideo",
16
+ "enhance_score",
17
+ "get_num_frames",
18
+ "set_num_frames",
19
+ "get_enhance_weight",
20
+ "set_enhance_weight",
21
+ "enable_enhance",
22
+ "is_enhance_enabled",
23
+ ]
enhance_a_video/enhance.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from enhance_a_video.globals import get_enhance_weight
4
+
5
+
6
+ def enhance_score(query_image, key_image, head_dim, num_frames):
7
+ scale = head_dim**-0.5
8
+ query_image = query_image * scale
9
+ attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32
10
+ attn_temp = attn_temp.to(torch.float32)
11
+ attn_temp = attn_temp.softmax(dim=-1)
12
+
13
+ # Reshape to [batch_size * num_tokens, num_frames, num_frames]
14
+ attn_temp = attn_temp.reshape(-1, num_frames, num_frames)
15
+
16
+ # Create a mask for diagonal elements
17
+ diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
18
+ diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1)
19
+
20
+ # Zero out diagonal elements
21
+ attn_wo_diag = attn_temp.masked_fill(diag_mask, 0)
22
+
23
+ # Calculate mean for each token's attention matrix
24
+ # Number of off-diagonal elements per matrix is n*n - n
25
+ num_off_diag = num_frames * num_frames - num_frames
26
+ mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag
27
+
28
+ enhance_scores = mean_scores.mean() * (num_frames + get_enhance_weight())
29
+ enhance_scores = enhance_scores.clamp(min=1)
30
+ return enhance_scores
enhance_a_video/globals.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NUM_FRAMES = None
2
+ ENHANCE_WEIGHT = None
3
+ ENABLE_ENHANCE = False
4
+
5
+
6
+ def set_num_frames(num_frames: int):
7
+ global NUM_FRAMES
8
+ NUM_FRAMES = num_frames
9
+
10
+
11
+ def get_num_frames() -> int:
12
+ return NUM_FRAMES
13
+
14
+
15
+ def enable_enhance():
16
+ global ENABLE_ENHANCE
17
+ ENABLE_ENHANCE = True
18
+
19
+
20
+ def is_enhance_enabled() -> bool:
21
+ return ENABLE_ENHANCE
22
+
23
+
24
+ def set_enhance_weight(enhance_weight: float):
25
+ global ENHANCE_WEIGHT
26
+ ENHANCE_WEIGHT = enhance_weight
27
+
28
+
29
+ def get_enhance_weight() -> float:
30
+ return ENHANCE_WEIGHT
enhance_a_video/models/__init__.py ADDED
File without changes
enhance_a_video/models/cogvideox.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from diffusers.models.attention import Attention
6
+ from einops import rearrange
7
+ from torch import nn
8
+
9
+ from enhance_a_video.enhance import enhance_score
10
+ from enhance_a_video.globals import get_num_frames, is_enhance_enabled, set_num_frames
11
+
12
+
13
+ def inject_enhance_for_cogvideox(model: nn.Module) -> None:
14
+ """
15
+ Inject enhance score for CogVideoX model.
16
+ 1. register hook to update num frames
17
+ 2. replace attention processor with enhance processor to weight the attention scores
18
+ """
19
+ # register hook to update num frames
20
+ model.register_forward_pre_hook(num_frames_hook, with_kwargs=True)
21
+ # replace attention with enhanceAvideo
22
+ for name, module in model.named_modules():
23
+ if "attn" in name and isinstance(module, Attention):
24
+ module.set_processor(EnhanceCogVideoXAttnProcessor2_0())
25
+
26
+
27
+ def num_frames_hook(_, args, kwargs):
28
+ """
29
+ Hook to update the number of frames automatically.
30
+ """
31
+ if "hidden_states" in kwargs:
32
+ hidden_states = kwargs["hidden_states"]
33
+ else:
34
+ hidden_states = args[0]
35
+ num_frames = hidden_states.shape[1]
36
+ set_num_frames(num_frames)
37
+ return args, kwargs
38
+
39
+
40
+ class EnhanceCogVideoXAttnProcessor2_0:
41
+ r"""
42
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
43
+ query and key vectors, but does not include spatial normalization.
44
+ """
45
+
46
+ def __init__(self):
47
+ if not hasattr(F, "scaled_dot_product_attention"):
48
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
49
+
50
+ def _get_enhance_scores(
51
+ self,
52
+ attn: Attention,
53
+ query: torch.Tensor,
54
+ key: torch.Tensor,
55
+ head_dim: int,
56
+ text_seq_length: int,
57
+ ) -> torch.Tensor:
58
+ num_frames = get_num_frames()
59
+ spatial_dim = int((query.shape[2] - text_seq_length) / num_frames)
60
+
61
+ query_image = rearrange(
62
+ query[:, :, text_seq_length:],
63
+ "B N (T S) C -> (B S) N T C",
64
+ N=attn.heads,
65
+ T=num_frames,
66
+ S=spatial_dim,
67
+ C=head_dim,
68
+ )
69
+ key_image = rearrange(
70
+ key[:, :, text_seq_length:],
71
+ "B N (T S) C -> (B S) N T C",
72
+ N=attn.heads,
73
+ T=num_frames,
74
+ S=spatial_dim,
75
+ C=head_dim,
76
+ )
77
+ return enhance_score(query_image, key_image, head_dim, num_frames)
78
+
79
+ def __call__(
80
+ self,
81
+ attn: Attention,
82
+ hidden_states: torch.Tensor,
83
+ encoder_hidden_states: torch.Tensor,
84
+ attention_mask: Optional[torch.Tensor] = None,
85
+ image_rotary_emb: Optional[torch.Tensor] = None,
86
+ ) -> torch.Tensor:
87
+ text_seq_length = encoder_hidden_states.size(1)
88
+
89
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
90
+
91
+ batch_size, sequence_length, _ = (
92
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
93
+ )
94
+
95
+ if attention_mask is not None:
96
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
97
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
98
+
99
+ query = attn.to_q(hidden_states)
100
+ key = attn.to_k(hidden_states)
101
+ value = attn.to_v(hidden_states)
102
+
103
+ inner_dim = key.shape[-1]
104
+ head_dim = inner_dim // attn.heads
105
+
106
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
107
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
108
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
109
+
110
+ if attn.norm_q is not None:
111
+ query = attn.norm_q(query)
112
+ if attn.norm_k is not None:
113
+ key = attn.norm_k(key)
114
+
115
+ # Apply RoPE if needed
116
+ if image_rotary_emb is not None:
117
+ from diffusers.models.embeddings import apply_rotary_emb
118
+
119
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
120
+ if not attn.is_cross_attention:
121
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
122
+
123
+ # ========== Enhance-A-Video ==========
124
+ if is_enhance_enabled():
125
+ enhance_scores = self._get_enhance_scores(attn, query, key, head_dim, text_seq_length)
126
+ # ========== Enhance-A-Video ==========
127
+
128
+ hidden_states = F.scaled_dot_product_attention(
129
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
130
+ )
131
+
132
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
133
+
134
+ # linear proj
135
+ hidden_states = attn.to_out[0](hidden_states)
136
+ # dropout
137
+ hidden_states = attn.to_out[1](hidden_states)
138
+
139
+ encoder_hidden_states, hidden_states = hidden_states.split(
140
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
141
+ )
142
+
143
+ # ========== Enhance-A-Video ==========
144
+ if is_enhance_enabled():
145
+ hidden_states = hidden_states * enhance_scores
146
+ # ========== Enhance-A-Video ==========
147
+
148
+ return hidden_states, encoder_hidden_states
enhance_a_video/models/hunyuanvideo.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from diffusers.models.attention import Attention
7
+ from einops import rearrange
8
+ from torch import nn
9
+
10
+ from enhance_a_video.enhance import enhance_score
11
+ from enhance_a_video.globals import get_num_frames, is_enhance_enabled, set_num_frames
12
+
13
+
14
+ def inject_enhance_for_hunyuanvideo(model: nn.Module) -> None:
15
+ """
16
+ Inject enhance score for HunyuanVideo model.
17
+ 1. register hook to update num frames
18
+ 2. replace attention processor with enhance processor to weight the attention scores
19
+ """
20
+ # register hook to update num frames
21
+ model.register_forward_pre_hook(num_frames_hook, with_kwargs=True)
22
+ # replace attention with enhanceAvideo
23
+ for name, module in model.named_modules():
24
+ if "attn" in name and isinstance(module, Attention) and "transformer_blocks" in name:
25
+ module.set_processor(EnhanceHunyuanVideoAttnProcessor2_0())
26
+
27
+
28
+ def num_frames_hook(module, args, kwargs):
29
+ """
30
+ Hook to update the number of frames automatically.
31
+ """
32
+ if "hidden_states" in kwargs:
33
+ hidden_states = kwargs["hidden_states"]
34
+ else:
35
+ hidden_states = args[0]
36
+ num_frames = hidden_states.shape[2]
37
+ p_t = module.config.patch_size_t
38
+ post_patch_num_frames = num_frames // p_t
39
+ set_num_frames(post_patch_num_frames)
40
+ return args, kwargs
41
+
42
+
43
+ class EnhanceHunyuanVideoAttnProcessor2_0:
44
+ def __init__(self):
45
+ if not hasattr(F, "scaled_dot_product_attention"):
46
+ raise ImportError(
47
+ "HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
48
+ )
49
+
50
+ def _get_enhance_scores(self, attn, query, key, encoder_hidden_states):
51
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
52
+ img_q, img_k = query[:, :, : -encoder_hidden_states.shape[1]], key[:, :, : -encoder_hidden_states.shape[1]]
53
+ else:
54
+ img_q, img_k = query, key
55
+
56
+ num_frames = get_num_frames()
57
+ _, num_heads, ST, head_dim = img_q.shape
58
+ spatial_dim = ST / num_frames
59
+ spatial_dim = int(spatial_dim)
60
+
61
+ query_image = rearrange(
62
+ img_q, "B N (T S) C -> (B S) N T C", T=num_frames, S=spatial_dim, N=num_heads, C=head_dim
63
+ )
64
+ key_image = rearrange(img_k, "B N (T S) C -> (B S) N T C", T=num_frames, S=spatial_dim, N=num_heads, C=head_dim)
65
+
66
+ return enhance_score(query_image, key_image, head_dim, num_frames)
67
+
68
+ def __call__(
69
+ self,
70
+ attn: Attention,
71
+ hidden_states: torch.Tensor,
72
+ encoder_hidden_states: Optional[torch.Tensor] = None,
73
+ attention_mask: Optional[torch.Tensor] = None,
74
+ image_rotary_emb: Optional[torch.Tensor] = None,
75
+ ) -> torch.Tensor:
76
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
77
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
78
+
79
+ # 1. QKV projections
80
+ query = attn.to_q(hidden_states)
81
+ key = attn.to_k(hidden_states)
82
+ value = attn.to_v(hidden_states)
83
+
84
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
85
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
86
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
87
+
88
+ # 2. QK normalization
89
+ if attn.norm_q is not None:
90
+ query = attn.norm_q(query)
91
+ if attn.norm_k is not None:
92
+ key = attn.norm_k(key)
93
+
94
+ # 3. Rotational positional embeddings applied to latent stream
95
+ if image_rotary_emb is not None:
96
+ from diffusers.models.embeddings import apply_rotary_emb
97
+
98
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
99
+ query = torch.cat(
100
+ [
101
+ apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
102
+ query[:, :, -encoder_hidden_states.shape[1] :],
103
+ ],
104
+ dim=2,
105
+ )
106
+ key = torch.cat(
107
+ [
108
+ apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
109
+ key[:, :, -encoder_hidden_states.shape[1] :],
110
+ ],
111
+ dim=2,
112
+ )
113
+ else:
114
+ query = apply_rotary_emb(query, image_rotary_emb)
115
+ key = apply_rotary_emb(key, image_rotary_emb)
116
+
117
+ # ========== Enhance-A-Video ==========
118
+ if is_enhance_enabled():
119
+ enhance_scores = self._get_enhance_scores(attn, query, key, encoder_hidden_states)
120
+ # ========== Enhance-A-Video ==========
121
+
122
+ # 4. Encoder condition QKV projection and normalization
123
+ if attn.add_q_proj is not None and encoder_hidden_states is not None:
124
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
125
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
126
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
127
+
128
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
129
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
130
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
131
+
132
+ if attn.norm_added_q is not None:
133
+ encoder_query = attn.norm_added_q(encoder_query)
134
+ if attn.norm_added_k is not None:
135
+ encoder_key = attn.norm_added_k(encoder_key)
136
+
137
+ query = torch.cat([query, encoder_query], dim=2)
138
+ key = torch.cat([key, encoder_key], dim=2)
139
+ value = torch.cat([value, encoder_value], dim=2)
140
+
141
+ # 5. Attention
142
+ hidden_states = F.scaled_dot_product_attention(
143
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
144
+ )
145
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
146
+ hidden_states = hidden_states.to(query.dtype)
147
+
148
+ # 6. Output projection
149
+ if encoder_hidden_states is not None:
150
+ hidden_states, encoder_hidden_states = (
151
+ hidden_states[:, : -encoder_hidden_states.shape[1]],
152
+ hidden_states[:, -encoder_hidden_states.shape[1] :],
153
+ )
154
+
155
+ if getattr(attn, "to_out", None) is not None:
156
+ hidden_states = attn.to_out[0](hidden_states)
157
+ hidden_states = attn.to_out[1](hidden_states)
158
+
159
+ if getattr(attn, "to_add_out", None) is not None:
160
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
161
+
162
+ # ========== Enhance-A-Video ==========
163
+ if is_enhance_enabled():
164
+ hidden_states = hidden_states * enhance_scores
165
+ # ========== Enhance-A-Video ==========
166
+
167
+ return hidden_states, encoder_hidden_states