openfree commited on
Commit
7b4dc6f
·
verified ·
1 Parent(s): 1d7967c

Update sonic.py

Browse files
Files changed (1) hide show
  1. sonic.py +227 -170
sonic.py CHANGED
@@ -1,9 +1,15 @@
1
- # sonic.py ── 전체
2
-
3
- import os, math, glob, torch, cv2
 
 
 
 
 
4
  from PIL import Image
5
  from omegaconf import OmegaConf
6
  from tqdm import tqdm
 
7
 
8
  from diffusers import AutoencoderKLTemporalDecoder
9
  from diffusers.schedulers import EulerDiscreteScheduler
@@ -12,7 +18,8 @@ from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatur
12
  from src.utils.util import save_videos_grid, seed_everything
13
  from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor
14
  from src.models.base.unet_spatio_temporal_condition import (
15
- UNetSpatioTemporalConditionModel, add_ip_adapters,
 
16
  )
17
  from src.pipelines.pipeline_sonic import SonicPipeline
18
  from src.models.audio_adapter.audio_proj import AudioProjModel
@@ -20,115 +27,165 @@ from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel
20
  from src.utils.RIFE.RIFE_HDv3 import RIFEModel
21
  from src.dataset.face_align.align import AlignImage
22
 
23
- try:
24
- from safetensors.torch import load_file as safe_load
25
- except ImportError:
26
- safe_load = None # safetensors 미설치 시 대비
27
-
28
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
29
 
30
 
31
- # ------------------------------------------------------------ utils
32
- def _find_ckpt(root: str, keyword: str):
33
- """root 밑에서 keyword 가 포함된 .pth / .pt / .safetensors 하나 찾기"""
34
- patterns = [f"**/*{keyword}*.pth",
35
- f"**/*{keyword}*.pt",
36
- f"**/*{keyword}*.safetensors"]
37
- for p in patterns:
38
- files = glob.glob(os.path.join(root, p), recursive=True)
39
- if files:
40
- return files[0]
41
- return None
42
-
43
-
44
- # --------------------------------------------------- speech video
45
- def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
46
- width, height, batch):
47
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  for k, v in batch.items():
49
  if isinstance(v, torch.Tensor):
50
  batch[k] = v.unsqueeze(0).to(pipe.device).float()
51
 
52
- ref_img = batch["ref_img"]
53
  clip_img = batch["clip_images"]
54
  face_mask = batch["face_mask"]
55
  image_embeds = image_encoder(clip_img).image_embeds
56
 
57
- audio_feature = batch["audio_feature"] # (1,80,T)
58
- audio_len = int(batch["audio_len"])
59
- step = max(1, int(cfg.step))
60
 
61
- window = 16_000
62
- audio_prompts, last_prompts = [], []
 
63
 
64
- for i in range(0, audio_feature.shape[-1], window):
65
- chunk = audio_feature[:, :, i:i+window]
66
- hidden = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
67
- last = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
68
- audio_prompts.append(torch.stack(hidden, dim=2))
69
- last_prompts.append(last)
70
 
71
- if not audio_prompts:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  raise ValueError("[ERROR] No speech recognised in the provided audio.")
73
 
74
- audio_prompts = torch.cat(audio_prompts, dim=1)
75
- last_prompts = torch.cat(last_prompts , dim=1)
 
76
 
