alexnasa commited on
Commit
9b96809
·
verified ·
1 Parent(s): aa0dca3

Update OmniAvatar/wan_video.py

Browse files
Files changed (1) hide show
  1. OmniAvatar/wan_video.py +344 -339
OmniAvatar/wan_video.py CHANGED
@@ -1,340 +1,345 @@
1
- import types
2
- from .models.model_manager import ModelManager
3
- from .models.wan_video_dit import WanModel
4
- from .models.wan_video_text_encoder import WanTextEncoder
5
- from .models.wan_video_vae import WanVideoVAE
6
- from .schedulers.flow_match import FlowMatchScheduler
7
- from .base import BasePipeline
8
- from .prompters import WanPrompter
9
- import torch, os
10
- from einops import rearrange
11
- import numpy as np
12
- from PIL import Image
13
- from tqdm import tqdm
14
- from typing import Optional
15
- from .vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
16
- from .models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
17
- from .models.wan_video_dit import RMSNorm
18
- from .models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
19
-
20
-
21
- class WanVideoPipeline(BasePipeline):
22
-
23
- def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None):
24
- super().__init__(device=device, torch_dtype=torch_dtype)
25
- self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
26
- self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
27
- self.text_encoder: WanTextEncoder = None
28
- self.image_encoder = None
29
- self.dit: WanModel = None
30
- self.vae: WanVideoVAE = None
31
- self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder']
32
- self.height_division_factor = 16
33
- self.width_division_factor = 16
34
- self.use_unified_sequence_parallel = False
35
- self.sp_size = 1
36
-
37
-
38
- def enable_vram_management(self, num_persistent_param_in_dit=None):
39
- dtype = next(iter(self.text_encoder.parameters())).dtype
40
- enable_vram_management(
41
- self.text_encoder,
42
- module_map = {
43
- torch.nn.Linear: AutoWrappedLinear,
44
- torch.nn.Embedding: AutoWrappedModule,
45
- T5RelativeEmbedding: AutoWrappedModule,
46
- T5LayerNorm: AutoWrappedModule,
47
- },
48
- module_config = dict(
49
- offload_dtype=dtype,
50
- offload_device="cpu",
51
- onload_dtype=dtype,
52
- onload_device="cpu",
53
- computation_dtype=self.torch_dtype,
54
- computation_device=self.device,
55
- ),
56
- )
57
- dtype = next(iter(self.dit.parameters())).dtype
58
- enable_vram_management(
59
- self.dit,
60
- module_map = {
61
- torch.nn.Linear: AutoWrappedLinear,
62
- torch.nn.Conv3d: AutoWrappedModule,
63
- torch.nn.LayerNorm: AutoWrappedModule,
64
- RMSNorm: AutoWrappedModule,
65
- },
66
- module_config = dict(
67
- offload_dtype=dtype,
68
- offload_device="cpu",
69
- onload_dtype=dtype,
70
- onload_device=self.device,
71
- computation_dtype=self.torch_dtype,
72
- computation_device=self.device,
73
- ),
74
- max_num_param=num_persistent_param_in_dit,
75
- overflow_module_config = dict(
76
- offload_dtype=dtype,
77
- offload_device="cpu",
78
- onload_dtype=dtype,
79
- onload_device="cpu",
80
- computation_dtype=self.torch_dtype,
81
- computation_device=self.device,
82
- ),
83
- )
84
- dtype = next(iter(self.vae.parameters())).dtype
85
- enable_vram_management(
86
- self.vae,
87
- module_map = {
88
- torch.nn.Linear: AutoWrappedLinear,
89
- torch.nn.Conv2d: AutoWrappedModule,
90
- RMS_norm: AutoWrappedModule,
91
- CausalConv3d: AutoWrappedModule,
92
- Upsample: AutoWrappedModule,
93
- torch.nn.SiLU: AutoWrappedModule,
94
- torch.nn.Dropout: AutoWrappedModule,
95
- },
96
- module_config = dict(
97
- offload_dtype=dtype,
98
- offload_device="cpu",
99
- onload_dtype=dtype,
100
- onload_device=self.device,
101
- computation_dtype=self.torch_dtype,
102
- computation_device=self.device,
103
- ),
104
- )
105
- if self.image_encoder is not None:
106
- dtype = next(iter(self.image_encoder.parameters())).dtype
107
- enable_vram_management(
108
- self.image_encoder,
109
- module_map = {
110
- torch.nn.Linear: AutoWrappedLinear,
111
- torch.nn.Conv2d: AutoWrappedModule,
112
- torch.nn.LayerNorm: AutoWrappedModule,
113
- },
114
- module_config = dict(
115
- offload_dtype=dtype,
116
- offload_device="cpu",
117
- onload_dtype=dtype,
118
- onload_device="cpu",
119
- computation_dtype=dtype,
120
- computation_device=self.device,
121
- ),
122
- )
123
- self.enable_cpu_offload()
124
-
125
-
126
- def fetch_models(self, model_manager: ModelManager):
127
- text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True)
128
- if text_encoder_model_and_path is not None:
129
- self.text_encoder, tokenizer_path = text_encoder_model_and_path
130
- self.prompter.fetch_models(self.text_encoder)
131
- self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl"))
132
- self.dit = model_manager.fetch_model("wan_video_dit")
133
- self.vae = model_manager.fetch_model("wan_video_vae")
134
- self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
135
-
136
-
137
- @staticmethod
138
- def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False, infer=False):
139
- if device is None: device = model_manager.device
140
- if torch_dtype is None: torch_dtype = model_manager.torch_dtype
141
- pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
142
- pipe.fetch_models(model_manager)
143
- if use_usp:
144
- from xfuser.core.distributed import get_sequence_parallel_world_size, get_sp_group
145
- from OmniAvatar.distributed.xdit_context_parallel import usp_attn_forward
146
- for block in pipe.dit.blocks:
147
- block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
148
- pipe.sp_size = get_sequence_parallel_world_size()
149
- pipe.use_unified_sequence_parallel = True
150
- pipe.sp_group = get_sp_group()
151
- return pipe
152
-
153
-
154
- def denoising_model(self):
155
- return self.dit
156
-
157
-
158
- def encode_prompt(self, prompt, positive=True):
159
- prompt_emb = self.prompter.encode_prompt(prompt, positive=positive, device=self.device)
160
- return {"context": prompt_emb}
161
-
162
-
163
- def encode_image(self, image, num_frames, height, width):
164
- image = self.preprocess_image(image.resize((width, height))).to(self.device, dtype=self.torch_dtype)
165
- clip_context = self.image_encoder.encode_image([image])
166
- clip_context = clip_context.to(dtype=self.torch_dtype)
167
- msk = torch.ones(1, num_frames, height//8, width//8, device=self.device, dtype=self.torch_dtype)
168
- msk[:, 1:] = 0
169
- msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
170
- msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
171
- msk = msk.transpose(1, 2)[0]
172
-
173
- vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device, dtype=self.torch_dtype)], dim=1)
174
- y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
175
- y = torch.concat([msk, y])
176
- y = y.unsqueeze(0)
177
- clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
178
- y = y.to(dtype=self.torch_dtype, device=self.device)
179
- return {"clip_feature": clip_context, "y": y}
180
-
181
-
182
- def tensor2video(self, frames):
183
- frames = rearrange(frames, "C T H W -> T H W C")
184
- frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
185
- frames = [Image.fromarray(frame) for frame in frames]
186
- return frames
187
-
188
-
189
- def prepare_extra_input(self, latents=None):
190
- return {}
191
-
192
-
193
- def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
194
- latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
195
- return latents
196
-
197
-
198
- def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
199
- frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
200
- return frames
201
-
202
-
203
- def prepare_unified_sequence_parallel(self):
204
- return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
205
-
206
-
207
- @torch.no_grad()
208
- def log_video(
209
- self,
210
- lat,
211
- prompt,
212
- fixed_frame=0, # lat frames
213
- image_emb={},
214
- audio_emb={},
215
- negative_prompt="",
216
- cfg_scale=5.0,
217
- audio_cfg_scale=5.0,
218
- num_inference_steps=50,
219
- denoising_strength=1.0,
220
- sigma_shift=5.0,
221
- tiled=True,
222
- tile_size=(30, 52),
223
- tile_stride=(15, 26),
224
- tea_cache_l1_thresh=None,
225
- tea_cache_model_id="",
226
- progress_bar_cmd=tqdm,
227
- return_latent=False,
228
- ):
229
- tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
230
- # Scheduler
231
- self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
232
-
233
- lat = lat.to(dtype=self.torch_dtype)
234
- latents = lat.clone()
235
- latents = torch.randn_like(latents, dtype=self.torch_dtype)
236
-
237
- # Encode prompts
238
- self.load_models_to_device(["text_encoder"])
239
- prompt_emb_posi = self.encode_prompt(prompt, positive=True)
240
- if cfg_scale != 1.0:
241
- prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
242
-
243
- # Extra input
244
- extra_input = self.prepare_extra_input(latents)
245
-
246
- # TeaCache
247
- tea_cache_posi = {"tea_cache": None}
248
- tea_cache_nega = {"tea_cache": None}
249
-
250
- # Denoise
251
- self.load_models_to_device(["dit"])
252
- for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
253
- if fixed_frame > 0: # new
254
- latents[:, :, :fixed_frame] = lat[:, :, :fixed_frame]
255
- timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
256
-
257
- # Inference
258
- noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **audio_emb, **tea_cache_posi, **extra_input)
259
- print(f'noise_pred_posi:{noise_pred_posi.dtype}')
260
- if cfg_scale != 1.0:
261
- audio_emb_uc = {}
262
- for key in audio_emb.keys():
263
- audio_emb_uc[key] = torch.zeros_like(audio_emb[key], dtype=self.torch_dtype)
264
- if audio_cfg_scale == cfg_scale:
265
- noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **audio_emb_uc, **tea_cache_nega, **extra_input)
266
- noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
267
- else:
268
- tea_cache_nega_audio = {"tea_cache": None}
269
- audio_noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **audio_emb_uc, **tea_cache_nega_audio, **extra_input)
270
- text_noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **audio_emb_uc, **tea_cache_nega, **extra_input)
271
- noise_pred = text_noise_pred_nega + cfg_scale * (audio_noise_pred_nega - text_noise_pred_nega) + audio_cfg_scale * (noise_pred_posi - audio_noise_pred_nega)
272
- else:
273
- noise_pred = noise_pred_posi
274
- # Scheduler
275
- latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
276
-
277
- if fixed_frame > 0: # new
278
- latents[:, :, :fixed_frame] = lat[:, :, :fixed_frame]
279
- # Decode
280
- self.load_models_to_device(['vae'])
281
- frames = self.decode_video(latents, **tiler_kwargs)
282
- recons = self.decode_video(lat, **tiler_kwargs)
283
- self.load_models_to_device([])
284
- frames = (frames.permute(0, 2, 1, 3, 4).float() + 1.0) / 2.0
285
- recons = (recons.permute(0, 2, 1, 3, 4).float() + 1.0) / 2.0
286
- if return_latent:
287
- return frames, recons, latents
288
- return frames, recons
289
-
290
-
291
- class TeaCache:
292
- def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
293
- self.num_inference_steps = num_inference_steps
294
- self.step = 0
295
- self.accumulated_rel_l1_distance = 0
296
- self.previous_modulated_input = None
297
- self.rel_l1_thresh = rel_l1_thresh
298
- self.previous_residual = None
299
- self.previous_hidden_states = None
300
-
301
- self.coefficients_dict = {
302
- "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
303
- "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
304
- "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
305
- "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
306
- }
307
- if model_id not in self.coefficients_dict:
308
- supported_model_ids = ", ".join([i for i in self.coefficients_dict])
309
- raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
310
- self.coefficients = self.coefficients_dict[model_id]
311
-
312
- def check(self, dit: WanModel, x, t_mod):
313
- modulated_inp = t_mod.clone()
314
- if self.step == 0 or self.step == self.num_inference_steps - 1:
315
- should_calc = True
316
- self.accumulated_rel_l1_distance = 0
317
- else:
318
- coefficients = self.coefficients
319
- rescale_func = np.poly1d(coefficients)
320
- self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
321
- if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
322
- should_calc = False
323
- else:
324
- should_calc = True
325
- self.accumulated_rel_l1_distance = 0
326
- self.previous_modulated_input = modulated_inp
327
- self.step += 1
328
- if self.step == self.num_inference_steps:
329
- self.step = 0
330
- if should_calc:
331
- self.previous_hidden_states = x.clone()
332
- return not should_calc
333
-
334
- def store(self, hidden_states):
335
- self.previous_residual = hidden_states - self.previous_hidden_states
336
- self.previous_hidden_states = None
337
-
338
- def update(self, hidden_states):
339
- hidden_states = hidden_states + self.previous_residual
 
 
 
 
 
