aiqcamp commited on
Commit
9b8d878
·
verified ·
1 Parent(s): b581974

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -17
app.py CHANGED
@@ -68,11 +68,14 @@ net, feature_utils, seq_cfg = get_model()
68
  @spaces.GPU(duration=60)
69
  @torch.inference_mode()
70
  def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
71
- seed: int = -1, num_steps: int = 25,
72
- cfg_strength: float = 4.5, target_duration: float = 8.0):
73
  try:
74
  logger.info("Starting audio generation process")
75
 
 
 
 
76
  rng = torch.Generator(device=device)
77
  if seed >= 0:
78
  rng.manual_seed(seed)
@@ -81,9 +84,8 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
81
 
82
  fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
83
 
84
- # video_info = load_video(video_path, duration) 대신:
85
- kwargs = {'static_duration': target_duration}
86
- video_info = load_video(video_path, **kwargs)
87
 
88
  if video_info is None:
89
  logger.error("Failed to load video")
@@ -97,14 +99,13 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
97
  logger.error("Failed to extract frames from video")
98
  return video_path
99
 
100
- clip_frames = clip_frames.unsqueeze(0).to(device)
101
- sync_frames = sync_frames.unsqueeze(0).to(device)
 
102
 
103
- # 시퀀스 길이 업데이트
104
  seq_cfg.duration = actual_duration
105
  net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
106
 
107
- # 오디오 생성
108
  logger.info("Generating audio...")
109
  audios = generate(clip_frames,
110
  sync_frames,
@@ -122,14 +123,16 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
122
 
123
  audio = audios.float().cpu()[0]
124
 
125
- # 결과 비디오 생성
126
  output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
127
  logger.info(f"Creating final video with audio at {output_path}")
128
 
129
- make_video(video_info, output_path, audio, sampling_rate=seq_cfg.sampling_rate)
 
 
 
130
 
131
- if not os.path.exists(output_path):
132
- logger.error("Failed to create output video")
133
  return video_path
134
 
135
  logger.info(f'Successfully saved video with audio to {output_path}')
@@ -137,7 +140,8 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
137
 
138
  except Exception as e:
139
  logger.error(f"Error in video_to_audio: {str(e)}")
140
- return video_path # 오류 발생 시 원본 비디오 반환
 
141
 
142
  def upload_to_catbox(file_path):
143
  """catbox.moe API를 사용하여 파일 업로드"""
@@ -357,14 +361,13 @@ def generate_video(image, prompt):
357
  prompt=prompt,
358
  negative_prompt="music",
359
  seed=-1,
360
- num_steps=25,
361
  cfg_strength=4.5,
362
- target_duration=8.0 # duration을 target_duration으로 변경
363
  )
364
 
365
  if final_path_with_audio != final_path:
366
  logger.info("Audio generation successful")
367
- # 임시 파일 정리
368
  try:
369
  if output_path != final_path:
370
  os.remove(output_path)
 
68
  @spaces.GPU(duration=60)
69
  @torch.inference_mode()
70
  def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
71
+ seed: int = -1, num_steps: int = 20,
72
+ cfg_strength: float = 4.5, target_duration: float = 6.0):
73
  try:
74
  logger.info("Starting audio generation process")
75
 
76
+ # GPU 메모리 최적화
77
+ torch.cuda.empty_cache()
78
+
79
  rng = torch.Generator(device=device)
80
  if seed >= 0:
81
  rng.manual_seed(seed)
 
84
 
85
  fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
86
 
87
+ # load_video 함수 호출 수정
88
+ video_info = load_video(video_path, duration=target_duration) # static_duration duration으로 변경
 
89
 
90
  if video_info is None:
91
  logger.error("Failed to load video")
 
99
  logger.error("Failed to extract frames from video")
100
  return video_path
101
 
102
+ # 메모리 효율을 위해 배치 크기 조정
103
+ clip_frames = clip_frames.unsqueeze(0).to(device, dtype=torch.float16)
104
+ sync_frames = sync_frames.unsqueeze(0).to(device, dtype=torch.float16)
105
 
 
106
  seq_cfg.duration = actual_duration
107
  net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
108
 
 
109
  logger.info("Generating audio...")
110
  audios = generate(clip_frames,
111
  sync_frames,
 
123
 
124
  audio = audios.float().cpu()[0]
125
 
 
126
  output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
127
  logger.info(f"Creating final video with audio at {output_path}")
128
 
129
+ success = make_video(video_info, output_path, audio, sampling_rate=seq_cfg.sampling_rate)
130
+
131
+ # GPU 메모리 정리
132
+ torch.cuda.empty_cache()
133
 
134
+ if not success:
135
+ logger.error("Failed to create video with audio")
136
  return video_path
137
 
138
  logger.info(f'Successfully saved video with audio to {output_path}')
 
140
 
141
  except Exception as e:
142
  logger.error(f"Error in video_to_audio: {str(e)}")
143
+ torch.cuda.empty_cache()
144
+ return video_path
145
 
146
  def upload_to_catbox(file_path):
147
  """catbox.moe API를 사용하여 파일 업로드"""
 
361
  prompt=prompt,
362
  negative_prompt="music",
363
  seed=-1,
364
+ num_steps=20,
365
  cfg_strength=4.5,
366
+ target_duration=6.0
367
  )
368
 
369
  if final_path_with_audio != final_path:
370
  logger.info("Audio generation successful")
 
371
  try:
372
  if output_path != final_path:
373
  os.remove(output_path)