Spaces:
Running
on
Zero
Running
on
Zero
""" | |
sonic.py – 2025-05 hot-fix | |
주요 수정 | |
• config.pretrained_model_name_or_path 가 실제 폴더인지 확인 | |
• 없다면 huggingface_hub.snapshot_download 로 자동 다운로드 | |
• 경로가 준비된 뒤 모델 로드 진행 | |
""" | |
import os, math, torch, cv2 | |
from PIL import Image | |
from omegaconf import OmegaConf | |
from tqdm.auto import tqdm | |
from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler | |
from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatureExtractor | |
from huggingface_hub import snapshot_download, hf_hub_download | |
from src.utils.util import save_videos_grid, seed_everything | |
from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor | |
from src.models.base.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel, add_ip_adapters | |
from src.pipelines.pipeline_sonic import SonicPipeline | |
from src.models.audio_adapter.audio_proj import AudioProjModel | |
from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel | |
from src.utils.RIFE.RIFE_HDv3 import RIFEModel | |
from src.dataset.face_align.align import AlignImage | |
# ------------------------------ | |
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
HF_STABLE_REPO = "stabilityai/stable-video-diffusion-img2vid-xt" | |
LOCAL_STABLE_DIR = os.path.join(BASE_DIR, "checkpoints", "stable-video-diffusion-img2vid-xt") | |
# ------------------------------------------------------------------ | |
# single image + speech → video tensor | |
# ------------------------------------------------------------------ | |
def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, img_enc, | |
width, height, batch): | |
# --- batch 차원 맞추기 ------------------------------------------ | |
for k, v in batch.items(): | |
if isinstance(v, torch.Tensor): | |
batch[k] = v.unsqueeze(0).float().to(pipe.device) | |
ref_img = batch['ref_img'] | |
clip_img = batch['clip_images'] | |
face_mask = batch['face_mask'] | |
img_emb = img_enc(clip_img).image_embeds # (1,1024) | |
audio_feat = batch['audio_feature'] # (1,80,T) | |
audio_len = int(batch['audio_len']) | |
step = max(1, int(cfg.step)) # 안전 보정 | |
window = 16_000 # 1-초 chunk | |
prompt_list, last_list = [], [] | |
for i in range(0, audio_feat.shape[-1], window): | |
chunk = audio_feat[:, :, i:i+window] | |
hs = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states | |
prompt_list.append(torch.stack(hs, 2)) # (1,80,L,384) | |
last = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2) | |
last_list.append(last) # (1,80,1,384) | |
if not prompt_list: | |
raise ValueError("❌ No speech recognised in audio.") | |
audio_prompts = torch.cat(prompt_list, 1) # (1,80,*L,384) | |
last_prompts = torch.cat(last_list, 1) # (1,80,*1,384) | |
# pad 규칙 (모델 원 논문과 동일) | |
audio_prompts = torch.cat([ torch.zeros_like(audio_prompts[:,:4]), | |
audio_prompts, | |
torch.zeros_like(audio_prompts[:,:6]) ], 1) | |
last_prompts = torch.cat([ torch.zeros_like(last_prompts[:,:24]), | |
last_prompts, | |
torch.zeros_like(last_prompts[:,:26]) ], 1) | |
# -------------------------------------------------------------- | |
total_tok = audio_prompts.shape[1] | |
n_chunks = max(1, math.ceil(total_tok / (2*step))) | |
ref_L, aud_L, uncond_L, buckets = [], [], [], [] | |
for i in tqdm(range(n_chunks), ncols=0): | |
st = i * 2 * step | |
# ① 조건 오디오 토큰(pad → 10×5×384) | |
cond = audio_prompts[:, st:st+10] # (1,80,10,384) → (1,10,8,384)? | |
cond = cond[:, :10] # f = 10 | |
cond = cond.permute(0,2,1,3) # (1,10,80,384) | |
cond = cond.reshape(1, 10, 10, 5, 384) # ★ w=10, b=5 (zero-pad auto) | |
# ② bucket 추정용 토큰 | |
buck = last_prompts[:, st:st+50] # (1,80,50,384) | |
if buck.shape[1] < 50: | |
pad = torch.zeros(1, 50-buck.shape[1], *buck.shape[2:], device=buck.device) | |
buck = torch.cat([buck, pad], 1) | |
buck = buck[:, :50].permute(0,2,1,3).reshape(1, 50, 10, 5, 384) | |
motion = audio2bucket(buck, img_emb) * 16 + 16 | |
ref_L.append(ref_img[0]) | |
aud_L.append(audio_pe(cond).squeeze(0)) # (10,1024) | |
uncond_L.append(audio_pe(torch.zeros_like(cond)).squeeze(0)) | |
buckets.append(motion[0]) | |
# -------------- diffusion ------------------------------------------------- | |
vid = pipe( | |
ref_img, clip_img, face_mask, | |
aud_L, uncond_L, buckets, | |
height=height, width=width, | |
num_frames=len(aud_L), | |
decode_chunk_size=cfg.decode_chunk_size, | |
motion_bucket_scale=cfg.motion_bucket_scale, | |
fps=cfg.fps, | |
noise_aug_strength=cfg.noise_aug_strength, | |
min_guidance_scale1=cfg.min_appearance_guidance_scale, | |
max_guidance_scale1=cfg.max_appearance_guidance_scale, | |
min_guidance_scale2=cfg.audio_guidance_scale, | |
max_guidance_scale2=cfg.audio_guidance_scale, | |
overlap=cfg.overlap, | |
shift_offset=cfg.shift_offset, | |
frames_per_batch=cfg.n_sample_frames, | |
num_inference_steps=cfg.num_inference_steps, | |
i2i_noise_strength=cfg.i2i_noise_strength, | |
).frames | |
return (vid*0.5+0.5).clamp(0,1).to(pipe.device).unsqueeze(0).cpu() | |
# ------------------------------------------------------------------ | |
# Sonic wrapper | |
# ------------------------------------------------------------------ | |
class Sonic: | |
config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml") | |
config = OmegaConf.load(config_file) | |
def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True): | |
cfg = self.config | |
cfg.use_interframe = enable_interpolate_frame | |
self.device = f"cuda:{device_id}" if torch.cuda.is_available() and device_id >= 0 else "cpu" | |
# ----------- ✨ [NEW] pretrained 모델 폴더 확보 ---------------------- | |
if not os.path.isdir(LOCAL_STABLE_DIR) or not os.path.isfile(os.path.join(LOCAL_STABLE_DIR, "vae", "config.json")): | |
print("[INFO] 1st-run – downloading base model (~2 GB)…") | |
snapshot_download(repo_id=HF_STABLE_REPO, | |
local_dir=LOCAL_STABLE_DIR, | |
resume_download=True, | |
local_dir_use_symlinks=False) | |
cfg.pretrained_model_name_or_path = LOCAL_STABLE_DIR | |
# ------------------------------------------------------------------ | |
self._load_models(cfg) | |
print("Sonic init done") | |
# model-loader (unchanged, but with tiny clean-ups) ------------------------ | |
def _load_models(self, cfg): | |
dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype] | |
vae = AutoencoderKLTemporalDecoder.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", variant="fp16") | |
sched = EulerDiscreteScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler") | |
img_enc = CLIPVisionModelWithProjection.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16") | |
unet = UNetSpatioTemporalConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", variant="fp16") | |
add_ip_adapters(unet, [32], [cfg.ip_audio_scale]) | |
self.audio2token = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device) | |
self.audio2bucket = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device) | |
unet.load_state_dict (torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu")) | |
self.audio2token.load_state_dict (torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu")) | |
self.audio2bucket.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu")) | |
self.whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval() | |
self.whisper.requires_grad_(False) | |
self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")) | |
self.face_det = AlignImage(self.device, det_path=os.path.join(BASE_DIR, "checkpoints/yoloface_v5m.pt")) | |
if cfg.use_interframe: | |
self.rife = RIFEModel(device=self.device); self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/")) | |
for m in (img_enc, vae, unet): m.to(dtype) | |
self.pipe = SonicPipeline(unet=unet, image_encoder=img_enc, vae=vae, scheduler=sched).to(device=self.device, dtype=dtype) | |
self.image_encoder = img_enc | |
# ------------------------------------------------------------------ | |
def preprocess(self, img_path, expand_ratio=1.0): | |
img = cv2.imread(img_path) | |
_, _, boxes = self.face_det(img, maxface=True) | |
if boxes: | |
x,y,w,h = boxes[0]; return {"face_num":1,"crop_bbox":process_bbox((x,y,x+w,y+h),expand_ratio,*img.shape[:2])} | |
return {"face_num":0,"crop_bbox":None} | |
# ------------------------------------------------------------------ | |
def process(self, img_path, wav_path, out_path, | |
min_resolution=512, inference_steps=25, | |
dynamic_scale=1.0, keep_resolution=False, seed=None): | |
cfg = self.config | |
if seed is not None: cfg.seed = seed | |
cfg.num_inference_steps = inference_steps | |
cfg.motion_bucket_scale = dynamic_scale | |
seed_everything(cfg.seed) | |
sample = image_audio_to_tensor( | |
self.face_det, self.feature_extractor, | |
img_path, wav_path, limit=-1, | |
image_size=min_resolution, area=cfg.area, | |
) | |
if sample is None: return -1 | |
h,w = sample['ref_img'].shape[-2:] | |
resolution = (f"{Image.open(img_path).width//2*2}x{Image.open(img_path).height//2*2}" | |
if keep_resolution else f"{w}x{h}") | |
video = test(self.pipe, cfg, self.whisper, self.audio2token, | |
self.audio2bucket, self.image_encoder, w, h, sample) | |
if cfg.use_interframe: # RIFE interpolation | |
out = video.to(self.device); frames=[] | |
for i in tqdm(range(out.shape[2]-1), ncols=0): | |
mid = self.rife.inference(out[:,:,i], out[:,:,i+1]).clamp(0,1) | |
frames += [out[:,:,i], mid] | |
frames.append(out[:,:,-1]); video = torch.stack(frames,2).cpu() | |
tmp = out_path.replace(".mp4","_noaudio.mp4") | |
save_videos_grid(video, tmp, n_rows=video.shape[0], fps=cfg.fps*(2 if cfg.use_interframe else 1)) | |
os.system(f"ffmpeg -i '{tmp}' -i '{wav_path}' -s {resolution} " | |
f"-vcodec libx264 -acodec aac -crf 18 -shortest '{out_path}' -y -loglevel error") | |
os.remove(tmp); return 0 | |