diff --git a/fish_speech/__pycache__/conversation.cpython-310.pyc b/fish_speech/__pycache__/conversation.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4715d8ad7d19c710c2a4809f48f44ce77b121cfd
Binary files /dev/null and b/fish_speech/__pycache__/conversation.cpython-310.pyc differ
diff --git a/fish_speech/__pycache__/scheduler.cpython-310.pyc b/fish_speech/__pycache__/scheduler.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef6897b47e20bec33721ee86f864d6a511f49635
Binary files /dev/null and b/fish_speech/__pycache__/scheduler.cpython-310.pyc differ
diff --git a/fish_speech/callbacks/__init__.py b/fish_speech/callbacks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbcf3f33656d180ca87cd14a21ede1544e5a61a3
--- /dev/null
+++ b/fish_speech/callbacks/__init__.py
@@ -0,0 +1,3 @@
+from .grad_norm import GradNormMonitor
+
+__all__ = ["GradNormMonitor"]
diff --git a/fish_speech/callbacks/__pycache__/__init__.cpython-310.pyc b/fish_speech/callbacks/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..af0d6a057cc9c633a575f24bbb4806f02586c88c
Binary files /dev/null and b/fish_speech/callbacks/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fish_speech/callbacks/__pycache__/grad_norm.cpython-310.pyc b/fish_speech/callbacks/__pycache__/grad_norm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..95bc0a749f8b9617c87c6fb9a0d5936e38d82e7e
Binary files /dev/null and b/fish_speech/callbacks/__pycache__/grad_norm.cpython-310.pyc differ
diff --git a/fish_speech/callbacks/grad_norm.py b/fish_speech/callbacks/grad_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbc95ef2a3723323b2d976001ed1e3c79c00b21a
--- /dev/null
+++ b/fish_speech/callbacks/grad_norm.py
@@ -0,0 +1,113 @@
+from typing import Optional, Union
+
+import lightning.pytorch as pl
+import torch
+from lightning import LightningModule, Trainer
+from lightning.pytorch.callbacks import Callback
+from torch import Tensor, nn
+from torch.utils._foreach_utils import (
+ _group_tensors_by_device_and_dtype,
+ _has_foreach_support,
+)
+
+
+@torch.no_grad()
+def grad_norm(
+ parameters: Union[Tensor, list[Tensor]],
+ norm_type: float = 2.0,
+) -> float:
+ """
+ Returns the norm of the gradients of the given parameters.
+
+ Args:
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
+ single Tensor that will have gradients normalized
+ norm_type (float): type of the used p-norm.
+
+ Returns:
+ Total norm of the parameter gradients (viewed as a single vector).
+ """ # noqa: E501
+
+ if isinstance(parameters, Tensor):
+ parameters = [parameters]
+
+ grads = [p.grad for p in parameters if p.grad is not None]
+ if len(grads) == 0:
+ return None
+
+ first_device = grads[0].device
+ grouped_grads: dict[
+ tuple[torch.device, torch.dtype], list[list[Tensor]]
+ ] = _group_tensors_by_device_and_dtype(
+ [[g.detach() for g in grads]]
+ ) # type: ignore[assignment]
+
+ norms = []
+ for (device, _), ([grads], _) in grouped_grads.items():
+ if _has_foreach_support(grads, device=device):
+ norms.extend(torch._foreach_norm(grads, norm_type))
+ else:
+ norms.extend([torch.norm(g, norm_type) for g in grads])
+
+ return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
+
+
+class GradNormMonitor(Callback):
+ """
+ Callback that computes the gradient norm of the model parameters.
+ """
+
+ def __init__(
+ self,
+ norm_type: float = 2.0,
+ logging_interval: str = "step",
+ sub_module: Optional[Union[str, list[str]]] = None,
+ ) -> None:
+ """
+ Args:
+ norm_type (float): type of the used p-norm.
+ logging_interval (str): "step" or "epoch".
+ """
+ super().__init__()
+
+ self.norm_type = norm_type
+ self.logging_interval = logging_interval
+ self.sub_module = sub_module
+
+ def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
+ """
+ Computes the gradient norm of the model parameters and logs it to the logger.
+
+ Args:
+ trainer (Trainer): The trainer object
+ model (LightningModule): The current lightningModule
+ """
+
+ lightning_model = model
+
+ if self.sub_module is None:
+ return self.log_sub_module_grad_norm(lightning_model, model, "")
+
+ sub_modules = self.sub_module
+ if isinstance(sub_modules, str):
+ sub_modules = [sub_modules]
+
+ for sub_module in sub_modules:
+ self.log_sub_module_grad_norm(
+ lightning_model, getattr(model, sub_module), f"/{sub_module}"
+ )
+
+ def log_sub_module_grad_norm(
+ self, lightning_model: LightningModule, model: nn.Module, path: str
+ ) -> None:
+ grad_norm_val = grad_norm(model.parameters(), self.norm_type)
+ if grad_norm_val is None:
+ return
+
+ on_step = self.logging_interval == "step"
+ lightning_model.log(
+ f"train{path}/grad_norm",
+ grad_norm_val,
+ on_step=on_step,
+ on_epoch=not on_step,
+ )
diff --git a/fish_speech/configs/base.yaml b/fish_speech/configs/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..99e6dab54d3f57bce4f6d29a9129a19a523cad75
--- /dev/null
+++ b/fish_speech/configs/base.yaml
@@ -0,0 +1,87 @@
+# Base configuration for training a model
+paths:
+ run_dir: results/${project}
+ ckpt_dir: ${paths.run_dir}/checkpoints
+
+hydra:
+ run:
+ dir: ${paths.run_dir}
+
+# Lightning Trainer
+trainer:
+ _target_: lightning.pytorch.trainer.Trainer
+
+ default_root_dir: ${paths.run_dir}
+ accelerator: gpu
+ num_nodes: 1
+ devices: auto
+ strategy:
+ _target_: lightning.pytorch.strategies.DDPStrategy
+ process_group_backend: nccl # This should be override when training on windows
+
+ precision: bf16-mixed
+
+ # disable validation by epoch end
+ check_val_every_n_epoch: null
+ val_check_interval: 5000
+ max_steps: 100_000
+
+ # Use torch.backends.cudnn.benchmark to speed up training
+ benchmark: true
+
+# Callbacks
+callbacks:
+ model_checkpoint:
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
+ dirpath: ${paths.ckpt_dir}
+ filename: "step_{step:09d}"
+ save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
+ save_top_k: 5 # save 5 latest checkpoints
+ monitor: step # use step to monitor checkpoints
+ mode: max # save the latest checkpoint with the highest global_step
+ every_n_epochs: null # don't save checkpoints by epoch end
+ every_n_train_steps: 5000 # save checkpoints every 5000 steps
+ auto_insert_metric_name: false
+
+ model_summary:
+ _target_: lightning.pytorch.callbacks.ModelSummary
+ max_depth: 2 # the maximum depth of layer nesting that the summary will include
+
+ learning_rate_monitor:
+ _target_: lightning.pytorch.callbacks.LearningRateMonitor
+ logging_interval: step
+ log_momentum: false
+
+ grad_norm_monitor:
+ _target_: fish_speech.callbacks.GradNormMonitor
+ norm_type: 2
+ logging_interval: step
+
+# Logger
+logger:
+ tensorboard:
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
+ save_dir: "${paths.run_dir}/tensorboard/"
+ name: null
+ log_graph: false
+ default_hp_metric: true
+ prefix: ""
+
+ # wandb:
+ # _target_: lightning.pytorch.loggers.wandb.WandbLogger
+ # # name: "" # name of the run (normally generated by wandb)
+ # save_dir: "${paths.run_dir}"
+ # offline: False
+ # id: null # pass correct id to resume experiment!
+ # anonymous: null # enable anonymous logging
+ # project: "fish-speech"
+ # log_model: False # upload lightning ckpts
+ # prefix: "" # a string to put at the beginning of metric keys
+ # # entity: "" # set to name of your wandb team
+ # group: ""
+ # tags: ["vq", "hq", "finetune"]
+ # job_type: ""
+
+# Loop
+train: true
+test: false
diff --git a/fish_speech/configs/firefly_gan_vq.yaml b/fish_speech/configs/firefly_gan_vq.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..10aa8d4a522f0859ed8f541f5d48672d84b39c8f
--- /dev/null
+++ b/fish_speech/configs/firefly_gan_vq.yaml
@@ -0,0 +1,33 @@
+_target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture
+spec_transform:
+ _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
+ sample_rate: 44100
+ n_mels: 160
+ n_fft: 2048
+ hop_length: 512
+ win_length: 2048
+backbone:
+ _target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder
+ input_channels: 160
+ depths: [3, 3, 9, 3]
+ dims: [128, 256, 384, 512]
+ drop_path_rate: 0.2
+ kernel_size: 7
+head:
+ _target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator
+ hop_length: 512
+ upsample_rates: [8, 8, 2, 2, 2] # aka. strides
+ upsample_kernel_sizes: [16, 16, 4, 4, 4]
+ resblock_kernel_sizes: [3, 7, 11]
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+ num_mels: 512
+ upsample_initial_channel: 512
+ pre_conv_kernel_size: 13
+ post_conv_kernel_size: 13
+quantizer:
+ _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
+ input_dim: 512
+ n_groups: 8
+ n_codebooks: 1
+ levels: [8, 5, 5, 5]
+ downsample_factor: [2, 2]
diff --git a/fish_speech/configs/lora/r_8_alpha_16.yaml b/fish_speech/configs/lora/r_8_alpha_16.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..aecc4d9766a18fe31c55941e01b1f590c95e77c9
--- /dev/null
+++ b/fish_speech/configs/lora/r_8_alpha_16.yaml
@@ -0,0 +1,4 @@
+_target_: fish_speech.models.text2semantic.lora.LoraConfig
+r: 8
+lora_alpha: 16
+lora_dropout: 0.01
diff --git a/fish_speech/configs/text2semantic_finetune.yaml b/fish_speech/configs/text2semantic_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f4c1993023099e122fc9e004bda55ec075ed5e1b
--- /dev/null
+++ b/fish_speech/configs/text2semantic_finetune.yaml
@@ -0,0 +1,83 @@
+defaults:
+ - base
+ - _self_
+
+project: text2semantic_finetune_dual_ar
+max_length: 4096
+pretrained_ckpt_path: checkpoints/fish-speech-1.4
+
+# Lightning Trainer
+trainer:
+ accumulate_grad_batches: 1
+ gradient_clip_val: 1.0
+ gradient_clip_algorithm: "norm"
+ max_steps: 1000
+ precision: bf16-true
+ limit_val_batches: 10
+ val_check_interval: 100
+
+# Dataset Configuration
+tokenizer:
+ _target_: transformers.AutoTokenizer.from_pretrained
+ pretrained_model_name_or_path: ${pretrained_ckpt_path}
+
+# Dataset Configuration
+train_dataset:
+ _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
+ proto_files:
+ - data/protos
+ tokenizer: ${tokenizer}
+ causal: true
+ max_length: ${max_length}
+ use_speaker: false
+ interactive_prob: 0.7
+
+val_dataset:
+ _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
+ proto_files:
+ - data/protos
+ tokenizer: ${tokenizer}
+ causal: true
+ max_length: ${max_length}
+ use_speaker: false
+ interactive_prob: 0.7
+
+data:
+ _target_: fish_speech.datasets.semantic.SemanticDataModule
+ train_dataset: ${train_dataset}
+ val_dataset: ${val_dataset}
+ num_workers: 4
+ batch_size: 8
+ tokenizer: ${tokenizer}
+ max_length: ${max_length}
+
+# Model Configuration
+model:
+ _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
+ model:
+ _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
+ path: ${pretrained_ckpt_path}
+ load_weights: true
+ max_length: ${max_length}
+ lora_config: null
+
+ optimizer:
+ _target_: torch.optim.AdamW
+ _partial_: true
+ lr: 1e-4
+ weight_decay: 0
+ betas: [0.9, 0.95]
+ eps: 1e-5
+
+ lr_scheduler:
+ _target_: torch.optim.lr_scheduler.LambdaLR
+ _partial_: true
+ lr_lambda:
+ _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
+ _partial_: true
+ num_warmup_steps: 10
+
+# Callbacks
+callbacks:
+ model_checkpoint:
+ every_n_train_steps: ${trainer.val_check_interval}
diff --git a/fish_speech/conversation.py b/fish_speech/conversation.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bbc1cdb6c4a1d276ccf922988a7ad13e058d70a
--- /dev/null
+++ b/fish_speech/conversation.py
@@ -0,0 +1,256 @@
+from dataclasses import dataclass, field
+from typing import Literal
+
+import torch
+from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerFast
+
+IM_START_TOKEN = "<|im_start|>"
+IM_END_TOKEN = "<|im_end|>"
+SEMANTIC_TOKEN = "<|semantic|>"
+MEL_TOKEN = "<|mel|>"
+PHONEME_START_TOKEN = "<|phoneme_start|>"
+PHONEME_END_TOKEN = "<|phoneme_end|>"
+ALL_SPECIAL_TOKENS = [
+ IM_START_TOKEN,
+ IM_END_TOKEN,
+ SEMANTIC_TOKEN,
+ MEL_TOKEN,
+ PHONEME_START_TOKEN,
+ PHONEME_END_TOKEN,
+]
+
+CODEBOOK_PAD_TOKEN_ID = 0
+
+
+class FishTokenizerConfig(PretrainedConfig):
+ share_codebook_embeddings: bool = True
+ codebook_size: int = 1024
+ num_codebooks: int = 8
+
+
+class FishTokenizerFast(PreTrainedTokenizerFast):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.share_codebook_embeddings = kwargs.pop("share_codebook_embeddings", True)
+ self.codebook_size = kwargs.pop("codebook_size", 1024)
+ self.num_codebooks = kwargs.pop("num_codebooks", 8)
+
+
+AutoTokenizer.register(FishTokenizerConfig, fast_tokenizer_class=FishTokenizerFast)
+
+
+@dataclass(kw_only=True)
+class BasePart:
+ pass
+
+
+@dataclass(kw_only=True)
+class VQPart(BasePart):
+ codes: torch.Tensor
+
+
+@dataclass(kw_only=True)
+class TextPart(BasePart):
+ text: str
+
+
+@dataclass(kw_only=True)
+class MelPart(BasePart):
+ mels: torch.Tensor
+
+
+@dataclass(kw_only=True)
+class EncodedMessage:
+ tokens: torch.Tensor
+ labels: torch.Tensor
+ vq_parts: list[torch.Tensor]
+ mel_parts: list[torch.Tensor]
+ vq_require_losses: torch.Tensor | None = None
+
+
+@dataclass(kw_only=True)
+class Message:
+ role: Literal["system", "user", "assistant"]
+ parts: list[VQPart | TextPart | MelPart] = field(default_factory=list)
+ add_im_start: bool = True
+ add_im_end: bool = True
+ cal_loss: bool = False
+
+ # By default, ignore the loss of the auto-generated im_start token
+ ignore_im_start_loss: bool = True
+
+ def encode(
+ self: "Message",
+ tokenizer: AutoTokenizer,
+ ) -> EncodedMessage:
+ all_tokens = []
+ all_labels = []
+
+ # Multi-modal tokens
+ vq_parts = []
+ mel_parts = []
+
+ semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
+ [SEMANTIC_TOKEN, MEL_TOKEN]
+ )
+
+ parts = self.parts.copy()
+ if self.add_im_start:
+ parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n"))
+
+ if self.add_im_end:
+ parts.append(TextPart(text="<|im_end|>"))
+
+ for part in parts:
+ if isinstance(part, TextPart):
+ tokens = tokenizer.encode(
+ part.text,
+ add_special_tokens=False,
+ truncation=False,
+ return_tensors="pt",
+ ).int()[0]
+ elif isinstance(part, VQPart):
+ tokens = torch.zeros(part.codes.shape[1], dtype=torch.int) + semantic_id
+ codes = part.codes.clone() + 1
+
+ if getattr(tokenizer, "share_codebook_embeddings", True) is False:
+ for i in range(len(codes)):
+ codes[i] += tokenizer.codebook_size * i
+
+ vq_parts.append(codes)
+ elif isinstance(part, MelPart):
+ tokens = torch.zeros(part.mels.shape[1], dtype=torch.int) + mel_id
+ mel_parts.append(part.mels)
+ else:
+ raise ValueError(f"Unsupported part type: {type(part)}")
+
+ all_tokens.append(tokens)
+ if self.cal_loss:
+ all_labels.append(tokens.clone())
+ else:
+ all_labels.append(torch.full_like(tokens, -100))
+
+ tokens = torch.cat(all_tokens, dim=0)
+ labels = torch.cat(all_labels, dim=0)
+ assert tokens.shape == labels.shape
+
+ if self.ignore_im_start_loss and self.add_im_start:
+ labels[: len(all_tokens[0])] = -100
+
+ return EncodedMessage(
+ tokens=tokens,
+ labels=labels,
+ vq_parts=vq_parts,
+ mel_parts=mel_parts,
+ )
+
+
+@dataclass
+class Conversation:
+ messages: list[Message]
+
+ def encode(
+ self: "Conversation",
+ tokenizer: AutoTokenizer,
+ add_shift: bool = True,
+ ) -> EncodedMessage:
+ # Build the input_ids and labels
+ tokens = []
+ labels = []
+ vq_parts = []
+ mel_parts = []
+ vq_require_losses = []
+
+ for message in self.messages:
+ encoded = message.encode(
+ tokenizer,
+ )
+ tokens.append(encoded.tokens)
+ labels.append(encoded.labels)
+ vq_parts.extend(encoded.vq_parts)
+ mel_parts.extend(encoded.mel_parts)
+ vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts))
+
+ tokens = torch.cat(tokens, dim=0)
+ labels = torch.cat(labels, dim=0)
+ vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
+
+ if add_shift:
+ tokens = tokens[:-1]
+ labels = labels[1:]
+
+ assert tokens.dtype in [
+ torch.int,
+ torch.long,
+ ], f"Invalid dtype: {tokens.dtype}, conv: {conversation}"
+
+ return EncodedMessage(
+ tokens=tokens,
+ labels=labels,
+ vq_parts=vq_parts,
+ mel_parts=mel_parts,
+ vq_require_losses=vq_require_losses,
+ )
+
+ def encode_for_inference(
+ self: "Conversation",
+ tokenizer: AutoTokenizer,
+ num_codebooks: int,
+ ) -> EncodedMessage:
+ encoded = self.encode(tokenizer, add_shift=False)
+ tokens = encoded.tokens
+ values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
+ values[0] = tokens
+
+ if encoded.vq_parts is None or len(encoded.vq_parts) == 0:
+ return values
+
+ semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
+ [SEMANTIC_TOKEN, MEL_TOKEN]
+ )
+ vq_parts = encoded.vq_parts
+ vq_parts = torch.cat(vq_parts, dim=1)
+ values[1:, tokens == semantic_id] = vq_parts
+ return values
+
+ def visualize(self: "Conversation", tokenizer: AutoTokenizer):
+ encoded = self.encode(tokenizer, add_shift=False)
+
+ print_in_blue = lambda x: print("\033[94m" + x + "\033[0m", end="")
+ print_in_green = lambda x: print("\033[92m" + x + "\033[0m", end="")
+
+ for tok, lab in zip(encoded.tokens, encoded.labels):
+ val = tokenizer.decode(tok, skip_special_tokens=False)
+ if val == "\n":
+ val = "\\n\n"
+
+ if lab == -100:
+ print_in_green(val)
+ else:
+ print_in_blue(val)
+
+ print()
+
+
+if __name__ == "__main__":
+ message0 = Message(
+ role="user",
+ parts=[
+ TextPart(text="Hello, how are you?"),
+ VQPart(codes=torch.zeros((4, 10))),
+ ],
+ cal_loss=False,
+ )
+
+ message1 = Message(
+ role="assistant",
+ parts=[TextPart(text="I'm fine, thank you.")],
+ cal_loss=True,
+ )
+ conversation = Conversation([message0, message1])
+ tokenizer = AutoTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct")
+ conversation.visualize(tokenizer)
+
+ encoded = conversation.encode(tokenizer)
+ print(encoded)
+ print(tokenizer.batch_decode(encoded.tokens))
diff --git a/fish_speech/datasets/__pycache__/semantic.cpython-310.pyc b/fish_speech/datasets/__pycache__/semantic.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9c674545ce61a4bf657e1be68d9969889a90b69f
Binary files /dev/null and b/fish_speech/datasets/__pycache__/semantic.cpython-310.pyc differ
diff --git a/fish_speech/datasets/concat_repeat.py b/fish_speech/datasets/concat_repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..4aa596b95a572ee15c5570cbdb792c9a78e62dfa
--- /dev/null
+++ b/fish_speech/datasets/concat_repeat.py
@@ -0,0 +1,53 @@
+import bisect
+import random
+from typing import Iterable
+
+from torch.utils.data import Dataset, IterableDataset
+
+
+class ConcatRepeatDataset(Dataset):
+ datasets: list[Dataset]
+ cumulative_sizes: list[int]
+ repeats: list[int]
+
+ @staticmethod
+ def cumsum(sequence, repeats):
+ r, s = [], 0
+ for dataset, repeat in zip(sequence, repeats):
+ l = len(dataset) * repeat
+ r.append(l + s)
+ s += l
+ return r
+
+ def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
+ super().__init__()
+
+ self.datasets = list(datasets)
+ self.repeats = repeats
+
+ assert len(self.datasets) > 0, "datasets should not be an empty iterable"
+ assert len(self.datasets) == len(
+ repeats
+ ), "datasets and repeats should have the same length"
+
+ for d in self.datasets:
+ assert not isinstance(
+ d, IterableDataset
+ ), "ConcatRepeatDataset does not support IterableDataset"
+
+ self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
+
+ def __len__(self):
+ return self.cumulative_sizes[-1]
+
+ def __getitem__(self, idx):
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+
+ if dataset_idx == 0:
+ sample_idx = idx
+ else:
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+
+ dataset = self.datasets[dataset_idx]
+
+ return dataset[sample_idx % len(dataset)]
diff --git a/fish_speech/datasets/protos/__pycache__/text_data_pb2.cpython-310.pyc b/fish_speech/datasets/protos/__pycache__/text_data_pb2.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..436d43fb714faad6e5b7cbe23b6de6e998fe03e0
Binary files /dev/null and b/fish_speech/datasets/protos/__pycache__/text_data_pb2.cpython-310.pyc differ
diff --git a/fish_speech/datasets/protos/__pycache__/text_data_stream.cpython-310.pyc b/fish_speech/datasets/protos/__pycache__/text_data_stream.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ac8b50578898c53aec2ca94f73b92e0d9d7d8537
Binary files /dev/null and b/fish_speech/datasets/protos/__pycache__/text_data_stream.cpython-310.pyc differ
diff --git a/fish_speech/datasets/protos/text-data.proto b/fish_speech/datasets/protos/text-data.proto
new file mode 100644
index 0000000000000000000000000000000000000000..5eb26d94aa3be1e21066f2bf38c90d54e85a8379
--- /dev/null
+++ b/fish_speech/datasets/protos/text-data.proto
@@ -0,0 +1,24 @@
+syntax = "proto3";
+
+package text_data;
+
+message Semantics {
+ repeated uint32 values = 1;
+}
+
+message Sentence {
+ repeated string texts = 1;
+ repeated Semantics semantics = 3;
+}
+
+message TextData {
+ string source = 1;
+ string name = 2;
+ repeated Sentence sentences = 4;
+}
+
+message SampledData {
+ string source = 1;
+ string name = 2;
+ repeated Sentence samples = 3;
+}
diff --git a/fish_speech/datasets/protos/text_data_pb2.py b/fish_speech/datasets/protos/text_data_pb2.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfce0e8be59fc51e68999ef137e1fd0e4adc0d7e
--- /dev/null
+++ b/fish_speech/datasets/protos/text_data_pb2.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: text-data.proto
+# Protobuf Python Version: 4.25.1
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+ b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3'
+)
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+ DESCRIPTOR._options = None
+ _globals["_SEMANTICS"]._serialized_start = 30
+ _globals["_SEMANTICS"]._serialized_end = 57
+ _globals["_SENTENCE"]._serialized_start = 59
+ _globals["_SENTENCE"]._serialized_end = 125
+ _globals["_TEXTDATA"]._serialized_start = 127
+ _globals["_TEXTDATA"]._serialized_end = 207
+ _globals["_SAMPLEDDATA"]._serialized_start = 209
+ _globals["_SAMPLEDDATA"]._serialized_end = 290
+# @@protoc_insertion_point(module_scope)
diff --git a/fish_speech/datasets/protos/text_data_stream.py b/fish_speech/datasets/protos/text_data_stream.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec3c25bcd764e8245de47dcdf9686d6adfb5a107
--- /dev/null
+++ b/fish_speech/datasets/protos/text_data_stream.py
@@ -0,0 +1,36 @@
+import struct
+
+from .text_data_pb2 import TextData
+
+
+def read_pb_stream(f):
+ while True:
+ buf = f.read(4)
+ if len(buf) == 0:
+ break
+ size = struct.unpack("I", buf)[0]
+ buf = f.read(size)
+ text_data = TextData()
+ text_data.ParseFromString(buf)
+ yield text_data
+
+
+def write_pb_stream(f, text_data):
+ buf = text_data.SerializeToString()
+ f.write(struct.pack("I", len(buf)))
+ f.write(buf)
+
+
+def pack_pb_stream(text_data):
+ buf = text_data.SerializeToString()
+ return struct.pack("I", len(buf)) + buf
+
+
+def split_pb_stream(f):
+ while True:
+ head = f.read(4)
+ if len(head) == 0:
+ break
+ size = struct.unpack("I", head)[0]
+ buf = f.read(size)
+ yield head + buf
diff --git a/fish_speech/datasets/semantic.py b/fish_speech/datasets/semantic.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c64e01077ae253bdc4e4d9cd948f8fb50df7418
--- /dev/null
+++ b/fish_speech/datasets/semantic.py
@@ -0,0 +1,496 @@
+import random
+from dataclasses import dataclass
+from itertools import chain
+from pathlib import Path
+from random import Random
+from typing import Optional, Union
+
+import numpy as np
+import pyarrow.parquet as pq
+import torch
+import torch.nn.functional as F
+from datasets.download.streaming_download_manager import xopen
+from huggingface_hub import HfApi
+from lightning import LightningDataModule
+from torch.distributed import get_rank, get_world_size, is_initialized
+from torch.utils.data import DataLoader, IterableDataset, get_worker_info
+from transformers import AutoTokenizer
+
+from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
+from fish_speech.datasets.protos.text_data_pb2 import SampledData
+from fish_speech.datasets.protos.text_data_stream import read_pb_stream
+from fish_speech.text.clean import clean_text
+from fish_speech.utils import RankedLogger
+from fish_speech.utils.braceexpand import braceexpand
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def split_by_rank_worker(files):
+ # We need to know the total number of devices
+ # to split the data properly
+
+ total_devices = 1
+ if is_initialized():
+ total_devices = get_world_size()
+
+ worker_info = get_worker_info()
+ if worker_info is not None:
+ total_devices *= worker_info.num_workers
+
+ if len(files) < total_devices:
+ # Repeat the files N times to match the number of devices
+ files = files * (total_devices // len(files) + 1)
+
+ # DDP
+ if is_initialized():
+ files = files[get_rank() :: get_world_size()]
+
+ # Split by worker
+ if worker_info is not None:
+ files = files[worker_info.id :: worker_info.num_workers]
+
+ return files
+
+
+class AutoTextSemanticInstructionDataset(IterableDataset):
+ """
+ Auto Augment Dataset by Speaker
+
+ 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
+ 2. Automatically normalize the text
+
+ For interactive mode, we use the following format (multiple sequences):
+ [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
+
+ For non-interactive mode, we use the following format (one long sequence):
+ [INST] text [/INST] ...
+ """
+
+ def __init__(
+ self,
+ proto_files: list[str],
+ seed: int = 42,
+ interactive_prob: float = 0.5,
+ max_length: int = 1024,
+ tokenizer: AutoTokenizer = None,
+ use_speaker: bool | float = True,
+ causal: bool = True,
+ num_codebooks: Optional[int] = None,
+ skip_text_prob: float = 0.0,
+ ):
+ """
+ Args:
+ proto_files: proto buf files if using local data
+ seed: random seed
+ interactive_prob: probability to use interactive mode
+ max_length: max length of the text
+ tokenizer: tokenizer
+ use_speaker: include speaker information in the prompt
+ causal: use causal sampling when using local data, disable will lead to random sampling
+ num_codebooks: number of codebooks, if None, it will be automatically detected
+ skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
+ """
+
+ super().__init__()
+
+ assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
+
+ self.seed = seed
+ self.max_length = max_length
+ self.tokenizer = tokenizer
+ self.interactive_prob = interactive_prob
+ self.use_speaker = use_speaker
+ self.proto_files = proto_files
+ self.causal = causal
+ self.num_codebooks = num_codebooks
+ self.skip_text_prob = skip_text_prob
+
+ self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
+ self.groups = None
+
+ def init_mock_data_server(self):
+ if self.groups is not None:
+ return
+
+ # Expand the proto files
+ expanded_proto_files = []
+ for filename in self.proto_files:
+ for i in braceexpand(filename):
+ i = Path(i)
+ if i.is_file():
+ expanded_proto_files.append(i)
+ elif i.is_dir():
+ expanded_proto_files.extend(i.rglob("*.proto"))
+ expanded_proto_files.extend(i.rglob("*.protos"))
+ else:
+ raise ValueError(f"{i} is not a file or directory")
+
+ expanded_proto_files = sorted(expanded_proto_files)
+ Random(self.seed).shuffle(expanded_proto_files)
+
+ self.groups = []
+ shard_proto_files = split_by_rank_worker(expanded_proto_files)
+ log.info(
+ f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
+ )
+
+ count = 0
+ for filename in shard_proto_files:
+ with open(filename, "rb") as f:
+ for text_data in read_pb_stream(f):
+ self.groups.append(text_data)
+ count += 1
+
+ log.info(f"Read total {count} groups of data")
+
+ # Shuffle the lines
+ Random(self.seed).shuffle(self.groups)
+ self.group_weights = [len(i.sentences) for i in self.groups]
+
+ def __iter__(self):
+ while True:
+ yield self.augment()
+
+ def tokenize_sentence(self, sentence: str):
+ sentence = clean_text(sentence)
+ tokens = self.tokenizer.encode(
+ f"{sentence}",
+ max_length=10**6,
+ add_special_tokens=False,
+ truncation=False,
+ )
+ return sentence, len(tokens)
+
+ def sample_data(self):
+ if self.groups is None:
+ self.init_mock_data_server()
+
+ # Shuffle unique lines, estimate that each sample is at least 20 tokens
+ num_samples = self.max_length // 20
+
+ # choice group based on their number of samples
+ group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
+
+ if self.causal:
+ # Sample in order
+ if num_samples >= len(group.sentences):
+ samples = group.sentences
+ else:
+ begin = random.randint(0, len(group.sentences) - num_samples)
+ samples = group.sentences[begin : begin + num_samples]
+ else:
+ samples = random.choices(
+ group.sentences, k=min(num_samples, len(group.sentences))
+ )
+
+ return SampledData(
+ source=group.source,
+ name=group.name,
+ samples=samples,
+ )
+
+ def augment(self):
+ final_text, final_semantic = [], []
+ response = self.sample_data()
+ if len(response.samples) == 0:
+ # Invalid group
+ return None
+
+ samples = list(response.samples)
+ idx = 0
+ use_interactive = random.random() < self.interactive_prob
+
+ if use_interactive is False:
+ # Random sample based on speaker using a truncated normal distribution
+ a = torch.tensor([0], dtype=torch.float32)
+ torch.nn.init.trunc_normal_(
+ a,
+ mean=self.max_length // 2,
+ std=self.max_length // 4,
+ a=10,
+ b=self.max_length,
+ )
+ remaining_tokens = a.long().item() - 4
+ else:
+ remaining_tokens = self.max_length
+
+ # Use speaker
+ if isinstance(self.use_speaker, float):
+ use_speaker = random.random() < self.use_speaker
+ else:
+ use_speaker = self.use_speaker
+
+ all_tokens, all_labels = [], []
+ while remaining_tokens > 0 and len(samples) > 0:
+ sentence = samples.pop(0)
+
+ text = random.choice(sentence.texts)
+ text, length = self.tokenize_sentence(text)
+ remaining_tokens -= length + len(sentence.semantics[0].values)
+
+ if use_interactive is False:
+ final_text.append(text)
+ final_semantic.append(sentence.semantics)
+ else:
+ # For interactive mode, we only apply speaker for the first sentence
+ # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
+ tokens, labels = self.pack_sentences(
+ sentences=[text],
+ semantics=[sentence.semantics],
+ speaker=response.name if use_speaker else None,
+ skip_text=random.random() < self.skip_text_prob,
+ )
+
+ all_tokens.append(tokens)
+ all_labels.append(labels)
+
+ idx += 1
+
+ if use_interactive is False:
+ tokens, labels = self.pack_sentences(
+ final_text,
+ semantics=final_semantic,
+ speaker=response.name if use_speaker else None,
+ )
+ all_tokens.append(tokens)
+ all_labels.append(labels)
+
+ tokens = torch.cat(all_tokens, dim=1)
+ labels = torch.cat(all_labels, dim=1)
+
+ # Verify that the length is correct
+ assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
+
+ data = {"tokens": tokens, "labels": labels}
+
+ return data
+
+ def pack_sentences(
+ self,
+ sentences: list[str],
+ semantics: list,
+ speaker: Optional[str] = None,
+ skip_text: bool = False,
+ ):
+ if speaker is None:
+ speaker = "assistant"
+
+ cated_sentences = " ".join(sentences)
+ if skip_text:
+ cated_sentences = "<|skip_text|>"
+
+ final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>"
+ final_text = final_text + f"<|im_start|>{speaker}\n"
+
+ encoded = self.tokenizer.encode(
+ final_text,
+ add_special_tokens=False,
+ truncation=False,
+ max_length=10**6,
+ )
+ semantic_length = sum([len(i[0].values) for i in semantics])
+ prompt_length = len(encoded)
+ num_codebooks = (
+ len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
+ )
+
+ # Pack the tokens and semantics (add and to semantic tokens)
+ tokens = (
+ encoded
+ + [self.semantic_token_id] * semantic_length
+ + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])
+ )
+
+ # Codebook bos/padding: 0, eos: 1
+ codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)]
+ for segment in semantics:
+ for book_idx, book in zip(range(num_codebooks), segment):
+ for j in book.values:
+ codes[book_idx].append(int(j) + 1)
+
+ for book in codes:
+ book.extend([CODEBOOK_PAD_TOKEN_ID] * 1)
+
+ tokens = [tokens] + codes
+
+ tokens = torch.tensor(tokens, dtype=torch.long)
+ labels = tokens.clone()
+
+ if skip_text:
+ # If text is not provided, the sentence is used for condition only, all labels are -100
+ torch.fill_(labels, -100)
+ return tokens, labels
+
+ # Mask out the tokens for semantic, predict semantic tokens only
+ # Since we don't mask out the input tokens, the language modeling still works
+ labels[1:, :prompt_length] = -100
+
+ tokens = tokens[:, :-1]
+ labels = labels[:, 1:]
+
+ # Verify the padding is correct, and the last token is eos
+ assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all()
+ assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
+
+ return tokens, labels
+
+
+@dataclass
+class TextDataCollator:
+ tokenizer: AutoTokenizer
+ max_length: int = 1024
+
+ def __call__(self, examples):
+ if "negative_tokens" in examples:
+ positive_examples = []
+ negative_examples = []
+
+ for i in examples:
+ positive_examples.append(
+ {
+ "tokens": i["tokens"],
+ "labels": i["labels"],
+ }
+ )
+ negative_examples.append(
+ {
+ "tokens": i["negative_tokens"],
+ "labels": i["negative_labels"],
+ }
+ )
+
+ examples = positive_examples + negative_examples
+
+ return self.batchify(examples)
+
+ def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
+ tokens, attention_masks, labels = [], [], []
+
+ # Calculate the max length
+ max_tokens_length = 0
+ for example in examples:
+ max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
+ max_tokens_length = min(max_tokens_length, self.max_length)
+
+ for example in examples:
+ _tokens = example[tokens_key][:, :max_tokens_length]
+ _labels = example[labels_key][:, :max_tokens_length]
+ _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
+ tokens_length = _tokens.size(1)
+ _attention_mask[:tokens_length] = False
+
+ assert tokens_length == _labels.size(
+ 1
+ ), f"{tokens_length} != {_labels.size(1)}"
+
+ if tokens_length < max_tokens_length:
+ _tokens = F.pad(
+ _tokens,
+ (0, max_tokens_length - tokens_length),
+ value=self.tokenizer.eos_token_id,
+ )
+ _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
+ _labels = F.pad(
+ _labels, (0, max_tokens_length - _labels.size(1)), value=-100
+ )
+
+ tokens.append(_tokens)
+ attention_masks.append(_attention_mask)
+ labels.append(_labels)
+
+ tokens = torch.stack(tokens, dim=0)
+ attention_masks = torch.stack(attention_masks, dim=0)
+ labels = torch.stack(labels, dim=0)
+
+ return {
+ "inputs": tokens,
+ "attention_masks": attention_masks,
+ "labels": labels,
+ }
+
+
+class InterleaveDataset(IterableDataset):
+ def __init__(
+ self,
+ datasets: list[IterableDataset],
+ probabilities: list[float],
+ seed: int = 42,
+ ):
+ super().__init__()
+
+ self.datasets = datasets
+ self.probabilities = probabilities
+ self.seed = seed
+
+ def __iter__(self):
+ rng = np.random.default_rng(self.seed)
+ dataset_iterators = [iter(dataset) for dataset in self.datasets]
+
+ while True:
+ # Random choice one
+ dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
+ dataset_iterator = dataset_iterators[dataset_idx]
+
+ try:
+ yield next(dataset_iterator)
+ except StopIteration:
+ # Exhausted, create a new iterator
+ dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
+ yield next(dataset_iterators[dataset_idx])
+
+
+class SemanticDataModule(LightningDataModule):
+ def __init__(
+ self,
+ train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
+ val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
+ batch_size: int = 32,
+ tokenizer: AutoTokenizer = None,
+ max_length: int = 1024,
+ num_workers: int = 4,
+ ):
+ super().__init__()
+
+ self.train_dataset = train_dataset
+ self.val_dataset = val_dataset
+ self.batch_size = batch_size
+ self.tokenizer = tokenizer
+ self.max_length = max_length
+ self.num_workers = num_workers
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ collate_fn=TextDataCollator(self.tokenizer, self.max_length),
+ num_workers=self.num_workers,
+ persistent_workers=True,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.val_dataset,
+ batch_size=self.batch_size,
+ collate_fn=TextDataCollator(self.tokenizer, self.max_length),
+ num_workers=self.num_workers,
+ persistent_workers=True,
+ )
+
+
+if __name__ == "__main__":
+ from tqdm import tqdm
+
+ ds = AutoTextSemanticInstructionDataset(
+ ["data/protos"],
+ tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
+ use_speaker=False,
+ interactive_prob=1.0,
+ skip_text_prob=0.5,
+ )
+
+ for i in ds:
+ print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
+ # i["labels"][0][i["labels"][0] == -100] = 0
+ # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
+ break
diff --git a/fish_speech/datasets/vqgan.py b/fish_speech/datasets/vqgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..a45583d22efb0feb9dc1e823bae1ef74534b299e
--- /dev/null
+++ b/fish_speech/datasets/vqgan.py
@@ -0,0 +1,147 @@
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional
+
+import librosa
+import numpy as np
+import torch
+from lightning import LightningDataModule
+from torch.utils.data import DataLoader, Dataset
+
+from fish_speech.utils import RankedLogger
+
+logger = RankedLogger(__name__, rank_zero_only=False)
+
+
+class VQGANDataset(Dataset):
+ def __init__(
+ self,
+ filelist: str,
+ sample_rate: int = 32000,
+ hop_length: int = 640,
+ slice_frames: Optional[int] = None,
+ ):
+ super().__init__()
+
+ filelist = Path(filelist)
+ root = filelist.parent
+
+ self.files = [
+ root / line.strip()
+ for line in filelist.read_text(encoding="utf-8").splitlines()
+ if line.strip()
+ ]
+ self.sample_rate = sample_rate
+ self.hop_length = hop_length
+ self.slice_frames = slice_frames
+
+ def __len__(self):
+ return len(self.files)
+
+ def get_item(self, idx):
+ file = self.files[idx]
+
+ audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
+
+ # Slice audio and features
+ if (
+ self.slice_frames is not None
+ and audio.shape[0] > self.slice_frames * self.hop_length
+ ):
+ start = np.random.randint(
+ 0, audio.shape[0] - self.slice_frames * self.hop_length
+ )
+ audio = audio[start : start + self.slice_frames * self.hop_length]
+
+ if len(audio) == 0:
+ return None
+
+ max_value = np.abs(audio).max()
+ if max_value > 1.0:
+ audio = audio / max_value
+
+ return {
+ "audio": torch.from_numpy(audio),
+ }
+
+ def __getitem__(self, idx):
+ try:
+ return self.get_item(idx)
+ except Exception as e:
+ import traceback
+
+ traceback.print_exc()
+ logger.error(f"Error loading {self.files[idx]}: {e}")
+ return None
+
+
+@dataclass
+class VQGANCollator:
+ def __call__(self, batch):
+ batch = [x for x in batch if x is not None]
+
+ audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
+ audio_maxlen = audio_lengths.max()
+
+ # Rounds up to nearest multiple of 2 (audio_lengths)
+ audios = []
+ for x in batch:
+ audios.append(
+ torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
+ )
+
+ return {
+ "audios": torch.stack(audios),
+ "audio_lengths": audio_lengths,
+ }
+
+
+class VQGANDataModule(LightningDataModule):
+ def __init__(
+ self,
+ train_dataset: VQGANDataset,
+ val_dataset: VQGANDataset,
+ batch_size: int = 32,
+ num_workers: int = 4,
+ val_batch_size: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.train_dataset = train_dataset
+ self.val_dataset = val_dataset
+ self.batch_size = batch_size
+ self.val_batch_size = val_batch_size or batch_size
+ self.num_workers = num_workers
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ collate_fn=VQGANCollator(),
+ num_workers=self.num_workers,
+ shuffle=True,
+ persistent_workers=True,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.val_dataset,
+ batch_size=self.val_batch_size,
+ collate_fn=VQGANCollator(),
+ num_workers=self.num_workers,
+ persistent_workers=True,
+ )
+
+
+if __name__ == "__main__":
+ dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
+ dataloader = DataLoader(
+ dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
+ )
+
+ for batch in dataloader:
+ print(batch["audios"].shape)
+ print(batch["features"].shape)
+ print(batch["audio_lengths"])
+ print(batch["feature_lengths"])
+ break
diff --git a/fish_speech/i18n/README.md b/fish_speech/i18n/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..700902b09db20911ef1ad678cbdce5644b84aea2
--- /dev/null
+++ b/fish_speech/i18n/README.md
@@ -0,0 +1,27 @@
+## i18n Folder Attribution
+
+The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below:
+
+### fish_speech/i18n/core.py
+
+**Related code from RVC:**
+[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py)
+
+**Initial commit:**
+add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35)
+
+**Initial author:**
+[@L4Ph](https://github.com/L4Ph)
+
+### fish_speech/i18n/scan.py
+
+**Related code from RVC:**
+[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py)
+
+**Initial commit:**
+File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058)
+
+**Initial author:**
+[@towzeur](https://github.com/towzeur)
+
+We appreciate the contributions of the RVC project and its authors.
diff --git a/fish_speech/i18n/__init__.py b/fish_speech/i18n/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..981dbb3b3ecf28043ec9ff5757f947182821a246
--- /dev/null
+++ b/fish_speech/i18n/__init__.py
@@ -0,0 +1,3 @@
+from .core import i18n
+
+__all__ = ["i18n"]
diff --git a/fish_speech/i18n/__pycache__/__init__.cpython-310.pyc b/fish_speech/i18n/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f82bc9c78570a30c3d65065ddadafd44f1ecdc5
Binary files /dev/null and b/fish_speech/i18n/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fish_speech/i18n/__pycache__/__init__.cpython-311.pyc b/fish_speech/i18n/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e7af52e95a04d4575fd7af359f9d711379d58fb1
Binary files /dev/null and b/fish_speech/i18n/__pycache__/__init__.cpython-311.pyc differ
diff --git a/fish_speech/i18n/__pycache__/core.cpython-310.pyc b/fish_speech/i18n/__pycache__/core.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..97139b6392ea03e558405d073a5f9f5ca38b123b
Binary files /dev/null and b/fish_speech/i18n/__pycache__/core.cpython-310.pyc differ
diff --git a/fish_speech/i18n/__pycache__/core.cpython-311.pyc b/fish_speech/i18n/__pycache__/core.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7348d84d230b52d86f2a46b8f472399baded8f01
Binary files /dev/null and b/fish_speech/i18n/__pycache__/core.cpython-311.pyc differ
diff --git a/fish_speech/i18n/core.py b/fish_speech/i18n/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f793ec95669228f7f4e8f9a7a5fe38da85c74bd
--- /dev/null
+++ b/fish_speech/i18n/core.py
@@ -0,0 +1,40 @@
+import json
+import locale
+from pathlib import Path
+
+I18N_FILE_PATH = Path(__file__).parent / "locale"
+DEFAULT_LANGUAGE = "en_US"
+
+
+def load_language_list(language):
+ with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f:
+ language_list = json.load(f)
+
+ return language_list
+
+
+class I18nAuto:
+ def __init__(self):
+ i18n_file = Path(".locale")
+
+ if i18n_file.exists():
+ with open(i18n_file, "r", encoding="utf-8") as f:
+ language = f.read().strip()
+ else:
+ # getlocale can't identify the system's language ((None, None))
+ language = locale.getdefaultlocale()[0]
+
+ if (I18N_FILE_PATH / f"{language}.json").exists() is False:
+ language = DEFAULT_LANGUAGE
+
+ self.language = language
+ self.language_map = load_language_list(language)
+
+ def __call__(self, key):
+ return self.language_map.get(key, key)
+
+ def __repr__(self):
+ return "Use Language: " + self.language
+
+
+i18n = I18nAuto()
diff --git a/fish_speech/i18n/locale/en_US.json b/fish_speech/i18n/locale/en_US.json
new file mode 100644
index 0000000000000000000000000000000000000000..d36c774313628fe9d4ee60e816f404c09935e655
--- /dev/null
+++ b/fish_speech/i18n/locale/en_US.json
@@ -0,0 +1,123 @@
+{
+ "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).",
+ "Accumulate Gradient Batches": "Accumulate Gradient Batches",
+ "Add to Processing Area": "Add to Processing Area",
+ "Added path successfully!": "Added path successfully!",
+ "Advanced Config": "Advanced Config",
+ "Base LLAMA Model": "Base LLAMA Model",
+ "Batch Inference": "Batch Inference",
+ "Batch Size": "Batch Size",
+ "Changing with the Model Path": "Changing with the Model Path",
+ "Chinese": "Chinese",
+ "Compile Model": "Compile Model",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time",
+ "Copy": "Copy",
+ "Data Preprocessing": "Data Preprocessing",
+ "Data Preprocessing Path": "Data Preprocessing Path",
+ "Data Source": "Data Source",
+ "Decoder Model Config": "Decoder Model Config",
+ "Decoder Model Path": "Decoder Model Path",
+ "Disabled": "Disabled",
+ "Enable Reference Audio": "Enable Reference Audio",
+ "English": "English",
+ "Error Message": "Error Message",
+ "File Preprocessing": "File Preprocessing",
+ "Generate": "Generate",
+ "Generated Audio": "Generated Audio",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format",
+ "Infer interface is closed": "Infer interface is closed",
+ "Inference Configuration": "Inference Configuration",
+ "Inference Server Configuration": "Inference Server Configuration",
+ "Inference Server Error": "Inference Server Error",
+ "Inferring interface is launched at {}": "Inferring interface is launched at {}",
+ "Initial Learning Rate": "Initial Learning Rate",
+ "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription",
+ "Input Text": "Input Text",
+ "Invalid path: {}": "Invalid path: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU",
+ "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off",
+ "Japanese": "Japanese",
+ "LLAMA Configuration": "LLAMA Configuration",
+ "LLAMA Model Config": "LLAMA Model Config",
+ "LLAMA Model Path": "LLAMA Model Path",
+ "Labeling Device": "Labeling Device",
+ "LoRA Model to be merged": "LoRA Model to be merged",
+ "Maximum Audio Duration": "Maximum Audio Duration",
+ "Maximum Length per Sample": "Maximum Length per Sample",
+ "Maximum Training Steps": "Maximum Training Steps",
+ "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit",
+ "Merge": "Merge",
+ "Merge LoRA": "Merge LoRA",
+ "Merge successfully": "Merge successfully",
+ "Minimum Audio Duration": "Minimum Audio Duration",
+ "Model Output Path": "Model Output Path",
+ "Model Size": "Model Size",
+ "Move": "Move",
+ "Move files successfully": "Move files successfully",
+ "No audio generated, please check the input text.": "No audio generated, please check the input text.",
+ "No selected options": "No selected options",
+ "Number of Workers": "Number of Workers",
+ "Open Inference Server": "Open Inference Server",
+ "Open Labeler WebUI": "Open Labeler WebUI",
+ "Open Tensorboard": "Open Tensorboard",
+ "Opened labeler in browser": "Opened labeler in browser",
+ "Optional Label Language": "Optional Label Language",
+ "Optional online ver": "Optional online ver",
+ "Output Path": "Output Path",
+ "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path",
+ "Precision": "Precision",
+ "Probability of applying Speaker Condition": "Probability of applying Speaker Condition",
+ "Put your text here.": "Put your text here.",
+ "Reference Audio": "Reference Audio",
+ "Reference Text": "Reference Text",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.",
+ "Remove Selected Data": "Remove Selected Data",
+ "Removed path successfully!": "Removed path successfully!",
+ "Repetition Penalty": "Repetition Penalty",
+ "Save model every n steps": "Save model every n steps",
+ "Select LLAMA ckpt": "Select LLAMA ckpt",
+ "Select VITS ckpt": "Select VITS ckpt",
+ "Select VQGAN ckpt": "Select VQGAN ckpt",
+ "Select source file processing method": "Select source file processing method",
+ "Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)",
+ "Selected: {}": "Selected: {}",
+ "Speaker": "Speaker",
+ "Speaker is identified by the folder name": "Speaker is identified by the folder name",
+ "Start Training": "Start Training",
+ "Streaming Audio": "Streaming Audio",
+ "Streaming Generate": "Streaming Generate",
+ "Tensorboard Host": "Tensorboard Host",
+ "Tensorboard Log Path": "Tensorboard Log Path",
+ "Tensorboard Port": "Tensorboard Port",
+ "Tensorboard interface is closed": "Tensorboard interface is closed",
+ "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}",
+ "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.",
+ "Training Configuration": "Training Configuration",
+ "Training Error": "Training Error",
+ "Training stopped": "Training stopped",
+ "Type name of the speaker": "Type name of the speaker",
+ "Type the path or select from the dropdown": "Type the path or select from the dropdown",
+ "Use LoRA": "Use LoRA",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model",
+ "Use filelist": "Use filelist",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G",
+ "VITS Configuration": "VITS Configuration",
+ "VQGAN Configuration": "VQGAN Configuration",
+ "Validation Batch Size": "Validation Batch Size",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.",
+ "WebUI Host": "WebUI Host",
+ "WebUI Port": "WebUI Port",
+ "Whisper Model": "Whisper Model",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU",
+ "latest": "latest",
+ "new": "new",
+ "Realtime Transform Text": "Realtime Transform Text",
+ "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
+ "Text Normalization": "Text Normalization",
+ "Select Example Audio": "Select Example Audio"
+}
diff --git a/fish_speech/i18n/locale/es_ES.json b/fish_speech/i18n/locale/es_ES.json
new file mode 100644
index 0000000000000000000000000000000000000000..7a4757967dd0fe3807ba4d354e75ad7a88eb510e
--- /dev/null
+++ b/fish_speech/i18n/locale/es_ES.json
@@ -0,0 +1,123 @@
+{
+ "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).",
+ "Accumulate Gradient Batches": "Acumular lotes de gradientes",
+ "Add to Processing Area": "Agregar al Área de Procesamiento",
+ "Added path successfully!": "¡Ruta agregada exitosamente!",
+ "Advanced Config": "Configuración Avanzada",
+ "Base LLAMA Model": "Modelo Base LLAMA",
+ "Batch Inference": "Inferencia por Lote",
+ "Batch Size": "Tamaño del Lote",
+ "Changing with the Model Path": "Cambiando con la Ruta del Modelo",
+ "Chinese": "Chino",
+ "Compile Model": "Compilar Modelo",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío",
+ "Copy": "Copiar",
+ "Data Preprocessing": "Preprocesamiento de Datos",
+ "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos",
+ "Data Source": "Fuente de Datos",
+ "Decoder Model Config": "Configuración del modelo decodificador",
+ "Decoder Model Path": "Ruta del modelo decodificador",
+ "Disabled": "Desactivado",
+ "Enable Reference Audio": "Habilitar Audio de Referencia",
+ "English": "Inglés",
+ "Error Message": "Mensaje de Error",
+ "File Preprocessing": "Preprocesamiento de Archivos",
+ "Generate": "Generar",
+ "Generated Audio": "Audio Generado",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab",
+ "Infer interface is closed": "La interfaz de inferencia está cerrada",
+ "Inference Configuration": "Configuración de Inferencia",
+ "Inference Server Configuration": "Configuración del Servidor de Inferencia",
+ "Inference Server Error": "Error del Servidor de Inferencia",
+ "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}",
+ "Initial Learning Rate": "Tasa de Aprendizaje Inicial",
+ "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción",
+ "Input Text": "Texto de Entrada",
+ "Invalid path: {}": "Ruta inválida: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU",
+ "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado",
+ "Japanese": "Japonés",
+ "LLAMA Configuration": "Configuración de LLAMA",
+ "LLAMA Model Config": "Configuración del Modelo LLAMA",
+ "LLAMA Model Path": "Ruta del Modelo LLAMA",
+ "Labeling Device": "Dispositivo de Etiquetado",
+ "LoRA Model to be merged": "Modelo LoRA a fusionar",
+ "Maximum Audio Duration": "Duración máxima de audio",
+ "Maximum Length per Sample": "Longitud Máxima por Muestra",
+ "Maximum Training Steps": "Pasos Máximos de Entrenamiento",
+ "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite",
+ "Merge": "Fusionar",
+ "Merge LoRA": "Fusionar LoRA",
+ "Merge successfully": "Fusionado exitosamente",
+ "Minimum Audio Duration": "Duración mínima de audio",
+ "Model Output Path": "Ruta de Salida del Modelo",
+ "Model Size": "Tamaño del Modelo",
+ "Move": "Mover",
+ "Move files successfully": "Archivos movidos exitosamente",
+ "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.",
+ "No selected options": "No hay opciones seleccionadas",
+ "Number of Workers": "Número de Trabajadores",
+ "Open Inference Server": "Abrir Servidor de Inferencia",
+ "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador",
+ "Open Tensorboard": "Abrir Tensorboard",
+ "Opened labeler in browser": "Se abrió el etiquetador en el navegador",
+ "Optional Label Language": "Idioma de Etiquetado Opcional",
+ "Optional online ver": "Ver en línea opcional",
+ "Output Path": "Ruta de Salida",
+ "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente",
+ "Precision": "Precisión",
+ "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante",
+ "Put your text here.": "Ponga su texto aquí.",
+ "Reference Audio": "Audio de Referencia",
+ "Reference Text": "Texto de Referencia",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.",
+ "Remove Selected Data": "Eliminar Datos Seleccionados",
+ "Removed path successfully!": "¡Ruta eliminada exitosamente!",
+ "Repetition Penalty": "Penalización por Repetición",
+ "Save model every n steps": "Guardar modelo cada n pasos",
+ "Select LLAMA ckpt": "Seleccionar punto de control LLAMA",
+ "Select VITS ckpt": "Seleccionar punto de control VITS",
+ "Select VQGAN ckpt": "Seleccionar punto de control VQGAN",
+ "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente",
+ "Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)",
+ "Selected: {}": "Seleccionado: {}",
+ "Speaker": "Hablante",
+ "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
+ "Start Training": "Iniciar Entrenamiento",
+ "Streaming Audio": "transmisión de audio",
+ "Streaming Generate": "síntesis en flujo",
+ "Tensorboard Host": "Host de Tensorboard",
+ "Tensorboard Log Path": "Ruta de Registro de Tensorboard",
+ "Tensorboard Port": "Puerto de Tensorboard",
+ "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada",
+ "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}",
+ "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.",
+ "Training Configuration": "Configuración de Entrenamiento",
+ "Training Error": "Error de Entrenamiento",
+ "Training stopped": "Entrenamiento detenido",
+ "Type name of the speaker": "Escriba el nombre del hablante",
+ "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable",
+ "Use LoRA": "Usar LoRA",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo",
+ "Use filelist": "Usar lista de archivos",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G",
+ "VITS Configuration": "Configuración de VITS",
+ "VQGAN Configuration": "Configuración de VQGAN",
+ "Validation Batch Size": "Tamaño del Lote de Validación",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.",
+ "WebUI Host": "Host de WebUI",
+ "WebUI Port": "Puerto de WebUI",
+ "Whisper Model": "Modelo Whisper",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+",
+ "latest": "más reciente",
+ "new": "nuevo",
+ "Realtime Transform Text": "Transformación de Texto en Tiempo Real",
+ "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
+ "Text Normalization": "Normalización de Texto",
+ "Select Example Audio": "Selecionar áudio de exemplo"
+}
diff --git a/fish_speech/i18n/locale/ja_JP.json b/fish_speech/i18n/locale/ja_JP.json
new file mode 100644
index 0000000000000000000000000000000000000000..863b8b0b41da7e504ac0dcc4abf707f1f71a53fa
--- /dev/null
+++ b/fish_speech/i18n/locale/ja_JP.json
@@ -0,0 +1,123 @@
+{
+ "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。",
+ "Accumulate Gradient Batches": "勾配バッチの累積",
+ "Add to Processing Area": "処理エリアに追加",
+ "Added path successfully!": "パスの追加に成功しました!",
+ "Advanced Config": "詳細設定",
+ "Base LLAMA Model": "基本LLAMAモデル",
+ "Batch Inference": "バッチ推論",
+ "Batch Size": "バッチサイズ",
+ "Changing with the Model Path": "モデルのパスに伴って変化する",
+ "Chinese": "中国語",
+ "Compile Model": "モデルのコンパイル",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります",
+ "Copy": "コピー",
+ "Data Preprocessing": "データ前処理",
+ "Data Preprocessing Path": "データ前処理パス",
+ "Data Source": "データソース",
+ "Decoder Model Config": "デコーダーモデルの構成",
+ "Decoder Model Path": "デコーダーモデルのパス",
+ "Disabled": "無効",
+ "Enable Reference Audio": "リファレンスオーディオを有効にする",
+ "English": "英語",
+ "Error Message": "エラーメッセージ",
+ "File Preprocessing": "文書前处理",
+ "Generate": "生成",
+ "Generated Audio": "生成されたオーディオ",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています",
+ "Infer interface is closed": "推論インターフェースが閉じられています",
+ "Inference Configuration": "推論設定",
+ "Inference Server Configuration": "推論サーバー設定",
+ "Inference Server Error": "推論サーバーエラー",
+ "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました",
+ "Initial Learning Rate": "初期学習率",
+ "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス",
+ "Input Text": "入力テキスト",
+ "Invalid path: {}": "無効なパス: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください",
+ "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します",
+ "Japanese": "日本語",
+ "LLAMA Configuration": "LLAMA設定",
+ "LLAMA Model Config": "LLAMAモデル設定",
+ "LLAMA Model Path": "LLAMAモデルパス",
+ "Labeling Device": "ラベリングデバイス",
+ "LoRA Model to be merged": "マージするLoRAモデル",
+ "Maximum Audio Duration": "最大オーディオの長さ",
+ "Maximum Length per Sample": "サンプルあたりの最大長",
+ "Maximum Training Steps": "最大トレーニングステップ数",
+ "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します",
+ "Merge": "マージ",
+ "Merge LoRA": "LoRAのマージ",
+ "Merge successfully": "マージに成功しました",
+ "Minimum Audio Duration": "最小オーディオの長さ",
+ "Model Output Path": "モデル出力パス",
+ "Model Size": "モデルサイズ",
+ "Move": "移動",
+ "Move files successfully": "ファイルの移動に成功しました",
+ "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。",
+ "No selected options": "選択されたオプションはありません",
+ "Number of Workers": "ワーカー数",
+ "Open Inference Server": "推論サーバーを開く",
+ "Open Labeler WebUI": "ラベラーWebUIを開く",
+ "Open Tensorboard": "Tensorboardを開く",
+ "Opened labeler in browser": "ブラウザでラベラーを開きました",
+ "Optional Label Language": "オプションのラベル言語",
+ "Optional online ver": "オプションのオンラインバージョン",
+ "Output Path": "出力パス",
+ "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください",
+ "Precision": "精度",
+ "Probability of applying Speaker Condition": "話者条件を適用する確率",
+ "Put your text here.": "ここにテキストを入力してください。",
+ "Reference Audio": "リファレンスオーディオ",
+ "Reference Text": "リファレンステキスト",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。",
+ "Remove Selected Data": "選択したデータを削除",
+ "Removed path successfully!": "パスの削除に成功しました!",
+ "Repetition Penalty": "反復ペナルティ",
+ "Save model every n steps": "nステップごとにモデルを保存",
+ "Select LLAMA ckpt": " LLAMA チェックポイントを選択",
+ "Select VITS ckpt": "VITS チェックポイントを選択",
+ "Select VQGAN ckpt": "VQGAN チェックポイントを選択",
+ "Select source file processing method": "ソースファイルの処理方法を選択",
+ "Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください",
+ "Selected: {}": "選択済み: {}",
+ "Speaker": "話者",
+ "Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
+ "Start Training": "トレーニング開始",
+ "Streaming Audio": "ストリーミングオーディオ",
+ "Streaming Generate": "ストリーミング合成",
+ "Tensorboard Host": "Tensorboardホスト",
+ "Tensorboard Log Path": "Tensorboardログパス",
+ "Tensorboard Port": "Tensorboardポート",
+ "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています",
+ "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました",
+ "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。",
+ "Training Configuration": "トレーニング設定",
+ "Training Error": "トレーニングエラー",
+ "Training stopped": "トレーニングが停止しました",
+ "Type name of the speaker": "話者の名前を入力",
+ "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください",
+ "Use LoRA": "LoRAを使用",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります",
+ "Use filelist": "ファイルリストを使用",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください",
+ "VITS Configuration": "VITS の構成",
+ "VQGAN Configuration": "VQGAN の構成",
+ "Validation Batch Size": "検証バッチサイズ",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。",
+ "WebUI Host": "WebUIホスト",
+ "WebUI Port": "WebUIポート",
+ "Whisper Model": "Whisperモデル",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします",
+ "latest": "最新",
+ "new": "新規",
+ "Realtime Transform Text": "リアルタイム変換テキスト",
+ "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
+ "Text Normalization": "テキスト正規化",
+ "Select Example Audio": "サンプル音声を選択"
+}
diff --git a/fish_speech/i18n/locale/ko_KR.json b/fish_speech/i18n/locale/ko_KR.json
new file mode 100644
index 0000000000000000000000000000000000000000..180263874b476059870035d4c2b74ce5fa553a8a
--- /dev/null
+++ b/fish_speech/i18n/locale/ko_KR.json
@@ -0,0 +1,123 @@
+{
+ "16-mixed is recommended for 10+ series GPU": "10+ 시리즈 GPU에는 16-mixed를 권장합니다.",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "화자를 특정하는 데 유의미한 5~10초의 길이의 참조 오디오 데이터.",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)에서 개발한 VQ-GAN 및 Llama 기반의 텍스트 음성 변환 모델.",
+ "Accumulate Gradient Batches": "그라디언트 배치 누적",
+ "Add to Processing Area": "처리 영역에 추가",
+ "Added path successfully!": "경로가 성공적으로 추가되었습니다!",
+ "Advanced Config": "고급 설정",
+ "Base LLAMA Model": "기본 LLAMA 모델",
+ "Batch Inference": "배치 추론",
+ "Batch Size": "배치 크기",
+ "Changing with the Model Path": "모델 경로에 따라 변경 중",
+ "Chinese": "중국어",
+ "Compile Model": "모델 컴파일",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "모델을 컴파일하면 추론 시간이 크게 줄어들지만, 초기 시작 시간이 길어집니다.",
+ "Copy": "복사",
+ "Data Preprocessing": "데이터 전처리",
+ "Data Preprocessing Path": "데이터 전처리 경로",
+ "Data Source": "데이터 소스",
+ "Decoder Model Config": "디코더 모델 설정",
+ "Decoder Model Path": "디코더 모델 경로",
+ "Disabled": "비활성화 됨",
+ "Enable Reference Audio": "참고 음성 활성화",
+ "English": "영어",
+ "Error Message": "오류 메시지",
+ "File Preprocessing": "파일 전처리",
+ "Generate": "생성",
+ "Generated Audio": "생성된 오디오",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "오디오애 대응하는 텍스트가 없을 경우, ASR을 적용해 지원하며, .txt 또는 .lab 형식을 지원합니다.",
+ "Infer interface is closed": "추론 인터페이스가 닫혔습니다.",
+ "Inference Configuration": "추론 설정",
+ "Inference Server Configuration": "추론 서버 설정",
+ "Inference Server Error": "추론 서버 오류",
+ "Inferring interface is launched at {}": "추론 인터페이스가 {}에서 시작되었습니다.",
+ "Initial Learning Rate": "초기 학습률",
+ "Input Audio & Source Path for Transcription": "전사할 입력 오디오 및 소스 경로",
+ "Input Text": "입력 텍스트",
+ "Invalid path: {}": "유효하지 않은 경로: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDA 사용을 권장하며, 낮은 사양일 경우 CPU를 사용하는 것을 권장합니다.",
+ "Iterative Prompt Length, 0 means off": "반복 프롬프트 길이. (0:비활성화)",
+ "Japanese": "일본어",
+ "LLAMA Configuration": "LLAMA 설정",
+ "LLAMA Model Config": "LLAMA 모델 설정",
+ "LLAMA Model Path": "LLAMA 모델 경로",
+ "Labeling Device": "라벨링 장치",
+ "LoRA Model to be merged": "병합할 LoRA 모델",
+ "Maximum Audio Duration": "최대 오디오 길이",
+ "Maximum Length per Sample": "샘플당 최대 길이",
+ "Maximum Training Steps": "최대 학습 단계",
+ "Maximum tokens per batch, 0 means no limit": "배치당 최대 토큰 수(0:제한 없음)",
+ "Merge": "병합",
+ "Merge LoRA": "LoRA 병합",
+ "Merge successfully": "성공적으로 병합 되었습니다.",
+ "Minimum Audio Duration": "최소 오디오 길이",
+ "Model Output Path": "모델 출력 경로",
+ "Model Size": "모델 크기",
+ "Move": "이동",
+ "Move files successfully": "파일이 성공적으로 이동되었습니다.",
+ "No audio generated, please check the input text.": "생성된 오디오가 없습니다. 입력된 텍스트를 확인하세요.",
+ "No selected options": "옵션이 선택되지 않았습니다.",
+ "Number of Workers": "작업자 수",
+ "Open Inference Server": "추론 서버 열기",
+ "Open Labeler WebUI": "라벨러 WebUI 열기",
+ "Open Tensorboard": "Tensorboard 열기",
+ "Opened labeler in browser": "브라우저에서 라벨러가 열렸습니다.",
+ "Optional Label Language": "선택적 라벨 언어",
+ "Optional online ver": "온라인 버전 선택",
+ "Output Path": "출력 경로",
+ "Path error, please check the model file exists in the corresponding path": "경로 오류, 해당 경로에 모델 파일이 있는지 확인하십시오.",
+ "Precision": "정밀도",
+ "Probability of applying Speaker Condition": "화자 조건 적용 확률",
+ "Put your text here.": "여기에 텍스트를 입력하세요.",
+ "Reference Audio": "참고 오디오",
+ "Reference Text": "참고 텍스트",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "관련 코드 및 가중치는 CC BY-NC-SA 4.0 라이선스 하에 배포됩니다.",
+ "Remove Selected Data": "선택한 데이터 제거",
+ "Removed path successfully!": "경로가 성공적으로 제거되었습니다!",
+ "Repetition Penalty": "반복 패널티",
+ "Save model every n steps": "n 단계마다 모델 저장",
+ "Select LLAMA ckpt": "LLAMA ckpt 선택",
+ "Select VITS ckpt": "VITS ckpt 선택",
+ "Select VQGAN ckpt": "VQGAN ckpt 선택",
+ "Select source file processing method": "소스 파일 처리 방법 선택",
+ "Select the model to be trained (Depending on the Tab page you are on)": "학습할 모델 선택(탭 페이지에 따라 다름)",
+ "Selected: {}": "선택됨: {}",
+ "Speaker": "화자",
+ "Speaker is identified by the folder name": "화자는 폴더 이름으로 식별됩니다",
+ "Start Training": "학습 시작",
+ "Streaming Audio": "스트리밍 오디오",
+ "Streaming Generate": "스트리밍 생성",
+ "Tensorboard Host": "Tensorboard 호스트",
+ "Tensorboard Log Path": "Tensorboard 로그 경로",
+ "Tensorboard Port": "Tensorboard 포트",
+ "Tensorboard interface is closed": "Tensorboard 인터페이스가 닫혔습니다",
+ "Tensorboard interface is launched at {}": "Tensorboard 인터페이스가 {}에서 시작되었습니다.",
+ "Text is too long, please keep it under {} characters.": "텍스트가 너무 깁니다. {}자 이하로 입력해주세요.",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "왼쪽의 입력 폴더 경로 또는 파일 목록의 경로. 체크 여부에 관계없이 이 목록에서 후속 학습에 사용됩니다.",
+ "Training Configuration": "학습 설정",
+ "Training Error": "학습 오류",
+ "Training stopped": "학습이 중지되었습니다.",
+ "Type name of the speaker": "화자의 이름을 입력하세요.",
+ "Type the path or select from the dropdown": "경로를 입력하거나 드롭다운에서 선택하세요.",
+ "Use LoRA": "LoRA 사용",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRA를 사용하면 GPU 메모리를 절약할 수 있지만, 모델의 품질이 저하될 수 있습니다.",
+ "Use filelist": "파일 목록 사용",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 환경에선 large, 5G에선 medium, 2G에선 small을 사용할 것을 권장합니다.",
+ "VITS Configuration": "VITS 설정",
+ "VQGAN Configuration": "VQGAN 설정",
+ "Validation Batch Size": "검증 배치 크기",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "전처리 폴더의 상태를 확인합니다(슬라이더를 사용하여 트리의 깊이를 조절합니다)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "모델의 오용에 대해 책임지지 않습니다. 사용하기 전에 현지 법률과 규정을 고려하시길 바랍니다.",
+ "WebUI Host": "WebUI 호스트",
+ "WebUI Port": "WebUI 포트",
+ "Whisper Model": "Whisper 모델",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "소스 코드는 [이곳](https://github.com/fishaudio/fish-speech)에서, 모델은 [이곳](https://huggingface.co/fishaudio/fish-speech-1)에서 확인하실 수 있습니다.",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 시리즈 GPU에는 bf16-true를, 10+ 시리즈 GPU에는 16-mixed를 권장합니다",
+ "latest": "최신",
+ "new": "새로운",
+ "Realtime Transform Text": "실시간 텍스트 변환",
+ "Normalization Result Preview (Currently Only Chinese)": "정규화 결과 미리보기(현재 중국어만 지원)",
+ "Text Normalization": "텍스트 정규화",
+ "Select Example Audio": "예시 오디오 선택"
+}
diff --git a/fish_speech/i18n/locale/pt_BR.json b/fish_speech/i18n/locale/pt_BR.json
new file mode 100644
index 0000000000000000000000000000000000000000..385f20272e19053ab9b6cf6463a84c8ece768c68
--- /dev/null
+++ b/fish_speech/i18n/locale/pt_BR.json
@@ -0,0 +1,133 @@
+{
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).",
+ "Accumulate Gradient Batches": "Acumular Lotes de Gradiente",
+ "Add to Processing Area": "Adicionar à Área de Processamento",
+ "Added path successfully!": "Caminho adicionado com sucesso!",
+ "Advanced Config": "Configuração Avançada",
+ "Base LLAMA Model": "Modelo LLAMA Base",
+ "Batch Inference": "Inferência em Lote",
+ "Batch Size": "Tamanho do Lote",
+ "Changing with the Model Path": "Alterando com o Caminho do Modelo",
+
+ "Compile Model": "Compilar Modelo",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial",
+ "Copy": "Copiar",
+ "Data Preprocessing": "Pré-processamento de Dados",
+ "Data Preprocessing Path": "Caminho de Pré-processamento de Dados",
+ "Data Source": "Fonte de Dados",
+ "Decoder Model Config": "Configuração do Modelo Decodificador",
+ "Decoder Model Path": "Caminho do Modelo Decodificador",
+ "Disabled": "Desativado",
+ "Enable Initial Prompt": "Habilitar Prompt Inicial",
+ "Enable Reference Audio": "Habilitar Áudio de Referência",
+ "English": "Inglês",
+ "Japanese": "Japonês",
+ "Chinese": "Chinês",
+ "Portuguese": "Português",
+ "Spanish": "Espanhol",
+ "Error Message": "Mensagem de Erro",
+ "Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)",
+ "File Preprocessing": "Pré-processamento de Arquivos",
+ "Generate": "Gerar",
+ "Generated Audio": "Áudio Gerado",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)",
+ "Infer interface is closed": "A interface de inferência foi fechada",
+ "Inference Configuration": "Configuração de Inferência",
+ "Inference Server Configuration": "Configuração do Servidor de Inferência",
+ "Inference Server Error": "Erro do Servidor de Inferência",
+ "Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}",
+ "Initial Learning Rate": "Taxa de Aprendizagem Inicial",
+ "Initial Prompt": "Prompt Inicial",
+ "Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.",
+ "Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição",
+ "Input Text": "Texto de Entrada",
+ "Invalid path: {}": "Caminho inválido: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU",
+ "Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)",
+ "LLAMA Configuration": "Configuração do LLAMA",
+ "LLAMA Model Config": "Configuração do Modelo LLAMA",
+ "LLAMA Model Path": "Caminho do Modelo LLAMA",
+ "Labeling Device": "Dispositivo de Rotulagem",
+ "LoRA Model to be merged": "Modelo LoRA para mesclagem",
+ "Maximum Length per Sample": "Comprimento Máximo por Amostra",
+ "Maximum Training Steps": "Etapas Máximas de Treinamento",
+ "Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite",
+ "Merge": "Mesclar",
+ "Merge LoRA": "Mesclar LoRA",
+ "Merge successfully": "Mesclado com sucesso",
+ "Model Output Path": "Caminho de Saída do Modelo",
+ "Model Quantization": "Quantização do Modelo",
+ "Model Size": "Tamanho do Modelo",
+ "Move": "Mover",
+ "Move files successfully": "Arquivos movidos com sucesso",
+ "No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.",
+ "No selected options": "Nenhuma opção selecionada",
+ "Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)",
+ "Number of Workers": "Número de Processos",
+ "Open Inference Server": "Abrir Servidor de Inferência",
+ "Open Labeler WebUI": "Abrir WebUI de Rotulagem",
+ "Open Tensorboard": "Abrir Tensorboard",
+ "Opened labeler in browser": "WebUI de rotulagem aberta no navegador",
+ "Optional Label Language": "Idioma do Rótulo (Opcional)",
+ "Optional online ver": "Versão online (opcional)",
+ "Output Path": "Caminho de Saída",
+ "Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente",
+ "Post-quantification Precision": "Precisão Pós-quantização",
+ "Precision": "Precisão",
+ "Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador",
+ "Put your text here.": "Insira seu texto aqui.",
+ "Quantify": "Quantizar",
+ "Quantify successfully": "Quantizado com sucesso",
+ "Realtime Transform Text": "Transformar Texto em Tempo Real",
+ "Reference Audio": "Áudio de Referência",
+ "Reference Text": "Texto de Referência",
+ "warning": "Aviso",
+ "Pre-processing begins...": "O pré-processamento começou!",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.",
+ "Remove Selected Data": "Remover Dados Selecionados",
+ "Removed path successfully!": "Caminho removido com sucesso!",
+ "Repetition Penalty": "Penalidade de Repetição",
+ "Save model every n steps": "Salvar modelo a cada n etapas",
+ "Select LLAMA ckpt": "Selecionar .ckpt do LLAMA",
+ "Select source file processing method": "Escolha como processar o arquivo de origem",
+ "Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)",
+ "Selected: {}": "Selecionado: {}",
+ "Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta",
+ "Start Training": "Iniciar Treinamento",
+ "Streaming Audio": "Áudio em Streaming",
+ "Streaming Generate": "Geração em Streaming",
+ "Tensorboard Host": "Host do Tensorboard",
+ "Tensorboard Log Path": "Caminho de Log do Tensorboard",
+ "Tensorboard Port": "Porta do Tensorboard",
+ "Tensorboard interface is closed": "A interface do Tensorboard está fechada",
+ "Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}",
+ "Text Normalization": "Normalização de Texto",
+ "Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.",
+ "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.",
+ "Training Configuration": "Configuração de Treinamento",
+ "Training Error": "Erro de Treinamento",
+ "Training stopped": "Treinamento interrompido!",
+ "Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso",
+ "Use LoRA": "Usar LoRA",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade",
+ "Use filelist": "Usar lista de arquivos",
+ "VQGAN Configuration": "Configuração do VQGAN",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.",
+ "WebUI Host": "Host da WebUI",
+ "WebUI Port": "Porta da WebUI",
+ "Whisper Model": "Modelo Whisper",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).",
+ "auto": "automático",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+",
+ "latest": "mais recente",
+ "new": "novo",
+ "This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.",
+ "You don't need to train this model!": "Não é necessário treinar este modelo!",
+ "Yes": "Sim",
+ "No": "Não",
+ "version:": "versão:",
+ "author:": "autor:"
+}
diff --git a/fish_speech/i18n/locale/zh_CN.json b/fish_speech/i18n/locale/zh_CN.json
new file mode 100644
index 0000000000000000000000000000000000000000..9068ef0b9a41b9941b37644c6a4c96ec6a5d836e
--- /dev/null
+++ b/fish_speech/i18n/locale/zh_CN.json
@@ -0,0 +1,123 @@
+{
+ "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.",
+ "Accumulate Gradient Batches": "梯度累积批次",
+ "Add to Processing Area": "加入处理区",
+ "Added path successfully!": "添加路径成功!",
+ "Advanced Config": "高级参数",
+ "Base LLAMA Model": "基础 LLAMA 模型",
+ "Batch Inference": "批量推理",
+ "Batch Size": "批次大小",
+ "Changing with the Model Path": "随模型路径变化",
+ "Chinese": "中文",
+ "Compile Model": "编译模型",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间",
+ "Copy": "复制",
+ "Data Preprocessing": "数据预处理",
+ "Data Preprocessing Path": "数据预处理路径",
+ "Data Source": "数据源",
+ "Decoder Model Config": "解码器模型配置",
+ "Decoder Model Path": "解码器模型路径",
+ "Disabled": "禁用",
+ "Enable Reference Audio": "启用参考音频",
+ "English": "英文",
+ "Error Message": "错误信息",
+ "File Preprocessing": "文件预处理",
+ "Generate": "生成",
+ "Generated Audio": "音频",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式",
+ "Infer interface is closed": "推理界面已关闭",
+ "Inference Configuration": "推理配置",
+ "Inference Server Configuration": "推理服务器配置",
+ "Inference Server Error": "推理服务器错误",
+ "Inferring interface is launched at {}": "推理界面已在 {} 上启动",
+ "Initial Learning Rate": "初始学习率",
+ "Input Audio & Source Path for Transcription": "输入音频和转录源路径",
+ "Input Text": "输入文本",
+ "Invalid path: {}": "无效路径: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU",
+ "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭",
+ "Japanese": "日文",
+ "LLAMA Configuration": "LLAMA 配置",
+ "LLAMA Model Config": "LLAMA 模型配置",
+ "LLAMA Model Path": "LLAMA 模型路径",
+ "Labeling Device": "标注加速设备",
+ "LoRA Model to be merged": "要合并的 LoRA 模型",
+ "Maximum Audio Duration": "最大音频时长",
+ "Maximum Length per Sample": "每个样本的最大长度",
+ "Maximum Training Steps": "最大训练步数",
+ "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制",
+ "Merge": "合并",
+ "Merge LoRA": "合并 LoRA",
+ "Merge successfully": "合并成功",
+ "Minimum Audio Duration": "最小音频时长",
+ "Model Output Path": "模型输出路径",
+ "Model Size": "模型规模",
+ "Move": "移动",
+ "Move files successfully": "移动文件成功",
+ "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.",
+ "No selected options": "没有选择的选项",
+ "Number of Workers": "数据加载进程数",
+ "Open Inference Server": "打开推理服务器",
+ "Open Labeler WebUI": "打开标注工具",
+ "Open Tensorboard": "打开 Tensorboard",
+ "Opened labeler in browser": "在浏览器中打开标注工具",
+ "Optional Label Language": "[可选] 标注语言",
+ "Optional online ver": "[可选] 使用在线版",
+ "Output Path": "输出路径",
+ "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
+ "Precision": "精度",
+ "Probability of applying Speaker Condition": "应用说话人条件的概率",
+ "Put your text here.": "在此处输入文本.",
+ "Reference Audio": "参考音频",
+ "Reference Text": "参考文本",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.",
+ "Remove Selected Data": "移除选中数据",
+ "Removed path successfully!": "移除路径成功!",
+ "Repetition Penalty": "重复惩罚",
+ "Save model every n steps": "每 n 步保存模型",
+ "Select LLAMA ckpt": "选择 LLAMA 检查点",
+ "Select VITS ckpt": "选择 VITS 检查点",
+ "Select VQGAN ckpt": "选择 VQGAN 检查点",
+ "Select source file processing method": "选择源文件处理方法",
+ "Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型",
+ "Selected: {}": "已选择: {}",
+ "Speaker": "说话人",
+ "Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
+ "Start Training": "开始训练",
+ "Streaming Audio": "流式音频",
+ "Streaming Generate": "流式合成",
+ "Tensorboard Host": "Tensorboard 监听地址",
+ "Tensorboard Log Path": "Tensorboard 日志路径",
+ "Tensorboard Port": "Tensorboard 端口",
+ "Tensorboard interface is closed": "Tensorboard 界面已关闭",
+ "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动",
+ "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.",
+ "Training Configuration": "训练配置",
+ "Training Error": "训练错误",
+ "Training stopped": "训练已停止",
+ "Type name of the speaker": "输入说话人的名称",
+ "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择",
+ "Use LoRA": "使用 LoRA",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量",
+ "Use filelist": "使用文件列表",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small",
+ "VITS Configuration": "VITS 配置",
+ "VQGAN Configuration": "VQGAN 配置",
+ "Validation Batch Size": "验证批次大小",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.",
+ "WebUI Host": "WebUI 监听地址",
+ "WebUI Port": "WebUI 端口",
+ "Whisper Model": "Whisper 模型",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed",
+ "latest": "最近的检查点",
+ "new": "创建新的检查点",
+ "Realtime Transform Text": "实时规范化文本",
+ "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览",
+ "Text Normalization": "文本规范化",
+ "Select Example Audio": "选择参考音频"
+}
diff --git a/fish_speech/i18n/scan.py b/fish_speech/i18n/scan.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0194c0f1a31dc95309c64626d13f04751a44ba1
--- /dev/null
+++ b/fish_speech/i18n/scan.py
@@ -0,0 +1,122 @@
+import ast
+import glob
+import json
+from collections import OrderedDict
+from pathlib import Path
+
+from loguru import logger
+
+from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH
+
+
+def extract_i18n_strings(node):
+ i18n_strings = []
+
+ if (
+ isinstance(node, ast.Call)
+ and isinstance(node.func, ast.Name)
+ and node.func.id == "i18n"
+ ):
+ for arg in node.args:
+ if isinstance(arg, ast.Str):
+ i18n_strings.append(arg.s)
+
+ for child_node in ast.iter_child_nodes(node):
+ i18n_strings.extend(extract_i18n_strings(child_node))
+
+ return i18n_strings
+
+
+# scan the directory for all .py files (recursively)
+# for each file, parse the code into an AST
+# for each AST, extract the i18n strings
+
+strings = []
+folders = ["fish_speech", "tools"]
+# for filename in glob.iglob("**/*.py", recursive=True):
+for folder in folders:
+ for f in Path(folder).rglob("*.py"):
+ code = f.read_text(encoding="utf-8")
+ if "i18n(" in code:
+ tree = ast.parse(code)
+ i18n_strings = extract_i18n_strings(tree)
+ logger.info(f"Found {len(i18n_strings)} i18n strings in {f}")
+ strings.extend(i18n_strings)
+
+code_keys = set(strings)
+logger.info(f"Total unique: {len(code_keys)}")
+
+
+standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
+with open(standard_file, "r", encoding="utf-8") as f:
+ standard_data = json.load(f, object_pairs_hook=OrderedDict)
+standard_keys = set(standard_data.keys())
+
+# Define the standard file name
+unused_keys = standard_keys - code_keys
+logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}")
+for unused_key in unused_keys:
+ logger.info(f"\t{unused_key}")
+
+missing_keys = code_keys - standard_keys
+logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}")
+for missing_key in missing_keys:
+ logger.info(f"\t{missing_key}")
+
+code_keys_dict = OrderedDict()
+for s in strings:
+ code_keys_dict[s] = s
+
+# write back
+with open(standard_file, "w", encoding="utf-8") as f:
+ json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True)
+ f.write("\n")
+
+logger.info(f"Updated {standard_file}")
+
+
+# Define the standard file name
+standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
+
+# Find all JSON files in the directory
+dir_path = I18N_FILE_PATH
+languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE]
+
+# Load the standard file
+with open(standard_file, "r", encoding="utf-8") as f:
+ standard_data = json.load(f, object_pairs_hook=OrderedDict)
+
+# Loop through each language file
+for lang_file in languages:
+ # Load the language file
+ with open(lang_file, "r", encoding="utf-8") as f:
+ lang_data = json.load(f, object_pairs_hook=OrderedDict)
+
+ # Find the difference between the language file and the standard file
+ diff = set(standard_data.keys()) - set(lang_data.keys())
+
+ miss = set(lang_data.keys()) - set(standard_data.keys())
+
+ # Add any missing keys to the language file
+ for key in diff:
+ lang_data[key] = "#!" + key
+ logger.info(f"Added missing key: {key} to {lang_file}")
+
+ # Del any extra keys to the language file
+ for key in miss:
+ del lang_data[key]
+ logger.info(f"Del extra key: {key} from {lang_file}")
+
+ # Sort the keys of the language file to match the order of the standard file
+ lang_data = OrderedDict(
+ sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0]))
+ )
+
+ # Save the updated language file
+ with open(lang_file, "w", encoding="utf-8") as f:
+ json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True)
+ f.write("\n")
+
+ logger.info(f"Updated {lang_file}")
+
+logger.info("Done")
diff --git a/fish_speech/models/text2semantic/__init__.py b/fish_speech/models/text2semantic/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fish_speech/models/text2semantic/__pycache__/__init__.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ddbc750faebdda397ea9ee396540f805dbe37ad8
Binary files /dev/null and b/fish_speech/models/text2semantic/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fish_speech/models/text2semantic/__pycache__/lit_module.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/lit_module.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e8a12b6c647b866c1dcd1615ff10de5cc5bf13b4
Binary files /dev/null and b/fish_speech/models/text2semantic/__pycache__/lit_module.cpython-310.pyc differ
diff --git a/fish_speech/models/text2semantic/__pycache__/llama.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/llama.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..893946c7e719a7a96b24a98c22c457735b66942f
Binary files /dev/null and b/fish_speech/models/text2semantic/__pycache__/llama.cpython-310.pyc differ
diff --git a/fish_speech/models/text2semantic/__pycache__/lora.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/lora.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6b40834cf2ef865bb0b1d78d34f7df84f73f3f61
Binary files /dev/null and b/fish_speech/models/text2semantic/__pycache__/lora.cpython-310.pyc differ
diff --git a/fish_speech/models/text2semantic/lit_module.py b/fish_speech/models/text2semantic/lit_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..df970400f8a073be4c4166a697245fabdf6b09b0
--- /dev/null
+++ b/fish_speech/models/text2semantic/lit_module.py
@@ -0,0 +1,202 @@
+from typing import Any, Optional
+
+import lightning as L
+import torch
+import torch.nn.functional as F
+from lightning.pytorch.utilities.types import OptimizerLRScheduler
+
+import fish_speech.utils as utils
+from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
+from fish_speech.models.text2semantic.llama import NaiveTransformer
+
+log = utils.RankedLogger(__name__, rank_zero_only=True)
+
+
+class TextToSemantic(L.LightningModule):
+ def __init__(
+ self,
+ model: NaiveTransformer,
+ optimizer: Any,
+ lr_scheduler: Any,
+ ):
+ super().__init__()
+
+ self.model = model
+ self.optimizer_builder = optimizer
+ self.lr_scheduler_builder = lr_scheduler
+
+ def forward(self, x):
+ return self.model(x)
+
+ def on_save_checkpoint(self, checkpoint):
+ # Save only LoRA parameters
+ state_dict = checkpoint["state_dict"]
+ use_lora = any("lora" in name for name in state_dict.keys())
+ if not use_lora:
+ return
+
+ for name in list(state_dict.keys()):
+ if "lora" not in name:
+ state_dict.pop(name)
+
+ def configure_optimizers(self) -> OptimizerLRScheduler:
+ # Get weight decay parameters
+ weight_decay_parameters, other_parameters = [], []
+ for name, param in self.named_parameters():
+ if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
+ other_parameters.append(param)
+ else:
+ weight_decay_parameters.append(param)
+
+ optimizer = self.optimizer_builder(
+ [
+ {"params": weight_decay_parameters},
+ {"params": other_parameters, "weight_decay": 0.0},
+ ]
+ )
+
+ # Print the parameters and their weight decay
+ for i in optimizer.param_groups:
+ log.info(
+ f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
+ )
+
+ lr_scheduler = self.lr_scheduler_builder(optimizer)
+
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": {
+ "scheduler": lr_scheduler,
+ "interval": "step",
+ },
+ }
+
+ # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
+ def get_batch_logps(
+ self,
+ logits: torch.FloatTensor,
+ labels: torch.LongTensor,
+ average_log_prob: bool = False,
+ ) -> torch.FloatTensor:
+ """Compute the log probabilities of the given labels under the given logits.
+
+ Args:
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
+
+ Returns:
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
+ """
+ assert logits.shape[:-1] == labels.shape
+
+ labels = labels.clone()
+ loss_mask = labels != -100
+
+ # dummy token; we'll ignore the losses on these tokens later
+ labels[labels == -100] = 0
+
+ per_token_logps = torch.gather(
+ logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
+ ).squeeze(-1)
+
+ if average_log_prob:
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
+ else:
+ return (per_token_logps * loss_mask).sum(-1)
+
+ def _step(self, batch, batch_idx, stage: str):
+ is_train = stage == "train"
+
+ if is_train:
+ # Key part to make lora work
+ # Otherwise the parameters are merged, which lead to incorrect gradients
+ self.model.train()
+
+ # Do positive and negative samples in the same batch to speed up training
+ labels = batch["labels"]
+ outputs = self.model(
+ inp=batch["inputs"],
+ key_padding_mask=batch["attention_masks"],
+ )
+ token_logits = outputs.token_logits
+ codebook_logits = outputs.codebook_logits
+
+ # Generate labels
+ base_loss = F.cross_entropy(
+ token_logits.view(-1, token_logits.size(-1)),
+ labels[:, 0].reshape(-1),
+ ignore_index=-100,
+ )
+
+ codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
+ semantic_loss = F.cross_entropy(
+ codebook_logits.view(-1, codebook_logits.size(-1)),
+ codebook_labels.reshape(-1),
+ ignore_index=-100,
+ )
+
+ loss = base_loss + semantic_loss
+
+ self.log(
+ f"{stage}/loss",
+ loss,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=True,
+ logger=True,
+ sync_dist=not is_train,
+ )
+
+ self.log(
+ f"{stage}/base_loss",
+ base_loss,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=False,
+ logger=True,
+ sync_dist=not is_train,
+ )
+
+ self.log(
+ f"{stage}/semantic_loss",
+ semantic_loss,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=False,
+ logger=True,
+ sync_dist=not is_train,
+ )
+
+ # Top-5 accuracy
+ accuracy = self.get_accuracy(codebook_logits, codebook_labels)
+ self.log(
+ f"{stage}/top_5_accuracy",
+ accuracy,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=True,
+ logger=True,
+ sync_dist=not is_train,
+ )
+
+ return loss
+
+ def get_accuracy(self, logits, labels):
+ mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
+ if mask.sum() == 0:
+ return torch.tensor(0.0, device=logits.device)
+
+ _, indices = logits.topk(5, dim=-1)
+ correct = indices.eq(labels.unsqueeze(-1))
+ correct[~mask] = 0
+ correct = correct.sum()
+ accuracy = correct / mask.sum()
+
+ return accuracy
+
+ def training_step(self, batch, batch_idx):
+ return self._step(batch, batch_idx, "train")
+
+ def validation_step(self, batch, batch_idx):
+ return self._step(batch, batch_idx, "val")
diff --git a/fish_speech/models/text2semantic/llama.py b/fish_speech/models/text2semantic/llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b5cd276c0c382a3334c45ca9bf74ea1c8a142d5
--- /dev/null
+++ b/fish_speech/models/text2semantic/llama.py
@@ -0,0 +1,752 @@
+import json
+import math
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+from loguru import logger
+from torch import Tensor
+from torch.nn import functional as F
+from torch.nn.attention import SDPBackend, sdpa_kernel
+from torch.utils.checkpoint import checkpoint
+from transformers import AutoTokenizer
+
+from fish_speech.conversation import SEMANTIC_TOKEN
+from fish_speech.utils import RankedLogger
+
+from .lora import LoraConfig, setup_lora
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def find_multiple(n: int, k: int) -> int:
+ if n % k == 0:
+ return n
+ return n + k - (n % k)
+
+
+@dataclass
+class BaseModelArgs:
+ model_type: str = "base"
+
+ vocab_size: int = 32000
+ n_layer: int = 32
+ n_head: int = 32
+ dim: int = 4096
+ intermediate_size: int = None
+ n_local_heads: int = -1
+ head_dim: int = 64
+ rope_base: float = 10000
+ norm_eps: float = 1e-5
+ max_seq_len: int = 2048
+ dropout: float = 0.0
+ tie_word_embeddings: bool = True
+ attention_qkv_bias: bool = False
+
+ # Codebook configs
+ codebook_size: int = 160
+ num_codebooks: int = 4
+
+ # Gradient checkpointing
+ use_gradient_checkpointing: bool = True
+
+ # Initialize the model
+ initializer_range: float = 0.02
+
+ def __post_init__(self):
+ if self.n_local_heads == -1:
+ self.n_local_heads = self.n_head
+ if self.intermediate_size is None:
+ hidden_dim = 4 * self.dim
+ n_hidden = int(2 * hidden_dim / 3)
+ self.intermediate_size = find_multiple(n_hidden, 256)
+ self.head_dim = self.dim // self.n_head
+
+ @staticmethod
+ def from_pretrained(path: str):
+ path = Path(path)
+
+ if path.is_dir():
+ path = path / "config.json"
+
+ with open(path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+
+ match data["model_type"]:
+ case "naive":
+ cls = NaiveModelArgs
+ case "dual_ar":
+ cls = DualARModelArgs
+ case _:
+ raise ValueError(f"Unknown model type: {data['model_type']}")
+
+ return cls(**data)
+
+ def save(self, path: str):
+ with open(path, "w") as f:
+ json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
+
+
+@dataclass
+class NaiveModelArgs(BaseModelArgs):
+ model_type: str = "naive"
+
+
+@dataclass
+class DualARModelArgs(BaseModelArgs):
+ model_type: str = "dual_ar"
+ n_fast_layer: int = 4
+
+
+class KVCache(nn.Module):
+ def __init__(
+ self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
+ ):
+ super().__init__()
+ cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
+ self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
+ self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
+
+ def update(self, input_pos, k_val, v_val):
+ # input_pos: [S], k_val: [B, H, S, D]
+ assert input_pos.shape[0] == k_val.shape[2]
+
+ k_out = self.k_cache
+ v_out = self.v_cache
+ k_out[:, :, input_pos] = k_val
+ v_out[:, :, input_pos] = v_val
+
+ return k_out, v_out
+
+
+@dataclass
+class TransformerForwardResult:
+ token_logits: Tensor
+ codebook_logits: Tensor
+
+
+@dataclass
+class BaseTransformerForwardResult:
+ logits: Tensor
+ hidden_states: Tensor
+
+
+class BaseTransformer(nn.Module):
+ def __init__(
+ self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True
+ ) -> None:
+ super().__init__()
+ self.config = config
+ self.tokenizer = tokenizer
+
+ self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN)
+
+ # Slow transformer
+ self.embeddings = nn.Embedding(
+ config.vocab_size,
+ config.dim,
+ )
+ self.codebook_embeddings = nn.Embedding(
+ config.codebook_size * config.num_codebooks,
+ config.dim,
+ )
+ self.layers = nn.ModuleList(
+ TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
+ )
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
+
+ if self.config.tie_word_embeddings is False:
+ self.output = nn.Linear(
+ config.dim,
+ config.vocab_size,
+ bias=False,
+ )
+
+ self.register_buffer(
+ "freqs_cis",
+ precompute_freqs_cis(
+ config.max_seq_len,
+ config.dim // config.n_head,
+ config.rope_base,
+ ),
+ persistent=False,
+ )
+ self.register_buffer(
+ "causal_mask",
+ torch.tril(
+ torch.ones(
+ config.max_seq_len,
+ config.max_seq_len,
+ dtype=torch.bool,
+ )
+ ),
+ persistent=False,
+ )
+
+ # For kv cache
+ self.max_batch_size = -1
+ self.max_seq_len = -1
+
+ if init_weights:
+ self.apply(self._init_weights)
+
+ def setup_caches(
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
+ ):
+ if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
+ return
+
+ head_dim = self.config.dim // self.config.n_head
+ max_seq_len = find_multiple(max_seq_len, 8)
+ self.max_seq_len = max_seq_len
+ self.max_batch_size = max_batch_size
+
+ for b in self.layers:
+ b.attention.kv_cache = KVCache(
+ max_batch_size,
+ max_seq_len,
+ self.config.n_local_heads,
+ head_dim,
+ dtype=dtype,
+ )
+
+ def embed(self, x: Tensor) -> Tensor:
+ vocab_embeds = [self.embeddings(x[:, 0])]
+ for i in range(self.config.num_codebooks):
+ emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
+ emb[x[:, 0] != self.semantic_token_id] = 0
+ vocab_embeds.append(emb)
+
+ x = torch.stack(vocab_embeds, dim=3)
+ x = x.sum(dim=3)
+
+ return x
+
+ def forward(
+ self,
+ inp: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> BaseTransformerForwardResult:
+ seq_len = inp.size(2)
+
+ # Here we want to merge the embeddings of the codebooks
+ x = self.embed(inp)
+
+ freqs_cis = self.freqs_cis[:seq_len]
+
+ # Not that the causal mask here follows the definition of scaled_dot_product_attention
+ # That is, FALSE means masked out
+ # To maintain consistency, key_padding_mask use TRUE to mask out
+ mask = None
+ if key_padding_mask is not None:
+ mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
+ mask = mask & key_padding_mask[:, None, None, :].logical_not()
+
+ for layer in self.layers:
+ if self.config.use_gradient_checkpointing and self.training:
+ x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
+ else:
+ x = layer(x, freqs_cis, mask)
+
+ # We got slow_out here
+ slow_out = self.norm(x)
+
+ if self.config.tie_word_embeddings:
+ token_logits = F.linear(slow_out, self.embeddings.weight)
+ else:
+ token_logits = self.output(slow_out)
+
+ return BaseTransformerForwardResult(
+ logits=token_logits,
+ hidden_states=x,
+ )
+
+ def forward_generate(
+ self,
+ x: Tensor,
+ input_pos: Optional[Tensor] = None,
+ return_all: bool = False,
+ ) -> BaseTransformerForwardResult:
+ # This is used for generation, optimized for torch compile
+ assert (
+ self.max_seq_len != -1 and self.max_batch_size != -1
+ ), "Please call setup_caches before forward_generate"
+
+ x = self.embed(x)
+
+ mask = self.causal_mask[
+ None, None, input_pos, : self.max_seq_len
+ ] # (B, N, Q, K)
+ freqs_cis = self.freqs_cis[input_pos]
+
+ for layer in self.layers:
+ x = layer(x, freqs_cis, mask, input_pos=input_pos)
+
+ # If prefill, we only calculate the logits of last token
+ if x.size(1) > 1 and not return_all:
+ x = x[:, -1:]
+
+ # We got slow_out here
+ slow_out = self.norm(x)
+
+ if self.config.tie_word_embeddings:
+ token_logits = F.linear(slow_out, self.embeddings.weight)
+ else:
+ token_logits = self.output(slow_out)
+
+ return BaseTransformerForwardResult(
+ logits=token_logits,
+ hidden_states=x,
+ )
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ @staticmethod
+ def from_pretrained(
+ path: str,
+ load_weights: bool = False,
+ max_length: int | None = None,
+ lora_config: LoraConfig | None = None,
+ rope_base: int | None = None,
+ ) -> "BaseTransformer":
+ config = BaseModelArgs.from_pretrained(str(path))
+ if max_length is not None:
+ config.max_seq_len = max_length
+ log.info(f"Override max_seq_len to {max_length}")
+
+ if rope_base is not None:
+ config.rope_base = rope_base
+ log.info(f"Override rope_base to {rope_base}")
+
+ match config.model_type:
+ case "naive":
+ model_cls = NaiveTransformer
+ case "dual_ar":
+ model_cls = DualARTransformer
+ case _:
+ raise ValueError(f"Unknown model type: {config.model_type}")
+
+ tokenizer = AutoTokenizer.from_pretrained(str(path))
+ log.info(f"Loading model from {path}, config: {config}")
+ model = model_cls(config, tokenizer=tokenizer)
+
+ if lora_config is not None:
+ setup_lora(model, lora_config)
+ log.info(f"LoRA setup: {lora_config}")
+
+ if load_weights is False:
+ log.info("Randomly initialized model")
+ else:
+
+ if "int8" in str(Path(path)):
+ logger.info("Using int8 weight-only quantization!")
+ from tools.llama.quantize import WeightOnlyInt8QuantHandler
+
+ simple_quantizer = WeightOnlyInt8QuantHandler(model)
+ model = simple_quantizer.convert_for_runtime()
+
+ if "int4" in str(Path(path)):
+ logger.info("Using int4 quantization!")
+ path_comps = path.name.split("-")
+ assert path_comps[-2].startswith("g")
+ groupsize = int(path_comps[-2][1:])
+ from tools.llama.quantize import WeightOnlyInt4QuantHandler
+
+ simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
+ model = simple_quantizer.convert_for_runtime()
+
+ weights = torch.load(
+ Path(path) / "model.pth", map_location="cpu", mmap=True
+ )
+ err = model.load_state_dict(weights, strict=False, assign=True)
+ log.info(f"Loaded weights with error: {err}")
+
+ return model
+
+ def save_pretrained(self, path: str, drop_lora: bool = False):
+ path = Path(path)
+ path.mkdir(parents=True, exist_ok=True)
+
+ self.config.save(path / "config.json")
+ state_dict = self.state_dict()
+
+ if drop_lora:
+ for key in list(state_dict.keys()):
+ if "lora" not in key:
+ continue
+
+ state_dict.pop(key)
+ log.info(f"Drop LoRA parameter: {key}")
+
+ torch.save(state_dict, path / "model.pth")
+ self.tokenizer.save_pretrained(path)
+
+
+class NaiveTransformer(BaseTransformer):
+ def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
+ super().__init__(config, init_weights=False, tokenizer=tokenizer)
+
+ self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
+ self.codebook_output = nn.Linear(
+ config.dim,
+ config.codebook_size * config.num_codebooks,
+ bias=False,
+ )
+
+ self.apply(self._init_weights)
+
+ def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
+ token_logits = result.logits
+ x = result.hidden_states
+
+ # Codebook
+ codebook_logits = self.codebook_output(self.codebook_norm(x))
+ codebook_logits = rearrange(
+ codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
+ )
+
+ return TransformerForwardResult(
+ token_logits=token_logits,
+ codebook_logits=codebook_logits,
+ )
+
+ def forward(
+ self,
+ inp: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> TransformerForwardResult:
+ result = super().forward(
+ inp=inp,
+ key_padding_mask=key_padding_mask,
+ )
+ return self.decode(result)
+
+ def forward_generate(
+ self, x: Tensor, input_pos: Optional[Tensor] = None
+ ) -> TransformerForwardResult:
+ result = super().forward_generate(x, input_pos)
+ return self.decode(result)
+
+
+class DualARTransformer(BaseTransformer):
+ def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
+ super().__init__(config, init_weights=False, tokenizer=tokenizer)
+
+ # Fast transformer
+ self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim)
+
+ # The equivalent bs is so large that sdpa doesn't work
+ self.fast_layers = nn.ModuleList(
+ TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
+ )
+ self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
+ self.fast_output = nn.Linear(
+ config.dim,
+ config.codebook_size,
+ bias=False,
+ )
+
+ self.apply(self._init_weights)
+
+ def setup_caches(
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
+ ):
+ super().setup_caches(max_batch_size, max_seq_len, dtype)
+
+ head_dim = self.config.dim // self.config.n_head
+
+ # Fast transformer
+ # The max seq len here is the number of codebooks
+ for b in self.fast_layers:
+ b.attention.kv_cache = KVCache(
+ max_batch_size,
+ self.config.num_codebooks,
+ self.config.n_local_heads,
+ head_dim,
+ dtype=dtype,
+ )
+
+ def forward(
+ self,
+ inp: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> TransformerForwardResult:
+ parent_result = super().forward(inp, key_padding_mask)
+ token_logits = parent_result.logits
+ x = parent_result.hidden_states
+
+ # Fast transformer
+ fast_seq_len = self.config.num_codebooks
+ fast_mask = self.causal_mask[
+ None, None, :fast_seq_len, :fast_seq_len
+ ] # (B, N, Q, K)
+ fast_freqs_cis = self.freqs_cis[:fast_seq_len]
+
+ # Drop the last token and rotate left
+ codebooks = inp[:, 1:-1, 1:]
+ codebooks = F.pad(codebooks, (0, 1), value=0)
+ codebook_embeddings = self.fast_embeddings(codebooks)
+ x = torch.cat([x[:, None], codebook_embeddings], dim=1)
+ b, s = x.size(0), x.size(2)
+ x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
+
+ # Remove padded part
+ codebooks = rearrange(codebooks, "b n s -> (b s) n")
+ codebook_mask = (codebooks == 0).all(dim=-1)
+
+ if torch.all(codebook_mask):
+ # If all codebooks are padded, we keep first 8 to make sure the model runs
+ codebook_mask[:8] = False
+
+ x_bs, x_len = x.size(0), x.size(1)
+ x = x[~codebook_mask]
+
+ for layer in self.fast_layers:
+ if self.config.use_gradient_checkpointing and self.training:
+ x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
+ else:
+ x = layer(x, fast_freqs_cis, fast_mask)
+
+ # unflatten the batch and num_codebooks
+ fast_out = self.fast_norm(x)
+ codebook_logits = self.fast_output(fast_out)
+
+ # Re-pad the codebook_logits
+ buffer = torch.zeros(
+ x_bs,
+ x_len,
+ codebook_logits.size(-1),
+ device=codebook_logits.device,
+ dtype=codebook_logits.dtype,
+ )
+ buffer[~codebook_mask] = codebook_logits
+ codebook_logits = buffer
+
+ assert codebook_logits.shape[1] == self.config.num_codebooks
+ codebook_logits = rearrange(
+ codebook_logits,
+ "(b s) n d -> b s n d",
+ b=b,
+ s=s,
+ n=self.config.num_codebooks,
+ )
+
+ return TransformerForwardResult(
+ token_logits=token_logits,
+ codebook_logits=codebook_logits,
+ )
+
+ def forward_generate_fast(
+ self, x: Tensor, input_pos: Optional[Tensor] = None
+ ) -> Tensor:
+ # Fast transformer
+ x = x.view(1, 1, -1)
+
+ fast_mask = self.causal_mask[
+ None, None, input_pos, : self.config.num_codebooks
+ ] # (B, N, Q, K)
+ fast_freqs_cis = self.freqs_cis[input_pos]
+
+ for layer in self.fast_layers:
+ x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
+
+ # unflatten the batch and num_codebooks
+ fast_out = self.fast_norm(x) # only take the last token
+ codebook_logits = self.fast_output(fast_out)
+
+ return codebook_logits
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
+ super().__init__()
+ self.attention = Attention(config, use_sdpa=use_sdpa)
+ self.feed_forward = FeedForward(config)
+ self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
+ self.attention_norm = RMSNorm(config.dim, config.norm_eps)
+
+ def forward(
+ self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
+ ) -> Tensor:
+ h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
+ out = h + self.feed_forward(self.ffn_norm(h))
+ return out
+
+
+class Attention(nn.Module):
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
+ super().__init__()
+ assert config.dim % config.n_head == 0
+
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
+ # key, query, value projections for all heads, but in a batch
+ self.wqkv = nn.Linear(
+ config.dim, total_head_dim, bias=config.attention_qkv_bias
+ )
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
+ self.kv_cache = None
+
+ self.dropout = config.dropout
+ self.n_head = config.n_head
+ self.head_dim = config.head_dim
+ self.n_local_heads = config.n_local_heads
+ self.dim = config.dim
+ self.use_sdpa = use_sdpa
+ self._register_load_state_dict_pre_hook(self.load_hook)
+
+ def load_hook(self, state_dict, prefix, *args):
+ if prefix + "wq.weight" in state_dict:
+ wq = state_dict.pop(prefix + "wq.weight")
+ wk = state_dict.pop(prefix + "wk.weight")
+ wv = state_dict.pop(prefix + "wv.weight")
+ state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
+
+ def forward(
+ self,
+ x: Tensor,
+ freqs_cis: Tensor,
+ mask: Tensor,
+ input_pos: Optional[Tensor] = None,
+ ) -> Tensor:
+ bsz, seqlen, _ = x.shape
+
+ kv_size = self.n_local_heads * self.head_dim
+ q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
+
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
+ k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+ v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+
+ q = apply_rotary_emb(q, freqs_cis)
+ k = apply_rotary_emb(k, freqs_cis)
+
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+ if self.kv_cache is not None:
+ k, v = self.kv_cache.update(input_pos, k, v)
+
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+
+ if self.use_sdpa:
+ if mask is None:
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
+ y = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ dropout_p=self.dropout if self.training else 0.0,
+ is_causal=True,
+ # No third party attn_mask here to use flash_attention
+ )
+ else:
+ y = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=mask,
+ dropout_p=self.dropout if self.training else 0.0,
+ )
+ else:
+ y = self.eq_scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=mask,
+ dropout_p=self.dropout if self.training else 0.0,
+ )
+
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
+
+ return self.wo(y)
+
+ def eq_scaled_dot_product_attention(
+ self,
+ query,
+ key,
+ value,
+ attn_mask=None,
+ dropout_p=0.0,
+ ) -> torch.Tensor:
+ # This is a standard scaled dot product attention
+ # It's low efficient, but it doesn't raise cuda error
+
+ L, S = query.size(-2), key.size(-2)
+ scale_factor = 1 / math.sqrt(query.size(-1))
+ attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
+ else:
+ attn_bias += attn_mask
+
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
+ attn_weight += attn_bias
+ attn_weight = torch.softmax(attn_weight, dim=-1)
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
+
+ return attn_weight @ value
+
+
+class FeedForward(nn.Module):
+ def __init__(self, config: BaseModelArgs) -> None:
+ super().__init__()
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-5):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
+
+ def forward(self, x: Tensor) -> Tensor:
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+
+def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
+ freqs = 1.0 / (
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
+ )
+ t = torch.arange(seq_len, device=freqs.device)
+ freqs = torch.outer(t, freqs)
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
+ return cache.to(dtype=torch.bfloat16)
+
+
+def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
+ x_out2 = torch.stack(
+ [
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
+ ],
+ -1,
+ )
+
+ x_out2 = x_out2.flatten(3)
+ return x_out2.type_as(x)
diff --git a/fish_speech/models/text2semantic/lora.py b/fish_speech/models/text2semantic/lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..647ca6fcccf038e17d2cf91a2874281dff3e0938
--- /dev/null
+++ b/fish_speech/models/text2semantic/lora.py
@@ -0,0 +1,92 @@
+from dataclasses import dataclass
+
+import loralib as lora
+
+
+@dataclass
+class LoraConfig:
+ r: int
+ lora_alpha: float
+ lora_dropout: float = 0.0
+
+
+def setup_lora(model, lora_config):
+ # Replace the embedding layer with a LoRA layer
+ model.embeddings = lora.Embedding(
+ num_embeddings=model.embeddings.num_embeddings,
+ embedding_dim=model.embeddings.embedding_dim,
+ padding_idx=model.embeddings.padding_idx,
+ r=lora_config.r,
+ lora_alpha=lora_config.lora_alpha,
+ )
+
+ model.codebook_embeddings = lora.Embedding(
+ num_embeddings=model.codebook_embeddings.num_embeddings,
+ embedding_dim=model.codebook_embeddings.embedding_dim,
+ padding_idx=model.codebook_embeddings.padding_idx,
+ r=lora_config.r,
+ lora_alpha=lora_config.lora_alpha,
+ )
+
+ # Replace output layer with a LoRA layer
+ linears = [(model, "output")]
+
+ # Replace all linear layers with LoRA layers
+ for layer in model.layers:
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
+ linears.extend(
+ [
+ (layer.feed_forward, "w1"),
+ (layer.feed_forward, "w2"),
+ (layer.feed_forward, "w3"),
+ ]
+ )
+
+ if hasattr(model, "fast_layers"):
+ model.fast_embeddings = lora.Embedding(
+ num_embeddings=model.fast_embeddings.num_embeddings,
+ embedding_dim=model.fast_embeddings.embedding_dim,
+ padding_idx=model.fast_embeddings.padding_idx,
+ r=lora_config.r,
+ lora_alpha=lora_config.lora_alpha,
+ )
+
+ # Dual-AR model
+ linears.append((model, "fast_output"))
+
+ for layer in model.fast_layers:
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
+ linears.extend(
+ [
+ (layer.feed_forward, "w1"),
+ (layer.feed_forward, "w2"),
+ (layer.feed_forward, "w3"),
+ ]
+ )
+
+ for module, layer in linears:
+ updated_linear = lora.Linear(
+ in_features=getattr(module, layer).in_features,
+ out_features=getattr(module, layer).out_features,
+ bias=getattr(module, layer).bias,
+ r=lora_config.r,
+ lora_alpha=lora_config.lora_alpha,
+ lora_dropout=lora_config.lora_dropout,
+ )
+ setattr(module, layer, updated_linear)
+
+ # Mark only the LoRA layers as trainable
+ lora.mark_only_lora_as_trainable(model, bias="none")
+
+
+def get_merged_state_dict(model):
+ # This line will merge the state dict of the model and the LoRA parameters
+ model.eval()
+
+ # Then we need to remove the LoRA parameters from the state dict
+ state_dict = model.state_dict()
+ for name in list(state_dict.keys()):
+ if "lora" in name:
+ state_dict.pop(name)
+
+ return state_dict
diff --git a/fish_speech/models/vqgan/__init__.py b/fish_speech/models/vqgan/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fish_speech/models/vqgan/__pycache__/__init__.cpython-310.pyc b/fish_speech/models/vqgan/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..816aebc6ece20e81ff3e60c00cd37e3b851db9cc
Binary files /dev/null and b/fish_speech/models/vqgan/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fish_speech/models/vqgan/modules/__pycache__/firefly.cpython-310.pyc b/fish_speech/models/vqgan/modules/__pycache__/firefly.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..011bdcaf170c195fd9f89d3e4c4c50df7ef44da1
Binary files /dev/null and b/fish_speech/models/vqgan/modules/__pycache__/firefly.cpython-310.pyc differ
diff --git a/fish_speech/models/vqgan/modules/__pycache__/fsq.cpython-310.pyc b/fish_speech/models/vqgan/modules/__pycache__/fsq.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ba2acfe4938fe71e64952119373c2b118a520356
Binary files /dev/null and b/fish_speech/models/vqgan/modules/__pycache__/fsq.cpython-310.pyc differ
diff --git a/fish_speech/models/vqgan/modules/firefly.py b/fish_speech/models/vqgan/modules/firefly.py
new file mode 100644
index 0000000000000000000000000000000000000000..91fc9118cc26f4d99171e7db3ee871071a7a296a
--- /dev/null
+++ b/fish_speech/models/vqgan/modules/firefly.py
@@ -0,0 +1,596 @@
+import math
+from functools import partial
+from math import prod
+from typing import Callable
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn.utils.parametrizations import weight_norm
+from torch.nn.utils.parametrize import remove_parametrizations
+from torch.utils.checkpoint import checkpoint
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv1D") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return (kernel_size * dilation - dilation) // 2
+
+
+def unpad1d(x: torch.Tensor, paddings: tuple[int, int]):
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ assert (padding_left + padding_right) <= x.shape[-1]
+ end = x.shape[-1] - padding_right
+ return x[..., padding_left:end]
+
+
+def get_extra_padding_for_conv1d(
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
+) -> int:
+ """See `pad_for_conv1d`."""
+ length = x.shape[-1]
+ n_frames = (length - kernel_size + padding_total) / stride + 1
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
+ return ideal_length - length
+
+
+def pad1d(
+ x: torch.Tensor,
+ paddings: tuple[int, int],
+ mode: str = "zeros",
+ value: float = 0.0,
+):
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
+ If this is the case, we insert extra 0 padding to the right
+ before the reflection happen.
+ """
+ length = x.shape[-1]
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ if mode == "reflect":
+ max_pad = max(padding_left, padding_right)
+ extra_pad = 0
+ if length <= max_pad:
+ extra_pad = max_pad - length + 1
+ x = F.pad(x, (0, extra_pad))
+ padded = F.pad(x, paddings, mode, value)
+ end = padded.shape[-1] - extra_pad
+ return padded[..., :end]
+ else:
+ return F.pad(x, paddings, mode, value)
+
+
+class FishConvNet(nn.Module):
+ def __init__(
+ self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1
+ ):
+ super(FishConvNet, self).__init__()
+ self.conv = nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ dilation=dilation,
+ groups=groups,
+ )
+ self.stride = stride
+ self.kernel_size = (kernel_size - 1) * dilation + 1
+ self.dilation = dilation
+
+ def forward(self, x):
+ pad = self.kernel_size - self.stride
+ extra_padding = get_extra_padding_for_conv1d(
+ x, self.kernel_size, self.stride, pad
+ )
+ x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
+ return self.conv(x).contiguous()
+
+ def weight_norm(self, name="weight", dim=0):
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
+ return self
+
+ def remove_parametrizations(self, name="weight"):
+ self.conv = remove_parametrizations(self.conv, name)
+ return self
+
+
+class FishTransConvNet(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1):
+ super(FishTransConvNet, self).__init__()
+ self.conv = nn.ConvTranspose1d(
+ in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
+ )
+ self.stride = stride
+ self.kernel_size = kernel_size
+
+ def forward(self, x):
+ x = self.conv(x)
+ pad = self.kernel_size - self.stride
+ padding_right = math.ceil(pad)
+ padding_left = pad - padding_right
+ x = unpad1d(x, (padding_left, padding_right))
+ return x.contiguous()
+
+ def weight_norm(self, name="weight", dim=0):
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
+ return self
+
+ def remove_parametrizations(self, name="weight"):
+ self.conv = remove_parametrizations(self.conv, name)
+ return self
+
+
+class ResBlock1(torch.nn.Module):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super().__init__()
+
+ self.convs1 = nn.ModuleList(
+ [
+ FishConvNet(
+ channels, channels, kernel_size, stride=1, dilation=dilation[0]
+ ).weight_norm(),
+ FishConvNet(
+ channels, channels, kernel_size, stride=1, dilation=dilation[1]
+ ).weight_norm(),
+ FishConvNet(
+ channels, channels, kernel_size, stride=1, dilation=dilation[2]
+ ).weight_norm(),
+ ]
+ )
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList(
+ [
+ FishConvNet(
+ channels, channels, kernel_size, stride=1, dilation=dilation[0]
+ ).weight_norm(),
+ FishConvNet(
+ channels, channels, kernel_size, stride=1, dilation=dilation[1]
+ ).weight_norm(),
+ FishConvNet(
+ channels, channels, kernel_size, stride=1, dilation=dilation[2]
+ ).weight_norm(),
+ ]
+ )
+ self.convs2.apply(init_weights)
+
+ def forward(self, x):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.silu(x)
+ xt = c1(xt)
+ xt = F.silu(xt)
+ xt = c2(xt)
+ x = xt + x
+ return x
+
+ def remove_parametrizations(self):
+ for conv in self.convs1:
+ conv.remove_parametrizations()
+ for conv in self.convs2:
+ conv.remove_parametrizations()
+
+
+class ParallelBlock(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ kernel_sizes: tuple[int] = (3, 7, 11),
+ dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+ ):
+ super().__init__()
+
+ assert len(kernel_sizes) == len(dilation_sizes)
+
+ self.blocks = nn.ModuleList()
+ for k, d in zip(kernel_sizes, dilation_sizes):
+ self.blocks.append(ResBlock1(channels, k, d))
+
+ def forward(self, x):
+ return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
+
+ def remove_parametrizations(self):
+ for block in self.blocks:
+ block.remove_parametrizations()
+
+
+class HiFiGANGenerator(nn.Module):
+ def __init__(
+ self,
+ *,
+ hop_length: int = 512,
+ upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
+ upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
+ resblock_kernel_sizes: tuple[int] = (3, 7, 11),
+ resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+ num_mels: int = 128,
+ upsample_initial_channel: int = 512,
+ pre_conv_kernel_size: int = 7,
+ post_conv_kernel_size: int = 7,
+ post_activation: Callable = partial(nn.SiLU, inplace=True),
+ ):
+ super().__init__()
+
+ assert (
+ prod(upsample_rates) == hop_length
+ ), f"hop_length must be {prod(upsample_rates)}"
+
+ self.conv_pre = FishConvNet(
+ num_mels,
+ upsample_initial_channel,
+ pre_conv_kernel_size,
+ stride=1,
+ ).weight_norm()
+
+ self.num_upsamples = len(upsample_rates)
+ self.num_kernels = len(resblock_kernel_sizes)
+
+ self.noise_convs = nn.ModuleList()
+ self.ups = nn.ModuleList()
+
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+ self.ups.append(
+ FishTransConvNet(
+ upsample_initial_channel // (2**i),
+ upsample_initial_channel // (2 ** (i + 1)),
+ k,
+ stride=u,
+ ).weight_norm()
+ )
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = upsample_initial_channel // (2 ** (i + 1))
+ self.resblocks.append(
+ ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
+ )
+
+ self.activation_post = post_activation()
+ self.conv_post = FishConvNet(
+ ch, 1, post_conv_kernel_size, stride=1
+ ).weight_norm()
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ def forward(self, x):
+ x = self.conv_pre(x)
+
+ for i in range(self.num_upsamples):
+ x = F.silu(x, inplace=True)
+ x = self.ups[i](x)
+
+ if self.training and self.checkpointing:
+ x = checkpoint(
+ self.resblocks[i],
+ x,
+ use_reentrant=False,
+ )
+ else:
+ x = self.resblocks[i](x)
+
+ x = self.activation_post(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_parametrizations(self):
+ for up in self.ups:
+ up.remove_parametrizations()
+ for block in self.resblocks:
+ block.remove_parametrizations()
+ self.conv_pre.remove_parametrizations()
+ self.conv_post.remove_parametrizations()
+
+
+# DropPath copied from timm library
+def drop_path(
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
+):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+
+ """ # noqa: E501
+
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (
+ x.ndim - 1
+ ) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
+
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+ def extra_repr(self):
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
+
+
+class LayerNorm(nn.Module):
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
+ with shape (batch_size, channels, height, width).
+ """ # noqa: E501
+
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
+ self.eps = eps
+ self.data_format = data_format
+ if self.data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError
+ self.normalized_shape = (normalized_shape,)
+
+ def forward(self, x):
+ if self.data_format == "channels_last":
+ return F.layer_norm(
+ x, self.normalized_shape, self.weight, self.bias, self.eps
+ )
+ elif self.data_format == "channels_first":
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None] * x + self.bias[:, None]
+ return x
+
+
+# ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
+class ConvNeXtBlock(nn.Module):
+ r"""ConvNeXt Block. There are two equivalent implementations:
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+ We use (2) as we find it slightly faster in PyTorch
+
+ Args:
+ dim (int): Number of input channels.
+ drop_path (float): Stochastic depth rate. Default: 0.0
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
+ kernel_size (int): Kernel size for depthwise conv. Default: 7.
+ dilation (int): Dilation for depthwise conv. Default: 1.
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ dim: int,
+ drop_path: float = 0.0,
+ layer_scale_init_value: float = 1e-6,
+ mlp_ratio: float = 4.0,
+ kernel_size: int = 7,
+ dilation: int = 1,
+ ):
+ super().__init__()
+
+ self.dwconv = FishConvNet(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ # padding=int(dilation * (kernel_size - 1) / 2),
+ groups=dim,
+ ) # depthwise conv
+ self.norm = LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(
+ dim, int(mlp_ratio * dim)
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
+ self.gamma = (
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+ if layer_scale_init_value > 0
+ else None
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ def forward(self, x, apply_residual: bool = True):
+ input = x
+
+ x = self.dwconv(x)
+ x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+
+ if self.gamma is not None:
+ x = self.gamma * x
+
+ x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
+ x = self.drop_path(x)
+
+ if apply_residual:
+ x = input + x
+
+ return x
+
+
+class ConvNeXtEncoder(nn.Module):
+ def __init__(
+ self,
+ input_channels: int = 3,
+ depths: list[int] = [3, 3, 9, 3],
+ dims: list[int] = [96, 192, 384, 768],
+ drop_path_rate: float = 0.0,
+ layer_scale_init_value: float = 1e-6,
+ kernel_size: int = 7,
+ ):
+ super().__init__()
+ assert len(depths) == len(dims)
+
+ self.downsample_layers = nn.ModuleList()
+ stem = nn.Sequential(
+ FishConvNet(
+ input_channels,
+ dims[0],
+ kernel_size=7,
+ # padding=3,
+ # padding_mode="replicate",
+ # padding_mode="zeros",
+ ),
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
+ )
+ self.downsample_layers.append(stem)
+
+ for i in range(len(depths) - 1):
+ mid_layer = nn.Sequential(
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
+ nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
+ )
+ self.downsample_layers.append(mid_layer)
+
+ self.stages = nn.ModuleList()
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
+
+ cur = 0
+ for i in range(len(depths)):
+ stage = nn.Sequential(
+ *[
+ ConvNeXtBlock(
+ dim=dims[i],
+ drop_path=dp_rates[cur + j],
+ layer_scale_init_value=layer_scale_init_value,
+ kernel_size=kernel_size,
+ )
+ for j in range(depths[i])
+ ]
+ )
+ self.stages.append(stage)
+ cur += depths[i]
+
+ self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ for i in range(len(self.downsample_layers)):
+ x = self.downsample_layers[i](x)
+ x = self.stages[i](x)
+
+ return self.norm(x)
+
+
+class FireflyArchitecture(nn.Module):
+ def __init__(
+ self,
+ backbone: nn.Module,
+ head: nn.Module,
+ quantizer: nn.Module,
+ spec_transform: nn.Module,
+ ):
+ super().__init__()
+
+ self.backbone = backbone
+ self.head = head
+ self.quantizer = quantizer
+ self.spec_transform = spec_transform
+ self.downsample_factor = math.prod(self.quantizer.downsample_factor)
+
+ def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
+ if self.spec_transform is not None:
+ x = self.spec_transform(x)
+
+ x = self.backbone(x)
+ if mask is not None:
+ x = x * mask
+
+ if self.quantizer is not None:
+ vq_result = self.quantizer(x)
+ x = vq_result.z
+
+ if mask is not None:
+ x = x * mask
+
+ x = self.head(x, template=template)
+
+ if x.ndim == 2:
+ x = x[:, None, :]
+
+ if self.vq is not None:
+ return x, vq_result
+
+ return x
+
+ def encode(self, audios, audio_lengths):
+ audios = audios.float()
+
+ mels = self.spec_transform(audios)
+ mel_lengths = audio_lengths // self.spec_transform.hop_length
+ mel_masks = sequence_mask(mel_lengths, mels.shape[2])
+ mel_masks_float_conv = mel_masks[:, None, :].float()
+ mels = mels * mel_masks_float_conv
+
+ # Encode
+ encoded_features = self.backbone(mels) * mel_masks_float_conv
+ feature_lengths = mel_lengths // self.downsample_factor
+
+ return self.quantizer.encode(encoded_features), feature_lengths
+
+ def decode(self, indices, feature_lengths) -> torch.Tensor:
+ mel_masks = sequence_mask(
+ feature_lengths * self.downsample_factor,
+ indices.shape[2] * self.downsample_factor,
+ )
+ mel_masks_float_conv = mel_masks[:, None, :].float()
+ audio_lengths = (
+ feature_lengths * self.downsample_factor * self.spec_transform.hop_length
+ )
+
+ audio_masks = sequence_mask(
+ audio_lengths,
+ indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length,
+ )
+ audio_masks_float_conv = audio_masks[:, None, :].float()
+
+ z = self.quantizer.decode(indices) * mel_masks_float_conv
+ x = self.head(z) * audio_masks_float_conv
+
+ return x, audio_lengths
+
+ def remove_parametrizations(self):
+ if hasattr(self.backbone, "remove_parametrizations"):
+ self.backbone.remove_parametrizations()
+
+ if hasattr(self.head, "remove_parametrizations"):
+ self.head.remove_parametrizations()
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
diff --git a/fish_speech/models/vqgan/modules/fsq.py b/fish_speech/models/vqgan/modules/fsq.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ea4853376b6e663404ff48d6c6b5f664dde4094
--- /dev/null
+++ b/fish_speech/models/vqgan/modules/fsq.py
@@ -0,0 +1,116 @@
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from vector_quantize_pytorch import GroupedResidualFSQ
+
+from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet
+
+
+@dataclass
+class FSQResult:
+ z: torch.Tensor
+ codes: torch.Tensor
+ latents: torch.Tensor
+
+
+class DownsampleFiniteScalarQuantize(nn.Module):
+ def __init__(
+ self,
+ input_dim: int = 512,
+ n_codebooks: int = 9,
+ n_groups: int = 1,
+ levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
+ downsample_factor: tuple[int] = (2, 2),
+ downsample_dims: tuple[int] | None = None,
+ ):
+ super().__init__()
+
+ if downsample_dims is None:
+ downsample_dims = [input_dim for _ in range(len(downsample_factor))]
+
+ all_dims = (input_dim,) + tuple(downsample_dims)
+
+ self.residual_fsq = GroupedResidualFSQ(
+ dim=all_dims[-1],
+ levels=levels,
+ num_quantizers=n_codebooks,
+ groups=n_groups,
+ )
+
+ self.downsample_factor = downsample_factor
+ self.downsample_dims = downsample_dims
+
+ self.downsample = nn.Sequential(
+ *[
+ nn.Sequential(
+ FishConvNet(
+ all_dims[idx],
+ all_dims[idx + 1],
+ kernel_size=factor,
+ stride=factor,
+ ),
+ ConvNeXtBlock(dim=all_dims[idx + 1]),
+ )
+ for idx, factor in enumerate(downsample_factor)
+ ]
+ )
+
+ self.upsample = nn.Sequential(
+ *[
+ nn.Sequential(
+ FishTransConvNet(
+ all_dims[idx + 1],
+ all_dims[idx],
+ kernel_size=factor,
+ stride=factor,
+ ),
+ ConvNeXtBlock(dim=all_dims[idx]),
+ )
+ for idx, factor in reversed(list(enumerate(downsample_factor)))
+ ]
+ )
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, z) -> FSQResult:
+ original_shape = z.shape
+ z = self.downsample(z)
+ quantized, indices = self.residual_fsq(z.mT)
+ result = FSQResult(
+ z=quantized.mT,
+ codes=indices.mT,
+ latents=z,
+ )
+ result.z = self.upsample(result.z)
+
+ # Pad or crop z to match original shape
+ diff = original_shape[-1] - result.z.shape[-1]
+ left = diff // 2
+ right = diff - left
+
+ if diff > 0:
+ result.z = F.pad(result.z, (left, right))
+ elif diff < 0:
+ result.z = result.z[..., left:-right]
+
+ return result
+
+ def encode(self, z):
+ z = self.downsample(z)
+ _, indices = self.residual_fsq(z.mT)
+ indices = rearrange(indices, "g b l r -> b (g r) l")
+ return indices
+
+ def decode(self, indices: torch.Tensor):
+ indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
+ z_q = self.residual_fsq.get_output_from_indices(indices)
+ z_q = self.upsample(z_q.mT)
+ return z_q
diff --git a/fish_speech/models/vqgan/utils.py b/fish_speech/models/vqgan/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b90c131d214006875476a161cdfd2dffa8949dac
--- /dev/null
+++ b/fish_speech/models/vqgan/utils.py
@@ -0,0 +1,94 @@
+import matplotlib
+import torch
+from matplotlib import pyplot as plt
+
+matplotlib.use("Agg")
+
+
+def convert_pad_shape(pad_shape):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+def plot_mel(data, titles=None):
+ fig, axes = plt.subplots(len(data), 1, squeeze=False)
+
+ if titles is None:
+ titles = [None for i in range(len(data))]
+
+ plt.tight_layout()
+
+ for i in range(len(data)):
+ mel = data[i]
+
+ if isinstance(mel, torch.Tensor):
+ mel = mel.float().detach().cpu().numpy()
+
+ axes[i][0].imshow(mel, origin="lower")
+ axes[i][0].set_aspect(2.5, adjustable="box")
+ axes[i][0].set_ylim(0, mel.shape[0])
+ axes[i][0].set_title(titles[i], fontsize="medium")
+ axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
+ axes[i][0].set_anchor("W")
+
+ return fig
+
+
+def slice_segments(x, ids_str, segment_size=4):
+ ret = torch.zeros_like(x[:, :, :segment_size])
+ for i in range(x.size(0)):
+ idx_str = ids_str[i]
+ idx_end = idx_str + segment_size
+ ret[i] = x[i, :, idx_str:idx_end]
+
+ return ret
+
+
+def rand_slice_segments(x, x_lengths=None, segment_size=4):
+ b, d, t = x.size()
+ if x_lengths is None:
+ x_lengths = t
+ ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
+ ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
+ ret = slice_segments(x, ids_str, segment_size)
+ return ret, ids_str
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(in_act, n_channels):
+ n_channels_int = n_channels[0]
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+ acts = t_act * s_act
+
+ return acts
+
+
+def avg_with_mask(x, mask):
+ assert mask.dtype == torch.float, "Mask should be float"
+
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(1)
+
+ if mask.shape[1] == 1:
+ mask = mask.expand_as(x)
+
+ return (x * mask).sum() / mask.sum()
diff --git a/fish_speech/scheduler.py b/fish_speech/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..43bed6a2210723a7d5e1ea0a48ba61140047ca29
--- /dev/null
+++ b/fish_speech/scheduler.py
@@ -0,0 +1,40 @@
+import math
+
+
+def get_cosine_schedule_with_warmup_lr_lambda(
+ current_step: int,
+ *,
+ num_warmup_steps: int | float,
+ num_training_steps: int,
+ num_cycles: float = 0.5,
+ final_lr_ratio: float = 0.0,
+):
+ if 0 < num_warmup_steps < 1: # float mode
+ num_warmup_steps = int(num_warmup_steps * num_training_steps)
+
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+
+ progress = float(current_step - num_warmup_steps) / float(
+ max(1, num_training_steps - num_warmup_steps)
+ )
+
+ return max(
+ final_lr_ratio,
+ 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
+ )
+
+
+def get_constant_schedule_with_warmup_lr_lambda(
+ current_step: int,
+ *,
+ num_warmup_steps: int | float,
+ num_training_steps: int | None = None,
+):
+ if 0 < num_warmup_steps < 1: # float mode
+ num_warmup_steps = int(num_warmup_steps * num_training_steps)
+
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+
+ return 1.0
diff --git a/fish_speech/text/__init__.py b/fish_speech/text/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d740bd8eed447d162e55b165965dec17130377ce
--- /dev/null
+++ b/fish_speech/text/__init__.py
@@ -0,0 +1,4 @@
+from .clean import clean_text
+from .spliter import split_text
+
+__all__ = ["clean_text", "split_text"]
diff --git a/fish_speech/text/__pycache__/__init__.cpython-310.pyc b/fish_speech/text/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee3c407f140e4719affed265a5c38ccdd391e280
Binary files /dev/null and b/fish_speech/text/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fish_speech/text/__pycache__/__init__.cpython-311.pyc b/fish_speech/text/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e526318b27af352e84a843b8853b16c65c70e406
Binary files /dev/null and b/fish_speech/text/__pycache__/__init__.cpython-311.pyc differ
diff --git a/fish_speech/text/__pycache__/clean.cpython-310.pyc b/fish_speech/text/__pycache__/clean.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..70ef441e0d1522129bf92e40ddf8b00372e6312f
Binary files /dev/null and b/fish_speech/text/__pycache__/clean.cpython-310.pyc differ
diff --git a/fish_speech/text/__pycache__/clean.cpython-311.pyc b/fish_speech/text/__pycache__/clean.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..50e9eaa666868f9be7dea2da436442b03d001bf9
Binary files /dev/null and b/fish_speech/text/__pycache__/clean.cpython-311.pyc differ
diff --git a/fish_speech/text/__pycache__/spliter.cpython-310.pyc b/fish_speech/text/__pycache__/spliter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..01c21677d80136aaed94e853b3faa4ee43c2bda6
Binary files /dev/null and b/fish_speech/text/__pycache__/spliter.cpython-310.pyc differ
diff --git a/fish_speech/text/__pycache__/spliter.cpython-311.pyc b/fish_speech/text/__pycache__/spliter.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4aa7e6e5ff356d530f44de2e763cb4815197926c
Binary files /dev/null and b/fish_speech/text/__pycache__/spliter.cpython-311.pyc differ
diff --git a/fish_speech/text/chn_text_norm/.gitignore b/fish_speech/text/chn_text_norm/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..75ea58fa4a7bf34fc9ab35afee24684aa6ef4c89
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/.gitignore
@@ -0,0 +1,114 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+
+# JetBrains PyCharm
+.idea
+
+# Customize
+references
+url.txt
+
+# Git
+.git
diff --git a/fish_speech/text/chn_text_norm/README.md b/fish_speech/text/chn_text_norm/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8450a2c6c0f8e40f4509f5be196eb9f9d2b9afb6
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/README.md
@@ -0,0 +1,36 @@
+# This account is no longer in use, see [Atomicoo](https://github.com/atomicoo) for my latest works.
+
+# Chn Text Norm
+
+this is a repository for chinese text normalization (no longer maintained).
+
+## Quick Start ##
+
+### Git Clone Repo ###
+
+git clone this repo to the root directory of your project which need to use it.
+
+ cd /path/to/proj
+ git clone https://github.com/Joee1995/chn-text-norm.git
+
+after that, your doc tree should be:
+```
+proj # root of your project
+|--- chn_text_norm # this chn-text-norm tool
+ |--- text.py
+ |--- ...
+|--- text_normalize.py # your text normalization code
+|--- ...
+```
+
+### How to Use ? ###
+
+ # text_normalize.py
+ from chn_text_norm.text import *
+
+ raw_text = 'your raw text'
+ text = Text(raw_text=raw_text).normalize()
+
+### How to add quantums ###
+
+打开test.py,然后你就知道怎么做了。
diff --git a/fish_speech/text/chn_text_norm/__init__.py b/fish_speech/text/chn_text_norm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fish_speech/text/chn_text_norm/__pycache__/__init__.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf43a50e99266da3a54e51615d8dc370008fd748
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/__init__.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3da0f8c0604b6b72c65fea53dd7e4566b5d3cca5
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/__init__.cpython-311.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/basic_class.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/basic_class.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f4dfdb015766efc309ead7be2cc40690445f7719
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/basic_class.cpython-310.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/basic_class.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/basic_class.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f61ea440e6de5ab546e5723513e18a17f2e08d50
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/basic_class.cpython-311.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/basic_constant.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/basic_constant.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..77d8601c1463006fde09a0e5959dfa46bc6450ee
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/basic_constant.cpython-310.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/basic_constant.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/basic_constant.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..510b5d1edbbe9f8d35cf484a7b605077c1a4268b
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/basic_constant.cpython-311.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/basic_util.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/basic_util.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..daf7f728bae1ae58904e6ecfe6cb6a83b7dc20db
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/basic_util.cpython-310.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/basic_util.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/basic_util.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ade7c3a6fe8a8203d58e0cd2dddd76c4142dec4
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/basic_util.cpython-311.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/cardinal.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/cardinal.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..892bfe68b3cf4287aaf5b875aec92bb2dfc917c6
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/cardinal.cpython-310.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/cardinal.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/cardinal.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..05fb51069de415b89cd3bcd567949e0c71ba3b4a
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/cardinal.cpython-311.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/date.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/date.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d09d9f6f67baf2c792ce64c8250f7754d10b0e23
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/date.cpython-310.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/date.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/date.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5dd12ae4851e5b97088d7b084fa395f2270cbafd
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/date.cpython-311.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/digit.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/digit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19ff89ce017f6aba2ea60cd4418c44ae9aa92cc6
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/digit.cpython-310.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/digit.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/digit.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..61773b54c1f390144c050c2d6edc523cea861272
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/digit.cpython-311.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/fraction.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/fraction.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e4e3ef77729e1372a5658809e4b04b0bc61a6db1
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/fraction.cpython-310.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/fraction.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/fraction.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c6571976f0f4c020136f2fd6023148a546ea5a87
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/fraction.cpython-311.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/money.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/money.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c7f74b20554c6c9a93718ec5378dfa7dcc0fb340
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/money.cpython-310.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/money.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/money.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dec683ad68c44eaf2c699cad4866870db617c375
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/money.cpython-311.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/percentage.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/percentage.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6ba258f5f8632961d8095de91c29be3be1d7a032
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/percentage.cpython-310.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/percentage.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/percentage.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..16a0e5b4fcb77f1f8779e5094554754ddf9dd9d6
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/percentage.cpython-311.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/telephone.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/telephone.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f8b4cbd724f19dd63970f717eaf869181cd4f5d4
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/telephone.cpython-310.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/telephone.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/telephone.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..12e4b471b4dcd4644cb0d264dfb701ca09cc5244
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/telephone.cpython-311.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/text.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/text.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..66b3817e2446e570067fe4cd1ccad0b9692134cd
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/text.cpython-310.pyc differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/text.cpython-311.pyc b/fish_speech/text/chn_text_norm/__pycache__/text.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bb01e18f4f23dd14cd5c4872f96841d775e9a868
Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/text.cpython-311.pyc differ
diff --git a/fish_speech/text/chn_text_norm/basic_class.py b/fish_speech/text/chn_text_norm/basic_class.py
new file mode 100644
index 0000000000000000000000000000000000000000..58d8f8eb7fc85d0861f106667d8f4e3e52b54761
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/basic_class.py
@@ -0,0 +1,172 @@
+# -*- coding: utf-8 -*-
+"""基本类
+中文字符类
+中文数字/数位类
+中文数字类
+中文数位类
+中文数字系统类
+中文数学符号类
+*中文其他符号类
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-02"
+
+from fish_speech.text.chn_text_norm.basic_constant import NUMBERING_TYPES
+
+
+class ChineseChar(object):
+ """
+ 中文字符
+ 每个字符对应简体和繁体,
+ e.g. 简体 = '负', 繁体 = '負'
+ 转换时可转换为简体或繁体
+ """
+
+ def __init__(self, simplified, traditional):
+ self.simplified = simplified
+ self.traditional = traditional
+ self.__repr__ = self.__str__
+
+ def __str__(self):
+ return self.simplified or self.traditional or None
+
+ def __repr__(self):
+ return self.__str__()
+
+
+class ChineseNumberUnit(ChineseChar):
+ """
+ 中文数字/数位字符
+ 每个字符除繁简体外还有一个额外的大写字符
+ e.g. '陆' 和 '陸'
+ """
+
+ def __init__(self, power, simplified, traditional, big_s, big_t):
+ super(ChineseNumberUnit, self).__init__(simplified, traditional)
+ self.power = power
+ self.big_s = big_s
+ self.big_t = big_t
+
+ def __str__(self):
+ return "10^{}".format(self.power)
+
+ @classmethod
+ def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
+
+ if small_unit:
+ return ChineseNumberUnit(
+ power=index + 1,
+ simplified=value[0],
+ traditional=value[1],
+ big_s=value[1],
+ big_t=value[1],
+ )
+ elif numbering_type == NUMBERING_TYPES[0]:
+ return ChineseNumberUnit(
+ power=index + 8,
+ simplified=value[0],
+ traditional=value[1],
+ big_s=value[0],
+ big_t=value[1],
+ )
+ elif numbering_type == NUMBERING_TYPES[1]:
+ return ChineseNumberUnit(
+ power=(index + 2) * 4,
+ simplified=value[0],
+ traditional=value[1],
+ big_s=value[0],
+ big_t=value[1],
+ )
+ elif numbering_type == NUMBERING_TYPES[2]:
+ return ChineseNumberUnit(
+ power=pow(2, index + 3),
+ simplified=value[0],
+ traditional=value[1],
+ big_s=value[0],
+ big_t=value[1],
+ )
+ else:
+ raise ValueError(
+ "Counting type should be in {0} ({1} provided).".format(
+ NUMBERING_TYPES, numbering_type
+ )
+ )
+
+
+class ChineseNumberDigit(ChineseChar):
+ """
+ 中文数字字符
+ """
+
+ def __init__(
+ self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None
+ ):
+ super(ChineseNumberDigit, self).__init__(simplified, traditional)
+ self.value = value
+ self.big_s = big_s
+ self.big_t = big_t
+ self.alt_s = alt_s
+ self.alt_t = alt_t
+
+ def __str__(self):
+ return str(self.value)
+
+ @classmethod
+ def create(cls, i, v):
+ return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
+
+
+class ChineseMath(ChineseChar):
+ """
+ 中文数位字符
+ """
+
+ def __init__(self, simplified, traditional, symbol, expression=None):
+ super(ChineseMath, self).__init__(simplified, traditional)
+ self.symbol = symbol
+ self.expression = expression
+ self.big_s = simplified
+ self.big_t = traditional
+
+
+CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
+
+
+class NumberSystem(object):
+ """
+ 中文数字系统
+ """
+
+ pass
+
+
+class MathSymbol(object):
+ """
+ 用于中文数字系统的数学符号 (繁/简体), e.g.
+ positive = ['正', '正']
+ negative = ['负', '負']
+ point = ['点', '點']
+ """
+
+ def __init__(self, positive, negative, point):
+ self.positive = positive
+ self.negative = negative
+ self.point = point
+
+ def __iter__(self):
+ for v in self.__dict__.values():
+ yield v
+
+
+# class OtherSymbol(object):
+# """
+# 其他符号
+# """
+#
+# def __init__(self, sil):
+# self.sil = sil
+#
+# def __iter__(self):
+# for v in self.__dict__.values():
+# yield v
diff --git a/fish_speech/text/chn_text_norm/basic_constant.py b/fish_speech/text/chn_text_norm/basic_constant.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a65991b9a9d349a0571c80508633951e52749ef
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/basic_constant.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+"""基本常量
+中文数字/数位/符号字符常量
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-02"
+
+CHINESE_DIGIS = "零一二三四五六七八九"
+BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖"
+BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖"
+SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万"
+SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬"
+LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载"
+LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載"
+SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万"
+SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬"
+
+ZERO_ALT = "〇"
+ONE_ALT = "幺"
+TWO_ALTS = ["两", "兩"]
+
+POSITIVE = ["正", "正"]
+NEGATIVE = ["负", "負"]
+POINT = ["点", "點"]
+# PLUS = [u'加', u'加']
+# SIL = [u'杠', u'槓']
+
+# 中文数字系统类型
+NUMBERING_TYPES = ["low", "mid", "high"]
diff --git a/fish_speech/text/chn_text_norm/basic_util.py b/fish_speech/text/chn_text_norm/basic_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbf6130be87f285eed9998186508ea489d3bac9e
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/basic_util.py
@@ -0,0 +1,342 @@
+# -*- coding: utf-8 -*-
+"""基本方法
+创建中文数字系统 方法
+中文字符串 <=> 数字串 方法
+数字串 <=> 中文字符串 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-02"
+
+from fish_speech.text.chn_text_norm.basic_class import *
+from fish_speech.text.chn_text_norm.basic_constant import *
+
+
+def create_system(numbering_type=NUMBERING_TYPES[1]):
+ """
+ 根据数字系统类型返回创建相应的数字系统,默认为 mid
+ NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
+ low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc.
+ mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc.
+ high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc.
+ 返回对应的数字系统
+ """
+
+ # chinese number units of '亿' and larger
+ all_larger_units = zip(
+ LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED,
+ LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL,
+ )
+ larger_units = [
+ CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)
+ ]
+ # chinese number units of '十, 百, 千, 万'
+ all_smaller_units = zip(
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED,
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL,
+ )
+ smaller_units = [
+ CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)
+ ]
+ # digis
+ chinese_digis = zip(
+ CHINESE_DIGIS,
+ CHINESE_DIGIS,
+ BIG_CHINESE_DIGIS_SIMPLIFIED,
+ BIG_CHINESE_DIGIS_TRADITIONAL,
+ )
+ digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
+ digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
+ digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
+ digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
+
+ # symbols
+ positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x)
+ negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x)
+ point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y)))
+ # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
+ system = NumberSystem()
+ system.units = smaller_units + larger_units
+ system.digits = digits
+ system.math = MathSymbol(positive_cn, negative_cn, point_cn)
+ # system.symbols = OtherSymbol(sil_cn)
+ return system
+
+
+def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
+
+ def get_symbol(char, system):
+ for u in system.units:
+ if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
+ return u
+ for d in system.digits:
+ if char in [
+ d.traditional,
+ d.simplified,
+ d.big_s,
+ d.big_t,
+ d.alt_s,
+ d.alt_t,
+ ]:
+ return d
+ for m in system.math:
+ if char in [m.traditional, m.simplified]:
+ return m
+
+ def string2symbols(chinese_string, system):
+ int_string, dec_string = chinese_string, ""
+ for p in [system.math.point.simplified, system.math.point.traditional]:
+ if p in chinese_string:
+ int_string, dec_string = chinese_string.split(p)
+ break
+ return [get_symbol(c, system) for c in int_string], [
+ get_symbol(c, system) for c in dec_string
+ ]
+
+ def correct_symbols(integer_symbols, system):
+ """
+ 一百八 to 一百八十
+ 一亿一千三百万 to 一亿 一千万 三百万
+ """
+
+ if integer_symbols and isinstance(integer_symbols[0], CNU):
+ if integer_symbols[0].power == 1:
+ integer_symbols = [system.digits[1]] + integer_symbols
+
+ if len(integer_symbols) > 1:
+ if isinstance(integer_symbols[-1], CND) and isinstance(
+ integer_symbols[-2], CNU
+ ):
+ integer_symbols.append(
+ CNU(integer_symbols[-2].power - 1, None, None, None, None)
+ )
+
+ result = []
+ unit_count = 0
+ for s in integer_symbols:
+ if isinstance(s, CND):
+ result.append(s)
+ unit_count = 0
+ elif isinstance(s, CNU):
+ current_unit = CNU(s.power, None, None, None, None)
+ unit_count += 1
+
+ if unit_count == 1:
+ result.append(current_unit)
+ elif unit_count > 1:
+ for i in range(len(result)):
+ if (
+ isinstance(result[-i - 1], CNU)
+ and result[-i - 1].power < current_unit.power
+ ):
+ result[-i - 1] = CNU(
+ result[-i - 1].power + current_unit.power,
+ None,
+ None,
+ None,
+ None,
+ )
+ return result
+
+ def compute_value(integer_symbols):
+ """
+ Compute the value.
+ When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
+ e.g. '两千万' = 2000 * 10000 not 2000 + 10000
+ """
+ value = [0]
+ last_power = 0
+ for s in integer_symbols:
+ if isinstance(s, CND):
+ value[-1] = s.value
+ elif isinstance(s, CNU):
+ value[-1] *= pow(10, s.power)
+ if s.power > last_power:
+ value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1]))
+ last_power = s.power
+ value.append(0)
+ return sum(value)
+
+ system = create_system(numbering_type)
+ int_part, dec_part = string2symbols(chinese_string, system)
+ int_part = correct_symbols(int_part, system)
+ int_str = str(compute_value(int_part))
+ dec_str = "".join([str(d.value) for d in dec_part])
+ if dec_part:
+ return "{0}.{1}".format(int_str, dec_str)
+ else:
+ return int_str
+
+
+def num2chn(
+ number_string,
+ numbering_type=NUMBERING_TYPES[1],
+ big=False,
+ traditional=False,
+ alt_zero=False,
+ alt_one=False,
+ alt_two=True,
+ use_zeros=True,
+ use_units=True,
+):
+
+ def get_value(value_string, use_zeros=True):
+
+ striped_string = value_string.lstrip("0")
+
+ # record nothing if all zeros
+ if not striped_string:
+ return []
+
+ # record one digits
+ elif len(striped_string) == 1:
+ if use_zeros and len(value_string) != len(striped_string):
+ return [system.digits[0], system.digits[int(striped_string)]]
+ else:
+ return [system.digits[int(striped_string)]]
+
+ # recursively record multiple digits
+ else:
+ result_unit = next(
+ u for u in reversed(system.units) if u.power < len(striped_string)
+ )
+ result_string = value_string[: -result_unit.power]
+ return (
+ get_value(result_string)
+ + [result_unit]
+ + get_value(striped_string[-result_unit.power :])
+ )
+
+ system = create_system(numbering_type)
+
+ int_dec = number_string.split(".")
+ if len(int_dec) == 1:
+ int_string = int_dec[0]
+ dec_string = ""
+ elif len(int_dec) == 2:
+ int_string = int_dec[0]
+ dec_string = int_dec[1]
+ else:
+ raise ValueError(
+ "invalid input num string with more than one dot: {}".format(number_string)
+ )
+
+ if use_units and len(int_string) > 1:
+ result_symbols = get_value(int_string)
+ else:
+ result_symbols = [system.digits[int(c)] for c in int_string]
+ dec_symbols = [system.digits[int(c)] for c in dec_string]
+ if dec_string:
+ result_symbols += [system.math.point] + dec_symbols
+
+ if alt_two:
+ liang = CND(
+ 2,
+ system.digits[2].alt_s,
+ system.digits[2].alt_t,
+ system.digits[2].big_s,
+ system.digits[2].big_t,
+ )
+ for i, v in enumerate(result_symbols):
+ if isinstance(v, CND) and v.value == 2:
+ next_symbol = (
+ result_symbols[i + 1] if i < len(result_symbols) - 1 else None
+ )
+ previous_symbol = result_symbols[i - 1] if i > 0 else None
+ if isinstance(next_symbol, CNU) and isinstance(
+ previous_symbol, (CNU, type(None))
+ ):
+ if next_symbol.power != 1 and (
+ (previous_symbol is None) or (previous_symbol.power != 1)
+ ):
+ result_symbols[i] = liang
+
+ # if big is True, '两' will not be used and `alt_two` has no impact on output
+ if big:
+ attr_name = "big_"
+ if traditional:
+ attr_name += "t"
+ else:
+ attr_name += "s"
+ else:
+ if traditional:
+ attr_name = "traditional"
+ else:
+ attr_name = "simplified"
+
+ result = "".join([getattr(s, attr_name) for s in result_symbols])
+
+ # if not use_zeros:
+ # result = result.strip(getattr(system.digits[0], attr_name))
+
+ if alt_zero:
+ result = result.replace(
+ getattr(system.digits[0], attr_name), system.digits[0].alt_s
+ )
+
+ if alt_one:
+ result = result.replace(
+ getattr(system.digits[1], attr_name), system.digits[1].alt_s
+ )
+
+ for i, p in enumerate(POINT):
+ if result.startswith(p):
+ return CHINESE_DIGIS[0] + result
+
+ # ^10, 11, .., 19
+ if (
+ len(result) >= 2
+ and result[1]
+ in [
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0],
+ ]
+ and result[0]
+ in [
+ CHINESE_DIGIS[1],
+ BIG_CHINESE_DIGIS_SIMPLIFIED[1],
+ BIG_CHINESE_DIGIS_TRADITIONAL[1],
+ ]
+ ):
+ result = result[1:]
+
+ return result
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ all_chinese_number_string = (
+ CHINESE_DIGIS
+ + BIG_CHINESE_DIGIS_SIMPLIFIED
+ + BIG_CHINESE_DIGIS_TRADITIONAL
+ + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED
+ + LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL
+ + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED
+ + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL
+ + ZERO_ALT
+ + ONE_ALT
+ + "".join(TWO_ALTS + POSITIVE + NEGATIVE + POINT)
+ )
+
+ print("num:", chn2num("一万零四百零三点八零五"))
+ print("num:", chn2num("一亿六点三"))
+ print("num:", chn2num("一亿零六点三"))
+ print("num:", chn2num("两千零一亿六点三"))
+ # print('num:', chn2num('一零零八六'))
+ print("txt:", num2chn("10260.03", alt_zero=True))
+ print("txt:", num2chn("20037.090", numbering_type="low", traditional=True))
+ print("txt:", num2chn("100860001.77", numbering_type="high", big=True))
+ print(
+ "txt:",
+ num2chn(
+ "059523810880",
+ alt_one=True,
+ alt_two=False,
+ use_lzeros=True,
+ use_rzeros=True,
+ use_units=False,
+ ),
+ )
+
+ print(all_chinese_number_string)
diff --git a/fish_speech/text/chn_text_norm/cardinal.py b/fish_speech/text/chn_text_norm/cardinal.py
new file mode 100644
index 0000000000000000000000000000000000000000..ace9f5ad8e7f3be3a8e41b11dc0b9f80db799616
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/cardinal.py
@@ -0,0 +1,32 @@
+# -*- coding: utf-8 -*-
+"""CARDINAL类 (包含小数DECIMAL类)
+纯数 <=> 中文字符串 方法
+中文字符串 <=> 纯数 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-03"
+
+from fish_speech.text.chn_text_norm.basic_util import *
+
+
+class Cardinal:
+ """
+ CARDINAL类
+ """
+
+ def __init__(self, cardinal=None, chntext=None):
+ self.cardinal = cardinal
+ self.chntext = chntext
+
+ def chntext2cardinal(self):
+ return chn2num(self.chntext)
+
+ def cardinal2chntext(self):
+ return num2chn(self.cardinal)
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(Cardinal(cardinal="21357.230").cardinal2chntext())
diff --git a/fish_speech/text/chn_text_norm/date.py b/fish_speech/text/chn_text_norm/date.py
new file mode 100644
index 0000000000000000000000000000000000000000..77acfdb9a91df0fe3c615a0784f61aad87fbe56e
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/date.py
@@ -0,0 +1,75 @@
+# -*- coding: utf-8 -*-
+"""DATE类
+日期 <=> 中文字符串 方法
+中文字符串 <=> 日期 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-07"
+
+from fish_speech.text.chn_text_norm.cardinal import Cardinal
+from fish_speech.text.chn_text_norm.digit import Digit
+
+
+class Date:
+ """
+ DATE类
+ """
+
+ def __init__(self, date=None, chntext=None):
+ self.date = date
+ self.chntext = chntext
+
+ # def chntext2date(self):
+ # chntext = self.chntext
+ # try:
+ # year, other = chntext.strip().split('年', maxsplit=1)
+ # year = Digit(chntext=year).digit2chntext() + '年'
+ # except ValueError:
+ # other = chntext
+ # year = ''
+ # if other:
+ # try:
+ # month, day = other.strip().split('月', maxsplit=1)
+ # month = Cardinal(chntext=month).chntext2cardinal() + '月'
+ # except ValueError:
+ # day = chntext
+ # month = ''
+ # if day:
+ # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
+ # else:
+ # month = ''
+ # day = ''
+ # date = year + month + day
+ # self.date = date
+ # return self.date
+
+ def date2chntext(self):
+ date = self.date
+ try:
+ year, other = date.strip().split("年", maxsplit=1)
+ year = Digit(digit=year).digit2chntext() + "年"
+ except ValueError:
+ other = date
+ year = ""
+ if other:
+ try:
+ month, day = other.strip().split("月", maxsplit=1)
+ month = Cardinal(cardinal=month).cardinal2chntext() + "月"
+ except ValueError:
+ day = date
+ month = ""
+ if day:
+ day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
+ else:
+ month = ""
+ day = ""
+ chntext = year + month + day
+ self.chntext = chntext
+ return self.chntext
+
+
+if __name__ == "__main__":
+
+ # 测试
+ print(Date(date="09年3月16日").date2chntext())
diff --git a/fish_speech/text/chn_text_norm/digit.py b/fish_speech/text/chn_text_norm/digit.py
new file mode 100644
index 0000000000000000000000000000000000000000..47c0cd4ad0c700635f84470bfdacfbdafb4a6185
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/digit.py
@@ -0,0 +1,32 @@
+# -*- coding: utf-8 -*-
+"""DIGIT类
+数字串 <=> 中文字符串 方法
+中文字符串 <=> 数字串 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-03"
+
+from fish_speech.text.chn_text_norm.basic_util import *
+
+
+class Digit:
+ """
+ DIGIT类
+ """
+
+ def __init__(self, digit=None, chntext=None):
+ self.digit = digit
+ self.chntext = chntext
+
+ # def chntext2digit(self):
+ # return chn2num(self.chntext)
+
+ def digit2chntext(self):
+ return num2chn(self.digit, alt_two=False, use_units=False)
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(Digit(digit="2016").digit2chntext())
diff --git a/fish_speech/text/chn_text_norm/fraction.py b/fish_speech/text/chn_text_norm/fraction.py
new file mode 100644
index 0000000000000000000000000000000000000000..b43b6a7feb634d346d59a2b4ab84b77ac88df103
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/fraction.py
@@ -0,0 +1,35 @@
+# -*- coding: utf-8 -*-
+"""FRACTION类
+分数 <=> 中文字符串 方法
+中文字符串 <=> 分数 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-03"
+
+from fish_speech.text.chn_text_norm.basic_util import *
+
+
+class Fraction:
+ """
+ FRACTION类
+ """
+
+ def __init__(self, fraction=None, chntext=None):
+ self.fraction = fraction
+ self.chntext = chntext
+
+ def chntext2fraction(self):
+ denominator, numerator = self.chntext.split("分之")
+ return chn2num(numerator) + "/" + chn2num(denominator)
+
+ def fraction2chntext(self):
+ numerator, denominator = self.fraction.split("/")
+ return num2chn(denominator) + "分之" + num2chn(numerator)
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(Fraction(fraction="2135/7230").fraction2chntext())
+ print(Fraction(chntext="五百八十一分之三百六十九").chntext2fraction())
diff --git a/fish_speech/text/chn_text_norm/money.py b/fish_speech/text/chn_text_norm/money.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4c980d32134e1460e96e5bcbcc73d0d55974d2a
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/money.py
@@ -0,0 +1,43 @@
+# -*- coding: utf-8 -*-
+"""MONEY类
+金钱 <=> 中文字符串 方法
+中文字符串 <=> 金钱 方法
+"""
+import re
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-08"
+
+from fish_speech.text.chn_text_norm.cardinal import Cardinal
+
+
+class Money:
+ """
+ MONEY类
+ """
+
+ def __init__(self, money=None, chntext=None):
+ self.money = money
+ self.chntext = chntext
+
+ # def chntext2money(self):
+ # return self.money
+
+ def money2chntext(self):
+ money = self.money
+ pattern = re.compile(r"(\d+(\.\d+)?)")
+ matchers = pattern.findall(money)
+ if matchers:
+ for matcher in matchers:
+ money = money.replace(
+ matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext()
+ )
+ self.chntext = money
+ return self.chntext
+
+
+if __name__ == "__main__":
+
+ # 测试
+ print(Money(money="21.5万元").money2chntext())
+ print(Money(money="230块5毛").money2chntext())
diff --git a/fish_speech/text/chn_text_norm/percentage.py b/fish_speech/text/chn_text_norm/percentage.py
new file mode 100644
index 0000000000000000000000000000000000000000..46abbf545af62eb951d8f6fe40bcf684587f81b0
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/percentage.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+"""PERCENTAGE类
+百分数 <=> 中文字符串 方法
+中文字符串 <=> 百分数 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-06"
+
+from fish_speech.text.chn_text_norm.basic_util import *
+
+
+class Percentage:
+ """
+ PERCENTAGE类
+ """
+
+ def __init__(self, percentage=None, chntext=None):
+ self.percentage = percentage
+ self.chntext = chntext
+
+ def chntext2percentage(self):
+ return chn2num(self.chntext.strip().strip("百分之")) + "%"
+
+ def percentage2chntext(self):
+ return "百分之" + num2chn(self.percentage.strip().strip("%"))
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(Percentage(chntext="百分之五十六点零三").chntext2percentage())
+ print(Percentage(percentage="65.3%").percentage2chntext())
diff --git a/fish_speech/text/chn_text_norm/telephone.py b/fish_speech/text/chn_text_norm/telephone.py
new file mode 100644
index 0000000000000000000000000000000000000000..e72b546db628a3b807dc6235b59b188cae3153ff
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/telephone.py
@@ -0,0 +1,51 @@
+# -*- coding: utf-8 -*-
+"""TELEPHONE类
+电话号码 <=> 中文字符串 方法
+中文字符串 <=> 电话号码 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-03"
+
+from fish_speech.text.chn_text_norm.basic_util import *
+
+
+class TelePhone:
+ """
+ TELEPHONE类
+ """
+
+ def __init__(self, telephone=None, raw_chntext=None, chntext=None):
+ self.telephone = telephone
+ self.raw_chntext = raw_chntext
+ self.chntext = chntext
+
+ # def chntext2telephone(self):
+ # sil_parts = self.raw_chntext.split('')
+ # self.telephone = '-'.join([
+ # str(chn2num(p)) for p in sil_parts
+ # ])
+ # return self.telephone
+
+ def telephone2chntext(self, fixed=False):
+
+ if fixed:
+ sil_parts = self.telephone.split("-")
+ self.raw_chntext = "".join(
+ [num2chn(part, alt_two=False, use_units=False) for part in sil_parts]
+ )
+ self.chntext = self.raw_chntext.replace("", "")
+ else:
+ sp_parts = self.telephone.strip("+").split()
+ self.raw_chntext = "".join(
+ [num2chn(part, alt_two=False, use_units=False) for part in sp_parts]
+ )
+ self.chntext = self.raw_chntext.replace("", "")
+ return self.chntext
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(TelePhone(telephone="0595-23980880").telephone2chntext())
+ # print(TelePhone(raw_chntext='零五九五杠二三八六五零九八').chntext2telephone())
diff --git a/fish_speech/text/chn_text_norm/text.py b/fish_speech/text/chn_text_norm/text.py
new file mode 100644
index 0000000000000000000000000000000000000000..54086fd933c01e14c3c55cee9adb52eefb58fd31
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/text.py
@@ -0,0 +1,177 @@
+# -*- coding: utf-8 -*-
+"""
+TEXT类
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-03"
+
+import re
+
+from fish_speech.text.chn_text_norm.cardinal import Cardinal
+from fish_speech.text.chn_text_norm.date import Date
+from fish_speech.text.chn_text_norm.digit import Digit
+from fish_speech.text.chn_text_norm.fraction import Fraction
+from fish_speech.text.chn_text_norm.money import Money
+from fish_speech.text.chn_text_norm.percentage import Percentage
+from fish_speech.text.chn_text_norm.telephone import TelePhone
+
+CURRENCY_NAMES = (
+ "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|"
+ "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)"
+)
+CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)"
+COM_QUANTIFIERS = (
+ "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|"
+ "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|"
+ "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|"
+ "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|"
+ "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|"
+ "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|人|抽)"
+)
+
+
+class Text:
+ """
+ Text类
+ """
+
+ def __init__(self, raw_text, norm_text=None):
+ self.raw_text = "^" + raw_text + "$"
+ self.norm_text = norm_text
+
+ def _particular(self):
+ text = self.norm_text
+ pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('particular')
+ for matcher in matchers:
+ text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
+ self.norm_text = text
+ return self.norm_text
+
+ def normalize(self):
+ text = self.raw_text
+
+ # 规范化日期
+ pattern = re.compile(
+ r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)"
+ )
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('date')
+ for matcher in matchers:
+ text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
+
+ # 规范化金钱
+ pattern = re.compile(
+ r"\D+((\d+(\.\d+)?)[多余几]?"
+ + CURRENCY_UNITS
+ + "(\d"
+ + CURRENCY_UNITS
+ + "?)?)"
+ )
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('money')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0], Money(money=matcher[0]).money2chntext(), 1
+ )
+
+ # 规范化固话/手机号码
+ # 手机
+ # http://www.jihaoba.com/news/show/13680
+ # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
+ # 联通:130、131、132、156、155、186、185、176
+ # 电信:133、153、189、180、181、177
+ pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('telephone')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1
+ )
+ # 固话
+ pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('fixed telephone')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0],
+ TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True),
+ 1,
+ )
+
+ # 规范化分数
+ pattern = re.compile(r"(\d+/\d+)")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('fraction')
+ for matcher in matchers:
+ text = text.replace(
+ matcher, Fraction(fraction=matcher).fraction2chntext(), 1
+ )
+
+ # 规范化百分数
+ text = text.replace("%", "%")
+ pattern = re.compile(r"(\d+(\.\d+)?%)")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('percentage')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0],
+ Percentage(percentage=matcher[0]).percentage2chntext(),
+ 1,
+ )
+
+ # 规范化纯数+量词
+ pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('cardinal+quantifier')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1
+ )
+
+ # 规范化数字编号
+ pattern = re.compile(r"(\d{4,32})")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('digit')
+ for matcher in matchers:
+ text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
+
+ # 规范化纯数
+ pattern = re.compile(r"(\d+(\.\d+)?)")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('cardinal')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1
+ )
+
+ self.norm_text = text
+ self._particular()
+
+ return self.norm_text.lstrip("^").rstrip("$")
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(Text(raw_text="固话:0595-23865596或23880880。").normalize())
+ print(Text(raw_text="手机:+86 19859213959或15659451527。").normalize())
+ print(Text(raw_text="分数:32477/76391。").normalize())
+ print(Text(raw_text="百分数:80.03%。").normalize())
+ print(Text(raw_text="编号:31520181154418。").normalize())
+ print(Text(raw_text="纯数:2983.07克或12345.60米。").normalize())
+ print(Text(raw_text="日期:1999年2月20日或09年3月15号。").normalize())
+ print(Text(raw_text="金钱:12块5,34.5元,20.1万").normalize())
+ print(Text(raw_text="特殊:O2O或B2C。").normalize())
diff --git a/fish_speech/text/clean.py b/fish_speech/text/clean.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbaf843d781f113735043319cc00dc2aed5ae382
--- /dev/null
+++ b/fish_speech/text/clean.py
@@ -0,0 +1,62 @@
+import re
+
+SYMBOLS_MAPPING = {
+ "\n": "",
+ "…": ".",
+ "“": "'",
+ "”": "'",
+ "‘": "'",
+ "’": "'",
+ "【": "",
+ "】": "",
+ "[": "",
+ "]": "",
+ "(": "",
+ ")": "",
+ "(": "",
+ ")": "",
+ "・": "",
+ "·": "",
+ "「": "'",
+ "」": "'",
+ "《": "'",
+ "》": "'",
+ "—": "",
+ "~": "",
+ "~": "",
+ ":": ",",
+ ";": ",",
+ ";": ",",
+ ":": ",",
+}
+
+REPLACE_SYMBOL_REGEX = re.compile(
+ "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
+)
+
+
+EMOJI_REGEX = re.compile(
+ "["
+ "\U0001F600-\U0001F64F" # emoticons
+ "\U0001F300-\U0001F5FF" # symbols & pictographs
+ "\U0001F680-\U0001F6FF" # transport & map symbols
+ "\U0001F1E0-\U0001F1FF" # flags (iOS)
+ "]+",
+ flags=re.UNICODE,
+)
+
+
+def clean_text(text):
+ # Clean the text
+ text = text.strip()
+
+ # Replace all chinese symbols with their english counterparts
+ text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
+
+ # Remove emojis
+ text = EMOJI_REGEX.sub(r"", text)
+
+ # Remove continuous periods (...) and commas (,,,)
+ text = re.sub(r"[.,]{2,}", lambda m: m.group()[0], text)
+
+ return text
diff --git a/fish_speech/text/spliter.py b/fish_speech/text/spliter.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4bb995487c4f53818c6b2a16cf0a886b4e02e84
--- /dev/null
+++ b/fish_speech/text/spliter.py
@@ -0,0 +1,130 @@
+import re
+import string
+
+from fish_speech.text.clean import clean_text
+
+
+def utf_8_len(text):
+ return len(text.encode("utf-8"))
+
+
+def break_text(texts, length, splits: set):
+ for text in texts:
+ if utf_8_len(text) <= length:
+ yield text
+ continue
+
+ curr = ""
+ for char in text:
+ curr += char
+
+ if char in splits:
+ yield curr
+ curr = ""
+
+ if curr:
+ yield curr
+
+
+def break_text_by_length(texts, length):
+ for text in texts:
+ if utf_8_len(text) <= length:
+ yield text
+ continue
+
+ curr = ""
+ for char in text:
+ curr += char
+
+ if utf_8_len(curr) >= length:
+ yield curr
+ curr = ""
+
+ if curr:
+ yield curr
+
+
+def add_cleaned(curr, segments):
+ curr = curr.strip()
+ if curr and not all(c.isspace() or c in string.punctuation for c in curr):
+ segments.append(curr)
+
+
+def protect_float(text):
+ # Turns 3.14 into <3_f_14> to prevent splitting
+ return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
+
+
+def unprotect_float(text):
+ # Turns <3_f_14> into 3.14
+ return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
+
+
+def split_text(text, length):
+ text = clean_text(text)
+
+ # Break the text into pieces with following rules:
+ # 1. Split the text at ".", "!", "?" if text is NOT a float
+ # 2. If the text is longer than length, split at ","
+ # 3. If the text is still longer than length, split at " "
+ # 4. If the text is still longer than length, split at any character to length
+
+ texts = [text]
+ texts = map(protect_float, texts)
+ texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"})
+ texts = map(unprotect_float, texts)
+ texts = break_text(texts, length, {",", ","})
+ texts = break_text(texts, length, {" "})
+ texts = list(break_text_by_length(texts, length))
+
+ # Then, merge the texts into segments with length <= length
+ segments = []
+ curr = ""
+
+ for text in texts:
+ if utf_8_len(curr) + utf_8_len(text) <= length:
+ curr += text
+ else:
+ add_cleaned(curr, segments)
+ curr = text
+
+ if curr:
+ add_cleaned(curr, segments)
+
+ return segments
+
+
+if __name__ == "__main__":
+ # Test the split_text function
+
+ text = "This is a test sentence. This is another test sentence. And a third one."
+
+ assert split_text(text, 50) == [
+ "This is a test sentence.",
+ "This is another test sentence. And a third one.",
+ ]
+ assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"]
+ assert split_text(" ", 10) == []
+ assert split_text("a", 10) == ["a"]
+
+ text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines."
+ assert split_text(text, 50) == [
+ "This is a test sentence with only commas,",
+ "and no dots, and no exclamation marks,",
+ "and no question marks, and no newlines.",
+ ]
+
+ text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence."
+ # First half split at " ", second half split at ","
+ assert split_text(text, 50) == [
+ "This is a test sentence This is a test sentence",
+ "This is a test sentence. This is a test sentence,",
+ "This is a test sentence, This is a test sentence.",
+ ]
+
+ text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。"
+ assert split_text(text, 50) == [
+ "这是一段很长的中文文本,",
+ "而且没有句号,也没有感叹号,",
+ "也没有问号,也没有换行符.",
+ ]
diff --git a/fish_speech/train.py b/fish_speech/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..e693f3adc4dda787bdd587aec29f53355f2b1653
--- /dev/null
+++ b/fish_speech/train.py
@@ -0,0 +1,141 @@
+import os
+
+os.environ["USE_LIBUV"] = "0"
+import sys
+from typing import Optional
+
+import hydra
+import lightning as L
+import pyrootutils
+import torch
+from lightning import Callback, LightningDataModule, LightningModule, Trainer
+from lightning.pytorch.loggers import Logger
+from lightning.pytorch.strategies import DDPStrategy
+from omegaconf import DictConfig, OmegaConf
+
+os.environ.pop("SLURM_NTASKS", None)
+os.environ.pop("SLURM_JOB_NAME", None)
+os.environ.pop("SLURM_NTASKS_PER_NODE", None)
+
+# register eval resolver and root
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+# Allow TF32 on Ampere GPUs
+torch.set_float32_matmul_precision("high")
+torch.backends.cudnn.allow_tf32 = True
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+
+import fish_speech.utils as utils
+
+log = utils.RankedLogger(__name__, rank_zero_only=True)
+
+
+@utils.task_wrapper
+def train(cfg: DictConfig) -> tuple[dict, dict]:
+ """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
+ training.
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
+ failure. Useful for multiruns, saving info about the crash, etc.
+ Args:
+ cfg (DictConfig): Configuration composed by Hydra.
+ Returns:
+ Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
+ """ # noqa: E501
+
+ # set seed for random number generators in pytorch, numpy and python.random
+ if cfg.get("seed"):
+ L.seed_everything(cfg.seed, workers=False)
+
+ if cfg.get("deterministic"):
+ torch.use_deterministic_algorithms(True)
+
+ log.info(f"Instantiating datamodule <{cfg.data._target_}>")
+ datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
+
+ log.info(f"Instantiating model <{cfg.model._target_}>")
+ model: LightningModule = hydra.utils.instantiate(cfg.model)
+
+ log.info("Instantiating callbacks...")
+ callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
+
+ log.info("Instantiating loggers...")
+ logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger"))
+
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
+ trainer: Trainer = hydra.utils.instantiate(
+ cfg.trainer,
+ callbacks=callbacks,
+ logger=logger,
+ )
+
+ object_dict = {
+ "cfg": cfg,
+ "datamodule": datamodule,
+ "model": model,
+ "callbacks": callbacks,
+ "logger": logger,
+ "trainer": trainer,
+ }
+
+ if logger:
+ log.info("Logging hyperparameters!")
+ utils.log_hyperparameters(object_dict)
+
+ if cfg.get("train"):
+ log.info("Starting training!")
+
+ ckpt_path = cfg.get("ckpt_path")
+ auto_resume = False
+
+ resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
+ if resume_ckpt_path is not None:
+ ckpt_path = resume_ckpt_path
+ auto_resume = True
+
+ if ckpt_path is not None:
+ log.info(f"Resuming from checkpoint: {ckpt_path}")
+
+ # resume weights only is disabled for auto-resume
+ if cfg.get("resume_weights_only") and auto_resume is False:
+ log.info("Resuming weights only!")
+ ckpt = torch.load(ckpt_path, map_location=model.device)
+ if "state_dict" in ckpt:
+ ckpt = ckpt["state_dict"]
+ err = model.load_state_dict(ckpt, strict=False)
+ log.info(f"Error loading state dict: {err}")
+ ckpt_path = None
+
+ trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
+
+ train_metrics = trainer.callback_metrics
+
+ if cfg.get("test"):
+ log.info("Starting testing!")
+ ckpt_path = trainer.checkpoint_callback.best_model_path
+ if ckpt_path == "":
+ log.warning("Best ckpt not found! Using current weights for testing...")
+ ckpt_path = cfg.get("ckpt_path")
+
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
+ log.info(f"Best ckpt path: {ckpt_path}")
+
+ test_metrics = trainer.callback_metrics
+
+ # merge train and test metrics
+ metric_dict = {**train_metrics, **test_metrics}
+
+ return metric_dict, object_dict
+
+
+@hydra.main(
+ version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
+)
+def main(cfg: DictConfig) -> Optional[float]:
+ # train the model
+ train(cfg)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fish_speech/utils/__init__.py b/fish_speech/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..53cf2f23174ddac9bf523730aca2f6a9965d134a
--- /dev/null
+++ b/fish_speech/utils/__init__.py
@@ -0,0 +1,24 @@
+from .braceexpand import braceexpand
+from .context import autocast_exclude_mps
+from .file import get_latest_checkpoint
+from .instantiators import instantiate_callbacks, instantiate_loggers
+from .logger import RankedLogger
+from .logging_utils import log_hyperparameters
+from .rich_utils import enforce_tags, print_config_tree
+from .utils import extras, get_metric_value, set_seed, task_wrapper
+
+__all__ = [
+ "enforce_tags",
+ "extras",
+ "get_metric_value",
+ "RankedLogger",
+ "instantiate_callbacks",
+ "instantiate_loggers",
+ "log_hyperparameters",
+ "print_config_tree",
+ "task_wrapper",
+ "braceexpand",
+ "get_latest_checkpoint",
+ "autocast_exclude_mps",
+ "set_seed",
+]
diff --git a/fish_speech/utils/__pycache__/__init__.cpython-310.pyc b/fish_speech/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f898aea996460deb67220a89b96b9f9a3ac2dcac
Binary files /dev/null and b/fish_speech/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fish_speech/utils/__pycache__/__init__.cpython-311.pyc b/fish_speech/utils/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c47fcd33d7e78d173e979359d42f144aa89ca988
Binary files /dev/null and b/fish_speech/utils/__pycache__/__init__.cpython-311.pyc differ
diff --git a/fish_speech/utils/__pycache__/braceexpand.cpython-310.pyc b/fish_speech/utils/__pycache__/braceexpand.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..06520901fe11707b16762007668ac1adc93e1695
Binary files /dev/null and b/fish_speech/utils/__pycache__/braceexpand.cpython-310.pyc differ
diff --git a/fish_speech/utils/__pycache__/braceexpand.cpython-311.pyc b/fish_speech/utils/__pycache__/braceexpand.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7672cb7128109ee6369f9f188185994c89217c9f
Binary files /dev/null and b/fish_speech/utils/__pycache__/braceexpand.cpython-311.pyc differ
diff --git a/fish_speech/utils/__pycache__/context.cpython-310.pyc b/fish_speech/utils/__pycache__/context.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1d2786f95f27a2000e27e18488cb831e495d1ada
Binary files /dev/null and b/fish_speech/utils/__pycache__/context.cpython-310.pyc differ
diff --git a/fish_speech/utils/__pycache__/context.cpython-311.pyc b/fish_speech/utils/__pycache__/context.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8d746b5ee2210833a402bc90148f56bd016f473f
Binary files /dev/null and b/fish_speech/utils/__pycache__/context.cpython-311.pyc differ
diff --git a/fish_speech/utils/__pycache__/file.cpython-310.pyc b/fish_speech/utils/__pycache__/file.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..619c698c0dfe596250bdf0c78a7bed2b6525ba0e
Binary files /dev/null and b/fish_speech/utils/__pycache__/file.cpython-310.pyc differ
diff --git a/fish_speech/utils/__pycache__/file.cpython-311.pyc b/fish_speech/utils/__pycache__/file.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83d701dcaf2a70f951279fd7e90cd00dfc51324e
Binary files /dev/null and b/fish_speech/utils/__pycache__/file.cpython-311.pyc differ
diff --git a/fish_speech/utils/__pycache__/instantiators.cpython-310.pyc b/fish_speech/utils/__pycache__/instantiators.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bd106557d47bc3e454d640c444155422d1b56bee
Binary files /dev/null and b/fish_speech/utils/__pycache__/instantiators.cpython-310.pyc differ
diff --git a/fish_speech/utils/__pycache__/instantiators.cpython-311.pyc b/fish_speech/utils/__pycache__/instantiators.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b24a6485661e4be93452051b418c1d174bcf977c
Binary files /dev/null and b/fish_speech/utils/__pycache__/instantiators.cpython-311.pyc differ
diff --git a/fish_speech/utils/__pycache__/logger.cpython-310.pyc b/fish_speech/utils/__pycache__/logger.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..62ac6dd3335b1d7d3c3cf724a7b61185829be505
Binary files /dev/null and b/fish_speech/utils/__pycache__/logger.cpython-310.pyc differ
diff --git a/fish_speech/utils/__pycache__/logger.cpython-311.pyc b/fish_speech/utils/__pycache__/logger.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..39bb762688c649c20460f08a39a3371347de56c6
Binary files /dev/null and b/fish_speech/utils/__pycache__/logger.cpython-311.pyc differ
diff --git a/fish_speech/utils/__pycache__/logging_utils.cpython-310.pyc b/fish_speech/utils/__pycache__/logging_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e85c1b95ed30fa702a301660c5f1e7cb2b15c1e
Binary files /dev/null and b/fish_speech/utils/__pycache__/logging_utils.cpython-310.pyc differ
diff --git a/fish_speech/utils/__pycache__/logging_utils.cpython-311.pyc b/fish_speech/utils/__pycache__/logging_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..85167fd9df12676fe436fdfb1ee4f7d8bd9b335d
Binary files /dev/null and b/fish_speech/utils/__pycache__/logging_utils.cpython-311.pyc differ
diff --git a/fish_speech/utils/__pycache__/rich_utils.cpython-310.pyc b/fish_speech/utils/__pycache__/rich_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2fe9fe2cd1aa91cbe7126d363af68c556ae1167c
Binary files /dev/null and b/fish_speech/utils/__pycache__/rich_utils.cpython-310.pyc differ
diff --git a/fish_speech/utils/__pycache__/rich_utils.cpython-311.pyc b/fish_speech/utils/__pycache__/rich_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..628da7662d28091856d1fcb0389c3db09ad8730c
Binary files /dev/null and b/fish_speech/utils/__pycache__/rich_utils.cpython-311.pyc differ
diff --git a/fish_speech/utils/__pycache__/spectrogram.cpython-310.pyc b/fish_speech/utils/__pycache__/spectrogram.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..17ca82b458cd0e7fc97d2ad5c8be8b130cb6a4bd
Binary files /dev/null and b/fish_speech/utils/__pycache__/spectrogram.cpython-310.pyc differ
diff --git a/fish_speech/utils/__pycache__/utils.cpython-310.pyc b/fish_speech/utils/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2809b72d4e8f82c71aaad3fe06835882c11c0f7e
Binary files /dev/null and b/fish_speech/utils/__pycache__/utils.cpython-310.pyc differ
diff --git a/fish_speech/utils/__pycache__/utils.cpython-311.pyc b/fish_speech/utils/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2eded4dfe6cbf55a90158c4a5188fbc38dee6491
Binary files /dev/null and b/fish_speech/utils/__pycache__/utils.cpython-311.pyc differ
diff --git a/fish_speech/utils/braceexpand.py b/fish_speech/utils/braceexpand.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3ac739f01f7e10e039c68c1157d6c761064f974
--- /dev/null
+++ b/fish_speech/utils/braceexpand.py
@@ -0,0 +1,217 @@
+"""
+Bash-style brace expansion
+Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py
+License: MIT
+"""
+
+import re
+import string
+from itertools import chain, product
+from typing import Iterable, Iterator, Optional
+
+__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"]
+
+
+class UnbalancedBracesError(ValueError):
+ pass
+
+
+alphabet = string.ascii_uppercase + string.ascii_lowercase
+
+int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$")
+char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$")
+escape_re = re.compile(r"\\(.)")
+
+
+def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]:
+ """braceexpand(pattern) -> iterator over generated strings
+
+ Returns an iterator over the strings resulting from brace expansion
+ of pattern. This function implements Brace Expansion as described in
+ bash(1), with the following limitations:
+
+ * A pattern containing unbalanced braces will raise an
+ UnbalancedBracesError exception. In bash, unbalanced braces will either
+ be partly expanded or ignored.
+
+ * A mixed-case character range like '{Z..a}' or '{a..Z}' will not
+ include the characters '[]^_`' between 'Z' and 'a'.
+
+ When escape is True (the default), characters in pattern can be
+ prefixed with a backslash to cause them not to be interpreted as
+ special characters for brace expansion (such as '{', '}', ',').
+ To pass through a a literal backslash, double it ('\\\\').
+
+ When escape is False, backslashes in pattern have no special
+ meaning and will be preserved in the output.
+
+ Examples:
+
+ >>> from braceexpand import braceexpand
+
+ # Integer range
+ >>> list(braceexpand('item{1..3}'))
+ ['item1', 'item2', 'item3']
+
+ # Character range
+ >>> list(braceexpand('{a..c}'))
+ ['a', 'b', 'c']
+
+ # Sequence
+ >>> list(braceexpand('index.html{,.backup}'))
+ ['index.html', 'index.html.backup']
+
+ # Nested patterns
+ >>> list(braceexpand('python{2.{5..7},3.{2,3}}'))
+ ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3']
+
+ # Prefixing an integer with zero causes all numbers to be padded to
+ # the same width.
+ >>> list(braceexpand('{07..10}'))
+ ['07', '08', '09', '10']
+
+ # An optional increment can be specified for ranges.
+ >>> list(braceexpand('{a..g..2}'))
+ ['a', 'c', 'e', 'g']
+
+ # Ranges can go in both directions.
+ >>> list(braceexpand('{4..1}'))
+ ['4', '3', '2', '1']
+
+ # Numbers can be negative
+ >>> list(braceexpand('{2..-1}'))
+ ['2', '1', '0', '-1']
+
+ # Unbalanced braces raise an exception.
+ >>> list(braceexpand('{1{2,3}'))
+ Traceback (most recent call last):
+ ...
+ UnbalancedBracesError: Unbalanced braces: '{1{2,3}'
+
+ # By default, the backslash is the escape character.
+ >>> list(braceexpand(r'{1\\{2,3}'))
+ ['1{2', '3']
+
+ # Setting 'escape' to False disables backslash escaping.
+ >>> list(braceexpand(r'\\{1,2}', escape=False))
+ ['\\\\1', '\\\\2']
+
+ """
+ return (
+ escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape)
+ )
+
+
+def parse_pattern(pattern: str, escape: bool) -> Iterator[str]:
+ start = 0
+ pos = 0
+ bracketdepth = 0
+ items: list[Iterable[str]] = []
+
+ # print 'pattern:', pattern
+ while pos < len(pattern):
+ if escape and pattern[pos] == "\\":
+ pos += 2
+ continue
+ elif pattern[pos] == "{":
+ if bracketdepth == 0 and pos > start:
+ # print 'literal:', pattern[start:pos]
+ items.append([pattern[start:pos]])
+ start = pos
+ bracketdepth += 1
+ elif pattern[pos] == "}":
+ bracketdepth -= 1
+ if bracketdepth == 0:
+ # print 'expression:', pattern[start+1:pos]
+ expr = pattern[start + 1 : pos]
+ item = parse_expression(expr, escape)
+ if item is None: # not a range or sequence
+ items.extend([["{"], parse_pattern(expr, escape), ["}"]])
+ else:
+ items.append(item)
+ start = pos + 1 # skip the closing brace
+ pos += 1
+
+ if bracketdepth != 0: # unbalanced braces
+ raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern)
+
+ if start < pos:
+ items.append([pattern[start:]])
+
+ return ("".join(item) for item in product(*items))
+
+
+def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]:
+ int_range_match = int_range_re.match(expr)
+ if int_range_match:
+ return make_int_range(*int_range_match.groups())
+
+ char_range_match = char_range_re.match(expr)
+ if char_range_match:
+ return make_char_range(*char_range_match.groups())
+
+ return parse_sequence(expr, escape)
+
+
+def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]:
+ # sequence -> chain(*sequence_items)
+ start = 0
+ pos = 0
+ bracketdepth = 0
+ items: list[Iterable[str]] = []
+
+ # print 'sequence:', seq
+ while pos < len(seq):
+ if escape and seq[pos] == "\\":
+ pos += 2
+ continue
+ elif seq[pos] == "{":
+ bracketdepth += 1
+ elif seq[pos] == "}":
+ bracketdepth -= 1
+ elif seq[pos] == "," and bracketdepth == 0:
+ items.append(parse_pattern(seq[start:pos], escape))
+ start = pos + 1 # skip the comma
+ pos += 1
+
+ if bracketdepth != 0:
+ raise UnbalancedBracesError
+ if not items:
+ return None
+
+ # part after the last comma (may be the empty string)
+ items.append(parse_pattern(seq[start:], escape))
+ return chain(*items)
+
+
+def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]:
+ if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]):
+ padding = max(len(left), len(right))
+ else:
+ padding = 0
+ step = (int(incr) or 1) if incr else 1
+ start = int(left)
+ end = int(right)
+ r = range(start, end + 1, step) if start < end else range(start, end - 1, -step)
+ fmt = "%0{}d".format(padding)
+ return (fmt % i for i in r)
+
+
+def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str:
+ step = (int(incr) or 1) if incr else 1
+ start = alphabet.index(left)
+ end = alphabet.index(right)
+ if start < end:
+ return alphabet[start : end + 1 : step]
+ else:
+ end = end or -len(alphabet)
+ return alphabet[start : end - 1 : -step]
+
+
+if __name__ == "__main__":
+ import doctest
+ import sys
+
+ failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL)
+ if failed:
+ sys.exit(1)
diff --git a/fish_speech/utils/context.py b/fish_speech/utils/context.py
new file mode 100644
index 0000000000000000000000000000000000000000..f04a99290ab32f7fe5b60656075a2d03af8468d6
--- /dev/null
+++ b/fish_speech/utils/context.py
@@ -0,0 +1,13 @@
+from contextlib import nullcontext
+
+import torch
+
+
+def autocast_exclude_mps(
+ device_type: str, dtype: torch.dtype
+) -> nullcontext | torch.autocast:
+ return (
+ nullcontext()
+ if torch.backends.mps.is_available()
+ else torch.autocast(device_type, dtype)
+ )
diff --git a/fish_speech/utils/file.py b/fish_speech/utils/file.py
new file mode 100644
index 0000000000000000000000000000000000000000..78c82640a963fa556657107729f7543d2e7c3510
--- /dev/null
+++ b/fish_speech/utils/file.py
@@ -0,0 +1,16 @@
+import os
+from pathlib import Path
+
+
+def get_latest_checkpoint(path: Path | str) -> Path | None:
+ # Find the latest checkpoint
+ ckpt_dir = Path(path)
+
+ if ckpt_dir.exists() is False:
+ return None
+
+ ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
+ if len(ckpts) == 0:
+ return None
+
+ return ckpts[-1]
diff --git a/fish_speech/utils/instantiators.py b/fish_speech/utils/instantiators.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6ee463924f588a35477937fbe3c3364043bdf3e
--- /dev/null
+++ b/fish_speech/utils/instantiators.py
@@ -0,0 +1,50 @@
+from typing import List
+
+import hydra
+from omegaconf import DictConfig
+from pytorch_lightning import Callback
+from pytorch_lightning.loggers import Logger
+
+from .logger import RankedLogger
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
+ """Instantiates callbacks from config."""
+
+ callbacks: List[Callback] = []
+
+ if not callbacks_cfg:
+ log.warning("No callback configs found! Skipping..")
+ return callbacks
+
+ if not isinstance(callbacks_cfg, DictConfig):
+ raise TypeError("Callbacks config must be a DictConfig!")
+
+ for _, cb_conf in callbacks_cfg.items():
+ if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
+ log.info(f"Instantiating callback <{cb_conf._target_}>")
+ callbacks.append(hydra.utils.instantiate(cb_conf))
+
+ return callbacks
+
+
+def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
+ """Instantiates loggers from config."""
+
+ logger: List[Logger] = []
+
+ if not logger_cfg:
+ log.warning("No logger configs found! Skipping...")
+ return logger
+
+ if not isinstance(logger_cfg, DictConfig):
+ raise TypeError("Logger config must be a DictConfig!")
+
+ for _, lg_conf in logger_cfg.items():
+ if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
+ log.info(f"Instantiating logger <{lg_conf._target_}>")
+ logger.append(hydra.utils.instantiate(lg_conf))
+
+ return logger
diff --git a/fish_speech/utils/logger.py b/fish_speech/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..94f94f738d1d87404354d086c30ef0ad9ab04cdc
--- /dev/null
+++ b/fish_speech/utils/logger.py
@@ -0,0 +1,55 @@
+import logging
+from typing import Mapping, Optional
+
+from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
+
+
+class RankedLogger(logging.LoggerAdapter):
+ """A multi-GPU-friendly python command line logger."""
+
+ def __init__(
+ self,
+ name: str = __name__,
+ rank_zero_only: bool = True,
+ extra: Optional[Mapping[str, object]] = None,
+ ) -> None:
+ """Initializes a multi-GPU-friendly python command line logger that logs on all processes
+ with their rank prefixed in the log message.
+
+ :param name: The name of the logger. Default is ``__name__``.
+ :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
+ :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
+ """
+ logger = logging.getLogger(name)
+ super().__init__(logger=logger, extra=extra)
+ self.rank_zero_only = rank_zero_only
+
+ def log(
+ self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
+ ) -> None:
+ """Delegate a log call to the underlying logger, after prefixing its message with the rank
+ of the process it's being logged from. If `'rank'` is provided, then the log will only
+ occur on that rank/process.
+
+ :param level: The level to log at. Look at `logging.__init__.py` for more information.
+ :param msg: The message to log.
+ :param rank: The rank to log at.
+ :param args: Additional args to pass to the underlying logging function.
+ :param kwargs: Any additional keyword args to pass to the underlying logging function.
+ """
+ if self.isEnabledFor(level):
+ msg, kwargs = self.process(msg, kwargs)
+ current_rank = getattr(rank_zero_only, "rank", None)
+ if current_rank is None:
+ raise RuntimeError(
+ "The `rank_zero_only.rank` needs to be set before use"
+ )
+ msg = rank_prefixed_message(msg, current_rank)
+ if self.rank_zero_only:
+ if current_rank == 0:
+ self.logger.log(level, msg, *args, **kwargs)
+ else:
+ if rank is None:
+ self.logger.log(level, msg, *args, **kwargs)
+ elif current_rank == rank:
+ self.logger.log(level, msg, *args, **kwargs)
diff --git a/fish_speech/utils/logging_utils.py b/fish_speech/utils/logging_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e3b0a2519e12845f09e5fbe86dfccbf5b345429
--- /dev/null
+++ b/fish_speech/utils/logging_utils.py
@@ -0,0 +1,48 @@
+from lightning.pytorch.utilities import rank_zero_only
+
+from fish_speech.utils import logger as log
+
+
+@rank_zero_only
+def log_hyperparameters(object_dict: dict) -> None:
+ """Controls which config parts are saved by lightning loggers.
+
+ Additionally saves:
+ - Number of model parameters
+ """
+
+ hparams = {}
+
+ cfg = object_dict["cfg"]
+ model = object_dict["model"]
+ trainer = object_dict["trainer"]
+
+ if not trainer.logger:
+ log.warning("Logger not found! Skipping hyperparameter logging...")
+ return
+
+ hparams["model"] = cfg["model"]
+
+ # save number of model parameters
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
+ hparams["model/params/trainable"] = sum(
+ p.numel() for p in model.parameters() if p.requires_grad
+ )
+ hparams["model/params/non_trainable"] = sum(
+ p.numel() for p in model.parameters() if not p.requires_grad
+ )
+
+ hparams["data"] = cfg["data"]
+ hparams["trainer"] = cfg["trainer"]
+
+ hparams["callbacks"] = cfg.get("callbacks")
+ hparams["extras"] = cfg.get("extras")
+
+ hparams["task_name"] = cfg.get("task_name")
+ hparams["tags"] = cfg.get("tags")
+ hparams["ckpt_path"] = cfg.get("ckpt_path")
+ hparams["seed"] = cfg.get("seed")
+
+ # send hparams to all loggers
+ for logger in trainer.loggers:
+ logger.log_hyperparams(hparams)
diff --git a/fish_speech/utils/rich_utils.py b/fish_speech/utils/rich_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a465f54d610779766d51e3d1a020a3b1517fd1f
--- /dev/null
+++ b/fish_speech/utils/rich_utils.py
@@ -0,0 +1,100 @@
+from pathlib import Path
+from typing import Sequence
+
+import rich
+import rich.syntax
+import rich.tree
+from hydra.core.hydra_config import HydraConfig
+from lightning.pytorch.utilities import rank_zero_only
+from omegaconf import DictConfig, OmegaConf, open_dict
+from rich.prompt import Prompt
+
+from fish_speech.utils import logger as log
+
+
+@rank_zero_only
+def print_config_tree(
+ cfg: DictConfig,
+ print_order: Sequence[str] = (
+ "data",
+ "model",
+ "callbacks",
+ "logger",
+ "trainer",
+ "paths",
+ "extras",
+ ),
+ resolve: bool = False,
+ save_to_file: bool = False,
+) -> None:
+ """Prints content of DictConfig using Rich library and its tree structure.
+
+ Args:
+ cfg (DictConfig): Configuration composed by Hydra.
+ print_order (Sequence[str], optional): Determines in what order config components are printed.
+ resolve (bool, optional): Whether to resolve reference fields of DictConfig.
+ save_to_file (bool, optional): Whether to export config to the hydra output folder.
+ """ # noqa: E501
+
+ style = "dim"
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
+
+ queue = []
+
+ # add fields from `print_order` to queue
+ for field in print_order:
+ (
+ queue.append(field)
+ if field in cfg
+ else log.warning(
+ f"Field '{field}' not found in config. "
+ + f"Skipping '{field}' config printing..."
+ )
+ )
+
+ # add all the other fields to queue (not specified in `print_order`)
+ for field in cfg:
+ if field not in queue:
+ queue.append(field)
+
+ # generate config tree from queue
+ for field in queue:
+ branch = tree.add(field, style=style, guide_style=style)
+
+ config_group = cfg[field]
+ if isinstance(config_group, DictConfig):
+ branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
+ else:
+ branch_content = str(config_group)
+
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
+
+ # print config tree
+ rich.print(tree)
+
+ # save config tree to file
+ if save_to_file:
+ with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
+ rich.print(tree, file=file)
+
+
+@rank_zero_only
+def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
+ """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501
+
+ if not cfg.get("tags"):
+ if "id" in HydraConfig().cfg.hydra.job:
+ raise ValueError("Specify tags before launching a multirun!")
+
+ log.warning("No tags provided in config. Prompting user to input tags...")
+ tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
+ tags = [t.strip() for t in tags.split(",") if t != ""]
+
+ with open_dict(cfg):
+ cfg.tags = tags
+
+ log.info(f"Tags: {cfg.tags}")
+
+ if save_to_file:
+ with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
+ rich.print(cfg.tags, file=file)
diff --git a/fish_speech/utils/spectrogram.py b/fish_speech/utils/spectrogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..01c3d7a2ab0f707ae92dbde0feb173927720c841
--- /dev/null
+++ b/fish_speech/utils/spectrogram.py
@@ -0,0 +1,122 @@
+import torch
+import torchaudio.functional as F
+from torch import Tensor, nn
+from torchaudio.transforms import MelScale
+
+
+class LinearSpectrogram(nn.Module):
+ def __init__(
+ self,
+ n_fft=2048,
+ win_length=2048,
+ hop_length=512,
+ center=False,
+ mode="pow2_sqrt",
+ ):
+ super().__init__()
+
+ self.n_fft = n_fft
+ self.win_length = win_length
+ self.hop_length = hop_length
+ self.center = center
+ self.mode = mode
+
+ self.register_buffer("window", torch.hann_window(win_length), persistent=False)
+
+ def forward(self, y: Tensor) -> Tensor:
+ if y.ndim == 3:
+ y = y.squeeze(1)
+
+ y = torch.nn.functional.pad(
+ y.unsqueeze(1),
+ (
+ (self.win_length - self.hop_length) // 2,
+ (self.win_length - self.hop_length + 1) // 2,
+ ),
+ mode="reflect",
+ ).squeeze(1)
+
+ spec = torch.stft(
+ y,
+ self.n_fft,
+ hop_length=self.hop_length,
+ win_length=self.win_length,
+ window=self.window,
+ center=self.center,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=True,
+ )
+
+ spec = torch.view_as_real(spec)
+
+ if self.mode == "pow2_sqrt":
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
+
+ return spec
+
+
+class LogMelSpectrogram(nn.Module):
+ def __init__(
+ self,
+ sample_rate=44100,
+ n_fft=2048,
+ win_length=2048,
+ hop_length=512,
+ n_mels=128,
+ center=False,
+ f_min=0.0,
+ f_max=None,
+ ):
+ super().__init__()
+
+ self.sample_rate = sample_rate
+ self.n_fft = n_fft
+ self.win_length = win_length
+ self.hop_length = hop_length
+ self.center = center
+ self.n_mels = n_mels
+ self.f_min = f_min
+ self.f_max = f_max or float(sample_rate // 2)
+
+ self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
+
+ fb = F.melscale_fbanks(
+ n_freqs=self.n_fft // 2 + 1,
+ f_min=self.f_min,
+ f_max=self.f_max,
+ n_mels=self.n_mels,
+ sample_rate=self.sample_rate,
+ norm="slaney",
+ mel_scale="slaney",
+ )
+ self.register_buffer(
+ "fb",
+ fb,
+ persistent=False,
+ )
+
+ def compress(self, x: Tensor) -> Tensor:
+ return torch.log(torch.clamp(x, min=1e-5))
+
+ def decompress(self, x: Tensor) -> Tensor:
+ return torch.exp(x)
+
+ def apply_mel_scale(self, x: Tensor) -> Tensor:
+ return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
+
+ def forward(
+ self, x: Tensor, return_linear: bool = False, sample_rate: int = None
+ ) -> Tensor:
+ if sample_rate is not None and sample_rate != self.sample_rate:
+ x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
+
+ linear = self.spectrogram(x)
+ x = self.apply_mel_scale(linear)
+ x = self.compress(x)
+
+ if return_linear:
+ return x, self.compress(linear)
+
+ return x
diff --git a/fish_speech/utils/utils.py b/fish_speech/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a34bdcfedff76c333f50ed8be050d0dd5a8f98a
--- /dev/null
+++ b/fish_speech/utils/utils.py
@@ -0,0 +1,136 @@
+import random
+import warnings
+from importlib.util import find_spec
+from typing import Callable
+
+import numpy as np
+import torch
+from omegaconf import DictConfig
+
+from .logger import RankedLogger
+from .rich_utils import enforce_tags, print_config_tree
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def extras(cfg: DictConfig) -> None:
+ """Applies optional utilities before the task is started.
+
+ Utilities:
+ - Ignoring python warnings
+ - Setting tags from command line
+ - Rich config printing
+ """
+
+ # return if no `extras` config
+ if not cfg.get("extras"):
+ log.warning("Extras config not found! ")
+ return
+
+ # disable python warnings
+ if cfg.extras.get("ignore_warnings"):
+ log.info("Disabling python warnings! ")
+ warnings.filterwarnings("ignore")
+
+ # prompt user to input tags from command line if none are provided in the config
+ if cfg.extras.get("enforce_tags"):
+ log.info("Enforcing tags! ")
+ enforce_tags(cfg, save_to_file=True)
+
+ # pretty print config tree using Rich library
+ if cfg.extras.get("print_config"):
+ log.info("Printing config tree with Rich! ")
+ print_config_tree(cfg, resolve=True, save_to_file=True)
+
+
+def task_wrapper(task_func: Callable) -> Callable:
+ """Optional decorator that controls the failure behavior when executing the task function.
+
+ This wrapper can be used to:
+ - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
+ - save the exception to a `.log` file
+ - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
+ - etc. (adjust depending on your needs)
+
+ Example:
+ ```
+ @utils.task_wrapper
+ def train(cfg: DictConfig) -> Tuple[dict, dict]:
+
+ ...
+
+ return metric_dict, object_dict
+ ```
+ """ # noqa: E501
+
+ def wrap(cfg: DictConfig):
+ # execute the task
+ try:
+ metric_dict, object_dict = task_func(cfg=cfg)
+
+ # things to do if exception occurs
+ except Exception as ex:
+ # save exception to `.log` file
+ log.exception("")
+
+ # some hyperparameter combinations might be invalid or
+ # cause out-of-memory errors so when using hparam search
+ # plugins like Optuna, you might want to disable
+ # raising the below exception to avoid multirun failure
+ raise ex
+
+ # things to always do after either success or exception
+ finally:
+ # display output dir path in terminal
+ log.info(f"Output dir: {cfg.paths.run_dir}")
+
+ # always close wandb run (even if exception occurs so multirun won't fail)
+ if find_spec("wandb"): # check if wandb is installed
+ import wandb
+
+ if wandb.run:
+ log.info("Closing wandb!")
+ wandb.finish()
+
+ return metric_dict, object_dict
+
+ return wrap
+
+
+def get_metric_value(metric_dict: dict, metric_name: str) -> float:
+ """Safely retrieves value of the metric logged in LightningModule."""
+
+ if not metric_name:
+ log.info("Metric name is None! Skipping metric value retrieval...")
+ return None
+
+ if metric_name not in metric_dict:
+ raise Exception(
+ f"Metric value not found! \n"
+ "Make sure metric name logged in LightningModule is correct!\n"
+ "Make sure `optimized_metric` name in `hparams_search` config is correct!"
+ )
+
+ metric_value = metric_dict[metric_name].item()
+ log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
+
+ return metric_value
+
+
+def set_seed(seed: int):
+ if seed < 0:
+ seed = -seed
+ if seed > (1 << 31):
+ seed = 1 << 31
+
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+ if torch.backends.cudnn.is_available():
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
diff --git a/fish_speech/webui/__pycache__/launch_utils.cpython-310.pyc b/fish_speech/webui/__pycache__/launch_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5a6ed5a24f1b659ff6aeb936748358bc413dc26e
Binary files /dev/null and b/fish_speech/webui/__pycache__/launch_utils.cpython-310.pyc differ
diff --git a/fish_speech/webui/css/style.css b/fish_speech/webui/css/style.css
new file mode 100644
index 0000000000000000000000000000000000000000..3c7a22ecc31881a65a76369b0fd889330a0874c7
--- /dev/null
+++ b/fish_speech/webui/css/style.css
@@ -0,0 +1,161 @@
+:root {
+ --my-200: #80eeee;
+ --my-50: #ecfdf5;
+ --water-width: 300px;
+ --water-heigh: 300px;
+}
+
+
+/* general styled components */
+.tools {
+ align-items: center;
+ justify-content: center;
+}
+
+.gradio-button {
+ max-width: 2.2em;
+ min-width: 2.2em !important;
+ height: 2.4em;
+ align-self: end;
+ line-height: 1em;
+ border-radius: 0.5em;
+
+}
+
+.gradio-button.secondary-down, .gradio-button.secondary-down:hover{
+ box-shadow: 1px 1px 1px rgba(0,0,0,0.25) inset, 0px 0px 3px rgba(0,0,0,0.15) inset;
+}
+
+/* replace original footer with ours */
+a{
+ font-weight: bold;
+ cursor: pointer;
+ color: #030C14 !important;
+}
+
+footer {
+ display: none !important;
+}
+
+#footer{
+ text-align: center;
+}
+
+#footer div{
+ display: inline-block;
+}
+
+#footer .versions{
+ font-size: 85%;
+ opacity: 0.85;
+}
+
+/*@keyframes moveBackground {*/
+/* 0% {*/
+/* background-position: 0 0;*/
+/* }*/
+/* 100% {*/
+/* background-position: -100px 100px;*/
+/* }*/
+/*}*/
+@keyframes moveJellyBackground {
+ 0% {
+ background-position: 0% 50%;
+ }
+ 50% {
+ background-position: 100% 50%;
+ }
+ 100% {
+ background-position: 0% 50%;
+ }
+}
+
+.gradio-container {
+ position: absolute;
+ z-index: 10;
+}
+
+
+.quan {
+ position: absolute;
+ bottom: 0;
+ width: var(--water-width);
+ height: var(--water-heigh);
+ border-radius: 0;
+ /*border: 3px solid rgb(246, 247, 248);*/
+ /*box-shadow: 0 0 0 3px rgb(41, 134, 196);*/
+ z-index: 0;
+
+}
+
+.quan:last-child {
+ margin-right: 0;
+}
+
+.shui {
+ position: absolute;
+ top: 0;
+ left: 0;
+ width: 100%;
+ height: 100%;
+ background-color: rgb(23, 106, 201);
+ border-radius: 0;
+ overflow: hidden;
+ z-index: 0;
+}
+
+.shui::after {
+
+ content: '';
+ position: absolute;
+ top: 20%;
+ left: 50%;
+ width: 150%;
+ height: 150%;
+ border-radius: 40%;
+ background-image: radial-gradient(circle at 0% 50%, #dcfcf1, var(--my-50) 50%);
+ animation: shi 5s linear infinite;
+}
+
+@keyframes shi {
+ 0% {
+ transform: translate(-50%, -65%) rotate(0deg);
+ }
+ 100% {
+ transform: translate(-50%, -65%) rotate(360deg);
+ }
+}
+
+.shui::before {
+ content: '';
+ position: absolute;
+ top: 20%;
+ left: 50%;
+ width: 150%;
+ height: 150%;
+ border-radius: 42%;
+ background-color: rgb(240, 228, 228, 0.2);
+ animation: xu 7s linear infinite;
+}
+
+@keyframes xu {
+ 0% {
+ transform: translate(-50%, -60%) rotate(0deg);
+ }
+ 100% {
+ transform: translate(-50%, -60%) rotate(360deg);
+ }
+}
+
+fieldset.data_src div.wrap label {
+ background: #f8bffee0 !important;
+}
+
+.scrollable-component {
+ max-height: 100px;
+ overflow-y: auto;
+}
+
+#file_accordion {
+ max-height: 220px !important;
+}
diff --git a/fish_speech/webui/html/footer.html b/fish_speech/webui/html/footer.html
new file mode 100644
index 0000000000000000000000000000000000000000..ac1745aa6f41f86a17e3d95564c2bf7a8d7bb615
--- /dev/null
+++ b/fish_speech/webui/html/footer.html
@@ -0,0 +1,11 @@
+
+
+
+{versions}
+
diff --git a/fish_speech/webui/js/animate.js b/fish_speech/webui/js/animate.js
new file mode 100644
index 0000000000000000000000000000000000000000..0637a541a8e704632a42b89bdf1471b26e7bb868
--- /dev/null
+++ b/fish_speech/webui/js/animate.js
@@ -0,0 +1,69 @@
+
+function createGradioAnimation() {
+ const params = new URLSearchParams(window.location.search);
+ if (!params.has('__theme')) {
+ params.set('__theme', 'light');
+ window.location.search = params.toString();
+ }
+
+ var gradioApp = document.querySelector('gradio-app');
+ if (gradioApp) {
+
+ document.documentElement.style.setProperty('--my-200', '#80eeee');
+ document.documentElement.style.setProperty('--my-50', '#ecfdf5');
+
+ // gradioApp.style.position = 'relative';
+ // gradioApp.style.backgroundSize = '200% 200%';
+ // gradioApp.style.animation = 'moveJellyBackground 10s ease infinite';
+ // gradioApp.style.backgroundImage = 'radial-gradient(circle at 0% 50%, var(--my-200), var(--my-50) 50%)';
+ // gradioApp.style.display = 'flex';
+ // gradioApp.style.justifyContent = 'flex-start';
+ // gradioApp.style.flexWrap = 'nowrap';
+ // gradioApp.style.overflowX = 'auto';
+
+ // for (let i = 0; i < 6; i++) {
+ // var quan = document.createElement('div');
+ // quan.className = 'quan';
+ // gradioApp.insertBefore(quan, gradioApp.firstChild);
+ // quan.id = 'quan' + i.toString();
+ // quan.style.left = 'calc(var(--water-width) * ' + i.toString() + ')';
+ // var quanContainer = document.querySelector('.quan');
+ // if (quanContainer) {
+ // var shui = document.createElement('div');
+ // shui.className = 'shui';
+ // quanContainer.insertBefore(shui, quanContainer.firstChild)
+ // }
+ // }
+ }
+
+ var container = document.createElement('div');
+ container.id = 'gradio-animation';
+ container.style.fontSize = '2em';
+ container.style.fontFamily = 'Maiandra GD, ui-monospace, monospace';
+ container.style.fontWeight = 'bold';
+ container.style.textAlign = 'center';
+ container.style.marginBottom = '20px';
+
+ var text = 'Welcome to Fish-Speech!';
+ for (var i = 0; i < text.length; i++) {
+ (function(i){
+ setTimeout(function(){
+ var letter = document.createElement('span');
+ letter.style.opacity = '0';
+ letter.style.transition = 'opacity 0.5s';
+ letter.innerText = text[i];
+
+ container.appendChild(letter);
+
+ setTimeout(function() {
+ letter.style.opacity = '1';
+ }, 50);
+ }, i * 200);
+ })(i);
+ }
+
+ var gradioContainer = document.querySelector('.gradio-container');
+ gradioContainer.insertBefore(container, gradioContainer.firstChild);
+
+ return 'Animation created';
+}
diff --git a/fish_speech/webui/launch_utils.py b/fish_speech/webui/launch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..790c0e632ce55e099e5578d8824e94b1d1260d6e
--- /dev/null
+++ b/fish_speech/webui/launch_utils.py
@@ -0,0 +1,120 @@
+import importlib.util
+import os
+import subprocess
+import sys
+from functools import lru_cache
+from pathlib import Path
+from typing import Iterable
+
+import gradio as gr
+from gradio.themes.base import Base
+from gradio.themes.utils import colors, fonts, sizes
+
+GIT = (
+ (Path(os.environ.get("GIT_HOME", "")) / "git").resolve()
+ if sys.platform == "win32"
+ else "git"
+)
+GIT = str(GIT)
+
+
+def is_module_installed(module_name: str) -> bool:
+ spec = importlib.util.find_spec(module_name)
+ return spec is not None
+
+
+@lru_cache()
+def commit_hash():
+ try:
+ return subprocess.check_output(
+ [GIT, "log", "-1", "--format='%h %s'"], shell=False, encoding="utf8"
+ ).strip()
+ except Exception:
+ return ""
+
+
+def versions_html():
+ import torch
+
+ python_version = ".".join([str(x) for x in sys.version_info[0:3]])
+ commit = commit_hash()
+ hash = commit.strip("'").split(" ")[0]
+
+ return f"""
+version: {hash}
+ •
+python: {python_version}
+ •
+torch: {getattr(torch, '__long_version__',torch.__version__)}
+ •
+gradio: {gr.__version__}
+ •
+author: fishaudio
+"""
+
+
+def version_check(commit):
+ try:
+ import requests
+
+ commits = requests.get(
+ "https://api.github.com/repos/fishaudio/fish-speech/branches/main"
+ ).json()
+ if commit != "" and commits["commit"]["sha"] != commit:
+ print("--------------------------------------------------------")
+ print("| You are not up to date with the most recent release. |")
+ print("| Consider running `git pull` to update. |")
+ print("--------------------------------------------------------")
+ elif commits["commit"]["sha"] == commit:
+ print("You are up to date with the most recent release.")
+ else:
+ print("Not a git clone, can't perform version check.")
+ except Exception as e:
+ print("version check failed", e)
+
+
+class Seafoam(Base):
+ def __init__(
+ self,
+ *,
+ primary_hue: colors.Color | str = colors.emerald,
+ secondary_hue: colors.Color | str = colors.blue,
+ neutral_hue: colors.Color | str = colors.blue,
+ spacing_size: sizes.Size | str = sizes.spacing_md,
+ radius_size: sizes.Size | str = sizes.radius_md,
+ text_size: sizes.Size | str = sizes.text_lg,
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
+ fonts.GoogleFont("Quicksand"),
+ "ui-sans-serif",
+ "sans-serif",
+ ),
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
+ fonts.GoogleFont("IBM Plex Mono"),
+ "ui-monospace",
+ "monospace",
+ ),
+ ):
+ super().__init__(
+ primary_hue=primary_hue,
+ secondary_hue=secondary_hue,
+ neutral_hue=neutral_hue,
+ spacing_size=spacing_size,
+ radius_size=radius_size,
+ text_size=text_size,
+ font=font,
+ font_mono=font_mono,
+ )
+ super().set(
+ button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)",
+ button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)",
+ button_primary_text_color="white",
+ button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
+ slider_color="*secondary_300",
+ slider_color_dark="*secondary_600",
+ block_title_text_weight="600",
+ block_border_width="3px",
+ block_shadow="*shadow_drop_lg",
+ # button_shadow="*shadow_drop_lg",
+ button_small_padding="0px",
+ button_large_padding="3px",
+ )
diff --git a/fish_speech/webui/manage.py b/fish_speech/webui/manage.py
new file mode 100644
index 0000000000000000000000000000000000000000..c21233eee3e3e99754c68efc2b8809a62217eb53
--- /dev/null
+++ b/fish_speech/webui/manage.py
@@ -0,0 +1,1239 @@
+from __future__ import annotations
+
+import os
+
+os.environ["USE_LIBUV"] = "0"
+import datetime
+import html
+import json
+import platform
+import shutil
+import signal
+import subprocess
+import sys
+from pathlib import Path
+
+import gradio as gr
+import psutil
+import yaml
+from loguru import logger
+from tqdm import tqdm
+
+PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python")
+sys.path.insert(0, "")
+print(sys.path)
+cur_work_dir = Path(os.getcwd()).resolve()
+print("You are in ", str(cur_work_dir))
+
+from fish_speech.i18n import i18n
+from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html
+
+config_path = cur_work_dir / "fish_speech" / "configs"
+vqgan_yml_path = config_path / "firefly_gan_vq.yaml"
+llama_yml_path = config_path / "text2semantic_finetune.yaml"
+
+env = os.environ.copy()
+env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0"
+
+seafoam = Seafoam()
+
+
+def build_html_error_message(error):
+ return f"""
+
+ {html.escape(error)}
+
+ """
+
+
+def build_html_ok_message(msg):
+ return f"""
+
+ {html.escape(msg)}
+
+ """
+
+
+def build_html_href(link, desc, msg):
+ return f"""
+
+ {html.escape(msg)}
+ {desc}
+
+ """
+
+
+def load_data_in_raw(path):
+ with open(path, "r", encoding="utf-8") as file:
+ data = file.read()
+ return str(data)
+
+
+def kill_proc_tree(pid, including_parent=True):
+ try:
+ parent = psutil.Process(pid)
+ except psutil.NoSuchProcess:
+ # Process already terminated
+ return
+
+ children = parent.children(recursive=True)
+ for child in children:
+ try:
+ os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
+ except OSError:
+ pass
+ if including_parent:
+ try:
+ os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
+ except OSError:
+ pass
+
+
+system = platform.system()
+p_label = None
+p_infer = None
+p_tensorboard = None
+
+
+def kill_process(pid):
+ if system == "Windows":
+ cmd = "taskkill /t /f /pid %s" % pid
+ # os.system(cmd)
+ subprocess.run(cmd)
+ else:
+ kill_proc_tree(pid)
+
+
+def change_label(if_label):
+ global p_label
+ if if_label == True and p_label is None:
+ url = "http://localhost:3000"
+ remote_url = "https://text-labeler.pages.dev/"
+ try:
+ p_label = subprocess.Popen(
+ [
+ (
+ "asr-label-linux-x64"
+ if sys.platform == "linux"
+ else "asr-label-win-x64.exe"
+ )
+ ]
+ )
+ except FileNotFoundError:
+ logger.warning("asr-label execution not found!")
+
+ yield build_html_href(
+ link=remote_url,
+ desc=i18n("Optional online ver"),
+ msg=i18n("Opened labeler in browser"),
+ )
+
+ elif if_label == False and p_label is not None:
+ kill_process(p_label.pid)
+ p_label = None
+ yield build_html_ok_message("Nothing")
+
+
+def clean_infer_cache():
+ import tempfile
+
+ temp_dir = Path(tempfile.gettempdir())
+ gradio_dir = str(temp_dir / "gradio")
+ try:
+ shutil.rmtree(gradio_dir)
+ logger.info(f"Deleted cached audios: {gradio_dir}")
+ except PermissionError:
+ logger.info(f"Permission denied: Unable to delete {gradio_dir}")
+ except FileNotFoundError:
+ logger.info(f"{gradio_dir} was not found")
+ except Exception as e:
+ logger.info(f"An error occurred: {e}")
+
+
+def change_infer(
+ if_infer,
+ host,
+ port,
+ infer_decoder_model,
+ infer_decoder_config,
+ infer_llama_model,
+ infer_compile,
+):
+ global p_infer
+ if if_infer == True and p_infer == None:
+ env = os.environ.copy()
+
+ env["GRADIO_SERVER_NAME"] = host
+ env["GRADIO_SERVER_PORT"] = port
+ # 启动第二个进程
+ url = f"http://{host}:{port}"
+ yield build_html_ok_message(
+ i18n("Inferring interface is launched at {}").format(url)
+ )
+
+ clean_infer_cache()
+
+ p_infer = subprocess.Popen(
+ [
+ PYTHON,
+ "tools/webui.py",
+ "--decoder-checkpoint-path",
+ infer_decoder_model,
+ "--decoder-config-name",
+ infer_decoder_config,
+ "--llama-checkpoint-path",
+ infer_llama_model,
+ ]
+ + (["--compile"] if infer_compile == "Yes" else []),
+ env=env,
+ )
+
+ elif if_infer == False and p_infer is not None:
+ kill_process(p_infer.pid)
+ p_infer = None
+ yield build_html_error_message(i18n("Infer interface is closed"))
+
+
+js = load_data_in_raw("fish_speech/webui/js/animate.js")
+css = load_data_in_raw("fish_speech/webui/css/style.css")
+
+data_pre_output = (cur_work_dir / "data").resolve()
+default_model_output = (cur_work_dir / "results").resolve()
+default_filelist = data_pre_output / "detect.list"
+data_pre_output.mkdir(parents=True, exist_ok=True)
+
+items = []
+dict_items = {}
+
+
+def load_yaml_data_in_fact(yml_path):
+ with open(yml_path, "r", encoding="utf-8") as file:
+ yml = yaml.safe_load(file)
+ return yml
+
+
+def write_yaml_data_in_fact(yml, yml_path):
+ with open(yml_path, "w", encoding="utf-8") as file:
+ yaml.safe_dump(yml, file, allow_unicode=True)
+ return yml
+
+
+def generate_tree(directory, depth=0, max_depth=None, prefix=""):
+ if max_depth is not None and depth > max_depth:
+ return ""
+
+ tree_str = ""
+ files = []
+ directories = []
+ for item in os.listdir(directory):
+ if os.path.isdir(os.path.join(directory, item)):
+ directories.append(item)
+ else:
+ files.append(item)
+
+ entries = directories + files
+ for i, entry in enumerate(entries):
+ connector = "├── " if i < len(entries) - 1 else "└── "
+ tree_str += f"{prefix}{connector}{entry}
"
+ if i < len(directories):
+ extension = "│ " if i < len(entries) - 1 else " "
+ tree_str += generate_tree(
+ os.path.join(directory, entry),
+ depth + 1,
+ max_depth,
+ prefix=prefix + extension,
+ )
+ return tree_str
+
+
+def new_explorer(data_path, max_depth):
+ return gr.Markdown(
+ elem_classes=["scrollable-component"],
+ value=generate_tree(data_path, max_depth=max_depth),
+ )
+
+
+def add_item(
+ folder: str,
+ method: str,
+ label_lang: str,
+ if_initial_prompt: bool,
+ initial_prompt: str | None,
+):
+ folder = folder.strip(" ").strip('"')
+
+ folder_path = Path(folder)
+
+ if folder and folder not in items and data_pre_output not in folder_path.parents:
+ if folder_path.is_dir():
+ items.append(folder)
+ dict_items[folder] = dict(
+ type="folder",
+ method=method,
+ label_lang=label_lang,
+ initial_prompt=initial_prompt if if_initial_prompt else None,
+ )
+ elif folder:
+ err = folder
+ return gr.Checkboxgroup(choices=items), build_html_error_message(
+ i18n("Invalid path: {}").format(err)
+ )
+
+ formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
+ logger.info("After Adding: " + formatted_data)
+ gr.Info(formatted_data)
+ return gr.Checkboxgroup(choices=items), build_html_ok_message(
+ i18n("Added path successfully!")
+ )
+
+
+def remove_items(selected_items):
+ global items, dict_items
+ to_remove = [item for item in items if item in selected_items]
+ for item in to_remove:
+ del dict_items[item]
+ items = [item for item in items if item in dict_items.keys()]
+ formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
+ logger.info(formatted_data)
+ gr.Warning("After Removing: " + formatted_data)
+ return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message(
+ i18n("Removed path successfully!")
+ )
+
+
+def show_selected(options):
+ selected_options = ", ".join(options)
+
+ if options:
+ return i18n("Selected: {}").format(selected_options)
+ else:
+ return i18n("No selected options")
+
+
+from pydub import AudioSegment
+
+
+def convert_to_mono_in_place(audio_path: Path):
+ audio = AudioSegment.from_file(audio_path)
+ if audio.channels > 1:
+ mono_audio = audio.set_channels(1)
+ mono_audio.export(audio_path, format=audio_path.suffix[1:])
+ logger.info(f"Convert {audio_path} successfully")
+
+
+def list_copy(list_file_path, method):
+ wav_root = data_pre_output
+ lst = []
+ with list_file_path.open("r", encoding="utf-8") as file:
+ for line in tqdm(file, desc="Processing audio/transcript"):
+ wav_path, speaker_name, language, text = line.strip().split("|")
+ original_wav_path = Path(wav_path)
+ target_wav_path = (
+ wav_root / original_wav_path.parent.name / original_wav_path.name
+ )
+ lst.append(f"{target_wav_path}|{speaker_name}|{language}|{text}")
+ if target_wav_path.is_file():
+ continue
+ target_wav_path.parent.mkdir(parents=True, exist_ok=True)
+ if method == i18n("Copy"):
+ shutil.copy(original_wav_path, target_wav_path)
+ else:
+ shutil.move(original_wav_path, target_wav_path.parent)
+ convert_to_mono_in_place(target_wav_path)
+ original_lab_path = original_wav_path.with_suffix(".lab")
+ target_lab_path = (
+ wav_root
+ / original_wav_path.parent.name
+ / original_wav_path.with_suffix(".lab").name
+ )
+ if target_lab_path.is_file():
+ continue
+ if method == i18n("Copy"):
+ shutil.copy(original_lab_path, target_lab_path)
+ else:
+ shutil.move(original_lab_path, target_lab_path.parent)
+
+ if method == i18n("Move"):
+ with list_file_path.open("w", encoding="utf-8") as file:
+ file.writelines("\n".join(lst))
+
+ del lst
+ return build_html_ok_message(i18n("Use filelist"))
+
+
+def check_files(data_path: str, max_depth: int, label_model: str, label_device: str):
+ global dict_items
+ data_path = Path(data_path)
+ gr.Warning("Pre-processing begins...")
+ for item, content in dict_items.items():
+ item_path = Path(item)
+ tar_path = data_path / item_path.name
+
+ if content["type"] == "folder" and item_path.is_dir():
+ if content["method"] == i18n("Copy"):
+ os.makedirs(tar_path, exist_ok=True)
+ shutil.copytree(
+ src=str(item_path), dst=str(tar_path), dirs_exist_ok=True
+ )
+ elif not tar_path.is_dir():
+ shutil.move(src=str(item_path), dst=str(tar_path))
+
+ for suf in ["wav", "flac", "mp3"]:
+ for audio_path in tar_path.glob(f"**/*.{suf}"):
+ convert_to_mono_in_place(audio_path)
+
+ cur_lang = content["label_lang"]
+ initial_prompt = content["initial_prompt"]
+
+ transcribe_cmd = [
+ PYTHON,
+ "tools/whisper_asr.py",
+ "--model-size",
+ label_model,
+ "--device",
+ label_device,
+ "--audio-dir",
+ tar_path,
+ "--save-dir",
+ tar_path,
+ "--language",
+ cur_lang,
+ ]
+
+ if initial_prompt is not None:
+ transcribe_cmd += ["--initial-prompt", initial_prompt]
+
+ if cur_lang != "IGNORE":
+ try:
+ gr.Warning("Begin To Transcribe")
+ subprocess.run(
+ transcribe_cmd,
+ env=env,
+ )
+ except Exception:
+ print("Transcription error occurred")
+
+ elif content["type"] == "file" and item_path.is_file():
+ list_copy(item_path, content["method"])
+
+ return build_html_ok_message(i18n("Move files successfully")), new_explorer(
+ data_path, max_depth=max_depth
+ )
+
+
+def generate_folder_name():
+ now = datetime.datetime.now()
+ folder_name = now.strftime("%Y%m%d_%H%M%S")
+ return folder_name
+
+
+def train_process(
+ data_path: str,
+ option: str,
+ # llama config
+ llama_ckpt,
+ llama_base_config,
+ llama_lr,
+ llama_maxsteps,
+ llama_data_num_workers,
+ llama_data_batch_size,
+ llama_data_max_length,
+ llama_precision,
+ llama_check_interval,
+ llama_grad_batches,
+ llama_use_speaker,
+ llama_use_lora,
+):
+
+ backend = "nccl" if sys.platform == "linux" else "gloo"
+
+ new_project = generate_folder_name()
+ print("New Project Name: ", new_project)
+
+ if option == "VQGAN":
+ msg = "Skipped VQGAN Training."
+ gr.Warning(msg)
+ logger.info(msg)
+
+ if option == "LLAMA":
+ msg = "LLAMA Training begins..."
+ gr.Warning(msg)
+ logger.info(msg)
+ subprocess.run(
+ [
+ PYTHON,
+ "tools/vqgan/extract_vq.py",
+ str(data_pre_output),
+ "--num-workers",
+ "1",
+ "--batch-size",
+ "16",
+ "--config-name",
+ "firefly_gan_vq",
+ "--checkpoint-path",
+ "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+ ]
+ )
+
+ subprocess.run(
+ [
+ PYTHON,
+ "tools/llama/build_dataset.py",
+ "--input",
+ str(data_pre_output),
+ "--text-extension",
+ ".lab",
+ "--num-workers",
+ "16",
+ ]
+ )
+ ckpt_path = "checkpoints/fish-speech-1.4/model.pth"
+ lora_prefix = "lora_" if llama_use_lora else ""
+ llama_name = lora_prefix + "text2semantic_" + new_project
+ latest = next(
+ iter(
+ sorted(
+ [
+ str(p.relative_to("results"))
+ for p in Path("results").glob(lora_prefix + "text2sem*/")
+ ],
+ reverse=True,
+ )
+ ),
+ llama_name,
+ )
+ project = (
+ llama_name
+ if llama_ckpt == i18n("new")
+ else (
+ latest
+ if llama_ckpt == i18n("latest")
+ else Path(llama_ckpt).relative_to("results")
+ )
+ )
+ logger.info(project)
+
+ if llama_check_interval > llama_maxsteps:
+ llama_check_interval = llama_maxsteps
+
+ train_cmd = [
+ PYTHON,
+ "fish_speech/train.py",
+ "--config-name",
+ "text2semantic_finetune",
+ f"project={project}",
+ f"trainer.strategy.process_group_backend={backend}",
+ f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
+ f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
+ f"model.optimizer.lr={llama_lr}",
+ f"trainer.max_steps={llama_maxsteps}",
+ f"data.num_workers={llama_data_num_workers}",
+ f"data.batch_size={llama_data_batch_size}",
+ f"max_length={llama_data_max_length}",
+ f"trainer.precision={llama_precision}",
+ f"trainer.val_check_interval={llama_check_interval}",
+ f"trainer.accumulate_grad_batches={llama_grad_batches}",
+ f"train_dataset.interactive_prob={llama_use_speaker}",
+ ] + ([f"+lora@model.model.lora_config=r_8_alpha_16"] if llama_use_lora else [])
+ logger.info(train_cmd)
+ subprocess.run(train_cmd)
+
+ return build_html_ok_message(i18n("Training stopped"))
+
+
+def tensorboard_process(
+ if_tensorboard: bool,
+ tensorboard_dir: str,
+ host: str,
+ port: str,
+):
+ global p_tensorboard
+ if if_tensorboard == True and p_tensorboard == None:
+ url = f"http://{host}:{port}"
+ yield build_html_ok_message(
+ i18n("Tensorboard interface is launched at {}").format(url)
+ )
+ prefix = ["tensorboard"]
+ if Path("fishenv").exists():
+ prefix = ["fishenv/env/python.exe", "fishenv/env/Scripts/tensorboard.exe"]
+
+ p_tensorboard = subprocess.Popen(
+ prefix
+ + [
+ "--logdir",
+ tensorboard_dir,
+ "--host",
+ host,
+ "--port",
+ port,
+ "--reload_interval",
+ "120",
+ ]
+ )
+ elif if_tensorboard == False and p_tensorboard != None:
+ kill_process(p_tensorboard.pid)
+ p_tensorboard = None
+ yield build_html_error_message(i18n("Tensorboard interface is closed"))
+
+
+def fresh_tb_dir():
+ return gr.Dropdown(
+ choices=[str(p) for p in Path("results").glob("**/tensorboard/")]
+ )
+
+
+def list_decoder_models():
+ paths = [str(p) for p in Path("checkpoints").glob("fish*/firefly*.pth")]
+ if not paths:
+ logger.warning("No decoder model found")
+ return paths
+
+
+def list_llama_models():
+ choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")]
+ choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")]
+ choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")]
+ choices = sorted(choices, reverse=True)
+ if not choices:
+ logger.warning("No LLaMA model found")
+ return choices
+
+
+def list_lora_llama_models():
+ choices = sorted(
+ [str(p) for p in Path("results").glob("lora*/**/*.ckpt")], reverse=True
+ )
+ if not choices:
+ logger.warning("No LoRA LLaMA model found")
+ return choices
+
+
+def fresh_decoder_model():
+ return gr.Dropdown(choices=list_decoder_models())
+
+
+def fresh_llama_ckpt(llama_use_lora):
+ return gr.Dropdown(
+ choices=[i18n("latest"), i18n("new")]
+ + (
+ [str(p) for p in Path("results").glob("text2sem*/")]
+ if not llama_use_lora
+ else [str(p) for p in Path("results").glob("lora_*/")]
+ )
+ )
+
+
+def fresh_llama_model():
+ return gr.Dropdown(choices=list_llama_models())
+
+
+def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_output):
+ if (
+ lora_weight is None
+ or not Path(lora_weight).exists()
+ or not Path(llama_weight).exists()
+ ):
+ return build_html_error_message(
+ i18n(
+ "Path error, please check the model file exists in the corresponding path"
+ )
+ )
+ gr.Warning("Merging begins...")
+ merge_cmd = [
+ PYTHON,
+ "tools/llama/merge_lora.py",
+ "--lora-config",
+ "r_8_alpha_16",
+ "--lora-weight",
+ lora_weight,
+ "--output",
+ llama_lora_output + "_" + generate_folder_name(),
+ ]
+ logger.info(merge_cmd)
+ subprocess.run(merge_cmd)
+ return build_html_ok_message(i18n("Merge successfully"))
+
+
+def llama_quantify(llama_weight, quantify_mode):
+ if llama_weight is None or not Path(llama_weight).exists():
+ return build_html_error_message(
+ i18n(
+ "Path error, please check the model file exists in the corresponding path"
+ )
+ )
+
+ gr.Warning("Quantifying begins...")
+
+ now = generate_folder_name()
+ quantify_cmd = [
+ PYTHON,
+ "tools/llama/quantize.py",
+ "--checkpoint-path",
+ llama_weight,
+ "--mode",
+ quantify_mode,
+ "--timestamp",
+ now,
+ ]
+ logger.info(quantify_cmd)
+ subprocess.run(quantify_cmd)
+ if quantify_mode == "int8":
+ quantize_path = str(
+ Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-{now}"
+ )
+ else:
+ quantize_path = str(
+ Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-g128-{now}"
+ )
+ return build_html_ok_message(
+ i18n("Quantify successfully") + f"Path: {quantize_path}"
+ )
+
+
+init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
+init_llama_yml = load_yaml_data_in_fact(llama_yml_path)
+
+with gr.Blocks(
+ head="",
+ js=js,
+ theme=seafoam,
+ analytics_enabled=False,
+ title="Fish Speech",
+) as demo:
+ with gr.Row():
+ with gr.Column():
+ with gr.Tab("\U0001F4D6 " + i18n("Data Preprocessing")):
+ with gr.Row():
+ textbox = gr.Textbox(
+ label="\U0000270F "
+ + i18n("Input Audio & Source Path for Transcription"),
+ info=i18n("Speaker is identified by the folder name"),
+ interactive=True,
+ )
+ with gr.Row(equal_height=False):
+ with gr.Column():
+ output_radio = gr.Radio(
+ label="\U0001F4C1 "
+ + i18n("Select source file processing method"),
+ choices=[i18n("Copy"), i18n("Move")],
+ value=i18n("Copy"),
+ interactive=True,
+ )
+ with gr.Column():
+ error = gr.HTML(label=i18n("Error Message"))
+ if_label = gr.Checkbox(
+ label=i18n("Open Labeler WebUI"), scale=0, show_label=True
+ )
+
+ with gr.Row():
+ label_device = gr.Dropdown(
+ label=i18n("Labeling Device"),
+ info=i18n(
+ "It is recommended to use CUDA, if you have low configuration, use CPU"
+ ),
+ choices=["cpu", "cuda"],
+ value="cuda",
+ interactive=True,
+ )
+ label_model = gr.Dropdown(
+ label=i18n("Whisper Model"),
+ info=i18n("Faster Whisper, Up to 5g GPU memory usage"),
+ choices=["large-v3", "medium"],
+ value="large-v3",
+ interactive=True,
+ )
+ label_radio = gr.Dropdown(
+ label=i18n("Optional Label Language"),
+ info=i18n(
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format"
+ ),
+ choices=[
+ (i18n("Chinese"), "zh"),
+ (i18n("English"), "en"),
+ (i18n("Japanese"), "ja"),
+ (i18n("Disabled"), "IGNORE"),
+ (i18n("auto"), "auto"),
+ ],
+ value="IGNORE",
+ interactive=True,
+ )
+
+ with gr.Row():
+ if_initial_prompt = gr.Checkbox(
+ value=False,
+ label=i18n("Enable Initial Prompt"),
+ min_width=120,
+ scale=0,
+ )
+ initial_prompt = gr.Textbox(
+ label=i18n("Initial Prompt"),
+ info=i18n(
+ "Initial prompt can provide contextual or vocabulary-specific guidance to the model."
+ ),
+ placeholder="This audio introduces the basic concepts and applications of artificial intelligence and machine learning.",
+ interactive=False,
+ )
+
+ with gr.Row():
+ add_button = gr.Button(
+ "\U000027A1 " + i18n("Add to Processing Area"),
+ variant="primary",
+ )
+ remove_button = gr.Button(
+ "\U000026D4 " + i18n("Remove Selected Data")
+ )
+
+ with gr.Tab("\U0001F6E0 " + i18n("Training Configuration")):
+ with gr.Row():
+ model_type_radio = gr.Radio(
+ label=i18n(
+ "Select the model to be trained (Depending on the Tab page you are on)"
+ ),
+ interactive=False,
+ choices=["VQGAN", "LLAMA"],
+ value="VQGAN",
+ )
+ with gr.Row():
+ with gr.Column():
+ with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page:
+ gr.HTML("You don't need to train this model!")
+
+ with gr.Tab(label=i18n("LLAMA Configuration")) as llama_page:
+ with gr.Row(equal_height=False):
+ llama_use_lora = gr.Checkbox(
+ label=i18n("Use LoRA"),
+ info=i18n(
+ "Use LoRA can save GPU memory, but may reduce the quality of the model"
+ ),
+ value=True,
+ interactive=True,
+ )
+ llama_ckpt = gr.Dropdown(
+ label=i18n("Select LLAMA ckpt"),
+ choices=[i18n("latest"), i18n("new")]
+ + [
+ str(p)
+ for p in Path("results").glob("text2sem*/")
+ ]
+ + [str(p) for p in Path("results").glob("lora*/")],
+ value=i18n("latest"),
+ interactive=True,
+ )
+ with gr.Row(equal_height=False):
+ llama_lr_slider = gr.Slider(
+ label=i18n("Initial Learning Rate"),
+ info=i18n(
+ "lr smaller -> usually train slower but more stable"
+ ),
+ interactive=True,
+ minimum=1e-5,
+ maximum=1e-4,
+ step=1e-5,
+ value=5e-5,
+ )
+ llama_maxsteps_slider = gr.Slider(
+ label=i18n("Maximum Training Steps"),
+ info=i18n(
+ "recommend: max_steps = num_audios // batch_size * (2 to 5)"
+ ),
+ interactive=True,
+ minimum=1,
+ maximum=10000,
+ step=1,
+ value=50,
+ )
+ with gr.Row(equal_height=False):
+ llama_base_config = gr.Dropdown(
+ label=i18n("Model Size"),
+ choices=[
+ "text2semantic_finetune",
+ ],
+ value="text2semantic_finetune",
+ )
+ llama_data_num_workers_slider = gr.Slider(
+ label=i18n("Number of Workers"),
+ minimum=1,
+ maximum=32,
+ step=1,
+ value=4,
+ )
+ with gr.Row(equal_height=False):
+ llama_data_batch_size_slider = gr.Slider(
+ label=i18n("Batch Size"),
+ interactive=True,
+ minimum=1,
+ maximum=32,
+ step=1,
+ value=2,
+ )
+ llama_data_max_length_slider = gr.Slider(
+ label=i18n("Maximum Length per Sample"),
+ interactive=True,
+ minimum=1024,
+ maximum=4096,
+ step=128,
+ value=2048,
+ )
+ with gr.Row(equal_height=False):
+ llama_precision_dropdown = gr.Dropdown(
+ label=i18n("Precision"),
+ info=i18n(
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU"
+ ),
+ interactive=True,
+ choices=["32", "bf16-true", "16-mixed"],
+ value="bf16-true",
+ )
+ llama_check_interval_slider = gr.Slider(
+ label=i18n("Save model every n steps"),
+ info=i18n(
+ "make sure that it's not greater than max_steps"
+ ),
+ interactive=True,
+ minimum=1,
+ maximum=1000,
+ step=1,
+ value=50,
+ )
+ with gr.Row(equal_height=False):
+ llama_grad_batches = gr.Slider(
+ label=i18n("Accumulate Gradient Batches"),
+ interactive=True,
+ minimum=1,
+ maximum=20,
+ step=1,
+ value=init_llama_yml["trainer"][
+ "accumulate_grad_batches"
+ ],
+ )
+ llama_use_speaker = gr.Slider(
+ label=i18n(
+ "Probability of applying Speaker Condition"
+ ),
+ interactive=True,
+ minimum=0.1,
+ maximum=1.0,
+ step=0.05,
+ value=init_llama_yml["train_dataset"][
+ "interactive_prob"
+ ],
+ )
+
+ with gr.Tab(label=i18n("Merge LoRA"), id=4):
+ with gr.Row(equal_height=False):
+ llama_weight = gr.Dropdown(
+ label=i18n("Base LLAMA Model"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ choices=[
+ "checkpoints/fish-speech-1.4/model.pth",
+ ],
+ value="checkpoints/fish-speech-1.4/model.pth",
+ allow_custom_value=True,
+ interactive=True,
+ )
+ with gr.Row(equal_height=False):
+ lora_weight = gr.Dropdown(
+ label=i18n("LoRA Model to be merged"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ choices=[
+ str(p)
+ for p in Path("results").glob("lora*/**/*.ckpt")
+ ],
+ allow_custom_value=True,
+ interactive=True,
+ )
+ lora_llama_config = gr.Dropdown(
+ label=i18n("LLAMA Model Config"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ choices=[
+ "text2semantic_finetune",
+ ],
+ value="text2semantic_finetune",
+ allow_custom_value=True,
+ )
+ with gr.Row(equal_height=False):
+ llama_lora_output = gr.Dropdown(
+ label=i18n("Output Path"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ value="checkpoints/merged",
+ choices=["checkpoints/merged"],
+ allow_custom_value=True,
+ interactive=True,
+ )
+ with gr.Row(equal_height=False):
+ llama_lora_merge_btn = gr.Button(
+ value=i18n("Merge"), variant="primary"
+ )
+
+ with gr.Tab(label=i18n("Model Quantization"), id=5):
+ with gr.Row(equal_height=False):
+ llama_weight_to_quantify = gr.Dropdown(
+ label=i18n("Base LLAMA Model"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ choices=list_llama_models(),
+ value="checkpoints/fish-speech-1.4",
+ allow_custom_value=True,
+ interactive=True,
+ )
+ quantify_mode = gr.Dropdown(
+ label=i18n("Post-quantification Precision"),
+ info=i18n(
+ "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase"
+ ),
+ choices=["int8", "int4"],
+ value="int8",
+ allow_custom_value=False,
+ interactive=True,
+ )
+ with gr.Row(equal_height=False):
+ llama_quantify_btn = gr.Button(
+ value=i18n("Quantify"), variant="primary"
+ )
+
+ with gr.Tab(label="Tensorboard", id=6):
+ with gr.Row(equal_height=False):
+ tb_host = gr.Textbox(
+ label=i18n("Tensorboard Host"), value="127.0.0.1"
+ )
+ tb_port = gr.Textbox(
+ label=i18n("Tensorboard Port"), value="11451"
+ )
+ with gr.Row(equal_height=False):
+ tb_dir = gr.Dropdown(
+ label=i18n("Tensorboard Log Path"),
+ allow_custom_value=True,
+ choices=[
+ str(p)
+ for p in Path("results").glob("**/tensorboard/")
+ ],
+ )
+ with gr.Row(equal_height=False):
+ if_tb = gr.Checkbox(
+ label=i18n("Open Tensorboard"),
+ )
+
+ with gr.Tab("\U0001F9E0 " + i18n("Inference Configuration")):
+ with gr.Column():
+ with gr.Row():
+ with gr.Accordion(
+ label="\U0001F5A5 "
+ + i18n("Inference Server Configuration"),
+ open=False,
+ ):
+ with gr.Row():
+ infer_host_textbox = gr.Textbox(
+ label=i18n("WebUI Host"), value="127.0.0.1"
+ )
+ infer_port_textbox = gr.Textbox(
+ label=i18n("WebUI Port"), value="7862"
+ )
+ with gr.Row():
+ infer_decoder_model = gr.Dropdown(
+ label=i18n("Decoder Model Path"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ choices=list_decoder_models(),
+ value="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+ allow_custom_value=True,
+ )
+ infer_decoder_config = gr.Dropdown(
+ label=i18n("Decoder Model Config"),
+ info=i18n("Changing with the Model Path"),
+ value="firefly_gan_vq",
+ choices=[
+ "firefly_gan_vq",
+ ],
+ allow_custom_value=True,
+ )
+ with gr.Row():
+ infer_llama_model = gr.Dropdown(
+ label=i18n("LLAMA Model Path"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ value="checkpoints/fish-speech-1.4",
+ choices=list_llama_models(),
+ allow_custom_value=True,
+ )
+
+ with gr.Row():
+ infer_compile = gr.Radio(
+ label=i18n("Compile Model"),
+ info=i18n(
+ "Compile the model can significantly reduce the inference time, but will increase cold start time"
+ ),
+ choices=["Yes", "No"],
+ value=(
+ "Yes" if (sys.platform == "linux") else "No"
+ ),
+ interactive=is_module_installed("triton"),
+ )
+
+ with gr.Row():
+ infer_checkbox = gr.Checkbox(
+ label=i18n("Open Inference Server")
+ )
+ infer_error = gr.HTML(label=i18n("Inference Server Error"))
+
+ with gr.Column():
+ train_error = gr.HTML(label=i18n("Training Error"))
+ checkbox_group = gr.CheckboxGroup(
+ label="\U0001F4CA " + i18n("Data Source"),
+ info=i18n(
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list."
+ ),
+ elem_classes=["data_src"],
+ )
+ train_box = gr.Textbox(
+ label=i18n("Data Preprocessing Path"),
+ value=str(data_pre_output),
+ interactive=False,
+ )
+ model_box = gr.Textbox(
+ label="\U0001F4BE " + i18n("Model Output Path"),
+ value=str(default_model_output),
+ interactive=False,
+ )
+
+ with gr.Accordion(
+ i18n(
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)"
+ ),
+ elem_classes=["scrollable-component"],
+ elem_id="file_accordion",
+ ):
+ tree_slider = gr.Slider(
+ minimum=0,
+ maximum=3,
+ value=0,
+ step=1,
+ show_label=False,
+ container=False,
+ )
+ file_markdown = new_explorer(str(data_pre_output), 0)
+ with gr.Row(equal_height=False):
+ admit_btn = gr.Button(
+ "\U00002705 " + i18n("File Preprocessing"),
+ variant="primary",
+ )
+ fresh_btn = gr.Button("\U0001F503", scale=0, min_width=80)
+ help_button = gr.Button("\U00002753", scale=0, min_width=80) # question
+ train_btn = gr.Button(i18n("Start Training"), variant="primary")
+
+ footer = load_data_in_raw("fish_speech/webui/html/footer.html")
+ footer = footer.format(
+ versions=versions_html(),
+ api_docs="https://speech.fish.audio/inference/#http-api",
+ )
+ gr.HTML(footer, elem_id="footer")
+ vqgan_page.select(lambda: "VQGAN", None, model_type_radio)
+ llama_page.select(lambda: "LLAMA", None, model_type_radio)
+ add_button.click(
+ fn=add_item,
+ inputs=[textbox, output_radio, label_radio, if_initial_prompt, initial_prompt],
+ outputs=[checkbox_group, error],
+ )
+ remove_button.click(
+ fn=remove_items, inputs=[checkbox_group], outputs=[checkbox_group, error]
+ )
+ checkbox_group.change(fn=show_selected, inputs=checkbox_group, outputs=[error])
+ help_button.click(
+ fn=None,
+ js='() => { window.open("https://speech.fish.audio/", "newwindow", "height=100, width=400, '
+ 'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
+ )
+ if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
+ if_initial_prompt.change(
+ fn=lambda x: gr.Textbox(value="", interactive=x),
+ inputs=[if_initial_prompt],
+ outputs=[initial_prompt],
+ )
+ train_btn.click(
+ fn=train_process,
+ inputs=[
+ train_box,
+ model_type_radio,
+ # llama config
+ llama_ckpt,
+ llama_base_config,
+ llama_lr_slider,
+ llama_maxsteps_slider,
+ llama_data_num_workers_slider,
+ llama_data_batch_size_slider,
+ llama_data_max_length_slider,
+ llama_precision_dropdown,
+ llama_check_interval_slider,
+ llama_grad_batches,
+ llama_use_speaker,
+ llama_use_lora,
+ ],
+ outputs=[train_error],
+ )
+ if_tb.change(
+ fn=tensorboard_process,
+ inputs=[if_tb, tb_dir, tb_host, tb_port],
+ outputs=[train_error],
+ )
+ tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir])
+ infer_decoder_model.change(
+ fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model]
+ )
+ infer_llama_model.change(
+ fn=fresh_llama_model, inputs=[], outputs=[infer_llama_model]
+ )
+ llama_weight.change(fn=fresh_llama_model, inputs=[], outputs=[llama_weight])
+ admit_btn.click(
+ fn=check_files,
+ inputs=[train_box, tree_slider, label_model, label_device],
+ outputs=[error, file_markdown],
+ )
+ fresh_btn.click(
+ fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown]
+ )
+ llama_use_lora.change(
+ fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
+ )
+ llama_ckpt.change(
+ fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
+ )
+ lora_weight.change(
+ fn=lambda: gr.Dropdown(choices=list_lora_llama_models()),
+ inputs=[],
+ outputs=[lora_weight],
+ )
+ llama_lora_merge_btn.click(
+ fn=llama_lora_merge,
+ inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output],
+ outputs=[train_error],
+ )
+ llama_quantify_btn.click(
+ fn=llama_quantify,
+ inputs=[llama_weight_to_quantify, quantify_mode],
+ outputs=[train_error],
+ )
+ infer_checkbox.change(
+ fn=change_infer,
+ inputs=[
+ infer_checkbox,
+ infer_host_textbox,
+ infer_port_textbox,
+ infer_decoder_model,
+ infer_decoder_config,
+ infer_llama_model,
+ infer_compile,
+ ],
+ outputs=[infer_error],
+ )
+
+demo.launch(inbrowser=True)