|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import glob |
|
import json |
|
import os |
|
import random |
|
|
|
import torch |
|
import torchvision |
|
from einops import rearrange |
|
from huggingface_hub import snapshot_download |
|
from nemo.collections.diffusion.models.model import DiT7BConfig |
|
from tqdm import tqdm |
|
from transformers import T5EncoderModel, T5TokenizerFast |
|
|
|
from .log import log |
|
|
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser(description="Process some configurations.") |
|
parser.add_argument("--tokenizer_dir", type=str, default="", help="Path to the VAE model") |
|
parser.add_argument( |
|
"--dataset_path", type=str, default="video_dataset", help="Path to the dataset (a folder of videos)" |
|
) |
|
parser.add_argument("--output_path", type=str, default="video_dataset_cached", help="Path to the output directory") |
|
parser.add_argument("--prompt", type=str, default="a video of sks.", help="Prompt for the video") |
|
parser.add_argument("--num_chunks", type=int, default=5, help="Number of random chunks to sample per video") |
|
parser.add_argument("--height", type=int, default=704, help="Height to resize video") |
|
parser.add_argument("--width", type=int, default=1280, help="Width to resize video") |
|
return parser |
|
|
|
|
|
def init_t5(): |
|
"""Initialize and return the T5 tokenizer and text encoder.""" |
|
tokenizer = T5TokenizerFast.from_pretrained("google-t5/t5-11b") |
|
text_encoder = T5EncoderModel.from_pretrained("google-t5/t5-11b") |
|
text_encoder.to("cuda") |
|
text_encoder.eval() |
|
return tokenizer, text_encoder |
|
|
|
|
|
def init_video_tokenizer(tokenizer_dir: str): |
|
"""Initialize and return the Cosmos Video tokenizer.""" |
|
dit_config = DiT7BConfig(vae_path=tokenizer_dir) |
|
vae = dit_config.configure_vae() |
|
return vae |
|
|
|
|
|
@torch.no_grad() |
|
def encode_for_batch(tokenizer, encoder, prompts: list[str], max_length=512): |
|
""" |
|
Encode a batch of text prompts to a batch of T5 embeddings. |
|
Parameters: |
|
tokenizer: T5 embedding tokenizer. |
|
encoder: T5 embedding text encoder. |
|
prompts: A batch of text prompts. |
|
max_length: Sequence length of text embedding (defaults to 512). |
|
""" |
|
|
|
batch_encoding = tokenizer.batch_encode_plus( |
|
prompts, |
|
return_tensors="pt", |
|
truncation=True, |
|
padding="max_length", |
|
max_length=max_length, |
|
return_length=True, |
|
return_offsets_mapping=False, |
|
) |
|
|
|
|
|
input_ids = batch_encoding.input_ids.cuda() |
|
attn_mask = batch_encoding.attention_mask.cuda() |
|
|
|
outputs = encoder(input_ids=input_ids, attention_mask=attn_mask) |
|
encoded_text = outputs.last_hidden_state |
|
|
|
lengths = attn_mask.sum(dim=1).cpu() |
|
for batch_id in range(encoded_text.shape[0]): |
|
encoded_text[batch_id][lengths[batch_id] :] = 0 |
|
|
|
return encoded_text |
|
|
|
|
|
def main(args): |
|
|
|
os.makedirs(args.output_path, exist_ok=True) |
|
|
|
|
|
tokenizer, text_encoder = init_t5() |
|
|
|
|
|
if args.tokenizer_dir == "": |
|
args.tokenizer_dir = snapshot_download("nvidia/Cosmos-1.0-Tokenizer-CV8x8x8") |
|
vae = init_video_tokenizer(args.tokenizer_dir) |
|
|
|
|
|
t5_embeding_max_length = 512 |
|
chunk_duration = vae.video_vae.pixel_chunk_duration |
|
cnt = 0 |
|
|
|
|
|
files = glob.glob(os.path.join(args.dataset_path, "*.mp4")) |
|
if not files: |
|
raise ValueError(f"Dataset path {args.dataset_path} does not contain any .mp4 files.") |
|
|
|
|
|
with torch.no_grad(): |
|
for video_path in tqdm(glob.glob(os.path.join(args.dataset_path, "*.mp4"))): |
|
|
|
video, _, meta = torchvision.io.read_video(video_path) |
|
T, H, W, C = video.shape |
|
|
|
|
|
if T < chunk_duration: |
|
log.info(f"Video {video_path} is shorter than {chunk_duration} frames. Skipped.") |
|
continue |
|
|
|
|
|
for _ in range(args.num_chunks): |
|
start_idx = random.randint(0, T - chunk_duration) |
|
chunk = video[start_idx : start_idx + chunk_duration] |
|
|
|
|
|
chunk = rearrange(chunk, "t h w c -> t c h w") |
|
|
|
|
|
chunk = torchvision.transforms.functional.resize(chunk, [args.height, args.width]) |
|
|
|
|
|
chunk = rearrange(chunk, "(b t) c h w -> b c t h w", b=1) |
|
|
|
|
|
chunk = chunk.to(device="cuda", dtype=torch.bfloat16, non_blocking=True) / 127.5 - 1.0 |
|
|
|
|
|
latent = vae.encode(chunk).cpu() |
|
|
|
|
|
out = encode_for_batch(tokenizer, text_encoder, [args.prompt])[0] |
|
encoded_text = torch.tensor(out, dtype=torch.bfloat16) |
|
|
|
|
|
L, C_ = encoded_text.shape |
|
t5_embed = torch.zeros(1, t5_embeding_max_length, C_, dtype=torch.bfloat16) |
|
t5_embed[0, :L] = encoded_text |
|
|
|
|
|
torch.save(latent[0], os.path.join(args.output_path, f"{cnt}.video_latent.pth")) |
|
torch.save(t5_embed[0], os.path.join(args.output_path, f"{cnt}.t5_text_embeddings.pth")) |
|
|
|
|
|
torch.save( |
|
torch.ones(512, dtype=torch.bfloat16), os.path.join(args.output_path, f"{cnt}.t5_text_mask.pth") |
|
) |
|
|
|
|
|
info = { |
|
"height": H, |
|
"width": W, |
|
"fps": meta["video_fps"], |
|
"num_frames": chunk_duration, |
|
"video_path": os.path.basename(video_path), |
|
"start_frame": start_idx, |
|
} |
|
with open(os.path.join(args.output_path, f"{cnt}.info.json"), "w") as json_file: |
|
json.dump(info, json_file) |
|
|
|
cnt += 1 |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = get_parser() |
|
args = parser.parse_args() |
|
main(args) |
|
|