Spaces:
Running
on
Zero
Running
on
Zero
Update sonic.py
Browse files
sonic.py
CHANGED
@@ -25,24 +25,24 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
25 |
def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
|
26 |
width, height, batch):
|
27 |
|
28 |
-
#
|
29 |
for k, v in batch.items():
|
30 |
if isinstance(v, torch.Tensor):
|
31 |
batch[k] = v.unsqueeze(0).to(pipe.device).float()
|
32 |
|
33 |
-
ref_img = batch["ref_img"]
|
34 |
clip_img = batch["clip_images"]
|
35 |
face_mask = batch["face_mask"]
|
36 |
image_embeds = image_encoder(clip_img).image_embeds
|
37 |
|
38 |
-
audio_feature = batch["audio_feature"]
|
39 |
audio_len = int(batch["audio_len"])
|
40 |
-
step = max(1, int(cfg.step))
|
41 |
|
42 |
-
|
|
|
43 |
audio_prompts, last_prompts = [], []
|
44 |
|
45 |
-
# ---------- Whisper 인코딩 ------------------------------------------
|
46 |
for i in range(0, audio_feature.shape[-1], window):
|
47 |
chunk = audio_feature[:, :, i:i+window]
|
48 |
|
@@ -58,7 +58,7 @@ def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
|
|
58 |
audio_prompts = torch.cat(audio_prompts, dim=1) # (1,T,12,384)
|
59 |
last_prompts = torch.cat(last_prompts, dim=1) # (1,T,1,384)
|
60 |
|
61 |
-
#
|
62 |
audio_prompts = torch.cat(
|
63 |
[torch.zeros_like(audio_prompts[:, :4]),
|
64 |
audio_prompts,
|
@@ -76,28 +76,29 @@ def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
|
|
76 |
for i in tqdm(range(num_chunks)):
|
77 |
start = i * 2 * step
|
78 |
|
79 |
-
#
|
80 |
-
cond_clip = audio_prompts[:, start
|
81 |
-
if cond_clip.shape[1] < 10:
|
82 |
pad = torch.zeros_like(cond_clip[:, :10-cond_clip.shape[1]])
|
83 |
cond_clip = torch.cat([cond_clip, pad], dim=1)
|
84 |
-
cond_clip = cond_clip
|
|
|
85 |
|
86 |
-
#
|
87 |
-
bucket_clip = last_prompts[:, start
|
88 |
-
if bucket_clip.shape[1] < 50:
|
89 |
pad = torch.zeros_like(bucket_clip[:, :50-bucket_clip.shape[1]])
|
90 |
bucket_clip = torch.cat([bucket_clip, pad], dim=1)
|
91 |
-
bucket_clip = bucket_clip.unsqueeze(
|
92 |
|
93 |
motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
|
94 |
|
95 |
ref_list.append(ref_img[0])
|
96 |
-
audio_list.append(audio_pe(cond_clip).squeeze(0)[0])
|
97 |
uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0)[0])
|
98 |
motion_buckets.append(motion[0])
|
99 |
|
100 |
-
#
|
101 |
video = pipe(
|
102 |
ref_img, clip_img, face_mask,
|
103 |
audio_list, uncond_list, motion_buckets,
|
@@ -123,7 +124,7 @@ def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
|
|
123 |
|
124 |
|
125 |
# ------------------------------------------------------------------
|
126 |
-
#
|
127 |
# ------------------------------------------------------------------
|
128 |
class Sonic:
|
129 |
config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
|
|
|
25 |
def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
|
26 |
width, height, batch):
|
27 |
|
28 |
+
# -------- 배치 차원 정리 ---------------------------------------------
|
29 |
for k, v in batch.items():
|
30 |
if isinstance(v, torch.Tensor):
|
31 |
batch[k] = v.unsqueeze(0).to(pipe.device).float()
|
32 |
|
33 |
+
ref_img = batch["ref_img"] # (1,C,H,W)
|
34 |
clip_img = batch["clip_images"]
|
35 |
face_mask = batch["face_mask"]
|
36 |
image_embeds = image_encoder(clip_img).image_embeds
|
37 |
|
38 |
+
audio_feature = batch["audio_feature"] # (1,80,T)
|
39 |
audio_len = int(batch["audio_len"])
|
40 |
+
step = max(1, int(cfg.step)) # 최소 1 보장
|
41 |
|
42 |
+
# -------- Whisper 인코딩 --------------------------------------------
|
43 |
+
window = 16_000 # 1-초 단위
|
44 |
audio_prompts, last_prompts = [], []
|
45 |
|
|
|
46 |
for i in range(0, audio_feature.shape[-1], window):
|
47 |
chunk = audio_feature[:, :, i:i+window]
|
48 |
|
|
|
58 |
audio_prompts = torch.cat(audio_prompts, dim=1) # (1,T,12,384)
|
59 |
last_prompts = torch.cat(last_prompts, dim=1) # (1,T,1,384)
|
60 |
|
61 |
+
# -------- 앞뒤 padding ----------------------------------------------
|
62 |
audio_prompts = torch.cat(
|
63 |
[torch.zeros_like(audio_prompts[:, :4]),
|
64 |
audio_prompts,
|
|
|
76 |
for i in tqdm(range(num_chunks)):
|
77 |
start = i * 2 * step
|
78 |
|
79 |
+
# ------ cond_clip : (bz=1, f=1, w=10, b=5, c=384) ----------------
|
80 |
+
cond_clip = audio_prompts[:, start:start+10] # (1,≤10,12,384)
|
81 |
+
if cond_clip.shape[1] < 10: # w 길이 패딩
|
82 |
pad = torch.zeros_like(cond_clip[:, :10-cond_clip.shape[1]])
|
83 |
cond_clip = torch.cat([cond_clip, pad], dim=1)
|
84 |
+
cond_clip = cond_clip.unsqueeze(1) # f 차원 삽입 → (1,1,10,12,384)
|
85 |
+
cond_clip = cond_clip[:, :, :, :5, :] # b 차원 5 로 절단 → (1,1,10,5,384)
|
86 |
|
87 |
+
# ------ bucket_clip : (1,1,50,1,384) -----------------------------
|
88 |
+
bucket_clip = last_prompts[:, start:start+50] # (1,≤50,1,384)
|
89 |
+
if bucket_clip.shape[1] < 50:
|
90 |
pad = torch.zeros_like(bucket_clip[:, :50-bucket_clip.shape[1]])
|
91 |
bucket_clip = torch.cat([bucket_clip, pad], dim=1)
|
92 |
+
bucket_clip = bucket_clip.unsqueeze(1) # (1,1,50,1,384)
|
93 |
|
94 |
motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
|
95 |
|
96 |
ref_list.append(ref_img[0])
|
97 |
+
audio_list.append(audio_pe(cond_clip).squeeze(0)[0]) # (tokens,1024)
|
98 |
uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0)[0])
|
99 |
motion_buckets.append(motion[0])
|
100 |
|
101 |
+
# -------- diffusion --------------------------------------------------
|
102 |
video = pipe(
|
103 |
ref_img, clip_img, face_mask,
|
104 |
audio_list, uncond_list, motion_buckets,
|
|
|
124 |
|
125 |
|
126 |
# ------------------------------------------------------------------
|
127 |
+
# Sonic class
|
128 |
# ------------------------------------------------------------------
|
129 |
class Sonic:
|
130 |
config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
|