Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,939 Bytes
4dab15f ea89faa b5979c9 4dab15f b5979c9 1a19e0f b5979c9 43bc5dc 4dab15f fededd1 4dab15f ea89faa 4dab15f b5979c9 43bc5dc 1a19e0f 43bc5dc 1a19e0f 4dab15f b5979c9 43bc5dc 4dab15f 43bc5dc 4dab15f b5979c9 4dab15f 43bc5dc 4dab15f b5979c9 4dab15f 43bc5dc 4dab15f 43bc5dc fededd1 43bc5dc 4dab15f 43bc5dc 4dab15f 43bc5dc 4dab15f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
# training script.
import os
from importlib.resources import files
import hydra
from omegaconf import OmegaConf
from f5_tts.model import CFM, Trainer
from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer
os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable)
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
def main(model_cfg):
model_cls = hydra.utils.get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
tokenizer = model_cfg.model.tokenizer
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
exp_name = f"{model_cfg.model.name}_{mel_spec_type}_{model_cfg.model.tokenizer}_{model_cfg.datasets.name}"
wandb_resume_id = None
# set text tokenizer
if tokenizer != "custom":
tokenizer_path = model_cfg.datasets.name
else:
tokenizer_path = model_cfg.model.tokenizer_path
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
# set model
model = CFM(
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=model_cfg.model.mel_spec.n_mel_channels),
mel_spec_kwargs=model_cfg.model.mel_spec,
vocab_char_map=vocab_char_map,
)
# init trainer
trainer = Trainer(
model,
epochs=model_cfg.optim.epochs,
learning_rate=model_cfg.optim.learning_rate,
num_warmup_updates=model_cfg.optim.num_warmup_updates,
save_per_updates=model_cfg.ckpts.save_per_updates,
keep_last_n_checkpoints=model_cfg.ckpts.keep_last_n_checkpoints,
checkpoint_path=str(files("f5_tts").joinpath(f"../../{model_cfg.ckpts.save_dir}")),
batch_size_per_gpu=model_cfg.datasets.batch_size_per_gpu,
batch_size_type=model_cfg.datasets.batch_size_type,
max_samples=model_cfg.datasets.max_samples,
grad_accumulation_steps=model_cfg.optim.grad_accumulation_steps,
max_grad_norm=model_cfg.optim.max_grad_norm,
logger=model_cfg.ckpts.logger,
wandb_project="CFM-TTS",
wandb_run_name=exp_name,
wandb_resume_id=wandb_resume_id,
last_per_updates=model_cfg.ckpts.last_per_updates,
log_samples=model_cfg.ckpts.log_samples,
bnb_optimizer=model_cfg.optim.bnb_optimizer,
mel_spec_type=mel_spec_type,
is_local_vocoder=model_cfg.model.vocoder.is_local,
local_vocoder_path=model_cfg.model.vocoder.local_path,
model_cfg_dict=OmegaConf.to_container(model_cfg, resolve=True),
)
train_dataset = load_dataset(model_cfg.datasets.name, tokenizer, mel_spec_kwargs=model_cfg.model.mel_spec)
trainer.train(
train_dataset,
num_workers=model_cfg.datasets.num_workers,
resumable_with_seed=666, # seed for shuffling dataset
)
if __name__ == "__main__":
main()
|