openfree commited on
Commit
85ad908
·
verified ·
1 Parent(s): 1fb410d

Update sonic.py

Browse files
Files changed (1) hide show
  1. sonic.py +91 -142
sonic.py CHANGED
@@ -1,7 +1,5 @@
1
- import os
2
- import math # [★ 수정] ceil 계산용
3
  import torch
4
- import torch.utils.checkpoint
5
  from PIL import Image
6
  from omegaconf import OmegaConf
7
  from tqdm import tqdm
@@ -26,109 +24,89 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
26
 
27
 
28
  # ------------------------------------------------------------------
29
- # test() : 한 장의 얼굴 + 오디오 프레임 텐서 시퀀스
30
  # ------------------------------------------------------------------
31
  def test(
32
- pipe,
33
- config,
34
- wav_enc,
35
- audio_pe,
36
- audio2bucket,
37
- image_encoder,
38
- width,
39
- height,
40
- batch,
41
  ):
42
- # (B,C,H,W) (1,B,C,H,W)
43
  for k, v in batch.items():
44
  if isinstance(v, torch.Tensor):
45
  batch[k] = v.unsqueeze(0).to(pipe.device).float()
46
 
47
- ref_img = batch["ref_img"]
48
  clip_img = batch["clip_images"]
49
  face_mask = batch["face_mask"]
50
  image_embeds = image_encoder(clip_img).image_embeds
51
 
52
- audio_feature = batch["audio_feature"] # (C,T)
53
- audio_len = batch["audio_len"] # # of whisper tokens
54
  step = int(config.step)
55
 
56
- # ----------------------------- [★ 수정] -----------------------------
57
- # ① 1 초 구간 단위를 위해 window 16000 → whisper‐tiny 기준 1 초
58
- # ② audio_len < step 이면 step 을 줄여 빈 리스트 방지
59
- # --------------------------------------------------------------------
60
- window = 16000
61
  if audio_len < step:
62
  step = max(1, audio_len)
63
 
64
- # ── 오디오를 1 초 단위로 자르면서 Whisper 인코딩
65
- audio_prompts, last_audio_prompts = [], []
 
 
66
  for i in range(0, audio_feature.shape[-1], window):
67
- chunk = audio_feature[:, :, i : i + window] # (B,C,window)
68
 
69
- # whisper encoder
70
- prompt_layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
71
- last_hidden = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
72
 
73
  audio_prompts.append(torch.stack(prompt_layers, dim=2))
74
- last_audio_prompts.append(last_hidden)
75
 
76
- # ── 예외: 아무 내용도 없으면 종료
77
  if len(audio_prompts) == 0:
78
- raise ValueError(
79
- "[ERROR] No speech recognized from the audio. "
80
- "Please provide a valid speech recording."
81
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- # Whisper token 시퀀스 재구성 (+ 모델 padding 규칙)
84
- audio_prompts = torch.cat(audio_prompts, dim=1)[:, : audio_len * 2]
85
- audio_prompts = torch.cat(
86
- [torch.zeros_like(audio_prompts[:, :4]), audio_prompts, torch.zeros_like(audio_prompts[:, :6])],
87
- dim=1,
88
- )
89
-
90
- last_audio_prompts = torch.cat(last_audio_prompts, dim=1)[:, : audio_len * 2]
91
- last_audio_prompts = torch.cat(
92
- [torch.zeros_like(last_audio_prompts[:, :24]), last_audio_prompts, torch.zeros_like(last_audio_prompts[:, :26])],
93
- dim=1,
94
- )
95
-
96
- # --------------------------------------------------------------------
97
- # step 조정 결과를 반영해 총 chunk 횟수 계산 (ceil)
98
- # --------------------------------------------------------------------
99
- num_chunks = math.ceil(audio_len / step)
100
-
101
- ref_tensor_list, audio_tensor_list, uncond_audio_tensor_list, motion_buckets = [], [], [], []
102
  for i in tqdm(range(num_chunks)):
103
  start = i * 2 * step
104
- audio_clip = audio_prompts[:, start : start + 10].unsqueeze(0)
105
- audio_clip_for_bucket = last_audio_prompts[:, start : start + 50].unsqueeze(0)
106
 
107
- motion_bucket = audio2bucket(audio_clip_for_bucket, image_embeds) * 16 + 16
108
- motion_buckets.append(motion_bucket[0])
 
 
109
 
110
- cond_audio = audio_pe(audio_clip).squeeze(0)
111
- uncond_audio = audio_pe(torch.zeros_like(audio_clip)).squeeze(0)
 
 
112
 
113
- ref_tensor_list.append(ref_img[0])
114
- audio_tensor_list.append(cond_audio[0])
115
- uncond_audio_tensor_list.append(uncond_audio[0])
116
 
117
- # 빈 리스트 방지
118
- if len(audio_tensor_list) == 0:
119
- raise ValueError("[ERROR] Audio too short for the configured 'step' size; no frames produced.")
 
120
 
121
- # --------------------------------------------------------------------
122
  video = pipe(
123
- ref_img,
124
- clip_img,
125
- face_mask,
126
- audio_tensor_list,
127
- uncond_audio_tensor_list,
128
- motion_buckets,
129
- height=height,
130
- width=width,
131
- num_frames=len(audio_tensor_list),
132
  decode_chunk_size=config.decode_chunk_size,
133
  motion_bucket_scale=config.motion_bucket_scale,
134
  fps=config.fps,
@@ -143,81 +121,60 @@ def test(
143
  num_inference_steps=config.num_inference_steps,
144
  i2i_noise_strength=config.i2i_noise_strength,
145
  ).frames
146
- # --------------------------------------------------------------------
147
 
148
  video = (video * 0.5 + 0.5).clamp(0, 1)
149
  return video.to(pipe.device).unsqueeze(0).cpu()
150
 
151
 
152
  # ------------------------------------------------------------------
153
- # Sonic 클래스
154
  # ------------------------------------------------------------------
155
  class Sonic:
156
  config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
157
  config = OmegaConf.load(config_file)
158
 
159
  def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
160
- cfg = self.config
161
- cfg.use_interframe = enable_interpolate_frame
162
- self.device = f"cuda:{device_id}" if device_id >= 0 and torch.cuda.is_available() else "cpu"
163
  cfg.pretrained_model_name_or_path = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path)
164
 
165
- # ───────────── 모델 로드
166
  self._load_models(cfg)
167
  print("Sonic init done")
168
 
169
- # --------------------------------------------------------------
170
- # model / pipeline loader
171
  # --------------------------------------------------------------
172
  def _load_models(self, cfg):
173
- dtype_map = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}
174
- weight_dtype = dtype_map.get(cfg.weight_dtype, torch.float32)
175
 
176
- # backbone
177
- vae = AutoencoderKLTemporalDecoder.from_pretrained(
178
- cfg.pretrained_model_name_or_path, subfolder="vae", variant="fp16"
179
- )
180
- scheduler = EulerDiscreteScheduler.from_pretrained(
181
- cfg.pretrained_model_name_or_path, subfolder="scheduler"
182
- )
183
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
184
- cfg.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16"
185
- )
186
- unet = UNetSpatioTemporalConditionModel.from_pretrained(
187
- cfg.pretrained_model_name_or_path, subfolder="unet", variant="fp16"
188
- )
189
  add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
190
 
191
- # audio adapters
192
- audio2token = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
193
- audio2bucket = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
194
 
195
- # checkpoints
196
  unet.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu"))
197
- audio2token.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu"))
198
- audio2bucket.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu"))
199
 
200
- # whisper
201
  whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
202
  whisper.requires_grad_(False)
203
 
204
- # extras
205
  self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny"))
206
  self.face_det = AlignImage(self.device, det_path=os.path.join(BASE_DIR, "checkpoints/yoloface_v5m.pt"))
207
  if cfg.use_interframe:
208
  self.rife = RIFEModel(device=self.device)
209
  self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
210
 
211
- # dtype
212
- for m in (image_encoder, vae, unet):
213
- m.to(weight_dtype)
214
 
215
- # pipeline
216
- pipe = SonicPipeline(unet=unet, image_encoder=image_encoder, vae=vae, scheduler=scheduler)
217
- self.pipe = pipe.to(device=self.device, dtype=weight_dtype)
218
- self.audio2token = audio2token
219
- self.audio2bucket = audio2bucket
220
- self.image_encoder = image_encoder
221
  self.whisper = whisper
