openfree commited on
Commit
5d6304b
·
verified ·
1 Parent(s): 0b636bf

Update sonic.py

Browse files
Files changed (1) hide show
  1. sonic.py +136 -167
sonic.py CHANGED
@@ -1,20 +1,23 @@
 
 
 
1
  import os, math, torch, cv2
 
2
  from PIL import Image
3
  from omegaconf import OmegaConf
4
  from tqdm import tqdm
5
-
6
  from diffusers import AutoencoderKLTemporalDecoder
7
  from diffusers.schedulers import EulerDiscreteScheduler
8
  from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatureExtractor
9
 
10
  from src.utils.util import save_videos_grid, seed_everything
11
- from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor
12
  from src.models.base.unet_spatio_temporal_condition import (
13
  UNetSpatioTemporalConditionModel, add_ip_adapters,
14
  )
15
- from src.pipelines.pipeline_sonic import SonicPipeline
16
  from src.models.audio_adapter.audio_proj import AudioProjModel
17
  from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel
 
18
  from src.utils.RIFE.RIFE_HDv3 import RIFEModel
19
  from src.dataset.face_align.align import AlignImage
20
 
@@ -22,223 +25,189 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
22
 
23
 
24
  # ------------------------------------------------------------------
25
- # single image + speech → video-tensor generator
26
  # ------------------------------------------------------------------
27
- def test(
28
- pipe, config, wav_enc, audio_pe, audio2bucket, image_encoder,
29
- width, height, batch,
30
- ):
31
- # ---- 배치 차원 맞추기 -----------------------------------------
32
  for k, v in batch.items():
33
  if isinstance(v, torch.Tensor):
34
- batch[k] = v.unsqueeze(0).to(pipe.device).float()
35
-
36
- ref_img = batch["ref_img"]
37
- clip_img = batch["clip_images"]
38
- face_mask = batch["face_mask"]
39
- image_embeds = image_encoder(clip_img).image_embeds # (1,1024)
40
-
41
- audio_feature = batch["audio_feature"] # (1, 80, T)
42
- audio_len = int(batch["audio_len"])
43
- step = int(config.step)
44
-
45
- window = 16_000 # 1-sec chunks
46
- audio_prompts, last_prompts = [], []
47
-
48
- for i in range(0, audio_feature.shape[-1], window):
49
- chunk = audio_feature[:, :, i : i + window] # (1, 80, win)
50
- layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
51
- last = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
52
- audio_prompts.append(torch.stack(layers, dim=2)) # (1, w, L, 384)
53
- last_prompts.append(last)
54
-
55
- if not audio_prompts:
56
- raise ValueError("[ERROR] No speech recognised in the provided audio.")
57
-
58
- audio_prompts = torch.cat(audio_prompts, dim=1)
59
- last_prompts = torch.cat(last_prompts, dim=1)
60
-
61
- # padding 규칙
62
- audio_prompts = torch.cat(
63
- [torch.zeros_like(audio_prompts[:, :4]), audio_prompts,
64
- torch.zeros_like(audio_prompts[:, :6])], dim=1)
65
- last_prompts = torch.cat(
66
- [torch.zeros_like(last_prompts[:, :24]), last_prompts,
67
- torch.zeros_like(last_prompts[:, :26])], dim=1)
68
-
69
- total_tokens = audio_prompts.shape[1]
70
- num_chunks = max(1, math.ceil(total_tokens / (2 * step)))
71
 
72
- ref_list, audio_list, uncond_list, motion_buckets = [], [], [], []
 
 
 
73
 
74
- for i in tqdm(range(num_chunks)):
75
- start = i * 2 * step
 
76
 
77
- # ------------ cond_clip : (1,1,10,5,384) ------------------
78
- clip_raw = audio_prompts[:, start : start + 10] # (1, ≤10, L, 384)
79
 
80
- # W-padding은 dim=1 이어야 함!
81
- if clip_raw.shape[1] < 10:
82
- pad_w = torch.zeros_like(clip_raw[:, : 10 - clip_raw.shape[1]])
83
- clip_raw = torch.cat([clip_raw, pad_w], dim=1)
 
 
84
 
85
- # L-padding은 dim=2
86
- while clip_raw.shape[2] < 5:
87
- clip_raw = torch.cat([clip_raw, clip_raw[:, :, -1:]], dim=2)
88
- clip_raw = clip_raw[:, :, :5] # (1,10,5,384)
89
 
