mrfakename commited on
Commit
1a19e0f
·
verified ·
1 Parent(s): 40eed23

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (47) hide show
  1. .gitattributes +5 -0
  2. .github/workflows/publish-pypi.yaml +66 -0
  3. README_REPO.md +99 -24
  4. app.py +52 -13
  5. ckpts/README.md +5 -3
  6. pyproject.toml +1 -1
  7. src/f5_tts/api.py +59 -60
  8. src/f5_tts/configs/E2TTS_Base.yaml +49 -0
  9. src/f5_tts/configs/E2TTS_Small.yaml +49 -0
  10. src/f5_tts/configs/F5TTS_Base.yaml +52 -0
  11. src/f5_tts/configs/F5TTS_Small.yaml +52 -0
  12. src/f5_tts/configs/F5TTS_v1_Base.yaml +53 -0
  13. src/f5_tts/eval/eval_infer_batch.py +22 -27
  14. src/f5_tts/eval/eval_infer_batch.sh +11 -6
  15. src/f5_tts/eval/eval_librispeech_test_clean.py +21 -27
  16. src/f5_tts/eval/eval_seedtts_testset.py +21 -27
  17. src/f5_tts/eval/eval_utmos.py +15 -17
  18. src/f5_tts/eval/utils_eval.py +11 -6
  19. src/f5_tts/infer/README.md +38 -80
  20. src/f5_tts/infer/SHARED.md +19 -9
  21. src/f5_tts/infer/examples/basic/basic.toml +2 -2
  22. src/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  23. src/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  24. src/f5_tts/infer/examples/multi/country.flac +0 -0
  25. src/f5_tts/infer/examples/multi/main.flac +0 -0
  26. src/f5_tts/infer/examples/multi/story.toml +2 -2
  27. src/f5_tts/infer/examples/multi/town.flac +0 -0
  28. src/f5_tts/infer/infer_cli.py +26 -31
  29. src/f5_tts/infer/speech_edit.py +35 -28
  30. src/f5_tts/infer/utils_infer.py +114 -72
  31. src/f5_tts/model/backbones/README.md +2 -2
  32. src/f5_tts/model/backbones/dit.py +63 -8
  33. src/f5_tts/model/backbones/mmdit.py +52 -9
  34. src/f5_tts/model/backbones/unett.py +36 -5
  35. src/f5_tts/model/cfm.py +9 -11
  36. src/f5_tts/model/dataset.py +21 -10
  37. src/f5_tts/model/modules.py +115 -42
  38. src/f5_tts/model/trainer.py +143 -72
  39. src/f5_tts/model/utils.py +4 -3
  40. src/f5_tts/scripts/count_max_epoch.py +3 -3
  41. src/f5_tts/socket_client.py +61 -0
  42. src/f5_tts/socket_server.py +176 -99
  43. src/f5_tts/train/README.md +5 -5
  44. src/f5_tts/train/datasets/prepare_csv_wavs.py +188 -43
  45. src/f5_tts/train/finetune_cli.py +63 -21
  46. src/f5_tts/train/finetune_gradio.py +272 -250
  47. src/f5_tts/train/train.py +12 -11
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ src/f5_tts/infer/examples/basic/basic_ref_en.wav filter=lfs diff=lfs merge=lfs -text
37
+ src/f5_tts/infer/examples/basic/basic_ref_zh.wav filter=lfs diff=lfs merge=lfs -text
38
+ src/f5_tts/infer/examples/multi/country.flac filter=lfs diff=lfs merge=lfs -text
39
+ src/f5_tts/infer/examples/multi/main.flac filter=lfs diff=lfs merge=lfs -text
40
+ src/f5_tts/infer/examples/multi/town.flac filter=lfs diff=lfs merge=lfs -text
.github/workflows/publish-pypi.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This workflow uses actions that are not certified by GitHub.
2
+ # They are provided by a third-party and are governed by
3
+ # separate terms of service, privacy policy, and support
4
+ # documentation.
5
+
6
+ # GitHub recommends pinning actions to a commit SHA.
7
+ # To get a newer version, you will need to update the SHA.
8
+ # You can also reference a tag or branch, but the action may change without warning.
9
+
10
+ name: Upload Python Package
11
+
12
+ on:
13
+ release:
14
+ types: [published]
15
+
16
+ permissions:
17
+ contents: read
18
+
19
+ jobs:
20
+ release-build:
21
+ runs-on: ubuntu-latest
22
+
23
+ steps:
24
+ - uses: actions/checkout@v4
25
+
26
+ - uses: actions/setup-python@v5
27
+ with:
28
+ python-version: "3.x"
29
+
30
+ - name: Build release distributions
31
+ run: |
32
+ # NOTE: put your own distribution build steps here.
33
+ python -m pip install build
34
+ python -m build
35
+
36
+ - name: Upload distributions
37
+ uses: actions/upload-artifact@v4
38
+ with:
39
+ name: release-dists
40
+ path: dist/
41
+
42
+ pypi-publish:
43
+ runs-on: ubuntu-latest
44
+
45
+ needs:
46
+ - release-build
47
+
48
+ permissions:
49
+ # IMPORTANT: this permission is mandatory for trusted publishing
50
+ id-token: write
51
+
52
+ # Dedicated environments with protections for publishing are strongly recommended.
53
+ environment:
54
+ name: pypi
55
+ # OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status:
56
+ # url: https://pypi.org/p/YOURPROJECT
57
+
58
+ steps:
59
+ - name: Retrieve release distributions
60
+ uses: actions/download-artifact@v4
61
+ with:
62
+ name: release-dists
63
+ path: dist/
64
+
65
+ - name: Publish release distributions to PyPI
66
+ uses: pypa/gh-action-pypi-publish@release/v1
README_REPO.md CHANGED
@@ -6,7 +6,8 @@
6
  [![hfspace](https://img.shields.io/badge/🤗-Space%20demo-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
7
  [![msspace](https://img.shields.io/badge/🤖-Space%20demo-blue)](https://modelscope.cn/studios/modelscope/E2-F5-TTS)
8
  [![lab](https://img.shields.io/badge/X--LANCE-Lab-grey?labelColor=lightgrey)](https://x-lance.sjtu.edu.cn/)
9
- <img src="https://github.com/user-attachments/assets/12d7749c-071a-427c-81bf-b87b91def670" alt="Watermark" style="width: 40px; height: auto">
 
10
 
11
  **F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
12
 
@@ -17,40 +18,84 @@
17
  ### Thanks to all the contributors !
18
 
19
  ## News
 
20
  - **2024/10/08**: F5-TTS & E2 TTS base models on [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), [🟣 Wisemodel](https://wisemodel.cn/models/SJTU_X-LANCE/F5-TTS_Emilia-ZH-EN).
21
 
22
  ## Installation
23
 
 
 
24
  ```bash
25
  # Create a python 3.10 conda env (you could also use virtualenv)
26
  conda create -n f5-tts python=3.10
27
  conda activate f5-tts
 
28
 
29
- # NVIDIA GPU: install pytorch with your CUDA version, e.g.
30
- pip install torch==2.3.0+cu118 torchaudio==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
31
 
32
- # AMD GPU: install pytorch with your ROCm version, e.g.
33
- pip install torch==2.5.1+rocm6.2 torchaudio==2.5.1+rocm6.2 --extra-index-url https://download.pytorch.org/whl/rocm6.2
34
- ```
35
 
36
- Then you can choose from a few options below:
 
 
 
37
 
38
- ### 1. As a pip package (if just for inference)
39
 
40
- ```bash
41
- pip install git+https://github.com/SWivid/F5-TTS.git
42
- ```
43
 
44
- ### 2. Local editable (if also do training, finetuning)
 
 
 
45
 
46
- ```bash
47
- git clone https://github.com/SWivid/F5-TTS.git
48
- cd F5-TTS
49
- # git submodule update --init --recursive # (optional, if need bigvgan)
50
- pip install -e .
51
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- ### 3. Docker usage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  ```bash
55
  # Build from Dockerfile
56
  docker build -t f5tts:v1 .
@@ -82,14 +127,40 @@ f5-tts_infer-gradio --port 7860 --host 0.0.0.0
82
  f5-tts_infer-gradio --share
83
  ```
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  ### 2. CLI Inference
86
 
87
  ```bash
88
  # Run with flags
89
  # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
90
- f5-tts_infer-cli \
91
- --model "F5-TTS" \
92
- --ref_audio "ref_audio.wav" \
93
  --ref_text "The content, subtitle or transcription of reference audio." \
94
  --gen_text "Some text you want TTS model generate for you."
95
 
@@ -110,15 +181,19 @@ f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml
110
 
111
  ## Training
112
 
113
- ### 1. Gradio App
114
 
115
- Read [training & finetuning guidance](src/f5_tts/train) for more instructions.
 
 
116
 
117
  ```bash
118
  # Quick start with Gradio web interface
119
  f5-tts_finetune-gradio
120
  ```
121
 
 
 
122
 
123
  ## [Evaluation](src/f5_tts/eval)
124
 
 
6
  [![hfspace](https://img.shields.io/badge/🤗-Space%20demo-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
7
  [![msspace](https://img.shields.io/badge/🤖-Space%20demo-blue)](https://modelscope.cn/studios/modelscope/E2-F5-TTS)
8
  [![lab](https://img.shields.io/badge/X--LANCE-Lab-grey?labelColor=lightgrey)](https://x-lance.sjtu.edu.cn/)
9
+ [![lab](https://img.shields.io/badge/Peng%20Cheng-Lab-grey?labelColor=lightgrey)](https://www.pcl.ac.cn)
10
+ <!-- <img src="https://github.com/user-attachments/assets/12d7749c-071a-427c-81bf-b87b91def670" alt="Watermark" style="width: 40px; height: auto"> -->
11
 
12
  **F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
13
 
 
18
  ### Thanks to all the contributors !
19
 
20
  ## News
21
+ - **2025/03/12**: 🔥 F5-TTS v1 base model with better training and inference performance. [Few demo](https://swivid.github.io/F5-TTS_updates).
22
  - **2024/10/08**: F5-TTS & E2 TTS base models on [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), [🟣 Wisemodel](https://wisemodel.cn/models/SJTU_X-LANCE/F5-TTS_Emilia-ZH-EN).
23
 
24
  ## Installation
25
 
26
+ ### Create a separate environment if needed
27
+
28
  ```bash
29
  # Create a python 3.10 conda env (you could also use virtualenv)
30
  conda create -n f5-tts python=3.10
31
  conda activate f5-tts
32
+ ```
33
 
34
+ ### Install PyTorch with matched device
 
35
 
36
+ <details>
37
+ <summary>NVIDIA GPU</summary>
 
38
 
39
+ > ```bash
40
+ > # Install pytorch with your CUDA version, e.g.
41
+ > pip install torch==2.4.0+cu124 torchaudio==2.4.0+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
42
+ > ```
43
 
44
+ </details>
45
 
46
+ <details>
47
+ <summary>AMD GPU</summary>
 
48
 
49
+ > ```bash
50
+ > # Install pytorch with your ROCm version (Linux only), e.g.
51
+ > pip install torch==2.5.1+rocm6.2 torchaudio==2.5.1+rocm6.2 --extra-index-url https://download.pytorch.org/whl/rocm6.2
52
+ > ```
53
 
54
+ </details>
55
+
56
+ <details>
57
+ <summary>Intel GPU</summary>
58
+
59
+ > ```bash
60
+ > # Install pytorch with your XPU version, e.g.
61
+ > # Intel® Deep Learning Essentials or Intel® oneAPI Base Toolkit must be installed
62
+ > pip install torch torchaudio --index-url https://download.pytorch.org/whl/test/xpu
63
+ >
64
+ > # Intel GPU support is also available through IPEX (Intel® Extension for PyTorch)
65
+ > # IPEX does not require the Intel® Deep Learning Essentials or Intel® oneAPI Base Toolkit
66
+ > # See: https://pytorch-extension.intel.com/installation?request=platform
67
+ > ```
68
+
69
+ </details>
70
+
71
+ <details>
72
+ <summary>Apple Silicon</summary>
73
+
74
+ > ```bash
75
+ > # Install the stable pytorch, e.g.
76
+ > pip install torch torchaudio
77
+ > ```
78
 
79
+ </details>
80
+
81
+ ### Then you can choose one from below:
82
+
83
+ > ### 1. As a pip package (if just for inference)
84
+ >
85
+ > ```bash
86
+ > pip install f5-tts
87
+ > ```
88
+ >
89
+ > ### 2. Local editable (if also do training, finetuning)
90
+ >
91
+ > ```bash
92
+ > git clone https://github.com/SWivid/F5-TTS.git
93
+ > cd F5-TTS
94
+ > # git submodule update --init --recursive # (optional, if need > bigvgan)
95
+ > pip install -e .
96
+ > ```
97
+
98
+ ### Docker usage also available
99
  ```bash
100
  # Build from Dockerfile
101
  docker build -t f5tts:v1 .
 
127
  f5-tts_infer-gradio --share
128
  ```
129
 
130
+ <details>
131
+ <summary>NVIDIA device docker compose file example</summary>
132
+
133
+ ```yaml
134
+ services:
135
+ f5-tts:
136
+ image: ghcr.io/swivid/f5-tts:main
137
+ ports:
138
+ - "7860:7860"
139
+ environment:
140
+ GRADIO_SERVER_PORT: 7860
141
+ entrypoint: ["f5-tts_infer-gradio", "--port", "7860", "--host", "0.0.0.0"]
142
+ deploy:
143
+ resources:
144
+ reservations:
145
+ devices:
146
+ - driver: nvidia
147
+ count: 1
148
+ capabilities: [gpu]
149
+
150
+ volumes:
151
+ f5-tts:
152
+ driver: local
153
+ ```
154
+
155
+ </details>
156
+
157
  ### 2. CLI Inference
158
 
159
  ```bash
160
  # Run with flags
161
  # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
162
+ f5-tts_infer-cli --model F5TTS_v1_Base \
163
+ --ref_audio "provide_prompt_wav_path_here.wav" \
 
164
  --ref_text "The content, subtitle or transcription of reference audio." \
165
  --gen_text "Some text you want TTS model generate for you."
166
 
 
181
 
182
  ## Training
183
 
184
+ ### 1. With Hugging Face Accelerate
185
 
186
+ Refer to [training & finetuning guidance](src/f5_tts/train) for best practice.
187
+
188
+ ### 2. With Gradio App
189
 
190
  ```bash
191
  # Quick start with Gradio web interface
192
  f5-tts_finetune-gradio
193
  ```
194
 
195
+ Read [training & finetuning guidance](src/f5_tts/train) for more instructions.
196
+
197
 
198
  ## [Evaluation](src/f5_tts/eval)
199
 
app.py CHANGED
@@ -41,12 +41,12 @@ from f5_tts.infer.utils_infer import (
41
  )
42
 
43
 
44
- DEFAULT_TTS_MODEL = "F5-TTS"
45
  tts_model_choice = DEFAULT_TTS_MODEL
46
 
47
  DEFAULT_TTS_MODEL_CFG = [
48
- "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
49
- "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
50
  json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
51
  ]
52
 
@@ -56,13 +56,15 @@ DEFAULT_TTS_MODEL_CFG = [
56
  vocoder = load_vocoder()
57
 
58
 
59
- def load_f5tts(ckpt_path=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))):
60
- F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
 
61
  return load_model(DiT, F5TTS_model_cfg, ckpt_path)
62
 
63
 
64
- def load_e2tts(ckpt_path=str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))):
65
- E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
 
66
  return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
67
 
68
 
@@ -73,7 +75,7 @@ def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
73
  if vocab_path.startswith("hf://"):
74
  vocab_path = str(cached_path(vocab_path))
75
  if model_cfg is None:
76
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
77
  return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
78
 
79
 
@@ -130,7 +132,7 @@ def infer(
130
 
131
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
132
 
133
- if model == "F5-TTS":
134
  ema_model = F5TTS_ema_model
135
  elif model == "E2-TTS":
136
  global E2TTS_ema_model
@@ -762,7 +764,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
762
  """
763
  )
764
 
765
- last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info.txt")
766
 
767
  def load_last_used_custom():
768
  try:
@@ -821,7 +823,30 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
821
  custom_model_cfg = gr.Dropdown(
822
  choices=[
823
  DEFAULT_TTS_MODEL_CFG[2],
824
- json.dumps(dict(dim=768, depth=18, heads=12, ff_mult=2, text_dim=512, conv_layers=4)),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
825
  ],
826
  value=load_last_used_custom()[2],
827
  allow_custom_value=True,
@@ -875,10 +900,24 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
875
  type=str,
876
  help='The root path (or "mount point") of the application, if it\'s not served from the root ("/") of the domain. Often used when the application is behind a reverse proxy that forwards requests to the application, e.g. set "/myapp" or full URL for application served at "https://example.com/myapp".',
877
  )
878
- def main(port, host, share, api, root_path):
 
 
 
 
 
 
 
879
  global app
880
  print("Starting app...")
881
- app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api, root_path=root_path)
 
 
 
 
 
 
 
882
 
883
 
884
  if __name__ == "__main__":
 
41
  )
42
 
43
 
44
+ DEFAULT_TTS_MODEL = "F5-TTS_v1"
45
  tts_model_choice = DEFAULT_TTS_MODEL
46
 
47
  DEFAULT_TTS_MODEL_CFG = [
48
+ "hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors",
49
+ "hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt",
50
  json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
51
  ]
52
 
 
56
  vocoder = load_vocoder()
57
 
58
 
59
+ def load_f5tts():
60
+ ckpt_path = str(cached_path(DEFAULT_TTS_MODEL_CFG[0]))
61
+ F5TTS_model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
62
  return load_model(DiT, F5TTS_model_cfg, ckpt_path)
63
 
64
 
65
+ def load_e2tts():
66
+ ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
67
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1)
68
  return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
69
 
70
 
 
75
  if vocab_path.startswith("hf://"):
76
  vocab_path = str(cached_path(vocab_path))
77
  if model_cfg is None:
78
+ model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
79
  return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
80
 
81
 
 
132
 
133
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
134
 
135
+ if model == DEFAULT_TTS_MODEL:
136
  ema_model = F5TTS_ema_model
137
  elif model == "E2-TTS":
138
  global E2TTS_ema_model
 
764
  """
765
  )
766
 
767
+ last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info_v1.txt")
768
 
769
  def load_last_used_custom():
770
  try:
 
823
  custom_model_cfg = gr.Dropdown(
824
  choices=[
825
  DEFAULT_TTS_MODEL_CFG[2],
826
+ json.dumps(
827
+ dict(
828
+ dim=1024,
829
+ depth=22,
830
+ heads=16,
831
+ ff_mult=2,
832
+ text_dim=512,
833
+ text_mask_padding=False,
834
+ conv_layers=4,
835
+ pe_attn_head=1,
836
+ )
837
+ ),
838
+ json.dumps(
839
+ dict(
840
+ dim=768,
841
+ depth=18,
842
+ heads=12,
843
+ ff_mult=2,
844
+ text_dim=512,
845
+ text_mask_padding=False,
846
+ conv_layers=4,
847
+ pe_attn_head=1,
848
+ )
849
+ ),
850
  ],
851
  value=load_last_used_custom()[2],
852
  allow_custom_value=True,
 
900
  type=str,
901
  help='The root path (or "mount point") of the application, if it\'s not served from the root ("/") of the domain. Often used when the application is behind a reverse proxy that forwards requests to the application, e.g. set "/myapp" or full URL for application served at "https://example.com/myapp".',
902
  )
903
+ @click.option(
904
+ "--inbrowser",
905
+ "-i",
906
+ is_flag=True,
907
+ default=False,
908
+ help="Automatically launch the interface in the default web browser",
909
+ )
910
+ def main(port, host, share, api, root_path, inbrowser):
911
  global app
912
  print("Starting app...")
913
+ app.queue(api_open=api).launch(
914
+ server_name=host,
915
+ server_port=port,
916
+ share=share,
917
+ show_api=api,
918
+ root_path=root_path,
919
+ inbrowser=inbrowser,
920
+ )
921
 
922
 
923
  if __name__ == "__main__":
ckpts/README.md CHANGED
@@ -3,8 +3,10 @@ Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS
3
 
4
  ```
5
  ckpts/
6
- E2TTS_Base/
7
- model_1200000.pt
8
  F5TTS_Base/
9
- model_1200000.pt
 
 
10
  ```
 
3
 
4
  ```
5
  ckpts/
6
+ F5TTS_v1_Base/
7
+ model_1250000.safetensors
8
  F5TTS_Base/
9
+ model_1200000.safetensors
10
+ E2TTS_Base/
11
+ model_1200000.safetensors
12
  ```
pyproject.toml CHANGED
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
 
5
  [project]
6
  name = "f5-tts"
7
- version = "0.3.4"
8
  description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
9
  readme = "README.md"
10
  license = {text = "MIT License"}
 
4
 
5
  [project]
6
  name = "f5-tts"
7
+ version = "1.0.1"
8
  description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
9
  readme = "README.md"
10
  license = {text = "MIT License"}
src/f5_tts/api.py CHANGED
@@ -5,84 +5,84 @@ from importlib.resources import files
5
  import soundfile as sf
6
  import tqdm
7
  from cached_path import cached_path
 
8
 
9
  from f5_tts.infer.utils_infer import (
10
- hop_length,
11
- infer_process,
12
  load_model,
13
  load_vocoder,
 
14
  preprocess_ref_audio_text,
 
15
  remove_silence_for_generated_wav,
16
  save_spectrogram,
17
- transcribe,
18
- target_sample_rate,
19
  )
20
- from f5_tts.model import DiT, UNetT
21
  from f5_tts.model.utils import seed_everything
22
 
23
 
24
  class F5TTS:
25
  def __init__(
26
  self,
27
- model_type="F5-TTS",
28
  ckpt_file="",
29
  vocab_file="",
30
  ode_method="euler",
31
  use_ema=True,
32
- vocoder_name="vocos",
33
- local_path=None,
34
  device=None,
35
  hf_cache_dir=None,
36
  ):
37
- # Initialize parameters
38
- self.final_wave = None
39
- self.target_sample_rate = target_sample_rate
40
- self.hop_length = hop_length
41
- self.seed = -1
42
- self.mel_spec_type = vocoder_name
43
-
44
- # Set device
 
 
45
  if device is not None:
46
  self.device = device
47
  else:
48
  import torch
49
 
50
- self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
 
 
 
 
 
 
 
 
51
 
52
  # Load models
53
- self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)
54
- self.load_ema_model(
55
- model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir
56
  )
57
 
58
- def load_vocoder_model(self, vocoder_name, local_path=None, hf_cache_dir=None):
59
- self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device, hf_cache_dir)
60
-
61
- def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, hf_cache_dir=None):
62
- if model_type == "F5-TTS":
63
- if not ckpt_file:
64
- if mel_spec_type == "vocos":
65
- ckpt_file = str(
66
- cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
67
- )
68
- elif mel_spec_type == "bigvgan":
69
- ckpt_file = str(
70
- cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir)
71
- )
72
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
73
- model_cls = DiT
74
- elif model_type == "E2-TTS":
75
- if not ckpt_file:
76
- ckpt_file = str(
77
- cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
78
- )
79
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
80
- model_cls = UNetT
81
  else:
82
- raise ValueError(f"Unknown model type: {model_type}")
83
 
 
 
 
 
84
  self.ema_model = load_model(
85
- model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
86
  )
87
 
88
  def transcribe(self, ref_audio, language=None):
@@ -94,8 +94,8 @@ class F5TTS:
94
  if remove_silence:
95
  remove_silence_for_generated_wav(file_wave)
96
 
97
- def export_spectrogram(self, spect, file_spect):
98
- save_spectrogram(spect, file_spect)
99
 
100
  def infer(
101
  self,
@@ -113,17 +113,16 @@ class F5TTS:
113
  fix_duration=None,
114
  remove_silence=False,
115
  file_wave=None,
116
- file_spect=None,
117
- seed=-1,
118
  ):
119
- if seed == -1:
120
- seed = random.randint(0, sys.maxsize)
121
- seed_everything(seed)
122
- self.seed = seed
123
 
124
  ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
125
 
126
- wav, sr, spect = infer_process(
127
  ref_file,
128
  ref_text,
129
  gen_text,
@@ -145,22 +144,22 @@ class F5TTS:
145
  if file_wave is not None:
146
  self.export_wav(wav, file_wave, remove_silence)
147
 
148
- if file_spect is not None:
149
- self.export_spectrogram(spect, file_spect)
150
 
151
- return wav, sr, spect
152
 
153
 
154
  if __name__ == "__main__":
155
  f5tts = F5TTS()
156
 
157
- wav, sr, spect = f5tts.infer(
158
  ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
159
  ref_text="some call me nature, others call me mother nature.",
160
  gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
161
  file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
162
- file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")),
163
- seed=-1, # random seed = -1
164
  )
165
 
166
  print("seed :", f5tts.seed)
 
5
  import soundfile as sf
6
  import tqdm
7
  from cached_path import cached_path
8
+ from omegaconf import OmegaConf
9
 
10
  from f5_tts.infer.utils_infer import (
 
 
11
  load_model,
12
  load_vocoder,
13
+ transcribe,
14
  preprocess_ref_audio_text,
15
+ infer_process,
16
  remove_silence_for_generated_wav,
17
  save_spectrogram,
 
 
18
  )
19
+ from f5_tts.model import DiT, UNetT # noqa: F401. used for config
20
  from f5_tts.model.utils import seed_everything
21
 
22
 
23
  class F5TTS:
24
  def __init__(
25
  self,
26
+ model="F5TTS_v1_Base",
27
  ckpt_file="",
28
  vocab_file="",
29
  ode_method="euler",
30
  use_ema=True,
31
+ vocoder_local_path=None,
 
32
  device=None,
33
  hf_cache_dir=None,
34
  ):
35
+ model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
36
+ model_cls = globals()[model_cfg.model.backbone]
37
+ model_arc = model_cfg.model.arch
38
+
39
+ self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
40
+ self.target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
41
+
42
+ self.ode_method = ode_method
43
+ self.use_ema = use_ema
44
+
45
  if device is not None:
46
  self.device = device
47
  else:
48
  import torch
49
 
50
+ self.device = (
51
+ "cuda"
52
+ if torch.cuda.is_available()
53
+ else "xpu"
54
+ if torch.xpu.is_available()
55
+ else "mps"
56
+ if torch.backends.mps.is_available()
57
+ else "cpu"
58
+ )
59
 
60
  # Load models
61
+ self.vocoder = load_vocoder(
62
+ self.mel_spec_type, vocoder_local_path is not None, vocoder_local_path, self.device, hf_cache_dir
 
63
  )
64
 
65
+ repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
66
+
67
+ # override for previous models
68
+ if model == "F5TTS_Base":
69
+ if self.mel_spec_type == "vocos":
70
+ ckpt_step = 1200000
71
+ elif self.mel_spec_type == "bigvgan":
72
+ model = "F5TTS_Base_bigvgan"
73
+ ckpt_type = "pt"
74
+ elif model == "E2TTS_Base":
75
+ repo_name = "E2-TTS"
76
+ ckpt_step = 1200000
 
 
 
 
 
 
 
 
 
 
 
77
  else:
78
+ raise ValueError(f"Unknown model type: {model}")
79
 
80
+ if not ckpt_file:
81
+ ckpt_file = str(
82
+ cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}", cache_dir=hf_cache_dir)
83
+ )
84
  self.ema_model = load_model(
85
+ model_cls, model_arc, ckpt_file, self.mel_spec_type, vocab_file, self.ode_method, self.use_ema, self.device
86
  )
87
 
88
  def transcribe(self, ref_audio, language=None):
 
94
  if remove_silence:
95
  remove_silence_for_generated_wav(file_wave)
96
 
97
+ def export_spectrogram(self, spec, file_spec):
98
+ save_spectrogram(spec, file_spec)
99
 
100
  def infer(
101
  self,
 
113
  fix_duration=None,
114
  remove_silence=False,
115
  file_wave=None,
116
+ file_spec=None,
117
+ seed=None,
118
  ):
119
+ if seed is None:
120
+ self.seed = random.randint(0, sys.maxsize)
121
+ seed_everything(self.seed)
 
122
 
123
  ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
124
 
125
+ wav, sr, spec = infer_process(
126
  ref_file,
127
  ref_text,
128
  gen_text,
 
144
  if file_wave is not None:
145
  self.export_wav(wav, file_wave, remove_silence)
146
 
147
+ if file_spec is not None:
148
+ self.export_spectrogram(spec, file_spec)
149
 
150
+ return wav, sr, spec
151
 
152
 
153
  if __name__ == "__main__":
154
  f5tts = F5TTS()
155
 
156
+ wav, sr, spec = f5tts.infer(
157
  ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
158
  ref_text="some call me nature, others call me mother nature.",
159
  gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
160
  file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
161
+ file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
162
+ seed=None,
163
  )
164
 
165
  print("seed :", f5tts.seed)
src/f5_tts/configs/E2TTS_Base.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
+ datasets:
6
+ name: Emilia_ZH_EN # dataset name
7
+ batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # frame | sample
9
+ max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16
11
+
12
+ optim:
13
+ epochs: 11
14
+ learning_rate: 7.5e-5
15
+ num_warmup_updates: 20000 # warmup updates
16
+ grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
+ max_grad_norm: 1.0 # gradient clipping
18
+ bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
19
+
20
+ model:
21
+ name: E2TTS_Base
22
+ tokenizer: pinyin
23
+ tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
24
+ backbone: UNetT
25
+ arch:
26
+ dim: 1024
27
+ depth: 24
28
+ heads: 16
29
+ ff_mult: 4
30
+ text_mask_padding: False
31
+ pe_attn_head: 1
32
+ mel_spec:
33
+ target_sample_rate: 24000
34
+ n_mel_channels: 100
35
+ hop_length: 256
36
+ win_length: 1024
37
+ n_fft: 1024
38
+ mel_spec_type: vocos # vocos | bigvgan
39
+ vocoder:
40
+ is_local: False # use local offline ckpt or not
41
+ local_path: null # local vocoder path
42
+
43
+ ckpts:
44
+ logger: wandb # wandb | tensorboard | null
45
+ log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
46
+ save_per_updates: 50000 # save checkpoint per updates
47
+ keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
48
+ last_per_updates: 5000 # save last checkpoint per updates
49
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
src/f5_tts/configs/E2TTS_Small.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
+ datasets:
6
+ name: Emilia_ZH_EN
7
+ batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # frame | sample
9
+ max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16
11
+
12
+ optim:
13
+ epochs: 11
14
+ learning_rate: 7.5e-5
15
+ num_warmup_updates: 20000 # warmup updates
16
+ grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
+ max_grad_norm: 1.0
18
+ bnb_optimizer: False
19
+
20
+ model:
21
+ name: E2TTS_Small
22
+ tokenizer: pinyin
23
+ tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
24
+ backbone: UNetT
25
+ arch:
26
+ dim: 768
27
+ depth: 20
28
+ heads: 12
29
+ ff_mult: 4
30
+ text_mask_padding: False
31
+ pe_attn_head: 1
32
+ mel_spec:
33
+ target_sample_rate: 24000
34
+ n_mel_channels: 100
35
+ hop_length: 256
36
+ win_length: 1024
37
+ n_fft: 1024
38
+ mel_spec_type: vocos # vocos | bigvgan
39
+ vocoder:
40
+ is_local: False # use local offline ckpt or not
41
+ local_path: null # local vocoder path
42
+
43
+ ckpts:
44
+ logger: wandb # wandb | tensorboard | null
45
+ log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
46
+ save_per_updates: 50000 # save checkpoint per updates
47
+ keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
48
+ last_per_updates: 5000 # save last checkpoint per updates
49
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
src/f5_tts/configs/F5TTS_Base.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
+ datasets:
6
+ name: Emilia_ZH_EN # dataset name
7
+ batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # frame | sample
9
+ max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16
11
+
12
+ optim:
13
+ epochs: 11
14
+ learning_rate: 7.5e-5
15
+ num_warmup_updates: 20000 # warmup updates
16
+ grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
+ max_grad_norm: 1.0 # gradient clipping
18
+ bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
19
+
20
+ model:
21
+ name: F5TTS_Base # model name
22
+ tokenizer: pinyin # tokenizer type
23
+ tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
24
+ backbone: DiT
25
+ arch:
26
+ dim: 1024
27
+ depth: 22
28
+ heads: 16
29
+ ff_mult: 2
30
+ text_dim: 512
31
+ text_mask_padding: False
32
+ conv_layers: 4
33
+ pe_attn_head: 1
34
+ checkpoint_activations: False # recompute activations and save memory for extra compute
35
+ mel_spec:
36
+ target_sample_rate: 24000
37
+ n_mel_channels: 100
38
+ hop_length: 256
39
+ win_length: 1024
40
+ n_fft: 1024
41
+ mel_spec_type: vocos # vocos | bigvgan
42
+ vocoder:
43
+ is_local: False # use local offline ckpt or not
44
+ local_path: null # local vocoder path
45
+
46
+ ckpts:
47
+ logger: wandb # wandb | tensorboard | null
48
+ log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
49
+ save_per_updates: 50000 # save checkpoint per updates
50
+ keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
51
+ last_per_updates: 5000 # save last checkpoint per updates
52
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
src/f5_tts/configs/F5TTS_Small.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
+ datasets:
6
+ name: Emilia_ZH_EN
7
+ batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # frame | sample
9
+ max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16
11
+
12
+ optim:
13
+ epochs: 11
14
+ learning_rate: 7.5e-5
15
+ num_warmup_updates: 20000 # warmup updates
16
+ grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
+ max_grad_norm: 1.0 # gradient clipping
18
+ bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
19
+
20
+ model:
21
+ name: F5TTS_Small
22
+ tokenizer: pinyin
23
+ tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
24
+ backbone: DiT
25
+ arch:
26
+ dim: 768
27
+ depth: 18
28
+ heads: 12
29
+ ff_mult: 2
30
+ text_dim: 512
31
+ text_mask_padding: False
32
+ conv_layers: 4
33
+ pe_attn_head: 1
34
+ checkpoint_activations: False # recompute activations and save memory for extra compute
35
+ mel_spec:
36
+ target_sample_rate: 24000
37
+ n_mel_channels: 100
38
+ hop_length: 256
39
+ win_length: 1024
40
+ n_fft: 1024
41
+ mel_spec_type: vocos # vocos | bigvgan
42
+ vocoder:
43
+ is_local: False # use local offline ckpt or not
44
+ local_path: null # local vocoder path
45
+
46
+ ckpts:
47
+ logger: wandb # wandb | tensorboard | null
48
+ log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
49
+ save_per_updates: 50000 # save checkpoint per updates
50
+ keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
51
+ last_per_updates: 5000 # save last checkpoint per updates
52
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
src/f5_tts/configs/F5TTS_v1_Base.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
+ datasets:
6
+ name: Emilia_ZH_EN # dataset name
7
+ batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # frame | sample
9
+ max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16
11
+
12
+ optim:
13
+ epochs: 11
14
+ learning_rate: 7.5e-5
15
+ num_warmup_updates: 20000 # warmup updates
16
+ grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
+ max_grad_norm: 1.0 # gradient clipping
18
+ bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
19
+
20
+ model:
21
+ name: F5TTS_v1_Base # model name
22
+ tokenizer: pinyin # tokenizer type
23
+ tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
24
+ backbone: DiT
25
+ arch:
26
+ dim: 1024
27
+ depth: 22
28
+ heads: 16
29
+ ff_mult: 2
30
+ text_dim: 512
31
+ text_mask_padding: True
32
+ qk_norm: null # null | rms_norm
33
+ conv_layers: 4
34
+ pe_attn_head: null
35
+ checkpoint_activations: False # recompute activations and save memory for extra compute
36
+ mel_spec:
37
+ target_sample_rate: 24000
38
+ n_mel_channels: 100
39
+ hop_length: 256
40
+ win_length: 1024
41
+ n_fft: 1024
42
+ mel_spec_type: vocos # vocos | bigvgan
43
+ vocoder:
44
+ is_local: False # use local offline ckpt or not
45
+ local_path: null # local vocoder path
46
+
47
+ ckpts:
48
+ logger: wandb # wandb | tensorboard | null
49
+ log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
50
+ save_per_updates: 50000 # save checkpoint per updates
51
+ keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
52
+ last_per_updates: 5000 # save last checkpoint per updates
53
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
src/f5_tts/eval/eval_infer_batch.py CHANGED
@@ -10,6 +10,7 @@ from importlib.resources import files
10
  import torch
11
  import torchaudio
12
  from accelerate import Accelerator
 
13
  from tqdm import tqdm
14
 
15
  from f5_tts.eval.utils_eval import (
@@ -18,36 +19,26 @@ from f5_tts.eval.utils_eval import (
18
  get_seedtts_testset_metainfo,
19
  )
20
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
21
- from f5_tts.model import CFM, DiT, UNetT
22
  from f5_tts.model.utils import get_tokenizer
23
 
24
  accelerator = Accelerator()
25
  device = f"cuda:{accelerator.process_index}"
26
 
27
 
28
- # --------------------- Dataset Settings -------------------- #
29
-
30
- target_sample_rate = 24000
31
- n_mel_channels = 100
32
- hop_length = 256
33
- win_length = 1024
34
- n_fft = 1024
35
  target_rms = 0.1
36
 
 
37
  rel_path = str(files("f5_tts").joinpath("../../"))
38
 
39
 
40
  def main():
41
- # ---------------------- infer setting ---------------------- #
42
-
43
  parser = argparse.ArgumentParser(description="batch inference")
44
 
45
  parser.add_argument("-s", "--seed", default=None, type=int)
46
- parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
47
  parser.add_argument("-n", "--expname", required=True)
48
- parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
49
- parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
50
- parser.add_argument("-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"])
51
 
52
  parser.add_argument("-nfe", "--nfestep", default=32, type=int)
53
  parser.add_argument("-o", "--odemethod", default="euler")
@@ -58,12 +49,8 @@ def main():
58
  args = parser.parse_args()
59
 
60
  seed = args.seed
61
- dataset_name = args.dataset
62
  exp_name = args.expname
63
  ckpt_step = args.ckptstep
64
- ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
65
- mel_spec_type = args.mel_spec_type
66
- tokenizer = args.tokenizer
67
 
68
  nfe_step = args.nfestep
69
  ode_method = args.odemethod
@@ -77,13 +64,19 @@ def main():
77
  use_truth_duration = False
78
  no_ref_audio = False
79
 
80
- if exp_name == "F5TTS_Base":
81
- model_cls = DiT
82
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
83
 
84
- elif exp_name == "E2TTS_Base":
85
- model_cls = UNetT
86
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
 
 
 
 
 
 
87
 
88
  if testset == "ls_pc_test_clean":
89
  metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
@@ -111,8 +104,6 @@ def main():
111
 
112
  # -------------------------------------------------#
113
 
114
- use_ema = True
115
-
116
  prompts_all = get_inference_prompt(
117
  metainfo,
118
  speed=speed,
@@ -139,7 +130,7 @@ def main():
139
 
140
  # Model
141
  model = CFM(
142
- transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
143
  mel_spec_kwargs=dict(
144
  n_fft=n_fft,
145
  hop_length=hop_length,
@@ -154,6 +145,10 @@ def main():
154
  vocab_char_map=vocab_char_map,
155
  ).to(device)
156
 
 
 
 
 
157
  dtype = torch.float32 if mel_spec_type == "bigvgan" else None
158
  model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
159
 
 
10
  import torch
11
  import torchaudio
12
  from accelerate import Accelerator
13
+ from omegaconf import OmegaConf
14
  from tqdm import tqdm
15
 
16
  from f5_tts.eval.utils_eval import (
 
19
  get_seedtts_testset_metainfo,
20
  )
21
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
22
+ from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
23
  from f5_tts.model.utils import get_tokenizer
24
 
25
  accelerator = Accelerator()
26
  device = f"cuda:{accelerator.process_index}"
27
 
28
 
29
+ use_ema = True
 
 
 
 
 
 
30
  target_rms = 0.1
31
 
32
+
33
  rel_path = str(files("f5_tts").joinpath("../../"))
34
 
35
 
36
  def main():
 
 
37
  parser = argparse.ArgumentParser(description="batch inference")
38
 
39
  parser.add_argument("-s", "--seed", default=None, type=int)
 
40
  parser.add_argument("-n", "--expname", required=True)
41
+ parser.add_argument("-c", "--ckptstep", default=1250000, type=int)
 
 
42
 
43
  parser.add_argument("-nfe", "--nfestep", default=32, type=int)
44
  parser.add_argument("-o", "--odemethod", default="euler")
 
49
  args = parser.parse_args()
50
 
51
  seed = args.seed
 
52
  exp_name = args.expname
53
  ckpt_step = args.ckptstep
 
 
 
54
 
55
  nfe_step = args.nfestep
56
  ode_method = args.odemethod
 
64
  use_truth_duration = False
65
  no_ref_audio = False
66
 
67
+ model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
68
+ model_cls = globals()[model_cfg.model.backbone]
69
+ model_arc = model_cfg.model.arch
70
 
71
+ dataset_name = model_cfg.datasets.name
72
+ tokenizer = model_cfg.model.tokenizer
73
+
74
+ mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
75
+ target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
76
+ n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
77
+ hop_length = model_cfg.model.mel_spec.hop_length
78
+ win_length = model_cfg.model.mel_spec.win_length
79
+ n_fft = model_cfg.model.mel_spec.n_fft
80
 
81
  if testset == "ls_pc_test_clean":
82
  metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
 
104
 
105
  # -------------------------------------------------#
106
 
 
 
107
  prompts_all = get_inference_prompt(
108
  metainfo,
109
  speed=speed,
 
130
 
131
  # Model
132
  model = CFM(
133
+ transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
134
  mel_spec_kwargs=dict(
135
  n_fft=n_fft,
136
  hop_length=hop_length,
 
145
  vocab_char_map=vocab_char_map,
146
  ).to(device)
147
 
148
+ ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
149
+ if not os.path.exists(ckpt_path):
150
+ print("Loading from self-organized training checkpoints rather than released pretrained.")
151
+ ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
152
  dtype = torch.float32 if mel_spec_type == "bigvgan" else None
153
  model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
154
 
src/f5_tts/eval/eval_infer_batch.sh CHANGED
@@ -1,13 +1,18 @@
1
  #!/bin/bash
2
 
3
  # e.g. F5-TTS, 16 NFE
4
- accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
5
- accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
6
- accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
7
 
8
  # e.g. Vanilla E2 TTS, 32 NFE
9
- accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
10
- accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
11
- accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
 
 
 
 
 
12
 
13
  # etc.
 
1
  #!/bin/bash
2
 
3
  # e.g. F5-TTS, 16 NFE
4
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16
5
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16
6
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16
7
 
8
  # e.g. Vanilla E2 TTS, 32 NFE
9
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0
10
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0
11
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0
12
+
13
+ # e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh
14
+ python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
15
+ python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
16
+ python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0
17
 
18
  # etc.
src/f5_tts/eval/eval_librispeech_test_clean.py CHANGED
@@ -53,43 +53,37 @@ def main():
53
  asr_ckpt_dir = "" # auto download to cache dir
54
  wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
55
 
56
- # --------------------------- WER ---------------------------
57
 
58
- if eval_task == "wer":
59
- wer_results = []
60
- wers = []
61
 
 
62
  with mp.Pool(processes=len(gpus)) as pool:
63
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
64
  results = pool.map(run_asr_wer, args)
65
  for r in results:
66
- wer_results.extend(r)
67
-
68
- wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
69
- with open(wer_result_path, "w") as f:
70
- for line in wer_results:
71
- wers.append(line["wer"])
72
- json_line = json.dumps(line, ensure_ascii=False)
73
- f.write(json_line + "\n")
74
-
75
- wer = round(np.mean(wers) * 100, 3)
76
- print(f"\nTotal {len(wers)} samples")
77
- print(f"WER : {wer}%")
78
- print(f"Results have been saved to {wer_result_path}")
79
-
80
- # --------------------------- SIM ---------------------------
81
-
82
- if eval_task == "sim":
83
- sims = []
84
  with mp.Pool(processes=len(gpus)) as pool:
85
  args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
86
  results = pool.map(run_sim, args)
87
  for r in results:
88
- sims.extend(r)
89
-
90
- sim = round(sum(sims) / len(sims), 3)
91
- print(f"\nTotal {len(sims)} samples")
92
- print(f"SIM : {sim}")
 
 
 
 
 
 
 
 
 
 
93
 
94
 
95
  if __name__ == "__main__":
 
53
  asr_ckpt_dir = "" # auto download to cache dir
54
  wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
55
 
56
+ # --------------------------------------------------------------------------
57
 
58
+ full_results = []
59
+ metrics = []
 
60
 
61
+ if eval_task == "wer":
62
  with mp.Pool(processes=len(gpus)) as pool:
63
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
64
  results = pool.map(run_asr_wer, args)
65
  for r in results:
66
+ full_results.extend(r)
67
+ elif eval_task == "sim":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  with mp.Pool(processes=len(gpus)) as pool:
69
  args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
70
  results = pool.map(run_sim, args)
71
  for r in results:
72
+ full_results.extend(r)
73
+ else:
74
+ raise ValueError(f"Unknown metric type: {eval_task}")
75
+
76
+ result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
77
+ with open(result_path, "w") as f:
78
+ for line in full_results:
79
+ metrics.append(line[eval_task])
80
+ f.write(json.dumps(line, ensure_ascii=False) + "\n")
81
+ metric = round(np.mean(metrics), 5)
82
+ f.write(f"\n{eval_task.upper()}: {metric}\n")
83
+
84
+ print(f"\nTotal {len(metrics)} samples")
85
+ print(f"{eval_task.upper()}: {metric}")
86
+ print(f"{eval_task.upper()} results saved to {result_path}")
87
 
88
 
89
  if __name__ == "__main__":
src/f5_tts/eval/eval_seedtts_testset.py CHANGED
@@ -52,43 +52,37 @@ def main():
52
  asr_ckpt_dir = "" # auto download to cache dir
53
  wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
54
 
55
- # --------------------------- WER ---------------------------
56
 
57
- if eval_task == "wer":
58
- wer_results = []
59
- wers = []
60
 
 
61
  with mp.Pool(processes=len(gpus)) as pool:
62
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
63
  results = pool.map(run_asr_wer, args)
64
  for r in results:
65
- wer_results.extend(r)
66
-
67
- wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
68
- with open(wer_result_path, "w") as f:
69
- for line in wer_results:
70
- wers.append(line["wer"])
71
- json_line = json.dumps(line, ensure_ascii=False)
72
- f.write(json_line + "\n")
73
-
74
- wer = round(np.mean(wers) * 100, 3)
75
- print(f"\nTotal {len(wers)} samples")
76
- print(f"WER : {wer}%")
77
- print(f"Results have been saved to {wer_result_path}")
78
-
79
- # --------------------------- SIM ---------------------------
80
-
81
- if eval_task == "sim":
82
- sims = []
83
  with mp.Pool(processes=len(gpus)) as pool:
84
  args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
85
  results = pool.map(run_sim, args)
86
  for r in results:
87
- sims.extend(r)
88
-
89
- sim = round(sum(sims) / len(sims), 3)
90
- print(f"\nTotal {len(sims)} samples")
91
- print(f"SIM : {sim}")
 
 
 
 
 
 
 
 
 
 
92
 
93
 
94
  if __name__ == "__main__":
 
52
  asr_ckpt_dir = "" # auto download to cache dir
53
  wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
54
 
55
+ # --------------------------------------------------------------------------
56
 
57
+ full_results = []
58
+ metrics = []
 
59
 
60
+ if eval_task == "wer":
61
  with mp.Pool(processes=len(gpus)) as pool:
62
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
63
  results = pool.map(run_asr_wer, args)
64
  for r in results:
65
+ full_results.extend(r)
66
+ elif eval_task == "sim":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  with mp.Pool(processes=len(gpus)) as pool:
68
  args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
69
  results = pool.map(run_sim, args)
70
  for r in results:
71
+ full_results.extend(r)
72
+ else:
73
+ raise ValueError(f"Unknown metric type: {eval_task}")
74
+
75
+ result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
76
+ with open(result_path, "w") as f:
77
+ for line in full_results:
78
+ metrics.append(line[eval_task])
79
+ f.write(json.dumps(line, ensure_ascii=False) + "\n")
80
+ metric = round(np.mean(metrics), 5)
81
+ f.write(f"\n{eval_task.upper()}: {metric}\n")
82
+
83
+ print(f"\nTotal {len(metrics)} samples")
84
+ print(f"{eval_task.upper()}: {metric}")
85
+ print(f"{eval_task.upper()} results saved to {result_path}")
86
 
87
 
88
  if __name__ == "__main__":
src/f5_tts/eval/eval_utmos.py CHANGED
@@ -13,31 +13,29 @@ def main():
13
  parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
14
  args = parser.parse_args()
15
 
16
- device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
  predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
19
  predictor = predictor.to(device)
20
 
21
  audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
22
- utmos_results = {}
23
  utmos_score = 0
24
 
25
- for audio_path in tqdm(audio_paths, desc="Processing"):
26
- wav_name = audio_path.stem
27
- wav, sr = librosa.load(audio_path, sr=None, mono=True)
28
- wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
29
- score = predictor(wav_tensor, sr)
30
- utmos_results[str(wav_name)] = score.item()
31
- utmos_score += score.item()
32
-
33
- avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
34
- print(f"UTMOS: {avg_score}")
35
-
36
- utmos_result_path = Path(args.audio_dir) / "utmos_results.json"
37
  with open(utmos_result_path, "w", encoding="utf-8") as f:
38
- json.dump(utmos_results, f, ensure_ascii=False, indent=4)
39
-
40
- print(f"Results have been saved to {utmos_result_path}")
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  if __name__ == "__main__":
 
13
  parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
14
  args = parser.parse_args()
15
 
16
+ device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu"
17
 
18
  predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
19
  predictor = predictor.to(device)
20
 
21
  audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
 
22
  utmos_score = 0
23
 
24
+ utmos_result_path = Path(args.audio_dir) / "_utmos_results.jsonl"
 
 
 
 
 
 
 
 
 
 
 
25
  with open(utmos_result_path, "w", encoding="utf-8") as f:
26
+ for audio_path in tqdm(audio_paths, desc="Processing"):
27
+ wav, sr = librosa.load(audio_path, sr=None, mono=True)
28
+ wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
29
+ score = predictor(wav_tensor, sr)
30
+ line = {}
31
+ line["wav"], line["utmos"] = str(audio_path.stem), score.item()
32
+ utmos_score += score.item()
33
+ f.write(json.dumps(line, ensure_ascii=False) + "\n")
34
+ avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
35
+ f.write(f"\nUTMOS: {avg_score:.4f}\n")
36
+
37
+ print(f"UTMOS: {avg_score:.4f}")
38
+ print(f"UTMOS results saved to {utmos_result_path}")
39
 
40
 
41
  if __name__ == "__main__":
src/f5_tts/eval/utils_eval.py CHANGED
@@ -389,10 +389,10 @@ def run_sim(args):
389
  model = model.cuda(device)
390
  model.eval()
391
 
392
- sims = []
393
- for wav1, wav2, truth in tqdm(test_set):
394
- wav1, sr1 = torchaudio.load(wav1)
395
- wav2, sr2 = torchaudio.load(wav2)
396
 
397
  resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
398
  resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
@@ -408,6 +408,11 @@ def run_sim(args):
408
 
409
  sim = F.cosine_similarity(emb1, emb2)[0].item()
410
  # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
411
- sims.append(sim)
 
 
 
 
 
412
 
413
- return sims
 
389
  model = model.cuda(device)
390
  model.eval()
391
 
392
+ sim_results = []
393
+ for gen_wav, prompt_wav, truth in tqdm(test_set):
394
+ wav1, sr1 = torchaudio.load(gen_wav)
395
+ wav2, sr2 = torchaudio.load(prompt_wav)
396
 
397
  resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
398
  resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
 
408
 
409
  sim = F.cosine_similarity(emb1, emb2)[0].item()
410
  # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
411
+ sim_results.append(
412
+ {
413
+ "wav": Path(gen_wav).stem,
414
+ "sim": sim,
415
+ }
416
+ )
417
 
418
+ return sim_results
src/f5_tts/infer/README.md CHANGED
@@ -23,12 +23,24 @@ Currently supported features:
23
  - Basic TTS with Chunk Inference
24
  - Multi-Style / Multi-Speaker Generation
25
  - Voice Chat powered by Qwen2.5-3B-Instruct
 
26
 
27
  The cli command `f5-tts_infer-gradio` equals to `python src/f5_tts/infer/infer_gradio.py`, which launches a Gradio APP (web interface) for inference.
28
 
29
  The script will load model checkpoints from Huggingface. You can also manually download files and update the path to `load_model()` in `infer_gradio.py`. Currently only load TTS models first, will load ASR model to do transcription if `ref_text` not provided, will load LLM model if use Voice Chat.
30
 
31
- Could also be used as a component for larger application.
 
 
 
 
 
 
 
 
 
 
 
32
  ```python
33
  import gradio as gr
34
  from f5_tts.infer.infer_gradio import app
@@ -56,14 +68,16 @@ Basically you can inference with flags:
56
  ```bash
57
  # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
58
  f5-tts_infer-cli \
59
- --model "F5-TTS" \
60
  --ref_audio "ref_audio.wav" \
61
  --ref_text "The content, subtitle or transcription of reference audio." \
62
  --gen_text "Some text you want TTS model generate for you."
63
 
64
- # Choose Vocoder
65
- f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
66
- f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
 
 
67
 
68
  # More instructions
69
  f5-tts_infer-cli --help
@@ -78,8 +92,8 @@ f5-tts_infer-cli -c custom.toml
78
  For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`:
79
 
80
  ```toml
81
- # F5-TTS | E2-TTS
82
- model = "F5-TTS"
83
  ref_audio = "infer/examples/basic/basic_ref_en.wav"
84
  # If an empty "", transcribes the reference audio automatically.
85
  ref_text = "Some call me nature, others call me mother nature."
@@ -93,8 +107,8 @@ output_dir = "tests"
93
  You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`.
94
 
95
  ```toml
96
- # F5-TTS | E2-TTS
97
- model = "F5-TTS"
98
  ref_audio = "infer/examples/multi/main.flac"
99
  # If an empty "", transcribes the reference audio automatically.
100
  ref_text = ""
@@ -114,83 +128,27 @@ ref_text = ""
114
  ```
115
  You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
116
 
117
- ## Speech Editing
118
 
119
- To test speech editing capabilities, use the following command:
120
 
121
  ```bash
122
- python src/f5_tts/infer/speech_edit.py
123
- ```
124
 
125
- ## Socket Realtime Client
 
 
126
 
127
- To communicate with socket server you need to run
128
- ```bash
129
- python src/f5_tts/socket_server.py
130
  ```
131
 
132
- <details>
133
- <summary>Then create client to communicate</summary>
134
-
135
- ``` python
136
- import socket
137
- import numpy as np
138
- import asyncio
139
- import pyaudio
140
-
141
- async def listen_to_voice(text, server_ip='localhost', server_port=9999):
142
- client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
143
- client_socket.connect((server_ip, server_port))
144
-
145
- async def play_audio_stream():
146
- buffer = b''
147
- p = pyaudio.PyAudio()
148
- stream = p.open(format=pyaudio.paFloat32,
149
- channels=1,
150
- rate=24000, # Ensure this matches the server's sampling rate
151
- output=True,
152
- frames_per_buffer=2048)
153
-
154
- try:
155
- while True:
156
- chunk = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 1024)
157
- if not chunk: # End of stream
158
- break
159
- if b"END_OF_AUDIO" in chunk:
160
- buffer += chunk.replace(b"END_OF_AUDIO", b"")
161
- if buffer:
162
- audio_array = np.frombuffer(buffer, dtype=np.float32).copy() # Make a writable copy
163
- stream.write(audio_array.tobytes())
164
- break
165
- buffer += chunk
166
- if len(buffer) >= 4096:
167
- audio_array = np.frombuffer(buffer[:4096], dtype=np.float32).copy() # Make a writable copy
168
- stream.write(audio_array.tobytes())
169
- buffer = buffer[4096:]
170
- finally:
171
- stream.stop_stream()
172
- stream.close()
173
- p.terminate()
174
-
175
- try:
176
- # Send only the text to the server
177
- await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, text.encode('utf-8'))
178
- await play_audio_stream()
179
- print("Audio playback finished.")
180
-
181
- except Exception as e:
182
- print(f"Error in listen_to_voice: {e}")
183
-
184
- finally:
185
- client_socket.close()
186
-
187
- # Example usage: Replace this with your actual server IP and port
188
- async def main():
189
- await listen_to_voice("my name is jenny..", server_ip='localhost', server_port=9998)
190
-
191
- # Run the main async function
192
- asyncio.run(main())
193
- ```
194
 
195
- </details>
 
 
 
 
196
 
 
23
  - Basic TTS with Chunk Inference
24
  - Multi-Style / Multi-Speaker Generation
25
  - Voice Chat powered by Qwen2.5-3B-Instruct
26
+ - [Custom inference with more language support](src/f5_tts/infer/SHARED.md)
27
 
28
  The cli command `f5-tts_infer-gradio` equals to `python src/f5_tts/infer/infer_gradio.py`, which launches a Gradio APP (web interface) for inference.
29
 
30
  The script will load model checkpoints from Huggingface. You can also manually download files and update the path to `load_model()` in `infer_gradio.py`. Currently only load TTS models first, will load ASR model to do transcription if `ref_text` not provided, will load LLM model if use Voice Chat.
31
 
32
+ More flags options:
33
+
34
+ ```bash
35
+ # Automatically launch the interface in the default web browser
36
+ f5-tts_infer-gradio --inbrowser
37
+
38
+ # Set the root path of the application, if it's not served from the root ("/") of the domain
39
+ # For example, if the application is served at "https://example.com/myapp"
40
+ f5-tts_infer-gradio --root_path "/myapp"
41
+ ```
42
+
43
+ Could also be used as a component for larger application:
44
  ```python
45
  import gradio as gr
46
  from f5_tts.infer.infer_gradio import app
 
68
  ```bash
69
  # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
70
  f5-tts_infer-cli \
71
+ --model F5TTS_v1_Base \
72
  --ref_audio "ref_audio.wav" \
73
  --ref_text "The content, subtitle or transcription of reference audio." \
74
  --gen_text "Some text you want TTS model generate for you."
75
 
76
+ # Use BigVGAN as vocoder. Currently only support F5TTS_Base.
77
+ f5-tts_infer-cli --model F5TTS_Base --vocoder_name bigvgan --load_vocoder_from_local
78
+
79
+ # Use custom path checkpoint, e.g.
80
+ f5-tts_infer-cli --ckpt_file ckpts/F5TTS_v1_Base/model_1250000.safetensors
81
 
82
  # More instructions
83
  f5-tts_infer-cli --help
 
92
  For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`:
93
 
94
  ```toml
95
+ # F5TTS_v1_Base | E2TTS_Base
96
+ model = "F5TTS_v1_Base"
97
  ref_audio = "infer/examples/basic/basic_ref_en.wav"
98
  # If an empty "", transcribes the reference audio automatically.
99
  ref_text = "Some call me nature, others call me mother nature."
 
107
  You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`.
108
 
109
  ```toml
110
+ # F5TTS_v1_Base | E2TTS_Base
111
+ model = "F5TTS_v1_Base"
112
  ref_audio = "infer/examples/multi/main.flac"
113
  # If an empty "", transcribes the reference audio automatically.
114
  ref_text = ""
 
128
  ```
129
  You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
130
 
131
+ ## Socket Real-time Service
132
 
133
+ Real-time voice output with chunk stream:
134
 
135
  ```bash
136
+ # Start socket server
137
+ python src/f5_tts/socket_server.py
138
 
139
+ # If PyAudio not installed
140
+ sudo apt-get install portaudio19-dev
141
+ pip install pyaudio
142
 
143
+ # Communicate with socket client
144
+ python src/f5_tts/socket_client.py
 
145
  ```
146
 
147
+ ## Speech Editing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
+ To test speech editing capabilities, use the following command:
150
+
151
+ ```bash
152
+ python src/f5_tts/infer/speech_edit.py
153
+ ```
154
 
src/f5_tts/infer/SHARED.md CHANGED
@@ -16,7 +16,7 @@
16
  <!-- omit in toc -->
17
  ### Supported Languages
18
  - [Multilingual](#multilingual)
19
- - [F5-TTS Base @ zh \& en @ F5-TTS](#f5-tts-base--zh--en--f5-tts)
20
  - [English](#english)
21
  - [Finnish](#finnish)
22
  - [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
@@ -37,7 +37,17 @@
37
 
38
  ## Multilingual
39
 
40
- #### F5-TTS Base @ zh & en @ F5-TTS
 
 
 
 
 
 
 
 
 
 
41
  |Model|🤗Hugging Face|Data (Hours)|Model License|
42
  |:---:|:------------:|:-----------:|:-------------:|
43
  |F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
@@ -45,7 +55,7 @@
45
  ```bash
46
  Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
47
  Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
48
- Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
49
  ```
50
 
51
  *Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
@@ -64,7 +74,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
64
  ```bash
65
  Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
66
  Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
67
- Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
68
  ```
69
 
70
 
@@ -78,7 +88,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
78
  ```bash
79
  Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
80
  Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
81
- Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
82
  ```
83
 
84
  - [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
@@ -96,7 +106,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
96
  ```bash
97
  Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
98
  Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
99
- Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
100
  ```
101
 
102
  - Authors: SPRING Lab, Indian Institute of Technology, Madras
@@ -113,7 +123,7 @@ Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "c
113
  ```bash
114
  Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
115
  Vocab: hf://alien79/F5-TTS-italian/vocab.txt
116
- Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
117
  ```
118
 
119
  - Trained by [Mithril Man](https://github.com/MithrilMan)
@@ -131,7 +141,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
131
  ```bash
132
  Model: hf://Jmica/F5TTS/JA_25498980/model_25498980.pt
133
  Vocab: hf://Jmica/F5TTS/JA_25498980/vocab_updated.txt
134
- Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
135
  ```
136
 
137
 
@@ -148,7 +158,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
148
  ```bash
149
  Model: hf://hotstone228/F5-TTS-Russian/model_last.safetensors
150
  Vocab: hf://hotstone228/F5-TTS-Russian/vocab.txt
151
- Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
152
  ```
153
  - Finetuned by [HotDro4illa](https://github.com/HotDro4illa)
154
  - Any improvements are welcome
 
16
  <!-- omit in toc -->
17
  ### Supported Languages
18
  - [Multilingual](#multilingual)
19
+ - [F5-TTS v1 v0 Base @ zh \& en @ F5-TTS](#f5-tts-v1-v0-base--zh--en--f5-tts)
20
  - [English](#english)
21
  - [Finnish](#finnish)
22
  - [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
 
37
 
38
  ## Multilingual
39
 
40
+ #### F5-TTS v1 v0 Base @ zh & en @ F5-TTS
41
+ |Model|🤗Hugging Face|Data (Hours)|Model License|
42
+ |:---:|:------------:|:-----------:|:-------------:|
43
+ |F5-TTS v1 Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_v1_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
44
+
45
+ ```bash
46
+ Model: hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors
47
+ Vocab: hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt
48
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
49
+ ```
50
+
51
  |Model|🤗Hugging Face|Data (Hours)|Model License|
52
  |:---:|:------------:|:-----------:|:-------------:|
53
  |F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
 
55
  ```bash
56
  Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
57
  Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
58
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
59
  ```
60
 
61
  *Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
 
74
  ```bash
75
  Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
76
  Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
77
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
78
  ```
79
 
80
 
 
88
  ```bash
89
  Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
90
  Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
91
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
92
  ```
93
 
94
  - [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
 
106
  ```bash
107
  Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
108
  Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
109
+ Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
110
  ```
111
 
112
  - Authors: SPRING Lab, Indian Institute of Technology, Madras
 
123
  ```bash
124
  Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
125
  Vocab: hf://alien79/F5-TTS-italian/vocab.txt
126
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
127
  ```
128
 
129
  - Trained by [Mithril Man](https://github.com/MithrilMan)
 
141
  ```bash
142
  Model: hf://Jmica/F5TTS/JA_25498980/model_25498980.pt
143
  Vocab: hf://Jmica/F5TTS/JA_25498980/vocab_updated.txt
144
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
145
  ```
146
 
147
 
 
158
  ```bash
159
  Model: hf://hotstone228/F5-TTS-Russian/model_last.safetensors
160
  Vocab: hf://hotstone228/F5-TTS-Russian/vocab.txt
161
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
162
  ```
163
  - Finetuned by [HotDro4illa](https://github.com/HotDro4illa)
164
  - Any improvements are welcome
src/f5_tts/infer/examples/basic/basic.toml CHANGED
@@ -1,5 +1,5 @@
1
- # F5-TTS | E2-TTS
2
- model = "F5-TTS"
3
  ref_audio = "infer/examples/basic/basic_ref_en.wav"
4
  # If an empty "", transcribes the reference audio automatically.
5
  ref_text = "Some call me nature, others call me mother nature."
 
1
+ # F5TTS_v1_Base | E2TTS_Base
2
+ model = "F5TTS_v1_Base"
3
  ref_audio = "infer/examples/basic/basic_ref_en.wav"
4
  # If an empty "", transcribes the reference audio automatically.
5
  ref_text = "Some call me nature, others call me mother nature."
src/f5_tts/infer/examples/basic/basic_ref_en.wav CHANGED
Binary files a/src/f5_tts/infer/examples/basic/basic_ref_en.wav and b/src/f5_tts/infer/examples/basic/basic_ref_en.wav differ
 
src/f5_tts/infer/examples/basic/basic_ref_zh.wav CHANGED
Binary files a/src/f5_tts/infer/examples/basic/basic_ref_zh.wav and b/src/f5_tts/infer/examples/basic/basic_ref_zh.wav differ
 
src/f5_tts/infer/examples/multi/country.flac CHANGED
Binary files a/src/f5_tts/infer/examples/multi/country.flac and b/src/f5_tts/infer/examples/multi/country.flac differ
 
src/f5_tts/infer/examples/multi/main.flac CHANGED
Binary files a/src/f5_tts/infer/examples/multi/main.flac and b/src/f5_tts/infer/examples/multi/main.flac differ
 
src/f5_tts/infer/examples/multi/story.toml CHANGED
@@ -1,5 +1,5 @@
1
- # F5-TTS | E2-TTS
2
- model = "F5-TTS"
3
  ref_audio = "infer/examples/multi/main.flac"
4
  # If an empty "", transcribes the reference audio automatically.
5
  ref_text = ""
 
1
+ # F5TTS_v1_Base | E2TTS_Base
2
+ model = "F5TTS_v1_Base"
3
  ref_audio = "infer/examples/multi/main.flac"
4
  # If an empty "", transcribes the reference audio automatically.
5
  ref_text = ""
src/f5_tts/infer/examples/multi/town.flac CHANGED
Binary files a/src/f5_tts/infer/examples/multi/town.flac and b/src/f5_tts/infer/examples/multi/town.flac differ
 
src/f5_tts/infer/infer_cli.py CHANGED
@@ -27,7 +27,7 @@ from f5_tts.infer.utils_infer import (
27
  preprocess_ref_audio_text,
28
  remove_silence_for_generated_wav,
29
  )
30
- from f5_tts.model import DiT, UNetT
31
 
32
 
33
  parser = argparse.ArgumentParser(
@@ -50,7 +50,7 @@ parser.add_argument(
50
  "-m",
51
  "--model",
52
  type=str,
53
- help="The model name: F5-TTS | E2-TTS",
54
  )
55
  parser.add_argument(
56
  "-mc",
@@ -172,8 +172,7 @@ config = tomli.load(open(args.config, "rb"))
172
 
173
  # command-line interface parameters
174
 
175
- model = args.model or config.get("model", "F5-TTS")
176
- model_cfg = args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath("configs/F5TTS_Base_train.yaml")))
177
  ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
178
  vocab_file = args.vocab_file or config.get("vocab_file", "")
179
 
@@ -245,36 +244,32 @@ vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_loc
245
 
246
  # load TTS model
247
 
248
- if model == "F5-TTS":
249
- model_cls = DiT
250
- model_cfg = OmegaConf.load(model_cfg).model.arch
251
- if not ckpt_file: # path not specified, download from repo
252
- if vocoder_name == "vocos":
253
- repo_name = "F5-TTS"
254
- exp_name = "F5TTS_Base"
255
- ckpt_step = 1200000
256
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
257
- # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
258
- elif vocoder_name == "bigvgan":
259
- repo_name = "F5-TTS"
260
- exp_name = "F5TTS_Base_bigvgan"
261
- ckpt_step = 1250000
262
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
263
-
264
- elif model == "E2-TTS":
265
- assert args.model_cfg is None, "E2-TTS does not support custom model_cfg yet"
266
- assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos yet"
267
- model_cls = UNetT
268
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
269
- if not ckpt_file: # path not specified, download from repo
270
- repo_name = "E2-TTS"
271
- exp_name = "E2TTS_Base"
272
  ckpt_step = 1200000
273
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
274
- # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
 
 
 
 
 
 
 
275
 
276
  print(f"Using {model}...")
277
- ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
278
 
279
 
280
  # inference process
 
27
  preprocess_ref_audio_text,
28
  remove_silence_for_generated_wav,
29
  )
30
+ from f5_tts.model import DiT, UNetT # noqa: F401. used for config
31
 
32
 
33
  parser = argparse.ArgumentParser(
 
50
  "-m",
51
  "--model",
52
  type=str,
53
+ help="The model name: F5TTS_v1_Base | F5TTS_Base | E2TTS_Base | etc.",
54
  )
55
  parser.add_argument(
56
  "-mc",
 
172
 
173
  # command-line interface parameters
174
 
175
+ model = args.model or config.get("model", "F5TTS_v1_Base")
 
176
  ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
177
  vocab_file = args.vocab_file or config.get("vocab_file", "")
178
 
 
244
 
245
  # load TTS model
246
 
247
+ model_cfg = OmegaConf.load(
248
+ args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
249
+ ).model
250
+ model_cls = globals()[model_cfg.backbone]
251
+
252
+ repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
253
+
254
+ if model != "F5TTS_Base":
255
+ assert vocoder_name == model_cfg.mel_spec.mel_spec_type
256
+
257
+ # override for previous models
258
+ if model == "F5TTS_Base":
259
+ if vocoder_name == "vocos":
 
 
 
 
 
 
 
 
 
 
 
260
  ckpt_step = 1200000
261
+ elif vocoder_name == "bigvgan":
262
+ model = "F5TTS_Base_bigvgan"
263
+ ckpt_type = "pt"
264
+ elif model == "E2TTS_Base":
265
+ repo_name = "E2-TTS"
266
+ ckpt_step = 1200000
267
+
268
+ if not ckpt_file:
269
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
270
 
271
  print(f"Using {model}...")
272
+ ema_model = load_model(model_cls, model_cfg.arch, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
273
 
274
 
275
  # inference process
src/f5_tts/infer/speech_edit.py CHANGED
@@ -1,56 +1,63 @@
1
  import os
2
 
3
- os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
 
 
4
 
5
  import torch
6
  import torch.nn.functional as F
7
  import torchaudio
 
8
 
9
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
10
- from f5_tts.model import CFM, DiT, UNetT
11
  from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
12
 
13
- device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
14
-
15
-
16
- # --------------------- Dataset Settings -------------------- #
17
-
18
- target_sample_rate = 24000
19
- n_mel_channels = 100
20
- hop_length = 256
21
- win_length = 1024
22
- n_fft = 1024
23
- mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
24
- target_rms = 0.1
25
-
26
- tokenizer = "pinyin"
27
- dataset_name = "Emilia_ZH_EN"
28
 
29
 
30
  # ---------------------- infer setting ---------------------- #
31
 
32
  seed = None # int | None
33
 
34
- exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
35
- ckpt_step = 1200000
36
 
37
  nfe_step = 32 # 16, 32
38
  cfg_strength = 2.0
39
  ode_method = "euler" # euler | midpoint
40
  sway_sampling_coef = -1.0
41
  speed = 1.0
 
 
 
 
 
 
42
 
43
- if exp_name == "F5TTS_Base":
44
- model_cls = DiT
45
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
46
 
47
- elif exp_name == "E2TTS_Base":
48
- model_cls = UNetT
49
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
 
 
 
50
 
51
- ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
 
52
  output_dir = "tests"
53
 
 
54
  # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
55
  # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
56
  # [write the origin_text into a file, e.g. tests/test_edit.txt]
@@ -59,7 +66,7 @@ output_dir = "tests"
59
  # [--language "zho" for Chinese, "eng" for English]
60
  # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
61
 
62
- audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_en.wav"
63
  origin_text = "Some call me nature, others call me mother nature."
64
  target_text = "Some call me optimist, others call me realist."
65
  parts_to_edit = [
@@ -98,7 +105,7 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
98
 
99
  # Model
100
  model = CFM(
101
- transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
102
  mel_spec_kwargs=dict(
103
  n_fft=n_fft,
104
  hop_length=hop_length,
 
1
  import os
2
 
3
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
4
+
5
+ from importlib.resources import files
6
 
7
  import torch
8
  import torch.nn.functional as F
9
  import torchaudio
10
+ from omegaconf import OmegaConf
11
 
12
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
13
+ from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
14
  from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
15
 
16
+ device = (
17
+ "cuda"
18
+ if torch.cuda.is_available()
19
+ else "xpu"
20
+ if torch.xpu.is_available()
21
+ else "mps"
22
+ if torch.backends.mps.is_available()
23
+ else "cpu"
24
+ )
 
 
 
 
 
 
25
 
26
 
27
  # ---------------------- infer setting ---------------------- #
28
 
29
  seed = None # int | None
30
 
31
+ exp_name = "F5TTS_v1_Base" # F5TTS_v1_Base | E2TTS_Base
32
+ ckpt_step = 1250000
33
 
34
  nfe_step = 32 # 16, 32
35
  cfg_strength = 2.0
36
  ode_method = "euler" # euler | midpoint
37
  sway_sampling_coef = -1.0
38
  speed = 1.0
39
+ target_rms = 0.1
40
+
41
+
42
+ model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
43
+ model_cls = globals()[model_cfg.model.backbone]
44
+ model_arc = model_cfg.model.arch
45
 
46
+ dataset_name = model_cfg.datasets.name
47
+ tokenizer = model_cfg.model.tokenizer
 
48
 
49
+ mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
50
+ target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
51
+ n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
52
+ hop_length = model_cfg.model.mel_spec.hop_length
53
+ win_length = model_cfg.model.mel_spec.win_length
54
+ n_fft = model_cfg.model.mel_spec.n_fft
55
 
56
+
57
+ ckpt_path = str(files("f5_tts").joinpath("../../")) + f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
58
  output_dir = "tests"
59
 
60
+
61
  # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
62
  # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
63
  # [write the origin_text into a file, e.g. tests/test_edit.txt]
 
66
  # [--language "zho" for Chinese, "eng" for English]
67
  # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
68
 
69
+ audio_to_edit = str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav"))
70
  origin_text = "Some call me nature, others call me mother nature."
71
  target_text = "Some call me optimist, others call me realist."
72
  parts_to_edit = [
 
105
 
106
  # Model
107
  model = CFM(
108
+ transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
109
  mel_spec_kwargs=dict(
110
  n_fft=n_fft,
111
  hop_length=hop_length,
src/f5_tts/infer/utils_infer.py CHANGED
@@ -2,8 +2,9 @@
2
  # Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
3
  import os
4
  import sys
 
5
 
6
- os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
7
  sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
8
 
9
  import hashlib
@@ -33,7 +34,15 @@ from f5_tts.model.utils import (
33
 
34
  _ref_audio_cache = {}
35
 
36
- device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
 
 
 
 
 
 
 
 
37
 
38
  # -----------------------------------------
39
 
@@ -292,19 +301,19 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
292
  )
293
  non_silent_wave = AudioSegment.silent(duration=0)
294
  for non_silent_seg in non_silent_segs:
295
- if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
296
  show_info("Audio is over 15s, clipping short. (1)")
297
  break
298
  non_silent_wave += non_silent_seg
299
 
300
  # 2. try to find short silence for clipping if 1. failed
301
- if len(non_silent_wave) > 15000:
302
  non_silent_segs = silence.split_on_silence(
303
  aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
304
  )
305
  non_silent_wave = AudioSegment.silent(duration=0)
306
  for non_silent_seg in non_silent_segs:
307
- if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
308
  show_info("Audio is over 15s, clipping short. (2)")
309
  break
310
  non_silent_wave += non_silent_seg
@@ -312,8 +321,8 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
312
  aseg = non_silent_wave
313
 
314
  # 3. if no proper silence found for clipping
315
- if len(aseg) > 15000:
316
- aseg = aseg[:15000]
317
  show_info("Audio is over 15s, clipping short. (3)")
318
 
319
  aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
@@ -374,29 +383,31 @@ def infer_process(
374
  ):
375
  # Split the input text into batches
376
  audio, sr = torchaudio.load(ref_audio)
377
- max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
378
  gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
379
  for i, gen_text in enumerate(gen_text_batches):
380
  print(f"gen_text {i}", gen_text)
381
  print("\n")
382
 
383
  show_info(f"Generating audio in {len(gen_text_batches)} batches...")
384
- return infer_batch_process(
385
- (audio, sr),
386
- ref_text,
387
- gen_text_batches,
388
- model_obj,
389
- vocoder,
390
- mel_spec_type=mel_spec_type,
391
- progress=progress,
392
- target_rms=target_rms,
393
- cross_fade_duration=cross_fade_duration,
394
- nfe_step=nfe_step,
395
- cfg_strength=cfg_strength,
396
- sway_sampling_coef=sway_sampling_coef,
397
- speed=speed,
398
- fix_duration=fix_duration,
399
- device=device,
 
 
400
  )
401
 
402
 
@@ -419,6 +430,8 @@ def infer_batch_process(
419
  speed=1,
420
  fix_duration=None,
421
  device=None,
 
 
422
  ):
423
  audio, sr = ref_audio
424
  if audio.shape[0] > 1:
@@ -437,7 +450,12 @@ def infer_batch_process(
437
 
438
  if len(ref_text[-1].encode("utf-8")) == 1:
439
  ref_text = ref_text + " "
440
- for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
 
 
 
 
 
441
  # Prepare the text
442
  text_list = [ref_text + gen_text]
443
  final_text_list = convert_char_to_pinyin(text_list)
@@ -449,7 +467,7 @@ def infer_batch_process(
449
  # Calculate duration
450
  ref_text_len = len(ref_text.encode("utf-8"))
451
  gen_text_len = len(gen_text.encode("utf-8"))
452
- duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
453
 
454
  # inference
455
  with torch.inference_mode():
@@ -461,64 +479,88 @@ def infer_batch_process(
461
  cfg_strength=cfg_strength,
462
  sway_sampling_coef=sway_sampling_coef,
463
  )
 
464
 
465
- generated = generated.to(torch.float32)
466
  generated = generated[:, ref_audio_len:, :]
467
- generated_mel_spec = generated.permute(0, 2, 1)
468
  if mel_spec_type == "vocos":
469
- generated_wave = vocoder.decode(generated_mel_spec)
470
  elif mel_spec_type == "bigvgan":
471
- generated_wave = vocoder(generated_mel_spec)
472
  if rms < target_rms:
473
  generated_wave = generated_wave * rms / target_rms
474
 
475
  # wav -> numpy
476
  generated_wave = generated_wave.squeeze().cpu().numpy()
477
 
478
- generated_waves.append(generated_wave)
479
- spectrograms.append(generated_mel_spec[0].cpu().numpy())
480
-
481
- # Combine all generated waves with cross-fading
482
- if cross_fade_duration <= 0:
483
- # Simply concatenate
484
- final_wave = np.concatenate(generated_waves)
 
 
 
 
 
485
  else:
486
- final_wave = generated_waves[0]
487
- for i in range(1, len(generated_waves)):
488
- prev_wave = final_wave
489
- next_wave = generated_waves[i]
490
-
491
- # Calculate cross-fade samples, ensuring it does not exceed wave lengths
492
- cross_fade_samples = int(cross_fade_duration * target_sample_rate)
493
- cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
494
-
495
- if cross_fade_samples <= 0:
496
- # No overlap possible, concatenate
497
- final_wave = np.concatenate([prev_wave, next_wave])
498
- continue
499
-
500
- # Overlapping parts
501
- prev_overlap = prev_wave[-cross_fade_samples:]
502
- next_overlap = next_wave[:cross_fade_samples]
503
-
504
- # Fade out and fade in
505
- fade_out = np.linspace(1, 0, cross_fade_samples)
506
- fade_in = np.linspace(0, 1, cross_fade_samples)
507
-
508
- # Cross-faded overlap
509
- cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
510
-
511
- # Combine
512
- new_wave = np.concatenate(
513
- [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
514
- )
515
-
516
- final_wave = new_wave
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
 
518
- # Create a combined spectrogram
519
- combined_spectrogram = np.concatenate(spectrograms, axis=1)
520
-
521
- return final_wave, target_sample_rate, combined_spectrogram
522
 
523
 
524
  # remove silence from generated wav
 
2
  # Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
3
  import os
4
  import sys
5
+ from concurrent.futures import ThreadPoolExecutor
6
 
7
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
8
  sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
9
 
10
  import hashlib
 
34
 
35
  _ref_audio_cache = {}
36
 
37
+ device = (
38
+ "cuda"
39
+ if torch.cuda.is_available()
40
+ else "xpu"
41
+ if torch.xpu.is_available()
42
+ else "mps"
43
+ if torch.backends.mps.is_available()
44
+ else "cpu"
45
+ )
46
 
47
  # -----------------------------------------
48
 
 
301
  )
302
  non_silent_wave = AudioSegment.silent(duration=0)
303
  for non_silent_seg in non_silent_segs:
304
+ if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
305
  show_info("Audio is over 15s, clipping short. (1)")
306
  break
307
  non_silent_wave += non_silent_seg
308
 
309
  # 2. try to find short silence for clipping if 1. failed
310
+ if len(non_silent_wave) > 12000:
311
  non_silent_segs = silence.split_on_silence(
312
  aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
313
  )
314
  non_silent_wave = AudioSegment.silent(duration=0)
315
  for non_silent_seg in non_silent_segs:
316
+ if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
317
  show_info("Audio is over 15s, clipping short. (2)")
318
  break
319
  non_silent_wave += non_silent_seg
 
321
  aseg = non_silent_wave
322
 
323
  # 3. if no proper silence found for clipping
324
+ if len(aseg) > 12000:
325
+ aseg = aseg[:12000]
326
  show_info("Audio is over 15s, clipping short. (3)")
327
 
328
  aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
 
383
  ):
384
  # Split the input text into batches
385
  audio, sr = torchaudio.load(ref_audio)
386
+ max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr))
387
  gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
388
  for i, gen_text in enumerate(gen_text_batches):
389
  print(f"gen_text {i}", gen_text)
390
  print("\n")
391
 
392
  show_info(f"Generating audio in {len(gen_text_batches)} batches...")
393
+ return next(
394
+ infer_batch_process(
395
+ (audio, sr),
396
+ ref_text,
397
+ gen_text_batches,
398
+ model_obj,
399
+ vocoder,
400
+ mel_spec_type=mel_spec_type,
401
+ progress=progress,
402
+ target_rms=target_rms,
403
+ cross_fade_duration=cross_fade_duration,
404
+ nfe_step=nfe_step,
405
+ cfg_strength=cfg_strength,
406
+ sway_sampling_coef=sway_sampling_coef,
407
+ speed=speed,
408
+ fix_duration=fix_duration,
409
+ device=device,
410
+ )
411
  )
412
 
413
 
 
430
  speed=1,
431
  fix_duration=None,
432
  device=None,
433
+ streaming=False,
434
+ chunk_size=2048,
435
  ):
436
  audio, sr = ref_audio
437
  if audio.shape[0] > 1:
 
450
 
451
  if len(ref_text[-1].encode("utf-8")) == 1:
452
  ref_text = ref_text + " "
453
+
454
+ def process_batch(gen_text):
455
+ local_speed = speed
456
+ if len(gen_text.encode("utf-8")) < 10:
457
+ local_speed = 0.3
458
+
459
  # Prepare the text
460
  text_list = [ref_text + gen_text]
461
  final_text_list = convert_char_to_pinyin(text_list)
 
467
  # Calculate duration
468
  ref_text_len = len(ref_text.encode("utf-8"))
469
  gen_text_len = len(gen_text.encode("utf-8"))
470
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed)
471
 
472
  # inference
473
  with torch.inference_mode():
 
479
  cfg_strength=cfg_strength,
480
  sway_sampling_coef=sway_sampling_coef,
481
  )
482
+ del _
483
 
484
+ generated = generated.to(torch.float32) # generated mel spectrogram
485
  generated = generated[:, ref_audio_len:, :]
486
+ generated = generated.permute(0, 2, 1)
487
  if mel_spec_type == "vocos":
488
+ generated_wave = vocoder.decode(generated)
489
  elif mel_spec_type == "bigvgan":
490
+ generated_wave = vocoder(generated)
491
  if rms < target_rms:
492
  generated_wave = generated_wave * rms / target_rms
493
 
494
  # wav -> numpy
495
  generated_wave = generated_wave.squeeze().cpu().numpy()
496
 
497
+ if streaming:
498
+ for j in range(0, len(generated_wave), chunk_size):
499
+ yield generated_wave[j : j + chunk_size], target_sample_rate
500
+ else:
501
+ generated_cpu = generated[0].cpu().numpy()
502
+ del generated
503
+ yield generated_wave, generated_cpu
504
+
505
+ if streaming:
506
+ for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:
507
+ for chunk in process_batch(gen_text):
508
+ yield chunk
509
  else:
510
+ with ThreadPoolExecutor() as executor:
511
+ futures = [executor.submit(process_batch, gen_text) for gen_text in gen_text_batches]
512
+ for future in progress.tqdm(futures) if progress is not None else futures:
513
+ result = future.result()
514
+ if result:
515
+ generated_wave, generated_mel_spec = next(result)
516
+ generated_waves.append(generated_wave)
517
+ spectrograms.append(generated_mel_spec)
518
+
519
+ if generated_waves:
520
+ if cross_fade_duration <= 0:
521
+ # Simply concatenate
522
+ final_wave = np.concatenate(generated_waves)
523
+ else:
524
+ # Combine all generated waves with cross-fading
525
+ final_wave = generated_waves[0]
526
+ for i in range(1, len(generated_waves)):
527
+ prev_wave = final_wave
528
+ next_wave = generated_waves[i]
529
+
530
+ # Calculate cross-fade samples, ensuring it does not exceed wave lengths
531
+ cross_fade_samples = int(cross_fade_duration * target_sample_rate)
532
+ cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
533
+
534
+ if cross_fade_samples <= 0:
535
+ # No overlap possible, concatenate
536
+ final_wave = np.concatenate([prev_wave, next_wave])
537
+ continue
538
+
539
+ # Overlapping parts
540
+ prev_overlap = prev_wave[-cross_fade_samples:]
541
+ next_overlap = next_wave[:cross_fade_samples]
542
+
543
+ # Fade out and fade in
544
+ fade_out = np.linspace(1, 0, cross_fade_samples)
545
+ fade_in = np.linspace(0, 1, cross_fade_samples)
546
+
547
+ # Cross-faded overlap
548
+ cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
549
+
550
+ # Combine
551
+ new_wave = np.concatenate(
552
+ [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
553
+ )
554
+
555
+ final_wave = new_wave
556
+
557
+ # Create a combined spectrogram
558
+ combined_spectrogram = np.concatenate(spectrograms, axis=1)
559
+
560
+ yield final_wave, target_sample_rate, combined_spectrogram
561
 
562
+ else:
563
+ yield None, target_sample_rate, None
 
 
564
 
565
 
566
  # remove silence from generated wav
src/f5_tts/model/backbones/README.md CHANGED
@@ -4,7 +4,7 @@
4
  ### unett.py
5
  - flat unet transformer
6
  - structure same as in e2-tts & voicebox paper except using rotary pos emb
7
- - update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat
8
 
9
  ### dit.py
10
  - adaln-zero dit
@@ -14,7 +14,7 @@
14
  - possible long skip connection (first layer to last layer)
15
 
16
  ### mmdit.py
17
- - sd3 structure
18
  - timestep as condition
19
  - left stream: text embedded and applied a abs pos emb
20
  - right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
 
4
  ### unett.py
5
  - flat unet transformer
6
  - structure same as in e2-tts & voicebox paper except using rotary pos emb
7
+ - possible abs pos emb & convnextv2 blocks for embedded text before concat
8
 
9
  ### dit.py
10
  - adaln-zero dit
 
14
  - possible long skip connection (first layer to last layer)
15
 
16
  ### mmdit.py
17
+ - stable diffusion 3 block structure
18
  - timestep as condition
19
  - left stream: text embedded and applied a abs pos emb
20
  - right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
src/f5_tts/model/backbones/dit.py CHANGED
@@ -20,7 +20,7 @@ from f5_tts.model.modules import (
20
  ConvNeXtV2Block,
21
  ConvPositionEmbedding,
22
  DiTBlock,
23
- AdaLayerNormZero_Final,
24
  precompute_freqs_cis,
25
  get_pos_embed_indices,
26
  )
@@ -30,10 +30,12 @@ from f5_tts.model.modules import (
30
 
31
 
32
  class TextEmbedding(nn.Module):
33
- def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
34
  super().__init__()
35
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
36
 
 
 
37
  if conv_layers > 0:
38
  self.extra_modeling = True
39
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
@@ -49,6 +51,8 @@ class TextEmbedding(nn.Module):
49
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
50
  batch, text_len = text.shape[0], text.shape[1]
51
  text = F.pad(text, (0, seq_len - text_len), value=0)
 
 
52
 
53
  if drop_text: # cfg for text
54
  text = torch.zeros_like(text)
@@ -64,7 +68,13 @@ class TextEmbedding(nn.Module):
64
  text = text + text_pos_embed
65
 
66
  # convnextv2 blocks
67
- text = self.text_blocks(text)
 
 
 
 
 
 
68
 
69
  return text
70
 
@@ -103,7 +113,10 @@ class DiT(nn.Module):
103
  mel_dim=100,
104
  text_num_embeds=256,
105
  text_dim=None,
 
 
106
  conv_layers=0,
 
107
  long_skip_connection=False,
108
  checkpoint_activations=False,
109
  ):
@@ -112,7 +125,10 @@ class DiT(nn.Module):
112
  self.time_embed = TimestepEmbedding(dim)
113
  if text_dim is None:
114
  text_dim = mel_dim
115
- self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
 
 
 
116
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
117
 
118
  self.rotary_embed = RotaryEmbedding(dim_head)
@@ -121,15 +137,40 @@ class DiT(nn.Module):
121
  self.depth = depth
122
 
123
  self.transformer_blocks = nn.ModuleList(
124
- [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
 
 
 
 
 
 
 
 
 
 
 
125
  )
126
  self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
127
 
128
- self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
129
  self.proj_out = nn.Linear(dim, mel_dim)
130
 
131
  self.checkpoint_activations = checkpoint_activations
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  def ckpt_wrapper(self, module):
134
  # https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
135
  def ckpt_forward(*inputs):
@@ -138,6 +179,9 @@ class DiT(nn.Module):
138
 
139
  return ckpt_forward
140
 
 
 
 
141
  def forward(
142
  self,
143
  x: float["b n d"], # nosied input audio # noqa: F722
@@ -147,14 +191,25 @@ class DiT(nn.Module):
147
  drop_audio_cond, # cfg for cond audio
148
  drop_text, # cfg for text
149
  mask: bool["b n"] | None = None, # noqa: F722
 
150
  ):
151
  batch, seq_len = x.shape[0], x.shape[1]
152
  if time.ndim == 0:
153
  time = time.repeat(batch)
154
 
155
- # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
156
  t = self.time_embed(time)
157
- text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
 
 
 
 
 
 
 
 
 
 
158
  x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
159
 
160
  rope = self.rotary_embed.forward_from_seq_len(seq_len)
 
20
  ConvNeXtV2Block,
21
  ConvPositionEmbedding,
22
  DiTBlock,
23
+ AdaLayerNorm_Final,
24
  precompute_freqs_cis,
25
  get_pos_embed_indices,
26
  )
 
30
 
31
 
32
  class TextEmbedding(nn.Module):
33
+ def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
34
  super().__init__()
35
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
36
 
37
+ self.mask_padding = mask_padding # mask filler and batch padding tokens or not
38
+
39
  if conv_layers > 0:
40
  self.extra_modeling = True
41
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
 
51
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
52
  batch, text_len = text.shape[0], text.shape[1]
53
  text = F.pad(text, (0, seq_len - text_len), value=0)
54
+ if self.mask_padding:
55
+ text_mask = text == 0
56
 
57
  if drop_text: # cfg for text
58
  text = torch.zeros_like(text)
 
68
  text = text + text_pos_embed
69
 
70
  # convnextv2 blocks
71
+ if self.mask_padding:
72
+ text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
73
+ for block in self.text_blocks:
74
+ text = block(text)
75
+ text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
76
+ else:
77
+ text = self.text_blocks(text)
78
 
79
  return text
80
 
 
113
  mel_dim=100,
114
  text_num_embeds=256,
115
  text_dim=None,
116
+ text_mask_padding=True,
117
+ qk_norm=None,
118
  conv_layers=0,
119
+ pe_attn_head=None,
120
  long_skip_connection=False,
121
  checkpoint_activations=False,
122
  ):
 
125
  self.time_embed = TimestepEmbedding(dim)
126
  if text_dim is None:
127
  text_dim = mel_dim
128
+ self.text_embed = TextEmbedding(
129
+ text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
130
+ )
131
+ self.text_cond, self.text_uncond = None, None # text cache
132
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
133
 
134
  self.rotary_embed = RotaryEmbedding(dim_head)
 
137
  self.depth = depth
138
 
139
  self.transformer_blocks = nn.ModuleList(
140
+ [
141
+ DiTBlock(
142
+ dim=dim,
143
+ heads=heads,
144
+ dim_head=dim_head,
145
+ ff_mult=ff_mult,
146
+ dropout=dropout,
147
+ qk_norm=qk_norm,
148
+ pe_attn_head=pe_attn_head,
149
+ )
150
+ for _ in range(depth)
151
+ ]
152
  )
153
  self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
154
 
155
+ self.norm_out = AdaLayerNorm_Final(dim) # final modulation
156
  self.proj_out = nn.Linear(dim, mel_dim)
157
 
158
  self.checkpoint_activations = checkpoint_activations
159
 
160
+ self.initialize_weights()
161
+
162
+ def initialize_weights(self):
163
+ # Zero-out AdaLN layers in DiT blocks:
164
+ for block in self.transformer_blocks:
165
+ nn.init.constant_(block.attn_norm.linear.weight, 0)
166
+ nn.init.constant_(block.attn_norm.linear.bias, 0)
167
+
168
+ # Zero-out output layers:
169
+ nn.init.constant_(self.norm_out.linear.weight, 0)
170
+ nn.init.constant_(self.norm_out.linear.bias, 0)
171
+ nn.init.constant_(self.proj_out.weight, 0)
172
+ nn.init.constant_(self.proj_out.bias, 0)
173
+
174
  def ckpt_wrapper(self, module):
175
  # https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
176
  def ckpt_forward(*inputs):
 
179
 
180
  return ckpt_forward
181
 
182
+ def clear_cache(self):
183
+ self.text_cond, self.text_uncond = None, None
184
+
185
  def forward(
186
  self,
187
  x: float["b n d"], # nosied input audio # noqa: F722
 
191
  drop_audio_cond, # cfg for cond audio
192
  drop_text, # cfg for text
193
  mask: bool["b n"] | None = None, # noqa: F722
194
+ cache=False,
195
  ):
196
  batch, seq_len = x.shape[0], x.shape[1]
197
  if time.ndim == 0:
198
  time = time.repeat(batch)
199
 
200
+ # t: conditioning time, text: text, x: noised audio + cond audio + text
201
  t = self.time_embed(time)
202
+ if cache:
203
+ if drop_text:
204
+ if self.text_uncond is None:
205
+ self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
206
+ text_embed = self.text_uncond
207
+ else:
208
+ if self.text_cond is None:
209
+ self.text_cond = self.text_embed(text, seq_len, drop_text=False)
210
+ text_embed = self.text_cond
211
+ else:
212
+ text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
213
  x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
214
 
215
  rope = self.rotary_embed.forward_from_seq_len(seq_len)
src/f5_tts/model/backbones/mmdit.py CHANGED
@@ -18,7 +18,7 @@ from f5_tts.model.modules import (
18
  TimestepEmbedding,
19
  ConvPositionEmbedding,
20
  MMDiTBlock,
21
- AdaLayerNormZero_Final,
22
  precompute_freqs_cis,
23
  get_pos_embed_indices,
24
  )
@@ -28,18 +28,24 @@ from f5_tts.model.modules import (
28
 
29
 
30
  class TextEmbedding(nn.Module):
31
- def __init__(self, out_dim, text_num_embeds):
32
  super().__init__()
33
  self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
34
 
 
 
35
  self.precompute_max_pos = 1024
36
  self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
37
 
38
  def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
39
- text = text + 1
40
- if drop_text:
 
 
 
41
  text = torch.zeros_like(text)
42
- text = self.text_embed(text)
 
43
 
44
  # sinus pos emb
45
  batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
@@ -49,6 +55,9 @@ class TextEmbedding(nn.Module):
49
 
50
  text = text + text_pos_embed
51
 
 
 
 
52
  return text
53
 
54
 
@@ -83,13 +92,16 @@ class MMDiT(nn.Module):
83
  dim_head=64,
84
  dropout=0.1,
85
  ff_mult=4,
86
- text_num_embeds=256,
87
  mel_dim=100,
 
 
 
88
  ):
89
  super().__init__()
90
 
91
  self.time_embed = TimestepEmbedding(dim)
92
- self.text_embed = TextEmbedding(dim, text_num_embeds)
 
93
  self.audio_embed = AudioEmbedding(mel_dim, dim)
94
 
95
  self.rotary_embed = RotaryEmbedding(dim_head)
@@ -106,13 +118,33 @@ class MMDiT(nn.Module):
106
  dropout=dropout,
107
  ff_mult=ff_mult,
108
  context_pre_only=i == depth - 1,
 
109
  )
110
  for i in range(depth)
111
  ]
112
  )
113
- self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
114
  self.proj_out = nn.Linear(dim, mel_dim)
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  def forward(
117
  self,
118
  x: float["b n d"], # nosied input audio # noqa: F722
@@ -122,6 +154,7 @@ class MMDiT(nn.Module):
122
  drop_audio_cond, # cfg for cond audio
123
  drop_text, # cfg for text
124
  mask: bool["b n"] | None = None, # noqa: F722
 
125
  ):
126
  batch = x.shape[0]
127
  if time.ndim == 0:
@@ -129,7 +162,17 @@ class MMDiT(nn.Module):
129
 
130
  # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
131
  t = self.time_embed(time)
132
- c = self.text_embed(text, drop_text=drop_text)
 
 
 
 
 
 
 
 
 
 
133
  x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
134
 
135
  seq_len = x.shape[1]
 
18
  TimestepEmbedding,
19
  ConvPositionEmbedding,
20
  MMDiTBlock,
21
+ AdaLayerNorm_Final,
22
  precompute_freqs_cis,
23
  get_pos_embed_indices,
24
  )
 
28
 
29
 
30
  class TextEmbedding(nn.Module):
31
+ def __init__(self, out_dim, text_num_embeds, mask_padding=True):
32
  super().__init__()
33
  self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
34
 
35
+ self.mask_padding = mask_padding # mask filler and batch padding tokens or not
36
+
37
  self.precompute_max_pos = 1024
38
  self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
39
 
40
  def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
41
+ text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
42
+ if self.mask_padding:
43
+ text_mask = text == 0
44
+
45
+ if drop_text: # cfg for text
46
  text = torch.zeros_like(text)
47
+
48
+ text = self.text_embed(text) # b nt -> b nt d
49
 
50
  # sinus pos emb
51
  batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
 
55
 
56
  text = text + text_pos_embed
57
 
58
+ if self.mask_padding:
59
+ text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
60
+
61
  return text
62
 
63
 
 
92
  dim_head=64,
93
  dropout=0.1,
94
  ff_mult=4,
 
95
  mel_dim=100,
96
+ text_num_embeds=256,
97
+ text_mask_padding=True,
98
+ qk_norm=None,
99
  ):
100
  super().__init__()
101
 
102
  self.time_embed = TimestepEmbedding(dim)
103
+ self.text_embed = TextEmbedding(dim, text_num_embeds, mask_padding=text_mask_padding)
104
+ self.text_cond, self.text_uncond = None, None # text cache
105
  self.audio_embed = AudioEmbedding(mel_dim, dim)
106
 
107
  self.rotary_embed = RotaryEmbedding(dim_head)
 
118
  dropout=dropout,
119
  ff_mult=ff_mult,
120
  context_pre_only=i == depth - 1,
121
+ qk_norm=qk_norm,
122
  )
123
  for i in range(depth)
124
  ]
125
  )
126
+ self.norm_out = AdaLayerNorm_Final(dim) # final modulation
127
  self.proj_out = nn.Linear(dim, mel_dim)
128
 
129
+ self.initialize_weights()
130
+
131
+ def initialize_weights(self):
132
+ # Zero-out AdaLN layers in MMDiT blocks:
133
+ for block in self.transformer_blocks:
134
+ nn.init.constant_(block.attn_norm_x.linear.weight, 0)
135
+ nn.init.constant_(block.attn_norm_x.linear.bias, 0)
136
+ nn.init.constant_(block.attn_norm_c.linear.weight, 0)
137
+ nn.init.constant_(block.attn_norm_c.linear.bias, 0)
138
+
139
+ # Zero-out output layers:
140
+ nn.init.constant_(self.norm_out.linear.weight, 0)
141
+ nn.init.constant_(self.norm_out.linear.bias, 0)
142
+ nn.init.constant_(self.proj_out.weight, 0)
143
+ nn.init.constant_(self.proj_out.bias, 0)
144
+
145
+ def clear_cache(self):
146
+ self.text_cond, self.text_uncond = None, None
147
+
148
  def forward(
149
  self,
150
  x: float["b n d"], # nosied input audio # noqa: F722
 
154
  drop_audio_cond, # cfg for cond audio
155
  drop_text, # cfg for text
156
  mask: bool["b n"] | None = None, # noqa: F722
157
+ cache=False,
158
  ):
159
  batch = x.shape[0]
160
  if time.ndim == 0:
 
162
 
163
  # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
164
  t = self.time_embed(time)
165
+ if cache:
166
+ if drop_text:
167
+ if self.text_uncond is None:
168
+ self.text_uncond = self.text_embed(text, drop_text=True)
169
+ c = self.text_uncond
170
+ else:
171
+ if self.text_cond is None:
172
+ self.text_cond = self.text_embed(text, drop_text=False)
173
+ c = self.text_cond
174
+ else:
175
+ c = self.text_embed(text, drop_text=drop_text)
176
  x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
177
 
178
  seq_len = x.shape[1]
src/f5_tts/model/backbones/unett.py CHANGED
@@ -33,10 +33,12 @@ from f5_tts.model.modules import (
33
 
34
 
35
  class TextEmbedding(nn.Module):
36
- def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
37
  super().__init__()
38
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
 
 
 
40
  if conv_layers > 0:
41
  self.extra_modeling = True
42
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
@@ -52,6 +54,8 @@ class TextEmbedding(nn.Module):
52
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
53
  batch, text_len = text.shape[0], text.shape[1]
54
  text = F.pad(text, (0, seq_len - text_len), value=0)
 
 
55
 
56
  if drop_text: # cfg for text
57
  text = torch.zeros_like(text)
@@ -67,7 +71,13 @@ class TextEmbedding(nn.Module):
67
  text = text + text_pos_embed
68
 
69
  # convnextv2 blocks
70
- text = self.text_blocks(text)
 
 
 
 
 
 
71
 
72
  return text
73
 
@@ -106,7 +116,10 @@ class UNetT(nn.Module):
106
  mel_dim=100,
107
  text_num_embeds=256,
108
  text_dim=None,
 
 
109
  conv_layers=0,
 
110
  skip_connect_type: Literal["add", "concat", "none"] = "concat",
111
  ):
112
  super().__init__()
@@ -115,7 +128,10 @@ class UNetT(nn.Module):
115
  self.time_embed = TimestepEmbedding(dim)
116
  if text_dim is None:
117
  text_dim = mel_dim
118
- self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
 
 
 
119
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
120
 
121
  self.rotary_embed = RotaryEmbedding(dim_head)
@@ -134,11 +150,12 @@ class UNetT(nn.Module):
134
 
135
  attn_norm = RMSNorm(dim)
136
  attn = Attention(
137
- processor=AttnProcessor(),
138
  dim=dim,
139
  heads=heads,
140
  dim_head=dim_head,
141
  dropout=dropout,
 
142
  )
143
 
144
  ff_norm = RMSNorm(dim)
@@ -161,6 +178,9 @@ class UNetT(nn.Module):
161
  self.norm_out = RMSNorm(dim)
162
  self.proj_out = nn.Linear(dim, mel_dim)
163
 
 
 
 
164
  def forward(
165
  self,
166
  x: float["b n d"], # nosied input audio # noqa: F722
@@ -170,6 +190,7 @@ class UNetT(nn.Module):
170
  drop_audio_cond, # cfg for cond audio
171
  drop_text, # cfg for text
172
  mask: bool["b n"] | None = None, # noqa: F722
 
173
  ):
174
  batch, seq_len = x.shape[0], x.shape[1]
175
  if time.ndim == 0:
@@ -177,7 +198,17 @@ class UNetT(nn.Module):
177
 
178
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
179
  t = self.time_embed(time)
180
- text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
 
 
 
 
 
 
 
 
 
 
181
  x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
182
 
183
  # postfix time t to input x, [b n d] -> [b n+1 d]
 
33
 
34
 
35
  class TextEmbedding(nn.Module):
36
+ def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
37
  super().__init__()
38
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
 
40
+ self.mask_padding = mask_padding # mask filler and batch padding tokens or not
41
+
42
  if conv_layers > 0:
43
  self.extra_modeling = True
44
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
 
54
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
55
  batch, text_len = text.shape[0], text.shape[1]
56
  text = F.pad(text, (0, seq_len - text_len), value=0)
57
+ if self.mask_padding:
58
+ text_mask = text == 0
59
 
60
  if drop_text: # cfg for text
61
  text = torch.zeros_like(text)
 
71
  text = text + text_pos_embed
72
 
73
  # convnextv2 blocks
74
+ if self.mask_padding:
75
+ text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
76
+ for block in self.text_blocks:
77
+ text = block(text)
78
+ text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
79
+ else:
80
+ text = self.text_blocks(text)
81
 
82
  return text
83
 
 
116
  mel_dim=100,
117
  text_num_embeds=256,
118
  text_dim=None,
119
+ text_mask_padding=True,
120
+ qk_norm=None,
121
  conv_layers=0,
122
+ pe_attn_head=None,
123
  skip_connect_type: Literal["add", "concat", "none"] = "concat",
124
  ):
125
  super().__init__()
 
128
  self.time_embed = TimestepEmbedding(dim)
129
  if text_dim is None:
130
  text_dim = mel_dim
131
+ self.text_embed = TextEmbedding(
132
+ text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
133
+ )
134
+ self.text_cond, self.text_uncond = None, None # text cache
135
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
136
 
137
  self.rotary_embed = RotaryEmbedding(dim_head)
 
150
 
151
  attn_norm = RMSNorm(dim)
152
  attn = Attention(
153
+ processor=AttnProcessor(pe_attn_head=pe_attn_head),
154
  dim=dim,
155
  heads=heads,
156
  dim_head=dim_head,
157
  dropout=dropout,
158
+ qk_norm=qk_norm,
159
  )
160
 
161
  ff_norm = RMSNorm(dim)
 
178
  self.norm_out = RMSNorm(dim)
179
  self.proj_out = nn.Linear(dim, mel_dim)
180
 
181
+ def clear_cache(self):
182
+ self.text_cond, self.text_uncond = None, None
183
+
184
  def forward(
185
  self,
186
  x: float["b n d"], # nosied input audio # noqa: F722
 
190
  drop_audio_cond, # cfg for cond audio
191
  drop_text, # cfg for text
192
  mask: bool["b n"] | None = None, # noqa: F722
193
+ cache=False,
194
  ):
195
  batch, seq_len = x.shape[0], x.shape[1]
196
  if time.ndim == 0:
 
198
 
199
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
200
  t = self.time_embed(time)
201
+ if cache:
202
+ if drop_text:
203
+ if self.text_uncond is None:
204
+ self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
205
+ text_embed = self.text_uncond
206
+ else:
207
+ if self.text_cond is None:
208
+ self.text_cond = self.text_embed(text, seq_len, drop_text=False)
209
+ text_embed = self.text_cond
210
+ else:
211
+ text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
212
  x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
213
 
214
  # postfix time t to input x, [b n d] -> [b n+1 d]
src/f5_tts/model/cfm.py CHANGED
@@ -120,10 +120,6 @@ class CFM(nn.Module):
120
  text = list_str_to_tensor(text).to(device)
121
  assert text.shape[0] == batch
122
 
123
- if exists(text):
124
- text_lens = (text != -1).sum(dim=-1)
125
- lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
126
-
127
  # duration
128
 
129
  cond_mask = lens_to_mask(lens)
@@ -133,7 +129,9 @@ class CFM(nn.Module):
133
  if isinstance(duration, int):
134
  duration = torch.full((batch,), duration, device=device, dtype=torch.long)
135
 
136
- duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
 
 
137
  duration = duration.clamp(max=max_duration)
138
  max_duration = duration.amax()
139
 
@@ -142,6 +140,9 @@ class CFM(nn.Module):
142
  test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
143
 
144
  cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
 
 
 
145
  cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
146
  cond_mask = cond_mask.unsqueeze(-1)
147
  step_cond = torch.where(
@@ -153,10 +154,6 @@ class CFM(nn.Module):
153
  else: # save memory and speed up, as single inference need no mask currently
154
  mask = None
155
 
156
- # test for no ref audio
157
- if no_ref_audio:
158
- cond = torch.zeros_like(cond)
159
-
160
  # neural ode
161
 
162
  def fn(t, x):
@@ -165,13 +162,13 @@ class CFM(nn.Module):
165
 
166
  # predict flow
167
  pred = self.transformer(
168
- x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False
169
  )
170
  if cfg_strength < 1e-5:
171
  return pred
172
 
173
  null_pred = self.transformer(
174
- x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True
175
  )
176
  return pred + (pred - null_pred) * cfg_strength
177
 
@@ -198,6 +195,7 @@ class CFM(nn.Module):
198
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
199
 
200
  trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
 
201
 
202
  sampled = trajectory[-1]
203
  out = sampled
 
120
  text = list_str_to_tensor(text).to(device)
121
  assert text.shape[0] == batch
122
 
 
 
 
 
123
  # duration
124
 
125
  cond_mask = lens_to_mask(lens)
 
129
  if isinstance(duration, int):
130
  duration = torch.full((batch,), duration, device=device, dtype=torch.long)
131
 
132
+ duration = torch.maximum(
133
+ torch.maximum((text != -1).sum(dim=-1), lens) + 1, duration
134
+ ) # duration at least text/audio prompt length plus one token, so something is generated
135
  duration = duration.clamp(max=max_duration)
136
  max_duration = duration.amax()
137
 
 
140
  test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
141
 
142
  cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
143
+ if no_ref_audio:
144
+ cond = torch.zeros_like(cond)
145
+
146
  cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
147
  cond_mask = cond_mask.unsqueeze(-1)
148
  step_cond = torch.where(
 
154
  else: # save memory and speed up, as single inference need no mask currently
155
  mask = None
156
 
 
 
 
 
157
  # neural ode
158
 
159
  def fn(t, x):
 
162
 
163
  # predict flow
164
  pred = self.transformer(
165
+ x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, cache=True
166
  )
167
  if cfg_strength < 1e-5:
168
  return pred
169
 
170
  null_pred = self.transformer(
171
+ x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True, cache=True
172
  )
173
  return pred + (pred - null_pred) * cfg_strength
174
 
 
195
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
196
 
197
  trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
198
+ self.transformer.clear_cache()
199
 
200
  sampled = trajectory[-1]
201
  out = sampled
src/f5_tts/model/dataset.py CHANGED
@@ -1,5 +1,4 @@
1
  import json
2
- import random
3
  from importlib.resources import files
4
 
5
  import torch
@@ -170,14 +169,17 @@ class DynamicBatchSampler(Sampler[list[int]]):
170
  in a batch to ensure that the total number of frames are less
171
  than a certain threshold.
172
  2. Make sure the padding efficiency in the batch is high.
 
173
  """
174
 
175
  def __init__(
176
- self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False
177
  ):
178
  self.sampler = sampler
179
  self.frames_threshold = frames_threshold
180
  self.max_samples = max_samples
 
 
181
 
182
  indices, batches = [], []
183
  data_source = self.sampler.data_source
@@ -206,21 +208,30 @@ class DynamicBatchSampler(Sampler[list[int]]):
206
  batch = []
207
  batch_frames = 0
208
 
209
- if not drop_last and len(batch) > 0:
210
  batches.append(batch)
211
 
212
  del indices
 
213
 
214
- # if want to have different batches between epochs, may just set a seed and log it in ckpt
215
- # cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
216
- # e.g. for epoch n, use (random_seed + n)
217
- random.seed(random_seed)
218
- random.shuffle(batches)
219
 
220
- self.batches = batches
 
 
221
 
222
  def __iter__(self):
223
- return iter(self.batches)
 
 
 
 
 
 
 
 
 
224
 
225
  def __len__(self):
226
  return len(self.batches)
 
1
  import json
 
2
  from importlib.resources import files
3
 
4
  import torch
 
169
  in a batch to ensure that the total number of frames are less
170
  than a certain threshold.
171
  2. Make sure the padding efficiency in the batch is high.
172
+ 3. Shuffle batches each epoch while maintaining reproducibility.
173
  """
174
 
175
  def __init__(
176
+ self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_residual: bool = False
177
  ):
178
  self.sampler = sampler
179
  self.frames_threshold = frames_threshold
180
  self.max_samples = max_samples
181
+ self.random_seed = random_seed
182
+ self.epoch = 0
183
 
184
  indices, batches = [], []
185
  data_source = self.sampler.data_source
 
208
  batch = []
209
  batch_frames = 0
210
 
211
+ if not drop_residual and len(batch) > 0:
212
  batches.append(batch)
213
 
214
  del indices
215
+ self.batches = batches
216
 
217
+ # Ensure even batches with accelerate BatchSamplerShard cls under frame_per_batch setting
218
+ self.drop_last = True
 
 
 
219
 
220
+ def set_epoch(self, epoch: int) -> None:
221
+ """Sets the epoch for this sampler."""
222
+ self.epoch = epoch
223
 
224
  def __iter__(self):
225
+ # Use both random_seed and epoch for deterministic but different shuffling per epoch
226
+ if self.random_seed is not None:
227
+ g = torch.Generator()
228
+ g.manual_seed(self.random_seed + self.epoch)
229
+ # Use PyTorch's random permutation for better reproducibility across PyTorch versions
230
+ indices = torch.randperm(len(self.batches), generator=g).tolist()
231
+ batches = [self.batches[i] for i in indices]
232
+ else:
233
+ batches = self.batches
234
+ return iter(batches)
235
 
236
  def __len__(self):
237
  return len(self.batches)
src/f5_tts/model/modules.py CHANGED
@@ -269,11 +269,36 @@ class ConvNeXtV2Block(nn.Module):
269
  return residual + x
270
 
271
 
272
- # AdaLayerNormZero
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  # return with modulated x for attn input, and params for later mlp modulation
274
 
275
 
276
- class AdaLayerNormZero(nn.Module):
277
  def __init__(self, dim):
278
  super().__init__()
279
 
@@ -290,11 +315,11 @@ class AdaLayerNormZero(nn.Module):
290
  return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
291
 
292
 
293
- # AdaLayerNormZero for final layer
294
  # return only with modulated x for attn input, cuz no more mlp modulation
295
 
296
 
297
- class AdaLayerNormZero_Final(nn.Module):
298
  def __init__(self, dim):
299
  super().__init__()
300
 
@@ -341,7 +366,8 @@ class Attention(nn.Module):
341
  dim_head: int = 64,
342
  dropout: float = 0.0,
343
  context_dim: Optional[int] = None, # if not None -> joint attention
344
- context_pre_only=None,
 
345
  ):
346
  super().__init__()
347
 
@@ -362,18 +388,32 @@ class Attention(nn.Module):
362
  self.to_k = nn.Linear(dim, self.inner_dim)
363
  self.to_v = nn.Linear(dim, self.inner_dim)
364
 
 
 
 
 
 
 
 
 
 
365
  if self.context_dim is not None:
 
366
  self.to_k_c = nn.Linear(context_dim, self.inner_dim)
367
  self.to_v_c = nn.Linear(context_dim, self.inner_dim)
368
- if self.context_pre_only is not None:
369
- self.to_q_c = nn.Linear(context_dim, self.inner_dim)
 
 
 
 
370
 
371
  self.to_out = nn.ModuleList([])
372
  self.to_out.append(nn.Linear(self.inner_dim, dim))
373
  self.to_out.append(nn.Dropout(dropout))
374
 
375
- if self.context_pre_only is not None and not self.context_pre_only:
376
- self.to_out_c = nn.Linear(self.inner_dim, dim)
377
 
378
  def forward(
379
  self,
@@ -393,8 +433,11 @@ class Attention(nn.Module):
393
 
394
 
395
  class AttnProcessor:
396
- def __init__(self):
397
- pass
 
 
 
398
 
399
  def __call__(
400
  self,
@@ -405,19 +448,11 @@ class AttnProcessor:
405
  ) -> torch.FloatTensor:
406
  batch_size = x.shape[0]
407
 
408
- # `sample` projections.
409
  query = attn.to_q(x)
410
  key = attn.to_k(x)
411
  value = attn.to_v(x)
412
 
413
- # apply rotary position embedding
414
- if rope is not None:
415
- freqs, xpos_scale = rope
416
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
417
-
418
- query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
419
- key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
420
-
421
  # attention
422
  inner_dim = key.shape[-1]
423
  head_dim = inner_dim // attn.heads
@@ -425,6 +460,25 @@ class AttnProcessor:
425
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
426
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  # mask. e.g. inference got a batch with different target durations, mask out the padding
429
  if mask is not None:
430
  attn_mask = mask
@@ -470,16 +524,36 @@ class JointAttnProcessor:
470
 
471
  batch_size = c.shape[0]
472
 
473
- # `sample` projections.
474
  query = attn.to_q(x)
475
  key = attn.to_k(x)
476
  value = attn.to_v(x)
477
 
478
- # `context` projections.
479
  c_query = attn.to_q_c(c)
480
  c_key = attn.to_k_c(c)
481
  c_value = attn.to_v_c(c)
482
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  # apply rope for context and noised input independently
484
  if rope is not None:
485
  freqs, xpos_scale = rope
@@ -492,16 +566,10 @@ class JointAttnProcessor:
492
  c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
493
  c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
494
 
495
- # attention
496
- query = torch.cat([query, c_query], dim=1)
497
- key = torch.cat([key, c_key], dim=1)
498
- value = torch.cat([value, c_value], dim=1)
499
-
500
- inner_dim = key.shape[-1]
501
- head_dim = inner_dim // attn.heads
502
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
503
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
504
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
505
 
506
  # mask. e.g. inference got a batch with different target durations, mask out the padding
507
  if mask is not None:
@@ -540,16 +608,17 @@ class JointAttnProcessor:
540
 
541
 
542
  class DiTBlock(nn.Module):
543
- def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
544
  super().__init__()
545
 
546
- self.attn_norm = AdaLayerNormZero(dim)
547
  self.attn = Attention(
548
- processor=AttnProcessor(),
549
  dim=dim,
550
  heads=heads,
551
  dim_head=dim_head,
552
  dropout=dropout,
 
553
  )
554
 
555
  self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
@@ -585,26 +654,30 @@ class MMDiTBlock(nn.Module):
585
  context_pre_only: last layer only do prenorm + modulation cuz no more ffn
586
  """
587
 
588
- def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
 
 
589
  super().__init__()
590
-
 
591
  self.context_pre_only = context_pre_only
592
 
593
- self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
594
- self.attn_norm_x = AdaLayerNormZero(dim)
595
  self.attn = Attention(
596
  processor=JointAttnProcessor(),
597
  dim=dim,
598
  heads=heads,
599
  dim_head=dim_head,
600
  dropout=dropout,
601
- context_dim=dim,
602
  context_pre_only=context_pre_only,
 
603
  )
604
 
605
  if not context_pre_only:
606
- self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
607
- self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
608
  else:
609
  self.ff_norm_c = None
610
  self.ff_c = None
 
269
  return residual + x
270
 
271
 
272
+ # RMSNorm
273
+
274
+
275
+ class RMSNorm(nn.Module):
276
+ def __init__(self, dim: int, eps: float):
277
+ super().__init__()
278
+ self.eps = eps
279
+ self.weight = nn.Parameter(torch.ones(dim))
280
+ self.native_rms_norm = float(torch.__version__[:3]) >= 2.4
281
+
282
+ def forward(self, x):
283
+ if self.native_rms_norm:
284
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
285
+ x = x.to(self.weight.dtype)
286
+ x = F.rms_norm(x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps)
287
+ else:
288
+ variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
289
+ x = x * torch.rsqrt(variance + self.eps)
290
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
291
+ x = x.to(self.weight.dtype)
292
+ x = x * self.weight
293
+
294
+ return x
295
+
296
+
297
+ # AdaLayerNorm
298
  # return with modulated x for attn input, and params for later mlp modulation
299
 
300
 
301
+ class AdaLayerNorm(nn.Module):
302
  def __init__(self, dim):
303
  super().__init__()
304
 
 
315
  return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
316
 
317
 
318
+ # AdaLayerNorm for final layer
319
  # return only with modulated x for attn input, cuz no more mlp modulation
320
 
321
 
322
+ class AdaLayerNorm_Final(nn.Module):
323
  def __init__(self, dim):
324
  super().__init__()
325
 
 
366
  dim_head: int = 64,
367
  dropout: float = 0.0,
368
  context_dim: Optional[int] = None, # if not None -> joint attention
369
+ context_pre_only: bool = False,
370
+ qk_norm: Optional[str] = None,
371
  ):
372
  super().__init__()
373
 
 
388
  self.to_k = nn.Linear(dim, self.inner_dim)
389
  self.to_v = nn.Linear(dim, self.inner_dim)
390
 
391
+ if qk_norm is None:
392
+ self.q_norm = None
393
+ self.k_norm = None
394
+ elif qk_norm == "rms_norm":
395
+ self.q_norm = RMSNorm(dim_head, eps=1e-6)
396
+ self.k_norm = RMSNorm(dim_head, eps=1e-6)
397
+ else:
398
+ raise ValueError(f"Unimplemented qk_norm: {qk_norm}")
399
+
400
  if self.context_dim is not None:
401
+ self.to_q_c = nn.Linear(context_dim, self.inner_dim)
402
  self.to_k_c = nn.Linear(context_dim, self.inner_dim)
403
  self.to_v_c = nn.Linear(context_dim, self.inner_dim)
404
+ if qk_norm is None:
405
+ self.c_q_norm = None
406
+ self.c_k_norm = None
407
+ elif qk_norm == "rms_norm":
408
+ self.c_q_norm = RMSNorm(dim_head, eps=1e-6)
409
+ self.c_k_norm = RMSNorm(dim_head, eps=1e-6)
410
 
411
  self.to_out = nn.ModuleList([])
412
  self.to_out.append(nn.Linear(self.inner_dim, dim))
413
  self.to_out.append(nn.Dropout(dropout))
414
 
415
+ if self.context_dim is not None and not self.context_pre_only:
416
+ self.to_out_c = nn.Linear(self.inner_dim, context_dim)
417
 
418
  def forward(
419
  self,
 
433
 
434
 
435
  class AttnProcessor:
436
+ def __init__(
437
+ self,
438
+ pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
439
+ ):
440
+ self.pe_attn_head = pe_attn_head
441
 
442
  def __call__(
443
  self,
 
448
  ) -> torch.FloatTensor:
449
  batch_size = x.shape[0]
450
 
451
+ # `sample` projections
452
  query = attn.to_q(x)
453
  key = attn.to_k(x)
454
  value = attn.to_v(x)
455
 
 
 
 
 
 
 
 
 
456
  # attention
457
  inner_dim = key.shape[-1]
458
  head_dim = inner_dim // attn.heads
 
460
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
461
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
462
 
463
+ # qk norm
464
+ if attn.q_norm is not None:
465
+ query = attn.q_norm(query)
466
+ if attn.k_norm is not None:
467
+ key = attn.k_norm(key)
468
+
469
+ # apply rotary position embedding
470
+ if rope is not None:
471
+ freqs, xpos_scale = rope
472
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
473
+
474
+ if self.pe_attn_head is not None:
475
+ pn = self.pe_attn_head
476
+ query[:, :pn, :, :] = apply_rotary_pos_emb(query[:, :pn, :, :], freqs, q_xpos_scale)
477
+ key[:, :pn, :, :] = apply_rotary_pos_emb(key[:, :pn, :, :], freqs, k_xpos_scale)
478
+ else:
479
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
480
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
481
+
482
  # mask. e.g. inference got a batch with different target durations, mask out the padding
483
  if mask is not None:
484
  attn_mask = mask
 
524
 
525
  batch_size = c.shape[0]
526
 
527
+ # `sample` projections
528
  query = attn.to_q(x)
529
  key = attn.to_k(x)
530
  value = attn.to_v(x)
531
 
532
+ # `context` projections
533
  c_query = attn.to_q_c(c)
534
  c_key = attn.to_k_c(c)
535
  c_value = attn.to_v_c(c)
536
 
537
+ # attention
538
+ inner_dim = key.shape[-1]
539
+ head_dim = inner_dim // attn.heads
540
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
541
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
542
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
543
+ c_query = c_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
544
+ c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
545
+ c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
546
+
547
+ # qk norm
548
+ if attn.q_norm is not None:
549
+ query = attn.q_norm(query)
550
+ if attn.k_norm is not None:
551
+ key = attn.k_norm(key)
552
+ if attn.c_q_norm is not None:
553
+ c_query = attn.c_q_norm(c_query)
554
+ if attn.c_k_norm is not None:
555
+ c_key = attn.c_k_norm(c_key)
556
+
557
  # apply rope for context and noised input independently
558
  if rope is not None:
559
  freqs, xpos_scale = rope
 
566
  c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
567
  c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
568
 
569
+ # joint attention
570
+ query = torch.cat([query, c_query], dim=2)
571
+ key = torch.cat([key, c_key], dim=2)
572
+ value = torch.cat([value, c_value], dim=2)
 
 
 
 
 
 
573
 
574
  # mask. e.g. inference got a batch with different target durations, mask out the padding
575
  if mask is not None:
 
608
 
609
 
610
  class DiTBlock(nn.Module):
611
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None):
612
  super().__init__()
613
 
614
+ self.attn_norm = AdaLayerNorm(dim)
615
  self.attn = Attention(
616
+ processor=AttnProcessor(pe_attn_head=pe_attn_head),
617
  dim=dim,
618
  heads=heads,
619
  dim_head=dim_head,
620
  dropout=dropout,
621
+ qk_norm=qk_norm,
622
  )
623
 
624
  self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
 
654
  context_pre_only: last layer only do prenorm + modulation cuz no more ffn
655
  """
656
 
657
+ def __init__(
658
+ self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_dim=None, context_pre_only=False, qk_norm=None
659
+ ):
660
  super().__init__()
661
+ if context_dim is None:
662
+ context_dim = dim
663
  self.context_pre_only = context_pre_only
664
 
665
+ self.attn_norm_c = AdaLayerNorm_Final(context_dim) if context_pre_only else AdaLayerNorm(context_dim)
666
+ self.attn_norm_x = AdaLayerNorm(dim)
667
  self.attn = Attention(
668
  processor=JointAttnProcessor(),
669
  dim=dim,
670
  heads=heads,
671
  dim_head=dim_head,
672
  dropout=dropout,
673
+ context_dim=context_dim,
674
  context_pre_only=context_pre_only,
675
+ qk_norm=qk_norm,
676
  )
677
 
678
  if not context_pre_only:
679
+ self.ff_norm_c = nn.LayerNorm(context_dim, elementwise_affine=False, eps=1e-6)
680
+ self.ff_c = FeedForward(dim=context_dim, mult=ff_mult, dropout=dropout, approximate="tanh")
681
  else:
682
  self.ff_norm_c = None
683
  self.ff_c = None
src/f5_tts/model/trainer.py CHANGED
@@ -1,6 +1,7 @@
1
  from __future__ import annotations
2
 
3
  import gc
 
4
  import os
5
 
6
  import torch
@@ -29,8 +30,9 @@ class Trainer:
29
  learning_rate,
30
  num_warmup_updates=20000,
31
  save_per_updates=1000,
 
32
  checkpoint_path=None,
33
- batch_size=32,
34
  batch_size_type: str = "sample",
35
  max_samples=32,
36
  grad_accumulation_steps=1,
@@ -38,23 +40,23 @@ class Trainer:
38
  noise_scheduler: str | None = None,
39
  duration_predictor: torch.nn.Module | None = None,
40
  logger: str | None = "wandb", # "wandb" | "tensorboard" | None
41
- wandb_project="test_e2-tts",
42
  wandb_run_name="test_run",
43
  wandb_resume_id: str = None,
44
  log_samples: bool = False,
45
- last_per_steps=None,
46
  accelerate_kwargs: dict = dict(),
47
  ema_kwargs: dict = dict(),
48
  bnb_optimizer: bool = False,
49
  mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
50
  is_local_vocoder: bool = False, # use local path vocoder
51
  local_vocoder_path: str = "", # local vocoder path
 
52
  ):
53
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
54
 
55
  if logger == "wandb" and not wandb.api.api_key:
56
  logger = None
57
- print(f"Using logger: {logger}")
58
  self.log_samples = log_samples
59
 
60
  self.accelerator = Accelerator(
@@ -71,21 +73,23 @@ class Trainer:
71
  else:
72
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
73
 
74
- self.accelerator.init_trackers(
75
- project_name=wandb_project,
76
- init_kwargs=init_kwargs,
77
- config={
78
  "epochs": epochs,
79
  "learning_rate": learning_rate,
80
  "num_warmup_updates": num_warmup_updates,
81
- "batch_size": batch_size,
82
  "batch_size_type": batch_size_type,
83
  "max_samples": max_samples,
84
  "grad_accumulation_steps": grad_accumulation_steps,
85
  "max_grad_norm": max_grad_norm,
86
- "gpus": self.accelerator.num_processes,
87
  "noise_scheduler": noise_scheduler,
88
- },
 
 
 
 
 
89
  )
90
 
91
  elif self.logger == "tensorboard":
@@ -99,13 +103,20 @@ class Trainer:
99
  self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
100
  self.ema_model.to(self.accelerator.device)
101
 
 
 
 
 
 
 
102
  self.epochs = epochs
103
  self.num_warmup_updates = num_warmup_updates
104
  self.save_per_updates = save_per_updates
105
- self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
106
- self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
 
107
 
108
- self.batch_size = batch_size
109
  self.batch_size_type = batch_size_type
110
  self.max_samples = max_samples
111
  self.grad_accumulation_steps = grad_accumulation_steps
@@ -132,7 +143,7 @@ class Trainer:
132
  def is_main(self):
133
  return self.accelerator.is_main_process
134
 
135
- def save_checkpoint(self, step, last=False):
136
  self.accelerator.wait_for_everyone()
137
  if self.is_main:
138
  checkpoint = dict(
@@ -140,21 +151,38 @@ class Trainer:
140
  optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
141
  ema_model_state_dict=self.ema_model.state_dict(),
142
  scheduler_state_dict=self.scheduler.state_dict(),
143
- step=step,
144
  )
145
  if not os.path.exists(self.checkpoint_path):
146
  os.makedirs(self.checkpoint_path)
147
  if last:
148
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
149
- print(f"Saved last checkpoint at step {step}")
150
  else:
151
- self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  def load_checkpoint(self):
154
  if (
155
  not exists(self.checkpoint_path)
156
  or not os.path.exists(self.checkpoint_path)
157
- or not any(filename.endswith(".pt") for filename in os.listdir(self.checkpoint_path))
158
  ):
159
  return 0
160
 
@@ -162,12 +190,34 @@ class Trainer:
162
  if "model_last.pt" in os.listdir(self.checkpoint_path):
163
  latest_checkpoint = "model_last.pt"
164
  else:
165
- latest_checkpoint = sorted(
166
- [f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
167
- key=lambda x: int("".join(filter(str.isdigit, x))),
168
- )[-1]
169
- # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
170
- checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  # patch for backward compatibility, 305e3ea
173
  for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
@@ -177,7 +227,14 @@ class Trainer:
177
  if self.is_main:
178
  self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
179
 
180
- if "step" in checkpoint:
 
 
 
 
 
 
 
181
  # patch for backward compatibility, 305e3ea
182
  for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
183
  if key in checkpoint["model_state_dict"]:
@@ -187,19 +244,19 @@ class Trainer:
187
  self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
188
  if self.scheduler:
189
  self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
190
- step = checkpoint["step"]
191
  else:
192
  checkpoint["model_state_dict"] = {
193
  k.replace("ema_model.", ""): v
194
  for k, v in checkpoint["ema_model_state_dict"].items()
195
- if k not in ["initted", "step"]
196
  }
197
  self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
198
- step = 0
199
 
200
  del checkpoint
201
  gc.collect()
202
- return step
203
 
204
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
205
  if self.log_samples:
@@ -225,7 +282,7 @@ class Trainer:
225
  num_workers=num_workers,
226
  pin_memory=True,
227
  persistent_workers=True,
228
- batch_size=self.batch_size,
229
  shuffle=True,
230
  generator=generator,
231
  )
@@ -233,7 +290,11 @@ class Trainer:
233
  self.accelerator.even_batches = False
234
  sampler = SequentialSampler(train_dataset)
235
  batch_sampler = DynamicBatchSampler(
236
- sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
 
 
 
 
237
  )
238
  train_dataloader = DataLoader(
239
  train_dataset,
@@ -248,25 +309,26 @@ class Trainer:
248
 
249
  # accelerator.prepare() dispatches batches to devices;
250
  # which means the length of dataloader calculated before, should consider the number of devices
251
- warmup_steps = (
252
  self.num_warmup_updates * self.accelerator.num_processes
253
  ) # consider a fixed warmup steps while using accelerate multi-gpu ddp
254
  # otherwise by default with split_batches=False, warmup steps change with num_processes
255
- total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
256
- decay_steps = total_steps - warmup_steps
257
- warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
258
- decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
259
  self.scheduler = SequentialLR(
260
- self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
261
  )
262
  train_dataloader, self.scheduler = self.accelerator.prepare(
263
  train_dataloader, self.scheduler
264
- ) # actual steps = 1 gpu steps / gpus
265
- start_step = self.load_checkpoint()
266
- global_step = start_step
267
 
268
  if exists(resumable_with_seed):
269
  orig_epoch_step = len(train_dataloader)
 
270
  skipped_epoch = int(start_step // orig_epoch_step)
271
  skipped_batch = start_step % orig_epoch_step
272
  skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
@@ -276,23 +338,25 @@ class Trainer:
276
  for epoch in range(skipped_epoch, self.epochs):
277
  self.model.train()
278
  if exists(resumable_with_seed) and epoch == skipped_epoch:
279
- progress_bar = tqdm(
280
- skipped_dataloader,
281
- desc=f"Epoch {epoch+1}/{self.epochs}",
282
- unit="step",
283
- disable=not self.accelerator.is_local_main_process,
284
- initial=skipped_batch,
285
- total=orig_epoch_step,
286
- )
287
  else:
288
- progress_bar = tqdm(
289
- train_dataloader,
290
- desc=f"Epoch {epoch+1}/{self.epochs}",
291
- unit="step",
292
- disable=not self.accelerator.is_local_main_process,
293
- )
 
 
 
 
 
 
 
 
294
 
295
- for batch in progress_bar:
296
  with self.accelerator.accumulate(self.model):
297
  text_inputs = batch["text"]
298
  mel_spec = batch["mel"].permute(0, 2, 1)
@@ -301,7 +365,7 @@ class Trainer:
301
  # TODO. add duration predictor training
302
  if self.duration_predictor is not None and self.accelerator.is_local_main_process:
303
  dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
304
- self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
305
 
306
  loss, cond, pred = self.model(
307
  mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
@@ -315,21 +379,24 @@ class Trainer:
315
  self.scheduler.step()
316
  self.optimizer.zero_grad()
317
 
318
- if self.is_main and self.accelerator.sync_gradients:
319
- self.ema_model.update()
 
320
 
321
- global_step += 1
 
 
322
 
323
  if self.accelerator.is_local_main_process:
324
- self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
 
 
325
  if self.logger == "tensorboard":
326
- self.writer.add_scalar("loss", loss.item(), global_step)
327
- self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step)
328
-
329
- progress_bar.set_postfix(step=str(global_step), loss=loss.item())
330
 
331
- if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
332
- self.save_checkpoint(global_step)
333
 
334
  if self.log_samples and self.accelerator.is_local_main_process:
335
  ref_audio_len = mel_lengths[0]
@@ -355,12 +422,16 @@ class Trainer:
355
  gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
356
  ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
357
 
358
- torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
359
- torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
 
 
 
 
360
 
361
- if global_step % self.last_per_steps == 0:
362
- self.save_checkpoint(global_step, last=True)
363
 
364
- self.save_checkpoint(global_step, last=True)
365
 
366
  self.accelerator.end_training()
 
1
  from __future__ import annotations
2
 
3
  import gc
4
+ import math
5
  import os
6
 
7
  import torch
 
30
  learning_rate,
31
  num_warmup_updates=20000,
32
  save_per_updates=1000,
33
+ keep_last_n_checkpoints: int = -1, # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
34
  checkpoint_path=None,
35
+ batch_size_per_gpu=32,
36
  batch_size_type: str = "sample",
37
  max_samples=32,
38
  grad_accumulation_steps=1,
 
40
  noise_scheduler: str | None = None,
41
  duration_predictor: torch.nn.Module | None = None,
42
  logger: str | None = "wandb", # "wandb" | "tensorboard" | None
43
+ wandb_project="test_f5-tts",
44
  wandb_run_name="test_run",
45
  wandb_resume_id: str = None,
46
  log_samples: bool = False,
47
+ last_per_updates=None,
48
  accelerate_kwargs: dict = dict(),
49
  ema_kwargs: dict = dict(),
50
  bnb_optimizer: bool = False,
51
  mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
52
  is_local_vocoder: bool = False, # use local path vocoder
53
  local_vocoder_path: str = "", # local vocoder path
54
+ cfg_dict: dict = dict(), # training config
55
  ):
56
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
57
 
58
  if logger == "wandb" and not wandb.api.api_key:
59
  logger = None
 
60
  self.log_samples = log_samples
61
 
62
  self.accelerator = Accelerator(
 
73
  else:
74
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
75
 
76
+ if not cfg_dict:
77
+ cfg_dict = {
 
 
78
  "epochs": epochs,
79
  "learning_rate": learning_rate,
80
  "num_warmup_updates": num_warmup_updates,
81
+ "batch_size_per_gpu": batch_size_per_gpu,
82
  "batch_size_type": batch_size_type,
83
  "max_samples": max_samples,
84
  "grad_accumulation_steps": grad_accumulation_steps,
85
  "max_grad_norm": max_grad_norm,
 
86
  "noise_scheduler": noise_scheduler,
87
+ }
88
+ cfg_dict["gpus"] = self.accelerator.num_processes
89
+ self.accelerator.init_trackers(
90
+ project_name=wandb_project,
91
+ init_kwargs=init_kwargs,
92
+ config=cfg_dict,
93
  )
94
 
95
  elif self.logger == "tensorboard":
 
103
  self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
104
  self.ema_model.to(self.accelerator.device)
105
 
106
+ print(f"Using logger: {logger}")
107
+ if grad_accumulation_steps > 1:
108
+ print(
109
+ "Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e"
110
+ )
111
+
112
  self.epochs = epochs
113
  self.num_warmup_updates = num_warmup_updates
114
  self.save_per_updates = save_per_updates
115
+ self.keep_last_n_checkpoints = keep_last_n_checkpoints
116
+ self.last_per_updates = default(last_per_updates, save_per_updates)
117
+ self.checkpoint_path = default(checkpoint_path, "ckpts/test_f5-tts")
118
 
119
+ self.batch_size_per_gpu = batch_size_per_gpu
120
  self.batch_size_type = batch_size_type
121
  self.max_samples = max_samples
122
  self.grad_accumulation_steps = grad_accumulation_steps
 
143
  def is_main(self):
144
  return self.accelerator.is_main_process
145
 
146
+ def save_checkpoint(self, update, last=False):
147
  self.accelerator.wait_for_everyone()
148
  if self.is_main:
149
  checkpoint = dict(
 
151
  optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
152
  ema_model_state_dict=self.ema_model.state_dict(),
153
  scheduler_state_dict=self.scheduler.state_dict(),
154
+ update=update,
155
  )
156
  if not os.path.exists(self.checkpoint_path):
157
  os.makedirs(self.checkpoint_path)
158
  if last:
159
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
160
+ print(f"Saved last checkpoint at update {update}")
161
  else:
162
+ if self.keep_last_n_checkpoints == 0:
163
+ return
164
+ self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
165
+ if self.keep_last_n_checkpoints > 0:
166
+ # Updated logic to exclude pretrained model from rotation
167
+ checkpoints = [
168
+ f
169
+ for f in os.listdir(self.checkpoint_path)
170
+ if f.startswith("model_")
171
+ and not f.startswith("pretrained_") # Exclude pretrained models
172
+ and f.endswith(".pt")
173
+ and f != "model_last.pt"
174
+ ]
175
+ checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
176
+ while len(checkpoints) > self.keep_last_n_checkpoints:
177
+ oldest_checkpoint = checkpoints.pop(0)
178
+ os.remove(os.path.join(self.checkpoint_path, oldest_checkpoint))
179
+ print(f"Removed old checkpoint: {oldest_checkpoint}")
180
 
181
  def load_checkpoint(self):
182
  if (
183
  not exists(self.checkpoint_path)
184
  or not os.path.exists(self.checkpoint_path)
185
+ or not any(filename.endswith((".pt", ".safetensors")) for filename in os.listdir(self.checkpoint_path))
186
  ):
187
  return 0
188
 
 
190
  if "model_last.pt" in os.listdir(self.checkpoint_path):
191
  latest_checkpoint = "model_last.pt"
192
  else:
193
+ # Updated to consider pretrained models for loading but prioritize training checkpoints
194
+ all_checkpoints = [
195
+ f
196
+ for f in os.listdir(self.checkpoint_path)
197
+ if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith((".pt", ".safetensors"))
198
+ ]
199
+
200
+ # First try to find regular training checkpoints
201
+ training_checkpoints = [f for f in all_checkpoints if f.startswith("model_") and f != "model_last.pt"]
202
+ if training_checkpoints:
203
+ latest_checkpoint = sorted(
204
+ training_checkpoints,
205
+ key=lambda x: int("".join(filter(str.isdigit, x))),
206
+ )[-1]
207
+ else:
208
+ # If no training checkpoints, use pretrained model
209
+ latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_"))
210
+
211
+ if latest_checkpoint.endswith(".safetensors"): # always a pretrained checkpoint
212
+ from safetensors.torch import load_file
213
+
214
+ checkpoint = load_file(f"{self.checkpoint_path}/{latest_checkpoint}", device="cpu")
215
+ checkpoint = {"ema_model_state_dict": checkpoint}
216
+ elif latest_checkpoint.endswith(".pt"):
217
+ # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
218
+ checkpoint = torch.load(
219
+ f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu"
220
+ )
221
 
222
  # patch for backward compatibility, 305e3ea
223
  for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
 
227
  if self.is_main:
228
  self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
229
 
230
+ if "update" in checkpoint or "step" in checkpoint:
231
+ # patch for backward compatibility, with before f992c4e
232
+ if "step" in checkpoint:
233
+ checkpoint["update"] = checkpoint["step"] // self.grad_accumulation_steps
234
+ if self.grad_accumulation_steps > 1 and self.is_main:
235
+ print(
236
+ "F5-TTS WARNING: Loading checkpoint saved with per_steps logic (before f992c4e), will convert to per_updates according to grad_accumulation_steps setting, may have unexpected behaviour."
237
+ )
238
  # patch for backward compatibility, 305e3ea
239
  for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
240
  if key in checkpoint["model_state_dict"]:
 
244
  self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
245
  if self.scheduler:
246
  self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
247
+ update = checkpoint["update"]
248
  else:
249
  checkpoint["model_state_dict"] = {
250
  k.replace("ema_model.", ""): v
251
  for k, v in checkpoint["ema_model_state_dict"].items()
252
+ if k not in ["initted", "update", "step"]
253
  }
254
  self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
255
+ update = 0
256
 
257
  del checkpoint
258
  gc.collect()
259
+ return update
260
 
261
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
262
  if self.log_samples:
 
282
  num_workers=num_workers,
283
  pin_memory=True,
284
  persistent_workers=True,
285
+ batch_size=self.batch_size_per_gpu,
286
  shuffle=True,
287
  generator=generator,
288
  )
 
290
  self.accelerator.even_batches = False
291
  sampler = SequentialSampler(train_dataset)
292
  batch_sampler = DynamicBatchSampler(
293
+ sampler,
294
+ self.batch_size_per_gpu,
295
+ max_samples=self.max_samples,
296
+ random_seed=resumable_with_seed, # This enables reproducible shuffling
297
+ drop_residual=False,
298
  )
299
  train_dataloader = DataLoader(
300
  train_dataset,
 
309
 
310
  # accelerator.prepare() dispatches batches to devices;
311
  # which means the length of dataloader calculated before, should consider the number of devices
312
+ warmup_updates = (
313
  self.num_warmup_updates * self.accelerator.num_processes
314
  ) # consider a fixed warmup steps while using accelerate multi-gpu ddp
315
  # otherwise by default with split_batches=False, warmup steps change with num_processes
316
+ total_updates = math.ceil(len(train_dataloader) / self.grad_accumulation_steps) * self.epochs
317
+ decay_updates = total_updates - warmup_updates
318
+ warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_updates)
319
+ decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_updates)
320
  self.scheduler = SequentialLR(
321
+ self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_updates]
322
  )
323
  train_dataloader, self.scheduler = self.accelerator.prepare(
324
  train_dataloader, self.scheduler
325
+ ) # actual multi_gpu updates = single_gpu updates / gpu nums
326
+ start_update = self.load_checkpoint()
327
+ global_update = start_update
328
 
329
  if exists(resumable_with_seed):
330
  orig_epoch_step = len(train_dataloader)
331
+ start_step = start_update * self.grad_accumulation_steps
332
  skipped_epoch = int(start_step // orig_epoch_step)
333
  skipped_batch = start_step % orig_epoch_step
334
  skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
 
338
  for epoch in range(skipped_epoch, self.epochs):
339
  self.model.train()
340
  if exists(resumable_with_seed) and epoch == skipped_epoch:
341
+ progress_bar_initial = math.ceil(skipped_batch / self.grad_accumulation_steps)
342
+ current_dataloader = skipped_dataloader
 
 
 
 
 
 
343
  else:
344
+ progress_bar_initial = 0
345
+ current_dataloader = train_dataloader
346
+
347
+ # Set epoch for the batch sampler if it exists
348
+ if hasattr(train_dataloader, "batch_sampler") and hasattr(train_dataloader.batch_sampler, "set_epoch"):
349
+ train_dataloader.batch_sampler.set_epoch(epoch)
350
+
351
+ progress_bar = tqdm(
352
+ range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)),
353
+ desc=f"Epoch {epoch+1}/{self.epochs}",
354
+ unit="update",
355
+ disable=not self.accelerator.is_local_main_process,
356
+ initial=progress_bar_initial,
357
+ )
358
 
359
+ for batch in current_dataloader:
360
  with self.accelerator.accumulate(self.model):
361
  text_inputs = batch["text"]
362
  mel_spec = batch["mel"].permute(0, 2, 1)
 
365
  # TODO. add duration predictor training
366
  if self.duration_predictor is not None and self.accelerator.is_local_main_process:
367
  dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
368
+ self.accelerator.log({"duration loss": dur_loss.item()}, step=global_update)
369
 
370
  loss, cond, pred = self.model(
371
  mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
 
379
  self.scheduler.step()
380
  self.optimizer.zero_grad()
381
 
382
+ if self.accelerator.sync_gradients:
383
+ if self.is_main:
384
+ self.ema_model.update()
385
 
386
+ global_update += 1
387
+ progress_bar.update(1)
388
+ progress_bar.set_postfix(update=str(global_update), loss=loss.item())
389
 
390
  if self.accelerator.is_local_main_process:
391
+ self.accelerator.log(
392
+ {"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_update
393
+ )
394
  if self.logger == "tensorboard":
395
+ self.writer.add_scalar("loss", loss.item(), global_update)
396
+ self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update)
 
 
397
 
398
+ if global_update % self.save_per_updates == 0 and self.accelerator.sync_gradients:
399
+ self.save_checkpoint(global_update)
400
 
401
  if self.log_samples and self.accelerator.is_local_main_process:
402
  ref_audio_len = mel_lengths[0]
 
422
  gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
423
  ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
424
 
425
+ torchaudio.save(
426
+ f"{log_samples_path}/update_{global_update}_gen.wav", gen_audio, target_sample_rate
427
+ )
428
+ torchaudio.save(
429
+ f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate
430
+ )
431
 
432
+ if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients:
433
+ self.save_checkpoint(global_update, last=True)
434
 
435
+ self.save_checkpoint(global_update, last=True)
436
 
437
  self.accelerator.end_training()
src/f5_tts/model/utils.py CHANGED
@@ -133,11 +133,12 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
133
 
134
  # convert char to pinyin
135
 
136
- jieba.initialize()
137
- print("Word segmentation module jieba initialized.\n")
138
-
139
 
140
  def convert_char_to_pinyin(text_list, polyphone=True):
 
 
 
 
141
  final_text_list = []
142
  custom_trans = str.maketrans(
143
  {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
 
133
 
134
  # convert char to pinyin
135
 
 
 
 
136
 
137
  def convert_char_to_pinyin(text_list, polyphone=True):
138
+ if jieba.dt.initialized is False:
139
+ jieba.default_logger.setLevel(50) # CRITICAL
140
+ jieba.initialize()
141
+
142
  final_text_list = []
143
  custom_trans = str.maketrans(
144
  {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
src/f5_tts/scripts/count_max_epoch.py CHANGED
@@ -9,7 +9,7 @@ mel_hop_length = 256
9
  mel_sampling_rate = 24000
10
 
11
  # target
12
- wanted_max_updates = 1000000
13
 
14
  # train params
15
  gpus = 8
@@ -20,13 +20,13 @@ grad_accum = 1
20
  mini_batch_frames = frames_per_gpu * grad_accum * gpus
21
  mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
22
  updates_per_epoch = total_hours / mini_batch_hours
23
- steps_per_epoch = updates_per_epoch * grad_accum
24
 
25
  # result
26
  epochs = wanted_max_updates / updates_per_epoch
27
  print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
28
  print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
29
- print(f" or approx. 0/{steps_per_epoch:.0f} steps")
30
 
31
  # others
32
  print(f"total {total_hours:.0f} hours")
 
9
  mel_sampling_rate = 24000
10
 
11
  # target
12
+ wanted_max_updates = 1200000
13
 
14
  # train params
15
  gpus = 8
 
20
  mini_batch_frames = frames_per_gpu * grad_accum * gpus
21
  mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
22
  updates_per_epoch = total_hours / mini_batch_hours
23
+ # steps_per_epoch = updates_per_epoch * grad_accum
24
 
25
  # result
26
  epochs = wanted_max_updates / updates_per_epoch
27
  print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
28
  print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
29
+ # print(f" or approx. 0/{steps_per_epoch:.0f} steps")
30
 
31
  # others
32
  print(f"total {total_hours:.0f} hours")
src/f5_tts/socket_client.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import socket
2
+ import asyncio
3
+ import pyaudio
4
+ import numpy as np
5
+ import logging
6
+ import time
7
+
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
13
+ client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
14
+ await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port)))
15
+
16
+ start_time = time.time()
17
+ first_chunk_time = None
18
+
19
+ async def play_audio_stream():
20
+ nonlocal first_chunk_time
21
+ p = pyaudio.PyAudio()
22
+ stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
23
+
24
+ try:
25
+ while True:
26
+ data = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 8192)
27
+ if not data:
28
+ break
29
+ if data == b"END":
30
+ logger.info("End of audio received.")
31
+ break
32
+
33
+ audio_array = np.frombuffer(data, dtype=np.float32)
34
+ stream.write(audio_array.tobytes())
35
+
36
+ if first_chunk_time is None:
37
+ first_chunk_time = time.time()
38
+
39
+ finally:
40
+ stream.stop_stream()
41
+ stream.close()
42
+ p.terminate()
43
+
44
+ logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds")
45
+
46
+ try:
47
+ data_to_send = f"{text}".encode("utf-8")
48
+ await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send)
49
+ await play_audio_stream()
50
+
51
+ except Exception as e:
52
+ logger.error(f"Error in listen_to_F5TTS: {e}")
53
+
54
+ finally:
55
+ client_socket.close()
56
+
57
+
58
+ if __name__ == "__main__":
59
+ text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components"
60
+
61
+ asyncio.run(listen_to_F5TTS(text_to_send))
src/f5_tts/socket_server.py CHANGED
@@ -1,142 +1,213 @@
1
  import argparse
2
  import gc
 
 
 
3
  import socket
4
  import struct
5
- import torch
6
- import torchaudio
7
  import traceback
 
8
  from importlib.resources import files
9
- from threading import Thread
10
 
11
- from cached_path import cached_path
12
-
13
- from infer.utils_infer import infer_batch_process, preprocess_ref_audio_text, load_vocoder, load_model
14
- from model.backbones.dit import DiT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  class TTSStreamingProcessor:
18
- def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
19
  self.device = device or (
20
- "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
 
 
 
 
 
 
21
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # Load the model using the provided checkpoint and vocab files
24
- self.model = load_model(
25
- model_cls=DiT,
26
- model_cfg=dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
27
  ckpt_path=ckpt_file,
28
- mel_spec_type="vocos", # or "bigvgan" depending on vocoder
29
  vocab_file=vocab_file,
30
  ode_method="euler",
31
  use_ema=True,
32
  device=self.device,
33
  ).to(self.device, dtype=dtype)
34
 
35
- # Load the vocoder
36
- self.vocoder = load_vocoder(is_local=False)
37
 
38
- # Set sampling rate for streaming
39
- self.sampling_rate = 24000 # Consistency with client
 
40
 
41
- # Set reference audio and text
42
- self.ref_audio = ref_audio
43
- self.ref_text = ref_text
44
-
45
- # Warm up the model
46
- self._warm_up()
47
 
48
  def _warm_up(self):
49
- """Warm up the model with a dummy input to ensure it's ready for real-time processing."""
50
- print("Warming up the model...")
51
- ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
52
- audio, sr = torchaudio.load(ref_audio)
53
  gen_text = "Warm-up text for the model."
54
-
55
- # Pass the vocoder as an argument here
56
- infer_batch_process((audio, sr), ref_text, [gen_text], self.model, self.vocoder, device=self.device)
57
- print("Warm-up completed.")
58
-
59
- def generate_stream(self, text, play_steps_in_s=0.5):
60
- """Generate audio in chunks and yield them in real-time."""
61
- # Preprocess the reference audio and text
62
- ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
63
-
64
- # Load reference audio
65
- audio, sr = torchaudio.load(ref_audio)
66
-
67
- # Run inference for the input text
68
- audio_chunk, final_sample_rate, _ = infer_batch_process(
69
- (audio, sr),
70
- ref_text,
71
- [text],
72
  self.model,
73
  self.vocoder,
74
- device=self.device, # Pass vocoder here
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  )
76
 
77
- # Break the generated audio into chunks and send them
78
- chunk_size = int(final_sample_rate * play_steps_in_s)
79
-
80
- if len(audio_chunk) < chunk_size:
81
- packed_audio = struct.pack(f"{len(audio_chunk)}f", *audio_chunk)
82
- yield packed_audio
83
- return
84
-
85
- for i in range(0, len(audio_chunk), chunk_size):
86
- chunk = audio_chunk[i : i + chunk_size]
87
-
88
- # Check if it's the final chunk
89
- if i + chunk_size >= len(audio_chunk):
90
- chunk = audio_chunk[i:]
91
-
92
- # Send the chunk if it is not empty
93
- if len(chunk) > 0:
94
- packed_audio = struct.pack(f"{len(chunk)}f", *chunk)
95
- yield packed_audio
96
 
 
 
 
97
 
98
- def handle_client(client_socket, processor):
99
- try:
100
- while True:
101
- # Receive data from the client
102
- data = client_socket.recv(1024).decode("utf-8")
103
- if not data:
104
- break
105
 
106
- try:
107
- # The client sends the text input
108
- text = data.strip()
109
 
110
- # Generate and stream audio chunks
111
- for audio_chunk in processor.generate_stream(text):
112
- client_socket.sendall(audio_chunk)
113
 
114
- # Send end-of-audio signal
115
- client_socket.sendall(b"END_OF_AUDIO")
116
 
117
- except Exception as inner_e:
118
- print(f"Error during processing: {inner_e}")
119
- traceback.print_exc() # Print the full traceback to diagnose the issue
120
- break
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  except Exception as e:
123
- print(f"Error handling client: {e}")
124
  traceback.print_exc()
125
- finally:
126
- client_socket.close()
127
 
128
 
129
  def start_server(host, port, processor):
130
- server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
131
- server.bind((host, port))
132
- server.listen(5)
133
- print(f"Server listening on {host}:{port}")
134
-
135
- while True:
136
- client_socket, addr = server.accept()
137
- print(f"Accepted connection from {addr}")
138
- client_handler = Thread(target=handle_client, args=(client_socket, processor))
139
- client_handler.start()
140
 
141
 
142
  if __name__ == "__main__":
@@ -145,9 +216,14 @@ if __name__ == "__main__":
145
  parser.add_argument("--host", default="0.0.0.0")
146
  parser.add_argument("--port", default=9998)
147
 
 
 
 
 
 
148
  parser.add_argument(
149
  "--ckpt_file",
150
- default=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")),
151
  help="Path to the model checkpoint file",
152
  )
153
  parser.add_argument(
@@ -175,6 +251,7 @@ if __name__ == "__main__":
175
  try:
176
  # Initialize the processor with the model and vocoder
177
  processor = TTSStreamingProcessor(
 
178
  ckpt_file=args.ckpt_file,
179
  vocab_file=args.vocab_file,
180
  ref_audio=args.ref_audio,
 
1
  import argparse
2
  import gc
3
+ import logging
4
+ import numpy as np
5
+ import queue
6
  import socket
7
  import struct
8
+ import threading
 
9
  import traceback
10
+ import wave
11
  from importlib.resources import files
 
12
 
13
+ import torch
14
+ import torchaudio
15
+ from huggingface_hub import hf_hub_download
16
+ from omegaconf import OmegaConf
17
+
18
+ from f5_tts.model.backbones.dit import DiT # noqa: F401. used for config
19
+ from f5_tts.infer.utils_infer import (
20
+ chunk_text,
21
+ preprocess_ref_audio_text,
22
+ load_vocoder,
23
+ load_model,
24
+ infer_batch_process,
25
+ )
26
+
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class AudioFileWriterThread(threading.Thread):
32
+ """Threaded file writer to avoid blocking the TTS streaming process."""
33
+
34
+ def __init__(self, output_file, sampling_rate):
35
+ super().__init__()
36
+ self.output_file = output_file
37
+ self.sampling_rate = sampling_rate
38
+ self.queue = queue.Queue()
39
+ self.stop_event = threading.Event()
40
+ self.audio_data = []
41
+
42
+ def run(self):
43
+ """Process queued audio data and write it to a file."""
44
+ logger.info("AudioFileWriterThread started.")
45
+ with wave.open(self.output_file, "wb") as wf:
46
+ wf.setnchannels(1)
47
+ wf.setsampwidth(2)
48
+ wf.setframerate(self.sampling_rate)
49
+
50
+ while not self.stop_event.is_set() or not self.queue.empty():
51
+ try:
52
+ chunk = self.queue.get(timeout=0.1)
53
+ if chunk is not None:
54
+ chunk = np.int16(chunk * 32767)
55
+ self.audio_data.append(chunk)
56
+ wf.writeframes(chunk.tobytes())
57
+ except queue.Empty:
58
+ continue
59
+
60
+ def add_chunk(self, chunk):
61
+ """Add a new chunk to the queue."""
62
+ self.queue.put(chunk)
63
+
64
+ def stop(self):
65
+ """Stop writing and ensure all queued data is written."""
66
+ self.stop_event.set()
67
+ self.join()
68
+ logger.info("Audio writing completed.")
69
 
70
 
71
  class TTSStreamingProcessor:
72
+ def __init__(self, model, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
73
  self.device = device or (
74
+ "cuda"
75
+ if torch.cuda.is_available()
76
+ else "xpu"
77
+ if torch.xpu.is_available()
78
+ else "mps"
79
+ if torch.backends.mps.is_available()
80
+ else "cpu"
81
  )
82
+ model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
83
+ self.model_cls = globals()[model_cfg.model.backbone]
84
+ self.model_arc = model_cfg.model.arch
85
+ self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
86
+ self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate
87
+
88
+ self.model = self.load_ema_model(ckpt_file, vocab_file, dtype)
89
+ self.vocoder = self.load_vocoder_model()
90
+
91
+ self.update_reference(ref_audio, ref_text)
92
+ self._warm_up()
93
+ self.file_writer_thread = None
94
+ self.first_package = True
95
 
96
+ def load_ema_model(self, ckpt_file, vocab_file, dtype):
97
+ return load_model(
98
+ self.model_cls,
99
+ self.model_arc,
100
  ckpt_path=ckpt_file,
101
+ mel_spec_type=self.mel_spec_type,
102
  vocab_file=vocab_file,
103
  ode_method="euler",
104
  use_ema=True,
105
  device=self.device,
106
  ).to(self.device, dtype=dtype)
107
 
108
+ def load_vocoder_model(self):
109
+ return load_vocoder(vocoder_name=self.mel_spec_type, is_local=False, local_path=None, device=self.device)
110
 
111
+ def update_reference(self, ref_audio, ref_text):
112
+ self.ref_audio, self.ref_text = preprocess_ref_audio_text(ref_audio, ref_text)
113
+ self.audio, self.sr = torchaudio.load(self.ref_audio)
114
 
115
+ ref_audio_duration = self.audio.shape[-1] / self.sr
116
+ ref_text_byte_len = len(self.ref_text.encode("utf-8"))
117
+ self.max_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration))
118
+ self.few_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 2)
119
+ self.min_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 4)
 
120
 
121
  def _warm_up(self):
122
+ logger.info("Warming up the model...")
 
 
 
123
  gen_text = "Warm-up text for the model."
124
+ for _ in infer_batch_process(
125
+ (self.audio, self.sr),
126
+ self.ref_text,
127
+ [gen_text],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  self.model,
129
  self.vocoder,
130
+ progress=None,
131
+ device=self.device,
132
+ streaming=True,
133
+ ):
134
+ pass
135
+ logger.info("Warm-up completed.")
136
+
137
+ def generate_stream(self, text, conn):
138
+ text_batches = chunk_text(text, max_chars=self.max_chars)
139
+ if self.first_package:
140
+ text_batches = chunk_text(text_batches[0], max_chars=self.few_chars) + text_batches[1:]
141
+ text_batches = chunk_text(text_batches[0], max_chars=self.min_chars) + text_batches[1:]
142
+ self.first_package = False
143
+
144
+ audio_stream = infer_batch_process(
145
+ (self.audio, self.sr),
146
+ self.ref_text,
147
+ text_batches,
148
+ self.model,
149
+ self.vocoder,
150
+ progress=None,
151
+ device=self.device,
152
+ streaming=True,
153
+ chunk_size=2048,
154
  )
155
 
156
+ # Reset the file writer thread
157
+ if self.file_writer_thread is not None:
158
+ self.file_writer_thread.stop()
159
+ self.file_writer_thread = AudioFileWriterThread("output.wav", self.sampling_rate)
160
+ self.file_writer_thread.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
+ for audio_chunk, _ in audio_stream:
163
+ if len(audio_chunk) > 0:
164
+ logger.info(f"Generated audio chunk of size: {len(audio_chunk)}")
165
 
166
+ # Send audio chunk via socket
167
+ conn.sendall(struct.pack(f"{len(audio_chunk)}f", *audio_chunk))
 
 
 
 
 
168
 
169
+ # Write to file asynchronously
170
+ self.file_writer_thread.add_chunk(audio_chunk)
 
171
 
172
+ logger.info("Finished sending audio stream.")
173
+ conn.sendall(b"END") # Send end signal
 
174
 
175
+ # Ensure all audio data is written before exiting
176
+ self.file_writer_thread.stop()
177
 
 
 
 
 
178
 
179
+ def handle_client(conn, processor):
180
+ try:
181
+ with conn:
182
+ conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
183
+ while True:
184
+ data = conn.recv(1024)
185
+ if not data:
186
+ processor.first_package = True
187
+ break
188
+ data_str = data.decode("utf-8").strip()
189
+ logger.info(f"Received text: {data_str}")
190
+
191
+ try:
192
+ processor.generate_stream(data_str, conn)
193
+ except Exception as inner_e:
194
+ logger.error(f"Error during processing: {inner_e}")
195
+ traceback.print_exc()
196
+ break
197
  except Exception as e:
198
+ logger.error(f"Error handling client: {e}")
199
  traceback.print_exc()
 
 
200
 
201
 
202
  def start_server(host, port, processor):
203
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
204
+ s.bind((host, port))
205
+ s.listen()
206
+ logger.info(f"Server started on {host}:{port}")
207
+ while True:
208
+ conn, addr = s.accept()
209
+ logger.info(f"Connected by {addr}")
210
+ handle_client(conn, processor)
 
 
211
 
212
 
213
  if __name__ == "__main__":
 
216
  parser.add_argument("--host", default="0.0.0.0")
217
  parser.add_argument("--port", default=9998)
218
 
219
+ parser.add_argument(
220
+ "--model",
221
+ default="F5TTS_v1_Base",
222
+ help="The model name, e.g. F5TTS_v1_Base",
223
+ )
224
  parser.add_argument(
225
  "--ckpt_file",
226
+ default=str(hf_hub_download(repo_id="SWivid/F5-TTS", filename="F5TTS_v1_Base/model_1250000.safetensors")),
227
  help="Path to the model checkpoint file",
228
  )
229
  parser.add_argument(
 
251
  try:
252
  # Initialize the processor with the model and vocoder
253
  processor = TTSStreamingProcessor(
254
+ model=args.model,
255
  ckpt_file=args.ckpt_file,
256
  vocab_file=args.vocab_file,
257
  ref_audio=args.ref_audio,
src/f5_tts/train/README.md CHANGED
@@ -40,10 +40,10 @@ Once your datasets are prepared, you can start the training process.
40
  accelerate config
41
 
42
  # .yaml files are under src/f5_tts/configs directory
43
- accelerate launch src/f5_tts/train/train.py --config-name F5TTS_Base_train.yaml
44
 
45
  # possible to overwrite accelerate and hydra config
46
- accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_Small_train.yaml ++datasets.batch_size_per_gpu=19200
47
  ```
48
 
49
  ### 2. Finetuning practice
@@ -53,7 +53,7 @@ Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#1
53
 
54
  The `use_ema = True` is harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off and see if provide better results.
55
 
56
- ### 3. Wandb Logging
57
 
58
  The `wandb/` dir will be created under path you run training/finetuning scripts.
59
 
@@ -62,7 +62,7 @@ By default, the training script does NOT use logging (assuming you didn't manual
62
  To turn on wandb logging, you can either:
63
 
64
  1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
65
- 2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/site/ and set the environment variable as follows:
66
 
67
  On Mac & Linux:
68
 
@@ -75,7 +75,7 @@ On Windows:
75
  ```
76
  set WANDB_API_KEY=<YOUR WANDB API KEY>
77
  ```
78
- Moreover, if you couldn't access Wandb and want to log metrics offline, you can the environment variable as follows:
79
 
80
  ```
81
  export WANDB_MODE=offline
 
40
  accelerate config
41
 
42
  # .yaml files are under src/f5_tts/configs directory
43
+ accelerate launch src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml
44
 
45
  # possible to overwrite accelerate and hydra config
46
+ accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml ++datasets.batch_size_per_gpu=19200
47
  ```
48
 
49
  ### 2. Finetuning practice
 
53
 
54
  The `use_ema = True` is harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off and see if provide better results.
55
 
56
+ ### 3. W&B Logging
57
 
58
  The `wandb/` dir will be created under path you run training/finetuning scripts.
59
 
 
62
  To turn on wandb logging, you can either:
63
 
64
  1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
65
+ 2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/authorize and set the environment variable as follows:
66
 
67
  On Mac & Linux:
68
 
 
75
  ```
76
  set WANDB_API_KEY=<YOUR WANDB API KEY>
77
  ```
78
+ Moreover, if you couldn't access W&B and want to log metrics offline, you can set the environment variable as follows:
79
 
80
  ```
81
  export WANDB_MODE=offline
src/f5_tts/train/datasets/prepare_csv_wavs.py CHANGED
@@ -1,12 +1,17 @@
1
  import os
2
  import sys
 
 
 
 
 
 
3
 
4
  sys.path.append(os.getcwd())
5
 
6
  import argparse
7
  import csv
8
  import json
9
- import shutil
10
  from importlib.resources import files
11
  from pathlib import Path
12
 
@@ -29,32 +34,157 @@ def is_csv_wavs_format(input_dataset_dir):
29
  return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
30
 
31
 
32
- def prepare_csv_wavs_dir(input_dir):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
34
  input_dir = Path(input_dir)
35
  metadata_path = input_dir / "metadata.csv"
36
  audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
37
 
38
- sub_result, durations = [], []
39
- vocab_set = set()
40
  polyphone = True
41
- for audio_path, text in audio_path_text_pairs:
42
- if not Path(audio_path).exists():
43
- print(f"audio {audio_path} not found, skipping")
44
- continue
45
- audio_duration = get_audio_duration(audio_path)
46
- # assume tokenizer = "pinyin" ("pinyin" | "char")
47
- text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
48
- sub_result.append({"audio_path": audio_path, "text": text, "duration": audio_duration})
49
- durations.append(audio_duration)
50
- vocab_set.update(list(text))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  return sub_result, durations, vocab_set
53
 
54
 
55
- def get_audio_duration(audio_path):
56
- audio, sample_rate = torchaudio.load(audio_path)
57
- return audio.shape[1] / sample_rate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
 
60
  def read_audio_text_pairs(csv_file_path):
@@ -76,36 +206,27 @@ def read_audio_text_pairs(csv_file_path):
76
 
77
  def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
78
  out_dir = Path(out_dir)
79
- # save preprocessed dataset to disk
80
  out_dir.mkdir(exist_ok=True, parents=True)
81
  print(f"\nSaving to {out_dir} ...")
82
 
83
- # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
84
- # dataset.save_to_disk(f"{out_dir}/raw", max_shard_size="2GB")
85
  raw_arrow_path = out_dir / "raw.arrow"
86
- with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
87
  for line in tqdm(result, desc="Writing to raw.arrow ..."):
88
  writer.write(line)
89
 
90
- # dup a json separately saving duration in case for DynamicBatchSampler ease
91
  dur_json_path = out_dir / "duration.json"
92
  with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
93
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
94
 
95
- # vocab map, i.e. tokenizer
96
- # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
97
- # if tokenizer == "pinyin":
98
- # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
99
  voca_out_path = out_dir / "vocab.txt"
100
- with open(voca_out_path.as_posix(), "w") as f:
101
- for vocab in sorted(text_vocab_set):
102
- f.write(vocab + "\n")
103
-
104
  if is_finetune:
105
  file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
106
  shutil.copy2(file_vocab_finetune, voca_out_path)
107
  else:
108
- with open(voca_out_path, "w") as f:
109
  for vocab in sorted(text_vocab_set):
110
  f.write(vocab + "\n")
111
 
@@ -115,24 +236,48 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
115
  print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
116
 
117
 
118
- def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True):
119
  if is_finetune:
120
  assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
121
- sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir)
122
  save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
123
 
124
 
125
  def cli():
126
- # finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
127
- # pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
128
- parser = argparse.ArgumentParser(description="Prepare and save dataset.")
129
- parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
130
- parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
131
- parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
132
-
133
- args = parser.parse_args()
134
-
135
- prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
 
138
  if __name__ == "__main__":
 
1
  import os
2
  import sys
3
+ import signal
4
+ import subprocess # For invoking ffprobe
5
+ import shutil
6
+ import concurrent.futures
7
+ import multiprocessing
8
+ from contextlib import contextmanager
9
 
10
  sys.path.append(os.getcwd())
11
 
12
  import argparse
13
  import csv
14
  import json
 
15
  from importlib.resources import files
16
  from pathlib import Path
17
 
 
34
  return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
35
 
36
 
37
+ # Configuration constants
38
+ BATCH_SIZE = 100 # Batch size for text conversion
39
+ MAX_WORKERS = max(1, multiprocessing.cpu_count() - 1) # Leave one CPU free
40
+ THREAD_NAME_PREFIX = "AudioProcessor"
41
+ CHUNK_SIZE = 100 # Number of files to process per worker batch
42
+
43
+ executor = None # Global executor for cleanup
44
+
45
+
46
+ @contextmanager
47
+ def graceful_exit():
48
+ """Context manager for graceful shutdown on signals"""
49
+
50
+ def signal_handler(signum, frame):
51
+ print("\nReceived signal to terminate. Cleaning up...")
52
+ if executor is not None:
53
+ print("Shutting down executor...")
54
+ executor.shutdown(wait=False, cancel_futures=True)
55
+ sys.exit(1)
56
+
57
+ # Set up signal handlers
58
+ signal.signal(signal.SIGINT, signal_handler)
59
+ signal.signal(signal.SIGTERM, signal_handler)
60
+
61
+ try:
62
+ yield
63
+ finally:
64
+ if executor is not None:
65
+ executor.shutdown(wait=False)
66
+
67
+
68
+ def process_audio_file(audio_path, text, polyphone):
69
+ """Process a single audio file by checking its existence and extracting duration."""
70
+ if not Path(audio_path).exists():
71
+ print(f"audio {audio_path} not found, skipping")
72
+ return None
73
+ try:
74
+ audio_duration = get_audio_duration(audio_path)
75
+ if audio_duration <= 0:
76
+ raise ValueError(f"Duration {audio_duration} is non-positive.")
77
+ return (audio_path, text, audio_duration)
78
+ except Exception as e:
79
+ print(f"Warning: Failed to process {audio_path} due to error: {e}. Skipping corrupt file.")
80
+ return None
81
+
82
+
83
+ def batch_convert_texts(texts, polyphone, batch_size=BATCH_SIZE):
84
+ """Convert a list of texts to pinyin in batches."""
85
+ converted_texts = []
86
+ for i in range(0, len(texts), batch_size):
87
+ batch = texts[i : i + batch_size]
88
+ converted_batch = convert_char_to_pinyin(batch, polyphone=polyphone)
89
+ converted_texts.extend(converted_batch)
90
+ return converted_texts
91
+
92
+
93
+ def prepare_csv_wavs_dir(input_dir, num_workers=None):
94
+ global executor
95
  assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
96
  input_dir = Path(input_dir)
97
  metadata_path = input_dir / "metadata.csv"
98
  audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
99
 
 
 
100
  polyphone = True
101
+ total_files = len(audio_path_text_pairs)
102
+
103
+ # Use provided worker count or calculate optimal number
104
+ worker_count = num_workers if num_workers is not None else min(MAX_WORKERS, total_files)
105
+ print(f"\nProcessing {total_files} audio files using {worker_count} workers...")
106
+
107
+ with graceful_exit():
108
+ # Initialize thread pool with optimized settings
109
+ with concurrent.futures.ThreadPoolExecutor(
110
+ max_workers=worker_count, thread_name_prefix=THREAD_NAME_PREFIX
111
+ ) as exec:
112
+ executor = exec
113
+ results = []
114
+
115
+ # Process files in chunks for better efficiency
116
+ for i in range(0, len(audio_path_text_pairs), CHUNK_SIZE):
117
+ chunk = audio_path_text_pairs[i : i + CHUNK_SIZE]
118
+ # Submit futures in order
119
+ chunk_futures = [executor.submit(process_audio_file, pair[0], pair[1], polyphone) for pair in chunk]
120
+
121
+ # Iterate over futures in the original submission order to preserve ordering
122
+ for future in tqdm(
123
+ chunk_futures,
124
+ total=len(chunk),
125
+ desc=f"Processing chunk {i//CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1)//CHUNK_SIZE}",
126
+ ):
127
+ try:
128
+ result = future.result()
129
+ if result is not None:
130
+ results.append(result)
131
+ except Exception as e:
132
+ print(f"Error processing file: {e}")
133
+
134
+ executor = None
135
+
136
+ # Filter out failed results
137
+ processed = [res for res in results if res is not None]
138
+ if not processed:
139
+ raise RuntimeError("No valid audio files were processed!")
140
+
141
+ # Batch process text conversion
142
+ raw_texts = [item[1] for item in processed]
143
+ converted_texts = batch_convert_texts(raw_texts, polyphone, batch_size=BATCH_SIZE)
144
+
145
+ # Prepare final results
146
+ sub_result = []
147
+ durations = []
148
+ vocab_set = set()
149
+
150
+ for (audio_path, _, duration), conv_text in zip(processed, converted_texts):
151
+ sub_result.append({"audio_path": audio_path, "text": conv_text, "duration": duration})
152
+ durations.append(duration)
153
+ vocab_set.update(list(conv_text))
154
 
155
  return sub_result, durations, vocab_set
156
 
157
 
158
+ def get_audio_duration(audio_path, timeout=5):
159
+ """
160
+ Get the duration of an audio file in seconds using ffmpeg's ffprobe.
161
+ Falls back to torchaudio.load() if ffprobe fails.
162
+ """
163
+ try:
164
+ cmd = [
165
+ "ffprobe",
166
+ "-v",
167
+ "error",
168
+ "-show_entries",
169
+ "format=duration",
170
+ "-of",
171
+ "default=noprint_wrappers=1:nokey=1",
172
+ audio_path,
173
+ ]
174
+ result = subprocess.run(
175
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True, timeout=timeout
176
+ )
177
+ duration_str = result.stdout.strip()
178
+ if duration_str:
179
+ return float(duration_str)
180
+ raise ValueError("Empty duration string from ffprobe.")
181
+ except (subprocess.TimeoutExpired, subprocess.SubprocessError, ValueError) as e:
182
+ print(f"Warning: ffprobe failed for {audio_path} with error: {e}. Falling back to torchaudio.")
183
+ try:
184
+ audio, sample_rate = torchaudio.load(audio_path)
185
+ return audio.shape[1] / sample_rate
186
+ except Exception as e:
187
+ raise RuntimeError(f"Both ffprobe and torchaudio failed for {audio_path}: {e}")
188
 
189
 
190
  def read_audio_text_pairs(csv_file_path):
 
206
 
207
  def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
208
  out_dir = Path(out_dir)
 
209
  out_dir.mkdir(exist_ok=True, parents=True)
210
  print(f"\nSaving to {out_dir} ...")
211
 
212
+ # Save dataset with improved batch size for better I/O performance
 
213
  raw_arrow_path = out_dir / "raw.arrow"
214
+ with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=100) as writer:
215
  for line in tqdm(result, desc="Writing to raw.arrow ..."):
216
  writer.write(line)
217
 
218
+ # Save durations to JSON
219
  dur_json_path = out_dir / "duration.json"
220
  with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
221
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
222
 
223
+ # Handle vocab file - write only once based on finetune flag
 
 
 
224
  voca_out_path = out_dir / "vocab.txt"
 
 
 
 
225
  if is_finetune:
226
  file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
227
  shutil.copy2(file_vocab_finetune, voca_out_path)
228
  else:
229
+ with open(voca_out_path.as_posix(), "w") as f:
230
  for vocab in sorted(text_vocab_set):
231
  f.write(vocab + "\n")
232
 
 
236
  print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
237
 
238
 
239
+ def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None):
240
  if is_finetune:
241
  assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
242
+ sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir, num_workers=num_workers)
243
  save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
244
 
245
 
246
  def cli():
247
+ try:
248
+ # Before processing, check if ffprobe is available.
249
+ if shutil.which("ffprobe") is None:
250
+ print(
251
+ "Warning: ffprobe is not available. Duration extraction will rely on torchaudio (which may be slower)."
252
+ )
253
+
254
+ # Usage examples in help text
255
+ parser = argparse.ArgumentParser(
256
+ description="Prepare and save dataset.",
257
+ epilog="""
258
+ Examples:
259
+ # For fine-tuning (default):
260
+ python prepare_csv_wavs.py /input/dataset/path /output/dataset/path
261
+
262
+ # For pre-training:
263
+ python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --pretrain
264
+
265
+ # With custom worker count:
266
+ python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --workers 4
267
+ """,
268
+ )
269
+ parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
270
+ parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
271
+ parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
272
+ parser.add_argument("--workers", type=int, help=f"Number of worker threads (default: {MAX_WORKERS})")
273
+ args = parser.parse_args()
274
+
275
+ prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain, num_workers=args.workers)
276
+ except KeyboardInterrupt:
277
+ print("\nOperation cancelled by user. Cleaning up...")
278
+ if executor is not None:
279
+ executor.shutdown(wait=False, cancel_futures=True)
280
+ sys.exit(1)
281
 
282
 
283
  if __name__ == "__main__":
src/f5_tts/train/finetune_cli.py CHANGED
@@ -1,12 +1,13 @@
1
  import argparse
2
  import os
3
  import shutil
 
4
 
5
  from cached_path import cached_path
 
6
  from f5_tts.model import CFM, UNetT, DiT, Trainer
7
  from f5_tts.model.utils import get_tokenizer
8
  from f5_tts.model.dataset import load_dataset
9
- from importlib.resources import files
10
 
11
 
12
  # -------------------------- Dataset Settings --------------------------- #
@@ -20,19 +21,14 @@ mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
20
 
21
  # -------------------------- Argument Parsing --------------------------- #
22
  def parse_args():
23
- # batch_size_per_gpu = 1000 settting for gpu 8GB
24
- # batch_size_per_gpu = 1600 settting for gpu 12GB
25
- # batch_size_per_gpu = 2000 settting for gpu 16GB
26
- # batch_size_per_gpu = 3200 settting for gpu 24GB
27
-
28
- # num_warmup_updates = 300 for 5000 sample about 10 hours
29
-
30
- # change save_per_updates , last_per_steps change this value what you need ,
31
-
32
  parser = argparse.ArgumentParser(description="Train CFM Model")
33
 
34
  parser.add_argument(
35
- "--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
 
 
 
 
36
  )
37
  parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
38
  parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
@@ -44,9 +40,15 @@ def parse_args():
44
  parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
45
  parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
46
  parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
47
- parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup steps")
48
- parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
49
- parser.add_argument("--last_per_steps", type=int, default=50000, help="Save last checkpoint every X steps")
 
 
 
 
 
 
50
  parser.add_argument("--finetune", action="store_true", help="Use Finetune")
51
  parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
52
  parser.add_argument(
@@ -61,7 +63,7 @@ def parse_args():
61
  parser.add_argument(
62
  "--log_samples",
63
  action="store_true",
64
- help="Log inferenced samples per ckpt save steps",
65
  )
66
  parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
67
  parser.add_argument(
@@ -82,19 +84,54 @@ def main():
82
  checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
83
 
84
  # Model parameters based on experiment name
85
- if args.exp_name == "F5TTS_Base":
 
86
  wandb_resume_id = None
87
  model_cls = DiT
88
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  if args.finetune:
90
  if args.pretrain is None:
91
  ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
92
  else:
93
  ckpt_path = args.pretrain
 
94
  elif args.exp_name == "E2TTS_Base":
95
  wandb_resume_id = None
96
  model_cls = UNetT
97
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
 
 
 
 
 
 
 
98
  if args.finetune:
99
  if args.pretrain is None:
100
  ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
@@ -105,12 +142,16 @@ def main():
105
  if not os.path.isdir(checkpoint_path):
106
  os.makedirs(checkpoint_path, exist_ok=True)
107
 
108
- file_checkpoint = os.path.join(checkpoint_path, os.path.basename(ckpt_path))
 
 
 
109
  if not os.path.isfile(file_checkpoint):
110
  shutil.copy2(ckpt_path, file_checkpoint)
111
  print("copy checkpoint for finetune")
112
 
113
  # Use the tokenizer and tokenizer_path provided in the command line arguments
 
114
  tokenizer = args.tokenizer
115
  if tokenizer == "custom":
116
  if not args.tokenizer_path:
@@ -145,8 +186,9 @@ def main():
145
  args.learning_rate,
146
  num_warmup_updates=args.num_warmup_updates,
147
  save_per_updates=args.save_per_updates,
 
148
  checkpoint_path=checkpoint_path,
149
- batch_size=args.batch_size_per_gpu,
150
  batch_size_type=args.batch_size_type,
151
  max_samples=args.max_samples,
152
  grad_accumulation_steps=args.grad_accumulation_steps,
@@ -156,7 +198,7 @@ def main():
156
  wandb_run_name=args.exp_name,
157
  wandb_resume_id=wandb_resume_id,
158
  log_samples=args.log_samples,
159
- last_per_steps=args.last_per_steps,
160
  bnb_optimizer=args.bnb_optimizer,
161
  )
162
 
 
1
  import argparse
2
  import os
3
  import shutil
4
+ from importlib.resources import files
5
 
6
  from cached_path import cached_path
7
+
8
  from f5_tts.model import CFM, UNetT, DiT, Trainer
9
  from f5_tts.model.utils import get_tokenizer
10
  from f5_tts.model.dataset import load_dataset
 
11
 
12
 
13
  # -------------------------- Dataset Settings --------------------------- #
 
21
 
22
  # -------------------------- Argument Parsing --------------------------- #
23
  def parse_args():
 
 
 
 
 
 
 
 
 
24
  parser = argparse.ArgumentParser(description="Train CFM Model")
25
 
26
  parser.add_argument(
27
+ "--exp_name",
28
+ type=str,
29
+ default="F5TTS_v1_Base",
30
+ choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"],
31
+ help="Experiment name",
32
  )
33
  parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
34
  parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
 
40
  parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
41
  parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
42
  parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
43
+ parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup updates")
44
+ parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X updates")
45
+ parser.add_argument(
46
+ "--keep_last_n_checkpoints",
47
+ type=int,
48
+ default=-1,
49
+ help="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
50
+ )
51
+ parser.add_argument("--last_per_updates", type=int, default=50000, help="Save last checkpoint every X updates")
52
  parser.add_argument("--finetune", action="store_true", help="Use Finetune")
53
  parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
54
  parser.add_argument(
 
63
  parser.add_argument(
64
  "--log_samples",
65
  action="store_true",
66
+ help="Log inferenced samples per ckpt save updates",
67
  )
68
  parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
69
  parser.add_argument(
 
84
  checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
85
 
86
  # Model parameters based on experiment name
87
+
88
+ if args.exp_name == "F5TTS_v1_Base":
89
  wandb_resume_id = None
90
  model_cls = DiT
91
+ model_cfg = dict(
92
+ dim=1024,
93
+ depth=22,
94
+ heads=16,
95
+ ff_mult=2,
96
+ text_dim=512,
97
+ conv_layers=4,
98
+ )
99
+ if args.finetune:
100
+ if args.pretrain is None:
101
+ ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors"))
102
+ else:
103
+ ckpt_path = args.pretrain
104
+
105
+ elif args.exp_name == "F5TTS_Base":
106
+ wandb_resume_id = None
107
+ model_cls = DiT
108
+ model_cfg = dict(
109
+ dim=1024,
110
+ depth=22,
111
+ heads=16,
112
+ ff_mult=2,
113
+ text_dim=512,
114
+ text_mask_padding=False,
115
+ conv_layers=4,
116
+ pe_attn_head=1,
117
+ )
118
  if args.finetune:
119
  if args.pretrain is None:
120
  ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
121
  else:
122
  ckpt_path = args.pretrain
123
+
124
  elif args.exp_name == "E2TTS_Base":
125
  wandb_resume_id = None
126
  model_cls = UNetT
127
+ model_cfg = dict(
128
+ dim=1024,
129
+ depth=24,
130
+ heads=16,
131
+ ff_mult=4,
132
+ text_mask_padding=False,
133
+ pe_attn_head=1,
134
+ )
135
  if args.finetune:
136
  if args.pretrain is None:
137
  ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
 
142
  if not os.path.isdir(checkpoint_path):
143
  os.makedirs(checkpoint_path, exist_ok=True)
144
 
145
+ file_checkpoint = os.path.basename(ckpt_path)
146
+ if not file_checkpoint.startswith("pretrained_"): # Change: Add 'pretrained_' prefix to copied model
147
+ file_checkpoint = "pretrained_" + file_checkpoint
148
+ file_checkpoint = os.path.join(checkpoint_path, file_checkpoint)
149
  if not os.path.isfile(file_checkpoint):
150
  shutil.copy2(ckpt_path, file_checkpoint)
151
  print("copy checkpoint for finetune")
152
 
153
  # Use the tokenizer and tokenizer_path provided in the command line arguments
154
+
155
  tokenizer = args.tokenizer
156
  if tokenizer == "custom":
157
  if not args.tokenizer_path:
 
186
  args.learning_rate,
187
  num_warmup_updates=args.num_warmup_updates,
188
  save_per_updates=args.save_per_updates,
189
+ keep_last_n_checkpoints=args.keep_last_n_checkpoints,
190
  checkpoint_path=checkpoint_path,
191
+ batch_size_per_gpu=args.batch_size_per_gpu,
192
  batch_size_type=args.batch_size_type,
193
  max_samples=args.max_samples,
194
  grad_accumulation_steps=args.grad_accumulation_steps,
 
198
  wandb_run_name=args.exp_name,
199
  wandb_resume_id=wandb_resume_id,
200
  log_samples=args.log_samples,
201
+ last_per_updates=args.last_per_updates,
202
  bnb_optimizer=args.bnb_optimizer,
203
  )
204
 
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -1,36 +1,36 @@
1
- import threading
2
- import queue
3
- import re
4
-
5
  import gc
6
  import json
 
7
  import os
8
  import platform
9
  import psutil
 
10
  import random
 
11
  import signal
12
  import shutil
13
  import subprocess
14
  import sys
15
  import tempfile
 
16
  import time
17
  from glob import glob
 
 
18
 
19
  import click
20
  import gradio as gr
21
  import librosa
22
- import numpy as np
23
  import torch
24
  import torchaudio
 
25
  from datasets import Dataset as Dataset_
26
  from datasets.arrow_writer import ArrowWriter
27
- from safetensors.torch import save_file
28
- from scipy.io import wavfile
29
- from cached_path import cached_path
30
  from f5_tts.api import F5TTS
31
  from f5_tts.model.utils import convert_char_to_pinyin
32
  from f5_tts.infer.utils_infer import transcribe
33
- from importlib.resources import files
34
 
35
 
36
  training_process = None
@@ -46,7 +46,15 @@ path_data = str(files("f5_tts").joinpath("../../data"))
46
  path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts"))
47
  file_train = str(files("f5_tts").joinpath("train/finetune_cli.py"))
48
 
49
- device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
 
 
 
 
 
 
 
 
50
 
51
 
52
  # Save settings from a JSON file
@@ -62,7 +70,8 @@ def save_settings(
62
  epochs,
63
  num_warmup_updates,
64
  save_per_updates,
65
- last_per_steps,
 
66
  finetune,
67
  file_checkpoint_train,
68
  tokenizer_type,
@@ -86,7 +95,8 @@ def save_settings(
86
  "epochs": epochs,
87
  "num_warmup_updates": num_warmup_updates,
88
  "save_per_updates": save_per_updates,
89
- "last_per_steps": last_per_steps,
 
90
  "finetune": finetune,
91
  "file_checkpoint_train": file_checkpoint_train,
92
  "tokenizer_type": tokenizer_type,
@@ -106,73 +116,56 @@ def load_settings(project_name):
106
  path_project = os.path.join(path_project_ckpts, project_name)
107
  file_setting = os.path.join(path_project, "setting.json")
108
 
109
- if not os.path.isfile(file_setting):
110
- settings = {
111
- "exp_name": "F5TTS_Base",
112
- "learning_rate": 1e-05,
113
- "batch_size_per_gpu": 1000,
114
- "batch_size_type": "frame",
115
- "max_samples": 64,
116
- "grad_accumulation_steps": 1,
117
- "max_grad_norm": 1,
118
- "epochs": 100,
119
- "num_warmup_updates": 2,
120
- "save_per_updates": 300,
121
- "last_per_steps": 100,
122
- "finetune": True,
123
- "file_checkpoint_train": "",
124
- "tokenizer_type": "pinyin",
125
- "tokenizer_file": "",
126
- "mixed_precision": "none",
127
- "logger": "wandb",
128
- "bnb_optimizer": False,
129
- }
130
- return (
131
- settings["exp_name"],
132
- settings["learning_rate"],
133
- settings["batch_size_per_gpu"],
134
- settings["batch_size_type"],
135
- settings["max_samples"],
136
- settings["grad_accumulation_steps"],
137
- settings["max_grad_norm"],
138
- settings["epochs"],
139
- settings["num_warmup_updates"],
140
- settings["save_per_updates"],
141
- settings["last_per_steps"],
142
- settings["finetune"],
143
- settings["file_checkpoint_train"],
144
- settings["tokenizer_type"],
145
- settings["tokenizer_file"],
146
- settings["mixed_precision"],
147
- settings["logger"],
148
- settings["bnb_optimizer"],
149
- )
150
 
151
- with open(file_setting, "r") as f:
152
- settings = json.load(f)
153
- if "logger" not in settings:
154
- settings["logger"] = "wandb"
155
- if "bnb_optimizer" not in settings:
156
- settings["bnb_optimizer"] = False
157
  return (
158
- settings["exp_name"],
159
- settings["learning_rate"],
160
- settings["batch_size_per_gpu"],
161
- settings["batch_size_type"],
162
- settings["max_samples"],
163
- settings["grad_accumulation_steps"],
164
- settings["max_grad_norm"],
165
- settings["epochs"],
166
- settings["num_warmup_updates"],
167
- settings["save_per_updates"],
168
- settings["last_per_steps"],
169
- settings["finetune"],
170
- settings["file_checkpoint_train"],
171
- settings["tokenizer_type"],
172
- settings["tokenizer_file"],
173
- settings["mixed_precision"],
174
- settings["logger"],
175
- settings["bnb_optimizer"],
 
176
  )
177
 
178
 
@@ -369,17 +362,18 @@ def terminate_process(pid):
369
 
370
  def start_training(
371
  dataset_name="",
372
- exp_name="F5TTS_Base",
373
- learning_rate=1e-4,
374
- batch_size_per_gpu=400,
375
- batch_size_type="frame",
376
  max_samples=64,
377
- grad_accumulation_steps=1,
378
  max_grad_norm=1.0,
379
- epochs=11,
380
- num_warmup_updates=200,
381
- save_per_updates=400,
382
- last_per_steps=800,
 
383
  finetune=True,
384
  file_checkpoint_train="",
385
  tokenizer_type="pinyin",
@@ -438,18 +432,19 @@ def start_training(
438
  fp16 = ""
439
 
440
  cmd = (
441
- f"accelerate launch {fp16} {file_train} --exp_name {exp_name} "
442
- f"--learning_rate {learning_rate} "
443
- f"--batch_size_per_gpu {batch_size_per_gpu} "
444
- f"--batch_size_type {batch_size_type} "
445
- f"--max_samples {max_samples} "
446
- f"--grad_accumulation_steps {grad_accumulation_steps} "
447
- f"--max_grad_norm {max_grad_norm} "
448
- f"--epochs {epochs} "
449
- f"--num_warmup_updates {num_warmup_updates} "
450
- f"--save_per_updates {save_per_updates} "
451
- f"--last_per_steps {last_per_steps} "
452
- f"--dataset_name {dataset_name}"
 
453
  )
454
 
455
  if finetune:
@@ -482,7 +477,8 @@ def start_training(
482
  epochs,
483
  num_warmup_updates,
484
  save_per_updates,
485
- last_per_steps,
 
486
  finetune,
487
  file_checkpoint_train,
488
  tokenizer_type,
@@ -548,7 +544,7 @@ def start_training(
548
  output = stdout_queue.get_nowait()
549
  print(output, end="")
550
  match = re.search(
551
- r"Epoch (\d+)/(\d+):\s+(\d+)%\|.*\[(\d+:\d+)<.*?loss=(\d+\.\d+), step=(\d+)", output
552
  )
553
  if match:
554
  current_epoch = match.group(1)
@@ -556,13 +552,13 @@ def start_training(
556
  percent_complete = match.group(3)
557
  elapsed_time = match.group(4)
558
  loss = match.group(5)
559
- current_step = match.group(6)
560
  message = (
561
  f"Epoch: {current_epoch}/{total_epochs}, "
562
  f"Progress: {percent_complete}%, "
563
  f"Elapsed Time: {elapsed_time}, "
564
  f"Loss: {loss}, "
565
- f"Step: {current_step}"
566
  )
567
  yield message, gr.update(interactive=False), gr.update(interactive=True)
568
  elif output.strip():
@@ -801,14 +797,14 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
801
  print(f"Error processing {file_audio}: {e}")
802
  continue
803
 
804
- if duration < 1 or duration > 25:
805
- if duration > 25:
806
- error_files.append([file_audio, "duration > 25 sec"])
807
  if duration < 1:
808
  error_files.append([file_audio, "duration < 1 sec "])
809
  continue
810
  if len(text) < 3:
811
- error_files.append([file_audio, "very small text len 3"])
812
  continue
813
 
814
  text = clear_text(text)
@@ -875,40 +871,37 @@ def check_user(value):
875
 
876
  def calculate_train(
877
  name_project,
 
 
 
878
  batch_size_type,
879
  max_samples,
880
- learning_rate,
881
  num_warmup_updates,
882
- save_per_updates,
883
- last_per_steps,
884
  finetune,
885
  ):
886
  path_project = os.path.join(path_data, name_project)
887
- file_duraction = os.path.join(path_project, "duration.json")
888
 
889
- if not os.path.isfile(file_duraction):
 
 
 
890
  return (
891
- 1000,
 
 
892
  max_samples,
893
  num_warmup_updates,
894
- save_per_updates,
895
- last_per_steps,
896
  "project not found !",
897
- learning_rate,
898
  )
899
 
900
- with open(file_duraction, "r") as file:
901
  data = json.load(file)
902
 
903
  duration_list = data["duration"]
904
- samples = len(duration_list)
905
- hours = sum(duration_list) / 3600
906
-
907
- # if torch.cuda.is_available():
908
- # gpu_properties = torch.cuda.get_device_properties(0)
909
- # total_memory = gpu_properties.total_memory / (1024**3)
910
- # elif torch.backends.mps.is_available():
911
- # total_memory = psutil.virtual_memory().available / (1024**3)
912
 
913
  if torch.cuda.is_available():
914
  gpu_count = torch.cuda.device_count()
@@ -916,57 +909,39 @@ def calculate_train(
916
  for i in range(gpu_count):
917
  gpu_properties = torch.cuda.get_device_properties(i)
918
  total_memory += gpu_properties.total_memory / (1024**3) # in GB
919
-
 
 
 
 
 
920
  elif torch.backends.mps.is_available():
921
  gpu_count = 1
922
  total_memory = psutil.virtual_memory().available / (1024**3)
923
 
 
 
 
924
  if batch_size_type == "frame":
925
- batch = int(total_memory * 0.5)
926
- batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
927
- batch_size_per_gpu = int(38400 / batch)
928
- else:
929
- batch_size_per_gpu = int(total_memory / 8)
930
- batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
931
- batch = batch_size_per_gpu
932
 
933
- if batch_size_per_gpu <= 0:
934
- batch_size_per_gpu = 1
935
 
936
- if samples < 64:
937
- max_samples = int(samples * 0.25)
938
- else:
939
- max_samples = 64
940
-
941
- num_warmup_updates = int(samples * 0.05)
942
- save_per_updates = int(samples * 0.10)
943
- last_per_steps = int(save_per_updates * 0.25)
944
-
945
- max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
946
- num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
947
- save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
948
- last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
949
- if last_per_steps <= 0:
950
- last_per_steps = 2
951
-
952
- total_hours = hours
953
- mel_hop_length = 256
954
- mel_sampling_rate = 24000
955
-
956
- # target
957
- wanted_max_updates = 1000000
958
-
959
- # train params
960
- gpus = gpu_count
961
- frames_per_gpu = batch_size_per_gpu # 8 * 38400 = 307200
962
- grad_accum = 1
963
-
964
- # intermediate
965
- mini_batch_frames = frames_per_gpu * grad_accum * gpus
966
- mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
967
- updates_per_epoch = total_hours / mini_batch_hours
968
- # steps_per_epoch = updates_per_epoch * grad_accum
969
- epochs = wanted_max_updates / updates_per_epoch
970
 
971
  if finetune:
972
  learning_rate = 1e-5
@@ -974,20 +949,18 @@ def calculate_train(
974
  learning_rate = 7.5e-5
975
 
976
  return (
 
 
977
  batch_size_per_gpu,
978
  max_samples,
979
  num_warmup_updates,
980
- save_per_updates,
981
- last_per_steps,
982
- samples,
983
- learning_rate,
984
- int(epochs),
985
  )
986
 
987
 
988
  def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, safetensors: bool) -> str:
989
  try:
990
- checkpoint = torch.load(checkpoint_path)
991
  print("Original Checkpoint Keys:", checkpoint.keys())
992
 
993
  ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
@@ -1018,7 +991,11 @@ def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42):
1018
  torch.backends.cudnn.deterministic = True
1019
  torch.backends.cudnn.benchmark = False
1020
 
1021
- ckpt = torch.load(ckpt_path, map_location="cpu")
 
 
 
 
1022
 
1023
  ema_sd = ckpt.get("ema_model_state_dict", {})
1024
  embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight"
@@ -1086,9 +1063,11 @@ def vocab_extend(project_name, symbols, model_type):
1086
  with open(file_vocab_project, "w", encoding="utf-8") as f:
1087
  f.write("\n".join(vocab))
1088
 
1089
- if model_type == "F5-TTS":
 
 
1090
  ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
1091
- else:
1092
  ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
1093
 
1094
  vocab_size_new = len(miss_symbols)
@@ -1096,7 +1075,9 @@ def vocab_extend(project_name, symbols, model_type):
1096
  dataset_name = name_project.replace("_pinyin", "").replace("_char", "")
1097
  new_ckpt_path = os.path.join(path_project_ckpts, dataset_name)
1098
  os.makedirs(new_ckpt_path, exist_ok=True)
1099
- new_ckpt_file = os.path.join(new_ckpt_path, "model_1200000.pt")
 
 
1100
 
1101
  size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)
1102
 
@@ -1226,21 +1207,21 @@ def infer(
1226
  vocab_file = os.path.join(path_data, project, "vocab.txt")
1227
 
1228
  tts_api = F5TTS(
1229
- model_type=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema
1230
  )
1231
 
1232
  print("update >> ", device_test, file_checkpoint, use_ema)
1233
 
1234
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
1235
  tts_api.infer(
1236
- gen_text=gen_text.lower().strip(),
1237
- ref_text=ref_text.lower().strip(),
1238
  ref_file=ref_audio,
 
 
1239
  nfe_step=nfe_step,
1240
- file_wave=f.name,
1241
  speed=speed,
1242
- seed=seed,
1243
  remove_silence=remove_silence,
 
 
1244
  )
1245
  return f.name, tts_api.device, str(tts_api.seed)
1246
 
@@ -1256,12 +1237,22 @@ def get_checkpoints_project(project_name, is_gradio=True):
1256
 
1257
  if os.path.isdir(path_project_ckpts):
1258
  files_checkpoints = glob(os.path.join(path_project_ckpts, project_name, "*.pt"))
1259
- files_checkpoints = sorted(
1260
- files_checkpoints,
1261
- key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0])
1262
- if os.path.basename(x) != "model_last.pt"
1263
- else float("inf"),
 
 
 
 
 
 
 
1264
  )
 
 
 
1265
  else:
1266
  files_checkpoints = []
1267
 
@@ -1312,7 +1303,21 @@ def get_gpu_stats():
1312
  f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
1313
  f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
1314
  )
 
 
 
 
 
 
 
 
1315
 
 
 
 
 
 
 
1316
  elif torch.backends.mps.is_available():
1317
  gpu_count = 1
1318
  gpu_stats += "MPS GPU\n"
@@ -1375,14 +1380,14 @@ def get_audio_select(file_sample):
1375
  with gr.Blocks() as app:
1376
  gr.Markdown(
1377
  """
1378
- # E2/F5 TTS Automatic Finetune
1379
 
1380
- This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
1381
 
1382
  * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
1383
  * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
1384
 
1385
- The checkpoints support English and Chinese.
1386
 
1387
  For tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143)
1388
  """
@@ -1459,7 +1464,9 @@ Check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are incl
1459
  Using the extended model, you can finetune to a new language that is missing symbols in the vocab. This creates a new model with a new vocabulary size and saves it in your ckpts/project folder.
1460
  ```""")
1461
 
1462
- exp_name_extend = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
 
 
1463
 
1464
  with gr.Row():
1465
  txt_extend = gr.Textbox(
@@ -1528,9 +1535,9 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
1528
  fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
1529
  )
1530
 
1531
- with gr.TabItem("Train Data"):
1532
  gr.Markdown("""```plaintext
1533
- The auto-setting is still experimental. Please make sure that the epochs, save per updates, and last per steps are set correctly, or change them manually as needed.
1534
  If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
1535
  ```""")
1536
  with gr.Row():
@@ -1544,11 +1551,13 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1544
  file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint", value="")
1545
 
1546
  with gr.Row():
1547
- exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
 
 
1548
  learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
1549
 
1550
  with gr.Row():
1551
- batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
1552
  max_samples = gr.Number(label="Max Samples", value=64)
1553
 
1554
  with gr.Row():
@@ -1556,59 +1565,70 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1556
  max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
1557
 
1558
  with gr.Row():
1559
- epochs = gr.Number(label="Epochs", value=10)
1560
- num_warmup_updates = gr.Number(label="Warmup Updates", value=2)
1561
 
1562
  with gr.Row():
1563
- save_per_updates = gr.Number(label="Save per Updates", value=300)
1564
- last_per_steps = gr.Number(label="Last per Steps", value=100)
 
 
 
 
 
 
 
1565
 
1566
  with gr.Row():
1567
  ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
1568
- mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="none")
1569
  cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
1570
  start_button = gr.Button("Start Training")
1571
  stop_button = gr.Button("Stop Training", interactive=False)
1572
 
1573
  if projects_selelect is not None:
1574
  (
1575
- exp_namev,
1576
- learning_ratev,
1577
- batch_size_per_gpuv,
1578
- batch_size_typev,
1579
- max_samplesv,
1580
- grad_accumulation_stepsv,
1581
- max_grad_normv,
1582
- epochsv,
1583
- num_warmupv_updatesv,
1584
- save_per_updatesv,
1585
- last_per_stepsv,
1586
- finetunev,
1587
- file_checkpoint_trainv,
1588
- tokenizer_typev,
1589
- tokenizer_filev,
1590
- mixed_precisionv,
1591
- cd_loggerv,
1592
- ch_8bit_adamv,
 
1593
  ) = load_settings(projects_selelect)
1594
- exp_name.value = exp_namev
1595
- learning_rate.value = learning_ratev
1596
- batch_size_per_gpu.value = batch_size_per_gpuv
1597
- batch_size_type.value = batch_size_typev
1598
- max_samples.value = max_samplesv
1599
- grad_accumulation_steps.value = grad_accumulation_stepsv
1600
- max_grad_norm.value = max_grad_normv
1601
- epochs.value = epochsv
1602
- num_warmup_updates.value = num_warmupv_updatesv
1603
- save_per_updates.value = save_per_updatesv
1604
- last_per_steps.value = last_per_stepsv
1605
- ch_finetune.value = finetunev
1606
- file_checkpoint_train.value = file_checkpoint_trainv
1607
- tokenizer_type.value = tokenizer_typev
1608
- tokenizer_file.value = tokenizer_filev
1609
- mixed_precision.value = mixed_precisionv
1610
- cd_logger.value = cd_loggerv
1611
- ch_8bit_adam.value = ch_8bit_adamv
 
 
 
1612
 
1613
  ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
1614
  txt_info_train = gr.Text(label="Info", value="")
@@ -1659,7 +1679,8 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1659
  epochs,
1660
  num_warmup_updates,
1661
  save_per_updates,
1662
- last_per_steps,
 
1663
  ch_finetune,
1664
  file_checkpoint_train,
1665
  tokenizer_type,
@@ -1677,23 +1698,21 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1677
  fn=calculate_train,
1678
  inputs=[
1679
  cm_project,
 
 
 
1680
  batch_size_type,
1681
  max_samples,
1682
- learning_rate,
1683
  num_warmup_updates,
1684
- save_per_updates,
1685
- last_per_steps,
1686
  ch_finetune,
1687
  ],
1688
  outputs=[
 
 
1689
  batch_size_per_gpu,
1690
  max_samples,
1691
  num_warmup_updates,
1692
- save_per_updates,
1693
- last_per_steps,
1694
  lb_samples,
1695
- learning_rate,
1696
- epochs,
1697
  ],
1698
  )
1699
 
@@ -1713,15 +1732,16 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1713
  epochs,
1714
  num_warmup_updates,
1715
  save_per_updates,
1716
- last_per_steps,
 
1717
  ch_finetune,
1718
  file_checkpoint_train,
1719
  tokenizer_type,
1720
  tokenizer_file,
1721
  mixed_precision,
1722
  cd_logger,
 
1723
  ]
1724
-
1725
  return output_components
1726
 
1727
  outputs = setup_load_settings()
@@ -1742,7 +1762,9 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1742
  gr.Markdown("""```plaintext
1743
  SOS: Check the use_ema setting (True or False) for your model to see what works best for you. use seed -1 from random
1744
  ```""")
1745
- exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
 
 
1746
  list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
1747
 
1748
  with gr.Row():
@@ -1796,9 +1818,9 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
1796
  bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
1797
  cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
1798
 
1799
- with gr.TabItem("Reduce Checkpoint"):
1800
  gr.Markdown("""```plaintext
1801
- Reduce the model size from 5GB to 1.3GB. The new checkpoint can be used for inference or fine-tuning afterward, but it cannot be used to continue training.
1802
  ```""")
1803
  txt_path_checkpoint = gr.Text(label="Path to Checkpoint:")
1804
  txt_path_checkpoint_small = gr.Text(label="Path to Output:")
 
 
 
 
 
1
  import gc
2
  import json
3
+ import numpy as np
4
  import os
5
  import platform
6
  import psutil
7
+ import queue
8
  import random
9
+ import re
10
  import signal
11
  import shutil
12
  import subprocess
13
  import sys
14
  import tempfile
15
+ import threading
16
  import time
17
  from glob import glob
18
+ from importlib.resources import files
19
+ from scipy.io import wavfile
20
 
21
  import click
22
  import gradio as gr
23
  import librosa
 
24
  import torch
25
  import torchaudio
26
+ from cached_path import cached_path
27
  from datasets import Dataset as Dataset_
28
  from datasets.arrow_writer import ArrowWriter
29
+ from safetensors.torch import load_file, save_file
30
+
 
31
  from f5_tts.api import F5TTS
32
  from f5_tts.model.utils import convert_char_to_pinyin
33
  from f5_tts.infer.utils_infer import transcribe
 
34
 
35
 
36
  training_process = None
 
46
  path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts"))
47
  file_train = str(files("f5_tts").joinpath("train/finetune_cli.py"))
48
 
49
+ device = (
50
+ "cuda"
51
+ if torch.cuda.is_available()
52
+ else "xpu"
53
+ if torch.xpu.is_available()
54
+ else "mps"
55
+ if torch.backends.mps.is_available()
56
+ else "cpu"
57
+ )
58
 
59
 
60
  # Save settings from a JSON file
 
70
  epochs,
71
  num_warmup_updates,
72
  save_per_updates,
73
+ keep_last_n_checkpoints,
74
+ last_per_updates,
75
  finetune,
76
  file_checkpoint_train,
77
  tokenizer_type,
 
95
  "epochs": epochs,
96
  "num_warmup_updates": num_warmup_updates,
97
  "save_per_updates": save_per_updates,
98
+ "keep_last_n_checkpoints": keep_last_n_checkpoints,
99
+ "last_per_updates": last_per_updates,
100
  "finetune": finetune,
101
  "file_checkpoint_train": file_checkpoint_train,
102
  "tokenizer_type": tokenizer_type,
 
116
  path_project = os.path.join(path_project_ckpts, project_name)
117
  file_setting = os.path.join(path_project, "setting.json")
118
 
119
+ # Default settings
120
+ default_settings = {
121
+ "exp_name": "F5TTS_v1_Base",
122
+ "learning_rate": 1e-5,
123
+ "batch_size_per_gpu": 1,
124
+ "batch_size_type": "sample",
125
+ "max_samples": 64,
126
+ "grad_accumulation_steps": 4,
127
+ "max_grad_norm": 1,
128
+ "epochs": 100,
129
+ "num_warmup_updates": 100,
130
+ "save_per_updates": 500,
131
+ "keep_last_n_checkpoints": -1,
132
+ "last_per_updates": 100,
133
+ "finetune": True,
134
+ "file_checkpoint_train": "",
135
+ "tokenizer_type": "pinyin",
136
+ "tokenizer_file": "",
137
+ "mixed_precision": "none",
138
+ "logger": "wandb",
139
+ "bnb_optimizer": False,
140
+ }
141
+
142
+ # Load settings from file if it exists
143
+ if os.path.isfile(file_setting):
144
+ with open(file_setting, "r") as f:
145
+ file_settings = json.load(f)
146
+ default_settings.update(file_settings)
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
+ # Return as a tuple in the correct order
 
 
 
 
 
149
  return (
150
+ default_settings["exp_name"],
151
+ default_settings["learning_rate"],
152
+ default_settings["batch_size_per_gpu"],
153
+ default_settings["batch_size_type"],
154
+ default_settings["max_samples"],
155
+ default_settings["grad_accumulation_steps"],
156
+ default_settings["max_grad_norm"],
157
+ default_settings["epochs"],
158
+ default_settings["num_warmup_updates"],
159
+ default_settings["save_per_updates"],
160
+ default_settings["keep_last_n_checkpoints"],
161
+ default_settings["last_per_updates"],
162
+ default_settings["finetune"],
163
+ default_settings["file_checkpoint_train"],
164
+ default_settings["tokenizer_type"],
165
+ default_settings["tokenizer_file"],
166
+ default_settings["mixed_precision"],
167
+ default_settings["logger"],
168
+ default_settings["bnb_optimizer"],
169
  )
170
 
171
 
 
362
 
363
  def start_training(
364
  dataset_name="",
365
+ exp_name="F5TTS_v1_Base",
366
+ learning_rate=1e-5,
367
+ batch_size_per_gpu=1,
368
+ batch_size_type="sample",
369
  max_samples=64,
370
+ grad_accumulation_steps=4,
371
  max_grad_norm=1.0,
372
+ epochs=100,
373
+ num_warmup_updates=100,
374
+ save_per_updates=500,
375
+ keep_last_n_checkpoints=-1,
376
+ last_per_updates=100,
377
  finetune=True,
378
  file_checkpoint_train="",
379
  tokenizer_type="pinyin",
 
432
  fp16 = ""
433
 
434
  cmd = (
435
+ f"accelerate launch {fp16} {file_train} --exp_name {exp_name}"
436
+ f" --learning_rate {learning_rate}"
437
+ f" --batch_size_per_gpu {batch_size_per_gpu}"
438
+ f" --batch_size_type {batch_size_type}"
439
+ f" --max_samples {max_samples}"
440
+ f" --grad_accumulation_steps {grad_accumulation_steps}"
441
+ f" --max_grad_norm {max_grad_norm}"
442
+ f" --epochs {epochs}"
443
+ f" --num_warmup_updates {num_warmup_updates}"
444
+ f" --save_per_updates {save_per_updates}"
445
+ f" --keep_last_n_checkpoints {keep_last_n_checkpoints}"
446
+ f" --last_per_updates {last_per_updates}"
447
+ f" --dataset_name {dataset_name}"
448
  )
449
 
450
  if finetune:
 
477
  epochs,
478
  num_warmup_updates,
479
  save_per_updates,
480
+ keep_last_n_checkpoints,
481
+ last_per_updates,
482
  finetune,
483
  file_checkpoint_train,
484
  tokenizer_type,
 
544
  output = stdout_queue.get_nowait()
545
  print(output, end="")
546
  match = re.search(
547
+ r"Epoch (\d+)/(\d+):\s+(\d+)%\|.*\[(\d+:\d+)<.*?loss=(\d+\.\d+), update=(\d+)", output
548
  )
549
  if match:
550
  current_epoch = match.group(1)
 
552
  percent_complete = match.group(3)
553
  elapsed_time = match.group(4)
554
  loss = match.group(5)
555
+ current_update = match.group(6)
556
  message = (
557
  f"Epoch: {current_epoch}/{total_epochs}, "
558
  f"Progress: {percent_complete}%, "
559
  f"Elapsed Time: {elapsed_time}, "
560
  f"Loss: {loss}, "
561
+ f"Update: {current_update}"
562
  )
563
  yield message, gr.update(interactive=False), gr.update(interactive=True)
564
  elif output.strip():
 
797
  print(f"Error processing {file_audio}: {e}")
798
  continue
799
 
800
+ if duration < 1 or duration > 30:
801
+ if duration > 30:
802
+ error_files.append([file_audio, "duration > 30 sec"])
803
  if duration < 1:
804
  error_files.append([file_audio, "duration < 1 sec "])
805
  continue
806
  if len(text) < 3:
807
+ error_files.append([file_audio, "very short text length 3"])
808
  continue
809
 
810
  text = clear_text(text)
 
871
 
872
  def calculate_train(
873
  name_project,
874
+ epochs,
875
+ learning_rate,
876
+ batch_size_per_gpu,
877
  batch_size_type,
878
  max_samples,
 
879
  num_warmup_updates,
 
 
880
  finetune,
881
  ):
882
  path_project = os.path.join(path_data, name_project)
883
+ file_duration = os.path.join(path_project, "duration.json")
884
 
885
+ hop_length = 256
886
+ sampling_rate = 24000
887
+
888
+ if not os.path.isfile(file_duration):
889
  return (
890
+ epochs,
891
+ learning_rate,
892
+ batch_size_per_gpu,
893
  max_samples,
894
  num_warmup_updates,
 
 
895
  "project not found !",
 
896
  )
897
 
898
+ with open(file_duration, "r") as file:
899
  data = json.load(file)
900
 
901
  duration_list = data["duration"]
902
+ max_sample_length = max(duration_list) * sampling_rate / hop_length
903
+ total_samples = len(duration_list)
904
+ total_duration = sum(duration_list)
 
 
 
 
 
905
 
906
  if torch.cuda.is_available():
907
  gpu_count = torch.cuda.device_count()
 
909
  for i in range(gpu_count):
910
  gpu_properties = torch.cuda.get_device_properties(i)
911
  total_memory += gpu_properties.total_memory / (1024**3) # in GB
912
+ elif torch.xpu.is_available():
913
+ gpu_count = torch.xpu.device_count()
914
+ total_memory = 0
915
+ for i in range(gpu_count):
916
+ gpu_properties = torch.xpu.get_device_properties(i)
917
+ total_memory += gpu_properties.total_memory / (1024**3)
918
  elif torch.backends.mps.is_available():
919
  gpu_count = 1
920
  total_memory = psutil.virtual_memory().available / (1024**3)
921
 
922
+ avg_gpu_memory = total_memory / gpu_count
923
+
924
+ # rough estimate of batch size
925
  if batch_size_type == "frame":
926
+ batch_size_per_gpu = max(int(38400 * (avg_gpu_memory - 5) / 75), int(max_sample_length))
927
+ elif batch_size_type == "sample":
928
+ batch_size_per_gpu = int(200 / (total_duration / total_samples))
 
 
 
 
929
 
930
+ if total_samples < 64:
931
+ max_samples = int(total_samples * 0.25)
932
 
933
+ num_warmup_updates = max(num_warmup_updates, int(total_samples * 0.05))
934
+
935
+ # take 1.2M updates as the maximum
936
+ max_updates = 1200000
937
+
938
+ if batch_size_type == "frame":
939
+ mini_batch_duration = batch_size_per_gpu * gpu_count * hop_length / sampling_rate
940
+ updates_per_epoch = total_duration / mini_batch_duration
941
+ elif batch_size_type == "sample":
942
+ updates_per_epoch = total_samples / batch_size_per_gpu / gpu_count
943
+
944
+ epochs = int(max_updates / updates_per_epoch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
945
 
946
  if finetune:
947
  learning_rate = 1e-5
 
949
  learning_rate = 7.5e-5
950
 
951
  return (
952
+ epochs,
953
+ learning_rate,
954
  batch_size_per_gpu,
955
  max_samples,
956
  num_warmup_updates,
957
+ total_samples,
 
 
 
 
958
  )
959
 
960
 
961
  def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, safetensors: bool) -> str:
962
  try:
963
+ checkpoint = torch.load(checkpoint_path, weights_only=True)
964
  print("Original Checkpoint Keys:", checkpoint.keys())
965
 
966
  ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
 
991
  torch.backends.cudnn.deterministic = True
992
  torch.backends.cudnn.benchmark = False
993
 
994
+ if ckpt_path.endswith(".safetensors"):
995
+ ckpt = load_file(ckpt_path, device="cpu")
996
+ ckpt = {"ema_model_state_dict": ckpt}
997
+ elif ckpt_path.endswith(".pt"):
998
+ ckpt = torch.load(ckpt_path, map_location="cpu")
999
 
1000
  ema_sd = ckpt.get("ema_model_state_dict", {})
1001
  embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight"
 
1063
  with open(file_vocab_project, "w", encoding="utf-8") as f:
1064
  f.write("\n".join(vocab))
1065
 
1066
+ if model_type == "F5TTS_v1_Base":
1067
+ ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors"))
1068
+ elif model_type == "F5TTS_Base":
1069
  ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
1070
+ elif model_type == "E2TTS_Base":
1071
  ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
1072
 
1073
  vocab_size_new = len(miss_symbols)
 
1075
  dataset_name = name_project.replace("_pinyin", "").replace("_char", "")
1076
  new_ckpt_path = os.path.join(path_project_ckpts, dataset_name)
1077
  os.makedirs(new_ckpt_path, exist_ok=True)
1078
+
1079
+ # Add pretrained_ prefix to model when copying for consistency with finetune_cli.py
1080
+ new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_" + os.path.basename(ckpt_path))
1081
 
1082
  size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)
1083
 
 
1207
  vocab_file = os.path.join(path_data, project, "vocab.txt")
1208
 
1209
  tts_api = F5TTS(
1210
+ model=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema
1211
  )
1212
 
1213
  print("update >> ", device_test, file_checkpoint, use_ema)
1214
 
1215
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
1216
  tts_api.infer(
 
 
1217
  ref_file=ref_audio,
1218
+ ref_text=ref_text.lower().strip(),
1219
+ gen_text=gen_text.lower().strip(),
1220
  nfe_step=nfe_step,
 
1221
  speed=speed,
 
1222
  remove_silence=remove_silence,
1223
+ file_wave=f.name,
1224
+ seed=seed,
1225
  )
1226
  return f.name, tts_api.device, str(tts_api.seed)
1227
 
 
1237
 
1238
  if os.path.isdir(path_project_ckpts):
1239
  files_checkpoints = glob(os.path.join(path_project_ckpts, project_name, "*.pt"))
1240
+ # Separate pretrained and regular checkpoints
1241
+ pretrained_checkpoints = [f for f in files_checkpoints if "pretrained_" in os.path.basename(f)]
1242
+ regular_checkpoints = [
1243
+ f
1244
+ for f in files_checkpoints
1245
+ if "pretrained_" not in os.path.basename(f) and "model_last.pt" not in os.path.basename(f)
1246
+ ]
1247
+ last_checkpoint = [f for f in files_checkpoints if "model_last.pt" in os.path.basename(f)]
1248
+
1249
+ # Sort regular checkpoints by number
1250
+ regular_checkpoints = sorted(
1251
+ regular_checkpoints, key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0])
1252
  )
1253
+
1254
+ # Combine in order: pretrained, regular, last
1255
+ files_checkpoints = pretrained_checkpoints + regular_checkpoints + last_checkpoint
1256
  else:
1257
  files_checkpoints = []
1258
 
 
1303
  f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
1304
  f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
1305
  )
1306
+ elif torch.xpu.is_available():
1307
+ gpu_count = torch.xpu.device_count()
1308
+ for i in range(gpu_count):
1309
+ gpu_name = torch.xpu.get_device_name(i)
1310
+ gpu_properties = torch.xpu.get_device_properties(i)
1311
+ total_memory = gpu_properties.total_memory / (1024**3) # in GB
1312
+ allocated_memory = torch.xpu.memory_allocated(i) / (1024**2) # in MB
1313
+ reserved_memory = torch.xpu.memory_reserved(i) / (1024**2) # in MB
1314
 
1315
+ gpu_stats += (
1316
+ f"GPU {i} Name: {gpu_name}\n"
1317
+ f"Total GPU memory (GPU {i}): {total_memory:.2f} GB\n"
1318
+ f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
1319
+ f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
1320
+ )
1321
  elif torch.backends.mps.is_available():
1322
  gpu_count = 1
1323
  gpu_stats += "MPS GPU\n"
 
1380
  with gr.Blocks() as app:
1381
  gr.Markdown(
1382
  """
1383
+ # F5 TTS Automatic Finetune
1384
 
1385
+ This is a local web UI for F5 TTS finetuning support. This app supports the following TTS models:
1386
 
1387
  * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
1388
  * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
1389
 
1390
+ The pretrained checkpoints support English and Chinese.
1391
 
1392
  For tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143)
1393
  """
 
1464
  Using the extended model, you can finetune to a new language that is missing symbols in the vocab. This creates a new model with a new vocabulary size and saves it in your ckpts/project folder.
1465
  ```""")
1466
 
1467
+ exp_name_extend = gr.Radio(
1468
+ label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
1469
+ )
1470
 
1471
  with gr.Row():
1472
  txt_extend = gr.Textbox(
 
1535
  fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
1536
  )
1537
 
1538
+ with gr.TabItem("Train Model"):
1539
  gr.Markdown("""```plaintext
1540
+ The auto-setting is still experimental. Set a large value of epoch if not sure; and keep last N checkpoints if limited disk space.
1541
  If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
1542
  ```""")
1543
  with gr.Row():
 
1551
  file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint", value="")
1552
 
1553
  with gr.Row():
1554
+ exp_name = gr.Radio(
1555
+ label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
1556
+ )
1557
  learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
1558
 
1559
  with gr.Row():
1560
+ batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=3200)
1561
  max_samples = gr.Number(label="Max Samples", value=64)
1562
 
1563
  with gr.Row():
 
1565
  max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
1566
 
1567
  with gr.Row():
1568
+ epochs = gr.Number(label="Epochs", value=100)
1569
+ num_warmup_updates = gr.Number(label="Warmup Updates", value=100)
1570
 
1571
  with gr.Row():
1572
+ save_per_updates = gr.Number(label="Save per Updates", value=500)
1573
+ keep_last_n_checkpoints = gr.Number(
1574
+ label="Keep Last N Checkpoints",
1575
+ value=-1,
1576
+ step=1,
1577
+ precision=0,
1578
+ info="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
1579
+ )
1580
+ last_per_updates = gr.Number(label="Last per Updates", value=100)
1581
 
1582
  with gr.Row():
1583
  ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
1584
+ mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="fp16")
1585
  cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
1586
  start_button = gr.Button("Start Training")
1587
  stop_button = gr.Button("Stop Training", interactive=False)
1588
 
1589
  if projects_selelect is not None:
1590
  (
1591
+ exp_name_value,
1592
+ learning_rate_value,
1593
+ batch_size_per_gpu_value,
1594
+ batch_size_type_value,
1595
+ max_samples_value,
1596
+ grad_accumulation_steps_value,
1597
+ max_grad_norm_value,
1598
+ epochs_value,
1599
+ num_warmup_updates_value,
1600
+ save_per_updates_value,
1601
+ keep_last_n_checkpoints_value,
1602
+ last_per_updates_value,
1603
+ finetune_value,
1604
+ file_checkpoint_train_value,
1605
+ tokenizer_type_value,
1606
+ tokenizer_file_value,
1607
+ mixed_precision_value,
1608
+ logger_value,
1609
+ bnb_optimizer_value,
1610
  ) = load_settings(projects_selelect)
1611
+
1612
+ # Assigning values to the respective components
1613
+ exp_name.value = exp_name_value
1614
+ learning_rate.value = learning_rate_value
1615
+ batch_size_per_gpu.value = batch_size_per_gpu_value
1616
+ batch_size_type.value = batch_size_type_value
1617
+ max_samples.value = max_samples_value
1618
+ grad_accumulation_steps.value = grad_accumulation_steps_value
1619
+ max_grad_norm.value = max_grad_norm_value
1620
+ epochs.value = epochs_value
1621
+ num_warmup_updates.value = num_warmup_updates_value
1622
+ save_per_updates.value = save_per_updates_value
1623
+ keep_last_n_checkpoints.value = keep_last_n_checkpoints_value
1624
+ last_per_updates.value = last_per_updates_value
1625
+ ch_finetune.value = finetune_value
1626
+ file_checkpoint_train.value = file_checkpoint_train_value
1627
+ tokenizer_type.value = tokenizer_type_value
1628
+ tokenizer_file.value = tokenizer_file_value
1629
+ mixed_precision.value = mixed_precision_value
1630
+ cd_logger.value = logger_value
1631
+ ch_8bit_adam.value = bnb_optimizer_value
1632
 
1633
  ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
1634
  txt_info_train = gr.Text(label="Info", value="")
 
1679
  epochs,
1680
  num_warmup_updates,
1681
  save_per_updates,
1682
+ keep_last_n_checkpoints,
1683
+ last_per_updates,
1684
  ch_finetune,
1685
  file_checkpoint_train,
1686
  tokenizer_type,
 
1698
  fn=calculate_train,
1699
  inputs=[
1700
  cm_project,
1701
+ epochs,
1702
+ learning_rate,
1703
+ batch_size_per_gpu,
1704
  batch_size_type,
1705
  max_samples,
 
1706
  num_warmup_updates,
 
 
1707
  ch_finetune,
1708
  ],
1709
  outputs=[
1710
+ epochs,
1711
+ learning_rate,
1712
  batch_size_per_gpu,
1713
  max_samples,
1714
  num_warmup_updates,
 
 
1715
  lb_samples,
 
 
1716
  ],
1717
  )
1718
 
 
1732
  epochs,
1733
  num_warmup_updates,
1734
  save_per_updates,
1735
+ keep_last_n_checkpoints,
1736
+ last_per_updates,
1737
  ch_finetune,
1738
  file_checkpoint_train,
1739
  tokenizer_type,
1740
  tokenizer_file,
1741
  mixed_precision,
1742
  cd_logger,
1743
+ ch_8bit_adam,
1744
  ]
 
1745
  return output_components
1746
 
1747
  outputs = setup_load_settings()
 
1762
  gr.Markdown("""```plaintext
1763
  SOS: Check the use_ema setting (True or False) for your model to see what works best for you. use seed -1 from random
1764
  ```""")
1765
+ exp_name = gr.Radio(
1766
+ label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
1767
+ )
1768
  list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
1769
 
1770
  with gr.Row():
 
1818
  bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
1819
  cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
1820
 
1821
+ with gr.TabItem("Prune Checkpoint"):
1822
  gr.Markdown("""```plaintext
1823
+ Reduce the Base model size from 5GB to 1.3GB. The new checkpoint file prunes out optimizer and etc., can be used for inference or finetuning afterward, but not able to resume pretraining.
1824
  ```""")
1825
  txt_path_checkpoint = gr.Text(label="Path to Checkpoint:")
1826
  txt_path_checkpoint_small = gr.Text(label="Path to Output:")
src/f5_tts/train/train.py CHANGED
@@ -4,8 +4,9 @@ import os
4
  from importlib.resources import files
5
 
6
  import hydra
 
7
 
8
- from f5_tts.model import CFM, DiT, Trainer, UNetT
9
  from f5_tts.model.dataset import load_dataset
10
  from f5_tts.model.utils import get_tokenizer
11
 
@@ -14,9 +15,13 @@ os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to
14
 
15
  @hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
16
  def main(cfg):
 
 
17
  tokenizer = cfg.model.tokenizer
18
  mel_spec_type = cfg.model.mel_spec.mel_spec_type
 
19
  exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
 
20
 
21
  # set text tokenizer
22
  if tokenizer != "custom":
@@ -26,14 +31,8 @@ def main(cfg):
26
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
27
 
28
  # set model
29
- if "F5TTS" in cfg.model.name:
30
- model_cls = DiT
31
- elif "E2TTS" in cfg.model.name:
32
- model_cls = UNetT
33
- wandb_resume_id = None
34
-
35
  model = CFM(
36
- transformer=model_cls(**cfg.model.arch, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
37
  mel_spec_kwargs=cfg.model.mel_spec,
38
  vocab_char_map=vocab_char_map,
39
  )
@@ -45,8 +44,9 @@ def main(cfg):
45
  learning_rate=cfg.optim.learning_rate,
46
  num_warmup_updates=cfg.optim.num_warmup_updates,
47
  save_per_updates=cfg.ckpts.save_per_updates,
 
48
  checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
49
- batch_size=cfg.datasets.batch_size_per_gpu,
50
  batch_size_type=cfg.datasets.batch_size_type,
51
  max_samples=cfg.datasets.max_samples,
52
  grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
@@ -55,12 +55,13 @@ def main(cfg):
55
  wandb_project="CFM-TTS",
56
  wandb_run_name=exp_name,
57
  wandb_resume_id=wandb_resume_id,
58
- last_per_steps=cfg.ckpts.last_per_steps,
59
- log_samples=True,
60
  bnb_optimizer=cfg.optim.bnb_optimizer,
61
  mel_spec_type=mel_spec_type,
62
  is_local_vocoder=cfg.model.vocoder.is_local,
63
  local_vocoder_path=cfg.model.vocoder.local_path,
 
64
  )
65
 
66
  train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
 
4
  from importlib.resources import files
5
 
6
  import hydra
7
+ from omegaconf import OmegaConf
8
 
9
+ from f5_tts.model import CFM, DiT, UNetT, Trainer # noqa: F401. used for config
10
  from f5_tts.model.dataset import load_dataset
11
  from f5_tts.model.utils import get_tokenizer
12
 
 
15
 
16
  @hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
17
  def main(cfg):
18
+ model_cls = globals()[cfg.model.backbone]
19
+ model_arc = cfg.model.arch
20
  tokenizer = cfg.model.tokenizer
21
  mel_spec_type = cfg.model.mel_spec.mel_spec_type
22
+
23
  exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
24
+ wandb_resume_id = None
25
 
26
  # set text tokenizer
27
  if tokenizer != "custom":
 
31
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
32
 
33
  # set model
 
 
 
 
 
 
34
  model = CFM(
35
+ transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
36
  mel_spec_kwargs=cfg.model.mel_spec,
37
  vocab_char_map=vocab_char_map,
38
  )
 
44
  learning_rate=cfg.optim.learning_rate,
45
  num_warmup_updates=cfg.optim.num_warmup_updates,
46
  save_per_updates=cfg.ckpts.save_per_updates,
47
+ keep_last_n_checkpoints=cfg.ckpts.keep_last_n_checkpoints,
48
  checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
49
+ batch_size_per_gpu=cfg.datasets.batch_size_per_gpu,
50
  batch_size_type=cfg.datasets.batch_size_type,
51
  max_samples=cfg.datasets.max_samples,
52
  grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
 
55
  wandb_project="CFM-TTS",
56
  wandb_run_name=exp_name,
57
  wandb_resume_id=wandb_resume_id,
58
+ last_per_updates=cfg.ckpts.last_per_updates,
59
+ log_samples=cfg.ckpts.log_samples,
60
  bnb_optimizer=cfg.optim.bnb_optimizer,
61
  mel_spec_type=mel_spec_type,
62
  is_local_vocoder=cfg.model.vocoder.is_local,
63
  local_vocoder_path=cfg.model.vocoder.local_path,
64
+ cfg_dict=OmegaConf.to_container(cfg, resolve=True),
65
  )
66
 
67
  train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)