EthanZyh's picture
modify log
02c5b0e
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import json
import os
import random
import torch
import torchvision
from einops import rearrange
from huggingface_hub import snapshot_download
from nemo.collections.diffusion.models.model import DiT7BConfig
from tqdm import tqdm
from transformers import T5EncoderModel, T5TokenizerFast
from .log import log
def get_parser():
parser = argparse.ArgumentParser(description="Process some configurations.")
parser.add_argument("--tokenizer_dir", type=str, default="", help="Path to the VAE model")
parser.add_argument(
"--dataset_path", type=str, default="video_dataset", help="Path to the dataset (a folder of videos)"
)
parser.add_argument("--output_path", type=str, default="video_dataset_cached", help="Path to the output directory")
parser.add_argument("--prompt", type=str, default="a video of sks.", help="Prompt for the video")
parser.add_argument("--num_chunks", type=int, default=5, help="Number of random chunks to sample per video")
parser.add_argument("--height", type=int, default=704, help="Height to resize video")
parser.add_argument("--width", type=int, default=1280, help="Width to resize video")
return parser
def init_t5():
"""Initialize and return the T5 tokenizer and text encoder."""
tokenizer = T5TokenizerFast.from_pretrained("google-t5/t5-11b")
text_encoder = T5EncoderModel.from_pretrained("google-t5/t5-11b")
text_encoder.to("cuda")
text_encoder.eval()
return tokenizer, text_encoder
def init_video_tokenizer(tokenizer_dir: str):
"""Initialize and return the Cosmos Video tokenizer."""
dit_config = DiT7BConfig(vae_path=tokenizer_dir)
vae = dit_config.configure_vae()
return vae
@torch.no_grad()
def encode_for_batch(tokenizer, encoder, prompts: list[str], max_length=512):
"""
Encode a batch of text prompts to a batch of T5 embeddings.
Parameters:
tokenizer: T5 embedding tokenizer.
encoder: T5 embedding text encoder.
prompts: A batch of text prompts.
max_length: Sequence length of text embedding (defaults to 512).
"""
batch_encoding = tokenizer.batch_encode_plus(
prompts,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=max_length,
return_length=True,
return_offsets_mapping=False,
)
# We expect all the processing is done on GPU.
input_ids = batch_encoding.input_ids.cuda()
attn_mask = batch_encoding.attention_mask.cuda()
outputs = encoder(input_ids=input_ids, attention_mask=attn_mask)
encoded_text = outputs.last_hidden_state
lengths = attn_mask.sum(dim=1).cpu()
for batch_id in range(encoded_text.shape[0]):
encoded_text[batch_id][lengths[batch_id] :] = 0
return encoded_text
def main(args):
# Set up output directory
os.makedirs(args.output_path, exist_ok=True)
# Initialize T5
tokenizer, text_encoder = init_t5()
# Initialize the VAE
if args.tokenizer_dir == "":
args.tokenizer_dir = snapshot_download("nvidia/Cosmos-1.0-Tokenizer-CV8x8x8")
vae = init_video_tokenizer(args.tokenizer_dir)
# Constants
t5_embeding_max_length = 512
chunk_duration = vae.video_vae.pixel_chunk_duration # Frames per chunk
cnt = 0 # File index
# Check if dataset_path is correct
files = glob.glob(os.path.join(args.dataset_path, "*.mp4"))
if not files:
raise ValueError(f"Dataset path {args.dataset_path} does not contain any .mp4 files.")
# Process each video in the dataset folder
with torch.no_grad():
for video_path in tqdm(glob.glob(os.path.join(args.dataset_path, "*.mp4"))):
# Read video (T x H x W x C)
video, _, meta = torchvision.io.read_video(video_path)
T, H, W, C = video.shape
# Skip videos shorter than one chunk
if T < chunk_duration:
log.info(f"Video {video_path} is shorter than {chunk_duration} frames. Skipped.")
continue
# Sample random segments
for _ in range(args.num_chunks):
start_idx = random.randint(0, T - chunk_duration)
chunk = video[start_idx : start_idx + chunk_duration] # (chunk_duration, H, W, C)
# Rearrange dimensions: (T, H, W, C) -> (T, C, H, W)
chunk = rearrange(chunk, "t h w c -> t c h w")
# Resize to [704, 1280] for each frame
chunk = torchvision.transforms.functional.resize(chunk, [args.height, args.width])
# Expand dims: (T, C, H, W) -> (B=1, C, T, H, W)
chunk = rearrange(chunk, "(b t) c h w -> b c t h w", b=1)
# Convert to bf16 and normalize from [0, 255] to [-1, 1]
chunk = chunk.to(device="cuda", dtype=torch.bfloat16, non_blocking=True) / 127.5 - 1.0
# Encode video
latent = vae.encode(chunk).cpu() # shape: (1, latent_channels, T//factor, H//factor, W//factor)
# Encode text
out = encode_for_batch(tokenizer, text_encoder, [args.prompt])[0]
encoded_text = torch.tensor(out, dtype=torch.bfloat16)
# Pad T5 embedding to t5_embeding_max_length
L, C_ = encoded_text.shape
t5_embed = torch.zeros(1, t5_embeding_max_length, C_, dtype=torch.bfloat16)
t5_embed[0, :L] = encoded_text
# Save data to folder
torch.save(latent[0], os.path.join(args.output_path, f"{cnt}.video_latent.pth"))
torch.save(t5_embed[0], os.path.join(args.output_path, f"{cnt}.t5_text_embeddings.pth"))
# Create a T5 text mask of all ones
torch.save(
torch.ones(512, dtype=torch.bfloat16), os.path.join(args.output_path, f"{cnt}.t5_text_mask.pth")
)
# Save metadata
info = {
"height": H,
"width": W,
"fps": meta["video_fps"],
"num_frames": chunk_duration,
"video_path": os.path.basename(video_path),
"start_frame": start_idx,
}
with open(os.path.join(args.output_path, f"{cnt}.info.json"), "w") as json_file:
json.dump(info, json_file)
cnt += 1
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
main(args)