222
 
223
  # --------------------------------------------------------------
@@ -227,9 +184,7 @@ class Sonic:
227
  _, _, bboxes = self.face_det(img, maxface=True)
228
  if bboxes:
229
  x1, y1, ww, hh = bboxes[0]
230
- bbox = (x1, y1, x1 + ww, y1 + hh)
231
- crop_bbox = process_bbox(bbox, expand_radio=expand_ratio, height=h, width=w)
232
- return {"face_num": len(bboxes), "crop_bbox": crop_bbox}
233
  return {"face_num": 0, "crop_bbox": None}
234
 
235
  # --------------------------------------------------------------
@@ -248,19 +203,17 @@ class Sonic:
248
  cfg = self.config
249
  if seed is not None:
250
  cfg.seed = seed
251
- cfg.num_inference_steps = inference_steps
252
- cfg.motion_bucket_scale = dynamic_scale
253
  seed_everything(cfg.seed)
254
 
255
- # ----------------------------------------------------------
256
- # 이미지·오디오 → 텐서
257
- # ----------------------------------------------------------
258
  test_data = image_audio_to_tensor(
259
  self.face_det,
260
  self.feature_extractor,
261
  image_path,
262
  audio_path,
263
- limit=-1, # 전체 오디오 사용
264
  image_size=min_resolution,
265
  area=cfg.area,
266
  )
@@ -269,14 +222,12 @@ class Sonic:
269
 
270
  h, w = test_data["ref_img"].shape[-2:]
271
  resolution = (
272
- f"{(Image.open(image_path).width // 2)*2}x{(Image.open(image_path).height // 2)*2}"
273
  if keep_resolution
274
  else f"{w}x{h}"
275
  )
276
 