77
- audio_prompts = torch.cat(
78
- [torch.zeros_like(audio_prompts[:, :4]),
79
- audio_prompts,
80
- torch.zeros_like(audio_prompts[:, :6])], 1)
81
  last_prompts = torch.cat(
82
- [torch.zeros_like(last_prompts[:, :24]),
83
- last_prompts,
84
- torch.zeros_like(last_prompts[:, :26])], 1)
85
-
86
- num_chunks = max(1, math.ceil(audio_prompts.shape[1] / (2*step)))
87
-
88
- ref_list, audio_list, uncond_list, buckets = [], [], [], []
89
- for i in tqdm(range(num_chunks)):
90
- st = i * 2 * step
91
- cond = audio_prompts[:, st: st+10]
92
- if cond.shape[2] < 10:
93
- pad = torch.zeros_like(cond[:, :, :10-cond.shape[2]])
94
- cond = torch.cat([cond, pad], 2)
95
-
96
- bucket_clip = last_prompts[:, st: st+50]
97
- if bucket_clip.shape[2] < 50:
98
- pad = torch.zeros_like(bucket_clip[:, :, :50-bucket_clip.shape[2]])
99
- bucket_clip = torch.cat([bucket_clip, pad], 2)
100
-
101
- motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
102
-
103
- ref_list.append(ref_img[0])
104
- audio_list.append(audio_pe(cond).squeeze(0))
105
- uncond_list.append(audio_pe(torch.zeros_like(cond)).squeeze(0))
106
- buckets.append(motion[0])
107
-
108
- video = pipe(
109
- ref_img, clip_img, face_mask,
110
- audio_list, uncond_list, buckets,
111
- height=height, width=width,
112
- num_frames=len(audio_list),
113
- decode_chunk_size=cfg.decode_chunk_size,
114
- motion_bucket_scale=cfg.motion_bucket_scale,
115
- fps=cfg.fps,
116
- noise_aug_strength=cfg.noise_aug_strength,
117
- min_guidance_scale1=cfg.min_appearance_guidance_scale,
118
- max_guidance_scale1=cfg.max_appearance_guidance_scale,
119
- min_guidance_scale2=cfg.audio_guidance_scale,
120
- max_guidance_scale2=cfg.audio_guidance_scale,
121
- overlap=cfg.overlap,
122
- shift_offset=cfg.shift_offset,
123
- frames_per_batch=cfg.n_sample_frames,
124
- num_inference_steps=cfg.num_inference_steps,
125
- i2i_noise_strength=cfg.i2i_noise_strength,
126
- ).frames
127
-
128
- return (video * .5 + .5).clamp(0,1).unsqueeze(0).cpu()
129
-
130
-
131
- # ------------------------------------------------------------ Sonic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  class Sonic:
133
  config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
134
  config = OmegaConf.load(config_file)
@@ -136,96 +193,87 @@ class Sonic:
136
  def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
137
  cfg = self.config
138
  cfg.use_interframe = enable_interpolate_frame
139
- self.device = f"cuda:{device_id}" if torch.cuda.is_available() and device_id >= 0 else "cpu"
140
 
141
- # diffusers 베이스 모델은 (config.json 포함)
142
- self.diffusers_root = os.path.join(BASE_DIR, "checkpoints", "stable-video-diffusion-img2vid-xt")
143
- # 추가 pth/pt/safetensors 는 ⇣
144
- self.ckpt_root = os.path.join(BASE_DIR, "checkpoints", "Sonic")
 
145
 
146
  self._load_models(cfg)
147
  print("Sonic init done")
148
 
149
- def _locate_diffusers_dir(root: str) -> str:
150
- """
151
- root 아래에서 model_index.json 또는 config.json 이 존재하는
152
- 디렉터리를 찾아 반환. (snapshots/<sha>/ … 형식 대응)
153
- """
154
- for cur, _dirs, files in os.walk(root):
155
- if {"model_index.json", "config.json"} & set(files):
156
- return cur
157
- raise FileNotFoundError(
158
- f"[ERROR] diffusers model files(model_index.json/config.json) "
159
- f"not found under {root}"
160
- )
161
-
162
-
163
- # --------------------------------------------- load all networks
164
  def _load_models(self, cfg):
 
165
  dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
166
- diff_root = _locate_diffusers_dir(self.diffusers_root) # ★★ 핵심 추가
167
 
168
- vae = AutoencoderKLTemporalDecoder.from_pretrained(self.diffusers_root, subfolder="vae", variant="fp16")
169
- sched = EulerDiscreteScheduler.from_pretrained(self.diffusers_root, subfolder="scheduler")
170
- img_e = CLIPVisionModelWithProjection.from_pretrained(self.diffusers_root, subfolder="image_encoder", variant="fp16")
171
- unet = UNetSpatioTemporalConditionModel.from_pretrained(self.diffusers_root, subfolder="unet", variant="fp16")
 
 
 
