# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from functools import lru_cache from pathlib import Path from typing import TYPE_CHECKING, Annotated, Callable, Optional import torch import torch.nn.functional as F from einops import rearrange, repeat from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from nemo.collections.llm.gpt.model.llama import Llama3Config, LlamaModel from nemo.collections.llm.utils import Config from nemo.lightning import OptimizerModule, io from nemo.lightning.base import teardown from torch import Tensor, nn from .log import log class RotaryEmbedding3D(RotaryEmbedding): """Rotary Embedding3D for Cosmos Language model. Args: kv_channels (int): Projection weights dimension in multi-head attention. Obtained from transformer config rotary_base (int, optional): Base period for rotary position embeddings. Defaults to 10000. use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly on the GPU. Defaults to False latent_shape: The shape of the latents produced by the video after being tokenized """ def __init__( self, seq_len: int, kv_channels: int, training_type: str = None, rotary_base: int = 10000, use_cpu_initialization: bool = False, latent_shape=[5, 40, 64], apply_yarn=False, original_latent_shape=None, beta_fast=32, beta_slow=1, scale=None, max_position_embeddings=None, original_max_position_embeddings=None, extrapolation_factor=1, attn_factor=1, ) -> None: super().__init__( kv_channels=kv_channels, rotary_base=rotary_base, rotary_percent=1.0, use_cpu_initialization=use_cpu_initialization, ) self.latent_shape = latent_shape self.device = "cpu" if use_cpu_initialization else torch.cuda.current_device() self.dim = kv_channels self.rope_theta = rotary_base self.apply_yarn = apply_yarn self.original_latent_shape = original_latent_shape self.beta_fast = beta_fast self.beta_slow = beta_slow self.scale = scale self.max_position_embeddings = max_position_embeddings self.original_max_position_embeddings = original_max_position_embeddings self.attn_factor = attn_factor dim_h = self.dim // 6 * 2 dim_t = self.dim - 2 * dim_h self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.device) / dim_h spatial_inv_freq = 1.0 / (self.rope_theta**self.dim_spatial_range) self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.device) / dim_t temporal_inv_freq = 1.0 / (self.rope_theta**self.dim_temporal_range) if self.apply_yarn: assert self.original_latent_shape is not None, "Original latent shape required." assert self.beta_slow is not None, "Beta slow value required." assert self.beta_fast is not None, "Beta fast value required." scale_factors_spatial = self.get_scale_factors(spatial_inv_freq, self.original_latent_shape[1]) spatial_inv_freq = spatial_inv_freq * scale_factors_spatial scale_factors_temporal = self.get_scale_factors(temporal_inv_freq, self.original_latent_shape[0]) temporal_inv_freq = temporal_inv_freq * scale_factors_temporal self.mscale = float(self.get_mscale(self.scale) * self.attn_factor) self.spatial_inv_freq = spatial_inv_freq self.temporal_inv_freq = temporal_inv_freq max_seq_len_cached = max(self.latent_shape) if self.apply_yarn and seq_len > max_seq_len_cached: max_seq_len_cached = seq_len self.max_seq_len_cached = max_seq_len_cached self.freqs = self.get_freqs_non_repeated(self.max_seq_len_cached) def get_mscale(self, scale: float = 1.0) -> float: """Get the magnitude scaling factor for YaRN.""" if scale <= 1: return 1.0 return 0.1 * math.log(scale) + 1.0 def get_scale_factors(self, inv_freq: torch.Tensor, original_seq_len: int) -> torch.Tensor: """Get the scale factors for YaRN.""" # Calculate the high and low frequency cutoffs for YaRN. Note: `beta_fast` and `beta_slow` are called # `high_freq_factor` and `low_freq_factor` in the Llama 3.1 RoPE scaling code. high_freq_cutoff = 2 * math.pi * self.beta_fast / original_seq_len low_freq_cutoff = 2 * math.pi * self.beta_slow / original_seq_len # Obtain a smooth mask that has a value of 0 for low frequencies and 1 for high frequencies, with linear # interpolation in between. smooth_mask = torch.clamp((inv_freq - low_freq_cutoff) / (high_freq_cutoff - low_freq_cutoff), min=0, max=1) # For low frequencies, we scale the frequency by 1/self.scale. For high frequencies, we keep the frequency. scale_factors = (1 - smooth_mask) / self.scale + smooth_mask return scale_factors def get_freqs_non_repeated(self, max_seq_len_cached: int, offset: int = 0) -> Tensor: dtype = self.spatial_inv_freq.dtype device = self.spatial_inv_freq.device self.seq = (torch.arange(max_seq_len_cached, device=device, dtype=dtype) + offset).cuda() assert hasattr( self, "latent_shape" ), "Latent shape is not set. Please run set_latent_shape() method on rope embedding. " T, H, W = self.latent_shape half_emb_t = torch.outer(self.seq[:T], self.temporal_inv_freq.cuda()) half_emb_h = torch.outer(self.seq[:H], self.spatial_inv_freq.cuda()) half_emb_w = torch.outer(self.seq[:W], self.spatial_inv_freq.cuda()) emb = torch.cat( [ repeat(half_emb_t, "t d -> t h w d", h=H, w=W), repeat(half_emb_h, "h d -> t h w d", t=T, w=W), repeat(half_emb_w, "w d -> t h w d", t=T, h=H), ] * 2, dim=-1, ) emb = rearrange(emb, "t h w d -> (t h w) 1 1 d").float() return emb @lru_cache(maxsize=32) def forward(self, seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor: if self.spatial_inv_freq.device.type == "cpu": # move `inv_freq` to GPU once at the first micro-batch forward pass self.spatial_inv_freq = self.spatial_inv_freq.to(device=torch.cuda.current_device()) max_seq_len_cached = self.max_seq_len_cached if self.apply_yarn and seq_len > max_seq_len_cached: max_seq_len_cached = seq_len self.max_seq_len_cached = max_seq_len_cached emb = self.get_freqs_non_repeated(self.max_seq_len_cached) return emb if TYPE_CHECKING: from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec @dataclass class CosmosConfig(Llama3Config): qk_layernorm: bool = True rope_dim: str = "3D" vocab_size: int = 64000 activation_func = F.silu def configure_model(self, tokenizer) -> "MCoreGPTModel": model = super().configure_model(tokenizer) if self.rope_dim == "3D": model.rotary_pos_emb = RotaryEmbedding3D( seq_len=self.seq_length, training_type=None, kv_channels=self.kv_channels, max_position_embeddings=self.seq_length, original_max_position_embeddings=self.original_seq_len if hasattr(self, "original_seq_len") else None, rotary_base=self.rotary_base, apply_yarn=True if hasattr(self, "apply_yarn") else False, scale=self.yarn_scale if hasattr(self, "yarn_scale") else None, extrapolation_factor=1, attn_factor=1, beta_fast=self.yarn_beta_fast if hasattr(self, "yarn_beta_fast") else 32, beta_slow=self.yarn_beta_slow if hasattr(self, "yarn_beta_slow") else 1, latent_shape=[5, 40, 64], original_latent_shape=self.original_latent_shape if hasattr(self, "original_latent_shape") else None, ) return model @dataclass class CosmosConfig4B(CosmosConfig): rotary_base: int = 500_000 seq_length: int = 15360 num_layers: int = 16 hidden_size: int = 4096 ffn_hidden_size: int = 14336 num_attention_heads: int = 32 num_query_groups: int = 8 layernorm_epsilon: float = 1e-5 use_cpu_initialization: bool = True make_vocab_size_divisible_by: int = 128 kv_channels: int = 128 @dataclass class CosmosConfig12B(CosmosConfig): rotary_base: int = 500_000 seq_length: int = 15360 num_layers: int = 40 hidden_size: int = 5120 ffn_hidden_size: int = 14336 num_attention_heads: int = 32 num_query_groups: int = 8 layernorm_epsilon: float = 1e-5 use_cpu_initialization: bool = True make_vocab_size_divisible_by: int = 128 kv_channels: int = 128 original_latent_shape = [3, 40, 64] apply_yarn: bool = True yarn_beta_fast: int = 4 yarn_beta_slow: int = 1 yarn_scale: int = 2 original_seq_len = 8192 class CosmosModel(LlamaModel): def __init__( self, config: Annotated[Optional[CosmosConfig], Config[CosmosConfig]] = None, optim: Optional[OptimizerModule] = None, tokenizer: Optional["TokenizerSpec"] = None, model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, ): super().__init__(config or CosmosConfig4B(), optim=optim, tokenizer=tokenizer, model_transform=model_transform) self.config = config @io.state_transform( source_key=( "model.layers.*.feed_forward.w1.weight", "model.layers.*.feed_forward.w3.weight", ), target_key="decoder.layers.*.mlp.linear_fc1.weight", ) def _mlp_glu(ctx: io.TransformCTX, w1, w3): return torch.cat((w1, w3), axis=0) @io.state_transform( source_key=( "model.layers.*.attention.wq.weight", "model.layers.*.attention.wk.weight", "model.layers.*.attention.wv.weight", ), target_key="decoder.layers.*.self_attention.linear_qkv.weight", ) def _import_qkv_cosmos(ctx: io.TransformCTX, q, k, v): megatron_config = ctx.target.config head_num = megatron_config.num_attention_heads num_query_groups = megatron_config.num_query_groups heads_per_group = head_num // num_query_groups hidden_size = megatron_config.hidden_size head_size = megatron_config.kv_channels old_tensor_shape = q.size() new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] q = q.view(*new_q_tensor_shape) k = k.view(*new_kv_tensor_shape) v = v.view(*new_kv_tensor_shape) qkv_weights_l = [] for i in range(num_query_groups): qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) qkv_weights_l.append(k[i : i + 1, :, :]) qkv_weights_l.append(v[i : i + 1, :, :]) qkv_weights = torch.cat(qkv_weights_l) assert qkv_weights.ndim == 3, qkv_weights.shape assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape assert qkv_weights.shape[1] == head_size, qkv_weights.shape assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) return qkv_weights @io.model_importer(CosmosModel, "pt") class PTCosmosImporter(io.ModelConnector["PTCosmosModel", CosmosModel]): def init(self) -> CosmosModel: return CosmosModel(self.config, tokenizer=self.tokenizer) def apply(self, output_path: Path) -> Path: pt_model_path = str(self) cosmos_model_state_dict = torch.load(pt_model_path, map_location="cpu") for k, v in cosmos_model_state_dict.items(): # convert to float 32 (for cpu conversion) (Original model is bf16) cosmos_model_state_dict[k] = v.float() # Small wrapper since nemo calls source.state_dict() , to get state dict class WrapperCosmos: def __init__(self, model_state_dict): self.model_state_dict = model_state_dict def state_dict(self): return self.model_state_dict source = WrapperCosmos(cosmos_model_state_dict) target = self.init() trainer = self.nemo_setup(target) self.convert_state(source, target) self.nemo_save(output_path, trainer) log.info(f"Converted PT Cosmos model to Nemo, model saved to {output_path}") teardown(trainer, target) del trainer, target return output_path def convert_state(self, source, target): mapping = { "model.tok_embeddings.weight": "embedding.word_embeddings.weight", "model.layers.*.attention.wo.weight": "decoder.layers.*.self_attention.linear_proj.weight", "model.layers.*.attention.q_norm.weight": "decoder.layers.*.self_attention.q_layernorm.weight", "model.layers.*.attention.k_norm.weight": "decoder.layers.*.self_attention.k_layernorm.weight", "model.layers.*.attention_norm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", "model.layers.*.feed_forward.w2.weight": "decoder.layers.*.mlp.linear_fc2.weight", "model.layers.*.ffn_norm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", "model.norm.weight": "decoder.final_layernorm.weight", "model.output.weight": "output_layer.weight", } return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv_cosmos, _mlp_glu]) @property def tokenizer(self): return None @property def config(self): if "4B" in str(self) or "4b" in str(self): return CosmosConfig4B() elif "12B" in str(self) or "12b" in str(self): return CosmosConfig12B() else: raise ValueError("Unable to infer model size from checkpoint")