""" train_tts.py Desc: An example script for training a Diffusion-based TTS model with a speaker encoder. """ import sys import torch import torch.nn as nn import torchaudio import gc import argparse import os from tqdm import tqdm import wandb from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler sys.path.append(".") from models.style_diffusion import StyleVDiffusion, StyleVSampler # from models.utils import MonoTransform # from util import calculate_codebook_bitrate, extract_melspectrogram, get_audio_file_bitrate, get_duration, load_neural_audio_codec from audioldm.pipeline import build_model import torch.multiprocessing as mp # Needed for Instruction/Prompt Models # from transformers import AutoTokenizer, T5EncoderModel import logging # Uncomment out below if wanting to supress # import warnings # warnings.filterwarnings("ignore") # Set Sample Rate if like so if desired SAMPLE_RATE = 16000 BATCH_SIZE = 16 NUM_SAMPLES = int(2.56 * SAMPLE_RATE) # NUM_SAMPLES = 2 ** 15 def create_model(): return DiffusionModel( net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case) # dim=2, # for spectrogram we use 2D-CNN in_channels=314, # U-Net: number of input (audio) channels out_channels=157, # U-Net: number of output (audio) channels channels=[256, 256, 512, 512, 768, 768, 1280, 1280], # U-Net: channels at each layer factors=[2, 2, 2, 2, 2, 2, 2, 1], # U-Net: downsampling and upsampling factors at each layer items=[2, 2, 2, 2, 2, 2, 2, 2], # U-Net: number of repeating items at each layer attentions=[0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer attention_heads=8, # U-Net: number of attention heads per attention item attention_features=64, # U-Net: number of attention features per attention item diffusion_t=StyleVDiffusion, # The diffusion method used sampler_t=StyleVSampler, # The diffusion sampler used # embedding_features = 8, # embedding_features = 2, # Embedding for when it's just res and weight embedding_features = 7, # Embedding Features for when Severity is Dropped cross_attentions=[0, 0, 0, 0, 1, 1, 1, 1] ) def main(): pass # args = parse_args() # os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID' # os.environ["CUDA_VISIBLE_DEVICES"] = args['cuda_ids'] # cuda_ids = [phy_id for phy_id in range(len(args['cuda_ids'].split(",")))] # logging.basicConfig( # format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", # datefmt="%Y-%m-%d %H:%M:%S", # level=os.environ.get("LOGLEVEL", "INFO").upper(), # stream=sys.stdout, # filemode='w', # ) # logger = logging.getLogger("") # # mp.set_start_method('spawn') # # mp.set_sharing_strategy('file_system') # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # # Load in text model # # tokenizer = AutoTokenizer.from_pretrained("t5-small") # # text_model = T5EncoderModel.from_pretrained("t5-small") # # text_model.eval() # Don't want to train it! # dataset = DSVAE_CondStyleWAVDataset( # path="/data/robbizorg/pqvd_gen_w_conditioning/speech_non_speech_timesteps_VCTK.json", # random_crop_size=NUM_SAMPLES, # sample_rate=SAMPLE_RATE, # transforms=AllTransform( # mono=True, # ), # reconstructive = False, # Make this true to just train a reconstructive model # identity_limit = 1 # Affects how often we learn identity mapping # ) # print(f"Dataset length: {len(dataset)}") # dataloader = torch.utils.data.DataLoader( # dataset, # batch_size=BATCH_SIZE, # shuffle=True, # num_workers=16, # pin_memory=False, # ) # vae_model = DSVAE(logger, **args).cuda() # if not os.path.exists(args['model_path']): # logger.warning("model not exist and we just create the new model......") # else: # logger.info("Model Exists") # logger.info("Model Path is " + args['model_path']) # vae_model.loadParameters(args['model_path']) # vae_model = torch.nn.DataParallel(vae_model, device_ids = cuda_ids, output_device=cuda_ids[0]) # vae_model = vae_model.cuda() # vae_model.eval() # vae_model.module.eer = True # diff_model = create_model().to(device) # # audio_codec = build_model().to(device) # # audio_codec.latent_t_size = 157 # # config, audio_codec, vocoder = load_neural_audio_codec('2021-05-19T22-16-54_vggsound_codebook', './logs', device) # # optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) # optimizer = torch.optim.AdamW(params=list(diff_model.parameters()), lr=1e-4, betas= (0.95, 0.999), eps=1e-6, weight_decay=1e-3) # print(f"Number of parameters: {sum(p.numel() for p in diff_model.parameters() if p.requires_grad)}") # run_id = wandb.util.generate_id() # if args["run_id"] is not None: # run_id = args["run_id"] # print(f"Run ID: {run_id}") # wandb.init(project="audio-diffusion-no-condition", resume=args["resume"], id=run_id) # epoch = 0 # step = 0 # checkpoint_path = os.path.join(args["checkpoint"], args["run_id"]) # if not os.path.exists(checkpoint_path): # os.makedirs(checkpoint_path) # os.makedirs(os.path.join(checkpoint_path, "mels")) # os.makedirs(os.path.join(checkpoint_path, "wavs")) # if wandb.run.resumed: # if os.path.exists(checkpoint_path): # checkpoint = torch.load(checkpoint_path) # else: # checkpoint = torch.load(wandb.restore(checkpoint_path)) # diff_model.load_state_dict(checkpoint['model_state_dict']) # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # epoch = checkpoint['epoch'] # step = epoch * len(dataloader) # scaler = torch.cuda.amp.GradScaler() # diff_model.train() # while epoch < 101: # avg_loss = 0 # avg_loss_step = 0 # progress = tqdm(dataloader) # for i, (audio, target, embedding) in enumerate(progress): # optimizer.zero_grad() # audio = audio.to(device) # target = target.to(device) # embedding = embedding.to(device) # with torch.no_grad(): # embedding = embedding.float() # Make it float like the others # speaker_embed_source, content_embed_source = vae_model(audio) # speaker_embed_source = speaker_embed_source.unsqueeze(1).expand(-1, 157, -1) # audio_embed = torch.cat((speaker_embed_source, content_embed_source), axis = -1) # # zeroes = torch.zeros(16, 3, 128, dtype=audio_embed.dtype, device = audio_embed.device) # # audio_embed = torch.cat((audio_embed, zeroes), dim=1) # speaker_embed, content_embed = vae_model(target) # speaker_embed = speaker_embed.unsqueeze(1).expand(-1, 157, -1) # # in order to simulate paired data, do (naive) voice conversion first # target_embed = torch.cat((speaker_embed, content_embed_source), axis = -1) # # target_embed = torch.cat((target_embed, zeroes), dim = 1) # with torch.cuda.amp.autocast(): # loss = diff_model(audio_embed, target_embed, embedding=embedding) # avg_loss += loss.item() # avg_loss_step += 1 # scaler.scale(loss).backward() # scaler.step(optimizer) # scaler.update() # progress.set_postfix( # # loss=loss.item(), # loss=avg_loss / avg_loss_step, # epoch=epoch + i / len(dataloader), # ) # if step % 500 == 0: # # if step % 1 == 0: # # Turn noise into new audio sample with diffusion # noise = torch.randn(1, 157, 128, device=device) # with torch.cuda.amp.autocast(): # sample = diff_model.sample(audio_embed[0], noise, embedding=embedding[0][None, :], num_steps=200) # # Save the melspecs # audio_sub = torch.swapaxes(audio[0].unsqueeze(0), 1, 2) # # target_sub = torch.swapaxes(target[0].unsqueeze(0), 1, 2) # This is the original target audio, not what we want # target_sub = vae_model.module.share_decoder(target_embed).loc # gen_mel = vae_model.module.share_decoder(sample).loc # vae_model.module.draw_mel(audio_sub, mode=f"source_{step}", file_path = os.path.join(checkpoint_path, "mels")) # vae_model.module.draw_mel(target_sub, mode=f"target_{step}", file_path = os.path.join(checkpoint_path, "mels")) # vae_model.module.draw_mel(gen_mel, mode=f"gen_{step}", file_path = os.path.join(checkpoint_path, "mels")) # vae_model.module.mel2wav(audio_sub, mode=f"source_{step}", task="vc", file_path = os.path.join(checkpoint_path, "wavs")) # vae_model.module.mel2wav(target_sub, mode=f"target_{step}", task="vc", file_path = os.path.join(checkpoint_path, "wavs")) # vae_model.module.mel2wav(gen_mel, mode=f"gen_{step}", task="vc", file_path = os.path.join(checkpoint_path, "wavs")) # # torchaudio.save(os.path.join(checkpoint_path, 'wavs', f'test_input_sound_{step}.wav'), torch.from_numpy(audio_codec.mel_spectrogram_to_waveform(audio_codec.decode_first_stage(audio[0].unsqueeze(0))))[0], SAMPLE_RATE) # # torchaudio.save(os.path.join(checkpoint_path, 'wavs', f'test_generated_sound_{step}.wav'), torch.from_numpy(audio_codec.mel_spectrogram_to_waveform(audio_codec.decode_first_stage(sample[0].unsqueeze(0))))[0], SAMPLE_RATE) # # torchaudio.save(os.path.join(checkpoint_path, 'wavs', f'test_target_sound_{step}.wav'), torch.from_numpy(audio_codec.mel_spectrogram_to_waveform(audio_codec.decode_first_stage(target[0].unsqueeze(0))))[0], SAMPLE_RATE) # wandb.log({ # "step": step, # "epoch": epoch + i / len(dataloader), # "loss": avg_loss / avg_loss_step, # "input_mel": wandb.Image(os.path.join(checkpoint_path, "mels", f"source_{step}_mel_0.png"), caption="Input Mel"), # "target_mel": wandb.Image(os.path.join(checkpoint_path, "mels", f"target_{step}_mel_0.png"), caption="Target Mel"), # "gen_mel": wandb.Image(os.path.join(checkpoint_path, "mels", f"gen_{step}_mel_0.png"), caption="Gen Mel"), # "input_audio": wandb.Audio(os.path.join(checkpoint_path, 'wavs', f'source_{step}0.wav'), caption="Input audio", sample_rate=SAMPLE_RATE), # "target_audio": wandb.Audio(os.path.join(checkpoint_path, 'wavs', f'target_{step}0.wav'), caption="Target audio", sample_rate=SAMPLE_RATE), # "generated_audio": wandb.Audio(os.path.join(checkpoint_path, 'wavs', f'gen_{step}0.wav'), caption="Generated audio", sample_rate=SAMPLE_RATE) # }) # if step % 100 == 0: # wandb.log({ # "step": step, # "epoch": epoch + i / len(dataloader), # "loss": avg_loss / avg_loss_step, # }) # avg_loss = 0 # avg_loss_step = 0 # step += 1 # epoch += 1 # if epoch % 100 == 0: # torch.save({ # 'epoch': epoch, # 'model_state_dict': diff_model.state_dict(), # 'optimizer_state_dict': optimizer.state_dict(), # }, os.path.join(checkpoint_path, f"epoch-{epoch}.pt")) # wandb.save(checkpoint_path, base_path=args["checkpoint"]) # def parse_args(): # parser = argparse.ArgumentParser() # parser.add_argument("--checkpoint", type=str, default='/data/robbizorg/pqvd_gen_w_dsvae/checkpoints/') # parser.add_argument("--resume", action="store_true") # parser.add_argument("--run_id", type=str, default='condition_ldm') # ## Params from DSVAE # parser.add_argument('--dataset', type=str, default="VCTK", help='VCTK, LibriTTS') # parser.add_argument('--encoder', type=str, default='dsvae', help='dsvae. tdnn') # parser.add_argument('--vocoder', type=str, default='hifigan', help='wavenet, hifigan') # parser.add_argument('--save_tsne', dest='save_tsne', action='store_true', help='save_tsne') # parser.add_argument('--mel_tsne', dest='mel_tsne', action='store_true', help='mel_tsne') # parser.add_argument('--feature', type=str, default='mel_spec', help='stft, mel_spec, mfcc') # parser.add_argument('--model_path', type=str, default='/home/robbizorg/research/dsvae/save_models/dsvae/best699.pth') # # parser.add_argument('--model_path', type=str, default='/data/andreaguz/save_models/dsvae_003_03/best699.pth') # Using the fine-tuned dsvae # # parser.add_argument('--model_path', type=str, default='/data/andreaguz/save_models/dsvae_0001_0005/best.pth') # Using the fine-tuned dsvae # parser.add_argument('--save_path', type=str, default='save_models/dsvae') # parser.add_argument('--cuda_ids', type=str, default='0') # parser.add_argument('--tsne_mode', type=str, default='test') # parser.add_argument("--optimizer", type=str, default='adam', help='sgd, adam') # parser.add_argument("--path_vc_1", type=str, default='', help='') # parser.add_argument("--path_vc_2", type=str, default='', help='') # parser.add_argument('--max_frames', type=int, default=100, help='1frame~10ms') # parser.add_argument("--hop_size", type=int, default=256, help='hop_size') # parser.add_argument("--win_length", type=int, default=1024, help='win_length') # parser.add_argument("--spk_dim", type=int, default=64, help='spk_embed') # parser.add_argument("--ecapa_spk_dim", type=int, default=128, help='ecapa spk_embed') # parser.add_argument("--content_dim", type=int, default=64, help="content_embed") # parser.add_argument("--conformer_hidden_dim", type=int, default=256, help="content_embed") # parser.add_argument('--n_epochs', type=int, default=700, help='n_epochs') # parser.add_argument('--eval_epoch', type=int, default=5, help='eval_epoch') # parser.add_argument('--step_size', type=int, default=5, help='step_size') # parser.add_argument('--num_workers', type=int, default=16, help='num_workers') # parser.add_argument('--lr_decay_rate',type=float, default=0.95, help='lr_decay_rate') # parser.add_argument('--lr',type=float, default=3e-4, help='lr_rate') # # parser.add_argument('--klf_factor', type=float, default=3e-3, help='klf_factor') # # parser.add_argument('--klt_factor', type=float, default=5, help='klt_factor') # parser.add_argument('--klf_factor', type=float, default=3e-4, help='klf_factor') # Changed for the Fine-tuned Version # parser.add_argument('--klt_factor', type=float, default=3e-3, help='klt_factor') # parser.add_argument('--rec_factor', type=float, default=1, help='rec_factor') # parser.add_argument('--vq_factor', type=float, default=1000, help='vq_factor') # parser.add_argument('--zf_vq_factor', type=float, default=1000, help='vq_factor') # parser.add_argument('--klf_std', type=float, default=0.5, help='klf_std') # parser.add_argument('--rec_std', type=float, default=0.04, help='rec_std') # parser.add_argument('--clip', type=float, default=1, help='rec_std') # parser.add_argument('--phoneme_factor', type=float, default=1, help='phoneme_factor') # parser.add_argument('--r_vq_factor', type=float, default=10, help='r_vq_factor') # parser.add_argument('--compute_speaker_eer', dest='compute_speaker_eer', action='store_true', help='ASV EER') # parser.add_argument('--eval_phoneme', dest='eval_phoneme', action='store_true', help='ASV EER') # parser.add_argument('--num_eval', type=int, default=20, help='num of segments for eval') # parser.add_argument('--batch_size', type=int, default=256, help='batch_size') # parser.add_argument('--num_phonemes', type=int, default=100, help='num_phonemes') # parser.add_argument('--with_phoneme', dest='with_phoneme', action='store_true', help='') # parser.add_argument("--conversion", action='store_true', help='for conversion text') # parser.add_argument("--conversion2", action='store_true', help='for conversion text') # parser.add_argument("--conversion3", action='store_true', help='for conversion text') # parser.add_argument("--mel2npy", action='store_true', help='mel2npy') # parser.add_argument("--unconditional", action='store_true', help='unconditional') # parser.add_argument('--zt_norm_mean', action='store_true', help='instancenorm1d on zt prior and post') # parser.add_argument('--zf_norm_mean', action='store_true', help='instancenorm1d on zf prior and post') # parser.add_argument('--freeze_encoder', action='store_true', help='if or not to freeze encoder') # parser.add_argument('--freeze_decoder', action='store_true', help='if or not to freeze decoder') # parser.add_argument("--sample_rate",type=int, default=16000, help='16000 or 48000') # parser.add_argument('--noise_path', type=str, default='datasets/noise_list.scp', help='nosie invariant') # parser.add_argument('--wav_aug_train', action='store_true', help='with data augmentation') # parser.add_argument('--spec_aug_train', action='store_true', help='with data augmentation') # parser.add_argument('--noise_train', action='store_true', help='noise') # parser.add_argument('--triphn', action='store_true', help='with triphn') # parser.add_argument('--train_hifigan', action='store_true', help='train hifigan') # parser.add_argument("--prior_alignment", action='store_true', help='') # parser.add_argument("--zf_vq", action='store_true', help='') # parser.add_argument("--vq_prior_independent", action='store_true', help='') # parser.add_argument("--vq_prior_regressive", action='store_true', help='') # parser.add_argument("--vq_prior_pseudo", action='store_true', help='') # parser.add_argument("--vq_size_zt",type=int, default=200, help='') # parser.add_argument("--vq_size_zf",type=int, default=200, help='') # parser.add_argument("--ignore_index",type=int, default=0, help='') # parser.add_argument("--hidden_dim",type=int, default=256, help='') # parser.add_argument("--share_encoder", type=str, default='cnn', help='') # parser.add_argument("--share_decoder", type=str, default='cnn_lstm', help='cnn_lstm, cnn_transformer') # parser.add_argument("--zt_encoder", type=str, default='lstm', help='lstm, conformer_encoder, transformer_encoder') # parser.add_argument("--zf_encoder", type=str, default='lstm', help='lstm, transformer_encoder, ecapa_tdnn') # parser.add_argument("--zt_prior_model", type=str, default='lstm', help='lstm, vqvae, transformer') # parser.add_argument("--prior_signal", type=str, default='None', help='alignment_triphn, alignment_mono, melspec_pseudo, wavlm_pseudo, vq_embeds, vq_pseudo') # parser.add_argument("--multi_scale_add", action='store_true', help='') # parser.add_argument("--multi_scale_cat", action='store_true', help='') # parser.add_argument("--num_scales",type=int, default=1, help='') # parser.add_argument("--kmeans_num_clusters",type=int, default=50, help='') # parser.add_argument("--wavlm_dim", type=int, default=768, help='') # parser.add_argument("--ema_zt", action='store_true', help='') # parser.add_argument("--ema_zf", action='store_true', help='') # parser.add_argument("--r_vqvae", action='store_true', help='') # parser.add_argument("--masked_mel", action='store_true', help='') # parser.add_argument("--rec_noise", action='store_true', help='') # parser.add_argument("--rec_mask", action='store_true', help='') # parser.add_argument("--mel_classification", action='store_true', help='') # parser.add_argument("--test_script", action='store_true', help='') # parser.add_argument("--no_klt", action='store_true', help='') # parser.add_argument("--zt_prior_ce_r_vq", action='store_true', help='') # parser.add_argument('--zt_prior_ce_r_vq_factor', type=float, default=1000, help='factor') # parser.add_argument("--zt_post_ce_r_vq", action='store_true', help='') # parser.add_argument("--zt_prior_ce_kmeans", action='store_true', help='') # parser.add_argument('--zt_prior_ce_kmeans_factor', type=float, default=1000, help='factor') # parser.add_argument("--zt_post_ce_kmeans", action='store_true', help='') # parser.add_argument('--zt_post_ce_kmeans_factor', type=float, default=10, help='factor') # parser.add_argument("--zt_prior_ce_alignment", action='store_true', help='') # parser.add_argument('--zt_prior_ce_alignment_factor', type=float, default=1000, help='factor') # parser.add_argument("--prior_type", type=str, default='None', help='normal, condition, lm') # parser.add_argument("--prior_embedding", type=str, default='one-hot', help='one-hot, embedding') # parser.add_argument("--prior_mask", action='store_true', help='') # parser.add_argument("--wavlm", action='store_true', help='') # parser.add_argument("--wavlm_type", type=str, default='base', help='') # parser.add_argument("--tts_phn_wav_path", type=str, default='', help='') # parser.add_argument("--sr", type=str, default="16000", help='') # parser.add_argument("--text", type=str, default="your tts", help='') # parser.add_argument("--tts_align", action='store_true', help='') # parser.add_argument("--tts_wavlm", action='store_true', help='') # parser.add_argument("--tts", action='store_true', help='') # parser.add_argument("--tts_config", type=str, default="conf/LibriTTS/preprocess.yaml", help='') # parser.add_argument("--tts_target_wav_path", type=str, default='', help='') # parser.add_argument("--speed", type=float, default='1.0', help='') # parser.add_argument("--train_mapping", action='store_true', help='') # parser.add_argument("--mapping_encoder", type=str, default='lstm', help='') # parser.add_argument("--mapping_model_path", type=str, default='lstm', help='') # parser.add_argument("--mask_mapping", action='store_true', help='') # parser.add_argument("--mask_mapping_factor", type=float, default=1, help='') # parser.add_argument("--l1_mapping_factor", type=float, default=1, help='') # parser.add_argument("--mapping_ratio", type=float, default=1.0, help='') # parser.add_argument("--condition2", action='store_true', help='') # args = parser.parse_args() # return update_args(**vars(args)) # if __name__ == "__main__": # # torch.cuda.empty_cache() # main()