import torch import torch.nn.functional as F from diffusers import AutoencoderKLWan, WanVideoTextToVideoPipeline, UniPCMultistepScheduler from diffusers.utils import export_to_video from diffusers.models import Transformer2DModel import gradio as gr import tempfile import spaces from huggingface_hub import hf_hub_download import numpy as np import random import logging import os import gc from typing import List, Optional, Union # MMAudio imports try: import mmaudio except ImportError: os.system("pip install -e .") import mmaudio # Set environment variables os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512' os.environ['HF_HUB_CACHE'] = '/tmp/hub' from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video, setup_eval_logging) from mmaudio.model.flow_matching import FlowMatching from mmaudio.model.networks import MMAudio, get_my_mmaudio from mmaudio.model.sequence_config import SequenceConfig from mmaudio.model.utils.features_utils import FeaturesUtils # NAG-enhanced Pipeline class NAGWanPipeline(WanVideoTextToVideoPipeline): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.nag_scale = 0.0 self.nag_tau = 3.5 self.nag_alpha = 0.5 @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, nag_negative_prompt: Optional[Union[str, List[str]]] = None, nag_scale: float = 0.0, nag_tau: float = 3.5, nag_alpha: float = 0.5, height: Optional[int] = None, width: Optional[int] = None, num_frames: int = 16, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback = None, callback_steps: int = 1, cross_attention_kwargs: Optional[dict] = None, clip_skip: Optional[int] = None, ): # Use NAG negative prompt if provided if nag_negative_prompt is not None: negative_prompt = nag_negative_prompt # Store NAG parameters self.nag_scale = nag_scale self.nag_tau = nag_tau self.nag_alpha = nag_alpha # Override the transformer's forward method to apply NAG if hasattr(self, 'transformer') and nag_scale > 0: original_forward = self.transformer.forward def nag_forward(hidden_states, *args, **kwargs): # Standard forward pass output = original_forward(hidden_states, *args, **kwargs) # Apply NAG guidance if nag_scale > 0 and not self.transformer.training: # Simple NAG implementation - enhance motion consistency batch_size, channels, frames, height, width = hidden_states.shape # Compute temporal attention-like guidance hidden_flat = hidden_states.view(batch_size, channels, -1) attention = F.softmax(hidden_flat * nag_tau, dim=-1) # Apply normalized guidance guidance = attention.mean(dim=2, keepdim=True) * nag_alpha guidance = guidance.unsqueeze(-1).unsqueeze(-1) # Scale and add guidance if hasattr(output, 'sample'): output.sample = output.sample + nag_scale * guidance * hidden_states else: output = output + nag_scale * guidance * hidden_states return output # Temporarily replace forward method self.transformer.forward = nag_forward # Call parent pipeline result = super().__call__( prompt=prompt, height=height, width=width, num_frames=num_frames, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, negative_prompt=negative_prompt, eta=eta, generator=generator, latents=latents, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, output_type=output_type, return_dict=return_dict, callback=callback, callback_steps=callback_steps, cross_attention_kwargs=cross_attention_kwargs, clip_skip=clip_skip, ) # Restore original forward method if hasattr(self, 'transformer') and hasattr(self.transformer, 'forward'): self.transformer.forward = original_forward return result # Clean up temp files def cleanup_temp_files(): temp_dir = tempfile.gettempdir() for filename in os.listdir(temp_dir): filepath = os.path.join(temp_dir, filename) try: if filename.endswith(('.mp4', '.flac', '.wav')): os.remove(filepath) except: pass # Video generation model setup MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers" LORA_REPO_ID = "Kijai/WanVideo_comfy" LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors" # Load the model components vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32) pipe = NAGWanPipeline.from_pretrained( MODEL_ID, vae=vae, torch_dtype=torch.bfloat16 ) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0) pipe.to("cuda") # Load LoRA weights for faster generation causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME) pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora") pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95]) pipe.fuse_lora() # Audio generation model setup torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True log = logging.getLogger() device = 'cuda' dtype = torch.bfloat16 # Global variables for audio model audio_model = None audio_net = None audio_feature_utils = None audio_seq_cfg = None def load_audio_model(): global audio_model, audio_net, audio_feature_utils, audio_seq_cfg if audio_net is None: audio_model = all_model_cfg['small_16k'] audio_model.download_if_needed() setup_eval_logging() seq_cfg = audio_model.seq_cfg net = get_my_mmaudio(audio_model.model_name).to(device, dtype).eval() net.load_weights(torch.load(audio_model.model_path, map_location=device, weights_only=True)) log.info(f'Loaded weights from {audio_model.model_path}') feature_utils = FeaturesUtils(tod_vae_ckpt=audio_model.vae_path, synchformer_ckpt=audio_model.synchformer_ckpt, enable_conditions=True, mode=audio_model.mode, bigvgan_vocoder_ckpt=audio_model.bigvgan_16k_path, need_vae_encoder=False) feature_utils = feature_utils.to(device, dtype).eval() audio_net = net audio_feature_utils = feature_utils audio_seq_cfg = seq_cfg return audio_net, audio_feature_utils, audio_seq_cfg # Constants MOD_VALUE = 32 DEFAULT_DURATION_SECONDS = 4 DEFAULT_STEPS = 4 DEFAULT_SEED = 2025 DEFAULT_H_SLIDER_VALUE = 480 DEFAULT_W_SLIDER_VALUE = 832 SLIDER_MIN_H, SLIDER_MAX_H = 128, 896 SLIDER_MIN_W, SLIDER_MAX_W = 128, 896 MAX_SEED = np.iinfo(np.int32).max FIXED_FPS = 16 MIN_FRAMES_MODEL = 8 MAX_FRAMES_MODEL = 129 DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details" default_prompt = "A ginger cat passionately plays electric guitar with intensity and emotion on a stage" default_audio_prompt = "" default_audio_negative_prompt = "music" # CSS custom_css = """ /* 전체 배경 그라디언트 */ .gradio-container { font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif !important; background: linear-gradient(135deg, #667eea 0%, #764ba2 25%, #f093fb 50%, #f5576c 75%, #fa709a 100%) !important; background-size: 400% 400% !important; animation: gradientShift 15s ease infinite !important; } @keyframes gradientShift { 0% { background-position: 0% 50%; } 50% { background-position: 100% 50%; } 100% { background-position: 0% 50%; } } /* 메인 컨테이너 스타일 */ .main-container { backdrop-filter: blur(10px); background: rgba(255, 255, 255, 0.1) !important; border-radius: 20px !important; padding: 30px !important; box-shadow: 0 8px 32px 0 rgba(31, 38, 135, 0.37) !important; border: 1px solid rgba(255, 255, 255, 0.18) !important; } /* 헤더 스타일 */ h1 { background: linear-gradient(45deg, #ffffff, #f0f0f0) !important; -webkit-background-clip: text !important; -webkit-text-fill-color: transparent !important; background-clip: text !important; font-weight: 800 !important; font-size: 2.5rem !important; text-align: center !important; margin-bottom: 2rem !important; text-shadow: 2px 2px 4px rgba(0,0,0,0.1) !important; } /* 컴포넌트 컨테이너 스타일 */ .input-container, .output-container { background: rgba(255, 255, 255, 0.08) !important; border-radius: 15px !important; padding: 20px !important; margin: 10px 0 !important; backdrop-filter: blur(5px) !important; border: 1px solid rgba(255, 255, 255, 0.1) !important; } /* 입력 필드 스타일 */ input, textarea, .gr-box { background: rgba(255, 255, 255, 0.9) !important; border: 1px solid rgba(255, 255, 255, 0.3) !important; border-radius: 10px !important; color: #333 !important; transition: all 0.3s ease !important; } input:focus, textarea:focus { background: rgba(255, 255, 255, 1) !important; border-color: #667eea !important; box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1) !important; } /* 버튼 스타일 */ .generate-btn { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; color: white !important; font-weight: 600 !important; font-size: 1.1rem !important; padding: 12px 30px !important; border-radius: 50px !important; border: none !important; cursor: pointer !important; transition: all 0.3s ease !important; box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important; } .generate-btn:hover { transform: translateY(-2px) !important; box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important; } /* 슬라이더 스타일 */ input[type="range"] { background: transparent !important; } input[type="range"]::-webkit-slider-track { background: rgba(255, 255, 255, 0.3) !important; border-radius: 5px !important; height: 6px !important; } input[type="range"]::-webkit-slider-thumb { background: linear-gradient(135deg, #667eea, #764ba2) !important; border: 2px solid white !important; border-radius: 50% !important; cursor: pointer !important; width: 18px !important; height: 18px !important; -webkit-appearance: none !important; } /* Accordion 스타일 */ .gr-accordion { background: rgba(255, 255, 255, 0.05) !important; border-radius: 10px !important; border: 1px solid rgba(255, 255, 255, 0.1) !important; margin: 15px 0 !important; } /* 라벨 스타일 */ label { color: #ffffff !important; font-weight: 500 !important; font-size: 0.95rem !important; margin-bottom: 5px !important; } /* 비디오 출력 영역 */ video { border-radius: 15px !important; box-shadow: 0 4px 20px rgba(0, 0, 0, 0.3) !important; } /* Examples 섹션 스타일 */ .gr-examples { background: rgba(255, 255, 255, 0.05) !important; border-radius: 15px !important; padding: 20px !important; margin-top: 20px !important; } /* Checkbox 스타일 */ input[type="checkbox"] { accent-color: #667eea !important; } /* Radio 버튼 스타일 */ input[type="radio"] { accent-color: #667eea !important; } /* Info box */ .info-box { background: linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%); border-radius: 10px; padding: 15px; margin: 10px 0; border-left: 4px solid #667eea; } /* 반응형 애니메이션 */ @media (max-width: 768px) { h1 { font-size: 2rem !important; } .main-container { padding: 20px !important; } } """ def clear_cache(): if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() gc.collect() def get_duration(prompt, nag_negative_prompt, nag_scale, height, width, duration_seconds, steps, seed, randomize_seed, audio_mode, audio_prompt, audio_negative_prompt, audio_seed, audio_steps, audio_cfg_strength, progress): duration = int(duration_seconds) * int(steps) * 2.25 + 5 if audio_mode == "Enable Audio": duration += 60 return duration @torch.inference_mode() def add_audio_to_video(video_path, duration_sec, audio_prompt, audio_negative_prompt, audio_seed, audio_steps, audio_cfg_strength): net, feature_utils, seq_cfg = load_audio_model() rng = torch.Generator(device=device) if audio_seed >= 0: rng.manual_seed(audio_seed) else: rng.seed() fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=audio_steps) video_info = load_video(video_path, duration_sec) clip_frames = video_info.clip_frames.unsqueeze(0) sync_frames = video_info.sync_frames.unsqueeze(0) duration = video_info.duration_sec seq_cfg.duration = duration net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) audios = generate(clip_frames, sync_frames, [audio_prompt], negative_text=[audio_negative_prompt], feature_utils=feature_utils, net=net, fm=fm, rng=rng, cfg_strength=audio_cfg_strength) audio = audios.float().cpu()[0] video_with_audio_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name make_video(video_info, video_with_audio_path, audio, sampling_rate=seq_cfg.sampling_rate) return video_with_audio_path @spaces.GPU(duration=get_duration) def generate_video(prompt, nag_negative_prompt, nag_scale, height, width, duration_seconds, steps, seed, randomize_seed, audio_mode, audio_prompt, audio_negative_prompt, audio_seed, audio_steps, audio_cfg_strength, progress=gr.Progress(track_tqdm=True)): if not prompt.strip(): raise gr.Error("Please enter a text prompt to generate video.") target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE) target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE) num_frames = np.clip(int(round(int(duration_seconds) * FIXED_FPS) + 1), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL) current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) # Generate video using NAG with torch.inference_mode(): output_frames_list = pipe( prompt=prompt, nag_negative_prompt=nag_negative_prompt, nag_scale=nag_scale, nag_tau=3.5, nag_alpha=0.5, height=target_h, width=target_w, num_frames=num_frames, guidance_scale=0., # NAG replaces traditional guidance num_inference_steps=int(steps), generator=torch.Generator(device="cuda").manual_seed(current_seed) ).frames[0] # Save video without audio with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: video_path = tmpfile.name export_to_video(output_frames_list, video_path, fps=FIXED_FPS) # Generate audio if enabled video_with_audio_path = None if audio_mode == "Enable Audio": progress(0.5, desc="Generating audio...") video_with_audio_path = add_audio_to_video( video_path, duration_seconds, audio_prompt, audio_negative_prompt, audio_seed, audio_steps, audio_cfg_strength ) clear_cache() cleanup_temp_files() return video_path, video_with_audio_path, current_seed def update_audio_visibility(audio_mode): return gr.update(visible=(audio_mode == "Enable Audio")) with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: with gr.Column(elem_classes=["main-container"]): gr.Markdown("# ✨ Fast NAG T2V (14B) with Audio Generation") gr.Markdown("### 🚀 Normalized Attention Guidance + CausVid LoRA + MMAudio") gr.HTML("""
🎯 NAG (Normalized Attention Guidance): Enhanced motion consistency and quality
⚡ Speed: Generate videos in just 4-8 steps with CausVid LoRA
🎵 Audio: Optional synchronized audio generation with MMAudio
💡 Tip: Try different NAG scales for varied artistic effects!