Spaces:
Running
on
Zero
Running
on
Zero
Update sonic.py
Browse files
sonic.py
CHANGED
@@ -1,9 +1,7 @@
|
|
1 |
-
import os, math
|
2 |
-
import torch
|
3 |
from PIL import Image
|
4 |
from omegaconf import OmegaConf
|
5 |
from tqdm import tqdm
|
6 |
-
import cv2
|
7 |
|
8 |
from diffusers import AutoencoderKLTemporalDecoder
|
9 |
from diffusers.schedulers import EulerDiscreteScheduler
|
@@ -40,15 +38,15 @@ def test(
|
|
40 |
face_mask = batch["face_mask"]
|
41 |
image_embeds = image_encoder(clip_img).image_embeds
|
42 |
|
43 |
-
audio_feature = batch["audio_feature"]
|
44 |
-
audio_len = int(batch["audio_len"])
|
45 |
step = int(config.step)
|
46 |
|
47 |
-
# ---
|
48 |
if audio_len < step:
|
49 |
step = max(1, audio_len)
|
50 |
|
51 |
-
window = 16000
|
52 |
audio_prompts, last_prompts = [], []
|
53 |
|
54 |
# --- window 단위 Whisper 인코딩 --------------------------------------
|
@@ -56,26 +54,28 @@ def test(
|
|
56 |
chunk = audio_feature[:, :, i : i + window]
|
57 |
|
58 |
prompt_layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
|
59 |
-
last_hidden = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
|
60 |
|
61 |
-
audio_prompts.append(torch.stack(prompt_layers, dim=2))
|
62 |
-
last_prompts.append(last_hidden)
|
63 |
|
64 |
if len(audio_prompts) == 0:
|
65 |
raise ValueError("[ERROR] No speech recognised in the provided audio.")
|
66 |
|
67 |
audio_prompts = torch.cat(audio_prompts, dim=1)
|
68 |
-
last_prompts = torch.cat(last_prompts,
|
69 |
|
70 |
# padding 규칙
|
71 |
audio_prompts = torch.cat(
|
72 |
-
[torch.zeros_like(audio_prompts[:, :4]),
|
|
|
73 |
torch.zeros_like(audio_prompts[:, :6])], dim=1)
|
74 |
last_prompts = torch.cat(
|
75 |
-
[torch.zeros_like(last_prompts[:, :24]),
|
|
|
76 |
torch.zeros_like(last_prompts[:, :26])], dim=1)
|
77 |
|
78 |
-
# ---
|
79 |
total_tokens = audio_prompts.shape[1]
|
80 |
num_chunks = max(1, math.ceil(total_tokens / (2 * step)))
|
81 |
|
@@ -84,19 +84,25 @@ def test(
|
|
84 |
for i in tqdm(range(num_chunks)):
|
85 |
start = i * 2 * step
|
86 |
|
87 |
-
cond_clip = audio_prompts[:, start : start + 10]
|
88 |
-
if cond_clip.shape[
|
89 |
-
pad = torch.zeros_like(cond_clip[:,
|
90 |
-
cond_clip = torch.cat([cond_clip, pad], dim=
|
91 |
|
92 |
-
|
93 |
-
bucket_clip =
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
95 |
|
|
|
96 |
|
97 |
ref_list.append(ref_img[0])
|
98 |
-
audio_list.append(audio_pe(cond_clip).squeeze(0)[0])
|
99 |
-
uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0)[0])
|
100 |
motion_buckets.append(motion[0])
|
101 |
|
102 |
# ----------------------------------------------------------------------
|
@@ -125,7 +131,7 @@ def test(
|
|
125 |
|
126 |
|
127 |
# ------------------------------------------------------------------
|
128 |
-
#
|
129 |
# ------------------------------------------------------------------
|
130 |
class Sonic:
|
131 |
config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
|
@@ -144,18 +150,18 @@ class Sonic:
|
|
144 |
def _load_models(self, cfg):
|
145 |
dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
|
146 |
|
147 |
-
vae
|
148 |
sched = EulerDiscreteScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler")
|
149 |
-
|
150 |
-
unet
|
151 |
add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
|
152 |
|
153 |
a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
|
154 |
a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
|
155 |
|
156 |
-
unet.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path),
|
157 |
-
a2t.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path),
|
158 |
-
a2b.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path),
|
159 |
|
160 |
whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
|
161 |
whisper.requires_grad_(False)
|
@@ -166,11 +172,11 @@ class Sonic:
|
|
166 |
self.rife = RIFEModel(device=self.device)
|
167 |
self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
|
168 |
|
169 |
-
for m in (
|
170 |
m.to(dtype)
|
171 |
|
172 |
-
self.pipe = SonicPipeline(unet=unet, image_encoder=
|
173 |
-
self.image_encoder =
|
174 |
self.audio2token = a2t
|
175 |
self.audio2bucket = a2b
|
176 |
self.whisper = whisper
|
|
|
1 |
+
import os, math, torch, cv2
|
|
|
2 |
from PIL import Image
|
3 |
from omegaconf import OmegaConf
|
4 |
from tqdm import tqdm
|
|
|
5 |
|
6 |
from diffusers import AutoencoderKLTemporalDecoder
|
7 |
from diffusers.schedulers import EulerDiscreteScheduler
|
|
|
38 |
face_mask = batch["face_mask"]
|
39 |
image_embeds = image_encoder(clip_img).image_embeds
|
40 |
|
41 |
+
audio_feature = batch["audio_feature"] # (1,80,T)
|
42 |
+
audio_len = int(batch["audio_len"]) # Python int
|
43 |
step = int(config.step)
|
44 |
|
45 |
+
# --- step 보정 (최소 1) -----------------------------------------------
|
46 |
if audio_len < step:
|
47 |
step = max(1, audio_len)
|
48 |
|
49 |
+
window = 16000 # 1초 chunk
|
50 |
audio_prompts, last_prompts = [], []
|
51 |
|
52 |
# --- window 단위 Whisper 인코딩 --------------------------------------
|
|
|
54 |
chunk = audio_feature[:, :, i : i + window]
|
55 |
|
56 |
prompt_layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
|
57 |
+
last_hidden = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2) # (1,t,1,384)
|
58 |
|
59 |
+
audio_prompts.append(torch.stack(prompt_layers, dim=2)) # (1,L,12,80)
|
60 |
+
last_prompts.append(last_hidden) # (1,L,1,384)
|
61 |
|
62 |
if len(audio_prompts) == 0:
|
63 |
raise ValueError("[ERROR] No speech recognised in the provided audio.")
|
64 |
|
65 |
audio_prompts = torch.cat(audio_prompts, dim=1)
|
66 |
+
last_prompts = torch.cat(last_prompts, dim=1)
|
67 |
|
68 |
# padding 규칙
|
69 |
audio_prompts = torch.cat(
|
70 |
+
[torch.zeros_like(audio_prompts[:, :4]),
|
71 |
+
audio_prompts,
|
72 |
torch.zeros_like(audio_prompts[:, :6])], dim=1)
|
73 |
last_prompts = torch.cat(
|
74 |
+
[torch.zeros_like(last_prompts[:, :24]),
|
75 |
+
last_prompts,
|
76 |
torch.zeros_like(last_prompts[:, :26])], dim=1)
|
77 |
|
78 |
+
# --- 반드시 ≥1 chunk --------------------------------------------------
|
79 |
total_tokens = audio_prompts.shape[1]
|
80 |
num_chunks = max(1, math.ceil(total_tokens / (2 * step)))
|
81 |
|
|
|
84 |
for i in tqdm(range(num_chunks)):
|
85 |
start = i * 2 * step
|
86 |
|
87 |
+
cond_clip = audio_prompts[:, start : start + 10] # (1,10,12,80)
|
88 |
+
if cond_clip.shape[1] < 10: # 짧으면 패딩
|
89 |
+
pad = torch.zeros_like(cond_clip[:, : 10 - cond_clip.shape[1]])
|
90 |
+
cond_clip = torch.cat([cond_clip, pad], dim=1)
|
91 |
|
92 |
+
# ------------------ (★) bucket_clip 차원 맞춤 -------------------
|
93 |
+
bucket_clip = last_prompts[:, start : start + 50] # (1,50,1,384)
|
94 |
+
if bucket_clip.shape[1] < 50: # 짧으면 패딩
|
95 |
+
pad = torch.zeros_like(bucket_clip[:, : 50 - bucket_clip.shape[1]])
|
96 |
+
bucket_clip = torch.cat([bucket_clip, pad], dim=1)
|
97 |
+
|
98 |
+
bucket_clip = bucket_clip.unsqueeze(1) # → (1,1,50,1,384) ✔ 5-D
|
99 |
+
# -----------------------------------------------------------------
|
100 |
|
101 |
+
motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
|
102 |
|
103 |
ref_list.append(ref_img[0])
|
104 |
+
audio_list.append(audio_pe(cond_clip.unsqueeze(1)).squeeze(0)[0]) # (10,···)→ unsqueeze 후 4-D
|
105 |
+
uncond_list.append(audio_pe(torch.zeros_like(cond_clip).unsqueeze(1)).squeeze(0)[0])
|
106 |
motion_buckets.append(motion[0])
|
107 |
|
108 |
# ----------------------------------------------------------------------
|
|
|
131 |
|
132 |
|
133 |
# ------------------------------------------------------------------
|
134 |
+
# Sonic 클래스
|
135 |
# ------------------------------------------------------------------
|
136 |
class Sonic:
|
137 |
config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
|
|
|
150 |
def _load_models(self, cfg):
|
151 |
dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
|
152 |
|
153 |
+
vae = AutoencoderKLTemporalDecoder.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", variant="fp16")
|
154 |
sched = EulerDiscreteScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler")
|
155 |
+
imgE = CLIPVisionModelWithProjection.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16")
|
156 |
+
unet = UNetSpatioTemporalConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", variant="fp16")
|
157 |
add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
|
158 |
|
159 |
a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
|
160 |
a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
|
161 |
|
162 |
+
unet.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu"))
|
163 |
+
a2t.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu"))
|
164 |
+
a2b.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu"))
|
165 |
|
166 |
whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
|
167 |
whisper.requires_grad_(False)
|
|
|
172 |
self.rife = RIFEModel(device=self.device)
|
173 |
self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
|
174 |
|
175 |
+
for m in (imgE, vae, unet):
|
176 |
m.to(dtype)
|
177 |
|
178 |
+
self.pipe = SonicPipeline(unet=unet, image_encoder=imgE, vae=vae, scheduler=sched).to(device=self.device, dtype=dtype)
|
179 |
+
self.image_encoder = imgE
|
180 |
self.audio2token = a2t
|
181 |
self.audio2bucket = a2b
|
182 |
self.whisper = whisper
|