Spaces:
Running
on
Zero
Running
on
Zero
| import pickle | |
| import os | |
| import re | |
| import wordsegment | |
| from g2p_en import G2p | |
| from text.symbols import punctuation | |
| from text.symbols2 import symbols | |
| import unicodedata | |
| from builtins import str as unicode | |
| from g2p_en.expand import normalize_numbers | |
| from nltk.tokenize import TweetTokenizer | |
| word_tokenize = TweetTokenizer().tokenize | |
| from nltk import pos_tag | |
| current_file_path = os.path.dirname(__file__) | |
| CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep") | |
| CMU_DICT_FAST_PATH = os.path.join(current_file_path, "cmudict-fast.rep") | |
| CMU_DICT_HOT_PATH = os.path.join(current_file_path, "engdict-hot.rep") | |
| CACHE_PATH = os.path.join(current_file_path, "engdict_cache.pickle") | |
| NAMECACHE_PATH = os.path.join(current_file_path, "namedict_cache.pickle") | |
| arpa = { | |
| "AH0", | |
| "S", | |
| "AH1", | |
| "EY2", | |
| "AE2", | |
| "EH0", | |
| "OW2", | |
| "UH0", | |
| "NG", | |
| "B", | |
| "G", | |
| "AY0", | |
| "M", | |
| "AA0", | |
| "F", | |
| "AO0", | |
| "ER2", | |
| "UH1", | |
| "IY1", | |
| "AH2", | |
| "DH", | |
| "IY0", | |
| "EY1", | |
| "IH0", | |
| "K", | |
| "N", | |
| "W", | |
| "IY2", | |
| "T", | |
| "AA1", | |
| "ER1", | |
| "EH2", | |
| "OY0", | |
| "UH2", | |
| "UW1", | |
| "Z", | |
| "AW2", | |
| "AW1", | |
| "V", | |
| "UW2", | |
| "AA2", | |
| "ER", | |
| "AW0", | |
| "UW0", | |
| "R", | |
| "OW1", | |
| "EH1", | |
| "ZH", | |
| "AE0", | |
| "IH2", | |
| "IH", | |
| "Y", | |
| "JH", | |
| "P", | |
| "AY1", | |
| "EY0", | |
| "OY2", | |
| "TH", | |
| "HH", | |
| "D", | |
| "ER0", | |
| "CH", | |
| "AO1", | |
| "AE1", | |
| "AO2", | |
| "OY1", | |
| "AY2", | |
| "IH1", | |
| "OW0", | |
| "L", | |
| "SH", | |
| } | |
| def replace_phs(phs): | |
| rep_map = {"'": "-"} | |
| phs_new = [] | |
| for ph in phs: | |
| if ph in symbols: | |
| phs_new.append(ph) | |
| elif ph in rep_map.keys(): | |
| phs_new.append(rep_map[ph]) | |
| else: | |
| print("ph not in symbols: ", ph) | |
| return phs_new | |
| def replace_consecutive_punctuation(text): | |
| punctuations = ''.join(re.escape(p) for p in punctuation) | |
| pattern = f'([{punctuations}])([{punctuations}])+' | |
| result = re.sub(pattern, r'\1', text) | |
| return result | |
| def read_dict(): | |
| g2p_dict = {} | |
| start_line = 49 | |
| with open(CMU_DICT_PATH) as f: | |
| line = f.readline() | |
| line_index = 1 | |
| while line: | |
| if line_index >= start_line: | |
| line = line.strip() | |
| word_split = line.split(" ") | |
| word = word_split[0].lower() | |
| syllable_split = word_split[1].split(" - ") | |
| g2p_dict[word] = [] | |
| for syllable in syllable_split: | |
| phone_split = syllable.split(" ") | |
| g2p_dict[word].append(phone_split) | |
| line_index = line_index + 1 | |
| line = f.readline() | |
| return g2p_dict | |
| def read_dict_new(): | |
| g2p_dict = {} | |
| with open(CMU_DICT_PATH) as f: | |
| line = f.readline() | |
| line_index = 1 | |
| while line: | |
| if line_index >= 57: | |
| line = line.strip() | |
| word_split = line.split(" ") | |
| word = word_split[0].lower() | |
| g2p_dict[word] = [word_split[1].split(" ")] | |
| line_index = line_index + 1 | |
| line = f.readline() | |
| with open(CMU_DICT_FAST_PATH) as f: | |
| line = f.readline() | |
| line_index = 1 | |
| while line: | |
| if line_index >= 0: | |
| line = line.strip() | |
| word_split = line.split(" ") | |
| word = word_split[0].lower() | |
| if word not in g2p_dict: | |
| g2p_dict[word] = [word_split[1:]] | |
| line_index = line_index + 1 | |
| line = f.readline() | |
| return g2p_dict | |
| def hot_reload_hot(g2p_dict): | |
| with open(CMU_DICT_HOT_PATH) as f: | |
| line = f.readline() | |
| line_index = 1 | |
| while line: | |
| if line_index >= 0: | |
| line = line.strip() | |
| word_split = line.split(" ") | |
| word = word_split[0].lower() | |
| # 自定义发音词直接覆盖字典 | |
| g2p_dict[word] = [word_split[1:]] | |
| line_index = line_index + 1 | |
| line = f.readline() | |
| return g2p_dict | |
| def cache_dict(g2p_dict, file_path): | |
| with open(file_path, "wb") as pickle_file: | |
| pickle.dump(g2p_dict, pickle_file) | |
| def get_dict(): | |
| if os.path.exists(CACHE_PATH): | |
| with open(CACHE_PATH, "rb") as pickle_file: | |
| g2p_dict = pickle.load(pickle_file) | |
| else: | |
| g2p_dict = read_dict_new() | |
| cache_dict(g2p_dict, CACHE_PATH) | |
| g2p_dict = hot_reload_hot(g2p_dict) | |
| return g2p_dict | |
| def get_namedict(): | |
| if os.path.exists(NAMECACHE_PATH): | |
| with open(NAMECACHE_PATH, "rb") as pickle_file: | |
| name_dict = pickle.load(pickle_file) | |
| else: | |
| name_dict = {} | |
| return name_dict | |
| def text_normalize(text): | |
| # todo: eng text normalize | |
| # 适配中文及 g2p_en 标点 | |
| rep_map = { | |
| "[;::,;]": ",", | |
| '["’]': "'", | |
| "。": ".", | |
| "!": "!", | |
| "?": "?", | |
| } | |
| for p, r in rep_map.items(): | |
| text = re.sub(p, r, text) | |
| # 来自 g2p_en 文本格式化处理 | |
| # 增加大写兼容 | |
| text = unicode(text) | |
| text = normalize_numbers(text) | |
| text = ''.join(char for char in unicodedata.normalize('NFD', text) | |
| if unicodedata.category(char) != 'Mn') # Strip accents | |
| text = re.sub("[^ A-Za-z'.,?!\-]", "", text) | |
| text = re.sub(r"(?i)i\.e\.", "that is", text) | |
| text = re.sub(r"(?i)e\.g\.", "for example", text) | |
| # 避免重复标点引起的参考泄露 | |
| text = replace_consecutive_punctuation(text) | |
| return text | |
| class en_G2p(G2p): | |
| def __init__(self): | |
| super().__init__() | |
| # 分词初始化 | |
| wordsegment.load() | |
| # 扩展过时字典, 添加姓名字典 | |
| self.cmu = get_dict() | |
| self.namedict = get_namedict() | |
| # 剔除读音错误的几个缩写 | |
| for word in ["AE", "AI", "AR", "IOS", "HUD", "OS"]: | |
| del self.cmu[word.lower()] | |
| # 修正多音字 | |
| self.homograph2features["read"] = (['R', 'IY1', 'D'], ['R', 'EH1', 'D'], 'VBP') | |
| self.homograph2features["complex"] = (['K', 'AH0', 'M', 'P', 'L', 'EH1', 'K', 'S'], ['K', 'AA1', 'M', 'P', 'L', 'EH0', 'K', 'S'], 'JJ') | |
| def __call__(self, text): | |
| # tokenization | |
| words = word_tokenize(text) | |
| tokens = pos_tag(words) # tuples of (word, tag) | |
| # steps | |
| prons = [] | |
| for o_word, pos in tokens: | |
| # 还原 g2p_en 小写操作逻辑 | |
| word = o_word.lower() | |
| if re.search("[a-z]", word) is None: | |
| pron = [word] | |
| # 先把单字母推出去 | |
| elif len(word) == 1: | |
| # 单读 A 发音修正, 这里需要原格式 o_word 判断大写 | |
| if o_word == "A": | |
| pron = ['EY1'] | |
| else: | |
| pron = self.cmu[word][0] | |
| # g2p_en 原版多音字处理 | |
| elif word in self.homograph2features: # Check homograph | |
| pron1, pron2, pos1 = self.homograph2features[word] | |
| if pos.startswith(pos1): | |
| pron = pron1 | |
| # pos1比pos长仅出现在read | |
| elif len(pos) < len(pos1) and pos == pos1[:len(pos)]: | |
| pron = pron1 | |
| else: | |
| pron = pron2 | |
| else: | |
| # 递归查找预测 | |
| pron = self.qryword(o_word) | |
| prons.extend(pron) | |
| prons.extend([" "]) | |
| return prons[:-1] | |
| def qryword(self, o_word): | |
| word = o_word.lower() | |
| # 查字典, 单字母除外 | |
| if len(word) > 1 and word in self.cmu: # lookup CMU dict | |
| return self.cmu[word][0] | |
| # 单词仅首字母大写时查找姓名字典 | |
| if o_word.istitle() and word in self.namedict: | |
| return self.namedict[word][0] | |
| # oov 长度小于等于 3 直接读字母 | |
| if len(word) <= 3: | |
| phones = [] | |
| for w in word: | |
| # 单读 A 发音修正, 此处不存在大写的情况 | |
| if w == "a": | |
| phones.extend(['EY1']) | |
| else: | |
| phones.extend(self.cmu[w][0]) | |
| return phones | |
| # 尝试分离所有格 | |
| if re.match(r"^([a-z]+)('s)$", word): | |
| phones = self.qryword(word[:-2])[:] | |
| # P T K F TH HH 无声辅音结尾 's 发 ['S'] | |
| if phones[-1] in ['P', 'T', 'K', 'F', 'TH', 'HH']: | |
| phones.extend(['S']) | |
| # S Z SH ZH CH JH 擦声结尾 's 发 ['IH1', 'Z'] 或 ['AH0', 'Z'] | |
| elif phones[-1] in ['S', 'Z', 'SH', 'ZH', 'CH', 'JH']: | |
| phones.extend(['AH0', 'Z']) | |
| # B D G DH V M N NG L R W Y 有声辅音结尾 's 发 ['Z'] | |
| # AH0 AH1 AH2 EY0 EY1 EY2 AE0 AE1 AE2 EH0 EH1 EH2 OW0 OW1 OW2 UH0 UH1 UH2 IY0 IY1 IY2 AA0 AA1 AA2 AO0 AO1 AO2 | |
| # ER ER0 ER1 ER2 UW0 UW1 UW2 AY0 AY1 AY2 AW0 AW1 AW2 OY0 OY1 OY2 IH IH0 IH1 IH2 元音结尾 's 发 ['Z'] | |
| else: | |
| phones.extend(['Z']) | |
| return phones | |
| # 尝试进行分词,应对复合词 | |
| comps = wordsegment.segment(word.lower()) | |
| # 无法分词的送回去预测 | |
| if len(comps)==1: | |
| return self.predict(word) | |
| # 可以分词的递归处理 | |
| return [phone for comp in comps for phone in self.qryword(comp)] | |
| _g2p = en_G2p() | |
| def g2p(text): | |
| # g2p_en 整段推理,剔除不存在的arpa返回 | |
| phone_list = _g2p(text) | |
| phones = [ph if ph != "<unk>" else "UNK" for ph in phone_list if ph not in [" ", "<pad>", "UW", "</s>", "<s>"]] | |
| return replace_phs(phones) | |
| if __name__ == "__main__": | |
| print(g2p("hello")) | |
| print(g2p(text_normalize("e.g. I used openai's AI tool to draw a picture."))) | |
| print(g2p(text_normalize("In this; paper, we propose 1 DSPGAN, a GAN-based universal vocoder."))) | |