alexnasa commited on
Commit
cbd515b
·
verified ·
1 Parent(s): 7e16671

Samples Added

Browse files
.gitattributes CHANGED
@@ -45,3 +45,6 @@ examples/audios/nature.wav filter=lfs diff=lfs merge=lfs -text
45
  examples/images/female-003.png filter=lfs diff=lfs merge=lfs -text
46
  examples/audios/bike.wav filter=lfs diff=lfs merge=lfs -text
47
  examples/audios/matcha.wav filter=lfs diff=lfs merge=lfs -text
 
 
 
 
45
  examples/images/female-003.png filter=lfs diff=lfs merge=lfs -text
46
  examples/audios/bike.wav filter=lfs diff=lfs merge=lfs -text
47
  examples/audios/matcha.wav filter=lfs diff=lfs merge=lfs -text
48
+ examples/audios/keen.wav filter=lfs diff=lfs merge=lfs -text
49
+ examples/audios/londoners.wav filter=lfs diff=lfs merge=lfs -text
50
+ examples/images/female-009.png filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,728 +1,763 @@
1
- import spaces
2
- import subprocess
3
- import gradio as gr
4
-
5
- import os, sys
6
- from glob import glob
7
- from datetime import datetime
8
- import math
9
- import random
10
- import librosa
11
- import numpy as np
12
- import uuid
13
- import shutil
14
-
15
- import importlib, site, sys
16
- from huggingface_hub import hf_hub_download, snapshot_download
17
-
18
- # Re-discover all .pth/.egg-link files
19
- for sitedir in site.getsitepackages():
20
- site.addsitedir(sitedir)
21
-
22
- # Clear caches so importlib will pick up new modules
23
- importlib.invalidate_caches()
24
-
25
- def sh(cmd): subprocess.check_call(cmd, shell=True)
26
-
27
- flash_attention_installed = False
28
-
29
- try:
30
- print("Attempting to download and install FlashAttention wheel...")
31
- flash_attention_wheel = hf_hub_download(
32
- repo_id="alexnasa/flash-attn-3",
33
- repo_type="model",
34
- filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
35
- )
36
-
37
- sh(f"pip install {flash_attention_wheel}")
38
-
39
- # tell Python to re-scan site-packages now that the egg-link exists
40
- import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
41
-
42
- flash_attention_installed = True
43
- print("FlashAttention installed successfully.")
44
-
45
- except Exception as e:
46
- print(f"⚠️ Could not install FlashAttention: {e}")
47
- print("Continuing without FlashAttention...")
48
-
49
- import torch
50
- print(f"Torch version: {torch.__version__}")
51
- print(f"FlashAttention available: {flash_attention_installed}")
52
-
53
-
54
- import torch.nn as nn
55
- from tqdm import tqdm
56
- from functools import partial
57
- from omegaconf import OmegaConf
58
- from argparse import Namespace
59
-
60
- # load the one true config you dumped
61
- _args_cfg = OmegaConf.load("args_config.yaml")
62
- args = Namespace(**OmegaConf.to_container(_args_cfg, resolve=True))
63
-
64
- from OmniAvatar.utils.args_config import set_global_args
65
-
66
- set_global_args(args)
67
- # args = parse_args()
68
-
69
- from OmniAvatar.utils.io_utils import load_state_dict
70
- from peft import LoraConfig, inject_adapter_in_model
71
- from OmniAvatar.models.model_manager import ModelManager
72
- from OmniAvatar.schedulers.flow_match import FlowMatchScheduler
73
- from OmniAvatar.wan_video import WanVideoPipeline
74
- from OmniAvatar.utils.io_utils import save_video_as_grid_and_mp4
75
- import torchvision.transforms as TT
76
- from transformers import Wav2Vec2FeatureExtractor
77
- import torchvision.transforms as transforms
78
- import torch.nn.functional as F
79
- from OmniAvatar.utils.audio_preprocess import add_silence_to_audio_ffmpeg
80
-
81
-
82
- os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/proprocess_results"
83
-
84
- def tensor_to_pil(tensor):
85
- """
86
- Args:
87
- tensor: torch.Tensor with shape like
88
- (1, C, H, W), (1, C, 1, H, W), (C, H, W), etc.
89
- values in [-1, 1], on any device.
90
- Returns:
91
- A PIL.Image in RGB mode.
92
- """
93
- # 1) Remove batch dim if it exists
94
- if tensor.dim() > 3 and tensor.shape[0] == 1:
95
- tensor = tensor[0]
96
-
97
- # 2) Squeeze out any other singleton dims (e.g. that extra frame axis)
98
- tensor = tensor.squeeze()
99
-
100
- # Now we should have exactly 3 dims: (C, H, W)
101
- if tensor.dim() != 3:
102
- raise ValueError(f"Expected 3 dims after squeeze, got {tensor.dim()}")
103
-
104
- # 3) Move to CPU float32
105
- tensor = tensor.cpu().float()
106
-
107
- # 4) Undo normalization from [-1,1] -> [0,1]
108
- tensor = (tensor + 1.0) / 2.0
109
-
110
- # 5) Clamp to [0,1]
111
- tensor = torch.clamp(tensor, 0.0, 1.0)
112
-
113
- # 6) To NumPy H×W×C in [0,255]
114
- np_img = (tensor.permute(1, 2, 0).numpy() * 255.0).round().astype("uint8")
115
-
116
- # 7) Build PIL Image
117
- return Image.fromarray(np_img)
118
-
119
-
120
- def set_seed(seed: int = 42):
121
- random.seed(seed)
122
- np.random.seed(seed)
123
- torch.manual_seed(seed)
124
- torch.cuda.manual_seed(seed) # 设置当前GPU
125
- torch.cuda.manual_seed_all(seed) # 设置所有GPU
126
-
127
- def read_from_file(p):
128
- with open(p, "r") as fin:
129
- for l in fin:
130
- yield l.strip()
131
-
132
- def match_size(image_size, h, w):
133
- ratio_ = 9999
134
- size_ = 9999
135
- select_size = None
136
- for image_s in image_size:
137
- ratio_tmp = abs(image_s[0] / image_s[1] - h / w)
138
- size_tmp = abs(max(image_s) - max(w, h))
139
- if ratio_tmp < ratio_:
140
- ratio_ = ratio_tmp
141
- size_ = size_tmp
142
- select_size = image_s
143
- if ratio_ == ratio_tmp:
144
- if size_ == size_tmp:
145
- select_size = image_s
146
- return select_size
147
-
148
- def resize_pad(image, ori_size, tgt_size):
149
- h, w = ori_size
150
- scale_ratio = max(tgt_size[0] / h, tgt_size[1] / w)
151
- scale_h = int(h * scale_ratio)
152
- scale_w = int(w * scale_ratio)
153
-
154
- image = transforms.Resize(size=[scale_h, scale_w])(image)
155
-
156
- padding_h = tgt_size[0] - scale_h
157
- padding_w = tgt_size[1] - scale_w
158
- pad_top = padding_h // 2
159
- pad_bottom = padding_h - pad_top
160
- pad_left = padding_w // 2
161
- pad_right = padding_w - pad_left
162
-
163
- image = F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
164
- return image
165
-
166
- class WanInferencePipeline(nn.Module):
167
- def __init__(self, args):
168
- super().__init__()
169
- self.args = args
170
- self.device = torch.device(f"cuda")
171
- self.dtype = torch.bfloat16
172
- self.pipe = self.load_model()
173
- chained_trainsforms = []
174
- chained_trainsforms.append(TT.ToTensor())
175
- self.transform = TT.Compose(chained_trainsforms)
176
-
177
- if self.args.use_audio:
178
- from OmniAvatar.models.wav2vec import Wav2VecModel
179
- self.wav_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
180
- self.args.wav2vec_path
181
- )
182
- self.audio_encoder = Wav2VecModel.from_pretrained(self.args.wav2vec_path, local_files_only=True).to(device=self.device, dtype=self.dtype)
183
- self.audio_encoder.feature_extractor._freeze_parameters()
184
-
185
-
186
- def load_model(self):
187
- ckpt_path = f'{self.args.exp_path}/pytorch_model.pt'
188
- assert os.path.exists(ckpt_path), f"pytorch_model.pt not found in {self.args.exp_path}"
189
- if self.args.train_architecture == 'lora':
190
- self.args.pretrained_lora_path = pretrained_lora_path = ckpt_path
191
- else:
192
- resume_path = ckpt_path
193
-
194
- self.step = 0
195
-
196
- # Load models
197
- model_manager = ModelManager(device="cuda", infer=True)
198
-
199
- model_manager.load_models(
200
- [
201
- self.args.dit_path.split(","),
202
- self.args.vae_path,
203
- self.args.text_encoder_path
204
- ],
205
- torch_dtype=self.dtype,
206
- device='cuda',
207
- )
208
-
209
- pipe = WanVideoPipeline.from_model_manager(model_manager,
210
- torch_dtype=self.dtype,
211
- device="cuda",
212
- use_usp=False,
213
- infer=True)
214
-
215
- if self.args.train_architecture == "lora":
216
- print(f'Use LoRA: lora rank: {self.args.lora_rank}, lora alpha: {self.args.lora_alpha}')
217
- self.add_lora_to_model(
218
- pipe.denoising_model(),
219
- lora_rank=self.args.lora_rank,
220
- lora_alpha=self.args.lora_alpha,
221
- lora_target_modules=self.args.lora_target_modules,
222
- init_lora_weights=self.args.init_lora_weights,
223
- pretrained_lora_path=pretrained_lora_path,
224
- )
225
- print(next(pipe.denoising_model().parameters()).device)
226
- else:
227
- missing_keys, unexpected_keys = pipe.denoising_model().load_state_dict(load_state_dict(resume_path), strict=True)
228
- print(f"load from {resume_path}, {len(missing_keys)} missing keys, {len(unexpected_keys)} unexpected keys")
229
- pipe.requires_grad_(False)
230
- pipe.eval()
231
- # pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit)
232
- return pipe
233
-
234
- def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
235
- # Add LoRA to UNet
236
-
237
- self.lora_alpha = lora_alpha
238
- if init_lora_weights == "kaiming":
239
- init_lora_weights = True
240
-
241
- lora_config = LoraConfig(
242
- r=lora_rank,
243
- lora_alpha=lora_alpha,
244
- init_lora_weights=init_lora_weights,
245
- target_modules=lora_target_modules.split(","),
246
- )
247
- model = inject_adapter_in_model(lora_config, model)
248
-
249
- # Lora pretrained lora weights
250
- if pretrained_lora_path is not None:
251
- state_dict = load_state_dict(pretrained_lora_path, torch_dtype=self.dtype)
252
- if state_dict_converter is not None:
253
- state_dict = state_dict_converter(state_dict)
254
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
255
- all_keys = [i for i, _ in model.named_parameters()]
256
- num_updated_keys = len(all_keys) - len(missing_keys)
257
- num_unexpected_keys = len(unexpected_keys)
258
-
259
- print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
260
-
261
- def get_times(self, prompt,
262
- image_path=None,
263
- audio_path=None,
264
- seq_len=101, # not used while audio_path is not None
265
- height=720,
266
- width=720,
267
- overlap_frame=None,
268
- num_steps=None,
269
- negative_prompt=None,
270
- guidance_scale=None,
271
- audio_scale=None):
272
-
273
- overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
274
- num_steps = num_steps if num_steps is not None else self.args.num_steps
275
- negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
276
- guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
277
- audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
278
-
279
- if image_path is not None:
280
- from PIL import Image
281
- image = Image.open(image_path).convert("RGB")
282
-
283
- image = self.transform(image).unsqueeze(0).to(dtype=self.dtype)
284
-
285
- _, _, h, w = image.shape
286
- select_size = match_size(getattr( self.args, f'image_sizes_{ self.args.max_hw}'), h, w)
287
- image = resize_pad(image, (h, w), select_size)
288
- image = image * 2.0 - 1.0
289
- image = image[:, :, None]
290
-
291
- else:
292
- image = None
293
- select_size = [height, width]
294
- num = self.args.max_tokens * 16 * 16 * 4
295
- den = select_size[0] * select_size[1]
296
- L0 = num // den
297
- diff = (L0 - 1) % 4
298
- L = L0 - diff
299
- if L < 1:
300
- L = 1
301
- T = (L + 3) // 4
302
-
303
-
304
- if self.args.random_prefix_frames:
305
- fixed_frame = overlap_frame
306
- assert fixed_frame % 4 == 1
307
- else:
308
- fixed_frame = 1
309
- prefix_lat_frame = (3 + fixed_frame) // 4
310
- first_fixed_frame = 1
311
-
312
-
313
- audio, sr = librosa.load(audio_path, sr= self.args.sample_rate)
314
-
315
- input_values = np.squeeze(
316
- self.wav_feature_extractor(audio, sampling_rate=16000).input_values
317
- )
318
- input_values = torch.from_numpy(input_values).float().to(dtype=self.dtype)
319
- audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
320
-
321
- if audio_len < L - first_fixed_frame:
322
- audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
323
- elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
324
- audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
325
-
326
- seq_len = audio_len
327
-
328
- times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
329
- if times * (L-fixed_frame) + fixed_frame < seq_len:
330
- times += 1
331
-
332
- return times
333
-
334
- @torch.no_grad()
335
- def forward(self, prompt,
336
- image_path=None,
337
- audio_path=None,
338
- seq_len=101, # not used while audio_path is not None
339
- height=720,
340
- width=720,
341
- overlap_frame=None,
342
- num_steps=None,
343
- negative_prompt=None,
344
- guidance_scale=None,
345
- audio_scale=None):
346
- overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
347
- num_steps = num_steps if num_steps is not None else self.args.num_steps
348
- negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
349
- guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
350
- audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
351
-
352
- if image_path is not None:
353
- from PIL import Image
354
- image = Image.open(image_path).convert("RGB")
355
-
356
- image = self.transform(image).unsqueeze(0).to(self.device, dtype=self.dtype)
357
-
358
- _, _, h, w = image.shape
359
- select_size = match_size(getattr(self.args, f'image_sizes_{self.args.max_hw}'), h, w)
360
- image = resize_pad(image, (h, w), select_size)
361
- image = image * 2.0 - 1.0
362
- image = image[:, :, None]
363
-
364
- else:
365
- image = None
366
- select_size = [height, width]
367
- # L = int(self.args.max_tokens * 16 * 16 * 4 / select_size[0] / select_size[1])
368
- # L = L // 4 * 4 + 1 if L % 4 != 0 else L - 3 # video frames
369
- # T = (L + 3) // 4 # latent frames
370
-
371
- # step 1: numerator and denominator as ints
372
- num = args.max_tokens * 16 * 16 * 4
373
- den = select_size[0] * select_size[1]
374
-
375
- # step 2: integer division
376
- L0 = num // den # exact floor division, no float in sight
377
-
378
- # step 3: make it ≡ 1 mod 4
379
- # if L0 % 4 == 1, keep L0;
380
- # otherwise subtract the difference so that (L0 - diff) % 4 == 1,
381
- # but ensure the result stays positive.
382
- diff = (L0 - 1) % 4
383
- L = L0 - diff
384
- if L < 1:
385
- L = 1 # or whatever your minimal frame count is
386
-
387
- # step 4: latent frames
388
- T = (L + 3) // 4
389
-
390
-
391
- if self.args.i2v:
392
- if self.args.random_prefix_frames:
393
- fixed_frame = overlap_frame
394
- assert fixed_frame % 4 == 1
395
- else:
396
- fixed_frame = 1
397
- prefix_lat_frame = (3 + fixed_frame) // 4
398
- first_fixed_frame = 1
399
- else:
400
- fixed_frame = 0
401
- prefix_lat_frame = 0
402
- first_fixed_frame = 0
403
-
404
-
405
- if audio_path is not None and self.args.use_audio:
406
- audio, sr = librosa.load(audio_path, sr=self.args.sample_rate)
407
- input_values = np.squeeze(
408
- self.wav_feature_extractor(audio, sampling_rate=16000).input_values
409
- )
410
- input_values = torch.from_numpy(input_values).float().to(device=self.device, dtype=self.dtype)
411
- ori_audio_len = audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
412
- input_values = input_values.unsqueeze(0)
413
- # padding audio
414
- if audio_len < L - first_fixed_frame:
415
- audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
416
- elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
417
- audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
418
- input_values = F.pad(input_values, (0, audio_len * int(self.args.sample_rate / self.args.fps) - input_values.shape[1]), mode='constant', value=0)
419
- with torch.no_grad():
420
- hidden_states = self.audio_encoder(input_values, seq_len=audio_len, output_hidden_states=True)
421
- audio_embeddings = hidden_states.last_hidden_state
422
- for mid_hidden_states in hidden_states.hidden_states:
423
- audio_embeddings = torch.cat((audio_embeddings, mid_hidden_states), -1)
424
- seq_len = audio_len
425
- audio_embeddings = audio_embeddings.squeeze(0)
426
- audio_prefix = torch.zeros_like(audio_embeddings[:first_fixed_frame])
427
- else:
428
- audio_embeddings = None
429
-
430
- # loop
431
- times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
432
- if times * (L-fixed_frame) + fixed_frame < seq_len:
433
- times += 1
434
- video = []
435
- image_emb = {}
436
- img_lat = None
437
- if self.args.i2v:
438
- self.pipe.load_models_to_device(['vae'])
439
- img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
440
-
441
- msk = torch.zeros_like(img_lat.repeat(1, 1, T, 1, 1)[:,:1], dtype=self.dtype)
442
- image_cat = img_lat.repeat(1, 1, T, 1, 1)
443
- msk[:, :, 1:] = 1
444
- image_emb["y"] = torch.cat([image_cat, msk], dim=1)
445
-
446
- for t in range(times):
447
- print(f"[{t+1}/{times}]")
448
- audio_emb = {}
449
- if t == 0:
450
- overlap = first_fixed_frame
451
- else:
452
- overlap = fixed_frame
453
- image_emb["y"][:, -1:, :prefix_lat_frame] = 0 # 第一次推理是mask只有1,往后都是mask overlap
454
- prefix_overlap = (3 + overlap) // 4
455
- if audio_embeddings is not None:
456
- if t == 0:
457
- audio_tensor = audio_embeddings[
458
- :min(L - overlap, audio_embeddings.shape[0])
459
- ]
460
- else:
461
- audio_start = L - first_fixed_frame + (t - 1) * (L - overlap)
462
- audio_tensor = audio_embeddings[
463
- audio_start: min(audio_start + L - overlap, audio_embeddings.shape[0])
464
- ]
465
-
466
- audio_tensor = torch.cat([audio_prefix, audio_tensor], dim=0)
467
- audio_prefix = audio_tensor[-fixed_frame:]
468
- audio_tensor = audio_tensor.unsqueeze(0).to(device=self.device, dtype=self.dtype)
469
- audio_emb["audio_emb"] = audio_tensor
470
- else:
471
- audio_prefix = None
472
- if image is not None and img_lat is None:
473
- self.pipe.load_models_to_device(['vae'])
474
- img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
475
- assert img_lat.shape[2] == prefix_overlap
476
- img_lat = torch.cat([img_lat, torch.zeros_like(img_lat[:, :, :1].repeat(1, 1, T - prefix_overlap, 1, 1), dtype=self.dtype)], dim=2)
477
- frames, _, latents = self.pipe.log_video(img_lat, prompt, prefix_overlap, image_emb, audio_emb,
478
- negative_prompt, num_inference_steps=num_steps,
479
- cfg_scale=guidance_scale, audio_cfg_scale=audio_scale if audio_scale is not None else guidance_scale,
480
- return_latent=True,
481
- tea_cache_l1_thresh=self.args.tea_cache_l1_thresh,tea_cache_model_id="Wan2.1-T2V-14B")
482
-
483
- torch.cuda.empty_cache()
484
- img_lat = None
485
- image = (frames[:, -fixed_frame:].clip(0, 1) * 2.0 - 1.0).permute(0, 2, 1, 3, 4).contiguous()
486
-
487
- if t == 0:
488
- video.append(frames)
489
- else:
490
- video.append(frames[:, overlap:])
491
- video = torch.cat(video, dim=1)
492
- video = video[:, :ori_audio_len + 1]
493
-
494
- return video
495
-
496
-
497
- snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-14B", local_dir="./pretrained_models/Wan2.1-T2V-14B")
498
- snapshot_download(repo_id="facebook/wav2vec2-base-960h", local_dir="./pretrained_models/wav2vec2-base-960h")
499
- snapshot_download(repo_id="OmniAvatar/OmniAvatar-14B", local_dir="./pretrained_models/OmniAvatar-14B")
500
-
501
-
502
- # snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./pretrained_models/Wan2.1-T2V-1.3B")
503
- # snapshot_download(repo_id="facebook/wav2vec2-base-960h", local_dir="./pretrained_models/wav2vec2-base-960h")
504
- # snapshot_download(repo_id="OmniAvatar/OmniAvatar-1.3B", local_dir="./pretrained_models/OmniAvatar-1.3B")
505
-
506
- import tempfile
507
-
508
- from PIL import Image
509
-
510
-
511
- set_seed(args.seed)
512
- seq_len = args.seq_len
513
- inferpipe = WanInferencePipeline(args)
514
-
515
-
516
- def update_generate_button(image_path, audio_path, text, num_steps):
517
-
518
- if image_path is None or audio_path is None:
519
- return gr.update(value="⌚ Zero GPU Required: --")
520
-
521
- duration_s = get_duration(image_path, audio_path, text, num_steps, None, None)
522
- duration_m = duration_s / 60
523
-
524
- return gr.update(value=f"⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)")
525
-
526
- def get_duration(image_path, audio_path, text, num_steps, session_id, progress):
527
-
528
- audio_chunks = inferpipe.get_times(
529
- prompt=text,
530
- image_path=image_path,
531
- audio_path=audio_path,
532
- seq_len=args.seq_len,
533
- num_steps=num_steps
534
- )
535
-
536
- warmup_s = 30
537
- duration_s = (20 * num_steps) + warmup_s
538
-
539
- if audio_chunks > 1:
540
- duration_s = (20 * num_steps * audio_chunks) + warmup_s
541
-
542
- print(f'for {audio_chunks} times, might take {duration_s}')
543
-
544
- return int(duration_s)
545
-
546
- def preprocess_img(image_path, session_id = None):
547
-
548
- if session_id is None:
549
- session_id = uuid.uuid4().hex
550
-
551
- image = Image.open(image_path).convert("RGB")
552
-
553
- image = inferpipe.transform(image).unsqueeze(0).to(dtype=inferpipe.dtype)
554
-
555
- _, _, h, w = image.shape
556
- select_size = match_size(getattr( args, f'image_sizes_{ args.max_hw}'), h, w)
557
- image = resize_pad(image, (h, w), select_size)
558
- image = image * 2.0 - 1.0
559
- image = image[:, :, None]
560
-
561
- output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
562
-
563
- img_dir = output_dir + '/image'
564
- os.makedirs(img_dir, exist_ok=True)
565
- input_img_path = os.path.join(img_dir, f"img_input.jpg")
566
-
567
- image = tensor_to_pil(image)
568
- image.save(input_img_path)
569
-
570
- return input_img_path
571
-
572
-
573
- @spaces.GPU(duration=get_duration)
574
- def infer(image_path, audio_path, text, num_steps, session_id = None, progress=gr.Progress(track_tqdm=True),):
575
-
576
- if session_id is None:
577
- session_id = uuid.uuid4().hex
578
-
579
- output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
580
-
581
- audio_dir = output_dir + '/audio'
582
- os.makedirs(audio_dir, exist_ok=True)
583
- if args.silence_duration_s > 0:
584
- input_audio_path = os.path.join(audio_dir, f"audio_input.wav")
585
- else:
586
- input_audio_path = audio_path
587
- prompt_dir = output_dir + '/prompt'
588
- os.makedirs(prompt_dir, exist_ok=True)
589
-
590
- if args.silence_duration_s > 0:
591
- add_silence_to_audio_ffmpeg(audio_path, input_audio_path, args.silence_duration_s)
592
-
593
- tmp2_audio_path = os.path.join(audio_dir, f"audio_out.wav")
594
- prompt_path = os.path.join(prompt_dir, f"prompt.txt")
595
-
596
- video = inferpipe(
597
- prompt=text,
598
- image_path=image_path,
599
- audio_path=input_audio_path,
600
- seq_len=args.seq_len,
601
- num_steps=num_steps
602
- )
603
-
604
- torch.cuda.empty_cache()
605
-
606
- add_silence_to_audio_ffmpeg(audio_path, tmp2_audio_path, 1.0 / args.fps + args.silence_duration_s)
607
- video_paths = save_video_as_grid_and_mp4(video,
608
- output_dir,
609
- args.fps,
610
- prompt=text,
611
- prompt_path = prompt_path,
612
- audio_path=tmp2_audio_path if args.use_audio else None,
613
- prefix=f'result')
614
-
615
- return video_paths[0]
616
-
617
- def cleanup(request: gr.Request):
618
-
619
- sid = request.session_hash
620
- if sid:
621
- d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
622
- shutil.rmtree(d1, ignore_errors=True)
623
-
624
- def start_session(request: gr.Request):
625
-
626
- return request.session_hash
627
-
628
- css = """
629
- #col-container {
630
- margin: 0 auto;
631
- max-width: 1560px;
632
- }
633
- """
634
-
635
- with gr.Blocks(css=css) as demo:
636
-
637
- session_state = gr.State()
638
- demo.load(start_session, outputs=[session_state])
639
-
640
-
641
- with gr.Column(elem_id="col-container"):
642
- gr.HTML(
643
- """
644
- <div style="text-align: left;">
645
- <p style="font-size:16px; display: inline; margin: 0;">
646
- <strong>OmniAvatar</strong> – Efficient Audio-Driven Avatar Video Generation with Adaptive Body Animation
647
- </p>
648
- <a href="https://huggingface.co/OmniAvatar/OmniAvatar-14B" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
649
- [model]
650
- </a>
651
- </div>
652
- <div style="text-align: left;">
653
- <strong>HF Space by:</strong>
654
- <a href="https://twitter.com/alexandernasa/" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
655
- <img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow Me" alt="GitHub Repo">
656
- </a>
657
- </div>
658
-
659
- """
660
- )
661
-
662
- with gr.Row():
663
-
664
- with gr.Column():
665
-
666
- image_input = gr.Image(label="Reference Image", type="filepath", height=512)
667
- audio_input = gr.Audio(label="Input Audio", type="filepath")
668
-
669
-
670
- with gr.Column():
671
-
672
- output_video = gr.Video(label="Avatar", height=512)
673
- num_steps = gr.Slider(4, 50, value=8, step=1, label="Steps")
674
- time_required = gr.Text(value="⌚ Zero GPU Required: --", show_label=False)
675
- infer_btn = gr.Button("🦜 Avatar Me", variant="primary")
676
- with gr.Accordion("Advanced Settings", open=False):
677
- text_input = gr.Textbox(label="Video Prompt", lines=6, value="A realistic video of a man speaking and sometimes looking directly to the camera and moving her eyes and pupils and head accordingly and he shakes his head in disappointment and tell look stright into the camera , with dynamic and rhythmic and extensive hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.")
678
-
679
- with gr.Column():
680
-
681
- examples = gr.Examples(
682
- examples=[
683
- [
684
- "examples/images/male-001.png",
685
- "examples/audios/denial.wav",
686
- "A realistic video of a man speaking and sometimes looking directly to the camera and moving her eyes and pupils and head accordingly and he shakes his head in disappointment and tell look stright into the camera , with dynamic and rhythmic and extensive hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
687
- 12
688
- ],
689
- [
690
- "examples/images/female-001.png",
691
- "examples/audios/script.wav",
692
- "A realistic video of a woman speaking and sometimes looking directly to the camera and moving her eyes and pupils and head accordingly and turning and looking at the camera and looking away from the camera based on her movements, sitting on a sofa, with dynamic and rhythmic and extensive hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
693
- 14
694
- ],
695
- [
696
- "examples/images/female-002.png",
697
- "examples/audios/nature.wav",
698
- "A realistic video of a woman speaking and sometimes looking directly to the camera and moving her eyes and pupils and head accordingly and turning and looking at the camera and looking away from the camera based on her movements, standing in the woods, with dynamic and rhythmic and extensive hand gestures that complement his speech. Her hands are clearly visible, independent, and unobstructed. Her facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
699
- 10
700
- ],
701
- # [
702
- # "examples/images/female-003.png",
703
- # "examples/audios/matcha.wav",
704
- # "A realistic video of a sad woman speaking and sometimes looking directly to the camera and moving her eyes and pupils and head accordingly and turning and looking at the camera and looking away from the camera based on her movements, touching a glass in front of her, with dynamic and rhythmic and extensive hand gestures that complement his speech. Her hands are clearly visible, independent, and unobstructed. Her facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
705
- # 20
706
- # ],
707
- ],
708
- inputs=[image_input, audio_input, text_input, num_steps],
709
- outputs=[output_video],
710
- fn=infer,
711
- cache_examples=True
712
- )
713
-
714
- infer_btn.click(
715
- fn=infer,
716
- inputs=[image_input, audio_input, text_input, num_steps, session_state],
717
- outputs=[output_video]
718
- )
719
- image_input.upload(fn=preprocess_img, inputs=[image_input, session_state], outputs=[image_input])
720
- image_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
721
- audio_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
722
- num_steps.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
723
-
724
-
725
- if __name__ == "__main__":
726
- demo.unload(cleanup)
727
- demo.queue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728
  demo.launch(ssr_mode=False)
 
