openfree commited on
Commit
430d42a
·
verified ·
1 Parent(s): e10969c

Update sonic.py

Browse files
Files changed (1) hide show
  1. sonic.py +73 -109
sonic.py CHANGED
@@ -9,9 +9,7 @@ from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatur
9
 
10
  from src.utils.util import save_videos_grid, seed_everything
11
  from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor
12
- from src.models.base.unet_spatio_temporal_condition import (
13
- UNetSpatioTemporalConditionModel, add_ip_adapters,
14
- )
15
  from src.pipelines.pipeline_sonic import SonicPipeline
16
  from src.models.audio_adapter.audio_proj import AudioProjModel
17
  from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel
@@ -22,13 +20,12 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
22
 
23
 
24
  # ------------------------------------------------------------------
25
- # single image + speech video-tensor generator
26
  # ------------------------------------------------------------------
27
- def test(
28
- pipe, config, wav_enc, audio_pe, audio2bucket, image_encoder,
29
- width, height, batch,
30
- ):
31
- # --- 배치 차원 맞추기 --------------------------------------------------
32
  for k, v in batch.items():
33
  if isinstance(v, torch.Tensor):
34
  batch[k] = v.unsqueeze(0).to(pipe.device).float()
@@ -39,33 +36,29 @@ def test(
39
  image_embeds = image_encoder(clip_img).image_embeds
40
 
41
  audio_feature = batch["audio_feature"] # (1,80,T)
42
- audio_len = int(batch["audio_len"]) # Python int
43
- step = int(config.step)
44
-
45
- # --- step 보정 (최소 1) -----------------------------------------------
46
- if audio_len < step:
47
- step = max(1, audio_len)
48
 
49
- window = 16000 # 1 chunk
50
  audio_prompts, last_prompts = [], []
51
 
52
- # --- window 단위 Whisper 인코딩 --------------------------------------
53
  for i in range(0, audio_feature.shape[-1], window):
54
- chunk = audio_feature[:, :, i : i + window]
55
 
56
- prompt_layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
57
- last_hidden = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2) # (1,t,1,384)
58
 
59
- audio_prompts.append(torch.stack(prompt_layers, dim=2)) # (1,L,12,80)
60
- last_prompts.append(last_hidden) # (1,L,1,384)
61
 
62
- if len(audio_prompts) == 0:
63
  raise ValueError("[ERROR] No speech recognised in the provided audio.")
64
 
65
- audio_prompts = torch.cat(audio_prompts, dim=1)
66
- last_prompts = torch.cat(last_prompts, dim=1)
67
 
