Portrait-Animation / sonic.py
openfree's picture
Update sonic.py
5d6304b verified
raw
history blame
10.4 kB
# ---------------------------------------------------------
# sonic.py (2025-05 rev – fix AudioProjModel tensor shape)
# ---------------------------------------------------------
import os, math, torch, cv2
import torch.utils.checkpoint
from PIL import Image
from omegaconf import OmegaConf
from tqdm import tqdm
from diffusers import AutoencoderKLTemporalDecoder
from diffusers.schedulers import EulerDiscreteScheduler
from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatureExtractor
from src.utils.util import save_videos_grid, seed_everything
from src.dataset.test_preprocess import image_audio_to_tensor, process_bbox
from src.models.base.unet_spatio_temporal_condition import (
UNetSpatioTemporalConditionModel, add_ip_adapters,
)
from src.models.audio_adapter.audio_proj import AudioProjModel
from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel
from src.pipelines.pipeline_sonic import SonicPipeline
from src.utils.RIFE.RIFE_HDv3 import RIFEModel
from src.dataset.face_align.align import AlignImage
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# ------------------------------------------------------------------
# single image + speech → video tensor
# ------------------------------------------------------------------
def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, img_enc,
width, height, batch):
# --- batch 차원 맞추기 ------------------------------------------
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.unsqueeze(0).float().to(pipe.device)
ref_img = batch['ref_img']
clip_img = batch['clip_images']
face_mask = batch['face_mask']
img_emb = img_enc(clip_img).image_embeds # (1,1024)
audio_feat = batch['audio_feature'] # (1,80,T)
audio_len = int(batch['audio_len'])
step = max(1, int(cfg.step)) # 안전 보정
window = 16_000 # 1-초 chunk
prompt_list, last_list = [], []
for i in range(0, audio_feat.shape[-1], window):
chunk = audio_feat[:, :, i:i+window]
hs = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
prompt_list.append(torch.stack(hs, 2)) # (1,80,L,384)
last = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
last_list.append(last) # (1,80,1,384)
if not prompt_list:
raise ValueError("❌ No speech recognised in audio.")
audio_prompts = torch.cat(prompt_list, 1) # (1,80,*L,384)
last_prompts = torch.cat(last_list, 1) # (1,80,*1,384)
# pad 규칙 (모델 원 논문과 동일)
audio_prompts = torch.cat([ torch.zeros_like(audio_prompts[:,:4]),
audio_prompts,
torch.zeros_like(audio_prompts[:,:6]) ], 1)
last_prompts = torch.cat([ torch.zeros_like(last_prompts[:,:24]),
last_prompts,
torch.zeros_like(last_prompts[:,:26]) ], 1)
# --------------------------------------------------------------
total_tok = audio_prompts.shape[1]
n_chunks = max(1, math.ceil(total_tok / (2*step)))
ref_L, aud_L, uncond_L, buckets = [], [], [], []
for i in tqdm(range(n_chunks), ncols=0):
st = i * 2 * step
# ① 조건 오디오 토큰(pad → 10×5×384)
cond = audio_prompts[:, st:st+10] # (1,80,10,384) → (1,10,8,384)?
cond = cond[:, :10] # f = 10
cond = cond.permute(0,2,1,3) # (1,10,80,384)
cond = cond.reshape(1, 10, 10, 5, 384) # ★ w=10, b=5 (zero-pad auto)
# ② bucket 추정용 토큰
buck = last_prompts[:, st:st+50] # (1,80,50,384)
if buck.shape[1] < 50:
pad = torch.zeros(1, 50-buck.shape[1], *buck.shape[2:], device=buck.device)
buck = torch.cat([buck, pad], 1)
buck = buck[:, :50].permute(0,2,1,3).reshape(1, 50, 10, 5, 384)
motion = audio2bucket(buck, img_emb) * 16 + 16
ref_L.append(ref_img[0])
aud_L.append(audio_pe(cond).squeeze(0)) # (10,1024)
uncond_L.append(audio_pe(torch.zeros_like(cond)).squeeze(0))
buckets.append(motion[0])
# -------------- diffusion -------------------------------------------------
vid = pipe(
ref_img, clip_img, face_mask,
aud_L, uncond_L, buckets,
height=height, width=width,
num_frames=len(aud_L),
decode_chunk_size=cfg.decode_chunk_size,
motion_bucket_scale=cfg.motion_bucket_scale,
fps=cfg.fps,
noise_aug_strength=cfg.noise_aug_strength,
min_guidance_scale1=cfg.min_appearance_guidance_scale,
max_guidance_scale1=cfg.max_appearance_guidance_scale,
min_guidance_scale2=cfg.audio_guidance_scale,
max_guidance_scale2=cfg.audio_guidance_scale,
overlap=cfg.overlap,
shift_offset=cfg.shift_offset,
frames_per_batch=cfg.n_sample_frames,
num_inference_steps=cfg.num_inference_steps,
i2i_noise_strength=cfg.i2i_noise_strength,
).frames
return (vid*0.5+0.5).clamp(0,1).to(pipe.device).unsqueeze(0).cpu()
# ------------------------------------------------------------------
# Sonic wrapper
# ------------------------------------------------------------------
class Sonic:
config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
config = OmegaConf.load(config_file)
def __init__(self, device_id=0, enable_interpolate_frame=True):
cfg = self.config
cfg.use_interframe = enable_interpolate_frame
self.device = f"cuda:{device_id}" if torch.cuda.is_available() and device_id>=0 else "cpu"
cfg.pretrained_model_name_or_path = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path)
self._load_models(cfg)
print("Sonic init done")
# model-loader (unchanged, but with tiny clean-ups) ------------------------
def _load_models(self, cfg):
dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
vae = AutoencoderKLTemporalDecoder.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", variant="fp16")
sched = EulerDiscreteScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler")
img_enc = CLIPVisionModelWithProjection.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16")
unet = UNetSpatioTemporalConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", variant="fp16")
add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
self.audio2token = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
self.audio2bucket = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
unet.load_state_dict (torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu"))
self.audio2token.load_state_dict (torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu"))
self.audio2bucket.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu"))
self.whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
self.whisper.requires_grad_(False)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny"))
self.face_det = AlignImage(self.device, det_path=os.path.join(BASE_DIR, "checkpoints/yoloface_v5m.pt"))
if cfg.use_interframe:
self.rife = RIFEModel(device=self.device); self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
for m in (img_enc, vae, unet): m.to(dtype)
self.pipe = SonicPipeline(unet=unet, image_encoder=img_enc, vae=vae, scheduler=sched).to(device=self.device, dtype=dtype)
self.image_encoder = img_enc
# ------------------------------------------------------------------
def preprocess(self, img_path, expand_ratio=1.0):
img = cv2.imread(img_path)
_, _, boxes = self.face_det(img, maxface=True)
if boxes:
x,y,w,h = boxes[0]; return {"face_num":1,"crop_bbox":process_bbox((x,y,x+w,y+h),expand_ratio,*img.shape[:2])}
return {"face_num":0,"crop_bbox":None}
# ------------------------------------------------------------------
@torch.no_grad()
def process(self, img_path, wav_path, out_path,
min_resolution=512, inference_steps=25,
dynamic_scale=1.0, keep_resolution=False, seed=None):
cfg = self.config
if seed is not None: cfg.seed = seed
cfg.num_inference_steps = inference_steps
cfg.motion_bucket_scale = dynamic_scale
seed_everything(cfg.seed)
sample = image_audio_to_tensor(
self.face_det, self.feature_extractor,
img_path, wav_path, limit=-1,
image_size=min_resolution, area=cfg.area,
)
if sample is None: return -1
h,w = sample['ref_img'].shape[-2:]
resolution = (f"{Image.open(img_path).width//2*2}x{Image.open(img_path).height//2*2}"
if keep_resolution else f"{w}x{h}")
video = test(self.pipe, cfg, self.whisper, self.audio2token,
self.audio2bucket, self.image_encoder, w, h, sample)
if cfg.use_interframe: # RIFE interpolation
out = video.to(self.device); frames=[]
for i in tqdm(range(out.shape[2]-1), ncols=0):
mid = self.rife.inference(out[:,:,i], out[:,:,i+1]).clamp(0,1)
frames += [out[:,:,i], mid]
frames.append(out[:,:,-1]); video = torch.stack(frames,2).cpu()
tmp = out_path.replace(".mp4","_noaudio.mp4")
save_videos_grid(video, tmp, n_rows=video.shape[0], fps=cfg.fps*(2 if cfg.use_interframe else 1))
os.system(f"ffmpeg -i '{tmp}' -i '{wav_path}' -s {resolution} "
f"-vcodec libx264 -acodec aac -crf 18 -shortest '{out_path}' -y -loglevel error")
os.remove(tmp); return 0