import os, math import torch from PIL import Image from omegaconf import OmegaConf from tqdm import tqdm import cv2 from diffusers import AutoencoderKLTemporalDecoder from diffusers.schedulers import EulerDiscreteScheduler from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatureExtractor from src.utils.util import save_videos_grid, seed_everything from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor from src.models.base.unet_spatio_temporal_condition import ( UNetSpatioTemporalConditionModel, add_ip_adapters, ) from src.pipelines.pipeline_sonic import SonicPipeline from src.models.audio_adapter.audio_proj import AudioProjModel from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel from src.utils.RIFE.RIFE_HDv3 import RIFEModel from src.dataset.face_align.align import AlignImage BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # ------------------------------------------------------------------ # single image + speech → video-tensor generator # ------------------------------------------------------------------ def test( pipe, config, wav_enc, audio_pe, audio2bucket, image_encoder, width, height, batch, ): # ---------------- batch 차원 맞추기 ----------------------------- for k, v in batch.items(): if isinstance(v, torch.Tensor): batch[k] = v.unsqueeze(0).to(pipe.device).float() ref_img = batch["ref_img"] # (1,C,H,W) clip_img = batch["clip_images"] face_mask = batch["face_mask"] image_embeds = image_encoder(clip_img).image_embeds # (1,1024) audio_feature = batch["audio_feature"] # (1,80,T) audio_len = int(batch["audio_len"]) # python int step = int(config.step) # ---------- window 단위 Whisper 인코딩 -------------------------- window = 16_000 # 1 초 audio_prompts, last_prompts = [], [] for i in range(0, audio_feature.shape[-1], window): chunk = audio_feature[:, :, i : i + window] layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states last = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2) audio_prompts.append(torch.stack(layers, dim=2)) # (1,?,L,384) last_prompts.append(last) # (1,?,1,384) if not audio_prompts: raise ValueError("[ERROR] No speech recognised in the provided audio.") audio_prompts = torch.cat(audio_prompts, dim=1) last_prompts = torch.cat(last_prompts, dim=1) # ---------- 모델 입력 규칙에 맞춰 padding ----------------------- audio_prompts = torch.cat( [torch.zeros_like(audio_prompts[:, :4]), # head pad audio_prompts, torch.zeros_like(audio_prompts[:, :6])], dim=1) last_prompts = torch.cat( [torch.zeros_like(last_prompts[:, :24]), last_prompts, torch.zeros_like(last_prompts[:, :26])], dim=1) # ---------- 음성 길이에 따라 chunk 횟수 산정 --------------------- total_tokens = audio_prompts.shape[1] num_chunks = max(1, math.ceil(total_tokens / (2 * step))) ref_list, audio_list, uncond_list, motion_buckets = [], [], [], [] for i in tqdm(range(num_chunks)): start = i * 2 * step # ---------------- cond_clip (w=10,L=5) -------------------- clip_raw = audio_prompts[:, start : start + 10] # (1,≤10,L,384) # w-pad if clip_raw.shape[1] < 10: pad_w = torch.zeros_like(clip_raw[:, :10 - clip_raw.shape[1]]) clip_raw = torch.cat([clip_raw, pad_w], dim=1) # ★ L-pad (Whisper-tiny → L=2 → 5로 확장) if clip_raw.shape[2] < 5: pad_L = clip_raw[:, :, -1:].repeat(1, 1, 5 - clip_raw.shape[2], 1) clip_raw = torch.cat([clip_raw, pad_L], dim=2) clip_raw = clip_raw[:, :, :5] # (1,10,5,384) cond_clip = clip_raw.unsqueeze(1) # (1,1,10,5,384) # ---------------- bucket_clip (w=50,L=1) ------------------ bucket_raw = last_prompts[:, start : start + 50] # (1,≤50,1,384) if bucket_raw.shape[1] < 50: pad_w = torch.zeros_like(bucket_raw[:, :50 - bucket_raw.shape[1]]) bucket_raw = torch.cat([bucket_raw, pad_w], dim=1) bucket_clip = bucket_raw.unsqueeze(1) # (1,1,50,1,384) motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16 ref_list.append(ref_img[0]) audio_list.append(audio_pe(cond_clip).squeeze(0)[0]) # (10,1024) uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0)[0]) motion_buckets.append(motion[0]) # ---------- Stable-Video-Diffusion 호출 ------------------------- video = pipe( ref_img, clip_img, face_mask, audio_list, uncond_list, motion_buckets, height=height, width=width, num_frames=len(audio_list), decode_chunk_size=config.decode_chunk_size, motion_bucket_scale=config.motion_bucket_scale, fps=config.fps, noise_aug_strength=config.noise_aug_strength, min_guidance_scale1=config.min_appearance_guidance_scale, max_guidance_scale1=config.max_appearance_guidance_scale, min_guidance_scale2=config.audio_guidance_scale, max_guidance_scale2=config.audio_guidance_scale, overlap=config.overlap, shift_offset=config.shift_offset, frames_per_batch=config.n_sample_frames, num_inference_steps=config.num_inference_steps, i2i_noise_strength=config.i2i_noise_strength, ).frames video = (video * 0.5 + 0.5).clamp(0, 1) return video.to(pipe.device).unsqueeze(0).cpu() # ------------------------------------------------------------------ # Sonic class # ------------------------------------------------------------------ class Sonic: config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml") config = OmegaConf.load(config_file) def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True): cfg = self.config cfg.use_interframe = enable_interpolate_frame self.device = f"cuda:{device_id}" if device_id >= 0 and torch.cuda.is_available() else "cpu" cfg.pretrained_model_name_or_path = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path) self._load_models(cfg) print("Sonic init done") # -------------------------------------------------------------- def _load_models(self, cfg): dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype] vae = AutoencoderKLTemporalDecoder.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", variant="fp16") sched = EulerDiscreteScheduler .from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler") imgenc= CLIPVisionModelWithProjection .from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16") unet = UNetSpatioTemporalConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", variant="fp16") add_ip_adapters(unet, [32], [cfg.ip_audio_scale]) a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device) a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device) unet .load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu")) a2t .load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu")) a2b .load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu")) whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval() whisper.requires_grad_(False) self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")) self.face_det = AlignImage(self.device, det_path=os.path.join(BASE_DIR, "checkpoints/yoloface_v5m.pt")) if cfg.use_interframe: self.rife = RIFEModel(device=self.device) self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/")) for m in (imgenc, vae, unet): m.to(dtype) self.pipe = SonicPipeline(unet=unet, image_encoder=imgenc, vae=vae, scheduler=sched).to(device=self.device, dtype=dtype) self.image_encoder = imgenc self.audio2token = a2t self.audio2bucket = a2b self.whisper = whisper # -------------------------------------------------------------- def preprocess(self, image_path: str, expand_ratio: float = 1.0): img = cv2.imread(image_path) h, w = img.shape[:2] _, _, bboxes = self.face_det(img, maxface=True) if bboxes: x1, y1, ww, hh = bboxes[0] return {"face_num": 1, "crop_bbox": process_bbox((x1, y1, x1 + ww, y1 + hh), expand_ratio, h, w)} return {"face_num": 0, "crop_bbox": None} # -------------------------------------------------------------- @torch.no_grad() def process( self, image_path: str, audio_path: str, output_path: str, min_resolution: int = 512, inference_steps: int = 25, dynamic_scale: float = 1.0, keep_resolution: bool = False, seed: int | None = None, ): cfg = self.config if seed is not None: cfg.seed = seed cfg.num_inference_steps = inference_steps cfg.motion_bucket_scale = dynamic_scale seed_everything(cfg.seed) # 이미지·오디오 → tensor test_data = image_audio_to_tensor( self.face_det, self.feature_extractor, image_path, audio_path, limit=-1, image_size=min_resolution, area=cfg.area, ) if test_data is None: return -1 h, w = test_data["ref_img"].shape[-2:] resolution = ( f"{(Image.open(image_path).width // 2) * 2}x{(Image.open(image_path).height // 2) * 2}" if keep_resolution else f"{w}x{h}" ) # 비디오 프레임 생성 video = test( self.pipe, cfg, self.whisper, self.audio2token, self.audio2bucket, self.image_encoder, width=w, height=h, batch=test_data, ) # 중간 프레임 보간 if cfg.use_interframe: out = video.to(self.device) frames = [] for i in tqdm(range(out.shape[2] - 1), ncols=0): mid = self.rife.inference(out[:, :, i], out[:, :, i + 1]).clamp(0, 1).detach() frames.extend([out[:, :, i], mid]) frames.append(out[:, :, -1]) video = torch.stack(frames, 2).cpu() # 저장 tmp_mp4 = output_path.replace(".mp4", "_noaudio.mp4") save_videos_grid(video, tmp_mp4, n_rows=video.shape[0], fps=cfg.fps * (2 if cfg.use_interframe else 1)) os.system( f"ffmpeg -i '{tmp_mp4}' -i '{audio_path}' -s {resolution} " f"-vcodec libx264 -acodec aac -crf 18 -shortest '{output_path}' -y -loglevel error" ) os.remove(tmp_mp4) return 0