90
- cond_clip = clip_raw.unsqueeze(1) # (1,1,10,5,384)
 
91
 
92
- # ------------ bucket_clip : (1,1,50,1,384) -----------------
93
- bucket_raw = last_prompts[:, start : start + 50]
94
- if bucket_raw.shape[1] < 50: # ★ dim=1
95
- pad_w = torch.zeros_like(bucket_raw[:, : 50 - bucket_raw.shape[1]])
96
- bucket_raw = torch.cat([bucket_raw, pad_w], dim=1)
97
- bucket_clip = bucket_raw.unsqueeze(1) # (1,1,50,1,384)
 
98
 
99
- motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
100
-
101
- ref_list.append(ref_img[0])
102
- audio_list.append(audio_pe(cond_clip).squeeze(0)) # (50,1024)
103
- uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0))
104
- motion_buckets.append(motion[0])
105
-
106
- # ---- Stable-Video-Diffusion 호출 ------------------------------
107
- video = pipe(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  ref_img, clip_img, face_mask,
109
- audio_list, uncond_list, motion_buckets,
110
  height=height, width=width,
111
- num_frames=len(audio_list),
112
- decode_chunk_size=config.decode_chunk_size,
113
- motion_bucket_scale=config.motion_bucket_scale,
114
- fps=config.fps,
115
- noise_aug_strength=config.noise_aug_strength,
116
- min_guidance_scale1=config.min_appearance_guidance_scale,
117
- max_guidance_scale1=config.max_appearance_guidance_scale,
118
- min_guidance_scale2=config.audio_guidance_scale,
119
- max_guidance_scale2=config.audio_guidance_scale,
120
- overlap=config.overlap,
121
- shift_offset=config.shift_offset,
122
- frames_per_batch=config.n_sample_frames,
123
- num_inference_steps=config.num_inference_steps,
124
- i2i_noise_strength=config.i2i_noise_strength,
125
  ).frames
126
 
127
- video = (video * 0.5 + 0.5).clamp(0, 1)
128
- return video.to(pipe.device).unsqueeze(0).cpu()
129
 
130
 
131
  # ------------------------------------------------------------------
132
- # Sonic 클래스
133
  # ------------------------------------------------------------------
134
  class Sonic:
135
  config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
136
  config = OmegaConf.load(config_file)
137
 
138
- def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
139
  cfg = self.config
140
  cfg.use_interframe = enable_interpolate_frame
141
- self.device = f"cuda:{device_id}" if device_id >= 0 and torch.cuda.is_available() else "cpu"
142
  cfg.pretrained_model_name_or_path = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path)
143
 
144
  self._load_models(cfg)
145
  print("Sonic init done")
146
 
147
- # --------------------------------------------------------------
148
  def _load_models(self, cfg):
149
  dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
150
 
