openfree commited on
Commit
6ee08fc
·
verified ·
1 Parent(s): 1fc29a2

Update sonic.py

Browse files
Files changed (1) hide show
  1. sonic.py +58 -53
sonic.py CHANGED
@@ -20,7 +20,6 @@ from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel
20
  from src.utils.RIFE.RIFE_HDv3 import RIFEModel
21
  from src.dataset.face_align.align import AlignImage
22
 
23
-
24
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
25
 
26
  def test(
@@ -40,43 +39,44 @@ def test(
40
  ref_img = batch['ref_img']
41
  clip_img = batch['clip_images']
42
  face_mask = batch['face_mask']
43
- image_embeds = image_encoder(
44
- clip_img
45
- ).image_embeds
46
 
47
  audio_feature = batch['audio_feature']
48
  audio_len = batch['audio_len']
49
  step = int(config.step)
50
 
51
- window = 3000
 
 
52
  audio_prompts = []
53
  last_audio_prompts = []
54
  for i in range(0, audio_feature.shape[-1], window):
55
- audio_prompt = wav_enc.encoder(audio_feature[:,:,i:i+window], output_hidden_states=True).hidden_states
56
- last_audio_prompt = wav_enc.encoder(audio_feature[:,:,i:i+window]).last_hidden_state
 
 
57
  last_audio_prompt = last_audio_prompt.unsqueeze(-2)
 
58
  audio_prompt = torch.stack(audio_prompt, dim=2)
59
  audio_prompts.append(audio_prompt)
60
  last_audio_prompts.append(last_audio_prompt)
61
 
62
  audio_prompts = torch.cat(audio_prompts, dim=1)
63
- audio_prompts = audio_prompts[:,:audio_len*2]
64
- audio_prompts = torch.cat([torch.zeros_like(audio_prompts[:,:4]), audio_prompts, torch.zeros_like(audio_prompts[:,:6])], 1)
 
65
 
66
  last_audio_prompts = torch.cat(last_audio_prompts, dim=1)
67
- last_audio_prompts = last_audio_prompts[:,:audio_len*2]
68
- last_audio_prompts = torch.cat([torch.zeros_like(last_audio_prompts[:,:24]), last_audio_prompts, torch.zeros_like(last_audio_prompts[:,:26])], 1)
69
-
70
 
71
  ref_tensor_list = []
72
  audio_tensor_list = []
73
  uncond_audio_tensor_list = []
74
  motion_buckets = []
75
- for i in tqdm(range(audio_len//step)):
76
-
77
-
78
- audio_clip = audio_prompts[:,i*2*step:i*2*step+10].unsqueeze(0)
79
- audio_clip_for_bucket = last_audio_prompts[:,i*2*step:i*2*step+50].unsqueeze(0)
80
  motion_bucket = audio2bucket(audio_clip_for_bucket, image_embeds)
81
  motion_bucket = motion_bucket * 16 + 16
82
  motion_buckets.append(motion_bucket[0])
@@ -102,9 +102,9 @@ def test(
102
  motion_bucket_scale=config.motion_bucket_scale,
103
  fps=config.fps,
104
  noise_aug_strength=config.noise_aug_strength,
105
- min_guidance_scale1=config.min_appearance_guidance_scale, # 1.0,
106
  max_guidance_scale1=config.max_appearance_guidance_scale,
107
- min_guidance_scale2=config.audio_guidance_scale, # 1.0,
108
  max_guidance_scale2=config.audio_guidance_scale,
109
  overlap=config.overlap,
110
  shift_offset=config.shift_offset,
@@ -113,12 +113,8 @@ def test(
113
  i2i_noise_strength=config.i2i_noise_strength
114
  ).frames
115
 
116
-
117
- # Concat it with pose tensor
118
- # pose_tensor = torch.stack(pose_tensor_list,1).unsqueeze(0)
119
- video = (video*0.5 + 0.5).clamp(0, 1)
120
  video = torch.cat([video.to(pipe.device)], dim=0).cpu()
121
-
122
  return video
123
 
124
 
@@ -151,14 +147,24 @@ class Sonic():
151
  config.pretrained_model_name_or_path,
152
  subfolder="image_encoder",
153
  variant="fp16")
 
154
  unet = UNetSpatioTemporalConditionModel.from_pretrained(
155
  config.pretrained_model_name_or_path,
156
  subfolder="unet",
157
  variant="fp16")
 
158
  add_ip_adapters(unet, [32], [config.ip_audio_scale])
159
 
160
- audio2token = AudioProjModel(seq_len=10, blocks=5, channels=384, intermediate_dim=1024, output_dim=1024, context_tokens=32).to(device)
161
- audio2bucket = Audio2bucketModel(seq_len=50, blocks=1, channels=384, clip_channels=1024, intermediate_dim=1024, output_dim=1, context_tokens=2).to(device)
 
 
 
 
 
 
 
 
162
 
163
  unet_checkpoint_path = os.path.join(BASE_DIR, config.unet_checkpoint_path)
164
  audio2token_checkpoint_path = os.path.join(BASE_DIR, config.audio2token_checkpoint_path)
@@ -179,7 +185,6 @@ class Sonic():
179
  strict=True,
180
  )
181
 
182
-
183
  if config.weight_dtype == "fp16":
184
  weight_dtype = torch.float16
185
  elif config.weight_dtype == "fp32":
@@ -188,23 +193,21 @@ class Sonic():
188
  weight_dtype = torch.bfloat16
189
  else:
190
  raise ValueError(
191
- f"Do not support weight dtype: {config.weight_dtype} during training"
192
  )
193
 
194
  whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/')).to(device).eval()
195
-
196
  whisper.requires_grad_(False)
197
 
198
  self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/'))
199
 
200
- det_path = os.path.join(BASE_DIR, os.path.join(BASE_DIR, 'checkpoints/yoloface_v5m.pt'))
201
  self.face_det = AlignImage(device, det_path=det_path)
202
  if config.use_interframe:
203
  rife = RIFEModel(device=device)
204
  rife.load_model(os.path.join(BASE_DIR, 'checkpoints', 'RIFE/'))
205
  self.rife = rife
206
 
207
-
208
  image_encoder.to(weight_dtype)
209
  vae.to(weight_dtype)
210
  unet.to(weight_dtype)
@@ -217,7 +220,6 @@ class Sonic():
217
  )
218
  pipe = pipe.to(device=device, dtype=weight_dtype)
219
 
220
-
221
  self.pipe = pipe
222
  self.whisper = whisper
223
  self.audio2token = audio2token
@@ -225,16 +227,15 @@ class Sonic():
225
  self.image_encoder = image_encoder
226
  self.device = device
227
 
228
- print('init done')
229
 
230
 
231
- def preprocess(self,
232
- image_path, expand_ratio=1.0):
233
  face_image = cv2.imread(image_path)
234
  h, w = face_image.shape[:2]
235
  _, _, bboxes = self.face_det(face_image, maxface=True)
236
  face_num = len(bboxes)
237
- bbox = []
238
  if face_num > 0:
239
  x1, y1, ww, hh = bboxes[0]
240
  x2, y2 = x1 + ww, y1 + hh
@@ -246,10 +247,7 @@ class Sonic():
246
  'crop_bbox': bbox_s,
247
  }
248
 
249
- def crop_image(self,
250
- input_image_path,
251
- output_image_path,
252
- crop_bbox):
253
  face_image = cv2.imread(input_image_path)
254
  crop_image = face_image[crop_bbox[1]:crop_bbox[3], crop_bbox[0]:crop_bbox[2]]
255
  cv2.imwrite(output_image_path, crop_image)
@@ -273,27 +271,34 @@ class Sonic():
273
  audio2bucket = self.audio2bucket
274
  image_encoder = self.image_encoder
275
 
276
- # specific parameters
277
  if seed:
278
  config.seed = seed
279
-
280
  config.num_inference_steps = inference_steps
281
-
282
  config.motion_bucket_scale = dynamic_scale
283
-
284
  seed_everything(config.seed)
285
 
286
  video_path = output_path.replace('.mp4', '_noaudio.mp4')
287
  audio_video_path = output_path
288
 
289
- imSrc_ = Image.open(image_path).convert('RGB')
290
- raw_w, raw_h = imSrc_.size
 
 
 
 
 
 
 
 
 
291
 
292
- test_data = image_audio_to_tensor(self.face_det, self.feature_extractor, image_path, audio_path, limit=config.frame_num, image_size=min_resolution, area=config.area)
293
  if test_data is None:
294
  return -1
 
295
  height, width = test_data['ref_img'].shape[-2:]
296
  if keep_resolution:
 
 
297
  resolution = f'{raw_w//2*2}x{raw_h//2*2}'
298
  else:
299
  resolution = f'{width}x{height}'
@@ -308,23 +313,23 @@ class Sonic():
308
  width=width,
309
  height=height,
310
  batch=test_data,
311
- )
312
 
 
313
  if config.use_interframe:
314
  rife = self.rife
315
  out = video.to(device)
316
  results = []
317
  video_len = out.shape[2]
318
- for idx in tqdm(range(video_len-1), ncols=0):
319
  I1 = out[:, :, idx]
320
- I2 = out[:, :, idx+1]
321
  middle = rife.inference(I1, I2).clamp(0, 1).detach()
322
  results.append(out[:, :, idx])
323
  results.append(middle)
324
- results.append(out[:, :, video_len-1])
325
  video = torch.stack(results, 2).cpu()
326
 
327
- save_videos_grid(video, video_path, n_rows=video.shape[0], fps=config.fps * 2 if config.use_interframe else config.fps)
328
  os.system(f"ffmpeg -i '{video_path}' -i '{audio_path}' -s {resolution} -vcodec libx264 -acodec aac -crf 18 -shortest '{audio_video_path}' -y; rm '{video_path}'")
329
  return 0
330
-
 
20
  from src.utils.RIFE.RIFE_HDv3 import RIFEModel
21
  from src.dataset.face_align.align import AlignImage
22
 
 
23
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
24
 
25
  def test(
 
39
  ref_img = batch['ref_img']
40
  clip_img = batch['clip_images']
41
  face_mask = batch['face_mask']
42
+ image_embeds = image_encoder(clip_img).image_embeds
 
 
43
 
44
  audio_feature = batch['audio_feature']
45
  audio_len = batch['audio_len']
46
  step = int(config.step)
47
 
48
+ # 여기서 window=3000 이었던 값을 더 크게 바꿔 최대 60초를 처리할 수 있게 함
49
+ # whisper-tiny는 기본 16kHz 샘플링이므로, 16,000단위면 대략 1초씩 끊게 됨
50
+ window = 16000 # (1초 단위로 chunk 처리)
51
  audio_prompts = []
52
  last_audio_prompts = []
53
  for i in range(0, audio_feature.shape[-1], window):
54
+ audio_clip_chunk = audio_feature[:, :, i:i+window]
55
+ # Whisper encoder
56
+ audio_prompt = wav_enc.encoder(audio_clip_chunk, output_hidden_states=True).hidden_states
57
+ last_audio_prompt = wav_enc.encoder(audio_clip_chunk).last_hidden_state
58
  last_audio_prompt = last_audio_prompt.unsqueeze(-2)
59
+
60
  audio_prompt = torch.stack(audio_prompt, dim=2)
61
  audio_prompts.append(audio_prompt)
62
  last_audio_prompts.append(last_audio_prompt)
63
 
64
  audio_prompts = torch.cat(audio_prompts, dim=1)
65
+ # audio_len*2 부분은 모델 내부 로직에 따라 필요한 padding 처리
66
+ audio_prompts = audio_prompts[:, :audio_len*2]
67
+ audio_prompts = torch.cat([torch.zeros_like(audio_prompts[:, :4]), audio_prompts, torch.zeros_like(audio_prompts[:, :6])], 1)
68
 
69
  last_audio_prompts = torch.cat(last_audio_prompts, dim=1)
70
+ last_audio_prompts = last_audio_prompts[:, :audio_len*2]
71
+ last_audio_prompts = torch.cat([torch.zeros_like(last_audio_prompts[:, :24]), last_audio_prompts, torch.zeros_like(last_audio_prompts[:, :26])], 1)
 
72
 
73
  ref_tensor_list = []
74
  audio_tensor_list = []
75
  uncond_audio_tensor_list = []
76
  motion_buckets = []
77
+ for i in tqdm(range(audio_len // step)):
78
+ audio_clip = audio_prompts[:, i*2*step : i*2*step + 10].unsqueeze(0)
79
+ audio_clip_for_bucket = last_audio_prompts[:, i*2*step : i*2*step + 50].unsqueeze(0)
 
 
80
  motion_bucket = audio2bucket(audio_clip_for_bucket, image_embeds)
81
  motion_bucket = motion_bucket * 16 + 16
82
  motion_buckets.append(motion_bucket[0])
 
102
  motion_bucket_scale=config.motion_bucket_scale,
103
  fps=config.fps,
104
  noise_aug_strength=config.noise_aug_strength,
105
+ min_guidance_scale1=config.min_appearance_guidance_scale,
106
  max_guidance_scale1=config.max_appearance_guidance_scale,
107
+ min_guidance_scale2=config.audio_guidance_scale,
108
  max_guidance_scale2=config.audio_guidance_scale,
109
  overlap=config.overlap,
110
  shift_offset=config.shift_offset,
 
113
  i2i_noise_strength=config.i2i_noise_strength
114
  ).frames
115
 
116
+ video = (video * 0.5 + 0.5).clamp(0, 1)
 
 
 
117
  video = torch.cat([video.to(pipe.device)], dim=0).cpu()
 
118
  return video
119
 
120
 
 
147
  config.pretrained_model_name_or_path,
148
  subfolder="image_encoder",
149
  variant="fp16")
150
+
151
  unet = UNetSpatioTemporalConditionModel.from_pretrained(
152
  config.pretrained_model_name_or_path,
153
  subfolder="unet",
154
  variant="fp16")
155
+
156
  add_ip_adapters(unet, [32], [config.ip_audio_scale])
157
 
158
+ audio2token = AudioProjModel(
159
+ seq_len=10, blocks=5, channels=384,
160
+ intermediate_dim=1024, output_dim=1024, context_tokens=32
161
+ ).to(device)
162
+
163
+ audio2bucket = Audio2bucketModel(
164
+ seq_len=50, blocks=1, channels=384,
165
+ clip_channels=1024, intermediate_dim=1024, output_dim=1,
166
+ context_tokens=2
167
+ ).to(device)
168
 
169
  unet_checkpoint_path = os.path.join(BASE_DIR, config.unet_checkpoint_path)
170
  audio2token_checkpoint_path = os.path.join(BASE_DIR, config.audio2token_checkpoint_path)
 
185
  strict=True,
186
  )
