Portrait-Animation / sonic.py
openfree's picture
Update sonic.py
b1cb088 verified
raw
history blame
11.6 kB
import os, math
import torch
from PIL import Image
from omegaconf import OmegaConf
from tqdm import tqdm
import cv2
from diffusers import AutoencoderKLTemporalDecoder
from diffusers.schedulers import EulerDiscreteScheduler
from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatureExtractor
from src.utils.util import save_videos_grid, seed_everything
from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor
from src.models.base.unet_spatio_temporal_condition import (
UNetSpatioTemporalConditionModel, add_ip_adapters,
)
from src.pipelines.pipeline_sonic import SonicPipeline
from src.models.audio_adapter.audio_proj import AudioProjModel
from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel
from src.utils.RIFE.RIFE_HDv3 import RIFEModel
from src.dataset.face_align.align import AlignImage
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# ------------------------------------------------------------------
# single image + speech → video-tensor generator
# ------------------------------------------------------------------
def test(
pipe, config, wav_enc, audio_pe, audio2bucket, image_encoder,
width, height, batch,
):
# ---------------- batch 차원 맞추기 -----------------------------
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.unsqueeze(0).to(pipe.device).float()
ref_img = batch["ref_img"] # (1,C,H,W)
clip_img = batch["clip_images"]
face_mask = batch["face_mask"]
image_embeds = image_encoder(clip_img).image_embeds # (1,1024)
audio_feature = batch["audio_feature"] # (1,80,T)
audio_len = int(batch["audio_len"]) # python int
step = int(config.step)
# ---------- window 단위 Whisper 인코딩 --------------------------
window = 16_000 # 1 초
audio_prompts, last_prompts = [], []
for i in range(0, audio_feature.shape[-1], window):
chunk = audio_feature[:, :, i : i + window]
layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
last = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
audio_prompts.append(torch.stack(layers, dim=2)) # (1,?,L,384)
last_prompts.append(last) # (1,?,1,384)
if not audio_prompts:
raise ValueError("[ERROR] No speech recognised in the provided audio.")
audio_prompts = torch.cat(audio_prompts, dim=1)
last_prompts = torch.cat(last_prompts, dim=1)
# ---------- 모델 입력 규칙에 맞춰 padding -----------------------
audio_prompts = torch.cat(
[torch.zeros_like(audio_prompts[:, :4]), # head pad
audio_prompts,
torch.zeros_like(audio_prompts[:, :6])], dim=1)
last_prompts = torch.cat(
[torch.zeros_like(last_prompts[:, :24]),
last_prompts,
torch.zeros_like(last_prompts[:, :26])], dim=1)
# ---------- 음성 길이에 따라 chunk 횟수 산정 ---------------------
total_tokens = audio_prompts.shape[1]
num_chunks = max(1, math.ceil(total_tokens / (2 * step)))
ref_list, audio_list, uncond_list, motion_buckets = [], [], [], []
for i in tqdm(range(num_chunks)):
start = i * 2 * step
# ---------------- cond_clip (w=10,L=5) --------------------
clip_raw = audio_prompts[:, start : start + 10] # (1,≤10,L,384)
# w-pad
if clip_raw.shape[1] < 10:
pad_w = torch.zeros_like(clip_raw[:, :10 - clip_raw.shape[1]])
clip_raw = torch.cat([clip_raw, pad_w], dim=1)
# ★ L-pad (Whisper-tiny → L=2 → 5로 확장)
if clip_raw.shape[2] < 5:
pad_L = clip_raw[:, :, -1:].repeat(1, 1, 5 - clip_raw.shape[2], 1)
clip_raw = torch.cat([clip_raw, pad_L], dim=2)
clip_raw = clip_raw[:, :, :5] # (1,10,5,384)
cond_clip = clip_raw.unsqueeze(1) # (1,1,10,5,384)
# ---------------- bucket_clip (w=50,L=1) ------------------
bucket_raw = last_prompts[:, start : start + 50] # (1,≤50,1,384)
if bucket_raw.shape[1] < 50:
pad_w = torch.zeros_like(bucket_raw[:, :50 - bucket_raw.shape[1]])
bucket_raw = torch.cat([bucket_raw, pad_w], dim=1)
bucket_clip = bucket_raw.unsqueeze(1) # (1,1,50,1,384)
motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
ref_list.append(ref_img[0])
audio_list.append(audio_pe(cond_clip).squeeze(0)[0]) # (10,1024)
uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0)[0])
motion_buckets.append(motion[0])
# ---------- Stable-Video-Diffusion 호출 -------------------------
video = pipe(
ref_img, clip_img, face_mask,
audio_list, uncond_list, motion_buckets,
height=height, width=width,
num_frames=len(audio_list),
decode_chunk_size=config.decode_chunk_size,
motion_bucket_scale=config.motion_bucket_scale,
fps=config.fps,
noise_aug_strength=config.noise_aug_strength,
min_guidance_scale1=config.min_appearance_guidance_scale,
max_guidance_scale1=config.max_appearance_guidance_scale,
min_guidance_scale2=config.audio_guidance_scale,
max_guidance_scale2=config.audio_guidance_scale,
overlap=config.overlap,
shift_offset=config.shift_offset,
frames_per_batch=config.n_sample_frames,
num_inference_steps=config.num_inference_steps,
i2i_noise_strength=config.i2i_noise_strength,
).frames
video = (video * 0.5 + 0.5).clamp(0, 1)
return video.to(pipe.device).unsqueeze(0).cpu()
# ------------------------------------------------------------------
# Sonic class
# ------------------------------------------------------------------
class Sonic:
config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
config = OmegaConf.load(config_file)
def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
cfg = self.config
cfg.use_interframe = enable_interpolate_frame
self.device = f"cuda:{device_id}" if device_id >= 0 and torch.cuda.is_available() else "cpu"
cfg.pretrained_model_name_or_path = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path)
self._load_models(cfg)
print("Sonic init done")
# --------------------------------------------------------------
def _load_models(self, cfg):
dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
vae = AutoencoderKLTemporalDecoder.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", variant="fp16")
sched = EulerDiscreteScheduler .from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler")
imgenc= CLIPVisionModelWithProjection .from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16")
unet = UNetSpatioTemporalConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", variant="fp16")
add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
unet .load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu"))
a2t .load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu"))
a2b .load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu"))
whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
whisper.requires_grad_(False)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny"))
self.face_det = AlignImage(self.device, det_path=os.path.join(BASE_DIR, "checkpoints/yoloface_v5m.pt"))
if cfg.use_interframe:
self.rife = RIFEModel(device=self.device)
self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
for m in (imgenc, vae, unet):
m.to(dtype)
self.pipe = SonicPipeline(unet=unet, image_encoder=imgenc, vae=vae, scheduler=sched).to(device=self.device, dtype=dtype)
self.image_encoder = imgenc
self.audio2token = a2t
self.audio2bucket = a2b
self.whisper = whisper
# --------------------------------------------------------------
def preprocess(self, image_path: str, expand_ratio: float = 1.0):
img = cv2.imread(image_path)
h, w = img.shape[:2]
_, _, bboxes = self.face_det(img, maxface=True)
if bboxes:
x1, y1, ww, hh = bboxes[0]
return {"face_num": 1, "crop_bbox": process_bbox((x1, y1, x1 + ww, y1 + hh), expand_ratio, h, w)}
return {"face_num": 0, "crop_bbox": None}
# --------------------------------------------------------------
@torch.no_grad()
def process(
self,
image_path: str,
audio_path: str,
output_path: str,
min_resolution: int = 512,
inference_steps: int = 25,
dynamic_scale: float = 1.0,
keep_resolution: bool = False,
seed: int | None = None,
):
cfg = self.config
if seed is not None:
cfg.seed = seed
cfg.num_inference_steps = inference_steps
cfg.motion_bucket_scale = dynamic_scale
seed_everything(cfg.seed)
# 이미지·오디오 → tensor
test_data = image_audio_to_tensor(
self.face_det,
self.feature_extractor,
image_path,
audio_path,
limit=-1,
image_size=min_resolution,
area=cfg.area,
)
if test_data is None:
return -1
h, w = test_data["ref_img"].shape[-2:]
resolution = (
f"{(Image.open(image_path).width // 2) * 2}x{(Image.open(image_path).height // 2) * 2}"
if keep_resolution
else f"{w}x{h}"
)
# 비디오 프레임 생성
video = test(
self.pipe, cfg, self.whisper, self.audio2token,
self.audio2bucket, self.image_encoder,
width=w, height=h, batch=test_data,
)
# 중간 프레임 보간
if cfg.use_interframe:
out = video.to(self.device)
frames = []
for i in tqdm(range(out.shape[2] - 1), ncols=0):
mid = self.rife.inference(out[:, :, i], out[:, :, i + 1]).clamp(0, 1).detach()
frames.extend([out[:, :, i], mid])
frames.append(out[:, :, -1])
video = torch.stack(frames, 2).cpu()
# 저장
tmp_mp4 = output_path.replace(".mp4", "_noaudio.mp4")
save_videos_grid(video, tmp_mp4, n_rows=video.shape[0], fps=cfg.fps * (2 if cfg.use_interframe else 1))
os.system(
f"ffmpeg -i '{tmp_mp4}' -i '{audio_path}' -s {resolution} "
f"-vcodec libx264 -acodec aac -crf 18 -shortest '{output_path}' -y -loglevel error"
)
os.remove(tmp_mp4)
return 0