openfree commited on
Commit
a47303a
·
verified ·
1 Parent(s): 0329637

Update sonic.py

Browse files
Files changed (1) hide show
  1. sonic.py +54 -70
sonic.py CHANGED
@@ -1,4 +1,4 @@
1
- # sonic.py (전체 파일)
2
 
3
  import os, math, glob, torch, cv2
4
  from PIL import Image
@@ -22,32 +22,29 @@ from src.dataset.face_align.align import AlignImage
22
 
23
  try:
24
  from safetensors.torch import load_file as safe_load
25
- except ImportError: # safetensors 가 없으면 torch.load 만 사용
26
- safe_load = None
27
 
28
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
29
 
30
 
31
- # -------------------------------------------------------------------
32
- # 공용 : 체크포인트(가중치) 탐색 함수
33
- # -------------------------------------------------------------------
34
  def _find_ckpt(root: str, keyword: str):
35
- """root 아래에서 keyword 가 포함된 .pth / .pt / .safetensors 파일 검색"""
36
- patterns = [f"**/*{keyword}*.pth", f"**/*{keyword}*.pt",
 
37
  f"**/*{keyword}*.safetensors"]
38
- files = []
39
  for p in patterns:
40
- files.extend(glob.glob(os.path.join(root, p), recursive=True))
41
- return files[0] if files else None
 
 
42
 
43
 
44
- # -------------------------------------------------------------------
45
- # single image + speech → video tensor
46
- # -------------------------------------------------------------------
47
  def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
48
  width, height, batch):
49
 
50
- # 배치 차원 맞추기
51
  for k, v in batch.items():
52
  if isinstance(v, torch.Tensor):
53
  batch[k] = v.unsqueeze(0).to(pipe.device).float()
@@ -59,17 +56,17 @@ def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
59
 
60
  audio_feature = batch["audio_feature"] # (1,80,T)
61
  audio_len = int(batch["audio_len"])
62
- step = max(1, int(cfg.step),) # 최소 1
63
 
64
- window = 16_000 # 1초 단위
65
  audio_prompts, last_prompts = [], []
66
 
67
  for i in range(0, audio_feature.shape[-1], window):
68
  chunk = audio_feature[:, :, i:i+window]
69
- hidden_layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
70
- last_hidden = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
71
- audio_prompts.append(torch.stack(hidden_layers, dim=2))
72
- last_prompts.append(last_hidden)
73
 
74
  if not audio_prompts:
75
  raise ValueError("[ERROR] No speech recognised in the provided audio.")
@@ -77,33 +74,29 @@ def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
77
  audio_prompts = torch.cat(audio_prompts, dim=1)
78
  last_prompts = torch.cat(last_prompts , dim=1)
79
 
80
- # padding 규칙
81
  audio_prompts = torch.cat(
82
  [torch.zeros_like(audio_prompts[:, :4]),
83
  audio_prompts,
84
- torch.zeros_like(audio_prompts[:, :6])], dim=1)
85
-
86
  last_prompts = torch.cat(
87
  [torch.zeros_like(last_prompts[:, :24]),
88
  last_prompts,
89
- torch.zeros_like(last_prompts[:, :26])], dim=1)
90
 
91
- total_tokens = audio_prompts.shape[1]
92
- num_chunks = max(1, math.ceil(total_tokens / (2*step)))
93
 
94
  ref_list, audio_list, uncond_list, buckets = [], [], [], []
95
-
96
  for i in tqdm(range(num_chunks)):
97
  st = i * 2 * step
98
  cond = audio_prompts[:, st: st+10]
99
  if cond.shape[2] < 10:
100
  pad = torch.zeros_like(cond[:, :, :10-cond.shape[2]])
101
- cond = torch.cat([cond, pad], dim=2)
102
 
103
  bucket_clip = last_prompts[:, st: st+50]
104
  if bucket_clip.shape[2] < 50:
105
  pad = torch.zeros_like(bucket_clip[:, :, :50-bucket_clip.shape[2]])
106
- bucket_clip = torch.cat([bucket_clip, pad], dim=2)
107
 
108
  motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
109
 