187
 
 
188
  if config.weight_dtype == "fp16":
189
  weight_dtype = torch.float16
190
  elif config.weight_dtype == "fp32":
 
193
  weight_dtype = torch.bfloat16
194
  else:
195
  raise ValueError(
196
+ f"Do not support weight dtype: {config.weight_dtype}"
197
  )
198
 
199
  whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/')).to(device).eval()
 
200
  whisper.requires_grad_(False)
201
 
202
  self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/'))
203
 
204
+ det_path = os.path.join(BASE_DIR, 'checkpoints/yoloface_v5m.pt')
205
  self.face_det = AlignImage(device, det_path=det_path)
206
  if config.use_interframe:
207
  rife = RIFEModel(device=device)
208
  rife.load_model(os.path.join(BASE_DIR, 'checkpoints', 'RIFE/'))
209
  self.rife = rife
210
 
 
211
  image_encoder.to(weight_dtype)
212
  vae.to(weight_dtype)
213
  unet.to(weight_dtype)
 
220
  )
221
  pipe = pipe.to(device=device, dtype=weight_dtype)
222
 
 
223
  self.pipe = pipe
224
  self.whisper = whisper
225
  self.audio2token = audio2token
 
227
  self.image_encoder = image_encoder
228
  self.device = device
229
 
