Spaces:
Running
on
Zero
Running
on
Zero
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- .gitattributes +5 -0
- .github/workflows/publish-pypi.yaml +66 -0
- README_REPO.md +99 -24
- app.py +52 -13
- ckpts/README.md +5 -3
- pyproject.toml +1 -1
- src/f5_tts/api.py +59 -60
- src/f5_tts/configs/E2TTS_Base.yaml +49 -0
- src/f5_tts/configs/E2TTS_Small.yaml +49 -0
- src/f5_tts/configs/F5TTS_Base.yaml +52 -0
- src/f5_tts/configs/F5TTS_Small.yaml +52 -0
- src/f5_tts/configs/F5TTS_v1_Base.yaml +53 -0
- src/f5_tts/eval/eval_infer_batch.py +22 -27
- src/f5_tts/eval/eval_infer_batch.sh +11 -6
- src/f5_tts/eval/eval_librispeech_test_clean.py +21 -27
- src/f5_tts/eval/eval_seedtts_testset.py +21 -27
- src/f5_tts/eval/eval_utmos.py +15 -17
- src/f5_tts/eval/utils_eval.py +11 -6
- src/f5_tts/infer/README.md +38 -80
- src/f5_tts/infer/SHARED.md +19 -9
- src/f5_tts/infer/examples/basic/basic.toml +2 -2
- src/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
- src/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
- src/f5_tts/infer/examples/multi/country.flac +0 -0
- src/f5_tts/infer/examples/multi/main.flac +0 -0
- src/f5_tts/infer/examples/multi/story.toml +2 -2
- src/f5_tts/infer/examples/multi/town.flac +0 -0
- src/f5_tts/infer/infer_cli.py +26 -31
- src/f5_tts/infer/speech_edit.py +35 -28
- src/f5_tts/infer/utils_infer.py +114 -72
- src/f5_tts/model/backbones/README.md +2 -2
- src/f5_tts/model/backbones/dit.py +63 -8
- src/f5_tts/model/backbones/mmdit.py +52 -9
- src/f5_tts/model/backbones/unett.py +36 -5
- src/f5_tts/model/cfm.py +9 -11
- src/f5_tts/model/dataset.py +21 -10
- src/f5_tts/model/modules.py +115 -42
- src/f5_tts/model/trainer.py +143 -72
- src/f5_tts/model/utils.py +4 -3
- src/f5_tts/scripts/count_max_epoch.py +3 -3
- src/f5_tts/socket_client.py +61 -0
- src/f5_tts/socket_server.py +176 -99
- src/f5_tts/train/README.md +5 -5
- src/f5_tts/train/datasets/prepare_csv_wavs.py +188 -43
- src/f5_tts/train/finetune_cli.py +63 -21
- src/f5_tts/train/finetune_gradio.py +272 -250
- src/f5_tts/train/train.py +12 -11
.gitattributes
CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
src/f5_tts/infer/examples/basic/basic_ref_en.wav filter=lfs diff=lfs merge=lfs -text
|
37 |
+
src/f5_tts/infer/examples/basic/basic_ref_zh.wav filter=lfs diff=lfs merge=lfs -text
|
38 |
+
src/f5_tts/infer/examples/multi/country.flac filter=lfs diff=lfs merge=lfs -text
|
39 |
+
src/f5_tts/infer/examples/multi/main.flac filter=lfs diff=lfs merge=lfs -text
|
40 |
+
src/f5_tts/infer/examples/multi/town.flac filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/publish-pypi.yaml
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This workflow uses actions that are not certified by GitHub.
|
2 |
+
# They are provided by a third-party and are governed by
|
3 |
+
# separate terms of service, privacy policy, and support
|
4 |
+
# documentation.
|
5 |
+
|
6 |
+
# GitHub recommends pinning actions to a commit SHA.
|
7 |
+
# To get a newer version, you will need to update the SHA.
|
8 |
+
# You can also reference a tag or branch, but the action may change without warning.
|
9 |
+
|
10 |
+
name: Upload Python Package
|
11 |
+
|
12 |
+
on:
|
13 |
+
release:
|
14 |
+
types: [published]
|
15 |
+
|
16 |
+
permissions:
|
17 |
+
contents: read
|
18 |
+
|
19 |
+
jobs:
|
20 |
+
release-build:
|
21 |
+
runs-on: ubuntu-latest
|
22 |
+
|
23 |
+
steps:
|
24 |
+
- uses: actions/checkout@v4
|
25 |
+
|
26 |
+
- uses: actions/setup-python@v5
|
27 |
+
with:
|
28 |
+
python-version: "3.x"
|
29 |
+
|
30 |
+
- name: Build release distributions
|
31 |
+
run: |
|
32 |
+
# NOTE: put your own distribution build steps here.
|
33 |
+
python -m pip install build
|
34 |
+
python -m build
|
35 |
+
|
36 |
+
- name: Upload distributions
|
37 |
+
uses: actions/upload-artifact@v4
|
38 |
+
with:
|
39 |
+
name: release-dists
|
40 |
+
path: dist/
|
41 |
+
|
42 |
+
pypi-publish:
|
43 |
+
runs-on: ubuntu-latest
|
44 |
+
|
45 |
+
needs:
|
46 |
+
- release-build
|
47 |
+
|
48 |
+
permissions:
|
49 |
+
# IMPORTANT: this permission is mandatory for trusted publishing
|
50 |
+
id-token: write
|
51 |
+
|
52 |
+
# Dedicated environments with protections for publishing are strongly recommended.
|
53 |
+
environment:
|
54 |
+
name: pypi
|
55 |
+
# OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status:
|
56 |
+
# url: https://pypi.org/p/YOURPROJECT
|
57 |
+
|
58 |
+
steps:
|
59 |
+
- name: Retrieve release distributions
|
60 |
+
uses: actions/download-artifact@v4
|
61 |
+
with:
|
62 |
+
name: release-dists
|
63 |
+
path: dist/
|
64 |
+
|
65 |
+
- name: Publish release distributions to PyPI
|
66 |
+
uses: pypa/gh-action-pypi-publish@release/v1
|
README_REPO.md
CHANGED
@@ -6,7 +6,8 @@
|
|
6 |
[](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
|
7 |
[](https://modelscope.cn/studios/modelscope/E2-F5-TTS)
|
8 |
[](https://x-lance.sjtu.edu.cn/)
|
9 |
-
|
|
|
10 |
|
11 |
**F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
|
12 |
|
@@ -17,40 +18,84 @@
|
|
17 |
### Thanks to all the contributors !
|
18 |
|
19 |
## News
|
|
|
20 |
- **2024/10/08**: F5-TTS & E2 TTS base models on [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), [🟣 Wisemodel](https://wisemodel.cn/models/SJTU_X-LANCE/F5-TTS_Emilia-ZH-EN).
|
21 |
|
22 |
## Installation
|
23 |
|
|
|
|
|
24 |
```bash
|
25 |
# Create a python 3.10 conda env (you could also use virtualenv)
|
26 |
conda create -n f5-tts python=3.10
|
27 |
conda activate f5-tts
|
|
|
28 |
|
29 |
-
|
30 |
-
pip install torch==2.3.0+cu118 torchaudio==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
```
|
35 |
|
36 |
-
|
|
|
|
|
|
|
37 |
|
38 |
-
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
```
|
43 |
|
44 |
-
|
|
|
|
|
|
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
```bash
|
55 |
# Build from Dockerfile
|
56 |
docker build -t f5tts:v1 .
|
@@ -82,14 +127,40 @@ f5-tts_infer-gradio --port 7860 --host 0.0.0.0
|
|
82 |
f5-tts_infer-gradio --share
|
83 |
```
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
### 2. CLI Inference
|
86 |
|
87 |
```bash
|
88 |
# Run with flags
|
89 |
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
|
90 |
-
f5-tts_infer-cli \
|
91 |
-
--
|
92 |
-
--ref_audio "ref_audio.wav" \
|
93 |
--ref_text "The content, subtitle or transcription of reference audio." \
|
94 |
--gen_text "Some text you want TTS model generate for you."
|
95 |
|
@@ -110,15 +181,19 @@ f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml
|
|
110 |
|
111 |
## Training
|
112 |
|
113 |
-
### 1.
|
114 |
|
115 |
-
|
|
|
|
|
116 |
|
117 |
```bash
|
118 |
# Quick start with Gradio web interface
|
119 |
f5-tts_finetune-gradio
|
120 |
```
|
121 |
|
|
|
|
|
122 |
|
123 |
## [Evaluation](src/f5_tts/eval)
|
124 |
|
|
|
6 |
[](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
|
7 |
[](https://modelscope.cn/studios/modelscope/E2-F5-TTS)
|
8 |
[](https://x-lance.sjtu.edu.cn/)
|
9 |
+
[](https://www.pcl.ac.cn)
|
10 |
+
<!-- <img src="https://github.com/user-attachments/assets/12d7749c-071a-427c-81bf-b87b91def670" alt="Watermark" style="width: 40px; height: auto"> -->
|
11 |
|
12 |
**F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
|
13 |
|
|
|
18 |
### Thanks to all the contributors !
|
19 |
|
20 |
## News
|
21 |
+
- **2025/03/12**: 🔥 F5-TTS v1 base model with better training and inference performance. [Few demo](https://swivid.github.io/F5-TTS_updates).
|
22 |
- **2024/10/08**: F5-TTS & E2 TTS base models on [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), [🟣 Wisemodel](https://wisemodel.cn/models/SJTU_X-LANCE/F5-TTS_Emilia-ZH-EN).
|
23 |
|
24 |
## Installation
|
25 |
|
26 |
+
### Create a separate environment if needed
|
27 |
+
|
28 |
```bash
|
29 |
# Create a python 3.10 conda env (you could also use virtualenv)
|
30 |
conda create -n f5-tts python=3.10
|
31 |
conda activate f5-tts
|
32 |
+
```
|
33 |
|
34 |
+
### Install PyTorch with matched device
|
|
|
35 |
|
36 |
+
<details>
|
37 |
+
<summary>NVIDIA GPU</summary>
|
|
|
38 |
|
39 |
+
> ```bash
|
40 |
+
> # Install pytorch with your CUDA version, e.g.
|
41 |
+
> pip install torch==2.4.0+cu124 torchaudio==2.4.0+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
|
42 |
+
> ```
|
43 |
|
44 |
+
</details>
|
45 |
|
46 |
+
<details>
|
47 |
+
<summary>AMD GPU</summary>
|
|
|
48 |
|
49 |
+
> ```bash
|
50 |
+
> # Install pytorch with your ROCm version (Linux only), e.g.
|
51 |
+
> pip install torch==2.5.1+rocm6.2 torchaudio==2.5.1+rocm6.2 --extra-index-url https://download.pytorch.org/whl/rocm6.2
|
52 |
+
> ```
|
53 |
|
54 |
+
</details>
|
55 |
+
|
56 |
+
<details>
|
57 |
+
<summary>Intel GPU</summary>
|
58 |
+
|
59 |
+
> ```bash
|
60 |
+
> # Install pytorch with your XPU version, e.g.
|
61 |
+
> # Intel® Deep Learning Essentials or Intel® oneAPI Base Toolkit must be installed
|
62 |
+
> pip install torch torchaudio --index-url https://download.pytorch.org/whl/test/xpu
|
63 |
+
>
|
64 |
+
> # Intel GPU support is also available through IPEX (Intel® Extension for PyTorch)
|
65 |
+
> # IPEX does not require the Intel® Deep Learning Essentials or Intel® oneAPI Base Toolkit
|
66 |
+
> # See: https://pytorch-extension.intel.com/installation?request=platform
|
67 |
+
> ```
|
68 |
+
|
69 |
+
</details>
|
70 |
+
|
71 |
+
<details>
|
72 |
+
<summary>Apple Silicon</summary>
|
73 |
+
|
74 |
+
> ```bash
|
75 |
+
> # Install the stable pytorch, e.g.
|
76 |
+
> pip install torch torchaudio
|
77 |
+
> ```
|
78 |
|
79 |
+
</details>
|
80 |
+
|
81 |
+
### Then you can choose one from below:
|
82 |
+
|
83 |
+
> ### 1. As a pip package (if just for inference)
|
84 |
+
>
|
85 |
+
> ```bash
|
86 |
+
> pip install f5-tts
|
87 |
+
> ```
|
88 |
+
>
|
89 |
+
> ### 2. Local editable (if also do training, finetuning)
|
90 |
+
>
|
91 |
+
> ```bash
|
92 |
+
> git clone https://github.com/SWivid/F5-TTS.git
|
93 |
+
> cd F5-TTS
|
94 |
+
> # git submodule update --init --recursive # (optional, if need > bigvgan)
|
95 |
+
> pip install -e .
|
96 |
+
> ```
|
97 |
+
|
98 |
+
### Docker usage also available
|
99 |
```bash
|
100 |
# Build from Dockerfile
|
101 |
docker build -t f5tts:v1 .
|
|
|
127 |
f5-tts_infer-gradio --share
|
128 |
```
|
129 |
|
130 |
+
<details>
|
131 |
+
<summary>NVIDIA device docker compose file example</summary>
|
132 |
+
|
133 |
+
```yaml
|
134 |
+
services:
|
135 |
+
f5-tts:
|
136 |
+
image: ghcr.io/swivid/f5-tts:main
|
137 |
+
ports:
|
138 |
+
- "7860:7860"
|
139 |
+
environment:
|
140 |
+
GRADIO_SERVER_PORT: 7860
|
141 |
+
entrypoint: ["f5-tts_infer-gradio", "--port", "7860", "--host", "0.0.0.0"]
|
142 |
+
deploy:
|
143 |
+
resources:
|
144 |
+
reservations:
|
145 |
+
devices:
|
146 |
+
- driver: nvidia
|
147 |
+
count: 1
|
148 |
+
capabilities: [gpu]
|
149 |
+
|
150 |
+
volumes:
|
151 |
+
f5-tts:
|
152 |
+
driver: local
|
153 |
+
```
|
154 |
+
|
155 |
+
</details>
|
156 |
+
|
157 |
### 2. CLI Inference
|
158 |
|
159 |
```bash
|
160 |
# Run with flags
|
161 |
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
|
162 |
+
f5-tts_infer-cli --model F5TTS_v1_Base \
|
163 |
+
--ref_audio "provide_prompt_wav_path_here.wav" \
|
|
|
164 |
--ref_text "The content, subtitle or transcription of reference audio." \
|
165 |
--gen_text "Some text you want TTS model generate for you."
|
166 |
|
|
|
181 |
|
182 |
## Training
|
183 |
|
184 |
+
### 1. With Hugging Face Accelerate
|
185 |
|
186 |
+
Refer to [training & finetuning guidance](src/f5_tts/train) for best practice.
|
187 |
+
|
188 |
+
### 2. With Gradio App
|
189 |
|
190 |
```bash
|
191 |
# Quick start with Gradio web interface
|
192 |
f5-tts_finetune-gradio
|
193 |
```
|
194 |
|
195 |
+
Read [training & finetuning guidance](src/f5_tts/train) for more instructions.
|
196 |
+
|
197 |
|
198 |
## [Evaluation](src/f5_tts/eval)
|
199 |
|
app.py
CHANGED
@@ -41,12 +41,12 @@ from f5_tts.infer.utils_infer import (
|
|
41 |
)
|
42 |
|
43 |
|
44 |
-
DEFAULT_TTS_MODEL = "F5-
|
45 |
tts_model_choice = DEFAULT_TTS_MODEL
|
46 |
|
47 |
DEFAULT_TTS_MODEL_CFG = [
|
48 |
-
"hf://SWivid/F5-TTS/
|
49 |
-
"hf://SWivid/F5-TTS/
|
50 |
json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
|
51 |
]
|
52 |
|
@@ -56,13 +56,15 @@ DEFAULT_TTS_MODEL_CFG = [
|
|
56 |
vocoder = load_vocoder()
|
57 |
|
58 |
|
59 |
-
def load_f5tts(
|
60 |
-
|
|
|
61 |
return load_model(DiT, F5TTS_model_cfg, ckpt_path)
|
62 |
|
63 |
|
64 |
-
def load_e2tts(
|
65 |
-
|
|
|
66 |
return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
|
67 |
|
68 |
|
@@ -73,7 +75,7 @@ def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
|
|
73 |
if vocab_path.startswith("hf://"):
|
74 |
vocab_path = str(cached_path(vocab_path))
|
75 |
if model_cfg is None:
|
76 |
-
model_cfg =
|
77 |
return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
|
78 |
|
79 |
|
@@ -130,7 +132,7 @@ def infer(
|
|
130 |
|
131 |
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
|
132 |
|
133 |
-
if model ==
|
134 |
ema_model = F5TTS_ema_model
|
135 |
elif model == "E2-TTS":
|
136 |
global E2TTS_ema_model
|
@@ -762,7 +764,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
|
762 |
"""
|
763 |
)
|
764 |
|
765 |
-
last_used_custom = files("f5_tts").joinpath("infer/.cache/
|
766 |
|
767 |
def load_last_used_custom():
|
768 |
try:
|
@@ -821,7 +823,30 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
|
821 |
custom_model_cfg = gr.Dropdown(
|
822 |
choices=[
|
823 |
DEFAULT_TTS_MODEL_CFG[2],
|
824 |
-
json.dumps(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
825 |
],
|
826 |
value=load_last_used_custom()[2],
|
827 |
allow_custom_value=True,
|
@@ -875,10 +900,24 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
|
875 |
type=str,
|
876 |
help='The root path (or "mount point") of the application, if it\'s not served from the root ("/") of the domain. Often used when the application is behind a reverse proxy that forwards requests to the application, e.g. set "/myapp" or full URL for application served at "https://example.com/myapp".',
|
877 |
)
|
878 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
879 |
global app
|
880 |
print("Starting app...")
|
881 |
-
app.queue(api_open=api).launch(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
882 |
|
883 |
|
884 |
if __name__ == "__main__":
|
|
|
41 |
)
|
42 |
|
43 |
|
44 |
+
DEFAULT_TTS_MODEL = "F5-TTS_v1"
|
45 |
tts_model_choice = DEFAULT_TTS_MODEL
|
46 |
|
47 |
DEFAULT_TTS_MODEL_CFG = [
|
48 |
+
"hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors",
|
49 |
+
"hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt",
|
50 |
json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
|
51 |
]
|
52 |
|
|
|
56 |
vocoder = load_vocoder()
|
57 |
|
58 |
|
59 |
+
def load_f5tts():
|
60 |
+
ckpt_path = str(cached_path(DEFAULT_TTS_MODEL_CFG[0]))
|
61 |
+
F5TTS_model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
|
62 |
return load_model(DiT, F5TTS_model_cfg, ckpt_path)
|
63 |
|
64 |
|
65 |
+
def load_e2tts():
|
66 |
+
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
|
67 |
+
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1)
|
68 |
return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
|
69 |
|
70 |
|
|
|
75 |
if vocab_path.startswith("hf://"):
|
76 |
vocab_path = str(cached_path(vocab_path))
|
77 |
if model_cfg is None:
|
78 |
+
model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
|
79 |
return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
|
80 |
|
81 |
|
|
|
132 |
|
133 |
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
|
134 |
|
135 |
+
if model == DEFAULT_TTS_MODEL:
|
136 |
ema_model = F5TTS_ema_model
|
137 |
elif model == "E2-TTS":
|
138 |
global E2TTS_ema_model
|
|
|
764 |
"""
|
765 |
)
|
766 |
|
767 |
+
last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info_v1.txt")
|
768 |
|
769 |
def load_last_used_custom():
|
770 |
try:
|
|
|
823 |
custom_model_cfg = gr.Dropdown(
|
824 |
choices=[
|
825 |
DEFAULT_TTS_MODEL_CFG[2],
|
826 |
+
json.dumps(
|
827 |
+
dict(
|
828 |
+
dim=1024,
|
829 |
+
depth=22,
|
830 |
+
heads=16,
|
831 |
+
ff_mult=2,
|
832 |
+
text_dim=512,
|
833 |
+
text_mask_padding=False,
|
834 |
+
conv_layers=4,
|
835 |
+
pe_attn_head=1,
|
836 |
+
)
|
837 |
+
),
|
838 |
+
json.dumps(
|
839 |
+
dict(
|
840 |
+
dim=768,
|
841 |
+
depth=18,
|
842 |
+
heads=12,
|
843 |
+
ff_mult=2,
|
844 |
+
text_dim=512,
|
845 |
+
text_mask_padding=False,
|
846 |
+
conv_layers=4,
|
847 |
+
pe_attn_head=1,
|
848 |
+
)
|
849 |
+
),
|
850 |
],
|
851 |
value=load_last_used_custom()[2],
|
852 |
allow_custom_value=True,
|
|
|
900 |
type=str,
|
901 |
help='The root path (or "mount point") of the application, if it\'s not served from the root ("/") of the domain. Often used when the application is behind a reverse proxy that forwards requests to the application, e.g. set "/myapp" or full URL for application served at "https://example.com/myapp".',
|
902 |
)
|
903 |
+
@click.option(
|
904 |
+
"--inbrowser",
|
905 |
+
"-i",
|
906 |
+
is_flag=True,
|
907 |
+
default=False,
|
908 |
+
help="Automatically launch the interface in the default web browser",
|
909 |
+
)
|
910 |
+
def main(port, host, share, api, root_path, inbrowser):
|
911 |
global app
|
912 |
print("Starting app...")
|
913 |
+
app.queue(api_open=api).launch(
|
914 |
+
server_name=host,
|
915 |
+
server_port=port,
|
916 |
+
share=share,
|
917 |
+
show_api=api,
|
918 |
+
root_path=root_path,
|
919 |
+
inbrowser=inbrowser,
|
920 |
+
)
|
921 |
|
922 |
|
923 |
if __name__ == "__main__":
|
ckpts/README.md
CHANGED
@@ -3,8 +3,10 @@ Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS
|
|
3 |
|
4 |
```
|
5 |
ckpts/
|
6 |
-
|
7 |
-
|
8 |
F5TTS_Base/
|
9 |
-
model_1200000.
|
|
|
|
|
10 |
```
|
|
|
3 |
|
4 |
```
|
5 |
ckpts/
|
6 |
+
F5TTS_v1_Base/
|
7 |
+
model_1250000.safetensors
|
8 |
F5TTS_Base/
|
9 |
+
model_1200000.safetensors
|
10 |
+
E2TTS_Base/
|
11 |
+
model_1200000.safetensors
|
12 |
```
|
pyproject.toml
CHANGED
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4 |
|
5 |
[project]
|
6 |
name = "f5-tts"
|
7 |
-
version = "0.
|
8 |
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
9 |
readme = "README.md"
|
10 |
license = {text = "MIT License"}
|
|
|
4 |
|
5 |
[project]
|
6 |
name = "f5-tts"
|
7 |
+
version = "1.0.1"
|
8 |
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
9 |
readme = "README.md"
|
10 |
license = {text = "MIT License"}
|
src/f5_tts/api.py
CHANGED
@@ -5,84 +5,84 @@ from importlib.resources import files
|
|
5 |
import soundfile as sf
|
6 |
import tqdm
|
7 |
from cached_path import cached_path
|
|
|
8 |
|
9 |
from f5_tts.infer.utils_infer import (
|
10 |
-
hop_length,
|
11 |
-
infer_process,
|
12 |
load_model,
|
13 |
load_vocoder,
|
|
|
14 |
preprocess_ref_audio_text,
|
|
|
15 |
remove_silence_for_generated_wav,
|
16 |
save_spectrogram,
|
17 |
-
transcribe,
|
18 |
-
target_sample_rate,
|
19 |
)
|
20 |
-
from f5_tts.model import DiT, UNetT
|
21 |
from f5_tts.model.utils import seed_everything
|
22 |
|
23 |
|
24 |
class F5TTS:
|
25 |
def __init__(
|
26 |
self,
|
27 |
-
|
28 |
ckpt_file="",
|
29 |
vocab_file="",
|
30 |
ode_method="euler",
|
31 |
use_ema=True,
|
32 |
-
|
33 |
-
local_path=None,
|
34 |
device=None,
|
35 |
hf_cache_dir=None,
|
36 |
):
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
self.
|
42 |
-
self.
|
43 |
-
|
44 |
-
|
|
|
|
|
45 |
if device is not None:
|
46 |
self.device = device
|
47 |
else:
|
48 |
import torch
|
49 |
|
50 |
-
self.device =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
# Load models
|
53 |
-
self.
|
54 |
-
|
55 |
-
model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir
|
56 |
)
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir)
|
71 |
-
)
|
72 |
-
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
73 |
-
model_cls = DiT
|
74 |
-
elif model_type == "E2-TTS":
|
75 |
-
if not ckpt_file:
|
76 |
-
ckpt_file = str(
|
77 |
-
cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
|
78 |
-
)
|
79 |
-
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
80 |
-
model_cls = UNetT
|
81 |
else:
|
82 |
-
raise ValueError(f"Unknown model type: {
|
83 |
|
|
|
|
|
|
|
|
|
84 |
self.ema_model = load_model(
|
85 |
-
model_cls,
|
86 |
)
|
87 |
|
88 |
def transcribe(self, ref_audio, language=None):
|
@@ -94,8 +94,8 @@ class F5TTS:
|
|
94 |
if remove_silence:
|
95 |
remove_silence_for_generated_wav(file_wave)
|
96 |
|
97 |
-
def export_spectrogram(self,
|
98 |
-
save_spectrogram(
|
99 |
|
100 |
def infer(
|
101 |
self,
|
@@ -113,17 +113,16 @@ class F5TTS:
|
|
113 |
fix_duration=None,
|
114 |
remove_silence=False,
|
115 |
file_wave=None,
|
116 |
-
|
117 |
-
seed
|
118 |
):
|
119 |
-
if seed
|
120 |
-
seed = random.randint(0, sys.maxsize)
|
121 |
-
seed_everything(seed)
|
122 |
-
self.seed = seed
|
123 |
|
124 |
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
|
125 |
|
126 |
-
wav, sr,
|
127 |
ref_file,
|
128 |
ref_text,
|
129 |
gen_text,
|
@@ -145,22 +144,22 @@ class F5TTS:
|
|
145 |
if file_wave is not None:
|
146 |
self.export_wav(wav, file_wave, remove_silence)
|
147 |
|
148 |
-
if
|
149 |
-
self.export_spectrogram(
|
150 |
|
151 |
-
return wav, sr,
|
152 |
|
153 |
|
154 |
if __name__ == "__main__":
|
155 |
f5tts = F5TTS()
|
156 |
|
157 |
-
wav, sr,
|
158 |
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
|
159 |
ref_text="some call me nature, others call me mother nature.",
|
160 |
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
|
161 |
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
|
162 |
-
|
163 |
-
seed
|
164 |
)
|
165 |
|
166 |
print("seed :", f5tts.seed)
|
|
|
5 |
import soundfile as sf
|
6 |
import tqdm
|
7 |
from cached_path import cached_path
|
8 |
+
from omegaconf import OmegaConf
|
9 |
|
10 |
from f5_tts.infer.utils_infer import (
|
|
|
|
|
11 |
load_model,
|
12 |
load_vocoder,
|
13 |
+
transcribe,
|
14 |
preprocess_ref_audio_text,
|
15 |
+
infer_process,
|
16 |
remove_silence_for_generated_wav,
|
17 |
save_spectrogram,
|
|
|
|
|
18 |
)
|
19 |
+
from f5_tts.model import DiT, UNetT # noqa: F401. used for config
|
20 |
from f5_tts.model.utils import seed_everything
|
21 |
|
22 |
|
23 |
class F5TTS:
|
24 |
def __init__(
|
25 |
self,
|
26 |
+
model="F5TTS_v1_Base",
|
27 |
ckpt_file="",
|
28 |
vocab_file="",
|
29 |
ode_method="euler",
|
30 |
use_ema=True,
|
31 |
+
vocoder_local_path=None,
|
|
|
32 |
device=None,
|
33 |
hf_cache_dir=None,
|
34 |
):
|
35 |
+
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
36 |
+
model_cls = globals()[model_cfg.model.backbone]
|
37 |
+
model_arc = model_cfg.model.arch
|
38 |
+
|
39 |
+
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
40 |
+
self.target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
|
41 |
+
|
42 |
+
self.ode_method = ode_method
|
43 |
+
self.use_ema = use_ema
|
44 |
+
|
45 |
if device is not None:
|
46 |
self.device = device
|
47 |
else:
|
48 |
import torch
|
49 |
|
50 |
+
self.device = (
|
51 |
+
"cuda"
|
52 |
+
if torch.cuda.is_available()
|
53 |
+
else "xpu"
|
54 |
+
if torch.xpu.is_available()
|
55 |
+
else "mps"
|
56 |
+
if torch.backends.mps.is_available()
|
57 |
+
else "cpu"
|
58 |
+
)
|
59 |
|
60 |
# Load models
|
61 |
+
self.vocoder = load_vocoder(
|
62 |
+
self.mel_spec_type, vocoder_local_path is not None, vocoder_local_path, self.device, hf_cache_dir
|
|
|
63 |
)
|
64 |
|
65 |
+
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
|
66 |
+
|
67 |
+
# override for previous models
|
68 |
+
if model == "F5TTS_Base":
|
69 |
+
if self.mel_spec_type == "vocos":
|
70 |
+
ckpt_step = 1200000
|
71 |
+
elif self.mel_spec_type == "bigvgan":
|
72 |
+
model = "F5TTS_Base_bigvgan"
|
73 |
+
ckpt_type = "pt"
|
74 |
+
elif model == "E2TTS_Base":
|
75 |
+
repo_name = "E2-TTS"
|
76 |
+
ckpt_step = 1200000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
else:
|
78 |
+
raise ValueError(f"Unknown model type: {model}")
|
79 |
|
80 |
+
if not ckpt_file:
|
81 |
+
ckpt_file = str(
|
82 |
+
cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}", cache_dir=hf_cache_dir)
|
83 |
+
)
|
84 |
self.ema_model = load_model(
|
85 |
+
model_cls, model_arc, ckpt_file, self.mel_spec_type, vocab_file, self.ode_method, self.use_ema, self.device
|
86 |
)
|
87 |
|
88 |
def transcribe(self, ref_audio, language=None):
|
|
|
94 |
if remove_silence:
|
95 |
remove_silence_for_generated_wav(file_wave)
|
96 |
|
97 |
+
def export_spectrogram(self, spec, file_spec):
|
98 |
+
save_spectrogram(spec, file_spec)
|
99 |
|
100 |
def infer(
|
101 |
self,
|
|
|
113 |
fix_duration=None,
|
114 |
remove_silence=False,
|
115 |
file_wave=None,
|
116 |
+
file_spec=None,
|
117 |
+
seed=None,
|
118 |
):
|
119 |
+
if seed is None:
|
120 |
+
self.seed = random.randint(0, sys.maxsize)
|
121 |
+
seed_everything(self.seed)
|
|
|
122 |
|
123 |
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
|
124 |
|
125 |
+
wav, sr, spec = infer_process(
|
126 |
ref_file,
|
127 |
ref_text,
|
128 |
gen_text,
|
|
|
144 |
if file_wave is not None:
|
145 |
self.export_wav(wav, file_wave, remove_silence)
|
146 |
|
147 |
+
if file_spec is not None:
|
148 |
+
self.export_spectrogram(spec, file_spec)
|
149 |
|
150 |
+
return wav, sr, spec
|
151 |
|
152 |
|
153 |
if __name__ == "__main__":
|
154 |
f5tts = F5TTS()
|
155 |
|
156 |
+
wav, sr, spec = f5tts.infer(
|
157 |
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
|
158 |
ref_text="some call me nature, others call me mother nature.",
|
159 |
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
|
160 |
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
|
161 |
+
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
|
162 |
+
seed=None,
|
163 |
)
|
164 |
|
165 |
print("seed :", f5tts.seed)
|
src/f5_tts/configs/E2TTS_Base.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hydra:
|
2 |
+
run:
|
3 |
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
4 |
+
|
5 |
+
datasets:
|
6 |
+
name: Emilia_ZH_EN # dataset name
|
7 |
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
+
batch_size_type: frame # frame | sample
|
9 |
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
+
num_workers: 16
|
11 |
+
|
12 |
+
optim:
|
13 |
+
epochs: 11
|
14 |
+
learning_rate: 7.5e-5
|
15 |
+
num_warmup_updates: 20000 # warmup updates
|
16 |
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
+
max_grad_norm: 1.0 # gradient clipping
|
18 |
+
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
19 |
+
|
20 |
+
model:
|
21 |
+
name: E2TTS_Base
|
22 |
+
tokenizer: pinyin
|
23 |
+
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
|
24 |
+
backbone: UNetT
|
25 |
+
arch:
|
26 |
+
dim: 1024
|
27 |
+
depth: 24
|
28 |
+
heads: 16
|
29 |
+
ff_mult: 4
|
30 |
+
text_mask_padding: False
|
31 |
+
pe_attn_head: 1
|
32 |
+
mel_spec:
|
33 |
+
target_sample_rate: 24000
|
34 |
+
n_mel_channels: 100
|
35 |
+
hop_length: 256
|
36 |
+
win_length: 1024
|
37 |
+
n_fft: 1024
|
38 |
+
mel_spec_type: vocos # vocos | bigvgan
|
39 |
+
vocoder:
|
40 |
+
is_local: False # use local offline ckpt or not
|
41 |
+
local_path: null # local vocoder path
|
42 |
+
|
43 |
+
ckpts:
|
44 |
+
logger: wandb # wandb | tensorboard | null
|
45 |
+
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
|
46 |
+
save_per_updates: 50000 # save checkpoint per updates
|
47 |
+
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
48 |
+
last_per_updates: 5000 # save last checkpoint per updates
|
49 |
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
src/f5_tts/configs/E2TTS_Small.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hydra:
|
2 |
+
run:
|
3 |
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
4 |
+
|
5 |
+
datasets:
|
6 |
+
name: Emilia_ZH_EN
|
7 |
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
+
batch_size_type: frame # frame | sample
|
9 |
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
+
num_workers: 16
|
11 |
+
|
12 |
+
optim:
|
13 |
+
epochs: 11
|
14 |
+
learning_rate: 7.5e-5
|
15 |
+
num_warmup_updates: 20000 # warmup updates
|
16 |
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
+
max_grad_norm: 1.0
|
18 |
+
bnb_optimizer: False
|
19 |
+
|
20 |
+
model:
|
21 |
+
name: E2TTS_Small
|
22 |
+
tokenizer: pinyin
|
23 |
+
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
|
24 |
+
backbone: UNetT
|
25 |
+
arch:
|
26 |
+
dim: 768
|
27 |
+
depth: 20
|
28 |
+
heads: 12
|
29 |
+
ff_mult: 4
|
30 |
+
text_mask_padding: False
|
31 |
+
pe_attn_head: 1
|
32 |
+
mel_spec:
|
33 |
+
target_sample_rate: 24000
|
34 |
+
n_mel_channels: 100
|
35 |
+
hop_length: 256
|
36 |
+
win_length: 1024
|
37 |
+
n_fft: 1024
|
38 |
+
mel_spec_type: vocos # vocos | bigvgan
|
39 |
+
vocoder:
|
40 |
+
is_local: False # use local offline ckpt or not
|
41 |
+
local_path: null # local vocoder path
|
42 |
+
|
43 |
+
ckpts:
|
44 |
+
logger: wandb # wandb | tensorboard | null
|
45 |
+
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
|
46 |
+
save_per_updates: 50000 # save checkpoint per updates
|
47 |
+
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
48 |
+
last_per_updates: 5000 # save last checkpoint per updates
|
49 |
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
src/f5_tts/configs/F5TTS_Base.yaml
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hydra:
|
2 |
+
run:
|
3 |
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
4 |
+
|
5 |
+
datasets:
|
6 |
+
name: Emilia_ZH_EN # dataset name
|
7 |
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
+
batch_size_type: frame # frame | sample
|
9 |
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
+
num_workers: 16
|
11 |
+
|
12 |
+
optim:
|
13 |
+
epochs: 11
|
14 |
+
learning_rate: 7.5e-5
|
15 |
+
num_warmup_updates: 20000 # warmup updates
|
16 |
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
+
max_grad_norm: 1.0 # gradient clipping
|
18 |
+
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
19 |
+
|
20 |
+
model:
|
21 |
+
name: F5TTS_Base # model name
|
22 |
+
tokenizer: pinyin # tokenizer type
|
23 |
+
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
|
24 |
+
backbone: DiT
|
25 |
+
arch:
|
26 |
+
dim: 1024
|
27 |
+
depth: 22
|
28 |
+
heads: 16
|
29 |
+
ff_mult: 2
|
30 |
+
text_dim: 512
|
31 |
+
text_mask_padding: False
|
32 |
+
conv_layers: 4
|
33 |
+
pe_attn_head: 1
|
34 |
+
checkpoint_activations: False # recompute activations and save memory for extra compute
|
35 |
+
mel_spec:
|
36 |
+
target_sample_rate: 24000
|
37 |
+
n_mel_channels: 100
|
38 |
+
hop_length: 256
|
39 |
+
win_length: 1024
|
40 |
+
n_fft: 1024
|
41 |
+
mel_spec_type: vocos # vocos | bigvgan
|
42 |
+
vocoder:
|
43 |
+
is_local: False # use local offline ckpt or not
|
44 |
+
local_path: null # local vocoder path
|
45 |
+
|
46 |
+
ckpts:
|
47 |
+
logger: wandb # wandb | tensorboard | null
|
48 |
+
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
|
49 |
+
save_per_updates: 50000 # save checkpoint per updates
|
50 |
+
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
51 |
+
last_per_updates: 5000 # save last checkpoint per updates
|
52 |
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
src/f5_tts/configs/F5TTS_Small.yaml
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hydra:
|
2 |
+
run:
|
3 |
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
4 |
+
|
5 |
+
datasets:
|
6 |
+
name: Emilia_ZH_EN
|
7 |
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
+
batch_size_type: frame # frame | sample
|
9 |
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
+
num_workers: 16
|
11 |
+
|
12 |
+
optim:
|
13 |
+
epochs: 11
|
14 |
+
learning_rate: 7.5e-5
|
15 |
+
num_warmup_updates: 20000 # warmup updates
|
16 |
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
+
max_grad_norm: 1.0 # gradient clipping
|
18 |
+
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
19 |
+
|
20 |
+
model:
|
21 |
+
name: F5TTS_Small
|
22 |
+
tokenizer: pinyin
|
23 |
+
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
|
24 |
+
backbone: DiT
|
25 |
+
arch:
|
26 |
+
dim: 768
|
27 |
+
depth: 18
|
28 |
+
heads: 12
|
29 |
+
ff_mult: 2
|
30 |
+
text_dim: 512
|
31 |
+
text_mask_padding: False
|
32 |
+
conv_layers: 4
|
33 |
+
pe_attn_head: 1
|
34 |
+
checkpoint_activations: False # recompute activations and save memory for extra compute
|
35 |
+
mel_spec:
|
36 |
+
target_sample_rate: 24000
|
37 |
+
n_mel_channels: 100
|
38 |
+
hop_length: 256
|
39 |
+
win_length: 1024
|
40 |
+
n_fft: 1024
|
41 |
+
mel_spec_type: vocos # vocos | bigvgan
|
42 |
+
vocoder:
|
43 |
+
is_local: False # use local offline ckpt or not
|
44 |
+
local_path: null # local vocoder path
|
45 |
+
|
46 |
+
ckpts:
|
47 |
+
logger: wandb # wandb | tensorboard | null
|
48 |
+
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
|
49 |
+
save_per_updates: 50000 # save checkpoint per updates
|
50 |
+
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
51 |
+
last_per_updates: 5000 # save last checkpoint per updates
|
52 |
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
src/f5_tts/configs/F5TTS_v1_Base.yaml
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hydra:
|
2 |
+
run:
|
3 |
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
4 |
+
|
5 |
+
datasets:
|
6 |
+
name: Emilia_ZH_EN # dataset name
|
7 |
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
+
batch_size_type: frame # frame | sample
|
9 |
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
+
num_workers: 16
|
11 |
+
|
12 |
+
optim:
|
13 |
+
epochs: 11
|
14 |
+
learning_rate: 7.5e-5
|
15 |
+
num_warmup_updates: 20000 # warmup updates
|
16 |
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
+
max_grad_norm: 1.0 # gradient clipping
|
18 |
+
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
19 |
+
|
20 |
+
model:
|
21 |
+
name: F5TTS_v1_Base # model name
|
22 |
+
tokenizer: pinyin # tokenizer type
|
23 |
+
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
|
24 |
+
backbone: DiT
|
25 |
+
arch:
|
26 |
+
dim: 1024
|
27 |
+
depth: 22
|
28 |
+
heads: 16
|
29 |
+
ff_mult: 2
|
30 |
+
text_dim: 512
|
31 |
+
text_mask_padding: True
|
32 |
+
qk_norm: null # null | rms_norm
|
33 |
+
conv_layers: 4
|
34 |
+
pe_attn_head: null
|
35 |
+
checkpoint_activations: False # recompute activations and save memory for extra compute
|
36 |
+
mel_spec:
|
37 |
+
target_sample_rate: 24000
|
38 |
+
n_mel_channels: 100
|
39 |
+
hop_length: 256
|
40 |
+
win_length: 1024
|
41 |
+
n_fft: 1024
|
42 |
+
mel_spec_type: vocos # vocos | bigvgan
|
43 |
+
vocoder:
|
44 |
+
is_local: False # use local offline ckpt or not
|
45 |
+
local_path: null # local vocoder path
|
46 |
+
|
47 |
+
ckpts:
|
48 |
+
logger: wandb # wandb | tensorboard | null
|
49 |
+
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
|
50 |
+
save_per_updates: 50000 # save checkpoint per updates
|
51 |
+
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
52 |
+
last_per_updates: 5000 # save last checkpoint per updates
|
53 |
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
src/f5_tts/eval/eval_infer_batch.py
CHANGED
@@ -10,6 +10,7 @@ from importlib.resources import files
|
|
10 |
import torch
|
11 |
import torchaudio
|
12 |
from accelerate import Accelerator
|
|
|
13 |
from tqdm import tqdm
|
14 |
|
15 |
from f5_tts.eval.utils_eval import (
|
@@ -18,36 +19,26 @@ from f5_tts.eval.utils_eval import (
|
|
18 |
get_seedtts_testset_metainfo,
|
19 |
)
|
20 |
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
|
21 |
-
from f5_tts.model import CFM, DiT, UNetT
|
22 |
from f5_tts.model.utils import get_tokenizer
|
23 |
|
24 |
accelerator = Accelerator()
|
25 |
device = f"cuda:{accelerator.process_index}"
|
26 |
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
target_sample_rate = 24000
|
31 |
-
n_mel_channels = 100
|
32 |
-
hop_length = 256
|
33 |
-
win_length = 1024
|
34 |
-
n_fft = 1024
|
35 |
target_rms = 0.1
|
36 |
|
|
|
37 |
rel_path = str(files("f5_tts").joinpath("../../"))
|
38 |
|
39 |
|
40 |
def main():
|
41 |
-
# ---------------------- infer setting ---------------------- #
|
42 |
-
|
43 |
parser = argparse.ArgumentParser(description="batch inference")
|
44 |
|
45 |
parser.add_argument("-s", "--seed", default=None, type=int)
|
46 |
-
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
|
47 |
parser.add_argument("-n", "--expname", required=True)
|
48 |
-
parser.add_argument("-c", "--ckptstep", default=
|
49 |
-
parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
|
50 |
-
parser.add_argument("-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"])
|
51 |
|
52 |
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
53 |
parser.add_argument("-o", "--odemethod", default="euler")
|
@@ -58,12 +49,8 @@ def main():
|
|
58 |
args = parser.parse_args()
|
59 |
|
60 |
seed = args.seed
|
61 |
-
dataset_name = args.dataset
|
62 |
exp_name = args.expname
|
63 |
ckpt_step = args.ckptstep
|
64 |
-
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
|
65 |
-
mel_spec_type = args.mel_spec_type
|
66 |
-
tokenizer = args.tokenizer
|
67 |
|
68 |
nfe_step = args.nfestep
|
69 |
ode_method = args.odemethod
|
@@ -77,13 +64,19 @@ def main():
|
|
77 |
use_truth_duration = False
|
78 |
no_ref_audio = False
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
if testset == "ls_pc_test_clean":
|
89 |
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
@@ -111,8 +104,6 @@ def main():
|
|
111 |
|
112 |
# -------------------------------------------------#
|
113 |
|
114 |
-
use_ema = True
|
115 |
-
|
116 |
prompts_all = get_inference_prompt(
|
117 |
metainfo,
|
118 |
speed=speed,
|
@@ -139,7 +130,7 @@ def main():
|
|
139 |
|
140 |
# Model
|
141 |
model = CFM(
|
142 |
-
transformer=model_cls(**
|
143 |
mel_spec_kwargs=dict(
|
144 |
n_fft=n_fft,
|
145 |
hop_length=hop_length,
|
@@ -154,6 +145,10 @@ def main():
|
|
154 |
vocab_char_map=vocab_char_map,
|
155 |
).to(device)
|
156 |
|
|
|
|
|
|
|
|
|
157 |
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
158 |
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
159 |
|
|
|
10 |
import torch
|
11 |
import torchaudio
|
12 |
from accelerate import Accelerator
|
13 |
+
from omegaconf import OmegaConf
|
14 |
from tqdm import tqdm
|
15 |
|
16 |
from f5_tts.eval.utils_eval import (
|
|
|
19 |
get_seedtts_testset_metainfo,
|
20 |
)
|
21 |
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
|
22 |
+
from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
|
23 |
from f5_tts.model.utils import get_tokenizer
|
24 |
|
25 |
accelerator = Accelerator()
|
26 |
device = f"cuda:{accelerator.process_index}"
|
27 |
|
28 |
|
29 |
+
use_ema = True
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
target_rms = 0.1
|
31 |
|
32 |
+
|
33 |
rel_path = str(files("f5_tts").joinpath("../../"))
|
34 |
|
35 |
|
36 |
def main():
|
|
|
|
|
37 |
parser = argparse.ArgumentParser(description="batch inference")
|
38 |
|
39 |
parser.add_argument("-s", "--seed", default=None, type=int)
|
|
|
40 |
parser.add_argument("-n", "--expname", required=True)
|
41 |
+
parser.add_argument("-c", "--ckptstep", default=1250000, type=int)
|
|
|
|
|
42 |
|
43 |
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
44 |
parser.add_argument("-o", "--odemethod", default="euler")
|
|
|
49 |
args = parser.parse_args()
|
50 |
|
51 |
seed = args.seed
|
|
|
52 |
exp_name = args.expname
|
53 |
ckpt_step = args.ckptstep
|
|
|
|
|
|
|
54 |
|
55 |
nfe_step = args.nfestep
|
56 |
ode_method = args.odemethod
|
|
|
64 |
use_truth_duration = False
|
65 |
no_ref_audio = False
|
66 |
|
67 |
+
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
|
68 |
+
model_cls = globals()[model_cfg.model.backbone]
|
69 |
+
model_arc = model_cfg.model.arch
|
70 |
|
71 |
+
dataset_name = model_cfg.datasets.name
|
72 |
+
tokenizer = model_cfg.model.tokenizer
|
73 |
+
|
74 |
+
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
75 |
+
target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
|
76 |
+
n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
|
77 |
+
hop_length = model_cfg.model.mel_spec.hop_length
|
78 |
+
win_length = model_cfg.model.mel_spec.win_length
|
79 |
+
n_fft = model_cfg.model.mel_spec.n_fft
|
80 |
|
81 |
if testset == "ls_pc_test_clean":
|
82 |
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
|
|
104 |
|
105 |
# -------------------------------------------------#
|
106 |
|
|
|
|
|
107 |
prompts_all = get_inference_prompt(
|
108 |
metainfo,
|
109 |
speed=speed,
|
|
|
130 |
|
131 |
# Model
|
132 |
model = CFM(
|
133 |
+
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
134 |
mel_spec_kwargs=dict(
|
135 |
n_fft=n_fft,
|
136 |
hop_length=hop_length,
|
|
|
145 |
vocab_char_map=vocab_char_map,
|
146 |
).to(device)
|
147 |
|
148 |
+
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
|
149 |
+
if not os.path.exists(ckpt_path):
|
150 |
+
print("Loading from self-organized training checkpoints rather than released pretrained.")
|
151 |
+
ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
|
152 |
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
153 |
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
154 |
|
src/f5_tts/eval/eval_infer_batch.sh
CHANGED
@@ -1,13 +1,18 @@
|
|
1 |
#!/bin/bash
|
2 |
|
3 |
# e.g. F5-TTS, 16 NFE
|
4 |
-
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "
|
5 |
-
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "
|
6 |
-
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "
|
7 |
|
8 |
# e.g. Vanilla E2 TTS, 32 NFE
|
9 |
-
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
|
10 |
-
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
|
11 |
-
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
# etc.
|
|
|
1 |
#!/bin/bash
|
2 |
|
3 |
# e.g. F5-TTS, 16 NFE
|
4 |
+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16
|
5 |
+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16
|
6 |
+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16
|
7 |
|
8 |
# e.g. Vanilla E2 TTS, 32 NFE
|
9 |
+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0
|
10 |
+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0
|
11 |
+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0
|
12 |
+
|
13 |
+
# e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh
|
14 |
+
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
|
15 |
+
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
|
16 |
+
python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0
|
17 |
|
18 |
# etc.
|
src/f5_tts/eval/eval_librispeech_test_clean.py
CHANGED
@@ -53,43 +53,37 @@ def main():
|
|
53 |
asr_ckpt_dir = "" # auto download to cache dir
|
54 |
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
55 |
|
56 |
-
#
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
wers = []
|
61 |
|
|
|
62 |
with mp.Pool(processes=len(gpus)) as pool:
|
63 |
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
64 |
results = pool.map(run_asr_wer, args)
|
65 |
for r in results:
|
66 |
-
|
67 |
-
|
68 |
-
wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
|
69 |
-
with open(wer_result_path, "w") as f:
|
70 |
-
for line in wer_results:
|
71 |
-
wers.append(line["wer"])
|
72 |
-
json_line = json.dumps(line, ensure_ascii=False)
|
73 |
-
f.write(json_line + "\n")
|
74 |
-
|
75 |
-
wer = round(np.mean(wers) * 100, 3)
|
76 |
-
print(f"\nTotal {len(wers)} samples")
|
77 |
-
print(f"WER : {wer}%")
|
78 |
-
print(f"Results have been saved to {wer_result_path}")
|
79 |
-
|
80 |
-
# --------------------------- SIM ---------------------------
|
81 |
-
|
82 |
-
if eval_task == "sim":
|
83 |
-
sims = []
|
84 |
with mp.Pool(processes=len(gpus)) as pool:
|
85 |
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
86 |
results = pool.map(run_sim, args)
|
87 |
for r in results:
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
|
95 |
if __name__ == "__main__":
|
|
|
53 |
asr_ckpt_dir = "" # auto download to cache dir
|
54 |
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
55 |
|
56 |
+
# --------------------------------------------------------------------------
|
57 |
|
58 |
+
full_results = []
|
59 |
+
metrics = []
|
|
|
60 |
|
61 |
+
if eval_task == "wer":
|
62 |
with mp.Pool(processes=len(gpus)) as pool:
|
63 |
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
64 |
results = pool.map(run_asr_wer, args)
|
65 |
for r in results:
|
66 |
+
full_results.extend(r)
|
67 |
+
elif eval_task == "sim":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
with mp.Pool(processes=len(gpus)) as pool:
|
69 |
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
70 |
results = pool.map(run_sim, args)
|
71 |
for r in results:
|
72 |
+
full_results.extend(r)
|
73 |
+
else:
|
74 |
+
raise ValueError(f"Unknown metric type: {eval_task}")
|
75 |
+
|
76 |
+
result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
|
77 |
+
with open(result_path, "w") as f:
|
78 |
+
for line in full_results:
|
79 |
+
metrics.append(line[eval_task])
|
80 |
+
f.write(json.dumps(line, ensure_ascii=False) + "\n")
|
81 |
+
metric = round(np.mean(metrics), 5)
|
82 |
+
f.write(f"\n{eval_task.upper()}: {metric}\n")
|
83 |
+
|
84 |
+
print(f"\nTotal {len(metrics)} samples")
|
85 |
+
print(f"{eval_task.upper()}: {metric}")
|
86 |
+
print(f"{eval_task.upper()} results saved to {result_path}")
|
87 |
|
88 |
|
89 |
if __name__ == "__main__":
|
src/f5_tts/eval/eval_seedtts_testset.py
CHANGED
@@ -52,43 +52,37 @@ def main():
|
|
52 |
asr_ckpt_dir = "" # auto download to cache dir
|
53 |
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
54 |
|
55 |
-
#
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
wers = []
|
60 |
|
|
|
61 |
with mp.Pool(processes=len(gpus)) as pool:
|
62 |
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
63 |
results = pool.map(run_asr_wer, args)
|
64 |
for r in results:
|
65 |
-
|
66 |
-
|
67 |
-
wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
|
68 |
-
with open(wer_result_path, "w") as f:
|
69 |
-
for line in wer_results:
|
70 |
-
wers.append(line["wer"])
|
71 |
-
json_line = json.dumps(line, ensure_ascii=False)
|
72 |
-
f.write(json_line + "\n")
|
73 |
-
|
74 |
-
wer = round(np.mean(wers) * 100, 3)
|
75 |
-
print(f"\nTotal {len(wers)} samples")
|
76 |
-
print(f"WER : {wer}%")
|
77 |
-
print(f"Results have been saved to {wer_result_path}")
|
78 |
-
|
79 |
-
# --------------------------- SIM ---------------------------
|
80 |
-
|
81 |
-
if eval_task == "sim":
|
82 |
-
sims = []
|
83 |
with mp.Pool(processes=len(gpus)) as pool:
|
84 |
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
85 |
results = pool.map(run_sim, args)
|
86 |
for r in results:
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
|
94 |
if __name__ == "__main__":
|
|
|
52 |
asr_ckpt_dir = "" # auto download to cache dir
|
53 |
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
54 |
|
55 |
+
# --------------------------------------------------------------------------
|
56 |
|
57 |
+
full_results = []
|
58 |
+
metrics = []
|
|
|
59 |
|
60 |
+
if eval_task == "wer":
|
61 |
with mp.Pool(processes=len(gpus)) as pool:
|
62 |
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
63 |
results = pool.map(run_asr_wer, args)
|
64 |
for r in results:
|
65 |
+
full_results.extend(r)
|
66 |
+
elif eval_task == "sim":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
with mp.Pool(processes=len(gpus)) as pool:
|
68 |
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
69 |
results = pool.map(run_sim, args)
|
70 |
for r in results:
|
71 |
+
full_results.extend(r)
|
72 |
+
else:
|
73 |
+
raise ValueError(f"Unknown metric type: {eval_task}")
|
74 |
+
|
75 |
+
result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
|
76 |
+
with open(result_path, "w") as f:
|
77 |
+
for line in full_results:
|
78 |
+
metrics.append(line[eval_task])
|
79 |
+
f.write(json.dumps(line, ensure_ascii=False) + "\n")
|
80 |
+
metric = round(np.mean(metrics), 5)
|
81 |
+
f.write(f"\n{eval_task.upper()}: {metric}\n")
|
82 |
+
|
83 |
+
print(f"\nTotal {len(metrics)} samples")
|
84 |
+
print(f"{eval_task.upper()}: {metric}")
|
85 |
+
print(f"{eval_task.upper()} results saved to {result_path}")
|
86 |
|
87 |
|
88 |
if __name__ == "__main__":
|
src/f5_tts/eval/eval_utmos.py
CHANGED
@@ -13,31 +13,29 @@ def main():
|
|
13 |
parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
|
14 |
args = parser.parse_args()
|
15 |
|
16 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
|
18 |
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
|
19 |
predictor = predictor.to(device)
|
20 |
|
21 |
audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
|
22 |
-
utmos_results = {}
|
23 |
utmos_score = 0
|
24 |
|
25 |
-
|
26 |
-
wav_name = audio_path.stem
|
27 |
-
wav, sr = librosa.load(audio_path, sr=None, mono=True)
|
28 |
-
wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
|
29 |
-
score = predictor(wav_tensor, sr)
|
30 |
-
utmos_results[str(wav_name)] = score.item()
|
31 |
-
utmos_score += score.item()
|
32 |
-
|
33 |
-
avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
|
34 |
-
print(f"UTMOS: {avg_score}")
|
35 |
-
|
36 |
-
utmos_result_path = Path(args.audio_dir) / "utmos_results.json"
|
37 |
with open(utmos_result_path, "w", encoding="utf-8") as f:
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
|
43 |
if __name__ == "__main__":
|
|
|
13 |
parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
|
14 |
args = parser.parse_args()
|
15 |
|
16 |
+
device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu"
|
17 |
|
18 |
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
|
19 |
predictor = predictor.to(device)
|
20 |
|
21 |
audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
|
|
|
22 |
utmos_score = 0
|
23 |
|
24 |
+
utmos_result_path = Path(args.audio_dir) / "_utmos_results.jsonl"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
with open(utmos_result_path, "w", encoding="utf-8") as f:
|
26 |
+
for audio_path in tqdm(audio_paths, desc="Processing"):
|
27 |
+
wav, sr = librosa.load(audio_path, sr=None, mono=True)
|
28 |
+
wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
|
29 |
+
score = predictor(wav_tensor, sr)
|
30 |
+
line = {}
|
31 |
+
line["wav"], line["utmos"] = str(audio_path.stem), score.item()
|
32 |
+
utmos_score += score.item()
|
33 |
+
f.write(json.dumps(line, ensure_ascii=False) + "\n")
|
34 |
+
avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
|
35 |
+
f.write(f"\nUTMOS: {avg_score:.4f}\n")
|
36 |
+
|
37 |
+
print(f"UTMOS: {avg_score:.4f}")
|
38 |
+
print(f"UTMOS results saved to {utmos_result_path}")
|
39 |
|
40 |
|
41 |
if __name__ == "__main__":
|
src/f5_tts/eval/utils_eval.py
CHANGED
@@ -389,10 +389,10 @@ def run_sim(args):
|
|
389 |
model = model.cuda(device)
|
390 |
model.eval()
|
391 |
|
392 |
-
|
393 |
-
for
|
394 |
-
wav1, sr1 = torchaudio.load(
|
395 |
-
wav2, sr2 = torchaudio.load(
|
396 |
|
397 |
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
|
398 |
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
|
@@ -408,6 +408,11 @@ def run_sim(args):
|
|
408 |
|
409 |
sim = F.cosine_similarity(emb1, emb2)[0].item()
|
410 |
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
|
411 |
-
|
|
|
|
|
|
|
|
|
|
|
412 |
|
413 |
-
return
|
|
|
389 |
model = model.cuda(device)
|
390 |
model.eval()
|
391 |
|
392 |
+
sim_results = []
|
393 |
+
for gen_wav, prompt_wav, truth in tqdm(test_set):
|
394 |
+
wav1, sr1 = torchaudio.load(gen_wav)
|
395 |
+
wav2, sr2 = torchaudio.load(prompt_wav)
|
396 |
|
397 |
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
|
398 |
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
|
|
|
408 |
|
409 |
sim = F.cosine_similarity(emb1, emb2)[0].item()
|
410 |
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
|
411 |
+
sim_results.append(
|
412 |
+
{
|
413 |
+
"wav": Path(gen_wav).stem,
|
414 |
+
"sim": sim,
|
415 |
+
}
|
416 |
+
)
|
417 |
|
418 |
+
return sim_results
|
src/f5_tts/infer/README.md
CHANGED
@@ -23,12 +23,24 @@ Currently supported features:
|
|
23 |
- Basic TTS with Chunk Inference
|
24 |
- Multi-Style / Multi-Speaker Generation
|
25 |
- Voice Chat powered by Qwen2.5-3B-Instruct
|
|
|
26 |
|
27 |
The cli command `f5-tts_infer-gradio` equals to `python src/f5_tts/infer/infer_gradio.py`, which launches a Gradio APP (web interface) for inference.
|
28 |
|
29 |
The script will load model checkpoints from Huggingface. You can also manually download files and update the path to `load_model()` in `infer_gradio.py`. Currently only load TTS models first, will load ASR model to do transcription if `ref_text` not provided, will load LLM model if use Voice Chat.
|
30 |
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
```python
|
33 |
import gradio as gr
|
34 |
from f5_tts.infer.infer_gradio import app
|
@@ -56,14 +68,16 @@ Basically you can inference with flags:
|
|
56 |
```bash
|
57 |
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
|
58 |
f5-tts_infer-cli \
|
59 |
-
--model
|
60 |
--ref_audio "ref_audio.wav" \
|
61 |
--ref_text "The content, subtitle or transcription of reference audio." \
|
62 |
--gen_text "Some text you want TTS model generate for you."
|
63 |
|
64 |
-
#
|
65 |
-
f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local
|
66 |
-
|
|
|
|
|
67 |
|
68 |
# More instructions
|
69 |
f5-tts_infer-cli --help
|
@@ -78,8 +92,8 @@ f5-tts_infer-cli -c custom.toml
|
|
78 |
For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`:
|
79 |
|
80 |
```toml
|
81 |
-
#
|
82 |
-
model = "
|
83 |
ref_audio = "infer/examples/basic/basic_ref_en.wav"
|
84 |
# If an empty "", transcribes the reference audio automatically.
|
85 |
ref_text = "Some call me nature, others call me mother nature."
|
@@ -93,8 +107,8 @@ output_dir = "tests"
|
|
93 |
You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`.
|
94 |
|
95 |
```toml
|
96 |
-
#
|
97 |
-
model = "
|
98 |
ref_audio = "infer/examples/multi/main.flac"
|
99 |
# If an empty "", transcribes the reference audio automatically.
|
100 |
ref_text = ""
|
@@ -114,83 +128,27 @@ ref_text = ""
|
|
114 |
```
|
115 |
You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
|
116 |
|
117 |
-
##
|
118 |
|
119 |
-
|
120 |
|
121 |
```bash
|
122 |
-
|
123 |
-
|
124 |
|
125 |
-
|
|
|
|
|
126 |
|
127 |
-
|
128 |
-
|
129 |
-
python src/f5_tts/socket_server.py
|
130 |
```
|
131 |
|
132 |
-
|
133 |
-
<summary>Then create client to communicate</summary>
|
134 |
-
|
135 |
-
``` python
|
136 |
-
import socket
|
137 |
-
import numpy as np
|
138 |
-
import asyncio
|
139 |
-
import pyaudio
|
140 |
-
|
141 |
-
async def listen_to_voice(text, server_ip='localhost', server_port=9999):
|
142 |
-
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
143 |
-
client_socket.connect((server_ip, server_port))
|
144 |
-
|
145 |
-
async def play_audio_stream():
|
146 |
-
buffer = b''
|
147 |
-
p = pyaudio.PyAudio()
|
148 |
-
stream = p.open(format=pyaudio.paFloat32,
|
149 |
-
channels=1,
|
150 |
-
rate=24000, # Ensure this matches the server's sampling rate
|
151 |
-
output=True,
|
152 |
-
frames_per_buffer=2048)
|
153 |
-
|
154 |
-
try:
|
155 |
-
while True:
|
156 |
-
chunk = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 1024)
|
157 |
-
if not chunk: # End of stream
|
158 |
-
break
|
159 |
-
if b"END_OF_AUDIO" in chunk:
|
160 |
-
buffer += chunk.replace(b"END_OF_AUDIO", b"")
|
161 |
-
if buffer:
|
162 |
-
audio_array = np.frombuffer(buffer, dtype=np.float32).copy() # Make a writable copy
|
163 |
-
stream.write(audio_array.tobytes())
|
164 |
-
break
|
165 |
-
buffer += chunk
|
166 |
-
if len(buffer) >= 4096:
|
167 |
-
audio_array = np.frombuffer(buffer[:4096], dtype=np.float32).copy() # Make a writable copy
|
168 |
-
stream.write(audio_array.tobytes())
|
169 |
-
buffer = buffer[4096:]
|
170 |
-
finally:
|
171 |
-
stream.stop_stream()
|
172 |
-
stream.close()
|
173 |
-
p.terminate()
|
174 |
-
|
175 |
-
try:
|
176 |
-
# Send only the text to the server
|
177 |
-
await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, text.encode('utf-8'))
|
178 |
-
await play_audio_stream()
|
179 |
-
print("Audio playback finished.")
|
180 |
-
|
181 |
-
except Exception as e:
|
182 |
-
print(f"Error in listen_to_voice: {e}")
|
183 |
-
|
184 |
-
finally:
|
185 |
-
client_socket.close()
|
186 |
-
|
187 |
-
# Example usage: Replace this with your actual server IP and port
|
188 |
-
async def main():
|
189 |
-
await listen_to_voice("my name is jenny..", server_ip='localhost', server_port=9998)
|
190 |
-
|
191 |
-
# Run the main async function
|
192 |
-
asyncio.run(main())
|
193 |
-
```
|
194 |
|
195 |
-
|
|
|
|
|
|
|
|
|
196 |
|
|
|
23 |
- Basic TTS with Chunk Inference
|
24 |
- Multi-Style / Multi-Speaker Generation
|
25 |
- Voice Chat powered by Qwen2.5-3B-Instruct
|
26 |
+
- [Custom inference with more language support](src/f5_tts/infer/SHARED.md)
|
27 |
|
28 |
The cli command `f5-tts_infer-gradio` equals to `python src/f5_tts/infer/infer_gradio.py`, which launches a Gradio APP (web interface) for inference.
|
29 |
|
30 |
The script will load model checkpoints from Huggingface. You can also manually download files and update the path to `load_model()` in `infer_gradio.py`. Currently only load TTS models first, will load ASR model to do transcription if `ref_text` not provided, will load LLM model if use Voice Chat.
|
31 |
|
32 |
+
More flags options:
|
33 |
+
|
34 |
+
```bash
|
35 |
+
# Automatically launch the interface in the default web browser
|
36 |
+
f5-tts_infer-gradio --inbrowser
|
37 |
+
|
38 |
+
# Set the root path of the application, if it's not served from the root ("/") of the domain
|
39 |
+
# For example, if the application is served at "https://example.com/myapp"
|
40 |
+
f5-tts_infer-gradio --root_path "/myapp"
|
41 |
+
```
|
42 |
+
|
43 |
+
Could also be used as a component for larger application:
|
44 |
```python
|
45 |
import gradio as gr
|
46 |
from f5_tts.infer.infer_gradio import app
|
|
|
68 |
```bash
|
69 |
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
|
70 |
f5-tts_infer-cli \
|
71 |
+
--model F5TTS_v1_Base \
|
72 |
--ref_audio "ref_audio.wav" \
|
73 |
--ref_text "The content, subtitle or transcription of reference audio." \
|
74 |
--gen_text "Some text you want TTS model generate for you."
|
75 |
|
76 |
+
# Use BigVGAN as vocoder. Currently only support F5TTS_Base.
|
77 |
+
f5-tts_infer-cli --model F5TTS_Base --vocoder_name bigvgan --load_vocoder_from_local
|
78 |
+
|
79 |
+
# Use custom path checkpoint, e.g.
|
80 |
+
f5-tts_infer-cli --ckpt_file ckpts/F5TTS_v1_Base/model_1250000.safetensors
|
81 |
|
82 |
# More instructions
|
83 |
f5-tts_infer-cli --help
|
|
|
92 |
For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`:
|
93 |
|
94 |
```toml
|
95 |
+
# F5TTS_v1_Base | E2TTS_Base
|
96 |
+
model = "F5TTS_v1_Base"
|
97 |
ref_audio = "infer/examples/basic/basic_ref_en.wav"
|
98 |
# If an empty "", transcribes the reference audio automatically.
|
99 |
ref_text = "Some call me nature, others call me mother nature."
|
|
|
107 |
You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`.
|
108 |
|
109 |
```toml
|
110 |
+
# F5TTS_v1_Base | E2TTS_Base
|
111 |
+
model = "F5TTS_v1_Base"
|
112 |
ref_audio = "infer/examples/multi/main.flac"
|
113 |
# If an empty "", transcribes the reference audio automatically.
|
114 |
ref_text = ""
|
|
|
128 |
```
|
129 |
You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
|
130 |
|
131 |
+
## Socket Real-time Service
|
132 |
|
133 |
+
Real-time voice output with chunk stream:
|
134 |
|
135 |
```bash
|
136 |
+
# Start socket server
|
137 |
+
python src/f5_tts/socket_server.py
|
138 |
|
139 |
+
# If PyAudio not installed
|
140 |
+
sudo apt-get install portaudio19-dev
|
141 |
+
pip install pyaudio
|
142 |
|
143 |
+
# Communicate with socket client
|
144 |
+
python src/f5_tts/socket_client.py
|
|
|
145 |
```
|
146 |
|
147 |
+
## Speech Editing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
+
To test speech editing capabilities, use the following command:
|
150 |
+
|
151 |
+
```bash
|
152 |
+
python src/f5_tts/infer/speech_edit.py
|
153 |
+
```
|
154 |
|
src/f5_tts/infer/SHARED.md
CHANGED
@@ -16,7 +16,7 @@
|
|
16 |
<!-- omit in toc -->
|
17 |
### Supported Languages
|
18 |
- [Multilingual](#multilingual)
|
19 |
-
- [F5-TTS Base @ zh \& en @ F5-TTS](#f5-tts-base--zh--en--f5-tts)
|
20 |
- [English](#english)
|
21 |
- [Finnish](#finnish)
|
22 |
- [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
|
@@ -37,7 +37,17 @@
|
|
37 |
|
38 |
## Multilingual
|
39 |
|
40 |
-
#### F5-TTS Base @ zh & en @ F5-TTS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
42 |
|:---:|:------------:|:-----------:|:-------------:|
|
43 |
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
|
@@ -45,7 +55,7 @@
|
|
45 |
```bash
|
46 |
Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
|
47 |
Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
|
48 |
-
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
49 |
```
|
50 |
|
51 |
*Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
|
@@ -64,7 +74,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
|
64 |
```bash
|
65 |
Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
|
66 |
Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
|
67 |
-
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
68 |
```
|
69 |
|
70 |
|
@@ -78,7 +88,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
|
78 |
```bash
|
79 |
Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
|
80 |
Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
|
81 |
-
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
82 |
```
|
83 |
|
84 |
- [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
|
@@ -96,7 +106,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
|
96 |
```bash
|
97 |
Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
|
98 |
Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
|
99 |
-
Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
100 |
```
|
101 |
|
102 |
- Authors: SPRING Lab, Indian Institute of Technology, Madras
|
@@ -113,7 +123,7 @@ Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "c
|
|
113 |
```bash
|
114 |
Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
|
115 |
Vocab: hf://alien79/F5-TTS-italian/vocab.txt
|
116 |
-
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
117 |
```
|
118 |
|
119 |
- Trained by [Mithril Man](https://github.com/MithrilMan)
|
@@ -131,7 +141,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
|
131 |
```bash
|
132 |
Model: hf://Jmica/F5TTS/JA_25498980/model_25498980.pt
|
133 |
Vocab: hf://Jmica/F5TTS/JA_25498980/vocab_updated.txt
|
134 |
-
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
135 |
```
|
136 |
|
137 |
|
@@ -148,7 +158,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
|
148 |
```bash
|
149 |
Model: hf://hotstone228/F5-TTS-Russian/model_last.safetensors
|
150 |
Vocab: hf://hotstone228/F5-TTS-Russian/vocab.txt
|
151 |
-
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
152 |
```
|
153 |
- Finetuned by [HotDro4illa](https://github.com/HotDro4illa)
|
154 |
- Any improvements are welcome
|
|
|
16 |
<!-- omit in toc -->
|
17 |
### Supported Languages
|
18 |
- [Multilingual](#multilingual)
|
19 |
+
- [F5-TTS v1 v0 Base @ zh \& en @ F5-TTS](#f5-tts-v1-v0-base--zh--en--f5-tts)
|
20 |
- [English](#english)
|
21 |
- [Finnish](#finnish)
|
22 |
- [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
|
|
|
37 |
|
38 |
## Multilingual
|
39 |
|
40 |
+
#### F5-TTS v1 v0 Base @ zh & en @ F5-TTS
|
41 |
+
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
42 |
+
|:---:|:------------:|:-----------:|:-------------:|
|
43 |
+
|F5-TTS v1 Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_v1_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
|
44 |
+
|
45 |
+
```bash
|
46 |
+
Model: hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors
|
47 |
+
Vocab: hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt
|
48 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
49 |
+
```
|
50 |
+
|
51 |
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
52 |
|:---:|:------------:|:-----------:|:-------------:|
|
53 |
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
|
|
|
55 |
```bash
|
56 |
Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
|
57 |
Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
|
58 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
59 |
```
|
60 |
|
61 |
*Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
|
|
|
74 |
```bash
|
75 |
Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
|
76 |
Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
|
77 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
78 |
```
|
79 |
|
80 |
|
|
|
88 |
```bash
|
89 |
Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
|
90 |
Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
|
91 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
92 |
```
|
93 |
|
94 |
- [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
|
|
|
106 |
```bash
|
107 |
Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
|
108 |
Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
|
109 |
+
Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
110 |
```
|
111 |
|
112 |
- Authors: SPRING Lab, Indian Institute of Technology, Madras
|
|
|
123 |
```bash
|
124 |
Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
|
125 |
Vocab: hf://alien79/F5-TTS-italian/vocab.txt
|
126 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
127 |
```
|
128 |
|
129 |
- Trained by [Mithril Man](https://github.com/MithrilMan)
|
|
|
141 |
```bash
|
142 |
Model: hf://Jmica/F5TTS/JA_25498980/model_25498980.pt
|
143 |
Vocab: hf://Jmica/F5TTS/JA_25498980/vocab_updated.txt
|
144 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
145 |
```
|
146 |
|
147 |
|
|
|
158 |
```bash
|
159 |
Model: hf://hotstone228/F5-TTS-Russian/model_last.safetensors
|
160 |
Vocab: hf://hotstone228/F5-TTS-Russian/vocab.txt
|
161 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
162 |
```
|
163 |
- Finetuned by [HotDro4illa](https://github.com/HotDro4illa)
|
164 |
- Any improvements are welcome
|
src/f5_tts/infer/examples/basic/basic.toml
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
#
|
2 |
-
model = "
|
3 |
ref_audio = "infer/examples/basic/basic_ref_en.wav"
|
4 |
# If an empty "", transcribes the reference audio automatically.
|
5 |
ref_text = "Some call me nature, others call me mother nature."
|
|
|
1 |
+
# F5TTS_v1_Base | E2TTS_Base
|
2 |
+
model = "F5TTS_v1_Base"
|
3 |
ref_audio = "infer/examples/basic/basic_ref_en.wav"
|
4 |
# If an empty "", transcribes the reference audio automatically.
|
5 |
ref_text = "Some call me nature, others call me mother nature."
|
src/f5_tts/infer/examples/basic/basic_ref_en.wav
CHANGED
Binary files a/src/f5_tts/infer/examples/basic/basic_ref_en.wav and b/src/f5_tts/infer/examples/basic/basic_ref_en.wav differ
|
|
src/f5_tts/infer/examples/basic/basic_ref_zh.wav
CHANGED
Binary files a/src/f5_tts/infer/examples/basic/basic_ref_zh.wav and b/src/f5_tts/infer/examples/basic/basic_ref_zh.wav differ
|
|
src/f5_tts/infer/examples/multi/country.flac
CHANGED
Binary files a/src/f5_tts/infer/examples/multi/country.flac and b/src/f5_tts/infer/examples/multi/country.flac differ
|
|
src/f5_tts/infer/examples/multi/main.flac
CHANGED
Binary files a/src/f5_tts/infer/examples/multi/main.flac and b/src/f5_tts/infer/examples/multi/main.flac differ
|
|
src/f5_tts/infer/examples/multi/story.toml
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
#
|
2 |
-
model = "
|
3 |
ref_audio = "infer/examples/multi/main.flac"
|
4 |
# If an empty "", transcribes the reference audio automatically.
|
5 |
ref_text = ""
|
|
|
1 |
+
# F5TTS_v1_Base | E2TTS_Base
|
2 |
+
model = "F5TTS_v1_Base"
|
3 |
ref_audio = "infer/examples/multi/main.flac"
|
4 |
# If an empty "", transcribes the reference audio automatically.
|
5 |
ref_text = ""
|
src/f5_tts/infer/examples/multi/town.flac
CHANGED
Binary files a/src/f5_tts/infer/examples/multi/town.flac and b/src/f5_tts/infer/examples/multi/town.flac differ
|
|
src/f5_tts/infer/infer_cli.py
CHANGED
@@ -27,7 +27,7 @@ from f5_tts.infer.utils_infer import (
|
|
27 |
preprocess_ref_audio_text,
|
28 |
remove_silence_for_generated_wav,
|
29 |
)
|
30 |
-
from f5_tts.model import DiT, UNetT
|
31 |
|
32 |
|
33 |
parser = argparse.ArgumentParser(
|
@@ -50,7 +50,7 @@ parser.add_argument(
|
|
50 |
"-m",
|
51 |
"--model",
|
52 |
type=str,
|
53 |
-
help="The model name:
|
54 |
)
|
55 |
parser.add_argument(
|
56 |
"-mc",
|
@@ -172,8 +172,7 @@ config = tomli.load(open(args.config, "rb"))
|
|
172 |
|
173 |
# command-line interface parameters
|
174 |
|
175 |
-
model = args.model or config.get("model", "
|
176 |
-
model_cfg = args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath("configs/F5TTS_Base_train.yaml")))
|
177 |
ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
|
178 |
vocab_file = args.vocab_file or config.get("vocab_file", "")
|
179 |
|
@@ -245,36 +244,32 @@ vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_loc
|
|
245 |
|
246 |
# load TTS model
|
247 |
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
ckpt_step = 1250000
|
262 |
-
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
|
263 |
-
|
264 |
-
elif model == "E2-TTS":
|
265 |
-
assert args.model_cfg is None, "E2-TTS does not support custom model_cfg yet"
|
266 |
-
assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos yet"
|
267 |
-
model_cls = UNetT
|
268 |
-
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
269 |
-
if not ckpt_file: # path not specified, download from repo
|
270 |
-
repo_name = "E2-TTS"
|
271 |
-
exp_name = "E2TTS_Base"
|
272 |
ckpt_step = 1200000
|
273 |
-
|
274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
|
276 |
print(f"Using {model}...")
|
277 |
-
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
|
278 |
|
279 |
|
280 |
# inference process
|
|
|
27 |
preprocess_ref_audio_text,
|
28 |
remove_silence_for_generated_wav,
|
29 |
)
|
30 |
+
from f5_tts.model import DiT, UNetT # noqa: F401. used for config
|
31 |
|
32 |
|
33 |
parser = argparse.ArgumentParser(
|
|
|
50 |
"-m",
|
51 |
"--model",
|
52 |
type=str,
|
53 |
+
help="The model name: F5TTS_v1_Base | F5TTS_Base | E2TTS_Base | etc.",
|
54 |
)
|
55 |
parser.add_argument(
|
56 |
"-mc",
|
|
|
172 |
|
173 |
# command-line interface parameters
|
174 |
|
175 |
+
model = args.model or config.get("model", "F5TTS_v1_Base")
|
|
|
176 |
ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
|
177 |
vocab_file = args.vocab_file or config.get("vocab_file", "")
|
178 |
|
|
|
244 |
|
245 |
# load TTS model
|
246 |
|
247 |
+
model_cfg = OmegaConf.load(
|
248 |
+
args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
249 |
+
).model
|
250 |
+
model_cls = globals()[model_cfg.backbone]
|
251 |
+
|
252 |
+
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
|
253 |
+
|
254 |
+
if model != "F5TTS_Base":
|
255 |
+
assert vocoder_name == model_cfg.mel_spec.mel_spec_type
|
256 |
+
|
257 |
+
# override for previous models
|
258 |
+
if model == "F5TTS_Base":
|
259 |
+
if vocoder_name == "vocos":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
ckpt_step = 1200000
|
261 |
+
elif vocoder_name == "bigvgan":
|
262 |
+
model = "F5TTS_Base_bigvgan"
|
263 |
+
ckpt_type = "pt"
|
264 |
+
elif model == "E2TTS_Base":
|
265 |
+
repo_name = "E2-TTS"
|
266 |
+
ckpt_step = 1200000
|
267 |
+
|
268 |
+
if not ckpt_file:
|
269 |
+
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
|
270 |
|
271 |
print(f"Using {model}...")
|
272 |
+
ema_model = load_model(model_cls, model_cfg.arch, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
|
273 |
|
274 |
|
275 |
# inference process
|
src/f5_tts/infer/speech_edit.py
CHANGED
@@ -1,56 +1,63 @@
|
|
1 |
import os
|
2 |
|
3 |
-
os.environ["
|
|
|
|
|
4 |
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
7 |
import torchaudio
|
|
|
8 |
|
9 |
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
|
10 |
-
from f5_tts.model import CFM, DiT, UNetT
|
11 |
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
12 |
|
13 |
-
device =
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
n_fft = 1024
|
23 |
-
mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
|
24 |
-
target_rms = 0.1
|
25 |
-
|
26 |
-
tokenizer = "pinyin"
|
27 |
-
dataset_name = "Emilia_ZH_EN"
|
28 |
|
29 |
|
30 |
# ---------------------- infer setting ---------------------- #
|
31 |
|
32 |
seed = None # int | None
|
33 |
|
34 |
-
exp_name = "
|
35 |
-
ckpt_step =
|
36 |
|
37 |
nfe_step = 32 # 16, 32
|
38 |
cfg_strength = 2.0
|
39 |
ode_method = "euler" # euler | midpoint
|
40 |
sway_sampling_coef = -1.0
|
41 |
speed = 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
50 |
|
51 |
-
|
|
|
52 |
output_dir = "tests"
|
53 |
|
|
|
54 |
# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
|
55 |
# pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
|
56 |
# [write the origin_text into a file, e.g. tests/test_edit.txt]
|
@@ -59,7 +66,7 @@ output_dir = "tests"
|
|
59 |
# [--language "zho" for Chinese, "eng" for English]
|
60 |
# [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
|
61 |
|
62 |
-
audio_to_edit = "
|
63 |
origin_text = "Some call me nature, others call me mother nature."
|
64 |
target_text = "Some call me optimist, others call me realist."
|
65 |
parts_to_edit = [
|
@@ -98,7 +105,7 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
|
98 |
|
99 |
# Model
|
100 |
model = CFM(
|
101 |
-
transformer=model_cls(**
|
102 |
mel_spec_kwargs=dict(
|
103 |
n_fft=n_fft,
|
104 |
hop_length=hop_length,
|
|
|
1 |
import os
|
2 |
|
3 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
|
4 |
+
|
5 |
+
from importlib.resources import files
|
6 |
|
7 |
import torch
|
8 |
import torch.nn.functional as F
|
9 |
import torchaudio
|
10 |
+
from omegaconf import OmegaConf
|
11 |
|
12 |
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
|
13 |
+
from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
|
14 |
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
15 |
|
16 |
+
device = (
|
17 |
+
"cuda"
|
18 |
+
if torch.cuda.is_available()
|
19 |
+
else "xpu"
|
20 |
+
if torch.xpu.is_available()
|
21 |
+
else "mps"
|
22 |
+
if torch.backends.mps.is_available()
|
23 |
+
else "cpu"
|
24 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
|
27 |
# ---------------------- infer setting ---------------------- #
|
28 |
|
29 |
seed = None # int | None
|
30 |
|
31 |
+
exp_name = "F5TTS_v1_Base" # F5TTS_v1_Base | E2TTS_Base
|
32 |
+
ckpt_step = 1250000
|
33 |
|
34 |
nfe_step = 32 # 16, 32
|
35 |
cfg_strength = 2.0
|
36 |
ode_method = "euler" # euler | midpoint
|
37 |
sway_sampling_coef = -1.0
|
38 |
speed = 1.0
|
39 |
+
target_rms = 0.1
|
40 |
+
|
41 |
+
|
42 |
+
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
|
43 |
+
model_cls = globals()[model_cfg.model.backbone]
|
44 |
+
model_arc = model_cfg.model.arch
|
45 |
|
46 |
+
dataset_name = model_cfg.datasets.name
|
47 |
+
tokenizer = model_cfg.model.tokenizer
|
|
|
48 |
|
49 |
+
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
50 |
+
target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
|
51 |
+
n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
|
52 |
+
hop_length = model_cfg.model.mel_spec.hop_length
|
53 |
+
win_length = model_cfg.model.mel_spec.win_length
|
54 |
+
n_fft = model_cfg.model.mel_spec.n_fft
|
55 |
|
56 |
+
|
57 |
+
ckpt_path = str(files("f5_tts").joinpath("../../")) + f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
|
58 |
output_dir = "tests"
|
59 |
|
60 |
+
|
61 |
# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
|
62 |
# pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
|
63 |
# [write the origin_text into a file, e.g. tests/test_edit.txt]
|
|
|
66 |
# [--language "zho" for Chinese, "eng" for English]
|
67 |
# [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
|
68 |
|
69 |
+
audio_to_edit = str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav"))
|
70 |
origin_text = "Some call me nature, others call me mother nature."
|
71 |
target_text = "Some call me optimist, others call me realist."
|
72 |
parts_to_edit = [
|
|
|
105 |
|
106 |
# Model
|
107 |
model = CFM(
|
108 |
+
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
109 |
mel_spec_kwargs=dict(
|
110 |
n_fft=n_fft,
|
111 |
hop_length=hop_length,
|
src/f5_tts/infer/utils_infer.py
CHANGED
@@ -2,8 +2,9 @@
|
|
2 |
# Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
|
3 |
import os
|
4 |
import sys
|
|
|
5 |
|
6 |
-
os.environ["
|
7 |
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
|
8 |
|
9 |
import hashlib
|
@@ -33,7 +34,15 @@ from f5_tts.model.utils import (
|
|
33 |
|
34 |
_ref_audio_cache = {}
|
35 |
|
36 |
-
device =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
# -----------------------------------------
|
39 |
|
@@ -292,19 +301,19 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
|
|
292 |
)
|
293 |
non_silent_wave = AudioSegment.silent(duration=0)
|
294 |
for non_silent_seg in non_silent_segs:
|
295 |
-
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) >
|
296 |
show_info("Audio is over 15s, clipping short. (1)")
|
297 |
break
|
298 |
non_silent_wave += non_silent_seg
|
299 |
|
300 |
# 2. try to find short silence for clipping if 1. failed
|
301 |
-
if len(non_silent_wave) >
|
302 |
non_silent_segs = silence.split_on_silence(
|
303 |
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
|
304 |
)
|
305 |
non_silent_wave = AudioSegment.silent(duration=0)
|
306 |
for non_silent_seg in non_silent_segs:
|
307 |
-
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) >
|
308 |
show_info("Audio is over 15s, clipping short. (2)")
|
309 |
break
|
310 |
non_silent_wave += non_silent_seg
|
@@ -312,8 +321,8 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
|
|
312 |
aseg = non_silent_wave
|
313 |
|
314 |
# 3. if no proper silence found for clipping
|
315 |
-
if len(aseg) >
|
316 |
-
aseg = aseg[:
|
317 |
show_info("Audio is over 15s, clipping short. (3)")
|
318 |
|
319 |
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
|
@@ -374,29 +383,31 @@ def infer_process(
|
|
374 |
):
|
375 |
# Split the input text into batches
|
376 |
audio, sr = torchaudio.load(ref_audio)
|
377 |
-
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (
|
378 |
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
379 |
for i, gen_text in enumerate(gen_text_batches):
|
380 |
print(f"gen_text {i}", gen_text)
|
381 |
print("\n")
|
382 |
|
383 |
show_info(f"Generating audio in {len(gen_text_batches)} batches...")
|
384 |
-
return
|
385 |
-
(
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
|
|
|
|
400 |
)
|
401 |
|
402 |
|
@@ -419,6 +430,8 @@ def infer_batch_process(
|
|
419 |
speed=1,
|
420 |
fix_duration=None,
|
421 |
device=None,
|
|
|
|
|
422 |
):
|
423 |
audio, sr = ref_audio
|
424 |
if audio.shape[0] > 1:
|
@@ -437,7 +450,12 @@ def infer_batch_process(
|
|
437 |
|
438 |
if len(ref_text[-1].encode("utf-8")) == 1:
|
439 |
ref_text = ref_text + " "
|
440 |
-
|
|
|
|
|
|
|
|
|
|
|
441 |
# Prepare the text
|
442 |
text_list = [ref_text + gen_text]
|
443 |
final_text_list = convert_char_to_pinyin(text_list)
|
@@ -449,7 +467,7 @@ def infer_batch_process(
|
|
449 |
# Calculate duration
|
450 |
ref_text_len = len(ref_text.encode("utf-8"))
|
451 |
gen_text_len = len(gen_text.encode("utf-8"))
|
452 |
-
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len /
|
453 |
|
454 |
# inference
|
455 |
with torch.inference_mode():
|
@@ -461,64 +479,88 @@ def infer_batch_process(
|
|
461 |
cfg_strength=cfg_strength,
|
462 |
sway_sampling_coef=sway_sampling_coef,
|
463 |
)
|
|
|
464 |
|
465 |
-
generated = generated.to(torch.float32)
|
466 |
generated = generated[:, ref_audio_len:, :]
|
467 |
-
|
468 |
if mel_spec_type == "vocos":
|
469 |
-
generated_wave = vocoder.decode(
|
470 |
elif mel_spec_type == "bigvgan":
|
471 |
-
generated_wave = vocoder(
|
472 |
if rms < target_rms:
|
473 |
generated_wave = generated_wave * rms / target_rms
|
474 |
|
475 |
# wav -> numpy
|
476 |
generated_wave = generated_wave.squeeze().cpu().numpy()
|
477 |
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
|
|
|
|
|
|
|
|
|
|
485 |
else:
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
517 |
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
return final_wave, target_sample_rate, combined_spectrogram
|
522 |
|
523 |
|
524 |
# remove silence from generated wav
|
|
|
2 |
# Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
|
3 |
import os
|
4 |
import sys
|
5 |
+
from concurrent.futures import ThreadPoolExecutor
|
6 |
|
7 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
|
8 |
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
|
9 |
|
10 |
import hashlib
|
|
|
34 |
|
35 |
_ref_audio_cache = {}
|
36 |
|
37 |
+
device = (
|
38 |
+
"cuda"
|
39 |
+
if torch.cuda.is_available()
|
40 |
+
else "xpu"
|
41 |
+
if torch.xpu.is_available()
|
42 |
+
else "mps"
|
43 |
+
if torch.backends.mps.is_available()
|
44 |
+
else "cpu"
|
45 |
+
)
|
46 |
|
47 |
# -----------------------------------------
|
48 |
|
|
|
301 |
)
|
302 |
non_silent_wave = AudioSegment.silent(duration=0)
|
303 |
for non_silent_seg in non_silent_segs:
|
304 |
+
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
305 |
show_info("Audio is over 15s, clipping short. (1)")
|
306 |
break
|
307 |
non_silent_wave += non_silent_seg
|
308 |
|
309 |
# 2. try to find short silence for clipping if 1. failed
|
310 |
+
if len(non_silent_wave) > 12000:
|
311 |
non_silent_segs = silence.split_on_silence(
|
312 |
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
|
313 |
)
|
314 |
non_silent_wave = AudioSegment.silent(duration=0)
|
315 |
for non_silent_seg in non_silent_segs:
|
316 |
+
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
317 |
show_info("Audio is over 15s, clipping short. (2)")
|
318 |
break
|
319 |
non_silent_wave += non_silent_seg
|
|
|
321 |
aseg = non_silent_wave
|
322 |
|
323 |
# 3. if no proper silence found for clipping
|
324 |
+
if len(aseg) > 12000:
|
325 |
+
aseg = aseg[:12000]
|
326 |
show_info("Audio is over 15s, clipping short. (3)")
|
327 |
|
328 |
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
|
|
|
383 |
):
|
384 |
# Split the input text into batches
|
385 |
audio, sr = torchaudio.load(ref_audio)
|
386 |
+
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr))
|
387 |
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
388 |
for i, gen_text in enumerate(gen_text_batches):
|
389 |
print(f"gen_text {i}", gen_text)
|
390 |
print("\n")
|
391 |
|
392 |
show_info(f"Generating audio in {len(gen_text_batches)} batches...")
|
393 |
+
return next(
|
394 |
+
infer_batch_process(
|
395 |
+
(audio, sr),
|
396 |
+
ref_text,
|
397 |
+
gen_text_batches,
|
398 |
+
model_obj,
|
399 |
+
vocoder,
|
400 |
+
mel_spec_type=mel_spec_type,
|
401 |
+
progress=progress,
|
402 |
+
target_rms=target_rms,
|
403 |
+
cross_fade_duration=cross_fade_duration,
|
404 |
+
nfe_step=nfe_step,
|
405 |
+
cfg_strength=cfg_strength,
|
406 |
+
sway_sampling_coef=sway_sampling_coef,
|
407 |
+
speed=speed,
|
408 |
+
fix_duration=fix_duration,
|
409 |
+
device=device,
|
410 |
+
)
|
411 |
)
|
412 |
|
413 |
|
|
|
430 |
speed=1,
|
431 |
fix_duration=None,
|
432 |
device=None,
|
433 |
+
streaming=False,
|
434 |
+
chunk_size=2048,
|
435 |
):
|
436 |
audio, sr = ref_audio
|
437 |
if audio.shape[0] > 1:
|
|
|
450 |
|
451 |
if len(ref_text[-1].encode("utf-8")) == 1:
|
452 |
ref_text = ref_text + " "
|
453 |
+
|
454 |
+
def process_batch(gen_text):
|
455 |
+
local_speed = speed
|
456 |
+
if len(gen_text.encode("utf-8")) < 10:
|
457 |
+
local_speed = 0.3
|
458 |
+
|
459 |
# Prepare the text
|
460 |
text_list = [ref_text + gen_text]
|
461 |
final_text_list = convert_char_to_pinyin(text_list)
|
|
|
467 |
# Calculate duration
|
468 |
ref_text_len = len(ref_text.encode("utf-8"))
|
469 |
gen_text_len = len(gen_text.encode("utf-8"))
|
470 |
+
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed)
|
471 |
|
472 |
# inference
|
473 |
with torch.inference_mode():
|
|
|
479 |
cfg_strength=cfg_strength,
|
480 |
sway_sampling_coef=sway_sampling_coef,
|
481 |
)
|
482 |
+
del _
|
483 |
|
484 |
+
generated = generated.to(torch.float32) # generated mel spectrogram
|
485 |
generated = generated[:, ref_audio_len:, :]
|
486 |
+
generated = generated.permute(0, 2, 1)
|
487 |
if mel_spec_type == "vocos":
|
488 |
+
generated_wave = vocoder.decode(generated)
|
489 |
elif mel_spec_type == "bigvgan":
|
490 |
+
generated_wave = vocoder(generated)
|
491 |
if rms < target_rms:
|
492 |
generated_wave = generated_wave * rms / target_rms
|
493 |
|
494 |
# wav -> numpy
|
495 |
generated_wave = generated_wave.squeeze().cpu().numpy()
|
496 |
|
497 |
+
if streaming:
|
498 |
+
for j in range(0, len(generated_wave), chunk_size):
|
499 |
+
yield generated_wave[j : j + chunk_size], target_sample_rate
|
500 |
+
else:
|
501 |
+
generated_cpu = generated[0].cpu().numpy()
|
502 |
+
del generated
|
503 |
+
yield generated_wave, generated_cpu
|
504 |
+
|
505 |
+
if streaming:
|
506 |
+
for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:
|
507 |
+
for chunk in process_batch(gen_text):
|
508 |
+
yield chunk
|
509 |
else:
|
510 |
+
with ThreadPoolExecutor() as executor:
|
511 |
+
futures = [executor.submit(process_batch, gen_text) for gen_text in gen_text_batches]
|
512 |
+
for future in progress.tqdm(futures) if progress is not None else futures:
|
513 |
+
result = future.result()
|
514 |
+
if result:
|
515 |
+
generated_wave, generated_mel_spec = next(result)
|
516 |
+
generated_waves.append(generated_wave)
|
517 |
+
spectrograms.append(generated_mel_spec)
|
518 |
+
|
519 |
+
if generated_waves:
|
520 |
+
if cross_fade_duration <= 0:
|
521 |
+
# Simply concatenate
|
522 |
+
final_wave = np.concatenate(generated_waves)
|
523 |
+
else:
|
524 |
+
# Combine all generated waves with cross-fading
|
525 |
+
final_wave = generated_waves[0]
|
526 |
+
for i in range(1, len(generated_waves)):
|
527 |
+
prev_wave = final_wave
|
528 |
+
next_wave = generated_waves[i]
|
529 |
+
|
530 |
+
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
|
531 |
+
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
|
532 |
+
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
|
533 |
+
|
534 |
+
if cross_fade_samples <= 0:
|
535 |
+
# No overlap possible, concatenate
|
536 |
+
final_wave = np.concatenate([prev_wave, next_wave])
|
537 |
+
continue
|
538 |
+
|
539 |
+
# Overlapping parts
|
540 |
+
prev_overlap = prev_wave[-cross_fade_samples:]
|
541 |
+
next_overlap = next_wave[:cross_fade_samples]
|
542 |
+
|
543 |
+
# Fade out and fade in
|
544 |
+
fade_out = np.linspace(1, 0, cross_fade_samples)
|
545 |
+
fade_in = np.linspace(0, 1, cross_fade_samples)
|
546 |
+
|
547 |
+
# Cross-faded overlap
|
548 |
+
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
|
549 |
+
|
550 |
+
# Combine
|
551 |
+
new_wave = np.concatenate(
|
552 |
+
[prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
|
553 |
+
)
|
554 |
+
|
555 |
+
final_wave = new_wave
|
556 |
+
|
557 |
+
# Create a combined spectrogram
|
558 |
+
combined_spectrogram = np.concatenate(spectrograms, axis=1)
|
559 |
+
|
560 |
+
yield final_wave, target_sample_rate, combined_spectrogram
|
561 |
|
562 |
+
else:
|
563 |
+
yield None, target_sample_rate, None
|
|
|
|
|
564 |
|
565 |
|
566 |
# remove silence from generated wav
|
src/f5_tts/model/backbones/README.md
CHANGED
@@ -4,7 +4,7 @@
|
|
4 |
### unett.py
|
5 |
- flat unet transformer
|
6 |
- structure same as in e2-tts & voicebox paper except using rotary pos emb
|
7 |
-
-
|
8 |
|
9 |
### dit.py
|
10 |
- adaln-zero dit
|
@@ -14,7 +14,7 @@
|
|
14 |
- possible long skip connection (first layer to last layer)
|
15 |
|
16 |
### mmdit.py
|
17 |
-
-
|
18 |
- timestep as condition
|
19 |
- left stream: text embedded and applied a abs pos emb
|
20 |
- right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
|
|
|
4 |
### unett.py
|
5 |
- flat unet transformer
|
6 |
- structure same as in e2-tts & voicebox paper except using rotary pos emb
|
7 |
+
- possible abs pos emb & convnextv2 blocks for embedded text before concat
|
8 |
|
9 |
### dit.py
|
10 |
- adaln-zero dit
|
|
|
14 |
- possible long skip connection (first layer to last layer)
|
15 |
|
16 |
### mmdit.py
|
17 |
+
- stable diffusion 3 block structure
|
18 |
- timestep as condition
|
19 |
- left stream: text embedded and applied a abs pos emb
|
20 |
- right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
|
src/f5_tts/model/backbones/dit.py
CHANGED
@@ -20,7 +20,7 @@ from f5_tts.model.modules import (
|
|
20 |
ConvNeXtV2Block,
|
21 |
ConvPositionEmbedding,
|
22 |
DiTBlock,
|
23 |
-
|
24 |
precompute_freqs_cis,
|
25 |
get_pos_embed_indices,
|
26 |
)
|
@@ -30,10 +30,12 @@ from f5_tts.model.modules import (
|
|
30 |
|
31 |
|
32 |
class TextEmbedding(nn.Module):
|
33 |
-
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
34 |
super().__init__()
|
35 |
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
36 |
|
|
|
|
|
37 |
if conv_layers > 0:
|
38 |
self.extra_modeling = True
|
39 |
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
@@ -49,6 +51,8 @@ class TextEmbedding(nn.Module):
|
|
49 |
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
50 |
batch, text_len = text.shape[0], text.shape[1]
|
51 |
text = F.pad(text, (0, seq_len - text_len), value=0)
|
|
|
|
|
52 |
|
53 |
if drop_text: # cfg for text
|
54 |
text = torch.zeros_like(text)
|
@@ -64,7 +68,13 @@ class TextEmbedding(nn.Module):
|
|
64 |
text = text + text_pos_embed
|
65 |
|
66 |
# convnextv2 blocks
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
return text
|
70 |
|
@@ -103,7 +113,10 @@ class DiT(nn.Module):
|
|
103 |
mel_dim=100,
|
104 |
text_num_embeds=256,
|
105 |
text_dim=None,
|
|
|
|
|
106 |
conv_layers=0,
|
|
|
107 |
long_skip_connection=False,
|
108 |
checkpoint_activations=False,
|
109 |
):
|
@@ -112,7 +125,10 @@ class DiT(nn.Module):
|
|
112 |
self.time_embed = TimestepEmbedding(dim)
|
113 |
if text_dim is None:
|
114 |
text_dim = mel_dim
|
115 |
-
self.text_embed = TextEmbedding(
|
|
|
|
|
|
|
116 |
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
117 |
|
118 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
@@ -121,15 +137,40 @@ class DiT(nn.Module):
|
|
121 |
self.depth = depth
|
122 |
|
123 |
self.transformer_blocks = nn.ModuleList(
|
124 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
)
|
126 |
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
127 |
|
128 |
-
self.norm_out =
|
129 |
self.proj_out = nn.Linear(dim, mel_dim)
|
130 |
|
131 |
self.checkpoint_activations = checkpoint_activations
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
def ckpt_wrapper(self, module):
|
134 |
# https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
|
135 |
def ckpt_forward(*inputs):
|
@@ -138,6 +179,9 @@ class DiT(nn.Module):
|
|
138 |
|
139 |
return ckpt_forward
|
140 |
|
|
|
|
|
|
|
141 |
def forward(
|
142 |
self,
|
143 |
x: float["b n d"], # nosied input audio # noqa: F722
|
@@ -147,14 +191,25 @@ class DiT(nn.Module):
|
|
147 |
drop_audio_cond, # cfg for cond audio
|
148 |
drop_text, # cfg for text
|
149 |
mask: bool["b n"] | None = None, # noqa: F722
|
|
|
150 |
):
|
151 |
batch, seq_len = x.shape[0], x.shape[1]
|
152 |
if time.ndim == 0:
|
153 |
time = time.repeat(batch)
|
154 |
|
155 |
-
# t: conditioning time,
|
156 |
t = self.time_embed(time)
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
159 |
|
160 |
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
|
|
20 |
ConvNeXtV2Block,
|
21 |
ConvPositionEmbedding,
|
22 |
DiTBlock,
|
23 |
+
AdaLayerNorm_Final,
|
24 |
precompute_freqs_cis,
|
25 |
get_pos_embed_indices,
|
26 |
)
|
|
|
30 |
|
31 |
|
32 |
class TextEmbedding(nn.Module):
|
33 |
+
def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
|
34 |
super().__init__()
|
35 |
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
36 |
|
37 |
+
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
|
38 |
+
|
39 |
if conv_layers > 0:
|
40 |
self.extra_modeling = True
|
41 |
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
|
|
51 |
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
52 |
batch, text_len = text.shape[0], text.shape[1]
|
53 |
text = F.pad(text, (0, seq_len - text_len), value=0)
|
54 |
+
if self.mask_padding:
|
55 |
+
text_mask = text == 0
|
56 |
|
57 |
if drop_text: # cfg for text
|
58 |
text = torch.zeros_like(text)
|
|
|
68 |
text = text + text_pos_embed
|
69 |
|
70 |
# convnextv2 blocks
|
71 |
+
if self.mask_padding:
|
72 |
+
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
|
73 |
+
for block in self.text_blocks:
|
74 |
+
text = block(text)
|
75 |
+
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
|
76 |
+
else:
|
77 |
+
text = self.text_blocks(text)
|
78 |
|
79 |
return text
|
80 |
|
|
|
113 |
mel_dim=100,
|
114 |
text_num_embeds=256,
|
115 |
text_dim=None,
|
116 |
+
text_mask_padding=True,
|
117 |
+
qk_norm=None,
|
118 |
conv_layers=0,
|
119 |
+
pe_attn_head=None,
|
120 |
long_skip_connection=False,
|
121 |
checkpoint_activations=False,
|
122 |
):
|
|
|
125 |
self.time_embed = TimestepEmbedding(dim)
|
126 |
if text_dim is None:
|
127 |
text_dim = mel_dim
|
128 |
+
self.text_embed = TextEmbedding(
|
129 |
+
text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
|
130 |
+
)
|
131 |
+
self.text_cond, self.text_uncond = None, None # text cache
|
132 |
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
133 |
|
134 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
|
|
137 |
self.depth = depth
|
138 |
|
139 |
self.transformer_blocks = nn.ModuleList(
|
140 |
+
[
|
141 |
+
DiTBlock(
|
142 |
+
dim=dim,
|
143 |
+
heads=heads,
|
144 |
+
dim_head=dim_head,
|
145 |
+
ff_mult=ff_mult,
|
146 |
+
dropout=dropout,
|
147 |
+
qk_norm=qk_norm,
|
148 |
+
pe_attn_head=pe_attn_head,
|
149 |
+
)
|
150 |
+
for _ in range(depth)
|
151 |
+
]
|
152 |
)
|
153 |
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
154 |
|
155 |
+
self.norm_out = AdaLayerNorm_Final(dim) # final modulation
|
156 |
self.proj_out = nn.Linear(dim, mel_dim)
|
157 |
|
158 |
self.checkpoint_activations = checkpoint_activations
|
159 |
|
160 |
+
self.initialize_weights()
|
161 |
+
|
162 |
+
def initialize_weights(self):
|
163 |
+
# Zero-out AdaLN layers in DiT blocks:
|
164 |
+
for block in self.transformer_blocks:
|
165 |
+
nn.init.constant_(block.attn_norm.linear.weight, 0)
|
166 |
+
nn.init.constant_(block.attn_norm.linear.bias, 0)
|
167 |
+
|
168 |
+
# Zero-out output layers:
|
169 |
+
nn.init.constant_(self.norm_out.linear.weight, 0)
|
170 |
+
nn.init.constant_(self.norm_out.linear.bias, 0)
|
171 |
+
nn.init.constant_(self.proj_out.weight, 0)
|
172 |
+
nn.init.constant_(self.proj_out.bias, 0)
|
173 |
+
|
174 |
def ckpt_wrapper(self, module):
|
175 |
# https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
|
176 |
def ckpt_forward(*inputs):
|
|
|
179 |
|
180 |
return ckpt_forward
|
181 |
|
182 |
+
def clear_cache(self):
|
183 |
+
self.text_cond, self.text_uncond = None, None
|
184 |
+
|
185 |
def forward(
|
186 |
self,
|
187 |
x: float["b n d"], # nosied input audio # noqa: F722
|
|
|
191 |
drop_audio_cond, # cfg for cond audio
|
192 |
drop_text, # cfg for text
|
193 |
mask: bool["b n"] | None = None, # noqa: F722
|
194 |
+
cache=False,
|
195 |
):
|
196 |
batch, seq_len = x.shape[0], x.shape[1]
|
197 |
if time.ndim == 0:
|
198 |
time = time.repeat(batch)
|
199 |
|
200 |
+
# t: conditioning time, text: text, x: noised audio + cond audio + text
|
201 |
t = self.time_embed(time)
|
202 |
+
if cache:
|
203 |
+
if drop_text:
|
204 |
+
if self.text_uncond is None:
|
205 |
+
self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
|
206 |
+
text_embed = self.text_uncond
|
207 |
+
else:
|
208 |
+
if self.text_cond is None:
|
209 |
+
self.text_cond = self.text_embed(text, seq_len, drop_text=False)
|
210 |
+
text_embed = self.text_cond
|
211 |
+
else:
|
212 |
+
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
213 |
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
214 |
|
215 |
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
src/f5_tts/model/backbones/mmdit.py
CHANGED
@@ -18,7 +18,7 @@ from f5_tts.model.modules import (
|
|
18 |
TimestepEmbedding,
|
19 |
ConvPositionEmbedding,
|
20 |
MMDiTBlock,
|
21 |
-
|
22 |
precompute_freqs_cis,
|
23 |
get_pos_embed_indices,
|
24 |
)
|
@@ -28,18 +28,24 @@ from f5_tts.model.modules import (
|
|
28 |
|
29 |
|
30 |
class TextEmbedding(nn.Module):
|
31 |
-
def __init__(self, out_dim, text_num_embeds):
|
32 |
super().__init__()
|
33 |
self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
|
34 |
|
|
|
|
|
35 |
self.precompute_max_pos = 1024
|
36 |
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
|
37 |
|
38 |
def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
|
39 |
-
text = text + 1
|
40 |
-
if
|
|
|
|
|
|
|
41 |
text = torch.zeros_like(text)
|
42 |
-
|
|
|
43 |
|
44 |
# sinus pos emb
|
45 |
batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
|
@@ -49,6 +55,9 @@ class TextEmbedding(nn.Module):
|
|
49 |
|
50 |
text = text + text_pos_embed
|
51 |
|
|
|
|
|
|
|
52 |
return text
|
53 |
|
54 |
|
@@ -83,13 +92,16 @@ class MMDiT(nn.Module):
|
|
83 |
dim_head=64,
|
84 |
dropout=0.1,
|
85 |
ff_mult=4,
|
86 |
-
text_num_embeds=256,
|
87 |
mel_dim=100,
|
|
|
|
|
|
|
88 |
):
|
89 |
super().__init__()
|
90 |
|
91 |
self.time_embed = TimestepEmbedding(dim)
|
92 |
-
self.text_embed = TextEmbedding(dim, text_num_embeds)
|
|
|
93 |
self.audio_embed = AudioEmbedding(mel_dim, dim)
|
94 |
|
95 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
@@ -106,13 +118,33 @@ class MMDiT(nn.Module):
|
|
106 |
dropout=dropout,
|
107 |
ff_mult=ff_mult,
|
108 |
context_pre_only=i == depth - 1,
|
|
|
109 |
)
|
110 |
for i in range(depth)
|
111 |
]
|
112 |
)
|
113 |
-
self.norm_out =
|
114 |
self.proj_out = nn.Linear(dim, mel_dim)
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
def forward(
|
117 |
self,
|
118 |
x: float["b n d"], # nosied input audio # noqa: F722
|
@@ -122,6 +154,7 @@ class MMDiT(nn.Module):
|
|
122 |
drop_audio_cond, # cfg for cond audio
|
123 |
drop_text, # cfg for text
|
124 |
mask: bool["b n"] | None = None, # noqa: F722
|
|
|
125 |
):
|
126 |
batch = x.shape[0]
|
127 |
if time.ndim == 0:
|
@@ -129,7 +162,17 @@ class MMDiT(nn.Module):
|
|
129 |
|
130 |
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
|
131 |
t = self.time_embed(time)
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
|
134 |
|
135 |
seq_len = x.shape[1]
|
|
|
18 |
TimestepEmbedding,
|
19 |
ConvPositionEmbedding,
|
20 |
MMDiTBlock,
|
21 |
+
AdaLayerNorm_Final,
|
22 |
precompute_freqs_cis,
|
23 |
get_pos_embed_indices,
|
24 |
)
|
|
|
28 |
|
29 |
|
30 |
class TextEmbedding(nn.Module):
|
31 |
+
def __init__(self, out_dim, text_num_embeds, mask_padding=True):
|
32 |
super().__init__()
|
33 |
self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
|
34 |
|
35 |
+
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
|
36 |
+
|
37 |
self.precompute_max_pos = 1024
|
38 |
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
|
39 |
|
40 |
def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
|
41 |
+
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
42 |
+
if self.mask_padding:
|
43 |
+
text_mask = text == 0
|
44 |
+
|
45 |
+
if drop_text: # cfg for text
|
46 |
text = torch.zeros_like(text)
|
47 |
+
|
48 |
+
text = self.text_embed(text) # b nt -> b nt d
|
49 |
|
50 |
# sinus pos emb
|
51 |
batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
|
|
|
55 |
|
56 |
text = text + text_pos_embed
|
57 |
|
58 |
+
if self.mask_padding:
|
59 |
+
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
|
60 |
+
|
61 |
return text
|
62 |
|
63 |
|
|
|
92 |
dim_head=64,
|
93 |
dropout=0.1,
|
94 |
ff_mult=4,
|
|
|
95 |
mel_dim=100,
|
96 |
+
text_num_embeds=256,
|
97 |
+
text_mask_padding=True,
|
98 |
+
qk_norm=None,
|
99 |
):
|
100 |
super().__init__()
|
101 |
|
102 |
self.time_embed = TimestepEmbedding(dim)
|
103 |
+
self.text_embed = TextEmbedding(dim, text_num_embeds, mask_padding=text_mask_padding)
|
104 |
+
self.text_cond, self.text_uncond = None, None # text cache
|
105 |
self.audio_embed = AudioEmbedding(mel_dim, dim)
|
106 |
|
107 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
|
|
118 |
dropout=dropout,
|
119 |
ff_mult=ff_mult,
|
120 |
context_pre_only=i == depth - 1,
|
121 |
+
qk_norm=qk_norm,
|
122 |
)
|
123 |
for i in range(depth)
|
124 |
]
|
125 |
)
|
126 |
+
self.norm_out = AdaLayerNorm_Final(dim) # final modulation
|
127 |
self.proj_out = nn.Linear(dim, mel_dim)
|
128 |
|
129 |
+
self.initialize_weights()
|
130 |
+
|
131 |
+
def initialize_weights(self):
|
132 |
+
# Zero-out AdaLN layers in MMDiT blocks:
|
133 |
+
for block in self.transformer_blocks:
|
134 |
+
nn.init.constant_(block.attn_norm_x.linear.weight, 0)
|
135 |
+
nn.init.constant_(block.attn_norm_x.linear.bias, 0)
|
136 |
+
nn.init.constant_(block.attn_norm_c.linear.weight, 0)
|
137 |
+
nn.init.constant_(block.attn_norm_c.linear.bias, 0)
|
138 |
+
|
139 |
+
# Zero-out output layers:
|
140 |
+
nn.init.constant_(self.norm_out.linear.weight, 0)
|
141 |
+
nn.init.constant_(self.norm_out.linear.bias, 0)
|
142 |
+
nn.init.constant_(self.proj_out.weight, 0)
|
143 |
+
nn.init.constant_(self.proj_out.bias, 0)
|
144 |
+
|
145 |
+
def clear_cache(self):
|
146 |
+
self.text_cond, self.text_uncond = None, None
|
147 |
+
|
148 |
def forward(
|
149 |
self,
|
150 |
x: float["b n d"], # nosied input audio # noqa: F722
|
|
|
154 |
drop_audio_cond, # cfg for cond audio
|
155 |
drop_text, # cfg for text
|
156 |
mask: bool["b n"] | None = None, # noqa: F722
|
157 |
+
cache=False,
|
158 |
):
|
159 |
batch = x.shape[0]
|
160 |
if time.ndim == 0:
|
|
|
162 |
|
163 |
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
|
164 |
t = self.time_embed(time)
|
165 |
+
if cache:
|
166 |
+
if drop_text:
|
167 |
+
if self.text_uncond is None:
|
168 |
+
self.text_uncond = self.text_embed(text, drop_text=True)
|
169 |
+
c = self.text_uncond
|
170 |
+
else:
|
171 |
+
if self.text_cond is None:
|
172 |
+
self.text_cond = self.text_embed(text, drop_text=False)
|
173 |
+
c = self.text_cond
|
174 |
+
else:
|
175 |
+
c = self.text_embed(text, drop_text=drop_text)
|
176 |
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
|
177 |
|
178 |
seq_len = x.shape[1]
|
src/f5_tts/model/backbones/unett.py
CHANGED
@@ -33,10 +33,12 @@ from f5_tts.model.modules import (
|
|
33 |
|
34 |
|
35 |
class TextEmbedding(nn.Module):
|
36 |
-
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
37 |
super().__init__()
|
38 |
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
39 |
|
|
|
|
|
40 |
if conv_layers > 0:
|
41 |
self.extra_modeling = True
|
42 |
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
@@ -52,6 +54,8 @@ class TextEmbedding(nn.Module):
|
|
52 |
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
53 |
batch, text_len = text.shape[0], text.shape[1]
|
54 |
text = F.pad(text, (0, seq_len - text_len), value=0)
|
|
|
|
|
55 |
|
56 |
if drop_text: # cfg for text
|
57 |
text = torch.zeros_like(text)
|
@@ -67,7 +71,13 @@ class TextEmbedding(nn.Module):
|
|
67 |
text = text + text_pos_embed
|
68 |
|
69 |
# convnextv2 blocks
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
return text
|
73 |
|
@@ -106,7 +116,10 @@ class UNetT(nn.Module):
|
|
106 |
mel_dim=100,
|
107 |
text_num_embeds=256,
|
108 |
text_dim=None,
|
|
|
|
|
109 |
conv_layers=0,
|
|
|
110 |
skip_connect_type: Literal["add", "concat", "none"] = "concat",
|
111 |
):
|
112 |
super().__init__()
|
@@ -115,7 +128,10 @@ class UNetT(nn.Module):
|
|
115 |
self.time_embed = TimestepEmbedding(dim)
|
116 |
if text_dim is None:
|
117 |
text_dim = mel_dim
|
118 |
-
self.text_embed = TextEmbedding(
|
|
|
|
|
|
|
119 |
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
120 |
|
121 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
@@ -134,11 +150,12 @@ class UNetT(nn.Module):
|
|
134 |
|
135 |
attn_norm = RMSNorm(dim)
|
136 |
attn = Attention(
|
137 |
-
processor=AttnProcessor(),
|
138 |
dim=dim,
|
139 |
heads=heads,
|
140 |
dim_head=dim_head,
|
141 |
dropout=dropout,
|
|
|
142 |
)
|
143 |
|
144 |
ff_norm = RMSNorm(dim)
|
@@ -161,6 +178,9 @@ class UNetT(nn.Module):
|
|
161 |
self.norm_out = RMSNorm(dim)
|
162 |
self.proj_out = nn.Linear(dim, mel_dim)
|
163 |
|
|
|
|
|
|
|
164 |
def forward(
|
165 |
self,
|
166 |
x: float["b n d"], # nosied input audio # noqa: F722
|
@@ -170,6 +190,7 @@ class UNetT(nn.Module):
|
|
170 |
drop_audio_cond, # cfg for cond audio
|
171 |
drop_text, # cfg for text
|
172 |
mask: bool["b n"] | None = None, # noqa: F722
|
|
|
173 |
):
|
174 |
batch, seq_len = x.shape[0], x.shape[1]
|
175 |
if time.ndim == 0:
|
@@ -177,7 +198,17 @@ class UNetT(nn.Module):
|
|
177 |
|
178 |
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
179 |
t = self.time_embed(time)
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
182 |
|
183 |
# postfix time t to input x, [b n d] -> [b n+1 d]
|
|
|
33 |
|
34 |
|
35 |
class TextEmbedding(nn.Module):
|
36 |
+
def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
|
37 |
super().__init__()
|
38 |
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
39 |
|
40 |
+
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
|
41 |
+
|
42 |
if conv_layers > 0:
|
43 |
self.extra_modeling = True
|
44 |
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
|
|
54 |
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
55 |
batch, text_len = text.shape[0], text.shape[1]
|
56 |
text = F.pad(text, (0, seq_len - text_len), value=0)
|
57 |
+
if self.mask_padding:
|
58 |
+
text_mask = text == 0
|
59 |
|
60 |
if drop_text: # cfg for text
|
61 |
text = torch.zeros_like(text)
|
|
|
71 |
text = text + text_pos_embed
|
72 |
|
73 |
# convnextv2 blocks
|
74 |
+
if self.mask_padding:
|
75 |
+
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
|
76 |
+
for block in self.text_blocks:
|
77 |
+
text = block(text)
|
78 |
+
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
|
79 |
+
else:
|
80 |
+
text = self.text_blocks(text)
|
81 |
|
82 |
return text
|
83 |
|
|
|
116 |
mel_dim=100,
|
117 |
text_num_embeds=256,
|
118 |
text_dim=None,
|
119 |
+
text_mask_padding=True,
|
120 |
+
qk_norm=None,
|
121 |
conv_layers=0,
|
122 |
+
pe_attn_head=None,
|
123 |
skip_connect_type: Literal["add", "concat", "none"] = "concat",
|
124 |
):
|
125 |
super().__init__()
|
|
|
128 |
self.time_embed = TimestepEmbedding(dim)
|
129 |
if text_dim is None:
|
130 |
text_dim = mel_dim
|
131 |
+
self.text_embed = TextEmbedding(
|
132 |
+
text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
|
133 |
+
)
|
134 |
+
self.text_cond, self.text_uncond = None, None # text cache
|
135 |
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
136 |
|
137 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
|
|
150 |
|
151 |
attn_norm = RMSNorm(dim)
|
152 |
attn = Attention(
|
153 |
+
processor=AttnProcessor(pe_attn_head=pe_attn_head),
|
154 |
dim=dim,
|
155 |
heads=heads,
|
156 |
dim_head=dim_head,
|
157 |
dropout=dropout,
|
158 |
+
qk_norm=qk_norm,
|
159 |
)
|
160 |
|
161 |
ff_norm = RMSNorm(dim)
|
|
|
178 |
self.norm_out = RMSNorm(dim)
|
179 |
self.proj_out = nn.Linear(dim, mel_dim)
|
180 |
|
181 |
+
def clear_cache(self):
|
182 |
+
self.text_cond, self.text_uncond = None, None
|
183 |
+
|
184 |
def forward(
|
185 |
self,
|
186 |
x: float["b n d"], # nosied input audio # noqa: F722
|
|
|
190 |
drop_audio_cond, # cfg for cond audio
|
191 |
drop_text, # cfg for text
|
192 |
mask: bool["b n"] | None = None, # noqa: F722
|
193 |
+
cache=False,
|
194 |
):
|
195 |
batch, seq_len = x.shape[0], x.shape[1]
|
196 |
if time.ndim == 0:
|
|
|
198 |
|
199 |
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
200 |
t = self.time_embed(time)
|
201 |
+
if cache:
|
202 |
+
if drop_text:
|
203 |
+
if self.text_uncond is None:
|
204 |
+
self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
|
205 |
+
text_embed = self.text_uncond
|
206 |
+
else:
|
207 |
+
if self.text_cond is None:
|
208 |
+
self.text_cond = self.text_embed(text, seq_len, drop_text=False)
|
209 |
+
text_embed = self.text_cond
|
210 |
+
else:
|
211 |
+
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
212 |
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
213 |
|
214 |
# postfix time t to input x, [b n d] -> [b n+1 d]
|
src/f5_tts/model/cfm.py
CHANGED
@@ -120,10 +120,6 @@ class CFM(nn.Module):
|
|
120 |
text = list_str_to_tensor(text).to(device)
|
121 |
assert text.shape[0] == batch
|
122 |
|
123 |
-
if exists(text):
|
124 |
-
text_lens = (text != -1).sum(dim=-1)
|
125 |
-
lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
|
126 |
-
|
127 |
# duration
|
128 |
|
129 |
cond_mask = lens_to_mask(lens)
|
@@ -133,7 +129,9 @@ class CFM(nn.Module):
|
|
133 |
if isinstance(duration, int):
|
134 |
duration = torch.full((batch,), duration, device=device, dtype=torch.long)
|
135 |
|
136 |
-
duration = torch.maximum(
|
|
|
|
|
137 |
duration = duration.clamp(max=max_duration)
|
138 |
max_duration = duration.amax()
|
139 |
|
@@ -142,6 +140,9 @@ class CFM(nn.Module):
|
|
142 |
test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
|
143 |
|
144 |
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
|
|
|
|
|
|
|
145 |
cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
|
146 |
cond_mask = cond_mask.unsqueeze(-1)
|
147 |
step_cond = torch.where(
|
@@ -153,10 +154,6 @@ class CFM(nn.Module):
|
|
153 |
else: # save memory and speed up, as single inference need no mask currently
|
154 |
mask = None
|
155 |
|
156 |
-
# test for no ref audio
|
157 |
-
if no_ref_audio:
|
158 |
-
cond = torch.zeros_like(cond)
|
159 |
-
|
160 |
# neural ode
|
161 |
|
162 |
def fn(t, x):
|
@@ -165,13 +162,13 @@ class CFM(nn.Module):
|
|
165 |
|
166 |
# predict flow
|
167 |
pred = self.transformer(
|
168 |
-
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False
|
169 |
)
|
170 |
if cfg_strength < 1e-5:
|
171 |
return pred
|
172 |
|
173 |
null_pred = self.transformer(
|
174 |
-
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True
|
175 |
)
|
176 |
return pred + (pred - null_pred) * cfg_strength
|
177 |
|
@@ -198,6 +195,7 @@ class CFM(nn.Module):
|
|
198 |
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
199 |
|
200 |
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
|
|
|
201 |
|
202 |
sampled = trajectory[-1]
|
203 |
out = sampled
|
|
|
120 |
text = list_str_to_tensor(text).to(device)
|
121 |
assert text.shape[0] == batch
|
122 |
|
|
|
|
|
|
|
|
|
123 |
# duration
|
124 |
|
125 |
cond_mask = lens_to_mask(lens)
|
|
|
129 |
if isinstance(duration, int):
|
130 |
duration = torch.full((batch,), duration, device=device, dtype=torch.long)
|
131 |
|
132 |
+
duration = torch.maximum(
|
133 |
+
torch.maximum((text != -1).sum(dim=-1), lens) + 1, duration
|
134 |
+
) # duration at least text/audio prompt length plus one token, so something is generated
|
135 |
duration = duration.clamp(max=max_duration)
|
136 |
max_duration = duration.amax()
|
137 |
|
|
|
140 |
test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
|
141 |
|
142 |
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
|
143 |
+
if no_ref_audio:
|
144 |
+
cond = torch.zeros_like(cond)
|
145 |
+
|
146 |
cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
|
147 |
cond_mask = cond_mask.unsqueeze(-1)
|
148 |
step_cond = torch.where(
|
|
|
154 |
else: # save memory and speed up, as single inference need no mask currently
|
155 |
mask = None
|
156 |
|
|
|
|
|
|
|
|
|
157 |
# neural ode
|
158 |
|
159 |
def fn(t, x):
|
|
|
162 |
|
163 |
# predict flow
|
164 |
pred = self.transformer(
|
165 |
+
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, cache=True
|
166 |
)
|
167 |
if cfg_strength < 1e-5:
|
168 |
return pred
|
169 |
|
170 |
null_pred = self.transformer(
|
171 |
+
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True, cache=True
|
172 |
)
|
173 |
return pred + (pred - null_pred) * cfg_strength
|
174 |
|
|
|
195 |
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
196 |
|
197 |
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
|
198 |
+
self.transformer.clear_cache()
|
199 |
|
200 |
sampled = trajectory[-1]
|
201 |
out = sampled
|
src/f5_tts/model/dataset.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import json
|
2 |
-
import random
|
3 |
from importlib.resources import files
|
4 |
|
5 |
import torch
|
@@ -170,14 +169,17 @@ class DynamicBatchSampler(Sampler[list[int]]):
|
|
170 |
in a batch to ensure that the total number of frames are less
|
171 |
than a certain threshold.
|
172 |
2. Make sure the padding efficiency in the batch is high.
|
|
|
173 |
"""
|
174 |
|
175 |
def __init__(
|
176 |
-
self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None,
|
177 |
):
|
178 |
self.sampler = sampler
|
179 |
self.frames_threshold = frames_threshold
|
180 |
self.max_samples = max_samples
|
|
|
|
|
181 |
|
182 |
indices, batches = [], []
|
183 |
data_source = self.sampler.data_source
|
@@ -206,21 +208,30 @@ class DynamicBatchSampler(Sampler[list[int]]):
|
|
206 |
batch = []
|
207 |
batch_frames = 0
|
208 |
|
209 |
-
if not
|
210 |
batches.append(batch)
|
211 |
|
212 |
del indices
|
|
|
213 |
|
214 |
-
#
|
215 |
-
|
216 |
-
# e.g. for epoch n, use (random_seed + n)
|
217 |
-
random.seed(random_seed)
|
218 |
-
random.shuffle(batches)
|
219 |
|
220 |
-
|
|
|
|
|
221 |
|
222 |
def __iter__(self):
|
223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
|
225 |
def __len__(self):
|
226 |
return len(self.batches)
|
|
|
1 |
import json
|
|
|
2 |
from importlib.resources import files
|
3 |
|
4 |
import torch
|
|
|
169 |
in a batch to ensure that the total number of frames are less
|
170 |
than a certain threshold.
|
171 |
2. Make sure the padding efficiency in the batch is high.
|
172 |
+
3. Shuffle batches each epoch while maintaining reproducibility.
|
173 |
"""
|
174 |
|
175 |
def __init__(
|
176 |
+
self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_residual: bool = False
|
177 |
):
|
178 |
self.sampler = sampler
|
179 |
self.frames_threshold = frames_threshold
|
180 |
self.max_samples = max_samples
|
181 |
+
self.random_seed = random_seed
|
182 |
+
self.epoch = 0
|
183 |
|
184 |
indices, batches = [], []
|
185 |
data_source = self.sampler.data_source
|
|
|
208 |
batch = []
|
209 |
batch_frames = 0
|
210 |
|
211 |
+
if not drop_residual and len(batch) > 0:
|
212 |
batches.append(batch)
|
213 |
|
214 |
del indices
|
215 |
+
self.batches = batches
|
216 |
|
217 |
+
# Ensure even batches with accelerate BatchSamplerShard cls under frame_per_batch setting
|
218 |
+
self.drop_last = True
|
|
|
|
|
|
|
219 |
|
220 |
+
def set_epoch(self, epoch: int) -> None:
|
221 |
+
"""Sets the epoch for this sampler."""
|
222 |
+
self.epoch = epoch
|
223 |
|
224 |
def __iter__(self):
|
225 |
+
# Use both random_seed and epoch for deterministic but different shuffling per epoch
|
226 |
+
if self.random_seed is not None:
|
227 |
+
g = torch.Generator()
|
228 |
+
g.manual_seed(self.random_seed + self.epoch)
|
229 |
+
# Use PyTorch's random permutation for better reproducibility across PyTorch versions
|
230 |
+
indices = torch.randperm(len(self.batches), generator=g).tolist()
|
231 |
+
batches = [self.batches[i] for i in indices]
|
232 |
+
else:
|
233 |
+
batches = self.batches
|
234 |
+
return iter(batches)
|
235 |
|
236 |
def __len__(self):
|
237 |
return len(self.batches)
|
src/f5_tts/model/modules.py
CHANGED
@@ -269,11 +269,36 @@ class ConvNeXtV2Block(nn.Module):
|
|
269 |
return residual + x
|
270 |
|
271 |
|
272 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
# return with modulated x for attn input, and params for later mlp modulation
|
274 |
|
275 |
|
276 |
-
class
|
277 |
def __init__(self, dim):
|
278 |
super().__init__()
|
279 |
|
@@ -290,11 +315,11 @@ class AdaLayerNormZero(nn.Module):
|
|
290 |
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
291 |
|
292 |
|
293 |
-
#
|
294 |
# return only with modulated x for attn input, cuz no more mlp modulation
|
295 |
|
296 |
|
297 |
-
class
|
298 |
def __init__(self, dim):
|
299 |
super().__init__()
|
300 |
|
@@ -341,7 +366,8 @@ class Attention(nn.Module):
|
|
341 |
dim_head: int = 64,
|
342 |
dropout: float = 0.0,
|
343 |
context_dim: Optional[int] = None, # if not None -> joint attention
|
344 |
-
context_pre_only=
|
|
|
345 |
):
|
346 |
super().__init__()
|
347 |
|
@@ -362,18 +388,32 @@ class Attention(nn.Module):
|
|
362 |
self.to_k = nn.Linear(dim, self.inner_dim)
|
363 |
self.to_v = nn.Linear(dim, self.inner_dim)
|
364 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
if self.context_dim is not None:
|
|
|
366 |
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
|
367 |
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
|
368 |
-
if
|
369 |
-
self.
|
|
|
|
|
|
|
|
|
370 |
|
371 |
self.to_out = nn.ModuleList([])
|
372 |
self.to_out.append(nn.Linear(self.inner_dim, dim))
|
373 |
self.to_out.append(nn.Dropout(dropout))
|
374 |
|
375 |
-
if self.
|
376 |
-
self.to_out_c = nn.Linear(self.inner_dim,
|
377 |
|
378 |
def forward(
|
379 |
self,
|
@@ -393,8 +433,11 @@ class Attention(nn.Module):
|
|
393 |
|
394 |
|
395 |
class AttnProcessor:
|
396 |
-
def __init__(
|
397 |
-
|
|
|
|
|
|
|
398 |
|
399 |
def __call__(
|
400 |
self,
|
@@ -405,19 +448,11 @@ class AttnProcessor:
|
|
405 |
) -> torch.FloatTensor:
|
406 |
batch_size = x.shape[0]
|
407 |
|
408 |
-
# `sample` projections
|
409 |
query = attn.to_q(x)
|
410 |
key = attn.to_k(x)
|
411 |
value = attn.to_v(x)
|
412 |
|
413 |
-
# apply rotary position embedding
|
414 |
-
if rope is not None:
|
415 |
-
freqs, xpos_scale = rope
|
416 |
-
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
417 |
-
|
418 |
-
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
419 |
-
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
420 |
-
|
421 |
# attention
|
422 |
inner_dim = key.shape[-1]
|
423 |
head_dim = inner_dim // attn.heads
|
@@ -425,6 +460,25 @@ class AttnProcessor:
|
|
425 |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
426 |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
427 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
429 |
if mask is not None:
|
430 |
attn_mask = mask
|
@@ -470,16 +524,36 @@ class JointAttnProcessor:
|
|
470 |
|
471 |
batch_size = c.shape[0]
|
472 |
|
473 |
-
# `sample` projections
|
474 |
query = attn.to_q(x)
|
475 |
key = attn.to_k(x)
|
476 |
value = attn.to_v(x)
|
477 |
|
478 |
-
# `context` projections
|
479 |
c_query = attn.to_q_c(c)
|
480 |
c_key = attn.to_k_c(c)
|
481 |
c_value = attn.to_v_c(c)
|
482 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
483 |
# apply rope for context and noised input independently
|
484 |
if rope is not None:
|
485 |
freqs, xpos_scale = rope
|
@@ -492,16 +566,10 @@ class JointAttnProcessor:
|
|
492 |
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
493 |
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
494 |
|
495 |
-
# attention
|
496 |
-
query = torch.cat([query, c_query], dim=
|
497 |
-
key = torch.cat([key, c_key], dim=
|
498 |
-
value = torch.cat([value, c_value], dim=
|
499 |
-
|
500 |
-
inner_dim = key.shape[-1]
|
501 |
-
head_dim = inner_dim // attn.heads
|
502 |
-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
503 |
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
504 |
-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
505 |
|
506 |
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
507 |
if mask is not None:
|
@@ -540,16 +608,17 @@ class JointAttnProcessor:
|
|
540 |
|
541 |
|
542 |
class DiTBlock(nn.Module):
|
543 |
-
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
|
544 |
super().__init__()
|
545 |
|
546 |
-
self.attn_norm =
|
547 |
self.attn = Attention(
|
548 |
-
processor=AttnProcessor(),
|
549 |
dim=dim,
|
550 |
heads=heads,
|
551 |
dim_head=dim_head,
|
552 |
dropout=dropout,
|
|
|
553 |
)
|
554 |
|
555 |
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
@@ -585,26 +654,30 @@ class MMDiTBlock(nn.Module):
|
|
585 |
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
586 |
"""
|
587 |
|
588 |
-
def __init__(
|
|
|
|
|
589 |
super().__init__()
|
590 |
-
|
|
|
591 |
self.context_pre_only = context_pre_only
|
592 |
|
593 |
-
self.attn_norm_c =
|
594 |
-
self.attn_norm_x =
|
595 |
self.attn = Attention(
|
596 |
processor=JointAttnProcessor(),
|
597 |
dim=dim,
|
598 |
heads=heads,
|
599 |
dim_head=dim_head,
|
600 |
dropout=dropout,
|
601 |
-
context_dim=
|
602 |
context_pre_only=context_pre_only,
|
|
|
603 |
)
|
604 |
|
605 |
if not context_pre_only:
|
606 |
-
self.ff_norm_c = nn.LayerNorm(
|
607 |
-
self.ff_c = FeedForward(dim=
|
608 |
else:
|
609 |
self.ff_norm_c = None
|
610 |
self.ff_c = None
|
|
|
269 |
return residual + x
|
270 |
|
271 |
|
272 |
+
# RMSNorm
|
273 |
+
|
274 |
+
|
275 |
+
class RMSNorm(nn.Module):
|
276 |
+
def __init__(self, dim: int, eps: float):
|
277 |
+
super().__init__()
|
278 |
+
self.eps = eps
|
279 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
280 |
+
self.native_rms_norm = float(torch.__version__[:3]) >= 2.4
|
281 |
+
|
282 |
+
def forward(self, x):
|
283 |
+
if self.native_rms_norm:
|
284 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
285 |
+
x = x.to(self.weight.dtype)
|
286 |
+
x = F.rms_norm(x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps)
|
287 |
+
else:
|
288 |
+
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
289 |
+
x = x * torch.rsqrt(variance + self.eps)
|
290 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
291 |
+
x = x.to(self.weight.dtype)
|
292 |
+
x = x * self.weight
|
293 |
+
|
294 |
+
return x
|
295 |
+
|
296 |
+
|
297 |
+
# AdaLayerNorm
|
298 |
# return with modulated x for attn input, and params for later mlp modulation
|
299 |
|
300 |
|
301 |
+
class AdaLayerNorm(nn.Module):
|
302 |
def __init__(self, dim):
|
303 |
super().__init__()
|
304 |
|
|
|
315 |
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
316 |
|
317 |
|
318 |
+
# AdaLayerNorm for final layer
|
319 |
# return only with modulated x for attn input, cuz no more mlp modulation
|
320 |
|
321 |
|
322 |
+
class AdaLayerNorm_Final(nn.Module):
|
323 |
def __init__(self, dim):
|
324 |
super().__init__()
|
325 |
|
|
|
366 |
dim_head: int = 64,
|
367 |
dropout: float = 0.0,
|
368 |
context_dim: Optional[int] = None, # if not None -> joint attention
|
369 |
+
context_pre_only: bool = False,
|
370 |
+
qk_norm: Optional[str] = None,
|
371 |
):
|
372 |
super().__init__()
|
373 |
|
|
|
388 |
self.to_k = nn.Linear(dim, self.inner_dim)
|
389 |
self.to_v = nn.Linear(dim, self.inner_dim)
|
390 |
|
391 |
+
if qk_norm is None:
|
392 |
+
self.q_norm = None
|
393 |
+
self.k_norm = None
|
394 |
+
elif qk_norm == "rms_norm":
|
395 |
+
self.q_norm = RMSNorm(dim_head, eps=1e-6)
|
396 |
+
self.k_norm = RMSNorm(dim_head, eps=1e-6)
|
397 |
+
else:
|
398 |
+
raise ValueError(f"Unimplemented qk_norm: {qk_norm}")
|
399 |
+
|
400 |
if self.context_dim is not None:
|
401 |
+
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
|
402 |
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
|
403 |
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
|
404 |
+
if qk_norm is None:
|
405 |
+
self.c_q_norm = None
|
406 |
+
self.c_k_norm = None
|
407 |
+
elif qk_norm == "rms_norm":
|
408 |
+
self.c_q_norm = RMSNorm(dim_head, eps=1e-6)
|
409 |
+
self.c_k_norm = RMSNorm(dim_head, eps=1e-6)
|
410 |
|
411 |
self.to_out = nn.ModuleList([])
|
412 |
self.to_out.append(nn.Linear(self.inner_dim, dim))
|
413 |
self.to_out.append(nn.Dropout(dropout))
|
414 |
|
415 |
+
if self.context_dim is not None and not self.context_pre_only:
|
416 |
+
self.to_out_c = nn.Linear(self.inner_dim, context_dim)
|
417 |
|
418 |
def forward(
|
419 |
self,
|
|
|
433 |
|
434 |
|
435 |
class AttnProcessor:
|
436 |
+
def __init__(
|
437 |
+
self,
|
438 |
+
pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
|
439 |
+
):
|
440 |
+
self.pe_attn_head = pe_attn_head
|
441 |
|
442 |
def __call__(
|
443 |
self,
|
|
|
448 |
) -> torch.FloatTensor:
|
449 |
batch_size = x.shape[0]
|
450 |
|
451 |
+
# `sample` projections
|
452 |
query = attn.to_q(x)
|
453 |
key = attn.to_k(x)
|
454 |
value = attn.to_v(x)
|
455 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
# attention
|
457 |
inner_dim = key.shape[-1]
|
458 |
head_dim = inner_dim // attn.heads
|
|
|
460 |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
461 |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
462 |
|
463 |
+
# qk norm
|
464 |
+
if attn.q_norm is not None:
|
465 |
+
query = attn.q_norm(query)
|
466 |
+
if attn.k_norm is not None:
|
467 |
+
key = attn.k_norm(key)
|
468 |
+
|
469 |
+
# apply rotary position embedding
|
470 |
+
if rope is not None:
|
471 |
+
freqs, xpos_scale = rope
|
472 |
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
473 |
+
|
474 |
+
if self.pe_attn_head is not None:
|
475 |
+
pn = self.pe_attn_head
|
476 |
+
query[:, :pn, :, :] = apply_rotary_pos_emb(query[:, :pn, :, :], freqs, q_xpos_scale)
|
477 |
+
key[:, :pn, :, :] = apply_rotary_pos_emb(key[:, :pn, :, :], freqs, k_xpos_scale)
|
478 |
+
else:
|
479 |
+
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
480 |
+
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
481 |
+
|
482 |
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
483 |
if mask is not None:
|
484 |
attn_mask = mask
|
|
|
524 |
|
525 |
batch_size = c.shape[0]
|
526 |
|
527 |
+
# `sample` projections
|
528 |
query = attn.to_q(x)
|
529 |
key = attn.to_k(x)
|
530 |
value = attn.to_v(x)
|
531 |
|
532 |
+
# `context` projections
|
533 |
c_query = attn.to_q_c(c)
|
534 |
c_key = attn.to_k_c(c)
|
535 |
c_value = attn.to_v_c(c)
|
536 |
|
537 |
+
# attention
|
538 |
+
inner_dim = key.shape[-1]
|
539 |
+
head_dim = inner_dim // attn.heads
|
540 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
541 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
542 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
543 |
+
c_query = c_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
544 |
+
c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
545 |
+
c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
546 |
+
|
547 |
+
# qk norm
|
548 |
+
if attn.q_norm is not None:
|
549 |
+
query = attn.q_norm(query)
|
550 |
+
if attn.k_norm is not None:
|
551 |
+
key = attn.k_norm(key)
|
552 |
+
if attn.c_q_norm is not None:
|
553 |
+
c_query = attn.c_q_norm(c_query)
|
554 |
+
if attn.c_k_norm is not None:
|
555 |
+
c_key = attn.c_k_norm(c_key)
|
556 |
+
|
557 |
# apply rope for context and noised input independently
|
558 |
if rope is not None:
|
559 |
freqs, xpos_scale = rope
|
|
|
566 |
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
567 |
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
568 |
|
569 |
+
# joint attention
|
570 |
+
query = torch.cat([query, c_query], dim=2)
|
571 |
+
key = torch.cat([key, c_key], dim=2)
|
572 |
+
value = torch.cat([value, c_value], dim=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
575 |
if mask is not None:
|
|
|
608 |
|
609 |
|
610 |
class DiTBlock(nn.Module):
|
611 |
+
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None):
|
612 |
super().__init__()
|
613 |
|
614 |
+
self.attn_norm = AdaLayerNorm(dim)
|
615 |
self.attn = Attention(
|
616 |
+
processor=AttnProcessor(pe_attn_head=pe_attn_head),
|
617 |
dim=dim,
|
618 |
heads=heads,
|
619 |
dim_head=dim_head,
|
620 |
dropout=dropout,
|
621 |
+
qk_norm=qk_norm,
|
622 |
)
|
623 |
|
624 |
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
|
654 |
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
655 |
"""
|
656 |
|
657 |
+
def __init__(
|
658 |
+
self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_dim=None, context_pre_only=False, qk_norm=None
|
659 |
+
):
|
660 |
super().__init__()
|
661 |
+
if context_dim is None:
|
662 |
+
context_dim = dim
|
663 |
self.context_pre_only = context_pre_only
|
664 |
|
665 |
+
self.attn_norm_c = AdaLayerNorm_Final(context_dim) if context_pre_only else AdaLayerNorm(context_dim)
|
666 |
+
self.attn_norm_x = AdaLayerNorm(dim)
|
667 |
self.attn = Attention(
|
668 |
processor=JointAttnProcessor(),
|
669 |
dim=dim,
|
670 |
heads=heads,
|
671 |
dim_head=dim_head,
|
672 |
dropout=dropout,
|
673 |
+
context_dim=context_dim,
|
674 |
context_pre_only=context_pre_only,
|
675 |
+
qk_norm=qk_norm,
|
676 |
)
|
677 |
|
678 |
if not context_pre_only:
|
679 |
+
self.ff_norm_c = nn.LayerNorm(context_dim, elementwise_affine=False, eps=1e-6)
|
680 |
+
self.ff_c = FeedForward(dim=context_dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
681 |
else:
|
682 |
self.ff_norm_c = None
|
683 |
self.ff_c = None
|
src/f5_tts/model/trainer.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import gc
|
|
|
4 |
import os
|
5 |
|
6 |
import torch
|
@@ -29,8 +30,9 @@ class Trainer:
|
|
29 |
learning_rate,
|
30 |
num_warmup_updates=20000,
|
31 |
save_per_updates=1000,
|
|
|
32 |
checkpoint_path=None,
|
33 |
-
|
34 |
batch_size_type: str = "sample",
|
35 |
max_samples=32,
|
36 |
grad_accumulation_steps=1,
|
@@ -38,23 +40,23 @@ class Trainer:
|
|
38 |
noise_scheduler: str | None = None,
|
39 |
duration_predictor: torch.nn.Module | None = None,
|
40 |
logger: str | None = "wandb", # "wandb" | "tensorboard" | None
|
41 |
-
wandb_project="
|
42 |
wandb_run_name="test_run",
|
43 |
wandb_resume_id: str = None,
|
44 |
log_samples: bool = False,
|
45 |
-
|
46 |
accelerate_kwargs: dict = dict(),
|
47 |
ema_kwargs: dict = dict(),
|
48 |
bnb_optimizer: bool = False,
|
49 |
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
50 |
is_local_vocoder: bool = False, # use local path vocoder
|
51 |
local_vocoder_path: str = "", # local vocoder path
|
|
|
52 |
):
|
53 |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
54 |
|
55 |
if logger == "wandb" and not wandb.api.api_key:
|
56 |
logger = None
|
57 |
-
print(f"Using logger: {logger}")
|
58 |
self.log_samples = log_samples
|
59 |
|
60 |
self.accelerator = Accelerator(
|
@@ -71,21 +73,23 @@ class Trainer:
|
|
71 |
else:
|
72 |
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
73 |
|
74 |
-
|
75 |
-
|
76 |
-
init_kwargs=init_kwargs,
|
77 |
-
config={
|
78 |
"epochs": epochs,
|
79 |
"learning_rate": learning_rate,
|
80 |
"num_warmup_updates": num_warmup_updates,
|
81 |
-
"
|
82 |
"batch_size_type": batch_size_type,
|
83 |
"max_samples": max_samples,
|
84 |
"grad_accumulation_steps": grad_accumulation_steps,
|
85 |
"max_grad_norm": max_grad_norm,
|
86 |
-
"gpus": self.accelerator.num_processes,
|
87 |
"noise_scheduler": noise_scheduler,
|
88 |
-
}
|
|
|
|
|
|
|
|
|
|
|
89 |
)
|
90 |
|
91 |
elif self.logger == "tensorboard":
|
@@ -99,13 +103,20 @@ class Trainer:
|
|
99 |
self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
|
100 |
self.ema_model.to(self.accelerator.device)
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
self.epochs = epochs
|
103 |
self.num_warmup_updates = num_warmup_updates
|
104 |
self.save_per_updates = save_per_updates
|
105 |
-
self.
|
106 |
-
self.
|
|
|
107 |
|
108 |
-
self.
|
109 |
self.batch_size_type = batch_size_type
|
110 |
self.max_samples = max_samples
|
111 |
self.grad_accumulation_steps = grad_accumulation_steps
|
@@ -132,7 +143,7 @@ class Trainer:
|
|
132 |
def is_main(self):
|
133 |
return self.accelerator.is_main_process
|
134 |
|
135 |
-
def save_checkpoint(self,
|
136 |
self.accelerator.wait_for_everyone()
|
137 |
if self.is_main:
|
138 |
checkpoint = dict(
|
@@ -140,21 +151,38 @@ class Trainer:
|
|
140 |
optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
|
141 |
ema_model_state_dict=self.ema_model.state_dict(),
|
142 |
scheduler_state_dict=self.scheduler.state_dict(),
|
143 |
-
|
144 |
)
|
145 |
if not os.path.exists(self.checkpoint_path):
|
146 |
os.makedirs(self.checkpoint_path)
|
147 |
if last:
|
148 |
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
|
149 |
-
print(f"Saved last checkpoint at
|
150 |
else:
|
151 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
def load_checkpoint(self):
|
154 |
if (
|
155 |
not exists(self.checkpoint_path)
|
156 |
or not os.path.exists(self.checkpoint_path)
|
157 |
-
or not any(filename.endswith(".pt") for filename in os.listdir(self.checkpoint_path))
|
158 |
):
|
159 |
return 0
|
160 |
|
@@ -162,12 +190,34 @@ class Trainer:
|
|
162 |
if "model_last.pt" in os.listdir(self.checkpoint_path):
|
163 |
latest_checkpoint = "model_last.pt"
|
164 |
else:
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
# patch for backward compatibility, 305e3ea
|
173 |
for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
|
@@ -177,7 +227,14 @@ class Trainer:
|
|
177 |
if self.is_main:
|
178 |
self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
|
179 |
|
180 |
-
if "step" in checkpoint:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
# patch for backward compatibility, 305e3ea
|
182 |
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
|
183 |
if key in checkpoint["model_state_dict"]:
|
@@ -187,19 +244,19 @@ class Trainer:
|
|
187 |
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
|
188 |
if self.scheduler:
|
189 |
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
190 |
-
|
191 |
else:
|
192 |
checkpoint["model_state_dict"] = {
|
193 |
k.replace("ema_model.", ""): v
|
194 |
for k, v in checkpoint["ema_model_state_dict"].items()
|
195 |
-
if k not in ["initted", "step"]
|
196 |
}
|
197 |
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
|
198 |
-
|
199 |
|
200 |
del checkpoint
|
201 |
gc.collect()
|
202 |
-
return
|
203 |
|
204 |
def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
|
205 |
if self.log_samples:
|
@@ -225,7 +282,7 @@ class Trainer:
|
|
225 |
num_workers=num_workers,
|
226 |
pin_memory=True,
|
227 |
persistent_workers=True,
|
228 |
-
batch_size=self.
|
229 |
shuffle=True,
|
230 |
generator=generator,
|
231 |
)
|
@@ -233,7 +290,11 @@ class Trainer:
|
|
233 |
self.accelerator.even_batches = False
|
234 |
sampler = SequentialSampler(train_dataset)
|
235 |
batch_sampler = DynamicBatchSampler(
|
236 |
-
sampler,
|
|
|
|
|
|
|
|
|
237 |
)
|
238 |
train_dataloader = DataLoader(
|
239 |
train_dataset,
|
@@ -248,25 +309,26 @@ class Trainer:
|
|
248 |
|
249 |
# accelerator.prepare() dispatches batches to devices;
|
250 |
# which means the length of dataloader calculated before, should consider the number of devices
|
251 |
-
|
252 |
self.num_warmup_updates * self.accelerator.num_processes
|
253 |
) # consider a fixed warmup steps while using accelerate multi-gpu ddp
|
254 |
# otherwise by default with split_batches=False, warmup steps change with num_processes
|
255 |
-
|
256 |
-
|
257 |
-
warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=
|
258 |
-
decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=
|
259 |
self.scheduler = SequentialLR(
|
260 |
-
self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[
|
261 |
)
|
262 |
train_dataloader, self.scheduler = self.accelerator.prepare(
|
263 |
train_dataloader, self.scheduler
|
264 |
-
) # actual
|
265 |
-
|
266 |
-
|
267 |
|
268 |
if exists(resumable_with_seed):
|
269 |
orig_epoch_step = len(train_dataloader)
|
|
|
270 |
skipped_epoch = int(start_step // orig_epoch_step)
|
271 |
skipped_batch = start_step % orig_epoch_step
|
272 |
skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
|
@@ -276,23 +338,25 @@ class Trainer:
|
|
276 |
for epoch in range(skipped_epoch, self.epochs):
|
277 |
self.model.train()
|
278 |
if exists(resumable_with_seed) and epoch == skipped_epoch:
|
279 |
-
|
280 |
-
|
281 |
-
desc=f"Epoch {epoch+1}/{self.epochs}",
|
282 |
-
unit="step",
|
283 |
-
disable=not self.accelerator.is_local_main_process,
|
284 |
-
initial=skipped_batch,
|
285 |
-
total=orig_epoch_step,
|
286 |
-
)
|
287 |
else:
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
|
295 |
-
for batch in
|
296 |
with self.accelerator.accumulate(self.model):
|
297 |
text_inputs = batch["text"]
|
298 |
mel_spec = batch["mel"].permute(0, 2, 1)
|
@@ -301,7 +365,7 @@ class Trainer:
|
|
301 |
# TODO. add duration predictor training
|
302 |
if self.duration_predictor is not None and self.accelerator.is_local_main_process:
|
303 |
dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
|
304 |
-
self.accelerator.log({"duration loss": dur_loss.item()}, step=
|
305 |
|
306 |
loss, cond, pred = self.model(
|
307 |
mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
|
@@ -315,21 +379,24 @@ class Trainer:
|
|
315 |
self.scheduler.step()
|
316 |
self.optimizer.zero_grad()
|
317 |
|
318 |
-
if self.
|
319 |
-
self.
|
|
|
320 |
|
321 |
-
|
|
|
|
|
322 |
|
323 |
if self.accelerator.is_local_main_process:
|
324 |
-
self.accelerator.log(
|
|
|
|
|
325 |
if self.logger == "tensorboard":
|
326 |
-
self.writer.add_scalar("loss", loss.item(),
|
327 |
-
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0],
|
328 |
-
|
329 |
-
progress_bar.set_postfix(step=str(global_step), loss=loss.item())
|
330 |
|
331 |
-
if
|
332 |
-
self.save_checkpoint(
|
333 |
|
334 |
if self.log_samples and self.accelerator.is_local_main_process:
|
335 |
ref_audio_len = mel_lengths[0]
|
@@ -355,12 +422,16 @@ class Trainer:
|
|
355 |
gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
|
356 |
ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
|
357 |
|
358 |
-
torchaudio.save(
|
359 |
-
|
|
|
|
|
|
|
|
|
360 |
|
361 |
-
if
|
362 |
-
self.save_checkpoint(
|
363 |
|
364 |
-
self.save_checkpoint(
|
365 |
|
366 |
self.accelerator.end_training()
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import gc
|
4 |
+
import math
|
5 |
import os
|
6 |
|
7 |
import torch
|
|
|
30 |
learning_rate,
|
31 |
num_warmup_updates=20000,
|
32 |
save_per_updates=1000,
|
33 |
+
keep_last_n_checkpoints: int = -1, # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
34 |
checkpoint_path=None,
|
35 |
+
batch_size_per_gpu=32,
|
36 |
batch_size_type: str = "sample",
|
37 |
max_samples=32,
|
38 |
grad_accumulation_steps=1,
|
|
|
40 |
noise_scheduler: str | None = None,
|
41 |
duration_predictor: torch.nn.Module | None = None,
|
42 |
logger: str | None = "wandb", # "wandb" | "tensorboard" | None
|
43 |
+
wandb_project="test_f5-tts",
|
44 |
wandb_run_name="test_run",
|
45 |
wandb_resume_id: str = None,
|
46 |
log_samples: bool = False,
|
47 |
+
last_per_updates=None,
|
48 |
accelerate_kwargs: dict = dict(),
|
49 |
ema_kwargs: dict = dict(),
|
50 |
bnb_optimizer: bool = False,
|
51 |
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
52 |
is_local_vocoder: bool = False, # use local path vocoder
|
53 |
local_vocoder_path: str = "", # local vocoder path
|
54 |
+
cfg_dict: dict = dict(), # training config
|
55 |
):
|
56 |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
57 |
|
58 |
if logger == "wandb" and not wandb.api.api_key:
|
59 |
logger = None
|
|
|
60 |
self.log_samples = log_samples
|
61 |
|
62 |
self.accelerator = Accelerator(
|
|
|
73 |
else:
|
74 |
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
75 |
|
76 |
+
if not cfg_dict:
|
77 |
+
cfg_dict = {
|
|
|
|
|
78 |
"epochs": epochs,
|
79 |
"learning_rate": learning_rate,
|
80 |
"num_warmup_updates": num_warmup_updates,
|
81 |
+
"batch_size_per_gpu": batch_size_per_gpu,
|
82 |
"batch_size_type": batch_size_type,
|
83 |
"max_samples": max_samples,
|
84 |
"grad_accumulation_steps": grad_accumulation_steps,
|
85 |
"max_grad_norm": max_grad_norm,
|
|
|
86 |
"noise_scheduler": noise_scheduler,
|
87 |
+
}
|
88 |
+
cfg_dict["gpus"] = self.accelerator.num_processes
|
89 |
+
self.accelerator.init_trackers(
|
90 |
+
project_name=wandb_project,
|
91 |
+
init_kwargs=init_kwargs,
|
92 |
+
config=cfg_dict,
|
93 |
)
|
94 |
|
95 |
elif self.logger == "tensorboard":
|
|
|
103 |
self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
|
104 |
self.ema_model.to(self.accelerator.device)
|
105 |
|
106 |
+
print(f"Using logger: {logger}")
|
107 |
+
if grad_accumulation_steps > 1:
|
108 |
+
print(
|
109 |
+
"Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e"
|
110 |
+
)
|
111 |
+
|
112 |
self.epochs = epochs
|
113 |
self.num_warmup_updates = num_warmup_updates
|
114 |
self.save_per_updates = save_per_updates
|
115 |
+
self.keep_last_n_checkpoints = keep_last_n_checkpoints
|
116 |
+
self.last_per_updates = default(last_per_updates, save_per_updates)
|
117 |
+
self.checkpoint_path = default(checkpoint_path, "ckpts/test_f5-tts")
|
118 |
|
119 |
+
self.batch_size_per_gpu = batch_size_per_gpu
|
120 |
self.batch_size_type = batch_size_type
|
121 |
self.max_samples = max_samples
|
122 |
self.grad_accumulation_steps = grad_accumulation_steps
|
|
|
143 |
def is_main(self):
|
144 |
return self.accelerator.is_main_process
|
145 |
|
146 |
+
def save_checkpoint(self, update, last=False):
|
147 |
self.accelerator.wait_for_everyone()
|
148 |
if self.is_main:
|
149 |
checkpoint = dict(
|
|
|
151 |
optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
|
152 |
ema_model_state_dict=self.ema_model.state_dict(),
|
153 |
scheduler_state_dict=self.scheduler.state_dict(),
|
154 |
+
update=update,
|
155 |
)
|
156 |
if not os.path.exists(self.checkpoint_path):
|
157 |
os.makedirs(self.checkpoint_path)
|
158 |
if last:
|
159 |
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
|
160 |
+
print(f"Saved last checkpoint at update {update}")
|
161 |
else:
|
162 |
+
if self.keep_last_n_checkpoints == 0:
|
163 |
+
return
|
164 |
+
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
|
165 |
+
if self.keep_last_n_checkpoints > 0:
|
166 |
+
# Updated logic to exclude pretrained model from rotation
|
167 |
+
checkpoints = [
|
168 |
+
f
|
169 |
+
for f in os.listdir(self.checkpoint_path)
|
170 |
+
if f.startswith("model_")
|
171 |
+
and not f.startswith("pretrained_") # Exclude pretrained models
|
172 |
+
and f.endswith(".pt")
|
173 |
+
and f != "model_last.pt"
|
174 |
+
]
|
175 |
+
checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
|
176 |
+
while len(checkpoints) > self.keep_last_n_checkpoints:
|
177 |
+
oldest_checkpoint = checkpoints.pop(0)
|
178 |
+
os.remove(os.path.join(self.checkpoint_path, oldest_checkpoint))
|
179 |
+
print(f"Removed old checkpoint: {oldest_checkpoint}")
|
180 |
|
181 |
def load_checkpoint(self):
|
182 |
if (
|
183 |
not exists(self.checkpoint_path)
|
184 |
or not os.path.exists(self.checkpoint_path)
|
185 |
+
or not any(filename.endswith((".pt", ".safetensors")) for filename in os.listdir(self.checkpoint_path))
|
186 |
):
|
187 |
return 0
|
188 |
|
|
|
190 |
if "model_last.pt" in os.listdir(self.checkpoint_path):
|
191 |
latest_checkpoint = "model_last.pt"
|
192 |
else:
|
193 |
+
# Updated to consider pretrained models for loading but prioritize training checkpoints
|
194 |
+
all_checkpoints = [
|
195 |
+
f
|
196 |
+
for f in os.listdir(self.checkpoint_path)
|
197 |
+
if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith((".pt", ".safetensors"))
|
198 |
+
]
|
199 |
+
|
200 |
+
# First try to find regular training checkpoints
|
201 |
+
training_checkpoints = [f for f in all_checkpoints if f.startswith("model_") and f != "model_last.pt"]
|
202 |
+
if training_checkpoints:
|
203 |
+
latest_checkpoint = sorted(
|
204 |
+
training_checkpoints,
|
205 |
+
key=lambda x: int("".join(filter(str.isdigit, x))),
|
206 |
+
)[-1]
|
207 |
+
else:
|
208 |
+
# If no training checkpoints, use pretrained model
|
209 |
+
latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_"))
|
210 |
+
|
211 |
+
if latest_checkpoint.endswith(".safetensors"): # always a pretrained checkpoint
|
212 |
+
from safetensors.torch import load_file
|
213 |
+
|
214 |
+
checkpoint = load_file(f"{self.checkpoint_path}/{latest_checkpoint}", device="cpu")
|
215 |
+
checkpoint = {"ema_model_state_dict": checkpoint}
|
216 |
+
elif latest_checkpoint.endswith(".pt"):
|
217 |
+
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
218 |
+
checkpoint = torch.load(
|
219 |
+
f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu"
|
220 |
+
)
|
221 |
|
222 |
# patch for backward compatibility, 305e3ea
|
223 |
for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
|
|
|
227 |
if self.is_main:
|
228 |
self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
|
229 |
|
230 |
+
if "update" in checkpoint or "step" in checkpoint:
|
231 |
+
# patch for backward compatibility, with before f992c4e
|
232 |
+
if "step" in checkpoint:
|
233 |
+
checkpoint["update"] = checkpoint["step"] // self.grad_accumulation_steps
|
234 |
+
if self.grad_accumulation_steps > 1 and self.is_main:
|
235 |
+
print(
|
236 |
+
"F5-TTS WARNING: Loading checkpoint saved with per_steps logic (before f992c4e), will convert to per_updates according to grad_accumulation_steps setting, may have unexpected behaviour."
|
237 |
+
)
|
238 |
# patch for backward compatibility, 305e3ea
|
239 |
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
|
240 |
if key in checkpoint["model_state_dict"]:
|
|
|
244 |
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
|
245 |
if self.scheduler:
|
246 |
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
247 |
+
update = checkpoint["update"]
|
248 |
else:
|
249 |
checkpoint["model_state_dict"] = {
|
250 |
k.replace("ema_model.", ""): v
|
251 |
for k, v in checkpoint["ema_model_state_dict"].items()
|
252 |
+
if k not in ["initted", "update", "step"]
|
253 |
}
|
254 |
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
|
255 |
+
update = 0
|
256 |
|
257 |
del checkpoint
|
258 |
gc.collect()
|
259 |
+
return update
|
260 |
|
261 |
def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
|
262 |
if self.log_samples:
|
|
|
282 |
num_workers=num_workers,
|
283 |
pin_memory=True,
|
284 |
persistent_workers=True,
|
285 |
+
batch_size=self.batch_size_per_gpu,
|
286 |
shuffle=True,
|
287 |
generator=generator,
|
288 |
)
|
|
|
290 |
self.accelerator.even_batches = False
|
291 |
sampler = SequentialSampler(train_dataset)
|
292 |
batch_sampler = DynamicBatchSampler(
|
293 |
+
sampler,
|
294 |
+
self.batch_size_per_gpu,
|
295 |
+
max_samples=self.max_samples,
|
296 |
+
random_seed=resumable_with_seed, # This enables reproducible shuffling
|
297 |
+
drop_residual=False,
|
298 |
)
|
299 |
train_dataloader = DataLoader(
|
300 |
train_dataset,
|
|
|
309 |
|
310 |
# accelerator.prepare() dispatches batches to devices;
|
311 |
# which means the length of dataloader calculated before, should consider the number of devices
|
312 |
+
warmup_updates = (
|
313 |
self.num_warmup_updates * self.accelerator.num_processes
|
314 |
) # consider a fixed warmup steps while using accelerate multi-gpu ddp
|
315 |
# otherwise by default with split_batches=False, warmup steps change with num_processes
|
316 |
+
total_updates = math.ceil(len(train_dataloader) / self.grad_accumulation_steps) * self.epochs
|
317 |
+
decay_updates = total_updates - warmup_updates
|
318 |
+
warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_updates)
|
319 |
+
decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_updates)
|
320 |
self.scheduler = SequentialLR(
|
321 |
+
self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_updates]
|
322 |
)
|
323 |
train_dataloader, self.scheduler = self.accelerator.prepare(
|
324 |
train_dataloader, self.scheduler
|
325 |
+
) # actual multi_gpu updates = single_gpu updates / gpu nums
|
326 |
+
start_update = self.load_checkpoint()
|
327 |
+
global_update = start_update
|
328 |
|
329 |
if exists(resumable_with_seed):
|
330 |
orig_epoch_step = len(train_dataloader)
|
331 |
+
start_step = start_update * self.grad_accumulation_steps
|
332 |
skipped_epoch = int(start_step // orig_epoch_step)
|
333 |
skipped_batch = start_step % orig_epoch_step
|
334 |
skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
|
|
|
338 |
for epoch in range(skipped_epoch, self.epochs):
|
339 |
self.model.train()
|
340 |
if exists(resumable_with_seed) and epoch == skipped_epoch:
|
341 |
+
progress_bar_initial = math.ceil(skipped_batch / self.grad_accumulation_steps)
|
342 |
+
current_dataloader = skipped_dataloader
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
else:
|
344 |
+
progress_bar_initial = 0
|
345 |
+
current_dataloader = train_dataloader
|
346 |
+
|
347 |
+
# Set epoch for the batch sampler if it exists
|
348 |
+
if hasattr(train_dataloader, "batch_sampler") and hasattr(train_dataloader.batch_sampler, "set_epoch"):
|
349 |
+
train_dataloader.batch_sampler.set_epoch(epoch)
|
350 |
+
|
351 |
+
progress_bar = tqdm(
|
352 |
+
range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)),
|
353 |
+
desc=f"Epoch {epoch+1}/{self.epochs}",
|
354 |
+
unit="update",
|
355 |
+
disable=not self.accelerator.is_local_main_process,
|
356 |
+
initial=progress_bar_initial,
|
357 |
+
)
|
358 |
|
359 |
+
for batch in current_dataloader:
|
360 |
with self.accelerator.accumulate(self.model):
|
361 |
text_inputs = batch["text"]
|
362 |
mel_spec = batch["mel"].permute(0, 2, 1)
|
|
|
365 |
# TODO. add duration predictor training
|
366 |
if self.duration_predictor is not None and self.accelerator.is_local_main_process:
|
367 |
dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
|
368 |
+
self.accelerator.log({"duration loss": dur_loss.item()}, step=global_update)
|
369 |
|
370 |
loss, cond, pred = self.model(
|
371 |
mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
|
|
|
379 |
self.scheduler.step()
|
380 |
self.optimizer.zero_grad()
|
381 |
|
382 |
+
if self.accelerator.sync_gradients:
|
383 |
+
if self.is_main:
|
384 |
+
self.ema_model.update()
|
385 |
|
386 |
+
global_update += 1
|
387 |
+
progress_bar.update(1)
|
388 |
+
progress_bar.set_postfix(update=str(global_update), loss=loss.item())
|
389 |
|
390 |
if self.accelerator.is_local_main_process:
|
391 |
+
self.accelerator.log(
|
392 |
+
{"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_update
|
393 |
+
)
|
394 |
if self.logger == "tensorboard":
|
395 |
+
self.writer.add_scalar("loss", loss.item(), global_update)
|
396 |
+
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update)
|
|
|
|
|
397 |
|
398 |
+
if global_update % self.save_per_updates == 0 and self.accelerator.sync_gradients:
|
399 |
+
self.save_checkpoint(global_update)
|
400 |
|
401 |
if self.log_samples and self.accelerator.is_local_main_process:
|
402 |
ref_audio_len = mel_lengths[0]
|
|
|
422 |
gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
|
423 |
ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
|
424 |
|
425 |
+
torchaudio.save(
|
426 |
+
f"{log_samples_path}/update_{global_update}_gen.wav", gen_audio, target_sample_rate
|
427 |
+
)
|
428 |
+
torchaudio.save(
|
429 |
+
f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate
|
430 |
+
)
|
431 |
|
432 |
+
if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients:
|
433 |
+
self.save_checkpoint(global_update, last=True)
|
434 |
|
435 |
+
self.save_checkpoint(global_update, last=True)
|
436 |
|
437 |
self.accelerator.end_training()
|
src/f5_tts/model/utils.py
CHANGED
@@ -133,11 +133,12 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
|
133 |
|
134 |
# convert char to pinyin
|
135 |
|
136 |
-
jieba.initialize()
|
137 |
-
print("Word segmentation module jieba initialized.\n")
|
138 |
-
|
139 |
|
140 |
def convert_char_to_pinyin(text_list, polyphone=True):
|
|
|
|
|
|
|
|
|
141 |
final_text_list = []
|
142 |
custom_trans = str.maketrans(
|
143 |
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
|
|
|
133 |
|
134 |
# convert char to pinyin
|
135 |
|
|
|
|
|
|
|
136 |
|
137 |
def convert_char_to_pinyin(text_list, polyphone=True):
|
138 |
+
if jieba.dt.initialized is False:
|
139 |
+
jieba.default_logger.setLevel(50) # CRITICAL
|
140 |
+
jieba.initialize()
|
141 |
+
|
142 |
final_text_list = []
|
143 |
custom_trans = str.maketrans(
|
144 |
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
|
src/f5_tts/scripts/count_max_epoch.py
CHANGED
@@ -9,7 +9,7 @@ mel_hop_length = 256
|
|
9 |
mel_sampling_rate = 24000
|
10 |
|
11 |
# target
|
12 |
-
wanted_max_updates =
|
13 |
|
14 |
# train params
|
15 |
gpus = 8
|
@@ -20,13 +20,13 @@ grad_accum = 1
|
|
20 |
mini_batch_frames = frames_per_gpu * grad_accum * gpus
|
21 |
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
|
22 |
updates_per_epoch = total_hours / mini_batch_hours
|
23 |
-
steps_per_epoch = updates_per_epoch * grad_accum
|
24 |
|
25 |
# result
|
26 |
epochs = wanted_max_updates / updates_per_epoch
|
27 |
print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
|
28 |
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
|
29 |
-
print(f" or approx. 0/{steps_per_epoch:.0f} steps")
|
30 |
|
31 |
# others
|
32 |
print(f"total {total_hours:.0f} hours")
|
|
|
9 |
mel_sampling_rate = 24000
|
10 |
|
11 |
# target
|
12 |
+
wanted_max_updates = 1200000
|
13 |
|
14 |
# train params
|
15 |
gpus = 8
|
|
|
20 |
mini_batch_frames = frames_per_gpu * grad_accum * gpus
|
21 |
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
|
22 |
updates_per_epoch = total_hours / mini_batch_hours
|
23 |
+
# steps_per_epoch = updates_per_epoch * grad_accum
|
24 |
|
25 |
# result
|
26 |
epochs = wanted_max_updates / updates_per_epoch
|
27 |
print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
|
28 |
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
|
29 |
+
# print(f" or approx. 0/{steps_per_epoch:.0f} steps")
|
30 |
|
31 |
# others
|
32 |
print(f"total {total_hours:.0f} hours")
|
src/f5_tts/socket_client.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import socket
|
2 |
+
import asyncio
|
3 |
+
import pyaudio
|
4 |
+
import numpy as np
|
5 |
+
import logging
|
6 |
+
import time
|
7 |
+
|
8 |
+
logging.basicConfig(level=logging.INFO)
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
|
13 |
+
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
14 |
+
await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port)))
|
15 |
+
|
16 |
+
start_time = time.time()
|
17 |
+
first_chunk_time = None
|
18 |
+
|
19 |
+
async def play_audio_stream():
|
20 |
+
nonlocal first_chunk_time
|
21 |
+
p = pyaudio.PyAudio()
|
22 |
+
stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
|
23 |
+
|
24 |
+
try:
|
25 |
+
while True:
|
26 |
+
data = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 8192)
|
27 |
+
if not data:
|
28 |
+
break
|
29 |
+
if data == b"END":
|
30 |
+
logger.info("End of audio received.")
|
31 |
+
break
|
32 |
+
|
33 |
+
audio_array = np.frombuffer(data, dtype=np.float32)
|
34 |
+
stream.write(audio_array.tobytes())
|
35 |
+
|
36 |
+
if first_chunk_time is None:
|
37 |
+
first_chunk_time = time.time()
|
38 |
+
|
39 |
+
finally:
|
40 |
+
stream.stop_stream()
|
41 |
+
stream.close()
|
42 |
+
p.terminate()
|
43 |
+
|
44 |
+
logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds")
|
45 |
+
|
46 |
+
try:
|
47 |
+
data_to_send = f"{text}".encode("utf-8")
|
48 |
+
await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send)
|
49 |
+
await play_audio_stream()
|
50 |
+
|
51 |
+
except Exception as e:
|
52 |
+
logger.error(f"Error in listen_to_F5TTS: {e}")
|
53 |
+
|
54 |
+
finally:
|
55 |
+
client_socket.close()
|
56 |
+
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components"
|
60 |
+
|
61 |
+
asyncio.run(listen_to_F5TTS(text_to_send))
|
src/f5_tts/socket_server.py
CHANGED
@@ -1,142 +1,213 @@
|
|
1 |
import argparse
|
2 |
import gc
|
|
|
|
|
|
|
3 |
import socket
|
4 |
import struct
|
5 |
-
import
|
6 |
-
import torchaudio
|
7 |
import traceback
|
|
|
8 |
from importlib.resources import files
|
9 |
-
from threading import Thread
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
from
|
14 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
class TTSStreamingProcessor:
|
18 |
-
def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
|
19 |
self.device = device or (
|
20 |
-
"cuda"
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
model_cls
|
26 |
-
|
27 |
ckpt_path=ckpt_file,
|
28 |
-
mel_spec_type=
|
29 |
vocab_file=vocab_file,
|
30 |
ode_method="euler",
|
31 |
use_ema=True,
|
32 |
device=self.device,
|
33 |
).to(self.device, dtype=dtype)
|
34 |
|
35 |
-
|
36 |
-
self.
|
37 |
|
38 |
-
|
39 |
-
self.
|
|
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
self.
|
44 |
-
|
45 |
-
|
46 |
-
self._warm_up()
|
47 |
|
48 |
def _warm_up(self):
|
49 |
-
"
|
50 |
-
print("Warming up the model...")
|
51 |
-
ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
|
52 |
-
audio, sr = torchaudio.load(ref_audio)
|
53 |
gen_text = "Warm-up text for the model."
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
def generate_stream(self, text, play_steps_in_s=0.5):
|
60 |
-
"""Generate audio in chunks and yield them in real-time."""
|
61 |
-
# Preprocess the reference audio and text
|
62 |
-
ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
|
63 |
-
|
64 |
-
# Load reference audio
|
65 |
-
audio, sr = torchaudio.load(ref_audio)
|
66 |
-
|
67 |
-
# Run inference for the input text
|
68 |
-
audio_chunk, final_sample_rate, _ = infer_batch_process(
|
69 |
-
(audio, sr),
|
70 |
-
ref_text,
|
71 |
-
[text],
|
72 |
self.model,
|
73 |
self.vocoder,
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
)
|
76 |
|
77 |
-
#
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
yield packed_audio
|
83 |
-
return
|
84 |
-
|
85 |
-
for i in range(0, len(audio_chunk), chunk_size):
|
86 |
-
chunk = audio_chunk[i : i + chunk_size]
|
87 |
-
|
88 |
-
# Check if it's the final chunk
|
89 |
-
if i + chunk_size >= len(audio_chunk):
|
90 |
-
chunk = audio_chunk[i:]
|
91 |
-
|
92 |
-
# Send the chunk if it is not empty
|
93 |
-
if len(chunk) > 0:
|
94 |
-
packed_audio = struct.pack(f"{len(chunk)}f", *chunk)
|
95 |
-
yield packed_audio
|
96 |
|
|
|
|
|
|
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
while True:
|
101 |
-
# Receive data from the client
|
102 |
-
data = client_socket.recv(1024).decode("utf-8")
|
103 |
-
if not data:
|
104 |
-
break
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
text = data.strip()
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
client_socket.sendall(audio_chunk)
|
113 |
|
114 |
-
|
115 |
-
|
116 |
|
117 |
-
except Exception as inner_e:
|
118 |
-
print(f"Error during processing: {inner_e}")
|
119 |
-
traceback.print_exc() # Print the full traceback to diagnose the issue
|
120 |
-
break
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
except Exception as e:
|
123 |
-
|
124 |
traceback.print_exc()
|
125 |
-
finally:
|
126 |
-
client_socket.close()
|
127 |
|
128 |
|
129 |
def start_server(host, port, processor):
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
client_handler = Thread(target=handle_client, args=(client_socket, processor))
|
139 |
-
client_handler.start()
|
140 |
|
141 |
|
142 |
if __name__ == "__main__":
|
@@ -145,9 +216,14 @@ if __name__ == "__main__":
|
|
145 |
parser.add_argument("--host", default="0.0.0.0")
|
146 |
parser.add_argument("--port", default=9998)
|
147 |
|
|
|
|
|
|
|
|
|
|
|
148 |
parser.add_argument(
|
149 |
"--ckpt_file",
|
150 |
-
default=str(
|
151 |
help="Path to the model checkpoint file",
|
152 |
)
|
153 |
parser.add_argument(
|
@@ -175,6 +251,7 @@ if __name__ == "__main__":
|
|
175 |
try:
|
176 |
# Initialize the processor with the model and vocoder
|
177 |
processor = TTSStreamingProcessor(
|
|
|
178 |
ckpt_file=args.ckpt_file,
|
179 |
vocab_file=args.vocab_file,
|
180 |
ref_audio=args.ref_audio,
|
|
|
1 |
import argparse
|
2 |
import gc
|
3 |
+
import logging
|
4 |
+
import numpy as np
|
5 |
+
import queue
|
6 |
import socket
|
7 |
import struct
|
8 |
+
import threading
|
|
|
9 |
import traceback
|
10 |
+
import wave
|
11 |
from importlib.resources import files
|
|
|
12 |
|
13 |
+
import torch
|
14 |
+
import torchaudio
|
15 |
+
from huggingface_hub import hf_hub_download
|
16 |
+
from omegaconf import OmegaConf
|
17 |
+
|
18 |
+
from f5_tts.model.backbones.dit import DiT # noqa: F401. used for config
|
19 |
+
from f5_tts.infer.utils_infer import (
|
20 |
+
chunk_text,
|
21 |
+
preprocess_ref_audio_text,
|
22 |
+
load_vocoder,
|
23 |
+
load_model,
|
24 |
+
infer_batch_process,
|
25 |
+
)
|
26 |
+
|
27 |
+
logging.basicConfig(level=logging.INFO)
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
class AudioFileWriterThread(threading.Thread):
|
32 |
+
"""Threaded file writer to avoid blocking the TTS streaming process."""
|
33 |
+
|
34 |
+
def __init__(self, output_file, sampling_rate):
|
35 |
+
super().__init__()
|
36 |
+
self.output_file = output_file
|
37 |
+
self.sampling_rate = sampling_rate
|
38 |
+
self.queue = queue.Queue()
|
39 |
+
self.stop_event = threading.Event()
|
40 |
+
self.audio_data = []
|
41 |
+
|
42 |
+
def run(self):
|
43 |
+
"""Process queued audio data and write it to a file."""
|
44 |
+
logger.info("AudioFileWriterThread started.")
|
45 |
+
with wave.open(self.output_file, "wb") as wf:
|
46 |
+
wf.setnchannels(1)
|
47 |
+
wf.setsampwidth(2)
|
48 |
+
wf.setframerate(self.sampling_rate)
|
49 |
+
|
50 |
+
while not self.stop_event.is_set() or not self.queue.empty():
|
51 |
+
try:
|
52 |
+
chunk = self.queue.get(timeout=0.1)
|
53 |
+
if chunk is not None:
|
54 |
+
chunk = np.int16(chunk * 32767)
|
55 |
+
self.audio_data.append(chunk)
|
56 |
+
wf.writeframes(chunk.tobytes())
|
57 |
+
except queue.Empty:
|
58 |
+
continue
|
59 |
+
|
60 |
+
def add_chunk(self, chunk):
|
61 |
+
"""Add a new chunk to the queue."""
|
62 |
+
self.queue.put(chunk)
|
63 |
+
|
64 |
+
def stop(self):
|
65 |
+
"""Stop writing and ensure all queued data is written."""
|
66 |
+
self.stop_event.set()
|
67 |
+
self.join()
|
68 |
+
logger.info("Audio writing completed.")
|
69 |
|
70 |
|
71 |
class TTSStreamingProcessor:
|
72 |
+
def __init__(self, model, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
|
73 |
self.device = device or (
|
74 |
+
"cuda"
|
75 |
+
if torch.cuda.is_available()
|
76 |
+
else "xpu"
|
77 |
+
if torch.xpu.is_available()
|
78 |
+
else "mps"
|
79 |
+
if torch.backends.mps.is_available()
|
80 |
+
else "cpu"
|
81 |
)
|
82 |
+
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
83 |
+
self.model_cls = globals()[model_cfg.model.backbone]
|
84 |
+
self.model_arc = model_cfg.model.arch
|
85 |
+
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
86 |
+
self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate
|
87 |
+
|
88 |
+
self.model = self.load_ema_model(ckpt_file, vocab_file, dtype)
|
89 |
+
self.vocoder = self.load_vocoder_model()
|
90 |
+
|
91 |
+
self.update_reference(ref_audio, ref_text)
|
92 |
+
self._warm_up()
|
93 |
+
self.file_writer_thread = None
|
94 |
+
self.first_package = True
|
95 |
|
96 |
+
def load_ema_model(self, ckpt_file, vocab_file, dtype):
|
97 |
+
return load_model(
|
98 |
+
self.model_cls,
|
99 |
+
self.model_arc,
|
100 |
ckpt_path=ckpt_file,
|
101 |
+
mel_spec_type=self.mel_spec_type,
|
102 |
vocab_file=vocab_file,
|
103 |
ode_method="euler",
|
104 |
use_ema=True,
|
105 |
device=self.device,
|
106 |
).to(self.device, dtype=dtype)
|
107 |
|
108 |
+
def load_vocoder_model(self):
|
109 |
+
return load_vocoder(vocoder_name=self.mel_spec_type, is_local=False, local_path=None, device=self.device)
|
110 |
|
111 |
+
def update_reference(self, ref_audio, ref_text):
|
112 |
+
self.ref_audio, self.ref_text = preprocess_ref_audio_text(ref_audio, ref_text)
|
113 |
+
self.audio, self.sr = torchaudio.load(self.ref_audio)
|
114 |
|
115 |
+
ref_audio_duration = self.audio.shape[-1] / self.sr
|
116 |
+
ref_text_byte_len = len(self.ref_text.encode("utf-8"))
|
117 |
+
self.max_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration))
|
118 |
+
self.few_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 2)
|
119 |
+
self.min_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 4)
|
|
|
120 |
|
121 |
def _warm_up(self):
|
122 |
+
logger.info("Warming up the model...")
|
|
|
|
|
|
|
123 |
gen_text = "Warm-up text for the model."
|
124 |
+
for _ in infer_batch_process(
|
125 |
+
(self.audio, self.sr),
|
126 |
+
self.ref_text,
|
127 |
+
[gen_text],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
self.model,
|
129 |
self.vocoder,
|
130 |
+
progress=None,
|
131 |
+
device=self.device,
|
132 |
+
streaming=True,
|
133 |
+
):
|
134 |
+
pass
|
135 |
+
logger.info("Warm-up completed.")
|
136 |
+
|
137 |
+
def generate_stream(self, text, conn):
|
138 |
+
text_batches = chunk_text(text, max_chars=self.max_chars)
|
139 |
+
if self.first_package:
|
140 |
+
text_batches = chunk_text(text_batches[0], max_chars=self.few_chars) + text_batches[1:]
|
141 |
+
text_batches = chunk_text(text_batches[0], max_chars=self.min_chars) + text_batches[1:]
|
142 |
+
self.first_package = False
|
143 |
+
|
144 |
+
audio_stream = infer_batch_process(
|
145 |
+
(self.audio, self.sr),
|
146 |
+
self.ref_text,
|
147 |
+
text_batches,
|
148 |
+
self.model,
|
149 |
+
self.vocoder,
|
150 |
+
progress=None,
|
151 |
+
device=self.device,
|
152 |
+
streaming=True,
|
153 |
+
chunk_size=2048,
|
154 |
)
|
155 |
|
156 |
+
# Reset the file writer thread
|
157 |
+
if self.file_writer_thread is not None:
|
158 |
+
self.file_writer_thread.stop()
|
159 |
+
self.file_writer_thread = AudioFileWriterThread("output.wav", self.sampling_rate)
|
160 |
+
self.file_writer_thread.start()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
+
for audio_chunk, _ in audio_stream:
|
163 |
+
if len(audio_chunk) > 0:
|
164 |
+
logger.info(f"Generated audio chunk of size: {len(audio_chunk)}")
|
165 |
|
166 |
+
# Send audio chunk via socket
|
167 |
+
conn.sendall(struct.pack(f"{len(audio_chunk)}f", *audio_chunk))
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
+
# Write to file asynchronously
|
170 |
+
self.file_writer_thread.add_chunk(audio_chunk)
|
|
|
171 |
|
172 |
+
logger.info("Finished sending audio stream.")
|
173 |
+
conn.sendall(b"END") # Send end signal
|
|
|
174 |
|
175 |
+
# Ensure all audio data is written before exiting
|
176 |
+
self.file_writer_thread.stop()
|
177 |
|
|
|
|
|
|
|
|
|
178 |
|
179 |
+
def handle_client(conn, processor):
|
180 |
+
try:
|
181 |
+
with conn:
|
182 |
+
conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
183 |
+
while True:
|
184 |
+
data = conn.recv(1024)
|
185 |
+
if not data:
|
186 |
+
processor.first_package = True
|
187 |
+
break
|
188 |
+
data_str = data.decode("utf-8").strip()
|
189 |
+
logger.info(f"Received text: {data_str}")
|
190 |
+
|
191 |
+
try:
|
192 |
+
processor.generate_stream(data_str, conn)
|
193 |
+
except Exception as inner_e:
|
194 |
+
logger.error(f"Error during processing: {inner_e}")
|
195 |
+
traceback.print_exc()
|
196 |
+
break
|
197 |
except Exception as e:
|
198 |
+
logger.error(f"Error handling client: {e}")
|
199 |
traceback.print_exc()
|
|
|
|
|
200 |
|
201 |
|
202 |
def start_server(host, port, processor):
|
203 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
204 |
+
s.bind((host, port))
|
205 |
+
s.listen()
|
206 |
+
logger.info(f"Server started on {host}:{port}")
|
207 |
+
while True:
|
208 |
+
conn, addr = s.accept()
|
209 |
+
logger.info(f"Connected by {addr}")
|
210 |
+
handle_client(conn, processor)
|
|
|
|
|
211 |
|
212 |
|
213 |
if __name__ == "__main__":
|
|
|
216 |
parser.add_argument("--host", default="0.0.0.0")
|
217 |
parser.add_argument("--port", default=9998)
|
218 |
|
219 |
+
parser.add_argument(
|
220 |
+
"--model",
|
221 |
+
default="F5TTS_v1_Base",
|
222 |
+
help="The model name, e.g. F5TTS_v1_Base",
|
223 |
+
)
|
224 |
parser.add_argument(
|
225 |
"--ckpt_file",
|
226 |
+
default=str(hf_hub_download(repo_id="SWivid/F5-TTS", filename="F5TTS_v1_Base/model_1250000.safetensors")),
|
227 |
help="Path to the model checkpoint file",
|
228 |
)
|
229 |
parser.add_argument(
|
|
|
251 |
try:
|
252 |
# Initialize the processor with the model and vocoder
|
253 |
processor = TTSStreamingProcessor(
|
254 |
+
model=args.model,
|
255 |
ckpt_file=args.ckpt_file,
|
256 |
vocab_file=args.vocab_file,
|
257 |
ref_audio=args.ref_audio,
|
src/f5_tts/train/README.md
CHANGED
@@ -40,10 +40,10 @@ Once your datasets are prepared, you can start the training process.
|
|
40 |
accelerate config
|
41 |
|
42 |
# .yaml files are under src/f5_tts/configs directory
|
43 |
-
accelerate launch src/f5_tts/train/train.py --config-name
|
44 |
|
45 |
# possible to overwrite accelerate and hydra config
|
46 |
-
accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name
|
47 |
```
|
48 |
|
49 |
### 2. Finetuning practice
|
@@ -53,7 +53,7 @@ Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#1
|
|
53 |
|
54 |
The `use_ema = True` is harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off and see if provide better results.
|
55 |
|
56 |
-
### 3.
|
57 |
|
58 |
The `wandb/` dir will be created under path you run training/finetuning scripts.
|
59 |
|
@@ -62,7 +62,7 @@ By default, the training script does NOT use logging (assuming you didn't manual
|
|
62 |
To turn on wandb logging, you can either:
|
63 |
|
64 |
1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
|
65 |
-
2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/
|
66 |
|
67 |
On Mac & Linux:
|
68 |
|
@@ -75,7 +75,7 @@ On Windows:
|
|
75 |
```
|
76 |
set WANDB_API_KEY=<YOUR WANDB API KEY>
|
77 |
```
|
78 |
-
Moreover, if you couldn't access
|
79 |
|
80 |
```
|
81 |
export WANDB_MODE=offline
|
|
|
40 |
accelerate config
|
41 |
|
42 |
# .yaml files are under src/f5_tts/configs directory
|
43 |
+
accelerate launch src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml
|
44 |
|
45 |
# possible to overwrite accelerate and hydra config
|
46 |
+
accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml ++datasets.batch_size_per_gpu=19200
|
47 |
```
|
48 |
|
49 |
### 2. Finetuning practice
|
|
|
53 |
|
54 |
The `use_ema = True` is harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off and see if provide better results.
|
55 |
|
56 |
+
### 3. W&B Logging
|
57 |
|
58 |
The `wandb/` dir will be created under path you run training/finetuning scripts.
|
59 |
|
|
|
62 |
To turn on wandb logging, you can either:
|
63 |
|
64 |
1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
|
65 |
+
2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/authorize and set the environment variable as follows:
|
66 |
|
67 |
On Mac & Linux:
|
68 |
|
|
|
75 |
```
|
76 |
set WANDB_API_KEY=<YOUR WANDB API KEY>
|
77 |
```
|
78 |
+
Moreover, if you couldn't access W&B and want to log metrics offline, you can set the environment variable as follows:
|
79 |
|
80 |
```
|
81 |
export WANDB_MODE=offline
|
src/f5_tts/train/datasets/prepare_csv_wavs.py
CHANGED
@@ -1,12 +1,17 @@
|
|
1 |
import os
|
2 |
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
sys.path.append(os.getcwd())
|
5 |
|
6 |
import argparse
|
7 |
import csv
|
8 |
import json
|
9 |
-
import shutil
|
10 |
from importlib.resources import files
|
11 |
from pathlib import Path
|
12 |
|
@@ -29,32 +34,157 @@ def is_csv_wavs_format(input_dataset_dir):
|
|
29 |
return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
|
30 |
|
31 |
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
|
34 |
input_dir = Path(input_dir)
|
35 |
metadata_path = input_dir / "metadata.csv"
|
36 |
audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
|
37 |
|
38 |
-
sub_result, durations = [], []
|
39 |
-
vocab_set = set()
|
40 |
polyphone = True
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
return sub_result, durations, vocab_set
|
53 |
|
54 |
|
55 |
-
def get_audio_duration(audio_path):
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
|
60 |
def read_audio_text_pairs(csv_file_path):
|
@@ -76,36 +206,27 @@ def read_audio_text_pairs(csv_file_path):
|
|
76 |
|
77 |
def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
|
78 |
out_dir = Path(out_dir)
|
79 |
-
# save preprocessed dataset to disk
|
80 |
out_dir.mkdir(exist_ok=True, parents=True)
|
81 |
print(f"\nSaving to {out_dir} ...")
|
82 |
|
83 |
-
# dataset
|
84 |
-
# dataset.save_to_disk(f"{out_dir}/raw", max_shard_size="2GB")
|
85 |
raw_arrow_path = out_dir / "raw.arrow"
|
86 |
-
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=
|
87 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
88 |
writer.write(line)
|
89 |
|
90 |
-
#
|
91 |
dur_json_path = out_dir / "duration.json"
|
92 |
with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
|
93 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
94 |
|
95 |
-
# vocab
|
96 |
-
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
|
97 |
-
# if tokenizer == "pinyin":
|
98 |
-
# text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
|
99 |
voca_out_path = out_dir / "vocab.txt"
|
100 |
-
with open(voca_out_path.as_posix(), "w") as f:
|
101 |
-
for vocab in sorted(text_vocab_set):
|
102 |
-
f.write(vocab + "\n")
|
103 |
-
|
104 |
if is_finetune:
|
105 |
file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
|
106 |
shutil.copy2(file_vocab_finetune, voca_out_path)
|
107 |
else:
|
108 |
-
with open(voca_out_path, "w") as f:
|
109 |
for vocab in sorted(text_vocab_set):
|
110 |
f.write(vocab + "\n")
|
111 |
|
@@ -115,24 +236,48 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
|
|
115 |
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
116 |
|
117 |
|
118 |
-
def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True):
|
119 |
if is_finetune:
|
120 |
assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
|
121 |
-
sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir)
|
122 |
save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
|
123 |
|
124 |
|
125 |
def cli():
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
|
138 |
if __name__ == "__main__":
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
+
import signal
|
4 |
+
import subprocess # For invoking ffprobe
|
5 |
+
import shutil
|
6 |
+
import concurrent.futures
|
7 |
+
import multiprocessing
|
8 |
+
from contextlib import contextmanager
|
9 |
|
10 |
sys.path.append(os.getcwd())
|
11 |
|
12 |
import argparse
|
13 |
import csv
|
14 |
import json
|
|
|
15 |
from importlib.resources import files
|
16 |
from pathlib import Path
|
17 |
|
|
|
34 |
return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
|
35 |
|
36 |
|
37 |
+
# Configuration constants
|
38 |
+
BATCH_SIZE = 100 # Batch size for text conversion
|
39 |
+
MAX_WORKERS = max(1, multiprocessing.cpu_count() - 1) # Leave one CPU free
|
40 |
+
THREAD_NAME_PREFIX = "AudioProcessor"
|
41 |
+
CHUNK_SIZE = 100 # Number of files to process per worker batch
|
42 |
+
|
43 |
+
executor = None # Global executor for cleanup
|
44 |
+
|
45 |
+
|
46 |
+
@contextmanager
|
47 |
+
def graceful_exit():
|
48 |
+
"""Context manager for graceful shutdown on signals"""
|
49 |
+
|
50 |
+
def signal_handler(signum, frame):
|
51 |
+
print("\nReceived signal to terminate. Cleaning up...")
|
52 |
+
if executor is not None:
|
53 |
+
print("Shutting down executor...")
|
54 |
+
executor.shutdown(wait=False, cancel_futures=True)
|
55 |
+
sys.exit(1)
|
56 |
+
|
57 |
+
# Set up signal handlers
|
58 |
+
signal.signal(signal.SIGINT, signal_handler)
|
59 |
+
signal.signal(signal.SIGTERM, signal_handler)
|
60 |
+
|
61 |
+
try:
|
62 |
+
yield
|
63 |
+
finally:
|
64 |
+
if executor is not None:
|
65 |
+
executor.shutdown(wait=False)
|
66 |
+
|
67 |
+
|
68 |
+
def process_audio_file(audio_path, text, polyphone):
|
69 |
+
"""Process a single audio file by checking its existence and extracting duration."""
|
70 |
+
if not Path(audio_path).exists():
|
71 |
+
print(f"audio {audio_path} not found, skipping")
|
72 |
+
return None
|
73 |
+
try:
|
74 |
+
audio_duration = get_audio_duration(audio_path)
|
75 |
+
if audio_duration <= 0:
|
76 |
+
raise ValueError(f"Duration {audio_duration} is non-positive.")
|
77 |
+
return (audio_path, text, audio_duration)
|
78 |
+
except Exception as e:
|
79 |
+
print(f"Warning: Failed to process {audio_path} due to error: {e}. Skipping corrupt file.")
|
80 |
+
return None
|
81 |
+
|
82 |
+
|
83 |
+
def batch_convert_texts(texts, polyphone, batch_size=BATCH_SIZE):
|
84 |
+
"""Convert a list of texts to pinyin in batches."""
|
85 |
+
converted_texts = []
|
86 |
+
for i in range(0, len(texts), batch_size):
|
87 |
+
batch = texts[i : i + batch_size]
|
88 |
+
converted_batch = convert_char_to_pinyin(batch, polyphone=polyphone)
|
89 |
+
converted_texts.extend(converted_batch)
|
90 |
+
return converted_texts
|
91 |
+
|
92 |
+
|
93 |
+
def prepare_csv_wavs_dir(input_dir, num_workers=None):
|
94 |
+
global executor
|
95 |
assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
|
96 |
input_dir = Path(input_dir)
|
97 |
metadata_path = input_dir / "metadata.csv"
|
98 |
audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
|
99 |
|
|
|
|
|
100 |
polyphone = True
|
101 |
+
total_files = len(audio_path_text_pairs)
|
102 |
+
|
103 |
+
# Use provided worker count or calculate optimal number
|
104 |
+
worker_count = num_workers if num_workers is not None else min(MAX_WORKERS, total_files)
|
105 |
+
print(f"\nProcessing {total_files} audio files using {worker_count} workers...")
|
106 |
+
|
107 |
+
with graceful_exit():
|
108 |
+
# Initialize thread pool with optimized settings
|
109 |
+
with concurrent.futures.ThreadPoolExecutor(
|
110 |
+
max_workers=worker_count, thread_name_prefix=THREAD_NAME_PREFIX
|
111 |
+
) as exec:
|
112 |
+
executor = exec
|
113 |
+
results = []
|
114 |
+
|
115 |
+
# Process files in chunks for better efficiency
|
116 |
+
for i in range(0, len(audio_path_text_pairs), CHUNK_SIZE):
|
117 |
+
chunk = audio_path_text_pairs[i : i + CHUNK_SIZE]
|
118 |
+
# Submit futures in order
|
119 |
+
chunk_futures = [executor.submit(process_audio_file, pair[0], pair[1], polyphone) for pair in chunk]
|
120 |
+
|
121 |
+
# Iterate over futures in the original submission order to preserve ordering
|
122 |
+
for future in tqdm(
|
123 |
+
chunk_futures,
|
124 |
+
total=len(chunk),
|
125 |
+
desc=f"Processing chunk {i//CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1)//CHUNK_SIZE}",
|
126 |
+
):
|
127 |
+
try:
|
128 |
+
result = future.result()
|
129 |
+
if result is not None:
|
130 |
+
results.append(result)
|
131 |
+
except Exception as e:
|
132 |
+
print(f"Error processing file: {e}")
|
133 |
+
|
134 |
+
executor = None
|
135 |
+
|
136 |
+
# Filter out failed results
|
137 |
+
processed = [res for res in results if res is not None]
|
138 |
+
if not processed:
|
139 |
+
raise RuntimeError("No valid audio files were processed!")
|
140 |
+
|
141 |
+
# Batch process text conversion
|
142 |
+
raw_texts = [item[1] for item in processed]
|
143 |
+
converted_texts = batch_convert_texts(raw_texts, polyphone, batch_size=BATCH_SIZE)
|
144 |
+
|
145 |
+
# Prepare final results
|
146 |
+
sub_result = []
|
147 |
+
durations = []
|
148 |
+
vocab_set = set()
|
149 |
+
|
150 |
+
for (audio_path, _, duration), conv_text in zip(processed, converted_texts):
|
151 |
+
sub_result.append({"audio_path": audio_path, "text": conv_text, "duration": duration})
|
152 |
+
durations.append(duration)
|
153 |
+
vocab_set.update(list(conv_text))
|
154 |
|
155 |
return sub_result, durations, vocab_set
|
156 |
|
157 |
|
158 |
+
def get_audio_duration(audio_path, timeout=5):
|
159 |
+
"""
|
160 |
+
Get the duration of an audio file in seconds using ffmpeg's ffprobe.
|
161 |
+
Falls back to torchaudio.load() if ffprobe fails.
|
162 |
+
"""
|
163 |
+
try:
|
164 |
+
cmd = [
|
165 |
+
"ffprobe",
|
166 |
+
"-v",
|
167 |
+
"error",
|
168 |
+
"-show_entries",
|
169 |
+
"format=duration",
|
170 |
+
"-of",
|
171 |
+
"default=noprint_wrappers=1:nokey=1",
|
172 |
+
audio_path,
|
173 |
+
]
|
174 |
+
result = subprocess.run(
|
175 |
+
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True, timeout=timeout
|
176 |
+
)
|
177 |
+
duration_str = result.stdout.strip()
|
178 |
+
if duration_str:
|
179 |
+
return float(duration_str)
|
180 |
+
raise ValueError("Empty duration string from ffprobe.")
|
181 |
+
except (subprocess.TimeoutExpired, subprocess.SubprocessError, ValueError) as e:
|
182 |
+
print(f"Warning: ffprobe failed for {audio_path} with error: {e}. Falling back to torchaudio.")
|
183 |
+
try:
|
184 |
+
audio, sample_rate = torchaudio.load(audio_path)
|
185 |
+
return audio.shape[1] / sample_rate
|
186 |
+
except Exception as e:
|
187 |
+
raise RuntimeError(f"Both ffprobe and torchaudio failed for {audio_path}: {e}")
|
188 |
|
189 |
|
190 |
def read_audio_text_pairs(csv_file_path):
|
|
|
206 |
|
207 |
def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
|
208 |
out_dir = Path(out_dir)
|
|
|
209 |
out_dir.mkdir(exist_ok=True, parents=True)
|
210 |
print(f"\nSaving to {out_dir} ...")
|
211 |
|
212 |
+
# Save dataset with improved batch size for better I/O performance
|
|
|
213 |
raw_arrow_path = out_dir / "raw.arrow"
|
214 |
+
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=100) as writer:
|
215 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
216 |
writer.write(line)
|
217 |
|
218 |
+
# Save durations to JSON
|
219 |
dur_json_path = out_dir / "duration.json"
|
220 |
with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
|
221 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
222 |
|
223 |
+
# Handle vocab file - write only once based on finetune flag
|
|
|
|
|
|
|
224 |
voca_out_path = out_dir / "vocab.txt"
|
|
|
|
|
|
|
|
|
225 |
if is_finetune:
|
226 |
file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
|
227 |
shutil.copy2(file_vocab_finetune, voca_out_path)
|
228 |
else:
|
229 |
+
with open(voca_out_path.as_posix(), "w") as f:
|
230 |
for vocab in sorted(text_vocab_set):
|
231 |
f.write(vocab + "\n")
|
232 |
|
|
|
236 |
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
237 |
|
238 |
|
239 |
+
def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None):
|
240 |
if is_finetune:
|
241 |
assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
|
242 |
+
sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir, num_workers=num_workers)
|
243 |
save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
|
244 |
|
245 |
|
246 |
def cli():
|
247 |
+
try:
|
248 |
+
# Before processing, check if ffprobe is available.
|
249 |
+
if shutil.which("ffprobe") is None:
|
250 |
+
print(
|
251 |
+
"Warning: ffprobe is not available. Duration extraction will rely on torchaudio (which may be slower)."
|
252 |
+
)
|
253 |
+
|
254 |
+
# Usage examples in help text
|
255 |
+
parser = argparse.ArgumentParser(
|
256 |
+
description="Prepare and save dataset.",
|
257 |
+
epilog="""
|
258 |
+
Examples:
|
259 |
+
# For fine-tuning (default):
|
260 |
+
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path
|
261 |
+
|
262 |
+
# For pre-training:
|
263 |
+
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --pretrain
|
264 |
+
|
265 |
+
# With custom worker count:
|
266 |
+
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --workers 4
|
267 |
+
""",
|
268 |
+
)
|
269 |
+
parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
|
270 |
+
parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
|
271 |
+
parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
|
272 |
+
parser.add_argument("--workers", type=int, help=f"Number of worker threads (default: {MAX_WORKERS})")
|
273 |
+
args = parser.parse_args()
|
274 |
+
|
275 |
+
prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain, num_workers=args.workers)
|
276 |
+
except KeyboardInterrupt:
|
277 |
+
print("\nOperation cancelled by user. Cleaning up...")
|
278 |
+
if executor is not None:
|
279 |
+
executor.shutdown(wait=False, cancel_futures=True)
|
280 |
+
sys.exit(1)
|
281 |
|
282 |
|
283 |
if __name__ == "__main__":
|
src/f5_tts/train/finetune_cli.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
import argparse
|
2 |
import os
|
3 |
import shutil
|
|
|
4 |
|
5 |
from cached_path import cached_path
|
|
|
6 |
from f5_tts.model import CFM, UNetT, DiT, Trainer
|
7 |
from f5_tts.model.utils import get_tokenizer
|
8 |
from f5_tts.model.dataset import load_dataset
|
9 |
-
from importlib.resources import files
|
10 |
|
11 |
|
12 |
# -------------------------- Dataset Settings --------------------------- #
|
@@ -20,19 +21,14 @@ mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
|
|
20 |
|
21 |
# -------------------------- Argument Parsing --------------------------- #
|
22 |
def parse_args():
|
23 |
-
# batch_size_per_gpu = 1000 settting for gpu 8GB
|
24 |
-
# batch_size_per_gpu = 1600 settting for gpu 12GB
|
25 |
-
# batch_size_per_gpu = 2000 settting for gpu 16GB
|
26 |
-
# batch_size_per_gpu = 3200 settting for gpu 24GB
|
27 |
-
|
28 |
-
# num_warmup_updates = 300 for 5000 sample about 10 hours
|
29 |
-
|
30 |
-
# change save_per_updates , last_per_steps change this value what you need ,
|
31 |
-
|
32 |
parser = argparse.ArgumentParser(description="Train CFM Model")
|
33 |
|
34 |
parser.add_argument(
|
35 |
-
"--exp_name",
|
|
|
|
|
|
|
|
|
36 |
)
|
37 |
parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
|
38 |
parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
|
@@ -44,9 +40,15 @@ def parse_args():
|
|
44 |
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
|
45 |
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
|
46 |
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
|
47 |
-
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup
|
48 |
-
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X
|
49 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
parser.add_argument("--finetune", action="store_true", help="Use Finetune")
|
51 |
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
|
52 |
parser.add_argument(
|
@@ -61,7 +63,7 @@ def parse_args():
|
|
61 |
parser.add_argument(
|
62 |
"--log_samples",
|
63 |
action="store_true",
|
64 |
-
help="Log inferenced samples per ckpt save
|
65 |
)
|
66 |
parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
|
67 |
parser.add_argument(
|
@@ -82,19 +84,54 @@ def main():
|
|
82 |
checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
|
83 |
|
84 |
# Model parameters based on experiment name
|
85 |
-
|
|
|
86 |
wandb_resume_id = None
|
87 |
model_cls = DiT
|
88 |
-
model_cfg = dict(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
if args.finetune:
|
90 |
if args.pretrain is None:
|
91 |
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
92 |
else:
|
93 |
ckpt_path = args.pretrain
|
|
|
94 |
elif args.exp_name == "E2TTS_Base":
|
95 |
wandb_resume_id = None
|
96 |
model_cls = UNetT
|
97 |
-
model_cfg = dict(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
if args.finetune:
|
99 |
if args.pretrain is None:
|
100 |
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
@@ -105,12 +142,16 @@ def main():
|
|
105 |
if not os.path.isdir(checkpoint_path):
|
106 |
os.makedirs(checkpoint_path, exist_ok=True)
|
107 |
|
108 |
-
file_checkpoint = os.path.
|
|
|
|
|
|
|
109 |
if not os.path.isfile(file_checkpoint):
|
110 |
shutil.copy2(ckpt_path, file_checkpoint)
|
111 |
print("copy checkpoint for finetune")
|
112 |
|
113 |
# Use the tokenizer and tokenizer_path provided in the command line arguments
|
|
|
114 |
tokenizer = args.tokenizer
|
115 |
if tokenizer == "custom":
|
116 |
if not args.tokenizer_path:
|
@@ -145,8 +186,9 @@ def main():
|
|
145 |
args.learning_rate,
|
146 |
num_warmup_updates=args.num_warmup_updates,
|
147 |
save_per_updates=args.save_per_updates,
|
|
|
148 |
checkpoint_path=checkpoint_path,
|
149 |
-
|
150 |
batch_size_type=args.batch_size_type,
|
151 |
max_samples=args.max_samples,
|
152 |
grad_accumulation_steps=args.grad_accumulation_steps,
|
@@ -156,7 +198,7 @@ def main():
|
|
156 |
wandb_run_name=args.exp_name,
|
157 |
wandb_resume_id=wandb_resume_id,
|
158 |
log_samples=args.log_samples,
|
159 |
-
|
160 |
bnb_optimizer=args.bnb_optimizer,
|
161 |
)
|
162 |
|
|
|
1 |
import argparse
|
2 |
import os
|
3 |
import shutil
|
4 |
+
from importlib.resources import files
|
5 |
|
6 |
from cached_path import cached_path
|
7 |
+
|
8 |
from f5_tts.model import CFM, UNetT, DiT, Trainer
|
9 |
from f5_tts.model.utils import get_tokenizer
|
10 |
from f5_tts.model.dataset import load_dataset
|
|
|
11 |
|
12 |
|
13 |
# -------------------------- Dataset Settings --------------------------- #
|
|
|
21 |
|
22 |
# -------------------------- Argument Parsing --------------------------- #
|
23 |
def parse_args():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
parser = argparse.ArgumentParser(description="Train CFM Model")
|
25 |
|
26 |
parser.add_argument(
|
27 |
+
"--exp_name",
|
28 |
+
type=str,
|
29 |
+
default="F5TTS_v1_Base",
|
30 |
+
choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"],
|
31 |
+
help="Experiment name",
|
32 |
)
|
33 |
parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
|
34 |
parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
|
|
|
40 |
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
|
41 |
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
|
42 |
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
|
43 |
+
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup updates")
|
44 |
+
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X updates")
|
45 |
+
parser.add_argument(
|
46 |
+
"--keep_last_n_checkpoints",
|
47 |
+
type=int,
|
48 |
+
default=-1,
|
49 |
+
help="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
|
50 |
+
)
|
51 |
+
parser.add_argument("--last_per_updates", type=int, default=50000, help="Save last checkpoint every X updates")
|
52 |
parser.add_argument("--finetune", action="store_true", help="Use Finetune")
|
53 |
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
|
54 |
parser.add_argument(
|
|
|
63 |
parser.add_argument(
|
64 |
"--log_samples",
|
65 |
action="store_true",
|
66 |
+
help="Log inferenced samples per ckpt save updates",
|
67 |
)
|
68 |
parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
|
69 |
parser.add_argument(
|
|
|
84 |
checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
|
85 |
|
86 |
# Model parameters based on experiment name
|
87 |
+
|
88 |
+
if args.exp_name == "F5TTS_v1_Base":
|
89 |
wandb_resume_id = None
|
90 |
model_cls = DiT
|
91 |
+
model_cfg = dict(
|
92 |
+
dim=1024,
|
93 |
+
depth=22,
|
94 |
+
heads=16,
|
95 |
+
ff_mult=2,
|
96 |
+
text_dim=512,
|
97 |
+
conv_layers=4,
|
98 |
+
)
|
99 |
+
if args.finetune:
|
100 |
+
if args.pretrain is None:
|
101 |
+
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors"))
|
102 |
+
else:
|
103 |
+
ckpt_path = args.pretrain
|
104 |
+
|
105 |
+
elif args.exp_name == "F5TTS_Base":
|
106 |
+
wandb_resume_id = None
|
107 |
+
model_cls = DiT
|
108 |
+
model_cfg = dict(
|
109 |
+
dim=1024,
|
110 |
+
depth=22,
|
111 |
+
heads=16,
|
112 |
+
ff_mult=2,
|
113 |
+
text_dim=512,
|
114 |
+
text_mask_padding=False,
|
115 |
+
conv_layers=4,
|
116 |
+
pe_attn_head=1,
|
117 |
+
)
|
118 |
if args.finetune:
|
119 |
if args.pretrain is None:
|
120 |
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
121 |
else:
|
122 |
ckpt_path = args.pretrain
|
123 |
+
|
124 |
elif args.exp_name == "E2TTS_Base":
|
125 |
wandb_resume_id = None
|
126 |
model_cls = UNetT
|
127 |
+
model_cfg = dict(
|
128 |
+
dim=1024,
|
129 |
+
depth=24,
|
130 |
+
heads=16,
|
131 |
+
ff_mult=4,
|
132 |
+
text_mask_padding=False,
|
133 |
+
pe_attn_head=1,
|
134 |
+
)
|
135 |
if args.finetune:
|
136 |
if args.pretrain is None:
|
137 |
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
|
|
142 |
if not os.path.isdir(checkpoint_path):
|
143 |
os.makedirs(checkpoint_path, exist_ok=True)
|
144 |
|
145 |
+
file_checkpoint = os.path.basename(ckpt_path)
|
146 |
+
if not file_checkpoint.startswith("pretrained_"): # Change: Add 'pretrained_' prefix to copied model
|
147 |
+
file_checkpoint = "pretrained_" + file_checkpoint
|
148 |
+
file_checkpoint = os.path.join(checkpoint_path, file_checkpoint)
|
149 |
if not os.path.isfile(file_checkpoint):
|
150 |
shutil.copy2(ckpt_path, file_checkpoint)
|
151 |
print("copy checkpoint for finetune")
|
152 |
|
153 |
# Use the tokenizer and tokenizer_path provided in the command line arguments
|
154 |
+
|
155 |
tokenizer = args.tokenizer
|
156 |
if tokenizer == "custom":
|
157 |
if not args.tokenizer_path:
|
|
|
186 |
args.learning_rate,
|
187 |
num_warmup_updates=args.num_warmup_updates,
|
188 |
save_per_updates=args.save_per_updates,
|
189 |
+
keep_last_n_checkpoints=args.keep_last_n_checkpoints,
|
190 |
checkpoint_path=checkpoint_path,
|
191 |
+
batch_size_per_gpu=args.batch_size_per_gpu,
|
192 |
batch_size_type=args.batch_size_type,
|
193 |
max_samples=args.max_samples,
|
194 |
grad_accumulation_steps=args.grad_accumulation_steps,
|
|
|
198 |
wandb_run_name=args.exp_name,
|
199 |
wandb_resume_id=wandb_resume_id,
|
200 |
log_samples=args.log_samples,
|
201 |
+
last_per_updates=args.last_per_updates,
|
202 |
bnb_optimizer=args.bnb_optimizer,
|
203 |
)
|
204 |
|
src/f5_tts/train/finetune_gradio.py
CHANGED
@@ -1,36 +1,36 @@
|
|
1 |
-
import threading
|
2 |
-
import queue
|
3 |
-
import re
|
4 |
-
|
5 |
import gc
|
6 |
import json
|
|
|
7 |
import os
|
8 |
import platform
|
9 |
import psutil
|
|
|
10 |
import random
|
|
|
11 |
import signal
|
12 |
import shutil
|
13 |
import subprocess
|
14 |
import sys
|
15 |
import tempfile
|
|
|
16 |
import time
|
17 |
from glob import glob
|
|
|
|
|
18 |
|
19 |
import click
|
20 |
import gradio as gr
|
21 |
import librosa
|
22 |
-
import numpy as np
|
23 |
import torch
|
24 |
import torchaudio
|
|
|
25 |
from datasets import Dataset as Dataset_
|
26 |
from datasets.arrow_writer import ArrowWriter
|
27 |
-
from safetensors.torch import save_file
|
28 |
-
|
29 |
-
from cached_path import cached_path
|
30 |
from f5_tts.api import F5TTS
|
31 |
from f5_tts.model.utils import convert_char_to_pinyin
|
32 |
from f5_tts.infer.utils_infer import transcribe
|
33 |
-
from importlib.resources import files
|
34 |
|
35 |
|
36 |
training_process = None
|
@@ -46,7 +46,15 @@ path_data = str(files("f5_tts").joinpath("../../data"))
|
|
46 |
path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts"))
|
47 |
file_train = str(files("f5_tts").joinpath("train/finetune_cli.py"))
|
48 |
|
49 |
-
device =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
|
52 |
# Save settings from a JSON file
|
@@ -62,7 +70,8 @@ def save_settings(
|
|
62 |
epochs,
|
63 |
num_warmup_updates,
|
64 |
save_per_updates,
|
65 |
-
|
|
|
66 |
finetune,
|
67 |
file_checkpoint_train,
|
68 |
tokenizer_type,
|
@@ -86,7 +95,8 @@ def save_settings(
|
|
86 |
"epochs": epochs,
|
87 |
"num_warmup_updates": num_warmup_updates,
|
88 |
"save_per_updates": save_per_updates,
|
89 |
-
"
|
|
|
90 |
"finetune": finetune,
|
91 |
"file_checkpoint_train": file_checkpoint_train,
|
92 |
"tokenizer_type": tokenizer_type,
|
@@ -106,73 +116,56 @@ def load_settings(project_name):
|
|
106 |
path_project = os.path.join(path_project_ckpts, project_name)
|
107 |
file_setting = os.path.join(path_project, "setting.json")
|
108 |
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
settings["max_grad_norm"],
|
138 |
-
settings["epochs"],
|
139 |
-
settings["num_warmup_updates"],
|
140 |
-
settings["save_per_updates"],
|
141 |
-
settings["last_per_steps"],
|
142 |
-
settings["finetune"],
|
143 |
-
settings["file_checkpoint_train"],
|
144 |
-
settings["tokenizer_type"],
|
145 |
-
settings["tokenizer_file"],
|
146 |
-
settings["mixed_precision"],
|
147 |
-
settings["logger"],
|
148 |
-
settings["bnb_optimizer"],
|
149 |
-
)
|
150 |
|
151 |
-
|
152 |
-
settings = json.load(f)
|
153 |
-
if "logger" not in settings:
|
154 |
-
settings["logger"] = "wandb"
|
155 |
-
if "bnb_optimizer" not in settings:
|
156 |
-
settings["bnb_optimizer"] = False
|
157 |
return (
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
|
|
176 |
)
|
177 |
|
178 |
|
@@ -369,17 +362,18 @@ def terminate_process(pid):
|
|
369 |
|
370 |
def start_training(
|
371 |
dataset_name="",
|
372 |
-
exp_name="
|
373 |
-
learning_rate=1e-
|
374 |
-
batch_size_per_gpu=
|
375 |
-
batch_size_type="
|
376 |
max_samples=64,
|
377 |
-
grad_accumulation_steps=
|
378 |
max_grad_norm=1.0,
|
379 |
-
epochs=
|
380 |
-
num_warmup_updates=
|
381 |
-
save_per_updates=
|
382 |
-
|
|
|
383 |
finetune=True,
|
384 |
file_checkpoint_train="",
|
385 |
tokenizer_type="pinyin",
|
@@ -438,18 +432,19 @@ def start_training(
|
|
438 |
fp16 = ""
|
439 |
|
440 |
cmd = (
|
441 |
-
f"accelerate launch {fp16} {file_train} --exp_name {exp_name}
|
442 |
-
f"--learning_rate {learning_rate}
|
443 |
-
f"--batch_size_per_gpu {batch_size_per_gpu}
|
444 |
-
f"--batch_size_type {batch_size_type}
|
445 |
-
f"--max_samples {max_samples}
|
446 |
-
f"--grad_accumulation_steps {grad_accumulation_steps}
|
447 |
-
f"--max_grad_norm {max_grad_norm}
|
448 |
-
f"--epochs {epochs}
|
449 |
-
f"--num_warmup_updates {num_warmup_updates}
|
450 |
-
f"--save_per_updates {save_per_updates}
|
451 |
-
f"--
|
452 |
-
f"--
|
|
|
453 |
)
|
454 |
|
455 |
if finetune:
|
@@ -482,7 +477,8 @@ def start_training(
|
|
482 |
epochs,
|
483 |
num_warmup_updates,
|
484 |
save_per_updates,
|
485 |
-
|
|
|
486 |
finetune,
|
487 |
file_checkpoint_train,
|
488 |
tokenizer_type,
|
@@ -548,7 +544,7 @@ def start_training(
|
|
548 |
output = stdout_queue.get_nowait()
|
549 |
print(output, end="")
|
550 |
match = re.search(
|
551 |
-
r"Epoch (\d+)/(\d+):\s+(\d+)%\|.*\[(\d+:\d+)<.*?loss=(\d+\.\d+),
|
552 |
)
|
553 |
if match:
|
554 |
current_epoch = match.group(1)
|
@@ -556,13 +552,13 @@ def start_training(
|
|
556 |
percent_complete = match.group(3)
|
557 |
elapsed_time = match.group(4)
|
558 |
loss = match.group(5)
|
559 |
-
|
560 |
message = (
|
561 |
f"Epoch: {current_epoch}/{total_epochs}, "
|
562 |
f"Progress: {percent_complete}%, "
|
563 |
f"Elapsed Time: {elapsed_time}, "
|
564 |
f"Loss: {loss}, "
|
565 |
-
f"
|
566 |
)
|
567 |
yield message, gr.update(interactive=False), gr.update(interactive=True)
|
568 |
elif output.strip():
|
@@ -801,14 +797,14 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
|
|
801 |
print(f"Error processing {file_audio}: {e}")
|
802 |
continue
|
803 |
|
804 |
-
if duration < 1 or duration >
|
805 |
-
if duration >
|
806 |
-
error_files.append([file_audio, "duration >
|
807 |
if duration < 1:
|
808 |
error_files.append([file_audio, "duration < 1 sec "])
|
809 |
continue
|
810 |
if len(text) < 3:
|
811 |
-
error_files.append([file_audio, "very
|
812 |
continue
|
813 |
|
814 |
text = clear_text(text)
|
@@ -875,40 +871,37 @@ def check_user(value):
|
|
875 |
|
876 |
def calculate_train(
|
877 |
name_project,
|
|
|
|
|
|
|
878 |
batch_size_type,
|
879 |
max_samples,
|
880 |
-
learning_rate,
|
881 |
num_warmup_updates,
|
882 |
-
save_per_updates,
|
883 |
-
last_per_steps,
|
884 |
finetune,
|
885 |
):
|
886 |
path_project = os.path.join(path_data, name_project)
|
887 |
-
|
888 |
|
889 |
-
|
|
|
|
|
|
|
890 |
return (
|
891 |
-
|
|
|
|
|
892 |
max_samples,
|
893 |
num_warmup_updates,
|
894 |
-
save_per_updates,
|
895 |
-
last_per_steps,
|
896 |
"project not found !",
|
897 |
-
learning_rate,
|
898 |
)
|
899 |
|
900 |
-
with open(
|
901 |
data = json.load(file)
|
902 |
|
903 |
duration_list = data["duration"]
|
904 |
-
|
905 |
-
|
906 |
-
|
907 |
-
# if torch.cuda.is_available():
|
908 |
-
# gpu_properties = torch.cuda.get_device_properties(0)
|
909 |
-
# total_memory = gpu_properties.total_memory / (1024**3)
|
910 |
-
# elif torch.backends.mps.is_available():
|
911 |
-
# total_memory = psutil.virtual_memory().available / (1024**3)
|
912 |
|
913 |
if torch.cuda.is_available():
|
914 |
gpu_count = torch.cuda.device_count()
|
@@ -916,57 +909,39 @@ def calculate_train(
|
|
916 |
for i in range(gpu_count):
|
917 |
gpu_properties = torch.cuda.get_device_properties(i)
|
918 |
total_memory += gpu_properties.total_memory / (1024**3) # in GB
|
919 |
-
|
|
|
|
|
|
|
|
|
|
|
920 |
elif torch.backends.mps.is_available():
|
921 |
gpu_count = 1
|
922 |
total_memory = psutil.virtual_memory().available / (1024**3)
|
923 |
|
|
|
|
|
|
|
924 |
if batch_size_type == "frame":
|
925 |
-
|
926 |
-
|
927 |
-
batch_size_per_gpu = int(
|
928 |
-
else:
|
929 |
-
batch_size_per_gpu = int(total_memory / 8)
|
930 |
-
batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
|
931 |
-
batch = batch_size_per_gpu
|
932 |
|
933 |
-
if
|
934 |
-
|
935 |
|
936 |
-
|
937 |
-
|
938 |
-
|
939 |
-
|
940 |
-
|
941 |
-
|
942 |
-
|
943 |
-
|
944 |
-
|
945 |
-
|
946 |
-
|
947 |
-
|
948 |
-
last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
|
949 |
-
if last_per_steps <= 0:
|
950 |
-
last_per_steps = 2
|
951 |
-
|
952 |
-
total_hours = hours
|
953 |
-
mel_hop_length = 256
|
954 |
-
mel_sampling_rate = 24000
|
955 |
-
|
956 |
-
# target
|
957 |
-
wanted_max_updates = 1000000
|
958 |
-
|
959 |
-
# train params
|
960 |
-
gpus = gpu_count
|
961 |
-
frames_per_gpu = batch_size_per_gpu # 8 * 38400 = 307200
|
962 |
-
grad_accum = 1
|
963 |
-
|
964 |
-
# intermediate
|
965 |
-
mini_batch_frames = frames_per_gpu * grad_accum * gpus
|
966 |
-
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
|
967 |
-
updates_per_epoch = total_hours / mini_batch_hours
|
968 |
-
# steps_per_epoch = updates_per_epoch * grad_accum
|
969 |
-
epochs = wanted_max_updates / updates_per_epoch
|
970 |
|
971 |
if finetune:
|
972 |
learning_rate = 1e-5
|
@@ -974,20 +949,18 @@ def calculate_train(
|
|
974 |
learning_rate = 7.5e-5
|
975 |
|
976 |
return (
|
|
|
|
|
977 |
batch_size_per_gpu,
|
978 |
max_samples,
|
979 |
num_warmup_updates,
|
980 |
-
|
981 |
-
last_per_steps,
|
982 |
-
samples,
|
983 |
-
learning_rate,
|
984 |
-
int(epochs),
|
985 |
)
|
986 |
|
987 |
|
988 |
def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, safetensors: bool) -> str:
|
989 |
try:
|
990 |
-
checkpoint = torch.load(checkpoint_path)
|
991 |
print("Original Checkpoint Keys:", checkpoint.keys())
|
992 |
|
993 |
ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
|
@@ -1018,7 +991,11 @@ def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42):
|
|
1018 |
torch.backends.cudnn.deterministic = True
|
1019 |
torch.backends.cudnn.benchmark = False
|
1020 |
|
1021 |
-
|
|
|
|
|
|
|
|
|
1022 |
|
1023 |
ema_sd = ckpt.get("ema_model_state_dict", {})
|
1024 |
embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight"
|
@@ -1086,9 +1063,11 @@ def vocab_extend(project_name, symbols, model_type):
|
|
1086 |
with open(file_vocab_project, "w", encoding="utf-8") as f:
|
1087 |
f.write("\n".join(vocab))
|
1088 |
|
1089 |
-
if model_type == "
|
|
|
|
|
1090 |
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
1091 |
-
|
1092 |
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
1093 |
|
1094 |
vocab_size_new = len(miss_symbols)
|
@@ -1096,7 +1075,9 @@ def vocab_extend(project_name, symbols, model_type):
|
|
1096 |
dataset_name = name_project.replace("_pinyin", "").replace("_char", "")
|
1097 |
new_ckpt_path = os.path.join(path_project_ckpts, dataset_name)
|
1098 |
os.makedirs(new_ckpt_path, exist_ok=True)
|
1099 |
-
|
|
|
|
|
1100 |
|
1101 |
size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)
|
1102 |
|
@@ -1226,21 +1207,21 @@ def infer(
|
|
1226 |
vocab_file = os.path.join(path_data, project, "vocab.txt")
|
1227 |
|
1228 |
tts_api = F5TTS(
|
1229 |
-
|
1230 |
)
|
1231 |
|
1232 |
print("update >> ", device_test, file_checkpoint, use_ema)
|
1233 |
|
1234 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
1235 |
tts_api.infer(
|
1236 |
-
gen_text=gen_text.lower().strip(),
|
1237 |
-
ref_text=ref_text.lower().strip(),
|
1238 |
ref_file=ref_audio,
|
|
|
|
|
1239 |
nfe_step=nfe_step,
|
1240 |
-
file_wave=f.name,
|
1241 |
speed=speed,
|
1242 |
-
seed=seed,
|
1243 |
remove_silence=remove_silence,
|
|
|
|
|
1244 |
)
|
1245 |
return f.name, tts_api.device, str(tts_api.seed)
|
1246 |
|
@@ -1256,12 +1237,22 @@ def get_checkpoints_project(project_name, is_gradio=True):
|
|
1256 |
|
1257 |
if os.path.isdir(path_project_ckpts):
|
1258 |
files_checkpoints = glob(os.path.join(path_project_ckpts, project_name, "*.pt"))
|
1259 |
-
|
1260 |
-
|
1261 |
-
|
1262 |
-
|
1263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1264 |
)
|
|
|
|
|
|
|
1265 |
else:
|
1266 |
files_checkpoints = []
|
1267 |
|
@@ -1312,7 +1303,21 @@ def get_gpu_stats():
|
|
1312 |
f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
|
1313 |
f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
|
1314 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1315 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1316 |
elif torch.backends.mps.is_available():
|
1317 |
gpu_count = 1
|
1318 |
gpu_stats += "MPS GPU\n"
|
@@ -1375,14 +1380,14 @@ def get_audio_select(file_sample):
|
|
1375 |
with gr.Blocks() as app:
|
1376 |
gr.Markdown(
|
1377 |
"""
|
1378 |
-
#
|
1379 |
|
1380 |
-
This is a local web UI for F5 TTS
|
1381 |
|
1382 |
* [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
|
1383 |
* [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
|
1384 |
|
1385 |
-
The checkpoints support English and Chinese.
|
1386 |
|
1387 |
For tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143)
|
1388 |
"""
|
@@ -1459,7 +1464,9 @@ Check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are incl
|
|
1459 |
Using the extended model, you can finetune to a new language that is missing symbols in the vocab. This creates a new model with a new vocabulary size and saves it in your ckpts/project folder.
|
1460 |
```""")
|
1461 |
|
1462 |
-
exp_name_extend = gr.Radio(
|
|
|
|
|
1463 |
|
1464 |
with gr.Row():
|
1465 |
txt_extend = gr.Textbox(
|
@@ -1528,9 +1535,9 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
|
|
1528 |
fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
|
1529 |
)
|
1530 |
|
1531 |
-
with gr.TabItem("Train
|
1532 |
gr.Markdown("""```plaintext
|
1533 |
-
The auto-setting is still experimental.
|
1534 |
If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
|
1535 |
```""")
|
1536 |
with gr.Row():
|
@@ -1544,11 +1551,13 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1544 |
file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint", value="")
|
1545 |
|
1546 |
with gr.Row():
|
1547 |
-
exp_name = gr.Radio(
|
|
|
|
|
1548 |
learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
|
1549 |
|
1550 |
with gr.Row():
|
1551 |
-
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=
|
1552 |
max_samples = gr.Number(label="Max Samples", value=64)
|
1553 |
|
1554 |
with gr.Row():
|
@@ -1556,59 +1565,70 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1556 |
max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
|
1557 |
|
1558 |
with gr.Row():
|
1559 |
-
epochs = gr.Number(label="Epochs", value=
|
1560 |
-
num_warmup_updates = gr.Number(label="Warmup Updates", value=
|
1561 |
|
1562 |
with gr.Row():
|
1563 |
-
save_per_updates = gr.Number(label="Save per Updates", value=
|
1564 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1565 |
|
1566 |
with gr.Row():
|
1567 |
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
|
1568 |
-
mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="
|
1569 |
cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
|
1570 |
start_button = gr.Button("Start Training")
|
1571 |
stop_button = gr.Button("Stop Training", interactive=False)
|
1572 |
|
1573 |
if projects_selelect is not None:
|
1574 |
(
|
1575 |
-
|
1576 |
-
|
1577 |
-
|
1578 |
-
|
1579 |
-
|
1580 |
-
|
1581 |
-
|
1582 |
-
|
1583 |
-
|
1584 |
-
|
1585 |
-
|
1586 |
-
|
1587 |
-
|
1588 |
-
|
1589 |
-
|
1590 |
-
|
1591 |
-
|
1592 |
-
|
|
|
1593 |
) = load_settings(projects_selelect)
|
1594 |
-
|
1595 |
-
|
1596 |
-
|
1597 |
-
|
1598 |
-
|
1599 |
-
|
1600 |
-
|
1601 |
-
|
1602 |
-
|
1603 |
-
|
1604 |
-
|
1605 |
-
|
1606 |
-
|
1607 |
-
|
1608 |
-
|
1609 |
-
|
1610 |
-
|
1611 |
-
|
|
|
|
|
|
|
1612 |
|
1613 |
ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
|
1614 |
txt_info_train = gr.Text(label="Info", value="")
|
@@ -1659,7 +1679,8 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1659 |
epochs,
|
1660 |
num_warmup_updates,
|
1661 |
save_per_updates,
|
1662 |
-
|
|
|
1663 |
ch_finetune,
|
1664 |
file_checkpoint_train,
|
1665 |
tokenizer_type,
|
@@ -1677,23 +1698,21 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1677 |
fn=calculate_train,
|
1678 |
inputs=[
|
1679 |
cm_project,
|
|
|
|
|
|
|
1680 |
batch_size_type,
|
1681 |
max_samples,
|
1682 |
-
learning_rate,
|
1683 |
num_warmup_updates,
|
1684 |
-
save_per_updates,
|
1685 |
-
last_per_steps,
|
1686 |
ch_finetune,
|
1687 |
],
|
1688 |
outputs=[
|
|
|
|
|
1689 |
batch_size_per_gpu,
|
1690 |
max_samples,
|
1691 |
num_warmup_updates,
|
1692 |
-
save_per_updates,
|
1693 |
-
last_per_steps,
|
1694 |
lb_samples,
|
1695 |
-
learning_rate,
|
1696 |
-
epochs,
|
1697 |
],
|
1698 |
)
|
1699 |
|
@@ -1713,15 +1732,16 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1713 |
epochs,
|
1714 |
num_warmup_updates,
|
1715 |
save_per_updates,
|
1716 |
-
|
|
|
1717 |
ch_finetune,
|
1718 |
file_checkpoint_train,
|
1719 |
tokenizer_type,
|
1720 |
tokenizer_file,
|
1721 |
mixed_precision,
|
1722 |
cd_logger,
|
|
|
1723 |
]
|
1724 |
-
|
1725 |
return output_components
|
1726 |
|
1727 |
outputs = setup_load_settings()
|
@@ -1742,7 +1762,9 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1742 |
gr.Markdown("""```plaintext
|
1743 |
SOS: Check the use_ema setting (True or False) for your model to see what works best for you. use seed -1 from random
|
1744 |
```""")
|
1745 |
-
exp_name = gr.Radio(
|
|
|
|
|
1746 |
list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
|
1747 |
|
1748 |
with gr.Row():
|
@@ -1796,9 +1818,9 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
|
|
1796 |
bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
|
1797 |
cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
|
1798 |
|
1799 |
-
with gr.TabItem("
|
1800 |
gr.Markdown("""```plaintext
|
1801 |
-
Reduce the model size from 5GB to 1.3GB. The new checkpoint can be used for inference or
|
1802 |
```""")
|
1803 |
txt_path_checkpoint = gr.Text(label="Path to Checkpoint:")
|
1804 |
txt_path_checkpoint_small = gr.Text(label="Path to Output:")
|
|
|
|
|
|
|
|
|
|
|
1 |
import gc
|
2 |
import json
|
3 |
+
import numpy as np
|
4 |
import os
|
5 |
import platform
|
6 |
import psutil
|
7 |
+
import queue
|
8 |
import random
|
9 |
+
import re
|
10 |
import signal
|
11 |
import shutil
|
12 |
import subprocess
|
13 |
import sys
|
14 |
import tempfile
|
15 |
+
import threading
|
16 |
import time
|
17 |
from glob import glob
|
18 |
+
from importlib.resources import files
|
19 |
+
from scipy.io import wavfile
|
20 |
|
21 |
import click
|
22 |
import gradio as gr
|
23 |
import librosa
|
|
|
24 |
import torch
|
25 |
import torchaudio
|
26 |
+
from cached_path import cached_path
|
27 |
from datasets import Dataset as Dataset_
|
28 |
from datasets.arrow_writer import ArrowWriter
|
29 |
+
from safetensors.torch import load_file, save_file
|
30 |
+
|
|
|
31 |
from f5_tts.api import F5TTS
|
32 |
from f5_tts.model.utils import convert_char_to_pinyin
|
33 |
from f5_tts.infer.utils_infer import transcribe
|
|
|
34 |
|
35 |
|
36 |
training_process = None
|
|
|
46 |
path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts"))
|
47 |
file_train = str(files("f5_tts").joinpath("train/finetune_cli.py"))
|
48 |
|
49 |
+
device = (
|
50 |
+
"cuda"
|
51 |
+
if torch.cuda.is_available()
|
52 |
+
else "xpu"
|
53 |
+
if torch.xpu.is_available()
|
54 |
+
else "mps"
|
55 |
+
if torch.backends.mps.is_available()
|
56 |
+
else "cpu"
|
57 |
+
)
|
58 |
|
59 |
|
60 |
# Save settings from a JSON file
|
|
|
70 |
epochs,
|
71 |
num_warmup_updates,
|
72 |
save_per_updates,
|
73 |
+
keep_last_n_checkpoints,
|
74 |
+
last_per_updates,
|
75 |
finetune,
|
76 |
file_checkpoint_train,
|
77 |
tokenizer_type,
|
|
|
95 |
"epochs": epochs,
|
96 |
"num_warmup_updates": num_warmup_updates,
|
97 |
"save_per_updates": save_per_updates,
|
98 |
+
"keep_last_n_checkpoints": keep_last_n_checkpoints,
|
99 |
+
"last_per_updates": last_per_updates,
|
100 |
"finetune": finetune,
|
101 |
"file_checkpoint_train": file_checkpoint_train,
|
102 |
"tokenizer_type": tokenizer_type,
|
|
|
116 |
path_project = os.path.join(path_project_ckpts, project_name)
|
117 |
file_setting = os.path.join(path_project, "setting.json")
|
118 |
|
119 |
+
# Default settings
|
120 |
+
default_settings = {
|
121 |
+
"exp_name": "F5TTS_v1_Base",
|
122 |
+
"learning_rate": 1e-5,
|
123 |
+
"batch_size_per_gpu": 1,
|
124 |
+
"batch_size_type": "sample",
|
125 |
+
"max_samples": 64,
|
126 |
+
"grad_accumulation_steps": 4,
|
127 |
+
"max_grad_norm": 1,
|
128 |
+
"epochs": 100,
|
129 |
+
"num_warmup_updates": 100,
|
130 |
+
"save_per_updates": 500,
|
131 |
+
"keep_last_n_checkpoints": -1,
|
132 |
+
"last_per_updates": 100,
|
133 |
+
"finetune": True,
|
134 |
+
"file_checkpoint_train": "",
|
135 |
+
"tokenizer_type": "pinyin",
|
136 |
+
"tokenizer_file": "",
|
137 |
+
"mixed_precision": "none",
|
138 |
+
"logger": "wandb",
|
139 |
+
"bnb_optimizer": False,
|
140 |
+
}
|
141 |
+
|
142 |
+
# Load settings from file if it exists
|
143 |
+
if os.path.isfile(file_setting):
|
144 |
+
with open(file_setting, "r") as f:
|
145 |
+
file_settings = json.load(f)
|
146 |
+
default_settings.update(file_settings)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
+
# Return as a tuple in the correct order
|
|
|
|
|
|
|
|
|
|
|
149 |
return (
|
150 |
+
default_settings["exp_name"],
|
151 |
+
default_settings["learning_rate"],
|
152 |
+
default_settings["batch_size_per_gpu"],
|
153 |
+
default_settings["batch_size_type"],
|
154 |
+
default_settings["max_samples"],
|
155 |
+
default_settings["grad_accumulation_steps"],
|
156 |
+
default_settings["max_grad_norm"],
|
157 |
+
default_settings["epochs"],
|
158 |
+
default_settings["num_warmup_updates"],
|
159 |
+
default_settings["save_per_updates"],
|
160 |
+
default_settings["keep_last_n_checkpoints"],
|
161 |
+
default_settings["last_per_updates"],
|
162 |
+
default_settings["finetune"],
|
163 |
+
default_settings["file_checkpoint_train"],
|
164 |
+
default_settings["tokenizer_type"],
|
165 |
+
default_settings["tokenizer_file"],
|
166 |
+
default_settings["mixed_precision"],
|
167 |
+
default_settings["logger"],
|
168 |
+
default_settings["bnb_optimizer"],
|
169 |
)
|
170 |
|
171 |
|
|
|
362 |
|
363 |
def start_training(
|
364 |
dataset_name="",
|
365 |
+
exp_name="F5TTS_v1_Base",
|
366 |
+
learning_rate=1e-5,
|
367 |
+
batch_size_per_gpu=1,
|
368 |
+
batch_size_type="sample",
|
369 |
max_samples=64,
|
370 |
+
grad_accumulation_steps=4,
|
371 |
max_grad_norm=1.0,
|
372 |
+
epochs=100,
|
373 |
+
num_warmup_updates=100,
|
374 |
+
save_per_updates=500,
|
375 |
+
keep_last_n_checkpoints=-1,
|
376 |
+
last_per_updates=100,
|
377 |
finetune=True,
|
378 |
file_checkpoint_train="",
|
379 |
tokenizer_type="pinyin",
|
|
|
432 |
fp16 = ""
|
433 |
|
434 |
cmd = (
|
435 |
+
f"accelerate launch {fp16} {file_train} --exp_name {exp_name}"
|
436 |
+
f" --learning_rate {learning_rate}"
|
437 |
+
f" --batch_size_per_gpu {batch_size_per_gpu}"
|
438 |
+
f" --batch_size_type {batch_size_type}"
|
439 |
+
f" --max_samples {max_samples}"
|
440 |
+
f" --grad_accumulation_steps {grad_accumulation_steps}"
|
441 |
+
f" --max_grad_norm {max_grad_norm}"
|
442 |
+
f" --epochs {epochs}"
|
443 |
+
f" --num_warmup_updates {num_warmup_updates}"
|
444 |
+
f" --save_per_updates {save_per_updates}"
|
445 |
+
f" --keep_last_n_checkpoints {keep_last_n_checkpoints}"
|
446 |
+
f" --last_per_updates {last_per_updates}"
|
447 |
+
f" --dataset_name {dataset_name}"
|
448 |
)
|
449 |
|
450 |
if finetune:
|
|
|
477 |
epochs,
|
478 |
num_warmup_updates,
|
479 |
save_per_updates,
|
480 |
+
keep_last_n_checkpoints,
|
481 |
+
last_per_updates,
|
482 |
finetune,
|
483 |
file_checkpoint_train,
|
484 |
tokenizer_type,
|
|
|
544 |
output = stdout_queue.get_nowait()
|
545 |
print(output, end="")
|
546 |
match = re.search(
|
547 |
+
r"Epoch (\d+)/(\d+):\s+(\d+)%\|.*\[(\d+:\d+)<.*?loss=(\d+\.\d+), update=(\d+)", output
|
548 |
)
|
549 |
if match:
|
550 |
current_epoch = match.group(1)
|
|
|
552 |
percent_complete = match.group(3)
|
553 |
elapsed_time = match.group(4)
|
554 |
loss = match.group(5)
|
555 |
+
current_update = match.group(6)
|
556 |
message = (
|
557 |
f"Epoch: {current_epoch}/{total_epochs}, "
|
558 |
f"Progress: {percent_complete}%, "
|
559 |
f"Elapsed Time: {elapsed_time}, "
|
560 |
f"Loss: {loss}, "
|
561 |
+
f"Update: {current_update}"
|
562 |
)
|
563 |
yield message, gr.update(interactive=False), gr.update(interactive=True)
|
564 |
elif output.strip():
|
|
|
797 |
print(f"Error processing {file_audio}: {e}")
|
798 |
continue
|
799 |
|
800 |
+
if duration < 1 or duration > 30:
|
801 |
+
if duration > 30:
|
802 |
+
error_files.append([file_audio, "duration > 30 sec"])
|
803 |
if duration < 1:
|
804 |
error_files.append([file_audio, "duration < 1 sec "])
|
805 |
continue
|
806 |
if len(text) < 3:
|
807 |
+
error_files.append([file_audio, "very short text length 3"])
|
808 |
continue
|
809 |
|
810 |
text = clear_text(text)
|
|
|
871 |
|
872 |
def calculate_train(
|
873 |
name_project,
|
874 |
+
epochs,
|
875 |
+
learning_rate,
|
876 |
+
batch_size_per_gpu,
|
877 |
batch_size_type,
|
878 |
max_samples,
|
|
|
879 |
num_warmup_updates,
|
|
|
|
|
880 |
finetune,
|
881 |
):
|
882 |
path_project = os.path.join(path_data, name_project)
|
883 |
+
file_duration = os.path.join(path_project, "duration.json")
|
884 |
|
885 |
+
hop_length = 256
|
886 |
+
sampling_rate = 24000
|
887 |
+
|
888 |
+
if not os.path.isfile(file_duration):
|
889 |
return (
|
890 |
+
epochs,
|
891 |
+
learning_rate,
|
892 |
+
batch_size_per_gpu,
|
893 |
max_samples,
|
894 |
num_warmup_updates,
|
|
|
|
|
895 |
"project not found !",
|
|
|
896 |
)
|
897 |
|
898 |
+
with open(file_duration, "r") as file:
|
899 |
data = json.load(file)
|
900 |
|
901 |
duration_list = data["duration"]
|
902 |
+
max_sample_length = max(duration_list) * sampling_rate / hop_length
|
903 |
+
total_samples = len(duration_list)
|
904 |
+
total_duration = sum(duration_list)
|
|
|
|
|
|
|
|
|
|
|
905 |
|
906 |
if torch.cuda.is_available():
|
907 |
gpu_count = torch.cuda.device_count()
|
|
|
909 |
for i in range(gpu_count):
|
910 |
gpu_properties = torch.cuda.get_device_properties(i)
|
911 |
total_memory += gpu_properties.total_memory / (1024**3) # in GB
|
912 |
+
elif torch.xpu.is_available():
|
913 |
+
gpu_count = torch.xpu.device_count()
|
914 |
+
total_memory = 0
|
915 |
+
for i in range(gpu_count):
|
916 |
+
gpu_properties = torch.xpu.get_device_properties(i)
|
917 |
+
total_memory += gpu_properties.total_memory / (1024**3)
|
918 |
elif torch.backends.mps.is_available():
|
919 |
gpu_count = 1
|
920 |
total_memory = psutil.virtual_memory().available / (1024**3)
|
921 |
|
922 |
+
avg_gpu_memory = total_memory / gpu_count
|
923 |
+
|
924 |
+
# rough estimate of batch size
|
925 |
if batch_size_type == "frame":
|
926 |
+
batch_size_per_gpu = max(int(38400 * (avg_gpu_memory - 5) / 75), int(max_sample_length))
|
927 |
+
elif batch_size_type == "sample":
|
928 |
+
batch_size_per_gpu = int(200 / (total_duration / total_samples))
|
|
|
|
|
|
|
|
|
929 |
|
930 |
+
if total_samples < 64:
|
931 |
+
max_samples = int(total_samples * 0.25)
|
932 |
|
933 |
+
num_warmup_updates = max(num_warmup_updates, int(total_samples * 0.05))
|
934 |
+
|
935 |
+
# take 1.2M updates as the maximum
|
936 |
+
max_updates = 1200000
|
937 |
+
|
938 |
+
if batch_size_type == "frame":
|
939 |
+
mini_batch_duration = batch_size_per_gpu * gpu_count * hop_length / sampling_rate
|
940 |
+
updates_per_epoch = total_duration / mini_batch_duration
|
941 |
+
elif batch_size_type == "sample":
|
942 |
+
updates_per_epoch = total_samples / batch_size_per_gpu / gpu_count
|
943 |
+
|
944 |
+
epochs = int(max_updates / updates_per_epoch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
945 |
|
946 |
if finetune:
|
947 |
learning_rate = 1e-5
|
|
|
949 |
learning_rate = 7.5e-5
|
950 |
|
951 |
return (
|
952 |
+
epochs,
|
953 |
+
learning_rate,
|
954 |
batch_size_per_gpu,
|
955 |
max_samples,
|
956 |
num_warmup_updates,
|
957 |
+
total_samples,
|
|
|
|
|
|
|
|
|
958 |
)
|
959 |
|
960 |
|
961 |
def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, safetensors: bool) -> str:
|
962 |
try:
|
963 |
+
checkpoint = torch.load(checkpoint_path, weights_only=True)
|
964 |
print("Original Checkpoint Keys:", checkpoint.keys())
|
965 |
|
966 |
ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
|
|
|
991 |
torch.backends.cudnn.deterministic = True
|
992 |
torch.backends.cudnn.benchmark = False
|
993 |
|
994 |
+
if ckpt_path.endswith(".safetensors"):
|
995 |
+
ckpt = load_file(ckpt_path, device="cpu")
|
996 |
+
ckpt = {"ema_model_state_dict": ckpt}
|
997 |
+
elif ckpt_path.endswith(".pt"):
|
998 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
999 |
|
1000 |
ema_sd = ckpt.get("ema_model_state_dict", {})
|
1001 |
embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight"
|
|
|
1063 |
with open(file_vocab_project, "w", encoding="utf-8") as f:
|
1064 |
f.write("\n".join(vocab))
|
1065 |
|
1066 |
+
if model_type == "F5TTS_v1_Base":
|
1067 |
+
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors"))
|
1068 |
+
elif model_type == "F5TTS_Base":
|
1069 |
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
1070 |
+
elif model_type == "E2TTS_Base":
|
1071 |
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
1072 |
|
1073 |
vocab_size_new = len(miss_symbols)
|
|
|
1075 |
dataset_name = name_project.replace("_pinyin", "").replace("_char", "")
|
1076 |
new_ckpt_path = os.path.join(path_project_ckpts, dataset_name)
|
1077 |
os.makedirs(new_ckpt_path, exist_ok=True)
|
1078 |
+
|
1079 |
+
# Add pretrained_ prefix to model when copying for consistency with finetune_cli.py
|
1080 |
+
new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_" + os.path.basename(ckpt_path))
|
1081 |
|
1082 |
size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)
|
1083 |
|
|
|
1207 |
vocab_file = os.path.join(path_data, project, "vocab.txt")
|
1208 |
|
1209 |
tts_api = F5TTS(
|
1210 |
+
model=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema
|
1211 |
)
|
1212 |
|
1213 |
print("update >> ", device_test, file_checkpoint, use_ema)
|
1214 |
|
1215 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
1216 |
tts_api.infer(
|
|
|
|
|
1217 |
ref_file=ref_audio,
|
1218 |
+
ref_text=ref_text.lower().strip(),
|
1219 |
+
gen_text=gen_text.lower().strip(),
|
1220 |
nfe_step=nfe_step,
|
|
|
1221 |
speed=speed,
|
|
|
1222 |
remove_silence=remove_silence,
|
1223 |
+
file_wave=f.name,
|
1224 |
+
seed=seed,
|
1225 |
)
|
1226 |
return f.name, tts_api.device, str(tts_api.seed)
|
1227 |
|
|
|
1237 |
|
1238 |
if os.path.isdir(path_project_ckpts):
|
1239 |
files_checkpoints = glob(os.path.join(path_project_ckpts, project_name, "*.pt"))
|
1240 |
+
# Separate pretrained and regular checkpoints
|
1241 |
+
pretrained_checkpoints = [f for f in files_checkpoints if "pretrained_" in os.path.basename(f)]
|
1242 |
+
regular_checkpoints = [
|
1243 |
+
f
|
1244 |
+
for f in files_checkpoints
|
1245 |
+
if "pretrained_" not in os.path.basename(f) and "model_last.pt" not in os.path.basename(f)
|
1246 |
+
]
|
1247 |
+
last_checkpoint = [f for f in files_checkpoints if "model_last.pt" in os.path.basename(f)]
|
1248 |
+
|
1249 |
+
# Sort regular checkpoints by number
|
1250 |
+
regular_checkpoints = sorted(
|
1251 |
+
regular_checkpoints, key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0])
|
1252 |
)
|
1253 |
+
|
1254 |
+
# Combine in order: pretrained, regular, last
|
1255 |
+
files_checkpoints = pretrained_checkpoints + regular_checkpoints + last_checkpoint
|
1256 |
else:
|
1257 |
files_checkpoints = []
|
1258 |
|
|
|
1303 |
f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
|
1304 |
f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
|
1305 |
)
|
1306 |
+
elif torch.xpu.is_available():
|
1307 |
+
gpu_count = torch.xpu.device_count()
|
1308 |
+
for i in range(gpu_count):
|
1309 |
+
gpu_name = torch.xpu.get_device_name(i)
|
1310 |
+
gpu_properties = torch.xpu.get_device_properties(i)
|
1311 |
+
total_memory = gpu_properties.total_memory / (1024**3) # in GB
|
1312 |
+
allocated_memory = torch.xpu.memory_allocated(i) / (1024**2) # in MB
|
1313 |
+
reserved_memory = torch.xpu.memory_reserved(i) / (1024**2) # in MB
|
1314 |
|
1315 |
+
gpu_stats += (
|
1316 |
+
f"GPU {i} Name: {gpu_name}\n"
|
1317 |
+
f"Total GPU memory (GPU {i}): {total_memory:.2f} GB\n"
|
1318 |
+
f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
|
1319 |
+
f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
|
1320 |
+
)
|
1321 |
elif torch.backends.mps.is_available():
|
1322 |
gpu_count = 1
|
1323 |
gpu_stats += "MPS GPU\n"
|
|
|
1380 |
with gr.Blocks() as app:
|
1381 |
gr.Markdown(
|
1382 |
"""
|
1383 |
+
# F5 TTS Automatic Finetune
|
1384 |
|
1385 |
+
This is a local web UI for F5 TTS finetuning support. This app supports the following TTS models:
|
1386 |
|
1387 |
* [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
|
1388 |
* [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
|
1389 |
|
1390 |
+
The pretrained checkpoints support English and Chinese.
|
1391 |
|
1392 |
For tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143)
|
1393 |
"""
|
|
|
1464 |
Using the extended model, you can finetune to a new language that is missing symbols in the vocab. This creates a new model with a new vocabulary size and saves it in your ckpts/project folder.
|
1465 |
```""")
|
1466 |
|
1467 |
+
exp_name_extend = gr.Radio(
|
1468 |
+
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
|
1469 |
+
)
|
1470 |
|
1471 |
with gr.Row():
|
1472 |
txt_extend = gr.Textbox(
|
|
|
1535 |
fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
|
1536 |
)
|
1537 |
|
1538 |
+
with gr.TabItem("Train Model"):
|
1539 |
gr.Markdown("""```plaintext
|
1540 |
+
The auto-setting is still experimental. Set a large value of epoch if not sure; and keep last N checkpoints if limited disk space.
|
1541 |
If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
|
1542 |
```""")
|
1543 |
with gr.Row():
|
|
|
1551 |
file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint", value="")
|
1552 |
|
1553 |
with gr.Row():
|
1554 |
+
exp_name = gr.Radio(
|
1555 |
+
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
|
1556 |
+
)
|
1557 |
learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
|
1558 |
|
1559 |
with gr.Row():
|
1560 |
+
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=3200)
|
1561 |
max_samples = gr.Number(label="Max Samples", value=64)
|
1562 |
|
1563 |
with gr.Row():
|
|
|
1565 |
max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
|
1566 |
|
1567 |
with gr.Row():
|
1568 |
+
epochs = gr.Number(label="Epochs", value=100)
|
1569 |
+
num_warmup_updates = gr.Number(label="Warmup Updates", value=100)
|
1570 |
|
1571 |
with gr.Row():
|
1572 |
+
save_per_updates = gr.Number(label="Save per Updates", value=500)
|
1573 |
+
keep_last_n_checkpoints = gr.Number(
|
1574 |
+
label="Keep Last N Checkpoints",
|
1575 |
+
value=-1,
|
1576 |
+
step=1,
|
1577 |
+
precision=0,
|
1578 |
+
info="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
|
1579 |
+
)
|
1580 |
+
last_per_updates = gr.Number(label="Last per Updates", value=100)
|
1581 |
|
1582 |
with gr.Row():
|
1583 |
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
|
1584 |
+
mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="fp16")
|
1585 |
cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
|
1586 |
start_button = gr.Button("Start Training")
|
1587 |
stop_button = gr.Button("Stop Training", interactive=False)
|
1588 |
|
1589 |
if projects_selelect is not None:
|
1590 |
(
|
1591 |
+
exp_name_value,
|
1592 |
+
learning_rate_value,
|
1593 |
+
batch_size_per_gpu_value,
|
1594 |
+
batch_size_type_value,
|
1595 |
+
max_samples_value,
|
1596 |
+
grad_accumulation_steps_value,
|
1597 |
+
max_grad_norm_value,
|
1598 |
+
epochs_value,
|
1599 |
+
num_warmup_updates_value,
|
1600 |
+
save_per_updates_value,
|
1601 |
+
keep_last_n_checkpoints_value,
|
1602 |
+
last_per_updates_value,
|
1603 |
+
finetune_value,
|
1604 |
+
file_checkpoint_train_value,
|
1605 |
+
tokenizer_type_value,
|
1606 |
+
tokenizer_file_value,
|
1607 |
+
mixed_precision_value,
|
1608 |
+
logger_value,
|
1609 |
+
bnb_optimizer_value,
|
1610 |
) = load_settings(projects_selelect)
|
1611 |
+
|
1612 |
+
# Assigning values to the respective components
|
1613 |
+
exp_name.value = exp_name_value
|
1614 |
+
learning_rate.value = learning_rate_value
|
1615 |
+
batch_size_per_gpu.value = batch_size_per_gpu_value
|
1616 |
+
batch_size_type.value = batch_size_type_value
|
1617 |
+
max_samples.value = max_samples_value
|
1618 |
+
grad_accumulation_steps.value = grad_accumulation_steps_value
|
1619 |
+
max_grad_norm.value = max_grad_norm_value
|
1620 |
+
epochs.value = epochs_value
|
1621 |
+
num_warmup_updates.value = num_warmup_updates_value
|
1622 |
+
save_per_updates.value = save_per_updates_value
|
1623 |
+
keep_last_n_checkpoints.value = keep_last_n_checkpoints_value
|
1624 |
+
last_per_updates.value = last_per_updates_value
|
1625 |
+
ch_finetune.value = finetune_value
|
1626 |
+
file_checkpoint_train.value = file_checkpoint_train_value
|
1627 |
+
tokenizer_type.value = tokenizer_type_value
|
1628 |
+
tokenizer_file.value = tokenizer_file_value
|
1629 |
+
mixed_precision.value = mixed_precision_value
|
1630 |
+
cd_logger.value = logger_value
|
1631 |
+
ch_8bit_adam.value = bnb_optimizer_value
|
1632 |
|
1633 |
ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
|
1634 |
txt_info_train = gr.Text(label="Info", value="")
|
|
|
1679 |
epochs,
|
1680 |
num_warmup_updates,
|
1681 |
save_per_updates,
|
1682 |
+
keep_last_n_checkpoints,
|
1683 |
+
last_per_updates,
|
1684 |
ch_finetune,
|
1685 |
file_checkpoint_train,
|
1686 |
tokenizer_type,
|
|
|
1698 |
fn=calculate_train,
|
1699 |
inputs=[
|
1700 |
cm_project,
|
1701 |
+
epochs,
|
1702 |
+
learning_rate,
|
1703 |
+
batch_size_per_gpu,
|
1704 |
batch_size_type,
|
1705 |
max_samples,
|
|
|
1706 |
num_warmup_updates,
|
|
|
|
|
1707 |
ch_finetune,
|
1708 |
],
|
1709 |
outputs=[
|
1710 |
+
epochs,
|
1711 |
+
learning_rate,
|
1712 |
batch_size_per_gpu,
|
1713 |
max_samples,
|
1714 |
num_warmup_updates,
|
|
|
|
|
1715 |
lb_samples,
|
|
|
|
|
1716 |
],
|
1717 |
)
|
1718 |
|
|
|
1732 |
epochs,
|
1733 |
num_warmup_updates,
|
1734 |
save_per_updates,
|
1735 |
+
keep_last_n_checkpoints,
|
1736 |
+
last_per_updates,
|
1737 |
ch_finetune,
|
1738 |
file_checkpoint_train,
|
1739 |
tokenizer_type,
|
1740 |
tokenizer_file,
|
1741 |
mixed_precision,
|
1742 |
cd_logger,
|
1743 |
+
ch_8bit_adam,
|
1744 |
]
|
|
|
1745 |
return output_components
|
1746 |
|
1747 |
outputs = setup_load_settings()
|
|
|
1762 |
gr.Markdown("""```plaintext
|
1763 |
SOS: Check the use_ema setting (True or False) for your model to see what works best for you. use seed -1 from random
|
1764 |
```""")
|
1765 |
+
exp_name = gr.Radio(
|
1766 |
+
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
|
1767 |
+
)
|
1768 |
list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
|
1769 |
|
1770 |
with gr.Row():
|
|
|
1818 |
bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
|
1819 |
cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
|
1820 |
|
1821 |
+
with gr.TabItem("Prune Checkpoint"):
|
1822 |
gr.Markdown("""```plaintext
|
1823 |
+
Reduce the Base model size from 5GB to 1.3GB. The new checkpoint file prunes out optimizer and etc., can be used for inference or finetuning afterward, but not able to resume pretraining.
|
1824 |
```""")
|
1825 |
txt_path_checkpoint = gr.Text(label="Path to Checkpoint:")
|
1826 |
txt_path_checkpoint_small = gr.Text(label="Path to Output:")
|
src/f5_tts/train/train.py
CHANGED
@@ -4,8 +4,9 @@ import os
|
|
4 |
from importlib.resources import files
|
5 |
|
6 |
import hydra
|
|
|
7 |
|
8 |
-
from f5_tts.model import CFM, DiT,
|
9 |
from f5_tts.model.dataset import load_dataset
|
10 |
from f5_tts.model.utils import get_tokenizer
|
11 |
|
@@ -14,9 +15,13 @@ os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to
|
|
14 |
|
15 |
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
|
16 |
def main(cfg):
|
|
|
|
|
17 |
tokenizer = cfg.model.tokenizer
|
18 |
mel_spec_type = cfg.model.mel_spec.mel_spec_type
|
|
|
19 |
exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
|
|
|
20 |
|
21 |
# set text tokenizer
|
22 |
if tokenizer != "custom":
|
@@ -26,14 +31,8 @@ def main(cfg):
|
|
26 |
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
27 |
|
28 |
# set model
|
29 |
-
if "F5TTS" in cfg.model.name:
|
30 |
-
model_cls = DiT
|
31 |
-
elif "E2TTS" in cfg.model.name:
|
32 |
-
model_cls = UNetT
|
33 |
-
wandb_resume_id = None
|
34 |
-
|
35 |
model = CFM(
|
36 |
-
transformer=model_cls(**
|
37 |
mel_spec_kwargs=cfg.model.mel_spec,
|
38 |
vocab_char_map=vocab_char_map,
|
39 |
)
|
@@ -45,8 +44,9 @@ def main(cfg):
|
|
45 |
learning_rate=cfg.optim.learning_rate,
|
46 |
num_warmup_updates=cfg.optim.num_warmup_updates,
|
47 |
save_per_updates=cfg.ckpts.save_per_updates,
|
|
|
48 |
checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
|
49 |
-
|
50 |
batch_size_type=cfg.datasets.batch_size_type,
|
51 |
max_samples=cfg.datasets.max_samples,
|
52 |
grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
|
@@ -55,12 +55,13 @@ def main(cfg):
|
|
55 |
wandb_project="CFM-TTS",
|
56 |
wandb_run_name=exp_name,
|
57 |
wandb_resume_id=wandb_resume_id,
|
58 |
-
|
59 |
-
log_samples=
|
60 |
bnb_optimizer=cfg.optim.bnb_optimizer,
|
61 |
mel_spec_type=mel_spec_type,
|
62 |
is_local_vocoder=cfg.model.vocoder.is_local,
|
63 |
local_vocoder_path=cfg.model.vocoder.local_path,
|
|
|
64 |
)
|
65 |
|
66 |
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
|
|
|
4 |
from importlib.resources import files
|
5 |
|
6 |
import hydra
|
7 |
+
from omegaconf import OmegaConf
|
8 |
|
9 |
+
from f5_tts.model import CFM, DiT, UNetT, Trainer # noqa: F401. used for config
|
10 |
from f5_tts.model.dataset import load_dataset
|
11 |
from f5_tts.model.utils import get_tokenizer
|
12 |
|
|
|
15 |
|
16 |
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
|
17 |
def main(cfg):
|
18 |
+
model_cls = globals()[cfg.model.backbone]
|
19 |
+
model_arc = cfg.model.arch
|
20 |
tokenizer = cfg.model.tokenizer
|
21 |
mel_spec_type = cfg.model.mel_spec.mel_spec_type
|
22 |
+
|
23 |
exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
|
24 |
+
wandb_resume_id = None
|
25 |
|
26 |
# set text tokenizer
|
27 |
if tokenizer != "custom":
|
|
|
31 |
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
32 |
|
33 |
# set model
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
model = CFM(
|
35 |
+
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
|
36 |
mel_spec_kwargs=cfg.model.mel_spec,
|
37 |
vocab_char_map=vocab_char_map,
|
38 |
)
|
|
|
44 |
learning_rate=cfg.optim.learning_rate,
|
45 |
num_warmup_updates=cfg.optim.num_warmup_updates,
|
46 |
save_per_updates=cfg.ckpts.save_per_updates,
|
47 |
+
keep_last_n_checkpoints=cfg.ckpts.keep_last_n_checkpoints,
|
48 |
checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
|
49 |
+
batch_size_per_gpu=cfg.datasets.batch_size_per_gpu,
|
50 |
batch_size_type=cfg.datasets.batch_size_type,
|
51 |
max_samples=cfg.datasets.max_samples,
|
52 |
grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
|
|
|
55 |
wandb_project="CFM-TTS",
|
56 |
wandb_run_name=exp_name,
|
57 |
wandb_resume_id=wandb_resume_id,
|
58 |
+
last_per_updates=cfg.ckpts.last_per_updates,
|
59 |
+
log_samples=cfg.ckpts.log_samples,
|
60 |
bnb_optimizer=cfg.optim.bnb_optimizer,
|
61 |
mel_spec_type=mel_spec_type,
|
62 |
is_local_vocoder=cfg.model.vocoder.is_local,
|
63 |
local_vocoder_path=cfg.model.vocoder.local_path,
|
64 |
+
cfg_dict=OmegaConf.to_container(cfg, resolve=True),
|
65 |
)
|
66 |
|
67 |
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
|