Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import torch.utils.checkpoint | |
from PIL import Image | |
import numpy as np | |
from omegaconf import OmegaConf | |
from tqdm import tqdm | |
import cv2 | |
from diffusers import AutoencoderKLTemporalDecoder | |
from diffusers.schedulers import EulerDiscreteScheduler | |
from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatureExtractor | |
from src.utils.util import save_videos_grid, seed_everything | |
from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor | |
from src.models.base.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel, add_ip_adapters | |
from src.pipelines.pipeline_sonic import SonicPipeline | |
from src.models.audio_adapter.audio_proj import AudioProjModel | |
from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel | |
from src.utils.RIFE.RIFE_HDv3 import RIFEModel | |
from src.dataset.face_align.align import AlignImage | |
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
def test( | |
pipe, | |
config, | |
wav_enc, | |
audio_pe, | |
audio2bucket, | |
image_encoder, | |
width, | |
height, | |
batch, | |
): | |
"""Generate a video tensor for the given batch.""" | |
for k, v in batch.items(): | |
if isinstance(v, torch.Tensor): | |
batch[k] = v.unsqueeze(0).to(pipe.device).float() | |
ref_img = batch['ref_img'] | |
clip_img = batch['clip_images'] | |
face_mask = batch['face_mask'] | |
image_embeds = image_encoder(clip_img).image_embeds | |
audio_feature = batch['audio_feature'] | |
audio_len = batch['audio_len'] | |
step = int(config.step) | |
window = 3000 | |
audio_prompts = [] | |
last_audio_prompts = [] | |
for i in range(0, audio_feature.shape[-1], window): | |
audio_prompt = wav_enc.encoder(audio_feature[:, :, i:i+window], output_hidden_states=True).hidden_states | |
last_audio_prompt = wav_enc.encoder(audio_feature[:, :, i:i+window]).last_hidden_state | |
last_audio_prompt = last_audio_prompt.unsqueeze(-2) | |
audio_prompt = torch.stack(audio_prompt, dim=2) | |
audio_prompts.append(audio_prompt) | |
last_audio_prompts.append(last_audio_prompt) | |
audio_prompts = torch.cat(audio_prompts, dim=1) | |
audio_prompts = audio_prompts[:, :audio_len*2] | |
audio_prompts = torch.cat([ | |
torch.zeros_like(audio_prompts[:, :4]), | |
audio_prompts, | |
torch.zeros_like(audio_prompts[:, :6]) | |
], 1) | |
last_audio_prompts = torch.cat(last_audio_prompts, dim=1) | |
last_audio_prompts = last_audio_prompts[:, :audio_len*2] | |
last_audio_prompts = torch.cat([ | |
torch.zeros_like(last_audio_prompts[:, :24]), | |
last_audio_prompts, | |
torch.zeros_like(last_audio_prompts[:, :26]) | |
], 1) | |
ref_tensor_list = [] | |
audio_tensor_list = [] | |
uncond_audio_tensor_list = [] | |
motion_buckets = [] | |
for i in tqdm(range(audio_len//step), ncols=0): | |
audio_clip = audio_prompts[:, i*2*step:i*2*step+10].unsqueeze(0) | |
audio_clip_for_bucket = last_audio_prompts[:, i*2*step:i*2*step+50].unsqueeze(0) | |
motion_bucket = audio2bucket(audio_clip_for_bucket, image_embeds) | |
motion_bucket = motion_bucket * 16 + 16 | |
motion_buckets.append(motion_bucket[0]) | |
cond_audio_clip = audio_pe(audio_clip).squeeze(0) | |
uncond_audio_clip = audio_pe(torch.zeros_like(audio_clip)).squeeze(0) | |
ref_tensor_list.append(ref_img[0]) | |
audio_tensor_list.append(cond_audio_clip[0]) | |
uncond_audio_tensor_list.append(uncond_audio_clip[0]) | |
video = pipe( | |
ref_img, | |
clip_img, | |
face_mask, | |
audio_tensor_list, | |
uncond_audio_tensor_list, | |
motion_buckets, | |
height=height, | |
width=width, | |
num_frames=len(audio_tensor_list), | |
decode_chunk_size=config.decode_chunk_size, | |
motion_bucket_scale=config.motion_bucket_scale, | |
fps=config.fps, | |
noise_aug_strength=config.noise_aug_strength, | |
min_guidance_scale1=config.min_appearance_guidance_scale, | |
max_guidance_scale1=config.max_appearance_guidance_scale, | |
min_guidance_scale2=config.audio_guidance_scale, | |
max_guidance_scale2=config.audio_guidance_scale, | |
overlap=config.overlap, | |
shift_offset=config.shift_offset, | |
frames_per_batch=config.n_sample_frames, | |
num_inference_steps=config.num_inference_steps, | |
i2i_noise_strength=config.i2i_noise_strength | |
).frames | |
video = (video * 0.5 + 0.5).clamp(0, 1) | |
video = torch.cat([video.to(pipe.device)], dim=0).cpu() | |
return video | |
class Sonic: | |
"""High-level interface for the Sonic portrait animation pipeline.""" | |
config_file = os.path.join(BASE_DIR, 'config/inference/sonic.yaml') | |
config = OmegaConf.load(config_file) | |
def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True): | |
config = self.config | |
config.use_interframe = enable_interpolate_frame | |
device = f'cuda:{device_id}' if device_id > -1 else 'cpu' | |
self.device = device | |
config.pretrained_model_name_or_path = os.path.join(BASE_DIR, config.pretrained_model_name_or_path) | |
vae = AutoencoderKLTemporalDecoder.from_pretrained( | |
config.pretrained_model_name_or_path, subfolder='vae', variant='fp16') | |
val_noise_scheduler = EulerDiscreteScheduler.from_pretrained( | |
config.pretrained_model_name_or_path, subfolder='scheduler') | |
image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
config.pretrained_model_name_or_path, subfolder='image_encoder', variant='fp16') | |
unet = UNetSpatioTemporalConditionModel.from_pretrained( | |
config.pretrained_model_name_or_path, subfolder='unet', variant='fp16') | |
add_ip_adapters(unet, [32], [config.ip_audio_scale]) | |
audio2token = AudioProjModel(seq_len=10, blocks=5, channels=384, intermediate_dim=1024, | |
output_dim=1024, context_tokens=32).to(device) | |
audio2bucket = Audio2bucketModel(seq_len=50, blocks=1, channels=384, clip_channels=1024, | |
intermediate_dim=1024, output_dim=1, context_tokens=2).to(device) | |
unet.load_state_dict( | |
torch.load(os.path.join(BASE_DIR, config.unet_checkpoint_path), map_location='cpu'), strict=True) | |
audio2token.load_state_dict( | |
torch.load(os.path.join(BASE_DIR, config.audio2token_checkpoint_path), map_location='cpu'), strict=True) | |
audio2bucket.load_state_dict( | |
torch.load(os.path.join(BASE_DIR, config.audio2bucket_checkpoint_path), map_location='cpu'), strict=True) | |
dtype_map = {'fp16': torch.float16, 'fp32': torch.float32, 'bf16': torch.bfloat16} | |
weight_dtype = dtype_map.get(config.weight_dtype) | |
if weight_dtype is None: | |
raise ValueError(f"Unsupported weight dtype: {config.weight_dtype}") | |
whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/')).to(device).eval() | |
whisper.requires_grad_(False) | |
self.feature_extractor = AutoFeatureExtractor.from_pretrained( | |
os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/')) | |
self.face_det = AlignImage(device, det_path=os.path.join(BASE_DIR, 'checkpoints/yoloface_v5m.pt')) | |
if config.use_interframe: | |
self.rife = RIFEModel(device=device) | |
self.rife.load_model(os.path.join(BASE_DIR, 'checkpoints', 'RIFE/')) | |
image_encoder.to(weight_dtype) | |
vae.to(weight_dtype) | |
unet.to(weight_dtype) | |
pipe = SonicPipeline( | |
unet=unet, image_encoder=image_encoder, vae=vae, scheduler=val_noise_scheduler) | |
self.pipe = pipe.to(device=device, dtype=weight_dtype) | |
self.whisper = whisper | |
self.audio2token = audio2token | |
self.audio2bucket = audio2bucket | |
self.image_encoder = image_encoder | |
print('Sonic initialization complete.') | |
def preprocess(self, image_path: str, expand_ratio: float = 1.0): | |
face_image = cv2.imread(image_path) | |
h, w = face_image.shape[:2] | |
_, _, bboxes = self.face_det(face_image, maxface=True) | |
face_num = len(bboxes) | |
bbox_s = [] | |
if face_num > 0: | |
x1, y1, ww, hh = bboxes[0] | |
x2, y2 = x1 + ww, y1 + hh | |
bbox_s = process_bbox((x1, y1, x2, y2), expand_radio=expand_ratio, height=h, width=w) | |
return {'face_num': face_num, 'crop_bbox': bbox_s} | |
def crop_image(self, input_image_path: str, output_image_path: str, crop_bbox): | |
face_image = cv2.imread(input_image_path) | |
crop_img = face_image[crop_bbox[1]:crop_bbox[3], crop_bbox[0]:crop_bbox[2]] | |
cv2.imwrite(output_image_path, crop_img) | |
def process(self, image_path, audio_path, output_path, min_resolution=512, | |
inference_steps=25, dynamic_scale=1.0, keep_resolution=False, seed=None): | |
config = self.config | |
device = self.device | |
pipe = self.pipe | |
whisper = self.whisper | |
audio2token = self.audio2token | |
audio2bucket = self.audio2bucket | |
image_encoder = self.image_encoder | |
if seed is not None: | |
config.seed = seed | |
seed_everything(config.seed) | |
config.num_inference_steps = inference_steps | |
config.frame_num = config.fps * 60 | |
config.motion_bucket_scale = dynamic_scale | |
video_path = output_path.replace('.mp4', '_noaudio.mp4') | |
audio_video_path = output_path | |
imSrc_ = Image.open(image_path).convert('RGB') | |
raw_w, raw_h = imSrc_.size | |
test_data = image_audio_to_tensor( | |
self.face_det, self.feature_extractor, image_path, audio_path, | |
limit=config.frame_num, image_size=min_resolution, area=config.area) | |
if test_data is None: | |
return -1 | |
height, width = test_data['ref_img'].shape[-2:] | |
resolution = f"{width}x{height}" if not keep_resolution else f"{raw_w//2*2}x{raw_h//2*2}" | |
video = test(pipe, config, wav_enc=whisper, audio_pe=audio2token, | |
audio2bucket=audio2bucket, image_encoder=image_encoder, | |
width=width, height=height, batch=test_data) | |
if config.use_interframe: | |
out = video.to(device) | |
results = [] | |
for idx in tqdm(range(out.shape[2]-1), ncols=0): | |
I1 = out[:, :, idx] | |
I2 = out[:, :, idx+1] | |
mid = self.rife.inference(I1, I2).clamp(0,1).detach() | |
results.extend([out[:, :, idx], mid]) | |
results.append(out[:, :, -1]) | |
video = torch.stack(results, 2).cpu() | |
save_videos_grid(video, video_path, n_rows=video.shape[0], fps=config.fps * (2 if config.use_interframe else 1)) | |
os.system(f"ffmpeg -i '{video_path}' -i '{audio_path}' -s {resolution} -vcodec libx264 -acodec aac -crf 18 -shortest '{audio_video_path}' -y; rm '{video_path}'") | |
return 0 | |