EthanZyh's picture
copied from EthanZyh/DiffusionText2WorldGeneration
8c31d70
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import numpy as np
import torch
from huggingface_hub import snapshot_download
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from nemo import lightning as nl
from nemo.lightning.megatron_parallel import MegatronParallel
MegatronParallel.init_ddp = lambda self: None
from nemo.collections.diffusion.mcore_parallel_utils import Utils
from nemo.collections.diffusion.sampler.conditioner import VideoConditioner
from nemo.collections.diffusion.sampler.conditioner_configs import (
FPSConfig,
ImageSizeConfig,
NumFramesConfig,
PaddingMaskConfig,
TextConfig,
)
from nemo.collections.diffusion.sampler.cosmos.cosmos_diffusion_pipeline import CosmosDiffusionPipeline
from transformers import T5EncoderModel, T5TokenizerFast
from cosmos1.models.diffusion.nemo.inference.inference_utils import process_prompt, save_video
from .log import log
EXAMPLE_PROMPT = (
"The teal robot is cooking food in a kitchen. Steam rises from a simmering pot "
"as the robot chops vegetables on a worn wooden cutting board. Copper pans hang "
"from an overhead rack, catching glints of afternoon light, while a well-loved "
"cast iron skillet sits on the stovetop next to scattered measuring spoons and "
"a half-empty bottle of olive oil."
)
def parse_args():
parser = argparse.ArgumentParser(description="Video foundation model inference")
parser.add_argument(
"--model",
type=str,
default="Cosmos-1.0-Diffusion-7B-Text2World",
choices=["Cosmos-1.0-Diffusion-7B-Text2World", "Cosmos-1.0-Diffusion-14B-Text2World"],
)
parser.add_argument(
"--prompt",
type=str,
default=EXAMPLE_PROMPT,
help="Prompt which the sampled video condition on",
)
# We turn on negative prompt by default. set to "" to turn it off.
parser.add_argument(
"--negative_prompt",
type=str,
default=(
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
"Overall, the video is of poor quality."
),
help="Negative prompt which the sampled video condition on",
)
parser.add_argument("--subject_name", type=str, default="", help="Name of fine-tuned subject")
parser.add_argument("--guidance", type=float, default=7, help="Classifier-free guidance scale")
parser.add_argument("--sampler", type=str, default="RES", help="Currently only supports RES sampler.")
parser.add_argument("--video_save_path", type=str, default="outputs", help="Path to save the video")
parser.add_argument("--fps", type=int, default=24, help="FPS of the sampled video")
parser.add_argument("--height", type=int, default=704, help="Height of image to sample")
parser.add_argument("--width", type=int, default=1280, help="Width of image to sample")
parser.add_argument("--seed", type=int, default=1, help="Random seed")
parser.add_argument("--num_devices", type=int, default=1, help="Number of devices for inference")
parser.add_argument("--cp_size", type=int, default=1, help="Number of cp ranks for multi-gpu inference.")
parser.add_argument("--num_steps", type=float, default=35, help="Number of diffusion sampling steps")
parser.add_argument("--num_video_frames", type=int, default=121, help="Number of video frames to sample")
parser.add_argument("--tokenizer_dir", type=str, default="", help="Directory for video tokenizer")
parser.add_argument("--cosmos_assets_dir", type=str, default="", help="Directory containing cosmos assets")
parser.add_argument("--prompt_upsampler_dir", type=str, default="", help="Prompt upsampler weights directory")
parser.add_argument("--guardrail_dir", type=str, default="", help="Guardrails weights directory")
parser.add_argument("--nemo_checkpoint", type=str, default="", help="Video diffusion model nemo weights")
parser.add_argument("--t5_cache_dir", type=str, default=None, help="Path to T5 model")
parser.add_argument(
"--enable_prompt_upsampler", action="store_true", help="Whether to use prompt upsampling before generation"
)
args = parser.parse_args()
return args
def print_rank_0(string: str):
rank = torch.distributed.get_rank()
if rank == 0:
log.info(string)
@torch.no_grad()
def encode_for_batch(tokenizer: T5TokenizerFast, encoder: T5EncoderModel, prompts: list[str], max_length: int = 512):
"""
Encode a batch of text prompts to a batch of T5 embeddings.
Parameters:
tokenizer: T5 embedding tokenizer.
encoder: T5 embedding text encoder.
prompts: A batch of text prompts.
max_length: Sequence length of text embedding (defaults to 512).
"""
batch_encoding = tokenizer.batch_encode_plus(
prompts,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=max_length,
return_length=True,
return_offsets_mapping=False,
)
# We expect all the processing is done on GPU.
input_ids = batch_encoding.input_ids.cuda()
attn_mask = batch_encoding.attention_mask.cuda()
outputs = encoder(input_ids=input_ids, attention_mask=attn_mask)
encoded_text = outputs.last_hidden_state
lengths = attn_mask.sum(dim=1).cpu()
for batch_id in range(encoded_text.shape[0]):
encoded_text[batch_id][lengths[batch_id] :] = 0
return encoded_text
def init_video_tokenizer(args):
"""
Initializes video tokenizer based on specified video tokenizer config / path.
"""
from nemo.collections.diffusion.models.model import DiT7BConfig, DiT14BConfig
vae_path = os.path.join(args.cosmos_assets_dir, args.tokenizer_dir)
if "7b" in args.nemo_checkpoint.lower():
dit_config = DiT7BConfig(vae_path=vae_path)
if "14b" in args.nemo_checkpoint.lower():
dit_config = DiT14BConfig(vae_path=vae_path)
vae = dit_config.configure_vae()
return vae
def check_prompt(args):
prompt = args.prompt
subject_string = None
if args.subject_name:
subject_string = f"A video of sks {args.subject_name}"
prompt = process_prompt(
prompt=prompt,
checkpoint_dir=args.cosmos_assets_dir,
prompt_upsampler_dir=args.prompt_upsampler_dir,
guardrails_dir=args.guardrail_dir,
enable_prompt_upsampler=args.enable_prompt_upsampler,
)
if subject_string:
prompt = f"{subject_string}. {prompt}"
return prompt
def prepare_data_batch(args, vae, t5_embeding_max_length=512):
tokenizer = T5TokenizerFast.from_pretrained("google-t5/t5-11b", cache_dir=args.t5_cache_dir)
text_encoder = T5EncoderModel.from_pretrained("google-t5/t5-11b", cache_dir=args.t5_cache_dir)
text_encoder.to("cuda")
text_encoder.eval()
# Encode text to T5 embedding
out = encode_for_batch(tokenizer, text_encoder, [args.prompt])[0]
encoded_text = torch.tensor(out, dtype=torch.bfloat16)
# Padding T5 embedding to t5_embeding_max_length
L, C = encoded_text.shape
t5_embed = torch.zeros(1, t5_embeding_max_length, C, dtype=torch.bfloat16)
t5_embed[0, :L] = encoded_text
if args.negative_prompt:
out = encode_for_batch(tokenizer, text_encoder, [args.negative_prompt])[0]
encoded_text = torch.tensor(out, dtype=torch.bfloat16)
# Padding T5 embedding to t5_embeding_max_length
L, C = encoded_text.shape
neg_t5_embed = torch.zeros(1, t5_embeding_max_length, C, dtype=torch.bfloat16)
neg_t5_embed[0, :L] = encoded_text
else:
neg_t5_embed = None
# Prepare data sample
t, h, w = args.num_video_frames, args.height, args.width
state_shape = [
vae.channel,
vae.get_latent_num_frames(t),
h // vae.spatial_compression_factor,
w // vae.spatial_compression_factor,
]
data_batch = {
"video": torch.zeros((1, 3, t, h, w), dtype=torch.uint8).cuda(),
"t5_text_embeddings": t5_embed,
"t5_text_mask": torch.ones(1, t5_embeding_max_length, dtype=torch.bfloat16).cuda(),
# other conditions
"image_size": torch.tensor(
[[args.height, args.width, args.height, args.width]] * 1, dtype=torch.bfloat16
).cuda(),
"fps": torch.tensor([args.fps] * 1, dtype=torch.bfloat16).cuda(),
"num_frames": torch.tensor([args.num_video_frames] * 1, dtype=torch.bfloat16).cuda(),
"padding_mask": torch.zeros((1, 1, args.height, args.width), dtype=torch.bfloat16).cuda(),
}
if args.negative_prompt:
data_batch["neg_t5_text_embeddings"] = neg_t5_embed
data_batch["neg_t5_text_mask"] = torch.ones(1, t5_embeding_max_length, dtype=torch.bfloat16)
return data_batch, state_shape
def setup_diffusion_pipeline(args):
"""
Initialize DiT model, parallel strategy, and diffusion pipeline for inference.
"""
# Initialize DiT model
from nemo.collections.diffusion.models.model import DiT7BConfig, DiT14BConfig, DiTModel
if "7b" in args.nemo_checkpoint.lower():
dit_config = DiT7BConfig()
if "14b" in args.nemo_checkpoint.lower():
dit_config = DiT14BConfig()
dit_model = DiTModel(dit_config)
# Initialize model parallel strategy. Here, we only use context parallel.
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
context_parallel_size=args.cp_size,
pipeline_dtype=torch.bfloat16,
)
# Initialize ptl trainer
trainer = nl.Trainer(
devices=args.num_devices, # you can change the numebr of devices to suit your setup
max_steps=1,
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
)
# Convert trainer to fabric for inference
fabric = trainer.to_fabric()
fabric.strategy.checkpoint_io.save_ckpt_format = "zarr"
fabric.strategy.checkpoint_io.validate_access_integrity = False
model = fabric.load_model(args.nemo_checkpoint, dit_model).to(device="cuda", dtype=torch.bfloat16)
# Set up diffusion pipeline
conditioner = VideoConditioner(
text=TextConfig(),
fps=FPSConfig(),
num_frames=NumFramesConfig(),
image_size=ImageSizeConfig(),
padding_mask=PaddingMaskConfig(),
)
diffusion_pipeline = CosmosDiffusionPipeline(
net=model.module, conditioner=conditioner, sampler_type=args.sampler, seed=args.seed
)
return diffusion_pipeline
def run_diffusion_inference(args, data_batch, state_shape, vae, diffusion_pipeline):
# prepare data
data_batch = {k: v.cuda() if torch.is_tensor(v) else v for k, v in data_batch.items()}
data_batch["inference_fwd"] = True
sample = diffusion_pipeline.generate_samples_from_batch(
data_batch,
guidance=args.guidance,
state_shape=state_shape,
num_steps=args.num_steps,
is_negative_prompt=True if "neg_t5_text_embeddings" in data_batch else False,
)
rank = torch.distributed.get_rank()
if rank == 0:
# Post-processing and save video
sigma_data = 0.5
grid = (1.0 + vae.decode(sample / sigma_data)).clamp(0, 2) / 2
grid = (grid[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy().astype(np.uint8)
save_video(
grid=grid,
fps=args.fps,
H=args.height,
W=args.width,
video_save_quality=5,
video_save_path=args.video_save_path,
checkpoint_dir=args.cosmos_assets_dir,
guardrails_dir=args.guardrail_dir,
)
print_rank_0(f"saved video to {args.video_save_path}!")
def main(args):
if args.guardrail_dir == "":
args.guardrail_dir = snapshot_download("nvidia/Cosmos-1.0-Guardrail")
if args.tokenizer_dir == "":
args.tokenizer_dir = snapshot_download("nvidia/Cosmos-1.0-Tokenizer-CV8x8x8")
if args.prompt_upsampler_dir == "" and args.enable_prompt_upsampler:
args.prompt_upsampler_dir = snapshot_download("nvidia/Cosmos-1.0-Prompt-Upsampler-12B-Text2World")
if args.nemo_checkpoint == "":
args.nemo_checkpoint = snapshot_download(f"nvidia/{args.model}", allow_patterns=["nemo/*"])
args.nemo_checkpoint = os.path.join(args.nemo_checkpoint, "nemo")
# Initialize megatron model parallel environment
Utils.initialize_distributed(1, 1, context_parallel_size=args.cp_size)
model_parallel_cuda_manual_seed(args.seed)
args.prompt = check_prompt(args)
# Load video tokenizer
print_rank_0("initializing video tokenizer...")
vae = init_video_tokenizer(args)
# Prepare data batch
print_rank_0("preparing data batch...")
data_batch, state_shape = prepare_data_batch(args, vae)
# Setup model / diffusion pipeline
print_rank_0("setting up diffusion pipeline...")
diffusion_pipeline = setup_diffusion_pipeline(args)
# Generate video from prompt
print_rank_0("generating video...")
run_diffusion_inference(args, data_batch, state_shape, vae, diffusion_pipeline)
if __name__ == "__main__":
args = parse_args()
main(args)