Spaces:
Sleeping
Sleeping
| import jax | |
| import jax.numpy as jnp | |
| import librosa | |
| import numpy as np | |
| import pax | |
| from text import english_cleaners | |
| from utils import ( | |
| create_tacotron_model, | |
| load_tacotron_ckpt, | |
| load_tacotron_config, | |
| load_wavegru_ckpt, | |
| load_wavegru_config, | |
| ) | |
| from wavegru import WaveGRU | |
| def load_tacotron_model(alphabet_file, config_file, model_file): | |
| """load tacotron model to memory""" | |
| with open(alphabet_file, "r", encoding="utf-8") as f: | |
| alphabet = f.read().split("\n") | |
| config = load_tacotron_config(config_file) | |
| net = create_tacotron_model(config) | |
| _, net, _ = load_tacotron_ckpt(net, None, model_file) | |
| net = net.eval() | |
| net = jax.device_put(net) | |
| return alphabet, net, config | |
| tacotron_inference_fn = pax.pure(lambda net, text: net.inference(text, max_len=500)) | |
| def text_to_mel(net, text, alphabet, config): | |
| """convert text to mel spectrogram""" | |
| text = english_cleaners(text) | |
| text = text + config["PAD"] * (100 - (len(text) % 100)) | |
| tokens = [] | |
| for c in text: | |
| if c in alphabet: | |
| tokens.append(alphabet.index(c)) | |
| tokens = jnp.array(tokens, dtype=jnp.int32) | |
| mel = tacotron_inference_fn(net, tokens[None]) | |
| return mel | |
| def load_wavegru_net(config_file, model_file): | |
| """load wavegru to memory""" | |
| config = load_wavegru_config(config_file) | |
| net = WaveGRU( | |
| mel_dim=config["mel_dim"], | |
| embed_dim=config["embed_dim"], | |
| rnn_dim=config["rnn_dim"], | |
| upsample_factors=config["upsample_factors"], | |
| ) | |
| _, net, _ = load_wavegru_ckpt(net, None, model_file) | |
| net = net.eval() | |
| net = jax.device_put(net) | |
| return config, net | |
| wavegru_inference = pax.pure(lambda net, mel: net.inference(mel, no_gru=True)) | |
| def mel_to_wav(net, netcpp, mel, config): | |
| """convert mel to wav""" | |
| if len(mel.shape) == 2: | |
| mel = mel[None] | |
| pad = config["num_pad_frames"] // 2 + 4 | |
| mel = np.pad( | |
| mel, | |
| [(0, 0), (pad, pad), (0, 0)], | |
| constant_values=np.log(config["mel_min"]), | |
| ) | |
| ft = wavegru_inference(net, mel) | |
| ft = jax.device_get(ft[0]) | |
| wav = netcpp.inference(ft, 1.0) | |
| wav = np.array(wav) | |
| wav = librosa.mu_expand(wav - 127, mu=255) | |
| wav = librosa.effects.deemphasis(wav, coef=0.86) | |
| wav = wav * 2.0 | |
| wav = wav / max(1.0, np.max(np.abs(wav))) | |
| wav = wav * 2**15 | |
| wav = np.clip(wav, a_min=-(2**15), a_max=(2**15) - 1) | |
| wav = wav.astype(np.int16) | |
| return wav | |