Spaces:
Running
Running
Upload 80 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- api.py +103 -0
- audios/1.wav +0 -0
- audios/2.wav +0 -0
- audios/3.wav +0 -0
- audios/4.wav +3 -0
- audios/5.wav +0 -0
- audios/6.wav +0 -0
- audios/7.wav +0 -0
- audios/8.wav +0 -0
- checkpoints/.keep +0 -0
- checkpoints/checkpoint_0.pt +3 -0
- config.py +50 -0
- datas/__init__.py +0 -0
- datas/dataset.py +69 -0
- datas/sampler.py +131 -0
- models/__init__.py +0 -0
- models/diffusion_transformer.py +205 -0
- models/duration_predictor.py +40 -0
- models/estimator.py +138 -0
- models/flow_matching.py +100 -0
- models/model.py +178 -0
- models/reference_encoder.py +168 -0
- models/text_encoder.py +44 -0
- monotonic_align/__init__.py +16 -0
- monotonic_align/core.py +46 -0
- requirements.txt +33 -0
- text/LICENSE +19 -0
- text/__init__.py +71 -0
- text/cleaners.py +58 -0
- text/cn2an/__init__.py +16 -0
- text/cn2an/an2cn.py +204 -0
- text/cn2an/cn2an.py +294 -0
- text/cn2an/conf.py +135 -0
- text/cn2an/transform.py +104 -0
- text/cnm3/ds_CNM3.txt +606 -0
- text/custom_pypinyin_dict/__init__.py +1 -0
- text/custom_pypinyin_dict/cc_cedict_0.py +0 -0
- text/custom_pypinyin_dict/cc_cedict_1.py +0 -0
- text/custom_pypinyin_dict/cc_cedict_2.py +0 -0
- text/custom_pypinyin_dict/cc_cedict_3.py +14 -0
- text/custom_pypinyin_dict/genshin.py +11 -0
- text/custom_pypinyin_dict/phrase_pinyin_data.py +24 -0
- text/english.py +175 -0
- text/japanese.py +157 -0
- text/mandarin.py +173 -0
- text/symbols.py +79 -0
- utils/__init__.py +0 -0
- utils/audio.py +74 -0
- utils/load.py +43 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
audios/4.wav filter=lfs diff=lfs merge=lfs -text
|
api.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from dataclasses import asdict
|
5 |
+
|
6 |
+
from utils.audio import LogMelSpectrogram
|
7 |
+
from config import ModelConfig, MelConfig
|
8 |
+
from models.model import StableTTS
|
9 |
+
|
10 |
+
from text import symbols
|
11 |
+
from text import cleaned_text_to_sequence
|
12 |
+
from text.mandarin import chinese_to_cnm3
|
13 |
+
from text.english import english_to_ipa2
|
14 |
+
from text.japanese import japanese_to_ipa2
|
15 |
+
|
16 |
+
|
17 |
+
from datas.dataset import intersperse
|
18 |
+
from utils.audio import load_and_resample_audio
|
19 |
+
|
20 |
+
def get_vocoder(model_path, model_name='ffgan') -> nn.Module:
|
21 |
+
if model_name == 'ffgan':
|
22 |
+
# training or changing ffgan config is not supported in this repo
|
23 |
+
# you can train your own model at https://github.com/fishaudio/vocoder
|
24 |
+
from vocoders.ffgan.model import FireflyGANBaseWrapper
|
25 |
+
vocoder = FireflyGANBaseWrapper(model_path)
|
26 |
+
|
27 |
+
elif model_name == 'vocos':
|
28 |
+
from vocoders.vocos.models.model import Vocos
|
29 |
+
from config import VocosConfig, MelConfig
|
30 |
+
vocoder = Vocos(VocosConfig(), MelConfig())
|
31 |
+
vocoder.load_state_dict(torch.load(model_path, weights_only=True, map_location='cpu'))
|
32 |
+
vocoder.eval()
|
33 |
+
|
34 |
+
else:
|
35 |
+
raise NotImplementedError(f"Unsupported model: {model_name}")
|
36 |
+
|
37 |
+
return vocoder
|
38 |
+
|
39 |
+
class StableTTSAPI(nn.Module):
|
40 |
+
def __init__(self, tts_model_path, vocoder_model_path, vocoder_name='ffgan'):
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
self.mel_config = MelConfig()
|
44 |
+
self.tts_model_config = ModelConfig()
|
45 |
+
|
46 |
+
self.mel_extractor = LogMelSpectrogram(**asdict(self.mel_config))
|
47 |
+
|
48 |
+
# text to mel spectrogram
|
49 |
+
self.tts_model = StableTTS(len(symbols), self.mel_config.n_mels, **asdict(self.tts_model_config))
|
50 |
+
self.tts_model.load_state_dict(torch.load(tts_model_path, map_location='cpu', weights_only=True))
|
51 |
+
self.tts_model.eval()
|
52 |
+
|
53 |
+
# mel spectrogram to waveform
|
54 |
+
self.vocoder_model = get_vocoder(vocoder_model_path, vocoder_name)
|
55 |
+
self.vocoder_model.eval()
|
56 |
+
|
57 |
+
self.g2p_mapping = {
|
58 |
+
'chinese': chinese_to_cnm3,
|
59 |
+
'japanese': japanese_to_ipa2,
|
60 |
+
'english': english_to_ipa2,
|
61 |
+
}
|
62 |
+
self.supported_languages = self.g2p_mapping.keys()
|
63 |
+
|
64 |
+
@ torch.inference_mode()
|
65 |
+
def inference(self, text, ref_audio, language, step, temperature=1.0, length_scale=1.0, solver=None, cfg=3.0):
|
66 |
+
device = next(self.parameters()).device
|
67 |
+
phonemizer = self.g2p_mapping.get(language)
|
68 |
+
|
69 |
+
text = phonemizer(text)
|
70 |
+
text = torch.tensor(intersperse(cleaned_text_to_sequence(text), item=0), dtype=torch.long, device=device).unsqueeze(0)
|
71 |
+
text_length = torch.tensor([text.size(-1)], dtype=torch.long, device=device)
|
72 |
+
|
73 |
+
ref_audio = load_and_resample_audio(ref_audio, self.mel_config.sample_rate).to(device)
|
74 |
+
ref_audio = self.mel_extractor(ref_audio)
|
75 |
+
|
76 |
+
mel_output = self.tts_model.synthesise(text, text_length, step, temperature, ref_audio, length_scale, solver, cfg)['decoder_outputs']
|
77 |
+
audio_output = self.vocoder_model(mel_output)
|
78 |
+
return audio_output.cpu(), mel_output.cpu()
|
79 |
+
|
80 |
+
def get_params(self):
|
81 |
+
tts_param = sum(p.numel() for p in self.tts_model.parameters()) / 1e6
|
82 |
+
vocoder_param = sum(p.numel() for p in self.vocoder_model.parameters()) / 1e6
|
83 |
+
return tts_param, vocoder_param
|
84 |
+
|
85 |
+
if __name__ == '__main__':
|
86 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
87 |
+
tts_model_path = './checkpoints/checkpoint_0.pt'
|
88 |
+
vocoder_model_path = './vocoders/pretrained/vocos.pt'
|
89 |
+
|
90 |
+
model = StableTTSAPI(tts_model_path, vocoder_model_path, 'vocos')
|
91 |
+
model.to(device)
|
92 |
+
|
93 |
+
text = '樱落满殇祈念集……殇歌花落集思祈……樱花满地集于我心……揲舞纷飞祈愿相随……'
|
94 |
+
audio = './audio_1.wav'
|
95 |
+
|
96 |
+
audio_output, mel_output = model.inference(text, audio, 'chinese', 10, solver='dopri5', cfg=3)
|
97 |
+
print(audio_output.shape)
|
98 |
+
print(mel_output.shape)
|
99 |
+
|
100 |
+
import torchaudio
|
101 |
+
torchaudio.save('output.wav', audio_output, MelConfig().sample_rate)
|
102 |
+
|
103 |
+
|
audios/1.wav
ADDED
Binary file (374 kB). View file
|
|
audios/2.wav
ADDED
Binary file (182 kB). View file
|
|
audios/3.wav
ADDED
Binary file (529 kB). View file
|
|
audios/4.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e6672b81d7dd41cac56cf49b75bb66a5486b5fe969ddab0f98f14b05be7857df
|
3 |
+
size 1349150
|
audios/5.wav
ADDED
Binary file (368 kB). View file
|
|
audios/6.wav
ADDED
Binary file (431 kB). View file
|
|
audios/7.wav
ADDED
Binary file (514 kB). View file
|
|
audios/8.wav
ADDED
Binary file (420 kB). View file
|
|
checkpoints/.keep
ADDED
File without changes
|
checkpoints/checkpoint_0.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b041bea13241b402bbfcdbfffd14381774be1179bae78e99ebd505d6d89f9367
|
3 |
+
size 126657600
|
config.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
@dataclass
|
4 |
+
class MelConfig:
|
5 |
+
sample_rate: int = 44100
|
6 |
+
n_fft: int = 2048
|
7 |
+
win_length: int = 2048
|
8 |
+
hop_length: int = 512
|
9 |
+
f_min: float = 0.0
|
10 |
+
f_max: float = None
|
11 |
+
pad: int = 0
|
12 |
+
n_mels: int = 128
|
13 |
+
center: bool = False
|
14 |
+
pad_mode: str = "reflect"
|
15 |
+
mel_scale: str = "slaney"
|
16 |
+
|
17 |
+
def __post_init__(self):
|
18 |
+
if self.pad == 0:
|
19 |
+
self.pad = (self.n_fft - self.hop_length) // 2
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class ModelConfig:
|
23 |
+
hidden_channels: int = 256
|
24 |
+
filter_channels: int = 1024
|
25 |
+
n_heads: int = 4
|
26 |
+
n_enc_layers: int = 3
|
27 |
+
n_dec_layers: int = 6
|
28 |
+
kernel_size: int = 3
|
29 |
+
p_dropout: int = 0.1
|
30 |
+
gin_channels: int = 256
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class TrainConfig:
|
34 |
+
train_dataset_path: str = 'filelists/filelist.json'
|
35 |
+
test_dataset_path: str = 'filelists/filelist.json' # not used
|
36 |
+
batch_size: int = 32
|
37 |
+
learning_rate: float = 1e-4
|
38 |
+
num_epochs: int = 10000
|
39 |
+
model_save_path: str = './checkpoints'
|
40 |
+
log_dir: str = './runs'
|
41 |
+
log_interval: int = 16
|
42 |
+
save_interval: int = 1
|
43 |
+
warmup_steps: int = 200
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class VocosConfig:
|
47 |
+
input_channels: int = 128
|
48 |
+
dim: int = 512
|
49 |
+
intermediate_dim: int = 1536
|
50 |
+
num_layers: int = 8
|
datas/__init__.py
ADDED
File without changes
|
datas/dataset.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
|
4 |
+
import json
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
|
8 |
+
from text import cleaned_text_to_sequence
|
9 |
+
|
10 |
+
def intersperse(lst: list, item: int):
|
11 |
+
"""
|
12 |
+
putting a blank token between any two input tokens to improve pronunciation
|
13 |
+
see https://github.com/jaywalnut310/glow-tts/issues/43 for more details
|
14 |
+
"""
|
15 |
+
result = [item] * (len(lst) * 2 + 1)
|
16 |
+
result[1::2] = lst
|
17 |
+
return result
|
18 |
+
|
19 |
+
class StableDataset(Dataset):
|
20 |
+
def __init__(self, filelist_path, hop_length):
|
21 |
+
self.filelist_path = filelist_path
|
22 |
+
self.hop_length = hop_length
|
23 |
+
|
24 |
+
self._load_filelist(filelist_path)
|
25 |
+
|
26 |
+
def _load_filelist(self, filelist_path):
|
27 |
+
filelist, lengths = [], []
|
28 |
+
with open(filelist_path, 'r', encoding='utf-8') as f:
|
29 |
+
for line in f:
|
30 |
+
line = json.loads(line.strip())
|
31 |
+
filelist.append((line['mel_path'], line['phone']))
|
32 |
+
lengths.append(line['mel_length'])
|
33 |
+
|
34 |
+
self.filelist = filelist
|
35 |
+
self.lengths = lengths # length is used for DistributedBucketSampler
|
36 |
+
|
37 |
+
def __len__(self):
|
38 |
+
return len(self.filelist)
|
39 |
+
|
40 |
+
def __getitem__(self, idx):
|
41 |
+
mel_path, phone = self.filelist[idx]
|
42 |
+
mel = torch.load(mel_path, map_location='cpu', weights_only=True)
|
43 |
+
phone = torch.tensor(intersperse(cleaned_text_to_sequence(phone), 0), dtype=torch.long)
|
44 |
+
return mel, phone
|
45 |
+
|
46 |
+
def collate_fn(batch):
|
47 |
+
texts = [item[1] for item in batch]
|
48 |
+
mels = [item[0] for item in batch]
|
49 |
+
mels_sliced = [random_slice_tensor(mel) for mel in mels]
|
50 |
+
|
51 |
+
text_lengths = torch.tensor([text.size(-1) for text in texts], dtype=torch.long)
|
52 |
+
mel_lengths = torch.tensor([mel.size(-1) for mel in mels], dtype=torch.long)
|
53 |
+
mels_sliced_lengths = torch.tensor([mel_sliced.size(-1) for mel_sliced in mels_sliced], dtype=torch.long)
|
54 |
+
|
55 |
+
# pad to the same length
|
56 |
+
texts_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(texts), padding=0)
|
57 |
+
mels_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(mels), padding=0)
|
58 |
+
mels_sliced_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(mels_sliced), padding=0)
|
59 |
+
|
60 |
+
return texts_padded, text_lengths, mels_padded, mel_lengths, mels_sliced_padded, mels_sliced_lengths
|
61 |
+
|
62 |
+
# random slice mel for reference encoder to prevent overfitting
|
63 |
+
def random_slice_tensor(x: torch.Tensor):
|
64 |
+
length = x.size(-1)
|
65 |
+
if length < 8:
|
66 |
+
return x
|
67 |
+
segmnt_size = random.randint(length // 12, length // 3)
|
68 |
+
start = random.randint(0, length - segmnt_size)
|
69 |
+
return x[..., start : start + segmnt_size]
|
datas/sampler.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
# reference: https://github.com/jaywalnut310/vits/blob/main/data_utils.py
|
4 |
+
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
5 |
+
"""
|
6 |
+
Maintain similar input lengths in a batch.
|
7 |
+
Length groups are specified by boundaries.
|
8 |
+
Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
|
9 |
+
|
10 |
+
It removes samples which are not included in the boundaries.
|
11 |
+
Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
dataset,
|
17 |
+
batch_size,
|
18 |
+
boundaries,
|
19 |
+
num_replicas=None,
|
20 |
+
rank=None,
|
21 |
+
shuffle=True,
|
22 |
+
):
|
23 |
+
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
24 |
+
self.lengths = dataset.lengths
|
25 |
+
self.batch_size = batch_size
|
26 |
+
self.boundaries = boundaries
|
27 |
+
|
28 |
+
self.buckets, self.num_samples_per_bucket = self._create_buckets()
|
29 |
+
self.total_size = sum(self.num_samples_per_bucket)
|
30 |
+
self.num_samples = self.total_size // self.num_replicas
|
31 |
+
|
32 |
+
def _create_buckets(self):
|
33 |
+
buckets = [[] for _ in range(len(self.boundaries) - 1)]
|
34 |
+
for i in range(len(self.lengths)):
|
35 |
+
length = self.lengths[i]
|
36 |
+
idx_bucket = self._bisect(length)
|
37 |
+
if idx_bucket != -1:
|
38 |
+
buckets[idx_bucket].append(i)
|
39 |
+
|
40 |
+
# from https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/data_utils.py
|
41 |
+
# avoid "integer division or modulo by zero" error for very small dataset
|
42 |
+
try:
|
43 |
+
for i in range(len(buckets) - 1, 0, -1):
|
44 |
+
if len(buckets[i]) == 0:
|
45 |
+
buckets.pop(i)
|
46 |
+
self.boundaries.pop(i + 1)
|
47 |
+
assert all(len(bucket) > 0 for bucket in buckets)
|
48 |
+
# When one bucket is not traversed
|
49 |
+
except Exception as e:
|
50 |
+
print('Bucket warning ', e)
|
51 |
+
for i in range(len(buckets) - 1, -1, -1):
|
52 |
+
if len(buckets[i]) == 0:
|
53 |
+
buckets.pop(i)
|
54 |
+
self.boundaries.pop(i + 1)
|
55 |
+
|
56 |
+
num_samples_per_bucket = []
|
57 |
+
for i in range(len(buckets)):
|
58 |
+
len_bucket = len(buckets[i])
|
59 |
+
total_batch_size = self.num_replicas * self.batch_size
|
60 |
+
rem = (
|
61 |
+
total_batch_size - (len_bucket % total_batch_size)
|
62 |
+
) % total_batch_size
|
63 |
+
num_samples_per_bucket.append(len_bucket + rem)
|
64 |
+
return buckets, num_samples_per_bucket
|
65 |
+
|
66 |
+
def __iter__(self):
|
67 |
+
# deterministically shuffle based on epoch
|
68 |
+
g = torch.Generator()
|
69 |
+
g.manual_seed(self.epoch)
|
70 |
+
|
71 |
+
indices = []
|
72 |
+
if self.shuffle:
|
73 |
+
for bucket in self.buckets:
|
74 |
+
indices.append(torch.randperm(len(bucket), generator=g).tolist())
|
75 |
+
else:
|
76 |
+
for bucket in self.buckets:
|
77 |
+
indices.append(list(range(len(bucket))))
|
78 |
+
|
79 |
+
batches = []
|
80 |
+
for i in range(len(self.buckets)):
|
81 |
+
bucket = self.buckets[i]
|
82 |
+
len_bucket = len(bucket)
|
83 |
+
ids_bucket = indices[i]
|
84 |
+
num_samples_bucket = self.num_samples_per_bucket[i]
|
85 |
+
|
86 |
+
# add extra samples to make it evenly divisible
|
87 |
+
rem = num_samples_bucket - len_bucket
|
88 |
+
ids_bucket = (
|
89 |
+
ids_bucket
|
90 |
+
+ ids_bucket * (rem // len_bucket)
|
91 |
+
+ ids_bucket[: (rem % len_bucket)]
|
92 |
+
)
|
93 |
+
|
94 |
+
# subsample
|
95 |
+
ids_bucket = ids_bucket[self.rank :: self.num_replicas]
|
96 |
+
|
97 |
+
# batching
|
98 |
+
for j in range(len(ids_bucket) // self.batch_size):
|
99 |
+
batch = [
|
100 |
+
bucket[idx]
|
101 |
+
for idx in ids_bucket[
|
102 |
+
j * self.batch_size : (j + 1) * self.batch_size
|
103 |
+
]
|
104 |
+
]
|
105 |
+
batches.append(batch)
|
106 |
+
|
107 |
+
if self.shuffle:
|
108 |
+
batch_ids = torch.randperm(len(batches), generator=g).tolist()
|
109 |
+
batches = [batches[i] for i in batch_ids]
|
110 |
+
self.batches = batches
|
111 |
+
|
112 |
+
assert len(self.batches) * self.batch_size == self.num_samples
|
113 |
+
return iter(self.batches)
|
114 |
+
|
115 |
+
def _bisect(self, x, lo=0, hi=None):
|
116 |
+
if hi is None:
|
117 |
+
hi = len(self.boundaries) - 1
|
118 |
+
|
119 |
+
if hi > lo:
|
120 |
+
mid = (hi + lo) // 2
|
121 |
+
if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
|
122 |
+
return mid
|
123 |
+
elif x <= self.boundaries[mid]:
|
124 |
+
return self._bisect(x, lo, mid)
|
125 |
+
else:
|
126 |
+
return self._bisect(x, mid + 1, hi)
|
127 |
+
else:
|
128 |
+
return -1
|
129 |
+
|
130 |
+
def __len__(self):
|
131 |
+
return self.num_samples // self.batch_size
|
models/__init__.py
ADDED
File without changes
|
models/diffusion_transformer.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# References:
|
2 |
+
# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/transformer.py
|
3 |
+
# https://github.com/jaywalnut310/vits/blob/main/attentions.py
|
4 |
+
# https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
class FFN(nn.Module):
|
11 |
+
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., gin_channels=0):
|
12 |
+
super().__init__()
|
13 |
+
self.in_channels = in_channels
|
14 |
+
self.out_channels = out_channels
|
15 |
+
self.filter_channels = filter_channels
|
16 |
+
self.kernel_size = kernel_size
|
17 |
+
self.p_dropout = p_dropout
|
18 |
+
self.gin_channels = gin_channels
|
19 |
+
|
20 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
21 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
22 |
+
self.drop = nn.Dropout(p_dropout)
|
23 |
+
self.act1 = nn.SiLU(inplace=True)
|
24 |
+
|
25 |
+
def forward(self, x, x_mask):
|
26 |
+
x = self.conv_1(x * x_mask)
|
27 |
+
x = self.act1(x)
|
28 |
+
x = self.drop(x)
|
29 |
+
x = self.conv_2(x * x_mask)
|
30 |
+
return x * x_mask
|
31 |
+
|
32 |
+
class MultiHeadAttention(nn.Module):
|
33 |
+
def __init__(self, channels, out_channels, n_heads, p_dropout=0.):
|
34 |
+
super().__init__()
|
35 |
+
assert channels % n_heads == 0
|
36 |
+
|
37 |
+
self.channels = channels
|
38 |
+
self.out_channels = out_channels
|
39 |
+
self.n_heads = n_heads
|
40 |
+
self.p_dropout = p_dropout
|
41 |
+
|
42 |
+
self.k_channels = channels // n_heads
|
43 |
+
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
|
44 |
+
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
|
45 |
+
self.conv_v = torch.nn.Conv1d(channels, channels, 1)
|
46 |
+
|
47 |
+
# from https://nn.labml.ai/transformers/rope/index.html
|
48 |
+
self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
|
49 |
+
self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
|
50 |
+
|
51 |
+
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
|
52 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
53 |
+
|
54 |
+
torch.nn.init.xavier_uniform_(self.conv_q.weight)
|
55 |
+
torch.nn.init.xavier_uniform_(self.conv_k.weight)
|
56 |
+
torch.nn.init.xavier_uniform_(self.conv_v.weight)
|
57 |
+
|
58 |
+
def forward(self, x, attn_mask=None):
|
59 |
+
q = self.conv_q(x)
|
60 |
+
k = self.conv_k(x)
|
61 |
+
v = self.conv_v(x)
|
62 |
+
|
63 |
+
x = self.attention(q, k, v, mask=attn_mask)
|
64 |
+
|
65 |
+
x = self.conv_o(x)
|
66 |
+
return x
|
67 |
+
|
68 |
+
def attention(self, query, key, value, mask=None):
|
69 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
70 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
71 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
72 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
73 |
+
|
74 |
+
query = self.query_rotary_pe(query) # [b, n_head, t, c // n_head]
|
75 |
+
key = self.key_rotary_pe(key)
|
76 |
+
|
77 |
+
output = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, dropout_p=self.p_dropout if self.training else 0)
|
78 |
+
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
79 |
+
return output
|
80 |
+
|
81 |
+
# modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/modules.py#L390
|
82 |
+
class DiTConVBlock(nn.Module):
|
83 |
+
"""
|
84 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
85 |
+
"""
|
86 |
+
def __init__(self, hidden_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0):
|
87 |
+
super().__init__()
|
88 |
+
self.norm1 = nn.LayerNorm(hidden_channels, elementwise_affine=False)
|
89 |
+
self.attn = MultiHeadAttention(hidden_channels, hidden_channels, num_heads, p_dropout)
|
90 |
+
self.norm2 = nn.LayerNorm(hidden_channels, elementwise_affine=False)
|
91 |
+
self.mlp = FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)
|
92 |
+
self.adaLN_modulation = nn.Sequential(
|
93 |
+
nn.Linear(gin_channels, hidden_channels) if gin_channels != hidden_channels else nn.Identity(),
|
94 |
+
nn.SiLU(),
|
95 |
+
nn.Linear(hidden_channels, 6 * hidden_channels, bias=True)
|
96 |
+
)
|
97 |
+
|
98 |
+
def forward(self, x, c, x_mask):
|
99 |
+
"""
|
100 |
+
Args:
|
101 |
+
x : [batch_size, channel, time]
|
102 |
+
c : [batch_size, channel]
|
103 |
+
x_mask : [batch_size, 1, time]
|
104 |
+
return the same shape as x
|
105 |
+
"""
|
106 |
+
x = x * x_mask
|
107 |
+
attn_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(-1) # shape: [batch_size, 1, time, time]
|
108 |
+
attn_mask = torch.zeros_like(attn_mask).masked_fill(attn_mask == 0, -torch.finfo(x.dtype).max)
|
109 |
+
|
110 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).unsqueeze(2).chunk(6, dim=1) # shape: [batch_size, channel, 1]
|
111 |
+
x = x + gate_msa * self.attn(self.modulate(self.norm1(x.transpose(1,2)).transpose(1,2), shift_msa, scale_msa), attn_mask) * x_mask
|
112 |
+
x = x + gate_mlp * self.mlp(self.modulate(self.norm2(x.transpose(1,2)).transpose(1,2), shift_mlp, scale_mlp), x_mask)
|
113 |
+
|
114 |
+
# no condition version
|
115 |
+
# x = x + self.attn(self.norm1(x.transpose(1,2)).transpose(1,2), attn_mask)
|
116 |
+
# x = x + self.mlp(self.norm2(x.transpose(1,2)).transpose(1,2), x_mask)
|
117 |
+
return x
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def modulate(x, shift, scale):
|
121 |
+
return x * (1 + scale) + shift
|
122 |
+
|
123 |
+
class RotaryPositionalEmbeddings(nn.Module):
|
124 |
+
"""
|
125 |
+
## RoPE module
|
126 |
+
|
127 |
+
Rotary encoding transforms pairs of features by rotating in the 2D plane.
|
128 |
+
That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
|
129 |
+
Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
|
130 |
+
by an angle depending on the position of the token.
|
131 |
+
"""
|
132 |
+
|
133 |
+
def __init__(self, d: int, base: int = 10_000):
|
134 |
+
r"""
|
135 |
+
* `d` is the number of features $d$
|
136 |
+
* `base` is the constant used for calculating $\Theta$
|
137 |
+
"""
|
138 |
+
super().__init__()
|
139 |
+
|
140 |
+
self.base = base
|
141 |
+
self.d = int(d)
|
142 |
+
self.cos_cached = None
|
143 |
+
self.sin_cached = None
|
144 |
+
|
145 |
+
def _build_cache(self, x: torch.Tensor):
|
146 |
+
r"""
|
147 |
+
Cache $\cos$ and $\sin$ values
|
148 |
+
"""
|
149 |
+
# Return if cache is already built
|
150 |
+
if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
|
151 |
+
return
|
152 |
+
|
153 |
+
# Get sequence length
|
154 |
+
seq_len = x.shape[0]
|
155 |
+
|
156 |
+
# $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
157 |
+
theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
|
158 |
+
|
159 |
+
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
160 |
+
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
|
161 |
+
|
162 |
+
# Calculate the product of position index and $\theta_i$
|
163 |
+
idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
|
164 |
+
|
165 |
+
# Concatenate so that for row $m$ we have
|
166 |
+
# $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
|
167 |
+
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
|
168 |
+
|
169 |
+
# Cache them
|
170 |
+
self.cos_cached = idx_theta2.cos()[:, None, None, :]
|
171 |
+
self.sin_cached = idx_theta2.sin()[:, None, None, :]
|
172 |
+
|
173 |
+
def _neg_half(self, x: torch.Tensor):
|
174 |
+
# $\frac{d}{2}$
|
175 |
+
d_2 = self.d // 2
|
176 |
+
|
177 |
+
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
178 |
+
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
|
179 |
+
|
180 |
+
def forward(self, x: torch.Tensor):
|
181 |
+
"""
|
182 |
+
* `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
|
183 |
+
"""
|
184 |
+
# Cache $\cos$ and $\sin$ values
|
185 |
+
x = x.permute(2, 0, 1, 3) # b h t d -> t b h d
|
186 |
+
|
187 |
+
self._build_cache(x)
|
188 |
+
|
189 |
+
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
|
190 |
+
x_rope, x_pass = x[..., : self.d], x[..., self.d :]
|
191 |
+
|
192 |
+
# Calculate
|
193 |
+
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
194 |
+
neg_half_x = self._neg_half(x_rope)
|
195 |
+
|
196 |
+
x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
|
197 |
+
|
198 |
+
return torch.cat((x_rope, x_pass), dim=-1).permute(1, 2, 0, 3) # t b h d -> b h t d
|
199 |
+
|
200 |
+
class Transpose(nn.Identity):
|
201 |
+
"""(N, T, D) -> (N, D, T)"""
|
202 |
+
|
203 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
204 |
+
return input.transpose(1, 2)
|
205 |
+
|
models/duration_predictor.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
# modified from https://github.com/jaywalnut310/vits/blob/main/models.py#L98
|
5 |
+
class DurationPredictor(nn.Module):
|
6 |
+
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
self.in_channels = in_channels
|
10 |
+
self.filter_channels = filter_channels
|
11 |
+
self.kernel_size = kernel_size
|
12 |
+
self.p_dropout = p_dropout
|
13 |
+
self.gin_channels = gin_channels
|
14 |
+
|
15 |
+
self.drop = nn.Dropout(p_dropout)
|
16 |
+
self.conv1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
|
17 |
+
self.norm1 = nn.LayerNorm(filter_channels)
|
18 |
+
self.conv2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
|
19 |
+
self.norm2 = nn.LayerNorm(filter_channels)
|
20 |
+
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
21 |
+
|
22 |
+
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
23 |
+
|
24 |
+
def forward(self, x, x_mask, g):
|
25 |
+
x = x.detach()
|
26 |
+
x = x + self.cond(g.unsqueeze(2).detach())
|
27 |
+
x = self.conv1(x * x_mask)
|
28 |
+
x = torch.relu(x)
|
29 |
+
x = self.norm1(x.transpose(1,2)).transpose(1,2)
|
30 |
+
x = self.drop(x)
|
31 |
+
x = self.conv2(x * x_mask)
|
32 |
+
x = torch.relu(x)
|
33 |
+
x = self.norm2(x.transpose(1,2)).transpose(1,2)
|
34 |
+
x = self.drop(x)
|
35 |
+
x = self.proj(x * x_mask)
|
36 |
+
return x * x_mask
|
37 |
+
|
38 |
+
def duration_loss(logw, logw_, lengths):
|
39 |
+
loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
|
40 |
+
return loss
|
models/estimator.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from models.diffusion_transformer import DiTConVBlock
|
7 |
+
|
8 |
+
class DitWrapper(nn.Module):
|
9 |
+
""" add FiLM layer to condition time embedding to DiT """
|
10 |
+
def __init__(self, hidden_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0, time_channels=0):
|
11 |
+
super().__init__()
|
12 |
+
self.time_fusion = FiLMLayer(hidden_channels, time_channels)
|
13 |
+
self.block = DiTConVBlock(hidden_channels, filter_channels, num_heads, kernel_size, p_dropout, gin_channels)
|
14 |
+
|
15 |
+
def forward(self, x, c, t, x_mask):
|
16 |
+
x = self.time_fusion(x, t) * x_mask
|
17 |
+
x = self.block(x, c, x_mask)
|
18 |
+
return x
|
19 |
+
|
20 |
+
class FiLMLayer(nn.Module):
|
21 |
+
"""
|
22 |
+
Feature-wise Linear Modulation (FiLM) layer
|
23 |
+
Reference: https://arxiv.org/abs/1709.07871
|
24 |
+
"""
|
25 |
+
def __init__(self, in_channels, cond_channels):
|
26 |
+
|
27 |
+
super(FiLMLayer, self).__init__()
|
28 |
+
self.in_channels = in_channels
|
29 |
+
self.film = nn.Conv1d(cond_channels, in_channels * 2, 1)
|
30 |
+
|
31 |
+
def forward(self, x, c):
|
32 |
+
gamma, beta = torch.chunk(self.film(c.unsqueeze(2)), chunks=2, dim=1)
|
33 |
+
return gamma * x + beta
|
34 |
+
|
35 |
+
class SinusoidalPosEmb(nn.Module):
|
36 |
+
def __init__(self, dim):
|
37 |
+
super().__init__()
|
38 |
+
self.dim = dim
|
39 |
+
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
40 |
+
|
41 |
+
def forward(self, x, scale=1000):
|
42 |
+
if x.ndim < 1:
|
43 |
+
x = x.unsqueeze(0)
|
44 |
+
half_dim = self.dim // 2
|
45 |
+
emb = math.log(10000) / (half_dim - 1)
|
46 |
+
emb = torch.exp(torch.arange(half_dim, device=x.device).float() * -emb)
|
47 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
48 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
49 |
+
return emb
|
50 |
+
|
51 |
+
class TimestepEmbedding(nn.Module):
|
52 |
+
def __init__(self, in_channels, out_channels, filter_channels):
|
53 |
+
super().__init__()
|
54 |
+
|
55 |
+
self.layer = nn.Sequential(
|
56 |
+
nn.Linear(in_channels, filter_channels),
|
57 |
+
nn.SiLU(inplace=True),
|
58 |
+
nn.Linear(filter_channels, out_channels)
|
59 |
+
)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
return self.layer(x)
|
63 |
+
|
64 |
+
# reference: https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/decoder.py
|
65 |
+
class Decoder(nn.Module):
|
66 |
+
def __init__(self, noise_channels, cond_channels, hidden_channels, out_channels, filter_channels, dropout=0.1, n_layers=1, n_heads=4, kernel_size=3, gin_channels=0, use_lsc=True):
|
67 |
+
super().__init__()
|
68 |
+
self.noise_channels = noise_channels
|
69 |
+
self.cond_channels = cond_channels
|
70 |
+
self.hidden_channels = hidden_channels
|
71 |
+
self.out_channels = out_channels
|
72 |
+
self.filter_channels = filter_channels
|
73 |
+
self.use_lsc = use_lsc # whether to use unet-like long skip connection
|
74 |
+
|
75 |
+
self.time_embeddings = SinusoidalPosEmb(hidden_channels)
|
76 |
+
self.time_mlp = TimestepEmbedding(hidden_channels, hidden_channels, filter_channels)
|
77 |
+
|
78 |
+
self.in_proj = nn.Conv1d(hidden_channels + noise_channels, hidden_channels, 1) # cat noise and encoder output as input
|
79 |
+
self.blocks = nn.ModuleList([DitWrapper(hidden_channels, filter_channels, n_heads, kernel_size, dropout, gin_channels, hidden_channels) for _ in range(n_layers)])
|
80 |
+
self.final_proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
81 |
+
|
82 |
+
# prenet for encoder output
|
83 |
+
self.cond_proj = nn.Sequential(
|
84 |
+
nn.Conv1d(cond_channels, filter_channels, kernel_size, padding=kernel_size//2),
|
85 |
+
nn.SiLU(inplace=True),
|
86 |
+
nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2), # add about 3M params
|
87 |
+
nn.SiLU(inplace=True),
|
88 |
+
nn.Conv1d(filter_channels, hidden_channels, kernel_size, padding=kernel_size//2)
|
89 |
+
)
|
90 |
+
|
91 |
+
if use_lsc:
|
92 |
+
assert n_layers % 2 == 0
|
93 |
+
self.n_lsc_layers = n_layers // 2
|
94 |
+
self.lsc_layers = nn.ModuleList([nn.Conv1d(hidden_channels + hidden_channels, hidden_channels, kernel_size, padding = kernel_size // 2) for _ in range(self.n_lsc_layers)])
|
95 |
+
|
96 |
+
self.initialize_weights()
|
97 |
+
|
98 |
+
def initialize_weights(self):
|
99 |
+
for block in self.blocks:
|
100 |
+
nn.init.constant_(block.block.adaLN_modulation[-1].weight, 0)
|
101 |
+
nn.init.constant_(block.block.adaLN_modulation[-1].bias, 0)
|
102 |
+
|
103 |
+
def forward(self, t, x, mask, mu, c):
|
104 |
+
"""Forward pass of the DiT model.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
t (torch.Tensor): timestep, shape (batch_size)
|
108 |
+
x (torch.Tensor): noise, shape (batch_size, in_channels, time)
|
109 |
+
mask (torch.Tensor): shape (batch_size, 1, time)
|
110 |
+
mu (torch.Tensor): output of encoder, shape (batch_size, in_channels, time)
|
111 |
+
c (torch.Tensor): shape (batch_size, gin_channels)
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
_type_: _description_
|
115 |
+
"""
|
116 |
+
|
117 |
+
t = self.time_mlp(self.time_embeddings(t))
|
118 |
+
mu = self.cond_proj(mu)
|
119 |
+
|
120 |
+
x = torch.cat((x, mu), dim=1)
|
121 |
+
x = self.in_proj(x)
|
122 |
+
|
123 |
+
lsc_outputs = [] if self.use_lsc else None
|
124 |
+
|
125 |
+
for idx, block in enumerate(self.blocks):
|
126 |
+
# add long skip connection, see https://arxiv.org/pdf/2209.12152 for more details
|
127 |
+
if self.use_lsc:
|
128 |
+
if idx < self.n_lsc_layers:
|
129 |
+
lsc_outputs.append(x)
|
130 |
+
else:
|
131 |
+
x = torch.cat((x, lsc_outputs.pop()), dim=1)
|
132 |
+
x = self.lsc_layers[idx - self.n_lsc_layers](x)
|
133 |
+
|
134 |
+
x = block(x, c, t, mask)
|
135 |
+
|
136 |
+
output = self.final_proj(x * mask)
|
137 |
+
|
138 |
+
return output * mask
|
models/flow_matching.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import functools
|
6 |
+
from torchdiffeq import odeint
|
7 |
+
|
8 |
+
from models.estimator import Decoder
|
9 |
+
|
10 |
+
# modified from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/flow_matching.py
|
11 |
+
class CFMDecoder(torch.nn.Module):
|
12 |
+
def __init__(self, noise_channels, cond_channels, hidden_channels, out_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels):
|
13 |
+
super().__init__()
|
14 |
+
self.noise_channels = noise_channels
|
15 |
+
self.cond_channels = cond_channels
|
16 |
+
self.hidden_channels = hidden_channels
|
17 |
+
self.out_channels = out_channels
|
18 |
+
self.filter_channels = filter_channels
|
19 |
+
self.gin_channels = gin_channels
|
20 |
+
self.sigma_min = 1e-4
|
21 |
+
|
22 |
+
self.estimator = Decoder(noise_channels, cond_channels, hidden_channels, out_channels, filter_channels, p_dropout, n_layers, n_heads, kernel_size, gin_channels)
|
23 |
+
|
24 |
+
@torch.inference_mode()
|
25 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, c=None, solver=None, cfg_kwargs=None):
|
26 |
+
"""Forward diffusion
|
27 |
+
|
28 |
+
Args:
|
29 |
+
mu (torch.Tensor): output of encoder
|
30 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
31 |
+
mask (torch.Tensor): output_mask
|
32 |
+
shape: (batch_size, 1, mel_timesteps)
|
33 |
+
n_timesteps (int): number of diffusion steps
|
34 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
35 |
+
c (torch.Tensor, optional): speaker embedding
|
36 |
+
shape: (batch_size, gin_channels)
|
37 |
+
solver: see https://github.com/rtqichen/torchdiffeq for supported solvers
|
38 |
+
cfg_kwargs: used for cfg inference
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
sample: generated mel-spectrogram
|
42 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
43 |
+
"""
|
44 |
+
|
45 |
+
z = torch.randn_like(mu) * temperature
|
46 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
47 |
+
|
48 |
+
# cfg control
|
49 |
+
if cfg_kwargs is None:
|
50 |
+
estimator = functools.partial(self.estimator, mask=mask, mu=mu, c=c)
|
51 |
+
else:
|
52 |
+
estimator = functools.partial(self.cfg_wrapper, mask=mask, mu=mu, c=c, cfg_kwargs=cfg_kwargs)
|
53 |
+
|
54 |
+
trajectory = odeint(estimator, z, t_span, method=solver, rtol=1e-5, atol=1e-5)
|
55 |
+
return trajectory[-1]
|
56 |
+
|
57 |
+
# cfg inference
|
58 |
+
def cfg_wrapper(self, t, x, mask, mu, c, cfg_kwargs):
|
59 |
+
fake_speaker = cfg_kwargs['fake_speaker'].repeat(x.size(0), 1)
|
60 |
+
fake_content = cfg_kwargs['fake_content'].repeat(x.size(0), 1, x.size(-1))
|
61 |
+
cfg_strength = cfg_kwargs['cfg_strength']
|
62 |
+
|
63 |
+
cond_output = self.estimator(t, x, mask, mu, c)
|
64 |
+
uncond_output = self.estimator(t, x, mask, fake_content, fake_speaker)
|
65 |
+
|
66 |
+
output = uncond_output + cfg_strength * (cond_output - uncond_output)
|
67 |
+
return output
|
68 |
+
|
69 |
+
def compute_loss(self, x1, mask, mu, c):
|
70 |
+
"""Computes diffusion loss
|
71 |
+
|
72 |
+
Args:
|
73 |
+
x1 (torch.Tensor): Target
|
74 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
75 |
+
mask (torch.Tensor): target mask
|
76 |
+
shape: (batch_size, 1, mel_timesteps)
|
77 |
+
mu (torch.Tensor): output of encoder
|
78 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
79 |
+
c (torch.Tensor, optional): speaker condition.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
loss: conditional flow matching loss
|
83 |
+
y: conditional flow
|
84 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
85 |
+
"""
|
86 |
+
b, _, t = mu.shape
|
87 |
+
|
88 |
+
# random timestep
|
89 |
+
# use cosine timestep scheduler from cosyvoice: https://github.com/FunAudioLLM/CosyVoice/blob/main/cosyvoice/flow/flow_matching.py
|
90 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
91 |
+
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
92 |
+
|
93 |
+
# sample noise p(x_0)
|
94 |
+
z = torch.randn_like(x1)
|
95 |
+
|
96 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
97 |
+
u = x1 - (1 - self.sigma_min) * z
|
98 |
+
|
99 |
+
loss = F.mse_loss(self.estimator(t.squeeze(), y, mask, mu, c), u, reduction="sum") / (torch.sum(mask) * u.size(1))
|
100 |
+
return loss, y
|
models/model.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
import monotonic_align
|
6 |
+
from models.text_encoder import TextEncoder
|
7 |
+
from models.flow_matching import CFMDecoder
|
8 |
+
from models.reference_encoder import MelStyleEncoder
|
9 |
+
from models.duration_predictor import DurationPredictor, duration_loss
|
10 |
+
from utils.mask import sequence_mask
|
11 |
+
|
12 |
+
def convert_pad_shape(pad_shape):
|
13 |
+
inverted_shape = pad_shape[::-1]
|
14 |
+
pad_shape = [item for sublist in inverted_shape for item in sublist]
|
15 |
+
return pad_shape
|
16 |
+
|
17 |
+
def generate_path(duration, mask):
|
18 |
+
b, t_x, t_y = mask.shape
|
19 |
+
cum_duration = torch.cumsum(duration, 1)
|
20 |
+
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype, device=duration.device)
|
21 |
+
|
22 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
23 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
24 |
+
path = path.view(b, t_x, t_y)
|
25 |
+
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
26 |
+
path = path * mask
|
27 |
+
return path
|
28 |
+
|
29 |
+
# modified from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/matcha_tts.py
|
30 |
+
class StableTTS(nn.Module):
|
31 |
+
def __init__(self, n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, n_dec_layers, kernel_size, p_dropout, gin_channels):
|
32 |
+
super().__init__()
|
33 |
+
|
34 |
+
self.n_vocab = n_vocab
|
35 |
+
self.mel_channels = mel_channels
|
36 |
+
|
37 |
+
self.encoder = TextEncoder(n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, kernel_size, p_dropout, gin_channels)
|
38 |
+
self.ref_encoder = MelStyleEncoder(mel_channels, style_vector_dim=gin_channels, style_kernel_size=5, dropout=0.25)
|
39 |
+
self.dp = DurationPredictor(hidden_channels, filter_channels, kernel_size, 0.5, gin_channels)
|
40 |
+
self.decoder = CFMDecoder(mel_channels, mel_channels, hidden_channels, mel_channels, filter_channels, n_heads, n_dec_layers, kernel_size, p_dropout, gin_channels)
|
41 |
+
|
42 |
+
# uncondition input for cfg
|
43 |
+
self.fake_speaker = nn.Parameter(torch.zeros(1, gin_channels))
|
44 |
+
self.fake_content = nn.Parameter(torch.zeros(1, mel_channels, 1))
|
45 |
+
|
46 |
+
self.cfg_dropout = 0.2
|
47 |
+
|
48 |
+
@torch.inference_mode()
|
49 |
+
def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, y=None, length_scale=1.0, solver=None, cfg=1.0):
|
50 |
+
"""
|
51 |
+
Generates mel-spectrogram from text. Returns:
|
52 |
+
1. encoder outputs
|
53 |
+
2. decoder outputs
|
54 |
+
3. generated alignment
|
55 |
+
|
56 |
+
Args:
|
57 |
+
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
|
58 |
+
shape: (batch_size, max_text_length)
|
59 |
+
x_lengths (torch.Tensor): lengths of texts in batch.
|
60 |
+
shape: (batch_size,)
|
61 |
+
n_timesteps (int): number of steps to use for reverse diffusion in decoder.
|
62 |
+
temperature (float, optional): controls variance of terminal distribution.
|
63 |
+
y (torch.Tensor): mel spectrogram of reference audio
|
64 |
+
shape: (batch_size, mel_channels, time)
|
65 |
+
length_scale (float, optional): controls speech pace.
|
66 |
+
Increase value to slow down generated speech and vice versa.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
dict: {
|
70 |
+
"encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
|
71 |
+
# Average mel spectrogram generated by the encoder
|
72 |
+
"decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
|
73 |
+
# Refined mel spectrogram improved by the CFM
|
74 |
+
"attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length),
|
75 |
+
# Alignment map between text and mel spectrogram
|
76 |
+
"""
|
77 |
+
|
78 |
+
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
79 |
+
c = self.ref_encoder(y, None)
|
80 |
+
x, mu_x, x_mask = self.encoder(x, c, x_lengths)
|
81 |
+
logw = self.dp(x, x_mask, c)
|
82 |
+
|
83 |
+
w = torch.exp(logw) * x_mask
|
84 |
+
w_ceil = torch.ceil(w) * length_scale
|
85 |
+
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
86 |
+
y_max_length = y_lengths.max()
|
87 |
+
|
88 |
+
# Using obtained durations `w` construct alignment map `attn`
|
89 |
+
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask.dtype)
|
90 |
+
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
|
91 |
+
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
|
92 |
+
|
93 |
+
# Align encoded text and get mu_y
|
94 |
+
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
|
95 |
+
mu_y = mu_y.transpose(1, 2)
|
96 |
+
encoder_outputs = mu_y[:, :, :y_max_length]
|
97 |
+
|
98 |
+
# Generate sample tracing the probability flow
|
99 |
+
if cfg == 1.0:
|
100 |
+
decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, c, solver)
|
101 |
+
else:
|
102 |
+
cfg_kwargs = {'fake_speaker': self.fake_speaker, 'fake_content': self.fake_content, 'cfg_strength': cfg}
|
103 |
+
decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, c, solver, cfg_kwargs)
|
104 |
+
|
105 |
+
decoder_outputs = decoder_outputs[:, :, :y_max_length]
|
106 |
+
|
107 |
+
|
108 |
+
return {
|
109 |
+
"encoder_outputs": encoder_outputs,
|
110 |
+
"decoder_outputs": decoder_outputs,
|
111 |
+
"attn": attn[:, :, :y_max_length],
|
112 |
+
}
|
113 |
+
|
114 |
+
def forward(self, x, x_lengths, y, y_lengths, z, z_lengths):
|
115 |
+
"""
|
116 |
+
Computes 3 losses:
|
117 |
+
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
|
118 |
+
2. prior loss: loss between mel-spectrogram and encoder outputs.
|
119 |
+
3. flow matching loss: loss between mel-spectrogram and decoder outputs.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
|
123 |
+
shape: (batch_size, max_text_length)
|
124 |
+
x_lengths (torch.Tensor): lengths of texts in batch.
|
125 |
+
shape: (batch_size,)
|
126 |
+
y (torch.Tensor): batch of corresponding mel-spectrograms.
|
127 |
+
shape: (batch_size, n_feats, max_mel_length)
|
128 |
+
y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
|
129 |
+
shape: (batch_size,)
|
130 |
+
z (torch.Tensor): batch of cliced mel-spectrograms.
|
131 |
+
shape: (batch_size, n_feats, max_mel_length)
|
132 |
+
z_lengths (torch.Tensor): lengths of sliced mel-spectrograms in batch.
|
133 |
+
shape: (batch_size,)
|
134 |
+
"""
|
135 |
+
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
136 |
+
y_mask = sequence_mask(y_lengths, y.size(2)).unsqueeze(1).to(y.dtype)
|
137 |
+
z_mask = sequence_mask(z_lengths, z.size(2)).unsqueeze(1).to(z.dtype)
|
138 |
+
cfg_mask = torch.rand(y.size(0), 1, device=y.device) > self.cfg_dropout
|
139 |
+
|
140 |
+
# compute global speaker embedding
|
141 |
+
c = self.ref_encoder(z, z_mask) * cfg_mask + ~cfg_mask * self.fake_speaker.repeat(z.size(0), 1)
|
142 |
+
|
143 |
+
x, mu_x, x_mask = self.encoder(x, c, x_lengths)
|
144 |
+
logw = self.dp(x, x_mask, c)
|
145 |
+
|
146 |
+
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
|
147 |
+
|
148 |
+
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
|
149 |
+
with torch.no_grad():
|
150 |
+
s_p_sq_r = torch.ones_like(mu_x) # [b, d, t]
|
151 |
+
neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi)- torch.zeros_like(mu_x), [1], keepdim=True)
|
152 |
+
neg_cent2 = torch.einsum("bdt, bds -> bts", -0.5 * (y**2), s_p_sq_r)
|
153 |
+
neg_cent3 = torch.einsum("bdt, bds -> bts", y, (mu_x * s_p_sq_r))
|
154 |
+
neg_cent4 = torch.sum(-0.5 * (mu_x**2) * s_p_sq_r, [1], keepdim=True)
|
155 |
+
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
|
156 |
+
|
157 |
+
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
158 |
+
attn = (monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach())
|
159 |
+
|
160 |
+
# Compute loss between predicted log-scaled durations and those obtained from MAS
|
161 |
+
# refered to as prior loss in the paper
|
162 |
+
logw_ = torch.log(1e-8 + attn.sum(2)) * x_mask
|
163 |
+
dur_loss = duration_loss(logw, logw_, x_lengths)
|
164 |
+
|
165 |
+
# Align encoded text with mel-spectrogram and get mu_y segment
|
166 |
+
attn = attn.squeeze(1).transpose(1,2)
|
167 |
+
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
|
168 |
+
mu_y = mu_y.transpose(1, 2)
|
169 |
+
|
170 |
+
# Compute loss of the decoder
|
171 |
+
cfg_mask = cfg_mask.unsqueeze(-1)
|
172 |
+
mu_y_masked = mu_y * cfg_mask + ~cfg_mask * self.fake_content.repeat(mu_y.size(0), 1, mu_y.size(-1)) # mask content information for better diversity for flow-matching
|
173 |
+
diff_loss, _ = self.decoder.compute_loss(y, y_mask, mu_y_masked, c)
|
174 |
+
|
175 |
+
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
|
176 |
+
prior_loss = prior_loss / (torch.sum(y_mask) * self.mel_channels)
|
177 |
+
|
178 |
+
return dur_loss, diff_loss, prior_loss, attn
|
models/reference_encoder.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class Conv1dGLU(nn.Module):
|
5 |
+
"""
|
6 |
+
Conv1d + GLU(Gated Linear Unit) with residual connection.
|
7 |
+
For GLU refer to https://arxiv.org/abs/1612.08083 paper.
|
8 |
+
"""
|
9 |
+
|
10 |
+
def __init__(self, in_channels, out_channels, kernel_size, dropout):
|
11 |
+
super(Conv1dGLU, self).__init__()
|
12 |
+
self.out_channels = out_channels
|
13 |
+
self.conv1 = nn.Conv1d(in_channels, 2 * out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
|
14 |
+
self.dropout = nn.Dropout(dropout)
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
residual = x
|
18 |
+
x = self.conv1(x)
|
19 |
+
x1, x2 = torch.split(x, self.out_channels, dim=1)
|
20 |
+
x = x1 * torch.sigmoid(x2)
|
21 |
+
x = residual + self.dropout(x)
|
22 |
+
return x
|
23 |
+
|
24 |
+
# modified from https://github.com/RVC-Boss/GPT-SoVITS/blob/main/GPT_SoVITS/module/modules.py#L766
|
25 |
+
class MelStyleEncoder(nn.Module):
|
26 |
+
"""MelStyleEncoder"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
n_mel_channels=80,
|
31 |
+
style_hidden=128,
|
32 |
+
style_vector_dim=256,
|
33 |
+
style_kernel_size=5,
|
34 |
+
style_head=2,
|
35 |
+
dropout=0.1,
|
36 |
+
):
|
37 |
+
super(MelStyleEncoder, self).__init__()
|
38 |
+
self.in_dim = n_mel_channels
|
39 |
+
self.hidden_dim = style_hidden
|
40 |
+
self.out_dim = style_vector_dim
|
41 |
+
self.kernel_size = style_kernel_size
|
42 |
+
self.n_head = style_head
|
43 |
+
self.dropout = dropout
|
44 |
+
|
45 |
+
self.spectral = nn.Sequential(
|
46 |
+
nn.Linear(self.in_dim, self.hidden_dim),
|
47 |
+
nn.Mish(inplace=True),
|
48 |
+
nn.Dropout(self.dropout),
|
49 |
+
nn.Linear(self.hidden_dim, self.hidden_dim),
|
50 |
+
nn.Mish(inplace=True),
|
51 |
+
nn.Dropout(self.dropout),
|
52 |
+
)
|
53 |
+
|
54 |
+
self.temporal = nn.Sequential(
|
55 |
+
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
|
56 |
+
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
|
57 |
+
)
|
58 |
+
|
59 |
+
self.slf_attn = nn.MultiheadAttention(
|
60 |
+
self.hidden_dim,
|
61 |
+
self.n_head,
|
62 |
+
self.dropout,
|
63 |
+
batch_first=True
|
64 |
+
)
|
65 |
+
|
66 |
+
self.fc = nn.Linear(self.hidden_dim, self.out_dim)
|
67 |
+
|
68 |
+
def temporal_avg_pool(self, x, mask=None):
|
69 |
+
if mask is None:
|
70 |
+
return torch.mean(x, dim=1)
|
71 |
+
else:
|
72 |
+
return torch.sum(x * ~mask.unsqueeze(-1), dim=1) / (~mask).sum(dim=1).unsqueeze(1)
|
73 |
+
|
74 |
+
def forward(self, x, x_mask=None):
|
75 |
+
x = x.transpose(1, 2)
|
76 |
+
|
77 |
+
# spectral
|
78 |
+
x = self.spectral(x)
|
79 |
+
# temporal
|
80 |
+
x = x.transpose(1, 2)
|
81 |
+
x = self.temporal(x)
|
82 |
+
x = x.transpose(1, 2)
|
83 |
+
# self-attention
|
84 |
+
if x_mask is not None:
|
85 |
+
x_mask = ~x_mask.squeeze(1).to(torch.bool)
|
86 |
+
x, _ = self.slf_attn(x, x, x, key_padding_mask=x_mask, need_weights=False)
|
87 |
+
# fc
|
88 |
+
x = self.fc(x)
|
89 |
+
# temoral average pooling
|
90 |
+
w = self.temporal_avg_pool(x, mask=x_mask)
|
91 |
+
|
92 |
+
return w
|
93 |
+
|
94 |
+
# Attention Pool version of MelStyleEncoder, not used
|
95 |
+
class AttnMelStyleEncoder(nn.Module):
|
96 |
+
"""MelStyleEncoder"""
|
97 |
+
|
98 |
+
def __init__(
|
99 |
+
self,
|
100 |
+
n_mel_channels=80,
|
101 |
+
style_hidden=128,
|
102 |
+
style_vector_dim=256,
|
103 |
+
style_kernel_size=5,
|
104 |
+
style_head=2,
|
105 |
+
dropout=0.1,
|
106 |
+
):
|
107 |
+
super().__init__()
|
108 |
+
self.in_dim = n_mel_channels
|
109 |
+
self.hidden_dim = style_hidden
|
110 |
+
self.out_dim = style_vector_dim
|
111 |
+
self.kernel_size = style_kernel_size
|
112 |
+
self.n_head = style_head
|
113 |
+
self.dropout = dropout
|
114 |
+
|
115 |
+
self.spectral = nn.Sequential(
|
116 |
+
nn.Linear(self.in_dim, self.hidden_dim),
|
117 |
+
nn.Mish(inplace=True),
|
118 |
+
nn.Dropout(self.dropout),
|
119 |
+
nn.Linear(self.hidden_dim, self.hidden_dim),
|
120 |
+
nn.Mish(inplace=True),
|
121 |
+
nn.Dropout(self.dropout),
|
122 |
+
)
|
123 |
+
|
124 |
+
self.temporal = nn.Sequential(
|
125 |
+
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
|
126 |
+
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
|
127 |
+
)
|
128 |
+
|
129 |
+
self.slf_attn = nn.MultiheadAttention(
|
130 |
+
self.hidden_dim,
|
131 |
+
self.n_head,
|
132 |
+
self.dropout,
|
133 |
+
batch_first=True
|
134 |
+
)
|
135 |
+
|
136 |
+
self.fc = nn.Linear(self.hidden_dim, self.out_dim)
|
137 |
+
|
138 |
+
def temporal_avg_pool(self, x, mask=None):
|
139 |
+
if mask is None:
|
140 |
+
return torch.mean(x, dim=1)
|
141 |
+
else:
|
142 |
+
return torch.sum(x * ~mask.unsqueeze(-1), dim=1) / (~mask).sum(dim=1).unsqueeze(1)
|
143 |
+
|
144 |
+
def forward(self, x, x_mask=None):
|
145 |
+
x = x.transpose(1, 2)
|
146 |
+
|
147 |
+
# spectral
|
148 |
+
x = self.spectral(x)
|
149 |
+
# temporal
|
150 |
+
x = x.transpose(1, 2)
|
151 |
+
x = self.temporal(x)
|
152 |
+
x = x.transpose(1, 2)
|
153 |
+
# self-attention
|
154 |
+
if x_mask is not None:
|
155 |
+
x_mask = ~x_mask.squeeze(1).to(torch.bool)
|
156 |
+
zeros = torch.zeros(x_mask.size(0), 1, device=x_mask.device, dtype=x_mask.dtype)
|
157 |
+
x_attn_mask = torch.cat((zeros, x_mask), dim=1)
|
158 |
+
else:
|
159 |
+
x_attn_mask = None
|
160 |
+
|
161 |
+
avg = self.temporal_avg_pool(x, x_mask).unsqueeze(1)
|
162 |
+
x = torch.cat([avg, x], dim=1)
|
163 |
+
x, _ = self.slf_attn(x, x, x, key_padding_mask=x_attn_mask, need_weights=False)
|
164 |
+
x = x[:, 0, :]
|
165 |
+
# fc
|
166 |
+
x = self.fc(x)
|
167 |
+
|
168 |
+
return x
|
models/text_encoder.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from models.diffusion_transformer import DiTConVBlock
|
5 |
+
from utils.mask import sequence_mask
|
6 |
+
|
7 |
+
# modified from https://github.com/jaywalnut310/vits/blob/main/models.py
|
8 |
+
class TextEncoder(nn.Module):
|
9 |
+
def __init__(self, n_vocab, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels):
|
10 |
+
super().__init__()
|
11 |
+
self.n_vocab = n_vocab
|
12 |
+
self.out_channels = out_channels
|
13 |
+
self.hidden_channels = hidden_channels
|
14 |
+
self.filter_channels = filter_channels
|
15 |
+
self.n_heads = n_heads
|
16 |
+
self.n_layers = n_layers
|
17 |
+
self.kernel_size = kernel_size
|
18 |
+
self.p_dropout = p_dropout
|
19 |
+
|
20 |
+
self.scale = self.hidden_channels ** 0.5
|
21 |
+
|
22 |
+
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
23 |
+
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
24 |
+
|
25 |
+
self.encoder = nn.ModuleList([DiTConVBlock(hidden_channels, filter_channels, n_heads, kernel_size, p_dropout, gin_channels) for _ in range(n_layers)])
|
26 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
27 |
+
|
28 |
+
self.initialize_weights()
|
29 |
+
|
30 |
+
def initialize_weights(self):
|
31 |
+
for block in self.encoder:
|
32 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
33 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
34 |
+
|
35 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor, x_lengths: torch.Tensor):
|
36 |
+
x = self.emb(x) * self.scale # [b, t, h]
|
37 |
+
x = x.transpose(1, -1) # [b, h, t]
|
38 |
+
x_mask = sequence_mask(x_lengths, x.size(2)).unsqueeze(1).to(x.dtype)
|
39 |
+
|
40 |
+
for layer in self.encoder:
|
41 |
+
x = layer(x, c, x_mask)
|
42 |
+
mu_x = self.proj(x) * x_mask
|
43 |
+
|
44 |
+
return x, mu_x, x_mask
|
monotonic_align/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from numpy import zeros, int32, float32
|
2 |
+
from torch import from_numpy
|
3 |
+
|
4 |
+
from .core import maximum_path_jit
|
5 |
+
|
6 |
+
|
7 |
+
def maximum_path(neg_cent, mask):
|
8 |
+
device = neg_cent.device
|
9 |
+
dtype = neg_cent.dtype
|
10 |
+
neg_cent = neg_cent.data.cpu().numpy().astype(float32)
|
11 |
+
path = zeros(neg_cent.shape, dtype=int32)
|
12 |
+
|
13 |
+
t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
|
14 |
+
t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
|
15 |
+
maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
|
16 |
+
return from_numpy(path).to(device=device, dtype=dtype)
|
monotonic_align/core.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numba
|
2 |
+
|
3 |
+
|
4 |
+
@numba.jit(
|
5 |
+
numba.void(
|
6 |
+
numba.int32[:, :, ::1],
|
7 |
+
numba.float32[:, :, ::1],
|
8 |
+
numba.int32[::1],
|
9 |
+
numba.int32[::1],
|
10 |
+
),
|
11 |
+
nopython=True,
|
12 |
+
nogil=True,
|
13 |
+
)
|
14 |
+
def maximum_path_jit(paths, values, t_ys, t_xs):
|
15 |
+
b = paths.shape[0]
|
16 |
+
max_neg_val = -1e9
|
17 |
+
for i in range(int(b)):
|
18 |
+
path = paths[i]
|
19 |
+
value = values[i]
|
20 |
+
t_y = t_ys[i]
|
21 |
+
t_x = t_xs[i]
|
22 |
+
|
23 |
+
v_prev = v_cur = 0.0
|
24 |
+
index = t_x - 1
|
25 |
+
|
26 |
+
for y in range(t_y):
|
27 |
+
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
28 |
+
if x == y:
|
29 |
+
v_cur = max_neg_val
|
30 |
+
else:
|
31 |
+
v_cur = value[y - 1, x]
|
32 |
+
if x == 0:
|
33 |
+
if y == 0:
|
34 |
+
v_prev = 0.0
|
35 |
+
else:
|
36 |
+
v_prev = max_neg_val
|
37 |
+
else:
|
38 |
+
v_prev = value[y - 1, x - 1]
|
39 |
+
value[y, x] += max(v_prev, v_cur)
|
40 |
+
|
41 |
+
for y in range(t_y - 1, -1, -1):
|
42 |
+
path[y, index] = 1
|
43 |
+
if index != 0 and (
|
44 |
+
index == y or value[y - 1, index] < value[y - 1, index - 1]
|
45 |
+
):
|
46 |
+
index = index - 1
|
requirements.txt
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchaudio
|
3 |
+
|
4 |
+
tqdm
|
5 |
+
numpy
|
6 |
+
soundfile # to make sure that torchaudio has at least one valid backend
|
7 |
+
|
8 |
+
tensorboard
|
9 |
+
|
10 |
+
# for monotonic_align
|
11 |
+
numba
|
12 |
+
|
13 |
+
# ODE-solver
|
14 |
+
torchdiffeq
|
15 |
+
|
16 |
+
# for g2p
|
17 |
+
# chinese
|
18 |
+
pypinyin
|
19 |
+
jieba
|
20 |
+
# english
|
21 |
+
eng_to_ipa
|
22 |
+
unidecode
|
23 |
+
inflect
|
24 |
+
# japanese
|
25 |
+
# if pyopenjtalk fail to download open_jtalk_dic_utf_8-1.11.tar.gz, manually download and unzip the file below
|
26 |
+
# https://github.com/r9y9/open_jtalk/releases/download/v1.11.1/open_jtalk_dic_utf_8-1.11.tar.gz
|
27 |
+
# and set os.environ['OPEN_JTALK_DICT_DIR'] to the folder path
|
28 |
+
pyopenjtalk-prebuilt # if using python >= 3.12, install pyopenjtalk instead
|
29 |
+
|
30 |
+
# for webui
|
31 |
+
gradio
|
32 |
+
matplotlib
|
33 |
+
|
text/LICENSE
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2017 Keith Ito
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
of this software and associated documentation files (the "Software"), to deal
|
5 |
+
in the Software without restriction, including without limitation the rights
|
6 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
copies of the Software, and to permit persons to whom the Software is
|
8 |
+
furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
The above copyright notice and this permission notice shall be included in
|
11 |
+
all copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
19 |
+
THE SOFTWARE.
|
text/__init__.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
from text import cleaners
|
3 |
+
from text.symbols import symbols
|
4 |
+
|
5 |
+
|
6 |
+
# Mappings from symbol to numeric ID and vice versa:
|
7 |
+
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
8 |
+
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
9 |
+
|
10 |
+
|
11 |
+
def text_to_sequence(text, symbols, cleaner_names):
|
12 |
+
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
13 |
+
Args:
|
14 |
+
text: string to convert to a sequence
|
15 |
+
cleaner_names: names of the cleaner functions to run the text through
|
16 |
+
Returns:
|
17 |
+
List of integers corresponding to the symbols in the text
|
18 |
+
'''
|
19 |
+
sequence = []
|
20 |
+
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
21 |
+
clean_text = _clean_text(text, cleaner_names)
|
22 |
+
print(clean_text)
|
23 |
+
print(f" length:{len(clean_text)}")
|
24 |
+
for symbol in clean_text:
|
25 |
+
if symbol not in symbol_to_id.keys():
|
26 |
+
continue
|
27 |
+
symbol_id = symbol_to_id[symbol]
|
28 |
+
sequence += [symbol_id]
|
29 |
+
print(f" length:{len(sequence)}")
|
30 |
+
return sequence
|
31 |
+
|
32 |
+
|
33 |
+
def cleaned_text_to_sequence(cleaned_text):
|
34 |
+
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
35 |
+
Args:
|
36 |
+
text: string to convert to a sequence
|
37 |
+
Returns:
|
38 |
+
List of integers corresponding to the symbols in the text
|
39 |
+
'''
|
40 |
+
# symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
41 |
+
sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()]
|
42 |
+
return sequence
|
43 |
+
|
44 |
+
def cleaned_text_to_sequence_chinese(cleaned_text):
|
45 |
+
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
46 |
+
Args:
|
47 |
+
text: string to convert to a sequence
|
48 |
+
Returns:
|
49 |
+
List of integers corresponding to the symbols in the text
|
50 |
+
'''
|
51 |
+
# symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
52 |
+
sequence = [_symbol_to_id[symbol] for symbol in cleaned_text.split(' ') if symbol in _symbol_to_id.keys()]
|
53 |
+
return sequence
|
54 |
+
|
55 |
+
|
56 |
+
def sequence_to_text(sequence):
|
57 |
+
'''Converts a sequence of IDs back to a string'''
|
58 |
+
result = ''
|
59 |
+
for symbol_id in sequence:
|
60 |
+
s = _id_to_symbol[symbol_id]
|
61 |
+
result += s
|
62 |
+
return result
|
63 |
+
|
64 |
+
|
65 |
+
def _clean_text(text, cleaner_names):
|
66 |
+
for name in cleaner_names:
|
67 |
+
cleaner = getattr(cleaners, name)
|
68 |
+
if not cleaner:
|
69 |
+
raise Exception('Unknown cleaner: %s' % name)
|
70 |
+
text = cleaner(text)
|
71 |
+
return text
|
text/cleaners.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
from text.english import english_to_ipa2
|
4 |
+
from text.mandarin import chinese_to_cnm3
|
5 |
+
from text.japanese import japanese_to_ipa2
|
6 |
+
|
7 |
+
language_module_map = {"PAD":0, "ZH": 1, "EN": 2, "JA": 3}
|
8 |
+
|
9 |
+
# 预编译正则表达式
|
10 |
+
ZH_PATTERN = re.compile(r'[\u3400-\u4DBF\u4e00-\u9FFF\uF900-\uFAFF\u3000-\u303F]')
|
11 |
+
EN_PATTERN = re.compile(r'[a-zA-Z.,!?\'"(){}[\]<>:;@#$%^&*-_+=/\\|~`]+')
|
12 |
+
JP_PATTERN = re.compile(r'[\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FAF\u31F0-\u31FF\uFF00-\uFFEF\u3000-\u303F]')
|
13 |
+
CLEANER_PATTERN = re.compile(r'\[(ZH|EN|JA)\]')
|
14 |
+
|
15 |
+
def detect_language(text: str, prev_lang=None):
|
16 |
+
"""
|
17 |
+
根据给定的文本检测语言
|
18 |
+
|
19 |
+
:param text: 输入文本
|
20 |
+
:param prev_lang: 上一个检测到的语言
|
21 |
+
:return: 'ZH' for Chinese, 'EN' for English, 'JA' for Japanese, or prev_lang for spaces
|
22 |
+
"""
|
23 |
+
if ZH_PATTERN.search(text): return 'ZH'
|
24 |
+
if EN_PATTERN.search(text): return 'EN'
|
25 |
+
if JP_PATTERN.search(text): return 'JA'
|
26 |
+
if text.isspace(): return prev_lang # 若是空格,则返回前一个语言
|
27 |
+
return None
|
28 |
+
|
29 |
+
# auto detect language using re
|
30 |
+
def cjke_cleaners4(text: str):
|
31 |
+
"""
|
32 |
+
根据文本内容自动检测语言并转换为IPA音标
|
33 |
+
|
34 |
+
:param text: 输入文本
|
35 |
+
:return: 转换为IPA音标的文本
|
36 |
+
"""
|
37 |
+
text = CLEANER_PATTERN.sub('', text)
|
38 |
+
pointer = 0
|
39 |
+
output = ''
|
40 |
+
current_language = detect_language(text[pointer])
|
41 |
+
|
42 |
+
while pointer < len(text):
|
43 |
+
temp_text = ''
|
44 |
+
while pointer < len(text) and detect_language(text[pointer], current_language) == current_language:
|
45 |
+
temp_text += text[pointer]
|
46 |
+
pointer += 1
|
47 |
+
if current_language == 'ZH':
|
48 |
+
output += chinese_to_cnm3(temp_text)
|
49 |
+
elif current_language == 'JA':
|
50 |
+
output += japanese_to_ipa2(temp_text)
|
51 |
+
elif current_language == 'EN':
|
52 |
+
output += english_to_ipa2(temp_text)
|
53 |
+
if pointer < len(text):
|
54 |
+
current_language = detect_language(text[pointer])
|
55 |
+
|
56 |
+
output = re.sub(r'\s+$', '', output)
|
57 |
+
output = re.sub(r'([^\.,!\?\-…~])$', r'\1.', output)
|
58 |
+
return output
|
text/cn2an/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "0.5.22"
|
2 |
+
|
3 |
+
from .cn2an import Cn2An
|
4 |
+
from .an2cn import An2Cn
|
5 |
+
from .transform import Transform
|
6 |
+
|
7 |
+
cn2an = Cn2An().cn2an
|
8 |
+
an2cn = An2Cn().an2cn
|
9 |
+
transform = Transform().transform
|
10 |
+
|
11 |
+
__all__ = [
|
12 |
+
"__version__",
|
13 |
+
"cn2an",
|
14 |
+
"an2cn",
|
15 |
+
"transform"
|
16 |
+
]
|
text/cn2an/an2cn.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
from warnings import warn
|
3 |
+
|
4 |
+
# from proces import preprocess
|
5 |
+
|
6 |
+
from .conf import NUMBER_LOW_AN2CN, NUMBER_UP_AN2CN, UNIT_LOW_ORDER_AN2CN, UNIT_UP_ORDER_AN2CN
|
7 |
+
|
8 |
+
|
9 |
+
class An2Cn(object):
|
10 |
+
def __init__(self) -> None:
|
11 |
+
self.all_num = "0123456789"
|
12 |
+
self.number_low = NUMBER_LOW_AN2CN
|
13 |
+
self.number_up = NUMBER_UP_AN2CN
|
14 |
+
self.mode_list = ["low", "up", "rmb", "direct"]
|
15 |
+
|
16 |
+
def an2cn(self, inputs: Union[str, int, float] = None, mode: str = "low") -> str:
|
17 |
+
"""阿拉伯数字转中文数字
|
18 |
+
|
19 |
+
:param inputs: 阿拉伯数字
|
20 |
+
:param mode: low 小写数字,up 大写数字,rmb 人民币大写,direct 直接转化
|
21 |
+
:return: 中文数字
|
22 |
+
"""
|
23 |
+
if inputs is not None and inputs != "":
|
24 |
+
if mode not in self.mode_list:
|
25 |
+
raise ValueError(f"mode 仅支持 {str(self.mode_list)} !")
|
26 |
+
|
27 |
+
# 将数字转化为字符串,这里会有Python会自动做转化
|
28 |
+
# 1. -> 1.0 1.00 -> 1.0 -0 -> 0
|
29 |
+
if not isinstance(inputs, str):
|
30 |
+
inputs = self.__number_to_string(inputs)
|
31 |
+
|
32 |
+
# 数据预处理:
|
33 |
+
# 1. 繁体转简体
|
34 |
+
# 2. 全角转半角
|
35 |
+
# inputs = preprocess(inputs, pipelines=[
|
36 |
+
# "traditional_to_simplified",
|
37 |
+
# "full_angle_to_half_angle"
|
38 |
+
# ])
|
39 |
+
|
40 |
+
# 检查数据是否有效
|
41 |
+
self.__check_inputs_is_valid(inputs)
|
42 |
+
|
43 |
+
# 判断正负
|
44 |
+
if inputs[0] == "-":
|
45 |
+
sign = "负"
|
46 |
+
inputs = inputs[1:]
|
47 |
+
else:
|
48 |
+
sign = ""
|
49 |
+
|
50 |
+
if mode == "direct":
|
51 |
+
output = self.__direct_convert(inputs)
|
52 |
+
else:
|
53 |
+
# 切割整数部分和小数部分
|
54 |
+
split_result = inputs.split(".")
|
55 |
+
len_split_result = len(split_result)
|
56 |
+
if len_split_result == 1:
|
57 |
+
# 不包含小数的输入
|
58 |
+
integer_data = split_result[0]
|
59 |
+
if mode == "rmb":
|
60 |
+
output = self.__integer_convert(integer_data, "up") + "元整"
|
61 |
+
else:
|
62 |
+
output = self.__integer_convert(integer_data, mode)
|
63 |
+
elif len_split_result == 2:
|
64 |
+
# 包含小数的输入
|
65 |
+
integer_data, decimal_data = split_result
|
66 |
+
if mode == "rmb":
|
67 |
+
int_data = self.__integer_convert(integer_data, "up")
|
68 |
+
dec_data = self.__decimal_convert(decimal_data, "up")
|
69 |
+
len_dec_data = len(dec_data)
|
70 |
+
|
71 |
+
if len_dec_data == 0:
|
72 |
+
output = int_data + "元整"
|
73 |
+
elif len_dec_data == 1:
|
74 |
+
raise ValueError(f"异常输出:{dec_data}")
|
75 |
+
elif len_dec_data == 2:
|
76 |
+
if dec_data[1] != "零":
|
77 |
+
if int_data == "零":
|
78 |
+
output = dec_data[1] + "角"
|
79 |
+
else:
|
80 |
+
output = int_data + "元" + dec_data[1] + "角"
|
81 |
+
else:
|
82 |
+
output = int_data + "元整"
|
83 |
+
else:
|
84 |
+
if dec_data[1] != "零":
|
85 |
+
if dec_data[2] != "零":
|
86 |
+
if int_data == "零":
|
87 |
+
output = dec_data[1] + "角" + dec_data[2] + "分"
|
88 |
+
else:
|
89 |
+
output = int_data + "元" + dec_data[1] + "角" + dec_data[2] + "分"
|
90 |
+
else:
|
91 |
+
if int_data == "零":
|
92 |
+
output = dec_data[1] + "角"
|
93 |
+
else:
|
94 |
+
output = int_data + "元" + dec_data[1] + "角"
|
95 |
+
else:
|
96 |
+
if dec_data[2] != "零":
|
97 |
+
if int_data == "零":
|
98 |
+
output = dec_data[2] + "分"
|
99 |
+
else:
|
100 |
+
output = int_data + "元" + "零" + dec_data[2] + "分"
|
101 |
+
else:
|
102 |
+
output = int_data + "元整"
|
103 |
+
else:
|
104 |
+
output = self.__integer_convert(integer_data, mode) + self.__decimal_convert(decimal_data, mode)
|
105 |
+
else:
|
106 |
+
raise ValueError(f"输入格式错误:{inputs}!")
|
107 |
+
else:
|
108 |
+
raise ValueError("输入数据为空!")
|
109 |
+
|
110 |
+
return sign + output
|
111 |
+
|
112 |
+
def __direct_convert(self, inputs: str) -> str:
|
113 |
+
_output = ""
|
114 |
+
for d in inputs:
|
115 |
+
if d == ".":
|
116 |
+
_output += "点"
|
117 |
+
else:
|
118 |
+
_output += self.number_low[int(d)]
|
119 |
+
return _output
|
120 |
+
|
121 |
+
@staticmethod
|
122 |
+
def __number_to_string(number_data: Union[int, float]) -> str:
|
123 |
+
# 小数处理:python 会自动把 0.00005 转化成 5e-05,因此 str(0.00005) != "0.00005"
|
124 |
+
string_data = str(number_data)
|
125 |
+
if "e" in string_data:
|
126 |
+
string_data_list = string_data.split("e")
|
127 |
+
string_key = string_data_list[0]
|
128 |
+
string_value = string_data_list[1]
|
129 |
+
if string_value[0] == "-":
|
130 |
+
string_data = "0." + "0" * (int(string_value[1:]) - 1) + string_key
|
131 |
+
else:
|
132 |
+
string_data = string_key + "0" * int(string_value)
|
133 |
+
return string_data
|
134 |
+
|
135 |
+
def __check_inputs_is_valid(self, check_data: str) -> None:
|
136 |
+
# 检查输入数据是否在规定的字典中
|
137 |
+
all_check_keys = self.all_num + ".-"
|
138 |
+
for data in check_data:
|
139 |
+
if data not in all_check_keys:
|
140 |
+
raise ValueError(f"输入的数据不在转化范围内:{data}!")
|
141 |
+
|
142 |
+
def __integer_convert(self, integer_data: str, mode: str) -> str:
|
143 |
+
if mode == "low":
|
144 |
+
numeral_list = NUMBER_LOW_AN2CN
|
145 |
+
unit_list = UNIT_LOW_ORDER_AN2CN
|
146 |
+
elif mode == "up":
|
147 |
+
numeral_list = NUMBER_UP_AN2CN
|
148 |
+
unit_list = UNIT_UP_ORDER_AN2CN
|
149 |
+
else:
|
150 |
+
raise ValueError(f"error mode: {mode}")
|
151 |
+
|
152 |
+
# 去除前面的 0,比如 007 => 7
|
153 |
+
integer_data = str(int(integer_data))
|
154 |
+
|
155 |
+
len_integer_data = len(integer_data)
|
156 |
+
if len_integer_data > len(unit_list):
|
157 |
+
raise ValueError(f"超出数据范围,最长支持 {len(unit_list)} 位")
|
158 |
+
|
159 |
+
output_an = ""
|
160 |
+
for i, d in enumerate(integer_data):
|
161 |
+
if int(d):
|
162 |
+
output_an += numeral_list[int(d)] + unit_list[len_integer_data - i - 1]
|
163 |
+
else:
|
164 |
+
if not (len_integer_data - i - 1) % 4:
|
165 |
+
output_an += numeral_list[int(d)] + unit_list[len_integer_data - i - 1]
|
166 |
+
|
167 |
+
if i > 0 and not output_an[-1] == "零":
|
168 |
+
output_an += numeral_list[int(d)]
|
169 |
+
|
170 |
+
output_an = output_an.replace("零零", "零").replace("零万", "万").replace("零亿", "亿").replace("亿万", "亿") \
|
171 |
+
.strip("零")
|
172 |
+
|
173 |
+
# 解决「一十几」问题
|
174 |
+
if output_an[:2] in ["一十"]:
|
175 |
+
output_an = output_an[1:]
|
176 |
+
|
177 |
+
# 0 - 1 之间的小数
|
178 |
+
if not output_an:
|
179 |
+
output_an = "零"
|
180 |
+
|
181 |
+
return output_an
|
182 |
+
|
183 |
+
def __decimal_convert(self, decimal_data: str, o_mode: str) -> str:
|
184 |
+
len_decimal_data = len(decimal_data)
|
185 |
+
|
186 |
+
if len_decimal_data > 16:
|
187 |
+
warn(f"注意:小数部分长度为 {len_decimal_data} ,将自动截取前 16 位有效精度!")
|
188 |
+
decimal_data = decimal_data[:16]
|
189 |
+
|
190 |
+
if len_decimal_data:
|
191 |
+
output_an = "点"
|
192 |
+
else:
|
193 |
+
output_an = ""
|
194 |
+
|
195 |
+
if o_mode == "low":
|
196 |
+
numeral_list = NUMBER_LOW_AN2CN
|
197 |
+
elif o_mode == "up":
|
198 |
+
numeral_list = NUMBER_UP_AN2CN
|
199 |
+
else:
|
200 |
+
raise ValueError(f"error mode: {o_mode}")
|
201 |
+
|
202 |
+
for data in decimal_data:
|
203 |
+
output_an += numeral_list[int(data)]
|
204 |
+
return output_an
|
text/cn2an/cn2an.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from warnings import warn
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
# from proces import preprocess
|
6 |
+
|
7 |
+
from .an2cn import An2Cn
|
8 |
+
from .conf import NUMBER_CN2AN, UNIT_CN2AN, STRICT_CN_NUMBER, NORMAL_CN_NUMBER, NUMBER_LOW_AN2CN, UNIT_LOW_AN2CN
|
9 |
+
|
10 |
+
|
11 |
+
class Cn2An(object):
|
12 |
+
def __init__(self) -> None:
|
13 |
+
self.all_num = "".join(list(NUMBER_CN2AN.keys()))
|
14 |
+
self.all_unit = "".join(list(UNIT_CN2AN.keys()))
|
15 |
+
self.strict_cn_number = STRICT_CN_NUMBER
|
16 |
+
self.normal_cn_number = NORMAL_CN_NUMBER
|
17 |
+
self.check_key_dict = {
|
18 |
+
"strict": "".join(self.strict_cn_number.values()) + "点负",
|
19 |
+
"normal": "".join(self.normal_cn_number.values()) + "点负",
|
20 |
+
"smart": "".join(self.normal_cn_number.values()) + "点负" + "01234567890.-"
|
21 |
+
}
|
22 |
+
self.pattern_dict = self.__get_pattern()
|
23 |
+
self.ac = An2Cn()
|
24 |
+
self.mode_list = ["strict", "normal", "smart"]
|
25 |
+
self.yjf_pattern = re.compile(fr"^.*?[元圆][{self.all_num}]角([{self.all_num}]分)?$")
|
26 |
+
self.pattern1 = re.compile(fr"^-?\d+(\.\d+)?[{self.all_unit}]?$")
|
27 |
+
self.ptn_all_num = re.compile(f"^[{self.all_num}]+$")
|
28 |
+
# "十?" is for special case "十一万三"
|
29 |
+
self.ptn_speaking_mode = re.compile(f"^([{self.all_num}]{{0,2}}[{self.all_unit}])+[{self.all_num}]$")
|
30 |
+
|
31 |
+
def cn2an(self, inputs: Union[str, int, float] = None, mode: str = "strict") -> Union[float, int]:
|
32 |
+
"""中文数字转阿拉伯数字
|
33 |
+
|
34 |
+
:param inputs: 中文数字、阿拉伯数字、中文数字和阿拉伯数字
|
35 |
+
:param mode: strict 严格,normal 正常,smart 智能
|
36 |
+
:return: 阿拉伯数字
|
37 |
+
"""
|
38 |
+
if inputs is not None or inputs == "":
|
39 |
+
if mode not in self.mode_list:
|
40 |
+
raise ValueError(f"mode 仅支持 {str(self.mode_list)} !")
|
41 |
+
|
42 |
+
# 将数字转化为字符串
|
43 |
+
if not isinstance(inputs, str):
|
44 |
+
inputs = str(inputs)
|
45 |
+
|
46 |
+
# 数据预处理:
|
47 |
+
# 1. 繁体转简体
|
48 |
+
# 2. 全角转半角
|
49 |
+
# inputs = preprocess(inputs, pipelines=[
|
50 |
+
# "traditional_to_simplified",
|
51 |
+
# "full_angle_to_half_angle"
|
52 |
+
# ])
|
53 |
+
|
54 |
+
# 特殊转化 廿
|
55 |
+
inputs = inputs.replace("廿", "二十")
|
56 |
+
|
57 |
+
# 检查输入数据是否有效
|
58 |
+
sign, integer_data, decimal_data, is_all_num = self.__check_input_data_is_valid(inputs, mode)
|
59 |
+
|
60 |
+
# smart 下的特殊情况
|
61 |
+
if sign == 0:
|
62 |
+
return integer_data
|
63 |
+
else:
|
64 |
+
if not is_all_num:
|
65 |
+
if decimal_data is None:
|
66 |
+
output = self.__integer_convert(integer_data)
|
67 |
+
else:
|
68 |
+
output = self.__integer_convert(integer_data) + self.__decimal_convert(decimal_data)
|
69 |
+
# fix 1 + 0.57 = 1.5699999999999998
|
70 |
+
output = round(output, len(decimal_data))
|
71 |
+
else:
|
72 |
+
if decimal_data is None:
|
73 |
+
output = self.__direct_convert(integer_data)
|
74 |
+
else:
|
75 |
+
output = self.__direct_convert(integer_data) + self.__decimal_convert(decimal_data)
|
76 |
+
# fix 1 + 0.57 = 1.5699999999999998
|
77 |
+
output = round(output, len(decimal_data))
|
78 |
+
else:
|
79 |
+
raise ValueError("输入数据为空!")
|
80 |
+
|
81 |
+
return sign * output
|
82 |
+
|
83 |
+
def __get_pattern(self) -> dict:
|
84 |
+
# 整数严格检查
|
85 |
+
_0 = "[零]"
|
86 |
+
_1_9 = "[一二三四五六七八九]"
|
87 |
+
_10_99 = f"{_1_9}?[十]{_1_9}?"
|
88 |
+
_1_99 = f"({_10_99}|{_1_9})"
|
89 |
+
_100_999 = f"({_1_9}[百]([零]{_1_9})?|{_1_9}[百]{_10_99})"
|
90 |
+
_1_999 = f"({_100_999}|{_1_99})"
|
91 |
+
_1000_9999 = f"({_1_9}[千]([零]{_1_99})?|{_1_9}[千]{_100_999})"
|
92 |
+
_1_9999 = f"({_1000_9999}|{_1_999})"
|
93 |
+
_10000_99999999 = f"({_1_9999}[万]([零]{_1_999})?|{_1_9999}[万]{_1000_9999})"
|
94 |
+
_1_99999999 = f"({_10000_99999999}|{_1_9999})"
|
95 |
+
_100000000_9999999999999999 = f"({_1_99999999}[亿]([零]{_1_99999999})?|{_1_99999999}[亿]{_10000_99999999})"
|
96 |
+
_1_9999999999999999 = f"({_100000000_9999999999999999}|{_1_99999999})"
|
97 |
+
str_int_pattern = f"^({_0}|{_1_9999999999999999})$"
|
98 |
+
nor_int_pattern = f"^({_0}|{_1_9999999999999999})$"
|
99 |
+
|
100 |
+
str_dec_pattern = "^[零一二三四五六七八九]{0,15}[一二三四五六七八九]$"
|
101 |
+
nor_dec_pattern = "^[零一二三四五六七八九]{0,16}$"
|
102 |
+
|
103 |
+
for str_num in self.strict_cn_number.keys():
|
104 |
+
str_int_pattern = str_int_pattern.replace(str_num, self.strict_cn_number[str_num])
|
105 |
+
str_dec_pattern = str_dec_pattern.replace(str_num, self.strict_cn_number[str_num])
|
106 |
+
for nor_num in self.normal_cn_number.keys():
|
107 |
+
nor_int_pattern = nor_int_pattern.replace(nor_num, self.normal_cn_number[nor_num])
|
108 |
+
nor_dec_pattern = nor_dec_pattern.replace(nor_num, self.normal_cn_number[nor_num])
|
109 |
+
|
110 |
+
pattern_dict = {
|
111 |
+
"strict": {
|
112 |
+
"int": re.compile(str_int_pattern),
|
113 |
+
"dec": re.compile(str_dec_pattern)
|
114 |
+
},
|
115 |
+
"normal": {
|
116 |
+
"int": re.compile(nor_int_pattern),
|
117 |
+
"dec": re.compile(nor_dec_pattern)
|
118 |
+
}
|
119 |
+
}
|
120 |
+
return pattern_dict
|
121 |
+
|
122 |
+
def __copy_num(self, num):
|
123 |
+
cn_num = ""
|
124 |
+
for n in num:
|
125 |
+
cn_num += NUMBER_LOW_AN2CN[int(n)]
|
126 |
+
return cn_num
|
127 |
+
|
128 |
+
def __check_input_data_is_valid(self, check_data: str, mode: str) -> (int, str, str, bool):
|
129 |
+
# 去除 元整、圆整、元正、圆正
|
130 |
+
stop_words = ["元整", "圆整", "元正", "圆正"]
|
131 |
+
for word in stop_words:
|
132 |
+
if check_data[-2:] == word:
|
133 |
+
check_data = check_data[:-2]
|
134 |
+
|
135 |
+
# 去除 元、圆
|
136 |
+
if mode != "strict":
|
137 |
+
normal_stop_words = ["圆", "元"]
|
138 |
+
for word in normal_stop_words:
|
139 |
+
if check_data[-1] == word:
|
140 |
+
check_data = check_data[:-1]
|
141 |
+
|
142 |
+
# 处理元角分
|
143 |
+
result = self.yjf_pattern.search(check_data)
|
144 |
+
if result:
|
145 |
+
check_data = check_data.replace("元", "点").replace("角", "").replace("分", "")
|
146 |
+
|
147 |
+
# 处理特殊问法:一千零十一 一万零百一十一
|
148 |
+
if "零十" in check_data:
|
149 |
+
check_data = check_data.replace("零十", "零一十")
|
150 |
+
if "零百" in check_data:
|
151 |
+
check_data = check_data.replace("零百", "零一百")
|
152 |
+
|
153 |
+
for data in check_data:
|
154 |
+
if data not in self.check_key_dict[mode]:
|
155 |
+
raise ValueError(f"当前为{mode}模式,输入的数据不在转化范围内:{data}!")
|
156 |
+
|
157 |
+
# 确定正负号
|
158 |
+
if check_data[0] == "负":
|
159 |
+
check_data = check_data[1:]
|
160 |
+
sign = -1
|
161 |
+
else:
|
162 |
+
sign = 1
|
163 |
+
|
164 |
+
if "点" in check_data:
|
165 |
+
split_data = check_data.split("点")
|
166 |
+
if len(split_data) == 2:
|
167 |
+
integer_data, decimal_data = split_data
|
168 |
+
# 将 smart 模式中的阿拉伯数字转化成中文数字
|
169 |
+
if mode == "smart":
|
170 |
+
integer_data = re.sub(r"\d+", lambda x: self.ac.an2cn(x.group()), integer_data)
|
171 |
+
decimal_data = re.sub(r"\d+", lambda x: self.__copy_num(x.group()), decimal_data)
|
172 |
+
mode = "normal"
|
173 |
+
else:
|
174 |
+
raise ValueError("数据中包含不止一个点!")
|
175 |
+
else:
|
176 |
+
integer_data = check_data
|
177 |
+
decimal_data = None
|
178 |
+
# 将 smart 模式中的阿拉伯数字转化成中文数字
|
179 |
+
if mode == "smart":
|
180 |
+
# 10.1万 10.1
|
181 |
+
result1 = self.pattern1.search(integer_data)
|
182 |
+
if result1:
|
183 |
+
if result1.group() == integer_data:
|
184 |
+
if integer_data[-1] in UNIT_CN2AN.keys():
|
185 |
+
output = int(float(integer_data[:-1]) * UNIT_CN2AN[integer_data[-1]])
|
186 |
+
else:
|
187 |
+
output = float(integer_data)
|
188 |
+
return 0, output, None, None
|
189 |
+
|
190 |
+
integer_data = re.sub(r"\d+", lambda x: self.ac.an2cn(x.group()), integer_data)
|
191 |
+
mode = "normal"
|
192 |
+
|
193 |
+
result_int = self.pattern_dict[mode]["int"].search(integer_data)
|
194 |
+
if result_int:
|
195 |
+
if result_int.group() == integer_data:
|
196 |
+
if decimal_data is not None:
|
197 |
+
result_dec = self.pattern_dict[mode]["dec"].search(decimal_data)
|
198 |
+
if result_dec:
|
199 |
+
if result_dec.group() == decimal_data:
|
200 |
+
return sign, integer_data, decimal_data, False
|
201 |
+
else:
|
202 |
+
return sign, integer_data, decimal_data, False
|
203 |
+
else:
|
204 |
+
if mode == "strict":
|
205 |
+
raise ValueError(f"不符合格式的数据:{integer_data}")
|
206 |
+
elif mode == "normal":
|
207 |
+
# 纯数模式:一二三
|
208 |
+
result_all_num = self.ptn_all_num.search(integer_data)
|
209 |
+
if result_all_num:
|
210 |
+
if result_all_num.group() == integer_data:
|
211 |
+
if decimal_data is not None:
|
212 |
+
result_dec = self.pattern_dict[mode]["dec"].search(decimal_data)
|
213 |
+
if result_dec:
|
214 |
+
if result_dec.group() == decimal_data:
|
215 |
+
return sign, integer_data, decimal_data, True
|
216 |
+
else:
|
217 |
+
return sign, integer_data, decimal_data, True
|
218 |
+
|
219 |
+
# 口语模式:一万二,两千三,三百四,十三万六,一百二十五万���
|
220 |
+
result_speaking_mode = self.ptn_speaking_mode.search(integer_data)
|
221 |
+
if len(integer_data) >= 3 and result_speaking_mode and result_speaking_mode.group() == integer_data:
|
222 |
+
# len(integer_data)>=3: because the minimum length of integer_data that can be matched is 3
|
223 |
+
# to find the last unit
|
224 |
+
last_unit = result_speaking_mode.groups()[-1][-1]
|
225 |
+
_unit = UNIT_LOW_AN2CN[UNIT_CN2AN[last_unit] // 10]
|
226 |
+
integer_data = integer_data + _unit
|
227 |
+
if decimal_data is not None:
|
228 |
+
result_dec = self.pattern_dict[mode]["dec"].search(decimal_data)
|
229 |
+
if result_dec:
|
230 |
+
if result_dec.group() == decimal_data:
|
231 |
+
return sign, integer_data, decimal_data, False
|
232 |
+
else:
|
233 |
+
return sign, integer_data, decimal_data, False
|
234 |
+
|
235 |
+
raise ValueError(f"不符合格式的数据:{check_data}")
|
236 |
+
|
237 |
+
def __integer_convert(self, integer_data: str) -> int:
|
238 |
+
# 核心
|
239 |
+
output_integer = 0
|
240 |
+
unit = 1
|
241 |
+
ten_thousand_unit = 1
|
242 |
+
for index, cn_num in enumerate(reversed(integer_data)):
|
243 |
+
# 数值
|
244 |
+
if cn_num in NUMBER_CN2AN:
|
245 |
+
num = NUMBER_CN2AN[cn_num]
|
246 |
+
output_integer += num * unit
|
247 |
+
# 单位
|
248 |
+
elif cn_num in UNIT_CN2AN:
|
249 |
+
unit = UNIT_CN2AN[cn_num]
|
250 |
+
# 判断出万、亿、万亿
|
251 |
+
if unit % 10000 == 0:
|
252 |
+
# 万 亿
|
253 |
+
if unit > ten_thousand_unit:
|
254 |
+
ten_thousand_unit = unit
|
255 |
+
# 万亿
|
256 |
+
else:
|
257 |
+
ten_thousand_unit = unit * ten_thousand_unit
|
258 |
+
unit = ten_thousand_unit
|
259 |
+
|
260 |
+
if unit < ten_thousand_unit:
|
261 |
+
unit = unit * ten_thousand_unit
|
262 |
+
|
263 |
+
if index == len(integer_data) - 1:
|
264 |
+
output_integer += unit
|
265 |
+
else:
|
266 |
+
raise ValueError(f"{cn_num} 不在转化范围内")
|
267 |
+
|
268 |
+
return int(output_integer)
|
269 |
+
|
270 |
+
def __decimal_convert(self, decimal_data: str) -> float:
|
271 |
+
len_decimal_data = len(decimal_data)
|
272 |
+
|
273 |
+
if len_decimal_data > 16:
|
274 |
+
warn(f"注意:小数部分长度为 {len_decimal_data} ,将自动截取前 16 位有效精度!")
|
275 |
+
decimal_data = decimal_data[:16]
|
276 |
+
len_decimal_data = 16
|
277 |
+
|
278 |
+
output_decimal = 0
|
279 |
+
for index in range(len(decimal_data) - 1, -1, -1):
|
280 |
+
unit_key = NUMBER_CN2AN[decimal_data[index]]
|
281 |
+
output_decimal += unit_key * 10 ** -(index + 1)
|
282 |
+
|
283 |
+
# 处理精度溢出问题
|
284 |
+
output_decimal = round(output_decimal, len_decimal_data)
|
285 |
+
|
286 |
+
return output_decimal
|
287 |
+
|
288 |
+
def __direct_convert(self, data: str) -> int:
|
289 |
+
output_data = 0
|
290 |
+
for index in range(len(data) - 1, -1, -1):
|
291 |
+
unit_key = NUMBER_CN2AN[data[index]]
|
292 |
+
output_data += unit_key * 10 ** (len(data) - index - 1)
|
293 |
+
|
294 |
+
return output_data
|
text/cn2an/conf.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
NUMBER_CN2AN = {
|
2 |
+
"零": 0,
|
3 |
+
"〇": 0,
|
4 |
+
"一": 1,
|
5 |
+
"壹": 1,
|
6 |
+
"幺": 1,
|
7 |
+
"二": 2,
|
8 |
+
"贰": 2,
|
9 |
+
"两": 2,
|
10 |
+
"三": 3,
|
11 |
+
"叁": 3,
|
12 |
+
"四": 4,
|
13 |
+
"肆": 4,
|
14 |
+
"五": 5,
|
15 |
+
"伍": 5,
|
16 |
+
"六": 6,
|
17 |
+
"陆": 6,
|
18 |
+
"七": 7,
|
19 |
+
"柒": 7,
|
20 |
+
"八": 8,
|
21 |
+
"捌": 8,
|
22 |
+
"九": 9,
|
23 |
+
"玖": 9,
|
24 |
+
}
|
25 |
+
UNIT_CN2AN = {
|
26 |
+
"十": 10,
|
27 |
+
"拾": 10,
|
28 |
+
"百": 100,
|
29 |
+
"佰": 100,
|
30 |
+
"千": 1000,
|
31 |
+
"仟": 1000,
|
32 |
+
"万": 10000,
|
33 |
+
"亿": 100000000,
|
34 |
+
}
|
35 |
+
UNIT_LOW_AN2CN = {
|
36 |
+
10: "十",
|
37 |
+
100: "百",
|
38 |
+
1000: "千",
|
39 |
+
10000: "万",
|
40 |
+
100000000: "亿",
|
41 |
+
}
|
42 |
+
NUMBER_LOW_AN2CN = {
|
43 |
+
0: "零",
|
44 |
+
1: "一",
|
45 |
+
2: "二",
|
46 |
+
3: "三",
|
47 |
+
4: "四",
|
48 |
+
5: "五",
|
49 |
+
6: "六",
|
50 |
+
7: "七",
|
51 |
+
8: "八",
|
52 |
+
9: "九",
|
53 |
+
}
|
54 |
+
NUMBER_UP_AN2CN = {
|
55 |
+
0: "零",
|
56 |
+
1: "壹",
|
57 |
+
2: "贰",
|
58 |
+
3: "叁",
|
59 |
+
4: "肆",
|
60 |
+
5: "伍",
|
61 |
+
6: "陆",
|
62 |
+
7: "柒",
|
63 |
+
8: "捌",
|
64 |
+
9: "玖",
|
65 |
+
}
|
66 |
+
UNIT_LOW_ORDER_AN2CN = [
|
67 |
+
"",
|
68 |
+
"十",
|
69 |
+
"百",
|
70 |
+
"千",
|
71 |
+
"万",
|
72 |
+
"十",
|
73 |
+
"百",
|
74 |
+
"千",
|
75 |
+
"亿",
|
76 |
+
"十",
|
77 |
+
"百",
|
78 |
+
"千",
|
79 |
+
"万",
|
80 |
+
"十",
|
81 |
+
"百",
|
82 |
+
"千",
|
83 |
+
]
|
84 |
+
UNIT_UP_ORDER_AN2CN = [
|
85 |
+
"",
|
86 |
+
"拾",
|
87 |
+
"佰",
|
88 |
+
"仟",
|
89 |
+
"万",
|
90 |
+
"拾",
|
91 |
+
"佰",
|
92 |
+
"仟",
|
93 |
+
"亿",
|
94 |
+
"拾",
|
95 |
+
"佰",
|
96 |
+
"仟",
|
97 |
+
"万",
|
98 |
+
"拾",
|
99 |
+
"佰",
|
100 |
+
"仟",
|
101 |
+
]
|
102 |
+
STRICT_CN_NUMBER = {
|
103 |
+
"零": "零",
|
104 |
+
"一": "一壹",
|
105 |
+
"二": "二贰",
|
106 |
+
"三": "三叁",
|
107 |
+
"四": "四肆",
|
108 |
+
"五": "五伍",
|
109 |
+
"六": "六陆",
|
110 |
+
"七": "七柒",
|
111 |
+
"八": "八捌",
|
112 |
+
"九": "九玖",
|
113 |
+
"十": "十拾",
|
114 |
+
"百": "百佰",
|
115 |
+
"千": "千仟",
|
116 |
+
"万": "万",
|
117 |
+
"亿": "亿",
|
118 |
+
}
|
119 |
+
NORMAL_CN_NUMBER = {
|
120 |
+
"零": "零〇",
|
121 |
+
"一": "一壹幺",
|
122 |
+
"二": "二贰两",
|
123 |
+
"三": "三叁仨",
|
124 |
+
"四": "四肆",
|
125 |
+
"五": "五伍",
|
126 |
+
"六": "六陆",
|
127 |
+
"七": "七柒",
|
128 |
+
"八": "八捌",
|
129 |
+
"九": "九玖",
|
130 |
+
"十": "十拾",
|
131 |
+
"百": "百佰",
|
132 |
+
"千": "千仟",
|
133 |
+
"万": "万",
|
134 |
+
"亿": "亿",
|
135 |
+
}
|
text/cn2an/transform.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from warnings import warn
|
3 |
+
|
4 |
+
from .cn2an import Cn2An
|
5 |
+
from .an2cn import An2Cn
|
6 |
+
from .conf import UNIT_CN2AN
|
7 |
+
|
8 |
+
|
9 |
+
class Transform(object):
|
10 |
+
def __init__(self) -> None:
|
11 |
+
self.all_num = "零一二三四五六七八九"
|
12 |
+
self.all_unit = "".join(list(UNIT_CN2AN.keys()))
|
13 |
+
self.cn2an = Cn2An().cn2an
|
14 |
+
self.an2cn = An2Cn().an2cn
|
15 |
+
self.cn_pattern = f"负?([{self.all_num}{self.all_unit}]+点)?[{self.all_num}{self.all_unit}]+"
|
16 |
+
self.smart_cn_pattern = f"-?([0-9]+.)?[0-9]+[{self.all_unit}]+"
|
17 |
+
|
18 |
+
def transform(self, inputs: str, method: str = "cn2an") -> str:
|
19 |
+
if method == "cn2an":
|
20 |
+
inputs = inputs.replace("廿", "二十").replace("半", "0.5").replace("两", "2")
|
21 |
+
# date
|
22 |
+
inputs = re.sub(
|
23 |
+
fr"((({self.smart_cn_pattern})|({self.cn_pattern}))年)?([{self.all_num}十]+月)?([{self.all_num}十]+日)?",
|
24 |
+
lambda x: self.__sub_util(x.group(), "cn2an", "date"), inputs)
|
25 |
+
# fraction
|
26 |
+
inputs = re.sub(fr"{self.cn_pattern}分之{self.cn_pattern}",
|
27 |
+
lambda x: self.__sub_util(x.group(), "cn2an", "fraction"), inputs)
|
28 |
+
# percent
|
29 |
+
inputs = re.sub(fr"百分之{self.cn_pattern}",
|
30 |
+
lambda x: self.__sub_util(x.group(), "cn2an", "percent"), inputs)
|
31 |
+
# celsius
|
32 |
+
inputs = re.sub(fr"{self.cn_pattern}摄氏度",
|
33 |
+
lambda x: self.__sub_util(x.group(), "cn2an", "celsius"), inputs)
|
34 |
+
# number
|
35 |
+
output = re.sub(self.cn_pattern,
|
36 |
+
lambda x: self.__sub_util(x.group(), "cn2an", "number"), inputs)
|
37 |
+
|
38 |
+
elif method == "an2cn":
|
39 |
+
# date
|
40 |
+
inputs = re.sub(r"(\d{2,4}年)?(\d{1,2}月)?(\d{1,2}日)?",
|
41 |
+
lambda x: self.__sub_util(x.group(), "an2cn", "date"), inputs)
|
42 |
+
# fraction
|
43 |
+
inputs = re.sub(r"\d+/\d+",
|
44 |
+
lambda x: self.__sub_util(x.group(), "an2cn", "fraction"), inputs)
|
45 |
+
# percent
|
46 |
+
inputs = re.sub(r"-?(\d+\.)?\d+%",
|
47 |
+
lambda x: self.__sub_util(x.group(), "an2cn", "percent"), inputs)
|
48 |
+
# celsius
|
49 |
+
inputs = re.sub(r"\d+℃",
|
50 |
+
lambda x: self.__sub_util(x.group(), "an2cn", "celsius"), inputs)
|
51 |
+
# number
|
52 |
+
output = re.sub(r"-?(\d+\.)?\d+",
|
53 |
+
lambda x: self.__sub_util(x.group(), "an2cn", "number"), inputs)
|
54 |
+
else:
|
55 |
+
raise ValueError(f"error method: {method}, only support 'cn2an' and 'an2cn'!")
|
56 |
+
|
57 |
+
return output
|
58 |
+
|
59 |
+
def __sub_util(self, inputs, method: str = "cn2an", sub_mode: str = "number") -> str:
|
60 |
+
try:
|
61 |
+
if inputs:
|
62 |
+
if method == "cn2an":
|
63 |
+
if sub_mode == "date":
|
64 |
+
return re.sub(fr"(({self.smart_cn_pattern})|({self.cn_pattern}))",
|
65 |
+
lambda x: str(self.cn2an(x.group(), "smart")), inputs)
|
66 |
+
elif sub_mode == "fraction":
|
67 |
+
if inputs[0] != "百":
|
68 |
+
frac_result = re.sub(self.cn_pattern,
|
69 |
+
lambda x: str(self.cn2an(x.group(), "smart")), inputs)
|
70 |
+
numerator, denominator = frac_result.split("分之")
|
71 |
+
return f"{denominator}/{numerator}"
|
72 |
+
else:
|
73 |
+
return inputs
|
74 |
+
elif sub_mode == "percent":
|
75 |
+
return re.sub(f"(?<=百分之){self.cn_pattern}",
|
76 |
+
lambda x: str(self.cn2an(x.group(), "smart")), inputs).replace("百分之", "") + "%"
|
77 |
+
elif sub_mode == "celsius":
|
78 |
+
return re.sub(f"{self.cn_pattern}(?=摄氏度)",
|
79 |
+
lambda x: str(self.cn2an(x.group(), "smart")), inputs).replace("摄氏度", "℃")
|
80 |
+
elif sub_mode == "number":
|
81 |
+
return str(self.cn2an(inputs, "smart"))
|
82 |
+
else:
|
83 |
+
raise Exception(f"error sub_mode: {sub_mode} !")
|
84 |
+
else:
|
85 |
+
if sub_mode == "date":
|
86 |
+
inputs = re.sub(r"\d+(?=年)",
|
87 |
+
lambda x: self.an2cn(x.group(), "direct"), inputs)
|
88 |
+
return re.sub(r"\d+",
|
89 |
+
lambda x: self.an2cn(x.group(), "low"), inputs)
|
90 |
+
elif sub_mode == "fraction":
|
91 |
+
frac_result = re.sub(r"\d+", lambda x: self.an2cn(x.group(), "low"), inputs)
|
92 |
+
numerator, denominator = frac_result.split("/")
|
93 |
+
return f"{denominator}分之{numerator}"
|
94 |
+
elif sub_mode == "celsius":
|
95 |
+
return self.an2cn(inputs[:-1], "low") + "摄氏度"
|
96 |
+
elif sub_mode == "percent":
|
97 |
+
return "百分之" + self.an2cn(inputs[:-1], "low")
|
98 |
+
elif sub_mode == "number":
|
99 |
+
return self.an2cn(inputs, "low")
|
100 |
+
else:
|
101 |
+
raise Exception(f"error sub_mode: {sub_mode} !")
|
102 |
+
except Exception as e:
|
103 |
+
warn(str(e))
|
104 |
+
return inputs
|
text/cnm3/ds_CNM3.txt
ADDED
@@ -0,0 +1,606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
a,a
|
2 |
+
ai,ai
|
3 |
+
ai0,a0 I0
|
4 |
+
an,an
|
5 |
+
an0,a0 N0
|
6 |
+
ang,ang
|
7 |
+
ang0,A0 ng0
|
8 |
+
ao,ao
|
9 |
+
ao0,A0 O0
|
10 |
+
ba,b a
|
11 |
+
bai,b a0 I0
|
12 |
+
ban,b a0 N0
|
13 |
+
bang,b A0 ng0
|
14 |
+
bao,b A0 O0
|
15 |
+
be,b e
|
16 |
+
bei,b E0 I0
|
17 |
+
ben,b e0 N0
|
18 |
+
beng,b e0 ng0
|
19 |
+
ber,b er
|
20 |
+
bi,b i
|
21 |
+
bia,b ia
|
22 |
+
bian,b iE0 N0
|
23 |
+
biang,b iA0 ng0
|
24 |
+
biao,b iA0 O0
|
25 |
+
bie,b ie
|
26 |
+
bin,b i N0
|
27 |
+
bing,b i ng0
|
28 |
+
biong,b iO0 ng0
|
29 |
+
biu,b io0 U0
|
30 |
+
bo,b o
|
31 |
+
bong,b oo0 ng0
|
32 |
+
bou,b o0 U0
|
33 |
+
bu,b u
|
34 |
+
bua,b ua
|
35 |
+
buai,b ua0 I0
|
36 |
+
buan,b ua0 N0
|
37 |
+
buang,b uA0 ng0
|
38 |
+
bui,b uE0 I0
|
39 |
+
bun,b ue0 N0
|
40 |
+
bv,b v
|
41 |
+
bve,b ve
|
42 |
+
ca,c a
|
43 |
+
cai,c a0 I0
|
44 |
+
can,c a0 N0
|
45 |
+
cang,c A0 ng0
|
46 |
+
cao,c A0 O0
|
47 |
+
ce,c e
|
48 |
+
cei,c E0 I0
|
49 |
+
cen,c e0 N0
|
50 |
+
ceng,c e0 ng0
|
51 |
+
cer,c er
|
52 |
+
cha,ch a
|
53 |
+
chai,ch a0 I0
|
54 |
+
chan,ch a0 N0
|
55 |
+
chang,ch A0 ng0
|
56 |
+
chao,ch A0 O0
|
57 |
+
che,ch e
|
58 |
+
chei,ch E0 I0
|
59 |
+
chen,ch e0 N0
|
60 |
+
cheng,ch e0 ng0
|
61 |
+
cher,ch er
|
62 |
+
chi,ch ir
|
63 |
+
chong,ch oo0 ng0
|
64 |
+
chou,ch o0 U0
|
65 |
+
chu,ch u
|
66 |
+
chua,ch ua
|
67 |
+
chuai,ch ua0 I0
|
68 |
+
chuan,ch ua0 N0
|
69 |
+
chuang,ch uA0 ng0
|
70 |
+
chui,ch uE0 I0
|
71 |
+
chun,ch ue0 N0
|
72 |
+
chuo,ch uo
|
73 |
+
chv,ch v
|
74 |
+
chyi,ch i
|
75 |
+
ci,c i0
|
76 |
+
cong,c oo0 ng0
|
77 |
+
cou,c o0 U0
|
78 |
+
cu,c u
|
79 |
+
cua,c ua
|
80 |
+
cuai,c ua0 I0
|
81 |
+
cuan,c ua0 N0
|
82 |
+
cuang,c uA0 ng0
|
83 |
+
cui,c uE0 I0
|
84 |
+
cun,c ue0 N0
|
85 |
+
cuo,c uo
|
86 |
+
cv,c v
|
87 |
+
cyi,c i
|
88 |
+
da,d a
|
89 |
+
dai,d a0 I0
|
90 |
+
dan,d a0 N0
|
91 |
+
dang,d A0 ng0
|
92 |
+
dao,d A0 O0
|
93 |
+
de,d e
|
94 |
+
dei,d E0 I0
|
95 |
+
den,d e0 N0
|
96 |
+
deng,d e0 ng0
|
97 |
+
der,d er
|
98 |
+
di,d i
|
99 |
+
dia,d ia
|
100 |
+
dian,d iE0 N0
|
101 |
+
diang,d iA0 ng0
|
102 |
+
diao,d iA0 O0
|
103 |
+
die,d ie
|
104 |
+
din,d i N0
|
105 |
+
ding,d i ng0
|
106 |
+
diong,d iO0 ng0
|
107 |
+
diu,d io0 U0
|
108 |
+
dong,d oo0 ng0
|
109 |
+
dou,d o0 U0
|
110 |
+
du,d u
|
111 |
+
dua,d ua
|
112 |
+
duai,d ua0 I0
|
113 |
+
duan,d ua0 N0
|
114 |
+
duang,d uA0 ng0
|
115 |
+
dui,d uE0 I0
|
116 |
+
dun,d ue0 N0
|
117 |
+
duo,d uo
|
118 |
+
dv,d v
|
119 |
+
dve,d ve
|
120 |
+
e,e
|
121 |
+
ei,E0 I0
|
122 |
+
en,e0 N0
|
123 |
+
eng,e0 ng0
|
124 |
+
er,er
|
125 |
+
fa,f a
|
126 |
+
fai,f a0 I0
|
127 |
+
fan,f a0 N0
|
128 |
+
fang,f A0 ng0
|
129 |
+
fao,f A0 O0
|
130 |
+
fe,f e
|
131 |
+
fei,f E0 I0
|
132 |
+
fen,f e0 N0
|
133 |
+
feng,f e0 ng0
|
134 |
+
fer,f er
|
135 |
+
fi,f i
|
136 |
+
fia,f ia
|
137 |
+
fian,f iE0 N0
|
138 |
+
fiang,f iA0 ng0
|
139 |
+
fiao,f iA0 O0
|
140 |
+
fie,f ie
|
141 |
+
fin,f i N0
|
142 |
+
fing,f i ng0
|
143 |
+
fiong,f iO0 ng0
|
144 |
+
fiu,f io0 U0
|
145 |
+
fo,f o
|
146 |
+
fong,f oo0 ng0
|
147 |
+
fou,f o0 U0
|
148 |
+
fu,f u
|
149 |
+
fua,f ua
|
150 |
+
fuai,f ua0 I0
|
151 |
+
fuan,f ua0 N0
|
152 |
+
fuang,f uA0 ng0
|
153 |
+
fui,f uE0 I0
|
154 |
+
fun,f ue0 N0
|
155 |
+
fv,f v
|
156 |
+
fve,f ve
|
157 |
+
ga,g a
|
158 |
+
gai,g a0 I0
|
159 |
+
gan,g a0 N0
|
160 |
+
gang,g A0 ng0
|
161 |
+
gao,g A0 O0
|
162 |
+
ge,g e
|
163 |
+
gei,g E0 I0
|
164 |
+
gen,g e0 N0
|
165 |
+
geng,g e0 ng0
|
166 |
+
ger,g er
|
167 |
+
gi,g i
|
168 |
+
gia,g ia
|
169 |
+
gian,g iE0 N0
|
170 |
+
giang,g iA0 ng0
|
171 |
+
giao,g iA0 O0
|
172 |
+
gie,g ie
|
173 |
+
gin,g i N0
|
174 |
+
ging,g i ng0
|
175 |
+
giong,g iO0 ng0
|
176 |
+
giu,g io0 U0
|
177 |
+
gong,g oo0 ng0
|
178 |
+
gou,g o0 U0
|
179 |
+
gu,g u
|
180 |
+
gua,g ua
|
181 |
+
guai,g ua0 I0
|
182 |
+
guan,g ua0 N0
|
183 |
+
guang,g uA0 ng0
|
184 |
+
gui,g uE0 I0
|
185 |
+
gun,g ue0 N0
|
186 |
+
guo,g uo
|
187 |
+
gv,g v
|
188 |
+
gve,g ve
|
189 |
+
ha,h a
|
190 |
+
hai,h a0 I0
|
191 |
+
han,h a0 N0
|
192 |
+
hang,h A0 ng0
|
193 |
+
hao,h A0 O0
|
194 |
+
he,h e
|
195 |
+
hei,h E0 I0
|
196 |
+
hen,h e0 N0
|
197 |
+
heng,h e0 ng0
|
198 |
+
her,h er
|
199 |
+
hi,h i
|
200 |
+
hia,h ia
|
201 |
+
hian,h iE0 N0
|
202 |
+
hiang,h iA0 ng0
|
203 |
+
hiao,h iA0 O0
|
204 |
+
hie,h ie
|
205 |
+
hin,h i N0
|
206 |
+
hing,h i ng0
|
207 |
+
hiong,h iO0 ng0
|
208 |
+
hiu,h io0 U0
|
209 |
+
hong,h oo0 ng0
|
210 |
+
hou,h o0 U0
|
211 |
+
hu,h u
|
212 |
+
hua,h ua
|
213 |
+
huai,h ua0 I0
|
214 |
+
huan,h ua0 N0
|
215 |
+
huang,h uA0 ng0
|
216 |
+
hui,h uE0 I0
|
217 |
+
hun,h ue0 N0
|
218 |
+
huo,h uo
|
219 |
+
hv,h v
|
220 |
+
hve,h ve
|
221 |
+
ji,j i
|
222 |
+
jia,j ia
|
223 |
+
jian,j iE0 N0
|
224 |
+
jiang,j iA0 ng0
|
225 |
+
jiao,j iA0 O0
|
226 |
+
jie,j ie
|
227 |
+
jin,j i N0
|
228 |
+
jing,j i ng0
|
229 |
+
jiong,j iO0 ng0
|
230 |
+
jiu,j io0 U0
|
231 |
+
ju,j v
|
232 |
+
juan,j vE0 N0
|
233 |
+
jue,j ve
|
234 |
+
jun,j v0 N0
|
235 |
+
ka,k a
|
236 |
+
kai,k a0 I0
|
237 |
+
kan,k a0 N0
|
238 |
+
kang,k A0 ng0
|
239 |
+
kao,k A0 O0
|
240 |
+
ke,k e
|
241 |
+
kei,k E0 I0
|
242 |
+
ken,k e0 N0
|
243 |
+
keng,k e0 ng0
|
244 |
+
ker,k er
|
245 |
+
ki,k i
|
246 |
+
kia,k ia
|
247 |
+
kian,k iE0 N0
|
248 |
+
kiang,k iA0 ng0
|
249 |
+
kiao,k iA0 O0
|
250 |
+
kie,k ie
|
251 |
+
kin,k i N0
|
252 |
+
king,k i ng0
|
253 |
+
kiong,k iO0 ng0
|
254 |
+
kiu,k io0 U0
|
255 |
+
kong,k oo0 ng0
|
256 |
+
kou,k o0 U0
|
257 |
+
ku,k u
|
258 |
+
kua,k ua
|
259 |
+
kuai,k ua0 I0
|
260 |
+
kuan,k ua0 N0
|
261 |
+
kuang,k uA0 ng0
|
262 |
+
kui,k uE0 I0
|
263 |
+
kun,k ue0 N0
|
264 |
+
kuo,k uo
|
265 |
+
kv,k v
|
266 |
+
kve,k ve
|
267 |
+
la,l a
|
268 |
+
lai,l a0 I0
|
269 |
+
lan,l a0 N0
|
270 |
+
lang,l A0 ng0
|
271 |
+
lao,l A0 O0
|
272 |
+
le,l e
|
273 |
+
lei,l E0 I0
|
274 |
+
len,l e0 N0
|
275 |
+
leng,l e0 ng0
|
276 |
+
ler,l er
|
277 |
+
li,l i
|
278 |
+
lia,l ia
|
279 |
+
lian,l iE0 N0
|
280 |
+
liang,l iA0 ng0
|
281 |
+
liao,l iA0 O0
|
282 |
+
lie,l ie
|
283 |
+
lin,l i N0
|
284 |
+
ling,l i ng0
|
285 |
+
liong,l iO0 ng0
|
286 |
+
liu,l io0 U0
|
287 |
+
lo,l o
|
288 |
+
long,l oo0 ng0
|
289 |
+
lou,l o0 U0
|
290 |
+
lu,l u
|
291 |
+
lua,l ua
|
292 |
+
luai,l ua0 I0
|
293 |
+
luan,l ua0 N0
|
294 |
+
luang,l uA0 ng0
|
295 |
+
lui,l uE0 I0
|
296 |
+
lun,l ue0 N0
|
297 |
+
luo,l uo
|
298 |
+
lv,l v
|
299 |
+
lve,l ve
|
300 |
+
ma,m a
|
301 |
+
mai,m a0 I0
|
302 |
+
man,m a0 N0
|
303 |
+
mang,m A0 ng0
|
304 |
+
mao,m A0 O0
|
305 |
+
me,m e
|
306 |
+
mei,m E0 I0
|
307 |
+
men,m e0 N0
|
308 |
+
meng,m e0 ng0
|
309 |
+
mer,m er
|
310 |
+
mi,m i
|
311 |
+
mia,m ia
|
312 |
+
mian,m iE0 N0
|
313 |
+
miang,m iA0 ng0
|
314 |
+
miao,m iA0 O0
|
315 |
+
mie,m ie
|
316 |
+
min,m i N0
|
317 |
+
ming,m i ng0
|
318 |
+
miong,m iO0 ng0
|
319 |
+
miu,m io0 U0
|
320 |
+
mo,m o
|
321 |
+
mong,m oo0 ng0
|
322 |
+
mou,m o0 U0
|
323 |
+
mu,m u
|
324 |
+
mua,m ua
|
325 |
+
muai,m ua0 I0
|
326 |
+
muan,m ua0 N0
|
327 |
+
muang,m uA0 ng0
|
328 |
+
mui,m uE0 I0
|
329 |
+
mun,m ue0 N0
|
330 |
+
mv,m v
|
331 |
+
mve,m ve
|
332 |
+
n,ng
|
333 |
+
na,n a
|
334 |
+
nai,n a0 I0
|
335 |
+
nan,n a0 N0
|
336 |
+
nang,n A0 ng0
|
337 |
+
nao,n A0 O0
|
338 |
+
ne,n e
|
339 |
+
nei,n E0 I0
|
340 |
+
nen,n e0 N0
|
341 |
+
neng,n e0 ng0
|
342 |
+
ner,n er
|
343 |
+
ni,n i
|
344 |
+
nia,n ia
|
345 |
+
nian,n iE0 N0
|
346 |
+
niang,n iA0 ng0
|
347 |
+
niao,n iA0 O0
|
348 |
+
nie,n ie
|
349 |
+
nin,n i N0
|
350 |
+
ning,n i ng0
|
351 |
+
niong,n iO0 ng0
|
352 |
+
niu,n io0 U0
|
353 |
+
nong,n oo0 ng0
|
354 |
+
nou,n o0 U0
|
355 |
+
nu,n u
|
356 |
+
nua,n ua
|
357 |
+
nuai,n ua0 I0
|
358 |
+
nuan,n ua0 N0
|
359 |
+
nuang,n uA0 ng0
|
360 |
+
nui,n uE0 I0
|
361 |
+
nun,n ue0 N0
|
362 |
+
nuo,n uo
|
363 |
+
nv,n v
|
364 |
+
nve,n ve
|
365 |
+
o,o
|
366 |
+
ong,ong
|
367 |
+
ou,ou
|
368 |
+
pa,p a
|
369 |
+
pai,p a0 I0
|
370 |
+
pan,p a0 N0
|
371 |
+
pang,p A0 ng0
|
372 |
+
pao,p A0 O0
|
373 |
+
pe,p e
|
374 |
+
pei,p E0 I0
|
375 |
+
pen,p e0 N0
|
376 |
+
peng,p e0 ng0
|
377 |
+
per,p er
|
378 |
+
pi,p i
|
379 |
+
pia,p ia
|
380 |
+
pian,p iE0 N0
|
381 |
+
piang,p iA0 ng0
|
382 |
+
piao,p iA0 O0
|
383 |
+
pie,p ie
|
384 |
+
pin,p i N0
|
385 |
+
ping,p i ng0
|
386 |
+
piong,p iO0 ng0
|
387 |
+
piu,p io0 U0
|
388 |
+
po,p o
|
389 |
+
pong,p oo0 ng0
|
390 |
+
pou,p o0 U0
|
391 |
+
pu,p u
|
392 |
+
pua,p ua
|
393 |
+
puai,p ua0 I0
|
394 |
+
puan,p ua0 N0
|
395 |
+
puang,p uA0 ng0
|
396 |
+
pui,p uE0 I0
|
397 |
+
pun,p ue0 N0
|
398 |
+
pv,p v
|
399 |
+
pve,p ve
|
400 |
+
qi,q i
|
401 |
+
qia,q ia
|
402 |
+
qian,q iE0 N0
|
403 |
+
qiang,q iA0 ng0
|
404 |
+
qiao,q iA0 O0
|
405 |
+
qie,q ie
|
406 |
+
qin,q i N0
|
407 |
+
qing,q i ng0
|
408 |
+
qiong,q iO0 ng0
|
409 |
+
qiu,q io0 U0
|
410 |
+
qu,q v
|
411 |
+
quan,q vE0 N0
|
412 |
+
que,q ve
|
413 |
+
qun,q v0 N0
|
414 |
+
ra,r a
|
415 |
+
rai,r a0 I0
|
416 |
+
ran,r a0 N0
|
417 |
+
rang,r A0 ng0
|
418 |
+
rao,r A0 O0
|
419 |
+
re,r e
|
420 |
+
rei,r E0 I0
|
421 |
+
ren,r e0 N0
|
422 |
+
reng,r e0 ng0
|
423 |
+
rer,r er
|
424 |
+
ri,r ir
|
425 |
+
rong,r oo0 ng0
|
426 |
+
rou,r o0 U0
|
427 |
+
ru,r u
|
428 |
+
rua,r ua
|
429 |
+
ruai,r ua0 I0
|
430 |
+
ruan,r ua0 N0
|
431 |
+
ruang,r uA0 ng0
|
432 |
+
rui,r uE0 I0
|
433 |
+
run,r ue0 N0
|
434 |
+
ruo,r uo
|
435 |
+
rv,r v
|
436 |
+
ryi,r i
|
437 |
+
sa,s a
|
438 |
+
sai,s a0 I0
|
439 |
+
san,s a0 N0
|
440 |
+
sang,s A0 ng0
|
441 |
+
sao,s A0 O0
|
442 |
+
se,s e
|
443 |
+
sei,s E0 I0
|
444 |
+
sen,s e0 N0
|
445 |
+
seng,s e0 ng0
|
446 |
+
ser,s er
|
447 |
+
sha,sh a
|
448 |
+
shai,sh a0 I0
|
449 |
+
shan,sh a0 N0
|
450 |
+
shang,sh A0 ng0
|
451 |
+
shao,sh A0 O0
|
452 |
+
she,sh e
|
453 |
+
shei,sh E0 I0
|
454 |
+
shen,sh e0 N0
|
455 |
+
sheng,sh e0 ng0
|
456 |
+
sher,sh er
|
457 |
+
shi,sh ir
|
458 |
+
shong,sh oo0 ng0
|
459 |
+
shou,sh o0 U0
|
460 |
+
shu,sh u
|
461 |
+
shua,sh ua
|
462 |
+
shuai,sh ua0 I0
|
463 |
+
shuan,sh ua0 N0
|
464 |
+
shuang,sh uA0 ng0
|
465 |
+
shui,sh uE0 I0
|
466 |
+
shun,sh ue0 N0
|
467 |
+
shuo,sh uo
|
468 |
+
shv,sh v
|
469 |
+
shyi,sh i
|
470 |
+
si,s i0
|
471 |
+
song,s oo0 ng0
|
472 |
+
sou,s o0 U0
|
473 |
+
su,s u
|
474 |
+
sua,s ua
|
475 |
+
suai,s ua0 I0
|
476 |
+
suan,s ua0 N0
|
477 |
+
suang,s uA0 ng0
|
478 |
+
sui,s uE0 I0
|
479 |
+
sun,s ue0 N0
|
480 |
+
suo,s uo
|
481 |
+
sv,s v
|
482 |
+
syi,s i
|
483 |
+
ta,t a
|
484 |
+
tai,t a0 I0
|
485 |
+
tan,t a0 N0
|
486 |
+
tang,t A0 ng0
|
487 |
+
tao,t A0 O0
|
488 |
+
te,t e
|
489 |
+
tei,t E0 I0
|
490 |
+
ten,t e0 N0
|
491 |
+
teng,t e0 ng0
|
492 |
+
ter,t er
|
493 |
+
ti,t i
|
494 |
+
tia,t ia
|
495 |
+
tian,t iE0 N0
|
496 |
+
tiang,t iA0 ng0
|
497 |
+
tiao,t iA0 O0
|
498 |
+
tie,t ie
|
499 |
+
tin,t i N0
|
500 |
+
ting,t i ng0
|
501 |
+
tiong,t iO0 ng0
|
502 |
+
tong,t oo0 ng0
|
503 |
+
tou,t o0 U0
|
504 |
+
tu,t u
|
505 |
+
tua,t ua
|
506 |
+
tuai,t ua0 I0
|
507 |
+
tuan,t ua0 N0
|
508 |
+
tuang,t uA0 ng0
|
509 |
+
tui,t uE0 I0
|
510 |
+
tun,t ue0 N0
|
511 |
+
tuo,t uo
|
512 |
+
tv,t v
|
513 |
+
tve,t ve
|
514 |
+
wa,w a
|
515 |
+
wai,w a0 I0
|
516 |
+
wan,w a0 N0
|
517 |
+
wang,w A0 ng0
|
518 |
+
wao,w A0 O0
|
519 |
+
we,w e
|
520 |
+
wei,w E0 I0
|
521 |
+
wen,w e0 N0
|
522 |
+
weng,w e0 ng0
|
523 |
+
wer,w er
|
524 |
+
wi,w i
|
525 |
+
wo,w o
|
526 |
+
wong,w oo0 ng0
|
527 |
+
wou,w o0 U0
|
528 |
+
wu,w u
|
529 |
+
xi,x i
|
530 |
+
xia,x ia
|
531 |
+
xian,x iE0 N0
|
532 |
+
xiang,x iA0 ng0
|
533 |
+
xiao,x iA0 O0
|
534 |
+
xie,x ie
|
535 |
+
xin,x i N0
|
536 |
+
xing,x i ng0
|
537 |
+
xiong,x iO0 ng0
|
538 |
+
xiu,x io0 U0
|
539 |
+
xu,x v
|
540 |
+
xuan,x vE0 N0
|
541 |
+
xue,x ve
|
542 |
+
xun,x v0 N0
|
543 |
+
ya,y a
|
544 |
+
yai,y a0 I0
|
545 |
+
yan,y iE0 N0
|
546 |
+
yang,y A0 ng0
|
547 |
+
yao,y A0 O0
|
548 |
+
ye,y E
|
549 |
+
yei,y E0 I0
|
550 |
+
yi,y i
|
551 |
+
yin,y i N0
|
552 |
+
ying,y i ng0
|
553 |
+
yo,y o
|
554 |
+
yong,y oo0 ng0
|
555 |
+
you,y o0 U0
|
556 |
+
yu,y v
|
557 |
+
yuan,y vE0 N0
|
558 |
+
yue,y ve
|
559 |
+
yun,y v0 N0
|
560 |
+
ywu,y u
|
561 |
+
za,z a
|
562 |
+
zai,z a0 I0
|
563 |
+
zan,z a0 N0
|
564 |
+
zang,z A0 ng0
|
565 |
+
zao,z A0 O0
|
566 |
+
ze,z e
|
567 |
+
zei,z E0 I0
|
568 |
+
zen,z e0 N0
|
569 |
+
zeng,z e0 ng0
|
570 |
+
zer,z er
|
571 |
+
zha,zh a
|
572 |
+
zhai,zh a0 I0
|
573 |
+
zhan,zh a0 N0
|
574 |
+
zhang,zh A0 ng0
|
575 |
+
zhao,zh A0 O0
|
576 |
+
zhe,zh e
|
577 |
+
zhei,zh E0 I0
|
578 |
+
zhen,zh e0 N0
|
579 |
+
zheng,zh e0 ng0
|
580 |
+
zher,zh er
|
581 |
+
zhi,zh ir
|
582 |
+
zhong,zh oo0 ng0
|
583 |
+
zhou,zh o0 U0
|
584 |
+
zhu,zh u
|
585 |
+
zhua,zh ua
|
586 |
+
zhuai,zh ua0 I0
|
587 |
+
zhuan,zh ua0 N0
|
588 |
+
zhuang,zh uA0 ng0
|
589 |
+
zhui,zh uE0 I0
|
590 |
+
zhun,zh ue0 N0
|
591 |
+
zhuo,zh uo
|
592 |
+
zhv,zh v
|
593 |
+
zhyi,zh i
|
594 |
+
zi,z i0
|
595 |
+
zong,z oo0 ng0
|
596 |
+
zou,z o0 U0
|
597 |
+
zu,z u
|
598 |
+
zua,z ua
|
599 |
+
zuai,z ua0 I0
|
600 |
+
zuan,z ua0 N0
|
601 |
+
zuang,z uA0 ng0
|
602 |
+
zui,z uE0 I0
|
603 |
+
zun,z ue0 N0
|
604 |
+
zuo,z uo
|
605 |
+
zv,z v
|
606 |
+
zyi,z i
|
text/custom_pypinyin_dict/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
text/custom_pypinyin_dict/cc_cedict_0.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text/custom_pypinyin_dict/cc_cedict_1.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text/custom_pypinyin_dict/cc_cedict_2.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
text/custom_pypinyin_dict/cc_cedict_3.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
from __future__ import unicode_literals
|
3 |
+
|
4 |
+
# Warning: Auto-generated file, don't edit.
|
5 |
+
phrases_dict = {
|
6 |
+
'𰻝𰻝面': [['biáng'], ['biáng'], ['miàn']],
|
7 |
+
}
|
8 |
+
|
9 |
+
|
10 |
+
from pypinyin import load_phrases_dict
|
11 |
+
|
12 |
+
|
13 |
+
def load():
|
14 |
+
load_phrases_dict(phrases_dict)
|
text/custom_pypinyin_dict/genshin.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
from __future__ import unicode_literals
|
3 |
+
|
4 |
+
phrases_dict = {
|
5 |
+
'㐖毒': [['xié'], ['dú']],
|
6 |
+
'若陀': [['rě'], ['tuó']],
|
7 |
+
'平藏': [['píng'], ['zàng']],
|
8 |
+
'派蒙': [['pài'], ['méng']],
|
9 |
+
'安柏': [['ān'], ['bó']],
|
10 |
+
'一斗': [['yī'], ['dǒu']]
|
11 |
+
}
|
text/custom_pypinyin_dict/phrase_pinyin_data.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
from __future__ import unicode_literals
|
3 |
+
|
4 |
+
from pypinyin import load_phrases_dict
|
5 |
+
|
6 |
+
from text.custom_pypinyin_dict import cc_cedict_0
|
7 |
+
from text.custom_pypinyin_dict import cc_cedict_1
|
8 |
+
from text.custom_pypinyin_dict import cc_cedict_2
|
9 |
+
from text.custom_pypinyin_dict import cc_cedict_3
|
10 |
+
from text.custom_pypinyin_dict import genshin
|
11 |
+
|
12 |
+
phrases_dict = {}
|
13 |
+
phrases_dict.update(cc_cedict_0.phrases_dict)
|
14 |
+
phrases_dict.update(cc_cedict_1.phrases_dict)
|
15 |
+
phrases_dict.update(cc_cedict_2.phrases_dict)
|
16 |
+
phrases_dict.update(cc_cedict_3.phrases_dict)
|
17 |
+
phrases_dict.update(genshin.phrases_dict)
|
18 |
+
|
19 |
+
def load():
|
20 |
+
load_phrases_dict(phrases_dict)
|
21 |
+
print("加载自定义词典成功")
|
22 |
+
|
23 |
+
if __name__ == '__main__':
|
24 |
+
print(phrases_dict)
|
text/english.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
|
3 |
+
'''
|
4 |
+
Cleaners are transformations that run over the input text at both training and eval time.
|
5 |
+
|
6 |
+
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
7 |
+
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
8 |
+
1. "english_cleaners" for English text
|
9 |
+
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
10 |
+
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
11 |
+
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
12 |
+
the symbols in symbols.py to match your data).
|
13 |
+
'''
|
14 |
+
|
15 |
+
|
16 |
+
# Regular expression matching whitespace:
|
17 |
+
|
18 |
+
|
19 |
+
import re
|
20 |
+
import inflect
|
21 |
+
from unidecode import unidecode
|
22 |
+
import eng_to_ipa as ipa
|
23 |
+
_inflect = inflect.engine()
|
24 |
+
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
25 |
+
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
26 |
+
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
|
27 |
+
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
|
28 |
+
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
29 |
+
_number_re = re.compile(r'[0-9]+')
|
30 |
+
|
31 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
32 |
+
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
33 |
+
('mrs', 'misess'),
|
34 |
+
('mr', 'mister'),
|
35 |
+
('dr', 'doctor'),
|
36 |
+
('st', 'saint'),
|
37 |
+
('co', 'company'),
|
38 |
+
('jr', 'junior'),
|
39 |
+
('maj', 'major'),
|
40 |
+
('gen', 'general'),
|
41 |
+
('drs', 'doctors'),
|
42 |
+
('rev', 'reverend'),
|
43 |
+
('lt', 'lieutenant'),
|
44 |
+
('hon', 'honorable'),
|
45 |
+
('sgt', 'sergeant'),
|
46 |
+
('capt', 'captain'),
|
47 |
+
('esq', 'esquire'),
|
48 |
+
('ltd', 'limited'),
|
49 |
+
('col', 'colonel'),
|
50 |
+
('ft', 'fort'),
|
51 |
+
]]
|
52 |
+
|
53 |
+
|
54 |
+
# List of (ipa, lazy ipa) pairs:
|
55 |
+
_lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
|
56 |
+
('r', 'ɹ'),
|
57 |
+
('æ', 'e'),
|
58 |
+
('ɑ', 'a'),
|
59 |
+
('ɔ', 'o'),
|
60 |
+
('ð', 'z'),
|
61 |
+
('θ', 's'),
|
62 |
+
('ɛ', 'e'),
|
63 |
+
('ɪ', 'i'),
|
64 |
+
('ʊ', 'u'),
|
65 |
+
('ʒ', 'ʥ'),
|
66 |
+
('ʤ', 'ʥ'),
|
67 |
+
('ˈ', '↓'),
|
68 |
+
]]
|
69 |
+
|
70 |
+
# List of (ipa, lazy ipa2) pairs:
|
71 |
+
_lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
|
72 |
+
('r', 'ɹ'),
|
73 |
+
('ð', 'z'),
|
74 |
+
('θ', 's'),
|
75 |
+
('ʒ', 'ʑ'),
|
76 |
+
('ʤ', 'dʑ'),
|
77 |
+
('ˈ', '↓'),
|
78 |
+
]]
|
79 |
+
|
80 |
+
# List of (ipa, ipa2) pairs
|
81 |
+
_ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
|
82 |
+
('r', 'ɹ'),
|
83 |
+
('ʤ', 'dʒ'),
|
84 |
+
('ʧ', 'tʃ')
|
85 |
+
]]
|
86 |
+
|
87 |
+
|
88 |
+
def expand_abbreviations(text):
|
89 |
+
for regex, replacement in _abbreviations:
|
90 |
+
text = re.sub(regex, replacement, text)
|
91 |
+
return text
|
92 |
+
|
93 |
+
|
94 |
+
def collapse_whitespace(text):
|
95 |
+
return re.sub(r'\s+', ' ', text)
|
96 |
+
|
97 |
+
|
98 |
+
def _remove_commas(m):
|
99 |
+
return m.group(1).replace(',', '')
|
100 |
+
|
101 |
+
|
102 |
+
def _expand_decimal_point(m):
|
103 |
+
return m.group(1).replace('.', ' point ')
|
104 |
+
|
105 |
+
|
106 |
+
def _expand_dollars(m):
|
107 |
+
match = m.group(1)
|
108 |
+
parts = match.split('.')
|
109 |
+
if len(parts) > 2:
|
110 |
+
return match + ' dollars' # Unexpected format
|
111 |
+
dollars = int(parts[0]) if parts[0] else 0
|
112 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
113 |
+
if dollars and cents:
|
114 |
+
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
115 |
+
cent_unit = 'cent' if cents == 1 else 'cents'
|
116 |
+
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
117 |
+
elif dollars:
|
118 |
+
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
119 |
+
return '%s %s' % (dollars, dollar_unit)
|
120 |
+
elif cents:
|
121 |
+
cent_unit = 'cent' if cents == 1 else 'cents'
|
122 |
+
return '%s %s' % (cents, cent_unit)
|
123 |
+
else:
|
124 |
+
return 'zero dollars'
|
125 |
+
|
126 |
+
|
127 |
+
def _expand_ordinal(m):
|
128 |
+
return _inflect.number_to_words(m.group(0))
|
129 |
+
|
130 |
+
|
131 |
+
def _expand_number(m):
|
132 |
+
num = int(m.group(0))
|
133 |
+
if num > 1000 and num < 3000:
|
134 |
+
if num == 2000:
|
135 |
+
return 'two thousand'
|
136 |
+
elif num > 2000 and num < 2010:
|
137 |
+
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
138 |
+
elif num % 100 == 0:
|
139 |
+
return _inflect.number_to_words(num // 100) + ' hundred'
|
140 |
+
else:
|
141 |
+
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
|
142 |
+
else:
|
143 |
+
return _inflect.number_to_words(num, andword='')
|
144 |
+
|
145 |
+
|
146 |
+
def normalize_numbers(text):
|
147 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
148 |
+
text = re.sub(_pounds_re, r'\1 pounds', text)
|
149 |
+
text = re.sub(_dollars_re, _expand_dollars, text)
|
150 |
+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
151 |
+
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
152 |
+
text = re.sub(_number_re, _expand_number, text)
|
153 |
+
return text
|
154 |
+
|
155 |
+
|
156 |
+
def mark_dark_l(text):
|
157 |
+
return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text)
|
158 |
+
|
159 |
+
|
160 |
+
def english_to_ipa(text):
|
161 |
+
text = unidecode(text).lower()
|
162 |
+
text = expand_abbreviations(text)
|
163 |
+
text = normalize_numbers(text)
|
164 |
+
phonemes = ipa.convert(text)
|
165 |
+
phonemes = collapse_whitespace(phonemes)
|
166 |
+
return phonemes
|
167 |
+
|
168 |
+
|
169 |
+
def english_to_ipa2(text):
|
170 |
+
text = english_to_ipa(text)
|
171 |
+
text = mark_dark_l(text)
|
172 |
+
for regex, replacement in _ipa_to_ipa2:
|
173 |
+
text = re.sub(regex, replacement, text)
|
174 |
+
return list(text.replace('...', '…'))
|
175 |
+
|
text/japanese.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from unidecode import unidecode
|
3 |
+
import pyopenjtalk
|
4 |
+
|
5 |
+
|
6 |
+
# Regular expression matching Japanese without punctuation marks:
|
7 |
+
_japanese_characters = re.compile(
|
8 |
+
r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
|
9 |
+
|
10 |
+
# Regular expression matching non-Japanese characters or punctuation marks:
|
11 |
+
_japanese_marks = re.compile(
|
12 |
+
r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
|
13 |
+
|
14 |
+
# List of (symbol, Japanese) pairs for marks:
|
15 |
+
_symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [
|
16 |
+
('%', 'パーセント')
|
17 |
+
]]
|
18 |
+
|
19 |
+
# List of (romaji, ipa) pairs for marks:
|
20 |
+
_romaji_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
|
21 |
+
('ts', 'ʦ'),
|
22 |
+
('u', 'ɯ'),
|
23 |
+
('j', 'ʥ'),
|
24 |
+
('y', 'j'),
|
25 |
+
('ni', 'n^i'),
|
26 |
+
('nj', 'n^'),
|
27 |
+
('hi', 'çi'),
|
28 |
+
('hj', 'ç'),
|
29 |
+
('f', 'ɸ'),
|
30 |
+
('I', 'i*'),
|
31 |
+
('U', 'ɯ*'),
|
32 |
+
('r', 'ɾ')
|
33 |
+
]]
|
34 |
+
|
35 |
+
# List of (romaji, ipa2) pairs for marks:
|
36 |
+
_romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
|
37 |
+
('u', 'ɯ'),
|
38 |
+
('ʧ', 'tʃ'),
|
39 |
+
('j', 'dʑ'),
|
40 |
+
('y', 'j'),
|
41 |
+
('ni', 'n^i'),
|
42 |
+
('nj', 'n^'),
|
43 |
+
('hi', 'çi'),
|
44 |
+
('hj', 'ç'),
|
45 |
+
('f', 'ɸ'),
|
46 |
+
('I', 'i*'),
|
47 |
+
('U', 'ɯ*'),
|
48 |
+
('r', 'ɾ')
|
49 |
+
]]
|
50 |
+
|
51 |
+
# List of (consonant, sokuon) pairs:
|
52 |
+
_real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [
|
53 |
+
(r'Q([↑↓]*[kg])', r'k#\1'),
|
54 |
+
(r'Q([↑↓]*[tdjʧ])', r't#\1'),
|
55 |
+
(r'Q([↑↓]*[sʃ])', r's\1'),
|
56 |
+
(r'Q([↑↓]*[pb])', r'p#\1')
|
57 |
+
]]
|
58 |
+
|
59 |
+
# List of (consonant, hatsuon) pairs:
|
60 |
+
_real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [
|
61 |
+
(r'N([↑↓]*[pbm])', r'm\1'),
|
62 |
+
(r'N([↑↓]*[ʧʥj])', r'n^\1'),
|
63 |
+
(r'N([↑↓]*[tdn])', r'n\1'),
|
64 |
+
(r'N([↑↓]*[kg])', r'ŋ\1')
|
65 |
+
]]
|
66 |
+
|
67 |
+
|
68 |
+
def symbols_to_japanese(text):
|
69 |
+
for regex, replacement in _symbols_to_japanese:
|
70 |
+
text = re.sub(regex, replacement, text)
|
71 |
+
return text
|
72 |
+
|
73 |
+
|
74 |
+
def japanese_to_romaji_with_accent(text):
|
75 |
+
'''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html'''
|
76 |
+
text = symbols_to_japanese(text)
|
77 |
+
sentences = re.split(_japanese_marks, text)
|
78 |
+
marks = re.findall(_japanese_marks, text)
|
79 |
+
text = ''
|
80 |
+
for i, sentence in enumerate(sentences):
|
81 |
+
if re.match(_japanese_characters, sentence):
|
82 |
+
if text != '':
|
83 |
+
text += ' '
|
84 |
+
labels = pyopenjtalk.extract_fullcontext(sentence)
|
85 |
+
for n, label in enumerate(labels):
|
86 |
+
phoneme = re.search(r'\-([^\+]*)\+', label).group(1)
|
87 |
+
if phoneme not in ['sil', 'pau']:
|
88 |
+
text += phoneme.replace('ch', 'ʧ').replace('sh',
|
89 |
+
'ʃ').replace('cl', 'Q')
|
90 |
+
else:
|
91 |
+
continue
|
92 |
+
# n_moras = int(re.search(r'/F:(\d+)_', label).group(1))
|
93 |
+
a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
|
94 |
+
a2 = int(re.search(r"\+(\d+)\+", label).group(1))
|
95 |
+
a3 = int(re.search(r"\+(\d+)/", label).group(1))
|
96 |
+
if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']:
|
97 |
+
a2_next = -1
|
98 |
+
else:
|
99 |
+
a2_next = int(
|
100 |
+
re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
|
101 |
+
# Accent phrase boundary
|
102 |
+
if a3 == 1 and a2_next == 1:
|
103 |
+
text += ' '
|
104 |
+
# Falling
|
105 |
+
elif a1 == 0 and a2_next == a2 + 1:
|
106 |
+
text += '↓'
|
107 |
+
# Rising
|
108 |
+
elif a2 == 1 and a2_next == 2:
|
109 |
+
text += '↑'
|
110 |
+
if i < len(marks):
|
111 |
+
text += unidecode(marks[i]).replace(' ', '')
|
112 |
+
return text
|
113 |
+
|
114 |
+
|
115 |
+
def get_real_sokuon(text):
|
116 |
+
for regex, replacement in _real_sokuon:
|
117 |
+
text = re.sub(regex, replacement, text)
|
118 |
+
return text
|
119 |
+
|
120 |
+
|
121 |
+
def get_real_hatsuon(text):
|
122 |
+
for regex, replacement in _real_hatsuon:
|
123 |
+
text = re.sub(regex, replacement, text)
|
124 |
+
return text
|
125 |
+
|
126 |
+
|
127 |
+
def japanese_to_ipa(text):
|
128 |
+
text = japanese_to_romaji_with_accent(text).replace('...', '…')
|
129 |
+
text = re.sub(
|
130 |
+
r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
|
131 |
+
text = get_real_sokuon(text)
|
132 |
+
text = get_real_hatsuon(text)
|
133 |
+
for regex, replacement in _romaji_to_ipa:
|
134 |
+
text = re.sub(regex, replacement, text)
|
135 |
+
return text
|
136 |
+
|
137 |
+
|
138 |
+
def japanese_to_ipa2(text):
|
139 |
+
text = japanese_to_romaji_with_accent(text).replace('...', '…')
|
140 |
+
text = get_real_sokuon(text)
|
141 |
+
text = get_real_hatsuon(text)
|
142 |
+
for regex, replacement in _romaji_to_ipa2:
|
143 |
+
text = re.sub(regex, replacement, text)
|
144 |
+
return list(text)
|
145 |
+
|
146 |
+
|
147 |
+
def japanese_to_ipa3(text):
|
148 |
+
text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace(
|
149 |
+
'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a')
|
150 |
+
text = re.sub(
|
151 |
+
r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
|
152 |
+
text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text)
|
153 |
+
return text
|
154 |
+
|
155 |
+
if __name__ == '__main__':
|
156 |
+
a = japanese_to_romaji_with_accent('こんにちは!はい、元気です。あなたは?')
|
157 |
+
print(a)
|
text/mandarin.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import Dict, List
|
3 |
+
from pypinyin import lazy_pinyin, Style
|
4 |
+
from .custom_pypinyin_dict import phrase_pinyin_data
|
5 |
+
import jieba
|
6 |
+
from .cn2an import an2cn
|
7 |
+
|
8 |
+
# 加载自定义拼音词典数据
|
9 |
+
phrase_pinyin_data.load()
|
10 |
+
|
11 |
+
# 标点符号正则
|
12 |
+
PUNC_MAP: Dict[str, str] = {
|
13 |
+
":": ",",
|
14 |
+
";": ",",
|
15 |
+
",": ",",
|
16 |
+
"。": ".",
|
17 |
+
"!": "!",
|
18 |
+
"?": "?",
|
19 |
+
"\n": ".",
|
20 |
+
"·": ",",
|
21 |
+
"、": ",",
|
22 |
+
"$": ".",
|
23 |
+
"/": ",",
|
24 |
+
"“": "'",
|
25 |
+
"”": "'",
|
26 |
+
'"': "'",
|
27 |
+
"‘": "'",
|
28 |
+
"’": "'",
|
29 |
+
"(": "'",
|
30 |
+
")": "'",
|
31 |
+
"(": "'",
|
32 |
+
")": "'",
|
33 |
+
"《": "'",
|
34 |
+
"》": "'",
|
35 |
+
"【": "'",
|
36 |
+
"】": "'",
|
37 |
+
"[": "'",
|
38 |
+
"]": "'",
|
39 |
+
"—": "-",
|
40 |
+
"~": "~",
|
41 |
+
"「": "'",
|
42 |
+
"」": "'",
|
43 |
+
"『": "'",
|
44 |
+
"』": "'",
|
45 |
+
}
|
46 |
+
|
47 |
+
# from GPT_SoVITS.text.zh_normalization.text_normlization
|
48 |
+
PUNC_MAP.update ({
|
49 |
+
'/': '每',
|
50 |
+
'①': '一',
|
51 |
+
'②': '二',
|
52 |
+
'③': '三',
|
53 |
+
'④': '四',
|
54 |
+
'⑤': '五',
|
55 |
+
'⑥': '六',
|
56 |
+
'⑦': '七',
|
57 |
+
'⑧': '八',
|
58 |
+
'⑨': '九',
|
59 |
+
'⑩': '十',
|
60 |
+
'α': '阿尔法',
|
61 |
+
'β': '贝塔',
|
62 |
+
'γ': '伽玛',
|
63 |
+
'Γ': '伽玛',
|
64 |
+
'δ': '德尔塔',
|
65 |
+
'Δ': '德尔塔',
|
66 |
+
'ε': '艾普西龙',
|
67 |
+
'ζ': '捷塔',
|
68 |
+
'η': '依塔',
|
69 |
+
'θ': '西塔',
|
70 |
+
'Θ': '西塔',
|
71 |
+
'ι': '艾欧塔',
|
72 |
+
'κ': '喀帕',
|
73 |
+
'λ': '拉姆达',
|
74 |
+
'Λ': '拉姆达',
|
75 |
+
'μ': '缪',
|
76 |
+
'ν': '拗',
|
77 |
+
'ξ': '克西',
|
78 |
+
'Ξ': '克西',
|
79 |
+
'ο': '欧米克伦',
|
80 |
+
'π': '派',
|
81 |
+
'Π': '派',
|
82 |
+
'ρ': '肉',
|
83 |
+
'ς': '西格玛',
|
84 |
+
'σ': '西格玛',
|
85 |
+
'Σ': '西格玛',
|
86 |
+
'τ': '套',
|
87 |
+
'υ': '宇普西龙',
|
88 |
+
'φ': '服艾',
|
89 |
+
'Φ': '服艾',
|
90 |
+
'χ': '器',
|
91 |
+
'ψ': '普赛',
|
92 |
+
'Ψ': '普赛',
|
93 |
+
'ω': '欧米伽',
|
94 |
+
'Ω': '欧米伽',
|
95 |
+
'+': '加',
|
96 |
+
'-': '减',
|
97 |
+
'×': '乘',
|
98 |
+
'÷': '除',
|
99 |
+
'=': '等',
|
100 |
+
|
101 |
+
"嗯": "恩",
|
102 |
+
"呣": "母"
|
103 |
+
})
|
104 |
+
|
105 |
+
PUNC_TABLE = str.maketrans(PUNC_MAP)
|
106 |
+
|
107 |
+
# 数字正则化
|
108 |
+
NUMBER_PATTERN: re.Pattern = re.compile(r'\d+(?:\.?\d+)?')
|
109 |
+
|
110 |
+
# 阿拉伯数字转汉字
|
111 |
+
def replace_number(match: re.Match) -> str:
|
112 |
+
return an2cn(match.group())
|
113 |
+
|
114 |
+
def normalize_number(text: str) -> str:
|
115 |
+
return NUMBER_PATTERN.sub(replace_number, text)
|
116 |
+
|
117 |
+
# get symbols of phones, not used
|
118 |
+
def load_pinyin_symbols(path):
|
119 |
+
pinyin_dict={}
|
120 |
+
temp = []
|
121 |
+
with open(path, "r", encoding='utf-8') as f:
|
122 |
+
content = f.readlines()
|
123 |
+
for line in content:
|
124 |
+
cuts = line.strip().split(',')
|
125 |
+
pinyin = cuts[0]
|
126 |
+
phones = cuts[1].split(' ')
|
127 |
+
pinyin_dict[pinyin] = phones
|
128 |
+
temp.extend(phones)
|
129 |
+
temp = list(set(temp))
|
130 |
+
tone = []
|
131 |
+
for phone in temp:
|
132 |
+
for i in range(1, 6):
|
133 |
+
phone2 = phone + str(i)
|
134 |
+
tone.append(phone2)
|
135 |
+
print(sorted(tone, key=lambda x: len(x)))
|
136 |
+
return pinyin_dict
|
137 |
+
|
138 |
+
def load_pinyin_dict(path: str) -> Dict[str, List[str]]:
|
139 |
+
pinyin_dict = {}
|
140 |
+
with open(path, "r", encoding='utf-8') as f:
|
141 |
+
for line in f:
|
142 |
+
key, value = line.strip().split(',', 1)
|
143 |
+
pinyin_dict[key] = value.split()
|
144 |
+
return pinyin_dict
|
145 |
+
|
146 |
+
import os
|
147 |
+
pinyin_dict = load_pinyin_dict(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'cnm3', 'ds_CNM3.txt'))
|
148 |
+
# pinyin_dict = load_pinyin_dict('text/cnm3/ds_CNM3.txt')
|
149 |
+
|
150 |
+
def chinese_to_cnm3(text: str) -> List[str]:
|
151 |
+
# 标点符号和数字正则化
|
152 |
+
text = text.translate(PUNC_TABLE)
|
153 |
+
text = normalize_number(text)
|
154 |
+
# 过滤掉特殊字符
|
155 |
+
text = re.sub(r'[#&@“”^_|\\]', '', text)
|
156 |
+
|
157 |
+
words = jieba.lcut(text, cut_all=False)
|
158 |
+
|
159 |
+
phones = []
|
160 |
+
for word in words:
|
161 |
+
pinyin_list: List[str] = lazy_pinyin(word, style=Style.TONE3, neutral_tone_with_five=True)
|
162 |
+
for pinyin in pinyin_list:
|
163 |
+
if pinyin[-1].isdigit():
|
164 |
+
tone = pinyin[-1]
|
165 |
+
syllable = pinyin[:-1]
|
166 |
+
phone = pinyin_dict[syllable]
|
167 |
+
phones.extend([ph + tone for ph in phone])
|
168 |
+
elif pinyin[-1].isalpha():
|
169 |
+
pass
|
170 |
+
else:
|
171 |
+
phones.extend(pinyin)
|
172 |
+
|
173 |
+
return phones
|
text/symbols.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Defines the set of symbols used in text input to the model.
|
3 |
+
'''
|
4 |
+
|
5 |
+
# japanese_cleaners
|
6 |
+
# _pad = '_'
|
7 |
+
# _punctuation = ',.!?-'
|
8 |
+
# _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
|
9 |
+
|
10 |
+
|
11 |
+
'''# japanese_cleaners2
|
12 |
+
_pad = '_'
|
13 |
+
_punctuation = ',.!?-~…'
|
14 |
+
_letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
|
15 |
+
'''
|
16 |
+
|
17 |
+
|
18 |
+
'''# korean_cleaners
|
19 |
+
_pad = '_'
|
20 |
+
_punctuation = ',.!?…~'
|
21 |
+
_letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
|
22 |
+
'''
|
23 |
+
|
24 |
+
'''# chinese_cleaners
|
25 |
+
_pad = '_'
|
26 |
+
_punctuation = ',。!?—…'
|
27 |
+
_letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
|
28 |
+
'''
|
29 |
+
|
30 |
+
# # zh_ja_mixture_cleaners
|
31 |
+
# _pad = '_'
|
32 |
+
# _punctuation = ',.!?-~…'
|
33 |
+
# _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ '
|
34 |
+
|
35 |
+
|
36 |
+
'''# sanskrit_cleaners
|
37 |
+
_pad = '_'
|
38 |
+
_punctuation = '।'
|
39 |
+
_letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ '
|
40 |
+
'''
|
41 |
+
|
42 |
+
'''# cjks_cleaners
|
43 |
+
_pad = '_'
|
44 |
+
_punctuation = ',.!?-~…'
|
45 |
+
_letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ '
|
46 |
+
'''
|
47 |
+
|
48 |
+
'''# thai_cleaners
|
49 |
+
_pad = '_'
|
50 |
+
_punctuation = '.!? '
|
51 |
+
_letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์'
|
52 |
+
'''
|
53 |
+
|
54 |
+
# # cjke_cleaners2
|
55 |
+
_pad = '_'
|
56 |
+
_punctuation = ',.!?-~…' + "'"
|
57 |
+
_IPA_letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
|
58 |
+
_CNM3_letters = ['y1', 'y2', 'y3', 'y4', 'y5', 'n1', 'n2', 'n3', 'n4', 'n5', 'p1', 'p2', 'p3', 'p4', 'p5', 'x1', 'x2', 'x3', 'x4', 'x5', 'k1', 'k2', 'k3', 'k4', 'k5', 'l1', 'l2', 'l3', 'l4', 'l5', 'q1', 'q2', 'q3', 'q4', 'q5', 'w1', 'w2', 'w3', 'w4', 'w5', 'E1', 'E2', 'E3', 'E4', 'E5', 'b1', 'b2', 'b3', 'b4', 'b5', 'c1', 'c2', 'c3', 'c4', 'c5', 'z1', 'z2', 'z3', 'z4', 'z5', 'e1', 'e2', 'e3', 'e4', 'e5', 'f1', 'f2', 'f3', 'f4', 'f5', 's1', 's2', 's3', 's4', 's5', 'j1', 'j2', 'j3', 'j4', 'j5', 'o1', 'o2', 'o3', 'o4', 'o5', 'i1', 'i2', 'i3', 'i4', 'i5', 'd1', 'd2', 'd3', 'd4', 'd5', 'm1', 'm2', 'm3', 'm4', 'm5', 't1', 't2', 't3', 't4', 't5', 'h1', 'h2', 'h3', 'h4', 'h5', 'g1', 'g2', 'g3', 'g4', 'g5', 'v1', 'v2', 'v3', 'v4', 'v5', 'r1', 'r2', 'r3', 'r4', 'r5', 'a1', 'a2', 'a3', 'a4', 'a5', 'u1', 'u2', 'u3', 'u4', 'u5', 'I01', 'I02', 'I03', 'I04', 'I05', 'i01', 'i02', 'i03', 'i04', 'i05', 'uo1', 'uo2', 'uo3', 'uo4', 'uo5', 'o01', 'o02', 'o03', 'o04', 'o05', 'U01', 'U02', 'U03', 'U04', 'U05', 'v01', 'v02', 'v03', 'v04', 'v05', 'er1', 'er2', 'er3', 'er4', 'er5', 'A01', 'A02', 'A03', 'A04', 'A05', 'ai1', 'ai2', 'ai3', 'ai4', 'ai5', 'e01', 'e02', 'e03', 'e04', 'e05', 'sh1', 'sh2', 'sh3', 'sh4', 'sh5', 'an1', 'an2', 'an3', 'an4', 'an5', 'ou1', 'ou2', 'ou3', 'ou4', 'ou5', 'ch1', 'ch2', 'ch3', 'ch4', 'ch5', 'a01', 'a02', 'a03', 'a04', 'a05', 'N01', 'N02', 'N03', 'N04', 'N05', 'ao1', 'ao2', 'ao3', 'ao4', 'ao5', 've1', 've2', 've3', 've4', 've5', 'ir1', 'ir2', 'ir3', 'ir4', 'ir5', 'ng1', 'ng2', 'ng3', 'ng4', 'ng5', 'ua1', 'ua2', 'ua3', 'ua4', 'ua5', 'zh1', 'zh2', 'zh3', 'zh4', 'zh5', 'O01', 'O02', 'O03', 'O04', 'O05', 'ie1', 'ie2', 'ie3', 'ie4', 'ie5', 'E01', 'E02', 'E03', 'E04', 'E05', 'ia1', 'ia2', 'ia3', 'ia4', 'ia5', 'iE01', 'iE02', 'iE03', 'iE04', 'iE05', 'ang1', 'ang2', 'ang3', 'ang4', 'ang5', 'ng01', 'ng02', 'ng03', 'ng04', 'ng05', 'io01', 'io02', 'io03', 'io04', 'io05', 'iA01', 'iA02', 'iA03', 'iA04', 'iA05', 'uA01', 'uA02', 'uA03', 'uA04', 'uA05', 'ong1', 'ong2', 'ong3', 'ong4', 'ong5', 'oo01', 'oo02', 'oo03', 'oo04', 'oo05', 'uE01', 'uE02', 'uE03', 'uE04', 'uE05', 'vE01', 'vE02', 'vE03', 'vE04', 'vE05', 'ue01', 'ue02', 'ue03', 'ue04', 'ue05', 'ua01', 'ua02', 'ua03', 'ua04', 'ua05', 'iO01', 'iO02', 'iO03', 'iO04', 'iO05']
|
59 |
+
_additional = ['<sil>', '<asp>']
|
60 |
+
# _CNM3_letters = []
|
61 |
+
|
62 |
+
|
63 |
+
'''# shanghainese_cleaners
|
64 |
+
_pad = '_'
|
65 |
+
_punctuation = ',.!?…'
|
66 |
+
_letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 '
|
67 |
+
'''
|
68 |
+
|
69 |
+
'''# chinese_dialect_cleaners
|
70 |
+
_pad = '_'
|
71 |
+
_punctuation = ',.!?~…─'
|
72 |
+
_letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ '
|
73 |
+
'''
|
74 |
+
|
75 |
+
# Export all symbols:
|
76 |
+
symbols = [_pad] + list(_punctuation) + list(_IPA_letters) + _CNM3_letters + _additional
|
77 |
+
|
78 |
+
# Special symbol ids
|
79 |
+
SPACE_ID = symbols.index(" ")
|
utils/__init__.py
ADDED
File without changes
|
utils/audio.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import Tensor
|
3 |
+
import torch.nn as nn
|
4 |
+
import torchaudio
|
5 |
+
|
6 |
+
class LinearSpectrogram(nn.Module):
|
7 |
+
def __init__(self, n_fft, win_length, hop_length, pad, center, pad_mode):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
self.n_fft = n_fft
|
11 |
+
self.win_length = win_length
|
12 |
+
self.hop_length = hop_length
|
13 |
+
self.pad = pad
|
14 |
+
self.center = center
|
15 |
+
self.pad_mode = pad_mode
|
16 |
+
|
17 |
+
self.register_buffer("window", torch.hann_window(win_length))
|
18 |
+
|
19 |
+
def forward(self, waveform: Tensor) -> Tensor:
|
20 |
+
if waveform.ndim == 3:
|
21 |
+
waveform = waveform.squeeze(1)
|
22 |
+
waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (self.pad, self.pad), self.pad_mode).squeeze(1)
|
23 |
+
spec = torch.stft(waveform, self.n_fft, self.hop_length, self.win_length, self.window, self.center, self.pad_mode, False, True, True)
|
24 |
+
spec = torch.view_as_real(spec)
|
25 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
26 |
+
return spec
|
27 |
+
|
28 |
+
|
29 |
+
class LogMelSpectrogram(nn.Module):
|
30 |
+
def __init__(self, sample_rate, n_fft, win_length, hop_length, f_min, f_max, pad, n_mels, center, pad_mode, mel_scale):
|
31 |
+
super().__init__()
|
32 |
+
self.sample_rate = sample_rate
|
33 |
+
self.n_fft = n_fft
|
34 |
+
self.win_length = win_length
|
35 |
+
self.hop_length = hop_length
|
36 |
+
self.f_min = f_min
|
37 |
+
self.f_max = f_max
|
38 |
+
self.pad = pad
|
39 |
+
self.n_mels = n_mels
|
40 |
+
self.center = center
|
41 |
+
self.pad_mode = pad_mode
|
42 |
+
self.mel_scale = mel_scale
|
43 |
+
|
44 |
+
self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, pad, center, pad_mode)
|
45 |
+
self.mel_scale = torchaudio.transforms.MelScale(n_mels, sample_rate, f_min, f_max, (n_fft//2)+1, mel_scale, mel_scale)
|
46 |
+
|
47 |
+
def compress(self, x: Tensor) -> Tensor:
|
48 |
+
return torch.log(torch.clamp(x, min=1e-5))
|
49 |
+
|
50 |
+
def decompress(self, x: Tensor) -> Tensor:
|
51 |
+
return torch.exp(x)
|
52 |
+
|
53 |
+
def forward(self, x: Tensor) -> Tensor:
|
54 |
+
linear_spec = self.spectrogram(x)
|
55 |
+
x = self.mel_scale(linear_spec)
|
56 |
+
x = self.compress(x)
|
57 |
+
return x
|
58 |
+
|
59 |
+
def load_and_resample_audio(audio_path, target_sr, device='cpu') -> Tensor:
|
60 |
+
try:
|
61 |
+
y, sr = torchaudio.load(audio_path)
|
62 |
+
except Exception as e:
|
63 |
+
print(str(e))
|
64 |
+
return None
|
65 |
+
|
66 |
+
y.to(device)
|
67 |
+
# Convert to mono
|
68 |
+
if y.size(0) > 1:
|
69 |
+
y = y[0, :].unsqueeze(0) # shape: [2, time] -> [time] -> [1, time]
|
70 |
+
|
71 |
+
# resample audio to target sample_rate
|
72 |
+
if sr != target_sr:
|
73 |
+
y = torchaudio.functional.resample(y, sr, target_sr)
|
74 |
+
return y
|
utils/load.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.optim as optim
|
5 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
6 |
+
|
7 |
+
def continue_training(checkpoint_path, model: DDP, optimizer: optim.Optimizer) -> int:
|
8 |
+
"""load the latest checkpoints and optimizers"""
|
9 |
+
model_dict = {}
|
10 |
+
optimizer_dict = {}
|
11 |
+
|
12 |
+
# globt all the checkpoints in the directory
|
13 |
+
for file in os.listdir(checkpoint_path):
|
14 |
+
if file.endswith(".pt") and '_' in file:
|
15 |
+
name, epoch_str = file.rsplit('_', 1)
|
16 |
+
epoch = int(epoch_str.split('.')[0])
|
17 |
+
|
18 |
+
if name.startswith("checkpoint"):
|
19 |
+
model_dict[epoch] = file
|
20 |
+
elif name.startswith("optimizer"):
|
21 |
+
optimizer_dict[epoch] = file
|
22 |
+
|
23 |
+
# get the largest epoch
|
24 |
+
common_epochs = set(model_dict.keys()) & set(optimizer_dict.keys())
|
25 |
+
if common_epochs:
|
26 |
+
max_epoch = max(common_epochs)
|
27 |
+
model_path = os.path.join(checkpoint_path, model_dict[max_epoch])
|
28 |
+
optimizer_path = os.path.join(checkpoint_path, optimizer_dict[max_epoch])
|
29 |
+
|
30 |
+
# load model and optimizer
|
31 |
+
model.module.load_state_dict(torch.load(model_path, map_location='cpu'))
|
32 |
+
optimizer.load_state_dict(torch.load(optimizer_path, map_location='cpu'))
|
33 |
+
|
34 |
+
print(f'resume model and optimizer from {max_epoch} epoch')
|
35 |
+
return max_epoch + 1
|
36 |
+
|
37 |
+
else:
|
38 |
+
# load pretrained checkpoint
|
39 |
+
if model_dict:
|
40 |
+
model_path = os.path.join(checkpoint_path, model_dict[max(model_dict.keys())])
|
41 |
+
model.module.load_state_dict(torch.load(model_path, map_location='cpu'))
|
42 |
+
|
43 |
+
return 0
|