Spaces:
Running
on
Zero
Running
on
Zero
Update sonic.py
Browse files
sonic.py
CHANGED
@@ -9,9 +9,7 @@ from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatur
|
|
9 |
|
10 |
from src.utils.util import save_videos_grid, seed_everything
|
11 |
from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor
|
12 |
-
from src.models.base.unet_spatio_temporal_condition import
|
13 |
-
UNetSpatioTemporalConditionModel, add_ip_adapters,
|
14 |
-
)
|
15 |
from src.pipelines.pipeline_sonic import SonicPipeline
|
16 |
from src.models.audio_adapter.audio_proj import AudioProjModel
|
17 |
from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel
|
@@ -22,13 +20,12 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
22 |
|
23 |
|
24 |
# ------------------------------------------------------------------
|
25 |
-
#
|
26 |
# ------------------------------------------------------------------
|
27 |
-
def test(
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
# --- 배치 차원 맞추기 --------------------------------------------------
|
32 |
for k, v in batch.items():
|
33 |
if isinstance(v, torch.Tensor):
|
34 |
batch[k] = v.unsqueeze(0).to(pipe.device).float()
|
@@ -39,33 +36,29 @@ def test(
|
|
39 |
image_embeds = image_encoder(clip_img).image_embeds
|
40 |
|
41 |
audio_feature = batch["audio_feature"] # (1,80,T)
|
42 |
-
audio_len = int(batch["audio_len"])
|
43 |
-
step = int(
|
44 |
-
|
45 |
-
# --- step 보정 (최소 1) -----------------------------------------------
|
46 |
-
if audio_len < step:
|
47 |
-
step = max(1, audio_len)
|
48 |
|
49 |
-
window =
|
50 |
audio_prompts, last_prompts = [], []
|
51 |
|
52 |
-
#
|
53 |
for i in range(0, audio_feature.shape[-1], window):
|
54 |
-
chunk = audio_feature[:, :, i
|
55 |
|
56 |
-
|
57 |
-
|
58 |
|
59 |
-
audio_prompts.append(torch.stack(
|
60 |
-
last_prompts.append(
|
61 |
|
62 |
-
if
|
63 |
raise ValueError("[ERROR] No speech recognised in the provided audio.")
|
64 |
|
65 |
-
audio_prompts = torch.cat(audio_prompts, dim=1)
|
66 |
-
last_prompts = torch.cat(last_prompts, dim=1)
|
67 |
|
68 |
-
# padding 규칙
|
69 |
audio_prompts = torch.cat(
|
70 |
[torch.zeros_like(audio_prompts[:, :4]),
|
71 |
audio_prompts,
|
@@ -75,7 +68,6 @@ def test(
|
|
75 |
last_prompts,
|
76 |
torch.zeros_like(last_prompts[:, :26])], dim=1)
|
77 |
|
78 |
-
# --- 반드시 ≥1 chunk --------------------------------------------------
|
79 |
total_tokens = audio_prompts.shape[1]
|
80 |
num_chunks = max(1, math.ceil(total_tokens / (2 * step)))
|
81 |
|
@@ -84,46 +76,46 @@ def test(
|
|
84 |
for i in tqdm(range(num_chunks)):
|
85 |
start = i * 2 * step
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
90 |
cond_clip = torch.cat([cond_clip, pad], dim=1)
|
|
|
91 |
|
92 |
-
#
|
93 |
-
bucket_clip = last_prompts[:, start : start + 50]
|
94 |
-
if bucket_clip.shape[1] < 50:
|
95 |
-
pad = torch.zeros_like(bucket_clip[:, :
|
96 |
bucket_clip = torch.cat([bucket_clip, pad], dim=1)
|
97 |
-
|
98 |
-
bucket_clip = bucket_clip.unsqueeze(1) # → (1,1,50,1,384) ✔ 5-D
|
99 |
-
# -----------------------------------------------------------------
|
100 |
|
101 |
motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
|
102 |
|
103 |
ref_list.append(ref_img[0])
|
104 |
-
audio_list.append(audio_pe(cond_clip
|
105 |
-
uncond_list.append(audio_pe(torch.zeros_like(cond_clip)
|
106 |
motion_buckets.append(motion[0])
|
107 |
|
108 |
-
#
|
109 |
video = pipe(
|
110 |
ref_img, clip_img, face_mask,
|
111 |
audio_list, uncond_list, motion_buckets,
|
112 |
height=height, width=width,
|
113 |
num_frames=len(audio_list),
|
114 |
-
decode_chunk_size=
|
115 |
-
motion_bucket_scale=
|
116 |
-
fps=
|
117 |
-
noise_aug_strength=
|
118 |
-
min_guidance_scale1=
|
119 |
-
max_guidance_scale1=
|
120 |
-
min_guidance_scale2=
|
121 |
-
max_guidance_scale2=
|
122 |
-
overlap=
|
123 |
-
shift_offset=
|
124 |
-
frames_per_batch=
|
125 |
-
num_inference_steps=
|
126 |
-
i2i_noise_strength=
|
127 |
).frames
|
128 |
|
129 |
video = (video * 0.5 + 0.5).clamp(0, 1)
|
@@ -131,16 +123,16 @@ def test(
|
|
131 |
|
132 |
|
133 |
# ------------------------------------------------------------------
|
134 |
-
#
|
135 |
# ------------------------------------------------------------------
|
136 |
class Sonic:
|
137 |
config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
|
138 |
config = OmegaConf.load(config_file)
|
139 |
|
140 |
def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
|
141 |
-
cfg
|
142 |
cfg.use_interframe = enable_interpolate_frame
|
143 |
-
self.device
|
144 |
cfg.pretrained_model_name_or_path = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path)
|
145 |
|
146 |
self._load_models(cfg)
|
@@ -159,9 +151,9 @@ class Sonic:
|
|
159 |
a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
|
160 |
a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
|
161 |
|
162 |
-
unet.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path),
|
163 |
-
a2t.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path),
|
164 |
-
a2b.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path),
|
165 |
|
166 |
whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
|
167 |
whisper.requires_grad_(False)
|
@@ -188,78 +180,50 @@ class Sonic:
|
|
188 |
_, _, bboxes = self.face_det(img, maxface=True)
|
189 |
if bboxes:
|
190 |
x1, y1, ww, hh = bboxes[0]
|
191 |
-
return {"face_num": 1, "crop_bbox": process_bbox((x1, y1, x1
|
192 |
return {"face_num": 0, "crop_bbox": None}
|
193 |
|
194 |
# --------------------------------------------------------------
|
195 |
@torch.no_grad()
|
196 |
-
def process(
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
min_resolution: int = 512,
|
202 |
-
inference_steps: int = 25,
|
203 |
-
dynamic_scale: float = 1.0,
|
204 |
-
keep_resolution: bool = False,
|
205 |
-
seed: int | None = None,
|
206 |
-
):
|
207 |
cfg = self.config
|
208 |
if seed is not None:
|
209 |
cfg.seed = seed
|
210 |
-
cfg.num_inference_steps
|
211 |
-
cfg.motion_bucket_scale
|
212 |
seed_everything(cfg.seed)
|
213 |
|
214 |
-
# 이미지·오디오 → tensor
|
215 |
test_data = image_audio_to_tensor(
|
216 |
-
self.face_det,
|
217 |
-
|
218 |
-
|
219 |
-
audio_path,
|
220 |
-
limit=-1,
|
221 |
-
image_size=min_resolution,
|
222 |
-
area=cfg.area,
|
223 |
)
|
224 |
if test_data is None:
|
225 |
return -1
|
226 |
|
227 |
h, w = test_data["ref_img"].shape[-2:]
|
228 |
-
resolution = (
|
229 |
-
|
230 |
-
if keep_resolution
|
231 |
-
else f"{w}x{h}"
|
232 |
-
)
|
233 |
|
234 |
-
|
235 |
-
|
236 |
-
self.pipe,
|
237 |
-
cfg,
|
238 |
-
wav_enc=self.whisper,
|
239 |
-
audio_pe=self.audio2token,
|
240 |
-
audio2bucket=self.audio2bucket,
|
241 |
-
image_encoder=self.image_encoder,
|
242 |
-
width=w,
|
243 |
-
height=h,
|
244 |
-
batch=test_data,
|
245 |
-
)
|
246 |
|
247 |
-
# 중간 프레임 보간
|
248 |
if cfg.use_interframe:
|
249 |
out = video.to(self.device)
|
250 |
frames = []
|
251 |
-
for i in tqdm(range(out.shape[2]
|
252 |
-
mid = self.rife.inference(out[
|
253 |
-
frames.extend([out[
|
254 |
-
frames.append(out[
|
255 |
video = torch.stack(frames, 2).cpu()
|
256 |
|
257 |
-
# 저장
|
258 |
tmp = output_path.replace(".mp4", "_noaudio.mp4")
|
259 |
-
save_videos_grid(video, tmp, n_rows=video.shape[0], fps=cfg.fps
|
260 |
-
os.system(
|
261 |
-
|
262 |
-
f"-vcodec libx264 -acodec aac -crf 18 -shortest '{output_path}' -y -loglevel error"
|
263 |
-
)
|
264 |
os.remove(tmp)
|
265 |
return 0
|
|
|
9 |
|
10 |
from src.utils.util import save_videos_grid, seed_everything
|
11 |
from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor
|
12 |
+
from src.models.base.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel, add_ip_adapters
|
|
|
|
|
13 |
from src.pipelines.pipeline_sonic import SonicPipeline
|
14 |
from src.models.audio_adapter.audio_proj import AudioProjModel
|
15 |
from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel
|
|
|
20 |
|
21 |
|
22 |
# ------------------------------------------------------------------
|
23 |
+
# single image + speech → video-tensor generator
|
24 |
# ------------------------------------------------------------------
|
25 |
+
def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
|
26 |
+
width, height, batch):
|
27 |
+
|
28 |
+
# ---------- 배치 차원 맞추기 -----------------------------------------
|
|
|
29 |
for k, v in batch.items():
|
30 |
if isinstance(v, torch.Tensor):
|
31 |
batch[k] = v.unsqueeze(0).to(pipe.device).float()
|
|
|
36 |
image_embeds = image_encoder(clip_img).image_embeds
|
37 |
|
38 |
audio_feature = batch["audio_feature"] # (1,80,T)
|
39 |
+
audio_len = int(batch["audio_len"])
|
40 |
+
step = max(1, int(cfg.step)) # 최소 1 보장
|
|
|
|
|
|
|
|
|
41 |
|
42 |
+
window = 16_000 # 1-second chunk
|
43 |
audio_prompts, last_prompts = [], []
|
44 |
|
45 |
+
# ---------- Whisper 인코딩 ------------------------------------------
|
46 |
for i in range(0, audio_feature.shape[-1], window):
|
47 |
+
chunk = audio_feature[:, :, i:i+window]
|
48 |
|
49 |
+
hs_all = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
|
50 |
+
last_hid = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2) # (1,t,1,384)
|
51 |
|
52 |
+
audio_prompts.append(torch.stack(hs_all, dim=2)) # (1,t,12,384)
|
53 |
+
last_prompts.append(last_hid) # (1,t,1,384)
|
54 |
|
55 |
+
if not audio_prompts:
|
56 |
raise ValueError("[ERROR] No speech recognised in the provided audio.")
|
57 |
|
58 |
+
audio_prompts = torch.cat(audio_prompts, dim=1) # (1,T,12,384)
|
59 |
+
last_prompts = torch.cat(last_prompts, dim=1) # (1,T,1,384)
|
60 |
|
61 |
+
# ---------- padding 규칙 --------------------------------------------
|
62 |
audio_prompts = torch.cat(
|
63 |
[torch.zeros_like(audio_prompts[:, :4]),
|
64 |
audio_prompts,
|
|
|
68 |
last_prompts,
|
69 |
torch.zeros_like(last_prompts[:, :26])], dim=1)
|
70 |
|
|
|
71 |
total_tokens = audio_prompts.shape[1]
|
72 |
num_chunks = max(1, math.ceil(total_tokens / (2 * step)))
|
73 |
|
|
|
76 |
for i in tqdm(range(num_chunks)):
|
77 |
start = i * 2 * step
|
78 |
|
79 |
+
# --------- cond_clip : (1,10,12,384) → (1,10,5,384) --------------
|
80 |
+
cond_clip = audio_prompts[:, start : start + 10] # (1,≤10,12,384)
|
81 |
+
if cond_clip.shape[1] < 10: # seq_len 패딩
|
82 |
+
pad = torch.zeros_like(cond_clip[:, :10-cond_clip.shape[1]])
|
83 |
cond_clip = torch.cat([cond_clip, pad], dim=1)
|
84 |
+
cond_clip = cond_clip[:, :, :5, :] # 5 blocks 선택
|
85 |
|
86 |
+
# --------- bucket_clip : (1,50,1,384) → unsqueeze(0) -----
|
87 |
+
bucket_clip = last_prompts[:, start : start + 50] # (1,≤50,1,384)
|
88 |
+
if bucket_clip.shape[1] < 50: # 길이 패딩
|
89 |
+
pad = torch.zeros_like(bucket_clip[:, :50-bucket_clip.shape[1]])
|
90 |
bucket_clip = torch.cat([bucket_clip, pad], dim=1)
|
91 |
+
bucket_clip = bucket_clip.unsqueeze(0) # (1,1,50,1,384)
|
|
|
|
|
92 |
|
93 |
motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
|
94 |
|
95 |
ref_list.append(ref_img[0])
|
96 |
+
audio_list.append(audio_pe(cond_clip).squeeze(0)[0]) # (10,*)
|
97 |
+
uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0)[0])
|
98 |
motion_buckets.append(motion[0])
|
99 |
|
100 |
+
# ---------- diffusion ------------------------------------------------
|
101 |
video = pipe(
|
102 |
ref_img, clip_img, face_mask,
|
103 |
audio_list, uncond_list, motion_buckets,
|
104 |
height=height, width=width,
|
105 |
num_frames=len(audio_list),
|
106 |
+
decode_chunk_size=cfg.decode_chunk_size,
|
107 |
+
motion_bucket_scale=cfg.motion_bucket_scale,
|
108 |
+
fps=cfg.fps,
|
109 |
+
noise_aug_strength=cfg.noise_aug_strength,
|
110 |
+
min_guidance_scale1=cfg.min_appearance_guidance_scale,
|
111 |
+
max_guidance_scale1=cfg.max_appearance_guidance_scale,
|
112 |
+
min_guidance_scale2=cfg.audio_guidance_scale,
|
113 |
+
max_guidance_scale2=cfg.audio_guidance_scale,
|
114 |
+
overlap=cfg.overlap,
|
115 |
+
shift_offset=cfg.shift_offset,
|
116 |
+
frames_per_batch=cfg.n_sample_frames,
|
117 |
+
num_inference_steps=cfg.num_inference_steps,
|
118 |
+
i2i_noise_strength=cfg.i2i_noise_strength,
|
119 |
).frames
|
120 |
|
121 |
video = (video * 0.5 + 0.5).clamp(0, 1)
|
|
|
123 |
|
124 |
|
125 |
# ------------------------------------------------------------------
|
126 |
+
# Sonic class
|
127 |
# ------------------------------------------------------------------
|
128 |
class Sonic:
|
129 |
config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
|
130 |
config = OmegaConf.load(config_file)
|
131 |
|
132 |
def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
|
133 |
+
cfg = self.config
|
134 |
cfg.use_interframe = enable_interpolate_frame
|
135 |
+
self.device = f"cuda:{device_id}" if device_id >= 0 and torch.cuda.is_available() else "cpu"
|
136 |
cfg.pretrained_model_name_or_path = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path)
|
137 |
|
138 |
self._load_models(cfg)
|
|
|
151 |
a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
|
152 |
a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
|
153 |
|
154 |
+
unet.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu"))
|
155 |
+
a2t.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu"))
|
156 |
+
a2b.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu"))
|
157 |
|
158 |
whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
|
159 |
whisper.requires_grad_(False)
|
|
|
180 |
_, _, bboxes = self.face_det(img, maxface=True)
|
181 |
if bboxes:
|
182 |
x1, y1, ww, hh = bboxes[0]
|
183 |
+
return {"face_num": 1, "crop_bbox": process_bbox((x1, y1, x1+ww, y1+hh), expand_ratio, h, w)}
|
184 |
return {"face_num": 0, "crop_bbox": None}
|
185 |
|
186 |
# --------------------------------------------------------------
|
187 |
@torch.no_grad()
|
188 |
+
def process(self, image_path: str, audio_path: str, output_path: str,
|
189 |
+
min_resolution: int = 512, inference_steps: int = 25,
|
190 |
+
dynamic_scale: float = 1.0, keep_resolution: bool = False,
|
191 |
+
seed: int | None = None):
|
192 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
cfg = self.config
|
194 |
if seed is not None:
|
195 |
cfg.seed = seed
|
196 |
+
cfg.num_inference_steps = inference_steps
|
197 |
+
cfg.motion_bucket_scale = dynamic_scale
|
198 |
seed_everything(cfg.seed)
|
199 |
|
|
|
200 |
test_data = image_audio_to_tensor(
|
201 |
+
self.face_det, self.feature_extractor,
|
202 |
+
image_path, audio_path, limit=-1,
|
203 |
+
image_size=min_resolution, area=cfg.area,
|
|
|
|
|
|
|
|
|
204 |
)
|
205 |
if test_data is None:
|
206 |
return -1
|
207 |
|
208 |
h, w = test_data["ref_img"].shape[-2:]
|
209 |
+
resolution = (f"{(Image.open(image_path).width//2)*2}x{(Image.open(image_path).height//2)*2}"
|
210 |
+
if keep_resolution else f"{w}x{h}")
|
|
|
|
|
|
|
211 |
|
212 |
+
video = test(self.pipe, cfg, self.whisper, self.audio2token,
|
213 |
+
self.audio2bucket, self.image_encoder, w, h, test_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
|
|
215 |
if cfg.use_interframe:
|
216 |
out = video.to(self.device)
|
217 |
frames = []
|
218 |
+
for i in tqdm(range(out.shape[2]-1), ncols=0):
|
219 |
+
mid = self.rife.inference(out[:,:,i], out[:,:,i+1]).clamp(0,1).detach()
|
220 |
+
frames.extend([out[:,:,i], mid])
|
221 |
+
frames.append(out[:,:,-1])
|
222 |
video = torch.stack(frames, 2).cpu()
|
223 |
|
|
|
224 |
tmp = output_path.replace(".mp4", "_noaudio.mp4")
|
225 |
+
save_videos_grid(video, tmp, n_rows=video.shape[0], fps=cfg.fps*(2 if cfg.use_interframe else 1))
|
226 |
+
os.system(f"ffmpeg -i '{tmp}' -i '{audio_path}' -s {resolution} "
|
227 |
+
f"-vcodec libx264 -acodec aac -crf 18 -shortest '{output_path}' -y -loglevel error")
|
|
|
|
|
228 |
os.remove(tmp)
|
229 |
return 0
|