openfree commited on
Commit
36737d5
·
verified ·
1 Parent(s): 7bf073b

Update sonic.py

Browse files
Files changed (1) hide show
  1. sonic.py +70 -159
sonic.py CHANGED
@@ -20,9 +20,9 @@ from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel
20
  from src.utils.RIFE.RIFE_HDv3 import RIFEModel
21
  from src.dataset.face_align.align import AlignImage
22
 
23
-
24
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
25
 
 
26
  def test(
27
  pipe,
28
  config,
@@ -34,15 +34,15 @@ def test(
34
  height,
35
  batch
36
  ):
 
37
  for k, v in batch.items():
38
  if isinstance(v, torch.Tensor):
39
  batch[k] = v.unsqueeze(0).to(pipe.device).float()
 
40
  ref_img = batch['ref_img']
41
  clip_img = batch['clip_images']
42
  face_mask = batch['face_mask']
43
- image_embeds = image_encoder(
44
- clip_img
45
- ).image_embeds
46
 
47
  audio_feature = batch['audio_feature']
48
  audio_len = batch['audio_len']
@@ -52,31 +52,30 @@ def test(
52
  audio_prompts = []
53
  last_audio_prompts = []
54
  for i in range(0, audio_feature.shape[-1], window):
55
- audio_prompt = wav_enc.encoder(audio_feature[:,:,i:i+window], output_hidden_states=True).hidden_states
56
- last_audio_prompt = wav_enc.encoder(audio_feature[:,:,i:i+window]).last_hidden_state
57
  last_audio_prompt = last_audio_prompt.unsqueeze(-2)
58
  audio_prompt = torch.stack(audio_prompt, dim=2)
59
  audio_prompts.append(audio_prompt)
60
  last_audio_prompts.append(last_audio_prompt)
61
 
62
  audio_prompts = torch.cat(audio_prompts, dim=1)
63
- audio_prompts = audio_prompts[:,:audio_len*2]
64
- audio_prompts = torch.cat([torch.zeros_like(audio_prompts[:,:4]), audio_prompts, torch.zeros_like(audio_prompts[:,:6])], 1)
 
65
 
66
  last_audio_prompts = torch.cat(last_audio_prompts, dim=1)
67
- last_audio_prompts = last_audio_prompts[:,:audio_len*2]
68
- last_audio_prompts = torch.cat([torch.zeros_like(last_audio_prompts[:,:24]), last_audio_prompts, torch.zeros_like(last_audio_prompts[:,:26])], 1)
69
-
70
 
71
  ref_tensor_list = []
72
  audio_tensor_list = []
73
  uncond_audio_tensor_list = []
74
  motion_buckets = []
75
- for i in tqdm(range(audio_len//step)):
76
-
77
-
78
- audio_clip = audio_prompts[:,i*2*step:i*2*step+10].unsqueeze(0)
79
- audio_clip_for_bucket = last_audio_prompts[:,i*2*step:i*2*step+50].unsqueeze(0)
80
  motion_bucket = audio2bucket(audio_clip_for_bucket, image_embeds)
81
  motion_bucket = motion_bucket * 16 + 16
82
  motion_buckets.append(motion_bucket[0])
@@ -102,9 +101,9 @@ def test(
102
  motion_bucket_scale=config.motion_bucket_scale,
103
  fps=config.fps,
104
  noise_aug_strength=config.noise_aug_strength,
105
- min_guidance_scale1=config.min_appearance_guidance_scale, # 1.0,
106
  max_guidance_scale1=config.max_appearance_guidance_scale,
107
- min_guidance_scale2=config.audio_guidance_scale, # 1.0,
108
  max_guidance_scale2=config.audio_guidance_scale,
109
  overlap=config.overlap,
110
  shift_offset=config.shift_offset,
@@ -113,73 +112,69 @@ def test(
113
  i2i_noise_strength=config.i2i_noise_strength
114
  ).frames
115
 
116
-
117
- # Concat it with pose tensor
118
- # pose_tensor = torch.stack(pose_tensor_list,1).unsqueeze(0)
119
- video = (video*0.5 + 0.5).clamp(0, 1)
120
  video = torch.cat([video.to(pipe.device)], dim=0).cpu()
121
 
122
  return video
123
 
124
 
125
- class Sonic():
 
 
126
  config_file = os.path.join(BASE_DIR, 'config/inference/sonic.yaml')
127
  config = OmegaConf.load(config_file)
128
 
129
- def __init__(self,
130
- device_id=0,
131
- enable_interpolate_frame=True,
132
- ):
133
-
134
  config = self.config
135
  config.use_interframe = enable_interpolate_frame
136
 
137
- device = 'cuda:{}'.format(device_id) if device_id > -1 else 'cpu'
 
138
 
 
139
  config.pretrained_model_name_or_path = os.path.join(BASE_DIR, config.pretrained_model_name_or_path)
140
 
 
141
  vae = AutoencoderKLTemporalDecoder.from_pretrained(
142
- config.pretrained_model_name_or_path,
143
  subfolder="vae",
144
- variant="fp16")
145
-
 
146
  val_noise_scheduler = EulerDiscreteScheduler.from_pretrained(
147
- config.pretrained_model_name_or_path,
148
- subfolder="scheduler")
149
-
 
150
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(
151
- config.pretrained_model_name_or_path,
152
  subfolder="image_encoder",
153
- variant="fp16")
 
 
154
  unet = UNetSpatioTemporalConditionModel.from_pretrained(
155
  config.pretrained_model_name_or_path,
156
  subfolder="unet",
157
- variant="fp16")
 
158
  add_ip_adapters(unet, [32], [config.ip_audio_scale])
159
-
160
- audio2token = AudioProjModel(seq_len=10, blocks=5, channels=384, intermediate_dim=1024, output_dim=1024, context_tokens=32).to(device)
161
- audio2bucket = Audio2bucketModel(seq_len=50, blocks=1, channels=384, clip_channels=1024, intermediate_dim=1024, output_dim=1, context_tokens=2).to(device)
162
 
163
- unet_checkpoint_path = os.path.join(BASE_DIR, config.unet_checkpoint_path)
164
- audio2token_checkpoint_path = os.path.join(BASE_DIR, config.audio2token_checkpoint_path)
165
- audio2bucket_checkpoint_path = os.path.join(BASE_DIR, config.audio2bucket_checkpoint_path)
 
166
 
167
- unet.load_state_dict(
168
- torch.load(unet_checkpoint_path, map_location="cpu"),
169
- strict=True,
170
- )
171
-
172
- audio2token.load_state_dict(
173
- torch.load(audio2token_checkpoint_path, map_location="cpu"),
174
- strict=True,
175
- )
176
 
177
- audio2bucket.load_state_dict(
178
- torch.load(audio2bucket_checkpoint_path, map_location="cpu"),
179
- strict=True,
180
- )
181
-
182
 
 
183
  if config.weight_dtype == "fp16":
184
  weight_dtype = torch.float16
185
  elif config.weight_dtype == "fp32":
@@ -187,54 +182,48 @@ class Sonic():
187
  elif config.weight_dtype == "bf16":
188
  weight_dtype = torch.bfloat16
189
  else:
190
- raise ValueError(
191
- f"Do not support weight dtype: {config.weight_dtype} during training"
192
- )
193
 
 
194
  whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/')).to(device).eval()
195
-
196
  whisper.requires_grad_(False)
197
-
198
  self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/'))
199
 
200
- det_path = os.path.join(BASE_DIR, os.path.join(BASE_DIR, 'checkpoints/yoloface_v5m.pt'))
 
201
  self.face_det = AlignImage(device, det_path=det_path)
202
  if config.use_interframe:
203
- rife = RIFEModel(device=device)
204
- rife.load_model(os.path.join(BASE_DIR, 'checkpoints', 'RIFE/'))
205
- self.rife = rife
206
-
207
 
 
208
  image_encoder.to(weight_dtype)
209
  vae.to(weight_dtype)
210
  unet.to(weight_dtype)
211
 
 
212
  pipe = SonicPipeline(
213
  unet=unet,
214
  image_encoder=image_encoder,
215
  vae=vae,
216
  scheduler=val_noise_scheduler,
217
  )
218
- pipe = pipe.to(device=device, dtype=weight_dtype)
219
-
220
-
221
- self.pipe = pipe
222
  self.whisper = whisper
223
  self.audio2token = audio2token
224
  self.audio2bucket = audio2bucket
225
  self.image_encoder = image_encoder
226
- self.device = device
227
-
228
- print('init done')
229
 
 
230
 
231
- def preprocess(self,
232
- image_path, expand_ratio=1.0):
 
233
  face_image = cv2.imread(image_path)
234
  h, w = face_image.shape[:2]
235
  _, _, bboxes = self.face_det(face_image, maxface=True)
236
  face_num = len(bboxes)
237
- bbox = []
238
  if face_num > 0:
239
  x1, y1, ww, hh = bboxes[0]
240
  x2, y2 = x1 + ww, y1 + hh
@@ -245,86 +234,8 @@ class Sonic():
245
  'face_num': face_num,
246
  'crop_bbox': bbox_s,
247
  }
248
-
249
- def crop_image(self,
250
- input_image_path,
251
- output_image_path,
252
- crop_bbox):
253
  face_image = cv2.imread(input_image_path)
254
  crop_image = face_image[crop_bbox[1]:crop_bbox[3], crop_bbox[0]:crop_bbox[2]]
255
- cv2.imwrite(output_image_path, crop_image)
256
-
257
- @torch.no_grad()
258
- def process(self,
259
- image_path,
260
- audio_path,
261
- output_path,
262
- min_resolution=512,
263
- inference_steps=25,
264
- dynamic_scale=1.0,
265
- keep_resolution=False,
266
- seed=None):
267
-
268
- config = self.config
269
- device = self.device
270
- pipe = self.pipe
271
- whisper = self.whisper
272
- audio2token = self.audio2token
273
- audio2bucket = self.audio2bucket
274
- image_encoder = self.image_encoder
275
-
276
- # specific parameters
277
- if seed:
278
- config.seed = seed
279
-
280
- config.num_inference_steps = inference_steps
281
-
282
- config.motion_bucket_scale = dynamic_scale
283
-
284
- seed_everything(config.seed)
285
-
286
- video_path = output_path.replace('.mp4', '_noaudio.mp4')
287
- audio_video_path = output_path
288
-
289
- imSrc_ = Image.open(image_path).convert('RGB')
290
- raw_w, raw_h = imSrc_.size
291
-
292
- test_data = image_audio_to_tensor(self.face_det, self.feature_extractor, image_path, audio_path, limit=config.frame_num, image_size=min_resolution, area=config.area)
293
- if test_data is None:
294
- return -1
295
- height, width = test_data['ref_img'].shape[-2:]
296
- if keep_resolution:
297
- resolution = f'{raw_w//2*2}x{raw_h//2*2}'
298
- else:
299
- resolution = f'{width}x{height}'
300
-
301
- video = test(
302
- pipe,
303
- config,
304
- wav_enc=whisper,
305
- audio_pe=audio2token,
306
- audio2bucket=audio2bucket,
307
- image_encoder=image_encoder,
308
- width=width,
309
- height=height,
310
- batch=test_data,
311
- )
312
-
313
- if config.use_interframe:
314
- rife = self.rife
315
- out = video.to(device)
316
- results = []
317
- video_len = out.shape[2]
318
- for idx in tqdm(range(video_len-1), ncols=0):
319
- I1 = out[:, :, idx]
320
- I2 = out[:, :, idx+1]
321
- middle = rife.inference(I1, I2).clamp(0, 1).detach()
322
- results.append(out[:, :, idx])
323
- results.append(middle)
324
- results.append(out[:, :, video_len-1])
325
- video = torch.stack(results, 2).cpu()
326
-
327
- save_videos_grid(video, video_path, n_rows=video.shape[0], fps=config.fps * 2 if config.use_interframe else config.fps)
328
- os.system(f"ffmpeg -i '{video_path}' -i '{audio_path}' -s {resolution} -vcodec libx264 -acodec aac -crf 18 -shortest '{audio_video_path}' -y; rm '{video_path}'")
329
- return 0
330
-
 
20
  from src.utils.RIFE.RIFE_HDv3 import RIFEModel
21
  from src.dataset.face_align.align import AlignImage
22
 
 
23
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
24
 
25
+
26
  def test(
27
  pipe,
28
  config,
 
34
  height,
35
  batch
36
  ):
37
+ """Run one forward pass to generate the video tensor."""
38
  for k, v in batch.items():
39
  if isinstance(v, torch.Tensor):
40
  batch[k] = v.unsqueeze(0).to(pipe.device).float()
41
+
42
  ref_img = batch['ref_img']
43
  clip_img = batch['clip_images']
44
  face_mask = batch['face_mask']
45
+ image_embeds = image_encoder(clip_img).image_embeds
 
 
46
 
47
  audio_feature = batch['audio_feature']
48
  audio_len = batch['audio_len']
 
52
  audio_prompts = []
53
  last_audio_prompts = []
54
  for i in range(0, audio_feature.shape[-1], window):
55
+ audio_prompt = wav_enc.encoder(audio_feature[:, :, i:i + window], output_hidden_states=True).hidden_states
56
+ last_audio_prompt = wav_enc.encoder(audio_feature[:, :, i:i + window]).last_hidden_state
57
  last_audio_prompt = last_audio_prompt.unsqueeze(-2)
58
  audio_prompt = torch.stack(audio_prompt, dim=2)
59
  audio_prompts.append(audio_prompt)
60
  last_audio_prompts.append(last_audio_prompt)
61
 
62
  audio_prompts = torch.cat(audio_prompts, dim=1)
63
+ audio_prompts = audio_prompts[:, :audio_len * 2]
64
+ audio_prompts = torch.cat([torch.zeros_like(audio_prompts[:, :4]), audio_prompts,
65
+ torch.zeros_like(audio_prompts[:, :6])], 1)
66
 
67
  last_audio_prompts = torch.cat(last_audio_prompts, dim=1)
68
+ last_audio_prompts = last_audio_prompts[:, :audio_len * 2]
69
+ last_audio_prompts = torch.cat([torch.zeros_like(last_audio_prompts[:, :24]), last_audio_prompts,
70
+ torch.zeros_like(last_audio_prompts[:, :26])], 1)
71
 
72
  ref_tensor_list = []
73
  audio_tensor_list = []
74
  uncond_audio_tensor_list = []
75
  motion_buckets = []
76
+ for i in tqdm(range(audio_len // step)):
77
+ audio_clip = audio_prompts[:, i * 2 * step:i * 2 * step + 10].unsqueeze(0)
78
+ audio_clip_for_bucket = last_audio_prompts[:, i * 2 * step:i * 2 * step + 50].unsqueeze(0)
 
 
79
  motion_bucket = audio2bucket(audio_clip_for_bucket, image_embeds)
80
  motion_bucket = motion_bucket * 16 + 16
81
  motion_buckets.append(motion_bucket[0])
 
101
  motion_bucket_scale=config.motion_bucket_scale,
102
  fps=config.fps,
103
  noise_aug_strength=config.noise_aug_strength,
104
+ min_guidance_scale1=config.min_appearance_guidance_scale,
105
  max_guidance_scale1=config.max_appearance_guidance_scale,
106
+ min_guidance_scale2=config.audio_guidance_scale,
107
  max_guidance_scale2=config.audio_guidance_scale,
108
  overlap=config.overlap,
109
  shift_offset=config.shift_offset,
 
112
  i2i_noise_strength=config.i2i_noise_strength
113
  ).frames
114
 
115
+ video = (video * 0.5 + 0.5).clamp(0, 1)
 
 
 
116
  video = torch.cat([video.to(pipe.device)], dim=0).cpu()
117
 
118
  return video
119
 
120
 
121
+ class Sonic:
122
+ """Wrapper class for the Sonic portrait animation pipeline."""
123
+
124
  config_file = os.path.join(BASE_DIR, 'config/inference/sonic.yaml')
125
  config = OmegaConf.load(config_file)
126
 
127
+ def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
128
+ # --------- load config & device ---------
 
 
 
129
  config = self.config
130
  config.use_interframe = enable_interpolate_frame
131
 
132
+ device = f'cuda:{device_id}' if device_id > -1 else 'cpu'
133
+ self.device = device
134
 
135
+ # --------- Model paths ---------
136
  config.pretrained_model_name_or_path = os.path.join(BASE_DIR, config.pretrained_model_name_or_path)
137
 
138
+ # --------- Load sub‑modules ---------
139
  vae = AutoencoderKLTemporalDecoder.from_pretrained(
140
+ config.pretrained_model_name_or_path,
141
  subfolder="vae",
142
+ variant="fp16"
143
+ )
144
+
145
  val_noise_scheduler = EulerDiscreteScheduler.from_pretrained(
146
+ config.pretrained_model_name_or_path,
147
+ subfolder="scheduler"
148
+ )
149
+
150
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(
151
+ config.pretrained_model_name_or_path,
152
  subfolder="image_encoder",
153
+ variant="fp16"
154
+ )
155
+
156
  unet = UNetSpatioTemporalConditionModel.from_pretrained(
157
  config.pretrained_model_name_or_path,
158
  subfolder="unet",
159
+ variant="fp16"
160
+ )
161
  add_ip_adapters(unet, [32], [config.ip_audio_scale])
 
 
 
162
 
163
+ audio2token = AudioProjModel(seq_len=10, blocks=5, channels=384, intermediate_dim=1024, output_dim=1024,
164
+ context_tokens=32).to(device)
165
+ audio2bucket = Audio2bucketModel(seq_len=50, blocks=1, channels=384, clip_channels=1024, intermediate_dim=1024,
166
+ output_dim=1, context_tokens=2).to(device)
167
 
168
+ # --------- Load checkpoints ---------
169
+ unet_ckpt = torch.load(os.path.join(BASE_DIR, config.unet_checkpoint_path), map_location="cpu")
170
+ audio2token_ckpt = torch.load(os.path.join(BASE_DIR, config.audio2token_checkpoint_path), map_location="cpu")
171
+ audio2bucket_ckpt = torch.load(os.path.join(BASE_DIR, config.audio2bucket_checkpoint_path), map_location="cpu")
 
 
 
 
 
172
 
173
+ unet.load_state_dict(unet_ckpt, strict=True)
174
+ audio2token.load_state_dict(audio2token_ckpt, strict=True)
175
+ audio2bucket.load_state_dict(audio2bucket_ckpt, strict=True)
 
 
176
 
177
+ # --------- dtype ---------
178
  if config.weight_dtype == "fp16":
179
  weight_dtype = torch.float16
180
  elif config.weight_dtype == "fp32":
 
182
  elif config.weight_dtype == "bf16":
183
  weight_dtype = torch.bfloat16
184
  else:
185
+ raise ValueError(f"Unsupported weight dtype: {config.weight_dtype}")
 
 
186
 
187
+ # --------- Whisper encoder for audio ---------
188
  whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/')).to(device).eval()
 
189
  whisper.requires_grad_(False)
 
190
  self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/'))
191
 
192
+ # --------- Face detector & frame interpolator ---------
193
+ det_path = os.path.join(BASE_DIR, 'checkpoints/yoloface_v5m.pt')
194
  self.face_det = AlignImage(device, det_path=det_path)
195
  if config.use_interframe:
196
+ self.rife = RIFEModel(device=device)
197
+ self.rife.load_model(os.path.join(BASE_DIR, 'checkpoints', 'RIFE/'))
 
 
198
 
199
+ # --------- Move modules to device & dtype ---------
200
  image_encoder.to(weight_dtype)
201
  vae.to(weight_dtype)
202
  unet.to(weight_dtype)
203
 
204
+ # --------- Compose pipeline ---------
205
  pipe = SonicPipeline(
206
  unet=unet,
207
  image_encoder=image_encoder,
208
  vae=vae,
209
  scheduler=val_noise_scheduler,
210
  )
211
+ self.pipe = pipe.to(device=device, dtype=weight_dtype)
 
 
 
212
  self.whisper = whisper
213
  self.audio2token = audio2token
214
  self.audio2bucket = audio2bucket
215
  self.image_encoder = image_encoder
 
 
 
216
 
217
+ print('Sonic initialization complete.')
218
 
219
+ # -------------------------- Public helpers --------------------------
220
+ def preprocess(self, image_path: str, expand_ratio: float = 1.0):
221
+ """Detect face and compute crop bbox (optional)."""
222
  face_image = cv2.imread(image_path)
223
  h, w = face_image.shape[:2]
224
  _, _, bboxes = self.face_det(face_image, maxface=True)
225
  face_num = len(bboxes)
226
+ bbox_s = []
227
  if face_num > 0:
228
  x1, y1, ww, hh = bboxes[0]
229
  x2, y2 = x1 + ww, y1 + hh
 
234
  'face_num': face_num,
235
  'crop_bbox': bbox_s,
236
  }
237
+
238
+ def crop_image(self, input_image_path: str, output_image_path: str, crop_bbox):
 
 
 
239
  face_image = cv2.imread(input_image_path)
240
  crop_image = face_image[crop_bbox[1]:crop_bbox[3], crop_bbox[0]:crop_bbox[2]]
241
+ cv2.imwrite(output