openfree commited on
Commit
f40c908
·
verified ·
1 Parent(s): 43cb38b

Update sonic.py

Browse files
Files changed (1) hide show
  1. sonic.py +58 -78
sonic.py CHANGED
@@ -1,9 +1,7 @@
1
- import os, math
2
- import torch
3
  from PIL import Image
4
  from omegaconf import OmegaConf
5
  from tqdm import tqdm
6
- import cv2
7
 
8
  from diffusers import AutoencoderKLTemporalDecoder
9
  from diffusers.schedulers import EulerDiscreteScheduler
@@ -22,10 +20,6 @@ from src.dataset.face_align.align import AlignImage
22
 
23
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
24
 
25
- # ------------------------------------------------------------------
26
- # single image + speech → video-tensor generator
27
- # ------------------------------------------------------------------
28
- # …(상단 import 및 기타 정의 동일)…
29
 
30
  # ------------------------------------------------------------------
31
  # single image + speech → video-tensor generator
@@ -42,20 +36,20 @@ def test(
42
  ref_img = batch["ref_img"]
43
  clip_img = batch["clip_images"]
44
  face_mask = batch["face_mask"]
45
- image_embeds = image_encoder(clip_img).image_embeds
46
 
47
- audio_feature = batch["audio_feature"]
48
  audio_len = int(batch["audio_len"])
49
  step = int(config.step)
50
 
51
- window = 16_000 # 1
52
  audio_prompts, last_prompts = [], []
53
 
54
  for i in range(0, audio_feature.shape[-1], window):
55
- chunk = audio_feature[:, :, i : i + window]
56
  layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
57
  last = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
58
- audio_prompts.append(torch.stack(layers, dim=2)) # (1,w,L,384)
59
  last_prompts.append(last)
60
 
61
  if not audio_prompts:
@@ -64,6 +58,7 @@ def test(
64
  audio_prompts = torch.cat(audio_prompts, dim=1)
65
  last_prompts = torch.cat(last_prompts, dim=1)
66
 
 
67
  audio_prompts = torch.cat(
68
  [torch.zeros_like(audio_prompts[:, :4]), audio_prompts,
69
  torch.zeros_like(audio_prompts[:, :6])], dim=1)
@@ -80,34 +75,35 @@ def test(
80
  start = i * 2 * step
81
 
82
  # ------------ cond_clip : (1,1,10,5,384) ------------------
83
- clip_raw = audio_prompts[:, start : start + 10] # (1,≤10,L,384)
84
- if clip_raw.shape[1] < 10: # w-pad
85
- pad_w = torch.zeros_like(clip_raw[:, :10 - clip_raw.shape[1]])
 
 
86
  clip_raw = torch.cat([clip_raw, pad_w], dim=1)
87
 
88
- # ★ L-pad → 정확히 5 레이어 만들기
89
  while clip_raw.shape[2] < 5:
90
  clip_raw = torch.cat([clip_raw, clip_raw[:, :, -1:]], dim=2)
91
- clip_raw = clip_raw[:, :, :5] # (1,10,5,384)
92
 
93
- cond_clip = clip_raw.unsqueeze(1) # (1,1,10,5,384)
94
 
95
  # ------------ bucket_clip : (1,1,50,1,384) -----------------
96
  bucket_raw = last_prompts[:, start : start + 50]
97
- if bucket_raw.shape[1] < 50:
98
- pad_w = torch.zeros_like(bucket_raw[:, :50 - bucket_raw.shape[1]])
99
  bucket_raw = torch.cat([bucket_raw, pad_w], dim=1)
100
- bucket_clip = bucket_raw.unsqueeze(1) # (1,1,50,1,384)
101
 
102
  motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
103
 
104
  ref_list.append(ref_img[0])
105
- # ★ 여기: squeeze(0) (bz 제거). [0] 인덱싱 제거
106
- audio_list.append(audio_pe(cond_clip).squeeze(0)) # (50,1024)
107
  uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0))
108
  motion_buckets.append(motion[0])
109
 
110
- # ---- Stable Video Diffusion 호출 ------------------------------
111
  video = pipe(
112
  ref_img, clip_img, face_mask,
113
  audio_list, uncond_list, motion_buckets,
@@ -132,20 +128,17 @@ def test(
132
  return video.to(pipe.device).unsqueeze(0).cpu()
133
 
134
 
135
-
136
-
137
-
138
  # ------------------------------------------------------------------
139
- # Sonic class
140
  # ------------------------------------------------------------------
141
  class Sonic:
142
  config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
143
  config = OmegaConf.load(config_file)
144
 
145
  def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
146
- cfg = self.config
147
  cfg.use_interframe = enable_interpolate_frame
148
- self.device = f"cuda:{device_id}" if device_id >= 0 and torch.cuda.is_available() else "cpu"
149
  cfg.pretrained_model_name_or_path = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path)
150
 
151
  self._load_models(cfg)
@@ -155,18 +148,18 @@ class Sonic:
155
  def _load_models(self, cfg):
156
  dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
157
 
158
- vae = AutoencoderKLTemporalDecoder.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", variant="fp16")
159
- sched = EulerDiscreteScheduler .from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler")
160
- imgenc= CLIPVisionModelWithProjection .from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16")
161
  unet = UNetSpatioTemporalConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", variant="fp16")
162
  add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
163
 
164
  a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
165
  a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
166
 
167
- unet .load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu"))
168
- a2t .load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu"))
169
- a2b .load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu"))
170
 
171
  whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
172
  whisper.requires_grad_(False)
@@ -177,22 +170,21 @@ class Sonic:
177
  self.rife = RIFEModel(device=self.device)
178
  self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
179
 
180
- for m in (imgenc, vae, unet):
181
- m.to(dtype)
182
 
183
- self.pipe = SonicPipeline(unet=unet, image_encoder=imgenc, vae=vae, scheduler=sched).to(device=self.device, dtype=dtype)
184
- self.image_encoder = imgenc
185
  self.audio2token = a2t
186
  self.audio2bucket = a2b
187
  self.whisper = whisper
188
 
189
  # --------------------------------------------------------------
190
- def preprocess(self, image_path: str, expand_ratio: float = 1.0):
191
- img = cv2.imread(image_path)
192
  h, w = img.shape[:2]
193
- _, _, bboxes = self.face_det(img, maxface=True)
194
- if bboxes:
195
- x1, y1, ww, hh = bboxes[0]
196
  return {"face_num": 1, "crop_bbox": process_bbox((x1, y1, x1 + ww, y1 + hh), expand_ratio, h, w)}
197
  return {"face_num": 0, "crop_bbox": None}
198
 
@@ -200,50 +192,39 @@ class Sonic:
200
  @torch.no_grad()
201
  def process(
202
  self,
203
- image_path: str,
204
- audio_path: str,
205
- output_path: str,
206
  min_resolution: int = 512,
207
- inference_steps: int = 25,
208
  dynamic_scale: float = 1.0,
209
  keep_resolution: bool = False,
210
  seed: int | None = None,
211
  ):
212
  cfg = self.config
213
- if seed is not None:
214
- cfg.seed = seed
215
- cfg.num_inference_steps = inference_steps
216
- cfg.motion_bucket_scale = dynamic_scale
217
  seed_everything(cfg.seed)
218
 
219
- # 이미지·오디오 → tensor
220
- test_data = image_audio_to_tensor(
221
- self.face_det,
222
- self.feature_extractor,
223
- image_path,
224
- audio_path,
225
- limit=-1,
226
- image_size=min_resolution,
227
- area=cfg.area,
228
  )
229
- if test_data is None:
230
  return -1
231
 
232
- h, w = test_data["ref_img"].shape[-2:]
233
- resolution = (
234
- f"{(Image.open(image_path).width // 2) * 2}x{(Image.open(image_path).height // 2) * 2}"
235
- if keep_resolution
236
- else f"{w}x{h}"
237
- )
238
 
239
- # 비디오 프레임 생성
240
  video = test(
241
  self.pipe, cfg, self.whisper, self.audio2token,
242
  self.audio2bucket, self.image_encoder,
243
- width=w, height=h, batch=test_data,
244
  )
245
 
246
- # 중간 프레임 보간
247
  if cfg.use_interframe:
248
  out = video.to(self.device)
249
  frames = []
@@ -253,12 +234,11 @@ class Sonic:
253
  frames.append(out[:, :, -1])
254
  video = torch.stack(frames, 2).cpu()
255
 
256
- # 저장
257
- tmp_mp4 = output_path.replace(".mp4", "_noaudio.mp4")
258
- save_videos_grid(video, tmp_mp4, n_rows=video.shape[0], fps=cfg.fps * (2 if cfg.use_interframe else 1))
259
  os.system(
260
- f"ffmpeg -i '{tmp_mp4}' -i '{audio_path}' -s {resolution} "
261
- f"-vcodec libx264 -acodec aac -crf 18 -shortest '{output_path}' -y -loglevel error"
262
  )
263
- os.remove(tmp_mp4)
264
  return 0
 
1
+ import os, math, torch, cv2
 
2
  from PIL import Image
3
  from omegaconf import OmegaConf
4
  from tqdm import tqdm
 
5
 
6
  from diffusers import AutoencoderKLTemporalDecoder
7
  from diffusers.schedulers import EulerDiscreteScheduler
 
20
 
21
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
22
 
 
 
 
 
23
 
24
  # ------------------------------------------------------------------
25
  # single image + speech → video-tensor generator
 
36
  ref_img = batch["ref_img"]
37
  clip_img = batch["clip_images"]
38
  face_mask = batch["face_mask"]
39
+ image_embeds = image_encoder(clip_img).image_embeds # (1,1024)
40
 
41
+ audio_feature = batch["audio_feature"] # (1, 80, T)
42
  audio_len = int(batch["audio_len"])
43
  step = int(config.step)
44
 
45
+ window = 16_000 # 1-sec chunks
46
  audio_prompts, last_prompts = [], []
47
 
48
  for i in range(0, audio_feature.shape[-1], window):
49
+ chunk = audio_feature[:, :, i : i + window] # (1, 80, win)
50
  layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
51
  last = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
52
+ audio_prompts.append(torch.stack(layers, dim=2)) # (1, w, L, 384)
53
  last_prompts.append(last)
54
 
55
  if not audio_prompts:
 
58
  audio_prompts = torch.cat(audio_prompts, dim=1)
59
  last_prompts = torch.cat(last_prompts, dim=1)
60
 
61
+ # padding 규칙
62
  audio_prompts = torch.cat(
63
  [torch.zeros_like(audio_prompts[:, :4]), audio_prompts,
64
  torch.zeros_like(audio_prompts[:, :6])], dim=1)
 
75
  start = i * 2 * step
76
 
77
  # ------------ cond_clip : (1,1,10,5,384) ------------------
78
+ clip_raw = audio_prompts[:, start : start + 10] # (1, ≤10, L, 384)
79
+
80
+ # W-padding은 dim=1 이어야 함!
81
+ if clip_raw.shape[1] < 10:
82
+ pad_w = torch.zeros_like(clip_raw[:, : 10 - clip_raw.shape[1]])
83
  clip_raw = torch.cat([clip_raw, pad_w], dim=1)
84
 
85
+ # ★ L-padding은 dim=2
86
  while clip_raw.shape[2] < 5:
87
  clip_raw = torch.cat([clip_raw, clip_raw[:, :, -1:]], dim=2)
88
+ clip_raw = clip_raw[:, :, :5] # (1,10,5,384)
89
 
90
+ cond_clip = clip_raw.unsqueeze(1) # (1,1,10,5,384)
91
 
92
  # ------------ bucket_clip : (1,1,50,1,384) -----------------
93
  bucket_raw = last_prompts[:, start : start + 50]
94
+ if bucket_raw.shape[1] < 50: # ★ dim=1
95
+ pad_w = torch.zeros_like(bucket_raw[:, : 50 - bucket_raw.shape[1]])
96
  bucket_raw = torch.cat([bucket_raw, pad_w], dim=1)
97
+ bucket_clip = bucket_raw.unsqueeze(1) # (1,1,50,1,384)
98
 
99
  motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
100
 
101
  ref_list.append(ref_img[0])
102
+ audio_list.append(audio_pe(cond_clip).squeeze(0)) # (50,1024)
 
103
  uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0))