172
  add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
173
 
174
- def _load_extra(module, key):
175
- path = _find_ckpt(self.ckpt_root, key)
176
- if not path:
177
- print(f"[WARN] extra ckpt for '{key}' not found → skip")
178
- return
179
- print(f"[INFO] load {key} → {os.path.relpath(path, BASE_DIR)}")
180
- state = safe_load(path, device="cpu") if (safe_load and path.endswith(".safetensors")) else torch.load(path, map_location="cpu")
181
- module.load_state_dict(state, strict=False)
182
-
183
- a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
184
- a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
185
-
186
- _load_extra(unet, "unet")
187
- _load_extra(a2t, "audio2token")
188
- _load_extra(a2b, "audio2bucket")
189
-
190
- whisper = WhisperModel.from_pretrained(
191
- os.path.join(BASE_DIR, "checkpoints/whisper-tiny")
192
- ).to(self.device).eval()
193
  whisper.requires_grad_(False)
194
 
195
- self.feature_extractor = AutoFeatureExtractor.from_pretrained(
196
- os.path.join(BASE_DIR, "checkpoints/whisper-tiny")
197
- )
198
  self.face_det = AlignImage(self.device, det_path=os.path.join(BASE_DIR, "checkpoints/yoloface_v5m.pt"))
199
  if cfg.use_interframe:
200
  self.rife = RIFEModel(device=self.device)
201
  self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
202
 
203
- for m in (img_e, vae, unet):
 
204
  m.to(dtype)
205
 
206
- self.pipe = SonicPipeline(unet=unet, image_encoder=img_e, vae=vae, scheduler=sched).to(device=self.device, dtype=dtype)
207
  self.image_encoder = img_e
208
  self.audio2token = a2t
209
  self.audio2bucket = a2b
210
  self.whisper = whisper
211
 
212
- # --------------------------------------------- preprocess helpers
213
- def preprocess(self, image_path: str, expand_ratio: float = 1.0):
214
  img = cv2.imread(image_path)
215
  h, w = img.shape[:2]
216
  _, _, bboxes = self.face_det(img, maxface=True)
217
  if bboxes:
218
  x1, y1, ww, hh = bboxes[0]
219
- return {"face_num": 1,
220
- "crop_bbox": process_bbox((x1, y1, x1+ww, y1+hh), expand_ratio, h, w)}
221
  return {"face_num": 0, "crop_bbox": None}
222
 
223
- # --------------------------------------------------------------- run
224
  @torch.no_grad()
225
- def process(self, image_path, audio_path, output_path,
226
- min_resolution=512, inference_steps=25,
227
- dynamic_scale=1.0, keep_resolution=False, seed=None):
228
-
 
 
 
 
 
 
 
229
  cfg = self.config
230
  if seed is not None:
231
  cfg.seed = seed
@@ -233,10 +281,15 @@ class Sonic:
233
  cfg.motion_bucket_scale = dynamic_scale
234
  seed_everything(cfg.seed)
235
 
 
236
  data = image_audio_to_tensor(
237
- self.face_det, self.feature_extractor,
238
- image_path, audio_path,
239
- limit=-1, image_size=min_resolution, area=cfg.area
 
 
 
 
240
  )
241
  if data is None:
242
  return -1
@@ -244,27 +297,31 @@ class Sonic:
244
  h, w = data["ref_img"].shape[-2:]
245
  if keep_resolution:
246
  im = Image.open(image_path)
247
- resolution = f"{im.width//2*2}x{im.height//2*2}"
248
  else:
249
  resolution = f"{w}x{h}"
250
 
251
- video = test(self.pipe, cfg, self.whisper, self.audio2token,
252
- self.audio2bucket, self.image_encoder,
253
- w, h, data)
 
 
254
 
 
255
  if cfg.use_interframe:
