alexnasa commited on
Commit
49206c4
·
verified ·
1 Parent(s): 542f3d9

Error handling with info

Browse files
Files changed (1) hide show
  1. app.py +787 -766
app.py CHANGED
@@ -1,767 +1,788 @@
1
- import spaces
2
- import subprocess
3
- import gradio as gr
4
-
5
- import os, sys
6
- from glob import glob
7
- from datetime import datetime
8
- import math
9
- import random
10
- import librosa
11
- import numpy as np
12
- import uuid
13
- import shutil
14
- from tqdm import tqdm
15
-
16
- import importlib, site, sys
17
- from huggingface_hub import hf_hub_download, snapshot_download
18
-
19
- # Re-discover all .pth/.egg-link files
20
- for sitedir in site.getsitepackages():
21
- site.addsitedir(sitedir)
22
-
23
- # Clear caches so importlib will pick up new modules
24
- importlib.invalidate_caches()
25
-
26
- def sh(cmd): subprocess.check_call(cmd, shell=True)
27
-
28
- flash_attention_installed = False
29
-
30
- try:
31
- print("Attempting to download and install FlashAttention wheel...")
32
- flash_attention_wheel = hf_hub_download(
33
- repo_id="alexnasa/flash-attn-3",
34
- repo_type="model",
35
- filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
36
- )
37
-
38
- sh(f"pip install {flash_attention_wheel}")
39
-
40
- # tell Python to re-scan site-packages now that the egg-link exists
41
- import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
42
-
43
- flash_attention_installed = True
44
- print("FlashAttention installed successfully.")
45
-
46
- except Exception as e:
47
- print(f"⚠️ Could not install FlashAttention: {e}")
48
- print("Continuing without FlashAttention...")
49
-
50
- import torch
51
- print(f"Torch version: {torch.__version__}")
52
- print(f"FlashAttention available: {flash_attention_installed}")
53
-
54
-
55
- import torch.nn as nn
56
- from tqdm import tqdm
57
- from functools import partial
58
- from omegaconf import OmegaConf
59
- from argparse import Namespace
60
-
61
- # load the one true config you dumped
62
- _args_cfg = OmegaConf.load("args_config.yaml")
63
- args = Namespace(**OmegaConf.to_container(_args_cfg, resolve=True))
64
-
65
- from OmniAvatar.utils.args_config import set_global_args
66
-
67
- set_global_args(args)
68
- # args = parse_args()
69
-
70
- from OmniAvatar.utils.io_utils import load_state_dict
71
- from peft import LoraConfig, inject_adapter_in_model
72
- from OmniAvatar.models.model_manager import ModelManager
73
- from OmniAvatar.schedulers.flow_match import FlowMatchScheduler
74
- from OmniAvatar.wan_video import WanVideoPipeline
75
- from OmniAvatar.utils.io_utils import save_video_as_grid_and_mp4
76
- import torchvision.transforms as TT
77
- from transformers import Wav2Vec2FeatureExtractor
78
- import torchvision.transforms as transforms
79
- import torch.nn.functional as F
80
- from OmniAvatar.utils.audio_preprocess import add_silence_to_audio_ffmpeg
81
-
82
-
83
- os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/proprocess_results"
84
-
85
- def tensor_to_pil(tensor):
86
- """
87
- Args:
88
- tensor: torch.Tensor with shape like
89
- (1, C, H, W), (1, C, 1, H, W), (C, H, W), etc.
90
- values in [-1, 1], on any device.
91
- Returns:
92
- A PIL.Image in RGB mode.
93
- """
94
- # 1) Remove batch dim if it exists
95
- if tensor.dim() > 3 and tensor.shape[0] == 1:
96
- tensor = tensor[0]
97
-
98
- # 2) Squeeze out any other singleton dims (e.g. that extra frame axis)
99
- tensor = tensor.squeeze()
100
-
101
- # Now we should have exactly 3 dims: (C, H, W)
102
- if tensor.dim() != 3:
103
- raise ValueError(f"Expected 3 dims after squeeze, got {tensor.dim()}")
104
-
105
- # 3) Move to CPU float32
106
- tensor = tensor.cpu().float()
107
-
108
- # 4) Undo normalization from [-1,1] -> [0,1]
109
- tensor = (tensor + 1.0) / 2.0
110
-
111
- # 5) Clamp to [0,1]
112
- tensor = torch.clamp(tensor, 0.0, 1.0)
113
-
114
- # 6) To NumPy H×W×C in [0,255]
115
- np_img = (tensor.permute(1, 2, 0).numpy() * 255.0).round().astype("uint8")
116
-
117
- # 7) Build PIL Image
118
- return Image.fromarray(np_img)
119
-
120
-
121
- def set_seed(seed: int = 42):
122
- random.seed(seed)
123
- np.random.seed(seed)
124
- torch.manual_seed(seed)
125
- torch.cuda.manual_seed(seed) # 设置当前GPU
126
- torch.cuda.manual_seed_all(seed) # 设置所有GPU
127
-
128
- def read_from_file(p):
129
- with open(p, "r") as fin:
130
- for l in fin:
131
- yield l.strip()
132
-
133
- def match_size(image_size, h, w):
134
- ratio_ = 9999
135
- size_ = 9999
136
- select_size = None
137
- for image_s in image_size:
138
- ratio_tmp = abs(image_s[0] / image_s[1] - h / w)
139
- size_tmp = abs(max(image_s) - max(w, h))
140
- if ratio_tmp < ratio_:
141
- ratio_ = ratio_tmp
142
- size_ = size_tmp
143
- select_size = image_s
144
- if ratio_ == ratio_tmp:
145
- if size_ == size_tmp:
146
- select_size = image_s
147
- return select_size
148
-
149
- def resize_pad(image, ori_size, tgt_size):
150
- h, w = ori_size
151
- scale_ratio = max(tgt_size[0] / h, tgt_size[1] / w)
152
- scale_h = int(h * scale_ratio)
153
- scale_w = int(w * scale_ratio)
154
-
155
- image = transforms.Resize(size=[scale_h, scale_w])(image)
156
-
157
- padding_h = tgt_size[0] - scale_h
158
- padding_w = tgt_size[1] - scale_w
159
- pad_top = padding_h // 2
160
- pad_bottom = padding_h - pad_top
161
- pad_left = padding_w // 2
162
- pad_right = padding_w - pad_left
163
-
164
- image = F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
165
- return image
166
-
167
- class WanInferencePipeline(nn.Module):
168
- def __init__(self, args):
169
- super().__init__()
170
- self.args = args
171
- self.device = torch.device(f"cuda")
172
- self.dtype = torch.bfloat16
173
- self.pipe = self.load_model()
174
- chained_trainsforms = []
175
- chained_trainsforms.append(TT.ToTensor())
176
- self.transform = TT.Compose(chained_trainsforms)
177
-
178
- if self.args.use_audio:
179
- from OmniAvatar.models.wav2vec import Wav2VecModel
180
- self.wav_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
181
- self.args.wav2vec_path
182
- )
183
- self.audio_encoder = Wav2VecModel.from_pretrained(self.args.wav2vec_path, local_files_only=True).to(device=self.device, dtype=self.dtype)
184
- self.audio_encoder.feature_extractor._freeze_parameters()
185
-
186
-
187
- def load_model(self):
188
- ckpt_path = f'{self.args.exp_path}/pytorch_model.pt'
189
- assert os.path.exists(ckpt_path), f"pytorch_model.pt not found in {self.args.exp_path}"
190
- if self.args.train_architecture == 'lora':
191
- self.args.pretrained_lora_path = pretrained_lora_path = ckpt_path
192
- else:
193
- resume_path = ckpt_path
194
-
195
- self.step = 0
196
-
197
- # Load models
198
- model_manager = ModelManager(device="cuda", infer=True)
199
-
200
- model_manager.load_models(
201
- [
202
- self.args.dit_path.split(","),
203
- self.args.vae_path,
204
- self.args.text_encoder_path
205
- ],
206
- torch_dtype=self.dtype,
207
- device='cuda',
208
- )
209
-
210
- pipe = WanVideoPipeline.from_model_manager(model_manager,
211
- torch_dtype=self.dtype,
212
- device="cuda",
213
- use_usp=False,
214
- infer=True)
215
-
216
- if self.args.train_architecture == "lora":
217
- print(f'Use LoRA: lora rank: {self.args.lora_rank}, lora alpha: {self.args.lora_alpha}')
218
- self.add_lora_to_model(
219
- pipe.denoising_model(),
220
- lora_rank=self.args.lora_rank,
221
- lora_alpha=self.args.lora_alpha,
222
- lora_target_modules=self.args.lora_target_modules,
223
- init_lora_weights=self.args.init_lora_weights,
224
- pretrained_lora_path=pretrained_lora_path,
225
- )
226
- print(next(pipe.denoising_model().parameters()).device)
227
- else:
228
- missing_keys, unexpected_keys = pipe.denoising_model().load_state_dict(load_state_dict(resume_path), strict=True)
229
- print(f"load from {resume_path}, {len(missing_keys)} missing keys, {len(unexpected_keys)} unexpected keys")
230
- pipe.requires_grad_(False)
231
- pipe.eval()
232
- # pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit)
233
- return pipe
234
-
235
- def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
236
- # Add LoRA to UNet
237
-
238
- self.lora_alpha = lora_alpha
239
- if init_lora_weights == "kaiming":
240
- init_lora_weights = True
241
-
242
- lora_config = LoraConfig(
243
- r=lora_rank,
244
- lora_alpha=lora_alpha,
245
- init_lora_weights=init_lora_weights,
246
- target_modules=lora_target_modules.split(","),
247
- )
248
- model = inject_adapter_in_model(lora_config, model)
249
-
250
- # Lora pretrained lora weights
251
- if pretrained_lora_path is not None:
252
- state_dict = load_state_dict(pretrained_lora_path, torch_dtype=self.dtype)
253
- if state_dict_converter is not None:
254
- state_dict = state_dict_converter(state_dict)
255
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
256
- all_keys = [i for i, _ in model.named_parameters()]
257
- num_updated_keys = len(all_keys) - len(missing_keys)
258
- num_unexpected_keys = len(unexpected_keys)
259
-
260
- print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
261
-
262
- def get_times(self, prompt,
263
- image_path=None,
264
- audio_path=None,
265
- seq_len=101, # not used while audio_path is not None
266
- height=720,
267
- width=720,
268
- overlap_frame=None,
269
- num_steps=None,
270
- negative_prompt=None,
271
- guidance_scale=None,
272
- audio_scale=None):
273
-
274
- overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
275
- num_steps = num_steps if num_steps is not None else self.args.num_steps
276
- negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
277
- guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
278
- audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
279
-
280
- if image_path is not None:
281
- from PIL import Image
282
- image = Image.open(image_path).convert("RGB")
283
-
284
- image = self.transform(image).unsqueeze(0).to(dtype=self.dtype)
285
-
286
- _, _, h, w = image.shape
287
- select_size = match_size(getattr( self.args, f'image_sizes_{ self.args.max_hw}'), h, w)
288
- image = resize_pad(image, (h, w), select_size)
289
- image = image * 2.0 - 1.0
290
- image = image[:, :, None]
291
-
292
- else:
293
- image = None
294
- select_size = [height, width]
295
- num = self.args.max_tokens * 16 * 16 * 4
296
- den = select_size[0] * select_size[1]
297
- L0 = num // den
298
- diff = (L0 - 1) % 4
299
- L = L0 - diff
300
- if L < 1:
301
- L = 1
302
- T = (L + 3) // 4
303
-
304
-
305
- if self.args.random_prefix_frames:
306
- fixed_frame = overlap_frame
307
- assert fixed_frame % 4 == 1
308
- else:
309
- fixed_frame = 1
310
- prefix_lat_frame = (3 + fixed_frame) // 4
311
- first_fixed_frame = 1
312
-
313
-
314
- audio, sr = librosa.load(audio_path, sr= self.args.sample_rate)
315
-
316
- input_values = np.squeeze(
317
- self.wav_feature_extractor(audio, sampling_rate=16000).input_values
318
- )
319
- input_values = torch.from_numpy(input_values).float().to(dtype=self.dtype)
320
- audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
321
-
322
- if audio_len < L - first_fixed_frame:
323
- audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
324
- elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
325
- audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
326
-
327
- seq_len = audio_len
328
-
329
- times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
330
- if times * (L-fixed_frame) + fixed_frame < seq_len:
331
- times += 1
332
-
333
- return times
334
-
335
- @torch.no_grad()
336
- def forward(self, prompt,
337
- image_path=None,
338
- audio_path=None,
339
- seq_len=101, # not used while audio_path is not None
340
- height=720,
341
- width=720,
342
- overlap_frame=None,
343
- num_steps=None,
344
- negative_prompt=None,
345
- guidance_scale=None,
346
- audio_scale=None):
347
- overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
348
- num_steps = num_steps if num_steps is not None else self.args.num_steps
349
- negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
350
- guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
351
- audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
352
-
353
- if image_path is not None:
354
- from PIL import Image
355
- image = Image.open(image_path).convert("RGB")
356
-
357
- image = self.transform(image).unsqueeze(0).to(self.device, dtype=self.dtype)
358
-
359
- _, _, h, w = image.shape
360
- select_size = match_size(getattr(self.args, f'image_sizes_{self.args.max_hw}'), h, w)
361
- image = resize_pad(image, (h, w), select_size)
362
- image = image * 2.0 - 1.0
363
- image = image[:, :, None]
364
-
365
- else:
366
- image = None
367
- select_size = [height, width]
368
- # L = int(self.args.max_tokens * 16 * 16 * 4 / select_size[0] / select_size[1])
369
- # L = L // 4 * 4 + 1 if L % 4 != 0 else L - 3 # video frames
370
- # T = (L + 3) // 4 # latent frames
371
-
372
- # step 1: numerator and denominator as ints
373
- num = args.max_tokens * 16 * 16 * 4
374
- den = select_size[0] * select_size[1]
375
-
376
- # step 2: integer division
377
- L0 = num // den # exact floor division, no float in sight
378
-
379
- # step 3: make it ≡ 1 mod 4
380
- # if L0 % 4 == 1, keep L0;
381
- # otherwise subtract the difference so that (L0 - diff) % 4 == 1,
382
- # but ensure the result stays positive.
383
- diff = (L0 - 1) % 4
384
- L = L0 - diff
385
- if L < 1:
386
- L = 1 # or whatever your minimal frame count is
387
-
388
- # step 4: latent frames
389
- T = (L + 3) // 4
390
-
391
-
392
- if self.args.i2v:
393
- if self.args.random_prefix_frames:
394
- fixed_frame = overlap_frame
395
- assert fixed_frame % 4 == 1
396
- else:
397
- fixed_frame = 1
398
- prefix_lat_frame = (3 + fixed_frame) // 4
399
- first_fixed_frame = 1
400
- else:
401
- fixed_frame = 0
402
- prefix_lat_frame = 0
403
- first_fixed_frame = 0
404
-
405
-
406
- if audio_path is not None and self.args.use_audio:
407
- audio, sr = librosa.load(audio_path, sr=self.args.sample_rate)
408
- input_values = np.squeeze(
409
- self.wav_feature_extractor(audio, sampling_rate=16000).input_values
410
- )
411
- input_values = torch.from_numpy(input_values).float().to(device=self.device, dtype=self.dtype)
412
- ori_audio_len = audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
413
- input_values = input_values.unsqueeze(0)
414
- # padding audio
415
- if audio_len < L - first_fixed_frame:
416
- audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
417
- elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
418
- audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
419
- input_values = F.pad(input_values, (0, audio_len * int(self.args.sample_rate / self.args.fps) - input_values.shape[1]), mode='constant', value=0)
420
- with torch.no_grad():
421
- hidden_states = self.audio_encoder(input_values, seq_len=audio_len, output_hidden_states=True)
422
- audio_embeddings = hidden_states.last_hidden_state
423
- for mid_hidden_states in hidden_states.hidden_states:
424
- audio_embeddings = torch.cat((audio_embeddings, mid_hidden_states), -1)
425
- seq_len = audio_len
426
- audio_embeddings = audio_embeddings.squeeze(0)
427
- audio_prefix = torch.zeros_like(audio_embeddings[:first_fixed_frame])
428
- else:
429
- audio_embeddings = None
430
-
431
- # loop
432
- times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
433
- if times * (L-fixed_frame) + fixed_frame < seq_len:
434
- times += 1
435
- video = []
436
- image_emb = {}
437
- img_lat = None
438
- if self.args.i2v:
439
- self.pipe.load_models_to_device(['vae'])
440
- img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
441
-
442
- msk = torch.zeros_like(img_lat.repeat(1, 1, T, 1, 1)[:,:1], dtype=self.dtype)
443
- image_cat = img_lat.repeat(1, 1, T, 1, 1)
444
- msk[:, :, 1:] = 1
445
- image_emb["y"] = torch.cat([image_cat, msk], dim=1)
446
-
447
- total_iterations = times * num_steps
448
-
449
- with tqdm(total=total_iterations) as pbar:
450
- for t in range(times):
451
- print(f"[{t+1}/{times}]")
452
- audio_emb = {}
453
- if t == 0:
454
- overlap = first_fixed_frame
455
- else:
456
- overlap = fixed_frame
457
- image_emb["y"][:, -1:, :prefix_lat_frame] = 0 # 第一次推理是mask只有1,往后都是mask overlap
458
- prefix_overlap = (3 + overlap) // 4
459
- if audio_embeddings is not None:
460
- if t == 0:
461
- audio_tensor = audio_embeddings[
462
- :min(L - overlap, audio_embeddings.shape[0])
463
- ]
464
- else:
465
- audio_start = L - first_fixed_frame + (t - 1) * (L - overlap)
466
- audio_tensor = audio_embeddings[
467
- audio_start: min(audio_start + L - overlap, audio_embeddings.shape[0])
468
- ]
469
-
470
- audio_tensor = torch.cat([audio_prefix, audio_tensor], dim=0)
471
- audio_prefix = audio_tensor[-fixed_frame:]
472
- audio_tensor = audio_tensor.unsqueeze(0).to(device=self.device, dtype=self.dtype)
473
- audio_emb["audio_emb"] = audio_tensor
474
- else:
475
- audio_prefix = None
476
- if image is not None and img_lat is None:
477
- self.pipe.load_models_to_device(['vae'])
478
- img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
479
- assert img_lat.shape[2] == prefix_overlap
480
- img_lat = torch.cat([img_lat, torch.zeros_like(img_lat[:, :, :1].repeat(1, 1, T - prefix_overlap, 1, 1), dtype=self.dtype)], dim=2)
481
- frames, _, latents = self.pipe.log_video(img_lat, prompt, prefix_overlap, image_emb, audio_emb,
482
- negative_prompt, num_inference_steps=num_steps,
483
- cfg_scale=guidance_scale, audio_cfg_scale=audio_scale if audio_scale is not None else guidance_scale,
484
- return_latent=True,
485
- tea_cache_l1_thresh=self.args.tea_cache_l1_thresh,tea_cache_model_id="Wan2.1-T2V-14B", progress_bar_cmd=pbar)
486
-
487
- torch.cuda.empty_cache()
488
- img_lat = None
489
- image = (frames[:, -fixed_frame:].clip(0, 1) * 2.0 - 1.0).permute(0, 2, 1, 3, 4).contiguous()
490
-
491
- if t == 0:
492
- video.append(frames)
493
- else:
494
- video.append(frames[:, overlap:])
495
- video = torch.cat(video, dim=1)
496
- video = video[:, :ori_audio_len + 1]
497
-
498
- return video
499
-
500
-
501
- snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-14B", local_dir="./pretrained_models/Wan2.1-T2V-14B")
502
- snapshot_download(repo_id="facebook/wav2vec2-base-960h", local_dir="./pretrained_models/wav2vec2-base-960h")
503
- snapshot_download(repo_id="OmniAvatar/OmniAvatar-14B", local_dir="./pretrained_models/OmniAvatar-14B")
504
-
505
-
506
- # snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./pretrained_models/Wan2.1-T2V-1.3B")
507
- # snapshot_download(repo_id="facebook/wav2vec2-base-960h", local_dir="./pretrained_models/wav2vec2-base-960h")
508
- # snapshot_download(repo_id="OmniAvatar/OmniAvatar-1.3B", local_dir="./pretrained_models/OmniAvatar-1.3B")
509
-
510
- import tempfile
511
-
512
- from PIL import Image
513
-
514
-
515
- set_seed(args.seed)
516
- seq_len = args.seq_len
517
- inferpipe = WanInferencePipeline(args)
518
-
519
-
520
- def update_generate_button(image_path, audio_path, text, num_steps):
521
-
522
- if image_path is None or audio_path is None:
523
- return gr.update(value="⌚ Zero GPU Required: --")
524
-
525
- duration_s = get_duration(image_path, audio_path, text, num_steps, None, None)
526
- duration_m = duration_s / 60
527
-
528
- return gr.update(value=f"⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)")
529
-
530
- def get_duration(image_path, audio_path, text, num_steps, session_id, progress):
531
-
532
- audio_chunks = inferpipe.get_times(
533
- prompt=text,
534
- image_path=image_path,
535
- audio_path=audio_path,
536
- seq_len=args.seq_len,
537
- num_steps=num_steps
538
- )
539
-
540
- warmup_s = 30
541
- duration_s = (20 * num_steps) + warmup_s
542
-
543
- if audio_chunks > 1:
544
- duration_s = (20 * num_steps * audio_chunks) + warmup_s
545
-
546
- print(f'for {audio_chunks} times, might take {duration_s}')
547
-
548
- return int(duration_s)
549
-
550
- def preprocess_img(image_path, session_id = None):
551
-
552
- if session_id is None:
553
- session_id = uuid.uuid4().hex
554
-
555
- image = Image.open(image_path).convert("RGB")
556
-
557
- image = inferpipe.transform(image).unsqueeze(0).to(dtype=inferpipe.dtype)
558
-
559
- _, _, h, w = image.shape
560
- select_size = match_size(getattr( args, f'image_sizes_{ args.max_hw}'), h, w)
561
- image = resize_pad(image, (h, w), select_size)
562
- image = image * 2.0 - 1.0
563
- image = image[:, :, None]
564
-
565
- output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
566
-
567
- img_dir = output_dir + '/image'
568
- os.makedirs(img_dir, exist_ok=True)
569
- input_img_path = os.path.join(img_dir, f"img_input.jpg")
570
-
571
- image = tensor_to_pil(image)
572
- image.save(input_img_path)
573
-
574
- return input_img_path
575
-
576
-
577
- @spaces.GPU(duration=get_duration)
578
- def infer(image_path, audio_path, text, num_steps, session_id = None, progress=gr.Progress(track_tqdm=True),):
579
-
580
- if session_id is None:
581
- session_id = uuid.uuid4().hex
582
-
583
- output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
584
-
585
- audio_dir = output_dir + '/audio'
586
- os.makedirs(audio_dir, exist_ok=True)
587
- if args.silence_duration_s > 0:
588
- input_audio_path = os.path.join(audio_dir, f"audio_input.wav")
589
- else:
590
- input_audio_path = audio_path
591
- prompt_dir = output_dir + '/prompt'
592
- os.makedirs(prompt_dir, exist_ok=True)
593
-
594
- if args.silence_duration_s > 0:
595
- add_silence_to_audio_ffmpeg(audio_path, input_audio_path, args.silence_duration_s)
596
-
597
- tmp2_audio_path = os.path.join(audio_dir, f"audio_out.wav")
598
- prompt_path = os.path.join(prompt_dir, f"prompt.txt")
599
-
600
- video = inferpipe(
601
- prompt=text,
602
- image_path=image_path,
603
- audio_path=input_audio_path,
604
- seq_len=args.seq_len,
605
- num_steps=num_steps
606
- )
607
-
608
- torch.cuda.empty_cache()
609
-
610
- add_silence_to_audio_ffmpeg(audio_path, tmp2_audio_path, 1.0 / args.fps + args.silence_duration_s)
611
- video_paths = save_video_as_grid_and_mp4(video,
612
- output_dir,
613
- args.fps,
614
- prompt=text,
615
- prompt_path = prompt_path,
616
- audio_path=tmp2_audio_path if args.use_audio else None,
617
- prefix=f'result')
618
-
619
- return video_paths[0]
620
-
621
- def apply(request):
622
-
623
- return request
624
-
625
- def cleanup(request: gr.Request):
626
-
627
- sid = request.session_hash
628
- if sid:
629
- d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
630
- shutil.rmtree(d1, ignore_errors=True)
631
-
632
- def start_session(request: gr.Request):
633
-
634
- return request.session_hash
635
-
636
- css = """
637
- #col-container {
638
- margin: 0 auto;
639
- max-width: 1560px;
640
- }
641
- """
642
-
643
- with gr.Blocks(css=css) as demo:
644
-
645
- session_state = gr.State()
646
- demo.load(start_session, outputs=[session_state])
647
-
648
-
649
- with gr.Column(elem_id="col-container"):
650
- gr.HTML(
651
- """
652
- <div style="text-align: left;">
653
- <p style="font-size:16px; display: inline; margin: 0;">
654
- <strong>OmniAvatar</strong> – Efficient Audio-Driven Avatar Video Generation with Adaptive Body Animation
655
- </p>
656
- <a href="https://huggingface.co/OmniAvatar/OmniAvatar-14B" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
657
- [model]
658
- </a>
659
- </div>
660
- <div style="text-align: left;">
661
- <strong>HF Space by:</strong>
662
- <a href="https://twitter.com/alexandernasa/" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
663
- <img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow Me" alt="GitHub Repo">
664
- </a>
665
- </div>
666
-
667
- """
668
- )
669
-
670
- with gr.Row():
671
-
672
- with gr.Column():
673
-
674
- image_input = gr.Image(label="Reference Image", type="filepath", height=512)
675
- audio_input = gr.Audio(label="Input Audio", type="filepath")
676
-
677
-
678
- with gr.Column():
679
-
680
- output_video = gr.Video(label="Avatar", height=512)
681
- num_steps = gr.Slider(4, 50, value=8, step=1, label="Steps")
682
- time_required = gr.Text(value="⌚ Zero GPU Required: --", show_label=False)
683
- infer_btn = gr.Button("🦜 Avatar Me", variant="primary")
684
- with gr.Accordion("Advanced Settings", open=False):
685
- text_input = gr.Textbox(label="Video Prompt", lines=6, value="A realistic video of a man speaking and sometimes looking directly to the camera and moving her eyes and pupils and head accordingly and he shakes his head in disappointment and tell look stright into the camera , with dynamic and rhythmic and extensive hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.")
686
-
687
- with gr.Column():
688
-
689
- cached_examples = gr.Examples(
690
- examples=[
691
- [
692
- "examples/images/male-001.png",
693
- "examples/audios/denial.wav",
694
- "A realistic video of a man speaking and sometimes looking directly to the camera and moving her eyes and pupils and head accordingly and he shakes his head in disappointment and tell look stright into the camera , with dynamic and rhythmic and extensive hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
695
- 12
696
- ],
697
- [
698
- "examples/images/female-001.png",
699
- "examples/audios/script.wav",
700
- "A realistic video of a woman speaking and sometimes looking directly to the camera and moving her eyes and pupils and head accordingly and turning and looking at the camera and looking away from the camera based on her movements, sitting on a sofa, with dynamic and rhythmic and extensive hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
701
- 14
702
- ],
703
- [
704
- "examples/images/female-002.png",
705
- "examples/audios/nature.wav",
706
- "A realistic video of a woman speaking and sometimes looking directly to the camera and moving her eyes and pupils and head accordingly and turning and looking at the camera and looking away from the camera based on her movements, standing in the woods, with dynamic and rhythmic and extensive hand gestures that complement his speech. Her hands are clearly visible, independent, and unobstructed. Her facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
707
- 10
708
- ],
709
- ],
710
- label="Cached Examples",
711
- inputs=[image_input, audio_input, text_input, num_steps],
712
- outputs=[output_video],
713
- fn=infer,
714
- cache_examples=True
715
- )
716
-
717
- image_examples = gr.Examples(
718
- examples=[
719
- [
720
- "examples/images/female-009.png",
721
- ],
722
- [
723
- "examples/images/male-005.png",
724
- ],
725
- [
726
- "examples/images/female-003.png",
727
- ],
728
- ],
729
- label="Image Samples",
730
- inputs=[image_input],
731
- outputs=[image_input],
732
- fn=apply
733
- )
734
-
735
- audio_examples = gr.Examples(
736
- examples=[
737
- [
738
- "examples/audios/londoners.wav",
739
- ],
740
- [
741
- "examples/audios/matcha.wav",
742
- ],
743
- [
744
- "examples/audios/keen.wav",
745
- ],
746
- ],
747
- label="Audio Samples",
748
- inputs=[audio_input],
749
- outputs=[audio_input],
750
- fn=apply
751
- )
752
-
753
- infer_btn.click(
754
- fn=infer,
755
- inputs=[image_input, audio_input, text_input, num_steps, session_state],
756
- outputs=[output_video]
757
- )
758
- image_input.upload(fn=preprocess_img, inputs=[image_input, session_state], outputs=[image_input])
759
- image_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
760
- audio_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
761
- num_steps.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
762
-
763
-
764
- if __name__ == "__main__":
765
- demo.unload(cleanup)
766
- demo.queue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
767
  demo.launch(ssr_mode=False)
 
