NeuralFalcon commited on
Commit
ab3c5d4
·
verified ·
1 Parent(s): e0afd53

Upload 7 files

Browse files
Files changed (7) hide show
  1. KOKORO/README.md +163 -0
  2. KOKORO/config.json +26 -0
  3. KOKORO/istftnet.py +523 -0
  4. KOKORO/kokoro.py +161 -0
  5. KOKORO/models.py +372 -0
  6. KOKORO/plbert.py +15 -0
  7. KOKORO/utils.py +342 -0
KOKORO/README.md ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ base_model:
6
+ - yl4579/StyleTTS2-LJSpeech
7
+ pipeline_tag: text-to-speech
8
+ ---
9
+ 📣 Jan 12 Status: Intent to improve the base model https://hf.co/hexgrad/Kokoro-82M/discussions/36
10
+
11
+ ❤️ Kokoro Discord Server: https://discord.gg/QuGxSWBfQy
12
+
13
+ 🚨 Got Synthetic Data? Want Trained Voicepacks? See https://hf.co/posts/hexgrad/418806998707773
14
+
15
+ <audio controls><source src="https://huggingface.co/hexgrad/Kokoro-82M/resolve/main/demo/HEARME.wav" type="audio/wav"></audio>
16
+
17
+ **Kokoro** is a frontier TTS model for its size of **82 million parameters** (text in/audio out).
18
+
19
+ On 25 Dec 2024, Kokoro v0.19 weights were permissively released in full fp32 precision under an Apache 2.0 license. As of 2 Jan 2025, 10 unique Voicepacks have been released, and a `.onnx` version of v0.19 is available.
20
+
21
+ In the weeks leading up to its release, Kokoro v0.19 was the #1🥇 ranked model in [TTS Spaces Arena](https://huggingface.co/hexgrad/Kokoro-82M#evaluation). Kokoro had achieved higher Elo in this single-voice Arena setting over other models, using fewer parameters and less data:
22
+ 1. **Kokoro v0.19: 82M params, Apache, trained on <100 hours of audio**
23
+ 2. XTTS v2: 467M, CPML, >10k hours
24
+ 3. Edge TTS: Microsoft, proprietary
25
+ 4. MetaVoice: 1.2B, Apache, 100k hours
26
+ 5. Parler Mini: 880M, Apache, 45k hours
27
+ 6. Fish Speech: ~500M, CC-BY-NC-SA, 1M hours
28
+
29
+ Kokoro's ability to top this Elo ladder suggests that the scaling law (Elo vs compute/data/params) for traditional TTS models might have a steeper slope than previously expected.
30
+
31
+ You can find a hosted demo at [hf.co/spaces/hexgrad/Kokoro-TTS](https://huggingface.co/spaces/hexgrad/Kokoro-TTS).
32
+
33
+ ### Usage
34
+
35
+ The following can be run in a single cell on [Google Colab](https://colab.research.google.com/).
36
+ ```py
37
+ # 1️⃣ Install dependencies silently
38
+ !git lfs install
39
+ !git clone https://huggingface.co/hexgrad/Kokoro-82M
40
+ %cd Kokoro-82M
41
+ !apt-get -qq -y install espeak-ng > /dev/null 2>&1
42
+ !pip install -q phonemizer torch transformers scipy munch
43
+
44
+ # 2️⃣ Build the model and load the default voicepack
45
+ from models import build_model
46
+ import torch
47
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
48
+ MODEL = build_model('kokoro-v0_19.pth', device)
49
+ VOICE_NAME = [
50
+ 'af', # Default voice is a 50-50 mix of Bella & Sarah
51
+ 'af_bella', 'af_sarah', 'am_adam', 'am_michael',
52
+ 'bf_emma', 'bf_isabella', 'bm_george', 'bm_lewis',
53
+ 'af_nicole', 'af_sky',
54
+ ][0]
55
+ VOICEPACK = torch.load(f'voices/{VOICE_NAME}.pt', weights_only=True).to(device)
56
+ print(f'Loaded voice: {VOICE_NAME}')
57
+
58
+ # 3️⃣ Call generate, which returns 24khz audio and the phonemes used
59
+ from kokoro import generate
60
+ text = "How could I know? It's an unanswerable question. Like asking an unborn child if they'll lead a good life. They haven't even been born."
61
+ audio, out_ps = generate(MODEL, text, VOICEPACK, lang=VOICE_NAME[0])
62
+ # Language is determined by the first letter of the VOICE_NAME:
63
+ # 🇺🇸 'a' => American English => en-us
64
+ # 🇬🇧 'b' => British English => en-gb
65
+
66
+ # 4️⃣ Display the 24khz audio and print the output phonemes
67
+ from IPython.display import display, Audio
68
+ display(Audio(data=audio, rate=24000, autoplay=True))
69
+ print(out_ps)
70
+ ```
71
+ If you have trouble with `espeak-ng`, see this [github issue](https://github.com/bootphon/phonemizer/issues/44#issuecomment-1540885186). [Mac users also see this](https://huggingface.co/hexgrad/Kokoro-82M/discussions/12#677435d3d8ace1de46071489), and [Windows users see this](https://huggingface.co/hexgrad/Kokoro-82M/discussions/12#67742594fdeebf74f001ecfc).
72
+
73
+ For ONNX usage, see [#14](https://huggingface.co/hexgrad/Kokoro-82M/discussions/14).
74
+
75
+ ### Model Facts
76
+
77
+ No affiliation can be assumed between parties on different lines.
78
+
79
+ **Architecture:**
80
+ - StyleTTS 2: https://arxiv.org/abs/2306.07691
81
+ - ISTFTNet: https://arxiv.org/abs/2203.02395
82
+ - Decoder only: no diffusion, no encoder release
83
+
84
+ **Architected by:** Li et al @ https://github.com/yl4579/StyleTTS2
85
+
86
+ **Trained by**: `@rzvzn` on Discord
87
+
88
+ **Supported Languages:** American English, British English
89
+
90
+ **Model SHA256 Hash:** `3b0c392f87508da38fad3a2f9d94c359f1b657ebd2ef79f9d56d69503e470b0a`
91
+
92
+ ### Releases
93
+ - 25 Dec 2024: Model v0.19, `af_bella`, `af_sarah`
94
+ - 26 Dec 2024: `am_adam`, `am_michael`
95
+ - 28 Dec 2024: `bf_emma`, `bf_isabella`, `bm_george`, `bm_lewis`
96
+ - 30 Dec 2024: `af_nicole`
97
+ - 31 Dec 2024: `af_sky`
98
+ - 2 Jan 2025: ONNX v0.19 `ebef4245`
99
+
100
+ ### Licenses
101
+ - Apache 2.0 weights in this repository
102
+ - MIT inference code in [spaces/hexgrad/Kokoro-TTS](https://huggingface.co/spaces/hexgrad/Kokoro-TTS) adapted from [yl4579/StyleTTS2](https://github.com/yl4579/StyleTTS2)
103
+ - GPLv3 dependency in [espeak-ng](https://github.com/espeak-ng/espeak-ng)
104
+
105
+ The inference code was originally MIT licensed by the paper author. Note that this card applies only to this model, Kokoro. Original models published by the paper author can be found at [hf.co/yl4579](https://huggingface.co/yl4579).
106
+
107
+ ### Evaluation
108
+
109
+ **Metric:** Elo rating
110
+
111
+ **Leaderboard:** [hf.co/spaces/Pendrokar/TTS-Spaces-Arena](https://huggingface.co/spaces/Pendrokar/TTS-Spaces-Arena)
112
+
113
+ ![TTS-Spaces-Arena-25-Dec-2024](demo/TTS-Spaces-Arena-25-Dec-2024.png)
114
+
115
+ The voice ranked in the Arena is a 50-50 mix of Bella and Sarah. For your convenience, this mix is included in this repository as `af.pt`, but you can trivially reproduce it like this:
116
+
117
+ ```py
118
+ import torch
119
+ bella = torch.load('voices/af_bella.pt', weights_only=True)
120
+ sarah = torch.load('voices/af_sarah.pt', weights_only=True)
121
+ af = torch.mean(torch.stack([bella, sarah]), dim=0)
122
+ assert torch.equal(af, torch.load('voices/af.pt', weights_only=True))
123
+ ```
124
+
125
+ ### Training Details
126
+
127
+ **Compute:** Kokoro was trained on A100 80GB vRAM instances rented from [Vast.ai](https://cloud.vast.ai/?ref_id=79907) (referral link). Vast was chosen over other compute providers due to its competitive on-demand hourly rates. The average hourly cost for the A100 80GB vRAM instances used for training was below $1/hr per GPU, which was around half the quoted rates from other providers at the time.
128
+
129
+ **Data:** Kokoro was trained exclusively on **permissive/non-copyrighted audio data** and IPA phoneme labels. Examples of permissive/non-copyrighted audio include:
130
+ - Public domain audio
131
+ - Audio licensed under Apache, MIT, etc
132
+ - Synthetic audio<sup>[1]</sup> generated by closed<sup>[2]</sup> TTS models from large providers<br/>
133
+ [1] https://copyright.gov/ai/ai_policy_guidance.pdf<br/>
134
+ [2] No synthetic audio from open TTS models or "custom voice clones"
135
+
136
+ **Epochs:** Less than **20 epochs**
137
+
138
+ **Total Dataset Size:** Less than **100 hours** of audio
139
+
140
+ ### Limitations
141
+
142
+ Kokoro v0.19 is limited in some specific ways, due to its training set and/or architecture:
143
+ - [Data] Lacks voice cloning capability, likely due to small <100h training set
144
+ - [Arch] Relies on external g2p (espeak-ng), which introduces a class of g2p failure modes
145
+ - [Data] Training dataset is mostly long-form reading and narration, not conversation
146
+ - [Arch] At 82M params, Kokoro almost certainly falls to a well-trained 1B+ param diffusion transformer, or a many-billion-param MLLM like GPT-4o / Gemini 2.0 Flash
147
+ - [Data] Multilingual capability is architecturally feasible, but training data is mostly English
148
+
149
+ Refer to the [Philosophy discussion](https://huggingface.co/hexgrad/Kokoro-82M/discussions/5) to better understand these limitations.
150
+
151
+ **Will the other voicepacks be released?** There is currently no release date scheduled for the other voicepacks, but in the meantime you can try them in the hosted demo at [hf.co/spaces/hexgrad/Kokoro-TTS](https://huggingface.co/spaces/hexgrad/Kokoro-TTS).
152
+
153
+ ### Acknowledgements
154
+ - [@yl4579](https://huggingface.co/yl4579) for architecting StyleTTS 2
155
+ - [@Pendrokar](https://huggingface.co/Pendrokar) for adding Kokoro as a contender in the TTS Spaces Arena
156
+
157
+ ### Model Card Contact
158
+
159
+ `@rzvzn` on Discord. Server invite: https://discord.gg/QuGxSWBfQy
160
+
161
+ <img src="https://static0.gamerantimages.com/wordpress/wp-content/uploads/2024/08/terminator-zero-41-1.jpg" width="400" alt="kokoro" />
162
+
163
+ https://terminator.fandom.com/wiki/Kokoro
KOKORO/config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "decoder": {
3
+ "type": "istftnet",
4
+ "upsample_kernel_sizes": [20, 12],
5
+ "upsample_rates": [10, 6],
6
+ "gen_istft_hop_size": 5,
7
+ "gen_istft_n_fft": 20,
8
+ "resblock_dilation_sizes": [
9
+ [1, 3, 5],
10
+ [1, 3, 5],
11
+ [1, 3, 5]
12
+ ],
13
+ "resblock_kernel_sizes": [3, 7, 11],
14
+ "upsample_initial_channel": 512
15
+ },
16
+ "dim_in": 64,
17
+ "dropout": 0.2,
18
+ "hidden_dim": 512,
19
+ "max_conv_dim": 512,
20
+ "max_dur": 50,
21
+ "multispeaker": true,
22
+ "n_layer": 3,
23
+ "n_mels": 80,
24
+ "n_token": 178,
25
+ "style_dim": 128
26
+ }
KOKORO/istftnet.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
2
+ from scipy.signal import get_window
3
+ from torch.nn import Conv1d, ConvTranspose1d
4
+ from torch.nn.utils import weight_norm, remove_weight_norm
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ # https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
11
+ def init_weights(m, mean=0.0, std=0.01):
12
+ classname = m.__class__.__name__
13
+ if classname.find("Conv") != -1:
14
+ m.weight.data.normal_(mean, std)
15
+
16
+ def get_padding(kernel_size, dilation=1):
17
+ return int((kernel_size*dilation - dilation)/2)
18
+
19
+ LRELU_SLOPE = 0.1
20
+
21
+ class AdaIN1d(nn.Module):
22
+ def __init__(self, style_dim, num_features):
23
+ super().__init__()
24
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
25
+ self.fc = nn.Linear(style_dim, num_features*2)
26
+
27
+ def forward(self, x, s):
28
+ h = self.fc(s)
29
+ h = h.view(h.size(0), h.size(1), 1)
30
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
31
+ return (1 + gamma) * self.norm(x) + beta
32
+
33
+ class AdaINResBlock1(torch.nn.Module):
34
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
35
+ super(AdaINResBlock1, self).__init__()
36
+ self.convs1 = nn.ModuleList([
37
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
38
+ padding=get_padding(kernel_size, dilation[0]))),
39
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
40
+ padding=get_padding(kernel_size, dilation[1]))),
41
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
42
+ padding=get_padding(kernel_size, dilation[2])))
43
+ ])
44
+ self.convs1.apply(init_weights)
45
+
46
+ self.convs2 = nn.ModuleList([
47
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
48
+ padding=get_padding(kernel_size, 1))),
49
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
50
+ padding=get_padding(kernel_size, 1))),
51
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
52
+ padding=get_padding(kernel_size, 1)))
53
+ ])
54
+ self.convs2.apply(init_weights)
55
+
56
+ self.adain1 = nn.ModuleList([
57
+ AdaIN1d(style_dim, channels),
58
+ AdaIN1d(style_dim, channels),
59
+ AdaIN1d(style_dim, channels),
60
+ ])
61
+
62
+ self.adain2 = nn.ModuleList([
63
+ AdaIN1d(style_dim, channels),
64
+ AdaIN1d(style_dim, channels),
65
+ AdaIN1d(style_dim, channels),
66
+ ])
67
+
68
+ self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
69
+ self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
70
+
71
+
72
+ def forward(self, x, s):
73
+ for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
74
+ xt = n1(x, s)
75
+ xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
76
+ xt = c1(xt)
77
+ xt = n2(xt, s)
78
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
79
+ xt = c2(xt)
80
+ x = xt + x
81
+ return x
82
+
83
+ def remove_weight_norm(self):
84
+ for l in self.convs1:
85
+ remove_weight_norm(l)
86
+ for l in self.convs2:
87
+ remove_weight_norm(l)
88
+
89
+ class TorchSTFT(torch.nn.Module):
90
+ def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
91
+ super().__init__()
92
+ self.filter_length = filter_length
93
+ self.hop_length = hop_length
94
+ self.win_length = win_length
95
+ self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32))
96
+
97
+ def transform(self, input_data):
98
+ forward_transform = torch.stft(
99
+ input_data,
100
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
101
+ return_complex=True)
102
+
103
+ return torch.abs(forward_transform), torch.angle(forward_transform)
104
+
105
+ def inverse(self, magnitude, phase):
106
+ inverse_transform = torch.istft(
107
+ magnitude * torch.exp(phase * 1j),
108
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
109
+
110
+ return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
111
+
112
+ def forward(self, input_data):
113
+ self.magnitude, self.phase = self.transform(input_data)
114
+ reconstruction = self.inverse(self.magnitude, self.phase)
115
+ return reconstruction
116
+
117
+ class SineGen(torch.nn.Module):
118
+ """ Definition of sine generator
119
+ SineGen(samp_rate, harmonic_num = 0,
120
+ sine_amp = 0.1, noise_std = 0.003,
121
+ voiced_threshold = 0,
122
+ flag_for_pulse=False)
123
+ samp_rate: sampling rate in Hz
124
+ harmonic_num: number of harmonic overtones (default 0)
125
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
126
+ noise_std: std of Gaussian noise (default 0.003)
127
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
128
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
129
+ Note: when flag_for_pulse is True, the first time step of a voiced
130
+ segment is always sin(np.pi) or cos(0)
131
+ """
132
+
133
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
134
+ sine_amp=0.1, noise_std=0.003,
135
+ voiced_threshold=0,
136
+ flag_for_pulse=False):
137
+ super(SineGen, self).__init__()
138
+ self.sine_amp = sine_amp
139
+ self.noise_std = noise_std
140
+ self.harmonic_num = harmonic_num
141
+ self.dim = self.harmonic_num + 1
142
+ self.sampling_rate = samp_rate
143
+ self.voiced_threshold = voiced_threshold
144
+ self.flag_for_pulse = flag_for_pulse
145
+ self.upsample_scale = upsample_scale
146
+
147
+ def _f02uv(self, f0):
148
+ # generate uv signal
149
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
150
+ return uv
151
+
152
+ def _f02sine(self, f0_values):
153
+ """ f0_values: (batchsize, length, dim)
154
+ where dim indicates fundamental tone and overtones
155
+ """
156
+ # convert to F0 in rad. The interger part n can be ignored
157
+ # because 2 * np.pi * n doesn't affect phase
158
+ rad_values = (f0_values / self.sampling_rate) % 1
159
+
160
+ # initial phase noise (no noise for fundamental component)
161
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
162
+ device=f0_values.device)
163
+ rand_ini[:, 0] = 0
164
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
165
+
166
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
167
+ if not self.flag_for_pulse:
168
+ # # for normal case
169
+
170
+ # # To prevent torch.cumsum numerical overflow,
171
+ # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
172
+ # # Buffer tmp_over_one_idx indicates the time step to add -1.
173
+ # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
174
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
175
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
176
+ # cumsum_shift = torch.zeros_like(rad_values)
177
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
178
+
179
+ # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
180
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
181
+ scale_factor=1/self.upsample_scale,
182
+ mode="linear").transpose(1, 2)
183
+
184
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
185
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
186
+ # cumsum_shift = torch.zeros_like(rad_values)
187
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
188
+
189
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
190
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
191
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
192
+ sines = torch.sin(phase)
193
+
194
+ else:
195
+ # If necessary, make sure that the first time step of every
196
+ # voiced segments is sin(pi) or cos(0)
197
+ # This is used for pulse-train generation
198
+
199
+ # identify the last time step in unvoiced segments
200
+ uv = self._f02uv(f0_values)
201
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
202
+ uv_1[:, -1, :] = 1
203
+ u_loc = (uv < 1) * (uv_1 > 0)
204
+
205
+ # get the instantanouse phase
206
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
207
+ # different batch needs to be processed differently
208
+ for idx in range(f0_values.shape[0]):
209
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
210
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
211
+ # stores the accumulation of i.phase within
212
+ # each voiced segments
213
+ tmp_cumsum[idx, :, :] = 0
214
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
215
+
216
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
217
+ # within the previous voiced segment.
218
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
219
+
220
+ # get the sines
221
+ sines = torch.cos(i_phase * 2 * np.pi)
222
+ return sines
223
+
224
+ def forward(self, f0):
225
+ """ sine_tensor, uv = forward(f0)
226
+ input F0: tensor(batchsize=1, length, dim=1)
227
+ f0 for unvoiced steps should be 0
228
+ output sine_tensor: tensor(batchsize=1, length, dim)
229
+ output uv: tensor(batchsize=1, length, 1)
230
+ """
231
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
232
+ device=f0.device)
233
+ # fundamental component
234
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
235
+
236
+ # generate sine waveforms
237
+ sine_waves = self._f02sine(fn) * self.sine_amp
238
+
239
+ # generate uv signal
240
+ # uv = torch.ones(f0.shape)
241
+ # uv = uv * (f0 > self.voiced_threshold)
242
+ uv = self._f02uv(f0)
243
+
244
+ # noise: for unvoiced should be similar to sine_amp
245
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
246
+ # . for voiced regions is self.noise_std
247
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
248
+ noise = noise_amp * torch.randn_like(sine_waves)
249
+
250
+ # first: set the unvoiced part to 0 by uv
251
+ # then: additive noise
252
+ sine_waves = sine_waves * uv + noise
253
+ return sine_waves, uv, noise
254
+
255
+
256
+ class SourceModuleHnNSF(torch.nn.Module):
257
+ """ SourceModule for hn-nsf
258
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
259
+ add_noise_std=0.003, voiced_threshod=0)
260
+ sampling_rate: sampling_rate in Hz
261
+ harmonic_num: number of harmonic above F0 (default: 0)
262
+ sine_amp: amplitude of sine source signal (default: 0.1)
263
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
264
+ note that amplitude of noise in unvoiced is decided
265
+ by sine_amp
266
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
267
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
268
+ F0_sampled (batchsize, length, 1)
269
+ Sine_source (batchsize, length, 1)
270
+ noise_source (batchsize, length 1)
271
+ uv (batchsize, length, 1)
272
+ """
273
+
274
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
275
+ add_noise_std=0.003, voiced_threshod=0):
276
+ super(SourceModuleHnNSF, self).__init__()
277
+
278
+ self.sine_amp = sine_amp
279
+ self.noise_std = add_noise_std
280
+
281
+ # to produce sine waveforms
282
+ self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
283
+ sine_amp, add_noise_std, voiced_threshod)
284
+
285
+ # to merge source harmonics into a single excitation
286
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
287
+ self.l_tanh = torch.nn.Tanh()
288
+
289
+ def forward(self, x):
290
+ """
291
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
292
+ F0_sampled (batchsize, length, 1)
293
+ Sine_source (batchsize, length, 1)
294
+ noise_source (batchsize, length 1)
295
+ """
296
+ # source for harmonic branch
297
+ with torch.no_grad():
298
+ sine_wavs, uv, _ = self.l_sin_gen(x)
299
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
300
+
301
+ # source for noise branch, in the same shape as uv
302
+ noise = torch.randn_like(uv) * self.sine_amp / 3
303
+ return sine_merge, noise, uv
304
+ def padDiff(x):
305
+ return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
306
+
307
+
308
+ class Generator(torch.nn.Module):
309
+ def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size):
310
+ super(Generator, self).__init__()
311
+
312
+ self.num_kernels = len(resblock_kernel_sizes)
313
+ self.num_upsamples = len(upsample_rates)
314
+ resblock = AdaINResBlock1
315
+
316
+ self.m_source = SourceModuleHnNSF(
317
+ sampling_rate=24000,
318
+ upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
319
+ harmonic_num=8, voiced_threshod=10)
320
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size)
321
+ self.noise_convs = nn.ModuleList()
322
+ self.noise_res = nn.ModuleList()
323
+
324
+ self.ups = nn.ModuleList()
325
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
326
+ self.ups.append(weight_norm(
327
+ ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
328
+ k, u, padding=(k-u)//2)))
329
+
330
+ self.resblocks = nn.ModuleList()
331
+ for i in range(len(self.ups)):
332
+ ch = upsample_initial_channel//(2**(i+1))
333
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
334
+ self.resblocks.append(resblock(ch, k, d, style_dim))
335
+
336
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
337
+
338
+ if i + 1 < len(upsample_rates): #
339
+ stride_f0 = np.prod(upsample_rates[i + 1:])
340
+ self.noise_convs.append(Conv1d(
341
+ gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
342
+ self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
343
+ else:
344
+ self.noise_convs.append(Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
345
+ self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
346
+
347
+
348
+ self.post_n_fft = gen_istft_n_fft
349
+ self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
350
+ self.ups.apply(init_weights)
351
+ self.conv_post.apply(init_weights)
352
+ self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
353
+ self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
354
+
355
+
356
+ def forward(self, x, s, f0):
357
+ with torch.no_grad():
358
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
359
+
360
+ har_source, noi_source, uv = self.m_source(f0)
361
+ har_source = har_source.transpose(1, 2).squeeze(1)
362
+ har_spec, har_phase = self.stft.transform(har_source)
363
+ har = torch.cat([har_spec, har_phase], dim=1)
364
+
365
+ for i in range(self.num_upsamples):
366
+ x = F.leaky_relu(x, LRELU_SLOPE)
367
+ x_source = self.noise_convs[i](har)
368
+ x_source = self.noise_res[i](x_source, s)
369
+
370
+ x = self.ups[i](x)
371
+ if i == self.num_upsamples - 1:
372
+ x = self.reflection_pad(x)
373
+
374
+ x = x + x_source
375
+ xs = None
376
+ for j in range(self.num_kernels):
377
+ if xs is None:
378
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
379
+ else:
380
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
381
+ x = xs / self.num_kernels
382
+ x = F.leaky_relu(x)
383
+ x = self.conv_post(x)
384
+ spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
385
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
386
+ return self.stft.inverse(spec, phase)
387
+
388
+ def fw_phase(self, x, s):
389
+ for i in range(self.num_upsamples):
390
+ x = F.leaky_relu(x, LRELU_SLOPE)
391
+ x = self.ups[i](x)
392
+ xs = None
393
+ for j in range(self.num_kernels):
394
+ if xs is None:
395
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
396
+ else:
397
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
398
+ x = xs / self.num_kernels
399
+ x = F.leaky_relu(x)
400
+ x = self.reflection_pad(x)
401
+ x = self.conv_post(x)
402
+ spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
403
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
404
+ return spec, phase
405
+
406
+ def remove_weight_norm(self):
407
+ print('Removing weight norm...')
408
+ for l in self.ups:
409
+ remove_weight_norm(l)
410
+ for l in self.resblocks:
411
+ l.remove_weight_norm()
412
+ remove_weight_norm(self.conv_pre)
413
+ remove_weight_norm(self.conv_post)
414
+
415
+
416
+ class AdainResBlk1d(nn.Module):
417
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
418
+ upsample='none', dropout_p=0.0):
419
+ super().__init__()
420
+ self.actv = actv
421
+ self.upsample_type = upsample
422
+ self.upsample = UpSample1d(upsample)
423
+ self.learned_sc = dim_in != dim_out
424
+ self._build_weights(dim_in, dim_out, style_dim)
425
+ self.dropout = nn.Dropout(dropout_p)
426
+
427
+ if upsample == 'none':
428
+ self.pool = nn.Identity()
429
+ else:
430
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
431
+
432
+
433
+ def _build_weights(self, dim_in, dim_out, style_dim):
434
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
435
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
436
+ self.norm1 = AdaIN1d(style_dim, dim_in)
437
+ self.norm2 = AdaIN1d(style_dim, dim_out)
438
+ if self.learned_sc:
439
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
440
+
441
+ def _shortcut(self, x):
442
+ x = self.upsample(x)
443
+ if self.learned_sc:
444
+ x = self.conv1x1(x)
445
+ return x
446
+
447
+ def _residual(self, x, s):
448
+ x = self.norm1(x, s)
449
+ x = self.actv(x)
450
+ x = self.pool(x)
451
+ x = self.conv1(self.dropout(x))
452
+ x = self.norm2(x, s)
453
+ x = self.actv(x)
454
+ x = self.conv2(self.dropout(x))
455
+ return x
456
+
457
+ def forward(self, x, s):
458
+ out = self._residual(x, s)
459
+ out = (out + self._shortcut(x)) / np.sqrt(2)
460
+ return out
461
+
462
+ class UpSample1d(nn.Module):
463
+ def __init__(self, layer_type):
464
+ super().__init__()
465
+ self.layer_type = layer_type
466
+
467
+ def forward(self, x):
468
+ if self.layer_type == 'none':
469
+ return x
470
+ else:
471
+ return F.interpolate(x, scale_factor=2, mode='nearest')
472
+
473
+ class Decoder(nn.Module):
474
+ def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
475
+ resblock_kernel_sizes = [3,7,11],
476
+ upsample_rates = [10, 6],
477
+ upsample_initial_channel=512,
478
+ resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
479
+ upsample_kernel_sizes=[20, 12],
480
+ gen_istft_n_fft=20, gen_istft_hop_size=5):
481
+ super().__init__()
482
+
483
+ self.decode = nn.ModuleList()
484
+
485
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
486
+
487
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
488
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
489
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
490
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
491
+
492
+ self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
493
+
494
+ self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
495
+
496
+ self.asr_res = nn.Sequential(
497
+ weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
498
+ )
499
+
500
+
501
+ self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
502
+ upsample_initial_channel, resblock_dilation_sizes,
503
+ upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size)
504
+
505
+ def forward(self, asr, F0_curve, N, s):
506
+ F0 = self.F0_conv(F0_curve.unsqueeze(1))
507
+ N = self.N_conv(N.unsqueeze(1))
508
+
509
+ x = torch.cat([asr, F0, N], axis=1)
510
+ x = self.encode(x, s)
511
+
512
+ asr_res = self.asr_res(asr)
513
+
514
+ res = True
515
+ for block in self.decode:
516
+ if res:
517
+ x = torch.cat([x, asr_res, F0, N], axis=1)
518
+ x = block(x, s)
519
+ if block.upsample_type != "none":
520
+ res = False
521
+
522
+ x = self.generator(x, s, F0_curve)
523
+ return x
KOKORO/kokoro.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import phonemizer
2
+ import re
3
+ import torch
4
+
5
+ import os
6
+ import platform
7
+
8
+ # Check if the system is Windows
9
+ is_windows = platform.system() == "Windows"
10
+
11
+ # If Windows, set the environment variables
12
+ if is_windows:
13
+ os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = r"C:\Program Files\eSpeak NG\libespeak-ng.dll"
14
+ os.environ["PHONEMIZER_ESPEAK_PATH"] = r"C:\Program Files\eSpeak NG\espeak-ng.exe"
15
+
16
+
17
+ def split_num(num):
18
+ num = num.group()
19
+ if '.' in num:
20
+ return num
21
+ elif ':' in num:
22
+ h, m = [int(n) for n in num.split(':')]
23
+ if m == 0:
24
+ return f"{h} o'clock"
25
+ elif m < 10:
26
+ return f'{h} oh {m}'
27
+ return f'{h} {m}'
28
+ year = int(num[:4])
29
+ if year < 1100 or year % 1000 < 10:
30
+ return num
31
+ left, right = num[:2], int(num[2:4])
32
+ s = 's' if num.endswith('s') else ''
33
+ if 100 <= year % 1000 <= 999:
34
+ if right == 0:
35
+ return f'{left} hundred{s}'
36
+ elif right < 10:
37
+ return f'{left} oh {right}{s}'
38
+ return f'{left} {right}{s}'
39
+
40
+ def flip_money(m):
41
+ m = m.group()
42
+ bill = 'dollar' if m[0] == '$' else 'pound'
43
+ if m[-1].isalpha():
44
+ return f'{m[1:]} {bill}s'
45
+ elif '.' not in m:
46
+ s = '' if m[1:] == '1' else 's'
47
+ return f'{m[1:]} {bill}{s}'
48
+ b, c = m[1:].split('.')
49
+ s = '' if b == '1' else 's'
50
+ c = int(c.ljust(2, '0'))
51
+ coins = f"cent{'' if c == 1 else 's'}" if m[0] == '$' else ('penny' if c == 1 else 'pence')
52
+ return f'{b} {bill}{s} and {c} {coins}'
53
+
54
+ def point_num(num):
55
+ a, b = num.group().split('.')
56
+ return ' point '.join([a, ' '.join(b)])
57
+
58
+ def normalize_text(text):
59
+ text = text.replace(chr(8216), "'").replace(chr(8217), "'")
60
+ text = text.replace('«', chr(8220)).replace('»', chr(8221))
61
+ text = text.replace(chr(8220), '"').replace(chr(8221), '"')
62
+ text = text.replace('(', '«').replace(')', '»')
63
+ for a, b in zip('、。!,:;?', ',.!,:;?'):
64
+ text = text.replace(a, b+' ')
65
+ text = re.sub(r'[^\S \n]', ' ', text)
66
+ text = re.sub(r' +', ' ', text)
67
+ text = re.sub(r'(?<=\n) +(?=\n)', '', text)
68
+ text = re.sub(r'\bD[Rr]\.(?= [A-Z])', 'Doctor', text)
69
+ text = re.sub(r'\b(?:Mr\.|MR\.(?= [A-Z]))', 'Mister', text)
70
+ text = re.sub(r'\b(?:Ms\.|MS\.(?= [A-Z]))', 'Miss', text)
71
+ text = re.sub(r'\b(?:Mrs\.|MRS\.(?= [A-Z]))', 'Mrs', text)
72
+ text = re.sub(r'\betc\.(?! [A-Z])', 'etc', text)
73
+ text = re.sub(r'(?i)\b(y)eah?\b', r"\1e'a", text)
74
+ text = re.sub(r'\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)', split_num, text)
75
+ text = re.sub(r'(?<=\d),(?=\d)', '', text)
76
+ text = re.sub(r'(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b', flip_money, text)
77
+ text = re.sub(r'\d*\.\d+', point_num, text)
78
+ text = re.sub(r'(?<=\d)-(?=\d)', ' to ', text)
79
+ text = re.sub(r'(?<=\d)S', ' S', text)
80
+ text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
81
+ text = re.sub(r"(?<=X')S\b", 's', text)
82
+ text = re.sub(r'(?:[A-Za-z]\.){2,} [a-z]', lambda m: m.group().replace('.', '-'), text)
83
+ text = re.sub(r'(?i)(?<=[A-Z])\.(?=[A-Z])', '-', text)
84
+ return text.strip()
85
+
86
+ def get_vocab():
87
+ _pad = "$"
88
+ _punctuation = ';:,.!?¡¿—…"«»“” '
89
+ _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
90
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
91
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
92
+ dicts = {}
93
+ for i in range(len((symbols))):
94
+ dicts[symbols[i]] = i
95
+ return dicts
96
+
97
+ VOCAB = get_vocab()
98
+ def tokenize(ps):
99
+ return [i for i in map(VOCAB.get, ps) if i is not None]
100
+
101
+ phonemizers = dict(
102
+ a=phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True),
103
+ b=phonemizer.backend.EspeakBackend(language='en-gb', preserve_punctuation=True, with_stress=True),
104
+ )
105
+ def phonemize(text, lang, norm=True):
106
+ if norm:
107
+ text = normalize_text(text)
108
+ ps = phonemizers[lang].phonemize([text])
109
+ ps = ps[0] if ps else ''
110
+ # https://en.wiktionary.org/wiki/kokoro#English
111
+ ps = ps.replace('kəkˈoːɹoʊ', 'kˈoʊkəɹoʊ').replace('kəkˈɔːɹəʊ', 'kˈəʊkəɹəʊ')
112
+ ps = ps.replace('ʲ', 'j').replace('r', 'ɹ').replace('x', 'k').replace('ɬ', 'l')
113
+ ps = re.sub(r'(?<=[a-zɹː])(?=hˈʌndɹɪd)', ' ', ps)
114
+ ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', 'z', ps)
115
+ if lang == 'a':
116
+ ps = re.sub(r'(?<=nˈaɪn)ti(?!ː)', 'di', ps)
117
+ ps = ''.join(filter(lambda p: p in VOCAB, ps))
118
+ return ps.strip()
119
+
120
+ def length_to_mask(lengths):
121
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
122
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
123
+ return mask
124
+
125
+ @torch.no_grad()
126
+ def forward(model, tokens, ref_s, speed):
127
+ device = ref_s.device
128
+ tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
129
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
130
+ text_mask = length_to_mask(input_lengths).to(device)
131
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
132
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
133
+ s = ref_s[:, 128:]
134
+ d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
135
+ x, _ = model.predictor.lstm(d)
136
+ duration = model.predictor.duration_proj(x)
137
+ duration = torch.sigmoid(duration).sum(axis=-1) / speed
138
+ pred_dur = torch.round(duration).clamp(min=1).long()
139
+ pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
140
+ c_frame = 0
141
+ for i in range(pred_aln_trg.size(0)):
142
+ pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
143
+ c_frame += pred_dur[0,i].item()
144
+ en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
145
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
146
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
147
+ asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
148
+ return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
149
+
150
+ def generate(model, text, voicepack, lang='a', speed=1, ps=None):
151
+ ps = ps or phonemize(text, lang)
152
+ tokens = tokenize(ps)
153
+ if not tokens:
154
+ return None
155
+ elif len(tokens) > 510:
156
+ tokens = tokens[:510]
157
+ print('Truncated to 510 tokens')
158
+ ref_s = voicepack[len(tokens)]
159
+ out = forward(model, tokens, ref_s, speed)
160
+ ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
161
+ return out, ps
KOKORO/models.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/yl4579/StyleTTS2/blob/main/models.py
2
+ from .istftnet import AdaIN1d, Decoder
3
+ from munch import Munch
4
+ from pathlib import Path
5
+ from .plbert import load_plbert
6
+ from torch.nn.utils import weight_norm, spectral_norm
7
+ import json
8
+ import numpy as np
9
+ import os
10
+ import os.path as osp
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ class LinearNorm(torch.nn.Module):
16
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
17
+ super(LinearNorm, self).__init__()
18
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
19
+
20
+ torch.nn.init.xavier_uniform_(
21
+ self.linear_layer.weight,
22
+ gain=torch.nn.init.calculate_gain(w_init_gain))
23
+
24
+ def forward(self, x):
25
+ return self.linear_layer(x)
26
+
27
+ class LayerNorm(nn.Module):
28
+ def __init__(self, channels, eps=1e-5):
29
+ super().__init__()
30
+ self.channels = channels
31
+ self.eps = eps
32
+
33
+ self.gamma = nn.Parameter(torch.ones(channels))
34
+ self.beta = nn.Parameter(torch.zeros(channels))
35
+
36
+ def forward(self, x):
37
+ x = x.transpose(1, -1)
38
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
39
+ return x.transpose(1, -1)
40
+
41
+ class TextEncoder(nn.Module):
42
+ def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
43
+ super().__init__()
44
+ self.embedding = nn.Embedding(n_symbols, channels)
45
+
46
+ padding = (kernel_size - 1) // 2
47
+ self.cnn = nn.ModuleList()
48
+ for _ in range(depth):
49
+ self.cnn.append(nn.Sequential(
50
+ weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
51
+ LayerNorm(channels),
52
+ actv,
53
+ nn.Dropout(0.2),
54
+ ))
55
+ # self.cnn = nn.Sequential(*self.cnn)
56
+
57
+ self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
58
+
59
+ def forward(self, x, input_lengths, m):
60
+ x = self.embedding(x) # [B, T, emb]
61
+ x = x.transpose(1, 2) # [B, emb, T]
62
+ m = m.to(input_lengths.device).unsqueeze(1)
63
+ x.masked_fill_(m, 0.0)
64
+
65
+ for c in self.cnn:
66
+ x = c(x)
67
+ x.masked_fill_(m, 0.0)
68
+
69
+ x = x.transpose(1, 2) # [B, T, chn]
70
+
71
+ input_lengths = input_lengths.cpu().numpy()
72
+ x = nn.utils.rnn.pack_padded_sequence(
73
+ x, input_lengths, batch_first=True, enforce_sorted=False)
74
+
75
+ self.lstm.flatten_parameters()
76
+ x, _ = self.lstm(x)
77
+ x, _ = nn.utils.rnn.pad_packed_sequence(
78
+ x, batch_first=True)
79
+
80
+ x = x.transpose(-1, -2)
81
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
82
+
83
+ x_pad[:, :, :x.shape[-1]] = x
84
+ x = x_pad.to(x.device)
85
+
86
+ x.masked_fill_(m, 0.0)
87
+
88
+ return x
89
+
90
+ def inference(self, x):
91
+ x = self.embedding(x)
92
+ x = x.transpose(1, 2)
93
+ x = self.cnn(x)
94
+ x = x.transpose(1, 2)
95
+ self.lstm.flatten_parameters()
96
+ x, _ = self.lstm(x)
97
+ return x
98
+
99
+ def length_to_mask(self, lengths):
100
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
101
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
102
+ return mask
103
+
104
+
105
+ class UpSample1d(nn.Module):
106
+ def __init__(self, layer_type):
107
+ super().__init__()
108
+ self.layer_type = layer_type
109
+
110
+ def forward(self, x):
111
+ if self.layer_type == 'none':
112
+ return x
113
+ else:
114
+ return F.interpolate(x, scale_factor=2, mode='nearest')
115
+
116
+ class AdainResBlk1d(nn.Module):
117
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
118
+ upsample='none', dropout_p=0.0):
119
+ super().__init__()
120
+ self.actv = actv
121
+ self.upsample_type = upsample
122
+ self.upsample = UpSample1d(upsample)
123
+ self.learned_sc = dim_in != dim_out
124
+ self._build_weights(dim_in, dim_out, style_dim)
125
+ self.dropout = nn.Dropout(dropout_p)
126
+
127
+ if upsample == 'none':
128
+ self.pool = nn.Identity()
129
+ else:
130
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
131
+
132
+
133
+ def _build_weights(self, dim_in, dim_out, style_dim):
134
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
135
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
136
+ self.norm1 = AdaIN1d(style_dim, dim_in)
137
+ self.norm2 = AdaIN1d(style_dim, dim_out)
138
+ if self.learned_sc:
139
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
140
+
141
+ def _shortcut(self, x):
142
+ x = self.upsample(x)
143
+ if self.learned_sc:
144
+ x = self.conv1x1(x)
145
+ return x
146
+
147
+ def _residual(self, x, s):
148
+ x = self.norm1(x, s)
149
+ x = self.actv(x)
150
+ x = self.pool(x)
151
+ x = self.conv1(self.dropout(x))
152
+ x = self.norm2(x, s)
153
+ x = self.actv(x)
154
+ x = self.conv2(self.dropout(x))
155
+ return x
156
+
157
+ def forward(self, x, s):
158
+ out = self._residual(x, s)
159
+ out = (out + self._shortcut(x)) / np.sqrt(2)
160
+ return out
161
+
162
+ class AdaLayerNorm(nn.Module):
163
+ def __init__(self, style_dim, channels, eps=1e-5):
164
+ super().__init__()
165
+ self.channels = channels
166
+ self.eps = eps
167
+
168
+ self.fc = nn.Linear(style_dim, channels*2)
169
+
170
+ def forward(self, x, s):
171
+ x = x.transpose(-1, -2)
172
+ x = x.transpose(1, -1)
173
+
174
+ h = self.fc(s)
175
+ h = h.view(h.size(0), h.size(1), 1)
176
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
177
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
178
+
179
+
180
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
181
+ x = (1 + gamma) * x + beta
182
+ return x.transpose(1, -1).transpose(-1, -2)
183
+
184
+ class ProsodyPredictor(nn.Module):
185
+
186
+ def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
187
+ super().__init__()
188
+
189
+ self.text_encoder = DurationEncoder(sty_dim=style_dim,
190
+ d_model=d_hid,
191
+ nlayers=nlayers,
192
+ dropout=dropout)
193
+
194
+ self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
195
+ self.duration_proj = LinearNorm(d_hid, max_dur)
196
+
197
+ self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
198
+ self.F0 = nn.ModuleList()
199
+ self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
200
+ self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
201
+ self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
202
+
203
+ self.N = nn.ModuleList()
204
+ self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
205
+ self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
206
+ self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
207
+
208
+ self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
209
+ self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
210
+
211
+
212
+ def forward(self, texts, style, text_lengths, alignment, m):
213
+ d = self.text_encoder(texts, style, text_lengths, m)
214
+
215
+ batch_size = d.shape[0]
216
+ text_size = d.shape[1]
217
+
218
+ # predict duration
219
+ input_lengths = text_lengths.cpu().numpy()
220
+ x = nn.utils.rnn.pack_padded_sequence(
221
+ d, input_lengths, batch_first=True, enforce_sorted=False)
222
+
223
+ m = m.to(text_lengths.device).unsqueeze(1)
224
+
225
+ self.lstm.flatten_parameters()
226
+ x, _ = self.lstm(x)
227
+ x, _ = nn.utils.rnn.pad_packed_sequence(
228
+ x, batch_first=True)
229
+
230
+ x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
231
+
232
+ x_pad[:, :x.shape[1], :] = x
233
+ x = x_pad.to(x.device)
234
+
235
+ duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
236
+
237
+ en = (d.transpose(-1, -2) @ alignment)
238
+
239
+ return duration.squeeze(-1), en
240
+
241
+ def F0Ntrain(self, x, s):
242
+ x, _ = self.shared(x.transpose(-1, -2))
243
+
244
+ F0 = x.transpose(-1, -2)
245
+ for block in self.F0:
246
+ F0 = block(F0, s)
247
+ F0 = self.F0_proj(F0)
248
+
249
+ N = x.transpose(-1, -2)
250
+ for block in self.N:
251
+ N = block(N, s)
252
+ N = self.N_proj(N)
253
+
254
+ return F0.squeeze(1), N.squeeze(1)
255
+
256
+ def length_to_mask(self, lengths):
257
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
258
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
259
+ return mask
260
+
261
+ class DurationEncoder(nn.Module):
262
+
263
+ def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
264
+ super().__init__()
265
+ self.lstms = nn.ModuleList()
266
+ for _ in range(nlayers):
267
+ self.lstms.append(nn.LSTM(d_model + sty_dim,
268
+ d_model // 2,
269
+ num_layers=1,
270
+ batch_first=True,
271
+ bidirectional=True,
272
+ dropout=dropout))
273
+ self.lstms.append(AdaLayerNorm(sty_dim, d_model))
274
+
275
+
276
+ self.dropout = dropout
277
+ self.d_model = d_model
278
+ self.sty_dim = sty_dim
279
+
280
+ def forward(self, x, style, text_lengths, m):
281
+ masks = m.to(text_lengths.device)
282
+
283
+ x = x.permute(2, 0, 1)
284
+ s = style.expand(x.shape[0], x.shape[1], -1)
285
+ x = torch.cat([x, s], axis=-1)
286
+ x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
287
+
288
+ x = x.transpose(0, 1)
289
+ input_lengths = text_lengths.cpu().numpy()
290
+ x = x.transpose(-1, -2)
291
+
292
+ for block in self.lstms:
293
+ if isinstance(block, AdaLayerNorm):
294
+ x = block(x.transpose(-1, -2), style).transpose(-1, -2)
295
+ x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
296
+ x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
297
+ else:
298
+ x = x.transpose(-1, -2)
299
+ x = nn.utils.rnn.pack_padded_sequence(
300
+ x, input_lengths, batch_first=True, enforce_sorted=False)
301
+ block.flatten_parameters()
302
+ x, _ = block(x)
303
+ x, _ = nn.utils.rnn.pad_packed_sequence(
304
+ x, batch_first=True)
305
+ x = F.dropout(x, p=self.dropout, training=self.training)
306
+ x = x.transpose(-1, -2)
307
+
308
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
309
+
310
+ x_pad[:, :, :x.shape[-1]] = x
311
+ x = x_pad.to(x.device)
312
+
313
+ return x.transpose(-1, -2)
314
+
315
+ def inference(self, x, style):
316
+ x = self.embedding(x.transpose(-1, -2)) * np.sqrt(self.d_model)
317
+ style = style.expand(x.shape[0], x.shape[1], -1)
318
+ x = torch.cat([x, style], axis=-1)
319
+ src = self.pos_encoder(x)
320
+ output = self.transformer_encoder(src).transpose(0, 1)
321
+ return output
322
+
323
+ def length_to_mask(self, lengths):
324
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
325
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
326
+ return mask
327
+
328
+ # https://github.com/yl4579/StyleTTS2/blob/main/utils.py
329
+ def recursive_munch(d):
330
+ if isinstance(d, dict):
331
+ return Munch((k, recursive_munch(v)) for k, v in d.items())
332
+ elif isinstance(d, list):
333
+ return [recursive_munch(v) for v in d]
334
+ else:
335
+ return d
336
+
337
+ def build_model(path, device):
338
+ config = Path(__file__).parent / 'config.json'
339
+ assert config.exists(), f'Config path incorrect: config.json not found at {config}'
340
+ with open(config, 'r') as r:
341
+ args = recursive_munch(json.load(r))
342
+ assert args.decoder.type == 'istftnet', f'Unknown decoder type: {args.decoder.type}'
343
+ decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
344
+ resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
345
+ upsample_rates = args.decoder.upsample_rates,
346
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
347
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
348
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
349
+ gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
350
+ text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
351
+ predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
352
+ bert = load_plbert()
353
+ bert_encoder = nn.Linear(bert.config.hidden_size, args.hidden_dim)
354
+ for parent in [bert, bert_encoder, predictor, decoder, text_encoder]:
355
+ for child in parent.children():
356
+ if isinstance(child, nn.RNNBase):
357
+ child.flatten_parameters()
358
+ model = Munch(
359
+ bert=bert.to(device).eval(),
360
+ bert_encoder=bert_encoder.to(device).eval(),
361
+ predictor=predictor.to(device).eval(),
362
+ decoder=decoder.to(device).eval(),
363
+ text_encoder=text_encoder.to(device).eval(),
364
+ )
365
+ for key, state_dict in torch.load(path, map_location='cpu', weights_only=True)['net'].items():
366
+ assert key in model, key
367
+ try:
368
+ model[key].load_state_dict(state_dict)
369
+ except:
370
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
371
+ model[key].load_state_dict(state_dict, strict=False)
372
+ return model
KOKORO/plbert.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
2
+ from transformers import AlbertConfig, AlbertModel
3
+
4
+ class CustomAlbert(AlbertModel):
5
+ def forward(self, *args, **kwargs):
6
+ # Call the original forward method
7
+ outputs = super().forward(*args, **kwargs)
8
+ # Only return the last_hidden_state
9
+ return outputs.last_hidden_state
10
+
11
+ def load_plbert():
12
+ plbert_config = {'vocab_size': 178, 'hidden_size': 768, 'num_attention_heads': 12, 'intermediate_size': 2048, 'max_position_embeddings': 512, 'num_hidden_layers': 12, 'dropout': 0.1}
13
+ albert_base_configuration = AlbertConfig(**plbert_config)
14
+ bert = CustomAlbert(albert_base_configuration)
15
+ return bert
KOKORO/utils.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .kokoro import normalize_text,phonemize,generate
2
+ import re
3
+ import librosa
4
+ import os
5
+ import uuid
6
+ from pydub.silence import split_on_silence
7
+ from pydub import AudioSegment
8
+ import wave
9
+ import numpy as np
10
+ import torch
11
+
12
+
13
+ def create_audio_dir():
14
+ """Creates the 'kokoro_audio' directory in the root folder if it doesn't exist."""
15
+ root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
16
+ audio_dir = os.path.join(root_dir, "kokoro_audio")
17
+
18
+ if not os.path.exists(audio_dir):
19
+ os.makedirs(audio_dir)
20
+ print(f"Created directory: {audio_dir}")
21
+ else:
22
+ print(f"Directory already exists: {audio_dir}")
23
+ return audio_dir
24
+
25
+ temp_folder = create_audio_dir()
26
+
27
+
28
+ debug=False
29
+ def resplit_strings(arr):
30
+ # Handle edge cases
31
+ if not arr:
32
+ return '', ''
33
+ if len(arr) == 1:
34
+ return arr[0], ''
35
+ # Try each possible split point
36
+ min_diff = float('inf')
37
+ best_split = 0
38
+ # Calculate lengths when joined with spaces
39
+ lengths = [len(s) for s in arr]
40
+ spaces = len(arr) - 1 # Total spaces needed
41
+ # Try each split point
42
+ left_len = 0
43
+ right_len = sum(lengths) + spaces
44
+ for i in range(1, len(arr)):
45
+ # Add current word and space to left side
46
+ left_len += lengths[i-1] + (1 if i > 1 else 0)
47
+ # Remove current word and space from right side
48
+ right_len -= lengths[i-1] + 1
49
+ diff = abs(left_len - right_len)
50
+ if diff < min_diff:
51
+ min_diff = diff
52
+ best_split = i
53
+ # Join the strings with the best split point
54
+ return ' '.join(arr[:best_split]), ' '.join(arr[best_split:])
55
+
56
+ def recursive_split(text, voice):
57
+ if not text:
58
+ return []
59
+ tokens = phonemize(text, voice, norm=False)
60
+ if len(tokens) < 511:
61
+ return [(text, tokens, len(tokens))] if tokens else []
62
+ if ' ' not in text:
63
+ return []
64
+ for punctuation in ['!.?…', ':;', ',—']:
65
+ splits = re.split(f'(?:(?<=[{punctuation}])|(?<=[{punctuation}]["\'»])|(?<=[{punctuation}]["\'»]["\'»])) ', text)
66
+ if len(splits) > 1:
67
+ break
68
+ else:
69
+ splits = None
70
+ splits = splits or text.split(' ')
71
+ a, b = resplit_strings(splits)
72
+ return recursive_split(a, voice) + recursive_split(b, voice)
73
+
74
+ def segment_and_tokenize(text, voice, skip_square_brackets=True, newline_split=2):
75
+ if skip_square_brackets:
76
+ text = re.sub(r'\[.*?\]', '', text)
77
+ texts = [t.strip() for t in re.split('\n{'+str(newline_split)+',}', normalize_text(text))] if newline_split > 0 else [normalize_text(text)]
78
+ segments = [row for t in texts for row in recursive_split(t, voice)]
79
+ return [(i, *row) for i, row in enumerate(segments)]
80
+
81
+
82
+ def large_text(text,VOICE_NAME):
83
+ if len(text) <= 500:
84
+ return [(0, text, len(text))]
85
+ else:
86
+ result=segment_and_tokenize(text, VOICE_NAME[0])
87
+ filtered_result = [(row[0], row[1], row[3]) for row in result]
88
+ return filtered_result
89
+
90
+
91
+ def clamp_speed(speed):
92
+ if not isinstance(speed, float) and not isinstance(speed, int):
93
+ return 1
94
+ elif speed < 0.5:
95
+ # return 0.5
96
+ return speed
97
+ elif speed > 2:
98
+ return 2
99
+ return speed
100
+
101
+ def clamp_trim(trim):
102
+ if not isinstance(trim, float) and not isinstance(trim, int):
103
+ return 0.5
104
+ elif trim <= 0:
105
+ return 0
106
+ elif trim > 1:
107
+ return 0.5
108
+ return trim
109
+
110
+ def trim_if_needed(out, trim):
111
+ if not trim:
112
+ return out
113
+ a, b = librosa.effects.trim(out, top_db=30)[1]
114
+ a = int(a*trim)
115
+ b = int(len(out)-(len(out)-b)*trim)
116
+ return out[a:b]
117
+
118
+ #Above code copied from https://huggingface.co/spaces/hexgrad/Kokoro-TTS/blob/main/app.py
119
+
120
+ def get_random_file_name(output_file=""):
121
+ global temp_folder
122
+ if output_file=="":
123
+ random_id = str(uuid.uuid4())[:8]
124
+ output_file = f"{temp_folder}/{random_id}.wav"
125
+ return output_file
126
+ # Ensure temp_folder exists
127
+ if not os.path.exists(output_file):
128
+ return output_file
129
+ try:
130
+ if output_file and os.path.exists(output_file):
131
+ os.remove(output_file) # Try to remove the file if it exists
132
+ return output_file # Return the same name if the file was successfully removed
133
+ except Exception as e:
134
+ # print(f"Error removing file {output_file}: {e}")
135
+ random_id = str(uuid.uuid4())[:8]
136
+ output_file = f"{temp_folder}/{random_id}.wav"
137
+ return output_file
138
+
139
+
140
+ def remove_silence_function(file_path,minimum_silence=50):
141
+ # Extract file name and format from the provided path
142
+ output_path = file_path.replace(".wav", "_no_silence.wav")
143
+ audio_format = "wav"
144
+ # Reading and splitting the audio file into chunks
145
+ sound = AudioSegment.from_file(file_path, format=audio_format)
146
+ audio_chunks = split_on_silence(sound,
147
+ min_silence_len=100,
148
+ silence_thresh=-45,
149
+ keep_silence=minimum_silence)
150
+ # Putting the file back together
151
+ combined = AudioSegment.empty()
152
+ for chunk in audio_chunks:
153
+ combined += chunk
154
+ combined.export(output_path, format=audio_format)
155
+ return output_path
156
+
157
+ import simpleaudio as sa
158
+ def play_audio(filename):
159
+ wave_obj = sa.WaveObject.from_wave_file(filename)
160
+ play_obj = wave_obj.play()
161
+ play_obj.wait_done()
162
+
163
+
164
+ import re
165
+
166
+ def clean_text(text):
167
+ # Define replacement rules
168
+ replacements = {
169
+ "–": " ", # Replace en-dash with space
170
+ "-": " ", # Replace hyphen with space
171
+ ":": ",", # Replace colon with comma
172
+ "**": " ", # Replace double asterisks with space
173
+ "*": " ", # Replace single asterisk with space
174
+ "#": " ", # Replace hash with space
175
+ }
176
+
177
+ # Apply replacements
178
+ for old, new in replacements.items():
179
+ text = text.replace(old, new)
180
+
181
+ # Remove emojis using regex (covering wide range of Unicode characters)
182
+ emoji_pattern = re.compile(
183
+ r'[\U0001F600-\U0001F64F]|' # Emoticons
184
+ r'[\U0001F300-\U0001F5FF]|' # Miscellaneous symbols and pictographs
185
+ r'[\U0001F680-\U0001F6FF]|' # Transport and map symbols
186
+ r'[\U0001F700-\U0001F77F]|' # Alchemical symbols
187
+ r'[\U0001F780-\U0001F7FF]|' # Geometric shapes extended
188
+ r'[\U0001F800-\U0001F8FF]|' # Supplemental arrows-C
189
+ r'[\U0001F900-\U0001F9FF]|' # Supplemental symbols and pictographs
190
+ r'[\U0001FA00-\U0001FA6F]|' # Chess symbols
191
+ r'[\U0001FA70-\U0001FAFF]|' # Symbols and pictographs extended-A
192
+ r'[\U00002702-\U000027B0]|' # Dingbats
193
+ r'[\U0001F1E0-\U0001F1FF]' # Flags (iOS)
194
+ r'', flags=re.UNICODE)
195
+ text = emoji_pattern.sub(r'', text)
196
+
197
+ # Remove multiple spaces and extra line breaks
198
+ text = re.sub(r'\s+', ' ', text).strip()
199
+
200
+ return text
201
+
202
+ #copied from F5TTS 😁
203
+ import re
204
+ def parse_speechtypes_text(gen_text):
205
+ # Pattern to find {speechtype}
206
+ pattern = r"\{(.*?)\}"
207
+
208
+ # Split the text by the pattern
209
+ tokens = re.split(pattern, gen_text)
210
+
211
+ segments = []
212
+
213
+ current_style = "af"
214
+
215
+ for i in range(len(tokens)):
216
+ if i % 2 == 0:
217
+ # This is text
218
+ text = tokens[i].strip()
219
+ if text:
220
+ text=clean_text(text)
221
+ segments.append({"voice_name": current_style, "text": text})
222
+ else:
223
+ # This is style
224
+ style = tokens[i].strip()
225
+ current_style = style
226
+
227
+ return segments
228
+
229
+ def podcast(MODEL, device, gen_text, speed=1.0, trim=0, pad_between_segments=0, remove_silence=True, minimum_silence=50):
230
+ segments = parse_speechtypes_text(gen_text)
231
+ speed = clamp_speed(speed)
232
+ trim = clamp_trim(trim)
233
+ silence_duration = clamp_trim(pad_between_segments)
234
+ # output_file = get_random_file_name(output_file)
235
+ sample_rate = 24000 # Sample rate of the audio
236
+
237
+ # Create a silent audio segment in float32
238
+ silence = np.zeros(int(sample_rate * silence_duration), dtype=np.float32)
239
+ if len(segments)>=1:
240
+ first_line_text=segments[0]["text"]
241
+ output_file=tts_file_name(first_line_text)
242
+ else:
243
+ output_file = get_random_file_name("")
244
+
245
+ output_file = output_file.replace('\n', '').replace('\r', '')
246
+ # Open a WAV file for writing
247
+ with wave.open(output_file, 'wb') as wav_file:
248
+ wav_file.setnchannels(1) # Mono
249
+ wav_file.setsampwidth(2) # 16-bit audio
250
+ wav_file.setframerate(sample_rate)
251
+
252
+ for idx, segment in enumerate(segments): # Added index `idx` to track position
253
+ voice_name = segment["voice_name"]
254
+ text = segment["text"]
255
+ voice_pack_path = f"./KOKORO/voices/{voice_name}.pt"
256
+ VOICEPACK = torch.load(voice_pack_path, weights_only=True).to(device)
257
+
258
+ # Generate audio for the segment
259
+ audio, out_ps = generate(MODEL, text, VOICEPACK, lang=voice_name[0], speed=speed)
260
+ audio = trim_if_needed(audio, trim)
261
+
262
+ # Scale audio from float32 to int16
263
+ audio = (audio * 32767).astype(np.int16)
264
+
265
+ # Write the audio segment to the WAV file
266
+ wav_file.writeframes(audio.tobytes())
267
+
268
+ # Add silence between segments, except after the last segment
269
+ if idx != len(segments) - 1:
270
+ wav_file.writeframes((silence * 32767).astype(np.int16).tobytes())
271
+
272
+ # Optionally remove silence from the output file
273
+ if remove_silence:
274
+ output_file = remove_silence_function(output_file, minimum_silence=minimum_silence)
275
+
276
+ return output_file
277
+
278
+ def tts(MODEL,device,text, voice_name, speed=1.0, trim=0.5, pad_between_segments=0.5, output_file="",remove_silence=True,minimum_silence=50):
279
+ text=clean_text(text)
280
+ segments = large_text(text, voice_name)
281
+ voice_pack_path = f"./KOKORO/voices/{voice_name}.pt"
282
+ VOICEPACK = torch.load(voice_pack_path, weights_only=True).to(device)
283
+ speed = clamp_speed(speed)
284
+ trim = clamp_trim(trim)
285
+ silence_duration = clamp_trim(pad_between_segments)
286
+ output_file=get_random_file_name(output_file)
287
+ if debug:
288
+ print(f'Loaded voice: {voice_name}')
289
+ print(f"Speed: {speed}")
290
+ print(f"Trim: {trim}")
291
+ print(f"Silence duration: {silence_duration}")
292
+ sample_rate = 24000 # Sample rate of the audio
293
+
294
+ # Create a silent audio segment in float32
295
+ silence = np.zeros(int(sample_rate * silence_duration), dtype=np.float32)
296
+
297
+ # Open a WAV file for writing
298
+ with wave.open(output_file, 'wb') as wav_file:
299
+ wav_file.setnchannels(1) # Mono
300
+ wav_file.setsampwidth(2) # 16-bit audio
301
+ wav_file.setframerate(sample_rate)
302
+
303
+ for i in segments:
304
+ id = i[0]
305
+ text = i[1]
306
+ if debug:
307
+ print(i)
308
+ audio, out_ps = generate(MODEL, text, VOICEPACK, lang=voice_name[0], speed=speed)
309
+ audio = trim_if_needed(audio, trim)
310
+
311
+ # Scale audio from float32 to int16
312
+ audio = (audio * 32767).astype(np.int16)
313
+
314
+ # Write the audio segment to the WAV file
315
+ wav_file.writeframes(audio.tobytes())
316
+
317
+ # Add silence between segments, except after the last segment
318
+ if id != len(segments) - 1:
319
+ wav_file.writeframes((silence * 32767).astype(np.int16).tobytes())
320
+ if remove_silence:
321
+ output_file=remove_silence_function(output_file,minimum_silence=minimum_silence)
322
+ return output_file
323
+
324
+
325
+
326
+ def tts_file_name(text):
327
+ global temp_folder
328
+ # Remove all non-alphabetic characters and convert to lowercase
329
+ text = re.sub(r'[^a-zA-Z\s]', '', text) # Retain only alphabets and spaces
330
+ text = text.lower().strip() # Convert to lowercase and strip leading/trailing spaces
331
+ text = text.replace(" ", "_") # Replace spaces with underscores
332
+
333
+ # Truncate or handle empty text
334
+ truncated_text = text[:25] if len(text) > 25 else text if len(text) > 0 else "empty"
335
+
336
+ # Generate a random string for uniqueness
337
+ random_string = uuid.uuid4().hex[:8].upper()
338
+
339
+ # Construct the file name
340
+ file_name = f"{temp_folder}/{truncated_text}_{random_string}.wav"
341
+ return file_name
342
+