uto1125 commited on
Commit
371f188
·
verified ·
1 Parent(s): cab1c25

Delete fish_speech

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fish_speech/__pycache__/conversation.cpython-310.pyc +0 -0
  2. fish_speech/__pycache__/scheduler.cpython-310.pyc +0 -0
  3. fish_speech/callbacks/__init__.py +0 -3
  4. fish_speech/callbacks/__pycache__/__init__.cpython-310.pyc +0 -0
  5. fish_speech/callbacks/__pycache__/grad_norm.cpython-310.pyc +0 -0
  6. fish_speech/callbacks/grad_norm.py +0 -113
  7. fish_speech/configs/base.yaml +0 -87
  8. fish_speech/configs/firefly_gan_vq.yaml +0 -33
  9. fish_speech/configs/lora/r_8_alpha_16.yaml +0 -4
  10. fish_speech/configs/text2semantic_finetune.yaml +0 -83
  11. fish_speech/conversation.py +0 -2
  12. fish_speech/datasets/__pycache__/semantic.cpython-310.pyc +0 -0
  13. fish_speech/datasets/concat_repeat.py +0 -53
  14. fish_speech/datasets/protos/__pycache__/text_data_pb2.cpython-310.pyc +0 -0
  15. fish_speech/datasets/protos/__pycache__/text_data_stream.cpython-310.pyc +0 -0
  16. fish_speech/datasets/protos/text-data.proto +0 -24
  17. fish_speech/datasets/protos/text_data_pb2.py +0 -33
  18. fish_speech/datasets/protos/text_data_stream.py +0 -36
  19. fish_speech/datasets/semantic.py +0 -496
  20. fish_speech/datasets/vqgan.py +0 -147
  21. fish_speech/i18n/README.md +0 -27
  22. fish_speech/i18n/__init__.py +0 -3
  23. fish_speech/i18n/__pycache__/__init__.cpython-310.pyc +0 -0
  24. fish_speech/i18n/__pycache__/core.cpython-310.pyc +0 -0
  25. fish_speech/i18n/core.py +0 -40
  26. fish_speech/i18n/locale/en_US.json +0 -122
  27. fish_speech/i18n/locale/es_ES.json +0 -122
  28. fish_speech/i18n/locale/ja_JP.json +0 -123
  29. fish_speech/i18n/locale/pt_BR.json +0 -133
  30. fish_speech/i18n/locale/zh_CN.json +0 -122
  31. fish_speech/i18n/scan.py +0 -122
  32. fish_speech/models/text2semantic/__init__.py +0 -0
  33. fish_speech/models/text2semantic/__pycache__/__init__.cpython-310.pyc +0 -0
  34. fish_speech/models/text2semantic/__pycache__/lit_module.cpython-310.pyc +0 -0
  35. fish_speech/models/text2semantic/__pycache__/llama.cpython-310.pyc +0 -0
  36. fish_speech/models/text2semantic/__pycache__/lora.cpython-310.pyc +0 -0
  37. fish_speech/models/text2semantic/lit_module.py +0 -202
  38. fish_speech/models/text2semantic/llama.py +0 -779
  39. fish_speech/models/text2semantic/lora.py +0 -92
  40. fish_speech/models/vqgan/__init__.py +0 -0
  41. fish_speech/models/vqgan/__pycache__/__init__.cpython-310.pyc +0 -0
  42. fish_speech/models/vqgan/modules/__pycache__/firefly.cpython-310.pyc +0 -0
  43. fish_speech/models/vqgan/modules/__pycache__/fsq.cpython-310.pyc +0 -0
  44. fish_speech/models/vqgan/modules/firefly.py +0 -596
  45. fish_speech/models/vqgan/modules/fsq.py +0 -116
  46. fish_speech/models/vqgan/utils.py +0 -94
  47. fish_speech/scheduler.py +0 -40
  48. fish_speech/text/__init__.py +0 -4
  49. fish_speech/text/__pycache__/__init__.cpython-310.pyc +0 -0
  50. fish_speech/text/__pycache__/clean.cpython-310.pyc +0 -0
fish_speech/__pycache__/conversation.cpython-310.pyc DELETED
Binary file (227 Bytes)
 
fish_speech/__pycache__/scheduler.cpython-310.pyc DELETED
Binary file (1.04 kB)
 
fish_speech/callbacks/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .grad_norm import GradNormMonitor
2
-
3
- __all__ = ["GradNormMonitor"]
 
 
 
 
fish_speech/callbacks/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (239 Bytes)
 
fish_speech/callbacks/__pycache__/grad_norm.cpython-310.pyc DELETED
Binary file (3.79 kB)
 