1
+ import spaces
2
+ import subprocess
3
+ import gradio as gr
4
+
5
+ import os, sys
6
+ from glob import glob
7
+ from datetime import datetime
8
+ import math
9
+ import random
10
+ import librosa
11
+ import numpy as np
12
+ import uuid
13
+ import shutil
14
+
15
+ import importlib, site, sys
16
+ from huggingface_hub import hf_hub_download, snapshot_download
17
+
18
+ # Re-discover all .pth/.egg-link files
19
+ for sitedir in site.getsitepackages():
20
+ site.addsitedir(sitedir)
21
+
22
+ # Clear caches so importlib will pick up new modules
23
+ importlib.invalidate_caches()
24
+
25
+ def sh(cmd): subprocess.check_call(cmd, shell=True)
26
+
27
+ flash_attention_installed = False
28
+
29
+ try:
30
+ print("Attempting to download and install FlashAttention wheel...")
31
+ flash_attention_wheel = hf_hub_download(
32
+ repo_id="alexnasa/flash-attn-3",
33
+ repo_type="model",
34
+ filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
35
+ )
36
+
37
+ sh(f"pip install {flash_attention_wheel}")
38
+
39
+ # tell Python to re-scan site-packages now that the egg-link exists
40
+ import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
41
+
42
+ flash_attention_installed = True
43
+ print("FlashAttention installed successfully.")
44
+
45
+ except Exception as e:
46
+ print(f"⚠️ Could not install FlashAttention: {e}")
47
+ print("Continuing without FlashAttention...")
48
+
49
+ import torch
50
+ print(f"Torch version: {torch.__version__}")
51
+ print(f"FlashAttention available: {flash_attention_installed}")
52
+
53
+
54
+ import torch.nn as nn
55
+ from tqdm import tqdm
56
+ from functools import partial
57
+ from omegaconf import OmegaConf
58
+ from argparse import Namespace
59
+
60
+ # load the one true config you dumped
61
+ _args_cfg = OmegaConf.load("args_config.yaml")
62
+ args = Namespace(**OmegaConf.to_container(_args_cfg, resolve=True))
63
+
64
+ from OmniAvatar.utils.args_config import set_global_args
65
+
66
+ set_global_args(args)
67
+ # args = parse_args()
68
+
69
+ from OmniAvatar.utils.io_utils import load_state_dict
70
+ from peft import LoraConfig, inject_adapter_in_model
71
+ from OmniAvatar.models.model_manager import ModelManager
72
+ from OmniAvatar.schedulers.flow_match import FlowMatchScheduler
73
+ from OmniAvatar.wan_video import WanVideoPipeline
74
+ from OmniAvatar.utils.io_utils import save_video_as_grid_and_mp4
75
+ import torchvision.transforms as TT
76
+ from transformers import Wav2Vec2FeatureExtractor
77
+ import torchvision.transforms as transforms
78
+ import torch.nn.functional as F
79
+ from OmniAvatar.utils.audio_preprocess import add_silence_to_audio_ffmpeg
80
+
81
+
82
+ os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/proprocess_results"
83
+
84
+ def tensor_to_pil(tensor):
85
+ """
86
+ Args:
87
+ tensor: torch.Tensor with shape like
88
+ (1, C, H, W), (1, C, 1, H, W), (C, H, W), etc.
89
+ values in [-1, 1], on any device.
90
+ Returns:
91
+ A PIL.Image in RGB mode.
92
+ """
93
+ # 1) Remove batch dim if it exists
94
+ if tensor.dim() > 3 and tensor.shape[0] == 1:
95
+ tensor = tensor[0]
96
+
97
+ # 2) Squeeze out any other singleton dims (e.g. that extra frame axis)
98
+ tensor = tensor.squeeze()
99
+
100
+ # Now we should have exactly 3 dims: (C, H, W)
101
+ if tensor.dim() != 3:
102
+ raise ValueError(f"Expected 3 dims after squeeze, got {tensor.dim()}")
103
+
104
+ # 3) Move to CPU float32
105
+ tensor = tensor.cpu().float()
106
+
107
+ # 4) Undo normalization from [-1,1] -> [0,1]
108
+ tensor = (tensor + 1.0) / 2.0
109
+
110
+ # 5) Clamp to [0,1]
111
+ tensor = torch.clamp(tensor, 0.0, 1.0)
112
+
113
+ # 6) To NumPy H×W×C in [0,255]
114
+ np_img = (tensor.permute(1, 2, 0).numpy() * 255.0).round().astype("uint8")
115
+
116
+ # 7) Build PIL Image
117
+ return Image.fromarray(np_img)
118
+
119
+
120
+ def set_seed(seed: int = 42):
121
+ random.seed(seed)
122
+ np.random.seed(seed)
123
+ torch.manual_seed(seed)
124
+ torch.cuda.manual_seed(seed) # 设置当前GPU
125
+ torch.cuda.manual_seed_all(seed) # 设��所有GPU
126
+
127
+ def read_from_file(p):
128
+ with open(p, "r") as fin:
129
+ for l in fin:
130
+ yield l.strip()
131
+
132
+ def match_size(image_size, h, w):
133
+ ratio_ = 9999
134
+ size_ = 9999
135
+ select_size = None
136
+ for image_s in image_size:
137
+ ratio_tmp = abs(image_s[0] / image_s[1] - h / w)
138
+ size_tmp = abs(max(image_s) - max(w, h))
139
+ if ratio_tmp < ratio_:
140
+ ratio_ = ratio_tmp
141
+ size_ = size_tmp
142
+ select_size = image_s
143
+ if ratio_ == ratio_tmp:
144
+ if size_ == size_tmp:
145
+ select_size = image_s
146
+ return select_size
147
+
148
+ def resize_pad(image, ori_size, tgt_size):
149
+ h, w = ori_size
150
+ scale_ratio = max(tgt_size[0] / h, tgt_size[1] / w)
151
+ scale_h = int(h * scale_ratio)
152
+ scale_w = int(w * scale_ratio)
153
+
154
+ image = transforms.Resize(size=[scale_h, scale_w])(image)
155
+
156
+ padding_h = tgt_size[0] - scale_h
157
+ padding_w = tgt_size[1] - scale_w
158
+ pad_top = padding_h // 2
159
+ pad_bottom = padding_h - pad_top
160
+ pad_left = padding_w // 2
161
+ pad_right = padding_w - pad_left
162
+
163
+ image = F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
164
+ return image
165
+
166
+ class WanInferencePipeline(nn.Module):
167
+ def __init__(self, args):
168
+ super().__init__()
169
+ self.args = args
170
+ self.device = torch.device(f"cuda")
171
+ self.dtype = torch.bfloat16
172
+ self.pipe = self.load_model()
173
+ chained_trainsforms = []
174
+ chained_trainsforms.append(TT.ToTensor())
175
+ self.transform = TT.Compose(chained_trainsforms)
176
+
177
+ if self.args.use_audio:
178
+ from OmniAvatar.models.wav2vec import Wav2VecModel
179
+ self.wav_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
180
+ self.args.wav2vec_path
181
+ )
182
+ self.audio_encoder = Wav2VecModel.from_pretrained(self.args.wav2vec_path, local_files_only=True).to(device=self.device, dtype=self.dtype)
183
+ self.audio_encoder.feature_extractor._freeze_parameters()
184
+
185
+
186
+ def load_model(self):
187
+ ckpt_path = f'{self.args.exp_path}/pytorch_model.pt'
188
+ assert os.path.exists(ckpt_path), f"pytorch_model.pt not found in {self.args.exp_path}"
189
+ if self.args.train_architecture == 'lora':
190
+ self.args.pretrained_lora_path = pretrained_lora_path = ckpt_path
191
+ else:
192
+ resume_path = ckpt_path
193
+
194
+ self.step = 0
195
+
196
+ # Load models
197
+ model_manager = ModelManager(device="cuda", infer=True)
198
+
199
+ model_manager.load_models(
200
+ [
201
+ self.args.dit_path.split(","),
202
+ self.args.vae_path,
203
+ self.args.text_encoder_path
204
+ ],
205
+ torch_dtype=self.dtype,
206
+ device='cuda',
207
+ )
208
+
209
+ pipe = WanVideoPipeline.from_model_manager(model_manager,
210
+ torch_dtype=self.dtype,
211
+ device="cuda",
212
+ use_usp=False,
213
+ infer=True)
214
+
215
+ if self.args.train_architecture == "lora":
216
+ print(f'Use LoRA: lora rank: {self.args.lora_rank}, lora alpha: {self.args.lora_alpha}')
217
+ self.add_lora_to_model(
218
+ pipe.denoising_model(),
219
+ lora_rank=self.args.lora_rank,
220
+ lora_alpha=self.args.lora_alpha,
221
+ lora_target_modules=self.args.lora_target_modules,
222
+ init_lora_weights=self.args.init_lora_weights,
223
+ pretrained_lora_path=pretrained_lora_path,
224
+ )
225
+ print(next(pipe.denoising_model().parameters()).device)
226
+ else:
227
+ missing_keys, unexpected_keys = pipe.denoising_model().load_state_dict(load_state_dict(resume_path), strict=True)
228
+ print(f"load from {resume_path}, {len(missing_keys)} missing keys, {len(unexpected_keys)} unexpected keys")
229
+ pipe.requires_grad_(False)
230
+ pipe.eval()
231
+ # pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit)
232
+ return pipe
233
+
234
+ def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
235
+ # Add LoRA to UNet
236
+
237
+ self.lora_alpha = lora_alpha
238
+ if init_lora_weights == "kaiming":
239
+ init_lora_weights = True
240
+
241
+ lora_config = LoraConfig(
242
+ r=lora_rank,
243
+ lora_alpha=lora_alpha,
244
+ init_lora_weights=init_lora_weights,
245
+ target_modules=lora_target_modules.split(","),
246
+ )
247
+ model = inject_adapter_in_model(lora_config, model)
248
+
249
+ # Lora pretrained lora weights
250
+ if pretrained_lora_path is not None:
251
+ state_dict = load_state_dict(pretrained_lora_path, torch_dtype=self.dtype)
252
+ if state_dict_converter is not None:
253
+ state_dict = state_dict_converter(state_dict)
254
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
255
+ all_keys = [i for i, _ in model.named_parameters()]
256
+ num_updated_keys = len(all_keys) - len(missing_keys)
257
+ num_unexpected_keys = len(unexpected_keys)
258
+
259
+ print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
260
+
261
+ def get_times(self, prompt,
262
+ image_path=None,
263
+ audio_path=None,
264
+ seq_len=101, # not used while audio_path is not None
265
+ height=720,
266
+ width=720,
267
+ overlap_frame=None,
268
+ num_steps=None,
269
+ negative_prompt=None,
270
+ guidance_scale=None,
271
+ audio_scale=None):
272
+
273
+ overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
274
+ num_steps = num_steps if num_steps is not None else self.args.num_steps
275
+ negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
276
+ guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
277
+ audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
278
+
279
+ if image_path is not None:
280
+ from PIL import Image
281
+ image = Image.open(image_path).convert("RGB")
282
+
283
+ image = self.transform(image).unsqueeze(0).to(dtype=self.dtype)
284
+
285
+ _, _, h, w = image.shape
286
+ select_size = match_size(getattr( self.args, f'image_sizes_{ self.args.max_hw}'), h, w)
287
+ image = resize_pad(image, (h, w), select_size)
288
+ image = image * 2.0 - 1.0
289
+ image = image[:, :, None]
290
+
291
+ else:
292
+ image = None
293
+ select_size = [height, width]
294
+ num = self.args.max_tokens * 16 * 16 * 4
295
+ den = select_size[0] * select_size[1]
296
+ L0 = num // den
297
+ diff = (L0 - 1) % 4
298
+ L = L0 - diff
299
+ if L < 1:
300
+ L = 1
301
+ T = (L + 3) // 4
302
+
303
+
304
+ if self.args.random_prefix_frames:
305
+ fixed_frame = overlap_frame
306
+ assert fixed_frame % 4 == 1
307
+ else:
308
+ fixed_frame = 1
309
+ prefix_lat_frame = (3 + fixed_frame) // 4
310
+ first_fixed_frame = 1
311
+
312
+
313
+ audio, sr = librosa.load(audio_path, sr= self.args.sample_rate)
314
+
315
+ input_values = np.squeeze(
316
+ self.wav_feature_extractor(audio, sampling_rate=16000).input_values
317
+ )
318
+ input_values = torch.from_numpy(input_values).float().to(dtype=self.dtype)
319
+ audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
320
+
321
+ if audio_len < L - first_fixed_frame:
322
+ audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
323
+ elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
324
+ audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
325
+
326
+ seq_len = audio_len
327
+
328
+ times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
329
+ if times * (L-fixed_frame) + fixed_frame < seq_len:
330
+ times += 1
331
+
332
+ return times
333
+
334
+ @torch.no_grad()
335
+ def forward(self, prompt,
336
+ image_path=None,
337
+ audio_path=None,
338
+ seq_len=101, # not used while audio_path is not None
339
+ height=720,
340
+ width=720,
341
+ overlap_frame=None,
342
+ num_steps=None,
343
+ negative_prompt=None,
344
+ guidance_scale=None,
345
+ audio_scale=None):
346
+ overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
347
+ num_steps = num_steps if num_steps is not None else self.args.num_steps
348
+ negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
349
+ guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
350
+ audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
351
+
352
+ if image_path is not None:
353
+ from PIL import Image
354
+ image = Image.open(image_path).convert("RGB")
355
+
356
+ image = self.transform(image).unsqueeze(0).to(self.device, dtype=self.dtype)
357
+
358
+ _, _, h, w = image.shape
359
+ select_size = match_size(getattr(self.args, f'image_sizes_{self.args.max_hw}'), h, w)
360
+ image = resize_pad(image, (h, w), select_size)
361
+ image = image * 2.0 - 1.0
362
+ image = image[:, :, None]
363
+
364
+ else:
365
+ image = None
366
+ select_size = [height, width]
367
+ # L = int(self.args.max_tokens * 16 * 16 * 4 / select_size[0] / select_size[1])
368
+ # L = L // 4 * 4 + 1 if L % 4 != 0 else L - 3 # video frames
369
+ # T = (L + 3) // 4 # latent frames
370
+
371
+ # step 1: numerator and denominator as ints
372
+ num = args.max_tokens * 16 * 16 * 4
373
+ den = select_size[0] * select_size[1]
374
+
375
+ # step 2: integer division
376
+ L0 = num // den # exact floor division, no float in sight
377
+
378
+ # step 3: make it ≡ 1 mod 4
379
+ # if L0 % 4 == 1, keep L0;
380
+ # otherwise subtract the difference so that (L0 - diff) % 4 == 1,
381
+ # but ensure the result stays positive.
382
+ diff = (L0 - 1) % 4
383
+ L = L0 - diff
384
+ if L < 1:
385
+ L = 1 # or whatever your minimal frame count is
386
+
387
+ # step 4: latent frames
388
+ T = (L + 3) // 4
389
+
390
+
391
+ if self.args.i2v:
392
+ if self.args.random_prefix_frames:
393
+ fixed_frame = overlap_frame
394
+ assert fixed_frame % 4 == 1
395
+ else:
396
+ fixed_frame = 1
397
+ prefix_lat_frame = (3 + fixed_frame) // 4
398
+ first_fixed_frame = 1
399
+ else:
400
+ fixed_frame = 0
401
+ prefix_lat_frame = 0
402
+ first_fixed_frame = 0
403
+
404
+
405
+ if audio_path is not None and self.args.use_audio:
406
+ audio, sr = librosa.load(audio_path, sr=self.args.sample_rate)
407
+ input_values = np.squeeze(
408
+ self.wav_feature_extractor(audio, sampling_rate=16000).input_values
409
+ )
410
+ input_values = torch.from_numpy(input_values).float().to(device=self.device, dtype=self.dtype)
411
+ ori_audio_len = audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
412
+ input_values = input_values.unsqueeze(0)
413
+ # padding audio
414
+ if audio_len < L - first_fixed_frame:
415
+ audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
416
+ elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
417
+ audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
418
+ input_values = F.pad(input_values, (0, audio_len * int(self.args.sample_rate / self.args.fps) - input_values.shape[1]), mode='constant', value=0)
419
+ with torch.no_grad():
420
+ hidden_states = self.audio_encoder(input_values, seq_len=audio_len, output_hidden_states=True)
421
+ audio_embeddings = hidden_states.last_hidden_state
422
+ for mid_hidden_states in hidden_states.hidden_states:
423
+ audio_embeddings = torch.cat((audio_embeddings, mid_hidden_states), -1)
424
+ seq_len = audio_len
425
+ audio_embeddings = audio_embeddings.squeeze(0)
426
+ audio_prefix = torch.zeros_like(audio_embeddings[:first_fixed_frame])
427
+ else:
428
+ audio_embeddings = None
429
+
430
+ # loop
431
+ times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
432
+ if times * (L-fixed_frame) + fixed_frame < seq_len:
433
+ times += 1
434
+ video = []
435
+ image_emb = {}
436
+ img_lat = None
437
+ if self.args.i2v:
438
+ self.pipe.load_models_to_device(['vae'])
439
+ img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
440
+
441
+ msk = torch.zeros_like(img_lat.repeat(1, 1, T, 1, 1)[:,:1], dtype=self.dtype)
442
+ image_cat = img_lat.repeat(1, 1, T, 1, 1)
443
+ msk[:, :, 1:] = 1
444
+ image_emb["y"] = torch.cat([image_cat, msk], dim=1)
445
+
446
+ for t in range(times):
447
+ print(f"[{t+1}/{times}]")
448
+ audio_emb = {}
449
+ if t == 0:
450
+ overlap = first_fixed_frame
451
+ else:
452
+ overlap = fixed_frame
453
+ image_emb["y"][:, -1:, :prefix_lat_frame] = 0 # 第一次推理是mask只有1,往后都是mask overlap
454
+ prefix_overlap = (3 + overlap) // 4
455
+ if audio_embeddings is not None:
456
+ if t == 0:
457
+ audio_tensor = audio_embeddings[
458
+ :min(L - overlap, audio_embeddings.shape[0])
459
+ ]
460
+ else:
461
+ audio_start = L - first_fixed_frame + (t - 1) * (L - overlap)
462
+ audio_tensor = audio_embeddings[
463
+ audio_start: min(audio_start + L - overlap, audio_embeddings.shape[0])
464
+ ]
465
+
466
+ audio_tensor = torch.cat([audio_prefix, audio_tensor], dim=0)
467
+ audio_prefix = audio_tensor[-fixed_frame:]
468
+ audio_tensor = audio_tensor.unsqueeze(0).to(device=self.device, dtype=self.dtype)
469
+ audio_emb["audio_emb"] = audio_tensor
470
+ else:
471
+ audio_prefix = None
472
+ if image is not None and img_lat is None:
473
+ self.pipe.load_models_to_device(['vae'])
474
+ img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
475
+ assert img_lat.shape[2] == prefix_overlap
476
+ img_lat = torch.cat([img_lat, torch.zeros_like(img_lat[:, :, :1].repeat(1, 1, T - prefix_overlap, 1, 1), dtype=self.dtype)], dim=2)
477
+ frames, _, latents = self.pipe.log_video(img_lat, prompt, prefix_overlap, image_emb, audio_emb,
478
+ negative_prompt, num_inference_steps=num_steps,
479
+ cfg_scale=guidance_scale, audio_cfg_scale=audio_scale if audio_scale is not None else guidance_scale,
480
+ return_latent=True,
481
+ tea_cache_l1_thresh=self.args.tea_cache_l1_thresh,tea_cache_model_id="Wan2.1-T2V-14B")
482
+
483
+ torch.cuda.empty_cache()
484
+ img_lat = None
485
+ image = (frames[:, -fixed_frame:].clip(0, 1) * 2.0 - 1.0).permute(0, 2, 1, 3, 4).contiguous()
486
+
487
+ if t == 0:
488
+ video.append(frames)
489
+ else:
490
+ video.append(frames[:, overlap:])
491
+ video = torch.cat(video, dim=1)
492
+ video = video[:, :ori_audio_len + 1]
493
+
494
+ return video
495
+
496
+
497
+ snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-14B", local_dir="./pretrained_models/Wan2.1-T2V-14B")
498
+ snapshot_download(repo_id="facebook/wav2vec2-base-960h", local_dir="./pretrained_models/wav2vec2-base-960h")
499
+ snapshot_download(repo_id="OmniAvatar/OmniAvatar-14B", local_dir="./pretrained_models/OmniAvatar-14B")
500
+
501
+
502
+ # snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./pretrained_models/Wan2.1-T2V-1.3B")
503
+ # snapshot_download(repo_id="facebook/wav2vec2-base-960h", local_dir="./pretrained_models/wav2vec2-base-960h")
504
+ # snapshot_download(repo_id="OmniAvatar/OmniAvatar-1.3B", local_dir="./pretrained_models/OmniAvatar-1.3B")
505
+
506
+ import tempfile
507
+
508
+ from PIL import Image
509
+
510
+
511
+ set_seed(args.seed)
512
+ seq_len = args.seq_len
513
+ inferpipe = WanInferencePipeline(args)
514
+
515
+
516
+ def update_generate_button(image_path, audio_path, text, num_steps):
517
+
518
+ if image_path is None or audio_path is None:
519
+ return gr.update(value="⌚ Zero GPU Required: --")
520
+
521
+ duration_s = get_duration(image_path, audio_path, text, num_steps, None, None)
522
+ duration_m = duration_s / 60
523
+
524
+ return gr.update(value=f"⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)")
525
+
526
+ def get_duration(image_path, audio_path, text, num_steps, session_id, progress):
527
+
528
+ audio_chunks = inferpipe.get_times(
529
+ prompt=text,
530
+ image_path=image_path,
531
+ audio_path=audio_path,
532
+ seq_len=args.seq_len,
533
+ num_steps=num_steps
534
+ )
535
+
536
+ warmup_s = 30
537
+ duration_s = (20 * num_steps) + warmup_s
538
+
539
+ if audio_chunks > 1:
540
+ duration_s = (20 * num_steps * audio_chunks) + warmup_s
541
+
542
+ print(f'for {audio_chunks} times, might take {duration_s}')
543
+
544
+ return int(duration_s)
545
+
546
+ def preprocess_img(image_path, session_id = None):
547
+
548
+ if session_id is None:
549
+ session_id = uuid.uuid4().hex
550
+
551
+ image = Image.open(image_path).convert("RGB")
552
+
553
+ image = inferpipe.transform(image).unsqueeze(0).to(dtype=inferpipe.dtype)
554
+
555
+ _, _, h, w = image.shape
556
+ select_size = match_size(getattr( args, f'image_sizes_{ args.max_hw}'), h, w)
557
+ image = resize_pad(image, (h, w), select_size)
558
+ image = image * 2.0 - 1.0
559
+ image = image[:, :, None]
560
+
561
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
562
+
563
+ img_dir = output_dir + '/image'
564
+ os.makedirs(img_dir, exist_ok=True)
565
+ input_img_path = os.path.join(img_dir, f"img_input.jpg")
566
+
567
+ image = tensor_to_pil(image)
568
+ image.save(input_img_path)
569
+
570
+ return input_img_path
571
+
572
+
573
+ @spaces.GPU(duration=get_duration)
574
+ def infer(image_path, audio_path, text, num_steps, session_id = None, progress=gr.Progress(track_tqdm=True),):
575
+
576
+ if session_id is None:
577
+ session_id = uuid.uuid4().hex
578
+
579
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
580
+
581
+ audio_dir = output_dir + '/audio'
582
+ os.makedirs(audio_dir, exist_ok=True)
583
+ if args.silence_duration_s > 0:
584
+ input_audio_path = os.path.join(audio_dir, f"audio_input.wav")
585
+ else:
586
+ input_audio_path = audio_path
587
+ prompt_dir = output_dir + '/prompt'
588
+ os.makedirs(prompt_dir, exist_ok=True)
589
+
590
+ if args.silence_duration_s > 0:
591
+ add_silence_to_audio_ffmpeg(audio_path, input_audio_path, args.silence_duration_s)
592
+
593
+ tmp2_audio_path = os.path.join(audio_dir, f"audio_out.wav")
594
+ prompt_path = os.path.join(prompt_dir, f"prompt.txt")
595
+
596
+ video = inferpipe(
597
+ prompt=text,
598
+ image_path=image_path,
599
+ audio_path=input_audio_path,
600
+ seq_len=args.seq_len,
601
+ num_steps=num_steps
602
+ )
603
+
604
+ torch.cuda.empty_cache()
605
+
606
+ add_silence_to_audio_ffmpeg(audio_path, tmp2_audio_path, 1.0 / args.fps + args.silence_duration_s)
607
+ video_paths = save_video_as_grid_and_mp4(video,
608
+ output_dir,
609
+ args.fps,
610
+ prompt=text,
611
+ prompt_path = prompt_path,
612
+ audio_path=tmp2_audio_path if args.use_audio else None,
613
+ prefix=f'result')
614
+
615
+ return video_paths[0]
616
+
617
+ def apply(request):
618
+
619
+ return request
620
+
621
+ def cleanup(request: gr.Request):
622
+
623
+ sid = request.session_hash
624
+ if sid:
625
+ d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
626
+ shutil.rmtree(d1, ignore_errors=True)
627
+
628
+ def start_session(request: gr.Request):
629
+
630
+ return request.session_hash
631
+
632
+ css = """
633
+ #col-container {
634
+ margin: 0 auto;
635
+ max-width: 1560px;
636
+ }
637
+ """
638
+
639
+ with gr.Blocks(css=css) as demo:
640
+
641
+ session_state = gr.State()
642
+ demo.load(start_session, outputs=[session_state])
643
+
644
+
645
+ with gr.Column(elem_id="col-container"):
646
+ gr.HTML(
647
+ """
648
+ <div style="text-align: left;">
649
+ <p style="font-size:16px; display: inline; margin: 0;">
650
+ <strong>OmniAvatar</strong> – Efficient Audio-Driven Avatar Video Generation with Adaptive Body Animation
651
+ </p>
652
+ <a href="https://huggingface.co/OmniAvatar/OmniAvatar-14B" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
653
+ [model]
654
+ </a>
655
+ </div>
656
+ <div style="text-align: left;">
657
+ <strong>HF Space by:</strong>
658
+ <a href="https://twitter.com/alexandernasa/" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
659
+ <img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow Me" alt="GitHub Repo">
660
+ </a>
661
+ </div>
662
+
663
+ """
664
+ )
665
+
666
+ with gr.Row():
667
+
668
+ with gr.Column():
669
+
670
+ image_input = gr.Image(label="Reference Image", type="filepath", height=512)
671
+ audio_input = gr.Audio(label="Input Audio", type="filepath")
672
+
673
+
674
+ with gr.Column():
675
+
676
+ output_video = gr.Video(label="Avatar", height=512)
677
+ num_steps = gr.Slider(4, 50, value=8, step=1, label="Steps")
678
+ time_required = gr.Text(value="⌚ Zero GPU Required: --", show_label=False)
679
+ infer_btn = gr.Button("🦜 Avatar Me", variant="primary")
680
+ with gr.Accordion("Advanced Settings", open=False):
681
+ text_input = gr.Textbox(label="Video Prompt", lines=6, value="A realistic video of a man speaking and sometimes looking directly to the camera and moving her eyes and pupils and head accordingly and he shakes his head in disappointment and tell look stright into the camera , with dynamic and rhythmic and extensive hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.")
682
+
683
+ with gr.Column():
684
+
685
+ cached_examples = gr.Examples(
686
+ examples=[
687
+ [
688
+ "examples/images/male-001.png",
689
+ "examples/audios/denial.wav",
690
+ "A realistic video of a man speaking and sometimes looking directly to the camera and moving her eyes and pupils and head accordingly and he shakes his head in disappointment and tell look stright into the camera , with dynamic and rhythmic and extensive hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
691
+ 12
692
+ ],
693
+ [
694
+ "examples/images/female-001.png",
695
+ "examples/audios/script.wav",
696
+ "A realistic video of a woman speaking and sometimes looking directly to the camera and moving her eyes and pupils and head accordingly and turning and looking at the camera and looking away from the camera based on her movements, sitting on a sofa, with dynamic and rhythmic and extensive hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
697
+ 14
698
+ ],
699
+ [
700
+ "examples/images/female-002.png",
701
+ "examples/audios/nature.wav",
702
+ "A realistic video of a woman speaking and sometimes looking directly to the camera and moving her eyes and pupils and head accordingly and turning and looking at the camera and looking away from the camera based on her movements, standing in the woods, with dynamic and rhythmic and extensive hand gestures that complement his speech. Her hands are clearly visible, independent, and unobstructed. Her facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
703
+ 10
704
+ ],
705
+ ],
706
+ label="Cached Examples",
707
+ inputs=[image_input, audio_input, text_input, num_steps],
708
+ outputs=[output_video],
709
+ fn=infer,
710
+ cache_examples=True
711
+ )
712
+
713
+ image_examples = gr.Examples(
714
+ examples=[
715
+ [
716
+ "examples/images/female-009.png",
717
+ ],
718
+ [
719
+ "examples/images/male-005.png",
720
+ ],
721
+ [
722
+ "examples/images/female-003.png",
723
+ ],
724
+ ],
725
+ label="Image Samples",
726
+ inputs=[image_input],
727
+ outputs=[image_input],
728
+ fn=apply
729
+ )
730
+
731
+ audio_examples = gr.Examples(
732
+ examples=[
733
+ [
734
+ "examples/audios/londoners.wav",
735
+ ],
736
+ [
737
+ "examples/audios/matcha.wav",
738
+ ],
739
+ [
740
+ "examples/audios/keen.wav",
741
+ ],
742
+ ],
743
+ label="Audio Samples",
744
+ inputs=[audio_input],
745
+ outputs=[audio_input],
746
+ fn=apply
747
+ )
748
+
749
+ infer_btn.click(
750
+ fn=infer,
751
+ inputs=[image_input, audio_input, text_input, num_steps, session_state],
752
+ outputs=[output_video]
753
+ )
754
+ image_input.upload(fn=preprocess_img, inputs=[image_input, session_state], outputs=[image_input])
755
+ image_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
756
+ audio_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
757
+ num_steps.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
758
+
759
+
760
+ if __name__ == "__main__":
761
+ demo.unload(cleanup)
762
+ demo.queue()
763
  demo.launch(ssr_mode=False)
examples/audios/keen.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fdcb6b27d1e8fcf72aa2cc8fbb2c8f92ba754970d5568fcdd83de49cf353f943
3
+ size 145964
examples/audios/londoners.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7372af9089b330118118dd03cec838a3fa45cc33ed35245eda315da3595715f
3
+ size 316844
examples/images/female-009.png ADDED

Git LFS Details

  • SHA256: 730b7afe8066c0ce822f713c73f0d660b6ee9ff4f6957d2dddcdeffe9b13b0e7
  • Pointer size: 132 Bytes
  • Size of remote file: 2.97 MB
examples/images/male-005.png ADDED