Spaces:
Runtime error
Runtime error
fix
Browse files- data/tokenizer.py +0 -260
data/tokenizer.py
CHANGED
|
@@ -22,160 +22,6 @@ import torch
|
|
| 22 |
import torchaudio
|
| 23 |
from encodec import EncodecModel
|
| 24 |
from encodec.utils import convert_audio
|
| 25 |
-
from phonemizer.backend import EspeakBackend
|
| 26 |
-
from phonemizer.backend.espeak.language_switch import LanguageSwitch
|
| 27 |
-
from phonemizer.backend.espeak.words_mismatch import WordMismatch
|
| 28 |
-
from phonemizer.punctuation import Punctuation
|
| 29 |
-
from phonemizer.separator import Separator
|
| 30 |
-
from phonemizer.separator import Separator
|
| 31 |
-
|
| 32 |
-
try:
|
| 33 |
-
from pypinyin import Style, pinyin
|
| 34 |
-
from pypinyin.style._utils import get_finals, get_initials
|
| 35 |
-
except Exception:
|
| 36 |
-
pass
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
class PypinyinBackend:
|
| 40 |
-
"""PypinyinBackend for Chinese. Most codes is referenced from espnet.
|
| 41 |
-
There are two types pinyin or initials_finals, one is
|
| 42 |
-
just like "ni1 hao3", the other is like "n i1 h ao3".
|
| 43 |
-
"""
|
| 44 |
-
|
| 45 |
-
def __init__(
|
| 46 |
-
self,
|
| 47 |
-
backend="initials_finals",
|
| 48 |
-
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
| 49 |
-
) -> None:
|
| 50 |
-
self.backend = backend
|
| 51 |
-
self.punctuation_marks = punctuation_marks
|
| 52 |
-
|
| 53 |
-
def phonemize(
|
| 54 |
-
self, text: List[str], separator: Separator, strip=True, njobs=1
|
| 55 |
-
) -> List[str]:
|
| 56 |
-
assert isinstance(text, List)
|
| 57 |
-
phonemized = []
|
| 58 |
-
for _text in text:
|
| 59 |
-
_text = re.sub(" +", " ", _text.strip())
|
| 60 |
-
_text = _text.replace(" ", separator.word)
|
| 61 |
-
phones = []
|
| 62 |
-
if self.backend == "pypinyin":
|
| 63 |
-
for n, py in enumerate(
|
| 64 |
-
pinyin(
|
| 65 |
-
_text, style=Style.TONE3, neutral_tone_with_five=True
|
| 66 |
-
)
|
| 67 |
-
):
|
| 68 |
-
if all([c in self.punctuation_marks for c in py[0]]):
|
| 69 |
-
if len(phones):
|
| 70 |
-
assert phones[-1] == separator.syllable
|
| 71 |
-
phones.pop(-1)
|
| 72 |
-
|
| 73 |
-
phones.extend(list(py[0]))
|
| 74 |
-
else:
|
| 75 |
-
phones.extend([py[0], separator.syllable])
|
| 76 |
-
elif self.backend == "pypinyin_initials_finals":
|
| 77 |
-
for n, py in enumerate(
|
| 78 |
-
pinyin(
|
| 79 |
-
_text, style=Style.TONE3, neutral_tone_with_five=True
|
| 80 |
-
)
|
| 81 |
-
):
|
| 82 |
-
if all([c in self.punctuation_marks for c in py[0]]):
|
| 83 |
-
if len(phones):
|
| 84 |
-
assert phones[-1] == separator.syllable
|
| 85 |
-
phones.pop(-1)
|
| 86 |
-
phones.extend(list(py[0]))
|
| 87 |
-
else:
|
| 88 |
-
if py[0][-1].isalnum():
|
| 89 |
-
initial = get_initials(py[0], strict=False)
|
| 90 |
-
if py[0][-1].isdigit():
|
| 91 |
-
final = (
|
| 92 |
-
get_finals(py[0][:-1], strict=False)
|
| 93 |
-
+ py[0][-1]
|
| 94 |
-
)
|
| 95 |
-
else:
|
| 96 |
-
final = get_finals(py[0], strict=False)
|
| 97 |
-
phones.extend(
|
| 98 |
-
[
|
| 99 |
-
initial,
|
| 100 |
-
separator.phone,
|
| 101 |
-
final,
|
| 102 |
-
separator.syllable,
|
| 103 |
-
]
|
| 104 |
-
)
|
| 105 |
-
else:
|
| 106 |
-
assert ValueError
|
| 107 |
-
else:
|
| 108 |
-
raise NotImplementedError
|
| 109 |
-
phonemized.append(
|
| 110 |
-
"".join(phones).rstrip(f"{separator.word}{separator.syllable}")
|
| 111 |
-
)
|
| 112 |
-
return phonemized
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
class TextTokenizer:
|
| 116 |
-
"""Phonemize Text."""
|
| 117 |
-
|
| 118 |
-
def __init__(
|
| 119 |
-
self,
|
| 120 |
-
language="en-us",
|
| 121 |
-
backend="espeak",
|
| 122 |
-
separator=Separator(word="_", syllable="-", phone="|"),
|
| 123 |
-
preserve_punctuation=True,
|
| 124 |
-
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
| 125 |
-
with_stress: bool = False,
|
| 126 |
-
tie: Union[bool, str] = False,
|
| 127 |
-
language_switch: LanguageSwitch = "keep-flags",
|
| 128 |
-
words_mismatch: WordMismatch = "ignore",
|
| 129 |
-
) -> None:
|
| 130 |
-
if backend == "espeak":
|
| 131 |
-
phonemizer = EspeakBackend(
|
| 132 |
-
language,
|
| 133 |
-
punctuation_marks=punctuation_marks,
|
| 134 |
-
preserve_punctuation=preserve_punctuation,
|
| 135 |
-
with_stress=with_stress,
|
| 136 |
-
tie=tie,
|
| 137 |
-
language_switch=language_switch,
|
| 138 |
-
words_mismatch=words_mismatch,
|
| 139 |
-
)
|
| 140 |
-
elif backend in ["pypinyin", "pypinyin_initials_finals"]:
|
| 141 |
-
phonemizer = PypinyinBackend(
|
| 142 |
-
backend=backend,
|
| 143 |
-
punctuation_marks=punctuation_marks + separator.word,
|
| 144 |
-
)
|
| 145 |
-
else:
|
| 146 |
-
raise NotImplementedError(f"{backend}")
|
| 147 |
-
|
| 148 |
-
self.backend = phonemizer
|
| 149 |
-
self.separator = separator
|
| 150 |
-
|
| 151 |
-
def to_list(self, phonemized: str) -> List[str]:
|
| 152 |
-
fields = []
|
| 153 |
-
for word in phonemized.split(self.separator.word):
|
| 154 |
-
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
|
| 155 |
-
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
|
| 156 |
-
fields.extend(
|
| 157 |
-
[p for p in pp if p != self.separator.phone]
|
| 158 |
-
+ [self.separator.word]
|
| 159 |
-
)
|
| 160 |
-
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
|
| 161 |
-
self.separator.phone
|
| 162 |
-
)
|
| 163 |
-
return fields[:-1]
|
| 164 |
-
|
| 165 |
-
def __call__(self, text, strip=True) -> List[List[str]]:
|
| 166 |
-
if isinstance(text, str):
|
| 167 |
-
text = [text]
|
| 168 |
-
|
| 169 |
-
phonemized = self.backend.phonemize(
|
| 170 |
-
text, separator=self.separator, strip=strip, njobs=1
|
| 171 |
-
)
|
| 172 |
-
return [self.to_list(p) for p in phonemized]
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
|
| 176 |
-
phonemes = tokenizer([text.strip()])
|
| 177 |
-
return phonemes[0] # k2symbols
|
| 178 |
-
|
| 179 |
|
| 180 |
def remove_encodec_weight_norm(model):
|
| 181 |
from encodec.modules import SConv1d
|
|
@@ -256,112 +102,6 @@ def tokenize_audio(tokenizer: AudioTokenizer, audio):
|
|
| 256 |
return encoded_frames
|
| 257 |
|
| 258 |
|
| 259 |
-
# @dataclass
|
| 260 |
-
# class AudioTokenConfig:
|
| 261 |
-
# frame_shift: Seconds = 320.0 / 24000
|
| 262 |
-
# num_quantizers: int = 8
|
| 263 |
-
#
|
| 264 |
-
# def to_dict(self) -> Dict[str, Any]:
|
| 265 |
-
# return asdict(self)
|
| 266 |
-
#
|
| 267 |
-
# @staticmethod
|
| 268 |
-
# def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig":
|
| 269 |
-
# return AudioTokenConfig(**data)
|
| 270 |
-
#
|
| 271 |
-
#
|
| 272 |
-
# class AudioTokenExtractor(FeatureExtractor):
|
| 273 |
-
# name = "encodec"
|
| 274 |
-
# config_type = AudioTokenConfig
|
| 275 |
-
#
|
| 276 |
-
# def __init__(self, config: Optional[Any] = None):
|
| 277 |
-
# super(AudioTokenExtractor, self).__init__(config)
|
| 278 |
-
# self.tokenizer = AudioTokenizer()
|
| 279 |
-
#
|
| 280 |
-
# def extract(
|
| 281 |
-
# self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
|
| 282 |
-
# ) -> np.ndarray:
|
| 283 |
-
# if not isinstance(samples, torch.Tensor):
|
| 284 |
-
# samples = torch.from_numpy(samples)
|
| 285 |
-
# if sampling_rate != self.tokenizer.sample_rate:
|
| 286 |
-
# samples = convert_audio(
|
| 287 |
-
# samples,
|
| 288 |
-
# sampling_rate,
|
| 289 |
-
# self.tokenizer.sample_rate,
|
| 290 |
-
# self.tokenizer.channels,
|
| 291 |
-
# )
|
| 292 |
-
# if len(samples.shape) == 2:
|
| 293 |
-
# samples = samples.unsqueeze(0)
|
| 294 |
-
# else:
|
| 295 |
-
# raise ValueError()
|
| 296 |
-
#
|
| 297 |
-
# device = self.tokenizer.device
|
| 298 |
-
# encoded_frames = self.tokenizer.encode(samples.detach().to(device))
|
| 299 |
-
# codes = encoded_frames[0][0] # [B, n_q, T]
|
| 300 |
-
# if True:
|
| 301 |
-
# duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
|
| 302 |
-
# expected_num_frames = compute_num_frames(
|
| 303 |
-
# duration=duration,
|
| 304 |
-
# frame_shift=self.frame_shift,
|
| 305 |
-
# sampling_rate=sampling_rate,
|
| 306 |
-
# )
|
| 307 |
-
# assert abs(codes.shape[-1] - expected_num_frames) <= 1
|
| 308 |
-
# codes = codes[..., :expected_num_frames]
|
| 309 |
-
# return codes.cpu().squeeze(0).permute(1, 0).numpy()
|
| 310 |
-
#
|
| 311 |
-
# @property
|
| 312 |
-
# def frame_shift(self) -> Seconds:
|
| 313 |
-
# return self.config.frame_shift
|
| 314 |
-
#
|
| 315 |
-
# def feature_dim(self, sampling_rate: int) -> int:
|
| 316 |
-
# return self.config.num_quantizers
|
| 317 |
-
#
|
| 318 |
-
# def pad_tensor_list(self, tensor_list, device, padding_value=0):
|
| 319 |
-
# # 计算每个张量的长度
|
| 320 |
-
# lengths = [tensor.shape[0] for tensor in tensor_list]
|
| 321 |
-
# # 使用pad_sequence函数进行填充
|
| 322 |
-
# tensor_list = [torch.Tensor(t).to(device) for t in tensor_list]
|
| 323 |
-
# padded_tensor = torch.nn.utils.rnn.pad_sequence(
|
| 324 |
-
# tensor_list, batch_first=True, padding_value=padding_value
|
| 325 |
-
# )
|
| 326 |
-
# return padded_tensor, lengths
|
| 327 |
-
#
|
| 328 |
-
# def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray:
|
| 329 |
-
# samples = [wav.squeeze() for wav in samples]
|
| 330 |
-
# device = self.tokenizer.device
|
| 331 |
-
# samples, lengths = self.pad_tensor_list(samples, device)
|
| 332 |
-
# samples = samples.unsqueeze(1)
|
| 333 |
-
#
|
| 334 |
-
# if not isinstance(samples, torch.Tensor):
|
| 335 |
-
# samples = torch.from_numpy(samples)
|
| 336 |
-
# if len(samples.shape) != 3:
|
| 337 |
-
# raise ValueError()
|
| 338 |
-
# if sampling_rate != self.tokenizer.sample_rate:
|
| 339 |
-
# samples = [
|
| 340 |
-
# convert_audio(
|
| 341 |
-
# wav,
|
| 342 |
-
# sampling_rate,
|
| 343 |
-
# self.tokenizer.sample_rate,
|
| 344 |
-
# self.tokenizer.channels,
|
| 345 |
-
# )
|
| 346 |
-
# for wav in samples
|
| 347 |
-
# ]
|
| 348 |
-
# # Extract discrete codes from EnCodec
|
| 349 |
-
# with torch.no_grad():
|
| 350 |
-
# encoded_frames = self.tokenizer.encode(samples.detach().to(device))
|
| 351 |
-
# encoded_frames = encoded_frames[0][0] # [B, n_q, T]
|
| 352 |
-
# batch_codes = []
|
| 353 |
-
# for b, length in enumerate(lengths):
|
| 354 |
-
# codes = encoded_frames[b]
|
| 355 |
-
# duration = round(length / sampling_rate, ndigits=12)
|
| 356 |
-
# expected_num_frames = compute_num_frames(
|
| 357 |
-
# duration=duration,
|
| 358 |
-
# frame_shift=self.frame_shift,
|
| 359 |
-
# sampling_rate=sampling_rate,
|
| 360 |
-
# )
|
| 361 |
-
# batch_codes.append(codes[..., :expected_num_frames])
|
| 362 |
-
# return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes]
|
| 363 |
-
|
| 364 |
-
|
| 365 |
if __name__ == "__main__":
|
| 366 |
model = EncodecModel.encodec_model_24khz()
|
| 367 |
model.set_target_bandwidth(6.0)
|
|
|
|
| 22 |
import torchaudio
|
| 23 |
from encodec import EncodecModel
|
| 24 |
from encodec.utils import convert_audio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
def remove_encodec_weight_norm(model):
|
| 27 |
from encodec.modules import SConv1d
|
|
|
|
| 102 |
return encoded_frames
|
| 103 |
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
if __name__ == "__main__":
|
| 106 |
model = EncodecModel.encodec_model_24khz()
|
| 107 |
model.set_target_bandwidth(6.0)
|