openfree commited on
Commit
1fb410d
·
verified ·
1 Parent(s): 79f5781

Update sonic.py

Browse files
Files changed (1) hide show
  1. sonic.py +208 -269
sonic.py CHANGED
@@ -1,8 +1,8 @@
1
  import os
 
2
  import torch
3
  import torch.utils.checkpoint
4
  from PIL import Image
5
- import numpy as np
6
  from omegaconf import OmegaConf
7
  from tqdm import tqdm
8
  import cv2
@@ -13,7 +13,9 @@ from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatur
13
 
14
  from src.utils.util import save_videos_grid, seed_everything
15
  from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor
16
- from src.models.base.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel, add_ip_adapters
 
 
17
  from src.pipelines.pipeline_sonic import SonicPipeline
18
  from src.models.audio_adapter.audio_proj import AudioProjModel
19
  from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel
@@ -22,6 +24,10 @@ from src.dataset.face_align.align import AlignImage
22
 
23
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
24
 
 
 
 
 
25
  def test(
26
  pipe,
27
  config,
@@ -31,81 +37,88 @@ def test(
31
  image_encoder,
32
  width,
33
  height,
34
- batch
35
  ):
36
- # 배치 텐서를 (1,B,C,H,W) 형태로
37
  for k, v in batch.items():
38
  if isinstance(v, torch.Tensor):
39
  batch[k] = v.unsqueeze(0).to(pipe.device).float()
40
 
41
- ref_img = batch['ref_img']
42
- clip_img = batch['clip_images']
43
- face_mask = batch['face_mask']
44
  image_embeds = image_encoder(clip_img).image_embeds
45
 
46
- audio_feature = batch['audio_feature']
47
- audio_len = batch['audio_len']
48
- step = int(config.step)
49
 
50
- # window=3000 -> 16000으로 변경(1초 간격)
 
 
 
51
  window = 16000
52
- audio_prompts = []
53
- last_audio_prompts = []
54
 
 
 
55
  for i in range(0, audio_feature.shape[-1], window):
56
- audio_clip_chunk = audio_feature[:, :, i:i+window]
57
- # Whisper encoder
58
- audio_prompt = wav_enc.encoder(audio_clip_chunk, output_hidden_states=True).hidden_states
59
- last_audio_prompt = wav_enc.encoder(audio_clip_chunk).last_hidden_state
60
- last_audio_prompt = last_audio_prompt.unsqueeze(-2)
61
 
62
- audio_prompt = torch.stack(audio_prompt, dim=2)
63
- audio_prompts.append(audio_prompt)
64
- last_audio_prompts.append(last_audio_prompt)
65
 
66
- # 여기서 비었으면 예외
67
  if len(audio_prompts) == 0:
68
  raise ValueError(
69
  "[ERROR] No speech recognized from the audio. "
70
- "Please provide a valid speech audio (with clear voice)."
71
  )
72
 
73
- audio_prompts = torch.cat(audio_prompts, dim=1)
74
- audio_prompts = audio_prompts[:, :audio_len*2]
75
- audio_prompts = torch.cat([
76
- torch.zeros_like(audio_prompts[:, :4]),
77
- audio_prompts,
78
- torch.zeros_like(audio_prompts[:, :6])
79
- ], dim=1)
80
-
81
- last_audio_prompts = torch.cat(last_audio_prompts, dim=1)
82
- last_audio_prompts = last_audio_prompts[:, :audio_len*2]
83
- last_audio_prompts = torch.cat([
84
- torch.zeros_like(last_audio_prompts[:, :24]),
85
- last_audio_prompts,
86
- torch.zeros_like(last_audio_prompts[:, :26])
87
- ], dim=1)
88
-
89
- ref_tensor_list = []
90
- audio_tensor_list = []
91
- uncond_audio_tensor_list = []
92
- motion_buckets = []
93
-
94
- for i in tqdm(range(audio_len // step)):
95
- audio_clip = audio_prompts[:, i*2*step : i*2*step + 10].unsqueeze(0)
96
- audio_clip_for_bucket = last_audio_prompts[:, i*2*step : i*2*step + 50].unsqueeze(0)
97
-
98
- motion_bucket = audio2bucket(audio_clip_for_bucket, image_embeds)
99
- motion_bucket = motion_bucket * 16 + 16
100
  motion_buckets.append(motion_bucket[0])
101
 
102
- cond_audio_clip = audio_pe(audio_clip).squeeze(0)
103
- uncond_audio_clip = audio_pe(torch.zeros_like(audio_clip)).squeeze(0)
104
 
105
  ref_tensor_list.append(ref_img[0])
106
- audio_tensor_list.append(cond_audio_clip[0])
107
- uncond_audio_tensor_list.append(uncond_audio_clip[0])
 
 
 
 
108
 
 
109
  video = pipe(
110
  ref_img,
111
  clip_img,
@@ -128,246 +141,172 @@ def test(
128
  shift_offset=config.shift_offset,
129
  frames_per_batch=config.n_sample_frames,
130
  num_inference_steps=config.num_inference_steps,
131
- i2i_noise_strength=config.i2i_noise_strength
132
  ).frames
 
133
 
134
  video = (video * 0.5 + 0.5).clamp(0, 1)
135
- video = torch.cat([video.to(pipe.device)], dim=0).cpu()
136
- return video
137
-
138
-
139
- class Sonic():
140
- config_file = os.path.join(BASE_DIR, 'config/inference/sonic.yaml')
141
- config = OmegaConf.load(config_file)
142
-
143
- def __init__(self,
144
- device_id=0,
145
- enable_interpolate_frame=True,
146
- ):
147
-
148
- config = self.config
149
- config.use_interframe = enable_interpolate_frame
150
-
151
- device = f'cuda:{device_id}' if device_id > -1 else 'cpu'
152
- config.pretrained_model_name_or_path = os.path.join(BASE_DIR, config.pretrained_model_name_or_path)
153
-
154
- # VAE
 
 
 
 
 
 
 
 
155
  vae = AutoencoderKLTemporalDecoder.from_pretrained(
156
- config.pretrained_model_name_or_path,
157
- subfolder="vae",
158
- variant="fp16")
159
-
160
- # 스케줄러
161
- val_noise_scheduler = EulerDiscreteScheduler.from_pretrained(
162
- config.pretrained_model_name_or_path,
163
- subfolder="scheduler")
164
-
165
- # CLIP Vision
166
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
167
- config.pretrained_model_name_or_path,
168
- subfolder="image_encoder",
169
- variant="fp16")
170
-
171
- # UNet
172
- unet = UNetSpatioTemporalConditionModel.from_pretrained(
173
- config.pretrained_model_name_or_path,
174
- subfolder="unet",
175
- variant="fp16")
176
-
177
- # Adapter
178
- add_ip_adapters(unet, [32], [config.ip_audio_scale])
179
-
180
- audio2token = AudioProjModel(
181
- seq_len=10, blocks=5, channels=384,
182
- intermediate_dim=1024, output_dim=1024, context_tokens=32
183
- ).to(device)
184
-
185
- audio2bucket = Audio2bucketModel(
186
- seq_len=50, blocks=1, channels=384,
187
- clip_channels=1024, intermediate_dim=1024, output_dim=1,
188
- context_tokens=2
189
- ).to(device)
190
-
191
- # 로컬 체크포인트 로드
192
- unet_checkpoint_path = os.path.join(BASE_DIR, config.unet_checkpoint_path)
193
- audio2token_checkpoint_path = os.path.join(BASE_DIR, config.audio2token_checkpoint_path)
194
- audio2bucket_checkpoint_path = os.path.join(BASE_DIR, config.audio2bucket_checkpoint_path)
195
-
196
- unet.load_state_dict(
197
- torch.load(unet_checkpoint_path, map_location="cpu"),
198
- strict=True,
199
- )
200
-
201
- audio2token.load_state_dict(
202
- torch.load(audio2token_checkpoint_path, map_location="cpu"),
203
- strict=True,
204
  )
205
-
206
- audio2bucket.load_state_dict(
207
- torch.load(audio2bucket_checkpoint_path, map_location="cpu"),
208
- strict=True,
209
  )
210
-
211
- # weight_dtype 설정
212
- if config.weight_dtype == "fp16":
213
- weight_dtype = torch.float16
214
- elif config.weight_dtype == "fp32":
215
- weight_dtype = torch.float32
216
- elif config.weight_dtype == "bf16":
217
- weight_dtype = torch.bfloat16
218
- else:
219
- raise ValueError(f"Do not support weight dtype: {config.weight_dtype}")
220
-
221
- # Whisper
222
- whisper = WhisperModel.from_pretrained(
223
- os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/')
224
- ).to(device).eval()
225
- whisper.requires_grad_(False)
226
-
227
- self.feature_extractor = AutoFeatureExtractor.from_pretrained(
228
- os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/')
229
  )
230
-
231
- # Face detect
232
- det_path = os.path.join(BASE_DIR, 'checkpoints/yoloface_v5m.pt')
233
- self.face_det = AlignImage(device, det_path=det_path)
234
-
235
- # RIFE 중간프레임 보간
236
- if config.use_interframe:
237
- rife = RIFEModel(device=device)
238
- rife.load_model(os.path.join(BASE_DIR, 'checkpoints', 'RIFE/'))
239
- self.rife = rife
240
-
241
- # dtype 변경
242
- image_encoder.to(weight_dtype)
243
- vae.to(weight_dtype)
244
- unet.to(weight_dtype)
245
-
246
- # SonicPipeline 초기화
247
- pipe = SonicPipeline(
248
- unet=unet,
249
- image_encoder=image_encoder,
250
- vae=vae,
251
- scheduler=val_noise_scheduler,
252
  )
253
- pipe = pipe.to(device=device, dtype=weight_dtype)
254
 
255
- self.pipe = pipe
256
- self.whisper = whisper
257
- self.audio2token = audio2token
258
- self.audio2bucket = audio2bucket
259
- self.image_encoder = image_encoder
260
- self.device = device
261
 
262
- print('Sonic init done')
 
 
 
263
 
264
- def preprocess(self, image_path, expand_ratio=1.0):
265
- face_image = cv2.imread(image_path)
266
- h, w = face_image.shape[:2]
267
- _, _, bboxes = self.face_det(face_image, maxface=True)
268
- face_num = len(bboxes)
269
- bbox_s = None
270
 
271
- if face_num > 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  x1, y1, ww, hh = bboxes[0]
273
- x2, y2 = x1 + ww, y1 + hh
274
- bbox = x1, y1, x2, y2
275
- bbox_s = process_bbox(bbox, expand_radio=expand_ratio, height=h, width=w)
276
-
277
- return {
278
- 'face_num': face_num,
279
- 'crop_bbox': bbox_s,
280
- }
281
-
282
- def crop_image(self, input_image_path, output_image_path, crop_bbox):
283
- face_image = cv2.imread(input_image_path)
284
- crop_image = face_image[crop_bbox[1]:crop_bbox[3], crop_bbox[0]:crop_bbox[2]]
285
- cv2.imwrite(output_image_path, crop_image)
286
 
 
287
  @torch.no_grad()
288
- def process(self,
289
- image_path,
290
- audio_path,
291
- output_path,
292
- min_resolution=512,
293
- inference_steps=25,
294
- dynamic_scale=1.0,
295
- keep_resolution=False,
296
- seed=None):
297
-
298
- config = self.config
299
- device = self.device
300
- pipe = self.pipe
301
- whisper = self.whisper
302
- audio2token = self.audio2token
303
- audio2bucket = self.audio2bucket
304
- image_encoder = self.image_encoder
305
-
306
- # 시드 설정
307
- if seed:
308
- config.seed = seed
309
- config.num_inference_steps = inference_steps
310
- config.motion_bucket_scale = dynamic_scale
311
- seed_everything(config.seed)
312
-
313
- video_path = output_path.replace('.mp4', '_noaudio.mp4')
314
- audio_video_path = output_path
315
-
316
- # 오디오+이미지 -> tensor
317
  test_data = image_audio_to_tensor(
318
- self.face_det,
319
- self.feature_extractor,
320
- image_path,
321
- audio_path,
322
- limit=-1, # 전체 사용
323
- image_size=min_resolution,
324
- area=config.area
325
  )
326
  if test_data is None:
327
  return -1
328
-
329
- height, width = test_data['ref_img'].shape[-2:]
330
- if keep_resolution:
331
- imSrc_ = Image.open(image_path).convert('RGB')
332
- raw_w, raw_h = imSrc_.size
333
- resolution = f'{raw_w//2*2}x{raw_h//2*2}'
334
- else:
335
- resolution = f'{width}x{height}'
336
-
337
- # 여기서 test(...) 호출
 
338
  video = test(
339
- pipe,
340
- config,
341
- wav_enc=whisper,
342
- audio_pe=audio2token,
343
- audio2bucket=audio2bucket,
344
- image_encoder=image_encoder,
345
- width=width,
346
- height=height,
347
  batch=test_data,
348
  )
349
 
350
- # 중간프레임 보간
351
- if config.use_interframe:
352
- rife = self.rife
353
- out = video.to(device)
354
- results = []
355
- video_len = out.shape[2]
356
- for idx in tqdm(range(video_len - 1), ncols=0):
357
- I1 = out[:, :, idx]
358
- I2 = out[:, :, idx + 1]
359
- middle = rife.inference(I1, I2).clamp(0, 1).detach()
360
- results.append(out[:, :, idx])
361
- results.append(middle)
362
- results.append(out[:, :, video_len - 1])
363
  video = torch.stack(results, 2).cpu()
364
-
365
- # 비디오 저장
366
- save_videos_grid(video, video_path, n_rows=video.shape[0], fps=config.fps * (2 if config.use_interframe else 1))
367
 
368
- # 오디오 합성 후 최종 mp4
 
 
 
 
369
  os.system(
370
- f"ffmpeg -i '{video_path}' -i '{audio_path}' -s {resolution} "
371
- f"-vcodec libx264 -acodec aac -crf 18 -shortest '{audio_video_path}' -y; rm '{video_path}'"
372
  )
 
373
  return 0
 
1
  import os
2
+ import math # [★ 수정] ceil 계산용
3
  import torch
4
  import torch.utils.checkpoint
5
  from PIL import Image
 
6
  from omegaconf import OmegaConf
7
  from tqdm import tqdm
8
  import cv2
 
13
 
14
  from src.utils.util import save_videos_grid, seed_everything
15
  from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor
16
+ from src.models.base.unet_spatio_temporal_condition import (
17
+ UNetSpatioTemporalConditionModel, add_ip_adapters,
18
+ )
19
  from src.pipelines.pipeline_sonic import SonicPipeline
20
  from src.models.audio_adapter.audio_proj import AudioProjModel
21
  from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel
 
24
 
25
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
26
 
27
+
28
+ # ------------------------------------------------------------------
29
+ # test() : 한 장의 얼굴 + 오디오 → 프레임 텐서 시퀀스
30
+ # ------------------------------------------------------------------
31
  def test(
32
  pipe,
33
  config,
 
37
  image_encoder,
38
  width,
39
  height,
40
+ batch,
41
  ):
42
+ # (B,C,H,W) (1,B,C,H,W)
43
  for k, v in batch.items():
44
  if isinstance(v, torch.Tensor):
45
  batch[k] = v.unsqueeze(0).to(pipe.device).float()
46
 
47
+ ref_img = batch["ref_img"]
48
+ clip_img = batch["clip_images"]
49
+ face_mask = batch["face_mask"]
50
  image_embeds = image_encoder(clip_img).image_embeds
51
 
52
+ audio_feature = batch["audio_feature"] # (C,T)
53
+ audio_len = batch["audio_len"] # # of whisper tokens
54
+ step = int(config.step)
55
 
56
+ # ----------------------------- [★ 수정] -----------------------------
57
+ # ① 1 초 구간 단위를 위해 window 16000 → whisper‐tiny 기준 1 초
58
+ # ② audio_len < step 이면 step 을 줄여 빈 리스트 방지
59
+ # --------------------------------------------------------------------
60
  window = 16000
61
+ if audio_len < step:
62
+ step = max(1, audio_len)
63
 
64
+ # ── 오디오를 1 초 단위로 자르면서 Whisper 인코딩
65
+ audio_prompts, last_audio_prompts = [], []
66
  for i in range(0, audio_feature.shape[-1], window):
67
+ chunk = audio_feature[:, :, i : i + window] # (B,C,window)
68
+
69
+ # whisper encoder
70
+ prompt_layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
71
+ last_hidden = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
72
 
73
+ audio_prompts.append(torch.stack(prompt_layers, dim=2))
74
+ last_audio_prompts.append(last_hidden)
 
75
 
76
+ # ── 예외: 아무 내용도 없으면 종료
77
  if len(audio_prompts) == 0:
78
  raise ValueError(
79
  "[ERROR] No speech recognized from the audio. "
80
+ "Please provide a valid speech recording."
81
  )
82
 
83
+ # Whisper token 시퀀스 재구성 (+ 모델 padding 규칙)
84
+ audio_prompts = torch.cat(audio_prompts, dim=1)[:, : audio_len * 2]
85
+ audio_prompts = torch.cat(
86
+ [torch.zeros_like(audio_prompts[:, :4]), audio_prompts, torch.zeros_like(audio_prompts[:, :6])],
87
+ dim=1,
88
+ )
89
+
90
+ last_audio_prompts = torch.cat(last_audio_prompts, dim=1)[:, : audio_len * 2]
91
+ last_audio_prompts = torch.cat(
92
+ [torch.zeros_like(last_audio_prompts[:, :24]), last_audio_prompts, torch.zeros_like(last_audio_prompts[:, :26])],
93
+ dim=1,
94
+ )
95
+
96
+ # --------------------------------------------------------------------
97
+ # step 조정 결과를 반영해 총 chunk 횟수 계산 (ceil)
98
+ # --------------------------------------------------------------------
99
+ num_chunks = math.ceil(audio_len / step)
100
+
101
+ ref_tensor_list, audio_tensor_list, uncond_audio_tensor_list, motion_buckets = [], [], [], []
102
+ for i in tqdm(range(num_chunks)):
103
+ start = i * 2 * step
104
+ audio_clip = audio_prompts[:, start : start + 10].unsqueeze(0)
105
+ audio_clip_for_bucket = last_audio_prompts[:, start : start + 50].unsqueeze(0)
106
+
107
+ motion_bucket = audio2bucket(audio_clip_for_bucket, image_embeds) * 16 + 16
 
 
108
  motion_buckets.append(motion_bucket[0])
109
 
110
+ cond_audio = audio_pe(audio_clip).squeeze(0)
111
+ uncond_audio = audio_pe(torch.zeros_like(audio_clip)).squeeze(0)
112
 
113
  ref_tensor_list.append(ref_img[0])
114
+ audio_tensor_list.append(cond_audio[0])
115
+ uncond_audio_tensor_list.append(uncond_audio[0])
116
+
117
+ # 빈 리스트 방지
118
+ if len(audio_tensor_list) == 0:
119
+ raise ValueError("[ERROR] Audio too short for the configured 'step' size; no frames produced.")
120
 
121
+ # --------------------------------------------------------------------
122
  video = pipe(
123
  ref_img,
124
  clip_img,
 
141
  shift_offset=config.shift_offset,
142
  frames_per_batch=config.n_sample_frames,
143
  num_inference_steps=config.num_inference_steps,
144
+ i2i_noise_strength=config.i2i_noise_strength,
145
  ).frames
146
+ # --------------------------------------------------------------------
147
 
148
  video = (video * 0.5 + 0.5).clamp(0, 1)
149
+ return video.to(pipe.device).unsqueeze(0).cpu()
150
+
151
+
152
+ # ------------------------------------------------------------------
153
+ # Sonic 클래스
154
+ # ------------------------------------------------------------------
155
+ class Sonic:
156
+ config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
157
+ config = OmegaConf.load(config_file)
158
+
159
+ def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
160
+ cfg = self.config
161
+ cfg.use_interframe = enable_interpolate_frame
162
+ self.device = f"cuda:{device_id}" if device_id >= 0 and torch.cuda.is_available() else "cpu"
163
+ cfg.pretrained_model_name_or_path = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path)
164
+
165
+ # ───────────── 모델 로드
166
+ self._load_models(cfg)
167
+ print("Sonic init done")
168
+
169
+ # --------------------------------------------------------------
170
+ # model / pipeline loader
171
+ # --------------------------------------------------------------
172
+ def _load_models(self, cfg):
173
+ dtype_map = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}
174
+ weight_dtype = dtype_map.get(cfg.weight_dtype, torch.float32)
175
+
176
+ # backbone
177
  vae = AutoencoderKLTemporalDecoder.from_pretrained(
178
+ cfg.pretrained_model_name_or_path, subfolder="vae", variant="fp16"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  )
180
+ scheduler = EulerDiscreteScheduler.from_pretrained(
181
+ cfg.pretrained_model_name_or_path, subfolder="scheduler"
 
 
182
  )
183
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
184
+ cfg.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  )
186
+ unet = UNetSpatioTemporalConditionModel.from_pretrained(
187
+ cfg.pretrained_model_name_or_path, subfolder="unet", variant="fp16"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  )
189
+ add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
190
 
