ALeLacheur's picture
uploading audio diffusion attacks
5a9b731
"""
train_tts.py
Desc: An example script for training a Diffusion-based TTS model with a speaker encoder.
"""
import sys
import torch
import torch.nn as nn
import torchaudio
import gc
import argparse
import os
from tqdm import tqdm
import wandb
from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler
sys.path.append(".")
from models.style_diffusion import StyleVDiffusion, StyleVSampler
# from models.utils import MonoTransform
# from util import calculate_codebook_bitrate, extract_melspectrogram, get_audio_file_bitrate, get_duration, load_neural_audio_codec
from audioldm.pipeline import build_model
import torch.multiprocessing as mp
# Needed for Instruction/Prompt Models
# from transformers import AutoTokenizer, T5EncoderModel
import logging
# Uncomment out below if wanting to supress
# import warnings
# warnings.filterwarnings("ignore")
# Set Sample Rate if like so if desired
SAMPLE_RATE = 16000
BATCH_SIZE = 16
NUM_SAMPLES = int(2.56 * SAMPLE_RATE)
# NUM_SAMPLES = 2 ** 15
def create_model():
return DiffusionModel(
net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
# dim=2, # for spectrogram we use 2D-CNN
in_channels=314, # U-Net: number of input (audio) channels
out_channels=157, # U-Net: number of output (audio) channels
channels=[256, 256, 512, 512, 768, 768, 1280, 1280], # U-Net: channels at each layer
factors=[2, 2, 2, 2, 2, 2, 2, 1], # U-Net: downsampling and upsampling factors at each layer
items=[2, 2, 2, 2, 2, 2, 2, 2], # U-Net: number of repeating items at each layer
attentions=[0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
attention_heads=8, # U-Net: number of attention heads per attention item
attention_features=64, # U-Net: number of attention features per attention item
diffusion_t=StyleVDiffusion, # The diffusion method used
sampler_t=StyleVSampler, # The diffusion sampler used
# embedding_features = 8,
# embedding_features = 2, # Embedding for when it's just res and weight
embedding_features = 7, # Embedding Features for when Severity is Dropped
cross_attentions=[0, 0, 0, 0, 1, 1, 1, 1]
)
def main():
pass
# args = parse_args()
# os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
# os.environ["CUDA_VISIBLE_DEVICES"] = args['cuda_ids']
# cuda_ids = [phy_id for phy_id in range(len(args['cuda_ids'].split(",")))]
# logging.basicConfig(
# format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
# datefmt="%Y-%m-%d %H:%M:%S",
# level=os.environ.get("LOGLEVEL", "INFO").upper(),
# stream=sys.stdout,
# filemode='w',
# )
# logger = logging.getLogger("")
# # mp.set_start_method('spawn')
# # mp.set_sharing_strategy('file_system')
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# # Load in text model
# # tokenizer = AutoTokenizer.from_pretrained("t5-small")
# # text_model = T5EncoderModel.from_pretrained("t5-small")
# # text_model.eval() # Don't want to train it!
# dataset = DSVAE_CondStyleWAVDataset(
# path="/data/robbizorg/pqvd_gen_w_conditioning/speech_non_speech_timesteps_VCTK.json",
# random_crop_size=NUM_SAMPLES,
# sample_rate=SAMPLE_RATE,
# transforms=AllTransform(
# mono=True,
# ),
# reconstructive = False, # Make this true to just train a reconstructive model
# identity_limit = 1 # Affects how often we learn identity mapping
# )
# print(f"Dataset length: {len(dataset)}")
# dataloader = torch.utils.data.DataLoader(
# dataset,
# batch_size=BATCH_SIZE,
# shuffle=True,
# num_workers=16,
# pin_memory=False,
# )
# vae_model = DSVAE(logger, **args).cuda()
# if not os.path.exists(args['model_path']):
# logger.warning("model not exist and we just create the new model......")
# else:
# logger.info("Model Exists")
# logger.info("Model Path is " + args['model_path'])
# vae_model.loadParameters(args['model_path'])
# vae_model = torch.nn.DataParallel(vae_model, device_ids = cuda_ids, output_device=cuda_ids[0])
# vae_model = vae_model.cuda()
# vae_model.eval()
# vae_model.module.eer = True
# diff_model = create_model().to(device)
# # audio_codec = build_model().to(device)
# # audio_codec.latent_t_size = 157
# # config, audio_codec, vocoder = load_neural_audio_codec('2021-05-19T22-16-54_vggsound_codebook', './logs', device)
# # optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# optimizer = torch.optim.AdamW(params=list(diff_model.parameters()), lr=1e-4, betas= (0.95, 0.999), eps=1e-6, weight_decay=1e-3)
# print(f"Number of parameters: {sum(p.numel() for p in diff_model.parameters() if p.requires_grad)}")
# run_id = wandb.util.generate_id()
# if args["run_id"] is not None:
# run_id = args["run_id"]
# print(f"Run ID: {run_id}")
# wandb.init(project="audio-diffusion-no-condition", resume=args["resume"], id=run_id)
# epoch = 0
# step = 0
# checkpoint_path = os.path.join(args["checkpoint"], args["run_id"])
# if not os.path.exists(checkpoint_path):
# os.makedirs(checkpoint_path)
# os.makedirs(os.path.join(checkpoint_path, "mels"))
# os.makedirs(os.path.join(checkpoint_path, "wavs"))
# if wandb.run.resumed:
# if os.path.exists(checkpoint_path):
# checkpoint = torch.load(checkpoint_path)
# else:
# checkpoint = torch.load(wandb.restore(checkpoint_path))
# diff_model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# epoch = checkpoint['epoch']
# step = epoch * len(dataloader)
# scaler = torch.cuda.amp.GradScaler()
# diff_model.train()
# while epoch < 101:
# avg_loss = 0
# avg_loss_step = 0
# progress = tqdm(dataloader)
# for i, (audio, target, embedding) in enumerate(progress):
# optimizer.zero_grad()
# audio = audio.to(device)
# target = target.to(device)
# embedding = embedding.to(device)
# with torch.no_grad():
# embedding = embedding.float() # Make it float like the others
# speaker_embed_source, content_embed_source = vae_model(audio)
# speaker_embed_source = speaker_embed_source.unsqueeze(1).expand(-1, 157, -1)
# audio_embed = torch.cat((speaker_embed_source, content_embed_source), axis = -1)
# # zeroes = torch.zeros(16, 3, 128, dtype=audio_embed.dtype, device = audio_embed.device)
# # audio_embed = torch.cat((audio_embed, zeroes), dim=1)
# speaker_embed, content_embed = vae_model(target)
# speaker_embed = speaker_embed.unsqueeze(1).expand(-1, 157, -1)
# # in order to simulate paired data, do (naive) voice conversion first
# target_embed = torch.cat((speaker_embed, content_embed_source), axis = -1)
# # target_embed = torch.cat((target_embed, zeroes), dim = 1)
# with torch.cuda.amp.autocast():
# loss = diff_model(audio_embed, target_embed, embedding=embedding)
# avg_loss += loss.item()
# avg_loss_step += 1
# scaler.scale(loss).backward()
# scaler.step(optimizer)
# scaler.update()
# progress.set_postfix(
# # loss=loss.item(),
# loss=avg_loss / avg_loss_step,
# epoch=epoch + i / len(dataloader),
# )
# if step % 500 == 0:
# # if step % 1 == 0:
# # Turn noise into new audio sample with diffusion
# noise = torch.randn(1, 157, 128, device=device)
# with torch.cuda.amp.autocast():
# sample = diff_model.sample(audio_embed[0], noise, embedding=embedding[0][None, :], num_steps=200)
# # Save the melspecs
# audio_sub = torch.swapaxes(audio[0].unsqueeze(0), 1, 2)
# # target_sub = torch.swapaxes(target[0].unsqueeze(0), 1, 2) # This is the original target audio, not what we want
# target_sub = vae_model.module.share_decoder(target_embed).loc
# gen_mel = vae_model.module.share_decoder(sample).loc
# vae_model.module.draw_mel(audio_sub, mode=f"source_{step}", file_path = os.path.join(checkpoint_path, "mels"))
# vae_model.module.draw_mel(target_sub, mode=f"target_{step}", file_path = os.path.join(checkpoint_path, "mels"))
# vae_model.module.draw_mel(gen_mel, mode=f"gen_{step}", file_path = os.path.join(checkpoint_path, "mels"))
# vae_model.module.mel2wav(audio_sub, mode=f"source_{step}", task="vc", file_path = os.path.join(checkpoint_path, "wavs"))
# vae_model.module.mel2wav(target_sub, mode=f"target_{step}", task="vc", file_path = os.path.join(checkpoint_path, "wavs"))
# vae_model.module.mel2wav(gen_mel, mode=f"gen_{step}", task="vc", file_path = os.path.join(checkpoint_path, "wavs"))
# # torchaudio.save(os.path.join(checkpoint_path, 'wavs', f'test_input_sound_{step}.wav'), torch.from_numpy(audio_codec.mel_spectrogram_to_waveform(audio_codec.decode_first_stage(audio[0].unsqueeze(0))))[0], SAMPLE_RATE)
# # torchaudio.save(os.path.join(checkpoint_path, 'wavs', f'test_generated_sound_{step}.wav'), torch.from_numpy(audio_codec.mel_spectrogram_to_waveform(audio_codec.decode_first_stage(sample[0].unsqueeze(0))))[0], SAMPLE_RATE)
# # torchaudio.save(os.path.join(checkpoint_path, 'wavs', f'test_target_sound_{step}.wav'), torch.from_numpy(audio_codec.mel_spectrogram_to_waveform(audio_codec.decode_first_stage(target[0].unsqueeze(0))))[0], SAMPLE_RATE)
# wandb.log({
# "step": step,
# "epoch": epoch + i / len(dataloader),
# "loss": avg_loss / avg_loss_step,
# "input_mel": wandb.Image(os.path.join(checkpoint_path, "mels", f"source_{step}_mel_0.png"), caption="Input Mel"),
# "target_mel": wandb.Image(os.path.join(checkpoint_path, "mels", f"target_{step}_mel_0.png"), caption="Target Mel"),
# "gen_mel": wandb.Image(os.path.join(checkpoint_path, "mels", f"gen_{step}_mel_0.png"), caption="Gen Mel"),
# "input_audio": wandb.Audio(os.path.join(checkpoint_path, 'wavs', f'source_{step}0.wav'), caption="Input audio", sample_rate=SAMPLE_RATE),
# "target_audio": wandb.Audio(os.path.join(checkpoint_path, 'wavs', f'target_{step}0.wav'), caption="Target audio", sample_rate=SAMPLE_RATE),
# "generated_audio": wandb.Audio(os.path.join(checkpoint_path, 'wavs', f'gen_{step}0.wav'), caption="Generated audio", sample_rate=SAMPLE_RATE)
# })
# if step % 100 == 0:
# wandb.log({
# "step": step,
# "epoch": epoch + i / len(dataloader),
# "loss": avg_loss / avg_loss_step,
# })
# avg_loss = 0
# avg_loss_step = 0
# step += 1
# epoch += 1
# if epoch % 100 == 0:
# torch.save({
# 'epoch': epoch,
# 'model_state_dict': diff_model.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(),
# }, os.path.join(checkpoint_path, f"epoch-{epoch}.pt"))
# wandb.save(checkpoint_path, base_path=args["checkpoint"])
# def parse_args():
# parser = argparse.ArgumentParser()
# parser.add_argument("--checkpoint", type=str, default='/data/robbizorg/pqvd_gen_w_dsvae/checkpoints/')
# parser.add_argument("--resume", action="store_true")
# parser.add_argument("--run_id", type=str, default='condition_ldm')
# ## Params from DSVAE
# parser.add_argument('--dataset', type=str, default="VCTK", help='VCTK, LibriTTS')
# parser.add_argument('--encoder', type=str, default='dsvae', help='dsvae. tdnn')
# parser.add_argument('--vocoder', type=str, default='hifigan', help='wavenet, hifigan')
# parser.add_argument('--save_tsne', dest='save_tsne', action='store_true', help='save_tsne')
# parser.add_argument('--mel_tsne', dest='mel_tsne', action='store_true', help='mel_tsne')
# parser.add_argument('--feature', type=str, default='mel_spec', help='stft, mel_spec, mfcc')
# parser.add_argument('--model_path', type=str, default='/home/robbizorg/research/dsvae/save_models/dsvae/best699.pth')
# # parser.add_argument('--model_path', type=str, default='/data/andreaguz/save_models/dsvae_003_03/best699.pth') # Using the fine-tuned dsvae
# # parser.add_argument('--model_path', type=str, default='/data/andreaguz/save_models/dsvae_0001_0005/best.pth') # Using the fine-tuned dsvae
# parser.add_argument('--save_path', type=str, default='save_models/dsvae')
# parser.add_argument('--cuda_ids', type=str, default='0')
# parser.add_argument('--tsne_mode', type=str, default='test')
# parser.add_argument("--optimizer", type=str, default='adam', help='sgd, adam')
# parser.add_argument("--path_vc_1", type=str, default='', help='')
# parser.add_argument("--path_vc_2", type=str, default='', help='')
# parser.add_argument('--max_frames', type=int, default=100, help='1frame~10ms')
# parser.add_argument("--hop_size", type=int, default=256, help='hop_size')
# parser.add_argument("--win_length", type=int, default=1024, help='win_length')
# parser.add_argument("--spk_dim", type=int, default=64, help='spk_embed')
# parser.add_argument("--ecapa_spk_dim", type=int, default=128, help='ecapa spk_embed')
# parser.add_argument("--content_dim", type=int, default=64, help="content_embed")
# parser.add_argument("--conformer_hidden_dim", type=int, default=256, help="content_embed")
# parser.add_argument('--n_epochs', type=int, default=700, help='n_epochs')
# parser.add_argument('--eval_epoch', type=int, default=5, help='eval_epoch')
# parser.add_argument('--step_size', type=int, default=5, help='step_size')
# parser.add_argument('--num_workers', type=int, default=16, help='num_workers')
# parser.add_argument('--lr_decay_rate',type=float, default=0.95, help='lr_decay_rate')
# parser.add_argument('--lr',type=float, default=3e-4, help='lr_rate')
# # parser.add_argument('--klf_factor', type=float, default=3e-3, help='klf_factor')
# # parser.add_argument('--klt_factor', type=float, default=5, help='klt_factor')
# parser.add_argument('--klf_factor', type=float, default=3e-4, help='klf_factor') # Changed for the Fine-tuned Version
# parser.add_argument('--klt_factor', type=float, default=3e-3, help='klt_factor')
# parser.add_argument('--rec_factor', type=float, default=1, help='rec_factor')
# parser.add_argument('--vq_factor', type=float, default=1000, help='vq_factor')
# parser.add_argument('--zf_vq_factor', type=float, default=1000, help='vq_factor')
# parser.add_argument('--klf_std', type=float, default=0.5, help='klf_std')
# parser.add_argument('--rec_std', type=float, default=0.04, help='rec_std')
# parser.add_argument('--clip', type=float, default=1, help='rec_std')
# parser.add_argument('--phoneme_factor', type=float, default=1, help='phoneme_factor')
# parser.add_argument('--r_vq_factor', type=float, default=10, help='r_vq_factor')
# parser.add_argument('--compute_speaker_eer', dest='compute_speaker_eer', action='store_true', help='ASV EER')
# parser.add_argument('--eval_phoneme', dest='eval_phoneme', action='store_true', help='ASV EER')
# parser.add_argument('--num_eval', type=int, default=20, help='num of segments for eval')
# parser.add_argument('--batch_size', type=int, default=256, help='batch_size')
# parser.add_argument('--num_phonemes', type=int, default=100, help='num_phonemes')
# parser.add_argument('--with_phoneme', dest='with_phoneme', action='store_true', help='')
# parser.add_argument("--conversion", action='store_true', help='for conversion text')
# parser.add_argument("--conversion2", action='store_true', help='for conversion text')
# parser.add_argument("--conversion3", action='store_true', help='for conversion text')
# parser.add_argument("--mel2npy", action='store_true', help='mel2npy')
# parser.add_argument("--unconditional", action='store_true', help='unconditional')
# parser.add_argument('--zt_norm_mean', action='store_true', help='instancenorm1d on zt prior and post')
# parser.add_argument('--zf_norm_mean', action='store_true', help='instancenorm1d on zf prior and post')
# parser.add_argument('--freeze_encoder', action='store_true', help='if or not to freeze encoder')
# parser.add_argument('--freeze_decoder', action='store_true', help='if or not to freeze decoder')
# parser.add_argument("--sample_rate",type=int, default=16000, help='16000 or 48000')
# parser.add_argument('--noise_path', type=str, default='datasets/noise_list.scp', help='nosie invariant')
# parser.add_argument('--wav_aug_train', action='store_true', help='with data augmentation')
# parser.add_argument('--spec_aug_train', action='store_true', help='with data augmentation')
# parser.add_argument('--noise_train', action='store_true', help='noise')
# parser.add_argument('--triphn', action='store_true', help='with triphn')
# parser.add_argument('--train_hifigan', action='store_true', help='train hifigan')
# parser.add_argument("--prior_alignment", action='store_true', help='')
# parser.add_argument("--zf_vq", action='store_true', help='')
# parser.add_argument("--vq_prior_independent", action='store_true', help='')
# parser.add_argument("--vq_prior_regressive", action='store_true', help='')
# parser.add_argument("--vq_prior_pseudo", action='store_true', help='')
# parser.add_argument("--vq_size_zt",type=int, default=200, help='')
# parser.add_argument("--vq_size_zf",type=int, default=200, help='')
# parser.add_argument("--ignore_index",type=int, default=0, help='')
# parser.add_argument("--hidden_dim",type=int, default=256, help='')
# parser.add_argument("--share_encoder", type=str, default='cnn', help='')
# parser.add_argument("--share_decoder", type=str, default='cnn_lstm', help='cnn_lstm, cnn_transformer')
# parser.add_argument("--zt_encoder", type=str, default='lstm', help='lstm, conformer_encoder, transformer_encoder')
# parser.add_argument("--zf_encoder", type=str, default='lstm', help='lstm, transformer_encoder, ecapa_tdnn')
# parser.add_argument("--zt_prior_model", type=str, default='lstm', help='lstm, vqvae, transformer')
# parser.add_argument("--prior_signal", type=str, default='None', help='alignment_triphn, alignment_mono, melspec_pseudo, wavlm_pseudo, vq_embeds, vq_pseudo')
# parser.add_argument("--multi_scale_add", action='store_true', help='')
# parser.add_argument("--multi_scale_cat", action='store_true', help='')
# parser.add_argument("--num_scales",type=int, default=1, help='')
# parser.add_argument("--kmeans_num_clusters",type=int, default=50, help='')
# parser.add_argument("--wavlm_dim", type=int, default=768, help='')
# parser.add_argument("--ema_zt", action='store_true', help='')
# parser.add_argument("--ema_zf", action='store_true', help='')
# parser.add_argument("--r_vqvae", action='store_true', help='')
# parser.add_argument("--masked_mel", action='store_true', help='')
# parser.add_argument("--rec_noise", action='store_true', help='')
# parser.add_argument("--rec_mask", action='store_true', help='')
# parser.add_argument("--mel_classification", action='store_true', help='')
# parser.add_argument("--test_script", action='store_true', help='')
# parser.add_argument("--no_klt", action='store_true', help='')
# parser.add_argument("--zt_prior_ce_r_vq", action='store_true', help='')
# parser.add_argument('--zt_prior_ce_r_vq_factor', type=float, default=1000, help='factor')
# parser.add_argument("--zt_post_ce_r_vq", action='store_true', help='')
# parser.add_argument("--zt_prior_ce_kmeans", action='store_true', help='')
# parser.add_argument('--zt_prior_ce_kmeans_factor', type=float, default=1000, help='factor')
# parser.add_argument("--zt_post_ce_kmeans", action='store_true', help='')
# parser.add_argument('--zt_post_ce_kmeans_factor', type=float, default=10, help='factor')
# parser.add_argument("--zt_prior_ce_alignment", action='store_true', help='')
# parser.add_argument('--zt_prior_ce_alignment_factor', type=float, default=1000, help='factor')
# parser.add_argument("--prior_type", type=str, default='None', help='normal, condition, lm')
# parser.add_argument("--prior_embedding", type=str, default='one-hot', help='one-hot, embedding')
# parser.add_argument("--prior_mask", action='store_true', help='')
# parser.add_argument("--wavlm", action='store_true', help='')
# parser.add_argument("--wavlm_type", type=str, default='base', help='')
# parser.add_argument("--tts_phn_wav_path", type=str, default='', help='')
# parser.add_argument("--sr", type=str, default="16000", help='')
# parser.add_argument("--text", type=str, default="your tts", help='')
# parser.add_argument("--tts_align", action='store_true', help='')
# parser.add_argument("--tts_wavlm", action='store_true', help='')
# parser.add_argument("--tts", action='store_true', help='')
# parser.add_argument("--tts_config", type=str, default="conf/LibriTTS/preprocess.yaml", help='')
# parser.add_argument("--tts_target_wav_path", type=str, default='', help='')
# parser.add_argument("--speed", type=float, default='1.0', help='')
# parser.add_argument("--train_mapping", action='store_true', help='')
# parser.add_argument("--mapping_encoder", type=str, default='lstm', help='')
# parser.add_argument("--mapping_model_path", type=str, default='lstm', help='')
# parser.add_argument("--mask_mapping", action='store_true', help='')
# parser.add_argument("--mask_mapping_factor", type=float, default=1, help='')
# parser.add_argument("--l1_mapping_factor", type=float, default=1, help='')
# parser.add_argument("--mapping_ratio", type=float, default=1.0, help='')
# parser.add_argument("--condition2", action='store_true', help='')
# args = parser.parse_args()
# return update_args(**vars(args))
# if __name__ == "__main__":
# # torch.cuda.empty_cache()
# main()