Spaces:
Sleeping
Sleeping
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Credits | |
This code is modified from https://github.com/GitYCC/g2pW | |
""" | |
import os | |
import re | |
def wordize_and_map(text: str): | |
words = [] | |
index_map_from_text_to_word = [] | |
index_map_from_word_to_text = [] | |
while len(text) > 0: | |
match_space = re.match(r"^ +", text) | |
if match_space: | |
space_str = match_space.group(0) | |
index_map_from_text_to_word += [None] * len(space_str) | |
text = text[len(space_str) :] | |
continue | |
match_en = re.match(r"^[a-zA-Z0-9]+", text) | |
if match_en: | |
en_word = match_en.group(0) | |
word_start_pos = len(index_map_from_text_to_word) | |
word_end_pos = word_start_pos + len(en_word) | |
index_map_from_word_to_text.append((word_start_pos, word_end_pos)) | |
index_map_from_text_to_word += [len(words)] * len(en_word) | |
words.append(en_word) | |
text = text[len(en_word) :] | |
else: | |
word_start_pos = len(index_map_from_text_to_word) | |
word_end_pos = word_start_pos + 1 | |
index_map_from_word_to_text.append((word_start_pos, word_end_pos)) | |
index_map_from_text_to_word += [len(words)] | |
words.append(text[0]) | |
text = text[1:] | |
return words, index_map_from_text_to_word, index_map_from_word_to_text | |
def tokenize_and_map(tokenizer, text: str): | |
words, text2word, word2text = wordize_and_map(text=text) | |
tokens = [] | |
index_map_from_token_to_text = [] | |
for word, (word_start, word_end) in zip(words, word2text): | |
word_tokens = tokenizer.tokenize(word) | |
if len(word_tokens) == 0 or word_tokens == ["[UNK]"]: | |
index_map_from_token_to_text.append((word_start, word_end)) | |
tokens.append("[UNK]") | |
else: | |
current_word_start = word_start | |
for word_token in word_tokens: | |
word_token_len = len(re.sub(r"^##", "", word_token)) | |
index_map_from_token_to_text.append( | |
(current_word_start, current_word_start + word_token_len) | |
) | |
current_word_start = current_word_start + word_token_len | |
tokens.append(word_token) | |
index_map_from_text_to_token = text2word | |
for i, (token_start, token_end) in enumerate(index_map_from_token_to_text): | |
for token_pos in range(token_start, token_end): | |
index_map_from_text_to_token[token_pos] = i | |
return tokens, index_map_from_text_to_token, index_map_from_token_to_text | |
def _load_config(config_path: os.PathLike): | |
import importlib.util | |
spec = importlib.util.spec_from_file_location("__init__", config_path) | |
config = importlib.util.module_from_spec(spec) | |
spec.loader.exec_module(config) | |
return config | |
default_config_dict = { | |
"manual_seed": 1313, | |
"model_source": "bert-base-chinese", | |
"window_size": 32, | |
"num_workers": 2, | |
"use_mask": True, | |
"use_char_phoneme": False, | |
"use_conditional": True, | |
"param_conditional": { | |
"affect_location": "softmax", | |
"bias": True, | |
"char-linear": True, | |
"pos-linear": False, | |
"char+pos-second": True, | |
"char+pos-second_lowrank": False, | |
"lowrank_size": 0, | |
"char+pos-second_fm": False, | |
"fm_size": 0, | |
"fix_mode": None, | |
"count_json": "train.count.json", | |
}, | |
"lr": 5e-5, | |
"val_interval": 200, | |
"num_iter": 10000, | |
"use_focal": False, | |
"param_focal": {"alpha": 0.0, "gamma": 0.7}, | |
"use_pos": True, | |
"param_pos ": { | |
"weight": 0.1, | |
"pos_joint_training": True, | |
"train_pos_path": "train.pos", | |
"valid_pos_path": "dev.pos", | |
"test_pos_path": "test.pos", | |
}, | |
} | |
def load_config(config_path: os.PathLike, use_default: bool = False): | |
config = _load_config(config_path) | |
if use_default: | |
for attr, val in default_config_dict.items(): | |
if not hasattr(config, attr): | |
setattr(config, attr, val) | |
elif isinstance(val, dict): | |
d = getattr(config, attr) | |
for dict_k, dict_v in val.items(): | |
if dict_k not in d: | |
d[dict_k] = dict_v | |
return config | |