@@ -132,12 +125,10 @@ def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
132
  i2i_noise_strength=cfg.i2i_noise_strength,
133
  ).frames
134
 
135
- return (video * 0.5 + 0.5).clamp(0, 1).unsqueeze(0).cpu()
136
 
137
 
138
- # -------------------------------------------------------------------
139
- # Sonic ✨
140
- # -------------------------------------------------------------------
141
  class Sonic:
142
  config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
143
  config = OmegaConf.load(config_file)
@@ -147,45 +138,40 @@ class Sonic:
147
  cfg.use_interframe = enable_interpolate_frame
148
  self.device = f"cuda:{device_id}" if torch.cuda.is_available() and device_id >= 0 else "cpu"
149
 
150
- # 가중치 루트
151
- ckpt_root = os.path.join(BASE_DIR, "checkpoints", "Sonic")
152
- cfg.pretrained_model_name_or_path = ckpt_root # diffusers 형식
 
153
 
154
- self._load_models(cfg, ckpt_root)
155
  print("Sonic init done")
156
 
157
- # --------------------------------------------------------------
158
- def _load_models(self, cfg, ckpt_root):
159
  dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
160
 
161
- # diffusers 기본 가중치
162
- vae = AutoencoderKLTemporalDecoder.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", variant="fp16")
163
- sched = EulerDiscreteScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler")
164
- image_enc = CLIPVisionModelWithProjection.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16")
165
- unet = UNetSpatioTemporalConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", variant="fp16")
166
  add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
167
 
168
- # ------------ 추가 체크포인트 (.pth / .safetensors) ------------
169
- def _try_load(module, keyword):
170
- path = _find_ckpt(ckpt_root, keyword)
171
  if not path:
172
- print(f"[WARN] {keyword} checkpoint not found → skip")
173
  return
174
- print(f"[INFO] load {keyword} ckpt → {os.path.relpath(path, BASE_DIR)}")
175
- if path.endswith(".safetensors") and safe_load is not None:
176
- state = safe_load(path, device="cpu")
177
- else:
178
- state = torch.load(path, map_location="cpu")
179
  module.load_state_dict(state, strict=False)
180
 
181
- _try_load(unet, "unet")
182
- # audio adapters (필수)
183
  a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
184
  a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
185
- _try_load(a2t, "audio2token")
186
- _try_load(a2b, "audio2bucket")
187
 
188
- # whisper tiny
 
 
 
189
  whisper = WhisperModel.from_pretrained(
190
  os.path.join(BASE_DIR, "checkpoints/whisper-tiny")
191
  ).to(self.device).eval()
@@ -199,16 +185,16 @@ class Sonic:
199
  self.rife = RIFEModel(device=self.device)
200
  self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
201
 
202
- for m in (image_enc, vae, unet):
203
  m.to(dtype)
204
 
205
- self.pipe = SonicPipeline(unet=unet, image_encoder=image_enc, vae=vae, scheduler=sched).to(device=self.device, dtype=dtype)
206
- self.image_encoder = image_enc
207
  self.audio2token = a2t
208
  self.audio2bucket = a2b
209
  self.whisper = whisper
210
 
211
- # --------------------------------------------------------------
212
  def preprocess(self, image_path: str, expand_ratio: float = 1.0):
213
  img = cv2.imread(image_path)
214
  h, w = img.shape[:2]
@@ -219,20 +205,19 @@ class Sonic:
219
  "crop_bbox": process_bbox((x1, y1, x1+ww, y1+hh), expand_ratio, h, w)}
220
  return {"face_num": 0, "crop_bbox": None}
221
 
222
- # --------------------------------------------------------------
223
  @torch.no_grad()
224
  def process(self, image_path, audio_path, output_path,
225
- min_resolution=512, inference_steps=25, dynamic_scale=1.0,
226
- keep_resolution=False, seed=None):
227
 
228
  cfg = self.config
229
  if seed is not None:
230
  cfg.seed = seed
231
- cfg.num_inference_steps = inference_steps
232
- cfg.motion_bucket_scale = dynamic_scale
233
  seed_everything(cfg.seed)
234
 