151
- vae = AutoencoderKLTemporalDecoder.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", variant="fp16")
152
- sched = EulerDiscreteScheduler.from_pretrained (cfg.pretrained_model_name_or_path, subfolder="scheduler")
153
- img_e = CLIPVisionModelWithProjection.from_pretrained (cfg.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16")
154
- unet = UNetSpatioTemporalConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", variant="fp16")
155
  add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
156
 
157
- a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
158
- a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
159
 
160
- unet.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu"))
161
- a2t.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu"))
162
- a2b.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu"))
163
 
164
- whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
165
- whisper.requires_grad_(False)
166
 
167
  self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny"))
168
  self.face_det = AlignImage(self.device, det_path=os.path.join(BASE_DIR, "checkpoints/yoloface_v5m.pt"))
169
  if cfg.use_interframe:
170
- self.rife = RIFEModel(device=self.device)
171
- self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
172
-
173
- img_e.to(dtype); vae.to(dtype); unet.to(dtype)
174
 
175
- self.pipe = SonicPipeline(unet=unet, image_encoder=img_e, vae=vae, scheduler=sched).to(device=self.device, dtype=dtype)
176
- self.image_encoder = img_e
177
- self.audio2token = a2t
178
- self.audio2bucket = a2b
179
- self.whisper = whisper
180
 
181
- # --------------------------------------------------------------
182
- def preprocess(self, img_path: str, expand_ratio: float = 1.0):
183
  img = cv2.imread(img_path)
184
- h, w = img.shape[:2]
185
- _, _, faces = self.face_det(img, maxface=True)
186
- if faces:
187
- x1, y1, ww, hh = faces[0]
188
- return {"face_num": 1, "crop_bbox": process_bbox((x1, y1, x1 + ww, y1 + hh), expand_ratio, h, w)}
189
- return {"face_num": 0, "crop_bbox": None}
190
 
191
- # --------------------------------------------------------------
192
  @torch.no_grad()
193
- def process(
194
- self,
195
- img_path: str,
196
- audio_path:str,
197
- out_path: str,
198
- min_resolution: int = 512,
199
- inference_steps:int = 25,
200
- dynamic_scale: float = 1.0,
201
- keep_resolution: bool = False,
202
- seed: int | None = None,
203
- ):
204
  cfg = self.config
205
  if seed is not None: cfg.seed = seed
206
- cfg.num_inference_steps = inference_steps
207
- cfg.motion_bucket_scale = dynamic_scale
208
  seed_everything(cfg.seed)
209
 
210
  sample = image_audio_to_tensor(
211
  self.face_det, self.feature_extractor,
212
- img_path, audio_path,
213
- limit=-1, image_size=min_resolution, area=cfg.area,
214
  )
215
- if sample is None:
216
- return -1
217
 
218
- h, w = sample["ref_img"].shape[-2:]
219
- resolution = (f"{(Image.open(img_path).width //2)*2}x{(Image.open(img_path).height//2)*2}"
220
  if keep_resolution else f"{w}x{h}")
221
 
222
- video = test(
223
- self.pipe, cfg, self.whisper, self.audio2token,
224
- self.audio2bucket, self.image_encoder,
225
- w, h, sample,
226
- )
227
-
228
- if cfg.use_interframe:
229
- out = video.to(self.device)
230
- frames = []
231
- for i in tqdm(range(out.shape[2] - 1), ncols=0):
232
- mid = self.rife.inference(out[:, :, i], out[:, :, i + 1]).clamp(0, 1).detach()
233
- frames.extend([out[:, :, i], mid])
234
- frames.append(out[:, :, -1])
235
- video = torch.stack(frames, 2).cpu()
236
-
237
- tmp = out_path.replace(".mp4", "_noaudio.mp4")
238
- save_videos_grid(video, tmp, n_rows=video.shape[0], fps=cfg.fps * (2 if cfg.use_interframe else 1))
239
- os.system(
240
- f"ffmpeg -i '{tmp}' -i '{audio_path}' -s {resolution} "
241
- f"-vcodec libx264 -acodec aac -crf 18 -shortest '{out_path}' -y -loglevel error"
242
- )
243
- os.remove(tmp)
244
- return 0
 
1
+ # ---------------------------------------------------------
2
+ # sonic.py (2025-05 rev – fix AudioProjModel tensor shape)
3
+ # ---------------------------------------------------------
4
  import os, math, torch, cv2
5
+ import torch.utils.checkpoint
6
  from PIL import Image
7
  from omegaconf import OmegaConf
8
  from tqdm import tqdm
 
9
  from diffusers import AutoencoderKLTemporalDecoder
10
  from diffusers.schedulers import EulerDiscreteScheduler
11
  from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatureExtractor
12
 
13
  from src.utils.util import save_videos_grid, seed_everything
14
+ from src.dataset.test_preprocess import image_audio_to_tensor, process_bbox
15
  from src.models.base.unet_spatio_temporal_condition import (
16
  UNetSpatioTemporalConditionModel, add_ip_adapters,
17
  )
 
18
  from src.models.audio_adapter.audio_proj import AudioProjModel
19
  from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel
20
+ from src.pipelines.pipeline_sonic import SonicPipeline
21
  from src.utils.RIFE.RIFE_HDv3 import RIFEModel
22
  from src.dataset.face_align.align import AlignImage
23
 
 
25
 
26
 
27
  # ------------------------------------------------------------------
28
+ # single image + speech → video tensor
29
  # ------------------------------------------------------------------
30
+ def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, img_enc,
31
+ width, height, batch):
32
+
33
+ # --- batch 차원 맞추기 ------------------------------------------
 
34
  for k, v in batch.items():
35
  if isinstance(v, torch.Tensor):