230
+ print('Sonic init done')
231
 
232
 
233
+ def preprocess(self, image_path, expand_ratio=1.0):
 
234
  face_image = cv2.imread(image_path)
235
  h, w = face_image.shape[:2]
236
  _, _, bboxes = self.face_det(face_image, maxface=True)
237
  face_num = len(bboxes)
238
+ bbox_s = None
239
  if face_num > 0:
240
  x1, y1, ww, hh = bboxes[0]
241
  x2, y2 = x1 + ww, y1 + hh
 
247
  'crop_bbox': bbox_s,
248
  }
249
 
250
+ def crop_image(self, input_image_path, output_image_path, crop_bbox):
 
 
 
251
  face_image = cv2.imread(input_image_path)
252
  crop_image = face_image[crop_bbox[1]:crop_bbox[3], crop_bbox[0]:crop_bbox[2]]
253
  cv2.imwrite(output_image_path, crop_image)
 
271
  audio2bucket = self.audio2bucket
272
  image_encoder = self.image_encoder
273
 
 
274
  if seed:
275
  config.seed = seed
 
276
  config.num_inference_steps = inference_steps
 
277
  config.motion_bucket_scale = dynamic_scale
 
278
  seed_everything(config.seed)
279
 
280
  video_path = output_path.replace('.mp4', '_noaudio.mp4')