fish_speech/callbacks/grad_norm.py DELETED
@@ -1,113 +0,0 @@
1
- from typing import Optional, Union
2
-
3
- import lightning.pytorch as pl
4
- import torch
5
- from lightning import LightningModule, Trainer
6
- from lightning.pytorch.callbacks import Callback
7
- from torch import Tensor, nn
8
- from torch.utils._foreach_utils import (
9
- _group_tensors_by_device_and_dtype,
10
- _has_foreach_support,
11
- )
12
-
13
-
14
- @torch.no_grad()
15
- def grad_norm(
16
- parameters: Union[Tensor, list[Tensor]],
17
- norm_type: float = 2.0,
18
- ) -> float:
19
- """
20
- Returns the norm of the gradients of the given parameters.
21
-
22
- Args:
23
- parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
24
- single Tensor that will have gradients normalized
25
- norm_type (float): type of the used p-norm.
26
-
27
- Returns:
28
- Total norm of the parameter gradients (viewed as a single vector).
29
- """ # noqa: E501
30
-
31
- if isinstance(parameters, Tensor):
32
- parameters = [parameters]
33
-
34
- grads = [p.grad for p in parameters if p.grad is not None]
35
- if len(grads) == 0:
36
- return None
37
-
38
- first_device = grads[0].device
39
- grouped_grads: dict[
40
- tuple[torch.device, torch.dtype], list[list[Tensor]]
41
- ] = _group_tensors_by_device_and_dtype(
42
- [[g.detach() for g in grads]]
43
- ) # type: ignore[assignment]
44
-
45
- norms = []
46
- for (device, _), ([grads], _) in grouped_grads.items():
47
- if _has_foreach_support(grads, device=device):
48
- norms.extend(torch._foreach_norm(grads, norm_type))
49
- else:
50
- norms.extend([torch.norm(g, norm_type) for g in grads])
51
-
52
- return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
53
-
54
-
55
- class GradNormMonitor(Callback):
56
- """
57
- Callback that computes the gradient norm of the model parameters.
58
- """
59
-
60
- def __init__(
61
- self,
62
- norm_type: float = 2.0,
63
- logging_interval: str = "step",
64
- sub_module: Optional[Union[str, list[str]]] = None,
65
- ) -> None:
66
- """
67
- Args:
68
- norm_type (float): type of the used p-norm.
69
- logging_interval (str): "step" or "epoch".
70
- """
71
- super().__init__()
72
-
73
- self.norm_type = norm_type
74
- self.logging_interval = logging_interval
75
- self.sub_module = sub_module
76
-
77
- def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
78
- """
79
- Computes the gradient norm of the model parameters and logs it to the logger.
80
-
81
- Args:
82
- trainer (Trainer): The trainer object
83
- model (LightningModule): The current lightningModule
84
- """
85
-
86
- lightning_model = model
87
-
88
- if self.sub_module is None:
89
- return self.log_sub_module_grad_norm(lightning_model, model, "")
90
-
91
- sub_modules = self.sub_module
92
- if isinstance(sub_modules, str):
93
- sub_modules = [sub_modules]
94
-
95
- for sub_module in sub_modules:
96
- self.log_sub_module_grad_norm(
97
- lightning_model, getattr(model, sub_module), f"/{sub_module}"
98
- )
99
-
100
- def log_sub_module_grad_norm(
101
- self, lightning_model: LightningModule, model: nn.Module, path: str
102
- ) -> None:
103
- grad_norm_val = grad_norm(model.parameters(), self.norm_type)
104
- if grad_norm_val is None:
105
- return
106
-
107
- on_step = self.logging_interval == "step"
108
- lightning_model.log(
109
- f"train{path}/grad_norm",
110
- grad_norm_val,
111
- on_step=on_step,
112
- on_epoch=not on_step,
113
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/configs/base.yaml DELETED
@@ -1,87 +0,0 @@
1
- # Base configuration for training a model
2
- paths:
3
- run_dir: results/${project}
4
- ckpt_dir: ${paths.run_dir}/checkpoints
5
-
6
- hydra:
7
- run:
8
- dir: ${paths.run_dir}
9
-
10
- # Lightning Trainer
11
- trainer:
12
- _target_: lightning.pytorch.trainer.Trainer
13
-
14
- default_root_dir: ${paths.run_dir}
15
- accelerator: gpu
16
- num_nodes: 1
17
- devices: auto
18
- strategy:
19
- _target_: lightning.pytorch.strategies.DDPStrategy
20
- process_group_backend: nccl # This should be override when training on windows
21
-
22
- precision: bf16-mixed
23
-
24
- # disable validation by epoch end
25
- check_val_every_n_epoch: null
26
- val_check_interval: 5000
27
- max_steps: 100_000
28
-
29
- # Use torch.backends.cudnn.benchmark to speed up training
30
- benchmark: true
31
-
32
- # Callbacks
33
- callbacks:
34
- model_checkpoint:
35
- _target_: lightning.pytorch.callbacks.ModelCheckpoint
36
- dirpath: ${paths.ckpt_dir}
37
- filename: "step_{step:09d}"
38
- save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
39
- save_top_k: 5 # save 5 latest checkpoints
40
- monitor: step # use step to monitor checkpoints
41
- mode: max # save the latest checkpoint with the highest global_step
42
- every_n_epochs: null # don't save checkpoints by epoch end
43
- every_n_train_steps: 5000 # save checkpoints every 5000 steps
44
- auto_insert_metric_name: false
45
-
46
- model_summary:
47
- _target_: lightning.pytorch.callbacks.ModelSummary
48
- max_depth: 2 # the maximum depth of layer nesting that the summary will include
49
-
50
- learning_rate_monitor:
51
- _target_: lightning.pytorch.callbacks.LearningRateMonitor
52
- logging_interval: step
53
- log_momentum: false
54
-
55
- grad_norm_monitor:
56
- _target_: fish_speech.callbacks.GradNormMonitor
57
- norm_type: 2
58
- logging_interval: step
59
-
60
- # Logger
61
- logger:
62
- tensorboard:
63
- _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
64
- save_dir: "${paths.run_dir}/tensorboard/"
65
- name: null
66
- log_graph: false
67
- default_hp_metric: true
68
- prefix: ""
69
-
70
- # wandb:
71
- # _target_: lightning.pytorch.loggers.wandb.WandbLogger
72
- # # name: "" # name of the run (normally generated by wandb)
73
- # save_dir: "${paths.run_dir}"
74
- # offline: False
75
- # id: null # pass correct id to resume experiment!
76
- # anonymous: null # enable anonymous logging
77
- # project: "fish-speech"
78
- # log_model: False # upload lightning ckpts
79
- # prefix: "" # a string to put at the beginning of metric keys
80
- # # entity: "" # set to name of your wandb team
81
- # group: ""
82
- # tags: ["vq", "hq", "finetune"]
83
- # job_type: ""
84
-
85
- # Loop
86
- train: true
87
- test: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/configs/firefly_gan_vq.yaml DELETED
@@ -1,33 +0,0 @@
1
- _target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture
2
- spec_transform:
3
- _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
4
- sample_rate: 44100
5
- n_mels: 160
6
- n_fft: 2048
7
- hop_length: 512
8
- win_length: 2048
9
- backbone:
10
- _target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder
11
- input_channels: 160
12
- depths: [3, 3, 9, 3]
13
- dims: [128, 256, 384, 512]
14
- drop_path_rate: 0.2
15
- kernel_size: 7
16
- head:
17
- _target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator
18
- hop_length: 512
19
- upsample_rates: [8, 8, 2, 2, 2] # aka. strides
20
- upsample_kernel_sizes: [16, 16, 4, 4, 4]
21
- resblock_kernel_sizes: [3, 7, 11]
22
- resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
23
- num_mels: 512
24
- upsample_initial_channel: 512
25
- pre_conv_kernel_size: 13
26
- post_conv_kernel_size: 13
27
- quantizer:
28
- _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
29
- input_dim: 512
30
- n_groups: 8
31
- n_codebooks: 1
32
- levels: [8, 5, 5, 5]
33
- downsample_factor: [2, 2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/configs/lora/r_8_alpha_16.yaml DELETED
@@ -1,4 +0,0 @@
1
- _target_: fish_speech.models.text2semantic.lora.LoraConfig
2
- r: 8
3
- lora_alpha: 16
4
- lora_dropout: 0.01
 
 
 
 
 
fish_speech/configs/text2semantic_finetune.yaml DELETED
@@ -1,83 +0,0 @@
1
- defaults:
2
- - base
3
- - _self_
4
-
5
- project: text2semantic_finetune_dual_ar
6
- max_length: 4096
7
- pretrained_ckpt_path: checkpoints/fish-speech-1.4
8
-
9
- # Lightning Trainer
10
- trainer:
11
- accumulate_grad_batches: 1
12
- gradient_clip_val: 1.0
13
- gradient_clip_algorithm: "norm"
14
- max_steps: 1000
15
- precision: bf16-true
16
- limit_val_batches: 10
17
- val_check_interval: 100
18
-
19
- # Dataset Configuration
20
- tokenizer:
21
- _target_: transformers.AutoTokenizer.from_pretrained
22
- pretrained_model_name_or_path: ${pretrained_ckpt_path}
23
-
24
- # Dataset Configuration
25
- train_dataset:
26
- _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
27
- proto_files:
28
- - data/protos
29
- tokenizer: ${tokenizer}
30
- causal: true
31
- max_length: ${max_length}
32
- use_speaker: false
33
- interactive_prob: 0.7
34
-
35
- val_dataset:
36
- _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
37
- proto_files:
38
- - data/protos
39
- tokenizer: ${tokenizer}
40
- causal: true
41
- max_length: ${max_length}
42
- use_speaker: false
43
- interactive_prob: 0.7
44
-
45
- data:
46
- _target_: fish_speech.datasets.semantic.SemanticDataModule
47
- train_dataset: ${train_dataset}
48
- val_dataset: ${val_dataset}
49
- num_workers: 4
50
- batch_size: 8
51
- tokenizer: ${tokenizer}
52
- max_length: ${max_length}
53
-
54
- # Model Configuration
55
- model:
56
- _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
57
- model:
58
- _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
59
- path: ${pretrained_ckpt_path}
60
- load_weights: true
61
- max_length: ${max_length}
62
- lora_config: null
63
-
64
- optimizer:
65
- _target_: torch.optim.AdamW
66
- _partial_: true
67
- lr: 1e-4
68
- weight_decay: 0
69
- betas: [0.9, 0.95]
70
- eps: 1e-5
71
-
72
- lr_scheduler:
73
- _target_: torch.optim.lr_scheduler.LambdaLR
74
- _partial_: true
75
- lr_lambda:
76
- _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
77
- _partial_: true
78
- num_warmup_steps: 10
79
-
80
- # Callbacks
81
- callbacks:
82
- model_checkpoint:
83
- every_n_train_steps: ${trainer.val_check_interval}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/conversation.py DELETED
@@ -1,2 +0,0 @@
1
- SEMANTIC_TOKEN = "<|semantic|>"
2
- CODEBOOK_PAD_TOKEN_ID = 0
 
 
 
fish_speech/datasets/__pycache__/semantic.cpython-310.pyc DELETED
Binary file (12.4 kB)
 
fish_speech/datasets/concat_repeat.py DELETED
@@ -1,53 +0,0 @@
1
- import bisect
2
- import random
3
- from typing import Iterable
4
-
5
- from torch.utils.data import Dataset, IterableDataset
6
-
7
-
8
- class ConcatRepeatDataset(Dataset):
9
- datasets: list[Dataset]
10
- cumulative_sizes: list[int]
11
- repeats: list[int]
12
-
13
- @staticmethod
14
- def cumsum(sequence, repeats):
15
- r, s = [], 0
16
- for dataset, repeat in zip(sequence, repeats):
17
- l = len(dataset) * repeat
18
- r.append(l + s)
19
- s += l
20
- return r
21
-
22
- def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
23
- super().__init__()
24
-
25
- self.datasets = list(datasets)
26
- self.repeats = repeats
27
-
28
- assert len(self.datasets) > 0, "datasets should not be an empty iterable"
29
- assert len(self.datasets) == len(
30
- repeats
31
- ), "datasets and repeats should have the same length"
32
-
33
- for d in self.datasets:
34
- assert not isinstance(
35
- d, IterableDataset
36
- ), "ConcatRepeatDataset does not support IterableDataset"
37
-
38
- self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
39
-
40
- def __len__(self):
41
- return self.cumulative_sizes[-1]
42
-
43
- def __getitem__(self, idx):
44
- dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
45
-
46
- if dataset_idx == 0:
47
- sample_idx = idx
48
- else:
49
- sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
50
-
51
- dataset = self.datasets[dataset_idx]
52
-
53
- return dataset[sample_idx % len(dataset)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/datasets/protos/__pycache__/text_data_pb2.cpython-310.pyc DELETED
Binary file (1.26 kB)
 
fish_speech/datasets/protos/__pycache__/text_data_stream.cpython-310.pyc DELETED
Binary file (1.13 kB)
 
fish_speech/datasets/protos/text-data.proto DELETED
@@ -1,24 +0,0 @@
1
- syntax = "proto3";
2
-
3
- package text_data;
4
-
5
- message Semantics {
6
- repeated uint32 values = 1;
7
- }
8
-
9
- message Sentence {
10
- repeated string texts = 1;
11
- repeated Semantics semantics = 3;
12
- }
13
-
14
- message TextData {
15
- string source = 1;
16
- string name = 2;
17
- repeated Sentence sentences = 4;
18
- }
19
-
20
- message SampledData {
21
- string source = 1;
22
- string name = 2;
23
- repeated Sentence samples = 3;
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/datasets/protos/text_data_pb2.py DELETED
@@ -1,33 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Generated by the protocol buffer compiler. DO NOT EDIT!
3
- # source: text-data.proto
4
- # Protobuf Python Version: 4.25.1
5
- """Generated protocol buffer code."""
6
- from google.protobuf import descriptor as _descriptor
7
- from google.protobuf import descriptor_pool as _descriptor_pool
8
- from google.protobuf import symbol_database as _symbol_database
9
- from google.protobuf.internal import builder as _builder
10
-
11
- # @@protoc_insertion_point(imports)
12
-
13
- _sym_db = _symbol_database.Default()
14
-
15
-
16
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
17
- b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3'
18
- )
19
-
20
- _globals = globals()
21
- _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
22
- _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
23
- if _descriptor._USE_C_DESCRIPTORS == False:
24
- DESCRIPTOR._options = None
25
- _globals["_SEMANTICS"]._serialized_start = 30
26
- _globals["_SEMANTICS"]._serialized_end = 57
27
- _globals["_SENTENCE"]._serialized_start = 59
28
- _globals["_SENTENCE"]._serialized_end = 125
29
- _globals["_TEXTDATA"]._serialized_start = 127
30
- _globals["_TEXTDATA"]._serialized_end = 207
31
- _globals["_SAMPLEDDATA"]._serialized_start = 209
32
- _globals["_SAMPLEDDATA"]._serialized_end = 290
33
- # @@protoc_insertion_point(module_scope)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/datasets/protos/text_data_stream.py DELETED
@@ -1,36 +0,0 @@
1
- import struct
2
-
3
- from .text_data_pb2 import TextData
4
-
5
-
6
- def read_pb_stream(f):
7
- while True:
8
- buf = f.read(4)
9
- if len(buf) == 0:
10
- break
11
- size = struct.unpack("I", buf)[0]
12
- buf = f.read(size)
13
- text_data = TextData()
14
- text_data.ParseFromString(buf)
15
- yield text_data
16
-
17
-
18
- def write_pb_stream(f, text_data):
19
- buf = text_data.SerializeToString()
20
- f.write(struct.pack("I", len(buf)))
21
- f.write(buf)
22
-
23
-
24
- def pack_pb_stream(text_data):
25
- buf = text_data.SerializeToString()
26
- return struct.pack("I", len(buf)) + buf
27
-
28
-
29
- def split_pb_stream(f):
30
- while True:
31
- head = f.read(4)
32
- if len(head) == 0:
33
- break
34
- size = struct.unpack("I", head)[0]
35
- buf = f.read(size)
36
- yield head + buf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/datasets/semantic.py DELETED
@@ -1,496 +0,0 @@
1
- import random
2
- from dataclasses import dataclass
3
- from itertools import chain
4
- from pathlib import Path
5
- from random import Random
6
- from typing import Optional, Union
7
-
8
- import numpy as np
9
- import pyarrow.parquet as pq
10
- import torch
11
- import torch.nn.functional as F
12
- from datasets.download.streaming_download_manager import xopen
13
- from huggingface_hub import HfApi
14
- from lightning import LightningDataModule
15
- from torch.distributed import get_rank, get_world_size, is_initialized
16
- from torch.utils.data import DataLoader, IterableDataset, get_worker_info
17
- from transformers import AutoTokenizer
18
-
19
- from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
20
- from fish_speech.datasets.protos.text_data_pb2 import SampledData
21
- from fish_speech.datasets.protos.text_data_stream import read_pb_stream
22
- from fish_speech.text.clean import clean_text
23
- from fish_speech.utils import RankedLogger
24
- from fish_speech.utils.braceexpand import braceexpand
25
-
26
- log = RankedLogger(__name__, rank_zero_only=True)
27
-
28
-
29
- def split_by_rank_worker(files):
30
- # We need to know the total number of devices
31
- # to split the data properly
32
-
33
- total_devices = 1
34
- if is_initialized():
35
- total_devices = get_world_size()
36
-
37
- worker_info = get_worker_info()
38
- if worker_info is not None:
39
- total_devices *= worker_info.num_workers
40
-
41
- if len(files) < total_devices:
42
- # Repeat the files N times to match the number of devices
43
- files = files * (total_devices // len(files) + 1)
44
-
45
- # DDP
46
- if is_initialized():
47
- files = files[get_rank() :: get_world_size()]
48
-
49
- # Split by worker
50
- if worker_info is not None:
51
- files = files[worker_info.id :: worker_info.num_workers]
52
-
53
- return files
54
-
55
-
56
- class AutoTextSemanticInstructionDataset(IterableDataset):
57
- """
58
- Auto Augment Dataset by Speaker
59
-
60
- 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
61
- 2. Automatically normalize the text
62
-
63
- For interactive mode, we use the following format (multiple sequences):
64
- <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
65
-
66
- For non-interactive mode, we use the following format (one long sequence):
67
- <s> [INST] text [/INST] ... </s>
68
- """
69
-
70
- def __init__(
71
- self,
72
- proto_files: list[str],
73
- seed: int = 42,
74
- interactive_prob: float = 0.5,
75
- max_length: int = 1024,
76
- tokenizer: AutoTokenizer = None,
77
- use_speaker: bool | float = True,
78
- causal: bool = True,
79
- num_codebooks: Optional[int] = None,
80
- skip_text_prob: float = 0.0,
81
- ):
82
- """
83
- Args:
84
- proto_files: proto buf files if using local data
85
- seed: random seed
86
- interactive_prob: probability to use interactive mode
87
- max_length: max length of the text
88
- tokenizer: tokenizer
89
- use_speaker: include speaker information in the prompt
90
- causal: use causal sampling when using local data, disable will lead to random sampling
91
- num_codebooks: number of codebooks, if None, it will be automatically detected
92
- skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
93
- """
94
-
95
- super().__init__()
96
-
97
- assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
98
-
99
- self.seed = seed
100
- self.max_length = max_length
101
- self.tokenizer = tokenizer
102
- self.interactive_prob = interactive_prob
103
- self.use_speaker = use_speaker
104
- self.proto_files = proto_files
105
- self.causal = causal
106
- self.num_codebooks = num_codebooks
107
- self.skip_text_prob = skip_text_prob
108
-
109
- self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
110
- self.groups = None
111
-
112
- def init_mock_data_server(self):
113
- if self.groups is not None:
114
- return
115
-
116
- # Expand the proto files
117
- expanded_proto_files = []
118
- for filename in self.proto_files:
119
- for i in braceexpand(filename):
120
- i = Path(i)
121
- if i.is_file():
122
- expanded_proto_files.append(i)
123
- elif i.is_dir():
124
- expanded_proto_files.extend(i.rglob("*.proto"))
125
- expanded_proto_files.extend(i.rglob("*.protos"))
126
- else:
127
- raise ValueError(f"{i} is not a file or directory")
128
-
129
- expanded_proto_files = sorted(expanded_proto_files)
130
- Random(self.seed).shuffle(expanded_proto_files)
131
-
132
- self.groups = []
133
- shard_proto_files = split_by_rank_worker(expanded_proto_files)
134
- log.info(
135
- f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
136
- )
137
-
138
- count = 0
139
- for filename in shard_proto_files:
140
- with open(filename, "rb") as f:
141
- for text_data in read_pb_stream(f):
142
- self.groups.append(text_data)
143
- count += 1
144
-
145
- log.info(f"Read total {count} groups of data")
146
-
147
- # Shuffle the lines
148
- Random(self.seed).shuffle(self.groups)
149
- self.group_weights = [len(i.sentences) for i in self.groups]
150
-
151
- def __iter__(self):
152
- while True:
153
- yield self.augment()
154
-
155
- def tokenize_sentence(self, sentence: str):
156
- sentence = clean_text(sentence)
157
- tokens = self.tokenizer.encode(
158
- f"{sentence}",
159
- max_length=10**6,
160
- add_special_tokens=False,
161
- truncation=False,
162
- )
163
- return sentence, len(tokens)
164
-
165
- def sample_data(self):
166
- if self.groups is None:
167
- self.init_mock_data_server()
168
-
169
- # Shuffle unique lines, estimate that each sample is at least 20 tokens
170
- num_samples = self.max_length // 20
171
-
172
- # choice group based on their number of samples
173
- group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
174
-
175
- if self.causal:
176
- # Sample in order
177
- if num_samples >= len(group.sentences):
178
- samples = group.sentences
179
- else:
180
- begin = random.randint(0, len(group.sentences) - num_samples)
181
- samples = group.sentences[begin : begin + num_samples]
182
- else:
183
- samples = random.choices(
184
- group.sentences, k=min(num_samples, len(group.sentences))
185
- )
186
-
187
- return SampledData(
188
- source=group.source,
189
- name=group.name,
190
- samples=samples,
191
- )
192
-
193
- def augment(self):
194
- final_text, final_semantic = [], []
195
- response = self.sample_data()
196
- if len(response.samples) == 0:
197
- # Invalid group
198
- return None
199
-
200
- samples = list(response.samples)
201
- idx = 0
202
- use_interactive = random.random() < self.interactive_prob
203
-
204
- if use_interactive is False:
205
- # Random sample based on speaker using a truncated normal distribution
206
- a = torch.tensor([0], dtype=torch.float32)
207
- torch.nn.init.trunc_normal_(
208
- a,
209
- mean=self.max_length // 2,
210
- std=self.max_length // 4,
211
- a=10,
212
- b=self.max_length,
213
- )
214
- remaining_tokens = a.long().item() - 4
215
- else:
216
- remaining_tokens = self.max_length
217
-
218
- # Use speaker
219
- if isinstance(self.use_speaker, float):
220
- use_speaker = random.random() < self.use_speaker
221
- else:
222
- use_speaker = self.use_speaker
223
-
224
- all_tokens, all_labels = [], []
225
- while remaining_tokens > 0 and len(samples) > 0:
226
- sentence = samples.pop(0)
227
-
228
- text = random.choice(sentence.texts)
229
- text, length = self.tokenize_sentence(text)
230
- remaining_tokens -= length + len(sentence.semantics[0].values)
231
-
232
- if use_interactive is False:
233
- final_text.append(text)
234
- final_semantic.append(sentence.semantics)
235
- else:
236
- # For interactive mode, we only apply speaker for the first sentence
237
- # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
238
- tokens, labels = self.pack_sentences(
239
- sentences=[text],
240
- semantics=[sentence.semantics],
241
- speaker=response.name if use_speaker else None,
242
- skip_text=random.random() < self.skip_text_prob,
243
- )
244
-
245
- all_tokens.append(tokens)
246
- all_labels.append(labels)
247
-
248
- idx += 1
249
-
250
- if use_interactive is False:
251
- tokens, labels = self.pack_sentences(
252
- final_text,
253
- semantics=final_semantic,
254
- speaker=response.name if use_speaker else None,
255
- )
256
- all_tokens.append(tokens)
257
- all_labels.append(labels)
258
-
259
- tokens = torch.cat(all_tokens, dim=1)
260
- labels = torch.cat(all_labels, dim=1)
261
-
262
- # Verify that the length is correct
263
- assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
264
-
265
- data = {"tokens": tokens, "labels": labels}
266
-
267
- return data
268
-
269
- def pack_sentences(
270
- self,
271
- sentences: list[str],
272
- semantics: list,
273
- speaker: Optional[str] = None,
274
- skip_text: bool = False,
275
- ):
276
- if speaker is None:
277
- speaker = "assistant"
278
-
279
- cated_sentences = " ".join(sentences)
280
- if skip_text:
281
- cated_sentences = "<|skip_text|>"
282
-
283
- final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>"
284
- final_text = final_text + f"<|im_start|>{speaker}\n"
285
-
286
- encoded = self.tokenizer.encode(
287
- final_text,
288
- add_special_tokens=False,
289
- truncation=False,
290
- max_length=10**6,
291
- )
292
- semantic_length = sum([len(i[0].values) for i in semantics])
293
- prompt_length = len(encoded)
294
- num_codebooks = (
295
- len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
296
- )
297
-
298
- # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
299
- tokens = (
300
- encoded
301
- + [self.semantic_token_id] * semantic_length
302
- + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])
303
- )
304
-
305
- # Codebook bos/padding: 0, eos: 1
306
- codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)]
307
- for segment in semantics:
308
- for book_idx, book in zip(range(num_codebooks), segment):
309
- for j in book.values:
310
- codes[book_idx].append(int(j) + 1)
311
-
312
- for book in codes:
313
- book.extend([CODEBOOK_PAD_TOKEN_ID] * 1)
314
-
315
- tokens = [tokens] + codes
316
-
317
- tokens = torch.tensor(tokens, dtype=torch.long)
318
- labels = tokens.clone()
319
-
320
- if skip_text:
321
- # If text is not provided, the sentence is used for condition only, all labels are -100
322
- torch.fill_(labels, -100)
323
- return tokens, labels
324
-
325
- # Mask out the <s> tokens for semantic, predict semantic tokens only
326
- # Since we don't mask out the input tokens, the language modeling still works
327
- labels[1:, :prompt_length] = -100
328
-
329
- tokens = tokens[:, :-1]
330
- labels = labels[:, 1:]
331
-
332
- # Verify the padding is correct, and the last token is eos
333
- assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all()
334
- assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
335
-
336
- return tokens, labels
337
-
338
-
339
- @dataclass
340
- class TextDataCollator:
341
- tokenizer: AutoTokenizer
342
- max_length: int = 1024
343
-
344
- def __call__(self, examples):
345
- if "negative_tokens" in examples:
346
- positive_examples = []
347
- negative_examples = []
348
-
349
- for i in examples:
350
- positive_examples.append(
351
- {
352
- "tokens": i["tokens"],
353
- "labels": i["labels"],
354
- }
355
- )
356
- negative_examples.append(
357
- {
358
- "tokens": i["negative_tokens"],
359
- "labels": i["negative_labels"],
360
- }
361
- )
362
-
363
- examples = positive_examples + negative_examples
364
-
365
- return self.batchify(examples)
366
-
367
- def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
368
- tokens, attention_masks, labels = [], [], []
369
-
370
- # Calculate the max length
371
- max_tokens_length = 0
372
- for example in examples:
373
- max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
374
- max_tokens_length = min(max_tokens_length, self.max_length)
375
-
376
- for example in examples:
377
- _tokens = example[tokens_key][:, :max_tokens_length]
378
- _labels = example[labels_key][:, :max_tokens_length]
379
- _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
380
- tokens_length = _tokens.size(1)
381
- _attention_mask[:tokens_length] = False
382
-
383
- assert tokens_length == _labels.size(
384
- 1
385
- ), f"{tokens_length} != {_labels.size(1)}"
386
-
387
- if tokens_length < max_tokens_length:
388
- _tokens = F.pad(
389
- _tokens,
390
- (0, max_tokens_length - tokens_length),
391
- value=self.tokenizer.eos_token_id,
392
- )
393
- _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
394
- _labels = F.pad(
395
- _labels, (0, max_tokens_length - _labels.size(1)), value=-100
396
- )
397
-
398
- tokens.append(_tokens)
399
- attention_masks.append(_attention_mask)
400
- labels.append(_labels)
401
-
402
- tokens = torch.stack(tokens, dim=0)
403
- attention_masks = torch.stack(attention_masks, dim=0)
404
- labels = torch.stack(labels, dim=0)
405
-
406
- return {
407
- "inputs": tokens,
408
- "attention_masks": attention_masks,
409
- "labels": labels,
410
- }
411
-
412
-
413
- class InterleaveDataset(IterableDataset):
414
- def __init__(
415
- self,
416
- datasets: list[IterableDataset],
417
- probabilities: list[float],
418
- seed: int = 42,
419
- ):
420
- super().__init__()
421
-
422
- self.datasets = datasets
423
- self.probabilities = probabilities
424
- self.seed = seed
425
-
426
- def __iter__(self):
427
- rng = np.random.default_rng(self.seed)
428
- dataset_iterators = [iter(dataset) for dataset in self.datasets]
429
-
430
- while True:
431
- # Random choice one
432
- dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
433
- dataset_iterator = dataset_iterators[dataset_idx]
434
-
435
- try:
436
- yield next(dataset_iterator)
437
- except StopIteration:
438
- # Exhausted, create a new iterator
439
- dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
440
- yield next(dataset_iterators[dataset_idx])
441
-
442
-
443
- class SemanticDataModule(LightningDataModule):
444
- def __init__(
445
- self,
446
- train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
447
- val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
448
- batch_size: int = 32,
449
- tokenizer: AutoTokenizer = None,
450
- max_length: int = 1024,
451
- num_workers: int = 4,
452
- ):
453
- super().__init__()
454
-
455
- self.train_dataset = train_dataset
456
- self.val_dataset = val_dataset
457
- self.batch_size = batch_size
458
- self.tokenizer = tokenizer
459
- self.max_length = max_length
460
- self.num_workers = num_workers
461
-
462
- def train_dataloader(self):
463
- return DataLoader(
464
- self.train_dataset,
465
- batch_size=self.batch_size,
466
- collate_fn=TextDataCollator(self.tokenizer, self.max_length),
467
- num_workers=self.num_workers,
468
- persistent_workers=True,
469
- )
470
-
471
- def val_dataloader(self):
472
- return DataLoader(
473
- self.val_dataset,
474
- batch_size=self.batch_size,
475
- collate_fn=TextDataCollator(self.tokenizer, self.max_length),
476
- num_workers=self.num_workers,
477
- persistent_workers=True,
478
- )
479
-
480
-
481
- if __name__ == "__main__":
482
- from tqdm import tqdm
483
-
484
- ds = AutoTextSemanticInstructionDataset(
485
- ["data/protos"],
486
- tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
487
- use_speaker=False,
488
- interactive_prob=1.0,
489
- skip_text_prob=0.5,
490
- )
491
-
492
- for i in ds:
493
- print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
494
- # i["labels"][0][i["labels"][0] == -100] = 0
495
- # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
496
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/datasets/vqgan.py DELETED
@@ -1,147 +0,0 @@
1
- from dataclasses import dataclass
2
- from pathlib import Path
3
- from typing import Optional
4
-
5
- import librosa
6
- import numpy as np
7
- import torch
8
- from lightning import LightningDataModule
9
- from torch.utils.data import DataLoader, Dataset
10
-
11
- from fish_speech.utils import RankedLogger
12
-
13
- logger = RankedLogger(__name__, rank_zero_only=False)
14
-
15
-
16
- class VQGANDataset(Dataset):
17
- def __init__(
18
- self,
19
- filelist: str,
20
- sample_rate: int = 32000,
21
- hop_length: int = 640,
22
- slice_frames: Optional[int] = None,
23
- ):
24
- super().__init__()
25
-
26
- filelist = Path(filelist)
27
- root = filelist.parent
28
-
29
- self.files = [
30
- root / line.strip()
31
- for line in filelist.read_text(encoding="utf-8").splitlines()
32
- if line.strip()
33
- ]
34
- self.sample_rate = sample_rate
35
- self.hop_length = hop_length
36
- self.slice_frames = slice_frames
37
-
38
- def __len__(self):
39
- return len(self.files)
40
-
41
- def get_item(self, idx):
42
- file = self.files[idx]
43
-
44
- audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
45
-
46
- # Slice audio and features
47
- if (
48
- self.slice_frames is not None
49
- and audio.shape[0] > self.slice_frames * self.hop_length
50
- ):
51
- start = np.random.randint(
52
- 0, audio.shape[0] - self.slice_frames * self.hop_length
53
- )
54
- audio = audio[start : start + self.slice_frames * self.hop_length]
55
-
56
- if len(audio) == 0:
57
- return None
58
-
59
- max_value = np.abs(audio).max()
60
- if max_value > 1.0:
61
- audio = audio / max_value
62
-
63
- return {
64
- "audio": torch.from_numpy(audio),
65
- }
66
-
67
- def __getitem__(self, idx):
68
- try:
69
- return self.get_item(idx)
70
- except Exception as e:
71
- import traceback
72
-
73
- traceback.print_exc()
74
- logger.error(f"Error loading {self.files[idx]}: {e}")
75
- return None
76
-
77
-
78
- @dataclass
79
- class VQGANCollator:
80
- def __call__(self, batch):
81
- batch = [x for x in batch if x is not None]
82
-
83
- audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
84
- audio_maxlen = audio_lengths.max()
85
-
86
- # Rounds up to nearest multiple of 2 (audio_lengths)
87
- audios = []
88
- for x in batch:
89
- audios.append(
90
- torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
91
- )
92
-
93
- return {
94
- "audios": torch.stack(audios),
95
- "audio_lengths": audio_lengths,
96
- }
97
-
98
-
99
- class VQGANDataModule(LightningDataModule):
100
- def __init__(
101
- self,
102
- train_dataset: VQGANDataset,
103
- val_dataset: VQGANDataset,
104
- batch_size: int = 32,
105
- num_workers: int = 4,
106
- val_batch_size: Optional[int] = None,
107
- ):
108
- super().__init__()
109
-
110
- self.train_dataset = train_dataset
111
- self.val_dataset = val_dataset
112
- self.batch_size = batch_size
113
- self.val_batch_size = val_batch_size or batch_size
114
- self.num_workers = num_workers
115
-
116
- def train_dataloader(self):
117
- return DataLoader(
118
- self.train_dataset,
119
- batch_size=self.batch_size,
120
- collate_fn=VQGANCollator(),
121
- num_workers=self.num_workers,
122
- shuffle=True,
123
- persistent_workers=True,
124
- )
125
-
126
- def val_dataloader(self):
127
- return DataLoader(
128
- self.val_dataset,
129
- batch_size=self.val_batch_size,
130
- collate_fn=VQGANCollator(),
131
- num_workers=self.num_workers,
132
- persistent_workers=True,
133
- )
134
-
135
-
136
- if __name__ == "__main__":
137
- dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
138
- dataloader = DataLoader(
139
- dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
140
- )
141
-
142
- for batch in dataloader:
143
- print(batch["audios"].shape)
144
- print(batch["features"].shape)
145
- print(batch["audio_lengths"])
146
- print(batch["feature_lengths"])
147
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/i18n/README.md DELETED
@@ -1,27 +0,0 @@
1
- ## i18n Folder Attribution
2
-
3
- The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below:
4
-
5
- ### fish_speech/i18n/core.py
6
-
7
- **Related code from RVC:**
8
- [https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py)
9
-
10
- **Initial commit:**
11
- add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35)
12
-
13
- **Initial author:**
14
- [@L4Ph](https://github.com/L4Ph)
15
-
16
- ### fish_speech/i18n/scan.py
17
-
18
- **Related code from RVC:**
19
- [https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py)
20
-
21
- **Initial commit:**
22
- File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058)
23
-
24
- **Initial author:**
25
- [@towzeur](https://github.com/towzeur)
26
-
27
- We appreciate the contributions of the RVC project and its authors.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/i18n/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .core import i18n
2
-
3
- __all__ = ["i18n"]
 
 
 
 
fish_speech/i18n/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (218 Bytes)
 
fish_speech/i18n/__pycache__/core.cpython-310.pyc DELETED
Binary file (1.44 kB)
 
fish_speech/i18n/core.py DELETED
@@ -1,40 +0,0 @@
1
- import json
2
- import locale
3
- from pathlib import Path
4
-
5
- I18N_FILE_PATH = Path(__file__).parent / "locale"
6
- DEFAULT_LANGUAGE = "en_US"
7
-
8
-
9
- def load_language_list(language):
10
- with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f:
11
- language_list = json.load(f)
12
-
13
- return language_list
14
-
15
-
16
- class I18nAuto:
17
- def __init__(self):
18
- i18n_file = Path(".locale")
19
-
20
- if i18n_file.exists():
21
- with open(i18n_file, "r", encoding="utf-8") as f:
22
- language = f.read().strip()
23
- else:
24
- # getlocale can't identify the system's language ((None, None))
25
- language = locale.getdefaultlocale()[0]
26
-
27
- if (I18N_FILE_PATH / f"{language}.json").exists() is False:
28
- language = DEFAULT_LANGUAGE
29
-
30
- self.language = language
31
- self.language_map = load_language_list(language)
32
-
33
- def __call__(self, key):
34
- return self.language_map.get(key, key)
35
-
36
- def __repr__(self):
37
- return "Use Language: " + self.language
38
-
39
-
40
- i18n = I18nAuto()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/i18n/locale/en_US.json DELETED
@@ -1,122 +0,0 @@
1
- {
2
- "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU",
3
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.",
4
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).",
5
- "Accumulate Gradient Batches": "Accumulate Gradient Batches",
6
- "Add to Processing Area": "Add to Processing Area",
7
- "Added path successfully!": "Added path successfully!",
8
- "Advanced Config": "Advanced Config",
9
- "Base LLAMA Model": "Base LLAMA Model",
10
- "Batch Inference": "Batch Inference",
11
- "Batch Size": "Batch Size",
12
- "Changing with the Model Path": "Changing with the Model Path",
13
- "Chinese": "Chinese",
14
- "Compile Model": "Compile Model",
15
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time",
16
- "Copy": "Copy",
17
- "Data Preprocessing": "Data Preprocessing",
18
- "Data Preprocessing Path": "Data Preprocessing Path",
19
- "Data Source": "Data Source",
20
- "Decoder Model Config": "Decoder Model Config",
21
- "Decoder Model Path": "Decoder Model Path",
22
- "Disabled": "Disabled",
23
- "Enable Reference Audio": "Enable Reference Audio",
24
- "English": "English",
25
- "Error Message": "Error Message",
26
- "File Preprocessing": "File Preprocessing",
27
- "Generate": "Generate",
28
- "Generated Audio": "Generated Audio",
29
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format",
30
- "Infer interface is closed": "Infer interface is closed",
31
- "Inference Configuration": "Inference Configuration",
32
- "Inference Server Configuration": "Inference Server Configuration",
33
- "Inference Server Error": "Inference Server Error",
34
- "Inferring interface is launched at {}": "Inferring interface is launched at {}",
35
- "Initial Learning Rate": "Initial Learning Rate",
36
- "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription",
37
- "Input Text": "Input Text",
38
- "Invalid path: {}": "Invalid path: {}",
39
- "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU",
40
- "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off",
41
- "Japanese": "Japanese",
42
- "LLAMA Configuration": "LLAMA Configuration",
43
- "LLAMA Model Config": "LLAMA Model Config",
44
- "LLAMA Model Path": "LLAMA Model Path",
45
- "Labeling Device": "Labeling Device",
46
- "LoRA Model to be merged": "LoRA Model to be merged",
47
- "Maximum Audio Duration": "Maximum Audio Duration",
48
- "Maximum Length per Sample": "Maximum Length per Sample",
49
- "Maximum Training Steps": "Maximum Training Steps",
50
- "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit",
51
- "Merge": "Merge",
52
- "Merge LoRA": "Merge LoRA",
53
- "Merge successfully": "Merge successfully",
54
- "Minimum Audio Duration": "Minimum Audio Duration",
55
- "Model Output Path": "Model Output Path",
56
- "Model Size": "Model Size",
57
- "Move": "Move",
58
- "Move files successfully": "Move files successfully",
59
- "No audio generated, please check the input text.": "No audio generated, please check the input text.",
60
- "No selected options": "No selected options",
61
- "Number of Workers": "Number of Workers",
62
- "Open Inference Server": "Open Inference Server",
63
- "Open Labeler WebUI": "Open Labeler WebUI",
64
- "Open Tensorboard": "Open Tensorboard",
65
- "Opened labeler in browser": "Opened labeler in browser",
66
- "Optional Label Language": "Optional Label Language",
67
- "Optional online ver": "Optional online ver",
68
- "Output Path": "Output Path",
69
- "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path",
70
- "Precision": "Precision",
71
- "Probability of applying Speaker Condition": "Probability of applying Speaker Condition",
72
- "Put your text here.": "Put your text here.",
73
- "Reference Audio": "Reference Audio",
74
- "Reference Text": "Reference Text",
75
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.",
76
- "Remove Selected Data": "Remove Selected Data",
77
- "Removed path successfully!": "Removed path successfully!",
78
- "Repetition Penalty": "Repetition Penalty",
79
- "Save model every n steps": "Save model every n steps",
80
- "Select LLAMA ckpt": "Select LLAMA ckpt",
81
- "Select VITS ckpt": "Select VITS ckpt",
82
- "Select VQGAN ckpt": "Select VQGAN ckpt",
83
- "Select source file processing method": "Select source file processing method",
84
- "Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)",
85
- "Selected: {}": "Selected: {}",
86
- "Speaker": "Speaker",
87
- "Speaker is identified by the folder name": "Speaker is identified by the folder name",
88
- "Start Training": "Start Training",
89
- "Streaming Audio": "Streaming Audio",
90
- "Streaming Generate": "Streaming Generate",
91
- "Tensorboard Host": "Tensorboard Host",
92
- "Tensorboard Log Path": "Tensorboard Log Path",
93
- "Tensorboard Port": "Tensorboard Port",
94
- "Tensorboard interface is closed": "Tensorboard interface is closed",
95
- "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}",
96
- "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.",
97
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.",
98
- "Training Configuration": "Training Configuration",
99
- "Training Error": "Training Error",
100
- "Training stopped": "Training stopped",
101
- "Type name of the speaker": "Type name of the speaker",
102
- "Type the path or select from the dropdown": "Type the path or select from the dropdown",
103
- "Use LoRA": "Use LoRA",
104
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model",
105
- "Use filelist": "Use filelist",
106
- "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G",
107
- "VITS Configuration": "VITS Configuration",
108
- "VQGAN Configuration": "VQGAN Configuration",
109
- "Validation Batch Size": "Validation Batch Size",
110
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)",
111
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.",
112
- "WebUI Host": "WebUI Host",
113
- "WebUI Port": "WebUI Port",
114
- "Whisper Model": "Whisper Model",
115
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).",
116
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU",
117
- "latest": "latest",
118
- "new": "new",
119
- "Realtime Transform Text": "Realtime Transform Text",
120
- "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
121
- "Text Normalization": "Text Normalization"
122
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/i18n/locale/es_ES.json DELETED
@@ -1,122 +0,0 @@
1
- {
2
- "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+",
3
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.",
4
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).",
5
- "Accumulate Gradient Batches": "Acumular lotes de gradientes",
6
- "Add to Processing Area": "Agregar al Área de Procesamiento",
7
- "Added path successfully!": "¡Ruta agregada exitosamente!",
8
- "Advanced Config": "Configuración Avanzada",
9
- "Base LLAMA Model": "Modelo Base LLAMA",
10
- "Batch Inference": "Inferencia por Lote",
11
- "Batch Size": "Tamaño del Lote",
12
- "Changing with the Model Path": "Cambiando con la Ruta del Modelo",
13
- "Chinese": "Chino",
14
- "Compile Model": "Compilar Modelo",
15
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío",
16
- "Copy": "Copiar",
17
- "Data Preprocessing": "Preprocesamiento de Datos",
18
- "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos",
19
- "Data Source": "Fuente de Datos",
20
- "Decoder Model Config": "Configuración del modelo decodificador",
21
- "Decoder Model Path": "Ruta del modelo decodificador",
22
- "Disabled": "Desactivado",
23
- "Enable Reference Audio": "Habilitar Audio de Referencia",
24
- "English": "Inglés",
25
- "Error Message": "Mensaje de Error",
26
- "File Preprocessing": "Preprocesamiento de Archivos",
27
- "Generate": "Generar",
28
- "Generated Audio": "Audio Generado",
29
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab",
30
- "Infer interface is closed": "La interfaz de inferencia está cerrada",
31
- "Inference Configuration": "Configuración de Inferencia",
32
- "Inference Server Configuration": "Configuración del Servidor de Inferencia",
33
- "Inference Server Error": "Error del Servidor de Inferencia",
34
- "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}",
35
- "Initial Learning Rate": "Tasa de Aprendizaje Inicial",
36
- "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción",
37
- "Input Text": "Texto de Entrada",
38
- "Invalid path: {}": "Ruta inválida: {}",
39
- "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU",
40
- "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado",
41
- "Japanese": "Japonés",
42
- "LLAMA Configuration": "Configuración de LLAMA",
43
- "LLAMA Model Config": "Configuración del Modelo LLAMA",
44
- "LLAMA Model Path": "Ruta del Modelo LLAMA",
45
- "Labeling Device": "Dispositivo de Etiquetado",
46
- "LoRA Model to be merged": "Modelo LoRA a fusionar",
47
- "Maximum Audio Duration": "Duración máxima de audio",
48
- "Maximum Length per Sample": "Longitud Máxima por Muestra",
49
- "Maximum Training Steps": "Pasos Máximos de Entrenamiento",
50
- "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite",
51
- "Merge": "Fusionar",
52
- "Merge LoRA": "Fusionar LoRA",
53
- "Merge successfully": "Fusionado exitosamente",
54
- "Minimum Audio Duration": "Duración mínima de audio",
55
- "Model Output Path": "Ruta de Salida del Modelo",
56
- "Model Size": "Tamaño del Modelo",
57
- "Move": "Mover",
58
- "Move files successfully": "Archivos movidos exitosamente",
59
- "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.",
60
- "No selected options": "No hay opciones seleccionadas",
61
- "Number of Workers": "Número de Trabajadores",
62
- "Open Inference Server": "Abrir Servidor de Inferencia",
63
- "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador",
64
- "Open Tensorboard": "Abrir Tensorboard",
65
- "Opened labeler in browser": "Se abrió el etiquetador en el navegador",
66
- "Optional Label Language": "Idioma de Etiquetado Opcional",
67
- "Optional online ver": "Ver en línea opcional",
68
- "Output Path": "Ruta de Salida",
69
- "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente",
70
- "Precision": "Precisión",
71
- "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante",
72
- "Put your text here.": "Ponga su texto aquí.",
73
- "Reference Audio": "Audio de Referencia",
74
- "Reference Text": "Texto de Referencia",
75
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.",
76
- "Remove Selected Data": "Eliminar Datos Seleccionados",
77
- "Removed path successfully!": "¡Ruta eliminada exitosamente!",
78
- "Repetition Penalty": "Penalización por Repetición",
79
- "Save model every n steps": "Guardar modelo cada n pasos",
80
- "Select LLAMA ckpt": "Seleccionar punto de control LLAMA",
81
- "Select VITS ckpt": "Seleccionar punto de control VITS",
82
- "Select VQGAN ckpt": "Seleccionar punto de control VQGAN",
83
- "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente",
84
- "Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)",
85
- "Selected: {}": "Seleccionado: {}",
86
- "Speaker": "Hablante",
87
- "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
88
- "Start Training": "Iniciar Entrenamiento",
89
- "Streaming Audio": "transmisión de audio",
90
- "Streaming Generate": "síntesis en flujo",
91
- "Tensorboard Host": "Host de Tensorboard",
92
- "Tensorboard Log Path": "Ruta de Registro de Tensorboard",
93
- "Tensorboard Port": "Puerto de Tensorboard",
94
- "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada",
95
- "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}",
96
- "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.",
97
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.",
98
- "Training Configuration": "Configuración de Entrenamiento",
99
- "Training Error": "Error de Entrenamiento",
100
- "Training stopped": "Entrenamiento detenido",
101
- "Type name of the speaker": "Escriba el nombre del hablante",
102
- "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable",
103
- "Use LoRA": "Usar LoRA",
104
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo",
105
- "Use filelist": "Usar lista de archivos",
106
- "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G",
107
- "VITS Configuration": "Configuración de VITS",
108
- "VQGAN Configuration": "Configuración de VQGAN",
109
- "Validation Batch Size": "Tamaño del Lote de Validación",
110
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)",
111
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.",
112
- "WebUI Host": "Host de WebUI",
113
- "WebUI Port": "Puerto de WebUI",
114
- "Whisper Model": "Modelo Whisper",
115
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).",
116
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+",
117
- "latest": "más reciente",
118
- "new": "nuevo",
119
- "Realtime Transform Text": "Transformación de Texto en Tiempo Real",
120
- "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
121
- "Text Normalization": "Normalización de Texto"
122
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/i18n/locale/ja_JP.json DELETED
@@ -1,123 +0,0 @@
1
- {
2
- "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします",
3
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。",
4
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。",
5
- "Accumulate Gradient Batches": "勾配バッチの累積",
6
- "Add to Processing Area": "処理エリアに追加",
7
- "Added path successfully!": "パスの追加に成功しました!",
8
- "Advanced Config": "詳細設定",
9
- "Base LLAMA Model": "基本LLAMAモデル",
10
- "Batch Inference": "バッチ推論",
11
- "Batch Size": "バッチサイズ",
12
- "Changing with the Model Path": "モデルのパスに伴って変化する",
13
- "Chinese": "中国語",
14
- "Compile Model": "モデルのコンパイル",
15
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります",
16
- "Copy": "コピー",
17
- "Data Preprocessing": "データ前処理",
18
- "Data Preprocessing Path": "データ前処理パス",
19
- "Data Source": "データソース",
20
- "Decoder Model Config": "デコーダーモデルの構成",
21
- "Decoder Model Path": "デコーダーモデルのパス",
22
- "Disabled": "無効",
23
- "Enable Reference Audio": "リファレンスオーディオを有効にする",
24
- "English": "英語",
25
- "Error Message": "エラーメッセージ",
26
- "File Preprocessing": "文書前处理",
27
- "Generate": "生成",
28
- "Generated Audio": "生成されたオーディオ",
29
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています",
30
- "Infer interface is closed": "推論インターフェースが閉じられています",
31
- "Inference Configuration": "推論設定",
32
- "Inference Server Configuration": "推論サーバー設定",
33
- "Inference Server Error": "推論サーバーエラー",
34
- "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました",
35
- "Initial Learning Rate": "初期学習率",
36
- "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス",
37
- "Input Text": "入力テキスト",
38
- "Invalid path: {}": "無効なパス: {}",
39
- "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください",
40
- "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します",
41
- "Japanese": "日本語",
42
- "LLAMA Configuration": "LLAMA設定",
43
- "LLAMA Model Config": "LLAMAモデル設定",
44
- "LLAMA Model Path": "LLAMAモデルパス",
45
- "Labeling Device": "ラベリングデバイス",
46
- "LoRA Model to be merged": "マージするLoRAモデル",
47
- "Maximum Audio Duration": "最大オーディオの長さ",
48
- "Maximum Length per Sample": "サンプルあたりの最大長",
49
- "Maximum Training Steps": "最大トレーニングステップ数",
50
- "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します",
51
- "Merge": "マージ",
52
- "Merge LoRA": "LoRAのマージ",
53
- "Merge successfully": "マージに成功しました",
54
- "Minimum Audio Duration": "最小オーディオの長さ",
55
- "Model Output Path": "モデル出力パス",
56
- "Model Size": "モデルサイズ",
57
- "Move": "移動",
58
- "Move files successfully": "ファイルの移動に成功しました",
59
- "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。",
60
- "No selected options": "選択されたオプションはありません",
61
- "Number of Workers": "ワーカー数",
62
- "Open Inference Server": "推論サーバーを開く",
63
- "Open Labeler WebUI": "ラベラーWebUIを開く",
64
- "Open Tensorboard": "Tensorboardを開く",
65
- "Opened labeler in browser": "ブラウザでラベラーを開きました",
66
- "Optional Label Language": "オプションのラベル言語",
67
- "Optional online ver": "オプションのオンラインバージョン",
68
- "Output Path": "出力パス",
69
- "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください",
70
- "Precision": "精度",
71
- "Probability of applying Speaker Condition": "話者条件を適用する確率",
72
- "Put your text here.": "ここにテキストを入力してください。",
73
- "Reference Audio": "リファレンスオーディオ",
74
- "Reference Text": "リファレンステキスト",
75
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。",
76
- "Remove Selected Data": "選択したデータを削除",
77
- "Removed path successfully!": "パスの削除に成功しました!",
78
- "Repetition Penalty": "反復ペナルティ",
79
- "Save model every n steps": "nステップごとにモデルを保存",
80
- "Select LLAMA ckpt": " LLAMA チェックポイントを選択",
81
- "Select VITS ckpt": "VITS チェックポイントを選択",
82
- "Select VQGAN ckpt": "VQGAN チェックポイントを選択",
83
- "Select source file processing method": "ソースファイルの処理方法を選択",
84
- "Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください",
85
- "Selected: {}": "選択済み: {}",
86
- "Speaker": "話者",
87
- "Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
88
- "Start Training": "トレーニング開始",
89
- "Streaming Audio": "ストリーミングオーディオ",
90
- "Streaming Generate": "ストリーミング合成",
91
- "Tensorboard Host": "Tensorboardホスト",
92
- "Tensorboard Log Path": "Tensorboardログパス",
93
- "Tensorboard Port": "Tensorboardポート",
94
- "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています",
95
- "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました",
96
- "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。",
97
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。",
98
- "Training Configuration": "トレーニング設定",
99
- "Training Error": "トレーニングエラー",
100
- "Training stopped": "トレーニングが停止しました",
101
- "Type name of the speaker": "話者の名前を入力",
102
- "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください",
103
- "Use LoRA": "LoRAを使用",
104
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります",
105
- "Use filelist": "ファイルリストを使用",
106
- "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください",
107
- "VITS Configuration": "VITS の構成",
108
- "VQGAN Configuration": "VQGAN の構成",
109
- "Validation Batch Size": "検証バッチサイズ",
110
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)",
111
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。",
112
- "WebUI Host": "WebUIホスト",
113
- "WebUI Port": "WebUIポート",
114
- "Whisper Model": "Whisperモデル",
115
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。",
116
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします",
117
- "latest": "最新",
118
- "new": "新規",
119
- "Realtime Transform Text": "リアルタイム変換テキスト",
120
- "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
121
- "Text Normalization": "テキスト正規化"
122
-
123
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/i18n/locale/pt_BR.json DELETED
@@ -1,133 +0,0 @@
1
- {
2
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.",
3
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).",
4
- "Accumulate Gradient Batches": "Acumular Lotes de Gradiente",
5
- "Add to Processing Area": "Adicionar à Área de Processamento",
6
- "Added path successfully!": "Caminho adicionado com sucesso!",
7
- "Advanced Config": "Configuração Avançada",
8
- "Base LLAMA Model": "Modelo LLAMA Base",
9
- "Batch Inference": "Inferência em Lote",
10
- "Batch Size": "Tamanho do Lote",
11
- "Changing with the Model Path": "Alterando com o Caminho do Modelo",
12
-
13
- "Compile Model": "Compilar Modelo",
14
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial",
15
- "Copy": "Copiar",
16
- "Data Preprocessing": "Pré-processamento de Dados",
17
- "Data Preprocessing Path": "Caminho de Pré-processamento de Dados",
18
- "Data Source": "Fonte de Dados",
19
- "Decoder Model Config": "Configuração do Modelo Decodificador",
20
- "Decoder Model Path": "Caminho do Modelo Decodificador",
21
- "Disabled": "Desativado",
22
- "Enable Initial Prompt": "Habilitar Prompt Inicial",
23
- "Enable Reference Audio": "Habilitar Áudio de Referência",
24
- "English": "Inglês",
25
- "Japanese": "Japonês",
26
- "Chinese": "Chinês",
27
- "Portuguese": "Português",
28
- "Spanish": "Espanhol",
29
- "Error Message": "Mensagem de Erro",
30
- "Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)",
31
- "File Preprocessing": "Pré-processamento de Arquivos",
32
- "Generate": "Gerar",
33
- "Generated Audio": "Áudio Gerado",
34
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)",
35
- "Infer interface is closed": "A interface de inferência foi fechada",
36
- "Inference Configuration": "Configuração de Inferência",
37
- "Inference Server Configuration": "Configuração do Servidor de Inferência",
38
- "Inference Server Error": "Erro do Servidor de Inferência",
39
- "Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}",
40
- "Initial Learning Rate": "Taxa de Aprendizagem Inicial",
41
- "Initial Prompt": "Prompt Inicial",
42
- "Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.",
43
- "Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição",
44
- "Input Text": "Texto de Entrada",
45
- "Invalid path: {}": "Caminho inválido: {}",
46
- "It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU",
47
- "Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)",
48
- "LLAMA Configuration": "Configuração do LLAMA",
49
- "LLAMA Model Config": "Configuração do Modelo LLAMA",
50
- "LLAMA Model Path": "Caminho do Modelo LLAMA",
51
- "Labeling Device": "Dispositivo de Rotulagem",
52
- "LoRA Model to be merged": "Modelo LoRA para mesclagem",
53
- "Maximum Length per Sample": "Comprimento Máximo por Amostra",
54
- "Maximum Training Steps": "Etapas Máximas de Treinamento",
55
- "Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite",
56
- "Merge": "Mesclar",
57
- "Merge LoRA": "Mesclar LoRA",
58
- "Merge successfully": "Mesclado com sucesso",
59
- "Model Output Path": "Caminho de Saída do Modelo",
60
- "Model Quantization": "Quantização do Modelo",
61
- "Model Size": "Tamanho do Modelo",
62
- "Move": "Mover",
63
- "Move files successfully": "Arquivos movidos com sucesso",
64
- "No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.",
65
- "No selected options": "Nenhuma opção selecionada",
66
- "Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)",
67
- "Number of Workers": "Número de Processos",
68
- "Open Inference Server": "Abrir Servidor de Inferência",
69
- "Open Labeler WebUI": "Abrir WebUI de Rotulagem",
70
- "Open Tensorboard": "Abrir Tensorboard",
71
- "Opened labeler in browser": "WebUI de rotulagem aberta no navegador",
72
- "Optional Label Language": "Idioma do Rótulo (Opcional)",
73
- "Optional online ver": "Versão online (opcional)",
74
- "Output Path": "Caminho de Saída",
75
- "Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente",
76
- "Post-quantification Precision": "Precisão Pós-quantização",
77
- "Precision": "Precisão",
78
- "Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador",
79
- "Put your text here.": "Insira seu texto aqui.",
80
- "Quantify": "Quantizar",
81
- "Quantify successfully": "Quantizado com sucesso",
82
- "Realtime Transform Text": "Transformar Texto em Tempo Real",
83
- "Reference Audio": "Áudio de Referência",
84
- "Reference Text": "Texto de Referência",
85
- "warning": "Aviso",
86
- "Pre-processing begins...": "O pré-processamento começou!",
87
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.",
88
- "Remove Selected Data": "Remover Dados Selecionados",
89
- "Removed path successfully!": "Caminho removido com sucesso!",
90
- "Repetition Penalty": "Penalidade de Repetição",
91
- "Save model every n steps": "Salvar modelo a cada n etapas",
92
- "Select LLAMA ckpt": "Selecionar .ckpt do LLAMA",
93
- "Select source file processing method": "Escolha como processar o arquivo de origem",
94
- "Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)",
95
- "Selected: {}": "Selecionado: {}",
96
- "Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta",
97
- "Start Training": "Iniciar Treinamento",
98
- "Streaming Audio": "Áudio em Streaming",
99
- "Streaming Generate": "Geração em Streaming",
100
- "Tensorboard Host": "Host do Tensorboard",
101
- "Tensorboard Log Path": "Caminho de Log do Tensorboard",
102
- "Tensorboard Port": "Porta do Tensorboard",
103
- "Tensorboard interface is closed": "A interface do Tensorboard está fechada",
104
- "Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}",
105
- "Text Normalization": "Normalização de Texto",
106
- "Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.",
107
- "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência",
108
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.",
109
- "Training Configuration": "Configuração de Treinamento",
110
- "Training Error": "Erro de Treinamento",
111
- "Training stopped": "Treinamento interrompido!",
112
- "Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso",
113
- "Use LoRA": "Usar LoRA",
114
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade",
115
- "Use filelist": "Usar lista de arquivos",
116
- "VQGAN Configuration": "Configuração do VQGAN",
117
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)",
118
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.",
119
- "WebUI Host": "Host da WebUI",
120
- "WebUI Port": "Porta da WebUI",
121
- "Whisper Model": "Modelo Whisper",
122
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).",
123
- "auto": "automático",
124
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+",
125
- "latest": "mais recente",
126
- "new": "novo",
127
- "This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.",
128
- "You don't need to train this model!": "Não é necessário treinar este modelo!",
129
- "Yes": "Sim",
130
- "No": "Não",
131
- "version:": "versão:",
132
- "author:": "autor:"
133
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/i18n/locale/zh_CN.json DELETED
@@ -1,122 +0,0 @@
1
- {
2
- "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed",
3
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。",
4
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.",
5
- "Accumulate Gradient Batches": "梯度累积批次",
6
- "Add to Processing Area": "加入处理区",
7
- "Added path successfully!": "添加路径成功!",
8
- "Advanced Config": "高级参数",
9
- "Base LLAMA Model": "基础 LLAMA 模型",
10
- "Batch Inference": "批量推理",
11
- "Batch Size": "批次大小",
12
- "Changing with the Model Path": "随模型路径变化",
13
- "Chinese": "中文",
14
- "Compile Model": "编译模型",
15
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间",
16
- "Copy": "复制",
17
- "Data Preprocessing": "数据预处理",
18
- "Data Preprocessing Path": "数据预处理路径",
19
- "Data Source": "数据源",
20
- "Decoder Model Config": "解码器模型配置",
21
- "Decoder Model Path": "解码器模型路径",
22
- "Disabled": "禁用",
23
- "Enable Reference Audio": "启用参考音频",
24
- "English": "英文",
25
- "Error Message": "错误信息",
26
- "File Preprocessing": "文件预处理",
27
- "Generate": "生成",
28
- "Generated Audio": "音频",
29
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式",
30
- "Infer interface is closed": "推理界面已关闭",
31
- "Inference Configuration": "推理配置",
32
- "Inference Server Configuration": "推理服务器配置",
33
- "Inference Server Error": "推理服务器错误",
34
- "Inferring interface is launched at {}": "推理界面已在 {} 上启动",
35
- "Initial Learning Rate": "初始学习率",
36
- "Input Audio & Source Path for Transcription": "输入音频和转录源路径",
37
- "Input Text": "输入文本",
38
- "Invalid path: {}": "无效路径: {}",
39
- "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU",
40
- "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭",
41
- "Japanese": "日文",
42
- "LLAMA Configuration": "LLAMA 配置",
43
- "LLAMA Model Config": "LLAMA 模型配置",
44
- "LLAMA Model Path": "LLAMA 模型路径",
45
- "Labeling Device": "标注加速设备",
46
- "LoRA Model to be merged": "要合并的 LoRA 模型",
47
- "Maximum Audio Duration": "最大音频时长",
48
- "Maximum Length per Sample": "每个样本的最大长度",
49
- "Maximum Training Steps": "最大训练步数",
50
- "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制",
51
- "Merge": "合并",
52
- "Merge LoRA": "合并 LoRA",
53
- "Merge successfully": "合并成功",
54
- "Minimum Audio Duration": "最小音频时长",
55
- "Model Output Path": "模型输出路径",
56
- "Model Size": "模型规模",
57
- "Move": "移动",
58
- "Move files successfully": "移动文件成功",
59
- "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.",
60
- "No selected options": "没有选择的选项",
61
- "Number of Workers": "数据加载进程数",
62
- "Open Inference Server": "打开推理服务器",
63
- "Open Labeler WebUI": "打开标注工具",
64
- "Open Tensorboard": "打开 Tensorboard",
65
- "Opened labeler in browser": "在浏览器中打开标注工具",
66
- "Optional Label Language": "[可选] 标注语言",
67
- "Optional online ver": "[可选] 使用在线版",
68
- "Output Path": "输出路径",
69
- "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
70
- "Precision": "精度",
71
- "Probability of applying Speaker Condition": "应用说话人条件的概率",
72
- "Put your text here.": "在此处输入文本.",
73
- "Reference Audio": "参考音频",
74
- "Reference Text": "参考文本",
75
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.",
76
- "Remove Selected Data": "移除选中数据",
77
- "Removed path successfully!": "移除路径成功!",
78
- "Repetition Penalty": "重复惩罚",
79
- "Save model every n steps": "每 n 步保存模型",
80
- "Select LLAMA ckpt": "选择 LLAMA 检查点",
81
- "Select VITS ckpt": "选择 VITS 检查点",
82
- "Select VQGAN ckpt": "选择 VQGAN 检查点",
83
- "Select source file processing method": "选择源文件处理方法",
84
- "Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型",
85
- "Selected: {}": "已选择: {}",
86
- "Speaker": "说话人",
87
- "Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
88
- "Start Training": "开始训练",
89
- "Streaming Audio": "流式音频",
90
- "Streaming Generate": "流式合成",
91
- "Tensorboard Host": "Tensorboard 监听地址",
92
- "Tensorboard Log Path": "Tensorboard 日志路径",
93
- "Tensorboard Port": "Tensorboard 端口",
94
- "Tensorboard interface is closed": "Tensorboard 界面已关闭",
95
- "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动",
96
- "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.",
97
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.",
98
- "Training Configuration": "训练配置",
99
- "Training Error": "训练错误",
100
- "Training stopped": "训练已停止",
101
- "Type name of the speaker": "输入说话人的名称",
102
- "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择",
103
- "Use LoRA": "使用 LoRA",
104
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量",
105
- "Use filelist": "使用文件列表",
106
- "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small",
107
- "VITS Configuration": "VITS 配置",
108
- "VQGAN Configuration": "VQGAN 配置",
109
- "Validation Batch Size": "验证批次大小",
110
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)",
111
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.",
112
- "WebUI Host": "WebUI 监听地址",
113
- "WebUI Port": "WebUI 端口",
114
- "Whisper Model": "Whisper 模型",
115
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.",
116
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed",
117
- "latest": "最近的检查点",
118
- "new": "创建新的检查点",
119
- "Realtime Transform Text": "实时规范化文本",
120
- "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览",
121
- "Text Normalization": "文本规范化"
122
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/i18n/scan.py DELETED
@@ -1,122 +0,0 @@
1
- import ast
2
- import glob
3
- import json
4
- from collections import OrderedDict
5
- from pathlib import Path
6
-
7
- from loguru import logger
8
-
9
- from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH
10
-
11
-
12
- def extract_i18n_strings(node):
13
- i18n_strings = []
14
-
15
- if (
16
- isinstance(node, ast.Call)
17
- and isinstance(node.func, ast.Name)
18
- and node.func.id == "i18n"
19
- ):
20
- for arg in node.args:
21
- if isinstance(arg, ast.Str):
22
- i18n_strings.append(arg.s)
23
-
24
- for child_node in ast.iter_child_nodes(node):
25
- i18n_strings.extend(extract_i18n_strings(child_node))
26
-
27
- return i18n_strings
28
-
29
-
30
- # scan the directory for all .py files (recursively)
31
- # for each file, parse the code into an AST
32
- # for each AST, extract the i18n strings
33
-
34
- strings = []
35
- folders = ["fish_speech", "tools"]
36
- # for filename in glob.iglob("**/*.py", recursive=True):
37
- for folder in folders:
38
- for f in Path(folder).rglob("*.py"):
39
- code = f.read_text(encoding="utf-8")
40
- if "i18n(" in code:
41
- tree = ast.parse(code)
42
- i18n_strings = extract_i18n_strings(tree)
43
- logger.info(f"Found {len(i18n_strings)} i18n strings in {f}")
44
- strings.extend(i18n_strings)
45
-
46
- code_keys = set(strings)
47
- logger.info(f"Total unique: {len(code_keys)}")
48
-
49
-
50
- standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
51
- with open(standard_file, "r", encoding="utf-8") as f:
52
- standard_data = json.load(f, object_pairs_hook=OrderedDict)
53
- standard_keys = set(standard_data.keys())
54
-
55
- # Define the standard file name
56
- unused_keys = standard_keys - code_keys
57
- logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}")
58
- for unused_key in unused_keys:
59
- logger.info(f"\t{unused_key}")
60
-
61
- missing_keys = code_keys - standard_keys
62
- logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}")
63
- for missing_key in missing_keys:
64
- logger.info(f"\t{missing_key}")
65
-
66
- code_keys_dict = OrderedDict()
67
- for s in strings:
68
- code_keys_dict[s] = s
69
-
70
- # write back
71
- with open(standard_file, "w", encoding="utf-8") as f:
72
- json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True)
73
- f.write("\n")
74
-
75
- logger.info(f"Updated {standard_file}")
76
-
77
-
78
- # Define the standard file name
79
- standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
80
-
81
- # Find all JSON files in the directory
82
- dir_path = I18N_FILE_PATH
83
- languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE]
84
-
85
- # Load the standard file
86
- with open(standard_file, "r", encoding="utf-8") as f:
87
- standard_data = json.load(f, object_pairs_hook=OrderedDict)
88
-
89
- # Loop through each language file
90
- for lang_file in languages:
91
- # Load the language file
92
- with open(lang_file, "r", encoding="utf-8") as f:
93
- lang_data = json.load(f, object_pairs_hook=OrderedDict)
94
-
95
- # Find the difference between the language file and the standard file
96
- diff = set(standard_data.keys()) - set(lang_data.keys())
97
-
98
- miss = set(lang_data.keys()) - set(standard_data.keys())
99
-
100
- # Add any missing keys to the language file
101
- for key in diff:
102
- lang_data[key] = "#!" + key
103
- logger.info(f"Added missing key: {key} to {lang_file}")
104
-
105
- # Del any extra keys to the language file
106
- for key in miss:
107
- del lang_data[key]
108
- logger.info(f"Del extra key: {key} from {lang_file}")
109
-
110
- # Sort the keys of the language file to match the order of the standard file
111
- lang_data = OrderedDict(
112
- sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0]))
113
- )
114
-
115
- # Save the updated language file
116
- with open(lang_file, "w", encoding="utf-8") as f:
117
- json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True)
118
- f.write("\n")
119
-
120
- logger.info(f"Updated {lang_file}")
121
-
122
- logger.info("Done")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/models/text2semantic/__init__.py DELETED
File without changes
fish_speech/models/text2semantic/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (179 Bytes)
 