36
+ batch[k] = v.unsqueeze(0).float().to(pipe.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ ref_img = batch['ref_img']
39
+ clip_img = batch['clip_images']
40
+ face_mask = batch['face_mask']
41
+ img_emb = img_enc(clip_img).image_embeds # (1,1024)
42
 
43
+ audio_feat = batch['audio_feature'] # (1,80,T)
44
+ audio_len = int(batch['audio_len'])
45
+ step = max(1, int(cfg.step)) # 안전 보정
46
 
47
+ window = 16_000 # 1-초 chunk
48
+ prompt_list, last_list = [], []
49
 
50
+ for i in range(0, audio_feat.shape[-1], window):
51
+ chunk = audio_feat[:, :, i:i+window]
52
+ hs = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
53
+ prompt_list.append(torch.stack(hs, 2)) # (1,80,L,384)
54
+ last = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
55
+ last_list.append(last) # (1,80,1,384)
56
 
57
+ if not prompt_list:
58
+ raise ValueError("❌ No speech recognised in audio.")
 
 
59
 
60
+ audio_prompts = torch.cat(prompt_list, 1) # (1,80,*L,384)
61
+ last_prompts = torch.cat(last_list, 1) # (1,80,*1,384)
62
 
63
+ # pad 규칙 (모델 원 논문과 동일)
64
+ audio_prompts = torch.cat([ torch.zeros_like(audio_prompts[:,:4]),
65
+ audio_prompts,
66
+ torch.zeros_like(audio_prompts[:,:6]) ], 1)
67
+ last_prompts = torch.cat([ torch.zeros_like(last_prompts[:,:24]),
68
+ last_prompts,
69
+ torch.zeros_like(last_prompts[:,:26]) ], 1)
70
 
71
+ # --------------------------------------------------------------
72
+ total_tok = audio_prompts.shape[1]
73
+ n_chunks = max(1, math.ceil(total_tok / (2*step)))
74
+
75
+ ref_L, aud_L, uncond_L, buckets = [], [], [], []
76
+
77
+ for i in tqdm(range(n_chunks), ncols=0):
78
+ st = i * 2 * step
79
+
80
+ # ① 조건 오디오 토큰(pad → 10×5×384)
81
+ cond = audio_prompts[:, st:st+10] # (1,80,10,384) → (1,10,8,384)?
82
+ cond = cond[:, :10] # f = 10
83
+ cond = cond.permute(0,2,1,3) # (1,10,80,384)
84
+ cond = cond.reshape(1, 10, 10, 5, 384) # ★ w=10, b=5 (zero-pad auto)
85
+ # ② bucket 추정용 토큰
86
+ buck = last_prompts[:, st:st+50] # (1,80,50,384)
87
+ if buck.shape[1] < 50:
88
+ pad = torch.zeros(1, 50-buck.shape[1], *buck.shape[2:], device=buck.device)
89
+ buck = torch.cat([buck, pad], 1)
90
+ buck = buck[:, :50].permute(0,2,1,3).reshape(1, 50, 10, 5, 384)
91
+
92
+ motion = audio2bucket(buck, img_emb) * 16 + 16
93
+
94
+ ref_L.append(ref_img[0])
95
+ aud_L.append(audio_pe(cond).squeeze(0)) # (10,1024)
96
+ uncond_L.append(audio_pe(torch.zeros_like(cond)).squeeze(0))
97
+ buckets.append(motion[0])
98
+
99
+ # -------------- diffusion -------------------------------------------------
100
+ vid = pipe(
101
  ref_img, clip_img, face_mask,
102
+ aud_L, uncond_L, buckets,
103
  height=height, width=width,
104
+ num_frames=len(aud_L),
105
+ decode_chunk_size=cfg.decode_chunk_size,
106
+ motion_bucket_scale=cfg.motion_bucket_scale,
107
+ fps=cfg.fps,
108
+ noise_aug_strength=cfg.noise_aug_strength,
109
+ min_guidance_scale1=cfg.min_appearance_guidance_scale,
110
+ max_guidance_scale1=cfg.max_appearance_guidance_scale,
111
+ min_guidance_scale2=cfg.audio_guidance_scale,
112
+ max_guidance_scale2=cfg.audio_guidance_scale,
113
+ overlap=cfg.overlap,
114
+ shift_offset=cfg.shift_offset,
115
+ frames_per_batch=cfg.n_sample_frames,
116
+ num_inference_steps=cfg.num_inference_steps,
117
+ i2i_noise_strength=cfg.i2i_noise_strength,
118
  ).frames
119
 
120
+ return (vid*0.5+0.5).clamp(0,1).to(pipe.device).unsqueeze(0).cpu()
 
121
 
122
 
123
  # ------------------------------------------------------------------
124
+ # Sonic wrapper
125
  # ------------------------------------------------------------------
126
  class Sonic:
127
  config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
128
  config = OmegaConf.load(config_file)
129
 
130
+ def __init__(self, device_id=0, enable_interpolate_frame=True):
131
  cfg = self.config
132
  cfg.use_interframe = enable_interpolate_frame
133
+ self.device = f"cuda:{device_id}" if torch.cuda.is_available() and device_id>=0 else "cpu"
134
  cfg.pretrained_model_name_or_path = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path)
