Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,6 +12,7 @@ import random
|
|
| 12 |
import logging
|
| 13 |
import torchaudio
|
| 14 |
import os
|
|
|
|
| 15 |
|
| 16 |
# MMAudio imports
|
| 17 |
try:
|
|
@@ -20,6 +21,10 @@ except ImportError:
|
|
| 20 |
os.system("pip install -e .")
|
| 21 |
import mmaudio
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
|
| 24 |
setup_eval_logging)
|
| 25 |
from mmaudio.model.flow_matching import FlowMatching
|
|
@@ -27,6 +32,18 @@ from mmaudio.model.networks import MMAudio, get_my_mmaudio
|
|
| 27 |
from mmaudio.model.sequence_config import SequenceConfig
|
| 28 |
from mmaudio.model.utils.features_utils import FeaturesUtils
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
# Video generation model setup
|
| 31 |
MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
|
| 32 |
LORA_REPO_ID = "Kijai/WanVideo_comfy"
|
|
@@ -53,26 +70,39 @@ log = logging.getLogger()
|
|
| 53 |
device = 'cuda'
|
| 54 |
dtype = torch.bfloat16
|
| 55 |
|
| 56 |
-
|
| 57 |
-
audio_model
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
seq_cfg = audio_model.seq_cfg
|
| 62 |
-
net: MMAudio = get_my_mmaudio(audio_model.model_name).to(device, dtype).eval()
|
| 63 |
-
net.load_weights(torch.load(audio_model.model_path, map_location=device, weights_only=True))
|
| 64 |
-
log.info(f'Loaded weights from {audio_model.model_path}')
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
# Constants
|
| 78 |
MOD_VALUE = 32
|
|
@@ -292,6 +322,13 @@ def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_
|
|
| 292 |
gr.Warning("Error attempting to calculate new dimensions")
|
| 293 |
return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
|
| 294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
def get_duration(input_image, prompt, height, width,
|
| 296 |
negative_prompt, duration_seconds,
|
| 297 |
guidance_scale, steps,
|
|
@@ -315,6 +352,9 @@ def get_duration(input_image, prompt, height, width,
|
|
| 315 |
def add_audio_to_video(video_path, duration_sec, audio_prompt, audio_negative_prompt,
|
| 316 |
audio_seed, audio_steps, audio_cfg_strength):
|
| 317 |
"""Add audio to video using MMAudio"""
|
|
|
|
|
|
|
|
|
|
| 318 |
rng = torch.Generator(device=device)
|
| 319 |
if audio_seed >= 0:
|
| 320 |
rng.manual_seed(audio_seed)
|
|
@@ -327,14 +367,14 @@ def add_audio_to_video(video_path, duration_sec, audio_prompt, audio_negative_pr
|
|
| 327 |
clip_frames = video_info.clip_frames.unsqueeze(0)
|
| 328 |
sync_frames = video_info.sync_frames.unsqueeze(0)
|
| 329 |
duration = video_info.duration_sec
|
| 330 |
-
|
| 331 |
-
|
| 332 |
|
| 333 |
audios = generate(clip_frames,
|
| 334 |
sync_frames, [audio_prompt],
|
| 335 |
negative_text=[audio_negative_prompt],
|
| 336 |
-
feature_utils=
|
| 337 |
-
net=
|
| 338 |
fm=fm,
|
| 339 |
rng=rng,
|
| 340 |
cfg_strength=audio_cfg_strength)
|
|
@@ -342,7 +382,7 @@ def add_audio_to_video(video_path, duration_sec, audio_prompt, audio_negative_pr
|
|
| 342 |
|
| 343 |
# Save video with audio
|
| 344 |
video_with_audio_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
|
| 345 |
-
make_video(video_info, video_with_audio_path, audio, sampling_rate=
|
| 346 |
|
| 347 |
return video_with_audio_path
|
| 348 |
|
|
@@ -391,6 +431,10 @@ def generate_video(input_image, prompt, height, width,
|
|
| 391 |
audio_seed, audio_steps, audio_cfg_strength
|
| 392 |
)
|
| 393 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
return video_path, video_with_audio_path, current_seed
|
| 395 |
|
| 396 |
def update_audio_visibility(audio_mode):
|
|
|
|
| 12 |
import logging
|
| 13 |
import torchaudio
|
| 14 |
import os
|
| 15 |
+
import gc
|
| 16 |
|
| 17 |
# MMAudio imports
|
| 18 |
try:
|
|
|
|
| 21 |
os.system("pip install -e .")
|
| 22 |
import mmaudio
|
| 23 |
|
| 24 |
+
# Set environment variables for better memory management
|
| 25 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
|
| 26 |
+
os.environ['HF_HUB_CACHE'] = '/tmp/hub' # Use temp directory to avoid filling persistent storage
|
| 27 |
+
|
| 28 |
from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
|
| 29 |
setup_eval_logging)
|
| 30 |
from mmaudio.model.flow_matching import FlowMatching
|
|
|
|
| 32 |
from mmaudio.model.sequence_config import SequenceConfig
|
| 33 |
from mmaudio.model.utils.features_utils import FeaturesUtils
|
| 34 |
|
| 35 |
+
# Clean up temp files periodically
|
| 36 |
+
def cleanup_temp_files():
|
| 37 |
+
"""Clean up temporary files to save storage"""
|
| 38 |
+
temp_dir = tempfile.gettempdir()
|
| 39 |
+
for filename in os.listdir(temp_dir):
|
| 40 |
+
filepath = os.path.join(temp_dir, filename)
|
| 41 |
+
try:
|
| 42 |
+
if filename.endswith(('.mp4', '.flac', '.wav')):
|
| 43 |
+
os.remove(filepath)
|
| 44 |
+
except:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
# Video generation model setup
|
| 48 |
MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
|
| 49 |
LORA_REPO_ID = "Kijai/WanVideo_comfy"
|
|
|
|
| 70 |
device = 'cuda'
|
| 71 |
dtype = torch.bfloat16
|
| 72 |
|
| 73 |
+
# Global variables for audio model (loaded on demand)
|
| 74 |
+
audio_model = None
|
| 75 |
+
audio_net = None
|
| 76 |
+
audio_feature_utils = None
|
| 77 |
+
audio_seq_cfg = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
+
def load_audio_model():
|
| 80 |
+
"""Load audio model on demand to save storage"""
|
| 81 |
+
global audio_model, audio_net, audio_feature_utils, audio_seq_cfg
|
| 82 |
+
|
| 83 |
+
if audio_net is None:
|
| 84 |
+
audio_model = all_model_cfg['small_16k'] # Use smaller model
|
| 85 |
+
audio_model.download_if_needed()
|
| 86 |
+
setup_eval_logging()
|
| 87 |
+
|
| 88 |
+
seq_cfg = audio_model.seq_cfg
|
| 89 |
+
net = get_my_mmaudio(audio_model.model_name).to(device, dtype).eval()
|
| 90 |
+
net.load_weights(torch.load(audio_model.model_path, map_location=device, weights_only=True))
|
| 91 |
+
log.info(f'Loaded weights from {audio_model.model_path}')
|
| 92 |
+
|
| 93 |
+
feature_utils = FeaturesUtils(tod_vae_ckpt=audio_model.vae_path,
|
| 94 |
+
synchformer_ckpt=audio_model.synchformer_ckpt,
|
| 95 |
+
enable_conditions=True,
|
| 96 |
+
mode=audio_model.mode,
|
| 97 |
+
bigvgan_vocoder_ckpt=audio_model.bigvgan_16k_path,
|
| 98 |
+
need_vae_encoder=False)
|
| 99 |
+
feature_utils = feature_utils.to(device, dtype).eval()
|
| 100 |
+
|
| 101 |
+
audio_net = net
|
| 102 |
+
audio_feature_utils = feature_utils
|
| 103 |
+
audio_seq_cfg = seq_cfg
|
| 104 |
+
|
| 105 |
+
return audio_net, audio_feature_utils, audio_seq_cfg
|
| 106 |
|
| 107 |
# Constants
|
| 108 |
MOD_VALUE = 32
|
|
|
|
| 322 |
gr.Warning("Error attempting to calculate new dimensions")
|
| 323 |
return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
|
| 324 |
|
| 325 |
+
def clear_cache():
|
| 326 |
+
"""Clear GPU and CPU cache to free memory"""
|
| 327 |
+
if torch.cuda.is_available():
|
| 328 |
+
torch.cuda.empty_cache()
|
| 329 |
+
torch.cuda.synchronize()
|
| 330 |
+
gc.collect()
|
| 331 |
+
|
| 332 |
def get_duration(input_image, prompt, height, width,
|
| 333 |
negative_prompt, duration_seconds,
|
| 334 |
guidance_scale, steps,
|
|
|
|
| 352 |
def add_audio_to_video(video_path, duration_sec, audio_prompt, audio_negative_prompt,
|
| 353 |
audio_seed, audio_steps, audio_cfg_strength):
|
| 354 |
"""Add audio to video using MMAudio"""
|
| 355 |
+
# Load audio model on demand
|
| 356 |
+
net, feature_utils, seq_cfg = load_audio_model()
|
| 357 |
+
|
| 358 |
rng = torch.Generator(device=device)
|
| 359 |
if audio_seed >= 0:
|
| 360 |
rng.manual_seed(audio_seed)
|
|
|
|
| 367 |
clip_frames = video_info.clip_frames.unsqueeze(0)
|
| 368 |
sync_frames = video_info.sync_frames.unsqueeze(0)
|
| 369 |
duration = video_info.duration_sec
|
| 370 |
+
seq_cfg.duration = duration
|
| 371 |
+
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
| 372 |
|
| 373 |
audios = generate(clip_frames,
|
| 374 |
sync_frames, [audio_prompt],
|
| 375 |
negative_text=[audio_negative_prompt],
|
| 376 |
+
feature_utils=feature_utils,
|
| 377 |
+
net=net,
|
| 378 |
fm=fm,
|
| 379 |
rng=rng,
|
| 380 |
cfg_strength=audio_cfg_strength)
|
|
|
|
| 382 |
|
| 383 |
# Save video with audio
|
| 384 |
video_with_audio_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
|
| 385 |
+
make_video(video_info, video_with_audio_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
| 386 |
|
| 387 |
return video_with_audio_path
|
| 388 |
|
|
|
|
| 431 |
audio_seed, audio_steps, audio_cfg_strength
|
| 432 |
)
|
| 433 |
|
| 434 |
+
# Clear cache to free memory
|
| 435 |
+
clear_cache()
|
| 436 |
+
cleanup_temp_files() # Clean up temp files
|
| 437 |
+
|
| 438 |
return video_path, video_with_audio_path, current_seed
|
| 439 |
|
| 440 |
def update_audio_visibility(audio_mode):
|