Spaces:
Running
Running
import logging | |
import re | |
import torch | |
from num2words import num2words | |
try: | |
import extensions.telegram_bot.source.utils as utils | |
from extensions.telegram_bot.source.user import User as User | |
except ImportError: | |
import source.utils as utils | |
from source.user import User as User | |
class Silero: | |
punctuation = r"[\s,.?!/)\'\]>]" | |
alphabet_map = { | |
"A": " Ei ", | |
"B": " Bee ", | |
"C": " See ", | |
"D": " Dee ", | |
"E": " Eee ", | |
"F": " Eff ", | |
"G": " Jee ", | |
"H": " Eich ", | |
"I": " Eye ", | |
"J": " Jay ", | |
"K": " Kay ", | |
"L": " El ", | |
"M": " Emm ", | |
"N": " Enn ", | |
"O": " Ohh ", | |
"P": " Pee ", | |
"Q": " Queue ", | |
"R": " Are ", | |
"S": " Ess ", | |
"T": " Tee ", | |
"U": " You ", | |
"V": " Vee ", | |
"W": " Double You ", | |
"X": " Ex ", | |
"Y": " Why ", | |
"Z": " Zed ", | |
} | |
voices = { | |
"en": { | |
"model": "v3_en", | |
"male": [ | |
"en_2", | |
"en_7", | |
"en_9", | |
"en_13", | |
"en_15", | |
"en_17", | |
"en_19", | |
"en_20", | |
], | |
"female": [ | |
"en_0", | |
"en_10", | |
"en_11", | |
"en_12", | |
"en_14", | |
"en_16", | |
"en_18", | |
"en_24", | |
"en_25", | |
"en_82", | |
"en_85", | |
], | |
}, | |
"ru": { | |
"model": "v3_1_ru", | |
"male": [ | |
"aidar", | |
"eugene", | |
], | |
"female": [ | |
"baya", | |
"kseniya", | |
"xenia", | |
], | |
}, | |
} | |
def __init__( | |
self, | |
): | |
torch.set_num_threads(4) | |
self.device = torch.device("cpu") | |
self.sample_rate = 48000 # 8000, 24000, 48000 | |
self.silero_repo = "snakers4/silero-models" | |
self.model = "silero_tts" | |
logging.info(f"### Silero INIT DONE ###") | |
async def get_audio(self, text: str, user_id: int, user: User): | |
return await self.generate_audio(text, user_id, user) | |
def generate_audio(self, text: str, user_id: int, user: User): | |
if user.silero_speaker == "None" or user.silero_model_id == "None": | |
return None | |
if user.silero_speaker == "None" or user.silero_model_id == "None": | |
user.silero_model_id, user.silero_speaker = self.get_default_audio_settings(user.language) | |
if user.silero_speaker not in self.voices[user.language]["male"] + self.voices[user.language]["female"]: | |
user.silero_model_id, user.silero_speaker = self.get_default_audio_settings(user.language) | |
if user.silero_model_id not in self.voices[user.language]["model"]: | |
user.silero_model_id, user.silero_speaker = self.get_default_audio_settings(user.language) | |
try: | |
model, _ = torch.hub.load( | |
repo_or_dir=self.silero_repo, | |
model=self.model, | |
language=user.language, | |
speaker=user.silero_model_id, | |
) | |
wav_path = str(user_id) + ".wav" | |
text = self.preprocess(text) | |
if len(text.replace(" ", "")) == 0: | |
return None | |
model.save_wav( | |
text=self.preprocess(text), | |
audio_path=wav_path, | |
speaker=user.silero_speaker, | |
sample_rate=self.sample_rate, | |
) | |
# self.wav_to_ogg2(wav_path, ogg_path) | |
return wav_path | |
except Exception as e: | |
print(e) | |
return None | |
def get_default_audio_settings(self, language, sex="female"): | |
if language in self.voices: | |
return self.voices[language]["model"], self.voices[language][sex][0] | |
else: | |
return "None", "None" | |
def preprocess(self, string): | |
# the order for some of these matter | |
# For example, you need to remove the commas in numbers before expanding them | |
string = self.remove_surrounded_chars(string) | |
string = string.replace('"', "") | |
string = string.replace("\u201D", "").replace("\u201C", "") # right and left quote | |
string = string.replace("\u201F", "") # italic looking quote | |
string = string.replace("\n", " ") | |
string = string.replace("*", " ! ") | |
string = self.convert_num_locale(string) | |
string = self.replace_negative(string) | |
string = self.replace_roman(string) | |
string = self.hyphen_range_to(string) | |
string = self.num_to_words(string) | |
# For now, expand abbreviations to pronunciations | |
# replace_abbreviations adds a lot of unnecessary whitespace to ensure separation | |
string = self.replace_abbreviations(string) | |
string = self.replace_lowercase_abbreviations(string) | |
# cleanup whitespaces | |
# remove whitespace before punctuation | |
string = re.sub(rf"\s+({self.punctuation})", r"\1", string) | |
string = string.strip() | |
# compact whitespace | |
string = " ".join(string.split()) | |
return string | |
def remove_surrounded_chars(string): | |
# first this expression will check if there is a string nested exclusively between a alt= | |
# and a style= string. This would correspond to only a the alt text of an embedded image | |
# If it matches it will only keep that part as the string, and rend it for further processing | |
# Afterwards this expression matches to 'as few symbols as possible (0 upwards) between any | |
# asterisks' OR' as few symbols as possible (0 upwards) between an asterisk and the end of the string' | |
if re.search(r"(?<=alt=)(.*)(?=style=)", string, re.DOTALL): | |
m = re.search(r"(?<=alt=)(.*)(?=style=)", string, re.DOTALL) | |
string = m.group(0) | |
return re.sub(r"\*[^*]*?(\*|$)", "", string) | |
def convert_num_locale(text): | |
# This detects locale and converts it to American without comma separators | |
pattern = re.compile(r"(?:\s|^)\d{1,3}(?:\.\d{3})+(,\d+)(?:\s|$)") | |
result = text | |
while True: | |
match = pattern.search(result) | |
if match is None: | |
break | |
start = match.start() | |
end = match.end() | |
result = result[0:start] + result[start:end].replace(".", "").replace(",", ".") + result[end: len(result)] | |
# removes comma separators from existing American numbers | |
pattern = re.compile(r"(\d),(\d)") | |
result = pattern.sub(r"\1\2", result) | |
return result | |
def replace_negative(self, string): | |
# handles situations like -5. -5 would become negative 5, which would then be expanded to negative five | |
return re.sub(rf"(\s)(-)(\d+)({self.punctuation})", r"\1negative \3\4", string) | |
def replace_roman(self, string): | |
# find a string of roman numerals. | |
# Only 2 or more, to avoid capturing I and single character abbreviations, like names | |
pattern = re.compile(rf"\s[IVXLCDM]{{2,}}{self.punctuation}") | |
result = string | |
while True: | |
match = pattern.search(result) | |
if match is None: | |
break | |
start = match.start() | |
end = match.end() | |
result = ( | |
result[0: start + 1] | |
+ str(self.roman_to_int(result[start + 1: end - 1])) | |
+ result[end - 1: len(result)] | |
) | |
return result | |
def roman_to_int(s): | |
rom_val = {"I": 1, "V": 5, "X": 10, "L": 50, "C": 100, "D": 500, "M": 1000} | |
int_val = 0 | |
for i in range(len(s)): | |
if i > 0 and rom_val[s[i]] > rom_val[s[i - 1]]: | |
int_val += rom_val[s[i]] - 2 * rom_val[s[i - 1]] | |
else: | |
int_val += rom_val[s[i]] | |
return int_val | |
def hyphen_range_to(text): | |
pattern = re.compile(r"(\d+)[-–](\d+)") | |
result = pattern.sub(lambda x: x.group(1) + " to " + x.group(2), text) | |
return result | |
def num_to_words(text): | |
# 1000 or 10.23 | |
pattern = re.compile(r"\d+\.\d+|\d+") | |
result = pattern.sub(lambda x: num2words(float(x.group())), text) | |
return result | |
def replace_abbreviations(self, string): | |
# abbreviations 1-4 characters long. It will get things like A and I, but those are pronounced with their letter | |
pattern = re.compile(rf"(^|[\s(.\'\[<])([A-Z]{{1,4}})({self.punctuation}|$)") | |
result = string | |
while True: | |
match = pattern.search(result) | |
if match is None: | |
break | |
start = match.start() | |
end = match.end() | |
result = result[0:start] + self.replace_abbreviation(result[start:end]) + result[end: len(result)] | |
return result | |
def replace_lowercase_abbreviations(self, string): | |
# abbreviations 1 to 4 characters long, separated by dots i.e. e.g. | |
pattern = re.compile(rf"(^|[\s(.\'\[<])(([a-z]\.){{1,4}})({self.punctuation}|$)") | |
result = string | |
while True: | |
match = pattern.search(result) | |
if match is None: | |
break | |
start = match.start() | |
end = match.end() | |
result = result[0:start] + self.replace_abbreviation(result[start:end].upper()) + result[end: len(result)] | |
return result | |
def replace_abbreviation(self, string): | |
result = "" | |
for char in string: | |
result += self.match_mapping(char) | |
return result | |
def match_mapping(self, char): | |
for mapping in self.alphabet_map.keys(): | |
if char == mapping: | |
return self.alphabet_map[char] | |
return char | |
def __main__(self, args): | |
print(self.preprocess(args[1])) | |