fish_speech/models/text2semantic/__pycache__/lit_module.cpython-310.pyc DELETED
Binary file (5.41 kB)
 
fish_speech/models/text2semantic/__pycache__/llama.cpython-310.pyc DELETED
Binary file (20.8 kB)
 
fish_speech/models/text2semantic/__pycache__/lora.cpython-310.pyc DELETED
Binary file (1.79 kB)
 
fish_speech/models/text2semantic/lit_module.py DELETED
@@ -1,202 +0,0 @@
1
- from typing import Any, Optional
2
-
3
- import lightning as L
4
- import torch
5
- import torch.nn.functional as F
6
- from lightning.pytorch.utilities.types import OptimizerLRScheduler
7
-
8
- import fish_speech.utils as utils
9
- from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
10
- from fish_speech.models.text2semantic.llama import NaiveTransformer
11
-
12
- log = utils.RankedLogger(__name__, rank_zero_only=True)
13
-
14
-
15
- class TextToSemantic(L.LightningModule):
16
- def __init__(
17
- self,
18
- model: NaiveTransformer,
19
- optimizer: Any,
20
- lr_scheduler: Any,
21
- ):
22
- super().__init__()
23
-
24
- self.model = model
25
- self.optimizer_builder = optimizer
26
- self.lr_scheduler_builder = lr_scheduler
27
-
28
- def forward(self, x):
29
- return self.model(x)
30
-
31
- def on_save_checkpoint(self, checkpoint):
32
- # Save only LoRA parameters
33
- state_dict = checkpoint["state_dict"]
34
- use_lora = any("lora" in name for name in state_dict.keys())
35
- if not use_lora:
36
- return
37
-
38
- for name in list(state_dict.keys()):
39
- if "lora" not in name:
40
- state_dict.pop(name)
41
-
42
- def configure_optimizers(self) -> OptimizerLRScheduler:
43
- # Get weight decay parameters
44
- weight_decay_parameters, other_parameters = [], []
45
- for name, param in self.named_parameters():
46
- if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
47
- other_parameters.append(param)
48
- else:
49
- weight_decay_parameters.append(param)
50
-
51
- optimizer = self.optimizer_builder(
52
- [
53
- {"params": weight_decay_parameters},
54
- {"params": other_parameters, "weight_decay": 0.0},
55
- ]
56
- )
57
-
58
- # Print the parameters and their weight decay
59
- for i in optimizer.param_groups:
60
- log.info(
61
- f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
62
- )
63
-
64
- lr_scheduler = self.lr_scheduler_builder(optimizer)
65
-
66
- return {
67
- "optimizer": optimizer,
68
- "lr_scheduler": {
69
- "scheduler": lr_scheduler,
70
- "interval": "step",
71
- },
72
- }
73
-
74
- # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
75
- def get_batch_logps(
76
- self,
77
- logits: torch.FloatTensor,
78
- labels: torch.LongTensor,
79
- average_log_prob: bool = False,
80
- ) -> torch.FloatTensor:
81
- """Compute the log probabilities of the given labels under the given logits.
82
-
83
- Args:
84
- logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
85
- labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
86
- average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
87
-
88
- Returns:
89
- A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
90
- """
91
- assert logits.shape[:-1] == labels.shape
92
-
93
- labels = labels.clone()
94
- loss_mask = labels != -100
95
-
96
- # dummy token; we'll ignore the losses on these tokens later
97
- labels[labels == -100] = 0
98
-
99
- per_token_logps = torch.gather(
100
- logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
101
- ).squeeze(-1)
102
-
103
- if average_log_prob:
104
- return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
105
- else:
106
- return (per_token_logps * loss_mask).sum(-1)
107
-
108
- def _step(self, batch, batch_idx, stage: str):
109
- is_train = stage == "train"
110
-
111
- if is_train:
112
- # Key part to make lora work
113
- # Otherwise the parameters are merged, which lead to incorrect gradients
114
- self.model.train()
115
-
116
- # Do positive and negative samples in the same batch to speed up training
117
- labels = batch["labels"]
118
- outputs = self.model(
119
- inp=batch["inputs"],
120
- key_padding_mask=batch["attention_masks"],
121
- )
122
- token_logits = outputs.token_logits
123
- codebook_logits = outputs.codebook_logits
124
-
125
- # Generate labels
126
- base_loss = F.cross_entropy(
127
- token_logits.view(-1, token_logits.size(-1)),
128
- labels[:, 0].reshape(-1),
129
- ignore_index=-100,
130
- )
131
-
132
- codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
133
- semantic_loss = F.cross_entropy(
134
- codebook_logits.view(-1, codebook_logits.size(-1)),
135
- codebook_labels.reshape(-1),
136
- ignore_index=-100,
137
- )
138
-
139
- loss = base_loss + semantic_loss
140
-
141
- self.log(
142
- f"{stage}/loss",
143
- loss,
144
- on_step=is_train,
145
- on_epoch=not is_train,
146
- prog_bar=True,
147
- logger=True,
148
- sync_dist=not is_train,
149
- )
150
-
151
- self.log(
152
- f"{stage}/base_loss",
153
- base_loss,
154
- on_step=is_train,
155
- on_epoch=not is_train,
156
- prog_bar=False,
157
- logger=True,
158
- sync_dist=not is_train,
159
- )
160
-
161
- self.log(
162
- f"{stage}/semantic_loss",
163
- semantic_loss,
164
- on_step=is_train,
165
- on_epoch=not is_train,
166
- prog_bar=False,
167
- logger=True,
168
- sync_dist=not is_train,
169
- )
170
-
171
- # Top-5 accuracy
172
- accuracy = self.get_accuracy(codebook_logits, codebook_labels)
173
- self.log(
174
- f"{stage}/top_5_accuracy",
175
- accuracy,
176
- on_step=is_train,
177
- on_epoch=not is_train,
178
- prog_bar=True,
179
- logger=True,
180
- sync_dist=not is_train,
181
- )
182
-
183
- return loss
184
-
185
- def get_accuracy(self, logits, labels):
186
- mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
187
- if mask.sum() == 0:
188
- return torch.tensor(0.0, device=logits.device)
189
-
190
- _, indices = logits.topk(5, dim=-1)
191
- correct = indices.eq(labels.unsqueeze(-1))
192
- correct[~mask] = 0
193
- correct = correct.sum()
194
- accuracy = correct / mask.sum()
195
-
196
- return accuracy
197
-
198
- def training_step(self, batch, batch_idx):
199
- return self._step(batch, batch_idx, "train")
200
-
201
- def validation_step(self, batch, batch_idx):
202
- return self._step(batch, batch_idx, "val")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/models/text2semantic/llama.py DELETED
@@ -1,779 +0,0 @@
1
- import json
2
- import math
3
- from collections import OrderedDict
4
- from dataclasses import dataclass
5
- from pathlib import Path
6
- from typing import Optional
7
-
8
- import torch
9
- import torch.nn as nn
10
- from einops import rearrange
11
- from loguru import logger
12
- from torch import Tensor
13
- from torch.nn import functional as F
14
- from torch.nn.attention import SDPBackend, sdpa_kernel
15
- from torch.utils.checkpoint import checkpoint
16
- from transformers import AutoTokenizer
17
-
18
- from fish_speech.conversation import SEMANTIC_TOKEN
19
- from fish_speech.utils import RankedLogger
20
-
21
- from .lora import LoraConfig, setup_lora
22
-
23
- log = RankedLogger(__name__, rank_zero_only=True)
24
-
25
-
26
- def find_multiple(n: int, k: int) -> int:
27
- if n % k == 0:
28
- return n
29
- return n + k - (n % k)
30
-
31
-
32
- @dataclass
33
- class BaseModelArgs:
34
- model_type: str = "base"
35
-
36
- vocab_size: int = 32000
37
- n_layer: int = 32
38
- n_head: int = 32
39
- dim: int = 4096
40
- intermediate_size: int = None
41
- n_local_heads: int = -1
42
- head_dim: int = 64
43
- rope_base: float = 10000
44
- norm_eps: float = 1e-5
45
- max_seq_len: int = 2048
46
- dropout: float = 0.0
47
- tie_word_embeddings: bool = True
48
- attention_qkv_bias: bool = False
49
-
50
- # Codebook configs
51
- codebook_size: int = 160
52
- num_codebooks: int = 4
53
-
54
- # Gradient checkpointing
55
- use_gradient_checkpointing: bool = True
56
-
57
- # Initialize the model
58
- initializer_range: float = 0.02
59
-
60
- def __post_init__(self):
61
- if self.n_local_heads == -1:
62
- self.n_local_heads = self.n_head
63
- if self.intermediate_size is None:
64
- hidden_dim = 4 * self.dim
65
- n_hidden = int(2 * hidden_dim / 3)
66
- self.intermediate_size = find_multiple(n_hidden, 256)
67
- self.head_dim = self.dim // self.n_head
68
-
69
- @staticmethod
70
- def from_pretrained(path: str):
71
- path = Path(path)
72
-
73
- if path.is_dir():
74
- path = path / "config.json"
75
-
76
- with open(path, "r", encoding="utf-8") as f:
77
- data = json.load(f)
78
-
79
- match data["model_type"]:
80
- case "naive":
81
- cls = NaiveModelArgs
82
- case "dual_ar":
83
- cls = DualARModelArgs
84
- case _:
85
- raise ValueError(f"Unknown model type: {data['model_type']}")
86
-
87
- return cls(**data)
88
-
89
- def save(self, path: str):
90
- with open(path, "w") as f:
91
- json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
92
-
93
-
94
- @dataclass
95
- class NaiveModelArgs(BaseModelArgs):
96
- model_type: str = "naive"
97
-
98
-
99
- @dataclass
100
- class DualARModelArgs(BaseModelArgs):
101
- model_type: str = "dual_ar"
102
- n_fast_layer: int = 4
103
-
104
-
105
- class KVCache(nn.Module):
106
- def __init__(
107
- self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
108
- ):
109
- super().__init__()
110
- cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
111
- self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
112
- self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
113
-
114
- def update(self, input_pos, k_val, v_val):
115
- # input_pos: [S], k_val: [B, H, S, D]
116
- assert input_pos.shape[0] == k_val.shape[2]
117
-
118
- k_out = self.k_cache
119
- v_out = self.v_cache
120
- k_out[:, :, input_pos] = k_val
121
- v_out[:, :, input_pos] = v_val
122
-
123
- return k_out, v_out
124
-
125
-
126
- @dataclass
127
- class TransformerForwardResult:
128
- token_logits: Tensor
129
- codebook_logits: Tensor
130
-
131
-
132
- @dataclass
133
- class BaseTransformerForwardResult:
134
- logits: Tensor
135
- hidden_states: Tensor
136
-
137
-
138
- class BaseTransformer(nn.Module):
139
- def __init__(
140
- self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True
141
- ) -> None:
142
- super().__init__()
143
- self.config = config
144
- self.tokenizer = tokenizer
145
-
146
- self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN)
147
-
148
- # Slow transformer
149
- self.embeddings = nn.Embedding(
150
- config.vocab_size,
151
- config.dim,
152
- )
153
- self.codebook_embeddings = nn.Embedding(
154
- config.codebook_size * config.num_codebooks,
155
- config.dim,
156
- )
157
- self.layers = nn.ModuleList(
158
- TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
159
- )
160
- self.norm = RMSNorm(config.dim, eps=config.norm_eps)
161
-
162
- if self.config.tie_word_embeddings is False:
163
- self.output = nn.Linear(
164
- config.dim,
165
- config.vocab_size,
166
- bias=False,
167
- )
168
-
169
- self.register_buffer(
170
- "freqs_cis",
171
- precompute_freqs_cis(
172
- config.max_seq_len,
173
- config.dim // config.n_head,
174
- config.rope_base,
175
- ),
176
- persistent=False,
177
- )
178
- self.register_buffer(
179
- "causal_mask",
180
- torch.tril(
181
- torch.ones(
182
- config.max_seq_len,
183
- config.max_seq_len,
184
- dtype=torch.bool,
185
- )
186
- ),
187
- persistent=False,
188
- )
189
-
190
- # For kv cache
191
- self.max_batch_size = -1
192
- self.max_seq_len = -1
193
-
194
- if init_weights:
195
- self.apply(self._init_weights)
196
-
197
- def setup_caches(
198
- self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
199
- ):
200
- if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
201
- return
202
-
203
- head_dim = self.config.dim // self.config.n_head
204
- max_seq_len = find_multiple(max_seq_len, 8)
205
- self.max_seq_len = max_seq_len
206
- self.max_batch_size = max_batch_size
207
-
208
- for b in self.layers:
209
- b.attention.kv_cache = KVCache(
210
- max_batch_size,
211
- max_seq_len,
212
- self.config.n_local_heads,
213
- head_dim,
214
- dtype=dtype,
215
- )
216
-
217
- def embed(self, x: Tensor) -> Tensor:
218
- vocab_embeds = [self.embeddings(x[:, 0])]
219
- for i in range(self.config.num_codebooks):
220
- emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
221
- emb[x[:, 0] != self.semantic_token_id] = 0
222
- vocab_embeds.append(emb)
223
-
224
- x = torch.stack(vocab_embeds, dim=3)
225
- x = x.sum(dim=3)
226
-
227
- return x
228
-
229
- def forward(
230
- self,
231
- inp: Tensor,
232
- key_padding_mask: Optional[Tensor] = None,
233
- ) -> BaseTransformerForwardResult:
234
- seq_len = inp.size(2)
235
-
236
- # Here we want to merge the embeddings of the codebooks
237
- x = self.embed(inp)
238
-
239
- freqs_cis = self.freqs_cis[:seq_len]
240
-
241
- # Not that the causal mask here follows the definition of scaled_dot_product_attention
242
- # That is, FALSE means masked out
243
- # To maintain consistency, key_padding_mask use TRUE to mask out
244
- mask = None
245
- if key_padding_mask is not None:
246
- mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
247
- mask = mask & key_padding_mask[:, None, None, :].logical_not()
248
-
249
- for layer in self.layers:
250
- if self.config.use_gradient_checkpointing and self.training:
251
- x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
252
- else:
253
- x = layer(x, freqs_cis, mask)
254
-
255
- # We got slow_out here
256
- slow_out = self.norm(x)
257
-
258
- if self.config.tie_word_embeddings:
259
- token_logits = F.linear(slow_out, self.embeddings.weight)
260
- else:
261
- token_logits = self.output(slow_out)
262
-
263
- return BaseTransformerForwardResult(
264
- logits=token_logits,
265
- hidden_states=x,
266
- )
267
-
268
- def forward_generate(
269
- self,
270
- x: Tensor,
271
- input_pos: Optional[Tensor] = None,
272
- return_all: bool = False,
273
- ) -> BaseTransformerForwardResult:
274
- # This is used for generation, optimized for torch compile
275
- assert (
276
- self.max_seq_len != -1 and self.max_batch_size != -1
277
- ), "Please call setup_caches before forward_generate"
278
-
279
- x = self.embed(x)
280
-
281
- mask = self.causal_mask[
282
- None, None, input_pos, : self.max_seq_len
283
- ] # (B, N, Q, K)
284
- freqs_cis = self.freqs_cis[input_pos]
285
-
286
- for layer in self.layers:
287
- x = layer(x, freqs_cis, mask, input_pos=input_pos)
288
-
289
- # If prefill, we only calculate the logits of last token
290
- if x.size(1) > 1 and not return_all:
291
- x = x[:, -1:]
292
-
293
- # We got slow_out here
294
- slow_out = self.norm(x)
295
-
296
- if self.config.tie_word_embeddings:
297
- token_logits = F.linear(slow_out, self.embeddings.weight)
298
- else:
299
- token_logits = self.output(slow_out)
300
-
301
- return BaseTransformerForwardResult(
302
- logits=token_logits,
303
- hidden_states=x,
304
- )
305
-
306
- def _init_weights(self, module):
307
- std = self.config.initializer_range
308
- if isinstance(module, nn.Linear):
309
- module.weight.data.normal_(mean=0.0, std=std)
310
- if module.bias is not None:
311
- module.bias.data.zero_()
312
- elif isinstance(module, nn.Embedding):
313
- module.weight.data.normal_(mean=0.0, std=std)
314
- if module.padding_idx is not None:
315
- module.weight.data[module.padding_idx].zero_()
316
-
317
- @staticmethod
318
- def from_pretrained(
319
- path: str,
320
- load_weights: bool = False,
321
- max_length: int | None = None,
322
- lora_config: LoraConfig | None = None,
323
- rope_base: int | None = None,
324
- ) -> "BaseTransformer":
325
- config = BaseModelArgs.from_pretrained(str(path))
326
- if max_length is not None:
327
- config.max_seq_len = max_length
328
- log.info(f"Override max_seq_len to {max_length}")
329
-
330
- if rope_base is not None:
331
- config.rope_base = rope_base
332
- log.info(f"Override rope_base to {rope_base}")
333
-
334
- match config.model_type:
335
- case "naive":
336
- model_cls = NaiveTransformer
337
- case "dual_ar":
338
- model_cls = DualARTransformer
339
- case _:
340
- raise ValueError(f"Unknown model type: {config.model_type}")
341
-
342
- tokenizer = AutoTokenizer.from_pretrained(str(path))
343
- log.info(f"Loading model from {path}, config: {config}")
344
- model = model_cls(config, tokenizer=tokenizer)
345
-
346
- if lora_config is not None:
347
- setup_lora(model, lora_config)
348
- log.info(f"LoRA setup: {lora_config}")
349
-
350
- if load_weights is False:
351
- log.info("Randomly initialized model")
352
- else:
353
-
354
- if "int8" in str(Path(path)):
355
- logger.info("Using int8 weight-only quantization!")
356
- from tools.llama.quantize import WeightOnlyInt8QuantHandler
357
-
358
- simple_quantizer = WeightOnlyInt8QuantHandler(model)
359
- model = simple_quantizer.convert_for_runtime()
360
-
361
- if "int4" in str(Path(path)):
362
- logger.info("Using int4 quantization!")
363
- path_comps = path.name.split("-")
364
- assert path_comps[-2].startswith("g")
365
- groupsize = int(path_comps[-2][1:])
366
- from tools.llama.quantize import WeightOnlyInt4QuantHandler
367
-
368
- simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
369
- model = simple_quantizer.convert_for_runtime()
370
-
371
- weights = torch.load(
372
- Path(path) / "model.pth", map_location="cpu", mmap=True
373
- )
374
-
375
- if "state_dict" in weights:
376
- logger.warning(
377
- "Using a TextToSemantic LightningModule checkpoint, "
378
- "please make sure it is a full model, not a LoRA model."
379
- )
380
- weights = weights["state_dict"]
381
-
382
- if next(iter(weights.keys())).startswith("model."):
383
- logger.info(
384
- f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
385
- )
386
- new_weights = OrderedDict()
387
- for k, v in weights.items():
388
- new_weights[k.replace("model.", "")] = v
389
- weights = new_weights
390
-
391
- # Verify the name and shape of parameters since strict=False in load_state_dict.
392
- for k, v in model.named_parameters():
393
- if k not in weights:
394
- logger.warning(f"No weight for {k}")
395
- elif v.shape != weights[k].shape:
396
- logger.warning(
397
- f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
398
- )
399
-
400
- err = model.load_state_dict(weights, strict=False, assign=True)
401
- log.info(f"Loaded weights with error: {err}")
402
-
403
- return model
404
-
405
- def save_pretrained(self, path: str, drop_lora: bool = False):
406
- path = Path(path)
407
- path.mkdir(parents=True, exist_ok=True)
408
-
409
- self.config.save(path / "config.json")
410
- state_dict = self.state_dict()
411
-
412
- if drop_lora:
413
- for key in list(state_dict.keys()):
414
- if "lora" not in key:
415
- continue
416
-
417
- state_dict.pop(key)
418
- log.info(f"Drop LoRA parameter: {key}")
419
-
420
- torch.save(state_dict, path / "model.pth")
421
- self.tokenizer.save_pretrained(path)
422
-
423
-
424
- class NaiveTransformer(BaseTransformer):
425
- def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
426
- super().__init__(config, init_weights=False, tokenizer=tokenizer)
427
-
428
- self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
429
- self.codebook_output = nn.Linear(
430
- config.dim,
431
- config.codebook_size * config.num_codebooks,
432
- bias=False,
433
- )
434
-
435
- self.apply(self._init_weights)
436
-
437
- def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
438
- token_logits = result.logits
439
- x = result.hidden_states
440
-
441
- # Codebook
442
- codebook_logits = self.codebook_output(self.codebook_norm(x))
443
- codebook_logits = rearrange(
444
- codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
445
- )
446
-
447
- return TransformerForwardResult(
448
- token_logits=token_logits,
449
- codebook_logits=codebook_logits,
450
- )
451
-
452
- def forward(
453
- self,
454
- inp: Tensor,
455
- key_padding_mask: Optional[Tensor] = None,
456
- ) -> TransformerForwardResult:
457
- result = super().forward(
458
- inp=inp,
459
- key_padding_mask=key_padding_mask,
460
- )
461
- return self.decode(result)
462
-
463
- def forward_generate(
464
- self, x: Tensor, input_pos: Optional[Tensor] = None
465
- ) -> TransformerForwardResult:
466
- result = super().forward_generate(x, input_pos)
467
- return self.decode(result)
468
-
469
-
470
- class DualARTransformer(BaseTransformer):
471
- def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
472
- super().__init__(config, init_weights=False, tokenizer=tokenizer)
473
-
474
- # Fast transformer
475
- self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim)
476
-
477
- # The equivalent bs is so large that sdpa doesn't work
478
- self.fast_layers = nn.ModuleList(
479
- TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
480
- )
481
- self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
482
- self.fast_output = nn.Linear(
483
- config.dim,
484
- config.codebook_size,
485
- bias=False,
486
- )
487
-
488
- self.apply(self._init_weights)
489
-
490
- def setup_caches(
491
- self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
492
- ):
493
- super().setup_caches(max_batch_size, max_seq_len, dtype)
494
-
495
- head_dim = self.config.dim // self.config.n_head
496
-
497
- # Fast transformer
498
- # The max seq len here is the number of codebooks
499
- for b in self.fast_layers:
500
- b.attention.kv_cache = KVCache(
501
- max_batch_size,
502
- self.config.num_codebooks,
503
- self.config.n_local_heads,
504
- head_dim,
505
- dtype=dtype,
506
- )
507
-
508
- def forward(
509
- self,
510
- inp: Tensor,
511
- key_padding_mask: Optional[Tensor] = None,
512
- ) -> TransformerForwardResult:
513
- parent_result = super().forward(inp, key_padding_mask)
514
- token_logits = parent_result.logits
515
- x = parent_result.hidden_states
516
-
517
- # Fast transformer
518
- fast_seq_len = self.config.num_codebooks
519
- fast_mask = self.causal_mask[
520
- None, None, :fast_seq_len, :fast_seq_len
521
- ] # (B, N, Q, K)
522
- fast_freqs_cis = self.freqs_cis[:fast_seq_len]
523
-
524
- # Drop the last token and rotate left
525
- codebooks = inp[:, 1:-1, 1:]
526
- codebooks = F.pad(codebooks, (0, 1), value=0)
527
- codebook_embeddings = self.fast_embeddings(codebooks)
528
- x = torch.cat([x[:, None], codebook_embeddings], dim=1)
529
- b, s = x.size(0), x.size(2)
530
- x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
531
-
532
- # Remove padded part
533
- codebooks = rearrange(codebooks, "b n s -> (b s) n")
534
- codebook_mask = (codebooks == 0).all(dim=-1)
535
-
536
- if torch.all(codebook_mask):
537
- # If all codebooks are padded, we keep first 8 to make sure the model runs
538
- codebook_mask[:8] = False
539
-
540
- x_bs, x_len = x.size(0), x.size(1)
541
- x = x[~codebook_mask]
542
-
543
- for layer in self.fast_layers:
544
- if self.config.use_gradient_checkpointing and self.training:
545
- x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
546
- else:
547
- x = layer(x, fast_freqs_cis, fast_mask)
548
-
549
- # unflatten the batch and num_codebooks
550
- fast_out = self.fast_norm(x)
551
- codebook_logits = self.fast_output(fast_out)
552
-
553
- # Re-pad the codebook_logits
554
- buffer = torch.zeros(
555
- x_bs,
556
- x_len,
557
- codebook_logits.size(-1),
558
- device=codebook_logits.device,
559
- dtype=codebook_logits.dtype,
560
- )
561
- buffer[~codebook_mask] = codebook_logits
562
- codebook_logits = buffer
563
-
564
- assert codebook_logits.shape[1] == self.config.num_codebooks
565
- codebook_logits = rearrange(
566
- codebook_logits,
567
- "(b s) n d -> b s n d",
568
- b=b,
569
- s=s,
570
- n=self.config.num_codebooks,
571
- )
572
-
573
- return TransformerForwardResult(
574
- token_logits=token_logits,
575
- codebook_logits=codebook_logits,
576
- )
577
-
578
- def forward_generate_fast(
579
- self, x: Tensor, input_pos: Optional[Tensor] = None
580
- ) -> Tensor:
581
- # Fast transformer
582
- x = x.view(1, 1, -1)
583
-
584
- fast_mask = self.causal_mask[
585
- None, None, input_pos, : self.config.num_codebooks
586
- ] # (B, N, Q, K)
587
- fast_freqs_cis = self.freqs_cis[input_pos]
588
-
589
- for layer in self.fast_layers:
590
- x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
591
-
592
- # unflatten the batch and num_codebooks
593
- fast_out = self.fast_norm(x) # only take the last token
594
- codebook_logits = self.fast_output(fast_out)
595
-
596
- return codebook_logits
597
-
598
-
599
- class TransformerBlock(nn.Module):
600
- def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
601
- super().__init__()
602
- self.attention = Attention(config, use_sdpa=use_sdpa)
603
- self.feed_forward = FeedForward(config)
604
- self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
605
- self.attention_norm = RMSNorm(config.dim, config.norm_eps)
606
-
607
- def forward(
608
- self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
609
- ) -> Tensor:
610
- h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
611
- out = h + self.feed_forward(self.ffn_norm(h))
612
- return out
613
-
614
-
615
- class Attention(nn.Module):
616
- def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
617
- super().__init__()
618
- assert config.dim % config.n_head == 0
619
-
620
- total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
621
- # key, query, value projections for all heads, but in a batch
622
- self.wqkv = nn.Linear(
623
- config.dim, total_head_dim, bias=config.attention_qkv_bias
624
- )
625
- self.wo = nn.Linear(config.dim, config.dim, bias=False)
626
- self.kv_cache = None
627
-
628
- self.dropout = config.dropout
629
- self.n_head = config.n_head
630
- self.head_dim = config.head_dim
631
- self.n_local_heads = config.n_local_heads
632
- self.dim = config.dim
633
- self.use_sdpa = use_sdpa
634
- self._register_load_state_dict_pre_hook(self.load_hook)
635
-
636
- def load_hook(self, state_dict, prefix, *args):
637
- if prefix + "wq.weight" in state_dict:
638
- wq = state_dict.pop(prefix + "wq.weight")
639
- wk = state_dict.pop(prefix + "wk.weight")
640
- wv = state_dict.pop(prefix + "wv.weight")
641
- state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
642
-
643
- def forward(
644
- self,
645
- x: Tensor,
646
- freqs_cis: Tensor,
647
- mask: Tensor,
648
- input_pos: Optional[Tensor] = None,
649
- ) -> Tensor:
650
- bsz, seqlen, _ = x.shape
651
-
652
- kv_size = self.n_local_heads * self.head_dim
653
- q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
654
-
655
- q = q.view(bsz, seqlen, self.n_head, self.head_dim)
656
- k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
657
- v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
658
-
659
- q = apply_rotary_emb(q, freqs_cis)
660
- k = apply_rotary_emb(k, freqs_cis)
661
-
662
- q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
663
-
664
- if self.kv_cache is not None:
665
- k, v = self.kv_cache.update(input_pos, k, v)
666
-
667
- k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
668
- v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
669
-
670
- if self.use_sdpa:
671
- if mask is None:
672
- with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
673
- y = F.scaled_dot_product_attention(
674
- q,
675
- k,
676
- v,
677
- dropout_p=self.dropout if self.training else 0.0,
678
- is_causal=True,
679
- # No third party attn_mask here to use flash_attention
680
- )
681
- else:
682
- y = F.scaled_dot_product_attention(
683
- q,
684
- k,
685
- v,
686
- attn_mask=mask,
687
- dropout_p=self.dropout if self.training else 0.0,
688
- )
689
- else:
690
- y = self.eq_scaled_dot_product_attention(
691
- q,
692
- k,
693
- v,
694
- attn_mask=mask,
695
- dropout_p=self.dropout if self.training else 0.0,
696
- )
697
-
698
- y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
699
-
700
- return self.wo(y)
701
-
702
- def eq_scaled_dot_product_attention(
703
- self,
704
- query,
705
- key,
706
- value,
707
- attn_mask=None,
708
- dropout_p=0.0,
709
- ) -> torch.Tensor:
710
- # This is a standard scaled dot product attention
711
- # It's low efficient, but it doesn't raise cuda error
712
-
713
- L, S = query.size(-2), key.size(-2)
714
- scale_factor = 1 / math.sqrt(query.size(-1))
715
- attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
716
-
717
- if attn_mask is not None:
718
- if attn_mask.dtype == torch.bool:
719
- attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
720
- else:
721
- attn_bias += attn_mask
722
-
723
- attn_weight = query @ key.transpose(-2, -1) * scale_factor
724
- attn_weight += attn_bias
725
- attn_weight = torch.softmax(attn_weight, dim=-1)
726
- attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
727
-
728
- return attn_weight @ value
729
-
730
-
731
- class FeedForward(nn.Module):
732
- def __init__(self, config: BaseModelArgs) -> None:
733
- super().__init__()
734
- self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
735
- self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
736
- self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
737
-
738
- def forward(self, x: Tensor) -> Tensor:
739
- return self.w2(F.silu(self.w1(x)) * self.w3(x))
740
-
741
-
742
- class RMSNorm(nn.Module):
743
- def __init__(self, dim: int, eps: float = 1e-5):
744
- super().__init__()
745
- self.eps = eps
746
- self.weight = nn.Parameter(torch.ones(dim))
747
-
748
- def _norm(self, x):
749
- return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
750
-
751
- def forward(self, x: Tensor) -> Tensor:
752
- output = self._norm(x.float()).type_as(x)
753
- return output * self.weight
754
-
755
-
756
- def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
757
- freqs = 1.0 / (
758
- base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
759
- )
760
- t = torch.arange(seq_len, device=freqs.device)
761
- freqs = torch.outer(t, freqs)
762
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
763
- cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
764
- return cache.to(dtype=torch.bfloat16)
765
-
766
-
767
- def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
768
- xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
769
- freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
770
- x_out2 = torch.stack(
771
- [
772
- xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
773
- xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
774
- ],
775
- -1,
776
- )
777
-
778
- x_out2 = x_out2.flatten(3)
779
- return x_out2.type_as(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/models/text2semantic/lora.py DELETED
@@ -1,92 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
- import loralib as lora
4
-
5
-
6
- @dataclass
7
- class LoraConfig:
8
- r: int
9
- lora_alpha: float
10
- lora_dropout: float = 0.0
11
-
12
-
13
- def setup_lora(model, lora_config):
14
- # Replace the embedding layer with a LoRA layer
15
- model.embeddings = lora.Embedding(
16
- num_embeddings=model.embeddings.num_embeddings,
17
- embedding_dim=model.embeddings.embedding_dim,
18
- padding_idx=model.embeddings.padding_idx,
19
- r=lora_config.r,
20
- lora_alpha=lora_config.lora_alpha,
21
- )
22
-
23
- model.codebook_embeddings = lora.Embedding(
24
- num_embeddings=model.codebook_embeddings.num_embeddings,
25
- embedding_dim=model.codebook_embeddings.embedding_dim,
26
- padding_idx=model.codebook_embeddings.padding_idx,
27
- r=lora_config.r,
28
- lora_alpha=lora_config.lora_alpha,
29
- )
30
-
31
- # Replace output layer with a LoRA layer
32
- linears = [(model, "output")]
33
-
34
- # Replace all linear layers with LoRA layers
35
- for layer in model.layers:
36
- linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
37
- linears.extend(
38
- [
39
- (layer.feed_forward, "w1"),
40
- (layer.feed_forward, "w2"),
41
- (layer.feed_forward, "w3"),
42
- ]
43
- )
44
-
45
- if hasattr(model, "fast_layers"):
46
- model.fast_embeddings = lora.Embedding(
47
- num_embeddings=model.fast_embeddings.num_embeddings,
48
- embedding_dim=model.fast_embeddings.embedding_dim,
49
- padding_idx=model.fast_embeddings.padding_idx,
50
- r=lora_config.r,
51
- lora_alpha=lora_config.lora_alpha,
52
- )
53
-
54
- # Dual-AR model
55
- linears.append((model, "fast_output"))
56
-
57
- for layer in model.fast_layers:
58
- linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
59
- linears.extend(
60
- [
61
- (layer.feed_forward, "w1"),
62
- (layer.feed_forward, "w2"),
63
- (layer.feed_forward, "w3"),
64
- ]
65
- )
66
-
67
- for module, layer in linears:
68
- updated_linear = lora.Linear(
69
- in_features=getattr(module, layer).in_features,
70
- out_features=getattr(module, layer).out_features,
71
- bias=getattr(module, layer).bias,
72
- r=lora_config.r,
73
- lora_alpha=lora_config.lora_alpha,
74
- lora_dropout=lora_config.lora_dropout,
75
- )
76
- setattr(module, layer, updated_linear)
77
-
78
- # Mark only the LoRA layers as trainable
79
- lora.mark_only_lora_as_trainable(model, bias="none")
80
-
81
-
82
- def get_merged_state_dict(model):
83
- # This line will merge the state dict of the model and the LoRA parameters
84
- model.eval()
85
-
86
- # Then we need to remove the LoRA parameters from the state dict
87
- state_dict = model.state_dict()
88
- for name in list(state_dict.keys()):
89
- if "lora" in name:
90
- state_dict.pop(name)
91
-
92
- return state_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/models/vqgan/__init__.py DELETED
File without changes
fish_speech/models/vqgan/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (171 Bytes)
 
fish_speech/models/vqgan/modules/__pycache__/firefly.cpython-310.pyc DELETED
Binary file (18.3 kB)
 
fish_speech/models/vqgan/modules/__pycache__/fsq.cpython-310.pyc DELETED
Binary file (3.72 kB)
 
fish_speech/models/vqgan/modules/firefly.py DELETED
@@ -1,596 +0,0 @@
1
- import math
2
- from functools import partial
3
- from math import prod
4
- from typing import Callable
5
-
6
- import torch
7
- import torch.nn.functional as F
8
- from torch import nn
9
- from torch.nn.utils.parametrizations import weight_norm
10
- from torch.nn.utils.parametrize import remove_parametrizations
11
- from torch.utils.checkpoint import checkpoint
12
-
13
-
14
- def sequence_mask(length, max_length=None):
15
- if max_length is None:
16
- max_length = length.max()
17
- x = torch.arange(max_length, dtype=length.dtype, device=length.device)
18
- return x.unsqueeze(0) < length.unsqueeze(1)
19
-
20
-
21
- def init_weights(m, mean=0.0, std=0.01):
22
- classname = m.__class__.__name__
23
- if classname.find("Conv1D") != -1:
24
- m.weight.data.normal_(mean, std)
25
-
26
-
27
- def get_padding(kernel_size, dilation=1):
28
- return (kernel_size * dilation - dilation) // 2
29
-
30
-
31
- def unpad1d(x: torch.Tensor, paddings: tuple[int, int]):
32
- """Remove padding from x, handling properly zero padding. Only for 1d!"""
33
- padding_left, padding_right = paddings
34
- assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
35
- assert (padding_left + padding_right) <= x.shape[-1]
36
- end = x.shape[-1] - padding_right
37
- return x[..., padding_left:end]
38
-
39
-
40
- def get_extra_padding_for_conv1d(
41
- x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
42
- ) -> int:
43
- """See `pad_for_conv1d`."""
44
- length = x.shape[-1]
45
- n_frames = (length - kernel_size + padding_total) / stride + 1
46
- ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
47
- return ideal_length - length
48
-
49
-
50
- def pad1d(
51
- x: torch.Tensor,
52
- paddings: tuple[int, int],
53
- mode: str = "zeros",
54
- value: float = 0.0,
55
- ):
56
- """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
57
- If this is the case, we insert extra 0 padding to the right
58
- before the reflection happen.
59
- """
60
- length = x.shape[-1]
61
- padding_left, padding_right = paddings
62
- assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
63
- if mode == "reflect":
64
- max_pad = max(padding_left, padding_right)
65
- extra_pad = 0
66
- if length <= max_pad:
67
- extra_pad = max_pad - length + 1
68
- x = F.pad(x, (0, extra_pad))
69
- padded = F.pad(x, paddings, mode, value)
70
- end = padded.shape[-1] - extra_pad
71
- return padded[..., :end]
72
- else:
73
- return F.pad(x, paddings, mode, value)
74
-
75
-
76
- class FishConvNet(nn.Module):
77
- def __init__(
78
- self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1
79
- ):
80
- super(FishConvNet, self).__init__()
81
- self.conv = nn.Conv1d(
82
- in_channels,
83
- out_channels,
84
- kernel_size,
85
- stride=stride,
86
- dilation=dilation,
87
- groups=groups,
88
- )
89
- self.stride = stride
90
- self.kernel_size = (kernel_size - 1) * dilation + 1
91
- self.dilation = dilation
92
-
93
- def forward(self, x):
94
- pad = self.kernel_size - self.stride
95
- extra_padding = get_extra_padding_for_conv1d(
96
- x, self.kernel_size, self.stride, pad
97
- )
98
- x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
99
- return self.conv(x).contiguous()
100
-
101
- def weight_norm(self, name="weight", dim=0):
102
- self.conv = weight_norm(self.conv, name=name, dim=dim)
103
- return self
104
-
105
- def remove_weight_norm(self):
106
- self.conv = remove_parametrizations(self.conv)
107
- return self
108
-
109
-
110
- class FishTransConvNet(nn.Module):
111
- def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1):
112
- super(FishTransConvNet, self).__init__()
113
- self.conv = nn.ConvTranspose1d(
114
- in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
115
- )
116
- self.stride = stride
117
- self.kernel_size = kernel_size
118
-
119
- def forward(self, x):
120
- x = self.conv(x)
121
- pad = self.kernel_size - self.stride
122
- padding_right = math.ceil(pad)
123
- padding_left = pad - padding_right
124
- x = unpad1d(x, (padding_left, padding_right))
125
- return x.contiguous()
126
-
127
- def weight_norm(self, name="weight", dim=0):
128
- self.conv = weight_norm(self.conv, name=name, dim=dim)
129
- return self
130
-
131
- def remove_weight_norm(self):
132
- self.conv = remove_parametrizations(self.conv)
133
- return self
134
-
135
-
136
- class ResBlock1(torch.nn.Module):
137
- def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
138
- super().__init__()
139
-
140
- self.convs1 = nn.ModuleList(
141
- [
142
- FishConvNet(
143
- channels, channels, kernel_size, stride=1, dilation=dilation[0]
144
- ).weight_norm(),
145
- FishConvNet(
146
- channels, channels, kernel_size, stride=1, dilation=dilation[1]
147
- ).weight_norm(),
148
- FishConvNet(
149
- channels, channels, kernel_size, stride=1, dilation=dilation[2]
150
- ).weight_norm(),
151
- ]
152
- )
153
- self.convs1.apply(init_weights)
154
-
155
- self.convs2 = nn.ModuleList(
156
- [
157
- FishConvNet(
158
- channels, channels, kernel_size, stride=1, dilation=dilation[0]
159
- ).weight_norm(),
160
- FishConvNet(
161
- channels, channels, kernel_size, stride=1, dilation=dilation[1]
162
- ).weight_norm(),
163
- FishConvNet(
164
- channels, channels, kernel_size, stride=1, dilation=dilation[2]
165
- ).weight_norm(),
166
- ]
167
- )
168
- self.convs2.apply(init_weights)
169
-
170
- def forward(self, x):
171
- for c1, c2 in zip(self.convs1, self.convs2):
172
- xt = F.silu(x)
173
- xt = c1(xt)
174
- xt = F.silu(xt)
175
- xt = c2(xt)
176
- x = xt + x
177
- return x
178
-
179
- def remove_parametrizations(self):
180
- for conv in self.convs1:
181
- remove_parametrizations(conv, tensor_name="weight")
182
- for conv in self.convs2:
183
- remove_parametrizations(conv, tensor_name="weight")
184
-
185
-
186
- class ParallelBlock(nn.Module):
187
- def __init__(
188
- self,
189
- channels: int,
190
- kernel_sizes: tuple[int] = (3, 7, 11),
191
- dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
192
- ):
193
- super().__init__()
194
-
195
- assert len(kernel_sizes) == len(dilation_sizes)
196
-
197
- self.blocks = nn.ModuleList()
198
- for k, d in zip(kernel_sizes, dilation_sizes):
199
- self.blocks.append(ResBlock1(channels, k, d))
200
-
201
- def forward(self, x):
202
- return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
203
-
204
- def remove_parametrizations(self):
205
- for block in self.blocks:
206
- block.remove_parametrizations()
207
-
208
-
209
- class HiFiGANGenerator(nn.Module):
210
- def __init__(
211
- self,
212
- *,
213
- hop_length: int = 512,
214
- upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
215
- upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
216
- resblock_kernel_sizes: tuple[int] = (3, 7, 11),
217
- resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
218
- num_mels: int = 128,
219
- upsample_initial_channel: int = 512,
220
- pre_conv_kernel_size: int = 7,
221
- post_conv_kernel_size: int = 7,
222
- post_activation: Callable = partial(nn.SiLU, inplace=True),
223
- ):
224
- super().__init__()
225
-
226
- assert (
227
- prod(upsample_rates) == hop_length
228
- ), f"hop_length must be {prod(upsample_rates)}"
229
-
230
- self.conv_pre = FishConvNet(
231
- num_mels,
232
- upsample_initial_channel,
233
- pre_conv_kernel_size,
234
- stride=1,
235
- ).weight_norm()
236
-
237
- self.num_upsamples = len(upsample_rates)
238
- self.num_kernels = len(resblock_kernel_sizes)
239
-
240
- self.noise_convs = nn.ModuleList()
241
- self.ups = nn.ModuleList()
242
-
243
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
244
- self.ups.append(
245
- FishTransConvNet(
246
- upsample_initial_channel // (2**i),
247
- upsample_initial_channel // (2 ** (i + 1)),
248
- k,
249
- stride=u,
250
- ).weight_norm()
251
- )
252
-
253
- self.resblocks = nn.ModuleList()
254
- for i in range(len(self.ups)):
255
- ch = upsample_initial_channel // (2 ** (i + 1))
256
- self.resblocks.append(
257
- ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
258
- )
259
-
260
- self.activation_post = post_activation()
261
- self.conv_post = FishConvNet(
262
- ch, 1, post_conv_kernel_size, stride=1
263
- ).weight_norm()
264
- self.ups.apply(init_weights)
265
- self.conv_post.apply(init_weights)
266
-
267
- def forward(self, x):
268
- x = self.conv_pre(x)
269
-
270
- for i in range(self.num_upsamples):
271
- x = F.silu(x, inplace=True)
272
- x = self.ups[i](x)
273
-
274
- if self.training and self.checkpointing:
275
- x = checkpoint(
276
- self.resblocks[i],
277
- x,
278
- use_reentrant=False,
279
- )
280
- else:
281
- x = self.resblocks[i](x)
282
-
283
- x = self.activation_post(x)
284
- x = self.conv_post(x)
285
- x = torch.tanh(x)
286
-
287
- return x
288
-
289
- def remove_parametrizations(self):
290
- for up in self.ups:
291
- remove_parametrizations(up, tensor_name="weight")
292
- for block in self.resblocks:
293
- block.remove_parametrizations()
294
- remove_parametrizations(self.conv_pre, tensor_name="weight")
295
- remove_parametrizations(self.conv_post, tensor_name="weight")
296
-
297
-
298
- # DropPath copied from timm library
299
- def drop_path(
300
- x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
301
- ):
302
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
303
-
304
- This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
305
- the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
306
- See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
307
- changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
308
- 'survival rate' as the argument.
309
-
310
- """ # noqa: E501
311
-
312
- if drop_prob == 0.0 or not training:
313
- return x
314
- keep_prob = 1 - drop_prob
315
- shape = (x.shape[0],) + (1,) * (
316
- x.ndim - 1
317
- ) # work with diff dim tensors, not just 2D ConvNets
318
- random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
319
- if keep_prob > 0.0 and scale_by_keep:
320
- random_tensor.div_(keep_prob)
321
- return x * random_tensor
322
-
323
-
324
- class DropPath(nn.Module):
325
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
326
-
327
- def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
328
- super(DropPath, self).__init__()
329
- self.drop_prob = drop_prob
330
- self.scale_by_keep = scale_by_keep
331
-
332
- def forward(self, x):
333
- return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
334
-
335
- def extra_repr(self):
336
- return f"drop_prob={round(self.drop_prob,3):0.3f}"
337
-
338
-
339
- class LayerNorm(nn.Module):
340
- r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
341
- The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
342
- shape (batch_size, height, width, channels) while channels_first corresponds to inputs
343
- with shape (batch_size, channels, height, width).
344
- """ # noqa: E501
345
-
346
- def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
347
- super().__init__()
348
- self.weight = nn.Parameter(torch.ones(normalized_shape))
349
- self.bias = nn.Parameter(torch.zeros(normalized_shape))
350
- self.eps = eps
351
- self.data_format = data_format
352
- if self.data_format not in ["channels_last", "channels_first"]:
353
- raise NotImplementedError
354
- self.normalized_shape = (normalized_shape,)
355
-
356
- def forward(self, x):
357
- if self.data_format == "channels_last":
358
- return F.layer_norm(
359
- x, self.normalized_shape, self.weight, self.bias, self.eps
360
- )
361
- elif self.data_format == "channels_first":
362
- u = x.mean(1, keepdim=True)
363
- s = (x - u).pow(2).mean(1, keepdim=True)
364
- x = (x - u) / torch.sqrt(s + self.eps)
365
- x = self.weight[:, None] * x + self.bias[:, None]
366
- return x
367
-
368
-
369
- # ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
370
- class ConvNeXtBlock(nn.Module):
371
- r"""ConvNeXt Block. There are two equivalent implementations:
372
- (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
373
- (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
374
- We use (2) as we find it slightly faster in PyTorch
375
-
376
- Args:
377
- dim (int): Number of input channels.
378
- drop_path (float): Stochastic depth rate. Default: 0.0
379
- layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
380
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
381
- kernel_size (int): Kernel size for depthwise conv. Default: 7.
382
- dilation (int): Dilation for depthwise conv. Default: 1.
383
- """ # noqa: E501
384
-
385
- def __init__(
386
- self,
387
- dim: int,
388
- drop_path: float = 0.0,
389
- layer_scale_init_value: float = 1e-6,
390
- mlp_ratio: float = 4.0,
391
- kernel_size: int = 7,
392
- dilation: int = 1,
393
- ):
394
- super().__init__()
395
-
396
- self.dwconv = FishConvNet(
397
- dim,
398
- dim,
399
- kernel_size=kernel_size,
400
- # padding=int(dilation * (kernel_size - 1) / 2),
401
- groups=dim,
402
- ) # depthwise conv
403
- self.norm = LayerNorm(dim, eps=1e-6)
404
- self.pwconv1 = nn.Linear(
405
- dim, int(mlp_ratio * dim)
406
- ) # pointwise/1x1 convs, implemented with linear layers
407
- self.act = nn.GELU()
408
- self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
409
- self.gamma = (
410
- nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
411
- if layer_scale_init_value > 0
412
- else None
413
- )
414
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
415
-
416
- def forward(self, x, apply_residual: bool = True):
417
- input = x
418
-
419
- x = self.dwconv(x)
420
- x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
421
- x = self.norm(x)
422
- x = self.pwconv1(x)
423
- x = self.act(x)
424
- x = self.pwconv2(x)
425
-
426
- if self.gamma is not None:
427
- x = self.gamma * x
428
-
429
- x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
430
- x = self.drop_path(x)
431
-
432
- if apply_residual:
433
- x = input + x
434
-
435
- return x
436
-
437
-
438
- class ConvNeXtEncoder(nn.Module):
439
- def __init__(
440
- self,
441
- input_channels: int = 3,
442
- depths: list[int] = [3, 3, 9, 3],
443
- dims: list[int] = [96, 192, 384, 768],
444
- drop_path_rate: float = 0.0,
445
- layer_scale_init_value: float = 1e-6,
446
- kernel_size: int = 7,
447
- ):
448
- super().__init__()
449
- assert len(depths) == len(dims)
450
-
451
- self.downsample_layers = nn.ModuleList()
452
- stem = nn.Sequential(
453
- FishConvNet(
454
- input_channels,
455
- dims[0],
456
- kernel_size=7,
457
- # padding=3,
458
- # padding_mode="replicate",
459
- # padding_mode="zeros",
460
- ),
461
- LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
462
- )
463
- self.downsample_layers.append(stem)
464
-
465
- for i in range(len(depths) - 1):
466
- mid_layer = nn.Sequential(
467
- LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
468
- nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
469
- )
470
- self.downsample_layers.append(mid_layer)
471
-
472
- self.stages = nn.ModuleList()
473
- dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
474
-
475
- cur = 0
476
- for i in range(len(depths)):
477
- stage = nn.Sequential(
478
- *[
479
- ConvNeXtBlock(
480
- dim=dims[i],
481
- drop_path=dp_rates[cur + j],
482
- layer_scale_init_value=layer_scale_init_value,
483
- kernel_size=kernel_size,
484
- )
485
- for j in range(depths[i])
486
- ]
487
- )
488
- self.stages.append(stage)
489
- cur += depths[i]
490
-
491
- self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
492
- self.apply(self._init_weights)
493
-
494
- def _init_weights(self, m):
495
- if isinstance(m, (nn.Conv1d, nn.Linear)):
496
- nn.init.trunc_normal_(m.weight, std=0.02)
497
- nn.init.constant_(m.bias, 0)
498
-
499
- def forward(
500
- self,
501
- x: torch.Tensor,
502
- ) -> torch.Tensor:
503
- for i in range(len(self.downsample_layers)):
504
- x = self.downsample_layers[i](x)
505
- x = self.stages[i](x)
506
-
507
- return self.norm(x)
508
-
509
-
510
- class FireflyArchitecture(nn.Module):
511
- def __init__(
512
- self,
513
- backbone: nn.Module,
514
- head: nn.Module,
515
- quantizer: nn.Module,
516
- spec_transform: nn.Module,
517
- ):
518
- super().__init__()
519
-
520
- self.backbone = backbone
521
- self.head = head
522
- self.quantizer = quantizer
523
- self.spec_transform = spec_transform
524
- self.downsample_factor = math.prod(self.quantizer.downsample_factor)
525
-
526
- def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
527
- if self.spec_transform is not None:
528
- x = self.spec_transform(x)
529
-
530
- x = self.backbone(x)
531
- if mask is not None:
532
- x = x * mask
533
-
534
- if self.quantizer is not None:
535
- vq_result = self.quantizer(x)
536
- x = vq_result.z
537
-
538
- if mask is not None:
539
- x = x * mask
540
-
541
- x = self.head(x, template=template)
542
-
543
- if x.ndim == 2:
544
- x = x[:, None, :]
545
-
546
- if self.vq is not None:
547
- return x, vq_result
548
-
549
- return x
550
-
551
- def encode(self, audios, audio_lengths):
552
- audios = audios.float()
553
-
554
- mels = self.spec_transform(audios)
555
- mel_lengths = audio_lengths // self.spec_transform.hop_length
556
- mel_masks = sequence_mask(mel_lengths, mels.shape[2])
557
- mel_masks_float_conv = mel_masks[:, None, :].float()
558
- mels = mels * mel_masks_float_conv
559
-
560
- # Encode
561
- encoded_features = self.backbone(mels) * mel_masks_float_conv
562
- feature_lengths = mel_lengths // self.downsample_factor
563
-
564
- return self.quantizer.encode(encoded_features), feature_lengths
565
-
566
- def decode(self, indices, feature_lengths) -> torch.Tensor:
567
- mel_masks = sequence_mask(
568
- feature_lengths * self.downsample_factor,
569
- indices.shape[2] * self.downsample_factor,
570
- )
571
- mel_masks_float_conv = mel_masks[:, None, :].float()
572
- audio_lengths = (
573
- feature_lengths * self.downsample_factor * self.spec_transform.hop_length
574
- )
575
-
576
- audio_masks = sequence_mask(
577
- audio_lengths,
578
- indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length,
579
- )
580
- audio_masks_float_conv = audio_masks[:, None, :].float()
581
-
582
- z = self.quantizer.decode(indices) * mel_masks_float_conv
583
- x = self.head(z) * audio_masks_float_conv
584
-
585
- return x, audio_lengths
586
-
587
- def remove_parametrizations(self):
588
- if hasattr(self.backbone, "remove_parametrizations"):
589
- self.backbone.remove_parametrizations()
590
-
591
- if hasattr(self.head, "remove_parametrizations"):
592
- self.head.remove_parametrizations()
593
-
594
- @property
595
- def device(self):
596
- return next(self.parameters()).device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/models/vqgan/modules/fsq.py DELETED
@@ -1,116 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- from einops import rearrange
7
- from vector_quantize_pytorch import GroupedResidualFSQ
8
-
9
- from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet
10
-
11
-
12
- @dataclass
13
- class FSQResult:
14
- z: torch.Tensor
15
- codes: torch.Tensor
16
- latents: torch.Tensor
17
-
18
-
19
- class DownsampleFiniteScalarQuantize(nn.Module):
20
- def __init__(
21
- self,
22
- input_dim: int = 512,
23
- n_codebooks: int = 9,
24
- n_groups: int = 1,
25
- levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
26
- downsample_factor: tuple[int] = (2, 2),
27
- downsample_dims: tuple[int] | None = None,
28
- ):
29
- super().__init__()
30
-
31
- if downsample_dims is None:
32
- downsample_dims = [input_dim for _ in range(len(downsample_factor))]
33
-
34
- all_dims = (input_dim,) + tuple(downsample_dims)
35
-
36
- self.residual_fsq = GroupedResidualFSQ(
37
- dim=all_dims[-1],
38
- levels=levels,
39
- num_quantizers=n_codebooks,
40
- groups=n_groups,
41
- )
42
-
43
- self.downsample_factor = downsample_factor
44
- self.downsample_dims = downsample_dims
45
-
46
- self.downsample = nn.Sequential(
47
- *[
48
- nn.Sequential(
49
- FishConvNet(
50
- all_dims[idx],
51
- all_dims[idx + 1],
52
- kernel_size=factor,
53
- stride=factor,
54
- ),
55
- ConvNeXtBlock(dim=all_dims[idx + 1]),
56
- )
57
- for idx, factor in enumerate(downsample_factor)
58
- ]
59
- )
60
-
61
- self.upsample = nn.Sequential(
62
- *[
63
- nn.Sequential(
64
- FishTransConvNet(
65
- all_dims[idx + 1],
66
- all_dims[idx],
67
- kernel_size=factor,
68
- stride=factor,
69
- ),
70
- ConvNeXtBlock(dim=all_dims[idx]),
71
- )
72
- for idx, factor in reversed(list(enumerate(downsample_factor)))
73
- ]
74
- )
75
-
76
- self.apply(self._init_weights)
77
-
78
- def _init_weights(self, m):
79
- if isinstance(m, (nn.Conv1d, nn.Linear)):
80
- nn.init.trunc_normal_(m.weight, std=0.02)
81
- nn.init.constant_(m.bias, 0)
82
-
83
- def forward(self, z) -> FSQResult:
84
- original_shape = z.shape
85
- z = self.downsample(z)
86
- quantized, indices = self.residual_fsq(z.mT)
87
- result = FSQResult(
88
- z=quantized.mT,
89
- codes=indices.mT,
90
- latents=z,
91
- )
92
- result.z = self.upsample(result.z)
93
-
94
- # Pad or crop z to match original shape
95
- diff = original_shape[-1] - result.z.shape[-1]
96
- left = diff // 2
97
- right = diff - left
98
-
99
- if diff > 0:
100
- result.z = F.pad(result.z, (left, right))
101
- elif diff < 0:
102
- result.z = result.z[..., left:-right]
103
-
104
- return result
105
-
106
- def encode(self, z):
107
- z = self.downsample(z)
108
- _, indices = self.residual_fsq(z.mT)
109
- indices = rearrange(indices, "g b l r -> b (g r) l")
110
- return indices
111
-
112
- def decode(self, indices: torch.Tensor):
113
- indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
114
- z_q = self.residual_fsq.get_output_from_indices(indices)
115
- z_q = self.upsample(z_q.mT)
116
- return z_q
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/models/vqgan/utils.py DELETED
@@ -1,94 +0,0 @@
1
- import matplotlib
2
- import torch
3
- from matplotlib import pyplot as plt
4
-
5
- matplotlib.use("Agg")
6
-
7
-
8
- def convert_pad_shape(pad_shape):
9
- l = pad_shape[::-1]
10
- pad_shape = [item for sublist in l for item in sublist]
11
- return pad_shape
12
-
13
-
14
- def sequence_mask(length, max_length=None):
15
- if max_length is None:
16
- max_length = length.max()
17
- x = torch.arange(max_length, dtype=length.dtype, device=length.device)
18
- return x.unsqueeze(0) < length.unsqueeze(1)
19
-
20
-
21
- def init_weights(m, mean=0.0, std=0.01):
22
- classname = m.__class__.__name__
23
- if classname.find("Conv") != -1:
24
- m.weight.data.normal_(mean, std)
25
-
26
-
27
- def get_padding(kernel_size, dilation=1):
28
- return int((kernel_size * dilation - dilation) / 2)
29
-
30
-
31
- def plot_mel(data, titles=None):
32
- fig, axes = plt.subplots(len(data), 1, squeeze=False)
33
-
34
- if titles is None:
35
- titles = [None for i in range(len(data))]
36
-
37
- plt.tight_layout()
38
-
39
- for i in range(len(data)):
40
- mel = data[i]
41
-
42
- if isinstance(mel, torch.Tensor):
43
- mel = mel.float().detach().cpu().numpy()
44
-
45
- axes[i][0].imshow(mel, origin="lower")
46
- axes[i][0].set_aspect(2.5, adjustable="box")
47
- axes[i][0].set_ylim(0, mel.shape[0])
48
- axes[i][0].set_title(titles[i], fontsize="medium")
49
- axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
50
- axes[i][0].set_anchor("W")
51
-
52
- return fig
53
-
54
-
55
- def slice_segments(x, ids_str, segment_size=4):
56
- ret = torch.zeros_like(x[:, :, :segment_size])
57
- for i in range(x.size(0)):
58
- idx_str = ids_str[i]
59
- idx_end = idx_str + segment_size
60
- ret[i] = x[i, :, idx_str:idx_end]
61
-
62
- return ret
63
-
64
-
65
- def rand_slice_segments(x, x_lengths=None, segment_size=4):
66
- b, d, t = x.size()
67
- if x_lengths is None:
68
- x_lengths = t
69
- ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
70
- ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
71
- ret = slice_segments(x, ids_str, segment_size)
72
- return ret, ids_str
73
-
74
-
75
- @torch.jit.script
76
- def fused_add_tanh_sigmoid_multiply(in_act, n_channels):
77
- n_channels_int = n_channels[0]
78
- t_act = torch.tanh(in_act[:, :n_channels_int, :])
79
- s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
80
- acts = t_act * s_act
81
-
82
- return acts
83
-
84
-
85
- def avg_with_mask(x, mask):
86
- assert mask.dtype == torch.float, "Mask should be float"
87
-
88
- if mask.ndim == 2:
89
- mask = mask.unsqueeze(1)
90
-
91
- if mask.shape[1] == 1:
92
- mask = mask.expand_as(x)
93
-
94
- return (x * mask).sum() / mask.sum()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/scheduler.py DELETED
@@ -1,40 +0,0 @@
1
- import math
2
-
3
-
4
- def get_cosine_schedule_with_warmup_lr_lambda(
5
- current_step: int,
6
- *,
7
- num_warmup_steps: int | float,
8
- num_training_steps: int,
9
- num_cycles: float = 0.5,
10
- final_lr_ratio: float = 0.0,
11
- ):
12
- if 0 < num_warmup_steps < 1: # float mode
13
- num_warmup_steps = int(num_warmup_steps * num_training_steps)
14
-
15
- if current_step < num_warmup_steps:
16
- return float(current_step) / float(max(1, num_warmup_steps))
17
-
18
- progress = float(current_step - num_warmup_steps) / float(
19
- max(1, num_training_steps - num_warmup_steps)
20
- )
21
-
22
- return max(
23
- final_lr_ratio,
24
- 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
25
- )
26
-
27
-
28
- def get_constant_schedule_with_warmup_lr_lambda(
29
- current_step: int,
30
- *,
31
- num_warmup_steps: int | float,
32
- num_training_steps: int | None = None,
33
- ):
34
- if 0 < num_warmup_steps < 1: # float mode
35
- num_warmup_steps = int(num_warmup_steps * num_training_steps)
36
-
37
- if current_step < num_warmup_steps:
38
- return float(current_step) / float(max(1, num_warmup_steps))
39
-
40
- return 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/text/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .clean import clean_text
2
- from .spliter import split_text
3
-
4
- __all__ = ["clean_text", "split_text"]
 
 
 
 
 
fish_speech/text/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (274 Bytes)
 
fish_speech/text/__pycache__/clean.cpython-310.pyc DELETED
Binary file (840 Bytes)