256
- out, frames = video.to(self.device), []
257
- for i in tqdm(range(out.shape[2]-1), ncols=0):
258
- mid = self.rife.inference(out[:,:,i], out[:,:,i+1]).clamp(0,1).detach()
259
- frames.extend([out[:,:,i], mid])
260
- frames.append(out[:,:,-1])
261
- video = torch.stack(frames, 2).cpu()
262
-
 
263
  tmp = output_path.replace(".mp4", "_noaudio.mp4")
264
- save_videos_grid(video, tmp, n_rows=video.shape[0],
265
- fps=cfg.fps*(2 if cfg.use_interframe else 1))
266
  os.system(
267
  f"ffmpeg -loglevel error -y -i '{tmp}' -i '{audio_path}' -s {resolution} "
268
- f"-vcodec libx264 -acodec aac -crf 18 -shortest '{output_path}'")
 
269
  os.remove(tmp)
270
  return 0
 
1
+ # sonic.py
2
+ # ---------------------------------------------------------------------
3
+ # Sonic single-image + speech → talking-head video (offline edition)
4
+ # ---------------------------------------------------------------------
5
+ import os, math
6
+ from typing import Dict, Any, List
7
+
8
+ import torch
9
  from PIL import Image
10
  from omegaconf import OmegaConf
11
  from tqdm import tqdm
12
+ import cv2
13
 
14
  from diffusers import AutoencoderKLTemporalDecoder
15
  from diffusers.schedulers import EulerDiscreteScheduler
 
18
  from src.utils.util import save_videos_grid, seed_everything
19
  from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor
20
  from src.models.base.unet_spatio_temporal_condition import (
21
+ UNetSpatioTemporalConditionModel,
22
+ add_ip_adapters,
23
  )
24
  from src.pipelines.pipeline_sonic import SonicPipeline
25
  from src.models.audio_adapter.audio_proj import AudioProjModel
 
27
  from src.utils.RIFE.RIFE_HDv3 import RIFEModel
28
  from src.dataset.face_align.align import AlignImage
29
 
 
 
 
 
 
30
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
31
 
32
 
33
+ # ------------------------------------------------------------------ #
34
+ # 헬퍼 : diffusers 경로 자동 찾기 #
35
+ # ------------------------------------------------------------------ #
36
+ def _locate_diffusers_dir(root: str) -> str:
37
+ """
38
+ `root` 하위 디렉터리에서 diffusers 스냅샷(model_index.json or config.json)
39
+ 들어 있는 실제 모델 폴더를 찾아서 반환한다. 존재하지 않으면 오류.
40
+ """
41
+ for cur, _dirs, files in os.walk(root):
42
+ if {"model_index.json", "config.json"} & set(files):
43
+ return cur
44
+ raise FileNotFoundError(
45
+ f"[ERROR] No diffusers model files found under '{root}'. "
46
+ "Check that the checkpoint was downloaded correctly."
47
+ )
48
+
49
+
50
+ # ------------------------------------------------------------------ #
51
+ # 영상 생성용 내부 함수 #
52
+ # ------------------------------------------------------------------ #
53
+ def _gen_video_tensor(
54
+ pipe: SonicPipeline,
55
+ cfg: OmegaConf,
56
+ wav_enc: WhisperModel,
57
+ audio_pe: AudioProjModel,
58
+ audio2bucket: Audio2bucketModel,
59
+ image_encoder: CLIPVisionModelWithProjection,
60
+ width: int,
61
+ height: int,
62
+ batch: Dict[str, torch.Tensor],
63
+ ) -> torch.Tensor:
64
+ """
65
+ single 이미지 + 오디오 feature → video tensor (C,T,H,W)
66
+ """
67
+
68
+ # -------- batch 차원 보정 --------------------------------------
69
  for k, v in batch.items():
70
  if isinstance(v, torch.Tensor):
71
  batch[k] = v.unsqueeze(0).to(pipe.device).float()
72
 
73
+ ref_img = batch["ref_img"] # (1,C,H,W)
74
  clip_img = batch["clip_images"]
75
  face_mask = batch["face_mask"]