191
+ # audio adapters
192
+ audio2token = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
193
+ audio2bucket = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
 
 
 
194
 
195
+ # checkpoints
196
+ unet.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu"))
197
+ audio2token.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu"))
198
+ audio2bucket.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu"))
199
 
200
+ # whisper
201
+ whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
202
+ whisper.requires_grad_(False)
 
 
 
203
 
204
+ # extras
205
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny"))
206
+ self.face_det = AlignImage(self.device, det_path=os.path.join(BASE_DIR, "checkpoints/yoloface_v5m.pt"))
207
+ if cfg.use_interframe:
208
+ self.rife = RIFEModel(device=self.device)
209
+ self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
210
+
211
+ # dtype
212
+ for m in (image_encoder, vae, unet):
213
+ m.to(weight_dtype)
214
+
215
+ # pipeline
216
+ pipe = SonicPipeline(unet=unet, image_encoder=image_encoder, vae=vae, scheduler=scheduler)
217
+ self.pipe = pipe.to(device=self.device, dtype=weight_dtype)
218
+ self.audio2token = audio2token
219
+ self.audio2bucket = audio2bucket
220
+ self.image_encoder = image_encoder
221
+ self.whisper = whisper
222
+
223
+ # --------------------------------------------------------------
224
+ def preprocess(self, image_path: str, expand_ratio: float = 1.0):
225
+ img = cv2.imread(image_path)
226
+ h, w = img.shape[:2]
227
+ _, _, bboxes = self.face_det(img, maxface=True)
228
+ if bboxes:
229
  x1, y1, ww, hh = bboxes[0]
