openfree commited on
Commit
914dc02
·
verified ·
1 Parent(s): b1cb088

Update sonic.py

Browse files
Files changed (1) hide show
  1. sonic.py +33 -35
sonic.py CHANGED
@@ -22,6 +22,11 @@ from src.dataset.face_align.align import AlignImage
22
 
23
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
24
 
 
 
 
 
 
25
  # ------------------------------------------------------------------
26
  # single image + speech → video-tensor generator
27
  # ------------------------------------------------------------------
@@ -29,32 +34,29 @@ def test(
29
  pipe, config, wav_enc, audio_pe, audio2bucket, image_encoder,
30
  width, height, batch,
31
  ):
32
- # ---------------- batch 차원 맞추기 -----------------------------
33
  for k, v in batch.items():
34
  if isinstance(v, torch.Tensor):
35
  batch[k] = v.unsqueeze(0).to(pipe.device).float()
36
 
37
- ref_img = batch["ref_img"] # (1,C,H,W)
38
  clip_img = batch["clip_images"]
39
  face_mask = batch["face_mask"]
40
- image_embeds = image_encoder(clip_img).image_embeds # (1,1024)
41
 
42
- audio_feature = batch["audio_feature"] # (1,80,T)
43
- audio_len = int(batch["audio_len"]) # python int
44
  step = int(config.step)
45
 
46
- # ---------- window 단위 Whisper 인코딩 --------------------------
47
- window = 16_000 # 1 초
48
  audio_prompts, last_prompts = [], []
49
 
50
  for i in range(0, audio_feature.shape[-1], window):
51
  chunk = audio_feature[:, :, i : i + window]
52
-
53
  layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
54
  last = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
55
-
56
- audio_prompts.append(torch.stack(layers, dim=2)) # (1,?,L,384)
57
- last_prompts.append(last) # (1,?,1,384)
58
 
59
  if not audio_prompts:
60
  raise ValueError("[ERROR] No speech recognised in the provided audio.")
@@ -62,17 +64,13 @@ def test(
62
  audio_prompts = torch.cat(audio_prompts, dim=1)
63
  last_prompts = torch.cat(last_prompts, dim=1)
64
 
65
- # ---------- 모델 입력 규칙에 맞춰 padding -----------------------
66
  audio_prompts = torch.cat(
67
- [torch.zeros_like(audio_prompts[:, :4]), # head pad
68
- audio_prompts,
69
  torch.zeros_like(audio_prompts[:, :6])], dim=1)
70
  last_prompts = torch.cat(
71
- [torch.zeros_like(last_prompts[:, :24]),
72
- last_prompts,
73
  torch.zeros_like(last_prompts[:, :26])], dim=1)
74
 
75
- # ---------- 음성 길이에 따라 chunk 횟수 산정 ---------------------
76
  total_tokens = audio_prompts.shape[1]
77
  num_chunks = max(1, math.ceil(total_tokens / (2 * step)))
78
 
@@ -81,38 +79,35 @@ def test(
81
  for i in tqdm(range(num_chunks)):
82
  start = i * 2 * step
83
 
84
- # ---------------- cond_clip (w=10,L=5) --------------------
85
- clip_raw = audio_prompts[:, start : start + 10] # (1,≤10,L,384)
86
-
87
- # w-pad
88
- if clip_raw.shape[1] < 10:
89
  pad_w = torch.zeros_like(clip_raw[:, :10 - clip_raw.shape[1]])
90
  clip_raw = torch.cat([clip_raw, pad_w], dim=1)
91
 
92
- # ★ L-pad (Whisper-tiny L=2 5 확장)
93
- if clip_raw.shape[2] < 5:
94
- pad_L = clip_raw[:, :, -1:].repeat(1, 1, 5 - clip_raw.shape[2], 1)
95
- clip_raw = torch.cat([clip_raw, pad_L], dim=2)
96
 
97
- clip_raw = clip_raw[:, :, :5] # (1,10,5,384)
98
- cond_clip = clip_raw.unsqueeze(1) # (1,1,10,5,384)
99
 
100
- # ---------------- bucket_clip (w=50,L=1) ------------------
101
- bucket_raw = last_prompts[:, start : start + 50] # (1,≤50,1,384)
102
  if bucket_raw.shape[1] < 50:
103
  pad_w = torch.zeros_like(bucket_raw[:, :50 - bucket_raw.shape[1]])
104
  bucket_raw = torch.cat([bucket_raw, pad_w], dim=1)
105
-
106
- bucket_clip = bucket_raw.unsqueeze(1) # (1,1,50,1,384)
107
 
108
  motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
109
 
110
  ref_list.append(ref_img[0])
111
- audio_list.append(audio_pe(cond_clip).squeeze(0)[0]) # (10,1024)
112
- uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0)[0])
 
113
  motion_buckets.append(motion[0])
114
 