104
  motion_buckets.append(motion[0])
105
 
106
+ # ---- Stable-Video-Diffusion 호출 ------------------------------
107
  video = pipe(
108
  ref_img, clip_img, face_mask,
109
  audio_list, uncond_list, motion_buckets,
 
128
  return video.to(pipe.device).unsqueeze(0).cpu()
129
 
130
 
 
 
 
131
  # ------------------------------------------------------------------
132
+ # Sonic 클래스
133
  # ------------------------------------------------------------------
134
  class Sonic:
135
  config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
136
  config = OmegaConf.load(config_file)
137
 
138
  def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
139
+ cfg = self.config
140
  cfg.use_interframe = enable_interpolate_frame
141
+ self.device = f"cuda:{device_id}" if device_id >= 0 and torch.cuda.is_available() else "cpu"
142
  cfg.pretrained_model_name_or_path = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path)
143
 
144
  self._load_models(cfg)
 
148
  def _load_models(self, cfg):
149
  dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
150
 
151
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", variant="fp16")
152
+ sched = EulerDiscreteScheduler.from_pretrained (cfg.pretrained_model_name_or_path, subfolder="scheduler")
153
+ img_e = CLIPVisionModelWithProjection.from_pretrained (cfg.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16")
154
  unet = UNetSpatioTemporalConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", variant="fp16")
155
  add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
156
 
157
  a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
158
  a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
159
 
160
+ unet.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu"))
161
+ a2t.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu"))
162
+ a2b.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu"))
163
 
164
  whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
165
  whisper.requires_grad_(False)
 
170
  self.rife = RIFEModel(device=self.device)
171
  self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
172
 
173
+ img_e.to(dtype); vae.to(dtype); unet.to(dtype)
 
174
 
175
+ self.pipe = SonicPipeline(unet=unet, image_encoder=img_e, vae=vae, scheduler=sched).to(device=self.device, dtype=dtype)
176
+ self.image_encoder = img_e
177
  self.audio2token = a2t
178
  self.audio2bucket = a2b
179
  self.whisper = whisper
180
 
181
  # --------------------------------------------------------------
182
+ def preprocess(self, img_path: str, expand_ratio: float = 1.0):
183
+ img = cv2.imread(img_path)
184
  h, w = img.shape[:2]
185
+ _, _, faces = self.face_det(img, maxface=True)
186
+ if faces:
187
+ x1, y1, ww, hh = faces[0]
188
  return {"face_num": 1, "crop_bbox": process_bbox((x1, y1, x1 + ww, y1 + hh), expand_ratio, h, w)}
189
  return {"face_num": 0, "crop_bbox": None}
190
 
 
192
  @torch.no_grad()
193
  def process(
194
  self,
195
+ img_path: str,
196
+ audio_path:str,
197
+ out_path: str,
198
  min_resolution: int = 512,
199
+ inference_steps:int = 25,
200
  dynamic_scale: float = 1.0,
201
  keep_resolution: bool = False,
202
  seed: int | None = None,
203
  ):
204
  cfg = self.config
205
+ if seed is not None: cfg.seed = seed
206
+ cfg.num_inference_steps = inference_steps
207
+ cfg.motion_bucket_scale = dynamic_scale
 
208
  seed_everything(cfg.seed)
209
 
210
+ sample = image_audio_to_tensor(
211
+ self.face_det, self.feature_extractor,
212
+ img_path, audio_path,
213
+ limit=-1, image_size=min_resolution, area=cfg.area,
 
 
 
 
 
214
  )
215
+ if sample is None:
216
  return -1
217
 
218
+ h, w = sample["ref_img"].shape[-2:]
219
+ resolution = (f"{(Image.open(img_path).width //2)*2}x{(Image.open(img_path).height//2)*2}"
220
+ if keep_resolution else f"{w}x{h}")
 
 
 
221
 
 
222
  video = test(
223
  self.pipe, cfg, self.whisper, self.audio2token,
224
  self.audio2bucket, self.image_encoder,
225
+ w, h, sample,
226
  )
227
 
 
228
  if cfg.use_interframe:
229
  out = video.to(self.device)
230
  frames = []
 
234
  frames.append(out[:, :, -1])
235
  video = torch.stack(frames, 2).cpu()
236
 
237
+ tmp = out_path.replace(".mp4", "_noaudio.mp4")
238
+ save_videos_grid(video, tmp, n_rows=video.shape[0], fps=cfg.fps * (2 if cfg.use_interframe else 1))
 
239
  os.system(
240
+ f"ffmpeg -i '{tmp}' -i '{audio_path}' -s {resolution} "
241
+ f"-vcodec libx264 -acodec aac -crf 18 -shortest '{out_path}' -y -loglevel error"
242
  )
243
+ os.remove(tmp)
244
  return 0