Spaces:
Running
on
Zero
Running
on
Zero
from prefigure.prefigure import get_all_args, push_wandb_config | |
import json | |
import os | |
import torch | |
import pytorch_lightning as pl | |
import random | |
from stable_audio_tools.data.dataset import create_dataloader_from_config | |
from stable_audio_tools.models import create_model_from_config | |
from stable_audio_tools.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model | |
from stable_audio_tools.training import create_training_wrapper_from_config, create_demo_callback_from_config | |
from stable_audio_tools.training.utils import copy_state_dict | |
class ExceptionCallback(pl.Callback): | |
def on_exception(self, trainer, module, err): | |
print(f'{type(err).__name__}: {err}') | |
class ModelConfigEmbedderCallback(pl.Callback): | |
def __init__(self, model_config): | |
self.model_config = model_config | |
def on_save_checkpoint(self, trainer, pl_module, checkpoint): | |
checkpoint["model_config"] = self.model_config | |
def main(): | |
args = get_all_args() | |
seed = args.seed | |
# Set a different seed for each process if using SLURM | |
if os.environ.get("SLURM_PROCID") is not None: | |
seed += int(os.environ.get("SLURM_PROCID")) | |
random.seed(seed) | |
torch.manual_seed(seed) | |
#Get JSON config from args.model_config | |
with open(args.model_config) as f: | |
model_config = json.load(f) | |
with open(args.dataset_config) as f: | |
dataset_config = json.load(f) | |
train_dl = create_dataloader_from_config( | |
dataset_config, | |
batch_size=args.batch_size, | |
num_workers=args.num_workers, | |
sample_rate=model_config["sample_rate"], | |
sample_size=model_config["sample_size"], | |
audio_channels=model_config.get("audio_channels", 2), | |
) | |
model = create_model_from_config(model_config) | |
print(model) | |
if args.pretrained_ckpt_path: | |
copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path)) | |
if args.remove_pretransform_weight_norm == "pre_load": | |
remove_weight_norm_from_model(model.pretransform) | |
if args.pretransform_ckpt_path: | |
model.pretransform.load_state_dict(load_ckpt_state_dict(args.pretransform_ckpt_path)) | |
# Remove weight_norm from the pretransform if specified | |
if args.remove_pretransform_weight_norm == "post_load": | |
remove_weight_norm_from_model(model.pretransform) | |
print("creating training wrapper") | |
training_wrapper = create_training_wrapper_from_config(model_config, model) | |
wandb_logger = pl.loggers.WandbLogger(project=args.name) | |
wandb_logger.watch(training_wrapper) | |
exc_callback = ExceptionCallback() | |
# if args.save_dir and isinstance(wandb_logger.experiment.id, str): | |
# checkpoint_dir = os.path.join(args.save_dir, wandb_logger.experiment.project, wandb_logger.experiment.id, "checkpoints") | |
# else: | |
# checkpoint_dir = None | |
checkpoint_dir = args.save_dir | |
ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, dirpath=checkpoint_dir, save_top_k=-1) | |
save_model_config_callback = ModelConfigEmbedderCallback(model_config) | |
demo_callback = create_demo_callback_from_config(model_config, demo_dl=train_dl) | |
#Combine args and config dicts | |
args_dict = vars(args) | |
args_dict.update({"model_config": model_config}) | |
args_dict.update({"dataset_config": dataset_config}) | |
push_wandb_config(wandb_logger, args_dict) | |
#Set multi-GPU strategy if specified | |
if args.strategy: | |
if args.strategy == "deepspeed": | |
from pytorch_lightning.strategies import DeepSpeedStrategy | |
strategy = DeepSpeedStrategy(stage=2, | |
contiguous_gradients=True, | |
overlap_comm=True, | |
reduce_scatter=True, | |
reduce_bucket_size=5e8, | |
allgather_bucket_size=5e8, | |
load_full_weights=True | |
) | |
else: | |
strategy = args.strategy | |
else: | |
strategy = 'ddp_find_unused_parameters_true' if args.num_gpus > 1 else "auto" | |
trainer = pl.Trainer( | |
devices=args.num_gpus, | |
accelerator="gpu", | |
num_nodes = args.num_nodes, | |
strategy=strategy, | |
precision=args.precision, | |
accumulate_grad_batches=args.accum_batches, | |
callbacks=[ckpt_callback, demo_callback, exc_callback, save_model_config_callback], | |
logger=wandb_logger, | |
log_every_n_steps=1, | |
max_epochs=12, | |
default_root_dir=args.save_dir, | |
gradient_clip_val=args.gradient_clip_val, | |
reload_dataloaders_every_n_epochs = 0 | |
) | |
trainer.fit(training_wrapper, train_dl, ckpt_path=args.ckpt_path if args.ckpt_path else None) | |
if __name__ == '__main__': | |
main() |