openfree commited on
Commit
c4a0f5f
·
verified ·
1 Parent(s): 36737d5

Update sonic.py

Browse files
Files changed (1) hide show
  1. sonic.py +100 -81
sonic.py CHANGED
@@ -32,9 +32,9 @@ def test(
32
  image_encoder,
33
  width,
34
  height,
35
- batch
36
  ):
37
- """Run one forward pass to generate the video tensor."""
38
  for k, v in batch.items():
39
  if isinstance(v, torch.Tensor):
40
  batch[k] = v.unsqueeze(0).to(pipe.device).float()
@@ -52,30 +52,36 @@ def test(
52
  audio_prompts = []
53
  last_audio_prompts = []
54
  for i in range(0, audio_feature.shape[-1], window):
55
- audio_prompt = wav_enc.encoder(audio_feature[:, :, i:i + window], output_hidden_states=True).hidden_states
56
- last_audio_prompt = wav_enc.encoder(audio_feature[:, :, i:i + window]).last_hidden_state
57
  last_audio_prompt = last_audio_prompt.unsqueeze(-2)
58
  audio_prompt = torch.stack(audio_prompt, dim=2)
59
  audio_prompts.append(audio_prompt)
60
  last_audio_prompts.append(last_audio_prompt)
61
 
62
  audio_prompts = torch.cat(audio_prompts, dim=1)
63
- audio_prompts = audio_prompts[:, :audio_len * 2]
64
- audio_prompts = torch.cat([torch.zeros_like(audio_prompts[:, :4]), audio_prompts,
65
- torch.zeros_like(audio_prompts[:, :6])], 1)
 
 
 
66
 
67
  last_audio_prompts = torch.cat(last_audio_prompts, dim=1)
68
- last_audio_prompts = last_audio_prompts[:, :audio_len * 2]
69
- last_audio_prompts = torch.cat([torch.zeros_like(last_audio_prompts[:, :24]), last_audio_prompts,
70
- torch.zeros_like(last_audio_prompts[:, :26])], 1)
 
 
 
71
 
72
  ref_tensor_list = []
73
  audio_tensor_list = []
74
  uncond_audio_tensor_list = []
75
  motion_buckets = []
76
- for i in tqdm(range(audio_len // step)):
77
- audio_clip = audio_prompts[:, i * 2 * step:i * 2 * step + 10].unsqueeze(0)
78
- audio_clip_for_bucket = last_audio_prompts[:, i * 2 * step:i * 2 * step + 50].unsqueeze(0)
79
  motion_bucket = audio2bucket(audio_clip_for_bucket, image_embeds)
80
  motion_bucket = motion_bucket * 16 + 16
81
  motion_buckets.append(motion_bucket[0])
@@ -114,100 +120,67 @@ def test(
114
 
115
  video = (video * 0.5 + 0.5).clamp(0, 1)
116
  video = torch.cat([video.to(pipe.device)], dim=0).cpu()
117
-
118
  return video
119
 
120
 
121
  class Sonic:
122
- """Wrapper class for the Sonic portrait animation pipeline."""
123
 
124
  config_file = os.path.join(BASE_DIR, 'config/inference/sonic.yaml')
125
  config = OmegaConf.load(config_file)
126
 
127
  def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
128
- # --------- load config & device ---------
129
  config = self.config
130
  config.use_interframe = enable_interpolate_frame
131
 
132
  device = f'cuda:{device_id}' if device_id > -1 else 'cpu'
133
  self.device = device
134
 
135
- # --------- Model paths ---------
136
  config.pretrained_model_name_or_path = os.path.join(BASE_DIR, config.pretrained_model_name_or_path)
137
 
138
- # --------- Load sub‑modules ---------
139
  vae = AutoencoderKLTemporalDecoder.from_pretrained(
140
- config.pretrained_model_name_or_path,
141
- subfolder="vae",
142
- variant="fp16"
143
- )
144
-
145
  val_noise_scheduler = EulerDiscreteScheduler.from_pretrained(
146
- config.pretrained_model_name_or_path,
147
- subfolder="scheduler"
148
- )
149
-
150
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(
151
- config.pretrained_model_name_or_path,
152
- subfolder="image_encoder",
153
- variant="fp16"
154
- )
155
-
156
  unet = UNetSpatioTemporalConditionModel.from_pretrained(
157
- config.pretrained_model_name_or_path,
158
- subfolder="unet",
159
- variant="fp16"
160
- )
161
  add_ip_adapters(unet, [32], [config.ip_audio_scale])
162
 
163
- audio2token = AudioProjModel(seq_len=10, blocks=5, channels=384, intermediate_dim=1024, output_dim=1024,
164
- context_tokens=32).to(device)
165
- audio2bucket = Audio2bucketModel(seq_len=50, blocks=1, channels=384, clip_channels=1024, intermediate_dim=1024,
166
- output_dim=1, context_tokens=2).to(device)
167
-
168
- # --------- Load checkpoints ---------
169
- unet_ckpt = torch.load(os.path.join(BASE_DIR, config.unet_checkpoint_path), map_location="cpu")
170
- audio2token_ckpt = torch.load(os.path.join(BASE_DIR, config.audio2token_checkpoint_path), map_location="cpu")
171
- audio2bucket_ckpt = torch.load(os.path.join(BASE_DIR, config.audio2bucket_checkpoint_path), map_location="cpu")
172
-
173
- unet.load_state_dict(unet_ckpt, strict=True)
174
- audio2token.load_state_dict(audio2token_ckpt, strict=True)
175
- audio2bucket.load_state_dict(audio2bucket_ckpt, strict=True)
176
-
177
- # --------- dtype ---------
178
- if config.weight_dtype == "fp16":
179
- weight_dtype = torch.float16
180
- elif config.weight_dtype == "fp32":
181
- weight_dtype = torch.float32
182
- elif config.weight_dtype == "bf16":
183
- weight_dtype = torch.bfloat16
184
- else:
185
  raise ValueError(f"Unsupported weight dtype: {config.weight_dtype}")
186
 
187
- # --------- Whisper encoder for audio ---------
188
  whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/')).to(device).eval()
189
  whisper.requires_grad_(False)
190
- self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/'))
 
191
 
192
- # --------- Face detector & frame interpolator ---------
193
- det_path = os.path.join(BASE_DIR, 'checkpoints/yoloface_v5m.pt')
194
- self.face_det = AlignImage(device, det_path=det_path)
195
  if config.use_interframe:
196
  self.rife = RIFEModel(device=device)
197
  self.rife.load_model(os.path.join(BASE_DIR, 'checkpoints', 'RIFE/'))
198
 
199
- # --------- Move modules to device & dtype ---------
200
  image_encoder.to(weight_dtype)
201
  vae.to(weight_dtype)
202
  unet.to(weight_dtype)
203
 
204
- # --------- Compose pipeline ---------
205
  pipe = SonicPipeline(
206
- unet=unet,
207
- image_encoder=image_encoder,
208
- vae=vae,
209
- scheduler=val_noise_scheduler,
210
- )
211
  self.pipe = pipe.to(device=device, dtype=weight_dtype)
212
  self.whisper = whisper
213
  self.audio2token = audio2token
@@ -216,9 +189,7 @@ class Sonic:
216
 
217
  print('Sonic initialization complete.')
218
 
219
- # -------------------------- Public helpers --------------------------
220
  def preprocess(self, image_path: str, expand_ratio: float = 1.0):
221
- """Detect face and compute crop bbox (optional)."""
222
  face_image = cv2.imread(image_path)
223
  h, w = face_image.shape[:2]
224
  _, _, bboxes = self.face_det(face_image, maxface=True)
@@ -227,15 +198,63 @@ class Sonic:
227
  if face_num > 0:
228
  x1, y1, ww, hh = bboxes[0]
229
  x2, y2 = x1 + ww, y1 + hh
230
- bbox = x1, y1, x2, y2
231
- bbox_s = process_bbox(bbox, expand_radio=expand_ratio, height=h, width=w)
232
-
233
- return {
234
- 'face_num': face_num,
235
- 'crop_bbox': bbox_s,
236
- }
237
 
238
  def crop_image(self, input_image_path: str, output_image_path: str, crop_bbox):
239
  face_image = cv2.imread(input_image_path)
240
- crop_image = face_image[crop_bbox[1]:crop_bbox[3], crop_bbox[0]:crop_bbox[2]]
241
- cv2.imwrite(output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  image_encoder,
33
  width,
34
  height,
35
+ batch,
36
  ):
37
+ """Generate a video tensor for the given batch."""
38
  for k, v in batch.items():
39
  if isinstance(v, torch.Tensor):
40
  batch[k] = v.unsqueeze(0).to(pipe.device).float()
 
52
  audio_prompts = []
53
  last_audio_prompts = []
54
  for i in range(0, audio_feature.shape[-1], window):
55
+ audio_prompt = wav_enc.encoder(audio_feature[:, :, i:i+window], output_hidden_states=True).hidden_states
56
+ last_audio_prompt = wav_enc.encoder(audio_feature[:, :, i:i+window]).last_hidden_state
57
  last_audio_prompt = last_audio_prompt.unsqueeze(-2)
58
  audio_prompt = torch.stack(audio_prompt, dim=2)
59
  audio_prompts.append(audio_prompt)
60
  last_audio_prompts.append(last_audio_prompt)
61
 
62
  audio_prompts = torch.cat(audio_prompts, dim=1)
63
+ audio_prompts = audio_prompts[:, :audio_len*2]
64
+ audio_prompts = torch.cat([
65
+ torch.zeros_like(audio_prompts[:, :4]),
66
+ audio_prompts,
67
+ torch.zeros_like(audio_prompts[:, :6])
68
+ ], 1)
69
 
70
  last_audio_prompts = torch.cat(last_audio_prompts, dim=1)
71
+ last_audio_prompts = last_audio_prompts[:, :audio_len*2]
72
+ last_audio_prompts = torch.cat([
73
+ torch.zeros_like(last_audio_prompts[:, :24]),
74
+ last_audio_prompts,
75
+ torch.zeros_like(last_audio_prompts[:, :26])
76
+ ], 1)
77
 
78
  ref_tensor_list = []
79
  audio_tensor_list = []
80
  uncond_audio_tensor_list = []
81
  motion_buckets = []
82
+ for i in tqdm(range(audio_len//step), ncols=0):
83
+ audio_clip = audio_prompts[:, i*2*step:i*2*step+10].unsqueeze(0)
84
+ audio_clip_for_bucket = last_audio_prompts[:, i*2*step:i*2*step+50].unsqueeze(0)
85
  motion_bucket = audio2bucket(audio_clip_for_bucket, image_embeds)
86
  motion_bucket = motion_bucket * 16 + 16
87
  motion_buckets.append(motion_bucket[0])
 
120
 
121
  video = (video * 0.5 + 0.5).clamp(0, 1)
122
  video = torch.cat([video.to(pipe.device)], dim=0).cpu()
 
123
  return video
124
 
125
 
126
  class Sonic:
127
+ """High-level interface for the Sonic portrait animation pipeline."""
128
 
129
  config_file = os.path.join(BASE_DIR, 'config/inference/sonic.yaml')
130
  config = OmegaConf.load(config_file)
131
 
132
  def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
 
133
  config = self.config
134
  config.use_interframe = enable_interpolate_frame
135
 
136
  device = f'cuda:{device_id}' if device_id > -1 else 'cpu'
137
  self.device = device
138
 
 
139
  config.pretrained_model_name_or_path = os.path.join(BASE_DIR, config.pretrained_model_name_or_path)
140
 
 
141
  vae = AutoencoderKLTemporalDecoder.from_pretrained(
142
+ config.pretrained_model_name_or_path, subfolder='vae', variant='fp16')
 
 
 
 
143
  val_noise_scheduler = EulerDiscreteScheduler.from_pretrained(
144
+ config.pretrained_model_name_or_path, subfolder='scheduler')
 
 
 
145
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(
146
+ config.pretrained_model_name_or_path, subfolder='image_encoder', variant='fp16')
 
 
 
 
147
  unet = UNetSpatioTemporalConditionModel.from_pretrained(
148
+ config.pretrained_model_name_or_path, subfolder='unet', variant='fp16')
 
 
 
149
  add_ip_adapters(unet, [32], [config.ip_audio_scale])
150
 
151
+ audio2token = AudioProjModel(seq_len=10, blocks=5, channels=384, intermediate_dim=1024,
152
+ output_dim=1024, context_tokens=32).to(device)
153
+ audio2bucket = Audio2bucketModel(seq_len=50, blocks=1, channels=384, clip_channels=1024,
154
+ intermediate_dim=1024, output_dim=1, context_tokens=2).to(device)
155
+
156
+ unet.load_state_dict(
157
+ torch.load(os.path.join(BASE_DIR, config.unet_checkpoint_path), map_location='cpu'), strict=True)
158
+ audio2token.load_state_dict(
159
+ torch.load(os.path.join(BASE_DIR, config.audio2token_checkpoint_path), map_location='cpu'), strict=True)
160
+ audio2bucket.load_state_dict(
161
+ torch.load(os.path.join(BASE_DIR, config.audio2bucket_checkpoint_path), map_location='cpu'), strict=True)
162
+
163
+ dtype_map = {'fp16': torch.float16, 'fp32': torch.float32, 'bf16': torch.bfloat16}
164
+ weight_dtype = dtype_map.get(config.weight_dtype)
165
+ if weight_dtype is None:
 
 
 
 
 
 
 
166
  raise ValueError(f"Unsupported weight dtype: {config.weight_dtype}")
167
 
 
168
  whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/')).to(device).eval()
169
  whisper.requires_grad_(False)
170
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(
171
+ os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/'))
172
 
173
+ self.face_det = AlignImage(device, det_path=os.path.join(BASE_DIR, 'checkpoints/yoloface_v5m.pt'))
 
 
174
  if config.use_interframe:
175
  self.rife = RIFEModel(device=device)
176
  self.rife.load_model(os.path.join(BASE_DIR, 'checkpoints', 'RIFE/'))
177
 
 
178
  image_encoder.to(weight_dtype)
179
  vae.to(weight_dtype)
180
  unet.to(weight_dtype)
181
 
 
182
  pipe = SonicPipeline(
183
+ unet=unet, image_encoder=image_encoder, vae=vae, scheduler=val_noise_scheduler)
 
 
 
 
184
  self.pipe = pipe.to(device=device, dtype=weight_dtype)
185
  self.whisper = whisper
186
  self.audio2token = audio2token
 
189
 
190
  print('Sonic initialization complete.')
191
 
 
192
  def preprocess(self, image_path: str, expand_ratio: float = 1.0):
 
193
  face_image = cv2.imread(image_path)
194
  h, w = face_image.shape[:2]
195
  _, _, bboxes = self.face_det(face_image, maxface=True)
 
198
  if face_num > 0:
199
  x1, y1, ww, hh = bboxes[0]
200
  x2, y2 = x1 + ww, y1 + hh
201
+ bbox_s = process_bbox((x1, y1, x2, y2), expand_radio=expand_ratio, height=h, width=w)
202
+ return {'face_num': face_num, 'crop_bbox': bbox_s}
 
 
 
 
 
203
 
204
  def crop_image(self, input_image_path: str, output_image_path: str, crop_bbox):
205
  face_image = cv2.imread(input_image_path)
206
+ crop_img = face_image[crop_bbox[1]:crop_bbox[3], crop_bbox[0]:crop_bbox[2]]
207
+ cv2.imwrite(output_image_path, crop_img)
208
+
209
+ @torch.no_grad()
210
+ def process(self, image_path, audio_path, output_path, min_resolution=512,
211
+ inference_steps=25, dynamic_scale=1.0, keep_resolution=False, seed=None):
212
+ config = self.config
213
+ device = self.device
214
+
215
+ pipe = self.pipe
216
+ whisper = self.whisper
217
+ audio2token = self.audio2token
218
+ audio2bucket = self.audio2bucket
219
+ image_encoder = self.image_encoder
220
+
221
+ if seed is not None:
222
+ config.seed = seed
223
+ seed_everything(config.seed)
224
+
225
+ config.num_inference_steps = inference_steps
226
+ config.frame_num = config.fps * 60
227
+ config.motion_bucket_scale = dynamic_scale
228
+
229
+ video_path = output_path.replace('.mp4', '_noaudio.mp4')
230
+ audio_video_path = output_path
231
+
232
+ imSrc_ = Image.open(image_path).convert('RGB')
233
+ raw_w, raw_h = imSrc_.size
234
+
235
+ test_data = image_audio_to_tensor(
236
+ self.face_det, self.feature_extractor, image_path, audio_path,
237
+ limit=config.frame_num, image_size=min_resolution, area=config.area)
238
+ if test_data is None:
239
+ return -1
240
+ height, width = test_data['ref_img'].shape[-2:]
241
+ resolution = f"{width}x{height}" if not keep_resolution else f"{raw_w//2*2}x{raw_h//2*2}"
242
+
243
+ video = test(pipe, config, wav_enc=whisper, audio_pe=audio2token,
244
+ audio2bucket=audio2bucket, image_encoder=image_encoder,
245
+ width=width, height=height, batch=test_data)
246
+
247
+ if config.use_interframe:
248
+ out = video.to(device)
249
+ results = []
250
+ for idx in tqdm(range(out.shape[2]-1), ncols=0):
251
+ I1 = out[:, :, idx]
252
+ I2 = out[:, :, idx+1]
253
+ mid = self.rife.inference(I1, I2).clamp(0,1).detach()
254
+ results.extend([out[:, :, idx], mid])
255
+ results.append(out[:, :, -1])
256
+ video = torch.stack(results, 2).cpu()
257
+
258
+ save_videos_grid(video, video_path, n_rows=video.shape[0], fps=config.fps * (2 if config.use_interframe else 1))
259
+ os.system(f"ffmpeg -i '{video_path}' -i '{audio_path}' -s {resolution} -vcodec libx264 -acodec aac -crf 18 -shortest '{audio_video_path}' -y; rm '{video_path}'")
260
+ return 0