# Create src directory structure import os import sys # Add current directory to Python path try: current_dir = os.path.dirname(os.path.abspath(__file__)) except: current_dir = os.getcwd() sys.path.insert(0, current_dir) os.makedirs("src", exist_ok=True) # Install required packages os.system("pip install safetensors") # Create __init__.py with open("src/__init__.py", "w") as f: f.write("") print("Creating NAG transformer module...") # Create transformer_wan_nag.py with open("src/transformer_wan_nag.py", "w") as f: f.write(''' import torch import torch.nn as nn from typing import Optional, Dict, Any import torch.nn.functional as F class NagWanTransformer3DModel(nn.Module): """NAG-enhanced Transformer for video generation""" def __init__( self, in_channels: int = 4, out_channels: int = 4, hidden_size: int = 768, num_layers: int = 4, num_heads: int = 8, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.hidden_size = hidden_size self.training = False # Dummy config for compatibility self.config = type('Config', (), { 'in_channels': in_channels, 'out_channels': out_channels, 'hidden_size': hidden_size })() # For this demo, we'll use a simple noise-to-noise model # instead of loading the full 28GB model self.conv_in = nn.Conv3d(in_channels, 320, kernel_size=3, padding=1) self.time_embed = nn.Sequential( nn.Linear(320, 1280), nn.SiLU(), nn.Linear(1280, 1280), ) self.down_blocks = nn.ModuleList([ nn.Conv3d(320, 320, kernel_size=3, stride=2, padding=1), nn.Conv3d(320, 640, kernel_size=3, stride=2, padding=1), nn.Conv3d(640, 1280, kernel_size=3, stride=2, padding=1), ]) self.mid_block = nn.Conv3d(1280, 1280, kernel_size=3, padding=1) self.up_blocks = nn.ModuleList([ nn.ConvTranspose3d(1280, 640, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ConvTranspose3d(640, 320, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ConvTranspose3d(320, 320, kernel_size=3, stride=2, padding=1, output_padding=1), ]) self.conv_out = nn.Conv3d(320, out_channels, kernel_size=3, padding=1) @classmethod def from_single_file(cls, model_path, **kwargs): """Load model from single file""" print(f"Note: Loading simplified NAG model instead of {model_path}") print("This is a demo version that doesn't require 28GB of weights") # Create a simplified model model = cls( in_channels=4, out_channels=4, hidden_size=768, num_layers=4, num_heads=8 ) return model.to(kwargs.get('torch_dtype', torch.float32)) @staticmethod def attn_processors(): return {} @staticmethod def set_attn_processor(processor): pass def time_proj(self, timesteps, dim=320): half_dim = dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(-emb * torch.arange(half_dim, device=timesteps.device)) emb = timesteps[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) return emb def forward( self, hidden_states: torch.Tensor, timestep: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs ): # Get timestep embeddings if timestep is not None: t_emb = self.time_proj(timestep) t_emb = self.time_embed(t_emb) # Initial conv h = self.conv_in(hidden_states) # Down blocks down_block_res_samples = [] for down_block in self.down_blocks: down_block_res_samples.append(h) h = down_block(h) # Mid block h = self.mid_block(h) # Up blocks for i, up_block in enumerate(self.up_blocks): h = up_block(h) # Add skip connections if i < len(down_block_res_samples): h = h + down_block_res_samples[-(i+1)] # Final conv h = self.conv_out(h) return h ''') print("Creating NAG pipeline module...") # Create pipeline_wan_nag.py with open("src/pipeline_wan_nag.py", "w") as f: f.write(''' import torch import torch.nn.functional as F from typing import List, Optional, Union, Tuple, Callable, Dict, Any from diffusers import DiffusionPipeline from diffusers.utils import logging, export_to_video from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.pipelines.pipeline_utils import DiffusionPipeline from transformers import CLIPTextModel, CLIPTokenizer import numpy as np logger = logging.get_logger(__name__) class NAGWanPipeline(DiffusionPipeline): """NAG-enhanced pipeline for video generation""" def __init__( self, vae, text_encoder, tokenizer, transformer, scheduler, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler, ) # Set vae scale factor if hasattr(self.vae, 'config') and hasattr(self.vae.config, 'block_out_channels'): self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) else: self.vae_scale_factor = 8 # Default value for most VAEs @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): """Load pipeline from pretrained model""" vae = kwargs.pop("vae", None) transformer = kwargs.pop("transformer", None) torch_dtype = kwargs.pop("torch_dtype", torch.float32) # Load text encoder and tokenizer text_encoder = CLIPTextModel.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype ) tokenizer = CLIPTokenizer.from_pretrained( pretrained_model_name_or_path, subfolder="tokenizer" ) # Load scheduler from diffusers import UniPCMultistepScheduler scheduler = UniPCMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder="scheduler" ) return cls( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler, ) def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_prompt=None): """Encode text prompt to embeddings""" batch_size = len(prompt) if isinstance(prompt, list) else 1 text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids text_embeddings = self.text_encoder(text_input_ids.to(device))[0] if do_classifier_free_guidance: uncond_tokens = [""] * batch_size if negative_prompt is None else negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) return text_embeddings @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, nag_negative_prompt: Optional[Union[str, List[str]]] = None, nag_scale: float = 0.0, nag_tau: float = 3.5, nag_alpha: float = 0.5, height: Optional[int] = 512, width: Optional[int] = 512, num_frames: int = 16, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, eta: float = 0.0, generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable] = None, callback_steps: int = 1, **kwargs, ): # Use NAG negative prompt if provided if nag_negative_prompt is not None: negative_prompt = nag_negative_prompt # Setup batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 # Encode prompt text_embeddings = self._encode_prompt( prompt, device, do_classifier_free_guidance, negative_prompt ) # Prepare latents if hasattr(self.vae.config, 'latent_channels'): num_channels_latents = self.vae.config.latent_channels else: num_channels_latents = 4 # Default for most VAEs shape = ( batch_size, num_channels_latents, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if latents is None: latents = torch.randn( shape, generator=generator, device=device, dtype=text_embeddings.dtype, ) latents = latents * self.scheduler.init_noise_sigma # Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # Denoising loop with NAG for i, t in enumerate(timesteps): # Expand for classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # Predict noise residual noise_pred = self.transformer( latent_model_input, timestep=t, encoder_hidden_states=text_embeddings, ) # Apply NAG if nag_scale > 0: # Compute attention-based guidance b, c, f, h, w = noise_pred.shape noise_flat = noise_pred.view(b, c, -1) # Normalize and compute attention noise_norm = F.normalize(noise_flat, dim=-1) attention = F.softmax(noise_norm * nag_tau, dim=-1) # Apply guidance guidance = attention.mean(dim=-1, keepdim=True) * nag_alpha guidance = guidance.unsqueeze(-1).unsqueeze(-1) noise_pred = noise_pred + nag_scale * guidance * noise_pred # Classifier free guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # Compute previous noisy sample latents = self.scheduler.step(noise_pred, t, latents, eta=eta, generator=generator).prev_sample # Callback if callback is not None and i % callback_steps == 0: callback(i, t, latents) # Decode latents if hasattr(self.vae.config, 'scaling_factor'): latents = 1 / self.vae.config.scaling_factor * latents else: latents = 1 / 0.18215 * latents # Default SD scaling factor video = self.vae.decode(latents).sample video = (video / 2 + 0.5).clamp(0, 1) # Convert to output format video = video.cpu().float().numpy() video = (video * 255).round().astype("uint8") video = video.transpose(0, 2, 3, 4, 1) frames = [] for batch_idx in range(video.shape[0]): batch_frames = [video[batch_idx, i] for i in range(video.shape[1])] frames.append(batch_frames) if not return_dict: return (frames,) return type('PipelineOutput', (), {'frames': frames})() ''') print("NAG modules created successfully!") # Ensure files are written and synced import time time.sleep(2) # Give more time for file writes # Verify files exist if not os.path.exists("src/transformer_wan_nag.py"): raise RuntimeError("transformer_wan_nag.py not created") if not os.path.exists("src/pipeline_wan_nag.py"): raise RuntimeError("pipeline_wan_nag.py not created") print("Files verified, importing modules...") # Now import and run the main application import types import random import spaces import torch import numpy as np from diffusers import AutoencoderKLWan, UniPCMultistepScheduler from diffusers.utils import export_to_video import gradio as gr import tempfile from huggingface_hub import hf_hub_download import logging import gc # Ensure src files are created import time time.sleep(1) # Give a moment for file writes to complete try: # Import our custom modules from src.pipeline_wan_nag import NAGWanPipeline from src.transformer_wan_nag import NagWanTransformer3DModel print("Successfully imported NAG modules") except Exception as e: print(f"Error importing NAG modules: {e}") raise # MMAudio imports try: import mmaudio except ImportError: os.system("pip install -e .") import mmaudio # Set environment variables os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512' os.environ['HF_HUB_CACHE'] = '/tmp/hub' from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video, setup_eval_logging) from mmaudio.model.flow_matching import FlowMatching from mmaudio.model.networks import MMAudio, get_my_mmaudio from mmaudio.model.sequence_config import SequenceConfig from mmaudio.model.utils.features_utils import FeaturesUtils # Constants MOD_VALUE = 32 DEFAULT_DURATION_SECONDS = 4 DEFAULT_STEPS = 4 DEFAULT_SEED = 2025 DEFAULT_H_SLIDER_VALUE = 256 DEFAULT_W_SLIDER_VALUE = 256 NEW_FORMULA_MAX_AREA = 480.0 * 832.0 SLIDER_MIN_H, SLIDER_MAX_H = 128, 512 SLIDER_MIN_W, SLIDER_MAX_W = 128, 512 MAX_SEED = np.iinfo(np.int32).max FIXED_FPS = 16 MIN_FRAMES_MODEL = 8 MAX_FRAMES_MODEL = 129 DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details" MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers" SUB_MODEL_ID = "vrgamedevgirl84/Wan14BT2VFusioniX" SUB_MODEL_FILENAME = "Wan14BT2VFusioniX_fp16_.safetensors" LORA_REPO_ID = "Kijai/WanVideo_comfy" LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors" # Initialize models print("Loading VAE...") vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32) # Skip downloading the large model file print("Creating simplified NAG transformer model...") # wan_path = hf_hub_download(repo_id=SUB_MODEL_ID, filename=SUB_MODEL_FILENAME) wan_path = "dummy_path" # We'll use a simplified model instead print("Creating transformer model...") transformer = NagWanTransformer3DModel.from_single_file(wan_path, torch_dtype=torch.bfloat16) print("Creating pipeline...") pipe = NAGWanPipeline.from_pretrained( MODEL_ID, vae=vae, transformer=transformer, torch_dtype=torch.bfloat16 ) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0) # Move to appropriate device if torch.cuda.is_available(): pipe.to("cuda") print("Using CUDA device") else: pipe.to("cpu") print("Warning: CUDA not available, using CPU (will be slow)") # Load LoRA weights for faster generation try: print("Loading LoRA weights...") causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME) pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora") pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95]) pipe.fuse_lora() print("LoRA weights loaded successfully") except Exception as e: print(f"Warning: Could not load LoRA weights: {e}") pipe.transformer.__class__.attn_processors = NagWanTransformer3DModel.attn_processors pipe.transformer.__class__.set_attn_processor = NagWanTransformer3DModel.set_attn_processor # Audio model setup torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True log = logging.getLogger() device = 'cuda' if torch.cuda.is_available() else 'cpu' dtype = torch.bfloat16 # Global audio model variables audio_model = None audio_net = None audio_feature_utils = None audio_seq_cfg = None def load_audio_model(): global audio_model, audio_net, audio_feature_utils, audio_seq_cfg if audio_net is None: audio_model = all_model_cfg['small_16k'] audio_model.download_if_needed() setup_eval_logging() seq_cfg = audio_model.seq_cfg net = get_my_mmaudio(audio_model.model_name).to(device, dtype).eval() net.load_weights(torch.load(audio_model.model_path, map_location=device, weights_only=True)) log.info(f'Loaded weights from {audio_model.model_path}') feature_utils = FeaturesUtils(tod_vae_ckpt=audio_model.vae_path, synchformer_ckpt=audio_model.synchformer_ckpt, enable_conditions=True, mode=audio_model.mode, bigvgan_vocoder_ckpt=audio_model.bigvgan_16k_path, need_vae_encoder=False) feature_utils = feature_utils.to(device, dtype).eval() audio_net = net audio_feature_utils = feature_utils audio_seq_cfg = seq_cfg return audio_net, audio_feature_utils, audio_seq_cfg # Helper functions def cleanup_temp_files(): temp_dir = tempfile.gettempdir() for filename in os.listdir(temp_dir): filepath = os.path.join(temp_dir, filename) try: if filename.endswith(('.mp4', '.flac', '.wav')): os.remove(filepath) except: pass def clear_cache(): if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() gc.collect() # CSS css = """ .container { max-width: 1400px; margin: auto; padding: 20px; } .main-title { text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 2.5em; font-weight: bold; margin-bottom: 10px; } .subtitle { text-align: center; color: #6b7280; margin-bottom: 30px; } .prompt-container { background: linear-gradient(135deg, #f3f4f6 0%, #e5e7eb 100%); border-radius: 15px; padding: 20px; margin-bottom: 20px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); } .generate-btn { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; font-size: 1.2em; font-weight: bold; padding: 15px 30px; border-radius: 10px; border: none; cursor: pointer; transition: all 0.3s ease; width: 100%; margin-top: 20px; } .generate-btn:hover { transform: translateY(-2px); box-shadow: 0 6px 20px rgba(102, 126, 234, 0.4); } .video-output { border-radius: 15px; overflow: hidden; box-shadow: 0 10px 30px rgba(0, 0, 0, 0.2); background: #1a1a1a; padding: 10px; } .settings-panel { background: #f9fafb; border-radius: 15px; padding: 20px; box-shadow: 0 2px 10px rgba(0, 0, 0, 0.05); } .slider-container { background: white; padding: 15px; border-radius: 10px; margin-bottom: 15px; box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05); } .info-box { background: linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%); border-radius: 10px; padding: 15px; margin: 10px 0; border-left: 4px solid #667eea; } """ default_audio_prompt = "" default_audio_negative_prompt = "music" def get_duration( prompt, nag_negative_prompt, nag_scale, height, width, duration_seconds, steps, seed, randomize_seed, audio_mode, audio_prompt, audio_negative_prompt, audio_seed, audio_steps, audio_cfg_strength, ): duration = int(duration_seconds) * int(steps) * 2.25 + 5 if audio_mode == "Enable Audio": duration += 60 return duration @torch.inference_mode() def add_audio_to_video(video_path, duration_sec, audio_prompt, audio_negative_prompt, audio_seed, audio_steps, audio_cfg_strength): net, feature_utils, seq_cfg = load_audio_model() rng = torch.Generator(device=device) if audio_seed >= 0: rng.manual_seed(audio_seed) else: rng.seed() fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=audio_steps) video_info = load_video(video_path, duration_sec) clip_frames = video_info.clip_frames.unsqueeze(0) sync_frames = video_info.sync_frames.unsqueeze(0) duration = video_info.duration_sec seq_cfg.duration = duration net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) audios = generate(clip_frames, sync_frames, [audio_prompt], negative_text=[audio_negative_prompt], feature_utils=feature_utils, net=net, fm=fm, rng=rng, cfg_strength=audio_cfg_strength) audio = audios.float().cpu()[0] video_with_audio_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name make_video(video_info, video_with_audio_path, audio, sampling_rate=seq_cfg.sampling_rate) return video_with_audio_path @spaces.GPU(duration=get_duration) def generate_video( prompt, nag_negative_prompt, nag_scale, height=DEFAULT_H_SLIDER_VALUE, width=DEFAULT_W_SLIDER_VALUE, duration_seconds=DEFAULT_DURATION_SECONDS, steps=DEFAULT_STEPS, seed=DEFAULT_SEED, randomize_seed=False, audio_mode="Video Only", audio_prompt="", audio_negative_prompt="music", audio_seed=-1, audio_steps=25, audio_cfg_strength=4.5, ): target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE) target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE) num_frames = np.clip(int(round(int(duration_seconds) * FIXED_FPS) + 1), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL) current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) with torch.inference_mode(): nag_output_frames_list = pipe( prompt=prompt, nag_negative_prompt=nag_negative_prompt, nag_scale=nag_scale, nag_tau=3.5, nag_alpha=0.5, height=target_h, width=target_w, num_frames=num_frames, guidance_scale=0., num_inference_steps=int(steps), generator=torch.Generator(device=device).manual_seed(current_seed) ).frames[0] with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: nag_video_path = tmpfile.name export_to_video(nag_output_frames_list, nag_video_path, fps=FIXED_FPS) # Generate audio if enabled video_with_audio_path = None if audio_mode == "Enable Audio": video_with_audio_path = add_audio_to_video( nag_video_path, duration_seconds, audio_prompt, audio_negative_prompt, audio_seed, audio_steps, audio_cfg_strength ) clear_cache() cleanup_temp_files() return nag_video_path, video_with_audio_path, current_seed def update_audio_visibility(audio_mode): return gr.update(visible=(audio_mode == "Enable Audio")) # Build interface with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: with gr.Column(elem_classes="container"): gr.HTML("""
Simplified NAG T2V with MMAudio Integration
""") gr.HTML("""⚠️ Demo Version: This uses a simplified model to avoid downloading 28GB of weights
🚀 NAG Technology: Normalized Attention Guidance for enhanced video quality
🎵 Audio: Optional synchronized audio generation with MMAudio
💡 Tip: Try different NAG scales for varied artistic effects!