135
 
136
  self._load_models(cfg)
137
  print("Sonic init done")
138
 
139
+ # model-loader (unchanged, but with tiny clean-ups) ------------------------
140
  def _load_models(self, cfg):
141
  dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
142
 
143
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", variant="fp16")
144
+ sched = EulerDiscreteScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler")
145
+ img_enc = CLIPVisionModelWithProjection.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16")
146
+ unet = UNetSpatioTemporalConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", variant="fp16")
147
  add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
148
 
149
+ self.audio2token = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
150
+ self.audio2bucket = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
151
 
152
+ unet.load_state_dict (torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu"))
153
+ self.audio2token.load_state_dict (torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu"))
154
+ self.audio2bucket.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu"))
155
 
156
+ self.whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
157
+ self.whisper.requires_grad_(False)
158
 
159
  self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny"))
160
  self.face_det = AlignImage(self.device, det_path=os.path.join(BASE_DIR, "checkpoints/yoloface_v5m.pt"))
161
  if cfg.use_interframe:
162
+ self.rife = RIFEModel(device=self.device); self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
 
 
 
163
 
164
+ for m in (img_enc, vae, unet): m.to(dtype)
165
+ self.pipe = SonicPipeline(unet=unet, image_encoder=img_enc, vae=vae, scheduler=sched).to(device=self.device, dtype=dtype)
166
+ self.image_encoder = img_enc
 
 
167
 
168
+ # ------------------------------------------------------------------
169
+ def preprocess(self, img_path, expand_ratio=1.0):
170
  img = cv2.imread(img_path)
171
+ _, _, boxes = self.face_det(img, maxface=True)
172
+ if boxes:
173
+ x,y,w,h = boxes[0]; return {"face_num":1,"crop_bbox":process_bbox((x,y,x+w,y+h),expand_ratio,*img.shape[:2])}
174
+ return {"face_num":0,"crop_bbox":None}
 
 
175
 
176
+ # ------------------------------------------------------------------
177
  @torch.no_grad()
178
+ def process(self, img_path, wav_path, out_path,
179
+ min_resolution=512, inference_steps=25,
180
+ dynamic_scale=1.0, keep_resolution=False, seed=None):
181
+
 
 
 
 
 
 
 
182
  cfg = self.config
183
  if seed is not None: cfg.seed = seed
184
+ cfg.num_inference_steps = inference_steps
185
+ cfg.motion_bucket_scale = dynamic_scale
186
  seed_everything(cfg.seed)
187
 
188
  sample = image_audio_to_tensor(
189
  self.face_det, self.feature_extractor,
190
+ img_path, wav_path, limit=-1,
191
+ image_size=min_resolution, area=cfg.area,
192
  )
193
+ if sample is None: return -1
 
194
 
195
+ h,w = sample['ref_img'].shape[-2:]
196
+ resolution = (f"{Image.open(img_path).width//2*2}x{Image.open(img_path).height//2*2}"
197
  if keep_resolution else f"{w}x{h}")
198
 
199
+ video = test(self.pipe, cfg, self.whisper, self.audio2token,
200
+ self.audio2bucket, self.image_encoder, w, h, sample)
201
+
202
+ if cfg.use_interframe: # RIFE interpolation
203
+ out = video.to(self.device); frames=[]
204
+ for i in tqdm(range(out.shape[2]-1), ncols=0):
205
+ mid = self.rife.inference(out[:,:,i], out[:,:,i+1]).clamp(0,1)
206
+ frames += [out[:,:,i], mid]
207
+ frames.append(out[:,:,-1]); video = torch.stack(frames,2).cpu()
208
+
209
+ tmp = out_path.replace(".mp4","_noaudio.mp4")
210
+ save_videos_grid(video, tmp, n_rows=video.shape[0], fps=cfg.fps*(2 if cfg.use_interframe else 1))
211
+ os.system(f"ffmpeg -i '{tmp}' -i '{wav_path}' -s {resolution} "
212
+ f"-vcodec libx264 -acodec aac -crf 18 -shortest '{out_path}' -y -loglevel error")
213
+ os.remove(tmp); return 0