277
- # ----------------------------------------------------------
278
- # 프레임 생성
279
- # ----------------------------------------------------------
280
  video = test(
281
  self.pipe,
282
  cfg,
@@ -291,22 +242,20 @@ class Sonic:
291
 
292
  # 중간 프레임 보간
293
  if cfg.use_interframe:
294
- out, results = video.to(self.device), []
 
295
  for i in tqdm(range(out.shape[2] - 1), ncols=0):
296
- I1, I2 = out[:, :, i], out[:, :, i + 1]
297
- middle = self.rife.inference(I1, I2).clamp(0, 1).detach()
298
- results.extend([out[:, :, i], middle])
299
- results.append(out[:, :, -1])
300
- video = torch.stack(results, 2).cpu()
301
-
302
- # ----------------------------------------------------------
303
- # 파일 저장
304
- # ----------------------------------------------------------
305
- tmp_video = output_path.replace(".mp4", "_noaudio.mp4")
306
- save_videos_grid(video, tmp_video, n_rows=video.shape[0], fps=cfg.fps * (2 if cfg.use_interframe else 1))
307
  os.system(
308
- f"ffmpeg -i '{tmp_video}' -i '{audio_path}' -s {resolution} "
309
  f"-vcodec libx264 -acodec aac -crf 18 -shortest '{output_path}' -y -loglevel error"
310
  )
311
- os.remove(tmp_video)
312
  return 0
 
1
+ import os, math
 
2
  import torch
 
3
  from PIL import Image
4
  from omegaconf import OmegaConf
5
  from tqdm import tqdm
 
24
 
25
 
26
  # ------------------------------------------------------------------
27
+ # single image + speech video-tensor generator
28
  # ------------------------------------------------------------------
29
  def test(
30
+ pipe, config, wav_enc, audio_pe, audio2bucket, image_encoder,
31
+ width, height, batch,
 
 
 
 
 
 
 
32
  ):
33
+ # --- 배치 차원 맞추기 --------------------------------------------------
34
  for k, v in batch.items():
35
  if isinstance(v, torch.Tensor):
36
  batch[k] = v.unsqueeze(0).to(pipe.device).float()
37
 
38
+ ref_img = batch["ref_img"] # (1,C,H,W)
39
  clip_img = batch["clip_images"]
40
  face_mask = batch["face_mask"]
41
  image_embeds = image_encoder(clip_img).image_embeds
42
 
43
+ audio_feature = batch["audio_feature"] # (1,80,T)
44
+ audio_len = int(batch["audio_len"]) # Python int
45
  step = int(config.step)
46
 
47
+ # --- [★ 수정] step 보정 (최소 1) --------------------------------------
 
 
 
 
48
  if audio_len < step:
49
  step = max(1, audio_len)
50
 
51
+ window = 16000 # 1 초 구간
52
+ audio_prompts, last_prompts = [], []
53
+
54
+ # --- window 단위 Whisper 인코딩 --------------------------------------
55
  for i in range(0, audio_feature.shape[-1], window):
56
+ chunk = audio_feature[:, :, i : i + window]
57
 
58
+ prompt_layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
59
+ last_hidden = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
 
60
 
61
  audio_prompts.append(torch.stack(prompt_layers, dim=2))
62
+ last_prompts.append(last_hidden)
63
 
 
64
  if len(audio_prompts) == 0:
65
+ raise ValueError("[ERROR] No speech recognised in the provided audio.")
66
+
67
+ audio_prompts = torch.cat(audio_prompts, dim=1)
68
+ last_prompts = torch.cat(last_prompts, dim=1)
69
+
70
+ # padding 규칙
71
+ audio_prompts = torch.cat(
72
+ [torch.zeros_like(audio_prompts[:, :4]), audio_prompts,
73
+ torch.zeros_like(audio_prompts[:, :6])], dim=1)
74
+ last_prompts = torch.cat(
75
+ [torch.zeros_like(last_prompts[:, :24]), last_prompts,
76
+ torch.zeros_like(last_prompts[:, :26])], dim=1)
77
+
78
+ # --- [★ 수정] 반드시 ≥1 chunk ----------------------------------------
79
+ total_tokens = audio_prompts.shape[1]
80
+ num_chunks = max(1, math.ceil(total_tokens / (2 * step)))
81
+
82
+ ref_list, audio_list, uncond_list, motion_buckets = [], [], [], []
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  for i in tqdm(range(num_chunks)):
85
  start = i * 2 * step
 
 
86
 
87
+ cond_clip = audio_prompts[:, start : start + 10]
88
+ if cond_clip.shape[2] < 10: # [★ 수정] 패딩
89
+ pad = torch.zeros_like(cond_clip[:, :, : 10 - cond_clip.shape[2]])
90
+ cond_clip = torch.cat([cond_clip, pad], dim=2)
91
 
92
+ bucket_clip = last_prompts[:, start : start + 50]
93
+ if bucket_clip.shape[2] < 50: # [★ 수정] 패딩
94
+ pad = torch.zeros_like(bucket_clip[:, :, : 50 - bucket_clip.shape[2]])
95
+ bucket_clip = torch.cat([bucket_clip, pad], dim=2)
96
 
97
+ motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
 
 
98
 
99
+ ref_list.append(ref_img[0])
100
+ audio_list.append(audio_pe(cond_clip).squeeze(0)[0])
101
+ uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0)[0])
102
+ motion_buckets.append(motion[0])
103
 
104
+ # ----------------------------------------------------------------------
105
  video = pipe(
106
+ ref_img, clip_img, face_mask,
107
+ audio_list, uncond_list, motion_buckets,
108
+ height=height, width=width,
109
+ num_frames=len(audio_list),
 
 
 
 
 
110
  decode_chunk_size=config.decode_chunk_size,
111
  motion_bucket_scale=config.motion_bucket_scale,
112
  fps=config.fps,
 
121
  num_inference_steps=config.num_inference_steps,
122
  i2i_noise_strength=config.i2i_noise_strength,
123
  ).frames
 
124
 
125
  video = (video * 0.5 + 0.5).clamp(0, 1)
126
  return video.to(pipe.device).unsqueeze(0).cpu()
127
 
128
 
129
  # ------------------------------------------------------------------
130
+ # Sonic 클래스
131
  # ------------------------------------------------------------------
132
  class Sonic:
133
  config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
134
  config = OmegaConf.load(config_file)
135
 
136
  def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