1
+ import spaces
2
+ import subprocess
3
+ import gradio as gr
4
+
5
+ import os, sys
6
+ from glob import glob
7
+ from datetime import datetime
8
+ import math
9
+ import random
10
+ import librosa
11
+ import numpy as np
12
+ import uuid
13
+ import shutil
14
+ from tqdm import tqdm
15
+
16
+ import importlib, site, sys
17
+ from huggingface_hub import hf_hub_download, snapshot_download
18
+
19
+ # Re-discover all .pth/.egg-link files
20
+ for sitedir in site.getsitepackages():
21
+ site.addsitedir(sitedir)
22
+
23
+ # Clear caches so importlib will pick up new modules
24
+ importlib.invalidate_caches()
25
+
26
+ def sh(cmd): subprocess.check_call(cmd, shell=True)
27
+
28
+ flash_attention_installed = False
29
+
30
+ try:
31
+ print("Attempting to download and install FlashAttention wheel...")
32
+ flash_attention_wheel = hf_hub_download(
33
+ repo_id="alexnasa/flash-attn-3",
34
+ repo_type="model",
35
+ filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
36
+ )
37
+
38
+ sh(f"pip install {flash_attention_wheel}")
39
+
40
+ # tell Python to re-scan site-packages now that the egg-link exists
41
+ import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
42
+
43
+ flash_attention_installed = True
44
+ print("FlashAttention installed successfully.")
45
+
46
+ except Exception as e:
47
+ print(f"⚠️ Could not install FlashAttention: {e}")
48
+ print("Continuing without FlashAttention...")
49
+
50
+ import torch
51
+ print(f"Torch version: {torch.__version__}")
52
+ print(f"FlashAttention available: {flash_attention_installed}")
53
+
54
+
55
+ import torch.nn as nn
56
+ from tqdm import tqdm
57
+ from functools import partial
58
+ from omegaconf import OmegaConf
59
+ from argparse import Namespace
60
+
61
+ # load the one true config you dumped
62
+ _args_cfg = OmegaConf.load("args_config.yaml")
63
+ args = Namespace(**OmegaConf.to_container(_args_cfg, resolve=True))
64
+
65
+ from OmniAvatar.utils.args_config import set_global_args
66
+
67
+ set_global_args(args)
68
+ # args = parse_args()
69
+
70
+ from OmniAvatar.utils.io_utils import load_state_dict
71
+ from peft import LoraConfig, inject_adapter_in_model
72
+ from OmniAvatar.models.model_manager import ModelManager
73
+ from OmniAvatar.schedulers.flow_match import FlowMatchScheduler
74
+ from OmniAvatar.wan_video import WanVideoPipeline
75
+ from OmniAvatar.utils.io_utils import save_video_as_grid_and_mp4
76
+ import torchvision.transforms as TT
77
+ from transformers import Wav2Vec2FeatureExtractor
78
+ import torchvision.transforms as transforms
79
+ import torch.nn.functional as F
80
+ from OmniAvatar.utils.audio_preprocess import add_silence_to_audio_ffmpeg
81
+
82
+
83
+ os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/proprocess_results"
84
+
85
+ def tensor_to_pil(tensor):
86
+ """
87
+ Args:
88
+ tensor: torch.Tensor with shape like
89
+ (1, C, H, W), (1, C, 1, H, W), (C, H, W), etc.
90
+ values in [-1, 1], on any device.
91
+ Returns:
92
+ A PIL.Image in RGB mode.
93
+ """
94
+ # 1) Remove batch dim if it exists
95
+ if tensor.dim() > 3 and tensor.shape[0] == 1:
96
+ tensor = tensor[0]
97
+
98
+ # 2) Squeeze out any other singleton dims (e.g. that extra frame axis)
99
+ tensor = tensor.squeeze()
100
+
101
+ # Now we should have exactly 3 dims: (C, H, W)
102
+ if tensor.dim() != 3:
103
+ raise ValueError(f"Expected 3 dims after squeeze, got {tensor.dim()}")
104
+
105
+ # 3) Move to CPU float32
106
+ tensor = tensor.cpu().float()
107
+
108
+ # 4) Undo normalization from [-1,1] -> [0,1]
109
+ tensor = (tensor + 1.0) / 2.0
110
+
111
+ # 5) Clamp to [0,1]
112
+ tensor = torch.clamp(tensor, 0.0, 1.0)
113
+
114
+ # 6) To NumPy H×W×C in [0,255]
115
+ np_img = (tensor.permute(1, 2, 0).numpy() * 255.0).round().astype("uint8")
116
+
117
+ # 7) Build PIL Image
118
+ return Image.fromarray(np_img)
119
+
120
+
121
+ def set_seed(seed: int = 42):
122
+ random.seed(seed)
123
+ np.random.seed(seed)
124
+ torch.manual_seed(seed)
125
+ torch.cuda.manual_seed(seed) # 设置当前GPU
126
+ torch.cuda.manual_seed_all(seed) # 设置所有GPU
127
+
128
+ def read_from_file(p):
129
+ with open(p, "r") as fin:
130
+ for l in fin:
131
+ yield l.strip()
132
+
133
+ def match_size(image_size, h, w):
134
+ ratio_ = 9999
135
+ size_ = 9999
136
+ select_size = None
137
+ for image_s in image_size:
138
+ ratio_tmp = abs(image_s[0] / image_s[1] - h / w)
139
+ size_tmp = abs(max(image_s) - max(w, h))
140
+ if ratio_tmp < ratio_:
141
+ ratio_ = ratio_tmp
142
+ size_ = size_tmp
143
+ select_size = image_s
144
+ if ratio_ == ratio_tmp:
145
+ if size_ == size_tmp:
146
+ select_size = image_s
147
+ return select_size
148
+
149
+ def resize_pad(image, ori_size, tgt_size):
150
+ h, w = ori_size
151
+ scale_ratio = max(tgt_size[0] / h, tgt_size[1] / w)
152
+ scale_h = int(h * scale_ratio)
153
+ scale_w = int(w * scale_ratio)
154
+
155
+ image = transforms.Resize(size=[scale_h, scale_w])(image)
156
+
157
+ padding_h = tgt_size[0] - scale_h
158
+ padding_w = tgt_size[1] - scale_w
159
+ pad_top = padding_h // 2
160
+ pad_bottom = padding_h - pad_top
161
+ pad_left = padding_w // 2
162
+ pad_right = padding_w - pad_left
163
+
164
+ image = F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
165
+ return image
166
+
167
+ class WanInferencePipeline(nn.Module):
168
+ def __init__(self, args):
169
+ super().__init__()
170
+ self.args = args
171
+ self.device = torch.device(f"cuda")
172
+ self.dtype = torch.bfloat16
173
+ self.pipe = self.load_model()
174
+ chained_trainsforms = []
175
+ chained_trainsforms.append(TT.ToTensor())
176
+ self.transform = TT.Compose(chained_trainsforms)
177
+
178
+ if self.args.use_audio:
179
+ from OmniAvatar.models.wav2vec import Wav2VecModel
180
+ self.wav_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
181
+ self.args.wav2vec_path
182
+ )
183
+ self.audio_encoder = Wav2VecModel.from_pretrained(self.args.wav2vec_path, local_files_only=True).to(device=self.device, dtype=self.dtype)
184
+ self.audio_encoder.feature_extractor._freeze_parameters()
185
+
186
+
187
+ def load_model(self):
188
+ ckpt_path = f'{self.args.exp_path}/pytorch_model.pt'
189
+ assert os.path.exists(ckpt_path), f"pytorch_model.pt not found in {self.args.exp_path}"
190
+ if self.args.train_architecture == 'lora':
191
+ self.args.pretrained_lora_path = pretrained_lora_path = ckpt_path
192
+ else:
193
+ resume_path = ckpt_path
194
+
195
+ self.step = 0
196
+
197
+ # Load models
198
+ model_manager = ModelManager(device="cuda", infer=True)
199
+
200
+ model_manager.load_models(
201
+ [
202
+ self.args.dit_path.split(","),
203
+ self.args.vae_path,
204
+ self.args.text_encoder_path
205
+ ],
206
+ torch_dtype=self.dtype,
207
+ device='cuda',
208
+ )
209
+
210
+ pipe = WanVideoPipeline.from_model_manager(model_manager,
211
+ torch_dtype=self.dtype,
212
+ device="cuda",
213
+ use_usp=False,
214
+ infer=True)
215
+
216
+ if self.args.train_architecture == "lora":
217
+ print(f'Use LoRA: lora rank: {self.args.lora_rank}, lora alpha: {self.args.lora_alpha}')
218
+ self.add_lora_to_model(
219
+ pipe.denoising_model(),
220
+ lora_rank=self.args.lora_rank,
221
+ lora_alpha=self.args.lora_alpha,
222
+ lora_target_modules=self.args.lora_target_modules,
223
+ init_lora_weights=self.args.init_lora_weights,
224
+ pretrained_lora_path=pretrained_lora_path,
225
+ )
226
+ print(next(pipe.denoising_model().parameters()).device)
227
+ else:
228
+ missing_keys, unexpected_keys = pipe.denoising_model().load_state_dict(load_state_dict(resume_path), strict=True)
229
+ print(f"load from {resume_path}, {len(missing_keys)} missing keys, {len(unexpected_keys)} unexpected keys")
230
+ pipe.requires_grad_(False)
231
+ pipe.eval()
232
+ # pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit)
233
+ return pipe
234
+
235
+ def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
236
+ # Add LoRA to UNet
237
+
238
+ self.lora_alpha = lora_alpha
239
+ if init_lora_weights == "kaiming":
240
+ init_lora_weights = True
241
+
242
+ lora_config = LoraConfig(
243
+ r=lora_rank,
244
+ lora_alpha=lora_alpha,
245
+ init_lora_weights=init_lora_weights,
246
+ target_modules=lora_target_modules.split(","),
247
+ )
248
+ model = inject_adapter_in_model(lora_config, model)
249
+
250
+ # Lora pretrained lora weights
251
+ if pretrained_lora_path is not None:
252
+ state_dict = load_state_dict(pretrained_lora_path, torch_dtype=self.dtype)
253
+ if state_dict_converter is not None:
254
+ state_dict = state_dict_converter(state_dict)
255
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
256
+ all_keys = [i for i, _ in model.named_parameters()]
257
+ num_updated_keys = len(all_keys) - len(missing_keys)
258
+ num_unexpected_keys = len(unexpected_keys)
259
+
260
+ print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
261
+
262
+ def get_times(self, prompt,
263
+ image_path=None,
264
+ audio_path=None,
265
+ seq_len=101, # not used while audio_path is not None
266
+ height=720,
267
+ width=720,
268
+ overlap_frame=None,
269
+ num_steps=None,
270
+ negative_prompt=None,
271
+ guidance_scale=None,
272
+ audio_scale=None):
273
+
274
+ overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
275
+ num_steps = num_steps if num_steps is not None else self.args.num_steps
276
+ negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
277
+ guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
278
+ audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
279
+
280
+ if image_path is not None:
281
+ from PIL import Image
282
+ image = Image.open(image_path).convert("RGB")
283
+
284
+ image = self.transform(image).unsqueeze(0).to(dtype=self.dtype)
285
+
286
+ _, _, h, w = image.shape
287
+ select_size = match_size(getattr( self.args, f'image_sizes_{ self.args.max_hw}'), h, w)
288
+ image = resize_pad(image, (h, w), select_size)
289
+ image = image * 2.0 - 1.0
290
+ image = image[:, :, None]
291
+
292
+ else:
293
+ image = None
294
+ select_size = [height, width]
295
+ num = self.args.max_tokens * 16 * 16 * 4
296
+ den = select_size[0] * select_size[1]
297
+ L0 = num // den
298
+ diff = (L0 - 1) % 4
299
+ L = L0 - diff
300
+ if L < 1:
301
+ L = 1
302
+ T = (L + 3) // 4
303
+
304
+
305
+ if self.args.random_prefix_frames:
306
+ fixed_frame = overlap_frame
307
+ assert fixed_frame % 4 == 1
308
+ else:
309
+ fixed_frame = 1
310
+ prefix_lat_frame = (3 + fixed_frame) // 4
311
+ first_fixed_frame = 1
312
+
313
+
314
+ audio, sr = librosa.load(audio_path, sr= self.args.sample_rate)
315
+
316
+ input_values = np.squeeze(
317
+ self.wav_feature_extractor(audio, sampling_rate=16000).input_values
318
+ )
319
+ input_values = torch.from_numpy(input_values).float().to(dtype=self.dtype)
320
+ audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
321
+
322
+ if audio_len < L - first_fixed_frame:
323
+ audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
324
+ elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
325
+ audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
326
+
327
+ seq_len = audio_len
328
+
329
+ times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
330
+ if times * (L-fixed_frame) + fixed_frame < seq_len:
331
+ times += 1
332
+
333
+ return times
334
+
335
+ @torch.no_grad()
336
+ def forward(self, prompt,
337
+ image_path=None,
338
+ audio_path=None,
339
+ seq_len=101, # not used while audio_path is not None
340
+ height=720,
341
+ width=720,
342
+ overlap_frame=None,
343
+ num_steps=None,
344
+ negative_prompt=None,
345
+ guidance_scale=None,
346
+ audio_scale=None):
347
+ overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
348
+ num_steps = num_steps if num_steps is not None else self.args.num_steps
349
+ negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
350
+ guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
351
+ audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
352
+
353
+ if image_path is not None:
354
+ from PIL import Image
355
+ image = Image.open(image_path).convert("RGB")
356
+
357
+ image = self.transform(image).unsqueeze(0).to(self.device, dtype=self.dtype)
358
+
359
+ _, _, h, w = image.shape
360
+ select_size = match_size(getattr(self.args, f'image_sizes_{self.args.max_hw}'), h, w)
361
+ image = resize_pad(image, (h, w), select_size)
362
+ image = image * 2.0 - 1.0
363
+ image = image[:, :, None]
364
+
365
+ else:
366
+ image = None
367
+ select_size = [height, width]
368
+ # L = int(self.args.max_tokens * 16 * 16 * 4 / select_size[0] / select_size[1])
369
+ # L = L // 4 * 4 + 1 if L % 4 != 0 else L - 3 # video frames
370
+ # T = (L + 3) // 4 # latent frames
371
+
372
+ # step 1: numerator and denominator as ints
373
+ num = args.max_tokens * 16 * 16 * 4
374
+ den = select_size[0] * select_size[1]
375
+
376
+ # step 2: integer division
377
+ L0 = num // den # exact floor division, no float in sight
378
+
379
+ # step 3: make it ≡ 1 mod 4
380
+ # if L0 % 4 == 1, keep L0;
381
+ # otherwise subtract the difference so that (L0 - diff) % 4 == 1,
382
+ # but ensure the result stays positive.
383
+ diff = (L0 - 1) % 4
384
+ L = L0 - diff
385
+ if L < 1:
386
+ L = 1 # or whatever your minimal frame count is
387
+
388
+ # step 4: latent frames
389
+ T = (L + 3) // 4
390
+
391
+
392
+ if self.args.i2v:
393
+ if self.args.random_prefix_frames:
394
+ fixed_frame = overlap_frame
395
+ assert fixed_frame % 4 == 1
396
+ else:
397
+ fixed_frame = 1
398
+ prefix_lat_frame = (3 + fixed_frame) // 4
399
+ first_fixed_frame = 1
400
+ else:
401
+ fixed_frame = 0
402
+ prefix_lat_frame = 0
403
+ first_fixed_frame = 0
404
+
405
+
406
+ if audio_path is not None and self.args.use_audio:
407
+ audio, sr = librosa.load(audio_path, sr=self.args.sample_rate)
408
+ input_values = np.squeeze(
409
+ self.wav_feature_extractor(audio, sampling_rate=16000).input_values
410
+ )
411
+ input_values = torch.from_numpy(input_values).float().to(device=self.device, dtype=self.dtype)
412
+ ori_audio_len = audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
413
+ input_values = input_values.unsqueeze(0)
414
+ # padding audio
415
+ if audio_len < L - first_fixed_frame:
416
+ audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
417
+ elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
418
+ audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
419
+ input_values = F.pad(input_values, (0, audio_len * int(self.args.sample_rate / self.args.fps) - input_values.shape[1]), mode='constant', value=0)
420
+ with torch.no_grad():
421
+ hidden_states = self.audio_encoder(input_values, seq_len=audio_len, output_hidden_states=True)
422
+ audio_embeddings = hidden_states.last_hidden_state
423
+ for mid_hidden_states in hidden_states.hidden_states:
424
+ audio_embeddings = torch.cat((audio_embeddings, mid_hidden_states), -1)
425
+ seq_len = audio_len
426
+ audio_embeddings = audio_embeddings.squeeze(0)
427
+ audio_prefix = torch.zeros_like(audio_embeddings[:first_fixed_frame])
428
+ else:
429
+ audio_embeddings = None
430
+
431
+ # loop
432
+ times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
433
+ if times * (L-fixed_frame) + fixed_frame < seq_len:
434
+ times += 1
435
+ video = []
436
+ image_emb = {}
437
+ img_lat = None
438
+ if self.args.i2v:
439
+ self.pipe.load_models_to_device(['vae'])
440
+ img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
441
+
442
+ msk = torch.zeros_like(img_lat.repeat(1, 1, T, 1, 1)[:,:1], dtype=self.dtype)
443
+ image_cat = img_lat.repeat(1, 1, T, 1, 1)
444
+ msk[:, :, 1:] = 1
445
+ image_emb["y"] = torch.cat([image_cat, msk], dim=1)
446
+
447
+ total_iterations = times * num_steps
448
+
449
+ with tqdm(total=total_iterations) as pbar:
450
+ for t in range(times):
451
+ print(f"[{t+1}/{times}]")
452
+ audio_emb = {}
453
+ if t == 0:
454
+ overlap = first_fixed_frame
455
+ else:
456
+ overlap = fixed_frame
457
+ image_emb["y"][:, -1:, :prefix_lat_frame] = 0 # 第一次推理是mask只有1,往后都是mask overlap
458
+ prefix_overlap = (3 + overlap) // 4
459
+ if audio_embeddings is not None:
460
+ if t == 0:
461
+ audio_tensor = audio_embeddings[
462
+ :min(L - overlap, audio_embeddings.shape[0])
463
+ ]
464
+ else:
465
+ audio_start = L - first_fixed_frame + (t - 1) * (L - overlap)
466
+ audio_tensor = audio_embeddings[
467
+ audio_start: min(audio_start + L - overlap, audio_embeddings.shape[0])
468
+ ]
469
+
470
+ audio_tensor = torch.cat([audio_prefix, audio_tensor], dim=0)
471
+ audio_prefix = audio_tensor[-fixed_frame:]
472
+ audio_tensor = audio_tensor.unsqueeze(0).to(device=self.device, dtype=self.dtype)
473
+ audio_emb["audio_emb"] = audio_tensor
474
+ else:
475
+ audio_prefix = None
476
+ if image is not None and img_lat is None:
477
+ self.pipe.load_models_to_device(['vae'])
478
+ img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
479
+ assert img_lat.shape[2] == prefix_overlap
480
+ img_lat = torch.cat([img_lat, torch.zeros_like(img_lat[:, :, :1].repeat(1, 1, T - prefix_overlap, 1, 1), dtype=self.dtype)], dim=2)
481
+ frames, _, latents = self.pipe.log_video(img_lat, prompt, prefix_overlap, image_emb, audio_emb,
482
+ negative_prompt, num_inference_steps=num_steps,
483
+ cfg_scale=guidance_scale, audio_cfg_scale=audio_scale if audio_scale is not None else guidance_scale,
484
+ return_latent=True,
485
+ tea_cache_l1_thresh=self.args.tea_cache_l1_thresh,tea_cache_model_id="Wan2.1-T2V-14B", progress_bar_cmd=pbar)
486
+
487
+ torch.cuda.empty_cache()
488
+ img_lat = None
489
+ image = (frames[:, -fixed_frame:].clip(0, 1) * 2.0 - 1.0).permute(0, 2, 1, 3, 4).contiguous()
490
+
491
+ if t == 0:
492
+ video.append(frames)
493
+ else:
494
+ video.append(frames[:, overlap:])
495
+ video = torch.cat(video, dim=1)
496
+ video = video[:, :ori_audio_len + 1]
497
+
498
+ return video
499
+
500
+
501
+ snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-14B", local_dir="./pretrained_models/Wan2.1-T2V-14B")
502
+ snapshot_download(repo_id="facebook/wav2vec2-base-960h", local_dir="./pretrained_models/wav2vec2-base-960h")
503
+ snapshot_download(repo_id="OmniAvatar/OmniAvatar-14B", local_dir="./pretrained_models/OmniAvatar-14B")
504
+
505
+
506
+ # snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./pretrained_models/Wan2.1-T2V-1.3B")
507
+ # snapshot_download(repo_id="facebook/wav2vec2-base-960h", local_dir="./pretrained_models/wav2vec2-base-960h")
508
+ # snapshot_download(repo_id="OmniAvatar/OmniAvatar-1.3B", local_dir="./pretrained_models/OmniAvatar-1.3B")
509
+
510
+ import tempfile
511
+
512
+ from PIL import Image
513
+
514
+
515
+ set_seed(args.seed)
516
+ seq_len = args.seq_len
517
+ inferpipe = WanInferencePipeline(args)
518
+
519
+
520
+ def update_generate_button(image_path, audio_path, text, num_steps):
521
+
522
+ if image_path is None or audio_path is None:
523
+ return gr.update(value="⌚ Zero GPU Required: --")
524
+
525
+ duration_s = get_duration(image_path, audio_path, text, num_steps, None, None)
526
+ duration_m = duration_s / 60
527
+
528
+ return gr.update(value=f"⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)")
529
+
530
+ def get_duration(image_path, audio_path, text, num_steps, session_id, progress):
531
+
532
+ if image_path is None:
533
+ gr.Info("Step1: Please Provide an Image or Choose from Image Samples")
534
+ print("Step1: Please Provide an Image or Choose from Image Samples")
535
+
536
+ return 0
537
+
538
+ if audio_path is None:
539
+ gr.Info("Step2: Please Provide an Audio or Choose from Audio Samples")
540
+ print("Step2: Please Provide an Audio or Choose from Audio Samples")
541
+
542
+ return 0
543
+
544
+
545
+ audio_chunks = inferpipe.get_times(
546
+ prompt=text,
547
+ image_path=image_path,
548
+ audio_path=audio_path,
549
+ seq_len=args.seq_len,
550
+ num_steps=num_steps
551
+ )
552
+
553
+ warmup_s = 30
554
+ duration_s = (20 * num_steps) + warmup_s
555
+
556
+ if audio_chunks > 1:
557
+ duration_s = (20 * num_steps * audio_chunks) + warmup_s
558
+
559
+ print(f'for {audio_chunks} times, might take {duration_s}')
560
+
561
+ return int(duration_s)
562
+
563
+ def preprocess_img(image_path, session_id = None):
564
+
565
+ if session_id is None:
566
+ session_id = uuid.uuid4().hex
567
+
568
+ image = Image.open(image_path).convert("RGB")
569
+
570
+ image = inferpipe.transform(image).unsqueeze(0).to(dtype=inferpipe.dtype)
571
+
572
+ _, _, h, w = image.shape
573
+ select_size = match_size(getattr( args, f'image_sizes_{ args.max_hw}'), h, w)
574
+ image = resize_pad(image, (h, w), select_size)
575
+ image = image * 2.0 - 1.0
576
+ image = image[:, :, None]
577
+
578
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
579
+
580
+ img_dir = output_dir + '/image'
581
+ os.makedirs(img_dir, exist_ok=True)
582
+ input_img_path = os.path.join(img_dir, f"img_input.jpg")
583
+
584
+ image = tensor_to_pil(image)
585
+ image.save(input_img_path)
586
+
587
+ return input_img_path
588
+
589
+
590
+ @spaces.GPU(duration=get_duration)
591
+ def infer(image_path, audio_path, text, num_steps, session_id = None, progress=gr.Progress(track_tqdm=True),):
592
+
593
+ if image_path is None:
594
+
595
+ return None
596
+
597
+ if audio_path is None:
598
+
599
+ return None
600
+
601
+ if session_id is None:
602
+ session_id = uuid.uuid4().hex
603
+
604
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
605
+
606
+ audio_dir = output_dir + '/audio'
607
+ os.makedirs(audio_dir, exist_ok=True)
608
+ if args.silence_duration_s > 0:
609
+ input_audio_path = os.path.join(audio_dir, f"audio_input.wav")
610
+ else:
611
+ input_audio_path = audio_path
612
+ prompt_dir = output_dir + '/prompt'
613
+ os.makedirs(prompt_dir, exist_ok=True)
614
+
615
+ if args.silence_duration_s > 0:
616
+ add_silence_to_audio_ffmpeg(audio_path, input_audio_path, args.silence_duration_s)
617
+
618
+ tmp2_audio_path = os.path.join(audio_dir, f"audio_out.wav")
619
+ prompt_path = os.path.join(prompt_dir, f"prompt.txt")
620
+
621
+ video = inferpipe(
622
+ prompt=text,
623
+ image_path=image_path,
624
+ audio_path=input_audio_path,
625
+ seq_len=args.seq_len,
626
+ num_steps=num_steps
627
+ )
628
+
629
+ torch.cuda.empty_cache()
630
+
631
+ add_silence_to_audio_ffmpeg(audio_path, tmp2_audio_path, 1.0 / args.fps + args.silence_duration_s)
632
+ video_paths = save_video_as_grid_and_mp4(video,
633
+ output_dir,
634
+ args.fps,
635
+ prompt=text,
636
+ prompt_path = prompt_path,
637
+ audio_path=tmp2_audio_path if args.use_audio else None,
638
+ prefix=f'result')
639
+
640
+ return video_paths[0]
641
+
642
+ def apply(request):
643
+
644
+ return request
645
+
646
+ def cleanup(request: gr.Request):
647
+
648
+ sid = request.session_hash
649
+ if sid:
650
+ d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
651
+ shutil.rmtree(d1, ignore_errors=True)
652
+
653
+ def start_session(request: gr.Request):
654
+
655
+ return request.session_hash
656
+
657
+ css = """
658
+ #col-container {
659
+ margin: 0 auto;
660
+ max-width: 1560px;
661
+ }
662
+ """
663
+
664
+ with gr.Blocks(css=css) as demo:
665
+
666
+ session_state = gr.State()
667
+ demo.load(start_session, outputs=[session_state])
668
+
669
+
670
+ with gr.Column(elem_id="col-container"):
671
+ gr.HTML(
672
+ """
673
+ <div style="text-align: left;">
674
+ <p style="font-size:16px; display: inline; margin: 0;">
675
+ <strong>OmniAvatar</strong> Efficient Audio-Driven Avatar Video Generation with Adaptive Body Animation
676
+ </p>
677
+ <a href="https://huggingface.co/OmniAvatar/OmniAvatar-14B" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
678
+ [model]
679
+ </a>
680
+ </div>
681
+ <div style="text-align: left;">
682
+ <strong>HF Space by:</strong>
683
+ <a href="https://twitter.com/alexandernasa/" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
684
+ <img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow Me" alt="GitHub Repo">
685
+ </a>
686
+ </div>
687
+
688
+ """
689
+ )
690
+
691
+ with gr.Row():
692
+
693
+ with gr.Column():
694
+
695
+ image_input = gr.Image(label="Reference Image", type="filepath", height=512)
696
+ audio_input = gr.Audio(label="Input Audio", type="filepath")
697
+
698
+
699
+ with gr.Column():
700
+
701
+ output_video = gr.Video(label="Avatar", height=512)
702
+ num_steps = gr.Slider(4, 50, value=8, step=1, label="Steps")
703
+ time_required = gr.Text(value="⌚ Zero GPU Required: --", show_label=False)
704
+ infer_btn = gr.Button("🦜 Avatar Me", variant="primary")
705
+ with gr.Accordion("Advanced Settings", open=False):
706
+ text_input = gr.Textbox(label="Video Prompt", lines=6, value="A realistic video of a man speaking and sometimes looking directly to the camera and moving her eyes and pupils and head accordingly and he shakes his head in disappointment and tell look stright into the camera , with dynamic and rhythmic and extensive hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.")
707
+
708
+ with gr.Column():
709
+
710
+ cached_examples = gr.Examples(
711
+ examples=[
712
+ [
713
+ "examples/images/male-001.png",
714
+ "examples/audios/denial.wav",
715
+ "A realistic video of a man speaking and sometimes looking directly to the camera and moving her eyes and pupils and head accordingly and he shakes his head in disappointment and tell look stright into the camera , with dynamic and rhythmic and extensive hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
716
+ 12
717
+ ],
718
+ [
719
+ "examples/images/female-001.png",
720
+ "examples/audios/script.wav",
721
+ "A realistic video of a woman speaking and sometimes looking directly to the camera and moving her eyes and pupils and head accordingly and turning and looking at the camera and looking away from the camera based on her movements, sitting on a sofa, with dynamic and rhythmic and extensive hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
722
+ 14
723
+ ],
724
+ [
725
+ "examples/images/female-002.png",
726
+ "examples/audios/nature.wav",
727
+ "A realistic video of a woman speaking and sometimes looking directly to the camera and moving her eyes and pupils and head accordingly and turning and looking at the camera and looking away from the camera based on her movements, standing in the woods, with dynamic and rhythmic and extensive hand gestures that complement his speech. Her hands are clearly visible, independent, and unobstructed. Her facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
728
+ 10
729
+ ],
730
+ ],
731
+ label="Cached Examples",
732
+ inputs=[image_input, audio_input, text_input, num_steps],
733
+ outputs=[output_video],
734
+ fn=infer,
735
+ cache_examples=True
736
+ )
737
+
738
+ image_examples = gr.Examples(
739
+ examples=[
740
+ [
741
+ "examples/images/female-009.png",
742
+ ],
743
+ [
744
+ "examples/images/male-005.png",
745
+ ],
746
+ [
747
+ "examples/images/female-003.png",
748
+ ],
749
+ ],
750
+ label="Image Samples",
751
+ inputs=[image_input],
752
+ outputs=[image_input],
753
+ fn=apply
754
+ )
755
+
756
+ audio_examples = gr.Examples(
757
+ examples=[
758
+ [
759
+ "examples/audios/londoners.wav",
760
+ ],
761
+ [
762
+ "examples/audios/matcha.wav",
763
+ ],
764
+ [
765
+ "examples/audios/keen.wav",
766
+ ],
767
+ ],
768
+ label="Audio Samples",
769
+ inputs=[audio_input],
770
+ outputs=[audio_input],
771
+ fn=apply
772
+ )
773
+
774
+ infer_btn.click(
775
+ fn=infer,
776
+ inputs=[image_input, audio_input, text_input, num_steps, session_state],
777
+ outputs=[output_video]
778
+ )
779
+ image_input.upload(fn=preprocess_img, inputs=[image_input, session_state], outputs=[image_input])
780
+ image_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
781
+ audio_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
782
+ num_steps.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
783
+
784
+
785
+ if __name__ == "__main__":
786
+ demo.unload(cleanup)
787
+ demo.queue()
788
  demo.launch(ssr_mode=False)