xsigus24's picture
Upload folder using huggingface_hub
1d777c4
import logging
import re
import torch
from num2words import num2words
try:
import extensions.telegram_bot.source.utils as utils
from extensions.telegram_bot.source.user import User as User
except ImportError:
import source.utils as utils
from source.user import User as User
class Silero:
punctuation = r"[\s,.?!/)\'\]>]"
alphabet_map = {
"A": " Ei ",
"B": " Bee ",
"C": " See ",
"D": " Dee ",
"E": " Eee ",
"F": " Eff ",
"G": " Jee ",
"H": " Eich ",
"I": " Eye ",
"J": " Jay ",
"K": " Kay ",
"L": " El ",
"M": " Emm ",
"N": " Enn ",
"O": " Ohh ",
"P": " Pee ",
"Q": " Queue ",
"R": " Are ",
"S": " Ess ",
"T": " Tee ",
"U": " You ",
"V": " Vee ",
"W": " Double You ",
"X": " Ex ",
"Y": " Why ",
"Z": " Zed ",
}
voices = {
"en": {
"model": "v3_en",
"male": [
"en_2",
"en_7",
"en_9",
"en_13",
"en_15",
"en_17",
"en_19",
"en_20",
],
"female": [
"en_0",
"en_10",
"en_11",
"en_12",
"en_14",
"en_16",
"en_18",
"en_24",
"en_25",
"en_82",
"en_85",
],
},
"ru": {
"model": "v3_1_ru",
"male": [
"aidar",
"eugene",
],
"female": [
"baya",
"kseniya",
"xenia",
],
},
}
def __init__(
self,
):
torch.set_num_threads(4)
self.device = torch.device("cpu")
self.sample_rate = 48000 # 8000, 24000, 48000
self.silero_repo = "snakers4/silero-models"
self.model = "silero_tts"
logging.info(f"### Silero INIT DONE ###")
async def get_audio(self, text: str, user_id: int, user: User):
return await self.generate_audio(text, user_id, user)
@utils.async_wrap
def generate_audio(self, text: str, user_id: int, user: User):
if user.silero_speaker == "None" or user.silero_model_id == "None":
return None
if user.silero_speaker == "None" or user.silero_model_id == "None":
user.silero_model_id, user.silero_speaker = self.get_default_audio_settings(user.language)
if user.silero_speaker not in self.voices[user.language]["male"] + self.voices[user.language]["female"]:
user.silero_model_id, user.silero_speaker = self.get_default_audio_settings(user.language)
if user.silero_model_id not in self.voices[user.language]["model"]:
user.silero_model_id, user.silero_speaker = self.get_default_audio_settings(user.language)
try:
model, _ = torch.hub.load(
repo_or_dir=self.silero_repo,
model=self.model,
language=user.language,
speaker=user.silero_model_id,
)
wav_path = str(user_id) + ".wav"
text = self.preprocess(text)
if len(text.replace(" ", "")) == 0:
return None
model.save_wav(
text=self.preprocess(text),
audio_path=wav_path,
speaker=user.silero_speaker,
sample_rate=self.sample_rate,
)
# self.wav_to_ogg2(wav_path, ogg_path)
return wav_path
except Exception as e:
print(e)
return None
def get_default_audio_settings(self, language, sex="female"):
if language in self.voices:
return self.voices[language]["model"], self.voices[language][sex][0]
else:
return "None", "None"
def preprocess(self, string):
# the order for some of these matter
# For example, you need to remove the commas in numbers before expanding them
string = self.remove_surrounded_chars(string)
string = string.replace('"', "")
string = string.replace("\u201D", "").replace("\u201C", "") # right and left quote
string = string.replace("\u201F", "") # italic looking quote
string = string.replace("\n", " ")
string = string.replace("*", " ! ")
string = self.convert_num_locale(string)
string = self.replace_negative(string)
string = self.replace_roman(string)
string = self.hyphen_range_to(string)
string = self.num_to_words(string)
# For now, expand abbreviations to pronunciations
# replace_abbreviations adds a lot of unnecessary whitespace to ensure separation
string = self.replace_abbreviations(string)
string = self.replace_lowercase_abbreviations(string)
# cleanup whitespaces
# remove whitespace before punctuation
string = re.sub(rf"\s+({self.punctuation})", r"\1", string)
string = string.strip()
# compact whitespace
string = " ".join(string.split())
return string
@staticmethod
def remove_surrounded_chars(string):
# first this expression will check if there is a string nested exclusively between a alt=
# and a style= string. This would correspond to only a the alt text of an embedded image
# If it matches it will only keep that part as the string, and rend it for further processing
# Afterwards this expression matches to 'as few symbols as possible (0 upwards) between any
# asterisks' OR' as few symbols as possible (0 upwards) between an asterisk and the end of the string'
if re.search(r"(?<=alt=)(.*)(?=style=)", string, re.DOTALL):
m = re.search(r"(?<=alt=)(.*)(?=style=)", string, re.DOTALL)
string = m.group(0)
return re.sub(r"\*[^*]*?(\*|$)", "", string)
@staticmethod
def convert_num_locale(text):
# This detects locale and converts it to American without comma separators
pattern = re.compile(r"(?:\s|^)\d{1,3}(?:\.\d{3})+(,\d+)(?:\s|$)")
result = text
while True:
match = pattern.search(result)
if match is None:
break
start = match.start()
end = match.end()
result = result[0:start] + result[start:end].replace(".", "").replace(",", ".") + result[end: len(result)]
# removes comma separators from existing American numbers
pattern = re.compile(r"(\d),(\d)")
result = pattern.sub(r"\1\2", result)
return result
def replace_negative(self, string):
# handles situations like -5. -5 would become negative 5, which would then be expanded to negative five
return re.sub(rf"(\s)(-)(\d+)({self.punctuation})", r"\1negative \3\4", string)
def replace_roman(self, string):
# find a string of roman numerals.
# Only 2 or more, to avoid capturing I and single character abbreviations, like names
pattern = re.compile(rf"\s[IVXLCDM]{{2,}}{self.punctuation}")
result = string
while True:
match = pattern.search(result)
if match is None:
break
start = match.start()
end = match.end()
result = (
result[0: start + 1]
+ str(self.roman_to_int(result[start + 1: end - 1]))
+ result[end - 1: len(result)]
)
return result
@staticmethod
def roman_to_int(s):
rom_val = {"I": 1, "V": 5, "X": 10, "L": 50, "C": 100, "D": 500, "M": 1000}
int_val = 0
for i in range(len(s)):
if i > 0 and rom_val[s[i]] > rom_val[s[i - 1]]:
int_val += rom_val[s[i]] - 2 * rom_val[s[i - 1]]
else:
int_val += rom_val[s[i]]
return int_val
@staticmethod
def hyphen_range_to(text):
pattern = re.compile(r"(\d+)[-–](\d+)")
result = pattern.sub(lambda x: x.group(1) + " to " + x.group(2), text)
return result
@staticmethod
def num_to_words(text):
# 1000 or 10.23
pattern = re.compile(r"\d+\.\d+|\d+")
result = pattern.sub(lambda x: num2words(float(x.group())), text)
return result
def replace_abbreviations(self, string):
# abbreviations 1-4 characters long. It will get things like A and I, but those are pronounced with their letter
pattern = re.compile(rf"(^|[\s(.\'\[<])([A-Z]{{1,4}})({self.punctuation}|$)")
result = string
while True:
match = pattern.search(result)
if match is None:
break
start = match.start()
end = match.end()
result = result[0:start] + self.replace_abbreviation(result[start:end]) + result[end: len(result)]
return result
def replace_lowercase_abbreviations(self, string):
# abbreviations 1 to 4 characters long, separated by dots i.e. e.g.
pattern = re.compile(rf"(^|[\s(.\'\[<])(([a-z]\.){{1,4}})({self.punctuation}|$)")
result = string
while True:
match = pattern.search(result)
if match is None:
break
start = match.start()
end = match.end()
result = result[0:start] + self.replace_abbreviation(result[start:end].upper()) + result[end: len(result)]
return result
def replace_abbreviation(self, string):
result = ""
for char in string:
result += self.match_mapping(char)
return result
def match_mapping(self, char):
for mapping in self.alphabet_map.keys():
if char == mapping:
return self.alphabet_map[char]
return char
def __main__(self, args):
print(self.preprocess(args[1]))