Spaces:
Running
Running
NTT123
commited on
Commit
·
df1ad02
1
Parent(s):
73eaac3
a slow but working model
Browse files- .gitattributes +2 -0
- alphabet.txt +41 -0
- app.py +30 -4
- inference.py +82 -0
- packages.txt +1 -0
- pooch.py +10 -0
- pretrained_model_ljs_500k.ckpt +3 -0
- requirements.txt +10 -0
- tacotron.py +446 -0
- tacotron.toml +31 -0
- text.py +87 -0
- utils.py +74 -0
- wavegru.py +234 -0
- wavegru.yaml +14 -0
- wavegru_vocoder_tpu_gta_preemphasis_pruning_v7_0040000.ckpt +3 -0
.gitattributes
CHANGED
|
@@ -25,3 +25,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
| 27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
| 27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
pretrained_model_ljs_500k.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
wavegru_vocoder_tpu_gta_preemphasis_pruning_v7_0040000.ckpt filter=lfs diff=lfs merge=lfs -text
|
alphabet.txt
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_
|
| 2 |
+
|
| 3 |
+
!
|
| 4 |
+
"
|
| 5 |
+
'
|
| 6 |
+
(
|
| 7 |
+
)
|
| 8 |
+
,
|
| 9 |
+
-
|
| 10 |
+
.
|
| 11 |
+
:
|
| 12 |
+
;
|
| 13 |
+
?
|
| 14 |
+
[
|
| 15 |
+
]
|
| 16 |
+
a
|
| 17 |
+
b
|
| 18 |
+
c
|
| 19 |
+
d
|
| 20 |
+
e
|
| 21 |
+
f
|
| 22 |
+
g
|
| 23 |
+
h
|
| 24 |
+
i
|
| 25 |
+
j
|
| 26 |
+
k
|
| 27 |
+
l
|
| 28 |
+
m
|
| 29 |
+
n
|
| 30 |
+
o
|
| 31 |
+
p
|
| 32 |
+
q
|
| 33 |
+
r
|
| 34 |
+
s
|
| 35 |
+
t
|
| 36 |
+
u
|
| 37 |
+
v
|
| 38 |
+
w
|
| 39 |
+
x
|
| 40 |
+
y
|
| 41 |
+
z
|
app.py
CHANGED
|
@@ -1,7 +1,33 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
|
| 3 |
-
|
| 4 |
-
return "Hello " + name + "!!"
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
|
| 3 |
+
from inference import load_tacotron_model, load_wavegru_net, text_to_mel, mel_to_wav
|
|
|
|
| 4 |
|
| 5 |
+
alphabet, tacotron_net, tacotron_config = load_tacotron_model(
|
| 6 |
+
"./alphabet.txt", "./tacotron.toml", "./pretrained_model_ljs_500k.ckpt"
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
wavegru_config, wavegru_net = load_wavegru_net(
|
| 11 |
+
"./wavegru.yaml", "./wavegru_vocoder_tpu_gta_preemphasis_pruning_v7_0040000.ckpt"
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def speak(text):
|
| 16 |
+
mel = text_to_mel(tacotron_net, text, alphabet, tacotron_config)
|
| 17 |
+
y = mel_to_wav(wavegru_net, mel, wavegru_config)
|
| 18 |
+
return 24_000, y
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
title = "WaveGRU-TTS"
|
| 22 |
+
description = "WaveGRU text-to-speech demo."
|
| 23 |
+
|
| 24 |
+
gr.Interface(
|
| 25 |
+
fn=speak,
|
| 26 |
+
inputs="text",
|
| 27 |
+
outputs="audio",
|
| 28 |
+
title=title,
|
| 29 |
+
description=description,
|
| 30 |
+
theme="default",
|
| 31 |
+
allow_screenshot=False,
|
| 32 |
+
allow_flagging="never",
|
| 33 |
+
).launch(debug=False)
|
inference.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import jax
|
| 2 |
+
import jax.numpy as jnp
|
| 3 |
+
import librosa
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pax
|
| 6 |
+
|
| 7 |
+
from text import english_cleaners
|
| 8 |
+
from utils import (
|
| 9 |
+
create_tacotron_model,
|
| 10 |
+
load_tacotron_ckpt,
|
| 11 |
+
load_tacotron_config,
|
| 12 |
+
load_wavegru_ckpt,
|
| 13 |
+
load_wavegru_config,
|
| 14 |
+
)
|
| 15 |
+
from wavegru import WaveGRU
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_tacotron_model(alphabet_file, config_file, model_file):
|
| 19 |
+
"""load tacotron model to memory"""
|
| 20 |
+
with open(alphabet_file, "r", encoding="utf-8") as f:
|
| 21 |
+
alphabet = f.read().split("\n")
|
| 22 |
+
|
| 23 |
+
config = load_tacotron_config(config_file)
|
| 24 |
+
net = create_tacotron_model(config)
|
| 25 |
+
_, net, _ = load_tacotron_ckpt(net, None, model_file)
|
| 26 |
+
net = net.eval()
|
| 27 |
+
net = jax.device_put(net)
|
| 28 |
+
return alphabet, net, config
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
tacotron_inference_fn = pax.pure(lambda net, text: net.inference(text, max_len=10000))
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def text_to_mel(net, text, alphabet, config):
|
| 35 |
+
"""convert text to mel spectrogram"""
|
| 36 |
+
text = english_cleaners(text)
|
| 37 |
+
text = text + config["PAD"] * (100 - (len(text) % 100))
|
| 38 |
+
tokens = [alphabet.index(c) for c in text]
|
| 39 |
+
tokens = jnp.array(tokens, dtype=jnp.int32)
|
| 40 |
+
mel = tacotron_inference_fn(net, tokens[None])
|
| 41 |
+
return mel
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def load_wavegru_net(config_file, model_file):
|
| 45 |
+
"""load wavegru to memory"""
|
| 46 |
+
config = load_wavegru_config(config_file)
|
| 47 |
+
net = WaveGRU(
|
| 48 |
+
mel_dim=config["mel_dim"],
|
| 49 |
+
embed_dim=config["embed_dim"],
|
| 50 |
+
rnn_dim=config["rnn_dim"],
|
| 51 |
+
upsample_factors=config["upsample_factors"],
|
| 52 |
+
)
|
| 53 |
+
_, net, _ = load_wavegru_ckpt(net, None, model_file)
|
| 54 |
+
net = net.eval()
|
| 55 |
+
net = jax.device_put(net)
|
| 56 |
+
return config, net
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
wavegru_inference = pax.pure(lambda net, mel: net.inference(mel, no_gru=False))
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def mel_to_wav(net, mel, config):
|
| 63 |
+
"""convert mel to wav"""
|
| 64 |
+
if len(mel.shape) == 2:
|
| 65 |
+
mel = mel[None]
|
| 66 |
+
pad = config["num_pad_frames"] // 2 + 4
|
| 67 |
+
mel = np.pad(
|
| 68 |
+
mel,
|
| 69 |
+
[(0, 0), (pad, pad), (0, 0)],
|
| 70 |
+
constant_values=np.log(config["mel_min"]),
|
| 71 |
+
)
|
| 72 |
+
x = wavegru_inference(net, mel)
|
| 73 |
+
x = jax.device_get(x)
|
| 74 |
+
|
| 75 |
+
wav = librosa.mu_expand(x - 127, mu=255)
|
| 76 |
+
wav = librosa.effects.deemphasis(wav, coef=0.86)
|
| 77 |
+
wav = wav * 2.0
|
| 78 |
+
wav = wav / max(1.0, np.max(np.abs(wav)))
|
| 79 |
+
wav = wav * 2**15
|
| 80 |
+
wav = np.clip(wav, a_min=-(2**15), a_max=(2**15) - 1)
|
| 81 |
+
wav = wav.astype(np.int16)
|
| 82 |
+
return wav
|
packages.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
libsndfile1-dev
|
pooch.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def os_cache(x):
|
| 2 |
+
return x
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def create(*args, **kwargs):
|
| 6 |
+
class T:
|
| 7 |
+
def load_registry(self, *args, **kwargs):
|
| 8 |
+
return None
|
| 9 |
+
|
| 10 |
+
return T()
|
pretrained_model_ljs_500k.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4eabdcac35cd016469d17678f9549bd25d1c9bf66c9089ea9f0632619ba91194
|
| 3 |
+
size 53221435
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
jax==0.3.1
|
| 2 |
+
jaxlib==0.3.0
|
| 3 |
+
numpy==1.22.3
|
| 4 |
+
librosa==0.9.1
|
| 5 |
+
pax3==0.5.6
|
| 6 |
+
gradio
|
| 7 |
+
jinja2
|
| 8 |
+
toml==0.10.2
|
| 9 |
+
unidecode==1.3.4
|
| 10 |
+
pyyaml==6.0
|
tacotron.py
ADDED
|
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tacotron + stepwise monotonic attention
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import jax
|
| 6 |
+
import jax.numpy as jnp
|
| 7 |
+
import pax
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def conv_block(in_ft, out_ft, kernel_size, activation_fn, use_dropout):
|
| 11 |
+
"""
|
| 12 |
+
Conv >> LayerNorm >> activation >> Dropout
|
| 13 |
+
"""
|
| 14 |
+
f = pax.Sequential(
|
| 15 |
+
pax.Conv1D(in_ft, out_ft, kernel_size, with_bias=False),
|
| 16 |
+
pax.LayerNorm(out_ft, -1, True, True),
|
| 17 |
+
)
|
| 18 |
+
if activation_fn is not None:
|
| 19 |
+
f >>= activation_fn
|
| 20 |
+
if use_dropout:
|
| 21 |
+
f >>= pax.Dropout(0.5)
|
| 22 |
+
return f
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class HighwayBlock(pax.Module):
|
| 26 |
+
"""
|
| 27 |
+
Highway block
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, dim: int) -> None:
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.dim = dim
|
| 33 |
+
self.fc = pax.Linear(dim, 2 * dim)
|
| 34 |
+
|
| 35 |
+
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
|
| 36 |
+
t, h = jnp.split(self.fc(x), 2, axis=-1)
|
| 37 |
+
t = jax.nn.sigmoid(t - 1.0) # bias toward keeping x
|
| 38 |
+
h = jax.nn.relu(h)
|
| 39 |
+
x = x * (1.0 - t) + h * t
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class BiGRU(pax.Module):
|
| 44 |
+
"""
|
| 45 |
+
Bidirectional GRU
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, dim):
|
| 49 |
+
super().__init__()
|
| 50 |
+
|
| 51 |
+
self.rnn_fwd = pax.GRU(dim, dim)
|
| 52 |
+
self.rnn_bwd = pax.GRU(dim, dim)
|
| 53 |
+
|
| 54 |
+
def __call__(self, x, reset_masks):
|
| 55 |
+
N = x.shape[0]
|
| 56 |
+
x_fwd = x
|
| 57 |
+
x_bwd = jnp.flip(x, axis=1)
|
| 58 |
+
x_fwd_states = self.rnn_fwd.initial_state(N)
|
| 59 |
+
x_bwd_states = self.rnn_bwd.initial_state(N)
|
| 60 |
+
x_fwd_states, x_fwd = pax.scan(
|
| 61 |
+
self.rnn_fwd, x_fwd_states, x_fwd, time_major=False
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
reset_masks = jnp.flip(reset_masks, axis=1)
|
| 65 |
+
x_bwd_states0 = x_bwd_states
|
| 66 |
+
|
| 67 |
+
def rnn_reset_core(prev, inputs):
|
| 68 |
+
x, reset_mask = inputs
|
| 69 |
+
|
| 70 |
+
def reset_state(x0, xt):
|
| 71 |
+
return jnp.where(reset_mask, x0, xt)
|
| 72 |
+
|
| 73 |
+
state, _ = self.rnn_bwd(prev, x)
|
| 74 |
+
state = jax.tree_map(reset_state, x_bwd_states0, state)
|
| 75 |
+
return state, state.hidden
|
| 76 |
+
|
| 77 |
+
x_bwd_states, x_bwd = pax.scan(
|
| 78 |
+
rnn_reset_core, x_bwd_states, (x_bwd, reset_masks), time_major=False
|
| 79 |
+
)
|
| 80 |
+
x_bwd = jnp.flip(x_bwd, axis=1)
|
| 81 |
+
x = jnp.concatenate((x_fwd, x_bwd), axis=-1)
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class CBHG(pax.Module):
|
| 86 |
+
"""
|
| 87 |
+
Conv Bank >> Highway net >> GRU
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self, dim):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.convs = [conv_block(dim, dim, i, jax.nn.relu, False) for i in range(1, 17)]
|
| 93 |
+
self.conv_projection_1 = conv_block(16 * dim, dim, 3, jax.nn.relu, False)
|
| 94 |
+
self.conv_projection_2 = conv_block(dim, dim, 3, None, False)
|
| 95 |
+
|
| 96 |
+
self.highway = pax.Sequential(
|
| 97 |
+
HighwayBlock(dim), HighwayBlock(dim), HighwayBlock(dim), HighwayBlock(dim)
|
| 98 |
+
)
|
| 99 |
+
self.rnn = BiGRU(dim)
|
| 100 |
+
|
| 101 |
+
def __call__(self, x, x_mask):
|
| 102 |
+
conv_input = x * x_mask
|
| 103 |
+
fts = [f(conv_input) for f in self.convs]
|
| 104 |
+
residual = jnp.concatenate(fts, axis=-1)
|
| 105 |
+
residual = pax.max_pool(residual, 2, 1, "SAME", -1)
|
| 106 |
+
residual = self.conv_projection_1(residual * x_mask)
|
| 107 |
+
residual = self.conv_projection_2(residual * x_mask)
|
| 108 |
+
x = x + residual
|
| 109 |
+
x = self.highway(x)
|
| 110 |
+
x = self.rnn(x * x_mask, reset_masks=1 - x_mask)
|
| 111 |
+
return x * x_mask
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class PreNet(pax.Module):
|
| 115 |
+
"""
|
| 116 |
+
Linear >> relu >> dropout >> Linear >> relu >> dropout
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
def __init__(self, input_dim, hidden_dim, output_dim, always_dropout=True):
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.fc1 = pax.Linear(input_dim, hidden_dim)
|
| 122 |
+
self.fc2 = pax.Linear(hidden_dim, output_dim)
|
| 123 |
+
self.rng_seq = pax.RngSeq()
|
| 124 |
+
self.always_dropout = always_dropout
|
| 125 |
+
|
| 126 |
+
def __call__(self, x, k1=None, k2=None):
|
| 127 |
+
x = self.fc1(x)
|
| 128 |
+
x = jax.nn.relu(x)
|
| 129 |
+
if self.always_dropout or self.training:
|
| 130 |
+
if k1 is None:
|
| 131 |
+
k1 = self.rng_seq.next_rng_key()
|
| 132 |
+
x = pax.dropout(k1, 0.5, x)
|
| 133 |
+
x = self.fc2(x)
|
| 134 |
+
x = jax.nn.relu(x)
|
| 135 |
+
if self.always_dropout or self.training:
|
| 136 |
+
if k2 is None:
|
| 137 |
+
k2 = self.rng_seq.next_rng_key()
|
| 138 |
+
x = pax.dropout(k2, 0.5, x)
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class Tacotron(pax.Module):
|
| 143 |
+
"""
|
| 144 |
+
Tacotron TTS model.
|
| 145 |
+
|
| 146 |
+
It uses stepwise monotonic attention for robust attention.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
mel_dim: int,
|
| 152 |
+
attn_bias,
|
| 153 |
+
rr,
|
| 154 |
+
max_rr,
|
| 155 |
+
mel_min,
|
| 156 |
+
sigmoid_noise,
|
| 157 |
+
pad_token,
|
| 158 |
+
prenet_dim,
|
| 159 |
+
attn_hidden_dim,
|
| 160 |
+
attn_rnn_dim,
|
| 161 |
+
rnn_dim,
|
| 162 |
+
postnet_dim,
|
| 163 |
+
text_dim,
|
| 164 |
+
):
|
| 165 |
+
"""
|
| 166 |
+
New Tacotron model
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
mel_dim (int): dimension of log mel-spectrogram features.
|
| 170 |
+
attn_bias (float): control how "slow" the attention will
|
| 171 |
+
move forward at initialization.
|
| 172 |
+
rr (int): the reduction factor.
|
| 173 |
+
Number of predicted frame at each time step. Default is 2.
|
| 174 |
+
max_rr (int): max value of rr.
|
| 175 |
+
mel_min (float): the minimum value of mel features.
|
| 176 |
+
The <go> frame is filled by `log(mel_min)` values.
|
| 177 |
+
sigmoid_noise (float): the variance of gaussian noise added
|
| 178 |
+
to attention scores in training.
|
| 179 |
+
pad_token (int): the pad value at the end of text sequences.
|
| 180 |
+
prenet_dim (int): dimension of prenet output.
|
| 181 |
+
attn_hidden_dim (int): dimension of attention hidden vectors.
|
| 182 |
+
attn_rnn_dim (int): number of cells in the attention RNN.
|
| 183 |
+
rnn_dim (int): number of cells in the decoder RNNs.
|
| 184 |
+
postnet_dim (int): number of features in the postnet convolutions.
|
| 185 |
+
text_dim (int): dimension of text embedding vectors.
|
| 186 |
+
"""
|
| 187 |
+
super().__init__()
|
| 188 |
+
self.text_dim = text_dim
|
| 189 |
+
assert rr <= max_rr
|
| 190 |
+
self.rr = rr
|
| 191 |
+
self.max_rr = max_rr
|
| 192 |
+
self.mel_dim = mel_dim
|
| 193 |
+
self.mel_min = mel_min
|
| 194 |
+
self.sigmoid_noise = sigmoid_noise
|
| 195 |
+
self.pad_token = pad_token
|
| 196 |
+
self.prenet_dim = prenet_dim
|
| 197 |
+
|
| 198 |
+
# encoder submodules
|
| 199 |
+
self.encoder_embed = pax.Embed(256, text_dim)
|
| 200 |
+
self.encoder_pre_net = PreNet(text_dim, 256, prenet_dim, always_dropout=True)
|
| 201 |
+
self.encoder_cbhg = CBHG(prenet_dim)
|
| 202 |
+
|
| 203 |
+
# random key generator
|
| 204 |
+
self.rng_seq = pax.RngSeq()
|
| 205 |
+
|
| 206 |
+
# pre-net
|
| 207 |
+
self.decoder_pre_net = PreNet(mel_dim, 256, prenet_dim, always_dropout=True)
|
| 208 |
+
|
| 209 |
+
# decoder submodules
|
| 210 |
+
self.attn_rnn = pax.LSTM(prenet_dim + prenet_dim * 2, attn_rnn_dim)
|
| 211 |
+
self.text_key_fc = pax.Linear(prenet_dim * 2, attn_hidden_dim, with_bias=True)
|
| 212 |
+
self.attn_query_fc = pax.Linear(attn_rnn_dim, attn_hidden_dim, with_bias=False)
|
| 213 |
+
|
| 214 |
+
self.attn_V = pax.Linear(attn_hidden_dim, 1, with_bias=False)
|
| 215 |
+
self.attn_V_weight_norm = jnp.array(1.0 / jnp.sqrt(attn_hidden_dim))
|
| 216 |
+
self.attn_V_bias = jnp.array(attn_bias)
|
| 217 |
+
self.attn_log = jnp.zeros((1,))
|
| 218 |
+
self.decoder_input = pax.Linear(attn_rnn_dim + 2 * prenet_dim, rnn_dim)
|
| 219 |
+
self.decoder_rnn1 = pax.LSTM(rnn_dim, rnn_dim)
|
| 220 |
+
self.decoder_rnn2 = pax.LSTM(rnn_dim, rnn_dim)
|
| 221 |
+
# mel + end-of-sequence token
|
| 222 |
+
self.output_fc = pax.Linear(rnn_dim, (mel_dim + 1) * max_rr, with_bias=True)
|
| 223 |
+
|
| 224 |
+
# post-net
|
| 225 |
+
self.post_net = pax.Sequential(
|
| 226 |
+
conv_block(mel_dim, postnet_dim, 5, jax.nn.tanh, True),
|
| 227 |
+
conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True),
|
| 228 |
+
conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True),
|
| 229 |
+
conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True),
|
| 230 |
+
conv_block(postnet_dim, mel_dim, 5, None, True),
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
parameters = pax.parameters_method("attn_V_weight_norm", "attn_V_bias")
|
| 234 |
+
|
| 235 |
+
def encode_text(self, text: jnp.ndarray) -> jnp.ndarray:
|
| 236 |
+
"""
|
| 237 |
+
Encode text to a sequence of real vectors
|
| 238 |
+
"""
|
| 239 |
+
N, L = text.shape
|
| 240 |
+
text_mask = (text != self.pad_token)[..., None]
|
| 241 |
+
x = self.encoder_embed(text)
|
| 242 |
+
x = self.encoder_pre_net(x)
|
| 243 |
+
x = self.encoder_cbhg(x, text_mask)
|
| 244 |
+
return x
|
| 245 |
+
|
| 246 |
+
def go_frame(self, batch_size: int) -> jnp.ndarray:
|
| 247 |
+
"""
|
| 248 |
+
return the go frame
|
| 249 |
+
"""
|
| 250 |
+
return jnp.ones((batch_size, self.mel_dim)) * jnp.log(self.mel_min)
|
| 251 |
+
|
| 252 |
+
def decoder_initial_state(self, N: int, L: int):
|
| 253 |
+
"""
|
| 254 |
+
setup decoder initial state
|
| 255 |
+
"""
|
| 256 |
+
attn_context = jnp.zeros((N, self.prenet_dim * 2))
|
| 257 |
+
attn_pr = jax.nn.one_hot(
|
| 258 |
+
jnp.zeros((N,), dtype=jnp.int32), num_classes=L, axis=-1
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
attn_state = (self.attn_rnn.initial_state(N), attn_context, attn_pr)
|
| 262 |
+
decoder_rnn_states = (
|
| 263 |
+
self.decoder_rnn1.initial_state(N),
|
| 264 |
+
self.decoder_rnn2.initial_state(N),
|
| 265 |
+
)
|
| 266 |
+
return attn_state, decoder_rnn_states
|
| 267 |
+
|
| 268 |
+
def monotonic_attention(self, prev_state, inputs, envs):
|
| 269 |
+
"""
|
| 270 |
+
Stepwise monotonic attention
|
| 271 |
+
"""
|
| 272 |
+
attn_rnn_state, attn_context, prev_attn_pr = prev_state
|
| 273 |
+
x, attn_rng_key = inputs
|
| 274 |
+
text, text_key = envs
|
| 275 |
+
attn_rnn_input = jnp.concatenate((x, attn_context), axis=-1)
|
| 276 |
+
attn_rnn_state, attn_rnn_output = self.attn_rnn(attn_rnn_state, attn_rnn_input)
|
| 277 |
+
attn_query_input = attn_rnn_output
|
| 278 |
+
attn_query = self.attn_query_fc(attn_query_input)
|
| 279 |
+
attn_hidden = jnp.tanh(attn_query[:, None, :] + text_key)
|
| 280 |
+
score = self.attn_V(attn_hidden)
|
| 281 |
+
score = jnp.squeeze(score, axis=-1)
|
| 282 |
+
weight_norm = jnp.linalg.norm(self.attn_V.weight)
|
| 283 |
+
score = score * (self.attn_V_weight_norm / weight_norm)
|
| 284 |
+
score = score + self.attn_V_bias
|
| 285 |
+
noise = jax.random.normal(attn_rng_key, score.shape) * self.sigmoid_noise
|
| 286 |
+
pr_stay = jax.nn.sigmoid(score + noise)
|
| 287 |
+
pr_move = 1.0 - pr_stay
|
| 288 |
+
pr_new_location = pr_move * prev_attn_pr
|
| 289 |
+
pr_new_location = jnp.pad(
|
| 290 |
+
pr_new_location[:, :-1], ((0, 0), (1, 0)), constant_values=0
|
| 291 |
+
)
|
| 292 |
+
attn_pr = pr_stay * prev_attn_pr + pr_new_location
|
| 293 |
+
attn_context = jnp.einsum("NL,NLD->ND", attn_pr, text)
|
| 294 |
+
new_state = (attn_rnn_state, attn_context, attn_pr)
|
| 295 |
+
return new_state, attn_rnn_output
|
| 296 |
+
|
| 297 |
+
def zoneout_lstm(self, lstm_core, rng_key, zoneout_pr=0.1):
|
| 298 |
+
"""
|
| 299 |
+
Return a zoneout lstm core.
|
| 300 |
+
|
| 301 |
+
It will zoneout the new hidden states and keep the new cell states unchanged.
|
| 302 |
+
"""
|
| 303 |
+
|
| 304 |
+
def core(state, x):
|
| 305 |
+
new_state, _ = lstm_core(state, x)
|
| 306 |
+
h_old = state.hidden
|
| 307 |
+
h_new = new_state.hidden
|
| 308 |
+
mask = jax.random.bernoulli(rng_key, zoneout_pr, h_old.shape)
|
| 309 |
+
h_new = h_old * mask + h_new * (1.0 - mask)
|
| 310 |
+
return pax.LSTMState(h_new, new_state.cell), h_new
|
| 311 |
+
|
| 312 |
+
return core
|
| 313 |
+
|
| 314 |
+
def decoder_step(
|
| 315 |
+
self,
|
| 316 |
+
attn_state,
|
| 317 |
+
decoder_rnn_states,
|
| 318 |
+
rng_key,
|
| 319 |
+
mel,
|
| 320 |
+
text,
|
| 321 |
+
text_key,
|
| 322 |
+
call_pre_net=False,
|
| 323 |
+
):
|
| 324 |
+
"""
|
| 325 |
+
One decoder step
|
| 326 |
+
"""
|
| 327 |
+
if call_pre_net:
|
| 328 |
+
k1, k2, zk1, zk2, rng_key, rng_key_next = jax.random.split(rng_key, 6)
|
| 329 |
+
mel = self.decoder_pre_net(mel, k1, k2)
|
| 330 |
+
else:
|
| 331 |
+
zk1, zk2, rng_key, rng_key_next = jax.random.split(rng_key, 4)
|
| 332 |
+
attn_inputs = (mel, rng_key)
|
| 333 |
+
attn_envs = (text, text_key)
|
| 334 |
+
attn_state, attn_rnn_output = self.monotonic_attention(
|
| 335 |
+
attn_state, attn_inputs, attn_envs
|
| 336 |
+
)
|
| 337 |
+
(_, attn_context, attn_pr) = attn_state
|
| 338 |
+
(decoder_rnn_state1, decoder_rnn_state2) = decoder_rnn_states
|
| 339 |
+
decoder_rnn1_input = jnp.concatenate((attn_rnn_output, attn_context), axis=-1)
|
| 340 |
+
decoder_rnn1_input = self.decoder_input(decoder_rnn1_input)
|
| 341 |
+
decoder_rnn1 = self.zoneout_lstm(self.decoder_rnn1, zk1)
|
| 342 |
+
decoder_rnn_state1, decoder_rnn_output1 = decoder_rnn1(
|
| 343 |
+
decoder_rnn_state1, decoder_rnn1_input
|
| 344 |
+
)
|
| 345 |
+
decoder_rnn2_input = decoder_rnn1_input + decoder_rnn_output1
|
| 346 |
+
decoder_rnn2 = self.zoneout_lstm(self.decoder_rnn2, zk2)
|
| 347 |
+
decoder_rnn_state2, decoder_rnn_output2 = decoder_rnn2(
|
| 348 |
+
decoder_rnn_state2, decoder_rnn2_input
|
| 349 |
+
)
|
| 350 |
+
x = decoder_rnn1_input + decoder_rnn_output1 + decoder_rnn_output2
|
| 351 |
+
decoder_rnn_states = (decoder_rnn_state1, decoder_rnn_state2)
|
| 352 |
+
return attn_state, decoder_rnn_states, rng_key_next, x, attn_pr[0]
|
| 353 |
+
|
| 354 |
+
@jax.jit
|
| 355 |
+
def inference_step(
|
| 356 |
+
self, attn_state, decoder_rnn_states, rng_key, mel, text, text_key
|
| 357 |
+
):
|
| 358 |
+
"""one inference step"""
|
| 359 |
+
attn_state, decoder_rnn_states, rng_key, x, _ = self.decoder_step(
|
| 360 |
+
attn_state,
|
| 361 |
+
decoder_rnn_states,
|
| 362 |
+
rng_key,
|
| 363 |
+
mel,
|
| 364 |
+
text,
|
| 365 |
+
text_key,
|
| 366 |
+
call_pre_net=True,
|
| 367 |
+
)
|
| 368 |
+
x = self.output_fc(x)
|
| 369 |
+
N, D2 = x.shape
|
| 370 |
+
x = jnp.reshape(x, (N, self.max_rr, D2 // self.max_rr))
|
| 371 |
+
x = x[:, : self.rr, :]
|
| 372 |
+
x = jnp.reshape(x, (N, self.rr, -1))
|
| 373 |
+
mel = x[..., :-1]
|
| 374 |
+
eos = x[..., -1]
|
| 375 |
+
return attn_state, decoder_rnn_states, rng_key, (mel, eos)
|
| 376 |
+
|
| 377 |
+
def inference(self, text, seed=42, max_len=1000):
|
| 378 |
+
"""
|
| 379 |
+
text to mel
|
| 380 |
+
"""
|
| 381 |
+
text = self.encode_text(text)
|
| 382 |
+
text_key = self.text_key_fc(text)
|
| 383 |
+
N, L, D = text.shape
|
| 384 |
+
mel = self.go_frame(N)
|
| 385 |
+
|
| 386 |
+
attn_state, decoder_rnn_states = self.decoder_initial_state(N, L)
|
| 387 |
+
rng_key = jax.random.PRNGKey(seed)
|
| 388 |
+
mels = []
|
| 389 |
+
count = 0
|
| 390 |
+
while True:
|
| 391 |
+
count = count + 1
|
| 392 |
+
attn_state, decoder_rnn_states, rng_key, (mel, eos) = self.inference_step(
|
| 393 |
+
attn_state, decoder_rnn_states, rng_key, mel, text, text_key
|
| 394 |
+
)
|
| 395 |
+
mels.append(mel)
|
| 396 |
+
if eos[0, -1].item() > 0 or count > max_len:
|
| 397 |
+
break
|
| 398 |
+
|
| 399 |
+
mel = mel[:, -1, :]
|
| 400 |
+
|
| 401 |
+
mels = jnp.concatenate(mels, axis=1)
|
| 402 |
+
mel = mel + self.post_net(mel)
|
| 403 |
+
return mels
|
| 404 |
+
|
| 405 |
+
def decode(self, mel, text):
|
| 406 |
+
"""
|
| 407 |
+
Attention mechanism + Decoder
|
| 408 |
+
"""
|
| 409 |
+
text_key = self.text_key_fc(text)
|
| 410 |
+
|
| 411 |
+
def scan_fn(prev_states, inputs):
|
| 412 |
+
attn_state, decoder_rnn_states = prev_states
|
| 413 |
+
x, rng_key = inputs
|
| 414 |
+
attn_state, decoder_rnn_states, _, output, attn_pr = self.decoder_step(
|
| 415 |
+
attn_state, decoder_rnn_states, rng_key, x, text, text_key
|
| 416 |
+
)
|
| 417 |
+
states = (attn_state, decoder_rnn_states)
|
| 418 |
+
return states, (output, attn_pr)
|
| 419 |
+
|
| 420 |
+
N, L, D = text.shape
|
| 421 |
+
decoder_states = self.decoder_initial_state(N, L)
|
| 422 |
+
rng_keys = self.rng_seq.next_rng_key(mel.shape[1])
|
| 423 |
+
rng_keys = jnp.stack(rng_keys, axis=1)
|
| 424 |
+
decoder_states, (x, attn_log) = pax.scan(
|
| 425 |
+
scan_fn,
|
| 426 |
+
decoder_states,
|
| 427 |
+
(mel, rng_keys),
|
| 428 |
+
time_major=False,
|
| 429 |
+
)
|
| 430 |
+
self.attn_log = attn_log
|
| 431 |
+
del decoder_states
|
| 432 |
+
x = self.output_fc(x)
|
| 433 |
+
|
| 434 |
+
N, T2, D2 = x.shape
|
| 435 |
+
x = jnp.reshape(x, (N, T2, self.max_rr, D2 // self.max_rr))
|
| 436 |
+
x = x[:, :, : self.rr, :]
|
| 437 |
+
x = jnp.reshape(x, (N, T2 * self.rr, -1))
|
| 438 |
+
mel = x[..., :-1]
|
| 439 |
+
eos = x[..., -1]
|
| 440 |
+
return mel, eos
|
| 441 |
+
|
| 442 |
+
def __call__(self, mel: jnp.ndarray, text: jnp.ndarray):
|
| 443 |
+
text = self.encode_text(text)
|
| 444 |
+
mel = self.decoder_pre_net(mel)
|
| 445 |
+
mel, eos = self.decode(mel, text)
|
| 446 |
+
return mel, mel + self.post_net(mel), eos
|
tacotron.toml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tacotron]
|
| 2 |
+
|
| 3 |
+
# training
|
| 4 |
+
BATCH_SIZE = 64
|
| 5 |
+
LR=1024e-6 # learning rate
|
| 6 |
+
MODEL_PREFIX = "mono_tts_cbhg_small"
|
| 7 |
+
LOG_DIR = "./logs"
|
| 8 |
+
CKPT_DIR = "./ckpts"
|
| 9 |
+
USE_MP = false # use mixed-precision training
|
| 10 |
+
|
| 11 |
+
# data
|
| 12 |
+
TF_DATA_DIR = "./tf_data" # tensorflow data directory
|
| 13 |
+
TF_GTA_DATA_DIR = "./tf_gta_data" # tf gta data directory
|
| 14 |
+
SAMPLE_RATE = 24000 # convert to this sample rate if needed
|
| 15 |
+
MEL_DIM = 80 # the dimension of melspectrogram features
|
| 16 |
+
MEL_MIN = 1e-5
|
| 17 |
+
PAD = "_" # padding character
|
| 18 |
+
PAD_TOKEN = 0
|
| 19 |
+
TEST_DATA_SIZE = 1024
|
| 20 |
+
|
| 21 |
+
# model
|
| 22 |
+
RR = 2 # reduction factor
|
| 23 |
+
MAX_RR=2
|
| 24 |
+
ATTN_BIAS = 0.0 # control how slow the attention moves forward
|
| 25 |
+
SIGMOID_NOISE = 2.0
|
| 26 |
+
PRENET_DIM = 128
|
| 27 |
+
TEXT_DIM = 256
|
| 28 |
+
RNN_DIM = 512
|
| 29 |
+
ATTN_RNN_DIM = 256
|
| 30 |
+
ATTN_HIDDEN_DIM = 128
|
| 31 |
+
POSTNET_DIM = 512
|
text.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" from https://github.com/keithito/tacotron """
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Cleaners are transformations that run over the input text at both training and eval time.
|
| 5 |
+
|
| 6 |
+
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
| 7 |
+
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
| 8 |
+
1. "english_cleaners" for English text
|
| 9 |
+
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
| 10 |
+
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
| 11 |
+
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
| 12 |
+
the symbols in symbols.py to match your data).
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import re
|
| 16 |
+
|
| 17 |
+
from unidecode import unidecode
|
| 18 |
+
|
| 19 |
+
# Regular expression matching whitespace:
|
| 20 |
+
_whitespace_re = re.compile(r"\s+")
|
| 21 |
+
|
| 22 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
| 23 |
+
_abbreviations = [
|
| 24 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
| 25 |
+
for x in [
|
| 26 |
+
("mrs", "misess"),
|
| 27 |
+
("mr", "mister"),
|
| 28 |
+
("dr", "doctor"),
|
| 29 |
+
("st", "saint"),
|
| 30 |
+
("co", "company"),
|
| 31 |
+
("jr", "junior"),
|
| 32 |
+
("maj", "major"),
|
| 33 |
+
("gen", "general"),
|
| 34 |
+
("drs", "doctors"),
|
| 35 |
+
("rev", "reverend"),
|
| 36 |
+
("lt", "lieutenant"),
|
| 37 |
+
("hon", "honorable"),
|
| 38 |
+
("sgt", "sergeant"),
|
| 39 |
+
("capt", "captain"),
|
| 40 |
+
("esq", "esquire"),
|
| 41 |
+
("ltd", "limited"),
|
| 42 |
+
("col", "colonel"),
|
| 43 |
+
("ft", "fort"),
|
| 44 |
+
]
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def expand_abbreviations(text):
|
| 49 |
+
for regex, replacement in _abbreviations:
|
| 50 |
+
text = re.sub(regex, replacement, text)
|
| 51 |
+
return text
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def lowercase(text):
|
| 55 |
+
return text.lower()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def collapse_whitespace(text):
|
| 59 |
+
return re.sub(_whitespace_re, " ", text)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def convert_to_ascii(text):
|
| 63 |
+
return unidecode(text)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def basic_cleaners(text):
|
| 67 |
+
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
| 68 |
+
text = lowercase(text)
|
| 69 |
+
text = collapse_whitespace(text)
|
| 70 |
+
return text
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def transliteration_cleaners(text):
|
| 74 |
+
"""Pipeline for non-English text that transliterates to ASCII."""
|
| 75 |
+
text = convert_to_ascii(text)
|
| 76 |
+
text = lowercase(text)
|
| 77 |
+
text = collapse_whitespace(text)
|
| 78 |
+
return text
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def english_cleaners(text):
|
| 82 |
+
"""Pipeline for English text, including number and abbreviation expansion."""
|
| 83 |
+
text = convert_to_ascii(text)
|
| 84 |
+
text = lowercase(text)
|
| 85 |
+
text = expand_abbreviations(text)
|
| 86 |
+
text = collapse_whitespace(text)
|
| 87 |
+
return text
|
utils.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions
|
| 3 |
+
"""
|
| 4 |
+
import pickle
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import pax
|
| 8 |
+
import toml
|
| 9 |
+
import yaml
|
| 10 |
+
|
| 11 |
+
from tacotron import Tacotron
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_tacotron_config(config_file=Path("tacotron.toml")):
|
| 15 |
+
"""
|
| 16 |
+
Load the project configurations
|
| 17 |
+
"""
|
| 18 |
+
return toml.load(config_file)["tacotron"]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_tacotron_ckpt(net: pax.Module, optim: pax.Module, path):
|
| 22 |
+
"""
|
| 23 |
+
load checkpoint from disk
|
| 24 |
+
"""
|
| 25 |
+
with open(path, "rb") as f:
|
| 26 |
+
dic = pickle.load(f)
|
| 27 |
+
if net is not None:
|
| 28 |
+
net = net.load_state_dict(dic["model_state_dict"])
|
| 29 |
+
if optim is not None:
|
| 30 |
+
optim = optim.load_state_dict(dic["optim_state_dict"])
|
| 31 |
+
return dic["step"], net, optim
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def create_tacotron_model(config):
|
| 35 |
+
"""
|
| 36 |
+
return a random initialized Tacotron model
|
| 37 |
+
"""
|
| 38 |
+
return Tacotron(
|
| 39 |
+
mel_dim=config["MEL_DIM"],
|
| 40 |
+
attn_bias=config["ATTN_BIAS"],
|
| 41 |
+
rr=config["RR"],
|
| 42 |
+
max_rr=config["MAX_RR"],
|
| 43 |
+
mel_min=config["MEL_MIN"],
|
| 44 |
+
sigmoid_noise=config["SIGMOID_NOISE"],
|
| 45 |
+
pad_token=config["PAD_TOKEN"],
|
| 46 |
+
prenet_dim=config["PRENET_DIM"],
|
| 47 |
+
attn_hidden_dim=config["ATTN_HIDDEN_DIM"],
|
| 48 |
+
attn_rnn_dim=config["ATTN_RNN_DIM"],
|
| 49 |
+
rnn_dim=config["RNN_DIM"],
|
| 50 |
+
postnet_dim=config["POSTNET_DIM"],
|
| 51 |
+
text_dim=config["TEXT_DIM"],
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def load_wavegru_config(config_file):
|
| 56 |
+
"""
|
| 57 |
+
Load project configurations
|
| 58 |
+
"""
|
| 59 |
+
with open(config_file, "r", encoding="utf-8") as f:
|
| 60 |
+
return yaml.safe_load(f)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def load_wavegru_ckpt(net, optim, ckpt_file):
|
| 64 |
+
"""
|
| 65 |
+
load training checkpoint from file
|
| 66 |
+
"""
|
| 67 |
+
with open(ckpt_file, "rb") as f:
|
| 68 |
+
dic = pickle.load(f)
|
| 69 |
+
|
| 70 |
+
if net is not None:
|
| 71 |
+
net = net.load_state_dict(dic["net_state_dict"])
|
| 72 |
+
if optim is not None:
|
| 73 |
+
optim = optim.load_state_dict(dic["optim_state_dict"])
|
| 74 |
+
return dic["step"], net, optim
|
wavegru.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
WaveGRU model: melspectrogram => mu-law encoded waveform
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import jax
|
| 6 |
+
import jax.numpy as jnp
|
| 7 |
+
import pax
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ReLU(pax.Module):
|
| 11 |
+
def __call__(self, x):
|
| 12 |
+
return jax.nn.relu(x)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def dilated_residual_conv_block(dim, kernel, stride, dilation):
|
| 16 |
+
"""
|
| 17 |
+
Use dilated convs to enlarge the receptive field
|
| 18 |
+
"""
|
| 19 |
+
return pax.Sequential(
|
| 20 |
+
pax.Conv1D(dim, dim, kernel, stride, dilation, "VALID", with_bias=False),
|
| 21 |
+
pax.LayerNorm(dim, -1, True, True),
|
| 22 |
+
ReLU(),
|
| 23 |
+
pax.Conv1D(dim, dim, 1, 1, 1, "VALID", with_bias=False),
|
| 24 |
+
pax.LayerNorm(dim, -1, True, True),
|
| 25 |
+
ReLU(),
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def tile_1d(x, factor):
|
| 30 |
+
"""
|
| 31 |
+
Tile tensor of shape N, L, D into N, L*factor, D
|
| 32 |
+
"""
|
| 33 |
+
N, L, D = x.shape
|
| 34 |
+
x = x[:, :, None, :]
|
| 35 |
+
x = jnp.tile(x, (1, 1, factor, 1))
|
| 36 |
+
x = jnp.reshape(x, (N, L * factor, D))
|
| 37 |
+
return x
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def up_block(dim, factor):
|
| 41 |
+
"""
|
| 42 |
+
Tile >> Conv >> BatchNorm >> ReLU
|
| 43 |
+
"""
|
| 44 |
+
return pax.Sequential(
|
| 45 |
+
lambda x: tile_1d(x, factor),
|
| 46 |
+
pax.Conv1D(dim, dim, 2 * factor, stride=1, padding="VALID", with_bias=False),
|
| 47 |
+
pax.LayerNorm(dim, -1, True, True),
|
| 48 |
+
ReLU(),
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class Upsample(pax.Module):
|
| 53 |
+
"""
|
| 54 |
+
Upsample melspectrogram to match raw audio sample rate.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self, input_dim, upsample_factors):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.input_conv = pax.Sequential(
|
| 60 |
+
pax.Conv1D(input_dim, 512, 1, with_bias=False),
|
| 61 |
+
pax.LayerNorm(512, -1, True, True),
|
| 62 |
+
)
|
| 63 |
+
self.upsample_factors = upsample_factors
|
| 64 |
+
self.dilated_convs = [
|
| 65 |
+
dilated_residual_conv_block(512, 3, 1, 2**i) for i in range(5)
|
| 66 |
+
]
|
| 67 |
+
self.up_factors = upsample_factors[:-1]
|
| 68 |
+
self.up_blocks = [up_block(512, x) for x in self.up_factors]
|
| 69 |
+
self.final_tile = upsample_factors[-1]
|
| 70 |
+
|
| 71 |
+
def __call__(self, x):
|
| 72 |
+
x = self.input_conv(x)
|
| 73 |
+
for residual in self.dilated_convs:
|
| 74 |
+
y = residual(x)
|
| 75 |
+
pad = (x.shape[1] - y.shape[1]) // 2
|
| 76 |
+
x = x[:, pad:-pad, :] + y
|
| 77 |
+
|
| 78 |
+
for f in self.up_blocks:
|
| 79 |
+
x = f(x)
|
| 80 |
+
|
| 81 |
+
x = tile_1d(x, self.final_tile)
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class Pruner(pax.Module):
|
| 86 |
+
"""
|
| 87 |
+
Base class for pruners
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self, update_freq=500):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.update_freq = update_freq
|
| 93 |
+
|
| 94 |
+
def compute_sparsity(self, step):
|
| 95 |
+
"""
|
| 96 |
+
Two-stages pruning
|
| 97 |
+
"""
|
| 98 |
+
t = jnp.power(1 - (step * 1.0 - 1_000) / 300_000, 3)
|
| 99 |
+
z = 0.5 * jnp.clip(1.0 - t, a_min=0, a_max=1)
|
| 100 |
+
for i in range(4):
|
| 101 |
+
t = jnp.power(1 - (step * 1.0 - 1_000 - 400_000 - i * 200_000) / 100_000, 3)
|
| 102 |
+
z = z + 0.1 * jnp.clip(1 - t, a_min=0, a_max=1)
|
| 103 |
+
return z
|
| 104 |
+
|
| 105 |
+
def prune(self, step, weights):
|
| 106 |
+
"""
|
| 107 |
+
Return a mask
|
| 108 |
+
"""
|
| 109 |
+
z = self.compute_sparsity(step)
|
| 110 |
+
x = weights
|
| 111 |
+
H, W = x.shape
|
| 112 |
+
x = x.reshape(H // 4, 4, W // 4, 4)
|
| 113 |
+
x = jnp.abs(x)
|
| 114 |
+
x = jnp.sum(x, axis=(1, 3), keepdims=True)
|
| 115 |
+
q = jnp.quantile(jnp.reshape(x, (-1,)), z)
|
| 116 |
+
x = x >= q
|
| 117 |
+
x = jnp.tile(x, (1, 4, 1, 4))
|
| 118 |
+
x = jnp.reshape(x, (H, W))
|
| 119 |
+
return x
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class GRUPruner(Pruner):
|
| 123 |
+
def __init__(self, gru, update_freq=500):
|
| 124 |
+
super().__init__(update_freq=update_freq)
|
| 125 |
+
self.xh_zr_fc_mask = jnp.ones_like(gru.xh_zr_fc.weight) == 1
|
| 126 |
+
self.xh_h_fc_mask = jnp.ones_like(gru.xh_h_fc.weight) == 1
|
| 127 |
+
|
| 128 |
+
def __call__(self, gru: pax.GRU):
|
| 129 |
+
"""
|
| 130 |
+
Apply mask after an optimization step
|
| 131 |
+
"""
|
| 132 |
+
zr_masked_weights = jnp.where(self.xh_zr_fc_mask, gru.xh_zr_fc.weight, 0)
|
| 133 |
+
gru = gru.replace_node(gru.xh_zr_fc.weight, zr_masked_weights)
|
| 134 |
+
h_masked_weights = jnp.where(self.xh_h_fc_mask, gru.xh_h_fc.weight, 0)
|
| 135 |
+
gru = gru.replace_node(gru.xh_h_fc.weight, h_masked_weights)
|
| 136 |
+
return gru
|
| 137 |
+
|
| 138 |
+
def update_mask(self, step, gru: pax.GRU):
|
| 139 |
+
"""
|
| 140 |
+
Update internal masks
|
| 141 |
+
"""
|
| 142 |
+
xh_z_weight, xh_r_weight = jnp.split(gru.xh_zr_fc.weight, 2, axis=1)
|
| 143 |
+
xh_z_weight = self.prune(step, xh_z_weight)
|
| 144 |
+
xh_r_weight = self.prune(step, xh_r_weight)
|
| 145 |
+
self.xh_zr_fc_mask *= jnp.concatenate((xh_z_weight, xh_r_weight), axis=1)
|
| 146 |
+
self.xh_h_fc_mask *= self.prune(step, gru.xh_h_fc.weight)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class LinearPruner(Pruner):
|
| 150 |
+
def __init__(self, linear, update_freq=500):
|
| 151 |
+
super().__init__(update_freq=update_freq)
|
| 152 |
+
self.mask = jnp.ones_like(linear.weight) == 1
|
| 153 |
+
|
| 154 |
+
def __call__(self, linear: pax.Linear):
|
| 155 |
+
"""
|
| 156 |
+
Apply mask after an optimization step
|
| 157 |
+
"""
|
| 158 |
+
return linear.replace(weight=jnp.where(self.mask, linear.weight, 0))
|
| 159 |
+
|
| 160 |
+
def update_mask(self, step, linear: pax.Linear):
|
| 161 |
+
"""
|
| 162 |
+
Update internal masks
|
| 163 |
+
"""
|
| 164 |
+
self.mask *= self.prune(step, linear.weight)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class WaveGRU(pax.Module):
|
| 168 |
+
"""
|
| 169 |
+
WaveGRU vocoder model
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
def __init__(
|
| 173 |
+
self, mel_dim=80, embed_dim=32, rnn_dim=512, upsample_factors=(5, 4, 3, 5)
|
| 174 |
+
):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.embed = pax.Embed(256, embed_dim)
|
| 177 |
+
self.upsample = Upsample(input_dim=mel_dim, upsample_factors=upsample_factors)
|
| 178 |
+
self.rnn = pax.GRU(embed_dim + rnn_dim, rnn_dim)
|
| 179 |
+
self.o1 = pax.Linear(rnn_dim, rnn_dim)
|
| 180 |
+
self.o2 = pax.Linear(rnn_dim, 256)
|
| 181 |
+
self.gru_pruner = GRUPruner(self.rnn)
|
| 182 |
+
self.o1_pruner = LinearPruner(self.o1)
|
| 183 |
+
self.o2_pruner = LinearPruner(self.o2)
|
| 184 |
+
|
| 185 |
+
def output(self, x):
|
| 186 |
+
x = self.o1(x)
|
| 187 |
+
x = jax.nn.relu(x)
|
| 188 |
+
x = self.o2(x)
|
| 189 |
+
return x
|
| 190 |
+
|
| 191 |
+
@jax.jit
|
| 192 |
+
def inference_step(self, rnn_state, mel, rng_key, x):
|
| 193 |
+
"""one inference step"""
|
| 194 |
+
x = self.embed(x)
|
| 195 |
+
x = jnp.concatenate((x, mel), axis=-1)
|
| 196 |
+
rnn_state, x = self.rnn(rnn_state, x)
|
| 197 |
+
x = self.output(x)
|
| 198 |
+
rng_key, next_rng_key = jax.random.split(rng_key, 2)
|
| 199 |
+
x = jax.random.categorical(rng_key, x, axis=-1)
|
| 200 |
+
return rnn_state, next_rng_key, x
|
| 201 |
+
|
| 202 |
+
def inference(self, mel, no_gru=False, seed=42):
|
| 203 |
+
"""
|
| 204 |
+
generate waveform form melspectrogram
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
y = self.upsample(mel)
|
| 208 |
+
if no_gru:
|
| 209 |
+
return y
|
| 210 |
+
x = jnp.array([127], dtype=jnp.int32)
|
| 211 |
+
rnn_state = self.rnn.initial_state(1)
|
| 212 |
+
output = []
|
| 213 |
+
rng_key = jax.random.PRNGKey(seed)
|
| 214 |
+
for i in range(y.shape[1]):
|
| 215 |
+
rnn_state, rng_key, x = self.inference_step(rnn_state, y[:, i], rng_key, x)
|
| 216 |
+
output.append(x)
|
| 217 |
+
x = jnp.concatenate(output, axis=0)
|
| 218 |
+
return x
|
| 219 |
+
|
| 220 |
+
def __call__(self, mel, x):
|
| 221 |
+
x = self.embed(x)
|
| 222 |
+
y = self.upsample(mel)
|
| 223 |
+
pad_left = (x.shape[1] - y.shape[1]) // 2
|
| 224 |
+
pad_right = x.shape[1] - y.shape[1] - pad_left
|
| 225 |
+
x = x[:, pad_left:-pad_right]
|
| 226 |
+
x = jnp.concatenate((x, y), axis=-1)
|
| 227 |
+
_, x = pax.scan(
|
| 228 |
+
self.rnn,
|
| 229 |
+
self.rnn.initial_state(x.shape[0]),
|
| 230 |
+
x,
|
| 231 |
+
time_major=False,
|
| 232 |
+
)
|
| 233 |
+
x = self.output(x)
|
| 234 |
+
return x
|
wavegru.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## dsp
|
| 2 |
+
sample_rate : 24000
|
| 3 |
+
window_length: 50.0 # ms
|
| 4 |
+
hop_length: 12.5 # ms
|
| 5 |
+
mel_min: 1.0e-5 ## need .0 to make it a float
|
| 6 |
+
mel_dim: 80
|
| 7 |
+
n_fft: 2048
|
| 8 |
+
|
| 9 |
+
## wavegru
|
| 10 |
+
embed_dim: 32
|
| 11 |
+
rnn_dim: 512
|
| 12 |
+
frames_per_sequence: 67
|
| 13 |
+
num_pad_frames: 62
|
| 14 |
+
upsample_factors: [5, 4, 3, 5]
|
wavegru_vocoder_tpu_gta_preemphasis_pruning_v7_0040000.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c09ed822c5daac0afbd19e8ba4e0ded26dd5732e0efd13ce193c3f54c4e63f54
|
| 3 |
+
size 56479599
|