230
+ bbox = (x1, y1, x1 + ww, y1 + hh)
231
+ crop_bbox = process_bbox(bbox, expand_radio=expand_ratio, height=h, width=w)
232
+ return {"face_num": len(bboxes), "crop_bbox": crop_bbox}
233
+ return {"face_num": 0, "crop_bbox": None}
 
 
 
 
 
 
 
 
 
234
 
235
+ # --------------------------------------------------------------
236
  @torch.no_grad()
237
+ def process(
238
+ self,
239
+ image_path: str,
240
+ audio_path: str,
241
+ output_path: str,
242
+ min_resolution: int = 512,
243
+ inference_steps: int = 25,
244
+ dynamic_scale: float = 1.0,
245
+ keep_resolution: bool = False,
246
+ seed: int | None = None,
247
+ ):
248
+ cfg = self.config
249
+ if seed is not None:
250
+ cfg.seed = seed
251
+ cfg.num_inference_steps = inference_steps
252
+ cfg.motion_bucket_scale = dynamic_scale
253
+ seed_everything(cfg.seed)
254
+
255
+ # ----------------------------------------------------------
256
+ # 이미지·오디오 → 텐서
257
+ # ----------------------------------------------------------
 
 
 
 
 
 
 
 
258
  test_data = image_audio_to_tensor(
259
+ self.face_det,
260
+ self.feature_extractor,
261
+ image_path,
262
+ audio_path,
263
+ limit=-1, # 전체 오디오 사용
264
+ image_size=min_resolution,
265
+ area=cfg.area,
266
  )
