Spaces:
Running
on
Zero
Running
on
Zero
Update sonic.py
Browse files
sonic.py
CHANGED
@@ -22,6 +22,11 @@ from src.dataset.face_align.align import AlignImage
|
|
22 |
|
23 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
24 |
|
|
|
|
|
|
|
|
|
|
|
25 |
# ------------------------------------------------------------------
|
26 |
# single image + speech → video-tensor generator
|
27 |
# ------------------------------------------------------------------
|
@@ -29,32 +34,29 @@ def test(
|
|
29 |
pipe, config, wav_enc, audio_pe, audio2bucket, image_encoder,
|
30 |
width, height, batch,
|
31 |
):
|
32 |
-
#
|
33 |
for k, v in batch.items():
|
34 |
if isinstance(v, torch.Tensor):
|
35 |
batch[k] = v.unsqueeze(0).to(pipe.device).float()
|
36 |
|
37 |
-
ref_img = batch["ref_img"]
|
38 |
clip_img = batch["clip_images"]
|
39 |
face_mask = batch["face_mask"]
|
40 |
-
image_embeds = image_encoder(clip_img).image_embeds
|
41 |
|
42 |
-
audio_feature = batch["audio_feature"]
|
43 |
-
audio_len = int(batch["audio_len"])
|
44 |
step = int(config.step)
|
45 |
|
46 |
-
|
47 |
-
window = 16_000 # 1 초
|
48 |
audio_prompts, last_prompts = [], []
|
49 |
|
50 |
for i in range(0, audio_feature.shape[-1], window):
|
51 |
chunk = audio_feature[:, :, i : i + window]
|
52 |
-
|
53 |
layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
|
54 |
last = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
|
55 |
-
|
56 |
-
|
57 |
-
last_prompts.append(last) # (1,?,1,384)
|
58 |
|
59 |
if not audio_prompts:
|
60 |
raise ValueError("[ERROR] No speech recognised in the provided audio.")
|
@@ -62,17 +64,13 @@ def test(
|
|
62 |
audio_prompts = torch.cat(audio_prompts, dim=1)
|
63 |
last_prompts = torch.cat(last_prompts, dim=1)
|
64 |
|
65 |
-
# ---------- 모델 입력 규칙에 맞춰 padding -----------------------
|
66 |
audio_prompts = torch.cat(
|
67 |
-
[torch.zeros_like(audio_prompts[:, :4]),
|
68 |
-
audio_prompts,
|
69 |
torch.zeros_like(audio_prompts[:, :6])], dim=1)
|
70 |
last_prompts = torch.cat(
|
71 |
-
[torch.zeros_like(last_prompts[:, :24]),
|
72 |
-
last_prompts,
|
73 |
torch.zeros_like(last_prompts[:, :26])], dim=1)
|
74 |
|
75 |
-
# ---------- 음성 길이에 따라 chunk 횟수 산정 ---------------------
|
76 |
total_tokens = audio_prompts.shape[1]
|
77 |
num_chunks = max(1, math.ceil(total_tokens / (2 * step)))
|
78 |
|
@@ -81,38 +79,35 @@ def test(
|
|
81 |
for i in tqdm(range(num_chunks)):
|
82 |
start = i * 2 * step
|
83 |
|
84 |
-
#
|
85 |
-
clip_raw = audio_prompts[:, start : start + 10]
|
86 |
-
|
87 |
-
# w-pad
|
88 |
-
if clip_raw.shape[1] < 10:
|
89 |
pad_w = torch.zeros_like(clip_raw[:, :10 - clip_raw.shape[1]])
|
90 |
clip_raw = torch.cat([clip_raw, pad_w], dim=1)
|
91 |
|
92 |
-
# ★ L-pad
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
|
97 |
-
|
98 |
-
cond_clip = clip_raw.unsqueeze(1) # (1,1,10,5,384)
|
99 |
|
100 |
-
#
|
101 |
-
bucket_raw = last_prompts[:, start : start + 50]
|
102 |
if bucket_raw.shape[1] < 50:
|
103 |
pad_w = torch.zeros_like(bucket_raw[:, :50 - bucket_raw.shape[1]])
|
104 |
bucket_raw = torch.cat([bucket_raw, pad_w], dim=1)
|
105 |
-
|
106 |
-
bucket_clip = bucket_raw.unsqueeze(1) # (1,1,50,1,384)
|
107 |
|
108 |
motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
|
109 |
|
110 |
ref_list.append(ref_img[0])
|
111 |
-
|
112 |
-
|
|
|
113 |
motion_buckets.append(motion[0])
|
114 |
|
115 |
-
#
|
116 |
video = pipe(
|
117 |
ref_img, clip_img, face_mask,
|
118 |
audio_list, uncond_list, motion_buckets,
|
@@ -137,6 +132,9 @@ def test(
|
|
137 |
return video.to(pipe.device).unsqueeze(0).cpu()
|
138 |
|
139 |
|
|
|
|
|
|
|
140 |
# ------------------------------------------------------------------
|
141 |
# Sonic class
|
142 |
# ------------------------------------------------------------------
|
|
|
22 |
|
23 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
24 |
|
25 |
+
# ------------------------------------------------------------------
|
26 |
+
# single image + speech → video-tensor generator
|
27 |
+
# ------------------------------------------------------------------
|
28 |
+
# …(상단 import 및 기타 정의 동일)…
|
29 |
+
|
30 |
# ------------------------------------------------------------------
|
31 |
# single image + speech → video-tensor generator
|
32 |
# ------------------------------------------------------------------
|
|
|
34 |
pipe, config, wav_enc, audio_pe, audio2bucket, image_encoder,
|
35 |
width, height, batch,
|
36 |
):
|
37 |
+
# ---- 배치 차원 맞추기 -----------------------------------------
|
38 |
for k, v in batch.items():
|
39 |
if isinstance(v, torch.Tensor):
|
40 |
batch[k] = v.unsqueeze(0).to(pipe.device).float()
|
41 |
|
42 |
+
ref_img = batch["ref_img"]
|
43 |
clip_img = batch["clip_images"]
|
44 |
face_mask = batch["face_mask"]
|
45 |
+
image_embeds = image_encoder(clip_img).image_embeds
|
46 |
|
47 |
+
audio_feature = batch["audio_feature"]
|
48 |
+
audio_len = int(batch["audio_len"])
|
49 |
step = int(config.step)
|
50 |
|
51 |
+
window = 16_000 # 1 초
|
|
|
52 |
audio_prompts, last_prompts = [], []
|
53 |
|
54 |
for i in range(0, audio_feature.shape[-1], window):
|
55 |
chunk = audio_feature[:, :, i : i + window]
|
|
|
56 |
layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
|
57 |
last = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
|
58 |
+
audio_prompts.append(torch.stack(layers, dim=2)) # (1,w,L,384)
|
59 |
+
last_prompts.append(last)
|
|
|
60 |
|
61 |
if not audio_prompts:
|
62 |
raise ValueError("[ERROR] No speech recognised in the provided audio.")
|
|
|
64 |
audio_prompts = torch.cat(audio_prompts, dim=1)
|
65 |
last_prompts = torch.cat(last_prompts, dim=1)
|
66 |
|
|
|
67 |
audio_prompts = torch.cat(
|
68 |
+
[torch.zeros_like(audio_prompts[:, :4]), audio_prompts,
|
|
|
69 |
torch.zeros_like(audio_prompts[:, :6])], dim=1)
|
70 |
last_prompts = torch.cat(
|
71 |
+
[torch.zeros_like(last_prompts[:, :24]), last_prompts,
|
|
|
72 |
torch.zeros_like(last_prompts[:, :26])], dim=1)
|
73 |
|
|
|
74 |
total_tokens = audio_prompts.shape[1]
|
75 |
num_chunks = max(1, math.ceil(total_tokens / (2 * step)))
|
76 |
|
|
|
79 |
for i in tqdm(range(num_chunks)):
|
80 |
start = i * 2 * step
|
81 |
|
82 |
+
# ------------ cond_clip : (1,1,10,5,384) ------------------
|
83 |
+
clip_raw = audio_prompts[:, start : start + 10] # (1,≤10,L,384)
|
84 |
+
if clip_raw.shape[1] < 10: # w-pad
|
|
|
|
|
85 |
pad_w = torch.zeros_like(clip_raw[:, :10 - clip_raw.shape[1]])
|
86 |
clip_raw = torch.cat([clip_raw, pad_w], dim=1)
|
87 |
|
88 |
+
# ★ L-pad → 정확히 5 레이어 만들기
|
89 |
+
while clip_raw.shape[2] < 5:
|
90 |
+
clip_raw = torch.cat([clip_raw, clip_raw[:, :, -1:]], dim=2)
|
91 |
+
clip_raw = clip_raw[:, :, :5] # (1,10,5,384)
|
92 |
|
93 |
+
cond_clip = clip_raw.unsqueeze(1) # (1,1,10,5,384)
|
|
|
94 |
|
95 |
+
# ------------ bucket_clip : (1,1,50,1,384) -----------------
|
96 |
+
bucket_raw = last_prompts[:, start : start + 50]
|
97 |
if bucket_raw.shape[1] < 50:
|
98 |
pad_w = torch.zeros_like(bucket_raw[:, :50 - bucket_raw.shape[1]])
|
99 |
bucket_raw = torch.cat([bucket_raw, pad_w], dim=1)
|
100 |
+
bucket_clip = bucket_raw.unsqueeze(1) # (1,1,50,1,384)
|
|
|
101 |
|
102 |
motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
|
103 |
|
104 |
ref_list.append(ref_img[0])
|
105 |
+
# ★ 여기: squeeze(0)만 (bz 제거). [0] 인덱싱 제거
|
106 |
+
audio_list.append(audio_pe(cond_clip).squeeze(0)) # (50,1024)
|
107 |
+
uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0))
|
108 |
motion_buckets.append(motion[0])
|
109 |
|
110 |
+
# ---- Stable Video Diffusion 호출 ------------------------------
|
111 |
video = pipe(
|
112 |
ref_img, clip_img, face_mask,
|
113 |
audio_list, uncond_list, motion_buckets,
|
|
|
132 |
return video.to(pipe.device).unsqueeze(0).cpu()
|
133 |
|
134 |
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
# ------------------------------------------------------------------
|
139 |
# Sonic class
|
140 |
# ------------------------------------------------------------------
|