openfree commited on
Commit
581c19e
·
verified ·
1 Parent(s): 430d42a

Update sonic.py

Browse files
Files changed (1) hide show
  1. sonic.py +19 -18
sonic.py CHANGED
@@ -25,24 +25,24 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
25
  def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
26
  width, height, batch):
27
 
28
- # ---------- 배치 차원 맞추기 -----------------------------------------
29
  for k, v in batch.items():
30
  if isinstance(v, torch.Tensor):
31
  batch[k] = v.unsqueeze(0).to(pipe.device).float()
32
 
33
- ref_img = batch["ref_img"] # (1,C,H,W)
34
  clip_img = batch["clip_images"]
35
  face_mask = batch["face_mask"]
36
  image_embeds = image_encoder(clip_img).image_embeds
37
 
38
- audio_feature = batch["audio_feature"] # (1,80,T)
39
  audio_len = int(batch["audio_len"])
40
- step = max(1, int(cfg.step)) # 최소 1 보장
41
 
42
- window = 16_000 # 1-second chunk
 
43
  audio_prompts, last_prompts = [], []
44
 
45
- # ---------- Whisper 인코딩 ------------------------------------------
46
  for i in range(0, audio_feature.shape[-1], window):
47
  chunk = audio_feature[:, :, i:i+window]
48
 
@@ -58,7 +58,7 @@ def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
58
  audio_prompts = torch.cat(audio_prompts, dim=1) # (1,T,12,384)
59
  last_prompts = torch.cat(last_prompts, dim=1) # (1,T,1,384)
60
 
61
- # ---------- padding 규칙 --------------------------------------------
62
  audio_prompts = torch.cat(
63
  [torch.zeros_like(audio_prompts[:, :4]),
64
  audio_prompts,
@@ -76,28 +76,29 @@ def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
76
  for i in tqdm(range(num_chunks)):
77
  start = i * 2 * step
78
 
79
- # --------- cond_clip : (1,10,12,384) → (1,10,5,384) --------------
80
- cond_clip = audio_prompts[:, start : start + 10] # (1,≤10,12,384)
81
- if cond_clip.shape[1] < 10: # seq_len 패딩
82
  pad = torch.zeros_like(cond_clip[:, :10-cond_clip.shape[1]])
83
  cond_clip = torch.cat([cond_clip, pad], dim=1)
84
- cond_clip = cond_clip[:, :, :5, :] # 5 blocks 선택
 
85
 
86
- # --------- bucket_clip : (1,50,1,384) → unsqueeze(0) -----
87
- bucket_clip = last_prompts[:, start : start + 50] # (1,≤50,1,384)
88
- if bucket_clip.shape[1] < 50: # 길이 패딩
89
  pad = torch.zeros_like(bucket_clip[:, :50-bucket_clip.shape[1]])
90
  bucket_clip = torch.cat([bucket_clip, pad], dim=1)
91
- bucket_clip = bucket_clip.unsqueeze(0) # (1,1,50,1,384)
92
 
93
  motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
94
 
95
  ref_list.append(ref_img[0])
96
- audio_list.append(audio_pe(cond_clip).squeeze(0)[0]) # (10,*)
97
  uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0)[0])
98
  motion_buckets.append(motion[0])
99
 
100
- # ---------- diffusion ------------------------------------------------
101
  video = pipe(
102
  ref_img, clip_img, face_mask,
103
  audio_list, uncond_list, motion_buckets,
@@ -123,7 +124,7 @@ def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
123
 
124
 
125
  # ------------------------------------------------------------------
126
- # Sonic class
127
  # ------------------------------------------------------------------
128
  class Sonic:
129
  config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
 
25
  def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
26
  width, height, batch):
27
 
28
+ # -------- 배치 차원 정리 ---------------------------------------------
29
  for k, v in batch.items():
30
  if isinstance(v, torch.Tensor):
31
  batch[k] = v.unsqueeze(0).to(pipe.device).float()
32
 
33
+ ref_img = batch["ref_img"] # (1,C,H,W)
34
  clip_img = batch["clip_images"]
35
  face_mask = batch["face_mask"]
36
  image_embeds = image_encoder(clip_img).image_embeds
37
 
38
+ audio_feature = batch["audio_feature"] # (1,80,T)
39
  audio_len = int(batch["audio_len"])
40
+ step = max(1, int(cfg.step)) # 최소 1 보장
41
 
42
+ # -------- Whisper 인코딩 --------------------------------------------
43
+ window = 16_000 # 1-초 단위
44
  audio_prompts, last_prompts = [], []
45
 
 
46
  for i in range(0, audio_feature.shape[-1], window):
47
  chunk = audio_feature[:, :, i:i+window]
48
 
 
58
  audio_prompts = torch.cat(audio_prompts, dim=1) # (1,T,12,384)
59
  last_prompts = torch.cat(last_prompts, dim=1) # (1,T,1,384)
60
 
61
+ # -------- 앞뒤 padding ----------------------------------------------
62
  audio_prompts = torch.cat(
63
  [torch.zeros_like(audio_prompts[:, :4]),
64
  audio_prompts,
 
76
  for i in tqdm(range(num_chunks)):
77
  start = i * 2 * step
78
 
79
+ # ------ cond_clip : (bz=1, f=1, w=10, b=5, c=384) ----------------
80
+ cond_clip = audio_prompts[:, start:start+10] # (1,≤10,12,384)
81
+ if cond_clip.shape[1] < 10: # w 길이 패딩
82
  pad = torch.zeros_like(cond_clip[:, :10-cond_clip.shape[1]])
83
  cond_clip = torch.cat([cond_clip, pad], dim=1)
84
+ cond_clip = cond_clip.unsqueeze(1) # f 차원 삽입 (1,1,10,12,384)
85
+ cond_clip = cond_clip[:, :, :, :5, :] # b 차원 5 로 절단 → (1,1,10,5,384)
86
 
87
+ # ------ bucket_clip : (1,1,50,1,384) -----------------------------
88
+ bucket_clip = last_prompts[:, start:start+50] # (1,≤50,1,384)
89
+ if bucket_clip.shape[1] < 50:
90
  pad = torch.zeros_like(bucket_clip[:, :50-bucket_clip.shape[1]])
91
  bucket_clip = torch.cat([bucket_clip, pad], dim=1)
92
+ bucket_clip = bucket_clip.unsqueeze(1) # (1,1,50,1,384)
93
 
94
  motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
95
 
96
  ref_list.append(ref_img[0])
97
+ audio_list.append(audio_pe(cond_clip).squeeze(0)[0]) # (tokens,1024)
98
  uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0)[0])
99
  motion_buckets.append(motion[0])
100
 
101
+ # -------- diffusion --------------------------------------------------
102
  video = pipe(
103
  ref_img, clip_img, face_mask,
104
  audio_list, uncond_list, motion_buckets,
 
124
 
125
 
126
  # ------------------------------------------------------------------
127
+ # Sonic class
128
  # ------------------------------------------------------------------
129
  class Sonic:
130
  config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")