76
  image_embeds = image_encoder(clip_img).image_embeds
77
 
78
+ audio_feat: torch.Tensor = batch["audio_feature"] # (1, 80, T)
79
+ audio_len: int = int(batch["audio_len"]) # scalar
80
+ step: int = int(cfg.step)
81
 
82
+ # step 이 전체 길이보다 크면 최소 1 로 보정
83
+ if audio_len < step:
84
+ step = max(1, audio_len)
85
 
86
+ # -------- Whisper encoder 1 단위로 수행 ----------------------
87
+ window = 16_000 # 1-s chunk
88
+ aud_prompts: List[torch.Tensor] = []
89
+ last_prompts: List[torch.Tensor] = []
 
 
90
 
91
+ for i in range(0, audio_feat.shape[-1], window):
92
+ chunk = audio_feat[:, :, i : i + window]
93
+
94
+ # 모든 hidden-states / 마지막 hidden-state
95
+ layers: List[torch.Tensor] = wav_enc.encoder(
96
+ chunk, output_hidden_states=True
97
+ ).hidden_states
98
+ last_hidden = wav_enc.encoder(chunk).last_hidden_state # (1, 80, 384)
99
+
100
+ # Whisper layer 는 6개 → AudioProj 가 기대하는 5개로 truncate
101
+ prompt = torch.stack(layers, dim=2)[:, :, :5] # (1,80,5,384)
102
+ aud_prompts.append(prompt)
103
+ last_prompts.append(last_hidden.unsqueeze(-2)) # (1,80,1,384)
104
+
105
+ if len(aud_prompts) == 0:
106
  raise ValueError("[ERROR] No speech recognised in the provided audio.")
107
 
108
+ # concat padding 규칙 적용
109
+ aud_prompts = torch.cat(aud_prompts, dim=1) # (1, 80*…, 5, 384)
110
+ last_prompts = torch.cat(last_prompts, dim=1) # (1, 80*…, 1, 384)
111
 
112
+ aud_prompts = torch.cat(
113
+ [torch.zeros_like(aud_prompts[:, :4]), aud_prompts, torch.zeros_like(aud_prompts[:, :6])],
114
+ dim=1,
115
+ )
116
  last_prompts = torch.cat(
117
+ [torch.zeros_like(last_prompts[:, :24]), last_prompts, torch.zeros_like(last_prompts[:, :26])],
118
+ dim=1,
119
+ )
120
+
121
+ # -------- f=10 / w=5 clip 자르기 --------------------------
122
+ ref_list, aud_list, uncond_list, mb_list = [], [], [], []
123
+
124
+ total_tokens = aud_prompts.shape[1]
125
+ n_chunks = max(1, math.ceil(total_tokens / (2 * step)))
126
+
127
+ for i in tqdm(range(n_chunks), desc="audio-chunks", ncols=0):
128
+ s = i * 2 * step
129
+
130
+ cond_clip = aud_prompts[:, s : s + 10] # (1,10,5,384)
131
+ if cond_clip.shape[1] < 10: # 뒤쪽 padding
132
+ pad = torch.zeros_like(cond_clip[:, : 10 - cond_clip.shape[1]])
133
+ cond_clip = torch.cat([cond_clip, pad], dim=1)
134
+
135
+ bucket_clip = last_prompts[:, s : s + 50] # (1,50,1,384)
136
+ if bucket_clip.shape[1] < 50:
137
+ pad = torch.zeros_like(bucket_clip[:, : 50 - bucket_clip.shape[1]])
138
+ bucket_clip = torch.cat([bucket_clip, pad], dim=1)
139
+
140
+ # (bz,f,w,b,c) 5-D 로 변환
141
+ cond_clip = cond_clip.unsqueeze(3) # (1,10,5,1,384)
142
+ bucket_clip = bucket_clip.unsqueeze(3) # (1,50,1,1,384)
143
+ uncond_clip = torch.zeros_like(cond_clip)
144
+
145
+ motion_bucket = audio2bucket(bucket_clip, image_embeds) * 16 + 16
146
+
147
+ ref_list .append(ref_img[0])
148
+ aud_list .append(audio_pe(cond_clip).squeeze(0)[0]) # (ctx,1024)
149
+ uncond_list .append(audio_pe(uncond_clip).squeeze(0)[0]) # (ctx,1024)
150
+ mb_list .append(motion_bucket[0])
151
+
152
+ # -------- UNet 파이프라인 실행 --------------------------------
153
+ video = (
154
+ pipe(
155
+ ref_img,
156
+ clip_img,
157
+ face_mask,
158
+ aud_list,
159
+ uncond_list,
160
+ mb_list,
161
+ height=height,
162
+ width=width,
163
+ num_frames=len(aud_list),
164
+ decode_chunk_size=cfg.decode_chunk_size,
165
+ motion_bucket_scale=cfg.motion_bucket_scale,
166
+ fps=cfg.fps,
167
+ noise_aug_strength=cfg.noise_aug_strength,
168
+ min_guidance_scale1=cfg.min_appearance_guidance_scale,
169
+ max_guidance_scale1=cfg.max_appearance_guidance_scale,
170
+ min_guidance_scale2=cfg.audio_guidance_scale,
171
+ max_guidance_scale2=cfg.audio_guidance_scale,
172
+ overlap=cfg.overlap,
173
+ shift_offset=cfg.shift_offset,
174
+ frames_per_batch=cfg.n_sample_frames,
175
+ num_inference_steps=cfg.num_inference_steps,
176
+ i2i_noise_strength=cfg.i2i_noise_strength,
177
+ ).frames
178
+ * 0.5
179
+ + 0.5
180
+ ).clamp(0, 1)
181
+
182
+ # (B,C,T,H,W) → (C,T,H,W)
183
+ return video.to(pipe.device).squeeze(0).cpu()
184
+
185
+
186
+ # ------------------------------------------------------------------ #
187
+ # Sonic – main class #
188
+ # ------------------------------------------------------------------ #
189
  class Sonic:
