AlexK-PL's picture
Update app.py
31277c7
raw
history blame
2.41 kB
import gradio as gr
from hyper_parameters import tacotron_params as hparams
from training import load_model
from text import text_to_sequence
from melgan.model.generator import Generator
from melgan.utils.hparams import load_hparam
import torch
import numpy as np
torch.manual_seed(1234)
MAX_WAV_VALUE = 32768.0
DESCRIPTION = """# Single-Head Attention Tacotron2 with Global Style Tokens
This is a Tacotron2 model based on the NVIDIA's model plus three unsupervised Global Style Tokens (GST).
The whole architecture has been trained from scratch with the LJSpeech dataset. In order to control the relevance
of each style token, we configured the attention module as a single-head.
"""
# load trained tacotron2 + GST model:
model = load_model(hparams)
checkpoint_path = "trained_models/checkpoint_78000.model"
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")['state_dict'])
# model.to('cuda')
_ = model.eval()
# load pre trained MelGAN model for mel2audio:
vocoder_checkpoint_path = "trained_models/nvidia_tacotron2_LJ11_epoch6400.pt"
checkpoint = torch.load(vocoder_checkpoint_path, map_location="cpu")
hp_melgan = load_hparam("melgan/config/default.yaml")
vocoder_model = Generator(80)
vocoder_model.load_state_dict(checkpoint['model_g'])
# vocoder_model = vocoder_model.to('cuda')
vocoder_model.eval(inference=False)
def synthesize(text, gst_1, gst_2, gst_3):
sequence = np.array(text_to_sequence(text, ['english_cleaners']))[None, :]
sequence = torch.from_numpy(sequence).to(device='cpu', dtype=torch.int64)
# gst_head_scores = np.array([0.5, 0.15, 0.35]) # originally ([0.5, 0.15, 0.35])
gst_head_scores = np.array([gst_1, gst_2, gst_3]) # originally ([0.5, 0.15, 0.35])
gst_scores = torch.from_numpy(gst_head_scores).float()
mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence, gst_scores)
# mel2wav inference:
with torch.no_grad():
audio = vocoder_model.inference(mel_outputs_postnet)
audio_numpy = audio.data.cpu().detach().numpy()
return (22050, audio_numpy)
iface = gr.Interface(fn=synthesize, inputs=[gr.Textbox(label="Input Text"), gr.Slider(0.25, 0.55, label="First style token weight:"), gr.Slider(0.25, 0.55, label="Second style token weight:"), gr.Slider(0.25, 0.55, label="Third style token weight:")], outputs=[gr.Audio(label="Generated Speech", type="numpy"),])
iface.launch()