Portrait-Animation / sonic.py
openfree's picture
Update sonic.py
7b4dc6f verified
raw
history blame
13.6 kB
# sonic.py
# ---------------------------------------------------------------------
# Sonic – single-image + speech → talking-head video (offline edition)
# ---------------------------------------------------------------------
import os, math
from typing import Dict, Any, List
import torch
from PIL import Image
from omegaconf import OmegaConf
from tqdm import tqdm
import cv2
from diffusers import AutoencoderKLTemporalDecoder
from diffusers.schedulers import EulerDiscreteScheduler
from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatureExtractor
from src.utils.util import save_videos_grid, seed_everything
from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor
from src.models.base.unet_spatio_temporal_condition import (
UNetSpatioTemporalConditionModel,
add_ip_adapters,
)
from src.pipelines.pipeline_sonic import SonicPipeline
from src.models.audio_adapter.audio_proj import AudioProjModel
from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel
from src.utils.RIFE.RIFE_HDv3 import RIFEModel
from src.dataset.face_align.align import AlignImage
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# ------------------------------------------------------------------ #
# 헬퍼 : diffusers 경로 자동 찾기 #
# ------------------------------------------------------------------ #
def _locate_diffusers_dir(root: str) -> str:
"""
`root` 하위 디렉터리에서 diffusers 스냅샷(model_index.json or config.json)
이 들어 있는 실제 모델 폴더를 찾아서 반환한다. 존재하지 않으면 오류.
"""
for cur, _dirs, files in os.walk(root):
if {"model_index.json", "config.json"} & set(files):
return cur
raise FileNotFoundError(
f"[ERROR] No diffusers model files found under '{root}'. "
"Check that the checkpoint was downloaded correctly."
)
# ------------------------------------------------------------------ #
# 영상 생성용 내부 함수 #
# ------------------------------------------------------------------ #
def _gen_video_tensor(
pipe: SonicPipeline,
cfg: OmegaConf,
wav_enc: WhisperModel,
audio_pe: AudioProjModel,
audio2bucket: Audio2bucketModel,
image_encoder: CLIPVisionModelWithProjection,
width: int,
height: int,
batch: Dict[str, torch.Tensor],
) -> torch.Tensor:
"""
single 이미지 + 오디오 feature → video tensor (C,T,H,W)
"""
# -------- batch 차원 보정 --------------------------------------
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.unsqueeze(0).to(pipe.device).float()
ref_img = batch["ref_img"] # (1,C,H,W)
clip_img = batch["clip_images"]
face_mask = batch["face_mask"]
image_embeds = image_encoder(clip_img).image_embeds
audio_feat: torch.Tensor = batch["audio_feature"] # (1, 80, T)
audio_len: int = int(batch["audio_len"]) # scalar
step: int = int(cfg.step)
# step 이 전체 길이보다 크면 최소 1 로 보정
if audio_len < step:
step = max(1, audio_len)
# -------- Whisper encoder 1초 단위로 수행 ----------------------
window = 16_000 # 1-s chunk
aud_prompts: List[torch.Tensor] = []
last_prompts: List[torch.Tensor] = []
for i in range(0, audio_feat.shape[-1], window):
chunk = audio_feat[:, :, i : i + window]
# 모든 hidden-states / 마지막 hidden-state
layers: List[torch.Tensor] = wav_enc.encoder(
chunk, output_hidden_states=True
).hidden_states
last_hidden = wav_enc.encoder(chunk).last_hidden_state # (1, 80, 384)
# Whisper layer 는 6개 → AudioProj 가 기대하는 5개로 truncate
prompt = torch.stack(layers, dim=2)[:, :, :5] # (1,80,5,384)
aud_prompts.append(prompt)
last_prompts.append(last_hidden.unsqueeze(-2)) # (1,80,1,384)
if len(aud_prompts) == 0:
raise ValueError("[ERROR] No speech recognised in the provided audio.")
# concat 뒤 padding 규칙 적용
aud_prompts = torch.cat(aud_prompts, dim=1) # (1, 80*…, 5, 384)
last_prompts = torch.cat(last_prompts, dim=1) # (1, 80*…, 1, 384)
aud_prompts = torch.cat(
[torch.zeros_like(aud_prompts[:, :4]), aud_prompts, torch.zeros_like(aud_prompts[:, :6])],
dim=1,
)
last_prompts = torch.cat(
[torch.zeros_like(last_prompts[:, :24]), last_prompts, torch.zeros_like(last_prompts[:, :26])],
dim=1,
)
# -------- f=10 / w=5 로 clip 자르기 --------------------------
ref_list, aud_list, uncond_list, mb_list = [], [], [], []
total_tokens = aud_prompts.shape[1]
n_chunks = max(1, math.ceil(total_tokens / (2 * step)))
for i in tqdm(range(n_chunks), desc="audio-chunks", ncols=0):
s = i * 2 * step
cond_clip = aud_prompts[:, s : s + 10] # (1,10,5,384)
if cond_clip.shape[1] < 10: # 뒤쪽 padding
pad = torch.zeros_like(cond_clip[:, : 10 - cond_clip.shape[1]])
cond_clip = torch.cat([cond_clip, pad], dim=1)
bucket_clip = last_prompts[:, s : s + 50] # (1,50,1,384)
if bucket_clip.shape[1] < 50:
pad = torch.zeros_like(bucket_clip[:, : 50 - bucket_clip.shape[1]])
bucket_clip = torch.cat([bucket_clip, pad], dim=1)
# (bz,f,w,b,c) 5-D 로 변환
cond_clip = cond_clip.unsqueeze(3) # (1,10,5,1,384)
bucket_clip = bucket_clip.unsqueeze(3) # (1,50,1,1,384)
uncond_clip = torch.zeros_like(cond_clip)
motion_bucket = audio2bucket(bucket_clip, image_embeds) * 16 + 16
ref_list .append(ref_img[0])
aud_list .append(audio_pe(cond_clip).squeeze(0)[0]) # (ctx,1024)
uncond_list .append(audio_pe(uncond_clip).squeeze(0)[0]) # (ctx,1024)
mb_list .append(motion_bucket[0])
# -------- UNet 파이프라인 실행 --------------------------------
video = (
pipe(
ref_img,
clip_img,
face_mask,
aud_list,
uncond_list,
mb_list,
height=height,
width=width,
num_frames=len(aud_list),
decode_chunk_size=cfg.decode_chunk_size,
motion_bucket_scale=cfg.motion_bucket_scale,
fps=cfg.fps,
noise_aug_strength=cfg.noise_aug_strength,
min_guidance_scale1=cfg.min_appearance_guidance_scale,
max_guidance_scale1=cfg.max_appearance_guidance_scale,
min_guidance_scale2=cfg.audio_guidance_scale,
max_guidance_scale2=cfg.audio_guidance_scale,
overlap=cfg.overlap,
shift_offset=cfg.shift_offset,
frames_per_batch=cfg.n_sample_frames,
num_inference_steps=cfg.num_inference_steps,
i2i_noise_strength=cfg.i2i_noise_strength,
).frames
* 0.5
+ 0.5
).clamp(0, 1)
# (B,C,T,H,W) → (C,T,H,W)
return video.to(pipe.device).squeeze(0).cpu()
# ------------------------------------------------------------------ #
# Sonic – main class #
# ------------------------------------------------------------------ #
class Sonic:
config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
config = OmegaConf.load(config_file)
def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
cfg = self.config
cfg.use_interframe = enable_interpolate_frame
# diffusers 모델 상위 폴더 (로컬 다운로드 경로)
self.diffusers_root = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path)
self.device = (
f"cuda:{device_id}" if device_id >= 0 and torch.cuda.is_available() else "cpu"
)
self._load_models(cfg)
print("Sonic init done")
# -------------------------------------------------------------- #
def _load_models(self, cfg):
# dtype
dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
diff_root = _locate_diffusers_dir(self.diffusers_root)
# diffusers 모듈들
vae = AutoencoderKLTemporalDecoder.from_pretrained(diff_root, subfolder="vae", variant="fp16")
sched = EulerDiscreteScheduler.from_pretrained(diff_root, subfolder="scheduler")
img_e = CLIPVisionModelWithProjection.from_pretrained(diff_root, subfolder="image_encoder", variant="fp16")
unet = UNetSpatioTemporalConditionModel.from_pretrained(diff_root, subfolder="unet", variant="fp16")
add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
# 오디오 어댑터
a2t = AudioProjModel(seq_len=10, blocks=5, channels=384,
intermediate_dim=1024, output_dim=1024, context_tokens=32).to(self.device)
a2b = Audio2bucketModel(seq_len=50, blocks=1, channels=384,
clip_channels=1024, intermediate_dim=1024, output_dim=1,
context_tokens=2).to(self.device)
# 체크포인트 로드
a2t.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu"))
a2b.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu"))
unet.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu"))
# Whisper
whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
whisper.requires_grad_(False)
# 이미지 / 얼굴 / 보간
self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny"))
self.face_det = AlignImage(self.device, det_path=os.path.join(BASE_DIR, "checkpoints/yoloface_v5m.pt"))
if cfg.use_interframe:
self.rife = RIFEModel(device=self.device)
self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
# dtype 적용
for m in (vae, img_e, unet):
m.to(dtype)
self.pipe = SonicPipeline(unet=unet, image_encoder=img_e, vae=vae, scheduler=sched).to(self.device, dtype=dtype)
self.image_encoder = img_e
self.audio2token = a2t
self.audio2bucket = a2b
self.whisper = whisper
# -------------------------------------------------------------- #
def preprocess(self, image_path: str, expand_ratio: float = 1.0) -> Dict[str, Any]:
img = cv2.imread(image_path)
h, w = img.shape[:2]
_, _, bboxes = self.face_det(img, maxface=True)
if bboxes:
x1, y1, ww, hh = bboxes[0]
crop = process_bbox((x1, y1, x1 + ww, y1 + hh), expand_ratio, h, w)
return {"face_num": 1, "crop_bbox": crop}
return {"face_num": 0, "crop_bbox": None}
# -------------------------------------------------------------- #
@torch.no_grad()
def process(
self,
image_path: str,
audio_path: str,
output_path: str,
min_resolution: int = 512,
inference_steps: int = 25,
dynamic_scale: float = 1.0,
keep_resolution: bool = False,
seed: int | None = None,
) -> int:
cfg = self.config
if seed is not None:
cfg.seed = seed
cfg.num_inference_steps = inference_steps
cfg.motion_bucket_scale = dynamic_scale
seed_everything(cfg.seed)
# 이미지·오디오 tensor 변환
data = image_audio_to_tensor(
self.face_det,
self.feature_extractor,
image_path,
audio_path,
limit=-1,
image_size=min_resolution,
area=cfg.area,
)
if data is None:
return -1
h, w = data["ref_img"].shape[-2:]
if keep_resolution:
im = Image.open(image_path)
resolution = f"{(im.width // 2) * 2}x{(im.height // 2) * 2}"
else:
resolution = f"{w}x{h}"
# video tensor 생성
video = _gen_video_tensor(
self.pipe, cfg, self.whisper, self.audio2token, self.audio2bucket,
self.image_encoder, w, h, data,
)
# 중간 프레임 보간
if cfg.use_interframe:
out = video.to(self.device)
frames = []
for i in tqdm(range(out.shape[1] - 1), desc="interpolate", ncols=0):
frames.extend([out[:, i], self.rife.inference(out[:, i], out[:, i + 1]).clamp(0, 1)])
frames.append(out[:, -1])
video = torch.stack(frames, 1).cpu() # (C,T',H,W)
# 저장
tmp = output_path.replace(".mp4", "_noaudio.mp4")
save_videos_grid(video.unsqueeze(0), tmp, n_rows=1, fps=cfg.fps * (2 if cfg.use_interframe else 1))
os.system(
f"ffmpeg -loglevel error -y -i '{tmp}' -i '{audio_path}' -s {resolution} "
f"-vcodec libx264 -acodec aac -crf 18 -shortest '{output_path}'"
)
os.remove(tmp)
return 0