Portrait-Animation / sonic.py
openfree's picture
Update sonic.py
1d7967c verified
raw
history blame
11.2 kB
# sonic.py ── 전체
import os, math, glob, torch, cv2
from PIL import Image
from omegaconf import OmegaConf
from tqdm import tqdm
from diffusers import AutoencoderKLTemporalDecoder
from diffusers.schedulers import EulerDiscreteScheduler
from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatureExtractor
from src.utils.util import save_videos_grid, seed_everything
from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor
from src.models.base.unet_spatio_temporal_condition import (
UNetSpatioTemporalConditionModel, add_ip_adapters,
)
from src.pipelines.pipeline_sonic import SonicPipeline
from src.models.audio_adapter.audio_proj import AudioProjModel
from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel
from src.utils.RIFE.RIFE_HDv3 import RIFEModel
from src.dataset.face_align.align import AlignImage
try:
from safetensors.torch import load_file as safe_load
except ImportError:
safe_load = None # safetensors 미설치 시 대비
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# ------------------------------------------------------------ utils
def _find_ckpt(root: str, keyword: str):
"""root 밑에서 keyword 가 포함된 .pth / .pt / .safetensors 하나 찾기"""
patterns = [f"**/*{keyword}*.pth",
f"**/*{keyword}*.pt",
f"**/*{keyword}*.safetensors"]
for p in patterns:
files = glob.glob(os.path.join(root, p), recursive=True)
if files:
return files[0]
return None
# --------------------------------------------------- speech → video
def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
width, height, batch):
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.unsqueeze(0).to(pipe.device).float()
ref_img = batch["ref_img"]
clip_img = batch["clip_images"]
face_mask = batch["face_mask"]
image_embeds = image_encoder(clip_img).image_embeds
audio_feature = batch["audio_feature"] # (1,80,T)
audio_len = int(batch["audio_len"])
step = max(1, int(cfg.step))
window = 16_000
audio_prompts, last_prompts = [], []
for i in range(0, audio_feature.shape[-1], window):
chunk = audio_feature[:, :, i:i+window]
hidden = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
last = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
audio_prompts.append(torch.stack(hidden, dim=2))
last_prompts.append(last)
if not audio_prompts:
raise ValueError("[ERROR] No speech recognised in the provided audio.")
audio_prompts = torch.cat(audio_prompts, dim=1)
last_prompts = torch.cat(last_prompts , dim=1)
audio_prompts = torch.cat(
[torch.zeros_like(audio_prompts[:, :4]),
audio_prompts,
torch.zeros_like(audio_prompts[:, :6])], 1)
last_prompts = torch.cat(
[torch.zeros_like(last_prompts[:, :24]),
last_prompts,
torch.zeros_like(last_prompts[:, :26])], 1)
num_chunks = max(1, math.ceil(audio_prompts.shape[1] / (2*step)))
ref_list, audio_list, uncond_list, buckets = [], [], [], []
for i in tqdm(range(num_chunks)):
st = i * 2 * step
cond = audio_prompts[:, st: st+10]
if cond.shape[2] < 10:
pad = torch.zeros_like(cond[:, :, :10-cond.shape[2]])
cond = torch.cat([cond, pad], 2)
bucket_clip = last_prompts[:, st: st+50]
if bucket_clip.shape[2] < 50:
pad = torch.zeros_like(bucket_clip[:, :, :50-bucket_clip.shape[2]])
bucket_clip = torch.cat([bucket_clip, pad], 2)
motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
ref_list.append(ref_img[0])
audio_list.append(audio_pe(cond).squeeze(0))
uncond_list.append(audio_pe(torch.zeros_like(cond)).squeeze(0))
buckets.append(motion[0])
video = pipe(
ref_img, clip_img, face_mask,
audio_list, uncond_list, buckets,
height=height, width=width,
num_frames=len(audio_list),
decode_chunk_size=cfg.decode_chunk_size,
motion_bucket_scale=cfg.motion_bucket_scale,
fps=cfg.fps,
noise_aug_strength=cfg.noise_aug_strength,
min_guidance_scale1=cfg.min_appearance_guidance_scale,
max_guidance_scale1=cfg.max_appearance_guidance_scale,
min_guidance_scale2=cfg.audio_guidance_scale,
max_guidance_scale2=cfg.audio_guidance_scale,
overlap=cfg.overlap,
shift_offset=cfg.shift_offset,
frames_per_batch=cfg.n_sample_frames,
num_inference_steps=cfg.num_inference_steps,
i2i_noise_strength=cfg.i2i_noise_strength,
).frames
return (video * .5 + .5).clamp(0,1).unsqueeze(0).cpu()
# ------------------------------------------------------------ Sonic
class Sonic:
config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
config = OmegaConf.load(config_file)
def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
cfg = self.config
cfg.use_interframe = enable_interpolate_frame
self.device = f"cuda:{device_id}" if torch.cuda.is_available() and device_id >= 0 else "cpu"
# diffusers 베이스 모델은 ⇣ (config.json 포함)
self.diffusers_root = os.path.join(BASE_DIR, "checkpoints", "stable-video-diffusion-img2vid-xt")
# 추가 pth/pt/safetensors 는 ⇣
self.ckpt_root = os.path.join(BASE_DIR, "checkpoints", "Sonic")
self._load_models(cfg)
print("Sonic init done")
def _locate_diffusers_dir(root: str) -> str:
"""
root 아래에서 model_index.json 또는 config.json 이 존재하는
디렉터리를 찾아 반환. (snapshots/<sha>/ … 형식 대응)
"""
for cur, _dirs, files in os.walk(root):
if {"model_index.json", "config.json"} & set(files):
return cur
raise FileNotFoundError(
f"[ERROR] diffusers model files(model_index.json/config.json) "
f"not found under {root}"
)
# --------------------------------------------- load all networks
def _load_models(self, cfg):
dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
diff_root = _locate_diffusers_dir(self.diffusers_root) # ★★ 핵심 추가
vae = AutoencoderKLTemporalDecoder.from_pretrained(self.diffusers_root, subfolder="vae", variant="fp16")
sched = EulerDiscreteScheduler.from_pretrained(self.diffusers_root, subfolder="scheduler")
img_e = CLIPVisionModelWithProjection.from_pretrained(self.diffusers_root, subfolder="image_encoder", variant="fp16")
unet = UNetSpatioTemporalConditionModel.from_pretrained(self.diffusers_root, subfolder="unet", variant="fp16")
add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
def _load_extra(module, key):
path = _find_ckpt(self.ckpt_root, key)
if not path:
print(f"[WARN] extra ckpt for '{key}' not found → skip")
return
print(f"[INFO] load {key}{os.path.relpath(path, BASE_DIR)}")
state = safe_load(path, device="cpu") if (safe_load and path.endswith(".safetensors")) else torch.load(path, map_location="cpu")
module.load_state_dict(state, strict=False)
a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
_load_extra(unet, "unet")
_load_extra(a2t, "audio2token")
_load_extra(a2b, "audio2bucket")
whisper = WhisperModel.from_pretrained(
os.path.join(BASE_DIR, "checkpoints/whisper-tiny")
).to(self.device).eval()
whisper.requires_grad_(False)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
os.path.join(BASE_DIR, "checkpoints/whisper-tiny")
)
self.face_det = AlignImage(self.device, det_path=os.path.join(BASE_DIR, "checkpoints/yoloface_v5m.pt"))
if cfg.use_interframe:
self.rife = RIFEModel(device=self.device)
self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
for m in (img_e, vae, unet):
m.to(dtype)
self.pipe = SonicPipeline(unet=unet, image_encoder=img_e, vae=vae, scheduler=sched).to(device=self.device, dtype=dtype)
self.image_encoder = img_e
self.audio2token = a2t
self.audio2bucket = a2b
self.whisper = whisper
# --------------------------------------------- preprocess helpers
def preprocess(self, image_path: str, expand_ratio: float = 1.0):
img = cv2.imread(image_path)
h, w = img.shape[:2]
_, _, bboxes = self.face_det(img, maxface=True)
if bboxes:
x1, y1, ww, hh = bboxes[0]
return {"face_num": 1,
"crop_bbox": process_bbox((x1, y1, x1+ww, y1+hh), expand_ratio, h, w)}
return {"face_num": 0, "crop_bbox": None}
# --------------------------------------------------------------- run
@torch.no_grad()
def process(self, image_path, audio_path, output_path,
min_resolution=512, inference_steps=25,
dynamic_scale=1.0, keep_resolution=False, seed=None):
cfg = self.config
if seed is not None:
cfg.seed = seed
cfg.num_inference_steps = inference_steps
cfg.motion_bucket_scale = dynamic_scale
seed_everything(cfg.seed)
data = image_audio_to_tensor(
self.face_det, self.feature_extractor,
image_path, audio_path,
limit=-1, image_size=min_resolution, area=cfg.area
)
if data is None:
return -1
h, w = data["ref_img"].shape[-2:]
if keep_resolution:
im = Image.open(image_path)
resolution = f"{im.width//2*2}x{im.height//2*2}"
else:
resolution = f"{w}x{h}"
video = test(self.pipe, cfg, self.whisper, self.audio2token,
self.audio2bucket, self.image_encoder,
w, h, data)
if cfg.use_interframe:
out, frames = video.to(self.device), []
for i in tqdm(range(out.shape[2]-1), ncols=0):
mid = self.rife.inference(out[:,:,i], out[:,:,i+1]).clamp(0,1).detach()
frames.extend([out[:,:,i], mid])
frames.append(out[:,:,-1])
video = torch.stack(frames, 2).cpu()
tmp = output_path.replace(".mp4", "_noaudio.mp4")
save_videos_grid(video, tmp, n_rows=video.shape[0],
fps=cfg.fps*(2 if cfg.use_interframe else 1))
os.system(
f"ffmpeg -loglevel error -y -i '{tmp}' -i '{audio_path}' -s {resolution} "
f"-vcodec libx264 -acodec aac -crf 18 -shortest '{output_path}'")
os.remove(tmp)
return 0