File size: 6,128 Bytes
01a383f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84490df
01a383f
 
84490df
01a383f
 
 
 
02c5b0e
84490df
01a383f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import importlib
import math
import os
from typing import List

import torch
import torchvision
from huggingface_hub import snapshot_download

from inference_config import DiffusionDecoderSamplingConfig
from cosmos1.models.autoregressive.diffusion_decoder.inference import diffusion_decoder_process_tokens
from cosmos1.models.autoregressive.diffusion_decoder.model import LatentDiffusionDecoderModel
from inference_utils import (
    load_network_model,
    load_tokenizer_model,
    skip_init_linear,
)
from .log import log
from config_helper import get_config_module, override

TOKENIZER_COMPRESSION_FACTOR = [8, 16, 16]
DATA_RESOLUTION_SUPPORTED = [640, 1024]
NUM_CONTEXT_FRAMES = 33


def resize_input(video: torch.Tensor, resolution: list[int]):
    r"""
    Function to perform aspect ratio preserving resizing and center cropping.
    This is needed to make the video into target resolution.
    Args:
        video (torch.Tensor): Input video tensor
        resolution (list[int]): Data resolution
    Returns:
        Cropped video
    """

    orig_h, orig_w = video.shape[2], video.shape[3]
    target_h, target_w = resolution

    scaling_ratio = max((target_w / orig_w), (target_h / orig_h))
    resizing_shape = (int(math.ceil(scaling_ratio * orig_h)), int(math.ceil(scaling_ratio * orig_w)))
    video_resized = torchvision.transforms.functional.resize(video, resizing_shape)
    video_cropped = torchvision.transforms.functional.center_crop(video_resized, resolution)
    return video_cropped


def read_input_videos(input_video: str) -> torch.tensor:
    """Utility to read the input video and return a torch tensor

    Args:
        input_video (str): A path to .mp4 file
        data_resolution (list, optional): The . Defaults to [640, 1024].

    Returns:
        A torch tensor of the video
    """
    video, _, _ = torchvision.io.read_video(input_video)
    video = video.float() / 255.0
    video = video * 2 - 1

    if video.shape[0] > NUM_CONTEXT_FRAMES:
        video = video[0:NUM_CONTEXT_FRAMES, :, :, :]
    else:
        log.info(f"Video doesn't have {NUM_CONTEXT_FRAMES} frames. Padding the video with the last frame.")
        # Pad the video
        nframes_in_video = video.shape[0]
        video = torch.cat(
            (video, video[-1, :, :, :].unsqueeze(0).repeat(NUM_CONTEXT_FRAMES - nframes_in_video, 1, 1, 1)),
            dim=0,
        )

    video = video[0:NUM_CONTEXT_FRAMES, :, :, :]
    video = video.permute(0, 3, 1, 2)
    video = resize_input(video, DATA_RESOLUTION_SUPPORTED)
    return video.transpose(0, 1).unsqueeze(0)


def run_diffusion_decoder_model(indices_tensor_cur_batch: List[torch.Tensor], out_videos_cur_batch):
    """Run a 7b diffusion model to enhance generation output

    Args:
        indices_tensor_cur_batch (List[torch.Tensor]): The index tensor(i.e) prompt + generation tokens
        out_videos_cur_batch (torch.Tensor): The output decoded video of shape [bs, 3, 33, 640, 1024]
    """
    diffusion_decoder_ckpt_path = snapshot_download("nvidia/Cosmos-1.0-Diffusion-7B-Decoder-DV8x16x16ToCV8x8x8")
    dd_tokenizer_dir = snapshot_download("nvidia/Cosmos-1.0-Tokenizer-CV8x8x8")
    tokenizer_corruptor_dir = snapshot_download("nvidia/Cosmos-1.0-Tokenizer-DV8x16x16")

    diffusion_decoder_model = load_model_by_config(
        config_job_name="DD_FT_7Bv1_003_002_tokenizer888_spatch2_discrete_cond_on_token",
        config_file="cosmos1/models/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py",
        model_class=LatentDiffusionDecoderModel,
        encoder_path=os.path.join(tokenizer_corruptor_dir, "encoder.jit"),
        decoder_path=os.path.join(tokenizer_corruptor_dir, "decoder.jit"),
    )
    load_network_model(diffusion_decoder_model, os.path.join(diffusion_decoder_ckpt_path, "model.pt"))
    load_tokenizer_model(diffusion_decoder_model, dd_tokenizer_dir)

    generic_prompt = dict()
    aux_vars = torch.load(os.path.join(diffusion_decoder_ckpt_path, "aux_vars.pt"), weights_only=True)
    generic_prompt["context"] = aux_vars["context"].cuda()
    generic_prompt["context_mask"] = aux_vars["context_mask"].cuda()

    output_video = diffusion_decoder_process_tokens(
        model=diffusion_decoder_model,
        indices_tensor=indices_tensor_cur_batch,
        dd_sampling_config=DiffusionDecoderSamplingConfig(),
        original_video_example=out_videos_cur_batch[0],
        t5_emb_batch=[generic_prompt["context"]],
    )

    del diffusion_decoder_model
    diffusion_decoder_model = None
    gc.collect()
    torch.cuda.empty_cache()

    return output_video


def load_model_by_config(
    config_job_name,
    config_file="projects/cosmos_video/config/config.py",
    model_class=LatentDiffusionDecoderModel,
    encoder_path=None,
    decoder_path=None,
):
    config_module = get_config_module(config_file)
    config = importlib.import_module(config_module).make_config()

    config = override(config, ["--", f"experiment={config_job_name}"])

    # Check that the config is valid
    config.validate()
    # Freeze the config so developers don't change it during training.
    config.freeze()  # type: ignore
    if encoder_path:
        config.model.tokenizer_corruptor["enc_fp"] = encoder_path
    if decoder_path:
        config.model.tokenizer_corruptor["dec_fp"] = decoder_path
    # Initialize model
    with skip_init_linear():
        model = model_class(config.model)
    return model