# sonic.py # --------------------------------------------------------------------- # Sonic – single-image + speech → talking-head video (offline edition) # --------------------------------------------------------------------- import os, math from typing import Dict, Any, List import torch from PIL import Image from omegaconf import OmegaConf from tqdm import tqdm import cv2 from diffusers import AutoencoderKLTemporalDecoder from diffusers.schedulers import EulerDiscreteScheduler from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatureExtractor from src.utils.util import save_videos_grid, seed_everything from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor from src.models.base.unet_spatio_temporal_condition import ( UNetSpatioTemporalConditionModel, add_ip_adapters, ) from src.pipelines.pipeline_sonic import SonicPipeline from src.models.audio_adapter.audio_proj import AudioProjModel from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel from src.utils.RIFE.RIFE_HDv3 import RIFEModel from src.dataset.face_align.align import AlignImage BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # ------------------------------------------------------------------ # # 헬퍼 : diffusers 경로 자동 찾기 # # ------------------------------------------------------------------ # def _locate_diffusers_dir(root: str) -> str: """ `root` 하위 디렉터리에서 diffusers 스냅샷(model_index.json or config.json) 이 들어 있는 실제 모델 폴더를 찾아서 반환한다. 존재하지 않으면 오류. """ for cur, _dirs, files in os.walk(root): if {"model_index.json", "config.json"} & set(files): return cur raise FileNotFoundError( f"[ERROR] No diffusers model files found under '{root}'. " "Check that the checkpoint was downloaded correctly." ) # ------------------------------------------------------------------ # # 영상 생성용 내부 함수 # # ------------------------------------------------------------------ # def _gen_video_tensor( pipe: SonicPipeline, cfg: OmegaConf, wav_enc: WhisperModel, audio_pe: AudioProjModel, audio2bucket: Audio2bucketModel, image_encoder: CLIPVisionModelWithProjection, width: int, height: int, batch: Dict[str, torch.Tensor], ) -> torch.Tensor: """ single 이미지 + 오디오 feature → video tensor (C,T,H,W) """ # -------- batch 차원 보정 -------------------------------------- for k, v in batch.items(): if isinstance(v, torch.Tensor): batch[k] = v.unsqueeze(0).to(pipe.device).float() ref_img = batch["ref_img"] # (1,C,H,W) clip_img = batch["clip_images"] face_mask = batch["face_mask"] image_embeds = image_encoder(clip_img).image_embeds audio_feat: torch.Tensor = batch["audio_feature"] # (1, 80, T) audio_len: int = int(batch["audio_len"]) # scalar step: int = int(cfg.step) # step 이 전체 길이보다 크면 최소 1 로 보정 if audio_len < step: step = max(1, audio_len) # -------- Whisper encoder 1초 단위로 수행 ---------------------- window = 16_000 # 1-s chunk aud_prompts: List[torch.Tensor] = [] last_prompts: List[torch.Tensor] = [] for i in range(0, audio_feat.shape[-1], window): chunk = audio_feat[:, :, i : i + window] # 모든 hidden-states / 마지막 hidden-state layers: List[torch.Tensor] = wav_enc.encoder( chunk, output_hidden_states=True ).hidden_states last_hidden = wav_enc.encoder(chunk).last_hidden_state # (1, 80, 384) # Whisper layer 는 6개 → AudioProj 가 기대하는 5개로 truncate prompt = torch.stack(layers, dim=2)[:, :, :5] # (1,80,5,384) aud_prompts.append(prompt) last_prompts.append(last_hidden.unsqueeze(-2)) # (1,80,1,384) if len(aud_prompts) == 0: raise ValueError("[ERROR] No speech recognised in the provided audio.") # concat 뒤 padding 규칙 적용 aud_prompts = torch.cat(aud_prompts, dim=1) # (1, 80*…, 5, 384) last_prompts = torch.cat(last_prompts, dim=1) # (1, 80*…, 1, 384) aud_prompts = torch.cat( [torch.zeros_like(aud_prompts[:, :4]), aud_prompts, torch.zeros_like(aud_prompts[:, :6])], dim=1, ) last_prompts = torch.cat( [torch.zeros_like(last_prompts[:, :24]), last_prompts, torch.zeros_like(last_prompts[:, :26])], dim=1, ) # -------- f=10 / w=5 로 clip 자르기 -------------------------- ref_list, aud_list, uncond_list, mb_list = [], [], [], [] total_tokens = aud_prompts.shape[1] n_chunks = max(1, math.ceil(total_tokens / (2 * step))) for i in tqdm(range(n_chunks), desc="audio-chunks", ncols=0): s = i * 2 * step cond_clip = aud_prompts[:, s : s + 10] # (1,10,5,384) if cond_clip.shape[1] < 10: # 뒤쪽 padding pad = torch.zeros_like(cond_clip[:, : 10 - cond_clip.shape[1]]) cond_clip = torch.cat([cond_clip, pad], dim=1) bucket_clip = last_prompts[:, s : s + 50] # (1,50,1,384) if bucket_clip.shape[1] < 50: pad = torch.zeros_like(bucket_clip[:, : 50 - bucket_clip.shape[1]]) bucket_clip = torch.cat([bucket_clip, pad], dim=1) # (bz,f,w,b,c) 5-D 로 변환 cond_clip = cond_clip.unsqueeze(3) # (1,10,5,1,384) bucket_clip = bucket_clip.unsqueeze(3) # (1,50,1,1,384) uncond_clip = torch.zeros_like(cond_clip) motion_bucket = audio2bucket(bucket_clip, image_embeds) * 16 + 16 ref_list .append(ref_img[0]) aud_list .append(audio_pe(cond_clip).squeeze(0)[0]) # (ctx,1024) uncond_list .append(audio_pe(uncond_clip).squeeze(0)[0]) # (ctx,1024) mb_list .append(motion_bucket[0]) # -------- UNet 파이프라인 실행 -------------------------------- video = ( pipe( ref_img, clip_img, face_mask, aud_list, uncond_list, mb_list, height=height, width=width, num_frames=len(aud_list), decode_chunk_size=cfg.decode_chunk_size, motion_bucket_scale=cfg.motion_bucket_scale, fps=cfg.fps, noise_aug_strength=cfg.noise_aug_strength, min_guidance_scale1=cfg.min_appearance_guidance_scale, max_guidance_scale1=cfg.max_appearance_guidance_scale, min_guidance_scale2=cfg.audio_guidance_scale, max_guidance_scale2=cfg.audio_guidance_scale, overlap=cfg.overlap, shift_offset=cfg.shift_offset, frames_per_batch=cfg.n_sample_frames, num_inference_steps=cfg.num_inference_steps, i2i_noise_strength=cfg.i2i_noise_strength, ).frames * 0.5 + 0.5 ).clamp(0, 1) # (B,C,T,H,W) → (C,T,H,W) return video.to(pipe.device).squeeze(0).cpu() # ------------------------------------------------------------------ # # Sonic – main class # # ------------------------------------------------------------------ # class Sonic: config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml") config = OmegaConf.load(config_file) def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True): cfg = self.config cfg.use_interframe = enable_interpolate_frame # diffusers 모델 상위 폴더 (로컬 다운로드 경로) self.diffusers_root = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path) self.device = ( f"cuda:{device_id}" if device_id >= 0 and torch.cuda.is_available() else "cpu" ) self._load_models(cfg) print("Sonic init done") # -------------------------------------------------------------- # def _load_models(self, cfg): # dtype dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype] diff_root = _locate_diffusers_dir(self.diffusers_root) # diffusers 모듈들 vae = AutoencoderKLTemporalDecoder.from_pretrained(diff_root, subfolder="vae", variant="fp16") sched = EulerDiscreteScheduler.from_pretrained(diff_root, subfolder="scheduler") img_e = CLIPVisionModelWithProjection.from_pretrained(diff_root, subfolder="image_encoder", variant="fp16") unet = UNetSpatioTemporalConditionModel.from_pretrained(diff_root, subfolder="unet", variant="fp16") add_ip_adapters(unet, [32], [cfg.ip_audio_scale]) # 오디오 어댑터 a2t = AudioProjModel(seq_len=10, blocks=5, channels=384, intermediate_dim=1024, output_dim=1024, context_tokens=32).to(self.device) a2b = Audio2bucketModel(seq_len=50, blocks=1, channels=384, clip_channels=1024, intermediate_dim=1024, output_dim=1, context_tokens=2).to(self.device) # 체크포인트 로드 a2t.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu")) a2b.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu")) unet.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu")) # Whisper whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval() whisper.requires_grad_(False) # 이미지 / 얼굴 / 보간 self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")) self.face_det = AlignImage(self.device, det_path=os.path.join(BASE_DIR, "checkpoints/yoloface_v5m.pt")) if cfg.use_interframe: self.rife = RIFEModel(device=self.device) self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/")) # dtype 적용 for m in (vae, img_e, unet): m.to(dtype) self.pipe = SonicPipeline(unet=unet, image_encoder=img_e, vae=vae, scheduler=sched).to(self.device, dtype=dtype) self.image_encoder = img_e self.audio2token = a2t self.audio2bucket = a2b self.whisper = whisper # -------------------------------------------------------------- # def preprocess(self, image_path: str, expand_ratio: float = 1.0) -> Dict[str, Any]: img = cv2.imread(image_path) h, w = img.shape[:2] _, _, bboxes = self.face_det(img, maxface=True) if bboxes: x1, y1, ww, hh = bboxes[0] crop = process_bbox((x1, y1, x1 + ww, y1 + hh), expand_ratio, h, w) return {"face_num": 1, "crop_bbox": crop} return {"face_num": 0, "crop_bbox": None} # -------------------------------------------------------------- # @torch.no_grad() def process( self, image_path: str, audio_path: str, output_path: str, min_resolution: int = 512, inference_steps: int = 25, dynamic_scale: float = 1.0, keep_resolution: bool = False, seed: int | None = None, ) -> int: cfg = self.config if seed is not None: cfg.seed = seed cfg.num_inference_steps = inference_steps cfg.motion_bucket_scale = dynamic_scale seed_everything(cfg.seed) # 이미지·오디오 tensor 변환 data = image_audio_to_tensor( self.face_det, self.feature_extractor, image_path, audio_path, limit=-1, image_size=min_resolution, area=cfg.area, ) if data is None: return -1 h, w = data["ref_img"].shape[-2:] if keep_resolution: im = Image.open(image_path) resolution = f"{(im.width // 2) * 2}x{(im.height // 2) * 2}" else: resolution = f"{w}x{h}" # video tensor 생성 video = _gen_video_tensor( self.pipe, cfg, self.whisper, self.audio2token, self.audio2bucket, self.image_encoder, w, h, data, ) # 중간 프레임 보간 if cfg.use_interframe: out = video.to(self.device) frames = [] for i in tqdm(range(out.shape[1] - 1), desc="interpolate", ncols=0): frames.extend([out[:, i], self.rife.inference(out[:, i], out[:, i + 1]).clamp(0, 1)]) frames.append(out[:, -1]) video = torch.stack(frames, 1).cpu() # (C,T',H,W) # 저장 tmp = output_path.replace(".mp4", "_noaudio.mp4") save_videos_grid(video.unsqueeze(0), tmp, n_rows=1, fps=cfg.fps * (2 if cfg.use_interframe else 1)) os.system( f"ffmpeg -loglevel error -y -i '{tmp}' -i '{audio_path}' -s {resolution} " f"-vcodec libx264 -acodec aac -crf 18 -shortest '{output_path}'" ) os.remove(tmp) return 0