68
- # padding 규칙
69
  audio_prompts = torch.cat(
70
  [torch.zeros_like(audio_prompts[:, :4]),
71
  audio_prompts,
@@ -75,7 +68,6 @@ def test(
75
  last_prompts,
76
  torch.zeros_like(last_prompts[:, :26])], dim=1)
77
 
78
- # --- 반드시 ≥1 chunk --------------------------------------------------
79
  total_tokens = audio_prompts.shape[1]
80
  num_chunks = max(1, math.ceil(total_tokens / (2 * step)))
81
 
@@ -84,46 +76,46 @@ def test(
84
  for i in tqdm(range(num_chunks)):
85
  start = i * 2 * step
86
 
87
- cond_clip = audio_prompts[:, start : start + 10] # (1,10,12,80)
88
- if cond_clip.shape[1] < 10: # 짧으면 패딩
89
- pad = torch.zeros_like(cond_clip[:, : 10 - cond_clip.shape[1]])
 
90
  cond_clip = torch.cat([cond_clip, pad], dim=1)
 
91
 
92
- # ------------------ () bucket_clip 차원 맞춤 -------------------
93
- bucket_clip = last_prompts[:, start : start + 50] # (1,50,1,384)
94
- if bucket_clip.shape[1] < 50: # 짧으면 패딩
95
- pad = torch.zeros_like(bucket_clip[:, : 50 - bucket_clip.shape[1]])
96
  bucket_clip = torch.cat([bucket_clip, pad], dim=1)
97
-
98
- bucket_clip = bucket_clip.unsqueeze(1) # → (1,1,50,1,384) ✔ 5-D
99
- # -----------------------------------------------------------------
100
 
101
  motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
102
 
103
  ref_list.append(ref_img[0])
104
- audio_list.append(audio_pe(cond_clip.unsqueeze(1)).squeeze(0)[0]) # (10,···)→ unsqueeze 후 4-D
105
- uncond_list.append(audio_pe(torch.zeros_like(cond_clip).unsqueeze(1)).squeeze(0)[0])
106
  motion_buckets.append(motion[0])
107
 
108
- # ----------------------------------------------------------------------
109
  video = pipe(
110
  ref_img, clip_img, face_mask,
111
  audio_list, uncond_list, motion_buckets,
112
  height=height, width=width,
113
  num_frames=len(audio_list),
114
- decode_chunk_size=config.decode_chunk_size,
115
- motion_bucket_scale=config.motion_bucket_scale,
116
- fps=config.fps,
117
- noise_aug_strength=config.noise_aug_strength,
118
- min_guidance_scale1=config.min_appearance_guidance_scale,
119
- max_guidance_scale1=config.max_appearance_guidance_scale,
120
- min_guidance_scale2=config.audio_guidance_scale,
121
- max_guidance_scale2=config.audio_guidance_scale,
122
- overlap=config.overlap,
123
- shift_offset=config.shift_offset,
124
- frames_per_batch=config.n_sample_frames,
125
- num_inference_steps=config.num_inference_steps,
126
- i2i_noise_strength=config.i2i_noise_strength,
127
  ).frames
128
 
129
  video = (video * 0.5 + 0.5).clamp(0, 1)
@@ -131,16 +123,16 @@ def test(
131
 
132
 
133
  # ------------------------------------------------------------------
134
- # Sonic 클래스
135
  # ------------------------------------------------------------------
136
  class Sonic:
137
  config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
138
  config = OmegaConf.load(config_file)
139
 
140
  def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
141
- cfg = self.config
142
  cfg.use_interframe = enable_interpolate_frame
143
- self.device = f"cuda:{device_id}" if device_id >= 0 and torch.cuda.is_available() else "cpu"
144
  cfg.pretrained_model_name_or_path = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path)
145
 
146
  self._load_models(cfg)
@@ -159,9 +151,9 @@ class Sonic:
159
  a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
160
  a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
161
 
162
- unet.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu"))
163
- a2t.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu"))
164
- a2b.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu"))
165
 
166
  whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
167
  whisper.requires_grad_(False)
@@ -188,78 +180,50 @@ class Sonic:
188
  _, _, bboxes = self.face_det(img, maxface=True)
189
  if bboxes:
190
  x1, y1, ww, hh = bboxes[0]
191
- return {"face_num": 1, "crop_bbox": process_bbox((x1, y1, x1 + ww, y1 + hh), expand_ratio, h, w)}
192
  return {"face_num": 0, "crop_bbox": None}
193
 
194
  # --------------------------------------------------------------
195
  @torch.no_grad()
196
- def process(
197
- self,
198
- image_path: str,
199
- audio_path: str,
200
- output_path: str,
201
- min_resolution: int = 512,
202
- inference_steps: int = 25,
203
- dynamic_scale: float = 1.0,
204
- keep_resolution: bool = False,
205
- seed: int | None = None,
206
- ):
207
  cfg = self.config
208
  if seed is not None:
209
  cfg.seed = seed
210
- cfg.num_inference_steps = inference_steps
211
- cfg.motion_bucket_scale = dynamic_scale
212
  seed_everything(cfg.seed)
213
 
214
- # 이미지·오디오 → tensor
215
  test_data = image_audio_to_tensor(
216
- self.face_det,
217
- self.feature_extractor,
218
- image_path,
219
- audio_path,
220
- limit=-1,
221
- image_size=min_resolution,
222
- area=cfg.area,
223
  )
224
  if test_data is None:
225
  return -1
226
 
227
  h, w = test_data["ref_img"].shape[-2:]
228
- resolution = (
229
- f"{(Image.open(image_path).width // 2) * 2}x{(Image.open(image_path).height // 2) * 2}"
230
- if keep_resolution
231
- else f"{w}x{h}"
232
- )
233
 
234
- # 비디오 프레임 생성
235
- video = test(
236
- self.pipe,
237
- cfg,
238
- wav_enc=self.whisper,
239
- audio_pe=self.audio2token,
240
- audio2bucket=self.audio2bucket,
241
- image_encoder=self.image_encoder,
242
- width=w,
243
- height=h,
244
- batch=test_data,
245
- )
246
 
247
- # 중간 프레임 보간
248
  if cfg.use_interframe:
249
  out = video.to(self.device)
250
  frames = []
251
- for i in tqdm(range(out.shape[2] - 1), ncols=0):
252
- mid = self.rife.inference(out[:, :, i], out[:, :, i + 1]).clamp(0, 1).detach()
253
- frames.extend([out[:, :, i], mid])
254
- frames.append(out[:, :, -1])
255
  video = torch.stack(frames, 2).cpu()
256
 
257
- # 저장
258
  tmp = output_path.replace(".mp4", "_noaudio.mp4")
259
- save_videos_grid(video, tmp, n_rows=video.shape[0], fps=cfg.fps * (2 if cfg.use_interframe else 1))
260
- os.system(
261
- f"ffmpeg -i '{tmp}' -i '{audio_path}' -s {resolution} "
262
- f"-vcodec libx264 -acodec aac -crf 18 -shortest '{output_path}' -y -loglevel error"
263
- )
264
  os.remove(tmp)
265
  return 0
 
9
 
10
  from src.utils.util import save_videos_grid, seed_everything
11
  from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor
12
+ from src.models.base.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel, add_ip_adapters
 
 
13
  from src.pipelines.pipeline_sonic import SonicPipeline
14
  from src.models.audio_adapter.audio_proj import AudioProjModel
15
  from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel
 
20
 
21
 
22
  # ------------------------------------------------------------------
23
+ # single image + speech video-tensor generator
24
  # ------------------------------------------------------------------
25
+ def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
26
+ width, height, batch):
27
+
28
+ # ---------- 배치 차원 맞추기 -----------------------------------------
 
29
  for k, v in batch.items():
30
  if isinstance(v, torch.Tensor):
31
  batch[k] = v.unsqueeze(0).to(pipe.device).float()
 
36
  image_embeds = image_encoder(clip_img).image_embeds
37
 
38
  audio_feature = batch["audio_feature"] # (1,80,T)
39
+ audio_len = int(batch["audio_len"])
40
+ step = max(1, int(cfg.step)) # 최소 1 보장
 
 
 
 
41
 
42
+ window = 16_000 # 1-second chunk
43
  audio_prompts, last_prompts = [], []
44
 
45
+ # ---------- Whisper 인코딩 ------------------------------------------
46
  for i in range(0, audio_feature.shape[-1], window):
47
+ chunk = audio_feature[:, :, i:i+window]
48
 
49
+ hs_all = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
50
+ last_hid = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2) # (1,t,1,384)
51
 
