import os import sys import glob import json import torch import datetime from collections import deque from distutils.util import strtobool from random import randint, shuffle from time import time as ttime from tqdm import tqdm import numpy as np from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from torch.cuda.amp import GradScaler, autocast from torch.utils.data import DataLoader from torch.nn import functional as F import torch.distributed as dist import torch.multiprocessing as mp now_dir = os.getcwd() sys.path.append(os.path.join(now_dir)) # Zluda hijack import rvc.lib.zluda from utils import ( HParams, plot_spectrogram_to_numpy, summarize, load_checkpoint, save_checkpoint, latest_checkpoint_path, load_wav_to_torch, ) from losses import ( discriminator_loss, feature_loss, generator_loss, kl_loss, ) from mel_processing import ( mel_spectrogram_torch, spec_to_mel_torch, MultiScaleMelSpectrogramLoss, ) from rvc.train.process.extract_model import extract_model from rvc.lib.algorithm import commons # Parse command line arguments model_name = sys.argv[1] save_every_epoch = int(sys.argv[2]) total_epoch = int(sys.argv[3]) pretrainG = sys.argv[4] pretrainD = sys.argv[5] gpus = sys.argv[6] batch_size = int(sys.argv[7]) sample_rate = int(sys.argv[8]) save_only_latest = strtobool(sys.argv[9]) save_every_weights = strtobool(sys.argv[10]) cache_data_in_gpu = strtobool(sys.argv[11]) overtraining_detector = strtobool(sys.argv[12]) overtraining_threshold = int(sys.argv[13]) cleanup = strtobool(sys.argv[14]) vocoder = sys.argv[15] checkpointing = strtobool(sys.argv[16]) randomized = True optimizer = "RAdam" # "AdamW" current_dir = os.getcwd() experiment_dir = os.path.join(current_dir, "logs", model_name) config_save_path = os.path.join(experiment_dir, "config.json") dataset_path = os.path.join(experiment_dir, "sliced_audios") with open(config_save_path, "r") as f: config = json.load(f) config = HParams(**config) config.data.training_files = os.path.join(experiment_dir, "filelist.txt") torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = True global_step = 0 last_loss_gen_all = 0 overtrain_save_epoch = 0 loss_gen_history = [] smoothed_loss_gen_history = [] loss_disc_history = [] smoothed_loss_disc_history = [] lowest_value = {"step": 0, "value": float("inf"), "epoch": 0} training_file_path = os.path.join(experiment_dir, "training_data.json") avg_losses = { "gen_loss_queue": deque(maxlen=10), "disc_loss_queue": deque(maxlen=10), "disc_loss_50": deque(maxlen=50), "fm_loss_50": deque(maxlen=50), "kl_loss_50": deque(maxlen=50), "mel_loss_50": deque(maxlen=50), "gen_loss_50": deque(maxlen=50), } import logging logging.getLogger("torch").setLevel(logging.ERROR) class EpochRecorder: """ Records the time elapsed per epoch. """ def __init__(self): self.last_time = ttime() def record(self): """ Records the elapsed time and returns a formatted string. """ now_time = ttime() elapsed_time = now_time - self.last_time self.last_time = now_time elapsed_time = round(elapsed_time, 1) elapsed_time_str = str(datetime.timedelta(seconds=int(elapsed_time))) current_time = datetime.datetime.now().strftime("%H:%M:%S") return f"time={current_time} | training_speed={elapsed_time_str}" def verify_checkpoint_shapes(checkpoint_path, model): checkpoint = torch.load(checkpoint_path, map_location="cpu") checkpoint_state_dict = checkpoint["model"] try: if hasattr(model, "module"): model_state_dict = model.module.load_state_dict(checkpoint_state_dict) else: model_state_dict = model.load_state_dict(checkpoint_state_dict) except RuntimeError: print( "The parameters of the pretrain model such as the sample rate or architecture do not match the selected model." ) sys.exit(1) else: del checkpoint del checkpoint_state_dict del model_state_dict def main(): """ Main function to start the training process. """ global training_file_path, last_loss_gen_all, smoothed_loss_gen_history, loss_gen_history, loss_disc_history, smoothed_loss_disc_history, overtrain_save_epoch, gpus os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(randint(20000, 55555)) # Check sample rate wavs = glob.glob( os.path.join(os.path.join(experiment_dir, "sliced_audios"), "*.wav") ) if wavs: _, sr = load_wav_to_torch(wavs[0]) if sr != sample_rate: print( f"Error: Pretrained model sample rate ({sample_rate} Hz) does not match dataset audio sample rate ({sr} Hz)." ) os._exit(1) else: print("No wav file found.") if torch.cuda.is_available(): device = torch.device("cuda") gpus = [int(item) for item in gpus.split("-")] n_gpus = len(gpus) elif torch.backends.mps.is_available(): device = torch.device("mps") gpus = [0] n_gpus = 1 else: device = torch.device("cpu") gpus = [0] n_gpus = 1 print("Training with CPU, this will take a long time.") def start(): """ Starts the training process with multi-GPU support or CPU. """ children = [] pid_data = {"process_pids": []} with open(config_save_path, "r") as pid_file: try: existing_data = json.load(pid_file) pid_data.update(existing_data) except json.JSONDecodeError: pass with open(config_save_path, "w") as pid_file: for rank, device_id in enumerate(gpus): subproc = mp.Process( target=run, args=( rank, n_gpus, experiment_dir, pretrainG, pretrainD, total_epoch, save_every_weights, config, device, device_id, ), ) children.append(subproc) subproc.start() pid_data["process_pids"].append(subproc.pid) json.dump(pid_data, pid_file, indent=4) for i in range(n_gpus): children[i].join() def load_from_json(file_path): """ Load data from a JSON file. Args: file_path (str): The path to the JSON file. """ if os.path.exists(file_path): with open(file_path, "r") as f: data = json.load(f) return ( data.get("loss_disc_history", []), data.get("smoothed_loss_disc_history", []), data.get("loss_gen_history", []), data.get("smoothed_loss_gen_history", []), ) return [], [], [], [] def continue_overtrain_detector(training_file_path): """ Continues the overtrain detector by loading the training history from a JSON file. Args: training_file_path (str): The file path of the JSON file containing the training history. """ if overtraining_detector: if os.path.exists(training_file_path): ( loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history, ) = load_from_json(training_file_path) if cleanup: print("Removing files from the prior training attempt...") # Clean up unnecessary files for root, dirs, files in os.walk( os.path.join(now_dir, "logs", model_name), topdown=False ): for name in files: file_path = os.path.join(root, name) file_name, file_extension = os.path.splitext(name) if ( file_extension == ".0" or (file_name.startswith("D_") and file_extension == ".pth") or (file_name.startswith("G_") and file_extension == ".pth") or (file_name.startswith("added") and file_extension == ".index") ): os.remove(file_path) for name in dirs: if name == "eval": folder_path = os.path.join(root, name) for item in os.listdir(folder_path): item_path = os.path.join(folder_path, item) if os.path.isfile(item_path): os.remove(item_path) os.rmdir(folder_path) print("Cleanup done!") continue_overtrain_detector(training_file_path) start() def run( rank, n_gpus, experiment_dir, pretrainG, pretrainD, custom_total_epoch, custom_save_every_weights, config, device, device_id, ): """ Runs the training loop on a specific GPU or CPU. Args: rank (int): The rank of the current process within the distributed training setup. n_gpus (int): The total number of GPUs available for training. experiment_dir (str): The directory where experiment logs and checkpoints will be saved. pretrainG (str): Path to the pre-trained generator model. pretrainD (str): Path to the pre-trained discriminator model. custom_total_epoch (int): The total number of epochs for training. custom_save_every_weights (int): The interval (in epochs) at which to save model weights. config (object): Configuration object containing training parameters. device (torch.device): The device to use for training (CPU or GPU). """ global global_step, smoothed_value_gen, smoothed_value_disc, optimizer smoothed_value_gen = 0 smoothed_value_disc = 0 if rank == 0: writer_eval = SummaryWriter(log_dir=os.path.join(experiment_dir, "eval")) else: writer_eval = None dist.init_process_group( backend="gloo", init_method="env://", world_size=n_gpus if device.type == "cuda" else 1, rank=rank if device.type == "cuda" else 0, ) torch.manual_seed(config.train.seed) if torch.cuda.is_available(): torch.cuda.set_device(device_id) # Create datasets and dataloaders from data_utils import ( DistributedBucketSampler, TextAudioCollateMultiNSFsid, TextAudioLoaderMultiNSFsid, ) train_dataset = TextAudioLoaderMultiNSFsid(config.data) collate_fn = TextAudioCollateMultiNSFsid() train_sampler = DistributedBucketSampler( train_dataset, batch_size * n_gpus, [50, 100, 200, 300, 400, 500, 600, 700, 800, 900], num_replicas=n_gpus, rank=rank, shuffle=True, ) train_loader = DataLoader( train_dataset, num_workers=4, shuffle=False, pin_memory=True, collate_fn=collate_fn, batch_sampler=train_sampler, persistent_workers=True, prefetch_factor=8, ) # Validations if len(train_loader) < 3: print( "Not enough data present in the training set. Perhaps you forgot to slice the audio files in preprocess?" ) os._exit(2333333) else: g_file = latest_checkpoint_path(experiment_dir, "G_*.pth") if g_file != None: print("Checking saved weights...") g = torch.load(g_file, map_location="cpu") if ( optimizer == "RAdam" and "amsgrad" in g["optimizer"]["param_groups"][0].keys() ): optimizer = "AdamW" print( f"Optimizer choice has been reverted to {optimizer} to match the saved D/G weights." ) elif ( optimizer == "AdamW" and "decoupled_weight_decay" in g["optimizer"]["param_groups"][0].keys() ): optimizer = "RAdam" print( f"Optimizer choice has been reverted to {optimizer} to match the saved D/G weights." ) del g # Initialize models and optimizers from rvc.lib.algorithm.discriminators import MultiPeriodDiscriminator from rvc.lib.algorithm.synthesizers import Synthesizer net_g = Synthesizer( config.data.filter_length // 2 + 1, config.train.segment_size // config.data.hop_length, **config.model, use_f0=True, sr=sample_rate, vocoder=vocoder, checkpointing=checkpointing, randomized=randomized, ) net_d = MultiPeriodDiscriminator( config.model.use_spectral_norm, checkpointing=checkpointing ) if torch.cuda.is_available(): net_g = net_g.cuda(device_id) net_d = net_d.cuda(device_id) else: net_g.to(device) net_d.to(device) if optimizer == "AdamW": optimizer = torch.optim.AdamW elif optimizer == "RAdam": optimizer = torch.optim.RAdam optim_g = optimizer( net_g.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps, ) optim_d = optimizer( net_d.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps, ) fn_mel_loss = MultiScaleMelSpectrogramLoss(sample_rate=sample_rate) # Wrap models with DDP for multi-gpu processing if n_gpus > 1 and device.type == "cuda": net_g = DDP(net_g, device_ids=[device_id]) net_d = DDP(net_d, device_ids=[device_id]) # Load checkpoint if available try: print("Starting training...") _, _, _, epoch_str = load_checkpoint( latest_checkpoint_path(experiment_dir, "D_*.pth"), net_d, optim_d ) _, _, _, epoch_str = load_checkpoint( latest_checkpoint_path(experiment_dir, "G_*.pth"), net_g, optim_g ) epoch_str += 1 global_step = (epoch_str - 1) * len(train_loader) except: epoch_str = 1 global_step = 0 if pretrainG != "" and pretrainG != "None": if rank == 0: verify_checkpoint_shapes(pretrainG, net_g) print(f"Loaded pretrained (G) '{pretrainG}'") if hasattr(net_g, "module"): net_g.module.load_state_dict( torch.load(pretrainG, map_location="cpu")["model"] ) else: net_g.load_state_dict( torch.load(pretrainG, map_location="cpu")["model"] ) if pretrainD != "" and pretrainD != "None": if rank == 0: print(f"Loaded pretrained (D) '{pretrainD}'") if hasattr(net_d, "module"): net_d.module.load_state_dict( torch.load(pretrainD, map_location="cpu")["model"] ) else: net_d.load_state_dict( torch.load(pretrainD, map_location="cpu")["model"] ) # Initialize schedulers scheduler_g = torch.optim.lr_scheduler.ExponentialLR( optim_g, gamma=config.train.lr_decay, last_epoch=epoch_str - 2 ) scheduler_d = torch.optim.lr_scheduler.ExponentialLR( optim_d, gamma=config.train.lr_decay, last_epoch=epoch_str - 2 ) cache = [] # get the first sample as reference for tensorboard evaluation # custom reference temporarily disabled if True == False and os.path.isfile( os.path.join("logs", "reference", f"ref{sample_rate}.wav") ): phone = np.load( os.path.join("logs", "reference", f"ref{sample_rate}_feats.npy") ) # expanding x2 to match pitch size phone = np.repeat(phone, 2, axis=0) phone = torch.FloatTensor(phone).unsqueeze(0).to(device) phone_lengths = torch.LongTensor(phone.size(0)).to(device) pitch = np.load(os.path.join("logs", "reference", f"ref{sample_rate}_f0c.npy")) # removed last frame to match features pitch = torch.LongTensor(pitch[:-1]).unsqueeze(0).to(device) pitchf = np.load(os.path.join("logs", "reference", f"ref{sample_rate}_f0f.npy")) # removed last frame to match features pitchf = torch.FloatTensor(pitchf[:-1]).unsqueeze(0).to(device) sid = torch.LongTensor([0]).to(device) reference = ( phone, phone_lengths, pitch, pitchf, sid, ) else: for info in train_loader: phone, phone_lengths, pitch, pitchf, _, _, _, _, sid = info if device.type == "cuda": reference = ( phone.cuda(device_id, non_blocking=True), phone_lengths.cuda(device_id, non_blocking=True), pitch.cuda(device_id, non_blocking=True), pitchf.cuda(device_id, non_blocking=True), sid.cuda(device_id, non_blocking=True), ) else: reference = ( phone.to(device), phone_lengths.to(device), pitch.to(device), pitchf.to(device), sid.to(device), ) break for epoch in range(epoch_str, total_epoch + 1): train_and_evaluate( rank, epoch, config, [net_g, net_d], [optim_g, optim_d], [train_loader, None], [writer_eval], cache, custom_save_every_weights, custom_total_epoch, device, device_id, reference, fn_mel_loss, ) scheduler_g.step() scheduler_d.step() def train_and_evaluate( rank, epoch, hps, nets, optims, loaders, writers, cache, custom_save_every_weights, custom_total_epoch, device, device_id, reference, fn_mel_loss, ): """ Trains and evaluates the model for one epoch. Args: rank (int): Rank of the current process. epoch (int): Current epoch number. hps (Namespace): Hyperparameters. nets (list): List of models [net_g, net_d]. optims (list): List of optimizers [optim_g, optim_d]. loaders (list): List of dataloaders [train_loader, eval_loader]. writers (list): List of TensorBoard writers [writer_eval]. cache (list): List to cache data in GPU memory. use_cpu (bool): Whether to use CPU for training. """ global global_step, lowest_value, loss_disc, consecutive_increases_gen, consecutive_increases_disc, smoothed_value_gen, smoothed_value_disc if epoch == 1: lowest_value = {"step": 0, "value": float("inf"), "epoch": 0} consecutive_increases_gen = 0 consecutive_increases_disc = 0 epoch_disc_sum = 0.0 epoch_gen_sum = 0.0 net_g, net_d = nets optim_g, optim_d = optims train_loader = loaders[0] if loaders is not None else None if writers is not None: writer = writers[0] train_loader.batch_sampler.set_epoch(epoch) net_g.train() net_d.train() # Data caching if device.type == "cuda" and cache_data_in_gpu: data_iterator = cache if cache == []: for batch_idx, info in enumerate(train_loader): # phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, wave_lengths, sid info = [tensor.cuda(device_id, non_blocking=True) for tensor in info] cache.append((batch_idx, info)) else: shuffle(cache) else: data_iterator = enumerate(train_loader) epoch_recorder = EpochRecorder() with tqdm(total=len(train_loader), leave=False) as pbar: for batch_idx, info in data_iterator: if device.type == "cuda" and not cache_data_in_gpu: info = [tensor.cuda(device_id, non_blocking=True) for tensor in info] elif device.type != "cuda": info = [tensor.to(device) for tensor in info] # else iterator is going thru a cached list with a device already assigned ( phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, wave_lengths, sid, ) = info # Forward pass model_output = net_g( phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid ) y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = ( model_output ) # slice of the original waveform to match a generate slice if randomized: wave = commons.slice_segments( wave, ids_slice * config.data.hop_length, config.train.segment_size, dim=3, ) y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach()) loss_disc, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g) # Discriminator backward and update epoch_disc_sum += loss_disc.item() optim_d.zero_grad() loss_disc.backward() grad_norm_d = torch.nn.utils.clip_grad_norm_( net_d.parameters(), max_norm=1000.0 ) optim_d.step() # Generator backward and update _, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat) loss_mel = fn_mel_loss(wave, y_hat) * config.train.c_mel / 3.0 loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl loss_fm = feature_loss(fmap_r, fmap_g) loss_gen, _ = generator_loss(y_d_hat_g) loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl if loss_gen_all < lowest_value["value"]: lowest_value = { "step": global_step, "value": loss_gen_all, "epoch": epoch, } epoch_gen_sum += loss_gen_all.item() optim_g.zero_grad() loss_gen_all.backward() grad_norm_g = torch.nn.utils.clip_grad_norm_( net_g.parameters(), max_norm=1000.0 ) optim_g.step() global_step += 1 # queue for rolling losses over 50 steps avg_losses["disc_loss_50"].append(loss_disc.detach()) avg_losses["fm_loss_50"].append(loss_fm.detach()) avg_losses["kl_loss_50"].append(loss_kl.detach()) avg_losses["mel_loss_50"].append(loss_mel.detach()) avg_losses["gen_loss_50"].append(loss_gen_all.detach()) if rank == 0 and global_step % 50 == 0: # logging rolling averages scalar_dict = { "loss_avg_50/d/total": torch.mean( torch.stack(list(avg_losses["disc_loss_50"])) ), "loss_avg_50/g/fm": torch.mean( torch.stack(list(avg_losses["fm_loss_50"])) ), "loss_avg_50/g/kl": torch.mean( torch.stack(list(avg_losses["kl_loss_50"])) ), "loss_avg_50/g/mel": torch.mean( torch.stack(list(avg_losses["mel_loss_50"])) ), "loss_avg_50/g/total": torch.mean( torch.stack(list(avg_losses["gen_loss_50"])) ), } summarize( writer=writer, global_step=global_step, scalars=scalar_dict, ) pbar.update(1) # end of batch train # end of tqdm with torch.no_grad(): torch.cuda.empty_cache() # Logging and checkpointing if rank == 0: avg_losses["disc_loss_queue"].append(epoch_disc_sum / len(train_loader)) avg_losses["gen_loss_queue"].append(epoch_gen_sum / len(train_loader)) # used for tensorboard chart - all/mel mel = spec_to_mel_torch( spec, config.data.filter_length, config.data.n_mel_channels, config.data.sample_rate, config.data.mel_fmin, config.data.mel_fmax, ) # used for tensorboard chart - slice/mel_org if randomized: y_mel = commons.slice_segments( mel, ids_slice, config.train.segment_size // config.data.hop_length, dim=3, ) else: y_mel = mel # used for tensorboard chart - slice/mel_gen y_hat_mel = mel_spectrogram_torch( y_hat.float().squeeze(1), config.data.filter_length, config.data.n_mel_channels, config.data.sample_rate, config.data.hop_length, config.data.win_length, config.data.mel_fmin, config.data.mel_fmax, ) lr = optim_g.param_groups[0]["lr"] scalar_dict = { "loss/g/total": loss_gen_all, "loss/d/total": loss_disc, "learning_rate": lr, "grad/norm_d": grad_norm_d.item(), "grad/norm_g": grad_norm_g.item(), "loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl, "loss_avg_epoch/disc": np.mean(avg_losses["disc_loss_queue"]), "loss_avg_epoch/gen": np.mean(avg_losses["gen_loss_queue"]), } image_dict = { "slice/mel_org": plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), "slice/mel_gen": plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), "all/mel": plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), } if epoch % save_every_epoch == 0: with torch.no_grad(): if hasattr(net_g, "module"): o, *_ = net_g.module.infer(*reference) else: o, *_ = net_g.infer(*reference) audio_dict = {f"gen/audio_{global_step:07d}": o[0, :, :]} summarize( writer=writer, global_step=global_step, images=image_dict, scalars=scalar_dict, audios=audio_dict, audio_sample_rate=config.data.sample_rate, ) else: summarize( writer=writer, global_step=global_step, images=image_dict, scalars=scalar_dict, ) # Save checkpoint model_add = [] model_del = [] done = False if rank == 0: overtrain_info = "" # Check overtraining if overtraining_detector and rank == 0 and epoch > 1: # Add the current loss to the history current_loss_disc = float(loss_disc) loss_disc_history.append(current_loss_disc) # Update smoothed loss history with loss_disc smoothed_value_disc = update_exponential_moving_average( smoothed_loss_disc_history, current_loss_disc ) # Check overtraining with smoothed loss_disc is_overtraining_disc = check_overtraining( smoothed_loss_disc_history, overtraining_threshold * 2 ) if is_overtraining_disc: consecutive_increases_disc += 1 else: consecutive_increases_disc = 0 # Add the current loss_gen to the history current_loss_gen = float(lowest_value["value"]) loss_gen_history.append(current_loss_gen) # Update the smoothed loss_gen history smoothed_value_gen = update_exponential_moving_average( smoothed_loss_gen_history, current_loss_gen ) # Check for overtraining with the smoothed loss_gen is_overtraining_gen = check_overtraining( smoothed_loss_gen_history, overtraining_threshold, 0.01 ) if is_overtraining_gen: consecutive_increases_gen += 1 else: consecutive_increases_gen = 0 overtrain_info = f"Smoothed loss_g {smoothed_value_gen:.3f} and loss_d {smoothed_value_disc:.3f}" # Save the data in the JSON file if the epoch is divisible by save_every_epoch if epoch % save_every_epoch == 0: save_to_json( training_file_path, loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history, ) if ( is_overtraining_gen and consecutive_increases_gen == overtraining_threshold or is_overtraining_disc and consecutive_increases_disc == overtraining_threshold * 2 ): print( f"Overtraining detected at epoch {epoch} with smoothed loss_g {smoothed_value_gen:.3f} and loss_d {smoothed_value_disc:.3f}" ) done = True else: print( f"New best epoch {epoch} with smoothed loss_g {smoothed_value_gen:.3f} and loss_d {smoothed_value_disc:.3f}" ) old_model_files = glob.glob( os.path.join(experiment_dir, f"{model_name}_*e_*s_best_epoch.pth") ) for file in old_model_files: model_del.append(file) model_add.append( os.path.join( experiment_dir, f"{model_name}_{epoch}e_{global_step}s_best_epoch.pth", ) ) # Print training progress lowest_value_rounded = float(lowest_value["value"]) lowest_value_rounded = round(lowest_value_rounded, 3) record = f"{model_name} | epoch={epoch} | step={global_step} | {epoch_recorder.record()}" if epoch > 1: record = ( record + f" | lowest_value={lowest_value_rounded} (epoch {lowest_value['epoch']} and step {lowest_value['step']})" ) if overtraining_detector: remaining_epochs_gen = overtraining_threshold - consecutive_increases_gen remaining_epochs_disc = ( overtraining_threshold * 2 - consecutive_increases_disc ) record = ( record + f" | Number of epochs remaining for overtraining: g/total: {remaining_epochs_gen} d/total: {remaining_epochs_disc} | smoothed_loss_gen={smoothed_value_gen:.3f} | smoothed_loss_disc={smoothed_value_disc:.3f}" ) print(record) # Save weights every N epochs if epoch % save_every_epoch == 0: checkpoint_suffix = f"{2333333 if save_only_latest else global_step}.pth" save_checkpoint( net_g, optim_g, config.train.learning_rate, epoch, os.path.join(experiment_dir, "G_" + checkpoint_suffix), ) save_checkpoint( net_d, optim_d, config.train.learning_rate, epoch, os.path.join(experiment_dir, "D_" + checkpoint_suffix), ) if custom_save_every_weights: model_add.append( os.path.join( experiment_dir, f"{model_name}_{epoch}e_{global_step}s.pth" ) ) # Clean-up old best epochs for m in model_del: os.remove(m) if model_add: ckpt = ( net_g.module.state_dict() if hasattr(net_g, "module") else net_g.state_dict() ) for m in model_add: if os.path.exists(m): print(f"{m} already exists. Overwriting.") extract_model( ckpt=ckpt, sr=sample_rate, name=model_name, model_path=m, epoch=epoch, step=global_step, hps=hps, overtrain_info=overtrain_info, vocoder=vocoder, ) # Check completion if epoch >= custom_total_epoch: lowest_value_rounded = float(lowest_value["value"]) lowest_value_rounded = round(lowest_value_rounded, 3) print( f"Training has been successfully completed with {epoch} epoch, {global_step} steps and {round(loss_gen_all.item(), 3)} loss gen." ) print( f"Lowest generator loss: {lowest_value_rounded} at epoch {lowest_value['epoch']}, step {lowest_value['step']}" ) pid_file_path = os.path.join(experiment_dir, "config.json") with open(pid_file_path, "r") as pid_file: pid_data = json.load(pid_file) with open(pid_file_path, "w") as pid_file: pid_data.pop("process_pids", None) json.dump(pid_data, pid_file, indent=4) # Final model model_add.append( os.path.join( experiment_dir, f"{model_name}_{epoch}e_{global_step}s.pth" ) ) done = True if done: os._exit(2333333) with torch.no_grad(): torch.cuda.empty_cache() def check_overtraining(smoothed_loss_history, threshold, epsilon=0.004): """ Checks for overtraining based on the smoothed loss history. Args: smoothed_loss_history (list): List of smoothed losses for each epoch. threshold (int): Number of consecutive epochs with insignificant changes or increases to consider overtraining. epsilon (float): The maximum change considered insignificant. """ if len(smoothed_loss_history) < threshold + 1: return False for i in range(-threshold, -1): if smoothed_loss_history[i + 1] > smoothed_loss_history[i]: return True if abs(smoothed_loss_history[i + 1] - smoothed_loss_history[i]) >= epsilon: return False return True def update_exponential_moving_average( smoothed_loss_history, new_value, smoothing=0.987 ): """ Updates the exponential moving average with a new value. Args: smoothed_loss_history (list): List of smoothed values. new_value (float): New value to be added. smoothing (float): Smoothing factor. """ if smoothed_loss_history: smoothed_value = ( smoothing * smoothed_loss_history[-1] + (1 - smoothing) * new_value ) else: smoothed_value = new_value smoothed_loss_history.append(smoothed_value) return smoothed_value def save_to_json( file_path, loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history, ): """ Save the training history to a JSON file. """ data = { "loss_disc_history": loss_disc_history, "smoothed_loss_disc_history": smoothed_loss_disc_history, "loss_gen_history": loss_gen_history, "smoothed_loss_gen_history": smoothed_loss_gen_history, } with open(file_path, "w") as f: json.dump(data, f) if __name__ == "__main__": torch.multiprocessing.set_start_method("spawn") main()