115
- # ---------- Stable-Video-Diffusion 호출 -------------------------
116
  video = pipe(
117
  ref_img, clip_img, face_mask,
118
  audio_list, uncond_list, motion_buckets,
@@ -137,6 +132,9 @@ def test(
137
  return video.to(pipe.device).unsqueeze(0).cpu()
138
 
139
 
 
 
 
140
  # ------------------------------------------------------------------
141
  # Sonic class
142
  # ------------------------------------------------------------------
 
22
 
23
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
24
 
25
+ # ------------------------------------------------------------------
26
+ # single image + speech → video-tensor generator
27
+ # ------------------------------------------------------------------
28
+ # …(상단 import 및 기타 정의 동일)…
29
+
30
  # ------------------------------------------------------------------
31
  # single image + speech → video-tensor generator
32
  # ------------------------------------------------------------------
 
34
  pipe, config, wav_enc, audio_pe, audio2bucket, image_encoder,
35
  width, height, batch,
36
  ):
37
+ # ---- 배치 차원 맞추기 -----------------------------------------
38
  for k, v in batch.items():
39
  if isinstance(v, torch.Tensor):
40
  batch[k] = v.unsqueeze(0).to(pipe.device).float()
41
 
42
+ ref_img = batch["ref_img"]
43
  clip_img = batch["clip_images"]
44
  face_mask = batch["face_mask"]
45
+ image_embeds = image_encoder(clip_img).image_embeds
46
 
47
+ audio_feature = batch["audio_feature"]
48
+ audio_len = int(batch["audio_len"])
49
  step = int(config.step)
50
 
51
+ window = 16_000 # 1
 
52
  audio_prompts, last_prompts = [], []
53
 
54
  for i in range(0, audio_feature.shape[-1], window):
55
  chunk = audio_feature[:, :, i : i + window]
 
56
  layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
57
  last = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
58
+ audio_prompts.append(torch.stack(layers, dim=2)) # (1,w,L,384)
59
+ last_prompts.append(last)
 
60
 
61
  if not audio_prompts:
62
  raise ValueError("[ERROR] No speech recognised in the provided audio.")
 
64
  audio_prompts = torch.cat(audio_prompts, dim=1)
65
  last_prompts = torch.cat(last_prompts, dim=1)
66
 
 
67
  audio_prompts = torch.cat(
68
+ [torch.zeros_like(audio_prompts[:, :4]), audio_prompts,
 
69
  torch.zeros_like(audio_prompts[:, :6])], dim=1)
70
  last_prompts = torch.cat(
71
+ [torch.zeros_like(last_prompts[:, :24]), last_prompts,
 
72
  torch.zeros_like(last_prompts[:, :26])], dim=1)
73
 
 
74
  total_tokens = audio_prompts.shape[1]
75
  num_chunks = max(1, math.ceil(total_tokens / (2 * step)))
76
 
 
79
  for i in tqdm(range(num_chunks)):
80
  start = i * 2 * step
81
 
82
+ # ------------ cond_clip : (1,1,10,5,384) ------------------
83
+ clip_raw = audio_prompts[:, start : start + 10] # (1,≤10,L,384)
84
+ if clip_raw.shape[1] < 10: # w-pad
 
 
85
  pad_w = torch.zeros_like(clip_raw[:, :10 - clip_raw.shape[1]])
86
  clip_raw = torch.cat([clip_raw, pad_w], dim=1)
87
 
88
+ # ★ L-pad → 정확히 5 레이어 만들기
89
+ while clip_raw.shape[2] < 5:
90
+ clip_raw = torch.cat([clip_raw, clip_raw[:, :, -1:]], dim=2)
91
+ clip_raw = clip_raw[:, :, :5] # (1,10,5,384)
92
 
93
+ cond_clip = clip_raw.unsqueeze(1) # (1,1,10,5,384)
 
94
 
95
+ # ------------ bucket_clip : (1,1,50,1,384) -----------------
96
+ bucket_raw = last_prompts[:, start : start + 50]
97
  if bucket_raw.shape[1] < 50:
98
  pad_w = torch.zeros_like(bucket_raw[:, :50 - bucket_raw.shape[1]])
99
  bucket_raw = torch.cat([bucket_raw, pad_w], dim=1)
100
+ bucket_clip = bucket_raw.unsqueeze(1) # (1,1,50,1,384)
 
101
 
102
  motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
103
 
104
  ref_list.append(ref_img[0])
105
+ # ★ 여기: squeeze(0)만 (bz 제거). [0] 인덱싱 제거
106
+ audio_list.append(audio_pe(cond_clip).squeeze(0)) # (50,1024)
107
+ uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0))
108
  motion_buckets.append(motion[0])
109
 
110
+ # ---- Stable Video Diffusion 호출 ------------------------------
111
  video = pipe(
112
  ref_img, clip_img, face_mask,
113
  audio_list, uncond_list, motion_buckets,
 
132
  return video.to(pipe.device).unsqueeze(0).cpu()
133
 
134
 
135
+
136
+
137
+
138
  # ------------------------------------------------------------------
139
  # Sonic class
140
  # ------------------------------------------------------------------