diff --git a/fish_speech/__pycache__/conversation.cpython-310.pyc b/fish_speech/__pycache__/conversation.cpython-310.pyc deleted file mode 100644 index b4dc1336106c5d496e7a1c091e609089eb30d096..0000000000000000000000000000000000000000 Binary files a/fish_speech/__pycache__/conversation.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/__pycache__/scheduler.cpython-310.pyc b/fish_speech/__pycache__/scheduler.cpython-310.pyc deleted file mode 100644 index 5ce90919af88b3c612722a85c3799f2cc4d58d76..0000000000000000000000000000000000000000 Binary files a/fish_speech/__pycache__/scheduler.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/callbacks/__init__.py b/fish_speech/callbacks/__init__.py deleted file mode 100644 index bbcf3f33656d180ca87cd14a21ede1544e5a61a3..0000000000000000000000000000000000000000 --- a/fish_speech/callbacks/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .grad_norm import GradNormMonitor - -__all__ = ["GradNormMonitor"] diff --git a/fish_speech/callbacks/__pycache__/__init__.cpython-310.pyc b/fish_speech/callbacks/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 033bf77b0edc8dbe764c3e4386c005136b1ee50c..0000000000000000000000000000000000000000 Binary files a/fish_speech/callbacks/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/callbacks/__pycache__/grad_norm.cpython-310.pyc b/fish_speech/callbacks/__pycache__/grad_norm.cpython-310.pyc deleted file mode 100644 index 2058510bf280afba7b92dc027d4629aa030b72fc..0000000000000000000000000000000000000000 Binary files a/fish_speech/callbacks/__pycache__/grad_norm.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/callbacks/grad_norm.py b/fish_speech/callbacks/grad_norm.py deleted file mode 100644 index dbc95ef2a3723323b2d976001ed1e3c79c00b21a..0000000000000000000000000000000000000000 --- a/fish_speech/callbacks/grad_norm.py +++ /dev/null @@ -1,113 +0,0 @@ -from typing import Optional, Union - -import lightning.pytorch as pl -import torch -from lightning import LightningModule, Trainer -from lightning.pytorch.callbacks import Callback -from torch import Tensor, nn -from torch.utils._foreach_utils import ( - _group_tensors_by_device_and_dtype, - _has_foreach_support, -) - - -@torch.no_grad() -def grad_norm( - parameters: Union[Tensor, list[Tensor]], - norm_type: float = 2.0, -) -> float: - """ - Returns the norm of the gradients of the given parameters. - - Args: - parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a - single Tensor that will have gradients normalized - norm_type (float): type of the used p-norm. - - Returns: - Total norm of the parameter gradients (viewed as a single vector). - """ # noqa: E501 - - if isinstance(parameters, Tensor): - parameters = [parameters] - - grads = [p.grad for p in parameters if p.grad is not None] - if len(grads) == 0: - return None - - first_device = grads[0].device - grouped_grads: dict[ - tuple[torch.device, torch.dtype], list[list[Tensor]] - ] = _group_tensors_by_device_and_dtype( - [[g.detach() for g in grads]] - ) # type: ignore[assignment] - - norms = [] - for (device, _), ([grads], _) in grouped_grads.items(): - if _has_foreach_support(grads, device=device): - norms.extend(torch._foreach_norm(grads, norm_type)) - else: - norms.extend([torch.norm(g, norm_type) for g in grads]) - - return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type) - - -class GradNormMonitor(Callback): - """ - Callback that computes the gradient norm of the model parameters. - """ - - def __init__( - self, - norm_type: float = 2.0, - logging_interval: str = "step", - sub_module: Optional[Union[str, list[str]]] = None, - ) -> None: - """ - Args: - norm_type (float): type of the used p-norm. - logging_interval (str): "step" or "epoch". - """ - super().__init__() - - self.norm_type = norm_type - self.logging_interval = logging_interval - self.sub_module = sub_module - - def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None: - """ - Computes the gradient norm of the model parameters and logs it to the logger. - - Args: - trainer (Trainer): The trainer object - model (LightningModule): The current lightningModule - """ - - lightning_model = model - - if self.sub_module is None: - return self.log_sub_module_grad_norm(lightning_model, model, "") - - sub_modules = self.sub_module - if isinstance(sub_modules, str): - sub_modules = [sub_modules] - - for sub_module in sub_modules: - self.log_sub_module_grad_norm( - lightning_model, getattr(model, sub_module), f"/{sub_module}" - ) - - def log_sub_module_grad_norm( - self, lightning_model: LightningModule, model: nn.Module, path: str - ) -> None: - grad_norm_val = grad_norm(model.parameters(), self.norm_type) - if grad_norm_val is None: - return - - on_step = self.logging_interval == "step" - lightning_model.log( - f"train{path}/grad_norm", - grad_norm_val, - on_step=on_step, - on_epoch=not on_step, - ) diff --git a/fish_speech/configs/base.yaml b/fish_speech/configs/base.yaml deleted file mode 100644 index 99e6dab54d3f57bce4f6d29a9129a19a523cad75..0000000000000000000000000000000000000000 --- a/fish_speech/configs/base.yaml +++ /dev/null @@ -1,87 +0,0 @@ -# Base configuration for training a model -paths: - run_dir: results/${project} - ckpt_dir: ${paths.run_dir}/checkpoints - -hydra: - run: - dir: ${paths.run_dir} - -# Lightning Trainer -trainer: - _target_: lightning.pytorch.trainer.Trainer - - default_root_dir: ${paths.run_dir} - accelerator: gpu - num_nodes: 1 - devices: auto - strategy: - _target_: lightning.pytorch.strategies.DDPStrategy - process_group_backend: nccl # This should be override when training on windows - - precision: bf16-mixed - - # disable validation by epoch end - check_val_every_n_epoch: null - val_check_interval: 5000 - max_steps: 100_000 - - # Use torch.backends.cudnn.benchmark to speed up training - benchmark: true - -# Callbacks -callbacks: - model_checkpoint: - _target_: lightning.pytorch.callbacks.ModelCheckpoint - dirpath: ${paths.ckpt_dir} - filename: "step_{step:09d}" - save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt - save_top_k: 5 # save 5 latest checkpoints - monitor: step # use step to monitor checkpoints - mode: max # save the latest checkpoint with the highest global_step - every_n_epochs: null # don't save checkpoints by epoch end - every_n_train_steps: 5000 # save checkpoints every 5000 steps - auto_insert_metric_name: false - - model_summary: - _target_: lightning.pytorch.callbacks.ModelSummary - max_depth: 2 # the maximum depth of layer nesting that the summary will include - - learning_rate_monitor: - _target_: lightning.pytorch.callbacks.LearningRateMonitor - logging_interval: step - log_momentum: false - - grad_norm_monitor: - _target_: fish_speech.callbacks.GradNormMonitor - norm_type: 2 - logging_interval: step - -# Logger -logger: - tensorboard: - _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger - save_dir: "${paths.run_dir}/tensorboard/" - name: null - log_graph: false - default_hp_metric: true - prefix: "" - - # wandb: - # _target_: lightning.pytorch.loggers.wandb.WandbLogger - # # name: "" # name of the run (normally generated by wandb) - # save_dir: "${paths.run_dir}" - # offline: False - # id: null # pass correct id to resume experiment! - # anonymous: null # enable anonymous logging - # project: "fish-speech" - # log_model: False # upload lightning ckpts - # prefix: "" # a string to put at the beginning of metric keys - # # entity: "" # set to name of your wandb team - # group: "" - # tags: ["vq", "hq", "finetune"] - # job_type: "" - -# Loop -train: true -test: false diff --git a/fish_speech/configs/firefly_gan_vq.yaml b/fish_speech/configs/firefly_gan_vq.yaml deleted file mode 100644 index 10aa8d4a522f0859ed8f541f5d48672d84b39c8f..0000000000000000000000000000000000000000 --- a/fish_speech/configs/firefly_gan_vq.yaml +++ /dev/null @@ -1,33 +0,0 @@ -_target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture -spec_transform: - _target_: fish_speech.utils.spectrogram.LogMelSpectrogram - sample_rate: 44100 - n_mels: 160 - n_fft: 2048 - hop_length: 512 - win_length: 2048 -backbone: - _target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder - input_channels: 160 - depths: [3, 3, 9, 3] - dims: [128, 256, 384, 512] - drop_path_rate: 0.2 - kernel_size: 7 -head: - _target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator - hop_length: 512 - upsample_rates: [8, 8, 2, 2, 2] # aka. strides - upsample_kernel_sizes: [16, 16, 4, 4, 4] - resblock_kernel_sizes: [3, 7, 11] - resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] - num_mels: 512 - upsample_initial_channel: 512 - pre_conv_kernel_size: 13 - post_conv_kernel_size: 13 -quantizer: - _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize - input_dim: 512 - n_groups: 8 - n_codebooks: 1 - levels: [8, 5, 5, 5] - downsample_factor: [2, 2] diff --git a/fish_speech/configs/lora/r_8_alpha_16.yaml b/fish_speech/configs/lora/r_8_alpha_16.yaml deleted file mode 100644 index aecc4d9766a18fe31c55941e01b1f590c95e77c9..0000000000000000000000000000000000000000 --- a/fish_speech/configs/lora/r_8_alpha_16.yaml +++ /dev/null @@ -1,4 +0,0 @@ -_target_: fish_speech.models.text2semantic.lora.LoraConfig -r: 8 -lora_alpha: 16 -lora_dropout: 0.01 diff --git a/fish_speech/configs/text2semantic_finetune.yaml b/fish_speech/configs/text2semantic_finetune.yaml deleted file mode 100644 index f4c1993023099e122fc9e004bda55ec075ed5e1b..0000000000000000000000000000000000000000 --- a/fish_speech/configs/text2semantic_finetune.yaml +++ /dev/null @@ -1,83 +0,0 @@ -defaults: - - base - - _self_ - -project: text2semantic_finetune_dual_ar -max_length: 4096 -pretrained_ckpt_path: checkpoints/fish-speech-1.4 - -# Lightning Trainer -trainer: - accumulate_grad_batches: 1 - gradient_clip_val: 1.0 - gradient_clip_algorithm: "norm" - max_steps: 1000 - precision: bf16-true - limit_val_batches: 10 - val_check_interval: 100 - -# Dataset Configuration -tokenizer: - _target_: transformers.AutoTokenizer.from_pretrained - pretrained_model_name_or_path: ${pretrained_ckpt_path} - -# Dataset Configuration -train_dataset: - _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset - proto_files: - - data/protos - tokenizer: ${tokenizer} - causal: true - max_length: ${max_length} - use_speaker: false - interactive_prob: 0.7 - -val_dataset: - _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset - proto_files: - - data/protos - tokenizer: ${tokenizer} - causal: true - max_length: ${max_length} - use_speaker: false - interactive_prob: 0.7 - -data: - _target_: fish_speech.datasets.semantic.SemanticDataModule - train_dataset: ${train_dataset} - val_dataset: ${val_dataset} - num_workers: 4 - batch_size: 8 - tokenizer: ${tokenizer} - max_length: ${max_length} - -# Model Configuration -model: - _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic - model: - _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained - path: ${pretrained_ckpt_path} - load_weights: true - max_length: ${max_length} - lora_config: null - - optimizer: - _target_: torch.optim.AdamW - _partial_: true - lr: 1e-4 - weight_decay: 0 - betas: [0.9, 0.95] - eps: 1e-5 - - lr_scheduler: - _target_: torch.optim.lr_scheduler.LambdaLR - _partial_: true - lr_lambda: - _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda - _partial_: true - num_warmup_steps: 10 - -# Callbacks -callbacks: - model_checkpoint: - every_n_train_steps: ${trainer.val_check_interval} diff --git a/fish_speech/conversation.py b/fish_speech/conversation.py deleted file mode 100644 index c9ca0ef9181754eda7e6b49e01abeafbe07fb00f..0000000000000000000000000000000000000000 --- a/fish_speech/conversation.py +++ /dev/null @@ -1,2 +0,0 @@ -SEMANTIC_TOKEN = "<|semantic|>" -CODEBOOK_PAD_TOKEN_ID = 0 diff --git a/fish_speech/datasets/__pycache__/semantic.cpython-310.pyc b/fish_speech/datasets/__pycache__/semantic.cpython-310.pyc deleted file mode 100644 index ca763c1c12b41234a939ddbe343575b67a79bb92..0000000000000000000000000000000000000000 Binary files a/fish_speech/datasets/__pycache__/semantic.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/datasets/concat_repeat.py b/fish_speech/datasets/concat_repeat.py deleted file mode 100644 index 4aa596b95a572ee15c5570cbdb792c9a78e62dfa..0000000000000000000000000000000000000000 --- a/fish_speech/datasets/concat_repeat.py +++ /dev/null @@ -1,53 +0,0 @@ -import bisect -import random -from typing import Iterable - -from torch.utils.data import Dataset, IterableDataset - - -class ConcatRepeatDataset(Dataset): - datasets: list[Dataset] - cumulative_sizes: list[int] - repeats: list[int] - - @staticmethod - def cumsum(sequence, repeats): - r, s = [], 0 - for dataset, repeat in zip(sequence, repeats): - l = len(dataset) * repeat - r.append(l + s) - s += l - return r - - def __init__(self, datasets: Iterable[Dataset], repeats: list[int]): - super().__init__() - - self.datasets = list(datasets) - self.repeats = repeats - - assert len(self.datasets) > 0, "datasets should not be an empty iterable" - assert len(self.datasets) == len( - repeats - ), "datasets and repeats should have the same length" - - for d in self.datasets: - assert not isinstance( - d, IterableDataset - ), "ConcatRepeatDataset does not support IterableDataset" - - self.cumulative_sizes = self.cumsum(self.datasets, self.repeats) - - def __len__(self): - return self.cumulative_sizes[-1] - - def __getitem__(self, idx): - dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) - - if dataset_idx == 0: - sample_idx = idx - else: - sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] - - dataset = self.datasets[dataset_idx] - - return dataset[sample_idx % len(dataset)] diff --git a/fish_speech/datasets/protos/__pycache__/text_data_pb2.cpython-310.pyc b/fish_speech/datasets/protos/__pycache__/text_data_pb2.cpython-310.pyc deleted file mode 100644 index 2b7bb23609e78b12b6e608581f1e8d764bd9db3a..0000000000000000000000000000000000000000 Binary files a/fish_speech/datasets/protos/__pycache__/text_data_pb2.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/datasets/protos/__pycache__/text_data_stream.cpython-310.pyc b/fish_speech/datasets/protos/__pycache__/text_data_stream.cpython-310.pyc deleted file mode 100644 index 6e22635c991e5d669704c3cf95dd528a39b8b822..0000000000000000000000000000000000000000 Binary files a/fish_speech/datasets/protos/__pycache__/text_data_stream.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/datasets/protos/text-data.proto b/fish_speech/datasets/protos/text-data.proto deleted file mode 100644 index 5eb26d94aa3be1e21066f2bf38c90d54e85a8379..0000000000000000000000000000000000000000 --- a/fish_speech/datasets/protos/text-data.proto +++ /dev/null @@ -1,24 +0,0 @@ -syntax = "proto3"; - -package text_data; - -message Semantics { - repeated uint32 values = 1; -} - -message Sentence { - repeated string texts = 1; - repeated Semantics semantics = 3; -} - -message TextData { - string source = 1; - string name = 2; - repeated Sentence sentences = 4; -} - -message SampledData { - string source = 1; - string name = 2; - repeated Sentence samples = 3; -} diff --git a/fish_speech/datasets/protos/text_data_pb2.py b/fish_speech/datasets/protos/text_data_pb2.py deleted file mode 100644 index bfce0e8be59fc51e68999ef137e1fd0e4adc0d7e..0000000000000000000000000000000000000000 --- a/fish_speech/datasets/protos/text_data_pb2.py +++ /dev/null @@ -1,33 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: text-data.proto -# Protobuf Python Version: 4.25.1 -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder - -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3' -) - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals) -if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _globals["_SEMANTICS"]._serialized_start = 30 - _globals["_SEMANTICS"]._serialized_end = 57 - _globals["_SENTENCE"]._serialized_start = 59 - _globals["_SENTENCE"]._serialized_end = 125 - _globals["_TEXTDATA"]._serialized_start = 127 - _globals["_TEXTDATA"]._serialized_end = 207 - _globals["_SAMPLEDDATA"]._serialized_start = 209 - _globals["_SAMPLEDDATA"]._serialized_end = 290 -# @@protoc_insertion_point(module_scope) diff --git a/fish_speech/datasets/protos/text_data_stream.py b/fish_speech/datasets/protos/text_data_stream.py deleted file mode 100644 index ec3c25bcd764e8245de47dcdf9686d6adfb5a107..0000000000000000000000000000000000000000 --- a/fish_speech/datasets/protos/text_data_stream.py +++ /dev/null @@ -1,36 +0,0 @@ -import struct - -from .text_data_pb2 import TextData - - -def read_pb_stream(f): - while True: - buf = f.read(4) - if len(buf) == 0: - break - size = struct.unpack("I", buf)[0] - buf = f.read(size) - text_data = TextData() - text_data.ParseFromString(buf) - yield text_data - - -def write_pb_stream(f, text_data): - buf = text_data.SerializeToString() - f.write(struct.pack("I", len(buf))) - f.write(buf) - - -def pack_pb_stream(text_data): - buf = text_data.SerializeToString() - return struct.pack("I", len(buf)) + buf - - -def split_pb_stream(f): - while True: - head = f.read(4) - if len(head) == 0: - break - size = struct.unpack("I", head)[0] - buf = f.read(size) - yield head + buf diff --git a/fish_speech/datasets/semantic.py b/fish_speech/datasets/semantic.py deleted file mode 100644 index 3c64e01077ae253bdc4e4d9cd948f8fb50df7418..0000000000000000000000000000000000000000 --- a/fish_speech/datasets/semantic.py +++ /dev/null @@ -1,496 +0,0 @@ -import random -from dataclasses import dataclass -from itertools import chain -from pathlib import Path -from random import Random -from typing import Optional, Union - -import numpy as np -import pyarrow.parquet as pq -import torch -import torch.nn.functional as F -from datasets.download.streaming_download_manager import xopen -from huggingface_hub import HfApi -from lightning import LightningDataModule -from torch.distributed import get_rank, get_world_size, is_initialized -from torch.utils.data import DataLoader, IterableDataset, get_worker_info -from transformers import AutoTokenizer - -from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID -from fish_speech.datasets.protos.text_data_pb2 import SampledData -from fish_speech.datasets.protos.text_data_stream import read_pb_stream -from fish_speech.text.clean import clean_text -from fish_speech.utils import RankedLogger -from fish_speech.utils.braceexpand import braceexpand - -log = RankedLogger(__name__, rank_zero_only=True) - - -def split_by_rank_worker(files): - # We need to know the total number of devices - # to split the data properly - - total_devices = 1 - if is_initialized(): - total_devices = get_world_size() - - worker_info = get_worker_info() - if worker_info is not None: - total_devices *= worker_info.num_workers - - if len(files) < total_devices: - # Repeat the files N times to match the number of devices - files = files * (total_devices // len(files) + 1) - - # DDP - if is_initialized(): - files = files[get_rank() :: get_world_size()] - - # Split by worker - if worker_info is not None: - files = files[worker_info.id :: worker_info.num_workers] - - return files - - -class AutoTextSemanticInstructionDataset(IterableDataset): - """ - Auto Augment Dataset by Speaker - - 1. Random concatenate multiple sentences from the same speaker to form a longer sentence - 2. Automatically normalize the text - - For interactive mode, we use the following format (multiple sequences): - [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] - - For non-interactive mode, we use the following format (one long sequence): - [INST] text [/INST] ... - """ - - def __init__( - self, - proto_files: list[str], - seed: int = 42, - interactive_prob: float = 0.5, - max_length: int = 1024, - tokenizer: AutoTokenizer = None, - use_speaker: bool | float = True, - causal: bool = True, - num_codebooks: Optional[int] = None, - skip_text_prob: float = 0.0, - ): - """ - Args: - proto_files: proto buf files if using local data - seed: random seed - interactive_prob: probability to use interactive mode - max_length: max length of the text - tokenizer: tokenizer - use_speaker: include speaker information in the prompt - causal: use causal sampling when using local data, disable will lead to random sampling - num_codebooks: number of codebooks, if None, it will be automatically detected - skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode - """ - - super().__init__() - - assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]" - - self.seed = seed - self.max_length = max_length - self.tokenizer = tokenizer - self.interactive_prob = interactive_prob - self.use_speaker = use_speaker - self.proto_files = proto_files - self.causal = causal - self.num_codebooks = num_codebooks - self.skip_text_prob = skip_text_prob - - self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>") - self.groups = None - - def init_mock_data_server(self): - if self.groups is not None: - return - - # Expand the proto files - expanded_proto_files = [] - for filename in self.proto_files: - for i in braceexpand(filename): - i = Path(i) - if i.is_file(): - expanded_proto_files.append(i) - elif i.is_dir(): - expanded_proto_files.extend(i.rglob("*.proto")) - expanded_proto_files.extend(i.rglob("*.protos")) - else: - raise ValueError(f"{i} is not a file or directory") - - expanded_proto_files = sorted(expanded_proto_files) - Random(self.seed).shuffle(expanded_proto_files) - - self.groups = [] - shard_proto_files = split_by_rank_worker(expanded_proto_files) - log.info( - f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files" - ) - - count = 0 - for filename in shard_proto_files: - with open(filename, "rb") as f: - for text_data in read_pb_stream(f): - self.groups.append(text_data) - count += 1 - - log.info(f"Read total {count} groups of data") - - # Shuffle the lines - Random(self.seed).shuffle(self.groups) - self.group_weights = [len(i.sentences) for i in self.groups] - - def __iter__(self): - while True: - yield self.augment() - - def tokenize_sentence(self, sentence: str): - sentence = clean_text(sentence) - tokens = self.tokenizer.encode( - f"{sentence}", - max_length=10**6, - add_special_tokens=False, - truncation=False, - ) - return sentence, len(tokens) - - def sample_data(self): - if self.groups is None: - self.init_mock_data_server() - - # Shuffle unique lines, estimate that each sample is at least 20 tokens - num_samples = self.max_length // 20 - - # choice group based on their number of samples - group = random.choices(self.groups, weights=self.group_weights, k=1)[0] - - if self.causal: - # Sample in order - if num_samples >= len(group.sentences): - samples = group.sentences - else: - begin = random.randint(0, len(group.sentences) - num_samples) - samples = group.sentences[begin : begin + num_samples] - else: - samples = random.choices( - group.sentences, k=min(num_samples, len(group.sentences)) - ) - - return SampledData( - source=group.source, - name=group.name, - samples=samples, - ) - - def augment(self): - final_text, final_semantic = [], [] - response = self.sample_data() - if len(response.samples) == 0: - # Invalid group - return None - - samples = list(response.samples) - idx = 0 - use_interactive = random.random() < self.interactive_prob - - if use_interactive is False: - # Random sample based on speaker using a truncated normal distribution - a = torch.tensor([0], dtype=torch.float32) - torch.nn.init.trunc_normal_( - a, - mean=self.max_length // 2, - std=self.max_length // 4, - a=10, - b=self.max_length, - ) - remaining_tokens = a.long().item() - 4 - else: - remaining_tokens = self.max_length - - # Use speaker - if isinstance(self.use_speaker, float): - use_speaker = random.random() < self.use_speaker - else: - use_speaker = self.use_speaker - - all_tokens, all_labels = [], [] - while remaining_tokens > 0 and len(samples) > 0: - sentence = samples.pop(0) - - text = random.choice(sentence.texts) - text, length = self.tokenize_sentence(text) - remaining_tokens -= length + len(sentence.semantics[0].values) - - if use_interactive is False: - final_text.append(text) - final_semantic.append(sentence.semantics) - else: - # For interactive mode, we only apply speaker for the first sentence - # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] - tokens, labels = self.pack_sentences( - sentences=[text], - semantics=[sentence.semantics], - speaker=response.name if use_speaker else None, - skip_text=random.random() < self.skip_text_prob, - ) - - all_tokens.append(tokens) - all_labels.append(labels) - - idx += 1 - - if use_interactive is False: - tokens, labels = self.pack_sentences( - final_text, - semantics=final_semantic, - speaker=response.name if use_speaker else None, - ) - all_tokens.append(tokens) - all_labels.append(labels) - - tokens = torch.cat(all_tokens, dim=1) - labels = torch.cat(all_labels, dim=1) - - # Verify that the length is correct - assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}" - - data = {"tokens": tokens, "labels": labels} - - return data - - def pack_sentences( - self, - sentences: list[str], - semantics: list, - speaker: Optional[str] = None, - skip_text: bool = False, - ): - if speaker is None: - speaker = "assistant" - - cated_sentences = " ".join(sentences) - if skip_text: - cated_sentences = "<|skip_text|>" - - final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>" - final_text = final_text + f"<|im_start|>{speaker}\n" - - encoded = self.tokenizer.encode( - final_text, - add_special_tokens=False, - truncation=False, - max_length=10**6, - ) - semantic_length = sum([len(i[0].values) for i in semantics]) - prompt_length = len(encoded) - num_codebooks = ( - len(semantics[0]) if self.num_codebooks is None else self.num_codebooks - ) - - # Pack the tokens and semantics (add and to semantic tokens) - tokens = ( - encoded - + [self.semantic_token_id] * semantic_length - + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"]) - ) - - # Codebook bos/padding: 0, eos: 1 - codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)] - for segment in semantics: - for book_idx, book in zip(range(num_codebooks), segment): - for j in book.values: - codes[book_idx].append(int(j) + 1) - - for book in codes: - book.extend([CODEBOOK_PAD_TOKEN_ID] * 1) - - tokens = [tokens] + codes - - tokens = torch.tensor(tokens, dtype=torch.long) - labels = tokens.clone() - - if skip_text: - # If text is not provided, the sentence is used for condition only, all labels are -100 - torch.fill_(labels, -100) - return tokens, labels - - # Mask out the tokens for semantic, predict semantic tokens only - # Since we don't mask out the input tokens, the language modeling still works - labels[1:, :prompt_length] = -100 - - tokens = tokens[:, :-1] - labels = labels[:, 1:] - - # Verify the padding is correct, and the last token is eos - assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all() - assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all() - - return tokens, labels - - -@dataclass -class TextDataCollator: - tokenizer: AutoTokenizer - max_length: int = 1024 - - def __call__(self, examples): - if "negative_tokens" in examples: - positive_examples = [] - negative_examples = [] - - for i in examples: - positive_examples.append( - { - "tokens": i["tokens"], - "labels": i["labels"], - } - ) - negative_examples.append( - { - "tokens": i["negative_tokens"], - "labels": i["negative_labels"], - } - ) - - examples = positive_examples + negative_examples - - return self.batchify(examples) - - def batchify(self, examples, tokens_key="tokens", labels_key="labels"): - tokens, attention_masks, labels = [], [], [] - - # Calculate the max length - max_tokens_length = 0 - for example in examples: - max_tokens_length = max(max_tokens_length, example[tokens_key].size(1)) - max_tokens_length = min(max_tokens_length, self.max_length) - - for example in examples: - _tokens = example[tokens_key][:, :max_tokens_length] - _labels = example[labels_key][:, :max_tokens_length] - _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool) - tokens_length = _tokens.size(1) - _attention_mask[:tokens_length] = False - - assert tokens_length == _labels.size( - 1 - ), f"{tokens_length} != {_labels.size(1)}" - - if tokens_length < max_tokens_length: - _tokens = F.pad( - _tokens, - (0, max_tokens_length - tokens_length), - value=self.tokenizer.eos_token_id, - ) - _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID - _labels = F.pad( - _labels, (0, max_tokens_length - _labels.size(1)), value=-100 - ) - - tokens.append(_tokens) - attention_masks.append(_attention_mask) - labels.append(_labels) - - tokens = torch.stack(tokens, dim=0) - attention_masks = torch.stack(attention_masks, dim=0) - labels = torch.stack(labels, dim=0) - - return { - "inputs": tokens, - "attention_masks": attention_masks, - "labels": labels, - } - - -class InterleaveDataset(IterableDataset): - def __init__( - self, - datasets: list[IterableDataset], - probabilities: list[float], - seed: int = 42, - ): - super().__init__() - - self.datasets = datasets - self.probabilities = probabilities - self.seed = seed - - def __iter__(self): - rng = np.random.default_rng(self.seed) - dataset_iterators = [iter(dataset) for dataset in self.datasets] - - while True: - # Random choice one - dataset_idx = rng.choice(len(self.datasets), p=self.probabilities) - dataset_iterator = dataset_iterators[dataset_idx] - - try: - yield next(dataset_iterator) - except StopIteration: - # Exhausted, create a new iterator - dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx]) - yield next(dataset_iterators[dataset_idx]) - - -class SemanticDataModule(LightningDataModule): - def __init__( - self, - train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset], - val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset], - batch_size: int = 32, - tokenizer: AutoTokenizer = None, - max_length: int = 1024, - num_workers: int = 4, - ): - super().__init__() - - self.train_dataset = train_dataset - self.val_dataset = val_dataset - self.batch_size = batch_size - self.tokenizer = tokenizer - self.max_length = max_length - self.num_workers = num_workers - - def train_dataloader(self): - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - collate_fn=TextDataCollator(self.tokenizer, self.max_length), - num_workers=self.num_workers, - persistent_workers=True, - ) - - def val_dataloader(self): - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - collate_fn=TextDataCollator(self.tokenizer, self.max_length), - num_workers=self.num_workers, - persistent_workers=True, - ) - - -if __name__ == "__main__": - from tqdm import tqdm - - ds = AutoTextSemanticInstructionDataset( - ["data/protos"], - tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"), - use_speaker=False, - interactive_prob=1.0, - skip_text_prob=0.5, - ) - - for i in ds: - print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False)) - # i["labels"][0][i["labels"][0] == -100] = 0 - # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False)) - break diff --git a/fish_speech/datasets/vqgan.py b/fish_speech/datasets/vqgan.py deleted file mode 100644 index a45583d22efb0feb9dc1e823bae1ef74534b299e..0000000000000000000000000000000000000000 --- a/fish_speech/datasets/vqgan.py +++ /dev/null @@ -1,147 +0,0 @@ -from dataclasses import dataclass -from pathlib import Path -from typing import Optional - -import librosa -import numpy as np -import torch -from lightning import LightningDataModule -from torch.utils.data import DataLoader, Dataset - -from fish_speech.utils import RankedLogger - -logger = RankedLogger(__name__, rank_zero_only=False) - - -class VQGANDataset(Dataset): - def __init__( - self, - filelist: str, - sample_rate: int = 32000, - hop_length: int = 640, - slice_frames: Optional[int] = None, - ): - super().__init__() - - filelist = Path(filelist) - root = filelist.parent - - self.files = [ - root / line.strip() - for line in filelist.read_text(encoding="utf-8").splitlines() - if line.strip() - ] - self.sample_rate = sample_rate - self.hop_length = hop_length - self.slice_frames = slice_frames - - def __len__(self): - return len(self.files) - - def get_item(self, idx): - file = self.files[idx] - - audio, _ = librosa.load(file, sr=self.sample_rate, mono=True) - - # Slice audio and features - if ( - self.slice_frames is not None - and audio.shape[0] > self.slice_frames * self.hop_length - ): - start = np.random.randint( - 0, audio.shape[0] - self.slice_frames * self.hop_length - ) - audio = audio[start : start + self.slice_frames * self.hop_length] - - if len(audio) == 0: - return None - - max_value = np.abs(audio).max() - if max_value > 1.0: - audio = audio / max_value - - return { - "audio": torch.from_numpy(audio), - } - - def __getitem__(self, idx): - try: - return self.get_item(idx) - except Exception as e: - import traceback - - traceback.print_exc() - logger.error(f"Error loading {self.files[idx]}: {e}") - return None - - -@dataclass -class VQGANCollator: - def __call__(self, batch): - batch = [x for x in batch if x is not None] - - audio_lengths = torch.tensor([len(x["audio"]) for x in batch]) - audio_maxlen = audio_lengths.max() - - # Rounds up to nearest multiple of 2 (audio_lengths) - audios = [] - for x in batch: - audios.append( - torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"]))) - ) - - return { - "audios": torch.stack(audios), - "audio_lengths": audio_lengths, - } - - -class VQGANDataModule(LightningDataModule): - def __init__( - self, - train_dataset: VQGANDataset, - val_dataset: VQGANDataset, - batch_size: int = 32, - num_workers: int = 4, - val_batch_size: Optional[int] = None, - ): - super().__init__() - - self.train_dataset = train_dataset - self.val_dataset = val_dataset - self.batch_size = batch_size - self.val_batch_size = val_batch_size or batch_size - self.num_workers = num_workers - - def train_dataloader(self): - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - collate_fn=VQGANCollator(), - num_workers=self.num_workers, - shuffle=True, - persistent_workers=True, - ) - - def val_dataloader(self): - return DataLoader( - self.val_dataset, - batch_size=self.val_batch_size, - collate_fn=VQGANCollator(), - num_workers=self.num_workers, - persistent_workers=True, - ) - - -if __name__ == "__main__": - dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt") - dataloader = DataLoader( - dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator() - ) - - for batch in dataloader: - print(batch["audios"].shape) - print(batch["features"].shape) - print(batch["audio_lengths"]) - print(batch["feature_lengths"]) - break diff --git a/fish_speech/i18n/README.md b/fish_speech/i18n/README.md deleted file mode 100644 index 700902b09db20911ef1ad678cbdce5644b84aea2..0000000000000000000000000000000000000000 --- a/fish_speech/i18n/README.md +++ /dev/null @@ -1,27 +0,0 @@ -## i18n Folder Attribution - -The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below: - -### fish_speech/i18n/core.py - -**Related code from RVC:** -[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py) - -**Initial commit:** -add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35) - -**Initial author:** -[@L4Ph](https://github.com/L4Ph) - -### fish_speech/i18n/scan.py - -**Related code from RVC:** -[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py) - -**Initial commit:** -File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058) - -**Initial author:** -[@towzeur](https://github.com/towzeur) - -We appreciate the contributions of the RVC project and its authors. diff --git a/fish_speech/i18n/__init__.py b/fish_speech/i18n/__init__.py deleted file mode 100644 index 981dbb3b3ecf28043ec9ff5757f947182821a246..0000000000000000000000000000000000000000 --- a/fish_speech/i18n/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .core import i18n - -__all__ = ["i18n"] diff --git a/fish_speech/i18n/__pycache__/__init__.cpython-310.pyc b/fish_speech/i18n/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index ba5a935b26a69595794d6840da906e6615c3a52f..0000000000000000000000000000000000000000 Binary files a/fish_speech/i18n/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/i18n/__pycache__/core.cpython-310.pyc b/fish_speech/i18n/__pycache__/core.cpython-310.pyc deleted file mode 100644 index 66d2787af00a38bc8ffebb84ed30565a71e94b01..0000000000000000000000000000000000000000 Binary files a/fish_speech/i18n/__pycache__/core.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/i18n/core.py b/fish_speech/i18n/core.py deleted file mode 100644 index 9f793ec95669228f7f4e8f9a7a5fe38da85c74bd..0000000000000000000000000000000000000000 --- a/fish_speech/i18n/core.py +++ /dev/null @@ -1,40 +0,0 @@ -import json -import locale -from pathlib import Path - -I18N_FILE_PATH = Path(__file__).parent / "locale" -DEFAULT_LANGUAGE = "en_US" - - -def load_language_list(language): - with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f: - language_list = json.load(f) - - return language_list - - -class I18nAuto: - def __init__(self): - i18n_file = Path(".locale") - - if i18n_file.exists(): - with open(i18n_file, "r", encoding="utf-8") as f: - language = f.read().strip() - else: - # getlocale can't identify the system's language ((None, None)) - language = locale.getdefaultlocale()[0] - - if (I18N_FILE_PATH / f"{language}.json").exists() is False: - language = DEFAULT_LANGUAGE - - self.language = language - self.language_map = load_language_list(language) - - def __call__(self, key): - return self.language_map.get(key, key) - - def __repr__(self): - return "Use Language: " + self.language - - -i18n = I18nAuto() diff --git a/fish_speech/i18n/locale/en_US.json b/fish_speech/i18n/locale/en_US.json deleted file mode 100644 index 6e280c236e9c79de2087ec33c7bf6f8e1a5296c4..0000000000000000000000000000000000000000 --- a/fish_speech/i18n/locale/en_US.json +++ /dev/null @@ -1,122 +0,0 @@ -{ - "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU", - "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.", - "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).", - "Accumulate Gradient Batches": "Accumulate Gradient Batches", - "Add to Processing Area": "Add to Processing Area", - "Added path successfully!": "Added path successfully!", - "Advanced Config": "Advanced Config", - "Base LLAMA Model": "Base LLAMA Model", - "Batch Inference": "Batch Inference", - "Batch Size": "Batch Size", - "Changing with the Model Path": "Changing with the Model Path", - "Chinese": "Chinese", - "Compile Model": "Compile Model", - "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time", - "Copy": "Copy", - "Data Preprocessing": "Data Preprocessing", - "Data Preprocessing Path": "Data Preprocessing Path", - "Data Source": "Data Source", - "Decoder Model Config": "Decoder Model Config", - "Decoder Model Path": "Decoder Model Path", - "Disabled": "Disabled", - "Enable Reference Audio": "Enable Reference Audio", - "English": "English", - "Error Message": "Error Message", - "File Preprocessing": "File Preprocessing", - "Generate": "Generate", - "Generated Audio": "Generated Audio", - "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format", - "Infer interface is closed": "Infer interface is closed", - "Inference Configuration": "Inference Configuration", - "Inference Server Configuration": "Inference Server Configuration", - "Inference Server Error": "Inference Server Error", - "Inferring interface is launched at {}": "Inferring interface is launched at {}", - "Initial Learning Rate": "Initial Learning Rate", - "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription", - "Input Text": "Input Text", - "Invalid path: {}": "Invalid path: {}", - "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU", - "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off", - "Japanese": "Japanese", - "LLAMA Configuration": "LLAMA Configuration", - "LLAMA Model Config": "LLAMA Model Config", - "LLAMA Model Path": "LLAMA Model Path", - "Labeling Device": "Labeling Device", - "LoRA Model to be merged": "LoRA Model to be merged", - "Maximum Audio Duration": "Maximum Audio Duration", - "Maximum Length per Sample": "Maximum Length per Sample", - "Maximum Training Steps": "Maximum Training Steps", - "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit", - "Merge": "Merge", - "Merge LoRA": "Merge LoRA", - "Merge successfully": "Merge successfully", - "Minimum Audio Duration": "Minimum Audio Duration", - "Model Output Path": "Model Output Path", - "Model Size": "Model Size", - "Move": "Move", - "Move files successfully": "Move files successfully", - "No audio generated, please check the input text.": "No audio generated, please check the input text.", - "No selected options": "No selected options", - "Number of Workers": "Number of Workers", - "Open Inference Server": "Open Inference Server", - "Open Labeler WebUI": "Open Labeler WebUI", - "Open Tensorboard": "Open Tensorboard", - "Opened labeler in browser": "Opened labeler in browser", - "Optional Label Language": "Optional Label Language", - "Optional online ver": "Optional online ver", - "Output Path": "Output Path", - "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path", - "Precision": "Precision", - "Probability of applying Speaker Condition": "Probability of applying Speaker Condition", - "Put your text here.": "Put your text here.", - "Reference Audio": "Reference Audio", - "Reference Text": "Reference Text", - "Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.", - "Remove Selected Data": "Remove Selected Data", - "Removed path successfully!": "Removed path successfully!", - "Repetition Penalty": "Repetition Penalty", - "Save model every n steps": "Save model every n steps", - "Select LLAMA ckpt": "Select LLAMA ckpt", - "Select VITS ckpt": "Select VITS ckpt", - "Select VQGAN ckpt": "Select VQGAN ckpt", - "Select source file processing method": "Select source file processing method", - "Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)", - "Selected: {}": "Selected: {}", - "Speaker": "Speaker", - "Speaker is identified by the folder name": "Speaker is identified by the folder name", - "Start Training": "Start Training", - "Streaming Audio": "Streaming Audio", - "Streaming Generate": "Streaming Generate", - "Tensorboard Host": "Tensorboard Host", - "Tensorboard Log Path": "Tensorboard Log Path", - "Tensorboard Port": "Tensorboard Port", - "Tensorboard interface is closed": "Tensorboard interface is closed", - "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}", - "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.", - "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.", - "Training Configuration": "Training Configuration", - "Training Error": "Training Error", - "Training stopped": "Training stopped", - "Type name of the speaker": "Type name of the speaker", - "Type the path or select from the dropdown": "Type the path or select from the dropdown", - "Use LoRA": "Use LoRA", - "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model", - "Use filelist": "Use filelist", - "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G", - "VITS Configuration": "VITS Configuration", - "VQGAN Configuration": "VQGAN Configuration", - "Validation Batch Size": "Validation Batch Size", - "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)", - "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.", - "WebUI Host": "WebUI Host", - "WebUI Port": "WebUI Port", - "Whisper Model": "Whisper Model", - "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).", - "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU", - "latest": "latest", - "new": "new", - "Realtime Transform Text": "Realtime Transform Text", - "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)", - "Text Normalization": "Text Normalization" -} diff --git a/fish_speech/i18n/locale/es_ES.json b/fish_speech/i18n/locale/es_ES.json deleted file mode 100644 index 3285341f6893fe3e2ccbee6490dd8c90ed21854e..0000000000000000000000000000000000000000 --- a/fish_speech/i18n/locale/es_ES.json +++ /dev/null @@ -1,122 +0,0 @@ -{ - "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+", - "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.", - "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).", - "Accumulate Gradient Batches": "Acumular lotes de gradientes", - "Add to Processing Area": "Agregar al Área de Procesamiento", - "Added path successfully!": "¡Ruta agregada exitosamente!", - "Advanced Config": "Configuración Avanzada", - "Base LLAMA Model": "Modelo Base LLAMA", - "Batch Inference": "Inferencia por Lote", - "Batch Size": "Tamaño del Lote", - "Changing with the Model Path": "Cambiando con la Ruta del Modelo", - "Chinese": "Chino", - "Compile Model": "Compilar Modelo", - "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío", - "Copy": "Copiar", - "Data Preprocessing": "Preprocesamiento de Datos", - "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos", - "Data Source": "Fuente de Datos", - "Decoder Model Config": "Configuración del modelo decodificador", - "Decoder Model Path": "Ruta del modelo decodificador", - "Disabled": "Desactivado", - "Enable Reference Audio": "Habilitar Audio de Referencia", - "English": "Inglés", - "Error Message": "Mensaje de Error", - "File Preprocessing": "Preprocesamiento de Archivos", - "Generate": "Generar", - "Generated Audio": "Audio Generado", - "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab", - "Infer interface is closed": "La interfaz de inferencia está cerrada", - "Inference Configuration": "Configuración de Inferencia", - "Inference Server Configuration": "Configuración del Servidor de Inferencia", - "Inference Server Error": "Error del Servidor de Inferencia", - "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}", - "Initial Learning Rate": "Tasa de Aprendizaje Inicial", - "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción", - "Input Text": "Texto de Entrada", - "Invalid path: {}": "Ruta inválida: {}", - "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU", - "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado", - "Japanese": "Japonés", - "LLAMA Configuration": "Configuración de LLAMA", - "LLAMA Model Config": "Configuración del Modelo LLAMA", - "LLAMA Model Path": "Ruta del Modelo LLAMA", - "Labeling Device": "Dispositivo de Etiquetado", - "LoRA Model to be merged": "Modelo LoRA a fusionar", - "Maximum Audio Duration": "Duración máxima de audio", - "Maximum Length per Sample": "Longitud Máxima por Muestra", - "Maximum Training Steps": "Pasos Máximos de Entrenamiento", - "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite", - "Merge": "Fusionar", - "Merge LoRA": "Fusionar LoRA", - "Merge successfully": "Fusionado exitosamente", - "Minimum Audio Duration": "Duración mínima de audio", - "Model Output Path": "Ruta de Salida del Modelo", - "Model Size": "Tamaño del Modelo", - "Move": "Mover", - "Move files successfully": "Archivos movidos exitosamente", - "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.", - "No selected options": "No hay opciones seleccionadas", - "Number of Workers": "Número de Trabajadores", - "Open Inference Server": "Abrir Servidor de Inferencia", - "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador", - "Open Tensorboard": "Abrir Tensorboard", - "Opened labeler in browser": "Se abrió el etiquetador en el navegador", - "Optional Label Language": "Idioma de Etiquetado Opcional", - "Optional online ver": "Ver en línea opcional", - "Output Path": "Ruta de Salida", - "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente", - "Precision": "Precisión", - "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante", - "Put your text here.": "Ponga su texto aquí.", - "Reference Audio": "Audio de Referencia", - "Reference Text": "Texto de Referencia", - "Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.", - "Remove Selected Data": "Eliminar Datos Seleccionados", - "Removed path successfully!": "¡Ruta eliminada exitosamente!", - "Repetition Penalty": "Penalización por Repetición", - "Save model every n steps": "Guardar modelo cada n pasos", - "Select LLAMA ckpt": "Seleccionar punto de control LLAMA", - "Select VITS ckpt": "Seleccionar punto de control VITS", - "Select VQGAN ckpt": "Seleccionar punto de control VQGAN", - "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente", - "Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)", - "Selected: {}": "Seleccionado: {}", - "Speaker": "Hablante", - "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta", - "Start Training": "Iniciar Entrenamiento", - "Streaming Audio": "transmisión de audio", - "Streaming Generate": "síntesis en flujo", - "Tensorboard Host": "Host de Tensorboard", - "Tensorboard Log Path": "Ruta de Registro de Tensorboard", - "Tensorboard Port": "Puerto de Tensorboard", - "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada", - "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}", - "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.", - "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.", - "Training Configuration": "Configuración de Entrenamiento", - "Training Error": "Error de Entrenamiento", - "Training stopped": "Entrenamiento detenido", - "Type name of the speaker": "Escriba el nombre del hablante", - "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable", - "Use LoRA": "Usar LoRA", - "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo", - "Use filelist": "Usar lista de archivos", - "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G", - "VITS Configuration": "Configuración de VITS", - "VQGAN Configuration": "Configuración de VQGAN", - "Validation Batch Size": "Tamaño del Lote de Validación", - "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)", - "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.", - "WebUI Host": "Host de WebUI", - "WebUI Port": "Puerto de WebUI", - "Whisper Model": "Modelo Whisper", - "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).", - "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+", - "latest": "más reciente", - "new": "nuevo", - "Realtime Transform Text": "Transformación de Texto en Tiempo Real", - "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)", - "Text Normalization": "Normalización de Texto" -} diff --git a/fish_speech/i18n/locale/ja_JP.json b/fish_speech/i18n/locale/ja_JP.json deleted file mode 100644 index d30bac7bcdf4f4c65b1f78b4dcf9d705c1d8eb39..0000000000000000000000000000000000000000 --- a/fish_speech/i18n/locale/ja_JP.json +++ /dev/null @@ -1,123 +0,0 @@ -{ - "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします", - "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。", - "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。", - "Accumulate Gradient Batches": "勾配バッチの累積", - "Add to Processing Area": "処理エリアに追加", - "Added path successfully!": "パスの追加に成功しました!", - "Advanced Config": "詳細設定", - "Base LLAMA Model": "基本LLAMAモデル", - "Batch Inference": "バッチ推論", - "Batch Size": "バッチサイズ", - "Changing with the Model Path": "モデルのパスに伴って変化する", - "Chinese": "中国語", - "Compile Model": "モデルのコンパイル", - "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります", - "Copy": "コピー", - "Data Preprocessing": "データ前処理", - "Data Preprocessing Path": "データ前処理パス", - "Data Source": "データソース", - "Decoder Model Config": "デコーダーモデルの構成", - "Decoder Model Path": "デコーダーモデルのパス", - "Disabled": "無効", - "Enable Reference Audio": "リファレンスオーディオを有効にする", - "English": "英語", - "Error Message": "エラーメッセージ", - "File Preprocessing": "文書前处理", - "Generate": "生成", - "Generated Audio": "生成されたオーディオ", - "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています", - "Infer interface is closed": "推論インターフェースが閉じられています", - "Inference Configuration": "推論設定", - "Inference Server Configuration": "推論サーバー設定", - "Inference Server Error": "推論サーバーエラー", - "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました", - "Initial Learning Rate": "初期学習率", - "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス", - "Input Text": "入力テキスト", - "Invalid path: {}": "無効なパス: {}", - "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください", - "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します", - "Japanese": "日本語", - "LLAMA Configuration": "LLAMA設定", - "LLAMA Model Config": "LLAMAモデル設定", - "LLAMA Model Path": "LLAMAモデルパス", - "Labeling Device": "ラベリングデバイス", - "LoRA Model to be merged": "マージするLoRAモデル", - "Maximum Audio Duration": "最大オーディオの長さ", - "Maximum Length per Sample": "サンプルあたりの最大長", - "Maximum Training Steps": "最大トレーニングステップ数", - "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します", - "Merge": "マージ", - "Merge LoRA": "LoRAのマージ", - "Merge successfully": "マージに成功しました", - "Minimum Audio Duration": "最小オーディオの長さ", - "Model Output Path": "モデル出力パス", - "Model Size": "モデルサイズ", - "Move": "移動", - "Move files successfully": "ファイルの移動に成功しました", - "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。", - "No selected options": "選択されたオプションはありません", - "Number of Workers": "ワーカー数", - "Open Inference Server": "推論サーバーを開く", - "Open Labeler WebUI": "ラベラーWebUIを開く", - "Open Tensorboard": "Tensorboardを開く", - "Opened labeler in browser": "ブラウザでラベラーを開きました", - "Optional Label Language": "オプションのラベル言語", - "Optional online ver": "オプションのオンラインバージョン", - "Output Path": "出力パス", - "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください", - "Precision": "精度", - "Probability of applying Speaker Condition": "話者条件を適用する確率", - "Put your text here.": "ここにテキストを入力してください。", - "Reference Audio": "リファレンスオーディオ", - "Reference Text": "リファレンステキスト", - "Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。", - "Remove Selected Data": "選択したデータを削除", - "Removed path successfully!": "パスの削除に成功しました!", - "Repetition Penalty": "反復ペナルティ", - "Save model every n steps": "nステップごとにモデルを保存", - "Select LLAMA ckpt": " LLAMA チェックポイントを選択", - "Select VITS ckpt": "VITS チェックポイントを選択", - "Select VQGAN ckpt": "VQGAN チェックポイントを選択", - "Select source file processing method": "ソースファイルの処理方法を選択", - "Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください", - "Selected: {}": "選択済み: {}", - "Speaker": "話者", - "Speaker is identified by the folder name": "話者はフォルダ名で識別されます", - "Start Training": "トレーニング開始", - "Streaming Audio": "ストリーミングオーディオ", - "Streaming Generate": "ストリーミング合成", - "Tensorboard Host": "Tensorboardホスト", - "Tensorboard Log Path": "Tensorboardログパス", - "Tensorboard Port": "Tensorboardポート", - "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています", - "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました", - "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。", - "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。", - "Training Configuration": "トレーニング設定", - "Training Error": "トレーニングエラー", - "Training stopped": "トレーニングが停止しました", - "Type name of the speaker": "話者の名前を入力", - "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください", - "Use LoRA": "LoRAを使用", - "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります", - "Use filelist": "ファイルリストを使用", - "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください", - "VITS Configuration": "VITS の構成", - "VQGAN Configuration": "VQGAN の構成", - "Validation Batch Size": "検証バッチサイズ", - "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)", - "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。", - "WebUI Host": "WebUIホスト", - "WebUI Port": "WebUIポート", - "Whisper Model": "Whisperモデル", - "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。", - "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします", - "latest": "最新", - "new": "新規", - "Realtime Transform Text": "リアルタイム変換テキスト", - "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)", - "Text Normalization": "テキスト正規化" - -} diff --git a/fish_speech/i18n/locale/pt_BR.json b/fish_speech/i18n/locale/pt_BR.json deleted file mode 100644 index 385f20272e19053ab9b6cf6463a84c8ece768c68..0000000000000000000000000000000000000000 --- a/fish_speech/i18n/locale/pt_BR.json +++ /dev/null @@ -1,133 +0,0 @@ -{ - "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.", - "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).", - "Accumulate Gradient Batches": "Acumular Lotes de Gradiente", - "Add to Processing Area": "Adicionar à Área de Processamento", - "Added path successfully!": "Caminho adicionado com sucesso!", - "Advanced Config": "Configuração Avançada", - "Base LLAMA Model": "Modelo LLAMA Base", - "Batch Inference": "Inferência em Lote", - "Batch Size": "Tamanho do Lote", - "Changing with the Model Path": "Alterando com o Caminho do Modelo", - - "Compile Model": "Compilar Modelo", - "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial", - "Copy": "Copiar", - "Data Preprocessing": "Pré-processamento de Dados", - "Data Preprocessing Path": "Caminho de Pré-processamento de Dados", - "Data Source": "Fonte de Dados", - "Decoder Model Config": "Configuração do Modelo Decodificador", - "Decoder Model Path": "Caminho do Modelo Decodificador", - "Disabled": "Desativado", - "Enable Initial Prompt": "Habilitar Prompt Inicial", - "Enable Reference Audio": "Habilitar Áudio de Referência", - "English": "Inglês", - "Japanese": "Japonês", - "Chinese": "Chinês", - "Portuguese": "Português", - "Spanish": "Espanhol", - "Error Message": "Mensagem de Erro", - "Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)", - "File Preprocessing": "Pré-processamento de Arquivos", - "Generate": "Gerar", - "Generated Audio": "Áudio Gerado", - "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)", - "Infer interface is closed": "A interface de inferência foi fechada", - "Inference Configuration": "Configuração de Inferência", - "Inference Server Configuration": "Configuração do Servidor de Inferência", - "Inference Server Error": "Erro do Servidor de Inferência", - "Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}", - "Initial Learning Rate": "Taxa de Aprendizagem Inicial", - "Initial Prompt": "Prompt Inicial", - "Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.", - "Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição", - "Input Text": "Texto de Entrada", - "Invalid path: {}": "Caminho inválido: {}", - "It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU", - "Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)", - "LLAMA Configuration": "Configuração do LLAMA", - "LLAMA Model Config": "Configuração do Modelo LLAMA", - "LLAMA Model Path": "Caminho do Modelo LLAMA", - "Labeling Device": "Dispositivo de Rotulagem", - "LoRA Model to be merged": "Modelo LoRA para mesclagem", - "Maximum Length per Sample": "Comprimento Máximo por Amostra", - "Maximum Training Steps": "Etapas Máximas de Treinamento", - "Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite", - "Merge": "Mesclar", - "Merge LoRA": "Mesclar LoRA", - "Merge successfully": "Mesclado com sucesso", - "Model Output Path": "Caminho de Saída do Modelo", - "Model Quantization": "Quantização do Modelo", - "Model Size": "Tamanho do Modelo", - "Move": "Mover", - "Move files successfully": "Arquivos movidos com sucesso", - "No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.", - "No selected options": "Nenhuma opção selecionada", - "Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)", - "Number of Workers": "Número de Processos", - "Open Inference Server": "Abrir Servidor de Inferência", - "Open Labeler WebUI": "Abrir WebUI de Rotulagem", - "Open Tensorboard": "Abrir Tensorboard", - "Opened labeler in browser": "WebUI de rotulagem aberta no navegador", - "Optional Label Language": "Idioma do Rótulo (Opcional)", - "Optional online ver": "Versão online (opcional)", - "Output Path": "Caminho de Saída", - "Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente", - "Post-quantification Precision": "Precisão Pós-quantização", - "Precision": "Precisão", - "Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador", - "Put your text here.": "Insira seu texto aqui.", - "Quantify": "Quantizar", - "Quantify successfully": "Quantizado com sucesso", - "Realtime Transform Text": "Transformar Texto em Tempo Real", - "Reference Audio": "Áudio de Referência", - "Reference Text": "Texto de Referência", - "warning": "Aviso", - "Pre-processing begins...": "O pré-processamento começou!", - "Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.", - "Remove Selected Data": "Remover Dados Selecionados", - "Removed path successfully!": "Caminho removido com sucesso!", - "Repetition Penalty": "Penalidade de Repetição", - "Save model every n steps": "Salvar modelo a cada n etapas", - "Select LLAMA ckpt": "Selecionar .ckpt do LLAMA", - "Select source file processing method": "Escolha como processar o arquivo de origem", - "Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)", - "Selected: {}": "Selecionado: {}", - "Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta", - "Start Training": "Iniciar Treinamento", - "Streaming Audio": "Áudio em Streaming", - "Streaming Generate": "Geração em Streaming", - "Tensorboard Host": "Host do Tensorboard", - "Tensorboard Log Path": "Caminho de Log do Tensorboard", - "Tensorboard Port": "Porta do Tensorboard", - "Tensorboard interface is closed": "A interface do Tensorboard está fechada", - "Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}", - "Text Normalization": "Normalização de Texto", - "Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.", - "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência", - "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.", - "Training Configuration": "Configuração de Treinamento", - "Training Error": "Erro de Treinamento", - "Training stopped": "Treinamento interrompido!", - "Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso", - "Use LoRA": "Usar LoRA", - "Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade", - "Use filelist": "Usar lista de arquivos", - "VQGAN Configuration": "Configuração do VQGAN", - "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)", - "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.", - "WebUI Host": "Host da WebUI", - "WebUI Port": "Porta da WebUI", - "Whisper Model": "Modelo Whisper", - "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).", - "auto": "automático", - "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+", - "latest": "mais recente", - "new": "novo", - "This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.", - "You don't need to train this model!": "Não é necessário treinar este modelo!", - "Yes": "Sim", - "No": "Não", - "version:": "versão:", - "author:": "autor:" -} diff --git a/fish_speech/i18n/locale/zh_CN.json b/fish_speech/i18n/locale/zh_CN.json deleted file mode 100644 index 3dd1a5cd1ccf3860ca508238cc64a68ca4fc3276..0000000000000000000000000000000000000000 --- a/fish_speech/i18n/locale/zh_CN.json +++ /dev/null @@ -1,122 +0,0 @@ -{ - "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed", - "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。", - "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.", - "Accumulate Gradient Batches": "梯度累积批次", - "Add to Processing Area": "加入处理区", - "Added path successfully!": "添加路径成功!", - "Advanced Config": "高级参数", - "Base LLAMA Model": "基础 LLAMA 模型", - "Batch Inference": "批量推理", - "Batch Size": "批次大小", - "Changing with the Model Path": "随模型路径变化", - "Chinese": "中文", - "Compile Model": "编译模型", - "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间", - "Copy": "复制", - "Data Preprocessing": "数据预处理", - "Data Preprocessing Path": "数据预处理路径", - "Data Source": "数据源", - "Decoder Model Config": "解码器模型配置", - "Decoder Model Path": "解码器模型路径", - "Disabled": "禁用", - "Enable Reference Audio": "启用参考音频", - "English": "英文", - "Error Message": "错误信息", - "File Preprocessing": "文件预处理", - "Generate": "生成", - "Generated Audio": "音频", - "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式", - "Infer interface is closed": "推理界面已关闭", - "Inference Configuration": "推理配置", - "Inference Server Configuration": "推理服务器配置", - "Inference Server Error": "推理服务器错误", - "Inferring interface is launched at {}": "推理界面已在 {} 上启动", - "Initial Learning Rate": "初始学习率", - "Input Audio & Source Path for Transcription": "输入音频和转录源路径", - "Input Text": "输入文本", - "Invalid path: {}": "无效路径: {}", - "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU", - "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭", - "Japanese": "日文", - "LLAMA Configuration": "LLAMA 配置", - "LLAMA Model Config": "LLAMA 模型配置", - "LLAMA Model Path": "LLAMA 模型路径", - "Labeling Device": "标注加速设备", - "LoRA Model to be merged": "要合并的 LoRA 模型", - "Maximum Audio Duration": "最大音频时长", - "Maximum Length per Sample": "每个样本的最大长度", - "Maximum Training Steps": "最大训练步数", - "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制", - "Merge": "合并", - "Merge LoRA": "合并 LoRA", - "Merge successfully": "合并成功", - "Minimum Audio Duration": "最小音频时长", - "Model Output Path": "模型输出路径", - "Model Size": "模型规模", - "Move": "移动", - "Move files successfully": "移动文件成功", - "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.", - "No selected options": "没有选择的选项", - "Number of Workers": "数据加载进程数", - "Open Inference Server": "打开推理服务器", - "Open Labeler WebUI": "打开标注工具", - "Open Tensorboard": "打开 Tensorboard", - "Opened labeler in browser": "在浏览器中打开标注工具", - "Optional Label Language": "[可选] 标注语言", - "Optional online ver": "[可选] 使用在线版", - "Output Path": "输出路径", - "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径", - "Precision": "精度", - "Probability of applying Speaker Condition": "应用说话人条件的概率", - "Put your text here.": "在此处输入文本.", - "Reference Audio": "参考音频", - "Reference Text": "参考文本", - "Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.", - "Remove Selected Data": "移除选中数据", - "Removed path successfully!": "移除路径成功!", - "Repetition Penalty": "重复惩罚", - "Save model every n steps": "每 n 步保存模型", - "Select LLAMA ckpt": "选择 LLAMA 检查点", - "Select VITS ckpt": "选择 VITS 检查点", - "Select VQGAN ckpt": "选择 VQGAN 检查点", - "Select source file processing method": "选择源文件处理方法", - "Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型", - "Selected: {}": "已选择: {}", - "Speaker": "说话人", - "Speaker is identified by the folder name": "自动根据父目录名称识别说话人", - "Start Training": "开始训练", - "Streaming Audio": "流式音频", - "Streaming Generate": "流式合成", - "Tensorboard Host": "Tensorboard 监听地址", - "Tensorboard Log Path": "Tensorboard 日志路径", - "Tensorboard Port": "Tensorboard 端口", - "Tensorboard interface is closed": "Tensorboard 界面已关闭", - "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动", - "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.", - "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.", - "Training Configuration": "训练配置", - "Training Error": "训练错误", - "Training stopped": "训练已停止", - "Type name of the speaker": "输入说话人的名称", - "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择", - "Use LoRA": "使用 LoRA", - "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量", - "Use filelist": "使用文件列表", - "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small", - "VITS Configuration": "VITS 配置", - "VQGAN Configuration": "VQGAN 配置", - "Validation Batch Size": "验证批次大小", - "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)", - "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.", - "WebUI Host": "WebUI 监听地址", - "WebUI Port": "WebUI 端口", - "Whisper Model": "Whisper 模型", - "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.", - "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed", - "latest": "最近的检查点", - "new": "创建新的检查点", - "Realtime Transform Text": "实时规范化文本", - "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览", - "Text Normalization": "文本规范化" -} diff --git a/fish_speech/i18n/scan.py b/fish_speech/i18n/scan.py deleted file mode 100644 index d0194c0f1a31dc95309c64626d13f04751a44ba1..0000000000000000000000000000000000000000 --- a/fish_speech/i18n/scan.py +++ /dev/null @@ -1,122 +0,0 @@ -import ast -import glob -import json -from collections import OrderedDict -from pathlib import Path - -from loguru import logger - -from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH - - -def extract_i18n_strings(node): - i18n_strings = [] - - if ( - isinstance(node, ast.Call) - and isinstance(node.func, ast.Name) - and node.func.id == "i18n" - ): - for arg in node.args: - if isinstance(arg, ast.Str): - i18n_strings.append(arg.s) - - for child_node in ast.iter_child_nodes(node): - i18n_strings.extend(extract_i18n_strings(child_node)) - - return i18n_strings - - -# scan the directory for all .py files (recursively) -# for each file, parse the code into an AST -# for each AST, extract the i18n strings - -strings = [] -folders = ["fish_speech", "tools"] -# for filename in glob.iglob("**/*.py", recursive=True): -for folder in folders: - for f in Path(folder).rglob("*.py"): - code = f.read_text(encoding="utf-8") - if "i18n(" in code: - tree = ast.parse(code) - i18n_strings = extract_i18n_strings(tree) - logger.info(f"Found {len(i18n_strings)} i18n strings in {f}") - strings.extend(i18n_strings) - -code_keys = set(strings) -logger.info(f"Total unique: {len(code_keys)}") - - -standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json" -with open(standard_file, "r", encoding="utf-8") as f: - standard_data = json.load(f, object_pairs_hook=OrderedDict) -standard_keys = set(standard_data.keys()) - -# Define the standard file name -unused_keys = standard_keys - code_keys -logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}") -for unused_key in unused_keys: - logger.info(f"\t{unused_key}") - -missing_keys = code_keys - standard_keys -logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}") -for missing_key in missing_keys: - logger.info(f"\t{missing_key}") - -code_keys_dict = OrderedDict() -for s in strings: - code_keys_dict[s] = s - -# write back -with open(standard_file, "w", encoding="utf-8") as f: - json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True) - f.write("\n") - -logger.info(f"Updated {standard_file}") - - -# Define the standard file name -standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json" - -# Find all JSON files in the directory -dir_path = I18N_FILE_PATH -languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE] - -# Load the standard file -with open(standard_file, "r", encoding="utf-8") as f: - standard_data = json.load(f, object_pairs_hook=OrderedDict) - -# Loop through each language file -for lang_file in languages: - # Load the language file - with open(lang_file, "r", encoding="utf-8") as f: - lang_data = json.load(f, object_pairs_hook=OrderedDict) - - # Find the difference between the language file and the standard file - diff = set(standard_data.keys()) - set(lang_data.keys()) - - miss = set(lang_data.keys()) - set(standard_data.keys()) - - # Add any missing keys to the language file - for key in diff: - lang_data[key] = "#!" + key - logger.info(f"Added missing key: {key} to {lang_file}") - - # Del any extra keys to the language file - for key in miss: - del lang_data[key] - logger.info(f"Del extra key: {key} from {lang_file}") - - # Sort the keys of the language file to match the order of the standard file - lang_data = OrderedDict( - sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0])) - ) - - # Save the updated language file - with open(lang_file, "w", encoding="utf-8") as f: - json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True) - f.write("\n") - - logger.info(f"Updated {lang_file}") - -logger.info("Done") diff --git a/fish_speech/models/text2semantic/__init__.py b/fish_speech/models/text2semantic/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/fish_speech/models/text2semantic/__pycache__/__init__.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 2660e31d27b749e906716f846ad0303f28c5d3ae..0000000000000000000000000000000000000000 Binary files a/fish_speech/models/text2semantic/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/models/text2semantic/__pycache__/lit_module.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/lit_module.cpython-310.pyc deleted file mode 100644 index 10287acff4182fbf2964bed0c6512b752fe087bc..0000000000000000000000000000000000000000 Binary files a/fish_speech/models/text2semantic/__pycache__/lit_module.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/models/text2semantic/__pycache__/llama.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/llama.cpython-310.pyc deleted file mode 100644 index c46a1595d473a5422c0f1a526faf162604c66191..0000000000000000000000000000000000000000 Binary files a/fish_speech/models/text2semantic/__pycache__/llama.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/models/text2semantic/__pycache__/lora.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/lora.cpython-310.pyc deleted file mode 100644 index 277545bc846fa08418ba7e846e38c006877bf95d..0000000000000000000000000000000000000000 Binary files a/fish_speech/models/text2semantic/__pycache__/lora.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/models/text2semantic/lit_module.py b/fish_speech/models/text2semantic/lit_module.py deleted file mode 100644 index df970400f8a073be4c4166a697245fabdf6b09b0..0000000000000000000000000000000000000000 --- a/fish_speech/models/text2semantic/lit_module.py +++ /dev/null @@ -1,202 +0,0 @@ -from typing import Any, Optional - -import lightning as L -import torch -import torch.nn.functional as F -from lightning.pytorch.utilities.types import OptimizerLRScheduler - -import fish_speech.utils as utils -from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID -from fish_speech.models.text2semantic.llama import NaiveTransformer - -log = utils.RankedLogger(__name__, rank_zero_only=True) - - -class TextToSemantic(L.LightningModule): - def __init__( - self, - model: NaiveTransformer, - optimizer: Any, - lr_scheduler: Any, - ): - super().__init__() - - self.model = model - self.optimizer_builder = optimizer - self.lr_scheduler_builder = lr_scheduler - - def forward(self, x): - return self.model(x) - - def on_save_checkpoint(self, checkpoint): - # Save only LoRA parameters - state_dict = checkpoint["state_dict"] - use_lora = any("lora" in name for name in state_dict.keys()) - if not use_lora: - return - - for name in list(state_dict.keys()): - if "lora" not in name: - state_dict.pop(name) - - def configure_optimizers(self) -> OptimizerLRScheduler: - # Get weight decay parameters - weight_decay_parameters, other_parameters = [], [] - for name, param in self.named_parameters(): - if ".bias" in name or "norm.weight" in name or ".embeddings." in name: - other_parameters.append(param) - else: - weight_decay_parameters.append(param) - - optimizer = self.optimizer_builder( - [ - {"params": weight_decay_parameters}, - {"params": other_parameters, "weight_decay": 0.0}, - ] - ) - - # Print the parameters and their weight decay - for i in optimizer.param_groups: - log.info( - f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters" - ) - - lr_scheduler = self.lr_scheduler_builder(optimizer) - - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": lr_scheduler, - "interval": "step", - }, - } - - # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90 - def get_batch_logps( - self, - logits: torch.FloatTensor, - labels: torch.LongTensor, - average_log_prob: bool = False, - ) -> torch.FloatTensor: - """Compute the log probabilities of the given labels under the given logits. - - Args: - logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size) - labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size) - average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. - - Returns: - A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. - """ - assert logits.shape[:-1] == labels.shape - - labels = labels.clone() - loss_mask = labels != -100 - - # dummy token; we'll ignore the losses on these tokens later - labels[labels == -100] = 0 - - per_token_logps = torch.gather( - logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1) - ).squeeze(-1) - - if average_log_prob: - return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - else: - return (per_token_logps * loss_mask).sum(-1) - - def _step(self, batch, batch_idx, stage: str): - is_train = stage == "train" - - if is_train: - # Key part to make lora work - # Otherwise the parameters are merged, which lead to incorrect gradients - self.model.train() - - # Do positive and negative samples in the same batch to speed up training - labels = batch["labels"] - outputs = self.model( - inp=batch["inputs"], - key_padding_mask=batch["attention_masks"], - ) - token_logits = outputs.token_logits - codebook_logits = outputs.codebook_logits - - # Generate labels - base_loss = F.cross_entropy( - token_logits.view(-1, token_logits.size(-1)), - labels[:, 0].reshape(-1), - ignore_index=-100, - ) - - codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT - semantic_loss = F.cross_entropy( - codebook_logits.view(-1, codebook_logits.size(-1)), - codebook_labels.reshape(-1), - ignore_index=-100, - ) - - loss = base_loss + semantic_loss - - self.log( - f"{stage}/loss", - loss, - on_step=is_train, - on_epoch=not is_train, - prog_bar=True, - logger=True, - sync_dist=not is_train, - ) - - self.log( - f"{stage}/base_loss", - base_loss, - on_step=is_train, - on_epoch=not is_train, - prog_bar=False, - logger=True, - sync_dist=not is_train, - ) - - self.log( - f"{stage}/semantic_loss", - semantic_loss, - on_step=is_train, - on_epoch=not is_train, - prog_bar=False, - logger=True, - sync_dist=not is_train, - ) - - # Top-5 accuracy - accuracy = self.get_accuracy(codebook_logits, codebook_labels) - self.log( - f"{stage}/top_5_accuracy", - accuracy, - on_step=is_train, - on_epoch=not is_train, - prog_bar=True, - logger=True, - sync_dist=not is_train, - ) - - return loss - - def get_accuracy(self, logits, labels): - mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID) - if mask.sum() == 0: - return torch.tensor(0.0, device=logits.device) - - _, indices = logits.topk(5, dim=-1) - correct = indices.eq(labels.unsqueeze(-1)) - correct[~mask] = 0 - correct = correct.sum() - accuracy = correct / mask.sum() - - return accuracy - - def training_step(self, batch, batch_idx): - return self._step(batch, batch_idx, "train") - - def validation_step(self, batch, batch_idx): - return self._step(batch, batch_idx, "val") diff --git a/fish_speech/models/text2semantic/llama.py b/fish_speech/models/text2semantic/llama.py deleted file mode 100644 index 0725dfb9b78b1154753641b69c959a2faadba48c..0000000000000000000000000000000000000000 --- a/fish_speech/models/text2semantic/llama.py +++ /dev/null @@ -1,779 +0,0 @@ -import json -import math -from collections import OrderedDict -from dataclasses import dataclass -from pathlib import Path -from typing import Optional - -import torch -import torch.nn as nn -from einops import rearrange -from loguru import logger -from torch import Tensor -from torch.nn import functional as F -from torch.nn.attention import SDPBackend, sdpa_kernel -from torch.utils.checkpoint import checkpoint -from transformers import AutoTokenizer - -from fish_speech.conversation import SEMANTIC_TOKEN -from fish_speech.utils import RankedLogger - -from .lora import LoraConfig, setup_lora - -log = RankedLogger(__name__, rank_zero_only=True) - - -def find_multiple(n: int, k: int) -> int: - if n % k == 0: - return n - return n + k - (n % k) - - -@dataclass -class BaseModelArgs: - model_type: str = "base" - - vocab_size: int = 32000 - n_layer: int = 32 - n_head: int = 32 - dim: int = 4096 - intermediate_size: int = None - n_local_heads: int = -1 - head_dim: int = 64 - rope_base: float = 10000 - norm_eps: float = 1e-5 - max_seq_len: int = 2048 - dropout: float = 0.0 - tie_word_embeddings: bool = True - attention_qkv_bias: bool = False - - # Codebook configs - codebook_size: int = 160 - num_codebooks: int = 4 - - # Gradient checkpointing - use_gradient_checkpointing: bool = True - - # Initialize the model - initializer_range: float = 0.02 - - def __post_init__(self): - if self.n_local_heads == -1: - self.n_local_heads = self.n_head - if self.intermediate_size is None: - hidden_dim = 4 * self.dim - n_hidden = int(2 * hidden_dim / 3) - self.intermediate_size = find_multiple(n_hidden, 256) - self.head_dim = self.dim // self.n_head - - @staticmethod - def from_pretrained(path: str): - path = Path(path) - - if path.is_dir(): - path = path / "config.json" - - with open(path, "r", encoding="utf-8") as f: - data = json.load(f) - - match data["model_type"]: - case "naive": - cls = NaiveModelArgs - case "dual_ar": - cls = DualARModelArgs - case _: - raise ValueError(f"Unknown model type: {data['model_type']}") - - return cls(**data) - - def save(self, path: str): - with open(path, "w") as f: - json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False) - - -@dataclass -class NaiveModelArgs(BaseModelArgs): - model_type: str = "naive" - - -@dataclass -class DualARModelArgs(BaseModelArgs): - model_type: str = "dual_ar" - n_fast_layer: int = 4 - - -class KVCache(nn.Module): - def __init__( - self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16 - ): - super().__init__() - cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim) - self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) - self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) - - def update(self, input_pos, k_val, v_val): - # input_pos: [S], k_val: [B, H, S, D] - assert input_pos.shape[0] == k_val.shape[2] - - k_out = self.k_cache - v_out = self.v_cache - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val - - return k_out, v_out - - -@dataclass -class TransformerForwardResult: - token_logits: Tensor - codebook_logits: Tensor - - -@dataclass -class BaseTransformerForwardResult: - logits: Tensor - hidden_states: Tensor - - -class BaseTransformer(nn.Module): - def __init__( - self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True - ) -> None: - super().__init__() - self.config = config - self.tokenizer = tokenizer - - self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN) - - # Slow transformer - self.embeddings = nn.Embedding( - config.vocab_size, - config.dim, - ) - self.codebook_embeddings = nn.Embedding( - config.codebook_size * config.num_codebooks, - config.dim, - ) - self.layers = nn.ModuleList( - TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer) - ) - self.norm = RMSNorm(config.dim, eps=config.norm_eps) - - if self.config.tie_word_embeddings is False: - self.output = nn.Linear( - config.dim, - config.vocab_size, - bias=False, - ) - - self.register_buffer( - "freqs_cis", - precompute_freqs_cis( - config.max_seq_len, - config.dim // config.n_head, - config.rope_base, - ), - persistent=False, - ) - self.register_buffer( - "causal_mask", - torch.tril( - torch.ones( - config.max_seq_len, - config.max_seq_len, - dtype=torch.bool, - ) - ), - persistent=False, - ) - - # For kv cache - self.max_batch_size = -1 - self.max_seq_len = -1 - - if init_weights: - self.apply(self._init_weights) - - def setup_caches( - self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16 - ): - if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size: - return - - head_dim = self.config.dim // self.config.n_head - max_seq_len = find_multiple(max_seq_len, 8) - self.max_seq_len = max_seq_len - self.max_batch_size = max_batch_size - - for b in self.layers: - b.attention.kv_cache = KVCache( - max_batch_size, - max_seq_len, - self.config.n_local_heads, - head_dim, - dtype=dtype, - ) - - def embed(self, x: Tensor) -> Tensor: - vocab_embeds = [self.embeddings(x[:, 0])] - for i in range(self.config.num_codebooks): - emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size) - emb[x[:, 0] != self.semantic_token_id] = 0 - vocab_embeds.append(emb) - - x = torch.stack(vocab_embeds, dim=3) - x = x.sum(dim=3) - - return x - - def forward( - self, - inp: Tensor, - key_padding_mask: Optional[Tensor] = None, - ) -> BaseTransformerForwardResult: - seq_len = inp.size(2) - - # Here we want to merge the embeddings of the codebooks - x = self.embed(inp) - - freqs_cis = self.freqs_cis[:seq_len] - - # Not that the causal mask here follows the definition of scaled_dot_product_attention - # That is, FALSE means masked out - # To maintain consistency, key_padding_mask use TRUE to mask out - mask = None - if key_padding_mask is not None: - mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K) - mask = mask & key_padding_mask[:, None, None, :].logical_not() - - for layer in self.layers: - if self.config.use_gradient_checkpointing and self.training: - x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True) - else: - x = layer(x, freqs_cis, mask) - - # We got slow_out here - slow_out = self.norm(x) - - if self.config.tie_word_embeddings: - token_logits = F.linear(slow_out, self.embeddings.weight) - else: - token_logits = self.output(slow_out) - - return BaseTransformerForwardResult( - logits=token_logits, - hidden_states=x, - ) - - def forward_generate( - self, - x: Tensor, - input_pos: Optional[Tensor] = None, - return_all: bool = False, - ) -> BaseTransformerForwardResult: - # This is used for generation, optimized for torch compile - assert ( - self.max_seq_len != -1 and self.max_batch_size != -1 - ), "Please call setup_caches before forward_generate" - - x = self.embed(x) - - mask = self.causal_mask[ - None, None, input_pos, : self.max_seq_len - ] # (B, N, Q, K) - freqs_cis = self.freqs_cis[input_pos] - - for layer in self.layers: - x = layer(x, freqs_cis, mask, input_pos=input_pos) - - # If prefill, we only calculate the logits of last token - if x.size(1) > 1 and not return_all: - x = x[:, -1:] - - # We got slow_out here - slow_out = self.norm(x) - - if self.config.tie_word_embeddings: - token_logits = F.linear(slow_out, self.embeddings.weight) - else: - token_logits = self.output(slow_out) - - return BaseTransformerForwardResult( - logits=token_logits, - hidden_states=x, - ) - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - @staticmethod - def from_pretrained( - path: str, - load_weights: bool = False, - max_length: int | None = None, - lora_config: LoraConfig | None = None, - rope_base: int | None = None, - ) -> "BaseTransformer": - config = BaseModelArgs.from_pretrained(str(path)) - if max_length is not None: - config.max_seq_len = max_length - log.info(f"Override max_seq_len to {max_length}") - - if rope_base is not None: - config.rope_base = rope_base - log.info(f"Override rope_base to {rope_base}") - - match config.model_type: - case "naive": - model_cls = NaiveTransformer - case "dual_ar": - model_cls = DualARTransformer - case _: - raise ValueError(f"Unknown model type: {config.model_type}") - - tokenizer = AutoTokenizer.from_pretrained(str(path)) - log.info(f"Loading model from {path}, config: {config}") - model = model_cls(config, tokenizer=tokenizer) - - if lora_config is not None: - setup_lora(model, lora_config) - log.info(f"LoRA setup: {lora_config}") - - if load_weights is False: - log.info("Randomly initialized model") - else: - - if "int8" in str(Path(path)): - logger.info("Using int8 weight-only quantization!") - from tools.llama.quantize import WeightOnlyInt8QuantHandler - - simple_quantizer = WeightOnlyInt8QuantHandler(model) - model = simple_quantizer.convert_for_runtime() - - if "int4" in str(Path(path)): - logger.info("Using int4 quantization!") - path_comps = path.name.split("-") - assert path_comps[-2].startswith("g") - groupsize = int(path_comps[-2][1:]) - from tools.llama.quantize import WeightOnlyInt4QuantHandler - - simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) - model = simple_quantizer.convert_for_runtime() - - weights = torch.load( - Path(path) / "model.pth", map_location="cpu", mmap=True - ) - - if "state_dict" in weights: - logger.warning( - "Using a TextToSemantic LightningModule checkpoint, " - "please make sure it is a full model, not a LoRA model." - ) - weights = weights["state_dict"] - - if next(iter(weights.keys())).startswith("model."): - logger.info( - f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys" - ) - new_weights = OrderedDict() - for k, v in weights.items(): - new_weights[k.replace("model.", "")] = v - weights = new_weights - - # Verify the name and shape of parameters since strict=False in load_state_dict. - for k, v in model.named_parameters(): - if k not in weights: - logger.warning(f"No weight for {k}") - elif v.shape != weights[k].shape: - logger.warning( - f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}" - ) - - err = model.load_state_dict(weights, strict=False, assign=True) - log.info(f"Loaded weights with error: {err}") - - return model - - def save_pretrained(self, path: str, drop_lora: bool = False): - path = Path(path) - path.mkdir(parents=True, exist_ok=True) - - self.config.save(path / "config.json") - state_dict = self.state_dict() - - if drop_lora: - for key in list(state_dict.keys()): - if "lora" not in key: - continue - - state_dict.pop(key) - log.info(f"Drop LoRA parameter: {key}") - - torch.save(state_dict, path / "model.pth") - self.tokenizer.save_pretrained(path) - - -class NaiveTransformer(BaseTransformer): - def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None: - super().__init__(config, init_weights=False, tokenizer=tokenizer) - - self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps) - self.codebook_output = nn.Linear( - config.dim, - config.codebook_size * config.num_codebooks, - bias=False, - ) - - self.apply(self._init_weights) - - def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult: - token_logits = result.logits - x = result.hidden_states - - # Codebook - codebook_logits = self.codebook_output(self.codebook_norm(x)) - codebook_logits = rearrange( - codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks - ) - - return TransformerForwardResult( - token_logits=token_logits, - codebook_logits=codebook_logits, - ) - - def forward( - self, - inp: Tensor, - key_padding_mask: Optional[Tensor] = None, - ) -> TransformerForwardResult: - result = super().forward( - inp=inp, - key_padding_mask=key_padding_mask, - ) - return self.decode(result) - - def forward_generate( - self, x: Tensor, input_pos: Optional[Tensor] = None - ) -> TransformerForwardResult: - result = super().forward_generate(x, input_pos) - return self.decode(result) - - -class DualARTransformer(BaseTransformer): - def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None: - super().__init__(config, init_weights=False, tokenizer=tokenizer) - - # Fast transformer - self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim) - - # The equivalent bs is so large that sdpa doesn't work - self.fast_layers = nn.ModuleList( - TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer) - ) - self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps) - self.fast_output = nn.Linear( - config.dim, - config.codebook_size, - bias=False, - ) - - self.apply(self._init_weights) - - def setup_caches( - self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16 - ): - super().setup_caches(max_batch_size, max_seq_len, dtype) - - head_dim = self.config.dim // self.config.n_head - - # Fast transformer - # The max seq len here is the number of codebooks - for b in self.fast_layers: - b.attention.kv_cache = KVCache( - max_batch_size, - self.config.num_codebooks, - self.config.n_local_heads, - head_dim, - dtype=dtype, - ) - - def forward( - self, - inp: Tensor, - key_padding_mask: Optional[Tensor] = None, - ) -> TransformerForwardResult: - parent_result = super().forward(inp, key_padding_mask) - token_logits = parent_result.logits - x = parent_result.hidden_states - - # Fast transformer - fast_seq_len = self.config.num_codebooks - fast_mask = self.causal_mask[ - None, None, :fast_seq_len, :fast_seq_len - ] # (B, N, Q, K) - fast_freqs_cis = self.freqs_cis[:fast_seq_len] - - # Drop the last token and rotate left - codebooks = inp[:, 1:-1, 1:] - codebooks = F.pad(codebooks, (0, 1), value=0) - codebook_embeddings = self.fast_embeddings(codebooks) - x = torch.cat([x[:, None], codebook_embeddings], dim=1) - b, s = x.size(0), x.size(2) - x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len - - # Remove padded part - codebooks = rearrange(codebooks, "b n s -> (b s) n") - codebook_mask = (codebooks == 0).all(dim=-1) - - if torch.all(codebook_mask): - # If all codebooks are padded, we keep first 8 to make sure the model runs - codebook_mask[:8] = False - - x_bs, x_len = x.size(0), x.size(1) - x = x[~codebook_mask] - - for layer in self.fast_layers: - if self.config.use_gradient_checkpointing and self.training: - x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True) - else: - x = layer(x, fast_freqs_cis, fast_mask) - - # unflatten the batch and num_codebooks - fast_out = self.fast_norm(x) - codebook_logits = self.fast_output(fast_out) - - # Re-pad the codebook_logits - buffer = torch.zeros( - x_bs, - x_len, - codebook_logits.size(-1), - device=codebook_logits.device, - dtype=codebook_logits.dtype, - ) - buffer[~codebook_mask] = codebook_logits - codebook_logits = buffer - - assert codebook_logits.shape[1] == self.config.num_codebooks - codebook_logits = rearrange( - codebook_logits, - "(b s) n d -> b s n d", - b=b, - s=s, - n=self.config.num_codebooks, - ) - - return TransformerForwardResult( - token_logits=token_logits, - codebook_logits=codebook_logits, - ) - - def forward_generate_fast( - self, x: Tensor, input_pos: Optional[Tensor] = None - ) -> Tensor: - # Fast transformer - x = x.view(1, 1, -1) - - fast_mask = self.causal_mask[ - None, None, input_pos, : self.config.num_codebooks - ] # (B, N, Q, K) - fast_freqs_cis = self.freqs_cis[input_pos] - - for layer in self.fast_layers: - x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos) - - # unflatten the batch and num_codebooks - fast_out = self.fast_norm(x) # only take the last token - codebook_logits = self.fast_output(fast_out) - - return codebook_logits - - -class TransformerBlock(nn.Module): - def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None: - super().__init__() - self.attention = Attention(config, use_sdpa=use_sdpa) - self.feed_forward = FeedForward(config) - self.ffn_norm = RMSNorm(config.dim, config.norm_eps) - self.attention_norm = RMSNorm(config.dim, config.norm_eps) - - def forward( - self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None - ) -> Tensor: - h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) - out = h + self.feed_forward(self.ffn_norm(h)) - return out - - -class Attention(nn.Module): - def __init__(self, config: BaseModelArgs, use_sdpa: bool = True): - super().__init__() - assert config.dim % config.n_head == 0 - - total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim - # key, query, value projections for all heads, but in a batch - self.wqkv = nn.Linear( - config.dim, total_head_dim, bias=config.attention_qkv_bias - ) - self.wo = nn.Linear(config.dim, config.dim, bias=False) - self.kv_cache = None - - self.dropout = config.dropout - self.n_head = config.n_head - self.head_dim = config.head_dim - self.n_local_heads = config.n_local_heads - self.dim = config.dim - self.use_sdpa = use_sdpa - self._register_load_state_dict_pre_hook(self.load_hook) - - def load_hook(self, state_dict, prefix, *args): - if prefix + "wq.weight" in state_dict: - wq = state_dict.pop(prefix + "wq.weight") - wk = state_dict.pop(prefix + "wk.weight") - wv = state_dict.pop(prefix + "wv.weight") - state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) - - def forward( - self, - x: Tensor, - freqs_cis: Tensor, - mask: Tensor, - input_pos: Optional[Tensor] = None, - ) -> Tensor: - bsz, seqlen, _ = x.shape - - kv_size = self.n_local_heads * self.head_dim - q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) - - q = q.view(bsz, seqlen, self.n_head, self.head_dim) - k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) - v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) - - q = apply_rotary_emb(q, freqs_cis) - k = apply_rotary_emb(k, freqs_cis) - - q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) - - if self.kv_cache is not None: - k, v = self.kv_cache.update(input_pos, k, v) - - k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) - v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) - - if self.use_sdpa: - if mask is None: - with sdpa_kernel(SDPBackend.FLASH_ATTENTION): - y = F.scaled_dot_product_attention( - q, - k, - v, - dropout_p=self.dropout if self.training else 0.0, - is_causal=True, - # No third party attn_mask here to use flash_attention - ) - else: - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=mask, - dropout_p=self.dropout if self.training else 0.0, - ) - else: - y = self.eq_scaled_dot_product_attention( - q, - k, - v, - attn_mask=mask, - dropout_p=self.dropout if self.training else 0.0, - ) - - y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) - - return self.wo(y) - - def eq_scaled_dot_product_attention( - self, - query, - key, - value, - attn_mask=None, - dropout_p=0.0, - ) -> torch.Tensor: - # This is a standard scaled dot product attention - # It's low efficient, but it doesn't raise cuda error - - L, S = query.size(-2), key.size(-2) - scale_factor = 1 / math.sqrt(query.size(-1)) - attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device) - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) - else: - attn_bias += attn_mask - - attn_weight = query @ key.transpose(-2, -1) * scale_factor - attn_weight += attn_bias - attn_weight = torch.softmax(attn_weight, dim=-1) - attn_weight = torch.dropout(attn_weight, dropout_p, train=True) - - return attn_weight @ value - - -class FeedForward(nn.Module): - def __init__(self, config: BaseModelArgs) -> None: - super().__init__() - self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) - self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) - self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) - - def forward(self, x: Tensor) -> Tensor: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) - - -class RMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-5): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) - - def forward(self, x: Tensor) -> Tensor: - output = self._norm(x.float()).type_as(x) - return output * self.weight - - -def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor: - freqs = 1.0 / ( - base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) - ) - t = torch.arange(seq_len, device=freqs.device) - freqs = torch.outer(t, freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) - cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) - return cache.to(dtype=torch.bfloat16) - - -def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: - xshaped = x.float().reshape(*x.shape[:-1], -1, 2) - freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) - x_out2 = torch.stack( - [ - xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], - xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], - ], - -1, - ) - - x_out2 = x_out2.flatten(3) - return x_out2.type_as(x) diff --git a/fish_speech/models/text2semantic/lora.py b/fish_speech/models/text2semantic/lora.py deleted file mode 100644 index 647ca6fcccf038e17d2cf91a2874281dff3e0938..0000000000000000000000000000000000000000 --- a/fish_speech/models/text2semantic/lora.py +++ /dev/null @@ -1,92 +0,0 @@ -from dataclasses import dataclass - -import loralib as lora - - -@dataclass -class LoraConfig: - r: int - lora_alpha: float - lora_dropout: float = 0.0 - - -def setup_lora(model, lora_config): - # Replace the embedding layer with a LoRA layer - model.embeddings = lora.Embedding( - num_embeddings=model.embeddings.num_embeddings, - embedding_dim=model.embeddings.embedding_dim, - padding_idx=model.embeddings.padding_idx, - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - ) - - model.codebook_embeddings = lora.Embedding( - num_embeddings=model.codebook_embeddings.num_embeddings, - embedding_dim=model.codebook_embeddings.embedding_dim, - padding_idx=model.codebook_embeddings.padding_idx, - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - ) - - # Replace output layer with a LoRA layer - linears = [(model, "output")] - - # Replace all linear layers with LoRA layers - for layer in model.layers: - linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) - linears.extend( - [ - (layer.feed_forward, "w1"), - (layer.feed_forward, "w2"), - (layer.feed_forward, "w3"), - ] - ) - - if hasattr(model, "fast_layers"): - model.fast_embeddings = lora.Embedding( - num_embeddings=model.fast_embeddings.num_embeddings, - embedding_dim=model.fast_embeddings.embedding_dim, - padding_idx=model.fast_embeddings.padding_idx, - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - ) - - # Dual-AR model - linears.append((model, "fast_output")) - - for layer in model.fast_layers: - linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) - linears.extend( - [ - (layer.feed_forward, "w1"), - (layer.feed_forward, "w2"), - (layer.feed_forward, "w3"), - ] - ) - - for module, layer in linears: - updated_linear = lora.Linear( - in_features=getattr(module, layer).in_features, - out_features=getattr(module, layer).out_features, - bias=getattr(module, layer).bias, - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - lora_dropout=lora_config.lora_dropout, - ) - setattr(module, layer, updated_linear) - - # Mark only the LoRA layers as trainable - lora.mark_only_lora_as_trainable(model, bias="none") - - -def get_merged_state_dict(model): - # This line will merge the state dict of the model and the LoRA parameters - model.eval() - - # Then we need to remove the LoRA parameters from the state dict - state_dict = model.state_dict() - for name in list(state_dict.keys()): - if "lora" in name: - state_dict.pop(name) - - return state_dict diff --git a/fish_speech/models/vqgan/__init__.py b/fish_speech/models/vqgan/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/fish_speech/models/vqgan/__pycache__/__init__.cpython-310.pyc b/fish_speech/models/vqgan/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 7370a6672b015f38616e92542abc71ddeeb7a87e..0000000000000000000000000000000000000000 Binary files a/fish_speech/models/vqgan/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/models/vqgan/modules/__pycache__/firefly.cpython-310.pyc b/fish_speech/models/vqgan/modules/__pycache__/firefly.cpython-310.pyc deleted file mode 100644 index 588bdb4e0f0f6fc5f9838713164c8ed4158b3303..0000000000000000000000000000000000000000 Binary files a/fish_speech/models/vqgan/modules/__pycache__/firefly.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/models/vqgan/modules/__pycache__/fsq.cpython-310.pyc b/fish_speech/models/vqgan/modules/__pycache__/fsq.cpython-310.pyc deleted file mode 100644 index 22aab32a8842e848cceb650a0a9274c4402bfddb..0000000000000000000000000000000000000000 Binary files a/fish_speech/models/vqgan/modules/__pycache__/fsq.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/models/vqgan/modules/firefly.py b/fish_speech/models/vqgan/modules/firefly.py deleted file mode 100644 index aa21839b544174d5d91378c5daf8fe1b376a154a..0000000000000000000000000000000000000000 --- a/fish_speech/models/vqgan/modules/firefly.py +++ /dev/null @@ -1,596 +0,0 @@ -import math -from functools import partial -from math import prod -from typing import Callable - -import torch -import torch.nn.functional as F -from torch import nn -from torch.nn.utils.parametrizations import weight_norm -from torch.nn.utils.parametrize import remove_parametrizations -from torch.utils.checkpoint import checkpoint - - -def sequence_mask(length, max_length=None): - if max_length is None: - max_length = length.max() - x = torch.arange(max_length, dtype=length.dtype, device=length.device) - return x.unsqueeze(0) < length.unsqueeze(1) - - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv1D") != -1: - m.weight.data.normal_(mean, std) - - -def get_padding(kernel_size, dilation=1): - return (kernel_size * dilation - dilation) // 2 - - -def unpad1d(x: torch.Tensor, paddings: tuple[int, int]): - """Remove padding from x, handling properly zero padding. Only for 1d!""" - padding_left, padding_right = paddings - assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) - assert (padding_left + padding_right) <= x.shape[-1] - end = x.shape[-1] - padding_right - return x[..., padding_left:end] - - -def get_extra_padding_for_conv1d( - x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 -) -> int: - """See `pad_for_conv1d`.""" - length = x.shape[-1] - n_frames = (length - kernel_size + padding_total) / stride + 1 - ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) - return ideal_length - length - - -def pad1d( - x: torch.Tensor, - paddings: tuple[int, int], - mode: str = "zeros", - value: float = 0.0, -): - """Tiny wrapper around F.pad, just to allow for reflect padding on small input. - If this is the case, we insert extra 0 padding to the right - before the reflection happen. - """ - length = x.shape[-1] - padding_left, padding_right = paddings - assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) - if mode == "reflect": - max_pad = max(padding_left, padding_right) - extra_pad = 0 - if length <= max_pad: - extra_pad = max_pad - length + 1 - x = F.pad(x, (0, extra_pad)) - padded = F.pad(x, paddings, mode, value) - end = padded.shape[-1] - extra_pad - return padded[..., :end] - else: - return F.pad(x, paddings, mode, value) - - -class FishConvNet(nn.Module): - def __init__( - self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1 - ): - super(FishConvNet, self).__init__() - self.conv = nn.Conv1d( - in_channels, - out_channels, - kernel_size, - stride=stride, - dilation=dilation, - groups=groups, - ) - self.stride = stride - self.kernel_size = (kernel_size - 1) * dilation + 1 - self.dilation = dilation - - def forward(self, x): - pad = self.kernel_size - self.stride - extra_padding = get_extra_padding_for_conv1d( - x, self.kernel_size, self.stride, pad - ) - x = pad1d(x, (pad, extra_padding), mode="constant", value=0) - return self.conv(x).contiguous() - - def weight_norm(self, name="weight", dim=0): - self.conv = weight_norm(self.conv, name=name, dim=dim) - return self - - def remove_weight_norm(self): - self.conv = remove_parametrizations(self.conv) - return self - - -class FishTransConvNet(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1): - super(FishTransConvNet, self).__init__() - self.conv = nn.ConvTranspose1d( - in_channels, out_channels, kernel_size, stride=stride, dilation=dilation - ) - self.stride = stride - self.kernel_size = kernel_size - - def forward(self, x): - x = self.conv(x) - pad = self.kernel_size - self.stride - padding_right = math.ceil(pad) - padding_left = pad - padding_right - x = unpad1d(x, (padding_left, padding_right)) - return x.contiguous() - - def weight_norm(self, name="weight", dim=0): - self.conv = weight_norm(self.conv, name=name, dim=dim) - return self - - def remove_weight_norm(self): - self.conv = remove_parametrizations(self.conv) - return self - - -class ResBlock1(torch.nn.Module): - def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): - super().__init__() - - self.convs1 = nn.ModuleList( - [ - FishConvNet( - channels, channels, kernel_size, stride=1, dilation=dilation[0] - ).weight_norm(), - FishConvNet( - channels, channels, kernel_size, stride=1, dilation=dilation[1] - ).weight_norm(), - FishConvNet( - channels, channels, kernel_size, stride=1, dilation=dilation[2] - ).weight_norm(), - ] - ) - self.convs1.apply(init_weights) - - self.convs2 = nn.ModuleList( - [ - FishConvNet( - channels, channels, kernel_size, stride=1, dilation=dilation[0] - ).weight_norm(), - FishConvNet( - channels, channels, kernel_size, stride=1, dilation=dilation[1] - ).weight_norm(), - FishConvNet( - channels, channels, kernel_size, stride=1, dilation=dilation[2] - ).weight_norm(), - ] - ) - self.convs2.apply(init_weights) - - def forward(self, x): - for c1, c2 in zip(self.convs1, self.convs2): - xt = F.silu(x) - xt = c1(xt) - xt = F.silu(xt) - xt = c2(xt) - x = xt + x - return x - - def remove_parametrizations(self): - for conv in self.convs1: - remove_parametrizations(conv, tensor_name="weight") - for conv in self.convs2: - remove_parametrizations(conv, tensor_name="weight") - - -class ParallelBlock(nn.Module): - def __init__( - self, - channels: int, - kernel_sizes: tuple[int] = (3, 7, 11), - dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)), - ): - super().__init__() - - assert len(kernel_sizes) == len(dilation_sizes) - - self.blocks = nn.ModuleList() - for k, d in zip(kernel_sizes, dilation_sizes): - self.blocks.append(ResBlock1(channels, k, d)) - - def forward(self, x): - return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0) - - def remove_parametrizations(self): - for block in self.blocks: - block.remove_parametrizations() - - -class HiFiGANGenerator(nn.Module): - def __init__( - self, - *, - hop_length: int = 512, - upsample_rates: tuple[int] = (8, 8, 2, 2, 2), - upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2), - resblock_kernel_sizes: tuple[int] = (3, 7, 11), - resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)), - num_mels: int = 128, - upsample_initial_channel: int = 512, - pre_conv_kernel_size: int = 7, - post_conv_kernel_size: int = 7, - post_activation: Callable = partial(nn.SiLU, inplace=True), - ): - super().__init__() - - assert ( - prod(upsample_rates) == hop_length - ), f"hop_length must be {prod(upsample_rates)}" - - self.conv_pre = FishConvNet( - num_mels, - upsample_initial_channel, - pre_conv_kernel_size, - stride=1, - ).weight_norm() - - self.num_upsamples = len(upsample_rates) - self.num_kernels = len(resblock_kernel_sizes) - - self.noise_convs = nn.ModuleList() - self.ups = nn.ModuleList() - - for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): - self.ups.append( - FishTransConvNet( - upsample_initial_channel // (2**i), - upsample_initial_channel // (2 ** (i + 1)), - k, - stride=u, - ).weight_norm() - ) - - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = upsample_initial_channel // (2 ** (i + 1)) - self.resblocks.append( - ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes) - ) - - self.activation_post = post_activation() - self.conv_post = FishConvNet( - ch, 1, post_conv_kernel_size, stride=1 - ).weight_norm() - self.ups.apply(init_weights) - self.conv_post.apply(init_weights) - - def forward(self, x): - x = self.conv_pre(x) - - for i in range(self.num_upsamples): - x = F.silu(x, inplace=True) - x = self.ups[i](x) - - if self.training and self.checkpointing: - x = checkpoint( - self.resblocks[i], - x, - use_reentrant=False, - ) - else: - x = self.resblocks[i](x) - - x = self.activation_post(x) - x = self.conv_post(x) - x = torch.tanh(x) - - return x - - def remove_parametrizations(self): - for up in self.ups: - remove_parametrizations(up, tensor_name="weight") - for block in self.resblocks: - block.remove_parametrizations() - remove_parametrizations(self.conv_pre, tensor_name="weight") - remove_parametrizations(self.conv_post, tensor_name="weight") - - -# DropPath copied from timm library -def drop_path( - x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True -): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - - This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, - the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for - changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use - 'survival rate' as the argument. - - """ # noqa: E501 - - if drop_prob == 0.0 or not training: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * ( - x.ndim - 1 - ) # work with diff dim tensors, not just 2D ConvNets - random_tensor = x.new_empty(shape).bernoulli_(keep_prob) - if keep_prob > 0.0 and scale_by_keep: - random_tensor.div_(keep_prob) - return x * random_tensor - - -class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501 - - def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): - super(DropPath, self).__init__() - self.drop_prob = drop_prob - self.scale_by_keep = scale_by_keep - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) - - def extra_repr(self): - return f"drop_prob={round(self.drop_prob,3):0.3f}" - - -class LayerNorm(nn.Module): - r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. - The ordering of the dimensions in the inputs. channels_last corresponds to inputs with - shape (batch_size, height, width, channels) while channels_first corresponds to inputs - with shape (batch_size, channels, height, width). - """ # noqa: E501 - - def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): - super().__init__() - self.weight = nn.Parameter(torch.ones(normalized_shape)) - self.bias = nn.Parameter(torch.zeros(normalized_shape)) - self.eps = eps - self.data_format = data_format - if self.data_format not in ["channels_last", "channels_first"]: - raise NotImplementedError - self.normalized_shape = (normalized_shape,) - - def forward(self, x): - if self.data_format == "channels_last": - return F.layer_norm( - x, self.normalized_shape, self.weight, self.bias, self.eps - ) - elif self.data_format == "channels_first": - u = x.mean(1, keepdim=True) - s = (x - u).pow(2).mean(1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.eps) - x = self.weight[:, None] * x + self.bias[:, None] - return x - - -# ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py -class ConvNeXtBlock(nn.Module): - r"""ConvNeXt Block. There are two equivalent implementations: - (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) - (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back - We use (2) as we find it slightly faster in PyTorch - - Args: - dim (int): Number of input channels. - drop_path (float): Stochastic depth rate. Default: 0.0 - layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. - kernel_size (int): Kernel size for depthwise conv. Default: 7. - dilation (int): Dilation for depthwise conv. Default: 1. - """ # noqa: E501 - - def __init__( - self, - dim: int, - drop_path: float = 0.0, - layer_scale_init_value: float = 1e-6, - mlp_ratio: float = 4.0, - kernel_size: int = 7, - dilation: int = 1, - ): - super().__init__() - - self.dwconv = FishConvNet( - dim, - dim, - kernel_size=kernel_size, - # padding=int(dilation * (kernel_size - 1) / 2), - groups=dim, - ) # depthwise conv - self.norm = LayerNorm(dim, eps=1e-6) - self.pwconv1 = nn.Linear( - dim, int(mlp_ratio * dim) - ) # pointwise/1x1 convs, implemented with linear layers - self.act = nn.GELU() - self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim) - self.gamma = ( - nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) - if layer_scale_init_value > 0 - else None - ) - self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - - def forward(self, x, apply_residual: bool = True): - input = x - - x = self.dwconv(x) - x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C) - x = self.norm(x) - x = self.pwconv1(x) - x = self.act(x) - x = self.pwconv2(x) - - if self.gamma is not None: - x = self.gamma * x - - x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L) - x = self.drop_path(x) - - if apply_residual: - x = input + x - - return x - - -class ConvNeXtEncoder(nn.Module): - def __init__( - self, - input_channels: int = 3, - depths: list[int] = [3, 3, 9, 3], - dims: list[int] = [96, 192, 384, 768], - drop_path_rate: float = 0.0, - layer_scale_init_value: float = 1e-6, - kernel_size: int = 7, - ): - super().__init__() - assert len(depths) == len(dims) - - self.downsample_layers = nn.ModuleList() - stem = nn.Sequential( - FishConvNet( - input_channels, - dims[0], - kernel_size=7, - # padding=3, - # padding_mode="replicate", - # padding_mode="zeros", - ), - LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), - ) - self.downsample_layers.append(stem) - - for i in range(len(depths) - 1): - mid_layer = nn.Sequential( - LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), - nn.Conv1d(dims[i], dims[i + 1], kernel_size=1), - ) - self.downsample_layers.append(mid_layer) - - self.stages = nn.ModuleList() - dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] - - cur = 0 - for i in range(len(depths)): - stage = nn.Sequential( - *[ - ConvNeXtBlock( - dim=dims[i], - drop_path=dp_rates[cur + j], - layer_scale_init_value=layer_scale_init_value, - kernel_size=kernel_size, - ) - for j in range(depths[i]) - ] - ) - self.stages.append(stage) - cur += depths[i] - - self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first") - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv1d, nn.Linear)): - nn.init.trunc_normal_(m.weight, std=0.02) - nn.init.constant_(m.bias, 0) - - def forward( - self, - x: torch.Tensor, - ) -> torch.Tensor: - for i in range(len(self.downsample_layers)): - x = self.downsample_layers[i](x) - x = self.stages[i](x) - - return self.norm(x) - - -class FireflyArchitecture(nn.Module): - def __init__( - self, - backbone: nn.Module, - head: nn.Module, - quantizer: nn.Module, - spec_transform: nn.Module, - ): - super().__init__() - - self.backbone = backbone - self.head = head - self.quantizer = quantizer - self.spec_transform = spec_transform - self.downsample_factor = math.prod(self.quantizer.downsample_factor) - - def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor: - if self.spec_transform is not None: - x = self.spec_transform(x) - - x = self.backbone(x) - if mask is not None: - x = x * mask - - if self.quantizer is not None: - vq_result = self.quantizer(x) - x = vq_result.z - - if mask is not None: - x = x * mask - - x = self.head(x, template=template) - - if x.ndim == 2: - x = x[:, None, :] - - if self.vq is not None: - return x, vq_result - - return x - - def encode(self, audios, audio_lengths): - audios = audios.float() - - mels = self.spec_transform(audios) - mel_lengths = audio_lengths // self.spec_transform.hop_length - mel_masks = sequence_mask(mel_lengths, mels.shape[2]) - mel_masks_float_conv = mel_masks[:, None, :].float() - mels = mels * mel_masks_float_conv - - # Encode - encoded_features = self.backbone(mels) * mel_masks_float_conv - feature_lengths = mel_lengths // self.downsample_factor - - return self.quantizer.encode(encoded_features), feature_lengths - - def decode(self, indices, feature_lengths) -> torch.Tensor: - mel_masks = sequence_mask( - feature_lengths * self.downsample_factor, - indices.shape[2] * self.downsample_factor, - ) - mel_masks_float_conv = mel_masks[:, None, :].float() - audio_lengths = ( - feature_lengths * self.downsample_factor * self.spec_transform.hop_length - ) - - audio_masks = sequence_mask( - audio_lengths, - indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length, - ) - audio_masks_float_conv = audio_masks[:, None, :].float() - - z = self.quantizer.decode(indices) * mel_masks_float_conv - x = self.head(z) * audio_masks_float_conv - - return x, audio_lengths - - def remove_parametrizations(self): - if hasattr(self.backbone, "remove_parametrizations"): - self.backbone.remove_parametrizations() - - if hasattr(self.head, "remove_parametrizations"): - self.head.remove_parametrizations() - - @property - def device(self): - return next(self.parameters()).device diff --git a/fish_speech/models/vqgan/modules/fsq.py b/fish_speech/models/vqgan/modules/fsq.py deleted file mode 100644 index 7ea4853376b6e663404ff48d6c6b5f664dde4094..0000000000000000000000000000000000000000 --- a/fish_speech/models/vqgan/modules/fsq.py +++ /dev/null @@ -1,116 +0,0 @@ -from dataclasses import dataclass - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from vector_quantize_pytorch import GroupedResidualFSQ - -from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet - - -@dataclass -class FSQResult: - z: torch.Tensor - codes: torch.Tensor - latents: torch.Tensor - - -class DownsampleFiniteScalarQuantize(nn.Module): - def __init__( - self, - input_dim: int = 512, - n_codebooks: int = 9, - n_groups: int = 1, - levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10 - downsample_factor: tuple[int] = (2, 2), - downsample_dims: tuple[int] | None = None, - ): - super().__init__() - - if downsample_dims is None: - downsample_dims = [input_dim for _ in range(len(downsample_factor))] - - all_dims = (input_dim,) + tuple(downsample_dims) - - self.residual_fsq = GroupedResidualFSQ( - dim=all_dims[-1], - levels=levels, - num_quantizers=n_codebooks, - groups=n_groups, - ) - - self.downsample_factor = downsample_factor - self.downsample_dims = downsample_dims - - self.downsample = nn.Sequential( - *[ - nn.Sequential( - FishConvNet( - all_dims[idx], - all_dims[idx + 1], - kernel_size=factor, - stride=factor, - ), - ConvNeXtBlock(dim=all_dims[idx + 1]), - ) - for idx, factor in enumerate(downsample_factor) - ] - ) - - self.upsample = nn.Sequential( - *[ - nn.Sequential( - FishTransConvNet( - all_dims[idx + 1], - all_dims[idx], - kernel_size=factor, - stride=factor, - ), - ConvNeXtBlock(dim=all_dims[idx]), - ) - for idx, factor in reversed(list(enumerate(downsample_factor))) - ] - ) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv1d, nn.Linear)): - nn.init.trunc_normal_(m.weight, std=0.02) - nn.init.constant_(m.bias, 0) - - def forward(self, z) -> FSQResult: - original_shape = z.shape - z = self.downsample(z) - quantized, indices = self.residual_fsq(z.mT) - result = FSQResult( - z=quantized.mT, - codes=indices.mT, - latents=z, - ) - result.z = self.upsample(result.z) - - # Pad or crop z to match original shape - diff = original_shape[-1] - result.z.shape[-1] - left = diff // 2 - right = diff - left - - if diff > 0: - result.z = F.pad(result.z, (left, right)) - elif diff < 0: - result.z = result.z[..., left:-right] - - return result - - def encode(self, z): - z = self.downsample(z) - _, indices = self.residual_fsq(z.mT) - indices = rearrange(indices, "g b l r -> b (g r) l") - return indices - - def decode(self, indices: torch.Tensor): - indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups) - z_q = self.residual_fsq.get_output_from_indices(indices) - z_q = self.upsample(z_q.mT) - return z_q diff --git a/fish_speech/models/vqgan/utils.py b/fish_speech/models/vqgan/utils.py deleted file mode 100644 index b90c131d214006875476a161cdfd2dffa8949dac..0000000000000000000000000000000000000000 --- a/fish_speech/models/vqgan/utils.py +++ /dev/null @@ -1,94 +0,0 @@ -import matplotlib -import torch -from matplotlib import pyplot as plt - -matplotlib.use("Agg") - - -def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape - - -def sequence_mask(length, max_length=None): - if max_length is None: - max_length = length.max() - x = torch.arange(max_length, dtype=length.dtype, device=length.device) - return x.unsqueeze(0) < length.unsqueeze(1) - - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - -def get_padding(kernel_size, dilation=1): - return int((kernel_size * dilation - dilation) / 2) - - -def plot_mel(data, titles=None): - fig, axes = plt.subplots(len(data), 1, squeeze=False) - - if titles is None: - titles = [None for i in range(len(data))] - - plt.tight_layout() - - for i in range(len(data)): - mel = data[i] - - if isinstance(mel, torch.Tensor): - mel = mel.float().detach().cpu().numpy() - - axes[i][0].imshow(mel, origin="lower") - axes[i][0].set_aspect(2.5, adjustable="box") - axes[i][0].set_ylim(0, mel.shape[0]) - axes[i][0].set_title(titles[i], fontsize="medium") - axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) - axes[i][0].set_anchor("W") - - return fig - - -def slice_segments(x, ids_str, segment_size=4): - ret = torch.zeros_like(x[:, :, :segment_size]) - for i in range(x.size(0)): - idx_str = ids_str[i] - idx_end = idx_str + segment_size - ret[i] = x[i, :, idx_str:idx_end] - - return ret - - -def rand_slice_segments(x, x_lengths=None, segment_size=4): - b, d, t = x.size() - if x_lengths is None: - x_lengths = t - ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0) - ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long) - ret = slice_segments(x, ids_str, segment_size) - return ret, ids_str - - -@torch.jit.script -def fused_add_tanh_sigmoid_multiply(in_act, n_channels): - n_channels_int = n_channels[0] - t_act = torch.tanh(in_act[:, :n_channels_int, :]) - s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) - acts = t_act * s_act - - return acts - - -def avg_with_mask(x, mask): - assert mask.dtype == torch.float, "Mask should be float" - - if mask.ndim == 2: - mask = mask.unsqueeze(1) - - if mask.shape[1] == 1: - mask = mask.expand_as(x) - - return (x * mask).sum() / mask.sum() diff --git a/fish_speech/scheduler.py b/fish_speech/scheduler.py deleted file mode 100644 index 43bed6a2210723a7d5e1ea0a48ba61140047ca29..0000000000000000000000000000000000000000 --- a/fish_speech/scheduler.py +++ /dev/null @@ -1,40 +0,0 @@ -import math - - -def get_cosine_schedule_with_warmup_lr_lambda( - current_step: int, - *, - num_warmup_steps: int | float, - num_training_steps: int, - num_cycles: float = 0.5, - final_lr_ratio: float = 0.0, -): - if 0 < num_warmup_steps < 1: # float mode - num_warmup_steps = int(num_warmup_steps * num_training_steps) - - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - - progress = float(current_step - num_warmup_steps) / float( - max(1, num_training_steps - num_warmup_steps) - ) - - return max( - final_lr_ratio, - 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), - ) - - -def get_constant_schedule_with_warmup_lr_lambda( - current_step: int, - *, - num_warmup_steps: int | float, - num_training_steps: int | None = None, -): - if 0 < num_warmup_steps < 1: # float mode - num_warmup_steps = int(num_warmup_steps * num_training_steps) - - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - - return 1.0 diff --git a/fish_speech/text/__init__.py b/fish_speech/text/__init__.py deleted file mode 100644 index d740bd8eed447d162e55b165965dec17130377ce..0000000000000000000000000000000000000000 --- a/fish_speech/text/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .clean import clean_text -from .spliter import split_text - -__all__ = ["clean_text", "split_text"] diff --git a/fish_speech/text/__pycache__/__init__.cpython-310.pyc b/fish_speech/text/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index cbda0e48251bdcc53c332c821e4ea9519047d490..0000000000000000000000000000000000000000 Binary files a/fish_speech/text/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/text/__pycache__/clean.cpython-310.pyc b/fish_speech/text/__pycache__/clean.cpython-310.pyc deleted file mode 100644 index f8c648bb945e8d4ff16146dff98a70a779dab7eb..0000000000000000000000000000000000000000 Binary files a/fish_speech/text/__pycache__/clean.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/text/__pycache__/spliter.cpython-310.pyc b/fish_speech/text/__pycache__/spliter.cpython-310.pyc deleted file mode 100644 index 94114179529badf1ecd0fa37c19d5fdc6223dcf9..0000000000000000000000000000000000000000 Binary files a/fish_speech/text/__pycache__/spliter.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/text/chn_text_norm/.gitignore b/fish_speech/text/chn_text_norm/.gitignore deleted file mode 100644 index 75ea58fa4a7bf34fc9ab35afee24684aa6ef4c89..0000000000000000000000000000000000000000 --- a/fish_speech/text/chn_text_norm/.gitignore +++ /dev/null @@ -1,114 +0,0 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -.hypothesis/ -.pytest_cache/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# pyenv -.python-version - -# celery beat schedule file -celerybeat-schedule - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ - -# JetBrains PyCharm -.idea - -# Customize -references -url.txt - -# Git -.git diff --git a/fish_speech/text/chn_text_norm/README.md b/fish_speech/text/chn_text_norm/README.md deleted file mode 100644 index 8450a2c6c0f8e40f4509f5be196eb9f9d2b9afb6..0000000000000000000000000000000000000000 --- a/fish_speech/text/chn_text_norm/README.md +++ /dev/null @@ -1,36 +0,0 @@ -# This account is no longer in use, see [Atomicoo](https://github.com/atomicoo) for my latest works. - -# Chn Text Norm - -this is a repository for chinese text normalization (no longer maintained). - -## Quick Start ## - -### Git Clone Repo ### - -git clone this repo to the root directory of your project which need to use it. - - cd /path/to/proj - git clone https://github.com/Joee1995/chn-text-norm.git - -after that, your doc tree should be: -``` -proj # root of your project -|--- chn_text_norm # this chn-text-norm tool - |--- text.py - |--- ... -|--- text_normalize.py # your text normalization code -|--- ... -``` - -### How to Use ? ### - - # text_normalize.py - from chn_text_norm.text import * - - raw_text = 'your raw text' - text = Text(raw_text=raw_text).normalize() - -### How to add quantums ### - -打开test.py,然后你就知道怎么做了。 diff --git a/fish_speech/text/chn_text_norm/__init__.py b/fish_speech/text/chn_text_norm/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/fish_speech/text/chn_text_norm/__pycache__/__init__.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 34ff30c1d86436d172d82a2afe4f2914407b2056..0000000000000000000000000000000000000000 Binary files a/fish_speech/text/chn_text_norm/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/basic_class.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/basic_class.cpython-310.pyc deleted file mode 100644 index cf1f70fdddf166b1e08a881f44651789cde5665b..0000000000000000000000000000000000000000 Binary files a/fish_speech/text/chn_text_norm/__pycache__/basic_class.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/basic_constant.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/basic_constant.cpython-310.pyc deleted file mode 100644 index 7ba0d65d52f907c05d8672ab5ffeba1ef69b0a58..0000000000000000000000000000000000000000 Binary files a/fish_speech/text/chn_text_norm/__pycache__/basic_constant.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/basic_util.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/basic_util.cpython-310.pyc deleted file mode 100644 index 565e2baec31a08a1d40e473dbaf8cc068c4b56eb..0000000000000000000000000000000000000000 Binary files a/fish_speech/text/chn_text_norm/__pycache__/basic_util.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/cardinal.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/cardinal.cpython-310.pyc deleted file mode 100644 index 5ac369bcff904eeb22fd4c359b7ef4d0dff2856b..0000000000000000000000000000000000000000 Binary files a/fish_speech/text/chn_text_norm/__pycache__/cardinal.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/date.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/date.cpython-310.pyc deleted file mode 100644 index cb55a304422a219e00687fc987b4cdcfd8283dcd..0000000000000000000000000000000000000000 Binary files a/fish_speech/text/chn_text_norm/__pycache__/date.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/digit.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/digit.cpython-310.pyc deleted file mode 100644 index 57bcee45c211f05e127b6556c6c3e0dc05a43e9c..0000000000000000000000000000000000000000 Binary files a/fish_speech/text/chn_text_norm/__pycache__/digit.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/fraction.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/fraction.cpython-310.pyc deleted file mode 100644 index 2982394aae51d7fef63f6ce4c13444662ccde1af..0000000000000000000000000000000000000000 Binary files a/fish_speech/text/chn_text_norm/__pycache__/fraction.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/money.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/money.cpython-310.pyc deleted file mode 100644 index 5cdaa0642dce2713d356ae35a4dea9955df67e9c..0000000000000000000000000000000000000000 Binary files a/fish_speech/text/chn_text_norm/__pycache__/money.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/percentage.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/percentage.cpython-310.pyc deleted file mode 100644 index 1572f267a79a5231149fc5bba14cb4b4d4907895..0000000000000000000000000000000000000000 Binary files a/fish_speech/text/chn_text_norm/__pycache__/percentage.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/telephone.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/telephone.cpython-310.pyc deleted file mode 100644 index b088af763cadeb63f0ee4308c56032c19da3ed1f..0000000000000000000000000000000000000000 Binary files a/fish_speech/text/chn_text_norm/__pycache__/telephone.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/text.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/text.cpython-310.pyc deleted file mode 100644 index f84f49e3880c7410255e8ab1038eeb94d5375656..0000000000000000000000000000000000000000 Binary files a/fish_speech/text/chn_text_norm/__pycache__/text.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/text/chn_text_norm/basic_class.py b/fish_speech/text/chn_text_norm/basic_class.py deleted file mode 100644 index 58d8f8eb7fc85d0861f106667d8f4e3e52b54761..0000000000000000000000000000000000000000 --- a/fish_speech/text/chn_text_norm/basic_class.py +++ /dev/null @@ -1,172 +0,0 @@ -# -*- coding: utf-8 -*- -"""基本类 -中文字符类 -中文数字/数位类 -中文数字类 -中文数位类 -中文数字系统类 -中文数学符号类 -*中文其他符号类 -""" - -__author__ = "Zhiyang Zhou " -__data__ = "2019-05-02" - -from fish_speech.text.chn_text_norm.basic_constant import NUMBERING_TYPES - - -class ChineseChar(object): - """ - 中文字符 - 每个字符对应简体和繁体, - e.g. 简体 = '负', 繁体 = '負' - 转换时可转换为简体或繁体 - """ - - def __init__(self, simplified, traditional): - self.simplified = simplified - self.traditional = traditional - self.__repr__ = self.__str__ - - def __str__(self): - return self.simplified or self.traditional or None - - def __repr__(self): - return self.__str__() - - -class ChineseNumberUnit(ChineseChar): - """ - 中文数字/数位字符 - 每个字符除繁简体外还有一个额外的大写字符 - e.g. '陆' 和 '陸' - """ - - def __init__(self, power, simplified, traditional, big_s, big_t): - super(ChineseNumberUnit, self).__init__(simplified, traditional) - self.power = power - self.big_s = big_s - self.big_t = big_t - - def __str__(self): - return "10^{}".format(self.power) - - @classmethod - def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False): - - if small_unit: - return ChineseNumberUnit( - power=index + 1, - simplified=value[0], - traditional=value[1], - big_s=value[1], - big_t=value[1], - ) - elif numbering_type == NUMBERING_TYPES[0]: - return ChineseNumberUnit( - power=index + 8, - simplified=value[0], - traditional=value[1], - big_s=value[0], - big_t=value[1], - ) - elif numbering_type == NUMBERING_TYPES[1]: - return ChineseNumberUnit( - power=(index + 2) * 4, - simplified=value[0], - traditional=value[1], - big_s=value[0], - big_t=value[1], - ) - elif numbering_type == NUMBERING_TYPES[2]: - return ChineseNumberUnit( - power=pow(2, index + 3), - simplified=value[0], - traditional=value[1], - big_s=value[0], - big_t=value[1], - ) - else: - raise ValueError( - "Counting type should be in {0} ({1} provided).".format( - NUMBERING_TYPES, numbering_type - ) - ) - - -class ChineseNumberDigit(ChineseChar): - """ - 中文数字字符 - """ - - def __init__( - self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None - ): - super(ChineseNumberDigit, self).__init__(simplified, traditional) - self.value = value - self.big_s = big_s - self.big_t = big_t - self.alt_s = alt_s - self.alt_t = alt_t - - def __str__(self): - return str(self.value) - - @classmethod - def create(cls, i, v): - return ChineseNumberDigit(i, v[0], v[1], v[2], v[3]) - - -class ChineseMath(ChineseChar): - """ - 中文数位字符 - """ - - def __init__(self, simplified, traditional, symbol, expression=None): - super(ChineseMath, self).__init__(simplified, traditional) - self.symbol = symbol - self.expression = expression - self.big_s = simplified - self.big_t = traditional - - -CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath - - -class NumberSystem(object): - """ - 中文数字系统 - """ - - pass - - -class MathSymbol(object): - """ - 用于中文数字系统的数学符号 (繁/简体), e.g. - positive = ['正', '正'] - negative = ['负', '負'] - point = ['点', '點'] - """ - - def __init__(self, positive, negative, point): - self.positive = positive - self.negative = negative - self.point = point - - def __iter__(self): - for v in self.__dict__.values(): - yield v - - -# class OtherSymbol(object): -# """ -# 其他符号 -# """ -# -# def __init__(self, sil): -# self.sil = sil -# -# def __iter__(self): -# for v in self.__dict__.values(): -# yield v diff --git a/fish_speech/text/chn_text_norm/basic_constant.py b/fish_speech/text/chn_text_norm/basic_constant.py deleted file mode 100644 index 9a65991b9a9d349a0571c80508633951e52749ef..0000000000000000000000000000000000000000 --- a/fish_speech/text/chn_text_norm/basic_constant.py +++ /dev/null @@ -1,30 +0,0 @@ -# -*- coding: utf-8 -*- -"""基本常量 -中文数字/数位/符号字符常量 -""" - -__author__ = "Zhiyang Zhou " -__data__ = "2019-05-02" - -CHINESE_DIGIS = "零一二三四五六七八九" -BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖" -BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖" -SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万" -SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬" -LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载" -LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載" -SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万" -SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬" - -ZERO_ALT = "〇" -ONE_ALT = "幺" -TWO_ALTS = ["两", "兩"] - -POSITIVE = ["正", "正"] -NEGATIVE = ["负", "負"] -POINT = ["点", "點"] -# PLUS = [u'加', u'加'] -# SIL = [u'杠', u'槓'] - -# 中文数字系统类型 -NUMBERING_TYPES = ["low", "mid", "high"] diff --git a/fish_speech/text/chn_text_norm/basic_util.py b/fish_speech/text/chn_text_norm/basic_util.py deleted file mode 100644 index dbf6130be87f285eed9998186508ea489d3bac9e..0000000000000000000000000000000000000000 --- a/fish_speech/text/chn_text_norm/basic_util.py +++ /dev/null @@ -1,342 +0,0 @@ -# -*- coding: utf-8 -*- -"""基本方法 -创建中文数字系统 方法 -中文字符串 <=> 数字串 方法 -数字串 <=> 中文字符串 方法 -""" - -__author__ = "Zhiyang Zhou " -__data__ = "2019-05-02" - -from fish_speech.text.chn_text_norm.basic_class import * -from fish_speech.text.chn_text_norm.basic_constant import * - - -def create_system(numbering_type=NUMBERING_TYPES[1]): - """ - 根据数字系统类型返回创建相应的数字系统,默认为 mid - NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型 - low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc. - mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc. - high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc. - 返回对应的数字系统 - """ - - # chinese number units of '亿' and larger - all_larger_units = zip( - LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, - LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL, - ) - larger_units = [ - CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units) - ] - # chinese number units of '十, 百, 千, 万' - all_smaller_units = zip( - SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, - SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL, - ) - smaller_units = [ - CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units) - ] - # digis - chinese_digis = zip( - CHINESE_DIGIS, - CHINESE_DIGIS, - BIG_CHINESE_DIGIS_SIMPLIFIED, - BIG_CHINESE_DIGIS_TRADITIONAL, - ) - digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)] - digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT - digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT - digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1] - - # symbols - positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x) - negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x) - point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y))) - # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y))) - system = NumberSystem() - system.units = smaller_units + larger_units - system.digits = digits - system.math = MathSymbol(positive_cn, negative_cn, point_cn) - # system.symbols = OtherSymbol(sil_cn) - return system - - -def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]): - - def get_symbol(char, system): - for u in system.units: - if char in [u.traditional, u.simplified, u.big_s, u.big_t]: - return u - for d in system.digits: - if char in [ - d.traditional, - d.simplified, - d.big_s, - d.big_t, - d.alt_s, - d.alt_t, - ]: - return d - for m in system.math: - if char in [m.traditional, m.simplified]: - return m - - def string2symbols(chinese_string, system): - int_string, dec_string = chinese_string, "" - for p in [system.math.point.simplified, system.math.point.traditional]: - if p in chinese_string: - int_string, dec_string = chinese_string.split(p) - break - return [get_symbol(c, system) for c in int_string], [ - get_symbol(c, system) for c in dec_string - ] - - def correct_symbols(integer_symbols, system): - """ - 一百八 to 一百八十 - 一亿一千三百万 to 一亿 一千万 三百万 - """ - - if integer_symbols and isinstance(integer_symbols[0], CNU): - if integer_symbols[0].power == 1: - integer_symbols = [system.digits[1]] + integer_symbols - - if len(integer_symbols) > 1: - if isinstance(integer_symbols[-1], CND) and isinstance( - integer_symbols[-2], CNU - ): - integer_symbols.append( - CNU(integer_symbols[-2].power - 1, None, None, None, None) - ) - - result = [] - unit_count = 0 - for s in integer_symbols: - if isinstance(s, CND): - result.append(s) - unit_count = 0 - elif isinstance(s, CNU): - current_unit = CNU(s.power, None, None, None, None) - unit_count += 1 - - if unit_count == 1: - result.append(current_unit) - elif unit_count > 1: - for i in range(len(result)): - if ( - isinstance(result[-i - 1], CNU) - and result[-i - 1].power < current_unit.power - ): - result[-i - 1] = CNU( - result[-i - 1].power + current_unit.power, - None, - None, - None, - None, - ) - return result - - def compute_value(integer_symbols): - """ - Compute the value. - When current unit is larger than previous unit, current unit * all previous units will be used as all previous units. - e.g. '两千万' = 2000 * 10000 not 2000 + 10000 - """ - value = [0] - last_power = 0 - for s in integer_symbols: - if isinstance(s, CND): - value[-1] = s.value - elif isinstance(s, CNU): - value[-1] *= pow(10, s.power) - if s.power > last_power: - value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1])) - last_power = s.power - value.append(0) - return sum(value) - - system = create_system(numbering_type) - int_part, dec_part = string2symbols(chinese_string, system) - int_part = correct_symbols(int_part, system) - int_str = str(compute_value(int_part)) - dec_str = "".join([str(d.value) for d in dec_part]) - if dec_part: - return "{0}.{1}".format(int_str, dec_str) - else: - return int_str - - -def num2chn( - number_string, - numbering_type=NUMBERING_TYPES[1], - big=False, - traditional=False, - alt_zero=False, - alt_one=False, - alt_two=True, - use_zeros=True, - use_units=True, -): - - def get_value(value_string, use_zeros=True): - - striped_string = value_string.lstrip("0") - - # record nothing if all zeros - if not striped_string: - return [] - - # record one digits - elif len(striped_string) == 1: - if use_zeros and len(value_string) != len(striped_string): - return [system.digits[0], system.digits[int(striped_string)]] - else: - return [system.digits[int(striped_string)]] - - # recursively record multiple digits - else: - result_unit = next( - u for u in reversed(system.units) if u.power < len(striped_string) - ) - result_string = value_string[: -result_unit.power] - return ( - get_value(result_string) - + [result_unit] - + get_value(striped_string[-result_unit.power :]) - ) - - system = create_system(numbering_type) - - int_dec = number_string.split(".") - if len(int_dec) == 1: - int_string = int_dec[0] - dec_string = "" - elif len(int_dec) == 2: - int_string = int_dec[0] - dec_string = int_dec[1] - else: - raise ValueError( - "invalid input num string with more than one dot: {}".format(number_string) - ) - - if use_units and len(int_string) > 1: - result_symbols = get_value(int_string) - else: - result_symbols = [system.digits[int(c)] for c in int_string] - dec_symbols = [system.digits[int(c)] for c in dec_string] - if dec_string: - result_symbols += [system.math.point] + dec_symbols - - if alt_two: - liang = CND( - 2, - system.digits[2].alt_s, - system.digits[2].alt_t, - system.digits[2].big_s, - system.digits[2].big_t, - ) - for i, v in enumerate(result_symbols): - if isinstance(v, CND) and v.value == 2: - next_symbol = ( - result_symbols[i + 1] if i < len(result_symbols) - 1 else None - ) - previous_symbol = result_symbols[i - 1] if i > 0 else None - if isinstance(next_symbol, CNU) and isinstance( - previous_symbol, (CNU, type(None)) - ): - if next_symbol.power != 1 and ( - (previous_symbol is None) or (previous_symbol.power != 1) - ): - result_symbols[i] = liang - - # if big is True, '两' will not be used and `alt_two` has no impact on output - if big: - attr_name = "big_" - if traditional: - attr_name += "t" - else: - attr_name += "s" - else: - if traditional: - attr_name = "traditional" - else: - attr_name = "simplified" - - result = "".join([getattr(s, attr_name) for s in result_symbols]) - - # if not use_zeros: - # result = result.strip(getattr(system.digits[0], attr_name)) - - if alt_zero: - result = result.replace( - getattr(system.digits[0], attr_name), system.digits[0].alt_s - ) - - if alt_one: - result = result.replace( - getattr(system.digits[1], attr_name), system.digits[1].alt_s - ) - - for i, p in enumerate(POINT): - if result.startswith(p): - return CHINESE_DIGIS[0] + result - - # ^10, 11, .., 19 - if ( - len(result) >= 2 - and result[1] - in [ - SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], - SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0], - ] - and result[0] - in [ - CHINESE_DIGIS[1], - BIG_CHINESE_DIGIS_SIMPLIFIED[1], - BIG_CHINESE_DIGIS_TRADITIONAL[1], - ] - ): - result = result[1:] - - return result - - -if __name__ == "__main__": - - # 测试程序 - all_chinese_number_string = ( - CHINESE_DIGIS - + BIG_CHINESE_DIGIS_SIMPLIFIED - + BIG_CHINESE_DIGIS_TRADITIONAL - + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED - + LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL - + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED - + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL - + ZERO_ALT - + ONE_ALT - + "".join(TWO_ALTS + POSITIVE + NEGATIVE + POINT) - ) - - print("num:", chn2num("一万零四百零三点八零五")) - print("num:", chn2num("一亿六点三")) - print("num:", chn2num("一亿零六点三")) - print("num:", chn2num("两千零一亿六点三")) - # print('num:', chn2num('一零零八六')) - print("txt:", num2chn("10260.03", alt_zero=True)) - print("txt:", num2chn("20037.090", numbering_type="low", traditional=True)) - print("txt:", num2chn("100860001.77", numbering_type="high", big=True)) - print( - "txt:", - num2chn( - "059523810880", - alt_one=True, - alt_two=False, - use_lzeros=True, - use_rzeros=True, - use_units=False, - ), - ) - - print(all_chinese_number_string) diff --git a/fish_speech/text/chn_text_norm/cardinal.py b/fish_speech/text/chn_text_norm/cardinal.py deleted file mode 100644 index ace9f5ad8e7f3be3a8e41b11dc0b9f80db799616..0000000000000000000000000000000000000000 --- a/fish_speech/text/chn_text_norm/cardinal.py +++ /dev/null @@ -1,32 +0,0 @@ -# -*- coding: utf-8 -*- -"""CARDINAL类 (包含小数DECIMAL类) -纯数 <=> 中文字符串 方法 -中文字符串 <=> 纯数 方法 -""" - -__author__ = "Zhiyang Zhou " -__data__ = "2019-05-03" - -from fish_speech.text.chn_text_norm.basic_util import * - - -class Cardinal: - """ - CARDINAL类 - """ - - def __init__(self, cardinal=None, chntext=None): - self.cardinal = cardinal - self.chntext = chntext - - def chntext2cardinal(self): - return chn2num(self.chntext) - - def cardinal2chntext(self): - return num2chn(self.cardinal) - - -if __name__ == "__main__": - - # 测试程序 - print(Cardinal(cardinal="21357.230").cardinal2chntext()) diff --git a/fish_speech/text/chn_text_norm/date.py b/fish_speech/text/chn_text_norm/date.py deleted file mode 100644 index 77acfdb9a91df0fe3c615a0784f61aad87fbe56e..0000000000000000000000000000000000000000 --- a/fish_speech/text/chn_text_norm/date.py +++ /dev/null @@ -1,75 +0,0 @@ -# -*- coding: utf-8 -*- -"""DATE类 -日期 <=> 中文字符串 方法 -中文字符串 <=> 日期 方法 -""" - -__author__ = "Zhiyang Zhou " -__data__ = "2019-05-07" - -from fish_speech.text.chn_text_norm.cardinal import Cardinal -from fish_speech.text.chn_text_norm.digit import Digit - - -class Date: - """ - DATE类 - """ - - def __init__(self, date=None, chntext=None): - self.date = date - self.chntext = chntext - - # def chntext2date(self): - # chntext = self.chntext - # try: - # year, other = chntext.strip().split('年', maxsplit=1) - # year = Digit(chntext=year).digit2chntext() + '年' - # except ValueError: - # other = chntext - # year = '' - # if other: - # try: - # month, day = other.strip().split('月', maxsplit=1) - # month = Cardinal(chntext=month).chntext2cardinal() + '月' - # except ValueError: - # day = chntext - # month = '' - # if day: - # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1] - # else: - # month = '' - # day = '' - # date = year + month + day - # self.date = date - # return self.date - - def date2chntext(self): - date = self.date - try: - year, other = date.strip().split("年", maxsplit=1) - year = Digit(digit=year).digit2chntext() + "年" - except ValueError: - other = date - year = "" - if other: - try: - month, day = other.strip().split("月", maxsplit=1) - month = Cardinal(cardinal=month).cardinal2chntext() + "月" - except ValueError: - day = date - month = "" - if day: - day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1] - else: - month = "" - day = "" - chntext = year + month + day - self.chntext = chntext - return self.chntext - - -if __name__ == "__main__": - - # 测试 - print(Date(date="09年3月16日").date2chntext()) diff --git a/fish_speech/text/chn_text_norm/digit.py b/fish_speech/text/chn_text_norm/digit.py deleted file mode 100644 index 47c0cd4ad0c700635f84470bfdacfbdafb4a6185..0000000000000000000000000000000000000000 --- a/fish_speech/text/chn_text_norm/digit.py +++ /dev/null @@ -1,32 +0,0 @@ -# -*- coding: utf-8 -*- -"""DIGIT类 -数字串 <=> 中文字符串 方法 -中文字符串 <=> 数字串 方法 -""" - -__author__ = "Zhiyang Zhou " -__data__ = "2019-05-03" - -from fish_speech.text.chn_text_norm.basic_util import * - - -class Digit: - """ - DIGIT类 - """ - - def __init__(self, digit=None, chntext=None): - self.digit = digit - self.chntext = chntext - - # def chntext2digit(self): - # return chn2num(self.chntext) - - def digit2chntext(self): - return num2chn(self.digit, alt_two=False, use_units=False) - - -if __name__ == "__main__": - - # 测试程序 - print(Digit(digit="2016").digit2chntext()) diff --git a/fish_speech/text/chn_text_norm/fraction.py b/fish_speech/text/chn_text_norm/fraction.py deleted file mode 100644 index b43b6a7feb634d346d59a2b4ab84b77ac88df103..0000000000000000000000000000000000000000 --- a/fish_speech/text/chn_text_norm/fraction.py +++ /dev/null @@ -1,35 +0,0 @@ -# -*- coding: utf-8 -*- -"""FRACTION类 -分数 <=> 中文字符串 方法 -中文字符串 <=> 分数 方法 -""" - -__author__ = "Zhiyang Zhou " -__data__ = "2019-05-03" - -from fish_speech.text.chn_text_norm.basic_util import * - - -class Fraction: - """ - FRACTION类 - """ - - def __init__(self, fraction=None, chntext=None): - self.fraction = fraction - self.chntext = chntext - - def chntext2fraction(self): - denominator, numerator = self.chntext.split("分之") - return chn2num(numerator) + "/" + chn2num(denominator) - - def fraction2chntext(self): - numerator, denominator = self.fraction.split("/") - return num2chn(denominator) + "分之" + num2chn(numerator) - - -if __name__ == "__main__": - - # 测试程序 - print(Fraction(fraction="2135/7230").fraction2chntext()) - print(Fraction(chntext="五百八十一分之三百六十九").chntext2fraction()) diff --git a/fish_speech/text/chn_text_norm/money.py b/fish_speech/text/chn_text_norm/money.py deleted file mode 100644 index b4c980d32134e1460e96e5bcbcc73d0d55974d2a..0000000000000000000000000000000000000000 --- a/fish_speech/text/chn_text_norm/money.py +++ /dev/null @@ -1,43 +0,0 @@ -# -*- coding: utf-8 -*- -"""MONEY类 -金钱 <=> 中文字符串 方法 -中文字符串 <=> 金钱 方法 -""" -import re - -__author__ = "Zhiyang Zhou " -__data__ = "2019-05-08" - -from fish_speech.text.chn_text_norm.cardinal import Cardinal - - -class Money: - """ - MONEY类 - """ - - def __init__(self, money=None, chntext=None): - self.money = money - self.chntext = chntext - - # def chntext2money(self): - # return self.money - - def money2chntext(self): - money = self.money - pattern = re.compile(r"(\d+(\.\d+)?)") - matchers = pattern.findall(money) - if matchers: - for matcher in matchers: - money = money.replace( - matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext() - ) - self.chntext = money - return self.chntext - - -if __name__ == "__main__": - - # 测试 - print(Money(money="21.5万元").money2chntext()) - print(Money(money="230块5毛").money2chntext()) diff --git a/fish_speech/text/chn_text_norm/percentage.py b/fish_speech/text/chn_text_norm/percentage.py deleted file mode 100644 index 46abbf545af62eb951d8f6fe40bcf684587f81b0..0000000000000000000000000000000000000000 --- a/fish_speech/text/chn_text_norm/percentage.py +++ /dev/null @@ -1,33 +0,0 @@ -# -*- coding: utf-8 -*- -"""PERCENTAGE类 -百分数 <=> 中文字符串 方法 -中文字符串 <=> 百分数 方法 -""" - -__author__ = "Zhiyang Zhou " -__data__ = "2019-05-06" - -from fish_speech.text.chn_text_norm.basic_util import * - - -class Percentage: - """ - PERCENTAGE类 - """ - - def __init__(self, percentage=None, chntext=None): - self.percentage = percentage - self.chntext = chntext - - def chntext2percentage(self): - return chn2num(self.chntext.strip().strip("百分之")) + "%" - - def percentage2chntext(self): - return "百分之" + num2chn(self.percentage.strip().strip("%")) - - -if __name__ == "__main__": - - # 测试程序 - print(Percentage(chntext="百分之五十六点零三").chntext2percentage()) - print(Percentage(percentage="65.3%").percentage2chntext()) diff --git a/fish_speech/text/chn_text_norm/telephone.py b/fish_speech/text/chn_text_norm/telephone.py deleted file mode 100644 index e72b546db628a3b807dc6235b59b188cae3153ff..0000000000000000000000000000000000000000 --- a/fish_speech/text/chn_text_norm/telephone.py +++ /dev/null @@ -1,51 +0,0 @@ -# -*- coding: utf-8 -*- -"""TELEPHONE类 -电话号码 <=> 中文字符串 方法 -中文字符串 <=> 电话号码 方法 -""" - -__author__ = "Zhiyang Zhou " -__data__ = "2019-05-03" - -from fish_speech.text.chn_text_norm.basic_util import * - - -class TelePhone: - """ - TELEPHONE类 - """ - - def __init__(self, telephone=None, raw_chntext=None, chntext=None): - self.telephone = telephone - self.raw_chntext = raw_chntext - self.chntext = chntext - - # def chntext2telephone(self): - # sil_parts = self.raw_chntext.split('') - # self.telephone = '-'.join([ - # str(chn2num(p)) for p in sil_parts - # ]) - # return self.telephone - - def telephone2chntext(self, fixed=False): - - if fixed: - sil_parts = self.telephone.split("-") - self.raw_chntext = "".join( - [num2chn(part, alt_two=False, use_units=False) for part in sil_parts] - ) - self.chntext = self.raw_chntext.replace("", "") - else: - sp_parts = self.telephone.strip("+").split() - self.raw_chntext = "".join( - [num2chn(part, alt_two=False, use_units=False) for part in sp_parts] - ) - self.chntext = self.raw_chntext.replace("", "") - return self.chntext - - -if __name__ == "__main__": - - # 测试程序 - print(TelePhone(telephone="0595-23980880").telephone2chntext()) - # print(TelePhone(raw_chntext='零五九五杠二三八六五零九八').chntext2telephone()) diff --git a/fish_speech/text/chn_text_norm/text.py b/fish_speech/text/chn_text_norm/text.py deleted file mode 100644 index 54086fd933c01e14c3c55cee9adb52eefb58fd31..0000000000000000000000000000000000000000 --- a/fish_speech/text/chn_text_norm/text.py +++ /dev/null @@ -1,177 +0,0 @@ -# -*- coding: utf-8 -*- -""" -TEXT类 -""" - -__author__ = "Zhiyang Zhou " -__data__ = "2019-05-03" - -import re - -from fish_speech.text.chn_text_norm.cardinal import Cardinal -from fish_speech.text.chn_text_norm.date import Date -from fish_speech.text.chn_text_norm.digit import Digit -from fish_speech.text.chn_text_norm.fraction import Fraction -from fish_speech.text.chn_text_norm.money import Money -from fish_speech.text.chn_text_norm.percentage import Percentage -from fish_speech.text.chn_text_norm.telephone import TelePhone - -CURRENCY_NAMES = ( - "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|" - "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)" -) -CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)" -COM_QUANTIFIERS = ( - "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|" - "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|" - "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|" - "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|" - "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|" - "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|人|抽)" -) - - -class Text: - """ - Text类 - """ - - def __init__(self, raw_text, norm_text=None): - self.raw_text = "^" + raw_text + "$" - self.norm_text = norm_text - - def _particular(self): - text = self.norm_text - pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))") - matchers = pattern.findall(text) - if matchers: - # print('particular') - for matcher in matchers: - text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1) - self.norm_text = text - return self.norm_text - - def normalize(self): - text = self.raw_text - - # 规范化日期 - pattern = re.compile( - r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)" - ) - matchers = pattern.findall(text) - if matchers: - # print('date') - for matcher in matchers: - text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1) - - # 规范化金钱 - pattern = re.compile( - r"\D+((\d+(\.\d+)?)[多余几]?" - + CURRENCY_UNITS - + "(\d" - + CURRENCY_UNITS - + "?)?)" - ) - matchers = pattern.findall(text) - if matchers: - # print('money') - for matcher in matchers: - text = text.replace( - matcher[0], Money(money=matcher[0]).money2chntext(), 1 - ) - - # 规范化固话/手机号码 - # 手机 - # http://www.jihaoba.com/news/show/13680 - # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 - # 联通:130、131、132、156、155、186、185、176 - # 电信:133、153、189、180、181、177 - pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") - matchers = pattern.findall(text) - if matchers: - # print('telephone') - for matcher in matchers: - text = text.replace( - matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1 - ) - # 固话 - pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") - matchers = pattern.findall(text) - if matchers: - # print('fixed telephone') - for matcher in matchers: - text = text.replace( - matcher[0], - TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), - 1, - ) - - # 规范化分数 - pattern = re.compile(r"(\d+/\d+)") - matchers = pattern.findall(text) - if matchers: - # print('fraction') - for matcher in matchers: - text = text.replace( - matcher, Fraction(fraction=matcher).fraction2chntext(), 1 - ) - - # 规范化百分数 - text = text.replace("%", "%") - pattern = re.compile(r"(\d+(\.\d+)?%)") - matchers = pattern.findall(text) - if matchers: - # print('percentage') - for matcher in matchers: - text = text.replace( - matcher[0], - Percentage(percentage=matcher[0]).percentage2chntext(), - 1, - ) - - # 规范化纯数+量词 - pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS) - matchers = pattern.findall(text) - if matchers: - # print('cardinal+quantifier') - for matcher in matchers: - text = text.replace( - matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1 - ) - - # 规范化数字编号 - pattern = re.compile(r"(\d{4,32})") - matchers = pattern.findall(text) - if matchers: - # print('digit') - for matcher in matchers: - text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1) - - # 规范化纯数 - pattern = re.compile(r"(\d+(\.\d+)?)") - matchers = pattern.findall(text) - if matchers: - # print('cardinal') - for matcher in matchers: - text = text.replace( - matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1 - ) - - self.norm_text = text - self._particular() - - return self.norm_text.lstrip("^").rstrip("$") - - -if __name__ == "__main__": - - # 测试程序 - print(Text(raw_text="固话:0595-23865596或23880880。").normalize()) - print(Text(raw_text="手机:+86 19859213959或15659451527。").normalize()) - print(Text(raw_text="分数:32477/76391。").normalize()) - print(Text(raw_text="百分数:80.03%。").normalize()) - print(Text(raw_text="编号:31520181154418。").normalize()) - print(Text(raw_text="纯数:2983.07克或12345.60米。").normalize()) - print(Text(raw_text="日期:1999年2月20日或09年3月15号。").normalize()) - print(Text(raw_text="金钱:12块5,34.5元,20.1万").normalize()) - print(Text(raw_text="特殊:O2O或B2C。").normalize()) diff --git a/fish_speech/text/clean.py b/fish_speech/text/clean.py deleted file mode 100644 index c228dfcd13324e8b1abe4ead5f01f4bd8ed0c33a..0000000000000000000000000000000000000000 --- a/fish_speech/text/clean.py +++ /dev/null @@ -1,31 +0,0 @@ -import re - -SYMBOLS_MAPPING = { - "“": "'", - "”": "'", - "‘": "'", - "’": "'", - "【": "", - "】": "", - "[": "", - "]": "", - "(": "", - ")": "", - "(": "", - ")": "", - "・": "·", -} - -REPLACE_SYMBOL_REGEX = re.compile( - "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys()) -) - - -def clean_text(text): - # Clean the text - text = text.strip() - - # Replace all chinese symbols with their english counterparts - text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text) - - return text diff --git a/fish_speech/text/spliter.py b/fish_speech/text/spliter.py deleted file mode 100644 index d4bb995487c4f53818c6b2a16cf0a886b4e02e84..0000000000000000000000000000000000000000 --- a/fish_speech/text/spliter.py +++ /dev/null @@ -1,130 +0,0 @@ -import re -import string - -from fish_speech.text.clean import clean_text - - -def utf_8_len(text): - return len(text.encode("utf-8")) - - -def break_text(texts, length, splits: set): - for text in texts: - if utf_8_len(text) <= length: - yield text - continue - - curr = "" - for char in text: - curr += char - - if char in splits: - yield curr - curr = "" - - if curr: - yield curr - - -def break_text_by_length(texts, length): - for text in texts: - if utf_8_len(text) <= length: - yield text - continue - - curr = "" - for char in text: - curr += char - - if utf_8_len(curr) >= length: - yield curr - curr = "" - - if curr: - yield curr - - -def add_cleaned(curr, segments): - curr = curr.strip() - if curr and not all(c.isspace() or c in string.punctuation for c in curr): - segments.append(curr) - - -def protect_float(text): - # Turns 3.14 into <3_f_14> to prevent splitting - return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text) - - -def unprotect_float(text): - # Turns <3_f_14> into 3.14 - return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text) - - -def split_text(text, length): - text = clean_text(text) - - # Break the text into pieces with following rules: - # 1. Split the text at ".", "!", "?" if text is NOT a float - # 2. If the text is longer than length, split at "," - # 3. If the text is still longer than length, split at " " - # 4. If the text is still longer than length, split at any character to length - - texts = [text] - texts = map(protect_float, texts) - texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"}) - texts = map(unprotect_float, texts) - texts = break_text(texts, length, {",", ","}) - texts = break_text(texts, length, {" "}) - texts = list(break_text_by_length(texts, length)) - - # Then, merge the texts into segments with length <= length - segments = [] - curr = "" - - for text in texts: - if utf_8_len(curr) + utf_8_len(text) <= length: - curr += text - else: - add_cleaned(curr, segments) - curr = text - - if curr: - add_cleaned(curr, segments) - - return segments - - -if __name__ == "__main__": - # Test the split_text function - - text = "This is a test sentence. This is another test sentence. And a third one." - - assert split_text(text, 50) == [ - "This is a test sentence.", - "This is another test sentence. And a third one.", - ] - assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"] - assert split_text(" ", 10) == [] - assert split_text("a", 10) == ["a"] - - text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines." - assert split_text(text, 50) == [ - "This is a test sentence with only commas,", - "and no dots, and no exclamation marks,", - "and no question marks, and no newlines.", - ] - - text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence." - # First half split at " ", second half split at "," - assert split_text(text, 50) == [ - "This is a test sentence This is a test sentence", - "This is a test sentence. This is a test sentence,", - "This is a test sentence, This is a test sentence.", - ] - - text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。" - assert split_text(text, 50) == [ - "这是一段很长的中文文本,", - "而且没有句号,也没有感叹号,", - "也没有问号,也没有换行符.", - ] diff --git a/fish_speech/train.py b/fish_speech/train.py deleted file mode 100644 index e693f3adc4dda787bdd587aec29f53355f2b1653..0000000000000000000000000000000000000000 --- a/fish_speech/train.py +++ /dev/null @@ -1,141 +0,0 @@ -import os - -os.environ["USE_LIBUV"] = "0" -import sys -from typing import Optional - -import hydra -import lightning as L -import pyrootutils -import torch -from lightning import Callback, LightningDataModule, LightningModule, Trainer -from lightning.pytorch.loggers import Logger -from lightning.pytorch.strategies import DDPStrategy -from omegaconf import DictConfig, OmegaConf - -os.environ.pop("SLURM_NTASKS", None) -os.environ.pop("SLURM_JOB_NAME", None) -os.environ.pop("SLURM_NTASKS_PER_NODE", None) - -# register eval resolver and root -pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) - -# Allow TF32 on Ampere GPUs -torch.set_float32_matmul_precision("high") -torch.backends.cudnn.allow_tf32 = True - -# register eval resolver -OmegaConf.register_new_resolver("eval", eval) - -import fish_speech.utils as utils - -log = utils.RankedLogger(__name__, rank_zero_only=True) - - -@utils.task_wrapper -def train(cfg: DictConfig) -> tuple[dict, dict]: - """Trains the model. Can additionally evaluate on a testset, using best weights obtained during - training. - This method is wrapped in optional @task_wrapper decorator, that controls the behavior during - failure. Useful for multiruns, saving info about the crash, etc. - Args: - cfg (DictConfig): Configuration composed by Hydra. - Returns: - Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. - """ # noqa: E501 - - # set seed for random number generators in pytorch, numpy and python.random - if cfg.get("seed"): - L.seed_everything(cfg.seed, workers=False) - - if cfg.get("deterministic"): - torch.use_deterministic_algorithms(True) - - log.info(f"Instantiating datamodule <{cfg.data._target_}>") - datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) - - log.info(f"Instantiating model <{cfg.model._target_}>") - model: LightningModule = hydra.utils.instantiate(cfg.model) - - log.info("Instantiating callbacks...") - callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) - - log.info("Instantiating loggers...") - logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger")) - - log.info(f"Instantiating trainer <{cfg.trainer._target_}>") - trainer: Trainer = hydra.utils.instantiate( - cfg.trainer, - callbacks=callbacks, - logger=logger, - ) - - object_dict = { - "cfg": cfg, - "datamodule": datamodule, - "model": model, - "callbacks": callbacks, - "logger": logger, - "trainer": trainer, - } - - if logger: - log.info("Logging hyperparameters!") - utils.log_hyperparameters(object_dict) - - if cfg.get("train"): - log.info("Starting training!") - - ckpt_path = cfg.get("ckpt_path") - auto_resume = False - - resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir) - if resume_ckpt_path is not None: - ckpt_path = resume_ckpt_path - auto_resume = True - - if ckpt_path is not None: - log.info(f"Resuming from checkpoint: {ckpt_path}") - - # resume weights only is disabled for auto-resume - if cfg.get("resume_weights_only") and auto_resume is False: - log.info("Resuming weights only!") - ckpt = torch.load(ckpt_path, map_location=model.device) - if "state_dict" in ckpt: - ckpt = ckpt["state_dict"] - err = model.load_state_dict(ckpt, strict=False) - log.info(f"Error loading state dict: {err}") - ckpt_path = None - - trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path) - - train_metrics = trainer.callback_metrics - - if cfg.get("test"): - log.info("Starting testing!") - ckpt_path = trainer.checkpoint_callback.best_model_path - if ckpt_path == "": - log.warning("Best ckpt not found! Using current weights for testing...") - ckpt_path = cfg.get("ckpt_path") - - trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) - log.info(f"Best ckpt path: {ckpt_path}") - - test_metrics = trainer.callback_metrics - - # merge train and test metrics - metric_dict = {**train_metrics, **test_metrics} - - return metric_dict, object_dict - - -@hydra.main( - version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml" -) -def main(cfg: DictConfig) -> Optional[float]: - # train the model - train(cfg) - - -if __name__ == "__main__": - main() diff --git a/fish_speech/utils/__init__.py b/fish_speech/utils/__init__.py deleted file mode 100644 index 05378519dbd18361c639e33413d011e7307c9adb..0000000000000000000000000000000000000000 --- a/fish_speech/utils/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -from .braceexpand import braceexpand -from .context import autocast_exclude_mps -from .file import get_latest_checkpoint -from .instantiators import instantiate_callbacks, instantiate_loggers -from .logger import RankedLogger -from .logging_utils import log_hyperparameters -from .rich_utils import enforce_tags, print_config_tree -from .utils import extras, get_metric_value, task_wrapper - -__all__ = [ - "enforce_tags", - "extras", - "get_metric_value", - "RankedLogger", - "instantiate_callbacks", - "instantiate_loggers", - "log_hyperparameters", - "print_config_tree", - "task_wrapper", - "braceexpand", - "get_latest_checkpoint", - "autocast_exclude_mps", -] diff --git a/fish_speech/utils/__pycache__/__init__.cpython-310.pyc b/fish_speech/utils/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 1275a8478b5b6c8ca96cd20f18be4f300e5fba8d..0000000000000000000000000000000000000000 Binary files a/fish_speech/utils/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/utils/__pycache__/braceexpand.cpython-310.pyc b/fish_speech/utils/__pycache__/braceexpand.cpython-310.pyc deleted file mode 100644 index 611e658e7387832fb9d481e775466a60689e364c..0000000000000000000000000000000000000000 Binary files a/fish_speech/utils/__pycache__/braceexpand.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/utils/__pycache__/context.cpython-310.pyc b/fish_speech/utils/__pycache__/context.cpython-310.pyc deleted file mode 100644 index 0701855f15ea618e6fca6bba156a480a26e06705..0000000000000000000000000000000000000000 Binary files a/fish_speech/utils/__pycache__/context.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/utils/__pycache__/file.cpython-310.pyc b/fish_speech/utils/__pycache__/file.cpython-310.pyc deleted file mode 100644 index d52787c4c9346fa3ac90012057d87598170b1619..0000000000000000000000000000000000000000 Binary files a/fish_speech/utils/__pycache__/file.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/utils/__pycache__/instantiators.cpython-310.pyc b/fish_speech/utils/__pycache__/instantiators.cpython-310.pyc deleted file mode 100644 index 78c1b17fb8f7e05a50ed4056b404d7e60c2f104f..0000000000000000000000000000000000000000 Binary files a/fish_speech/utils/__pycache__/instantiators.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/utils/__pycache__/logger.cpython-310.pyc b/fish_speech/utils/__pycache__/logger.cpython-310.pyc deleted file mode 100644 index 32cfb48f1bac889f58a4059ebc3033b2ec328077..0000000000000000000000000000000000000000 Binary files a/fish_speech/utils/__pycache__/logger.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/utils/__pycache__/logging_utils.cpython-310.pyc b/fish_speech/utils/__pycache__/logging_utils.cpython-310.pyc deleted file mode 100644 index 5e24723cdd60dd27d036e1e3a72def349e22f5d8..0000000000000000000000000000000000000000 Binary files a/fish_speech/utils/__pycache__/logging_utils.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/utils/__pycache__/rich_utils.cpython-310.pyc b/fish_speech/utils/__pycache__/rich_utils.cpython-310.pyc deleted file mode 100644 index e8a99c143231037868f6d6593d101636c0955844..0000000000000000000000000000000000000000 Binary files a/fish_speech/utils/__pycache__/rich_utils.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/utils/__pycache__/spectrogram.cpython-310.pyc b/fish_speech/utils/__pycache__/spectrogram.cpython-310.pyc deleted file mode 100644 index d83d0c6d63e8a397e659056c7f4ecdc3299f9135..0000000000000000000000000000000000000000 Binary files a/fish_speech/utils/__pycache__/spectrogram.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/utils/__pycache__/utils.cpython-310.pyc b/fish_speech/utils/__pycache__/utils.cpython-310.pyc deleted file mode 100644 index fbad6b1a0fbbd0e58817cd597ae6b9ed26f7e53a..0000000000000000000000000000000000000000 Binary files a/fish_speech/utils/__pycache__/utils.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/utils/braceexpand.py b/fish_speech/utils/braceexpand.py deleted file mode 100644 index f3ac739f01f7e10e039c68c1157d6c761064f974..0000000000000000000000000000000000000000 --- a/fish_speech/utils/braceexpand.py +++ /dev/null @@ -1,217 +0,0 @@ -""" -Bash-style brace expansion -Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py -License: MIT -""" - -import re -import string -from itertools import chain, product -from typing import Iterable, Iterator, Optional - -__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"] - - -class UnbalancedBracesError(ValueError): - pass - - -alphabet = string.ascii_uppercase + string.ascii_lowercase - -int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$") -char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$") -escape_re = re.compile(r"\\(.)") - - -def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]: - """braceexpand(pattern) -> iterator over generated strings - - Returns an iterator over the strings resulting from brace expansion - of pattern. This function implements Brace Expansion as described in - bash(1), with the following limitations: - - * A pattern containing unbalanced braces will raise an - UnbalancedBracesError exception. In bash, unbalanced braces will either - be partly expanded or ignored. - - * A mixed-case character range like '{Z..a}' or '{a..Z}' will not - include the characters '[]^_`' between 'Z' and 'a'. - - When escape is True (the default), characters in pattern can be - prefixed with a backslash to cause them not to be interpreted as - special characters for brace expansion (such as '{', '}', ','). - To pass through a a literal backslash, double it ('\\\\'). - - When escape is False, backslashes in pattern have no special - meaning and will be preserved in the output. - - Examples: - - >>> from braceexpand import braceexpand - - # Integer range - >>> list(braceexpand('item{1..3}')) - ['item1', 'item2', 'item3'] - - # Character range - >>> list(braceexpand('{a..c}')) - ['a', 'b', 'c'] - - # Sequence - >>> list(braceexpand('index.html{,.backup}')) - ['index.html', 'index.html.backup'] - - # Nested patterns - >>> list(braceexpand('python{2.{5..7},3.{2,3}}')) - ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3'] - - # Prefixing an integer with zero causes all numbers to be padded to - # the same width. - >>> list(braceexpand('{07..10}')) - ['07', '08', '09', '10'] - - # An optional increment can be specified for ranges. - >>> list(braceexpand('{a..g..2}')) - ['a', 'c', 'e', 'g'] - - # Ranges can go in both directions. - >>> list(braceexpand('{4..1}')) - ['4', '3', '2', '1'] - - # Numbers can be negative - >>> list(braceexpand('{2..-1}')) - ['2', '1', '0', '-1'] - - # Unbalanced braces raise an exception. - >>> list(braceexpand('{1{2,3}')) - Traceback (most recent call last): - ... - UnbalancedBracesError: Unbalanced braces: '{1{2,3}' - - # By default, the backslash is the escape character. - >>> list(braceexpand(r'{1\\{2,3}')) - ['1{2', '3'] - - # Setting 'escape' to False disables backslash escaping. - >>> list(braceexpand(r'\\{1,2}', escape=False)) - ['\\\\1', '\\\\2'] - - """ - return ( - escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape) - ) - - -def parse_pattern(pattern: str, escape: bool) -> Iterator[str]: - start = 0 - pos = 0 - bracketdepth = 0 - items: list[Iterable[str]] = [] - - # print 'pattern:', pattern - while pos < len(pattern): - if escape and pattern[pos] == "\\": - pos += 2 - continue - elif pattern[pos] == "{": - if bracketdepth == 0 and pos > start: - # print 'literal:', pattern[start:pos] - items.append([pattern[start:pos]]) - start = pos - bracketdepth += 1 - elif pattern[pos] == "}": - bracketdepth -= 1 - if bracketdepth == 0: - # print 'expression:', pattern[start+1:pos] - expr = pattern[start + 1 : pos] - item = parse_expression(expr, escape) - if item is None: # not a range or sequence - items.extend([["{"], parse_pattern(expr, escape), ["}"]]) - else: - items.append(item) - start = pos + 1 # skip the closing brace - pos += 1 - - if bracketdepth != 0: # unbalanced braces - raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern) - - if start < pos: - items.append([pattern[start:]]) - - return ("".join(item) for item in product(*items)) - - -def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]: - int_range_match = int_range_re.match(expr) - if int_range_match: - return make_int_range(*int_range_match.groups()) - - char_range_match = char_range_re.match(expr) - if char_range_match: - return make_char_range(*char_range_match.groups()) - - return parse_sequence(expr, escape) - - -def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]: - # sequence -> chain(*sequence_items) - start = 0 - pos = 0 - bracketdepth = 0 - items: list[Iterable[str]] = [] - - # print 'sequence:', seq - while pos < len(seq): - if escape and seq[pos] == "\\": - pos += 2 - continue - elif seq[pos] == "{": - bracketdepth += 1 - elif seq[pos] == "}": - bracketdepth -= 1 - elif seq[pos] == "," and bracketdepth == 0: - items.append(parse_pattern(seq[start:pos], escape)) - start = pos + 1 # skip the comma - pos += 1 - - if bracketdepth != 0: - raise UnbalancedBracesError - if not items: - return None - - # part after the last comma (may be the empty string) - items.append(parse_pattern(seq[start:], escape)) - return chain(*items) - - -def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]: - if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]): - padding = max(len(left), len(right)) - else: - padding = 0 - step = (int(incr) or 1) if incr else 1 - start = int(left) - end = int(right) - r = range(start, end + 1, step) if start < end else range(start, end - 1, -step) - fmt = "%0{}d".format(padding) - return (fmt % i for i in r) - - -def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str: - step = (int(incr) or 1) if incr else 1 - start = alphabet.index(left) - end = alphabet.index(right) - if start < end: - return alphabet[start : end + 1 : step] - else: - end = end or -len(alphabet) - return alphabet[start : end - 1 : -step] - - -if __name__ == "__main__": - import doctest - import sys - - failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL) - if failed: - sys.exit(1) diff --git a/fish_speech/utils/context.py b/fish_speech/utils/context.py deleted file mode 100644 index f04a99290ab32f7fe5b60656075a2d03af8468d6..0000000000000000000000000000000000000000 --- a/fish_speech/utils/context.py +++ /dev/null @@ -1,13 +0,0 @@ -from contextlib import nullcontext - -import torch - - -def autocast_exclude_mps( - device_type: str, dtype: torch.dtype -) -> nullcontext | torch.autocast: - return ( - nullcontext() - if torch.backends.mps.is_available() - else torch.autocast(device_type, dtype) - ) diff --git a/fish_speech/utils/file.py b/fish_speech/utils/file.py deleted file mode 100644 index 78c82640a963fa556657107729f7543d2e7c3510..0000000000000000000000000000000000000000 --- a/fish_speech/utils/file.py +++ /dev/null @@ -1,16 +0,0 @@ -import os -from pathlib import Path - - -def get_latest_checkpoint(path: Path | str) -> Path | None: - # Find the latest checkpoint - ckpt_dir = Path(path) - - if ckpt_dir.exists() is False: - return None - - ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime) - if len(ckpts) == 0: - return None - - return ckpts[-1] diff --git a/fish_speech/utils/instantiators.py b/fish_speech/utils/instantiators.py deleted file mode 100644 index f6ee463924f588a35477937fbe3c3364043bdf3e..0000000000000000000000000000000000000000 --- a/fish_speech/utils/instantiators.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import List - -import hydra -from omegaconf import DictConfig -from pytorch_lightning import Callback -from pytorch_lightning.loggers import Logger - -from .logger import RankedLogger - -log = RankedLogger(__name__, rank_zero_only=True) - - -def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: - """Instantiates callbacks from config.""" - - callbacks: List[Callback] = [] - - if not callbacks_cfg: - log.warning("No callback configs found! Skipping..") - return callbacks - - if not isinstance(callbacks_cfg, DictConfig): - raise TypeError("Callbacks config must be a DictConfig!") - - for _, cb_conf in callbacks_cfg.items(): - if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: - log.info(f"Instantiating callback <{cb_conf._target_}>") - callbacks.append(hydra.utils.instantiate(cb_conf)) - - return callbacks - - -def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: - """Instantiates loggers from config.""" - - logger: List[Logger] = [] - - if not logger_cfg: - log.warning("No logger configs found! Skipping...") - return logger - - if not isinstance(logger_cfg, DictConfig): - raise TypeError("Logger config must be a DictConfig!") - - for _, lg_conf in logger_cfg.items(): - if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: - log.info(f"Instantiating logger <{lg_conf._target_}>") - logger.append(hydra.utils.instantiate(lg_conf)) - - return logger diff --git a/fish_speech/utils/logger.py b/fish_speech/utils/logger.py deleted file mode 100644 index 94f94f738d1d87404354d086c30ef0ad9ab04cdc..0000000000000000000000000000000000000000 --- a/fish_speech/utils/logger.py +++ /dev/null @@ -1,55 +0,0 @@ -import logging -from typing import Mapping, Optional - -from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only - - -class RankedLogger(logging.LoggerAdapter): - """A multi-GPU-friendly python command line logger.""" - - def __init__( - self, - name: str = __name__, - rank_zero_only: bool = True, - extra: Optional[Mapping[str, object]] = None, - ) -> None: - """Initializes a multi-GPU-friendly python command line logger that logs on all processes - with their rank prefixed in the log message. - - :param name: The name of the logger. Default is ``__name__``. - :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. - :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. - """ - logger = logging.getLogger(name) - super().__init__(logger=logger, extra=extra) - self.rank_zero_only = rank_zero_only - - def log( - self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs - ) -> None: - """Delegate a log call to the underlying logger, after prefixing its message with the rank - of the process it's being logged from. If `'rank'` is provided, then the log will only - occur on that rank/process. - - :param level: The level to log at. Look at `logging.__init__.py` for more information. - :param msg: The message to log. - :param rank: The rank to log at. - :param args: Additional args to pass to the underlying logging function. - :param kwargs: Any additional keyword args to pass to the underlying logging function. - """ - if self.isEnabledFor(level): - msg, kwargs = self.process(msg, kwargs) - current_rank = getattr(rank_zero_only, "rank", None) - if current_rank is None: - raise RuntimeError( - "The `rank_zero_only.rank` needs to be set before use" - ) - msg = rank_prefixed_message(msg, current_rank) - if self.rank_zero_only: - if current_rank == 0: - self.logger.log(level, msg, *args, **kwargs) - else: - if rank is None: - self.logger.log(level, msg, *args, **kwargs) - elif current_rank == rank: - self.logger.log(level, msg, *args, **kwargs) diff --git a/fish_speech/utils/logging_utils.py b/fish_speech/utils/logging_utils.py deleted file mode 100644 index 8e3b0a2519e12845f09e5fbe86dfccbf5b345429..0000000000000000000000000000000000000000 --- a/fish_speech/utils/logging_utils.py +++ /dev/null @@ -1,48 +0,0 @@ -from lightning.pytorch.utilities import rank_zero_only - -from fish_speech.utils import logger as log - - -@rank_zero_only -def log_hyperparameters(object_dict: dict) -> None: - """Controls which config parts are saved by lightning loggers. - - Additionally saves: - - Number of model parameters - """ - - hparams = {} - - cfg = object_dict["cfg"] - model = object_dict["model"] - trainer = object_dict["trainer"] - - if not trainer.logger: - log.warning("Logger not found! Skipping hyperparameter logging...") - return - - hparams["model"] = cfg["model"] - - # save number of model parameters - hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) - hparams["model/params/trainable"] = sum( - p.numel() for p in model.parameters() if p.requires_grad - ) - hparams["model/params/non_trainable"] = sum( - p.numel() for p in model.parameters() if not p.requires_grad - ) - - hparams["data"] = cfg["data"] - hparams["trainer"] = cfg["trainer"] - - hparams["callbacks"] = cfg.get("callbacks") - hparams["extras"] = cfg.get("extras") - - hparams["task_name"] = cfg.get("task_name") - hparams["tags"] = cfg.get("tags") - hparams["ckpt_path"] = cfg.get("ckpt_path") - hparams["seed"] = cfg.get("seed") - - # send hparams to all loggers - for logger in trainer.loggers: - logger.log_hyperparams(hparams) diff --git a/fish_speech/utils/rich_utils.py b/fish_speech/utils/rich_utils.py deleted file mode 100644 index 6a465f54d610779766d51e3d1a020a3b1517fd1f..0000000000000000000000000000000000000000 --- a/fish_speech/utils/rich_utils.py +++ /dev/null @@ -1,100 +0,0 @@ -from pathlib import Path -from typing import Sequence - -import rich -import rich.syntax -import rich.tree -from hydra.core.hydra_config import HydraConfig -from lightning.pytorch.utilities import rank_zero_only -from omegaconf import DictConfig, OmegaConf, open_dict -from rich.prompt import Prompt - -from fish_speech.utils import logger as log - - -@rank_zero_only -def print_config_tree( - cfg: DictConfig, - print_order: Sequence[str] = ( - "data", - "model", - "callbacks", - "logger", - "trainer", - "paths", - "extras", - ), - resolve: bool = False, - save_to_file: bool = False, -) -> None: - """Prints content of DictConfig using Rich library and its tree structure. - - Args: - cfg (DictConfig): Configuration composed by Hydra. - print_order (Sequence[str], optional): Determines in what order config components are printed. - resolve (bool, optional): Whether to resolve reference fields of DictConfig. - save_to_file (bool, optional): Whether to export config to the hydra output folder. - """ # noqa: E501 - - style = "dim" - tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) - - queue = [] - - # add fields from `print_order` to queue - for field in print_order: - ( - queue.append(field) - if field in cfg - else log.warning( - f"Field '{field}' not found in config. " - + f"Skipping '{field}' config printing..." - ) - ) - - # add all the other fields to queue (not specified in `print_order`) - for field in cfg: - if field not in queue: - queue.append(field) - - # generate config tree from queue - for field in queue: - branch = tree.add(field, style=style, guide_style=style) - - config_group = cfg[field] - if isinstance(config_group, DictConfig): - branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) - else: - branch_content = str(config_group) - - branch.add(rich.syntax.Syntax(branch_content, "yaml")) - - # print config tree - rich.print(tree) - - # save config tree to file - if save_to_file: - with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: - rich.print(tree, file=file) - - -@rank_zero_only -def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: - """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501 - - if not cfg.get("tags"): - if "id" in HydraConfig().cfg.hydra.job: - raise ValueError("Specify tags before launching a multirun!") - - log.warning("No tags provided in config. Prompting user to input tags...") - tags = Prompt.ask("Enter a list of comma separated tags", default="dev") - tags = [t.strip() for t in tags.split(",") if t != ""] - - with open_dict(cfg): - cfg.tags = tags - - log.info(f"Tags: {cfg.tags}") - - if save_to_file: - with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: - rich.print(cfg.tags, file=file) diff --git a/fish_speech/utils/spectrogram.py b/fish_speech/utils/spectrogram.py deleted file mode 100644 index 01c3d7a2ab0f707ae92dbde0feb173927720c841..0000000000000000000000000000000000000000 --- a/fish_speech/utils/spectrogram.py +++ /dev/null @@ -1,122 +0,0 @@ -import torch -import torchaudio.functional as F -from torch import Tensor, nn -from torchaudio.transforms import MelScale - - -class LinearSpectrogram(nn.Module): - def __init__( - self, - n_fft=2048, - win_length=2048, - hop_length=512, - center=False, - mode="pow2_sqrt", - ): - super().__init__() - - self.n_fft = n_fft - self.win_length = win_length - self.hop_length = hop_length - self.center = center - self.mode = mode - - self.register_buffer("window", torch.hann_window(win_length), persistent=False) - - def forward(self, y: Tensor) -> Tensor: - if y.ndim == 3: - y = y.squeeze(1) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - ( - (self.win_length - self.hop_length) // 2, - (self.win_length - self.hop_length + 1) // 2, - ), - mode="reflect", - ).squeeze(1) - - spec = torch.stft( - y, - self.n_fft, - hop_length=self.hop_length, - win_length=self.win_length, - window=self.window, - center=self.center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - - spec = torch.view_as_real(spec) - - if self.mode == "pow2_sqrt": - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - - return spec - - -class LogMelSpectrogram(nn.Module): - def __init__( - self, - sample_rate=44100, - n_fft=2048, - win_length=2048, - hop_length=512, - n_mels=128, - center=False, - f_min=0.0, - f_max=None, - ): - super().__init__() - - self.sample_rate = sample_rate - self.n_fft = n_fft - self.win_length = win_length - self.hop_length = hop_length - self.center = center - self.n_mels = n_mels - self.f_min = f_min - self.f_max = f_max or float(sample_rate // 2) - - self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center) - - fb = F.melscale_fbanks( - n_freqs=self.n_fft // 2 + 1, - f_min=self.f_min, - f_max=self.f_max, - n_mels=self.n_mels, - sample_rate=self.sample_rate, - norm="slaney", - mel_scale="slaney", - ) - self.register_buffer( - "fb", - fb, - persistent=False, - ) - - def compress(self, x: Tensor) -> Tensor: - return torch.log(torch.clamp(x, min=1e-5)) - - def decompress(self, x: Tensor) -> Tensor: - return torch.exp(x) - - def apply_mel_scale(self, x: Tensor) -> Tensor: - return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2) - - def forward( - self, x: Tensor, return_linear: bool = False, sample_rate: int = None - ) -> Tensor: - if sample_rate is not None and sample_rate != self.sample_rate: - x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate) - - linear = self.spectrogram(x) - x = self.apply_mel_scale(linear) - x = self.compress(x) - - if return_linear: - return x, self.compress(linear) - - return x diff --git a/fish_speech/utils/utils.py b/fish_speech/utils/utils.py deleted file mode 100644 index c546bfa1eddd2ac6bf484cce1ec06da1d33fb121..0000000000000000000000000000000000000000 --- a/fish_speech/utils/utils.py +++ /dev/null @@ -1,114 +0,0 @@ -import warnings -from importlib.util import find_spec -from typing import Callable - -from omegaconf import DictConfig - -from .logger import RankedLogger -from .rich_utils import enforce_tags, print_config_tree - -log = RankedLogger(__name__, rank_zero_only=True) - - -def extras(cfg: DictConfig) -> None: - """Applies optional utilities before the task is started. - - Utilities: - - Ignoring python warnings - - Setting tags from command line - - Rich config printing - """ - - # return if no `extras` config - if not cfg.get("extras"): - log.warning("Extras config not found! ") - return - - # disable python warnings - if cfg.extras.get("ignore_warnings"): - log.info("Disabling python warnings! ") - warnings.filterwarnings("ignore") - - # prompt user to input tags from command line if none are provided in the config - if cfg.extras.get("enforce_tags"): - log.info("Enforcing tags! ") - enforce_tags(cfg, save_to_file=True) - - # pretty print config tree using Rich library - if cfg.extras.get("print_config"): - log.info("Printing config tree with Rich! ") - print_config_tree(cfg, resolve=True, save_to_file=True) - - -def task_wrapper(task_func: Callable) -> Callable: - """Optional decorator that controls the failure behavior when executing the task function. - - This wrapper can be used to: - - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) - - save the exception to a `.log` file - - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) - - etc. (adjust depending on your needs) - - Example: - ``` - @utils.task_wrapper - def train(cfg: DictConfig) -> Tuple[dict, dict]: - - ... - - return metric_dict, object_dict - ``` - """ # noqa: E501 - - def wrap(cfg: DictConfig): - # execute the task - try: - metric_dict, object_dict = task_func(cfg=cfg) - - # things to do if exception occurs - except Exception as ex: - # save exception to `.log` file - log.exception("") - - # some hyperparameter combinations might be invalid or - # cause out-of-memory errors so when using hparam search - # plugins like Optuna, you might want to disable - # raising the below exception to avoid multirun failure - raise ex - - # things to always do after either success or exception - finally: - # display output dir path in terminal - log.info(f"Output dir: {cfg.paths.run_dir}") - - # always close wandb run (even if exception occurs so multirun won't fail) - if find_spec("wandb"): # check if wandb is installed - import wandb - - if wandb.run: - log.info("Closing wandb!") - wandb.finish() - - return metric_dict, object_dict - - return wrap - - -def get_metric_value(metric_dict: dict, metric_name: str) -> float: - """Safely retrieves value of the metric logged in LightningModule.""" - - if not metric_name: - log.info("Metric name is None! Skipping metric value retrieval...") - return None - - if metric_name not in metric_dict: - raise Exception( - f"Metric value not found! \n" - "Make sure metric name logged in LightningModule is correct!\n" - "Make sure `optimized_metric` name in `hparams_search` config is correct!" - ) - - metric_value = metric_dict[metric_name].item() - log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") - - return metric_value diff --git a/fish_speech/webui/__pycache__/launch_utils.cpython-310.pyc b/fish_speech/webui/__pycache__/launch_utils.cpython-310.pyc deleted file mode 100644 index 0bd0b8af3ea645c95065dbbe9b037384e54ad614..0000000000000000000000000000000000000000 Binary files a/fish_speech/webui/__pycache__/launch_utils.cpython-310.pyc and /dev/null differ diff --git a/fish_speech/webui/css/style.css b/fish_speech/webui/css/style.css deleted file mode 100644 index 3c7a22ecc31881a65a76369b0fd889330a0874c7..0000000000000000000000000000000000000000 --- a/fish_speech/webui/css/style.css +++ /dev/null @@ -1,161 +0,0 @@ -:root { - --my-200: #80eeee; - --my-50: #ecfdf5; - --water-width: 300px; - --water-heigh: 300px; -} - - -/* general styled components */ -.tools { - align-items: center; - justify-content: center; -} - -.gradio-button { - max-width: 2.2em; - min-width: 2.2em !important; - height: 2.4em; - align-self: end; - line-height: 1em; - border-radius: 0.5em; - -} - -.gradio-button.secondary-down, .gradio-button.secondary-down:hover{ - box-shadow: 1px 1px 1px rgba(0,0,0,0.25) inset, 0px 0px 3px rgba(0,0,0,0.15) inset; -} - -/* replace original footer with ours */ -a{ - font-weight: bold; - cursor: pointer; - color: #030C14 !important; -} - -footer { - display: none !important; -} - -#footer{ - text-align: center; -} - -#footer div{ - display: inline-block; -} - -#footer .versions{ - font-size: 85%; - opacity: 0.85; -} - -/*@keyframes moveBackground {*/ -/* 0% {*/ -/* background-position: 0 0;*/ -/* }*/ -/* 100% {*/ -/* background-position: -100px 100px;*/ -/* }*/ -/*}*/ -@keyframes moveJellyBackground { - 0% { - background-position: 0% 50%; - } - 50% { - background-position: 100% 50%; - } - 100% { - background-position: 0% 50%; - } -} - -.gradio-container { - position: absolute; - z-index: 10; -} - - -.quan { - position: absolute; - bottom: 0; - width: var(--water-width); - height: var(--water-heigh); - border-radius: 0; - /*border: 3px solid rgb(246, 247, 248);*/ - /*box-shadow: 0 0 0 3px rgb(41, 134, 196);*/ - z-index: 0; - -} - -.quan:last-child { - margin-right: 0; -} - -.shui { - position: absolute; - top: 0; - left: 0; - width: 100%; - height: 100%; - background-color: rgb(23, 106, 201); - border-radius: 0; - overflow: hidden; - z-index: 0; -} - -.shui::after { - - content: ''; - position: absolute; - top: 20%; - left: 50%; - width: 150%; - height: 150%; - border-radius: 40%; - background-image: radial-gradient(circle at 0% 50%, #dcfcf1, var(--my-50) 50%); - animation: shi 5s linear infinite; -} - -@keyframes shi { - 0% { - transform: translate(-50%, -65%) rotate(0deg); - } - 100% { - transform: translate(-50%, -65%) rotate(360deg); - } -} - -.shui::before { - content: ''; - position: absolute; - top: 20%; - left: 50%; - width: 150%; - height: 150%; - border-radius: 42%; - background-color: rgb(240, 228, 228, 0.2); - animation: xu 7s linear infinite; -} - -@keyframes xu { - 0% { - transform: translate(-50%, -60%) rotate(0deg); - } - 100% { - transform: translate(-50%, -60%) rotate(360deg); - } -} - -fieldset.data_src div.wrap label { - background: #f8bffee0 !important; -} - -.scrollable-component { - max-height: 100px; - overflow-y: auto; -} - -#file_accordion { - max-height: 220px !important; -} diff --git a/fish_speech/webui/html/footer.html b/fish_speech/webui/html/footer.html deleted file mode 100644 index ac1745aa6f41f86a17e3d95564c2bf7a8d7bb615..0000000000000000000000000000000000000000 --- a/fish_speech/webui/html/footer.html +++ /dev/null @@ -1,11 +0,0 @@ -
- API -  •  - Github -  •  - Gradio -
-
-
-{versions} -
diff --git a/fish_speech/webui/js/animate.js b/fish_speech/webui/js/animate.js deleted file mode 100644 index 0637a541a8e704632a42b89bdf1471b26e7bb868..0000000000000000000000000000000000000000 --- a/fish_speech/webui/js/animate.js +++ /dev/null @@ -1,69 +0,0 @@ - -function createGradioAnimation() { - const params = new URLSearchParams(window.location.search); - if (!params.has('__theme')) { - params.set('__theme', 'light'); - window.location.search = params.toString(); - } - - var gradioApp = document.querySelector('gradio-app'); - if (gradioApp) { - - document.documentElement.style.setProperty('--my-200', '#80eeee'); - document.documentElement.style.setProperty('--my-50', '#ecfdf5'); - - // gradioApp.style.position = 'relative'; - // gradioApp.style.backgroundSize = '200% 200%'; - // gradioApp.style.animation = 'moveJellyBackground 10s ease infinite'; - // gradioApp.style.backgroundImage = 'radial-gradient(circle at 0% 50%, var(--my-200), var(--my-50) 50%)'; - // gradioApp.style.display = 'flex'; - // gradioApp.style.justifyContent = 'flex-start'; - // gradioApp.style.flexWrap = 'nowrap'; - // gradioApp.style.overflowX = 'auto'; - - // for (let i = 0; i < 6; i++) { - // var quan = document.createElement('div'); - // quan.className = 'quan'; - // gradioApp.insertBefore(quan, gradioApp.firstChild); - // quan.id = 'quan' + i.toString(); - // quan.style.left = 'calc(var(--water-width) * ' + i.toString() + ')'; - // var quanContainer = document.querySelector('.quan'); - // if (quanContainer) { - // var shui = document.createElement('div'); - // shui.className = 'shui'; - // quanContainer.insertBefore(shui, quanContainer.firstChild) - // } - // } - } - - var container = document.createElement('div'); - container.id = 'gradio-animation'; - container.style.fontSize = '2em'; - container.style.fontFamily = 'Maiandra GD, ui-monospace, monospace'; - container.style.fontWeight = 'bold'; - container.style.textAlign = 'center'; - container.style.marginBottom = '20px'; - - var text = 'Welcome to Fish-Speech!'; - for (var i = 0; i < text.length; i++) { - (function(i){ - setTimeout(function(){ - var letter = document.createElement('span'); - letter.style.opacity = '0'; - letter.style.transition = 'opacity 0.5s'; - letter.innerText = text[i]; - - container.appendChild(letter); - - setTimeout(function() { - letter.style.opacity = '1'; - }, 50); - }, i * 200); - })(i); - } - - var gradioContainer = document.querySelector('.gradio-container'); - gradioContainer.insertBefore(container, gradioContainer.firstChild); - - return 'Animation created'; -} diff --git a/fish_speech/webui/launch_utils.py b/fish_speech/webui/launch_utils.py deleted file mode 100644 index 2f57b595a20177800dbedd71faef573ee8398418..0000000000000000000000000000000000000000 --- a/fish_speech/webui/launch_utils.py +++ /dev/null @@ -1,120 +0,0 @@ -import importlib.util -import os -import subprocess -import sys -from functools import lru_cache -from pathlib import Path -from typing import Iterable - -import gradio as gr -from gradio.themes.base import Base -from gradio.themes.utils import colors, fonts, sizes - -GIT = ( - (Path(os.environ.get("GIT_HOME", "")) / "git").resolve() - if sys.platform == "win32" - else "git" -) -GIT = str(GIT) - - -def is_module_installed(module_name: str) -> bool: - spec = importlib.util.find_spec(module_name) - return spec is not None - - -@lru_cache() -def commit_hash(): - try: - return subprocess.check_output( - [GIT, "log", "-1", "--format='%h %s'"], shell=False, encoding="utf8" - ).strip() - except Exception: - return "" - - -def versions_html(): - import torch - - python_version = ".".join([str(x) for x in sys.version_info[0:3]]) - commit = commit_hash() - hash = commit.strip("'").split(" ")[0] - - return f""" -version: {hash} - •  -python: {python_version} - •  -torch: {getattr(torch, '__long_version__',torch.__version__)} - •  -gradio: {gr.__version__} - •  -author: fishaudio -""" - - -def version_check(commit): - try: - import requests - - commits = requests.get( - "https://api.github.com/repos/fishaudio/fish-speech/branches/main" - ).json() - if commit != "" and commits["commit"]["sha"] != commit: - print("--------------------------------------------------------") - print("| You are not up to date with the most recent release. |") - print("| Consider running `git pull` to update. |") - print("--------------------------------------------------------") - elif commits["commit"]["sha"] == commit: - print("You are up to date with the most recent release.") - else: - print("Not a git clone, can't perform version check.") - except Exception as e: - print("version check failed", e) - - -class Seafoam(Base): - def __init__( - self, - *, - primary_hue: colors.Color | str = colors.emerald, - secondary_hue: colors.Color | str = colors.blue, - neutral_hue: colors.Color | str = colors.blue, - spacing_size: sizes.Size | str = sizes.spacing_md, - radius_size: sizes.Size | str = sizes.radius_md, - text_size: sizes.Size | str = sizes.text_lg, - font: fonts.Font | str | Iterable[fonts.Font | str] = ( - fonts.GoogleFont("Quicksand"), - "ui-sans-serif", - "sans-serif", - ), - font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( - fonts.GoogleFont("IBM Plex Mono"), - "ui-monospace", - "monospace", - ), - ): - super().__init__( - primary_hue=primary_hue, - secondary_hue=secondary_hue, - neutral_hue=neutral_hue, - spacing_size=spacing_size, - radius_size=radius_size, - text_size=text_size, - font=font, - font_mono=font_mono, - ) - super().set( - button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)", - button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)", - button_primary_text_color="white", - button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)", - slider_color="*secondary_300", - slider_color_dark="*secondary_600", - block_title_text_weight="600", - block_border_width="3px", - block_shadow="*shadow_drop_lg", - button_shadow="*shadow_drop_lg", - button_small_padding="0px", - button_large_padding="3px", - ) diff --git a/fish_speech/webui/manage.py b/fish_speech/webui/manage.py deleted file mode 100644 index 4ec3fcac25de3cc7d239c4903403d1a4cd81567b..0000000000000000000000000000000000000000 --- a/fish_speech/webui/manage.py +++ /dev/null @@ -1,1239 +0,0 @@ -from __future__ import annotations - -import os - -os.environ["USE_LIBUV"] = "0" -import datetime -import html -import json -import platform -import shutil -import signal -import subprocess -import sys -from pathlib import Path - -import gradio as gr -import psutil -import yaml -from loguru import logger -from tqdm import tqdm - -PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python") -sys.path.insert(0, "") -print(sys.path) -cur_work_dir = Path(os.getcwd()).resolve() -print("You are in ", str(cur_work_dir)) - -from fish_speech.i18n import i18n -from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html - -config_path = cur_work_dir / "fish_speech" / "configs" -vqgan_yml_path = config_path / "firefly_gan_vq.yaml" -llama_yml_path = config_path / "text2semantic_finetune.yaml" - -env = os.environ.copy() -env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0" - -seafoam = Seafoam() - - -def build_html_error_message(error): - return f""" -
- {html.escape(error)} -
- """ - - -def build_html_ok_message(msg): - return f""" -
- {html.escape(msg)} -
- """ - - -def build_html_href(link, desc, msg): - return f""" - - {html.escape(msg)} - {desc} - - """ - - -def load_data_in_raw(path): - with open(path, "r", encoding="utf-8") as file: - data = file.read() - return str(data) - - -def kill_proc_tree(pid, including_parent=True): - try: - parent = psutil.Process(pid) - except psutil.NoSuchProcess: - # Process already terminated - return - - children = parent.children(recursive=True) - for child in children: - try: - os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL - except OSError: - pass - if including_parent: - try: - os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL - except OSError: - pass - - -system = platform.system() -p_label = None -p_infer = None -p_tensorboard = None - - -def kill_process(pid): - if system == "Windows": - cmd = "taskkill /t /f /pid %s" % pid - # os.system(cmd) - subprocess.run(cmd) - else: - kill_proc_tree(pid) - - -def change_label(if_label): - global p_label - if if_label == True and p_label is None: - url = "http://localhost:3000" - remote_url = "https://text-labeler.pages.dev/" - try: - p_label = subprocess.Popen( - [ - ( - "asr-label-linux-x64" - if sys.platform == "linux" - else "asr-label-win-x64.exe" - ) - ] - ) - except FileNotFoundError: - logger.warning("asr-label execution not found!") - - yield build_html_href( - link=remote_url, - desc=i18n("Optional online ver"), - msg=i18n("Opened labeler in browser"), - ) - - elif if_label == False and p_label is not None: - kill_process(p_label.pid) - p_label = None - yield build_html_ok_message("Nothing") - - -def clean_infer_cache(): - import tempfile - - temp_dir = Path(tempfile.gettempdir()) - gradio_dir = str(temp_dir / "gradio") - try: - shutil.rmtree(gradio_dir) - logger.info(f"Deleted cached audios: {gradio_dir}") - except PermissionError: - logger.info(f"Permission denied: Unable to delete {gradio_dir}") - except FileNotFoundError: - logger.info(f"{gradio_dir} was not found") - except Exception as e: - logger.info(f"An error occurred: {e}") - - -def change_infer( - if_infer, - host, - port, - infer_decoder_model, - infer_decoder_config, - infer_llama_model, - infer_compile, -): - global p_infer - if if_infer == True and p_infer == None: - env = os.environ.copy() - - env["GRADIO_SERVER_NAME"] = host - env["GRADIO_SERVER_PORT"] = port - # 启动第二个进程 - url = f"http://{host}:{port}" - yield build_html_ok_message( - i18n("Inferring interface is launched at {}").format(url) - ) - - clean_infer_cache() - - p_infer = subprocess.Popen( - [ - PYTHON, - "tools/webui.py", - "--decoder-checkpoint-path", - infer_decoder_model, - "--decoder-config-name", - infer_decoder_config, - "--llama-checkpoint-path", - infer_llama_model, - ] - + (["--compile"] if infer_compile == "Yes" else []), - env=env, - ) - - elif if_infer == False and p_infer is not None: - kill_process(p_infer.pid) - p_infer = None - yield build_html_error_message(i18n("Infer interface is closed")) - - -js = load_data_in_raw("fish_speech/webui/js/animate.js") -css = load_data_in_raw("fish_speech/webui/css/style.css") - -data_pre_output = (cur_work_dir / "data").resolve() -default_model_output = (cur_work_dir / "results").resolve() -default_filelist = data_pre_output / "detect.list" -data_pre_output.mkdir(parents=True, exist_ok=True) - -items = [] -dict_items = {} - - -def load_yaml_data_in_fact(yml_path): - with open(yml_path, "r", encoding="utf-8") as file: - yml = yaml.safe_load(file) - return yml - - -def write_yaml_data_in_fact(yml, yml_path): - with open(yml_path, "w", encoding="utf-8") as file: - yaml.safe_dump(yml, file, allow_unicode=True) - return yml - - -def generate_tree(directory, depth=0, max_depth=None, prefix=""): - if max_depth is not None and depth > max_depth: - return "" - - tree_str = "" - files = [] - directories = [] - for item in os.listdir(directory): - if os.path.isdir(os.path.join(directory, item)): - directories.append(item) - else: - files.append(item) - - entries = directories + files - for i, entry in enumerate(entries): - connector = "├── " if i < len(entries) - 1 else "└── " - tree_str += f"{prefix}{connector}{entry}
" - if i < len(directories): - extension = "│ " if i < len(entries) - 1 else " " - tree_str += generate_tree( - os.path.join(directory, entry), - depth + 1, - max_depth, - prefix=prefix + extension, - ) - return tree_str - - -def new_explorer(data_path, max_depth): - return gr.Markdown( - elem_classes=["scrollable-component"], - value=generate_tree(data_path, max_depth=max_depth), - ) - - -def add_item( - folder: str, - method: str, - label_lang: str, - if_initial_prompt: bool, - initial_prompt: str | None, -): - folder = folder.strip(" ").strip('"') - - folder_path = Path(folder) - - if folder and folder not in items and data_pre_output not in folder_path.parents: - if folder_path.is_dir(): - items.append(folder) - dict_items[folder] = dict( - type="folder", - method=method, - label_lang=label_lang, - initial_prompt=initial_prompt if if_initial_prompt else None, - ) - elif folder: - err = folder - return gr.Checkboxgroup(choices=items), build_html_error_message( - i18n("Invalid path: {}").format(err) - ) - - formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4) - logger.info("After Adding: " + formatted_data) - gr.Info(formatted_data) - return gr.Checkboxgroup(choices=items), build_html_ok_message( - i18n("Added path successfully!") - ) - - -def remove_items(selected_items): - global items, dict_items - to_remove = [item for item in items if item in selected_items] - for item in to_remove: - del dict_items[item] - items = [item for item in items if item in dict_items.keys()] - formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4) - logger.info(formatted_data) - gr.Warning("After Removing: " + formatted_data) - return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message( - i18n("Removed path successfully!") - ) - - -def show_selected(options): - selected_options = ", ".join(options) - - if options: - return i18n("Selected: {}").format(selected_options) - else: - return i18n("No selected options") - - -from pydub import AudioSegment - - -def convert_to_mono_in_place(audio_path: Path): - audio = AudioSegment.from_file(audio_path) - if audio.channels > 1: - mono_audio = audio.set_channels(1) - mono_audio.export(audio_path, format=audio_path.suffix[1:]) - logger.info(f"Convert {audio_path} successfully") - - -def list_copy(list_file_path, method): - wav_root = data_pre_output - lst = [] - with list_file_path.open("r", encoding="utf-8") as file: - for line in tqdm(file, desc="Processing audio/transcript"): - wav_path, speaker_name, language, text = line.strip().split("|") - original_wav_path = Path(wav_path) - target_wav_path = ( - wav_root / original_wav_path.parent.name / original_wav_path.name - ) - lst.append(f"{target_wav_path}|{speaker_name}|{language}|{text}") - if target_wav_path.is_file(): - continue - target_wav_path.parent.mkdir(parents=True, exist_ok=True) - if method == i18n("Copy"): - shutil.copy(original_wav_path, target_wav_path) - else: - shutil.move(original_wav_path, target_wav_path.parent) - convert_to_mono_in_place(target_wav_path) - original_lab_path = original_wav_path.with_suffix(".lab") - target_lab_path = ( - wav_root - / original_wav_path.parent.name - / original_wav_path.with_suffix(".lab").name - ) - if target_lab_path.is_file(): - continue - if method == i18n("Copy"): - shutil.copy(original_lab_path, target_lab_path) - else: - shutil.move(original_lab_path, target_lab_path.parent) - - if method == i18n("Move"): - with list_file_path.open("w", encoding="utf-8") as file: - file.writelines("\n".join(lst)) - - del lst - return build_html_ok_message(i18n("Use filelist")) - - -def check_files(data_path: str, max_depth: int, label_model: str, label_device: str): - global dict_items - data_path = Path(data_path) - gr.Warning("Pre-processing begins...") - for item, content in dict_items.items(): - item_path = Path(item) - tar_path = data_path / item_path.name - - if content["type"] == "folder" and item_path.is_dir(): - if content["method"] == i18n("Copy"): - os.makedirs(tar_path, exist_ok=True) - shutil.copytree( - src=str(item_path), dst=str(tar_path), dirs_exist_ok=True - ) - elif not tar_path.is_dir(): - shutil.move(src=str(item_path), dst=str(tar_path)) - - for suf in ["wav", "flac", "mp3"]: - for audio_path in tar_path.glob(f"**/*.{suf}"): - convert_to_mono_in_place(audio_path) - - cur_lang = content["label_lang"] - initial_prompt = content["initial_prompt"] - - transcribe_cmd = [ - PYTHON, - "tools/whisper_asr.py", - "--model-size", - label_model, - "--device", - label_device, - "--audio-dir", - tar_path, - "--save-dir", - tar_path, - "--language", - cur_lang, - ] - - if initial_prompt is not None: - transcribe_cmd += ["--initial-prompt", initial_prompt] - - if cur_lang != "IGNORE": - try: - gr.Warning("Begin To Transcribe") - subprocess.run( - transcribe_cmd, - env=env, - ) - except Exception: - print("Transcription error occurred") - - elif content["type"] == "file" and item_path.is_file(): - list_copy(item_path, content["method"]) - - return build_html_ok_message(i18n("Move files successfully")), new_explorer( - data_path, max_depth=max_depth - ) - - -def generate_folder_name(): - now = datetime.datetime.now() - folder_name = now.strftime("%Y%m%d_%H%M%S") - return folder_name - - -def train_process( - data_path: str, - option: str, - # llama config - llama_ckpt, - llama_base_config, - llama_lr, - llama_maxsteps, - llama_data_num_workers, - llama_data_batch_size, - llama_data_max_length, - llama_precision, - llama_check_interval, - llama_grad_batches, - llama_use_speaker, - llama_use_lora, -): - - backend = "nccl" if sys.platform == "linux" else "gloo" - - new_project = generate_folder_name() - print("New Project Name: ", new_project) - - if option == "VQGAN": - msg = "Skipped VQGAN Training." - gr.Warning(msg) - logger.info(msg) - - if option == "LLAMA": - msg = "LLAMA Training begins..." - gr.Warning(msg) - logger.info(msg) - subprocess.run( - [ - PYTHON, - "tools/vqgan/extract_vq.py", - str(data_pre_output), - "--num-workers", - "1", - "--batch-size", - "16", - "--config-name", - "firefly_gan_vq", - "--checkpoint-path", - "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", - ] - ) - - subprocess.run( - [ - PYTHON, - "tools/llama/build_dataset.py", - "--input", - str(data_pre_output), - "--text-extension", - ".lab", - "--num-workers", - "16", - ] - ) - ckpt_path = "checkpoints/fish-speech-1.4/model.pth" - lora_prefix = "lora_" if llama_use_lora else "" - llama_name = lora_prefix + "text2semantic_" + new_project - latest = next( - iter( - sorted( - [ - str(p.relative_to("results")) - for p in Path("results").glob(lora_prefix + "text2sem*/") - ], - reverse=True, - ) - ), - llama_name, - ) - project = ( - llama_name - if llama_ckpt == i18n("new") - else ( - latest - if llama_ckpt == i18n("latest") - else Path(llama_ckpt).relative_to("results") - ) - ) - logger.info(project) - - if llama_check_interval > llama_maxsteps: - llama_check_interval = llama_maxsteps - - train_cmd = [ - PYTHON, - "fish_speech/train.py", - "--config-name", - "text2semantic_finetune", - f"project={project}", - f"trainer.strategy.process_group_backend={backend}", - f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}", - f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}", - f"model.optimizer.lr={llama_lr}", - f"trainer.max_steps={llama_maxsteps}", - f"data.num_workers={llama_data_num_workers}", - f"data.batch_size={llama_data_batch_size}", - f"max_length={llama_data_max_length}", - f"trainer.precision={llama_precision}", - f"trainer.val_check_interval={llama_check_interval}", - f"trainer.accumulate_grad_batches={llama_grad_batches}", - f"train_dataset.interactive_prob={llama_use_speaker}", - ] + ([f"+lora@model.model.lora_config=r_8_alpha_16"] if llama_use_lora else []) - logger.info(train_cmd) - subprocess.run(train_cmd) - - return build_html_ok_message(i18n("Training stopped")) - - -def tensorboard_process( - if_tensorboard: bool, - tensorboard_dir: str, - host: str, - port: str, -): - global p_tensorboard - if if_tensorboard == True and p_tensorboard == None: - url = f"http://{host}:{port}" - yield build_html_ok_message( - i18n("Tensorboard interface is launched at {}").format(url) - ) - prefix = ["tensorboard"] - if Path("fishenv").exists(): - prefix = ["fishenv/env/python.exe", "fishenv/env/Scripts/tensorboard.exe"] - - p_tensorboard = subprocess.Popen( - prefix - + [ - "--logdir", - tensorboard_dir, - "--host", - host, - "--port", - port, - "--reload_interval", - "120", - ] - ) - elif if_tensorboard == False and p_tensorboard != None: - kill_process(p_tensorboard.pid) - p_tensorboard = None - yield build_html_error_message(i18n("Tensorboard interface is closed")) - - -def fresh_tb_dir(): - return gr.Dropdown( - choices=[str(p) for p in Path("results").glob("**/tensorboard/")] - ) - - -def list_decoder_models(): - paths = [str(p) for p in Path("checkpoints").glob("fish*/firefly*.pth")] - if not paths: - logger.warning("No decoder model found") - return paths - - -def list_llama_models(): - choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")] - choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")] - choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")] - choices = sorted(choices, reverse=True) - if not choices: - logger.warning("No LLaMA model found") - return choices - - -def list_lora_llama_models(): - choices = sorted( - [str(p) for p in Path("results").glob("lora*/**/*.ckpt")], reverse=True - ) - if not choices: - logger.warning("No LoRA LLaMA model found") - return choices - - -def fresh_decoder_model(): - return gr.Dropdown(choices=list_decoder_models()) - - -def fresh_llama_ckpt(llama_use_lora): - return gr.Dropdown( - choices=[i18n("latest"), i18n("new")] - + ( - [str(p) for p in Path("results").glob("text2sem*/")] - if not llama_use_lora - else [str(p) for p in Path("results").glob("lora_*/")] - ) - ) - - -def fresh_llama_model(): - return gr.Dropdown(choices=list_llama_models()) - - -def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_output): - if ( - lora_weight is None - or not Path(lora_weight).exists() - or not Path(llama_weight).exists() - ): - return build_html_error_message( - i18n( - "Path error, please check the model file exists in the corresponding path" - ) - ) - gr.Warning("Merging begins...") - merge_cmd = [ - PYTHON, - "tools/llama/merge_lora.py", - "--lora-config", - "r_8_alpha_16", - "--lora-weight", - lora_weight, - "--output", - llama_lora_output + "_" + generate_folder_name(), - ] - logger.info(merge_cmd) - subprocess.run(merge_cmd) - return build_html_ok_message(i18n("Merge successfully")) - - -def llama_quantify(llama_weight, quantify_mode): - if llama_weight is None or not Path(llama_weight).exists(): - return build_html_error_message( - i18n( - "Path error, please check the model file exists in the corresponding path" - ) - ) - - gr.Warning("Quantifying begins...") - - now = generate_folder_name() - quantify_cmd = [ - PYTHON, - "tools/llama/quantize.py", - "--checkpoint-path", - llama_weight, - "--mode", - quantify_mode, - "--timestamp", - now, - ] - logger.info(quantify_cmd) - subprocess.run(quantify_cmd) - if quantify_mode == "int8": - quantize_path = str( - Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-{now}" - ) - else: - quantize_path = str( - Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-g128-{now}" - ) - return build_html_ok_message( - i18n("Quantify successfully") + f"Path: {quantize_path}" - ) - - -init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path) -init_llama_yml = load_yaml_data_in_fact(llama_yml_path) - -with gr.Blocks( - head="", - js=js, - theme=seafoam, - analytics_enabled=False, - title="Fish Speech", -) as demo: - with gr.Row(): - with gr.Column(): - with gr.Tab("\U0001F4D6 " + i18n("Data Preprocessing")): - with gr.Row(): - textbox = gr.Textbox( - label="\U0000270F " - + i18n("Input Audio & Source Path for Transcription"), - info=i18n("Speaker is identified by the folder name"), - interactive=True, - ) - with gr.Row(equal_height=False): - with gr.Column(): - output_radio = gr.Radio( - label="\U0001F4C1 " - + i18n("Select source file processing method"), - choices=[i18n("Copy"), i18n("Move")], - value=i18n("Copy"), - interactive=True, - ) - with gr.Column(): - error = gr.HTML(label=i18n("Error Message")) - if_label = gr.Checkbox( - label=i18n("Open Labeler WebUI"), scale=0, show_label=True - ) - - with gr.Row(): - label_device = gr.Dropdown( - label=i18n("Labeling Device"), - info=i18n( - "It is recommended to use CUDA, if you have low configuration, use CPU" - ), - choices=["cpu", "cuda"], - value="cuda", - interactive=True, - ) - label_model = gr.Dropdown( - label=i18n("Whisper Model"), - info=i18n("Faster Whisper, Up to 5g GPU memory usage"), - choices=["large-v3", "medium"], - value="large-v3", - interactive=True, - ) - label_radio = gr.Dropdown( - label=i18n("Optional Label Language"), - info=i18n( - "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format" - ), - choices=[ - (i18n("Chinese"), "zh"), - (i18n("English"), "en"), - (i18n("Japanese"), "ja"), - (i18n("Disabled"), "IGNORE"), - (i18n("auto"), "auto"), - ], - value="IGNORE", - interactive=True, - ) - - with gr.Row(): - if_initial_prompt = gr.Checkbox( - value=False, - label=i18n("Enable Initial Prompt"), - min_width=120, - scale=0, - ) - initial_prompt = gr.Textbox( - label=i18n("Initial Prompt"), - info=i18n( - "Initial prompt can provide contextual or vocabulary-specific guidance to the model." - ), - placeholder="This audio introduces the basic concepts and applications of artificial intelligence and machine learning.", - interactive=False, - ) - - with gr.Row(): - add_button = gr.Button( - "\U000027A1 " + i18n("Add to Processing Area"), - variant="primary", - ) - remove_button = gr.Button( - "\U000026D4 " + i18n("Remove Selected Data") - ) - - with gr.Tab("\U0001F6E0 " + i18n("Training Configuration")): - with gr.Row(): - model_type_radio = gr.Radio( - label=i18n( - "Select the model to be trained (Depending on the Tab page you are on)" - ), - interactive=False, - choices=["VQGAN", "LLAMA"], - value="VQGAN", - ) - with gr.Row(): - with gr.Tabs(): - with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page: - gr.HTML("You don't need to train this model!") - - with gr.Tab(label=i18n("LLAMA Configuration")) as llama_page: - with gr.Row(equal_height=False): - llama_use_lora = gr.Checkbox( - label=i18n("Use LoRA"), - info=i18n( - "Use LoRA can save GPU memory, but may reduce the quality of the model" - ), - value=True, - interactive=True, - ) - llama_ckpt = gr.Dropdown( - label=i18n("Select LLAMA ckpt"), - choices=[i18n("latest"), i18n("new")] - + [ - str(p) - for p in Path("results").glob("text2sem*/") - ] - + [str(p) for p in Path("results").glob("lora*/")], - value=i18n("latest"), - interactive=True, - ) - with gr.Row(equal_height=False): - llama_lr_slider = gr.Slider( - label=i18n("Initial Learning Rate"), - info=i18n( - "lr smaller -> usually train slower but more stable" - ), - interactive=True, - minimum=1e-5, - maximum=1e-4, - step=1e-5, - value=5e-5, - ) - llama_maxsteps_slider = gr.Slider( - label=i18n("Maximum Training Steps"), - info=i18n( - "recommend: max_steps = num_audios // batch_size * (2 to 5)" - ), - interactive=True, - minimum=1, - maximum=10000, - step=1, - value=50, - ) - with gr.Row(equal_height=False): - llama_base_config = gr.Dropdown( - label=i18n("Model Size"), - choices=[ - "text2semantic_finetune", - ], - value="text2semantic_finetune", - ) - llama_data_num_workers_slider = gr.Slider( - label=i18n("Number of Workers"), - minimum=1, - maximum=32, - step=1, - value=4, - ) - with gr.Row(equal_height=False): - llama_data_batch_size_slider = gr.Slider( - label=i18n("Batch Size"), - interactive=True, - minimum=1, - maximum=32, - step=1, - value=2, - ) - llama_data_max_length_slider = gr.Slider( - label=i18n("Maximum Length per Sample"), - interactive=True, - minimum=1024, - maximum=4096, - step=128, - value=2048, - ) - with gr.Row(equal_height=False): - llama_precision_dropdown = gr.Dropdown( - label=i18n("Precision"), - info=i18n( - "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU" - ), - interactive=True, - choices=["32", "bf16-true", "16-mixed"], - value="bf16-true", - ) - llama_check_interval_slider = gr.Slider( - label=i18n("Save model every n steps"), - info=i18n( - "make sure that it's not greater than max_steps" - ), - interactive=True, - minimum=1, - maximum=1000, - step=1, - value=50, - ) - with gr.Row(equal_height=False): - llama_grad_batches = gr.Slider( - label=i18n("Accumulate Gradient Batches"), - interactive=True, - minimum=1, - maximum=20, - step=1, - value=init_llama_yml["trainer"][ - "accumulate_grad_batches" - ], - ) - llama_use_speaker = gr.Slider( - label=i18n( - "Probability of applying Speaker Condition" - ), - interactive=True, - minimum=0.1, - maximum=1.0, - step=0.05, - value=init_llama_yml["train_dataset"][ - "interactive_prob" - ], - ) - - with gr.Tab(label=i18n("Merge LoRA"), id=4): - with gr.Row(equal_height=False): - llama_weight = gr.Dropdown( - label=i18n("Base LLAMA Model"), - info=i18n( - "Type the path or select from the dropdown" - ), - choices=[ - "checkpoints/fish-speech-1.4/model.pth", - ], - value="checkpoints/fish-speech-1.4/model.pth", - allow_custom_value=True, - interactive=True, - ) - with gr.Row(equal_height=False): - lora_weight = gr.Dropdown( - label=i18n("LoRA Model to be merged"), - info=i18n( - "Type the path or select from the dropdown" - ), - choices=[ - str(p) - for p in Path("results").glob("lora*/**/*.ckpt") - ], - allow_custom_value=True, - interactive=True, - ) - lora_llama_config = gr.Dropdown( - label=i18n("LLAMA Model Config"), - info=i18n( - "Type the path or select from the dropdown" - ), - choices=[ - "text2semantic_finetune", - ], - value="text2semantic_finetune", - allow_custom_value=True, - ) - with gr.Row(equal_height=False): - llama_lora_output = gr.Dropdown( - label=i18n("Output Path"), - info=i18n( - "Type the path or select from the dropdown" - ), - value="checkpoints/merged", - choices=["checkpoints/merged"], - allow_custom_value=True, - interactive=True, - ) - with gr.Row(equal_height=False): - llama_lora_merge_btn = gr.Button( - value=i18n("Merge"), variant="primary" - ) - - with gr.Tab(label=i18n("Model Quantization"), id=5): - with gr.Row(equal_height=False): - llama_weight_to_quantify = gr.Dropdown( - label=i18n("Base LLAMA Model"), - info=i18n( - "Type the path or select from the dropdown" - ), - choices=list_llama_models(), - value="checkpoints/fish-speech-1.4", - allow_custom_value=True, - interactive=True, - ) - quantify_mode = gr.Dropdown( - label=i18n("Post-quantification Precision"), - info=i18n( - "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase" - ), - choices=["int8", "int4"], - value="int8", - allow_custom_value=False, - interactive=True, - ) - with gr.Row(equal_height=False): - llama_quantify_btn = gr.Button( - value=i18n("Quantify"), variant="primary" - ) - - with gr.Tab(label="Tensorboard", id=6): - with gr.Row(equal_height=False): - tb_host = gr.Textbox( - label=i18n("Tensorboard Host"), value="127.0.0.1" - ) - tb_port = gr.Textbox( - label=i18n("Tensorboard Port"), value="11451" - ) - with gr.Row(equal_height=False): - tb_dir = gr.Dropdown( - label=i18n("Tensorboard Log Path"), - allow_custom_value=True, - choices=[ - str(p) - for p in Path("results").glob("**/tensorboard/") - ], - ) - with gr.Row(equal_height=False): - if_tb = gr.Checkbox( - label=i18n("Open Tensorboard"), - ) - - with gr.Tab("\U0001F9E0 " + i18n("Inference Configuration")): - with gr.Column(): - with gr.Row(): - with gr.Accordion( - label="\U0001F5A5 " - + i18n("Inference Server Configuration"), - open=False, - ): - with gr.Row(): - infer_host_textbox = gr.Textbox( - label=i18n("WebUI Host"), value="127.0.0.1" - ) - infer_port_textbox = gr.Textbox( - label=i18n("WebUI Port"), value="7862" - ) - with gr.Row(): - infer_decoder_model = gr.Dropdown( - label=i18n("Decoder Model Path"), - info=i18n( - "Type the path or select from the dropdown" - ), - choices=list_decoder_models(), - value="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", - allow_custom_value=True, - ) - infer_decoder_config = gr.Dropdown( - label=i18n("Decoder Model Config"), - info=i18n("Changing with the Model Path"), - value="firefly_gan_vq", - choices=[ - "firefly_gan_vq", - ], - allow_custom_value=True, - ) - with gr.Row(): - infer_llama_model = gr.Dropdown( - label=i18n("LLAMA Model Path"), - info=i18n( - "Type the path or select from the dropdown" - ), - value="checkpoints/fish-speech-1.4", - choices=list_llama_models(), - allow_custom_value=True, - ) - - with gr.Row(): - infer_compile = gr.Radio( - label=i18n("Compile Model"), - info=i18n( - "Compile the model can significantly reduce the inference time, but will increase cold start time" - ), - choices=["Yes", "No"], - value=( - "Yes" if (sys.platform == "linux") else "No" - ), - interactive=is_module_installed("triton"), - ) - - with gr.Row(): - infer_checkbox = gr.Checkbox( - label=i18n("Open Inference Server") - ) - infer_error = gr.HTML(label=i18n("Inference Server Error")) - - with gr.Column(): - train_error = gr.HTML(label=i18n("Training Error")) - checkbox_group = gr.CheckboxGroup( - label="\U0001F4CA " + i18n("Data Source"), - info=i18n( - "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list." - ), - elem_classes=["data_src"], - ) - train_box = gr.Textbox( - label=i18n("Data Preprocessing Path"), - value=str(data_pre_output), - interactive=False, - ) - model_box = gr.Textbox( - label="\U0001F4BE " + i18n("Model Output Path"), - value=str(default_model_output), - interactive=False, - ) - - with gr.Accordion( - i18n( - "View the status of the preprocessing folder (use the slider to control the depth of the tree)" - ), - elem_classes=["scrollable-component"], - elem_id="file_accordion", - ): - tree_slider = gr.Slider( - minimum=0, - maximum=3, - value=0, - step=1, - show_label=False, - container=False, - ) - file_markdown = new_explorer(str(data_pre_output), 0) - with gr.Row(equal_height=False): - admit_btn = gr.Button( - "\U00002705 " + i18n("File Preprocessing"), - variant="primary", - ) - fresh_btn = gr.Button("\U0001F503", scale=0, min_width=80) - help_button = gr.Button("\U00002753", scale=0, min_width=80) # question - train_btn = gr.Button(i18n("Start Training"), variant="primary") - - footer = load_data_in_raw("fish_speech/webui/html/footer.html") - footer = footer.format( - versions=versions_html(), - api_docs="https://speech.fish.audio/inference/#http-api", - ) - gr.HTML(footer, elem_id="footer") - vqgan_page.select(lambda: "VQGAN", None, model_type_radio) - llama_page.select(lambda: "LLAMA", None, model_type_radio) - add_button.click( - fn=add_item, - inputs=[textbox, output_radio, label_radio, if_initial_prompt, initial_prompt], - outputs=[checkbox_group, error], - ) - remove_button.click( - fn=remove_items, inputs=[checkbox_group], outputs=[checkbox_group, error] - ) - checkbox_group.change(fn=show_selected, inputs=checkbox_group, outputs=[error]) - help_button.click( - fn=None, - js='() => { window.open("https://speech.fish.audio/", "newwindow", "height=100, width=400, ' - 'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}', - ) - if_label.change(fn=change_label, inputs=[if_label], outputs=[error]) - if_initial_prompt.change( - fn=lambda x: gr.Textbox(value="", interactive=x), - inputs=[if_initial_prompt], - outputs=[initial_prompt], - ) - train_btn.click( - fn=train_process, - inputs=[ - train_box, - model_type_radio, - # llama config - llama_ckpt, - llama_base_config, - llama_lr_slider, - llama_maxsteps_slider, - llama_data_num_workers_slider, - llama_data_batch_size_slider, - llama_data_max_length_slider, - llama_precision_dropdown, - llama_check_interval_slider, - llama_grad_batches, - llama_use_speaker, - llama_use_lora, - ], - outputs=[train_error], - ) - if_tb.change( - fn=tensorboard_process, - inputs=[if_tb, tb_dir, tb_host, tb_port], - outputs=[train_error], - ) - tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir]) - infer_decoder_model.change( - fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model] - ) - infer_llama_model.change( - fn=fresh_llama_model, inputs=[], outputs=[infer_llama_model] - ) - llama_weight.change(fn=fresh_llama_model, inputs=[], outputs=[llama_weight]) - admit_btn.click( - fn=check_files, - inputs=[train_box, tree_slider, label_model, label_device], - outputs=[error, file_markdown], - ) - fresh_btn.click( - fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown] - ) - llama_use_lora.change( - fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt] - ) - llama_ckpt.change( - fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt] - ) - lora_weight.change( - fn=lambda: gr.Dropdown(choices=list_lora_llama_models()), - inputs=[], - outputs=[lora_weight], - ) - llama_lora_merge_btn.click( - fn=llama_lora_merge, - inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output], - outputs=[train_error], - ) - llama_quantify_btn.click( - fn=llama_quantify, - inputs=[llama_weight_to_quantify, quantify_mode], - outputs=[train_error], - ) - infer_checkbox.change( - fn=change_infer, - inputs=[ - infer_checkbox, - infer_host_textbox, - infer_port_textbox, - infer_decoder_model, - infer_decoder_config, - infer_llama_model, - infer_compile, - ], - outputs=[infer_error], - ) - -demo.launch(inbrowser=True)