diff --git a/fish_speech/__pycache__/conversation.cpython-310.pyc b/fish_speech/__pycache__/conversation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4715d8ad7d19c710c2a4809f48f44ce77b121cfd Binary files /dev/null and b/fish_speech/__pycache__/conversation.cpython-310.pyc differ diff --git a/fish_speech/__pycache__/scheduler.cpython-310.pyc b/fish_speech/__pycache__/scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef6897b47e20bec33721ee86f864d6a511f49635 Binary files /dev/null and b/fish_speech/__pycache__/scheduler.cpython-310.pyc differ diff --git a/fish_speech/callbacks/__init__.py b/fish_speech/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bbcf3f33656d180ca87cd14a21ede1544e5a61a3 --- /dev/null +++ b/fish_speech/callbacks/__init__.py @@ -0,0 +1,3 @@ +from .grad_norm import GradNormMonitor + +__all__ = ["GradNormMonitor"] diff --git a/fish_speech/callbacks/__pycache__/__init__.cpython-310.pyc b/fish_speech/callbacks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af0d6a057cc9c633a575f24bbb4806f02586c88c Binary files /dev/null and b/fish_speech/callbacks/__pycache__/__init__.cpython-310.pyc differ diff --git a/fish_speech/callbacks/__pycache__/grad_norm.cpython-310.pyc b/fish_speech/callbacks/__pycache__/grad_norm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95bc0a749f8b9617c87c6fb9a0d5936e38d82e7e Binary files /dev/null and b/fish_speech/callbacks/__pycache__/grad_norm.cpython-310.pyc differ diff --git a/fish_speech/callbacks/grad_norm.py b/fish_speech/callbacks/grad_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..dbc95ef2a3723323b2d976001ed1e3c79c00b21a --- /dev/null +++ b/fish_speech/callbacks/grad_norm.py @@ -0,0 +1,113 @@ +from typing import Optional, Union + +import lightning.pytorch as pl +import torch +from lightning import LightningModule, Trainer +from lightning.pytorch.callbacks import Callback +from torch import Tensor, nn +from torch.utils._foreach_utils import ( + _group_tensors_by_device_and_dtype, + _has_foreach_support, +) + + +@torch.no_grad() +def grad_norm( + parameters: Union[Tensor, list[Tensor]], + norm_type: float = 2.0, +) -> float: + """ + Returns the norm of the gradients of the given parameters. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + norm_type (float): type of the used p-norm. + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ # noqa: E501 + + if isinstance(parameters, Tensor): + parameters = [parameters] + + grads = [p.grad for p in parameters if p.grad is not None] + if len(grads) == 0: + return None + + first_device = grads[0].device + grouped_grads: dict[ + tuple[torch.device, torch.dtype], list[list[Tensor]] + ] = _group_tensors_by_device_and_dtype( + [[g.detach() for g in grads]] + ) # type: ignore[assignment] + + norms = [] + for (device, _), ([grads], _) in grouped_grads.items(): + if _has_foreach_support(grads, device=device): + norms.extend(torch._foreach_norm(grads, norm_type)) + else: + norms.extend([torch.norm(g, norm_type) for g in grads]) + + return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type) + + +class GradNormMonitor(Callback): + """ + Callback that computes the gradient norm of the model parameters. + """ + + def __init__( + self, + norm_type: float = 2.0, + logging_interval: str = "step", + sub_module: Optional[Union[str, list[str]]] = None, + ) -> None: + """ + Args: + norm_type (float): type of the used p-norm. + logging_interval (str): "step" or "epoch". + """ + super().__init__() + + self.norm_type = norm_type + self.logging_interval = logging_interval + self.sub_module = sub_module + + def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None: + """ + Computes the gradient norm of the model parameters and logs it to the logger. + + Args: + trainer (Trainer): The trainer object + model (LightningModule): The current lightningModule + """ + + lightning_model = model + + if self.sub_module is None: + return self.log_sub_module_grad_norm(lightning_model, model, "") + + sub_modules = self.sub_module + if isinstance(sub_modules, str): + sub_modules = [sub_modules] + + for sub_module in sub_modules: + self.log_sub_module_grad_norm( + lightning_model, getattr(model, sub_module), f"/{sub_module}" + ) + + def log_sub_module_grad_norm( + self, lightning_model: LightningModule, model: nn.Module, path: str + ) -> None: + grad_norm_val = grad_norm(model.parameters(), self.norm_type) + if grad_norm_val is None: + return + + on_step = self.logging_interval == "step" + lightning_model.log( + f"train{path}/grad_norm", + grad_norm_val, + on_step=on_step, + on_epoch=not on_step, + ) diff --git a/fish_speech/configs/base.yaml b/fish_speech/configs/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..99e6dab54d3f57bce4f6d29a9129a19a523cad75 --- /dev/null +++ b/fish_speech/configs/base.yaml @@ -0,0 +1,87 @@ +# Base configuration for training a model +paths: + run_dir: results/${project} + ckpt_dir: ${paths.run_dir}/checkpoints + +hydra: + run: + dir: ${paths.run_dir} + +# Lightning Trainer +trainer: + _target_: lightning.pytorch.trainer.Trainer + + default_root_dir: ${paths.run_dir} + accelerator: gpu + num_nodes: 1 + devices: auto + strategy: + _target_: lightning.pytorch.strategies.DDPStrategy + process_group_backend: nccl # This should be override when training on windows + + precision: bf16-mixed + + # disable validation by epoch end + check_val_every_n_epoch: null + val_check_interval: 5000 + max_steps: 100_000 + + # Use torch.backends.cudnn.benchmark to speed up training + benchmark: true + +# Callbacks +callbacks: + model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: ${paths.ckpt_dir} + filename: "step_{step:09d}" + save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt + save_top_k: 5 # save 5 latest checkpoints + monitor: step # use step to monitor checkpoints + mode: max # save the latest checkpoint with the highest global_step + every_n_epochs: null # don't save checkpoints by epoch end + every_n_train_steps: 5000 # save checkpoints every 5000 steps + auto_insert_metric_name: false + + model_summary: + _target_: lightning.pytorch.callbacks.ModelSummary + max_depth: 2 # the maximum depth of layer nesting that the summary will include + + learning_rate_monitor: + _target_: lightning.pytorch.callbacks.LearningRateMonitor + logging_interval: step + log_momentum: false + + grad_norm_monitor: + _target_: fish_speech.callbacks.GradNormMonitor + norm_type: 2 + logging_interval: step + +# Logger +logger: + tensorboard: + _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger + save_dir: "${paths.run_dir}/tensorboard/" + name: null + log_graph: false + default_hp_metric: true + prefix: "" + + # wandb: + # _target_: lightning.pytorch.loggers.wandb.WandbLogger + # # name: "" # name of the run (normally generated by wandb) + # save_dir: "${paths.run_dir}" + # offline: False + # id: null # pass correct id to resume experiment! + # anonymous: null # enable anonymous logging + # project: "fish-speech" + # log_model: False # upload lightning ckpts + # prefix: "" # a string to put at the beginning of metric keys + # # entity: "" # set to name of your wandb team + # group: "" + # tags: ["vq", "hq", "finetune"] + # job_type: "" + +# Loop +train: true +test: false diff --git a/fish_speech/configs/firefly_gan_vq.yaml b/fish_speech/configs/firefly_gan_vq.yaml new file mode 100644 index 0000000000000000000000000000000000000000..10aa8d4a522f0859ed8f541f5d48672d84b39c8f --- /dev/null +++ b/fish_speech/configs/firefly_gan_vq.yaml @@ -0,0 +1,33 @@ +_target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture +spec_transform: + _target_: fish_speech.utils.spectrogram.LogMelSpectrogram + sample_rate: 44100 + n_mels: 160 + n_fft: 2048 + hop_length: 512 + win_length: 2048 +backbone: + _target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder + input_channels: 160 + depths: [3, 3, 9, 3] + dims: [128, 256, 384, 512] + drop_path_rate: 0.2 + kernel_size: 7 +head: + _target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator + hop_length: 512 + upsample_rates: [8, 8, 2, 2, 2] # aka. strides + upsample_kernel_sizes: [16, 16, 4, 4, 4] + resblock_kernel_sizes: [3, 7, 11] + resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + num_mels: 512 + upsample_initial_channel: 512 + pre_conv_kernel_size: 13 + post_conv_kernel_size: 13 +quantizer: + _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize + input_dim: 512 + n_groups: 8 + n_codebooks: 1 + levels: [8, 5, 5, 5] + downsample_factor: [2, 2] diff --git a/fish_speech/configs/lora/r_8_alpha_16.yaml b/fish_speech/configs/lora/r_8_alpha_16.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aecc4d9766a18fe31c55941e01b1f590c95e77c9 --- /dev/null +++ b/fish_speech/configs/lora/r_8_alpha_16.yaml @@ -0,0 +1,4 @@ +_target_: fish_speech.models.text2semantic.lora.LoraConfig +r: 8 +lora_alpha: 16 +lora_dropout: 0.01 diff --git a/fish_speech/configs/text2semantic_finetune.yaml b/fish_speech/configs/text2semantic_finetune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f4c1993023099e122fc9e004bda55ec075ed5e1b --- /dev/null +++ b/fish_speech/configs/text2semantic_finetune.yaml @@ -0,0 +1,83 @@ +defaults: + - base + - _self_ + +project: text2semantic_finetune_dual_ar +max_length: 4096 +pretrained_ckpt_path: checkpoints/fish-speech-1.4 + +# Lightning Trainer +trainer: + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + gradient_clip_algorithm: "norm" + max_steps: 1000 + precision: bf16-true + limit_val_batches: 10 + val_check_interval: 100 + +# Dataset Configuration +tokenizer: + _target_: transformers.AutoTokenizer.from_pretrained + pretrained_model_name_or_path: ${pretrained_ckpt_path} + +# Dataset Configuration +train_dataset: + _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset + proto_files: + - data/protos + tokenizer: ${tokenizer} + causal: true + max_length: ${max_length} + use_speaker: false + interactive_prob: 0.7 + +val_dataset: + _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset + proto_files: + - data/protos + tokenizer: ${tokenizer} + causal: true + max_length: ${max_length} + use_speaker: false + interactive_prob: 0.7 + +data: + _target_: fish_speech.datasets.semantic.SemanticDataModule + train_dataset: ${train_dataset} + val_dataset: ${val_dataset} + num_workers: 4 + batch_size: 8 + tokenizer: ${tokenizer} + max_length: ${max_length} + +# Model Configuration +model: + _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic + model: + _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained + path: ${pretrained_ckpt_path} + load_weights: true + max_length: ${max_length} + lora_config: null + + optimizer: + _target_: torch.optim.AdamW + _partial_: true + lr: 1e-4 + weight_decay: 0 + betas: [0.9, 0.95] + eps: 1e-5 + + lr_scheduler: + _target_: torch.optim.lr_scheduler.LambdaLR + _partial_: true + lr_lambda: + _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda + _partial_: true + num_warmup_steps: 10 + +# Callbacks +callbacks: + model_checkpoint: + every_n_train_steps: ${trainer.val_check_interval} diff --git a/fish_speech/conversation.py b/fish_speech/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..9bbc1cdb6c4a1d276ccf922988a7ad13e058d70a --- /dev/null +++ b/fish_speech/conversation.py @@ -0,0 +1,256 @@ +from dataclasses import dataclass, field +from typing import Literal + +import torch +from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerFast + +IM_START_TOKEN = "<|im_start|>" +IM_END_TOKEN = "<|im_end|>" +SEMANTIC_TOKEN = "<|semantic|>" +MEL_TOKEN = "<|mel|>" +PHONEME_START_TOKEN = "<|phoneme_start|>" +PHONEME_END_TOKEN = "<|phoneme_end|>" +ALL_SPECIAL_TOKENS = [ + IM_START_TOKEN, + IM_END_TOKEN, + SEMANTIC_TOKEN, + MEL_TOKEN, + PHONEME_START_TOKEN, + PHONEME_END_TOKEN, +] + +CODEBOOK_PAD_TOKEN_ID = 0 + + +class FishTokenizerConfig(PretrainedConfig): + share_codebook_embeddings: bool = True + codebook_size: int = 1024 + num_codebooks: int = 8 + + +class FishTokenizerFast(PreTrainedTokenizerFast): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.share_codebook_embeddings = kwargs.pop("share_codebook_embeddings", True) + self.codebook_size = kwargs.pop("codebook_size", 1024) + self.num_codebooks = kwargs.pop("num_codebooks", 8) + + +AutoTokenizer.register(FishTokenizerConfig, fast_tokenizer_class=FishTokenizerFast) + + +@dataclass(kw_only=True) +class BasePart: + pass + + +@dataclass(kw_only=True) +class VQPart(BasePart): + codes: torch.Tensor + + +@dataclass(kw_only=True) +class TextPart(BasePart): + text: str + + +@dataclass(kw_only=True) +class MelPart(BasePart): + mels: torch.Tensor + + +@dataclass(kw_only=True) +class EncodedMessage: + tokens: torch.Tensor + labels: torch.Tensor + vq_parts: list[torch.Tensor] + mel_parts: list[torch.Tensor] + vq_require_losses: torch.Tensor | None = None + + +@dataclass(kw_only=True) +class Message: + role: Literal["system", "user", "assistant"] + parts: list[VQPart | TextPart | MelPart] = field(default_factory=list) + add_im_start: bool = True + add_im_end: bool = True + cal_loss: bool = False + + # By default, ignore the loss of the auto-generated im_start token + ignore_im_start_loss: bool = True + + def encode( + self: "Message", + tokenizer: AutoTokenizer, + ) -> EncodedMessage: + all_tokens = [] + all_labels = [] + + # Multi-modal tokens + vq_parts = [] + mel_parts = [] + + semantic_id, mel_id = tokenizer.convert_tokens_to_ids( + [SEMANTIC_TOKEN, MEL_TOKEN] + ) + + parts = self.parts.copy() + if self.add_im_start: + parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n")) + + if self.add_im_end: + parts.append(TextPart(text="<|im_end|>")) + + for part in parts: + if isinstance(part, TextPart): + tokens = tokenizer.encode( + part.text, + add_special_tokens=False, + truncation=False, + return_tensors="pt", + ).int()[0] + elif isinstance(part, VQPart): + tokens = torch.zeros(part.codes.shape[1], dtype=torch.int) + semantic_id + codes = part.codes.clone() + 1 + + if getattr(tokenizer, "share_codebook_embeddings", True) is False: + for i in range(len(codes)): + codes[i] += tokenizer.codebook_size * i + + vq_parts.append(codes) + elif isinstance(part, MelPart): + tokens = torch.zeros(part.mels.shape[1], dtype=torch.int) + mel_id + mel_parts.append(part.mels) + else: + raise ValueError(f"Unsupported part type: {type(part)}") + + all_tokens.append(tokens) + if self.cal_loss: + all_labels.append(tokens.clone()) + else: + all_labels.append(torch.full_like(tokens, -100)) + + tokens = torch.cat(all_tokens, dim=0) + labels = torch.cat(all_labels, dim=0) + assert tokens.shape == labels.shape + + if self.ignore_im_start_loss and self.add_im_start: + labels[: len(all_tokens[0])] = -100 + + return EncodedMessage( + tokens=tokens, + labels=labels, + vq_parts=vq_parts, + mel_parts=mel_parts, + ) + + +@dataclass +class Conversation: + messages: list[Message] + + def encode( + self: "Conversation", + tokenizer: AutoTokenizer, + add_shift: bool = True, + ) -> EncodedMessage: + # Build the input_ids and labels + tokens = [] + labels = [] + vq_parts = [] + mel_parts = [] + vq_require_losses = [] + + for message in self.messages: + encoded = message.encode( + tokenizer, + ) + tokens.append(encoded.tokens) + labels.append(encoded.labels) + vq_parts.extend(encoded.vq_parts) + mel_parts.extend(encoded.mel_parts) + vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts)) + + tokens = torch.cat(tokens, dim=0) + labels = torch.cat(labels, dim=0) + vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool) + + if add_shift: + tokens = tokens[:-1] + labels = labels[1:] + + assert tokens.dtype in [ + torch.int, + torch.long, + ], f"Invalid dtype: {tokens.dtype}, conv: {conversation}" + + return EncodedMessage( + tokens=tokens, + labels=labels, + vq_parts=vq_parts, + mel_parts=mel_parts, + vq_require_losses=vq_require_losses, + ) + + def encode_for_inference( + self: "Conversation", + tokenizer: AutoTokenizer, + num_codebooks: int, + ) -> EncodedMessage: + encoded = self.encode(tokenizer, add_shift=False) + tokens = encoded.tokens + values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int) + values[0] = tokens + + if encoded.vq_parts is None or len(encoded.vq_parts) == 0: + return values + + semantic_id, mel_id = tokenizer.convert_tokens_to_ids( + [SEMANTIC_TOKEN, MEL_TOKEN] + ) + vq_parts = encoded.vq_parts + vq_parts = torch.cat(vq_parts, dim=1) + values[1:, tokens == semantic_id] = vq_parts + return values + + def visualize(self: "Conversation", tokenizer: AutoTokenizer): + encoded = self.encode(tokenizer, add_shift=False) + + print_in_blue = lambda x: print("\033[94m" + x + "\033[0m", end="") + print_in_green = lambda x: print("\033[92m" + x + "\033[0m", end="") + + for tok, lab in zip(encoded.tokens, encoded.labels): + val = tokenizer.decode(tok, skip_special_tokens=False) + if val == "\n": + val = "\\n\n" + + if lab == -100: + print_in_green(val) + else: + print_in_blue(val) + + print() + + +if __name__ == "__main__": + message0 = Message( + role="user", + parts=[ + TextPart(text="Hello, how are you?"), + VQPart(codes=torch.zeros((4, 10))), + ], + cal_loss=False, + ) + + message1 = Message( + role="assistant", + parts=[TextPart(text="I'm fine, thank you.")], + cal_loss=True, + ) + conversation = Conversation([message0, message1]) + tokenizer = AutoTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct") + conversation.visualize(tokenizer) + + encoded = conversation.encode(tokenizer) + print(encoded) + print(tokenizer.batch_decode(encoded.tokens)) diff --git a/fish_speech/datasets/__pycache__/semantic.cpython-310.pyc b/fish_speech/datasets/__pycache__/semantic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c674545ce61a4bf657e1be68d9969889a90b69f Binary files /dev/null and b/fish_speech/datasets/__pycache__/semantic.cpython-310.pyc differ diff --git a/fish_speech/datasets/concat_repeat.py b/fish_speech/datasets/concat_repeat.py new file mode 100644 index 0000000000000000000000000000000000000000..4aa596b95a572ee15c5570cbdb792c9a78e62dfa --- /dev/null +++ b/fish_speech/datasets/concat_repeat.py @@ -0,0 +1,53 @@ +import bisect +import random +from typing import Iterable + +from torch.utils.data import Dataset, IterableDataset + + +class ConcatRepeatDataset(Dataset): + datasets: list[Dataset] + cumulative_sizes: list[int] + repeats: list[int] + + @staticmethod + def cumsum(sequence, repeats): + r, s = [], 0 + for dataset, repeat in zip(sequence, repeats): + l = len(dataset) * repeat + r.append(l + s) + s += l + return r + + def __init__(self, datasets: Iterable[Dataset], repeats: list[int]): + super().__init__() + + self.datasets = list(datasets) + self.repeats = repeats + + assert len(self.datasets) > 0, "datasets should not be an empty iterable" + assert len(self.datasets) == len( + repeats + ), "datasets and repeats should have the same length" + + for d in self.datasets: + assert not isinstance( + d, IterableDataset + ), "ConcatRepeatDataset does not support IterableDataset" + + self.cumulative_sizes = self.cumsum(self.datasets, self.repeats) + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + + dataset = self.datasets[dataset_idx] + + return dataset[sample_idx % len(dataset)] diff --git a/fish_speech/datasets/protos/__pycache__/text_data_pb2.cpython-310.pyc b/fish_speech/datasets/protos/__pycache__/text_data_pb2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..436d43fb714faad6e5b7cbe23b6de6e998fe03e0 Binary files /dev/null and b/fish_speech/datasets/protos/__pycache__/text_data_pb2.cpython-310.pyc differ diff --git a/fish_speech/datasets/protos/__pycache__/text_data_stream.cpython-310.pyc b/fish_speech/datasets/protos/__pycache__/text_data_stream.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac8b50578898c53aec2ca94f73b92e0d9d7d8537 Binary files /dev/null and b/fish_speech/datasets/protos/__pycache__/text_data_stream.cpython-310.pyc differ diff --git a/fish_speech/datasets/protos/text-data.proto b/fish_speech/datasets/protos/text-data.proto new file mode 100644 index 0000000000000000000000000000000000000000..5eb26d94aa3be1e21066f2bf38c90d54e85a8379 --- /dev/null +++ b/fish_speech/datasets/protos/text-data.proto @@ -0,0 +1,24 @@ +syntax = "proto3"; + +package text_data; + +message Semantics { + repeated uint32 values = 1; +} + +message Sentence { + repeated string texts = 1; + repeated Semantics semantics = 3; +} + +message TextData { + string source = 1; + string name = 2; + repeated Sentence sentences = 4; +} + +message SampledData { + string source = 1; + string name = 2; + repeated Sentence samples = 3; +} diff --git a/fish_speech/datasets/protos/text_data_pb2.py b/fish_speech/datasets/protos/text_data_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..bfce0e8be59fc51e68999ef137e1fd0e4adc0d7e --- /dev/null +++ b/fish_speech/datasets/protos/text_data_pb2.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: text-data.proto +# Protobuf Python Version: 4.25.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals["_SEMANTICS"]._serialized_start = 30 + _globals["_SEMANTICS"]._serialized_end = 57 + _globals["_SENTENCE"]._serialized_start = 59 + _globals["_SENTENCE"]._serialized_end = 125 + _globals["_TEXTDATA"]._serialized_start = 127 + _globals["_TEXTDATA"]._serialized_end = 207 + _globals["_SAMPLEDDATA"]._serialized_start = 209 + _globals["_SAMPLEDDATA"]._serialized_end = 290 +# @@protoc_insertion_point(module_scope) diff --git a/fish_speech/datasets/protos/text_data_stream.py b/fish_speech/datasets/protos/text_data_stream.py new file mode 100644 index 0000000000000000000000000000000000000000..ec3c25bcd764e8245de47dcdf9686d6adfb5a107 --- /dev/null +++ b/fish_speech/datasets/protos/text_data_stream.py @@ -0,0 +1,36 @@ +import struct + +from .text_data_pb2 import TextData + + +def read_pb_stream(f): + while True: + buf = f.read(4) + if len(buf) == 0: + break + size = struct.unpack("I", buf)[0] + buf = f.read(size) + text_data = TextData() + text_data.ParseFromString(buf) + yield text_data + + +def write_pb_stream(f, text_data): + buf = text_data.SerializeToString() + f.write(struct.pack("I", len(buf))) + f.write(buf) + + +def pack_pb_stream(text_data): + buf = text_data.SerializeToString() + return struct.pack("I", len(buf)) + buf + + +def split_pb_stream(f): + while True: + head = f.read(4) + if len(head) == 0: + break + size = struct.unpack("I", head)[0] + buf = f.read(size) + yield head + buf diff --git a/fish_speech/datasets/semantic.py b/fish_speech/datasets/semantic.py new file mode 100644 index 0000000000000000000000000000000000000000..3c64e01077ae253bdc4e4d9cd948f8fb50df7418 --- /dev/null +++ b/fish_speech/datasets/semantic.py @@ -0,0 +1,496 @@ +import random +from dataclasses import dataclass +from itertools import chain +from pathlib import Path +from random import Random +from typing import Optional, Union + +import numpy as np +import pyarrow.parquet as pq +import torch +import torch.nn.functional as F +from datasets.download.streaming_download_manager import xopen +from huggingface_hub import HfApi +from lightning import LightningDataModule +from torch.distributed import get_rank, get_world_size, is_initialized +from torch.utils.data import DataLoader, IterableDataset, get_worker_info +from transformers import AutoTokenizer + +from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID +from fish_speech.datasets.protos.text_data_pb2 import SampledData +from fish_speech.datasets.protos.text_data_stream import read_pb_stream +from fish_speech.text.clean import clean_text +from fish_speech.utils import RankedLogger +from fish_speech.utils.braceexpand import braceexpand + +log = RankedLogger(__name__, rank_zero_only=True) + + +def split_by_rank_worker(files): + # We need to know the total number of devices + # to split the data properly + + total_devices = 1 + if is_initialized(): + total_devices = get_world_size() + + worker_info = get_worker_info() + if worker_info is not None: + total_devices *= worker_info.num_workers + + if len(files) < total_devices: + # Repeat the files N times to match the number of devices + files = files * (total_devices // len(files) + 1) + + # DDP + if is_initialized(): + files = files[get_rank() :: get_world_size()] + + # Split by worker + if worker_info is not None: + files = files[worker_info.id :: worker_info.num_workers] + + return files + + +class AutoTextSemanticInstructionDataset(IterableDataset): + """ + Auto Augment Dataset by Speaker + + 1. Random concatenate multiple sentences from the same speaker to form a longer sentence + 2. Automatically normalize the text + + For interactive mode, we use the following format (multiple sequences): + [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] + + For non-interactive mode, we use the following format (one long sequence): + [INST] text [/INST] ... + """ + + def __init__( + self, + proto_files: list[str], + seed: int = 42, + interactive_prob: float = 0.5, + max_length: int = 1024, + tokenizer: AutoTokenizer = None, + use_speaker: bool | float = True, + causal: bool = True, + num_codebooks: Optional[int] = None, + skip_text_prob: float = 0.0, + ): + """ + Args: + proto_files: proto buf files if using local data + seed: random seed + interactive_prob: probability to use interactive mode + max_length: max length of the text + tokenizer: tokenizer + use_speaker: include speaker information in the prompt + causal: use causal sampling when using local data, disable will lead to random sampling + num_codebooks: number of codebooks, if None, it will be automatically detected + skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode + """ + + super().__init__() + + assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]" + + self.seed = seed + self.max_length = max_length + self.tokenizer = tokenizer + self.interactive_prob = interactive_prob + self.use_speaker = use_speaker + self.proto_files = proto_files + self.causal = causal + self.num_codebooks = num_codebooks + self.skip_text_prob = skip_text_prob + + self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>") + self.groups = None + + def init_mock_data_server(self): + if self.groups is not None: + return + + # Expand the proto files + expanded_proto_files = [] + for filename in self.proto_files: + for i in braceexpand(filename): + i = Path(i) + if i.is_file(): + expanded_proto_files.append(i) + elif i.is_dir(): + expanded_proto_files.extend(i.rglob("*.proto")) + expanded_proto_files.extend(i.rglob("*.protos")) + else: + raise ValueError(f"{i} is not a file or directory") + + expanded_proto_files = sorted(expanded_proto_files) + Random(self.seed).shuffle(expanded_proto_files) + + self.groups = [] + shard_proto_files = split_by_rank_worker(expanded_proto_files) + log.info( + f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files" + ) + + count = 0 + for filename in shard_proto_files: + with open(filename, "rb") as f: + for text_data in read_pb_stream(f): + self.groups.append(text_data) + count += 1 + + log.info(f"Read total {count} groups of data") + + # Shuffle the lines + Random(self.seed).shuffle(self.groups) + self.group_weights = [len(i.sentences) for i in self.groups] + + def __iter__(self): + while True: + yield self.augment() + + def tokenize_sentence(self, sentence: str): + sentence = clean_text(sentence) + tokens = self.tokenizer.encode( + f"{sentence}", + max_length=10**6, + add_special_tokens=False, + truncation=False, + ) + return sentence, len(tokens) + + def sample_data(self): + if self.groups is None: + self.init_mock_data_server() + + # Shuffle unique lines, estimate that each sample is at least 20 tokens + num_samples = self.max_length // 20 + + # choice group based on their number of samples + group = random.choices(self.groups, weights=self.group_weights, k=1)[0] + + if self.causal: + # Sample in order + if num_samples >= len(group.sentences): + samples = group.sentences + else: + begin = random.randint(0, len(group.sentences) - num_samples) + samples = group.sentences[begin : begin + num_samples] + else: + samples = random.choices( + group.sentences, k=min(num_samples, len(group.sentences)) + ) + + return SampledData( + source=group.source, + name=group.name, + samples=samples, + ) + + def augment(self): + final_text, final_semantic = [], [] + response = self.sample_data() + if len(response.samples) == 0: + # Invalid group + return None + + samples = list(response.samples) + idx = 0 + use_interactive = random.random() < self.interactive_prob + + if use_interactive is False: + # Random sample based on speaker using a truncated normal distribution + a = torch.tensor([0], dtype=torch.float32) + torch.nn.init.trunc_normal_( + a, + mean=self.max_length // 2, + std=self.max_length // 4, + a=10, + b=self.max_length, + ) + remaining_tokens = a.long().item() - 4 + else: + remaining_tokens = self.max_length + + # Use speaker + if isinstance(self.use_speaker, float): + use_speaker = random.random() < self.use_speaker + else: + use_speaker = self.use_speaker + + all_tokens, all_labels = [], [] + while remaining_tokens > 0 and len(samples) > 0: + sentence = samples.pop(0) + + text = random.choice(sentence.texts) + text, length = self.tokenize_sentence(text) + remaining_tokens -= length + len(sentence.semantics[0].values) + + if use_interactive is False: + final_text.append(text) + final_semantic.append(sentence.semantics) + else: + # For interactive mode, we only apply speaker for the first sentence + # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] + tokens, labels = self.pack_sentences( + sentences=[text], + semantics=[sentence.semantics], + speaker=response.name if use_speaker else None, + skip_text=random.random() < self.skip_text_prob, + ) + + all_tokens.append(tokens) + all_labels.append(labels) + + idx += 1 + + if use_interactive is False: + tokens, labels = self.pack_sentences( + final_text, + semantics=final_semantic, + speaker=response.name if use_speaker else None, + ) + all_tokens.append(tokens) + all_labels.append(labels) + + tokens = torch.cat(all_tokens, dim=1) + labels = torch.cat(all_labels, dim=1) + + # Verify that the length is correct + assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}" + + data = {"tokens": tokens, "labels": labels} + + return data + + def pack_sentences( + self, + sentences: list[str], + semantics: list, + speaker: Optional[str] = None, + skip_text: bool = False, + ): + if speaker is None: + speaker = "assistant" + + cated_sentences = " ".join(sentences) + if skip_text: + cated_sentences = "<|skip_text|>" + + final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>" + final_text = final_text + f"<|im_start|>{speaker}\n" + + encoded = self.tokenizer.encode( + final_text, + add_special_tokens=False, + truncation=False, + max_length=10**6, + ) + semantic_length = sum([len(i[0].values) for i in semantics]) + prompt_length = len(encoded) + num_codebooks = ( + len(semantics[0]) if self.num_codebooks is None else self.num_codebooks + ) + + # Pack the tokens and semantics (add and to semantic tokens) + tokens = ( + encoded + + [self.semantic_token_id] * semantic_length + + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"]) + ) + + # Codebook bos/padding: 0, eos: 1 + codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)] + for segment in semantics: + for book_idx, book in zip(range(num_codebooks), segment): + for j in book.values: + codes[book_idx].append(int(j) + 1) + + for book in codes: + book.extend([CODEBOOK_PAD_TOKEN_ID] * 1) + + tokens = [tokens] + codes + + tokens = torch.tensor(tokens, dtype=torch.long) + labels = tokens.clone() + + if skip_text: + # If text is not provided, the sentence is used for condition only, all labels are -100 + torch.fill_(labels, -100) + return tokens, labels + + # Mask out the tokens for semantic, predict semantic tokens only + # Since we don't mask out the input tokens, the language modeling still works + labels[1:, :prompt_length] = -100 + + tokens = tokens[:, :-1] + labels = labels[:, 1:] + + # Verify the padding is correct, and the last token is eos + assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all() + assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all() + + return tokens, labels + + +@dataclass +class TextDataCollator: + tokenizer: AutoTokenizer + max_length: int = 1024 + + def __call__(self, examples): + if "negative_tokens" in examples: + positive_examples = [] + negative_examples = [] + + for i in examples: + positive_examples.append( + { + "tokens": i["tokens"], + "labels": i["labels"], + } + ) + negative_examples.append( + { + "tokens": i["negative_tokens"], + "labels": i["negative_labels"], + } + ) + + examples = positive_examples + negative_examples + + return self.batchify(examples) + + def batchify(self, examples, tokens_key="tokens", labels_key="labels"): + tokens, attention_masks, labels = [], [], [] + + # Calculate the max length + max_tokens_length = 0 + for example in examples: + max_tokens_length = max(max_tokens_length, example[tokens_key].size(1)) + max_tokens_length = min(max_tokens_length, self.max_length) + + for example in examples: + _tokens = example[tokens_key][:, :max_tokens_length] + _labels = example[labels_key][:, :max_tokens_length] + _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool) + tokens_length = _tokens.size(1) + _attention_mask[:tokens_length] = False + + assert tokens_length == _labels.size( + 1 + ), f"{tokens_length} != {_labels.size(1)}" + + if tokens_length < max_tokens_length: + _tokens = F.pad( + _tokens, + (0, max_tokens_length - tokens_length), + value=self.tokenizer.eos_token_id, + ) + _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID + _labels = F.pad( + _labels, (0, max_tokens_length - _labels.size(1)), value=-100 + ) + + tokens.append(_tokens) + attention_masks.append(_attention_mask) + labels.append(_labels) + + tokens = torch.stack(tokens, dim=0) + attention_masks = torch.stack(attention_masks, dim=0) + labels = torch.stack(labels, dim=0) + + return { + "inputs": tokens, + "attention_masks": attention_masks, + "labels": labels, + } + + +class InterleaveDataset(IterableDataset): + def __init__( + self, + datasets: list[IterableDataset], + probabilities: list[float], + seed: int = 42, + ): + super().__init__() + + self.datasets = datasets + self.probabilities = probabilities + self.seed = seed + + def __iter__(self): + rng = np.random.default_rng(self.seed) + dataset_iterators = [iter(dataset) for dataset in self.datasets] + + while True: + # Random choice one + dataset_idx = rng.choice(len(self.datasets), p=self.probabilities) + dataset_iterator = dataset_iterators[dataset_idx] + + try: + yield next(dataset_iterator) + except StopIteration: + # Exhausted, create a new iterator + dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx]) + yield next(dataset_iterators[dataset_idx]) + + +class SemanticDataModule(LightningDataModule): + def __init__( + self, + train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset], + val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset], + batch_size: int = 32, + tokenizer: AutoTokenizer = None, + max_length: int = 1024, + num_workers: int = 4, + ): + super().__init__() + + self.train_dataset = train_dataset + self.val_dataset = val_dataset + self.batch_size = batch_size + self.tokenizer = tokenizer + self.max_length = max_length + self.num_workers = num_workers + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + collate_fn=TextDataCollator(self.tokenizer, self.max_length), + num_workers=self.num_workers, + persistent_workers=True, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + collate_fn=TextDataCollator(self.tokenizer, self.max_length), + num_workers=self.num_workers, + persistent_workers=True, + ) + + +if __name__ == "__main__": + from tqdm import tqdm + + ds = AutoTextSemanticInstructionDataset( + ["data/protos"], + tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"), + use_speaker=False, + interactive_prob=1.0, + skip_text_prob=0.5, + ) + + for i in ds: + print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False)) + # i["labels"][0][i["labels"][0] == -100] = 0 + # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False)) + break diff --git a/fish_speech/datasets/vqgan.py b/fish_speech/datasets/vqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..a45583d22efb0feb9dc1e823bae1ef74534b299e --- /dev/null +++ b/fish_speech/datasets/vqgan.py @@ -0,0 +1,147 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import librosa +import numpy as np +import torch +from lightning import LightningDataModule +from torch.utils.data import DataLoader, Dataset + +from fish_speech.utils import RankedLogger + +logger = RankedLogger(__name__, rank_zero_only=False) + + +class VQGANDataset(Dataset): + def __init__( + self, + filelist: str, + sample_rate: int = 32000, + hop_length: int = 640, + slice_frames: Optional[int] = None, + ): + super().__init__() + + filelist = Path(filelist) + root = filelist.parent + + self.files = [ + root / line.strip() + for line in filelist.read_text(encoding="utf-8").splitlines() + if line.strip() + ] + self.sample_rate = sample_rate + self.hop_length = hop_length + self.slice_frames = slice_frames + + def __len__(self): + return len(self.files) + + def get_item(self, idx): + file = self.files[idx] + + audio, _ = librosa.load(file, sr=self.sample_rate, mono=True) + + # Slice audio and features + if ( + self.slice_frames is not None + and audio.shape[0] > self.slice_frames * self.hop_length + ): + start = np.random.randint( + 0, audio.shape[0] - self.slice_frames * self.hop_length + ) + audio = audio[start : start + self.slice_frames * self.hop_length] + + if len(audio) == 0: + return None + + max_value = np.abs(audio).max() + if max_value > 1.0: + audio = audio / max_value + + return { + "audio": torch.from_numpy(audio), + } + + def __getitem__(self, idx): + try: + return self.get_item(idx) + except Exception as e: + import traceback + + traceback.print_exc() + logger.error(f"Error loading {self.files[idx]}: {e}") + return None + + +@dataclass +class VQGANCollator: + def __call__(self, batch): + batch = [x for x in batch if x is not None] + + audio_lengths = torch.tensor([len(x["audio"]) for x in batch]) + audio_maxlen = audio_lengths.max() + + # Rounds up to nearest multiple of 2 (audio_lengths) + audios = [] + for x in batch: + audios.append( + torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"]))) + ) + + return { + "audios": torch.stack(audios), + "audio_lengths": audio_lengths, + } + + +class VQGANDataModule(LightningDataModule): + def __init__( + self, + train_dataset: VQGANDataset, + val_dataset: VQGANDataset, + batch_size: int = 32, + num_workers: int = 4, + val_batch_size: Optional[int] = None, + ): + super().__init__() + + self.train_dataset = train_dataset + self.val_dataset = val_dataset + self.batch_size = batch_size + self.val_batch_size = val_batch_size or batch_size + self.num_workers = num_workers + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + collate_fn=VQGANCollator(), + num_workers=self.num_workers, + shuffle=True, + persistent_workers=True, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.val_batch_size, + collate_fn=VQGANCollator(), + num_workers=self.num_workers, + persistent_workers=True, + ) + + +if __name__ == "__main__": + dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt") + dataloader = DataLoader( + dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator() + ) + + for batch in dataloader: + print(batch["audios"].shape) + print(batch["features"].shape) + print(batch["audio_lengths"]) + print(batch["feature_lengths"]) + break diff --git a/fish_speech/i18n/README.md b/fish_speech/i18n/README.md new file mode 100644 index 0000000000000000000000000000000000000000..700902b09db20911ef1ad678cbdce5644b84aea2 --- /dev/null +++ b/fish_speech/i18n/README.md @@ -0,0 +1,27 @@ +## i18n Folder Attribution + +The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below: + +### fish_speech/i18n/core.py + +**Related code from RVC:** +[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py) + +**Initial commit:** +add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35) + +**Initial author:** +[@L4Ph](https://github.com/L4Ph) + +### fish_speech/i18n/scan.py + +**Related code from RVC:** +[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py) + +**Initial commit:** +File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058) + +**Initial author:** +[@towzeur](https://github.com/towzeur) + +We appreciate the contributions of the RVC project and its authors. diff --git a/fish_speech/i18n/__init__.py b/fish_speech/i18n/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..981dbb3b3ecf28043ec9ff5757f947182821a246 --- /dev/null +++ b/fish_speech/i18n/__init__.py @@ -0,0 +1,3 @@ +from .core import i18n + +__all__ = ["i18n"] diff --git a/fish_speech/i18n/__pycache__/__init__.cpython-310.pyc b/fish_speech/i18n/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f82bc9c78570a30c3d65065ddadafd44f1ecdc5 Binary files /dev/null and b/fish_speech/i18n/__pycache__/__init__.cpython-310.pyc differ diff --git a/fish_speech/i18n/__pycache__/__init__.cpython-311.pyc b/fish_speech/i18n/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7af52e95a04d4575fd7af359f9d711379d58fb1 Binary files /dev/null and b/fish_speech/i18n/__pycache__/__init__.cpython-311.pyc differ diff --git a/fish_speech/i18n/__pycache__/core.cpython-310.pyc b/fish_speech/i18n/__pycache__/core.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97139b6392ea03e558405d073a5f9f5ca38b123b Binary files /dev/null and b/fish_speech/i18n/__pycache__/core.cpython-310.pyc differ diff --git a/fish_speech/i18n/__pycache__/core.cpython-311.pyc b/fish_speech/i18n/__pycache__/core.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7348d84d230b52d86f2a46b8f472399baded8f01 Binary files /dev/null and b/fish_speech/i18n/__pycache__/core.cpython-311.pyc differ diff --git a/fish_speech/i18n/core.py b/fish_speech/i18n/core.py new file mode 100644 index 0000000000000000000000000000000000000000..9f793ec95669228f7f4e8f9a7a5fe38da85c74bd --- /dev/null +++ b/fish_speech/i18n/core.py @@ -0,0 +1,40 @@ +import json +import locale +from pathlib import Path + +I18N_FILE_PATH = Path(__file__).parent / "locale" +DEFAULT_LANGUAGE = "en_US" + + +def load_language_list(language): + with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f: + language_list = json.load(f) + + return language_list + + +class I18nAuto: + def __init__(self): + i18n_file = Path(".locale") + + if i18n_file.exists(): + with open(i18n_file, "r", encoding="utf-8") as f: + language = f.read().strip() + else: + # getlocale can't identify the system's language ((None, None)) + language = locale.getdefaultlocale()[0] + + if (I18N_FILE_PATH / f"{language}.json").exists() is False: + language = DEFAULT_LANGUAGE + + self.language = language + self.language_map = load_language_list(language) + + def __call__(self, key): + return self.language_map.get(key, key) + + def __repr__(self): + return "Use Language: " + self.language + + +i18n = I18nAuto() diff --git a/fish_speech/i18n/locale/en_US.json b/fish_speech/i18n/locale/en_US.json new file mode 100644 index 0000000000000000000000000000000000000000..d36c774313628fe9d4ee60e816f404c09935e655 --- /dev/null +++ b/fish_speech/i18n/locale/en_US.json @@ -0,0 +1,123 @@ +{ + "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).", + "Accumulate Gradient Batches": "Accumulate Gradient Batches", + "Add to Processing Area": "Add to Processing Area", + "Added path successfully!": "Added path successfully!", + "Advanced Config": "Advanced Config", + "Base LLAMA Model": "Base LLAMA Model", + "Batch Inference": "Batch Inference", + "Batch Size": "Batch Size", + "Changing with the Model Path": "Changing with the Model Path", + "Chinese": "Chinese", + "Compile Model": "Compile Model", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time", + "Copy": "Copy", + "Data Preprocessing": "Data Preprocessing", + "Data Preprocessing Path": "Data Preprocessing Path", + "Data Source": "Data Source", + "Decoder Model Config": "Decoder Model Config", + "Decoder Model Path": "Decoder Model Path", + "Disabled": "Disabled", + "Enable Reference Audio": "Enable Reference Audio", + "English": "English", + "Error Message": "Error Message", + "File Preprocessing": "File Preprocessing", + "Generate": "Generate", + "Generated Audio": "Generated Audio", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format", + "Infer interface is closed": "Infer interface is closed", + "Inference Configuration": "Inference Configuration", + "Inference Server Configuration": "Inference Server Configuration", + "Inference Server Error": "Inference Server Error", + "Inferring interface is launched at {}": "Inferring interface is launched at {}", + "Initial Learning Rate": "Initial Learning Rate", + "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription", + "Input Text": "Input Text", + "Invalid path: {}": "Invalid path: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU", + "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off", + "Japanese": "Japanese", + "LLAMA Configuration": "LLAMA Configuration", + "LLAMA Model Config": "LLAMA Model Config", + "LLAMA Model Path": "LLAMA Model Path", + "Labeling Device": "Labeling Device", + "LoRA Model to be merged": "LoRA Model to be merged", + "Maximum Audio Duration": "Maximum Audio Duration", + "Maximum Length per Sample": "Maximum Length per Sample", + "Maximum Training Steps": "Maximum Training Steps", + "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit", + "Merge": "Merge", + "Merge LoRA": "Merge LoRA", + "Merge successfully": "Merge successfully", + "Minimum Audio Duration": "Minimum Audio Duration", + "Model Output Path": "Model Output Path", + "Model Size": "Model Size", + "Move": "Move", + "Move files successfully": "Move files successfully", + "No audio generated, please check the input text.": "No audio generated, please check the input text.", + "No selected options": "No selected options", + "Number of Workers": "Number of Workers", + "Open Inference Server": "Open Inference Server", + "Open Labeler WebUI": "Open Labeler WebUI", + "Open Tensorboard": "Open Tensorboard", + "Opened labeler in browser": "Opened labeler in browser", + "Optional Label Language": "Optional Label Language", + "Optional online ver": "Optional online ver", + "Output Path": "Output Path", + "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path", + "Precision": "Precision", + "Probability of applying Speaker Condition": "Probability of applying Speaker Condition", + "Put your text here.": "Put your text here.", + "Reference Audio": "Reference Audio", + "Reference Text": "Reference Text", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.", + "Remove Selected Data": "Remove Selected Data", + "Removed path successfully!": "Removed path successfully!", + "Repetition Penalty": "Repetition Penalty", + "Save model every n steps": "Save model every n steps", + "Select LLAMA ckpt": "Select LLAMA ckpt", + "Select VITS ckpt": "Select VITS ckpt", + "Select VQGAN ckpt": "Select VQGAN ckpt", + "Select source file processing method": "Select source file processing method", + "Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)", + "Selected: {}": "Selected: {}", + "Speaker": "Speaker", + "Speaker is identified by the folder name": "Speaker is identified by the folder name", + "Start Training": "Start Training", + "Streaming Audio": "Streaming Audio", + "Streaming Generate": "Streaming Generate", + "Tensorboard Host": "Tensorboard Host", + "Tensorboard Log Path": "Tensorboard Log Path", + "Tensorboard Port": "Tensorboard Port", + "Tensorboard interface is closed": "Tensorboard interface is closed", + "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}", + "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.", + "Training Configuration": "Training Configuration", + "Training Error": "Training Error", + "Training stopped": "Training stopped", + "Type name of the speaker": "Type name of the speaker", + "Type the path or select from the dropdown": "Type the path or select from the dropdown", + "Use LoRA": "Use LoRA", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model", + "Use filelist": "Use filelist", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G", + "VITS Configuration": "VITS Configuration", + "VQGAN Configuration": "VQGAN Configuration", + "Validation Batch Size": "Validation Batch Size", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.", + "WebUI Host": "WebUI Host", + "WebUI Port": "WebUI Port", + "Whisper Model": "Whisper Model", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU", + "latest": "latest", + "new": "new", + "Realtime Transform Text": "Realtime Transform Text", + "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)", + "Text Normalization": "Text Normalization", + "Select Example Audio": "Select Example Audio" +} diff --git a/fish_speech/i18n/locale/es_ES.json b/fish_speech/i18n/locale/es_ES.json new file mode 100644 index 0000000000000000000000000000000000000000..7a4757967dd0fe3807ba4d354e75ad7a88eb510e --- /dev/null +++ b/fish_speech/i18n/locale/es_ES.json @@ -0,0 +1,123 @@ +{ + "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).", + "Accumulate Gradient Batches": "Acumular lotes de gradientes", + "Add to Processing Area": "Agregar al Área de Procesamiento", + "Added path successfully!": "¡Ruta agregada exitosamente!", + "Advanced Config": "Configuración Avanzada", + "Base LLAMA Model": "Modelo Base LLAMA", + "Batch Inference": "Inferencia por Lote", + "Batch Size": "Tamaño del Lote", + "Changing with the Model Path": "Cambiando con la Ruta del Modelo", + "Chinese": "Chino", + "Compile Model": "Compilar Modelo", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío", + "Copy": "Copiar", + "Data Preprocessing": "Preprocesamiento de Datos", + "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos", + "Data Source": "Fuente de Datos", + "Decoder Model Config": "Configuración del modelo decodificador", + "Decoder Model Path": "Ruta del modelo decodificador", + "Disabled": "Desactivado", + "Enable Reference Audio": "Habilitar Audio de Referencia", + "English": "Inglés", + "Error Message": "Mensaje de Error", + "File Preprocessing": "Preprocesamiento de Archivos", + "Generate": "Generar", + "Generated Audio": "Audio Generado", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab", + "Infer interface is closed": "La interfaz de inferencia está cerrada", + "Inference Configuration": "Configuración de Inferencia", + "Inference Server Configuration": "Configuración del Servidor de Inferencia", + "Inference Server Error": "Error del Servidor de Inferencia", + "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}", + "Initial Learning Rate": "Tasa de Aprendizaje Inicial", + "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción", + "Input Text": "Texto de Entrada", + "Invalid path: {}": "Ruta inválida: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU", + "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado", + "Japanese": "Japonés", + "LLAMA Configuration": "Configuración de LLAMA", + "LLAMA Model Config": "Configuración del Modelo LLAMA", + "LLAMA Model Path": "Ruta del Modelo LLAMA", + "Labeling Device": "Dispositivo de Etiquetado", + "LoRA Model to be merged": "Modelo LoRA a fusionar", + "Maximum Audio Duration": "Duración máxima de audio", + "Maximum Length per Sample": "Longitud Máxima por Muestra", + "Maximum Training Steps": "Pasos Máximos de Entrenamiento", + "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite", + "Merge": "Fusionar", + "Merge LoRA": "Fusionar LoRA", + "Merge successfully": "Fusionado exitosamente", + "Minimum Audio Duration": "Duración mínima de audio", + "Model Output Path": "Ruta de Salida del Modelo", + "Model Size": "Tamaño del Modelo", + "Move": "Mover", + "Move files successfully": "Archivos movidos exitosamente", + "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.", + "No selected options": "No hay opciones seleccionadas", + "Number of Workers": "Número de Trabajadores", + "Open Inference Server": "Abrir Servidor de Inferencia", + "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador", + "Open Tensorboard": "Abrir Tensorboard", + "Opened labeler in browser": "Se abrió el etiquetador en el navegador", + "Optional Label Language": "Idioma de Etiquetado Opcional", + "Optional online ver": "Ver en línea opcional", + "Output Path": "Ruta de Salida", + "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente", + "Precision": "Precisión", + "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante", + "Put your text here.": "Ponga su texto aquí.", + "Reference Audio": "Audio de Referencia", + "Reference Text": "Texto de Referencia", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.", + "Remove Selected Data": "Eliminar Datos Seleccionados", + "Removed path successfully!": "¡Ruta eliminada exitosamente!", + "Repetition Penalty": "Penalización por Repetición", + "Save model every n steps": "Guardar modelo cada n pasos", + "Select LLAMA ckpt": "Seleccionar punto de control LLAMA", + "Select VITS ckpt": "Seleccionar punto de control VITS", + "Select VQGAN ckpt": "Seleccionar punto de control VQGAN", + "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente", + "Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)", + "Selected: {}": "Seleccionado: {}", + "Speaker": "Hablante", + "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta", + "Start Training": "Iniciar Entrenamiento", + "Streaming Audio": "transmisión de audio", + "Streaming Generate": "síntesis en flujo", + "Tensorboard Host": "Host de Tensorboard", + "Tensorboard Log Path": "Ruta de Registro de Tensorboard", + "Tensorboard Port": "Puerto de Tensorboard", + "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada", + "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}", + "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.", + "Training Configuration": "Configuración de Entrenamiento", + "Training Error": "Error de Entrenamiento", + "Training stopped": "Entrenamiento detenido", + "Type name of the speaker": "Escriba el nombre del hablante", + "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable", + "Use LoRA": "Usar LoRA", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo", + "Use filelist": "Usar lista de archivos", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G", + "VITS Configuration": "Configuración de VITS", + "VQGAN Configuration": "Configuración de VQGAN", + "Validation Batch Size": "Tamaño del Lote de Validación", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.", + "WebUI Host": "Host de WebUI", + "WebUI Port": "Puerto de WebUI", + "Whisper Model": "Modelo Whisper", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+", + "latest": "más reciente", + "new": "nuevo", + "Realtime Transform Text": "Transformación de Texto en Tiempo Real", + "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)", + "Text Normalization": "Normalización de Texto", + "Select Example Audio": "Selecionar áudio de exemplo" +} diff --git a/fish_speech/i18n/locale/ja_JP.json b/fish_speech/i18n/locale/ja_JP.json new file mode 100644 index 0000000000000000000000000000000000000000..863b8b0b41da7e504ac0dcc4abf707f1f71a53fa --- /dev/null +++ b/fish_speech/i18n/locale/ja_JP.json @@ -0,0 +1,123 @@ +{ + "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。", + "Accumulate Gradient Batches": "勾配バッチの累積", + "Add to Processing Area": "処理エリアに追加", + "Added path successfully!": "パスの追加に成功しました!", + "Advanced Config": "詳細設定", + "Base LLAMA Model": "基本LLAMAモデル", + "Batch Inference": "バッチ推論", + "Batch Size": "バッチサイズ", + "Changing with the Model Path": "モデルのパスに伴って変化する", + "Chinese": "中国語", + "Compile Model": "モデルのコンパイル", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります", + "Copy": "コピー", + "Data Preprocessing": "データ前処理", + "Data Preprocessing Path": "データ前処理パス", + "Data Source": "データソース", + "Decoder Model Config": "デコーダーモデルの構成", + "Decoder Model Path": "デコーダーモデルのパス", + "Disabled": "無効", + "Enable Reference Audio": "リファレンスオーディオを有効にする", + "English": "英語", + "Error Message": "エラーメッセージ", + "File Preprocessing": "文書前处理", + "Generate": "生成", + "Generated Audio": "生成されたオーディオ", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています", + "Infer interface is closed": "推論インターフェースが閉じられています", + "Inference Configuration": "推論設定", + "Inference Server Configuration": "推論サーバー設定", + "Inference Server Error": "推論サーバーエラー", + "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました", + "Initial Learning Rate": "初期学習率", + "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス", + "Input Text": "入力テキスト", + "Invalid path: {}": "無効なパス: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください", + "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します", + "Japanese": "日本語", + "LLAMA Configuration": "LLAMA設定", + "LLAMA Model Config": "LLAMAモデル設定", + "LLAMA Model Path": "LLAMAモデルパス", + "Labeling Device": "ラベリングデバイス", + "LoRA Model to be merged": "マージするLoRAモデル", + "Maximum Audio Duration": "最大オーディオの長さ", + "Maximum Length per Sample": "サンプルあたりの最大長", + "Maximum Training Steps": "最大トレーニングステップ数", + "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します", + "Merge": "マージ", + "Merge LoRA": "LoRAのマージ", + "Merge successfully": "マージに成功しました", + "Minimum Audio Duration": "最小オーディオの長さ", + "Model Output Path": "モデル出力パス", + "Model Size": "モデルサイズ", + "Move": "移動", + "Move files successfully": "ファイルの移動に成功しました", + "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。", + "No selected options": "選択されたオプションはありません", + "Number of Workers": "ワーカー数", + "Open Inference Server": "推論サーバーを開く", + "Open Labeler WebUI": "ラベラーWebUIを開く", + "Open Tensorboard": "Tensorboardを開く", + "Opened labeler in browser": "ブラウザでラベラーを開きました", + "Optional Label Language": "オプションのラベル言語", + "Optional online ver": "オプションのオンラインバージョン", + "Output Path": "出力パス", + "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください", + "Precision": "精度", + "Probability of applying Speaker Condition": "話者条件を適用する確率", + "Put your text here.": "ここにテキストを入力してください。", + "Reference Audio": "リファレンスオーディオ", + "Reference Text": "リファレンステキスト", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。", + "Remove Selected Data": "選択したデータを削除", + "Removed path successfully!": "パスの削除に成功しました!", + "Repetition Penalty": "反復ペナルティ", + "Save model every n steps": "nステップごとにモデルを保存", + "Select LLAMA ckpt": " LLAMA チェックポイントを選択", + "Select VITS ckpt": "VITS チェックポイントを選択", + "Select VQGAN ckpt": "VQGAN チェックポイントを選択", + "Select source file processing method": "ソースファイルの処理方法を選択", + "Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください", + "Selected: {}": "選択済み: {}", + "Speaker": "話者", + "Speaker is identified by the folder name": "話者はフォルダ名で識別されます", + "Start Training": "トレーニング開始", + "Streaming Audio": "ストリーミングオーディオ", + "Streaming Generate": "ストリーミング合成", + "Tensorboard Host": "Tensorboardホスト", + "Tensorboard Log Path": "Tensorboardログパス", + "Tensorboard Port": "Tensorboardポート", + "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています", + "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました", + "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。", + "Training Configuration": "トレーニング設定", + "Training Error": "トレーニングエラー", + "Training stopped": "トレーニングが停止しました", + "Type name of the speaker": "話者の名前を入力", + "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください", + "Use LoRA": "LoRAを使用", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります", + "Use filelist": "ファイルリストを使用", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください", + "VITS Configuration": "VITS の構成", + "VQGAN Configuration": "VQGAN の構成", + "Validation Batch Size": "検証バッチサイズ", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。", + "WebUI Host": "WebUIホスト", + "WebUI Port": "WebUIポート", + "Whisper Model": "Whisperモデル", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします", + "latest": "最新", + "new": "新規", + "Realtime Transform Text": "リアルタイム変換テキスト", + "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)", + "Text Normalization": "テキスト正規化", + "Select Example Audio": "サンプル音声を選択" +} diff --git a/fish_speech/i18n/locale/ko_KR.json b/fish_speech/i18n/locale/ko_KR.json new file mode 100644 index 0000000000000000000000000000000000000000..180263874b476059870035d4c2b74ce5fa553a8a --- /dev/null +++ b/fish_speech/i18n/locale/ko_KR.json @@ -0,0 +1,123 @@ +{ + "16-mixed is recommended for 10+ series GPU": "10+ 시리즈 GPU에는 16-mixed를 권장합니다.", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "화자를 특정하는 데 유의미한 5~10초의 길이의 참조 오디오 데이터.", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)에서 개발한 VQ-GAN 및 Llama 기반의 텍스트 음성 변환 모델.", + "Accumulate Gradient Batches": "그라디언트 배치 누적", + "Add to Processing Area": "처리 영역에 추가", + "Added path successfully!": "경로가 성공적으로 추가되었습니다!", + "Advanced Config": "고급 설정", + "Base LLAMA Model": "기본 LLAMA 모델", + "Batch Inference": "배치 추론", + "Batch Size": "배치 크기", + "Changing with the Model Path": "모델 경로에 따라 변경 중", + "Chinese": "중국어", + "Compile Model": "모델 컴파일", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "모델을 컴파일하면 추론 시간이 크게 줄어들지만, 초기 시작 시간이 길어집니다.", + "Copy": "복사", + "Data Preprocessing": "데이터 전처리", + "Data Preprocessing Path": "데이터 전처리 경로", + "Data Source": "데이터 소스", + "Decoder Model Config": "디코더 모델 설정", + "Decoder Model Path": "디코더 모델 경로", + "Disabled": "비활성화 됨", + "Enable Reference Audio": "참고 음성 활성화", + "English": "영어", + "Error Message": "오류 메시지", + "File Preprocessing": "파일 전처리", + "Generate": "생성", + "Generated Audio": "생성된 오디오", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "오디오애 대응하는 텍스트가 없을 경우, ASR을 적용해 지원하며, .txt 또는 .lab 형식을 지원합니다.", + "Infer interface is closed": "추론 인터페이스가 닫혔습니다.", + "Inference Configuration": "추론 설정", + "Inference Server Configuration": "추론 서버 설정", + "Inference Server Error": "추론 서버 오류", + "Inferring interface is launched at {}": "추론 인터페이스가 {}에서 시작되었습니다.", + "Initial Learning Rate": "초기 학습률", + "Input Audio & Source Path for Transcription": "전사할 입력 오디오 및 소스 경로", + "Input Text": "입력 텍스트", + "Invalid path: {}": "유효하지 않은 경로: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDA 사용을 권장하며, 낮은 사양일 경우 CPU를 사용하는 것을 권장합니다.", + "Iterative Prompt Length, 0 means off": "반복 프롬프트 길이. (0:비활성화)", + "Japanese": "일본어", + "LLAMA Configuration": "LLAMA 설정", + "LLAMA Model Config": "LLAMA 모델 설정", + "LLAMA Model Path": "LLAMA 모델 경로", + "Labeling Device": "라벨링 장치", + "LoRA Model to be merged": "병합할 LoRA 모델", + "Maximum Audio Duration": "최대 오디오 길이", + "Maximum Length per Sample": "샘플당 최대 길이", + "Maximum Training Steps": "최대 학습 단계", + "Maximum tokens per batch, 0 means no limit": "배치당 최대 토큰 수(0:제한 없음)", + "Merge": "병합", + "Merge LoRA": "LoRA 병합", + "Merge successfully": "성공적으로 병합 되었습니다.", + "Minimum Audio Duration": "최소 오디오 길이", + "Model Output Path": "모델 출력 경로", + "Model Size": "모델 크기", + "Move": "이동", + "Move files successfully": "파일이 성공적으로 이동되었습니다.", + "No audio generated, please check the input text.": "생성된 오디오가 없습니다. 입력된 텍스트를 확인하세요.", + "No selected options": "옵션이 선택되지 않았습니다.", + "Number of Workers": "작업자 수", + "Open Inference Server": "추론 서버 열기", + "Open Labeler WebUI": "라벨러 WebUI 열기", + "Open Tensorboard": "Tensorboard 열기", + "Opened labeler in browser": "브라우저에서 라벨러가 열렸습니다.", + "Optional Label Language": "선택적 라벨 언어", + "Optional online ver": "온라인 버전 선택", + "Output Path": "출력 경로", + "Path error, please check the model file exists in the corresponding path": "경로 오류, 해당 경로에 모델 파일이 있는지 확인하십시오.", + "Precision": "정밀도", + "Probability of applying Speaker Condition": "화자 조건 적용 확률", + "Put your text here.": "여기에 텍스트를 입력하세요.", + "Reference Audio": "참고 오디오", + "Reference Text": "참고 텍스트", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "관련 코드 및 가중치는 CC BY-NC-SA 4.0 라이선스 하에 배포됩니다.", + "Remove Selected Data": "선택한 데이터 제거", + "Removed path successfully!": "경로가 성공적으로 제거되었습니다!", + "Repetition Penalty": "반복 패널티", + "Save model every n steps": "n 단계마다 모델 저장", + "Select LLAMA ckpt": "LLAMA ckpt 선택", + "Select VITS ckpt": "VITS ckpt 선택", + "Select VQGAN ckpt": "VQGAN ckpt 선택", + "Select source file processing method": "소스 파일 처리 방법 선택", + "Select the model to be trained (Depending on the Tab page you are on)": "학습할 모델 선택(탭 페이지에 따라 다름)", + "Selected: {}": "선택됨: {}", + "Speaker": "화자", + "Speaker is identified by the folder name": "화자는 폴더 이름으로 식별됩니다", + "Start Training": "학습 시작", + "Streaming Audio": "스트리밍 오디오", + "Streaming Generate": "스트리밍 생성", + "Tensorboard Host": "Tensorboard 호스트", + "Tensorboard Log Path": "Tensorboard 로그 경로", + "Tensorboard Port": "Tensorboard 포트", + "Tensorboard interface is closed": "Tensorboard 인터페이스가 닫혔습니다", + "Tensorboard interface is launched at {}": "Tensorboard 인터페이스가 {}에서 시작되었습니다.", + "Text is too long, please keep it under {} characters.": "텍스트가 너무 깁니다. {}자 이하로 입력해주세요.", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "왼쪽의 입력 폴더 경로 또는 파일 목록의 경로. 체크 여부에 관계없이 이 목록에서 후속 학습에 사용됩니다.", + "Training Configuration": "학습 설정", + "Training Error": "학습 오류", + "Training stopped": "학습이 중지되었습니다.", + "Type name of the speaker": "화자의 이름을 입력하세요.", + "Type the path or select from the dropdown": "경로를 입력하거나 드롭다운에서 선택하세요.", + "Use LoRA": "LoRA 사용", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRA를 사용하면 GPU 메모리를 절약할 수 있지만, 모델의 품질이 저하될 수 있습니다.", + "Use filelist": "파일 목록 사용", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 환경에선 large, 5G에선 medium, 2G에선 small을 사용할 것을 권장합니다.", + "VITS Configuration": "VITS 설정", + "VQGAN Configuration": "VQGAN 설정", + "Validation Batch Size": "검증 배치 크기", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "전처리 폴더의 상태를 확인합니다(슬라이더를 사용하여 트리의 깊이를 조절합니다)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "모델의 오용에 대해 책임지지 않습니다. 사용하기 전에 현지 법률과 규정을 고려하시길 바랍니다.", + "WebUI Host": "WebUI 호스트", + "WebUI Port": "WebUI 포트", + "Whisper Model": "Whisper 모델", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "소스 코드는 [이곳](https://github.com/fishaudio/fish-speech)에서, 모델은 [이곳](https://huggingface.co/fishaudio/fish-speech-1)에서 확인하실 수 있습니다.", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 시리즈 GPU에는 bf16-true를, 10+ 시리즈 GPU에는 16-mixed를 권장합니다", + "latest": "최신", + "new": "새로운", + "Realtime Transform Text": "실시간 텍스트 변환", + "Normalization Result Preview (Currently Only Chinese)": "정규화 결과 미리보기(현재 중국어만 지원)", + "Text Normalization": "텍스트 정규화", + "Select Example Audio": "예시 오디오 선택" +} diff --git a/fish_speech/i18n/locale/pt_BR.json b/fish_speech/i18n/locale/pt_BR.json new file mode 100644 index 0000000000000000000000000000000000000000..385f20272e19053ab9b6cf6463a84c8ece768c68 --- /dev/null +++ b/fish_speech/i18n/locale/pt_BR.json @@ -0,0 +1,133 @@ +{ + "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).", + "Accumulate Gradient Batches": "Acumular Lotes de Gradiente", + "Add to Processing Area": "Adicionar à Área de Processamento", + "Added path successfully!": "Caminho adicionado com sucesso!", + "Advanced Config": "Configuração Avançada", + "Base LLAMA Model": "Modelo LLAMA Base", + "Batch Inference": "Inferência em Lote", + "Batch Size": "Tamanho do Lote", + "Changing with the Model Path": "Alterando com o Caminho do Modelo", + + "Compile Model": "Compilar Modelo", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial", + "Copy": "Copiar", + "Data Preprocessing": "Pré-processamento de Dados", + "Data Preprocessing Path": "Caminho de Pré-processamento de Dados", + "Data Source": "Fonte de Dados", + "Decoder Model Config": "Configuração do Modelo Decodificador", + "Decoder Model Path": "Caminho do Modelo Decodificador", + "Disabled": "Desativado", + "Enable Initial Prompt": "Habilitar Prompt Inicial", + "Enable Reference Audio": "Habilitar Áudio de Referência", + "English": "Inglês", + "Japanese": "Japonês", + "Chinese": "Chinês", + "Portuguese": "Português", + "Spanish": "Espanhol", + "Error Message": "Mensagem de Erro", + "Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)", + "File Preprocessing": "Pré-processamento de Arquivos", + "Generate": "Gerar", + "Generated Audio": "Áudio Gerado", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)", + "Infer interface is closed": "A interface de inferência foi fechada", + "Inference Configuration": "Configuração de Inferência", + "Inference Server Configuration": "Configuração do Servidor de Inferência", + "Inference Server Error": "Erro do Servidor de Inferência", + "Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}", + "Initial Learning Rate": "Taxa de Aprendizagem Inicial", + "Initial Prompt": "Prompt Inicial", + "Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.", + "Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição", + "Input Text": "Texto de Entrada", + "Invalid path: {}": "Caminho inválido: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU", + "Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)", + "LLAMA Configuration": "Configuração do LLAMA", + "LLAMA Model Config": "Configuração do Modelo LLAMA", + "LLAMA Model Path": "Caminho do Modelo LLAMA", + "Labeling Device": "Dispositivo de Rotulagem", + "LoRA Model to be merged": "Modelo LoRA para mesclagem", + "Maximum Length per Sample": "Comprimento Máximo por Amostra", + "Maximum Training Steps": "Etapas Máximas de Treinamento", + "Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite", + "Merge": "Mesclar", + "Merge LoRA": "Mesclar LoRA", + "Merge successfully": "Mesclado com sucesso", + "Model Output Path": "Caminho de Saída do Modelo", + "Model Quantization": "Quantização do Modelo", + "Model Size": "Tamanho do Modelo", + "Move": "Mover", + "Move files successfully": "Arquivos movidos com sucesso", + "No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.", + "No selected options": "Nenhuma opção selecionada", + "Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)", + "Number of Workers": "Número de Processos", + "Open Inference Server": "Abrir Servidor de Inferência", + "Open Labeler WebUI": "Abrir WebUI de Rotulagem", + "Open Tensorboard": "Abrir Tensorboard", + "Opened labeler in browser": "WebUI de rotulagem aberta no navegador", + "Optional Label Language": "Idioma do Rótulo (Opcional)", + "Optional online ver": "Versão online (opcional)", + "Output Path": "Caminho de Saída", + "Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente", + "Post-quantification Precision": "Precisão Pós-quantização", + "Precision": "Precisão", + "Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador", + "Put your text here.": "Insira seu texto aqui.", + "Quantify": "Quantizar", + "Quantify successfully": "Quantizado com sucesso", + "Realtime Transform Text": "Transformar Texto em Tempo Real", + "Reference Audio": "Áudio de Referência", + "Reference Text": "Texto de Referência", + "warning": "Aviso", + "Pre-processing begins...": "O pré-processamento começou!", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.", + "Remove Selected Data": "Remover Dados Selecionados", + "Removed path successfully!": "Caminho removido com sucesso!", + "Repetition Penalty": "Penalidade de Repetição", + "Save model every n steps": "Salvar modelo a cada n etapas", + "Select LLAMA ckpt": "Selecionar .ckpt do LLAMA", + "Select source file processing method": "Escolha como processar o arquivo de origem", + "Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)", + "Selected: {}": "Selecionado: {}", + "Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta", + "Start Training": "Iniciar Treinamento", + "Streaming Audio": "Áudio em Streaming", + "Streaming Generate": "Geração em Streaming", + "Tensorboard Host": "Host do Tensorboard", + "Tensorboard Log Path": "Caminho de Log do Tensorboard", + "Tensorboard Port": "Porta do Tensorboard", + "Tensorboard interface is closed": "A interface do Tensorboard está fechada", + "Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}", + "Text Normalization": "Normalização de Texto", + "Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.", + "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.", + "Training Configuration": "Configuração de Treinamento", + "Training Error": "Erro de Treinamento", + "Training stopped": "Treinamento interrompido!", + "Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso", + "Use LoRA": "Usar LoRA", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade", + "Use filelist": "Usar lista de arquivos", + "VQGAN Configuration": "Configuração do VQGAN", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.", + "WebUI Host": "Host da WebUI", + "WebUI Port": "Porta da WebUI", + "Whisper Model": "Modelo Whisper", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).", + "auto": "automático", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+", + "latest": "mais recente", + "new": "novo", + "This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.", + "You don't need to train this model!": "Não é necessário treinar este modelo!", + "Yes": "Sim", + "No": "Não", + "version:": "versão:", + "author:": "autor:" +} diff --git a/fish_speech/i18n/locale/zh_CN.json b/fish_speech/i18n/locale/zh_CN.json new file mode 100644 index 0000000000000000000000000000000000000000..9068ef0b9a41b9941b37644c6a4c96ec6a5d836e --- /dev/null +++ b/fish_speech/i18n/locale/zh_CN.json @@ -0,0 +1,123 @@ +{ + "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.", + "Accumulate Gradient Batches": "梯度累积批次", + "Add to Processing Area": "加入处理区", + "Added path successfully!": "添加路径成功!", + "Advanced Config": "高级参数", + "Base LLAMA Model": "基础 LLAMA 模型", + "Batch Inference": "批量推理", + "Batch Size": "批次大小", + "Changing with the Model Path": "随模型路径变化", + "Chinese": "中文", + "Compile Model": "编译模型", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间", + "Copy": "复制", + "Data Preprocessing": "数据预处理", + "Data Preprocessing Path": "数据预处理路径", + "Data Source": "数据源", + "Decoder Model Config": "解码器模型配置", + "Decoder Model Path": "解码器模型路径", + "Disabled": "禁用", + "Enable Reference Audio": "启用参考音频", + "English": "英文", + "Error Message": "错误信息", + "File Preprocessing": "文件预处理", + "Generate": "生成", + "Generated Audio": "音频", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式", + "Infer interface is closed": "推理界面已关闭", + "Inference Configuration": "推理配置", + "Inference Server Configuration": "推理服务器配置", + "Inference Server Error": "推理服务器错误", + "Inferring interface is launched at {}": "推理界面已在 {} 上启动", + "Initial Learning Rate": "初始学习率", + "Input Audio & Source Path for Transcription": "输入音频和转录源路径", + "Input Text": "输入文本", + "Invalid path: {}": "无效路径: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU", + "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭", + "Japanese": "日文", + "LLAMA Configuration": "LLAMA 配置", + "LLAMA Model Config": "LLAMA 模型配置", + "LLAMA Model Path": "LLAMA 模型路径", + "Labeling Device": "标注加速设备", + "LoRA Model to be merged": "要合并的 LoRA 模型", + "Maximum Audio Duration": "最大音频时长", + "Maximum Length per Sample": "每个样本的最大长度", + "Maximum Training Steps": "最大训练步数", + "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制", + "Merge": "合并", + "Merge LoRA": "合并 LoRA", + "Merge successfully": "合并成功", + "Minimum Audio Duration": "最小音频时长", + "Model Output Path": "模型输出路径", + "Model Size": "模型规模", + "Move": "移动", + "Move files successfully": "移动文件成功", + "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.", + "No selected options": "没有选择的选项", + "Number of Workers": "数据加载进程数", + "Open Inference Server": "打开推理服务器", + "Open Labeler WebUI": "打开标注工具", + "Open Tensorboard": "打开 Tensorboard", + "Opened labeler in browser": "在浏览器中打开标注工具", + "Optional Label Language": "[可选] 标注语言", + "Optional online ver": "[可选] 使用在线版", + "Output Path": "输出路径", + "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径", + "Precision": "精度", + "Probability of applying Speaker Condition": "应用说话人条件的概率", + "Put your text here.": "在此处输入文本.", + "Reference Audio": "参考音频", + "Reference Text": "参考文本", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.", + "Remove Selected Data": "移除选中数据", + "Removed path successfully!": "移除路径成功!", + "Repetition Penalty": "重复惩罚", + "Save model every n steps": "每 n 步保存模型", + "Select LLAMA ckpt": "选择 LLAMA 检查点", + "Select VITS ckpt": "选择 VITS 检查点", + "Select VQGAN ckpt": "选择 VQGAN 检查点", + "Select source file processing method": "选择源文件处理方法", + "Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型", + "Selected: {}": "已选择: {}", + "Speaker": "说话人", + "Speaker is identified by the folder name": "自动根据父目录名称识别说话人", + "Start Training": "开始训练", + "Streaming Audio": "流式音频", + "Streaming Generate": "流式合成", + "Tensorboard Host": "Tensorboard 监听地址", + "Tensorboard Log Path": "Tensorboard 日志路径", + "Tensorboard Port": "Tensorboard 端口", + "Tensorboard interface is closed": "Tensorboard 界面已关闭", + "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动", + "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.", + "Training Configuration": "训练配置", + "Training Error": "训练错误", + "Training stopped": "训练已停止", + "Type name of the speaker": "输入说话人的名称", + "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择", + "Use LoRA": "使用 LoRA", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量", + "Use filelist": "使用文件列表", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small", + "VITS Configuration": "VITS 配置", + "VQGAN Configuration": "VQGAN 配置", + "Validation Batch Size": "验证批次大小", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.", + "WebUI Host": "WebUI 监听地址", + "WebUI Port": "WebUI 端口", + "Whisper Model": "Whisper 模型", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed", + "latest": "最近的检查点", + "new": "创建新的检查点", + "Realtime Transform Text": "实时规范化文本", + "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览", + "Text Normalization": "文本规范化", + "Select Example Audio": "选择参考音频" +} diff --git a/fish_speech/i18n/scan.py b/fish_speech/i18n/scan.py new file mode 100644 index 0000000000000000000000000000000000000000..d0194c0f1a31dc95309c64626d13f04751a44ba1 --- /dev/null +++ b/fish_speech/i18n/scan.py @@ -0,0 +1,122 @@ +import ast +import glob +import json +from collections import OrderedDict +from pathlib import Path + +from loguru import logger + +from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH + + +def extract_i18n_strings(node): + i18n_strings = [] + + if ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Name) + and node.func.id == "i18n" + ): + for arg in node.args: + if isinstance(arg, ast.Str): + i18n_strings.append(arg.s) + + for child_node in ast.iter_child_nodes(node): + i18n_strings.extend(extract_i18n_strings(child_node)) + + return i18n_strings + + +# scan the directory for all .py files (recursively) +# for each file, parse the code into an AST +# for each AST, extract the i18n strings + +strings = [] +folders = ["fish_speech", "tools"] +# for filename in glob.iglob("**/*.py", recursive=True): +for folder in folders: + for f in Path(folder).rglob("*.py"): + code = f.read_text(encoding="utf-8") + if "i18n(" in code: + tree = ast.parse(code) + i18n_strings = extract_i18n_strings(tree) + logger.info(f"Found {len(i18n_strings)} i18n strings in {f}") + strings.extend(i18n_strings) + +code_keys = set(strings) +logger.info(f"Total unique: {len(code_keys)}") + + +standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json" +with open(standard_file, "r", encoding="utf-8") as f: + standard_data = json.load(f, object_pairs_hook=OrderedDict) +standard_keys = set(standard_data.keys()) + +# Define the standard file name +unused_keys = standard_keys - code_keys +logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}") +for unused_key in unused_keys: + logger.info(f"\t{unused_key}") + +missing_keys = code_keys - standard_keys +logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}") +for missing_key in missing_keys: + logger.info(f"\t{missing_key}") + +code_keys_dict = OrderedDict() +for s in strings: + code_keys_dict[s] = s + +# write back +with open(standard_file, "w", encoding="utf-8") as f: + json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True) + f.write("\n") + +logger.info(f"Updated {standard_file}") + + +# Define the standard file name +standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json" + +# Find all JSON files in the directory +dir_path = I18N_FILE_PATH +languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE] + +# Load the standard file +with open(standard_file, "r", encoding="utf-8") as f: + standard_data = json.load(f, object_pairs_hook=OrderedDict) + +# Loop through each language file +for lang_file in languages: + # Load the language file + with open(lang_file, "r", encoding="utf-8") as f: + lang_data = json.load(f, object_pairs_hook=OrderedDict) + + # Find the difference between the language file and the standard file + diff = set(standard_data.keys()) - set(lang_data.keys()) + + miss = set(lang_data.keys()) - set(standard_data.keys()) + + # Add any missing keys to the language file + for key in diff: + lang_data[key] = "#!" + key + logger.info(f"Added missing key: {key} to {lang_file}") + + # Del any extra keys to the language file + for key in miss: + del lang_data[key] + logger.info(f"Del extra key: {key} from {lang_file}") + + # Sort the keys of the language file to match the order of the standard file + lang_data = OrderedDict( + sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0])) + ) + + # Save the updated language file + with open(lang_file, "w", encoding="utf-8") as f: + json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True) + f.write("\n") + + logger.info(f"Updated {lang_file}") + +logger.info("Done") diff --git a/fish_speech/models/text2semantic/__init__.py b/fish_speech/models/text2semantic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fish_speech/models/text2semantic/__pycache__/__init__.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddbc750faebdda397ea9ee396540f805dbe37ad8 Binary files /dev/null and b/fish_speech/models/text2semantic/__pycache__/__init__.cpython-310.pyc differ diff --git a/fish_speech/models/text2semantic/__pycache__/lit_module.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/lit_module.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8a12b6c647b866c1dcd1615ff10de5cc5bf13b4 Binary files /dev/null and b/fish_speech/models/text2semantic/__pycache__/lit_module.cpython-310.pyc differ diff --git a/fish_speech/models/text2semantic/__pycache__/llama.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/llama.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..893946c7e719a7a96b24a98c22c457735b66942f Binary files /dev/null and b/fish_speech/models/text2semantic/__pycache__/llama.cpython-310.pyc differ diff --git a/fish_speech/models/text2semantic/__pycache__/lora.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/lora.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b40834cf2ef865bb0b1d78d34f7df84f73f3f61 Binary files /dev/null and b/fish_speech/models/text2semantic/__pycache__/lora.cpython-310.pyc differ diff --git a/fish_speech/models/text2semantic/lit_module.py b/fish_speech/models/text2semantic/lit_module.py new file mode 100644 index 0000000000000000000000000000000000000000..df970400f8a073be4c4166a697245fabdf6b09b0 --- /dev/null +++ b/fish_speech/models/text2semantic/lit_module.py @@ -0,0 +1,202 @@ +from typing import Any, Optional + +import lightning as L +import torch +import torch.nn.functional as F +from lightning.pytorch.utilities.types import OptimizerLRScheduler + +import fish_speech.utils as utils +from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID +from fish_speech.models.text2semantic.llama import NaiveTransformer + +log = utils.RankedLogger(__name__, rank_zero_only=True) + + +class TextToSemantic(L.LightningModule): + def __init__( + self, + model: NaiveTransformer, + optimizer: Any, + lr_scheduler: Any, + ): + super().__init__() + + self.model = model + self.optimizer_builder = optimizer + self.lr_scheduler_builder = lr_scheduler + + def forward(self, x): + return self.model(x) + + def on_save_checkpoint(self, checkpoint): + # Save only LoRA parameters + state_dict = checkpoint["state_dict"] + use_lora = any("lora" in name for name in state_dict.keys()) + if not use_lora: + return + + for name in list(state_dict.keys()): + if "lora" not in name: + state_dict.pop(name) + + def configure_optimizers(self) -> OptimizerLRScheduler: + # Get weight decay parameters + weight_decay_parameters, other_parameters = [], [] + for name, param in self.named_parameters(): + if ".bias" in name or "norm.weight" in name or ".embeddings." in name: + other_parameters.append(param) + else: + weight_decay_parameters.append(param) + + optimizer = self.optimizer_builder( + [ + {"params": weight_decay_parameters}, + {"params": other_parameters, "weight_decay": 0.0}, + ] + ) + + # Print the parameters and their weight decay + for i in optimizer.param_groups: + log.info( + f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters" + ) + + lr_scheduler = self.lr_scheduler_builder(optimizer) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": lr_scheduler, + "interval": "step", + }, + } + + # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90 + def get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + assert logits.shape[:-1] == labels.shape + + labels = labels.clone() + loss_mask = labels != -100 + + # dummy token; we'll ignore the losses on these tokens later + labels[labels == -100] = 0 + + per_token_logps = torch.gather( + logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1) + ).squeeze(-1) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def _step(self, batch, batch_idx, stage: str): + is_train = stage == "train" + + if is_train: + # Key part to make lora work + # Otherwise the parameters are merged, which lead to incorrect gradients + self.model.train() + + # Do positive and negative samples in the same batch to speed up training + labels = batch["labels"] + outputs = self.model( + inp=batch["inputs"], + key_padding_mask=batch["attention_masks"], + ) + token_logits = outputs.token_logits + codebook_logits = outputs.codebook_logits + + # Generate labels + base_loss = F.cross_entropy( + token_logits.view(-1, token_logits.size(-1)), + labels[:, 0].reshape(-1), + ignore_index=-100, + ) + + codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT + semantic_loss = F.cross_entropy( + codebook_logits.view(-1, codebook_logits.size(-1)), + codebook_labels.reshape(-1), + ignore_index=-100, + ) + + loss = base_loss + semantic_loss + + self.log( + f"{stage}/loss", + loss, + on_step=is_train, + on_epoch=not is_train, + prog_bar=True, + logger=True, + sync_dist=not is_train, + ) + + self.log( + f"{stage}/base_loss", + base_loss, + on_step=is_train, + on_epoch=not is_train, + prog_bar=False, + logger=True, + sync_dist=not is_train, + ) + + self.log( + f"{stage}/semantic_loss", + semantic_loss, + on_step=is_train, + on_epoch=not is_train, + prog_bar=False, + logger=True, + sync_dist=not is_train, + ) + + # Top-5 accuracy + accuracy = self.get_accuracy(codebook_logits, codebook_labels) + self.log( + f"{stage}/top_5_accuracy", + accuracy, + on_step=is_train, + on_epoch=not is_train, + prog_bar=True, + logger=True, + sync_dist=not is_train, + ) + + return loss + + def get_accuracy(self, logits, labels): + mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID) + if mask.sum() == 0: + return torch.tensor(0.0, device=logits.device) + + _, indices = logits.topk(5, dim=-1) + correct = indices.eq(labels.unsqueeze(-1)) + correct[~mask] = 0 + correct = correct.sum() + accuracy = correct / mask.sum() + + return accuracy + + def training_step(self, batch, batch_idx): + return self._step(batch, batch_idx, "train") + + def validation_step(self, batch, batch_idx): + return self._step(batch, batch_idx, "val") diff --git a/fish_speech/models/text2semantic/llama.py b/fish_speech/models/text2semantic/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..4b5cd276c0c382a3334c45ca9bf74ea1c8a142d5 --- /dev/null +++ b/fish_speech/models/text2semantic/llama.py @@ -0,0 +1,752 @@ +import json +import math +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange +from loguru import logger +from torch import Tensor +from torch.nn import functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.utils.checkpoint import checkpoint +from transformers import AutoTokenizer + +from fish_speech.conversation import SEMANTIC_TOKEN +from fish_speech.utils import RankedLogger + +from .lora import LoraConfig, setup_lora + +log = RankedLogger(__name__, rank_zero_only=True) + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +@dataclass +class BaseModelArgs: + model_type: str = "base" + + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + max_seq_len: int = 2048 + dropout: float = 0.0 + tie_word_embeddings: bool = True + attention_qkv_bias: bool = False + + # Codebook configs + codebook_size: int = 160 + num_codebooks: int = 4 + + # Gradient checkpointing + use_gradient_checkpointing: bool = True + + # Initialize the model + initializer_range: float = 0.02 + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + self.head_dim = self.dim // self.n_head + + @staticmethod + def from_pretrained(path: str): + path = Path(path) + + if path.is_dir(): + path = path / "config.json" + + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + + match data["model_type"]: + case "naive": + cls = NaiveModelArgs + case "dual_ar": + cls = DualARModelArgs + case _: + raise ValueError(f"Unknown model type: {data['model_type']}") + + return cls(**data) + + def save(self, path: str): + with open(path, "w") as f: + json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False) + + +@dataclass +class NaiveModelArgs(BaseModelArgs): + model_type: str = "naive" + + +@dataclass +class DualARModelArgs(BaseModelArgs): + model_type: str = "dual_ar" + n_fast_layer: int = 4 + + +class KVCache(nn.Module): + def __init__( + self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16 + ): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim) + self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + + +@dataclass +class TransformerForwardResult: + token_logits: Tensor + codebook_logits: Tensor + + +@dataclass +class BaseTransformerForwardResult: + logits: Tensor + hidden_states: Tensor + + +class BaseTransformer(nn.Module): + def __init__( + self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True + ) -> None: + super().__init__() + self.config = config + self.tokenizer = tokenizer + + self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN) + + # Slow transformer + self.embeddings = nn.Embedding( + config.vocab_size, + config.dim, + ) + self.codebook_embeddings = nn.Embedding( + config.codebook_size * config.num_codebooks, + config.dim, + ) + self.layers = nn.ModuleList( + TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer) + ) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + + if self.config.tie_word_embeddings is False: + self.output = nn.Linear( + config.dim, + config.vocab_size, + bias=False, + ) + + self.register_buffer( + "freqs_cis", + precompute_freqs_cis( + config.max_seq_len, + config.dim // config.n_head, + config.rope_base, + ), + persistent=False, + ) + self.register_buffer( + "causal_mask", + torch.tril( + torch.ones( + config.max_seq_len, + config.max_seq_len, + dtype=torch.bool, + ) + ), + persistent=False, + ) + + # For kv cache + self.max_batch_size = -1 + self.max_seq_len = -1 + + if init_weights: + self.apply(self._init_weights) + + def setup_caches( + self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16 + ): + if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size: + return + + head_dim = self.config.dim // self.config.n_head + max_seq_len = find_multiple(max_seq_len, 8) + self.max_seq_len = max_seq_len + self.max_batch_size = max_batch_size + + for b in self.layers: + b.attention.kv_cache = KVCache( + max_batch_size, + max_seq_len, + self.config.n_local_heads, + head_dim, + dtype=dtype, + ) + + def embed(self, x: Tensor) -> Tensor: + vocab_embeds = [self.embeddings(x[:, 0])] + for i in range(self.config.num_codebooks): + emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size) + emb[x[:, 0] != self.semantic_token_id] = 0 + vocab_embeds.append(emb) + + x = torch.stack(vocab_embeds, dim=3) + x = x.sum(dim=3) + + return x + + def forward( + self, + inp: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> BaseTransformerForwardResult: + seq_len = inp.size(2) + + # Here we want to merge the embeddings of the codebooks + x = self.embed(inp) + + freqs_cis = self.freqs_cis[:seq_len] + + # Not that the causal mask here follows the definition of scaled_dot_product_attention + # That is, FALSE means masked out + # To maintain consistency, key_padding_mask use TRUE to mask out + mask = None + if key_padding_mask is not None: + mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K) + mask = mask & key_padding_mask[:, None, None, :].logical_not() + + for layer in self.layers: + if self.config.use_gradient_checkpointing and self.training: + x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True) + else: + x = layer(x, freqs_cis, mask) + + # We got slow_out here + slow_out = self.norm(x) + + if self.config.tie_word_embeddings: + token_logits = F.linear(slow_out, self.embeddings.weight) + else: + token_logits = self.output(slow_out) + + return BaseTransformerForwardResult( + logits=token_logits, + hidden_states=x, + ) + + def forward_generate( + self, + x: Tensor, + input_pos: Optional[Tensor] = None, + return_all: bool = False, + ) -> BaseTransformerForwardResult: + # This is used for generation, optimized for torch compile + assert ( + self.max_seq_len != -1 and self.max_batch_size != -1 + ), "Please call setup_caches before forward_generate" + + x = self.embed(x) + + mask = self.causal_mask[ + None, None, input_pos, : self.max_seq_len + ] # (B, N, Q, K) + freqs_cis = self.freqs_cis[input_pos] + + for layer in self.layers: + x = layer(x, freqs_cis, mask, input_pos=input_pos) + + # If prefill, we only calculate the logits of last token + if x.size(1) > 1 and not return_all: + x = x[:, -1:] + + # We got slow_out here + slow_out = self.norm(x) + + if self.config.tie_word_embeddings: + token_logits = F.linear(slow_out, self.embeddings.weight) + else: + token_logits = self.output(slow_out) + + return BaseTransformerForwardResult( + logits=token_logits, + hidden_states=x, + ) + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @staticmethod + def from_pretrained( + path: str, + load_weights: bool = False, + max_length: int | None = None, + lora_config: LoraConfig | None = None, + rope_base: int | None = None, + ) -> "BaseTransformer": + config = BaseModelArgs.from_pretrained(str(path)) + if max_length is not None: + config.max_seq_len = max_length + log.info(f"Override max_seq_len to {max_length}") + + if rope_base is not None: + config.rope_base = rope_base + log.info(f"Override rope_base to {rope_base}") + + match config.model_type: + case "naive": + model_cls = NaiveTransformer + case "dual_ar": + model_cls = DualARTransformer + case _: + raise ValueError(f"Unknown model type: {config.model_type}") + + tokenizer = AutoTokenizer.from_pretrained(str(path)) + log.info(f"Loading model from {path}, config: {config}") + model = model_cls(config, tokenizer=tokenizer) + + if lora_config is not None: + setup_lora(model, lora_config) + log.info(f"LoRA setup: {lora_config}") + + if load_weights is False: + log.info("Randomly initialized model") + else: + + if "int8" in str(Path(path)): + logger.info("Using int8 weight-only quantization!") + from tools.llama.quantize import WeightOnlyInt8QuantHandler + + simple_quantizer = WeightOnlyInt8QuantHandler(model) + model = simple_quantizer.convert_for_runtime() + + if "int4" in str(Path(path)): + logger.info("Using int4 quantization!") + path_comps = path.name.split("-") + assert path_comps[-2].startswith("g") + groupsize = int(path_comps[-2][1:]) + from tools.llama.quantize import WeightOnlyInt4QuantHandler + + simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) + model = simple_quantizer.convert_for_runtime() + + weights = torch.load( + Path(path) / "model.pth", map_location="cpu", mmap=True + ) + err = model.load_state_dict(weights, strict=False, assign=True) + log.info(f"Loaded weights with error: {err}") + + return model + + def save_pretrained(self, path: str, drop_lora: bool = False): + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + + self.config.save(path / "config.json") + state_dict = self.state_dict() + + if drop_lora: + for key in list(state_dict.keys()): + if "lora" not in key: + continue + + state_dict.pop(key) + log.info(f"Drop LoRA parameter: {key}") + + torch.save(state_dict, path / "model.pth") + self.tokenizer.save_pretrained(path) + + +class NaiveTransformer(BaseTransformer): + def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None: + super().__init__(config, init_weights=False, tokenizer=tokenizer) + + self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.codebook_output = nn.Linear( + config.dim, + config.codebook_size * config.num_codebooks, + bias=False, + ) + + self.apply(self._init_weights) + + def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult: + token_logits = result.logits + x = result.hidden_states + + # Codebook + codebook_logits = self.codebook_output(self.codebook_norm(x)) + codebook_logits = rearrange( + codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks + ) + + return TransformerForwardResult( + token_logits=token_logits, + codebook_logits=codebook_logits, + ) + + def forward( + self, + inp: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> TransformerForwardResult: + result = super().forward( + inp=inp, + key_padding_mask=key_padding_mask, + ) + return self.decode(result) + + def forward_generate( + self, x: Tensor, input_pos: Optional[Tensor] = None + ) -> TransformerForwardResult: + result = super().forward_generate(x, input_pos) + return self.decode(result) + + +class DualARTransformer(BaseTransformer): + def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None: + super().__init__(config, init_weights=False, tokenizer=tokenizer) + + # Fast transformer + self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim) + + # The equivalent bs is so large that sdpa doesn't work + self.fast_layers = nn.ModuleList( + TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer) + ) + self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.fast_output = nn.Linear( + config.dim, + config.codebook_size, + bias=False, + ) + + self.apply(self._init_weights) + + def setup_caches( + self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16 + ): + super().setup_caches(max_batch_size, max_seq_len, dtype) + + head_dim = self.config.dim // self.config.n_head + + # Fast transformer + # The max seq len here is the number of codebooks + for b in self.fast_layers: + b.attention.kv_cache = KVCache( + max_batch_size, + self.config.num_codebooks, + self.config.n_local_heads, + head_dim, + dtype=dtype, + ) + + def forward( + self, + inp: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> TransformerForwardResult: + parent_result = super().forward(inp, key_padding_mask) + token_logits = parent_result.logits + x = parent_result.hidden_states + + # Fast transformer + fast_seq_len = self.config.num_codebooks + fast_mask = self.causal_mask[ + None, None, :fast_seq_len, :fast_seq_len + ] # (B, N, Q, K) + fast_freqs_cis = self.freqs_cis[:fast_seq_len] + + # Drop the last token and rotate left + codebooks = inp[:, 1:-1, 1:] + codebooks = F.pad(codebooks, (0, 1), value=0) + codebook_embeddings = self.fast_embeddings(codebooks) + x = torch.cat([x[:, None], codebook_embeddings], dim=1) + b, s = x.size(0), x.size(2) + x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len + + # Remove padded part + codebooks = rearrange(codebooks, "b n s -> (b s) n") + codebook_mask = (codebooks == 0).all(dim=-1) + + if torch.all(codebook_mask): + # If all codebooks are padded, we keep first 8 to make sure the model runs + codebook_mask[:8] = False + + x_bs, x_len = x.size(0), x.size(1) + x = x[~codebook_mask] + + for layer in self.fast_layers: + if self.config.use_gradient_checkpointing and self.training: + x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True) + else: + x = layer(x, fast_freqs_cis, fast_mask) + + # unflatten the batch and num_codebooks + fast_out = self.fast_norm(x) + codebook_logits = self.fast_output(fast_out) + + # Re-pad the codebook_logits + buffer = torch.zeros( + x_bs, + x_len, + codebook_logits.size(-1), + device=codebook_logits.device, + dtype=codebook_logits.dtype, + ) + buffer[~codebook_mask] = codebook_logits + codebook_logits = buffer + + assert codebook_logits.shape[1] == self.config.num_codebooks + codebook_logits = rearrange( + codebook_logits, + "(b s) n d -> b s n d", + b=b, + s=s, + n=self.config.num_codebooks, + ) + + return TransformerForwardResult( + token_logits=token_logits, + codebook_logits=codebook_logits, + ) + + def forward_generate_fast( + self, x: Tensor, input_pos: Optional[Tensor] = None + ) -> Tensor: + # Fast transformer + x = x.view(1, 1, -1) + + fast_mask = self.causal_mask[ + None, None, input_pos, : self.config.num_codebooks + ] # (B, N, Q, K) + fast_freqs_cis = self.freqs_cis[input_pos] + + for layer in self.fast_layers: + x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos) + + # unflatten the batch and num_codebooks + fast_out = self.fast_norm(x) # only take the last token + codebook_logits = self.fast_output(fast_out) + + return codebook_logits + + +class TransformerBlock(nn.Module): + def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None: + super().__init__() + self.attention = Attention(config, use_sdpa=use_sdpa) + self.feed_forward = FeedForward(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward( + self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None + ) -> Tensor: + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class Attention(nn.Module): + def __init__(self, config: BaseModelArgs, use_sdpa: bool = True): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear( + config.dim, total_head_dim, bias=config.attention_qkv_bias + ) + self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.kv_cache = None + + self.dropout = config.dropout + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self.use_sdpa = use_sdpa + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward( + self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + + if self.use_sdpa: + if mask is None: + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + y = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.dropout if self.training else 0.0, + is_causal=True, + # No third party attn_mask here to use flash_attention + ) + else: + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=self.dropout if self.training else 0.0, + ) + else: + y = self.eq_scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=self.dropout if self.training else 0.0, + ) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + return self.wo(y) + + def eq_scaled_dot_product_attention( + self, + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + ) -> torch.Tensor: + # This is a standard scaled dot product attention + # It's low efficient, but it doesn't raise cuda error + + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) + attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + + return attn_weight @ value + + +class FeedForward(nn.Module): + def __init__(self, config: BaseModelArgs) -> None: + super().__init__() + self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor: + freqs = 1.0 / ( + base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) + ) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=torch.bfloat16) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) diff --git a/fish_speech/models/text2semantic/lora.py b/fish_speech/models/text2semantic/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..647ca6fcccf038e17d2cf91a2874281dff3e0938 --- /dev/null +++ b/fish_speech/models/text2semantic/lora.py @@ -0,0 +1,92 @@ +from dataclasses import dataclass + +import loralib as lora + + +@dataclass +class LoraConfig: + r: int + lora_alpha: float + lora_dropout: float = 0.0 + + +def setup_lora(model, lora_config): + # Replace the embedding layer with a LoRA layer + model.embeddings = lora.Embedding( + num_embeddings=model.embeddings.num_embeddings, + embedding_dim=model.embeddings.embedding_dim, + padding_idx=model.embeddings.padding_idx, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + ) + + model.codebook_embeddings = lora.Embedding( + num_embeddings=model.codebook_embeddings.num_embeddings, + embedding_dim=model.codebook_embeddings.embedding_dim, + padding_idx=model.codebook_embeddings.padding_idx, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + ) + + # Replace output layer with a LoRA layer + linears = [(model, "output")] + + # Replace all linear layers with LoRA layers + for layer in model.layers: + linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) + linears.extend( + [ + (layer.feed_forward, "w1"), + (layer.feed_forward, "w2"), + (layer.feed_forward, "w3"), + ] + ) + + if hasattr(model, "fast_layers"): + model.fast_embeddings = lora.Embedding( + num_embeddings=model.fast_embeddings.num_embeddings, + embedding_dim=model.fast_embeddings.embedding_dim, + padding_idx=model.fast_embeddings.padding_idx, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + ) + + # Dual-AR model + linears.append((model, "fast_output")) + + for layer in model.fast_layers: + linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) + linears.extend( + [ + (layer.feed_forward, "w1"), + (layer.feed_forward, "w2"), + (layer.feed_forward, "w3"), + ] + ) + + for module, layer in linears: + updated_linear = lora.Linear( + in_features=getattr(module, layer).in_features, + out_features=getattr(module, layer).out_features, + bias=getattr(module, layer).bias, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + lora_dropout=lora_config.lora_dropout, + ) + setattr(module, layer, updated_linear) + + # Mark only the LoRA layers as trainable + lora.mark_only_lora_as_trainable(model, bias="none") + + +def get_merged_state_dict(model): + # This line will merge the state dict of the model and the LoRA parameters + model.eval() + + # Then we need to remove the LoRA parameters from the state dict + state_dict = model.state_dict() + for name in list(state_dict.keys()): + if "lora" in name: + state_dict.pop(name) + + return state_dict diff --git a/fish_speech/models/vqgan/__init__.py b/fish_speech/models/vqgan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fish_speech/models/vqgan/__pycache__/__init__.cpython-310.pyc b/fish_speech/models/vqgan/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..816aebc6ece20e81ff3e60c00cd37e3b851db9cc Binary files /dev/null and b/fish_speech/models/vqgan/__pycache__/__init__.cpython-310.pyc differ diff --git a/fish_speech/models/vqgan/modules/__pycache__/firefly.cpython-310.pyc b/fish_speech/models/vqgan/modules/__pycache__/firefly.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..011bdcaf170c195fd9f89d3e4c4c50df7ef44da1 Binary files /dev/null and b/fish_speech/models/vqgan/modules/__pycache__/firefly.cpython-310.pyc differ diff --git a/fish_speech/models/vqgan/modules/__pycache__/fsq.cpython-310.pyc b/fish_speech/models/vqgan/modules/__pycache__/fsq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba2acfe4938fe71e64952119373c2b118a520356 Binary files /dev/null and b/fish_speech/models/vqgan/modules/__pycache__/fsq.cpython-310.pyc differ diff --git a/fish_speech/models/vqgan/modules/firefly.py b/fish_speech/models/vqgan/modules/firefly.py new file mode 100644 index 0000000000000000000000000000000000000000..91fc9118cc26f4d99171e7db3ee871071a7a296a --- /dev/null +++ b/fish_speech/models/vqgan/modules/firefly.py @@ -0,0 +1,596 @@ +import math +from functools import partial +from math import prod +from typing import Callable + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import remove_parametrizations +from torch.utils.checkpoint import checkpoint + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv1D") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return (kernel_size * dilation - dilation) // 2 + + +def unpad1d(x: torch.Tensor, paddings: tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left:end] + + +def get_extra_padding_for_conv1d( + x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 +) -> int: + """See `pad_for_conv1d`.""" + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad1d( + x: torch.Tensor, + paddings: tuple[int, int], + mode: str = "zeros", + value: float = 0.0, +): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right + before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == "reflect": + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +class FishConvNet(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1 + ): + super(FishConvNet, self).__init__() + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + ) + self.stride = stride + self.kernel_size = (kernel_size - 1) * dilation + 1 + self.dilation = dilation + + def forward(self, x): + pad = self.kernel_size - self.stride + extra_padding = get_extra_padding_for_conv1d( + x, self.kernel_size, self.stride, pad + ) + x = pad1d(x, (pad, extra_padding), mode="constant", value=0) + return self.conv(x).contiguous() + + def weight_norm(self, name="weight", dim=0): + self.conv = weight_norm(self.conv, name=name, dim=dim) + return self + + def remove_parametrizations(self, name="weight"): + self.conv = remove_parametrizations(self.conv, name) + return self + + +class FishTransConvNet(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1): + super(FishTransConvNet, self).__init__() + self.conv = nn.ConvTranspose1d( + in_channels, out_channels, kernel_size, stride=stride, dilation=dilation + ) + self.stride = stride + self.kernel_size = kernel_size + + def forward(self, x): + x = self.conv(x) + pad = self.kernel_size - self.stride + padding_right = math.ceil(pad) + padding_left = pad - padding_right + x = unpad1d(x, (padding_left, padding_right)) + return x.contiguous() + + def weight_norm(self, name="weight", dim=0): + self.conv = weight_norm(self.conv, name=name, dim=dim) + return self + + def remove_parametrizations(self, name="weight"): + self.conv = remove_parametrizations(self.conv, name) + return self + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + + self.convs1 = nn.ModuleList( + [ + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[0] + ).weight_norm(), + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[1] + ).weight_norm(), + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[2] + ).weight_norm(), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[0] + ).weight_norm(), + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[1] + ).weight_norm(), + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[2] + ).weight_norm(), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.silu(x) + xt = c1(xt) + xt = F.silu(xt) + xt = c2(xt) + x = xt + x + return x + + def remove_parametrizations(self): + for conv in self.convs1: + conv.remove_parametrizations() + for conv in self.convs2: + conv.remove_parametrizations() + + +class ParallelBlock(nn.Module): + def __init__( + self, + channels: int, + kernel_sizes: tuple[int] = (3, 7, 11), + dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)), + ): + super().__init__() + + assert len(kernel_sizes) == len(dilation_sizes) + + self.blocks = nn.ModuleList() + for k, d in zip(kernel_sizes, dilation_sizes): + self.blocks.append(ResBlock1(channels, k, d)) + + def forward(self, x): + return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0) + + def remove_parametrizations(self): + for block in self.blocks: + block.remove_parametrizations() + + +class HiFiGANGenerator(nn.Module): + def __init__( + self, + *, + hop_length: int = 512, + upsample_rates: tuple[int] = (8, 8, 2, 2, 2), + upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2), + resblock_kernel_sizes: tuple[int] = (3, 7, 11), + resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)), + num_mels: int = 128, + upsample_initial_channel: int = 512, + pre_conv_kernel_size: int = 7, + post_conv_kernel_size: int = 7, + post_activation: Callable = partial(nn.SiLU, inplace=True), + ): + super().__init__() + + assert ( + prod(upsample_rates) == hop_length + ), f"hop_length must be {prod(upsample_rates)}" + + self.conv_pre = FishConvNet( + num_mels, + upsample_initial_channel, + pre_conv_kernel_size, + stride=1, + ).weight_norm() + + self.num_upsamples = len(upsample_rates) + self.num_kernels = len(resblock_kernel_sizes) + + self.noise_convs = nn.ModuleList() + self.ups = nn.ModuleList() + + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + FishTransConvNet( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + stride=u, + ).weight_norm() + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + self.resblocks.append( + ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes) + ) + + self.activation_post = post_activation() + self.conv_post = FishConvNet( + ch, 1, post_conv_kernel_size, stride=1 + ).weight_norm() + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + x = F.silu(x, inplace=True) + x = self.ups[i](x) + + if self.training and self.checkpointing: + x = checkpoint( + self.resblocks[i], + x, + use_reentrant=False, + ) + else: + x = self.resblocks[i](x) + + x = self.activation_post(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_parametrizations(self): + for up in self.ups: + up.remove_parametrizations() + for block in self.resblocks: + block.remove_parametrizations() + self.conv_pre.remove_parametrizations() + self.conv_post.remove_parametrizations() + + +# DropPath copied from timm library +def drop_path( + x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True +): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ # noqa: E501 + + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501 + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f"drop_prob={round(self.drop_prob,3):0.3f}" + + +class LayerNorm(nn.Module): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ # noqa: E501 + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps + ) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None] * x + self.bias[:, None] + return x + + +# ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py +class ConvNeXtBlock(nn.Module): + r"""ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + kernel_size (int): Kernel size for depthwise conv. Default: 7. + dilation (int): Dilation for depthwise conv. Default: 1. + """ # noqa: E501 + + def __init__( + self, + dim: int, + drop_path: float = 0.0, + layer_scale_init_value: float = 1e-6, + mlp_ratio: float = 4.0, + kernel_size: int = 7, + dilation: int = 1, + ): + super().__init__() + + self.dwconv = FishConvNet( + dim, + dim, + kernel_size=kernel_size, + # padding=int(dilation * (kernel_size - 1) / 2), + groups=dim, + ) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, int(mlp_ratio * dim) + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x, apply_residual: bool = True): + input = x + + x = self.dwconv(x) + x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + + if self.gamma is not None: + x = self.gamma * x + + x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L) + x = self.drop_path(x) + + if apply_residual: + x = input + x + + return x + + +class ConvNeXtEncoder(nn.Module): + def __init__( + self, + input_channels: int = 3, + depths: list[int] = [3, 3, 9, 3], + dims: list[int] = [96, 192, 384, 768], + drop_path_rate: float = 0.0, + layer_scale_init_value: float = 1e-6, + kernel_size: int = 7, + ): + super().__init__() + assert len(depths) == len(dims) + + self.downsample_layers = nn.ModuleList() + stem = nn.Sequential( + FishConvNet( + input_channels, + dims[0], + kernel_size=7, + # padding=3, + # padding_mode="replicate", + # padding_mode="zeros", + ), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), + ) + self.downsample_layers.append(stem) + + for i in range(len(depths) - 1): + mid_layer = nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), + nn.Conv1d(dims[i], dims[i + 1], kernel_size=1), + ) + self.downsample_layers.append(mid_layer) + + self.stages = nn.ModuleList() + dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + + cur = 0 + for i in range(len(depths)): + stage = nn.Sequential( + *[ + ConvNeXtBlock( + dim=dims[i], + drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_value, + kernel_size=kernel_size, + ) + for j in range(depths[i]) + ] + ) + self.stages.append(stage) + cur += depths[i] + + self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first") + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + for i in range(len(self.downsample_layers)): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + + return self.norm(x) + + +class FireflyArchitecture(nn.Module): + def __init__( + self, + backbone: nn.Module, + head: nn.Module, + quantizer: nn.Module, + spec_transform: nn.Module, + ): + super().__init__() + + self.backbone = backbone + self.head = head + self.quantizer = quantizer + self.spec_transform = spec_transform + self.downsample_factor = math.prod(self.quantizer.downsample_factor) + + def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor: + if self.spec_transform is not None: + x = self.spec_transform(x) + + x = self.backbone(x) + if mask is not None: + x = x * mask + + if self.quantizer is not None: + vq_result = self.quantizer(x) + x = vq_result.z + + if mask is not None: + x = x * mask + + x = self.head(x, template=template) + + if x.ndim == 2: + x = x[:, None, :] + + if self.vq is not None: + return x, vq_result + + return x + + def encode(self, audios, audio_lengths): + audios = audios.float() + + mels = self.spec_transform(audios) + mel_lengths = audio_lengths // self.spec_transform.hop_length + mel_masks = sequence_mask(mel_lengths, mels.shape[2]) + mel_masks_float_conv = mel_masks[:, None, :].float() + mels = mels * mel_masks_float_conv + + # Encode + encoded_features = self.backbone(mels) * mel_masks_float_conv + feature_lengths = mel_lengths // self.downsample_factor + + return self.quantizer.encode(encoded_features), feature_lengths + + def decode(self, indices, feature_lengths) -> torch.Tensor: + mel_masks = sequence_mask( + feature_lengths * self.downsample_factor, + indices.shape[2] * self.downsample_factor, + ) + mel_masks_float_conv = mel_masks[:, None, :].float() + audio_lengths = ( + feature_lengths * self.downsample_factor * self.spec_transform.hop_length + ) + + audio_masks = sequence_mask( + audio_lengths, + indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length, + ) + audio_masks_float_conv = audio_masks[:, None, :].float() + + z = self.quantizer.decode(indices) * mel_masks_float_conv + x = self.head(z) * audio_masks_float_conv + + return x, audio_lengths + + def remove_parametrizations(self): + if hasattr(self.backbone, "remove_parametrizations"): + self.backbone.remove_parametrizations() + + if hasattr(self.head, "remove_parametrizations"): + self.head.remove_parametrizations() + + @property + def device(self): + return next(self.parameters()).device diff --git a/fish_speech/models/vqgan/modules/fsq.py b/fish_speech/models/vqgan/modules/fsq.py new file mode 100644 index 0000000000000000000000000000000000000000..7ea4853376b6e663404ff48d6c6b5f664dde4094 --- /dev/null +++ b/fish_speech/models/vqgan/modules/fsq.py @@ -0,0 +1,116 @@ +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from vector_quantize_pytorch import GroupedResidualFSQ + +from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet + + +@dataclass +class FSQResult: + z: torch.Tensor + codes: torch.Tensor + latents: torch.Tensor + + +class DownsampleFiniteScalarQuantize(nn.Module): + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + n_groups: int = 1, + levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10 + downsample_factor: tuple[int] = (2, 2), + downsample_dims: tuple[int] | None = None, + ): + super().__init__() + + if downsample_dims is None: + downsample_dims = [input_dim for _ in range(len(downsample_factor))] + + all_dims = (input_dim,) + tuple(downsample_dims) + + self.residual_fsq = GroupedResidualFSQ( + dim=all_dims[-1], + levels=levels, + num_quantizers=n_codebooks, + groups=n_groups, + ) + + self.downsample_factor = downsample_factor + self.downsample_dims = downsample_dims + + self.downsample = nn.Sequential( + *[ + nn.Sequential( + FishConvNet( + all_dims[idx], + all_dims[idx + 1], + kernel_size=factor, + stride=factor, + ), + ConvNeXtBlock(dim=all_dims[idx + 1]), + ) + for idx, factor in enumerate(downsample_factor) + ] + ) + + self.upsample = nn.Sequential( + *[ + nn.Sequential( + FishTransConvNet( + all_dims[idx + 1], + all_dims[idx], + kernel_size=factor, + stride=factor, + ), + ConvNeXtBlock(dim=all_dims[idx]), + ) + for idx, factor in reversed(list(enumerate(downsample_factor))) + ] + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward(self, z) -> FSQResult: + original_shape = z.shape + z = self.downsample(z) + quantized, indices = self.residual_fsq(z.mT) + result = FSQResult( + z=quantized.mT, + codes=indices.mT, + latents=z, + ) + result.z = self.upsample(result.z) + + # Pad or crop z to match original shape + diff = original_shape[-1] - result.z.shape[-1] + left = diff // 2 + right = diff - left + + if diff > 0: + result.z = F.pad(result.z, (left, right)) + elif diff < 0: + result.z = result.z[..., left:-right] + + return result + + def encode(self, z): + z = self.downsample(z) + _, indices = self.residual_fsq(z.mT) + indices = rearrange(indices, "g b l r -> b (g r) l") + return indices + + def decode(self, indices: torch.Tensor): + indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups) + z_q = self.residual_fsq.get_output_from_indices(indices) + z_q = self.upsample(z_q.mT) + return z_q diff --git a/fish_speech/models/vqgan/utils.py b/fish_speech/models/vqgan/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b90c131d214006875476a161cdfd2dffa8949dac --- /dev/null +++ b/fish_speech/models/vqgan/utils.py @@ -0,0 +1,94 @@ +import matplotlib +import torch +from matplotlib import pyplot as plt + +matplotlib.use("Agg") + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def plot_mel(data, titles=None): + fig, axes = plt.subplots(len(data), 1, squeeze=False) + + if titles is None: + titles = [None for i in range(len(data))] + + plt.tight_layout() + + for i in range(len(data)): + mel = data[i] + + if isinstance(mel, torch.Tensor): + mel = mel.float().detach().cpu().numpy() + + axes[i][0].imshow(mel, origin="lower") + axes[i][0].set_aspect(2.5, adjustable="box") + axes[i][0].set_ylim(0, mel.shape[0]) + axes[i][0].set_title(titles[i], fontsize="medium") + axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) + axes[i][0].set_anchor("W") + + return fig + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0) + ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(in_act, n_channels): + n_channels_int = n_channels[0] + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + + return acts + + +def avg_with_mask(x, mask): + assert mask.dtype == torch.float, "Mask should be float" + + if mask.ndim == 2: + mask = mask.unsqueeze(1) + + if mask.shape[1] == 1: + mask = mask.expand_as(x) + + return (x * mask).sum() / mask.sum() diff --git a/fish_speech/scheduler.py b/fish_speech/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..43bed6a2210723a7d5e1ea0a48ba61140047ca29 --- /dev/null +++ b/fish_speech/scheduler.py @@ -0,0 +1,40 @@ +import math + + +def get_cosine_schedule_with_warmup_lr_lambda( + current_step: int, + *, + num_warmup_steps: int | float, + num_training_steps: int, + num_cycles: float = 0.5, + final_lr_ratio: float = 0.0, +): + if 0 < num_warmup_steps < 1: # float mode + num_warmup_steps = int(num_warmup_steps * num_training_steps) + + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + + return max( + final_lr_ratio, + 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), + ) + + +def get_constant_schedule_with_warmup_lr_lambda( + current_step: int, + *, + num_warmup_steps: int | float, + num_training_steps: int | None = None, +): + if 0 < num_warmup_steps < 1: # float mode + num_warmup_steps = int(num_warmup_steps * num_training_steps) + + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + + return 1.0 diff --git a/fish_speech/text/__init__.py b/fish_speech/text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d740bd8eed447d162e55b165965dec17130377ce --- /dev/null +++ b/fish_speech/text/__init__.py @@ -0,0 +1,4 @@ +from .clean import clean_text +from .spliter import split_text + +__all__ = ["clean_text", "split_text"] diff --git a/fish_speech/text/__pycache__/__init__.cpython-310.pyc b/fish_speech/text/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee3c407f140e4719affed265a5c38ccdd391e280 Binary files /dev/null and b/fish_speech/text/__pycache__/__init__.cpython-310.pyc differ diff --git a/fish_speech/text/__pycache__/__init__.cpython-311.pyc b/fish_speech/text/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e526318b27af352e84a843b8853b16c65c70e406 Binary files /dev/null and b/fish_speech/text/__pycache__/__init__.cpython-311.pyc differ diff --git a/fish_speech/text/__pycache__/clean.cpython-310.pyc b/fish_speech/text/__pycache__/clean.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70ef441e0d1522129bf92e40ddf8b00372e6312f Binary files /dev/null and b/fish_speech/text/__pycache__/clean.cpython-310.pyc differ diff --git a/fish_speech/text/__pycache__/clean.cpython-311.pyc b/fish_speech/text/__pycache__/clean.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50e9eaa666868f9be7dea2da436442b03d001bf9 Binary files /dev/null and b/fish_speech/text/__pycache__/clean.cpython-311.pyc differ diff --git a/fish_speech/text/__pycache__/spliter.cpython-310.pyc b/fish_speech/text/__pycache__/spliter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01c21677d80136aaed94e853b3faa4ee43c2bda6 Binary files /dev/null and b/fish_speech/text/__pycache__/spliter.cpython-310.pyc differ diff --git a/fish_speech/text/__pycache__/spliter.cpython-311.pyc b/fish_speech/text/__pycache__/spliter.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4aa7e6e5ff356d530f44de2e763cb4815197926c Binary files /dev/null and b/fish_speech/text/__pycache__/spliter.cpython-311.pyc differ diff --git a/fish_speech/text/chn_text_norm/.gitignore b/fish_speech/text/chn_text_norm/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..75ea58fa4a7bf34fc9ab35afee24684aa6ef4c89 --- /dev/null +++ b/fish_speech/text/chn_text_norm/.gitignore @@ -0,0 +1,114 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# JetBrains PyCharm +.idea + +# Customize +references +url.txt + +# Git +.git diff --git a/fish_speech/text/chn_text_norm/README.md b/fish_speech/text/chn_text_norm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8450a2c6c0f8e40f4509f5be196eb9f9d2b9afb6 --- /dev/null +++ b/fish_speech/text/chn_text_norm/README.md @@ -0,0 +1,36 @@ +# This account is no longer in use, see [Atomicoo](https://github.com/atomicoo) for my latest works. + +# Chn Text Norm + +this is a repository for chinese text normalization (no longer maintained). + +## Quick Start ## + +### Git Clone Repo ### + +git clone this repo to the root directory of your project which need to use it. + + cd /path/to/proj + git clone https://github.com/Joee1995/chn-text-norm.git + +after that, your doc tree should be: +``` +proj # root of your project +|--- chn_text_norm # this chn-text-norm tool + |--- text.py + |--- ... +|--- text_normalize.py # your text normalization code +|--- ... +``` + +### How to Use ? ### + + # text_normalize.py + from chn_text_norm.text import * + + raw_text = 'your raw text' + text = Text(raw_text=raw_text).normalize() + +### How to add quantums ### + +打开test.py,然后你就知道怎么做了。 diff --git a/fish_speech/text/chn_text_norm/__init__.py b/fish_speech/text/chn_text_norm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fish_speech/text/chn_text_norm/__pycache__/__init__.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf43a50e99266da3a54e51615d8dc370008fd748 Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/__init__.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/__init__.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3da0f8c0604b6b72c65fea53dd7e4566b5d3cca5 Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/__init__.cpython-311.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/basic_class.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/basic_class.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4dfdb015766efc309ead7be2cc40690445f7719 Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/basic_class.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/basic_class.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/basic_class.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f61ea440e6de5ab546e5723513e18a17f2e08d50 Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/basic_class.cpython-311.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/basic_constant.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/basic_constant.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77d8601c1463006fde09a0e5959dfa46bc6450ee Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/basic_constant.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/basic_constant.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/basic_constant.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..510b5d1edbbe9f8d35cf484a7b605077c1a4268b Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/basic_constant.cpython-311.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/basic_util.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/basic_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..daf7f728bae1ae58904e6ecfe6cb6a83b7dc20db Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/basic_util.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/basic_util.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/basic_util.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ade7c3a6fe8a8203d58e0cd2dddd76c4142dec4 Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/basic_util.cpython-311.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/cardinal.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/cardinal.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..892bfe68b3cf4287aaf5b875aec92bb2dfc917c6 Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/cardinal.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/cardinal.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/cardinal.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05fb51069de415b89cd3bcd567949e0c71ba3b4a Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/cardinal.cpython-311.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/date.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/date.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d09d9f6f67baf2c792ce64c8250f7754d10b0e23 Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/date.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/date.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/date.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dd12ae4851e5b97088d7b084fa395f2270cbafd Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/date.cpython-311.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/digit.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/digit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19ff89ce017f6aba2ea60cd4418c44ae9aa92cc6 Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/digit.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/digit.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/digit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61773b54c1f390144c050c2d6edc523cea861272 Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/digit.cpython-311.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/fraction.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/fraction.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4e3ef77729e1372a5658809e4b04b0bc61a6db1 Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/fraction.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/fraction.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/fraction.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6571976f0f4c020136f2fd6023148a546ea5a87 Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/fraction.cpython-311.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/money.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/money.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7f74b20554c6c9a93718ec5378dfa7dcc0fb340 Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/money.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/money.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/money.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dec683ad68c44eaf2c699cad4866870db617c375 Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/money.cpython-311.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/percentage.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/percentage.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ba258f5f8632961d8095de91c29be3be1d7a032 Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/percentage.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/percentage.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/percentage.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16a0e5b4fcb77f1f8779e5094554754ddf9dd9d6 Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/percentage.cpython-311.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/telephone.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/telephone.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8b4cbd724f19dd63970f717eaf869181cd4f5d4 Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/telephone.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/telephone.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/telephone.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12e4b471b4dcd4644cb0d264dfb701ca09cc5244 Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/telephone.cpython-311.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/text.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/text.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66b3817e2446e570067fe4cd1ccad0b9692134cd Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/text.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/text.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/text.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb01e18f4f23dd14cd5c4872f96841d775e9a868 Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/text.cpython-311.pyc differ diff --git a/fish_speech/text/chn_text_norm/basic_class.py b/fish_speech/text/chn_text_norm/basic_class.py new file mode 100644 index 0000000000000000000000000000000000000000..58d8f8eb7fc85d0861f106667d8f4e3e52b54761 --- /dev/null +++ b/fish_speech/text/chn_text_norm/basic_class.py @@ -0,0 +1,172 @@ +# -*- coding: utf-8 -*- +"""基本类 +中文字符类 +中文数字/数位类 +中文数字类 +中文数位类 +中文数字系统类 +中文数学符号类 +*中文其他符号类 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-02" + +from fish_speech.text.chn_text_norm.basic_constant import NUMBERING_TYPES + + +class ChineseChar(object): + """ + 中文字符 + 每个字符对应简体和繁体, + e.g. 简体 = '负', 繁体 = '負' + 转换时可转换为简体或繁体 + """ + + def __init__(self, simplified, traditional): + self.simplified = simplified + self.traditional = traditional + self.__repr__ = self.__str__ + + def __str__(self): + return self.simplified or self.traditional or None + + def __repr__(self): + return self.__str__() + + +class ChineseNumberUnit(ChineseChar): + """ + 中文数字/数位字符 + 每个字符除繁简体外还有一个额外的大写字符 + e.g. '陆' 和 '陸' + """ + + def __init__(self, power, simplified, traditional, big_s, big_t): + super(ChineseNumberUnit, self).__init__(simplified, traditional) + self.power = power + self.big_s = big_s + self.big_t = big_t + + def __str__(self): + return "10^{}".format(self.power) + + @classmethod + def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False): + + if small_unit: + return ChineseNumberUnit( + power=index + 1, + simplified=value[0], + traditional=value[1], + big_s=value[1], + big_t=value[1], + ) + elif numbering_type == NUMBERING_TYPES[0]: + return ChineseNumberUnit( + power=index + 8, + simplified=value[0], + traditional=value[1], + big_s=value[0], + big_t=value[1], + ) + elif numbering_type == NUMBERING_TYPES[1]: + return ChineseNumberUnit( + power=(index + 2) * 4, + simplified=value[0], + traditional=value[1], + big_s=value[0], + big_t=value[1], + ) + elif numbering_type == NUMBERING_TYPES[2]: + return ChineseNumberUnit( + power=pow(2, index + 3), + simplified=value[0], + traditional=value[1], + big_s=value[0], + big_t=value[1], + ) + else: + raise ValueError( + "Counting type should be in {0} ({1} provided).".format( + NUMBERING_TYPES, numbering_type + ) + ) + + +class ChineseNumberDigit(ChineseChar): + """ + 中文数字字符 + """ + + def __init__( + self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None + ): + super(ChineseNumberDigit, self).__init__(simplified, traditional) + self.value = value + self.big_s = big_s + self.big_t = big_t + self.alt_s = alt_s + self.alt_t = alt_t + + def __str__(self): + return str(self.value) + + @classmethod + def create(cls, i, v): + return ChineseNumberDigit(i, v[0], v[1], v[2], v[3]) + + +class ChineseMath(ChineseChar): + """ + 中文数位字符 + """ + + def __init__(self, simplified, traditional, symbol, expression=None): + super(ChineseMath, self).__init__(simplified, traditional) + self.symbol = symbol + self.expression = expression + self.big_s = simplified + self.big_t = traditional + + +CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath + + +class NumberSystem(object): + """ + 中文数字系统 + """ + + pass + + +class MathSymbol(object): + """ + 用于中文数字系统的数学符号 (繁/简体), e.g. + positive = ['正', '正'] + negative = ['负', '負'] + point = ['点', '點'] + """ + + def __init__(self, positive, negative, point): + self.positive = positive + self.negative = negative + self.point = point + + def __iter__(self): + for v in self.__dict__.values(): + yield v + + +# class OtherSymbol(object): +# """ +# 其他符号 +# """ +# +# def __init__(self, sil): +# self.sil = sil +# +# def __iter__(self): +# for v in self.__dict__.values(): +# yield v diff --git a/fish_speech/text/chn_text_norm/basic_constant.py b/fish_speech/text/chn_text_norm/basic_constant.py new file mode 100644 index 0000000000000000000000000000000000000000..9a65991b9a9d349a0571c80508633951e52749ef --- /dev/null +++ b/fish_speech/text/chn_text_norm/basic_constant.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +"""基本常量 +中文数字/数位/符号字符常量 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-02" + +CHINESE_DIGIS = "零一二三四五六七八九" +BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖" +BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖" +SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万" +SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬" +LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载" +LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載" +SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万" +SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬" + +ZERO_ALT = "〇" +ONE_ALT = "幺" +TWO_ALTS = ["两", "兩"] + +POSITIVE = ["正", "正"] +NEGATIVE = ["负", "負"] +POINT = ["点", "點"] +# PLUS = [u'加', u'加'] +# SIL = [u'杠', u'槓'] + +# 中文数字系统类型 +NUMBERING_TYPES = ["low", "mid", "high"] diff --git a/fish_speech/text/chn_text_norm/basic_util.py b/fish_speech/text/chn_text_norm/basic_util.py new file mode 100644 index 0000000000000000000000000000000000000000..dbf6130be87f285eed9998186508ea489d3bac9e --- /dev/null +++ b/fish_speech/text/chn_text_norm/basic_util.py @@ -0,0 +1,342 @@ +# -*- coding: utf-8 -*- +"""基本方法 +创建中文数字系统 方法 +中文字符串 <=> 数字串 方法 +数字串 <=> 中文字符串 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-02" + +from fish_speech.text.chn_text_norm.basic_class import * +from fish_speech.text.chn_text_norm.basic_constant import * + + +def create_system(numbering_type=NUMBERING_TYPES[1]): + """ + 根据数字系统类型返回创建相应的数字系统,默认为 mid + NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型 + low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc. + mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc. + high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc. + 返回对应的数字系统 + """ + + # chinese number units of '亿' and larger + all_larger_units = zip( + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, + LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL, + ) + larger_units = [ + CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units) + ] + # chinese number units of '十, 百, 千, 万' + all_smaller_units = zip( + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL, + ) + smaller_units = [ + CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units) + ] + # digis + chinese_digis = zip( + CHINESE_DIGIS, + CHINESE_DIGIS, + BIG_CHINESE_DIGIS_SIMPLIFIED, + BIG_CHINESE_DIGIS_TRADITIONAL, + ) + digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)] + digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT + digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT + digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1] + + # symbols + positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x) + negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x) + point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y))) + # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y))) + system = NumberSystem() + system.units = smaller_units + larger_units + system.digits = digits + system.math = MathSymbol(positive_cn, negative_cn, point_cn) + # system.symbols = OtherSymbol(sil_cn) + return system + + +def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]): + + def get_symbol(char, system): + for u in system.units: + if char in [u.traditional, u.simplified, u.big_s, u.big_t]: + return u + for d in system.digits: + if char in [ + d.traditional, + d.simplified, + d.big_s, + d.big_t, + d.alt_s, + d.alt_t, + ]: + return d + for m in system.math: + if char in [m.traditional, m.simplified]: + return m + + def string2symbols(chinese_string, system): + int_string, dec_string = chinese_string, "" + for p in [system.math.point.simplified, system.math.point.traditional]: + if p in chinese_string: + int_string, dec_string = chinese_string.split(p) + break + return [get_symbol(c, system) for c in int_string], [ + get_symbol(c, system) for c in dec_string + ] + + def correct_symbols(integer_symbols, system): + """ + 一百八 to 一百八十 + 一亿一千三百万 to 一亿 一千万 三百万 + """ + + if integer_symbols and isinstance(integer_symbols[0], CNU): + if integer_symbols[0].power == 1: + integer_symbols = [system.digits[1]] + integer_symbols + + if len(integer_symbols) > 1: + if isinstance(integer_symbols[-1], CND) and isinstance( + integer_symbols[-2], CNU + ): + integer_symbols.append( + CNU(integer_symbols[-2].power - 1, None, None, None, None) + ) + + result = [] + unit_count = 0 + for s in integer_symbols: + if isinstance(s, CND): + result.append(s) + unit_count = 0 + elif isinstance(s, CNU): + current_unit = CNU(s.power, None, None, None, None) + unit_count += 1 + + if unit_count == 1: + result.append(current_unit) + elif unit_count > 1: + for i in range(len(result)): + if ( + isinstance(result[-i - 1], CNU) + and result[-i - 1].power < current_unit.power + ): + result[-i - 1] = CNU( + result[-i - 1].power + current_unit.power, + None, + None, + None, + None, + ) + return result + + def compute_value(integer_symbols): + """ + Compute the value. + When current unit is larger than previous unit, current unit * all previous units will be used as all previous units. + e.g. '两千万' = 2000 * 10000 not 2000 + 10000 + """ + value = [0] + last_power = 0 + for s in integer_symbols: + if isinstance(s, CND): + value[-1] = s.value + elif isinstance(s, CNU): + value[-1] *= pow(10, s.power) + if s.power > last_power: + value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1])) + last_power = s.power + value.append(0) + return sum(value) + + system = create_system(numbering_type) + int_part, dec_part = string2symbols(chinese_string, system) + int_part = correct_symbols(int_part, system) + int_str = str(compute_value(int_part)) + dec_str = "".join([str(d.value) for d in dec_part]) + if dec_part: + return "{0}.{1}".format(int_str, dec_str) + else: + return int_str + + +def num2chn( + number_string, + numbering_type=NUMBERING_TYPES[1], + big=False, + traditional=False, + alt_zero=False, + alt_one=False, + alt_two=True, + use_zeros=True, + use_units=True, +): + + def get_value(value_string, use_zeros=True): + + striped_string = value_string.lstrip("0") + + # record nothing if all zeros + if not striped_string: + return [] + + # record one digits + elif len(striped_string) == 1: + if use_zeros and len(value_string) != len(striped_string): + return [system.digits[0], system.digits[int(striped_string)]] + else: + return [system.digits[int(striped_string)]] + + # recursively record multiple digits + else: + result_unit = next( + u for u in reversed(system.units) if u.power < len(striped_string) + ) + result_string = value_string[: -result_unit.power] + return ( + get_value(result_string) + + [result_unit] + + get_value(striped_string[-result_unit.power :]) + ) + + system = create_system(numbering_type) + + int_dec = number_string.split(".") + if len(int_dec) == 1: + int_string = int_dec[0] + dec_string = "" + elif len(int_dec) == 2: + int_string = int_dec[0] + dec_string = int_dec[1] + else: + raise ValueError( + "invalid input num string with more than one dot: {}".format(number_string) + ) + + if use_units and len(int_string) > 1: + result_symbols = get_value(int_string) + else: + result_symbols = [system.digits[int(c)] for c in int_string] + dec_symbols = [system.digits[int(c)] for c in dec_string] + if dec_string: + result_symbols += [system.math.point] + dec_symbols + + if alt_two: + liang = CND( + 2, + system.digits[2].alt_s, + system.digits[2].alt_t, + system.digits[2].big_s, + system.digits[2].big_t, + ) + for i, v in enumerate(result_symbols): + if isinstance(v, CND) and v.value == 2: + next_symbol = ( + result_symbols[i + 1] if i < len(result_symbols) - 1 else None + ) + previous_symbol = result_symbols[i - 1] if i > 0 else None + if isinstance(next_symbol, CNU) and isinstance( + previous_symbol, (CNU, type(None)) + ): + if next_symbol.power != 1 and ( + (previous_symbol is None) or (previous_symbol.power != 1) + ): + result_symbols[i] = liang + + # if big is True, '两' will not be used and `alt_two` has no impact on output + if big: + attr_name = "big_" + if traditional: + attr_name += "t" + else: + attr_name += "s" + else: + if traditional: + attr_name = "traditional" + else: + attr_name = "simplified" + + result = "".join([getattr(s, attr_name) for s in result_symbols]) + + # if not use_zeros: + # result = result.strip(getattr(system.digits[0], attr_name)) + + if alt_zero: + result = result.replace( + getattr(system.digits[0], attr_name), system.digits[0].alt_s + ) + + if alt_one: + result = result.replace( + getattr(system.digits[1], attr_name), system.digits[1].alt_s + ) + + for i, p in enumerate(POINT): + if result.startswith(p): + return CHINESE_DIGIS[0] + result + + # ^10, 11, .., 19 + if ( + len(result) >= 2 + and result[1] + in [ + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0], + ] + and result[0] + in [ + CHINESE_DIGIS[1], + BIG_CHINESE_DIGIS_SIMPLIFIED[1], + BIG_CHINESE_DIGIS_TRADITIONAL[1], + ] + ): + result = result[1:] + + return result + + +if __name__ == "__main__": + + # 测试程序 + all_chinese_number_string = ( + CHINESE_DIGIS + + BIG_CHINESE_DIGIS_SIMPLIFIED + + BIG_CHINESE_DIGIS_TRADITIONAL + + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED + + LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL + + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED + + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL + + ZERO_ALT + + ONE_ALT + + "".join(TWO_ALTS + POSITIVE + NEGATIVE + POINT) + ) + + print("num:", chn2num("一万零四百零三点八零五")) + print("num:", chn2num("一亿六点三")) + print("num:", chn2num("一亿零六点三")) + print("num:", chn2num("两千零一亿六点三")) + # print('num:', chn2num('一零零八六')) + print("txt:", num2chn("10260.03", alt_zero=True)) + print("txt:", num2chn("20037.090", numbering_type="low", traditional=True)) + print("txt:", num2chn("100860001.77", numbering_type="high", big=True)) + print( + "txt:", + num2chn( + "059523810880", + alt_one=True, + alt_two=False, + use_lzeros=True, + use_rzeros=True, + use_units=False, + ), + ) + + print(all_chinese_number_string) diff --git a/fish_speech/text/chn_text_norm/cardinal.py b/fish_speech/text/chn_text_norm/cardinal.py new file mode 100644 index 0000000000000000000000000000000000000000..ace9f5ad8e7f3be3a8e41b11dc0b9f80db799616 --- /dev/null +++ b/fish_speech/text/chn_text_norm/cardinal.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +"""CARDINAL类 (包含小数DECIMAL类) +纯数 <=> 中文字符串 方法 +中文字符串 <=> 纯数 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-03" + +from fish_speech.text.chn_text_norm.basic_util import * + + +class Cardinal: + """ + CARDINAL类 + """ + + def __init__(self, cardinal=None, chntext=None): + self.cardinal = cardinal + self.chntext = chntext + + def chntext2cardinal(self): + return chn2num(self.chntext) + + def cardinal2chntext(self): + return num2chn(self.cardinal) + + +if __name__ == "__main__": + + # 测试程序 + print(Cardinal(cardinal="21357.230").cardinal2chntext()) diff --git a/fish_speech/text/chn_text_norm/date.py b/fish_speech/text/chn_text_norm/date.py new file mode 100644 index 0000000000000000000000000000000000000000..77acfdb9a91df0fe3c615a0784f61aad87fbe56e --- /dev/null +++ b/fish_speech/text/chn_text_norm/date.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +"""DATE类 +日期 <=> 中文字符串 方法 +中文字符串 <=> 日期 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-07" + +from fish_speech.text.chn_text_norm.cardinal import Cardinal +from fish_speech.text.chn_text_norm.digit import Digit + + +class Date: + """ + DATE类 + """ + + def __init__(self, date=None, chntext=None): + self.date = date + self.chntext = chntext + + # def chntext2date(self): + # chntext = self.chntext + # try: + # year, other = chntext.strip().split('年', maxsplit=1) + # year = Digit(chntext=year).digit2chntext() + '年' + # except ValueError: + # other = chntext + # year = '' + # if other: + # try: + # month, day = other.strip().split('月', maxsplit=1) + # month = Cardinal(chntext=month).chntext2cardinal() + '月' + # except ValueError: + # day = chntext + # month = '' + # if day: + # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1] + # else: + # month = '' + # day = '' + # date = year + month + day + # self.date = date + # return self.date + + def date2chntext(self): + date = self.date + try: + year, other = date.strip().split("年", maxsplit=1) + year = Digit(digit=year).digit2chntext() + "年" + except ValueError: + other = date + year = "" + if other: + try: + month, day = other.strip().split("月", maxsplit=1) + month = Cardinal(cardinal=month).cardinal2chntext() + "月" + except ValueError: + day = date + month = "" + if day: + day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1] + else: + month = "" + day = "" + chntext = year + month + day + self.chntext = chntext + return self.chntext + + +if __name__ == "__main__": + + # 测试 + print(Date(date="09年3月16日").date2chntext()) diff --git a/fish_speech/text/chn_text_norm/digit.py b/fish_speech/text/chn_text_norm/digit.py new file mode 100644 index 0000000000000000000000000000000000000000..47c0cd4ad0c700635f84470bfdacfbdafb4a6185 --- /dev/null +++ b/fish_speech/text/chn_text_norm/digit.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +"""DIGIT类 +数字串 <=> 中文字符串 方法 +中文字符串 <=> 数字串 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-03" + +from fish_speech.text.chn_text_norm.basic_util import * + + +class Digit: + """ + DIGIT类 + """ + + def __init__(self, digit=None, chntext=None): + self.digit = digit + self.chntext = chntext + + # def chntext2digit(self): + # return chn2num(self.chntext) + + def digit2chntext(self): + return num2chn(self.digit, alt_two=False, use_units=False) + + +if __name__ == "__main__": + + # 测试程序 + print(Digit(digit="2016").digit2chntext()) diff --git a/fish_speech/text/chn_text_norm/fraction.py b/fish_speech/text/chn_text_norm/fraction.py new file mode 100644 index 0000000000000000000000000000000000000000..b43b6a7feb634d346d59a2b4ab84b77ac88df103 --- /dev/null +++ b/fish_speech/text/chn_text_norm/fraction.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +"""FRACTION类 +分数 <=> 中文字符串 方法 +中文字符串 <=> 分数 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-03" + +from fish_speech.text.chn_text_norm.basic_util import * + + +class Fraction: + """ + FRACTION类 + """ + + def __init__(self, fraction=None, chntext=None): + self.fraction = fraction + self.chntext = chntext + + def chntext2fraction(self): + denominator, numerator = self.chntext.split("分之") + return chn2num(numerator) + "/" + chn2num(denominator) + + def fraction2chntext(self): + numerator, denominator = self.fraction.split("/") + return num2chn(denominator) + "分之" + num2chn(numerator) + + +if __name__ == "__main__": + + # 测试程序 + print(Fraction(fraction="2135/7230").fraction2chntext()) + print(Fraction(chntext="五百八十一分之三百六十九").chntext2fraction()) diff --git a/fish_speech/text/chn_text_norm/money.py b/fish_speech/text/chn_text_norm/money.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c980d32134e1460e96e5bcbcc73d0d55974d2a --- /dev/null +++ b/fish_speech/text/chn_text_norm/money.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +"""MONEY类 +金钱 <=> 中文字符串 方法 +中文字符串 <=> 金钱 方法 +""" +import re + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-08" + +from fish_speech.text.chn_text_norm.cardinal import Cardinal + + +class Money: + """ + MONEY类 + """ + + def __init__(self, money=None, chntext=None): + self.money = money + self.chntext = chntext + + # def chntext2money(self): + # return self.money + + def money2chntext(self): + money = self.money + pattern = re.compile(r"(\d+(\.\d+)?)") + matchers = pattern.findall(money) + if matchers: + for matcher in matchers: + money = money.replace( + matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext() + ) + self.chntext = money + return self.chntext + + +if __name__ == "__main__": + + # 测试 + print(Money(money="21.5万元").money2chntext()) + print(Money(money="230块5毛").money2chntext()) diff --git a/fish_speech/text/chn_text_norm/percentage.py b/fish_speech/text/chn_text_norm/percentage.py new file mode 100644 index 0000000000000000000000000000000000000000..46abbf545af62eb951d8f6fe40bcf684587f81b0 --- /dev/null +++ b/fish_speech/text/chn_text_norm/percentage.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +"""PERCENTAGE类 +百分数 <=> 中文字符串 方法 +中文字符串 <=> 百分数 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-06" + +from fish_speech.text.chn_text_norm.basic_util import * + + +class Percentage: + """ + PERCENTAGE类 + """ + + def __init__(self, percentage=None, chntext=None): + self.percentage = percentage + self.chntext = chntext + + def chntext2percentage(self): + return chn2num(self.chntext.strip().strip("百分之")) + "%" + + def percentage2chntext(self): + return "百分之" + num2chn(self.percentage.strip().strip("%")) + + +if __name__ == "__main__": + + # 测试程序 + print(Percentage(chntext="百分之五十六点零三").chntext2percentage()) + print(Percentage(percentage="65.3%").percentage2chntext()) diff --git a/fish_speech/text/chn_text_norm/telephone.py b/fish_speech/text/chn_text_norm/telephone.py new file mode 100644 index 0000000000000000000000000000000000000000..e72b546db628a3b807dc6235b59b188cae3153ff --- /dev/null +++ b/fish_speech/text/chn_text_norm/telephone.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +"""TELEPHONE类 +电话号码 <=> 中文字符串 方法 +中文字符串 <=> 电话号码 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-03" + +from fish_speech.text.chn_text_norm.basic_util import * + + +class TelePhone: + """ + TELEPHONE类 + """ + + def __init__(self, telephone=None, raw_chntext=None, chntext=None): + self.telephone = telephone + self.raw_chntext = raw_chntext + self.chntext = chntext + + # def chntext2telephone(self): + # sil_parts = self.raw_chntext.split('') + # self.telephone = '-'.join([ + # str(chn2num(p)) for p in sil_parts + # ]) + # return self.telephone + + def telephone2chntext(self, fixed=False): + + if fixed: + sil_parts = self.telephone.split("-") + self.raw_chntext = "".join( + [num2chn(part, alt_two=False, use_units=False) for part in sil_parts] + ) + self.chntext = self.raw_chntext.replace("", "") + else: + sp_parts = self.telephone.strip("+").split() + self.raw_chntext = "".join( + [num2chn(part, alt_two=False, use_units=False) for part in sp_parts] + ) + self.chntext = self.raw_chntext.replace("", "") + return self.chntext + + +if __name__ == "__main__": + + # 测试程序 + print(TelePhone(telephone="0595-23980880").telephone2chntext()) + # print(TelePhone(raw_chntext='零五九五杠二三八六五零九八').chntext2telephone()) diff --git a/fish_speech/text/chn_text_norm/text.py b/fish_speech/text/chn_text_norm/text.py new file mode 100644 index 0000000000000000000000000000000000000000..54086fd933c01e14c3c55cee9adb52eefb58fd31 --- /dev/null +++ b/fish_speech/text/chn_text_norm/text.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- +""" +TEXT类 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-03" + +import re + +from fish_speech.text.chn_text_norm.cardinal import Cardinal +from fish_speech.text.chn_text_norm.date import Date +from fish_speech.text.chn_text_norm.digit import Digit +from fish_speech.text.chn_text_norm.fraction import Fraction +from fish_speech.text.chn_text_norm.money import Money +from fish_speech.text.chn_text_norm.percentage import Percentage +from fish_speech.text.chn_text_norm.telephone import TelePhone + +CURRENCY_NAMES = ( + "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|" + "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)" +) +CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)" +COM_QUANTIFIERS = ( + "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|" + "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|" + "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|" + "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|" + "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|" + "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|人|抽)" +) + + +class Text: + """ + Text类 + """ + + def __init__(self, raw_text, norm_text=None): + self.raw_text = "^" + raw_text + "$" + self.norm_text = norm_text + + def _particular(self): + text = self.norm_text + pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))") + matchers = pattern.findall(text) + if matchers: + # print('particular') + for matcher in matchers: + text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1) + self.norm_text = text + return self.norm_text + + def normalize(self): + text = self.raw_text + + # 规范化日期 + pattern = re.compile( + r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)" + ) + matchers = pattern.findall(text) + if matchers: + # print('date') + for matcher in matchers: + text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1) + + # 规范化金钱 + pattern = re.compile( + r"\D+((\d+(\.\d+)?)[多余几]?" + + CURRENCY_UNITS + + "(\d" + + CURRENCY_UNITS + + "?)?)" + ) + matchers = pattern.findall(text) + if matchers: + # print('money') + for matcher in matchers: + text = text.replace( + matcher[0], Money(money=matcher[0]).money2chntext(), 1 + ) + + # 规范化固话/手机号码 + # 手机 + # http://www.jihaoba.com/news/show/13680 + # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 + # 联通:130、131、132、156、155、186、185、176 + # 电信:133、153、189、180、181、177 + pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") + matchers = pattern.findall(text) + if matchers: + # print('telephone') + for matcher in matchers: + text = text.replace( + matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1 + ) + # 固话 + pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") + matchers = pattern.findall(text) + if matchers: + # print('fixed telephone') + for matcher in matchers: + text = text.replace( + matcher[0], + TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), + 1, + ) + + # 规范化分数 + pattern = re.compile(r"(\d+/\d+)") + matchers = pattern.findall(text) + if matchers: + # print('fraction') + for matcher in matchers: + text = text.replace( + matcher, Fraction(fraction=matcher).fraction2chntext(), 1 + ) + + # 规范化百分数 + text = text.replace("%", "%") + pattern = re.compile(r"(\d+(\.\d+)?%)") + matchers = pattern.findall(text) + if matchers: + # print('percentage') + for matcher in matchers: + text = text.replace( + matcher[0], + Percentage(percentage=matcher[0]).percentage2chntext(), + 1, + ) + + # 规范化纯数+量词 + pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS) + matchers = pattern.findall(text) + if matchers: + # print('cardinal+quantifier') + for matcher in matchers: + text = text.replace( + matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1 + ) + + # 规范化数字编号 + pattern = re.compile(r"(\d{4,32})") + matchers = pattern.findall(text) + if matchers: + # print('digit') + for matcher in matchers: + text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1) + + # 规范化纯数 + pattern = re.compile(r"(\d+(\.\d+)?)") + matchers = pattern.findall(text) + if matchers: + # print('cardinal') + for matcher in matchers: + text = text.replace( + matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1 + ) + + self.norm_text = text + self._particular() + + return self.norm_text.lstrip("^").rstrip("$") + + +if __name__ == "__main__": + + # 测试程序 + print(Text(raw_text="固话:0595-23865596或23880880。").normalize()) + print(Text(raw_text="手机:+86 19859213959或15659451527。").normalize()) + print(Text(raw_text="分数:32477/76391。").normalize()) + print(Text(raw_text="百分数:80.03%。").normalize()) + print(Text(raw_text="编号:31520181154418。").normalize()) + print(Text(raw_text="纯数:2983.07克或12345.60米。").normalize()) + print(Text(raw_text="日期:1999年2月20日或09年3月15号。").normalize()) + print(Text(raw_text="金钱:12块5,34.5元,20.1万").normalize()) + print(Text(raw_text="特殊:O2O或B2C。").normalize()) diff --git a/fish_speech/text/clean.py b/fish_speech/text/clean.py new file mode 100644 index 0000000000000000000000000000000000000000..dbaf843d781f113735043319cc00dc2aed5ae382 --- /dev/null +++ b/fish_speech/text/clean.py @@ -0,0 +1,62 @@ +import re + +SYMBOLS_MAPPING = { + "\n": "", + "…": ".", + "“": "'", + "”": "'", + "‘": "'", + "’": "'", + "【": "", + "】": "", + "[": "", + "]": "", + "(": "", + ")": "", + "(": "", + ")": "", + "・": "", + "·": "", + "「": "'", + "」": "'", + "《": "'", + "》": "'", + "—": "", + "~": "", + "~": "", + ":": ",", + ";": ",", + ";": ",", + ":": ",", +} + +REPLACE_SYMBOL_REGEX = re.compile( + "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys()) +) + + +EMOJI_REGEX = re.compile( + "[" + "\U0001F600-\U0001F64F" # emoticons + "\U0001F300-\U0001F5FF" # symbols & pictographs + "\U0001F680-\U0001F6FF" # transport & map symbols + "\U0001F1E0-\U0001F1FF" # flags (iOS) + "]+", + flags=re.UNICODE, +) + + +def clean_text(text): + # Clean the text + text = text.strip() + + # Replace all chinese symbols with their english counterparts + text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text) + + # Remove emojis + text = EMOJI_REGEX.sub(r"", text) + + # Remove continuous periods (...) and commas (,,,) + text = re.sub(r"[.,]{2,}", lambda m: m.group()[0], text) + + return text diff --git a/fish_speech/text/spliter.py b/fish_speech/text/spliter.py new file mode 100644 index 0000000000000000000000000000000000000000..d4bb995487c4f53818c6b2a16cf0a886b4e02e84 --- /dev/null +++ b/fish_speech/text/spliter.py @@ -0,0 +1,130 @@ +import re +import string + +from fish_speech.text.clean import clean_text + + +def utf_8_len(text): + return len(text.encode("utf-8")) + + +def break_text(texts, length, splits: set): + for text in texts: + if utf_8_len(text) <= length: + yield text + continue + + curr = "" + for char in text: + curr += char + + if char in splits: + yield curr + curr = "" + + if curr: + yield curr + + +def break_text_by_length(texts, length): + for text in texts: + if utf_8_len(text) <= length: + yield text + continue + + curr = "" + for char in text: + curr += char + + if utf_8_len(curr) >= length: + yield curr + curr = "" + + if curr: + yield curr + + +def add_cleaned(curr, segments): + curr = curr.strip() + if curr and not all(c.isspace() or c in string.punctuation for c in curr): + segments.append(curr) + + +def protect_float(text): + # Turns 3.14 into <3_f_14> to prevent splitting + return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text) + + +def unprotect_float(text): + # Turns <3_f_14> into 3.14 + return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text) + + +def split_text(text, length): + text = clean_text(text) + + # Break the text into pieces with following rules: + # 1. Split the text at ".", "!", "?" if text is NOT a float + # 2. If the text is longer than length, split at "," + # 3. If the text is still longer than length, split at " " + # 4. If the text is still longer than length, split at any character to length + + texts = [text] + texts = map(protect_float, texts) + texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"}) + texts = map(unprotect_float, texts) + texts = break_text(texts, length, {",", ","}) + texts = break_text(texts, length, {" "}) + texts = list(break_text_by_length(texts, length)) + + # Then, merge the texts into segments with length <= length + segments = [] + curr = "" + + for text in texts: + if utf_8_len(curr) + utf_8_len(text) <= length: + curr += text + else: + add_cleaned(curr, segments) + curr = text + + if curr: + add_cleaned(curr, segments) + + return segments + + +if __name__ == "__main__": + # Test the split_text function + + text = "This is a test sentence. This is another test sentence. And a third one." + + assert split_text(text, 50) == [ + "This is a test sentence.", + "This is another test sentence. And a third one.", + ] + assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"] + assert split_text(" ", 10) == [] + assert split_text("a", 10) == ["a"] + + text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines." + assert split_text(text, 50) == [ + "This is a test sentence with only commas,", + "and no dots, and no exclamation marks,", + "and no question marks, and no newlines.", + ] + + text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence." + # First half split at " ", second half split at "," + assert split_text(text, 50) == [ + "This is a test sentence This is a test sentence", + "This is a test sentence. This is a test sentence,", + "This is a test sentence, This is a test sentence.", + ] + + text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。" + assert split_text(text, 50) == [ + "这是一段很长的中文文本,", + "而且没有句号,也没有感叹号,", + "也没有问号,也没有换行符.", + ] diff --git a/fish_speech/train.py b/fish_speech/train.py new file mode 100644 index 0000000000000000000000000000000000000000..e693f3adc4dda787bdd587aec29f53355f2b1653 --- /dev/null +++ b/fish_speech/train.py @@ -0,0 +1,141 @@ +import os + +os.environ["USE_LIBUV"] = "0" +import sys +from typing import Optional + +import hydra +import lightning as L +import pyrootutils +import torch +from lightning import Callback, LightningDataModule, LightningModule, Trainer +from lightning.pytorch.loggers import Logger +from lightning.pytorch.strategies import DDPStrategy +from omegaconf import DictConfig, OmegaConf + +os.environ.pop("SLURM_NTASKS", None) +os.environ.pop("SLURM_JOB_NAME", None) +os.environ.pop("SLURM_NTASKS_PER_NODE", None) + +# register eval resolver and root +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +# Allow TF32 on Ampere GPUs +torch.set_float32_matmul_precision("high") +torch.backends.cudnn.allow_tf32 = True + +# register eval resolver +OmegaConf.register_new_resolver("eval", eval) + +import fish_speech.utils as utils + +log = utils.RankedLogger(__name__, rank_zero_only=True) + + +@utils.task_wrapper +def train(cfg: DictConfig) -> tuple[dict, dict]: + """Trains the model. Can additionally evaluate on a testset, using best weights obtained during + training. + This method is wrapped in optional @task_wrapper decorator, that controls the behavior during + failure. Useful for multiruns, saving info about the crash, etc. + Args: + cfg (DictConfig): Configuration composed by Hydra. + Returns: + Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. + """ # noqa: E501 + + # set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + L.seed_everything(cfg.seed, workers=False) + + if cfg.get("deterministic"): + torch.use_deterministic_algorithms(True) + + log.info(f"Instantiating datamodule <{cfg.data._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) + + log.info(f"Instantiating model <{cfg.model._target_}>") + model: LightningModule = hydra.utils.instantiate(cfg.model) + + log.info("Instantiating callbacks...") + callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) + + log.info("Instantiating loggers...") + logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger")) + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate( + cfg.trainer, + callbacks=callbacks, + logger=logger, + ) + + object_dict = { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "callbacks": callbacks, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + utils.log_hyperparameters(object_dict) + + if cfg.get("train"): + log.info("Starting training!") + + ckpt_path = cfg.get("ckpt_path") + auto_resume = False + + resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir) + if resume_ckpt_path is not None: + ckpt_path = resume_ckpt_path + auto_resume = True + + if ckpt_path is not None: + log.info(f"Resuming from checkpoint: {ckpt_path}") + + # resume weights only is disabled for auto-resume + if cfg.get("resume_weights_only") and auto_resume is False: + log.info("Resuming weights only!") + ckpt = torch.load(ckpt_path, map_location=model.device) + if "state_dict" in ckpt: + ckpt = ckpt["state_dict"] + err = model.load_state_dict(ckpt, strict=False) + log.info(f"Error loading state dict: {err}") + ckpt_path = None + + trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + + train_metrics = trainer.callback_metrics + + if cfg.get("test"): + log.info("Starting testing!") + ckpt_path = trainer.checkpoint_callback.best_model_path + if ckpt_path == "": + log.warning("Best ckpt not found! Using current weights for testing...") + ckpt_path = cfg.get("ckpt_path") + + trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + log.info(f"Best ckpt path: {ckpt_path}") + + test_metrics = trainer.callback_metrics + + # merge train and test metrics + metric_dict = {**train_metrics, **test_metrics} + + return metric_dict, object_dict + + +@hydra.main( + version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml" +) +def main(cfg: DictConfig) -> Optional[float]: + # train the model + train(cfg) + + +if __name__ == "__main__": + main() diff --git a/fish_speech/utils/__init__.py b/fish_speech/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53cf2f23174ddac9bf523730aca2f6a9965d134a --- /dev/null +++ b/fish_speech/utils/__init__.py @@ -0,0 +1,24 @@ +from .braceexpand import braceexpand +from .context import autocast_exclude_mps +from .file import get_latest_checkpoint +from .instantiators import instantiate_callbacks, instantiate_loggers +from .logger import RankedLogger +from .logging_utils import log_hyperparameters +from .rich_utils import enforce_tags, print_config_tree +from .utils import extras, get_metric_value, set_seed, task_wrapper + +__all__ = [ + "enforce_tags", + "extras", + "get_metric_value", + "RankedLogger", + "instantiate_callbacks", + "instantiate_loggers", + "log_hyperparameters", + "print_config_tree", + "task_wrapper", + "braceexpand", + "get_latest_checkpoint", + "autocast_exclude_mps", + "set_seed", +] diff --git a/fish_speech/utils/__pycache__/__init__.cpython-310.pyc b/fish_speech/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f898aea996460deb67220a89b96b9f9a3ac2dcac Binary files /dev/null and b/fish_speech/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/fish_speech/utils/__pycache__/__init__.cpython-311.pyc b/fish_speech/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c47fcd33d7e78d173e979359d42f144aa89ca988 Binary files /dev/null and b/fish_speech/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/fish_speech/utils/__pycache__/braceexpand.cpython-310.pyc b/fish_speech/utils/__pycache__/braceexpand.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06520901fe11707b16762007668ac1adc93e1695 Binary files /dev/null and b/fish_speech/utils/__pycache__/braceexpand.cpython-310.pyc differ diff --git a/fish_speech/utils/__pycache__/braceexpand.cpython-311.pyc b/fish_speech/utils/__pycache__/braceexpand.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7672cb7128109ee6369f9f188185994c89217c9f Binary files /dev/null and b/fish_speech/utils/__pycache__/braceexpand.cpython-311.pyc differ diff --git a/fish_speech/utils/__pycache__/context.cpython-310.pyc b/fish_speech/utils/__pycache__/context.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d2786f95f27a2000e27e18488cb831e495d1ada Binary files /dev/null and b/fish_speech/utils/__pycache__/context.cpython-310.pyc differ diff --git a/fish_speech/utils/__pycache__/context.cpython-311.pyc b/fish_speech/utils/__pycache__/context.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d746b5ee2210833a402bc90148f56bd016f473f Binary files /dev/null and b/fish_speech/utils/__pycache__/context.cpython-311.pyc differ diff --git a/fish_speech/utils/__pycache__/file.cpython-310.pyc b/fish_speech/utils/__pycache__/file.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..619c698c0dfe596250bdf0c78a7bed2b6525ba0e Binary files /dev/null and b/fish_speech/utils/__pycache__/file.cpython-310.pyc differ diff --git a/fish_speech/utils/__pycache__/file.cpython-311.pyc b/fish_speech/utils/__pycache__/file.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83d701dcaf2a70f951279fd7e90cd00dfc51324e Binary files /dev/null and b/fish_speech/utils/__pycache__/file.cpython-311.pyc differ diff --git a/fish_speech/utils/__pycache__/instantiators.cpython-310.pyc b/fish_speech/utils/__pycache__/instantiators.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd106557d47bc3e454d640c444155422d1b56bee Binary files /dev/null and b/fish_speech/utils/__pycache__/instantiators.cpython-310.pyc differ diff --git a/fish_speech/utils/__pycache__/instantiators.cpython-311.pyc b/fish_speech/utils/__pycache__/instantiators.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b24a6485661e4be93452051b418c1d174bcf977c Binary files /dev/null and b/fish_speech/utils/__pycache__/instantiators.cpython-311.pyc differ diff --git a/fish_speech/utils/__pycache__/logger.cpython-310.pyc b/fish_speech/utils/__pycache__/logger.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62ac6dd3335b1d7d3c3cf724a7b61185829be505 Binary files /dev/null and b/fish_speech/utils/__pycache__/logger.cpython-310.pyc differ diff --git a/fish_speech/utils/__pycache__/logger.cpython-311.pyc b/fish_speech/utils/__pycache__/logger.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39bb762688c649c20460f08a39a3371347de56c6 Binary files /dev/null and b/fish_speech/utils/__pycache__/logger.cpython-311.pyc differ diff --git a/fish_speech/utils/__pycache__/logging_utils.cpython-310.pyc b/fish_speech/utils/__pycache__/logging_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e85c1b95ed30fa702a301660c5f1e7cb2b15c1e Binary files /dev/null and b/fish_speech/utils/__pycache__/logging_utils.cpython-310.pyc differ diff --git a/fish_speech/utils/__pycache__/logging_utils.cpython-311.pyc b/fish_speech/utils/__pycache__/logging_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85167fd9df12676fe436fdfb1ee4f7d8bd9b335d Binary files /dev/null and b/fish_speech/utils/__pycache__/logging_utils.cpython-311.pyc differ diff --git a/fish_speech/utils/__pycache__/rich_utils.cpython-310.pyc b/fish_speech/utils/__pycache__/rich_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fe9fe2cd1aa91cbe7126d363af68c556ae1167c Binary files /dev/null and b/fish_speech/utils/__pycache__/rich_utils.cpython-310.pyc differ diff --git a/fish_speech/utils/__pycache__/rich_utils.cpython-311.pyc b/fish_speech/utils/__pycache__/rich_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..628da7662d28091856d1fcb0389c3db09ad8730c Binary files /dev/null and b/fish_speech/utils/__pycache__/rich_utils.cpython-311.pyc differ diff --git a/fish_speech/utils/__pycache__/spectrogram.cpython-310.pyc b/fish_speech/utils/__pycache__/spectrogram.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17ca82b458cd0e7fc97d2ad5c8be8b130cb6a4bd Binary files /dev/null and b/fish_speech/utils/__pycache__/spectrogram.cpython-310.pyc differ diff --git a/fish_speech/utils/__pycache__/utils.cpython-310.pyc b/fish_speech/utils/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2809b72d4e8f82c71aaad3fe06835882c11c0f7e Binary files /dev/null and b/fish_speech/utils/__pycache__/utils.cpython-310.pyc differ diff --git a/fish_speech/utils/__pycache__/utils.cpython-311.pyc b/fish_speech/utils/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2eded4dfe6cbf55a90158c4a5188fbc38dee6491 Binary files /dev/null and b/fish_speech/utils/__pycache__/utils.cpython-311.pyc differ diff --git a/fish_speech/utils/braceexpand.py b/fish_speech/utils/braceexpand.py new file mode 100644 index 0000000000000000000000000000000000000000..f3ac739f01f7e10e039c68c1157d6c761064f974 --- /dev/null +++ b/fish_speech/utils/braceexpand.py @@ -0,0 +1,217 @@ +""" +Bash-style brace expansion +Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py +License: MIT +""" + +import re +import string +from itertools import chain, product +from typing import Iterable, Iterator, Optional + +__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"] + + +class UnbalancedBracesError(ValueError): + pass + + +alphabet = string.ascii_uppercase + string.ascii_lowercase + +int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$") +char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$") +escape_re = re.compile(r"\\(.)") + + +def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]: + """braceexpand(pattern) -> iterator over generated strings + + Returns an iterator over the strings resulting from brace expansion + of pattern. This function implements Brace Expansion as described in + bash(1), with the following limitations: + + * A pattern containing unbalanced braces will raise an + UnbalancedBracesError exception. In bash, unbalanced braces will either + be partly expanded or ignored. + + * A mixed-case character range like '{Z..a}' or '{a..Z}' will not + include the characters '[]^_`' between 'Z' and 'a'. + + When escape is True (the default), characters in pattern can be + prefixed with a backslash to cause them not to be interpreted as + special characters for brace expansion (such as '{', '}', ','). + To pass through a a literal backslash, double it ('\\\\'). + + When escape is False, backslashes in pattern have no special + meaning and will be preserved in the output. + + Examples: + + >>> from braceexpand import braceexpand + + # Integer range + >>> list(braceexpand('item{1..3}')) + ['item1', 'item2', 'item3'] + + # Character range + >>> list(braceexpand('{a..c}')) + ['a', 'b', 'c'] + + # Sequence + >>> list(braceexpand('index.html{,.backup}')) + ['index.html', 'index.html.backup'] + + # Nested patterns + >>> list(braceexpand('python{2.{5..7},3.{2,3}}')) + ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3'] + + # Prefixing an integer with zero causes all numbers to be padded to + # the same width. + >>> list(braceexpand('{07..10}')) + ['07', '08', '09', '10'] + + # An optional increment can be specified for ranges. + >>> list(braceexpand('{a..g..2}')) + ['a', 'c', 'e', 'g'] + + # Ranges can go in both directions. + >>> list(braceexpand('{4..1}')) + ['4', '3', '2', '1'] + + # Numbers can be negative + >>> list(braceexpand('{2..-1}')) + ['2', '1', '0', '-1'] + + # Unbalanced braces raise an exception. + >>> list(braceexpand('{1{2,3}')) + Traceback (most recent call last): + ... + UnbalancedBracesError: Unbalanced braces: '{1{2,3}' + + # By default, the backslash is the escape character. + >>> list(braceexpand(r'{1\\{2,3}')) + ['1{2', '3'] + + # Setting 'escape' to False disables backslash escaping. + >>> list(braceexpand(r'\\{1,2}', escape=False)) + ['\\\\1', '\\\\2'] + + """ + return ( + escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape) + ) + + +def parse_pattern(pattern: str, escape: bool) -> Iterator[str]: + start = 0 + pos = 0 + bracketdepth = 0 + items: list[Iterable[str]] = [] + + # print 'pattern:', pattern + while pos < len(pattern): + if escape and pattern[pos] == "\\": + pos += 2 + continue + elif pattern[pos] == "{": + if bracketdepth == 0 and pos > start: + # print 'literal:', pattern[start:pos] + items.append([pattern[start:pos]]) + start = pos + bracketdepth += 1 + elif pattern[pos] == "}": + bracketdepth -= 1 + if bracketdepth == 0: + # print 'expression:', pattern[start+1:pos] + expr = pattern[start + 1 : pos] + item = parse_expression(expr, escape) + if item is None: # not a range or sequence + items.extend([["{"], parse_pattern(expr, escape), ["}"]]) + else: + items.append(item) + start = pos + 1 # skip the closing brace + pos += 1 + + if bracketdepth != 0: # unbalanced braces + raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern) + + if start < pos: + items.append([pattern[start:]]) + + return ("".join(item) for item in product(*items)) + + +def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]: + int_range_match = int_range_re.match(expr) + if int_range_match: + return make_int_range(*int_range_match.groups()) + + char_range_match = char_range_re.match(expr) + if char_range_match: + return make_char_range(*char_range_match.groups()) + + return parse_sequence(expr, escape) + + +def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]: + # sequence -> chain(*sequence_items) + start = 0 + pos = 0 + bracketdepth = 0 + items: list[Iterable[str]] = [] + + # print 'sequence:', seq + while pos < len(seq): + if escape and seq[pos] == "\\": + pos += 2 + continue + elif seq[pos] == "{": + bracketdepth += 1 + elif seq[pos] == "}": + bracketdepth -= 1 + elif seq[pos] == "," and bracketdepth == 0: + items.append(parse_pattern(seq[start:pos], escape)) + start = pos + 1 # skip the comma + pos += 1 + + if bracketdepth != 0: + raise UnbalancedBracesError + if not items: + return None + + # part after the last comma (may be the empty string) + items.append(parse_pattern(seq[start:], escape)) + return chain(*items) + + +def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]: + if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]): + padding = max(len(left), len(right)) + else: + padding = 0 + step = (int(incr) or 1) if incr else 1 + start = int(left) + end = int(right) + r = range(start, end + 1, step) if start < end else range(start, end - 1, -step) + fmt = "%0{}d".format(padding) + return (fmt % i for i in r) + + +def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str: + step = (int(incr) or 1) if incr else 1 + start = alphabet.index(left) + end = alphabet.index(right) + if start < end: + return alphabet[start : end + 1 : step] + else: + end = end or -len(alphabet) + return alphabet[start : end - 1 : -step] + + +if __name__ == "__main__": + import doctest + import sys + + failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL) + if failed: + sys.exit(1) diff --git a/fish_speech/utils/context.py b/fish_speech/utils/context.py new file mode 100644 index 0000000000000000000000000000000000000000..f04a99290ab32f7fe5b60656075a2d03af8468d6 --- /dev/null +++ b/fish_speech/utils/context.py @@ -0,0 +1,13 @@ +from contextlib import nullcontext + +import torch + + +def autocast_exclude_mps( + device_type: str, dtype: torch.dtype +) -> nullcontext | torch.autocast: + return ( + nullcontext() + if torch.backends.mps.is_available() + else torch.autocast(device_type, dtype) + ) diff --git a/fish_speech/utils/file.py b/fish_speech/utils/file.py new file mode 100644 index 0000000000000000000000000000000000000000..78c82640a963fa556657107729f7543d2e7c3510 --- /dev/null +++ b/fish_speech/utils/file.py @@ -0,0 +1,16 @@ +import os +from pathlib import Path + + +def get_latest_checkpoint(path: Path | str) -> Path | None: + # Find the latest checkpoint + ckpt_dir = Path(path) + + if ckpt_dir.exists() is False: + return None + + ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime) + if len(ckpts) == 0: + return None + + return ckpts[-1] diff --git a/fish_speech/utils/instantiators.py b/fish_speech/utils/instantiators.py new file mode 100644 index 0000000000000000000000000000000000000000..f6ee463924f588a35477937fbe3c3364043bdf3e --- /dev/null +++ b/fish_speech/utils/instantiators.py @@ -0,0 +1,50 @@ +from typing import List + +import hydra +from omegaconf import DictConfig +from pytorch_lightning import Callback +from pytorch_lightning.loggers import Logger + +from .logger import RankedLogger + +log = RankedLogger(__name__, rank_zero_only=True) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config.""" + + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config.""" + + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger diff --git a/fish_speech/utils/logger.py b/fish_speech/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..94f94f738d1d87404354d086c30ef0ad9ab04cdc --- /dev/null +++ b/fish_speech/utils/logger.py @@ -0,0 +1,55 @@ +import logging +from typing import Mapping, Optional + +from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only + + +class RankedLogger(logging.LoggerAdapter): + """A multi-GPU-friendly python command line logger.""" + + def __init__( + self, + name: str = __name__, + rank_zero_only: bool = True, + extra: Optional[Mapping[str, object]] = None, + ) -> None: + """Initializes a multi-GPU-friendly python command line logger that logs on all processes + with their rank prefixed in the log message. + + :param name: The name of the logger. Default is ``__name__``. + :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. + :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. + """ + logger = logging.getLogger(name) + super().__init__(logger=logger, extra=extra) + self.rank_zero_only = rank_zero_only + + def log( + self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs + ) -> None: + """Delegate a log call to the underlying logger, after prefixing its message with the rank + of the process it's being logged from. If `'rank'` is provided, then the log will only + occur on that rank/process. + + :param level: The level to log at. Look at `logging.__init__.py` for more information. + :param msg: The message to log. + :param rank: The rank to log at. + :param args: Additional args to pass to the underlying logging function. + :param kwargs: Any additional keyword args to pass to the underlying logging function. + """ + if self.isEnabledFor(level): + msg, kwargs = self.process(msg, kwargs) + current_rank = getattr(rank_zero_only, "rank", None) + if current_rank is None: + raise RuntimeError( + "The `rank_zero_only.rank` needs to be set before use" + ) + msg = rank_prefixed_message(msg, current_rank) + if self.rank_zero_only: + if current_rank == 0: + self.logger.log(level, msg, *args, **kwargs) + else: + if rank is None: + self.logger.log(level, msg, *args, **kwargs) + elif current_rank == rank: + self.logger.log(level, msg, *args, **kwargs) diff --git a/fish_speech/utils/logging_utils.py b/fish_speech/utils/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8e3b0a2519e12845f09e5fbe86dfccbf5b345429 --- /dev/null +++ b/fish_speech/utils/logging_utils.py @@ -0,0 +1,48 @@ +from lightning.pytorch.utilities import rank_zero_only + +from fish_speech.utils import logger as log + + +@rank_zero_only +def log_hyperparameters(object_dict: dict) -> None: + """Controls which config parts are saved by lightning loggers. + + Additionally saves: + - Number of model parameters + """ + + hparams = {} + + cfg = object_dict["cfg"] + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + hparams["data"] = cfg["data"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) diff --git a/fish_speech/utils/rich_utils.py b/fish_speech/utils/rich_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6a465f54d610779766d51e3d1a020a3b1517fd1f --- /dev/null +++ b/fish_speech/utils/rich_utils.py @@ -0,0 +1,100 @@ +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from lightning.pytorch.utilities import rank_zero_only +from omegaconf import DictConfig, OmegaConf, open_dict +from rich.prompt import Prompt + +from fish_speech.utils import logger as log + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "data", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints content of DictConfig using Rich library and its tree structure. + + Args: + cfg (DictConfig): Configuration composed by Hydra. + print_order (Sequence[str], optional): Determines in what order config components are printed. + resolve (bool, optional): Whether to resolve reference fields of DictConfig. + save_to_file (bool, optional): Whether to export config to the hydra output folder. + """ # noqa: E501 + + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + ( + queue.append(field) + if field in cfg + else log.warning( + f"Field '{field}' not found in config. " + + f"Skipping '{field}' config printing..." + ) + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501 + + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) diff --git a/fish_speech/utils/spectrogram.py b/fish_speech/utils/spectrogram.py new file mode 100644 index 0000000000000000000000000000000000000000..01c3d7a2ab0f707ae92dbde0feb173927720c841 --- /dev/null +++ b/fish_speech/utils/spectrogram.py @@ -0,0 +1,122 @@ +import torch +import torchaudio.functional as F +from torch import Tensor, nn +from torchaudio.transforms import MelScale + + +class LinearSpectrogram(nn.Module): + def __init__( + self, + n_fft=2048, + win_length=2048, + hop_length=512, + center=False, + mode="pow2_sqrt", + ): + super().__init__() + + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.center = center + self.mode = mode + + self.register_buffer("window", torch.hann_window(win_length), persistent=False) + + def forward(self, y: Tensor) -> Tensor: + if y.ndim == 3: + y = y.squeeze(1) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + ( + (self.win_length - self.hop_length) // 2, + (self.win_length - self.hop_length + 1) // 2, + ), + mode="reflect", + ).squeeze(1) + + spec = torch.stft( + y, + self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + center=self.center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + + spec = torch.view_as_real(spec) + + if self.mode == "pow2_sqrt": + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + + return spec + + +class LogMelSpectrogram(nn.Module): + def __init__( + self, + sample_rate=44100, + n_fft=2048, + win_length=2048, + hop_length=512, + n_mels=128, + center=False, + f_min=0.0, + f_max=None, + ): + super().__init__() + + self.sample_rate = sample_rate + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.center = center + self.n_mels = n_mels + self.f_min = f_min + self.f_max = f_max or float(sample_rate // 2) + + self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center) + + fb = F.melscale_fbanks( + n_freqs=self.n_fft // 2 + 1, + f_min=self.f_min, + f_max=self.f_max, + n_mels=self.n_mels, + sample_rate=self.sample_rate, + norm="slaney", + mel_scale="slaney", + ) + self.register_buffer( + "fb", + fb, + persistent=False, + ) + + def compress(self, x: Tensor) -> Tensor: + return torch.log(torch.clamp(x, min=1e-5)) + + def decompress(self, x: Tensor) -> Tensor: + return torch.exp(x) + + def apply_mel_scale(self, x: Tensor) -> Tensor: + return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2) + + def forward( + self, x: Tensor, return_linear: bool = False, sample_rate: int = None + ) -> Tensor: + if sample_rate is not None and sample_rate != self.sample_rate: + x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate) + + linear = self.spectrogram(x) + x = self.apply_mel_scale(linear) + x = self.compress(x) + + if return_linear: + return x, self.compress(linear) + + return x diff --git a/fish_speech/utils/utils.py b/fish_speech/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5a34bdcfedff76c333f50ed8be050d0dd5a8f98a --- /dev/null +++ b/fish_speech/utils/utils.py @@ -0,0 +1,136 @@ +import random +import warnings +from importlib.util import find_spec +from typing import Callable + +import numpy as np +import torch +from omegaconf import DictConfig + +from .logger import RankedLogger +from .rich_utils import enforce_tags, print_config_tree + +log = RankedLogger(__name__, rank_zero_only=True) + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + """ + + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + print_config_tree(cfg, resolve=True, save_to_file=True) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that controls the failure behavior when executing the task function. + + This wrapper can be used to: + - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) + - save the exception to a `.log` file + - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) + - etc. (adjust depending on your needs) + + Example: + ``` + @utils.task_wrapper + def train(cfg: DictConfig) -> Tuple[dict, dict]: + + ... + + return metric_dict, object_dict + ``` + """ # noqa: E501 + + def wrap(cfg: DictConfig): + # execute the task + try: + metric_dict, object_dict = task_func(cfg=cfg) + + # things to do if exception occurs + except Exception as ex: + # save exception to `.log` file + log.exception("") + + # some hyperparameter combinations might be invalid or + # cause out-of-memory errors so when using hparam search + # plugins like Optuna, you might want to disable + # raising the below exception to avoid multirun failure + raise ex + + # things to always do after either success or exception + finally: + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.run_dir}") + + # always close wandb run (even if exception occurs so multirun won't fail) + if find_spec("wandb"): # check if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() + + return metric_dict, object_dict + + return wrap + + +def get_metric_value(metric_dict: dict, metric_name: str) -> float: + """Safely retrieves value of the metric logged in LightningModule.""" + + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise Exception( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value + + +def set_seed(seed: int): + if seed < 0: + seed = -seed + if seed > (1 << 31): + seed = 1 << 31 + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + if torch.backends.cudnn.is_available(): + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False diff --git a/fish_speech/webui/__pycache__/launch_utils.cpython-310.pyc b/fish_speech/webui/__pycache__/launch_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a6ed5a24f1b659ff6aeb936748358bc413dc26e Binary files /dev/null and b/fish_speech/webui/__pycache__/launch_utils.cpython-310.pyc differ diff --git a/fish_speech/webui/css/style.css b/fish_speech/webui/css/style.css new file mode 100644 index 0000000000000000000000000000000000000000..3c7a22ecc31881a65a76369b0fd889330a0874c7 --- /dev/null +++ b/fish_speech/webui/css/style.css @@ -0,0 +1,161 @@ +:root { + --my-200: #80eeee; + --my-50: #ecfdf5; + --water-width: 300px; + --water-heigh: 300px; +} + + +/* general styled components */ +.tools { + align-items: center; + justify-content: center; +} + +.gradio-button { + max-width: 2.2em; + min-width: 2.2em !important; + height: 2.4em; + align-self: end; + line-height: 1em; + border-radius: 0.5em; + +} + +.gradio-button.secondary-down, .gradio-button.secondary-down:hover{ + box-shadow: 1px 1px 1px rgba(0,0,0,0.25) inset, 0px 0px 3px rgba(0,0,0,0.15) inset; +} + +/* replace original footer with ours */ +a{ + font-weight: bold; + cursor: pointer; + color: #030C14 !important; +} + +footer { + display: none !important; +} + +#footer{ + text-align: center; +} + +#footer div{ + display: inline-block; +} + +#footer .versions{ + font-size: 85%; + opacity: 0.85; +} + +/*@keyframes moveBackground {*/ +/* 0% {*/ +/* background-position: 0 0;*/ +/* }*/ +/* 100% {*/ +/* background-position: -100px 100px;*/ +/* }*/ +/*}*/ +@keyframes moveJellyBackground { + 0% { + background-position: 0% 50%; + } + 50% { + background-position: 100% 50%; + } + 100% { + background-position: 0% 50%; + } +} + +.gradio-container { + position: absolute; + z-index: 10; +} + + +.quan { + position: absolute; + bottom: 0; + width: var(--water-width); + height: var(--water-heigh); + border-radius: 0; + /*border: 3px solid rgb(246, 247, 248);*/ + /*box-shadow: 0 0 0 3px rgb(41, 134, 196);*/ + z-index: 0; + +} + +.quan:last-child { + margin-right: 0; +} + +.shui { + position: absolute; + top: 0; + left: 0; + width: 100%; + height: 100%; + background-color: rgb(23, 106, 201); + border-radius: 0; + overflow: hidden; + z-index: 0; +} + +.shui::after { + + content: ''; + position: absolute; + top: 20%; + left: 50%; + width: 150%; + height: 150%; + border-radius: 40%; + background-image: radial-gradient(circle at 0% 50%, #dcfcf1, var(--my-50) 50%); + animation: shi 5s linear infinite; +} + +@keyframes shi { + 0% { + transform: translate(-50%, -65%) rotate(0deg); + } + 100% { + transform: translate(-50%, -65%) rotate(360deg); + } +} + +.shui::before { + content: ''; + position: absolute; + top: 20%; + left: 50%; + width: 150%; + height: 150%; + border-radius: 42%; + background-color: rgb(240, 228, 228, 0.2); + animation: xu 7s linear infinite; +} + +@keyframes xu { + 0% { + transform: translate(-50%, -60%) rotate(0deg); + } + 100% { + transform: translate(-50%, -60%) rotate(360deg); + } +} + +fieldset.data_src div.wrap label { + background: #f8bffee0 !important; +} + +.scrollable-component { + max-height: 100px; + overflow-y: auto; +} + +#file_accordion { + max-height: 220px !important; +} diff --git a/fish_speech/webui/html/footer.html b/fish_speech/webui/html/footer.html new file mode 100644 index 0000000000000000000000000000000000000000..ac1745aa6f41f86a17e3d95564c2bf7a8d7bb615 --- /dev/null +++ b/fish_speech/webui/html/footer.html @@ -0,0 +1,11 @@ +
+ API +  •  + Github +  •  + Gradio +
+
+
+{versions} +
diff --git a/fish_speech/webui/js/animate.js b/fish_speech/webui/js/animate.js new file mode 100644 index 0000000000000000000000000000000000000000..0637a541a8e704632a42b89bdf1471b26e7bb868 --- /dev/null +++ b/fish_speech/webui/js/animate.js @@ -0,0 +1,69 @@ + +function createGradioAnimation() { + const params = new URLSearchParams(window.location.search); + if (!params.has('__theme')) { + params.set('__theme', 'light'); + window.location.search = params.toString(); + } + + var gradioApp = document.querySelector('gradio-app'); + if (gradioApp) { + + document.documentElement.style.setProperty('--my-200', '#80eeee'); + document.documentElement.style.setProperty('--my-50', '#ecfdf5'); + + // gradioApp.style.position = 'relative'; + // gradioApp.style.backgroundSize = '200% 200%'; + // gradioApp.style.animation = 'moveJellyBackground 10s ease infinite'; + // gradioApp.style.backgroundImage = 'radial-gradient(circle at 0% 50%, var(--my-200), var(--my-50) 50%)'; + // gradioApp.style.display = 'flex'; + // gradioApp.style.justifyContent = 'flex-start'; + // gradioApp.style.flexWrap = 'nowrap'; + // gradioApp.style.overflowX = 'auto'; + + // for (let i = 0; i < 6; i++) { + // var quan = document.createElement('div'); + // quan.className = 'quan'; + // gradioApp.insertBefore(quan, gradioApp.firstChild); + // quan.id = 'quan' + i.toString(); + // quan.style.left = 'calc(var(--water-width) * ' + i.toString() + ')'; + // var quanContainer = document.querySelector('.quan'); + // if (quanContainer) { + // var shui = document.createElement('div'); + // shui.className = 'shui'; + // quanContainer.insertBefore(shui, quanContainer.firstChild) + // } + // } + } + + var container = document.createElement('div'); + container.id = 'gradio-animation'; + container.style.fontSize = '2em'; + container.style.fontFamily = 'Maiandra GD, ui-monospace, monospace'; + container.style.fontWeight = 'bold'; + container.style.textAlign = 'center'; + container.style.marginBottom = '20px'; + + var text = 'Welcome to Fish-Speech!'; + for (var i = 0; i < text.length; i++) { + (function(i){ + setTimeout(function(){ + var letter = document.createElement('span'); + letter.style.opacity = '0'; + letter.style.transition = 'opacity 0.5s'; + letter.innerText = text[i]; + + container.appendChild(letter); + + setTimeout(function() { + letter.style.opacity = '1'; + }, 50); + }, i * 200); + })(i); + } + + var gradioContainer = document.querySelector('.gradio-container'); + gradioContainer.insertBefore(container, gradioContainer.firstChild); + + return 'Animation created'; +} diff --git a/fish_speech/webui/launch_utils.py b/fish_speech/webui/launch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..790c0e632ce55e099e5578d8824e94b1d1260d6e --- /dev/null +++ b/fish_speech/webui/launch_utils.py @@ -0,0 +1,120 @@ +import importlib.util +import os +import subprocess +import sys +from functools import lru_cache +from pathlib import Path +from typing import Iterable + +import gradio as gr +from gradio.themes.base import Base +from gradio.themes.utils import colors, fonts, sizes + +GIT = ( + (Path(os.environ.get("GIT_HOME", "")) / "git").resolve() + if sys.platform == "win32" + else "git" +) +GIT = str(GIT) + + +def is_module_installed(module_name: str) -> bool: + spec = importlib.util.find_spec(module_name) + return spec is not None + + +@lru_cache() +def commit_hash(): + try: + return subprocess.check_output( + [GIT, "log", "-1", "--format='%h %s'"], shell=False, encoding="utf8" + ).strip() + except Exception: + return "" + + +def versions_html(): + import torch + + python_version = ".".join([str(x) for x in sys.version_info[0:3]]) + commit = commit_hash() + hash = commit.strip("'").split(" ")[0] + + return f""" +version: {hash} + •  +python: {python_version} + •  +torch: {getattr(torch, '__long_version__',torch.__version__)} + •  +gradio: {gr.__version__} + •  +author: fishaudio +""" + + +def version_check(commit): + try: + import requests + + commits = requests.get( + "https://api.github.com/repos/fishaudio/fish-speech/branches/main" + ).json() + if commit != "" and commits["commit"]["sha"] != commit: + print("--------------------------------------------------------") + print("| You are not up to date with the most recent release. |") + print("| Consider running `git pull` to update. |") + print("--------------------------------------------------------") + elif commits["commit"]["sha"] == commit: + print("You are up to date with the most recent release.") + else: + print("Not a git clone, can't perform version check.") + except Exception as e: + print("version check failed", e) + + +class Seafoam(Base): + def __init__( + self, + *, + primary_hue: colors.Color | str = colors.emerald, + secondary_hue: colors.Color | str = colors.blue, + neutral_hue: colors.Color | str = colors.blue, + spacing_size: sizes.Size | str = sizes.spacing_md, + radius_size: sizes.Size | str = sizes.radius_md, + text_size: sizes.Size | str = sizes.text_lg, + font: fonts.Font | str | Iterable[fonts.Font | str] = ( + fonts.GoogleFont("Quicksand"), + "ui-sans-serif", + "sans-serif", + ), + font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( + fonts.GoogleFont("IBM Plex Mono"), + "ui-monospace", + "monospace", + ), + ): + super().__init__( + primary_hue=primary_hue, + secondary_hue=secondary_hue, + neutral_hue=neutral_hue, + spacing_size=spacing_size, + radius_size=radius_size, + text_size=text_size, + font=font, + font_mono=font_mono, + ) + super().set( + button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)", + button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)", + button_primary_text_color="white", + button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)", + slider_color="*secondary_300", + slider_color_dark="*secondary_600", + block_title_text_weight="600", + block_border_width="3px", + block_shadow="*shadow_drop_lg", + # button_shadow="*shadow_drop_lg", + button_small_padding="0px", + button_large_padding="3px", + ) diff --git a/fish_speech/webui/manage.py b/fish_speech/webui/manage.py new file mode 100644 index 0000000000000000000000000000000000000000..c21233eee3e3e99754c68efc2b8809a62217eb53 --- /dev/null +++ b/fish_speech/webui/manage.py @@ -0,0 +1,1239 @@ +from __future__ import annotations + +import os + +os.environ["USE_LIBUV"] = "0" +import datetime +import html +import json +import platform +import shutil +import signal +import subprocess +import sys +from pathlib import Path + +import gradio as gr +import psutil +import yaml +from loguru import logger +from tqdm import tqdm + +PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python") +sys.path.insert(0, "") +print(sys.path) +cur_work_dir = Path(os.getcwd()).resolve() +print("You are in ", str(cur_work_dir)) + +from fish_speech.i18n import i18n +from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html + +config_path = cur_work_dir / "fish_speech" / "configs" +vqgan_yml_path = config_path / "firefly_gan_vq.yaml" +llama_yml_path = config_path / "text2semantic_finetune.yaml" + +env = os.environ.copy() +env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0" + +seafoam = Seafoam() + + +def build_html_error_message(error): + return f""" +
+ {html.escape(error)} +
+ """ + + +def build_html_ok_message(msg): + return f""" +
+ {html.escape(msg)} +
+ """ + + +def build_html_href(link, desc, msg): + return f""" + + {html.escape(msg)} + {desc} + + """ + + +def load_data_in_raw(path): + with open(path, "r", encoding="utf-8") as file: + data = file.read() + return str(data) + + +def kill_proc_tree(pid, including_parent=True): + try: + parent = psutil.Process(pid) + except psutil.NoSuchProcess: + # Process already terminated + return + + children = parent.children(recursive=True) + for child in children: + try: + os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL + except OSError: + pass + if including_parent: + try: + os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL + except OSError: + pass + + +system = platform.system() +p_label = None +p_infer = None +p_tensorboard = None + + +def kill_process(pid): + if system == "Windows": + cmd = "taskkill /t /f /pid %s" % pid + # os.system(cmd) + subprocess.run(cmd) + else: + kill_proc_tree(pid) + + +def change_label(if_label): + global p_label + if if_label == True and p_label is None: + url = "http://localhost:3000" + remote_url = "https://text-labeler.pages.dev/" + try: + p_label = subprocess.Popen( + [ + ( + "asr-label-linux-x64" + if sys.platform == "linux" + else "asr-label-win-x64.exe" + ) + ] + ) + except FileNotFoundError: + logger.warning("asr-label execution not found!") + + yield build_html_href( + link=remote_url, + desc=i18n("Optional online ver"), + msg=i18n("Opened labeler in browser"), + ) + + elif if_label == False and p_label is not None: + kill_process(p_label.pid) + p_label = None + yield build_html_ok_message("Nothing") + + +def clean_infer_cache(): + import tempfile + + temp_dir = Path(tempfile.gettempdir()) + gradio_dir = str(temp_dir / "gradio") + try: + shutil.rmtree(gradio_dir) + logger.info(f"Deleted cached audios: {gradio_dir}") + except PermissionError: + logger.info(f"Permission denied: Unable to delete {gradio_dir}") + except FileNotFoundError: + logger.info(f"{gradio_dir} was not found") + except Exception as e: + logger.info(f"An error occurred: {e}") + + +def change_infer( + if_infer, + host, + port, + infer_decoder_model, + infer_decoder_config, + infer_llama_model, + infer_compile, +): + global p_infer + if if_infer == True and p_infer == None: + env = os.environ.copy() + + env["GRADIO_SERVER_NAME"] = host + env["GRADIO_SERVER_PORT"] = port + # 启动第二个进程 + url = f"http://{host}:{port}" + yield build_html_ok_message( + i18n("Inferring interface is launched at {}").format(url) + ) + + clean_infer_cache() + + p_infer = subprocess.Popen( + [ + PYTHON, + "tools/webui.py", + "--decoder-checkpoint-path", + infer_decoder_model, + "--decoder-config-name", + infer_decoder_config, + "--llama-checkpoint-path", + infer_llama_model, + ] + + (["--compile"] if infer_compile == "Yes" else []), + env=env, + ) + + elif if_infer == False and p_infer is not None: + kill_process(p_infer.pid) + p_infer = None + yield build_html_error_message(i18n("Infer interface is closed")) + + +js = load_data_in_raw("fish_speech/webui/js/animate.js") +css = load_data_in_raw("fish_speech/webui/css/style.css") + +data_pre_output = (cur_work_dir / "data").resolve() +default_model_output = (cur_work_dir / "results").resolve() +default_filelist = data_pre_output / "detect.list" +data_pre_output.mkdir(parents=True, exist_ok=True) + +items = [] +dict_items = {} + + +def load_yaml_data_in_fact(yml_path): + with open(yml_path, "r", encoding="utf-8") as file: + yml = yaml.safe_load(file) + return yml + + +def write_yaml_data_in_fact(yml, yml_path): + with open(yml_path, "w", encoding="utf-8") as file: + yaml.safe_dump(yml, file, allow_unicode=True) + return yml + + +def generate_tree(directory, depth=0, max_depth=None, prefix=""): + if max_depth is not None and depth > max_depth: + return "" + + tree_str = "" + files = [] + directories = [] + for item in os.listdir(directory): + if os.path.isdir(os.path.join(directory, item)): + directories.append(item) + else: + files.append(item) + + entries = directories + files + for i, entry in enumerate(entries): + connector = "├── " if i < len(entries) - 1 else "└── " + tree_str += f"{prefix}{connector}{entry}
" + if i < len(directories): + extension = "│ " if i < len(entries) - 1 else " " + tree_str += generate_tree( + os.path.join(directory, entry), + depth + 1, + max_depth, + prefix=prefix + extension, + ) + return tree_str + + +def new_explorer(data_path, max_depth): + return gr.Markdown( + elem_classes=["scrollable-component"], + value=generate_tree(data_path, max_depth=max_depth), + ) + + +def add_item( + folder: str, + method: str, + label_lang: str, + if_initial_prompt: bool, + initial_prompt: str | None, +): + folder = folder.strip(" ").strip('"') + + folder_path = Path(folder) + + if folder and folder not in items and data_pre_output not in folder_path.parents: + if folder_path.is_dir(): + items.append(folder) + dict_items[folder] = dict( + type="folder", + method=method, + label_lang=label_lang, + initial_prompt=initial_prompt if if_initial_prompt else None, + ) + elif folder: + err = folder + return gr.Checkboxgroup(choices=items), build_html_error_message( + i18n("Invalid path: {}").format(err) + ) + + formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4) + logger.info("After Adding: " + formatted_data) + gr.Info(formatted_data) + return gr.Checkboxgroup(choices=items), build_html_ok_message( + i18n("Added path successfully!") + ) + + +def remove_items(selected_items): + global items, dict_items + to_remove = [item for item in items if item in selected_items] + for item in to_remove: + del dict_items[item] + items = [item for item in items if item in dict_items.keys()] + formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4) + logger.info(formatted_data) + gr.Warning("After Removing: " + formatted_data) + return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message( + i18n("Removed path successfully!") + ) + + +def show_selected(options): + selected_options = ", ".join(options) + + if options: + return i18n("Selected: {}").format(selected_options) + else: + return i18n("No selected options") + + +from pydub import AudioSegment + + +def convert_to_mono_in_place(audio_path: Path): + audio = AudioSegment.from_file(audio_path) + if audio.channels > 1: + mono_audio = audio.set_channels(1) + mono_audio.export(audio_path, format=audio_path.suffix[1:]) + logger.info(f"Convert {audio_path} successfully") + + +def list_copy(list_file_path, method): + wav_root = data_pre_output + lst = [] + with list_file_path.open("r", encoding="utf-8") as file: + for line in tqdm(file, desc="Processing audio/transcript"): + wav_path, speaker_name, language, text = line.strip().split("|") + original_wav_path = Path(wav_path) + target_wav_path = ( + wav_root / original_wav_path.parent.name / original_wav_path.name + ) + lst.append(f"{target_wav_path}|{speaker_name}|{language}|{text}") + if target_wav_path.is_file(): + continue + target_wav_path.parent.mkdir(parents=True, exist_ok=True) + if method == i18n("Copy"): + shutil.copy(original_wav_path, target_wav_path) + else: + shutil.move(original_wav_path, target_wav_path.parent) + convert_to_mono_in_place(target_wav_path) + original_lab_path = original_wav_path.with_suffix(".lab") + target_lab_path = ( + wav_root + / original_wav_path.parent.name + / original_wav_path.with_suffix(".lab").name + ) + if target_lab_path.is_file(): + continue + if method == i18n("Copy"): + shutil.copy(original_lab_path, target_lab_path) + else: + shutil.move(original_lab_path, target_lab_path.parent) + + if method == i18n("Move"): + with list_file_path.open("w", encoding="utf-8") as file: + file.writelines("\n".join(lst)) + + del lst + return build_html_ok_message(i18n("Use filelist")) + + +def check_files(data_path: str, max_depth: int, label_model: str, label_device: str): + global dict_items + data_path = Path(data_path) + gr.Warning("Pre-processing begins...") + for item, content in dict_items.items(): + item_path = Path(item) + tar_path = data_path / item_path.name + + if content["type"] == "folder" and item_path.is_dir(): + if content["method"] == i18n("Copy"): + os.makedirs(tar_path, exist_ok=True) + shutil.copytree( + src=str(item_path), dst=str(tar_path), dirs_exist_ok=True + ) + elif not tar_path.is_dir(): + shutil.move(src=str(item_path), dst=str(tar_path)) + + for suf in ["wav", "flac", "mp3"]: + for audio_path in tar_path.glob(f"**/*.{suf}"): + convert_to_mono_in_place(audio_path) + + cur_lang = content["label_lang"] + initial_prompt = content["initial_prompt"] + + transcribe_cmd = [ + PYTHON, + "tools/whisper_asr.py", + "--model-size", + label_model, + "--device", + label_device, + "--audio-dir", + tar_path, + "--save-dir", + tar_path, + "--language", + cur_lang, + ] + + if initial_prompt is not None: + transcribe_cmd += ["--initial-prompt", initial_prompt] + + if cur_lang != "IGNORE": + try: + gr.Warning("Begin To Transcribe") + subprocess.run( + transcribe_cmd, + env=env, + ) + except Exception: + print("Transcription error occurred") + + elif content["type"] == "file" and item_path.is_file(): + list_copy(item_path, content["method"]) + + return build_html_ok_message(i18n("Move files successfully")), new_explorer( + data_path, max_depth=max_depth + ) + + +def generate_folder_name(): + now = datetime.datetime.now() + folder_name = now.strftime("%Y%m%d_%H%M%S") + return folder_name + + +def train_process( + data_path: str, + option: str, + # llama config + llama_ckpt, + llama_base_config, + llama_lr, + llama_maxsteps, + llama_data_num_workers, + llama_data_batch_size, + llama_data_max_length, + llama_precision, + llama_check_interval, + llama_grad_batches, + llama_use_speaker, + llama_use_lora, +): + + backend = "nccl" if sys.platform == "linux" else "gloo" + + new_project = generate_folder_name() + print("New Project Name: ", new_project) + + if option == "VQGAN": + msg = "Skipped VQGAN Training." + gr.Warning(msg) + logger.info(msg) + + if option == "LLAMA": + msg = "LLAMA Training begins..." + gr.Warning(msg) + logger.info(msg) + subprocess.run( + [ + PYTHON, + "tools/vqgan/extract_vq.py", + str(data_pre_output), + "--num-workers", + "1", + "--batch-size", + "16", + "--config-name", + "firefly_gan_vq", + "--checkpoint-path", + "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", + ] + ) + + subprocess.run( + [ + PYTHON, + "tools/llama/build_dataset.py", + "--input", + str(data_pre_output), + "--text-extension", + ".lab", + "--num-workers", + "16", + ] + ) + ckpt_path = "checkpoints/fish-speech-1.4/model.pth" + lora_prefix = "lora_" if llama_use_lora else "" + llama_name = lora_prefix + "text2semantic_" + new_project + latest = next( + iter( + sorted( + [ + str(p.relative_to("results")) + for p in Path("results").glob(lora_prefix + "text2sem*/") + ], + reverse=True, + ) + ), + llama_name, + ) + project = ( + llama_name + if llama_ckpt == i18n("new") + else ( + latest + if llama_ckpt == i18n("latest") + else Path(llama_ckpt).relative_to("results") + ) + ) + logger.info(project) + + if llama_check_interval > llama_maxsteps: + llama_check_interval = llama_maxsteps + + train_cmd = [ + PYTHON, + "fish_speech/train.py", + "--config-name", + "text2semantic_finetune", + f"project={project}", + f"trainer.strategy.process_group_backend={backend}", + f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}", + f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}", + f"model.optimizer.lr={llama_lr}", + f"trainer.max_steps={llama_maxsteps}", + f"data.num_workers={llama_data_num_workers}", + f"data.batch_size={llama_data_batch_size}", + f"max_length={llama_data_max_length}", + f"trainer.precision={llama_precision}", + f"trainer.val_check_interval={llama_check_interval}", + f"trainer.accumulate_grad_batches={llama_grad_batches}", + f"train_dataset.interactive_prob={llama_use_speaker}", + ] + ([f"+lora@model.model.lora_config=r_8_alpha_16"] if llama_use_lora else []) + logger.info(train_cmd) + subprocess.run(train_cmd) + + return build_html_ok_message(i18n("Training stopped")) + + +def tensorboard_process( + if_tensorboard: bool, + tensorboard_dir: str, + host: str, + port: str, +): + global p_tensorboard + if if_tensorboard == True and p_tensorboard == None: + url = f"http://{host}:{port}" + yield build_html_ok_message( + i18n("Tensorboard interface is launched at {}").format(url) + ) + prefix = ["tensorboard"] + if Path("fishenv").exists(): + prefix = ["fishenv/env/python.exe", "fishenv/env/Scripts/tensorboard.exe"] + + p_tensorboard = subprocess.Popen( + prefix + + [ + "--logdir", + tensorboard_dir, + "--host", + host, + "--port", + port, + "--reload_interval", + "120", + ] + ) + elif if_tensorboard == False and p_tensorboard != None: + kill_process(p_tensorboard.pid) + p_tensorboard = None + yield build_html_error_message(i18n("Tensorboard interface is closed")) + + +def fresh_tb_dir(): + return gr.Dropdown( + choices=[str(p) for p in Path("results").glob("**/tensorboard/")] + ) + + +def list_decoder_models(): + paths = [str(p) for p in Path("checkpoints").glob("fish*/firefly*.pth")] + if not paths: + logger.warning("No decoder model found") + return paths + + +def list_llama_models(): + choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")] + choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")] + choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")] + choices = sorted(choices, reverse=True) + if not choices: + logger.warning("No LLaMA model found") + return choices + + +def list_lora_llama_models(): + choices = sorted( + [str(p) for p in Path("results").glob("lora*/**/*.ckpt")], reverse=True + ) + if not choices: + logger.warning("No LoRA LLaMA model found") + return choices + + +def fresh_decoder_model(): + return gr.Dropdown(choices=list_decoder_models()) + + +def fresh_llama_ckpt(llama_use_lora): + return gr.Dropdown( + choices=[i18n("latest"), i18n("new")] + + ( + [str(p) for p in Path("results").glob("text2sem*/")] + if not llama_use_lora + else [str(p) for p in Path("results").glob("lora_*/")] + ) + ) + + +def fresh_llama_model(): + return gr.Dropdown(choices=list_llama_models()) + + +def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_output): + if ( + lora_weight is None + or not Path(lora_weight).exists() + or not Path(llama_weight).exists() + ): + return build_html_error_message( + i18n( + "Path error, please check the model file exists in the corresponding path" + ) + ) + gr.Warning("Merging begins...") + merge_cmd = [ + PYTHON, + "tools/llama/merge_lora.py", + "--lora-config", + "r_8_alpha_16", + "--lora-weight", + lora_weight, + "--output", + llama_lora_output + "_" + generate_folder_name(), + ] + logger.info(merge_cmd) + subprocess.run(merge_cmd) + return build_html_ok_message(i18n("Merge successfully")) + + +def llama_quantify(llama_weight, quantify_mode): + if llama_weight is None or not Path(llama_weight).exists(): + return build_html_error_message( + i18n( + "Path error, please check the model file exists in the corresponding path" + ) + ) + + gr.Warning("Quantifying begins...") + + now = generate_folder_name() + quantify_cmd = [ + PYTHON, + "tools/llama/quantize.py", + "--checkpoint-path", + llama_weight, + "--mode", + quantify_mode, + "--timestamp", + now, + ] + logger.info(quantify_cmd) + subprocess.run(quantify_cmd) + if quantify_mode == "int8": + quantize_path = str( + Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-{now}" + ) + else: + quantize_path = str( + Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-g128-{now}" + ) + return build_html_ok_message( + i18n("Quantify successfully") + f"Path: {quantize_path}" + ) + + +init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path) +init_llama_yml = load_yaml_data_in_fact(llama_yml_path) + +with gr.Blocks( + head="", + js=js, + theme=seafoam, + analytics_enabled=False, + title="Fish Speech", +) as demo: + with gr.Row(): + with gr.Column(): + with gr.Tab("\U0001F4D6 " + i18n("Data Preprocessing")): + with gr.Row(): + textbox = gr.Textbox( + label="\U0000270F " + + i18n("Input Audio & Source Path for Transcription"), + info=i18n("Speaker is identified by the folder name"), + interactive=True, + ) + with gr.Row(equal_height=False): + with gr.Column(): + output_radio = gr.Radio( + label="\U0001F4C1 " + + i18n("Select source file processing method"), + choices=[i18n("Copy"), i18n("Move")], + value=i18n("Copy"), + interactive=True, + ) + with gr.Column(): + error = gr.HTML(label=i18n("Error Message")) + if_label = gr.Checkbox( + label=i18n("Open Labeler WebUI"), scale=0, show_label=True + ) + + with gr.Row(): + label_device = gr.Dropdown( + label=i18n("Labeling Device"), + info=i18n( + "It is recommended to use CUDA, if you have low configuration, use CPU" + ), + choices=["cpu", "cuda"], + value="cuda", + interactive=True, + ) + label_model = gr.Dropdown( + label=i18n("Whisper Model"), + info=i18n("Faster Whisper, Up to 5g GPU memory usage"), + choices=["large-v3", "medium"], + value="large-v3", + interactive=True, + ) + label_radio = gr.Dropdown( + label=i18n("Optional Label Language"), + info=i18n( + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format" + ), + choices=[ + (i18n("Chinese"), "zh"), + (i18n("English"), "en"), + (i18n("Japanese"), "ja"), + (i18n("Disabled"), "IGNORE"), + (i18n("auto"), "auto"), + ], + value="IGNORE", + interactive=True, + ) + + with gr.Row(): + if_initial_prompt = gr.Checkbox( + value=False, + label=i18n("Enable Initial Prompt"), + min_width=120, + scale=0, + ) + initial_prompt = gr.Textbox( + label=i18n("Initial Prompt"), + info=i18n( + "Initial prompt can provide contextual or vocabulary-specific guidance to the model." + ), + placeholder="This audio introduces the basic concepts and applications of artificial intelligence and machine learning.", + interactive=False, + ) + + with gr.Row(): + add_button = gr.Button( + "\U000027A1 " + i18n("Add to Processing Area"), + variant="primary", + ) + remove_button = gr.Button( + "\U000026D4 " + i18n("Remove Selected Data") + ) + + with gr.Tab("\U0001F6E0 " + i18n("Training Configuration")): + with gr.Row(): + model_type_radio = gr.Radio( + label=i18n( + "Select the model to be trained (Depending on the Tab page you are on)" + ), + interactive=False, + choices=["VQGAN", "LLAMA"], + value="VQGAN", + ) + with gr.Row(): + with gr.Column(): + with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page: + gr.HTML("You don't need to train this model!") + + with gr.Tab(label=i18n("LLAMA Configuration")) as llama_page: + with gr.Row(equal_height=False): + llama_use_lora = gr.Checkbox( + label=i18n("Use LoRA"), + info=i18n( + "Use LoRA can save GPU memory, but may reduce the quality of the model" + ), + value=True, + interactive=True, + ) + llama_ckpt = gr.Dropdown( + label=i18n("Select LLAMA ckpt"), + choices=[i18n("latest"), i18n("new")] + + [ + str(p) + for p in Path("results").glob("text2sem*/") + ] + + [str(p) for p in Path("results").glob("lora*/")], + value=i18n("latest"), + interactive=True, + ) + with gr.Row(equal_height=False): + llama_lr_slider = gr.Slider( + label=i18n("Initial Learning Rate"), + info=i18n( + "lr smaller -> usually train slower but more stable" + ), + interactive=True, + minimum=1e-5, + maximum=1e-4, + step=1e-5, + value=5e-5, + ) + llama_maxsteps_slider = gr.Slider( + label=i18n("Maximum Training Steps"), + info=i18n( + "recommend: max_steps = num_audios // batch_size * (2 to 5)" + ), + interactive=True, + minimum=1, + maximum=10000, + step=1, + value=50, + ) + with gr.Row(equal_height=False): + llama_base_config = gr.Dropdown( + label=i18n("Model Size"), + choices=[ + "text2semantic_finetune", + ], + value="text2semantic_finetune", + ) + llama_data_num_workers_slider = gr.Slider( + label=i18n("Number of Workers"), + minimum=1, + maximum=32, + step=1, + value=4, + ) + with gr.Row(equal_height=False): + llama_data_batch_size_slider = gr.Slider( + label=i18n("Batch Size"), + interactive=True, + minimum=1, + maximum=32, + step=1, + value=2, + ) + llama_data_max_length_slider = gr.Slider( + label=i18n("Maximum Length per Sample"), + interactive=True, + minimum=1024, + maximum=4096, + step=128, + value=2048, + ) + with gr.Row(equal_height=False): + llama_precision_dropdown = gr.Dropdown( + label=i18n("Precision"), + info=i18n( + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU" + ), + interactive=True, + choices=["32", "bf16-true", "16-mixed"], + value="bf16-true", + ) + llama_check_interval_slider = gr.Slider( + label=i18n("Save model every n steps"), + info=i18n( + "make sure that it's not greater than max_steps" + ), + interactive=True, + minimum=1, + maximum=1000, + step=1, + value=50, + ) + with gr.Row(equal_height=False): + llama_grad_batches = gr.Slider( + label=i18n("Accumulate Gradient Batches"), + interactive=True, + minimum=1, + maximum=20, + step=1, + value=init_llama_yml["trainer"][ + "accumulate_grad_batches" + ], + ) + llama_use_speaker = gr.Slider( + label=i18n( + "Probability of applying Speaker Condition" + ), + interactive=True, + minimum=0.1, + maximum=1.0, + step=0.05, + value=init_llama_yml["train_dataset"][ + "interactive_prob" + ], + ) + + with gr.Tab(label=i18n("Merge LoRA"), id=4): + with gr.Row(equal_height=False): + llama_weight = gr.Dropdown( + label=i18n("Base LLAMA Model"), + info=i18n( + "Type the path or select from the dropdown" + ), + choices=[ + "checkpoints/fish-speech-1.4/model.pth", + ], + value="checkpoints/fish-speech-1.4/model.pth", + allow_custom_value=True, + interactive=True, + ) + with gr.Row(equal_height=False): + lora_weight = gr.Dropdown( + label=i18n("LoRA Model to be merged"), + info=i18n( + "Type the path or select from the dropdown" + ), + choices=[ + str(p) + for p in Path("results").glob("lora*/**/*.ckpt") + ], + allow_custom_value=True, + interactive=True, + ) + lora_llama_config = gr.Dropdown( + label=i18n("LLAMA Model Config"), + info=i18n( + "Type the path or select from the dropdown" + ), + choices=[ + "text2semantic_finetune", + ], + value="text2semantic_finetune", + allow_custom_value=True, + ) + with gr.Row(equal_height=False): + llama_lora_output = gr.Dropdown( + label=i18n("Output Path"), + info=i18n( + "Type the path or select from the dropdown" + ), + value="checkpoints/merged", + choices=["checkpoints/merged"], + allow_custom_value=True, + interactive=True, + ) + with gr.Row(equal_height=False): + llama_lora_merge_btn = gr.Button( + value=i18n("Merge"), variant="primary" + ) + + with gr.Tab(label=i18n("Model Quantization"), id=5): + with gr.Row(equal_height=False): + llama_weight_to_quantify = gr.Dropdown( + label=i18n("Base LLAMA Model"), + info=i18n( + "Type the path or select from the dropdown" + ), + choices=list_llama_models(), + value="checkpoints/fish-speech-1.4", + allow_custom_value=True, + interactive=True, + ) + quantify_mode = gr.Dropdown( + label=i18n("Post-quantification Precision"), + info=i18n( + "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase" + ), + choices=["int8", "int4"], + value="int8", + allow_custom_value=False, + interactive=True, + ) + with gr.Row(equal_height=False): + llama_quantify_btn = gr.Button( + value=i18n("Quantify"), variant="primary" + ) + + with gr.Tab(label="Tensorboard", id=6): + with gr.Row(equal_height=False): + tb_host = gr.Textbox( + label=i18n("Tensorboard Host"), value="127.0.0.1" + ) + tb_port = gr.Textbox( + label=i18n("Tensorboard Port"), value="11451" + ) + with gr.Row(equal_height=False): + tb_dir = gr.Dropdown( + label=i18n("Tensorboard Log Path"), + allow_custom_value=True, + choices=[ + str(p) + for p in Path("results").glob("**/tensorboard/") + ], + ) + with gr.Row(equal_height=False): + if_tb = gr.Checkbox( + label=i18n("Open Tensorboard"), + ) + + with gr.Tab("\U0001F9E0 " + i18n("Inference Configuration")): + with gr.Column(): + with gr.Row(): + with gr.Accordion( + label="\U0001F5A5 " + + i18n("Inference Server Configuration"), + open=False, + ): + with gr.Row(): + infer_host_textbox = gr.Textbox( + label=i18n("WebUI Host"), value="127.0.0.1" + ) + infer_port_textbox = gr.Textbox( + label=i18n("WebUI Port"), value="7862" + ) + with gr.Row(): + infer_decoder_model = gr.Dropdown( + label=i18n("Decoder Model Path"), + info=i18n( + "Type the path or select from the dropdown" + ), + choices=list_decoder_models(), + value="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", + allow_custom_value=True, + ) + infer_decoder_config = gr.Dropdown( + label=i18n("Decoder Model Config"), + info=i18n("Changing with the Model Path"), + value="firefly_gan_vq", + choices=[ + "firefly_gan_vq", + ], + allow_custom_value=True, + ) + with gr.Row(): + infer_llama_model = gr.Dropdown( + label=i18n("LLAMA Model Path"), + info=i18n( + "Type the path or select from the dropdown" + ), + value="checkpoints/fish-speech-1.4", + choices=list_llama_models(), + allow_custom_value=True, + ) + + with gr.Row(): + infer_compile = gr.Radio( + label=i18n("Compile Model"), + info=i18n( + "Compile the model can significantly reduce the inference time, but will increase cold start time" + ), + choices=["Yes", "No"], + value=( + "Yes" if (sys.platform == "linux") else "No" + ), + interactive=is_module_installed("triton"), + ) + + with gr.Row(): + infer_checkbox = gr.Checkbox( + label=i18n("Open Inference Server") + ) + infer_error = gr.HTML(label=i18n("Inference Server Error")) + + with gr.Column(): + train_error = gr.HTML(label=i18n("Training Error")) + checkbox_group = gr.CheckboxGroup( + label="\U0001F4CA " + i18n("Data Source"), + info=i18n( + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list." + ), + elem_classes=["data_src"], + ) + train_box = gr.Textbox( + label=i18n("Data Preprocessing Path"), + value=str(data_pre_output), + interactive=False, + ) + model_box = gr.Textbox( + label="\U0001F4BE " + i18n("Model Output Path"), + value=str(default_model_output), + interactive=False, + ) + + with gr.Accordion( + i18n( + "View the status of the preprocessing folder (use the slider to control the depth of the tree)" + ), + elem_classes=["scrollable-component"], + elem_id="file_accordion", + ): + tree_slider = gr.Slider( + minimum=0, + maximum=3, + value=0, + step=1, + show_label=False, + container=False, + ) + file_markdown = new_explorer(str(data_pre_output), 0) + with gr.Row(equal_height=False): + admit_btn = gr.Button( + "\U00002705 " + i18n("File Preprocessing"), + variant="primary", + ) + fresh_btn = gr.Button("\U0001F503", scale=0, min_width=80) + help_button = gr.Button("\U00002753", scale=0, min_width=80) # question + train_btn = gr.Button(i18n("Start Training"), variant="primary") + + footer = load_data_in_raw("fish_speech/webui/html/footer.html") + footer = footer.format( + versions=versions_html(), + api_docs="https://speech.fish.audio/inference/#http-api", + ) + gr.HTML(footer, elem_id="footer") + vqgan_page.select(lambda: "VQGAN", None, model_type_radio) + llama_page.select(lambda: "LLAMA", None, model_type_radio) + add_button.click( + fn=add_item, + inputs=[textbox, output_radio, label_radio, if_initial_prompt, initial_prompt], + outputs=[checkbox_group, error], + ) + remove_button.click( + fn=remove_items, inputs=[checkbox_group], outputs=[checkbox_group, error] + ) + checkbox_group.change(fn=show_selected, inputs=checkbox_group, outputs=[error]) + help_button.click( + fn=None, + js='() => { window.open("https://speech.fish.audio/", "newwindow", "height=100, width=400, ' + 'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}', + ) + if_label.change(fn=change_label, inputs=[if_label], outputs=[error]) + if_initial_prompt.change( + fn=lambda x: gr.Textbox(value="", interactive=x), + inputs=[if_initial_prompt], + outputs=[initial_prompt], + ) + train_btn.click( + fn=train_process, + inputs=[ + train_box, + model_type_radio, + # llama config + llama_ckpt, + llama_base_config, + llama_lr_slider, + llama_maxsteps_slider, + llama_data_num_workers_slider, + llama_data_batch_size_slider, + llama_data_max_length_slider, + llama_precision_dropdown, + llama_check_interval_slider, + llama_grad_batches, + llama_use_speaker, + llama_use_lora, + ], + outputs=[train_error], + ) + if_tb.change( + fn=tensorboard_process, + inputs=[if_tb, tb_dir, tb_host, tb_port], + outputs=[train_error], + ) + tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir]) + infer_decoder_model.change( + fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model] + ) + infer_llama_model.change( + fn=fresh_llama_model, inputs=[], outputs=[infer_llama_model] + ) + llama_weight.change(fn=fresh_llama_model, inputs=[], outputs=[llama_weight]) + admit_btn.click( + fn=check_files, + inputs=[train_box, tree_slider, label_model, label_device], + outputs=[error, file_markdown], + ) + fresh_btn.click( + fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown] + ) + llama_use_lora.change( + fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt] + ) + llama_ckpt.change( + fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt] + ) + lora_weight.change( + fn=lambda: gr.Dropdown(choices=list_lora_llama_models()), + inputs=[], + outputs=[lora_weight], + ) + llama_lora_merge_btn.click( + fn=llama_lora_merge, + inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output], + outputs=[train_error], + ) + llama_quantify_btn.click( + fn=llama_quantify, + inputs=[llama_weight_to_quantify, quantify_mode], + outputs=[train_error], + ) + infer_checkbox.change( + fn=change_infer, + inputs=[ + infer_checkbox, + infer_host_textbox, + infer_port_textbox, + infer_decoder_model, + infer_decoder_config, + infer_llama_model, + infer_compile, + ], + outputs=[infer_error], + ) + +demo.launch(inbrowser=True)