Update app.py
Browse files
app.py
CHANGED
@@ -68,11 +68,14 @@ net, feature_utils, seq_cfg = get_model()
|
|
68 |
@spaces.GPU(duration=60)
|
69 |
@torch.inference_mode()
|
70 |
def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
71 |
-
seed: int = -1, num_steps: int =
|
72 |
-
cfg_strength: float = 4.5, target_duration: float =
|
73 |
try:
|
74 |
logger.info("Starting audio generation process")
|
75 |
|
|
|
|
|
|
|
76 |
rng = torch.Generator(device=device)
|
77 |
if seed >= 0:
|
78 |
rng.manual_seed(seed)
|
@@ -81,9 +84,8 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
|
81 |
|
82 |
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
83 |
|
84 |
-
#
|
85 |
-
|
86 |
-
video_info = load_video(video_path, **kwargs)
|
87 |
|
88 |
if video_info is None:
|
89 |
logger.error("Failed to load video")
|
@@ -97,14 +99,13 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
|
97 |
logger.error("Failed to extract frames from video")
|
98 |
return video_path
|
99 |
|
100 |
-
|
101 |
-
|
|
|
102 |
|
103 |
-
# 시퀀스 길이 업데이트
|
104 |
seq_cfg.duration = actual_duration
|
105 |
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
106 |
|
107 |
-
# 오디오 생성
|
108 |
logger.info("Generating audio...")
|
109 |
audios = generate(clip_frames,
|
110 |
sync_frames,
|
@@ -122,14 +123,16 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
|
122 |
|
123 |
audio = audios.float().cpu()[0]
|
124 |
|
125 |
-
# 결과 비디오 생성
|
126 |
output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
|
127 |
logger.info(f"Creating final video with audio at {output_path}")
|
128 |
|
129 |
-
make_video(video_info, output_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
|
|
|
|
|
|
130 |
|
131 |
-
if not
|
132 |
-
logger.error("Failed to create
|
133 |
return video_path
|
134 |
|
135 |
logger.info(f'Successfully saved video with audio to {output_path}')
|
@@ -137,7 +140,8 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
|
137 |
|
138 |
except Exception as e:
|
139 |
logger.error(f"Error in video_to_audio: {str(e)}")
|
140 |
-
|
|
|
141 |
|
142 |
def upload_to_catbox(file_path):
|
143 |
"""catbox.moe API를 사용하여 파일 업로드"""
|
@@ -357,14 +361,13 @@ def generate_video(image, prompt):
|
|
357 |
prompt=prompt,
|
358 |
negative_prompt="music",
|
359 |
seed=-1,
|
360 |
-
num_steps=
|
361 |
cfg_strength=4.5,
|
362 |
-
target_duration=
|
363 |
)
|
364 |
|
365 |
if final_path_with_audio != final_path:
|
366 |
logger.info("Audio generation successful")
|
367 |
-
# 임시 파일 정리
|
368 |
try:
|
369 |
if output_path != final_path:
|
370 |
os.remove(output_path)
|
|
|
68 |
@spaces.GPU(duration=60)
|
69 |
@torch.inference_mode()
|
70 |
def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
71 |
+
seed: int = -1, num_steps: int = 20,
|
72 |
+
cfg_strength: float = 4.5, target_duration: float = 6.0):
|
73 |
try:
|
74 |
logger.info("Starting audio generation process")
|
75 |
|
76 |
+
# GPU 메모리 최적화
|
77 |
+
torch.cuda.empty_cache()
|
78 |
+
|
79 |
rng = torch.Generator(device=device)
|
80 |
if seed >= 0:
|
81 |
rng.manual_seed(seed)
|
|
|
84 |
|
85 |
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
86 |
|
87 |
+
# load_video 함수 호출 수정
|
88 |
+
video_info = load_video(video_path, duration=target_duration) # static_duration을 duration으로 변경
|
|
|
89 |
|
90 |
if video_info is None:
|
91 |
logger.error("Failed to load video")
|
|
|
99 |
logger.error("Failed to extract frames from video")
|
100 |
return video_path
|
101 |
|
102 |
+
# 메모리 효율을 위해 배치 크기 조정
|
103 |
+
clip_frames = clip_frames.unsqueeze(0).to(device, dtype=torch.float16)
|
104 |
+
sync_frames = sync_frames.unsqueeze(0).to(device, dtype=torch.float16)
|
105 |
|
|
|
106 |
seq_cfg.duration = actual_duration
|
107 |
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
108 |
|
|
|
109 |
logger.info("Generating audio...")
|
110 |
audios = generate(clip_frames,
|
111 |
sync_frames,
|
|
|
123 |
|
124 |
audio = audios.float().cpu()[0]
|
125 |
|
|
|
126 |
output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
|
127 |
logger.info(f"Creating final video with audio at {output_path}")
|
128 |
|
129 |
+
success = make_video(video_info, output_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
130 |
+
|
131 |
+
# GPU 메모리 정리
|
132 |
+
torch.cuda.empty_cache()
|
133 |
|
134 |
+
if not success:
|
135 |
+
logger.error("Failed to create video with audio")
|
136 |
return video_path
|
137 |
|
138 |
logger.info(f'Successfully saved video with audio to {output_path}')
|
|
|
140 |
|
141 |
except Exception as e:
|
142 |
logger.error(f"Error in video_to_audio: {str(e)}")
|
143 |
+
torch.cuda.empty_cache()
|
144 |
+
return video_path
|
145 |
|
146 |
def upload_to_catbox(file_path):
|
147 |
"""catbox.moe API를 사용하여 파일 업로드"""
|
|
|
361 |
prompt=prompt,
|
362 |
negative_prompt="music",
|
363 |
seed=-1,
|
364 |
+
num_steps=20,
|
365 |
cfg_strength=4.5,
|
366 |
+
target_duration=6.0
|
367 |
)
|
368 |
|
369 |
if final_path_with_audio != final_path:
|
370 |
logger.info("Audio generation successful")
|
|
|
371 |
try:
|
372 |
if output_path != final_path:
|
373 |
os.remove(output_path)
|