import torch import torchaudio import wandb from einops import rearrange from safetensors.torch import save_file, save_model from ema_pytorch import EMA from .losses.auraloss import SumAndDifferenceSTFTLoss, MultiResolutionSTFTLoss import pytorch_lightning as pl from ..models.autoencoders import AudioAutoencoder from ..models.discriminators import EncodecDiscriminator, OobleckDiscriminator, DACGANLoss from ..models.bottleneck import VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, RVQVAEBottleneck, WassersteinBottleneck from .losses import MultiLoss, AuralossLoss, ValueLoss, L1Loss from .utils import create_optimizer_from_config, create_scheduler_from_config from pytorch_lightning.utilities.rank_zero import rank_zero_only from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image class AutoencoderTrainingWrapper(pl.LightningModule): def __init__( self, autoencoder: AudioAutoencoder, lr: float = 1e-4, warmup_steps: int = 0, encoder_freeze_on_warmup: bool = False, sample_rate=48000, loss_config: dict = None, optimizer_configs: dict = None, use_ema: bool = True, ema_copy = None, force_input_mono = False, latent_mask_ratio = 0.0, teacher_model: AudioAutoencoder = None ): super().__init__() self.automatic_optimization = False self.autoencoder = autoencoder self.warmed_up = False self.warmup_steps = warmup_steps self.encoder_freeze_on_warmup = encoder_freeze_on_warmup self.lr = lr self.force_input_mono = force_input_mono self.teacher_model = teacher_model if optimizer_configs is None: optimizer_configs ={ "autoencoder": { "optimizer": { "type": "AdamW", "config": { "lr": lr, "betas": (.8, .99) } } }, "discriminator": { "optimizer": { "type": "AdamW", "config": { "lr": lr, "betas": (.8, .99) } } } } self.optimizer_configs = optimizer_configs if loss_config is None: scales = [2048, 1024, 512, 256, 128, 64, 32] hop_sizes = [] win_lengths = [] overlap = 0.75 for s in scales: hop_sizes.append(int(s * (1 - overlap))) win_lengths.append(s) loss_config = { "discriminator": { "type": "encodec", "config": { "n_ffts": scales, "hop_lengths": hop_sizes, "win_lengths": win_lengths, "filters": 32 }, "weights": { "adversarial": 0.1, "feature_matching": 5.0, } }, "spectral": { "type": "mrstft", "config": { "fft_sizes": scales, "hop_sizes": hop_sizes, "win_lengths": win_lengths, "perceptual_weighting": True }, "weights": { "mrstft": 1.0, } }, "time": { "type": "l1", "config": {}, "weights": { "l1": 0.0, } } } self.loss_config = loss_config # Spectral reconstruction loss stft_loss_args = loss_config['spectral']['config'] if self.autoencoder.out_channels == 2: self.sdstft = SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args) self.lrstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) else: self.sdstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) # Discriminator if loss_config['discriminator']['type'] == 'oobleck': self.discriminator = OobleckDiscriminator(**loss_config['discriminator']['config']) elif loss_config['discriminator']['type'] == 'encodec': self.discriminator = EncodecDiscriminator(in_channels=self.autoencoder.out_channels, **loss_config['discriminator']['config']) elif loss_config['discriminator']['type'] == 'dac': self.discriminator = DACGANLoss(channels=self.autoencoder.out_channels, sample_rate=sample_rate, **loss_config['discriminator']['config']) self.gen_loss_modules = [] # Adversarial and feature matching losses self.gen_loss_modules += [ ValueLoss(key='loss_adv', weight=self.loss_config['discriminator']['weights']['adversarial'], name='loss_adv'), ValueLoss(key='feature_matching_distance', weight=self.loss_config['discriminator']['weights']['feature_matching'], name='feature_matching'), ] if self.teacher_model is not None: # Distillation losses stft_loss_weight = self.loss_config['spectral']['weights']['mrstft'] * 0.25 self.gen_loss_modules += [ AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=stft_loss_weight), # Reconstruction loss AuralossLoss(self.sdstft, 'decoded', 'teacher_decoded', name='mrstft_loss_distill', weight=stft_loss_weight), # Distilled model's decoder is compatible with teacher's decoder AuralossLoss(self.sdstft, 'reals', 'own_latents_teacher_decoded', name='mrstft_loss_own_latents_teacher', weight=stft_loss_weight), # Distilled model's encoder is compatible with teacher's decoder AuralossLoss(self.sdstft, 'reals', 'teacher_latents_own_decoded', name='mrstft_loss_teacher_latents_own', weight=stft_loss_weight) # Teacher's encoder is compatible with distilled model's decoder ] else: # Reconstruction loss self.gen_loss_modules += [ AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']), ] if self.autoencoder.out_channels == 2: # Add left and right channel reconstruction losses in addition to the sum and difference self.gen_loss_modules += [ AuralossLoss(self.lrstft, 'reals_left', 'decoded_left', name='stft_loss_left', weight=self.loss_config['spectral']['weights']['mrstft']/2), AuralossLoss(self.lrstft, 'reals_right', 'decoded_right', name='stft_loss_right', weight=self.loss_config['spectral']['weights']['mrstft']/2), ] self.gen_loss_modules += [ AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']), ] if self.loss_config['time']['weights']['l1'] > 0.0: self.gen_loss_modules.append(L1Loss(key_a='reals', key_b='decoded', weight=self.loss_config['time']['weights']['l1'], name='l1_time_loss')) if self.autoencoder.bottleneck is not None: self.gen_loss_modules += create_loss_modules_from_bottleneck(self.autoencoder.bottleneck, self.loss_config) self.losses_gen = MultiLoss(self.gen_loss_modules) self.disc_loss_modules = [ ValueLoss(key='loss_dis', weight=1.0, name='discriminator_loss'), ] self.losses_disc = MultiLoss(self.disc_loss_modules) # Set up EMA for model weights self.autoencoder_ema = None self.use_ema = use_ema if self.use_ema: self.autoencoder_ema = EMA( self.autoencoder, ema_model=ema_copy, beta=0.9999, power=3/4, update_every=1, update_after_step=1 ) self.latent_mask_ratio = latent_mask_ratio def configure_optimizers(self): opt_gen = create_optimizer_from_config(self.optimizer_configs['autoencoder']['optimizer'], self.autoencoder.parameters()) opt_disc = create_optimizer_from_config(self.optimizer_configs['discriminator']['optimizer'], self.discriminator.parameters()) if "scheduler" in self.optimizer_configs['autoencoder'] and "scheduler" in self.optimizer_configs['discriminator']: sched_gen = create_scheduler_from_config(self.optimizer_configs['autoencoder']['scheduler'], opt_gen) sched_disc = create_scheduler_from_config(self.optimizer_configs['discriminator']['scheduler'], opt_disc) return [opt_gen, opt_disc], [sched_gen, sched_disc] return [opt_gen, opt_disc] def training_step(self, batch, batch_idx): reals, _ = batch # Remove extra dimension added by WebDataset if reals.ndim == 4 and reals.shape[0] == 1: reals = reals[0] if self.global_step >= self.warmup_steps: self.warmed_up = True loss_info = {} loss_info["reals"] = reals encoder_input = reals if self.force_input_mono and encoder_input.shape[1] > 1: encoder_input = encoder_input.mean(dim=1, keepdim=True) loss_info["encoder_input"] = encoder_input data_std = encoder_input.std() if self.warmed_up and self.encoder_freeze_on_warmup: with torch.no_grad(): latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) else: latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) loss_info["latents"] = latents loss_info.update(encoder_info) # Encode with teacher model for distillation if self.teacher_model is not None: with torch.no_grad(): teacher_latents = self.teacher_model.encode(encoder_input, return_info=False) loss_info['teacher_latents'] = teacher_latents # Optionally mask out some latents for noise resistance if self.latent_mask_ratio > 0.0: mask = torch.rand_like(latents) < self.latent_mask_ratio latents = torch.where(mask, torch.zeros_like(latents), latents) decoded = self.autoencoder.decode(latents) loss_info["decoded"] = decoded if self.autoencoder.out_channels == 2: loss_info["decoded_left"] = decoded[:, 0:1, :] loss_info["decoded_right"] = decoded[:, 1:2, :] loss_info["reals_left"] = reals[:, 0:1, :] loss_info["reals_right"] = reals[:, 1:2, :] # Distillation if self.teacher_model is not None: with torch.no_grad(): teacher_decoded = self.teacher_model.decode(teacher_latents) own_latents_teacher_decoded = self.teacher_model.decode(latents) #Distilled model's latents decoded by teacher teacher_latents_own_decoded = self.autoencoder.decode(teacher_latents) #Teacher's latents decoded by distilled model loss_info['teacher_decoded'] = teacher_decoded loss_info['own_latents_teacher_decoded'] = own_latents_teacher_decoded loss_info['teacher_latents_own_decoded'] = teacher_latents_own_decoded if self.warmed_up: loss_dis, loss_adv, feature_matching_distance = self.discriminator.loss(reals, decoded) else: loss_dis = torch.tensor(0.).to(reals) loss_adv = torch.tensor(0.).to(reals) feature_matching_distance = torch.tensor(0.).to(reals) loss_info["loss_dis"] = loss_dis loss_info["loss_adv"] = loss_adv loss_info["feature_matching_distance"] = feature_matching_distance opt_gen, opt_disc = self.optimizers() lr_schedulers = self.lr_schedulers() sched_gen = None sched_disc = None if lr_schedulers is not None: sched_gen, sched_disc = lr_schedulers # Train the discriminator if self.global_step % 2 and self.warmed_up: loss, losses = self.losses_disc(loss_info) log_dict = { 'train/disc_lr': opt_disc.param_groups[0]['lr'] } opt_disc.zero_grad() self.manual_backward(loss) opt_disc.step() if sched_disc is not None: # sched step every step sched_disc.step() # Train the generator else: loss, losses = self.losses_gen(loss_info) if self.use_ema: self.autoencoder_ema.update() opt_gen.zero_grad() self.manual_backward(loss) opt_gen.step() if sched_gen is not None: # scheduler step every step sched_gen.step() log_dict = { 'train/loss': loss.detach(), 'train/latent_std': latents.std().detach(), 'train/data_std': data_std.detach(), 'train/gen_lr': opt_gen.param_groups[0]['lr'] } for loss_name, loss_value in losses.items(): log_dict[f'train/{loss_name}'] = loss_value.detach() self.log_dict(log_dict, prog_bar=True, on_step=True) return loss def export_model(self, path, use_safetensors=False): if self.autoencoder_ema is not None: model = self.autoencoder_ema.ema_model else: model = self.autoencoder if use_safetensors: save_model(model, path) else: torch.save({"state_dict": model.state_dict()}, path) class AutoencoderDemoCallback(pl.Callback): def __init__( self, demo_dl, demo_every=2000, sample_size=65536, sample_rate=48000 ): super().__init__() self.demo_every = demo_every self.demo_samples = sample_size self.demo_dl = iter(demo_dl) self.sample_rate = sample_rate self.last_demo_step = -1 @rank_zero_only @torch.no_grad() def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: return self.last_demo_step = trainer.global_step module.eval() try: demo_reals, _ = next(self.demo_dl) # Remove extra dimension added by WebDataset if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: demo_reals = demo_reals[0] encoder_input = demo_reals encoder_input = encoder_input.to(module.device) if module.force_input_mono: encoder_input = encoder_input.mean(dim=1, keepdim=True) demo_reals = demo_reals.to(module.device) with torch.no_grad(): if module.use_ema: latents = module.autoencoder_ema.ema_model.encode(encoder_input) fakes = module.autoencoder_ema.ema_model.decode(latents) else: latents = module.autoencoder.encode(encoder_input) fakes = module.autoencoder.decode(latents) #Interleave reals and fakes reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') # Put the demos together reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') log_dict = {} filename = f'recon_{trainer.global_step:08}.wav' reals_fakes = reals_fakes.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu() torchaudio.save(filename, reals_fakes, self.sample_rate) log_dict[f'recon'] = wandb.Audio(filename, sample_rate=self.sample_rate, caption=f'Reconstructed') log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents) log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents)) log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes)) trainer.logger.experiment.log(log_dict) except Exception as e: print(f'{type(e).__name__}: {e}') raise e finally: module.train() def create_loss_modules_from_bottleneck(bottleneck, loss_config): losses = [] if isinstance(bottleneck, VAEBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck) or isinstance(bottleneck, RVQVAEBottleneck): try: kl_weight = loss_config['bottleneck']['weights']['kl'] except: kl_weight = 1e-6 kl_loss = ValueLoss(key='kl', weight=kl_weight, name='kl_loss') losses.append(kl_loss) if isinstance(bottleneck, RVQBottleneck) or isinstance(bottleneck, RVQVAEBottleneck): quantizer_loss = ValueLoss(key='quantizer_loss', weight=1.0, name='quantizer_loss') losses.append(quantizer_loss) if isinstance(bottleneck, DACRVQBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck): codebook_loss = ValueLoss(key='vq/codebook_loss', weight=1.0, name='codebook_loss') commitment_loss = ValueLoss(key='vq/commitment_loss', weight=0.25, name='commitment_loss') losses.append(codebook_loss) losses.append(commitment_loss) if isinstance(bottleneck, WassersteinBottleneck): try: mmd_weight = loss_config['bottleneck']['weights']['mmd'] except: mmd_weight = 100 mmd_loss = ValueLoss(key='mmd', weight=mmd_weight, name='mmd_loss') losses.append(mmd_loss) return losses