267
  if test_data is None:
268
  return -1
269
+
270
+ h, w = test_data["ref_img"].shape[-2:]
271
+ resolution = (
272
+ f"{(Image.open(image_path).width // 2)*2}x{(Image.open(image_path).height // 2)*2}"
273
+ if keep_resolution
274
+ else f"{w}x{h}"
275
+ )
276
+
277
+ # ----------------------------------------------------------
278
+ # 프레임 생성
279
+ # ----------------------------------------------------------
280
  video = test(
281
+ self.pipe,
282
+ cfg,
283
+ wav_enc=self.whisper,
284
+ audio_pe=self.audio2token,
285
+ audio2bucket=self.audio2bucket,
286
+ image_encoder=self.image_encoder,
287
+ width=w,
288
+ height=h,
289
  batch=test_data,
290
  )
291
 
292
+ # 중간 프레임 보간
293
+ if cfg.use_interframe:
294
+ out, results = video.to(self.device), []
295
+ for i in tqdm(range(out.shape[2] - 1), ncols=0):
296
+ I1, I2 = out[:, :, i], out[:, :, i + 1]
297
+ middle = self.rife.inference(I1, I2).clamp(0, 1).detach()
298
+ results.extend([out[:, :, i], middle])
299
+ results.append(out[:, :, -1])
 
 
 
 
 
300
  video = torch.stack(results, 2).cpu()
 
 
 
301
 
302
+ # ----------------------------------------------------------
303
+ # 파일 저장
304
+ # ----------------------------------------------------------
305
+ tmp_video = output_path.replace(".mp4", "_noaudio.mp4")
306
+ save_videos_grid(video, tmp_video, n_rows=video.shape[0], fps=cfg.fps * (2 if cfg.use_interframe else 1))
307
  os.system(
308
+ f"ffmpeg -i '{tmp_video}' -i '{audio_path}' -s {resolution} "
309
+ f"-vcodec libx264 -acodec aac -crf 18 -shortest '{output_path}' -y -loglevel error"
310
  )
311
+ os.remove(tmp_video)
312
  return 0