Spaces:
Running
Running
Upload 7 files
Browse files- KOKORO/README.md +163 -0
- KOKORO/config.json +26 -0
- KOKORO/istftnet.py +523 -0
- KOKORO/kokoro.py +161 -0
- KOKORO/models.py +372 -0
- KOKORO/plbert.py +15 -0
- KOKORO/utils.py +342 -0
KOKORO/README.md
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
base_model:
|
6 |
+
- yl4579/StyleTTS2-LJSpeech
|
7 |
+
pipeline_tag: text-to-speech
|
8 |
+
---
|
9 |
+
📣 Jan 12 Status: Intent to improve the base model https://hf.co/hexgrad/Kokoro-82M/discussions/36
|
10 |
+
|
11 |
+
❤️ Kokoro Discord Server: https://discord.gg/QuGxSWBfQy
|
12 |
+
|
13 |
+
🚨 Got Synthetic Data? Want Trained Voicepacks? See https://hf.co/posts/hexgrad/418806998707773
|
14 |
+
|
15 |
+
<audio controls><source src="https://huggingface.co/hexgrad/Kokoro-82M/resolve/main/demo/HEARME.wav" type="audio/wav"></audio>
|
16 |
+
|
17 |
+
**Kokoro** is a frontier TTS model for its size of **82 million parameters** (text in/audio out).
|
18 |
+
|
19 |
+
On 25 Dec 2024, Kokoro v0.19 weights were permissively released in full fp32 precision under an Apache 2.0 license. As of 2 Jan 2025, 10 unique Voicepacks have been released, and a `.onnx` version of v0.19 is available.
|
20 |
+
|
21 |
+
In the weeks leading up to its release, Kokoro v0.19 was the #1🥇 ranked model in [TTS Spaces Arena](https://huggingface.co/hexgrad/Kokoro-82M#evaluation). Kokoro had achieved higher Elo in this single-voice Arena setting over other models, using fewer parameters and less data:
|
22 |
+
1. **Kokoro v0.19: 82M params, Apache, trained on <100 hours of audio**
|
23 |
+
2. XTTS v2: 467M, CPML, >10k hours
|
24 |
+
3. Edge TTS: Microsoft, proprietary
|
25 |
+
4. MetaVoice: 1.2B, Apache, 100k hours
|
26 |
+
5. Parler Mini: 880M, Apache, 45k hours
|
27 |
+
6. Fish Speech: ~500M, CC-BY-NC-SA, 1M hours
|
28 |
+
|
29 |
+
Kokoro's ability to top this Elo ladder suggests that the scaling law (Elo vs compute/data/params) for traditional TTS models might have a steeper slope than previously expected.
|
30 |
+
|
31 |
+
You can find a hosted demo at [hf.co/spaces/hexgrad/Kokoro-TTS](https://huggingface.co/spaces/hexgrad/Kokoro-TTS).
|
32 |
+
|
33 |
+
### Usage
|
34 |
+
|
35 |
+
The following can be run in a single cell on [Google Colab](https://colab.research.google.com/).
|
36 |
+
```py
|
37 |
+
# 1️⃣ Install dependencies silently
|
38 |
+
!git lfs install
|
39 |
+
!git clone https://huggingface.co/hexgrad/Kokoro-82M
|
40 |
+
%cd Kokoro-82M
|
41 |
+
!apt-get -qq -y install espeak-ng > /dev/null 2>&1
|
42 |
+
!pip install -q phonemizer torch transformers scipy munch
|
43 |
+
|
44 |
+
# 2️⃣ Build the model and load the default voicepack
|
45 |
+
from models import build_model
|
46 |
+
import torch
|
47 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
48 |
+
MODEL = build_model('kokoro-v0_19.pth', device)
|
49 |
+
VOICE_NAME = [
|
50 |
+
'af', # Default voice is a 50-50 mix of Bella & Sarah
|
51 |
+
'af_bella', 'af_sarah', 'am_adam', 'am_michael',
|
52 |
+
'bf_emma', 'bf_isabella', 'bm_george', 'bm_lewis',
|
53 |
+
'af_nicole', 'af_sky',
|
54 |
+
][0]
|
55 |
+
VOICEPACK = torch.load(f'voices/{VOICE_NAME}.pt', weights_only=True).to(device)
|
56 |
+
print(f'Loaded voice: {VOICE_NAME}')
|
57 |
+
|
58 |
+
# 3️⃣ Call generate, which returns 24khz audio and the phonemes used
|
59 |
+
from kokoro import generate
|
60 |
+
text = "How could I know? It's an unanswerable question. Like asking an unborn child if they'll lead a good life. They haven't even been born."
|
61 |
+
audio, out_ps = generate(MODEL, text, VOICEPACK, lang=VOICE_NAME[0])
|
62 |
+
# Language is determined by the first letter of the VOICE_NAME:
|
63 |
+
# 🇺🇸 'a' => American English => en-us
|
64 |
+
# 🇬🇧 'b' => British English => en-gb
|
65 |
+
|
66 |
+
# 4️⃣ Display the 24khz audio and print the output phonemes
|
67 |
+
from IPython.display import display, Audio
|
68 |
+
display(Audio(data=audio, rate=24000, autoplay=True))
|
69 |
+
print(out_ps)
|
70 |
+
```
|
71 |
+
If you have trouble with `espeak-ng`, see this [github issue](https://github.com/bootphon/phonemizer/issues/44#issuecomment-1540885186). [Mac users also see this](https://huggingface.co/hexgrad/Kokoro-82M/discussions/12#677435d3d8ace1de46071489), and [Windows users see this](https://huggingface.co/hexgrad/Kokoro-82M/discussions/12#67742594fdeebf74f001ecfc).
|
72 |
+
|
73 |
+
For ONNX usage, see [#14](https://huggingface.co/hexgrad/Kokoro-82M/discussions/14).
|
74 |
+
|
75 |
+
### Model Facts
|
76 |
+
|
77 |
+
No affiliation can be assumed between parties on different lines.
|
78 |
+
|
79 |
+
**Architecture:**
|
80 |
+
- StyleTTS 2: https://arxiv.org/abs/2306.07691
|
81 |
+
- ISTFTNet: https://arxiv.org/abs/2203.02395
|
82 |
+
- Decoder only: no diffusion, no encoder release
|
83 |
+
|
84 |
+
**Architected by:** Li et al @ https://github.com/yl4579/StyleTTS2
|
85 |
+
|
86 |
+
**Trained by**: `@rzvzn` on Discord
|
87 |
+
|
88 |
+
**Supported Languages:** American English, British English
|
89 |
+
|
90 |
+
**Model SHA256 Hash:** `3b0c392f87508da38fad3a2f9d94c359f1b657ebd2ef79f9d56d69503e470b0a`
|
91 |
+
|
92 |
+
### Releases
|
93 |
+
- 25 Dec 2024: Model v0.19, `af_bella`, `af_sarah`
|
94 |
+
- 26 Dec 2024: `am_adam`, `am_michael`
|
95 |
+
- 28 Dec 2024: `bf_emma`, `bf_isabella`, `bm_george`, `bm_lewis`
|
96 |
+
- 30 Dec 2024: `af_nicole`
|
97 |
+
- 31 Dec 2024: `af_sky`
|
98 |
+
- 2 Jan 2025: ONNX v0.19 `ebef4245`
|
99 |
+
|
100 |
+
### Licenses
|
101 |
+
- Apache 2.0 weights in this repository
|
102 |
+
- MIT inference code in [spaces/hexgrad/Kokoro-TTS](https://huggingface.co/spaces/hexgrad/Kokoro-TTS) adapted from [yl4579/StyleTTS2](https://github.com/yl4579/StyleTTS2)
|
103 |
+
- GPLv3 dependency in [espeak-ng](https://github.com/espeak-ng/espeak-ng)
|
104 |
+
|
105 |
+
The inference code was originally MIT licensed by the paper author. Note that this card applies only to this model, Kokoro. Original models published by the paper author can be found at [hf.co/yl4579](https://huggingface.co/yl4579).
|
106 |
+
|
107 |
+
### Evaluation
|
108 |
+
|
109 |
+
**Metric:** Elo rating
|
110 |
+
|
111 |
+
**Leaderboard:** [hf.co/spaces/Pendrokar/TTS-Spaces-Arena](https://huggingface.co/spaces/Pendrokar/TTS-Spaces-Arena)
|
112 |
+
|
113 |
+
![TTS-Spaces-Arena-25-Dec-2024](demo/TTS-Spaces-Arena-25-Dec-2024.png)
|
114 |
+
|
115 |
+
The voice ranked in the Arena is a 50-50 mix of Bella and Sarah. For your convenience, this mix is included in this repository as `af.pt`, but you can trivially reproduce it like this:
|
116 |
+
|
117 |
+
```py
|
118 |
+
import torch
|
119 |
+
bella = torch.load('voices/af_bella.pt', weights_only=True)
|
120 |
+
sarah = torch.load('voices/af_sarah.pt', weights_only=True)
|
121 |
+
af = torch.mean(torch.stack([bella, sarah]), dim=0)
|
122 |
+
assert torch.equal(af, torch.load('voices/af.pt', weights_only=True))
|
123 |
+
```
|
124 |
+
|
125 |
+
### Training Details
|
126 |
+
|
127 |
+
**Compute:** Kokoro was trained on A100 80GB vRAM instances rented from [Vast.ai](https://cloud.vast.ai/?ref_id=79907) (referral link). Vast was chosen over other compute providers due to its competitive on-demand hourly rates. The average hourly cost for the A100 80GB vRAM instances used for training was below $1/hr per GPU, which was around half the quoted rates from other providers at the time.
|
128 |
+
|
129 |
+
**Data:** Kokoro was trained exclusively on **permissive/non-copyrighted audio data** and IPA phoneme labels. Examples of permissive/non-copyrighted audio include:
|
130 |
+
- Public domain audio
|
131 |
+
- Audio licensed under Apache, MIT, etc
|
132 |
+
- Synthetic audio<sup>[1]</sup> generated by closed<sup>[2]</sup> TTS models from large providers<br/>
|
133 |
+
[1] https://copyright.gov/ai/ai_policy_guidance.pdf<br/>
|
134 |
+
[2] No synthetic audio from open TTS models or "custom voice clones"
|
135 |
+
|
136 |
+
**Epochs:** Less than **20 epochs**
|
137 |
+
|
138 |
+
**Total Dataset Size:** Less than **100 hours** of audio
|
139 |
+
|
140 |
+
### Limitations
|
141 |
+
|
142 |
+
Kokoro v0.19 is limited in some specific ways, due to its training set and/or architecture:
|
143 |
+
- [Data] Lacks voice cloning capability, likely due to small <100h training set
|
144 |
+
- [Arch] Relies on external g2p (espeak-ng), which introduces a class of g2p failure modes
|
145 |
+
- [Data] Training dataset is mostly long-form reading and narration, not conversation
|
146 |
+
- [Arch] At 82M params, Kokoro almost certainly falls to a well-trained 1B+ param diffusion transformer, or a many-billion-param MLLM like GPT-4o / Gemini 2.0 Flash
|
147 |
+
- [Data] Multilingual capability is architecturally feasible, but training data is mostly English
|
148 |
+
|
149 |
+
Refer to the [Philosophy discussion](https://huggingface.co/hexgrad/Kokoro-82M/discussions/5) to better understand these limitations.
|
150 |
+
|
151 |
+
**Will the other voicepacks be released?** There is currently no release date scheduled for the other voicepacks, but in the meantime you can try them in the hosted demo at [hf.co/spaces/hexgrad/Kokoro-TTS](https://huggingface.co/spaces/hexgrad/Kokoro-TTS).
|
152 |
+
|
153 |
+
### Acknowledgements
|
154 |
+
- [@yl4579](https://huggingface.co/yl4579) for architecting StyleTTS 2
|
155 |
+
- [@Pendrokar](https://huggingface.co/Pendrokar) for adding Kokoro as a contender in the TTS Spaces Arena
|
156 |
+
|
157 |
+
### Model Card Contact
|
158 |
+
|
159 |
+
`@rzvzn` on Discord. Server invite: https://discord.gg/QuGxSWBfQy
|
160 |
+
|
161 |
+
<img src="https://static0.gamerantimages.com/wordpress/wp-content/uploads/2024/08/terminator-zero-41-1.jpg" width="400" alt="kokoro" />
|
162 |
+
|
163 |
+
https://terminator.fandom.com/wiki/Kokoro
|
KOKORO/config.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"decoder": {
|
3 |
+
"type": "istftnet",
|
4 |
+
"upsample_kernel_sizes": [20, 12],
|
5 |
+
"upsample_rates": [10, 6],
|
6 |
+
"gen_istft_hop_size": 5,
|
7 |
+
"gen_istft_n_fft": 20,
|
8 |
+
"resblock_dilation_sizes": [
|
9 |
+
[1, 3, 5],
|
10 |
+
[1, 3, 5],
|
11 |
+
[1, 3, 5]
|
12 |
+
],
|
13 |
+
"resblock_kernel_sizes": [3, 7, 11],
|
14 |
+
"upsample_initial_channel": 512
|
15 |
+
},
|
16 |
+
"dim_in": 64,
|
17 |
+
"dropout": 0.2,
|
18 |
+
"hidden_dim": 512,
|
19 |
+
"max_conv_dim": 512,
|
20 |
+
"max_dur": 50,
|
21 |
+
"multispeaker": true,
|
22 |
+
"n_layer": 3,
|
23 |
+
"n_mels": 80,
|
24 |
+
"n_token": 178,
|
25 |
+
"style_dim": 128
|
26 |
+
}
|
KOKORO/istftnet.py
ADDED
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
|
2 |
+
from scipy.signal import get_window
|
3 |
+
from torch.nn import Conv1d, ConvTranspose1d
|
4 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
|
11 |
+
def init_weights(m, mean=0.0, std=0.01):
|
12 |
+
classname = m.__class__.__name__
|
13 |
+
if classname.find("Conv") != -1:
|
14 |
+
m.weight.data.normal_(mean, std)
|
15 |
+
|
16 |
+
def get_padding(kernel_size, dilation=1):
|
17 |
+
return int((kernel_size*dilation - dilation)/2)
|
18 |
+
|
19 |
+
LRELU_SLOPE = 0.1
|
20 |
+
|
21 |
+
class AdaIN1d(nn.Module):
|
22 |
+
def __init__(self, style_dim, num_features):
|
23 |
+
super().__init__()
|
24 |
+
self.norm = nn.InstanceNorm1d(num_features, affine=False)
|
25 |
+
self.fc = nn.Linear(style_dim, num_features*2)
|
26 |
+
|
27 |
+
def forward(self, x, s):
|
28 |
+
h = self.fc(s)
|
29 |
+
h = h.view(h.size(0), h.size(1), 1)
|
30 |
+
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
31 |
+
return (1 + gamma) * self.norm(x) + beta
|
32 |
+
|
33 |
+
class AdaINResBlock1(torch.nn.Module):
|
34 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
|
35 |
+
super(AdaINResBlock1, self).__init__()
|
36 |
+
self.convs1 = nn.ModuleList([
|
37 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
38 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
39 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
40 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
41 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
42 |
+
padding=get_padding(kernel_size, dilation[2])))
|
43 |
+
])
|
44 |
+
self.convs1.apply(init_weights)
|
45 |
+
|
46 |
+
self.convs2 = nn.ModuleList([
|
47 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
48 |
+
padding=get_padding(kernel_size, 1))),
|
49 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
50 |
+
padding=get_padding(kernel_size, 1))),
|
51 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
52 |
+
padding=get_padding(kernel_size, 1)))
|
53 |
+
])
|
54 |
+
self.convs2.apply(init_weights)
|
55 |
+
|
56 |
+
self.adain1 = nn.ModuleList([
|
57 |
+
AdaIN1d(style_dim, channels),
|
58 |
+
AdaIN1d(style_dim, channels),
|
59 |
+
AdaIN1d(style_dim, channels),
|
60 |
+
])
|
61 |
+
|
62 |
+
self.adain2 = nn.ModuleList([
|
63 |
+
AdaIN1d(style_dim, channels),
|
64 |
+
AdaIN1d(style_dim, channels),
|
65 |
+
AdaIN1d(style_dim, channels),
|
66 |
+
])
|
67 |
+
|
68 |
+
self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
|
69 |
+
self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
|
70 |
+
|
71 |
+
|
72 |
+
def forward(self, x, s):
|
73 |
+
for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
|
74 |
+
xt = n1(x, s)
|
75 |
+
xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
|
76 |
+
xt = c1(xt)
|
77 |
+
xt = n2(xt, s)
|
78 |
+
xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
|
79 |
+
xt = c2(xt)
|
80 |
+
x = xt + x
|
81 |
+
return x
|
82 |
+
|
83 |
+
def remove_weight_norm(self):
|
84 |
+
for l in self.convs1:
|
85 |
+
remove_weight_norm(l)
|
86 |
+
for l in self.convs2:
|
87 |
+
remove_weight_norm(l)
|
88 |
+
|
89 |
+
class TorchSTFT(torch.nn.Module):
|
90 |
+
def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
|
91 |
+
super().__init__()
|
92 |
+
self.filter_length = filter_length
|
93 |
+
self.hop_length = hop_length
|
94 |
+
self.win_length = win_length
|
95 |
+
self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32))
|
96 |
+
|
97 |
+
def transform(self, input_data):
|
98 |
+
forward_transform = torch.stft(
|
99 |
+
input_data,
|
100 |
+
self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
|
101 |
+
return_complex=True)
|
102 |
+
|
103 |
+
return torch.abs(forward_transform), torch.angle(forward_transform)
|
104 |
+
|
105 |
+
def inverse(self, magnitude, phase):
|
106 |
+
inverse_transform = torch.istft(
|
107 |
+
magnitude * torch.exp(phase * 1j),
|
108 |
+
self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
|
109 |
+
|
110 |
+
return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
|
111 |
+
|
112 |
+
def forward(self, input_data):
|
113 |
+
self.magnitude, self.phase = self.transform(input_data)
|
114 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
115 |
+
return reconstruction
|
116 |
+
|
117 |
+
class SineGen(torch.nn.Module):
|
118 |
+
""" Definition of sine generator
|
119 |
+
SineGen(samp_rate, harmonic_num = 0,
|
120 |
+
sine_amp = 0.1, noise_std = 0.003,
|
121 |
+
voiced_threshold = 0,
|
122 |
+
flag_for_pulse=False)
|
123 |
+
samp_rate: sampling rate in Hz
|
124 |
+
harmonic_num: number of harmonic overtones (default 0)
|
125 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
126 |
+
noise_std: std of Gaussian noise (default 0.003)
|
127 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
128 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
129 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
130 |
+
segment is always sin(np.pi) or cos(0)
|
131 |
+
"""
|
132 |
+
|
133 |
+
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
|
134 |
+
sine_amp=0.1, noise_std=0.003,
|
135 |
+
voiced_threshold=0,
|
136 |
+
flag_for_pulse=False):
|
137 |
+
super(SineGen, self).__init__()
|
138 |
+
self.sine_amp = sine_amp
|
139 |
+
self.noise_std = noise_std
|
140 |
+
self.harmonic_num = harmonic_num
|
141 |
+
self.dim = self.harmonic_num + 1
|
142 |
+
self.sampling_rate = samp_rate
|
143 |
+
self.voiced_threshold = voiced_threshold
|
144 |
+
self.flag_for_pulse = flag_for_pulse
|
145 |
+
self.upsample_scale = upsample_scale
|
146 |
+
|
147 |
+
def _f02uv(self, f0):
|
148 |
+
# generate uv signal
|
149 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
150 |
+
return uv
|
151 |
+
|
152 |
+
def _f02sine(self, f0_values):
|
153 |
+
""" f0_values: (batchsize, length, dim)
|
154 |
+
where dim indicates fundamental tone and overtones
|
155 |
+
"""
|
156 |
+
# convert to F0 in rad. The interger part n can be ignored
|
157 |
+
# because 2 * np.pi * n doesn't affect phase
|
158 |
+
rad_values = (f0_values / self.sampling_rate) % 1
|
159 |
+
|
160 |
+
# initial phase noise (no noise for fundamental component)
|
161 |
+
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
|
162 |
+
device=f0_values.device)
|
163 |
+
rand_ini[:, 0] = 0
|
164 |
+
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
165 |
+
|
166 |
+
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
167 |
+
if not self.flag_for_pulse:
|
168 |
+
# # for normal case
|
169 |
+
|
170 |
+
# # To prevent torch.cumsum numerical overflow,
|
171 |
+
# # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
|
172 |
+
# # Buffer tmp_over_one_idx indicates the time step to add -1.
|
173 |
+
# # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
|
174 |
+
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
175 |
+
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
176 |
+
# cumsum_shift = torch.zeros_like(rad_values)
|
177 |
+
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
178 |
+
|
179 |
+
# phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
180 |
+
rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
|
181 |
+
scale_factor=1/self.upsample_scale,
|
182 |
+
mode="linear").transpose(1, 2)
|
183 |
+
|
184 |
+
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
185 |
+
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
186 |
+
# cumsum_shift = torch.zeros_like(rad_values)
|
187 |
+
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
188 |
+
|
189 |
+
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
190 |
+
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
|
191 |
+
scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
|
192 |
+
sines = torch.sin(phase)
|
193 |
+
|
194 |
+
else:
|
195 |
+
# If necessary, make sure that the first time step of every
|
196 |
+
# voiced segments is sin(pi) or cos(0)
|
197 |
+
# This is used for pulse-train generation
|
198 |
+
|
199 |
+
# identify the last time step in unvoiced segments
|
200 |
+
uv = self._f02uv(f0_values)
|
201 |
+
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
202 |
+
uv_1[:, -1, :] = 1
|
203 |
+
u_loc = (uv < 1) * (uv_1 > 0)
|
204 |
+
|
205 |
+
# get the instantanouse phase
|
206 |
+
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
207 |
+
# different batch needs to be processed differently
|
208 |
+
for idx in range(f0_values.shape[0]):
|
209 |
+
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
210 |
+
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
211 |
+
# stores the accumulation of i.phase within
|
212 |
+
# each voiced segments
|
213 |
+
tmp_cumsum[idx, :, :] = 0
|
214 |
+
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
215 |
+
|
216 |
+
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
217 |
+
# within the previous voiced segment.
|
218 |
+
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
219 |
+
|
220 |
+
# get the sines
|
221 |
+
sines = torch.cos(i_phase * 2 * np.pi)
|
222 |
+
return sines
|
223 |
+
|
224 |
+
def forward(self, f0):
|
225 |
+
""" sine_tensor, uv = forward(f0)
|
226 |
+
input F0: tensor(batchsize=1, length, dim=1)
|
227 |
+
f0 for unvoiced steps should be 0
|
228 |
+
output sine_tensor: tensor(batchsize=1, length, dim)
|
229 |
+
output uv: tensor(batchsize=1, length, 1)
|
230 |
+
"""
|
231 |
+
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
|
232 |
+
device=f0.device)
|
233 |
+
# fundamental component
|
234 |
+
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
|
235 |
+
|
236 |
+
# generate sine waveforms
|
237 |
+
sine_waves = self._f02sine(fn) * self.sine_amp
|
238 |
+
|
239 |
+
# generate uv signal
|
240 |
+
# uv = torch.ones(f0.shape)
|
241 |
+
# uv = uv * (f0 > self.voiced_threshold)
|
242 |
+
uv = self._f02uv(f0)
|
243 |
+
|
244 |
+
# noise: for unvoiced should be similar to sine_amp
|
245 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
246 |
+
# . for voiced regions is self.noise_std
|
247 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
248 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
249 |
+
|
250 |
+
# first: set the unvoiced part to 0 by uv
|
251 |
+
# then: additive noise
|
252 |
+
sine_waves = sine_waves * uv + noise
|
253 |
+
return sine_waves, uv, noise
|
254 |
+
|
255 |
+
|
256 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
257 |
+
""" SourceModule for hn-nsf
|
258 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
259 |
+
add_noise_std=0.003, voiced_threshod=0)
|
260 |
+
sampling_rate: sampling_rate in Hz
|
261 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
262 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
263 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
264 |
+
note that amplitude of noise in unvoiced is decided
|
265 |
+
by sine_amp
|
266 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
267 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
268 |
+
F0_sampled (batchsize, length, 1)
|
269 |
+
Sine_source (batchsize, length, 1)
|
270 |
+
noise_source (batchsize, length 1)
|
271 |
+
uv (batchsize, length, 1)
|
272 |
+
"""
|
273 |
+
|
274 |
+
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
275 |
+
add_noise_std=0.003, voiced_threshod=0):
|
276 |
+
super(SourceModuleHnNSF, self).__init__()
|
277 |
+
|
278 |
+
self.sine_amp = sine_amp
|
279 |
+
self.noise_std = add_noise_std
|
280 |
+
|
281 |
+
# to produce sine waveforms
|
282 |
+
self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
|
283 |
+
sine_amp, add_noise_std, voiced_threshod)
|
284 |
+
|
285 |
+
# to merge source harmonics into a single excitation
|
286 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
287 |
+
self.l_tanh = torch.nn.Tanh()
|
288 |
+
|
289 |
+
def forward(self, x):
|
290 |
+
"""
|
291 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
292 |
+
F0_sampled (batchsize, length, 1)
|
293 |
+
Sine_source (batchsize, length, 1)
|
294 |
+
noise_source (batchsize, length 1)
|
295 |
+
"""
|
296 |
+
# source for harmonic branch
|
297 |
+
with torch.no_grad():
|
298 |
+
sine_wavs, uv, _ = self.l_sin_gen(x)
|
299 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
300 |
+
|
301 |
+
# source for noise branch, in the same shape as uv
|
302 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
303 |
+
return sine_merge, noise, uv
|
304 |
+
def padDiff(x):
|
305 |
+
return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
|
306 |
+
|
307 |
+
|
308 |
+
class Generator(torch.nn.Module):
|
309 |
+
def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size):
|
310 |
+
super(Generator, self).__init__()
|
311 |
+
|
312 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
313 |
+
self.num_upsamples = len(upsample_rates)
|
314 |
+
resblock = AdaINResBlock1
|
315 |
+
|
316 |
+
self.m_source = SourceModuleHnNSF(
|
317 |
+
sampling_rate=24000,
|
318 |
+
upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
|
319 |
+
harmonic_num=8, voiced_threshod=10)
|
320 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size)
|
321 |
+
self.noise_convs = nn.ModuleList()
|
322 |
+
self.noise_res = nn.ModuleList()
|
323 |
+
|
324 |
+
self.ups = nn.ModuleList()
|
325 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
326 |
+
self.ups.append(weight_norm(
|
327 |
+
ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
|
328 |
+
k, u, padding=(k-u)//2)))
|
329 |
+
|
330 |
+
self.resblocks = nn.ModuleList()
|
331 |
+
for i in range(len(self.ups)):
|
332 |
+
ch = upsample_initial_channel//(2**(i+1))
|
333 |
+
for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
|
334 |
+
self.resblocks.append(resblock(ch, k, d, style_dim))
|
335 |
+
|
336 |
+
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
337 |
+
|
338 |
+
if i + 1 < len(upsample_rates): #
|
339 |
+
stride_f0 = np.prod(upsample_rates[i + 1:])
|
340 |
+
self.noise_convs.append(Conv1d(
|
341 |
+
gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
|
342 |
+
self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
|
343 |
+
else:
|
344 |
+
self.noise_convs.append(Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
|
345 |
+
self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
|
346 |
+
|
347 |
+
|
348 |
+
self.post_n_fft = gen_istft_n_fft
|
349 |
+
self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
|
350 |
+
self.ups.apply(init_weights)
|
351 |
+
self.conv_post.apply(init_weights)
|
352 |
+
self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
|
353 |
+
self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
|
354 |
+
|
355 |
+
|
356 |
+
def forward(self, x, s, f0):
|
357 |
+
with torch.no_grad():
|
358 |
+
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
359 |
+
|
360 |
+
har_source, noi_source, uv = self.m_source(f0)
|
361 |
+
har_source = har_source.transpose(1, 2).squeeze(1)
|
362 |
+
har_spec, har_phase = self.stft.transform(har_source)
|
363 |
+
har = torch.cat([har_spec, har_phase], dim=1)
|
364 |
+
|
365 |
+
for i in range(self.num_upsamples):
|
366 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
367 |
+
x_source = self.noise_convs[i](har)
|
368 |
+
x_source = self.noise_res[i](x_source, s)
|
369 |
+
|
370 |
+
x = self.ups[i](x)
|
371 |
+
if i == self.num_upsamples - 1:
|
372 |
+
x = self.reflection_pad(x)
|
373 |
+
|
374 |
+
x = x + x_source
|
375 |
+
xs = None
|
376 |
+
for j in range(self.num_kernels):
|
377 |
+
if xs is None:
|
378 |
+
xs = self.resblocks[i*self.num_kernels+j](x, s)
|
379 |
+
else:
|
380 |
+
xs += self.resblocks[i*self.num_kernels+j](x, s)
|
381 |
+
x = xs / self.num_kernels
|
382 |
+
x = F.leaky_relu(x)
|
383 |
+
x = self.conv_post(x)
|
384 |
+
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
|
385 |
+
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
386 |
+
return self.stft.inverse(spec, phase)
|
387 |
+
|
388 |
+
def fw_phase(self, x, s):
|
389 |
+
for i in range(self.num_upsamples):
|
390 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
391 |
+
x = self.ups[i](x)
|
392 |
+
xs = None
|
393 |
+
for j in range(self.num_kernels):
|
394 |
+
if xs is None:
|
395 |
+
xs = self.resblocks[i*self.num_kernels+j](x, s)
|
396 |
+
else:
|
397 |
+
xs += self.resblocks[i*self.num_kernels+j](x, s)
|
398 |
+
x = xs / self.num_kernels
|
399 |
+
x = F.leaky_relu(x)
|
400 |
+
x = self.reflection_pad(x)
|
401 |
+
x = self.conv_post(x)
|
402 |
+
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
|
403 |
+
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
404 |
+
return spec, phase
|
405 |
+
|
406 |
+
def remove_weight_norm(self):
|
407 |
+
print('Removing weight norm...')
|
408 |
+
for l in self.ups:
|
409 |
+
remove_weight_norm(l)
|
410 |
+
for l in self.resblocks:
|
411 |
+
l.remove_weight_norm()
|
412 |
+
remove_weight_norm(self.conv_pre)
|
413 |
+
remove_weight_norm(self.conv_post)
|
414 |
+
|
415 |
+
|
416 |
+
class AdainResBlk1d(nn.Module):
|
417 |
+
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
|
418 |
+
upsample='none', dropout_p=0.0):
|
419 |
+
super().__init__()
|
420 |
+
self.actv = actv
|
421 |
+
self.upsample_type = upsample
|
422 |
+
self.upsample = UpSample1d(upsample)
|
423 |
+
self.learned_sc = dim_in != dim_out
|
424 |
+
self._build_weights(dim_in, dim_out, style_dim)
|
425 |
+
self.dropout = nn.Dropout(dropout_p)
|
426 |
+
|
427 |
+
if upsample == 'none':
|
428 |
+
self.pool = nn.Identity()
|
429 |
+
else:
|
430 |
+
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
|
431 |
+
|
432 |
+
|
433 |
+
def _build_weights(self, dim_in, dim_out, style_dim):
|
434 |
+
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
435 |
+
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
436 |
+
self.norm1 = AdaIN1d(style_dim, dim_in)
|
437 |
+
self.norm2 = AdaIN1d(style_dim, dim_out)
|
438 |
+
if self.learned_sc:
|
439 |
+
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
440 |
+
|
441 |
+
def _shortcut(self, x):
|
442 |
+
x = self.upsample(x)
|
443 |
+
if self.learned_sc:
|
444 |
+
x = self.conv1x1(x)
|
445 |
+
return x
|
446 |
+
|
447 |
+
def _residual(self, x, s):
|
448 |
+
x = self.norm1(x, s)
|
449 |
+
x = self.actv(x)
|
450 |
+
x = self.pool(x)
|
451 |
+
x = self.conv1(self.dropout(x))
|
452 |
+
x = self.norm2(x, s)
|
453 |
+
x = self.actv(x)
|
454 |
+
x = self.conv2(self.dropout(x))
|
455 |
+
return x
|
456 |
+
|
457 |
+
def forward(self, x, s):
|
458 |
+
out = self._residual(x, s)
|
459 |
+
out = (out + self._shortcut(x)) / np.sqrt(2)
|
460 |
+
return out
|
461 |
+
|
462 |
+
class UpSample1d(nn.Module):
|
463 |
+
def __init__(self, layer_type):
|
464 |
+
super().__init__()
|
465 |
+
self.layer_type = layer_type
|
466 |
+
|
467 |
+
def forward(self, x):
|
468 |
+
if self.layer_type == 'none':
|
469 |
+
return x
|
470 |
+
else:
|
471 |
+
return F.interpolate(x, scale_factor=2, mode='nearest')
|
472 |
+
|
473 |
+
class Decoder(nn.Module):
|
474 |
+
def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
|
475 |
+
resblock_kernel_sizes = [3,7,11],
|
476 |
+
upsample_rates = [10, 6],
|
477 |
+
upsample_initial_channel=512,
|
478 |
+
resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
|
479 |
+
upsample_kernel_sizes=[20, 12],
|
480 |
+
gen_istft_n_fft=20, gen_istft_hop_size=5):
|
481 |
+
super().__init__()
|
482 |
+
|
483 |
+
self.decode = nn.ModuleList()
|
484 |
+
|
485 |
+
self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
|
486 |
+
|
487 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
488 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
489 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
490 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
|
491 |
+
|
492 |
+
self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
493 |
+
|
494 |
+
self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
495 |
+
|
496 |
+
self.asr_res = nn.Sequential(
|
497 |
+
weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
|
498 |
+
)
|
499 |
+
|
500 |
+
|
501 |
+
self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
|
502 |
+
upsample_initial_channel, resblock_dilation_sizes,
|
503 |
+
upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size)
|
504 |
+
|
505 |
+
def forward(self, asr, F0_curve, N, s):
|
506 |
+
F0 = self.F0_conv(F0_curve.unsqueeze(1))
|
507 |
+
N = self.N_conv(N.unsqueeze(1))
|
508 |
+
|
509 |
+
x = torch.cat([asr, F0, N], axis=1)
|
510 |
+
x = self.encode(x, s)
|
511 |
+
|
512 |
+
asr_res = self.asr_res(asr)
|
513 |
+
|
514 |
+
res = True
|
515 |
+
for block in self.decode:
|
516 |
+
if res:
|
517 |
+
x = torch.cat([x, asr_res, F0, N], axis=1)
|
518 |
+
x = block(x, s)
|
519 |
+
if block.upsample_type != "none":
|
520 |
+
res = False
|
521 |
+
|
522 |
+
x = self.generator(x, s, F0_curve)
|
523 |
+
return x
|
KOKORO/kokoro.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import phonemizer
|
2 |
+
import re
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import os
|
6 |
+
import platform
|
7 |
+
|
8 |
+
# Check if the system is Windows
|
9 |
+
is_windows = platform.system() == "Windows"
|
10 |
+
|
11 |
+
# If Windows, set the environment variables
|
12 |
+
if is_windows:
|
13 |
+
os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = r"C:\Program Files\eSpeak NG\libespeak-ng.dll"
|
14 |
+
os.environ["PHONEMIZER_ESPEAK_PATH"] = r"C:\Program Files\eSpeak NG\espeak-ng.exe"
|
15 |
+
|
16 |
+
|
17 |
+
def split_num(num):
|
18 |
+
num = num.group()
|
19 |
+
if '.' in num:
|
20 |
+
return num
|
21 |
+
elif ':' in num:
|
22 |
+
h, m = [int(n) for n in num.split(':')]
|
23 |
+
if m == 0:
|
24 |
+
return f"{h} o'clock"
|
25 |
+
elif m < 10:
|
26 |
+
return f'{h} oh {m}'
|
27 |
+
return f'{h} {m}'
|
28 |
+
year = int(num[:4])
|
29 |
+
if year < 1100 or year % 1000 < 10:
|
30 |
+
return num
|
31 |
+
left, right = num[:2], int(num[2:4])
|
32 |
+
s = 's' if num.endswith('s') else ''
|
33 |
+
if 100 <= year % 1000 <= 999:
|
34 |
+
if right == 0:
|
35 |
+
return f'{left} hundred{s}'
|
36 |
+
elif right < 10:
|
37 |
+
return f'{left} oh {right}{s}'
|
38 |
+
return f'{left} {right}{s}'
|
39 |
+
|
40 |
+
def flip_money(m):
|
41 |
+
m = m.group()
|
42 |
+
bill = 'dollar' if m[0] == '$' else 'pound'
|
43 |
+
if m[-1].isalpha():
|
44 |
+
return f'{m[1:]} {bill}s'
|
45 |
+
elif '.' not in m:
|
46 |
+
s = '' if m[1:] == '1' else 's'
|
47 |
+
return f'{m[1:]} {bill}{s}'
|
48 |
+
b, c = m[1:].split('.')
|
49 |
+
s = '' if b == '1' else 's'
|
50 |
+
c = int(c.ljust(2, '0'))
|
51 |
+
coins = f"cent{'' if c == 1 else 's'}" if m[0] == '$' else ('penny' if c == 1 else 'pence')
|
52 |
+
return f'{b} {bill}{s} and {c} {coins}'
|
53 |
+
|
54 |
+
def point_num(num):
|
55 |
+
a, b = num.group().split('.')
|
56 |
+
return ' point '.join([a, ' '.join(b)])
|
57 |
+
|
58 |
+
def normalize_text(text):
|
59 |
+
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
60 |
+
text = text.replace('«', chr(8220)).replace('»', chr(8221))
|
61 |
+
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
|
62 |
+
text = text.replace('(', '«').replace(')', '»')
|
63 |
+
for a, b in zip('、。!,:;?', ',.!,:;?'):
|
64 |
+
text = text.replace(a, b+' ')
|
65 |
+
text = re.sub(r'[^\S \n]', ' ', text)
|
66 |
+
text = re.sub(r' +', ' ', text)
|
67 |
+
text = re.sub(r'(?<=\n) +(?=\n)', '', text)
|
68 |
+
text = re.sub(r'\bD[Rr]\.(?= [A-Z])', 'Doctor', text)
|
69 |
+
text = re.sub(r'\b(?:Mr\.|MR\.(?= [A-Z]))', 'Mister', text)
|
70 |
+
text = re.sub(r'\b(?:Ms\.|MS\.(?= [A-Z]))', 'Miss', text)
|
71 |
+
text = re.sub(r'\b(?:Mrs\.|MRS\.(?= [A-Z]))', 'Mrs', text)
|
72 |
+
text = re.sub(r'\betc\.(?! [A-Z])', 'etc', text)
|
73 |
+
text = re.sub(r'(?i)\b(y)eah?\b', r"\1e'a", text)
|
74 |
+
text = re.sub(r'\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)', split_num, text)
|
75 |
+
text = re.sub(r'(?<=\d),(?=\d)', '', text)
|
76 |
+
text = re.sub(r'(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b', flip_money, text)
|
77 |
+
text = re.sub(r'\d*\.\d+', point_num, text)
|
78 |
+
text = re.sub(r'(?<=\d)-(?=\d)', ' to ', text)
|
79 |
+
text = re.sub(r'(?<=\d)S', ' S', text)
|
80 |
+
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
|
81 |
+
text = re.sub(r"(?<=X')S\b", 's', text)
|
82 |
+
text = re.sub(r'(?:[A-Za-z]\.){2,} [a-z]', lambda m: m.group().replace('.', '-'), text)
|
83 |
+
text = re.sub(r'(?i)(?<=[A-Z])\.(?=[A-Z])', '-', text)
|
84 |
+
return text.strip()
|
85 |
+
|
86 |
+
def get_vocab():
|
87 |
+
_pad = "$"
|
88 |
+
_punctuation = ';:,.!?¡¿—…"«»“” '
|
89 |
+
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
90 |
+
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
91 |
+
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
|
92 |
+
dicts = {}
|
93 |
+
for i in range(len((symbols))):
|
94 |
+
dicts[symbols[i]] = i
|
95 |
+
return dicts
|
96 |
+
|
97 |
+
VOCAB = get_vocab()
|
98 |
+
def tokenize(ps):
|
99 |
+
return [i for i in map(VOCAB.get, ps) if i is not None]
|
100 |
+
|
101 |
+
phonemizers = dict(
|
102 |
+
a=phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True),
|
103 |
+
b=phonemizer.backend.EspeakBackend(language='en-gb', preserve_punctuation=True, with_stress=True),
|
104 |
+
)
|
105 |
+
def phonemize(text, lang, norm=True):
|
106 |
+
if norm:
|
107 |
+
text = normalize_text(text)
|
108 |
+
ps = phonemizers[lang].phonemize([text])
|
109 |
+
ps = ps[0] if ps else ''
|
110 |
+
# https://en.wiktionary.org/wiki/kokoro#English
|
111 |
+
ps = ps.replace('kəkˈoːɹoʊ', 'kˈoʊkəɹoʊ').replace('kəkˈɔːɹəʊ', 'kˈəʊkəɹəʊ')
|
112 |
+
ps = ps.replace('ʲ', 'j').replace('r', 'ɹ').replace('x', 'k').replace('ɬ', 'l')
|
113 |
+
ps = re.sub(r'(?<=[a-zɹː])(?=hˈʌndɹɪd)', ' ', ps)
|
114 |
+
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', 'z', ps)
|
115 |
+
if lang == 'a':
|
116 |
+
ps = re.sub(r'(?<=nˈaɪn)ti(?!ː)', 'di', ps)
|
117 |
+
ps = ''.join(filter(lambda p: p in VOCAB, ps))
|
118 |
+
return ps.strip()
|
119 |
+
|
120 |
+
def length_to_mask(lengths):
|
121 |
+
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
122 |
+
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
123 |
+
return mask
|
124 |
+
|
125 |
+
@torch.no_grad()
|
126 |
+
def forward(model, tokens, ref_s, speed):
|
127 |
+
device = ref_s.device
|
128 |
+
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
129 |
+
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
130 |
+
text_mask = length_to_mask(input_lengths).to(device)
|
131 |
+
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
132 |
+
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
133 |
+
s = ref_s[:, 128:]
|
134 |
+
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
135 |
+
x, _ = model.predictor.lstm(d)
|
136 |
+
duration = model.predictor.duration_proj(x)
|
137 |
+
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
138 |
+
pred_dur = torch.round(duration).clamp(min=1).long()
|
139 |
+
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
|
140 |
+
c_frame = 0
|
141 |
+
for i in range(pred_aln_trg.size(0)):
|
142 |
+
pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
|
143 |
+
c_frame += pred_dur[0,i].item()
|
144 |
+
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
|
145 |
+
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
146 |
+
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
147 |
+
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
|
148 |
+
return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
|
149 |
+
|
150 |
+
def generate(model, text, voicepack, lang='a', speed=1, ps=None):
|
151 |
+
ps = ps or phonemize(text, lang)
|
152 |
+
tokens = tokenize(ps)
|
153 |
+
if not tokens:
|
154 |
+
return None
|
155 |
+
elif len(tokens) > 510:
|
156 |
+
tokens = tokens[:510]
|
157 |
+
print('Truncated to 510 tokens')
|
158 |
+
ref_s = voicepack[len(tokens)]
|
159 |
+
out = forward(model, tokens, ref_s, speed)
|
160 |
+
ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
|
161 |
+
return out, ps
|
KOKORO/models.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/yl4579/StyleTTS2/blob/main/models.py
|
2 |
+
from .istftnet import AdaIN1d, Decoder
|
3 |
+
from munch import Munch
|
4 |
+
from pathlib import Path
|
5 |
+
from .plbert import load_plbert
|
6 |
+
from torch.nn.utils import weight_norm, spectral_norm
|
7 |
+
import json
|
8 |
+
import numpy as np
|
9 |
+
import os
|
10 |
+
import os.path as osp
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
class LinearNorm(torch.nn.Module):
|
16 |
+
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
17 |
+
super(LinearNorm, self).__init__()
|
18 |
+
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
19 |
+
|
20 |
+
torch.nn.init.xavier_uniform_(
|
21 |
+
self.linear_layer.weight,
|
22 |
+
gain=torch.nn.init.calculate_gain(w_init_gain))
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
return self.linear_layer(x)
|
26 |
+
|
27 |
+
class LayerNorm(nn.Module):
|
28 |
+
def __init__(self, channels, eps=1e-5):
|
29 |
+
super().__init__()
|
30 |
+
self.channels = channels
|
31 |
+
self.eps = eps
|
32 |
+
|
33 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
34 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = x.transpose(1, -1)
|
38 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
39 |
+
return x.transpose(1, -1)
|
40 |
+
|
41 |
+
class TextEncoder(nn.Module):
|
42 |
+
def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
|
43 |
+
super().__init__()
|
44 |
+
self.embedding = nn.Embedding(n_symbols, channels)
|
45 |
+
|
46 |
+
padding = (kernel_size - 1) // 2
|
47 |
+
self.cnn = nn.ModuleList()
|
48 |
+
for _ in range(depth):
|
49 |
+
self.cnn.append(nn.Sequential(
|
50 |
+
weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
|
51 |
+
LayerNorm(channels),
|
52 |
+
actv,
|
53 |
+
nn.Dropout(0.2),
|
54 |
+
))
|
55 |
+
# self.cnn = nn.Sequential(*self.cnn)
|
56 |
+
|
57 |
+
self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
|
58 |
+
|
59 |
+
def forward(self, x, input_lengths, m):
|
60 |
+
x = self.embedding(x) # [B, T, emb]
|
61 |
+
x = x.transpose(1, 2) # [B, emb, T]
|
62 |
+
m = m.to(input_lengths.device).unsqueeze(1)
|
63 |
+
x.masked_fill_(m, 0.0)
|
64 |
+
|
65 |
+
for c in self.cnn:
|
66 |
+
x = c(x)
|
67 |
+
x.masked_fill_(m, 0.0)
|
68 |
+
|
69 |
+
x = x.transpose(1, 2) # [B, T, chn]
|
70 |
+
|
71 |
+
input_lengths = input_lengths.cpu().numpy()
|
72 |
+
x = nn.utils.rnn.pack_padded_sequence(
|
73 |
+
x, input_lengths, batch_first=True, enforce_sorted=False)
|
74 |
+
|
75 |
+
self.lstm.flatten_parameters()
|
76 |
+
x, _ = self.lstm(x)
|
77 |
+
x, _ = nn.utils.rnn.pad_packed_sequence(
|
78 |
+
x, batch_first=True)
|
79 |
+
|
80 |
+
x = x.transpose(-1, -2)
|
81 |
+
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
|
82 |
+
|
83 |
+
x_pad[:, :, :x.shape[-1]] = x
|
84 |
+
x = x_pad.to(x.device)
|
85 |
+
|
86 |
+
x.masked_fill_(m, 0.0)
|
87 |
+
|
88 |
+
return x
|
89 |
+
|
90 |
+
def inference(self, x):
|
91 |
+
x = self.embedding(x)
|
92 |
+
x = x.transpose(1, 2)
|
93 |
+
x = self.cnn(x)
|
94 |
+
x = x.transpose(1, 2)
|
95 |
+
self.lstm.flatten_parameters()
|
96 |
+
x, _ = self.lstm(x)
|
97 |
+
return x
|
98 |
+
|
99 |
+
def length_to_mask(self, lengths):
|
100 |
+
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
101 |
+
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
102 |
+
return mask
|
103 |
+
|
104 |
+
|
105 |
+
class UpSample1d(nn.Module):
|
106 |
+
def __init__(self, layer_type):
|
107 |
+
super().__init__()
|
108 |
+
self.layer_type = layer_type
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
if self.layer_type == 'none':
|
112 |
+
return x
|
113 |
+
else:
|
114 |
+
return F.interpolate(x, scale_factor=2, mode='nearest')
|
115 |
+
|
116 |
+
class AdainResBlk1d(nn.Module):
|
117 |
+
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
|
118 |
+
upsample='none', dropout_p=0.0):
|
119 |
+
super().__init__()
|
120 |
+
self.actv = actv
|
121 |
+
self.upsample_type = upsample
|
122 |
+
self.upsample = UpSample1d(upsample)
|
123 |
+
self.learned_sc = dim_in != dim_out
|
124 |
+
self._build_weights(dim_in, dim_out, style_dim)
|
125 |
+
self.dropout = nn.Dropout(dropout_p)
|
126 |
+
|
127 |
+
if upsample == 'none':
|
128 |
+
self.pool = nn.Identity()
|
129 |
+
else:
|
130 |
+
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
|
131 |
+
|
132 |
+
|
133 |
+
def _build_weights(self, dim_in, dim_out, style_dim):
|
134 |
+
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
135 |
+
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
136 |
+
self.norm1 = AdaIN1d(style_dim, dim_in)
|
137 |
+
self.norm2 = AdaIN1d(style_dim, dim_out)
|
138 |
+
if self.learned_sc:
|
139 |
+
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
140 |
+
|
141 |
+
def _shortcut(self, x):
|
142 |
+
x = self.upsample(x)
|
143 |
+
if self.learned_sc:
|
144 |
+
x = self.conv1x1(x)
|
145 |
+
return x
|
146 |
+
|
147 |
+
def _residual(self, x, s):
|
148 |
+
x = self.norm1(x, s)
|
149 |
+
x = self.actv(x)
|
150 |
+
x = self.pool(x)
|
151 |
+
x = self.conv1(self.dropout(x))
|
152 |
+
x = self.norm2(x, s)
|
153 |
+
x = self.actv(x)
|
154 |
+
x = self.conv2(self.dropout(x))
|
155 |
+
return x
|
156 |
+
|
157 |
+
def forward(self, x, s):
|
158 |
+
out = self._residual(x, s)
|
159 |
+
out = (out + self._shortcut(x)) / np.sqrt(2)
|
160 |
+
return out
|
161 |
+
|
162 |
+
class AdaLayerNorm(nn.Module):
|
163 |
+
def __init__(self, style_dim, channels, eps=1e-5):
|
164 |
+
super().__init__()
|
165 |
+
self.channels = channels
|
166 |
+
self.eps = eps
|
167 |
+
|
168 |
+
self.fc = nn.Linear(style_dim, channels*2)
|
169 |
+
|
170 |
+
def forward(self, x, s):
|
171 |
+
x = x.transpose(-1, -2)
|
172 |
+
x = x.transpose(1, -1)
|
173 |
+
|
174 |
+
h = self.fc(s)
|
175 |
+
h = h.view(h.size(0), h.size(1), 1)
|
176 |
+
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
177 |
+
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
|
178 |
+
|
179 |
+
|
180 |
+
x = F.layer_norm(x, (self.channels,), eps=self.eps)
|
181 |
+
x = (1 + gamma) * x + beta
|
182 |
+
return x.transpose(1, -1).transpose(-1, -2)
|
183 |
+
|
184 |
+
class ProsodyPredictor(nn.Module):
|
185 |
+
|
186 |
+
def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
|
187 |
+
super().__init__()
|
188 |
+
|
189 |
+
self.text_encoder = DurationEncoder(sty_dim=style_dim,
|
190 |
+
d_model=d_hid,
|
191 |
+
nlayers=nlayers,
|
192 |
+
dropout=dropout)
|
193 |
+
|
194 |
+
self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
|
195 |
+
self.duration_proj = LinearNorm(d_hid, max_dur)
|
196 |
+
|
197 |
+
self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
|
198 |
+
self.F0 = nn.ModuleList()
|
199 |
+
self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
200 |
+
self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
|
201 |
+
self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
|
202 |
+
|
203 |
+
self.N = nn.ModuleList()
|
204 |
+
self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
205 |
+
self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
|
206 |
+
self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
|
207 |
+
|
208 |
+
self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
209 |
+
self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
210 |
+
|
211 |
+
|
212 |
+
def forward(self, texts, style, text_lengths, alignment, m):
|
213 |
+
d = self.text_encoder(texts, style, text_lengths, m)
|
214 |
+
|
215 |
+
batch_size = d.shape[0]
|
216 |
+
text_size = d.shape[1]
|
217 |
+
|
218 |
+
# predict duration
|
219 |
+
input_lengths = text_lengths.cpu().numpy()
|
220 |
+
x = nn.utils.rnn.pack_padded_sequence(
|
221 |
+
d, input_lengths, batch_first=True, enforce_sorted=False)
|
222 |
+
|
223 |
+
m = m.to(text_lengths.device).unsqueeze(1)
|
224 |
+
|
225 |
+
self.lstm.flatten_parameters()
|
226 |
+
x, _ = self.lstm(x)
|
227 |
+
x, _ = nn.utils.rnn.pad_packed_sequence(
|
228 |
+
x, batch_first=True)
|
229 |
+
|
230 |
+
x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
|
231 |
+
|
232 |
+
x_pad[:, :x.shape[1], :] = x
|
233 |
+
x = x_pad.to(x.device)
|
234 |
+
|
235 |
+
duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
|
236 |
+
|
237 |
+
en = (d.transpose(-1, -2) @ alignment)
|
238 |
+
|
239 |
+
return duration.squeeze(-1), en
|
240 |
+
|
241 |
+
def F0Ntrain(self, x, s):
|
242 |
+
x, _ = self.shared(x.transpose(-1, -2))
|
243 |
+
|
244 |
+
F0 = x.transpose(-1, -2)
|
245 |
+
for block in self.F0:
|
246 |
+
F0 = block(F0, s)
|
247 |
+
F0 = self.F0_proj(F0)
|
248 |
+
|
249 |
+
N = x.transpose(-1, -2)
|
250 |
+
for block in self.N:
|
251 |
+
N = block(N, s)
|
252 |
+
N = self.N_proj(N)
|
253 |
+
|
254 |
+
return F0.squeeze(1), N.squeeze(1)
|
255 |
+
|
256 |
+
def length_to_mask(self, lengths):
|
257 |
+
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
258 |
+
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
259 |
+
return mask
|
260 |
+
|
261 |
+
class DurationEncoder(nn.Module):
|
262 |
+
|
263 |
+
def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
|
264 |
+
super().__init__()
|
265 |
+
self.lstms = nn.ModuleList()
|
266 |
+
for _ in range(nlayers):
|
267 |
+
self.lstms.append(nn.LSTM(d_model + sty_dim,
|
268 |
+
d_model // 2,
|
269 |
+
num_layers=1,
|
270 |
+
batch_first=True,
|
271 |
+
bidirectional=True,
|
272 |
+
dropout=dropout))
|
273 |
+
self.lstms.append(AdaLayerNorm(sty_dim, d_model))
|
274 |
+
|
275 |
+
|
276 |
+
self.dropout = dropout
|
277 |
+
self.d_model = d_model
|
278 |
+
self.sty_dim = sty_dim
|
279 |
+
|
280 |
+
def forward(self, x, style, text_lengths, m):
|
281 |
+
masks = m.to(text_lengths.device)
|
282 |
+
|
283 |
+
x = x.permute(2, 0, 1)
|
284 |
+
s = style.expand(x.shape[0], x.shape[1], -1)
|
285 |
+
x = torch.cat([x, s], axis=-1)
|
286 |
+
x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
|
287 |
+
|
288 |
+
x = x.transpose(0, 1)
|
289 |
+
input_lengths = text_lengths.cpu().numpy()
|
290 |
+
x = x.transpose(-1, -2)
|
291 |
+
|
292 |
+
for block in self.lstms:
|
293 |
+
if isinstance(block, AdaLayerNorm):
|
294 |
+
x = block(x.transpose(-1, -2), style).transpose(-1, -2)
|
295 |
+
x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
|
296 |
+
x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
|
297 |
+
else:
|
298 |
+
x = x.transpose(-1, -2)
|
299 |
+
x = nn.utils.rnn.pack_padded_sequence(
|
300 |
+
x, input_lengths, batch_first=True, enforce_sorted=False)
|
301 |
+
block.flatten_parameters()
|
302 |
+
x, _ = block(x)
|
303 |
+
x, _ = nn.utils.rnn.pad_packed_sequence(
|
304 |
+
x, batch_first=True)
|
305 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
306 |
+
x = x.transpose(-1, -2)
|
307 |
+
|
308 |
+
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
|
309 |
+
|
310 |
+
x_pad[:, :, :x.shape[-1]] = x
|
311 |
+
x = x_pad.to(x.device)
|
312 |
+
|
313 |
+
return x.transpose(-1, -2)
|
314 |
+
|
315 |
+
def inference(self, x, style):
|
316 |
+
x = self.embedding(x.transpose(-1, -2)) * np.sqrt(self.d_model)
|
317 |
+
style = style.expand(x.shape[0], x.shape[1], -1)
|
318 |
+
x = torch.cat([x, style], axis=-1)
|
319 |
+
src = self.pos_encoder(x)
|
320 |
+
output = self.transformer_encoder(src).transpose(0, 1)
|
321 |
+
return output
|
322 |
+
|
323 |
+
def length_to_mask(self, lengths):
|
324 |
+
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
325 |
+
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
326 |
+
return mask
|
327 |
+
|
328 |
+
# https://github.com/yl4579/StyleTTS2/blob/main/utils.py
|
329 |
+
def recursive_munch(d):
|
330 |
+
if isinstance(d, dict):
|
331 |
+
return Munch((k, recursive_munch(v)) for k, v in d.items())
|
332 |
+
elif isinstance(d, list):
|
333 |
+
return [recursive_munch(v) for v in d]
|
334 |
+
else:
|
335 |
+
return d
|
336 |
+
|
337 |
+
def build_model(path, device):
|
338 |
+
config = Path(__file__).parent / 'config.json'
|
339 |
+
assert config.exists(), f'Config path incorrect: config.json not found at {config}'
|
340 |
+
with open(config, 'r') as r:
|
341 |
+
args = recursive_munch(json.load(r))
|
342 |
+
assert args.decoder.type == 'istftnet', f'Unknown decoder type: {args.decoder.type}'
|
343 |
+
decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
|
344 |
+
resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
|
345 |
+
upsample_rates = args.decoder.upsample_rates,
|
346 |
+
upsample_initial_channel=args.decoder.upsample_initial_channel,
|
347 |
+
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
|
348 |
+
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
|
349 |
+
gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
|
350 |
+
text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
|
351 |
+
predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
|
352 |
+
bert = load_plbert()
|
353 |
+
bert_encoder = nn.Linear(bert.config.hidden_size, args.hidden_dim)
|
354 |
+
for parent in [bert, bert_encoder, predictor, decoder, text_encoder]:
|
355 |
+
for child in parent.children():
|
356 |
+
if isinstance(child, nn.RNNBase):
|
357 |
+
child.flatten_parameters()
|
358 |
+
model = Munch(
|
359 |
+
bert=bert.to(device).eval(),
|
360 |
+
bert_encoder=bert_encoder.to(device).eval(),
|
361 |
+
predictor=predictor.to(device).eval(),
|
362 |
+
decoder=decoder.to(device).eval(),
|
363 |
+
text_encoder=text_encoder.to(device).eval(),
|
364 |
+
)
|
365 |
+
for key, state_dict in torch.load(path, map_location='cpu', weights_only=True)['net'].items():
|
366 |
+
assert key in model, key
|
367 |
+
try:
|
368 |
+
model[key].load_state_dict(state_dict)
|
369 |
+
except:
|
370 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
371 |
+
model[key].load_state_dict(state_dict, strict=False)
|
372 |
+
return model
|
KOKORO/plbert.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
|
2 |
+
from transformers import AlbertConfig, AlbertModel
|
3 |
+
|
4 |
+
class CustomAlbert(AlbertModel):
|
5 |
+
def forward(self, *args, **kwargs):
|
6 |
+
# Call the original forward method
|
7 |
+
outputs = super().forward(*args, **kwargs)
|
8 |
+
# Only return the last_hidden_state
|
9 |
+
return outputs.last_hidden_state
|
10 |
+
|
11 |
+
def load_plbert():
|
12 |
+
plbert_config = {'vocab_size': 178, 'hidden_size': 768, 'num_attention_heads': 12, 'intermediate_size': 2048, 'max_position_embeddings': 512, 'num_hidden_layers': 12, 'dropout': 0.1}
|
13 |
+
albert_base_configuration = AlbertConfig(**plbert_config)
|
14 |
+
bert = CustomAlbert(albert_base_configuration)
|
15 |
+
return bert
|
KOKORO/utils.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .kokoro import normalize_text,phonemize,generate
|
2 |
+
import re
|
3 |
+
import librosa
|
4 |
+
import os
|
5 |
+
import uuid
|
6 |
+
from pydub.silence import split_on_silence
|
7 |
+
from pydub import AudioSegment
|
8 |
+
import wave
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
|
13 |
+
def create_audio_dir():
|
14 |
+
"""Creates the 'kokoro_audio' directory in the root folder if it doesn't exist."""
|
15 |
+
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
16 |
+
audio_dir = os.path.join(root_dir, "kokoro_audio")
|
17 |
+
|
18 |
+
if not os.path.exists(audio_dir):
|
19 |
+
os.makedirs(audio_dir)
|
20 |
+
print(f"Created directory: {audio_dir}")
|
21 |
+
else:
|
22 |
+
print(f"Directory already exists: {audio_dir}")
|
23 |
+
return audio_dir
|
24 |
+
|
25 |
+
temp_folder = create_audio_dir()
|
26 |
+
|
27 |
+
|
28 |
+
debug=False
|
29 |
+
def resplit_strings(arr):
|
30 |
+
# Handle edge cases
|
31 |
+
if not arr:
|
32 |
+
return '', ''
|
33 |
+
if len(arr) == 1:
|
34 |
+
return arr[0], ''
|
35 |
+
# Try each possible split point
|
36 |
+
min_diff = float('inf')
|
37 |
+
best_split = 0
|
38 |
+
# Calculate lengths when joined with spaces
|
39 |
+
lengths = [len(s) for s in arr]
|
40 |
+
spaces = len(arr) - 1 # Total spaces needed
|
41 |
+
# Try each split point
|
42 |
+
left_len = 0
|
43 |
+
right_len = sum(lengths) + spaces
|
44 |
+
for i in range(1, len(arr)):
|
45 |
+
# Add current word and space to left side
|
46 |
+
left_len += lengths[i-1] + (1 if i > 1 else 0)
|
47 |
+
# Remove current word and space from right side
|
48 |
+
right_len -= lengths[i-1] + 1
|
49 |
+
diff = abs(left_len - right_len)
|
50 |
+
if diff < min_diff:
|
51 |
+
min_diff = diff
|
52 |
+
best_split = i
|
53 |
+
# Join the strings with the best split point
|
54 |
+
return ' '.join(arr[:best_split]), ' '.join(arr[best_split:])
|
55 |
+
|
56 |
+
def recursive_split(text, voice):
|
57 |
+
if not text:
|
58 |
+
return []
|
59 |
+
tokens = phonemize(text, voice, norm=False)
|
60 |
+
if len(tokens) < 511:
|
61 |
+
return [(text, tokens, len(tokens))] if tokens else []
|
62 |
+
if ' ' not in text:
|
63 |
+
return []
|
64 |
+
for punctuation in ['!.?…', ':;', ',—']:
|
65 |
+
splits = re.split(f'(?:(?<=[{punctuation}])|(?<=[{punctuation}]["\'»])|(?<=[{punctuation}]["\'»]["\'»])) ', text)
|
66 |
+
if len(splits) > 1:
|
67 |
+
break
|
68 |
+
else:
|
69 |
+
splits = None
|
70 |
+
splits = splits or text.split(' ')
|
71 |
+
a, b = resplit_strings(splits)
|
72 |
+
return recursive_split(a, voice) + recursive_split(b, voice)
|
73 |
+
|
74 |
+
def segment_and_tokenize(text, voice, skip_square_brackets=True, newline_split=2):
|
75 |
+
if skip_square_brackets:
|
76 |
+
text = re.sub(r'\[.*?\]', '', text)
|
77 |
+
texts = [t.strip() for t in re.split('\n{'+str(newline_split)+',}', normalize_text(text))] if newline_split > 0 else [normalize_text(text)]
|
78 |
+
segments = [row for t in texts for row in recursive_split(t, voice)]
|
79 |
+
return [(i, *row) for i, row in enumerate(segments)]
|
80 |
+
|
81 |
+
|
82 |
+
def large_text(text,VOICE_NAME):
|
83 |
+
if len(text) <= 500:
|
84 |
+
return [(0, text, len(text))]
|
85 |
+
else:
|
86 |
+
result=segment_and_tokenize(text, VOICE_NAME[0])
|
87 |
+
filtered_result = [(row[0], row[1], row[3]) for row in result]
|
88 |
+
return filtered_result
|
89 |
+
|
90 |
+
|
91 |
+
def clamp_speed(speed):
|
92 |
+
if not isinstance(speed, float) and not isinstance(speed, int):
|
93 |
+
return 1
|
94 |
+
elif speed < 0.5:
|
95 |
+
# return 0.5
|
96 |
+
return speed
|
97 |
+
elif speed > 2:
|
98 |
+
return 2
|
99 |
+
return speed
|
100 |
+
|
101 |
+
def clamp_trim(trim):
|
102 |
+
if not isinstance(trim, float) and not isinstance(trim, int):
|
103 |
+
return 0.5
|
104 |
+
elif trim <= 0:
|
105 |
+
return 0
|
106 |
+
elif trim > 1:
|
107 |
+
return 0.5
|
108 |
+
return trim
|
109 |
+
|
110 |
+
def trim_if_needed(out, trim):
|
111 |
+
if not trim:
|
112 |
+
return out
|
113 |
+
a, b = librosa.effects.trim(out, top_db=30)[1]
|
114 |
+
a = int(a*trim)
|
115 |
+
b = int(len(out)-(len(out)-b)*trim)
|
116 |
+
return out[a:b]
|
117 |
+
|
118 |
+
#Above code copied from https://huggingface.co/spaces/hexgrad/Kokoro-TTS/blob/main/app.py
|
119 |
+
|
120 |
+
def get_random_file_name(output_file=""):
|
121 |
+
global temp_folder
|
122 |
+
if output_file=="":
|
123 |
+
random_id = str(uuid.uuid4())[:8]
|
124 |
+
output_file = f"{temp_folder}/{random_id}.wav"
|
125 |
+
return output_file
|
126 |
+
# Ensure temp_folder exists
|
127 |
+
if not os.path.exists(output_file):
|
128 |
+
return output_file
|
129 |
+
try:
|
130 |
+
if output_file and os.path.exists(output_file):
|
131 |
+
os.remove(output_file) # Try to remove the file if it exists
|
132 |
+
return output_file # Return the same name if the file was successfully removed
|
133 |
+
except Exception as e:
|
134 |
+
# print(f"Error removing file {output_file}: {e}")
|
135 |
+
random_id = str(uuid.uuid4())[:8]
|
136 |
+
output_file = f"{temp_folder}/{random_id}.wav"
|
137 |
+
return output_file
|
138 |
+
|
139 |
+
|
140 |
+
def remove_silence_function(file_path,minimum_silence=50):
|
141 |
+
# Extract file name and format from the provided path
|
142 |
+
output_path = file_path.replace(".wav", "_no_silence.wav")
|
143 |
+
audio_format = "wav"
|
144 |
+
# Reading and splitting the audio file into chunks
|
145 |
+
sound = AudioSegment.from_file(file_path, format=audio_format)
|
146 |
+
audio_chunks = split_on_silence(sound,
|
147 |
+
min_silence_len=100,
|
148 |
+
silence_thresh=-45,
|
149 |
+
keep_silence=minimum_silence)
|
150 |
+
# Putting the file back together
|
151 |
+
combined = AudioSegment.empty()
|
152 |
+
for chunk in audio_chunks:
|
153 |
+
combined += chunk
|
154 |
+
combined.export(output_path, format=audio_format)
|
155 |
+
return output_path
|
156 |
+
|
157 |
+
import simpleaudio as sa
|
158 |
+
def play_audio(filename):
|
159 |
+
wave_obj = sa.WaveObject.from_wave_file(filename)
|
160 |
+
play_obj = wave_obj.play()
|
161 |
+
play_obj.wait_done()
|
162 |
+
|
163 |
+
|
164 |
+
import re
|
165 |
+
|
166 |
+
def clean_text(text):
|
167 |
+
# Define replacement rules
|
168 |
+
replacements = {
|
169 |
+
"–": " ", # Replace en-dash with space
|
170 |
+
"-": " ", # Replace hyphen with space
|
171 |
+
":": ",", # Replace colon with comma
|
172 |
+
"**": " ", # Replace double asterisks with space
|
173 |
+
"*": " ", # Replace single asterisk with space
|
174 |
+
"#": " ", # Replace hash with space
|
175 |
+
}
|
176 |
+
|
177 |
+
# Apply replacements
|
178 |
+
for old, new in replacements.items():
|
179 |
+
text = text.replace(old, new)
|
180 |
+
|
181 |
+
# Remove emojis using regex (covering wide range of Unicode characters)
|
182 |
+
emoji_pattern = re.compile(
|
183 |
+
r'[\U0001F600-\U0001F64F]|' # Emoticons
|
184 |
+
r'[\U0001F300-\U0001F5FF]|' # Miscellaneous symbols and pictographs
|
185 |
+
r'[\U0001F680-\U0001F6FF]|' # Transport and map symbols
|
186 |
+
r'[\U0001F700-\U0001F77F]|' # Alchemical symbols
|
187 |
+
r'[\U0001F780-\U0001F7FF]|' # Geometric shapes extended
|
188 |
+
r'[\U0001F800-\U0001F8FF]|' # Supplemental arrows-C
|
189 |
+
r'[\U0001F900-\U0001F9FF]|' # Supplemental symbols and pictographs
|
190 |
+
r'[\U0001FA00-\U0001FA6F]|' # Chess symbols
|
191 |
+
r'[\U0001FA70-\U0001FAFF]|' # Symbols and pictographs extended-A
|
192 |
+
r'[\U00002702-\U000027B0]|' # Dingbats
|
193 |
+
r'[\U0001F1E0-\U0001F1FF]' # Flags (iOS)
|
194 |
+
r'', flags=re.UNICODE)
|
195 |
+
text = emoji_pattern.sub(r'', text)
|
196 |
+
|
197 |
+
# Remove multiple spaces and extra line breaks
|
198 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
199 |
+
|
200 |
+
return text
|
201 |
+
|
202 |
+
#copied from F5TTS 😁
|
203 |
+
import re
|
204 |
+
def parse_speechtypes_text(gen_text):
|
205 |
+
# Pattern to find {speechtype}
|
206 |
+
pattern = r"\{(.*?)\}"
|
207 |
+
|
208 |
+
# Split the text by the pattern
|
209 |
+
tokens = re.split(pattern, gen_text)
|
210 |
+
|
211 |
+
segments = []
|
212 |
+
|
213 |
+
current_style = "af"
|
214 |
+
|
215 |
+
for i in range(len(tokens)):
|
216 |
+
if i % 2 == 0:
|
217 |
+
# This is text
|
218 |
+
text = tokens[i].strip()
|
219 |
+
if text:
|
220 |
+
text=clean_text(text)
|
221 |
+
segments.append({"voice_name": current_style, "text": text})
|
222 |
+
else:
|
223 |
+
# This is style
|
224 |
+
style = tokens[i].strip()
|
225 |
+
current_style = style
|
226 |
+
|
227 |
+
return segments
|
228 |
+
|
229 |
+
def podcast(MODEL, device, gen_text, speed=1.0, trim=0, pad_between_segments=0, remove_silence=True, minimum_silence=50):
|
230 |
+
segments = parse_speechtypes_text(gen_text)
|
231 |
+
speed = clamp_speed(speed)
|
232 |
+
trim = clamp_trim(trim)
|
233 |
+
silence_duration = clamp_trim(pad_between_segments)
|
234 |
+
# output_file = get_random_file_name(output_file)
|
235 |
+
sample_rate = 24000 # Sample rate of the audio
|
236 |
+
|
237 |
+
# Create a silent audio segment in float32
|
238 |
+
silence = np.zeros(int(sample_rate * silence_duration), dtype=np.float32)
|
239 |
+
if len(segments)>=1:
|
240 |
+
first_line_text=segments[0]["text"]
|
241 |
+
output_file=tts_file_name(first_line_text)
|
242 |
+
else:
|
243 |
+
output_file = get_random_file_name("")
|
244 |
+
|
245 |
+
output_file = output_file.replace('\n', '').replace('\r', '')
|
246 |
+
# Open a WAV file for writing
|
247 |
+
with wave.open(output_file, 'wb') as wav_file:
|
248 |
+
wav_file.setnchannels(1) # Mono
|
249 |
+
wav_file.setsampwidth(2) # 16-bit audio
|
250 |
+
wav_file.setframerate(sample_rate)
|
251 |
+
|
252 |
+
for idx, segment in enumerate(segments): # Added index `idx` to track position
|
253 |
+
voice_name = segment["voice_name"]
|
254 |
+
text = segment["text"]
|
255 |
+
voice_pack_path = f"./KOKORO/voices/{voice_name}.pt"
|
256 |
+
VOICEPACK = torch.load(voice_pack_path, weights_only=True).to(device)
|
257 |
+
|
258 |
+
# Generate audio for the segment
|
259 |
+
audio, out_ps = generate(MODEL, text, VOICEPACK, lang=voice_name[0], speed=speed)
|
260 |
+
audio = trim_if_needed(audio, trim)
|
261 |
+
|
262 |
+
# Scale audio from float32 to int16
|
263 |
+
audio = (audio * 32767).astype(np.int16)
|
264 |
+
|
265 |
+
# Write the audio segment to the WAV file
|
266 |
+
wav_file.writeframes(audio.tobytes())
|
267 |
+
|
268 |
+
# Add silence between segments, except after the last segment
|
269 |
+
if idx != len(segments) - 1:
|
270 |
+
wav_file.writeframes((silence * 32767).astype(np.int16).tobytes())
|
271 |
+
|
272 |
+
# Optionally remove silence from the output file
|
273 |
+
if remove_silence:
|
274 |
+
output_file = remove_silence_function(output_file, minimum_silence=minimum_silence)
|
275 |
+
|
276 |
+
return output_file
|
277 |
+
|
278 |
+
def tts(MODEL,device,text, voice_name, speed=1.0, trim=0.5, pad_between_segments=0.5, output_file="",remove_silence=True,minimum_silence=50):
|
279 |
+
text=clean_text(text)
|
280 |
+
segments = large_text(text, voice_name)
|
281 |
+
voice_pack_path = f"./KOKORO/voices/{voice_name}.pt"
|
282 |
+
VOICEPACK = torch.load(voice_pack_path, weights_only=True).to(device)
|
283 |
+
speed = clamp_speed(speed)
|
284 |
+
trim = clamp_trim(trim)
|
285 |
+
silence_duration = clamp_trim(pad_between_segments)
|
286 |
+
output_file=get_random_file_name(output_file)
|
287 |
+
if debug:
|
288 |
+
print(f'Loaded voice: {voice_name}')
|
289 |
+
print(f"Speed: {speed}")
|
290 |
+
print(f"Trim: {trim}")
|
291 |
+
print(f"Silence duration: {silence_duration}")
|
292 |
+
sample_rate = 24000 # Sample rate of the audio
|
293 |
+
|
294 |
+
# Create a silent audio segment in float32
|
295 |
+
silence = np.zeros(int(sample_rate * silence_duration), dtype=np.float32)
|
296 |
+
|
297 |
+
# Open a WAV file for writing
|
298 |
+
with wave.open(output_file, 'wb') as wav_file:
|
299 |
+
wav_file.setnchannels(1) # Mono
|
300 |
+
wav_file.setsampwidth(2) # 16-bit audio
|
301 |
+
wav_file.setframerate(sample_rate)
|
302 |
+
|
303 |
+
for i in segments:
|
304 |
+
id = i[0]
|
305 |
+
text = i[1]
|
306 |
+
if debug:
|
307 |
+
print(i)
|
308 |
+
audio, out_ps = generate(MODEL, text, VOICEPACK, lang=voice_name[0], speed=speed)
|
309 |
+
audio = trim_if_needed(audio, trim)
|
310 |
+
|
311 |
+
# Scale audio from float32 to int16
|
312 |
+
audio = (audio * 32767).astype(np.int16)
|
313 |
+
|
314 |
+
# Write the audio segment to the WAV file
|
315 |
+
wav_file.writeframes(audio.tobytes())
|
316 |
+
|
317 |
+
# Add silence between segments, except after the last segment
|
318 |
+
if id != len(segments) - 1:
|
319 |
+
wav_file.writeframes((silence * 32767).astype(np.int16).tobytes())
|
320 |
+
if remove_silence:
|
321 |
+
output_file=remove_silence_function(output_file,minimum_silence=minimum_silence)
|
322 |
+
return output_file
|
323 |
+
|
324 |
+
|
325 |
+
|
326 |
+
def tts_file_name(text):
|
327 |
+
global temp_folder
|
328 |
+
# Remove all non-alphabetic characters and convert to lowercase
|
329 |
+
text = re.sub(r'[^a-zA-Z\s]', '', text) # Retain only alphabets and spaces
|
330 |
+
text = text.lower().strip() # Convert to lowercase and strip leading/trailing spaces
|
331 |
+
text = text.replace(" ", "_") # Replace spaces with underscores
|
332 |
+
|
333 |
+
# Truncate or handle empty text
|
334 |
+
truncated_text = text[:25] if len(text) > 25 else text if len(text) > 0 else "empty"
|
335 |
+
|
336 |
+
# Generate a random string for uniqueness
|
337 |
+
random_string = uuid.uuid4().hex[:8].upper()
|
338 |
+
|
339 |
+
# Construct the file name
|
340 |
+
file_name = f"{temp_folder}/{truncated_text}_{random_string}.wav"
|
341 |
+
return file_name
|
342 |
+
|