|
import base64 |
|
import json |
|
import logging |
|
from pathlib import Path |
|
|
|
import tiktoken |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
FISH_TIKTOKEN_PATTERN = "|".join( |
|
[ |
|
r"(?i:'s|'t|'re|'ve|'m|'ll|'d)", |
|
r"\p{P}", |
|
r"[^\r\n\p{L}\p{N}]?\p{L}+", |
|
r"\p{N}", |
|
r" ?[^\s\p{L}\p{N}]+[\r\n]*", |
|
r"\s*[\r\n]+", |
|
r"\s+(\?!\S)", |
|
r"\s+", |
|
] |
|
) |
|
TIKTOKEN_MAX_ENCODE_CHARS = 400_000 |
|
|
|
BOS_TOKEN = "<|begin_of_text|>" |
|
EOS_TOKEN = "<|end_of_text|>" |
|
PAD_TOKEN = "<|pad|>" |
|
IM_START_TOKEN = "<|im_start|>" |
|
IM_END_TOKEN = "<|im_end|>" |
|
|
|
MODALITY_TEXT_TOKEN = "<|text|>" |
|
MODALITY_VOICE_TOKEN = "<|voice|>" |
|
MODALITY_INTERLEAVE_TOKEN = "<|interleave|>" |
|
MODALITY_TOKENS = { |
|
"text": MODALITY_TEXT_TOKEN, |
|
"voice": MODALITY_VOICE_TOKEN, |
|
"interleave": MODALITY_INTERLEAVE_TOKEN, |
|
} |
|
|
|
PLACEHOLDER_TOKEN = [""] * 4 |
|
for i in range(4): |
|
PLACEHOLDER_TOKEN[i] = f"<|placeholder:{i}|>" |
|
|
|
SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>" |
|
SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(1024)] |
|
|
|
|
|
ALL_SPECIAL_TOKENS = [ |
|
BOS_TOKEN, |
|
EOS_TOKEN, |
|
PAD_TOKEN, |
|
IM_START_TOKEN, |
|
IM_END_TOKEN, |
|
PLACEHOLDER_TOKEN[0], |
|
PLACEHOLDER_TOKEN[1], |
|
PLACEHOLDER_TOKEN[2], |
|
PLACEHOLDER_TOKEN[3], |
|
MODALITY_TEXT_TOKEN, |
|
MODALITY_VOICE_TOKEN, |
|
MODALITY_INTERLEAVE_TOKEN, |
|
*SEMANTIC_TOKENS, |
|
] |
|
|
|
|
|
class FishTokenizer: |
|
def __init__(self, model_path: str) -> None: |
|
mergeable_ranks = self.load_tiktoken_bpe(model_path) |
|
special_token_begin = len(mergeable_ranks) |
|
self.all_special_tokens_with_ids = { |
|
token: special_token_begin + i for i, token in enumerate(ALL_SPECIAL_TOKENS) |
|
} |
|
self.semantic_id_to_token_id = { |
|
i: self.all_special_tokens_with_ids[token] |
|
for i, token in enumerate(SEMANTIC_TOKENS) |
|
} |
|
self.semantic_begin_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[0]] |
|
self.semantic_end_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[-1]] |
|
|
|
self.tkt_model = tiktoken.core.Encoding( |
|
name=Path(model_path).stem, |
|
pat_str=FISH_TIKTOKEN_PATTERN, |
|
mergeable_ranks=mergeable_ranks, |
|
special_tokens=self.all_special_tokens_with_ids, |
|
) |
|
|
|
@staticmethod |
|
def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]: |
|
data = {} |
|
for line in open(tiktoken_bpe_file).read().splitlines(): |
|
if not line: |
|
continue |
|
token, rank = line.split() |
|
data[base64.b64decode(token)] = int(rank) |
|
return data |
|
|
|
def get_token_id(self, token: str) -> int: |
|
return self.all_special_tokens_with_ids[token] |
|
|
|
def encode(self, s: str, allowed_special: bool | set[str] = True) -> list[int]: |
|
assert isinstance(s, str) |
|
|
|
subs = [] |
|
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS): |
|
subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS]) |
|
|
|
if allowed_special is True: |
|
allowed_special = self.tkt_model.special_tokens_set |
|
elif allowed_special is False: |
|
allowed_special = set() |
|
|
|
return sum( |
|
self.tkt_model.encode_batch( |
|
subs, allowed_special=allowed_special, disallowed_special=set() |
|
), |
|
start=[], |
|
) |
|
|
|
def decode(self, tokens: list[int]) -> str: |
|
return self.tkt_model.decode(tokens) |
|
|
|
def save_pretrained(self, path: str): |
|
path = Path(path) |
|
path.mkdir(parents=True, exist_ok=True) |
|
|
|
with open(path / "tokenizer.tiktoken", "w") as f: |
|
for token, rank in self.tkt_model._mergeable_ranks.items(): |
|
f.write(f"{base64.b64encode(token).decode()} {rank}\n") |
|
|
|
with open(path / "special_tokens.json", "w") as f: |
|
json.dump( |
|
self.all_special_tokens_with_ids, |
|
f, |
|
indent=2, |
|
ensure_ascii=False, |
|
) |
|
|
|
@staticmethod |
|
def from_pretrained(path: str): |
|
return FishTokenizer(Path(path) / "tokenizer.tiktoken") |
|
|
|
|
|
if __name__ == "__main__": |
|
tokenizer = FishTokenizer("data/mpacks/v1.4-pretrain/tokenizer.all.tiktoken") |
|
tokenizer.save_pretrained("checkpoints/fish-speech-0.5B") |
|
tokenizer = FishTokenizer.from_pretrained("checkpoints/fish-speech-0.5B") |
|
|
|
print( |
|
[ |
|
tokenizer.decode([i]) |
|
for i in tokenizer.encode(f"{BOS_TOKEN}你好,世界!{EOS_TOKEN}") |
|
] |
|
) |
|
|