NeoPy commited on
Commit
df92f27
·
verified ·
1 Parent(s): 0ef0349

Delete train.py

Browse files
Files changed (1) hide show
  1. train.py +0 -92
train.py DELETED
@@ -1,92 +0,0 @@
1
- from model import CFM, UNetT, DiT, Trainer
2
- from model.utils import get_tokenizer
3
- from model.dataset import load_dataset
4
-
5
-
6
- # -------------------------- Dataset Settings --------------------------- #
7
-
8
- target_sample_rate = 24000
9
- n_mel_channels = 100
10
- hop_length = 256
11
-
12
- tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
13
- tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
14
- dataset_name = "Emilia_ZH_EN"
15
-
16
- # -------------------------- Training Settings -------------------------- #
17
-
18
- exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
19
-
20
- learning_rate = 7.5e-5
21
-
22
- batch_size_per_gpu = 38400 # 8 GPUs, 8 * 38400 = 307200
23
- batch_size_type = "frame" # "frame" or "sample"
24
- max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
25
- grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps
26
- max_grad_norm = 1.0
27
-
28
- epochs = 11 # use linear decay, thus epochs control the slope
29
- num_warmup_updates = 20000 # warmup steps
30
- save_per_updates = 50000 # save checkpoint per steps
31
- last_per_steps = 5000 # save last checkpoint per steps
32
-
33
- # model params
34
- if exp_name == "F5TTS_Base":
35
- wandb_resume_id = None
36
- model_cls = DiT
37
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
38
- elif exp_name == "E2TTS_Base":
39
- wandb_resume_id = None
40
- model_cls = UNetT
41
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
42
-
43
-
44
- # ----------------------------------------------------------------------- #
45
-
46
-
47
- def main():
48
- if tokenizer == "custom":
49
- tokenizer_path = tokenizer_path
50
- else:
51
- tokenizer_path = dataset_name
52
- vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
53
-
54
- mel_spec_kwargs = dict(
55
- target_sample_rate=target_sample_rate,
56
- n_mel_channels=n_mel_channels,
57
- hop_length=hop_length,
58
- )
59
-
60
- model = CFM(
61
- transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
62
- mel_spec_kwargs=mel_spec_kwargs,
63
- vocab_char_map=vocab_char_map,
64
- )
65
-
66
- trainer = Trainer(
67
- model,
68
- epochs,
69
- learning_rate,
70
- num_warmup_updates=num_warmup_updates,
71
- save_per_updates=save_per_updates,
72
- checkpoint_path=f"ckpts/{exp_name}",
73
- batch_size=batch_size_per_gpu,
74
- batch_size_type=batch_size_type,
75
- max_samples=max_samples,
76
- grad_accumulation_steps=grad_accumulation_steps,
77
- max_grad_norm=max_grad_norm,
78
- wandb_project="CFM-TTS",
79
- wandb_run_name=exp_name,
80
- wandb_resume_id=wandb_resume_id,
81
- last_per_steps=last_per_steps,
82
- )
83
-
84
- train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
85
- trainer.train(
86
- train_dataset,
87
- resumable_with_seed=666, # seed for shuffling dataset
88
- )
89
-
90
-
91
- if __name__ == "__main__":
92
- main()