EthanZyh's picture
modify log
02c5b0e
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import importlib
import math
import os
from typing import List
import torch
import torchvision
from huggingface_hub import snapshot_download
from inference_config import DiffusionDecoderSamplingConfig
from cosmos1.models.autoregressive.diffusion_decoder.inference import diffusion_decoder_process_tokens
from cosmos1.models.autoregressive.diffusion_decoder.model import LatentDiffusionDecoderModel
from inference_utils import (
load_network_model,
load_tokenizer_model,
skip_init_linear,
)
from .log import log
from config_helper import get_config_module, override
TOKENIZER_COMPRESSION_FACTOR = [8, 16, 16]
DATA_RESOLUTION_SUPPORTED = [640, 1024]
NUM_CONTEXT_FRAMES = 33
def resize_input(video: torch.Tensor, resolution: list[int]):
r"""
Function to perform aspect ratio preserving resizing and center cropping.
This is needed to make the video into target resolution.
Args:
video (torch.Tensor): Input video tensor
resolution (list[int]): Data resolution
Returns:
Cropped video
"""
orig_h, orig_w = video.shape[2], video.shape[3]
target_h, target_w = resolution
scaling_ratio = max((target_w / orig_w), (target_h / orig_h))
resizing_shape = (int(math.ceil(scaling_ratio * orig_h)), int(math.ceil(scaling_ratio * orig_w)))
video_resized = torchvision.transforms.functional.resize(video, resizing_shape)
video_cropped = torchvision.transforms.functional.center_crop(video_resized, resolution)
return video_cropped
def read_input_videos(input_video: str) -> torch.tensor:
"""Utility to read the input video and return a torch tensor
Args:
input_video (str): A path to .mp4 file
data_resolution (list, optional): The . Defaults to [640, 1024].
Returns:
A torch tensor of the video
"""
video, _, _ = torchvision.io.read_video(input_video)
video = video.float() / 255.0
video = video * 2 - 1
if video.shape[0] > NUM_CONTEXT_FRAMES:
video = video[0:NUM_CONTEXT_FRAMES, :, :, :]
else:
log.info(f"Video doesn't have {NUM_CONTEXT_FRAMES} frames. Padding the video with the last frame.")
# Pad the video
nframes_in_video = video.shape[0]
video = torch.cat(
(video, video[-1, :, :, :].unsqueeze(0).repeat(NUM_CONTEXT_FRAMES - nframes_in_video, 1, 1, 1)),
dim=0,
)
video = video[0:NUM_CONTEXT_FRAMES, :, :, :]
video = video.permute(0, 3, 1, 2)
video = resize_input(video, DATA_RESOLUTION_SUPPORTED)
return video.transpose(0, 1).unsqueeze(0)
def run_diffusion_decoder_model(indices_tensor_cur_batch: List[torch.Tensor], out_videos_cur_batch):
"""Run a 7b diffusion model to enhance generation output
Args:
indices_tensor_cur_batch (List[torch.Tensor]): The index tensor(i.e) prompt + generation tokens
out_videos_cur_batch (torch.Tensor): The output decoded video of shape [bs, 3, 33, 640, 1024]
"""
diffusion_decoder_ckpt_path = snapshot_download("nvidia/Cosmos-1.0-Diffusion-7B-Decoder-DV8x16x16ToCV8x8x8")
dd_tokenizer_dir = snapshot_download("nvidia/Cosmos-1.0-Tokenizer-CV8x8x8")
tokenizer_corruptor_dir = snapshot_download("nvidia/Cosmos-1.0-Tokenizer-DV8x16x16")
diffusion_decoder_model = load_model_by_config(
config_job_name="DD_FT_7Bv1_003_002_tokenizer888_spatch2_discrete_cond_on_token",
config_file="cosmos1/models/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py",
model_class=LatentDiffusionDecoderModel,
encoder_path=os.path.join(tokenizer_corruptor_dir, "encoder.jit"),
decoder_path=os.path.join(tokenizer_corruptor_dir, "decoder.jit"),
)
load_network_model(diffusion_decoder_model, os.path.join(diffusion_decoder_ckpt_path, "model.pt"))
load_tokenizer_model(diffusion_decoder_model, dd_tokenizer_dir)
generic_prompt = dict()
aux_vars = torch.load(os.path.join(diffusion_decoder_ckpt_path, "aux_vars.pt"), weights_only=True)
generic_prompt["context"] = aux_vars["context"].cuda()
generic_prompt["context_mask"] = aux_vars["context_mask"].cuda()
output_video = diffusion_decoder_process_tokens(
model=diffusion_decoder_model,
indices_tensor=indices_tensor_cur_batch,
dd_sampling_config=DiffusionDecoderSamplingConfig(),
original_video_example=out_videos_cur_batch[0],
t5_emb_batch=[generic_prompt["context"]],
)
del diffusion_decoder_model
diffusion_decoder_model = None
gc.collect()
torch.cuda.empty_cache()
return output_video
def load_model_by_config(
config_job_name,
config_file="projects/cosmos_video/config/config.py",
model_class=LatentDiffusionDecoderModel,
encoder_path=None,
decoder_path=None,
):
config_module = get_config_module(config_file)
config = importlib.import_module(config_module).make_config()
config = override(config, ["--", f"experiment={config_job_name}"])
# Check that the config is valid
config.validate()
# Freeze the config so developers don't change it during training.
config.freeze() # type: ignore
if encoder_path:
config.model.tokenizer_corruptor["enc_fp"] = encoder_path
if decoder_path:
config.model.tokenizer_corruptor["dec_fp"] = decoder_path
# Initialize model
with skip_init_linear():
model = model_class(config.model)
return model