Flux9665's picture
remove the spaces environment, since we now have a permanent GPU
c69f215
raw
history blame
15.8 kB
import itertools
import os
import warnings
import matplotlib.pyplot as plt
import pyloudnorm
import sounddevice
import soundfile
import torch
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from speechbrain.pretrained import EncoderClassifier
from torchaudio.transforms import Resample
from Architectures.ToucanTTS.InferenceToucanTTS import ToucanTTS
from Architectures.Vocoder.HiFiGAN_Generator import HiFiGAN
from Preprocessing.AudioPreprocessor import AudioPreprocessor
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
from Preprocessing.TextFrontend import get_language_id
from Utility.storage_config import MODELS_DIR
from Utility.utils import cumsum_durations
from Utility.utils import float2pcm
class ToucanTTSInterface(torch.nn.Module):
def __init__(self,
device="cpu", # device that everything computes on. If a cuda device is available, this can speed things up by an order of magnitude.
tts_model_path=os.path.join(MODELS_DIR, f"ToucanTTS_Meta", "best.pt"), # path to the ToucanTTS checkpoint or just a shorthand if run standalone
vocoder_model_path=os.path.join(MODELS_DIR, f"Vocoder", "best.pt"), # path to the Vocoder checkpoint
language="eng", # initial language of the model, can be changed later with the setter methods
enhance=None # legacy argument
):
super().__init__()
self.device = device
if not tts_model_path.endswith(".pt"):
# default to shorthand system
tts_model_path = os.path.join(MODELS_DIR, f"ToucanTTS_{tts_model_path}", "best.pt")
################################
# build text to phone #
################################
self.text2phone = ArticulatoryCombinedTextFrontend(language=language, add_silence_to_end=True)
#####################################
# load phone to features model #
#####################################
checkpoint = torch.load(tts_model_path, map_location='cpu')
self.phone2mel = ToucanTTS(weights=checkpoint["model"], config=checkpoint["config"])
with torch.no_grad():
self.phone2mel.store_inverse_all() # this also removes weight norm
self.phone2mel = self.phone2mel.to(torch.device(device))
######################################
# load features to style models #
######################################
self.speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb",
run_opts={"device": str(device)},
savedir=os.path.join(MODELS_DIR, "Embedding", "speechbrain_speaker_embedding_ecapa"))
################################
# load mel to wave model #
################################
vocoder_checkpoint = torch.load(vocoder_model_path, map_location="cpu")
self.vocoder = HiFiGAN()
self.vocoder.load_state_dict(vocoder_checkpoint)
self.vocoder = self.vocoder.to(device).eval()
self.vocoder.remove_weight_norm()
self.meter = pyloudnorm.Meter(24000)
################################
# set defaults #
################################
self.default_utterance_embedding = checkpoint["default_emb"].to(self.device)
self.ap = AudioPreprocessor(input_sr=100, output_sr=16000, device=device)
self.phone2mel.eval()
self.vocoder.eval()
self.lang_id = get_language_id(language)
self.to(torch.device(device))
self.eval()
def set_utterance_embedding(self, path_to_reference_audio="", embedding=None):
if embedding is not None:
self.default_utterance_embedding = embedding.squeeze().to(self.device)
return
if type(path_to_reference_audio) != list:
path_to_reference_audio = [path_to_reference_audio]
if len(path_to_reference_audio) > 0:
for path in path_to_reference_audio:
assert os.path.exists(path)
speaker_embs = list()
for path in path_to_reference_audio:
wave, sr = soundfile.read(path)
wave = Resample(orig_freq=sr, new_freq=16000).to(self.device)(torch.tensor(wave, device=self.device, dtype=torch.float32))
speaker_embedding = self.speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(self.device).unsqueeze(0)).squeeze()
speaker_embs.append(speaker_embedding)
self.default_utterance_embedding = sum(speaker_embs) / len(speaker_embs)
def set_language(self, lang_id):
"""
The id parameter actually refers to the shorthand. This has become ambiguous with the introduction of the actual language IDs
"""
self.set_phonemizer_language(lang_id=lang_id)
self.set_accent_language(lang_id=lang_id)
def set_phonemizer_language(self, lang_id):
self.text2phone.change_lang(language=lang_id, add_silence_to_end=True)
def set_accent_language(self, lang_id):
if lang_id in ['ajp', 'ajt', 'lak', 'lno', 'nul', 'pii', 'plj', 'slq', 'smd', 'snb', 'tpw', 'wya', 'zua', 'en-us', 'en-sc', 'fr-be', 'fr-sw', 'pt-br', 'spa-lat', 'vi-ctr', 'vi-so']:
if lang_id == 'vi-so' or lang_id == 'vi-ctr':
lang_id = 'vie'
elif lang_id == 'spa-lat':
lang_id = 'spa'
elif lang_id == 'pt-br':
lang_id = 'por'
elif lang_id == 'fr-sw' or lang_id == 'fr-be':
lang_id = 'fra'
elif lang_id == 'en-sc' or lang_id == 'en-us':
lang_id = 'eng'
else:
# no clue where these others are even coming from, they are not in ISO 639-2
lang_id = 'eng'
self.lang_id = get_language_id(lang_id).to(self.device)
def forward(self,
text,
view=False,
duration_scaling_factor=1.0,
pitch_variance_scale=1.0,
energy_variance_scale=1.0,
pause_duration_scaling_factor=1.0,
durations=None,
pitch=None,
energy=None,
input_is_phones=False,
return_plot_as_filepath=False,
loudness_in_db=-24.0,
glow_sampling_temperature=0.2):
"""
duration_scaling_factor: reasonable values are 0.8 < scale < 1.2.
1.0 means no scaling happens, higher values increase durations for the whole
utterance, lower values decrease durations for the whole utterance.
pitch_variance_scale: reasonable values are 0.6 < scale < 1.4.
1.0 means no scaling happens, higher values increase variance of the pitch curve,
lower values decrease variance of the pitch curve.
energy_variance_scale: reasonable values are 0.6 < scale < 1.4.
1.0 means no scaling happens, higher values increase variance of the energy curve,
lower values decrease variance of the energy curve.
"""
with torch.inference_mode():
phones = self.text2phone.string_to_tensor(text, input_phonemes=input_is_phones).to(torch.device(self.device))
mel, durations, pitch, energy = self.phone2mel(phones,
return_duration_pitch_energy=True,
utterance_embedding=self.default_utterance_embedding.to(self.device),
durations=durations,
pitch=pitch,
energy=energy,
lang_id=self.lang_id.to(self.device),
duration_scaling_factor=duration_scaling_factor,
pitch_variance_scale=pitch_variance_scale,
energy_variance_scale=energy_variance_scale,
pause_duration_scaling_factor=pause_duration_scaling_factor,
glow_sampling_temperature=glow_sampling_temperature)
wave, _, _ = self.vocoder(mel.unsqueeze(0))
wave = wave.squeeze().cpu()
wave = wave.numpy()
sr = 24000
try:
loudness = self.meter.integrated_loudness(wave)
wave = pyloudnorm.normalize.loudness(wave, loudness, loudness_in_db)
except ValueError:
# if the audio is too short, a value error will arise
pass
if view or return_plot_as_filepath:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 5))
ax.imshow(mel.cpu().numpy(), origin="lower", cmap='GnBu')
ax.yaxis.set_visible(False)
duration_splits, label_positions = cumsum_durations(durations.cpu().numpy())
ax.xaxis.grid(True, which='minor')
ax.set_xticks(label_positions, minor=False)
if input_is_phones:
phones = text.replace(" ", "|")
else:
phones = self.text2phone.get_phone_string(text, for_plot_labels=True)
ax.set_xticklabels(phones)
word_boundaries = list()
for label_index, phone in enumerate(phones):
if phone == "|":
word_boundaries.append(label_positions[label_index])
try:
prev_word_boundary = 0
word_label_positions = list()
for word_boundary in word_boundaries:
word_label_positions.append((word_boundary + prev_word_boundary) / 2)
prev_word_boundary = word_boundary
word_label_positions.append((duration_splits[-1] + prev_word_boundary) / 2)
secondary_ax = ax.secondary_xaxis('bottom')
secondary_ax.tick_params(axis="x", direction="out", pad=24)
secondary_ax.set_xticks(word_label_positions, minor=False)
secondary_ax.set_xticklabels(text.split())
secondary_ax.tick_params(axis='x', colors='orange')
secondary_ax.xaxis.label.set_color('orange')
except ValueError:
ax.set_title(text)
except IndexError:
ax.set_title(text)
ax.vlines(x=duration_splits, colors="green", linestyles="solid", ymin=0, ymax=120, linewidth=0.5)
ax.vlines(x=word_boundaries, colors="orange", linestyles="solid", ymin=0, ymax=120, linewidth=1.0)
plt.subplots_adjust(left=0.02, bottom=0.2, right=0.98, top=.9, wspace=0.0, hspace=0.0)
ax.set_aspect("auto")
if return_plot_as_filepath:
plt.savefig("tmp.png")
return wave, sr, "tmp.png"
return wave, sr
def read_to_file(self,
text_list,
file_location,
duration_scaling_factor=1.0,
pitch_variance_scale=1.0,
energy_variance_scale=1.0,
pause_duration_scaling_factor=1.0,
silent=False,
dur_list=None,
pitch_list=None,
energy_list=None,
glow_sampling_temperature=0.2):
"""
Args:
silent: Whether to be verbose about the process
text_list: A list of strings to be read
file_location: The path and name of the file it should be saved to
energy_list: list of energy tensors to be used for the texts
pitch_list: list of pitch tensors to be used for the texts
dur_list: list of duration tensors to be used for the texts
duration_scaling_factor: reasonable values are 0.8 < scale < 1.2.
1.0 means no scaling happens, higher values increase durations for the whole
utterance, lower values decrease durations for the whole utterance.
pitch_variance_scale: reasonable values are 0.6 < scale < 1.4.
1.0 means no scaling happens, higher values increase variance of the pitch curve,
lower values decrease variance of the pitch curve.
energy_variance_scale: reasonable values are 0.6 < scale < 1.4.
1.0 means no scaling happens, higher values increase variance of the energy curve,
lower values decrease variance of the energy curve.
"""
if not dur_list:
dur_list = []
if not pitch_list:
pitch_list = []
if not energy_list:
energy_list = []
silence = torch.zeros([14300])
wav = silence.clone()
for (text, durations, pitch, energy) in itertools.zip_longest(text_list, dur_list, pitch_list, energy_list):
if text.strip() != "":
if not silent:
print("Now synthesizing: {}".format(text))
spoken_sentence, sr = self(text,
durations=durations.to(self.device) if durations is not None else None,
pitch=pitch.to(self.device) if pitch is not None else None,
energy=energy.to(self.device) if energy is not None else None,
duration_scaling_factor=duration_scaling_factor,
pitch_variance_scale=pitch_variance_scale,
energy_variance_scale=energy_variance_scale,
pause_duration_scaling_factor=pause_duration_scaling_factor,
glow_sampling_temperature=glow_sampling_temperature)
spoken_sentence = torch.tensor(spoken_sentence).cpu()
wav = torch.cat((wav, spoken_sentence, silence), 0)
soundfile.write(file=file_location, data=float2pcm(wav), samplerate=sr, subtype="PCM_16")
def read_aloud(self,
text,
view=False,
duration_scaling_factor=1.0,
pitch_variance_scale=1.0,
energy_variance_scale=1.0,
blocking=False,
glow_sampling_temperature=0.2):
if text.strip() == "":
return
wav, sr = self(text,
view,
duration_scaling_factor=duration_scaling_factor,
pitch_variance_scale=pitch_variance_scale,
energy_variance_scale=energy_variance_scale,
glow_sampling_temperature=glow_sampling_temperature)
silence = torch.zeros([sr // 2])
wav = torch.cat((silence, torch.tensor(wav), silence), 0).numpy()
sounddevice.play(float2pcm(wav), samplerate=sr)
if view:
plt.show()
if blocking:
sounddevice.wait()