Spaces:
Running
on
Zero
Running
on
Zero
Update sonic.py
Browse files
sonic.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
-
import os
|
2 |
-
import math # [★ 수정] ceil 계산용
|
3 |
import torch
|
4 |
-
import torch.utils.checkpoint
|
5 |
from PIL import Image
|
6 |
from omegaconf import OmegaConf
|
7 |
from tqdm import tqdm
|
@@ -26,109 +24,89 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
26 |
|
27 |
|
28 |
# ------------------------------------------------------------------
|
29 |
-
#
|
30 |
# ------------------------------------------------------------------
|
31 |
def test(
|
32 |
-
pipe,
|
33 |
-
|
34 |
-
wav_enc,
|
35 |
-
audio_pe,
|
36 |
-
audio2bucket,
|
37 |
-
image_encoder,
|
38 |
-
width,
|
39 |
-
height,
|
40 |
-
batch,
|
41 |
):
|
42 |
-
#
|
43 |
for k, v in batch.items():
|
44 |
if isinstance(v, torch.Tensor):
|
45 |
batch[k] = v.unsqueeze(0).to(pipe.device).float()
|
46 |
|
47 |
-
ref_img = batch["ref_img"]
|
48 |
clip_img = batch["clip_images"]
|
49 |
face_mask = batch["face_mask"]
|
50 |
image_embeds = image_encoder(clip_img).image_embeds
|
51 |
|
52 |
-
audio_feature = batch["audio_feature"] # (
|
53 |
-
audio_len = batch["audio_len"]
|
54 |
step = int(config.step)
|
55 |
|
56 |
-
#
|
57 |
-
# ① 1 초 구간 단위를 위해 window 16000 → whisper‐tiny 기준 1 초
|
58 |
-
# ② audio_len < step 이면 step 을 줄여 빈 리스트 방지
|
59 |
-
# --------------------------------------------------------------------
|
60 |
-
window = 16000
|
61 |
if audio_len < step:
|
62 |
step = max(1, audio_len)
|
63 |
|
64 |
-
|
65 |
-
audio_prompts,
|
|
|
|
|
66 |
for i in range(0, audio_feature.shape[-1], window):
|
67 |
-
chunk = audio_feature[:, :, i : i + window]
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
last_hidden = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
|
72 |
|
73 |
audio_prompts.append(torch.stack(prompt_layers, dim=2))
|
74 |
-
|
75 |
|
76 |
-
# ── 예외: 아무 내용도 없으면 종료
|
77 |
if len(audio_prompts) == 0:
|
78 |
-
raise ValueError(
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
-
# Whisper token 시퀀스 재구성 (+ 모델 padding 규칙)
|
84 |
-
audio_prompts = torch.cat(audio_prompts, dim=1)[:, : audio_len * 2]
|
85 |
-
audio_prompts = torch.cat(
|
86 |
-
[torch.zeros_like(audio_prompts[:, :4]), audio_prompts, torch.zeros_like(audio_prompts[:, :6])],
|
87 |
-
dim=1,
|
88 |
-
)
|
89 |
-
|
90 |
-
last_audio_prompts = torch.cat(last_audio_prompts, dim=1)[:, : audio_len * 2]
|
91 |
-
last_audio_prompts = torch.cat(
|
92 |
-
[torch.zeros_like(last_audio_prompts[:, :24]), last_audio_prompts, torch.zeros_like(last_audio_prompts[:, :26])],
|
93 |
-
dim=1,
|
94 |
-
)
|
95 |
-
|
96 |
-
# --------------------------------------------------------------------
|
97 |
-
# step 조정 결과를 반영해 총 chunk 횟수 계산 (ceil)
|
98 |
-
# --------------------------------------------------------------------
|
99 |
-
num_chunks = math.ceil(audio_len / step)
|
100 |
-
|
101 |
-
ref_tensor_list, audio_tensor_list, uncond_audio_tensor_list, motion_buckets = [], [], [], []
|
102 |
for i in tqdm(range(num_chunks)):
|
103 |
start = i * 2 * step
|
104 |
-
audio_clip = audio_prompts[:, start : start + 10].unsqueeze(0)
|
105 |
-
audio_clip_for_bucket = last_audio_prompts[:, start : start + 50].unsqueeze(0)
|
106 |
|
107 |
-
|
108 |
-
|
|
|
|
|
109 |
|
110 |
-
|
111 |
-
|
|
|
|
|
112 |
|
113 |
-
|
114 |
-
audio_tensor_list.append(cond_audio[0])
|
115 |
-
uncond_audio_tensor_list.append(uncond_audio[0])
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
120 |
|
121 |
-
#
|
122 |
video = pipe(
|
123 |
-
ref_img,
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
uncond_audio_tensor_list,
|
128 |
-
motion_buckets,
|
129 |
-
height=height,
|
130 |
-
width=width,
|
131 |
-
num_frames=len(audio_tensor_list),
|
132 |
decode_chunk_size=config.decode_chunk_size,
|
133 |
motion_bucket_scale=config.motion_bucket_scale,
|
134 |
fps=config.fps,
|
@@ -143,81 +121,60 @@ def test(
|
|
143 |
num_inference_steps=config.num_inference_steps,
|
144 |
i2i_noise_strength=config.i2i_noise_strength,
|
145 |
).frames
|
146 |
-
# --------------------------------------------------------------------
|
147 |
|
148 |
video = (video * 0.5 + 0.5).clamp(0, 1)
|
149 |
return video.to(pipe.device).unsqueeze(0).cpu()
|
150 |
|
151 |
|
152 |
# ------------------------------------------------------------------
|
153 |
-
#
|
154 |
# ------------------------------------------------------------------
|
155 |
class Sonic:
|
156 |
config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
|
157 |
config = OmegaConf.load(config_file)
|
158 |
|
159 |
def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
|
160 |
-
cfg
|
161 |
-
cfg.use_interframe
|
162 |
-
self.device
|
163 |
cfg.pretrained_model_name_or_path = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path)
|
164 |
|
165 |
-
# ───────────── 모델 로드
|
166 |
self._load_models(cfg)
|
167 |
print("Sonic init done")
|
168 |
|
169 |
-
# --------------------------------------------------------------
|
170 |
-
# model / pipeline loader
|
171 |
# --------------------------------------------------------------
|
172 |
def _load_models(self, cfg):
|
173 |
-
|
174 |
-
weight_dtype = dtype_map.get(cfg.weight_dtype, torch.float32)
|
175 |
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
)
|
180 |
-
scheduler = EulerDiscreteScheduler.from_pretrained(
|
181 |
-
cfg.pretrained_model_name_or_path, subfolder="scheduler"
|
182 |
-
)
|
183 |
-
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
184 |
-
cfg.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16"
|
185 |
-
)
|
186 |
-
unet = UNetSpatioTemporalConditionModel.from_pretrained(
|
187 |
-
cfg.pretrained_model_name_or_path, subfolder="unet", variant="fp16"
|
188 |
-
)
|
189 |
add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
|
190 |
|
191 |
-
|
192 |
-
|
193 |
-
audio2bucket = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
|
194 |
|
195 |
-
# checkpoints
|
196 |
unet.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu"))
|
197 |
-
|
198 |
-
|
199 |
|
200 |
-
# whisper
|
201 |
whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
|
202 |
whisper.requires_grad_(False)
|
203 |
|
204 |
-
# extras
|
205 |
self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny"))
|
206 |
self.face_det = AlignImage(self.device, det_path=os.path.join(BASE_DIR, "checkpoints/yoloface_v5m.pt"))
|
207 |
if cfg.use_interframe:
|
208 |
self.rife = RIFEModel(device=self.device)
|
209 |
self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
|
210 |
|
211 |
-
|
212 |
-
|
213 |
-
m.to(weight_dtype)
|
214 |
|
215 |
-
|
216 |
-
|
217 |
-
self.
|
218 |
-
self.
|
219 |
-
self.audio2bucket = audio2bucket
|
220 |
-
self.image_encoder = image_encoder
|
221 |
self.whisper = whisper
|
222 |
|
223 |
# --------------------------------------------------------------
|
@@ -227,9 +184,7 @@ class Sonic:
|
|
227 |
_, _, bboxes = self.face_det(img, maxface=True)
|
228 |
if bboxes:
|
229 |
x1, y1, ww, hh = bboxes[0]
|
230 |
-
|
231 |
-
crop_bbox = process_bbox(bbox, expand_radio=expand_ratio, height=h, width=w)
|
232 |
-
return {"face_num": len(bboxes), "crop_bbox": crop_bbox}
|
233 |
return {"face_num": 0, "crop_bbox": None}
|
234 |
|
235 |
# --------------------------------------------------------------
|
@@ -248,19 +203,17 @@ class Sonic:
|
|
248 |
cfg = self.config
|
249 |
if seed is not None:
|
250 |
cfg.seed = seed
|
251 |
-
cfg.num_inference_steps
|
252 |
-
cfg.motion_bucket_scale
|
253 |
seed_everything(cfg.seed)
|
254 |
|
255 |
-
#
|
256 |
-
# 이미지·오디오 → 텐서
|
257 |
-
# ----------------------------------------------------------
|
258 |
test_data = image_audio_to_tensor(
|
259 |
self.face_det,
|
260 |
self.feature_extractor,
|
261 |
image_path,
|
262 |
audio_path,
|
263 |
-
limit=-1,
|
264 |
image_size=min_resolution,
|
265 |
area=cfg.area,
|
266 |
)
|
@@ -269,14 +222,12 @@ class Sonic:
|
|
269 |
|
270 |
h, w = test_data["ref_img"].shape[-2:]
|
271 |
resolution = (
|
272 |
-
f"{(Image.open(image_path).width // 2)*2}x{(Image.open(image_path).height // 2)*2}"
|
273 |
if keep_resolution
|
274 |
else f"{w}x{h}"
|
275 |
)
|
276 |
|
277 |
-
#
|
278 |
-
# 프레임 생성
|
279 |
-
# ----------------------------------------------------------
|
280 |
video = test(
|
281 |
self.pipe,
|
282 |
cfg,
|
@@ -291,22 +242,20 @@ class Sonic:
|
|
291 |
|
292 |
# 중간 프레임 보간
|
293 |
if cfg.use_interframe:
|
294 |
-
out
|
|
|
295 |
for i in tqdm(range(out.shape[2] - 1), ncols=0):
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
# ----------------------------------------------------------
|
305 |
-
tmp_video = output_path.replace(".mp4", "_noaudio.mp4")
|
306 |
-
save_videos_grid(video, tmp_video, n_rows=video.shape[0], fps=cfg.fps * (2 if cfg.use_interframe else 1))
|
307 |
os.system(
|
308 |
-
f"ffmpeg -i '{
|
309 |
f"-vcodec libx264 -acodec aac -crf 18 -shortest '{output_path}' -y -loglevel error"
|
310 |
)
|
311 |
-
os.remove(
|
312 |
return 0
|
|
|
1 |
+
import os, math
|
|
|
2 |
import torch
|
|
|
3 |
from PIL import Image
|
4 |
from omegaconf import OmegaConf
|
5 |
from tqdm import tqdm
|
|
|
24 |
|
25 |
|
26 |
# ------------------------------------------------------------------
|
27 |
+
# single image + speech → video-tensor generator
|
28 |
# ------------------------------------------------------------------
|
29 |
def test(
|
30 |
+
pipe, config, wav_enc, audio_pe, audio2bucket, image_encoder,
|
31 |
+
width, height, batch,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
):
|
33 |
+
# --- 배치 차원 맞추기 --------------------------------------------------
|
34 |
for k, v in batch.items():
|
35 |
if isinstance(v, torch.Tensor):
|
36 |
batch[k] = v.unsqueeze(0).to(pipe.device).float()
|
37 |
|
38 |
+
ref_img = batch["ref_img"] # (1,C,H,W)
|
39 |
clip_img = batch["clip_images"]
|
40 |
face_mask = batch["face_mask"]
|
41 |
image_embeds = image_encoder(clip_img).image_embeds
|
42 |
|
43 |
+
audio_feature = batch["audio_feature"] # (1,80,T)
|
44 |
+
audio_len = int(batch["audio_len"]) # Python int
|
45 |
step = int(config.step)
|
46 |
|
47 |
+
# --- [★ 수정] step 보정 (최소 1) --------------------------------------
|
|
|
|
|
|
|
|
|
48 |
if audio_len < step:
|
49 |
step = max(1, audio_len)
|
50 |
|
51 |
+
window = 16000 # 1 초 구간
|
52 |
+
audio_prompts, last_prompts = [], []
|
53 |
+
|
54 |
+
# --- window 단위 Whisper 인코딩 --------------------------------------
|
55 |
for i in range(0, audio_feature.shape[-1], window):
|
56 |
+
chunk = audio_feature[:, :, i : i + window]
|
57 |
|
58 |
+
prompt_layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
|
59 |
+
last_hidden = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
|
|
|
60 |
|
61 |
audio_prompts.append(torch.stack(prompt_layers, dim=2))
|
62 |
+
last_prompts.append(last_hidden)
|
63 |
|
|
|
64 |
if len(audio_prompts) == 0:
|
65 |
+
raise ValueError("[ERROR] No speech recognised in the provided audio.")
|
66 |
+
|
67 |
+
audio_prompts = torch.cat(audio_prompts, dim=1)
|
68 |
+
last_prompts = torch.cat(last_prompts, dim=1)
|
69 |
+
|
70 |
+
# padding 규칙
|
71 |
+
audio_prompts = torch.cat(
|
72 |
+
[torch.zeros_like(audio_prompts[:, :4]), audio_prompts,
|
73 |
+
torch.zeros_like(audio_prompts[:, :6])], dim=1)
|
74 |
+
last_prompts = torch.cat(
|
75 |
+
[torch.zeros_like(last_prompts[:, :24]), last_prompts,
|
76 |
+
torch.zeros_like(last_prompts[:, :26])], dim=1)
|
77 |
+
|
78 |
+
# --- [★ 수정] 반드시 ≥1 chunk ----------------------------------------
|
79 |
+
total_tokens = audio_prompts.shape[1]
|
80 |
+
num_chunks = max(1, math.ceil(total_tokens / (2 * step)))
|
81 |
+
|
82 |
+
ref_list, audio_list, uncond_list, motion_buckets = [], [], [], []
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
for i in tqdm(range(num_chunks)):
|
85 |
start = i * 2 * step
|
|
|
|
|
86 |
|
87 |
+
cond_clip = audio_prompts[:, start : start + 10]
|
88 |
+
if cond_clip.shape[2] < 10: # [★ 수정] 패딩
|
89 |
+
pad = torch.zeros_like(cond_clip[:, :, : 10 - cond_clip.shape[2]])
|
90 |
+
cond_clip = torch.cat([cond_clip, pad], dim=2)
|
91 |
|
92 |
+
bucket_clip = last_prompts[:, start : start + 50]
|
93 |
+
if bucket_clip.shape[2] < 50: # [★ 수정] 패딩
|
94 |
+
pad = torch.zeros_like(bucket_clip[:, :, : 50 - bucket_clip.shape[2]])
|
95 |
+
bucket_clip = torch.cat([bucket_clip, pad], dim=2)
|
96 |
|
97 |
+
motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
|
|
|
|
|
98 |
|
99 |
+
ref_list.append(ref_img[0])
|
100 |
+
audio_list.append(audio_pe(cond_clip).squeeze(0)[0])
|
101 |
+
uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0)[0])
|
102 |
+
motion_buckets.append(motion[0])
|
103 |
|
104 |
+
# ----------------------------------------------------------------------
|
105 |
video = pipe(
|
106 |
+
ref_img, clip_img, face_mask,
|
107 |
+
audio_list, uncond_list, motion_buckets,
|
108 |
+
height=height, width=width,
|
109 |
+
num_frames=len(audio_list),
|
|
|
|
|
|
|
|
|
|
|
110 |
decode_chunk_size=config.decode_chunk_size,
|
111 |
motion_bucket_scale=config.motion_bucket_scale,
|
112 |
fps=config.fps,
|
|
|
121 |
num_inference_steps=config.num_inference_steps,
|
122 |
i2i_noise_strength=config.i2i_noise_strength,
|
123 |
).frames
|
|
|
124 |
|
125 |
video = (video * 0.5 + 0.5).clamp(0, 1)
|
126 |
return video.to(pipe.device).unsqueeze(0).cpu()
|
127 |
|
128 |
|
129 |
# ------------------------------------------------------------------
|
130 |
+
# Sonic 클래스
|
131 |
# ------------------------------------------------------------------
|
132 |
class Sonic:
|
133 |
config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
|
134 |
config = OmegaConf.load(config_file)
|
135 |
|
136 |
def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
|
137 |
+
cfg = self.config
|
138 |
+
cfg.use_interframe = enable_interpolate_frame
|
139 |
+
self.device = f"cuda:{device_id}" if device_id >= 0 and torch.cuda.is_available() else "cpu"
|
140 |
cfg.pretrained_model_name_or_path = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path)
|
141 |
|
|
|
142 |
self._load_models(cfg)
|
143 |
print("Sonic init done")
|
144 |
|
|
|
|
|
145 |
# --------------------------------------------------------------
|
146 |
def _load_models(self, cfg):
|
147 |
+
dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
|
|
|
148 |
|
149 |
+
vae = AutoencoderKLTemporalDecoder.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", variant="fp16")
|
150 |
+
sched = EulerDiscreteScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler")
|
151 |
+
image_enc = CLIPVisionModelWithProjection.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16")
|
152 |
+
unet = UNetSpatioTemporalConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", variant="fp16")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
|
154 |
|
155 |
+
a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
|
156 |
+
a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
|
|
|
157 |
|
|
|
158 |
unet.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu"))
|
159 |
+
a2t.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu"))
|
160 |
+
a2b.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu"))
|
161 |
|
|
|
162 |
whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
|
163 |
whisper.requires_grad_(False)
|
164 |
|
|
|
165 |
self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny"))
|
166 |
self.face_det = AlignImage(self.device, det_path=os.path.join(BASE_DIR, "checkpoints/yoloface_v5m.pt"))
|
167 |
if cfg.use_interframe:
|
168 |
self.rife = RIFEModel(device=self.device)
|
169 |
self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
|
170 |
|
171 |
+
for m in (image_enc, vae, unet):
|
172 |
+
m.to(dtype)
|
|
|
173 |
|
174 |
+
self.pipe = SonicPipeline(unet=unet, image_encoder=image_enc, vae=vae, scheduler=sched).to(device=self.device, dtype=dtype)
|
175 |
+
self.image_encoder = image_enc
|
176 |
+
self.audio2token = a2t
|
177 |
+
self.audio2bucket = a2b
|
|
|
|
|
178 |
self.whisper = whisper
|
179 |
|
180 |
# --------------------------------------------------------------
|
|
|
184 |
_, _, bboxes = self.face_det(img, maxface=True)
|
185 |
if bboxes:
|
186 |
x1, y1, ww, hh = bboxes[0]
|
187 |
+
return {"face_num": 1, "crop_bbox": process_bbox((x1, y1, x1 + ww, y1 + hh), expand_ratio, h, w)}
|
|
|
|
|
188 |
return {"face_num": 0, "crop_bbox": None}
|
189 |
|
190 |
# --------------------------------------------------------------
|
|
|
203 |
cfg = self.config
|
204 |
if seed is not None:
|
205 |
cfg.seed = seed
|
206 |
+
cfg.num_inference_steps = inference_steps
|
207 |
+
cfg.motion_bucket_scale = dynamic_scale
|
208 |
seed_everything(cfg.seed)
|
209 |
|
210 |
+
# 이미지·오디오 → tensor
|
|
|
|
|
211 |
test_data = image_audio_to_tensor(
|
212 |
self.face_det,
|
213 |
self.feature_extractor,
|
214 |
image_path,
|
215 |
audio_path,
|
216 |
+
limit=-1,
|
217 |
image_size=min_resolution,
|
218 |
area=cfg.area,
|
219 |
)
|
|
|
222 |
|
223 |
h, w = test_data["ref_img"].shape[-2:]
|
224 |
resolution = (
|
225 |
+
f"{(Image.open(image_path).width // 2) * 2}x{(Image.open(image_path).height // 2) * 2}"
|
226 |
if keep_resolution
|
227 |
else f"{w}x{h}"
|
228 |
)
|
229 |
|
230 |
+
# 비디오 프레임 생성
|
|
|
|
|
231 |
video = test(
|
232 |
self.pipe,
|
233 |
cfg,
|
|
|
242 |
|
243 |
# 중간 프레임 보간
|
244 |
if cfg.use_interframe:
|
245 |
+
out = video.to(self.device)
|
246 |
+
frames = []
|
247 |
for i in tqdm(range(out.shape[2] - 1), ncols=0):
|
248 |
+
mid = self.rife.inference(out[:, :, i], out[:, :, i + 1]).clamp(0, 1).detach()
|
249 |
+
frames.extend([out[:, :, i], mid])
|
250 |
+
frames.append(out[:, :, -1])
|
251 |
+
video = torch.stack(frames, 2).cpu()
|
252 |
+
|
253 |
+
# 저장
|
254 |
+
tmp = output_path.replace(".mp4", "_noaudio.mp4")
|
255 |
+
save_videos_grid(video, tmp, n_rows=video.shape[0], fps=cfg.fps * (2 if cfg.use_interframe else 1))
|
|
|
|
|
|
|
256 |
os.system(
|
257 |
+
f"ffmpeg -i '{tmp}' -i '{audio_path}' -s {resolution} "
|
258 |
f"-vcodec libx264 -acodec aac -crf 18 -shortest '{output_path}' -y -loglevel error"
|
259 |
)
|
260 |
+
os.remove(tmp)
|
261 |
return 0
|