Spaces:
Paused
Paused
import os | |
from glob import glob | |
from logging import getLogger | |
from typing import Literal, Optional, Tuple | |
from pathlib import Path | |
from threading import Thread | |
import torch | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
from accelerate import Accelerator | |
from datasets import Dataset | |
from .pretrained import pretrained_checkpoints | |
from .constants import * | |
from torch.utils.tensorboard import SummaryWriter | |
import time | |
from tqdm.auto import tqdm | |
from huggingface_hub import HfApi, upload_folder | |
from .synthesizer import commons | |
from .synthesizer.models import ( | |
SynthesizerTrnMs768NSFsid, | |
MultiPeriodDiscriminator, | |
) | |
from .utils.losses import ( | |
discriminator_loss, | |
feature_loss, | |
generator_loss, | |
kl_loss, | |
) | |
from .utils.mel_processing import mel_spectrogram_torch, spec_to_mel_torch | |
from .utils.data_utils import TextAudioCollateMultiNSFsid | |
logger = getLogger(__name__) | |
class TrainingCheckpoint: | |
def __init__( | |
self, | |
epoch: int, | |
G: SynthesizerTrnMs768NSFsid, | |
D: MultiPeriodDiscriminator, | |
optimizer_G: torch.optim.AdamW, | |
optimizer_D: torch.optim.AdamW, | |
scheduler_G: torch.optim.lr_scheduler.ExponentialLR, | |
scheduler_D: torch.optim.lr_scheduler.ExponentialLR, | |
loss_gen: float, | |
loss_fm: float, | |
loss_mel: float, | |
loss_kl: float, | |
loss_gen_all: float, | |
loss_disc: float, | |
): | |
self.epoch = epoch | |
self.G = G | |
self.D = D | |
self.optimizer_G = optimizer_G | |
self.optimizer_D = optimizer_D | |
self.scheduler_G = scheduler_G | |
self.scheduler_D = scheduler_D | |
self.loss_gen = loss_gen | |
self.loss_fm = loss_fm | |
self.loss_mel = loss_mel | |
self.loss_kl = loss_kl | |
self.loss_gen_all = loss_gen_all | |
self.loss_disc = loss_disc | |
def save( | |
self, | |
exp_dir="./", | |
g_checkpoint: str | None = None, | |
d_checkpoint: str | None = None, | |
): | |
g_path = g_checkpoint if g_checkpoint is not None else f"G_latest.pth" | |
d_path = d_checkpoint if d_checkpoint is not None else f"D_latest.pth" | |
torch.save( | |
{ | |
"epoch": self.epoch, | |
"model": self.G.state_dict(), | |
"optimizer": self.optimizer_G.state_dict(), | |
"scheduler": self.scheduler_G.state_dict(), | |
"loss_gen": self.loss_gen, | |
"loss_fm": self.loss_fm, | |
"loss_mel": self.loss_mel, | |
"loss_kl": self.loss_kl, | |
"loss_gen_all": self.loss_gen_all, | |
"loss_disc": self.loss_disc, | |
}, | |
os.path.join(exp_dir, g_path), | |
) | |
torch.save( | |
{ | |
"epoch": self.epoch, | |
"model": self.D.state_dict(), | |
"optimizer": self.optimizer_D.state_dict(), | |
"scheduler": self.scheduler_D.state_dict(), | |
}, | |
os.path.join(exp_dir, d_path), | |
) | |
def latest_checkpoint_file(files: list[str]) -> str: | |
try: | |
return max(files, key=lambda x: int(Path(x).stem.split("_")[1])) | |
except: | |
return max(files, key=os.path.getctime) | |
class RVCTrainer: | |
def __init__( | |
self, | |
exp_dir: str, | |
dataset_train: Dataset, | |
dataset_test: Optional[Dataset] = None, | |
sr: int = SR_48K, | |
): | |
self.exp_dir = exp_dir | |
self.dataset_train = dataset_train | |
self.dataset_test = dataset_test | |
self.sr = sr | |
self.writer = SummaryWriter( | |
os.path.join(exp_dir, "logs", time.strftime("%Y%m%d-%H%M%S")) | |
) | |
def latest_checkpoint(self, fallback_to_pretrained: bool = True): | |
files_g = glob(os.path.join(self.exp_dir, "G_*.pth")) | |
if not files_g: | |
return pretrained_checkpoints() if fallback_to_pretrained else None | |
latest_g = latest_checkpoint_file(files_g) | |
files_d = glob(os.path.join(self.exp_dir, "D_*.pth")) | |
if not files_d: | |
return pretrained_checkpoints() if fallback_to_pretrained else None | |
latest_d = latest_checkpoint_file(files_d) | |
return latest_g, latest_d | |
def setup_models( | |
self, | |
resume_from: Tuple[str, str] | None = None, | |
accelerator: Accelerator | None = None, | |
lr=1e-4, | |
lr_decay=0.999875, | |
betas: Tuple[float, float] = (0.8, 0.99), | |
eps=1e-9, | |
use_spectral_norm=False, | |
segment_size=17280, | |
filter_length=N_FFT, | |
hop_length=HOP_LENGTH, | |
inter_channels=192, | |
hidden_channels=192, | |
filter_channels=768, | |
n_heads=2, | |
n_layers=6, | |
kernel_size=3, | |
p_dropout=0.0, | |
resblock: Literal["1", "2"] = "1", | |
resblock_kernel_sizes: list[int] = [3, 7, 11], | |
resblock_dilation_sizes: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], | |
upsample_initial_channel=512, | |
upsample_rates: list[int] = [12, 10, 2, 2], | |
upsample_kernel_sizes: list[int] = [24, 20, 4, 4], | |
spk_embed_dim=109, | |
gin_channels=256, | |
) -> Tuple[ | |
SynthesizerTrnMs768NSFsid, | |
MultiPeriodDiscriminator, | |
torch.optim.AdamW, | |
torch.optim.AdamW, | |
torch.optim.lr_scheduler.ExponentialLR, | |
torch.optim.lr_scheduler.ExponentialLR, | |
int, | |
]: | |
if accelerator is None: | |
accelerator = Accelerator() | |
G = SynthesizerTrnMs768NSFsid( | |
spec_channels=filter_length // 2 + 1, | |
segment_size=segment_size // hop_length, | |
inter_channels=inter_channels, | |
hidden_channels=hidden_channels, | |
filter_channels=filter_channels, | |
n_heads=n_heads, | |
n_layers=n_layers, | |
kernel_size=kernel_size, | |
p_dropout=p_dropout, | |
resblock=resblock, | |
resblock_kernel_sizes=resblock_kernel_sizes, | |
resblock_dilation_sizes=resblock_dilation_sizes, | |
upsample_initial_channel=upsample_initial_channel, | |
upsample_rates=upsample_rates, | |
upsample_kernel_sizes=upsample_kernel_sizes, | |
spk_embed_dim=spk_embed_dim, | |
gin_channels=gin_channels, | |
sr=self.sr, | |
).to(accelerator.device) | |
D = MultiPeriodDiscriminator(use_spectral_norm=use_spectral_norm).to( | |
accelerator.device | |
) | |
optimizer_G = torch.optim.AdamW( | |
G.parameters(), | |
lr, | |
betas=betas, | |
eps=eps, | |
) | |
optimizer_D = torch.optim.AdamW( | |
D.parameters(), | |
lr, | |
betas=betas, | |
eps=eps, | |
) | |
if resume_from is not None: | |
g_checkpoint, d_checkpoint = resume_from | |
logger.info(f"Resuming from {g_checkpoint} and {d_checkpoint}") | |
G_checkpoint = torch.load( | |
g_checkpoint, map_location=accelerator.device, weights_only=True | |
) | |
D_checkpoint = torch.load( | |
d_checkpoint, map_location=accelerator.device, weights_only=True | |
) | |
if "epoch" in G_checkpoint: | |
finished_epoch = int(G_checkpoint["epoch"]) | |
try: | |
finished_epoch = int(Path(g_checkpoint).stem.split("_")[1]) | |
except: | |
finished_epoch = 0 | |
scheduler_G = torch.optim.lr_scheduler.ExponentialLR( | |
optimizer_G, gamma=lr_decay, last_epoch=finished_epoch - 1 | |
) | |
scheduler_D = torch.optim.lr_scheduler.ExponentialLR( | |
optimizer_D, gamma=lr_decay, last_epoch=finished_epoch - 1 | |
) | |
G.load_state_dict(G_checkpoint["model"]) | |
if "optimizer" in G_checkpoint: | |
optimizer_G.load_state_dict(G_checkpoint["optimizer"]) | |
if "scheduler" in G_checkpoint: | |
scheduler_G.load_state_dict(G_checkpoint["scheduler"]) | |
D.load_state_dict(D_checkpoint["model"]) | |
if "optimizer" in D_checkpoint: | |
optimizer_D.load_state_dict(D_checkpoint["optimizer"]) | |
if "scheduler" in D_checkpoint: | |
scheduler_D.load_state_dict(D_checkpoint["scheduler"]) | |
else: | |
finished_epoch = 0 | |
scheduler_G = torch.optim.lr_scheduler.ExponentialLR( | |
optimizer_G, gamma=lr_decay, last_epoch=-1 | |
) | |
scheduler_D = torch.optim.lr_scheduler.ExponentialLR( | |
optimizer_D, gamma=lr_decay, last_epoch=-1 | |
) | |
G, D, optimizer_G, optimizer_D = accelerator.prepare( | |
G, D, optimizer_G, optimizer_D | |
) | |
return G, D, optimizer_G, optimizer_D, scheduler_G, scheduler_D, finished_epoch | |
def setup_dataloader( | |
self, | |
dataset: Dataset, | |
batch_size=1, | |
shuffle=True, | |
accelerator: Accelerator | None = None, | |
): | |
if accelerator is None: | |
accelerator = Accelerator() | |
dataset = dataset.with_format("torch", device=accelerator.device) | |
loader = DataLoader( | |
dataset, | |
batch_size=batch_size, | |
shuffle=shuffle, | |
collate_fn=TextAudioCollateMultiNSFsid(), | |
) | |
loader = accelerator.prepare(loader) | |
return loader | |
def run( | |
self, | |
G, | |
D, | |
optimizer_G, | |
optimizer_D, | |
scheduler_G, | |
scheduler_D, | |
finished_epoch, | |
loader_train, | |
loader_test, | |
accelerator: Accelerator | None = None, | |
epochs=100, | |
segment_size=17280, | |
filter_length=N_FFT, | |
hop_length=HOP_LENGTH, | |
n_mel_channels=N_MELS, | |
win_length=WIN_LENGTH, | |
mel_fmin=0.0, | |
mel_fmax: float | None = None, | |
c_mel=45, | |
c_kl=1.0, | |
upload_to_hub: str | None = None, | |
upload_window_minutes=5, | |
): | |
if accelerator is None: | |
accelerator = Accelerator() | |
if accelerator.is_main_process: | |
logger.info("Start training") | |
upload_state_last = 0.0 | |
prev_loss_gen = -1.0 | |
prev_loss_fm = -1.0 | |
prev_loss_mel = -1.0 | |
prev_loss_kl = -1.0 | |
prev_loss_disc = -1.0 | |
prev_loss_gen_all = -1.0 | |
with accelerator.autocast(): | |
epoch_iterator = tqdm( | |
range(1, epochs + 1), | |
desc="Training", | |
disable=not accelerator.is_main_process, | |
) | |
for epoch in epoch_iterator: | |
if epoch <= finished_epoch: | |
continue | |
G.train() | |
D.train() | |
epoch_loss_gen = 0.0 | |
epoch_loss_fm = 0.0 | |
epoch_loss_mel = 0.0 | |
epoch_loss_kl = 0.0 | |
epoch_loss_disc = 0.0 | |
epoch_loss_gen_all = 0.0 | |
num_batches = 0 | |
batch_iterator = tqdm( | |
loader_train, | |
desc=f"Epoch {epoch}", | |
leave=False, | |
disable=not accelerator.is_main_process, | |
) | |
for batch in batch_iterator: | |
( | |
phone, | |
phone_lengths, | |
pitch, | |
pitchf, | |
spec, | |
spec_lengths, | |
wave, | |
wave_lengths, | |
sid, | |
) = batch | |
# Generator | |
optimizer_G.zero_grad() | |
( | |
y_hat, | |
ids_slice, | |
x_mask, | |
z_mask, | |
(z, z_p, m_p, logs_p, m_q, logs_q), | |
) = G( | |
phone, | |
phone_lengths, | |
pitch, | |
pitchf, | |
spec, | |
spec_lengths, | |
sid, | |
) | |
mel = spec_to_mel_torch( | |
spec, | |
filter_length, | |
n_mel_channels, | |
self.sr, | |
mel_fmin, | |
mel_fmax, | |
) | |
y_mel = commons.slice_segments( | |
mel, ids_slice, segment_size // hop_length | |
) | |
y_hat_mel = mel_spectrogram_torch( | |
y_hat.squeeze(1), | |
filter_length, | |
n_mel_channels, | |
self.sr, | |
hop_length, | |
win_length, | |
mel_fmin, | |
mel_fmax, | |
) | |
wave = commons.slice_segments( | |
wave, ids_slice * hop_length, segment_size | |
) | |
# Discriminator | |
optimizer_D.zero_grad() | |
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = D(wave, y_hat.detach()) | |
# Update Discriminator | |
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( | |
y_d_hat_r, y_d_hat_g | |
) | |
accelerator.backward(loss_disc) | |
optimizer_D.step() | |
# Re-compute discriminator output (since we just got a "better" discriminator) | |
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = D(wave, y_hat) | |
# Update Generator | |
loss_gen, losses_gen = generator_loss(y_d_hat_g) | |
loss_mel = F.l1_loss(y_mel, y_hat_mel) * c_mel | |
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * c_kl | |
loss_fm = feature_loss(fmap_r, fmap_g) | |
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl | |
accelerator.backward(loss_gen_all) | |
optimizer_G.step() | |
prev_loss_gen = loss_gen.item() | |
prev_loss_fm = loss_fm.item() | |
prev_loss_mel = loss_mel.item() | |
prev_loss_kl = loss_kl.item() | |
prev_loss_disc = loss_disc.item() | |
prev_loss_gen_all = loss_gen_all.item() | |
# Update progress bar with current losses | |
if accelerator.is_main_process: | |
batch_iterator.set_postfix( | |
{ | |
"g_loss": f"{prev_loss_gen:.4f}", | |
"d_loss": f"{prev_loss_disc:.4f}", | |
"mel_loss": f"{prev_loss_mel:.4f}", | |
"total": f"{prev_loss_gen_all:.4f}", | |
} | |
) | |
epoch_loss_gen += prev_loss_gen | |
epoch_loss_fm += prev_loss_fm | |
epoch_loss_mel += prev_loss_mel | |
epoch_loss_kl += prev_loss_kl | |
epoch_loss_disc += prev_loss_disc | |
epoch_loss_gen_all += prev_loss_gen_all | |
num_batches += 1 | |
scheduler_G.step() | |
scheduler_D.step() | |
if accelerator.is_main_process and num_batches > 0: | |
avg_gen = epoch_loss_gen / num_batches | |
avg_disc = epoch_loss_disc / num_batches | |
avg_fm = epoch_loss_fm / num_batches | |
avg_mel = epoch_loss_mel / num_batches | |
avg_kl = epoch_loss_kl / num_batches | |
avg_total = epoch_loss_gen_all / num_batches | |
logger.info( | |
f"Epoch {epoch} | " | |
f"Generator Loss: {avg_gen:.4f} | " | |
f"Discriminator Loss: {avg_disc:.4f} | " | |
f"Mel Loss: {avg_mel:.4f} | " | |
f"Total Loss: {avg_total:.4f}" | |
) | |
# Update epoch progress bar | |
epoch_iterator.set_postfix( | |
{ | |
"g_loss": f"{avg_gen:.4f}", | |
"d_loss": f"{avg_disc:.4f}", | |
"total": f"{avg_total:.4f}", | |
} | |
) | |
self.writer.add_scalar("Loss/Generator", avg_gen, epoch) | |
self.writer.add_scalar("Loss/Feature_Matching", avg_fm, epoch) | |
self.writer.add_scalar("Loss/Mel", avg_mel, epoch) | |
self.writer.add_scalar("Loss/KL", avg_kl, epoch) | |
self.writer.add_scalar("Loss/Discriminator", avg_disc, epoch) | |
self.writer.add_scalar("Loss/Generator_Total", avg_total, epoch) | |
self.writer.add_scalar( | |
"Learning_Rate/Generator", | |
scheduler_G.get_last_lr()[0], | |
epoch, | |
) | |
self.writer.add_scalar( | |
"Learning_Rate/Discriminator", | |
scheduler_D.get_last_lr()[0], | |
epoch, | |
) | |
if loader_test is not None: | |
with torch.no_grad(): | |
sample_idx = 0 | |
test_iterator = tqdm( | |
loader_test, | |
desc=f"Testing epoch {epoch}", | |
leave=False, | |
disable=not accelerator.is_main_process, | |
) | |
for batch_idx, ( | |
phone, | |
phone_lengths, | |
pitch, | |
pitchf, | |
spec, | |
spec_lengths, | |
wave, | |
wave_lengths, | |
sid, | |
) in enumerate(test_iterator): | |
# Generate audio for each sample in the batch | |
audio_segments = G.infer( | |
phone, phone_lengths, pitch, pitchf, sid | |
)[0] | |
# Log each audio sample in the batch | |
for i, audio in enumerate(audio_segments): | |
audio_numpy = audio[0].data.cpu().float().numpy() | |
self.writer.add_audio( | |
f"Audio/{sample_idx}", | |
audio_numpy, | |
epoch, | |
sample_rate=self.sr, | |
) | |
sample_idx += 1 | |
res = TrainingCheckpoint( | |
epoch, | |
G, | |
D, | |
optimizer_G, | |
optimizer_D, | |
scheduler_G, | |
scheduler_D, | |
prev_loss_gen, | |
prev_loss_fm, | |
prev_loss_mel, | |
prev_loss_kl, | |
prev_loss_gen_all, | |
prev_loss_disc, | |
) | |
res.save(self.exp_dir) | |
G.save_pretrained(self.exp_dir) | |
if upload_to_hub is not None: | |
if ( | |
time.time() - upload_state_last > 60 * upload_window_minutes | |
or epoch == epochs | |
): | |
try: | |
self.push_to_hub(upload_to_hub) | |
upload_state_last = time.time() | |
except Exception: | |
logger.error(f"Failed to upload to Hub.", exc_info=1) | |
else: | |
next_upload = 60 * upload_window_minutes - ( | |
time.time() - upload_state_last | |
) | |
logger.info( | |
f"Skipping upload to Hub (next upload in {next_upload:.0f} seconds)" | |
) | |
def train( | |
self, | |
resume_from: Tuple[str, str] | None = None, | |
accelerator: Accelerator | None = None, | |
batch_size=1, | |
epochs=100, | |
lr=1e-4, | |
lr_decay=0.999875, | |
betas: Tuple[float, float] = (0.8, 0.99), | |
eps=1e-9, | |
use_spectral_norm=False, | |
segment_size=17280, | |
filter_length=N_FFT, | |
hop_length=HOP_LENGTH, | |
inter_channels=192, | |
hidden_channels=192, | |
filter_channels=768, | |
n_heads=2, | |
n_layers=6, | |
kernel_size=3, | |
p_dropout=0.0, | |
resblock: Literal["1", "2"] = "1", | |
resblock_kernel_sizes: list[int] = [3, 7, 11], | |
resblock_dilation_sizes: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], | |
upsample_initial_channel=512, | |
upsample_rates: list[int] = [12, 10, 2, 2], | |
upsample_kernel_sizes: list[int] = [24, 20, 4, 4], | |
spk_embed_dim=109, | |
gin_channels=256, | |
n_mel_channels=N_MELS, | |
win_length=WIN_LENGTH, | |
mel_fmin=0.0, | |
mel_fmax: float | None = None, | |
c_mel=45, | |
c_kl=1.0, | |
upload_to_hub: str | None = None, | |
): | |
if not os.path.exists(self.exp_dir): | |
os.makedirs(self.exp_dir) | |
if accelerator is None: | |
accelerator = Accelerator() | |
( | |
G, | |
D, | |
optimizer_G, | |
optimizer_D, | |
scheduler_G, | |
scheduler_D, | |
finished_epoch, | |
) = self.setup_models( | |
resume_from=resume_from or self.latest_checkpoint(), | |
accelerator=accelerator, | |
lr=lr, | |
lr_decay=lr_decay, | |
betas=betas, | |
eps=eps, | |
use_spectral_norm=use_spectral_norm, | |
segment_size=segment_size, | |
filter_length=filter_length, | |
hop_length=hop_length, | |
inter_channels=inter_channels, | |
hidden_channels=hidden_channels, | |
filter_channels=filter_channels, | |
n_heads=n_heads, | |
n_layers=n_layers, | |
kernel_size=kernel_size, | |
p_dropout=p_dropout, | |
resblock=resblock, | |
resblock_kernel_sizes=resblock_kernel_sizes, | |
resblock_dilation_sizes=resblock_dilation_sizes, | |
upsample_initial_channel=upsample_initial_channel, | |
upsample_rates=upsample_rates, | |
upsample_kernel_sizes=upsample_kernel_sizes, | |
spk_embed_dim=spk_embed_dim, | |
gin_channels=gin_channels, | |
) | |
loader_train = self.setup_dataloader( | |
self.dataset_train, | |
batch_size=batch_size, | |
accelerator=accelerator, | |
) | |
loader_test = ( | |
self.setup_dataloader( | |
self.dataset_test, | |
batch_size=batch_size, | |
accelerator=accelerator, | |
shuffle=False, | |
) | |
if self.dataset_test is not None | |
else None | |
) | |
return self.run( | |
G, | |
D, | |
optimizer_G, | |
optimizer_D, | |
scheduler_G, | |
scheduler_D, | |
finished_epoch, | |
loader_train, | |
loader_test, | |
accelerator, | |
epochs=epochs, | |
segment_size=segment_size, | |
filter_length=filter_length, | |
hop_length=hop_length, | |
n_mel_channels=n_mel_channels, | |
win_length=win_length, | |
mel_fmin=mel_fmin, | |
mel_fmax=mel_fmax, | |
c_mel=c_mel, | |
c_kl=c_kl, | |
upload_to_hub=upload_to_hub, | |
) | |
def push_to_hub(self, repo: str, private: bool = True): | |
if not os.path.exists(self.exp_dir): | |
raise FileNotFoundError("exp_dir not found") | |
api = HfApi() | |
repo_id = api.create_repo(repo_id=repo, private=private, exist_ok=True).repo_id | |
return upload_folder( | |
repo_id=repo_id, | |
folder_path=self.exp_dir, | |
commit_message="Upload via ZeroRVC", | |
) | |
def __del__(self): | |
if hasattr(self, "writer"): | |
self.writer.close() | |