import yaml import random import argparse import os import time from tqdm import tqdm from pathlib import Path import torch from torch.utils.data import DataLoader from accelerate import Accelerator from diffusers import DDIMScheduler from configs.plugin import get_params from model.p2e_cross import P2E_Cross from modules.speaker_encoder.encoder import inference as spk_encoder from transformers import T5Tokenizer, T5EncoderModel, AutoModel from inference_freevc import eval_plugin from dataset.dreamvc import DreamData # from vc_wrapper import load_diffvc_models from freevc_wrapper import get_freevc_models from utils import minmax_norm_diff, reverse_minmax_norm_diff, scale_shift parser = argparse.ArgumentParser() # config settings parser.add_argument('--config-name', type=str, default='Plugin_freevc') parser.add_argument('--vc-unet-path', type=str, default='freevc') parser.add_argument('--speaker-path', type=str, default='speaker_encoder/ckpt/pretrained_bak_5805000.pt') # training settings parser.add_argument("--amp", type=str, default='fp16') parser.add_argument('--epochs', type=int, default=200) parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--num-workers', type=int, default=8) parser.add_argument('--num-threads', type=int, default=1) parser.add_argument('--save-every', type=int, default=10) # log and random seed parser.add_argument('--random-seed', type=int, default=2023) parser.add_argument('--log-step', type=int, default=200) parser.add_argument('--log-dir', type=str, default='../logs/') parser.add_argument('--save-dir', type=str, default='../ckpts/') args = parser.parse_args() params = get_params(args.config_name) args.log_dir = args.log_dir + args.config_name + '/' with open('model/p2e_cross.yaml', 'r') as fp: config = yaml.safe_load(fp) if os.path.exists(args.save_dir + args.config_name) is False: os.makedirs(args.save_dir + args.config_name) if os.path.exists(args.log_dir) is False: os.makedirs(args.log_dir) if __name__ == '__main__': # Fix the random seed random.seed(args.random_seed) torch.manual_seed(args.random_seed) # Set device torch.set_num_threads(args.num_threads) if torch.cuda.is_available(): args.device = 'cuda' torch.cuda.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) torch.backends.cuda.matmul.allow_tf32 = True if torch.backends.cudnn.is_available(): torch.backends.cudnn.deterministic = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = False else: args.device = 'cpu' train_set = DreamData(data_dir='../prepare_freevc/spk/', meta_dir='../prepare/plugin_meta.csv', subset='train', prompt_dir='../prepare/prompts.csv',) train_loader = DataLoader(train_set, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True) # use accelerator for multi-gpu training accelerator = Accelerator(mixed_precision=args.amp) # vc_unet, hifigan, _, logmel, vc_scheduler = load_diffvc_models(args.vc_unet_path, # args.vocoder_path, # args.speaker_path, # args.vc_config_path, # accelerator.device) freevc_24, cmodel, _, hps = get_freevc_models(args.vc_unet_path, args.speaker_path, accelerator.device) # speaker # spk_encoder.load_model(Path(args.speaker_path), accelerator.device) # text encoder tokenizer = T5Tokenizer.from_pretrained(params.text_encoder.model) text_encoder = T5EncoderModel.from_pretrained(params.text_encoder.model).to(accelerator.device) text_encoder.eval() # main U-Net model = P2E_Cross(config['diffwrap']).to(accelerator.device) model.load_state_dict(torch.load('../ckpts/Plugin_freevc/49.pt')['model']) total_params = sum([param.nelement() for param in model.parameters()]) print("Number of parameter: %.2fM" % (total_params / 1e6)) if params.diff.v_prediction: print('v prediction') noise_scheduler = DDIMScheduler(num_train_timesteps=params.diff.num_train_steps, beta_start=params.diff.beta_start, beta_end=params.diff.beta_end, rescale_betas_zero_snr=True, timestep_spacing="trailing", clip_sample=False, prediction_type='v_prediction') else: print('noise prediction') noise_scheduler = DDIMScheduler(num_train_timesteps=args.num_train_steps, beta_start=args.beta_start, beta_end=args.beta_end, clip_sample=False, prediction_type='epsilon') optimizer = torch.optim.AdamW(model.parameters(), lr=params.opt.learning_rate, betas=(params.opt.beta1, params.opt.beta2), weight_decay=params.opt.weight_decay, eps=params.opt.adam_epsilon, ) loss_func = torch.nn.MSELoss() model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader) global_step = 0 losses = 0 if accelerator.is_main_process: eval_plugin(freevc_24, cmodel, [tokenizer, text_encoder], model, noise_scheduler, (1, 256, 1), val_meta='../prepare/val_meta.csv', val_folder='/home/jerry/Projects/Dataset/Speech/vctk_libritts/', guidance_scale=3.0, guidance_rescale=0.0, ddim_steps=100, eta=1, random_seed=None, device=accelerator.device, epoch='test', save_path=args.log_dir + 'output/', val_num=10) accelerator.wait_for_everyone() for epoch in range(args.epochs): model.train() for step, batch in enumerate(tqdm(train_loader)): spk_embed, prompt = batch spk_embed = spk_embed.unsqueeze(-1) with torch.no_grad(): text_batch = tokenizer(prompt, max_length=32, padding='max_length', truncation=True, return_tensors="pt") text, text_mask = text_batch.input_ids.to(spk_embed.device), \ text_batch.attention_mask.to(spk_embed.device) text = text_encoder(input_ids=text, attention_mask=text_mask)[0] spk_embed = scale_shift(spk_embed, 20, -0.035) # spk_embed = minmax_norm_diff(spk_embed, vmax=0.5, vmin=0.0) # content_clip = align_seq(content_clip, audio_clip.shape[-1]) # f0_clip = align_seq(f0_clip, audio_clip.shape[-1]) # adding noise noise = torch.randn(spk_embed.shape).to(accelerator.device) timesteps = torch.randint(0, params.diff.num_train_steps, (noise.shape[0],), device=accelerator.device, ).long() noisy_target = noise_scheduler.add_noise(spk_embed, noise, timesteps) # v prediction - model output velocity = noise_scheduler.get_velocity(spk_embed, noise, timesteps) # inference pred = model(noisy_target, timesteps, text, text_mask, train_cfg=True, cfg_prob=0.25) # backward if params.diff.v_prediction: loss = loss_func(pred, velocity) else: loss = loss_func(pred, noise) accelerator.backward(loss) optimizer.step() optimizer.zero_grad() global_step += 1 losses += loss.item() if accelerator.is_main_process: if global_step % args.log_step == 0: n = open(args.log_dir + 'diff_vc.txt', mode='a') n.write(time.asctime(time.localtime(time.time()))) n.write('\n') n.write('Epoch: [{}][{}] Batch: [{}][{}] Loss: {:.6f}\n'.format( epoch + 1, args.epochs, step + 1, len(train_loader), losses / args.log_step)) n.close() losses = 0.0 accelerator.wait_for_everyone() if (epoch + 1) % args.save_every == 0: if accelerator.is_main_process: eval_plugin(freevc_24, cmodel, [tokenizer, text_encoder], model, noise_scheduler, (1, 256, 1), val_meta='../prepare/val_meta.csv', val_folder='/home/jerry/Projects/Dataset/Speech/vctk_libritts/', guidance_scale=3, guidance_rescale=0.0, ddim_steps=50, eta=1, random_seed=2024, device=accelerator.device, epoch=epoch, save_path=args.log_dir + 'output/', val_num=10) unwrapped_unet = accelerator.unwrap_model(model) accelerator.save({ "model": unwrapped_unet.state_dict(), }, args.save_dir + args.config_name + '/' + str(epoch) + '.pt')