diff --git a/fish_speech/__pycache__/conversation.cpython-310.pyc b/fish_speech/__pycache__/conversation.cpython-310.pyc
deleted file mode 100644
index b4dc1336106c5d496e7a1c091e609089eb30d096..0000000000000000000000000000000000000000
Binary files a/fish_speech/__pycache__/conversation.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/__pycache__/scheduler.cpython-310.pyc b/fish_speech/__pycache__/scheduler.cpython-310.pyc
deleted file mode 100644
index 5ce90919af88b3c612722a85c3799f2cc4d58d76..0000000000000000000000000000000000000000
Binary files a/fish_speech/__pycache__/scheduler.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/callbacks/__init__.py b/fish_speech/callbacks/__init__.py
deleted file mode 100644
index bbcf3f33656d180ca87cd14a21ede1544e5a61a3..0000000000000000000000000000000000000000
--- a/fish_speech/callbacks/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .grad_norm import GradNormMonitor
-
-__all__ = ["GradNormMonitor"]
diff --git a/fish_speech/callbacks/__pycache__/__init__.cpython-310.pyc b/fish_speech/callbacks/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index 033bf77b0edc8dbe764c3e4386c005136b1ee50c..0000000000000000000000000000000000000000
Binary files a/fish_speech/callbacks/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/callbacks/__pycache__/grad_norm.cpython-310.pyc b/fish_speech/callbacks/__pycache__/grad_norm.cpython-310.pyc
deleted file mode 100644
index 2058510bf280afba7b92dc027d4629aa030b72fc..0000000000000000000000000000000000000000
Binary files a/fish_speech/callbacks/__pycache__/grad_norm.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/callbacks/grad_norm.py b/fish_speech/callbacks/grad_norm.py
deleted file mode 100644
index dbc95ef2a3723323b2d976001ed1e3c79c00b21a..0000000000000000000000000000000000000000
--- a/fish_speech/callbacks/grad_norm.py
+++ /dev/null
@@ -1,113 +0,0 @@
-from typing import Optional, Union
-
-import lightning.pytorch as pl
-import torch
-from lightning import LightningModule, Trainer
-from lightning.pytorch.callbacks import Callback
-from torch import Tensor, nn
-from torch.utils._foreach_utils import (
- _group_tensors_by_device_and_dtype,
- _has_foreach_support,
-)
-
-
-@torch.no_grad()
-def grad_norm(
- parameters: Union[Tensor, list[Tensor]],
- norm_type: float = 2.0,
-) -> float:
- """
- Returns the norm of the gradients of the given parameters.
-
- Args:
- parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
- single Tensor that will have gradients normalized
- norm_type (float): type of the used p-norm.
-
- Returns:
- Total norm of the parameter gradients (viewed as a single vector).
- """ # noqa: E501
-
- if isinstance(parameters, Tensor):
- parameters = [parameters]
-
- grads = [p.grad for p in parameters if p.grad is not None]
- if len(grads) == 0:
- return None
-
- first_device = grads[0].device
- grouped_grads: dict[
- tuple[torch.device, torch.dtype], list[list[Tensor]]
- ] = _group_tensors_by_device_and_dtype(
- [[g.detach() for g in grads]]
- ) # type: ignore[assignment]
-
- norms = []
- for (device, _), ([grads], _) in grouped_grads.items():
- if _has_foreach_support(grads, device=device):
- norms.extend(torch._foreach_norm(grads, norm_type))
- else:
- norms.extend([torch.norm(g, norm_type) for g in grads])
-
- return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
-
-
-class GradNormMonitor(Callback):
- """
- Callback that computes the gradient norm of the model parameters.
- """
-
- def __init__(
- self,
- norm_type: float = 2.0,
- logging_interval: str = "step",
- sub_module: Optional[Union[str, list[str]]] = None,
- ) -> None:
- """
- Args:
- norm_type (float): type of the used p-norm.
- logging_interval (str): "step" or "epoch".
- """
- super().__init__()
-
- self.norm_type = norm_type
- self.logging_interval = logging_interval
- self.sub_module = sub_module
-
- def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
- """
- Computes the gradient norm of the model parameters and logs it to the logger.
-
- Args:
- trainer (Trainer): The trainer object
- model (LightningModule): The current lightningModule
- """
-
- lightning_model = model
-
- if self.sub_module is None:
- return self.log_sub_module_grad_norm(lightning_model, model, "")
-
- sub_modules = self.sub_module
- if isinstance(sub_modules, str):
- sub_modules = [sub_modules]
-
- for sub_module in sub_modules:
- self.log_sub_module_grad_norm(
- lightning_model, getattr(model, sub_module), f"/{sub_module}"
- )
-
- def log_sub_module_grad_norm(
- self, lightning_model: LightningModule, model: nn.Module, path: str
- ) -> None:
- grad_norm_val = grad_norm(model.parameters(), self.norm_type)
- if grad_norm_val is None:
- return
-
- on_step = self.logging_interval == "step"
- lightning_model.log(
- f"train{path}/grad_norm",
- grad_norm_val,
- on_step=on_step,
- on_epoch=not on_step,
- )
diff --git a/fish_speech/configs/base.yaml b/fish_speech/configs/base.yaml
deleted file mode 100644
index 99e6dab54d3f57bce4f6d29a9129a19a523cad75..0000000000000000000000000000000000000000
--- a/fish_speech/configs/base.yaml
+++ /dev/null
@@ -1,87 +0,0 @@
-# Base configuration for training a model
-paths:
- run_dir: results/${project}
- ckpt_dir: ${paths.run_dir}/checkpoints
-
-hydra:
- run:
- dir: ${paths.run_dir}
-
-# Lightning Trainer
-trainer:
- _target_: lightning.pytorch.trainer.Trainer
-
- default_root_dir: ${paths.run_dir}
- accelerator: gpu
- num_nodes: 1
- devices: auto
- strategy:
- _target_: lightning.pytorch.strategies.DDPStrategy
- process_group_backend: nccl # This should be override when training on windows
-
- precision: bf16-mixed
-
- # disable validation by epoch end
- check_val_every_n_epoch: null
- val_check_interval: 5000
- max_steps: 100_000
-
- # Use torch.backends.cudnn.benchmark to speed up training
- benchmark: true
-
-# Callbacks
-callbacks:
- model_checkpoint:
- _target_: lightning.pytorch.callbacks.ModelCheckpoint
- dirpath: ${paths.ckpt_dir}
- filename: "step_{step:09d}"
- save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
- save_top_k: 5 # save 5 latest checkpoints
- monitor: step # use step to monitor checkpoints
- mode: max # save the latest checkpoint with the highest global_step
- every_n_epochs: null # don't save checkpoints by epoch end
- every_n_train_steps: 5000 # save checkpoints every 5000 steps
- auto_insert_metric_name: false
-
- model_summary:
- _target_: lightning.pytorch.callbacks.ModelSummary
- max_depth: 2 # the maximum depth of layer nesting that the summary will include
-
- learning_rate_monitor:
- _target_: lightning.pytorch.callbacks.LearningRateMonitor
- logging_interval: step
- log_momentum: false
-
- grad_norm_monitor:
- _target_: fish_speech.callbacks.GradNormMonitor
- norm_type: 2
- logging_interval: step
-
-# Logger
-logger:
- tensorboard:
- _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
- save_dir: "${paths.run_dir}/tensorboard/"
- name: null
- log_graph: false
- default_hp_metric: true
- prefix: ""
-
- # wandb:
- # _target_: lightning.pytorch.loggers.wandb.WandbLogger
- # # name: "" # name of the run (normally generated by wandb)
- # save_dir: "${paths.run_dir}"
- # offline: False
- # id: null # pass correct id to resume experiment!
- # anonymous: null # enable anonymous logging
- # project: "fish-speech"
- # log_model: False # upload lightning ckpts
- # prefix: "" # a string to put at the beginning of metric keys
- # # entity: "" # set to name of your wandb team
- # group: ""
- # tags: ["vq", "hq", "finetune"]
- # job_type: ""
-
-# Loop
-train: true
-test: false
diff --git a/fish_speech/configs/firefly_gan_vq.yaml b/fish_speech/configs/firefly_gan_vq.yaml
deleted file mode 100644
index 10aa8d4a522f0859ed8f541f5d48672d84b39c8f..0000000000000000000000000000000000000000
--- a/fish_speech/configs/firefly_gan_vq.yaml
+++ /dev/null
@@ -1,33 +0,0 @@
-_target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture
-spec_transform:
- _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
- sample_rate: 44100
- n_mels: 160
- n_fft: 2048
- hop_length: 512
- win_length: 2048
-backbone:
- _target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder
- input_channels: 160
- depths: [3, 3, 9, 3]
- dims: [128, 256, 384, 512]
- drop_path_rate: 0.2
- kernel_size: 7
-head:
- _target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator
- hop_length: 512
- upsample_rates: [8, 8, 2, 2, 2] # aka. strides
- upsample_kernel_sizes: [16, 16, 4, 4, 4]
- resblock_kernel_sizes: [3, 7, 11]
- resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
- num_mels: 512
- upsample_initial_channel: 512
- pre_conv_kernel_size: 13
- post_conv_kernel_size: 13
-quantizer:
- _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
- input_dim: 512
- n_groups: 8
- n_codebooks: 1
- levels: [8, 5, 5, 5]
- downsample_factor: [2, 2]
diff --git a/fish_speech/configs/lora/r_8_alpha_16.yaml b/fish_speech/configs/lora/r_8_alpha_16.yaml
deleted file mode 100644
index aecc4d9766a18fe31c55941e01b1f590c95e77c9..0000000000000000000000000000000000000000
--- a/fish_speech/configs/lora/r_8_alpha_16.yaml
+++ /dev/null
@@ -1,4 +0,0 @@
-_target_: fish_speech.models.text2semantic.lora.LoraConfig
-r: 8
-lora_alpha: 16
-lora_dropout: 0.01
diff --git a/fish_speech/configs/text2semantic_finetune.yaml b/fish_speech/configs/text2semantic_finetune.yaml
deleted file mode 100644
index f4c1993023099e122fc9e004bda55ec075ed5e1b..0000000000000000000000000000000000000000
--- a/fish_speech/configs/text2semantic_finetune.yaml
+++ /dev/null
@@ -1,83 +0,0 @@
-defaults:
- - base
- - _self_
-
-project: text2semantic_finetune_dual_ar
-max_length: 4096
-pretrained_ckpt_path: checkpoints/fish-speech-1.4
-
-# Lightning Trainer
-trainer:
- accumulate_grad_batches: 1
- gradient_clip_val: 1.0
- gradient_clip_algorithm: "norm"
- max_steps: 1000
- precision: bf16-true
- limit_val_batches: 10
- val_check_interval: 100
-
-# Dataset Configuration
-tokenizer:
- _target_: transformers.AutoTokenizer.from_pretrained
- pretrained_model_name_or_path: ${pretrained_ckpt_path}
-
-# Dataset Configuration
-train_dataset:
- _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
- proto_files:
- - data/protos
- tokenizer: ${tokenizer}
- causal: true
- max_length: ${max_length}
- use_speaker: false
- interactive_prob: 0.7
-
-val_dataset:
- _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
- proto_files:
- - data/protos
- tokenizer: ${tokenizer}
- causal: true
- max_length: ${max_length}
- use_speaker: false
- interactive_prob: 0.7
-
-data:
- _target_: fish_speech.datasets.semantic.SemanticDataModule
- train_dataset: ${train_dataset}
- val_dataset: ${val_dataset}
- num_workers: 4
- batch_size: 8
- tokenizer: ${tokenizer}
- max_length: ${max_length}
-
-# Model Configuration
-model:
- _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
- model:
- _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
- path: ${pretrained_ckpt_path}
- load_weights: true
- max_length: ${max_length}
- lora_config: null
-
- optimizer:
- _target_: torch.optim.AdamW
- _partial_: true
- lr: 1e-4
- weight_decay: 0
- betas: [0.9, 0.95]
- eps: 1e-5
-
- lr_scheduler:
- _target_: torch.optim.lr_scheduler.LambdaLR
- _partial_: true
- lr_lambda:
- _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
- _partial_: true
- num_warmup_steps: 10
-
-# Callbacks
-callbacks:
- model_checkpoint:
- every_n_train_steps: ${trainer.val_check_interval}
diff --git a/fish_speech/conversation.py b/fish_speech/conversation.py
deleted file mode 100644
index c9ca0ef9181754eda7e6b49e01abeafbe07fb00f..0000000000000000000000000000000000000000
--- a/fish_speech/conversation.py
+++ /dev/null
@@ -1,2 +0,0 @@
-SEMANTIC_TOKEN = "<|semantic|>"
-CODEBOOK_PAD_TOKEN_ID = 0
diff --git a/fish_speech/datasets/__pycache__/semantic.cpython-310.pyc b/fish_speech/datasets/__pycache__/semantic.cpython-310.pyc
deleted file mode 100644
index ca763c1c12b41234a939ddbe343575b67a79bb92..0000000000000000000000000000000000000000
Binary files a/fish_speech/datasets/__pycache__/semantic.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/datasets/concat_repeat.py b/fish_speech/datasets/concat_repeat.py
deleted file mode 100644
index 4aa596b95a572ee15c5570cbdb792c9a78e62dfa..0000000000000000000000000000000000000000
--- a/fish_speech/datasets/concat_repeat.py
+++ /dev/null
@@ -1,53 +0,0 @@
-import bisect
-import random
-from typing import Iterable
-
-from torch.utils.data import Dataset, IterableDataset
-
-
-class ConcatRepeatDataset(Dataset):
- datasets: list[Dataset]
- cumulative_sizes: list[int]
- repeats: list[int]
-
- @staticmethod
- def cumsum(sequence, repeats):
- r, s = [], 0
- for dataset, repeat in zip(sequence, repeats):
- l = len(dataset) * repeat
- r.append(l + s)
- s += l
- return r
-
- def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
- super().__init__()
-
- self.datasets = list(datasets)
- self.repeats = repeats
-
- assert len(self.datasets) > 0, "datasets should not be an empty iterable"
- assert len(self.datasets) == len(
- repeats
- ), "datasets and repeats should have the same length"
-
- for d in self.datasets:
- assert not isinstance(
- d, IterableDataset
- ), "ConcatRepeatDataset does not support IterableDataset"
-
- self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
-
- def __len__(self):
- return self.cumulative_sizes[-1]
-
- def __getitem__(self, idx):
- dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
-
- if dataset_idx == 0:
- sample_idx = idx
- else:
- sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
-
- dataset = self.datasets[dataset_idx]
-
- return dataset[sample_idx % len(dataset)]
diff --git a/fish_speech/datasets/protos/__pycache__/text_data_pb2.cpython-310.pyc b/fish_speech/datasets/protos/__pycache__/text_data_pb2.cpython-310.pyc
deleted file mode 100644
index 2b7bb23609e78b12b6e608581f1e8d764bd9db3a..0000000000000000000000000000000000000000
Binary files a/fish_speech/datasets/protos/__pycache__/text_data_pb2.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/datasets/protos/__pycache__/text_data_stream.cpython-310.pyc b/fish_speech/datasets/protos/__pycache__/text_data_stream.cpython-310.pyc
deleted file mode 100644
index 6e22635c991e5d669704c3cf95dd528a39b8b822..0000000000000000000000000000000000000000
Binary files a/fish_speech/datasets/protos/__pycache__/text_data_stream.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/datasets/protos/text-data.proto b/fish_speech/datasets/protos/text-data.proto
deleted file mode 100644
index 5eb26d94aa3be1e21066f2bf38c90d54e85a8379..0000000000000000000000000000000000000000
--- a/fish_speech/datasets/protos/text-data.proto
+++ /dev/null
@@ -1,24 +0,0 @@
-syntax = "proto3";
-
-package text_data;
-
-message Semantics {
- repeated uint32 values = 1;
-}
-
-message Sentence {
- repeated string texts = 1;
- repeated Semantics semantics = 3;
-}
-
-message TextData {
- string source = 1;
- string name = 2;
- repeated Sentence sentences = 4;
-}
-
-message SampledData {
- string source = 1;
- string name = 2;
- repeated Sentence samples = 3;
-}
diff --git a/fish_speech/datasets/protos/text_data_pb2.py b/fish_speech/datasets/protos/text_data_pb2.py
deleted file mode 100644
index bfce0e8be59fc51e68999ef137e1fd0e4adc0d7e..0000000000000000000000000000000000000000
--- a/fish_speech/datasets/protos/text_data_pb2.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# -*- coding: utf-8 -*-
-# Generated by the protocol buffer compiler. DO NOT EDIT!
-# source: text-data.proto
-# Protobuf Python Version: 4.25.1
-"""Generated protocol buffer code."""
-from google.protobuf import descriptor as _descriptor
-from google.protobuf import descriptor_pool as _descriptor_pool
-from google.protobuf import symbol_database as _symbol_database
-from google.protobuf.internal import builder as _builder
-
-# @@protoc_insertion_point(imports)
-
-_sym_db = _symbol_database.Default()
-
-
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3'
-)
-
-_globals = globals()
-_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
-_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
-if _descriptor._USE_C_DESCRIPTORS == False:
- DESCRIPTOR._options = None
- _globals["_SEMANTICS"]._serialized_start = 30
- _globals["_SEMANTICS"]._serialized_end = 57
- _globals["_SENTENCE"]._serialized_start = 59
- _globals["_SENTENCE"]._serialized_end = 125
- _globals["_TEXTDATA"]._serialized_start = 127
- _globals["_TEXTDATA"]._serialized_end = 207
- _globals["_SAMPLEDDATA"]._serialized_start = 209
- _globals["_SAMPLEDDATA"]._serialized_end = 290
-# @@protoc_insertion_point(module_scope)
diff --git a/fish_speech/datasets/protos/text_data_stream.py b/fish_speech/datasets/protos/text_data_stream.py
deleted file mode 100644
index ec3c25bcd764e8245de47dcdf9686d6adfb5a107..0000000000000000000000000000000000000000
--- a/fish_speech/datasets/protos/text_data_stream.py
+++ /dev/null
@@ -1,36 +0,0 @@
-import struct
-
-from .text_data_pb2 import TextData
-
-
-def read_pb_stream(f):
- while True:
- buf = f.read(4)
- if len(buf) == 0:
- break
- size = struct.unpack("I", buf)[0]
- buf = f.read(size)
- text_data = TextData()
- text_data.ParseFromString(buf)
- yield text_data
-
-
-def write_pb_stream(f, text_data):
- buf = text_data.SerializeToString()
- f.write(struct.pack("I", len(buf)))
- f.write(buf)
-
-
-def pack_pb_stream(text_data):
- buf = text_data.SerializeToString()
- return struct.pack("I", len(buf)) + buf
-
-
-def split_pb_stream(f):
- while True:
- head = f.read(4)
- if len(head) == 0:
- break
- size = struct.unpack("I", head)[0]
- buf = f.read(size)
- yield head + buf
diff --git a/fish_speech/datasets/semantic.py b/fish_speech/datasets/semantic.py
deleted file mode 100644
index 3c64e01077ae253bdc4e4d9cd948f8fb50df7418..0000000000000000000000000000000000000000
--- a/fish_speech/datasets/semantic.py
+++ /dev/null
@@ -1,496 +0,0 @@
-import random
-from dataclasses import dataclass
-from itertools import chain
-from pathlib import Path
-from random import Random
-from typing import Optional, Union
-
-import numpy as np
-import pyarrow.parquet as pq
-import torch
-import torch.nn.functional as F
-from datasets.download.streaming_download_manager import xopen
-from huggingface_hub import HfApi
-from lightning import LightningDataModule
-from torch.distributed import get_rank, get_world_size, is_initialized
-from torch.utils.data import DataLoader, IterableDataset, get_worker_info
-from transformers import AutoTokenizer
-
-from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
-from fish_speech.datasets.protos.text_data_pb2 import SampledData
-from fish_speech.datasets.protos.text_data_stream import read_pb_stream
-from fish_speech.text.clean import clean_text
-from fish_speech.utils import RankedLogger
-from fish_speech.utils.braceexpand import braceexpand
-
-log = RankedLogger(__name__, rank_zero_only=True)
-
-
-def split_by_rank_worker(files):
- # We need to know the total number of devices
- # to split the data properly
-
- total_devices = 1
- if is_initialized():
- total_devices = get_world_size()
-
- worker_info = get_worker_info()
- if worker_info is not None:
- total_devices *= worker_info.num_workers
-
- if len(files) < total_devices:
- # Repeat the files N times to match the number of devices
- files = files * (total_devices // len(files) + 1)
-
- # DDP
- if is_initialized():
- files = files[get_rank() :: get_world_size()]
-
- # Split by worker
- if worker_info is not None:
- files = files[worker_info.id :: worker_info.num_workers]
-
- return files
-
-
-class AutoTextSemanticInstructionDataset(IterableDataset):
- """
- Auto Augment Dataset by Speaker
-
- 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
- 2. Automatically normalize the text
-
- For interactive mode, we use the following format (multiple sequences):
- [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
-
- For non-interactive mode, we use the following format (one long sequence):
- [INST] text [/INST] ...
- """
-
- def __init__(
- self,
- proto_files: list[str],
- seed: int = 42,
- interactive_prob: float = 0.5,
- max_length: int = 1024,
- tokenizer: AutoTokenizer = None,
- use_speaker: bool | float = True,
- causal: bool = True,
- num_codebooks: Optional[int] = None,
- skip_text_prob: float = 0.0,
- ):
- """
- Args:
- proto_files: proto buf files if using local data
- seed: random seed
- interactive_prob: probability to use interactive mode
- max_length: max length of the text
- tokenizer: tokenizer
- use_speaker: include speaker information in the prompt
- causal: use causal sampling when using local data, disable will lead to random sampling
- num_codebooks: number of codebooks, if None, it will be automatically detected
- skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
- """
-
- super().__init__()
-
- assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
-
- self.seed = seed
- self.max_length = max_length
- self.tokenizer = tokenizer
- self.interactive_prob = interactive_prob
- self.use_speaker = use_speaker
- self.proto_files = proto_files
- self.causal = causal
- self.num_codebooks = num_codebooks
- self.skip_text_prob = skip_text_prob
-
- self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
- self.groups = None
-
- def init_mock_data_server(self):
- if self.groups is not None:
- return
-
- # Expand the proto files
- expanded_proto_files = []
- for filename in self.proto_files:
- for i in braceexpand(filename):
- i = Path(i)
- if i.is_file():
- expanded_proto_files.append(i)
- elif i.is_dir():
- expanded_proto_files.extend(i.rglob("*.proto"))
- expanded_proto_files.extend(i.rglob("*.protos"))
- else:
- raise ValueError(f"{i} is not a file or directory")
-
- expanded_proto_files = sorted(expanded_proto_files)
- Random(self.seed).shuffle(expanded_proto_files)
-
- self.groups = []
- shard_proto_files = split_by_rank_worker(expanded_proto_files)
- log.info(
- f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
- )
-
- count = 0
- for filename in shard_proto_files:
- with open(filename, "rb") as f:
- for text_data in read_pb_stream(f):
- self.groups.append(text_data)
- count += 1
-
- log.info(f"Read total {count} groups of data")
-
- # Shuffle the lines
- Random(self.seed).shuffle(self.groups)
- self.group_weights = [len(i.sentences) for i in self.groups]
-
- def __iter__(self):
- while True:
- yield self.augment()
-
- def tokenize_sentence(self, sentence: str):
- sentence = clean_text(sentence)
- tokens = self.tokenizer.encode(
- f"{sentence}",
- max_length=10**6,
- add_special_tokens=False,
- truncation=False,
- )
- return sentence, len(tokens)
-
- def sample_data(self):
- if self.groups is None:
- self.init_mock_data_server()
-
- # Shuffle unique lines, estimate that each sample is at least 20 tokens
- num_samples = self.max_length // 20
-
- # choice group based on their number of samples
- group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
-
- if self.causal:
- # Sample in order
- if num_samples >= len(group.sentences):
- samples = group.sentences
- else:
- begin = random.randint(0, len(group.sentences) - num_samples)
- samples = group.sentences[begin : begin + num_samples]
- else:
- samples = random.choices(
- group.sentences, k=min(num_samples, len(group.sentences))
- )
-
- return SampledData(
- source=group.source,
- name=group.name,
- samples=samples,
- )
-
- def augment(self):
- final_text, final_semantic = [], []
- response = self.sample_data()
- if len(response.samples) == 0:
- # Invalid group
- return None
-
- samples = list(response.samples)
- idx = 0
- use_interactive = random.random() < self.interactive_prob
-
- if use_interactive is False:
- # Random sample based on speaker using a truncated normal distribution
- a = torch.tensor([0], dtype=torch.float32)
- torch.nn.init.trunc_normal_(
- a,
- mean=self.max_length // 2,
- std=self.max_length // 4,
- a=10,
- b=self.max_length,
- )
- remaining_tokens = a.long().item() - 4
- else:
- remaining_tokens = self.max_length
-
- # Use speaker
- if isinstance(self.use_speaker, float):
- use_speaker = random.random() < self.use_speaker
- else:
- use_speaker = self.use_speaker
-
- all_tokens, all_labels = [], []
- while remaining_tokens > 0 and len(samples) > 0:
- sentence = samples.pop(0)
-
- text = random.choice(sentence.texts)
- text, length = self.tokenize_sentence(text)
- remaining_tokens -= length + len(sentence.semantics[0].values)
-
- if use_interactive is False:
- final_text.append(text)
- final_semantic.append(sentence.semantics)
- else:
- # For interactive mode, we only apply speaker for the first sentence
- # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
- tokens, labels = self.pack_sentences(
- sentences=[text],
- semantics=[sentence.semantics],
- speaker=response.name if use_speaker else None,
- skip_text=random.random() < self.skip_text_prob,
- )
-
- all_tokens.append(tokens)
- all_labels.append(labels)
-
- idx += 1
-
- if use_interactive is False:
- tokens, labels = self.pack_sentences(
- final_text,
- semantics=final_semantic,
- speaker=response.name if use_speaker else None,
- )
- all_tokens.append(tokens)
- all_labels.append(labels)
-
- tokens = torch.cat(all_tokens, dim=1)
- labels = torch.cat(all_labels, dim=1)
-
- # Verify that the length is correct
- assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
-
- data = {"tokens": tokens, "labels": labels}
-
- return data
-
- def pack_sentences(
- self,
- sentences: list[str],
- semantics: list,
- speaker: Optional[str] = None,
- skip_text: bool = False,
- ):
- if speaker is None:
- speaker = "assistant"
-
- cated_sentences = " ".join(sentences)
- if skip_text:
- cated_sentences = "<|skip_text|>"
-
- final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>"
- final_text = final_text + f"<|im_start|>{speaker}\n"
-
- encoded = self.tokenizer.encode(
- final_text,
- add_special_tokens=False,
- truncation=False,
- max_length=10**6,
- )
- semantic_length = sum([len(i[0].values) for i in semantics])
- prompt_length = len(encoded)
- num_codebooks = (
- len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
- )
-
- # Pack the tokens and semantics (add and to semantic tokens)
- tokens = (
- encoded
- + [self.semantic_token_id] * semantic_length
- + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])
- )
-
- # Codebook bos/padding: 0, eos: 1
- codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)]
- for segment in semantics:
- for book_idx, book in zip(range(num_codebooks), segment):
- for j in book.values:
- codes[book_idx].append(int(j) + 1)
-
- for book in codes:
- book.extend([CODEBOOK_PAD_TOKEN_ID] * 1)
-
- tokens = [tokens] + codes
-
- tokens = torch.tensor(tokens, dtype=torch.long)
- labels = tokens.clone()
-
- if skip_text:
- # If text is not provided, the sentence is used for condition only, all labels are -100
- torch.fill_(labels, -100)
- return tokens, labels
-
- # Mask out the tokens for semantic, predict semantic tokens only
- # Since we don't mask out the input tokens, the language modeling still works
- labels[1:, :prompt_length] = -100
-
- tokens = tokens[:, :-1]
- labels = labels[:, 1:]
-
- # Verify the padding is correct, and the last token is eos
- assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all()
- assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
-
- return tokens, labels
-
-
-@dataclass
-class TextDataCollator:
- tokenizer: AutoTokenizer
- max_length: int = 1024
-
- def __call__(self, examples):
- if "negative_tokens" in examples:
- positive_examples = []
- negative_examples = []
-
- for i in examples:
- positive_examples.append(
- {
- "tokens": i["tokens"],
- "labels": i["labels"],
- }
- )
- negative_examples.append(
- {
- "tokens": i["negative_tokens"],
- "labels": i["negative_labels"],
- }
- )
-
- examples = positive_examples + negative_examples
-
- return self.batchify(examples)
-
- def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
- tokens, attention_masks, labels = [], [], []
-
- # Calculate the max length
- max_tokens_length = 0
- for example in examples:
- max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
- max_tokens_length = min(max_tokens_length, self.max_length)
-
- for example in examples:
- _tokens = example[tokens_key][:, :max_tokens_length]
- _labels = example[labels_key][:, :max_tokens_length]
- _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
- tokens_length = _tokens.size(1)
- _attention_mask[:tokens_length] = False
-
- assert tokens_length == _labels.size(
- 1
- ), f"{tokens_length} != {_labels.size(1)}"
-
- if tokens_length < max_tokens_length:
- _tokens = F.pad(
- _tokens,
- (0, max_tokens_length - tokens_length),
- value=self.tokenizer.eos_token_id,
- )
- _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
- _labels = F.pad(
- _labels, (0, max_tokens_length - _labels.size(1)), value=-100
- )
-
- tokens.append(_tokens)
- attention_masks.append(_attention_mask)
- labels.append(_labels)
-
- tokens = torch.stack(tokens, dim=0)
- attention_masks = torch.stack(attention_masks, dim=0)
- labels = torch.stack(labels, dim=0)
-
- return {
- "inputs": tokens,
- "attention_masks": attention_masks,
- "labels": labels,
- }
-
-
-class InterleaveDataset(IterableDataset):
- def __init__(
- self,
- datasets: list[IterableDataset],
- probabilities: list[float],
- seed: int = 42,
- ):
- super().__init__()
-
- self.datasets = datasets
- self.probabilities = probabilities
- self.seed = seed
-
- def __iter__(self):
- rng = np.random.default_rng(self.seed)
- dataset_iterators = [iter(dataset) for dataset in self.datasets]
-
- while True:
- # Random choice one
- dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
- dataset_iterator = dataset_iterators[dataset_idx]
-
- try:
- yield next(dataset_iterator)
- except StopIteration:
- # Exhausted, create a new iterator
- dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
- yield next(dataset_iterators[dataset_idx])
-
-
-class SemanticDataModule(LightningDataModule):
- def __init__(
- self,
- train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
- val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
- batch_size: int = 32,
- tokenizer: AutoTokenizer = None,
- max_length: int = 1024,
- num_workers: int = 4,
- ):
- super().__init__()
-
- self.train_dataset = train_dataset
- self.val_dataset = val_dataset
- self.batch_size = batch_size
- self.tokenizer = tokenizer
- self.max_length = max_length
- self.num_workers = num_workers
-
- def train_dataloader(self):
- return DataLoader(
- self.train_dataset,
- batch_size=self.batch_size,
- collate_fn=TextDataCollator(self.tokenizer, self.max_length),
- num_workers=self.num_workers,
- persistent_workers=True,
- )
-
- def val_dataloader(self):
- return DataLoader(
- self.val_dataset,
- batch_size=self.batch_size,
- collate_fn=TextDataCollator(self.tokenizer, self.max_length),
- num_workers=self.num_workers,
- persistent_workers=True,
- )
-
-
-if __name__ == "__main__":
- from tqdm import tqdm
-
- ds = AutoTextSemanticInstructionDataset(
- ["data/protos"],
- tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
- use_speaker=False,
- interactive_prob=1.0,
- skip_text_prob=0.5,
- )
-
- for i in ds:
- print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
- # i["labels"][0][i["labels"][0] == -100] = 0
- # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
- break
diff --git a/fish_speech/datasets/vqgan.py b/fish_speech/datasets/vqgan.py
deleted file mode 100644
index a45583d22efb0feb9dc1e823bae1ef74534b299e..0000000000000000000000000000000000000000
--- a/fish_speech/datasets/vqgan.py
+++ /dev/null
@@ -1,147 +0,0 @@
-from dataclasses import dataclass
-from pathlib import Path
-from typing import Optional
-
-import librosa
-import numpy as np
-import torch
-from lightning import LightningDataModule
-from torch.utils.data import DataLoader, Dataset
-
-from fish_speech.utils import RankedLogger
-
-logger = RankedLogger(__name__, rank_zero_only=False)
-
-
-class VQGANDataset(Dataset):
- def __init__(
- self,
- filelist: str,
- sample_rate: int = 32000,
- hop_length: int = 640,
- slice_frames: Optional[int] = None,
- ):
- super().__init__()
-
- filelist = Path(filelist)
- root = filelist.parent
-
- self.files = [
- root / line.strip()
- for line in filelist.read_text(encoding="utf-8").splitlines()
- if line.strip()
- ]
- self.sample_rate = sample_rate
- self.hop_length = hop_length
- self.slice_frames = slice_frames
-
- def __len__(self):
- return len(self.files)
-
- def get_item(self, idx):
- file = self.files[idx]
-
- audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
-
- # Slice audio and features
- if (
- self.slice_frames is not None
- and audio.shape[0] > self.slice_frames * self.hop_length
- ):
- start = np.random.randint(
- 0, audio.shape[0] - self.slice_frames * self.hop_length
- )
- audio = audio[start : start + self.slice_frames * self.hop_length]
-
- if len(audio) == 0:
- return None
-
- max_value = np.abs(audio).max()
- if max_value > 1.0:
- audio = audio / max_value
-
- return {
- "audio": torch.from_numpy(audio),
- }
-
- def __getitem__(self, idx):
- try:
- return self.get_item(idx)
- except Exception as e:
- import traceback
-
- traceback.print_exc()
- logger.error(f"Error loading {self.files[idx]}: {e}")
- return None
-
-
-@dataclass
-class VQGANCollator:
- def __call__(self, batch):
- batch = [x for x in batch if x is not None]
-
- audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
- audio_maxlen = audio_lengths.max()
-
- # Rounds up to nearest multiple of 2 (audio_lengths)
- audios = []
- for x in batch:
- audios.append(
- torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
- )
-
- return {
- "audios": torch.stack(audios),
- "audio_lengths": audio_lengths,
- }
-
-
-class VQGANDataModule(LightningDataModule):
- def __init__(
- self,
- train_dataset: VQGANDataset,
- val_dataset: VQGANDataset,
- batch_size: int = 32,
- num_workers: int = 4,
- val_batch_size: Optional[int] = None,
- ):
- super().__init__()
-
- self.train_dataset = train_dataset
- self.val_dataset = val_dataset
- self.batch_size = batch_size
- self.val_batch_size = val_batch_size or batch_size
- self.num_workers = num_workers
-
- def train_dataloader(self):
- return DataLoader(
- self.train_dataset,
- batch_size=self.batch_size,
- collate_fn=VQGANCollator(),
- num_workers=self.num_workers,
- shuffle=True,
- persistent_workers=True,
- )
-
- def val_dataloader(self):
- return DataLoader(
- self.val_dataset,
- batch_size=self.val_batch_size,
- collate_fn=VQGANCollator(),
- num_workers=self.num_workers,
- persistent_workers=True,
- )
-
-
-if __name__ == "__main__":
- dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
- dataloader = DataLoader(
- dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
- )
-
- for batch in dataloader:
- print(batch["audios"].shape)
- print(batch["features"].shape)
- print(batch["audio_lengths"])
- print(batch["feature_lengths"])
- break
diff --git a/fish_speech/i18n/README.md b/fish_speech/i18n/README.md
deleted file mode 100644
index 700902b09db20911ef1ad678cbdce5644b84aea2..0000000000000000000000000000000000000000
--- a/fish_speech/i18n/README.md
+++ /dev/null
@@ -1,27 +0,0 @@
-## i18n Folder Attribution
-
-The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below:
-
-### fish_speech/i18n/core.py
-
-**Related code from RVC:**
-[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py)
-
-**Initial commit:**
-add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35)
-
-**Initial author:**
-[@L4Ph](https://github.com/L4Ph)
-
-### fish_speech/i18n/scan.py
-
-**Related code from RVC:**
-[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py)
-
-**Initial commit:**
-File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058)
-
-**Initial author:**
-[@towzeur](https://github.com/towzeur)
-
-We appreciate the contributions of the RVC project and its authors.
diff --git a/fish_speech/i18n/__init__.py b/fish_speech/i18n/__init__.py
deleted file mode 100644
index 981dbb3b3ecf28043ec9ff5757f947182821a246..0000000000000000000000000000000000000000
--- a/fish_speech/i18n/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .core import i18n
-
-__all__ = ["i18n"]
diff --git a/fish_speech/i18n/__pycache__/__init__.cpython-310.pyc b/fish_speech/i18n/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index ba5a935b26a69595794d6840da906e6615c3a52f..0000000000000000000000000000000000000000
Binary files a/fish_speech/i18n/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/i18n/__pycache__/core.cpython-310.pyc b/fish_speech/i18n/__pycache__/core.cpython-310.pyc
deleted file mode 100644
index 66d2787af00a38bc8ffebb84ed30565a71e94b01..0000000000000000000000000000000000000000
Binary files a/fish_speech/i18n/__pycache__/core.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/i18n/core.py b/fish_speech/i18n/core.py
deleted file mode 100644
index 9f793ec95669228f7f4e8f9a7a5fe38da85c74bd..0000000000000000000000000000000000000000
--- a/fish_speech/i18n/core.py
+++ /dev/null
@@ -1,40 +0,0 @@
-import json
-import locale
-from pathlib import Path
-
-I18N_FILE_PATH = Path(__file__).parent / "locale"
-DEFAULT_LANGUAGE = "en_US"
-
-
-def load_language_list(language):
- with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f:
- language_list = json.load(f)
-
- return language_list
-
-
-class I18nAuto:
- def __init__(self):
- i18n_file = Path(".locale")
-
- if i18n_file.exists():
- with open(i18n_file, "r", encoding="utf-8") as f:
- language = f.read().strip()
- else:
- # getlocale can't identify the system's language ((None, None))
- language = locale.getdefaultlocale()[0]
-
- if (I18N_FILE_PATH / f"{language}.json").exists() is False:
- language = DEFAULT_LANGUAGE
-
- self.language = language
- self.language_map = load_language_list(language)
-
- def __call__(self, key):
- return self.language_map.get(key, key)
-
- def __repr__(self):
- return "Use Language: " + self.language
-
-
-i18n = I18nAuto()
diff --git a/fish_speech/i18n/locale/en_US.json b/fish_speech/i18n/locale/en_US.json
deleted file mode 100644
index 6e280c236e9c79de2087ec33c7bf6f8e1a5296c4..0000000000000000000000000000000000000000
--- a/fish_speech/i18n/locale/en_US.json
+++ /dev/null
@@ -1,122 +0,0 @@
-{
- "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU",
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.",
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).",
- "Accumulate Gradient Batches": "Accumulate Gradient Batches",
- "Add to Processing Area": "Add to Processing Area",
- "Added path successfully!": "Added path successfully!",
- "Advanced Config": "Advanced Config",
- "Base LLAMA Model": "Base LLAMA Model",
- "Batch Inference": "Batch Inference",
- "Batch Size": "Batch Size",
- "Changing with the Model Path": "Changing with the Model Path",
- "Chinese": "Chinese",
- "Compile Model": "Compile Model",
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time",
- "Copy": "Copy",
- "Data Preprocessing": "Data Preprocessing",
- "Data Preprocessing Path": "Data Preprocessing Path",
- "Data Source": "Data Source",
- "Decoder Model Config": "Decoder Model Config",
- "Decoder Model Path": "Decoder Model Path",
- "Disabled": "Disabled",
- "Enable Reference Audio": "Enable Reference Audio",
- "English": "English",
- "Error Message": "Error Message",
- "File Preprocessing": "File Preprocessing",
- "Generate": "Generate",
- "Generated Audio": "Generated Audio",
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format",
- "Infer interface is closed": "Infer interface is closed",
- "Inference Configuration": "Inference Configuration",
- "Inference Server Configuration": "Inference Server Configuration",
- "Inference Server Error": "Inference Server Error",
- "Inferring interface is launched at {}": "Inferring interface is launched at {}",
- "Initial Learning Rate": "Initial Learning Rate",
- "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription",
- "Input Text": "Input Text",
- "Invalid path: {}": "Invalid path: {}",
- "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU",
- "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off",
- "Japanese": "Japanese",
- "LLAMA Configuration": "LLAMA Configuration",
- "LLAMA Model Config": "LLAMA Model Config",
- "LLAMA Model Path": "LLAMA Model Path",
- "Labeling Device": "Labeling Device",
- "LoRA Model to be merged": "LoRA Model to be merged",
- "Maximum Audio Duration": "Maximum Audio Duration",
- "Maximum Length per Sample": "Maximum Length per Sample",
- "Maximum Training Steps": "Maximum Training Steps",
- "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit",
- "Merge": "Merge",
- "Merge LoRA": "Merge LoRA",
- "Merge successfully": "Merge successfully",
- "Minimum Audio Duration": "Minimum Audio Duration",
- "Model Output Path": "Model Output Path",
- "Model Size": "Model Size",
- "Move": "Move",
- "Move files successfully": "Move files successfully",
- "No audio generated, please check the input text.": "No audio generated, please check the input text.",
- "No selected options": "No selected options",
- "Number of Workers": "Number of Workers",
- "Open Inference Server": "Open Inference Server",
- "Open Labeler WebUI": "Open Labeler WebUI",
- "Open Tensorboard": "Open Tensorboard",
- "Opened labeler in browser": "Opened labeler in browser",
- "Optional Label Language": "Optional Label Language",
- "Optional online ver": "Optional online ver",
- "Output Path": "Output Path",
- "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path",
- "Precision": "Precision",
- "Probability of applying Speaker Condition": "Probability of applying Speaker Condition",
- "Put your text here.": "Put your text here.",
- "Reference Audio": "Reference Audio",
- "Reference Text": "Reference Text",
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.",
- "Remove Selected Data": "Remove Selected Data",
- "Removed path successfully!": "Removed path successfully!",
- "Repetition Penalty": "Repetition Penalty",
- "Save model every n steps": "Save model every n steps",
- "Select LLAMA ckpt": "Select LLAMA ckpt",
- "Select VITS ckpt": "Select VITS ckpt",
- "Select VQGAN ckpt": "Select VQGAN ckpt",
- "Select source file processing method": "Select source file processing method",
- "Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)",
- "Selected: {}": "Selected: {}",
- "Speaker": "Speaker",
- "Speaker is identified by the folder name": "Speaker is identified by the folder name",
- "Start Training": "Start Training",
- "Streaming Audio": "Streaming Audio",
- "Streaming Generate": "Streaming Generate",
- "Tensorboard Host": "Tensorboard Host",
- "Tensorboard Log Path": "Tensorboard Log Path",
- "Tensorboard Port": "Tensorboard Port",
- "Tensorboard interface is closed": "Tensorboard interface is closed",
- "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}",
- "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.",
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.",
- "Training Configuration": "Training Configuration",
- "Training Error": "Training Error",
- "Training stopped": "Training stopped",
- "Type name of the speaker": "Type name of the speaker",
- "Type the path or select from the dropdown": "Type the path or select from the dropdown",
- "Use LoRA": "Use LoRA",
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model",
- "Use filelist": "Use filelist",
- "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G",
- "VITS Configuration": "VITS Configuration",
- "VQGAN Configuration": "VQGAN Configuration",
- "Validation Batch Size": "Validation Batch Size",
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)",
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.",
- "WebUI Host": "WebUI Host",
- "WebUI Port": "WebUI Port",
- "Whisper Model": "Whisper Model",
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).",
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU",
- "latest": "latest",
- "new": "new",
- "Realtime Transform Text": "Realtime Transform Text",
- "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
- "Text Normalization": "Text Normalization"
-}
diff --git a/fish_speech/i18n/locale/es_ES.json b/fish_speech/i18n/locale/es_ES.json
deleted file mode 100644
index 3285341f6893fe3e2ccbee6490dd8c90ed21854e..0000000000000000000000000000000000000000
--- a/fish_speech/i18n/locale/es_ES.json
+++ /dev/null
@@ -1,122 +0,0 @@
-{
- "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+",
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.",
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).",
- "Accumulate Gradient Batches": "Acumular lotes de gradientes",
- "Add to Processing Area": "Agregar al Área de Procesamiento",
- "Added path successfully!": "¡Ruta agregada exitosamente!",
- "Advanced Config": "Configuración Avanzada",
- "Base LLAMA Model": "Modelo Base LLAMA",
- "Batch Inference": "Inferencia por Lote",
- "Batch Size": "Tamaño del Lote",
- "Changing with the Model Path": "Cambiando con la Ruta del Modelo",
- "Chinese": "Chino",
- "Compile Model": "Compilar Modelo",
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío",
- "Copy": "Copiar",
- "Data Preprocessing": "Preprocesamiento de Datos",
- "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos",
- "Data Source": "Fuente de Datos",
- "Decoder Model Config": "Configuración del modelo decodificador",
- "Decoder Model Path": "Ruta del modelo decodificador",
- "Disabled": "Desactivado",
- "Enable Reference Audio": "Habilitar Audio de Referencia",
- "English": "Inglés",
- "Error Message": "Mensaje de Error",
- "File Preprocessing": "Preprocesamiento de Archivos",
- "Generate": "Generar",
- "Generated Audio": "Audio Generado",
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab",
- "Infer interface is closed": "La interfaz de inferencia está cerrada",
- "Inference Configuration": "Configuración de Inferencia",
- "Inference Server Configuration": "Configuración del Servidor de Inferencia",
- "Inference Server Error": "Error del Servidor de Inferencia",
- "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}",
- "Initial Learning Rate": "Tasa de Aprendizaje Inicial",
- "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción",
- "Input Text": "Texto de Entrada",
- "Invalid path: {}": "Ruta inválida: {}",
- "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU",
- "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado",
- "Japanese": "Japonés",
- "LLAMA Configuration": "Configuración de LLAMA",
- "LLAMA Model Config": "Configuración del Modelo LLAMA",
- "LLAMA Model Path": "Ruta del Modelo LLAMA",
- "Labeling Device": "Dispositivo de Etiquetado",
- "LoRA Model to be merged": "Modelo LoRA a fusionar",
- "Maximum Audio Duration": "Duración máxima de audio",
- "Maximum Length per Sample": "Longitud Máxima por Muestra",
- "Maximum Training Steps": "Pasos Máximos de Entrenamiento",
- "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite",
- "Merge": "Fusionar",
- "Merge LoRA": "Fusionar LoRA",
- "Merge successfully": "Fusionado exitosamente",
- "Minimum Audio Duration": "Duración mínima de audio",
- "Model Output Path": "Ruta de Salida del Modelo",
- "Model Size": "Tamaño del Modelo",
- "Move": "Mover",
- "Move files successfully": "Archivos movidos exitosamente",
- "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.",
- "No selected options": "No hay opciones seleccionadas",
- "Number of Workers": "Número de Trabajadores",
- "Open Inference Server": "Abrir Servidor de Inferencia",
- "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador",
- "Open Tensorboard": "Abrir Tensorboard",
- "Opened labeler in browser": "Se abrió el etiquetador en el navegador",
- "Optional Label Language": "Idioma de Etiquetado Opcional",
- "Optional online ver": "Ver en línea opcional",
- "Output Path": "Ruta de Salida",
- "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente",
- "Precision": "Precisión",
- "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante",
- "Put your text here.": "Ponga su texto aquí.",
- "Reference Audio": "Audio de Referencia",
- "Reference Text": "Texto de Referencia",
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.",
- "Remove Selected Data": "Eliminar Datos Seleccionados",
- "Removed path successfully!": "¡Ruta eliminada exitosamente!",
- "Repetition Penalty": "Penalización por Repetición",
- "Save model every n steps": "Guardar modelo cada n pasos",
- "Select LLAMA ckpt": "Seleccionar punto de control LLAMA",
- "Select VITS ckpt": "Seleccionar punto de control VITS",
- "Select VQGAN ckpt": "Seleccionar punto de control VQGAN",
- "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente",
- "Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)",
- "Selected: {}": "Seleccionado: {}",
- "Speaker": "Hablante",
- "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
- "Start Training": "Iniciar Entrenamiento",
- "Streaming Audio": "transmisión de audio",
- "Streaming Generate": "síntesis en flujo",
- "Tensorboard Host": "Host de Tensorboard",
- "Tensorboard Log Path": "Ruta de Registro de Tensorboard",
- "Tensorboard Port": "Puerto de Tensorboard",
- "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada",
- "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}",
- "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.",
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.",
- "Training Configuration": "Configuración de Entrenamiento",
- "Training Error": "Error de Entrenamiento",
- "Training stopped": "Entrenamiento detenido",
- "Type name of the speaker": "Escriba el nombre del hablante",
- "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable",
- "Use LoRA": "Usar LoRA",
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo",
- "Use filelist": "Usar lista de archivos",
- "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G",
- "VITS Configuration": "Configuración de VITS",
- "VQGAN Configuration": "Configuración de VQGAN",
- "Validation Batch Size": "Tamaño del Lote de Validación",
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)",
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.",
- "WebUI Host": "Host de WebUI",
- "WebUI Port": "Puerto de WebUI",
- "Whisper Model": "Modelo Whisper",
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).",
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+",
- "latest": "más reciente",
- "new": "nuevo",
- "Realtime Transform Text": "Transformación de Texto en Tiempo Real",
- "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
- "Text Normalization": "Normalización de Texto"
-}
diff --git a/fish_speech/i18n/locale/ja_JP.json b/fish_speech/i18n/locale/ja_JP.json
deleted file mode 100644
index d30bac7bcdf4f4c65b1f78b4dcf9d705c1d8eb39..0000000000000000000000000000000000000000
--- a/fish_speech/i18n/locale/ja_JP.json
+++ /dev/null
@@ -1,123 +0,0 @@
-{
- "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします",
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。",
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。",
- "Accumulate Gradient Batches": "勾配バッチの累積",
- "Add to Processing Area": "処理エリアに追加",
- "Added path successfully!": "パスの追加に成功しました!",
- "Advanced Config": "詳細設定",
- "Base LLAMA Model": "基本LLAMAモデル",
- "Batch Inference": "バッチ推論",
- "Batch Size": "バッチサイズ",
- "Changing with the Model Path": "モデルのパスに伴って変化する",
- "Chinese": "中国語",
- "Compile Model": "モデルのコンパイル",
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります",
- "Copy": "コピー",
- "Data Preprocessing": "データ前処理",
- "Data Preprocessing Path": "データ前処理パス",
- "Data Source": "データソース",
- "Decoder Model Config": "デコーダーモデルの構成",
- "Decoder Model Path": "デコーダーモデルのパス",
- "Disabled": "無効",
- "Enable Reference Audio": "リファレンスオーディオを有効にする",
- "English": "英語",
- "Error Message": "エラーメッセージ",
- "File Preprocessing": "文書前处理",
- "Generate": "生成",
- "Generated Audio": "生成されたオーディオ",
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています",
- "Infer interface is closed": "推論インターフェースが閉じられています",
- "Inference Configuration": "推論設定",
- "Inference Server Configuration": "推論サーバー設定",
- "Inference Server Error": "推論サーバーエラー",
- "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました",
- "Initial Learning Rate": "初期学習率",
- "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス",
- "Input Text": "入力テキスト",
- "Invalid path: {}": "無効なパス: {}",
- "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください",
- "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します",
- "Japanese": "日本語",
- "LLAMA Configuration": "LLAMA設定",
- "LLAMA Model Config": "LLAMAモデル設定",
- "LLAMA Model Path": "LLAMAモデルパス",
- "Labeling Device": "ラベリングデバイス",
- "LoRA Model to be merged": "マージするLoRAモデル",
- "Maximum Audio Duration": "最大オーディオの長さ",
- "Maximum Length per Sample": "サンプルあたりの最大長",
- "Maximum Training Steps": "最大トレーニングステップ数",
- "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します",
- "Merge": "マージ",
- "Merge LoRA": "LoRAのマージ",
- "Merge successfully": "マージに成功しました",
- "Minimum Audio Duration": "最小オーディオの長さ",
- "Model Output Path": "モデル出力パス",
- "Model Size": "モデルサイズ",
- "Move": "移動",
- "Move files successfully": "ファイルの移動に成功しました",
- "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。",
- "No selected options": "選択されたオプションはありません",
- "Number of Workers": "ワーカー数",
- "Open Inference Server": "推論サーバーを開く",
- "Open Labeler WebUI": "ラベラーWebUIを開く",
- "Open Tensorboard": "Tensorboardを開く",
- "Opened labeler in browser": "ブラウザでラベラーを開きました",
- "Optional Label Language": "オプションのラベル言語",
- "Optional online ver": "オプションのオンラインバージョン",
- "Output Path": "出力パス",
- "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください",
- "Precision": "精度",
- "Probability of applying Speaker Condition": "話者条件を適用する確率",
- "Put your text here.": "ここにテキストを入力してください。",
- "Reference Audio": "リファレンスオーディオ",
- "Reference Text": "リファレンステキスト",
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。",
- "Remove Selected Data": "選択したデータを削除",
- "Removed path successfully!": "パスの削除に成功しました!",
- "Repetition Penalty": "反復ペナルティ",
- "Save model every n steps": "nステップごとにモデルを保存",
- "Select LLAMA ckpt": " LLAMA チェックポイントを選択",
- "Select VITS ckpt": "VITS チェックポイントを選択",
- "Select VQGAN ckpt": "VQGAN チェックポイントを選択",
- "Select source file processing method": "ソースファイルの処理方法を選択",
- "Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください",
- "Selected: {}": "選択済み: {}",
- "Speaker": "話者",
- "Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
- "Start Training": "トレーニング開始",
- "Streaming Audio": "ストリーミングオーディオ",
- "Streaming Generate": "ストリーミング合成",
- "Tensorboard Host": "Tensorboardホスト",
- "Tensorboard Log Path": "Tensorboardログパス",
- "Tensorboard Port": "Tensorboardポート",
- "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています",
- "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました",
- "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。",
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。",
- "Training Configuration": "トレーニング設定",
- "Training Error": "トレーニングエラー",
- "Training stopped": "トレーニングが停止しました",
- "Type name of the speaker": "話者の名前を入力",
- "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください",
- "Use LoRA": "LoRAを使用",
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります",
- "Use filelist": "ファイルリストを使用",
- "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください",
- "VITS Configuration": "VITS の構成",
- "VQGAN Configuration": "VQGAN の構成",
- "Validation Batch Size": "検証バッチサイズ",
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)",
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。",
- "WebUI Host": "WebUIホスト",
- "WebUI Port": "WebUIポート",
- "Whisper Model": "Whisperモデル",
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。",
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします",
- "latest": "最新",
- "new": "新規",
- "Realtime Transform Text": "リアルタイム変換テキスト",
- "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
- "Text Normalization": "テキスト正規化"
-
-}
diff --git a/fish_speech/i18n/locale/pt_BR.json b/fish_speech/i18n/locale/pt_BR.json
deleted file mode 100644
index 385f20272e19053ab9b6cf6463a84c8ece768c68..0000000000000000000000000000000000000000
--- a/fish_speech/i18n/locale/pt_BR.json
+++ /dev/null
@@ -1,133 +0,0 @@
-{
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.",
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).",
- "Accumulate Gradient Batches": "Acumular Lotes de Gradiente",
- "Add to Processing Area": "Adicionar à Área de Processamento",
- "Added path successfully!": "Caminho adicionado com sucesso!",
- "Advanced Config": "Configuração Avançada",
- "Base LLAMA Model": "Modelo LLAMA Base",
- "Batch Inference": "Inferência em Lote",
- "Batch Size": "Tamanho do Lote",
- "Changing with the Model Path": "Alterando com o Caminho do Modelo",
-
- "Compile Model": "Compilar Modelo",
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial",
- "Copy": "Copiar",
- "Data Preprocessing": "Pré-processamento de Dados",
- "Data Preprocessing Path": "Caminho de Pré-processamento de Dados",
- "Data Source": "Fonte de Dados",
- "Decoder Model Config": "Configuração do Modelo Decodificador",
- "Decoder Model Path": "Caminho do Modelo Decodificador",
- "Disabled": "Desativado",
- "Enable Initial Prompt": "Habilitar Prompt Inicial",
- "Enable Reference Audio": "Habilitar Áudio de Referência",
- "English": "Inglês",
- "Japanese": "Japonês",
- "Chinese": "Chinês",
- "Portuguese": "Português",
- "Spanish": "Espanhol",
- "Error Message": "Mensagem de Erro",
- "Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)",
- "File Preprocessing": "Pré-processamento de Arquivos",
- "Generate": "Gerar",
- "Generated Audio": "Áudio Gerado",
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)",
- "Infer interface is closed": "A interface de inferência foi fechada",
- "Inference Configuration": "Configuração de Inferência",
- "Inference Server Configuration": "Configuração do Servidor de Inferência",
- "Inference Server Error": "Erro do Servidor de Inferência",
- "Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}",
- "Initial Learning Rate": "Taxa de Aprendizagem Inicial",
- "Initial Prompt": "Prompt Inicial",
- "Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.",
- "Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição",
- "Input Text": "Texto de Entrada",
- "Invalid path: {}": "Caminho inválido: {}",
- "It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU",
- "Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)",
- "LLAMA Configuration": "Configuração do LLAMA",
- "LLAMA Model Config": "Configuração do Modelo LLAMA",
- "LLAMA Model Path": "Caminho do Modelo LLAMA",
- "Labeling Device": "Dispositivo de Rotulagem",
- "LoRA Model to be merged": "Modelo LoRA para mesclagem",
- "Maximum Length per Sample": "Comprimento Máximo por Amostra",
- "Maximum Training Steps": "Etapas Máximas de Treinamento",
- "Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite",
- "Merge": "Mesclar",
- "Merge LoRA": "Mesclar LoRA",
- "Merge successfully": "Mesclado com sucesso",
- "Model Output Path": "Caminho de Saída do Modelo",
- "Model Quantization": "Quantização do Modelo",
- "Model Size": "Tamanho do Modelo",
- "Move": "Mover",
- "Move files successfully": "Arquivos movidos com sucesso",
- "No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.",
- "No selected options": "Nenhuma opção selecionada",
- "Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)",
- "Number of Workers": "Número de Processos",
- "Open Inference Server": "Abrir Servidor de Inferência",
- "Open Labeler WebUI": "Abrir WebUI de Rotulagem",
- "Open Tensorboard": "Abrir Tensorboard",
- "Opened labeler in browser": "WebUI de rotulagem aberta no navegador",
- "Optional Label Language": "Idioma do Rótulo (Opcional)",
- "Optional online ver": "Versão online (opcional)",
- "Output Path": "Caminho de Saída",
- "Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente",
- "Post-quantification Precision": "Precisão Pós-quantização",
- "Precision": "Precisão",
- "Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador",
- "Put your text here.": "Insira seu texto aqui.",
- "Quantify": "Quantizar",
- "Quantify successfully": "Quantizado com sucesso",
- "Realtime Transform Text": "Transformar Texto em Tempo Real",
- "Reference Audio": "Áudio de Referência",
- "Reference Text": "Texto de Referência",
- "warning": "Aviso",
- "Pre-processing begins...": "O pré-processamento começou!",
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.",
- "Remove Selected Data": "Remover Dados Selecionados",
- "Removed path successfully!": "Caminho removido com sucesso!",
- "Repetition Penalty": "Penalidade de Repetição",
- "Save model every n steps": "Salvar modelo a cada n etapas",
- "Select LLAMA ckpt": "Selecionar .ckpt do LLAMA",
- "Select source file processing method": "Escolha como processar o arquivo de origem",
- "Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)",
- "Selected: {}": "Selecionado: {}",
- "Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta",
- "Start Training": "Iniciar Treinamento",
- "Streaming Audio": "Áudio em Streaming",
- "Streaming Generate": "Geração em Streaming",
- "Tensorboard Host": "Host do Tensorboard",
- "Tensorboard Log Path": "Caminho de Log do Tensorboard",
- "Tensorboard Port": "Porta do Tensorboard",
- "Tensorboard interface is closed": "A interface do Tensorboard está fechada",
- "Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}",
- "Text Normalization": "Normalização de Texto",
- "Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.",
- "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência",
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.",
- "Training Configuration": "Configuração de Treinamento",
- "Training Error": "Erro de Treinamento",
- "Training stopped": "Treinamento interrompido!",
- "Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso",
- "Use LoRA": "Usar LoRA",
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade",
- "Use filelist": "Usar lista de arquivos",
- "VQGAN Configuration": "Configuração do VQGAN",
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)",
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.",
- "WebUI Host": "Host da WebUI",
- "WebUI Port": "Porta da WebUI",
- "Whisper Model": "Modelo Whisper",
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).",
- "auto": "automático",
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+",
- "latest": "mais recente",
- "new": "novo",
- "This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.",
- "You don't need to train this model!": "Não é necessário treinar este modelo!",
- "Yes": "Sim",
- "No": "Não",
- "version:": "versão:",
- "author:": "autor:"
-}
diff --git a/fish_speech/i18n/locale/zh_CN.json b/fish_speech/i18n/locale/zh_CN.json
deleted file mode 100644
index 3dd1a5cd1ccf3860ca508238cc64a68ca4fc3276..0000000000000000000000000000000000000000
--- a/fish_speech/i18n/locale/zh_CN.json
+++ /dev/null
@@ -1,122 +0,0 @@
-{
- "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed",
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。",
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.",
- "Accumulate Gradient Batches": "梯度累积批次",
- "Add to Processing Area": "加入处理区",
- "Added path successfully!": "添加路径成功!",
- "Advanced Config": "高级参数",
- "Base LLAMA Model": "基础 LLAMA 模型",
- "Batch Inference": "批量推理",
- "Batch Size": "批次大小",
- "Changing with the Model Path": "随模型路径变化",
- "Chinese": "中文",
- "Compile Model": "编译模型",
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间",
- "Copy": "复制",
- "Data Preprocessing": "数据预处理",
- "Data Preprocessing Path": "数据预处理路径",
- "Data Source": "数据源",
- "Decoder Model Config": "解码器模型配置",
- "Decoder Model Path": "解码器模型路径",
- "Disabled": "禁用",
- "Enable Reference Audio": "启用参考音频",
- "English": "英文",
- "Error Message": "错误信息",
- "File Preprocessing": "文件预处理",
- "Generate": "生成",
- "Generated Audio": "音频",
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式",
- "Infer interface is closed": "推理界面已关闭",
- "Inference Configuration": "推理配置",
- "Inference Server Configuration": "推理服务器配置",
- "Inference Server Error": "推理服务器错误",
- "Inferring interface is launched at {}": "推理界面已在 {} 上启动",
- "Initial Learning Rate": "初始学习率",
- "Input Audio & Source Path for Transcription": "输入音频和转录源路径",
- "Input Text": "输入文本",
- "Invalid path: {}": "无效路径: {}",
- "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU",
- "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭",
- "Japanese": "日文",
- "LLAMA Configuration": "LLAMA 配置",
- "LLAMA Model Config": "LLAMA 模型配置",
- "LLAMA Model Path": "LLAMA 模型路径",
- "Labeling Device": "标注加速设备",
- "LoRA Model to be merged": "要合并的 LoRA 模型",
- "Maximum Audio Duration": "最大音频时长",
- "Maximum Length per Sample": "每个样本的最大长度",
- "Maximum Training Steps": "最大训练步数",
- "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制",
- "Merge": "合并",
- "Merge LoRA": "合并 LoRA",
- "Merge successfully": "合并成功",
- "Minimum Audio Duration": "最小音频时长",
- "Model Output Path": "模型输出路径",
- "Model Size": "模型规模",
- "Move": "移动",
- "Move files successfully": "移动文件成功",
- "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.",
- "No selected options": "没有选择的选项",
- "Number of Workers": "数据加载进程数",
- "Open Inference Server": "打开推理服务器",
- "Open Labeler WebUI": "打开标注工具",
- "Open Tensorboard": "打开 Tensorboard",
- "Opened labeler in browser": "在浏览器中打开标注工具",
- "Optional Label Language": "[可选] 标注语言",
- "Optional online ver": "[可选] 使用在线版",
- "Output Path": "输出路径",
- "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
- "Precision": "精度",
- "Probability of applying Speaker Condition": "应用说话人条件的概率",
- "Put your text here.": "在此处输入文本.",
- "Reference Audio": "参考音频",
- "Reference Text": "参考文本",
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.",
- "Remove Selected Data": "移除选中数据",
- "Removed path successfully!": "移除路径成功!",
- "Repetition Penalty": "重复惩罚",
- "Save model every n steps": "每 n 步保存模型",
- "Select LLAMA ckpt": "选择 LLAMA 检查点",
- "Select VITS ckpt": "选择 VITS 检查点",
- "Select VQGAN ckpt": "选择 VQGAN 检查点",
- "Select source file processing method": "选择源文件处理方法",
- "Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型",
- "Selected: {}": "已选择: {}",
- "Speaker": "说话人",
- "Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
- "Start Training": "开始训练",
- "Streaming Audio": "流式音频",
- "Streaming Generate": "流式合成",
- "Tensorboard Host": "Tensorboard 监听地址",
- "Tensorboard Log Path": "Tensorboard 日志路径",
- "Tensorboard Port": "Tensorboard 端口",
- "Tensorboard interface is closed": "Tensorboard 界面已关闭",
- "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动",
- "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.",
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.",
- "Training Configuration": "训练配置",
- "Training Error": "训练错误",
- "Training stopped": "训练已停止",
- "Type name of the speaker": "输入说话人的名称",
- "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择",
- "Use LoRA": "使用 LoRA",
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量",
- "Use filelist": "使用文件列表",
- "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small",
- "VITS Configuration": "VITS 配置",
- "VQGAN Configuration": "VQGAN 配置",
- "Validation Batch Size": "验证批次大小",
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)",
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.",
- "WebUI Host": "WebUI 监听地址",
- "WebUI Port": "WebUI 端口",
- "Whisper Model": "Whisper 模型",
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.",
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed",
- "latest": "最近的检查点",
- "new": "创建新的检查点",
- "Realtime Transform Text": "实时规范化文本",
- "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览",
- "Text Normalization": "文本规范化"
-}
diff --git a/fish_speech/i18n/scan.py b/fish_speech/i18n/scan.py
deleted file mode 100644
index d0194c0f1a31dc95309c64626d13f04751a44ba1..0000000000000000000000000000000000000000
--- a/fish_speech/i18n/scan.py
+++ /dev/null
@@ -1,122 +0,0 @@
-import ast
-import glob
-import json
-from collections import OrderedDict
-from pathlib import Path
-
-from loguru import logger
-
-from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH
-
-
-def extract_i18n_strings(node):
- i18n_strings = []
-
- if (
- isinstance(node, ast.Call)
- and isinstance(node.func, ast.Name)
- and node.func.id == "i18n"
- ):
- for arg in node.args:
- if isinstance(arg, ast.Str):
- i18n_strings.append(arg.s)
-
- for child_node in ast.iter_child_nodes(node):
- i18n_strings.extend(extract_i18n_strings(child_node))
-
- return i18n_strings
-
-
-# scan the directory for all .py files (recursively)
-# for each file, parse the code into an AST
-# for each AST, extract the i18n strings
-
-strings = []
-folders = ["fish_speech", "tools"]
-# for filename in glob.iglob("**/*.py", recursive=True):
-for folder in folders:
- for f in Path(folder).rglob("*.py"):
- code = f.read_text(encoding="utf-8")
- if "i18n(" in code:
- tree = ast.parse(code)
- i18n_strings = extract_i18n_strings(tree)
- logger.info(f"Found {len(i18n_strings)} i18n strings in {f}")
- strings.extend(i18n_strings)
-
-code_keys = set(strings)
-logger.info(f"Total unique: {len(code_keys)}")
-
-
-standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
-with open(standard_file, "r", encoding="utf-8") as f:
- standard_data = json.load(f, object_pairs_hook=OrderedDict)
-standard_keys = set(standard_data.keys())
-
-# Define the standard file name
-unused_keys = standard_keys - code_keys
-logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}")
-for unused_key in unused_keys:
- logger.info(f"\t{unused_key}")
-
-missing_keys = code_keys - standard_keys
-logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}")
-for missing_key in missing_keys:
- logger.info(f"\t{missing_key}")
-
-code_keys_dict = OrderedDict()
-for s in strings:
- code_keys_dict[s] = s
-
-# write back
-with open(standard_file, "w", encoding="utf-8") as f:
- json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True)
- f.write("\n")
-
-logger.info(f"Updated {standard_file}")
-
-
-# Define the standard file name
-standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
-
-# Find all JSON files in the directory
-dir_path = I18N_FILE_PATH
-languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE]
-
-# Load the standard file
-with open(standard_file, "r", encoding="utf-8") as f:
- standard_data = json.load(f, object_pairs_hook=OrderedDict)
-
-# Loop through each language file
-for lang_file in languages:
- # Load the language file
- with open(lang_file, "r", encoding="utf-8") as f:
- lang_data = json.load(f, object_pairs_hook=OrderedDict)
-
- # Find the difference between the language file and the standard file
- diff = set(standard_data.keys()) - set(lang_data.keys())
-
- miss = set(lang_data.keys()) - set(standard_data.keys())
-
- # Add any missing keys to the language file
- for key in diff:
- lang_data[key] = "#!" + key
- logger.info(f"Added missing key: {key} to {lang_file}")
-
- # Del any extra keys to the language file
- for key in miss:
- del lang_data[key]
- logger.info(f"Del extra key: {key} from {lang_file}")
-
- # Sort the keys of the language file to match the order of the standard file
- lang_data = OrderedDict(
- sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0]))
- )
-
- # Save the updated language file
- with open(lang_file, "w", encoding="utf-8") as f:
- json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True)
- f.write("\n")
-
- logger.info(f"Updated {lang_file}")
-
-logger.info("Done")
diff --git a/fish_speech/models/text2semantic/__init__.py b/fish_speech/models/text2semantic/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/fish_speech/models/text2semantic/__pycache__/__init__.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index 2660e31d27b749e906716f846ad0303f28c5d3ae..0000000000000000000000000000000000000000
Binary files a/fish_speech/models/text2semantic/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/models/text2semantic/__pycache__/lit_module.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/lit_module.cpython-310.pyc
deleted file mode 100644
index 10287acff4182fbf2964bed0c6512b752fe087bc..0000000000000000000000000000000000000000
Binary files a/fish_speech/models/text2semantic/__pycache__/lit_module.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/models/text2semantic/__pycache__/llama.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/llama.cpython-310.pyc
deleted file mode 100644
index c46a1595d473a5422c0f1a526faf162604c66191..0000000000000000000000000000000000000000
Binary files a/fish_speech/models/text2semantic/__pycache__/llama.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/models/text2semantic/__pycache__/lora.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/lora.cpython-310.pyc
deleted file mode 100644
index 277545bc846fa08418ba7e846e38c006877bf95d..0000000000000000000000000000000000000000
Binary files a/fish_speech/models/text2semantic/__pycache__/lora.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/models/text2semantic/lit_module.py b/fish_speech/models/text2semantic/lit_module.py
deleted file mode 100644
index df970400f8a073be4c4166a697245fabdf6b09b0..0000000000000000000000000000000000000000
--- a/fish_speech/models/text2semantic/lit_module.py
+++ /dev/null
@@ -1,202 +0,0 @@
-from typing import Any, Optional
-
-import lightning as L
-import torch
-import torch.nn.functional as F
-from lightning.pytorch.utilities.types import OptimizerLRScheduler
-
-import fish_speech.utils as utils
-from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
-from fish_speech.models.text2semantic.llama import NaiveTransformer
-
-log = utils.RankedLogger(__name__, rank_zero_only=True)
-
-
-class TextToSemantic(L.LightningModule):
- def __init__(
- self,
- model: NaiveTransformer,
- optimizer: Any,
- lr_scheduler: Any,
- ):
- super().__init__()
-
- self.model = model
- self.optimizer_builder = optimizer
- self.lr_scheduler_builder = lr_scheduler
-
- def forward(self, x):
- return self.model(x)
-
- def on_save_checkpoint(self, checkpoint):
- # Save only LoRA parameters
- state_dict = checkpoint["state_dict"]
- use_lora = any("lora" in name for name in state_dict.keys())
- if not use_lora:
- return
-
- for name in list(state_dict.keys()):
- if "lora" not in name:
- state_dict.pop(name)
-
- def configure_optimizers(self) -> OptimizerLRScheduler:
- # Get weight decay parameters
- weight_decay_parameters, other_parameters = [], []
- for name, param in self.named_parameters():
- if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
- other_parameters.append(param)
- else:
- weight_decay_parameters.append(param)
-
- optimizer = self.optimizer_builder(
- [
- {"params": weight_decay_parameters},
- {"params": other_parameters, "weight_decay": 0.0},
- ]
- )
-
- # Print the parameters and their weight decay
- for i in optimizer.param_groups:
- log.info(
- f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
- )
-
- lr_scheduler = self.lr_scheduler_builder(optimizer)
-
- return {
- "optimizer": optimizer,
- "lr_scheduler": {
- "scheduler": lr_scheduler,
- "interval": "step",
- },
- }
-
- # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
- def get_batch_logps(
- self,
- logits: torch.FloatTensor,
- labels: torch.LongTensor,
- average_log_prob: bool = False,
- ) -> torch.FloatTensor:
- """Compute the log probabilities of the given labels under the given logits.
-
- Args:
- logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
- labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
- average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
-
- Returns:
- A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
- """
- assert logits.shape[:-1] == labels.shape
-
- labels = labels.clone()
- loss_mask = labels != -100
-
- # dummy token; we'll ignore the losses on these tokens later
- labels[labels == -100] = 0
-
- per_token_logps = torch.gather(
- logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
- ).squeeze(-1)
-
- if average_log_prob:
- return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
- else:
- return (per_token_logps * loss_mask).sum(-1)
-
- def _step(self, batch, batch_idx, stage: str):
- is_train = stage == "train"
-
- if is_train:
- # Key part to make lora work
- # Otherwise the parameters are merged, which lead to incorrect gradients
- self.model.train()
-
- # Do positive and negative samples in the same batch to speed up training
- labels = batch["labels"]
- outputs = self.model(
- inp=batch["inputs"],
- key_padding_mask=batch["attention_masks"],
- )
- token_logits = outputs.token_logits
- codebook_logits = outputs.codebook_logits
-
- # Generate labels
- base_loss = F.cross_entropy(
- token_logits.view(-1, token_logits.size(-1)),
- labels[:, 0].reshape(-1),
- ignore_index=-100,
- )
-
- codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
- semantic_loss = F.cross_entropy(
- codebook_logits.view(-1, codebook_logits.size(-1)),
- codebook_labels.reshape(-1),
- ignore_index=-100,
- )
-
- loss = base_loss + semantic_loss
-
- self.log(
- f"{stage}/loss",
- loss,
- on_step=is_train,
- on_epoch=not is_train,
- prog_bar=True,
- logger=True,
- sync_dist=not is_train,
- )
-
- self.log(
- f"{stage}/base_loss",
- base_loss,
- on_step=is_train,
- on_epoch=not is_train,
- prog_bar=False,
- logger=True,
- sync_dist=not is_train,
- )
-
- self.log(
- f"{stage}/semantic_loss",
- semantic_loss,
- on_step=is_train,
- on_epoch=not is_train,
- prog_bar=False,
- logger=True,
- sync_dist=not is_train,
- )
-
- # Top-5 accuracy
- accuracy = self.get_accuracy(codebook_logits, codebook_labels)
- self.log(
- f"{stage}/top_5_accuracy",
- accuracy,
- on_step=is_train,
- on_epoch=not is_train,
- prog_bar=True,
- logger=True,
- sync_dist=not is_train,
- )
-
- return loss
-
- def get_accuracy(self, logits, labels):
- mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
- if mask.sum() == 0:
- return torch.tensor(0.0, device=logits.device)
-
- _, indices = logits.topk(5, dim=-1)
- correct = indices.eq(labels.unsqueeze(-1))
- correct[~mask] = 0
- correct = correct.sum()
- accuracy = correct / mask.sum()
-
- return accuracy
-
- def training_step(self, batch, batch_idx):
- return self._step(batch, batch_idx, "train")
-
- def validation_step(self, batch, batch_idx):
- return self._step(batch, batch_idx, "val")
diff --git a/fish_speech/models/text2semantic/llama.py b/fish_speech/models/text2semantic/llama.py
deleted file mode 100644
index 0725dfb9b78b1154753641b69c959a2faadba48c..0000000000000000000000000000000000000000
--- a/fish_speech/models/text2semantic/llama.py
+++ /dev/null
@@ -1,779 +0,0 @@
-import json
-import math
-from collections import OrderedDict
-from dataclasses import dataclass
-from pathlib import Path
-from typing import Optional
-
-import torch
-import torch.nn as nn
-from einops import rearrange
-from loguru import logger
-from torch import Tensor
-from torch.nn import functional as F
-from torch.nn.attention import SDPBackend, sdpa_kernel
-from torch.utils.checkpoint import checkpoint
-from transformers import AutoTokenizer
-
-from fish_speech.conversation import SEMANTIC_TOKEN
-from fish_speech.utils import RankedLogger
-
-from .lora import LoraConfig, setup_lora
-
-log = RankedLogger(__name__, rank_zero_only=True)
-
-
-def find_multiple(n: int, k: int) -> int:
- if n % k == 0:
- return n
- return n + k - (n % k)
-
-
-@dataclass
-class BaseModelArgs:
- model_type: str = "base"
-
- vocab_size: int = 32000
- n_layer: int = 32
- n_head: int = 32
- dim: int = 4096
- intermediate_size: int = None
- n_local_heads: int = -1
- head_dim: int = 64
- rope_base: float = 10000
- norm_eps: float = 1e-5
- max_seq_len: int = 2048
- dropout: float = 0.0
- tie_word_embeddings: bool = True
- attention_qkv_bias: bool = False
-
- # Codebook configs
- codebook_size: int = 160
- num_codebooks: int = 4
-
- # Gradient checkpointing
- use_gradient_checkpointing: bool = True
-
- # Initialize the model
- initializer_range: float = 0.02
-
- def __post_init__(self):
- if self.n_local_heads == -1:
- self.n_local_heads = self.n_head
- if self.intermediate_size is None:
- hidden_dim = 4 * self.dim
- n_hidden = int(2 * hidden_dim / 3)
- self.intermediate_size = find_multiple(n_hidden, 256)
- self.head_dim = self.dim // self.n_head
-
- @staticmethod
- def from_pretrained(path: str):
- path = Path(path)
-
- if path.is_dir():
- path = path / "config.json"
-
- with open(path, "r", encoding="utf-8") as f:
- data = json.load(f)
-
- match data["model_type"]:
- case "naive":
- cls = NaiveModelArgs
- case "dual_ar":
- cls = DualARModelArgs
- case _:
- raise ValueError(f"Unknown model type: {data['model_type']}")
-
- return cls(**data)
-
- def save(self, path: str):
- with open(path, "w") as f:
- json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
-
-
-@dataclass
-class NaiveModelArgs(BaseModelArgs):
- model_type: str = "naive"
-
-
-@dataclass
-class DualARModelArgs(BaseModelArgs):
- model_type: str = "dual_ar"
- n_fast_layer: int = 4
-
-
-class KVCache(nn.Module):
- def __init__(
- self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
- ):
- super().__init__()
- cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
- self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
- self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
-
- def update(self, input_pos, k_val, v_val):
- # input_pos: [S], k_val: [B, H, S, D]
- assert input_pos.shape[0] == k_val.shape[2]
-
- k_out = self.k_cache
- v_out = self.v_cache
- k_out[:, :, input_pos] = k_val
- v_out[:, :, input_pos] = v_val
-
- return k_out, v_out
-
-
-@dataclass
-class TransformerForwardResult:
- token_logits: Tensor
- codebook_logits: Tensor
-
-
-@dataclass
-class BaseTransformerForwardResult:
- logits: Tensor
- hidden_states: Tensor
-
-
-class BaseTransformer(nn.Module):
- def __init__(
- self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True
- ) -> None:
- super().__init__()
- self.config = config
- self.tokenizer = tokenizer
-
- self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN)
-
- # Slow transformer
- self.embeddings = nn.Embedding(
- config.vocab_size,
- config.dim,
- )
- self.codebook_embeddings = nn.Embedding(
- config.codebook_size * config.num_codebooks,
- config.dim,
- )
- self.layers = nn.ModuleList(
- TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
- )
- self.norm = RMSNorm(config.dim, eps=config.norm_eps)
-
- if self.config.tie_word_embeddings is False:
- self.output = nn.Linear(
- config.dim,
- config.vocab_size,
- bias=False,
- )
-
- self.register_buffer(
- "freqs_cis",
- precompute_freqs_cis(
- config.max_seq_len,
- config.dim // config.n_head,
- config.rope_base,
- ),
- persistent=False,
- )
- self.register_buffer(
- "causal_mask",
- torch.tril(
- torch.ones(
- config.max_seq_len,
- config.max_seq_len,
- dtype=torch.bool,
- )
- ),
- persistent=False,
- )
-
- # For kv cache
- self.max_batch_size = -1
- self.max_seq_len = -1
-
- if init_weights:
- self.apply(self._init_weights)
-
- def setup_caches(
- self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
- ):
- if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
- return
-
- head_dim = self.config.dim // self.config.n_head
- max_seq_len = find_multiple(max_seq_len, 8)
- self.max_seq_len = max_seq_len
- self.max_batch_size = max_batch_size
-
- for b in self.layers:
- b.attention.kv_cache = KVCache(
- max_batch_size,
- max_seq_len,
- self.config.n_local_heads,
- head_dim,
- dtype=dtype,
- )
-
- def embed(self, x: Tensor) -> Tensor:
- vocab_embeds = [self.embeddings(x[:, 0])]
- for i in range(self.config.num_codebooks):
- emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
- emb[x[:, 0] != self.semantic_token_id] = 0
- vocab_embeds.append(emb)
-
- x = torch.stack(vocab_embeds, dim=3)
- x = x.sum(dim=3)
-
- return x
-
- def forward(
- self,
- inp: Tensor,
- key_padding_mask: Optional[Tensor] = None,
- ) -> BaseTransformerForwardResult:
- seq_len = inp.size(2)
-
- # Here we want to merge the embeddings of the codebooks
- x = self.embed(inp)
-
- freqs_cis = self.freqs_cis[:seq_len]
-
- # Not that the causal mask here follows the definition of scaled_dot_product_attention
- # That is, FALSE means masked out
- # To maintain consistency, key_padding_mask use TRUE to mask out
- mask = None
- if key_padding_mask is not None:
- mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
- mask = mask & key_padding_mask[:, None, None, :].logical_not()
-
- for layer in self.layers:
- if self.config.use_gradient_checkpointing and self.training:
- x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
- else:
- x = layer(x, freqs_cis, mask)
-
- # We got slow_out here
- slow_out = self.norm(x)
-
- if self.config.tie_word_embeddings:
- token_logits = F.linear(slow_out, self.embeddings.weight)
- else:
- token_logits = self.output(slow_out)
-
- return BaseTransformerForwardResult(
- logits=token_logits,
- hidden_states=x,
- )
-
- def forward_generate(
- self,
- x: Tensor,
- input_pos: Optional[Tensor] = None,
- return_all: bool = False,
- ) -> BaseTransformerForwardResult:
- # This is used for generation, optimized for torch compile
- assert (
- self.max_seq_len != -1 and self.max_batch_size != -1
- ), "Please call setup_caches before forward_generate"
-
- x = self.embed(x)
-
- mask = self.causal_mask[
- None, None, input_pos, : self.max_seq_len
- ] # (B, N, Q, K)
- freqs_cis = self.freqs_cis[input_pos]
-
- for layer in self.layers:
- x = layer(x, freqs_cis, mask, input_pos=input_pos)
-
- # If prefill, we only calculate the logits of last token
- if x.size(1) > 1 and not return_all:
- x = x[:, -1:]
-
- # We got slow_out here
- slow_out = self.norm(x)
-
- if self.config.tie_word_embeddings:
- token_logits = F.linear(slow_out, self.embeddings.weight)
- else:
- token_logits = self.output(slow_out)
-
- return BaseTransformerForwardResult(
- logits=token_logits,
- hidden_states=x,
- )
-
- def _init_weights(self, module):
- std = self.config.initializer_range
- if isinstance(module, nn.Linear):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
-
- @staticmethod
- def from_pretrained(
- path: str,
- load_weights: bool = False,
- max_length: int | None = None,
- lora_config: LoraConfig | None = None,
- rope_base: int | None = None,
- ) -> "BaseTransformer":
- config = BaseModelArgs.from_pretrained(str(path))
- if max_length is not None:
- config.max_seq_len = max_length
- log.info(f"Override max_seq_len to {max_length}")
-
- if rope_base is not None:
- config.rope_base = rope_base
- log.info(f"Override rope_base to {rope_base}")
-
- match config.model_type:
- case "naive":
- model_cls = NaiveTransformer
- case "dual_ar":
- model_cls = DualARTransformer
- case _:
- raise ValueError(f"Unknown model type: {config.model_type}")
-
- tokenizer = AutoTokenizer.from_pretrained(str(path))
- log.info(f"Loading model from {path}, config: {config}")
- model = model_cls(config, tokenizer=tokenizer)
-
- if lora_config is not None:
- setup_lora(model, lora_config)
- log.info(f"LoRA setup: {lora_config}")
-
- if load_weights is False:
- log.info("Randomly initialized model")
- else:
-
- if "int8" in str(Path(path)):
- logger.info("Using int8 weight-only quantization!")
- from tools.llama.quantize import WeightOnlyInt8QuantHandler
-
- simple_quantizer = WeightOnlyInt8QuantHandler(model)
- model = simple_quantizer.convert_for_runtime()
-
- if "int4" in str(Path(path)):
- logger.info("Using int4 quantization!")
- path_comps = path.name.split("-")
- assert path_comps[-2].startswith("g")
- groupsize = int(path_comps[-2][1:])
- from tools.llama.quantize import WeightOnlyInt4QuantHandler
-
- simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
- model = simple_quantizer.convert_for_runtime()
-
- weights = torch.load(
- Path(path) / "model.pth", map_location="cpu", mmap=True
- )
-
- if "state_dict" in weights:
- logger.warning(
- "Using a TextToSemantic LightningModule checkpoint, "
- "please make sure it is a full model, not a LoRA model."
- )
- weights = weights["state_dict"]
-
- if next(iter(weights.keys())).startswith("model."):
- logger.info(
- f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
- )
- new_weights = OrderedDict()
- for k, v in weights.items():
- new_weights[k.replace("model.", "")] = v
- weights = new_weights
-
- # Verify the name and shape of parameters since strict=False in load_state_dict.
- for k, v in model.named_parameters():
- if k not in weights:
- logger.warning(f"No weight for {k}")
- elif v.shape != weights[k].shape:
- logger.warning(
- f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
- )
-
- err = model.load_state_dict(weights, strict=False, assign=True)
- log.info(f"Loaded weights with error: {err}")
-
- return model
-
- def save_pretrained(self, path: str, drop_lora: bool = False):
- path = Path(path)
- path.mkdir(parents=True, exist_ok=True)
-
- self.config.save(path / "config.json")
- state_dict = self.state_dict()
-
- if drop_lora:
- for key in list(state_dict.keys()):
- if "lora" not in key:
- continue
-
- state_dict.pop(key)
- log.info(f"Drop LoRA parameter: {key}")
-
- torch.save(state_dict, path / "model.pth")
- self.tokenizer.save_pretrained(path)
-
-
-class NaiveTransformer(BaseTransformer):
- def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
- super().__init__(config, init_weights=False, tokenizer=tokenizer)
-
- self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
- self.codebook_output = nn.Linear(
- config.dim,
- config.codebook_size * config.num_codebooks,
- bias=False,
- )
-
- self.apply(self._init_weights)
-
- def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
- token_logits = result.logits
- x = result.hidden_states
-
- # Codebook
- codebook_logits = self.codebook_output(self.codebook_norm(x))
- codebook_logits = rearrange(
- codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
- )
-
- return TransformerForwardResult(
- token_logits=token_logits,
- codebook_logits=codebook_logits,
- )
-
- def forward(
- self,
- inp: Tensor,
- key_padding_mask: Optional[Tensor] = None,
- ) -> TransformerForwardResult:
- result = super().forward(
- inp=inp,
- key_padding_mask=key_padding_mask,
- )
- return self.decode(result)
-
- def forward_generate(
- self, x: Tensor, input_pos: Optional[Tensor] = None
- ) -> TransformerForwardResult:
- result = super().forward_generate(x, input_pos)
- return self.decode(result)
-
-
-class DualARTransformer(BaseTransformer):
- def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
- super().__init__(config, init_weights=False, tokenizer=tokenizer)
-
- # Fast transformer
- self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim)
-
- # The equivalent bs is so large that sdpa doesn't work
- self.fast_layers = nn.ModuleList(
- TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
- )
- self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
- self.fast_output = nn.Linear(
- config.dim,
- config.codebook_size,
- bias=False,
- )
-
- self.apply(self._init_weights)
-
- def setup_caches(
- self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
- ):
- super().setup_caches(max_batch_size, max_seq_len, dtype)
-
- head_dim = self.config.dim // self.config.n_head
-
- # Fast transformer
- # The max seq len here is the number of codebooks
- for b in self.fast_layers:
- b.attention.kv_cache = KVCache(
- max_batch_size,
- self.config.num_codebooks,
- self.config.n_local_heads,
- head_dim,
- dtype=dtype,
- )
-
- def forward(
- self,
- inp: Tensor,
- key_padding_mask: Optional[Tensor] = None,
- ) -> TransformerForwardResult:
- parent_result = super().forward(inp, key_padding_mask)
- token_logits = parent_result.logits
- x = parent_result.hidden_states
-
- # Fast transformer
- fast_seq_len = self.config.num_codebooks
- fast_mask = self.causal_mask[
- None, None, :fast_seq_len, :fast_seq_len
- ] # (B, N, Q, K)
- fast_freqs_cis = self.freqs_cis[:fast_seq_len]
-
- # Drop the last token and rotate left
- codebooks = inp[:, 1:-1, 1:]
- codebooks = F.pad(codebooks, (0, 1), value=0)
- codebook_embeddings = self.fast_embeddings(codebooks)
- x = torch.cat([x[:, None], codebook_embeddings], dim=1)
- b, s = x.size(0), x.size(2)
- x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
-
- # Remove padded part
- codebooks = rearrange(codebooks, "b n s -> (b s) n")
- codebook_mask = (codebooks == 0).all(dim=-1)
-
- if torch.all(codebook_mask):
- # If all codebooks are padded, we keep first 8 to make sure the model runs
- codebook_mask[:8] = False
-
- x_bs, x_len = x.size(0), x.size(1)
- x = x[~codebook_mask]
-
- for layer in self.fast_layers:
- if self.config.use_gradient_checkpointing and self.training:
- x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
- else:
- x = layer(x, fast_freqs_cis, fast_mask)
-
- # unflatten the batch and num_codebooks
- fast_out = self.fast_norm(x)
- codebook_logits = self.fast_output(fast_out)
-
- # Re-pad the codebook_logits
- buffer = torch.zeros(
- x_bs,
- x_len,
- codebook_logits.size(-1),
- device=codebook_logits.device,
- dtype=codebook_logits.dtype,
- )
- buffer[~codebook_mask] = codebook_logits
- codebook_logits = buffer
-
- assert codebook_logits.shape[1] == self.config.num_codebooks
- codebook_logits = rearrange(
- codebook_logits,
- "(b s) n d -> b s n d",
- b=b,
- s=s,
- n=self.config.num_codebooks,
- )
-
- return TransformerForwardResult(
- token_logits=token_logits,
- codebook_logits=codebook_logits,
- )
-
- def forward_generate_fast(
- self, x: Tensor, input_pos: Optional[Tensor] = None
- ) -> Tensor:
- # Fast transformer
- x = x.view(1, 1, -1)
-
- fast_mask = self.causal_mask[
- None, None, input_pos, : self.config.num_codebooks
- ] # (B, N, Q, K)
- fast_freqs_cis = self.freqs_cis[input_pos]
-
- for layer in self.fast_layers:
- x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
-
- # unflatten the batch and num_codebooks
- fast_out = self.fast_norm(x) # only take the last token
- codebook_logits = self.fast_output(fast_out)
-
- return codebook_logits
-
-
-class TransformerBlock(nn.Module):
- def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
- super().__init__()
- self.attention = Attention(config, use_sdpa=use_sdpa)
- self.feed_forward = FeedForward(config)
- self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
- self.attention_norm = RMSNorm(config.dim, config.norm_eps)
-
- def forward(
- self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
- ) -> Tensor:
- h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
- out = h + self.feed_forward(self.ffn_norm(h))
- return out
-
-
-class Attention(nn.Module):
- def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
- super().__init__()
- assert config.dim % config.n_head == 0
-
- total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
- # key, query, value projections for all heads, but in a batch
- self.wqkv = nn.Linear(
- config.dim, total_head_dim, bias=config.attention_qkv_bias
- )
- self.wo = nn.Linear(config.dim, config.dim, bias=False)
- self.kv_cache = None
-
- self.dropout = config.dropout
- self.n_head = config.n_head
- self.head_dim = config.head_dim
- self.n_local_heads = config.n_local_heads
- self.dim = config.dim
- self.use_sdpa = use_sdpa
- self._register_load_state_dict_pre_hook(self.load_hook)
-
- def load_hook(self, state_dict, prefix, *args):
- if prefix + "wq.weight" in state_dict:
- wq = state_dict.pop(prefix + "wq.weight")
- wk = state_dict.pop(prefix + "wk.weight")
- wv = state_dict.pop(prefix + "wv.weight")
- state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
-
- def forward(
- self,
- x: Tensor,
- freqs_cis: Tensor,
- mask: Tensor,
- input_pos: Optional[Tensor] = None,
- ) -> Tensor:
- bsz, seqlen, _ = x.shape
-
- kv_size = self.n_local_heads * self.head_dim
- q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
-
- q = q.view(bsz, seqlen, self.n_head, self.head_dim)
- k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
- v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
-
- q = apply_rotary_emb(q, freqs_cis)
- k = apply_rotary_emb(k, freqs_cis)
-
- q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
-
- if self.kv_cache is not None:
- k, v = self.kv_cache.update(input_pos, k, v)
-
- k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
- v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
-
- if self.use_sdpa:
- if mask is None:
- with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
- y = F.scaled_dot_product_attention(
- q,
- k,
- v,
- dropout_p=self.dropout if self.training else 0.0,
- is_causal=True,
- # No third party attn_mask here to use flash_attention
- )
- else:
- y = F.scaled_dot_product_attention(
- q,
- k,
- v,
- attn_mask=mask,
- dropout_p=self.dropout if self.training else 0.0,
- )
- else:
- y = self.eq_scaled_dot_product_attention(
- q,
- k,
- v,
- attn_mask=mask,
- dropout_p=self.dropout if self.training else 0.0,
- )
-
- y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
-
- return self.wo(y)
-
- def eq_scaled_dot_product_attention(
- self,
- query,
- key,
- value,
- attn_mask=None,
- dropout_p=0.0,
- ) -> torch.Tensor:
- # This is a standard scaled dot product attention
- # It's low efficient, but it doesn't raise cuda error
-
- L, S = query.size(-2), key.size(-2)
- scale_factor = 1 / math.sqrt(query.size(-1))
- attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
-
- if attn_mask is not None:
- if attn_mask.dtype == torch.bool:
- attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
- else:
- attn_bias += attn_mask
-
- attn_weight = query @ key.transpose(-2, -1) * scale_factor
- attn_weight += attn_bias
- attn_weight = torch.softmax(attn_weight, dim=-1)
- attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
-
- return attn_weight @ value
-
-
-class FeedForward(nn.Module):
- def __init__(self, config: BaseModelArgs) -> None:
- super().__init__()
- self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
- self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
- self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
-
- def forward(self, x: Tensor) -> Tensor:
- return self.w2(F.silu(self.w1(x)) * self.w3(x))
-
-
-class RMSNorm(nn.Module):
- def __init__(self, dim: int, eps: float = 1e-5):
- super().__init__()
- self.eps = eps
- self.weight = nn.Parameter(torch.ones(dim))
-
- def _norm(self, x):
- return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
-
- def forward(self, x: Tensor) -> Tensor:
- output = self._norm(x.float()).type_as(x)
- return output * self.weight
-
-
-def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
- freqs = 1.0 / (
- base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
- )
- t = torch.arange(seq_len, device=freqs.device)
- freqs = torch.outer(t, freqs)
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
- cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
- return cache.to(dtype=torch.bfloat16)
-
-
-def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
- xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
- freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
- x_out2 = torch.stack(
- [
- xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
- xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
- ],
- -1,
- )
-
- x_out2 = x_out2.flatten(3)
- return x_out2.type_as(x)
diff --git a/fish_speech/models/text2semantic/lora.py b/fish_speech/models/text2semantic/lora.py
deleted file mode 100644
index 647ca6fcccf038e17d2cf91a2874281dff3e0938..0000000000000000000000000000000000000000
--- a/fish_speech/models/text2semantic/lora.py
+++ /dev/null
@@ -1,92 +0,0 @@
-from dataclasses import dataclass
-
-import loralib as lora
-
-
-@dataclass
-class LoraConfig:
- r: int
- lora_alpha: float
- lora_dropout: float = 0.0
-
-
-def setup_lora(model, lora_config):
- # Replace the embedding layer with a LoRA layer
- model.embeddings = lora.Embedding(
- num_embeddings=model.embeddings.num_embeddings,
- embedding_dim=model.embeddings.embedding_dim,
- padding_idx=model.embeddings.padding_idx,
- r=lora_config.r,
- lora_alpha=lora_config.lora_alpha,
- )
-
- model.codebook_embeddings = lora.Embedding(
- num_embeddings=model.codebook_embeddings.num_embeddings,
- embedding_dim=model.codebook_embeddings.embedding_dim,
- padding_idx=model.codebook_embeddings.padding_idx,
- r=lora_config.r,
- lora_alpha=lora_config.lora_alpha,
- )
-
- # Replace output layer with a LoRA layer
- linears = [(model, "output")]
-
- # Replace all linear layers with LoRA layers
- for layer in model.layers:
- linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
- linears.extend(
- [
- (layer.feed_forward, "w1"),
- (layer.feed_forward, "w2"),
- (layer.feed_forward, "w3"),
- ]
- )
-
- if hasattr(model, "fast_layers"):
- model.fast_embeddings = lora.Embedding(
- num_embeddings=model.fast_embeddings.num_embeddings,
- embedding_dim=model.fast_embeddings.embedding_dim,
- padding_idx=model.fast_embeddings.padding_idx,
- r=lora_config.r,
- lora_alpha=lora_config.lora_alpha,
- )
-
- # Dual-AR model
- linears.append((model, "fast_output"))
-
- for layer in model.fast_layers:
- linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
- linears.extend(
- [
- (layer.feed_forward, "w1"),
- (layer.feed_forward, "w2"),
- (layer.feed_forward, "w3"),
- ]
- )
-
- for module, layer in linears:
- updated_linear = lora.Linear(
- in_features=getattr(module, layer).in_features,
- out_features=getattr(module, layer).out_features,
- bias=getattr(module, layer).bias,
- r=lora_config.r,
- lora_alpha=lora_config.lora_alpha,
- lora_dropout=lora_config.lora_dropout,
- )
- setattr(module, layer, updated_linear)
-
- # Mark only the LoRA layers as trainable
- lora.mark_only_lora_as_trainable(model, bias="none")
-
-
-def get_merged_state_dict(model):
- # This line will merge the state dict of the model and the LoRA parameters
- model.eval()
-
- # Then we need to remove the LoRA parameters from the state dict
- state_dict = model.state_dict()
- for name in list(state_dict.keys()):
- if "lora" in name:
- state_dict.pop(name)
-
- return state_dict
diff --git a/fish_speech/models/vqgan/__init__.py b/fish_speech/models/vqgan/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/fish_speech/models/vqgan/__pycache__/__init__.cpython-310.pyc b/fish_speech/models/vqgan/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index 7370a6672b015f38616e92542abc71ddeeb7a87e..0000000000000000000000000000000000000000
Binary files a/fish_speech/models/vqgan/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/models/vqgan/modules/__pycache__/firefly.cpython-310.pyc b/fish_speech/models/vqgan/modules/__pycache__/firefly.cpython-310.pyc
deleted file mode 100644
index 588bdb4e0f0f6fc5f9838713164c8ed4158b3303..0000000000000000000000000000000000000000
Binary files a/fish_speech/models/vqgan/modules/__pycache__/firefly.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/models/vqgan/modules/__pycache__/fsq.cpython-310.pyc b/fish_speech/models/vqgan/modules/__pycache__/fsq.cpython-310.pyc
deleted file mode 100644
index 22aab32a8842e848cceb650a0a9274c4402bfddb..0000000000000000000000000000000000000000
Binary files a/fish_speech/models/vqgan/modules/__pycache__/fsq.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/models/vqgan/modules/firefly.py b/fish_speech/models/vqgan/modules/firefly.py
deleted file mode 100644
index aa21839b544174d5d91378c5daf8fe1b376a154a..0000000000000000000000000000000000000000
--- a/fish_speech/models/vqgan/modules/firefly.py
+++ /dev/null
@@ -1,596 +0,0 @@
-import math
-from functools import partial
-from math import prod
-from typing import Callable
-
-import torch
-import torch.nn.functional as F
-from torch import nn
-from torch.nn.utils.parametrizations import weight_norm
-from torch.nn.utils.parametrize import remove_parametrizations
-from torch.utils.checkpoint import checkpoint
-
-
-def sequence_mask(length, max_length=None):
- if max_length is None:
- max_length = length.max()
- x = torch.arange(max_length, dtype=length.dtype, device=length.device)
- return x.unsqueeze(0) < length.unsqueeze(1)
-
-
-def init_weights(m, mean=0.0, std=0.01):
- classname = m.__class__.__name__
- if classname.find("Conv1D") != -1:
- m.weight.data.normal_(mean, std)
-
-
-def get_padding(kernel_size, dilation=1):
- return (kernel_size * dilation - dilation) // 2
-
-
-def unpad1d(x: torch.Tensor, paddings: tuple[int, int]):
- """Remove padding from x, handling properly zero padding. Only for 1d!"""
- padding_left, padding_right = paddings
- assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
- assert (padding_left + padding_right) <= x.shape[-1]
- end = x.shape[-1] - padding_right
- return x[..., padding_left:end]
-
-
-def get_extra_padding_for_conv1d(
- x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
-) -> int:
- """See `pad_for_conv1d`."""
- length = x.shape[-1]
- n_frames = (length - kernel_size + padding_total) / stride + 1
- ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
- return ideal_length - length
-
-
-def pad1d(
- x: torch.Tensor,
- paddings: tuple[int, int],
- mode: str = "zeros",
- value: float = 0.0,
-):
- """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
- If this is the case, we insert extra 0 padding to the right
- before the reflection happen.
- """
- length = x.shape[-1]
- padding_left, padding_right = paddings
- assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
- if mode == "reflect":
- max_pad = max(padding_left, padding_right)
- extra_pad = 0
- if length <= max_pad:
- extra_pad = max_pad - length + 1
- x = F.pad(x, (0, extra_pad))
- padded = F.pad(x, paddings, mode, value)
- end = padded.shape[-1] - extra_pad
- return padded[..., :end]
- else:
- return F.pad(x, paddings, mode, value)
-
-
-class FishConvNet(nn.Module):
- def __init__(
- self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1
- ):
- super(FishConvNet, self).__init__()
- self.conv = nn.Conv1d(
- in_channels,
- out_channels,
- kernel_size,
- stride=stride,
- dilation=dilation,
- groups=groups,
- )
- self.stride = stride
- self.kernel_size = (kernel_size - 1) * dilation + 1
- self.dilation = dilation
-
- def forward(self, x):
- pad = self.kernel_size - self.stride
- extra_padding = get_extra_padding_for_conv1d(
- x, self.kernel_size, self.stride, pad
- )
- x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
- return self.conv(x).contiguous()
-
- def weight_norm(self, name="weight", dim=0):
- self.conv = weight_norm(self.conv, name=name, dim=dim)
- return self
-
- def remove_weight_norm(self):
- self.conv = remove_parametrizations(self.conv)
- return self
-
-
-class FishTransConvNet(nn.Module):
- def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1):
- super(FishTransConvNet, self).__init__()
- self.conv = nn.ConvTranspose1d(
- in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
- )
- self.stride = stride
- self.kernel_size = kernel_size
-
- def forward(self, x):
- x = self.conv(x)
- pad = self.kernel_size - self.stride
- padding_right = math.ceil(pad)
- padding_left = pad - padding_right
- x = unpad1d(x, (padding_left, padding_right))
- return x.contiguous()
-
- def weight_norm(self, name="weight", dim=0):
- self.conv = weight_norm(self.conv, name=name, dim=dim)
- return self
-
- def remove_weight_norm(self):
- self.conv = remove_parametrizations(self.conv)
- return self
-
-
-class ResBlock1(torch.nn.Module):
- def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
- super().__init__()
-
- self.convs1 = nn.ModuleList(
- [
- FishConvNet(
- channels, channels, kernel_size, stride=1, dilation=dilation[0]
- ).weight_norm(),
- FishConvNet(
- channels, channels, kernel_size, stride=1, dilation=dilation[1]
- ).weight_norm(),
- FishConvNet(
- channels, channels, kernel_size, stride=1, dilation=dilation[2]
- ).weight_norm(),
- ]
- )
- self.convs1.apply(init_weights)
-
- self.convs2 = nn.ModuleList(
- [
- FishConvNet(
- channels, channels, kernel_size, stride=1, dilation=dilation[0]
- ).weight_norm(),
- FishConvNet(
- channels, channels, kernel_size, stride=1, dilation=dilation[1]
- ).weight_norm(),
- FishConvNet(
- channels, channels, kernel_size, stride=1, dilation=dilation[2]
- ).weight_norm(),
- ]
- )
- self.convs2.apply(init_weights)
-
- def forward(self, x):
- for c1, c2 in zip(self.convs1, self.convs2):
- xt = F.silu(x)
- xt = c1(xt)
- xt = F.silu(xt)
- xt = c2(xt)
- x = xt + x
- return x
-
- def remove_parametrizations(self):
- for conv in self.convs1:
- remove_parametrizations(conv, tensor_name="weight")
- for conv in self.convs2:
- remove_parametrizations(conv, tensor_name="weight")
-
-
-class ParallelBlock(nn.Module):
- def __init__(
- self,
- channels: int,
- kernel_sizes: tuple[int] = (3, 7, 11),
- dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
- ):
- super().__init__()
-
- assert len(kernel_sizes) == len(dilation_sizes)
-
- self.blocks = nn.ModuleList()
- for k, d in zip(kernel_sizes, dilation_sizes):
- self.blocks.append(ResBlock1(channels, k, d))
-
- def forward(self, x):
- return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
-
- def remove_parametrizations(self):
- for block in self.blocks:
- block.remove_parametrizations()
-
-
-class HiFiGANGenerator(nn.Module):
- def __init__(
- self,
- *,
- hop_length: int = 512,
- upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
- upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
- resblock_kernel_sizes: tuple[int] = (3, 7, 11),
- resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
- num_mels: int = 128,
- upsample_initial_channel: int = 512,
- pre_conv_kernel_size: int = 7,
- post_conv_kernel_size: int = 7,
- post_activation: Callable = partial(nn.SiLU, inplace=True),
- ):
- super().__init__()
-
- assert (
- prod(upsample_rates) == hop_length
- ), f"hop_length must be {prod(upsample_rates)}"
-
- self.conv_pre = FishConvNet(
- num_mels,
- upsample_initial_channel,
- pre_conv_kernel_size,
- stride=1,
- ).weight_norm()
-
- self.num_upsamples = len(upsample_rates)
- self.num_kernels = len(resblock_kernel_sizes)
-
- self.noise_convs = nn.ModuleList()
- self.ups = nn.ModuleList()
-
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
- self.ups.append(
- FishTransConvNet(
- upsample_initial_channel // (2**i),
- upsample_initial_channel // (2 ** (i + 1)),
- k,
- stride=u,
- ).weight_norm()
- )
-
- self.resblocks = nn.ModuleList()
- for i in range(len(self.ups)):
- ch = upsample_initial_channel // (2 ** (i + 1))
- self.resblocks.append(
- ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
- )
-
- self.activation_post = post_activation()
- self.conv_post = FishConvNet(
- ch, 1, post_conv_kernel_size, stride=1
- ).weight_norm()
- self.ups.apply(init_weights)
- self.conv_post.apply(init_weights)
-
- def forward(self, x):
- x = self.conv_pre(x)
-
- for i in range(self.num_upsamples):
- x = F.silu(x, inplace=True)
- x = self.ups[i](x)
-
- if self.training and self.checkpointing:
- x = checkpoint(
- self.resblocks[i],
- x,
- use_reentrant=False,
- )
- else:
- x = self.resblocks[i](x)
-
- x = self.activation_post(x)
- x = self.conv_post(x)
- x = torch.tanh(x)
-
- return x
-
- def remove_parametrizations(self):
- for up in self.ups:
- remove_parametrizations(up, tensor_name="weight")
- for block in self.resblocks:
- block.remove_parametrizations()
- remove_parametrizations(self.conv_pre, tensor_name="weight")
- remove_parametrizations(self.conv_post, tensor_name="weight")
-
-
-# DropPath copied from timm library
-def drop_path(
- x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
-):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
-
- This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
- the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
- See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
- changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
- 'survival rate' as the argument.
-
- """ # noqa: E501
-
- if drop_prob == 0.0 or not training:
- return x
- keep_prob = 1 - drop_prob
- shape = (x.shape[0],) + (1,) * (
- x.ndim - 1
- ) # work with diff dim tensors, not just 2D ConvNets
- random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
- if keep_prob > 0.0 and scale_by_keep:
- random_tensor.div_(keep_prob)
- return x * random_tensor
-
-
-class DropPath(nn.Module):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
-
- def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
- super(DropPath, self).__init__()
- self.drop_prob = drop_prob
- self.scale_by_keep = scale_by_keep
-
- def forward(self, x):
- return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
-
- def extra_repr(self):
- return f"drop_prob={round(self.drop_prob,3):0.3f}"
-
-
-class LayerNorm(nn.Module):
- r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
- The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
- shape (batch_size, height, width, channels) while channels_first corresponds to inputs
- with shape (batch_size, channels, height, width).
- """ # noqa: E501
-
- def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
- super().__init__()
- self.weight = nn.Parameter(torch.ones(normalized_shape))
- self.bias = nn.Parameter(torch.zeros(normalized_shape))
- self.eps = eps
- self.data_format = data_format
- if self.data_format not in ["channels_last", "channels_first"]:
- raise NotImplementedError
- self.normalized_shape = (normalized_shape,)
-
- def forward(self, x):
- if self.data_format == "channels_last":
- return F.layer_norm(
- x, self.normalized_shape, self.weight, self.bias, self.eps
- )
- elif self.data_format == "channels_first":
- u = x.mean(1, keepdim=True)
- s = (x - u).pow(2).mean(1, keepdim=True)
- x = (x - u) / torch.sqrt(s + self.eps)
- x = self.weight[:, None] * x + self.bias[:, None]
- return x
-
-
-# ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
-class ConvNeXtBlock(nn.Module):
- r"""ConvNeXt Block. There are two equivalent implementations:
- (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
- (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
- We use (2) as we find it slightly faster in PyTorch
-
- Args:
- dim (int): Number of input channels.
- drop_path (float): Stochastic depth rate. Default: 0.0
- layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
- kernel_size (int): Kernel size for depthwise conv. Default: 7.
- dilation (int): Dilation for depthwise conv. Default: 1.
- """ # noqa: E501
-
- def __init__(
- self,
- dim: int,
- drop_path: float = 0.0,
- layer_scale_init_value: float = 1e-6,
- mlp_ratio: float = 4.0,
- kernel_size: int = 7,
- dilation: int = 1,
- ):
- super().__init__()
-
- self.dwconv = FishConvNet(
- dim,
- dim,
- kernel_size=kernel_size,
- # padding=int(dilation * (kernel_size - 1) / 2),
- groups=dim,
- ) # depthwise conv
- self.norm = LayerNorm(dim, eps=1e-6)
- self.pwconv1 = nn.Linear(
- dim, int(mlp_ratio * dim)
- ) # pointwise/1x1 convs, implemented with linear layers
- self.act = nn.GELU()
- self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
- self.gamma = (
- nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
- if layer_scale_init_value > 0
- else None
- )
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
-
- def forward(self, x, apply_residual: bool = True):
- input = x
-
- x = self.dwconv(x)
- x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
- x = self.norm(x)
- x = self.pwconv1(x)
- x = self.act(x)
- x = self.pwconv2(x)
-
- if self.gamma is not None:
- x = self.gamma * x
-
- x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
- x = self.drop_path(x)
-
- if apply_residual:
- x = input + x
-
- return x
-
-
-class ConvNeXtEncoder(nn.Module):
- def __init__(
- self,
- input_channels: int = 3,
- depths: list[int] = [3, 3, 9, 3],
- dims: list[int] = [96, 192, 384, 768],
- drop_path_rate: float = 0.0,
- layer_scale_init_value: float = 1e-6,
- kernel_size: int = 7,
- ):
- super().__init__()
- assert len(depths) == len(dims)
-
- self.downsample_layers = nn.ModuleList()
- stem = nn.Sequential(
- FishConvNet(
- input_channels,
- dims[0],
- kernel_size=7,
- # padding=3,
- # padding_mode="replicate",
- # padding_mode="zeros",
- ),
- LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
- )
- self.downsample_layers.append(stem)
-
- for i in range(len(depths) - 1):
- mid_layer = nn.Sequential(
- LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
- nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
- )
- self.downsample_layers.append(mid_layer)
-
- self.stages = nn.ModuleList()
- dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
-
- cur = 0
- for i in range(len(depths)):
- stage = nn.Sequential(
- *[
- ConvNeXtBlock(
- dim=dims[i],
- drop_path=dp_rates[cur + j],
- layer_scale_init_value=layer_scale_init_value,
- kernel_size=kernel_size,
- )
- for j in range(depths[i])
- ]
- )
- self.stages.append(stage)
- cur += depths[i]
-
- self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, (nn.Conv1d, nn.Linear)):
- nn.init.trunc_normal_(m.weight, std=0.02)
- nn.init.constant_(m.bias, 0)
-
- def forward(
- self,
- x: torch.Tensor,
- ) -> torch.Tensor:
- for i in range(len(self.downsample_layers)):
- x = self.downsample_layers[i](x)
- x = self.stages[i](x)
-
- return self.norm(x)
-
-
-class FireflyArchitecture(nn.Module):
- def __init__(
- self,
- backbone: nn.Module,
- head: nn.Module,
- quantizer: nn.Module,
- spec_transform: nn.Module,
- ):
- super().__init__()
-
- self.backbone = backbone
- self.head = head
- self.quantizer = quantizer
- self.spec_transform = spec_transform
- self.downsample_factor = math.prod(self.quantizer.downsample_factor)
-
- def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
- if self.spec_transform is not None:
- x = self.spec_transform(x)
-
- x = self.backbone(x)
- if mask is not None:
- x = x * mask
-
- if self.quantizer is not None:
- vq_result = self.quantizer(x)
- x = vq_result.z
-
- if mask is not None:
- x = x * mask
-
- x = self.head(x, template=template)
-
- if x.ndim == 2:
- x = x[:, None, :]
-
- if self.vq is not None:
- return x, vq_result
-
- return x
-
- def encode(self, audios, audio_lengths):
- audios = audios.float()
-
- mels = self.spec_transform(audios)
- mel_lengths = audio_lengths // self.spec_transform.hop_length
- mel_masks = sequence_mask(mel_lengths, mels.shape[2])
- mel_masks_float_conv = mel_masks[:, None, :].float()
- mels = mels * mel_masks_float_conv
-
- # Encode
- encoded_features = self.backbone(mels) * mel_masks_float_conv
- feature_lengths = mel_lengths // self.downsample_factor
-
- return self.quantizer.encode(encoded_features), feature_lengths
-
- def decode(self, indices, feature_lengths) -> torch.Tensor:
- mel_masks = sequence_mask(
- feature_lengths * self.downsample_factor,
- indices.shape[2] * self.downsample_factor,
- )
- mel_masks_float_conv = mel_masks[:, None, :].float()
- audio_lengths = (
- feature_lengths * self.downsample_factor * self.spec_transform.hop_length
- )
-
- audio_masks = sequence_mask(
- audio_lengths,
- indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length,
- )
- audio_masks_float_conv = audio_masks[:, None, :].float()
-
- z = self.quantizer.decode(indices) * mel_masks_float_conv
- x = self.head(z) * audio_masks_float_conv
-
- return x, audio_lengths
-
- def remove_parametrizations(self):
- if hasattr(self.backbone, "remove_parametrizations"):
- self.backbone.remove_parametrizations()
-
- if hasattr(self.head, "remove_parametrizations"):
- self.head.remove_parametrizations()
-
- @property
- def device(self):
- return next(self.parameters()).device
diff --git a/fish_speech/models/vqgan/modules/fsq.py b/fish_speech/models/vqgan/modules/fsq.py
deleted file mode 100644
index 7ea4853376b6e663404ff48d6c6b5f664dde4094..0000000000000000000000000000000000000000
--- a/fish_speech/models/vqgan/modules/fsq.py
+++ /dev/null
@@ -1,116 +0,0 @@
-from dataclasses import dataclass
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from einops import rearrange
-from vector_quantize_pytorch import GroupedResidualFSQ
-
-from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet
-
-
-@dataclass
-class FSQResult:
- z: torch.Tensor
- codes: torch.Tensor
- latents: torch.Tensor
-
-
-class DownsampleFiniteScalarQuantize(nn.Module):
- def __init__(
- self,
- input_dim: int = 512,
- n_codebooks: int = 9,
- n_groups: int = 1,
- levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
- downsample_factor: tuple[int] = (2, 2),
- downsample_dims: tuple[int] | None = None,
- ):
- super().__init__()
-
- if downsample_dims is None:
- downsample_dims = [input_dim for _ in range(len(downsample_factor))]
-
- all_dims = (input_dim,) + tuple(downsample_dims)
-
- self.residual_fsq = GroupedResidualFSQ(
- dim=all_dims[-1],
- levels=levels,
- num_quantizers=n_codebooks,
- groups=n_groups,
- )
-
- self.downsample_factor = downsample_factor
- self.downsample_dims = downsample_dims
-
- self.downsample = nn.Sequential(
- *[
- nn.Sequential(
- FishConvNet(
- all_dims[idx],
- all_dims[idx + 1],
- kernel_size=factor,
- stride=factor,
- ),
- ConvNeXtBlock(dim=all_dims[idx + 1]),
- )
- for idx, factor in enumerate(downsample_factor)
- ]
- )
-
- self.upsample = nn.Sequential(
- *[
- nn.Sequential(
- FishTransConvNet(
- all_dims[idx + 1],
- all_dims[idx],
- kernel_size=factor,
- stride=factor,
- ),
- ConvNeXtBlock(dim=all_dims[idx]),
- )
- for idx, factor in reversed(list(enumerate(downsample_factor)))
- ]
- )
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, (nn.Conv1d, nn.Linear)):
- nn.init.trunc_normal_(m.weight, std=0.02)
- nn.init.constant_(m.bias, 0)
-
- def forward(self, z) -> FSQResult:
- original_shape = z.shape
- z = self.downsample(z)
- quantized, indices = self.residual_fsq(z.mT)
- result = FSQResult(
- z=quantized.mT,
- codes=indices.mT,
- latents=z,
- )
- result.z = self.upsample(result.z)
-
- # Pad or crop z to match original shape
- diff = original_shape[-1] - result.z.shape[-1]
- left = diff // 2
- right = diff - left
-
- if diff > 0:
- result.z = F.pad(result.z, (left, right))
- elif diff < 0:
- result.z = result.z[..., left:-right]
-
- return result
-
- def encode(self, z):
- z = self.downsample(z)
- _, indices = self.residual_fsq(z.mT)
- indices = rearrange(indices, "g b l r -> b (g r) l")
- return indices
-
- def decode(self, indices: torch.Tensor):
- indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
- z_q = self.residual_fsq.get_output_from_indices(indices)
- z_q = self.upsample(z_q.mT)
- return z_q
diff --git a/fish_speech/models/vqgan/utils.py b/fish_speech/models/vqgan/utils.py
deleted file mode 100644
index b90c131d214006875476a161cdfd2dffa8949dac..0000000000000000000000000000000000000000
--- a/fish_speech/models/vqgan/utils.py
+++ /dev/null
@@ -1,94 +0,0 @@
-import matplotlib
-import torch
-from matplotlib import pyplot as plt
-
-matplotlib.use("Agg")
-
-
-def convert_pad_shape(pad_shape):
- l = pad_shape[::-1]
- pad_shape = [item for sublist in l for item in sublist]
- return pad_shape
-
-
-def sequence_mask(length, max_length=None):
- if max_length is None:
- max_length = length.max()
- x = torch.arange(max_length, dtype=length.dtype, device=length.device)
- return x.unsqueeze(0) < length.unsqueeze(1)
-
-
-def init_weights(m, mean=0.0, std=0.01):
- classname = m.__class__.__name__
- if classname.find("Conv") != -1:
- m.weight.data.normal_(mean, std)
-
-
-def get_padding(kernel_size, dilation=1):
- return int((kernel_size * dilation - dilation) / 2)
-
-
-def plot_mel(data, titles=None):
- fig, axes = plt.subplots(len(data), 1, squeeze=False)
-
- if titles is None:
- titles = [None for i in range(len(data))]
-
- plt.tight_layout()
-
- for i in range(len(data)):
- mel = data[i]
-
- if isinstance(mel, torch.Tensor):
- mel = mel.float().detach().cpu().numpy()
-
- axes[i][0].imshow(mel, origin="lower")
- axes[i][0].set_aspect(2.5, adjustable="box")
- axes[i][0].set_ylim(0, mel.shape[0])
- axes[i][0].set_title(titles[i], fontsize="medium")
- axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
- axes[i][0].set_anchor("W")
-
- return fig
-
-
-def slice_segments(x, ids_str, segment_size=4):
- ret = torch.zeros_like(x[:, :, :segment_size])
- for i in range(x.size(0)):
- idx_str = ids_str[i]
- idx_end = idx_str + segment_size
- ret[i] = x[i, :, idx_str:idx_end]
-
- return ret
-
-
-def rand_slice_segments(x, x_lengths=None, segment_size=4):
- b, d, t = x.size()
- if x_lengths is None:
- x_lengths = t
- ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
- ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
- ret = slice_segments(x, ids_str, segment_size)
- return ret, ids_str
-
-
-@torch.jit.script
-def fused_add_tanh_sigmoid_multiply(in_act, n_channels):
- n_channels_int = n_channels[0]
- t_act = torch.tanh(in_act[:, :n_channels_int, :])
- s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
- acts = t_act * s_act
-
- return acts
-
-
-def avg_with_mask(x, mask):
- assert mask.dtype == torch.float, "Mask should be float"
-
- if mask.ndim == 2:
- mask = mask.unsqueeze(1)
-
- if mask.shape[1] == 1:
- mask = mask.expand_as(x)
-
- return (x * mask).sum() / mask.sum()
diff --git a/fish_speech/scheduler.py b/fish_speech/scheduler.py
deleted file mode 100644
index 43bed6a2210723a7d5e1ea0a48ba61140047ca29..0000000000000000000000000000000000000000
--- a/fish_speech/scheduler.py
+++ /dev/null
@@ -1,40 +0,0 @@
-import math
-
-
-def get_cosine_schedule_with_warmup_lr_lambda(
- current_step: int,
- *,
- num_warmup_steps: int | float,
- num_training_steps: int,
- num_cycles: float = 0.5,
- final_lr_ratio: float = 0.0,
-):
- if 0 < num_warmup_steps < 1: # float mode
- num_warmup_steps = int(num_warmup_steps * num_training_steps)
-
- if current_step < num_warmup_steps:
- return float(current_step) / float(max(1, num_warmup_steps))
-
- progress = float(current_step - num_warmup_steps) / float(
- max(1, num_training_steps - num_warmup_steps)
- )
-
- return max(
- final_lr_ratio,
- 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
- )
-
-
-def get_constant_schedule_with_warmup_lr_lambda(
- current_step: int,
- *,
- num_warmup_steps: int | float,
- num_training_steps: int | None = None,
-):
- if 0 < num_warmup_steps < 1: # float mode
- num_warmup_steps = int(num_warmup_steps * num_training_steps)
-
- if current_step < num_warmup_steps:
- return float(current_step) / float(max(1, num_warmup_steps))
-
- return 1.0
diff --git a/fish_speech/text/__init__.py b/fish_speech/text/__init__.py
deleted file mode 100644
index d740bd8eed447d162e55b165965dec17130377ce..0000000000000000000000000000000000000000
--- a/fish_speech/text/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from .clean import clean_text
-from .spliter import split_text
-
-__all__ = ["clean_text", "split_text"]
diff --git a/fish_speech/text/__pycache__/__init__.cpython-310.pyc b/fish_speech/text/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index cbda0e48251bdcc53c332c821e4ea9519047d490..0000000000000000000000000000000000000000
Binary files a/fish_speech/text/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/text/__pycache__/clean.cpython-310.pyc b/fish_speech/text/__pycache__/clean.cpython-310.pyc
deleted file mode 100644
index f8c648bb945e8d4ff16146dff98a70a779dab7eb..0000000000000000000000000000000000000000
Binary files a/fish_speech/text/__pycache__/clean.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/text/__pycache__/spliter.cpython-310.pyc b/fish_speech/text/__pycache__/spliter.cpython-310.pyc
deleted file mode 100644
index 94114179529badf1ecd0fa37c19d5fdc6223dcf9..0000000000000000000000000000000000000000
Binary files a/fish_speech/text/__pycache__/spliter.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/text/chn_text_norm/.gitignore b/fish_speech/text/chn_text_norm/.gitignore
deleted file mode 100644
index 75ea58fa4a7bf34fc9ab35afee24684aa6ef4c89..0000000000000000000000000000000000000000
--- a/fish_speech/text/chn_text_norm/.gitignore
+++ /dev/null
@@ -1,114 +0,0 @@
-# Byte-compiled / optimized / DLL files
-__pycache__/
-*.py[cod]
-*$py.class
-
-# C extensions
-*.so
-
-# Distribution / packaging
-.Python
-build/
-develop-eggs/
-dist/
-downloads/
-eggs/
-.eggs/
-lib/
-lib64/
-parts/
-sdist/
-var/
-wheels/
-*.egg-info/
-.installed.cfg
-*.egg
-MANIFEST
-
-# PyInstaller
-# Usually these files are written by a python script from a template
-# before PyInstaller builds the exe, so as to inject date/other infos into it.
-*.manifest
-*.spec
-
-# Installer logs
-pip-log.txt
-pip-delete-this-directory.txt
-
-# Unit test / coverage reports
-htmlcov/
-.tox/
-.coverage
-.coverage.*
-.cache
-nosetests.xml
-coverage.xml
-*.cover
-.hypothesis/
-.pytest_cache/
-
-# Translations
-*.mo
-*.pot
-
-# Django stuff:
-*.log
-local_settings.py
-db.sqlite3
-
-# Flask stuff:
-instance/
-.webassets-cache
-
-# Scrapy stuff:
-.scrapy
-
-# Sphinx documentation
-docs/_build/
-
-# PyBuilder
-target/
-
-# Jupyter Notebook
-.ipynb_checkpoints
-
-# pyenv
-.python-version
-
-# celery beat schedule file
-celerybeat-schedule
-
-# SageMath parsed files
-*.sage.py
-
-# Environments
-.env
-.venv
-env/
-venv/
-ENV/
-env.bak/
-venv.bak/
-
-# Spyder project settings
-.spyderproject
-.spyproject
-
-# Rope project settings
-.ropeproject
-
-# mkdocs documentation
-/site
-
-# mypy
-.mypy_cache/
-
-# JetBrains PyCharm
-.idea
-
-# Customize
-references
-url.txt
-
-# Git
-.git
diff --git a/fish_speech/text/chn_text_norm/README.md b/fish_speech/text/chn_text_norm/README.md
deleted file mode 100644
index 8450a2c6c0f8e40f4509f5be196eb9f9d2b9afb6..0000000000000000000000000000000000000000
--- a/fish_speech/text/chn_text_norm/README.md
+++ /dev/null
@@ -1,36 +0,0 @@
-# This account is no longer in use, see [Atomicoo](https://github.com/atomicoo) for my latest works.
-
-# Chn Text Norm
-
-this is a repository for chinese text normalization (no longer maintained).
-
-## Quick Start ##
-
-### Git Clone Repo ###
-
-git clone this repo to the root directory of your project which need to use it.
-
- cd /path/to/proj
- git clone https://github.com/Joee1995/chn-text-norm.git
-
-after that, your doc tree should be:
-```
-proj # root of your project
-|--- chn_text_norm # this chn-text-norm tool
- |--- text.py
- |--- ...
-|--- text_normalize.py # your text normalization code
-|--- ...
-```
-
-### How to Use ? ###
-
- # text_normalize.py
- from chn_text_norm.text import *
-
- raw_text = 'your raw text'
- text = Text(raw_text=raw_text).normalize()
-
-### How to add quantums ###
-
-打开test.py,然后你就知道怎么做了。
diff --git a/fish_speech/text/chn_text_norm/__init__.py b/fish_speech/text/chn_text_norm/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/fish_speech/text/chn_text_norm/__pycache__/__init__.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index 34ff30c1d86436d172d82a2afe4f2914407b2056..0000000000000000000000000000000000000000
Binary files a/fish_speech/text/chn_text_norm/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/basic_class.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/basic_class.cpython-310.pyc
deleted file mode 100644
index cf1f70fdddf166b1e08a881f44651789cde5665b..0000000000000000000000000000000000000000
Binary files a/fish_speech/text/chn_text_norm/__pycache__/basic_class.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/basic_constant.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/basic_constant.cpython-310.pyc
deleted file mode 100644
index 7ba0d65d52f907c05d8672ab5ffeba1ef69b0a58..0000000000000000000000000000000000000000
Binary files a/fish_speech/text/chn_text_norm/__pycache__/basic_constant.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/basic_util.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/basic_util.cpython-310.pyc
deleted file mode 100644
index 565e2baec31a08a1d40e473dbaf8cc068c4b56eb..0000000000000000000000000000000000000000
Binary files a/fish_speech/text/chn_text_norm/__pycache__/basic_util.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/cardinal.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/cardinal.cpython-310.pyc
deleted file mode 100644
index 5ac369bcff904eeb22fd4c359b7ef4d0dff2856b..0000000000000000000000000000000000000000
Binary files a/fish_speech/text/chn_text_norm/__pycache__/cardinal.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/date.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/date.cpython-310.pyc
deleted file mode 100644
index cb55a304422a219e00687fc987b4cdcfd8283dcd..0000000000000000000000000000000000000000
Binary files a/fish_speech/text/chn_text_norm/__pycache__/date.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/digit.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/digit.cpython-310.pyc
deleted file mode 100644
index 57bcee45c211f05e127b6556c6c3e0dc05a43e9c..0000000000000000000000000000000000000000
Binary files a/fish_speech/text/chn_text_norm/__pycache__/digit.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/fraction.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/fraction.cpython-310.pyc
deleted file mode 100644
index 2982394aae51d7fef63f6ce4c13444662ccde1af..0000000000000000000000000000000000000000
Binary files a/fish_speech/text/chn_text_norm/__pycache__/fraction.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/money.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/money.cpython-310.pyc
deleted file mode 100644
index 5cdaa0642dce2713d356ae35a4dea9955df67e9c..0000000000000000000000000000000000000000
Binary files a/fish_speech/text/chn_text_norm/__pycache__/money.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/percentage.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/percentage.cpython-310.pyc
deleted file mode 100644
index 1572f267a79a5231149fc5bba14cb4b4d4907895..0000000000000000000000000000000000000000
Binary files a/fish_speech/text/chn_text_norm/__pycache__/percentage.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/telephone.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/telephone.cpython-310.pyc
deleted file mode 100644
index b088af763cadeb63f0ee4308c56032c19da3ed1f..0000000000000000000000000000000000000000
Binary files a/fish_speech/text/chn_text_norm/__pycache__/telephone.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/text/chn_text_norm/__pycache__/text.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/text.cpython-310.pyc
deleted file mode 100644
index f84f49e3880c7410255e8ab1038eeb94d5375656..0000000000000000000000000000000000000000
Binary files a/fish_speech/text/chn_text_norm/__pycache__/text.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/text/chn_text_norm/basic_class.py b/fish_speech/text/chn_text_norm/basic_class.py
deleted file mode 100644
index 58d8f8eb7fc85d0861f106667d8f4e3e52b54761..0000000000000000000000000000000000000000
--- a/fish_speech/text/chn_text_norm/basic_class.py
+++ /dev/null
@@ -1,172 +0,0 @@
-# -*- coding: utf-8 -*-
-"""基本类
-中文字符类
-中文数字/数位类
-中文数字类
-中文数位类
-中文数字系统类
-中文数学符号类
-*中文其他符号类
-"""
-
-__author__ = "Zhiyang Zhou "
-__data__ = "2019-05-02"
-
-from fish_speech.text.chn_text_norm.basic_constant import NUMBERING_TYPES
-
-
-class ChineseChar(object):
- """
- 中文字符
- 每个字符对应简体和繁体,
- e.g. 简体 = '负', 繁体 = '負'
- 转换时可转换为简体或繁体
- """
-
- def __init__(self, simplified, traditional):
- self.simplified = simplified
- self.traditional = traditional
- self.__repr__ = self.__str__
-
- def __str__(self):
- return self.simplified or self.traditional or None
-
- def __repr__(self):
- return self.__str__()
-
-
-class ChineseNumberUnit(ChineseChar):
- """
- 中文数字/数位字符
- 每个字符除繁简体外还有一个额外的大写字符
- e.g. '陆' 和 '陸'
- """
-
- def __init__(self, power, simplified, traditional, big_s, big_t):
- super(ChineseNumberUnit, self).__init__(simplified, traditional)
- self.power = power
- self.big_s = big_s
- self.big_t = big_t
-
- def __str__(self):
- return "10^{}".format(self.power)
-
- @classmethod
- def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
-
- if small_unit:
- return ChineseNumberUnit(
- power=index + 1,
- simplified=value[0],
- traditional=value[1],
- big_s=value[1],
- big_t=value[1],
- )
- elif numbering_type == NUMBERING_TYPES[0]:
- return ChineseNumberUnit(
- power=index + 8,
- simplified=value[0],
- traditional=value[1],
- big_s=value[0],
- big_t=value[1],
- )
- elif numbering_type == NUMBERING_TYPES[1]:
- return ChineseNumberUnit(
- power=(index + 2) * 4,
- simplified=value[0],
- traditional=value[1],
- big_s=value[0],
- big_t=value[1],
- )
- elif numbering_type == NUMBERING_TYPES[2]:
- return ChineseNumberUnit(
- power=pow(2, index + 3),
- simplified=value[0],
- traditional=value[1],
- big_s=value[0],
- big_t=value[1],
- )
- else:
- raise ValueError(
- "Counting type should be in {0} ({1} provided).".format(
- NUMBERING_TYPES, numbering_type
- )
- )
-
-
-class ChineseNumberDigit(ChineseChar):
- """
- 中文数字字符
- """
-
- def __init__(
- self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None
- ):
- super(ChineseNumberDigit, self).__init__(simplified, traditional)
- self.value = value
- self.big_s = big_s
- self.big_t = big_t
- self.alt_s = alt_s
- self.alt_t = alt_t
-
- def __str__(self):
- return str(self.value)
-
- @classmethod
- def create(cls, i, v):
- return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
-
-
-class ChineseMath(ChineseChar):
- """
- 中文数位字符
- """
-
- def __init__(self, simplified, traditional, symbol, expression=None):
- super(ChineseMath, self).__init__(simplified, traditional)
- self.symbol = symbol
- self.expression = expression
- self.big_s = simplified
- self.big_t = traditional
-
-
-CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
-
-
-class NumberSystem(object):
- """
- 中文数字系统
- """
-
- pass
-
-
-class MathSymbol(object):
- """
- 用于中文数字系统的数学符号 (繁/简体), e.g.
- positive = ['正', '正']
- negative = ['负', '負']
- point = ['点', '點']
- """
-
- def __init__(self, positive, negative, point):
- self.positive = positive
- self.negative = negative
- self.point = point
-
- def __iter__(self):
- for v in self.__dict__.values():
- yield v
-
-
-# class OtherSymbol(object):
-# """
-# 其他符号
-# """
-#
-# def __init__(self, sil):
-# self.sil = sil
-#
-# def __iter__(self):
-# for v in self.__dict__.values():
-# yield v
diff --git a/fish_speech/text/chn_text_norm/basic_constant.py b/fish_speech/text/chn_text_norm/basic_constant.py
deleted file mode 100644
index 9a65991b9a9d349a0571c80508633951e52749ef..0000000000000000000000000000000000000000
--- a/fish_speech/text/chn_text_norm/basic_constant.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# -*- coding: utf-8 -*-
-"""基本常量
-中文数字/数位/符号字符常量
-"""
-
-__author__ = "Zhiyang Zhou "
-__data__ = "2019-05-02"
-
-CHINESE_DIGIS = "零一二三四五六七八九"
-BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖"
-BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖"
-SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万"
-SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬"
-LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载"
-LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載"
-SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万"
-SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬"
-
-ZERO_ALT = "〇"
-ONE_ALT = "幺"
-TWO_ALTS = ["两", "兩"]
-
-POSITIVE = ["正", "正"]
-NEGATIVE = ["负", "負"]
-POINT = ["点", "點"]
-# PLUS = [u'加', u'加']
-# SIL = [u'杠', u'槓']
-
-# 中文数字系统类型
-NUMBERING_TYPES = ["low", "mid", "high"]
diff --git a/fish_speech/text/chn_text_norm/basic_util.py b/fish_speech/text/chn_text_norm/basic_util.py
deleted file mode 100644
index dbf6130be87f285eed9998186508ea489d3bac9e..0000000000000000000000000000000000000000
--- a/fish_speech/text/chn_text_norm/basic_util.py
+++ /dev/null
@@ -1,342 +0,0 @@
-# -*- coding: utf-8 -*-
-"""基本方法
-创建中文数字系统 方法
-中文字符串 <=> 数字串 方法
-数字串 <=> 中文字符串 方法
-"""
-
-__author__ = "Zhiyang Zhou "
-__data__ = "2019-05-02"
-
-from fish_speech.text.chn_text_norm.basic_class import *
-from fish_speech.text.chn_text_norm.basic_constant import *
-
-
-def create_system(numbering_type=NUMBERING_TYPES[1]):
- """
- 根据数字系统类型返回创建相应的数字系统,默认为 mid
- NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
- low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc.
- mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc.
- high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc.
- 返回对应的数字系统
- """
-
- # chinese number units of '亿' and larger
- all_larger_units = zip(
- LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED,
- LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL,
- )
- larger_units = [
- CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)
- ]
- # chinese number units of '十, 百, 千, 万'
- all_smaller_units = zip(
- SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED,
- SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL,
- )
- smaller_units = [
- CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)
- ]
- # digis
- chinese_digis = zip(
- CHINESE_DIGIS,
- CHINESE_DIGIS,
- BIG_CHINESE_DIGIS_SIMPLIFIED,
- BIG_CHINESE_DIGIS_TRADITIONAL,
- )
- digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
- digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
- digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
- digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
-
- # symbols
- positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x)
- negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x)
- point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y)))
- # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
- system = NumberSystem()
- system.units = smaller_units + larger_units
- system.digits = digits
- system.math = MathSymbol(positive_cn, negative_cn, point_cn)
- # system.symbols = OtherSymbol(sil_cn)
- return system
-
-
-def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
-
- def get_symbol(char, system):
- for u in system.units:
- if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
- return u
- for d in system.digits:
- if char in [
- d.traditional,
- d.simplified,
- d.big_s,
- d.big_t,
- d.alt_s,
- d.alt_t,
- ]:
- return d
- for m in system.math:
- if char in [m.traditional, m.simplified]:
- return m
-
- def string2symbols(chinese_string, system):
- int_string, dec_string = chinese_string, ""
- for p in [system.math.point.simplified, system.math.point.traditional]:
- if p in chinese_string:
- int_string, dec_string = chinese_string.split(p)
- break
- return [get_symbol(c, system) for c in int_string], [
- get_symbol(c, system) for c in dec_string
- ]
-
- def correct_symbols(integer_symbols, system):
- """
- 一百八 to 一百八十
- 一亿一千三百万 to 一亿 一千万 三百万
- """
-
- if integer_symbols and isinstance(integer_symbols[0], CNU):
- if integer_symbols[0].power == 1:
- integer_symbols = [system.digits[1]] + integer_symbols
-
- if len(integer_symbols) > 1:
- if isinstance(integer_symbols[-1], CND) and isinstance(
- integer_symbols[-2], CNU
- ):
- integer_symbols.append(
- CNU(integer_symbols[-2].power - 1, None, None, None, None)
- )
-
- result = []
- unit_count = 0
- for s in integer_symbols:
- if isinstance(s, CND):
- result.append(s)
- unit_count = 0
- elif isinstance(s, CNU):
- current_unit = CNU(s.power, None, None, None, None)
- unit_count += 1
-
- if unit_count == 1:
- result.append(current_unit)
- elif unit_count > 1:
- for i in range(len(result)):
- if (
- isinstance(result[-i - 1], CNU)
- and result[-i - 1].power < current_unit.power
- ):
- result[-i - 1] = CNU(
- result[-i - 1].power + current_unit.power,
- None,
- None,
- None,
- None,
- )
- return result
-
- def compute_value(integer_symbols):
- """
- Compute the value.
- When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
- e.g. '两千万' = 2000 * 10000 not 2000 + 10000
- """
- value = [0]
- last_power = 0
- for s in integer_symbols:
- if isinstance(s, CND):
- value[-1] = s.value
- elif isinstance(s, CNU):
- value[-1] *= pow(10, s.power)
- if s.power > last_power:
- value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1]))
- last_power = s.power
- value.append(0)
- return sum(value)
-
- system = create_system(numbering_type)
- int_part, dec_part = string2symbols(chinese_string, system)
- int_part = correct_symbols(int_part, system)
- int_str = str(compute_value(int_part))
- dec_str = "".join([str(d.value) for d in dec_part])
- if dec_part:
- return "{0}.{1}".format(int_str, dec_str)
- else:
- return int_str
-
-
-def num2chn(
- number_string,
- numbering_type=NUMBERING_TYPES[1],
- big=False,
- traditional=False,
- alt_zero=False,
- alt_one=False,
- alt_two=True,
- use_zeros=True,
- use_units=True,
-):
-
- def get_value(value_string, use_zeros=True):
-
- striped_string = value_string.lstrip("0")
-
- # record nothing if all zeros
- if not striped_string:
- return []
-
- # record one digits
- elif len(striped_string) == 1:
- if use_zeros and len(value_string) != len(striped_string):
- return [system.digits[0], system.digits[int(striped_string)]]
- else:
- return [system.digits[int(striped_string)]]
-
- # recursively record multiple digits
- else:
- result_unit = next(
- u for u in reversed(system.units) if u.power < len(striped_string)
- )
- result_string = value_string[: -result_unit.power]
- return (
- get_value(result_string)
- + [result_unit]
- + get_value(striped_string[-result_unit.power :])
- )
-
- system = create_system(numbering_type)
-
- int_dec = number_string.split(".")
- if len(int_dec) == 1:
- int_string = int_dec[0]
- dec_string = ""
- elif len(int_dec) == 2:
- int_string = int_dec[0]
- dec_string = int_dec[1]
- else:
- raise ValueError(
- "invalid input num string with more than one dot: {}".format(number_string)
- )
-
- if use_units and len(int_string) > 1:
- result_symbols = get_value(int_string)
- else:
- result_symbols = [system.digits[int(c)] for c in int_string]
- dec_symbols = [system.digits[int(c)] for c in dec_string]
- if dec_string:
- result_symbols += [system.math.point] + dec_symbols
-
- if alt_two:
- liang = CND(
- 2,
- system.digits[2].alt_s,
- system.digits[2].alt_t,
- system.digits[2].big_s,
- system.digits[2].big_t,
- )
- for i, v in enumerate(result_symbols):
- if isinstance(v, CND) and v.value == 2:
- next_symbol = (
- result_symbols[i + 1] if i < len(result_symbols) - 1 else None
- )
- previous_symbol = result_symbols[i - 1] if i > 0 else None
- if isinstance(next_symbol, CNU) and isinstance(
- previous_symbol, (CNU, type(None))
- ):
- if next_symbol.power != 1 and (
- (previous_symbol is None) or (previous_symbol.power != 1)
- ):
- result_symbols[i] = liang
-
- # if big is True, '两' will not be used and `alt_two` has no impact on output
- if big:
- attr_name = "big_"
- if traditional:
- attr_name += "t"
- else:
- attr_name += "s"
- else:
- if traditional:
- attr_name = "traditional"
- else:
- attr_name = "simplified"
-
- result = "".join([getattr(s, attr_name) for s in result_symbols])
-
- # if not use_zeros:
- # result = result.strip(getattr(system.digits[0], attr_name))
-
- if alt_zero:
- result = result.replace(
- getattr(system.digits[0], attr_name), system.digits[0].alt_s
- )
-
- if alt_one:
- result = result.replace(
- getattr(system.digits[1], attr_name), system.digits[1].alt_s
- )
-
- for i, p in enumerate(POINT):
- if result.startswith(p):
- return CHINESE_DIGIS[0] + result
-
- # ^10, 11, .., 19
- if (
- len(result) >= 2
- and result[1]
- in [
- SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
- SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0],
- ]
- and result[0]
- in [
- CHINESE_DIGIS[1],
- BIG_CHINESE_DIGIS_SIMPLIFIED[1],
- BIG_CHINESE_DIGIS_TRADITIONAL[1],
- ]
- ):
- result = result[1:]
-
- return result
-
-
-if __name__ == "__main__":
-
- # 测试程序
- all_chinese_number_string = (
- CHINESE_DIGIS
- + BIG_CHINESE_DIGIS_SIMPLIFIED
- + BIG_CHINESE_DIGIS_TRADITIONAL
- + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED
- + LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL
- + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED
- + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL
- + ZERO_ALT
- + ONE_ALT
- + "".join(TWO_ALTS + POSITIVE + NEGATIVE + POINT)
- )
-
- print("num:", chn2num("一万零四百零三点八零五"))
- print("num:", chn2num("一亿六点三"))
- print("num:", chn2num("一亿零六点三"))
- print("num:", chn2num("两千零一亿六点三"))
- # print('num:', chn2num('一零零八六'))
- print("txt:", num2chn("10260.03", alt_zero=True))
- print("txt:", num2chn("20037.090", numbering_type="low", traditional=True))
- print("txt:", num2chn("100860001.77", numbering_type="high", big=True))
- print(
- "txt:",
- num2chn(
- "059523810880",
- alt_one=True,
- alt_two=False,
- use_lzeros=True,
- use_rzeros=True,
- use_units=False,
- ),
- )
-
- print(all_chinese_number_string)
diff --git a/fish_speech/text/chn_text_norm/cardinal.py b/fish_speech/text/chn_text_norm/cardinal.py
deleted file mode 100644
index ace9f5ad8e7f3be3a8e41b11dc0b9f80db799616..0000000000000000000000000000000000000000
--- a/fish_speech/text/chn_text_norm/cardinal.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# -*- coding: utf-8 -*-
-"""CARDINAL类 (包含小数DECIMAL类)
-纯数 <=> 中文字符串 方法
-中文字符串 <=> 纯数 方法
-"""
-
-__author__ = "Zhiyang Zhou "
-__data__ = "2019-05-03"
-
-from fish_speech.text.chn_text_norm.basic_util import *
-
-
-class Cardinal:
- """
- CARDINAL类
- """
-
- def __init__(self, cardinal=None, chntext=None):
- self.cardinal = cardinal
- self.chntext = chntext
-
- def chntext2cardinal(self):
- return chn2num(self.chntext)
-
- def cardinal2chntext(self):
- return num2chn(self.cardinal)
-
-
-if __name__ == "__main__":
-
- # 测试程序
- print(Cardinal(cardinal="21357.230").cardinal2chntext())
diff --git a/fish_speech/text/chn_text_norm/date.py b/fish_speech/text/chn_text_norm/date.py
deleted file mode 100644
index 77acfdb9a91df0fe3c615a0784f61aad87fbe56e..0000000000000000000000000000000000000000
--- a/fish_speech/text/chn_text_norm/date.py
+++ /dev/null
@@ -1,75 +0,0 @@
-# -*- coding: utf-8 -*-
-"""DATE类
-日期 <=> 中文字符串 方法
-中文字符串 <=> 日期 方法
-"""
-
-__author__ = "Zhiyang Zhou "
-__data__ = "2019-05-07"
-
-from fish_speech.text.chn_text_norm.cardinal import Cardinal
-from fish_speech.text.chn_text_norm.digit import Digit
-
-
-class Date:
- """
- DATE类
- """
-
- def __init__(self, date=None, chntext=None):
- self.date = date
- self.chntext = chntext
-
- # def chntext2date(self):
- # chntext = self.chntext
- # try:
- # year, other = chntext.strip().split('年', maxsplit=1)
- # year = Digit(chntext=year).digit2chntext() + '年'
- # except ValueError:
- # other = chntext
- # year = ''
- # if other:
- # try:
- # month, day = other.strip().split('月', maxsplit=1)
- # month = Cardinal(chntext=month).chntext2cardinal() + '月'
- # except ValueError:
- # day = chntext
- # month = ''
- # if day:
- # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
- # else:
- # month = ''
- # day = ''
- # date = year + month + day
- # self.date = date
- # return self.date
-
- def date2chntext(self):
- date = self.date
- try:
- year, other = date.strip().split("年", maxsplit=1)
- year = Digit(digit=year).digit2chntext() + "年"
- except ValueError:
- other = date
- year = ""
- if other:
- try:
- month, day = other.strip().split("月", maxsplit=1)
- month = Cardinal(cardinal=month).cardinal2chntext() + "月"
- except ValueError:
- day = date
- month = ""
- if day:
- day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
- else:
- month = ""
- day = ""
- chntext = year + month + day
- self.chntext = chntext
- return self.chntext
-
-
-if __name__ == "__main__":
-
- # 测试
- print(Date(date="09年3月16日").date2chntext())
diff --git a/fish_speech/text/chn_text_norm/digit.py b/fish_speech/text/chn_text_norm/digit.py
deleted file mode 100644
index 47c0cd4ad0c700635f84470bfdacfbdafb4a6185..0000000000000000000000000000000000000000
--- a/fish_speech/text/chn_text_norm/digit.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# -*- coding: utf-8 -*-
-"""DIGIT类
-数字串 <=> 中文字符串 方法
-中文字符串 <=> 数字串 方法
-"""
-
-__author__ = "Zhiyang Zhou "
-__data__ = "2019-05-03"
-
-from fish_speech.text.chn_text_norm.basic_util import *
-
-
-class Digit:
- """
- DIGIT类
- """
-
- def __init__(self, digit=None, chntext=None):
- self.digit = digit
- self.chntext = chntext
-
- # def chntext2digit(self):
- # return chn2num(self.chntext)
-
- def digit2chntext(self):
- return num2chn(self.digit, alt_two=False, use_units=False)
-
-
-if __name__ == "__main__":
-
- # 测试程序
- print(Digit(digit="2016").digit2chntext())
diff --git a/fish_speech/text/chn_text_norm/fraction.py b/fish_speech/text/chn_text_norm/fraction.py
deleted file mode 100644
index b43b6a7feb634d346d59a2b4ab84b77ac88df103..0000000000000000000000000000000000000000
--- a/fish_speech/text/chn_text_norm/fraction.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# -*- coding: utf-8 -*-
-"""FRACTION类
-分数 <=> 中文字符串 方法
-中文字符串 <=> 分数 方法
-"""
-
-__author__ = "Zhiyang Zhou "
-__data__ = "2019-05-03"
-
-from fish_speech.text.chn_text_norm.basic_util import *
-
-
-class Fraction:
- """
- FRACTION类
- """
-
- def __init__(self, fraction=None, chntext=None):
- self.fraction = fraction
- self.chntext = chntext
-
- def chntext2fraction(self):
- denominator, numerator = self.chntext.split("分之")
- return chn2num(numerator) + "/" + chn2num(denominator)
-
- def fraction2chntext(self):
- numerator, denominator = self.fraction.split("/")
- return num2chn(denominator) + "分之" + num2chn(numerator)
-
-
-if __name__ == "__main__":
-
- # 测试程序
- print(Fraction(fraction="2135/7230").fraction2chntext())
- print(Fraction(chntext="五百八十一分之三百六十九").chntext2fraction())
diff --git a/fish_speech/text/chn_text_norm/money.py b/fish_speech/text/chn_text_norm/money.py
deleted file mode 100644
index b4c980d32134e1460e96e5bcbcc73d0d55974d2a..0000000000000000000000000000000000000000
--- a/fish_speech/text/chn_text_norm/money.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# -*- coding: utf-8 -*-
-"""MONEY类
-金钱 <=> 中文字符串 方法
-中文字符串 <=> 金钱 方法
-"""
-import re
-
-__author__ = "Zhiyang Zhou "
-__data__ = "2019-05-08"
-
-from fish_speech.text.chn_text_norm.cardinal import Cardinal
-
-
-class Money:
- """
- MONEY类
- """
-
- def __init__(self, money=None, chntext=None):
- self.money = money
- self.chntext = chntext
-
- # def chntext2money(self):
- # return self.money
-
- def money2chntext(self):
- money = self.money
- pattern = re.compile(r"(\d+(\.\d+)?)")
- matchers = pattern.findall(money)
- if matchers:
- for matcher in matchers:
- money = money.replace(
- matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext()
- )
- self.chntext = money
- return self.chntext
-
-
-if __name__ == "__main__":
-
- # 测试
- print(Money(money="21.5万元").money2chntext())
- print(Money(money="230块5毛").money2chntext())
diff --git a/fish_speech/text/chn_text_norm/percentage.py b/fish_speech/text/chn_text_norm/percentage.py
deleted file mode 100644
index 46abbf545af62eb951d8f6fe40bcf684587f81b0..0000000000000000000000000000000000000000
--- a/fish_speech/text/chn_text_norm/percentage.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# -*- coding: utf-8 -*-
-"""PERCENTAGE类
-百分数 <=> 中文字符串 方法
-中文字符串 <=> 百分数 方法
-"""
-
-__author__ = "Zhiyang Zhou "
-__data__ = "2019-05-06"
-
-from fish_speech.text.chn_text_norm.basic_util import *
-
-
-class Percentage:
- """
- PERCENTAGE类
- """
-
- def __init__(self, percentage=None, chntext=None):
- self.percentage = percentage
- self.chntext = chntext
-
- def chntext2percentage(self):
- return chn2num(self.chntext.strip().strip("百分之")) + "%"
-
- def percentage2chntext(self):
- return "百分之" + num2chn(self.percentage.strip().strip("%"))
-
-
-if __name__ == "__main__":
-
- # 测试程序
- print(Percentage(chntext="百分之五十六点零三").chntext2percentage())
- print(Percentage(percentage="65.3%").percentage2chntext())
diff --git a/fish_speech/text/chn_text_norm/telephone.py b/fish_speech/text/chn_text_norm/telephone.py
deleted file mode 100644
index e72b546db628a3b807dc6235b59b188cae3153ff..0000000000000000000000000000000000000000
--- a/fish_speech/text/chn_text_norm/telephone.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# -*- coding: utf-8 -*-
-"""TELEPHONE类
-电话号码 <=> 中文字符串 方法
-中文字符串 <=> 电话号码 方法
-"""
-
-__author__ = "Zhiyang Zhou "
-__data__ = "2019-05-03"
-
-from fish_speech.text.chn_text_norm.basic_util import *
-
-
-class TelePhone:
- """
- TELEPHONE类
- """
-
- def __init__(self, telephone=None, raw_chntext=None, chntext=None):
- self.telephone = telephone
- self.raw_chntext = raw_chntext
- self.chntext = chntext
-
- # def chntext2telephone(self):
- # sil_parts = self.raw_chntext.split('')
- # self.telephone = '-'.join([
- # str(chn2num(p)) for p in sil_parts
- # ])
- # return self.telephone
-
- def telephone2chntext(self, fixed=False):
-
- if fixed:
- sil_parts = self.telephone.split("-")
- self.raw_chntext = "".join(
- [num2chn(part, alt_two=False, use_units=False) for part in sil_parts]
- )
- self.chntext = self.raw_chntext.replace("", "")
- else:
- sp_parts = self.telephone.strip("+").split()
- self.raw_chntext = "".join(
- [num2chn(part, alt_two=False, use_units=False) for part in sp_parts]
- )
- self.chntext = self.raw_chntext.replace("", "")
- return self.chntext
-
-
-if __name__ == "__main__":
-
- # 测试程序
- print(TelePhone(telephone="0595-23980880").telephone2chntext())
- # print(TelePhone(raw_chntext='零五九五杠二三八六五零九八').chntext2telephone())
diff --git a/fish_speech/text/chn_text_norm/text.py b/fish_speech/text/chn_text_norm/text.py
deleted file mode 100644
index 54086fd933c01e14c3c55cee9adb52eefb58fd31..0000000000000000000000000000000000000000
--- a/fish_speech/text/chn_text_norm/text.py
+++ /dev/null
@@ -1,177 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-TEXT类
-"""
-
-__author__ = "Zhiyang Zhou "
-__data__ = "2019-05-03"
-
-import re
-
-from fish_speech.text.chn_text_norm.cardinal import Cardinal
-from fish_speech.text.chn_text_norm.date import Date
-from fish_speech.text.chn_text_norm.digit import Digit
-from fish_speech.text.chn_text_norm.fraction import Fraction
-from fish_speech.text.chn_text_norm.money import Money
-from fish_speech.text.chn_text_norm.percentage import Percentage
-from fish_speech.text.chn_text_norm.telephone import TelePhone
-
-CURRENCY_NAMES = (
- "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|"
- "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)"
-)
-CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)"
-COM_QUANTIFIERS = (
- "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|"
- "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|"
- "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|"
- "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|"
- "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|"
- "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|人|抽)"
-)
-
-
-class Text:
- """
- Text类
- """
-
- def __init__(self, raw_text, norm_text=None):
- self.raw_text = "^" + raw_text + "$"
- self.norm_text = norm_text
-
- def _particular(self):
- text = self.norm_text
- pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
- matchers = pattern.findall(text)
- if matchers:
- # print('particular')
- for matcher in matchers:
- text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
- self.norm_text = text
- return self.norm_text
-
- def normalize(self):
- text = self.raw_text
-
- # 规范化日期
- pattern = re.compile(
- r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)"
- )
- matchers = pattern.findall(text)
- if matchers:
- # print('date')
- for matcher in matchers:
- text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
-
- # 规范化金钱
- pattern = re.compile(
- r"\D+((\d+(\.\d+)?)[多余几]?"
- + CURRENCY_UNITS
- + "(\d"
- + CURRENCY_UNITS
- + "?)?)"
- )
- matchers = pattern.findall(text)
- if matchers:
- # print('money')
- for matcher in matchers:
- text = text.replace(
- matcher[0], Money(money=matcher[0]).money2chntext(), 1
- )
-
- # 规范化固话/手机号码
- # 手机
- # http://www.jihaoba.com/news/show/13680
- # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
- # 联通:130、131、132、156、155、186、185、176
- # 电信:133、153、189、180、181、177
- pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
- matchers = pattern.findall(text)
- if matchers:
- # print('telephone')
- for matcher in matchers:
- text = text.replace(
- matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1
- )
- # 固话
- pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
- matchers = pattern.findall(text)
- if matchers:
- # print('fixed telephone')
- for matcher in matchers:
- text = text.replace(
- matcher[0],
- TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True),
- 1,
- )
-
- # 规范化分数
- pattern = re.compile(r"(\d+/\d+)")
- matchers = pattern.findall(text)
- if matchers:
- # print('fraction')
- for matcher in matchers:
- text = text.replace(
- matcher, Fraction(fraction=matcher).fraction2chntext(), 1
- )
-
- # 规范化百分数
- text = text.replace("%", "%")
- pattern = re.compile(r"(\d+(\.\d+)?%)")
- matchers = pattern.findall(text)
- if matchers:
- # print('percentage')
- for matcher in matchers:
- text = text.replace(
- matcher[0],
- Percentage(percentage=matcher[0]).percentage2chntext(),
- 1,
- )
-
- # 规范化纯数+量词
- pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
- matchers = pattern.findall(text)
- if matchers:
- # print('cardinal+quantifier')
- for matcher in matchers:
- text = text.replace(
- matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1
- )
-
- # 规范化数字编号
- pattern = re.compile(r"(\d{4,32})")
- matchers = pattern.findall(text)
- if matchers:
- # print('digit')
- for matcher in matchers:
- text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
-
- # 规范化纯数
- pattern = re.compile(r"(\d+(\.\d+)?)")
- matchers = pattern.findall(text)
- if matchers:
- # print('cardinal')
- for matcher in matchers:
- text = text.replace(
- matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1
- )
-
- self.norm_text = text
- self._particular()
-
- return self.norm_text.lstrip("^").rstrip("$")
-
-
-if __name__ == "__main__":
-
- # 测试程序
- print(Text(raw_text="固话:0595-23865596或23880880。").normalize())
- print(Text(raw_text="手机:+86 19859213959或15659451527。").normalize())
- print(Text(raw_text="分数:32477/76391。").normalize())
- print(Text(raw_text="百分数:80.03%。").normalize())
- print(Text(raw_text="编号:31520181154418。").normalize())
- print(Text(raw_text="纯数:2983.07克或12345.60米。").normalize())
- print(Text(raw_text="日期:1999年2月20日或09年3月15号。").normalize())
- print(Text(raw_text="金钱:12块5,34.5元,20.1万").normalize())
- print(Text(raw_text="特殊:O2O或B2C。").normalize())
diff --git a/fish_speech/text/clean.py b/fish_speech/text/clean.py
deleted file mode 100644
index c228dfcd13324e8b1abe4ead5f01f4bd8ed0c33a..0000000000000000000000000000000000000000
--- a/fish_speech/text/clean.py
+++ /dev/null
@@ -1,31 +0,0 @@
-import re
-
-SYMBOLS_MAPPING = {
- "“": "'",
- "”": "'",
- "‘": "'",
- "’": "'",
- "【": "",
- "】": "",
- "[": "",
- "]": "",
- "(": "",
- ")": "",
- "(": "",
- ")": "",
- "・": "·",
-}
-
-REPLACE_SYMBOL_REGEX = re.compile(
- "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
-)
-
-
-def clean_text(text):
- # Clean the text
- text = text.strip()
-
- # Replace all chinese symbols with their english counterparts
- text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
-
- return text
diff --git a/fish_speech/text/spliter.py b/fish_speech/text/spliter.py
deleted file mode 100644
index d4bb995487c4f53818c6b2a16cf0a886b4e02e84..0000000000000000000000000000000000000000
--- a/fish_speech/text/spliter.py
+++ /dev/null
@@ -1,130 +0,0 @@
-import re
-import string
-
-from fish_speech.text.clean import clean_text
-
-
-def utf_8_len(text):
- return len(text.encode("utf-8"))
-
-
-def break_text(texts, length, splits: set):
- for text in texts:
- if utf_8_len(text) <= length:
- yield text
- continue
-
- curr = ""
- for char in text:
- curr += char
-
- if char in splits:
- yield curr
- curr = ""
-
- if curr:
- yield curr
-
-
-def break_text_by_length(texts, length):
- for text in texts:
- if utf_8_len(text) <= length:
- yield text
- continue
-
- curr = ""
- for char in text:
- curr += char
-
- if utf_8_len(curr) >= length:
- yield curr
- curr = ""
-
- if curr:
- yield curr
-
-
-def add_cleaned(curr, segments):
- curr = curr.strip()
- if curr and not all(c.isspace() or c in string.punctuation for c in curr):
- segments.append(curr)
-
-
-def protect_float(text):
- # Turns 3.14 into <3_f_14> to prevent splitting
- return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
-
-
-def unprotect_float(text):
- # Turns <3_f_14> into 3.14
- return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
-
-
-def split_text(text, length):
- text = clean_text(text)
-
- # Break the text into pieces with following rules:
- # 1. Split the text at ".", "!", "?" if text is NOT a float
- # 2. If the text is longer than length, split at ","
- # 3. If the text is still longer than length, split at " "
- # 4. If the text is still longer than length, split at any character to length
-
- texts = [text]
- texts = map(protect_float, texts)
- texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"})
- texts = map(unprotect_float, texts)
- texts = break_text(texts, length, {",", ","})
- texts = break_text(texts, length, {" "})
- texts = list(break_text_by_length(texts, length))
-
- # Then, merge the texts into segments with length <= length
- segments = []
- curr = ""
-
- for text in texts:
- if utf_8_len(curr) + utf_8_len(text) <= length:
- curr += text
- else:
- add_cleaned(curr, segments)
- curr = text
-
- if curr:
- add_cleaned(curr, segments)
-
- return segments
-
-
-if __name__ == "__main__":
- # Test the split_text function
-
- text = "This is a test sentence. This is another test sentence. And a third one."
-
- assert split_text(text, 50) == [
- "This is a test sentence.",
- "This is another test sentence. And a third one.",
- ]
- assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"]
- assert split_text(" ", 10) == []
- assert split_text("a", 10) == ["a"]
-
- text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines."
- assert split_text(text, 50) == [
- "This is a test sentence with only commas,",
- "and no dots, and no exclamation marks,",
- "and no question marks, and no newlines.",
- ]
-
- text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence."
- # First half split at " ", second half split at ","
- assert split_text(text, 50) == [
- "This is a test sentence This is a test sentence",
- "This is a test sentence. This is a test sentence,",
- "This is a test sentence, This is a test sentence.",
- ]
-
- text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。"
- assert split_text(text, 50) == [
- "这是一段很长的中文文本,",
- "而且没有句号,也没有感叹号,",
- "也没有问号,也没有换行符.",
- ]
diff --git a/fish_speech/train.py b/fish_speech/train.py
deleted file mode 100644
index e693f3adc4dda787bdd587aec29f53355f2b1653..0000000000000000000000000000000000000000
--- a/fish_speech/train.py
+++ /dev/null
@@ -1,141 +0,0 @@
-import os
-
-os.environ["USE_LIBUV"] = "0"
-import sys
-from typing import Optional
-
-import hydra
-import lightning as L
-import pyrootutils
-import torch
-from lightning import Callback, LightningDataModule, LightningModule, Trainer
-from lightning.pytorch.loggers import Logger
-from lightning.pytorch.strategies import DDPStrategy
-from omegaconf import DictConfig, OmegaConf
-
-os.environ.pop("SLURM_NTASKS", None)
-os.environ.pop("SLURM_JOB_NAME", None)
-os.environ.pop("SLURM_NTASKS_PER_NODE", None)
-
-# register eval resolver and root
-pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
-
-# Allow TF32 on Ampere GPUs
-torch.set_float32_matmul_precision("high")
-torch.backends.cudnn.allow_tf32 = True
-
-# register eval resolver
-OmegaConf.register_new_resolver("eval", eval)
-
-import fish_speech.utils as utils
-
-log = utils.RankedLogger(__name__, rank_zero_only=True)
-
-
-@utils.task_wrapper
-def train(cfg: DictConfig) -> tuple[dict, dict]:
- """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
- training.
- This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
- failure. Useful for multiruns, saving info about the crash, etc.
- Args:
- cfg (DictConfig): Configuration composed by Hydra.
- Returns:
- Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
- """ # noqa: E501
-
- # set seed for random number generators in pytorch, numpy and python.random
- if cfg.get("seed"):
- L.seed_everything(cfg.seed, workers=False)
-
- if cfg.get("deterministic"):
- torch.use_deterministic_algorithms(True)
-
- log.info(f"Instantiating datamodule <{cfg.data._target_}>")
- datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
-
- log.info(f"Instantiating model <{cfg.model._target_}>")
- model: LightningModule = hydra.utils.instantiate(cfg.model)
-
- log.info("Instantiating callbacks...")
- callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
-
- log.info("Instantiating loggers...")
- logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger"))
-
- log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
- trainer: Trainer = hydra.utils.instantiate(
- cfg.trainer,
- callbacks=callbacks,
- logger=logger,
- )
-
- object_dict = {
- "cfg": cfg,
- "datamodule": datamodule,
- "model": model,
- "callbacks": callbacks,
- "logger": logger,
- "trainer": trainer,
- }
-
- if logger:
- log.info("Logging hyperparameters!")
- utils.log_hyperparameters(object_dict)
-
- if cfg.get("train"):
- log.info("Starting training!")
-
- ckpt_path = cfg.get("ckpt_path")
- auto_resume = False
-
- resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
- if resume_ckpt_path is not None:
- ckpt_path = resume_ckpt_path
- auto_resume = True
-
- if ckpt_path is not None:
- log.info(f"Resuming from checkpoint: {ckpt_path}")
-
- # resume weights only is disabled for auto-resume
- if cfg.get("resume_weights_only") and auto_resume is False:
- log.info("Resuming weights only!")
- ckpt = torch.load(ckpt_path, map_location=model.device)
- if "state_dict" in ckpt:
- ckpt = ckpt["state_dict"]
- err = model.load_state_dict(ckpt, strict=False)
- log.info(f"Error loading state dict: {err}")
- ckpt_path = None
-
- trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
-
- train_metrics = trainer.callback_metrics
-
- if cfg.get("test"):
- log.info("Starting testing!")
- ckpt_path = trainer.checkpoint_callback.best_model_path
- if ckpt_path == "":
- log.warning("Best ckpt not found! Using current weights for testing...")
- ckpt_path = cfg.get("ckpt_path")
-
- trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
- log.info(f"Best ckpt path: {ckpt_path}")
-
- test_metrics = trainer.callback_metrics
-
- # merge train and test metrics
- metric_dict = {**train_metrics, **test_metrics}
-
- return metric_dict, object_dict
-
-
-@hydra.main(
- version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
-)
-def main(cfg: DictConfig) -> Optional[float]:
- # train the model
- train(cfg)
-
-
-if __name__ == "__main__":
- main()
diff --git a/fish_speech/utils/__init__.py b/fish_speech/utils/__init__.py
deleted file mode 100644
index 05378519dbd18361c639e33413d011e7307c9adb..0000000000000000000000000000000000000000
--- a/fish_speech/utils/__init__.py
+++ /dev/null
@@ -1,23 +0,0 @@
-from .braceexpand import braceexpand
-from .context import autocast_exclude_mps
-from .file import get_latest_checkpoint
-from .instantiators import instantiate_callbacks, instantiate_loggers
-from .logger import RankedLogger
-from .logging_utils import log_hyperparameters
-from .rich_utils import enforce_tags, print_config_tree
-from .utils import extras, get_metric_value, task_wrapper
-
-__all__ = [
- "enforce_tags",
- "extras",
- "get_metric_value",
- "RankedLogger",
- "instantiate_callbacks",
- "instantiate_loggers",
- "log_hyperparameters",
- "print_config_tree",
- "task_wrapper",
- "braceexpand",
- "get_latest_checkpoint",
- "autocast_exclude_mps",
-]
diff --git a/fish_speech/utils/__pycache__/__init__.cpython-310.pyc b/fish_speech/utils/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index 1275a8478b5b6c8ca96cd20f18be4f300e5fba8d..0000000000000000000000000000000000000000
Binary files a/fish_speech/utils/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/utils/__pycache__/braceexpand.cpython-310.pyc b/fish_speech/utils/__pycache__/braceexpand.cpython-310.pyc
deleted file mode 100644
index 611e658e7387832fb9d481e775466a60689e364c..0000000000000000000000000000000000000000
Binary files a/fish_speech/utils/__pycache__/braceexpand.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/utils/__pycache__/context.cpython-310.pyc b/fish_speech/utils/__pycache__/context.cpython-310.pyc
deleted file mode 100644
index 0701855f15ea618e6fca6bba156a480a26e06705..0000000000000000000000000000000000000000
Binary files a/fish_speech/utils/__pycache__/context.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/utils/__pycache__/file.cpython-310.pyc b/fish_speech/utils/__pycache__/file.cpython-310.pyc
deleted file mode 100644
index d52787c4c9346fa3ac90012057d87598170b1619..0000000000000000000000000000000000000000
Binary files a/fish_speech/utils/__pycache__/file.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/utils/__pycache__/instantiators.cpython-310.pyc b/fish_speech/utils/__pycache__/instantiators.cpython-310.pyc
deleted file mode 100644
index 78c1b17fb8f7e05a50ed4056b404d7e60c2f104f..0000000000000000000000000000000000000000
Binary files a/fish_speech/utils/__pycache__/instantiators.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/utils/__pycache__/logger.cpython-310.pyc b/fish_speech/utils/__pycache__/logger.cpython-310.pyc
deleted file mode 100644
index 32cfb48f1bac889f58a4059ebc3033b2ec328077..0000000000000000000000000000000000000000
Binary files a/fish_speech/utils/__pycache__/logger.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/utils/__pycache__/logging_utils.cpython-310.pyc b/fish_speech/utils/__pycache__/logging_utils.cpython-310.pyc
deleted file mode 100644
index 5e24723cdd60dd27d036e1e3a72def349e22f5d8..0000000000000000000000000000000000000000
Binary files a/fish_speech/utils/__pycache__/logging_utils.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/utils/__pycache__/rich_utils.cpython-310.pyc b/fish_speech/utils/__pycache__/rich_utils.cpython-310.pyc
deleted file mode 100644
index e8a99c143231037868f6d6593d101636c0955844..0000000000000000000000000000000000000000
Binary files a/fish_speech/utils/__pycache__/rich_utils.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/utils/__pycache__/spectrogram.cpython-310.pyc b/fish_speech/utils/__pycache__/spectrogram.cpython-310.pyc
deleted file mode 100644
index d83d0c6d63e8a397e659056c7f4ecdc3299f9135..0000000000000000000000000000000000000000
Binary files a/fish_speech/utils/__pycache__/spectrogram.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/utils/__pycache__/utils.cpython-310.pyc b/fish_speech/utils/__pycache__/utils.cpython-310.pyc
deleted file mode 100644
index fbad6b1a0fbbd0e58817cd597ae6b9ed26f7e53a..0000000000000000000000000000000000000000
Binary files a/fish_speech/utils/__pycache__/utils.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/utils/braceexpand.py b/fish_speech/utils/braceexpand.py
deleted file mode 100644
index f3ac739f01f7e10e039c68c1157d6c761064f974..0000000000000000000000000000000000000000
--- a/fish_speech/utils/braceexpand.py
+++ /dev/null
@@ -1,217 +0,0 @@
-"""
-Bash-style brace expansion
-Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py
-License: MIT
-"""
-
-import re
-import string
-from itertools import chain, product
-from typing import Iterable, Iterator, Optional
-
-__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"]
-
-
-class UnbalancedBracesError(ValueError):
- pass
-
-
-alphabet = string.ascii_uppercase + string.ascii_lowercase
-
-int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$")
-char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$")
-escape_re = re.compile(r"\\(.)")
-
-
-def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]:
- """braceexpand(pattern) -> iterator over generated strings
-
- Returns an iterator over the strings resulting from brace expansion
- of pattern. This function implements Brace Expansion as described in
- bash(1), with the following limitations:
-
- * A pattern containing unbalanced braces will raise an
- UnbalancedBracesError exception. In bash, unbalanced braces will either
- be partly expanded or ignored.
-
- * A mixed-case character range like '{Z..a}' or '{a..Z}' will not
- include the characters '[]^_`' between 'Z' and 'a'.
-
- When escape is True (the default), characters in pattern can be
- prefixed with a backslash to cause them not to be interpreted as
- special characters for brace expansion (such as '{', '}', ',').
- To pass through a a literal backslash, double it ('\\\\').
-
- When escape is False, backslashes in pattern have no special
- meaning and will be preserved in the output.
-
- Examples:
-
- >>> from braceexpand import braceexpand
-
- # Integer range
- >>> list(braceexpand('item{1..3}'))
- ['item1', 'item2', 'item3']
-
- # Character range
- >>> list(braceexpand('{a..c}'))
- ['a', 'b', 'c']
-
- # Sequence
- >>> list(braceexpand('index.html{,.backup}'))
- ['index.html', 'index.html.backup']
-
- # Nested patterns
- >>> list(braceexpand('python{2.{5..7},3.{2,3}}'))
- ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3']
-
- # Prefixing an integer with zero causes all numbers to be padded to
- # the same width.
- >>> list(braceexpand('{07..10}'))
- ['07', '08', '09', '10']
-
- # An optional increment can be specified for ranges.
- >>> list(braceexpand('{a..g..2}'))
- ['a', 'c', 'e', 'g']
-
- # Ranges can go in both directions.
- >>> list(braceexpand('{4..1}'))
- ['4', '3', '2', '1']
-
- # Numbers can be negative
- >>> list(braceexpand('{2..-1}'))
- ['2', '1', '0', '-1']
-
- # Unbalanced braces raise an exception.
- >>> list(braceexpand('{1{2,3}'))
- Traceback (most recent call last):
- ...
- UnbalancedBracesError: Unbalanced braces: '{1{2,3}'
-
- # By default, the backslash is the escape character.
- >>> list(braceexpand(r'{1\\{2,3}'))
- ['1{2', '3']
-
- # Setting 'escape' to False disables backslash escaping.
- >>> list(braceexpand(r'\\{1,2}', escape=False))
- ['\\\\1', '\\\\2']
-
- """
- return (
- escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape)
- )
-
-
-def parse_pattern(pattern: str, escape: bool) -> Iterator[str]:
- start = 0
- pos = 0
- bracketdepth = 0
- items: list[Iterable[str]] = []
-
- # print 'pattern:', pattern
- while pos < len(pattern):
- if escape and pattern[pos] == "\\":
- pos += 2
- continue
- elif pattern[pos] == "{":
- if bracketdepth == 0 and pos > start:
- # print 'literal:', pattern[start:pos]
- items.append([pattern[start:pos]])
- start = pos
- bracketdepth += 1
- elif pattern[pos] == "}":
- bracketdepth -= 1
- if bracketdepth == 0:
- # print 'expression:', pattern[start+1:pos]
- expr = pattern[start + 1 : pos]
- item = parse_expression(expr, escape)
- if item is None: # not a range or sequence
- items.extend([["{"], parse_pattern(expr, escape), ["}"]])
- else:
- items.append(item)
- start = pos + 1 # skip the closing brace
- pos += 1
-
- if bracketdepth != 0: # unbalanced braces
- raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern)
-
- if start < pos:
- items.append([pattern[start:]])
-
- return ("".join(item) for item in product(*items))
-
-
-def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]:
- int_range_match = int_range_re.match(expr)
- if int_range_match:
- return make_int_range(*int_range_match.groups())
-
- char_range_match = char_range_re.match(expr)
- if char_range_match:
- return make_char_range(*char_range_match.groups())
-
- return parse_sequence(expr, escape)
-
-
-def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]:
- # sequence -> chain(*sequence_items)
- start = 0
- pos = 0
- bracketdepth = 0
- items: list[Iterable[str]] = []
-
- # print 'sequence:', seq
- while pos < len(seq):
- if escape and seq[pos] == "\\":
- pos += 2
- continue
- elif seq[pos] == "{":
- bracketdepth += 1
- elif seq[pos] == "}":
- bracketdepth -= 1
- elif seq[pos] == "," and bracketdepth == 0:
- items.append(parse_pattern(seq[start:pos], escape))
- start = pos + 1 # skip the comma
- pos += 1
-
- if bracketdepth != 0:
- raise UnbalancedBracesError
- if not items:
- return None
-
- # part after the last comma (may be the empty string)
- items.append(parse_pattern(seq[start:], escape))
- return chain(*items)
-
-
-def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]:
- if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]):
- padding = max(len(left), len(right))
- else:
- padding = 0
- step = (int(incr) or 1) if incr else 1
- start = int(left)
- end = int(right)
- r = range(start, end + 1, step) if start < end else range(start, end - 1, -step)
- fmt = "%0{}d".format(padding)
- return (fmt % i for i in r)
-
-
-def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str:
- step = (int(incr) or 1) if incr else 1
- start = alphabet.index(left)
- end = alphabet.index(right)
- if start < end:
- return alphabet[start : end + 1 : step]
- else:
- end = end or -len(alphabet)
- return alphabet[start : end - 1 : -step]
-
-
-if __name__ == "__main__":
- import doctest
- import sys
-
- failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL)
- if failed:
- sys.exit(1)
diff --git a/fish_speech/utils/context.py b/fish_speech/utils/context.py
deleted file mode 100644
index f04a99290ab32f7fe5b60656075a2d03af8468d6..0000000000000000000000000000000000000000
--- a/fish_speech/utils/context.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from contextlib import nullcontext
-
-import torch
-
-
-def autocast_exclude_mps(
- device_type: str, dtype: torch.dtype
-) -> nullcontext | torch.autocast:
- return (
- nullcontext()
- if torch.backends.mps.is_available()
- else torch.autocast(device_type, dtype)
- )
diff --git a/fish_speech/utils/file.py b/fish_speech/utils/file.py
deleted file mode 100644
index 78c82640a963fa556657107729f7543d2e7c3510..0000000000000000000000000000000000000000
--- a/fish_speech/utils/file.py
+++ /dev/null
@@ -1,16 +0,0 @@
-import os
-from pathlib import Path
-
-
-def get_latest_checkpoint(path: Path | str) -> Path | None:
- # Find the latest checkpoint
- ckpt_dir = Path(path)
-
- if ckpt_dir.exists() is False:
- return None
-
- ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
- if len(ckpts) == 0:
- return None
-
- return ckpts[-1]
diff --git a/fish_speech/utils/instantiators.py b/fish_speech/utils/instantiators.py
deleted file mode 100644
index f6ee463924f588a35477937fbe3c3364043bdf3e..0000000000000000000000000000000000000000
--- a/fish_speech/utils/instantiators.py
+++ /dev/null
@@ -1,50 +0,0 @@
-from typing import List
-
-import hydra
-from omegaconf import DictConfig
-from pytorch_lightning import Callback
-from pytorch_lightning.loggers import Logger
-
-from .logger import RankedLogger
-
-log = RankedLogger(__name__, rank_zero_only=True)
-
-
-def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
- """Instantiates callbacks from config."""
-
- callbacks: List[Callback] = []
-
- if not callbacks_cfg:
- log.warning("No callback configs found! Skipping..")
- return callbacks
-
- if not isinstance(callbacks_cfg, DictConfig):
- raise TypeError("Callbacks config must be a DictConfig!")
-
- for _, cb_conf in callbacks_cfg.items():
- if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
- log.info(f"Instantiating callback <{cb_conf._target_}>")
- callbacks.append(hydra.utils.instantiate(cb_conf))
-
- return callbacks
-
-
-def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
- """Instantiates loggers from config."""
-
- logger: List[Logger] = []
-
- if not logger_cfg:
- log.warning("No logger configs found! Skipping...")
- return logger
-
- if not isinstance(logger_cfg, DictConfig):
- raise TypeError("Logger config must be a DictConfig!")
-
- for _, lg_conf in logger_cfg.items():
- if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
- log.info(f"Instantiating logger <{lg_conf._target_}>")
- logger.append(hydra.utils.instantiate(lg_conf))
-
- return logger
diff --git a/fish_speech/utils/logger.py b/fish_speech/utils/logger.py
deleted file mode 100644
index 94f94f738d1d87404354d086c30ef0ad9ab04cdc..0000000000000000000000000000000000000000
--- a/fish_speech/utils/logger.py
+++ /dev/null
@@ -1,55 +0,0 @@
-import logging
-from typing import Mapping, Optional
-
-from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
-
-
-class RankedLogger(logging.LoggerAdapter):
- """A multi-GPU-friendly python command line logger."""
-
- def __init__(
- self,
- name: str = __name__,
- rank_zero_only: bool = True,
- extra: Optional[Mapping[str, object]] = None,
- ) -> None:
- """Initializes a multi-GPU-friendly python command line logger that logs on all processes
- with their rank prefixed in the log message.
-
- :param name: The name of the logger. Default is ``__name__``.
- :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
- :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
- """
- logger = logging.getLogger(name)
- super().__init__(logger=logger, extra=extra)
- self.rank_zero_only = rank_zero_only
-
- def log(
- self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
- ) -> None:
- """Delegate a log call to the underlying logger, after prefixing its message with the rank
- of the process it's being logged from. If `'rank'` is provided, then the log will only
- occur on that rank/process.
-
- :param level: The level to log at. Look at `logging.__init__.py` for more information.
- :param msg: The message to log.
- :param rank: The rank to log at.
- :param args: Additional args to pass to the underlying logging function.
- :param kwargs: Any additional keyword args to pass to the underlying logging function.
- """
- if self.isEnabledFor(level):
- msg, kwargs = self.process(msg, kwargs)
- current_rank = getattr(rank_zero_only, "rank", None)
- if current_rank is None:
- raise RuntimeError(
- "The `rank_zero_only.rank` needs to be set before use"
- )
- msg = rank_prefixed_message(msg, current_rank)
- if self.rank_zero_only:
- if current_rank == 0:
- self.logger.log(level, msg, *args, **kwargs)
- else:
- if rank is None:
- self.logger.log(level, msg, *args, **kwargs)
- elif current_rank == rank:
- self.logger.log(level, msg, *args, **kwargs)
diff --git a/fish_speech/utils/logging_utils.py b/fish_speech/utils/logging_utils.py
deleted file mode 100644
index 8e3b0a2519e12845f09e5fbe86dfccbf5b345429..0000000000000000000000000000000000000000
--- a/fish_speech/utils/logging_utils.py
+++ /dev/null
@@ -1,48 +0,0 @@
-from lightning.pytorch.utilities import rank_zero_only
-
-from fish_speech.utils import logger as log
-
-
-@rank_zero_only
-def log_hyperparameters(object_dict: dict) -> None:
- """Controls which config parts are saved by lightning loggers.
-
- Additionally saves:
- - Number of model parameters
- """
-
- hparams = {}
-
- cfg = object_dict["cfg"]
- model = object_dict["model"]
- trainer = object_dict["trainer"]
-
- if not trainer.logger:
- log.warning("Logger not found! Skipping hyperparameter logging...")
- return
-
- hparams["model"] = cfg["model"]
-
- # save number of model parameters
- hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
- hparams["model/params/trainable"] = sum(
- p.numel() for p in model.parameters() if p.requires_grad
- )
- hparams["model/params/non_trainable"] = sum(
- p.numel() for p in model.parameters() if not p.requires_grad
- )
-
- hparams["data"] = cfg["data"]
- hparams["trainer"] = cfg["trainer"]
-
- hparams["callbacks"] = cfg.get("callbacks")
- hparams["extras"] = cfg.get("extras")
-
- hparams["task_name"] = cfg.get("task_name")
- hparams["tags"] = cfg.get("tags")
- hparams["ckpt_path"] = cfg.get("ckpt_path")
- hparams["seed"] = cfg.get("seed")
-
- # send hparams to all loggers
- for logger in trainer.loggers:
- logger.log_hyperparams(hparams)
diff --git a/fish_speech/utils/rich_utils.py b/fish_speech/utils/rich_utils.py
deleted file mode 100644
index 6a465f54d610779766d51e3d1a020a3b1517fd1f..0000000000000000000000000000000000000000
--- a/fish_speech/utils/rich_utils.py
+++ /dev/null
@@ -1,100 +0,0 @@
-from pathlib import Path
-from typing import Sequence
-
-import rich
-import rich.syntax
-import rich.tree
-from hydra.core.hydra_config import HydraConfig
-from lightning.pytorch.utilities import rank_zero_only
-from omegaconf import DictConfig, OmegaConf, open_dict
-from rich.prompt import Prompt
-
-from fish_speech.utils import logger as log
-
-
-@rank_zero_only
-def print_config_tree(
- cfg: DictConfig,
- print_order: Sequence[str] = (
- "data",
- "model",
- "callbacks",
- "logger",
- "trainer",
- "paths",
- "extras",
- ),
- resolve: bool = False,
- save_to_file: bool = False,
-) -> None:
- """Prints content of DictConfig using Rich library and its tree structure.
-
- Args:
- cfg (DictConfig): Configuration composed by Hydra.
- print_order (Sequence[str], optional): Determines in what order config components are printed.
- resolve (bool, optional): Whether to resolve reference fields of DictConfig.
- save_to_file (bool, optional): Whether to export config to the hydra output folder.
- """ # noqa: E501
-
- style = "dim"
- tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
-
- queue = []
-
- # add fields from `print_order` to queue
- for field in print_order:
- (
- queue.append(field)
- if field in cfg
- else log.warning(
- f"Field '{field}' not found in config. "
- + f"Skipping '{field}' config printing..."
- )
- )
-
- # add all the other fields to queue (not specified in `print_order`)
- for field in cfg:
- if field not in queue:
- queue.append(field)
-
- # generate config tree from queue
- for field in queue:
- branch = tree.add(field, style=style, guide_style=style)
-
- config_group = cfg[field]
- if isinstance(config_group, DictConfig):
- branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
- else:
- branch_content = str(config_group)
-
- branch.add(rich.syntax.Syntax(branch_content, "yaml"))
-
- # print config tree
- rich.print(tree)
-
- # save config tree to file
- if save_to_file:
- with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
- rich.print(tree, file=file)
-
-
-@rank_zero_only
-def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
- """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501
-
- if not cfg.get("tags"):
- if "id" in HydraConfig().cfg.hydra.job:
- raise ValueError("Specify tags before launching a multirun!")
-
- log.warning("No tags provided in config. Prompting user to input tags...")
- tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
- tags = [t.strip() for t in tags.split(",") if t != ""]
-
- with open_dict(cfg):
- cfg.tags = tags
-
- log.info(f"Tags: {cfg.tags}")
-
- if save_to_file:
- with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
- rich.print(cfg.tags, file=file)
diff --git a/fish_speech/utils/spectrogram.py b/fish_speech/utils/spectrogram.py
deleted file mode 100644
index 01c3d7a2ab0f707ae92dbde0feb173927720c841..0000000000000000000000000000000000000000
--- a/fish_speech/utils/spectrogram.py
+++ /dev/null
@@ -1,122 +0,0 @@
-import torch
-import torchaudio.functional as F
-from torch import Tensor, nn
-from torchaudio.transforms import MelScale
-
-
-class LinearSpectrogram(nn.Module):
- def __init__(
- self,
- n_fft=2048,
- win_length=2048,
- hop_length=512,
- center=False,
- mode="pow2_sqrt",
- ):
- super().__init__()
-
- self.n_fft = n_fft
- self.win_length = win_length
- self.hop_length = hop_length
- self.center = center
- self.mode = mode
-
- self.register_buffer("window", torch.hann_window(win_length), persistent=False)
-
- def forward(self, y: Tensor) -> Tensor:
- if y.ndim == 3:
- y = y.squeeze(1)
-
- y = torch.nn.functional.pad(
- y.unsqueeze(1),
- (
- (self.win_length - self.hop_length) // 2,
- (self.win_length - self.hop_length + 1) // 2,
- ),
- mode="reflect",
- ).squeeze(1)
-
- spec = torch.stft(
- y,
- self.n_fft,
- hop_length=self.hop_length,
- win_length=self.win_length,
- window=self.window,
- center=self.center,
- pad_mode="reflect",
- normalized=False,
- onesided=True,
- return_complex=True,
- )
-
- spec = torch.view_as_real(spec)
-
- if self.mode == "pow2_sqrt":
- spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
-
- return spec
-
-
-class LogMelSpectrogram(nn.Module):
- def __init__(
- self,
- sample_rate=44100,
- n_fft=2048,
- win_length=2048,
- hop_length=512,
- n_mels=128,
- center=False,
- f_min=0.0,
- f_max=None,
- ):
- super().__init__()
-
- self.sample_rate = sample_rate
- self.n_fft = n_fft
- self.win_length = win_length
- self.hop_length = hop_length
- self.center = center
- self.n_mels = n_mels
- self.f_min = f_min
- self.f_max = f_max or float(sample_rate // 2)
-
- self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
-
- fb = F.melscale_fbanks(
- n_freqs=self.n_fft // 2 + 1,
- f_min=self.f_min,
- f_max=self.f_max,
- n_mels=self.n_mels,
- sample_rate=self.sample_rate,
- norm="slaney",
- mel_scale="slaney",
- )
- self.register_buffer(
- "fb",
- fb,
- persistent=False,
- )
-
- def compress(self, x: Tensor) -> Tensor:
- return torch.log(torch.clamp(x, min=1e-5))
-
- def decompress(self, x: Tensor) -> Tensor:
- return torch.exp(x)
-
- def apply_mel_scale(self, x: Tensor) -> Tensor:
- return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
-
- def forward(
- self, x: Tensor, return_linear: bool = False, sample_rate: int = None
- ) -> Tensor:
- if sample_rate is not None and sample_rate != self.sample_rate:
- x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
-
- linear = self.spectrogram(x)
- x = self.apply_mel_scale(linear)
- x = self.compress(x)
-
- if return_linear:
- return x, self.compress(linear)
-
- return x
diff --git a/fish_speech/utils/utils.py b/fish_speech/utils/utils.py
deleted file mode 100644
index c546bfa1eddd2ac6bf484cce1ec06da1d33fb121..0000000000000000000000000000000000000000
--- a/fish_speech/utils/utils.py
+++ /dev/null
@@ -1,114 +0,0 @@
-import warnings
-from importlib.util import find_spec
-from typing import Callable
-
-from omegaconf import DictConfig
-
-from .logger import RankedLogger
-from .rich_utils import enforce_tags, print_config_tree
-
-log = RankedLogger(__name__, rank_zero_only=True)
-
-
-def extras(cfg: DictConfig) -> None:
- """Applies optional utilities before the task is started.
-
- Utilities:
- - Ignoring python warnings
- - Setting tags from command line
- - Rich config printing
- """
-
- # return if no `extras` config
- if not cfg.get("extras"):
- log.warning("Extras config not found! ")
- return
-
- # disable python warnings
- if cfg.extras.get("ignore_warnings"):
- log.info("Disabling python warnings! ")
- warnings.filterwarnings("ignore")
-
- # prompt user to input tags from command line if none are provided in the config
- if cfg.extras.get("enforce_tags"):
- log.info("Enforcing tags! ")
- enforce_tags(cfg, save_to_file=True)
-
- # pretty print config tree using Rich library
- if cfg.extras.get("print_config"):
- log.info("Printing config tree with Rich! ")
- print_config_tree(cfg, resolve=True, save_to_file=True)
-
-
-def task_wrapper(task_func: Callable) -> Callable:
- """Optional decorator that controls the failure behavior when executing the task function.
-
- This wrapper can be used to:
- - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
- - save the exception to a `.log` file
- - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
- - etc. (adjust depending on your needs)
-
- Example:
- ```
- @utils.task_wrapper
- def train(cfg: DictConfig) -> Tuple[dict, dict]:
-
- ...
-
- return metric_dict, object_dict
- ```
- """ # noqa: E501
-
- def wrap(cfg: DictConfig):
- # execute the task
- try:
- metric_dict, object_dict = task_func(cfg=cfg)
-
- # things to do if exception occurs
- except Exception as ex:
- # save exception to `.log` file
- log.exception("")
-
- # some hyperparameter combinations might be invalid or
- # cause out-of-memory errors so when using hparam search
- # plugins like Optuna, you might want to disable
- # raising the below exception to avoid multirun failure
- raise ex
-
- # things to always do after either success or exception
- finally:
- # display output dir path in terminal
- log.info(f"Output dir: {cfg.paths.run_dir}")
-
- # always close wandb run (even if exception occurs so multirun won't fail)
- if find_spec("wandb"): # check if wandb is installed
- import wandb
-
- if wandb.run:
- log.info("Closing wandb!")
- wandb.finish()
-
- return metric_dict, object_dict
-
- return wrap
-
-
-def get_metric_value(metric_dict: dict, metric_name: str) -> float:
- """Safely retrieves value of the metric logged in LightningModule."""
-
- if not metric_name:
- log.info("Metric name is None! Skipping metric value retrieval...")
- return None
-
- if metric_name not in metric_dict:
- raise Exception(
- f"Metric value not found! \n"
- "Make sure metric name logged in LightningModule is correct!\n"
- "Make sure `optimized_metric` name in `hparams_search` config is correct!"
- )
-
- metric_value = metric_dict[metric_name].item()
- log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
-
- return metric_value
diff --git a/fish_speech/webui/__pycache__/launch_utils.cpython-310.pyc b/fish_speech/webui/__pycache__/launch_utils.cpython-310.pyc
deleted file mode 100644
index 0bd0b8af3ea645c95065dbbe9b037384e54ad614..0000000000000000000000000000000000000000
Binary files a/fish_speech/webui/__pycache__/launch_utils.cpython-310.pyc and /dev/null differ
diff --git a/fish_speech/webui/css/style.css b/fish_speech/webui/css/style.css
deleted file mode 100644
index 3c7a22ecc31881a65a76369b0fd889330a0874c7..0000000000000000000000000000000000000000
--- a/fish_speech/webui/css/style.css
+++ /dev/null
@@ -1,161 +0,0 @@
-:root {
- --my-200: #80eeee;
- --my-50: #ecfdf5;
- --water-width: 300px;
- --water-heigh: 300px;
-}
-
-
-/* general styled components */
-.tools {
- align-items: center;
- justify-content: center;
-}
-
-.gradio-button {
- max-width: 2.2em;
- min-width: 2.2em !important;
- height: 2.4em;
- align-self: end;
- line-height: 1em;
- border-radius: 0.5em;
-
-}
-
-.gradio-button.secondary-down, .gradio-button.secondary-down:hover{
- box-shadow: 1px 1px 1px rgba(0,0,0,0.25) inset, 0px 0px 3px rgba(0,0,0,0.15) inset;
-}
-
-/* replace original footer with ours */
-a{
- font-weight: bold;
- cursor: pointer;
- color: #030C14 !important;
-}
-
-footer {
- display: none !important;
-}
-
-#footer{
- text-align: center;
-}
-
-#footer div{
- display: inline-block;
-}
-
-#footer .versions{
- font-size: 85%;
- opacity: 0.85;
-}
-
-/*@keyframes moveBackground {*/
-/* 0% {*/
-/* background-position: 0 0;*/
-/* }*/
-/* 100% {*/
-/* background-position: -100px 100px;*/
-/* }*/
-/*}*/
-@keyframes moveJellyBackground {
- 0% {
- background-position: 0% 50%;
- }
- 50% {
- background-position: 100% 50%;
- }
- 100% {
- background-position: 0% 50%;
- }
-}
-
-.gradio-container {
- position: absolute;
- z-index: 10;
-}
-
-
-.quan {
- position: absolute;
- bottom: 0;
- width: var(--water-width);
- height: var(--water-heigh);
- border-radius: 0;
- /*border: 3px solid rgb(246, 247, 248);*/
- /*box-shadow: 0 0 0 3px rgb(41, 134, 196);*/
- z-index: 0;
-
-}
-
-.quan:last-child {
- margin-right: 0;
-}
-
-.shui {
- position: absolute;
- top: 0;
- left: 0;
- width: 100%;
- height: 100%;
- background-color: rgb(23, 106, 201);
- border-radius: 0;
- overflow: hidden;
- z-index: 0;
-}
-
-.shui::after {
-
- content: '';
- position: absolute;
- top: 20%;
- left: 50%;
- width: 150%;
- height: 150%;
- border-radius: 40%;
- background-image: radial-gradient(circle at 0% 50%, #dcfcf1, var(--my-50) 50%);
- animation: shi 5s linear infinite;
-}
-
-@keyframes shi {
- 0% {
- transform: translate(-50%, -65%) rotate(0deg);
- }
- 100% {
- transform: translate(-50%, -65%) rotate(360deg);
- }
-}
-
-.shui::before {
- content: '';
- position: absolute;
- top: 20%;
- left: 50%;
- width: 150%;
- height: 150%;
- border-radius: 42%;
- background-color: rgb(240, 228, 228, 0.2);
- animation: xu 7s linear infinite;
-}
-
-@keyframes xu {
- 0% {
- transform: translate(-50%, -60%) rotate(0deg);
- }
- 100% {
- transform: translate(-50%, -60%) rotate(360deg);
- }
-}
-
-fieldset.data_src div.wrap label {
- background: #f8bffee0 !important;
-}
-
-.scrollable-component {
- max-height: 100px;
- overflow-y: auto;
-}
-
-#file_accordion {
- max-height: 220px !important;
-}
diff --git a/fish_speech/webui/html/footer.html b/fish_speech/webui/html/footer.html
deleted file mode 100644
index ac1745aa6f41f86a17e3d95564c2bf7a8d7bb615..0000000000000000000000000000000000000000
--- a/fish_speech/webui/html/footer.html
+++ /dev/null
@@ -1,11 +0,0 @@
-
-
-
-{versions}
-
diff --git a/fish_speech/webui/js/animate.js b/fish_speech/webui/js/animate.js
deleted file mode 100644
index 0637a541a8e704632a42b89bdf1471b26e7bb868..0000000000000000000000000000000000000000
--- a/fish_speech/webui/js/animate.js
+++ /dev/null
@@ -1,69 +0,0 @@
-
-function createGradioAnimation() {
- const params = new URLSearchParams(window.location.search);
- if (!params.has('__theme')) {
- params.set('__theme', 'light');
- window.location.search = params.toString();
- }
-
- var gradioApp = document.querySelector('gradio-app');
- if (gradioApp) {
-
- document.documentElement.style.setProperty('--my-200', '#80eeee');
- document.documentElement.style.setProperty('--my-50', '#ecfdf5');
-
- // gradioApp.style.position = 'relative';
- // gradioApp.style.backgroundSize = '200% 200%';
- // gradioApp.style.animation = 'moveJellyBackground 10s ease infinite';
- // gradioApp.style.backgroundImage = 'radial-gradient(circle at 0% 50%, var(--my-200), var(--my-50) 50%)';
- // gradioApp.style.display = 'flex';
- // gradioApp.style.justifyContent = 'flex-start';
- // gradioApp.style.flexWrap = 'nowrap';
- // gradioApp.style.overflowX = 'auto';
-
- // for (let i = 0; i < 6; i++) {
- // var quan = document.createElement('div');
- // quan.className = 'quan';
- // gradioApp.insertBefore(quan, gradioApp.firstChild);
- // quan.id = 'quan' + i.toString();
- // quan.style.left = 'calc(var(--water-width) * ' + i.toString() + ')';
- // var quanContainer = document.querySelector('.quan');
- // if (quanContainer) {
- // var shui = document.createElement('div');
- // shui.className = 'shui';
- // quanContainer.insertBefore(shui, quanContainer.firstChild)
- // }
- // }
- }
-
- var container = document.createElement('div');
- container.id = 'gradio-animation';
- container.style.fontSize = '2em';
- container.style.fontFamily = 'Maiandra GD, ui-monospace, monospace';
- container.style.fontWeight = 'bold';
- container.style.textAlign = 'center';
- container.style.marginBottom = '20px';
-
- var text = 'Welcome to Fish-Speech!';
- for (var i = 0; i < text.length; i++) {
- (function(i){
- setTimeout(function(){
- var letter = document.createElement('span');
- letter.style.opacity = '0';
- letter.style.transition = 'opacity 0.5s';
- letter.innerText = text[i];
-
- container.appendChild(letter);
-
- setTimeout(function() {
- letter.style.opacity = '1';
- }, 50);
- }, i * 200);
- })(i);
- }
-
- var gradioContainer = document.querySelector('.gradio-container');
- gradioContainer.insertBefore(container, gradioContainer.firstChild);
-
- return 'Animation created';
-}
diff --git a/fish_speech/webui/launch_utils.py b/fish_speech/webui/launch_utils.py
deleted file mode 100644
index 2f57b595a20177800dbedd71faef573ee8398418..0000000000000000000000000000000000000000
--- a/fish_speech/webui/launch_utils.py
+++ /dev/null
@@ -1,120 +0,0 @@
-import importlib.util
-import os
-import subprocess
-import sys
-from functools import lru_cache
-from pathlib import Path
-from typing import Iterable
-
-import gradio as gr
-from gradio.themes.base import Base
-from gradio.themes.utils import colors, fonts, sizes
-
-GIT = (
- (Path(os.environ.get("GIT_HOME", "")) / "git").resolve()
- if sys.platform == "win32"
- else "git"
-)
-GIT = str(GIT)
-
-
-def is_module_installed(module_name: str) -> bool:
- spec = importlib.util.find_spec(module_name)
- return spec is not None
-
-
-@lru_cache()
-def commit_hash():
- try:
- return subprocess.check_output(
- [GIT, "log", "-1", "--format='%h %s'"], shell=False, encoding="utf8"
- ).strip()
- except Exception:
- return ""
-
-
-def versions_html():
- import torch
-
- python_version = ".".join([str(x) for x in sys.version_info[0:3]])
- commit = commit_hash()
- hash = commit.strip("'").split(" ")[0]
-
- return f"""
-version: {hash}
- •
-python: {python_version}
- •
-torch: {getattr(torch, '__long_version__',torch.__version__)}
- •
-gradio: {gr.__version__}
- •
-author: fishaudio
-"""
-
-
-def version_check(commit):
- try:
- import requests
-
- commits = requests.get(
- "https://api.github.com/repos/fishaudio/fish-speech/branches/main"
- ).json()
- if commit != "" and commits["commit"]["sha"] != commit:
- print("--------------------------------------------------------")
- print("| You are not up to date with the most recent release. |")
- print("| Consider running `git pull` to update. |")
- print("--------------------------------------------------------")
- elif commits["commit"]["sha"] == commit:
- print("You are up to date with the most recent release.")
- else:
- print("Not a git clone, can't perform version check.")
- except Exception as e:
- print("version check failed", e)
-
-
-class Seafoam(Base):
- def __init__(
- self,
- *,
- primary_hue: colors.Color | str = colors.emerald,
- secondary_hue: colors.Color | str = colors.blue,
- neutral_hue: colors.Color | str = colors.blue,
- spacing_size: sizes.Size | str = sizes.spacing_md,
- radius_size: sizes.Size | str = sizes.radius_md,
- text_size: sizes.Size | str = sizes.text_lg,
- font: fonts.Font | str | Iterable[fonts.Font | str] = (
- fonts.GoogleFont("Quicksand"),
- "ui-sans-serif",
- "sans-serif",
- ),
- font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
- fonts.GoogleFont("IBM Plex Mono"),
- "ui-monospace",
- "monospace",
- ),
- ):
- super().__init__(
- primary_hue=primary_hue,
- secondary_hue=secondary_hue,
- neutral_hue=neutral_hue,
- spacing_size=spacing_size,
- radius_size=radius_size,
- text_size=text_size,
- font=font,
- font_mono=font_mono,
- )
- super().set(
- button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)",
- button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)",
- button_primary_text_color="white",
- button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
- slider_color="*secondary_300",
- slider_color_dark="*secondary_600",
- block_title_text_weight="600",
- block_border_width="3px",
- block_shadow="*shadow_drop_lg",
- button_shadow="*shadow_drop_lg",
- button_small_padding="0px",
- button_large_padding="3px",
- )
diff --git a/fish_speech/webui/manage.py b/fish_speech/webui/manage.py
deleted file mode 100644
index 4ec3fcac25de3cc7d239c4903403d1a4cd81567b..0000000000000000000000000000000000000000
--- a/fish_speech/webui/manage.py
+++ /dev/null
@@ -1,1239 +0,0 @@
-from __future__ import annotations
-
-import os
-
-os.environ["USE_LIBUV"] = "0"
-import datetime
-import html
-import json
-import platform
-import shutil
-import signal
-import subprocess
-import sys
-from pathlib import Path
-
-import gradio as gr
-import psutil
-import yaml
-from loguru import logger
-from tqdm import tqdm
-
-PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python")
-sys.path.insert(0, "")
-print(sys.path)
-cur_work_dir = Path(os.getcwd()).resolve()
-print("You are in ", str(cur_work_dir))
-
-from fish_speech.i18n import i18n
-from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html
-
-config_path = cur_work_dir / "fish_speech" / "configs"
-vqgan_yml_path = config_path / "firefly_gan_vq.yaml"
-llama_yml_path = config_path / "text2semantic_finetune.yaml"
-
-env = os.environ.copy()
-env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0"
-
-seafoam = Seafoam()
-
-
-def build_html_error_message(error):
- return f"""
-
- {html.escape(error)}
-
- """
-
-
-def build_html_ok_message(msg):
- return f"""
-
- {html.escape(msg)}
-
- """
-
-
-def build_html_href(link, desc, msg):
- return f"""
-
- {html.escape(msg)}
- {desc}
-
- """
-
-
-def load_data_in_raw(path):
- with open(path, "r", encoding="utf-8") as file:
- data = file.read()
- return str(data)
-
-
-def kill_proc_tree(pid, including_parent=True):
- try:
- parent = psutil.Process(pid)
- except psutil.NoSuchProcess:
- # Process already terminated
- return
-
- children = parent.children(recursive=True)
- for child in children:
- try:
- os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
- except OSError:
- pass
- if including_parent:
- try:
- os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
- except OSError:
- pass
-
-
-system = platform.system()
-p_label = None
-p_infer = None
-p_tensorboard = None
-
-
-def kill_process(pid):
- if system == "Windows":
- cmd = "taskkill /t /f /pid %s" % pid
- # os.system(cmd)
- subprocess.run(cmd)
- else:
- kill_proc_tree(pid)
-
-
-def change_label(if_label):
- global p_label
- if if_label == True and p_label is None:
- url = "http://localhost:3000"
- remote_url = "https://text-labeler.pages.dev/"
- try:
- p_label = subprocess.Popen(
- [
- (
- "asr-label-linux-x64"
- if sys.platform == "linux"
- else "asr-label-win-x64.exe"
- )
- ]
- )
- except FileNotFoundError:
- logger.warning("asr-label execution not found!")
-
- yield build_html_href(
- link=remote_url,
- desc=i18n("Optional online ver"),
- msg=i18n("Opened labeler in browser"),
- )
-
- elif if_label == False and p_label is not None:
- kill_process(p_label.pid)
- p_label = None
- yield build_html_ok_message("Nothing")
-
-
-def clean_infer_cache():
- import tempfile
-
- temp_dir = Path(tempfile.gettempdir())
- gradio_dir = str(temp_dir / "gradio")
- try:
- shutil.rmtree(gradio_dir)
- logger.info(f"Deleted cached audios: {gradio_dir}")
- except PermissionError:
- logger.info(f"Permission denied: Unable to delete {gradio_dir}")
- except FileNotFoundError:
- logger.info(f"{gradio_dir} was not found")
- except Exception as e:
- logger.info(f"An error occurred: {e}")
-
-
-def change_infer(
- if_infer,
- host,
- port,
- infer_decoder_model,
- infer_decoder_config,
- infer_llama_model,
- infer_compile,
-):
- global p_infer
- if if_infer == True and p_infer == None:
- env = os.environ.copy()
-
- env["GRADIO_SERVER_NAME"] = host
- env["GRADIO_SERVER_PORT"] = port
- # 启动第二个进程
- url = f"http://{host}:{port}"
- yield build_html_ok_message(
- i18n("Inferring interface is launched at {}").format(url)
- )
-
- clean_infer_cache()
-
- p_infer = subprocess.Popen(
- [
- PYTHON,
- "tools/webui.py",
- "--decoder-checkpoint-path",
- infer_decoder_model,
- "--decoder-config-name",
- infer_decoder_config,
- "--llama-checkpoint-path",
- infer_llama_model,
- ]
- + (["--compile"] if infer_compile == "Yes" else []),
- env=env,
- )
-
- elif if_infer == False and p_infer is not None:
- kill_process(p_infer.pid)
- p_infer = None
- yield build_html_error_message(i18n("Infer interface is closed"))
-
-
-js = load_data_in_raw("fish_speech/webui/js/animate.js")
-css = load_data_in_raw("fish_speech/webui/css/style.css")
-
-data_pre_output = (cur_work_dir / "data").resolve()
-default_model_output = (cur_work_dir / "results").resolve()
-default_filelist = data_pre_output / "detect.list"
-data_pre_output.mkdir(parents=True, exist_ok=True)
-
-items = []
-dict_items = {}
-
-
-def load_yaml_data_in_fact(yml_path):
- with open(yml_path, "r", encoding="utf-8") as file:
- yml = yaml.safe_load(file)
- return yml
-
-
-def write_yaml_data_in_fact(yml, yml_path):
- with open(yml_path, "w", encoding="utf-8") as file:
- yaml.safe_dump(yml, file, allow_unicode=True)
- return yml
-
-
-def generate_tree(directory, depth=0, max_depth=None, prefix=""):
- if max_depth is not None and depth > max_depth:
- return ""
-
- tree_str = ""
- files = []
- directories = []
- for item in os.listdir(directory):
- if os.path.isdir(os.path.join(directory, item)):
- directories.append(item)
- else:
- files.append(item)
-
- entries = directories + files
- for i, entry in enumerate(entries):
- connector = "├── " if i < len(entries) - 1 else "└── "
- tree_str += f"{prefix}{connector}{entry}
"
- if i < len(directories):
- extension = "│ " if i < len(entries) - 1 else " "
- tree_str += generate_tree(
- os.path.join(directory, entry),
- depth + 1,
- max_depth,
- prefix=prefix + extension,
- )
- return tree_str
-
-
-def new_explorer(data_path, max_depth):
- return gr.Markdown(
- elem_classes=["scrollable-component"],
- value=generate_tree(data_path, max_depth=max_depth),
- )
-
-
-def add_item(
- folder: str,
- method: str,
- label_lang: str,
- if_initial_prompt: bool,
- initial_prompt: str | None,
-):
- folder = folder.strip(" ").strip('"')
-
- folder_path = Path(folder)
-
- if folder and folder not in items and data_pre_output not in folder_path.parents:
- if folder_path.is_dir():
- items.append(folder)
- dict_items[folder] = dict(
- type="folder",
- method=method,
- label_lang=label_lang,
- initial_prompt=initial_prompt if if_initial_prompt else None,
- )
- elif folder:
- err = folder
- return gr.Checkboxgroup(choices=items), build_html_error_message(
- i18n("Invalid path: {}").format(err)
- )
-
- formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
- logger.info("After Adding: " + formatted_data)
- gr.Info(formatted_data)
- return gr.Checkboxgroup(choices=items), build_html_ok_message(
- i18n("Added path successfully!")
- )
-
-
-def remove_items(selected_items):
- global items, dict_items
- to_remove = [item for item in items if item in selected_items]
- for item in to_remove:
- del dict_items[item]
- items = [item for item in items if item in dict_items.keys()]
- formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
- logger.info(formatted_data)
- gr.Warning("After Removing: " + formatted_data)
- return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message(
- i18n("Removed path successfully!")
- )
-
-
-def show_selected(options):
- selected_options = ", ".join(options)
-
- if options:
- return i18n("Selected: {}").format(selected_options)
- else:
- return i18n("No selected options")
-
-
-from pydub import AudioSegment
-
-
-def convert_to_mono_in_place(audio_path: Path):
- audio = AudioSegment.from_file(audio_path)
- if audio.channels > 1:
- mono_audio = audio.set_channels(1)
- mono_audio.export(audio_path, format=audio_path.suffix[1:])
- logger.info(f"Convert {audio_path} successfully")
-
-
-def list_copy(list_file_path, method):
- wav_root = data_pre_output
- lst = []
- with list_file_path.open("r", encoding="utf-8") as file:
- for line in tqdm(file, desc="Processing audio/transcript"):
- wav_path, speaker_name, language, text = line.strip().split("|")
- original_wav_path = Path(wav_path)
- target_wav_path = (
- wav_root / original_wav_path.parent.name / original_wav_path.name
- )
- lst.append(f"{target_wav_path}|{speaker_name}|{language}|{text}")
- if target_wav_path.is_file():
- continue
- target_wav_path.parent.mkdir(parents=True, exist_ok=True)
- if method == i18n("Copy"):
- shutil.copy(original_wav_path, target_wav_path)
- else:
- shutil.move(original_wav_path, target_wav_path.parent)
- convert_to_mono_in_place(target_wav_path)
- original_lab_path = original_wav_path.with_suffix(".lab")
- target_lab_path = (
- wav_root
- / original_wav_path.parent.name
- / original_wav_path.with_suffix(".lab").name
- )
- if target_lab_path.is_file():
- continue
- if method == i18n("Copy"):
- shutil.copy(original_lab_path, target_lab_path)
- else:
- shutil.move(original_lab_path, target_lab_path.parent)
-
- if method == i18n("Move"):
- with list_file_path.open("w", encoding="utf-8") as file:
- file.writelines("\n".join(lst))
-
- del lst
- return build_html_ok_message(i18n("Use filelist"))
-
-
-def check_files(data_path: str, max_depth: int, label_model: str, label_device: str):
- global dict_items
- data_path = Path(data_path)
- gr.Warning("Pre-processing begins...")
- for item, content in dict_items.items():
- item_path = Path(item)
- tar_path = data_path / item_path.name
-
- if content["type"] == "folder" and item_path.is_dir():
- if content["method"] == i18n("Copy"):
- os.makedirs(tar_path, exist_ok=True)
- shutil.copytree(
- src=str(item_path), dst=str(tar_path), dirs_exist_ok=True
- )
- elif not tar_path.is_dir():
- shutil.move(src=str(item_path), dst=str(tar_path))
-
- for suf in ["wav", "flac", "mp3"]:
- for audio_path in tar_path.glob(f"**/*.{suf}"):
- convert_to_mono_in_place(audio_path)
-
- cur_lang = content["label_lang"]
- initial_prompt = content["initial_prompt"]
-
- transcribe_cmd = [
- PYTHON,
- "tools/whisper_asr.py",
- "--model-size",
- label_model,
- "--device",
- label_device,
- "--audio-dir",
- tar_path,
- "--save-dir",
- tar_path,
- "--language",
- cur_lang,
- ]
-
- if initial_prompt is not None:
- transcribe_cmd += ["--initial-prompt", initial_prompt]
-
- if cur_lang != "IGNORE":
- try:
- gr.Warning("Begin To Transcribe")
- subprocess.run(
- transcribe_cmd,
- env=env,
- )
- except Exception:
- print("Transcription error occurred")
-
- elif content["type"] == "file" and item_path.is_file():
- list_copy(item_path, content["method"])
-
- return build_html_ok_message(i18n("Move files successfully")), new_explorer(
- data_path, max_depth=max_depth
- )
-
-
-def generate_folder_name():
- now = datetime.datetime.now()
- folder_name = now.strftime("%Y%m%d_%H%M%S")
- return folder_name
-
-
-def train_process(
- data_path: str,
- option: str,
- # llama config
- llama_ckpt,
- llama_base_config,
- llama_lr,
- llama_maxsteps,
- llama_data_num_workers,
- llama_data_batch_size,
- llama_data_max_length,
- llama_precision,
- llama_check_interval,
- llama_grad_batches,
- llama_use_speaker,
- llama_use_lora,
-):
-
- backend = "nccl" if sys.platform == "linux" else "gloo"
-
- new_project = generate_folder_name()
- print("New Project Name: ", new_project)
-
- if option == "VQGAN":
- msg = "Skipped VQGAN Training."
- gr.Warning(msg)
- logger.info(msg)
-
- if option == "LLAMA":
- msg = "LLAMA Training begins..."
- gr.Warning(msg)
- logger.info(msg)
- subprocess.run(
- [
- PYTHON,
- "tools/vqgan/extract_vq.py",
- str(data_pre_output),
- "--num-workers",
- "1",
- "--batch-size",
- "16",
- "--config-name",
- "firefly_gan_vq",
- "--checkpoint-path",
- "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
- ]
- )
-
- subprocess.run(
- [
- PYTHON,
- "tools/llama/build_dataset.py",
- "--input",
- str(data_pre_output),
- "--text-extension",
- ".lab",
- "--num-workers",
- "16",
- ]
- )
- ckpt_path = "checkpoints/fish-speech-1.4/model.pth"
- lora_prefix = "lora_" if llama_use_lora else ""
- llama_name = lora_prefix + "text2semantic_" + new_project
- latest = next(
- iter(
- sorted(
- [
- str(p.relative_to("results"))
- for p in Path("results").glob(lora_prefix + "text2sem*/")
- ],
- reverse=True,
- )
- ),
- llama_name,
- )
- project = (
- llama_name
- if llama_ckpt == i18n("new")
- else (
- latest
- if llama_ckpt == i18n("latest")
- else Path(llama_ckpt).relative_to("results")
- )
- )
- logger.info(project)
-
- if llama_check_interval > llama_maxsteps:
- llama_check_interval = llama_maxsteps
-
- train_cmd = [
- PYTHON,
- "fish_speech/train.py",
- "--config-name",
- "text2semantic_finetune",
- f"project={project}",
- f"trainer.strategy.process_group_backend={backend}",
- f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
- f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
- f"model.optimizer.lr={llama_lr}",
- f"trainer.max_steps={llama_maxsteps}",
- f"data.num_workers={llama_data_num_workers}",
- f"data.batch_size={llama_data_batch_size}",
- f"max_length={llama_data_max_length}",
- f"trainer.precision={llama_precision}",
- f"trainer.val_check_interval={llama_check_interval}",
- f"trainer.accumulate_grad_batches={llama_grad_batches}",
- f"train_dataset.interactive_prob={llama_use_speaker}",
- ] + ([f"+lora@model.model.lora_config=r_8_alpha_16"] if llama_use_lora else [])
- logger.info(train_cmd)
- subprocess.run(train_cmd)
-
- return build_html_ok_message(i18n("Training stopped"))
-
-
-def tensorboard_process(
- if_tensorboard: bool,
- tensorboard_dir: str,
- host: str,
- port: str,
-):
- global p_tensorboard
- if if_tensorboard == True and p_tensorboard == None:
- url = f"http://{host}:{port}"
- yield build_html_ok_message(
- i18n("Tensorboard interface is launched at {}").format(url)
- )
- prefix = ["tensorboard"]
- if Path("fishenv").exists():
- prefix = ["fishenv/env/python.exe", "fishenv/env/Scripts/tensorboard.exe"]
-
- p_tensorboard = subprocess.Popen(
- prefix
- + [
- "--logdir",
- tensorboard_dir,
- "--host",
- host,
- "--port",
- port,
- "--reload_interval",
- "120",
- ]
- )
- elif if_tensorboard == False and p_tensorboard != None:
- kill_process(p_tensorboard.pid)
- p_tensorboard = None
- yield build_html_error_message(i18n("Tensorboard interface is closed"))
-
-
-def fresh_tb_dir():
- return gr.Dropdown(
- choices=[str(p) for p in Path("results").glob("**/tensorboard/")]
- )
-
-
-def list_decoder_models():
- paths = [str(p) for p in Path("checkpoints").glob("fish*/firefly*.pth")]
- if not paths:
- logger.warning("No decoder model found")
- return paths
-
-
-def list_llama_models():
- choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")]
- choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")]
- choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")]
- choices = sorted(choices, reverse=True)
- if not choices:
- logger.warning("No LLaMA model found")
- return choices
-
-
-def list_lora_llama_models():
- choices = sorted(
- [str(p) for p in Path("results").glob("lora*/**/*.ckpt")], reverse=True
- )
- if not choices:
- logger.warning("No LoRA LLaMA model found")
- return choices
-
-
-def fresh_decoder_model():
- return gr.Dropdown(choices=list_decoder_models())
-
-
-def fresh_llama_ckpt(llama_use_lora):
- return gr.Dropdown(
- choices=[i18n("latest"), i18n("new")]
- + (
- [str(p) for p in Path("results").glob("text2sem*/")]
- if not llama_use_lora
- else [str(p) for p in Path("results").glob("lora_*/")]
- )
- )
-
-
-def fresh_llama_model():
- return gr.Dropdown(choices=list_llama_models())
-
-
-def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_output):
- if (
- lora_weight is None
- or not Path(lora_weight).exists()
- or not Path(llama_weight).exists()
- ):
- return build_html_error_message(
- i18n(
- "Path error, please check the model file exists in the corresponding path"
- )
- )
- gr.Warning("Merging begins...")
- merge_cmd = [
- PYTHON,
- "tools/llama/merge_lora.py",
- "--lora-config",
- "r_8_alpha_16",
- "--lora-weight",
- lora_weight,
- "--output",
- llama_lora_output + "_" + generate_folder_name(),
- ]
- logger.info(merge_cmd)
- subprocess.run(merge_cmd)
- return build_html_ok_message(i18n("Merge successfully"))
-
-
-def llama_quantify(llama_weight, quantify_mode):
- if llama_weight is None or not Path(llama_weight).exists():
- return build_html_error_message(
- i18n(
- "Path error, please check the model file exists in the corresponding path"
- )
- )
-
- gr.Warning("Quantifying begins...")
-
- now = generate_folder_name()
- quantify_cmd = [
- PYTHON,
- "tools/llama/quantize.py",
- "--checkpoint-path",
- llama_weight,
- "--mode",
- quantify_mode,
- "--timestamp",
- now,
- ]
- logger.info(quantify_cmd)
- subprocess.run(quantify_cmd)
- if quantify_mode == "int8":
- quantize_path = str(
- Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-{now}"
- )
- else:
- quantize_path = str(
- Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-g128-{now}"
- )
- return build_html_ok_message(
- i18n("Quantify successfully") + f"Path: {quantize_path}"
- )
-
-
-init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
-init_llama_yml = load_yaml_data_in_fact(llama_yml_path)
-
-with gr.Blocks(
- head="",
- js=js,
- theme=seafoam,
- analytics_enabled=False,
- title="Fish Speech",
-) as demo:
- with gr.Row():
- with gr.Column():
- with gr.Tab("\U0001F4D6 " + i18n("Data Preprocessing")):
- with gr.Row():
- textbox = gr.Textbox(
- label="\U0000270F "
- + i18n("Input Audio & Source Path for Transcription"),
- info=i18n("Speaker is identified by the folder name"),
- interactive=True,
- )
- with gr.Row(equal_height=False):
- with gr.Column():
- output_radio = gr.Radio(
- label="\U0001F4C1 "
- + i18n("Select source file processing method"),
- choices=[i18n("Copy"), i18n("Move")],
- value=i18n("Copy"),
- interactive=True,
- )
- with gr.Column():
- error = gr.HTML(label=i18n("Error Message"))
- if_label = gr.Checkbox(
- label=i18n("Open Labeler WebUI"), scale=0, show_label=True
- )
-
- with gr.Row():
- label_device = gr.Dropdown(
- label=i18n("Labeling Device"),
- info=i18n(
- "It is recommended to use CUDA, if you have low configuration, use CPU"
- ),
- choices=["cpu", "cuda"],
- value="cuda",
- interactive=True,
- )
- label_model = gr.Dropdown(
- label=i18n("Whisper Model"),
- info=i18n("Faster Whisper, Up to 5g GPU memory usage"),
- choices=["large-v3", "medium"],
- value="large-v3",
- interactive=True,
- )
- label_radio = gr.Dropdown(
- label=i18n("Optional Label Language"),
- info=i18n(
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format"
- ),
- choices=[
- (i18n("Chinese"), "zh"),
- (i18n("English"), "en"),
- (i18n("Japanese"), "ja"),
- (i18n("Disabled"), "IGNORE"),
- (i18n("auto"), "auto"),
- ],
- value="IGNORE",
- interactive=True,
- )
-
- with gr.Row():
- if_initial_prompt = gr.Checkbox(
- value=False,
- label=i18n("Enable Initial Prompt"),
- min_width=120,
- scale=0,
- )
- initial_prompt = gr.Textbox(
- label=i18n("Initial Prompt"),
- info=i18n(
- "Initial prompt can provide contextual or vocabulary-specific guidance to the model."
- ),
- placeholder="This audio introduces the basic concepts and applications of artificial intelligence and machine learning.",
- interactive=False,
- )
-
- with gr.Row():
- add_button = gr.Button(
- "\U000027A1 " + i18n("Add to Processing Area"),
- variant="primary",
- )
- remove_button = gr.Button(
- "\U000026D4 " + i18n("Remove Selected Data")
- )
-
- with gr.Tab("\U0001F6E0 " + i18n("Training Configuration")):
- with gr.Row():
- model_type_radio = gr.Radio(
- label=i18n(
- "Select the model to be trained (Depending on the Tab page you are on)"
- ),
- interactive=False,
- choices=["VQGAN", "LLAMA"],
- value="VQGAN",
- )
- with gr.Row():
- with gr.Tabs():
- with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page:
- gr.HTML("You don't need to train this model!")
-
- with gr.Tab(label=i18n("LLAMA Configuration")) as llama_page:
- with gr.Row(equal_height=False):
- llama_use_lora = gr.Checkbox(
- label=i18n("Use LoRA"),
- info=i18n(
- "Use LoRA can save GPU memory, but may reduce the quality of the model"
- ),
- value=True,
- interactive=True,
- )
- llama_ckpt = gr.Dropdown(
- label=i18n("Select LLAMA ckpt"),
- choices=[i18n("latest"), i18n("new")]
- + [
- str(p)
- for p in Path("results").glob("text2sem*/")
- ]
- + [str(p) for p in Path("results").glob("lora*/")],
- value=i18n("latest"),
- interactive=True,
- )
- with gr.Row(equal_height=False):
- llama_lr_slider = gr.Slider(
- label=i18n("Initial Learning Rate"),
- info=i18n(
- "lr smaller -> usually train slower but more stable"
- ),
- interactive=True,
- minimum=1e-5,
- maximum=1e-4,
- step=1e-5,
- value=5e-5,
- )
- llama_maxsteps_slider = gr.Slider(
- label=i18n("Maximum Training Steps"),
- info=i18n(
- "recommend: max_steps = num_audios // batch_size * (2 to 5)"
- ),
- interactive=True,
- minimum=1,
- maximum=10000,
- step=1,
- value=50,
- )
- with gr.Row(equal_height=False):
- llama_base_config = gr.Dropdown(
- label=i18n("Model Size"),
- choices=[
- "text2semantic_finetune",
- ],
- value="text2semantic_finetune",
- )
- llama_data_num_workers_slider = gr.Slider(
- label=i18n("Number of Workers"),
- minimum=1,
- maximum=32,
- step=1,
- value=4,
- )
- with gr.Row(equal_height=False):
- llama_data_batch_size_slider = gr.Slider(
- label=i18n("Batch Size"),
- interactive=True,
- minimum=1,
- maximum=32,
- step=1,
- value=2,
- )
- llama_data_max_length_slider = gr.Slider(
- label=i18n("Maximum Length per Sample"),
- interactive=True,
- minimum=1024,
- maximum=4096,
- step=128,
- value=2048,
- )
- with gr.Row(equal_height=False):
- llama_precision_dropdown = gr.Dropdown(
- label=i18n("Precision"),
- info=i18n(
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU"
- ),
- interactive=True,
- choices=["32", "bf16-true", "16-mixed"],
- value="bf16-true",
- )
- llama_check_interval_slider = gr.Slider(
- label=i18n("Save model every n steps"),
- info=i18n(
- "make sure that it's not greater than max_steps"
- ),
- interactive=True,
- minimum=1,
- maximum=1000,
- step=1,
- value=50,
- )
- with gr.Row(equal_height=False):
- llama_grad_batches = gr.Slider(
- label=i18n("Accumulate Gradient Batches"),
- interactive=True,
- minimum=1,
- maximum=20,
- step=1,
- value=init_llama_yml["trainer"][
- "accumulate_grad_batches"
- ],
- )
- llama_use_speaker = gr.Slider(
- label=i18n(
- "Probability of applying Speaker Condition"
- ),
- interactive=True,
- minimum=0.1,
- maximum=1.0,
- step=0.05,
- value=init_llama_yml["train_dataset"][
- "interactive_prob"
- ],
- )
-
- with gr.Tab(label=i18n("Merge LoRA"), id=4):
- with gr.Row(equal_height=False):
- llama_weight = gr.Dropdown(
- label=i18n("Base LLAMA Model"),
- info=i18n(
- "Type the path or select from the dropdown"
- ),
- choices=[
- "checkpoints/fish-speech-1.4/model.pth",
- ],
- value="checkpoints/fish-speech-1.4/model.pth",
- allow_custom_value=True,
- interactive=True,
- )
- with gr.Row(equal_height=False):
- lora_weight = gr.Dropdown(
- label=i18n("LoRA Model to be merged"),
- info=i18n(
- "Type the path or select from the dropdown"
- ),
- choices=[
- str(p)
- for p in Path("results").glob("lora*/**/*.ckpt")
- ],
- allow_custom_value=True,
- interactive=True,
- )
- lora_llama_config = gr.Dropdown(
- label=i18n("LLAMA Model Config"),
- info=i18n(
- "Type the path or select from the dropdown"
- ),
- choices=[
- "text2semantic_finetune",
- ],
- value="text2semantic_finetune",
- allow_custom_value=True,
- )
- with gr.Row(equal_height=False):
- llama_lora_output = gr.Dropdown(
- label=i18n("Output Path"),
- info=i18n(
- "Type the path or select from the dropdown"
- ),
- value="checkpoints/merged",
- choices=["checkpoints/merged"],
- allow_custom_value=True,
- interactive=True,
- )
- with gr.Row(equal_height=False):
- llama_lora_merge_btn = gr.Button(
- value=i18n("Merge"), variant="primary"
- )
-
- with gr.Tab(label=i18n("Model Quantization"), id=5):
- with gr.Row(equal_height=False):
- llama_weight_to_quantify = gr.Dropdown(
- label=i18n("Base LLAMA Model"),
- info=i18n(
- "Type the path or select from the dropdown"
- ),
- choices=list_llama_models(),
- value="checkpoints/fish-speech-1.4",
- allow_custom_value=True,
- interactive=True,
- )
- quantify_mode = gr.Dropdown(
- label=i18n("Post-quantification Precision"),
- info=i18n(
- "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase"
- ),
- choices=["int8", "int4"],
- value="int8",
- allow_custom_value=False,
- interactive=True,
- )
- with gr.Row(equal_height=False):
- llama_quantify_btn = gr.Button(
- value=i18n("Quantify"), variant="primary"
- )
-
- with gr.Tab(label="Tensorboard", id=6):
- with gr.Row(equal_height=False):
- tb_host = gr.Textbox(
- label=i18n("Tensorboard Host"), value="127.0.0.1"
- )
- tb_port = gr.Textbox(
- label=i18n("Tensorboard Port"), value="11451"
- )
- with gr.Row(equal_height=False):
- tb_dir = gr.Dropdown(
- label=i18n("Tensorboard Log Path"),
- allow_custom_value=True,
- choices=[
- str(p)
- for p in Path("results").glob("**/tensorboard/")
- ],
- )
- with gr.Row(equal_height=False):
- if_tb = gr.Checkbox(
- label=i18n("Open Tensorboard"),
- )
-
- with gr.Tab("\U0001F9E0 " + i18n("Inference Configuration")):
- with gr.Column():
- with gr.Row():
- with gr.Accordion(
- label="\U0001F5A5 "
- + i18n("Inference Server Configuration"),
- open=False,
- ):
- with gr.Row():
- infer_host_textbox = gr.Textbox(
- label=i18n("WebUI Host"), value="127.0.0.1"
- )
- infer_port_textbox = gr.Textbox(
- label=i18n("WebUI Port"), value="7862"
- )
- with gr.Row():
- infer_decoder_model = gr.Dropdown(
- label=i18n("Decoder Model Path"),
- info=i18n(
- "Type the path or select from the dropdown"
- ),
- choices=list_decoder_models(),
- value="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
- allow_custom_value=True,
- )
- infer_decoder_config = gr.Dropdown(
- label=i18n("Decoder Model Config"),
- info=i18n("Changing with the Model Path"),
- value="firefly_gan_vq",
- choices=[
- "firefly_gan_vq",
- ],
- allow_custom_value=True,
- )
- with gr.Row():
- infer_llama_model = gr.Dropdown(
- label=i18n("LLAMA Model Path"),
- info=i18n(
- "Type the path or select from the dropdown"
- ),
- value="checkpoints/fish-speech-1.4",
- choices=list_llama_models(),
- allow_custom_value=True,
- )
-
- with gr.Row():
- infer_compile = gr.Radio(
- label=i18n("Compile Model"),
- info=i18n(
- "Compile the model can significantly reduce the inference time, but will increase cold start time"
- ),
- choices=["Yes", "No"],
- value=(
- "Yes" if (sys.platform == "linux") else "No"
- ),
- interactive=is_module_installed("triton"),
- )
-
- with gr.Row():
- infer_checkbox = gr.Checkbox(
- label=i18n("Open Inference Server")
- )
- infer_error = gr.HTML(label=i18n("Inference Server Error"))
-
- with gr.Column():
- train_error = gr.HTML(label=i18n("Training Error"))
- checkbox_group = gr.CheckboxGroup(
- label="\U0001F4CA " + i18n("Data Source"),
- info=i18n(
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list."
- ),
- elem_classes=["data_src"],
- )
- train_box = gr.Textbox(
- label=i18n("Data Preprocessing Path"),
- value=str(data_pre_output),
- interactive=False,
- )
- model_box = gr.Textbox(
- label="\U0001F4BE " + i18n("Model Output Path"),
- value=str(default_model_output),
- interactive=False,
- )
-
- with gr.Accordion(
- i18n(
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)"
- ),
- elem_classes=["scrollable-component"],
- elem_id="file_accordion",
- ):
- tree_slider = gr.Slider(
- minimum=0,
- maximum=3,
- value=0,
- step=1,
- show_label=False,
- container=False,
- )
- file_markdown = new_explorer(str(data_pre_output), 0)
- with gr.Row(equal_height=False):
- admit_btn = gr.Button(
- "\U00002705 " + i18n("File Preprocessing"),
- variant="primary",
- )
- fresh_btn = gr.Button("\U0001F503", scale=0, min_width=80)
- help_button = gr.Button("\U00002753", scale=0, min_width=80) # question
- train_btn = gr.Button(i18n("Start Training"), variant="primary")
-
- footer = load_data_in_raw("fish_speech/webui/html/footer.html")
- footer = footer.format(
- versions=versions_html(),
- api_docs="https://speech.fish.audio/inference/#http-api",
- )
- gr.HTML(footer, elem_id="footer")
- vqgan_page.select(lambda: "VQGAN", None, model_type_radio)
- llama_page.select(lambda: "LLAMA", None, model_type_radio)
- add_button.click(
- fn=add_item,
- inputs=[textbox, output_radio, label_radio, if_initial_prompt, initial_prompt],
- outputs=[checkbox_group, error],
- )
- remove_button.click(
- fn=remove_items, inputs=[checkbox_group], outputs=[checkbox_group, error]
- )
- checkbox_group.change(fn=show_selected, inputs=checkbox_group, outputs=[error])
- help_button.click(
- fn=None,
- js='() => { window.open("https://speech.fish.audio/", "newwindow", "height=100, width=400, '
- 'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
- )
- if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
- if_initial_prompt.change(
- fn=lambda x: gr.Textbox(value="", interactive=x),
- inputs=[if_initial_prompt],
- outputs=[initial_prompt],
- )
- train_btn.click(
- fn=train_process,
- inputs=[
- train_box,
- model_type_radio,
- # llama config
- llama_ckpt,
- llama_base_config,
- llama_lr_slider,
- llama_maxsteps_slider,
- llama_data_num_workers_slider,
- llama_data_batch_size_slider,
- llama_data_max_length_slider,
- llama_precision_dropdown,
- llama_check_interval_slider,
- llama_grad_batches,
- llama_use_speaker,
- llama_use_lora,
- ],
- outputs=[train_error],
- )
- if_tb.change(
- fn=tensorboard_process,
- inputs=[if_tb, tb_dir, tb_host, tb_port],
- outputs=[train_error],
- )
- tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir])
- infer_decoder_model.change(
- fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model]
- )
- infer_llama_model.change(
- fn=fresh_llama_model, inputs=[], outputs=[infer_llama_model]
- )
- llama_weight.change(fn=fresh_llama_model, inputs=[], outputs=[llama_weight])
- admit_btn.click(
- fn=check_files,
- inputs=[train_box, tree_slider, label_model, label_device],
- outputs=[error, file_markdown],
- )
- fresh_btn.click(
- fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown]
- )
- llama_use_lora.change(
- fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
- )
- llama_ckpt.change(
- fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
- )
- lora_weight.change(
- fn=lambda: gr.Dropdown(choices=list_lora_llama_models()),
- inputs=[],
- outputs=[lora_weight],
- )
- llama_lora_merge_btn.click(
- fn=llama_lora_merge,
- inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output],
- outputs=[train_error],
- )
- llama_quantify_btn.click(
- fn=llama_quantify,
- inputs=[llama_weight_to_quantify, quantify_mode],
- outputs=[train_error],
- )
- infer_checkbox.change(
- fn=change_infer,
- inputs=[
- infer_checkbox,
- infer_host_textbox,
- infer_port_textbox,
- infer_decoder_model,
- infer_decoder_config,
- infer_llama_model,
- infer_compile,
- ],
- outputs=[infer_error],
- )
-
-demo.launch(inbrowser=True)