52
+ audio_prompts.append(torch.stack(hs_all, dim=2)) # (1,t,12,384)
53
+ last_prompts.append(last_hid) # (1,t,1,384)
54
 
55
+ if not audio_prompts:
56
  raise ValueError("[ERROR] No speech recognised in the provided audio.")
57
 
58
+ audio_prompts = torch.cat(audio_prompts, dim=1) # (1,T,12,384)
59
+ last_prompts = torch.cat(last_prompts, dim=1) # (1,T,1,384)
60
 
61
+ # ---------- padding 규칙 --------------------------------------------
62
  audio_prompts = torch.cat(
63
  [torch.zeros_like(audio_prompts[:, :4]),
64
  audio_prompts,
 
68
  last_prompts,
69
  torch.zeros_like(last_prompts[:, :26])], dim=1)
70
 
 
71
  total_tokens = audio_prompts.shape[1]
72
  num_chunks = max(1, math.ceil(total_tokens / (2 * step)))
73
 
 
76
  for i in tqdm(range(num_chunks)):
77
  start = i * 2 * step
78
 
79
+ # --------- cond_clip : (1,10,12,384) (1,10,5,384) --------------
80
+ cond_clip = audio_prompts[:, start : start + 10] # (1,≤10,12,384)
81
+ if cond_clip.shape[1] < 10: # seq_len 패딩
82
+ pad = torch.zeros_like(cond_clip[:, :10-cond_clip.shape[1]])
83
  cond_clip = torch.cat([cond_clip, pad], dim=1)
84
+ cond_clip = cond_clip[:, :, :5, :] # 5 blocks 선택
85
 
86
+ # --------- bucket_clip : (1,50,1,384) unsqueeze(0) -----
87
+ bucket_clip = last_prompts[:, start : start + 50] # (1,≤50,1,384)
88
+ if bucket_clip.shape[1] < 50: # 길이 패딩
89
+ pad = torch.zeros_like(bucket_clip[:, :50-bucket_clip.shape[1]])
90
  bucket_clip = torch.cat([bucket_clip, pad], dim=1)
91
+ bucket_clip = bucket_clip.unsqueeze(0) # (1,1,50,1,384)
 
 
92
 
93
  motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
94
 
95
  ref_list.append(ref_img[0])
96
+ audio_list.append(audio_pe(cond_clip).squeeze(0)[0]) # (10,*)
97
+ uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0)[0])
98
  motion_buckets.append(motion[0])
99
 