137
+ cfg = self.config
138
+ cfg.use_interframe = enable_interpolate_frame
139
+ self.device = f"cuda:{device_id}" if device_id >= 0 and torch.cuda.is_available() else "cpu"
140
  cfg.pretrained_model_name_or_path = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path)
141
 
 
142
  self._load_models(cfg)
143
  print("Sonic init done")
144
 
 
 
145
  # --------------------------------------------------------------
146
  def _load_models(self, cfg):
147
+ dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
 
148
 
149
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", variant="fp16")
150
+ sched = EulerDiscreteScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler")
151
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16")
152
+ unet = UNetSpatioTemporalConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", variant="fp16")
 
 
 
 
 
 
 
 
 
153
  add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
154
 
155
+ a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
156
+ a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
 
157
 
 
158
  unet.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu"))
159
+ a2t.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu"))
160
+ a2b.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu"))
161
 
 
162
  whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
163
  whisper.requires_grad_(False)
164
 
 
165
  self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny"))
166
  self.face_det = AlignImage(self.device, det_path=os.path.join(BASE_DIR, "checkpoints/yoloface_v5m.pt"))
167
  if cfg.use_interframe:
168
  self.rife = RIFEModel(device=self.device)
169
  self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
170
 
171
+ for m in (image_enc, vae, unet):
172
+ m.to(dtype)
 
173
 
174
+ self.pipe = SonicPipeline(unet=unet, image_encoder=image_enc, vae=vae, scheduler=sched).to(device=self.device, dtype=dtype)
175
+ self.image_encoder = image_enc
176
+ self.audio2token = a2t
177
+ self.audio2bucket = a2b
 
 
178
  self.whisper = whisper
179
 
180
  # --------------------------------------------------------------
 
184
  _, _, bboxes = self.face_det(img, maxface=True)
185
  if bboxes:
186
  x1, y1, ww, hh = bboxes[0]
187
+ return {"face_num": 1, "crop_bbox": process_bbox((x1, y1, x1 + ww, y1 + hh), expand_ratio, h, w)}
 
 
188
  return {"face_num": 0, "crop_bbox": None}
189
 
190
  # --------------------------------------------------------------
 
203
  cfg = self.config
204
  if seed is not None:
205
  cfg.seed = seed
206
+ cfg.num_inference_steps = inference_steps
207
+ cfg.motion_bucket_scale = dynamic_scale
208
  seed_everything(cfg.seed)
209
 
210
+ # 이미지·오디오 → tensor
 
 
211
  test_data = image_audio_to_tensor(
212
  self.face_det,
213
  self.feature_extractor,
214
  image_path,
215
  audio_path,
216
+ limit=-1,
217
  image_size=min_resolution,
218
  area=cfg.area,
219
  )
 
222
 
223
  h, w = test_data["ref_img"].shape[-2:]
224
  resolution = (
225
+ f"{(Image.open(image_path).width // 2) * 2}x{(Image.open(image_path).height // 2) * 2}"
226
  if keep_resolution
227
  else f"{w}x{h}"
228
  )
229
 
230
+ # 비디오 프레임 생성
 
 
231
  video = test(
232
  self.pipe,
233
  cfg,
 
242
 
243
  # 중간 프레임 보간
244
  if cfg.use_interframe:
245
+ out = video.to(self.device)
246
+ frames = []
247
  for i in tqdm(range(out.shape[2] - 1), ncols=0):
248
+ mid = self.rife.inference(out[:, :, i], out[:, :, i + 1]).clamp(0, 1).detach()
249
+ frames.extend([out[:, :, i], mid])
250
+ frames.append(out[:, :, -1])
251
+ video = torch.stack(frames, 2).cpu()
252
+
253
+ # 저장
254
+ tmp = output_path.replace(".mp4", "_noaudio.mp4")
255
+ save_videos_grid(video, tmp, n_rows=video.shape[0], fps=cfg.fps * (2 if cfg.use_interframe else 1))
 
 
 
256
  os.system(
257
+ f"ffmpeg -i '{tmp}' -i '{audio_path}' -s {resolution} "
258
  f"-vcodec libx264 -acodec aac -crf 18 -shortest '{output_path}' -y -loglevel error"
259
  )
260
+ os.remove(tmp)
261
  return 0