340
  return hidden_states
 
1
+ import types
2
+ from .models.model_manager import ModelManager
3
+ from .models.wan_video_dit import WanModel
4
+ from .models.wan_video_text_encoder import WanTextEncoder
5
+ from .models.wan_video_vae import WanVideoVAE
6
+ from .schedulers.flow_match import FlowMatchScheduler
7
+
8
+ from .base import BasePipeline
9
+ from .prompters import WanPrompter
10
+ import torch, os
11
+ from einops import rearrange
12
+ import numpy as np
13
+ from PIL import Image
14
+ from tqdm import tqdm
15
+ from typing import Optional
16
+ from .vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
17
+ from .models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
18
+ from .models.wan_video_dit import RMSNorm
19
+ from .models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
20
+
21
+ from diffusers import UniPCMultistepScheduler
22
+
23
+ class WanVideoPipeline(BasePipeline):
24
+
25
+ def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None):
26
+ super().__init__(device=device, torch_dtype=torch_dtype)
27
+ self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
28
+
29
+ self.scheduler = UniPCMultistepScheduler.from_config(self.scheduler.config, flow_shift=8.0)
30
+
31
+ self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
32
+ self.text_encoder: WanTextEncoder = None
33
+ self.image_encoder = None
34
+ self.dit: WanModel = None
35
+ self.vae: WanVideoVAE = None
36
+ self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder']
37
+ self.height_division_factor = 16
38
+ self.width_division_factor = 16
39
+ self.use_unified_sequence_parallel = False
40
+ self.sp_size = 1
41
+
42
+
43
+ def enable_vram_management(self, num_persistent_param_in_dit=None):
44
+ dtype = next(iter(self.text_encoder.parameters())).dtype
45
+ enable_vram_management(
46
+ self.text_encoder,
47
+ module_map = {
48
+ torch.nn.Linear: AutoWrappedLinear,
49
+ torch.nn.Embedding: AutoWrappedModule,
50
+ T5RelativeEmbedding: AutoWrappedModule,
51
+ T5LayerNorm: AutoWrappedModule,
52
+ },
53
+ module_config = dict(
54
+ offload_dtype=dtype,
55
+ offload_device="cpu",
56
+ onload_dtype=dtype,
57
+ onload_device="cpu",
58
+ computation_dtype=self.torch_dtype,
59
+ computation_device=self.device,
60
+ ),
61
+ )
62
+ dtype = next(iter(self.dit.parameters())).dtype
63
+ enable_vram_management(
64
+ self.dit,
65
+ module_map = {
66
+ torch.nn.Linear: AutoWrappedLinear,
67
+ torch.nn.Conv3d: AutoWrappedModule,
68
+ torch.nn.LayerNorm: AutoWrappedModule,
69
+ RMSNorm: AutoWrappedModule,
70
+ },
71
+ module_config = dict(
72
+ offload_dtype=dtype,
73
+ offload_device="cpu",
74
+ onload_dtype=dtype,
75
+ onload_device=self.device,
76
+ computation_dtype=self.torch_dtype,
77
+ computation_device=self.device,
78
+ ),
79
+ max_num_param=num_persistent_param_in_dit,
80
+ overflow_module_config = dict(
81
+ offload_dtype=dtype,
82
+ offload_device="cpu",
83
+ onload_dtype=dtype,
84
+ onload_device="cpu",
85
+ computation_dtype=self.torch_dtype,
86
+ computation_device=self.device,
87
+ ),
88
+ )
89
+ dtype = next(iter(self.vae.parameters())).dtype
90
+ enable_vram_management(
91
+ self.vae,
92
+ module_map = {
93
+ torch.nn.Linear: AutoWrappedLinear,
94
+ torch.nn.Conv2d: AutoWrappedModule,
95
+ RMS_norm: AutoWrappedModule,
96
+ CausalConv3d: AutoWrappedModule,
97
+ Upsample: AutoWrappedModule,
98
+ torch.nn.SiLU: AutoWrappedModule,
99
+ torch.nn.Dropout: AutoWrappedModule,
100
+ },
101
+ module_config = dict(
102
+ offload_dtype=dtype,
103
+ offload_device="cpu",
104
+ onload_dtype=dtype,
105
+ onload_device=self.device,
106
+ computation_dtype=self.torch_dtype,
107
+ computation_device=self.device,
108
+ ),
109
+ )
110
+ if self.image_encoder is not None:
111
+ dtype = next(iter(self.image_encoder.parameters())).dtype
112
+ enable_vram_management(
113
+ self.image_encoder,
114
+ module_map = {
115
+ torch.nn.Linear: AutoWrappedLinear,
116
+ torch.nn.Conv2d: AutoWrappedModule,
117
+ torch.nn.LayerNorm: AutoWrappedModule,
118
+ },
119
+ module_config = dict(
120
+ offload_dtype=dtype,
121
+ offload_device="cpu",
122
+ onload_dtype=dtype,
123
+ onload_device="cpu",
124
+ computation_dtype=dtype,
125
+ computation_device=self.device,
126
+ ),
127
+ )
128
+ self.enable_cpu_offload()
129
+
130
+
131
+ def fetch_models(self, model_manager: ModelManager):
132
+ text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True)
133
+ if text_encoder_model_and_path is not None:
134
+ self.text_encoder, tokenizer_path = text_encoder_model_and_path
135
+ self.prompter.fetch_models(self.text_encoder)
136
+ self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl"))
137
+ self.dit = model_manager.fetch_model("wan_video_dit")
138
+ self.vae = model_manager.fetch_model("wan_video_vae")
139
+ self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
140
+
141
+
142
+ @staticmethod
143
+ def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False, infer=False):
144
+ if device is None: device = model_manager.device
145
+ if torch_dtype is None: torch_dtype = model_manager.torch_dtype
146
+ pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
147
+ pipe.fetch_models(model_manager)
148
+ if use_usp:
149
+ from xfuser.core.distributed import get_sequence_parallel_world_size, get_sp_group
150
+ from OmniAvatar.distributed.xdit_context_parallel import usp_attn_forward
151
+ for block in pipe.dit.blocks:
152
+ block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
153
+ pipe.sp_size = get_sequence_parallel_world_size()
154
+ pipe.use_unified_sequence_parallel = True
155
+ pipe.sp_group = get_sp_group()
156
+ return pipe
157
+
158
+
159
+ def denoising_model(self):
160
+ return self.dit
161
+
162
+
163
+ def encode_prompt(self, prompt, positive=True):
164
+ prompt_emb = self.prompter.encode_prompt(prompt, positive=positive, device=self.device)
165
+ return {"context": prompt_emb}
166
+
167
+
168
+ def encode_image(self, image, num_frames, height, width):
169
+ image = self.preprocess_image(image.resize((width, height))).to(self.device, dtype=self.torch_dtype)
170
+ clip_context = self.image_encoder.encode_image([image])
171
+ clip_context = clip_context.to(dtype=self.torch_dtype)
172
+ msk = torch.ones(1, num_frames, height//8, width//8, device=self.device, dtype=self.torch_dtype)
173
+ msk[:, 1:] = 0
174
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
175
+ msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
176
+ msk = msk.transpose(1, 2)[0]
177
+
178
+ vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device, dtype=self.torch_dtype)], dim=1)
179
+ y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
180
+ y = torch.concat([msk, y])
181
+ y = y.unsqueeze(0)
182
+ clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
183
+ y = y.to(dtype=self.torch_dtype, device=self.device)
184
+ return {"clip_feature": clip_context, "y": y}
185
+
186
+
187
+ def tensor2video(self, frames):
188
+ frames = rearrange(frames, "C T H W -> T H W C")
189
+ frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
190
+ frames = [Image.fromarray(frame) for frame in frames]
191
+ return frames
192
+
193
+
194
+ def prepare_extra_input(self, latents=None):
195
+ return {}
196
+
197
+
198
+ def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
199
+ latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
200
+ return latents
201
+
202
+
203
+ def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
204
+ frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
205
+ return frames
206
+
207
+
208
+ def prepare_unified_sequence_parallel(self):
209
+ return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
210
+
211
+
212
+ @torch.no_grad()
213
+ def log_video(
214
+ self,
215
+ lat,
216
+ prompt,
217
+ fixed_frame=0, # lat frames
218
+ image_emb={},
219
+ audio_emb={},
220
+ negative_prompt="",
221
+ cfg_scale=5.0,
222
+ audio_cfg_scale=5.0,
223
+ num_inference_steps=50,
224
+ denoising_strength=1.0,
225
+ sigma_shift=5.0,
226
+ tiled=True,
227
+ tile_size=(30, 52),
228
+ tile_stride=(15, 26),
229
+ tea_cache_l1_thresh=None,
230
+ tea_cache_model_id="",
231
+ progress_bar_cmd=tqdm,
232
+ return_latent=False,
233
+ ):
234
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
235
+ # Scheduler
236
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
237
+
238
+ lat = lat.to(dtype=self.torch_dtype)
239
+ latents = lat.clone()
240
+ latents = torch.randn_like(latents, dtype=self.torch_dtype)
241
+
242
+ # Encode prompts
243
+ self.load_models_to_device(["text_encoder"])
244
+ prompt_emb_posi = self.encode_prompt(prompt, positive=True)
245
+ if cfg_scale != 1.0:
246
+ prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
247
+
248
+ # Extra input
249
+ extra_input = self.prepare_extra_input(latents)
250
+
251
+ # TeaCache
252
+ tea_cache_posi = {"tea_cache": None}
253
+ tea_cache_nega = {"tea_cache": None}
254
+
255
+ # Denoise
256
+ self.load_models_to_device(["dit"])
257
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
258
+ if fixed_frame > 0: # new
259
+ latents[:, :, :fixed_frame] = lat[:, :, :fixed_frame]
260
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
261
+
262
+ # Inference
263
+ noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **audio_emb, **tea_cache_posi, **extra_input)
264
+ print(f'noise_pred_posi:{noise_pred_posi.dtype}')
265
+ if cfg_scale != 1.0:
266
+ audio_emb_uc = {}
267
+ for key in audio_emb.keys():
268
+ audio_emb_uc[key] = torch.zeros_like(audio_emb[key], dtype=self.torch_dtype)
269
+ if audio_cfg_scale == cfg_scale:
270
+ noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **audio_emb_uc, **tea_cache_nega, **extra_input)
271
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
272
+ else:
273
+ tea_cache_nega_audio = {"tea_cache": None}
274
+ audio_noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **audio_emb_uc, **tea_cache_nega_audio, **extra_input)
275
+ text_noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **audio_emb_uc, **tea_cache_nega, **extra_input)
276
+ noise_pred = text_noise_pred_nega + cfg_scale * (audio_noise_pred_nega - text_noise_pred_nega) + audio_cfg_scale * (noise_pred_posi - audio_noise_pred_nega)
277
+ else:
278
+ noise_pred = noise_pred_posi
279
+ # Scheduler
280
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
281
+
282
+ if fixed_frame > 0: # new
283
+ latents[:, :, :fixed_frame] = lat[:, :, :fixed_frame]
284
+ # Decode
285
+ self.load_models_to_device(['vae'])
286
+ frames = self.decode_video(latents, **tiler_kwargs)
287
+ recons = self.decode_video(lat, **tiler_kwargs)
288
+ self.load_models_to_device([])
289
+ frames = (frames.permute(0, 2, 1, 3, 4).float() + 1.0) / 2.0
290
+ recons = (recons.permute(0, 2, 1, 3, 4).float() + 1.0) / 2.0
291
+ if return_latent:
292
+ return frames, recons, latents
293
+ return frames, recons
294
+
295
+
296
+ class TeaCache:
297
+ def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
298
+ self.num_inference_steps = num_inference_steps
299
+ self.step = 0
300
+ self.accumulated_rel_l1_distance = 0
301
+ self.previous_modulated_input = None
302
+ self.rel_l1_thresh = rel_l1_thresh
303
+ self.previous_residual = None
304
+ self.previous_hidden_states = None
305
+
306
+ self.coefficients_dict = {
307
+ "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
308
+ "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
309
+ "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
310
+ "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
311
+ }
312
+ if model_id not in self.coefficients_dict:
313
+ supported_model_ids = ", ".join([i for i in self.coefficients_dict])
314
+ raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
315
+ self.coefficients = self.coefficients_dict[model_id]
316
+
317
+ def check(self, dit: WanModel, x, t_mod):
318
+ modulated_inp = t_mod.clone()
319
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
320
+ should_calc = True
321
+ self.accumulated_rel_l1_distance = 0
322
+ else:
323
+ coefficients = self.coefficients
324
+ rescale_func = np.poly1d(coefficients)
325
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
326
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
327
+ should_calc = False
328
+ else:
329
+ should_calc = True
330
+ self.accumulated_rel_l1_distance = 0
331
+ self.previous_modulated_input = modulated_inp
332
+ self.step += 1
333
+ if self.step == self.num_inference_steps:
334
+ self.step = 0
335
+ if should_calc:
336
+ self.previous_hidden_states = x.clone()
337
+ return not should_calc
338
+
339
+ def store(self, hidden_states):
340
+ self.previous_residual = hidden_states - self.previous_hidden_states
341
+ self.previous_hidden_states = None
342
+
343
+ def update(self, hidden_states):
344
+ hidden_states = hidden_states + self.previous_residual
345
  return hidden_states