|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gc |
|
import os |
|
from argparse import ArgumentParser |
|
from typing import List |
|
|
|
import imageio |
|
import nemo.lightning as nl |
|
import numpy as np |
|
import torch |
|
from einops import rearrange |
|
from huggingface_hub import snapshot_download |
|
from megatron.core.inference.common_inference_params import CommonInferenceParams |
|
from megatron.core.inference.engines.mcore_engine import MCoreEngine |
|
from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import ( |
|
SimpleTextGenerationController, |
|
) |
|
from nemo.collections.llm.inference.base import _setup_trainer_and_restore_model |
|
from nemo.lightning import io |
|
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir |
|
|
|
from cosmos1.models.autoregressive.nemo.utils import run_diffusion_decoder_model |
|
from discrete_video import DiscreteVideoFSQJITTokenizer |
|
from cosmos1.models.autoregressive.utils.inference import load_vision_input |
|
from .presets import presets as guardrail_presets |
|
from .log import log |
|
|
|
torch._C._jit_set_texpr_fuser_enabled(False) |
|
|
|
TOKENIZER_COMPRESSION_FACTOR = [8, 16, 16] |
|
NUM_CONTEXT_FRAMES = 33 |
|
NUM_INPUT_FRAMES_VIDEO = 9 |
|
LATENT_SHAPE = [5, 40, 64] |
|
DATA_RESOLUTION = [640, 1024] |
|
|
|
|
|
class CosmosMCoreTokenizerWrappper: |
|
""" |
|
A small dummy wrapper to pass into the text generation controller. |
|
""" |
|
|
|
def __init__(self): |
|
self.tokenizer = None |
|
self.eod = -1 |
|
self.vocab_size = 64000 |
|
|
|
def detokenize(self, tokens: List[int], remove_special_tokens: bool = False): |
|
return tokens |
|
|
|
def tokenize(self, prompt: List[int]): |
|
return prompt |
|
|
|
|
|
def main(args): |
|
num_input_frames = 1 if args.input_type == "image" else NUM_INPUT_FRAMES_VIDEO |
|
|
|
vision_input_dict = load_vision_input( |
|
input_type=args.input_type, |
|
batch_input_path=None, |
|
input_image_or_video_path=args.input_image_or_video_path, |
|
data_resolution=DATA_RESOLUTION, |
|
num_input_frames=num_input_frames, |
|
) |
|
|
|
vision_input = list(vision_input_dict.values())[0].cuda() |
|
|
|
T, H, W = LATENT_SHAPE |
|
latent_context_t_size = 1 if args.input_type == "image" else 2 |
|
num_tokens_to_generate = int(np.prod([T - latent_context_t_size, H, W])) |
|
|
|
|
|
if args.encoder_path == "nvidia/Cosmos-1.0-Tokenizer-DV8x16x16": |
|
args.encoder_path = os.path.join(snapshot_download(args.encoder_path), "encoder.jit") |
|
if args.decoder_path == "nvidia/Cosmos-1.0-Tokenizer-DV8x16x16": |
|
args.decoder_path = os.path.join(snapshot_download(args.decoder_path), "decoder.jit") |
|
video_tokenizer = DiscreteVideoFSQJITTokenizer( |
|
enc_fp=args.encoder_path, |
|
dec_fp=args.decoder_path, |
|
name="discrete_video_fsq", |
|
pixel_chunk_duration=NUM_CONTEXT_FRAMES, |
|
latent_chunk_duration=T, |
|
).cuda() |
|
|
|
quantized_out, _ = video_tokenizer.encode(vision_input, pixel_chunk_duration=None) |
|
indices = video_tokenizer.fsq_quantizer.codes_to_indices(quantized_out.permute(0, 2, 3, 4, 1)) |
|
indices = rearrange(indices, "B T H W -> B (T H W)") |
|
video_tokens = [indices[0][0:-num_tokens_to_generate].tolist()] |
|
|
|
|
|
if args.ar_model_dir in ["nvidia/Cosmos-1.0-Autoregressive-4B", "nvidia/Cosmos-1.0-Autoregressive-12B"]: |
|
args.ar_model_dir = os.path.join(snapshot_download(args.ar_model_dir, allow_patterns=["nemo/*"]), "nemo") |
|
model: io.TrainerContext = io.load_context(path=ckpt_to_context_subdir(args.ar_model_dir), subpath="model") |
|
|
|
strategy = nl.MegatronStrategy( |
|
tensor_model_parallel_size=1, |
|
pipeline_model_parallel_size=1, |
|
context_parallel_size=1, |
|
sequence_parallel=False, |
|
setup_optimizers=False, |
|
store_optimizer_states=False, |
|
) |
|
|
|
trainer = nl.Trainer( |
|
accelerator="gpu", |
|
devices=1, |
|
num_nodes=1, |
|
strategy=strategy, |
|
num_sanity_val_steps=0, |
|
plugins=nl.MegatronMixedPrecision( |
|
precision="bf16-mixed", |
|
params_dtype=torch.bfloat16, |
|
pipeline_dtype=torch.bfloat16, |
|
autocast_enabled=False, |
|
grad_reduce_in_fp32=False, |
|
), |
|
) |
|
_setup_trainer_and_restore_model(path=args.ar_model_dir, trainer=trainer, model=model) |
|
|
|
inference_wrapped_model = model.get_inference_wrapper(torch.bfloat16, inference_batch_times_seqlen_threshold=1000) |
|
|
|
|
|
text_generation_controller = SimpleTextGenerationController( |
|
inference_wrapped_model=inference_wrapped_model, tokenizer=CosmosMCoreTokenizerWrappper() |
|
) |
|
|
|
mcore_engine = MCoreEngine(text_generation_controller=text_generation_controller, max_batch_size=1) |
|
|
|
common_inference_params = CommonInferenceParams( |
|
temperature=args.temperature, top_p=args.top_p, num_tokens_to_generate=num_tokens_to_generate |
|
) |
|
|
|
log.info(f"Running Inference to generate {num_tokens_to_generate} tokens. This will take some time. ") |
|
results = mcore_engine.generate( |
|
prompts=video_tokens, |
|
add_BOS=False, |
|
encoder_prompts=None, |
|
common_inference_params=common_inference_params, |
|
) |
|
|
|
result = list(results)[0] |
|
prompt_tokens = torch.tensor(result.prompt_tokens).cuda() |
|
prompt_tokens[prompt_tokens == -1] = result.generated_tokens |
|
|
|
indices_tensor = prompt_tokens.unsqueeze(dim=0) |
|
indices_tensor = rearrange( |
|
indices_tensor, |
|
"B (T H W) -> B T H W", |
|
T=LATENT_SHAPE[0], |
|
H=LATENT_SHAPE[1], |
|
W=LATENT_SHAPE[2], |
|
) |
|
|
|
if torch.cuda.current_device() == 0: |
|
|
|
log.info("Running diffusion model on the generated result") |
|
video_decoded = video_tokenizer.decode(indices_tensor.cuda()) |
|
out_video = (video_decoded * 0.5 + 0.5).clamp_(0, 1) |
|
|
|
if not args.disable_diffusion_decoder: |
|
del model |
|
del inference_wrapped_model |
|
del video_tokenizer |
|
model = None |
|
inference_wrapped_model = None |
|
video_tokenizer = None |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
out_video = run_diffusion_decoder_model( |
|
indices_tensor_cur_batch=[indices_tensor.squeeze()], out_videos_cur_batch=out_video |
|
) |
|
|
|
out_video = out_video[0].detach().clone() |
|
output_video = (out_video * 255).to(torch.uint8).permute(1, 2, 3, 0).cpu().numpy() |
|
|
|
if args.guardrail_dir: |
|
log.info("Running guardrails on the generated video") |
|
if args.guardrail_dir == "nvidia/Cosmos-1.0-Guardrail": |
|
args.guardrail_dir = snapshot_download(args.guardrail_dir) |
|
video_guardrail = guardrail_presets.create_video_guardrail_runner(checkpoint_dir=args.guardrail_dir) |
|
output_video = guardrail_presets.run_video_guardrail(output_video, video_guardrail) |
|
if output_video is None: |
|
raise ValueError("Guardrail blocked world generation.") |
|
|
|
|
|
imageio.mimsave( |
|
args.video_save_name, |
|
output_video, |
|
fps=25, |
|
) |
|
|
|
log.info(f"Saved to {args.video_save_name}") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = ArgumentParser() |
|
parser.add_argument("--input_type", type=str, default="video", help="Type of input", choices=["image", "video"]) |
|
parser.add_argument( |
|
"--input_image_or_video_path", required=True, type=str, help="The path to the input video to run inference" |
|
) |
|
parser.add_argument( |
|
"--video_save_name", default="./nemo_generated_video.mp4", type=str, help="The path to generated video" |
|
) |
|
parser.add_argument( |
|
"--ar_model_dir", |
|
default="nvidia/Cosmos-1.0-Autoregressive-4B", |
|
type=str, |
|
help="The path to the nemo autoregressive model", |
|
) |
|
parser.add_argument( |
|
"--encoder_path", default="nvidia/Cosmos-1.0-Tokenizer-DV8x16x16", type=str, help="The path to encoder" |
|
) |
|
parser.add_argument( |
|
"--decoder_path", default="nvidia/Cosmos-1.0-Tokenizer-DV8x16x16", type=str, help="The path to the decoder" |
|
) |
|
parser.add_argument( |
|
"--guardrail_dir", default="nvidia/Cosmos-1.0-Guardrail", type=str, help="The path to the guardrails" |
|
) |
|
parser.add_argument("--top_p", default=0.8, type=float, help="The top_p inference parameter ") |
|
parser.add_argument("--temperature", default=1, type=int, help="Sampling temperature") |
|
parser.add_argument("--disable_diffusion_decoder", action="store_true", help="Disable diffusion decoder") |
|
|
|
args = parser.parse_args() |
|
|
|
main(args) |
|
|