diff --git a/app.py b/app.py index a43c979e4e4da0111993bf025ed7cc6e70d8d682..82b472d64430ece5251278eba306bae9b5a4d69a 100644 --- a/app.py +++ b/app.py @@ -10,7 +10,7 @@ import gc # Download if not exists os.makedirs("checkpoints", exist_ok=True) -snapshot_download(repo_id="fishaudio/fish-speech-1.4", local_dir="./checkpoints/fish-speech-1.4") +snapshot_download(repo_id="fishaudio/fish-speech-1.5", local_dir="./checkpoints/fish-speech-1.5") print("All checkpoints downloaded") @@ -31,11 +31,11 @@ torchaudio.set_audio_backend("soundfile") from loguru import logger from transformers import AutoTokenizer -from tools.llama.generate import launch_thread_safe_queue -from tools.vqgan.inference import load_model as load_vqgan_model +from fish_speech.i18n import i18n from fish_speech.text.chn_text_norm.text import Text as ChnNormedText +from fish_speech.utils import autocast_exclude_mps, set_seed from tools.api import decode_vq_tokens, encode_reference -from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model +from tools.file import AUDIO_EXTENSIONS, list_files from tools.llama.generate import ( GenerateRequest, GenerateResponse, @@ -44,20 +44,43 @@ from tools.llama.generate import ( ) from tools.vqgan.inference import load_model as load_decoder_model +from tools.schema import ( + GLOBAL_NUM_SAMPLES, + ASRPackRequest, + ServeASRRequest, + ServeASRResponse, + ServeASRSegment, + ServeAudioPart, + ServeForwardMessage, + ServeMessage, + ServeRequest, + ServeResponse, + ServeStreamDelta, + ServeStreamResponse, + ServeTextPart, + ServeTimedASRResponse, + ServeTTSRequest, + ServeVQGANDecodeRequest, + ServeVQGANDecodeResponse, + ServeVQGANEncodeRequest, + ServeVQGANEncodeResponse, + ServeVQPart, + ServeReferenceAudio +) # Make einx happy os.environ["EINX_FILTER_TRACEBACK"] = "false" HEADER_MD = """# Fish Speech -## The demo in this space is version 1.4, Please check [Fish Audio](https://fish.audio) for the best model. -## 该 Demo 为 Fish Speech 1.4 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO. +## The demo in this space is version 1.5, Please check [Fish Audio](https://fish.audio) for the best model. +## 该 Demo 为 Fish Speech 1.5 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO. A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio). 由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成. -You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.4). -你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1.4) 找到模型. +You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5). +你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1.5) 找到模型. Related code and weights are released under CC BY-NC-SA 4.0 License. 相关代码,权重使用 CC BY-NC-SA 4.0 许可证发布. @@ -65,8 +88,8 @@ Related code and weights are released under CC BY-NC-SA 4.0 License. We are not responsible for any misuse of the model, please consider your local laws and regulations before using it. 我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规. -The model running in this WebUI is Fish Speech V1.4 Medium. -在此 WebUI 中运行的模型是 Fish Speech V1.4 Medium. +The model running in this WebUI is Fish Speech V1.5 Medium. +在此 WebUI 中运行的模型是 Fish Speech V1.5 Medium. """ TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本.""" @@ -95,48 +118,77 @@ def build_html_error_message(error): @GPU_DECORATOR @torch.inference_mode() -def inference( - text, - enable_reference_audio, - reference_audio, - reference_text, - max_new_tokens, - chunk_length, - top_p, - repetition_penalty, - temperature, - streaming=False -): - if args.max_gradio_length > 0 and len(text) > args.max_gradio_length: - return ( - None, - None, - "Text is too long, please keep it under {} characters.".format( - args.max_gradio_length - ), +def inference(req: ServeTTSRequest): + + global prompt_tokens, prompt_texts + + idstr: str | None = req.reference_id + if idstr is not None: + ref_folder = Path("references") / idstr + ref_folder.mkdir(parents=True, exist_ok=True) + ref_audios = list_files( + ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False ) - # Parse reference audio aka prompt - prompt_tokens = encode_reference( - decoder_model=decoder_model, - reference_audio=reference_audio, - enable_reference_audio=enable_reference_audio, - ) + if req.use_memory_cache == "never" or ( + req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0 + ): + prompt_tokens = [ + encode_reference( + decoder_model=decoder_model, + reference_audio=audio_to_bytes(str(ref_audio)), + enable_reference_audio=True, + ) + for ref_audio in ref_audios + ] + prompt_texts = [ + read_ref_text(str(ref_audio.with_suffix(".lab"))) + for ref_audio in ref_audios + ] + else: + logger.info("Use same references") + + else: + # Parse reference audio aka prompt + refs = req.references + + if req.use_memory_cache == "never" or ( + req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0 + ): + prompt_tokens = [ + encode_reference( + decoder_model=decoder_model, + reference_audio=ref.audio, + enable_reference_audio=True, + ) + for ref in refs + ] + prompt_texts = [ref.text for ref in refs] + else: + logger.info("Use same references") + + if req.seed is not None: + set_seed(req.seed) + logger.warning(f"set seed: {req.seed}") # LLAMA Inference request = dict( device=decoder_model.device, - max_new_tokens=max_new_tokens, - text=text, - top_p=top_p, - repetition_penalty=repetition_penalty, - temperature=temperature, + max_new_tokens=req.max_new_tokens, + text=( + req.text + if not req.normalize + else ChnNormedText(raw_text=req.text).normalize() + ), + top_p=req.top_p, + repetition_penalty=req.repetition_penalty, + temperature=req.temperature, compile=args.compile, - iterative_prompt=chunk_length > 0, - chunk_length=chunk_length, - max_length=2048, - prompt_tokens=prompt_tokens if enable_reference_audio else None, - prompt_text=reference_text if enable_reference_audio else None, + iterative_prompt=req.chunk_length > 0, + chunk_length=req.chunk_length, + max_length=4096, + prompt_tokens=prompt_tokens, + prompt_text=prompt_texts, ) response_queue = queue.Queue() @@ -152,19 +204,15 @@ def inference( while True: result: WrappedGenerateResponse = response_queue.get() if result.status == "error": - return None, None, build_html_error_message(result.response) + yield None, None, build_html_error_message(result.response) + break result: GenerateResponse = result.response if result.action == "next": break - with torch.autocast( - device_type=( - "cpu" - if decoder_model.device.type == "mps" - else decoder_model.device.type - ), - dtype=args.precision, + with autocast_exclude_mps( + device_type=decoder_model.device.type, dtype=args.precision ): fake_audios = decode_vq_tokens( decoder_model=decoder_model, @@ -179,79 +227,24 @@ def inference( None, None, build_html_error_message( - "No audio generated, please check the input text." + i18n("No audio generated, please check the input text.") ), ) - # Return the final audio + # No matter streaming or not, we need to return the final audio audio = np.concatenate(segments, axis=0) - return None, (decoder_model.spec_transform.sample_rate, audio), None + yield None, (decoder_model.spec_transform.sample_rate, audio), None if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() - -def inference_with_auto_rerank( - text, - enable_reference_audio, - reference_audio, - reference_text, - max_new_tokens, - chunk_length, - top_p, - repetition_penalty, - temperature, - use_auto_rerank, - streaming=False, -): - max_attempts = 2 if use_auto_rerank else 1 - best_wer = float("inf") - best_audio = None - best_sample_rate = None - - for attempt in range(max_attempts): - _, (sample_rate, audio), message = inference( - text, - enable_reference_audio, - reference_audio, - reference_text, - max_new_tokens, - chunk_length, - top_p, - repetition_penalty, - temperature, - streaming=False, - ) - - if audio is None: - return None, None, message - - if not use_auto_rerank: - return None, (sample_rate, audio), None - - asr_result = batch_asr(asr_model, [audio], sample_rate)[0] - wer = calculate_wer(text, asr_result["text"]) - - if wer <= 0.3 and not asr_result["huge_gap"]: - return None, (sample_rate, audio), None - - if wer < best_wer: - best_wer = wer - best_audio = audio - best_sample_rate = sample_rate - - if attempt == max_attempts - 1: - break - - return None, (best_sample_rate, best_audio), None - - n_audios = 4 global_audio_list = [] global_error_list = [] + def inference_wrapper( text, enable_reference_audio, @@ -262,14 +255,14 @@ def inference_wrapper( top_p, repetition_penalty, temperature, + seed, batch_infer_num, - if_load_asr_model, ): audios = [] errors = [] for _ in range(batch_infer_num): - result = inference_with_auto_rerank( + result = inference( text, enable_reference_audio, reference_audio, @@ -279,10 +272,10 @@ def inference_wrapper( top_p, repetition_penalty, temperature, - if_load_asr_model, + seed, ) - _, audio_data, error_message = result + _, audio_data, error_message = next(result) audios.append( gr.Audio(value=audio_data if audio_data else None, visible=True), @@ -314,52 +307,17 @@ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): buffer.close() return wav_header_bytes - def normalize_text(user_input, use_normalization): if use_normalization: return ChnNormedText(raw_text=user_input).normalize() else: return user_input - -asr_model = None - - -def change_if_load_asr_model(if_load): - global asr_model - - if if_load: - gr.Warning("Loading faster whisper model...") - if asr_model is None: - asr_model = load_model() - return gr.Checkbox(label="Unload faster whisper model", value=if_load) - - if if_load is False: - gr.Warning("Unloading faster whisper model...") - del asr_model - asr_model = None - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - return gr.Checkbox(label="Load faster whisper model", value=if_load) - - -def change_if_auto_label(if_load, if_auto_label, enable_ref, ref_audio, ref_text): - if if_load and asr_model is not None: - if ( - if_auto_label - and enable_ref - and ref_audio is not None - and ref_text.strip() == "" - ): - data, sample_rate = librosa.load(ref_audio) - res = batch_asr(asr_model, [data], sample_rate)[0] - ref_text = res["text"] - else: - gr.Warning("Whisper model not loaded!") - - return gr.Textbox(value=ref_text) - +def update_examples(): + examples_dir = Path("references") + examples_dir.mkdir(parents=True, exist_ok=True) + example_audios = list_files(examples_dir, AUDIO_EXTENSIONS, recursive=True) + return gr.Dropdown(choices=example_audios + [""]) def build_app(): with gr.Blocks(theme=gr.themes.Base()) as app: @@ -377,202 +335,185 @@ def build_app(): with gr.Row(): with gr.Column(scale=3): text = gr.Textbox( - label="Input Text", placeholder=TEXTBOX_PLACEHOLDER, lines=10 + label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10 ) refined_text = gr.Textbox( - label="Realtime Transform Text", - placeholder= - "Normalization Result Preview (Currently Only Chinese)", + label=i18n("Realtime Transform Text"), + placeholder=i18n( + "Normalization Result Preview (Currently Only Chinese)" + ), lines=5, interactive=False, ) with gr.Row(): - if_refine_text = gr.Checkbox( - label="Text Normalization (ZH)", - value=False, - scale=1, - ) - - if_load_asr_model = gr.Checkbox( - label="Load / Unload ASR model for auto-reranking", + normalize = gr.Checkbox( + label=i18n("Text Normalization"), value=False, - scale=3, ) with gr.Row(): - with gr.Tab(label="Advanced Config"): - chunk_length = gr.Slider( - label="Iterative Prompt Length, 0 means off", - minimum=0, - maximum=500, - value=200, - step=8, - ) - - max_new_tokens = gr.Slider( - label="Maximum tokens per batch, 0 means no limit", - minimum=0, - maximum=2048, - value=0, # 0 means no limit - step=8, - ) - - top_p = gr.Slider( - label="Top-P", - minimum=0.6, - maximum=0.9, - value=0.7, - step=0.01, - ) - - repetition_penalty = gr.Slider( - label="Repetition Penalty", - minimum=1, - maximum=1.5, - value=1.2, - step=0.01, - ) - - temperature = gr.Slider( - label="Temperature", - minimum=0.6, - maximum=0.9, - value=0.7, - step=0.01, - ) - - with gr.Tab(label="Reference Audio"): - gr.Markdown( - "5 to 10 seconds of reference audio, useful for specifying speaker." - ) - - enable_reference_audio = gr.Checkbox( - label="Enable Reference Audio", - ) - - # Add dropdown for selecting example audio files - example_audio_files = [f for f in os.listdir("examples") if f.endswith(".wav")] - example_audio_dropdown = gr.Dropdown( - label="Select Example Audio", - choices=[""] + example_audio_files, - value="" - ) - - reference_audio = gr.Audio( - label="Reference Audio", - type="filepath", - ) - with gr.Row(): - if_auto_label = gr.Checkbox( - label="Auto Labeling", - min_width=100, - scale=0, - value=False, - ) - reference_text = gr.Textbox( - label="Reference Text", - lines=1, - placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。", - value="", - ) - with gr.Tab(label="Batch Inference"): - batch_infer_num = gr.Slider( - label="Batch infer nums", - minimum=1, - maximum=n_audios, - step=1, - value=1, - ) + with gr.Column(): + with gr.Tab(label=i18n("Advanced Config")): + with gr.Row(): + chunk_length = gr.Slider( + label=i18n("Iterative Prompt Length, 0 means off"), + minimum=0, + maximum=300, + value=200, + step=8, + ) + + max_new_tokens = gr.Slider( + label=i18n( + "Maximum tokens per batch, 0 means no limit" + ), + minimum=0, + maximum=2048, + value=0, + step=8, + ) + + with gr.Row(): + top_p = gr.Slider( + label="Top-P", + minimum=0.6, + maximum=0.9, + value=0.7, + step=0.01, + ) + + repetition_penalty = gr.Slider( + label=i18n("Repetition Penalty"), + minimum=1, + maximum=1.5, + value=1.2, + step=0.01, + ) + + with gr.Row(): + temperature = gr.Slider( + label="Temperature", + minimum=0.6, + maximum=0.9, + value=0.7, + step=0.01, + ) + seed = gr.Number( + label="Seed", + info="0 means randomized inference, otherwise deterministic", + value=0, + ) + + with gr.Tab(label=i18n("Reference Audio")): + with gr.Row(): + gr.Markdown( + i18n( + "5 to 10 seconds of reference audio, useful for specifying speaker." + ) + ) + with gr.Row(): + reference_id = gr.Textbox( + label=i18n("Reference ID"), + placeholder="Leave empty to use uploaded references", + ) + + with gr.Row(): + use_memory_cache = gr.Radio( + label=i18n("Use Memory Cache"), + choices=["never", "on-demand", "always"], + value="on-demand", + ) + + with gr.Row(): + reference_audio = gr.Audio( + label=i18n("Reference Audio"), + type="filepath", + ) + with gr.Row(): + reference_text = gr.Textbox( + label=i18n("Reference Text"), + lines=1, + placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。", + value="", + ) with gr.Column(scale=3): - for _ in range(n_audios): - with gr.Row(): - error = gr.HTML( - label="Error Message", - visible=True if _ == 0 else False, - ) - global_error_list.append(error) - with gr.Row(): - audio = gr.Audio( - label="Generated Audio", - type="numpy", - interactive=False, - visible=True if _ == 0 else False, - ) - global_audio_list.append(audio) - with gr.Row(): - stream_audio = gr.Audio( - label="Streaming Audio", - streaming=True, - autoplay=True, + error = gr.HTML( + label=i18n("Error Message"), + visible=True, + ) + with gr.Row(): + audio = gr.Audio( + label=i18n("Generated Audio"), + type="numpy", interactive=False, - show_download_button=True, + visible=True, ) + with gr.Row(): with gr.Column(scale=3): generate = gr.Button( - value="\U0001F3A7 " + "Generate", variant="primary" - ) - generate_stream = gr.Button( - value="\U0001F3A7 " + "Streaming Generate", - variant="primary", + value="\U0001F3A7 " + i18n("Generate"), variant="primary" ) text.input( - fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text] + fn=normalize_text, inputs=[text, normalize], outputs=[refined_text] ) - if_load_asr_model.change( - fn=change_if_load_asr_model, - inputs=[if_load_asr_model], - outputs=[if_load_asr_model], - ) - - if_auto_label.change( - fn=lambda: gr.Textbox(value=""), - inputs=[], - outputs=[reference_text], - ).then( - fn=change_if_auto_label, - inputs=[ - if_load_asr_model, - if_auto_label, - enable_reference_audio, - reference_audio, - reference_text, - ], - outputs=[reference_text], - ) - - def select_example_audio(audio_file): - if audio_file: - audio_path = os.path.join("examples", audio_file) - lab_file = os.path.splitext(audio_file)[0] + ".lab" - lab_path = os.path.join("examples", lab_file) - - if os.path.exists(lab_path): - with open(lab_path, "r", encoding="utf-8") as f: - lab_content = f.read().strip() - else: - lab_content = "" - - return audio_path, lab_content, True - return None, "", False - - # Connect the dropdown to update reference audio and text - example_audio_dropdown.change( - fn=select_example_audio, - inputs=[example_audio_dropdown], - outputs=[reference_audio, reference_text, enable_reference_audio] - ) - # # Submit + def inference_wrapper( + text, + normalize, + reference_id, + reference_audio, + reference_text, + max_new_tokens, + chunk_length, + top_p, + repetition_penalty, + temperature, + seed, + use_memory_cache, + ): + references = [] + if reference_audio: + # 将文件路径转换为字节 + with open(reference_audio, 'rb') as audio_file: + audio_bytes = audio_file.read() + references = [ + ServeReferenceAudio(audio=audio_bytes, text=reference_text) + ] + + req = ServeTTSRequest( + text=text, + normalize=normalize, + reference_id=reference_id if reference_id else None, + references=references, + max_new_tokens=max_new_tokens, + chunk_length=chunk_length, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + seed=int(seed) if seed else None, + use_memory_cache=use_memory_cache, + ) + + for result in inference(req): + if result[2]: # Error message + return None, result[2] + elif result[1]: # Audio data + return result[1], None + + return None, i18n("No audio generated") + + # Submit generate.click( inference_wrapper, [ refined_text, - enable_reference_audio, + normalize, + reference_id, reference_audio, reference_text, max_new_tokens, @@ -580,26 +521,28 @@ def build_app(): top_p, repetition_penalty, temperature, - batch_infer_num, - if_load_asr_model, + seed, + use_memory_cache, ], - [stream_audio, *global_audio_list, *global_error_list], + [audio, error], concurrency_limit=1, ) + return app + def parse_args(): parser = ArgumentParser() parser.add_argument( "--llama-checkpoint-path", type=Path, - default="checkpoints/fish-speech-1.4", + default="checkpoints/fish-speech-1.5", ) parser.add_argument( "--decoder-checkpoint-path", type=Path, - default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", + default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", ) parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") parser.add_argument("--device", type=str, default="cuda") @@ -634,17 +577,20 @@ if __name__ == "__main__": # Dry run to check if the model is loaded correctly and avoid the first-time latency list( - inference( - text="Hello, world!", - enable_reference_audio=False, - reference_audio=None, - reference_text="", - max_new_tokens=0, - chunk_length=200, - top_p=0.7, - repetition_penalty=1.2, - temperature=0.7, - ) + inference( + ServeTTSRequest( + text="Hello world.", + references=[], + reference_id=None, + max_new_tokens=0, + chunk_length=200, + top_p=0.7, + repetition_penalty=1.5, + temperature=0.7, + emotion=None, + format="wav", + ) + ) ) logger.info("Warming up done, launching the web UI...") diff --git a/fish_speech/callbacks/__init__.py b/fish_speech/callbacks/__init__.py index bbcf3f33656d180ca87cd14a21ede1544e5a61a3..8ed09366e9ee3736d3653903ab4b708e828732a6 100644 --- a/fish_speech/callbacks/__init__.py +++ b/fish_speech/callbacks/__init__.py @@ -1,3 +1,3 @@ -from .grad_norm import GradNormMonitor - -__all__ = ["GradNormMonitor"] +from .grad_norm import GradNormMonitor + +__all__ = ["GradNormMonitor"] diff --git a/fish_speech/callbacks/grad_norm.py b/fish_speech/callbacks/grad_norm.py index dbc95ef2a3723323b2d976001ed1e3c79c00b21a..64ca083b5565da69e8c1de59fd90ff94608fc4ff 100644 --- a/fish_speech/callbacks/grad_norm.py +++ b/fish_speech/callbacks/grad_norm.py @@ -1,113 +1,113 @@ -from typing import Optional, Union - -import lightning.pytorch as pl -import torch -from lightning import LightningModule, Trainer -from lightning.pytorch.callbacks import Callback -from torch import Tensor, nn -from torch.utils._foreach_utils import ( - _group_tensors_by_device_and_dtype, - _has_foreach_support, -) - - -@torch.no_grad() -def grad_norm( - parameters: Union[Tensor, list[Tensor]], - norm_type: float = 2.0, -) -> float: - """ - Returns the norm of the gradients of the given parameters. - - Args: - parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a - single Tensor that will have gradients normalized - norm_type (float): type of the used p-norm. - - Returns: - Total norm of the parameter gradients (viewed as a single vector). - """ # noqa: E501 - - if isinstance(parameters, Tensor): - parameters = [parameters] - - grads = [p.grad for p in parameters if p.grad is not None] - if len(grads) == 0: - return None - - first_device = grads[0].device - grouped_grads: dict[ - tuple[torch.device, torch.dtype], list[list[Tensor]] - ] = _group_tensors_by_device_and_dtype( - [[g.detach() for g in grads]] - ) # type: ignore[assignment] - - norms = [] - for (device, _), ([grads], _) in grouped_grads.items(): - if _has_foreach_support(grads, device=device): - norms.extend(torch._foreach_norm(grads, norm_type)) - else: - norms.extend([torch.norm(g, norm_type) for g in grads]) - - return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type) - - -class GradNormMonitor(Callback): - """ - Callback that computes the gradient norm of the model parameters. - """ - - def __init__( - self, - norm_type: float = 2.0, - logging_interval: str = "step", - sub_module: Optional[Union[str, list[str]]] = None, - ) -> None: - """ - Args: - norm_type (float): type of the used p-norm. - logging_interval (str): "step" or "epoch". - """ - super().__init__() - - self.norm_type = norm_type - self.logging_interval = logging_interval - self.sub_module = sub_module - - def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None: - """ - Computes the gradient norm of the model parameters and logs it to the logger. - - Args: - trainer (Trainer): The trainer object - model (LightningModule): The current lightningModule - """ - - lightning_model = model - - if self.sub_module is None: - return self.log_sub_module_grad_norm(lightning_model, model, "") - - sub_modules = self.sub_module - if isinstance(sub_modules, str): - sub_modules = [sub_modules] - - for sub_module in sub_modules: - self.log_sub_module_grad_norm( - lightning_model, getattr(model, sub_module), f"/{sub_module}" - ) - - def log_sub_module_grad_norm( - self, lightning_model: LightningModule, model: nn.Module, path: str - ) -> None: - grad_norm_val = grad_norm(model.parameters(), self.norm_type) - if grad_norm_val is None: - return - - on_step = self.logging_interval == "step" - lightning_model.log( - f"train{path}/grad_norm", - grad_norm_val, - on_step=on_step, - on_epoch=not on_step, - ) +from typing import Optional, Union + +import lightning.pytorch as pl +import torch +from lightning import LightningModule, Trainer +from lightning.pytorch.callbacks import Callback +from torch import Tensor, nn +from torch.utils._foreach_utils import ( + _group_tensors_by_device_and_dtype, + _has_foreach_support, +) + + +@torch.no_grad() +def grad_norm( + parameters: Union[Tensor, list[Tensor]], + norm_type: float = 2.0, +) -> float: + """ + Returns the norm of the gradients of the given parameters. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + norm_type (float): type of the used p-norm. + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ # noqa: E501 + + if isinstance(parameters, Tensor): + parameters = [parameters] + + grads = [p.grad for p in parameters if p.grad is not None] + if len(grads) == 0: + return None + + first_device = grads[0].device + grouped_grads: dict[ + tuple[torch.device, torch.dtype], list[list[Tensor]] + ] = _group_tensors_by_device_and_dtype( + [[g.detach() for g in grads]] + ) # type: ignore[assignment] + + norms = [] + for (device, _), ([grads], _) in grouped_grads.items(): + if _has_foreach_support(grads, device=device): + norms.extend(torch._foreach_norm(grads, norm_type)) + else: + norms.extend([torch.norm(g, norm_type) for g in grads]) + + return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type) + + +class GradNormMonitor(Callback): + """ + Callback that computes the gradient norm of the model parameters. + """ + + def __init__( + self, + norm_type: float = 2.0, + logging_interval: str = "step", + sub_module: Optional[Union[str, list[str]]] = None, + ) -> None: + """ + Args: + norm_type (float): type of the used p-norm. + logging_interval (str): "step" or "epoch". + """ + super().__init__() + + self.norm_type = norm_type + self.logging_interval = logging_interval + self.sub_module = sub_module + + def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None: + """ + Computes the gradient norm of the model parameters and logs it to the logger. + + Args: + trainer (Trainer): The trainer object + model (LightningModule): The current lightningModule + """ + + lightning_model = model + + if self.sub_module is None: + return self.log_sub_module_grad_norm(lightning_model, model, "") + + sub_modules = self.sub_module + if isinstance(sub_modules, str): + sub_modules = [sub_modules] + + for sub_module in sub_modules: + self.log_sub_module_grad_norm( + lightning_model, getattr(model, sub_module), f"/{sub_module}" + ) + + def log_sub_module_grad_norm( + self, lightning_model: LightningModule, model: nn.Module, path: str + ) -> None: + grad_norm_val = grad_norm(model.parameters(), self.norm_type) + if grad_norm_val is None: + return + + on_step = self.logging_interval == "step" + lightning_model.log( + f"train{path}/grad_norm", + grad_norm_val, + on_step=on_step, + on_epoch=not on_step, + ) diff --git a/fish_speech/configs/base.yaml b/fish_speech/configs/base.yaml index 99e6dab54d3f57bce4f6d29a9129a19a523cad75..b6bf1c8cf5a79e7ddb58ebc725f2c002c26c9486 100644 --- a/fish_speech/configs/base.yaml +++ b/fish_speech/configs/base.yaml @@ -1,87 +1,87 @@ -# Base configuration for training a model -paths: - run_dir: results/${project} - ckpt_dir: ${paths.run_dir}/checkpoints - -hydra: - run: - dir: ${paths.run_dir} - -# Lightning Trainer -trainer: - _target_: lightning.pytorch.trainer.Trainer - - default_root_dir: ${paths.run_dir} - accelerator: gpu - num_nodes: 1 - devices: auto - strategy: - _target_: lightning.pytorch.strategies.DDPStrategy - process_group_backend: nccl # This should be override when training on windows - - precision: bf16-mixed - - # disable validation by epoch end - check_val_every_n_epoch: null - val_check_interval: 5000 - max_steps: 100_000 - - # Use torch.backends.cudnn.benchmark to speed up training - benchmark: true - -# Callbacks -callbacks: - model_checkpoint: - _target_: lightning.pytorch.callbacks.ModelCheckpoint - dirpath: ${paths.ckpt_dir} - filename: "step_{step:09d}" - save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt - save_top_k: 5 # save 5 latest checkpoints - monitor: step # use step to monitor checkpoints - mode: max # save the latest checkpoint with the highest global_step - every_n_epochs: null # don't save checkpoints by epoch end - every_n_train_steps: 5000 # save checkpoints every 5000 steps - auto_insert_metric_name: false - - model_summary: - _target_: lightning.pytorch.callbacks.ModelSummary - max_depth: 2 # the maximum depth of layer nesting that the summary will include - - learning_rate_monitor: - _target_: lightning.pytorch.callbacks.LearningRateMonitor - logging_interval: step - log_momentum: false - - grad_norm_monitor: - _target_: fish_speech.callbacks.GradNormMonitor - norm_type: 2 - logging_interval: step - -# Logger -logger: - tensorboard: - _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger - save_dir: "${paths.run_dir}/tensorboard/" - name: null - log_graph: false - default_hp_metric: true - prefix: "" - - # wandb: - # _target_: lightning.pytorch.loggers.wandb.WandbLogger - # # name: "" # name of the run (normally generated by wandb) - # save_dir: "${paths.run_dir}" - # offline: False - # id: null # pass correct id to resume experiment! - # anonymous: null # enable anonymous logging - # project: "fish-speech" - # log_model: False # upload lightning ckpts - # prefix: "" # a string to put at the beginning of metric keys - # # entity: "" # set to name of your wandb team - # group: "" - # tags: ["vq", "hq", "finetune"] - # job_type: "" - -# Loop -train: true -test: false +# Base configuration for training a model +paths: + run_dir: results/${project} + ckpt_dir: ${paths.run_dir}/checkpoints + +hydra: + run: + dir: ${paths.run_dir} + +# Lightning Trainer +trainer: + _target_: lightning.pytorch.trainer.Trainer + + default_root_dir: ${paths.run_dir} + accelerator: gpu + num_nodes: 1 + devices: auto + strategy: + _target_: lightning.pytorch.strategies.DDPStrategy + process_group_backend: nccl # This should be override when training on windows + + precision: bf16-mixed + + # disable validation by epoch end + check_val_every_n_epoch: null + val_check_interval: 5000 + max_steps: 100_000 + + # Use torch.backends.cudnn.benchmark to speed up training + benchmark: true + +# Callbacks +callbacks: + model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: ${paths.ckpt_dir} + filename: "step_{step:09d}" + save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt + save_top_k: 5 # save 5 latest checkpoints + monitor: step # use step to monitor checkpoints + mode: max # save the latest checkpoint with the highest global_step + every_n_epochs: null # don't save checkpoints by epoch end + every_n_train_steps: 5000 # save checkpoints every 5000 steps + auto_insert_metric_name: false + + model_summary: + _target_: lightning.pytorch.callbacks.ModelSummary + max_depth: 2 # the maximum depth of layer nesting that the summary will include + + learning_rate_monitor: + _target_: lightning.pytorch.callbacks.LearningRateMonitor + logging_interval: step + log_momentum: false + + grad_norm_monitor: + _target_: fish_speech.callbacks.GradNormMonitor + norm_type: 2 + logging_interval: step + +# Logger +logger: + tensorboard: + _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger + save_dir: "${paths.run_dir}/tensorboard/" + name: null + log_graph: false + default_hp_metric: true + prefix: "" + + # wandb: + # _target_: lightning.pytorch.loggers.wandb.WandbLogger + # # name: "" # name of the run (normally generated by wandb) + # save_dir: "${paths.run_dir}" + # offline: False + # id: null # pass correct id to resume experiment! + # anonymous: null # enable anonymous logging + # project: "fish-speech" + # log_model: False # upload lightning ckpts + # prefix: "" # a string to put at the beginning of metric keys + # # entity: "" # set to name of your wandb team + # group: "" + # tags: ["vq", "hq", "finetune"] + # job_type: "" + +# Loop +train: true +test: false diff --git a/fish_speech/configs/firefly_gan_vq.yaml b/fish_speech/configs/firefly_gan_vq.yaml index 10aa8d4a522f0859ed8f541f5d48672d84b39c8f..a6a2521e8b1c6663cfccf6d7d6f8a593f939e043 100644 --- a/fish_speech/configs/firefly_gan_vq.yaml +++ b/fish_speech/configs/firefly_gan_vq.yaml @@ -1,33 +1,33 @@ -_target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture -spec_transform: - _target_: fish_speech.utils.spectrogram.LogMelSpectrogram - sample_rate: 44100 - n_mels: 160 - n_fft: 2048 - hop_length: 512 - win_length: 2048 -backbone: - _target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder - input_channels: 160 - depths: [3, 3, 9, 3] - dims: [128, 256, 384, 512] - drop_path_rate: 0.2 - kernel_size: 7 -head: - _target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator - hop_length: 512 - upsample_rates: [8, 8, 2, 2, 2] # aka. strides - upsample_kernel_sizes: [16, 16, 4, 4, 4] - resblock_kernel_sizes: [3, 7, 11] - resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] - num_mels: 512 - upsample_initial_channel: 512 - pre_conv_kernel_size: 13 - post_conv_kernel_size: 13 -quantizer: - _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize - input_dim: 512 - n_groups: 8 - n_codebooks: 1 - levels: [8, 5, 5, 5] - downsample_factor: [2, 2] +_target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture +spec_transform: + _target_: fish_speech.utils.spectrogram.LogMelSpectrogram + sample_rate: 44100 + n_mels: 160 + n_fft: 2048 + hop_length: 512 + win_length: 2048 +backbone: + _target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder + input_channels: 160 + depths: [3, 3, 9, 3] + dims: [128, 256, 384, 512] + drop_path_rate: 0.2 + kernel_size: 7 +head: + _target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator + hop_length: 512 + upsample_rates: [8, 8, 2, 2, 2] # aka. strides + upsample_kernel_sizes: [16, 16, 4, 4, 4] + resblock_kernel_sizes: [3, 7, 11] + resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + num_mels: 512 + upsample_initial_channel: 512 + pre_conv_kernel_size: 13 + post_conv_kernel_size: 13 +quantizer: + _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize + input_dim: 512 + n_groups: 8 + n_codebooks: 1 + levels: [8, 5, 5, 5] + downsample_factor: [2, 2] diff --git a/fish_speech/configs/lora/r_8_alpha_16.yaml b/fish_speech/configs/lora/r_8_alpha_16.yaml index aecc4d9766a18fe31c55941e01b1f590c95e77c9..0bb13622fc9dcf3ea59fd9db215536946fb346fa 100644 --- a/fish_speech/configs/lora/r_8_alpha_16.yaml +++ b/fish_speech/configs/lora/r_8_alpha_16.yaml @@ -1,4 +1,4 @@ -_target_: fish_speech.models.text2semantic.lora.LoraConfig -r: 8 -lora_alpha: 16 -lora_dropout: 0.01 +_target_: fish_speech.models.text2semantic.lora.LoraConfig +r: 8 +lora_alpha: 16 +lora_dropout: 0.01 diff --git a/fish_speech/configs/model/dual_ar_2_codebook_large.yaml b/fish_speech/configs/model/dual_ar_2_codebook_large.yaml deleted file mode 100644 index d4504d8bba62009710ec7761b4dcc87f3172be01..0000000000000000000000000000000000000000 --- a/fish_speech/configs/model/dual_ar_2_codebook_large.yaml +++ /dev/null @@ -1,9 +0,0 @@ -defaults: - - dual_ar_2_codebook_small - - _self_ - -config: - n_layer: 30 - n_fast_layer: 6 - n_head: 24 - dim: 1536 diff --git a/fish_speech/configs/model/dual_ar_2_codebook_medium.yaml b/fish_speech/configs/model/dual_ar_2_codebook_medium.yaml deleted file mode 100644 index 0ad6c2a10ac82452e33685d08c188d6c1e735678..0000000000000000000000000000000000000000 --- a/fish_speech/configs/model/dual_ar_2_codebook_medium.yaml +++ /dev/null @@ -1,9 +0,0 @@ -defaults: - - dual_ar_2_codebook_small - - _self_ - -config: - n_layer: 24 - n_fast_layer: 6 - n_head: 16 - dim: 1024 diff --git a/fish_speech/configs/model/dual_ar_2_codebook_small.yaml b/fish_speech/configs/model/dual_ar_2_codebook_small.yaml deleted file mode 100644 index a56974083d1d247629b2d83cb81a33838c70c22d..0000000000000000000000000000000000000000 --- a/fish_speech/configs/model/dual_ar_2_codebook_small.yaml +++ /dev/null @@ -1,13 +0,0 @@ -_target_: fish_speech.models.text2semantic.llama.DualARTransformer -config: - _target_: fish_speech.models.text2semantic.llama.DualARModelArgs - max_seq_len: ${max_length} - vocab_size: 264 # pad 262 to 8x - n_layer: 12 - n_fast_layer: 4 - n_head: 12 - dim: 768 - rope_base: 10000 - norm_eps: 1e-5 - num_codebooks: 2 # input/output codebook size - codebook_size: 1032 # codebook size 1024 + 2 special tokens diff --git a/fish_speech/configs/model/naive_2_codebook_small.yaml b/fish_speech/configs/model/naive_2_codebook_small.yaml deleted file mode 100644 index 16d1737c90c9c4f88587202ea7e7bb5bd741b30f..0000000000000000000000000000000000000000 --- a/fish_speech/configs/model/naive_2_codebook_small.yaml +++ /dev/null @@ -1,12 +0,0 @@ -_target_: fish_speech.models.text2semantic.llama.NaiveTransformer -config: - _target_: fish_speech.models.text2semantic.llama.NaiveModelArgs - max_seq_len: ${max_length} - vocab_size: 36408 - n_layer: 12 - n_head: 12 - dim: 768 - rope_base: 10000 - norm_eps: 1e-5 - num_codebooks: 2 # input/output codebook size - codebook_size: 1032 # codebook size 1024 + 2 special tokens diff --git a/fish_speech/configs/text2semantic_finetune.yaml b/fish_speech/configs/text2semantic_finetune.yaml index f4c1993023099e122fc9e004bda55ec075ed5e1b..8c999411703e85c7ad5972134f7a33f42b279571 100644 --- a/fish_speech/configs/text2semantic_finetune.yaml +++ b/fish_speech/configs/text2semantic_finetune.yaml @@ -1,83 +1,83 @@ -defaults: - - base - - _self_ - -project: text2semantic_finetune_dual_ar -max_length: 4096 -pretrained_ckpt_path: checkpoints/fish-speech-1.4 - -# Lightning Trainer -trainer: - accumulate_grad_batches: 1 - gradient_clip_val: 1.0 - gradient_clip_algorithm: "norm" - max_steps: 1000 - precision: bf16-true - limit_val_batches: 10 - val_check_interval: 100 - -# Dataset Configuration -tokenizer: - _target_: transformers.AutoTokenizer.from_pretrained - pretrained_model_name_or_path: ${pretrained_ckpt_path} - -# Dataset Configuration -train_dataset: - _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset - proto_files: - - data/protos - tokenizer: ${tokenizer} - causal: true - max_length: ${max_length} - use_speaker: false - interactive_prob: 0.7 - -val_dataset: - _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset - proto_files: - - data/protos - tokenizer: ${tokenizer} - causal: true - max_length: ${max_length} - use_speaker: false - interactive_prob: 0.7 - -data: - _target_: fish_speech.datasets.semantic.SemanticDataModule - train_dataset: ${train_dataset} - val_dataset: ${val_dataset} - num_workers: 4 - batch_size: 8 - tokenizer: ${tokenizer} - max_length: ${max_length} - -# Model Configuration -model: - _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic - model: - _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained - path: ${pretrained_ckpt_path} - load_weights: true - max_length: ${max_length} - lora_config: null - - optimizer: - _target_: torch.optim.AdamW - _partial_: true - lr: 1e-4 - weight_decay: 0 - betas: [0.9, 0.95] - eps: 1e-5 - - lr_scheduler: - _target_: torch.optim.lr_scheduler.LambdaLR - _partial_: true - lr_lambda: - _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda - _partial_: true - num_warmup_steps: 10 - -# Callbacks -callbacks: - model_checkpoint: - every_n_train_steps: ${trainer.val_check_interval} +defaults: + - base + - _self_ + +project: text2semantic_finetune_dual_ar +max_length: 4096 +pretrained_ckpt_path: checkpoints/fish-speech-1.4 + +# Lightning Trainer +trainer: + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + gradient_clip_algorithm: "norm" + max_steps: 1000 + precision: bf16-true + limit_val_batches: 10 + val_check_interval: 100 + +# Dataset Configuration +tokenizer: + _target_: transformers.AutoTokenizer.from_pretrained + pretrained_model_name_or_path: ${pretrained_ckpt_path} + +# Dataset Configuration +train_dataset: + _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset + proto_files: + - data/protos + tokenizer: ${tokenizer} + causal: true + max_length: ${max_length} + use_speaker: false + interactive_prob: 0.7 + +val_dataset: + _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset + proto_files: + - data/protos + tokenizer: ${tokenizer} + causal: true + max_length: ${max_length} + use_speaker: false + interactive_prob: 0.7 + +data: + _target_: fish_speech.datasets.semantic.SemanticDataModule + train_dataset: ${train_dataset} + val_dataset: ${val_dataset} + num_workers: 4 + batch_size: 8 + tokenizer: ${tokenizer} + max_length: ${max_length} + +# Model Configuration +model: + _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic + model: + _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained + path: ${pretrained_ckpt_path} + load_weights: true + max_length: ${max_length} + lora_config: null + + optimizer: + _target_: torch.optim.AdamW + _partial_: true + lr: 1e-4 + weight_decay: 0 + betas: [0.9, 0.95] + eps: 1e-5 + + lr_scheduler: + _target_: torch.optim.lr_scheduler.LambdaLR + _partial_: true + lr_lambda: + _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda + _partial_: true + num_warmup_steps: 10 + +# Callbacks +callbacks: + model_checkpoint: + every_n_train_steps: ${trainer.val_check_interval} diff --git a/fish_speech/configs/text2semantic_finetune_lora.yaml b/fish_speech/configs/text2semantic_finetune_lora.yaml deleted file mode 100644 index 36da21ceac8fd9e0e684b5863d8200156b53e0de..0000000000000000000000000000000000000000 --- a/fish_speech/configs/text2semantic_finetune_lora.yaml +++ /dev/null @@ -1,13 +0,0 @@ -defaults: - - text2semantic_finetune - - _self_ - -project: text2semantic_finetune_dual_ar_lora - -# Model Configuration -model: - save_lora_only: true - lora_config: - _target_: fish_speech.models.text2semantic.lit_module.LoraConfig - r: 8 - lora_alpha: 16 diff --git a/fish_speech/configs/text2semantic_pretrain.yaml b/fish_speech/configs/text2semantic_pretrain.yaml deleted file mode 100644 index 98983f4c2417028f980d24402aaecae39049d295..0000000000000000000000000000000000000000 --- a/fish_speech/configs/text2semantic_pretrain.yaml +++ /dev/null @@ -1,74 +0,0 @@ -defaults: - - base - - model@model.model: dual_ar_2_codebook_small - - _self_ - -project: text2semantic_pretrain_dual_ar_debug -max_length: 2048 - -# Lightning Trainer -trainer: - accumulate_grad_batches: 1 - gradient_clip_val: 1.0 - gradient_clip_algorithm: 'norm' - max_steps: 1_000_000 - precision: bf16-true - limit_val_batches: 10 - -# Dataset Configuration -tokenizer: - _target_: transformers.AutoTokenizer.from_pretrained - pretrained_model_name_or_path: fishaudio/fish-speech-1 - -# Dataset Configuration -train_dataset: - _target_: fish_speech.datasets.text.AutoAugTextDataset - proto_files: - - data/protos/train - tokenizer: ${tokenizer} - max_length: ${max_length} - num_codebooks: ${model.model.config.num_codebooks} - use_speaker: false - interactive_prob: 0.5 - -val_dataset: - _target_: fish_speech.datasets.text.AutoAugTextDataset - proto_files: - - data/protos/test - tokenizer: ${tokenizer} - max_length: ${max_length} - num_codebooks: ${model.model.config.num_codebooks} - use_speaker: false - interactive_prob: 0.5 - -data: - _target_: fish_speech.datasets.text.TextDataModule - train_dataset: ${train_dataset} - val_dataset: ${val_dataset} - num_workers: 4 - batch_size: 8 - tokenizer: ${tokenizer} - max_length: ${max_length} - -# Model Configuration -model: - _target_: fish_speech.models.text2semantic.TextToSemantic - model: {} - - optimizer: - _target_: torch.optim.AdamW - _partial_: true - lr: 3e-4 - weight_decay: 0.01 - betas: [0.9, 0.95] - eps: 1e-5 - - lr_scheduler: - _target_: torch.optim.lr_scheduler.LambdaLR - _partial_: true - lr_lambda: - _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda - _partial_: true - num_warmup_steps: 2000 - num_training_steps: ${trainer.max_steps} - final_lr_ratio: 0.1 diff --git a/fish_speech/configs/text2semantic_sft.yaml b/fish_speech/configs/text2semantic_sft.yaml deleted file mode 100644 index 9a3cf2675fc1d5f8feb3355b7e9c914a8f6254f0..0000000000000000000000000000000000000000 --- a/fish_speech/configs/text2semantic_sft.yaml +++ /dev/null @@ -1,87 +0,0 @@ -defaults: - - base - - model@model.model: dual_ar_8_codebook_small - - _self_ - -project: text2semantic_sft_medium_dual_ar -max_length: 4096 -ckpt_path: results/text2semantic_pretrain_medium_dual_ar/checkpoints/step_000060000.ckpt -resume_weights_only: true - -# Lightning Trainer -trainer: - accumulate_grad_batches: 1 - gradient_clip_val: 1.0 - gradient_clip_algorithm: 'norm' - max_steps: 10_000 - precision: bf16-true - limit_val_batches: 10 - val_check_interval: 500 - -# Dataset Configuration -tokenizer: - _target_: transformers.AutoTokenizer.from_pretrained - pretrained_model_name_or_path: fishaudio/speech-lm-v1 - -# Dataset Configuration -train_dataset: - _target_: fish_speech.datasets.text.AutoAugTextDataset - use_data_server: false - proto_files: - - data/protos/sft/train_Genshin.protos - - data/protos/sft/sft.protos - tokenizer: ${tokenizer} - max_length: ${max_length} - num_codebooks: ${model.model.config.num_codebooks} - use_speaker: false - phones_prob: 0.5 - interactive_prob: 0.5 - -val_dataset: - _target_: fish_speech.datasets.text.AutoAugTextDataset - use_data_server: false - proto_files: - - data/protos/sft/val_Genshin.protos - tokenizer: ${tokenizer} - max_length: ${max_length} - num_codebooks: ${model.model.config.num_codebooks} - use_speaker: false - phones_prob: 0.5 - interactive_prob: 0.5 - -data: - _target_: fish_speech.datasets.text.TextDataModule - train_dataset: ${train_dataset} - val_dataset: ${val_dataset} - num_workers: 4 - batch_size: 8 - tokenizer: ${tokenizer} - max_length: ${max_length} - -# Model Configuration -model: - _target_: fish_speech.models.text2semantic.TextToSemantic - model: {} - - optimizer: - _target_: torch.optim.AdamW - _partial_: true - lr: 4e-5 - weight_decay: 0 - betas: [0.9, 0.95] - eps: 1e-5 - - lr_scheduler: - _target_: torch.optim.lr_scheduler.LambdaLR - _partial_: true - lr_lambda: - _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda - _partial_: true - num_warmup_steps: 100 - num_training_steps: ${trainer.max_steps} - final_lr_ratio: 0 - -callbacks: - model_checkpoint: - every_n_train_steps: 1000 - save_top_k: 10 diff --git a/fish_speech/configs/vqgan_finetune.yaml b/fish_speech/configs/vqgan_finetune.yaml deleted file mode 100644 index 138ab975b718e6cd93c92cd91c5b0bd3c6ad206e..0000000000000000000000000000000000000000 --- a/fish_speech/configs/vqgan_finetune.yaml +++ /dev/null @@ -1,135 +0,0 @@ -defaults: - - base - - _self_ - -project: vq-gan-finetune -ckpt_path: checkpoints/vq-gan-group-fsq-2x1024.pth -resume_weights_only: true - -# Lightning Trainer -trainer: - accelerator: gpu - devices: auto - precision: bf16-mixed - max_steps: 100_000 - val_check_interval: 5000 - strategy: ddp_find_unused_parameters_true - -sample_rate: 44100 -hop_length: 512 -num_mels: 128 -n_fft: 2048 -win_length: 2048 -freeze_encoder: true - -# Dataset Configuration -train_dataset: - _target_: fish_speech.datasets.vqgan.VQGANDataset - filelist: data/filelist.train.txt - sample_rate: ${sample_rate} - hop_length: ${hop_length} - slice_frames: 512 - -val_dataset: - _target_: fish_speech.datasets.vqgan.VQGANDataset - filelist: data/filelist.val.txt - sample_rate: ${sample_rate} - hop_length: ${hop_length} - -data: - _target_: fish_speech.datasets.vqgan.VQGANDataModule - train_dataset: ${train_dataset} - val_dataset: ${val_dataset} - num_workers: 4 - batch_size: 16 - val_batch_size: 16 - -# Model Configuration -model: - _target_: fish_speech.models.vqgan.VQGAN - - sampling_rate: ${sample_rate} - weight_adv: 0.2 - weight_vq: 1.0 - weight_mel: 1.0 - freeze_encoder: false - - encoder: - _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet - input_channels: ${num_mels} - residual_channels: 768 - residual_layers: 20 - dilation_cycle: 4 - - quantizer: - _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize - input_dim: 768 - n_codebooks: 1 - n_groups: 2 - levels: [8, 5, 5, 5] - - decoder: - _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet - output_channels: ${num_mels} - residual_channels: 768 - residual_layers: 20 - dilation_cycle: 4 - condition_channels: 768 - - discriminator: - _target_: fish_speech.models.vqgan.modules.discriminator.Discriminator - - vocoder: - _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase - ckpt_path: null # You may download the pretrained vocoder and set the path here - - encode_mel_transform: - _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram - sample_rate: ${sample_rate} - n_fft: ${n_fft} - hop_length: ${hop_length} - win_length: ${win_length} - n_mels: ${num_mels} - f_min: 0.0 - f_max: 8000.0 - - gt_mel_transform: - _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram - sample_rate: ${sample_rate} - n_fft: ${n_fft} - hop_length: ${hop_length} - win_length: ${win_length} - n_mels: ${num_mels} - - optimizer: - _target_: torch.optim.AdamW - _partial_: true - lr: 4e-5 - betas: [0.8, 0.99] - eps: 1e-5 - weight_decay: 0.01 - - lr_scheduler: - _target_: torch.optim.lr_scheduler.LambdaLR - _partial_: true - lr_lambda: - _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda - _partial_: true - num_warmup_steps: 100 - num_training_steps: ${trainer.max_steps} - final_lr_ratio: 0 - -callbacks: - model_summary: - _target_: lightning.pytorch.callbacks.ModelSummary - max_depth: 1 - - model_checkpoint: - every_n_train_steps: ${trainer.val_check_interval} - - grad_norm_monitor: - sub_module: - - encoder - - decoder - - quantizer - - discriminator diff --git a/fish_speech/configs/vqgan_pretrain.yaml b/fish_speech/configs/vqgan_pretrain.yaml deleted file mode 100644 index 271b97dd0ddb69e343557751d74a529005e35667..0000000000000000000000000000000000000000 --- a/fish_speech/configs/vqgan_pretrain.yaml +++ /dev/null @@ -1,139 +0,0 @@ -defaults: - - base - - _self_ - -project: vq-gan-pretrain - -# Lightning Trainer -trainer: - accelerator: gpu - devices: auto - precision: bf16-mixed - max_steps: 1_000_000 - val_check_interval: 5000 - strategy: ddp_find_unused_parameters_true - -sample_rate: 44100 -hop_length: 512 -num_mels: 128 -n_fft: 2048 -win_length: 2048 - -# Dataset Configuration -train_dataset: - _target_: torch.utils.data.ConcatDataset - datasets: - - _target_: fish_speech.datasets.vqgan.VQGANDataset - filelist: data/gigaspeech/vq_train_filelist.txt - sample_rate: ${sample_rate} - hop_length: ${hop_length} - slice_frames: 512 - - _target_: fish_speech.datasets.vqgan.VQGANDataset - filelist: data/sft/vq_train_filelist.txt - sample_rate: ${sample_rate} - hop_length: ${hop_length} - slice_frames: 512 - -val_dataset: - _target_: fish_speech.datasets.vqgan.VQGANDataset - filelist: data/sft/vq_val_filelist.txt - sample_rate: ${sample_rate} - hop_length: ${hop_length} - -data: - _target_: fish_speech.datasets.vqgan.VQGANDataModule - train_dataset: ${train_dataset} - val_dataset: ${val_dataset} - num_workers: 4 - batch_size: 32 - val_batch_size: 32 - -# Model Configuration -model: - _target_: fish_speech.models.vqgan.VQGAN - - sampling_rate: ${sample_rate} - weight_adv: 0.2 - weight_vq: 1.0 - weight_mel: 1.0 - freeze_encoder: false - - encoder: - _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet - input_channels: ${num_mels} - residual_channels: 768 - residual_layers: 20 - dilation_cycle: 4 - - quantizer: - _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize - input_dim: 768 - n_codebooks: 1 - n_groups: 2 - levels: [8, 5, 5, 5] - - decoder: - _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet - output_channels: ${num_mels} - residual_channels: 768 - residual_layers: 20 - dilation_cycle: 4 - condition_channels: 768 - - discriminator: - _target_: fish_speech.models.vqgan.modules.discriminator.Discriminator - - vocoder: - _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase - ckpt_path: null # You may download the pretrained vocoder and set the path here - - encode_mel_transform: - _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram - sample_rate: ${sample_rate} - n_fft: ${n_fft} - hop_length: ${hop_length} - win_length: ${win_length} - n_mels: ${num_mels} - f_min: 0.0 - f_max: 8000.0 - - gt_mel_transform: - _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram - sample_rate: ${sample_rate} - n_fft: ${n_fft} - hop_length: ${hop_length} - win_length: ${win_length} - n_mels: ${num_mels} - - optimizer: - _target_: torch.optim.AdamW - _partial_: true - lr: 1e-4 - betas: [0.8, 0.99] - eps: 1e-5 - weight_decay: 0.01 - - lr_scheduler: - _target_: torch.optim.lr_scheduler.LambdaLR - _partial_: true - lr_lambda: - _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda - _partial_: true - num_warmup_steps: 100 - num_training_steps: ${trainer.max_steps} - final_lr_ratio: 0 - -callbacks: - model_summary: - _target_: lightning.pytorch.callbacks.ModelSummary - max_depth: 1 - - model_checkpoint: - every_n_train_steps: ${trainer.val_check_interval} - - grad_norm_monitor: - sub_module: - - encoder - - decoder - - quantizer - - discriminator diff --git a/fish_speech/conversation.py b/fish_speech/conversation.py index c9ca0ef9181754eda7e6b49e01abeafbe07fb00f..b9b648ff62e953317551abf73f9665bda3ccf3c1 100644 --- a/fish_speech/conversation.py +++ b/fish_speech/conversation.py @@ -1,2 +1,267 @@ -SEMANTIC_TOKEN = "<|semantic|>" -CODEBOOK_PAD_TOKEN_ID = 0 +from dataclasses import dataclass, field +from typing import Literal + +import torch + +from .tokenizer import MODALITY_TOKENS, FishTokenizer + +CODEBOOK_PAD_TOKEN_ID = 0 + + +@dataclass(kw_only=True) +class BasePart: + pass + + +@dataclass(kw_only=True) +class VQPart(BasePart): + codes: torch.Tensor + + +@dataclass(kw_only=True) +class TextPart(BasePart): + text: str + + +@dataclass(kw_only=True) +class EncodedMessage: + tokens: torch.Tensor + labels: torch.Tensor + vq_mask_tokens: torch.Tensor | None = None + vq_mask_labels: torch.Tensor | None = None + vq_parts: list[torch.Tensor] + vq_require_losses: torch.Tensor | None = None + + +@dataclass(kw_only=True) +class Message: + role: Literal["system", "user", "assistant"] + parts: list[VQPart | TextPart] = field(default_factory=list) + add_im_start: bool = True + add_im_end: bool = True + cal_loss: bool = False + modality: Literal["text", "voice", "interleave"] | None = None + + # By default, ignore the loss of the auto-generated im_start token + ignore_im_start_loss: bool = True + + def encode( + self: "Message", + tokenizer: FishTokenizer, + ) -> EncodedMessage: + all_tokens = [] + all_labels = [] + + # Multi-modal tokens + vq_parts = [] + vq_masks = [] + + parts = self.parts.copy() + if self.add_im_start: + modality_token = MODALITY_TOKENS[self.modality] if self.modality else "" + parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n{modality_token}")) + + if self.add_im_end: + parts.append(TextPart(text="<|im_end|>")) + + for part in parts: + if isinstance(part, TextPart): + tokens = torch.tensor( + tokenizer.encode(part.text), + dtype=torch.int, + ) + elif isinstance(part, VQPart): + curr_codes = part.codes.clone() + tokens = torch.tensor( + [ + tokenizer.semantic_id_to_token_id[i.item()] + for i in curr_codes[0].int() + ], + dtype=torch.int, + ) + vq_parts.append(curr_codes) + else: + raise ValueError(f"Unsupported part type: {type(part)}") + + all_tokens.append(tokens) + if isinstance(part, VQPart): + vq_masks.append(torch.ones_like(tokens, dtype=torch.bool)) + else: + vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool)) + + if self.cal_loss: + all_labels.append(tokens.clone()) + else: + all_labels.append(torch.full_like(tokens, -100)) + + tokens = torch.cat(all_tokens, dim=0) + labels = torch.cat(all_labels, dim=0) + vq_masks = torch.cat(vq_masks, dim=0) + + assert tokens.shape == labels.shape == vq_masks.shape + + if self.ignore_im_start_loss and self.add_im_start: + labels[: len(all_tokens[0])] = -100 + + return EncodedMessage( + tokens=tokens, + labels=labels, + vq_parts=vq_parts, + vq_mask_tokens=vq_masks, + vq_mask_labels=vq_masks, + ) + + +@dataclass +class Conversation: + messages: list[Message] + + def __init__(self: "Conversation", messages: list[Message] | None = None): + self.messages = messages or [] + + def encode( + self: "Conversation", + tokenizer: FishTokenizer, + add_shift: bool = True, + ignore_loss_tokens: list[str] = [], + ) -> EncodedMessage: + # Build the input_ids and labels + tokens = [] + labels = [] + vq_parts = [] + vq_mask_tokens = [] + vq_mask_labels = [] + vq_require_losses = [] + ignore_loss_token_ids = [tokenizer.get_token_id(i) for i in ignore_loss_tokens] + + for message in self.messages: + encoded = message.encode( + tokenizer, + ) + tokens.append(encoded.tokens) + labels.append(encoded.labels) + vq_parts.extend(encoded.vq_parts) + vq_mask_tokens.append(encoded.vq_mask_tokens) + vq_mask_labels.append(encoded.vq_mask_labels) + vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts)) + + tokens = torch.cat(tokens, dim=0) + labels = torch.cat(labels, dim=0) + vq_mask_tokens = torch.cat(vq_mask_tokens, dim=0) + vq_mask_labels = torch.cat(vq_mask_labels, dim=0) + vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool) + + if add_shift: + tokens = tokens[:-1] + labels = labels[1:] + vq_mask_tokens = vq_mask_tokens[:-1] + vq_mask_labels = vq_mask_labels[1:] + + for i in ignore_loss_token_ids: + assert i != -100 and i is not None + labels[labels == i] = -100 + + assert tokens.dtype in [ + torch.int, + torch.long, + ], f"Invalid dtype: {tokens.dtype}, conv: {conversation}" + + return EncodedMessage( + tokens=tokens, + labels=labels, + vq_parts=vq_parts, + vq_mask_tokens=vq_mask_tokens, + vq_mask_labels=vq_mask_labels, + vq_require_losses=vq_require_losses, + ) + + def encode_for_inference( + self: "Conversation", + tokenizer: FishTokenizer, + num_codebooks: int, + ) -> EncodedMessage: + # self.visualize(tokenizer) + + encoded = self.encode(tokenizer, add_shift=False) + tokens = encoded.tokens + values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int) + values[0] = tokens + + if encoded.vq_parts is None or len(encoded.vq_parts) == 0: + return values + + vq_parts = encoded.vq_parts + vq_parts = [part.to(values.device) for part in vq_parts] + vq_parts = torch.cat(vq_parts, dim=1) + values[0, encoded.vq_mask_tokens] = vq_parts[0] + tokenizer.semantic_begin_id + values[1:, encoded.vq_mask_tokens] = vq_parts + + return values + + def visualize( + self: "Conversation", + tokenizer: FishTokenizer, + ignore_loss_tokens: list[str] = [], + ): + encoded = self.encode( + tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens + ) + + # Colors for alternating tokens + colors = { + "blue": "\033[94m", # Light blue + "cyan": "\033[96m", # Cyan + "green": "\033[92m", # Light green + "dark_green": "\033[32m", # Dark green + } + blue_idx = 0 + green_idx = 0 + + def print_in_blue(x): + nonlocal blue_idx + color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"] + print(f"{color}{x}\033[0m", end="") + blue_idx += 1 + + def print_in_green(x): + nonlocal green_idx + color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"] + print(f"{color}{x}\033[0m", end="") + green_idx += 1 + + for tok, lab in zip(encoded.tokens, encoded.labels): + val = tokenizer.decode([tok]) + + if lab == -100: + print_in_green(val) + else: + print_in_blue(val) + + print() + + def append(self: "Conversation", message: Message): + self.messages.append(message) + + +if __name__ == "__main__": + message0 = Message( + role="user", + parts=[ + TextPart(text="Hello, how are you?"), + VQPart(codes=torch.zeros((4, 10))), + ], + cal_loss=False, + ) + + message1 = Message( + role="assistant", + parts=[TextPart(text="I'm fine, thank you.")], + cal_loss=True, + ) + conversation = Conversation([message0, message1]) + tokenizer = FishTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct") + conversation.visualize(tokenizer) + + encoded = conversation.encode(tokenizer) + print(encoded) + print(tokenizer.batch_decode(encoded.tokens)) diff --git a/fish_speech/datasets/concat_repeat.py b/fish_speech/datasets/concat_repeat.py index 4aa596b95a572ee15c5570cbdb792c9a78e62dfa..b74a0fc93119e45405ff89a299324393ae7de57b 100644 --- a/fish_speech/datasets/concat_repeat.py +++ b/fish_speech/datasets/concat_repeat.py @@ -1,53 +1,53 @@ -import bisect -import random -from typing import Iterable - -from torch.utils.data import Dataset, IterableDataset - - -class ConcatRepeatDataset(Dataset): - datasets: list[Dataset] - cumulative_sizes: list[int] - repeats: list[int] - - @staticmethod - def cumsum(sequence, repeats): - r, s = [], 0 - for dataset, repeat in zip(sequence, repeats): - l = len(dataset) * repeat - r.append(l + s) - s += l - return r - - def __init__(self, datasets: Iterable[Dataset], repeats: list[int]): - super().__init__() - - self.datasets = list(datasets) - self.repeats = repeats - - assert len(self.datasets) > 0, "datasets should not be an empty iterable" - assert len(self.datasets) == len( - repeats - ), "datasets and repeats should have the same length" - - for d in self.datasets: - assert not isinstance( - d, IterableDataset - ), "ConcatRepeatDataset does not support IterableDataset" - - self.cumulative_sizes = self.cumsum(self.datasets, self.repeats) - - def __len__(self): - return self.cumulative_sizes[-1] - - def __getitem__(self, idx): - dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) - - if dataset_idx == 0: - sample_idx = idx - else: - sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] - - dataset = self.datasets[dataset_idx] - - return dataset[sample_idx % len(dataset)] +import bisect +import random +from typing import Iterable + +from torch.utils.data import Dataset, IterableDataset + + +class ConcatRepeatDataset(Dataset): + datasets: list[Dataset] + cumulative_sizes: list[int] + repeats: list[int] + + @staticmethod + def cumsum(sequence, repeats): + r, s = [], 0 + for dataset, repeat in zip(sequence, repeats): + l = len(dataset) * repeat + r.append(l + s) + s += l + return r + + def __init__(self, datasets: Iterable[Dataset], repeats: list[int]): + super().__init__() + + self.datasets = list(datasets) + self.repeats = repeats + + assert len(self.datasets) > 0, "datasets should not be an empty iterable" + assert len(self.datasets) == len( + repeats + ), "datasets and repeats should have the same length" + + for d in self.datasets: + assert not isinstance( + d, IterableDataset + ), "ConcatRepeatDataset does not support IterableDataset" + + self.cumulative_sizes = self.cumsum(self.datasets, self.repeats) + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + + dataset = self.datasets[dataset_idx] + + return dataset[sample_idx % len(dataset)] diff --git a/fish_speech/datasets/protos/text-data.proto b/fish_speech/datasets/protos/text-data.proto index 5eb26d94aa3be1e21066f2bf38c90d54e85a8379..97ce77f8f2b10a7be81e20ed69f0bcd098abe3f3 100644 --- a/fish_speech/datasets/protos/text-data.proto +++ b/fish_speech/datasets/protos/text-data.proto @@ -1,24 +1,24 @@ -syntax = "proto3"; - -package text_data; - -message Semantics { - repeated uint32 values = 1; -} - -message Sentence { - repeated string texts = 1; - repeated Semantics semantics = 3; -} - -message TextData { - string source = 1; - string name = 2; - repeated Sentence sentences = 4; -} - -message SampledData { - string source = 1; - string name = 2; - repeated Sentence samples = 3; -} +syntax = "proto3"; + +package text_data; + +message Semantics { + repeated uint32 values = 1; +} + +message Sentence { + repeated string texts = 1; + repeated Semantics semantics = 3; +} + +message TextData { + string source = 1; + string name = 2; + repeated Sentence sentences = 4; +} + +message SampledData { + string source = 1; + string name = 2; + repeated Sentence samples = 3; +} diff --git a/fish_speech/datasets/protos/text_data_pb2.py b/fish_speech/datasets/protos/text_data_pb2.py index bfce0e8be59fc51e68999ef137e1fd0e4adc0d7e..4d0091190426115d7a3e3ffd41ad8ab0ebdada73 100644 --- a/fish_speech/datasets/protos/text_data_pb2.py +++ b/fish_speech/datasets/protos/text_data_pb2.py @@ -1,33 +1,33 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: text-data.proto -# Protobuf Python Version: 4.25.1 -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder - -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3' -) - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals) -if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _globals["_SEMANTICS"]._serialized_start = 30 - _globals["_SEMANTICS"]._serialized_end = 57 - _globals["_SENTENCE"]._serialized_start = 59 - _globals["_SENTENCE"]._serialized_end = 125 - _globals["_TEXTDATA"]._serialized_start = 127 - _globals["_TEXTDATA"]._serialized_end = 207 - _globals["_SAMPLEDDATA"]._serialized_start = 209 - _globals["_SAMPLEDDATA"]._serialized_end = 290 -# @@protoc_insertion_point(module_scope) +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: text-data.proto +# Protobuf Python Version: 4.25.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals["_SEMANTICS"]._serialized_start = 30 + _globals["_SEMANTICS"]._serialized_end = 57 + _globals["_SENTENCE"]._serialized_start = 59 + _globals["_SENTENCE"]._serialized_end = 125 + _globals["_TEXTDATA"]._serialized_start = 127 + _globals["_TEXTDATA"]._serialized_end = 207 + _globals["_SAMPLEDDATA"]._serialized_start = 209 + _globals["_SAMPLEDDATA"]._serialized_end = 290 +# @@protoc_insertion_point(module_scope) diff --git a/fish_speech/datasets/protos/text_data_stream.py b/fish_speech/datasets/protos/text_data_stream.py index ec3c25bcd764e8245de47dcdf9686d6adfb5a107..03767b8ce11a51f19c91ab17496bf57df51cec45 100644 --- a/fish_speech/datasets/protos/text_data_stream.py +++ b/fish_speech/datasets/protos/text_data_stream.py @@ -1,36 +1,36 @@ -import struct - -from .text_data_pb2 import TextData - - -def read_pb_stream(f): - while True: - buf = f.read(4) - if len(buf) == 0: - break - size = struct.unpack("I", buf)[0] - buf = f.read(size) - text_data = TextData() - text_data.ParseFromString(buf) - yield text_data - - -def write_pb_stream(f, text_data): - buf = text_data.SerializeToString() - f.write(struct.pack("I", len(buf))) - f.write(buf) - - -def pack_pb_stream(text_data): - buf = text_data.SerializeToString() - return struct.pack("I", len(buf)) + buf - - -def split_pb_stream(f): - while True: - head = f.read(4) - if len(head) == 0: - break - size = struct.unpack("I", head)[0] - buf = f.read(size) - yield head + buf +import struct + +from .text_data_pb2 import TextData + + +def read_pb_stream(f): + while True: + buf = f.read(4) + if len(buf) == 0: + break + size = struct.unpack("I", buf)[0] + buf = f.read(size) + text_data = TextData() + text_data.ParseFromString(buf) + yield text_data + + +def write_pb_stream(f, text_data): + buf = text_data.SerializeToString() + f.write(struct.pack("I", len(buf))) + f.write(buf) + + +def pack_pb_stream(text_data): + buf = text_data.SerializeToString() + return struct.pack("I", len(buf)) + buf + + +def split_pb_stream(f): + while True: + head = f.read(4) + if len(head) == 0: + break + size = struct.unpack("I", head)[0] + buf = f.read(size) + yield head + buf diff --git a/fish_speech/datasets/semantic.py b/fish_speech/datasets/semantic.py index 3c64e01077ae253bdc4e4d9cd948f8fb50df7418..0fb00c6ab1516621e961580393e81b369082276f 100644 --- a/fish_speech/datasets/semantic.py +++ b/fish_speech/datasets/semantic.py @@ -1,496 +1,496 @@ -import random -from dataclasses import dataclass -from itertools import chain -from pathlib import Path -from random import Random -from typing import Optional, Union - -import numpy as np -import pyarrow.parquet as pq -import torch -import torch.nn.functional as F -from datasets.download.streaming_download_manager import xopen -from huggingface_hub import HfApi -from lightning import LightningDataModule -from torch.distributed import get_rank, get_world_size, is_initialized -from torch.utils.data import DataLoader, IterableDataset, get_worker_info -from transformers import AutoTokenizer - -from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID -from fish_speech.datasets.protos.text_data_pb2 import SampledData -from fish_speech.datasets.protos.text_data_stream import read_pb_stream -from fish_speech.text.clean import clean_text -from fish_speech.utils import RankedLogger -from fish_speech.utils.braceexpand import braceexpand - -log = RankedLogger(__name__, rank_zero_only=True) - - -def split_by_rank_worker(files): - # We need to know the total number of devices - # to split the data properly - - total_devices = 1 - if is_initialized(): - total_devices = get_world_size() - - worker_info = get_worker_info() - if worker_info is not None: - total_devices *= worker_info.num_workers - - if len(files) < total_devices: - # Repeat the files N times to match the number of devices - files = files * (total_devices // len(files) + 1) - - # DDP - if is_initialized(): - files = files[get_rank() :: get_world_size()] - - # Split by worker - if worker_info is not None: - files = files[worker_info.id :: worker_info.num_workers] - - return files - - -class AutoTextSemanticInstructionDataset(IterableDataset): - """ - Auto Augment Dataset by Speaker - - 1. Random concatenate multiple sentences from the same speaker to form a longer sentence - 2. Automatically normalize the text - - For interactive mode, we use the following format (multiple sequences): - [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] - - For non-interactive mode, we use the following format (one long sequence): - [INST] text [/INST] ... - """ - - def __init__( - self, - proto_files: list[str], - seed: int = 42, - interactive_prob: float = 0.5, - max_length: int = 1024, - tokenizer: AutoTokenizer = None, - use_speaker: bool | float = True, - causal: bool = True, - num_codebooks: Optional[int] = None, - skip_text_prob: float = 0.0, - ): - """ - Args: - proto_files: proto buf files if using local data - seed: random seed - interactive_prob: probability to use interactive mode - max_length: max length of the text - tokenizer: tokenizer - use_speaker: include speaker information in the prompt - causal: use causal sampling when using local data, disable will lead to random sampling - num_codebooks: number of codebooks, if None, it will be automatically detected - skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode - """ - - super().__init__() - - assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]" - - self.seed = seed - self.max_length = max_length - self.tokenizer = tokenizer - self.interactive_prob = interactive_prob - self.use_speaker = use_speaker - self.proto_files = proto_files - self.causal = causal - self.num_codebooks = num_codebooks - self.skip_text_prob = skip_text_prob - - self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>") - self.groups = None - - def init_mock_data_server(self): - if self.groups is not None: - return - - # Expand the proto files - expanded_proto_files = [] - for filename in self.proto_files: - for i in braceexpand(filename): - i = Path(i) - if i.is_file(): - expanded_proto_files.append(i) - elif i.is_dir(): - expanded_proto_files.extend(i.rglob("*.proto")) - expanded_proto_files.extend(i.rglob("*.protos")) - else: - raise ValueError(f"{i} is not a file or directory") - - expanded_proto_files = sorted(expanded_proto_files) - Random(self.seed).shuffle(expanded_proto_files) - - self.groups = [] - shard_proto_files = split_by_rank_worker(expanded_proto_files) - log.info( - f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files" - ) - - count = 0 - for filename in shard_proto_files: - with open(filename, "rb") as f: - for text_data in read_pb_stream(f): - self.groups.append(text_data) - count += 1 - - log.info(f"Read total {count} groups of data") - - # Shuffle the lines - Random(self.seed).shuffle(self.groups) - self.group_weights = [len(i.sentences) for i in self.groups] - - def __iter__(self): - while True: - yield self.augment() - - def tokenize_sentence(self, sentence: str): - sentence = clean_text(sentence) - tokens = self.tokenizer.encode( - f"{sentence}", - max_length=10**6, - add_special_tokens=False, - truncation=False, - ) - return sentence, len(tokens) - - def sample_data(self): - if self.groups is None: - self.init_mock_data_server() - - # Shuffle unique lines, estimate that each sample is at least 20 tokens - num_samples = self.max_length // 20 - - # choice group based on their number of samples - group = random.choices(self.groups, weights=self.group_weights, k=1)[0] - - if self.causal: - # Sample in order - if num_samples >= len(group.sentences): - samples = group.sentences - else: - begin = random.randint(0, len(group.sentences) - num_samples) - samples = group.sentences[begin : begin + num_samples] - else: - samples = random.choices( - group.sentences, k=min(num_samples, len(group.sentences)) - ) - - return SampledData( - source=group.source, - name=group.name, - samples=samples, - ) - - def augment(self): - final_text, final_semantic = [], [] - response = self.sample_data() - if len(response.samples) == 0: - # Invalid group - return None - - samples = list(response.samples) - idx = 0 - use_interactive = random.random() < self.interactive_prob - - if use_interactive is False: - # Random sample based on speaker using a truncated normal distribution - a = torch.tensor([0], dtype=torch.float32) - torch.nn.init.trunc_normal_( - a, - mean=self.max_length // 2, - std=self.max_length // 4, - a=10, - b=self.max_length, - ) - remaining_tokens = a.long().item() - 4 - else: - remaining_tokens = self.max_length - - # Use speaker - if isinstance(self.use_speaker, float): - use_speaker = random.random() < self.use_speaker - else: - use_speaker = self.use_speaker - - all_tokens, all_labels = [], [] - while remaining_tokens > 0 and len(samples) > 0: - sentence = samples.pop(0) - - text = random.choice(sentence.texts) - text, length = self.tokenize_sentence(text) - remaining_tokens -= length + len(sentence.semantics[0].values) - - if use_interactive is False: - final_text.append(text) - final_semantic.append(sentence.semantics) - else: - # For interactive mode, we only apply speaker for the first sentence - # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] - tokens, labels = self.pack_sentences( - sentences=[text], - semantics=[sentence.semantics], - speaker=response.name if use_speaker else None, - skip_text=random.random() < self.skip_text_prob, - ) - - all_tokens.append(tokens) - all_labels.append(labels) - - idx += 1 - - if use_interactive is False: - tokens, labels = self.pack_sentences( - final_text, - semantics=final_semantic, - speaker=response.name if use_speaker else None, - ) - all_tokens.append(tokens) - all_labels.append(labels) - - tokens = torch.cat(all_tokens, dim=1) - labels = torch.cat(all_labels, dim=1) - - # Verify that the length is correct - assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}" - - data = {"tokens": tokens, "labels": labels} - - return data - - def pack_sentences( - self, - sentences: list[str], - semantics: list, - speaker: Optional[str] = None, - skip_text: bool = False, - ): - if speaker is None: - speaker = "assistant" - - cated_sentences = " ".join(sentences) - if skip_text: - cated_sentences = "<|skip_text|>" - - final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>" - final_text = final_text + f"<|im_start|>{speaker}\n" - - encoded = self.tokenizer.encode( - final_text, - add_special_tokens=False, - truncation=False, - max_length=10**6, - ) - semantic_length = sum([len(i[0].values) for i in semantics]) - prompt_length = len(encoded) - num_codebooks = ( - len(semantics[0]) if self.num_codebooks is None else self.num_codebooks - ) - - # Pack the tokens and semantics (add and to semantic tokens) - tokens = ( - encoded - + [self.semantic_token_id] * semantic_length - + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"]) - ) - - # Codebook bos/padding: 0, eos: 1 - codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)] - for segment in semantics: - for book_idx, book in zip(range(num_codebooks), segment): - for j in book.values: - codes[book_idx].append(int(j) + 1) - - for book in codes: - book.extend([CODEBOOK_PAD_TOKEN_ID] * 1) - - tokens = [tokens] + codes - - tokens = torch.tensor(tokens, dtype=torch.long) - labels = tokens.clone() - - if skip_text: - # If text is not provided, the sentence is used for condition only, all labels are -100 - torch.fill_(labels, -100) - return tokens, labels - - # Mask out the tokens for semantic, predict semantic tokens only - # Since we don't mask out the input tokens, the language modeling still works - labels[1:, :prompt_length] = -100 - - tokens = tokens[:, :-1] - labels = labels[:, 1:] - - # Verify the padding is correct, and the last token is eos - assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all() - assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all() - - return tokens, labels - - -@dataclass -class TextDataCollator: - tokenizer: AutoTokenizer - max_length: int = 1024 - - def __call__(self, examples): - if "negative_tokens" in examples: - positive_examples = [] - negative_examples = [] - - for i in examples: - positive_examples.append( - { - "tokens": i["tokens"], - "labels": i["labels"], - } - ) - negative_examples.append( - { - "tokens": i["negative_tokens"], - "labels": i["negative_labels"], - } - ) - - examples = positive_examples + negative_examples - - return self.batchify(examples) - - def batchify(self, examples, tokens_key="tokens", labels_key="labels"): - tokens, attention_masks, labels = [], [], [] - - # Calculate the max length - max_tokens_length = 0 - for example in examples: - max_tokens_length = max(max_tokens_length, example[tokens_key].size(1)) - max_tokens_length = min(max_tokens_length, self.max_length) - - for example in examples: - _tokens = example[tokens_key][:, :max_tokens_length] - _labels = example[labels_key][:, :max_tokens_length] - _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool) - tokens_length = _tokens.size(1) - _attention_mask[:tokens_length] = False - - assert tokens_length == _labels.size( - 1 - ), f"{tokens_length} != {_labels.size(1)}" - - if tokens_length < max_tokens_length: - _tokens = F.pad( - _tokens, - (0, max_tokens_length - tokens_length), - value=self.tokenizer.eos_token_id, - ) - _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID - _labels = F.pad( - _labels, (0, max_tokens_length - _labels.size(1)), value=-100 - ) - - tokens.append(_tokens) - attention_masks.append(_attention_mask) - labels.append(_labels) - - tokens = torch.stack(tokens, dim=0) - attention_masks = torch.stack(attention_masks, dim=0) - labels = torch.stack(labels, dim=0) - - return { - "inputs": tokens, - "attention_masks": attention_masks, - "labels": labels, - } - - -class InterleaveDataset(IterableDataset): - def __init__( - self, - datasets: list[IterableDataset], - probabilities: list[float], - seed: int = 42, - ): - super().__init__() - - self.datasets = datasets - self.probabilities = probabilities - self.seed = seed - - def __iter__(self): - rng = np.random.default_rng(self.seed) - dataset_iterators = [iter(dataset) for dataset in self.datasets] - - while True: - # Random choice one - dataset_idx = rng.choice(len(self.datasets), p=self.probabilities) - dataset_iterator = dataset_iterators[dataset_idx] - - try: - yield next(dataset_iterator) - except StopIteration: - # Exhausted, create a new iterator - dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx]) - yield next(dataset_iterators[dataset_idx]) - - -class SemanticDataModule(LightningDataModule): - def __init__( - self, - train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset], - val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset], - batch_size: int = 32, - tokenizer: AutoTokenizer = None, - max_length: int = 1024, - num_workers: int = 4, - ): - super().__init__() - - self.train_dataset = train_dataset - self.val_dataset = val_dataset - self.batch_size = batch_size - self.tokenizer = tokenizer - self.max_length = max_length - self.num_workers = num_workers - - def train_dataloader(self): - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - collate_fn=TextDataCollator(self.tokenizer, self.max_length), - num_workers=self.num_workers, - persistent_workers=True, - ) - - def val_dataloader(self): - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - collate_fn=TextDataCollator(self.tokenizer, self.max_length), - num_workers=self.num_workers, - persistent_workers=True, - ) - - -if __name__ == "__main__": - from tqdm import tqdm - - ds = AutoTextSemanticInstructionDataset( - ["data/protos"], - tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"), - use_speaker=False, - interactive_prob=1.0, - skip_text_prob=0.5, - ) - - for i in ds: - print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False)) - # i["labels"][0][i["labels"][0] == -100] = 0 - # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False)) - break +import random +from dataclasses import dataclass +from itertools import chain +from pathlib import Path +from random import Random +from typing import Optional, Union + +import numpy as np +import pyarrow.parquet as pq +import torch +import torch.nn.functional as F +from datasets.download.streaming_download_manager import xopen +from huggingface_hub import HfApi +from lightning import LightningDataModule +from torch.distributed import get_rank, get_world_size, is_initialized +from torch.utils.data import DataLoader, IterableDataset, get_worker_info +from transformers import AutoTokenizer + +from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID +from fish_speech.datasets.protos.text_data_pb2 import SampledData +from fish_speech.datasets.protos.text_data_stream import read_pb_stream +from fish_speech.text.clean import clean_text +from fish_speech.utils import RankedLogger +from fish_speech.utils.braceexpand import braceexpand + +log = RankedLogger(__name__, rank_zero_only=True) + + +def split_by_rank_worker(files): + # We need to know the total number of devices + # to split the data properly + + total_devices = 1 + if is_initialized(): + total_devices = get_world_size() + + worker_info = get_worker_info() + if worker_info is not None: + total_devices *= worker_info.num_workers + + if len(files) < total_devices: + # Repeat the files N times to match the number of devices + files = files * (total_devices // len(files) + 1) + + # DDP + if is_initialized(): + files = files[get_rank() :: get_world_size()] + + # Split by worker + if worker_info is not None: + files = files[worker_info.id :: worker_info.num_workers] + + return files + + +class AutoTextSemanticInstructionDataset(IterableDataset): + """ + Auto Augment Dataset by Speaker + + 1. Random concatenate multiple sentences from the same speaker to form a longer sentence + 2. Automatically normalize the text + + For interactive mode, we use the following format (multiple sequences): + [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] + + For non-interactive mode, we use the following format (one long sequence): + [INST] text [/INST] ... + """ + + def __init__( + self, + proto_files: list[str], + seed: int = 42, + interactive_prob: float = 0.5, + max_length: int = 1024, + tokenizer: AutoTokenizer = None, + use_speaker: bool | float = True, + causal: bool = True, + num_codebooks: Optional[int] = None, + skip_text_prob: float = 0.0, + ): + """ + Args: + proto_files: proto buf files if using local data + seed: random seed + interactive_prob: probability to use interactive mode + max_length: max length of the text + tokenizer: tokenizer + use_speaker: include speaker information in the prompt + causal: use causal sampling when using local data, disable will lead to random sampling + num_codebooks: number of codebooks, if None, it will be automatically detected + skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode + """ + + super().__init__() + + assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]" + + self.seed = seed + self.max_length = max_length + self.tokenizer = tokenizer + self.interactive_prob = interactive_prob + self.use_speaker = use_speaker + self.proto_files = proto_files + self.causal = causal + self.num_codebooks = num_codebooks + self.skip_text_prob = skip_text_prob + + self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>") + self.groups = None + + def init_mock_data_server(self): + if self.groups is not None: + return + + # Expand the proto files + expanded_proto_files = [] + for filename in self.proto_files: + for i in braceexpand(filename): + i = Path(i) + if i.is_file(): + expanded_proto_files.append(i) + elif i.is_dir(): + expanded_proto_files.extend(i.rglob("*.proto")) + expanded_proto_files.extend(i.rglob("*.protos")) + else: + raise ValueError(f"{i} is not a file or directory") + + expanded_proto_files = sorted(expanded_proto_files) + Random(self.seed).shuffle(expanded_proto_files) + + self.groups = [] + shard_proto_files = split_by_rank_worker(expanded_proto_files) + log.info( + f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files" + ) + + count = 0 + for filename in shard_proto_files: + with open(filename, "rb") as f: + for text_data in read_pb_stream(f): + self.groups.append(text_data) + count += 1 + + log.info(f"Read total {count} groups of data") + + # Shuffle the lines + Random(self.seed).shuffle(self.groups) + self.group_weights = [len(i.sentences) for i in self.groups] + + def __iter__(self): + while True: + yield self.augment() + + def tokenize_sentence(self, sentence: str): + sentence = clean_text(sentence) + tokens = self.tokenizer.encode( + f"{sentence}", + max_length=10**6, + add_special_tokens=False, + truncation=False, + ) + return sentence, len(tokens) + + def sample_data(self): + if self.groups is None: + self.init_mock_data_server() + + # Shuffle unique lines, estimate that each sample is at least 20 tokens + num_samples = self.max_length // 20 + + # choice group based on their number of samples + group = random.choices(self.groups, weights=self.group_weights, k=1)[0] + + if self.causal: + # Sample in order + if num_samples >= len(group.sentences): + samples = group.sentences + else: + begin = random.randint(0, len(group.sentences) - num_samples) + samples = group.sentences[begin : begin + num_samples] + else: + samples = random.choices( + group.sentences, k=min(num_samples, len(group.sentences)) + ) + + return SampledData( + source=group.source, + name=group.name, + samples=samples, + ) + + def augment(self): + final_text, final_semantic = [], [] + response = self.sample_data() + if len(response.samples) == 0: + # Invalid group + return None + + samples = list(response.samples) + idx = 0 + use_interactive = random.random() < self.interactive_prob + + if use_interactive is False: + # Random sample based on speaker using a truncated normal distribution + a = torch.tensor([0], dtype=torch.float32) + torch.nn.init.trunc_normal_( + a, + mean=self.max_length // 2, + std=self.max_length // 4, + a=10, + b=self.max_length, + ) + remaining_tokens = a.long().item() - 4 + else: + remaining_tokens = self.max_length + + # Use speaker + if isinstance(self.use_speaker, float): + use_speaker = random.random() < self.use_speaker + else: + use_speaker = self.use_speaker + + all_tokens, all_labels = [], [] + while remaining_tokens > 0 and len(samples) > 0: + sentence = samples.pop(0) + + text = random.choice(sentence.texts) + text, length = self.tokenize_sentence(text) + remaining_tokens -= length + len(sentence.semantics[0].values) + + if use_interactive is False: + final_text.append(text) + final_semantic.append(sentence.semantics) + else: + # For interactive mode, we only apply speaker for the first sentence + # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] + tokens, labels = self.pack_sentences( + sentences=[text], + semantics=[sentence.semantics], + speaker=response.name if use_speaker else None, + skip_text=random.random() < self.skip_text_prob, + ) + + all_tokens.append(tokens) + all_labels.append(labels) + + idx += 1 + + if use_interactive is False: + tokens, labels = self.pack_sentences( + final_text, + semantics=final_semantic, + speaker=response.name if use_speaker else None, + ) + all_tokens.append(tokens) + all_labels.append(labels) + + tokens = torch.cat(all_tokens, dim=1) + labels = torch.cat(all_labels, dim=1) + + # Verify that the length is correct + assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}" + + data = {"tokens": tokens, "labels": labels} + + return data + + def pack_sentences( + self, + sentences: list[str], + semantics: list, + speaker: Optional[str] = None, + skip_text: bool = False, + ): + if speaker is None: + speaker = "assistant" + + cated_sentences = " ".join(sentences) + if skip_text: + cated_sentences = "<|skip_text|>" + + final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>" + final_text = final_text + f"<|im_start|>{speaker}\n" + + encoded = self.tokenizer.encode( + final_text, + add_special_tokens=False, + truncation=False, + max_length=10**6, + ) + semantic_length = sum([len(i[0].values) for i in semantics]) + prompt_length = len(encoded) + num_codebooks = ( + len(semantics[0]) if self.num_codebooks is None else self.num_codebooks + ) + + # Pack the tokens and semantics (add and to semantic tokens) + tokens = ( + encoded + + [self.semantic_token_id] * semantic_length + + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"]) + ) + + # Codebook bos/padding: 0, eos: 1 + codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)] + for segment in semantics: + for book_idx, book in zip(range(num_codebooks), segment): + for j in book.values: + codes[book_idx].append(int(j) + 1) + + for book in codes: + book.extend([CODEBOOK_PAD_TOKEN_ID] * 1) + + tokens = [tokens] + codes + + tokens = torch.tensor(tokens, dtype=torch.long) + labels = tokens.clone() + + if skip_text: + # If text is not provided, the sentence is used for condition only, all labels are -100 + torch.fill_(labels, -100) + return tokens, labels + + # Mask out the tokens for semantic, predict semantic tokens only + # Since we don't mask out the input tokens, the language modeling still works + labels[1:, :prompt_length] = -100 + + tokens = tokens[:, :-1] + labels = labels[:, 1:] + + # Verify the padding is correct, and the last token is eos + assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all() + assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all() + + return tokens, labels + + +@dataclass +class TextDataCollator: + tokenizer: AutoTokenizer + max_length: int = 1024 + + def __call__(self, examples): + if "negative_tokens" in examples: + positive_examples = [] + negative_examples = [] + + for i in examples: + positive_examples.append( + { + "tokens": i["tokens"], + "labels": i["labels"], + } + ) + negative_examples.append( + { + "tokens": i["negative_tokens"], + "labels": i["negative_labels"], + } + ) + + examples = positive_examples + negative_examples + + return self.batchify(examples) + + def batchify(self, examples, tokens_key="tokens", labels_key="labels"): + tokens, attention_masks, labels = [], [], [] + + # Calculate the max length + max_tokens_length = 0 + for example in examples: + max_tokens_length = max(max_tokens_length, example[tokens_key].size(1)) + max_tokens_length = min(max_tokens_length, self.max_length) + + for example in examples: + _tokens = example[tokens_key][:, :max_tokens_length] + _labels = example[labels_key][:, :max_tokens_length] + _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool) + tokens_length = _tokens.size(1) + _attention_mask[:tokens_length] = False + + assert tokens_length == _labels.size( + 1 + ), f"{tokens_length} != {_labels.size(1)}" + + if tokens_length < max_tokens_length: + _tokens = F.pad( + _tokens, + (0, max_tokens_length - tokens_length), + value=self.tokenizer.eos_token_id, + ) + _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID + _labels = F.pad( + _labels, (0, max_tokens_length - _labels.size(1)), value=-100 + ) + + tokens.append(_tokens) + attention_masks.append(_attention_mask) + labels.append(_labels) + + tokens = torch.stack(tokens, dim=0) + attention_masks = torch.stack(attention_masks, dim=0) + labels = torch.stack(labels, dim=0) + + return { + "inputs": tokens, + "attention_masks": attention_masks, + "labels": labels, + } + + +class InterleaveDataset(IterableDataset): + def __init__( + self, + datasets: list[IterableDataset], + probabilities: list[float], + seed: int = 42, + ): + super().__init__() + + self.datasets = datasets + self.probabilities = probabilities + self.seed = seed + + def __iter__(self): + rng = np.random.default_rng(self.seed) + dataset_iterators = [iter(dataset) for dataset in self.datasets] + + while True: + # Random choice one + dataset_idx = rng.choice(len(self.datasets), p=self.probabilities) + dataset_iterator = dataset_iterators[dataset_idx] + + try: + yield next(dataset_iterator) + except StopIteration: + # Exhausted, create a new iterator + dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx]) + yield next(dataset_iterators[dataset_idx]) + + +class SemanticDataModule(LightningDataModule): + def __init__( + self, + train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset], + val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset], + batch_size: int = 32, + tokenizer: AutoTokenizer = None, + max_length: int = 1024, + num_workers: int = 4, + ): + super().__init__() + + self.train_dataset = train_dataset + self.val_dataset = val_dataset + self.batch_size = batch_size + self.tokenizer = tokenizer + self.max_length = max_length + self.num_workers = num_workers + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + collate_fn=TextDataCollator(self.tokenizer, self.max_length), + num_workers=self.num_workers, + persistent_workers=True, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + collate_fn=TextDataCollator(self.tokenizer, self.max_length), + num_workers=self.num_workers, + persistent_workers=True, + ) + + +if __name__ == "__main__": + from tqdm import tqdm + + ds = AutoTextSemanticInstructionDataset( + ["data/protos"], + tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"), + use_speaker=False, + interactive_prob=1.0, + skip_text_prob=0.5, + ) + + for i in ds: + print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False)) + # i["labels"][0][i["labels"][0] == -100] = 0 + # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False)) + break diff --git a/fish_speech/datasets/text.py b/fish_speech/datasets/text.py deleted file mode 100644 index 44b3bfaa5582930d4f2c3950aa9e03ea4219017b..0000000000000000000000000000000000000000 --- a/fish_speech/datasets/text.py +++ /dev/null @@ -1,661 +0,0 @@ -import random -from dataclasses import dataclass -from itertools import chain -from pathlib import Path -from random import Random -from typing import Optional, Union - -import grpc -import numpy as np -import pyarrow.parquet as pq -import torch -import torch.nn.functional as F -from datasets.download.streaming_download_manager import xopen -from huggingface_hub import HfApi -from lightning import LightningDataModule -from torch.distributed import get_rank, get_world_size, is_initialized -from torch.utils.data import DataLoader, IterableDataset, get_worker_info -from transformers import AutoTokenizer - -from fish_speech.datasets.protos.text_data_pb2 import SampledData -from fish_speech.datasets.protos.text_data_stream import read_pb_stream -from fish_speech.text.clean import clean_text -from fish_speech.utils import RankedLogger -from fish_speech.utils.braceexpand import braceexpand - -log = RankedLogger(__name__, rank_zero_only=True) - -CODEBOOK_PAD_TOKEN_ID = 0 -CODEBOOK_EOS_TOKEN_ID = 1 - - -def split_by_rank_worker(files): - # We need to know the total number of devices - # to split the data properly - - total_devices = 1 - if is_initialized(): - total_devices = get_world_size() - - worker_info = get_worker_info() - if worker_info is not None: - total_devices *= worker_info.num_workers - - if len(files) < total_devices: - # Repeat the files N times to match the number of devices - files = files * (total_devices // len(files) + 1) - - # DDP - if is_initialized(): - files = files[get_rank() :: get_world_size()] - - # Split by worker - if worker_info is not None: - files = files[worker_info.id :: worker_info.num_workers] - - return files - - -class StreamTextDataset(IterableDataset): - def __init__( - self, - files: Optional[Union[list[str], str]] = None, - prefix: Optional[str] = None, - seed: int = 42, - parquet_batch_size: int = 10000, - repo: str = "uonlp/CulturaX", - max_length: int = 1024, - tokenizer: AutoTokenizer = None, - ): - super().__init__() - - self.seed = seed - self.parquet_batch_size = parquet_batch_size - self.repo = repo - self.max_length = max_length - self.tokenizer = tokenizer - - if files is None and prefix is None: - raise ValueError("Either files or prefix must be specified") - - if prefix is not None: - files = HfApi().list_repo_files(repo, repo_type="dataset") - files = [ - f for f in files if f.startswith(prefix) and f.endswith(".parquet") - ] - log.info(f"Found {len(files)} files in {repo} with prefix {prefix}") - else: - if isinstance(files, str): - files = [files] - - files = list(chain.from_iterable(map(braceexpand, files))) - log.info(f"Expanded {len(files)} files in {repo}") - - # Get sharded files - self.files = sorted(files) - Random(seed).shuffle(self.files) - - def __iter__(self): - files = split_by_rank_worker(self.files) - random.shuffle(files) - - for filename in files: - try: - yield from self.parse_data(filename) - except Exception as e: - log.exception(f"Failed to parse {filename}: {e}") - - def parse_data(self, filename: str): - for data in self.parse_data_internal(filename): - text = data["text"] - - # encode - tokens = self.tokenizer.encode( - text, - add_special_tokens=False, - truncation=False, - max_length=10**6, - ) - - # Random choice self.max_length - if len(tokens) > self.max_length: - start = random.randint(0, len(tokens) - self.max_length) - tokens = tokens[start : start + self.max_length - 1] - - tokens = ( - [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id] - ) - # Pad dims - placeholder_multi_codebook = torch.zeros((4, len(tokens)), dtype=torch.long) - - tokens = torch.concat( - [ - torch.tensor([tokens], dtype=torch.long), - placeholder_multi_codebook, - ], - dim=0, - ) - labels = tokens.clone() - tokens = tokens[:, :-1] - labels = labels[:, 1:] - labels[1:] = -100 # remove all placeholders - - yield {"tokens": tokens, "labels": labels} - - def parse_data_internal(self, filename: str): - url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}" - - with xopen(url, mode="rb") as stream: - parquet_file = pq.ParquetFile(stream) - - for batch in parquet_file.iter_batches( - batch_size=self.parquet_batch_size, columns=["text"] - ): - # In-batch shuffling - texts = [{"text": text.as_py()} for text in batch["text"]] - random.shuffle(texts) - yield from texts - - -class AutoAugTextDataset(IterableDataset): - """ - Auto Augment Dataset by Speaker - - 1. Random concatenate multiple sentences from the same speaker to form a longer sentence - 2. Automatically normalize the text - - For interactive mode, we use the following format (multiple sequences): - [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] - - For non-interactive mode, we use the following format (one long sequence): - [INST] text [/INST] ... - """ - - def __init__( - self, - proto_files: list[str], - seed: int = 42, - interactive_prob: float = 0.5, - max_length: int = 1024, - tokenizer: AutoTokenizer = None, - use_speaker: bool = True, - causual: bool = True, - use_negative_samples: bool = False, - num_codebooks: Optional[int] = None, - ): - """ - Args: - proto_files: proto buf files if using local data - seed: random seed - interactive_prob: probability to use interactive mode - max_length: max length of the text - tokenizer: tokenizer - use_speaker: include speaker information in the prompt - causual: use causual sampling when using local data, disable will lead to random sampling - use_negative_samples: generate negative samples - num_codebooks: number of codebooks, if None, it will be automatically detected - """ - - super().__init__() - - assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]" - - self.seed = seed - self.max_length = max_length - self.tokenizer = tokenizer - self.interactive_prob = interactive_prob - self.use_speaker = use_speaker - self.proto_files = proto_files - self.causual = causual - self.use_negative_samples = use_negative_samples - self.num_codebooks = num_codebooks - - self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>") - self.groups = None - - def init_mock_data_server(self): - if self.groups is not None: - return - - # Expand the proto files - expanded_proto_files = [] - for filename in self.proto_files: - for i in braceexpand(filename): - i = Path(i) - if i.is_file(): - expanded_proto_files.append(i) - elif i.is_dir(): - expanded_proto_files.extend(i.rglob("*.proto")) - expanded_proto_files.extend(i.rglob("*.protos")) - else: - raise ValueError(f"{i} is not a file or directory") - - expanded_proto_files = sorted(expanded_proto_files) - Random(self.seed).shuffle(expanded_proto_files) - - self.groups = [] - shard_proto_files = split_by_rank_worker(expanded_proto_files) - log.info( - f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files" - ) - - count = 0 - for filename in shard_proto_files: - with open(filename, "rb") as f: - for text_data in read_pb_stream(f): - self.groups.append(text_data) - count += 1 - - log.info(f"Read total {count} groups of data") - - # Shuffle the lines - Random(self.seed).shuffle(self.groups) - self.group_weights = [len(i.sentences) for i in self.groups] - - def __iter__(self): - while True: - yield self.augment() - - def tokenize_sentence(self, sentence: str): - sentence = clean_text(sentence) - tokens = self.tokenizer.encode( - f"{sentence}", - max_length=10**6, - add_special_tokens=False, - truncation=False, - ) - return sentence, len(tokens) - - def sample_data(self): - if self.groups is None: - self.init_mock_data_server() - - # Shuffle unique lines, estimate that each sample is at least 20 tokens - num_samples = self.max_length // 20 - - # choice group based on their number of samples - group = random.choices(self.groups, weights=self.group_weights, k=1)[0] - - if self.causual: - # Sample in order - if num_samples >= len(group.sentences): - samples = group.sentences - else: - begin = random.randint(0, len(group.sentences) - num_samples) - samples = group.sentences[begin : begin + num_samples] - else: - samples = random.choices( - group.sentences, k=min(num_samples, len(group.sentences)) - ) - - return SampledData( - source=group.source, - name=group.name, - samples=samples, - ) - - def augment(self): - # Random sample based on speaker using a truncated normal distribution - a = torch.tensor([0], dtype=torch.float32) - torch.nn.init.trunc_normal_( - a, - mean=self.max_length // 2, - std=self.max_length // 4, - a=10, - b=self.max_length, - ) - remaining_tokens = a.long().item() - 4 - - final_text, final_semantic = [], [] - response = self.sample_data() - if len(response.samples) == 0: - # Invalid group - return None - - samples = list(response.samples) - idx = 0 - use_interactive = random.random() < self.interactive_prob - - all_tokens, all_labels = [], [] - while remaining_tokens > 0 and len(samples) > 0: - sentence = samples.pop(0) - - text = random.choice(sentence.texts) - text, length = self.tokenize_sentence(text) - remaining_tokens -= length + len(sentence.semantics[0].values) - - if use_interactive is False: - final_text.append(text) - final_semantic.append(sentence.semantics) - else: - # For interactive mode, we only apply speaker for the first sentence - # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] - tokens, labels = self.pack_sentences( - sentences=[text], - semantics=[sentence.semantics], - speaker=response.name if (self.use_speaker and idx == 0) else None, - add_bos=idx == 0, - ) - - all_tokens.append(tokens) - all_labels.append(labels) - - idx += 1 - - if use_interactive is False: - tokens, labels = self.pack_sentences( - final_text, - semantics=final_semantic, - speaker=response.name if self.use_speaker else None, - add_bos=True, - ) - all_tokens.append(tokens) - all_labels.append(labels) - - tokens = torch.cat(all_tokens, dim=1) - labels = torch.cat(all_labels, dim=1) - - # Verify that the length is correct - assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}" - - # Verify bos token - assert tokens[0, 0] == self.tokenizer.bos_token_id - - data = {"tokens": tokens, "labels": labels} - - if self.use_negative_samples: - negative_samples = self.generate_negative_samples(all_tokens, all_labels) - data.update(negative_samples) - - return data - - def generate_negative_samples(self, all_tokens, all_labels): - new_tokens, new_labels = [], [] - - for tokens, labels in zip(all_tokens, all_labels): - # If all codebooks are not -100, we find where it starts - start = torch.where(labels[1:].sum(0) != -100 * (labels.size(0) - 1))[0][0] - assert (labels[1:, start:] != -100).all() # This shouldn't happen - - mode = random.choice(["repeat", "lost", "noise"]) - begin = random.randint(start, labels.size(1) - 1) - end = random.randint(begin, labels.size(1) - 1) - - if mode == "repeat": - tokens = torch.cat( - [ - tokens[:, :begin], - tokens[:, begin:end], - tokens[:, begin:end], - tokens[:, end:], - ], - dim=1, - ) - labels = torch.cat( - [ - labels[:, :begin], - labels[:, begin:end], - labels[:, begin:end], - labels[:, end:], - ], - dim=1, - ) - elif mode == "lost": - tokens = torch.cat([tokens[:, :begin], tokens[:, end:]], dim=1) - labels = torch.cat([labels[:, :begin], labels[:, end:]], dim=1) - elif mode == "noise": - middle_tokens, middle_labels = ( - tokens[:, begin:end], - labels[:, begin:end], - ) - random_order0 = torch.randperm(middle_tokens.size(1)) - random_order1 = torch.randperm(middle_tokens.size(1)) - middle_tokens = middle_tokens[:, random_order0] - middle_labels = middle_labels[:, random_order1] - tokens = torch.cat( - [tokens[:, :begin], middle_tokens, tokens[:, end:]], dim=1 - ) - labels = torch.cat( - [labels[:, :begin], middle_labels, labels[:, end:]], dim=1 - ) - - new_tokens.append(tokens) - new_labels.append(labels) - - tokens = torch.cat(new_tokens, dim=1) - labels = torch.cat(new_labels, dim=1) - - # Verify that the length is correct - assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}" - - return {"negative_tokens": tokens, "negative_labels": labels} - - def pack_sentences( - self, - sentences: list[str], - semantics=list, - speaker: Optional[str] = None, - add_bos: bool = True, - ): - if speaker is not None: - sentences = [f"[SPK: {speaker}]"] + sentences - - final_text = "<|im_start|>user<|im_sep|>" + " ".join(sentences) + "<|im_end|>" - final_text = final_text + "<|im_start|>assistant<|im_sep|>" - - encoded = self.tokenizer.encode( - final_text, - add_special_tokens=False, - truncation=False, - max_length=10**6, - ) - semantic_length = sum([len(i[0].values) for i in semantics]) - prompt_length = len(encoded) - num_codebooks = ( - len(semantics[0]) if self.num_codebooks is None else self.num_codebooks - ) - - bos_bias = 1 if add_bos else 0 - - # Pack the tokens and semantics (add and to semantic tokens) - tokens = ( - encoded - + [self.semantic_token_id] * semantic_length - + self.tokenizer.convert_tokens_to_ids( - ["<|im_end|>", "<|end_of_sequence|>"] - ) - ) - - if add_bos: - tokens = [self.tokenizer.bos_token_id] + tokens - - # Codebook bos/padding: 0, eos: 1 - codes = [ - [CODEBOOK_PAD_TOKEN_ID] * (prompt_length + bos_bias) - for _ in range(num_codebooks) - ] - for segment in semantics: - for book_idx, book in zip(range(num_codebooks), segment): - for j in book.values: - codes[book_idx].append(int(j) + 2) - - for book in codes: - book.extend([CODEBOOK_EOS_TOKEN_ID] * 2) - - tokens = [tokens] + codes - - tokens = torch.tensor(tokens, dtype=torch.long) - labels = tokens.clone() - - # Mask out the tokens for semantic, predict semantic tokens only - # Since we don't mask out the input tokens, the language modeling still works - labels[1:, : (prompt_length + bos_bias)] = -100 - - tokens = tokens[:, :-1] - labels = labels[:, 1:] - - # Verify the padding is correct, and the last token is eos - assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id - assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_PAD_TOKEN_ID).all() - assert labels[0, -1] == self.tokenizer.eos_token_id - assert (labels[1:, -2:] == CODEBOOK_EOS_TOKEN_ID).all() - - return tokens, labels - - -@dataclass -class TextDataCollator: - tokenizer: AutoTokenizer - max_length: int = 1024 - - def __call__(self, examples): - if "negative_tokens" in examples: - positive_examples = [] - negative_examples = [] - - for i in examples: - positive_examples.append( - { - "tokens": i["tokens"], - "labels": i["labels"], - } - ) - negative_examples.append( - { - "tokens": i["negative_tokens"], - "labels": i["negative_labels"], - } - ) - - examples = positive_examples + negative_examples - - return self.batchify(examples) - - def batchify(self, examples, tokens_key="tokens", labels_key="labels"): - tokens, attention_masks, labels = [], [], [] - - # Calculate the max length - max_tokens_length = 0 - for example in examples: - max_tokens_length = max(max_tokens_length, example[tokens_key].size(1)) - max_tokens_length = min(max_tokens_length, self.max_length) - - for example in examples: - _tokens = example[tokens_key][:, :max_tokens_length] - _labels = example[labels_key][:, :max_tokens_length] - _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool) - tokens_length = _tokens.size(1) - _attention_mask[:tokens_length] = False - - assert tokens_length == _labels.size( - 1 - ), f"{tokens_length} != {_labels.size(1)}" - - if tokens_length < max_tokens_length: - _tokens = F.pad( - _tokens, - (0, max_tokens_length - tokens_length), - value=self.tokenizer.eos_token_id, - ) - _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID - _labels = F.pad( - _labels, (0, max_tokens_length - _labels.size(1)), value=-100 - ) - - tokens.append(_tokens) - attention_masks.append(_attention_mask) - labels.append(_labels) - - tokens = torch.stack(tokens, dim=0) - attention_masks = torch.stack(attention_masks, dim=0) - labels = torch.stack(labels, dim=0) - - return { - "inputs": tokens, - "attention_masks": attention_masks, - "labels": labels, - } - - -class InterleaveDataset(IterableDataset): - def __init__( - self, - datasets: list[IterableDataset], - probabilities: list[float], - seed: int = 42, - ): - super().__init__() - - self.datasets = datasets - self.probabilities = probabilities - self.seed = seed - - def __iter__(self): - rng = np.random.default_rng(self.seed) - dataset_iterators = [iter(dataset) for dataset in self.datasets] - - while True: - # Random choice one - dataset_idx = rng.choice(len(self.datasets), p=self.probabilities) - dataset_iterator = dataset_iterators[dataset_idx] - - try: - yield next(dataset_iterator) - except StopIteration: - # Exhausted, create a new iterator - dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx]) - yield next(dataset_iterators[dataset_idx]) - - -class TextDataModule(LightningDataModule): - def __init__( - self, - train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset], - val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset], - batch_size: int = 32, - tokenizer: AutoTokenizer = None, - max_length: int = 1024, - num_workers: int = 4, - ): - super().__init__() - - self.train_dataset = train_dataset - self.val_dataset = val_dataset - self.batch_size = batch_size - self.tokenizer = tokenizer - self.max_length = max_length - self.num_workers = num_workers - - def train_dataloader(self): - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - collate_fn=TextDataCollator(self.tokenizer, self.max_length), - num_workers=self.num_workers, - ) - - def val_dataloader(self): - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - collate_fn=TextDataCollator(self.tokenizer, self.max_length), - num_workers=self.num_workers, - ) - - -if __name__ == "__main__": - from tqdm import tqdm - - ds = AutoAugTextDataset( - ["data/protos"], - tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"), - use_speaker=False, - interactive_prob=1.0, - use_negative_samples=False, - ) - - for i in ds: - print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False)) - # i["labels"][0][i["labels"][0] == -100] = 0 - # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False)) - break diff --git a/fish_speech/datasets/vqgan.py b/fish_speech/datasets/vqgan.py index a45583d22efb0feb9dc1e823bae1ef74534b299e..050f694481170329a80110d522fe116ff6b2b135 100644 --- a/fish_speech/datasets/vqgan.py +++ b/fish_speech/datasets/vqgan.py @@ -1,147 +1,147 @@ -from dataclasses import dataclass -from pathlib import Path -from typing import Optional - -import librosa -import numpy as np -import torch -from lightning import LightningDataModule -from torch.utils.data import DataLoader, Dataset - -from fish_speech.utils import RankedLogger - -logger = RankedLogger(__name__, rank_zero_only=False) - - -class VQGANDataset(Dataset): - def __init__( - self, - filelist: str, - sample_rate: int = 32000, - hop_length: int = 640, - slice_frames: Optional[int] = None, - ): - super().__init__() - - filelist = Path(filelist) - root = filelist.parent - - self.files = [ - root / line.strip() - for line in filelist.read_text(encoding="utf-8").splitlines() - if line.strip() - ] - self.sample_rate = sample_rate - self.hop_length = hop_length - self.slice_frames = slice_frames - - def __len__(self): - return len(self.files) - - def get_item(self, idx): - file = self.files[idx] - - audio, _ = librosa.load(file, sr=self.sample_rate, mono=True) - - # Slice audio and features - if ( - self.slice_frames is not None - and audio.shape[0] > self.slice_frames * self.hop_length - ): - start = np.random.randint( - 0, audio.shape[0] - self.slice_frames * self.hop_length - ) - audio = audio[start : start + self.slice_frames * self.hop_length] - - if len(audio) == 0: - return None - - max_value = np.abs(audio).max() - if max_value > 1.0: - audio = audio / max_value - - return { - "audio": torch.from_numpy(audio), - } - - def __getitem__(self, idx): - try: - return self.get_item(idx) - except Exception as e: - import traceback - - traceback.print_exc() - logger.error(f"Error loading {self.files[idx]}: {e}") - return None - - -@dataclass -class VQGANCollator: - def __call__(self, batch): - batch = [x for x in batch if x is not None] - - audio_lengths = torch.tensor([len(x["audio"]) for x in batch]) - audio_maxlen = audio_lengths.max() - - # Rounds up to nearest multiple of 2 (audio_lengths) - audios = [] - for x in batch: - audios.append( - torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"]))) - ) - - return { - "audios": torch.stack(audios), - "audio_lengths": audio_lengths, - } - - -class VQGANDataModule(LightningDataModule): - def __init__( - self, - train_dataset: VQGANDataset, - val_dataset: VQGANDataset, - batch_size: int = 32, - num_workers: int = 4, - val_batch_size: Optional[int] = None, - ): - super().__init__() - - self.train_dataset = train_dataset - self.val_dataset = val_dataset - self.batch_size = batch_size - self.val_batch_size = val_batch_size or batch_size - self.num_workers = num_workers - - def train_dataloader(self): - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - collate_fn=VQGANCollator(), - num_workers=self.num_workers, - shuffle=True, - persistent_workers=True, - ) - - def val_dataloader(self): - return DataLoader( - self.val_dataset, - batch_size=self.val_batch_size, - collate_fn=VQGANCollator(), - num_workers=self.num_workers, - persistent_workers=True, - ) - - -if __name__ == "__main__": - dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt") - dataloader = DataLoader( - dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator() - ) - - for batch in dataloader: - print(batch["audios"].shape) - print(batch["features"].shape) - print(batch["audio_lengths"]) - print(batch["feature_lengths"]) - break +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import librosa +import numpy as np +import torch +from lightning import LightningDataModule +from torch.utils.data import DataLoader, Dataset + +from fish_speech.utils import RankedLogger + +logger = RankedLogger(__name__, rank_zero_only=False) + + +class VQGANDataset(Dataset): + def __init__( + self, + filelist: str, + sample_rate: int = 32000, + hop_length: int = 640, + slice_frames: Optional[int] = None, + ): + super().__init__() + + filelist = Path(filelist) + root = filelist.parent + + self.files = [ + root / line.strip() + for line in filelist.read_text(encoding="utf-8").splitlines() + if line.strip() + ] + self.sample_rate = sample_rate + self.hop_length = hop_length + self.slice_frames = slice_frames + + def __len__(self): + return len(self.files) + + def get_item(self, idx): + file = self.files[idx] + + audio, _ = librosa.load(file, sr=self.sample_rate, mono=True) + + # Slice audio and features + if ( + self.slice_frames is not None + and audio.shape[0] > self.slice_frames * self.hop_length + ): + start = np.random.randint( + 0, audio.shape[0] - self.slice_frames * self.hop_length + ) + audio = audio[start : start + self.slice_frames * self.hop_length] + + if len(audio) == 0: + return None + + max_value = np.abs(audio).max() + if max_value > 1.0: + audio = audio / max_value + + return { + "audio": torch.from_numpy(audio), + } + + def __getitem__(self, idx): + try: + return self.get_item(idx) + except Exception as e: + import traceback + + traceback.print_exc() + logger.error(f"Error loading {self.files[idx]}: {e}") + return None + + +@dataclass +class VQGANCollator: + def __call__(self, batch): + batch = [x for x in batch if x is not None] + + audio_lengths = torch.tensor([len(x["audio"]) for x in batch]) + audio_maxlen = audio_lengths.max() + + # Rounds up to nearest multiple of 2 (audio_lengths) + audios = [] + for x in batch: + audios.append( + torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"]))) + ) + + return { + "audios": torch.stack(audios), + "audio_lengths": audio_lengths, + } + + +class VQGANDataModule(LightningDataModule): + def __init__( + self, + train_dataset: VQGANDataset, + val_dataset: VQGANDataset, + batch_size: int = 32, + num_workers: int = 4, + val_batch_size: Optional[int] = None, + ): + super().__init__() + + self.train_dataset = train_dataset + self.val_dataset = val_dataset + self.batch_size = batch_size + self.val_batch_size = val_batch_size or batch_size + self.num_workers = num_workers + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + collate_fn=VQGANCollator(), + num_workers=self.num_workers, + shuffle=True, + persistent_workers=True, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.val_batch_size, + collate_fn=VQGANCollator(), + num_workers=self.num_workers, + persistent_workers=True, + ) + + +if __name__ == "__main__": + dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt") + dataloader = DataLoader( + dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator() + ) + + for batch in dataloader: + print(batch["audios"].shape) + print(batch["features"].shape) + print(batch["audio_lengths"]) + print(batch["feature_lengths"]) + break diff --git a/fish_speech/i18n/README.md b/fish_speech/i18n/README.md index 700902b09db20911ef1ad678cbdce5644b84aea2..7d4612883379a16fd2d0945c431d9fb8b04b249a 100644 --- a/fish_speech/i18n/README.md +++ b/fish_speech/i18n/README.md @@ -1,27 +1,27 @@ -## i18n Folder Attribution - -The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below: - -### fish_speech/i18n/core.py - -**Related code from RVC:** -[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py) - -**Initial commit:** -add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35) - -**Initial author:** -[@L4Ph](https://github.com/L4Ph) - -### fish_speech/i18n/scan.py - -**Related code from RVC:** -[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py) - -**Initial commit:** -File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058) - -**Initial author:** -[@towzeur](https://github.com/towzeur) - -We appreciate the contributions of the RVC project and its authors. +## i18n Folder Attribution + +The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below: + +### fish_speech/i18n/core.py + +**Related code from RVC:** +[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py) + +**Initial commit:** +add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35) + +**Initial author:** +[@L4Ph](https://github.com/L4Ph) + +### fish_speech/i18n/scan.py + +**Related code from RVC:** +[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py) + +**Initial commit:** +File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058) + +**Initial author:** +[@towzeur](https://github.com/towzeur) + +We appreciate the contributions of the RVC project and its authors. diff --git a/fish_speech/i18n/__init__.py b/fish_speech/i18n/__init__.py index 981dbb3b3ecf28043ec9ff5757f947182821a246..0ac9702a707223257997e283ebc259b78996ad5c 100644 --- a/fish_speech/i18n/__init__.py +++ b/fish_speech/i18n/__init__.py @@ -1,3 +1,3 @@ -from .core import i18n - -__all__ = ["i18n"] +from .core import i18n + +__all__ = ["i18n"] diff --git a/fish_speech/i18n/core.py b/fish_speech/i18n/core.py index 9f793ec95669228f7f4e8f9a7a5fe38da85c74bd..8375d9ddb4e3c2b3ec25c426d2786f2a7506a0ab 100644 --- a/fish_speech/i18n/core.py +++ b/fish_speech/i18n/core.py @@ -1,40 +1,40 @@ -import json -import locale -from pathlib import Path - -I18N_FILE_PATH = Path(__file__).parent / "locale" -DEFAULT_LANGUAGE = "en_US" - - -def load_language_list(language): - with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f: - language_list = json.load(f) - - return language_list - - -class I18nAuto: - def __init__(self): - i18n_file = Path(".locale") - - if i18n_file.exists(): - with open(i18n_file, "r", encoding="utf-8") as f: - language = f.read().strip() - else: - # getlocale can't identify the system's language ((None, None)) - language = locale.getdefaultlocale()[0] - - if (I18N_FILE_PATH / f"{language}.json").exists() is False: - language = DEFAULT_LANGUAGE - - self.language = language - self.language_map = load_language_list(language) - - def __call__(self, key): - return self.language_map.get(key, key) - - def __repr__(self): - return "Use Language: " + self.language - - -i18n = I18nAuto() +import json +import locale +from pathlib import Path + +I18N_FILE_PATH = Path(__file__).parent / "locale" +DEFAULT_LANGUAGE = "en_US" + + +def load_language_list(language): + with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f: + language_list = json.load(f) + + return language_list + + +class I18nAuto: + def __init__(self): + i18n_file = Path(".locale") + + if i18n_file.exists(): + with open(i18n_file, "r", encoding="utf-8") as f: + language = f.read().strip() + else: + # getlocale can't identify the system's language ((None, None)) + language = locale.getdefaultlocale()[0] + + if (I18N_FILE_PATH / f"{language}.json").exists() is False: + language = DEFAULT_LANGUAGE + + self.language = language + self.language_map = load_language_list(language) + + def __call__(self, key): + return self.language_map.get(key, key) + + def __repr__(self): + return "Use Language: " + self.language + + +i18n = I18nAuto() diff --git a/fish_speech/i18n/locale/en_US.json b/fish_speech/i18n/locale/en_US.json index 6e280c236e9c79de2087ec33c7bf6f8e1a5296c4..32d58983a5df52d032411fd50ef1a9e3ecdeb859 100644 --- a/fish_speech/i18n/locale/en_US.json +++ b/fish_speech/i18n/locale/en_US.json @@ -1,122 +1,123 @@ -{ - "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU", - "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.", - "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).", - "Accumulate Gradient Batches": "Accumulate Gradient Batches", - "Add to Processing Area": "Add to Processing Area", - "Added path successfully!": "Added path successfully!", - "Advanced Config": "Advanced Config", - "Base LLAMA Model": "Base LLAMA Model", - "Batch Inference": "Batch Inference", - "Batch Size": "Batch Size", - "Changing with the Model Path": "Changing with the Model Path", - "Chinese": "Chinese", - "Compile Model": "Compile Model", - "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time", - "Copy": "Copy", - "Data Preprocessing": "Data Preprocessing", - "Data Preprocessing Path": "Data Preprocessing Path", - "Data Source": "Data Source", - "Decoder Model Config": "Decoder Model Config", - "Decoder Model Path": "Decoder Model Path", - "Disabled": "Disabled", - "Enable Reference Audio": "Enable Reference Audio", - "English": "English", - "Error Message": "Error Message", - "File Preprocessing": "File Preprocessing", - "Generate": "Generate", - "Generated Audio": "Generated Audio", - "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format", - "Infer interface is closed": "Infer interface is closed", - "Inference Configuration": "Inference Configuration", - "Inference Server Configuration": "Inference Server Configuration", - "Inference Server Error": "Inference Server Error", - "Inferring interface is launched at {}": "Inferring interface is launched at {}", - "Initial Learning Rate": "Initial Learning Rate", - "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription", - "Input Text": "Input Text", - "Invalid path: {}": "Invalid path: {}", - "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU", - "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off", - "Japanese": "Japanese", - "LLAMA Configuration": "LLAMA Configuration", - "LLAMA Model Config": "LLAMA Model Config", - "LLAMA Model Path": "LLAMA Model Path", - "Labeling Device": "Labeling Device", - "LoRA Model to be merged": "LoRA Model to be merged", - "Maximum Audio Duration": "Maximum Audio Duration", - "Maximum Length per Sample": "Maximum Length per Sample", - "Maximum Training Steps": "Maximum Training Steps", - "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit", - "Merge": "Merge", - "Merge LoRA": "Merge LoRA", - "Merge successfully": "Merge successfully", - "Minimum Audio Duration": "Minimum Audio Duration", - "Model Output Path": "Model Output Path", - "Model Size": "Model Size", - "Move": "Move", - "Move files successfully": "Move files successfully", - "No audio generated, please check the input text.": "No audio generated, please check the input text.", - "No selected options": "No selected options", - "Number of Workers": "Number of Workers", - "Open Inference Server": "Open Inference Server", - "Open Labeler WebUI": "Open Labeler WebUI", - "Open Tensorboard": "Open Tensorboard", - "Opened labeler in browser": "Opened labeler in browser", - "Optional Label Language": "Optional Label Language", - "Optional online ver": "Optional online ver", - "Output Path": "Output Path", - "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path", - "Precision": "Precision", - "Probability of applying Speaker Condition": "Probability of applying Speaker Condition", - "Put your text here.": "Put your text here.", - "Reference Audio": "Reference Audio", - "Reference Text": "Reference Text", - "Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.", - "Remove Selected Data": "Remove Selected Data", - "Removed path successfully!": "Removed path successfully!", - "Repetition Penalty": "Repetition Penalty", - "Save model every n steps": "Save model every n steps", - "Select LLAMA ckpt": "Select LLAMA ckpt", - "Select VITS ckpt": "Select VITS ckpt", - "Select VQGAN ckpt": "Select VQGAN ckpt", - "Select source file processing method": "Select source file processing method", - "Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)", - "Selected: {}": "Selected: {}", - "Speaker": "Speaker", - "Speaker is identified by the folder name": "Speaker is identified by the folder name", - "Start Training": "Start Training", - "Streaming Audio": "Streaming Audio", - "Streaming Generate": "Streaming Generate", - "Tensorboard Host": "Tensorboard Host", - "Tensorboard Log Path": "Tensorboard Log Path", - "Tensorboard Port": "Tensorboard Port", - "Tensorboard interface is closed": "Tensorboard interface is closed", - "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}", - "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.", - "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.", - "Training Configuration": "Training Configuration", - "Training Error": "Training Error", - "Training stopped": "Training stopped", - "Type name of the speaker": "Type name of the speaker", - "Type the path or select from the dropdown": "Type the path or select from the dropdown", - "Use LoRA": "Use LoRA", - "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model", - "Use filelist": "Use filelist", - "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G", - "VITS Configuration": "VITS Configuration", - "VQGAN Configuration": "VQGAN Configuration", - "Validation Batch Size": "Validation Batch Size", - "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)", - "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.", - "WebUI Host": "WebUI Host", - "WebUI Port": "WebUI Port", - "Whisper Model": "Whisper Model", - "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).", - "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU", - "latest": "latest", - "new": "new", - "Realtime Transform Text": "Realtime Transform Text", - "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)", - "Text Normalization": "Text Normalization" -} +{ + "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).", + "Accumulate Gradient Batches": "Accumulate Gradient Batches", + "Add to Processing Area": "Add to Processing Area", + "Added path successfully!": "Added path successfully!", + "Advanced Config": "Advanced Config", + "Base LLAMA Model": "Base LLAMA Model", + "Batch Inference": "Batch Inference", + "Batch Size": "Batch Size", + "Changing with the Model Path": "Changing with the Model Path", + "Chinese": "Chinese", + "Compile Model": "Compile Model", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time", + "Copy": "Copy", + "Data Preprocessing": "Data Preprocessing", + "Data Preprocessing Path": "Data Preprocessing Path", + "Data Source": "Data Source", + "Decoder Model Config": "Decoder Model Config", + "Decoder Model Path": "Decoder Model Path", + "Disabled": "Disabled", + "Enable Reference Audio": "Enable Reference Audio", + "English": "English", + "Error Message": "Error Message", + "File Preprocessing": "File Preprocessing", + "Generate": "Generate", + "Generated Audio": "Generated Audio", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format", + "Infer interface is closed": "Infer interface is closed", + "Inference Configuration": "Inference Configuration", + "Inference Server Configuration": "Inference Server Configuration", + "Inference Server Error": "Inference Server Error", + "Inferring interface is launched at {}": "Inferring interface is launched at {}", + "Initial Learning Rate": "Initial Learning Rate", + "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription", + "Input Text": "Input Text", + "Invalid path: {}": "Invalid path: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU", + "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off", + "Japanese": "Japanese", + "LLAMA Configuration": "LLAMA Configuration", + "LLAMA Model Config": "LLAMA Model Config", + "LLAMA Model Path": "LLAMA Model Path", + "Labeling Device": "Labeling Device", + "LoRA Model to be merged": "LoRA Model to be merged", + "Maximum Audio Duration": "Maximum Audio Duration", + "Maximum Length per Sample": "Maximum Length per Sample", + "Maximum Training Steps": "Maximum Training Steps", + "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit", + "Merge": "Merge", + "Merge LoRA": "Merge LoRA", + "Merge successfully": "Merge successfully", + "Minimum Audio Duration": "Minimum Audio Duration", + "Model Output Path": "Model Output Path", + "Model Size": "Model Size", + "Move": "Move", + "Move files successfully": "Move files successfully", + "No audio generated, please check the input text.": "No audio generated, please check the input text.", + "No selected options": "No selected options", + "Number of Workers": "Number of Workers", + "Open Inference Server": "Open Inference Server", + "Open Labeler WebUI": "Open Labeler WebUI", + "Open Tensorboard": "Open Tensorboard", + "Opened labeler in browser": "Opened labeler in browser", + "Optional Label Language": "Optional Label Language", + "Optional online ver": "Optional online ver", + "Output Path": "Output Path", + "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path", + "Precision": "Precision", + "Probability of applying Speaker Condition": "Probability of applying Speaker Condition", + "Put your text here.": "Put your text here.", + "Reference Audio": "Reference Audio", + "Reference Text": "Reference Text", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.", + "Remove Selected Data": "Remove Selected Data", + "Removed path successfully!": "Removed path successfully!", + "Repetition Penalty": "Repetition Penalty", + "Save model every n steps": "Save model every n steps", + "Select LLAMA ckpt": "Select LLAMA ckpt", + "Select VITS ckpt": "Select VITS ckpt", + "Select VQGAN ckpt": "Select VQGAN ckpt", + "Select source file processing method": "Select source file processing method", + "Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)", + "Selected: {}": "Selected: {}", + "Speaker": "Speaker", + "Speaker is identified by the folder name": "Speaker is identified by the folder name", + "Start Training": "Start Training", + "Streaming Audio": "Streaming Audio", + "Streaming Generate": "Streaming Generate", + "Tensorboard Host": "Tensorboard Host", + "Tensorboard Log Path": "Tensorboard Log Path", + "Tensorboard Port": "Tensorboard Port", + "Tensorboard interface is closed": "Tensorboard interface is closed", + "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}", + "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.", + "Training Configuration": "Training Configuration", + "Training Error": "Training Error", + "Training stopped": "Training stopped", + "Type name of the speaker": "Type name of the speaker", + "Type the path or select from the dropdown": "Type the path or select from the dropdown", + "Use LoRA": "Use LoRA", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model", + "Use filelist": "Use filelist", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G", + "VITS Configuration": "VITS Configuration", + "VQGAN Configuration": "VQGAN Configuration", + "Validation Batch Size": "Validation Batch Size", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.", + "WebUI Host": "WebUI Host", + "WebUI Port": "WebUI Port", + "Whisper Model": "Whisper Model", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU", + "latest": "latest", + "new": "new", + "Realtime Transform Text": "Realtime Transform Text", + "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)", + "Text Normalization": "Text Normalization", + "Select Example Audio": "Select Example Audio" +} diff --git a/fish_speech/i18n/locale/es_ES.json b/fish_speech/i18n/locale/es_ES.json index 3285341f6893fe3e2ccbee6490dd8c90ed21854e..0bde8404cdaff94e7c49be580cbba99b8f41ce29 100644 --- a/fish_speech/i18n/locale/es_ES.json +++ b/fish_speech/i18n/locale/es_ES.json @@ -1,122 +1,123 @@ -{ - "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+", - "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.", - "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).", - "Accumulate Gradient Batches": "Acumular lotes de gradientes", - "Add to Processing Area": "Agregar al Área de Procesamiento", - "Added path successfully!": "¡Ruta agregada exitosamente!", - "Advanced Config": "Configuración Avanzada", - "Base LLAMA Model": "Modelo Base LLAMA", - "Batch Inference": "Inferencia por Lote", - "Batch Size": "Tamaño del Lote", - "Changing with the Model Path": "Cambiando con la Ruta del Modelo", - "Chinese": "Chino", - "Compile Model": "Compilar Modelo", - "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío", - "Copy": "Copiar", - "Data Preprocessing": "Preprocesamiento de Datos", - "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos", - "Data Source": "Fuente de Datos", - "Decoder Model Config": "Configuración del modelo decodificador", - "Decoder Model Path": "Ruta del modelo decodificador", - "Disabled": "Desactivado", - "Enable Reference Audio": "Habilitar Audio de Referencia", - "English": "Inglés", - "Error Message": "Mensaje de Error", - "File Preprocessing": "Preprocesamiento de Archivos", - "Generate": "Generar", - "Generated Audio": "Audio Generado", - "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab", - "Infer interface is closed": "La interfaz de inferencia está cerrada", - "Inference Configuration": "Configuración de Inferencia", - "Inference Server Configuration": "Configuración del Servidor de Inferencia", - "Inference Server Error": "Error del Servidor de Inferencia", - "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}", - "Initial Learning Rate": "Tasa de Aprendizaje Inicial", - "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción", - "Input Text": "Texto de Entrada", - "Invalid path: {}": "Ruta inválida: {}", - "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU", - "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado", - "Japanese": "Japonés", - "LLAMA Configuration": "Configuración de LLAMA", - "LLAMA Model Config": "Configuración del Modelo LLAMA", - "LLAMA Model Path": "Ruta del Modelo LLAMA", - "Labeling Device": "Dispositivo de Etiquetado", - "LoRA Model to be merged": "Modelo LoRA a fusionar", - "Maximum Audio Duration": "Duración máxima de audio", - "Maximum Length per Sample": "Longitud Máxima por Muestra", - "Maximum Training Steps": "Pasos Máximos de Entrenamiento", - "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite", - "Merge": "Fusionar", - "Merge LoRA": "Fusionar LoRA", - "Merge successfully": "Fusionado exitosamente", - "Minimum Audio Duration": "Duración mínima de audio", - "Model Output Path": "Ruta de Salida del Modelo", - "Model Size": "Tamaño del Modelo", - "Move": "Mover", - "Move files successfully": "Archivos movidos exitosamente", - "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.", - "No selected options": "No hay opciones seleccionadas", - "Number of Workers": "Número de Trabajadores", - "Open Inference Server": "Abrir Servidor de Inferencia", - "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador", - "Open Tensorboard": "Abrir Tensorboard", - "Opened labeler in browser": "Se abrió el etiquetador en el navegador", - "Optional Label Language": "Idioma de Etiquetado Opcional", - "Optional online ver": "Ver en línea opcional", - "Output Path": "Ruta de Salida", - "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente", - "Precision": "Precisión", - "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante", - "Put your text here.": "Ponga su texto aquí.", - "Reference Audio": "Audio de Referencia", - "Reference Text": "Texto de Referencia", - "Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.", - "Remove Selected Data": "Eliminar Datos Seleccionados", - "Removed path successfully!": "¡Ruta eliminada exitosamente!", - "Repetition Penalty": "Penalización por Repetición", - "Save model every n steps": "Guardar modelo cada n pasos", - "Select LLAMA ckpt": "Seleccionar punto de control LLAMA", - "Select VITS ckpt": "Seleccionar punto de control VITS", - "Select VQGAN ckpt": "Seleccionar punto de control VQGAN", - "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente", - "Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)", - "Selected: {}": "Seleccionado: {}", - "Speaker": "Hablante", - "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta", - "Start Training": "Iniciar Entrenamiento", - "Streaming Audio": "transmisión de audio", - "Streaming Generate": "síntesis en flujo", - "Tensorboard Host": "Host de Tensorboard", - "Tensorboard Log Path": "Ruta de Registro de Tensorboard", - "Tensorboard Port": "Puerto de Tensorboard", - "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada", - "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}", - "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.", - "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.", - "Training Configuration": "Configuración de Entrenamiento", - "Training Error": "Error de Entrenamiento", - "Training stopped": "Entrenamiento detenido", - "Type name of the speaker": "Escriba el nombre del hablante", - "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable", - "Use LoRA": "Usar LoRA", - "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo", - "Use filelist": "Usar lista de archivos", - "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G", - "VITS Configuration": "Configuración de VITS", - "VQGAN Configuration": "Configuración de VQGAN", - "Validation Batch Size": "Tamaño del Lote de Validación", - "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)", - "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.", - "WebUI Host": "Host de WebUI", - "WebUI Port": "Puerto de WebUI", - "Whisper Model": "Modelo Whisper", - "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).", - "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+", - "latest": "más reciente", - "new": "nuevo", - "Realtime Transform Text": "Transformación de Texto en Tiempo Real", - "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)", - "Text Normalization": "Normalización de Texto" -} +{ + "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).", + "Accumulate Gradient Batches": "Acumular lotes de gradientes", + "Add to Processing Area": "Agregar al Área de Procesamiento", + "Added path successfully!": "¡Ruta agregada exitosamente!", + "Advanced Config": "Configuración Avanzada", + "Base LLAMA Model": "Modelo Base LLAMA", + "Batch Inference": "Inferencia por Lote", + "Batch Size": "Tamaño del Lote", + "Changing with the Model Path": "Cambiando con la Ruta del Modelo", + "Chinese": "Chino", + "Compile Model": "Compilar Modelo", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío", + "Copy": "Copiar", + "Data Preprocessing": "Preprocesamiento de Datos", + "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos", + "Data Source": "Fuente de Datos", + "Decoder Model Config": "Configuración del modelo decodificador", + "Decoder Model Path": "Ruta del modelo decodificador", + "Disabled": "Desactivado", + "Enable Reference Audio": "Habilitar Audio de Referencia", + "English": "Inglés", + "Error Message": "Mensaje de Error", + "File Preprocessing": "Preprocesamiento de Archivos", + "Generate": "Generar", + "Generated Audio": "Audio Generado", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab", + "Infer interface is closed": "La interfaz de inferencia está cerrada", + "Inference Configuration": "Configuración de Inferencia", + "Inference Server Configuration": "Configuración del Servidor de Inferencia", + "Inference Server Error": "Error del Servidor de Inferencia", + "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}", + "Initial Learning Rate": "Tasa de Aprendizaje Inicial", + "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción", + "Input Text": "Texto de Entrada", + "Invalid path: {}": "Ruta inválida: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU", + "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado", + "Japanese": "Japonés", + "LLAMA Configuration": "Configuración de LLAMA", + "LLAMA Model Config": "Configuración del Modelo LLAMA", + "LLAMA Model Path": "Ruta del Modelo LLAMA", + "Labeling Device": "Dispositivo de Etiquetado", + "LoRA Model to be merged": "Modelo LoRA a fusionar", + "Maximum Audio Duration": "Duración máxima de audio", + "Maximum Length per Sample": "Longitud Máxima por Muestra", + "Maximum Training Steps": "Pasos Máximos de Entrenamiento", + "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite", + "Merge": "Fusionar", + "Merge LoRA": "Fusionar LoRA", + "Merge successfully": "Fusionado exitosamente", + "Minimum Audio Duration": "Duración mínima de audio", + "Model Output Path": "Ruta de Salida del Modelo", + "Model Size": "Tamaño del Modelo", + "Move": "Mover", + "Move files successfully": "Archivos movidos exitosamente", + "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.", + "No selected options": "No hay opciones seleccionadas", + "Number of Workers": "Número de Trabajadores", + "Open Inference Server": "Abrir Servidor de Inferencia", + "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador", + "Open Tensorboard": "Abrir Tensorboard", + "Opened labeler in browser": "Se abrió el etiquetador en el navegador", + "Optional Label Language": "Idioma de Etiquetado Opcional", + "Optional online ver": "Ver en línea opcional", + "Output Path": "Ruta de Salida", + "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente", + "Precision": "Precisión", + "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante", + "Put your text here.": "Ponga su texto aquí.", + "Reference Audio": "Audio de Referencia", + "Reference Text": "Texto de Referencia", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.", + "Remove Selected Data": "Eliminar Datos Seleccionados", + "Removed path successfully!": "¡Ruta eliminada exitosamente!", + "Repetition Penalty": "Penalización por Repetición", + "Save model every n steps": "Guardar modelo cada n pasos", + "Select LLAMA ckpt": "Seleccionar punto de control LLAMA", + "Select VITS ckpt": "Seleccionar punto de control VITS", + "Select VQGAN ckpt": "Seleccionar punto de control VQGAN", + "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente", + "Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)", + "Selected: {}": "Seleccionado: {}", + "Speaker": "Hablante", + "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta", + "Start Training": "Iniciar Entrenamiento", + "Streaming Audio": "transmisión de audio", + "Streaming Generate": "síntesis en flujo", + "Tensorboard Host": "Host de Tensorboard", + "Tensorboard Log Path": "Ruta de Registro de Tensorboard", + "Tensorboard Port": "Puerto de Tensorboard", + "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada", + "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}", + "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.", + "Training Configuration": "Configuración de Entrenamiento", + "Training Error": "Error de Entrenamiento", + "Training stopped": "Entrenamiento detenido", + "Type name of the speaker": "Escriba el nombre del hablante", + "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable", + "Use LoRA": "Usar LoRA", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo", + "Use filelist": "Usar lista de archivos", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G", + "VITS Configuration": "Configuración de VITS", + "VQGAN Configuration": "Configuración de VQGAN", + "Validation Batch Size": "Tamaño del Lote de Validación", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.", + "WebUI Host": "Host de WebUI", + "WebUI Port": "Puerto de WebUI", + "Whisper Model": "Modelo Whisper", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+", + "latest": "más reciente", + "new": "nuevo", + "Realtime Transform Text": "Transformación de Texto en Tiempo Real", + "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)", + "Text Normalization": "Normalización de Texto", + "Select Example Audio": "Selecionar áudio de exemplo" +} diff --git a/fish_speech/i18n/locale/ja_JP.json b/fish_speech/i18n/locale/ja_JP.json index d30bac7bcdf4f4c65b1f78b4dcf9d705c1d8eb39..9d0baeb73ffd9cc1af7570ef0ac7e6018ce9527b 100644 --- a/fish_speech/i18n/locale/ja_JP.json +++ b/fish_speech/i18n/locale/ja_JP.json @@ -1,123 +1,123 @@ -{ - "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします", - "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。", - "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。", - "Accumulate Gradient Batches": "勾配バッチの累積", - "Add to Processing Area": "処理エリアに追加", - "Added path successfully!": "パスの追加に成功しました!", - "Advanced Config": "詳細設定", - "Base LLAMA Model": "基本LLAMAモデル", - "Batch Inference": "バッチ推論", - "Batch Size": "バッチサイズ", - "Changing with the Model Path": "モデルのパスに伴って変化する", - "Chinese": "中国語", - "Compile Model": "モデルのコンパイル", - "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります", - "Copy": "コピー", - "Data Preprocessing": "データ前処理", - "Data Preprocessing Path": "データ前処理パス", - "Data Source": "データソース", - "Decoder Model Config": "デコーダーモデルの構成", - "Decoder Model Path": "デコーダーモデルのパス", - "Disabled": "無効", - "Enable Reference Audio": "リファレンスオーディオを有効にする", - "English": "英語", - "Error Message": "エラーメッセージ", - "File Preprocessing": "文書前处理", - "Generate": "生成", - "Generated Audio": "生成されたオーディオ", - "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています", - "Infer interface is closed": "推論インターフェースが閉じられています", - "Inference Configuration": "推論設定", - "Inference Server Configuration": "推論サーバー設定", - "Inference Server Error": "推論サーバーエラー", - "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました", - "Initial Learning Rate": "初期学習率", - "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス", - "Input Text": "入力テキスト", - "Invalid path: {}": "無効なパス: {}", - "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください", - "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します", - "Japanese": "日本語", - "LLAMA Configuration": "LLAMA設定", - "LLAMA Model Config": "LLAMAモデル設定", - "LLAMA Model Path": "LLAMAモデルパス", - "Labeling Device": "ラベリングデバイス", - "LoRA Model to be merged": "マージするLoRAモデル", - "Maximum Audio Duration": "最大オーディオの長さ", - "Maximum Length per Sample": "サンプルあたりの最大長", - "Maximum Training Steps": "最大トレーニングステップ数", - "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します", - "Merge": "マージ", - "Merge LoRA": "LoRAのマージ", - "Merge successfully": "マージに成功しました", - "Minimum Audio Duration": "最小オーディオの長さ", - "Model Output Path": "モデル出力パス", - "Model Size": "モデルサイズ", - "Move": "移動", - "Move files successfully": "ファイルの移動に成功しました", - "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。", - "No selected options": "選択されたオプションはありません", - "Number of Workers": "ワーカー数", - "Open Inference Server": "推論サーバーを開く", - "Open Labeler WebUI": "ラベラーWebUIを開く", - "Open Tensorboard": "Tensorboardを開く", - "Opened labeler in browser": "ブラウザでラベラーを開きました", - "Optional Label Language": "オプションのラベル言語", - "Optional online ver": "オプションのオンラインバージョン", - "Output Path": "出力パス", - "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください", - "Precision": "精度", - "Probability of applying Speaker Condition": "話者条件を適用する確率", - "Put your text here.": "ここにテキストを入力してください。", - "Reference Audio": "リファレンスオーディオ", - "Reference Text": "リファレンステキスト", - "Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。", - "Remove Selected Data": "選択したデータを削除", - "Removed path successfully!": "パスの削除に成功しました!", - "Repetition Penalty": "反復ペナルティ", - "Save model every n steps": "nステップごとにモデルを保存", - "Select LLAMA ckpt": " LLAMA チェックポイントを選択", - "Select VITS ckpt": "VITS チェックポイントを選択", - "Select VQGAN ckpt": "VQGAN チェックポイントを選択", - "Select source file processing method": "ソースファイルの処理方法を選択", - "Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください", - "Selected: {}": "選択済み: {}", - "Speaker": "話者", - "Speaker is identified by the folder name": "話者はフォルダ名で識別されます", - "Start Training": "トレーニング開始", - "Streaming Audio": "ストリーミングオーディオ", - "Streaming Generate": "ストリーミング合成", - "Tensorboard Host": "Tensorboardホスト", - "Tensorboard Log Path": "Tensorboardログパス", - "Tensorboard Port": "Tensorboardポート", - "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています", - "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました", - "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。", - "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。", - "Training Configuration": "トレーニング設定", - "Training Error": "トレーニングエラー", - "Training stopped": "トレーニングが停止しました", - "Type name of the speaker": "話者の名前を入力", - "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください", - "Use LoRA": "LoRAを使用", - "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります", - "Use filelist": "ファイルリストを使用", - "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください", - "VITS Configuration": "VITS の構成", - "VQGAN Configuration": "VQGAN の構成", - "Validation Batch Size": "検証バッチサイズ", - "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)", - "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。", - "WebUI Host": "WebUIホスト", - "WebUI Port": "WebUIポート", - "Whisper Model": "Whisperモデル", - "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。", - "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします", - "latest": "最新", - "new": "新規", - "Realtime Transform Text": "リアルタイム変換テキスト", - "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)", - "Text Normalization": "テキスト正規化" - -} +{ + "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。", + "Accumulate Gradient Batches": "勾配バッチの累積", + "Add to Processing Area": "処理エリアに追加", + "Added path successfully!": "パスの追加に成功しました!", + "Advanced Config": "詳細設定", + "Base LLAMA Model": "基本LLAMAモデル", + "Batch Inference": "バッチ推論", + "Batch Size": "バッチサイズ", + "Changing with the Model Path": "モデルのパスに伴って変化する", + "Chinese": "中国語", + "Compile Model": "モデルのコンパイル", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります", + "Copy": "コピー", + "Data Preprocessing": "データ前処理", + "Data Preprocessing Path": "データ前処理パス", + "Data Source": "データソース", + "Decoder Model Config": "デコーダーモデルの構成", + "Decoder Model Path": "デコーダーモデルのパス", + "Disabled": "無効", + "Enable Reference Audio": "リファレンスオーディオを有効にする", + "English": "英語", + "Error Message": "エラーメッセージ", + "File Preprocessing": "文書前处理", + "Generate": "生成", + "Generated Audio": "生成されたオーディオ", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています", + "Infer interface is closed": "推論インターフェースが閉じられています", + "Inference Configuration": "推論設定", + "Inference Server Configuration": "推論サーバー設定", + "Inference Server Error": "推論サーバーエラー", + "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました", + "Initial Learning Rate": "初期学習率", + "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス", + "Input Text": "入力テキスト", + "Invalid path: {}": "無効なパス: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください", + "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します", + "Japanese": "日本語", + "LLAMA Configuration": "LLAMA設定", + "LLAMA Model Config": "LLAMAモデル設定", + "LLAMA Model Path": "LLAMAモデルパス", + "Labeling Device": "ラベリングデバイス", + "LoRA Model to be merged": "マージするLoRAモデル", + "Maximum Audio Duration": "最大オーディオの長さ", + "Maximum Length per Sample": "サンプルあたりの最大長", + "Maximum Training Steps": "最大トレーニングステップ数", + "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します", + "Merge": "マージ", + "Merge LoRA": "LoRAのマージ", + "Merge successfully": "マージに成功しました", + "Minimum Audio Duration": "最小オーディオの長さ", + "Model Output Path": "モデル出力パス", + "Model Size": "モデルサイズ", + "Move": "移動", + "Move files successfully": "ファイルの移動に成功しました", + "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。", + "No selected options": "選択されたオプションはありません", + "Number of Workers": "ワーカー数", + "Open Inference Server": "推論サーバーを開く", + "Open Labeler WebUI": "ラベラーWebUIを開く", + "Open Tensorboard": "Tensorboardを開く", + "Opened labeler in browser": "ブラウザでラベラーを開きました", + "Optional Label Language": "オプションのラベル言語", + "Optional online ver": "オプションのオンラインバージョン", + "Output Path": "出力パス", + "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください", + "Precision": "精度", + "Probability of applying Speaker Condition": "話者条件を適用する確率", + "Put your text here.": "ここにテキストを入力してください。", + "Reference Audio": "リファレンスオーディオ", + "Reference Text": "リファレンステキスト", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。", + "Remove Selected Data": "選択したデータを削除", + "Removed path successfully!": "パスの削除に成功しました!", + "Repetition Penalty": "反復ペナルティ", + "Save model every n steps": "nステップごとにモデルを保存", + "Select LLAMA ckpt": " LLAMA チェックポイントを選択", + "Select VITS ckpt": "VITS チェックポイントを選択", + "Select VQGAN ckpt": "VQGAN チェックポイントを選択", + "Select source file processing method": "ソースファイルの処理方法を選択", + "Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください", + "Selected: {}": "選択済み: {}", + "Speaker": "話者", + "Speaker is identified by the folder name": "話者はフォルダ名で識別されます", + "Start Training": "トレーニング開始", + "Streaming Audio": "ストリーミングオーディオ", + "Streaming Generate": "ストリーミング合成", + "Tensorboard Host": "Tensorboardホスト", + "Tensorboard Log Path": "Tensorboardログパス", + "Tensorboard Port": "Tensorboardポート", + "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています", + "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました", + "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。", + "Training Configuration": "トレーニング設定", + "Training Error": "トレーニングエラー", + "Training stopped": "トレーニングが停止しました", + "Type name of the speaker": "話者の名前を入力", + "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください", + "Use LoRA": "LoRAを使用", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります", + "Use filelist": "ファイルリストを使用", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください", + "VITS Configuration": "VITS の構成", + "VQGAN Configuration": "VQGAN の構成", + "Validation Batch Size": "検証バッチサイズ", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。", + "WebUI Host": "WebUIホスト", + "WebUI Port": "WebUIポート", + "Whisper Model": "Whisperモデル", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします", + "latest": "最新", + "new": "新規", + "Realtime Transform Text": "リアルタイム変換テキスト", + "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)", + "Text Normalization": "テキスト正規化", + "Select Example Audio": "サンプル音声を選択" +} diff --git a/fish_speech/i18n/locale/ko_KR.json b/fish_speech/i18n/locale/ko_KR.json new file mode 100644 index 0000000000000000000000000000000000000000..f4bf1841b7c847993707ec3b8e32f5174de77214 --- /dev/null +++ b/fish_speech/i18n/locale/ko_KR.json @@ -0,0 +1,123 @@ +{ + "16-mixed is recommended for 10+ series GPU": "10+ 시리즈 GPU에는 16-mixed를 권장합니다.", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "화자를 특정하는 데 유의미한 5~10초의 길이의 참조 오디오 데이터.", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)에서 개발한 VQ-GAN 및 Llama 기반의 텍스트 음성 변환 모델.", + "Accumulate Gradient Batches": "그라디언트 배치 누적", + "Add to Processing Area": "처리 영역에 추가", + "Added path successfully!": "경로가 성공적으로 추가되었습니다!", + "Advanced Config": "고급 설정", + "Base LLAMA Model": "기본 LLAMA 모델", + "Batch Inference": "배치 추론", + "Batch Size": "배치 크기", + "Changing with the Model Path": "모델 경로에 따라 변경 중", + "Chinese": "중국어", + "Compile Model": "모델 컴파일", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "모델을 컴파일하면 추론 시간이 크게 줄어들지만, 초기 시작 시간이 길어집니다.", + "Copy": "복사", + "Data Preprocessing": "데이터 전처리", + "Data Preprocessing Path": "데이터 전처리 경로", + "Data Source": "데이터 소스", + "Decoder Model Config": "디코더 모델 설정", + "Decoder Model Path": "디코더 모델 경로", + "Disabled": "비활성화 됨", + "Enable Reference Audio": "참고 음성 활성화", + "English": "영어", + "Error Message": "오류 메시지", + "File Preprocessing": "파일 전처리", + "Generate": "생성", + "Generated Audio": "생성된 오디오", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "오디오애 대응하는 텍스트가 없을 경우, ASR을 적용해 지원하며, .txt 또는 .lab 형식을 지원합니다.", + "Infer interface is closed": "추론 인터페이스가 닫혔습니다.", + "Inference Configuration": "추론 설정", + "Inference Server Configuration": "추론 서버 설정", + "Inference Server Error": "추론 서버 오류", + "Inferring interface is launched at {}": "추론 인터페이스가 {}에서 시작되었습니다.", + "Initial Learning Rate": "초기 학습률", + "Input Audio & Source Path for Transcription": "전사할 입력 오디오 및 소스 경로", + "Input Text": "입력 텍스트", + "Invalid path: {}": "유효하지 않은 경로: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDA 사용을 권장하며, 낮은 사양일 경우 CPU를 사용하는 것을 권장합니다.", + "Iterative Prompt Length, 0 means off": "반복 프롬프트 길이. (0:비활성화)", + "Japanese": "일본어", + "LLAMA Configuration": "LLAMA 설정", + "LLAMA Model Config": "LLAMA 모델 설정", + "LLAMA Model Path": "LLAMA 모델 경로", + "Labeling Device": "라벨링 장치", + "LoRA Model to be merged": "병합할 LoRA 모델", + "Maximum Audio Duration": "최대 오디오 길이", + "Maximum Length per Sample": "샘플당 최대 길이", + "Maximum Training Steps": "최대 학습 단계", + "Maximum tokens per batch, 0 means no limit": "배치당 최대 토큰 수(0:제한 없음)", + "Merge": "병합", + "Merge LoRA": "LoRA 병합", + "Merge successfully": "성공적으로 병합 되었습니다.", + "Minimum Audio Duration": "최소 오디오 길이", + "Model Output Path": "모델 출력 경로", + "Model Size": "모델 크기", + "Move": "이동", + "Move files successfully": "파일이 성공적으로 이동되었습니다.", + "No audio generated, please check the input text.": "생성된 오디오가 없습니다. 입력된 텍스트를 확인하세요.", + "No selected options": "옵션이 선택되지 않았습니다.", + "Number of Workers": "작업자 수", + "Open Inference Server": "추론 서버 열기", + "Open Labeler WebUI": "라벨러 WebUI 열기", + "Open Tensorboard": "Tensorboard 열기", + "Opened labeler in browser": "브라우저에서 라벨러가 열렸습니다.", + "Optional Label Language": "선택적 라벨 언어", + "Optional online ver": "온라인 버전 선택", + "Output Path": "출력 경로", + "Path error, please check the model file exists in the corresponding path": "경로 오류, 해당 경로에 모델 파일이 있는지 확인하십시오.", + "Precision": "정밀도", + "Probability of applying Speaker Condition": "화자 조건 적용 확률", + "Put your text here.": "여기에 텍스트를 입력하세요.", + "Reference Audio": "참고 오디오", + "Reference Text": "참고 텍스트", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "관련 코드 및 가중치는 CC BY-NC-SA 4.0 라이선스 하에 배포됩니다.", + "Remove Selected Data": "선택한 데이터 제거", + "Removed path successfully!": "경로가 성공적으로 제거되었습니다!", + "Repetition Penalty": "반복 패널티", + "Save model every n steps": "n 단계마다 모델 저장", + "Select LLAMA ckpt": "LLAMA ckpt 선택", + "Select VITS ckpt": "VITS ckpt 선택", + "Select VQGAN ckpt": "VQGAN ckpt 선택", + "Select source file processing method": "소스 파일 처리 방법 선택", + "Select the model to be trained (Depending on the Tab page you are on)": "학습할 모델 선택(탭 페이지에 따라 다름)", + "Selected: {}": "선택됨: {}", + "Speaker": "화자", + "Speaker is identified by the folder name": "화자는 폴더 이름으로 식별됩니다", + "Start Training": "학습 시작", + "Streaming Audio": "스트리밍 오디오", + "Streaming Generate": "스트리밍 생성", + "Tensorboard Host": "Tensorboard 호스트", + "Tensorboard Log Path": "Tensorboard 로그 경로", + "Tensorboard Port": "Tensorboard 포트", + "Tensorboard interface is closed": "Tensorboard 인터페이스가 닫혔습니다", + "Tensorboard interface is launched at {}": "Tensorboard 인터페이스가 {}에서 시작되었습니다.", + "Text is too long, please keep it under {} characters.": "텍스트가 너무 깁니다. {}자 이하로 입력해주세요.", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "왼쪽의 입력 폴더 경로 또는 파일 목록의 경로. 체크 여부에 관계없이 이 목록에서 후속 학습에 사용됩니다.", + "Training Configuration": "학습 설정", + "Training Error": "학습 오류", + "Training stopped": "학습이 중지되었습니다.", + "Type name of the speaker": "화자의 이름을 입력하세요.", + "Type the path or select from the dropdown": "경로를 입력하거나 드롭다운에서 선택하세요.", + "Use LoRA": "LoRA 사용", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRA를 사용하면 GPU 메모리를 절약할 수 있지만, 모델의 품질이 저하될 수 있습니다.", + "Use filelist": "파일 목록 사용", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 환경에선 large, 5G에선 medium, 2G에선 small을 사용할 것을 권장합니다.", + "VITS Configuration": "VITS 설정", + "VQGAN Configuration": "VQGAN 설정", + "Validation Batch Size": "검증 배치 크기", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "전처리 폴더의 상태를 확인합니다(슬라이더를 사용하여 트리의 깊이를 조절합니다)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "모델의 오용에 대해 책임지지 않습니다. 사용하기 전에 현지 법률과 규정을 고려하시길 바랍니다.", + "WebUI Host": "WebUI 호스트", + "WebUI Port": "WebUI 포트", + "Whisper Model": "Whisper 모델", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "소스 코드는 [이곳](https://github.com/fishaudio/fish-speech)에서, 모델은 [이곳](https://huggingface.co/fishaudio/fish-speech-1)에서 확인하실 수 있습니다.", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 시리즈 GPU에는 bf16-true를, 10+ 시리즈 GPU에는 16-mixed를 권장합니다", + "latest": "최신", + "new": "새로운", + "Realtime Transform Text": "실시간 텍스트 변환", + "Normalization Result Preview (Currently Only Chinese)": "정규화 결과 미리보기(현재 중국어만 지원)", + "Text Normalization": "텍스트 정규화", + "Select Example Audio": "예시 오디오 선택" +} diff --git a/fish_speech/i18n/locale/pt_BR.json b/fish_speech/i18n/locale/pt_BR.json index 385f20272e19053ab9b6cf6463a84c8ece768c68..a5278e29fac737d9fd3da3f8e82e49ac22a96ac3 100644 --- a/fish_speech/i18n/locale/pt_BR.json +++ b/fish_speech/i18n/locale/pt_BR.json @@ -1,133 +1,133 @@ -{ - "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.", - "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).", - "Accumulate Gradient Batches": "Acumular Lotes de Gradiente", - "Add to Processing Area": "Adicionar à Área de Processamento", - "Added path successfully!": "Caminho adicionado com sucesso!", - "Advanced Config": "Configuração Avançada", - "Base LLAMA Model": "Modelo LLAMA Base", - "Batch Inference": "Inferência em Lote", - "Batch Size": "Tamanho do Lote", - "Changing with the Model Path": "Alterando com o Caminho do Modelo", - - "Compile Model": "Compilar Modelo", - "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial", - "Copy": "Copiar", - "Data Preprocessing": "Pré-processamento de Dados", - "Data Preprocessing Path": "Caminho de Pré-processamento de Dados", - "Data Source": "Fonte de Dados", - "Decoder Model Config": "Configuração do Modelo Decodificador", - "Decoder Model Path": "Caminho do Modelo Decodificador", - "Disabled": "Desativado", - "Enable Initial Prompt": "Habilitar Prompt Inicial", - "Enable Reference Audio": "Habilitar Áudio de Referência", - "English": "Inglês", - "Japanese": "Japonês", - "Chinese": "Chinês", - "Portuguese": "Português", - "Spanish": "Espanhol", - "Error Message": "Mensagem de Erro", - "Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)", - "File Preprocessing": "Pré-processamento de Arquivos", - "Generate": "Gerar", - "Generated Audio": "Áudio Gerado", - "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)", - "Infer interface is closed": "A interface de inferência foi fechada", - "Inference Configuration": "Configuração de Inferência", - "Inference Server Configuration": "Configuração do Servidor de Inferência", - "Inference Server Error": "Erro do Servidor de Inferência", - "Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}", - "Initial Learning Rate": "Taxa de Aprendizagem Inicial", - "Initial Prompt": "Prompt Inicial", - "Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.", - "Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição", - "Input Text": "Texto de Entrada", - "Invalid path: {}": "Caminho inválido: {}", - "It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU", - "Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)", - "LLAMA Configuration": "Configuração do LLAMA", - "LLAMA Model Config": "Configuração do Modelo LLAMA", - "LLAMA Model Path": "Caminho do Modelo LLAMA", - "Labeling Device": "Dispositivo de Rotulagem", - "LoRA Model to be merged": "Modelo LoRA para mesclagem", - "Maximum Length per Sample": "Comprimento Máximo por Amostra", - "Maximum Training Steps": "Etapas Máximas de Treinamento", - "Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite", - "Merge": "Mesclar", - "Merge LoRA": "Mesclar LoRA", - "Merge successfully": "Mesclado com sucesso", - "Model Output Path": "Caminho de Saída do Modelo", - "Model Quantization": "Quantização do Modelo", - "Model Size": "Tamanho do Modelo", - "Move": "Mover", - "Move files successfully": "Arquivos movidos com sucesso", - "No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.", - "No selected options": "Nenhuma opção selecionada", - "Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)", - "Number of Workers": "Número de Processos", - "Open Inference Server": "Abrir Servidor de Inferência", - "Open Labeler WebUI": "Abrir WebUI de Rotulagem", - "Open Tensorboard": "Abrir Tensorboard", - "Opened labeler in browser": "WebUI de rotulagem aberta no navegador", - "Optional Label Language": "Idioma do Rótulo (Opcional)", - "Optional online ver": "Versão online (opcional)", - "Output Path": "Caminho de Saída", - "Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente", - "Post-quantification Precision": "Precisão Pós-quantização", - "Precision": "Precisão", - "Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador", - "Put your text here.": "Insira seu texto aqui.", - "Quantify": "Quantizar", - "Quantify successfully": "Quantizado com sucesso", - "Realtime Transform Text": "Transformar Texto em Tempo Real", - "Reference Audio": "Áudio de Referência", - "Reference Text": "Texto de Referência", - "warning": "Aviso", - "Pre-processing begins...": "O pré-processamento começou!", - "Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.", - "Remove Selected Data": "Remover Dados Selecionados", - "Removed path successfully!": "Caminho removido com sucesso!", - "Repetition Penalty": "Penalidade de Repetição", - "Save model every n steps": "Salvar modelo a cada n etapas", - "Select LLAMA ckpt": "Selecionar .ckpt do LLAMA", - "Select source file processing method": "Escolha como processar o arquivo de origem", - "Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)", - "Selected: {}": "Selecionado: {}", - "Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta", - "Start Training": "Iniciar Treinamento", - "Streaming Audio": "Áudio em Streaming", - "Streaming Generate": "Geração em Streaming", - "Tensorboard Host": "Host do Tensorboard", - "Tensorboard Log Path": "Caminho de Log do Tensorboard", - "Tensorboard Port": "Porta do Tensorboard", - "Tensorboard interface is closed": "A interface do Tensorboard está fechada", - "Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}", - "Text Normalization": "Normalização de Texto", - "Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.", - "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência", - "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.", - "Training Configuration": "Configuração de Treinamento", - "Training Error": "Erro de Treinamento", - "Training stopped": "Treinamento interrompido!", - "Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso", - "Use LoRA": "Usar LoRA", - "Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade", - "Use filelist": "Usar lista de arquivos", - "VQGAN Configuration": "Configuração do VQGAN", - "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)", - "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.", - "WebUI Host": "Host da WebUI", - "WebUI Port": "Porta da WebUI", - "Whisper Model": "Modelo Whisper", - "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).", - "auto": "automático", - "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+", - "latest": "mais recente", - "new": "novo", - "This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.", - "You don't need to train this model!": "Não é necessário treinar este modelo!", - "Yes": "Sim", - "No": "Não", - "version:": "versão:", - "author:": "autor:" -} +{ + "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).", + "Accumulate Gradient Batches": "Acumular Lotes de Gradiente", + "Add to Processing Area": "Adicionar à Área de Processamento", + "Added path successfully!": "Caminho adicionado com sucesso!", + "Advanced Config": "Configuração Avançada", + "Base LLAMA Model": "Modelo LLAMA Base", + "Batch Inference": "Inferência em Lote", + "Batch Size": "Tamanho do Lote", + "Changing with the Model Path": "Alterando com o Caminho do Modelo", + + "Compile Model": "Compilar Modelo", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial", + "Copy": "Copiar", + "Data Preprocessing": "Pré-processamento de Dados", + "Data Preprocessing Path": "Caminho de Pré-processamento de Dados", + "Data Source": "Fonte de Dados", + "Decoder Model Config": "Configuração do Modelo Decodificador", + "Decoder Model Path": "Caminho do Modelo Decodificador", + "Disabled": "Desativado", + "Enable Initial Prompt": "Habilitar Prompt Inicial", + "Enable Reference Audio": "Habilitar Áudio de Referência", + "English": "Inglês", + "Japanese": "Japonês", + "Chinese": "Chinês", + "Portuguese": "Português", + "Spanish": "Espanhol", + "Error Message": "Mensagem de Erro", + "Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)", + "File Preprocessing": "Pré-processamento de Arquivos", + "Generate": "Gerar", + "Generated Audio": "Áudio Gerado", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)", + "Infer interface is closed": "A interface de inferência foi fechada", + "Inference Configuration": "Configuração de Inferência", + "Inference Server Configuration": "Configuração do Servidor de Inferência", + "Inference Server Error": "Erro do Servidor de Inferência", + "Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}", + "Initial Learning Rate": "Taxa de Aprendizagem Inicial", + "Initial Prompt": "Prompt Inicial", + "Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.", + "Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição", + "Input Text": "Texto de Entrada", + "Invalid path: {}": "Caminho inválido: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU", + "Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)", + "LLAMA Configuration": "Configuração do LLAMA", + "LLAMA Model Config": "Configuração do Modelo LLAMA", + "LLAMA Model Path": "Caminho do Modelo LLAMA", + "Labeling Device": "Dispositivo de Rotulagem", + "LoRA Model to be merged": "Modelo LoRA para mesclagem", + "Maximum Length per Sample": "Comprimento Máximo por Amostra", + "Maximum Training Steps": "Etapas Máximas de Treinamento", + "Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite", + "Merge": "Mesclar", + "Merge LoRA": "Mesclar LoRA", + "Merge successfully": "Mesclado com sucesso", + "Model Output Path": "Caminho de Saída do Modelo", + "Model Quantization": "Quantização do Modelo", + "Model Size": "Tamanho do Modelo", + "Move": "Mover", + "Move files successfully": "Arquivos movidos com sucesso", + "No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.", + "No selected options": "Nenhuma opção selecionada", + "Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)", + "Number of Workers": "Número de Processos", + "Open Inference Server": "Abrir Servidor de Inferência", + "Open Labeler WebUI": "Abrir WebUI de Rotulagem", + "Open Tensorboard": "Abrir Tensorboard", + "Opened labeler in browser": "WebUI de rotulagem aberta no navegador", + "Optional Label Language": "Idioma do Rótulo (Opcional)", + "Optional online ver": "Versão online (opcional)", + "Output Path": "Caminho de Saída", + "Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente", + "Post-quantification Precision": "Precisão Pós-quantização", + "Precision": "Precisão", + "Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador", + "Put your text here.": "Insira seu texto aqui.", + "Quantify": "Quantizar", + "Quantify successfully": "Quantizado com sucesso", + "Realtime Transform Text": "Transformar Texto em Tempo Real", + "Reference Audio": "Áudio de Referência", + "Reference Text": "Texto de Referência", + "warning": "Aviso", + "Pre-processing begins...": "O pré-processamento começou!", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.", + "Remove Selected Data": "Remover Dados Selecionados", + "Removed path successfully!": "Caminho removido com sucesso!", + "Repetition Penalty": "Penalidade de Repetição", + "Save model every n steps": "Salvar modelo a cada n etapas", + "Select LLAMA ckpt": "Selecionar .ckpt do LLAMA", + "Select source file processing method": "Escolha como processar o arquivo de origem", + "Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)", + "Selected: {}": "Selecionado: {}", + "Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta", + "Start Training": "Iniciar Treinamento", + "Streaming Audio": "Áudio em Streaming", + "Streaming Generate": "Geração em Streaming", + "Tensorboard Host": "Host do Tensorboard", + "Tensorboard Log Path": "Caminho de Log do Tensorboard", + "Tensorboard Port": "Porta do Tensorboard", + "Tensorboard interface is closed": "A interface do Tensorboard está fechada", + "Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}", + "Text Normalization": "Normalização de Texto", + "Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.", + "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.", + "Training Configuration": "Configuração de Treinamento", + "Training Error": "Erro de Treinamento", + "Training stopped": "Treinamento interrompido!", + "Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso", + "Use LoRA": "Usar LoRA", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade", + "Use filelist": "Usar lista de arquivos", + "VQGAN Configuration": "Configuração do VQGAN", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.", + "WebUI Host": "Host da WebUI", + "WebUI Port": "Porta da WebUI", + "Whisper Model": "Modelo Whisper", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).", + "auto": "automático", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+", + "latest": "mais recente", + "new": "novo", + "This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.", + "You don't need to train this model!": "Não é necessário treinar este modelo!", + "Yes": "Sim", + "No": "Não", + "version:": "versão:", + "author:": "autor:" +} diff --git a/fish_speech/i18n/locale/zh_CN.json b/fish_speech/i18n/locale/zh_CN.json index 3dd1a5cd1ccf3860ca508238cc64a68ca4fc3276..df7cd5477bfa035ea66ae7322dad46f5d054d9b0 100644 --- a/fish_speech/i18n/locale/zh_CN.json +++ b/fish_speech/i18n/locale/zh_CN.json @@ -1,122 +1,123 @@ -{ - "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed", - "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。", - "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.", - "Accumulate Gradient Batches": "梯度累积批次", - "Add to Processing Area": "加入处理区", - "Added path successfully!": "添加路径成功!", - "Advanced Config": "高级参数", - "Base LLAMA Model": "基础 LLAMA 模型", - "Batch Inference": "批量推理", - "Batch Size": "批次大小", - "Changing with the Model Path": "随模型路径变化", - "Chinese": "中文", - "Compile Model": "编译模型", - "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间", - "Copy": "复制", - "Data Preprocessing": "数据预处理", - "Data Preprocessing Path": "数据预处理路径", - "Data Source": "数据源", - "Decoder Model Config": "解码器模型配置", - "Decoder Model Path": "解码器模型路径", - "Disabled": "禁用", - "Enable Reference Audio": "启用参考音频", - "English": "英文", - "Error Message": "错误信息", - "File Preprocessing": "文件预处理", - "Generate": "生成", - "Generated Audio": "音频", - "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式", - "Infer interface is closed": "推理界面已关闭", - "Inference Configuration": "推理配置", - "Inference Server Configuration": "推理服务器配置", - "Inference Server Error": "推理服务器错误", - "Inferring interface is launched at {}": "推理界面已在 {} 上启动", - "Initial Learning Rate": "初始学习率", - "Input Audio & Source Path for Transcription": "输入音频和转录源路径", - "Input Text": "输入文本", - "Invalid path: {}": "无效路径: {}", - "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU", - "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭", - "Japanese": "日文", - "LLAMA Configuration": "LLAMA 配置", - "LLAMA Model Config": "LLAMA 模型配置", - "LLAMA Model Path": "LLAMA 模型路径", - "Labeling Device": "标注加速设备", - "LoRA Model to be merged": "要合并的 LoRA 模型", - "Maximum Audio Duration": "最大音频时长", - "Maximum Length per Sample": "每个样本的最大长度", - "Maximum Training Steps": "最大训练步数", - "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制", - "Merge": "合并", - "Merge LoRA": "合并 LoRA", - "Merge successfully": "合并成功", - "Minimum Audio Duration": "最小音频时长", - "Model Output Path": "模型输出路径", - "Model Size": "模型规模", - "Move": "移动", - "Move files successfully": "移动文件成功", - "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.", - "No selected options": "没有选择的选项", - "Number of Workers": "数据加载进程数", - "Open Inference Server": "打开推理服务器", - "Open Labeler WebUI": "打开标注工具", - "Open Tensorboard": "打开 Tensorboard", - "Opened labeler in browser": "在浏览器中打开标注工具", - "Optional Label Language": "[可选] 标注语言", - "Optional online ver": "[可选] 使用在线版", - "Output Path": "输出路径", - "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径", - "Precision": "精度", - "Probability of applying Speaker Condition": "应用说话人条件的概率", - "Put your text here.": "在此处输入文本.", - "Reference Audio": "参考音频", - "Reference Text": "参考文本", - "Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.", - "Remove Selected Data": "移除选中数据", - "Removed path successfully!": "移除路径成功!", - "Repetition Penalty": "重复惩罚", - "Save model every n steps": "每 n 步保存模型", - "Select LLAMA ckpt": "选择 LLAMA 检查点", - "Select VITS ckpt": "选择 VITS 检查点", - "Select VQGAN ckpt": "选择 VQGAN 检查点", - "Select source file processing method": "选择源文件处理方法", - "Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型", - "Selected: {}": "已选择: {}", - "Speaker": "说话人", - "Speaker is identified by the folder name": "自动根据父目录名称识别说话人", - "Start Training": "开始训练", - "Streaming Audio": "流式音频", - "Streaming Generate": "流式合成", - "Tensorboard Host": "Tensorboard 监听地址", - "Tensorboard Log Path": "Tensorboard 日志路径", - "Tensorboard Port": "Tensorboard 端口", - "Tensorboard interface is closed": "Tensorboard 界面已关闭", - "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动", - "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.", - "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.", - "Training Configuration": "训练配置", - "Training Error": "训练错误", - "Training stopped": "训练已停止", - "Type name of the speaker": "输入说话人的名称", - "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择", - "Use LoRA": "使用 LoRA", - "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量", - "Use filelist": "使用文件列表", - "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small", - "VITS Configuration": "VITS 配置", - "VQGAN Configuration": "VQGAN 配置", - "Validation Batch Size": "验证批次大小", - "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)", - "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.", - "WebUI Host": "WebUI 监听地址", - "WebUI Port": "WebUI 端口", - "Whisper Model": "Whisper 模型", - "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.", - "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed", - "latest": "最近的检查点", - "new": "创建新的检查点", - "Realtime Transform Text": "实时规范化文本", - "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览", - "Text Normalization": "文本规范化" -} +{ + "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.", + "Accumulate Gradient Batches": "梯度累积批次", + "Add to Processing Area": "加入处理区", + "Added path successfully!": "添加路径成功!", + "Advanced Config": "高级参数", + "Base LLAMA Model": "基础 LLAMA 模型", + "Batch Inference": "批量推理", + "Batch Size": "批次大小", + "Changing with the Model Path": "随模型路径变化", + "Chinese": "中文", + "Compile Model": "编译模型", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间", + "Copy": "复制", + "Data Preprocessing": "数据预处理", + "Data Preprocessing Path": "数据预处理路径", + "Data Source": "数据源", + "Decoder Model Config": "解码器模型配置", + "Decoder Model Path": "解码器模型路径", + "Disabled": "禁用", + "Enable Reference Audio": "启用参考音频", + "English": "英文", + "Error Message": "错误信息", + "File Preprocessing": "文件预处理", + "Generate": "生成", + "Generated Audio": "音频", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式", + "Infer interface is closed": "推理界面已关闭", + "Inference Configuration": "推理配置", + "Inference Server Configuration": "推理服务器配置", + "Inference Server Error": "推理服务器错误", + "Inferring interface is launched at {}": "推理界面已在 {} 上启动", + "Initial Learning Rate": "初始学习率", + "Input Audio & Source Path for Transcription": "输入音频和转录源路径", + "Input Text": "输入文本", + "Invalid path: {}": "无效路径: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU", + "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭", + "Japanese": "日文", + "LLAMA Configuration": "LLAMA 配置", + "LLAMA Model Config": "LLAMA 模型配置", + "LLAMA Model Path": "LLAMA 模型路径", + "Labeling Device": "标注加速设备", + "LoRA Model to be merged": "要合并的 LoRA 模型", + "Maximum Audio Duration": "最大音频时长", + "Maximum Length per Sample": "每个样本的最大长度", + "Maximum Training Steps": "最大训练步数", + "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制", + "Merge": "合并", + "Merge LoRA": "合并 LoRA", + "Merge successfully": "合并成功", + "Minimum Audio Duration": "最小音频时长", + "Model Output Path": "模型输出路径", + "Model Size": "模型规模", + "Move": "移动", + "Move files successfully": "移动文件成功", + "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.", + "No selected options": "没有选择的选项", + "Number of Workers": "数据加载进程数", + "Open Inference Server": "打开推理服务器", + "Open Labeler WebUI": "打开标注工具", + "Open Tensorboard": "打开 Tensorboard", + "Opened labeler in browser": "在浏览器中打开标注工具", + "Optional Label Language": "[可选] 标注语言", + "Optional online ver": "[可选] 使用在线版", + "Output Path": "输出路径", + "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径", + "Precision": "精度", + "Probability of applying Speaker Condition": "应用说话人条件的概率", + "Put your text here.": "在此处输入文本.", + "Reference Audio": "参考音频", + "Reference Text": "参考文本", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.", + "Remove Selected Data": "移除选中数据", + "Removed path successfully!": "移除路径成功!", + "Repetition Penalty": "重复惩罚", + "Save model every n steps": "每 n 步保存模型", + "Select LLAMA ckpt": "选择 LLAMA 检查点", + "Select VITS ckpt": "选择 VITS 检查点", + "Select VQGAN ckpt": "选择 VQGAN 检查点", + "Select source file processing method": "选择源文件处理方法", + "Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型", + "Selected: {}": "已选择: {}", + "Speaker": "说话人", + "Speaker is identified by the folder name": "自动根据父目录名称识别说话人", + "Start Training": "开始训练", + "Streaming Audio": "流式音频", + "Streaming Generate": "流式合成", + "Tensorboard Host": "Tensorboard 监听地址", + "Tensorboard Log Path": "Tensorboard 日志路径", + "Tensorboard Port": "Tensorboard 端口", + "Tensorboard interface is closed": "Tensorboard 界面已关闭", + "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动", + "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.", + "Training Configuration": "训练配置", + "Training Error": "训练错误", + "Training stopped": "训练已停止", + "Type name of the speaker": "输入说话人的名称", + "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择", + "Use LoRA": "使用 LoRA", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量", + "Use filelist": "使用文件列表", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small", + "VITS Configuration": "VITS 配置", + "VQGAN Configuration": "VQGAN 配置", + "Validation Batch Size": "验证批次大小", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.", + "WebUI Host": "WebUI 监听地址", + "WebUI Port": "WebUI 端口", + "Whisper Model": "Whisper 模型", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed", + "latest": "最近的检查点", + "new": "创建新的检查点", + "Realtime Transform Text": "实时规范化文本", + "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览", + "Text Normalization": "文本规范化", + "Select Example Audio": "选择参考音频" +} diff --git a/fish_speech/i18n/scan.py b/fish_speech/i18n/scan.py index d0194c0f1a31dc95309c64626d13f04751a44ba1..00a39a4d08b8a19c91a8518e90bafe6ceea2231c 100644 --- a/fish_speech/i18n/scan.py +++ b/fish_speech/i18n/scan.py @@ -1,122 +1,122 @@ -import ast -import glob -import json -from collections import OrderedDict -from pathlib import Path - -from loguru import logger - -from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH - - -def extract_i18n_strings(node): - i18n_strings = [] - - if ( - isinstance(node, ast.Call) - and isinstance(node.func, ast.Name) - and node.func.id == "i18n" - ): - for arg in node.args: - if isinstance(arg, ast.Str): - i18n_strings.append(arg.s) - - for child_node in ast.iter_child_nodes(node): - i18n_strings.extend(extract_i18n_strings(child_node)) - - return i18n_strings - - -# scan the directory for all .py files (recursively) -# for each file, parse the code into an AST -# for each AST, extract the i18n strings - -strings = [] -folders = ["fish_speech", "tools"] -# for filename in glob.iglob("**/*.py", recursive=True): -for folder in folders: - for f in Path(folder).rglob("*.py"): - code = f.read_text(encoding="utf-8") - if "i18n(" in code: - tree = ast.parse(code) - i18n_strings = extract_i18n_strings(tree) - logger.info(f"Found {len(i18n_strings)} i18n strings in {f}") - strings.extend(i18n_strings) - -code_keys = set(strings) -logger.info(f"Total unique: {len(code_keys)}") - - -standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json" -with open(standard_file, "r", encoding="utf-8") as f: - standard_data = json.load(f, object_pairs_hook=OrderedDict) -standard_keys = set(standard_data.keys()) - -# Define the standard file name -unused_keys = standard_keys - code_keys -logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}") -for unused_key in unused_keys: - logger.info(f"\t{unused_key}") - -missing_keys = code_keys - standard_keys -logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}") -for missing_key in missing_keys: - logger.info(f"\t{missing_key}") - -code_keys_dict = OrderedDict() -for s in strings: - code_keys_dict[s] = s - -# write back -with open(standard_file, "w", encoding="utf-8") as f: - json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True) - f.write("\n") - -logger.info(f"Updated {standard_file}") - - -# Define the standard file name -standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json" - -# Find all JSON files in the directory -dir_path = I18N_FILE_PATH -languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE] - -# Load the standard file -with open(standard_file, "r", encoding="utf-8") as f: - standard_data = json.load(f, object_pairs_hook=OrderedDict) - -# Loop through each language file -for lang_file in languages: - # Load the language file - with open(lang_file, "r", encoding="utf-8") as f: - lang_data = json.load(f, object_pairs_hook=OrderedDict) - - # Find the difference between the language file and the standard file - diff = set(standard_data.keys()) - set(lang_data.keys()) - - miss = set(lang_data.keys()) - set(standard_data.keys()) - - # Add any missing keys to the language file - for key in diff: - lang_data[key] = "#!" + key - logger.info(f"Added missing key: {key} to {lang_file}") - - # Del any extra keys to the language file - for key in miss: - del lang_data[key] - logger.info(f"Del extra key: {key} from {lang_file}") - - # Sort the keys of the language file to match the order of the standard file - lang_data = OrderedDict( - sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0])) - ) - - # Save the updated language file - with open(lang_file, "w", encoding="utf-8") as f: - json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True) - f.write("\n") - - logger.info(f"Updated {lang_file}") - -logger.info("Done") +import ast +import glob +import json +from collections import OrderedDict +from pathlib import Path + +from loguru import logger + +from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH + + +def extract_i18n_strings(node): + i18n_strings = [] + + if ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Name) + and node.func.id == "i18n" + ): + for arg in node.args: + if isinstance(arg, ast.Str): + i18n_strings.append(arg.s) + + for child_node in ast.iter_child_nodes(node): + i18n_strings.extend(extract_i18n_strings(child_node)) + + return i18n_strings + + +# scan the directory for all .py files (recursively) +# for each file, parse the code into an AST +# for each AST, extract the i18n strings + +strings = [] +folders = ["fish_speech", "tools"] +# for filename in glob.iglob("**/*.py", recursive=True): +for folder in folders: + for f in Path(folder).rglob("*.py"): + code = f.read_text(encoding="utf-8") + if "i18n(" in code: + tree = ast.parse(code) + i18n_strings = extract_i18n_strings(tree) + logger.info(f"Found {len(i18n_strings)} i18n strings in {f}") + strings.extend(i18n_strings) + +code_keys = set(strings) +logger.info(f"Total unique: {len(code_keys)}") + + +standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json" +with open(standard_file, "r", encoding="utf-8") as f: + standard_data = json.load(f, object_pairs_hook=OrderedDict) +standard_keys = set(standard_data.keys()) + +# Define the standard file name +unused_keys = standard_keys - code_keys +logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}") +for unused_key in unused_keys: + logger.info(f"\t{unused_key}") + +missing_keys = code_keys - standard_keys +logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}") +for missing_key in missing_keys: + logger.info(f"\t{missing_key}") + +code_keys_dict = OrderedDict() +for s in strings: + code_keys_dict[s] = s + +# write back +with open(standard_file, "w", encoding="utf-8") as f: + json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True) + f.write("\n") + +logger.info(f"Updated {standard_file}") + + +# Define the standard file name +standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json" + +# Find all JSON files in the directory +dir_path = I18N_FILE_PATH +languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE] + +# Load the standard file +with open(standard_file, "r", encoding="utf-8") as f: + standard_data = json.load(f, object_pairs_hook=OrderedDict) + +# Loop through each language file +for lang_file in languages: + # Load the language file + with open(lang_file, "r", encoding="utf-8") as f: + lang_data = json.load(f, object_pairs_hook=OrderedDict) + + # Find the difference between the language file and the standard file + diff = set(standard_data.keys()) - set(lang_data.keys()) + + miss = set(lang_data.keys()) - set(standard_data.keys()) + + # Add any missing keys to the language file + for key in diff: + lang_data[key] = "#!" + key + logger.info(f"Added missing key: {key} to {lang_file}") + + # Del any extra keys to the language file + for key in miss: + del lang_data[key] + logger.info(f"Del extra key: {key} from {lang_file}") + + # Sort the keys of the language file to match the order of the standard file + lang_data = OrderedDict( + sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0])) + ) + + # Save the updated language file + with open(lang_file, "w", encoding="utf-8") as f: + json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True) + f.write("\n") + + logger.info(f"Updated {lang_file}") + +logger.info("Done") diff --git a/fish_speech/models/text2semantic/lit_module.py b/fish_speech/models/text2semantic/lit_module.py index df970400f8a073be4c4166a697245fabdf6b09b0..7b26793532c9e6b42de189c628aac59a477d0f66 100644 --- a/fish_speech/models/text2semantic/lit_module.py +++ b/fish_speech/models/text2semantic/lit_module.py @@ -1,202 +1,202 @@ -from typing import Any, Optional - -import lightning as L -import torch -import torch.nn.functional as F -from lightning.pytorch.utilities.types import OptimizerLRScheduler - -import fish_speech.utils as utils -from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID -from fish_speech.models.text2semantic.llama import NaiveTransformer - -log = utils.RankedLogger(__name__, rank_zero_only=True) - - -class TextToSemantic(L.LightningModule): - def __init__( - self, - model: NaiveTransformer, - optimizer: Any, - lr_scheduler: Any, - ): - super().__init__() - - self.model = model - self.optimizer_builder = optimizer - self.lr_scheduler_builder = lr_scheduler - - def forward(self, x): - return self.model(x) - - def on_save_checkpoint(self, checkpoint): - # Save only LoRA parameters - state_dict = checkpoint["state_dict"] - use_lora = any("lora" in name for name in state_dict.keys()) - if not use_lora: - return - - for name in list(state_dict.keys()): - if "lora" not in name: - state_dict.pop(name) - - def configure_optimizers(self) -> OptimizerLRScheduler: - # Get weight decay parameters - weight_decay_parameters, other_parameters = [], [] - for name, param in self.named_parameters(): - if ".bias" in name or "norm.weight" in name or ".embeddings." in name: - other_parameters.append(param) - else: - weight_decay_parameters.append(param) - - optimizer = self.optimizer_builder( - [ - {"params": weight_decay_parameters}, - {"params": other_parameters, "weight_decay": 0.0}, - ] - ) - - # Print the parameters and their weight decay - for i in optimizer.param_groups: - log.info( - f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters" - ) - - lr_scheduler = self.lr_scheduler_builder(optimizer) - - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": lr_scheduler, - "interval": "step", - }, - } - - # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90 - def get_batch_logps( - self, - logits: torch.FloatTensor, - labels: torch.LongTensor, - average_log_prob: bool = False, - ) -> torch.FloatTensor: - """Compute the log probabilities of the given labels under the given logits. - - Args: - logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size) - labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size) - average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. - - Returns: - A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. - """ - assert logits.shape[:-1] == labels.shape - - labels = labels.clone() - loss_mask = labels != -100 - - # dummy token; we'll ignore the losses on these tokens later - labels[labels == -100] = 0 - - per_token_logps = torch.gather( - logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1) - ).squeeze(-1) - - if average_log_prob: - return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - else: - return (per_token_logps * loss_mask).sum(-1) - - def _step(self, batch, batch_idx, stage: str): - is_train = stage == "train" - - if is_train: - # Key part to make lora work - # Otherwise the parameters are merged, which lead to incorrect gradients - self.model.train() - - # Do positive and negative samples in the same batch to speed up training - labels = batch["labels"] - outputs = self.model( - inp=batch["inputs"], - key_padding_mask=batch["attention_masks"], - ) - token_logits = outputs.token_logits - codebook_logits = outputs.codebook_logits - - # Generate labels - base_loss = F.cross_entropy( - token_logits.view(-1, token_logits.size(-1)), - labels[:, 0].reshape(-1), - ignore_index=-100, - ) - - codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT - semantic_loss = F.cross_entropy( - codebook_logits.view(-1, codebook_logits.size(-1)), - codebook_labels.reshape(-1), - ignore_index=-100, - ) - - loss = base_loss + semantic_loss - - self.log( - f"{stage}/loss", - loss, - on_step=is_train, - on_epoch=not is_train, - prog_bar=True, - logger=True, - sync_dist=not is_train, - ) - - self.log( - f"{stage}/base_loss", - base_loss, - on_step=is_train, - on_epoch=not is_train, - prog_bar=False, - logger=True, - sync_dist=not is_train, - ) - - self.log( - f"{stage}/semantic_loss", - semantic_loss, - on_step=is_train, - on_epoch=not is_train, - prog_bar=False, - logger=True, - sync_dist=not is_train, - ) - - # Top-5 accuracy - accuracy = self.get_accuracy(codebook_logits, codebook_labels) - self.log( - f"{stage}/top_5_accuracy", - accuracy, - on_step=is_train, - on_epoch=not is_train, - prog_bar=True, - logger=True, - sync_dist=not is_train, - ) - - return loss - - def get_accuracy(self, logits, labels): - mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID) - if mask.sum() == 0: - return torch.tensor(0.0, device=logits.device) - - _, indices = logits.topk(5, dim=-1) - correct = indices.eq(labels.unsqueeze(-1)) - correct[~mask] = 0 - correct = correct.sum() - accuracy = correct / mask.sum() - - return accuracy - - def training_step(self, batch, batch_idx): - return self._step(batch, batch_idx, "train") - - def validation_step(self, batch, batch_idx): - return self._step(batch, batch_idx, "val") +from typing import Any, Optional + +import lightning as L +import torch +import torch.nn.functional as F +from lightning.pytorch.utilities.types import OptimizerLRScheduler + +import fish_speech.utils as utils +from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID +from fish_speech.models.text2semantic.llama import NaiveTransformer + +log = utils.RankedLogger(__name__, rank_zero_only=True) + + +class TextToSemantic(L.LightningModule): + def __init__( + self, + model: NaiveTransformer, + optimizer: Any, + lr_scheduler: Any, + ): + super().__init__() + + self.model = model + self.optimizer_builder = optimizer + self.lr_scheduler_builder = lr_scheduler + + def forward(self, x): + return self.model(x) + + def on_save_checkpoint(self, checkpoint): + # Save only LoRA parameters + state_dict = checkpoint["state_dict"] + use_lora = any("lora" in name for name in state_dict.keys()) + if not use_lora: + return + + for name in list(state_dict.keys()): + if "lora" not in name: + state_dict.pop(name) + + def configure_optimizers(self) -> OptimizerLRScheduler: + # Get weight decay parameters + weight_decay_parameters, other_parameters = [], [] + for name, param in self.named_parameters(): + if ".bias" in name or "norm.weight" in name or ".embeddings." in name: + other_parameters.append(param) + else: + weight_decay_parameters.append(param) + + optimizer = self.optimizer_builder( + [ + {"params": weight_decay_parameters}, + {"params": other_parameters, "weight_decay": 0.0}, + ] + ) + + # Print the parameters and their weight decay + for i in optimizer.param_groups: + log.info( + f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters" + ) + + lr_scheduler = self.lr_scheduler_builder(optimizer) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": lr_scheduler, + "interval": "step", + }, + } + + # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90 + def get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + assert logits.shape[:-1] == labels.shape + + labels = labels.clone() + loss_mask = labels != -100 + + # dummy token; we'll ignore the losses on these tokens later + labels[labels == -100] = 0 + + per_token_logps = torch.gather( + logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1) + ).squeeze(-1) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def _step(self, batch, batch_idx, stage: str): + is_train = stage == "train" + + if is_train: + # Key part to make lora work + # Otherwise the parameters are merged, which lead to incorrect gradients + self.model.train() + + # Do positive and negative samples in the same batch to speed up training + labels = batch["labels"] + outputs = self.model( + inp=batch["inputs"], + key_padding_mask=batch["attention_masks"], + ) + token_logits = outputs.token_logits + codebook_logits = outputs.codebook_logits + + # Generate labels + base_loss = F.cross_entropy( + token_logits.view(-1, token_logits.size(-1)), + labels[:, 0].reshape(-1), + ignore_index=-100, + ) + + codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT + semantic_loss = F.cross_entropy( + codebook_logits.view(-1, codebook_logits.size(-1)), + codebook_labels.reshape(-1), + ignore_index=-100, + ) + + loss = base_loss + semantic_loss + + self.log( + f"{stage}/loss", + loss, + on_step=is_train, + on_epoch=not is_train, + prog_bar=True, + logger=True, + sync_dist=not is_train, + ) + + self.log( + f"{stage}/base_loss", + base_loss, + on_step=is_train, + on_epoch=not is_train, + prog_bar=False, + logger=True, + sync_dist=not is_train, + ) + + self.log( + f"{stage}/semantic_loss", + semantic_loss, + on_step=is_train, + on_epoch=not is_train, + prog_bar=False, + logger=True, + sync_dist=not is_train, + ) + + # Top-5 accuracy + accuracy = self.get_accuracy(codebook_logits, codebook_labels) + self.log( + f"{stage}/top_5_accuracy", + accuracy, + on_step=is_train, + on_epoch=not is_train, + prog_bar=True, + logger=True, + sync_dist=not is_train, + ) + + return loss + + def get_accuracy(self, logits, labels): + mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID) + if mask.sum() == 0: + return torch.tensor(0.0, device=logits.device) + + _, indices = logits.topk(5, dim=-1) + correct = indices.eq(labels.unsqueeze(-1)) + correct[~mask] = 0 + correct = correct.sum() + accuracy = correct / mask.sum() + + return accuracy + + def training_step(self, batch, batch_idx): + return self._step(batch, batch_idx, "train") + + def validation_step(self, batch, batch_idx): + return self._step(batch, batch_idx, "val") diff --git a/fish_speech/models/text2semantic/llama.py b/fish_speech/models/text2semantic/llama.py index 0725dfb9b78b1154753641b69c959a2faadba48c..dcd0f0fa96a3bf43768bbb8087f976068e48a8e0 100644 --- a/fish_speech/models/text2semantic/llama.py +++ b/fish_speech/models/text2semantic/llama.py @@ -1,779 +1,887 @@ -import json -import math -from collections import OrderedDict -from dataclasses import dataclass -from pathlib import Path -from typing import Optional - -import torch -import torch.nn as nn -from einops import rearrange -from loguru import logger -from torch import Tensor -from torch.nn import functional as F -from torch.nn.attention import SDPBackend, sdpa_kernel -from torch.utils.checkpoint import checkpoint -from transformers import AutoTokenizer - -from fish_speech.conversation import SEMANTIC_TOKEN -from fish_speech.utils import RankedLogger - -from .lora import LoraConfig, setup_lora - -log = RankedLogger(__name__, rank_zero_only=True) - - -def find_multiple(n: int, k: int) -> int: - if n % k == 0: - return n - return n + k - (n % k) - - -@dataclass -class BaseModelArgs: - model_type: str = "base" - - vocab_size: int = 32000 - n_layer: int = 32 - n_head: int = 32 - dim: int = 4096 - intermediate_size: int = None - n_local_heads: int = -1 - head_dim: int = 64 - rope_base: float = 10000 - norm_eps: float = 1e-5 - max_seq_len: int = 2048 - dropout: float = 0.0 - tie_word_embeddings: bool = True - attention_qkv_bias: bool = False - - # Codebook configs - codebook_size: int = 160 - num_codebooks: int = 4 - - # Gradient checkpointing - use_gradient_checkpointing: bool = True - - # Initialize the model - initializer_range: float = 0.02 - - def __post_init__(self): - if self.n_local_heads == -1: - self.n_local_heads = self.n_head - if self.intermediate_size is None: - hidden_dim = 4 * self.dim - n_hidden = int(2 * hidden_dim / 3) - self.intermediate_size = find_multiple(n_hidden, 256) - self.head_dim = self.dim // self.n_head - - @staticmethod - def from_pretrained(path: str): - path = Path(path) - - if path.is_dir(): - path = path / "config.json" - - with open(path, "r", encoding="utf-8") as f: - data = json.load(f) - - match data["model_type"]: - case "naive": - cls = NaiveModelArgs - case "dual_ar": - cls = DualARModelArgs - case _: - raise ValueError(f"Unknown model type: {data['model_type']}") - - return cls(**data) - - def save(self, path: str): - with open(path, "w") as f: - json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False) - - -@dataclass -class NaiveModelArgs(BaseModelArgs): - model_type: str = "naive" - - -@dataclass -class DualARModelArgs(BaseModelArgs): - model_type: str = "dual_ar" - n_fast_layer: int = 4 - - -class KVCache(nn.Module): - def __init__( - self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16 - ): - super().__init__() - cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim) - self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) - self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) - - def update(self, input_pos, k_val, v_val): - # input_pos: [S], k_val: [B, H, S, D] - assert input_pos.shape[0] == k_val.shape[2] - - k_out = self.k_cache - v_out = self.v_cache - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val - - return k_out, v_out - - -@dataclass -class TransformerForwardResult: - token_logits: Tensor - codebook_logits: Tensor - - -@dataclass -class BaseTransformerForwardResult: - logits: Tensor - hidden_states: Tensor - - -class BaseTransformer(nn.Module): - def __init__( - self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True - ) -> None: - super().__init__() - self.config = config - self.tokenizer = tokenizer - - self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN) - - # Slow transformer - self.embeddings = nn.Embedding( - config.vocab_size, - config.dim, - ) - self.codebook_embeddings = nn.Embedding( - config.codebook_size * config.num_codebooks, - config.dim, - ) - self.layers = nn.ModuleList( - TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer) - ) - self.norm = RMSNorm(config.dim, eps=config.norm_eps) - - if self.config.tie_word_embeddings is False: - self.output = nn.Linear( - config.dim, - config.vocab_size, - bias=False, - ) - - self.register_buffer( - "freqs_cis", - precompute_freqs_cis( - config.max_seq_len, - config.dim // config.n_head, - config.rope_base, - ), - persistent=False, - ) - self.register_buffer( - "causal_mask", - torch.tril( - torch.ones( - config.max_seq_len, - config.max_seq_len, - dtype=torch.bool, - ) - ), - persistent=False, - ) - - # For kv cache - self.max_batch_size = -1 - self.max_seq_len = -1 - - if init_weights: - self.apply(self._init_weights) - - def setup_caches( - self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16 - ): - if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size: - return - - head_dim = self.config.dim // self.config.n_head - max_seq_len = find_multiple(max_seq_len, 8) - self.max_seq_len = max_seq_len - self.max_batch_size = max_batch_size - - for b in self.layers: - b.attention.kv_cache = KVCache( - max_batch_size, - max_seq_len, - self.config.n_local_heads, - head_dim, - dtype=dtype, - ) - - def embed(self, x: Tensor) -> Tensor: - vocab_embeds = [self.embeddings(x[:, 0])] - for i in range(self.config.num_codebooks): - emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size) - emb[x[:, 0] != self.semantic_token_id] = 0 - vocab_embeds.append(emb) - - x = torch.stack(vocab_embeds, dim=3) - x = x.sum(dim=3) - - return x - - def forward( - self, - inp: Tensor, - key_padding_mask: Optional[Tensor] = None, - ) -> BaseTransformerForwardResult: - seq_len = inp.size(2) - - # Here we want to merge the embeddings of the codebooks - x = self.embed(inp) - - freqs_cis = self.freqs_cis[:seq_len] - - # Not that the causal mask here follows the definition of scaled_dot_product_attention - # That is, FALSE means masked out - # To maintain consistency, key_padding_mask use TRUE to mask out - mask = None - if key_padding_mask is not None: - mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K) - mask = mask & key_padding_mask[:, None, None, :].logical_not() - - for layer in self.layers: - if self.config.use_gradient_checkpointing and self.training: - x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True) - else: - x = layer(x, freqs_cis, mask) - - # We got slow_out here - slow_out = self.norm(x) - - if self.config.tie_word_embeddings: - token_logits = F.linear(slow_out, self.embeddings.weight) - else: - token_logits = self.output(slow_out) - - return BaseTransformerForwardResult( - logits=token_logits, - hidden_states=x, - ) - - def forward_generate( - self, - x: Tensor, - input_pos: Optional[Tensor] = None, - return_all: bool = False, - ) -> BaseTransformerForwardResult: - # This is used for generation, optimized for torch compile - assert ( - self.max_seq_len != -1 and self.max_batch_size != -1 - ), "Please call setup_caches before forward_generate" - - x = self.embed(x) - - mask = self.causal_mask[ - None, None, input_pos, : self.max_seq_len - ] # (B, N, Q, K) - freqs_cis = self.freqs_cis[input_pos] - - for layer in self.layers: - x = layer(x, freqs_cis, mask, input_pos=input_pos) - - # If prefill, we only calculate the logits of last token - if x.size(1) > 1 and not return_all: - x = x[:, -1:] - - # We got slow_out here - slow_out = self.norm(x) - - if self.config.tie_word_embeddings: - token_logits = F.linear(slow_out, self.embeddings.weight) - else: - token_logits = self.output(slow_out) - - return BaseTransformerForwardResult( - logits=token_logits, - hidden_states=x, - ) - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - @staticmethod - def from_pretrained( - path: str, - load_weights: bool = False, - max_length: int | None = None, - lora_config: LoraConfig | None = None, - rope_base: int | None = None, - ) -> "BaseTransformer": - config = BaseModelArgs.from_pretrained(str(path)) - if max_length is not None: - config.max_seq_len = max_length - log.info(f"Override max_seq_len to {max_length}") - - if rope_base is not None: - config.rope_base = rope_base - log.info(f"Override rope_base to {rope_base}") - - match config.model_type: - case "naive": - model_cls = NaiveTransformer - case "dual_ar": - model_cls = DualARTransformer - case _: - raise ValueError(f"Unknown model type: {config.model_type}") - - tokenizer = AutoTokenizer.from_pretrained(str(path)) - log.info(f"Loading model from {path}, config: {config}") - model = model_cls(config, tokenizer=tokenizer) - - if lora_config is not None: - setup_lora(model, lora_config) - log.info(f"LoRA setup: {lora_config}") - - if load_weights is False: - log.info("Randomly initialized model") - else: - - if "int8" in str(Path(path)): - logger.info("Using int8 weight-only quantization!") - from tools.llama.quantize import WeightOnlyInt8QuantHandler - - simple_quantizer = WeightOnlyInt8QuantHandler(model) - model = simple_quantizer.convert_for_runtime() - - if "int4" in str(Path(path)): - logger.info("Using int4 quantization!") - path_comps = path.name.split("-") - assert path_comps[-2].startswith("g") - groupsize = int(path_comps[-2][1:]) - from tools.llama.quantize import WeightOnlyInt4QuantHandler - - simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) - model = simple_quantizer.convert_for_runtime() - - weights = torch.load( - Path(path) / "model.pth", map_location="cpu", mmap=True - ) - - if "state_dict" in weights: - logger.warning( - "Using a TextToSemantic LightningModule checkpoint, " - "please make sure it is a full model, not a LoRA model." - ) - weights = weights["state_dict"] - - if next(iter(weights.keys())).startswith("model."): - logger.info( - f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys" - ) - new_weights = OrderedDict() - for k, v in weights.items(): - new_weights[k.replace("model.", "")] = v - weights = new_weights - - # Verify the name and shape of parameters since strict=False in load_state_dict. - for k, v in model.named_parameters(): - if k not in weights: - logger.warning(f"No weight for {k}") - elif v.shape != weights[k].shape: - logger.warning( - f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}" - ) - - err = model.load_state_dict(weights, strict=False, assign=True) - log.info(f"Loaded weights with error: {err}") - - return model - - def save_pretrained(self, path: str, drop_lora: bool = False): - path = Path(path) - path.mkdir(parents=True, exist_ok=True) - - self.config.save(path / "config.json") - state_dict = self.state_dict() - - if drop_lora: - for key in list(state_dict.keys()): - if "lora" not in key: - continue - - state_dict.pop(key) - log.info(f"Drop LoRA parameter: {key}") - - torch.save(state_dict, path / "model.pth") - self.tokenizer.save_pretrained(path) - - -class NaiveTransformer(BaseTransformer): - def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None: - super().__init__(config, init_weights=False, tokenizer=tokenizer) - - self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps) - self.codebook_output = nn.Linear( - config.dim, - config.codebook_size * config.num_codebooks, - bias=False, - ) - - self.apply(self._init_weights) - - def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult: - token_logits = result.logits - x = result.hidden_states - - # Codebook - codebook_logits = self.codebook_output(self.codebook_norm(x)) - codebook_logits = rearrange( - codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks - ) - - return TransformerForwardResult( - token_logits=token_logits, - codebook_logits=codebook_logits, - ) - - def forward( - self, - inp: Tensor, - key_padding_mask: Optional[Tensor] = None, - ) -> TransformerForwardResult: - result = super().forward( - inp=inp, - key_padding_mask=key_padding_mask, - ) - return self.decode(result) - - def forward_generate( - self, x: Tensor, input_pos: Optional[Tensor] = None - ) -> TransformerForwardResult: - result = super().forward_generate(x, input_pos) - return self.decode(result) - - -class DualARTransformer(BaseTransformer): - def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None: - super().__init__(config, init_weights=False, tokenizer=tokenizer) - - # Fast transformer - self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim) - - # The equivalent bs is so large that sdpa doesn't work - self.fast_layers = nn.ModuleList( - TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer) - ) - self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps) - self.fast_output = nn.Linear( - config.dim, - config.codebook_size, - bias=False, - ) - - self.apply(self._init_weights) - - def setup_caches( - self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16 - ): - super().setup_caches(max_batch_size, max_seq_len, dtype) - - head_dim = self.config.dim // self.config.n_head - - # Fast transformer - # The max seq len here is the number of codebooks - for b in self.fast_layers: - b.attention.kv_cache = KVCache( - max_batch_size, - self.config.num_codebooks, - self.config.n_local_heads, - head_dim, - dtype=dtype, - ) - - def forward( - self, - inp: Tensor, - key_padding_mask: Optional[Tensor] = None, - ) -> TransformerForwardResult: - parent_result = super().forward(inp, key_padding_mask) - token_logits = parent_result.logits - x = parent_result.hidden_states - - # Fast transformer - fast_seq_len = self.config.num_codebooks - fast_mask = self.causal_mask[ - None, None, :fast_seq_len, :fast_seq_len - ] # (B, N, Q, K) - fast_freqs_cis = self.freqs_cis[:fast_seq_len] - - # Drop the last token and rotate left - codebooks = inp[:, 1:-1, 1:] - codebooks = F.pad(codebooks, (0, 1), value=0) - codebook_embeddings = self.fast_embeddings(codebooks) - x = torch.cat([x[:, None], codebook_embeddings], dim=1) - b, s = x.size(0), x.size(2) - x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len - - # Remove padded part - codebooks = rearrange(codebooks, "b n s -> (b s) n") - codebook_mask = (codebooks == 0).all(dim=-1) - - if torch.all(codebook_mask): - # If all codebooks are padded, we keep first 8 to make sure the model runs - codebook_mask[:8] = False - - x_bs, x_len = x.size(0), x.size(1) - x = x[~codebook_mask] - - for layer in self.fast_layers: - if self.config.use_gradient_checkpointing and self.training: - x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True) - else: - x = layer(x, fast_freqs_cis, fast_mask) - - # unflatten the batch and num_codebooks - fast_out = self.fast_norm(x) - codebook_logits = self.fast_output(fast_out) - - # Re-pad the codebook_logits - buffer = torch.zeros( - x_bs, - x_len, - codebook_logits.size(-1), - device=codebook_logits.device, - dtype=codebook_logits.dtype, - ) - buffer[~codebook_mask] = codebook_logits - codebook_logits = buffer - - assert codebook_logits.shape[1] == self.config.num_codebooks - codebook_logits = rearrange( - codebook_logits, - "(b s) n d -> b s n d", - b=b, - s=s, - n=self.config.num_codebooks, - ) - - return TransformerForwardResult( - token_logits=token_logits, - codebook_logits=codebook_logits, - ) - - def forward_generate_fast( - self, x: Tensor, input_pos: Optional[Tensor] = None - ) -> Tensor: - # Fast transformer - x = x.view(1, 1, -1) - - fast_mask = self.causal_mask[ - None, None, input_pos, : self.config.num_codebooks - ] # (B, N, Q, K) - fast_freqs_cis = self.freqs_cis[input_pos] - - for layer in self.fast_layers: - x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos) - - # unflatten the batch and num_codebooks - fast_out = self.fast_norm(x) # only take the last token - codebook_logits = self.fast_output(fast_out) - - return codebook_logits - - -class TransformerBlock(nn.Module): - def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None: - super().__init__() - self.attention = Attention(config, use_sdpa=use_sdpa) - self.feed_forward = FeedForward(config) - self.ffn_norm = RMSNorm(config.dim, config.norm_eps) - self.attention_norm = RMSNorm(config.dim, config.norm_eps) - - def forward( - self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None - ) -> Tensor: - h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) - out = h + self.feed_forward(self.ffn_norm(h)) - return out - - -class Attention(nn.Module): - def __init__(self, config: BaseModelArgs, use_sdpa: bool = True): - super().__init__() - assert config.dim % config.n_head == 0 - - total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim - # key, query, value projections for all heads, but in a batch - self.wqkv = nn.Linear( - config.dim, total_head_dim, bias=config.attention_qkv_bias - ) - self.wo = nn.Linear(config.dim, config.dim, bias=False) - self.kv_cache = None - - self.dropout = config.dropout - self.n_head = config.n_head - self.head_dim = config.head_dim - self.n_local_heads = config.n_local_heads - self.dim = config.dim - self.use_sdpa = use_sdpa - self._register_load_state_dict_pre_hook(self.load_hook) - - def load_hook(self, state_dict, prefix, *args): - if prefix + "wq.weight" in state_dict: - wq = state_dict.pop(prefix + "wq.weight") - wk = state_dict.pop(prefix + "wk.weight") - wv = state_dict.pop(prefix + "wv.weight") - state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) - - def forward( - self, - x: Tensor, - freqs_cis: Tensor, - mask: Tensor, - input_pos: Optional[Tensor] = None, - ) -> Tensor: - bsz, seqlen, _ = x.shape - - kv_size = self.n_local_heads * self.head_dim - q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) - - q = q.view(bsz, seqlen, self.n_head, self.head_dim) - k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) - v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) - - q = apply_rotary_emb(q, freqs_cis) - k = apply_rotary_emb(k, freqs_cis) - - q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) - - if self.kv_cache is not None: - k, v = self.kv_cache.update(input_pos, k, v) - - k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) - v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) - - if self.use_sdpa: - if mask is None: - with sdpa_kernel(SDPBackend.FLASH_ATTENTION): - y = F.scaled_dot_product_attention( - q, - k, - v, - dropout_p=self.dropout if self.training else 0.0, - is_causal=True, - # No third party attn_mask here to use flash_attention - ) - else: - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=mask, - dropout_p=self.dropout if self.training else 0.0, - ) - else: - y = self.eq_scaled_dot_product_attention( - q, - k, - v, - attn_mask=mask, - dropout_p=self.dropout if self.training else 0.0, - ) - - y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) - - return self.wo(y) - - def eq_scaled_dot_product_attention( - self, - query, - key, - value, - attn_mask=None, - dropout_p=0.0, - ) -> torch.Tensor: - # This is a standard scaled dot product attention - # It's low efficient, but it doesn't raise cuda error - - L, S = query.size(-2), key.size(-2) - scale_factor = 1 / math.sqrt(query.size(-1)) - attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device) - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) - else: - attn_bias += attn_mask - - attn_weight = query @ key.transpose(-2, -1) * scale_factor - attn_weight += attn_bias - attn_weight = torch.softmax(attn_weight, dim=-1) - attn_weight = torch.dropout(attn_weight, dropout_p, train=True) - - return attn_weight @ value - - -class FeedForward(nn.Module): - def __init__(self, config: BaseModelArgs) -> None: - super().__init__() - self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) - self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) - self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) - - def forward(self, x: Tensor) -> Tensor: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) - - -class RMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-5): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) - - def forward(self, x: Tensor) -> Tensor: - output = self._norm(x.float()).type_as(x) - return output * self.weight - - -def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor: - freqs = 1.0 / ( - base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) - ) - t = torch.arange(seq_len, device=freqs.device) - freqs = torch.outer(t, freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) - cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) - return cache.to(dtype=torch.bfloat16) - - -def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: - xshaped = x.float().reshape(*x.shape[:-1], -1, 2) - freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) - x_out2 = torch.stack( - [ - xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], - xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], - ], - -1, - ) - - x_out2 = x_out2.flatten(3) - return x_out2.type_as(x) +import dataclasses +import json +import math +from collections import OrderedDict +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange +from loguru import logger +from torch import Tensor +from torch.nn import functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.utils.checkpoint import checkpoint +from transformers import AutoTokenizer + +from fish_speech.tokenizer import SEMANTIC_TOKENS, FishTokenizer +from fish_speech.utils import RankedLogger + +from .lora import LoraConfig, setup_lora + +log = RankedLogger(__name__, rank_zero_only=True) + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +@dataclass +class BaseModelArgs: + model_type: str = "base" + + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + max_seq_len: int = 2048 + dropout: float = 0.0 + tie_word_embeddings: bool = True + attention_qkv_bias: bool = False + + # Codebook configs + codebook_size: int = 160 + num_codebooks: int = 4 + + # Gradient checkpointing + use_gradient_checkpointing: bool = True + + # Initialize the model + initializer_range: float = 0.02 + + # Dummy vars + is_reward_model: bool = False + share_codebook_embeddings: bool = True + scale_codebook_embeddings: bool = False + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + self.head_dim = self.dim // self.n_head + + @staticmethod + def from_pretrained(path: str): + path = Path(path) + + if path.is_dir(): + path = path / "config.json" + + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + + match data["model_type"]: + case "naive": + cls = NaiveModelArgs + case "dual_ar": + cls = DualARModelArgs + case _: + raise ValueError(f"Unknown model type: {data['model_type']}") + + return cls(**data) + + def save(self, path: str): + with open(path, "w") as f: + json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False) + + +@dataclass +class NaiveModelArgs(BaseModelArgs): + model_type: str = "naive" + + +@dataclass +class DualARModelArgs(BaseModelArgs): + model_type: str = "dual_ar" + n_fast_layer: int = 4 + fast_dim: int | None = None + fast_n_head: int | None = None + fast_n_local_heads: int | None = None + fast_head_dim: int | None = None + fast_intermediate_size: int | None = None + fast_attention_qkv_bias: bool | None = None + + def __post_init__(self): + super().__post_init__() + + self.fast_dim = self.fast_dim or self.dim + self.fast_n_head = self.fast_n_head or self.n_head + self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads + self.fast_head_dim = self.fast_head_dim or self.head_dim + self.fast_intermediate_size = ( + self.fast_intermediate_size or self.intermediate_size + ) + self.fast_attention_qkv_bias = ( + self.fast_attention_qkv_bias + if self.fast_attention_qkv_bias is not None + else self.attention_qkv_bias + ) + + +class KVCache(nn.Module): + def __init__( + self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16 + ): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim) + self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + + +@dataclass +class TransformerForwardResult: + token_logits: Tensor + codebook_logits: Tensor + + +@dataclass +class BaseTransformerForwardResult: + logits: Tensor + hidden_states: Tensor + + +class BaseTransformer(nn.Module): + def __init__( + self, + config: BaseModelArgs, + tokenizer: FishTokenizer | AutoTokenizer, + init_weights: bool = True, + ) -> None: + super().__init__() + self.config = config + self.tokenizer = tokenizer + self.semantic_token_ids = [ + tokenizer.get_token_id(SEMANTIC_TOKEN) for SEMANTIC_TOKEN in SEMANTIC_TOKENS + ] + + # Slow transformer + self.embeddings = nn.Embedding( + config.vocab_size, + config.dim, + ) + self.codebook_embeddings = nn.Embedding( + config.codebook_size * config.num_codebooks, + config.dim, + ) + self.layers = nn.ModuleList( + TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer) + ) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + + if self.config.tie_word_embeddings is False: + self.output = nn.Linear( + config.dim, + config.vocab_size, + bias=False, + ) + + self.register_buffer( + "freqs_cis", + precompute_freqs_cis( + config.max_seq_len, + config.dim // config.n_head, + config.rope_base, + ), + persistent=False, + ) + self.register_buffer( + "causal_mask", + torch.tril( + torch.ones( + config.max_seq_len, + config.max_seq_len, + dtype=torch.bool, + ) + ), + persistent=False, + ) + + # For kv cache + self.max_batch_size = -1 + self.max_seq_len = -1 + + if init_weights: + self.apply(self._init_weights) + + def setup_caches( + self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16 + ): + if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size: + return + + head_dim = self.config.dim // self.config.n_head + max_seq_len = find_multiple(max_seq_len, 8) + self.max_seq_len = max_seq_len + self.max_batch_size = max_batch_size + + for b in self.layers: + b.attention.kv_cache = KVCache( + max_batch_size, + max_seq_len, + self.config.n_local_heads, + head_dim, + dtype=dtype, + ) + + def embed(self, x: Tensor) -> Tensor: + vocab_embeds = [self.embeddings(x[:, 0])] + for i in range(self.config.num_codebooks): + emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size) + semantic_token_ids_tensor = torch.tensor( + self.semantic_token_ids, device=x.device + ) + emb[~torch.isin(x[:, 0], semantic_token_ids_tensor)] = 0 + + x = torch.stack(vocab_embeds, dim=3) + x = x.sum(dim=3) + + return x + + def forward( + self, + inp: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> BaseTransformerForwardResult: + seq_len = inp.size(2) + + # Here we want to merge the embeddings of the codebooks + x = self.embed(inp) + + freqs_cis = self.freqs_cis[:seq_len] + + # Not that the causal mask here follows the definition of scaled_dot_product_attention + # That is, FALSE means masked out + # To maintain consistency, key_padding_mask use TRUE to mask out + mask = None + if key_padding_mask is not None: + mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K) + mask = mask & key_padding_mask[:, None, None, :].logical_not() + + for layer in self.layers: + if self.config.use_gradient_checkpointing and self.training: + x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True) + else: + x = layer(x, freqs_cis, mask) + + # We got slow_out here + slow_out = self.norm(x) + + if self.config.tie_word_embeddings: + token_logits = F.linear(slow_out, self.embeddings.weight) + else: + token_logits = self.output(slow_out) + + return BaseTransformerForwardResult( + logits=token_logits, + hidden_states=x, + ) + + def forward_generate( + self, + inp: Tensor, + input_pos: Optional[Tensor] = None, + vq_masks: Optional[Tensor] = None, # this is not used in fact + return_all: bool = False, + ) -> BaseTransformerForwardResult: + # This is used for generation, optimized for torch compile + # assert ( + # self.max_seq_len != -1 and self.max_batch_size != -1 + # ), "Please call setup_caches before forward_generate" + + embeds = [] + for i in range(self.config.num_codebooks): + if self.config.share_codebook_embeddings: + _tokens = inp[:, i + 1] + i * self.config.codebook_size + else: + _tokens = inp[:, i + 1] + + emb = self.codebook_embeddings(_tokens) + embeds.append(emb) + + vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1) + # if self.config.use_codebook_mlp: + # vq_embeds_sum = vq_embeds_sum / self.config.num_codebooks + # vq_embeds_sum = self.codebook_mlp(vq_embeds_sum) + + vq_masks = (inp[:, 0] >= self.tokenizer.semantic_begin_id) & ( + inp[:, 0] <= self.tokenizer.semantic_end_id + ) + + vq_embeds_sum[~vq_masks] = 0 + x = self.embeddings(inp[:, 0]) + vq_embeds_sum + + if input_pos is None: + input_pos = torch.arange(inp.shape[-1], device=x.device) + max_seq_len = inp.shape[-1] + else: + max_seq_len = self.max_seq_len + + mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K) + freqs_cis = self.freqs_cis[input_pos] + + for layer in self.layers: + x = layer(x, freqs_cis, mask, input_pos=input_pos) + + # If prefill, we only calculate the logits of last token + if x.size(1) > 1 and not return_all: + x = x[:, -1:] + + # We got slow_out here + slow_out = self.norm(x) + + if self.config.is_reward_model: + token_logits = self.score_output(slow_out) + elif self.config.tie_word_embeddings: + token_logits = F.linear(slow_out, self.embeddings.weight) + else: + token_logits = self.output(slow_out) + + return BaseTransformerForwardResult( + logits=token_logits, + hidden_states=x, + ) + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @staticmethod + def from_pretrained( + path: str, + load_weights: bool = False, + max_length: int | None = None, + lora_config: LoraConfig | None = None, + rope_base: int | None = None, + is_agent: bool = False, + ) -> "BaseTransformer": + config = BaseModelArgs.from_pretrained(str(path)) + if max_length is not None: + config.max_seq_len = max_length + log.info(f"Override max_seq_len to {max_length}") + + if rope_base is not None: + config.rope_base = rope_base + log.info(f"Override rope_base to {rope_base}") + + match config.model_type: + case "naive": + model_cls = NaiveTransformer + case "dual_ar": + model_cls = DualARTransformer + case _: + raise ValueError(f"Unknown model type: {config.model_type}") + + if is_agent: + tokenizer = AutoTokenizer.from_pretrained(str(path)) + else: + tokenizer_path = str(path) + "/tokenizer.tiktoken" + tokenizer = FishTokenizer(tokenizer_path) + + log.info(f"Loading model from {path}, config: {config}") + model = model_cls(config, tokenizer=tokenizer) + + if lora_config is not None: + setup_lora(model, lora_config) + log.info(f"LoRA setup: {lora_config}") + + if load_weights is False: + log.info("Randomly initialized model") + else: + + if "int8" in str(Path(path)): + logger.info("Using int8 weight-only quantization!") + from tools.llama.quantize import WeightOnlyInt8QuantHandler + + simple_quantizer = WeightOnlyInt8QuantHandler(model) + model = simple_quantizer.convert_for_runtime() + + if "int4" in str(Path(path)): + logger.info("Using int4 quantization!") + path_comps = path.name.split("-") + assert path_comps[-2].startswith("g") + groupsize = int(path_comps[-2][1:]) + from tools.llama.quantize import WeightOnlyInt4QuantHandler + + simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) + model = simple_quantizer.convert_for_runtime() + + weights = torch.load( + Path(path) / "model.pth", + map_location="cpu", + mmap=True, + weights_only=True, + ) + + if "state_dict" in weights: + logger.warning( + "Using a TextToSemantic LightningModule checkpoint, " + "please make sure it is a full model, not a LoRA model." + ) + weights = weights["state_dict"] + + if next(iter(weights.keys())).startswith("model."): + logger.info( + f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys" + ) + new_weights = OrderedDict() + for k, v in weights.items(): + new_weights[k.replace("model.", "")] = v + weights = new_weights + + # Verify the name and shape of parameters since strict=False in load_state_dict. + for k, v in model.named_parameters(): + if k not in weights: + logger.warning(f"No weight for {k}") + elif v.shape != weights[k].shape: + logger.warning( + f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}" + ) + + err = model.load_state_dict(weights, strict=False, assign=True) + log.info(f"Loaded weights with error: {err}") + + return model + + def save_pretrained(self, path: str, drop_lora: bool = False): + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + + self.config.save(path / "config.json") + state_dict = self.state_dict() + + if drop_lora: + for key in list(state_dict.keys()): + if "lora" not in key: + continue + + state_dict.pop(key) + log.info(f"Drop LoRA parameter: {key}") + + torch.save(state_dict, path / "model.pth") + self.tokenizer.save_pretrained(path) + + +class NaiveTransformer(BaseTransformer): + def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None: + super().__init__(config, init_weights=False, tokenizer=tokenizer) + + self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.codebook_output = nn.Linear( + config.dim, + config.codebook_size * config.num_codebooks, + bias=False, + ) + + self.apply(self._init_weights) + + def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult: + token_logits = result.logits + x = result.hidden_states + + # Codebook + codebook_logits = self.codebook_output(self.codebook_norm(x)) + codebook_logits = rearrange( + codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks + ) + + return TransformerForwardResult( + token_logits=token_logits, + codebook_logits=codebook_logits, + ) + + def forward( + self, + inp: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> TransformerForwardResult: + result = super().forward( + inp=inp, + key_padding_mask=key_padding_mask, + ) + return self.decode(result) + + def forward_generate( + self, x: Tensor, input_pos: Optional[Tensor] = None + ) -> TransformerForwardResult: + result = super().forward_generate(x, input_pos) + return self.decode(result) + + +class DualARTransformer(BaseTransformer): + def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None: + super().__init__(config, init_weights=False, tokenizer=tokenizer) + + # Project to fast dim if needed + if config.fast_dim is not None and config.fast_dim != config.dim: + self.fast_project_in = nn.Linear(config.dim, config.fast_dim) + else: + self.fast_project_in = nn.Identity() + + # Fast transformer + self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim) + + # The equivalent bs is so large that sdpa doesn't work + override_config = dataclasses.replace( + config, + dim=config.fast_dim, + n_head=config.fast_n_head, + n_local_heads=config.fast_n_local_heads, + head_dim=config.fast_head_dim, + intermediate_size=config.fast_intermediate_size, + attention_qkv_bias=config.fast_attention_qkv_bias, + ) + + self.fast_layers = nn.ModuleList( + TransformerBlock(override_config, use_sdpa=False) + for _ in range(config.n_fast_layer) + ) + self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps) + self.fast_output = nn.Linear( + config.fast_dim, + config.codebook_size, + bias=False, + ) + + self.register_buffer( + "fast_freqs_cis", + precompute_freqs_cis( + config.num_codebooks, + config.fast_dim // config.fast_n_head, + config.rope_base, + ), + persistent=False, + ) + self.apply(self._init_weights) + + def setup_caches( + self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16 + ): + super().setup_caches(max_batch_size, max_seq_len, dtype) + + head_dim = self.config.fast_dim // self.config.fast_n_head + + # Fast transformer + # The max seq len here is the number of codebooks + for b in self.fast_layers: + b.attention.kv_cache = KVCache( + max_batch_size, + self.config.num_codebooks, + self.config.fast_n_local_heads, + head_dim, + dtype=dtype, + ) + + def forward( + self, + inp: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> TransformerForwardResult: + parent_result = super().forward(inp, key_padding_mask) + token_logits = parent_result.logits + x = parent_result.hidden_states + x = self.fast_project_in(x) + + # Fast transformer + fast_seq_len = self.config.num_codebooks + fast_mask = self.causal_mask[ + None, None, :fast_seq_len, :fast_seq_len + ] # (B, N, Q, K) + + # Drop the last token and rotate left + codebooks = inp[:, 1:-1, 1:] + codebooks = F.pad(codebooks, (0, 1), value=0) + codebook_embeddings = self.fast_embeddings(codebooks) + x = torch.cat([x[:, None], codebook_embeddings], dim=1) + b, s = x.size(0), x.size(2) + x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len + + # Remove padded part + codebooks = rearrange(codebooks, "b n s -> (b s) n") + codebook_mask = (codebooks == 0).all(dim=-1) + + if torch.all(codebook_mask): + # If all codebooks are padded, we keep first 8 to make sure the model runs + codebook_mask[:8] = False + + x_bs, x_len = x.size(0), x.size(1) + x = x[~codebook_mask] + + for layer in self.fast_layers: + if self.config.use_gradient_checkpointing and self.training: + x = checkpoint( + layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True + ) + else: + x = layer(x, self.fast_freqs_cis, fast_mask) + + # unflatten the batch and num_codebooks + fast_out = self.fast_norm(x) + codebook_logits = self.fast_output(fast_out) + + # Re-pad the codebook_logits + buffer = torch.zeros( + x_bs, + x_len, + codebook_logits.size(-1), + device=codebook_logits.device, + dtype=codebook_logits.dtype, + ) + buffer[~codebook_mask] = codebook_logits + codebook_logits = buffer + + assert codebook_logits.shape[1] == self.config.num_codebooks + codebook_logits = rearrange( + codebook_logits, + "(b s) n d -> b s n d", + b=b, + s=s, + n=self.config.num_codebooks, + ) + + return TransformerForwardResult( + token_logits=token_logits, + codebook_logits=codebook_logits, + ) + + def forward_generate_fast( + self, x: Tensor, input_pos: Optional[Tensor] = None + ) -> Tensor: + # Fast transformer + x = x.view(1, 1, -1) + + fast_mask = self.causal_mask[ + None, None, input_pos, : self.config.num_codebooks + ] # (B, N, Q, K) + fast_freqs_cis = self.fast_freqs_cis[input_pos] + + for layer in self.fast_layers: + x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos) + + # unflatten the batch and num_codebooks + fast_out = self.fast_norm(x) # only take the last token + codebook_logits = self.fast_output(fast_out) + + return codebook_logits + + def forward_generate( + self, + x: Tensor, + input_pos: Optional[Tensor] = None, + vq_masks: Optional[Tensor] = None, + ) -> TransformerForwardResult: + x = super().forward_generate(x, input_pos, vq_masks) + x.hidden_states = self.fast_project_in(x.hidden_states) + return x + + +class TransformerBlock(nn.Module): + def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None: + super().__init__() + self.attention = Attention(config, use_sdpa=use_sdpa) + self.feed_forward = FeedForward(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward( + self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None + ) -> Tensor: + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class Attention(nn.Module): + def __init__(self, config: BaseModelArgs, use_sdpa: bool = True): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear( + config.dim, total_head_dim, bias=config.attention_qkv_bias + ) + self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.kv_cache = None + + self.dropout = config.dropout + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self.use_sdpa = use_sdpa + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward( + self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + + if self.use_sdpa: + if mask is None: + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + y = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.dropout if self.training else 0.0, + is_causal=True, + # No third party attn_mask here to use flash_attention + ) + else: + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=self.dropout if self.training else 0.0, + ) + else: + y = self.eq_scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=self.dropout if self.training else 0.0, + ) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + return self.wo(y) + + def eq_scaled_dot_product_attention( + self, + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + ) -> torch.Tensor: + # This is a standard scaled dot product attention + # It's low efficient, but it doesn't raise cuda error + + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) + attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + + return attn_weight @ value + + +class FeedForward(nn.Module): + def __init__(self, config: BaseModelArgs) -> None: + super().__init__() + self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor: + freqs = 1.0 / ( + base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) + ) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=torch.bfloat16) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) diff --git a/fish_speech/models/text2semantic/lora.py b/fish_speech/models/text2semantic/lora.py index 647ca6fcccf038e17d2cf91a2874281dff3e0938..bb4a6192c469bce1535b5f93e147f89ce05cca04 100644 --- a/fish_speech/models/text2semantic/lora.py +++ b/fish_speech/models/text2semantic/lora.py @@ -1,92 +1,92 @@ -from dataclasses import dataclass - -import loralib as lora - - -@dataclass -class LoraConfig: - r: int - lora_alpha: float - lora_dropout: float = 0.0 - - -def setup_lora(model, lora_config): - # Replace the embedding layer with a LoRA layer - model.embeddings = lora.Embedding( - num_embeddings=model.embeddings.num_embeddings, - embedding_dim=model.embeddings.embedding_dim, - padding_idx=model.embeddings.padding_idx, - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - ) - - model.codebook_embeddings = lora.Embedding( - num_embeddings=model.codebook_embeddings.num_embeddings, - embedding_dim=model.codebook_embeddings.embedding_dim, - padding_idx=model.codebook_embeddings.padding_idx, - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - ) - - # Replace output layer with a LoRA layer - linears = [(model, "output")] - - # Replace all linear layers with LoRA layers - for layer in model.layers: - linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) - linears.extend( - [ - (layer.feed_forward, "w1"), - (layer.feed_forward, "w2"), - (layer.feed_forward, "w3"), - ] - ) - - if hasattr(model, "fast_layers"): - model.fast_embeddings = lora.Embedding( - num_embeddings=model.fast_embeddings.num_embeddings, - embedding_dim=model.fast_embeddings.embedding_dim, - padding_idx=model.fast_embeddings.padding_idx, - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - ) - - # Dual-AR model - linears.append((model, "fast_output")) - - for layer in model.fast_layers: - linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) - linears.extend( - [ - (layer.feed_forward, "w1"), - (layer.feed_forward, "w2"), - (layer.feed_forward, "w3"), - ] - ) - - for module, layer in linears: - updated_linear = lora.Linear( - in_features=getattr(module, layer).in_features, - out_features=getattr(module, layer).out_features, - bias=getattr(module, layer).bias, - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - lora_dropout=lora_config.lora_dropout, - ) - setattr(module, layer, updated_linear) - - # Mark only the LoRA layers as trainable - lora.mark_only_lora_as_trainable(model, bias="none") - - -def get_merged_state_dict(model): - # This line will merge the state dict of the model and the LoRA parameters - model.eval() - - # Then we need to remove the LoRA parameters from the state dict - state_dict = model.state_dict() - for name in list(state_dict.keys()): - if "lora" in name: - state_dict.pop(name) - - return state_dict +from dataclasses import dataclass + +import loralib as lora + + +@dataclass +class LoraConfig: + r: int + lora_alpha: float + lora_dropout: float = 0.0 + + +def setup_lora(model, lora_config): + # Replace the embedding layer with a LoRA layer + model.embeddings = lora.Embedding( + num_embeddings=model.embeddings.num_embeddings, + embedding_dim=model.embeddings.embedding_dim, + padding_idx=model.embeddings.padding_idx, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + ) + + model.codebook_embeddings = lora.Embedding( + num_embeddings=model.codebook_embeddings.num_embeddings, + embedding_dim=model.codebook_embeddings.embedding_dim, + padding_idx=model.codebook_embeddings.padding_idx, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + ) + + # Replace output layer with a LoRA layer + linears = [(model, "output")] + + # Replace all linear layers with LoRA layers + for layer in model.layers: + linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) + linears.extend( + [ + (layer.feed_forward, "w1"), + (layer.feed_forward, "w2"), + (layer.feed_forward, "w3"), + ] + ) + + if hasattr(model, "fast_layers"): + model.fast_embeddings = lora.Embedding( + num_embeddings=model.fast_embeddings.num_embeddings, + embedding_dim=model.fast_embeddings.embedding_dim, + padding_idx=model.fast_embeddings.padding_idx, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + ) + + # Dual-AR model + linears.append((model, "fast_output")) + + for layer in model.fast_layers: + linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) + linears.extend( + [ + (layer.feed_forward, "w1"), + (layer.feed_forward, "w2"), + (layer.feed_forward, "w3"), + ] + ) + + for module, layer in linears: + updated_linear = lora.Linear( + in_features=getattr(module, layer).in_features, + out_features=getattr(module, layer).out_features, + bias=getattr(module, layer).bias, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + lora_dropout=lora_config.lora_dropout, + ) + setattr(module, layer, updated_linear) + + # Mark only the LoRA layers as trainable + lora.mark_only_lora_as_trainable(model, bias="none") + + +def get_merged_state_dict(model): + # This line will merge the state dict of the model and the LoRA parameters + model.eval() + + # Then we need to remove the LoRA parameters from the state dict + state_dict = model.state_dict() + for name in list(state_dict.keys()): + if "lora" in name: + state_dict.pop(name) + + return state_dict diff --git a/fish_speech/models/vqgan/lit_module.py b/fish_speech/models/vqgan/lit_module.py deleted file mode 100644 index bd0733ba748ab69bb539eb6b596b36a365ac460f..0000000000000000000000000000000000000000 --- a/fish_speech/models/vqgan/lit_module.py +++ /dev/null @@ -1,442 +0,0 @@ -import itertools -import math -from typing import Any, Callable - -import lightning as L -import torch -import torch.nn.functional as F -import wandb -from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger -from matplotlib import pyplot as plt -from torch import nn - -from fish_speech.models.vqgan.modules.discriminator import Discriminator -from fish_speech.models.vqgan.modules.wavenet import WaveNet -from fish_speech.models.vqgan.utils import avg_with_mask, plot_mel, sequence_mask - - -class VQGAN(L.LightningModule): - def __init__( - self, - optimizer: Callable, - lr_scheduler: Callable, - encoder: WaveNet, - quantizer: nn.Module, - decoder: WaveNet, - discriminator: Discriminator, - vocoder: nn.Module, - encode_mel_transform: nn.Module, - gt_mel_transform: nn.Module, - weight_adv: float = 1.0, - weight_vq: float = 1.0, - weight_mel: float = 1.0, - sampling_rate: int = 44100, - freeze_encoder: bool = False, - ): - super().__init__() - - # Model parameters - self.optimizer_builder = optimizer - self.lr_scheduler_builder = lr_scheduler - - # Modules - self.encoder = encoder - self.quantizer = quantizer - self.decoder = decoder - self.vocoder = vocoder - self.discriminator = discriminator - self.encode_mel_transform = encode_mel_transform - self.gt_mel_transform = gt_mel_transform - - # A simple linear layer to project quality to condition channels - self.quality_projection = nn.Linear(1, 768) - - # Freeze vocoder - for param in self.vocoder.parameters(): - param.requires_grad = False - - # Loss weights - self.weight_adv = weight_adv - self.weight_vq = weight_vq - self.weight_mel = weight_mel - - # Other parameters - self.sampling_rate = sampling_rate - - # Disable strict loading - self.strict_loading = False - - # If encoder is frozen - if freeze_encoder: - for param in self.encoder.parameters(): - param.requires_grad = False - - for param in self.quantizer.parameters(): - param.requires_grad = False - - self.automatic_optimization = False - - def on_save_checkpoint(self, checkpoint): - # Do not save vocoder - state_dict = checkpoint["state_dict"] - for name in list(state_dict.keys()): - if "vocoder" in name: - state_dict.pop(name) - - def configure_optimizers(self): - optimizer_generator = self.optimizer_builder( - itertools.chain( - self.encoder.parameters(), - self.quantizer.parameters(), - self.decoder.parameters(), - self.quality_projection.parameters(), - ) - ) - optimizer_discriminator = self.optimizer_builder( - self.discriminator.parameters() - ) - - lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator) - lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator) - - return ( - { - "optimizer": optimizer_generator, - "lr_scheduler": { - "scheduler": lr_scheduler_generator, - "interval": "step", - "name": "optimizer/generator", - }, - }, - { - "optimizer": optimizer_discriminator, - "lr_scheduler": { - "scheduler": lr_scheduler_discriminator, - "interval": "step", - "name": "optimizer/discriminator", - }, - }, - ) - - def training_step(self, batch, batch_idx): - optim_g, optim_d = self.optimizers() - - audios, audio_lengths = batch["audios"], batch["audio_lengths"] - - audios = audios.float() - audios = audios[:, None, :] - - with torch.no_grad(): - encoded_mels = self.encode_mel_transform(audios) - gt_mels = self.gt_mel_transform(audios) - quality = ((gt_mels.mean(-1) > -8).sum(-1) - 90) / 10 - quality = quality.unsqueeze(-1) - - mel_lengths = audio_lengths // self.gt_mel_transform.hop_length - mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2]) - mel_masks_float_conv = mel_masks[:, None, :].float() - gt_mels = gt_mels * mel_masks_float_conv - encoded_mels = encoded_mels * mel_masks_float_conv - - # Encode - encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv - - # Quantize - vq_result = self.quantizer(encoded_features) - loss_vq = getattr("vq_result", "loss", 0.0) - vq_recon_features = vq_result.z * mel_masks_float_conv - vq_recon_features = ( - vq_recon_features + self.quality_projection(quality)[:, :, None] - ) - - # VQ Decode - gen_mel = ( - self.decoder( - torch.randn_like(vq_recon_features) * mel_masks_float_conv, - condition=vq_recon_features, - ) - * mel_masks_float_conv - ) - - # Discriminator - real_logits = self.discriminator(gt_mels) - fake_logits = self.discriminator(gen_mel.detach()) - d_mask = F.interpolate( - mel_masks_float_conv, size=(real_logits.shape[2],), mode="nearest" - ) - - loss_real = avg_with_mask((real_logits - 1) ** 2, d_mask) - loss_fake = avg_with_mask(fake_logits**2, d_mask) - - loss_d = loss_real + loss_fake - - self.log( - "train/discriminator/loss", - loss_d, - on_step=True, - on_epoch=False, - prog_bar=True, - logger=True, - ) - - # Discriminator backward - optim_d.zero_grad() - self.manual_backward(loss_d) - self.clip_gradients( - optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm" - ) - optim_d.step() - - # Mel Loss, applying l1, using a weighted sum - mel_distance = ( - gen_mel - gt_mels - ).abs() # * 0.5 + self.ssim(gen_mel, gt_mels) * 0.5 - loss_mel_low_freq = avg_with_mask(mel_distance[:, :40, :], mel_masks_float_conv) - loss_mel_mid_freq = avg_with_mask( - mel_distance[:, 40:70, :], mel_masks_float_conv - ) - loss_mel_high_freq = avg_with_mask( - mel_distance[:, 70:, :], mel_masks_float_conv - ) - loss_mel = ( - loss_mel_low_freq * 0.6 + loss_mel_mid_freq * 0.3 + loss_mel_high_freq * 0.1 - ) - - # Adversarial Loss - fake_logits = self.discriminator(gen_mel) - loss_adv = avg_with_mask((fake_logits - 1) ** 2, d_mask) - - # Total loss - loss = ( - self.weight_vq * loss_vq - + self.weight_mel * loss_mel - + self.weight_adv * loss_adv - ) - - # Log losses - self.log( - "train/generator/loss", - loss, - on_step=True, - on_epoch=False, - prog_bar=True, - logger=True, - ) - self.log( - "train/generator/loss_vq", - loss_vq, - on_step=True, - on_epoch=False, - prog_bar=False, - logger=True, - ) - self.log( - "train/generator/loss_mel", - loss_mel, - on_step=True, - on_epoch=False, - prog_bar=False, - logger=True, - ) - self.log( - "train/generator/loss_adv", - loss_adv, - on_step=True, - on_epoch=False, - prog_bar=False, - logger=True, - ) - - # Generator backward - optim_g.zero_grad() - self.manual_backward(loss) - self.clip_gradients( - optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm" - ) - optim_g.step() - - scheduler_g, scheduler_d = self.lr_schedulers() - scheduler_g.step() - scheduler_d.step() - - def validation_step(self, batch: Any, batch_idx: int): - audios, audio_lengths = batch["audios"], batch["audio_lengths"] - - audios = audios.float() - audios = audios[:, None, :] - - encoded_mels = self.encode_mel_transform(audios) - gt_mels = self.gt_mel_transform(audios) - - mel_lengths = audio_lengths // self.gt_mel_transform.hop_length - mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2]) - mel_masks_float_conv = mel_masks[:, None, :].float() - gt_mels = gt_mels * mel_masks_float_conv - encoded_mels = encoded_mels * mel_masks_float_conv - - # Encode - encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv - - # Quantize - vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv - vq_recon_features = ( - vq_recon_features - + self.quality_projection( - torch.ones( - vq_recon_features.shape[0], 1, device=vq_recon_features.device - ) - * 2 - )[:, :, None] - ) - - # VQ Decode - gen_aux_mels = ( - self.decoder( - torch.randn_like(vq_recon_features) * mel_masks_float_conv, - condition=vq_recon_features, - ) - * mel_masks_float_conv - ) - loss_mel = avg_with_mask((gen_aux_mels - gt_mels).abs(), mel_masks_float_conv) - - self.log( - "val/loss_mel", - loss_mel, - on_step=False, - on_epoch=True, - prog_bar=False, - logger=True, - sync_dist=True, - ) - - recon_audios = self.vocoder(gt_mels) - gen_aux_audios = self.vocoder(gen_aux_mels) - - # only log the first batch - if batch_idx != 0: - return - - for idx, ( - gt_mel, - gen_aux_mel, - audio, - gen_aux_audio, - recon_audio, - audio_len, - ) in enumerate( - zip( - gt_mels, - gen_aux_mels, - audios.cpu().float(), - gen_aux_audios.cpu().float(), - recon_audios.cpu().float(), - audio_lengths, - ) - ): - if idx > 4: - break - - mel_len = audio_len // self.gt_mel_transform.hop_length - - image_mels = plot_mel( - [ - gt_mel[:, :mel_len], - gen_aux_mel[:, :mel_len], - ], - [ - "Ground-Truth", - "Auxiliary", - ], - ) - - if isinstance(self.logger, WandbLogger): - self.logger.experiment.log( - { - "reconstruction_mel": wandb.Image(image_mels, caption="mels"), - "wavs": [ - wandb.Audio( - audio[0, :audio_len], - sample_rate=self.sampling_rate, - caption="gt", - ), - wandb.Audio( - gen_aux_audio[0, :audio_len], - sample_rate=self.sampling_rate, - caption="aux", - ), - wandb.Audio( - recon_audio[0, :audio_len], - sample_rate=self.sampling_rate, - caption="recon", - ), - ], - }, - ) - - if isinstance(self.logger, TensorBoardLogger): - self.logger.experiment.add_figure( - f"sample-{idx}/mels", - image_mels, - global_step=self.global_step, - ) - self.logger.experiment.add_audio( - f"sample-{idx}/wavs/gt", - audio[0, :audio_len], - self.global_step, - sample_rate=self.sampling_rate, - ) - self.logger.experiment.add_audio( - f"sample-{idx}/wavs/gen", - gen_aux_audio[0, :audio_len], - self.global_step, - sample_rate=self.sampling_rate, - ) - self.logger.experiment.add_audio( - f"sample-{idx}/wavs/recon", - recon_audio[0, :audio_len], - self.global_step, - sample_rate=self.sampling_rate, - ) - - plt.close(image_mels) - - def encode(self, audios, audio_lengths): - audios = audios.float() - - mels = self.encode_mel_transform(audios) - mel_lengths = audio_lengths // self.encode_mel_transform.hop_length - mel_masks = sequence_mask(mel_lengths, mels.shape[2]) - mel_masks_float_conv = mel_masks[:, None, :].float() - mels = mels * mel_masks_float_conv - - # Encode - encoded_features = self.encoder(mels) * mel_masks_float_conv - feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor) - - return self.quantizer.encode(encoded_features), feature_lengths - - def decode(self, indices, feature_lengths, return_audios=False): - factor = math.prod(self.quantizer.downsample_factor) - mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor) - mel_masks_float_conv = mel_masks[:, None, :].float() - - z = self.quantizer.decode(indices) * mel_masks_float_conv - z = ( - z - + self.quality_projection(torch.ones(z.shape[0], 1, device=z.device) * 2)[ - :, :, None - ] - ) - - gen_mel = ( - self.decoder( - torch.randn_like(z) * mel_masks_float_conv, - condition=z, - ) - * mel_masks_float_conv - ) - - if return_audios: - return self.vocoder(gen_mel) - - return gen_mel diff --git a/fish_speech/models/vqgan/modules/discriminator.py b/fish_speech/models/vqgan/modules/discriminator.py deleted file mode 100644 index 69c7df41033f2cde22583468731f56b49eb594b7..0000000000000000000000000000000000000000 --- a/fish_speech/models/vqgan/modules/discriminator.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch -from torch import nn -from torch.nn.utils.parametrizations import weight_norm - - -class Discriminator(nn.Module): - def __init__(self): - super().__init__() - - blocks = [] - convs = [ - (1, 64, (3, 9), 1, (1, 4)), - (64, 128, (3, 9), (1, 2), (1, 4)), - (128, 256, (3, 9), (1, 2), (1, 4)), - (256, 512, (3, 9), (1, 2), (1, 4)), - (512, 1024, (3, 3), 1, (1, 1)), - (1024, 1, (3, 3), 1, (1, 1)), - ] - - for idx, (in_channels, out_channels, kernel_size, stride, padding) in enumerate( - convs - ): - blocks.append( - weight_norm( - nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) - ) - ) - - if idx != len(convs) - 1: - blocks.append(nn.SiLU(inplace=True)) - - self.blocks = nn.Sequential(*blocks) - - def forward(self, x): - return self.blocks(x[:, None])[:, 0] - - -if __name__ == "__main__": - model = Discriminator() - print(sum(p.numel() for p in model.parameters()) / 1_000_000) - x = torch.randn(1, 128, 1024) - y = model(x) - print(y.shape) - print(y) diff --git a/fish_speech/models/vqgan/modules/firefly.py b/fish_speech/models/vqgan/modules/firefly.py index aa21839b544174d5d91378c5daf8fe1b376a154a..c6f386fedb8b41f0094c5edcb9f934c3293bca96 100644 --- a/fish_speech/models/vqgan/modules/firefly.py +++ b/fish_speech/models/vqgan/modules/firefly.py @@ -1,596 +1,596 @@ -import math -from functools import partial -from math import prod -from typing import Callable - -import torch -import torch.nn.functional as F -from torch import nn -from torch.nn.utils.parametrizations import weight_norm -from torch.nn.utils.parametrize import remove_parametrizations -from torch.utils.checkpoint import checkpoint - - -def sequence_mask(length, max_length=None): - if max_length is None: - max_length = length.max() - x = torch.arange(max_length, dtype=length.dtype, device=length.device) - return x.unsqueeze(0) < length.unsqueeze(1) - - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv1D") != -1: - m.weight.data.normal_(mean, std) - - -def get_padding(kernel_size, dilation=1): - return (kernel_size * dilation - dilation) // 2 - - -def unpad1d(x: torch.Tensor, paddings: tuple[int, int]): - """Remove padding from x, handling properly zero padding. Only for 1d!""" - padding_left, padding_right = paddings - assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) - assert (padding_left + padding_right) <= x.shape[-1] - end = x.shape[-1] - padding_right - return x[..., padding_left:end] - - -def get_extra_padding_for_conv1d( - x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 -) -> int: - """See `pad_for_conv1d`.""" - length = x.shape[-1] - n_frames = (length - kernel_size + padding_total) / stride + 1 - ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) - return ideal_length - length - - -def pad1d( - x: torch.Tensor, - paddings: tuple[int, int], - mode: str = "zeros", - value: float = 0.0, -): - """Tiny wrapper around F.pad, just to allow for reflect padding on small input. - If this is the case, we insert extra 0 padding to the right - before the reflection happen. - """ - length = x.shape[-1] - padding_left, padding_right = paddings - assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) - if mode == "reflect": - max_pad = max(padding_left, padding_right) - extra_pad = 0 - if length <= max_pad: - extra_pad = max_pad - length + 1 - x = F.pad(x, (0, extra_pad)) - padded = F.pad(x, paddings, mode, value) - end = padded.shape[-1] - extra_pad - return padded[..., :end] - else: - return F.pad(x, paddings, mode, value) - - -class FishConvNet(nn.Module): - def __init__( - self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1 - ): - super(FishConvNet, self).__init__() - self.conv = nn.Conv1d( - in_channels, - out_channels, - kernel_size, - stride=stride, - dilation=dilation, - groups=groups, - ) - self.stride = stride - self.kernel_size = (kernel_size - 1) * dilation + 1 - self.dilation = dilation - - def forward(self, x): - pad = self.kernel_size - self.stride - extra_padding = get_extra_padding_for_conv1d( - x, self.kernel_size, self.stride, pad - ) - x = pad1d(x, (pad, extra_padding), mode="constant", value=0) - return self.conv(x).contiguous() - - def weight_norm(self, name="weight", dim=0): - self.conv = weight_norm(self.conv, name=name, dim=dim) - return self - - def remove_weight_norm(self): - self.conv = remove_parametrizations(self.conv) - return self - - -class FishTransConvNet(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1): - super(FishTransConvNet, self).__init__() - self.conv = nn.ConvTranspose1d( - in_channels, out_channels, kernel_size, stride=stride, dilation=dilation - ) - self.stride = stride - self.kernel_size = kernel_size - - def forward(self, x): - x = self.conv(x) - pad = self.kernel_size - self.stride - padding_right = math.ceil(pad) - padding_left = pad - padding_right - x = unpad1d(x, (padding_left, padding_right)) - return x.contiguous() - - def weight_norm(self, name="weight", dim=0): - self.conv = weight_norm(self.conv, name=name, dim=dim) - return self - - def remove_weight_norm(self): - self.conv = remove_parametrizations(self.conv) - return self - - -class ResBlock1(torch.nn.Module): - def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): - super().__init__() - - self.convs1 = nn.ModuleList( - [ - FishConvNet( - channels, channels, kernel_size, stride=1, dilation=dilation[0] - ).weight_norm(), - FishConvNet( - channels, channels, kernel_size, stride=1, dilation=dilation[1] - ).weight_norm(), - FishConvNet( - channels, channels, kernel_size, stride=1, dilation=dilation[2] - ).weight_norm(), - ] - ) - self.convs1.apply(init_weights) - - self.convs2 = nn.ModuleList( - [ - FishConvNet( - channels, channels, kernel_size, stride=1, dilation=dilation[0] - ).weight_norm(), - FishConvNet( - channels, channels, kernel_size, stride=1, dilation=dilation[1] - ).weight_norm(), - FishConvNet( - channels, channels, kernel_size, stride=1, dilation=dilation[2] - ).weight_norm(), - ] - ) - self.convs2.apply(init_weights) - - def forward(self, x): - for c1, c2 in zip(self.convs1, self.convs2): - xt = F.silu(x) - xt = c1(xt) - xt = F.silu(xt) - xt = c2(xt) - x = xt + x - return x - - def remove_parametrizations(self): - for conv in self.convs1: - remove_parametrizations(conv, tensor_name="weight") - for conv in self.convs2: - remove_parametrizations(conv, tensor_name="weight") - - -class ParallelBlock(nn.Module): - def __init__( - self, - channels: int, - kernel_sizes: tuple[int] = (3, 7, 11), - dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)), - ): - super().__init__() - - assert len(kernel_sizes) == len(dilation_sizes) - - self.blocks = nn.ModuleList() - for k, d in zip(kernel_sizes, dilation_sizes): - self.blocks.append(ResBlock1(channels, k, d)) - - def forward(self, x): - return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0) - - def remove_parametrizations(self): - for block in self.blocks: - block.remove_parametrizations() - - -class HiFiGANGenerator(nn.Module): - def __init__( - self, - *, - hop_length: int = 512, - upsample_rates: tuple[int] = (8, 8, 2, 2, 2), - upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2), - resblock_kernel_sizes: tuple[int] = (3, 7, 11), - resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)), - num_mels: int = 128, - upsample_initial_channel: int = 512, - pre_conv_kernel_size: int = 7, - post_conv_kernel_size: int = 7, - post_activation: Callable = partial(nn.SiLU, inplace=True), - ): - super().__init__() - - assert ( - prod(upsample_rates) == hop_length - ), f"hop_length must be {prod(upsample_rates)}" - - self.conv_pre = FishConvNet( - num_mels, - upsample_initial_channel, - pre_conv_kernel_size, - stride=1, - ).weight_norm() - - self.num_upsamples = len(upsample_rates) - self.num_kernels = len(resblock_kernel_sizes) - - self.noise_convs = nn.ModuleList() - self.ups = nn.ModuleList() - - for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): - self.ups.append( - FishTransConvNet( - upsample_initial_channel // (2**i), - upsample_initial_channel // (2 ** (i + 1)), - k, - stride=u, - ).weight_norm() - ) - - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = upsample_initial_channel // (2 ** (i + 1)) - self.resblocks.append( - ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes) - ) - - self.activation_post = post_activation() - self.conv_post = FishConvNet( - ch, 1, post_conv_kernel_size, stride=1 - ).weight_norm() - self.ups.apply(init_weights) - self.conv_post.apply(init_weights) - - def forward(self, x): - x = self.conv_pre(x) - - for i in range(self.num_upsamples): - x = F.silu(x, inplace=True) - x = self.ups[i](x) - - if self.training and self.checkpointing: - x = checkpoint( - self.resblocks[i], - x, - use_reentrant=False, - ) - else: - x = self.resblocks[i](x) - - x = self.activation_post(x) - x = self.conv_post(x) - x = torch.tanh(x) - - return x - - def remove_parametrizations(self): - for up in self.ups: - remove_parametrizations(up, tensor_name="weight") - for block in self.resblocks: - block.remove_parametrizations() - remove_parametrizations(self.conv_pre, tensor_name="weight") - remove_parametrizations(self.conv_post, tensor_name="weight") - - -# DropPath copied from timm library -def drop_path( - x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True -): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - - This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, - the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for - changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use - 'survival rate' as the argument. - - """ # noqa: E501 - - if drop_prob == 0.0 or not training: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * ( - x.ndim - 1 - ) # work with diff dim tensors, not just 2D ConvNets - random_tensor = x.new_empty(shape).bernoulli_(keep_prob) - if keep_prob > 0.0 and scale_by_keep: - random_tensor.div_(keep_prob) - return x * random_tensor - - -class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501 - - def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): - super(DropPath, self).__init__() - self.drop_prob = drop_prob - self.scale_by_keep = scale_by_keep - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) - - def extra_repr(self): - return f"drop_prob={round(self.drop_prob,3):0.3f}" - - -class LayerNorm(nn.Module): - r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. - The ordering of the dimensions in the inputs. channels_last corresponds to inputs with - shape (batch_size, height, width, channels) while channels_first corresponds to inputs - with shape (batch_size, channels, height, width). - """ # noqa: E501 - - def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): - super().__init__() - self.weight = nn.Parameter(torch.ones(normalized_shape)) - self.bias = nn.Parameter(torch.zeros(normalized_shape)) - self.eps = eps - self.data_format = data_format - if self.data_format not in ["channels_last", "channels_first"]: - raise NotImplementedError - self.normalized_shape = (normalized_shape,) - - def forward(self, x): - if self.data_format == "channels_last": - return F.layer_norm( - x, self.normalized_shape, self.weight, self.bias, self.eps - ) - elif self.data_format == "channels_first": - u = x.mean(1, keepdim=True) - s = (x - u).pow(2).mean(1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.eps) - x = self.weight[:, None] * x + self.bias[:, None] - return x - - -# ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py -class ConvNeXtBlock(nn.Module): - r"""ConvNeXt Block. There are two equivalent implementations: - (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) - (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back - We use (2) as we find it slightly faster in PyTorch - - Args: - dim (int): Number of input channels. - drop_path (float): Stochastic depth rate. Default: 0.0 - layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. - kernel_size (int): Kernel size for depthwise conv. Default: 7. - dilation (int): Dilation for depthwise conv. Default: 1. - """ # noqa: E501 - - def __init__( - self, - dim: int, - drop_path: float = 0.0, - layer_scale_init_value: float = 1e-6, - mlp_ratio: float = 4.0, - kernel_size: int = 7, - dilation: int = 1, - ): - super().__init__() - - self.dwconv = FishConvNet( - dim, - dim, - kernel_size=kernel_size, - # padding=int(dilation * (kernel_size - 1) / 2), - groups=dim, - ) # depthwise conv - self.norm = LayerNorm(dim, eps=1e-6) - self.pwconv1 = nn.Linear( - dim, int(mlp_ratio * dim) - ) # pointwise/1x1 convs, implemented with linear layers - self.act = nn.GELU() - self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim) - self.gamma = ( - nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) - if layer_scale_init_value > 0 - else None - ) - self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - - def forward(self, x, apply_residual: bool = True): - input = x - - x = self.dwconv(x) - x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C) - x = self.norm(x) - x = self.pwconv1(x) - x = self.act(x) - x = self.pwconv2(x) - - if self.gamma is not None: - x = self.gamma * x - - x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L) - x = self.drop_path(x) - - if apply_residual: - x = input + x - - return x - - -class ConvNeXtEncoder(nn.Module): - def __init__( - self, - input_channels: int = 3, - depths: list[int] = [3, 3, 9, 3], - dims: list[int] = [96, 192, 384, 768], - drop_path_rate: float = 0.0, - layer_scale_init_value: float = 1e-6, - kernel_size: int = 7, - ): - super().__init__() - assert len(depths) == len(dims) - - self.downsample_layers = nn.ModuleList() - stem = nn.Sequential( - FishConvNet( - input_channels, - dims[0], - kernel_size=7, - # padding=3, - # padding_mode="replicate", - # padding_mode="zeros", - ), - LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), - ) - self.downsample_layers.append(stem) - - for i in range(len(depths) - 1): - mid_layer = nn.Sequential( - LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), - nn.Conv1d(dims[i], dims[i + 1], kernel_size=1), - ) - self.downsample_layers.append(mid_layer) - - self.stages = nn.ModuleList() - dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] - - cur = 0 - for i in range(len(depths)): - stage = nn.Sequential( - *[ - ConvNeXtBlock( - dim=dims[i], - drop_path=dp_rates[cur + j], - layer_scale_init_value=layer_scale_init_value, - kernel_size=kernel_size, - ) - for j in range(depths[i]) - ] - ) - self.stages.append(stage) - cur += depths[i] - - self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first") - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv1d, nn.Linear)): - nn.init.trunc_normal_(m.weight, std=0.02) - nn.init.constant_(m.bias, 0) - - def forward( - self, - x: torch.Tensor, - ) -> torch.Tensor: - for i in range(len(self.downsample_layers)): - x = self.downsample_layers[i](x) - x = self.stages[i](x) - - return self.norm(x) - - -class FireflyArchitecture(nn.Module): - def __init__( - self, - backbone: nn.Module, - head: nn.Module, - quantizer: nn.Module, - spec_transform: nn.Module, - ): - super().__init__() - - self.backbone = backbone - self.head = head - self.quantizer = quantizer - self.spec_transform = spec_transform - self.downsample_factor = math.prod(self.quantizer.downsample_factor) - - def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor: - if self.spec_transform is not None: - x = self.spec_transform(x) - - x = self.backbone(x) - if mask is not None: - x = x * mask - - if self.quantizer is not None: - vq_result = self.quantizer(x) - x = vq_result.z - - if mask is not None: - x = x * mask - - x = self.head(x, template=template) - - if x.ndim == 2: - x = x[:, None, :] - - if self.vq is not None: - return x, vq_result - - return x - - def encode(self, audios, audio_lengths): - audios = audios.float() - - mels = self.spec_transform(audios) - mel_lengths = audio_lengths // self.spec_transform.hop_length - mel_masks = sequence_mask(mel_lengths, mels.shape[2]) - mel_masks_float_conv = mel_masks[:, None, :].float() - mels = mels * mel_masks_float_conv - - # Encode - encoded_features = self.backbone(mels) * mel_masks_float_conv - feature_lengths = mel_lengths // self.downsample_factor - - return self.quantizer.encode(encoded_features), feature_lengths - - def decode(self, indices, feature_lengths) -> torch.Tensor: - mel_masks = sequence_mask( - feature_lengths * self.downsample_factor, - indices.shape[2] * self.downsample_factor, - ) - mel_masks_float_conv = mel_masks[:, None, :].float() - audio_lengths = ( - feature_lengths * self.downsample_factor * self.spec_transform.hop_length - ) - - audio_masks = sequence_mask( - audio_lengths, - indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length, - ) - audio_masks_float_conv = audio_masks[:, None, :].float() - - z = self.quantizer.decode(indices) * mel_masks_float_conv - x = self.head(z) * audio_masks_float_conv - - return x, audio_lengths - - def remove_parametrizations(self): - if hasattr(self.backbone, "remove_parametrizations"): - self.backbone.remove_parametrizations() - - if hasattr(self.head, "remove_parametrizations"): - self.head.remove_parametrizations() - - @property - def device(self): - return next(self.parameters()).device +import math +from functools import partial +from math import prod +from typing import Callable + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import remove_parametrizations +from torch.utils.checkpoint import checkpoint + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv1D") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return (kernel_size * dilation - dilation) // 2 + + +def unpad1d(x: torch.Tensor, paddings: tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left:end] + + +def get_extra_padding_for_conv1d( + x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 +) -> int: + """See `pad_for_conv1d`.""" + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad1d( + x: torch.Tensor, + paddings: tuple[int, int], + mode: str = "zeros", + value: float = 0.0, +): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right + before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == "reflect": + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +class FishConvNet(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1 + ): + super(FishConvNet, self).__init__() + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + ) + self.stride = stride + self.kernel_size = (kernel_size - 1) * dilation + 1 + self.dilation = dilation + + def forward(self, x): + pad = self.kernel_size - self.stride + extra_padding = get_extra_padding_for_conv1d( + x, self.kernel_size, self.stride, pad + ) + x = pad1d(x, (pad, extra_padding), mode="constant", value=0) + return self.conv(x).contiguous() + + def weight_norm(self, name="weight", dim=0): + self.conv = weight_norm(self.conv, name=name, dim=dim) + return self + + def remove_parametrizations(self, name="weight"): + self.conv = remove_parametrizations(self.conv, name) + return self + + +class FishTransConvNet(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1): + super(FishTransConvNet, self).__init__() + self.conv = nn.ConvTranspose1d( + in_channels, out_channels, kernel_size, stride=stride, dilation=dilation + ) + self.stride = stride + self.kernel_size = kernel_size + + def forward(self, x): + x = self.conv(x) + pad = self.kernel_size - self.stride + padding_right = math.ceil(pad) + padding_left = pad - padding_right + x = unpad1d(x, (padding_left, padding_right)) + return x.contiguous() + + def weight_norm(self, name="weight", dim=0): + self.conv = weight_norm(self.conv, name=name, dim=dim) + return self + + def remove_parametrizations(self, name="weight"): + self.conv = remove_parametrizations(self.conv, name) + return self + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + + self.convs1 = nn.ModuleList( + [ + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[0] + ).weight_norm(), + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[1] + ).weight_norm(), + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[2] + ).weight_norm(), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[0] + ).weight_norm(), + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[1] + ).weight_norm(), + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[2] + ).weight_norm(), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.silu(x) + xt = c1(xt) + xt = F.silu(xt) + xt = c2(xt) + x = xt + x + return x + + def remove_parametrizations(self): + for conv in self.convs1: + conv.remove_parametrizations() + for conv in self.convs2: + conv.remove_parametrizations() + + +class ParallelBlock(nn.Module): + def __init__( + self, + channels: int, + kernel_sizes: tuple[int] = (3, 7, 11), + dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)), + ): + super().__init__() + + assert len(kernel_sizes) == len(dilation_sizes) + + self.blocks = nn.ModuleList() + for k, d in zip(kernel_sizes, dilation_sizes): + self.blocks.append(ResBlock1(channels, k, d)) + + def forward(self, x): + return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0) + + def remove_parametrizations(self): + for block in self.blocks: + block.remove_parametrizations() + + +class HiFiGANGenerator(nn.Module): + def __init__( + self, + *, + hop_length: int = 512, + upsample_rates: tuple[int] = (8, 8, 2, 2, 2), + upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2), + resblock_kernel_sizes: tuple[int] = (3, 7, 11), + resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)), + num_mels: int = 128, + upsample_initial_channel: int = 512, + pre_conv_kernel_size: int = 7, + post_conv_kernel_size: int = 7, + post_activation: Callable = partial(nn.SiLU, inplace=True), + ): + super().__init__() + + assert ( + prod(upsample_rates) == hop_length + ), f"hop_length must be {prod(upsample_rates)}" + + self.conv_pre = FishConvNet( + num_mels, + upsample_initial_channel, + pre_conv_kernel_size, + stride=1, + ).weight_norm() + + self.num_upsamples = len(upsample_rates) + self.num_kernels = len(resblock_kernel_sizes) + + self.noise_convs = nn.ModuleList() + self.ups = nn.ModuleList() + + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + FishTransConvNet( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + stride=u, + ).weight_norm() + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + self.resblocks.append( + ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes) + ) + + self.activation_post = post_activation() + self.conv_post = FishConvNet( + ch, 1, post_conv_kernel_size, stride=1 + ).weight_norm() + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + x = F.silu(x, inplace=True) + x = self.ups[i](x) + + if self.training and self.checkpointing: + x = checkpoint( + self.resblocks[i], + x, + use_reentrant=False, + ) + else: + x = self.resblocks[i](x) + + x = self.activation_post(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_parametrizations(self): + for up in self.ups: + up.remove_parametrizations() + for block in self.resblocks: + block.remove_parametrizations() + self.conv_pre.remove_parametrizations() + self.conv_post.remove_parametrizations() + + +# DropPath copied from timm library +def drop_path( + x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True +): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ # noqa: E501 + + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501 + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f"drop_prob={round(self.drop_prob,3):0.3f}" + + +class LayerNorm(nn.Module): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ # noqa: E501 + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps + ) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None] * x + self.bias[:, None] + return x + + +# ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py +class ConvNeXtBlock(nn.Module): + r"""ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + kernel_size (int): Kernel size for depthwise conv. Default: 7. + dilation (int): Dilation for depthwise conv. Default: 1. + """ # noqa: E501 + + def __init__( + self, + dim: int, + drop_path: float = 0.0, + layer_scale_init_value: float = 1e-6, + mlp_ratio: float = 4.0, + kernel_size: int = 7, + dilation: int = 1, + ): + super().__init__() + + self.dwconv = FishConvNet( + dim, + dim, + kernel_size=kernel_size, + # padding=int(dilation * (kernel_size - 1) / 2), + groups=dim, + ) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, int(mlp_ratio * dim) + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x, apply_residual: bool = True): + input = x + + x = self.dwconv(x) + x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + + if self.gamma is not None: + x = self.gamma * x + + x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L) + x = self.drop_path(x) + + if apply_residual: + x = input + x + + return x + + +class ConvNeXtEncoder(nn.Module): + def __init__( + self, + input_channels: int = 3, + depths: list[int] = [3, 3, 9, 3], + dims: list[int] = [96, 192, 384, 768], + drop_path_rate: float = 0.0, + layer_scale_init_value: float = 1e-6, + kernel_size: int = 7, + ): + super().__init__() + assert len(depths) == len(dims) + + self.downsample_layers = nn.ModuleList() + stem = nn.Sequential( + FishConvNet( + input_channels, + dims[0], + kernel_size=7, + # padding=3, + # padding_mode="replicate", + # padding_mode="zeros", + ), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), + ) + self.downsample_layers.append(stem) + + for i in range(len(depths) - 1): + mid_layer = nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), + nn.Conv1d(dims[i], dims[i + 1], kernel_size=1), + ) + self.downsample_layers.append(mid_layer) + + self.stages = nn.ModuleList() + dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + + cur = 0 + for i in range(len(depths)): + stage = nn.Sequential( + *[ + ConvNeXtBlock( + dim=dims[i], + drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_value, + kernel_size=kernel_size, + ) + for j in range(depths[i]) + ] + ) + self.stages.append(stage) + cur += depths[i] + + self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first") + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + for i in range(len(self.downsample_layers)): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + + return self.norm(x) + + +class FireflyArchitecture(nn.Module): + def __init__( + self, + backbone: nn.Module, + head: nn.Module, + quantizer: nn.Module, + spec_transform: nn.Module, + ): + super().__init__() + + self.backbone = backbone + self.head = head + self.quantizer = quantizer + self.spec_transform = spec_transform + self.downsample_factor = math.prod(self.quantizer.downsample_factor) + + def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor: + if self.spec_transform is not None: + x = self.spec_transform(x) + + x = self.backbone(x) + if mask is not None: + x = x * mask + + if self.quantizer is not None: + vq_result = self.quantizer(x) + x = vq_result.z + + if mask is not None: + x = x * mask + + x = self.head(x, template=template) + + if x.ndim == 2: + x = x[:, None, :] + + if self.vq is not None: + return x, vq_result + + return x + + def encode(self, audios, audio_lengths): + audios = audios.float() + + mels = self.spec_transform(audios) + mel_lengths = audio_lengths // self.spec_transform.hop_length + mel_masks = sequence_mask(mel_lengths, mels.shape[2]) + mel_masks_float_conv = mel_masks[:, None, :].float() + mels = mels * mel_masks_float_conv + + # Encode + encoded_features = self.backbone(mels) * mel_masks_float_conv + feature_lengths = mel_lengths // self.downsample_factor + + return self.quantizer.encode(encoded_features), feature_lengths + + def decode(self, indices, feature_lengths) -> torch.Tensor: + mel_masks = sequence_mask( + feature_lengths * self.downsample_factor, + indices.shape[2] * self.downsample_factor, + ) + mel_masks_float_conv = mel_masks[:, None, :].float() + audio_lengths = ( + feature_lengths * self.downsample_factor * self.spec_transform.hop_length + ) + + audio_masks = sequence_mask( + audio_lengths, + indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length, + ) + audio_masks_float_conv = audio_masks[:, None, :].float() + + z = self.quantizer.decode(indices) * mel_masks_float_conv + x = self.head(z) * audio_masks_float_conv + + return x, audio_lengths + + def remove_parametrizations(self): + if hasattr(self.backbone, "remove_parametrizations"): + self.backbone.remove_parametrizations() + + if hasattr(self.head, "remove_parametrizations"): + self.head.remove_parametrizations() + + @property + def device(self): + return next(self.parameters()).device diff --git a/fish_speech/models/vqgan/modules/fsq.py b/fish_speech/models/vqgan/modules/fsq.py index 7ea4853376b6e663404ff48d6c6b5f664dde4094..4cc565f7a187b88f49f9c8a540971889bb10ce7e 100644 --- a/fish_speech/models/vqgan/modules/fsq.py +++ b/fish_speech/models/vqgan/modules/fsq.py @@ -1,116 +1,116 @@ -from dataclasses import dataclass - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from vector_quantize_pytorch import GroupedResidualFSQ - -from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet - - -@dataclass -class FSQResult: - z: torch.Tensor - codes: torch.Tensor - latents: torch.Tensor - - -class DownsampleFiniteScalarQuantize(nn.Module): - def __init__( - self, - input_dim: int = 512, - n_codebooks: int = 9, - n_groups: int = 1, - levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10 - downsample_factor: tuple[int] = (2, 2), - downsample_dims: tuple[int] | None = None, - ): - super().__init__() - - if downsample_dims is None: - downsample_dims = [input_dim for _ in range(len(downsample_factor))] - - all_dims = (input_dim,) + tuple(downsample_dims) - - self.residual_fsq = GroupedResidualFSQ( - dim=all_dims[-1], - levels=levels, - num_quantizers=n_codebooks, - groups=n_groups, - ) - - self.downsample_factor = downsample_factor - self.downsample_dims = downsample_dims - - self.downsample = nn.Sequential( - *[ - nn.Sequential( - FishConvNet( - all_dims[idx], - all_dims[idx + 1], - kernel_size=factor, - stride=factor, - ), - ConvNeXtBlock(dim=all_dims[idx + 1]), - ) - for idx, factor in enumerate(downsample_factor) - ] - ) - - self.upsample = nn.Sequential( - *[ - nn.Sequential( - FishTransConvNet( - all_dims[idx + 1], - all_dims[idx], - kernel_size=factor, - stride=factor, - ), - ConvNeXtBlock(dim=all_dims[idx]), - ) - for idx, factor in reversed(list(enumerate(downsample_factor))) - ] - ) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv1d, nn.Linear)): - nn.init.trunc_normal_(m.weight, std=0.02) - nn.init.constant_(m.bias, 0) - - def forward(self, z) -> FSQResult: - original_shape = z.shape - z = self.downsample(z) - quantized, indices = self.residual_fsq(z.mT) - result = FSQResult( - z=quantized.mT, - codes=indices.mT, - latents=z, - ) - result.z = self.upsample(result.z) - - # Pad or crop z to match original shape - diff = original_shape[-1] - result.z.shape[-1] - left = diff // 2 - right = diff - left - - if diff > 0: - result.z = F.pad(result.z, (left, right)) - elif diff < 0: - result.z = result.z[..., left:-right] - - return result - - def encode(self, z): - z = self.downsample(z) - _, indices = self.residual_fsq(z.mT) - indices = rearrange(indices, "g b l r -> b (g r) l") - return indices - - def decode(self, indices: torch.Tensor): - indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups) - z_q = self.residual_fsq.get_output_from_indices(indices) - z_q = self.upsample(z_q.mT) - return z_q +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from vector_quantize_pytorch import GroupedResidualFSQ + +from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet + + +@dataclass +class FSQResult: + z: torch.Tensor + codes: torch.Tensor + latents: torch.Tensor + + +class DownsampleFiniteScalarQuantize(nn.Module): + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + n_groups: int = 1, + levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10 + downsample_factor: tuple[int] = (2, 2), + downsample_dims: tuple[int] | None = None, + ): + super().__init__() + + if downsample_dims is None: + downsample_dims = [input_dim for _ in range(len(downsample_factor))] + + all_dims = (input_dim,) + tuple(downsample_dims) + + self.residual_fsq = GroupedResidualFSQ( + dim=all_dims[-1], + levels=levels, + num_quantizers=n_codebooks, + groups=n_groups, + ) + + self.downsample_factor = downsample_factor + self.downsample_dims = downsample_dims + + self.downsample = nn.Sequential( + *[ + nn.Sequential( + FishConvNet( + all_dims[idx], + all_dims[idx + 1], + kernel_size=factor, + stride=factor, + ), + ConvNeXtBlock(dim=all_dims[idx + 1]), + ) + for idx, factor in enumerate(downsample_factor) + ] + ) + + self.upsample = nn.Sequential( + *[ + nn.Sequential( + FishTransConvNet( + all_dims[idx + 1], + all_dims[idx], + kernel_size=factor, + stride=factor, + ), + ConvNeXtBlock(dim=all_dims[idx]), + ) + for idx, factor in reversed(list(enumerate(downsample_factor))) + ] + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward(self, z) -> FSQResult: + original_shape = z.shape + z = self.downsample(z) + quantized, indices = self.residual_fsq(z.mT) + result = FSQResult( + z=quantized.mT, + codes=indices.mT, + latents=z, + ) + result.z = self.upsample(result.z) + + # Pad or crop z to match original shape + diff = original_shape[-1] - result.z.shape[-1] + left = diff // 2 + right = diff - left + + if diff > 0: + result.z = F.pad(result.z, (left, right)) + elif diff < 0: + result.z = result.z[..., -left:right] + + return result + + def encode(self, z): + z = self.downsample(z) + _, indices = self.residual_fsq(z.mT) + indices = rearrange(indices, "g b l r -> b (g r) l") + return indices + + def decode(self, indices: torch.Tensor): + indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups) + z_q = self.residual_fsq.get_output_from_indices(indices) + z_q = self.upsample(z_q.mT) + return z_q diff --git a/fish_speech/models/vqgan/modules/reference.py b/fish_speech/models/vqgan/modules/reference.py deleted file mode 100644 index 034d5c5e3572bd3828649fc0f82a1856ccc6b9e1..0000000000000000000000000000000000000000 --- a/fish_speech/models/vqgan/modules/reference.py +++ /dev/null @@ -1,113 +0,0 @@ -from typing import Optional - -import torch -import torch.nn.functional as F -from torch import nn - -from .wavenet import WaveNet - - -class ReferenceEncoder(WaveNet): - def __init__( - self, - input_channels: Optional[int] = None, - output_channels: Optional[int] = None, - residual_channels: int = 512, - residual_layers: int = 20, - dilation_cycle: Optional[int] = 4, - num_heads: int = 8, - latent_len: int = 4, - ): - super().__init__( - input_channels=input_channels, - residual_channels=residual_channels, - residual_layers=residual_layers, - dilation_cycle=dilation_cycle, - ) - - self.head_dim = residual_channels // num_heads - self.num_heads = num_heads - - self.latent_len = latent_len - self.latent = nn.Parameter(torch.zeros(1, self.latent_len, residual_channels)) - - self.q = nn.Linear(residual_channels, residual_channels, bias=True) - self.kv = nn.Linear(residual_channels, residual_channels * 2, bias=True) - self.q_norm = nn.LayerNorm(self.head_dim) - self.k_norm = nn.LayerNorm(self.head_dim) - self.proj = nn.Linear(residual_channels, residual_channels) - self.proj_drop = nn.Dropout(0.1) - - self.norm = nn.LayerNorm(residual_channels) - self.mlp = nn.Sequential( - nn.Linear(residual_channels, residual_channels * 4), - nn.SiLU(), - nn.Linear(residual_channels * 4, residual_channels), - ) - self.output_projection_attn = nn.Linear(residual_channels, output_channels) - - torch.nn.init.trunc_normal_(self.latent, std=0.02) - self.apply(self.init_weights) - - def init_weights(self, m): - if isinstance(m, nn.Linear): - torch.nn.init.trunc_normal_(m.weight, std=0.02) - if m.bias is not None: - torch.nn.init.constant_(m.bias, 0) - - def forward(self, x, attn_mask=None): - x = super().forward(x).mT - B, N, C = x.shape - - # Calculate mask - if attn_mask is not None: - assert attn_mask.shape == (B, N) and attn_mask.dtype == torch.bool - - attn_mask = attn_mask[:, None, None, :].expand( - B, self.num_heads, self.latent_len, N - ) - - q_latent = self.latent.expand(B, -1, -1) - q = ( - self.q(q_latent) - .reshape(B, self.latent_len, self.num_heads, self.head_dim) - .transpose(1, 2) - ) - - kv = ( - self.kv(x) - .reshape(B, N, 2, self.num_heads, self.head_dim) - .permute(2, 0, 3, 1, 4) - ) - k, v = kv.unbind(0) - - q, k = self.q_norm(q), self.k_norm(k) - x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) - - x = x.transpose(1, 2).reshape(B, self.latent_len, C) - x = self.proj(x) - x = self.proj_drop(x) - - x = x + self.mlp(self.norm(x)) - x = self.output_projection_attn(x) - x = x.mean(1) - - return x - - -if __name__ == "__main__": - with torch.autocast(device_type="cpu", dtype=torch.bfloat16): - model = ReferenceEncoder( - input_channels=128, - output_channels=64, - residual_channels=384, - residual_layers=20, - dilation_cycle=4, - num_heads=8, - ) - x = torch.randn(4, 128, 64) - mask = torch.ones(4, 64, dtype=torch.bool) - y = model(x, mask) - print(y.shape) - loss = F.mse_loss(y, torch.randn(4, 64)) - loss.backward() diff --git a/fish_speech/models/vqgan/modules/wavenet.py b/fish_speech/models/vqgan/modules/wavenet.py deleted file mode 100644 index e7cc011c3e071067ff36e1aba12c05cff81d94f6..0000000000000000000000000000000000000000 --- a/fish_speech/models/vqgan/modules/wavenet.py +++ /dev/null @@ -1,225 +0,0 @@ -import math -from typing import Optional - -import torch -import torch.nn.functional as F -from torch import nn - - -class Mish(nn.Module): - def forward(self, x): - return x * torch.tanh(F.softplus(x)) - - -class DiffusionEmbedding(nn.Module): - """Diffusion Step Embedding""" - - def __init__(self, d_denoiser): - super(DiffusionEmbedding, self).__init__() - self.dim = d_denoiser - - def forward(self, x): - device = x.device - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = x[:, None] * emb[None, :] - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb - - -class LinearNorm(nn.Module): - """LinearNorm Projection""" - - def __init__(self, in_features, out_features, bias=False): - super(LinearNorm, self).__init__() - self.linear = nn.Linear(in_features, out_features, bias) - - nn.init.xavier_uniform_(self.linear.weight) - if bias: - nn.init.constant_(self.linear.bias, 0.0) - - def forward(self, x): - x = self.linear(x) - return x - - -class ConvNorm(nn.Module): - """1D Convolution""" - - def __init__( - self, - in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=None, - dilation=1, - bias=True, - w_init_gain="linear", - ): - super(ConvNorm, self).__init__() - - if padding is None: - assert kernel_size % 2 == 1 - padding = int(dilation * (kernel_size - 1) / 2) - - self.conv = nn.Conv1d( - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - bias=bias, - ) - nn.init.kaiming_normal_(self.conv.weight) - - def forward(self, signal): - conv_signal = self.conv(signal) - - return conv_signal - - -class ResidualBlock(nn.Module): - """Residual Block""" - - def __init__( - self, - residual_channels, - use_linear_bias=False, - dilation=1, - condition_channels=None, - ): - super(ResidualBlock, self).__init__() - self.conv_layer = ConvNorm( - residual_channels, - 2 * residual_channels, - kernel_size=3, - stride=1, - padding=dilation, - dilation=dilation, - ) - - if condition_channels is not None: - self.diffusion_projection = LinearNorm( - residual_channels, residual_channels, use_linear_bias - ) - self.condition_projection = ConvNorm( - condition_channels, 2 * residual_channels, kernel_size=1 - ) - - self.output_projection = ConvNorm( - residual_channels, 2 * residual_channels, kernel_size=1 - ) - - def forward(self, x, condition=None, diffusion_step=None): - y = x - - if diffusion_step is not None: - diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) - y = y + diffusion_step - - y = self.conv_layer(y) - - if condition is not None: - condition = self.condition_projection(condition) - y = y + condition - - gate, filter = torch.chunk(y, 2, dim=1) - y = torch.sigmoid(gate) * torch.tanh(filter) - - y = self.output_projection(y) - residual, skip = torch.chunk(y, 2, dim=1) - - return (x + residual) / math.sqrt(2.0), skip - - -class WaveNet(nn.Module): - def __init__( - self, - input_channels: Optional[int] = None, - output_channels: Optional[int] = None, - residual_channels: int = 512, - residual_layers: int = 20, - dilation_cycle: Optional[int] = 4, - is_diffusion: bool = False, - condition_channels: Optional[int] = None, - ): - super().__init__() - - # Input projection - self.input_projection = None - if input_channels is not None and input_channels != residual_channels: - self.input_projection = ConvNorm( - input_channels, residual_channels, kernel_size=1 - ) - - if input_channels is None: - input_channels = residual_channels - - self.input_channels = input_channels - - # Residual layers - self.residual_layers = nn.ModuleList( - [ - ResidualBlock( - residual_channels=residual_channels, - use_linear_bias=False, - dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1, - condition_channels=condition_channels, - ) - for i in range(residual_layers) - ] - ) - - # Skip projection - self.skip_projection = ConvNorm( - residual_channels, residual_channels, kernel_size=1 - ) - - # Output projection - self.output_projection = None - if output_channels is not None and output_channels != residual_channels: - self.output_projection = ConvNorm( - residual_channels, output_channels, kernel_size=1 - ) - - if is_diffusion: - self.diffusion_embedding = DiffusionEmbedding(residual_channels) - self.mlp = nn.Sequential( - LinearNorm(residual_channels, residual_channels * 4, False), - Mish(), - LinearNorm(residual_channels * 4, residual_channels, False), - ) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv1d, nn.Linear)): - nn.init.trunc_normal_(m.weight, std=0.02) - if getattr(m, "bias", None) is not None: - nn.init.constant_(m.bias, 0) - - def forward(self, x, t=None, condition=None): - if self.input_projection is not None: - x = self.input_projection(x) - x = F.silu(x) - - if t is not None: - t = self.diffusion_embedding(t) - t = self.mlp(t) - - skip = [] - for layer in self.residual_layers: - x, skip_connection = layer(x, condition, t) - skip.append(skip_connection) - - x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers)) - x = self.skip_projection(x) - - if self.output_projection is not None: - x = F.silu(x) - x = self.output_projection(x) - - return x diff --git a/fish_speech/models/vqgan/spectrogram.py b/fish_speech/models/vqgan/spectrogram.py deleted file mode 100644 index 01c3d7a2ab0f707ae92dbde0feb173927720c841..0000000000000000000000000000000000000000 --- a/fish_speech/models/vqgan/spectrogram.py +++ /dev/null @@ -1,122 +0,0 @@ -import torch -import torchaudio.functional as F -from torch import Tensor, nn -from torchaudio.transforms import MelScale - - -class LinearSpectrogram(nn.Module): - def __init__( - self, - n_fft=2048, - win_length=2048, - hop_length=512, - center=False, - mode="pow2_sqrt", - ): - super().__init__() - - self.n_fft = n_fft - self.win_length = win_length - self.hop_length = hop_length - self.center = center - self.mode = mode - - self.register_buffer("window", torch.hann_window(win_length), persistent=False) - - def forward(self, y: Tensor) -> Tensor: - if y.ndim == 3: - y = y.squeeze(1) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - ( - (self.win_length - self.hop_length) // 2, - (self.win_length - self.hop_length + 1) // 2, - ), - mode="reflect", - ).squeeze(1) - - spec = torch.stft( - y, - self.n_fft, - hop_length=self.hop_length, - win_length=self.win_length, - window=self.window, - center=self.center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - - spec = torch.view_as_real(spec) - - if self.mode == "pow2_sqrt": - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - - return spec - - -class LogMelSpectrogram(nn.Module): - def __init__( - self, - sample_rate=44100, - n_fft=2048, - win_length=2048, - hop_length=512, - n_mels=128, - center=False, - f_min=0.0, - f_max=None, - ): - super().__init__() - - self.sample_rate = sample_rate - self.n_fft = n_fft - self.win_length = win_length - self.hop_length = hop_length - self.center = center - self.n_mels = n_mels - self.f_min = f_min - self.f_max = f_max or float(sample_rate // 2) - - self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center) - - fb = F.melscale_fbanks( - n_freqs=self.n_fft // 2 + 1, - f_min=self.f_min, - f_max=self.f_max, - n_mels=self.n_mels, - sample_rate=self.sample_rate, - norm="slaney", - mel_scale="slaney", - ) - self.register_buffer( - "fb", - fb, - persistent=False, - ) - - def compress(self, x: Tensor) -> Tensor: - return torch.log(torch.clamp(x, min=1e-5)) - - def decompress(self, x: Tensor) -> Tensor: - return torch.exp(x) - - def apply_mel_scale(self, x: Tensor) -> Tensor: - return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2) - - def forward( - self, x: Tensor, return_linear: bool = False, sample_rate: int = None - ) -> Tensor: - if sample_rate is not None and sample_rate != self.sample_rate: - x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate) - - linear = self.spectrogram(x) - x = self.apply_mel_scale(linear) - x = self.compress(x) - - if return_linear: - return x, self.compress(linear) - - return x diff --git a/fish_speech/models/vqgan/utils.py b/fish_speech/models/vqgan/utils.py index b90c131d214006875476a161cdfd2dffa8949dac..6e9948fa462485d1404f9f2ef4fc1d15ba1438d8 100644 --- a/fish_speech/models/vqgan/utils.py +++ b/fish_speech/models/vqgan/utils.py @@ -1,94 +1,94 @@ -import matplotlib -import torch -from matplotlib import pyplot as plt - -matplotlib.use("Agg") - - -def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape - - -def sequence_mask(length, max_length=None): - if max_length is None: - max_length = length.max() - x = torch.arange(max_length, dtype=length.dtype, device=length.device) - return x.unsqueeze(0) < length.unsqueeze(1) - - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - -def get_padding(kernel_size, dilation=1): - return int((kernel_size * dilation - dilation) / 2) - - -def plot_mel(data, titles=None): - fig, axes = plt.subplots(len(data), 1, squeeze=False) - - if titles is None: - titles = [None for i in range(len(data))] - - plt.tight_layout() - - for i in range(len(data)): - mel = data[i] - - if isinstance(mel, torch.Tensor): - mel = mel.float().detach().cpu().numpy() - - axes[i][0].imshow(mel, origin="lower") - axes[i][0].set_aspect(2.5, adjustable="box") - axes[i][0].set_ylim(0, mel.shape[0]) - axes[i][0].set_title(titles[i], fontsize="medium") - axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) - axes[i][0].set_anchor("W") - - return fig - - -def slice_segments(x, ids_str, segment_size=4): - ret = torch.zeros_like(x[:, :, :segment_size]) - for i in range(x.size(0)): - idx_str = ids_str[i] - idx_end = idx_str + segment_size - ret[i] = x[i, :, idx_str:idx_end] - - return ret - - -def rand_slice_segments(x, x_lengths=None, segment_size=4): - b, d, t = x.size() - if x_lengths is None: - x_lengths = t - ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0) - ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long) - ret = slice_segments(x, ids_str, segment_size) - return ret, ids_str - - -@torch.jit.script -def fused_add_tanh_sigmoid_multiply(in_act, n_channels): - n_channels_int = n_channels[0] - t_act = torch.tanh(in_act[:, :n_channels_int, :]) - s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) - acts = t_act * s_act - - return acts - - -def avg_with_mask(x, mask): - assert mask.dtype == torch.float, "Mask should be float" - - if mask.ndim == 2: - mask = mask.unsqueeze(1) - - if mask.shape[1] == 1: - mask = mask.expand_as(x) - - return (x * mask).sum() / mask.sum() +import matplotlib +import torch +from matplotlib import pyplot as plt + +matplotlib.use("Agg") + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def plot_mel(data, titles=None): + fig, axes = plt.subplots(len(data), 1, squeeze=False) + + if titles is None: + titles = [None for i in range(len(data))] + + plt.tight_layout() + + for i in range(len(data)): + mel = data[i] + + if isinstance(mel, torch.Tensor): + mel = mel.float().detach().cpu().numpy() + + axes[i][0].imshow(mel, origin="lower") + axes[i][0].set_aspect(2.5, adjustable="box") + axes[i][0].set_ylim(0, mel.shape[0]) + axes[i][0].set_title(titles[i], fontsize="medium") + axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) + axes[i][0].set_anchor("W") + + return fig + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0) + ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(in_act, n_channels): + n_channels_int = n_channels[0] + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + + return acts + + +def avg_with_mask(x, mask): + assert mask.dtype == torch.float, "Mask should be float" + + if mask.ndim == 2: + mask = mask.unsqueeze(1) + + if mask.shape[1] == 1: + mask = mask.expand_as(x) + + return (x * mask).sum() / mask.sum() diff --git a/fish_speech/scheduler.py b/fish_speech/scheduler.py index 43bed6a2210723a7d5e1ea0a48ba61140047ca29..d5162e7b0e06f6af0ce12b4739fa0178b92a5c11 100644 --- a/fish_speech/scheduler.py +++ b/fish_speech/scheduler.py @@ -1,40 +1,40 @@ -import math - - -def get_cosine_schedule_with_warmup_lr_lambda( - current_step: int, - *, - num_warmup_steps: int | float, - num_training_steps: int, - num_cycles: float = 0.5, - final_lr_ratio: float = 0.0, -): - if 0 < num_warmup_steps < 1: # float mode - num_warmup_steps = int(num_warmup_steps * num_training_steps) - - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - - progress = float(current_step - num_warmup_steps) / float( - max(1, num_training_steps - num_warmup_steps) - ) - - return max( - final_lr_ratio, - 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), - ) - - -def get_constant_schedule_with_warmup_lr_lambda( - current_step: int, - *, - num_warmup_steps: int | float, - num_training_steps: int | None = None, -): - if 0 < num_warmup_steps < 1: # float mode - num_warmup_steps = int(num_warmup_steps * num_training_steps) - - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - - return 1.0 +import math + + +def get_cosine_schedule_with_warmup_lr_lambda( + current_step: int, + *, + num_warmup_steps: int | float, + num_training_steps: int, + num_cycles: float = 0.5, + final_lr_ratio: float = 0.0, +): + if 0 < num_warmup_steps < 1: # float mode + num_warmup_steps = int(num_warmup_steps * num_training_steps) + + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + + return max( + final_lr_ratio, + 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), + ) + + +def get_constant_schedule_with_warmup_lr_lambda( + current_step: int, + *, + num_warmup_steps: int | float, + num_training_steps: int | None = None, +): + if 0 < num_warmup_steps < 1: # float mode + num_warmup_steps = int(num_warmup_steps * num_training_steps) + + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + + return 1.0 diff --git a/fish_speech/text/__init__.py b/fish_speech/text/__init__.py index d740bd8eed447d162e55b165965dec17130377ce..c35b4e0dbd174b36229350f27a21b5acf0e9825b 100644 --- a/fish_speech/text/__init__.py +++ b/fish_speech/text/__init__.py @@ -1,4 +1,4 @@ -from .clean import clean_text -from .spliter import split_text - -__all__ = ["clean_text", "split_text"] +from .clean import clean_text +from .spliter import split_text + +__all__ = ["clean_text", "split_text"] diff --git a/fish_speech/text/chn_text_norm/.gitignore b/fish_speech/text/chn_text_norm/.gitignore index 75ea58fa4a7bf34fc9ab35afee24684aa6ef4c89..ddfe6a00ed1a83cc3607108e78dd4fe55117d1a8 100644 --- a/fish_speech/text/chn_text_norm/.gitignore +++ b/fish_speech/text/chn_text_norm/.gitignore @@ -1,114 +1,114 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -.hypothesis/ -.pytest_cache/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# pyenv -.python-version - -# celery beat schedule file -celerybeat-schedule - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ - -# JetBrains PyCharm -.idea - -# Customize -references -url.txt - -# Git -.git +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# JetBrains PyCharm +.idea + +# Customize +references +url.txt + +# Git +.git diff --git a/fish_speech/text/chn_text_norm/README.md b/fish_speech/text/chn_text_norm/README.md index 8450a2c6c0f8e40f4509f5be196eb9f9d2b9afb6..8bc7827df4a8e3c3af7e896a5af9e5e368a7c4ba 100644 --- a/fish_speech/text/chn_text_norm/README.md +++ b/fish_speech/text/chn_text_norm/README.md @@ -1,36 +1,36 @@ -# This account is no longer in use, see [Atomicoo](https://github.com/atomicoo) for my latest works. - -# Chn Text Norm - -this is a repository for chinese text normalization (no longer maintained). - -## Quick Start ## - -### Git Clone Repo ### - -git clone this repo to the root directory of your project which need to use it. - - cd /path/to/proj - git clone https://github.com/Joee1995/chn-text-norm.git - -after that, your doc tree should be: -``` -proj # root of your project -|--- chn_text_norm # this chn-text-norm tool - |--- text.py - |--- ... -|--- text_normalize.py # your text normalization code -|--- ... -``` - -### How to Use ? ### - - # text_normalize.py - from chn_text_norm.text import * - - raw_text = 'your raw text' - text = Text(raw_text=raw_text).normalize() - -### How to add quantums ### - -打开test.py,然后你就知道怎么做了。 +# This account is no longer in use, see [Atomicoo](https://github.com/atomicoo) for my latest works. + +# Chn Text Norm + +this is a repository for chinese text normalization (no longer maintained). + +## Quick Start ## + +### Git Clone Repo ### + +git clone this repo to the root directory of your project which need to use it. + + cd /path/to/proj + git clone https://github.com/Joee1995/chn-text-norm.git + +after that, your doc tree should be: +``` +proj # root of your project +|--- chn_text_norm # this chn-text-norm tool + |--- text.py + |--- ... +|--- text_normalize.py # your text normalization code +|--- ... +``` + +### How to Use ? ### + + # text_normalize.py + from chn_text_norm.text import * + + raw_text = 'your raw text' + text = Text(raw_text=raw_text).normalize() + +### How to add quantums ### + +打开test.py,然后你就知道怎么做了。 diff --git a/fish_speech/text/chn_text_norm/basic_class.py b/fish_speech/text/chn_text_norm/basic_class.py index 58d8f8eb7fc85d0861f106667d8f4e3e52b54761..1e1e3fa30b2d191d87374e0e21fd88dd430fde95 100644 --- a/fish_speech/text/chn_text_norm/basic_class.py +++ b/fish_speech/text/chn_text_norm/basic_class.py @@ -1,172 +1,172 @@ -# -*- coding: utf-8 -*- -"""基本类 -中文字符类 -中文数字/数位类 -中文数字类 -中文数位类 -中文数字系统类 -中文数学符号类 -*中文其他符号类 -""" - -__author__ = "Zhiyang Zhou " -__data__ = "2019-05-02" - -from fish_speech.text.chn_text_norm.basic_constant import NUMBERING_TYPES - - -class ChineseChar(object): - """ - 中文字符 - 每个字符对应简体和繁体, - e.g. 简体 = '负', 繁体 = '負' - 转换时可转换为简体或繁体 - """ - - def __init__(self, simplified, traditional): - self.simplified = simplified - self.traditional = traditional - self.__repr__ = self.__str__ - - def __str__(self): - return self.simplified or self.traditional or None - - def __repr__(self): - return self.__str__() - - -class ChineseNumberUnit(ChineseChar): - """ - 中文数字/数位字符 - 每个字符除繁简体外还有一个额外的大写字符 - e.g. '陆' 和 '陸' - """ - - def __init__(self, power, simplified, traditional, big_s, big_t): - super(ChineseNumberUnit, self).__init__(simplified, traditional) - self.power = power - self.big_s = big_s - self.big_t = big_t - - def __str__(self): - return "10^{}".format(self.power) - - @classmethod - def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False): - - if small_unit: - return ChineseNumberUnit( - power=index + 1, - simplified=value[0], - traditional=value[1], - big_s=value[1], - big_t=value[1], - ) - elif numbering_type == NUMBERING_TYPES[0]: - return ChineseNumberUnit( - power=index + 8, - simplified=value[0], - traditional=value[1], - big_s=value[0], - big_t=value[1], - ) - elif numbering_type == NUMBERING_TYPES[1]: - return ChineseNumberUnit( - power=(index + 2) * 4, - simplified=value[0], - traditional=value[1], - big_s=value[0], - big_t=value[1], - ) - elif numbering_type == NUMBERING_TYPES[2]: - return ChineseNumberUnit( - power=pow(2, index + 3), - simplified=value[0], - traditional=value[1], - big_s=value[0], - big_t=value[1], - ) - else: - raise ValueError( - "Counting type should be in {0} ({1} provided).".format( - NUMBERING_TYPES, numbering_type - ) - ) - - -class ChineseNumberDigit(ChineseChar): - """ - 中文数字字符 - """ - - def __init__( - self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None - ): - super(ChineseNumberDigit, self).__init__(simplified, traditional) - self.value = value - self.big_s = big_s - self.big_t = big_t - self.alt_s = alt_s - self.alt_t = alt_t - - def __str__(self): - return str(self.value) - - @classmethod - def create(cls, i, v): - return ChineseNumberDigit(i, v[0], v[1], v[2], v[3]) - - -class ChineseMath(ChineseChar): - """ - 中文数位字符 - """ - - def __init__(self, simplified, traditional, symbol, expression=None): - super(ChineseMath, self).__init__(simplified, traditional) - self.symbol = symbol - self.expression = expression - self.big_s = simplified - self.big_t = traditional - - -CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath - - -class NumberSystem(object): - """ - 中文数字系统 - """ - - pass - - -class MathSymbol(object): - """ - 用于中文数字系统的数学符号 (繁/简体), e.g. - positive = ['正', '正'] - negative = ['负', '負'] - point = ['点', '點'] - """ - - def __init__(self, positive, negative, point): - self.positive = positive - self.negative = negative - self.point = point - - def __iter__(self): - for v in self.__dict__.values(): - yield v - - -# class OtherSymbol(object): -# """ -# 其他符号 -# """ -# -# def __init__(self, sil): -# self.sil = sil -# -# def __iter__(self): -# for v in self.__dict__.values(): -# yield v +# -*- coding: utf-8 -*- +"""基本类 +中文字符类 +中文数字/数位类 +中文数字类 +中文数位类 +中文数字系统类 +中文数学符号类 +*中文其他符号类 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-02" + +from fish_speech.text.chn_text_norm.basic_constant import NUMBERING_TYPES + + +class ChineseChar(object): + """ + 中文字符 + 每个字符对应简体和繁体, + e.g. 简体 = '负', 繁体 = '負' + 转换时可转换为简体或繁体 + """ + + def __init__(self, simplified, traditional): + self.simplified = simplified + self.traditional = traditional + self.__repr__ = self.__str__ + + def __str__(self): + return self.simplified or self.traditional or None + + def __repr__(self): + return self.__str__() + + +class ChineseNumberUnit(ChineseChar): + """ + 中文数字/数位字符 + 每个字符除繁简体外还有一个额外的大写字符 + e.g. '陆' 和 '陸' + """ + + def __init__(self, power, simplified, traditional, big_s, big_t): + super(ChineseNumberUnit, self).__init__(simplified, traditional) + self.power = power + self.big_s = big_s + self.big_t = big_t + + def __str__(self): + return "10^{}".format(self.power) + + @classmethod + def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False): + + if small_unit: + return ChineseNumberUnit( + power=index + 1, + simplified=value[0], + traditional=value[1], + big_s=value[1], + big_t=value[1], + ) + elif numbering_type == NUMBERING_TYPES[0]: + return ChineseNumberUnit( + power=index + 8, + simplified=value[0], + traditional=value[1], + big_s=value[0], + big_t=value[1], + ) + elif numbering_type == NUMBERING_TYPES[1]: + return ChineseNumberUnit( + power=(index + 2) * 4, + simplified=value[0], + traditional=value[1], + big_s=value[0], + big_t=value[1], + ) + elif numbering_type == NUMBERING_TYPES[2]: + return ChineseNumberUnit( + power=pow(2, index + 3), + simplified=value[0], + traditional=value[1], + big_s=value[0], + big_t=value[1], + ) + else: + raise ValueError( + "Counting type should be in {0} ({1} provided).".format( + NUMBERING_TYPES, numbering_type + ) + ) + + +class ChineseNumberDigit(ChineseChar): + """ + 中文数字字符 + """ + + def __init__( + self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None + ): + super(ChineseNumberDigit, self).__init__(simplified, traditional) + self.value = value + self.big_s = big_s + self.big_t = big_t + self.alt_s = alt_s + self.alt_t = alt_t + + def __str__(self): + return str(self.value) + + @classmethod + def create(cls, i, v): + return ChineseNumberDigit(i, v[0], v[1], v[2], v[3]) + + +class ChineseMath(ChineseChar): + """ + 中文数位字符 + """ + + def __init__(self, simplified, traditional, symbol, expression=None): + super(ChineseMath, self).__init__(simplified, traditional) + self.symbol = symbol + self.expression = expression + self.big_s = simplified + self.big_t = traditional + + +CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath + + +class NumberSystem(object): + """ + 中文数字系统 + """ + + pass + + +class MathSymbol(object): + """ + 用于中文数字系统的数学符号 (繁/简体), e.g. + positive = ['正', '正'] + negative = ['负', '負'] + point = ['点', '點'] + """ + + def __init__(self, positive, negative, point): + self.positive = positive + self.negative = negative + self.point = point + + def __iter__(self): + for v in self.__dict__.values(): + yield v + + +# class OtherSymbol(object): +# """ +# 其他符号 +# """ +# +# def __init__(self, sil): +# self.sil = sil +# +# def __iter__(self): +# for v in self.__dict__.values(): +# yield v diff --git a/fish_speech/text/chn_text_norm/basic_constant.py b/fish_speech/text/chn_text_norm/basic_constant.py index 9a65991b9a9d349a0571c80508633951e52749ef..213ca23e0f75e90abd091a7b61e5adff67d75377 100644 --- a/fish_speech/text/chn_text_norm/basic_constant.py +++ b/fish_speech/text/chn_text_norm/basic_constant.py @@ -1,30 +1,30 @@ -# -*- coding: utf-8 -*- -"""基本常量 -中文数字/数位/符号字符常量 -""" - -__author__ = "Zhiyang Zhou " -__data__ = "2019-05-02" - -CHINESE_DIGIS = "零一二三四五六七八九" -BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖" -BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖" -SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万" -SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬" -LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载" -LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載" -SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万" -SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬" - -ZERO_ALT = "〇" -ONE_ALT = "幺" -TWO_ALTS = ["两", "兩"] - -POSITIVE = ["正", "正"] -NEGATIVE = ["负", "負"] -POINT = ["点", "點"] -# PLUS = [u'加', u'加'] -# SIL = [u'杠', u'槓'] - -# 中文数字系统类型 -NUMBERING_TYPES = ["low", "mid", "high"] +# -*- coding: utf-8 -*- +"""基本常量 +中文数字/数位/符号字符常量 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-02" + +CHINESE_DIGIS = "零一二三四五六七八九" +BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖" +BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖" +SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万" +SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬" +LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载" +LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載" +SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万" +SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬" + +ZERO_ALT = "〇" +ONE_ALT = "幺" +TWO_ALTS = ["两", "兩"] + +POSITIVE = ["正", "正"] +NEGATIVE = ["负", "負"] +POINT = ["点", "點"] +# PLUS = [u'加', u'加'] +# SIL = [u'杠', u'槓'] + +# 中文数字系统类型 +NUMBERING_TYPES = ["low", "mid", "high"] diff --git a/fish_speech/text/chn_text_norm/basic_util.py b/fish_speech/text/chn_text_norm/basic_util.py index dbf6130be87f285eed9998186508ea489d3bac9e..5dc91d2222136068d500624b46876c06b6277c2a 100644 --- a/fish_speech/text/chn_text_norm/basic_util.py +++ b/fish_speech/text/chn_text_norm/basic_util.py @@ -1,342 +1,342 @@ -# -*- coding: utf-8 -*- -"""基本方法 -创建中文数字系统 方法 -中文字符串 <=> 数字串 方法 -数字串 <=> 中文字符串 方法 -""" - -__author__ = "Zhiyang Zhou " -__data__ = "2019-05-02" - -from fish_speech.text.chn_text_norm.basic_class import * -from fish_speech.text.chn_text_norm.basic_constant import * - - -def create_system(numbering_type=NUMBERING_TYPES[1]): - """ - 根据数字系统类型返回创建相应的数字系统,默认为 mid - NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型 - low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc. - mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc. - high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc. - 返回对应的数字系统 - """ - - # chinese number units of '亿' and larger - all_larger_units = zip( - LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, - LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL, - ) - larger_units = [ - CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units) - ] - # chinese number units of '十, 百, 千, 万' - all_smaller_units = zip( - SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, - SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL, - ) - smaller_units = [ - CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units) - ] - # digis - chinese_digis = zip( - CHINESE_DIGIS, - CHINESE_DIGIS, - BIG_CHINESE_DIGIS_SIMPLIFIED, - BIG_CHINESE_DIGIS_TRADITIONAL, - ) - digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)] - digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT - digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT - digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1] - - # symbols - positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x) - negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x) - point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y))) - # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y))) - system = NumberSystem() - system.units = smaller_units + larger_units - system.digits = digits - system.math = MathSymbol(positive_cn, negative_cn, point_cn) - # system.symbols = OtherSymbol(sil_cn) - return system - - -def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]): - - def get_symbol(char, system): - for u in system.units: - if char in [u.traditional, u.simplified, u.big_s, u.big_t]: - return u - for d in system.digits: - if char in [ - d.traditional, - d.simplified, - d.big_s, - d.big_t, - d.alt_s, - d.alt_t, - ]: - return d - for m in system.math: - if char in [m.traditional, m.simplified]: - return m - - def string2symbols(chinese_string, system): - int_string, dec_string = chinese_string, "" - for p in [system.math.point.simplified, system.math.point.traditional]: - if p in chinese_string: - int_string, dec_string = chinese_string.split(p) - break - return [get_symbol(c, system) for c in int_string], [ - get_symbol(c, system) for c in dec_string - ] - - def correct_symbols(integer_symbols, system): - """ - 一百八 to 一百八十 - 一亿一千三百万 to 一亿 一千万 三百万 - """ - - if integer_symbols and isinstance(integer_symbols[0], CNU): - if integer_symbols[0].power == 1: - integer_symbols = [system.digits[1]] + integer_symbols - - if len(integer_symbols) > 1: - if isinstance(integer_symbols[-1], CND) and isinstance( - integer_symbols[-2], CNU - ): - integer_symbols.append( - CNU(integer_symbols[-2].power - 1, None, None, None, None) - ) - - result = [] - unit_count = 0 - for s in integer_symbols: - if isinstance(s, CND): - result.append(s) - unit_count = 0 - elif isinstance(s, CNU): - current_unit = CNU(s.power, None, None, None, None) - unit_count += 1 - - if unit_count == 1: - result.append(current_unit) - elif unit_count > 1: - for i in range(len(result)): - if ( - isinstance(result[-i - 1], CNU) - and result[-i - 1].power < current_unit.power - ): - result[-i - 1] = CNU( - result[-i - 1].power + current_unit.power, - None, - None, - None, - None, - ) - return result - - def compute_value(integer_symbols): - """ - Compute the value. - When current unit is larger than previous unit, current unit * all previous units will be used as all previous units. - e.g. '两千万' = 2000 * 10000 not 2000 + 10000 - """ - value = [0] - last_power = 0 - for s in integer_symbols: - if isinstance(s, CND): - value[-1] = s.value - elif isinstance(s, CNU): - value[-1] *= pow(10, s.power) - if s.power > last_power: - value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1])) - last_power = s.power - value.append(0) - return sum(value) - - system = create_system(numbering_type) - int_part, dec_part = string2symbols(chinese_string, system) - int_part = correct_symbols(int_part, system) - int_str = str(compute_value(int_part)) - dec_str = "".join([str(d.value) for d in dec_part]) - if dec_part: - return "{0}.{1}".format(int_str, dec_str) - else: - return int_str - - -def num2chn( - number_string, - numbering_type=NUMBERING_TYPES[1], - big=False, - traditional=False, - alt_zero=False, - alt_one=False, - alt_two=True, - use_zeros=True, - use_units=True, -): - - def get_value(value_string, use_zeros=True): - - striped_string = value_string.lstrip("0") - - # record nothing if all zeros - if not striped_string: - return [] - - # record one digits - elif len(striped_string) == 1: - if use_zeros and len(value_string) != len(striped_string): - return [system.digits[0], system.digits[int(striped_string)]] - else: - return [system.digits[int(striped_string)]] - - # recursively record multiple digits - else: - result_unit = next( - u for u in reversed(system.units) if u.power < len(striped_string) - ) - result_string = value_string[: -result_unit.power] - return ( - get_value(result_string) - + [result_unit] - + get_value(striped_string[-result_unit.power :]) - ) - - system = create_system(numbering_type) - - int_dec = number_string.split(".") - if len(int_dec) == 1: - int_string = int_dec[0] - dec_string = "" - elif len(int_dec) == 2: - int_string = int_dec[0] - dec_string = int_dec[1] - else: - raise ValueError( - "invalid input num string with more than one dot: {}".format(number_string) - ) - - if use_units and len(int_string) > 1: - result_symbols = get_value(int_string) - else: - result_symbols = [system.digits[int(c)] for c in int_string] - dec_symbols = [system.digits[int(c)] for c in dec_string] - if dec_string: - result_symbols += [system.math.point] + dec_symbols - - if alt_two: - liang = CND( - 2, - system.digits[2].alt_s, - system.digits[2].alt_t, - system.digits[2].big_s, - system.digits[2].big_t, - ) - for i, v in enumerate(result_symbols): - if isinstance(v, CND) and v.value == 2: - next_symbol = ( - result_symbols[i + 1] if i < len(result_symbols) - 1 else None - ) - previous_symbol = result_symbols[i - 1] if i > 0 else None - if isinstance(next_symbol, CNU) and isinstance( - previous_symbol, (CNU, type(None)) - ): - if next_symbol.power != 1 and ( - (previous_symbol is None) or (previous_symbol.power != 1) - ): - result_symbols[i] = liang - - # if big is True, '两' will not be used and `alt_two` has no impact on output - if big: - attr_name = "big_" - if traditional: - attr_name += "t" - else: - attr_name += "s" - else: - if traditional: - attr_name = "traditional" - else: - attr_name = "simplified" - - result = "".join([getattr(s, attr_name) for s in result_symbols]) - - # if not use_zeros: - # result = result.strip(getattr(system.digits[0], attr_name)) - - if alt_zero: - result = result.replace( - getattr(system.digits[0], attr_name), system.digits[0].alt_s - ) - - if alt_one: - result = result.replace( - getattr(system.digits[1], attr_name), system.digits[1].alt_s - ) - - for i, p in enumerate(POINT): - if result.startswith(p): - return CHINESE_DIGIS[0] + result - - # ^10, 11, .., 19 - if ( - len(result) >= 2 - and result[1] - in [ - SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], - SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0], - ] - and result[0] - in [ - CHINESE_DIGIS[1], - BIG_CHINESE_DIGIS_SIMPLIFIED[1], - BIG_CHINESE_DIGIS_TRADITIONAL[1], - ] - ): - result = result[1:] - - return result - - -if __name__ == "__main__": - - # 测试程序 - all_chinese_number_string = ( - CHINESE_DIGIS - + BIG_CHINESE_DIGIS_SIMPLIFIED - + BIG_CHINESE_DIGIS_TRADITIONAL - + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED - + LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL - + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED - + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL - + ZERO_ALT - + ONE_ALT - + "".join(TWO_ALTS + POSITIVE + NEGATIVE + POINT) - ) - - print("num:", chn2num("一万零四百零三点八零五")) - print("num:", chn2num("一亿六点三")) - print("num:", chn2num("一亿零六点三")) - print("num:", chn2num("两千零一亿六点三")) - # print('num:', chn2num('一零零八六')) - print("txt:", num2chn("10260.03", alt_zero=True)) - print("txt:", num2chn("20037.090", numbering_type="low", traditional=True)) - print("txt:", num2chn("100860001.77", numbering_type="high", big=True)) - print( - "txt:", - num2chn( - "059523810880", - alt_one=True, - alt_two=False, - use_lzeros=True, - use_rzeros=True, - use_units=False, - ), - ) - - print(all_chinese_number_string) +# -*- coding: utf-8 -*- +"""基本方法 +创建中文数字系统 方法 +中文字符串 <=> 数字串 方法 +数字串 <=> 中文字符串 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-02" + +from fish_speech.text.chn_text_norm.basic_class import * +from fish_speech.text.chn_text_norm.basic_constant import * + + +def create_system(numbering_type=NUMBERING_TYPES[1]): + """ + 根据数字系统类型返回创建相应的数字系统,默认为 mid + NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型 + low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc. + mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc. + high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc. + 返回对应的数字系统 + """ + + # chinese number units of '亿' and larger + all_larger_units = zip( + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, + LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL, + ) + larger_units = [ + CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units) + ] + # chinese number units of '十, 百, 千, 万' + all_smaller_units = zip( + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL, + ) + smaller_units = [ + CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units) + ] + # digis + chinese_digis = zip( + CHINESE_DIGIS, + CHINESE_DIGIS, + BIG_CHINESE_DIGIS_SIMPLIFIED, + BIG_CHINESE_DIGIS_TRADITIONAL, + ) + digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)] + digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT + digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT + digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1] + + # symbols + positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x) + negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x) + point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y))) + # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y))) + system = NumberSystem() + system.units = smaller_units + larger_units + system.digits = digits + system.math = MathSymbol(positive_cn, negative_cn, point_cn) + # system.symbols = OtherSymbol(sil_cn) + return system + + +def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]): + + def get_symbol(char, system): + for u in system.units: + if char in [u.traditional, u.simplified, u.big_s, u.big_t]: + return u + for d in system.digits: + if char in [ + d.traditional, + d.simplified, + d.big_s, + d.big_t, + d.alt_s, + d.alt_t, + ]: + return d + for m in system.math: + if char in [m.traditional, m.simplified]: + return m + + def string2symbols(chinese_string, system): + int_string, dec_string = chinese_string, "" + for p in [system.math.point.simplified, system.math.point.traditional]: + if p in chinese_string: + int_string, dec_string = chinese_string.split(p) + break + return [get_symbol(c, system) for c in int_string], [ + get_symbol(c, system) for c in dec_string + ] + + def correct_symbols(integer_symbols, system): + """ + 一百八 to 一百八十 + 一亿一千三百万 to 一亿 一千万 三百万 + """ + + if integer_symbols and isinstance(integer_symbols[0], CNU): + if integer_symbols[0].power == 1: + integer_symbols = [system.digits[1]] + integer_symbols + + if len(integer_symbols) > 1: + if isinstance(integer_symbols[-1], CND) and isinstance( + integer_symbols[-2], CNU + ): + integer_symbols.append( + CNU(integer_symbols[-2].power - 1, None, None, None, None) + ) + + result = [] + unit_count = 0 + for s in integer_symbols: + if isinstance(s, CND): + result.append(s) + unit_count = 0 + elif isinstance(s, CNU): + current_unit = CNU(s.power, None, None, None, None) + unit_count += 1 + + if unit_count == 1: + result.append(current_unit) + elif unit_count > 1: + for i in range(len(result)): + if ( + isinstance(result[-i - 1], CNU) + and result[-i - 1].power < current_unit.power + ): + result[-i - 1] = CNU( + result[-i - 1].power + current_unit.power, + None, + None, + None, + None, + ) + return result + + def compute_value(integer_symbols): + """ + Compute the value. + When current unit is larger than previous unit, current unit * all previous units will be used as all previous units. + e.g. '两千万' = 2000 * 10000 not 2000 + 10000 + """ + value = [0] + last_power = 0 + for s in integer_symbols: + if isinstance(s, CND): + value[-1] = s.value + elif isinstance(s, CNU): + value[-1] *= pow(10, s.power) + if s.power > last_power: + value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1])) + last_power = s.power + value.append(0) + return sum(value) + + system = create_system(numbering_type) + int_part, dec_part = string2symbols(chinese_string, system) + int_part = correct_symbols(int_part, system) + int_str = str(compute_value(int_part)) + dec_str = "".join([str(d.value) for d in dec_part]) + if dec_part: + return "{0}.{1}".format(int_str, dec_str) + else: + return int_str + + +def num2chn( + number_string, + numbering_type=NUMBERING_TYPES[1], + big=False, + traditional=False, + alt_zero=False, + alt_one=False, + alt_two=True, + use_zeros=True, + use_units=True, +): + + def get_value(value_string, use_zeros=True): + + striped_string = value_string.lstrip("0") + + # record nothing if all zeros + if not striped_string: + return [] + + # record one digits + elif len(striped_string) == 1: + if use_zeros and len(value_string) != len(striped_string): + return [system.digits[0], system.digits[int(striped_string)]] + else: + return [system.digits[int(striped_string)]] + + # recursively record multiple digits + else: + result_unit = next( + u for u in reversed(system.units) if u.power < len(striped_string) + ) + result_string = value_string[: -result_unit.power] + return ( + get_value(result_string) + + [result_unit] + + get_value(striped_string[-result_unit.power :]) + ) + + system = create_system(numbering_type) + + int_dec = number_string.split(".") + if len(int_dec) == 1: + int_string = int_dec[0] + dec_string = "" + elif len(int_dec) == 2: + int_string = int_dec[0] + dec_string = int_dec[1] + else: + raise ValueError( + "invalid input num string with more than one dot: {}".format(number_string) + ) + + if use_units and len(int_string) > 1: + result_symbols = get_value(int_string) + else: + result_symbols = [system.digits[int(c)] for c in int_string] + dec_symbols = [system.digits[int(c)] for c in dec_string] + if dec_string: + result_symbols += [system.math.point] + dec_symbols + + if alt_two: + liang = CND( + 2, + system.digits[2].alt_s, + system.digits[2].alt_t, + system.digits[2].big_s, + system.digits[2].big_t, + ) + for i, v in enumerate(result_symbols): + if isinstance(v, CND) and v.value == 2: + next_symbol = ( + result_symbols[i + 1] if i < len(result_symbols) - 1 else None + ) + previous_symbol = result_symbols[i - 1] if i > 0 else None + if isinstance(next_symbol, CNU) and isinstance( + previous_symbol, (CNU, type(None)) + ): + if next_symbol.power != 1 and ( + (previous_symbol is None) or (previous_symbol.power != 1) + ): + result_symbols[i] = liang + + # if big is True, '两' will not be used and `alt_two` has no impact on output + if big: + attr_name = "big_" + if traditional: + attr_name += "t" + else: + attr_name += "s" + else: + if traditional: + attr_name = "traditional" + else: + attr_name = "simplified" + + result = "".join([getattr(s, attr_name) for s in result_symbols]) + + # if not use_zeros: + # result = result.strip(getattr(system.digits[0], attr_name)) + + if alt_zero: + result = result.replace( + getattr(system.digits[0], attr_name), system.digits[0].alt_s + ) + + if alt_one: + result = result.replace( + getattr(system.digits[1], attr_name), system.digits[1].alt_s + ) + + for i, p in enumerate(POINT): + if result.startswith(p): + return CHINESE_DIGIS[0] + result + + # ^10, 11, .., 19 + if ( + len(result) >= 2 + and result[1] + in [ + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0], + ] + and result[0] + in [ + CHINESE_DIGIS[1], + BIG_CHINESE_DIGIS_SIMPLIFIED[1], + BIG_CHINESE_DIGIS_TRADITIONAL[1], + ] + ): + result = result[1:] + + return result + + +if __name__ == "__main__": + + # 测试程序 + all_chinese_number_string = ( + CHINESE_DIGIS + + BIG_CHINESE_DIGIS_SIMPLIFIED + + BIG_CHINESE_DIGIS_TRADITIONAL + + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED + + LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL + + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED + + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL + + ZERO_ALT + + ONE_ALT + + "".join(TWO_ALTS + POSITIVE + NEGATIVE + POINT) + ) + + print("num:", chn2num("一万零四百零三点八零五")) + print("num:", chn2num("一亿六点三")) + print("num:", chn2num("一亿零六点三")) + print("num:", chn2num("两千零一亿六点三")) + # print('num:', chn2num('一零零八六')) + print("txt:", num2chn("10260.03", alt_zero=True)) + print("txt:", num2chn("20037.090", numbering_type="low", traditional=True)) + print("txt:", num2chn("100860001.77", numbering_type="high", big=True)) + print( + "txt:", + num2chn( + "059523810880", + alt_one=True, + alt_two=False, + use_lzeros=True, + use_rzeros=True, + use_units=False, + ), + ) + + print(all_chinese_number_string) diff --git a/fish_speech/text/chn_text_norm/cardinal.py b/fish_speech/text/chn_text_norm/cardinal.py index ace9f5ad8e7f3be3a8e41b11dc0b9f80db799616..8e9ac0b25e5f542bdee933959976297ca45d2646 100644 --- a/fish_speech/text/chn_text_norm/cardinal.py +++ b/fish_speech/text/chn_text_norm/cardinal.py @@ -1,32 +1,32 @@ -# -*- coding: utf-8 -*- -"""CARDINAL类 (包含小数DECIMAL类) -纯数 <=> 中文字符串 方法 -中文字符串 <=> 纯数 方法 -""" - -__author__ = "Zhiyang Zhou " -__data__ = "2019-05-03" - -from fish_speech.text.chn_text_norm.basic_util import * - - -class Cardinal: - """ - CARDINAL类 - """ - - def __init__(self, cardinal=None, chntext=None): - self.cardinal = cardinal - self.chntext = chntext - - def chntext2cardinal(self): - return chn2num(self.chntext) - - def cardinal2chntext(self): - return num2chn(self.cardinal) - - -if __name__ == "__main__": - - # 测试程序 - print(Cardinal(cardinal="21357.230").cardinal2chntext()) +# -*- coding: utf-8 -*- +"""CARDINAL类 (包含小数DECIMAL类) +纯数 <=> 中文字符串 方法 +中文字符串 <=> 纯数 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-03" + +from fish_speech.text.chn_text_norm.basic_util import * + + +class Cardinal: + """ + CARDINAL类 + """ + + def __init__(self, cardinal=None, chntext=None): + self.cardinal = cardinal + self.chntext = chntext + + def chntext2cardinal(self): + return chn2num(self.chntext) + + def cardinal2chntext(self): + return num2chn(self.cardinal) + + +if __name__ == "__main__": + + # 测试程序 + print(Cardinal(cardinal="21357.230").cardinal2chntext()) diff --git a/fish_speech/text/chn_text_norm/date.py b/fish_speech/text/chn_text_norm/date.py index 77acfdb9a91df0fe3c615a0784f61aad87fbe56e..4473cc4cc24b412ed4fd6c9f2664495cc9b00c3c 100644 --- a/fish_speech/text/chn_text_norm/date.py +++ b/fish_speech/text/chn_text_norm/date.py @@ -1,75 +1,75 @@ -# -*- coding: utf-8 -*- -"""DATE类 -日期 <=> 中文字符串 方法 -中文字符串 <=> 日期 方法 -""" - -__author__ = "Zhiyang Zhou " -__data__ = "2019-05-07" - -from fish_speech.text.chn_text_norm.cardinal import Cardinal -from fish_speech.text.chn_text_norm.digit import Digit - - -class Date: - """ - DATE类 - """ - - def __init__(self, date=None, chntext=None): - self.date = date - self.chntext = chntext - - # def chntext2date(self): - # chntext = self.chntext - # try: - # year, other = chntext.strip().split('年', maxsplit=1) - # year = Digit(chntext=year).digit2chntext() + '年' - # except ValueError: - # other = chntext - # year = '' - # if other: - # try: - # month, day = other.strip().split('月', maxsplit=1) - # month = Cardinal(chntext=month).chntext2cardinal() + '月' - # except ValueError: - # day = chntext - # month = '' - # if day: - # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1] - # else: - # month = '' - # day = '' - # date = year + month + day - # self.date = date - # return self.date - - def date2chntext(self): - date = self.date - try: - year, other = date.strip().split("年", maxsplit=1) - year = Digit(digit=year).digit2chntext() + "年" - except ValueError: - other = date - year = "" - if other: - try: - month, day = other.strip().split("月", maxsplit=1) - month = Cardinal(cardinal=month).cardinal2chntext() + "月" - except ValueError: - day = date - month = "" - if day: - day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1] - else: - month = "" - day = "" - chntext = year + month + day - self.chntext = chntext - return self.chntext - - -if __name__ == "__main__": - - # 测试 - print(Date(date="09年3月16日").date2chntext()) +# -*- coding: utf-8 -*- +"""DATE类 +日期 <=> 中文字符串 方法 +中文字符串 <=> 日期 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-07" + +from fish_speech.text.chn_text_norm.cardinal import Cardinal +from fish_speech.text.chn_text_norm.digit import Digit + + +class Date: + """ + DATE类 + """ + + def __init__(self, date=None, chntext=None): + self.date = date + self.chntext = chntext + + # def chntext2date(self): + # chntext = self.chntext + # try: + # year, other = chntext.strip().split('年', maxsplit=1) + # year = Digit(chntext=year).digit2chntext() + '年' + # except ValueError: + # other = chntext + # year = '' + # if other: + # try: + # month, day = other.strip().split('月', maxsplit=1) + # month = Cardinal(chntext=month).chntext2cardinal() + '月' + # except ValueError: + # day = chntext + # month = '' + # if day: + # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1] + # else: + # month = '' + # day = '' + # date = year + month + day + # self.date = date + # return self.date + + def date2chntext(self): + date = self.date + try: + year, other = date.strip().split("年", maxsplit=1) + year = Digit(digit=year).digit2chntext() + "年" + except ValueError: + other = date + year = "" + if other: + try: + month, day = other.strip().split("月", maxsplit=1) + month = Cardinal(cardinal=month).cardinal2chntext() + "月" + except ValueError: + day = date + month = "" + if day: + day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1] + else: + month = "" + day = "" + chntext = year + month + day + self.chntext = chntext + return self.chntext + + +if __name__ == "__main__": + + # 测试 + print(Date(date="09年3月16日").date2chntext()) diff --git a/fish_speech/text/chn_text_norm/digit.py b/fish_speech/text/chn_text_norm/digit.py index 47c0cd4ad0c700635f84470bfdacfbdafb4a6185..37af7a0d746e6ecf36060f5331654b575b141dc6 100644 --- a/fish_speech/text/chn_text_norm/digit.py +++ b/fish_speech/text/chn_text_norm/digit.py @@ -1,32 +1,32 @@ -# -*- coding: utf-8 -*- -"""DIGIT类 -数字串 <=> 中文字符串 方法 -中文字符串 <=> 数字串 方法 -""" - -__author__ = "Zhiyang Zhou " -__data__ = "2019-05-03" - -from fish_speech.text.chn_text_norm.basic_util import * - - -class Digit: - """ - DIGIT类 - """ - - def __init__(self, digit=None, chntext=None): - self.digit = digit - self.chntext = chntext - - # def chntext2digit(self): - # return chn2num(self.chntext) - - def digit2chntext(self): - return num2chn(self.digit, alt_two=False, use_units=False) - - -if __name__ == "__main__": - - # 测试程序 - print(Digit(digit="2016").digit2chntext()) +# -*- coding: utf-8 -*- +"""DIGIT类 +数字串 <=> 中文字符串 方法 +中文字符串 <=> 数字串 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-03" + +from fish_speech.text.chn_text_norm.basic_util import * + + +class Digit: + """ + DIGIT类 + """ + + def __init__(self, digit=None, chntext=None): + self.digit = digit + self.chntext = chntext + + # def chntext2digit(self): + # return chn2num(self.chntext) + + def digit2chntext(self): + return num2chn(self.digit, alt_two=False, use_units=False) + + +if __name__ == "__main__": + + # 测试程序 + print(Digit(digit="2016").digit2chntext()) diff --git a/fish_speech/text/chn_text_norm/fraction.py b/fish_speech/text/chn_text_norm/fraction.py index b43b6a7feb634d346d59a2b4ab84b77ac88df103..18c02c01638ad2fd1849b9f6dd884c05f536faca 100644 --- a/fish_speech/text/chn_text_norm/fraction.py +++ b/fish_speech/text/chn_text_norm/fraction.py @@ -1,35 +1,35 @@ -# -*- coding: utf-8 -*- -"""FRACTION类 -分数 <=> 中文字符串 方法 -中文字符串 <=> 分数 方法 -""" - -__author__ = "Zhiyang Zhou " -__data__ = "2019-05-03" - -from fish_speech.text.chn_text_norm.basic_util import * - - -class Fraction: - """ - FRACTION类 - """ - - def __init__(self, fraction=None, chntext=None): - self.fraction = fraction - self.chntext = chntext - - def chntext2fraction(self): - denominator, numerator = self.chntext.split("分之") - return chn2num(numerator) + "/" + chn2num(denominator) - - def fraction2chntext(self): - numerator, denominator = self.fraction.split("/") - return num2chn(denominator) + "分之" + num2chn(numerator) - - -if __name__ == "__main__": - - # 测试程序 - print(Fraction(fraction="2135/7230").fraction2chntext()) - print(Fraction(chntext="五百八十一分之三百六十九").chntext2fraction()) +# -*- coding: utf-8 -*- +"""FRACTION类 +分数 <=> 中文字符串 方法 +中文字符串 <=> 分数 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-03" + +from fish_speech.text.chn_text_norm.basic_util import * + + +class Fraction: + """ + FRACTION类 + """ + + def __init__(self, fraction=None, chntext=None): + self.fraction = fraction + self.chntext = chntext + + def chntext2fraction(self): + denominator, numerator = self.chntext.split("分之") + return chn2num(numerator) + "/" + chn2num(denominator) + + def fraction2chntext(self): + numerator, denominator = self.fraction.split("/") + return num2chn(denominator) + "分之" + num2chn(numerator) + + +if __name__ == "__main__": + + # 测试程序 + print(Fraction(fraction="2135/7230").fraction2chntext()) + print(Fraction(chntext="五百八十一分之三百六十九").chntext2fraction()) diff --git a/fish_speech/text/chn_text_norm/money.py b/fish_speech/text/chn_text_norm/money.py index b4c980d32134e1460e96e5bcbcc73d0d55974d2a..a927e06e9ed34432b699fdcae23e0ed3dbf639f4 100644 --- a/fish_speech/text/chn_text_norm/money.py +++ b/fish_speech/text/chn_text_norm/money.py @@ -1,43 +1,43 @@ -# -*- coding: utf-8 -*- -"""MONEY类 -金钱 <=> 中文字符串 方法 -中文字符串 <=> 金钱 方法 -""" -import re - -__author__ = "Zhiyang Zhou " -__data__ = "2019-05-08" - -from fish_speech.text.chn_text_norm.cardinal import Cardinal - - -class Money: - """ - MONEY类 - """ - - def __init__(self, money=None, chntext=None): - self.money = money - self.chntext = chntext - - # def chntext2money(self): - # return self.money - - def money2chntext(self): - money = self.money - pattern = re.compile(r"(\d+(\.\d+)?)") - matchers = pattern.findall(money) - if matchers: - for matcher in matchers: - money = money.replace( - matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext() - ) - self.chntext = money - return self.chntext - - -if __name__ == "__main__": - - # 测试 - print(Money(money="21.5万元").money2chntext()) - print(Money(money="230块5毛").money2chntext()) +# -*- coding: utf-8 -*- +"""MONEY类 +金钱 <=> 中文字符串 方法 +中文字符串 <=> 金钱 方法 +""" +import re + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-08" + +from fish_speech.text.chn_text_norm.cardinal import Cardinal + + +class Money: + """ + MONEY类 + """ + + def __init__(self, money=None, chntext=None): + self.money = money + self.chntext = chntext + + # def chntext2money(self): + # return self.money + + def money2chntext(self): + money = self.money + pattern = re.compile(r"(\d+(\.\d+)?)") + matchers = pattern.findall(money) + if matchers: + for matcher in matchers: + money = money.replace( + matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext() + ) + self.chntext = money + return self.chntext + + +if __name__ == "__main__": + + # 测试 + print(Money(money="21.5万元").money2chntext()) + print(Money(money="230块5毛").money2chntext()) diff --git a/fish_speech/text/chn_text_norm/percentage.py b/fish_speech/text/chn_text_norm/percentage.py index 46abbf545af62eb951d8f6fe40bcf684587f81b0..ef36735a6551686c686f393f5c41525e1cf411e5 100644 --- a/fish_speech/text/chn_text_norm/percentage.py +++ b/fish_speech/text/chn_text_norm/percentage.py @@ -1,33 +1,33 @@ -# -*- coding: utf-8 -*- -"""PERCENTAGE类 -百分数 <=> 中文字符串 方法 -中文字符串 <=> 百分数 方法 -""" - -__author__ = "Zhiyang Zhou " -__data__ = "2019-05-06" - -from fish_speech.text.chn_text_norm.basic_util import * - - -class Percentage: - """ - PERCENTAGE类 - """ - - def __init__(self, percentage=None, chntext=None): - self.percentage = percentage - self.chntext = chntext - - def chntext2percentage(self): - return chn2num(self.chntext.strip().strip("百分之")) + "%" - - def percentage2chntext(self): - return "百分之" + num2chn(self.percentage.strip().strip("%")) - - -if __name__ == "__main__": - - # 测试程序 - print(Percentage(chntext="百分之五十六点零三").chntext2percentage()) - print(Percentage(percentage="65.3%").percentage2chntext()) +# -*- coding: utf-8 -*- +"""PERCENTAGE类 +百分数 <=> 中文字符串 方法 +中文字符串 <=> 百分数 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-06" + +from fish_speech.text.chn_text_norm.basic_util import * + + +class Percentage: + """ + PERCENTAGE类 + """ + + def __init__(self, percentage=None, chntext=None): + self.percentage = percentage + self.chntext = chntext + + def chntext2percentage(self): + return chn2num(self.chntext.strip().strip("百分之")) + "%" + + def percentage2chntext(self): + return "百分之" + num2chn(self.percentage.strip().strip("%")) + + +if __name__ == "__main__": + + # 测试程序 + print(Percentage(chntext="百分之五十六点零三").chntext2percentage()) + print(Percentage(percentage="65.3%").percentage2chntext()) diff --git a/fish_speech/text/chn_text_norm/telephone.py b/fish_speech/text/chn_text_norm/telephone.py index e72b546db628a3b807dc6235b59b188cae3153ff..46a673191f241caa0b25c9a5e54fc09cd283cc85 100644 --- a/fish_speech/text/chn_text_norm/telephone.py +++ b/fish_speech/text/chn_text_norm/telephone.py @@ -1,51 +1,51 @@ -# -*- coding: utf-8 -*- -"""TELEPHONE类 -电话号码 <=> 中文字符串 方法 -中文字符串 <=> 电话号码 方法 -""" - -__author__ = "Zhiyang Zhou " -__data__ = "2019-05-03" - -from fish_speech.text.chn_text_norm.basic_util import * - - -class TelePhone: - """ - TELEPHONE类 - """ - - def __init__(self, telephone=None, raw_chntext=None, chntext=None): - self.telephone = telephone - self.raw_chntext = raw_chntext - self.chntext = chntext - - # def chntext2telephone(self): - # sil_parts = self.raw_chntext.split('') - # self.telephone = '-'.join([ - # str(chn2num(p)) for p in sil_parts - # ]) - # return self.telephone - - def telephone2chntext(self, fixed=False): - - if fixed: - sil_parts = self.telephone.split("-") - self.raw_chntext = "".join( - [num2chn(part, alt_two=False, use_units=False) for part in sil_parts] - ) - self.chntext = self.raw_chntext.replace("", "") - else: - sp_parts = self.telephone.strip("+").split() - self.raw_chntext = "".join( - [num2chn(part, alt_two=False, use_units=False) for part in sp_parts] - ) - self.chntext = self.raw_chntext.replace("", "") - return self.chntext - - -if __name__ == "__main__": - - # 测试程序 - print(TelePhone(telephone="0595-23980880").telephone2chntext()) - # print(TelePhone(raw_chntext='零五九五杠二三八六五零九八').chntext2telephone()) +# -*- coding: utf-8 -*- +"""TELEPHONE类 +电话号码 <=> 中文字符串 方法 +中文字符串 <=> 电话号码 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-03" + +from fish_speech.text.chn_text_norm.basic_util import * + + +class TelePhone: + """ + TELEPHONE类 + """ + + def __init__(self, telephone=None, raw_chntext=None, chntext=None): + self.telephone = telephone + self.raw_chntext = raw_chntext + self.chntext = chntext + + # def chntext2telephone(self): + # sil_parts = self.raw_chntext.split('') + # self.telephone = '-'.join([ + # str(chn2num(p)) for p in sil_parts + # ]) + # return self.telephone + + def telephone2chntext(self, fixed=False): + + if fixed: + sil_parts = self.telephone.split("-") + self.raw_chntext = "".join( + [num2chn(part, alt_two=False, use_units=False) for part in sil_parts] + ) + self.chntext = self.raw_chntext.replace("", "") + else: + sp_parts = self.telephone.strip("+").split() + self.raw_chntext = "".join( + [num2chn(part, alt_two=False, use_units=False) for part in sp_parts] + ) + self.chntext = self.raw_chntext.replace("", "") + return self.chntext + + +if __name__ == "__main__": + + # 测试程序 + print(TelePhone(telephone="0595-23980880").telephone2chntext()) + # print(TelePhone(raw_chntext='零五九五杠二三八六五零九八').chntext2telephone()) diff --git a/fish_speech/text/chn_text_norm/text.py b/fish_speech/text/chn_text_norm/text.py index 54086fd933c01e14c3c55cee9adb52eefb58fd31..dfae9cb476cf074d4134cdf4ada995132451b1ab 100644 --- a/fish_speech/text/chn_text_norm/text.py +++ b/fish_speech/text/chn_text_norm/text.py @@ -1,177 +1,177 @@ -# -*- coding: utf-8 -*- -""" -TEXT类 -""" - -__author__ = "Zhiyang Zhou " -__data__ = "2019-05-03" - -import re - -from fish_speech.text.chn_text_norm.cardinal import Cardinal -from fish_speech.text.chn_text_norm.date import Date -from fish_speech.text.chn_text_norm.digit import Digit -from fish_speech.text.chn_text_norm.fraction import Fraction -from fish_speech.text.chn_text_norm.money import Money -from fish_speech.text.chn_text_norm.percentage import Percentage -from fish_speech.text.chn_text_norm.telephone import TelePhone - -CURRENCY_NAMES = ( - "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|" - "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)" -) -CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)" -COM_QUANTIFIERS = ( - "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|" - "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|" - "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|" - "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|" - "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|" - "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|人|抽)" -) - - -class Text: - """ - Text类 - """ - - def __init__(self, raw_text, norm_text=None): - self.raw_text = "^" + raw_text + "$" - self.norm_text = norm_text - - def _particular(self): - text = self.norm_text - pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))") - matchers = pattern.findall(text) - if matchers: - # print('particular') - for matcher in matchers: - text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1) - self.norm_text = text - return self.norm_text - - def normalize(self): - text = self.raw_text - - # 规范化日期 - pattern = re.compile( - r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)" - ) - matchers = pattern.findall(text) - if matchers: - # print('date') - for matcher in matchers: - text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1) - - # 规范化金钱 - pattern = re.compile( - r"\D+((\d+(\.\d+)?)[多余几]?" - + CURRENCY_UNITS - + "(\d" - + CURRENCY_UNITS - + "?)?)" - ) - matchers = pattern.findall(text) - if matchers: - # print('money') - for matcher in matchers: - text = text.replace( - matcher[0], Money(money=matcher[0]).money2chntext(), 1 - ) - - # 规范化固话/手机号码 - # 手机 - # http://www.jihaoba.com/news/show/13680 - # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 - # 联通:130、131、132、156、155、186、185、176 - # 电信:133、153、189、180、181、177 - pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") - matchers = pattern.findall(text) - if matchers: - # print('telephone') - for matcher in matchers: - text = text.replace( - matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1 - ) - # 固话 - pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") - matchers = pattern.findall(text) - if matchers: - # print('fixed telephone') - for matcher in matchers: - text = text.replace( - matcher[0], - TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), - 1, - ) - - # 规范化分数 - pattern = re.compile(r"(\d+/\d+)") - matchers = pattern.findall(text) - if matchers: - # print('fraction') - for matcher in matchers: - text = text.replace( - matcher, Fraction(fraction=matcher).fraction2chntext(), 1 - ) - - # 规范化百分数 - text = text.replace("%", "%") - pattern = re.compile(r"(\d+(\.\d+)?%)") - matchers = pattern.findall(text) - if matchers: - # print('percentage') - for matcher in matchers: - text = text.replace( - matcher[0], - Percentage(percentage=matcher[0]).percentage2chntext(), - 1, - ) - - # 规范化纯数+量词 - pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS) - matchers = pattern.findall(text) - if matchers: - # print('cardinal+quantifier') - for matcher in matchers: - text = text.replace( - matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1 - ) - - # 规范化数字编号 - pattern = re.compile(r"(\d{4,32})") - matchers = pattern.findall(text) - if matchers: - # print('digit') - for matcher in matchers: - text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1) - - # 规范化纯数 - pattern = re.compile(r"(\d+(\.\d+)?)") - matchers = pattern.findall(text) - if matchers: - # print('cardinal') - for matcher in matchers: - text = text.replace( - matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1 - ) - - self.norm_text = text - self._particular() - - return self.norm_text.lstrip("^").rstrip("$") - - -if __name__ == "__main__": - - # 测试程序 - print(Text(raw_text="固话:0595-23865596或23880880。").normalize()) - print(Text(raw_text="手机:+86 19859213959或15659451527。").normalize()) - print(Text(raw_text="分数:32477/76391。").normalize()) - print(Text(raw_text="百分数:80.03%。").normalize()) - print(Text(raw_text="编号:31520181154418。").normalize()) - print(Text(raw_text="纯数:2983.07克或12345.60米。").normalize()) - print(Text(raw_text="日期:1999年2月20日或09年3月15号。").normalize()) - print(Text(raw_text="金钱:12块5,34.5元,20.1万").normalize()) - print(Text(raw_text="特殊:O2O或B2C。").normalize()) +# -*- coding: utf-8 -*- +""" +TEXT类 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-03" + +import re + +from fish_speech.text.chn_text_norm.cardinal import Cardinal +from fish_speech.text.chn_text_norm.date import Date +from fish_speech.text.chn_text_norm.digit import Digit +from fish_speech.text.chn_text_norm.fraction import Fraction +from fish_speech.text.chn_text_norm.money import Money +from fish_speech.text.chn_text_norm.percentage import Percentage +from fish_speech.text.chn_text_norm.telephone import TelePhone + +CURRENCY_NAMES = ( + "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|" + "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)" +) +CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)" +COM_QUANTIFIERS = ( + "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|" + "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|" + "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|" + "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|" + "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|" + "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|人|抽)" +) + + +class Text: + """ + Text类 + """ + + def __init__(self, raw_text, norm_text=None): + self.raw_text = "^" + raw_text + "$" + self.norm_text = norm_text + + def _particular(self): + text = self.norm_text + pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))") + matchers = pattern.findall(text) + if matchers: + # print('particular') + for matcher in matchers: + text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1) + self.norm_text = text + return self.norm_text + + def normalize(self): + text = self.raw_text + + # 规范化日期 + pattern = re.compile( + r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)" + ) + matchers = pattern.findall(text) + if matchers: + # print('date') + for matcher in matchers: + text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1) + + # 规范化金钱 + pattern = re.compile( + r"\D+((\d+(\.\d+)?)[多余几]?" + + CURRENCY_UNITS + + "(\d" + + CURRENCY_UNITS + + "?)?)" + ) + matchers = pattern.findall(text) + if matchers: + # print('money') + for matcher in matchers: + text = text.replace( + matcher[0], Money(money=matcher[0]).money2chntext(), 1 + ) + + # 规范化固话/手机号码 + # 手机 + # http://www.jihaoba.com/news/show/13680 + # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 + # 联通:130、131、132、156、155、186、185、176 + # 电信:133、153、189、180、181、177 + pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") + matchers = pattern.findall(text) + if matchers: + # print('telephone') + for matcher in matchers: + text = text.replace( + matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1 + ) + # 固话 + pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") + matchers = pattern.findall(text) + if matchers: + # print('fixed telephone') + for matcher in matchers: + text = text.replace( + matcher[0], + TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), + 1, + ) + + # 规范化分数 + pattern = re.compile(r"(\d+/\d+)") + matchers = pattern.findall(text) + if matchers: + # print('fraction') + for matcher in matchers: + text = text.replace( + matcher, Fraction(fraction=matcher).fraction2chntext(), 1 + ) + + # 规范化百分数 + text = text.replace("%", "%") + pattern = re.compile(r"(\d+(\.\d+)?%)") + matchers = pattern.findall(text) + if matchers: + # print('percentage') + for matcher in matchers: + text = text.replace( + matcher[0], + Percentage(percentage=matcher[0]).percentage2chntext(), + 1, + ) + + # 规范化纯数+量词 + pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS) + matchers = pattern.findall(text) + if matchers: + # print('cardinal+quantifier') + for matcher in matchers: + text = text.replace( + matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1 + ) + + # 规范化数字编号 + pattern = re.compile(r"(\d{4,32})") + matchers = pattern.findall(text) + if matchers: + # print('digit') + for matcher in matchers: + text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1) + + # 规范化纯数 + pattern = re.compile(r"(\d+(\.\d+)?)") + matchers = pattern.findall(text) + if matchers: + # print('cardinal') + for matcher in matchers: + text = text.replace( + matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1 + ) + + self.norm_text = text + self._particular() + + return self.norm_text.lstrip("^").rstrip("$") + + +if __name__ == "__main__": + + # 测试程序 + print(Text(raw_text="固话:0595-23865596或23880880。").normalize()) + print(Text(raw_text="手机:+86 19859213959或15659451527。").normalize()) + print(Text(raw_text="分数:32477/76391。").normalize()) + print(Text(raw_text="百分数:80.03%。").normalize()) + print(Text(raw_text="编号:31520181154418。").normalize()) + print(Text(raw_text="纯数:2983.07克或12345.60米。").normalize()) + print(Text(raw_text="日期:1999年2月20日或09年3月15号。").normalize()) + print(Text(raw_text="金钱:12块5,34.5元,20.1万").normalize()) + print(Text(raw_text="特殊:O2O或B2C。").normalize()) diff --git a/fish_speech/text/clean.py b/fish_speech/text/clean.py index dbaf843d781f113735043319cc00dc2aed5ae382..2aba28fc1bc7fc6054e37534ece06c743bff9f6c 100644 --- a/fish_speech/text/clean.py +++ b/fish_speech/text/clean.py @@ -1,62 +1,37 @@ -import re - -SYMBOLS_MAPPING = { - "\n": "", - "…": ".", - "“": "'", - "”": "'", - "‘": "'", - "’": "'", - "【": "", - "】": "", - "[": "", - "]": "", - "(": "", - ")": "", - "(": "", - ")": "", - "・": "", - "·": "", - "「": "'", - "」": "'", - "《": "'", - "》": "'", - "—": "", - "~": "", - "~": "", - ":": ",", - ";": ",", - ";": ",", - ":": ",", -} - -REPLACE_SYMBOL_REGEX = re.compile( - "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys()) -) - - -EMOJI_REGEX = re.compile( - "[" - "\U0001F600-\U0001F64F" # emoticons - "\U0001F300-\U0001F5FF" # symbols & pictographs - "\U0001F680-\U0001F6FF" # transport & map symbols - "\U0001F1E0-\U0001F1FF" # flags (iOS) - "]+", - flags=re.UNICODE, -) - - -def clean_text(text): - # Clean the text - text = text.strip() - - # Replace all chinese symbols with their english counterparts - text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text) - - # Remove emojis - text = EMOJI_REGEX.sub(r"", text) - - # Remove continuous periods (...) and commas (,,,) - text = re.sub(r"[.,]{2,}", lambda m: m.group()[0], text) - - return text +import re + +SYMBOLS_MAPPING = { + "‘": "'", + "’": "'", +} + +REPLACE_SYMBOL_REGEX = re.compile( + "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys()) +) + + +EMOJI_REGEX = re.compile( + "[" + "\U0001F600-\U0001F64F" # emoticons + "\U0001F300-\U0001F5FF" # symbols & pictographs + "\U0001F680-\U0001F6FF" # transport & map symbols + "\U0001F1E0-\U0001F1FF" # flags (iOS) + "]+", + flags=re.UNICODE, +) + + +def clean_text(text): + # Clean the text + text = text.strip() + + # Replace all chinese symbols with their english counterparts + text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text) + + # Remove emojis + text = EMOJI_REGEX.sub(r"", text) + + # Remove continuous periods (...) and commas (,,,) + text = re.sub(r"[,]{2,}", lambda m: m.group()[0], text) + + return text diff --git a/fish_speech/text/spliter.py b/fish_speech/text/spliter.py index d4bb995487c4f53818c6b2a16cf0a886b4e02e84..30661e4ef3796250e539aa367467bac22ecbbfb8 100644 --- a/fish_speech/text/spliter.py +++ b/fish_speech/text/spliter.py @@ -1,130 +1,130 @@ -import re -import string - -from fish_speech.text.clean import clean_text - - -def utf_8_len(text): - return len(text.encode("utf-8")) - - -def break_text(texts, length, splits: set): - for text in texts: - if utf_8_len(text) <= length: - yield text - continue - - curr = "" - for char in text: - curr += char - - if char in splits: - yield curr - curr = "" - - if curr: - yield curr - - -def break_text_by_length(texts, length): - for text in texts: - if utf_8_len(text) <= length: - yield text - continue - - curr = "" - for char in text: - curr += char - - if utf_8_len(curr) >= length: - yield curr - curr = "" - - if curr: - yield curr - - -def add_cleaned(curr, segments): - curr = curr.strip() - if curr and not all(c.isspace() or c in string.punctuation for c in curr): - segments.append(curr) - - -def protect_float(text): - # Turns 3.14 into <3_f_14> to prevent splitting - return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text) - - -def unprotect_float(text): - # Turns <3_f_14> into 3.14 - return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text) - - -def split_text(text, length): - text = clean_text(text) - - # Break the text into pieces with following rules: - # 1. Split the text at ".", "!", "?" if text is NOT a float - # 2. If the text is longer than length, split at "," - # 3. If the text is still longer than length, split at " " - # 4. If the text is still longer than length, split at any character to length - - texts = [text] - texts = map(protect_float, texts) - texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"}) - texts = map(unprotect_float, texts) - texts = break_text(texts, length, {",", ","}) - texts = break_text(texts, length, {" "}) - texts = list(break_text_by_length(texts, length)) - - # Then, merge the texts into segments with length <= length - segments = [] - curr = "" - - for text in texts: - if utf_8_len(curr) + utf_8_len(text) <= length: - curr += text - else: - add_cleaned(curr, segments) - curr = text - - if curr: - add_cleaned(curr, segments) - - return segments - - -if __name__ == "__main__": - # Test the split_text function - - text = "This is a test sentence. This is another test sentence. And a third one." - - assert split_text(text, 50) == [ - "This is a test sentence.", - "This is another test sentence. And a third one.", - ] - assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"] - assert split_text(" ", 10) == [] - assert split_text("a", 10) == ["a"] - - text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines." - assert split_text(text, 50) == [ - "This is a test sentence with only commas,", - "and no dots, and no exclamation marks,", - "and no question marks, and no newlines.", - ] - - text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence." - # First half split at " ", second half split at "," - assert split_text(text, 50) == [ - "This is a test sentence This is a test sentence", - "This is a test sentence. This is a test sentence,", - "This is a test sentence, This is a test sentence.", - ] - - text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。" - assert split_text(text, 50) == [ - "这是一段很长的中文文本,", - "而且没有句号,也没有感叹号,", - "也没有问号,也没有换行符.", - ] +import re +import string + +from fish_speech.text.clean import clean_text + + +def utf_8_len(text: str): + return len(text.encode("utf-8")) + + +def break_text(texts, length, splits: set): + for text in texts: + if utf_8_len(text) <= length: + yield text + continue + + curr = "" + for char in text: + curr += char + + if char in splits: + yield curr + curr = "" + + if curr: + yield curr + + +def break_text_by_length(texts, length): + for text in texts: + if utf_8_len(text) <= length: + yield text + continue + + curr = "" + for char in text: + curr += char + + if utf_8_len(curr) >= length: + yield curr + curr = "" + + if curr: + yield curr + + +def add_cleaned(curr, segments): + curr = curr.strip() + if curr and not all(c.isspace() or c in string.punctuation for c in curr): + segments.append(curr) + + +def protect_float(text): + # Turns 3.14 into <3_f_14> to prevent splitting + return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text) + + +def unprotect_float(text): + # Turns <3_f_14> into 3.14 + return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text) + + +def split_text(text, length): + text = clean_text(text) + + # Break the text into pieces with following rules: + # 1. Split the text at ".", "!", "?" if text is NOT a float + # 2. If the text is longer than length, split at "," + # 3. If the text is still longer than length, split at " " + # 4. If the text is still longer than length, split at any character to length + + texts = [text] + texts = map(protect_float, texts) + texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"}) + texts = map(unprotect_float, texts) + texts = break_text(texts, length, {",", ","}) + texts = break_text(texts, length, {" "}) + texts = list(break_text_by_length(texts, length)) + + # Then, merge the texts into segments with length <= length + segments = [] + curr = "" + + for text in texts: + if utf_8_len(curr) + utf_8_len(text) <= length: + curr += text + else: + add_cleaned(curr, segments) + curr = text + + if curr: + add_cleaned(curr, segments) + + return segments + + +if __name__ == "__main__": + # Test the split_text function + + text = "This is a test sentence. This is another test sentence. And a third one." + + assert split_text(text, 50) == [ + "This is a test sentence.", + "This is another test sentence. And a third one.", + ] + assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"] + assert split_text(" ", 10) == [] + assert split_text("a", 10) == ["a"] + + text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines." + assert split_text(text, 50) == [ + "This is a test sentence with only commas,", + "and no dots, and no exclamation marks,", + "and no question marks, and no newlines.", + ] + + text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence." + # First half split at " ", second half split at "," + assert split_text(text, 50) == [ + "This is a test sentence This is a test sentence", + "This is a test sentence. This is a test sentence,", + "This is a test sentence, This is a test sentence.", + ] + + text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。" + assert split_text(text, 50) == [ + "这是一段很长的中文文本,", + "而且没有句号,也没有感叹号,", + "也没有问号,也没有换行符.", + ] diff --git a/fish_speech/tokenizer.py b/fish_speech/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..f4d512d31263dcb2abc95c3a7bf3cd4bde8c4830 --- /dev/null +++ b/fish_speech/tokenizer.py @@ -0,0 +1,152 @@ +import base64 +import json +import logging +from pathlib import Path + +import tiktoken + +logger = logging.getLogger(__name__) + +# This is a modified version of the default pattern from GPT-4o, that better handles punctuations. +FISH_TIKTOKEN_PATTERN = "|".join( + [ + r"(?i:'s|'t|'re|'ve|'m|'ll|'d)", + r"\p{P}", + r"[^\r\n\p{L}\p{N}]?\p{L}+", + r"\p{N}", + r" ?[^\s\p{L}\p{N}]+[\r\n]*", + r"\s*[\r\n]+", + r"\s+(\?!\S)", + r"\s+", + ] +) +TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + +BOS_TOKEN = "<|begin_of_text|>" +EOS_TOKEN = "<|end_of_text|>" +PAD_TOKEN = "<|pad|>" +IM_START_TOKEN = "<|im_start|>" +IM_END_TOKEN = "<|im_end|>" + +MODALITY_TEXT_TOKEN = "<|text|>" +MODALITY_VOICE_TOKEN = "<|voice|>" +MODALITY_INTERLEAVE_TOKEN = "<|interleave|>" +MODALITY_TOKENS = { + "text": MODALITY_TEXT_TOKEN, + "voice": MODALITY_VOICE_TOKEN, + "interleave": MODALITY_INTERLEAVE_TOKEN, +} + +PLACEHOLDER_TOKEN = [""] * 4 +for i in range(4): + PLACEHOLDER_TOKEN[i] = f"<|placeholder:{i}|>" + +SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>" +SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(1024)] + +# Warning: when you add a new special token, you should only add it to the end of the list. +ALL_SPECIAL_TOKENS = [ + BOS_TOKEN, + EOS_TOKEN, + PAD_TOKEN, + IM_START_TOKEN, + IM_END_TOKEN, + PLACEHOLDER_TOKEN[0], + PLACEHOLDER_TOKEN[1], + PLACEHOLDER_TOKEN[2], + PLACEHOLDER_TOKEN[3], + MODALITY_TEXT_TOKEN, + MODALITY_VOICE_TOKEN, + MODALITY_INTERLEAVE_TOKEN, + *SEMANTIC_TOKENS, +] + + +class FishTokenizer: + def __init__(self, model_path: str) -> None: + mergeable_ranks = self.load_tiktoken_bpe(model_path) + special_token_begin = len(mergeable_ranks) + self.all_special_tokens_with_ids = { + token: special_token_begin + i for i, token in enumerate(ALL_SPECIAL_TOKENS) + } + self.semantic_id_to_token_id = { + i: self.all_special_tokens_with_ids[token] + for i, token in enumerate(SEMANTIC_TOKENS) + } + self.semantic_begin_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[0]] + self.semantic_end_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[-1]] + + self.tkt_model = tiktoken.core.Encoding( + name=Path(model_path).stem, + pat_str=FISH_TIKTOKEN_PATTERN, + mergeable_ranks=mergeable_ranks, + special_tokens=self.all_special_tokens_with_ids, + ) + + @staticmethod + def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]: + data = {} + for line in open(tiktoken_bpe_file).read().splitlines(): + if not line: + continue + token, rank = line.split() + data[base64.b64decode(token)] = int(rank) + return data + + def get_token_id(self, token: str) -> int: + return self.all_special_tokens_with_ids[token] + + def encode(self, s: str, allowed_special: bool | set[str] = True) -> list[int]: + assert isinstance(s, str) + + subs = [] + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS): + subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS]) + + if allowed_special is True: + allowed_special = self.tkt_model.special_tokens_set + elif allowed_special is False: + allowed_special = set() + + return sum( + self.tkt_model.encode_batch( + subs, allowed_special=allowed_special, disallowed_special=set() + ), + start=[], + ) + + def decode(self, tokens: list[int]) -> str: + return self.tkt_model.decode(tokens) + + def save_pretrained(self, path: str): + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + + with open(path / "tokenizer.tiktoken", "w") as f: + for token, rank in self.tkt_model._mergeable_ranks.items(): + f.write(f"{base64.b64encode(token).decode()} {rank}\n") + + with open(path / "special_tokens.json", "w") as f: + json.dump( + self.all_special_tokens_with_ids, + f, + indent=2, + ensure_ascii=False, + ) + + @staticmethod + def from_pretrained(path: str): + return FishTokenizer(Path(path) / "tokenizer.tiktoken") + + +if __name__ == "__main__": + tokenizer = FishTokenizer("data/mpacks/v1.4-pretrain/tokenizer.all.tiktoken") + tokenizer.save_pretrained("checkpoints/fish-speech-0.5B") + tokenizer = FishTokenizer.from_pretrained("checkpoints/fish-speech-0.5B") + + print( + [ + tokenizer.decode([i]) + for i in tokenizer.encode(f"{BOS_TOKEN}你好,世界!{EOS_TOKEN}") + ] + ) diff --git a/fish_speech/train.py b/fish_speech/train.py index e61b793c3af812b9e0d5add86b3be210cf27940e..7dc8d2dbb2e744ddcf000736e65f009cdbf66eae 100644 --- a/fish_speech/train.py +++ b/fish_speech/train.py @@ -1,139 +1,141 @@ -import os -import sys -from typing import Optional - -import hydra -import lightning as L -import pyrootutils -import torch -from lightning import Callback, LightningDataModule, LightningModule, Trainer -from lightning.pytorch.loggers import Logger -from lightning.pytorch.strategies import DDPStrategy -from omegaconf import DictConfig, OmegaConf - -os.environ.pop("SLURM_NTASKS", None) -os.environ.pop("SLURM_JOB_NAME", None) -os.environ.pop("SLURM_NTASKS_PER_NODE", None) - -# register eval resolver and root -pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) - -# Allow TF32 on Ampere GPUs -torch.set_float32_matmul_precision("high") -torch.backends.cudnn.allow_tf32 = True - -# register eval resolver -OmegaConf.register_new_resolver("eval", eval) - -import fish_speech.utils as utils - -log = utils.RankedLogger(__name__, rank_zero_only=True) - - -@utils.task_wrapper -def train(cfg: DictConfig) -> tuple[dict, dict]: - """Trains the model. Can additionally evaluate on a testset, using best weights obtained during - training. - This method is wrapped in optional @task_wrapper decorator, that controls the behavior during - failure. Useful for multiruns, saving info about the crash, etc. - Args: - cfg (DictConfig): Configuration composed by Hydra. - Returns: - Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. - """ # noqa: E501 - - # set seed for random number generators in pytorch, numpy and python.random - if cfg.get("seed"): - L.seed_everything(cfg.seed, workers=False) - - if cfg.get("deterministic"): - torch.use_deterministic_algorithms(True) - - log.info(f"Instantiating datamodule <{cfg.data._target_}>") - datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) - - log.info(f"Instantiating model <{cfg.model._target_}>") - model: LightningModule = hydra.utils.instantiate(cfg.model) - - log.info("Instantiating callbacks...") - callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) - - log.info("Instantiating loggers...") - logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger")) - - log.info(f"Instantiating trainer <{cfg.trainer._target_}>") - trainer: Trainer = hydra.utils.instantiate( - cfg.trainer, - callbacks=callbacks, - logger=logger, - ) - - object_dict = { - "cfg": cfg, - "datamodule": datamodule, - "model": model, - "callbacks": callbacks, - "logger": logger, - "trainer": trainer, - } - - if logger: - log.info("Logging hyperparameters!") - utils.log_hyperparameters(object_dict) - - if cfg.get("train"): - log.info("Starting training!") - - ckpt_path = cfg.get("ckpt_path") - auto_resume = False - - resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir) - if resume_ckpt_path is not None: - ckpt_path = resume_ckpt_path - auto_resume = True - - if ckpt_path is not None: - log.info(f"Resuming from checkpoint: {ckpt_path}") - - # resume weights only is disabled for auto-resume - if cfg.get("resume_weights_only") and auto_resume is False: - log.info("Resuming weights only!") - ckpt = torch.load(ckpt_path, map_location=model.device) - if "state_dict" in ckpt: - ckpt = ckpt["state_dict"] - err = model.load_state_dict(ckpt, strict=False) - log.info(f"Error loading state dict: {err}") - ckpt_path = None - - trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path) - - train_metrics = trainer.callback_metrics - - if cfg.get("test"): - log.info("Starting testing!") - ckpt_path = trainer.checkpoint_callback.best_model_path - if ckpt_path == "": - log.warning("Best ckpt not found! Using current weights for testing...") - ckpt_path = cfg.get("ckpt_path") - - trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) - log.info(f"Best ckpt path: {ckpt_path}") - - test_metrics = trainer.callback_metrics - - # merge train and test metrics - metric_dict = {**train_metrics, **test_metrics} - - return metric_dict, object_dict - - -@hydra.main( - version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml" -) -def main(cfg: DictConfig) -> Optional[float]: - # train the model - train(cfg) - - -if __name__ == "__main__": - main() +import os + +os.environ["USE_LIBUV"] = "0" +import sys +from typing import Optional + +import hydra +import lightning as L +import pyrootutils +import torch +from lightning import Callback, LightningDataModule, LightningModule, Trainer +from lightning.pytorch.loggers import Logger +from lightning.pytorch.strategies import DDPStrategy +from omegaconf import DictConfig, OmegaConf + +os.environ.pop("SLURM_NTASKS", None) +os.environ.pop("SLURM_JOB_NAME", None) +os.environ.pop("SLURM_NTASKS_PER_NODE", None) + +# register eval resolver and root +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +# Allow TF32 on Ampere GPUs +torch.set_float32_matmul_precision("high") +torch.backends.cudnn.allow_tf32 = True + +# register eval resolver +OmegaConf.register_new_resolver("eval", eval) + +import fish_speech.utils as utils + +log = utils.RankedLogger(__name__, rank_zero_only=True) + + +@utils.task_wrapper +def train(cfg: DictConfig) -> tuple[dict, dict]: + """Trains the model. Can additionally evaluate on a testset, using best weights obtained during + training. + This method is wrapped in optional @task_wrapper decorator, that controls the behavior during + failure. Useful for multiruns, saving info about the crash, etc. + Args: + cfg (DictConfig): Configuration composed by Hydra. + Returns: + Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. + """ # noqa: E501 + + # set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + L.seed_everything(cfg.seed, workers=False) + + if cfg.get("deterministic"): + torch.use_deterministic_algorithms(True) + + log.info(f"Instantiating datamodule <{cfg.data._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) + + log.info(f"Instantiating model <{cfg.model._target_}>") + model: LightningModule = hydra.utils.instantiate(cfg.model) + + log.info("Instantiating callbacks...") + callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) + + log.info("Instantiating loggers...") + logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger")) + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate( + cfg.trainer, + callbacks=callbacks, + logger=logger, + ) + + object_dict = { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "callbacks": callbacks, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + utils.log_hyperparameters(object_dict) + + if cfg.get("train"): + log.info("Starting training!") + + ckpt_path = cfg.get("ckpt_path") + auto_resume = False + + resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir) + if resume_ckpt_path is not None: + ckpt_path = resume_ckpt_path + auto_resume = True + + if ckpt_path is not None: + log.info(f"Resuming from checkpoint: {ckpt_path}") + + # resume weights only is disabled for auto-resume + if cfg.get("resume_weights_only") and auto_resume is False: + log.info("Resuming weights only!") + ckpt = torch.load(ckpt_path, map_location=model.device) + if "state_dict" in ckpt: + ckpt = ckpt["state_dict"] + err = model.load_state_dict(ckpt, strict=False) + log.info(f"Error loading state dict: {err}") + ckpt_path = None + + trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + + train_metrics = trainer.callback_metrics + + if cfg.get("test"): + log.info("Starting testing!") + ckpt_path = trainer.checkpoint_callback.best_model_path + if ckpt_path == "": + log.warning("Best ckpt not found! Using current weights for testing...") + ckpt_path = cfg.get("ckpt_path") + + trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + log.info(f"Best ckpt path: {ckpt_path}") + + test_metrics = trainer.callback_metrics + + # merge train and test metrics + metric_dict = {**train_metrics, **test_metrics} + + return metric_dict, object_dict + + +@hydra.main( + version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml" +) +def main(cfg: DictConfig) -> Optional[float]: + # train the model + train(cfg) + + +if __name__ == "__main__": + main() diff --git a/fish_speech/utils/__init__.py b/fish_speech/utils/__init__.py index 05378519dbd18361c639e33413d011e7307c9adb..185517110af6cacfbda3c3d9561e6081d825fca4 100644 --- a/fish_speech/utils/__init__.py +++ b/fish_speech/utils/__init__.py @@ -1,23 +1,24 @@ -from .braceexpand import braceexpand -from .context import autocast_exclude_mps -from .file import get_latest_checkpoint -from .instantiators import instantiate_callbacks, instantiate_loggers -from .logger import RankedLogger -from .logging_utils import log_hyperparameters -from .rich_utils import enforce_tags, print_config_tree -from .utils import extras, get_metric_value, task_wrapper - -__all__ = [ - "enforce_tags", - "extras", - "get_metric_value", - "RankedLogger", - "instantiate_callbacks", - "instantiate_loggers", - "log_hyperparameters", - "print_config_tree", - "task_wrapper", - "braceexpand", - "get_latest_checkpoint", - "autocast_exclude_mps", -] +from .braceexpand import braceexpand +from .context import autocast_exclude_mps +from .file import get_latest_checkpoint +from .instantiators import instantiate_callbacks, instantiate_loggers +from .logger import RankedLogger +from .logging_utils import log_hyperparameters +from .rich_utils import enforce_tags, print_config_tree +from .utils import extras, get_metric_value, set_seed, task_wrapper + +__all__ = [ + "enforce_tags", + "extras", + "get_metric_value", + "RankedLogger", + "instantiate_callbacks", + "instantiate_loggers", + "log_hyperparameters", + "print_config_tree", + "task_wrapper", + "braceexpand", + "get_latest_checkpoint", + "autocast_exclude_mps", + "set_seed", +] diff --git a/fish_speech/utils/braceexpand.py b/fish_speech/utils/braceexpand.py index f3ac739f01f7e10e039c68c1157d6c761064f974..8888977ce194fc5caa9e85bcf548e3bc42a3c52c 100644 --- a/fish_speech/utils/braceexpand.py +++ b/fish_speech/utils/braceexpand.py @@ -1,217 +1,217 @@ -""" -Bash-style brace expansion -Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py -License: MIT -""" - -import re -import string -from itertools import chain, product -from typing import Iterable, Iterator, Optional - -__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"] - - -class UnbalancedBracesError(ValueError): - pass - - -alphabet = string.ascii_uppercase + string.ascii_lowercase - -int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$") -char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$") -escape_re = re.compile(r"\\(.)") - - -def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]: - """braceexpand(pattern) -> iterator over generated strings - - Returns an iterator over the strings resulting from brace expansion - of pattern. This function implements Brace Expansion as described in - bash(1), with the following limitations: - - * A pattern containing unbalanced braces will raise an - UnbalancedBracesError exception. In bash, unbalanced braces will either - be partly expanded or ignored. - - * A mixed-case character range like '{Z..a}' or '{a..Z}' will not - include the characters '[]^_`' between 'Z' and 'a'. - - When escape is True (the default), characters in pattern can be - prefixed with a backslash to cause them not to be interpreted as - special characters for brace expansion (such as '{', '}', ','). - To pass through a a literal backslash, double it ('\\\\'). - - When escape is False, backslashes in pattern have no special - meaning and will be preserved in the output. - - Examples: - - >>> from braceexpand import braceexpand - - # Integer range - >>> list(braceexpand('item{1..3}')) - ['item1', 'item2', 'item3'] - - # Character range - >>> list(braceexpand('{a..c}')) - ['a', 'b', 'c'] - - # Sequence - >>> list(braceexpand('index.html{,.backup}')) - ['index.html', 'index.html.backup'] - - # Nested patterns - >>> list(braceexpand('python{2.{5..7},3.{2,3}}')) - ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3'] - - # Prefixing an integer with zero causes all numbers to be padded to - # the same width. - >>> list(braceexpand('{07..10}')) - ['07', '08', '09', '10'] - - # An optional increment can be specified for ranges. - >>> list(braceexpand('{a..g..2}')) - ['a', 'c', 'e', 'g'] - - # Ranges can go in both directions. - >>> list(braceexpand('{4..1}')) - ['4', '3', '2', '1'] - - # Numbers can be negative - >>> list(braceexpand('{2..-1}')) - ['2', '1', '0', '-1'] - - # Unbalanced braces raise an exception. - >>> list(braceexpand('{1{2,3}')) - Traceback (most recent call last): - ... - UnbalancedBracesError: Unbalanced braces: '{1{2,3}' - - # By default, the backslash is the escape character. - >>> list(braceexpand(r'{1\\{2,3}')) - ['1{2', '3'] - - # Setting 'escape' to False disables backslash escaping. - >>> list(braceexpand(r'\\{1,2}', escape=False)) - ['\\\\1', '\\\\2'] - - """ - return ( - escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape) - ) - - -def parse_pattern(pattern: str, escape: bool) -> Iterator[str]: - start = 0 - pos = 0 - bracketdepth = 0 - items: list[Iterable[str]] = [] - - # print 'pattern:', pattern - while pos < len(pattern): - if escape and pattern[pos] == "\\": - pos += 2 - continue - elif pattern[pos] == "{": - if bracketdepth == 0 and pos > start: - # print 'literal:', pattern[start:pos] - items.append([pattern[start:pos]]) - start = pos - bracketdepth += 1 - elif pattern[pos] == "}": - bracketdepth -= 1 - if bracketdepth == 0: - # print 'expression:', pattern[start+1:pos] - expr = pattern[start + 1 : pos] - item = parse_expression(expr, escape) - if item is None: # not a range or sequence - items.extend([["{"], parse_pattern(expr, escape), ["}"]]) - else: - items.append(item) - start = pos + 1 # skip the closing brace - pos += 1 - - if bracketdepth != 0: # unbalanced braces - raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern) - - if start < pos: - items.append([pattern[start:]]) - - return ("".join(item) for item in product(*items)) - - -def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]: - int_range_match = int_range_re.match(expr) - if int_range_match: - return make_int_range(*int_range_match.groups()) - - char_range_match = char_range_re.match(expr) - if char_range_match: - return make_char_range(*char_range_match.groups()) - - return parse_sequence(expr, escape) - - -def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]: - # sequence -> chain(*sequence_items) - start = 0 - pos = 0 - bracketdepth = 0 - items: list[Iterable[str]] = [] - - # print 'sequence:', seq - while pos < len(seq): - if escape and seq[pos] == "\\": - pos += 2 - continue - elif seq[pos] == "{": - bracketdepth += 1 - elif seq[pos] == "}": - bracketdepth -= 1 - elif seq[pos] == "," and bracketdepth == 0: - items.append(parse_pattern(seq[start:pos], escape)) - start = pos + 1 # skip the comma - pos += 1 - - if bracketdepth != 0: - raise UnbalancedBracesError - if not items: - return None - - # part after the last comma (may be the empty string) - items.append(parse_pattern(seq[start:], escape)) - return chain(*items) - - -def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]: - if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]): - padding = max(len(left), len(right)) - else: - padding = 0 - step = (int(incr) or 1) if incr else 1 - start = int(left) - end = int(right) - r = range(start, end + 1, step) if start < end else range(start, end - 1, -step) - fmt = "%0{}d".format(padding) - return (fmt % i for i in r) - - -def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str: - step = (int(incr) or 1) if incr else 1 - start = alphabet.index(left) - end = alphabet.index(right) - if start < end: - return alphabet[start : end + 1 : step] - else: - end = end or -len(alphabet) - return alphabet[start : end - 1 : -step] - - -if __name__ == "__main__": - import doctest - import sys - - failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL) - if failed: - sys.exit(1) +""" +Bash-style brace expansion +Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py +License: MIT +""" + +import re +import string +from itertools import chain, product +from typing import Iterable, Iterator, Optional + +__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"] + + +class UnbalancedBracesError(ValueError): + pass + + +alphabet = string.ascii_uppercase + string.ascii_lowercase + +int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$") +char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$") +escape_re = re.compile(r"\\(.)") + + +def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]: + """braceexpand(pattern) -> iterator over generated strings + + Returns an iterator over the strings resulting from brace expansion + of pattern. This function implements Brace Expansion as described in + bash(1), with the following limitations: + + * A pattern containing unbalanced braces will raise an + UnbalancedBracesError exception. In bash, unbalanced braces will either + be partly expanded or ignored. + + * A mixed-case character range like '{Z..a}' or '{a..Z}' will not + include the characters '[]^_`' between 'Z' and 'a'. + + When escape is True (the default), characters in pattern can be + prefixed with a backslash to cause them not to be interpreted as + special characters for brace expansion (such as '{', '}', ','). + To pass through a a literal backslash, double it ('\\\\'). + + When escape is False, backslashes in pattern have no special + meaning and will be preserved in the output. + + Examples: + + >>> from braceexpand import braceexpand + + # Integer range + >>> list(braceexpand('item{1..3}')) + ['item1', 'item2', 'item3'] + + # Character range + >>> list(braceexpand('{a..c}')) + ['a', 'b', 'c'] + + # Sequence + >>> list(braceexpand('index.html{,.backup}')) + ['index.html', 'index.html.backup'] + + # Nested patterns + >>> list(braceexpand('python{2.{5..7},3.{2,3}}')) + ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3'] + + # Prefixing an integer with zero causes all numbers to be padded to + # the same width. + >>> list(braceexpand('{07..10}')) + ['07', '08', '09', '10'] + + # An optional increment can be specified for ranges. + >>> list(braceexpand('{a..g..2}')) + ['a', 'c', 'e', 'g'] + + # Ranges can go in both directions. + >>> list(braceexpand('{4..1}')) + ['4', '3', '2', '1'] + + # Numbers can be negative + >>> list(braceexpand('{2..-1}')) + ['2', '1', '0', '-1'] + + # Unbalanced braces raise an exception. + >>> list(braceexpand('{1{2,3}')) + Traceback (most recent call last): + ... + UnbalancedBracesError: Unbalanced braces: '{1{2,3}' + + # By default, the backslash is the escape character. + >>> list(braceexpand(r'{1\\{2,3}')) + ['1{2', '3'] + + # Setting 'escape' to False disables backslash escaping. + >>> list(braceexpand(r'\\{1,2}', escape=False)) + ['\\\\1', '\\\\2'] + + """ + return ( + escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape) + ) + + +def parse_pattern(pattern: str, escape: bool) -> Iterator[str]: + start = 0 + pos = 0 + bracketdepth = 0 + items: list[Iterable[str]] = [] + + # print 'pattern:', pattern + while pos < len(pattern): + if escape and pattern[pos] == "\\": + pos += 2 + continue + elif pattern[pos] == "{": + if bracketdepth == 0 and pos > start: + # print 'literal:', pattern[start:pos] + items.append([pattern[start:pos]]) + start = pos + bracketdepth += 1 + elif pattern[pos] == "}": + bracketdepth -= 1 + if bracketdepth == 0: + # print 'expression:', pattern[start+1:pos] + expr = pattern[start + 1 : pos] + item = parse_expression(expr, escape) + if item is None: # not a range or sequence + items.extend([["{"], parse_pattern(expr, escape), ["}"]]) + else: + items.append(item) + start = pos + 1 # skip the closing brace + pos += 1 + + if bracketdepth != 0: # unbalanced braces + raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern) + + if start < pos: + items.append([pattern[start:]]) + + return ("".join(item) for item in product(*items)) + + +def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]: + int_range_match = int_range_re.match(expr) + if int_range_match: + return make_int_range(*int_range_match.groups()) + + char_range_match = char_range_re.match(expr) + if char_range_match: + return make_char_range(*char_range_match.groups()) + + return parse_sequence(expr, escape) + + +def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]: + # sequence -> chain(*sequence_items) + start = 0 + pos = 0 + bracketdepth = 0 + items: list[Iterable[str]] = [] + + # print 'sequence:', seq + while pos < len(seq): + if escape and seq[pos] == "\\": + pos += 2 + continue + elif seq[pos] == "{": + bracketdepth += 1 + elif seq[pos] == "}": + bracketdepth -= 1 + elif seq[pos] == "," and bracketdepth == 0: + items.append(parse_pattern(seq[start:pos], escape)) + start = pos + 1 # skip the comma + pos += 1 + + if bracketdepth != 0: + raise UnbalancedBracesError + if not items: + return None + + # part after the last comma (may be the empty string) + items.append(parse_pattern(seq[start:], escape)) + return chain(*items) + + +def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]: + if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]): + padding = max(len(left), len(right)) + else: + padding = 0 + step = (int(incr) or 1) if incr else 1 + start = int(left) + end = int(right) + r = range(start, end + 1, step) if start < end else range(start, end - 1, -step) + fmt = "%0{}d".format(padding) + return (fmt % i for i in r) + + +def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str: + step = (int(incr) or 1) if incr else 1 + start = alphabet.index(left) + end = alphabet.index(right) + if start < end: + return alphabet[start : end + 1 : step] + else: + end = end or -len(alphabet) + return alphabet[start : end - 1 : -step] + + +if __name__ == "__main__": + import doctest + import sys + + failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL) + if failed: + sys.exit(1) diff --git a/fish_speech/utils/context.py b/fish_speech/utils/context.py index f04a99290ab32f7fe5b60656075a2d03af8468d6..618c4deceaa2578cd9f0672a65d1dd55430c7dcc 100644 --- a/fish_speech/utils/context.py +++ b/fish_speech/utils/context.py @@ -1,13 +1,13 @@ -from contextlib import nullcontext - -import torch - - -def autocast_exclude_mps( - device_type: str, dtype: torch.dtype -) -> nullcontext | torch.autocast: - return ( - nullcontext() - if torch.backends.mps.is_available() - else torch.autocast(device_type, dtype) - ) +from contextlib import nullcontext + +import torch + + +def autocast_exclude_mps( + device_type: str, dtype: torch.dtype +) -> nullcontext | torch.autocast: + return ( + nullcontext() + if torch.backends.mps.is_available() + else torch.autocast(device_type, dtype) + ) diff --git a/fish_speech/utils/file.py b/fish_speech/utils/file.py index 78c82640a963fa556657107729f7543d2e7c3510..7516ad28b31e9836e6c02a991690d3e466d3ea62 100644 --- a/fish_speech/utils/file.py +++ b/fish_speech/utils/file.py @@ -1,16 +1,16 @@ -import os -from pathlib import Path - - -def get_latest_checkpoint(path: Path | str) -> Path | None: - # Find the latest checkpoint - ckpt_dir = Path(path) - - if ckpt_dir.exists() is False: - return None - - ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime) - if len(ckpts) == 0: - return None - - return ckpts[-1] +import os +from pathlib import Path + + +def get_latest_checkpoint(path: Path | str) -> Path | None: + # Find the latest checkpoint + ckpt_dir = Path(path) + + if ckpt_dir.exists() is False: + return None + + ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime) + if len(ckpts) == 0: + return None + + return ckpts[-1] diff --git a/fish_speech/utils/instantiators.py b/fish_speech/utils/instantiators.py index f6ee463924f588a35477937fbe3c3364043bdf3e..d1a08fe2fd76bedc5b4ad40f8dddfa40e6951c58 100644 --- a/fish_speech/utils/instantiators.py +++ b/fish_speech/utils/instantiators.py @@ -1,50 +1,50 @@ -from typing import List - -import hydra -from omegaconf import DictConfig -from pytorch_lightning import Callback -from pytorch_lightning.loggers import Logger - -from .logger import RankedLogger - -log = RankedLogger(__name__, rank_zero_only=True) - - -def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: - """Instantiates callbacks from config.""" - - callbacks: List[Callback] = [] - - if not callbacks_cfg: - log.warning("No callback configs found! Skipping..") - return callbacks - - if not isinstance(callbacks_cfg, DictConfig): - raise TypeError("Callbacks config must be a DictConfig!") - - for _, cb_conf in callbacks_cfg.items(): - if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: - log.info(f"Instantiating callback <{cb_conf._target_}>") - callbacks.append(hydra.utils.instantiate(cb_conf)) - - return callbacks - - -def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: - """Instantiates loggers from config.""" - - logger: List[Logger] = [] - - if not logger_cfg: - log.warning("No logger configs found! Skipping...") - return logger - - if not isinstance(logger_cfg, DictConfig): - raise TypeError("Logger config must be a DictConfig!") - - for _, lg_conf in logger_cfg.items(): - if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: - log.info(f"Instantiating logger <{lg_conf._target_}>") - logger.append(hydra.utils.instantiate(lg_conf)) - - return logger +from typing import List + +import hydra +from omegaconf import DictConfig +from pytorch_lightning import Callback +from pytorch_lightning.loggers import Logger + +from .logger import RankedLogger + +log = RankedLogger(__name__, rank_zero_only=True) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config.""" + + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config.""" + + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger diff --git a/fish_speech/utils/logger.py b/fish_speech/utils/logger.py index 94f94f738d1d87404354d086c30ef0ad9ab04cdc..5e909c26380affa14ec2e8e92ce5ecb37dc0777e 100644 --- a/fish_speech/utils/logger.py +++ b/fish_speech/utils/logger.py @@ -1,55 +1,55 @@ -import logging -from typing import Mapping, Optional - -from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only - - -class RankedLogger(logging.LoggerAdapter): - """A multi-GPU-friendly python command line logger.""" - - def __init__( - self, - name: str = __name__, - rank_zero_only: bool = True, - extra: Optional[Mapping[str, object]] = None, - ) -> None: - """Initializes a multi-GPU-friendly python command line logger that logs on all processes - with their rank prefixed in the log message. - - :param name: The name of the logger. Default is ``__name__``. - :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. - :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. - """ - logger = logging.getLogger(name) - super().__init__(logger=logger, extra=extra) - self.rank_zero_only = rank_zero_only - - def log( - self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs - ) -> None: - """Delegate a log call to the underlying logger, after prefixing its message with the rank - of the process it's being logged from. If `'rank'` is provided, then the log will only - occur on that rank/process. - - :param level: The level to log at. Look at `logging.__init__.py` for more information. - :param msg: The message to log. - :param rank: The rank to log at. - :param args: Additional args to pass to the underlying logging function. - :param kwargs: Any additional keyword args to pass to the underlying logging function. - """ - if self.isEnabledFor(level): - msg, kwargs = self.process(msg, kwargs) - current_rank = getattr(rank_zero_only, "rank", None) - if current_rank is None: - raise RuntimeError( - "The `rank_zero_only.rank` needs to be set before use" - ) - msg = rank_prefixed_message(msg, current_rank) - if self.rank_zero_only: - if current_rank == 0: - self.logger.log(level, msg, *args, **kwargs) - else: - if rank is None: - self.logger.log(level, msg, *args, **kwargs) - elif current_rank == rank: - self.logger.log(level, msg, *args, **kwargs) +import logging +from typing import Mapping, Optional + +from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only + + +class RankedLogger(logging.LoggerAdapter): + """A multi-GPU-friendly python command line logger.""" + + def __init__( + self, + name: str = __name__, + rank_zero_only: bool = True, + extra: Optional[Mapping[str, object]] = None, + ) -> None: + """Initializes a multi-GPU-friendly python command line logger that logs on all processes + with their rank prefixed in the log message. + + :param name: The name of the logger. Default is ``__name__``. + :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. + :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. + """ + logger = logging.getLogger(name) + super().__init__(logger=logger, extra=extra) + self.rank_zero_only = rank_zero_only + + def log( + self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs + ) -> None: + """Delegate a log call to the underlying logger, after prefixing its message with the rank + of the process it's being logged from. If `'rank'` is provided, then the log will only + occur on that rank/process. + + :param level: The level to log at. Look at `logging.__init__.py` for more information. + :param msg: The message to log. + :param rank: The rank to log at. + :param args: Additional args to pass to the underlying logging function. + :param kwargs: Any additional keyword args to pass to the underlying logging function. + """ + if self.isEnabledFor(level): + msg, kwargs = self.process(msg, kwargs) + current_rank = getattr(rank_zero_only, "rank", None) + if current_rank is None: + raise RuntimeError( + "The `rank_zero_only.rank` needs to be set before use" + ) + msg = rank_prefixed_message(msg, current_rank) + if self.rank_zero_only: + if current_rank == 0: + self.logger.log(level, msg, *args, **kwargs) + else: + if rank is None: + self.logger.log(level, msg, *args, **kwargs) + elif current_rank == rank: + self.logger.log(level, msg, *args, **kwargs) diff --git a/fish_speech/utils/logging_utils.py b/fish_speech/utils/logging_utils.py index 8e3b0a2519e12845f09e5fbe86dfccbf5b345429..ead61c20564687585e945bf6e88f13d803851bd2 100644 --- a/fish_speech/utils/logging_utils.py +++ b/fish_speech/utils/logging_utils.py @@ -1,48 +1,48 @@ -from lightning.pytorch.utilities import rank_zero_only - -from fish_speech.utils import logger as log - - -@rank_zero_only -def log_hyperparameters(object_dict: dict) -> None: - """Controls which config parts are saved by lightning loggers. - - Additionally saves: - - Number of model parameters - """ - - hparams = {} - - cfg = object_dict["cfg"] - model = object_dict["model"] - trainer = object_dict["trainer"] - - if not trainer.logger: - log.warning("Logger not found! Skipping hyperparameter logging...") - return - - hparams["model"] = cfg["model"] - - # save number of model parameters - hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) - hparams["model/params/trainable"] = sum( - p.numel() for p in model.parameters() if p.requires_grad - ) - hparams["model/params/non_trainable"] = sum( - p.numel() for p in model.parameters() if not p.requires_grad - ) - - hparams["data"] = cfg["data"] - hparams["trainer"] = cfg["trainer"] - - hparams["callbacks"] = cfg.get("callbacks") - hparams["extras"] = cfg.get("extras") - - hparams["task_name"] = cfg.get("task_name") - hparams["tags"] = cfg.get("tags") - hparams["ckpt_path"] = cfg.get("ckpt_path") - hparams["seed"] = cfg.get("seed") - - # send hparams to all loggers - for logger in trainer.loggers: - logger.log_hyperparams(hparams) +from lightning.pytorch.utilities import rank_zero_only + +from fish_speech.utils import logger as log + + +@rank_zero_only +def log_hyperparameters(object_dict: dict) -> None: + """Controls which config parts are saved by lightning loggers. + + Additionally saves: + - Number of model parameters + """ + + hparams = {} + + cfg = object_dict["cfg"] + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + hparams["data"] = cfg["data"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) diff --git a/fish_speech/utils/rich_utils.py b/fish_speech/utils/rich_utils.py index 6a465f54d610779766d51e3d1a020a3b1517fd1f..5a11ba95b6e54461e9f4faba9ca1f98de6e194ab 100644 --- a/fish_speech/utils/rich_utils.py +++ b/fish_speech/utils/rich_utils.py @@ -1,100 +1,100 @@ -from pathlib import Path -from typing import Sequence - -import rich -import rich.syntax -import rich.tree -from hydra.core.hydra_config import HydraConfig -from lightning.pytorch.utilities import rank_zero_only -from omegaconf import DictConfig, OmegaConf, open_dict -from rich.prompt import Prompt - -from fish_speech.utils import logger as log - - -@rank_zero_only -def print_config_tree( - cfg: DictConfig, - print_order: Sequence[str] = ( - "data", - "model", - "callbacks", - "logger", - "trainer", - "paths", - "extras", - ), - resolve: bool = False, - save_to_file: bool = False, -) -> None: - """Prints content of DictConfig using Rich library and its tree structure. - - Args: - cfg (DictConfig): Configuration composed by Hydra. - print_order (Sequence[str], optional): Determines in what order config components are printed. - resolve (bool, optional): Whether to resolve reference fields of DictConfig. - save_to_file (bool, optional): Whether to export config to the hydra output folder. - """ # noqa: E501 - - style = "dim" - tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) - - queue = [] - - # add fields from `print_order` to queue - for field in print_order: - ( - queue.append(field) - if field in cfg - else log.warning( - f"Field '{field}' not found in config. " - + f"Skipping '{field}' config printing..." - ) - ) - - # add all the other fields to queue (not specified in `print_order`) - for field in cfg: - if field not in queue: - queue.append(field) - - # generate config tree from queue - for field in queue: - branch = tree.add(field, style=style, guide_style=style) - - config_group = cfg[field] - if isinstance(config_group, DictConfig): - branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) - else: - branch_content = str(config_group) - - branch.add(rich.syntax.Syntax(branch_content, "yaml")) - - # print config tree - rich.print(tree) - - # save config tree to file - if save_to_file: - with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: - rich.print(tree, file=file) - - -@rank_zero_only -def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: - """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501 - - if not cfg.get("tags"): - if "id" in HydraConfig().cfg.hydra.job: - raise ValueError("Specify tags before launching a multirun!") - - log.warning("No tags provided in config. Prompting user to input tags...") - tags = Prompt.ask("Enter a list of comma separated tags", default="dev") - tags = [t.strip() for t in tags.split(",") if t != ""] - - with open_dict(cfg): - cfg.tags = tags - - log.info(f"Tags: {cfg.tags}") - - if save_to_file: - with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: - rich.print(cfg.tags, file=file) +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from lightning.pytorch.utilities import rank_zero_only +from omegaconf import DictConfig, OmegaConf, open_dict +from rich.prompt import Prompt + +from fish_speech.utils import logger as log + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "data", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints content of DictConfig using Rich library and its tree structure. + + Args: + cfg (DictConfig): Configuration composed by Hydra. + print_order (Sequence[str], optional): Determines in what order config components are printed. + resolve (bool, optional): Whether to resolve reference fields of DictConfig. + save_to_file (bool, optional): Whether to export config to the hydra output folder. + """ # noqa: E501 + + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + ( + queue.append(field) + if field in cfg + else log.warning( + f"Field '{field}' not found in config. " + + f"Skipping '{field}' config printing..." + ) + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501 + + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) diff --git a/fish_speech/utils/spectrogram.py b/fish_speech/utils/spectrogram.py index 01c3d7a2ab0f707ae92dbde0feb173927720c841..81ce022e2e62781cc62016d70de33916c736f85d 100644 --- a/fish_speech/utils/spectrogram.py +++ b/fish_speech/utils/spectrogram.py @@ -1,122 +1,122 @@ -import torch -import torchaudio.functional as F -from torch import Tensor, nn -from torchaudio.transforms import MelScale - - -class LinearSpectrogram(nn.Module): - def __init__( - self, - n_fft=2048, - win_length=2048, - hop_length=512, - center=False, - mode="pow2_sqrt", - ): - super().__init__() - - self.n_fft = n_fft - self.win_length = win_length - self.hop_length = hop_length - self.center = center - self.mode = mode - - self.register_buffer("window", torch.hann_window(win_length), persistent=False) - - def forward(self, y: Tensor) -> Tensor: - if y.ndim == 3: - y = y.squeeze(1) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - ( - (self.win_length - self.hop_length) // 2, - (self.win_length - self.hop_length + 1) // 2, - ), - mode="reflect", - ).squeeze(1) - - spec = torch.stft( - y, - self.n_fft, - hop_length=self.hop_length, - win_length=self.win_length, - window=self.window, - center=self.center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - - spec = torch.view_as_real(spec) - - if self.mode == "pow2_sqrt": - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - - return spec - - -class LogMelSpectrogram(nn.Module): - def __init__( - self, - sample_rate=44100, - n_fft=2048, - win_length=2048, - hop_length=512, - n_mels=128, - center=False, - f_min=0.0, - f_max=None, - ): - super().__init__() - - self.sample_rate = sample_rate - self.n_fft = n_fft - self.win_length = win_length - self.hop_length = hop_length - self.center = center - self.n_mels = n_mels - self.f_min = f_min - self.f_max = f_max or float(sample_rate // 2) - - self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center) - - fb = F.melscale_fbanks( - n_freqs=self.n_fft // 2 + 1, - f_min=self.f_min, - f_max=self.f_max, - n_mels=self.n_mels, - sample_rate=self.sample_rate, - norm="slaney", - mel_scale="slaney", - ) - self.register_buffer( - "fb", - fb, - persistent=False, - ) - - def compress(self, x: Tensor) -> Tensor: - return torch.log(torch.clamp(x, min=1e-5)) - - def decompress(self, x: Tensor) -> Tensor: - return torch.exp(x) - - def apply_mel_scale(self, x: Tensor) -> Tensor: - return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2) - - def forward( - self, x: Tensor, return_linear: bool = False, sample_rate: int = None - ) -> Tensor: - if sample_rate is not None and sample_rate != self.sample_rate: - x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate) - - linear = self.spectrogram(x) - x = self.apply_mel_scale(linear) - x = self.compress(x) - - if return_linear: - return x, self.compress(linear) - - return x +import torch +import torchaudio.functional as F +from torch import Tensor, nn +from torchaudio.transforms import MelScale + + +class LinearSpectrogram(nn.Module): + def __init__( + self, + n_fft=2048, + win_length=2048, + hop_length=512, + center=False, + mode="pow2_sqrt", + ): + super().__init__() + + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.center = center + self.mode = mode + + self.register_buffer("window", torch.hann_window(win_length), persistent=False) + + def forward(self, y: Tensor) -> Tensor: + if y.ndim == 3: + y = y.squeeze(1) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + ( + (self.win_length - self.hop_length) // 2, + (self.win_length - self.hop_length + 1) // 2, + ), + mode="reflect", + ).squeeze(1) + + spec = torch.stft( + y, + self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + center=self.center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + + spec = torch.view_as_real(spec) + + if self.mode == "pow2_sqrt": + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + + return spec + + +class LogMelSpectrogram(nn.Module): + def __init__( + self, + sample_rate=44100, + n_fft=2048, + win_length=2048, + hop_length=512, + n_mels=128, + center=False, + f_min=0.0, + f_max=None, + ): + super().__init__() + + self.sample_rate = sample_rate + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.center = center + self.n_mels = n_mels + self.f_min = f_min + self.f_max = f_max or float(sample_rate // 2) + + self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center) + + fb = F.melscale_fbanks( + n_freqs=self.n_fft // 2 + 1, + f_min=self.f_min, + f_max=self.f_max, + n_mels=self.n_mels, + sample_rate=self.sample_rate, + norm="slaney", + mel_scale="slaney", + ) + self.register_buffer( + "fb", + fb, + persistent=False, + ) + + def compress(self, x: Tensor) -> Tensor: + return torch.log(torch.clamp(x, min=1e-5)) + + def decompress(self, x: Tensor) -> Tensor: + return torch.exp(x) + + def apply_mel_scale(self, x: Tensor) -> Tensor: + return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2) + + def forward( + self, x: Tensor, return_linear: bool = False, sample_rate: int = None + ) -> Tensor: + if sample_rate is not None and sample_rate != self.sample_rate: + x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate) + + linear = self.spectrogram(x) + x = self.apply_mel_scale(linear) + x = self.compress(x) + + if return_linear: + return x, self.compress(linear) + + return x diff --git a/fish_speech/utils/utils.py b/fish_speech/utils/utils.py index c546bfa1eddd2ac6bf484cce1ec06da1d33fb121..f5e02cd0c8f8dec1d002e7d634b00434e70873f9 100644 --- a/fish_speech/utils/utils.py +++ b/fish_speech/utils/utils.py @@ -1,114 +1,136 @@ -import warnings -from importlib.util import find_spec -from typing import Callable - -from omegaconf import DictConfig - -from .logger import RankedLogger -from .rich_utils import enforce_tags, print_config_tree - -log = RankedLogger(__name__, rank_zero_only=True) - - -def extras(cfg: DictConfig) -> None: - """Applies optional utilities before the task is started. - - Utilities: - - Ignoring python warnings - - Setting tags from command line - - Rich config printing - """ - - # return if no `extras` config - if not cfg.get("extras"): - log.warning("Extras config not found! ") - return - - # disable python warnings - if cfg.extras.get("ignore_warnings"): - log.info("Disabling python warnings! ") - warnings.filterwarnings("ignore") - - # prompt user to input tags from command line if none are provided in the config - if cfg.extras.get("enforce_tags"): - log.info("Enforcing tags! ") - enforce_tags(cfg, save_to_file=True) - - # pretty print config tree using Rich library - if cfg.extras.get("print_config"): - log.info("Printing config tree with Rich! ") - print_config_tree(cfg, resolve=True, save_to_file=True) - - -def task_wrapper(task_func: Callable) -> Callable: - """Optional decorator that controls the failure behavior when executing the task function. - - This wrapper can be used to: - - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) - - save the exception to a `.log` file - - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) - - etc. (adjust depending on your needs) - - Example: - ``` - @utils.task_wrapper - def train(cfg: DictConfig) -> Tuple[dict, dict]: - - ... - - return metric_dict, object_dict - ``` - """ # noqa: E501 - - def wrap(cfg: DictConfig): - # execute the task - try: - metric_dict, object_dict = task_func(cfg=cfg) - - # things to do if exception occurs - except Exception as ex: - # save exception to `.log` file - log.exception("") - - # some hyperparameter combinations might be invalid or - # cause out-of-memory errors so when using hparam search - # plugins like Optuna, you might want to disable - # raising the below exception to avoid multirun failure - raise ex - - # things to always do after either success or exception - finally: - # display output dir path in terminal - log.info(f"Output dir: {cfg.paths.run_dir}") - - # always close wandb run (even if exception occurs so multirun won't fail) - if find_spec("wandb"): # check if wandb is installed - import wandb - - if wandb.run: - log.info("Closing wandb!") - wandb.finish() - - return metric_dict, object_dict - - return wrap - - -def get_metric_value(metric_dict: dict, metric_name: str) -> float: - """Safely retrieves value of the metric logged in LightningModule.""" - - if not metric_name: - log.info("Metric name is None! Skipping metric value retrieval...") - return None - - if metric_name not in metric_dict: - raise Exception( - f"Metric value not found! \n" - "Make sure metric name logged in LightningModule is correct!\n" - "Make sure `optimized_metric` name in `hparams_search` config is correct!" - ) - - metric_value = metric_dict[metric_name].item() - log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") - - return metric_value +import random +import warnings +from importlib.util import find_spec +from typing import Callable + +import numpy as np +import torch +from omegaconf import DictConfig + +from .logger import RankedLogger +from .rich_utils import enforce_tags, print_config_tree + +log = RankedLogger(__name__, rank_zero_only=True) + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + """ + + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + print_config_tree(cfg, resolve=True, save_to_file=True) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that controls the failure behavior when executing the task function. + + This wrapper can be used to: + - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) + - save the exception to a `.log` file + - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) + - etc. (adjust depending on your needs) + + Example: + ``` + @utils.task_wrapper + def train(cfg: DictConfig) -> Tuple[dict, dict]: + + ... + + return metric_dict, object_dict + ``` + """ # noqa: E501 + + def wrap(cfg: DictConfig): + # execute the task + try: + metric_dict, object_dict = task_func(cfg=cfg) + + # things to do if exception occurs + except Exception as ex: + # save exception to `.log` file + log.exception("") + + # some hyperparameter combinations might be invalid or + # cause out-of-memory errors so when using hparam search + # plugins like Optuna, you might want to disable + # raising the below exception to avoid multirun failure + raise ex + + # things to always do after either success or exception + finally: + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.run_dir}") + + # always close wandb run (even if exception occurs so multirun won't fail) + if find_spec("wandb"): # check if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() + + return metric_dict, object_dict + + return wrap + + +def get_metric_value(metric_dict: dict, metric_name: str) -> float: + """Safely retrieves value of the metric logged in LightningModule.""" + + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise Exception( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value + + +def set_seed(seed: int): + if seed < 0: + seed = -seed + if seed > (1 << 31): + seed = 1 << 31 + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + if torch.backends.cudnn.is_available(): + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False diff --git a/fish_speech/webui/css/style.css b/fish_speech/webui/css/style.css index 3c7a22ecc31881a65a76369b0fd889330a0874c7..0059de9aca35e9bc93ea8e7da7cd1b4a55c24410 100644 --- a/fish_speech/webui/css/style.css +++ b/fish_speech/webui/css/style.css @@ -1,161 +1,161 @@ -:root { - --my-200: #80eeee; - --my-50: #ecfdf5; - --water-width: 300px; - --water-heigh: 300px; -} - - -/* general styled components */ -.tools { - align-items: center; - justify-content: center; -} - -.gradio-button { - max-width: 2.2em; - min-width: 2.2em !important; - height: 2.4em; - align-self: end; - line-height: 1em; - border-radius: 0.5em; - -} - -.gradio-button.secondary-down, .gradio-button.secondary-down:hover{ - box-shadow: 1px 1px 1px rgba(0,0,0,0.25) inset, 0px 0px 3px rgba(0,0,0,0.15) inset; -} - -/* replace original footer with ours */ -a{ - font-weight: bold; - cursor: pointer; - color: #030C14 !important; -} - -footer { - display: none !important; -} - -#footer{ - text-align: center; -} - -#footer div{ - display: inline-block; -} - -#footer .versions{ - font-size: 85%; - opacity: 0.85; -} - -/*@keyframes moveBackground {*/ -/* 0% {*/ -/* background-position: 0 0;*/ -/* }*/ -/* 100% {*/ -/* background-position: -100px 100px;*/ -/* }*/ -/*}*/ -@keyframes moveJellyBackground { - 0% { - background-position: 0% 50%; - } - 50% { - background-position: 100% 50%; - } - 100% { - background-position: 0% 50%; - } -} - -.gradio-container { - position: absolute; - z-index: 10; -} - - -.quan { - position: absolute; - bottom: 0; - width: var(--water-width); - height: var(--water-heigh); - border-radius: 0; - /*border: 3px solid rgb(246, 247, 248);*/ - /*box-shadow: 0 0 0 3px rgb(41, 134, 196);*/ - z-index: 0; - -} - -.quan:last-child { - margin-right: 0; -} - -.shui { - position: absolute; - top: 0; - left: 0; - width: 100%; - height: 100%; - background-color: rgb(23, 106, 201); - border-radius: 0; - overflow: hidden; - z-index: 0; -} - -.shui::after { - - content: ''; - position: absolute; - top: 20%; - left: 50%; - width: 150%; - height: 150%; - border-radius: 40%; - background-image: radial-gradient(circle at 0% 50%, #dcfcf1, var(--my-50) 50%); - animation: shi 5s linear infinite; -} - -@keyframes shi { - 0% { - transform: translate(-50%, -65%) rotate(0deg); - } - 100% { - transform: translate(-50%, -65%) rotate(360deg); - } -} - -.shui::before { - content: ''; - position: absolute; - top: 20%; - left: 50%; - width: 150%; - height: 150%; - border-radius: 42%; - background-color: rgb(240, 228, 228, 0.2); - animation: xu 7s linear infinite; -} - -@keyframes xu { - 0% { - transform: translate(-50%, -60%) rotate(0deg); - } - 100% { - transform: translate(-50%, -60%) rotate(360deg); - } -} - -fieldset.data_src div.wrap label { - background: #f8bffee0 !important; -} - -.scrollable-component { - max-height: 100px; - overflow-y: auto; -} - -#file_accordion { - max-height: 220px !important; -} +:root { + --my-200: #80eeee; + --my-50: #ecfdf5; + --water-width: 300px; + --water-heigh: 300px; +} + + +/* general styled components */ +.tools { + align-items: center; + justify-content: center; +} + +.gradio-button { + max-width: 2.2em; + min-width: 2.2em !important; + height: 2.4em; + align-self: end; + line-height: 1em; + border-radius: 0.5em; + +} + +.gradio-button.secondary-down, .gradio-button.secondary-down:hover{ + box-shadow: 1px 1px 1px rgba(0,0,0,0.25) inset, 0px 0px 3px rgba(0,0,0,0.15) inset; +} + +/* replace original footer with ours */ +a{ + font-weight: bold; + cursor: pointer; + color: #030C14 !important; +} + +footer { + display: none !important; +} + +#footer{ + text-align: center; +} + +#footer div{ + display: inline-block; +} + +#footer .versions{ + font-size: 85%; + opacity: 0.85; +} + +/*@keyframes moveBackground {*/ +/* 0% {*/ +/* background-position: 0 0;*/ +/* }*/ +/* 100% {*/ +/* background-position: -100px 100px;*/ +/* }*/ +/*}*/ +@keyframes moveJellyBackground { + 0% { + background-position: 0% 50%; + } + 50% { + background-position: 100% 50%; + } + 100% { + background-position: 0% 50%; + } +} + +.gradio-container { + position: absolute; + z-index: 10; +} + + +.quan { + position: absolute; + bottom: 0; + width: var(--water-width); + height: var(--water-heigh); + border-radius: 0; + /*border: 3px solid rgb(246, 247, 248);*/ + /*box-shadow: 0 0 0 3px rgb(41, 134, 196);*/ + z-index: 0; + +} + +.quan:last-child { + margin-right: 0; +} + +.shui { + position: absolute; + top: 0; + left: 0; + width: 100%; + height: 100%; + background-color: rgb(23, 106, 201); + border-radius: 0; + overflow: hidden; + z-index: 0; +} + +.shui::after { + + content: ''; + position: absolute; + top: 20%; + left: 50%; + width: 150%; + height: 150%; + border-radius: 40%; + background-image: radial-gradient(circle at 0% 50%, #dcfcf1, var(--my-50) 50%); + animation: shi 5s linear infinite; +} + +@keyframes shi { + 0% { + transform: translate(-50%, -65%) rotate(0deg); + } + 100% { + transform: translate(-50%, -65%) rotate(360deg); + } +} + +.shui::before { + content: ''; + position: absolute; + top: 20%; + left: 50%; + width: 150%; + height: 150%; + border-radius: 42%; + background-color: rgb(240, 228, 228, 0.2); + animation: xu 7s linear infinite; +} + +@keyframes xu { + 0% { + transform: translate(-50%, -60%) rotate(0deg); + } + 100% { + transform: translate(-50%, -60%) rotate(360deg); + } +} + +fieldset.data_src div.wrap label { + background: #f8bffee0 !important; +} + +.scrollable-component { + max-height: 100px; + overflow-y: auto; +} + +#file_accordion { + max-height: 220px !important; +} diff --git a/fish_speech/webui/html/footer.html b/fish_speech/webui/html/footer.html index ac1745aa6f41f86a17e3d95564c2bf7a8d7bb615..ad53df1c34f5cc3ee05bf713c6aa82e78c14e47e 100644 --- a/fish_speech/webui/html/footer.html +++ b/fish_speech/webui/html/footer.html @@ -1,11 +1,11 @@ -
- API -  •  - Github -  •  - Gradio -
-
-
-{versions} -
+
+ API +  •  + Github +  •  + Gradio +
+
+
+{versions} +
diff --git a/fish_speech/webui/js/animate.js b/fish_speech/webui/js/animate.js index 0637a541a8e704632a42b89bdf1471b26e7bb868..143e10d979317319b0ba257fd2c08d17fc5514e3 100644 --- a/fish_speech/webui/js/animate.js +++ b/fish_speech/webui/js/animate.js @@ -1,69 +1,69 @@ - -function createGradioAnimation() { - const params = new URLSearchParams(window.location.search); - if (!params.has('__theme')) { - params.set('__theme', 'light'); - window.location.search = params.toString(); - } - - var gradioApp = document.querySelector('gradio-app'); - if (gradioApp) { - - document.documentElement.style.setProperty('--my-200', '#80eeee'); - document.documentElement.style.setProperty('--my-50', '#ecfdf5'); - - // gradioApp.style.position = 'relative'; - // gradioApp.style.backgroundSize = '200% 200%'; - // gradioApp.style.animation = 'moveJellyBackground 10s ease infinite'; - // gradioApp.style.backgroundImage = 'radial-gradient(circle at 0% 50%, var(--my-200), var(--my-50) 50%)'; - // gradioApp.style.display = 'flex'; - // gradioApp.style.justifyContent = 'flex-start'; - // gradioApp.style.flexWrap = 'nowrap'; - // gradioApp.style.overflowX = 'auto'; - - // for (let i = 0; i < 6; i++) { - // var quan = document.createElement('div'); - // quan.className = 'quan'; - // gradioApp.insertBefore(quan, gradioApp.firstChild); - // quan.id = 'quan' + i.toString(); - // quan.style.left = 'calc(var(--water-width) * ' + i.toString() + ')'; - // var quanContainer = document.querySelector('.quan'); - // if (quanContainer) { - // var shui = document.createElement('div'); - // shui.className = 'shui'; - // quanContainer.insertBefore(shui, quanContainer.firstChild) - // } - // } - } - - var container = document.createElement('div'); - container.id = 'gradio-animation'; - container.style.fontSize = '2em'; - container.style.fontFamily = 'Maiandra GD, ui-monospace, monospace'; - container.style.fontWeight = 'bold'; - container.style.textAlign = 'center'; - container.style.marginBottom = '20px'; - - var text = 'Welcome to Fish-Speech!'; - for (var i = 0; i < text.length; i++) { - (function(i){ - setTimeout(function(){ - var letter = document.createElement('span'); - letter.style.opacity = '0'; - letter.style.transition = 'opacity 0.5s'; - letter.innerText = text[i]; - - container.appendChild(letter); - - setTimeout(function() { - letter.style.opacity = '1'; - }, 50); - }, i * 200); - })(i); - } - - var gradioContainer = document.querySelector('.gradio-container'); - gradioContainer.insertBefore(container, gradioContainer.firstChild); - - return 'Animation created'; -} + +function createGradioAnimation() { + const params = new URLSearchParams(window.location.search); + if (!params.has('__theme')) { + params.set('__theme', 'light'); + window.location.search = params.toString(); + } + + var gradioApp = document.querySelector('gradio-app'); + if (gradioApp) { + + document.documentElement.style.setProperty('--my-200', '#80eeee'); + document.documentElement.style.setProperty('--my-50', '#ecfdf5'); + + // gradioApp.style.position = 'relative'; + // gradioApp.style.backgroundSize = '200% 200%'; + // gradioApp.style.animation = 'moveJellyBackground 10s ease infinite'; + // gradioApp.style.backgroundImage = 'radial-gradient(circle at 0% 50%, var(--my-200), var(--my-50) 50%)'; + // gradioApp.style.display = 'flex'; + // gradioApp.style.justifyContent = 'flex-start'; + // gradioApp.style.flexWrap = 'nowrap'; + // gradioApp.style.overflowX = 'auto'; + + // for (let i = 0; i < 6; i++) { + // var quan = document.createElement('div'); + // quan.className = 'quan'; + // gradioApp.insertBefore(quan, gradioApp.firstChild); + // quan.id = 'quan' + i.toString(); + // quan.style.left = 'calc(var(--water-width) * ' + i.toString() + ')'; + // var quanContainer = document.querySelector('.quan'); + // if (quanContainer) { + // var shui = document.createElement('div'); + // shui.className = 'shui'; + // quanContainer.insertBefore(shui, quanContainer.firstChild) + // } + // } + } + + var container = document.createElement('div'); + container.id = 'gradio-animation'; + container.style.fontSize = '2em'; + container.style.fontFamily = 'Maiandra GD, ui-monospace, monospace'; + container.style.fontWeight = 'bold'; + container.style.textAlign = 'center'; + container.style.marginBottom = '20px'; + + var text = 'Welcome to Fish-Speech!'; + for (var i = 0; i < text.length; i++) { + (function(i){ + setTimeout(function(){ + var letter = document.createElement('span'); + letter.style.opacity = '0'; + letter.style.transition = 'opacity 0.5s'; + letter.innerText = text[i]; + + container.appendChild(letter); + + setTimeout(function() { + letter.style.opacity = '1'; + }, 50); + }, i * 200); + })(i); + } + + var gradioContainer = document.querySelector('.gradio-container'); + gradioContainer.insertBefore(container, gradioContainer.firstChild); + + return 'Animation created'; +} diff --git a/fish_speech/webui/launch_utils.py b/fish_speech/webui/launch_utils.py index 2f57b595a20177800dbedd71faef573ee8398418..88164c6c8e38714dcd7f7a2bc850ab56819053d6 100644 --- a/fish_speech/webui/launch_utils.py +++ b/fish_speech/webui/launch_utils.py @@ -1,120 +1,120 @@ -import importlib.util -import os -import subprocess -import sys -from functools import lru_cache -from pathlib import Path -from typing import Iterable - -import gradio as gr -from gradio.themes.base import Base -from gradio.themes.utils import colors, fonts, sizes - -GIT = ( - (Path(os.environ.get("GIT_HOME", "")) / "git").resolve() - if sys.platform == "win32" - else "git" -) -GIT = str(GIT) - - -def is_module_installed(module_name: str) -> bool: - spec = importlib.util.find_spec(module_name) - return spec is not None - - -@lru_cache() -def commit_hash(): - try: - return subprocess.check_output( - [GIT, "log", "-1", "--format='%h %s'"], shell=False, encoding="utf8" - ).strip() - except Exception: - return "" - - -def versions_html(): - import torch - - python_version = ".".join([str(x) for x in sys.version_info[0:3]]) - commit = commit_hash() - hash = commit.strip("'").split(" ")[0] - - return f""" -version: {hash} - •  -python: {python_version} - •  -torch: {getattr(torch, '__long_version__',torch.__version__)} - •  -gradio: {gr.__version__} - •  -author: fishaudio -""" - - -def version_check(commit): - try: - import requests - - commits = requests.get( - "https://api.github.com/repos/fishaudio/fish-speech/branches/main" - ).json() - if commit != "" and commits["commit"]["sha"] != commit: - print("--------------------------------------------------------") - print("| You are not up to date with the most recent release. |") - print("| Consider running `git pull` to update. |") - print("--------------------------------------------------------") - elif commits["commit"]["sha"] == commit: - print("You are up to date with the most recent release.") - else: - print("Not a git clone, can't perform version check.") - except Exception as e: - print("version check failed", e) - - -class Seafoam(Base): - def __init__( - self, - *, - primary_hue: colors.Color | str = colors.emerald, - secondary_hue: colors.Color | str = colors.blue, - neutral_hue: colors.Color | str = colors.blue, - spacing_size: sizes.Size | str = sizes.spacing_md, - radius_size: sizes.Size | str = sizes.radius_md, - text_size: sizes.Size | str = sizes.text_lg, - font: fonts.Font | str | Iterable[fonts.Font | str] = ( - fonts.GoogleFont("Quicksand"), - "ui-sans-serif", - "sans-serif", - ), - font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( - fonts.GoogleFont("IBM Plex Mono"), - "ui-monospace", - "monospace", - ), - ): - super().__init__( - primary_hue=primary_hue, - secondary_hue=secondary_hue, - neutral_hue=neutral_hue, - spacing_size=spacing_size, - radius_size=radius_size, - text_size=text_size, - font=font, - font_mono=font_mono, - ) - super().set( - button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)", - button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)", - button_primary_text_color="white", - button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)", - slider_color="*secondary_300", - slider_color_dark="*secondary_600", - block_title_text_weight="600", - block_border_width="3px", - block_shadow="*shadow_drop_lg", - button_shadow="*shadow_drop_lg", - button_small_padding="0px", - button_large_padding="3px", - ) +import importlib.util +import os +import subprocess +import sys +from functools import lru_cache +from pathlib import Path +from typing import Iterable + +import gradio as gr +from gradio.themes.base import Base +from gradio.themes.utils import colors, fonts, sizes + +GIT = ( + (Path(os.environ.get("GIT_HOME", "")) / "git").resolve() + if sys.platform == "win32" + else "git" +) +GIT = str(GIT) + + +def is_module_installed(module_name: str) -> bool: + spec = importlib.util.find_spec(module_name) + return spec is not None + + +@lru_cache() +def commit_hash(): + try: + return subprocess.check_output( + [GIT, "log", "-1", "--format='%h %s'"], shell=False, encoding="utf8" + ).strip() + except Exception: + return "" + + +def versions_html(): + import torch + + python_version = ".".join([str(x) for x in sys.version_info[0:3]]) + commit = commit_hash() + hash = commit.strip("'").split(" ")[0] + + return f""" +version: {hash} + •  +python: {python_version} + •  +torch: {getattr(torch, '__long_version__',torch.__version__)} + •  +gradio: {gr.__version__} + •  +author: fishaudio +""" + + +def version_check(commit): + try: + import requests + + commits = requests.get( + "https://api.github.com/repos/fishaudio/fish-speech/branches/main" + ).json() + if commit != "" and commits["commit"]["sha"] != commit: + print("--------------------------------------------------------") + print("| You are not up to date with the most recent release. |") + print("| Consider running `git pull` to update. |") + print("--------------------------------------------------------") + elif commits["commit"]["sha"] == commit: + print("You are up to date with the most recent release.") + else: + print("Not a git clone, can't perform version check.") + except Exception as e: + print("version check failed", e) + + +class Seafoam(Base): + def __init__( + self, + *, + primary_hue: colors.Color | str = colors.emerald, + secondary_hue: colors.Color | str = colors.blue, + neutral_hue: colors.Color | str = colors.blue, + spacing_size: sizes.Size | str = sizes.spacing_md, + radius_size: sizes.Size | str = sizes.radius_md, + text_size: sizes.Size | str = sizes.text_lg, + font: fonts.Font | str | Iterable[fonts.Font | str] = ( + fonts.GoogleFont("Quicksand"), + "ui-sans-serif", + "sans-serif", + ), + font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( + fonts.GoogleFont("IBM Plex Mono"), + "ui-monospace", + "monospace", + ), + ): + super().__init__( + primary_hue=primary_hue, + secondary_hue=secondary_hue, + neutral_hue=neutral_hue, + spacing_size=spacing_size, + radius_size=radius_size, + text_size=text_size, + font=font, + font_mono=font_mono, + ) + super().set( + button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)", + button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)", + button_primary_text_color="white", + button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)", + slider_color="*secondary_300", + slider_color_dark="*secondary_600", + block_title_text_weight="600", + block_border_width="3px", + block_shadow="*shadow_drop_lg", + # button_shadow="*shadow_drop_lg", + button_small_padding="0px", + button_large_padding="3px", + ) diff --git a/fish_speech/webui/manage.py b/fish_speech/webui/manage.py index 09c5f001ad247f508d7a867207755e983e18d667..fc20fbc8daf61f22a4b29638740f8676254aa608 100644 --- a/fish_speech/webui/manage.py +++ b/fish_speech/webui/manage.py @@ -1,1237 +1,1239 @@ -from __future__ import annotations - -import datetime -import html -import json -import os -import platform -import shutil -import signal -import subprocess -import sys -from pathlib import Path - -import gradio as gr -import psutil -import yaml -from loguru import logger -from tqdm import tqdm - -PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python") -sys.path.insert(0, "") -print(sys.path) -cur_work_dir = Path(os.getcwd()).resolve() -print("You are in ", str(cur_work_dir)) - -from fish_speech.i18n import i18n -from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html - -config_path = cur_work_dir / "fish_speech" / "configs" -vqgan_yml_path = config_path / "firefly_gan_vq.yaml" -llama_yml_path = config_path / "text2semantic_finetune.yaml" - -env = os.environ.copy() -env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0" - -seafoam = Seafoam() - - -def build_html_error_message(error): - return f""" -
- {html.escape(error)} -
- """ - - -def build_html_ok_message(msg): - return f""" -
- {html.escape(msg)} -
- """ - - -def build_html_href(link, desc, msg): - return f""" - - {html.escape(msg)} - {desc} - - """ - - -def load_data_in_raw(path): - with open(path, "r", encoding="utf-8") as file: - data = file.read() - return str(data) - - -def kill_proc_tree(pid, including_parent=True): - try: - parent = psutil.Process(pid) - except psutil.NoSuchProcess: - # Process already terminated - return - - children = parent.children(recursive=True) - for child in children: - try: - os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL - except OSError: - pass - if including_parent: - try: - os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL - except OSError: - pass - - -system = platform.system() -p_label = None -p_infer = None -p_tensorboard = None - - -def kill_process(pid): - if system == "Windows": - cmd = "taskkill /t /f /pid %s" % pid - # os.system(cmd) - subprocess.run(cmd) - else: - kill_proc_tree(pid) - - -def change_label(if_label): - global p_label - if if_label == True and p_label is None: - url = "http://localhost:3000" - remote_url = "https://text-labeler.pages.dev/" - try: - p_label = subprocess.Popen( - [ - ( - "asr-label-linux-x64" - if sys.platform == "linux" - else "asr-label-win-x64.exe" - ) - ] - ) - except FileNotFoundError: - logger.warning("asr-label execution not found!") - - yield build_html_href( - link=remote_url, - desc=i18n("Optional online ver"), - msg=i18n("Opened labeler in browser"), - ) - - elif if_label == False and p_label is not None: - kill_process(p_label.pid) - p_label = None - yield build_html_ok_message("Nothing") - - -def clean_infer_cache(): - import tempfile - - temp_dir = Path(tempfile.gettempdir()) - gradio_dir = str(temp_dir / "gradio") - try: - shutil.rmtree(gradio_dir) - logger.info(f"Deleted cached audios: {gradio_dir}") - except PermissionError: - logger.info(f"Permission denied: Unable to delete {gradio_dir}") - except FileNotFoundError: - logger.info(f"{gradio_dir} was not found") - except Exception as e: - logger.info(f"An error occurred: {e}") - - -def change_infer( - if_infer, - host, - port, - infer_decoder_model, - infer_decoder_config, - infer_llama_model, - infer_compile, -): - global p_infer - if if_infer == True and p_infer == None: - env = os.environ.copy() - - env["GRADIO_SERVER_NAME"] = host - env["GRADIO_SERVER_PORT"] = port - # 启动第二个进程 - url = f"http://{host}:{port}" - yield build_html_ok_message( - i18n("Inferring interface is launched at {}").format(url) - ) - - clean_infer_cache() - - p_infer = subprocess.Popen( - [ - PYTHON, - "tools/webui.py", - "--decoder-checkpoint-path", - infer_decoder_model, - "--decoder-config-name", - infer_decoder_config, - "--llama-checkpoint-path", - infer_llama_model, - ] - + (["--compile"] if infer_compile == "Yes" else []), - env=env, - ) - - elif if_infer == False and p_infer is not None: - kill_process(p_infer.pid) - p_infer = None - yield build_html_error_message(i18n("Infer interface is closed")) - - -js = load_data_in_raw("fish_speech/webui/js/animate.js") -css = load_data_in_raw("fish_speech/webui/css/style.css") - -data_pre_output = (cur_work_dir / "data").resolve() -default_model_output = (cur_work_dir / "results").resolve() -default_filelist = data_pre_output / "detect.list" -data_pre_output.mkdir(parents=True, exist_ok=True) - -items = [] -dict_items = {} - - -def load_yaml_data_in_fact(yml_path): - with open(yml_path, "r", encoding="utf-8") as file: - yml = yaml.safe_load(file) - return yml - - -def write_yaml_data_in_fact(yml, yml_path): - with open(yml_path, "w", encoding="utf-8") as file: - yaml.safe_dump(yml, file, allow_unicode=True) - return yml - - -def generate_tree(directory, depth=0, max_depth=None, prefix=""): - if max_depth is not None and depth > max_depth: - return "" - - tree_str = "" - files = [] - directories = [] - for item in os.listdir(directory): - if os.path.isdir(os.path.join(directory, item)): - directories.append(item) - else: - files.append(item) - - entries = directories + files - for i, entry in enumerate(entries): - connector = "├── " if i < len(entries) - 1 else "└── " - tree_str += f"{prefix}{connector}{entry}
" - if i < len(directories): - extension = "│ " if i < len(entries) - 1 else " " - tree_str += generate_tree( - os.path.join(directory, entry), - depth + 1, - max_depth, - prefix=prefix + extension, - ) - return tree_str - - -def new_explorer(data_path, max_depth): - return gr.Markdown( - elem_classes=["scrollable-component"], - value=generate_tree(data_path, max_depth=max_depth), - ) - - -def add_item( - folder: str, - method: str, - label_lang: str, - if_initial_prompt: bool, - initial_prompt: str | None, -): - folder = folder.strip(" ").strip('"') - - folder_path = Path(folder) - - if folder and folder not in items and data_pre_output not in folder_path.parents: - if folder_path.is_dir(): - items.append(folder) - dict_items[folder] = dict( - type="folder", - method=method, - label_lang=label_lang, - initial_prompt=initial_prompt if if_initial_prompt else None, - ) - elif folder: - err = folder - return gr.Checkboxgroup(choices=items), build_html_error_message( - i18n("Invalid path: {}").format(err) - ) - - formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4) - logger.info("After Adding: " + formatted_data) - gr.Info(formatted_data) - return gr.Checkboxgroup(choices=items), build_html_ok_message( - i18n("Added path successfully!") - ) - - -def remove_items(selected_items): - global items, dict_items - to_remove = [item for item in items if item in selected_items] - for item in to_remove: - del dict_items[item] - items = [item for item in items if item in dict_items.keys()] - formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4) - logger.info(formatted_data) - gr.Warning("After Removing: " + formatted_data) - return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message( - i18n("Removed path successfully!") - ) - - -def show_selected(options): - selected_options = ", ".join(options) - - if options: - return i18n("Selected: {}").format(selected_options) - else: - return i18n("No selected options") - - -from pydub import AudioSegment - - -def convert_to_mono_in_place(audio_path: Path): - audio = AudioSegment.from_file(audio_path) - if audio.channels > 1: - mono_audio = audio.set_channels(1) - mono_audio.export(audio_path, format=audio_path.suffix[1:]) - logger.info(f"Convert {audio_path} successfully") - - -def list_copy(list_file_path, method): - wav_root = data_pre_output - lst = [] - with list_file_path.open("r", encoding="utf-8") as file: - for line in tqdm(file, desc="Processing audio/transcript"): - wav_path, speaker_name, language, text = line.strip().split("|") - original_wav_path = Path(wav_path) - target_wav_path = ( - wav_root / original_wav_path.parent.name / original_wav_path.name - ) - lst.append(f"{target_wav_path}|{speaker_name}|{language}|{text}") - if target_wav_path.is_file(): - continue - target_wav_path.parent.mkdir(parents=True, exist_ok=True) - if method == i18n("Copy"): - shutil.copy(original_wav_path, target_wav_path) - else: - shutil.move(original_wav_path, target_wav_path.parent) - convert_to_mono_in_place(target_wav_path) - original_lab_path = original_wav_path.with_suffix(".lab") - target_lab_path = ( - wav_root - / original_wav_path.parent.name - / original_wav_path.with_suffix(".lab").name - ) - if target_lab_path.is_file(): - continue - if method == i18n("Copy"): - shutil.copy(original_lab_path, target_lab_path) - else: - shutil.move(original_lab_path, target_lab_path.parent) - - if method == i18n("Move"): - with list_file_path.open("w", encoding="utf-8") as file: - file.writelines("\n".join(lst)) - - del lst - return build_html_ok_message(i18n("Use filelist")) - - -def check_files(data_path: str, max_depth: int, label_model: str, label_device: str): - global dict_items - data_path = Path(data_path) - gr.Warning("Pre-processing begins...") - for item, content in dict_items.items(): - item_path = Path(item) - tar_path = data_path / item_path.name - - if content["type"] == "folder" and item_path.is_dir(): - if content["method"] == i18n("Copy"): - os.makedirs(tar_path, exist_ok=True) - shutil.copytree( - src=str(item_path), dst=str(tar_path), dirs_exist_ok=True - ) - elif not tar_path.is_dir(): - shutil.move(src=str(item_path), dst=str(tar_path)) - - for suf in ["wav", "flac", "mp3"]: - for audio_path in tar_path.glob(f"**/*.{suf}"): - convert_to_mono_in_place(audio_path) - - cur_lang = content["label_lang"] - initial_prompt = content["initial_prompt"] - - transcribe_cmd = [ - PYTHON, - "tools/whisper_asr.py", - "--model-size", - label_model, - "--device", - label_device, - "--audio-dir", - tar_path, - "--save-dir", - tar_path, - "--language", - cur_lang, - ] - - if initial_prompt is not None: - transcribe_cmd += ["--initial-prompt", initial_prompt] - - if cur_lang != "IGNORE": - try: - gr.Warning("Begin To Transcribe") - subprocess.run( - transcribe_cmd, - env=env, - ) - except Exception: - print("Transcription error occurred") - - elif content["type"] == "file" and item_path.is_file(): - list_copy(item_path, content["method"]) - - return build_html_ok_message(i18n("Move files successfully")), new_explorer( - data_path, max_depth=max_depth - ) - - -def generate_folder_name(): - now = datetime.datetime.now() - folder_name = now.strftime("%Y%m%d_%H%M%S") - return folder_name - - -def train_process( - data_path: str, - option: str, - # llama config - llama_ckpt, - llama_base_config, - llama_lr, - llama_maxsteps, - llama_data_num_workers, - llama_data_batch_size, - llama_data_max_length, - llama_precision, - llama_check_interval, - llama_grad_batches, - llama_use_speaker, - llama_use_lora, -): - - backend = "nccl" if sys.platform == "linux" else "gloo" - - new_project = generate_folder_name() - print("New Project Name: ", new_project) - - if option == "VQGAN": - msg = "Skipped VQGAN Training." - gr.Warning(msg) - logger.info(msg) - - if option == "LLAMA": - msg = "LLAMA Training begins..." - gr.Warning(msg) - logger.info(msg) - subprocess.run( - [ - PYTHON, - "tools/vqgan/extract_vq.py", - str(data_pre_output), - "--num-workers", - "1", - "--batch-size", - "16", - "--config-name", - "firefly_gan_vq", - "--checkpoint-path", - "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", - ] - ) - - subprocess.run( - [ - PYTHON, - "tools/llama/build_dataset.py", - "--input", - str(data_pre_output), - "--text-extension", - ".lab", - "--num-workers", - "16", - ] - ) - ckpt_path = "checkpoints/fish-speech-1.4/model.pth" - lora_prefix = "lora_" if llama_use_lora else "" - llama_name = lora_prefix + "text2semantic_" + new_project - latest = next( - iter( - sorted( - [ - str(p.relative_to("results")) - for p in Path("results").glob(lora_prefix + "text2sem*/") - ], - reverse=True, - ) - ), - llama_name, - ) - project = ( - llama_name - if llama_ckpt == i18n("new") - else ( - latest - if llama_ckpt == i18n("latest") - else Path(llama_ckpt).relative_to("results") - ) - ) - logger.info(project) - - if llama_check_interval > llama_maxsteps: - llama_check_interval = llama_maxsteps - - train_cmd = [ - PYTHON, - "fish_speech/train.py", - "--config-name", - "text2semantic_finetune", - f"project={project}", - f"trainer.strategy.process_group_backend={backend}", - f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}", - f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}", - f"model.optimizer.lr={llama_lr}", - f"trainer.max_steps={llama_maxsteps}", - f"data.num_workers={llama_data_num_workers}", - f"data.batch_size={llama_data_batch_size}", - f"max_length={llama_data_max_length}", - f"trainer.precision={llama_precision}", - f"trainer.val_check_interval={llama_check_interval}", - f"trainer.accumulate_grad_batches={llama_grad_batches}", - f"train_dataset.interactive_prob={llama_use_speaker}", - ] + ([f"+lora@model.model.lora_config=r_8_alpha_16"] if llama_use_lora else []) - logger.info(train_cmd) - subprocess.run(train_cmd) - - return build_html_ok_message(i18n("Training stopped")) - - -def tensorboard_process( - if_tensorboard: bool, - tensorboard_dir: str, - host: str, - port: str, -): - global p_tensorboard - if if_tensorboard == True and p_tensorboard == None: - url = f"http://{host}:{port}" - yield build_html_ok_message( - i18n("Tensorboard interface is launched at {}").format(url) - ) - prefix = ["tensorboard"] - if Path("fishenv").exists(): - prefix = ["fishenv/env/python.exe", "fishenv/env/Scripts/tensorboard.exe"] - - p_tensorboard = subprocess.Popen( - prefix - + [ - "--logdir", - tensorboard_dir, - "--host", - host, - "--port", - port, - "--reload_interval", - "120", - ] - ) - elif if_tensorboard == False and p_tensorboard != None: - kill_process(p_tensorboard.pid) - p_tensorboard = None - yield build_html_error_message(i18n("Tensorboard interface is closed")) - - -def fresh_tb_dir(): - return gr.Dropdown( - choices=[str(p) for p in Path("results").glob("**/tensorboard/")] - ) - - -def list_decoder_models(): - paths = [str(p) for p in Path("checkpoints").glob("fish*/firefly*.pth")] - if not paths: - logger.warning("No decoder model found") - return paths - - -def list_llama_models(): - choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")] - choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")] - choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")] - choices = sorted(choices, reverse=True) - if not choices: - logger.warning("No LLaMA model found") - return choices - - -def list_lora_llama_models(): - choices = sorted( - [str(p) for p in Path("results").glob("lora*/**/*.ckpt")], reverse=True - ) - if not choices: - logger.warning("No LoRA LLaMA model found") - return choices - - -def fresh_decoder_model(): - return gr.Dropdown(choices=list_decoder_models()) - - -def fresh_llama_ckpt(llama_use_lora): - return gr.Dropdown( - choices=[i18n("latest"), i18n("new")] - + ( - [str(p) for p in Path("results").glob("text2sem*/")] - if not llama_use_lora - else [str(p) for p in Path("results").glob("lora_*/")] - ) - ) - - -def fresh_llama_model(): - return gr.Dropdown(choices=list_llama_models()) - - -def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_output): - if ( - lora_weight is None - or not Path(lora_weight).exists() - or not Path(llama_weight).exists() - ): - return build_html_error_message( - i18n( - "Path error, please check the model file exists in the corresponding path" - ) - ) - gr.Warning("Merging begins...") - merge_cmd = [ - PYTHON, - "tools/llama/merge_lora.py", - "--lora-config", - "r_8_alpha_16", - "--lora-weight", - lora_weight, - "--output", - llama_lora_output + "_" + generate_folder_name(), - ] - logger.info(merge_cmd) - subprocess.run(merge_cmd) - return build_html_ok_message(i18n("Merge successfully")) - - -def llama_quantify(llama_weight, quantify_mode): - if llama_weight is None or not Path(llama_weight).exists(): - return build_html_error_message( - i18n( - "Path error, please check the model file exists in the corresponding path" - ) - ) - - gr.Warning("Quantifying begins...") - - now = generate_folder_name() - quantify_cmd = [ - PYTHON, - "tools/llama/quantize.py", - "--checkpoint-path", - llama_weight, - "--mode", - quantify_mode, - "--timestamp", - now, - ] - logger.info(quantify_cmd) - subprocess.run(quantify_cmd) - if quantify_mode == "int8": - quantize_path = str( - Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-{now}" - ) - else: - quantize_path = str( - Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-g128-{now}" - ) - return build_html_ok_message( - i18n("Quantify successfully") + f"Path: {quantize_path}" - ) - - -init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path) -init_llama_yml = load_yaml_data_in_fact(llama_yml_path) - -with gr.Blocks( - head="", - js=js, - theme=seafoam, - analytics_enabled=False, - title="Fish Speech", -) as demo: - with gr.Row(): - with gr.Column(): - with gr.Tab("\U0001F4D6 " + i18n("Data Preprocessing")): - with gr.Row(): - textbox = gr.Textbox( - label="\U0000270F " - + i18n("Input Audio & Source Path for Transcription"), - info=i18n("Speaker is identified by the folder name"), - interactive=True, - ) - with gr.Row(equal_height=False): - with gr.Column(): - output_radio = gr.Radio( - label="\U0001F4C1 " - + i18n("Select source file processing method"), - choices=[i18n("Copy"), i18n("Move")], - value=i18n("Copy"), - interactive=True, - ) - with gr.Column(): - error = gr.HTML(label=i18n("Error Message")) - if_label = gr.Checkbox( - label=i18n("Open Labeler WebUI"), scale=0, show_label=True - ) - - with gr.Row(): - label_device = gr.Dropdown( - label=i18n("Labeling Device"), - info=i18n( - "It is recommended to use CUDA, if you have low configuration, use CPU" - ), - choices=["cpu", "cuda"], - value="cuda", - interactive=True, - ) - label_model = gr.Dropdown( - label=i18n("Whisper Model"), - info=i18n("Faster Whisper, Up to 5g GPU memory usage"), - choices=["large-v3", "medium"], - value="large-v3", - interactive=True, - ) - label_radio = gr.Dropdown( - label=i18n("Optional Label Language"), - info=i18n( - "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format" - ), - choices=[ - (i18n("Chinese"), "zh"), - (i18n("English"), "en"), - (i18n("Japanese"), "ja"), - (i18n("Disabled"), "IGNORE"), - (i18n("auto"), "auto"), - ], - value="IGNORE", - interactive=True, - ) - - with gr.Row(): - if_initial_prompt = gr.Checkbox( - value=False, - label=i18n("Enable Initial Prompt"), - min_width=120, - scale=0, - ) - initial_prompt = gr.Textbox( - label=i18n("Initial Prompt"), - info=i18n( - "Initial prompt can provide contextual or vocabulary-specific guidance to the model." - ), - placeholder="This audio introduces the basic concepts and applications of artificial intelligence and machine learning.", - interactive=False, - ) - - with gr.Row(): - add_button = gr.Button( - "\U000027A1 " + i18n("Add to Processing Area"), - variant="primary", - ) - remove_button = gr.Button( - "\U000026D4 " + i18n("Remove Selected Data") - ) - - with gr.Tab("\U0001F6E0 " + i18n("Training Configuration")): - with gr.Row(): - model_type_radio = gr.Radio( - label=i18n( - "Select the model to be trained (Depending on the Tab page you are on)" - ), - interactive=False, - choices=["VQGAN", "LLAMA"], - value="VQGAN", - ) - with gr.Row(): - with gr.Tabs(): - with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page: - gr.HTML("You don't need to train this model!") - - with gr.Tab(label=i18n("LLAMA Configuration")) as llama_page: - with gr.Row(equal_height=False): - llama_use_lora = gr.Checkbox( - label=i18n("Use LoRA"), - info=i18n( - "Use LoRA can save GPU memory, but may reduce the quality of the model" - ), - value=True, - interactive=True, - ) - llama_ckpt = gr.Dropdown( - label=i18n("Select LLAMA ckpt"), - choices=[i18n("latest"), i18n("new")] - + [ - str(p) - for p in Path("results").glob("text2sem*/") - ] - + [str(p) for p in Path("results").glob("lora*/")], - value=i18n("latest"), - interactive=True, - ) - with gr.Row(equal_height=False): - llama_lr_slider = gr.Slider( - label=i18n("Initial Learning Rate"), - info=i18n( - "lr smaller -> usually train slower but more stable" - ), - interactive=True, - minimum=1e-5, - maximum=1e-4, - step=1e-5, - value=5e-5, - ) - llama_maxsteps_slider = gr.Slider( - label=i18n("Maximum Training Steps"), - info=i18n( - "recommend: max_steps = num_audios // batch_size * (2 to 5)" - ), - interactive=True, - minimum=1, - maximum=10000, - step=1, - value=50, - ) - with gr.Row(equal_height=False): - llama_base_config = gr.Dropdown( - label=i18n("Model Size"), - choices=[ - "text2semantic_finetune", - ], - value="text2semantic_finetune", - ) - llama_data_num_workers_slider = gr.Slider( - label=i18n("Number of Workers"), - minimum=1, - maximum=32, - step=1, - value=4, - ) - with gr.Row(equal_height=False): - llama_data_batch_size_slider = gr.Slider( - label=i18n("Batch Size"), - interactive=True, - minimum=1, - maximum=32, - step=1, - value=4, - ) - llama_data_max_length_slider = gr.Slider( - label=i18n("Maximum Length per Sample"), - interactive=True, - minimum=1024, - maximum=4096, - step=128, - value=1024, - ) - with gr.Row(equal_height=False): - llama_precision_dropdown = gr.Dropdown( - label=i18n("Precision"), - info=i18n( - "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU" - ), - interactive=True, - choices=["32", "bf16-true", "16-mixed"], - value="bf16-true", - ) - llama_check_interval_slider = gr.Slider( - label=i18n("Save model every n steps"), - info=i18n( - "make sure that it's not greater than max_steps" - ), - interactive=True, - minimum=1, - maximum=1000, - step=1, - value=50, - ) - with gr.Row(equal_height=False): - llama_grad_batches = gr.Slider( - label=i18n("Accumulate Gradient Batches"), - interactive=True, - minimum=1, - maximum=20, - step=1, - value=init_llama_yml["trainer"][ - "accumulate_grad_batches" - ], - ) - llama_use_speaker = gr.Slider( - label=i18n( - "Probability of applying Speaker Condition" - ), - interactive=True, - minimum=0.1, - maximum=1.0, - step=0.05, - value=init_llama_yml["train_dataset"][ - "interactive_prob" - ], - ) - - with gr.Tab(label=i18n("Merge LoRA"), id=4): - with gr.Row(equal_height=False): - llama_weight = gr.Dropdown( - label=i18n("Base LLAMA Model"), - info=i18n( - "Type the path or select from the dropdown" - ), - choices=[ - "checkpoints/fish-speech-1.4/model.pth", - ], - value="checkpoints/fish-speech-1.4/model.pth", - allow_custom_value=True, - interactive=True, - ) - with gr.Row(equal_height=False): - lora_weight = gr.Dropdown( - label=i18n("LoRA Model to be merged"), - info=i18n( - "Type the path or select from the dropdown" - ), - choices=[ - str(p) - for p in Path("results").glob("lora*/**/*.ckpt") - ], - allow_custom_value=True, - interactive=True, - ) - lora_llama_config = gr.Dropdown( - label=i18n("LLAMA Model Config"), - info=i18n( - "Type the path or select from the dropdown" - ), - choices=[ - "text2semantic_finetune", - ], - value="text2semantic_finetune", - allow_custom_value=True, - ) - with gr.Row(equal_height=False): - llama_lora_output = gr.Dropdown( - label=i18n("Output Path"), - info=i18n( - "Type the path or select from the dropdown" - ), - value="checkpoints/merged", - choices=["checkpoints/merged"], - allow_custom_value=True, - interactive=True, - ) - with gr.Row(equal_height=False): - llama_lora_merge_btn = gr.Button( - value=i18n("Merge"), variant="primary" - ) - - with gr.Tab(label=i18n("Model Quantization"), id=5): - with gr.Row(equal_height=False): - llama_weight_to_quantify = gr.Dropdown( - label=i18n("Base LLAMA Model"), - info=i18n( - "Type the path or select from the dropdown" - ), - choices=list_llama_models(), - value="checkpoints/fish-speech-1.4", - allow_custom_value=True, - interactive=True, - ) - quantify_mode = gr.Dropdown( - label=i18n("Post-quantification Precision"), - info=i18n( - "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase" - ), - choices=["int8", "int4"], - value="int8", - allow_custom_value=False, - interactive=True, - ) - with gr.Row(equal_height=False): - llama_quantify_btn = gr.Button( - value=i18n("Quantify"), variant="primary" - ) - - with gr.Tab(label="Tensorboard", id=6): - with gr.Row(equal_height=False): - tb_host = gr.Textbox( - label=i18n("Tensorboard Host"), value="127.0.0.1" - ) - tb_port = gr.Textbox( - label=i18n("Tensorboard Port"), value="11451" - ) - with gr.Row(equal_height=False): - tb_dir = gr.Dropdown( - label=i18n("Tensorboard Log Path"), - allow_custom_value=True, - choices=[ - str(p) - for p in Path("results").glob("**/tensorboard/") - ], - ) - with gr.Row(equal_height=False): - if_tb = gr.Checkbox( - label=i18n("Open Tensorboard"), - ) - - with gr.Tab("\U0001F9E0 " + i18n("Inference Configuration")): - with gr.Column(): - with gr.Row(): - with gr.Accordion( - label="\U0001F5A5 " - + i18n("Inference Server Configuration"), - open=False, - ): - with gr.Row(): - infer_host_textbox = gr.Textbox( - label=i18n("WebUI Host"), value="127.0.0.1" - ) - infer_port_textbox = gr.Textbox( - label=i18n("WebUI Port"), value="7862" - ) - with gr.Row(): - infer_decoder_model = gr.Dropdown( - label=i18n("Decoder Model Path"), - info=i18n( - "Type the path or select from the dropdown" - ), - choices=list_decoder_models(), - value="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", - allow_custom_value=True, - ) - infer_decoder_config = gr.Dropdown( - label=i18n("Decoder Model Config"), - info=i18n("Changing with the Model Path"), - value="firefly_gan_vq", - choices=[ - "firefly_gan_vq", - ], - allow_custom_value=True, - ) - with gr.Row(): - infer_llama_model = gr.Dropdown( - label=i18n("LLAMA Model Path"), - info=i18n( - "Type the path or select from the dropdown" - ), - value="checkpoints/fish-speech-1.4", - choices=list_llama_models(), - allow_custom_value=True, - ) - - with gr.Row(): - infer_compile = gr.Radio( - label=i18n("Compile Model"), - info=i18n( - "Compile the model can significantly reduce the inference time, but will increase cold start time" - ), - choices=["Yes", "No"], - value=( - "Yes" if (sys.platform == "linux") else "No" - ), - interactive=is_module_installed("triton"), - ) - - with gr.Row(): - infer_checkbox = gr.Checkbox( - label=i18n("Open Inference Server") - ) - infer_error = gr.HTML(label=i18n("Inference Server Error")) - - with gr.Column(): - train_error = gr.HTML(label=i18n("Training Error")) - checkbox_group = gr.CheckboxGroup( - label="\U0001F4CA " + i18n("Data Source"), - info=i18n( - "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list." - ), - elem_classes=["data_src"], - ) - train_box = gr.Textbox( - label=i18n("Data Preprocessing Path"), - value=str(data_pre_output), - interactive=False, - ) - model_box = gr.Textbox( - label="\U0001F4BE " + i18n("Model Output Path"), - value=str(default_model_output), - interactive=False, - ) - - with gr.Accordion( - i18n( - "View the status of the preprocessing folder (use the slider to control the depth of the tree)" - ), - elem_classes=["scrollable-component"], - elem_id="file_accordion", - ): - tree_slider = gr.Slider( - minimum=0, - maximum=3, - value=0, - step=1, - show_label=False, - container=False, - ) - file_markdown = new_explorer(str(data_pre_output), 0) - with gr.Row(equal_height=False): - admit_btn = gr.Button( - "\U00002705 " + i18n("File Preprocessing"), - variant="primary", - ) - fresh_btn = gr.Button("\U0001F503", scale=0, min_width=80) - help_button = gr.Button("\U00002753", scale=0, min_width=80) # question - train_btn = gr.Button(i18n("Start Training"), variant="primary") - - footer = load_data_in_raw("fish_speech/webui/html/footer.html") - footer = footer.format( - versions=versions_html(), - api_docs="https://speech.fish.audio/inference/#http-api", - ) - gr.HTML(footer, elem_id="footer") - vqgan_page.select(lambda: "VQGAN", None, model_type_radio) - llama_page.select(lambda: "LLAMA", None, model_type_radio) - add_button.click( - fn=add_item, - inputs=[textbox, output_radio, label_radio, if_initial_prompt, initial_prompt], - outputs=[checkbox_group, error], - ) - remove_button.click( - fn=remove_items, inputs=[checkbox_group], outputs=[checkbox_group, error] - ) - checkbox_group.change(fn=show_selected, inputs=checkbox_group, outputs=[error]) - help_button.click( - fn=None, - js='() => { window.open("https://speech.fish.audio/", "newwindow", "height=100, width=400, ' - 'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}', - ) - if_label.change(fn=change_label, inputs=[if_label], outputs=[error]) - if_initial_prompt.change( - fn=lambda x: gr.Textbox(value="", interactive=x), - inputs=[if_initial_prompt], - outputs=[initial_prompt], - ) - train_btn.click( - fn=train_process, - inputs=[ - train_box, - model_type_radio, - # llama config - llama_ckpt, - llama_base_config, - llama_lr_slider, - llama_maxsteps_slider, - llama_data_num_workers_slider, - llama_data_batch_size_slider, - llama_data_max_length_slider, - llama_precision_dropdown, - llama_check_interval_slider, - llama_grad_batches, - llama_use_speaker, - llama_use_lora, - ], - outputs=[train_error], - ) - if_tb.change( - fn=tensorboard_process, - inputs=[if_tb, tb_dir, tb_host, tb_port], - outputs=[train_error], - ) - tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir]) - infer_decoder_model.change( - fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model] - ) - infer_llama_model.change( - fn=fresh_llama_model, inputs=[], outputs=[infer_llama_model] - ) - llama_weight.change(fn=fresh_llama_model, inputs=[], outputs=[llama_weight]) - admit_btn.click( - fn=check_files, - inputs=[train_box, tree_slider, label_model, label_device], - outputs=[error, file_markdown], - ) - fresh_btn.click( - fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown] - ) - llama_use_lora.change( - fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt] - ) - llama_ckpt.change( - fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt] - ) - lora_weight.change( - fn=lambda: gr.Dropdown(choices=list_lora_llama_models()), - inputs=[], - outputs=[lora_weight], - ) - llama_lora_merge_btn.click( - fn=llama_lora_merge, - inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output], - outputs=[train_error], - ) - llama_quantify_btn.click( - fn=llama_quantify, - inputs=[llama_weight_to_quantify, quantify_mode], - outputs=[train_error], - ) - infer_checkbox.change( - fn=change_infer, - inputs=[ - infer_checkbox, - infer_host_textbox, - infer_port_textbox, - infer_decoder_model, - infer_decoder_config, - infer_llama_model, - infer_compile, - ], - outputs=[infer_error], - ) - -demo.launch(inbrowser=True) +from __future__ import annotations + +import os + +os.environ["USE_LIBUV"] = "0" +import datetime +import html +import json +import platform +import shutil +import signal +import subprocess +import sys +from pathlib import Path + +import gradio as gr +import psutil +import yaml +from loguru import logger +from tqdm import tqdm + +PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python") +sys.path.insert(0, "") +print(sys.path) +cur_work_dir = Path(os.getcwd()).resolve() +print("You are in ", str(cur_work_dir)) + +from fish_speech.i18n import i18n +from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html + +config_path = cur_work_dir / "fish_speech" / "configs" +vqgan_yml_path = config_path / "firefly_gan_vq.yaml" +llama_yml_path = config_path / "text2semantic_finetune.yaml" + +env = os.environ.copy() +env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0" + +seafoam = Seafoam() + + +def build_html_error_message(error): + return f""" +
+ {html.escape(error)} +
+ """ + + +def build_html_ok_message(msg): + return f""" +
+ {html.escape(msg)} +
+ """ + + +def build_html_href(link, desc, msg): + return f""" + + {html.escape(msg)} + {desc} + + """ + + +def load_data_in_raw(path): + with open(path, "r", encoding="utf-8") as file: + data = file.read() + return str(data) + + +def kill_proc_tree(pid, including_parent=True): + try: + parent = psutil.Process(pid) + except psutil.NoSuchProcess: + # Process already terminated + return + + children = parent.children(recursive=True) + for child in children: + try: + os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL + except OSError: + pass + if including_parent: + try: + os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL + except OSError: + pass + + +system = platform.system() +p_label = None +p_infer = None +p_tensorboard = None + + +def kill_process(pid): + if system == "Windows": + cmd = "taskkill /t /f /pid %s" % pid + # os.system(cmd) + subprocess.run(cmd) + else: + kill_proc_tree(pid) + + +def change_label(if_label): + global p_label + if if_label == True and p_label is None: + url = "http://localhost:3000" + remote_url = "https://text-labeler.pages.dev/" + try: + p_label = subprocess.Popen( + [ + ( + "asr-label-linux-x64" + if sys.platform == "linux" + else "asr-label-win-x64.exe" + ) + ] + ) + except FileNotFoundError: + logger.warning("asr-label execution not found!") + + yield build_html_href( + link=remote_url, + desc=i18n("Optional online ver"), + msg=i18n("Opened labeler in browser"), + ) + + elif if_label == False and p_label is not None: + kill_process(p_label.pid) + p_label = None + yield build_html_ok_message("Nothing") + + +def clean_infer_cache(): + import tempfile + + temp_dir = Path(tempfile.gettempdir()) + gradio_dir = str(temp_dir / "gradio") + try: + shutil.rmtree(gradio_dir) + logger.info(f"Deleted cached audios: {gradio_dir}") + except PermissionError: + logger.info(f"Permission denied: Unable to delete {gradio_dir}") + except FileNotFoundError: + logger.info(f"{gradio_dir} was not found") + except Exception as e: + logger.info(f"An error occurred: {e}") + + +def change_infer( + if_infer, + host, + port, + infer_decoder_model, + infer_decoder_config, + infer_llama_model, + infer_compile, +): + global p_infer + if if_infer == True and p_infer == None: + env = os.environ.copy() + + env["GRADIO_SERVER_NAME"] = host + env["GRADIO_SERVER_PORT"] = port + # 启动第二个进程 + url = f"http://{host}:{port}" + yield build_html_ok_message( + i18n("Inferring interface is launched at {}").format(url) + ) + + clean_infer_cache() + + p_infer = subprocess.Popen( + [ + PYTHON, + "tools/webui.py", + "--decoder-checkpoint-path", + infer_decoder_model, + "--decoder-config-name", + infer_decoder_config, + "--llama-checkpoint-path", + infer_llama_model, + ] + + (["--compile"] if infer_compile == "Yes" else []), + env=env, + ) + + elif if_infer == False and p_infer is not None: + kill_process(p_infer.pid) + p_infer = None + yield build_html_error_message(i18n("Infer interface is closed")) + + +js = load_data_in_raw("fish_speech/webui/js/animate.js") +css = load_data_in_raw("fish_speech/webui/css/style.css") + +data_pre_output = (cur_work_dir / "data").resolve() +default_model_output = (cur_work_dir / "results").resolve() +default_filelist = data_pre_output / "detect.list" +data_pre_output.mkdir(parents=True, exist_ok=True) + +items = [] +dict_items = {} + + +def load_yaml_data_in_fact(yml_path): + with open(yml_path, "r", encoding="utf-8") as file: + yml = yaml.safe_load(file) + return yml + + +def write_yaml_data_in_fact(yml, yml_path): + with open(yml_path, "w", encoding="utf-8") as file: + yaml.safe_dump(yml, file, allow_unicode=True) + return yml + + +def generate_tree(directory, depth=0, max_depth=None, prefix=""): + if max_depth is not None and depth > max_depth: + return "" + + tree_str = "" + files = [] + directories = [] + for item in os.listdir(directory): + if os.path.isdir(os.path.join(directory, item)): + directories.append(item) + else: + files.append(item) + + entries = directories + files + for i, entry in enumerate(entries): + connector = "├── " if i < len(entries) - 1 else "└── " + tree_str += f"{prefix}{connector}{entry}
" + if i < len(directories): + extension = "│ " if i < len(entries) - 1 else " " + tree_str += generate_tree( + os.path.join(directory, entry), + depth + 1, + max_depth, + prefix=prefix + extension, + ) + return tree_str + + +def new_explorer(data_path, max_depth): + return gr.Markdown( + elem_classes=["scrollable-component"], + value=generate_tree(data_path, max_depth=max_depth), + ) + + +def add_item( + folder: str, + method: str, + label_lang: str, + if_initial_prompt: bool, + initial_prompt: str | None, +): + folder = folder.strip(" ").strip('"') + + folder_path = Path(folder) + + if folder and folder not in items and data_pre_output not in folder_path.parents: + if folder_path.is_dir(): + items.append(folder) + dict_items[folder] = dict( + type="folder", + method=method, + label_lang=label_lang, + initial_prompt=initial_prompt if if_initial_prompt else None, + ) + elif folder: + err = folder + return gr.Checkboxgroup(choices=items), build_html_error_message( + i18n("Invalid path: {}").format(err) + ) + + formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4) + logger.info("After Adding: " + formatted_data) + gr.Info(formatted_data) + return gr.Checkboxgroup(choices=items), build_html_ok_message( + i18n("Added path successfully!") + ) + + +def remove_items(selected_items): + global items, dict_items + to_remove = [item for item in items if item in selected_items] + for item in to_remove: + del dict_items[item] + items = [item for item in items if item in dict_items.keys()] + formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4) + logger.info(formatted_data) + gr.Warning("After Removing: " + formatted_data) + return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message( + i18n("Removed path successfully!") + ) + + +def show_selected(options): + selected_options = ", ".join(options) + + if options: + return i18n("Selected: {}").format(selected_options) + else: + return i18n("No selected options") + + +from pydub import AudioSegment + + +def convert_to_mono_in_place(audio_path: Path): + audio = AudioSegment.from_file(audio_path) + if audio.channels > 1: + mono_audio = audio.set_channels(1) + mono_audio.export(audio_path, format=audio_path.suffix[1:]) + logger.info(f"Convert {audio_path} successfully") + + +def list_copy(list_file_path, method): + wav_root = data_pre_output + lst = [] + with list_file_path.open("r", encoding="utf-8") as file: + for line in tqdm(file, desc="Processing audio/transcript"): + wav_path, speaker_name, language, text = line.strip().split("|") + original_wav_path = Path(wav_path) + target_wav_path = ( + wav_root / original_wav_path.parent.name / original_wav_path.name + ) + lst.append(f"{target_wav_path}|{speaker_name}|{language}|{text}") + if target_wav_path.is_file(): + continue + target_wav_path.parent.mkdir(parents=True, exist_ok=True) + if method == i18n("Copy"): + shutil.copy(original_wav_path, target_wav_path) + else: + shutil.move(original_wav_path, target_wav_path.parent) + convert_to_mono_in_place(target_wav_path) + original_lab_path = original_wav_path.with_suffix(".lab") + target_lab_path = ( + wav_root + / original_wav_path.parent.name + / original_wav_path.with_suffix(".lab").name + ) + if target_lab_path.is_file(): + continue + if method == i18n("Copy"): + shutil.copy(original_lab_path, target_lab_path) + else: + shutil.move(original_lab_path, target_lab_path.parent) + + if method == i18n("Move"): + with list_file_path.open("w", encoding="utf-8") as file: + file.writelines("\n".join(lst)) + + del lst + return build_html_ok_message(i18n("Use filelist")) + + +def check_files(data_path: str, max_depth: int, label_model: str, label_device: str): + global dict_items + data_path = Path(data_path) + gr.Warning("Pre-processing begins...") + for item, content in dict_items.items(): + item_path = Path(item) + tar_path = data_path / item_path.name + + if content["type"] == "folder" and item_path.is_dir(): + if content["method"] == i18n("Copy"): + os.makedirs(tar_path, exist_ok=True) + shutil.copytree( + src=str(item_path), dst=str(tar_path), dirs_exist_ok=True + ) + elif not tar_path.is_dir(): + shutil.move(src=str(item_path), dst=str(tar_path)) + + for suf in ["wav", "flac", "mp3"]: + for audio_path in tar_path.glob(f"**/*.{suf}"): + convert_to_mono_in_place(audio_path) + + cur_lang = content["label_lang"] + initial_prompt = content["initial_prompt"] + + transcribe_cmd = [ + PYTHON, + "tools/whisper_asr.py", + "--model-size", + label_model, + "--device", + label_device, + "--audio-dir", + tar_path, + "--save-dir", + tar_path, + "--language", + cur_lang, + ] + + if initial_prompt is not None: + transcribe_cmd += ["--initial-prompt", initial_prompt] + + if cur_lang != "IGNORE": + try: + gr.Warning("Begin To Transcribe") + subprocess.run( + transcribe_cmd, + env=env, + ) + except Exception: + print("Transcription error occurred") + + elif content["type"] == "file" and item_path.is_file(): + list_copy(item_path, content["method"]) + + return build_html_ok_message(i18n("Move files successfully")), new_explorer( + data_path, max_depth=max_depth + ) + + +def generate_folder_name(): + now = datetime.datetime.now() + folder_name = now.strftime("%Y%m%d_%H%M%S") + return folder_name + + +def train_process( + data_path: str, + option: str, + # llama config + llama_ckpt, + llama_base_config, + llama_lr, + llama_maxsteps, + llama_data_num_workers, + llama_data_batch_size, + llama_data_max_length, + llama_precision, + llama_check_interval, + llama_grad_batches, + llama_use_speaker, + llama_use_lora, +): + + backend = "nccl" if sys.platform == "linux" else "gloo" + + new_project = generate_folder_name() + print("New Project Name: ", new_project) + + if option == "VQGAN": + msg = "Skipped VQGAN Training." + gr.Warning(msg) + logger.info(msg) + + if option == "LLAMA": + msg = "LLAMA Training begins..." + gr.Warning(msg) + logger.info(msg) + subprocess.run( + [ + PYTHON, + "tools/vqgan/extract_vq.py", + str(data_pre_output), + "--num-workers", + "1", + "--batch-size", + "16", + "--config-name", + "firefly_gan_vq", + "--checkpoint-path", + "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", + ] + ) + + subprocess.run( + [ + PYTHON, + "tools/llama/build_dataset.py", + "--input", + str(data_pre_output), + "--text-extension", + ".lab", + "--num-workers", + "16", + ] + ) + ckpt_path = "checkpoints/fish-speech-1.4/model.pth" + lora_prefix = "lora_" if llama_use_lora else "" + llama_name = lora_prefix + "text2semantic_" + new_project + latest = next( + iter( + sorted( + [ + str(p.relative_to("results")) + for p in Path("results").glob(lora_prefix + "text2sem*/") + ], + reverse=True, + ) + ), + llama_name, + ) + project = ( + llama_name + if llama_ckpt == i18n("new") + else ( + latest + if llama_ckpt == i18n("latest") + else Path(llama_ckpt).relative_to("results") + ) + ) + logger.info(project) + + if llama_check_interval > llama_maxsteps: + llama_check_interval = llama_maxsteps + + train_cmd = [ + PYTHON, + "fish_speech/train.py", + "--config-name", + "text2semantic_finetune", + f"project={project}", + f"trainer.strategy.process_group_backend={backend}", + f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}", + f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}", + f"model.optimizer.lr={llama_lr}", + f"trainer.max_steps={llama_maxsteps}", + f"data.num_workers={llama_data_num_workers}", + f"data.batch_size={llama_data_batch_size}", + f"max_length={llama_data_max_length}", + f"trainer.precision={llama_precision}", + f"trainer.val_check_interval={llama_check_interval}", + f"trainer.accumulate_grad_batches={llama_grad_batches}", + f"train_dataset.interactive_prob={llama_use_speaker}", + ] + ([f"+lora@model.model.lora_config=r_8_alpha_16"] if llama_use_lora else []) + logger.info(train_cmd) + subprocess.run(train_cmd) + + return build_html_ok_message(i18n("Training stopped")) + + +def tensorboard_process( + if_tensorboard: bool, + tensorboard_dir: str, + host: str, + port: str, +): + global p_tensorboard + if if_tensorboard == True and p_tensorboard == None: + url = f"http://{host}:{port}" + yield build_html_ok_message( + i18n("Tensorboard interface is launched at {}").format(url) + ) + prefix = ["tensorboard"] + if Path("fishenv").exists(): + prefix = ["fishenv/env/python.exe", "fishenv/env/Scripts/tensorboard.exe"] + + p_tensorboard = subprocess.Popen( + prefix + + [ + "--logdir", + tensorboard_dir, + "--host", + host, + "--port", + port, + "--reload_interval", + "120", + ] + ) + elif if_tensorboard == False and p_tensorboard != None: + kill_process(p_tensorboard.pid) + p_tensorboard = None + yield build_html_error_message(i18n("Tensorboard interface is closed")) + + +def fresh_tb_dir(): + return gr.Dropdown( + choices=[str(p) for p in Path("results").glob("**/tensorboard/")] + ) + + +def list_decoder_models(): + paths = [str(p) for p in Path("checkpoints").glob("fish*/firefly*.pth")] + if not paths: + logger.warning("No decoder model found") + return paths + + +def list_llama_models(): + choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")] + choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")] + choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")] + choices = sorted(choices, reverse=True) + if not choices: + logger.warning("No LLaMA model found") + return choices + + +def list_lora_llama_models(): + choices = sorted( + [str(p) for p in Path("results").glob("lora*/**/*.ckpt")], reverse=True + ) + if not choices: + logger.warning("No LoRA LLaMA model found") + return choices + + +def fresh_decoder_model(): + return gr.Dropdown(choices=list_decoder_models()) + + +def fresh_llama_ckpt(llama_use_lora): + return gr.Dropdown( + choices=[i18n("latest"), i18n("new")] + + ( + [str(p) for p in Path("results").glob("text2sem*/")] + if not llama_use_lora + else [str(p) for p in Path("results").glob("lora_*/")] + ) + ) + + +def fresh_llama_model(): + return gr.Dropdown(choices=list_llama_models()) + + +def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_output): + if ( + lora_weight is None + or not Path(lora_weight).exists() + or not Path(llama_weight).exists() + ): + return build_html_error_message( + i18n( + "Path error, please check the model file exists in the corresponding path" + ) + ) + gr.Warning("Merging begins...") + merge_cmd = [ + PYTHON, + "tools/llama/merge_lora.py", + "--lora-config", + "r_8_alpha_16", + "--lora-weight", + lora_weight, + "--output", + llama_lora_output + "_" + generate_folder_name(), + ] + logger.info(merge_cmd) + subprocess.run(merge_cmd) + return build_html_ok_message(i18n("Merge successfully")) + + +def llama_quantify(llama_weight, quantify_mode): + if llama_weight is None or not Path(llama_weight).exists(): + return build_html_error_message( + i18n( + "Path error, please check the model file exists in the corresponding path" + ) + ) + + gr.Warning("Quantifying begins...") + + now = generate_folder_name() + quantify_cmd = [ + PYTHON, + "tools/llama/quantize.py", + "--checkpoint-path", + llama_weight, + "--mode", + quantify_mode, + "--timestamp", + now, + ] + logger.info(quantify_cmd) + subprocess.run(quantify_cmd) + if quantify_mode == "int8": + quantize_path = str( + Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-{now}" + ) + else: + quantize_path = str( + Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-g128-{now}" + ) + return build_html_ok_message( + i18n("Quantify successfully") + f"Path: {quantize_path}" + ) + + +init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path) +init_llama_yml = load_yaml_data_in_fact(llama_yml_path) + +with gr.Blocks( + head="", + js=js, + theme=seafoam, + analytics_enabled=False, + title="Fish Speech", +) as demo: + with gr.Row(): + with gr.Column(): + with gr.Tab("\U0001F4D6 " + i18n("Data Preprocessing")): + with gr.Row(): + textbox = gr.Textbox( + label="\U0000270F " + + i18n("Input Audio & Source Path for Transcription"), + info=i18n("Speaker is identified by the folder name"), + interactive=True, + ) + with gr.Row(equal_height=False): + with gr.Column(): + output_radio = gr.Radio( + label="\U0001F4C1 " + + i18n("Select source file processing method"), + choices=[i18n("Copy"), i18n("Move")], + value=i18n("Copy"), + interactive=True, + ) + with gr.Column(): + error = gr.HTML(label=i18n("Error Message")) + if_label = gr.Checkbox( + label=i18n("Open Labeler WebUI"), scale=0, show_label=True + ) + + with gr.Row(): + label_device = gr.Dropdown( + label=i18n("Labeling Device"), + info=i18n( + "It is recommended to use CUDA, if you have low configuration, use CPU" + ), + choices=["cpu", "cuda"], + value="cuda", + interactive=True, + ) + label_model = gr.Dropdown( + label=i18n("Whisper Model"), + info=i18n("Faster Whisper, Up to 5g GPU memory usage"), + choices=["large-v3", "medium"], + value="large-v3", + interactive=True, + ) + label_radio = gr.Dropdown( + label=i18n("Optional Label Language"), + info=i18n( + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format" + ), + choices=[ + (i18n("Chinese"), "zh"), + (i18n("English"), "en"), + (i18n("Japanese"), "ja"), + (i18n("Disabled"), "IGNORE"), + (i18n("auto"), "auto"), + ], + value="IGNORE", + interactive=True, + ) + + with gr.Row(): + if_initial_prompt = gr.Checkbox( + value=False, + label=i18n("Enable Initial Prompt"), + min_width=120, + scale=0, + ) + initial_prompt = gr.Textbox( + label=i18n("Initial Prompt"), + info=i18n( + "Initial prompt can provide contextual or vocabulary-specific guidance to the model." + ), + placeholder="This audio introduces the basic concepts and applications of artificial intelligence and machine learning.", + interactive=False, + ) + + with gr.Row(): + add_button = gr.Button( + "\U000027A1 " + i18n("Add to Processing Area"), + variant="primary", + ) + remove_button = gr.Button( + "\U000026D4 " + i18n("Remove Selected Data") + ) + + with gr.Tab("\U0001F6E0 " + i18n("Training Configuration")): + with gr.Row(): + model_type_radio = gr.Radio( + label=i18n( + "Select the model to be trained (Depending on the Tab page you are on)" + ), + interactive=False, + choices=["VQGAN", "LLAMA"], + value="VQGAN", + ) + with gr.Row(): + with gr.Column(): + with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page: + gr.HTML("You don't need to train this model!") + + with gr.Tab(label=i18n("LLAMA Configuration")) as llama_page: + with gr.Row(equal_height=False): + llama_use_lora = gr.Checkbox( + label=i18n("Use LoRA"), + info=i18n( + "Use LoRA can save GPU memory, but may reduce the quality of the model" + ), + value=True, + interactive=True, + ) + llama_ckpt = gr.Dropdown( + label=i18n("Select LLAMA ckpt"), + choices=[i18n("latest"), i18n("new")] + + [ + str(p) + for p in Path("results").glob("text2sem*/") + ] + + [str(p) for p in Path("results").glob("lora*/")], + value=i18n("latest"), + interactive=True, + ) + with gr.Row(equal_height=False): + llama_lr_slider = gr.Slider( + label=i18n("Initial Learning Rate"), + info=i18n( + "lr smaller -> usually train slower but more stable" + ), + interactive=True, + minimum=1e-5, + maximum=1e-4, + step=1e-5, + value=5e-5, + ) + llama_maxsteps_slider = gr.Slider( + label=i18n("Maximum Training Steps"), + info=i18n( + "recommend: max_steps = num_audios // batch_size * (2 to 5)" + ), + interactive=True, + minimum=1, + maximum=10000, + step=1, + value=50, + ) + with gr.Row(equal_height=False): + llama_base_config = gr.Dropdown( + label=i18n("Model Size"), + choices=[ + "text2semantic_finetune", + ], + value="text2semantic_finetune", + ) + llama_data_num_workers_slider = gr.Slider( + label=i18n("Number of Workers"), + minimum=1, + maximum=32, + step=1, + value=4, + ) + with gr.Row(equal_height=False): + llama_data_batch_size_slider = gr.Slider( + label=i18n("Batch Size"), + interactive=True, + minimum=1, + maximum=32, + step=1, + value=2, + ) + llama_data_max_length_slider = gr.Slider( + label=i18n("Maximum Length per Sample"), + interactive=True, + minimum=1024, + maximum=4096, + step=128, + value=2048, + ) + with gr.Row(equal_height=False): + llama_precision_dropdown = gr.Dropdown( + label=i18n("Precision"), + info=i18n( + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU" + ), + interactive=True, + choices=["32", "bf16-true", "16-mixed"], + value="bf16-true", + ) + llama_check_interval_slider = gr.Slider( + label=i18n("Save model every n steps"), + info=i18n( + "make sure that it's not greater than max_steps" + ), + interactive=True, + minimum=1, + maximum=1000, + step=1, + value=50, + ) + with gr.Row(equal_height=False): + llama_grad_batches = gr.Slider( + label=i18n("Accumulate Gradient Batches"), + interactive=True, + minimum=1, + maximum=20, + step=1, + value=init_llama_yml["trainer"][ + "accumulate_grad_batches" + ], + ) + llama_use_speaker = gr.Slider( + label=i18n( + "Probability of applying Speaker Condition" + ), + interactive=True, + minimum=0.1, + maximum=1.0, + step=0.05, + value=init_llama_yml["train_dataset"][ + "interactive_prob" + ], + ) + + with gr.Tab(label=i18n("Merge LoRA"), id=4): + with gr.Row(equal_height=False): + llama_weight = gr.Dropdown( + label=i18n("Base LLAMA Model"), + info=i18n( + "Type the path or select from the dropdown" + ), + choices=[ + "checkpoints/fish-speech-1.4/model.pth", + ], + value="checkpoints/fish-speech-1.4/model.pth", + allow_custom_value=True, + interactive=True, + ) + with gr.Row(equal_height=False): + lora_weight = gr.Dropdown( + label=i18n("LoRA Model to be merged"), + info=i18n( + "Type the path or select from the dropdown" + ), + choices=[ + str(p) + for p in Path("results").glob("lora*/**/*.ckpt") + ], + allow_custom_value=True, + interactive=True, + ) + lora_llama_config = gr.Dropdown( + label=i18n("LLAMA Model Config"), + info=i18n( + "Type the path or select from the dropdown" + ), + choices=[ + "text2semantic_finetune", + ], + value="text2semantic_finetune", + allow_custom_value=True, + ) + with gr.Row(equal_height=False): + llama_lora_output = gr.Dropdown( + label=i18n("Output Path"), + info=i18n( + "Type the path or select from the dropdown" + ), + value="checkpoints/merged", + choices=["checkpoints/merged"], + allow_custom_value=True, + interactive=True, + ) + with gr.Row(equal_height=False): + llama_lora_merge_btn = gr.Button( + value=i18n("Merge"), variant="primary" + ) + + with gr.Tab(label=i18n("Model Quantization"), id=5): + with gr.Row(equal_height=False): + llama_weight_to_quantify = gr.Dropdown( + label=i18n("Base LLAMA Model"), + info=i18n( + "Type the path or select from the dropdown" + ), + choices=list_llama_models(), + value="checkpoints/fish-speech-1.4", + allow_custom_value=True, + interactive=True, + ) + quantify_mode = gr.Dropdown( + label=i18n("Post-quantification Precision"), + info=i18n( + "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase" + ), + choices=["int8", "int4"], + value="int8", + allow_custom_value=False, + interactive=True, + ) + with gr.Row(equal_height=False): + llama_quantify_btn = gr.Button( + value=i18n("Quantify"), variant="primary" + ) + + with gr.Tab(label="Tensorboard", id=6): + with gr.Row(equal_height=False): + tb_host = gr.Textbox( + label=i18n("Tensorboard Host"), value="127.0.0.1" + ) + tb_port = gr.Textbox( + label=i18n("Tensorboard Port"), value="11451" + ) + with gr.Row(equal_height=False): + tb_dir = gr.Dropdown( + label=i18n("Tensorboard Log Path"), + allow_custom_value=True, + choices=[ + str(p) + for p in Path("results").glob("**/tensorboard/") + ], + ) + with gr.Row(equal_height=False): + if_tb = gr.Checkbox( + label=i18n("Open Tensorboard"), + ) + + with gr.Tab("\U0001F9E0 " + i18n("Inference Configuration")): + with gr.Column(): + with gr.Row(): + with gr.Accordion( + label="\U0001F5A5 " + + i18n("Inference Server Configuration"), + open=False, + ): + with gr.Row(): + infer_host_textbox = gr.Textbox( + label=i18n("WebUI Host"), value="127.0.0.1" + ) + infer_port_textbox = gr.Textbox( + label=i18n("WebUI Port"), value="7862" + ) + with gr.Row(): + infer_decoder_model = gr.Dropdown( + label=i18n("Decoder Model Path"), + info=i18n( + "Type the path or select from the dropdown" + ), + choices=list_decoder_models(), + value="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", + allow_custom_value=True, + ) + infer_decoder_config = gr.Dropdown( + label=i18n("Decoder Model Config"), + info=i18n("Changing with the Model Path"), + value="firefly_gan_vq", + choices=[ + "firefly_gan_vq", + ], + allow_custom_value=True, + ) + with gr.Row(): + infer_llama_model = gr.Dropdown( + label=i18n("LLAMA Model Path"), + info=i18n( + "Type the path or select from the dropdown" + ), + value="checkpoints/fish-speech-1.4", + choices=list_llama_models(), + allow_custom_value=True, + ) + + with gr.Row(): + infer_compile = gr.Radio( + label=i18n("Compile Model"), + info=i18n( + "Compile the model can significantly reduce the inference time, but will increase cold start time" + ), + choices=["Yes", "No"], + value=( + "Yes" if (sys.platform == "linux") else "No" + ), + interactive=is_module_installed("triton"), + ) + + with gr.Row(): + infer_checkbox = gr.Checkbox( + label=i18n("Open Inference Server") + ) + infer_error = gr.HTML(label=i18n("Inference Server Error")) + + with gr.Column(): + train_error = gr.HTML(label=i18n("Training Error")) + checkbox_group = gr.CheckboxGroup( + label="\U0001F4CA " + i18n("Data Source"), + info=i18n( + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list." + ), + elem_classes=["data_src"], + ) + train_box = gr.Textbox( + label=i18n("Data Preprocessing Path"), + value=str(data_pre_output), + interactive=False, + ) + model_box = gr.Textbox( + label="\U0001F4BE " + i18n("Model Output Path"), + value=str(default_model_output), + interactive=False, + ) + + with gr.Accordion( + i18n( + "View the status of the preprocessing folder (use the slider to control the depth of the tree)" + ), + elem_classes=["scrollable-component"], + elem_id="file_accordion", + ): + tree_slider = gr.Slider( + minimum=0, + maximum=3, + value=0, + step=1, + show_label=False, + container=False, + ) + file_markdown = new_explorer(str(data_pre_output), 0) + with gr.Row(equal_height=False): + admit_btn = gr.Button( + "\U00002705 " + i18n("File Preprocessing"), + variant="primary", + ) + fresh_btn = gr.Button("\U0001F503", scale=0, min_width=80) + help_button = gr.Button("\U00002753", scale=0, min_width=80) # question + train_btn = gr.Button(i18n("Start Training"), variant="primary") + + footer = load_data_in_raw("fish_speech/webui/html/footer.html") + footer = footer.format( + versions=versions_html(), + api_docs="https://speech.fish.audio/inference/#http-api", + ) + gr.HTML(footer, elem_id="footer") + vqgan_page.select(lambda: "VQGAN", None, model_type_radio) + llama_page.select(lambda: "LLAMA", None, model_type_radio) + add_button.click( + fn=add_item, + inputs=[textbox, output_radio, label_radio, if_initial_prompt, initial_prompt], + outputs=[checkbox_group, error], + ) + remove_button.click( + fn=remove_items, inputs=[checkbox_group], outputs=[checkbox_group, error] + ) + checkbox_group.change(fn=show_selected, inputs=checkbox_group, outputs=[error]) + help_button.click( + fn=None, + js='() => { window.open("https://speech.fish.audio/", "newwindow", "height=100, width=400, ' + 'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}', + ) + if_label.change(fn=change_label, inputs=[if_label], outputs=[error]) + if_initial_prompt.change( + fn=lambda x: gr.Textbox(value="", interactive=x), + inputs=[if_initial_prompt], + outputs=[initial_prompt], + ) + train_btn.click( + fn=train_process, + inputs=[ + train_box, + model_type_radio, + # llama config + llama_ckpt, + llama_base_config, + llama_lr_slider, + llama_maxsteps_slider, + llama_data_num_workers_slider, + llama_data_batch_size_slider, + llama_data_max_length_slider, + llama_precision_dropdown, + llama_check_interval_slider, + llama_grad_batches, + llama_use_speaker, + llama_use_lora, + ], + outputs=[train_error], + ) + if_tb.change( + fn=tensorboard_process, + inputs=[if_tb, tb_dir, tb_host, tb_port], + outputs=[train_error], + ) + tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir]) + infer_decoder_model.change( + fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model] + ) + infer_llama_model.change( + fn=fresh_llama_model, inputs=[], outputs=[infer_llama_model] + ) + llama_weight.change(fn=fresh_llama_model, inputs=[], outputs=[llama_weight]) + admit_btn.click( + fn=check_files, + inputs=[train_box, tree_slider, label_model, label_device], + outputs=[error, file_markdown], + ) + fresh_btn.click( + fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown] + ) + llama_use_lora.change( + fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt] + ) + llama_ckpt.change( + fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt] + ) + lora_weight.change( + fn=lambda: gr.Dropdown(choices=list_lora_llama_models()), + inputs=[], + outputs=[lora_weight], + ) + llama_lora_merge_btn.click( + fn=llama_lora_merge, + inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output], + outputs=[train_error], + ) + llama_quantify_btn.click( + fn=llama_quantify, + inputs=[llama_weight_to_quantify, quantify_mode], + outputs=[train_error], + ) + infer_checkbox.change( + fn=change_infer, + inputs=[ + infer_checkbox, + infer_host_textbox, + infer_port_textbox, + infer_decoder_model, + infer_decoder_config, + infer_llama_model, + infer_compile, + ], + outputs=[infer_error], + ) + +demo.launch(inbrowser=True) diff --git a/tools/api.py b/tools/api.py index 8245656d97f404c2c4620932f0b0a6ccaf6b4e77..5b83e636a710832df14a85eb6e3ed16c20732eaa 100644 --- a/tools/api.py +++ b/tools/api.py @@ -1,440 +1,953 @@ -import base64 -import io -import json -import queue -import random -import sys -import traceback -import wave -from argparse import ArgumentParser -from http import HTTPStatus -from pathlib import Path -from typing import Annotated, Any, Literal, Optional - -import numpy as np -import ormsgpack -import pyrootutils -import soundfile as sf -import torch -import torchaudio -from baize.datastructures import ContentType -from kui.asgi import ( - Body, - FactoryClass, - HTTPException, - HttpRequest, - HttpView, - JSONResponse, - Kui, - OpenAPI, - StreamResponse, -) -from kui.asgi.routing import MultimethodRoutes -from loguru import logger -from pydantic import BaseModel, Field, conint - -pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) - -# from fish_speech.models.vqgan.lit_module import VQGAN -from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture -from fish_speech.text.chn_text_norm.text import Text as ChnNormedText -from fish_speech.utils import autocast_exclude_mps -from tools.commons import ServeReferenceAudio, ServeTTSRequest -from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text -from tools.llama.generate import ( - GenerateRequest, - GenerateResponse, - WrappedGenerateResponse, - launch_thread_safe_queue, -) -from tools.vqgan.inference import load_model as load_decoder_model - - -def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): - buffer = io.BytesIO() - - with wave.open(buffer, "wb") as wav_file: - wav_file.setnchannels(channels) - wav_file.setsampwidth(bit_depth // 8) - wav_file.setframerate(sample_rate) - - wav_header_bytes = buffer.getvalue() - buffer.close() - return wav_header_bytes - - -# Define utils for web server -async def http_execption_handler(exc: HTTPException): - return JSONResponse( - dict( - statusCode=exc.status_code, - message=exc.content, - error=HTTPStatus(exc.status_code).phrase, - ), - exc.status_code, - exc.headers, - ) - - -async def other_exception_handler(exc: "Exception"): - traceback.print_exc() - - status = HTTPStatus.INTERNAL_SERVER_ERROR - return JSONResponse( - dict(statusCode=status, message=str(exc), error=status.phrase), - status, - ) - - -def load_audio(reference_audio, sr): - if len(reference_audio) > 255 or not Path(reference_audio).exists(): - audio_data = reference_audio - reference_audio = io.BytesIO(audio_data) - - waveform, original_sr = torchaudio.load( - reference_audio, backend="soundfile" if sys.platform == "linux" else "soundfile" - ) - - if waveform.shape[0] > 1: - waveform = torch.mean(waveform, dim=0, keepdim=True) - - if original_sr != sr: - resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr) - waveform = resampler(waveform) - - audio = waveform.squeeze().numpy() - return audio - - -def encode_reference(*, decoder_model, reference_audio, enable_reference_audio): - if enable_reference_audio and reference_audio is not None: - # Load audios, and prepare basic info here - reference_audio_content = load_audio( - reference_audio, decoder_model.spec_transform.sample_rate - ) - - audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[ - None, None, : - ] - audio_lengths = torch.tensor( - [audios.shape[2]], device=decoder_model.device, dtype=torch.long - ) - logger.info( - f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds" - ) - - # VQ Encoder - if isinstance(decoder_model, FireflyArchitecture): - prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0] - - logger.info(f"Encoded prompt: {prompt_tokens.shape}") - else: - prompt_tokens = None - logger.info("No reference audio provided") - - return prompt_tokens - - -def decode_vq_tokens( - *, - decoder_model, - codes, -): - feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device) - logger.info(f"VQ features: {codes.shape}") - - if isinstance(decoder_model, FireflyArchitecture): - # VQGAN Inference - return decoder_model.decode( - indices=codes[None], - feature_lengths=feature_lengths, - )[0].squeeze() - - raise ValueError(f"Unknown model type: {type(decoder_model)}") - - -routes = MultimethodRoutes(base_class=HttpView) - - -def get_content_type(audio_format): - if audio_format == "wav": - return "audio/wav" - elif audio_format == "flac": - return "audio/flac" - elif audio_format == "mp3": - return "audio/mpeg" - else: - return "application/octet-stream" - - -@torch.inference_mode() -def inference(req: ServeTTSRequest): - - idstr: str | None = req.reference_id - if idstr is not None: - ref_folder = Path("references") / idstr - ref_folder.mkdir(parents=True, exist_ok=True) - ref_audios = list_files( - ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False - ) - prompt_tokens = [ - encode_reference( - decoder_model=decoder_model, - reference_audio=audio_to_bytes(str(ref_audio)), - enable_reference_audio=True, - ) - for ref_audio in ref_audios - ] - prompt_texts = [ - read_ref_text(str(ref_audio.with_suffix(".lab"))) - for ref_audio in ref_audios - ] - - else: - # Parse reference audio aka prompt - refs = req.references - if refs is None: - refs = [] - prompt_tokens = [ - encode_reference( - decoder_model=decoder_model, - reference_audio=ref.audio, - enable_reference_audio=True, - ) - for ref in refs - ] - prompt_texts = [ref.text for ref in refs] - - # LLAMA Inference - request = dict( - device=decoder_model.device, - max_new_tokens=req.max_new_tokens, - text=( - req.text - if not req.normalize - else ChnNormedText(raw_text=req.text).normalize() - ), - top_p=req.top_p, - repetition_penalty=req.repetition_penalty, - temperature=req.temperature, - compile=args.compile, - iterative_prompt=req.chunk_length > 0, - chunk_length=req.chunk_length, - max_length=2048, - prompt_tokens=prompt_tokens, - prompt_text=prompt_texts, - ) - - response_queue = queue.Queue() - llama_queue.put( - GenerateRequest( - request=request, - response_queue=response_queue, - ) - ) - - if req.streaming: - yield wav_chunk_header() - - segments = [] - while True: - result: WrappedGenerateResponse = response_queue.get() - if result.status == "error": - raise result.response - break - - result: GenerateResponse = result.response - if result.action == "next": - break - - with autocast_exclude_mps( - device_type=decoder_model.device.type, dtype=args.precision - ): - fake_audios = decode_vq_tokens( - decoder_model=decoder_model, - codes=result.codes, - ) - - fake_audios = fake_audios.float().cpu().numpy() - - if req.streaming: - yield (fake_audios * 32768).astype(np.int16).tobytes() - else: - segments.append(fake_audios) - - if req.streaming: - return - - if len(segments) == 0: - raise HTTPException( - HTTPStatus.INTERNAL_SERVER_ERROR, - content="No audio generated, please check the input text.", - ) - - fake_audios = np.concatenate(segments, axis=0) - yield fake_audios - - -async def inference_async(req: ServeTTSRequest): - for chunk in inference(req): - yield chunk - - -async def buffer_to_async_generator(buffer): - yield buffer - - -@routes.http.post("/v1/tts") -async def api_invoke_model( - req: Annotated[ServeTTSRequest, Body(exclusive=True)], -): - """ - Invoke model and generate audio - """ - - if args.max_text_length > 0 and len(req.text) > args.max_text_length: - raise HTTPException( - HTTPStatus.BAD_REQUEST, - content=f"Text is too long, max length is {args.max_text_length}", - ) - - if req.streaming and req.format != "wav": - raise HTTPException( - HTTPStatus.BAD_REQUEST, - content="Streaming only supports WAV format", - ) - - if req.streaming: - return StreamResponse( - iterable=inference_async(req), - headers={ - "Content-Disposition": f"attachment; filename=audio.{req.format}", - }, - content_type=get_content_type(req.format), - ) - else: - fake_audios = next(inference(req)) - buffer = io.BytesIO() - sf.write( - buffer, - fake_audios, - decoder_model.spec_transform.sample_rate, - format=req.format, - ) - - return StreamResponse( - iterable=buffer_to_async_generator(buffer.getvalue()), - headers={ - "Content-Disposition": f"attachment; filename=audio.{req.format}", - }, - content_type=get_content_type(req.format), - ) - - -@routes.http.post("/v1/health") -async def api_health(): - """ - Health check - """ - - return JSONResponse({"status": "ok"}) - - -def parse_args(): - parser = ArgumentParser() - parser.add_argument( - "--llama-checkpoint-path", - type=str, - default="checkpoints/fish-speech-1.4", - ) - parser.add_argument( - "--decoder-checkpoint-path", - type=str, - default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", - ) - parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") - parser.add_argument("--device", type=str, default="cuda") - parser.add_argument("--half", action="store_true") - parser.add_argument("--compile", action="store_true") - parser.add_argument("--max-text-length", type=int, default=0) - parser.add_argument("--listen", type=str, default="127.0.0.1:8080") - parser.add_argument("--workers", type=int, default=1) - parser.add_argument("--use-auto-rerank", type=bool, default=True) - - return parser.parse_args() - - -# Define Kui app -openapi = OpenAPI( - { - "title": "Fish Speech API", - }, -).routes - - -class MsgPackRequest(HttpRequest): - async def data(self) -> Annotated[Any, ContentType("application/msgpack")]: - if self.content_type == "application/msgpack": - return ormsgpack.unpackb(await self.body) - - raise HTTPException( - HTTPStatus.UNSUPPORTED_MEDIA_TYPE, - headers={"Accept": "application/msgpack"}, - ) - - -app = Kui( - routes=routes + openapi[1:], # Remove the default route - exception_handlers={ - HTTPException: http_execption_handler, - Exception: other_exception_handler, - }, - factory_class=FactoryClass(http=MsgPackRequest), - cors_config={}, -) - - -if __name__ == "__main__": - - import uvicorn - - args = parse_args() - args.precision = torch.half if args.half else torch.bfloat16 - - logger.info("Loading Llama model...") - llama_queue = launch_thread_safe_queue( - checkpoint_path=args.llama_checkpoint_path, - device=args.device, - precision=args.precision, - compile=args.compile, - ) - logger.info("Llama model loaded, loading VQ-GAN model...") - - decoder_model = load_decoder_model( - config_name=args.decoder_config_name, - checkpoint_path=args.decoder_checkpoint_path, - device=args.device, - ) - - logger.info("VQ-GAN model loaded, warming up...") - - # Dry run to check if the model is loaded correctly and avoid the first-time latency - list( - inference( - ServeTTSRequest( - text="Hello world.", - references=[], - reference_id=None, - max_new_tokens=0, - top_p=0.7, - repetition_penalty=1.2, - temperature=0.7, - emotion=None, - format="wav", - ) - ) - ) - - logger.info(f"Warming up done, starting server at http://{args.listen}") - host, port = args.listen.split(":") - uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info") +import io +import json +import os +import queue +import re +import time +import traceback +import wave +from argparse import ArgumentParser +from http import HTTPStatus +from pathlib import Path +from typing import Annotated, Any + +import librosa +import numpy as np +import ormsgpack +import pyrootutils +import soundfile as sf +import torch +import torchaudio +from baize.datastructures import ContentType +from kui.asgi import ( + Body, + FactoryClass, + HTTPException, + HttpRequest, + HttpView, + JSONResponse, + Kui, + OpenAPI, + StreamResponse, + request, +) +from kui.asgi.routing import MultimethodRoutes +from loguru import logger + +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +import struct +from threading import Lock + +import httpx +from cachetools import LRUCache, cached +from funasr import AutoModel +from silero_vad import get_speech_timestamps, load_silero_vad + +from fish_speech.models.text2semantic.llama import BaseModelArgs + +# from fish_speech.models.vqgan.lit_module import VQGAN +from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture +from fish_speech.text.chn_text_norm.text import Text as ChnNormedText + +# from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN +from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer +from fish_speech.utils import autocast_exclude_mps, set_seed +from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text +from tools.llama.generate import ( + GenerateRequest, + GenerateResponse, + WrappedGenerateResponse, + launch_thread_safe_queue, + launch_thread_safe_queue_agent, +) +from tools.schema import ( + GLOBAL_NUM_SAMPLES, + ASRPackRequest, + ServeASRRequest, + ServeASRResponse, + ServeASRSegment, + ServeAudioPart, + ServeForwardMessage, + ServeMessage, + ServeRequest, + ServeResponse, + ServeStreamDelta, + ServeStreamResponse, + ServeTextPart, + ServeTimedASRResponse, + ServeTTSRequest, + ServeVQGANDecodeRequest, + ServeVQGANDecodeResponse, + ServeVQGANEncodeRequest, + ServeVQGANEncodeResponse, + ServeVQPart, +) +from tools.vqgan.inference import load_model as load_decoder_model + +global_lock = Lock() + +# Whether to disable keepalive (which is helpful if the server is in the same cluster) +DISABLE_KEEPALIVE = os.getenv("DISABLE_KEEPALIVE", "false").lower() == "true" +async_client = httpx.AsyncClient( + timeout=120, limits=httpx.Limits(keepalive_expiry=0 if DISABLE_KEEPALIVE else None) +) +backends = torchaudio.list_audio_backends() + +if "ffmpeg" in backends: + backend = "ffmpeg" +else: + backend = "soundfile" + + +def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): + buffer = io.BytesIO() + + with wave.open(buffer, "wb") as wav_file: + wav_file.setnchannels(channels) + wav_file.setsampwidth(bit_depth // 8) + wav_file.setframerate(sample_rate) + + wav_header_bytes = buffer.getvalue() + buffer.close() + return wav_header_bytes + + +# Define utils for web server +async def http_execption_handler(exc: HTTPException): + return JSONResponse( + dict( + statusCode=exc.status_code, + message=exc.content, + error=HTTPStatus(exc.status_code).phrase, + ), + exc.status_code, + exc.headers, + ) + + +async def other_exception_handler(exc: "Exception"): + traceback.print_exc() + + status = HTTPStatus.INTERNAL_SERVER_ERROR + return JSONResponse( + dict(statusCode=status, message=str(exc), error=status.phrase), + status, + ) + + +def load_audio(reference_audio, sr): + if len(reference_audio) > 255 or not Path(reference_audio).exists(): + audio_data = reference_audio + reference_audio = io.BytesIO(audio_data) + + waveform, original_sr = torchaudio.load(reference_audio, backend=backend) + + if waveform.shape[0] > 1: + waveform = torch.mean(waveform, dim=0, keepdim=True) + + if original_sr != sr: + resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr) + waveform = resampler(waveform) + + audio = waveform.squeeze().numpy() + return audio + + +def encode_reference(*, decoder_model, reference_audio, enable_reference_audio): + if enable_reference_audio and reference_audio is not None: + # Load audios, and prepare basic info here + reference_audio_content = load_audio( + reference_audio, decoder_model.spec_transform.sample_rate + ) + + audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[ + None, None, : + ] + audio_lengths = torch.tensor( + [audios.shape[2]], device=decoder_model.device, dtype=torch.long + ) + logger.info( + f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds" + ) + + # VQ Encoder + if isinstance(decoder_model, FireflyArchitecture): + prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0] + + logger.info(f"Encoded prompt: {prompt_tokens.shape}") + else: + prompt_tokens = None + logger.info("No reference audio provided") + + return prompt_tokens + + +def decode_vq_tokens( + *, + decoder_model, + codes, +): + feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device) + logger.info(f"VQ features: {codes.shape}") + + if isinstance(decoder_model, FireflyArchitecture): + # VQGAN Inference + return decoder_model.decode( + indices=codes[None], + feature_lengths=feature_lengths, + )[0].squeeze() + + raise ValueError(f"Unknown model type: {type(decoder_model)}") + + +routes = MultimethodRoutes(base_class=HttpView) + + +def get_content_type(audio_format): + if audio_format == "wav": + return "audio/wav" + elif audio_format == "flac": + return "audio/flac" + elif audio_format == "mp3": + return "audio/mpeg" + else: + return "application/octet-stream" + + +@torch.no_grad() +@torch.autocast(device_type="cuda", dtype=torch.half) +def batch_encode(model, audios: list[bytes | torch.Tensor]): + audios = [ + ( + torch.from_numpy( + librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0] + )[None] + if isinstance(audio, bytes) + else audio + ) + for audio in audios + ] + + # if any(audio.shape[-1] > model.spec_transform.sample_rate * 120 for audio in audios): + # raise ValueError("Single audio length is too long (>120s)") + + max_length = max(audio.shape[-1] for audio in audios) + print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s") + + lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device) + max_length = lengths.max().item() + padded = torch.stack( + [ + torch.nn.functional.pad(audio, (0, max_length - audio.shape[-1])) + for audio in audios + ] + ).to(model.device) + + features, feature_lengths = model.encode(padded, audio_lengths=lengths) + features, feature_lengths = features.cpu(), feature_lengths.cpu() + + return [feature[..., :length] for feature, length in zip(features, feature_lengths)] + + +@cached( + cache=LRUCache(maxsize=10000), + key=lambda model, audios: (model.device, tuple(audios)), +) +def cached_vqgan_batch_encode(model, audios: list[bytes]): + return batch_encode(model, audios) + + +@routes.http.post("/v1/vqgan/encode") +def api_vqgan_encode(payload: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]): + + start_time = time.time() + tokens = cached_vqgan_batch_encode(decoder_model, payload.audios) + logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms") + + return ormsgpack.packb( + ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]), + option=ormsgpack.OPT_SERIALIZE_PYDANTIC, + ) + + +@torch.no_grad() +@torch.autocast(device_type="cuda", dtype=torch.half) +def vqgan_decode(model, features): + lengths = torch.tensor( + [feature.shape[-1] for feature in features], device=model.device + ) + max_length = lengths.max().item() + padded = torch.stack( + [ + torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1])) + for feature in features + ] + ).to(model.device) + + # If bs too large, we do micro batch decode + audios, audio_lengths = [], [] + for i in range(0, padded.shape[0], 8): + audio, audio_length = model.decode( + padded[i : i + 8], feature_lengths=lengths[i : i + 8] + ) + audios.append(audio) + audio_lengths.append(audio_length) + audios = torch.cat(audios, dim=0) + audio_lengths = torch.cat(audio_lengths, dim=0) + audios, audio_lengths = audios.cpu(), audio_lengths.cpu() + + return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)] + + +@routes.http.post("/v1/vqgan/decode") +def api_vqgan_decode(payload: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]): + tokens = [torch.tensor(token, dtype=torch.int) for token in payload.tokens] + start_time = time.time() + audios = vqgan_decode(decoder_model, tokens) + logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms") + audios = [audio.astype(np.float16).tobytes() for audio in audios] + return ormsgpack.packb( + ServeVQGANDecodeResponse(audios=audios), option=ormsgpack.OPT_SERIALIZE_PYDANTIC + ) + + +@torch.no_grad() +def batch_asr(model, audios, sr, language="auto"): + resampled_audios = [] + for audio in audios: + audio = torchaudio.functional.resample(audio, sr, 16000) + assert audio.ndim == 1 + resampled_audios.append(audio) + + with global_lock: + res = model.generate( + input=resampled_audios, + batch_size=len(resampled_audios), + language=language, + use_itn=True, + ) + + results = [] + for r, audio in zip(res, audios): + text = r["text"] + text = re.sub(r"<\|.*?\|>", "", text) + duration = len(audio) / sr * 1000 + huge_gap = False + + if "timestamp" in r and len(r["timestamp"]) > 2: + for timestamp_a, timestamp_b in zip( + r["timestamp"][:-1], r["timestamp"][1:] + ): + # If there is a gap of more than 5 seconds, we consider it as a huge gap + if timestamp_b[0] - timestamp_a[1] > 5000: + huge_gap = True + break + + # Doesn't make sense to have a huge gap at the end + if duration - r["timestamp"][-1][1] > 3000: + huge_gap = True + + results.append( + { + "text": text, + "duration": duration, + "huge_gap": huge_gap, + } + ) + + return results + + +@routes.http.post("/v1/asr") +def api_invoke_asr(payload: Annotated[ServeASRRequest, Body(exclusive=True)]): + start_time = time.time() + audios = [np.frombuffer(audio, dtype=np.float16) for audio in payload.audios] + audios = [torch.from_numpy(audio).float() for audio in audios] + + if any(audios.shape[-1] >= 30 * payload.sample_rate for audios in audios): + raise HTTPException(status_code=400, detail="Audio length is too long") + + transcriptions = batch_asr( + asr_model, audios=audios, sr=payload.sample_rate, language=payload.language + ) + logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms") + + return ormsgpack.packb( + ServeASRResponse(transcriptions=transcriptions), + option=ormsgpack.OPT_SERIALIZE_PYDANTIC, + ) + + +from fish_speech.conversation import Conversation, Message + + +def execute_request( + input_queue: queue.Queue, + tokenizer: FishTokenizer, + config: BaseModelArgs, + request: ServeRequest, + device: str = "cuda:0", +): + + im_end_id = tokenizer.get_token_id(IM_END_TOKEN) + messages = [] + for message in request.messages: + messages.append(message.to_conversation_message()) + + assert len(messages) >= 1, "At least one message is required" + # assert messages[-1].role == "user", "The last message must be from the user" + + if messages[-1].role == "user": + messages.append( + Message(role="assistant", parts=[], add_im_end=False, modality="voice") + ) + elif messages[-1].role == "raw": + messages[-1].add_im_start = False + messages[-1].add_im_end = False + messages[-1].modality = "voice" + else: + assert ( + messages[-1].role == "assistant" + ), "The last message must be from the assistant" + messages[-1].add_im_end = False + + conv = Conversation(messages=messages) + + # conv.visualize(tokenizer) + prompt = conv.encode_for_inference( + tokenizer=tokenizer, num_codebooks=config.num_codebooks + ).to(device) + + if request.streaming: + for i in range(request.num_samples): + yield ServeStreamResponse( + sample_id=i, + delta=ServeStreamDelta( + role="assistant", + ), + ) + + req = { + "prompt": prompt, + "max_new_tokens": request.max_new_tokens, + "im_end_id": im_end_id, + "temperature": request.temperature, + "top_p": request.top_p, + "repetition_penalty": request.repetition_penalty, + "num_samples": request.num_samples, + "early_stop_threshold": request.early_stop_threshold, + } + + start = time.time() + response_queue = queue.Queue() + input_queue.put(GenerateRequest(req, response_queue)) + + # Decoding + decode_buffer = [[] for _ in range(request.num_samples)] + parts = [[] for _ in range(request.num_samples)] + + def send_reset_buffer(sample_id): + nonlocal decode_buffer + if len(decode_buffer[sample_id]) == 0: + return + + decoded = tokenizer.decode(decode_buffer[sample_id]) + part = ServeTextPart(text=decoded) + + if request.streaming: + yield ServeStreamResponse(delta=ServeStreamDelta(part=part)) + else: + parts[sample_id].append(part) + + decode_buffer[sample_id] = [] + + # Decode process + finished = [False for _ in range(request.num_samples)] + stats = {} + idx = 0 + while True: + response = response_queue.get() + + if response in ["stop", "error"]: + break + + for sample_id, tokens in enumerate(response): + if finished[sample_id]: + continue + + if tokens[0] == im_end_id: + finished[sample_id] = True + if request.streaming: + yield from send_reset_buffer(sample_id) + yield ServeStreamResponse( + sample_id=sample_id, + finish_reason="stop", + stats=stats, + ) + continue + + is_semantic = ( + tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id + ) + if is_semantic and request.streaming: + yield from send_reset_buffer(sample_id) + # Streaming vq + _tokens = tokens[1:].clone() + + if config.share_codebook_embeddings is False: + for i in range(len(_tokens)): + _tokens[i] -= config.codebook_size * i + + yield ServeStreamResponse( + sample_id=sample_id, + delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())), + ) + continue + + # Not streaming vq + if is_semantic: + yield from send_reset_buffer(sample_id) + # None streaming vq + if len(parts[sample_id]) == 0 or not isinstance( + parts[sample_id][-1], ServeVQPart + ): + _tokens = tokens[1:].clone() + + if config.share_codebook_embeddings is False: + for i in range(len(_tokens)): + _tokens[i] -= config.codebook_size * i + + parts[sample_id].append(ServeVQPart(codes=_tokens.tolist())) + else: + for codebook_id, value in enumerate(tokens[1:, :]): + val = value.item() + if config.share_codebook_embeddings is False: + val -= config.codebook_size * codebook_id + + parts[sample_id][-1].codes[codebook_id].append(val) + continue + + if not is_semantic: + # Stream text decode is not supported now + decode_buffer[sample_id].append(tokens[0, 0]) + + if idx == 0: + stats["time_to_first_token"] = (time.time() - start) * 1000 + + idx += 1 + + for sample_id in range(request.num_samples): + yield from send_reset_buffer(sample_id) + + stats["total_time"] = (time.time() - start) * 1000 + stats["total_tokens"] = idx + + if request.streaming: + for sample_id in range(request.num_samples): + if finished[sample_id]: + continue + yield ServeStreamResponse( + finish_reason=response, stats=stats, sample_id=sample_id + ) + return + + yield ServeResponse( + messages=[ + ServeMessage(role="assistant", parts=parts[i]) + for i in range(request.num_samples) + ], + finish_reason=response, + stats=stats, + ) + + +@routes.http.post("/v1/chat") +def api_invoke_chat( + req: Annotated[ServeRequest, Body(exclusive=True)], +): + """ + Invoke model and generate audio + """ + + # This makes torch compile happy + assert ( + req.num_samples == GLOBAL_NUM_SAMPLES + ), f"num_samples must be {GLOBAL_NUM_SAMPLES}" + + content_type = request.headers.get("Content-Type", "application/json") + json_mode = "application/json" in content_type + + async def wrapped_generator(): + generator = execute_request(llama_queue, tokenizer, config, req, args.device) + + for i in generator: + if json_mode: + body = i.model_dump_json().encode("utf-8") + yield b"data: " + body + b"\n\n" + else: + body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC) + yield struct.pack("I", len(body)) + body + + # Naive mode + if req.streaming is False: + result = next(execute_request(llama_queue, tokenizer, config, req, args.device)) + + if json_mode: + return JSONResponse(result.model_dump()) + else: + return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC) + + return StreamResponse( + iterable=wrapped_generator(), content_type="text/event-stream" + ) + + +@torch.inference_mode() +def inference(req: ServeTTSRequest): + + global prompt_tokens, prompt_texts + + idstr: str | None = req.reference_id + if idstr is not None: + ref_folder = Path("references") / idstr + ref_folder.mkdir(parents=True, exist_ok=True) + ref_audios = list_files( + ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False + ) + + if req.use_memory_cache == "never" or ( + req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0 + ): + prompt_tokens = [ + encode_reference( + decoder_model=decoder_model, + reference_audio=audio_to_bytes(str(ref_audio)), + enable_reference_audio=True, + ) + for ref_audio in ref_audios + ] + prompt_texts = [ + read_ref_text(str(ref_audio.with_suffix(".lab"))) + for ref_audio in ref_audios + ] + else: + logger.info("Use same references") + + else: + # Parse reference audio aka prompt + refs = req.references + + if req.use_memory_cache == "never" or ( + req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0 + ): + prompt_tokens = [ + encode_reference( + decoder_model=decoder_model, + reference_audio=ref.audio, + enable_reference_audio=True, + ) + for ref in refs + ] + prompt_texts = [ref.text for ref in refs] + else: + logger.info("Use same references") + + if req.seed is not None: + set_seed(req.seed) + logger.warning(f"set seed: {req.seed}") + + # LLAMA Inference + request = dict( + device=decoder_model.device, + max_new_tokens=req.max_new_tokens, + text=( + req.text + if not req.normalize + else ChnNormedText(raw_text=req.text).normalize() + ), + top_p=req.top_p, + repetition_penalty=req.repetition_penalty, + temperature=req.temperature, + compile=args.compile, + iterative_prompt=req.chunk_length > 0, + chunk_length=req.chunk_length, + max_length=4096, + prompt_tokens=prompt_tokens, + prompt_text=prompt_texts, + ) + + response_queue = queue.Queue() + llama_queue.put( + GenerateRequest( + request=request, + response_queue=response_queue, + ) + ) + + if req.streaming: + yield wav_chunk_header() + + segments = [] + while True: + result: WrappedGenerateResponse = response_queue.get() + if result.status == "error": + raise result.response + break + + result: GenerateResponse = result.response + if result.action == "next": + break + + with autocast_exclude_mps( + device_type=decoder_model.device.type, dtype=args.precision + ): + fake_audios = decode_vq_tokens( + decoder_model=decoder_model, + codes=result.codes, + ) + + fake_audios = fake_audios.float().cpu().numpy() + + if req.streaming: + yield (fake_audios * 32768).astype(np.int16).tobytes() + else: + segments.append(fake_audios) + + if req.streaming: + return + + if len(segments) == 0: + raise HTTPException( + HTTPStatus.INTERNAL_SERVER_ERROR, + content="No audio generated, please check the input text.", + ) + + fake_audios = np.concatenate(segments, axis=0) + yield fake_audios + + +async def inference_async(req: ServeTTSRequest): + for chunk in inference(req): + yield chunk + + +async def buffer_to_async_generator(buffer): + yield buffer + + +@routes.http.post("/v1/tts") +async def api_invoke_model( + req: Annotated[ServeTTSRequest, Body(exclusive=True)], +): + """ + Invoke model and generate audio + """ + + if args.max_text_length > 0 and len(req.text) > args.max_text_length: + raise HTTPException( + HTTPStatus.BAD_REQUEST, + content=f"Text is too long, max length is {args.max_text_length}", + ) + + if req.streaming and req.format != "wav": + raise HTTPException( + HTTPStatus.BAD_REQUEST, + content="Streaming only supports WAV format", + ) + + if req.streaming: + return StreamResponse( + iterable=inference_async(req), + headers={ + "Content-Disposition": f"attachment; filename=audio.{req.format}", + }, + content_type=get_content_type(req.format), + ) + else: + fake_audios = next(inference(req)) + buffer = io.BytesIO() + sf.write( + buffer, + fake_audios, + decoder_model.spec_transform.sample_rate, + format=req.format, + ) + + return StreamResponse( + iterable=buffer_to_async_generator(buffer.getvalue()), + headers={ + "Content-Disposition": f"attachment; filename=audio.{req.format}", + }, + content_type=get_content_type(req.format), + ) + + +@routes.http.post("/v1/health") +async def api_health(): + """ + Health check + """ + return JSONResponse({"status": "ok"}) + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts") + parser.add_argument("--load-asr-model", action="store_true") + parser.add_argument( + "--llama-checkpoint-path", + type=str, + default="checkpoints/fish-speech-1.4", + ) + parser.add_argument( + "--decoder-checkpoint-path", + type=str, + default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", + ) + parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--half", action="store_true") + parser.add_argument("--compile", action="store_true") + parser.add_argument("--max-text-length", type=int, default=0) + parser.add_argument("--listen", type=str, default="127.0.0.1:8080") + parser.add_argument("--workers", type=int, default=1) + + return parser.parse_args() + + +# Define Kui app +openapi = OpenAPI( + { + "title": "Fish Speech API", + "version": "1.4.2", + }, +).routes + + +class MsgPackRequest(HttpRequest): + async def data( + self, + ) -> Annotated[ + Any, ContentType("application/msgpack"), ContentType("application/json") + ]: + if self.content_type == "application/msgpack": + return ormsgpack.unpackb(await self.body) + + elif self.content_type == "application/json": + return await self.json + + raise HTTPException( + HTTPStatus.UNSUPPORTED_MEDIA_TYPE, + headers={"Accept": "application/msgpack, application/json"}, + ) + + +app = Kui( + routes=routes + openapi[1:], # Remove the default route + exception_handlers={ + HTTPException: http_execption_handler, + Exception: other_exception_handler, + }, + factory_class=FactoryClass(http=MsgPackRequest), + cors_config={}, +) + + +def load_asr_model(*, device="cuda", hub="ms"): + return AutoModel( + model="iic/SenseVoiceSmall", + device=device, + disable_pbar=True, + hub=hub, + ) + + +# Each worker process created by Uvicorn has its own memory space, +# meaning that models and variables are not shared between processes. +# Therefore, any global variables (like `llama_queue` or `decoder_model`) +# will not be shared across workers. + + +# Multi-threading for deep learning can cause issues, such as inconsistent +# outputs if multiple threads access the same buffers simultaneously. +# Instead, it's better to use multiprocessing or independent models per thread. +@app.on_startup +def initialize_app(app: Kui): + + global args, llama_queue, tokenizer, config, decoder_model, vad_model, asr_model, prompt_tokens, prompt_texts + + prompt_tokens, prompt_texts = [], [] + + args = parse_args() # args same as ones in other processes + args.precision = torch.half if args.half else torch.bfloat16 + + if args.load_asr_model: + logger.info(f"Loading ASR model...") + asr_model = load_asr_model(device=args.device) + + logger.info("Loading Llama model...") + + if args.mode == "tts": + llama_queue = launch_thread_safe_queue( + checkpoint_path=args.llama_checkpoint_path, + device=args.device, + precision=args.precision, + compile=args.compile, + ) + else: + llama_queue, tokenizer, config = launch_thread_safe_queue_agent( + checkpoint_path=args.llama_checkpoint_path, + device=args.device, + precision=args.precision, + compile=args.compile, + ) + + logger.info("Llama model loaded, loading VQ-GAN model...") + + decoder_model = load_decoder_model( + config_name=args.decoder_config_name, + checkpoint_path=args.decoder_checkpoint_path, + device=args.device, + ) + + logger.info("VQ-GAN model loaded, warming up...") + + vad_model = load_silero_vad() + + logger.info("VAD model loaded, warming up...") + + if args.mode == "tts": + # Dry run to ensure models work and avoid first-time latency + list( + inference( + ServeTTSRequest( + text="Hello world.", + references=[], + reference_id=None, + max_new_tokens=0, + chunk_length=200, + top_p=0.7, + repetition_penalty=1.5, + temperature=0.7, + emotion=None, + format="wav", + ) + ) + ) + + logger.info(f"Warming up done, starting server at http://{args.listen}") + + +if __name__ == "__main__": + + import uvicorn + + args = parse_args() + host, port = args.listen.split(":") + uvicorn.run( + "tools.api:app", + host=host, + port=int(port), + workers=args.workers, + log_level="info", + ) diff --git a/tools/auto_rerank.py b/tools/auto_rerank.py deleted file mode 100644 index 0297d63d77c67586c5c465b1225d022d668eeee5..0000000000000000000000000000000000000000 --- a/tools/auto_rerank.py +++ /dev/null @@ -1,159 +0,0 @@ -import os - -os.environ["MODELSCOPE_CACHE"] = ".cache/" - -import string -import time -from threading import Lock - -import librosa -import numpy as np -import opencc -import torch -from faster_whisper import WhisperModel - -t2s_converter = opencc.OpenCC("t2s") - - -def load_model(*, device="cuda"): - model = WhisperModel( - "medium", - device=device, - compute_type="float16", - download_root="faster_whisper", - ) - print("faster_whisper loaded!") - return model - - -@torch.no_grad() -def batch_asr_internal(model: WhisperModel, audios, sr): - resampled_audios = [] - for audio in audios: - - if isinstance(audio, np.ndarray): - audio = torch.from_numpy(audio).float() - - if audio.dim() > 1: - audio = audio.squeeze() - - assert audio.dim() == 1 - audio_np = audio.numpy() - resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000) - resampled_audios.append(resampled_audio) - - trans_results = [] - - for resampled_audio in resampled_audios: - segments, info = model.transcribe( - resampled_audio, - language=None, - beam_size=5, - initial_prompt="Punctuation is needed in any language.", - ) - trans_results.append(list(segments)) - - results = [] - for trans_res, audio in zip(trans_results, audios): - - duration = len(audio) / sr * 1000 - huge_gap = False - max_gap = 0.0 - - text = None - last_tr = None - - for tr in trans_res: - delta = tr.text.strip() - if tr.id > 1: - max_gap = max(tr.start - last_tr.end, max_gap) - text += delta - else: - text = delta - - last_tr = tr - if max_gap > 3.0: - huge_gap = True - break - - sim_text = t2s_converter.convert(text) - results.append( - { - "text": sim_text, - "duration": duration, - "huge_gap": huge_gap, - } - ) - - return results - - -global_lock = Lock() - - -def batch_asr(model, audios, sr): - return batch_asr_internal(model, audios, sr) - - -def is_chinese(text): - return True - - -def calculate_wer(text1, text2, debug=False): - chars1 = remove_punctuation(text1) - chars2 = remove_punctuation(text2) - - m, n = len(chars1), len(chars2) - - if m > n: - chars1, chars2 = chars2, chars1 - m, n = n, m - - prev = list(range(m + 1)) # row 0 distance: [0, 1, 2, ...] - curr = [0] * (m + 1) - - for j in range(1, n + 1): - curr[0] = j - for i in range(1, m + 1): - if chars1[i - 1] == chars2[j - 1]: - curr[i] = prev[i - 1] - else: - curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1 - prev, curr = curr, prev - - edits = prev[m] - tot = max(len(chars1), len(chars2)) - wer = edits / tot - - if debug: - print(" gt: ", chars1) - print(" pred: ", chars2) - print(" edits/tot = wer: ", edits, "/", tot, "=", wer) - - return wer - - -def remove_punctuation(text): - chinese_punctuation = ( - " \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—" - '‛""„‟…‧﹏' - ) - all_punctuation = string.punctuation + chinese_punctuation - translator = str.maketrans("", "", all_punctuation) - text_without_punctuation = text.translate(translator) - return text_without_punctuation - - -if __name__ == "__main__": - model = load_model() - audios = [ - librosa.load("44100.wav", sr=44100)[0], - librosa.load("lengyue.wav", sr=44100)[0], - ] - print(np.array(audios[0])) - print(batch_asr(model, audios, 44100)) - - start_time = time.time() - for _ in range(10): - print(batch_asr(model, audios, 44100)) - print("Time taken:", time.time() - start_time) diff --git a/tools/commons.py b/tools/commons.py deleted file mode 100644 index f81cadec1efd6e4f749c279e64a65ea9caaa3f53..0000000000000000000000000000000000000000 --- a/tools/commons.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Annotated, Literal, Optional - -from pydantic import BaseModel, Field, conint - - -class ServeReferenceAudio(BaseModel): - audio: bytes - text: str - - -class ServeTTSRequest(BaseModel): - text: str - chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200 - # Audio format - format: Literal["wav", "pcm", "mp3"] = "wav" - mp3_bitrate: Literal[64, 128, 192] = 128 - # References audios for in-context learning - references: list[ServeReferenceAudio] = [] - # Reference id - # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/ - # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1 - reference_id: str | None = None - # Normalize text for en & zh, this increase stability for numbers - normalize: bool = True - mp3_bitrate: Optional[int] = 64 - opus_bitrate: Optional[int] = -1000 - # Balance mode will reduce latency to 300ms, but may decrease stability - latency: Literal["normal", "balanced"] = "normal" - # not usually used below - streaming: bool = False - emotion: Optional[str] = None - max_new_tokens: int = 1024 - top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 - repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2 - temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 diff --git a/tools/download_models.py b/tools/download_models.py index 9e79c34c43b424a8e47c43dd3edf003634fc667e..fc735d36e5e07645d46faa035cd5cd3ad88ebdb3 100644 --- a/tools/download_models.py +++ b/tools/download_models.py @@ -1,55 +1,55 @@ -import os - -from huggingface_hub import hf_hub_download - - -# Download -def check_and_download_files(repo_id, file_list, local_dir): - os.makedirs(local_dir, exist_ok=True) - for file in file_list: - file_path = os.path.join(local_dir, file) - if not os.path.exists(file_path): - print(f"{file} 不存在,从 Hugging Face 仓库下载...") - hf_hub_download( - repo_id=repo_id, - filename=file, - resume_download=True, - local_dir=local_dir, - local_dir_use_symlinks=False, - ) - else: - print(f"{file} 已存在,跳过下载。") - - -# 1st -repo_id_1 = "fishaudio/fish-speech-1.4" -local_dir_1 = "./checkpoints/fish-speech-1.4" -files_1 = [ - "model.pth", - "README.md", - "special_tokens_map.json", - "tokenizer_config.json", - "tokenizer.json", - "config.json", - "firefly-gan-vq-fsq-8x1024-21hz-generator.pth", -] - -# 3rd -repo_id_3 = "fishaudio/fish-speech-1" -local_dir_3 = "./" -files_3 = [ - "ffmpeg.exe", - "ffprobe.exe", -] - -# 4th -repo_id_4 = "SpicyqSama007/fish-speech-packed" -local_dir_4 = "./" -files_4 = [ - "asr-label-win-x64.exe", -] - -check_and_download_files(repo_id_1, files_1, local_dir_1) - -check_and_download_files(repo_id_3, files_3, local_dir_3) -check_and_download_files(repo_id_4, files_4, local_dir_4) +import os + +from huggingface_hub import hf_hub_download + + +# Download +def check_and_download_files(repo_id, file_list, local_dir): + os.makedirs(local_dir, exist_ok=True) + for file in file_list: + file_path = os.path.join(local_dir, file) + if not os.path.exists(file_path): + print(f"{file} 不存在,从 Hugging Face 仓库下载...") + hf_hub_download( + repo_id=repo_id, + filename=file, + resume_download=True, + local_dir=local_dir, + local_dir_use_symlinks=False, + ) + else: + print(f"{file} 已存在,跳过下载。") + + +# 1st +repo_id_1 = "fishaudio/fish-speech-1.4" +local_dir_1 = "./checkpoints/fish-speech-1.4" +files_1 = [ + "model.pth", + "README.md", + "special_tokens_map.json", + "tokenizer_config.json", + "tokenizer.json", + "config.json", + "firefly-gan-vq-fsq-8x1024-21hz-generator.pth", +] + +# 3rd +repo_id_3 = "fishaudio/fish-speech-1" +local_dir_3 = "./" +files_3 = [ + "ffmpeg.exe", + "ffprobe.exe", +] + +# 4th +repo_id_4 = "SpicyqSama007/fish-speech-packed" +local_dir_4 = "./" +files_4 = [ + "asr-label-win-x64.exe", +] + +check_and_download_files(repo_id_1, files_1, local_dir_1) + +check_and_download_files(repo_id_3, files_3, local_dir_3) +check_and_download_files(repo_id_4, files_4, local_dir_4) diff --git a/tools/e2e_webui.py b/tools/e2e_webui.py new file mode 100644 index 0000000000000000000000000000000000000000..2331904d97ab0babca61fc0142574678b87c1253 --- /dev/null +++ b/tools/e2e_webui.py @@ -0,0 +1,232 @@ +import io +import re +import wave + +import gradio as gr +import numpy as np + +from .fish_e2e import FishE2EAgent, FishE2EEventType +from .schema import ServeMessage, ServeTextPart, ServeVQPart + + +def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): + buffer = io.BytesIO() + + with wave.open(buffer, "wb") as wav_file: + wav_file.setnchannels(channels) + wav_file.setsampwidth(bit_depth // 8) + wav_file.setframerate(sample_rate) + + wav_header_bytes = buffer.getvalue() + buffer.close() + return wav_header_bytes + + +class ChatState: + def __init__(self): + self.conversation = [] + self.added_systext = False + self.added_sysaudio = False + + def get_history(self): + results = [] + for msg in self.conversation: + results.append({"role": msg.role, "content": self.repr_message(msg)}) + + # Process assistant messages to extract questions and update user messages + for i, msg in enumerate(results): + if msg["role"] == "assistant": + match = re.search(r"Question: (.*?)\n\nResponse:", msg["content"]) + if match and i > 0 and results[i - 1]["role"] == "user": + # Update previous user message with extracted question + results[i - 1]["content"] += "\n" + match.group(1) + # Remove the Question/Answer format from assistant message + msg["content"] = msg["content"].split("\n\nResponse: ", 1)[1] + return results + + def repr_message(self, msg: ServeMessage): + response = "" + for part in msg.parts: + if isinstance(part, ServeTextPart): + response += part.text + elif isinstance(part, ServeVQPart): + response += f"