|
import torch |
|
from cached_path import cached_path |
|
|
|
import audresample |
|
|
|
import numpy as np |
|
import yaml |
|
import torchaudio |
|
import librosa |
|
from models import ProsodyPredictor, TextEncoder, StyleEncoder, load_F0_models |
|
from nltk.tokenize import word_tokenize |
|
|
|
|
|
|
|
_pad = "$" |
|
_punctuation = ';:,.!?¡¿—…"«»“” ' |
|
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' |
|
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" |
|
|
|
|
|
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) |
|
|
|
dicts = {} |
|
for i in range(len((symbols))): |
|
dicts[symbols[i]] = i |
|
|
|
class TextCleaner: |
|
def __init__(self, dummy=None): |
|
self.word_index_dictionary = dicts |
|
print(len(dicts)) |
|
def __call__(self, text): |
|
indexes = [] |
|
for char in text: |
|
try: |
|
indexes.append(self.word_index_dictionary[char]) |
|
except KeyError: |
|
print('CLEAN', text) |
|
return indexes |
|
|
|
|
|
|
|
textclenaer = TextCleaner() |
|
|
|
|
|
to_mel = torchaudio.transforms.MelSpectrogram( |
|
n_mels=80, n_fft=2048, win_length=1200, hop_length=300) |
|
mean, std = -4, 4 |
|
|
|
def alpha_num(f): |
|
f = re.sub(' +', ' ', f) |
|
f = re.sub(r'[^A-Z a-z0-9 ]+', '', f) |
|
return f |
|
|
|
def preprocess(wave): |
|
wave_tensor = torch.from_numpy(wave).float() |
|
mel_tensor = to_mel(wave_tensor) |
|
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std |
|
return mel_tensor |
|
|
|
def compute_style(path): |
|
wave, sr = librosa.load(path, sr=24000) |
|
audio, index = librosa.effects.trim(wave, top_db=30) |
|
if sr != 24000: |
|
audio = librosa.resample(audio, sr, 24000) |
|
mel_tensor = preprocess(audio).to(device) |
|
|
|
with torch.no_grad(): |
|
ref_s = style_encoder(mel_tensor.unsqueeze(1)) |
|
ref_p = predictor_encoder(mel_tensor.unsqueeze(1)) |
|
|
|
s = torch.cat([ref_s, ref_p], dim=3) |
|
|
|
s = s[:, :, 0, :].transpose(1, 2) |
|
return s |
|
|
|
device = 'cpu' |
|
if torch.cuda.is_available(): |
|
device = 'cuda' |
|
elif torch.backends.mps.is_available(): |
|
|
|
pass |
|
|
|
|
|
import phonemizer |
|
global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True) |
|
|
|
|
|
|
|
args = yaml.safe_load(open(str('Utils/config.yml'))) |
|
ASR_config = args['ASR_config'] |
|
|
|
F0_path = args['F0_path'] |
|
pitch_extractor = load_F0_models(F0_path).eval().to(device) |
|
|
|
from Utils.PLBERT.util import load_plbert |
|
from Modules.hifigan import Decoder |
|
|
|
bert = load_plbert(args['PLBERT_dir']).eval().to(device) |
|
|
|
decoder = Decoder(dim_in=512, |
|
style_dim=128, |
|
dim_out=80, |
|
resblock_kernel_sizes = [3, 7, 11], |
|
upsample_rates = [10, 5, 3, 2], |
|
upsample_initial_channel=512, |
|
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], |
|
upsample_kernel_sizes=[20, 10, 6, 4]).eval().to(device) |
|
|
|
text_encoder = TextEncoder(channels=512, |
|
kernel_size=5, |
|
depth=3, |
|
n_symbols=178, |
|
).eval().to(device) |
|
|
|
predictor = ProsodyPredictor(style_dim=128, |
|
d_hid=512, |
|
nlayers=3, |
|
max_dur=50, |
|
dropout=.2).eval().to(device) |
|
|
|
style_encoder = StyleEncoder(dim_in=64, |
|
style_dim=128, |
|
max_conv_dim=512).eval().to(device) |
|
predictor_encoder = StyleEncoder(dim_in=64, |
|
style_dim=128, |
|
max_conv_dim=512).eval().to(device) |
|
bert_encoder = torch.nn.Linear(bert.config.hidden_size, 512).eval().to(device) |
|
|
|
|
|
params_whole = torch.load(str(cached_path("hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth")), map_location='cpu') |
|
params = params_whole['net'] |
|
|
|
from collections import OrderedDict |
|
|
|
def _del_prefix(d): |
|
|
|
out = OrderedDict() |
|
for k, v in d.items(): |
|
out[k[7:]] = v |
|
return out |
|
|
|
bert.load_state_dict( _del_prefix(params['bert']), strict=True) |
|
bert_encoder.load_state_dict(_del_prefix(params['bert_encoder']), strict=True) |
|
predictor.load_state_dict( _del_prefix(params['predictor']), strict=True) |
|
decoder.load_state_dict( _del_prefix(params['decoder']), strict=True) |
|
text_encoder.load_state_dict(_del_prefix(params['text_encoder']), strict=True) |
|
predictor_encoder.load_state_dict(_del_prefix(params['predictor_encoder']), strict=True) |
|
style_encoder.load_state_dict(_del_prefix(params['style_encoder']), strict=True) |
|
pitch_extractor.load_state_dict(_del_prefix(params['pitch_extractor']), strict=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inference(text, |
|
ref_s, |
|
use_gruut=False): |
|
|
|
|
|
text = text.strip() |
|
|
|
ps = global_phonemizer.phonemize([text]) |
|
|
|
ps = word_tokenize(ps[0]) |
|
|
|
ps = ' '.join(ps) |
|
tokens = textclenaer(ps) |
|
|
|
tokens.insert(0, 0) |
|
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) |
|
|
|
hidden_states = text_encoder(tokens, input_lengths) |
|
|
|
bert_dur = bert(tokens, attention_mask=None) |
|
d_en = bert_encoder(bert_dur).transpose(-1, -2) |
|
ref = ref_s[:, :128, :] |
|
s = ref_s[:, 128:, :] |
|
d = predictor.text_encoder(d_en, s, input_lengths) |
|
d = d.transpose(1, 2) |
|
|
|
|
|
d = predictor.text_encoder(d_en, |
|
s, |
|
input_lengths) |
|
|
|
x, _ = predictor.lstm(d) |
|
|
|
duration = predictor.duration_proj(x) |
|
|
|
duration = torch.sigmoid(duration).sum(axis=-1) |
|
pred_dur = torch.round(duration.squeeze()).clamp(min=1) |
|
|
|
|
|
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) |
|
c_frame = 0 |
|
for i in range(pred_aln_trg.size(0)): |
|
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 |
|
c_frame += int(pred_dur[i].data) |
|
|
|
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)) |
|
asr_new = torch.zeros_like(en) |
|
asr_new[:, :, 0] = en[:, :, 0] |
|
asr_new[:, :, 1:] = en[:, :, 0:-1] |
|
en = asr_new |
|
|
|
F0_pred, N_pred = predictor.F0Ntrain(en, s) |
|
|
|
asr = (hidden_states @ pred_aln_trg.unsqueeze(0).to(device)) |
|
|
|
asr_new = torch.zeros_like(asr) |
|
asr_new[:, :, 0] = asr[:, :, 0] |
|
asr_new[:, :, 1:] = asr[:, :, 0:-1] |
|
asr = asr_new |
|
|
|
|
|
x = decoder(asr=asr, |
|
F0_curve=F0_pred, |
|
N=N_pred, |
|
s=ref) |
|
|
|
x = x.cpu().numpy()[0, 0, :-400] |
|
|
|
print(x.shape,' A') |
|
if x.shape[0] > 10: |
|
x /= np.abs(x).max() + 1e-7 |
|
else: |
|
print('\n\n\n\n\nEMPTY TTS\n\n\n\n\n\nn', x.shape) |
|
x = np.zeros(0) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from num2words import num2words |
|
import os |
|
import re |
|
import tempfile |
|
import torch |
|
import sys |
|
from Modules.vits.models import VitsModel, VitsTokenizer |
|
|
|
TTS_LANGUAGES = {} |
|
|
|
with open(f"Utils/all_langs.csv") as f: |
|
for line in f: |
|
iso, name = line.split(",", 1) |
|
TTS_LANGUAGES[iso.strip()] = name.strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PHONEME_MAP = { |
|
'služ' : 'sloooozz', |
|
'suver': 'siuveeerra', |
|
'država': 'dirrezav', |
|
'iči': 'ici', |
|
's ': 'se', |
|
'q': 'ku', |
|
'w': 'aou', |
|
'z': 's', |
|
"š": "s", |
|
'th': 'ta', |
|
'v': 'vv', |
|
|
|
|
|
|
|
|
|
"ž": "z", |
|
|
|
} |
|
|
|
|
|
|
|
def number_to_phonemes(match): |
|
number = int(match.group()) |
|
words = num2words(number, lang='sr') |
|
return fix_phones(words.lower()) |
|
|
|
|
|
def fix_phones(text): |
|
for src, target in PHONEME_MAP.items(): |
|
text = text.replace(src, target) |
|
|
|
|
|
|
|
return text.replace(',', '_ _').replace('.', '_ _') |
|
|
|
def has_cyrillic(text): |
|
|
|
return bool(re.search('[\u0400-\u04FF]', text)) |
|
|
|
def foreign(text=None, |
|
lang='romanian', |
|
speed=None): |
|
|
|
lang = lang.lower() |
|
|
|
|
|
|
|
if 'hun' in lang: |
|
|
|
lang_code = 'hun' |
|
|
|
elif any([i in lang for i in ['ser', 'bosn', 'herzegov', 'montenegr', 'macedon']]): |
|
|
|
if has_cyrillic(text[0]): |
|
|
|
lang_code = 'rmc-script_cyrillic' |
|
|
|
else: |
|
|
|
lang_code = 'rmc-script_latin' |
|
|
|
elif 'rom' in lang: |
|
|
|
lang_code = 'ron' |
|
speed = 1.24 if speed is None else speed |
|
|
|
elif 'ger' in lang: |
|
|
|
lang_code = 'deu' |
|
speed = 1.14 if speed is None else speed |
|
|
|
elif 'alban' in lang: |
|
|
|
lang_code = 'sqi' |
|
speed = 1.04 if speed is None else speed |
|
|
|
else: |
|
|
|
lang_code = lang.split()[0].strip() |
|
|
|
|
|
|
|
net_g = VitsModel.from_pretrained(f'facebook/mms-tts-{lang_code}').eval().to(device) |
|
tokenizer = VitsTokenizer.from_pretrained(f'facebook/mms-tts-{lang_code}') |
|
|
|
|
|
|
|
x = [] |
|
|
|
for _t in text: |
|
|
|
_t = _t.lower() |
|
|
|
if lang_code == 'rmc-script_latin': |
|
|
|
_t = re.sub(r'\d+', number_to_phonemes, _t) |
|
_t = fix_phones(_t) |
|
|
|
elif lang_code == 'ron': |
|
|
|
_t = _t.replace("ţ", "ț" |
|
).replace('ț','ts').replace('î', 'u') |
|
|
|
|
|
inputs = tokenizer(_t, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
|
|
net_g.speaking_rate = speed |
|
|
|
x.append( |
|
net_g(input_ids=inputs.input_ids.to(device), |
|
attention_mask=inputs.attention_mask.to(device)) |
|
) |
|
print(x[-1].shape) |
|
print(f'{speed=}\n\n\n\n_______________________________ {_t}') |
|
|
|
x = torch.cat(x).cpu().numpy() |
|
|
|
x /= np.abs(x).max() + 1e-7 |
|
|
|
|
|
|
|
x = audresample.resample(signal=x.astype(np.float32), |
|
original_rate=16000, |
|
target_rate=24000)[0, :] |
|
return x |
|
|
|
|
|
|