190
  config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
191
  config = OmegaConf.load(config_file)
 
193
  def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
194
  cfg = self.config
195
  cfg.use_interframe = enable_interpolate_frame
 
196
 
197
+ # diffusers 모델 상위 폴더 (로컬 다운로드 경로)
198
+ self.diffusers_root = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path)
199
+ self.device = (
200
+ f"cuda:{device_id}" if device_id >= 0 and torch.cuda.is_available() else "cpu"
201
+ )
202
 
203
  self._load_models(cfg)
204
  print("Sonic init done")
205
 
206
+ # -------------------------------------------------------------- #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  def _load_models(self, cfg):
208
+ # dtype
209
  dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
 
210
 
211
+ diff_root = _locate_diffusers_dir(self.diffusers_root)
212
+
213
+ # diffusers 모듈들
214
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(diff_root, subfolder="vae", variant="fp16")
215
+ sched = EulerDiscreteScheduler.from_pretrained(diff_root, subfolder="scheduler")
216
+ img_e = CLIPVisionModelWithProjection.from_pretrained(diff_root, subfolder="image_encoder", variant="fp16")
217
+ unet = UNetSpatioTemporalConditionModel.from_pretrained(diff_root, subfolder="unet", variant="fp16")
218
  add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
219
 
220
+ # 오디오 어댑터
221
+ a2t = AudioProjModel(seq_len=10, blocks=5, channels=384,
222
+ intermediate_dim=1024, output_dim=1024, context_tokens=32).to(self.device)
223
+ a2b = Audio2bucketModel(seq_len=50, blocks=1, channels=384,
224
+ clip_channels=1024, intermediate_dim=1024, output_dim=1,
225
+ context_tokens=2).to(self.device)
226
+
227
+ # 체크포인트 로드
228
+ a2t.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu"))
229
+ a2b.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu"))
230
+ unet.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu"))
231
+
232
+ # Whisper
233
+ whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
 
 
 
 
 
234
  whisper.requires_grad_(False)
235
 
