Spaces:
Running
on
Zero
Running
on
Zero
Update sonic.py
Browse files
sonic.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# sonic.py
|
2 |
|
3 |
import os, math, glob, torch, cv2
|
4 |
from PIL import Image
|
@@ -22,32 +22,29 @@ from src.dataset.face_align.align import AlignImage
|
|
22 |
|
23 |
try:
|
24 |
from safetensors.torch import load_file as safe_load
|
25 |
-
except ImportError:
|
26 |
-
safe_load = None
|
27 |
|
28 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
29 |
|
30 |
|
31 |
-
#
|
32 |
-
# 공용 : 체크포인트(가중치) 탐색 함수
|
33 |
-
# -------------------------------------------------------------------
|
34 |
def _find_ckpt(root: str, keyword: str):
|
35 |
-
"""root
|
36 |
-
patterns = [f"**/*{keyword}*.pth",
|
|
|
37 |
f"**/*{keyword}*.safetensors"]
|
38 |
-
files = []
|
39 |
for p in patterns:
|
40 |
-
files
|
41 |
-
|
|
|
|
|
42 |
|
43 |
|
44 |
-
#
|
45 |
-
# single image + speech → video tensor
|
46 |
-
# -------------------------------------------------------------------
|
47 |
def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
|
48 |
width, height, batch):
|
49 |
|
50 |
-
# 배치 차원 맞추기
|
51 |
for k, v in batch.items():
|
52 |
if isinstance(v, torch.Tensor):
|
53 |
batch[k] = v.unsqueeze(0).to(pipe.device).float()
|
@@ -59,17 +56,17 @@ def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
|
|
59 |
|
60 |
audio_feature = batch["audio_feature"] # (1,80,T)
|
61 |
audio_len = int(batch["audio_len"])
|
62 |
-
step = max(1, int(cfg.step)
|
63 |
|
64 |
-
window = 16_000
|
65 |
audio_prompts, last_prompts = [], []
|
66 |
|
67 |
for i in range(0, audio_feature.shape[-1], window):
|
68 |
chunk = audio_feature[:, :, i:i+window]
|
69 |
-
|
70 |
-
|
71 |
-
audio_prompts.append(torch.stack(
|
72 |
-
last_prompts.append(
|
73 |
|
74 |
if not audio_prompts:
|
75 |
raise ValueError("[ERROR] No speech recognised in the provided audio.")
|
@@ -77,33 +74,29 @@ def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
|
|
77 |
audio_prompts = torch.cat(audio_prompts, dim=1)
|
78 |
last_prompts = torch.cat(last_prompts , dim=1)
|
79 |
|
80 |
-
# padding 규칙
|
81 |
audio_prompts = torch.cat(
|
82 |
[torch.zeros_like(audio_prompts[:, :4]),
|
83 |
audio_prompts,
|
84 |
-
torch.zeros_like(audio_prompts[:, :6])],
|
85 |
-
|
86 |
last_prompts = torch.cat(
|
87 |
[torch.zeros_like(last_prompts[:, :24]),
|
88 |
last_prompts,
|
89 |
-
torch.zeros_like(last_prompts[:, :26])],
|
90 |
|
91 |
-
|
92 |
-
num_chunks = max(1, math.ceil(total_tokens / (2*step)))
|
93 |
|
94 |
ref_list, audio_list, uncond_list, buckets = [], [], [], []
|
95 |
-
|
96 |
for i in tqdm(range(num_chunks)):
|
97 |
st = i * 2 * step
|
98 |
cond = audio_prompts[:, st: st+10]
|
99 |
if cond.shape[2] < 10:
|
100 |
pad = torch.zeros_like(cond[:, :, :10-cond.shape[2]])
|
101 |
-
cond = torch.cat([cond, pad],
|
102 |
|
103 |
bucket_clip = last_prompts[:, st: st+50]
|
104 |
if bucket_clip.shape[2] < 50:
|
105 |
pad = torch.zeros_like(bucket_clip[:, :, :50-bucket_clip.shape[2]])
|
106 |
-
bucket_clip = torch.cat([bucket_clip, pad],
|
107 |
|
108 |
motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
|
109 |
|
@@ -132,12 +125,10 @@ def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
|
|
132 |
i2i_noise_strength=cfg.i2i_noise_strength,
|
133 |
).frames
|
134 |
|
135 |
-
return (video *
|
136 |
|
137 |
|
138 |
-
#
|
139 |
-
# Sonic ✨
|
140 |
-
# -------------------------------------------------------------------
|
141 |
class Sonic:
|
142 |
config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
|
143 |
config = OmegaConf.load(config_file)
|
@@ -147,45 +138,40 @@ class Sonic:
|
|
147 |
cfg.use_interframe = enable_interpolate_frame
|
148 |
self.device = f"cuda:{device_id}" if torch.cuda.is_available() and device_id >= 0 else "cpu"
|
149 |
|
150 |
-
#
|
151 |
-
|
152 |
-
|
|
|
153 |
|
154 |
-
self._load_models(cfg
|
155 |
print("Sonic init done")
|
156 |
|
157 |
-
#
|
158 |
-
def _load_models(self, cfg
|
159 |
dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
|
160 |
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
unet = UNetSpatioTemporalConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", variant="fp16")
|
166 |
add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
|
167 |
|
168 |
-
|
169 |
-
|
170 |
-
path = _find_ckpt(ckpt_root, keyword)
|
171 |
if not path:
|
172 |
-
print(f"[WARN] {
|
173 |
return
|
174 |
-
print(f"[INFO] load {
|
175 |
-
if path.endswith(".safetensors")
|
176 |
-
state = safe_load(path, device="cpu")
|
177 |
-
else:
|
178 |
-
state = torch.load(path, map_location="cpu")
|
179 |
module.load_state_dict(state, strict=False)
|
180 |
|
181 |
-
_try_load(unet, "unet")
|
182 |
-
# audio adapters (필수)
|
183 |
a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
|
184 |
a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
|
185 |
-
_try_load(a2t, "audio2token")
|
186 |
-
_try_load(a2b, "audio2bucket")
|
187 |
|
188 |
-
|
|
|
|
|
|
|
189 |
whisper = WhisperModel.from_pretrained(
|
190 |
os.path.join(BASE_DIR, "checkpoints/whisper-tiny")
|
191 |
).to(self.device).eval()
|
@@ -199,16 +185,16 @@ class Sonic:
|
|
199 |
self.rife = RIFEModel(device=self.device)
|
200 |
self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
|
201 |
|
202 |
-
for m in (
|
203 |
m.to(dtype)
|
204 |
|
205 |
-
self.pipe = SonicPipeline(unet=unet, image_encoder=
|
206 |
-
self.image_encoder =
|
207 |
self.audio2token = a2t
|
208 |
self.audio2bucket = a2b
|
209 |
self.whisper = whisper
|
210 |
|
211 |
-
#
|
212 |
def preprocess(self, image_path: str, expand_ratio: float = 1.0):
|
213 |
img = cv2.imread(image_path)
|
214 |
h, w = img.shape[:2]
|
@@ -219,20 +205,19 @@ class Sonic:
|
|
219 |
"crop_bbox": process_bbox((x1, y1, x1+ww, y1+hh), expand_ratio, h, w)}
|
220 |
return {"face_num": 0, "crop_bbox": None}
|
221 |
|
222 |
-
#
|
223 |
@torch.no_grad()
|
224 |
def process(self, image_path, audio_path, output_path,
|
225 |
-
min_resolution=512, inference_steps=25,
|
226 |
-
keep_resolution=False, seed=None):
|
227 |
|
228 |
cfg = self.config
|
229 |
if seed is not None:
|
230 |
cfg.seed = seed
|
231 |
-
cfg.num_inference_steps
|
232 |
-
cfg.motion_bucket_scale
|
233 |
seed_everything(cfg.seed)
|
234 |
|
235 |
-
# 이미지·오디오 → tensor
|
236 |
data = image_audio_to_tensor(
|
237 |
self.face_det, self.feature_extractor,
|
238 |
image_path, audio_path,
|
@@ -252,7 +237,6 @@ class Sonic:
|
|
252 |
self.audio2bucket, self.image_encoder,
|
253 |
w, h, data)
|
254 |
|
255 |
-
# 인터프레임 보간
|
256 |
if cfg.use_interframe:
|
257 |
out, frames = video.to(self.device), []
|
258 |
for i in tqdm(range(out.shape[2]-1), ncols=0):
|
|
|
1 |
+
# sonic.py ── 전체
|
2 |
|
3 |
import os, math, glob, torch, cv2
|
4 |
from PIL import Image
|
|
|
22 |
|
23 |
try:
|
24 |
from safetensors.torch import load_file as safe_load
|
25 |
+
except ImportError:
|
26 |
+
safe_load = None # safetensors 미설치 시 대비
|
27 |
|
28 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
29 |
|
30 |
|
31 |
+
# ------------------------------------------------------------ utils
|
|
|
|
|
32 |
def _find_ckpt(root: str, keyword: str):
|
33 |
+
"""root 밑에서 keyword 가 포함된 .pth / .pt / .safetensors 하나 찾기"""
|
34 |
+
patterns = [f"**/*{keyword}*.pth",
|
35 |
+
f"**/*{keyword}*.pt",
|
36 |
f"**/*{keyword}*.safetensors"]
|
|
|
37 |
for p in patterns:
|
38 |
+
files = glob.glob(os.path.join(root, p), recursive=True)
|
39 |
+
if files:
|
40 |
+
return files[0]
|
41 |
+
return None
|
42 |
|
43 |
|
44 |
+
# --------------------------------------------------- speech → video
|
|
|
|
|
45 |
def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
|
46 |
width, height, batch):
|
47 |
|
|
|
48 |
for k, v in batch.items():
|
49 |
if isinstance(v, torch.Tensor):
|
50 |
batch[k] = v.unsqueeze(0).to(pipe.device).float()
|
|
|
56 |
|
57 |
audio_feature = batch["audio_feature"] # (1,80,T)
|
58 |
audio_len = int(batch["audio_len"])
|
59 |
+
step = max(1, int(cfg.step))
|
60 |
|
61 |
+
window = 16_000
|
62 |
audio_prompts, last_prompts = [], []
|
63 |
|
64 |
for i in range(0, audio_feature.shape[-1], window):
|
65 |
chunk = audio_feature[:, :, i:i+window]
|
66 |
+
hidden = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
|
67 |
+
last = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
|
68 |
+
audio_prompts.append(torch.stack(hidden, dim=2))
|
69 |
+
last_prompts.append(last)
|
70 |
|
71 |
if not audio_prompts:
|
72 |
raise ValueError("[ERROR] No speech recognised in the provided audio.")
|
|
|
74 |
audio_prompts = torch.cat(audio_prompts, dim=1)
|
75 |
last_prompts = torch.cat(last_prompts , dim=1)
|
76 |
|
|
|
77 |
audio_prompts = torch.cat(
|
78 |
[torch.zeros_like(audio_prompts[:, :4]),
|
79 |
audio_prompts,
|
80 |
+
torch.zeros_like(audio_prompts[:, :6])], 1)
|
|
|
81 |
last_prompts = torch.cat(
|
82 |
[torch.zeros_like(last_prompts[:, :24]),
|
83 |
last_prompts,
|
84 |
+
torch.zeros_like(last_prompts[:, :26])], 1)
|
85 |
|
86 |
+
num_chunks = max(1, math.ceil(audio_prompts.shape[1] / (2*step)))
|
|
|
87 |
|
88 |
ref_list, audio_list, uncond_list, buckets = [], [], [], []
|
|
|
89 |
for i in tqdm(range(num_chunks)):
|
90 |
st = i * 2 * step
|
91 |
cond = audio_prompts[:, st: st+10]
|
92 |
if cond.shape[2] < 10:
|
93 |
pad = torch.zeros_like(cond[:, :, :10-cond.shape[2]])
|
94 |
+
cond = torch.cat([cond, pad], 2)
|
95 |
|
96 |
bucket_clip = last_prompts[:, st: st+50]
|
97 |
if bucket_clip.shape[2] < 50:
|
98 |
pad = torch.zeros_like(bucket_clip[:, :, :50-bucket_clip.shape[2]])
|
99 |
+
bucket_clip = torch.cat([bucket_clip, pad], 2)
|
100 |
|
101 |
motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
|
102 |
|
|
|
125 |
i2i_noise_strength=cfg.i2i_noise_strength,
|
126 |
).frames
|
127 |
|
128 |
+
return (video * .5 + .5).clamp(0,1).unsqueeze(0).cpu()
|
129 |
|
130 |
|
131 |
+
# ------------------------------------------------------------ Sonic
|
|
|
|
|
132 |
class Sonic:
|
133 |
config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
|
134 |
config = OmegaConf.load(config_file)
|
|
|
138 |
cfg.use_interframe = enable_interpolate_frame
|
139 |
self.device = f"cuda:{device_id}" if torch.cuda.is_available() and device_id >= 0 else "cpu"
|
140 |
|
141 |
+
# diffusers 베이스 모델은 ⇣ (config.json 포함)
|
142 |
+
self.diffusers_root = os.path.join(BASE_DIR, "checkpoints", "stable-video-diffusion-img2vid-xt")
|
143 |
+
# 추가 pth/pt/safetensors 는 ⇣
|
144 |
+
self.ckpt_root = os.path.join(BASE_DIR, "checkpoints", "Sonic")
|
145 |
|
146 |
+
self._load_models(cfg)
|
147 |
print("Sonic init done")
|
148 |
|
149 |
+
# --------------------------------------------- load all networks
|
150 |
+
def _load_models(self, cfg):
|
151 |
dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
|
152 |
|
153 |
+
vae = AutoencoderKLTemporalDecoder.from_pretrained(self.diffusers_root, subfolder="vae", variant="fp16")
|
154 |
+
sched = EulerDiscreteScheduler.from_pretrained(self.diffusers_root, subfolder="scheduler")
|
155 |
+
img_e = CLIPVisionModelWithProjection.from_pretrained(self.diffusers_root, subfolder="image_encoder", variant="fp16")
|
156 |
+
unet = UNetSpatioTemporalConditionModel.from_pretrained(self.diffusers_root, subfolder="unet", variant="fp16")
|
|
|
157 |
add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
|
158 |
|
159 |
+
def _load_extra(module, key):
|
160 |
+
path = _find_ckpt(self.ckpt_root, key)
|
|
|
161 |
if not path:
|
162 |
+
print(f"[WARN] extra ckpt for '{key}' not found → skip")
|
163 |
return
|
164 |
+
print(f"[INFO] load {key} → {os.path.relpath(path, BASE_DIR)}")
|
165 |
+
state = safe_load(path, device="cpu") if (safe_load and path.endswith(".safetensors")) else torch.load(path, map_location="cpu")
|
|
|
|
|
|
|
166 |
module.load_state_dict(state, strict=False)
|
167 |
|
|
|
|
|
168 |
a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
|
169 |
a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
|
|
|
|
|
170 |
|
171 |
+
_load_extra(unet, "unet")
|
172 |
+
_load_extra(a2t, "audio2token")
|
173 |
+
_load_extra(a2b, "audio2bucket")
|
174 |
+
|
175 |
whisper = WhisperModel.from_pretrained(
|
176 |
os.path.join(BASE_DIR, "checkpoints/whisper-tiny")
|
177 |
).to(self.device).eval()
|
|
|
185 |
self.rife = RIFEModel(device=self.device)
|
186 |
self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
|
187 |
|
188 |
+
for m in (img_e, vae, unet):
|
189 |
m.to(dtype)
|
190 |
|
191 |
+
self.pipe = SonicPipeline(unet=unet, image_encoder=img_e, vae=vae, scheduler=sched).to(device=self.device, dtype=dtype)
|
192 |
+
self.image_encoder = img_e
|
193 |
self.audio2token = a2t
|
194 |
self.audio2bucket = a2b
|
195 |
self.whisper = whisper
|
196 |
|
197 |
+
# --------------------------------------------- preprocess helpers
|
198 |
def preprocess(self, image_path: str, expand_ratio: float = 1.0):
|
199 |
img = cv2.imread(image_path)
|
200 |
h, w = img.shape[:2]
|
|
|
205 |
"crop_bbox": process_bbox((x1, y1, x1+ww, y1+hh), expand_ratio, h, w)}
|
206 |
return {"face_num": 0, "crop_bbox": None}
|
207 |
|
208 |
+
# --------------------------------------------------------------- run
|
209 |
@torch.no_grad()
|
210 |
def process(self, image_path, audio_path, output_path,
|
211 |
+
min_resolution=512, inference_steps=25,
|
212 |
+
dynamic_scale=1.0, keep_resolution=False, seed=None):
|
213 |
|
214 |
cfg = self.config
|
215 |
if seed is not None:
|
216 |
cfg.seed = seed
|
217 |
+
cfg.num_inference_steps = inference_steps
|
218 |
+
cfg.motion_bucket_scale = dynamic_scale
|
219 |
seed_everything(cfg.seed)
|
220 |
|
|
|
221 |
data = image_audio_to_tensor(
|
222 |
self.face_det, self.feature_extractor,
|
223 |
image_path, audio_path,
|
|
|
237 |
self.audio2bucket, self.image_encoder,
|
238 |
w, h, data)
|
239 |
|
|
|
240 |
if cfg.use_interframe:
|
241 |
out, frames = video.to(self.device), []
|
242 |
for i in tqdm(range(out.shape[2]-1), ncols=0):
|