281
  audio_video_path = output_path
282
 
283
+ # limit=config.frame_num 대신 오디오 전체를 쓰도록 수정
284
+ # 만약 config.frame_num이 작아 2초 제한을 걸고 있었다면 제거해야 함
285
+ test_data = image_audio_to_tensor(
286
+ self.face_det,
287
+ self.feature_extractor,
288
+ image_path,
289
+ audio_path,
290
+ limit=-1, # -1 등으로 제한 해제
291
+ image_size=min_resolution,
292
+ area=config.area
293
+ )
294
 
 
295
  if test_data is None:
296
  return -1
297
+
298
  height, width = test_data['ref_img'].shape[-2:]
299
  if keep_resolution:
300
+ imSrc_ = Image.open(image_path).convert('RGB')
301
+ raw_w, raw_h = imSrc_.size
302
  resolution = f'{raw_w//2*2}x{raw_h//2*2}'
303
  else:
304
  resolution = f'{width}x{height}'
 
313
  width=width,
314
  height=height,
315
  batch=test_data,
316
+ )
317
 
318
+ # 중간프레임 보간 사용시
319
  if config.use_interframe:
320
  rife = self.rife
321
  out = video.to(device)
322
  results = []
323
  video_len = out.shape[2]
324
+ for idx in tqdm(range(video_len - 1), ncols=0):
325
  I1 = out[:, :, idx]
326
+ I2 = out[:, :, idx + 1]
327
  middle = rife.inference(I1, I2).clamp(0, 1).detach()
328
  results.append(out[:, :, idx])
329
  results.append(middle)
330
+ results.append(out[:, :, video_len - 1])
331
  video = torch.stack(results, 2).cpu()
332
 
333
+ save_videos_grid(video, video_path, n_rows=video.shape[0], fps=config.fps * (2 if config.use_interframe else 1))
334
  os.system(f"ffmpeg -i '{video_path}' -i '{audio_path}' -s {resolution} -vcodec libx264 -acodec aac -crf 18 -shortest '{audio_video_path}' -y; rm '{video_path}'")
335
  return 0