100
+ # ---------- diffusion ------------------------------------------------
101
  video = pipe(
102
  ref_img, clip_img, face_mask,
103
  audio_list, uncond_list, motion_buckets,
104
  height=height, width=width,
105
  num_frames=len(audio_list),
106
+ decode_chunk_size=cfg.decode_chunk_size,
107
+ motion_bucket_scale=cfg.motion_bucket_scale,
108
+ fps=cfg.fps,
109
+ noise_aug_strength=cfg.noise_aug_strength,
110
+ min_guidance_scale1=cfg.min_appearance_guidance_scale,
111
+ max_guidance_scale1=cfg.max_appearance_guidance_scale,
112
+ min_guidance_scale2=cfg.audio_guidance_scale,
113
+ max_guidance_scale2=cfg.audio_guidance_scale,
114
+ overlap=cfg.overlap,
115
+ shift_offset=cfg.shift_offset,
116
+ frames_per_batch=cfg.n_sample_frames,
117
+ num_inference_steps=cfg.num_inference_steps,
118
+ i2i_noise_strength=cfg.i2i_noise_strength,
119
  ).frames
120
 
121
  video = (video * 0.5 + 0.5).clamp(0, 1)
 
123
 
124
 
125
  # ------------------------------------------------------------------
126
+ # Sonic class
127
  # ------------------------------------------------------------------
128
  class Sonic:
129
  config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
130
  config = OmegaConf.load(config_file)
131
 
132
  def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
133
+ cfg = self.config
134
  cfg.use_interframe = enable_interpolate_frame
135
+ self.device = f"cuda:{device_id}" if device_id >= 0 and torch.cuda.is_available() else "cpu"
136
  cfg.pretrained_model_name_or_path = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path)
137
 
138
  self._load_models(cfg)
 
151
  a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
152
  a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
153
 
154
+ unet.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu"))
155
+ a2t.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu"))
156
+ a2b.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu"))
157
 
158
  whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
159
  whisper.requires_grad_(False)
 
180
  _, _, bboxes = self.face_det(img, maxface=True)
181
  if bboxes:
182
  x1, y1, ww, hh = bboxes[0]
183
+ return {"face_num": 1, "crop_bbox": process_bbox((x1, y1, x1+ww, y1+hh), expand_ratio, h, w)}
184
  return {"face_num": 0, "crop_bbox": None}
185
 
186
  # --------------------------------------------------------------
187
  @torch.no_grad()
188
+ def process(self, image_path: str, audio_path: str, output_path: str,
189
+ min_resolution: int = 512, inference_steps: int = 25,
190
+ dynamic_scale: float = 1.0, keep_resolution: bool = False,
191
+ seed: int | None = None):
192
+
 
 
 
 
 
 
193
  cfg = self.config
194
  if seed is not None:
195
  cfg.seed = seed
196
+ cfg.num_inference_steps = inference_steps
197
+ cfg.motion_bucket_scale = dynamic_scale
198
  seed_everything(cfg.seed)
199
 
 
200
  test_data = image_audio_to_tensor(
201
+ self.face_det, self.feature_extractor,
202
+ image_path, audio_path, limit=-1,
203
+ image_size=min_resolution, area=cfg.area,
 
 
 
 
204
  )
205
  if test_data is None:
206
  return -1
207
 
208
  h, w = test_data["ref_img"].shape[-2:]
209
+ resolution = (f"{(Image.open(image_path).width//2)*2}x{(Image.open(image_path).height//2)*2}"
210
+ if keep_resolution else f"{w}x{h}")
 
 
 
211
 
212
+ video = test(self.pipe, cfg, self.whisper, self.audio2token,
213
+ self.audio2bucket, self.image_encoder, w, h, test_data)
 
 
 
 
 
 
 
 
 
 
214
 
 
215
  if cfg.use_interframe:
216
  out = video.to(self.device)
217
  frames = []
218
+ for i in tqdm(range(out.shape[2]-1), ncols=0):
219
+ mid = self.rife.inference(out[:,:,i], out[:,:,i+1]).clamp(0,1).detach()
220
+ frames.extend([out[:,:,i], mid])
221
+ frames.append(out[:,:,-1])
222
  video = torch.stack(frames, 2).cpu()
223
 
 
224
  tmp = output_path.replace(".mp4", "_noaudio.mp4")
225
+ save_videos_grid(video, tmp, n_rows=video.shape[0], fps=cfg.fps*(2 if cfg.use_interframe else 1))
226
+ os.system(f"ffmpeg -i '{tmp}' -i '{audio_path}' -s {resolution} "
227
+ f"-vcodec libx264 -acodec aac -crf 18 -shortest '{output_path}' -y -loglevel error")
 
 
228
  os.remove(tmp)
229
  return 0