236
+ # 이미지 / 얼굴 / 보간
237
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny"))
 
238
  self.face_det = AlignImage(self.device, det_path=os.path.join(BASE_DIR, "checkpoints/yoloface_v5m.pt"))
239
  if cfg.use_interframe:
240
  self.rife = RIFEModel(device=self.device)
241
  self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
242
 
243
+ # dtype 적용
244
+ for m in (vae, img_e, unet):
245
  m.to(dtype)
246
 
247
+ self.pipe = SonicPipeline(unet=unet, image_encoder=img_e, vae=vae, scheduler=sched).to(self.device, dtype=dtype)
248
  self.image_encoder = img_e
249
  self.audio2token = a2t
250
  self.audio2bucket = a2b
251
  self.whisper = whisper
252
 
253
+ # -------------------------------------------------------------- #
254
+ def preprocess(self, image_path: str, expand_ratio: float = 1.0) -> Dict[str, Any]:
255
  img = cv2.imread(image_path)
256
  h, w = img.shape[:2]
257
  _, _, bboxes = self.face_det(img, maxface=True)
258
  if bboxes:
259
  x1, y1, ww, hh = bboxes[0]
260
+ crop = process_bbox((x1, y1, x1 + ww, y1 + hh), expand_ratio, h, w)
261
+ return {"face_num": 1, "crop_bbox": crop}
262
  return {"face_num": 0, "crop_bbox": None}
263
 
264
+ # -------------------------------------------------------------- #
265
  @torch.no_grad()
266
+ def process(
267
+ self,
268
+ image_path: str,
269
+ audio_path: str,
270
+ output_path: str,
271
+ min_resolution: int = 512,
272
+ inference_steps: int = 25,
273
+ dynamic_scale: float = 1.0,
274
+ keep_resolution: bool = False,
275
+ seed: int | None = None,
276
+ ) -> int:
277
  cfg = self.config
278
  if seed is not None:
279
  cfg.seed = seed
 
281
  cfg.motion_bucket_scale = dynamic_scale
282
  seed_everything(cfg.seed)
283
 
284
+ # 이미지·오디오 tensor 변환
285
  data = image_audio_to_tensor(
286
+ self.face_det,
287
+ self.feature_extractor,
288
+ image_path,
289
+ audio_path,
290
+ limit=-1,
291
+ image_size=min_resolution,
292
+ area=cfg.area,
293
  )
294
  if data is None:
295
  return -1
 
297
  h, w = data["ref_img"].shape[-2:]
298
  if keep_resolution:
299
  im = Image.open(image_path)
300
+ resolution = f"{(im.width // 2) * 2}x{(im.height // 2) * 2}"
301
  else:
302
  resolution = f"{w}x{h}"
303
 
304
+ # video tensor 생성
305
+ video = _gen_video_tensor(
306
+ self.pipe, cfg, self.whisper, self.audio2token, self.audio2bucket,
307
+ self.image_encoder, w, h, data,
308
+ )
309
 
310
+ # 중간 프레임 보간
311
  if cfg.use_interframe:
312
+ out = video.to(self.device)
313
+ frames = []
314
+ for i in tqdm(range(out.shape[1] - 1), desc="interpolate", ncols=0):
315
+ frames.extend([out[:, i], self.rife.inference(out[:, i], out[:, i + 1]).clamp(0, 1)])
316
+ frames.append(out[:, -1])
317
+ video = torch.stack(frames, 1).cpu() # (C,T',H,W)
318
+
319
+ # 저장
320
  tmp = output_path.replace(".mp4", "_noaudio.mp4")
321
+ save_videos_grid(video.unsqueeze(0), tmp, n_rows=1, fps=cfg.fps * (2 if cfg.use_interframe else 1))
 
322
  os.system(
323
  f"ffmpeg -loglevel error -y -i '{tmp}' -i '{audio_path}' -s {resolution} "
324
+ f"-vcodec libx264 -acodec aac -crf 18 -shortest '{output_path}'"
325
+ )
326
  os.remove(tmp)
327
  return 0