# --------------------------------------------------------- # sonic.py (2025-05 rev – fix AudioProjModel tensor shape) # --------------------------------------------------------- import os, math, torch, cv2 import torch.utils.checkpoint from PIL import Image from omegaconf import OmegaConf from tqdm import tqdm from diffusers import AutoencoderKLTemporalDecoder from diffusers.schedulers import EulerDiscreteScheduler from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatureExtractor from src.utils.util import save_videos_grid, seed_everything from src.dataset.test_preprocess import image_audio_to_tensor, process_bbox from src.models.base.unet_spatio_temporal_condition import ( UNetSpatioTemporalConditionModel, add_ip_adapters, ) from src.models.audio_adapter.audio_proj import AudioProjModel from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel from src.pipelines.pipeline_sonic import SonicPipeline from src.utils.RIFE.RIFE_HDv3 import RIFEModel from src.dataset.face_align.align import AlignImage BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # ------------------------------------------------------------------ # single image + speech → video tensor # ------------------------------------------------------------------ def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, img_enc, width, height, batch): # --- batch 차원 맞추기 ------------------------------------------ for k, v in batch.items(): if isinstance(v, torch.Tensor): batch[k] = v.unsqueeze(0).float().to(pipe.device) ref_img = batch['ref_img'] clip_img = batch['clip_images'] face_mask = batch['face_mask'] img_emb = img_enc(clip_img).image_embeds # (1,1024) audio_feat = batch['audio_feature'] # (1,80,T) audio_len = int(batch['audio_len']) step = max(1, int(cfg.step)) # 안전 보정 window = 16_000 # 1-초 chunk prompt_list, last_list = [], [] for i in range(0, audio_feat.shape[-1], window): chunk = audio_feat[:, :, i:i+window] hs = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states prompt_list.append(torch.stack(hs, 2)) # (1,80,L,384) last = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2) last_list.append(last) # (1,80,1,384) if not prompt_list: raise ValueError("❌ No speech recognised in audio.") audio_prompts = torch.cat(prompt_list, 1) # (1,80,*L,384) last_prompts = torch.cat(last_list, 1) # (1,80,*1,384) # pad 규칙 (모델 원 논문과 동일) audio_prompts = torch.cat([ torch.zeros_like(audio_prompts[:,:4]), audio_prompts, torch.zeros_like(audio_prompts[:,:6]) ], 1) last_prompts = torch.cat([ torch.zeros_like(last_prompts[:,:24]), last_prompts, torch.zeros_like(last_prompts[:,:26]) ], 1) # -------------------------------------------------------------- total_tok = audio_prompts.shape[1] n_chunks = max(1, math.ceil(total_tok / (2*step))) ref_L, aud_L, uncond_L, buckets = [], [], [], [] for i in tqdm(range(n_chunks), ncols=0): st = i * 2 * step # ① 조건 오디오 토큰(pad → 10×5×384) cond = audio_prompts[:, st:st+10] # (1,80,10,384) → (1,10,8,384)? cond = cond[:, :10] # f = 10 cond = cond.permute(0,2,1,3) # (1,10,80,384) cond = cond.reshape(1, 10, 10, 5, 384) # ★ w=10, b=5 (zero-pad auto) # ② bucket 추정용 토큰 buck = last_prompts[:, st:st+50] # (1,80,50,384) if buck.shape[1] < 50: pad = torch.zeros(1, 50-buck.shape[1], *buck.shape[2:], device=buck.device) buck = torch.cat([buck, pad], 1) buck = buck[:, :50].permute(0,2,1,3).reshape(1, 50, 10, 5, 384) motion = audio2bucket(buck, img_emb) * 16 + 16 ref_L.append(ref_img[0]) aud_L.append(audio_pe(cond).squeeze(0)) # (10,1024) uncond_L.append(audio_pe(torch.zeros_like(cond)).squeeze(0)) buckets.append(motion[0]) # -------------- diffusion ------------------------------------------------- vid = pipe( ref_img, clip_img, face_mask, aud_L, uncond_L, buckets, height=height, width=width, num_frames=len(aud_L), decode_chunk_size=cfg.decode_chunk_size, motion_bucket_scale=cfg.motion_bucket_scale, fps=cfg.fps, noise_aug_strength=cfg.noise_aug_strength, min_guidance_scale1=cfg.min_appearance_guidance_scale, max_guidance_scale1=cfg.max_appearance_guidance_scale, min_guidance_scale2=cfg.audio_guidance_scale, max_guidance_scale2=cfg.audio_guidance_scale, overlap=cfg.overlap, shift_offset=cfg.shift_offset, frames_per_batch=cfg.n_sample_frames, num_inference_steps=cfg.num_inference_steps, i2i_noise_strength=cfg.i2i_noise_strength, ).frames return (vid*0.5+0.5).clamp(0,1).to(pipe.device).unsqueeze(0).cpu() # ------------------------------------------------------------------ # Sonic wrapper # ------------------------------------------------------------------ class Sonic: config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml") config = OmegaConf.load(config_file) def __init__(self, device_id=0, enable_interpolate_frame=True): cfg = self.config cfg.use_interframe = enable_interpolate_frame self.device = f"cuda:{device_id}" if torch.cuda.is_available() and device_id>=0 else "cpu" cfg.pretrained_model_name_or_path = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path) self._load_models(cfg) print("Sonic init done") # model-loader (unchanged, but with tiny clean-ups) ------------------------ def _load_models(self, cfg): dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype] vae = AutoencoderKLTemporalDecoder.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", variant="fp16") sched = EulerDiscreteScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler") img_enc = CLIPVisionModelWithProjection.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16") unet = UNetSpatioTemporalConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", variant="fp16") add_ip_adapters(unet, [32], [cfg.ip_audio_scale]) self.audio2token = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device) self.audio2bucket = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device) unet.load_state_dict (torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu")) self.audio2token.load_state_dict (torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu")) self.audio2bucket.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu")) self.whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval() self.whisper.requires_grad_(False) self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")) self.face_det = AlignImage(self.device, det_path=os.path.join(BASE_DIR, "checkpoints/yoloface_v5m.pt")) if cfg.use_interframe: self.rife = RIFEModel(device=self.device); self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/")) for m in (img_enc, vae, unet): m.to(dtype) self.pipe = SonicPipeline(unet=unet, image_encoder=img_enc, vae=vae, scheduler=sched).to(device=self.device, dtype=dtype) self.image_encoder = img_enc # ------------------------------------------------------------------ def preprocess(self, img_path, expand_ratio=1.0): img = cv2.imread(img_path) _, _, boxes = self.face_det(img, maxface=True) if boxes: x,y,w,h = boxes[0]; return {"face_num":1,"crop_bbox":process_bbox((x,y,x+w,y+h),expand_ratio,*img.shape[:2])} return {"face_num":0,"crop_bbox":None} # ------------------------------------------------------------------ @torch.no_grad() def process(self, img_path, wav_path, out_path, min_resolution=512, inference_steps=25, dynamic_scale=1.0, keep_resolution=False, seed=None): cfg = self.config if seed is not None: cfg.seed = seed cfg.num_inference_steps = inference_steps cfg.motion_bucket_scale = dynamic_scale seed_everything(cfg.seed) sample = image_audio_to_tensor( self.face_det, self.feature_extractor, img_path, wav_path, limit=-1, image_size=min_resolution, area=cfg.area, ) if sample is None: return -1 h,w = sample['ref_img'].shape[-2:] resolution = (f"{Image.open(img_path).width//2*2}x{Image.open(img_path).height//2*2}" if keep_resolution else f"{w}x{h}") video = test(self.pipe, cfg, self.whisper, self.audio2token, self.audio2bucket, self.image_encoder, w, h, sample) if cfg.use_interframe: # RIFE interpolation out = video.to(self.device); frames=[] for i in tqdm(range(out.shape[2]-1), ncols=0): mid = self.rife.inference(out[:,:,i], out[:,:,i+1]).clamp(0,1) frames += [out[:,:,i], mid] frames.append(out[:,:,-1]); video = torch.stack(frames,2).cpu() tmp = out_path.replace(".mp4","_noaudio.mp4") save_videos_grid(video, tmp, n_rows=video.shape[0], fps=cfg.fps*(2 if cfg.use_interframe else 1)) os.system(f"ffmpeg -i '{tmp}' -i '{wav_path}' -s {resolution} " f"-vcodec libx264 -acodec aac -crf 18 -shortest '{out_path}' -y -loglevel error") os.remove(tmp); return 0