235
- # 이미지·오디오 → tensor
236
  data = image_audio_to_tensor(
237
  self.face_det, self.feature_extractor,
238
  image_path, audio_path,
@@ -252,7 +237,6 @@ class Sonic:
252
  self.audio2bucket, self.image_encoder,
253
  w, h, data)
254
 
255
- # 인터프레임 보간
256
  if cfg.use_interframe:
257
  out, frames = video.to(self.device), []
258
  for i in tqdm(range(out.shape[2]-1), ncols=0):
 
1
+ # sonic.py ── 전체
2
 
3
  import os, math, glob, torch, cv2
4
  from PIL import Image
 
22
 
23
  try:
24
  from safetensors.torch import load_file as safe_load
25
+ except ImportError:
26
+ safe_load = None # safetensors 미설치 시 대비
27
 
28
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
29
 
30
 
31
+ # ------------------------------------------------------------ utils
 
 
32
  def _find_ckpt(root: str, keyword: str):
33
+ """root 밑에서 keyword 가 포함된 .pth / .pt / .safetensors 하나 찾기"""
34
+ patterns = [f"**/*{keyword}*.pth",
35
+ f"**/*{keyword}*.pt",
36
  f"**/*{keyword}*.safetensors"]
 
37
  for p in patterns:
38
+ files = glob.glob(os.path.join(root, p), recursive=True)
39
+ if files:
40
+ return files[0]
41
+ return None
42
 
43
 
44
+ # --------------------------------------------------- speech → video
 
 
45
  def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
46
  width, height, batch):
47
 
 
48
  for k, v in batch.items():
49
  if isinstance(v, torch.Tensor):
50
  batch[k] = v.unsqueeze(0).to(pipe.device).float()
 
56
 
57
  audio_feature = batch["audio_feature"] # (1,80,T)
58
  audio_len = int(batch["audio_len"])
59
+ step = max(1, int(cfg.step))
60
 
61
+ window = 16_000
62
  audio_prompts, last_prompts = [], []
63
 
64
  for i in range(0, audio_feature.shape[-1], window):
65
  chunk = audio_feature[:, :, i:i+window]
66
+ hidden = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
67
+ last = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
68
+ audio_prompts.append(torch.stack(hidden, dim=2))
69
+ last_prompts.append(last)
70
 
71
  if not audio_prompts:
72
  raise ValueError("[ERROR] No speech recognised in the provided audio.")
 
74
  audio_prompts = torch.cat(audio_prompts, dim=1)
75
  last_prompts = torch.cat(last_prompts , dim=1)
76
 
 
77
  audio_prompts = torch.cat(
78
  [torch.zeros_like(audio_prompts[:, :4]),
79
  audio_prompts,
80
+ torch.zeros_like(audio_prompts[:, :6])], 1)
 
81
  last_prompts = torch.cat(
82
  [torch.zeros_like(last_prompts[:, :24]),
83
  last_prompts,
84
+ torch.zeros_like(last_prompts[:, :26])], 1)
85
 
86
+ num_chunks = max(1, math.ceil(audio_prompts.shape[1] / (2*step)))
 
87
 
88
  ref_list, audio_list, uncond_list, buckets = [], [], [], []
 
89
  for i in tqdm(range(num_chunks)):
90
  st = i * 2 * step
91
  cond = audio_prompts[:, st: st+10]
92
  if cond.shape[2] < 10:
93
  pad = torch.zeros_like(cond[:, :, :10-cond.shape[2]])
94
+ cond = torch.cat([cond, pad], 2)
95
 
96
  bucket_clip = last_prompts[:, st: st+50]
97
  if bucket_clip.shape[2] < 50:
98
  pad = torch.zeros_like(bucket_clip[:, :, :50-bucket_clip.shape[2]])
99
+ bucket_clip = torch.cat([bucket_clip, pad], 2)
100
 
101
  motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
102
 
 
125
  i2i_noise_strength=cfg.i2i_noise_strength,
126
  ).frames
127
 
128
+ return (video * .5 + .5).clamp(0,1).unsqueeze(0).cpu()
129
 
130
 
131
+ # ------------------------------------------------------------ Sonic
 
 
132
  class Sonic:
133
  config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
134
  config = OmegaConf.load(config_file)
 
138
  cfg.use_interframe = enable_interpolate_frame
139
  self.device = f"cuda:{device_id}" if torch.cuda.is_available() and device_id >= 0 else "cpu"
140
 
141
+ # diffusers 베이스 모델은 ⇣ (config.json 포함)
142
+ self.diffusers_root = os.path.join(BASE_DIR, "checkpoints", "stable-video-diffusion-img2vid-xt")
143
+ # 추가 pth/pt/safetensors
144
+ self.ckpt_root = os.path.join(BASE_DIR, "checkpoints", "Sonic")
145
 
146
+ self._load_models(cfg)
147
  print("Sonic init done")
148
 
149
+ # --------------------------------------------- load all networks
150
+ def _load_models(self, cfg):
151
  dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
152
 
153
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(self.diffusers_root, subfolder="vae", variant="fp16")
154
+ sched = EulerDiscreteScheduler.from_pretrained(self.diffusers_root, subfolder="scheduler")
155
+ img_e = CLIPVisionModelWithProjection.from_pretrained(self.diffusers_root, subfolder="image_encoder", variant="fp16")
156
+ unet = UNetSpatioTemporalConditionModel.from_pretrained(self.diffusers_root, subfolder="unet", variant="fp16")
 
157
  add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
158
 
159
+ def _load_extra(module, key):
160
+ path = _find_ckpt(self.ckpt_root, key)
 
161
  if not path:
162
+ print(f"[WARN] extra ckpt for '{key}' not found → skip")
163
  return
164
+ print(f"[INFO] load {key} → {os.path.relpath(path, BASE_DIR)}")
165
+ state = safe_load(path, device="cpu") if (safe_load and path.endswith(".safetensors")) else torch.load(path, map_location="cpu")
 
 
 
166
  module.load_state_dict(state, strict=False)
167
 
 
 
168
  a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
169
  a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
 
 
170
 
171
+ _load_extra(unet, "unet")
172
+ _load_extra(a2t, "audio2token")
173
+ _load_extra(a2b, "audio2bucket")
174
+
175
  whisper = WhisperModel.from_pretrained(
176
  os.path.join(BASE_DIR, "checkpoints/whisper-tiny")
177
  ).to(self.device).eval()
 
185
  self.rife = RIFEModel(device=self.device)
186
  self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
187
 
188
+ for m in (img_e, vae, unet):
189
  m.to(dtype)
190
 
191
+ self.pipe = SonicPipeline(unet=unet, image_encoder=img_e, vae=vae, scheduler=sched).to(device=self.device, dtype=dtype)
192
+ self.image_encoder = img_e
193
  self.audio2token = a2t
194
  self.audio2bucket = a2b
195
  self.whisper = whisper
196
 
197
+ # --------------------------------------------- preprocess helpers
198
  def preprocess(self, image_path: str, expand_ratio: float = 1.0):
199
  img = cv2.imread(image_path)
200
  h, w = img.shape[:2]
 
205
  "crop_bbox": process_bbox((x1, y1, x1+ww, y1+hh), expand_ratio, h, w)}
206
  return {"face_num": 0, "crop_bbox": None}
207
 
208
+ # --------------------------------------------------------------- run
209
  @torch.no_grad()
210
  def process(self, image_path, audio_path, output_path,
211
+ min_resolution=512, inference_steps=25,
212
+ dynamic_scale=1.0, keep_resolution=False, seed=None):
213
 
214
  cfg = self.config
215
  if seed is not None:
216
  cfg.seed = seed
217
+ cfg.num_inference_steps = inference_steps
218
+ cfg.motion_bucket_scale = dynamic_scale
219
  seed_everything(cfg.seed)
220
 
 
221
  data = image_audio_to_tensor(
222
  self.face_det, self.feature_extractor,
223
  image_path, audio_path,
 
237
  self.audio2bucket, self.image_encoder,
238
  w, h, data)
239
 
 
240
  if cfg.use_interframe:
241
  out, frames = video.to(self.device), []
242
  for i in tqdm(range(out.shape[2]-1), ncols=0):