openfree commited on
Commit
e10969c
·
verified ·
1 Parent(s): 6ab32bd

Update sonic.py

Browse files
Files changed (1) hide show
  1. sonic.py +39 -33
sonic.py CHANGED
@@ -1,9 +1,7 @@
1
- import os, math
2
- import torch
3
  from PIL import Image
4
  from omegaconf import OmegaConf
5
  from tqdm import tqdm
6
- import cv2
7
 
8
  from diffusers import AutoencoderKLTemporalDecoder
9
  from diffusers.schedulers import EulerDiscreteScheduler
@@ -40,15 +38,15 @@ def test(
40
  face_mask = batch["face_mask"]
41
  image_embeds = image_encoder(clip_img).image_embeds
42
 
43
- audio_feature = batch["audio_feature"] # (1,80,T)
44
- audio_len = int(batch["audio_len"]) # Python int
45
  step = int(config.step)
46
 
47
- # --- [★ 수정] step 보정 (최소 1) --------------------------------------
48
  if audio_len < step:
49
  step = max(1, audio_len)
50
 
51
- window = 16000 # 1 구간
52
  audio_prompts, last_prompts = [], []
53
 
54
  # --- window 단위 Whisper 인코딩 --------------------------------------
@@ -56,26 +54,28 @@ def test(
56
  chunk = audio_feature[:, :, i : i + window]
57
 
58
  prompt_layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
59
- last_hidden = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
60
 
61
- audio_prompts.append(torch.stack(prompt_layers, dim=2))
62
- last_prompts.append(last_hidden)
63
 
64
  if len(audio_prompts) == 0:
65
  raise ValueError("[ERROR] No speech recognised in the provided audio.")
66
 
67
  audio_prompts = torch.cat(audio_prompts, dim=1)
68
- last_prompts = torch.cat(last_prompts, dim=1)
69
 
70
  # padding 규칙
71
  audio_prompts = torch.cat(
72
- [torch.zeros_like(audio_prompts[:, :4]), audio_prompts,
 
73
  torch.zeros_like(audio_prompts[:, :6])], dim=1)
74
  last_prompts = torch.cat(
75
- [torch.zeros_like(last_prompts[:, :24]), last_prompts,
 
76
  torch.zeros_like(last_prompts[:, :26])], dim=1)
77
 
78
- # --- [★ 수정] 반드시 ≥1 chunk ----------------------------------------
79
  total_tokens = audio_prompts.shape[1]
80
  num_chunks = max(1, math.ceil(total_tokens / (2 * step)))
81
 
@@ -84,19 +84,25 @@ def test(
84
  for i in tqdm(range(num_chunks)):
85
  start = i * 2 * step
86
 
87
- cond_clip = audio_prompts[:, start : start + 10]
88
- if cond_clip.shape[2] < 10: # [★ 수정] 패딩
89
- pad = torch.zeros_like(cond_clip[:, :, : 10 - cond_clip.shape[2]])
90
- cond_clip = torch.cat([cond_clip, pad], dim=2)
91
 
92
- bucket_clip = last_prompts[:, start : start + 50] # (1 , 50 , 384)
93
- bucket_clip = bucket_clip.unsqueeze(0).unsqueeze(-2) # (1 , 1 , 50 , 1 , 384)
94
- motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
 
 
 
 
 
95
 
 
96
 
97
  ref_list.append(ref_img[0])
98
- audio_list.append(audio_pe(cond_clip).squeeze(0)[0])
99
- uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0)[0])
100
  motion_buckets.append(motion[0])
101
 
102
  # ----------------------------------------------------------------------
@@ -125,7 +131,7 @@ def test(
125
 
126
 
127
  # ------------------------------------------------------------------
128
- # Sonic 클래스
129
  # ------------------------------------------------------------------
130
  class Sonic:
131
  config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
@@ -144,18 +150,18 @@ class Sonic:
144
  def _load_models(self, cfg):
145
  dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
146
 
147
- vae = AutoencoderKLTemporalDecoder.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", variant="fp16")
148
  sched = EulerDiscreteScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler")
149
- image_enc = CLIPVisionModelWithProjection.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16")
150
- unet = UNetSpatioTemporalConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", variant="fp16")
151
  add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
152
 
153
  a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
154
  a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
155
 
156
- unet.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu"))
157
- a2t.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu"))
158
- a2b.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu"))
159
 
160
  whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
161
  whisper.requires_grad_(False)
@@ -166,11 +172,11 @@ class Sonic:
166
  self.rife = RIFEModel(device=self.device)
167
  self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
168
 
169
- for m in (image_enc, vae, unet):
170
  m.to(dtype)
171
 
172
- self.pipe = SonicPipeline(unet=unet, image_encoder=image_enc, vae=vae, scheduler=sched).to(device=self.device, dtype=dtype)
173
- self.image_encoder = image_enc
174
  self.audio2token = a2t
175
  self.audio2bucket = a2b
176
  self.whisper = whisper
 
1
+ import os, math, torch, cv2
 
2
  from PIL import Image
3
  from omegaconf import OmegaConf
4
  from tqdm import tqdm
 
5
 
6
  from diffusers import AutoencoderKLTemporalDecoder
7
  from diffusers.schedulers import EulerDiscreteScheduler
 
38
  face_mask = batch["face_mask"]
39
  image_embeds = image_encoder(clip_img).image_embeds
40
 
41
+ audio_feature = batch["audio_feature"] # (1,80,T)
42
+ audio_len = int(batch["audio_len"]) # Python int
43
  step = int(config.step)
44
 
45
+ # --- step 보정 (최소 1) -----------------------------------------------
46
  if audio_len < step:
47
  step = max(1, audio_len)
48
 
49
+ window = 16000 # 1초 chunk
50
  audio_prompts, last_prompts = [], []
51
 
52
  # --- window 단위 Whisper 인코딩 --------------------------------------
 
54
  chunk = audio_feature[:, :, i : i + window]
55
 
56
  prompt_layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
57
+ last_hidden = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2) # (1,t,1,384)
58
 
59
+ audio_prompts.append(torch.stack(prompt_layers, dim=2)) # (1,L,12,80)
60
+ last_prompts.append(last_hidden) # (1,L,1,384)
61
 
62
  if len(audio_prompts) == 0:
63
  raise ValueError("[ERROR] No speech recognised in the provided audio.")
64
 
65
  audio_prompts = torch.cat(audio_prompts, dim=1)
66
+ last_prompts = torch.cat(last_prompts, dim=1)
67
 
68
  # padding 규칙
69
  audio_prompts = torch.cat(
70
+ [torch.zeros_like(audio_prompts[:, :4]),
71
+ audio_prompts,
72
  torch.zeros_like(audio_prompts[:, :6])], dim=1)
73
  last_prompts = torch.cat(
74
+ [torch.zeros_like(last_prompts[:, :24]),
75
+ last_prompts,
76
  torch.zeros_like(last_prompts[:, :26])], dim=1)
77
 
78
+ # --- 반드시 ≥1 chunk --------------------------------------------------
79
  total_tokens = audio_prompts.shape[1]
80
  num_chunks = max(1, math.ceil(total_tokens / (2 * step)))
81
 
 
84
  for i in tqdm(range(num_chunks)):
85
  start = i * 2 * step
86
 
87
+ cond_clip = audio_prompts[:, start : start + 10] # (1,10,12,80)
88
+ if cond_clip.shape[1] < 10: # 짧으면 패딩
89
+ pad = torch.zeros_like(cond_clip[:, : 10 - cond_clip.shape[1]])
90
+ cond_clip = torch.cat([cond_clip, pad], dim=1)
91
 
92
+ # ------------------ (★) bucket_clip 차원 맞춤 -------------------
93
+ bucket_clip = last_prompts[:, start : start + 50] # (1,50,1,384)
94
+ if bucket_clip.shape[1] < 50: # 짧으면 패딩
95
+ pad = torch.zeros_like(bucket_clip[:, : 50 - bucket_clip.shape[1]])
96
+ bucket_clip = torch.cat([bucket_clip, pad], dim=1)
97
+
98
+ bucket_clip = bucket_clip.unsqueeze(1) # → (1,1,50,1,384) ✔ 5-D
99
+ # -----------------------------------------------------------------
100
 
101
+ motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
102
 
103
  ref_list.append(ref_img[0])
104
+ audio_list.append(audio_pe(cond_clip.unsqueeze(1)).squeeze(0)[0]) # (10,···)→ unsqueeze 후 4-D
105
+ uncond_list.append(audio_pe(torch.zeros_like(cond_clip).unsqueeze(1)).squeeze(0)[0])
106
  motion_buckets.append(motion[0])
107
 
108
  # ----------------------------------------------------------------------
 
131
 
132
 
133
  # ------------------------------------------------------------------
134
+ # Sonic 클래스
135
  # ------------------------------------------------------------------
136
  class Sonic:
137
  config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
 
150
  def _load_models(self, cfg):
151
  dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
152
 
153
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", variant="fp16")
154
  sched = EulerDiscreteScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler")
155
+ imgE = CLIPVisionModelWithProjection.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16")
156
+ unet = UNetSpatioTemporalConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", variant="fp16")
157
  add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
158
 
159
  a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
160
  a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
161
 
162
+ unet.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu"))
163
+ a2t.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu"))
164
+ a2b.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu"))
165
 
166
  whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
167
  whisper.requires_grad_(False)
 
172
  self.rife = RIFEModel(device=self.device)
173
  self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
174
 
175
+ for m in (imgE, vae, unet):
176
  m.to(dtype)
177
 
178
+ self.pipe = SonicPipeline(unet=unet, image_encoder=imgE, vae=vae, scheduler=sched).to(device=self.device, dtype=dtype)
179
+ self.image_encoder = imgE
180
  self.audio2token = a2t
181
  self.audio2bucket = a2b
182
  self.whisper = whisper