Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,476 Bytes
600759a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
from typing import Tuple, List
import warnings
warnings.filterwarnings("ignore")
import os
import torch
import argparse
from pathlib import Path
from omegaconf import OmegaConf, DictConfig
from einops._torch_specific import allow_ops_in_compiled_graph # requires einops>=0.6.1
allow_ops_in_compiled_graph()
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
from pytorch_lightning.strategies import DDPStrategy, DeepSpeedStrategy
from pytorch_lightning.loggers import Logger, TensorBoardLogger
from pytorch_lightning.utilities import rank_zero_info
from hy3dshape.utils import get_config_from_file, instantiate_from_config
class SetupCallback(Callback):
def __init__(self, config: DictConfig, basedir: Path, logdir: str = "log", ckptdir: str = "ckpt") -> None:
super().__init__()
self.logdir = basedir / logdir
self.ckptdir = basedir / ckptdir
self.config = config
def on_fit_start(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule) -> None:
if trainer.global_rank == 0:
os.makedirs(self.logdir, exist_ok=True)
os.makedirs(self.ckptdir, exist_ok=True)
def setup_callbacks(config: DictConfig) -> Tuple[List[Callback], Logger]:
training_cfg = config.training
basedir = Path(training_cfg.output_dir)
os.makedirs(basedir, exist_ok=True)
all_callbacks = []
setup_callback = SetupCallback(config, basedir)
all_callbacks.append(setup_callback)
checkpoint_callback = ModelCheckpoint(
dirpath=setup_callback.ckptdir,
filename="ckpt-{step:08d}",
monitor=training_cfg.monitor,
mode="max",
save_top_k=-1,
verbose=False,
every_n_train_steps=training_cfg.every_n_train_steps)
all_callbacks.append(checkpoint_callback)
if "callbacks" in config:
for key, value in config['callbacks'].items():
custom_callback = instantiate_from_config(value)
all_callbacks.append(custom_callback)
logger = TensorBoardLogger(save_dir=str(setup_callback.logdir), name="tensorboard")
return all_callbacks, logger
def merge_cfg(cfg, arg_cfg):
for key in arg_cfg.keys():
if key in cfg.training:
arg_cfg[key] = cfg.training[key]
cfg.training = DictConfig(arg_cfg)
return cfg
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--fast", action='store_true')
parser.add_argument("-c", "--config", type=str, required=True)
parser.add_argument("-s", "--seed", type=int, default=0)
parser.add_argument("-nn", "--num_nodes", type=int, default=1)
parser.add_argument("-ng", "--num_gpus", type=int, default=1)
parser.add_argument("-u", "--update_every", type=int, default=1)
parser.add_argument("-st", "--steps", type=int, default=50000000)
parser.add_argument("-lr", "--base_lr", type=float, default=4.5e-6)
parser.add_argument("-a", "--use_amp", default=False, action="store_true")
parser.add_argument("--amp_type", type=str, default="16")
parser.add_argument("--gradient_clip_val", type=float, default=None)
parser.add_argument("--gradient_clip_algorithm", type=str, default=None)
parser.add_argument("--every_n_train_steps", type=int, default=50000)
parser.add_argument("--log_every_n_steps", type=int, default=50)
parser.add_argument("--val_check_interval", type=int, default=1024)
parser.add_argument("--limit_val_batches", type=int, default=64)
parser.add_argument("--monitor", type=str, default="val/total_loss")
parser.add_argument("--output_dir", type=str, help="the output directory to save everything.")
parser.add_argument("--ckpt_path", type=str, default="", help="the restore checkpoints.")
parser.add_argument("--deepspeed", default=False, action="store_true")
parser.add_argument("--deepspeed2", default=False, action="store_true")
parser.add_argument("--scale_lr", type=bool, nargs="?", const=True, default=False,
help="scale base-lr by ngpu * batch_size * n_accumulate")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
if args.fast:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision('medium')
torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL = 0.05
# Set random seed
pl.seed_everything(args.seed, workers=True)
# Load configuration
config = get_config_from_file(args.config)
config = merge_cfg(config, vars(args))
training_cfg = config.training
# print config
rank_zero_info("Begin to print configuration ...")
rank_zero_info(OmegaConf.to_yaml(config))
rank_zero_info("Finish print ...")
# Setup callbacks
callbacks, loggers = setup_callbacks(config)
# Build data modules
data: pl.LightningDataModule = instantiate_from_config(config.dataset)
# Build model
model: pl.LightningModule = instantiate_from_config(config.model)
nodes = args.num_nodes
ngpus = args.num_gpus
base_lr = training_cfg.base_lr
accumulate_grad_batches = training_cfg.update_every
batch_size = config.dataset.params.batch_size
if 'NNODES' in os.environ:
nodes = int(os.environ['NNODES'])
training_cfg.num_nodes = nodes
args.num_nodes = nodes
if args.scale_lr:
model.learning_rate = accumulate_grad_batches * nodes * ngpus * batch_size * base_lr
info = f"Setting learning rate to {model.learning_rate:.2e} = {accumulate_grad_batches} (accumulate)"
info += f" * {nodes} (nodes) * {ngpus} (num_gpus) * {batch_size} (batchsize) * {base_lr:.2e} (base_lr)"
rank_zero_info(info)
else:
model.learning_rate = base_lr
rank_zero_info("++++ NOT USING LR SCALING ++++")
rank_zero_info(f"Setting learning rate to {model.learning_rate:.2e}")
# Build trainer
if args.num_nodes > 1 or args.num_gpus > 1:
if args.deepspeed:
ddp_strategy = DeepSpeedStrategy(stage=1)
elif args.deepspeed2:
ddp_strategy = 'deepspeed_stage_2'
else:
ddp_strategy = DDPStrategy(find_unused_parameters=False, bucket_cap_mb=1500)
else:
ddp_strategy = None # 'auto'
rank_zero_info(f'*' * 100)
if training_cfg.use_amp:
amp_type = training_cfg.amp_type
assert amp_type in ['bf16', '16', '32'], f"Invalid amp_type: {amp_type}"
rank_zero_info(f'Using {amp_type} precision')
else:
amp_type = 32
rank_zero_info(f'Using 32 bit precision')
rank_zero_info(f'*' * 100)
trainer = pl.Trainer(
max_steps=training_cfg.steps,
precision=amp_type,
callbacks=callbacks,
accelerator="gpu",
devices=training_cfg.num_gpus,
num_nodes=training_cfg.num_nodes,
strategy=ddp_strategy,
gradient_clip_val=training_cfg.get('gradient_clip_val'),
gradient_clip_algorithm=training_cfg.get('gradient_clip_algorithm'),
accumulate_grad_batches=args.update_every,
logger=loggers,
log_every_n_steps=training_cfg.log_every_n_steps,
val_check_interval=training_cfg.val_check_interval,
limit_val_batches=training_cfg.limit_val_batches,
check_val_every_n_epoch=None
)
# Train
if training_cfg.ckpt_path == '':
training_cfg.ckpt_path = None
trainer.fit(model, datamodule=data, ckpt_path=training_cfg.ckpt_path)
|