File size: 5,331 Bytes
13a8699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from argparse import ArgumentParser

import torch
from huggingface_hub import snapshot_download
from lightning.pytorch.loggers import WandbLogger
from megatron.core.optimizer import OptimizerConfig
from nemo import lightning as nl
from nemo.collections import llm
from nemo.lightning.pytorch.callbacks import ModelCheckpoint, PreemptionCallback
from nemo.lightning.pytorch.strategies.utils import RestoreConfig

from cosmos1.models.autoregressive.nemo.cosmos import CosmosConfig4B, CosmosConfig12B, CosmosModel


def main(args):
    if "4B" in args.model_path:
        config = CosmosConfig4B()
    elif "12B" in args.model_path:
        config = CosmosConfig12B()
    else:
        raise NotImplementedError

    if args.model_path in ["nvidia/Cosmos-1.0-Autoregressive-4B", "nvidia/Cosmos-1.0-Autoregressive-12B"]:
        args.model_path = os.path.join(snapshot_download(args.model_path, allow_patterns=["nemo/*"]), "nemo")

    model = CosmosModel(config)

    data_module = llm.PreTrainingDataModule(
        paths=[args.data_path],
        seq_length=12800,
        global_batch_size=args.global_batch_size,
        micro_batch_size=args.micro_batch_size,
        tokenizer=None,
        split=args.split_string,
        num_workers=1,
        index_mapping_dir=args.index_mapping_dir,
    )

    # Finetune is the same as train (Except train gives the option to set tokenizer to None)
    # So we use it since in this case we dont store a tokenizer with the model
    llm.api.train(
        model=model,
        data=data_module,
        trainer=nl.Trainer(
            devices=args.tensor_model_parallel_size,
            num_nodes=1,
            max_steps=args.max_steps,
            accelerator="gpu",
            strategy=nl.MegatronStrategy(
                tensor_model_parallel_size=args.tensor_model_parallel_size,
                pipeline_model_parallel_size=1,
                context_parallel_size=1,
                sequence_parallel=False,
                pipeline_dtype=torch.bfloat16,
            ),
            plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
            num_sanity_val_steps=0,
            limit_val_batches=0,
            max_epochs=args.max_epochs,
            log_every_n_steps=1,
            callbacks=[
                ModelCheckpoint(
                    monitor="reduced_train_loss",
                    filename="{epoch}-{step}",
                    every_n_train_steps=args.save_every_n_steps,
                    save_top_k=2,
                ),
                PreemptionCallback(),
            ],
        ),
        log=nl.NeMoLogger(wandb=(WandbLogger() if "WANDB_API_KEY" in os.environ else None), log_dir=args.log_dir),
        optim=nl.MegatronOptimizerModule(
            config=OptimizerConfig(
                lr=args.lr,
                bf16=True,
                params_dtype=torch.bfloat16,
                use_distributed_optimizer=False,
            )
        ),
        tokenizer=None,
        resume=nl.AutoResume(
            restore_config=RestoreConfig(path=args.model_path),
            resume_if_exists=True,
            resume_ignore_no_checkpoint=False,
            resume_past_end=True,
        ),
    )


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--data_path", required=True, type=str, help="The path to the .bin .idx files")
    parser.add_argument(
        "--model_path", default="nvidia/Cosmos-1.0-Autoregressive-4B", type=str, help="The path to the nemo model"
    )
    parser.add_argument(
        "--index_mapping_dir", default="./index_mapping", type=str, help="The directory to store mapped indices"
    )
    parser.add_argument("--log_dir", default="./log_dir", type=str, help="The path to the logs")
    parser.add_argument("--split_string", default="98,1,1", type=str, help="The train/test/validation split")
    parser.add_argument("--tensor_model_parallel_size", default=2, type=int, help="Tensor model parallel size")
    parser.add_argument("--max_steps", default=100, type=int, help="The max number of steps to run finetuning")
    parser.add_argument("--save_every_n_steps", default=100, type=int, help="How often to save a checkpoint")
    parser.add_argument("--global_batch_size", default=2, type=int, help="The global batch size")
    parser.add_argument(
        "--micro_batch_size", default=1, type=int, help="The micro batch size if using pipeline parallel"
    )
    parser.add_argument("--lr", default=5e-5, type=float, help="The learning rate")
    parser.add_argument("--max_epochs", default=10, type=int, help="Max number of epochs")

    args = parser.parse_args()

    main(args)