openfree commited on
Commit
79f5781
·
verified ·
1 Parent(s): cca593e

Update sonic.py

Browse files
Files changed (1) hide show
  1. sonic.py +53 -22
sonic.py CHANGED
@@ -33,9 +33,11 @@ def test(
33
  height,
34
  batch
35
  ):
 
36
  for k, v in batch.items():
37
  if isinstance(v, torch.Tensor):
38
  batch[k] = v.unsqueeze(0).to(pipe.device).float()
 
39
  ref_img = batch['ref_img']
40
  clip_img = batch['clip_images']
41
  face_mask = batch['face_mask']
@@ -45,11 +47,11 @@ def test(
45
  audio_len = batch['audio_len']
46
  step = int(config.step)
47
 
48
- # 여기서 window=3000 이었던 값을 크게 바꿔 최대 60초를 처리할 수 있게 함
49
- # whisper-tiny는 기본 16kHz 샘플링이므로, 16,000단위면 대략 1초씩 끊게 됨
50
- window = 16000 # (1초 단위로 chunk 처리)
51
  audio_prompts = []
52
  last_audio_prompts = []
 
53
  for i in range(0, audio_feature.shape[-1], window):
54
  audio_clip_chunk = audio_feature[:, :, i:i+window]
55
  # Whisper encoder
@@ -61,30 +63,38 @@ def test(
61
  audio_prompts.append(audio_prompt)
62
  last_audio_prompts.append(last_audio_prompt)
63
 
64
- # ---------------------- [추가된 예외 처리] ----------------------
65
  if len(audio_prompts) == 0:
66
  raise ValueError(
67
  "[ERROR] No speech recognized from the audio. "
68
  "Please provide a valid speech audio (with clear voice)."
69
  )
70
- # -------------------------------------------------------------
71
 
72
  audio_prompts = torch.cat(audio_prompts, dim=1)
73
- # audio_len*2 부분은 모델 내부 로직에 따라 필요한 padding 처리
74
  audio_prompts = audio_prompts[:, :audio_len*2]
75
- audio_prompts = torch.cat([torch.zeros_like(audio_prompts[:, :4]), audio_prompts, torch.zeros_like(audio_prompts[:, :6])], 1)
 
 
 
 
76
 
77
  last_audio_prompts = torch.cat(last_audio_prompts, dim=1)
78
  last_audio_prompts = last_audio_prompts[:, :audio_len*2]
79
- last_audio_prompts = torch.cat([torch.zeros_like(last_audio_prompts[:, :24]), last_audio_prompts, torch.zeros_like(last_audio_prompts[:, :26])], 1)
 
 
 
 
80
 
81
  ref_tensor_list = []
82
  audio_tensor_list = []
83
  uncond_audio_tensor_list = []
84
  motion_buckets = []
 
85
  for i in tqdm(range(audio_len // step)):
86
  audio_clip = audio_prompts[:, i*2*step : i*2*step + 10].unsqueeze(0)
87
  audio_clip_for_bucket = last_audio_prompts[:, i*2*step : i*2*step + 50].unsqueeze(0)
 
88
  motion_bucket = audio2bucket(audio_clip_for_bucket, image_embeds)
89
  motion_bucket = motion_bucket * 16 + 16
90
  motion_buckets.append(motion_bucket[0])
@@ -138,29 +148,33 @@ class Sonic():
138
  config = self.config
139
  config.use_interframe = enable_interpolate_frame
140
 
141
- device = 'cuda:{}'.format(device_id) if device_id > -1 else 'cpu'
142
-
143
  config.pretrained_model_name_or_path = os.path.join(BASE_DIR, config.pretrained_model_name_or_path)
144
 
 
145
  vae = AutoencoderKLTemporalDecoder.from_pretrained(
146
  config.pretrained_model_name_or_path,
147
  subfolder="vae",
148
  variant="fp16")
149
 
 
150
  val_noise_scheduler = EulerDiscreteScheduler.from_pretrained(
151
  config.pretrained_model_name_or_path,
152
  subfolder="scheduler")
153
 
 
154
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(
155
  config.pretrained_model_name_or_path,
156
  subfolder="image_encoder",
157
  variant="fp16")
158
 
 
159
  unet = UNetSpatioTemporalConditionModel.from_pretrained(
160
  config.pretrained_model_name_or_path,
161
  subfolder="unet",
162
  variant="fp16")
163
 
 
164
  add_ip_adapters(unet, [32], [config.ip_audio_scale])
165
 
166
  audio2token = AudioProjModel(
@@ -174,6 +188,7 @@ class Sonic():
174
  context_tokens=2
175
  ).to(device)
176
 
 
177
  unet_checkpoint_path = os.path.join(BASE_DIR, config.unet_checkpoint_path)
178
  audio2token_checkpoint_path = os.path.join(BASE_DIR, config.audio2token_checkpoint_path)
179
  audio2bucket_checkpoint_path = os.path.join(BASE_DIR, config.audio2bucket_checkpoint_path)
@@ -193,6 +208,7 @@ class Sonic():
193
  strict=True,
194
  )
195
 
 
196
  if config.weight_dtype == "fp16":
197
  weight_dtype = torch.float16
198
  elif config.weight_dtype == "fp32":
@@ -200,26 +216,34 @@ class Sonic():
200
  elif config.weight_dtype == "bf16":
201
  weight_dtype = torch.bfloat16
202
  else:
203
- raise ValueError(
204
- f"Do not support weight dtype: {config.weight_dtype}"
205
- )
206
 
207
- whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/')).to(device).eval()
 
 
 
208
  whisper.requires_grad_(False)
209
 
210
- self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/'))
 
 
211
 
 
212
  det_path = os.path.join(BASE_DIR, 'checkpoints/yoloface_v5m.pt')
213
  self.face_det = AlignImage(device, det_path=det_path)
 
 
214
  if config.use_interframe:
215
  rife = RIFEModel(device=device)
216
  rife.load_model(os.path.join(BASE_DIR, 'checkpoints', 'RIFE/'))
217
  self.rife = rife
218
 
 
219
  image_encoder.to(weight_dtype)
220
  vae.to(weight_dtype)
221
  unet.to(weight_dtype)
222
 
 
223
  pipe = SonicPipeline(
224
  unet=unet,
225
  image_encoder=image_encoder,
@@ -237,13 +261,13 @@ class Sonic():
237
 
238
  print('Sonic init done')
239
 
240
-
241
  def preprocess(self, image_path, expand_ratio=1.0):
242
  face_image = cv2.imread(image_path)
243
  h, w = face_image.shape[:2]
244
  _, _, bboxes = self.face_det(face_image, maxface=True)
245
  face_num = len(bboxes)
246
  bbox_s = None
 
247
  if face_num > 0:
248
  x1, y1, ww, hh = bboxes[0]
249
  x2, y2 = x1 + ww, y1 + hh
@@ -270,7 +294,7 @@ class Sonic():
270
  dynamic_scale=1.0,
271
  keep_resolution=False,
272
  seed=None):
273
-
274
  config = self.config
275
  device = self.device
276
  pipe = self.pipe
@@ -279,6 +303,7 @@ class Sonic():
279
  audio2bucket = self.audio2bucket
280
  image_encoder = self.image_encoder
281
 
 
282
  if seed:
283
  config.seed = seed
284
  config.num_inference_steps = inference_steps
@@ -288,17 +313,16 @@ class Sonic():
288
  video_path = output_path.replace('.mp4', '_noaudio.mp4')
289
  audio_video_path = output_path
290
 
291
- # limit=config.frame_num 대신 오디오 전체를 쓰도록 수정
292
  test_data = image_audio_to_tensor(
293
  self.face_det,
294
  self.feature_extractor,
295
  image_path,
296
  audio_path,
297
- limit=-1, # -1 등으로 제한 해제
298
  image_size=min_resolution,
299
  area=config.area
300
  )
301
-
302
  if test_data is None:
303
  return -1
304
 
@@ -310,6 +334,7 @@ class Sonic():
310
  else:
311
  resolution = f'{width}x{height}'
312
 
 
313
  video = test(
314
  pipe,
315
  config,
@@ -322,7 +347,7 @@ class Sonic():
322
  batch=test_data,
323
  )
324
 
325
- # 중간프레임 보간 사용시
326
  if config.use_interframe:
327
  rife = self.rife
328
  out = video.to(device)
@@ -337,6 +362,12 @@ class Sonic():
337
  results.append(out[:, :, video_len - 1])
338
  video = torch.stack(results, 2).cpu()
339
 
 
340
  save_videos_grid(video, video_path, n_rows=video.shape[0], fps=config.fps * (2 if config.use_interframe else 1))
341
- os.system(f"ffmpeg -i '{video_path}' -i '{audio_path}' -s {resolution} -vcodec libx264 -acodec aac -crf 18 -shortest '{audio_video_path}' -y; rm '{video_path}'")
 
 
 
 
 
342
  return 0
 
33
  height,
34
  batch
35
  ):
36
+ # 배치 텐서를 (1,B,C,H,W) 형태로
37
  for k, v in batch.items():
38
  if isinstance(v, torch.Tensor):
39
  batch[k] = v.unsqueeze(0).to(pipe.device).float()
40
+
41
  ref_img = batch['ref_img']
42
  clip_img = batch['clip_images']
43
  face_mask = batch['face_mask']
 
47
  audio_len = batch['audio_len']
48
  step = int(config.step)
49
 
50
+ # window=3000 -> 16000으로 변경(1초 간격)
51
+ window = 16000
 
52
  audio_prompts = []
53
  last_audio_prompts = []
54
+
55
  for i in range(0, audio_feature.shape[-1], window):
56
  audio_clip_chunk = audio_feature[:, :, i:i+window]
57
  # Whisper encoder
 
63
  audio_prompts.append(audio_prompt)
64
  last_audio_prompts.append(last_audio_prompt)
65
 
66
+ # 여기서 비었으면 예외
67
  if len(audio_prompts) == 0:
68
  raise ValueError(
69
  "[ERROR] No speech recognized from the audio. "
70
  "Please provide a valid speech audio (with clear voice)."
71
  )
 
72
 
73
  audio_prompts = torch.cat(audio_prompts, dim=1)
 
74
  audio_prompts = audio_prompts[:, :audio_len*2]
75
+ audio_prompts = torch.cat([
76
+ torch.zeros_like(audio_prompts[:, :4]),
77
+ audio_prompts,
78
+ torch.zeros_like(audio_prompts[:, :6])
79
+ ], dim=1)
80
 
81
  last_audio_prompts = torch.cat(last_audio_prompts, dim=1)
82
  last_audio_prompts = last_audio_prompts[:, :audio_len*2]
83
+ last_audio_prompts = torch.cat([
84
+ torch.zeros_like(last_audio_prompts[:, :24]),
85
+ last_audio_prompts,
86
+ torch.zeros_like(last_audio_prompts[:, :26])
87
+ ], dim=1)
88
 
89
  ref_tensor_list = []
90
  audio_tensor_list = []
91
  uncond_audio_tensor_list = []
92
  motion_buckets = []
93
+
94
  for i in tqdm(range(audio_len // step)):
95
  audio_clip = audio_prompts[:, i*2*step : i*2*step + 10].unsqueeze(0)
96
  audio_clip_for_bucket = last_audio_prompts[:, i*2*step : i*2*step + 50].unsqueeze(0)
97
+
98
  motion_bucket = audio2bucket(audio_clip_for_bucket, image_embeds)
99
  motion_bucket = motion_bucket * 16 + 16
100
  motion_buckets.append(motion_bucket[0])
 
148
  config = self.config
149
  config.use_interframe = enable_interpolate_frame
150
 
151
+ device = f'cuda:{device_id}' if device_id > -1 else 'cpu'
 
152
  config.pretrained_model_name_or_path = os.path.join(BASE_DIR, config.pretrained_model_name_or_path)
153
 
154
+ # VAE
155
  vae = AutoencoderKLTemporalDecoder.from_pretrained(
156
  config.pretrained_model_name_or_path,
157
  subfolder="vae",
158
  variant="fp16")
159
 
160
+ # 스케줄러
161
  val_noise_scheduler = EulerDiscreteScheduler.from_pretrained(
162
  config.pretrained_model_name_or_path,
163
  subfolder="scheduler")
164
 
165
+ # CLIP Vision
166
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(
167
  config.pretrained_model_name_or_path,
168
  subfolder="image_encoder",
169
  variant="fp16")
170
 
171
+ # UNet
172
  unet = UNetSpatioTemporalConditionModel.from_pretrained(
173
  config.pretrained_model_name_or_path,
174
  subfolder="unet",
175
  variant="fp16")
176
 
177
+ # Adapter
178
  add_ip_adapters(unet, [32], [config.ip_audio_scale])
179
 
180
  audio2token = AudioProjModel(
 
188
  context_tokens=2
189
  ).to(device)
190
 
191
+ # 로컬 체크포인트 로드
192
  unet_checkpoint_path = os.path.join(BASE_DIR, config.unet_checkpoint_path)
193
  audio2token_checkpoint_path = os.path.join(BASE_DIR, config.audio2token_checkpoint_path)
194
  audio2bucket_checkpoint_path = os.path.join(BASE_DIR, config.audio2bucket_checkpoint_path)
 
208
  strict=True,
209
  )
210
 
211
+ # weight_dtype 설정
212
  if config.weight_dtype == "fp16":
213
  weight_dtype = torch.float16
214
  elif config.weight_dtype == "fp32":
 
216
  elif config.weight_dtype == "bf16":
217
  weight_dtype = torch.bfloat16
218
  else:
219
+ raise ValueError(f"Do not support weight dtype: {config.weight_dtype}")
 
 
220
 
221
+ # Whisper
222
+ whisper = WhisperModel.from_pretrained(
223
+ os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/')
224
+ ).to(device).eval()
225
  whisper.requires_grad_(False)
226
 
227
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(
228
+ os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/')
229
+ )
230
 
231
+ # Face detect
232
  det_path = os.path.join(BASE_DIR, 'checkpoints/yoloface_v5m.pt')
233
  self.face_det = AlignImage(device, det_path=det_path)
234
+
235
+ # RIFE 중간프레임 보간
236
  if config.use_interframe:
237
  rife = RIFEModel(device=device)
238
  rife.load_model(os.path.join(BASE_DIR, 'checkpoints', 'RIFE/'))
239
  self.rife = rife
240
 
241
+ # dtype 변경
242
  image_encoder.to(weight_dtype)
243
  vae.to(weight_dtype)
244
  unet.to(weight_dtype)
245
 
246
+ # SonicPipeline 초기화
247
  pipe = SonicPipeline(
248
  unet=unet,
249
  image_encoder=image_encoder,
 
261
 
262
  print('Sonic init done')
263
 
 
264
  def preprocess(self, image_path, expand_ratio=1.0):
265
  face_image = cv2.imread(image_path)
266
  h, w = face_image.shape[:2]
267
  _, _, bboxes = self.face_det(face_image, maxface=True)
268
  face_num = len(bboxes)
269
  bbox_s = None
270
+
271
  if face_num > 0:
272
  x1, y1, ww, hh = bboxes[0]
273
  x2, y2 = x1 + ww, y1 + hh
 
294
  dynamic_scale=1.0,
295
  keep_resolution=False,
296
  seed=None):
297
+
298
  config = self.config
299
  device = self.device
300
  pipe = self.pipe
 
303
  audio2bucket = self.audio2bucket
304
  image_encoder = self.image_encoder
305
 
306
+ # 시드 설정
307
  if seed:
308
  config.seed = seed
309
  config.num_inference_steps = inference_steps
 
313
  video_path = output_path.replace('.mp4', '_noaudio.mp4')
314
  audio_video_path = output_path
315
 
316
+ # 오디오+이미지 -> tensor
317
  test_data = image_audio_to_tensor(
318
  self.face_det,
319
  self.feature_extractor,
320
  image_path,
321
  audio_path,
322
+ limit=-1, # 전체 사용
323
  image_size=min_resolution,
324
  area=config.area
325
  )
 
326
  if test_data is None:
327
  return -1
328
 
 
334
  else:
335
  resolution = f'{width}x{height}'
336
 
337
+ # 여기서 test(...) 호출
338
  video = test(
339
  pipe,
340
  config,
 
347
  batch=test_data,
348
  )
349
 
350
+ # 중간프레임 보간
351
  if config.use_interframe:
352
  rife = self.rife
353
  out = video.to(device)
 
362
  results.append(out[:, :, video_len - 1])
363
  video = torch.stack(results, 2).cpu()
364
 
365
+ # 비디오 저장
366
  save_videos_grid(video, video_path, n_rows=video.shape[0], fps=config.fps * (2 if config.use_interframe else 1))
367
+
368
+ # 오디오 합성 후 최종 mp4
369
+ os.system(
370
+ f"ffmpeg -i '{video_path}' -i '{audio_path}' -s {resolution} "
371
+ f"-vcodec libx264 -acodec aac -crf 18 -shortest '{audio_video_path}' -y; rm '{video_path}'"
372
+ )
373
  return 0