diff --git a/app.py b/app.py index 65b0580ab3a941c29c182847c3c9d9612f48ff8b..d8ec89e8a796a8a8a146bc686b5cc39cdea6b096 100644 --- a/app.py +++ b/app.py @@ -1,130 +1,51 @@ -import gradio as gr -import torch -import torchaudio -import numpy as np +## IMPORTS ## +import os import tempfile import time from pathlib import Path -from huggingface_hub import hf_hub_download -import os + +import gradio as gr +import numpy as np import spaces +import torch +import torchaudio +from cached_path import cached_path +from huggingface_hub import hf_hub_download from transformers import pipeline -# Import the inference module from infer import DMOInference -# Global variables -model = None -asr_pipe = None +## CUDA DEVICE ## device = "cuda" if torch.cuda.is_available() else "cpu" -# Initialize ASR pipeline -def initialize_asr_pipeline(device=device, dtype=None): - """Initialize the ASR pipeline on startup.""" - global asr_pipe - - if dtype is None: - dtype = ( - torch.float16 - if "cuda" in device - and torch.cuda.is_available() - and torch.cuda.get_device_properties(device).major >= 7 - and not torch.cuda.get_device_name().endswith("[ZLUDA]") - else torch.float32 - ) - - print("Initializing ASR pipeline...") - try: - asr_pipe = pipeline( - "automatic-speech-recognition", - model="openai/whisper-large-v3-turbo", - torch_dtype=dtype, - device="cpu" # Keep ASR on CPU to save GPU memory - ) - print("ASR pipeline initialized successfully") - except Exception as e: - print(f"Error initializing ASR pipeline: {e}") - asr_pipe = None - -# Transcribe function +## LOAD MODELS ## +asr_pipe = pipeline( + "automatic-speech-recognition", model="openai/whisper-large-v3-turbo", device=device +) +model = DMOInference( + student_checkpoint_path=str(cached_path("hf://yl4579/DMOSpeech2/model_85000.pt")), + duration_predictor_path=str(cached_path("hf://yl4579/DMOSpeech2/model_1500.pt")), + device=device, + model_type="F5TTS_Base", +) + + def transcribe(ref_audio, language=None): """Transcribe audio using the pre-loaded ASR pipeline.""" - global asr_pipe - - if asr_pipe is None: - return "" # Return empty string if ASR is not available - - try: - result = asr_pipe( - ref_audio, - chunk_length_s=30, - batch_size=128, - generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"}, - return_timestamps=False, - ) - return result["text"].strip() - except Exception as e: - print(f"Transcription error: {e}") - return "" - -def download_models(): - """Download models from HuggingFace Hub.""" - try: - print("Downloading models from HuggingFace...") - - # Download student model - student_path = hf_hub_download( - repo_id="yl4579/DMOSpeech2", - filename="model_85000.pt", - cache_dir="./models" - ) - - # Download duration predictor - duration_path = hf_hub_download( - repo_id="yl4579/DMOSpeech2", - filename="model_1500.pt", - cache_dir="./models" - ) - - print(f"Student model: {student_path}") - print(f"Duration model: {duration_path}") - - return student_path, duration_path - - except Exception as e: - print(f"Error downloading models: {e}") - return None, None - -def initialize_model(): - """Initialize the model on startup.""" - global model - - try: - # Download models - student_path, duration_path = download_models() - - if not student_path or not duration_path: - return False, "Failed to download models from HuggingFace" - - # Initialize model - model = DMOInference( - student_checkpoint_path=student_path, - duration_predictor_path=duration_path, - device=device, - model_type="F5TTS_Base" - ) - - return True, f"Model loaded successfully on {device.upper()}" - - except Exception as e: - return False, f"Error initializing model: {str(e)}" - -# Initialize models on startup -print("Initializing models...") -model_loaded, status_message = initialize_model() -initialize_asr_pipeline() # Initialize ASR pipeline - -@spaces.GPU(duration=120) # Request GPU for up to 120 seconds + return asr_pipe( + ref_audio, + chunk_length_s=30, + batch_size=128, + generate_kwargs=( + {"task": "transcribe", "language": language} + if language + else {"task": "transcribe"} + ), + return_timestamps=False, + )["text"].strip() + + +@spaces.GPU(duration=120) def generate_speech( prompt_audio, prompt_text, @@ -134,128 +55,115 @@ def generate_speech( custom_teacher_steps, custom_teacher_stopping_time, custom_student_start_step, - verbose + verbose, ): - """Generate speech with different configurations.""" - - if not model_loaded or model is None: - return None, "Model not loaded! Please refresh the page.", "", "" - if prompt_audio is None: - return None, "Please upload a reference audio!", "", "" - + raise gr.Error("Please upload a reference audio!") + if not target_text: - return None, "Please enter text to generate!", "", "" - - try: - # Auto-transcribe if prompt_text is empty - if not prompt_text and prompt_text != "": - print("Auto-transcribing reference audio...") - prompt_text = transcribe(prompt_audio) - print(f"Transcribed: {prompt_text}") - - start_time = time.time() - - # Configure parameters based on mode - if mode == "Student Only (4 steps)": - teacher_steps = 0 - student_start_step = 0 - teacher_stopping_time = 1.0 - elif mode == "Teacher-Guided (8 steps)": - # Default configuration from the notebook - teacher_steps = 16 - teacher_stopping_time = 0.07 - student_start_step = 1 - elif mode == "High Diversity (16 steps)": - teacher_steps = 24 - teacher_stopping_time = 0.3 - student_start_step = 2 - else: # Custom - teacher_steps = custom_teacher_steps - teacher_stopping_time = custom_teacher_stopping_time - student_start_step = custom_student_start_step - - # Generate speech - generated_audio = model.generate( - gen_text=target_text, - audio_path=prompt_audio, - prompt_text=prompt_text if prompt_text else None, - teacher_steps=teacher_steps, - teacher_stopping_time=teacher_stopping_time, - student_start_step=student_start_step, - temperature=temperature, - verbose=verbose - ) - - end_time = time.time() - - # Calculate metrics - processing_time = end_time - start_time - audio_duration = generated_audio.shape[-1] / 24000 - rtf = processing_time / audio_duration - - # Save audio - with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: - output_path = tmp_file.name - - if isinstance(generated_audio, np.ndarray): - generated_audio = torch.from_numpy(generated_audio) - - if generated_audio.dim() == 1: - generated_audio = generated_audio.unsqueeze(0) - - torchaudio.save(output_path, generated_audio, 24000) - - # Format metrics - metrics = f"RTF: {rtf:.2f}x ({1/rtf:.2f}x speed) | Processing: {processing_time:.2f}s for {audio_duration:.2f}s audio" - - return output_path, "Success!", metrics, f"Mode: {mode} | Transcribed: {prompt_text[:50]}..." if not prompt_text else f"Mode: {mode}" - - except Exception as e: - return None, f"Error: {str(e)}", "", "" + raise gr.Error("Please enter text to generate!") + + if not prompt_text and prompt_text != "": + prompt_text = transcribe(prompt_audio) + + + if mode == "Student Only (4 steps)": + teacher_steps = 0 + student_start_step = 0 + teacher_stopping_time = 1.0 + elif mode == "Teacher-Guided (8 steps)": + teacher_steps = 16 + teacher_stopping_time = 0.07 + student_start_step = 1 + elif mode == "High Diversity (16 steps)": + teacher_steps = 24 + teacher_stopping_time = 0.3 + student_start_step = 2 + else: # Custom + teacher_steps = custom_teacher_steps + teacher_stopping_time = custom_teacher_stopping_time + student_start_step = custom_student_start_step + + # Generate speech + generated_audio = model.generate( + gen_text=target_text, + audio_path=prompt_audio, + prompt_text=prompt_text if prompt_text else None, + teacher_steps=teacher_steps, + teacher_stopping_time=teacher_stopping_time, + student_start_step=student_start_step, + temperature=temperature, + verbose=verbose, + ) + + + # Save audio + with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: + output_path = tmp_file.name + + if isinstance(generated_audio, np.ndarray): + generated_audio = torch.from_numpy(generated_audio) + + if generated_audio.dim() == 1: + generated_audio = generated_audio.unsqueeze(0) + + torchaudio.save(output_path, generated_audio, 24000) + + return ( + output_path, + "Success!", + ( + f"Mode: {mode} | Transcribed: {prompt_text[:50]}..." + if not prompt_text + else f"Mode: {mode}" + ), + ) + # Create Gradio interface with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS") as demo: - gr.Markdown(f""" + gr.Markdown( + f""" # 🎙️ DMOSpeech 2: Zero-Shot Text-to-Speech Generate natural speech in any voice with just a short reference audio! - """) - + """ + ) + with gr.Row(): with gr.Column(scale=1): # Reference audio input prompt_audio = gr.Audio( label="📎 Reference Audio", type="filepath", - sources=["upload", "microphone"] + sources=["upload", "microphone"], ) - + prompt_text = gr.Textbox( label="📝 Reference Text (leave empty for auto-transcription)", placeholder="The text spoken in the reference audio...", - lines=2 + lines=2, ) - + target_text = gr.Textbox( label="✍️ Text to Generate", placeholder="Enter the text you want to synthesize...", - lines=4 + lines=4, ) - + # Generation mode mode = gr.Radio( choices=[ "Student Only (4 steps)", "Teacher-Guided (8 steps)", "High Diversity (16 steps)", - "Custom" + "Custom", ], value="Teacher-Guided (8 steps)", label="🚀 Generation Mode", - info="Choose speed vs quality/diversity tradeoff" + info="Choose speed vs quality/diversity tradeoff", ) - + # Advanced settings (collapsible) with gr.Accordion("⚙️ Advanced Settings", open=False): temperature = gr.Slider( @@ -264,9 +172,9 @@ with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS") as demo: value=0.0, step=0.1, label="Duration Temperature", - info="0 = deterministic, >0 = more variation in speech rhythm" + info="0 = deterministic, >0 = more variation in speech rhythm", ) - + with gr.Group(visible=False) as custom_settings: gr.Markdown("### Custom Mode Settings") custom_teacher_steps = gr.Slider( @@ -275,60 +183,50 @@ with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS") as demo: value=16, step=1, label="Teacher Steps", - info="More steps = higher quality" + info="More steps = higher quality", ) - + custom_teacher_stopping_time = gr.Slider( minimum=0.0, maximum=1.0, value=0.07, step=0.01, label="Teacher Stopping Time", - info="When to switch to student" + info="When to switch to student", ) - + custom_student_start_step = gr.Slider( minimum=0, maximum=4, value=1, step=1, label="Student Start Step", - info="Which student step to start from" + info="Which student step to start from", ) - + verbose = gr.Checkbox( value=False, label="Verbose Output", - info="Show detailed generation steps" + info="Show detailed generation steps", ) - + generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg") - + with gr.Column(scale=1): # Output output_audio = gr.Audio( - label="🔊 Generated Speech", - type="filepath", - autoplay=True + label="🔊 Generated Speech", type="filepath", autoplay=True ) - - status = gr.Textbox( - label="Status", - interactive=False - ) - - metrics = gr.Textbox( - label="Performance Metrics", - interactive=False - ) - - info = gr.Textbox( - label="Generation Info", - interactive=False - ) - + + status = gr.Textbox(label="Status", interactive=False) + + metrics = gr.Textbox(label="Performance Metrics", interactive=False) + + info = gr.Textbox(label="Generation Info", interactive=False) + # Tips - gr.Markdown(""" + gr.Markdown( + """ ### 💡 Quick Tips: - **Auto-transcription**: Leave reference text empty to auto-transcribe @@ -341,8 +239,9 @@ with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS") as demo: - Student Only: ~0.05x (20x faster than real-time) - Teacher-Guided: ~0.10x (10x faster) - High Diversity: ~0.20x (5x faster) - """) - + """ + ) + # Event handler generate_btn.click( generate_speech, @@ -355,21 +254,17 @@ with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS") as demo: custom_teacher_steps, custom_teacher_stopping_time, custom_student_start_step, - verbose + verbose, ], - outputs=[output_audio, status, metrics, info] + outputs=[output_audio, status, metrics, info], ) - + # Update visibility of custom settings based on mode def update_custom_visibility(mode): - is_custom = (mode == "Custom") + is_custom = mode == "Custom" return gr.update(visible=is_custom) - - mode.change( - update_custom_visibility, - inputs=[mode], - outputs=[custom_settings] - ) + + mode.change(update_custom_visibility, inputs=[mode], outputs=[custom_settings]) # Launch the app if __name__ == "__main__": @@ -377,5 +272,5 @@ if __name__ == "__main__": print(f"Warning: Model failed to load - {status_message}") if not asr_pipe: print("Warning: ASR pipeline not available - auto-transcription disabled") - - demo.launch() \ No newline at end of file + + demo.launch() diff --git a/ctcmodel.py b/ctcmodel.py index 0bc89d87b0dfc858fb743db82cc9bdaaa4177421..4149e0a386ef6d1383929da745a68cce80a362a5 100644 --- a/ctcmodel.py +++ b/ctcmodel.py @@ -1,36 +1,24 @@ -from torch import nn -import torch import copy - from pathlib import Path -from torchaudio.models import Conformer - -from f5_tts.model.utils import default -from f5_tts.model.utils import exists -from f5_tts.model.utils import list_str_to_idx -from f5_tts.model.utils import list_str_to_tensor -from f5_tts.model.utils import lens_to_mask -from f5_tts.model.utils import mask_from_frac_lengths +import torch +from torch import nn +from torchaudio.models import Conformer +from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx, + list_str_to_tensor, mask_from_frac_lengths) -from f5_tts.model.utils import ( - default, - exists, - list_str_to_idx, - list_str_to_tensor, - lens_to_mask, - mask_from_frac_lengths, -) class ResBlock(nn.Module): def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2): super().__init__() self._n_groups = 8 - self.blocks = nn.ModuleList([ - self._get_conv(hidden_dim, dilation=3**i, dropout_p=dropout_p) - for i in range(n_conv)]) - + self.blocks = nn.ModuleList( + [ + self._get_conv(hidden_dim, dilation=3**i, dropout_p=dropout_p) + for i in range(n_conv) + ] + ) def forward(self, x): for block in self.blocks: @@ -41,70 +29,71 @@ class ResBlock(nn.Module): def _get_conv(self, hidden_dim, dilation, dropout_p=0.2): layers = [ - nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation), + nn.Conv1d( + hidden_dim, + hidden_dim, + kernel_size=3, + padding=dilation, + dilation=dilation, + ), nn.ReLU(), nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim), nn.Dropout(p=dropout_p), nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), nn.ReLU(), - nn.Dropout(p=dropout_p) + nn.Dropout(p=dropout_p), ] return nn.Sequential(*layers) class ConformerCTC(nn.Module): - def __init__(self, - vocab_size, - mel_dim=100, - num_heads=8, - d_hid=512, - nlayers=6): + def __init__(self, vocab_size, mel_dim=100, num_heads=8, d_hid=512, nlayers=6): super().__init__() - + self.mel_proj = nn.Conv1d(mel_dim, d_hid, kernel_size=3, padding=1) - + self.d_hid = d_hid - + self.resblock1 = nn.Sequential( - ResBlock(d_hid), - nn.GroupNorm(num_groups=1, num_channels=d_hid) - ) - + ResBlock(d_hid), nn.GroupNorm(num_groups=1, num_channels=d_hid) + ) + self.resblock2 = nn.Sequential( - ResBlock(d_hid), - nn.GroupNorm(num_groups=1, num_channels=d_hid) - ) - + ResBlock(d_hid), nn.GroupNorm(num_groups=1, num_channels=d_hid) + ) self.conf_pre = torch.nn.ModuleList( - [Conformer( - input_dim=d_hid, - num_heads=num_heads, - ffn_dim=d_hid * 2, - num_layers=1, - depthwise_conv_kernel_size=15, - use_group_norm=True,) + [ + Conformer( + input_dim=d_hid, + num_heads=num_heads, + ffn_dim=d_hid * 2, + num_layers=1, + depthwise_conv_kernel_size=15, + use_group_norm=True, + ) for _ in range(nlayers // 2) ] ) - + self.conf_after = torch.nn.ModuleList( - [Conformer( - input_dim=d_hid, - num_heads=num_heads, - ffn_dim=d_hid * 2, - num_layers=1, - depthwise_conv_kernel_size=7, - use_group_norm=True,) + [ + Conformer( + input_dim=d_hid, + num_heads=num_heads, + ffn_dim=d_hid * 2, + num_layers=1, + depthwise_conv_kernel_size=7, + use_group_norm=True, + ) for _ in range(nlayers // 2) ] ) - self.out = nn.Linear(d_hid, 1 + vocab_size) # 1 for blank + self.out = nn.Linear(d_hid, 1 + vocab_size) # 1 for blank self.ctc_loss = nn.CTCLoss(blank=vocab_size, zero_infinity=True).cuda() - def forward(self, latent, text=None, text_lens=None): layers = [] @@ -125,20 +114,24 @@ class ConformerCTC(nn.Module): batch_size, time_steps, _ = x.shape # Create a dummy lengths tensor (all sequences are assumed to be full length). - input_lengths = torch.full((batch_size,), time_steps, device=x.device, dtype=torch.int64) + input_lengths = torch.full( + (batch_size,), time_steps, device=x.device, dtype=torch.int64 + ) - for layer in (self.conf_pre): + for layer in self.conf_pre: x, _ = layer(x, input_lengths) layers.append(x.transpose(1, 2)) - for layer in (self.conf_after): + for layer in self.conf_after: x, _ = layer(x, input_lengths) layers.append(x.transpose(1, 2)) x = self.out(x) if text_lens is not None and text is not None: - loss = self.ctc_loss(x.log_softmax(dim=2).transpose(0, 1), text, input_lengths, text_lens) + loss = self.ctc_loss( + x.log_softmax(dim=2).transpose(0, 1), text, input_lengths, text_lens + ) return x, layers, loss else: return x, layers @@ -147,9 +140,8 @@ class ConformerCTC(nn.Module): if __name__ == "__main__": from f5_tts.model.utils import get_tokenizer - bsz = 16 - + tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) dataset_name = "Emilia_ZH_EN" @@ -158,15 +150,17 @@ if __name__ == "__main__": else: tokenizer_path = dataset_name vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) - - model = ConformerCTC(vocab_size, mel_dim=80, num_heads=8, d_hid=512, nlayers=6).cuda() - + + model = ConformerCTC( + vocab_size, mel_dim=80, num_heads=8, d_hid=512, nlayers=6 + ).cuda() + text = ["hello world"] * bsz lens = torch.randint(1, 1000, (bsz,)).cuda() inp = torch.randn(bsz, lens.max(), 80).cuda() - + batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device - + # handle text as string text_lens = torch.tensor([len(t) for t in text], device=device) if isinstance(text, list): @@ -198,7 +192,6 @@ if __name__ == "__main__": char_vocab_map = list(vocab_char_map.keys()) - for batch in best_path: decoded_sequence = [] previous_token = None @@ -212,10 +205,15 @@ if __name__ == "__main__": decoded_sequences.append(decoded_sequence) # Convert token indices to characters - decoded_texts = [''.join([char_vocab_map[token] for token in sequence]) for sequence in decoded_sequences] + decoded_texts = [ + "".join([char_vocab_map[token] for token in sequence]) + for sequence in decoded_sequences + ] gt_texts = [] for i in range(text_lens.size(0)): - gt_texts.append(''.join([char_vocab_map[token] for token in text[i, :text_lens[i]]])) - + gt_texts.append( + "".join([char_vocab_map[token] for token in text[i, : text_lens[i]]]) + ) + print(decoded_texts) - print(gt_texts) \ No newline at end of file + print(gt_texts) diff --git a/discriminator_conformer.py b/discriminator_conformer.py index 058e15106917895d9fe9cfa039fb6058b607b67b..c48df249f1f995a6c2497e2bc4abb68737668132 100644 --- a/discriminator_conformer.py +++ b/discriminator_conformer.py @@ -2,30 +2,28 @@ from __future__ import annotations +from pathlib import Path + import torch import torch.nn as nn import torch.nn.functional as F import torchaudio.transforms as trans -from pathlib import Path from torchaudio.models import Conformer -from f5_tts.model.utils import ( - default, - exists, - list_str_to_idx, - list_str_to_tensor, - lens_to_mask, - mask_from_frac_lengths, -) +from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx, + list_str_to_tensor, mask_from_frac_lengths) + class ResBlock(nn.Module): def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2): super().__init__() self._n_groups = 8 - self.blocks = nn.ModuleList([ - self._get_conv(hidden_dim, dilation=3**i, dropout_p=dropout_p) - for i in range(n_conv)]) - + self.blocks = nn.ModuleList( + [ + self._get_conv(hidden_dim, dilation=3**i, dropout_p=dropout_p) + for i in range(n_conv) + ] + ) def forward(self, x): for block in self.blocks: @@ -36,46 +34,67 @@ class ResBlock(nn.Module): def _get_conv(self, hidden_dim, dilation, dropout_p=0.2): layers = [ - nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation), + nn.Conv1d( + hidden_dim, + hidden_dim, + kernel_size=3, + padding=dilation, + dilation=dilation, + ), nn.ReLU(), nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim), nn.Dropout(p=dropout_p), nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), nn.ReLU(), - nn.Dropout(p=dropout_p) + nn.Dropout(p=dropout_p), ] return nn.Sequential(*layers) + class ConformerDiscirminator(nn.Module): - def __init__(self, input_dim, channels=512, num_layers=3, num_heads=8, depthwise_conv_kernel_size=15, use_group_norm=True): + def __init__( + self, + input_dim, + channels=512, + num_layers=3, + num_heads=8, + depthwise_conv_kernel_size=15, + use_group_norm=True, + ): super().__init__() - + self.input_layer = nn.Conv1d(input_dim, channels, kernel_size=3, padding=1) self.resblock1 = nn.Sequential( - ResBlock(channels), - nn.GroupNorm(num_groups=1, num_channels=channels) - ) - + ResBlock(channels), nn.GroupNorm(num_groups=1, num_channels=channels) + ) + self.resblock2 = nn.Sequential( - ResBlock(channels), - nn.GroupNorm(num_groups=1, num_channels=channels) - ) - - self.conformer1 = Conformer(**{"input_dim": channels, - "num_heads": num_heads, - "ffn_dim": channels * 2, - "num_layers": 1, + ResBlock(channels), nn.GroupNorm(num_groups=1, num_channels=channels) + ) + + self.conformer1 = Conformer( + **{ + "input_dim": channels, + "num_heads": num_heads, + "ffn_dim": channels * 2, + "num_layers": 1, "depthwise_conv_kernel_size": depthwise_conv_kernel_size // 2, - "use_group_norm": use_group_norm}) - - self.conformer2 = Conformer(**{"input_dim": channels, - "num_heads": num_heads, - "ffn_dim": channels * 2, - "num_layers": num_layers - 1, + "use_group_norm": use_group_norm, + } + ) + + self.conformer2 = Conformer( + **{ + "input_dim": channels, + "num_heads": num_heads, + "ffn_dim": channels * 2, + "num_layers": num_layers - 1, "depthwise_conv_kernel_size": depthwise_conv_kernel_size, - "use_group_norm": use_group_norm}) - + "use_group_norm": use_group_norm, + } + ) + self.linear = nn.Conv1d(channels, 1, kernel_size=1) def forward(self, x): @@ -89,12 +108,14 @@ class ConformerDiscirminator(nn.Module): x = nn.functional.avg_pool1d(x, 2) x = self.resblock2(x) x = nn.functional.avg_pool1d(x, 2) - + # Transpose to (B, T, C) for the conformer. x = x.transpose(1, 2) batch_size, time_steps, _ = x.shape # Create a dummy lengths tensor (all sequences are assumed to be full length). - lengths = torch.full((batch_size,), time_steps, device=x.device, dtype=torch.int64) + lengths = torch.full( + (batch_size,), time_steps, device=x.device, dtype=torch.int64 + ) # The built-in Conformer returns (output, output_lengths); we discard lengths. x, _ = self.conformer1(x, lengths) @@ -107,12 +128,13 @@ class ConformerDiscirminator(nn.Module): return out + if __name__ == "__main__": - from f5_tts.model.utils import get_tokenizer from f5_tts.model import DiT + from f5_tts.model.utils import get_tokenizer bsz = 2 - + tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) dataset_name = "Emilia_ZH_EN" @@ -121,20 +143,28 @@ if __name__ == "__main__": else: tokenizer_path = dataset_name vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) - - - fake_unet = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4, text_num_embeds=vocab_size, mel_dim=80) + + fake_unet = DiT( + dim=1024, + depth=22, + heads=16, + ff_mult=2, + text_dim=512, + conv_layers=4, + text_num_embeds=vocab_size, + mel_dim=80, + ) fake_unet = fake_unet.cuda() text = ["hello world"] * bsz lens = torch.randint(1, 1000, (bsz,)).cuda() inp = torch.randn(bsz, lens.max(), 80).cuda() - + batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device - + # handle text as string if isinstance(text, list): if exists(vocab_char_map): @@ -147,13 +177,17 @@ if __name__ == "__main__": if not exists(lens): lens = torch.full((batch,), seq_len, device=device) - mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch + mask = lens_to_mask( + lens, length=seq_len + ) # useless here, as collate_fn will pad to max length in batch frac_lengths_mask = (0.7, 1.0) - + # get a random span to mask out for training conditionally - frac_lengths = torch.zeros((batch,), device=device).float().uniform_(*frac_lengths_mask) + frac_lengths = ( + torch.zeros((batch,), device=device).float().uniform_(*frac_lengths_mask) + ) rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) - + if exists(mask): rand_span_mask &= mask @@ -163,38 +197,41 @@ if __name__ == "__main__": x1 = inp x0 = torch.randn_like(x1) t = time.unsqueeze(-1).unsqueeze(-1) - + phi = (1 - t) * x0 + t * x1 flow = x1 - x0 cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) layers = fake_unet( - x=phi, + x=phi, cond=cond, - text=text, - time=time, + text=text, + time=time, drop_audio_cond=False, drop_text=False, - classify_mode=True + classify_mode=True, ) # layers = torch.stack(layers, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2) # print(layers.shape) from ctcmodel import ConformerCTC - ctcmodel = ConformerCTC(vocab_size=vocab_size, mel_dim=80, num_heads=8, d_hid=512, nlayers=6).cuda() + + ctcmodel = ConformerCTC( + vocab_size=vocab_size, mel_dim=80, num_heads=8, d_hid=512, nlayers=6 + ).cuda() real_out, layer = ctcmodel(inp) - layer = layer[-3:] # only use the last 3 layers - layer = [F.interpolate(l, mode='nearest', scale_factor=4).transpose(-1, -2) for l in layer] + layer = layer[-3:] # only use the last 3 layers + layer = [ + F.interpolate(l, mode="nearest", scale_factor=4).transpose(-1, -2) + for l in layer + ] if layer[0].size(1) < layers[0].size(1): layer = [F.pad(l, (0, 0, 0, layers[0].size(1) - l.size(1))) for l in layer] - + layers = layer + layers - model = ConformerDiscirminator(input_dim=23 * 1024 + 3 * 512, - channels=512 - ) - + model = ConformerDiscirminator(input_dim=23 * 1024 + 3 * 512, channels=512) model = model.cuda() print(model) diff --git a/dmd_trainer.py b/dmd_trainer.py index cf7f94f7dfd4cce56b02e4f568d40557cab88080..db54d6a7386fb73c8c8c6e29784d3f98eb07dca3 100644 --- a/dmd_trainer.py +++ b/dmd_trainer.py @@ -1,28 +1,26 @@ from __future__ import annotations -import os import gc -from tqdm import tqdm -import wandb +import math +import os import torch import torch.nn as nn -from torch.optim import AdamW -from torch.utils.data import DataLoader, Dataset, SequentialSampler -from torch.optim.lr_scheduler import LinearLR, SequentialLR - +import wandb from accelerate import Accelerator from accelerate.utils import DistributedDataParallelKwargs +from torch.optim import AdamW +from torch.optim.lr_scheduler import LinearLR, SequentialLR +from torch.utils.data import DataLoader, Dataset, SequentialSampler +from tqdm import tqdm -from unimodel import UniModel from f5_tts.model import CFM -from f5_tts.model.utils import exists, default from f5_tts.model.dataset import DynamicBatchSampler, collate_fn - +from f5_tts.model.utils import default, exists +from unimodel import UniModel # trainer -import math class RunningStats: def __init__(self): @@ -41,7 +39,7 @@ class RunningStats: @property def variance(self): """Return the sample variance. Returns NaN if fewer than two samples.""" - return self.M2 / (self.count - 1) if self.count > 1 else float('nan') + return self.M2 / (self.count - 1) if self.count > 1 else float("nan") @property def std(self): @@ -49,7 +47,6 @@ class RunningStats: return math.sqrt(self.variance) - class Trainer: def __init__( self, @@ -74,7 +71,6 @@ class Trainer: accelerate_kwargs: dict = dict(), bnb_optimizer: bool = False, scale: float = 1.0, - # training parameters for DMDSpeech num_student_step: int = 1, gen_update_ratio: int = 5, @@ -82,7 +78,6 @@ class Trainer: lambda_generator_loss: float = 1.0, lambda_ctc_loss: float = 1.0, lambda_sim_loss: float = 1.0, - num_GAN: int = 5000, num_D: int = 500, num_ctc: int = 5000, @@ -103,7 +98,13 @@ class Trainer: if logger == "wandb": if exists(wandb_resume_id): - init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}} + init_kwargs = { + "wandb": { + "resume": "allow", + "name": wandb_run_name, + "id": wandb_resume_id, + } + } else: init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}} self.accelerator.init_trackers( @@ -130,7 +131,9 @@ class Trainer: self.epochs = epochs self.num_warmup_updates = num_warmup_updates self.save_per_updates = save_per_updates - self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps) + self.last_per_steps = default( + last_per_steps, save_per_updates * grad_accumulation_steps + ) self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts") self.batch_size = batch_size @@ -142,41 +145,56 @@ class Trainer: self.noise_scheduler = noise_scheduler self.duration_predictor = duration_predictor - + self.log_step = log_step - self.gen_update_ratio = gen_update_ratio # number of generator updates per guidance (fake score function and discriminator) update - self.lambda_discriminator_loss = lambda_discriminator_loss # weight for discriminator loss (L_adv) - self.lambda_generator_loss = lambda_generator_loss # weight for generator loss (L_adv) - self.lambda_ctc_loss = lambda_ctc_loss # weight for ctc loss - self.lambda_sim_loss = lambda_sim_loss # weight for similarity loss - + self.gen_update_ratio = gen_update_ratio # number of generator updates per guidance (fake score function and discriminator) update + self.lambda_discriminator_loss = ( + lambda_discriminator_loss # weight for discriminator loss (L_adv) + ) + self.lambda_generator_loss = ( + lambda_generator_loss # weight for generator loss (L_adv) + ) + self.lambda_ctc_loss = lambda_ctc_loss # weight for ctc loss + self.lambda_sim_loss = lambda_sim_loss # weight for similarity loss + # create distillation schedule for student model - self.student_steps = ( - torch.linspace(0.0, 1.0, num_student_step + 1)[:-1]) - - self.GAN = model.guidance_model.gen_cls_loss # whether to use GAN training - self.num_GAN = num_GAN # number of steps before adversarial training - self.num_D = num_D # number of steps to train the discriminator before adversarial training - self.num_ctc = num_ctc # number of steps before CTC training - self.num_sim = num_sim # number of steps before similarity training - self.num_simu = num_simu # number of steps before using simulated data + self.student_steps = torch.linspace(0.0, 1.0, num_student_step + 1)[:-1] + + self.GAN = model.guidance_model.gen_cls_loss # whether to use GAN training + self.num_GAN = num_GAN # number of steps before adversarial training + self.num_D = num_D # number of steps to train the discriminator before adversarial training + self.num_ctc = num_ctc # number of steps before CTC training + self.num_sim = num_sim # number of steps before similarity training + self.num_simu = num_simu # number of steps before using simulated data # Assuming `self.model.fake_unet.parameters()` and `self.model.guidance_model.parameters()` are accessible if bnb_optimizer: import bitsandbytes as bnb - self.optimizer_generator = bnb.optim.AdamW8bit(self.model.feedforward_model.parameters(), lr=learning_rate) - self.optimizer_guidance = bnb.optim.AdamW8bit(self.model.guidance_model.parameters(), lr=learning_rate) + + self.optimizer_generator = bnb.optim.AdamW8bit( + self.model.feedforward_model.parameters(), lr=learning_rate + ) + self.optimizer_guidance = bnb.optim.AdamW8bit( + self.model.guidance_model.parameters(), lr=learning_rate + ) else: - self.optimizer_generator = AdamW(self.model.feedforward_model.parameters(), lr=learning_rate, eps=1e-7) - self.optimizer_guidance = AdamW(self.model.guidance_model.parameters(), lr=learning_rate, eps=1e-7) + self.optimizer_generator = AdamW( + self.model.feedforward_model.parameters(), lr=learning_rate, eps=1e-7 + ) + self.optimizer_guidance = AdamW( + self.model.guidance_model.parameters(), lr=learning_rate, eps=1e-7 + ) - self.model, self.optimizer_generator, self.optimizer_guidance = self.accelerator.prepare(self.model, self.optimizer_generator, self.optimizer_guidance) + self.model, self.optimizer_generator, self.optimizer_guidance = ( + self.accelerator.prepare( + self.model, self.optimizer_generator, self.optimizer_guidance + ) + ) self.generator_norm = RunningStats() self.guidance_norm = RunningStats() - @property def is_main(self): return self.accelerator.is_main_process @@ -186,8 +204,12 @@ class Trainer: if self.is_main: checkpoint = dict( model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(), - optimizer_generator_state_dict=self.accelerator.unwrap_model(self.optimizer_generator).state_dict(), - optimizer_guidance_state_dict=self.accelerator.unwrap_model(self.optimizer_guidance).state_dict(), + optimizer_generator_state_dict=self.accelerator.unwrap_model( + self.optimizer_generator + ).state_dict(), + optimizer_guidance_state_dict=self.accelerator.unwrap_model( + self.optimizer_guidance + ).state_dict(), scheduler_generator_state_dict=self.scheduler_generator.state_dict(), scheduler_guidance_state_dict=self.scheduler_guidance.state_dict(), step=step, @@ -196,10 +218,14 @@ class Trainer: if not os.path.exists(self.checkpoint_path): os.makedirs(self.checkpoint_path) if last: - self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt") + self.accelerator.save( + checkpoint, f"{self.checkpoint_path}/model_last.pt" + ) print(f"Saved last checkpoint at step {step}") else: - self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt") + self.accelerator.save( + checkpoint, f"{self.checkpoint_path}/model_{step}.pt" + ) def load_checkpoint(self): if ( @@ -218,9 +244,15 @@ class Trainer: key=lambda x: int("".join(filter(str.isdigit, x))), )[-1] # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ - checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu") + checkpoint = torch.load( + f"{self.checkpoint_path}/{latest_checkpoint}", + weights_only=True, + map_location="cpu", + ) - self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"], strict=False) + self.accelerator.unwrap_model(self.model).load_state_dict( + checkpoint["model_state_dict"], strict=False + ) # self.accelerator.unwrap_model(self.optimizer_generator).load_state_dict(checkpoint["optimizer_generator_state_dict"]) # self.accelerator.unwrap_model(self.optimizer_guidance).load_state_dict(checkpoint["optimizer_guidance_state_dict"]) # if self.scheduler_guidance: @@ -232,9 +264,14 @@ class Trainer: del checkpoint gc.collect() return step - - def train(self, train_dataset: Dataset, num_workers=64, resumable_with_seed: int = None, vocoder: nn.Module = None): + def train( + self, + train_dataset: Dataset, + num_workers=64, + resumable_with_seed: int = None, + vocoder: nn.Module = None, + ): if exists(resumable_with_seed): generator = torch.Generator() generator.manual_seed(resumable_with_seed) @@ -256,7 +293,11 @@ class Trainer: self.accelerator.even_batches = False sampler = SequentialSampler(train_dataset) batch_sampler = DynamicBatchSampler( - sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False + sampler, + self.batch_size, + max_samples=self.max_samples, + random_seed=resumable_with_seed, + drop_last=False, ) train_dataloader = DataLoader( train_dataset, @@ -267,29 +308,63 @@ class Trainer: batch_sampler=batch_sampler, ) else: - raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}") + raise ValueError( + f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}" + ) # accelerator.prepare() dispatches batches to devices; # which means the length of dataloader calculated before, should consider the number of devices - warmup_steps = ( - self.num_warmup_updates * self.accelerator.num_processes - ) - + warmup_steps = self.num_warmup_updates * self.accelerator.num_processes + # consider a fixed warmup steps while using accelerate multi-gpu ddp # otherwise by default with split_batches=False, warmup steps change with num_processes total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps decay_steps = total_steps - warmup_steps - - warmup_scheduler_generator = LinearLR(self.optimizer_generator, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps // (self.gen_update_ratio * self.grad_accumulation_steps)) - decay_scheduler_generator = LinearLR(self.optimizer_generator, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps // (self.gen_update_ratio * self.grad_accumulation_steps)) - self.scheduler_generator = SequentialLR(self.optimizer_generator, schedulers=[warmup_scheduler_generator, decay_scheduler_generator], milestones=[warmup_steps // (self.gen_update_ratio * self.grad_accumulation_steps)]) - warmup_scheduler_guidance = LinearLR(self.optimizer_guidance, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps) - decay_scheduler_guidance = LinearLR(self.optimizer_guidance, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps) - self.scheduler_guidance = SequentialLR(self.optimizer_guidance, schedulers=[warmup_scheduler_guidance, decay_scheduler_guidance], milestones=[warmup_steps]) + warmup_scheduler_generator = LinearLR( + self.optimizer_generator, + start_factor=1e-8, + end_factor=1.0, + total_iters=warmup_steps + // (self.gen_update_ratio * self.grad_accumulation_steps), + ) + decay_scheduler_generator = LinearLR( + self.optimizer_generator, + start_factor=1.0, + end_factor=1e-8, + total_iters=decay_steps + // (self.gen_update_ratio * self.grad_accumulation_steps), + ) + self.scheduler_generator = SequentialLR( + self.optimizer_generator, + schedulers=[warmup_scheduler_generator, decay_scheduler_generator], + milestones=[ + warmup_steps // (self.gen_update_ratio * self.grad_accumulation_steps) + ], + ) + + warmup_scheduler_guidance = LinearLR( + self.optimizer_guidance, + start_factor=1e-8, + end_factor=1.0, + total_iters=warmup_steps, + ) + decay_scheduler_guidance = LinearLR( + self.optimizer_guidance, + start_factor=1.0, + end_factor=1e-8, + total_iters=decay_steps, + ) + self.scheduler_guidance = SequentialLR( + self.optimizer_guidance, + schedulers=[warmup_scheduler_guidance, decay_scheduler_guidance], + milestones=[warmup_steps], + ) - train_dataloader, self.scheduler_generator, self.scheduler_guidance = self.accelerator.prepare( - train_dataloader, self.scheduler_generator, self.scheduler_guidance + train_dataloader, self.scheduler_generator, self.scheduler_guidance = ( + self.accelerator.prepare( + train_dataloader, self.scheduler_generator, self.scheduler_guidance + ) ) # actual steps = 1 gpu steps / gpus start_step = self.load_checkpoint() global_step = start_step @@ -298,7 +373,9 @@ class Trainer: orig_epoch_step = len(train_dataloader) skipped_epoch = int(start_step // orig_epoch_step) skipped_batch = start_step % orig_epoch_step - skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch) + skipped_dataloader = self.accelerator.skip_first_batches( + train_dataloader, num_batches=skipped_batch + ) else: skipped_epoch = 0 @@ -323,48 +400,59 @@ class Trainer: for batch in progress_bar: update_generator = global_step % self.gen_update_ratio == 0 - + with self.accelerator.accumulate(self.model): metrics = {} text_inputs = batch["text"] mel_spec = batch["mel"].permute(0, 2, 1) mel_lengths = batch["mel_lengths"] - + mel_spec = mel_spec / self.scale - - guidance_loss_dict, guidance_log_dict = self.model(inp=mel_spec, - text=text_inputs, - lens=mel_lengths, - student_steps=self.student_steps, - update_generator=False, - use_simulated=global_step >= self.num_simu, - ) + + guidance_loss_dict, guidance_log_dict = self.model( + inp=mel_spec, + text=text_inputs, + lens=mel_lengths, + student_steps=self.student_steps, + update_generator=False, + use_simulated=global_step >= self.num_simu, + ) # if self.GAN and update_generator: # # only add discriminator loss if GAN is enabled and generator is being updated # guidance_cls_loss = guidance_loss_dict["guidance_cls_loss"] * (self.lambda_discriminator_loss if global_step >= self.num_GAN and update_generator else 0) # metrics['loss/discriminator_loss'] = guidance_loss_dict["guidance_cls_loss"] # self.accelerator.backward(guidance_cls_loss, retain_graph=True) - + # if self.max_grad_norm > 0 and self.accelerator.sync_gradients: # metrics['grad_norm_guidance'] = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) guidance_loss = 0 guidance_loss += guidance_loss_dict["loss_fake_mean"] - metrics['loss/fake_score'] = guidance_loss_dict["loss_fake_mean"] + metrics["loss/fake_score"] = guidance_loss_dict["loss_fake_mean"] metrics["loss/guidance_loss"] = guidance_loss if self.GAN and update_generator: # only add discriminator loss if GAN is enabled and generator is being updated - guidance_cls_loss = guidance_loss_dict["guidance_cls_loss"] * (self.lambda_discriminator_loss if global_step >= self.num_GAN and update_generator else 0) - metrics['loss/discriminator_loss'] = guidance_loss_dict["guidance_cls_loss"] + guidance_cls_loss = guidance_loss_dict["guidance_cls_loss"] * ( + self.lambda_discriminator_loss + if global_step >= self.num_GAN and update_generator + else 0 + ) + metrics["loss/discriminator_loss"] = guidance_loss_dict[ + "guidance_cls_loss" + ] guidance_loss += guidance_cls_loss - + self.accelerator.backward(guidance_loss) if self.max_grad_norm > 0 and self.accelerator.sync_gradients: - metrics['grad_norm_guidance'] = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) + metrics["grad_norm_guidance"] = ( + self.accelerator.clip_grad_norm_( + self.model.parameters(), self.max_grad_norm + ) + ) # if self.guidance_norm.count < 100: # self.guidance_norm.update(metrics['grad_norm_guidance']) @@ -376,20 +464,20 @@ class Trainer: # elif self.guidance_norm.count >= 100: # self.guidance_norm.update(metrics['grad_norm_guidance']) - self.optimizer_guidance.step() self.scheduler_guidance.step() self.optimizer_guidance.zero_grad() self.optimizer_generator.zero_grad() # zero out the generator's gradient as well - + if update_generator: - generator_loss_dict, generator_log_dict = self.model(inp=mel_spec, - text=text_inputs, - lens=mel_lengths, - student_steps=self.student_steps, - update_generator=True, - use_simulated=global_step >= self.num_ctc, - ) + generator_loss_dict, generator_log_dict = self.model( + inp=mel_spec, + text=text_inputs, + lens=mel_lengths, + student_steps=self.student_steps, + update_generator=True, + use_simulated=global_step >= self.num_ctc, + ) # if self.GAN: # gen_cls_loss = generator_loss_dict["gen_cls_loss"] * (self.lambda_generator_loss if global_step >= (self.num_GAN + self.num_D) and update_generator else 0) # metrics["loss/gen_cls_loss"] = generator_loss_dict["gen_cls_loss"] @@ -402,32 +490,57 @@ class Trainer: generator_loss = 0 generator_loss += generator_loss_dict["loss_dm"] if "loss_mse" in generator_loss_dict: - generator_loss += generator_loss_dict["loss_mse"] - generator_loss += generator_loss_dict["loss_ctc"] * (self.lambda_ctc_loss if global_step >= self.num_ctc else 0) - generator_loss += generator_loss_dict["loss_sim"] * (self.lambda_sim_loss if global_step >= self.num_sim else 0) - generator_loss += generator_loss_dict["loss_kl"] * (self.lambda_ctc_loss if global_step >= self.num_ctc else 0) + generator_loss += generator_loss_dict["loss_mse"] + generator_loss += generator_loss_dict["loss_ctc"] * ( + self.lambda_ctc_loss if global_step >= self.num_ctc else 0 + ) + generator_loss += generator_loss_dict["loss_sim"] * ( + self.lambda_sim_loss if global_step >= self.num_sim else 0 + ) + generator_loss += generator_loss_dict["loss_kl"] * ( + self.lambda_ctc_loss if global_step >= self.num_ctc else 0 + ) if self.GAN: - gen_cls_loss = generator_loss_dict["gen_cls_loss"] * (self.lambda_generator_loss if global_step >= (self.num_GAN + self.num_D) and update_generator else 0) - metrics["loss/gen_cls_loss"] = generator_loss_dict["gen_cls_loss"] + gen_cls_loss = generator_loss_dict["gen_cls_loss"] * ( + self.lambda_generator_loss + if global_step >= (self.num_GAN + self.num_D) + and update_generator + else 0 + ) + metrics["loss/gen_cls_loss"] = generator_loss_dict[ + "gen_cls_loss" + ] generator_loss += gen_cls_loss - metrics['loss/dm_loss'] = generator_loss_dict["loss_dm"] - metrics['loss/ctc_loss'] = generator_loss_dict["loss_ctc"] - - metrics['loss/similarity_loss'] = generator_loss_dict["loss_sim"] - metrics['loss/generator_loss'] = generator_loss - - if "loss_mse" in generator_loss_dict and generator_loss_dict["loss_mse"] != 0: - metrics['loss/mse_loss'] = generator_loss_dict["loss_mse"] - if "loss_kl" in generator_loss_dict and generator_loss_dict["loss_kl"] != 0: - metrics['loss/kl_loss'] = generator_loss_dict["loss_kl"] + metrics["loss/dm_loss"] = generator_loss_dict["loss_dm"] + metrics["loss/ctc_loss"] = generator_loss_dict["loss_ctc"] + + metrics["loss/similarity_loss"] = generator_loss_dict[ + "loss_sim" + ] + metrics["loss/generator_loss"] = generator_loss + + if ( + "loss_mse" in generator_loss_dict + and generator_loss_dict["loss_mse"] != 0 + ): + metrics["loss/mse_loss"] = generator_loss_dict["loss_mse"] + if ( + "loss_kl" in generator_loss_dict + and generator_loss_dict["loss_kl"] != 0 + ): + metrics["loss/kl_loss"] = generator_loss_dict["loss_kl"] self.accelerator.backward(generator_loss) if self.max_grad_norm > 0 and self.accelerator.sync_gradients: - metrics['grad_norm_generator'] = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) + metrics["grad_norm_generator"] = ( + self.accelerator.clip_grad_norm_( + self.model.parameters(), self.max_grad_norm + ) + ) # self.generator_norm.update(metrics['grad_norm_generator']) - + # if metrics['grad_norm_generator'] > self.generator_norm.mean + 15 * self.generator_norm.std: # self.optimizer_generator.zero_grad() # self.optimizer_guidance.zero_grad() @@ -440,89 +553,165 @@ class Trainer: self.optimizer_generator.zero_grad() self.optimizer_guidance.zero_grad() # zero out the guidance's gradient as well - global_step += 1 if self.accelerator.is_local_main_process: - self.accelerator.log({**metrics, - "lr_generator": self.scheduler_generator.get_last_lr()[0], - "lr_guidance": self.scheduler_guidance.get_last_lr()[0], - } - , step=global_step) - - if global_step % self.log_step == 0 and self.accelerator.is_local_main_process and vocoder is not None: + self.accelerator.log( + { + **metrics, + "lr_generator": self.scheduler_generator.get_last_lr()[0], + "lr_guidance": self.scheduler_guidance.get_last_lr()[0], + }, + step=global_step, + ) + + if ( + global_step % self.log_step == 0 + and self.accelerator.is_local_main_process + and vocoder is not None + ): # log the first batch of the epoch with torch.no_grad(): - generator_input = generator_log_dict['generator_input'][0].unsqueeze(0).permute(0, 2, 1) * self.scale + generator_input = ( + generator_log_dict["generator_input"][0] + .unsqueeze(0) + .permute(0, 2, 1) + * self.scale + ) generator_input = vocoder.decode(generator_input.float().cpu()) generator_input = wandb.Audio( generator_input.float().numpy().squeeze(), sample_rate=24000, - caption="time: " + str(generator_log_dict['time'][0].float().cpu().numpy()) + caption="time: " + + str(generator_log_dict["time"][0].float().cpu().numpy()), ) - generator_output = generator_log_dict['generator_output'][0].unsqueeze(0).permute(0, 2, 1) * self.scale - generator_output = vocoder.decode(generator_output.float().cpu()) + generator_output = ( + generator_log_dict["generator_output"][0] + .unsqueeze(0) + .permute(0, 2, 1) + * self.scale + ) + generator_output = vocoder.decode( + generator_output.float().cpu() + ) generator_output = wandb.Audio( generator_output.float().numpy().squeeze(), sample_rate=24000, - caption="time: " + str(generator_log_dict['time'][0].float().cpu().numpy()) + caption="time: " + + str(generator_log_dict["time"][0].float().cpu().numpy()), + ) + + generator_cond = ( + generator_log_dict["generator_cond"][0] + .unsqueeze(0) + .permute(0, 2, 1) + * self.scale ) - - generator_cond = generator_log_dict['generator_cond'][0].unsqueeze(0).permute(0, 2, 1) * self.scale generator_cond = vocoder.decode(generator_cond.float().cpu()) generator_cond = wandb.Audio( generator_cond.float().numpy().squeeze(), sample_rate=24000, - caption="time: " + str(generator_log_dict['time'][0].float().cpu().numpy()) + caption="time: " + + str(generator_log_dict["time"][0].float().cpu().numpy()), + ) + + ground_truth = ( + generator_log_dict["ground_truth"][0] + .unsqueeze(0) + .permute(0, 2, 1) + * self.scale ) - - ground_truth = generator_log_dict['ground_truth'][0].unsqueeze(0).permute(0, 2, 1) * self.scale ground_truth = vocoder.decode(ground_truth.float().cpu()) ground_truth = wandb.Audio( ground_truth.float().numpy().squeeze(), sample_rate=24000, - caption="time: " + str(generator_log_dict['time'][0].float().cpu().numpy()) + caption="time: " + + str(generator_log_dict["time"][0].float().cpu().numpy()), + ) + + dmtrain_noisy_inp = ( + generator_log_dict["dmtrain_noisy_inp"][0] + .unsqueeze(0) + .permute(0, 2, 1) + * self.scale + ) + dmtrain_noisy_inp = vocoder.decode( + dmtrain_noisy_inp.float().cpu() ) - - dmtrain_noisy_inp = generator_log_dict['dmtrain_noisy_inp'][0].unsqueeze(0).permute(0, 2, 1) * self.scale - dmtrain_noisy_inp = vocoder.decode(dmtrain_noisy_inp.float().cpu()) dmtrain_noisy_inp = wandb.Audio( dmtrain_noisy_inp.float().numpy().squeeze(), sample_rate=24000, - caption="dmtrain_time: " + str(generator_log_dict['dmtrain_time'][0].float().cpu().numpy()) + caption="dmtrain_time: " + + str( + generator_log_dict["dmtrain_time"][0] + .float() + .cpu() + .numpy() + ), + ) + + dmtrain_pred_real_image = ( + generator_log_dict["dmtrain_pred_real_image"][0] + .unsqueeze(0) + .permute(0, 2, 1) + * self.scale + ) + dmtrain_pred_real_image = vocoder.decode( + dmtrain_pred_real_image.float().cpu() ) - - dmtrain_pred_real_image = generator_log_dict['dmtrain_pred_real_image'][0].unsqueeze(0).permute(0, 2, 1) * self.scale - dmtrain_pred_real_image = vocoder.decode(dmtrain_pred_real_image.float().cpu()) dmtrain_pred_real_image = wandb.Audio( dmtrain_pred_real_image.float().numpy().squeeze(), sample_rate=24000, - caption="dmtrain_time: " + str(generator_log_dict['dmtrain_time'][0].float().cpu().numpy()) + caption="dmtrain_time: " + + str( + generator_log_dict["dmtrain_time"][0] + .float() + .cpu() + .numpy() + ), + ) + + dmtrain_pred_fake_image = ( + generator_log_dict["dmtrain_pred_fake_image"][0] + .unsqueeze(0) + .permute(0, 2, 1) + * self.scale + ) + dmtrain_pred_fake_image = vocoder.decode( + dmtrain_pred_fake_image.float().cpu() ) - - dmtrain_pred_fake_image = generator_log_dict['dmtrain_pred_fake_image'][0].unsqueeze(0).permute(0, 2, 1) * self.scale - dmtrain_pred_fake_image = vocoder.decode(dmtrain_pred_fake_image.float().cpu()) dmtrain_pred_fake_image = wandb.Audio( dmtrain_pred_fake_image.float().numpy().squeeze(), sample_rate=24000, - caption="dmtrain_time: " + str(generator_log_dict['dmtrain_time'][0].float().cpu().numpy()) + caption="dmtrain_time: " + + str( + generator_log_dict["dmtrain_time"][0] + .float() + .cpu() + .numpy() + ), + ) + + self.accelerator.log( + { + "noisy_input": generator_input, + "output": generator_output, + "cond": generator_cond, + "ground_truth": ground_truth, + "dmtrain_noisy_inp": dmtrain_noisy_inp, + "dmtrain_pred_real_image": dmtrain_pred_real_image, + "dmtrain_pred_fake_image": dmtrain_pred_fake_image, + }, + step=global_step, ) - - - self.accelerator.log({"noisy_input": generator_input, - "output": generator_output, - "cond": generator_cond, - "ground_truth": ground_truth, - "dmtrain_noisy_inp": dmtrain_noisy_inp, - "dmtrain_pred_real_image": dmtrain_pred_real_image, - "dmtrain_pred_fake_image": dmtrain_pred_fake_image, - - }, step=global_step) progress_bar.set_postfix(step=str(global_step), metrics=metrics) - if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0: + if ( + global_step % (self.save_per_updates * self.grad_accumulation_steps) + == 0 + ): self.save_checkpoint(global_step) if global_step % self.last_per_steps == 0: @@ -531,5 +720,3 @@ class Trainer: self.save_checkpoint(global_step, last=True) self.accelerator.end_training() - - diff --git a/duration_predictor.py b/duration_predictor.py index 03c73fccbb58597170c410cc9067299d7f675483..2eed7f0a71ea7951a36e962e6f64812d25d9eb77 100644 --- a/duration_predictor.py +++ b/duration_predictor.py @@ -3,6 +3,7 @@ import torch.nn as nn # from tts_encode import tts_encode + def calculate_remaining_lengths(mel_lengths): B = mel_lengths.shape[0] max_L = mel_lengths.max().item() # Get the maximum length in the batch @@ -21,64 +22,84 @@ class PositionalEncoding(nn.Module): super().__init__() pe = torch.zeros(max_len, hidden_dim) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp(torch.arange(0, hidden_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / hidden_dim)) + div_term = torch.exp( + torch.arange(0, hidden_dim, 2).float() + * (-torch.log(torch.tensor(10000.0)) / hidden_dim) + ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.pe = pe.unsqueeze(0) # Shape: (1, max_len, hidden_dim) def forward(self, x): - x = x + self.pe[:, :x.size(1)].to(x.device) + x = x + self.pe[:, : x.size(1)].to(x.device) return x class SpeechLengthPredictor(nn.Module): - def __init__(self, - vocab_size=2545, n_mel=100, hidden_dim=256, - n_text_layer=4, n_cross_layer=4, n_head=8, + def __init__( + self, + vocab_size=2545, + n_mel=100, + hidden_dim=256, + n_text_layer=4, + n_cross_layer=4, + n_head=8, output_dim=1, ): super().__init__() - + # Text Encoder: Embedding + Transformer Layers - self.text_embedder = nn.Embedding(vocab_size+1, hidden_dim, padding_idx=vocab_size) + self.text_embedder = nn.Embedding( + vocab_size + 1, hidden_dim, padding_idx=vocab_size + ) self.text_pe = PositionalEncoding(hidden_dim) encoder_layer = nn.TransformerEncoderLayer( - d_model=hidden_dim, nhead=n_head, dim_feedforward=hidden_dim*2, batch_first=True + d_model=hidden_dim, + nhead=n_head, + dim_feedforward=hidden_dim * 2, + batch_first=True, + ) + self.text_encoder = nn.TransformerEncoder( + encoder_layer, num_layers=n_text_layer ) - self.text_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_text_layer) - + # Mel Spectrogram Embedder self.mel_embedder = nn.Linear(n_mel, hidden_dim) self.mel_pe = PositionalEncoding(hidden_dim) # Transformer Decoder Layers with Cross-Attention in Every Layer decoder_layer = nn.TransformerDecoderLayer( - d_model=hidden_dim, nhead=n_head, dim_feedforward=hidden_dim*2, batch_first=True + d_model=hidden_dim, + nhead=n_head, + dim_feedforward=hidden_dim * 2, + batch_first=True, ) self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_cross_layer) - + # Final Classification Layer self.predictor = nn.Linear(hidden_dim, output_dim) def forward(self, text_ids, mel): # Encode text text_embedded = self.text_pe(self.text_embedder(text_ids)) - text_features = self.text_encoder(text_embedded) # (B, L_text, D) - + text_features = self.text_encoder(text_embedded) # (B, L_text, D) + # Encode Mel spectrogram mel_features = self.mel_pe(self.mel_embedder(mel)) # (B, L_mel, D) - + # Causal Masking for Decoder seq_len = mel_features.size(1) - causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(mel.device) + causal_mask = ( + torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(mel.device) + ) # causal_mask = torch.triu( # torch.full((seq_len, seq_len), float('-inf'), device=mel.device), diagonal=1 # ) # Transformer Decoder with Cross-Attention in Each Layer decoder_out = self.decoder(mel_features, text_features, tgt_mask=causal_mask) - + # Length Prediction length_logits = self.predictor(decoder_out).squeeze(-1) return length_logits diff --git a/duration_trainer.py b/duration_trainer.py index 3757bbffb7d217932bb6b1abd1ccaccef71b10b0..e32e9e9984415e95db73c4a536e5d8957ff3e90e 100644 --- a/duration_trainer.py +++ b/duration_trainer.py @@ -1,11 +1,11 @@ from __future__ import annotations import gc -import os - import math +import os import torch +import torch.nn.functional as F import torchaudio import wandb from accelerate import Accelerator @@ -13,37 +13,28 @@ from accelerate.utils import DistributedDataParallelKwargs from ema_pytorch import EMA from torch.optim import AdamW from torch.optim.lr_scheduler import LinearLR, SequentialLR -from torch.utils.data import DataLoader, Dataset, SequentialSampler, Subset # <-- Added Subset import +from torch.utils.data import Dataset # <-- Added Subset import +from torch.utils.data import DataLoader, SequentialSampler, Subset from tqdm import tqdm -import torch.nn.functional as F - -from f5_tts.model import CFM -from f5_tts.model.dataset import collate_fn, DynamicBatchSampler -from f5_tts.model.utils import default, exists - from duration_predictor import calculate_remaining_lengths +from f5_tts.model import CFM +from f5_tts.model.dataset import DynamicBatchSampler, collate_fn +from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx, + list_str_to_tensor, mask_from_frac_lengths) # trainer -from f5_tts.model.utils import ( - default, - exists, - list_str_to_idx, - list_str_to_tensor, - lens_to_mask, - mask_from_frac_lengths, -) SAMPLE_RATE = 24_000 def masked_l1_loss(est_lengths, tar_lengths): - first_zero_idx = (tar_lengths == 0).int().argmax(dim=1) + first_zero_idx = (tar_lengths == 0).int().argmax(dim=1) B, L = tar_lengths.shape - range_tensor = torch.arange(L, device=tar_lengths.device).expand(B, L) + range_tensor = torch.arange(L, device=tar_lengths.device).expand(B, L) mask = range_tensor <= first_zero_idx[:, None] # Include the first 0 - loss = F.l1_loss(est_lengths, tar_lengths, reduction='none') # (B, L) + loss = F.l1_loss(est_lengths, tar_lengths, reduction="none") # (B, L) loss = loss * mask # Zero out ignored positions loss = loss.sum() / mask.sum() # Normalize by valid elements return loss @@ -55,9 +46,9 @@ def masked_cross_entropy_loss(est_length_logits, tar_length_labels): range_tensor = torch.arange(L, device=tar_length_labels.device).expand(B, L) mask = range_tensor <= first_zero_idx[:, None] # Include the first 0 loss = F.cross_entropy( - est_length_logits.reshape(-1, est_length_logits.size(-1)), - tar_length_labels.reshape(-1), - reduction='none' + est_length_logits.reshape(-1, est_length_logits.size(-1)), + tar_length_labels.reshape(-1), + reduction="none", ).reshape(B, L) loss = loss * mask loss = loss.sum() / mask.sum() @@ -71,7 +62,7 @@ class Trainer: vocab_size, vocab_char_map, process_token_to_id=True, - loss_fn='L1', + loss_fn="L1", lambda_L1=1, gumbel_tau=0.5, n_class=301, @@ -110,7 +101,13 @@ class Trainer: self.logger = logger if self.logger == "wandb": if exists(wandb_resume_id): - init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}} + init_kwargs = { + "wandb": { + "resume": "allow", + "name": wandb_run_name, + "id": wandb_resume_id, + } + } else: init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}} @@ -139,7 +136,7 @@ class Trainer: self.vocab_size = vocab_size self.vocab_char_map = vocab_char_map self.process_token_to_id = process_token_to_id - assert loss_fn in ['L1', 'CE', 'L1_and_CE'] + assert loss_fn in ["L1", "CE", "L1_and_CE"] self.loss_fn = loss_fn self.lambda_L1 = lambda_L1 self.n_class = n_class @@ -149,7 +146,9 @@ class Trainer: self.epochs = epochs self.num_warmup_updates = num_warmup_updates self.save_per_updates = save_per_updates - self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps) + self.last_per_steps = default( + last_per_steps, save_per_updates * grad_accumulation_steps + ) self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts") self.batch_size = batch_size @@ -164,33 +163,44 @@ class Trainer: self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate) else: self.optimizer = AdamW(model.parameters(), lr=learning_rate) - self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) - + self.model, self.optimizer = self.accelerator.prepare( + self.model, self.optimizer + ) + @property def is_main(self): return self.accelerator.is_main_process def save_checkpoint(self, step, last=False): self.accelerator.wait_for_everyone() - if self.is_main: + if self.is_main: checkpoint = dict( model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(), - optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(), + optimizer_state_dict=self.accelerator.unwrap_model( + self.optimizer + ).state_dict(), scheduler_state_dict=self.scheduler.state_dict(), step=step, ) if not os.path.exists(self.checkpoint_path): os.makedirs(self.checkpoint_path) if last: - self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt") + self.accelerator.save( + checkpoint, f"{self.checkpoint_path}/model_last.pt" + ) else: - self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt") + self.accelerator.save( + checkpoint, f"{self.checkpoint_path}/model_{step}.pt" + ) def load_checkpoint(self): if ( not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) - or not any(filename.endswith(".pt") for filename in os.listdir(self.checkpoint_path)) + or not any( + filename.endswith(".pt") + for filename in os.listdir(self.checkpoint_path) + ) ): return 0 @@ -203,21 +213,32 @@ class Trainer: key=lambda x: int("".join(filter(str.isdigit, x))), )[-1] - print(f'To load from {latest_checkpoint}.') + print(f"To load from {latest_checkpoint}.") # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ - checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu") + checkpoint = torch.load( + f"{self.checkpoint_path}/{latest_checkpoint}", + weights_only=True, + map_location="cpu", + ) - print(f'Loaded from {latest_checkpoint}.') + print(f"Loaded from {latest_checkpoint}.") if "step" in checkpoint: # patch for backward compatibility, 305e3ea - for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]: + for key in [ + "mel_spec.mel_stft.mel_scale.fb", + "mel_spec.mel_stft.spectrogram.window", + ]: if key in checkpoint["model_state_dict"]: del checkpoint["model_state_dict"][key] - self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"]) - self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"]) + self.accelerator.unwrap_model(self.model).load_state_dict( + checkpoint["model_state_dict"] + ) + self.accelerator.unwrap_model(self.optimizer).load_state_dict( + checkpoint["optimizer_state_dict"] + ) if self.scheduler: self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) step = checkpoint["step"] @@ -227,17 +248,18 @@ class Trainer: for k, v in checkpoint["ema_model_state_dict"].items() if k not in ["initted", "step"] } - self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"]) + self.accelerator.unwrap_model(self.model).load_state_dict( + checkpoint["model_state_dict"] + ) step = 0 - + del checkpoint gc.collect() - print(f'Exit load_checkpoint.') + print(f"Exit load_checkpoint.") return step - def validate(self, valid_dataloader, global_step): """ Runs evaluation on the validation set, computes the average loss, @@ -251,54 +273,61 @@ class Trainer: with torch.no_grad(): for batch in valid_dataloader: # Inputs - mel = batch['mel'].permute(0, 2, 1) # (B, L_mel, D) - text = batch['text'] + mel = batch["mel"].permute(0, 2, 1) # (B, L_mel, D) + text = batch["text"] if self.process_token_to_id: text_ids = list_str_to_idx(text, self.vocab_char_map).to(mel.device) - text_ids = text_ids.masked_fill(text_ids==-1, self.vocab_size) + text_ids = text_ids.masked_fill(text_ids == -1, self.vocab_size) else: text_ids = text # Targets - mel_lengths = batch['mel_lengths'] + mel_lengths = batch["mel_lengths"] tar_lengths = calculate_remaining_lengths(mel_lengths) predictions = self.model(text_ids=text_ids, mel=mel) - if self.loss_fn == 'L1': + if self.loss_fn == "L1": est_lengths = predictions loss = masked_l1_loss( est_lengths=est_lengths, tar_lengths=tar_lengths ) frame_error = loss - elif self.loss_fn == 'CE': - tar_length_labels = (tar_lengths // self.n_frame_per_class) \ - .clamp(min=0, max=self.n_class-1) # [0, 1, ..., n_class-1] + elif self.loss_fn == "CE": + tar_length_labels = (tar_lengths // self.n_frame_per_class).clamp( + min=0, max=self.n_class - 1 + ) # [0, 1, ..., n_class-1] est_length_logtis = predictions est_length_labels = torch.argmax(est_length_logtis, dim=-1) loss = masked_cross_entropy_loss( - est_length_logits=est_length_logtis, tar_length_labels=tar_length_labels + est_length_logits=est_length_logtis, + tar_length_labels=tar_length_labels, ) est_lengths = est_length_labels * self.n_frame_per_class frame_error = masked_l1_loss( est_lengths=est_lengths, tar_lengths=tar_lengths ) - elif self.loss_fn == 'L1_and_CE': - tar_length_labels = (tar_lengths // self.n_frame_per_class) \ - .clamp(min=0, max=self.n_class-1) # [0, 1, ..., n_class-1] + elif self.loss_fn == "L1_and_CE": + tar_length_labels = (tar_lengths // self.n_frame_per_class).clamp( + min=0, max=self.n_class - 1 + ) # [0, 1, ..., n_class-1] est_length_logtis = predictions est_length_1hots = F.gumbel_softmax( est_length_logtis, tau=self.gumbel_tau, hard=True, dim=-1 ) - length_values = torch.arange( - self.n_class, device=est_length_1hots.device - ).float() * self.n_frame_per_class + length_values = ( + torch.arange( + self.n_class, device=est_length_1hots.device + ).float() + * self.n_frame_per_class + ) est_lengths = (est_length_1hots * length_values).sum(-1) loss_CE = masked_cross_entropy_loss( - est_length_logits=est_length_logtis, tar_length_labels=tar_length_labels + est_length_logits=est_length_logtis, + tar_length_labels=tar_length_labels, ) loss_L1 = masked_l1_loss( @@ -321,18 +350,19 @@ class Trainer: avg_valid_sec_error = total_sec_error / count if count > 0 else 0.0 # Log validation metrics self.accelerator.log( - { - f"valid_loss": avg_valid_loss, - f"valid_sec_error": avg_valid_sec_error - }, - step=global_step + {f"valid_loss": avg_valid_loss, f"valid_sec_error": avg_valid_sec_error}, + step=global_step, ) - - self.model.train() + self.model.train() - def train(self, train_dataset: Dataset, valid_dataset: Dataset, - num_workers=64, resumable_with_seed: int = None): + def train( + self, + train_dataset: Dataset, + valid_dataset: Dataset, + num_workers=64, + resumable_with_seed: int = None, + ): if exists(resumable_with_seed): generator = torch.Generator() generator.manual_seed(resumable_with_seed) @@ -366,7 +396,11 @@ class Trainer: sampler = SequentialSampler(train_dataset) batch_sampler = DynamicBatchSampler( - sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False + sampler, + self.batch_size, + max_samples=self.max_samples, + random_seed=resumable_with_seed, + drop_last=False, ) train_dataloader = DataLoader( train_dataset, @@ -379,20 +413,26 @@ class Trainer: sampler = SequentialSampler(valid_dataset) batch_sampler = DynamicBatchSampler( - sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False + sampler, + self.batch_size, + max_samples=self.max_samples, + random_seed=resumable_with_seed, + drop_last=False, ) # Create validation dataloader (always sequential, no shuffling) valid_dataloader = DataLoader( valid_dataset, collate_fn=collate_fn, num_workers=num_workers, - pin_memory=True, + pin_memory=True, persistent_workers=True, batch_sampler=batch_sampler, ) else: - raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}") - + raise ValueError( + f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}" + ) + # accelerator.prepare() dispatches batches to devices; # which means the length of dataloader calculated before, should consider the number of devices warmup_steps = ( @@ -401,10 +441,16 @@ class Trainer: # otherwise by default with split_batches=False, warmup steps change with num_processes total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps decay_steps = total_steps - warmup_steps - warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps) - decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps) + warmup_scheduler = LinearLR( + self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps + ) + decay_scheduler = LinearLR( + self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps + ) self.scheduler = SequentialLR( - self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps] + self.optimizer, + schedulers=[warmup_scheduler, decay_scheduler], + milestones=[warmup_steps], ) train_dataloader, self.scheduler = self.accelerator.prepare( train_dataloader, self.scheduler @@ -418,7 +464,9 @@ class Trainer: orig_epoch_step = len(train_dataloader) skipped_epoch = int(start_step // orig_epoch_step) skipped_batch = start_step % orig_epoch_step - skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch) + skipped_dataloader = self.accelerator.skip_first_batches( + train_dataloader, num_batches=skipped_batch + ) else: skipped_epoch = 0 @@ -444,21 +492,23 @@ class Trainer: for batch in progress_bar: with self.accelerator.accumulate(self.model): # Inputs - mel = batch['mel'].permute(0, 2, 1) # (B, L_mel, D) - text = batch['text'] + mel = batch["mel"].permute(0, 2, 1) # (B, L_mel, D) + text = batch["text"] if self.process_token_to_id: - text_ids = list_str_to_idx(text, self.vocab_char_map).to(mel.device) - text_ids = text_ids.masked_fill(text_ids==-1, self.vocab_size) + text_ids = list_str_to_idx(text, self.vocab_char_map).to( + mel.device + ) + text_ids = text_ids.masked_fill(text_ids == -1, self.vocab_size) else: text_ids = text # Targets - mel_lengths = batch['mel_lengths'] + mel_lengths = batch["mel_lengths"] tar_lengths = calculate_remaining_lengths(mel_lengths) predictions = self.model(text_ids=text_ids, mel=mel) - if self.loss_fn == 'L1': + if self.loss_fn == "L1": est_lengths = predictions loss = masked_l1_loss( est_lengths=est_lengths, tar_lengths=tar_lengths @@ -469,19 +519,23 @@ class Trainer: sec_error = frame_error * 256 / 24000 log_dict = { - 'loss': loss.item(), - 'loss_L1': loss.item(), - 'sec_error': sec_error.item(), - 'lr': self.scheduler.get_last_lr()[0] - } - - elif self.loss_fn == 'CE': - tar_length_labels = (tar_lengths // self.n_frame_per_class) \ - .clamp(min=0, max=self.n_class-1) # [0, 1, ..., n_class-1] + "loss": loss.item(), + "loss_L1": loss.item(), + "sec_error": sec_error.item(), + "lr": self.scheduler.get_last_lr()[0], + } + + elif self.loss_fn == "CE": + tar_length_labels = ( + tar_lengths // self.n_frame_per_class + ).clamp( + min=0, max=self.n_class - 1 + ) # [0, 1, ..., n_class-1] est_length_logtis = predictions est_length_labels = torch.argmax(est_length_logtis, dim=-1) loss = masked_cross_entropy_loss( - est_length_logits=est_length_logtis, tar_length_labels=tar_length_labels + est_length_logits=est_length_logtis, + tar_length_labels=tar_length_labels, ) with torch.no_grad(): est_lengths = est_length_labels * self.n_frame_per_class @@ -491,29 +545,36 @@ class Trainer: sec_error = frame_error * 256 / 24000 log_dict = { - 'loss': loss.item(), - 'loss_CE': loss.item(), - 'sec_error': sec_error.item(), - 'lr': self.scheduler.get_last_lr()[0] - } - - elif self.loss_fn == 'L1_and_CE': - tar_length_labels = (tar_lengths // self.n_frame_per_class) \ - .clamp(min=0, max=self.n_class-1) # [0, 1, ..., n_class-1] + "loss": loss.item(), + "loss_CE": loss.item(), + "sec_error": sec_error.item(), + "lr": self.scheduler.get_last_lr()[0], + } + + elif self.loss_fn == "L1_and_CE": + tar_length_labels = ( + tar_lengths // self.n_frame_per_class + ).clamp( + min=0, max=self.n_class - 1 + ) # [0, 1, ..., n_class-1] est_length_logtis = predictions est_length_1hots = F.gumbel_softmax( est_length_logtis, tau=self.gumbel_tau, hard=True, dim=-1 ) - length_values = torch.arange( - self.n_class, device=est_length_1hots.device - ).float() * self.n_frame_per_class + length_values = ( + torch.arange( + self.n_class, device=est_length_1hots.device + ).float() + * self.n_frame_per_class + ) est_lengths = (est_length_1hots * length_values).sum(-1) loss_CE = masked_cross_entropy_loss( - est_length_logits=est_length_logtis, tar_length_labels=tar_length_labels + est_length_logits=est_length_logtis, + tar_length_labels=tar_length_labels, ) - loss_L1 = masked_l1_loss( + loss_L1 = masked_l1_loss( est_lengths=est_lengths, tar_lengths=tar_lengths ) @@ -524,21 +585,22 @@ class Trainer: sec_error = frame_error * 256 / 24000 log_dict = { - 'loss': loss.item(), - 'loss_L1': loss_L1.item(), - 'loss_CE': loss_CE.item(), - 'sec_error': sec_error.item(), - 'lr': self.scheduler.get_last_lr()[0] + "loss": loss.item(), + "loss_L1": loss_L1.item(), + "loss_CE": loss_CE.item(), + "sec_error": sec_error.item(), + "lr": self.scheduler.get_last_lr()[0], } else: raise NotImplementedError(self.loss_fn) - self.accelerator.backward(loss) if self.max_grad_norm > 0 and self.accelerator.sync_gradients: - self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) + self.accelerator.clip_grad_norm_( + self.model.parameters(), self.max_grad_norm + ) self.optimizer.step() self.scheduler.step() @@ -550,7 +612,10 @@ class Trainer: self.accelerator.log(log_dict, step=global_step) progress_bar.set_postfix(step=str(global_step), loss=loss.item()) - if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0: + if ( + global_step % (self.save_per_updates * self.grad_accumulation_steps) + == 0 + ): self.save_checkpoint(global_step) # if self.log_samples and self.accelerator.is_local_main_process: # Run validation at the end of each epoch (only on the main process) diff --git a/duration_trainer_with_prompt.py b/duration_trainer_with_prompt.py index cb964575c3d5114efde1fcc8868053e7cba3b714..bcaa1fe9c2e9a4ca2bb99b3bf8f7adf0e1862e66 100644 --- a/duration_trainer_with_prompt.py +++ b/duration_trainer_with_prompt.py @@ -1,11 +1,11 @@ from __future__ import annotations import gc -import os - import math +import os import torch +import torch.nn.functional as F import torchaudio import wandb from accelerate import Accelerator @@ -13,25 +13,17 @@ from accelerate.utils import DistributedDataParallelKwargs from ema_pytorch import EMA from torch.optim import AdamW from torch.optim.lr_scheduler import LinearLR, SequentialLR -from torch.utils.data import DataLoader, Dataset, SequentialSampler, Subset # <-- Added Subset import +from torch.utils.data import Dataset # <-- Added Subset import +from torch.utils.data import DataLoader, SequentialSampler, Subset from tqdm import tqdm -import torch.nn.functional as F - from f5_tts.model import CFM -from f5_tts.model.dataset import collate_fn, DynamicBatchSampler -from f5_tts.model.utils import default, exists +from f5_tts.model.dataset import DynamicBatchSampler, collate_fn +from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx, + list_str_to_tensor, mask_from_frac_lengths) # trainer -from f5_tts.model.utils import ( - default, - exists, - list_str_to_idx, - list_str_to_tensor, - lens_to_mask, - mask_from_frac_lengths, -) SAMPLE_RATE = 24_000 @@ -43,7 +35,7 @@ class Trainer: vocab_size, vocab_char_map, process_token_to_id=True, - loss_fn='L1', + loss_fn="L1", lambda_L1=1, gumbel_tau=0.5, n_class=301, @@ -83,7 +75,13 @@ class Trainer: self.logger = logger if self.logger == "wandb": if exists(wandb_resume_id): - init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}} + init_kwargs = { + "wandb": { + "resume": "allow", + "name": wandb_run_name, + "id": wandb_resume_id, + } + } else: init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}} @@ -112,7 +110,7 @@ class Trainer: self.vocab_size = vocab_size self.vocab_char_map = vocab_char_map self.process_token_to_id = process_token_to_id - assert loss_fn in ['L1', 'CE', 'L1_and_CE'] + assert loss_fn in ["L1", "CE", "L1_and_CE"] self.loss_fn = loss_fn self.lambda_L1 = lambda_L1 self.n_class = n_class @@ -122,7 +120,9 @@ class Trainer: self.epochs = epochs self.num_warmup_updates = num_warmup_updates self.save_per_updates = save_per_updates - self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps) + self.last_per_steps = default( + last_per_steps, save_per_updates * grad_accumulation_steps + ) self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts") self.batch_size = batch_size @@ -137,33 +137,44 @@ class Trainer: self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate) else: self.optimizer = AdamW(model.parameters(), lr=learning_rate) - self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) - + self.model, self.optimizer = self.accelerator.prepare( + self.model, self.optimizer + ) + @property def is_main(self): return self.accelerator.is_main_process def save_checkpoint(self, step, last=False): self.accelerator.wait_for_everyone() - if self.is_main: + if self.is_main: checkpoint = dict( model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(), - optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(), + optimizer_state_dict=self.accelerator.unwrap_model( + self.optimizer + ).state_dict(), scheduler_state_dict=self.scheduler.state_dict(), step=step, ) if not os.path.exists(self.checkpoint_path): os.makedirs(self.checkpoint_path) if last: - self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt") + self.accelerator.save( + checkpoint, f"{self.checkpoint_path}/model_last.pt" + ) else: - self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt") + self.accelerator.save( + checkpoint, f"{self.checkpoint_path}/model_{step}.pt" + ) def load_checkpoint(self): if ( not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) - or not any(filename.endswith(".pt") for filename in os.listdir(self.checkpoint_path)) + or not any( + filename.endswith(".pt") + for filename in os.listdir(self.checkpoint_path) + ) ): return 0 @@ -176,21 +187,32 @@ class Trainer: key=lambda x: int("".join(filter(str.isdigit, x))), )[-1] - print(f'To load from {latest_checkpoint}.') + print(f"To load from {latest_checkpoint}.") # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ - checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu") + checkpoint = torch.load( + f"{self.checkpoint_path}/{latest_checkpoint}", + weights_only=True, + map_location="cpu", + ) - print(f'Loaded from {latest_checkpoint}.') + print(f"Loaded from {latest_checkpoint}.") if "step" in checkpoint: # patch for backward compatibility, 305e3ea - for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]: + for key in [ + "mel_spec.mel_stft.mel_scale.fb", + "mel_spec.mel_stft.spectrogram.window", + ]: if key in checkpoint["model_state_dict"]: del checkpoint["model_state_dict"][key] - self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"]) - self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"]) + self.accelerator.unwrap_model(self.model).load_state_dict( + checkpoint["model_state_dict"] + ) + self.accelerator.unwrap_model(self.optimizer).load_state_dict( + checkpoint["optimizer_state_dict"] + ) if self.scheduler: self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) step = checkpoint["step"] @@ -200,17 +222,18 @@ class Trainer: for k, v in checkpoint["ema_model_state_dict"].items() if k not in ["initted", "step"] } - self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"]) + self.accelerator.unwrap_model(self.model).load_state_dict( + checkpoint["model_state_dict"] + ) step = 0 - + del checkpoint gc.collect() - print(f'Exit load_checkpoint.') + print(f"Exit load_checkpoint.") return step - def validate(self, valid_dataloader, global_step): """ Runs evaluation on the validation set, computes the average loss, @@ -226,31 +249,40 @@ class Trainer: for batch in valid_dataloader: # Inputs - prompt_mel = batch['pmt_mel_specs'].permute(0, 2, 1) # (B, L_mel, D) - prompt_text = batch['pmt_text'] - text = batch['text'] + prompt_mel = batch["pmt_mel_specs"].permute(0, 2, 1) # (B, L_mel, D) + prompt_text = batch["pmt_text"] + text = batch["text"] - target_ids = list_str_to_idx(text, self.vocab_char_map).to(prompt_mel.device) - target_ids = target_ids.masked_fill(target_ids==-1, vocab_size) + target_ids = list_str_to_idx(text, self.vocab_char_map).to( + prompt_mel.device + ) + target_ids = target_ids.masked_fill(target_ids == -1, vocab_size) - prompt_ids = list_str_to_idx(prompt_text, self.vocab_char_map).to(prompt_mel.device) - prompt_ids = prompt_ids.masked_fill(prompt_ids==-1, vocab_size) + prompt_ids = list_str_to_idx(prompt_text, self.vocab_char_map).to( + prompt_mel.device + ) + prompt_ids = prompt_ids.masked_fill(prompt_ids == -1, vocab_size) # Targets - tar_lengths = batch['mel_lengths'] + tar_lengths = batch["mel_lengths"] # Forward - predictions = SLP(target_ids=target_ids, prompt_ids=prompt_ids, prompt_mel=prompt_mel) # (B, C) - - if self.loss_fn == 'CE': - tar_length_labels = (tar_lengths // self.n_frame_per_class) \ - .clamp(min=0, max=self.n_class-1) # [0, 1, ..., n_class-1] + predictions = SLP( + target_ids=target_ids, prompt_ids=prompt_ids, prompt_mel=prompt_mel + ) # (B, C) + + if self.loss_fn == "CE": + tar_length_labels = (tar_lengths // self.n_frame_per_class).clamp( + min=0, max=self.n_class - 1 + ) # [0, 1, ..., n_class-1] est_length_logtis = predictions est_length_labels = torch.argmax(est_length_logtis, dim=-1) loss = F.cross_entropy(est_length_logtis, tar_length_labels) - + est_lengths = est_length_labels * self.n_frame_per_class - frame_error = (est_lengths.float() - tar_lengths.float()).abs().mean() + frame_error = ( + (est_lengths.float() - tar_lengths.float()).abs().mean() + ) sec_error = frame_error * 256 / 24000 total_sec_error += sec_error.item() @@ -262,18 +294,19 @@ class Trainer: # Log validation metrics self.accelerator.log( - { - f"valid_loss": avg_valid_loss, - f"valid_sec_error": avg_valid_sec_error - }, - step=global_step + {f"valid_loss": avg_valid_loss, f"valid_sec_error": avg_valid_sec_error}, + step=global_step, ) - - self.model.train() + self.model.train() - def train(self, train_dataset: Dataset, valid_dataset: Dataset, - num_workers=64, resumable_with_seed: int = None): + def train( + self, + train_dataset: Dataset, + valid_dataset: Dataset, + num_workers=64, + resumable_with_seed: int = None, + ): if exists(resumable_with_seed): generator = torch.Generator() generator.manual_seed(resumable_with_seed) @@ -307,7 +340,11 @@ class Trainer: sampler = SequentialSampler(train_dataset) batch_sampler = DynamicBatchSampler( - sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False + sampler, + self.batch_size, + max_samples=self.max_samples, + random_seed=resumable_with_seed, + drop_last=False, ) train_dataloader = DataLoader( train_dataset, @@ -320,20 +357,26 @@ class Trainer: sampler = SequentialSampler(valid_dataset) batch_sampler = DynamicBatchSampler( - sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False + sampler, + self.batch_size, + max_samples=self.max_samples, + random_seed=resumable_with_seed, + drop_last=False, ) # Create validation dataloader (always sequential, no shuffling) valid_dataloader = DataLoader( valid_dataset, collate_fn=collate_fn, num_workers=num_workers, - pin_memory=True, + pin_memory=True, persistent_workers=True, batch_sampler=batch_sampler, ) else: - raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}") - + raise ValueError( + f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}" + ) + # accelerator.prepare() dispatches batches to devices; # which means the length of dataloader calculated before, should consider the number of devices warmup_steps = ( @@ -342,10 +385,16 @@ class Trainer: # otherwise by default with split_batches=False, warmup steps change with num_processes total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps decay_steps = total_steps - warmup_steps - warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps) - decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps) + warmup_scheduler = LinearLR( + self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps + ) + decay_scheduler = LinearLR( + self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps + ) self.scheduler = SequentialLR( - self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps] + self.optimizer, + schedulers=[warmup_scheduler, decay_scheduler], + milestones=[warmup_steps], ) train_dataloader, self.scheduler = self.accelerator.prepare( train_dataloader, self.scheduler @@ -359,7 +408,9 @@ class Trainer: orig_epoch_step = len(train_dataloader) skipped_epoch = int(start_step // orig_epoch_step) skipped_batch = start_step % orig_epoch_step - skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch) + skipped_dataloader = self.accelerator.skip_first_batches( + train_dataloader, num_batches=skipped_batch + ) else: skipped_epoch = 0 @@ -385,49 +436,65 @@ class Trainer: for batch in progress_bar: with self.accelerator.accumulate(self.model): # Inputs - prompt_mel = batch['pmt_mel_specs'].permute(0, 2, 1) # (B, L_mel, D) - prompt_text = batch['pmt_text'] - text = batch['text'] - - target_ids = list_str_to_idx(text, self.vocab_char_map).to(prompt_mel.device) - target_ids = target_ids.masked_fill(target_ids==-1, vocab_size) - - prompt_ids = list_str_to_idx(prompt_text, self.vocab_char_map).to(prompt_mel.device) - prompt_ids = prompt_ids.masked_fill(prompt_ids==-1, vocab_size) + prompt_mel = batch["pmt_mel_specs"].permute( + 0, 2, 1 + ) # (B, L_mel, D) + prompt_text = batch["pmt_text"] + text = batch["text"] + + target_ids = list_str_to_idx(text, self.vocab_char_map).to( + prompt_mel.device + ) + target_ids = target_ids.masked_fill(target_ids == -1, vocab_size) + + prompt_ids = list_str_to_idx(prompt_text, self.vocab_char_map).to( + prompt_mel.device + ) + prompt_ids = prompt_ids.masked_fill(prompt_ids == -1, vocab_size) # Targets - tar_lengths = batch['mel_lengths'] + tar_lengths = batch["mel_lengths"] # Forward - predictions = SLP(target_ids=target_ids, prompt_ids=prompt_ids, prompt_mel=prompt_mel) # (B, C) - - if self.loss_fn == 'CE': - tar_length_labels = (tar_lengths // self.n_frame_per_class) \ - .clamp(min=0, max=self.n_class-1) # [0, 1, ..., n_class-1] + predictions = SLP( + target_ids=target_ids, + prompt_ids=prompt_ids, + prompt_mel=prompt_mel, + ) # (B, C) + + if self.loss_fn == "CE": + tar_length_labels = ( + tar_lengths // self.n_frame_per_class + ).clamp( + min=0, max=self.n_class - 1 + ) # [0, 1, ..., n_class-1] est_length_logtis = predictions est_length_labels = torch.argmax(est_length_logtis, dim=-1) loss = F.cross_entropy(est_length_logtis, tar_length_labels) - + with torch.no_grad(): est_lengths = est_length_labels * self.n_frame_per_class - frame_error = (est_lengths.float() - tar_lengths.float()).abs().mean() + frame_error = ( + (est_lengths.float() - tar_lengths.float()).abs().mean() + ) sec_error = frame_error * 256 / 24000 log_dict = { - 'loss': loss.item(), - 'loss_CE': loss.item(), - 'sec_error': sec_error.item(), - 'lr': self.scheduler.get_last_lr()[0] - } + "loss": loss.item(), + "loss_CE": loss.item(), + "sec_error": sec_error.item(), + "lr": self.scheduler.get_last_lr()[0], + } else: raise NotImplementedError(self.loss_fn) - self.accelerator.backward(loss) if self.max_grad_norm > 0 and self.accelerator.sync_gradients: - self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) + self.accelerator.clip_grad_norm_( + self.model.parameters(), self.max_grad_norm + ) self.optimizer.step() self.scheduler.step() @@ -439,7 +506,10 @@ class Trainer: self.accelerator.log(log_dict, step=global_step) progress_bar.set_postfix(step=str(global_step), loss=loss.item()) - if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0: + if ( + global_step % (self.save_per_updates * self.grad_accumulation_steps) + == 0 + ): self.save_checkpoint(global_step) # if self.log_samples and self.accelerator.is_local_main_process: # Run validation at the end of each epoch (only on the main process) diff --git a/ecapa_tdnn.py b/ecapa_tdnn.py index b55aaf2fcc29117a8ad2341a2262bf79839c3100..15583583bc9d7cc240413c8e18544f706e355ff6 100644 --- a/ecapa_tdnn.py +++ b/ecapa_tdnn.py @@ -1,23 +1,34 @@ # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN +# from ctcmodel_nopool import ConformerCTC as ConformerCTCNoPool +from pathlib import Path + import torch import torch.nn as nn import torch.nn.functional as F import torchaudio.transforms as trans + from ctcmodel import ConformerCTC -# from ctcmodel_nopool import ConformerCTC as ConformerCTCNoPool -from pathlib import Path -''' Res2Conv1d + BatchNorm1d + ReLU -''' +""" Res2Conv1d + BatchNorm1d + ReLU +""" class Res2Conv1dReluBn(nn.Module): - ''' + """ in_channels == out_channels == channels - ''' - - def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4): + """ + + def __init__( + self, + channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + bias=True, + scale=4, + ): super().__init__() assert channels % scale == 0, "{} % {} != 0".format(channels, scale) self.scale = scale @@ -27,7 +38,17 @@ class Res2Conv1dReluBn(nn.Module): self.convs = [] self.bns = [] for i in range(self.nums): - self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias)) + self.convs.append( + nn.Conv1d( + self.width, + self.width, + kernel_size, + stride, + padding, + dilation, + bias=bias, + ) + ) self.bns.append(nn.BatchNorm1d(self.width)) self.convs = nn.ModuleList(self.convs) self.bns = nn.ModuleList(self.bns) @@ -51,22 +72,33 @@ class Res2Conv1dReluBn(nn.Module): return out -''' Conv1d + BatchNorm1d + ReLU -''' +""" Conv1d + BatchNorm1d + ReLU +""" class Conv1dReluBn(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True): + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + bias=True, + ): super().__init__() - self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias) + self.conv = nn.Conv1d( + in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias + ) self.bn = nn.BatchNorm1d(out_channels) def forward(self, x): return self.bn(F.relu(self.conv(x))) -''' The SE connection of 1D case. -''' +""" The SE connection of 1D case. +""" class SE_Connect(nn.Module): @@ -84,15 +116,32 @@ class SE_Connect(nn.Module): return out -''' SE-Res2Block of the ECAPA-TDNN architecture. -''' +""" SE-Res2Block of the ECAPA-TDNN architecture. +""" + class SE_Res2Block(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + scale, + se_bottleneck_dim, + ): super().__init__() - self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale) - self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0) + self.Conv1dReluBn1 = Conv1dReluBn( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + self.Res2Conv1dReluBn = Res2Conv1dReluBn( + out_channels, kernel_size, stride, padding, dilation, scale=scale + ) + self.Conv1dReluBn2 = Conv1dReluBn( + out_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim) self.shortcut = None @@ -116,8 +165,9 @@ class SE_Res2Block(nn.Module): return x + residual -''' Attentive weighted mean and standard deviation pooling. -''' +""" Attentive weighted mean and standard deviation pooling. +""" + class AttentiveStatsPool(nn.Module): def __init__(self, in_dim, attention_channels=128, global_context_att=False): @@ -126,16 +176,24 @@ class AttentiveStatsPool(nn.Module): # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs. if global_context_att: - self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper + self.linear1 = nn.Conv1d( + in_dim * 3, attention_channels, kernel_size=1 + ) # equals W and b in the paper else: - self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper - self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper + self.linear1 = nn.Conv1d( + in_dim, attention_channels, kernel_size=1 + ) # equals W and b in the paper + self.linear2 = nn.Conv1d( + attention_channels, in_dim, kernel_size=1 + ) # equals V and k in the paper def forward(self, x): if self.global_context_att: context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) - context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x) + context_std = torch.sqrt( + torch.var(x, dim=-1, keepdim=True) + 1e-10 + ).expand_as(x) x_in = torch.cat((x, context_mean, context_std), dim=1) else: x_in = x @@ -145,42 +203,52 @@ class AttentiveStatsPool(nn.Module): # alpha = F.relu(self.linear1(x_in)) alpha = torch.softmax(self.linear2(alpha), dim=2) mean = torch.sum(alpha * x, dim=2) - residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2 + residuals = torch.sum(alpha * (x**2), dim=2) - mean**2 std = torch.sqrt(residuals.clamp(min=1e-9)) return torch.cat([mean, std], dim=1) class ECAPA_TDNN(nn.Module): - def __init__(self, channels=512, emb_dim=512, - global_context_att=False, use_fp16=True, + def __init__( + self, + channels=512, + emb_dim=512, + global_context_att=False, + use_fp16=True, ctc_cls=ConformerCTC, - ctc_path='/data4/F5TTS/ckpts/F5TTS_norm_ASR_vocos_pinyin_Emilia_ZH_EN/model_last.pt', - ctc_args={'vocab_size': 2545, 'mel_dim': 100, 'num_heads': 8, 'd_hid': 512, 'nlayers': 6}, - ctc_no_grad=False + ctc_path="/data4/F5TTS/ckpts/F5TTS_norm_ASR_vocos_pinyin_Emilia_ZH_EN/model_last.pt", + ctc_args={ + "vocab_size": 2545, + "mel_dim": 100, + "num_heads": 8, + "d_hid": 512, + "nlayers": 6, + }, + ctc_no_grad=False, ): super().__init__() if ctc_path != None: ctc_path = Path(ctc_path) model = ctc_cls(**ctc_args) - state_dict = torch.load(ctc_path, map_location='cpu') - model.load_state_dict(state_dict['model_state_dict']) + state_dict = torch.load(ctc_path, map_location="cpu") + model.load_state_dict(state_dict["model_state_dict"]) print(f"Initialized pretrained ConformerCTC backbone from {ctc_path}.") else: raise ValueError(ctc_path) self.ctc_model = model self.ctc_model.out.requires_grad_(False) - + if ctc_cls == ConformerCTC: - self.feat_num = ctc_args['nlayers'] + 2 + 1 + self.feat_num = ctc_args["nlayers"] + 2 + 1 # elif ctc_cls == ConformerCTCNoPool: # self.feat_num = ctc_args['nlayers'] + 1 else: raise ValueError(ctc_cls) - feat_dim = ctc_args['d_hid'] + feat_dim = ctc_args["d_hid"] self.emb_dim = emb_dim - + self.feature_weight = nn.Parameter(torch.zeros(self.feat_num)) self.instance_norm = nn.InstanceNorm1d(feat_dim) @@ -188,14 +256,45 @@ class ECAPA_TDNN(nn.Module): self.channels = [channels] * 4 + [1536] self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2) - self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128) - self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128) - self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128) + self.layer2 = SE_Res2Block( + self.channels[0], + self.channels[1], + kernel_size=3, + stride=1, + padding=2, + dilation=2, + scale=8, + se_bottleneck_dim=128, + ) + self.layer3 = SE_Res2Block( + self.channels[1], + self.channels[2], + kernel_size=3, + stride=1, + padding=3, + dilation=3, + scale=8, + se_bottleneck_dim=128, + ) + self.layer4 = SE_Res2Block( + self.channels[2], + self.channels[3], + kernel_size=3, + stride=1, + padding=4, + dilation=4, + scale=8, + se_bottleneck_dim=128, + ) # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1) cat_channels = channels * 3 self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1) - self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att) + self.pooling = AttentiveStatsPool( + self.channels[-1], + attention_channels=128, + global_context_att=global_context_att, + ) self.bn = nn.BatchNorm1d(self.channels[-1] * 2) self.linear = nn.Linear(self.channels[-1] * 2, emb_dim) @@ -206,21 +305,26 @@ class ECAPA_TDNN(nn.Module): else: self.ctc_model = self.ctc_model.train() self.ctc_no_grad = ctc_no_grad - print('ctc_no_grad: ', self.ctc_no_grad) + print("ctc_no_grad: ", self.ctc_no_grad) - def forward(self, latent, input_lengths, return_asr=False): + def forward(self, latent, input_lengths, return_asr=False): if self.ctc_no_grad: with torch.no_grad(): asr, h = self.ctc_model(latent, input_lengths) else: asr, h = self.ctc_model(latent, input_lengths) - + x = torch.stack(h, dim=0) - norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + norm_weights = ( + F.softmax(self.feature_weight, dim=-1) + .unsqueeze(-1) + .unsqueeze(-1) + .unsqueeze(-1) + ) x = (norm_weights * x).sum(dim=0) x = x + 1e-6 # x = torch.transpose(x, 1, 2) + 1e-6 - + x = self.instance_norm(x) # x = torch.transpose(x, 1, 2) @@ -238,9 +342,10 @@ class ECAPA_TDNN(nn.Module): return out, asr return out + if __name__ == "__main__": - from diffspeech.ldm.model import DiT from diffspeech.data.collate import get_mask_from_lengths + from diffspeech.ldm.model import DiT from diffspeech.tools.text.vocab import IPA bsz = 3 @@ -265,4 +370,4 @@ if __name__ == "__main__": emb = model(latent, latent_mask.sum(axis=-1)) - print(emb.shape) \ No newline at end of file + print(emb.shape) diff --git a/f5_tts/api.py b/f5_tts/api.py index d73ee1be113643feff86938a601065813016f98e..2702a41b7ab294cbeca38aa01205c8ab60e5baca 100644 --- a/f5_tts/api.py +++ b/f5_tts/api.py @@ -8,15 +8,10 @@ from cached_path import cached_path from hydra.utils import get_class from omegaconf import OmegaConf -from f5_tts.infer.utils_infer import ( - infer_process, - load_model, - load_vocoder, - preprocess_ref_audio_text, - remove_silence_for_generated_wav, - save_spectrogram, - transcribe, -) +from f5_tts.infer.utils_infer import (infer_process, load_model, load_vocoder, + preprocess_ref_audio_text, + remove_silence_for_generated_wav, + save_spectrogram, transcribe) from f5_tts.model.utils import seed_everything @@ -32,7 +27,9 @@ class F5TTS: device=None, hf_cache_dir=None, ): - model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) + model_cfg = OmegaConf.load( + str(files("f5_tts").joinpath(f"configs/{model}.yaml")) + ) model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") model_arc = model_cfg.model.arch @@ -50,16 +47,20 @@ class F5TTS: self.device = ( "cuda" if torch.cuda.is_available() - else "xpu" - if torch.xpu.is_available() - else "mps" - if torch.backends.mps.is_available() - else "cpu" + else ( + "xpu" + if torch.xpu.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" + ) ) # Load models self.vocoder = load_vocoder( - self.mel_spec_type, vocoder_local_path is not None, vocoder_local_path, self.device, hf_cache_dir + self.mel_spec_type, + vocoder_local_path is not None, + vocoder_local_path, + self.device, + hf_cache_dir, ) repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors" @@ -77,10 +78,20 @@ class F5TTS: if not ckpt_file: ckpt_file = str( - cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}", cache_dir=hf_cache_dir) + cached_path( + f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}", + cache_dir=hf_cache_dir, + ) ) self.ema_model = load_model( - model_cls, model_arc, ckpt_file, self.mel_spec_type, vocab_file, self.ode_method, self.use_ema, self.device + model_cls, + model_arc, + ckpt_file, + self.mel_spec_type, + vocab_file, + self.ode_method, + self.use_ema, + self.device, ) def transcribe(self, ref_audio, language=None): diff --git a/f5_tts/eval/ecapa_tdnn.py b/f5_tts/eval/ecapa_tdnn.py index f0e4c9cc20c3a9f251fd22f4f27851fd04238962..4d06d1aa0eeaf0c9bc46df2e06c89b6e4db7c478 100644 --- a/f5_tts/eval/ecapa_tdnn.py +++ b/f5_tts/eval/ecapa_tdnn.py @@ -9,7 +9,6 @@ import torch import torch.nn as nn import torch.nn.functional as F - """ Res2Conv1d + BatchNorm1d + ReLU """ @@ -19,7 +18,16 @@ class Res2Conv1dReluBn(nn.Module): in_channels == out_channels == channels """ - def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4): + def __init__( + self, + channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + bias=True, + scale=4, + ): super().__init__() assert channels % scale == 0, "{} % {} != 0".format(channels, scale) self.scale = scale @@ -29,7 +37,17 @@ class Res2Conv1dReluBn(nn.Module): self.convs = [] self.bns = [] for i in range(self.nums): - self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias)) + self.convs.append( + nn.Conv1d( + self.width, + self.width, + kernel_size, + stride, + padding, + dilation, + bias=bias, + ) + ) self.bns.append(nn.BatchNorm1d(self.width)) self.convs = nn.ModuleList(self.convs) self.bns = nn.ModuleList(self.bns) @@ -58,9 +76,20 @@ class Res2Conv1dReluBn(nn.Module): class Conv1dReluBn(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True): + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + bias=True, + ): super().__init__() - self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias) + self.conv = nn.Conv1d( + in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias + ) self.bn = nn.BatchNorm1d(out_channels) def forward(self, x): @@ -99,11 +128,27 @@ class SE_Connect(nn.Module): class SE_Res2Block(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + scale, + se_bottleneck_dim, + ): super().__init__() - self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale) - self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0) + self.Conv1dReluBn1 = Conv1dReluBn( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + self.Res2Conv1dReluBn = Res2Conv1dReluBn( + out_channels, kernel_size, stride, padding, dilation, scale=scale + ) + self.Conv1dReluBn2 = Conv1dReluBn( + out_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim) self.shortcut = None @@ -138,15 +183,23 @@ class AttentiveStatsPool(nn.Module): # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs. if global_context_att: - self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper + self.linear1 = nn.Conv1d( + in_dim * 3, attention_channels, kernel_size=1 + ) # equals W and b in the paper else: - self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper - self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper + self.linear1 = nn.Conv1d( + in_dim, attention_channels, kernel_size=1 + ) # equals W and b in the paper + self.linear2 = nn.Conv1d( + attention_channels, in_dim, kernel_size=1 + ) # equals V and k in the paper def forward(self, x): if self.global_context_att: context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) - context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x) + context_std = torch.sqrt( + torch.var(x, dim=-1, keepdim=True) + 1e-10 + ).expand_as(x) x_in = torch.cat((x, context_mean, context_std), dim=1) else: x_in = x @@ -184,24 +237,36 @@ class ECAPA_TDNN(nn.Module): torch.hub._validate_not_a_forked_repo = lambda a, b, c: True try: local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main") - self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source="local", config_path=config_path) + self.feature_extract = torch.hub.load( + local_s3prl_path, feat_type, source="local", config_path=config_path + ) except: # noqa: E722 self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type) if len(self.feature_extract.model.encoder.layers) == 24 and hasattr( self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention" ): - self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False + self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = ( + False + ) if len(self.feature_extract.model.encoder.layers) == 24 and hasattr( self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention" ): - self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False + self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = ( + False + ) self.feat_num = self.get_feat_num() self.feature_weight = nn.Parameter(torch.zeros(self.feat_num)) if feat_type != "fbank" and feat_type != "mfcc": - freeze_list = ["final_proj", "label_embs_concat", "mask_emb", "project_q", "quantizer"] + freeze_list = [ + "final_proj", + "label_embs_concat", + "mask_emb", + "project_q", + "quantizer", + ] for name, param in self.feature_extract.named_parameters(): for freeze_val in freeze_list: if freeze_val in name: @@ -252,7 +317,9 @@ class ECAPA_TDNN(nn.Module): cat_channels = channels * 3 self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1) self.pooling = AttentiveStatsPool( - self.channels[-1], attention_channels=128, global_context_att=global_context_att + self.channels[-1], + attention_channels=128, + global_context_att=global_context_att, ) self.bn = nn.BatchNorm1d(self.channels[-1] * 2) self.linear = nn.Linear(self.channels[-1] * 2, emb_dim) @@ -287,7 +354,12 @@ class ECAPA_TDNN(nn.Module): x = torch.stack(x, dim=0) else: x = x.unsqueeze(0) - norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + norm_weights = ( + F.softmax(self.feature_weight, dim=-1) + .unsqueeze(-1) + .unsqueeze(-1) + .unsqueeze(-1) + ) x = (norm_weights * x).sum(dim=0) x = torch.transpose(x, 1, 2) + 1e-6 diff --git a/f5_tts/eval/eval_infer_batch.py b/f5_tts/eval/eval_infer_batch.py index cea5b7a19fdaad5c60f8ce7fd74b18493fc36780..56bd815b269a4e9c0e5cb4a2ee37d5bce122b24d 100644 --- a/f5_tts/eval/eval_infer_batch.py +++ b/f5_tts/eval/eval_infer_batch.py @@ -1,7 +1,6 @@ import os import sys - sys.path.append(os.getcwd()) import argparse @@ -15,16 +14,13 @@ from hydra.utils import get_class from omegaconf import OmegaConf from tqdm import tqdm -from f5_tts.eval.utils_eval import ( - get_inference_prompt, - get_librispeech_test_clean_metainfo, - get_seedtts_testset_metainfo, -) +from f5_tts.eval.utils_eval import (get_inference_prompt, + get_librispeech_test_clean_metainfo, + get_seedtts_testset_metainfo) from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder from f5_tts.model import CFM from f5_tts.model.utils import get_tokenizer - accelerator = Accelerator() device = f"cuda:{accelerator.process_index}" @@ -67,7 +63,9 @@ def main(): use_truth_duration = False no_ref_audio = False - model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml"))) + model_cfg = OmegaConf.load( + str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")) + ) model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") model_arc = model_cfg.model.arch @@ -83,8 +81,12 @@ def main(): if testset == "ls_pc_test_clean": metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst" - librispeech_test_clean_path = "/LibriSpeech/test-clean" # test-clean path - metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path) + librispeech_test_clean_path = ( + "/LibriSpeech/test-clean" # test-clean path + ) + metainfo = get_librispeech_test_clean_metainfo( + metalst, librispeech_test_clean_path + ) elif testset == "seedtts_test_zh": metalst = rel_path + "/data/seedtts_testset/zh/meta.lst" @@ -126,14 +128,18 @@ def main(): vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz" elif mel_spec_type == "bigvgan": vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" - vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path) + vocoder = load_vocoder( + vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path + ) # Tokenizer vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) # Model model = CFM( - transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels), + transformer=model_cls( + **model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels + ), mel_spec_kwargs=dict( n_fft=n_fft, hop_length=hop_length, @@ -154,7 +160,9 @@ def main(): elif os.path.exists(ckpt_prefix + ".safetensors"): ckpt_path = ckpt_prefix + ".safetensors" else: - print("Loading from self-organized training checkpoints rather than released pretrained.") + print( + "Loading from self-organized training checkpoints rather than released pretrained." + ) ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt" dtype = torch.float32 if mel_spec_type == "bigvgan" else None @@ -169,7 +177,14 @@ def main(): with accelerator.split_between_processes(prompts_all) as prompts: for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process): - utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt + ( + utts, + ref_rms_list, + ref_mels, + ref_mel_lens, + total_mel_lens, + final_text_list, + ) = prompt ref_mels = ref_mels.to(device) ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device) total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device) @@ -198,7 +213,11 @@ def main(): if ref_rms_list[i] < target_rms: generated_wave = generated_wave * ref_rms_list[i] / target_rms - torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate) + torchaudio.save( + f"{output_dir}/{utts[i]}.wav", + generated_wave, + target_sample_rate, + ) accelerator.wait_for_everyone() if accelerator.is_main_process: diff --git a/f5_tts/eval/eval_librispeech_test_clean.py b/f5_tts/eval/eval_librispeech_test_clean.py index 0fef8019855faa5446983674b700e2e9b05c861f..5f973daef739bce711c5b7554a7bac2c1be2791b 100644 --- a/f5_tts/eval/eval_librispeech_test_clean.py +++ b/f5_tts/eval/eval_librispeech_test_clean.py @@ -5,7 +5,6 @@ import json import os import sys - sys.path.append(os.getcwd()) import multiprocessing as mp @@ -15,18 +14,23 @@ import numpy as np from f5_tts.eval.utils_eval import get_librispeech_test, run_asr_wer, run_sim - rel_path = str(files("f5_tts").joinpath("../../")) def get_args(): parser = argparse.ArgumentParser() - parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"]) + parser.add_argument( + "-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"] + ) parser.add_argument("-l", "--lang", type=str, default="en") parser.add_argument("-g", "--gen_wav_dir", type=str, required=True) parser.add_argument("-p", "--librispeech_test_clean_path", type=str, required=True) - parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use") - parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory") + parser.add_argument( + "-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use" + ) + parser.add_argument( + "--local", action="store_true", help="Use local custom checkpoint directory" + ) return parser.parse_args() @@ -39,7 +43,9 @@ def main(): metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst" gpus = list(range(args.gpu_nums)) - test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path) + test_set = get_librispeech_test( + metalst, gen_wav_dir, gpus, librispeech_test_clean_path + ) ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book, ## leading to a low similarity for the ground truth in some cases. @@ -59,13 +65,19 @@ def main(): if eval_task == "wer": with mp.Pool(processes=len(gpus)) as pool: - args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set] + args = [ + (rank, lang, sub_test_set, asr_ckpt_dir) + for (rank, sub_test_set) in test_set + ] results = pool.map(run_asr_wer, args) for r in results: full_results.extend(r) elif eval_task == "sim": with mp.Pool(processes=len(gpus)) as pool: - args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set] + args = [ + (rank, sub_test_set, wavlm_ckpt_dir) + for (rank, sub_test_set) in test_set + ] results = pool.map(run_sim, args) for r in results: full_results.extend(r) diff --git a/f5_tts/eval/eval_seedtts_testset.py b/f5_tts/eval/eval_seedtts_testset.py index 158a3dd81bcb99d54ad484e73878b115b37f1c80..69afc5a70a7569ddb72361f9bcfa31cd0f1f27bd 100644 --- a/f5_tts/eval/eval_seedtts_testset.py +++ b/f5_tts/eval/eval_seedtts_testset.py @@ -5,7 +5,6 @@ import json import os import sys - sys.path.append(os.getcwd()) import multiprocessing as mp @@ -15,17 +14,22 @@ import numpy as np from f5_tts.eval.utils_eval import get_seed_tts_test, run_asr_wer, run_sim - rel_path = str(files("f5_tts").joinpath("../../")) def get_args(): parser = argparse.ArgumentParser() - parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"]) + parser.add_argument( + "-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"] + ) parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"]) parser.add_argument("-g", "--gen_wav_dir", type=str, required=True) - parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use") - parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory") + parser.add_argument( + "-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use" + ) + parser.add_argument( + "--local", action="store_true", help="Use local custom checkpoint directory" + ) return parser.parse_args() @@ -58,13 +62,19 @@ def main(): if eval_task == "wer": with mp.Pool(processes=len(gpus)) as pool: - args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set] + args = [ + (rank, lang, sub_test_set, asr_ckpt_dir) + for (rank, sub_test_set) in test_set + ] results = pool.map(run_asr_wer, args) for r in results: full_results.extend(r) elif eval_task == "sim": with mp.Pool(processes=len(gpus)) as pool: - args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set] + args = [ + (rank, sub_test_set, wavlm_ckpt_dir) + for (rank, sub_test_set) in test_set + ] results = pool.map(run_sim, args) for r in results: full_results.extend(r) diff --git a/f5_tts/eval/eval_utmos.py b/f5_tts/eval/eval_utmos.py index b6166e8ab073a6134b23936e15e440332991bab2..267569d969cb5888bb955cb111abc39a0a256732 100644 --- a/f5_tts/eval/eval_utmos.py +++ b/f5_tts/eval/eval_utmos.py @@ -13,9 +13,15 @@ def main(): parser.add_argument("--ext", type=str, default="wav", help="Audio extension.") args = parser.parse_args() - device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu" - - predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True) + device = ( + "cuda" + if torch.cuda.is_available() + else "xpu" if torch.xpu.is_available() else "cpu" + ) + + predictor = torch.hub.load( + "tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True + ) predictor = predictor.to(device) audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}")) diff --git a/f5_tts/eval/utils_eval.py b/f5_tts/eval/utils_eval.py index c5ac834f357dcaa780bb5ab699ec701f5794b5be..112b04f56bdf41186d4b90cac05ca0b3ca5970e4 100644 --- a/f5_tts/eval/utils_eval.py +++ b/f5_tts/eval/utils_eval.py @@ -43,11 +43,15 @@ def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path): # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc) ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-") - ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac") + ref_wav = os.path.join( + librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac" + ) # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc) gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-") - gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac") + gen_wav = os.path.join( + librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac" + ) metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav)) @@ -106,13 +110,17 @@ def get_inference_prompt( mel_spec_type=mel_spec_type, ) - for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."): + for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm( + metainfo, desc="Processing prompts..." + ): # Audio ref_audio, ref_sr = torchaudio.load(prompt_wav) ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio))) if ref_rms < target_rms: ref_audio = ref_audio * target_rms / ref_rms - assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue." + assert ( + ref_audio.shape[-1] > 5000 + ), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue." if ref_sr != target_sample_rate: resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) ref_audio = resampler(ref_audio) @@ -145,14 +153,18 @@ def get_inference_prompt( else: ref_text_len = len(prompt_text.encode("utf-8")) gen_text_len = len(gt_text.encode("utf-8")) - total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed) + total_mel_len = ref_mel_len + int( + ref_mel_len / ref_text_len * gen_text_len / speed + ) # deal with batch assert infer_batch_size > 0, "infer_batch_size should be greater than 0." - assert min_tokens <= total_mel_len <= max_tokens, ( - f"Audio {utt} has duration {total_mel_len * hop_length // target_sample_rate}s out of range [{min_secs}, {max_secs}]." + assert ( + min_tokens <= total_mel_len <= max_tokens + ), f"Audio {utt} has duration {total_mel_len * hop_length // target_sample_rate}s out of range [{min_secs}, {max_secs}]." + bucket_i = math.floor( + (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets ) - bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets) utts[bucket_i].append(utt) ref_rms_list[bucket_i].append(ref_rms) @@ -183,7 +195,14 @@ def get_inference_prompt( ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i], - ) = [], [], [], [], [], [] + ) = ( + [], + [], + [], + [], + [], + [], + ) # add residual for bucket_i, bucket_frames in enumerate(batch_accum): @@ -244,7 +263,9 @@ def get_seed_tts_test(metalst, gen_wav_dir, gpus): # get librispeech test-clean cross sentence test -def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False): +def get_librispeech_test( + metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False +): f = open(metalst) lines = f.readlines() f.close() @@ -255,14 +276,21 @@ def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path if eval_ground_truth: gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-") - gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac") + gen_wav = os.path.join( + librispeech_test_clean_path, + gen_spk_id, + gen_chaptr_id, + gen_utt + ".flac", + ) else: if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")): raise FileNotFoundError(f"Generated wav not found: {gen_utt}") gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav") ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-") - ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac") + ref_wav = os.path.join( + librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac" + ) test_set_.append((gen_wav, ref_wav, gen_txt)) @@ -382,7 +410,9 @@ def run_sim(args): device = f"cuda:{rank}" model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None) - state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage) + state_dict = torch.load( + ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage + ) model.load_state_dict(state_dict["model"], strict=False) use_gpu = True if torch.cuda.is_available() else False diff --git a/f5_tts/infer/infer_cli.py b/f5_tts/infer/infer_cli.py index 7d511706cb57a0c4857d68d312785863d6f2f168..4f14d9927f4ad8efce1a517b9c83a3e21cf19bf6 100644 --- a/f5_tts/infer/infer_cli.py +++ b/f5_tts/infer/infer_cli.py @@ -14,23 +14,12 @@ from hydra.utils import get_class from omegaconf import OmegaConf from unidecode import unidecode -from f5_tts.infer.utils_infer import ( - cfg_strength, - cross_fade_duration, - device, - fix_duration, - infer_process, - load_model, - load_vocoder, - mel_spec_type, - nfe_step, - preprocess_ref_audio_text, - remove_silence_for_generated_wav, - speed, - sway_sampling_coef, - target_rms, -) - +from f5_tts.infer.utils_infer import (cfg_strength, cross_fade_duration, + device, fix_duration, infer_process, + load_model, load_vocoder, mel_spec_type, + nfe_step, preprocess_ref_audio_text, + remove_silence_for_generated_wav, speed, + sway_sampling_coef, target_rms) parser = argparse.ArgumentParser( prog="python3 infer-cli.py", @@ -41,7 +30,9 @@ parser.add_argument( "-c", "--config", type=str, - default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"), + default=os.path.join( + files("f5_tts").joinpath("infer/examples/basic"), "basic.toml" + ), help="The configuration file, default see infer/examples/basic/basic.toml", ) @@ -188,13 +179,17 @@ model = args.model or config.get("model", "F5TTS_v1_Base") ckpt_file = args.ckpt_file or config.get("ckpt_file", "") vocab_file = args.vocab_file or config.get("vocab_file", "") -ref_audio = args.ref_audio or config.get("ref_audio", "infer/examples/basic/basic_ref_en.wav") +ref_audio = args.ref_audio or config.get( + "ref_audio", "infer/examples/basic/basic_ref_en.wav" +) ref_text = ( args.ref_text if args.ref_text is not None else config.get("ref_text", "Some call me nature, others call me mother nature.") ) -gen_text = args.gen_text or config.get("gen_text", "Here we generate something just for test.") +gen_text = args.gen_text or config.get( + "gen_text", "Here we generate something just for test." +) gen_file = args.gen_file or config.get("gen_file", "") output_dir = args.output_dir or config.get("output_dir", "tests") @@ -203,21 +198,29 @@ output_file = args.output_file or config.get( ) save_chunk = args.save_chunk or config.get("save_chunk", False) -use_legacy_text = args.no_legacy_text or config.get("no_legacy_text", False) # no_legacy_text is a store_false arg +use_legacy_text = args.no_legacy_text or config.get( + "no_legacy_text", False +) # no_legacy_text is a store_false arg if save_chunk and use_legacy_text: print( "\nWarning to --save_chunk: lossy ASCII transliterations of unicode text for legacy (.wav) file names, --no_legacy_text to disable.\n" ) remove_silence = args.remove_silence or config.get("remove_silence", False) -load_vocoder_from_local = args.load_vocoder_from_local or config.get("load_vocoder_from_local", False) +load_vocoder_from_local = args.load_vocoder_from_local or config.get( + "load_vocoder_from_local", False +) vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type) target_rms = args.target_rms or config.get("target_rms", target_rms) -cross_fade_duration = args.cross_fade_duration or config.get("cross_fade_duration", cross_fade_duration) +cross_fade_duration = args.cross_fade_duration or config.get( + "cross_fade_duration", cross_fade_duration +) nfe_step = args.nfe_step or config.get("nfe_step", nfe_step) cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength) -sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef) +sway_sampling_coef = args.sway_sampling_coef or config.get( + "sway_sampling_coef", sway_sampling_coef +) speed = args.speed or config.get("speed", speed) fix_duration = args.fix_duration or config.get("fix_duration", fix_duration) device = args.device or config.get("device", device) @@ -232,7 +235,9 @@ if "voices" in config: for voice in config["voices"]: voice_ref_audio = config["voices"][voice]["ref_audio"] if "infer/examples/" in voice_ref_audio: - config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}")) + config["voices"][voice]["ref_audio"] = str( + files("f5_tts").joinpath(f"{voice_ref_audio}") + ) # ignore gen_text if gen_file provided @@ -259,14 +264,18 @@ elif vocoder_name == "bigvgan": vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" vocoder = load_vocoder( - vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path, device=device + vocoder_name=vocoder_name, + is_local=load_vocoder_from_local, + local_path=vocoder_local_path, + device=device, ) # load TTS model model_cfg = OmegaConf.load( - args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) + args.model_cfg + or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) ) model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") model_arc = model_cfg.model.arch @@ -288,11 +297,18 @@ elif model == "E2TTS_Base": ckpt_step = 1200000 if not ckpt_file: - ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}")) + ckpt_file = str( + cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}") + ) print(f"Using {model}...") ema_model = load_model( - model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device + model_cls, + model_arc, + ckpt_file, + mel_spec_type=vocoder_name, + vocab_file=vocab_file, + device=device, ) @@ -309,8 +325,10 @@ def main(): for voice in voices: print("Voice:", voice) print("ref_audio ", voices[voice]["ref_audio"]) - voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text( - voices[voice]["ref_audio"], voices[voice]["ref_text"] + voices[voice]["ref_audio"], voices[voice]["ref_text"] = ( + preprocess_ref_audio_text( + voices[voice]["ref_audio"], voices[voice]["ref_text"] + ) ) print("ref_audio_", voices[voice]["ref_audio"], "\n\n") @@ -360,7 +378,10 @@ def main(): if use_legacy_text: gen_text_ = unidecode(gen_text_) sf.write( - os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"), + os.path.join( + output_chunk_dir, + f"{len(generated_audio_segments) - 1}_{gen_text_}.wav", + ), audio_segment, final_sample_rate, ) diff --git a/f5_tts/infer/infer_gradio.py b/f5_tts/infer/infer_gradio.py index f4c3aefe7e7ddcb8ffa3287a8a6569911498fd32..ec4b691547999a2ed73a46715832c032c3b0ba3a 100644 --- a/f5_tts/infer/infer_gradio.py +++ b/f5_tts/infer/infer_gradio.py @@ -19,7 +19,6 @@ import torchaudio from cached_path import cached_path from transformers import AutoModelForCausalLM, AutoTokenizer - try: import spaces @@ -35,25 +34,21 @@ def gpu_decorator(func): return func -from f5_tts.infer.utils_infer import ( - infer_process, - load_model, - load_vocoder, - preprocess_ref_audio_text, - remove_silence_for_generated_wav, - save_spectrogram, - tempfile_kwargs, -) +from f5_tts.infer.utils_infer import (infer_process, load_model, load_vocoder, + preprocess_ref_audio_text, + remove_silence_for_generated_wav, + save_spectrogram, tempfile_kwargs) from f5_tts.model import DiT, UNetT - DEFAULT_TTS_MODEL = "F5-TTS_v1" tts_model_choice = DEFAULT_TTS_MODEL DEFAULT_TTS_MODEL_CFG = [ "hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors", "hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt", - json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)), + json.dumps( + dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) + ), ] @@ -69,8 +64,12 @@ def load_f5tts(): def load_e2tts(): - ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")) - E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1) + ckpt_path = str( + cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors") + ) + E2TTS_model_cfg = dict( + dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1 + ) return load_model(UNetT, E2TTS_model_cfg, ckpt_path) @@ -113,7 +112,8 @@ def chat_model_inference(messages, model, tokenizer): ) generated_ids = [ - output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) + output_ids[len(input_ids) :] + for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] @@ -157,7 +157,9 @@ def infer( gr.Warning("Please enter text to generate or upload a text file.") return gr.update(), gr.update(), ref_text - ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info) + ref_audio, ref_text = preprocess_ref_audio_text( + ref_audio_orig, ref_text, show_info=show_info + ) if model == DEFAULT_TTS_MODEL: ema_model = F5TTS_ema_model @@ -172,7 +174,9 @@ def infer( global custom_ema_model, pre_custom_path if pre_custom_path != model[1]: show_info("Loading Custom TTS model...") - custom_ema_model = load_custom(model[1], vocab_path=model[2], model_cfg=model[3]) + custom_ema_model = load_custom( + model[1], vocab_path=model[2], model_cfg=model[3] + ) pre_custom_path = model[1] ema_model = custom_ema_model @@ -202,7 +206,9 @@ def infer( final_wave = final_wave.squeeze().cpu().numpy() # Save the spectrogram - with tempfile.NamedTemporaryFile(suffix=".png", **tempfile_kwargs) as tmp_spectrogram: + with tempfile.NamedTemporaryFile( + suffix=".png", **tempfile_kwargs + ) as tmp_spectrogram: spectrogram_path = tmp_spectrogram.name save_spectrogram(combined_spectrogram, spectrogram_path) @@ -219,7 +225,9 @@ with gr.Blocks() as app_tts: max_lines=40, scale=4, ) - gen_text_file = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1) + gen_text_file = gr.File( + label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1 + ) generate_btn = gr.Button("Synthesize", variant="primary") with gr.Accordion("Advanced Settings", open=False): with gr.Row(): @@ -229,7 +237,11 @@ with gr.Blocks() as app_tts: lines=2, scale=4, ) - ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1) + ref_text_file = gr.File( + label="Load Reference Text from File (.txt)", + file_types=[".txt"], + scale=1, + ) with gr.Row(): randomize_seed = gr.Checkbox( label="Randomize Seed", @@ -417,13 +429,25 @@ with gr.Blocks() as app_multistyle: regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=4) with gr.Row(): regular_seed_slider = gr.Slider( - show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed, -1 for random" + show_label=False, + minimum=-1, + maximum=999, + value=-1, + step=1, + info="Seed, -1 for random", ) regular_speed_slider = gr.Slider( - show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed" + show_label=False, + minimum=0.3, + maximum=2.0, + value=1.0, + step=0.1, + info="Adjust the speed", ) with gr.Column(scale=1, min_width=160): - regular_ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"]) + regular_ref_text_file = gr.File( + label="Load Reference Text from File (.txt)", file_types=[".txt"] + ) # Regular speech type (max 100) max_speech_types = 100 @@ -450,13 +474,25 @@ with gr.Blocks() as app_multistyle: ref_text_input = gr.Textbox(label="Reference Text", lines=4) with gr.Row(): seed_input = gr.Slider( - show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed. -1 for random" + show_label=False, + minimum=-1, + maximum=999, + value=-1, + step=1, + info="Seed. -1 for random", ) speed_input = gr.Slider( - show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed" + show_label=False, + minimum=0.3, + maximum=2.0, + value=1.0, + step=0.1, + info="Adjust the speed", ) with gr.Column(scale=1, min_width=160): - ref_text_file_input = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"]) + ref_text_file_input = gr.File( + label="Load Reference Text from File (.txt)", file_types=[".txt"] + ) speech_type_rows.append(row) speech_type_names.append(name_input) speech_type_audios.append(audio_input) @@ -494,7 +530,9 @@ with gr.Blocks() as app_multistyle: row_updates[speech_type_count] = gr.update(visible=True) speech_type_count += 1 else: - gr.Warning("Exhausted maximum number of speech types. Consider restart the app.") + gr.Warning( + "Exhausted maximum number of speech types. Consider restart the app." + ) return row_updates add_speech_type_btn.click(add_speech_type_fn, outputs=speech_type_rows) @@ -525,10 +563,14 @@ with gr.Blocks() as app_multistyle: scale=4, placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!", ) - gen_text_file_multistyle = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1) + gen_text_file_multistyle = gr.File( + label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1 + ) def make_insert_speech_type_fn(index): - def insert_speech_type_fn(current_text, speech_type_name, speech_type_seed, speech_type_speed): + def insert_speech_type_fn( + current_text, speech_type_name, speech_type_seed, speech_type_speed + ): current_text = current_text or "" if not speech_type_name: gr.Warning("Please enter speech type name before insert.") @@ -547,7 +589,12 @@ with gr.Blocks() as app_multistyle: insert_fn = make_insert_speech_type_fn(i) insert_btn.click( insert_fn, - inputs=[gen_text_input_multistyle, speech_type_names[i], speech_type_seeds[i], speech_type_speeds[i]], + inputs=[ + gen_text_input_multistyle, + speech_type_names[i], + speech_type_seeds[i], + speech_type_speeds[i], + ], outputs=gen_text_input_multistyle, ) @@ -567,7 +614,9 @@ with gr.Blocks() as app_multistyle: ) # Generate button - generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary") + generate_multistyle_btn = gr.Button( + "Generate Multi-Style Speech", variant="primary" + ) # Output audio audio_output_multistyle = gr.Audio(label="Synthesized Audio") @@ -613,7 +662,10 @@ with gr.Blocks() as app_multistyle: speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list ): if name_input and audio_input: - speech_types[name_input] = {"audio": audio_input, "ref_text": ref_text_input} + speech_types[name_input] = { + "audio": audio_input, + "ref_text": ref_text_input, + } else: speech_types[f"@{ref_text_idx}@"] = {"audio": "", "ref_text": ""} ref_text_idx += 1 @@ -635,14 +687,22 @@ with gr.Blocks() as app_multistyle: if name in speech_types: current_type_name = name else: - gr.Warning(f"Type {name} is not available, will use Regular as default.") + gr.Warning( + f"Type {name} is not available, will use Regular as default." + ) current_type_name = "Regular" try: ref_audio = speech_types[current_type_name]["audio"] except KeyError: - gr.Warning(f"Please provide reference audio for type {current_type_name}.") - return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None] + gr.Warning( + f"Please provide reference audio for type {current_type_name}." + ) + return ( + [None] + + [speech_types[name]["ref_text"] for name in speech_types] + + [None] + ) ref_text = speech_types[current_type_name].get("ref_text", "") if seed_input == -1: @@ -664,7 +724,9 @@ with gr.Blocks() as app_multistyle: generated_audio_segments.append(audio_data) speech_types[current_type_name]["ref_text"] = ref_text_out - inference_meta_data += json.dumps(dict(name=name, seed=used_seed, speed=speed)) + f" {text}\n" + inference_meta_data += ( + json.dumps(dict(name=name, seed=used_seed, speed=speed)) + f" {text}\n" + ) # Concatenate all audio segments if generated_audio_segments: @@ -676,7 +738,11 @@ with gr.Blocks() as app_multistyle: ) else: gr.Warning("No audio generated.") - return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None] + return ( + [None] + + [speech_types[name]["ref_text"] for name in speech_types] + + [None] + ) generate_multistyle_btn.click( generate_multistyle_speech, @@ -689,7 +755,9 @@ with gr.Blocks() as app_multistyle: + [ remove_silence_multistyle, ], - outputs=[audio_output_multistyle] + speech_type_ref_texts + [cherrypick_interface_multistyle], + outputs=[audio_output_multistyle] + + speech_type_ref_texts + + [cherrypick_interface_multistyle], ) # Validation function to disable Generate button if speech types are missing @@ -753,7 +821,9 @@ Have a conversation with an AI using your reference voice! torch.cuda.empty_cache() show_info(f"Loading chat model: {chat_model_name}") - chat_model_state = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype="auto", device_map="auto") + chat_model_state = AutoModelForCausalLM.from_pretrained( + chat_model_name, torch_dtype="auto", device_map="auto" + ) chat_tokenizer_state = AutoTokenizer.from_pretrained(chat_model_name) show_info(f"Chat model {chat_model_name} loaded successfully!") @@ -769,7 +839,9 @@ Have a conversation with an AI using your reference voice! info="Enter the name of a HuggingFace chat model", allow_custom_value=not USING_SPACES, ) - load_chat_model_btn = gr.Button("Load Chat Model", variant="primary", visible=not USING_SPACES) + load_chat_model_btn = gr.Button( + "Load Chat Model", variant="primary", visible=not USING_SPACES + ) chat_interface_container = gr.Column(visible=USING_SPACES) chat_model_name_input.change( @@ -779,7 +851,9 @@ Have a conversation with an AI using your reference voice! show_progress="hidden", ) load_chat_model_btn.click( - load_chat_model, inputs=[chat_model_name_input], outputs=[load_chat_model_btn, chat_interface_container] + load_chat_model, + inputs=[chat_model_name_input], + outputs=[load_chat_model_btn, chat_interface_container], ) with chat_interface_container: @@ -796,7 +870,9 @@ Have a conversation with an AI using your reference voice! scale=3, ) ref_text_file_chat = gr.File( - label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1 + label="Load Reference Text from File (.txt)", + file_types=[".txt"], + scale=1, ) with gr.Row(): randomize_seed_chat = gr.Checkbox( @@ -805,7 +881,9 @@ Have a conversation with an AI using your reference voice! info="Uncheck to use the seed specified.", scale=3, ) - seed_input_chat = gr.Number(show_label=False, value=0, precision=0, scale=1) + seed_input_chat = gr.Number( + show_label=False, value=0, precision=0, scale=1 + ) remove_silence_chat = gr.Checkbox( label="Remove Silences", value=True, @@ -855,13 +933,17 @@ Have a conversation with an AI using your reference voice! """Generate text response from AI""" system_prompt_state = [{"role": "system", "content": system_prompt}] - response = chat_model_inference(system_prompt_state + conv_state, chat_model_state, chat_tokenizer_state) + response = chat_model_inference( + system_prompt_state + conv_state, chat_model_state, chat_tokenizer_state + ) conv_state.append({"role": "assistant", "content": response}) return conv_state @gpu_decorator - def generate_audio_response(conv_state, ref_audio, ref_text, remove_silence, randomize_seed, seed_input): + def generate_audio_response( + conv_state, ref_audio, ref_text, remove_silence, randomize_seed, seed_input + ): """Generate TTS audio for AI response""" if not conv_state or not ref_audio: return None, ref_text, seed_input @@ -896,7 +978,11 @@ Have a conversation with an AI using your reference voice! outputs=[ref_text_chat], ) - for user_operation in [audio_input_chat.stop_recording, text_input_chat.submit, send_btn_chat.click]: + for user_operation in [ + audio_input_chat.stop_recording, + text_input_chat.submit, + send_btn_chat.click, + ]: user_operation( process_audio_input, inputs=[chatbot_interface, audio_input_chat, text_input_chat], @@ -923,7 +1009,11 @@ Have a conversation with an AI using your reference voice! ) # Handle clear button or system prompt change and reset conversation - for user_operation in [clear_btn_chat.click, system_prompt_chat.change, chatbot_interface.clear]: + for user_operation in [ + clear_btn_chat.click, + system_prompt_chat.change, + chatbot_interface.clear, + ]: user_operation( clear_conversation, outputs=[chatbot_interface, audio_output_chat], @@ -931,13 +1021,15 @@ Have a conversation with an AI using your reference voice! with gr.Blocks() as app_credits: - gr.Markdown(""" + gr.Markdown( + """ # Credits * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS) * [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration * [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat -""") +""" + ) with gr.Blocks() as app: @@ -958,7 +1050,9 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip """ ) - last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info_v1.txt") + last_used_custom = files("f5_tts").joinpath( + "infer/.cache/last_used_custom_model_info_v1.txt" + ) def load_last_used_custom(): try: @@ -974,8 +1068,15 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip def switch_tts_model(new_choice): global tts_model_choice if new_choice == "Custom": # override in case webpage is refreshed - custom_ckpt_path, custom_vocab_path, custom_model_cfg = load_last_used_custom() - tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg) + custom_ckpt_path, custom_vocab_path, custom_model_cfg = ( + load_last_used_custom() + ) + tts_model_choice = ( + "Custom", + custom_ckpt_path, + custom_vocab_path, + custom_model_cfg, + ) return ( gr.update(visible=True, value=custom_ckpt_path), gr.update(visible=True, value=custom_vocab_path), @@ -983,22 +1084,42 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip ) else: tts_model_choice = new_choice - return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) + return ( + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=False), + ) def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg): global tts_model_choice - tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg) + tts_model_choice = ( + "Custom", + custom_ckpt_path, + custom_vocab_path, + custom_model_cfg, + ) with open(last_used_custom, "w", encoding="utf-8") as f: - f.write(custom_ckpt_path + "\n" + custom_vocab_path + "\n" + custom_model_cfg + "\n") + f.write( + custom_ckpt_path + + "\n" + + custom_vocab_path + + "\n" + + custom_model_cfg + + "\n" + ) with gr.Row(): if not USING_SPACES: choose_tts_model = gr.Radio( - choices=[DEFAULT_TTS_MODEL, "E2-TTS", "Custom"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL + choices=[DEFAULT_TTS_MODEL, "E2-TTS", "Custom"], + label="Choose TTS Model", + value=DEFAULT_TTS_MODEL, ) else: choose_tts_model = gr.Radio( - choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL + choices=[DEFAULT_TTS_MODEL, "E2-TTS"], + label="Choose TTS Model", + value=DEFAULT_TTS_MODEL, ) custom_ckpt_path = gr.Dropdown( choices=[DEFAULT_TTS_MODEL_CFG[0]], diff --git a/f5_tts/infer/speech_edit.py b/f5_tts/infer/speech_edit.py index fdeda9fee6ab57a13a43b06c7c802227222791e3..57db29aa07e1bae04db866c9e3577ecde3d25002 100644 --- a/f5_tts/infer/speech_edit.py +++ b/f5_tts/infer/speech_edit.py @@ -1,6 +1,5 @@ import os - os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility from importlib.resources import files @@ -12,19 +11,19 @@ from cached_path import cached_path from hydra.utils import get_class from omegaconf import OmegaConf -from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram +from f5_tts.infer.utils_infer import (load_checkpoint, load_vocoder, + save_spectrogram) from f5_tts.model import CFM from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer - device = ( "cuda" if torch.cuda.is_available() - else "xpu" - if torch.xpu.is_available() - else "mps" - if torch.backends.mps.is_available() - else "cpu" + else ( + "xpu" + if torch.xpu.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" + ) ) @@ -59,7 +58,9 @@ n_fft = model_cfg.model.mel_spec.n_fft # ckpt_path = str(files("f5_tts").joinpath("../../")) + f"/ckpts/{exp_name}/model_{ckpt_step}.safetensors" -ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors")) +ckpt_path = str( + cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors") +) output_dir = "tests" @@ -103,14 +104,18 @@ if mel_spec_type == "vocos": vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz" elif mel_spec_type == "bigvgan": vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" -vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path) +vocoder = load_vocoder( + vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path +) # Tokenizer vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) # Model model = CFM( - transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels), + transformer=model_cls( + **model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels + ), mel_spec_kwargs=dict( n_fft=n_fft, hop_length=hop_length, @@ -146,7 +151,14 @@ for part in parts_to_edit: part_dur = end - start if fix_duration is None else fix_duration.pop(0) part_dur = part_dur * target_sample_rate start = start * target_sample_rate - audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1) + audio_ = torch.cat( + ( + audio_, + audio[:, round(offset) : round(start)], + torch.zeros(1, round(part_dur)), + ), + dim=-1, + ) edit_mask = torch.cat( ( edit_mask, @@ -157,7 +169,9 @@ for part in parts_to_edit: ) offset = end * target_sample_rate audio = torch.cat((audio_, audio[:, round(offset) :]), dim=-1) -edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True) +edit_mask = F.pad( + edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True +) audio = audio.to(device) edit_mask = edit_mask.to(device) @@ -201,5 +215,7 @@ with torch.inference_mode(): generated_wave = generated_wave * rms / target_rms save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png") - torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate) + torchaudio.save( + f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate + ) print(f"Generated wav: {generated_wave.shape}") diff --git a/f5_tts/infer/utils_infer.py b/f5_tts/infer/utils_infer.py index a1f311113bcbdbd9b9999e451fb38cbf9d96a146..090ad57e57f54c1a0ae7ba7029685ea7ff89279d 100644 --- a/f5_tts/infer/utils_infer.py +++ b/f5_tts/infer/utils_infer.py @@ -4,9 +4,10 @@ import os import sys from concurrent.futures import ThreadPoolExecutor - os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility -sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/") +sys.path.append( + f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/" +) import hashlib import re @@ -15,7 +16,6 @@ from importlib.resources import files import matplotlib - matplotlib.use("Agg") import matplotlib.pylab as plt @@ -31,21 +31,22 @@ from vocos import Vocos from f5_tts.model import CFM from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer - _ref_audio_cache = {} _ref_text_cache = {} device = ( "cuda" if torch.cuda.is_available() - else "xpu" - if torch.xpu.is_available() - else "mps" - if torch.backends.mps.is_available() - else "cpu" + else ( + "xpu" + if torch.xpu.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" + ) ) -tempfile_kwargs = {"delete_on_close": False} if sys.version_info >= (3, 12) else {"delete": False} +tempfile_kwargs = ( + {"delete_on_close": False} if sys.version_info >= (3, 12) else {"delete": False} +) # ----------------------------------------- @@ -87,12 +88,23 @@ def chunk_text(text, max_chars=135): sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text) for sentence in sentences: - if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars: - current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence + if ( + len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) + <= max_chars + ): + current_chunk += ( + sentence + " " + if sentence and len(sentence[-1].encode("utf-8")) == 1 + else sentence + ) else: if current_chunk: chunks.append(current_chunk.strip()) - current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence + current_chunk = ( + sentence + " " + if sentence and len(sentence[-1].encode("utf-8")) == 1 + else sentence + ) if current_chunk: chunks.append(current_chunk.strip()) @@ -101,7 +113,13 @@ def chunk_text(text, max_chars=135): # load vocoder -def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=device, hf_cache_dir=None): +def load_vocoder( + vocoder_name="vocos", + is_local=False, + local_path="", + device=device, + hf_cache_dir=None, +): if vocoder_name == "vocos": # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device) if is_local: @@ -111,8 +129,12 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev else: print("Download Vocos from huggingface charactr/vocos-mel-24khz") repo_id = "charactr/vocos-mel-24khz" - config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml") - model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin") + config_path = hf_hub_download( + repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml" + ) + model_path = hf_hub_download( + repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin" + ) vocoder = Vocos.from_hparams(config_path) state_dict = torch.load(model_path, map_location="cpu", weights_only=True) from vocos.feature_extractors import EncodecFeatures @@ -129,13 +151,17 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev try: from third_party.BigVGAN import bigvgan except ImportError: - print("You need to follow the README to init submodule and change the BigVGAN source code.") + print( + "You need to follow the README to init submodule and change the BigVGAN source code." + ) if is_local: # download generator from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False) else: vocoder = bigvgan.BigVGAN.from_pretrained( - "nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False, cache_dir=hf_cache_dir + "nvidia/bigvgan_v2_24khz_100band_256x", + use_cuda_kernel=False, + cache_dir=hf_cache_dir, ) vocoder.remove_weight_norm() @@ -177,7 +203,11 @@ def transcribe(ref_audio, language=None): ref_audio, chunk_length_s=30, batch_size=128, - generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"}, + generate_kwargs=( + {"task": "transcribe", "language": language} + if language + else {"task": "transcribe"} + ), return_timestamps=False, )["text"].strip() @@ -214,7 +244,10 @@ def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True): } # patch for backward compatibility, 305e3ea - for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]: + for key in [ + "mel_spec.mel_stft.mel_scale.fb", + "mel_spec.mel_stft.spectrogram.window", + ]: if key in checkpoint["model_state_dict"]: del checkpoint["model_state_dict"][key] @@ -253,7 +286,9 @@ def load_model( vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer) model = CFM( - transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), + transformer=model_cls( + **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels + ), mel_spec_kwargs=dict( n_fft=n_fft, hop_length=hop_length, @@ -276,7 +311,9 @@ def load_model( def remove_silence_edges(audio, silence_threshold=-42): # Remove silence from the start - non_silent_start_idx = silence.detect_leading_silence(audio, silence_threshold=silence_threshold) + non_silent_start_idx = silence.detect_leading_silence( + audio, silence_threshold=silence_threshold + ) audio = audio[non_silent_start_idx:] # Remove silence from the end @@ -315,11 +352,18 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print): # 1. try to find long silence for clipping non_silent_segs = silence.split_on_silence( - aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10 + aseg, + min_silence_len=1000, + silence_thresh=-50, + keep_silence=1000, + seek_step=10, ) non_silent_wave = AudioSegment.silent(duration=0) for non_silent_seg in non_silent_segs: - if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000: + if ( + len(non_silent_wave) > 6000 + and len(non_silent_wave + non_silent_seg) > 12000 + ): show_info("Audio is over 12s, clipping short. (1)") break non_silent_wave += non_silent_seg @@ -327,11 +371,18 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print): # 2. try to find short silence for clipping if 1. failed if len(non_silent_wave) > 12000: non_silent_segs = silence.split_on_silence( - aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10 + aseg, + min_silence_len=100, + silence_thresh=-40, + keep_silence=1000, + seek_step=10, ) non_silent_wave = AudioSegment.silent(duration=0) for non_silent_seg in non_silent_segs: - if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000: + if ( + len(non_silent_wave) > 6000 + and len(non_silent_wave + non_silent_seg) > 12000 + ): show_info("Audio is over 12s, clipping short. (2)") break non_silent_wave += non_silent_seg @@ -399,7 +450,12 @@ def infer_process( ): # Split the input text into batches audio, sr = torchaudio.load(ref_audio) - max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed) + max_chars = int( + len(ref_text.encode("utf-8")) + / (audio.shape[-1] / sr) + * (22 - audio.shape[-1] / sr) + * speed + ) gen_text_batches = chunk_text(gen_text, max_chars=max_chars) for i, gen_text in enumerate(gen_text_batches): print(f"gen_text {i}", gen_text) @@ -483,7 +539,9 @@ def infer_batch_process( # Calculate duration ref_text_len = len(ref_text.encode("utf-8")) gen_text_len = len(gen_text.encode("utf-8")) - duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed) + duration = ref_audio_len + int( + ref_audio_len / ref_text_len * gen_text_len / local_speed + ) # inference with torch.inference_mode(): @@ -519,12 +577,19 @@ def infer_batch_process( yield generated_wave, generated_cpu if streaming: - for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches: + for gen_text in ( + progress.tqdm(gen_text_batches) + if progress is not None + else gen_text_batches + ): for chunk in process_batch(gen_text): yield chunk else: with ThreadPoolExecutor() as executor: - futures = [executor.submit(process_batch, gen_text) for gen_text in gen_text_batches] + futures = [ + executor.submit(process_batch, gen_text) + for gen_text in gen_text_batches + ] for future in progress.tqdm(futures) if progress is not None else futures: result = future.result() if result: @@ -545,7 +610,9 @@ def infer_batch_process( # Calculate cross-fade samples, ensuring it does not exceed wave lengths cross_fade_samples = int(cross_fade_duration * target_sample_rate) - cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave)) + cross_fade_samples = min( + cross_fade_samples, len(prev_wave), len(next_wave) + ) if cross_fade_samples <= 0: # No overlap possible, concatenate @@ -561,11 +628,17 @@ def infer_batch_process( fade_in = np.linspace(0, 1, cross_fade_samples) # Cross-faded overlap - cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in + cross_faded_overlap = ( + prev_overlap * fade_out + next_overlap * fade_in + ) # Combine new_wave = np.concatenate( - [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]] + [ + prev_wave[:-cross_fade_samples], + cross_faded_overlap, + next_wave[cross_fade_samples:], + ] ) final_wave = new_wave diff --git a/f5_tts/model/__init__.py b/f5_tts/model/__init__.py index 59cf691c9f73f357dd17b43faf08d549dcbb9550..4e11f92db3e0a997c2eb52c5bec7b17cc5b17710 100644 --- a/f5_tts/model/__init__.py +++ b/f5_tts/model/__init__.py @@ -1,10 +1,7 @@ -from f5_tts.model.cfm import CFM - -from f5_tts.model.backbones.unett import UNetT from f5_tts.model.backbones.dit import DiT from f5_tts.model.backbones.mmdit import MMDiT - +from f5_tts.model.backbones.unett import UNetT +from f5_tts.model.cfm import CFM from f5_tts.model.trainer import Trainer - __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"] diff --git a/f5_tts/model/backbones/dit.py b/f5_tts/model/backbones/dit.py index 70d4cdbf46330555b6a6631b23157897a4df2541..a429337a7d5e06b6697a070707c8a264af32e0c5 100644 --- a/f5_tts/model/backbones/dit.py +++ b/f5_tts/model/backbones/dit.py @@ -10,21 +10,14 @@ d - dimension from __future__ import annotations import torch -from torch import nn import torch.nn.functional as F - +from torch import nn from x_transformers.x_transformers import RotaryEmbedding -from f5_tts.model.modules import ( - TimestepEmbedding, - ConvNeXtV2Block, - ConvPositionEmbedding, - DiTBlock, - AdaLayerNormZero_Final, - precompute_freqs_cis, - get_pos_embed_indices, -) - +from f5_tts.model.modules import (AdaLayerNormZero_Final, ConvNeXtV2Block, + ConvPositionEmbedding, DiTBlock, + TimestepEmbedding, get_pos_embed_indices, + precompute_freqs_cis) # Text embedding @@ -32,34 +25,49 @@ from f5_tts.model.modules import ( class TextEmbedding(nn.Module): def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): super().__init__() - self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token + self.text_embed = nn.Embedding( + text_num_embeds + 1, text_dim + ) # use 0 as filler token if conv_layers > 0: self.extra_modeling = True self.precompute_max_pos = 4096 # ~44s of 24khz audio - self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False) + self.register_buffer( + "freqs_cis", + precompute_freqs_cis(text_dim, self.precompute_max_pos), + persistent=False, + ) self.text_blocks = nn.Sequential( - *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)] + *[ + ConvNeXtV2Block(text_dim, text_dim * conv_mult) + for _ in range(conv_layers) + ] ) else: self.extra_modeling = False def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 - text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() - text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens + text = ( + text + 1 + ) # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() + text = text[ + :, :seq_len + ] # curtail if character tokens are more than the mel spec tokens batch, text_len = text.shape[0], text.shape[1] text = F.pad(text, (0, seq_len - text_len), value=0) if drop_text: # cfg for text text = torch.zeros_like(text) - + text = self.text_embed(text) # b n -> b n d # possible extra modeling if self.extra_modeling: # sinus pos emb batch_start = torch.zeros((batch,), dtype=torch.long) - pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos) + pos_idx = get_pos_embed_indices( + batch_start, seq_len, max_pos=self.precompute_max_pos + ) text_pos_embed = self.freqs_cis[pos_idx] text = text + text_pos_embed @@ -78,7 +86,13 @@ class InputEmbedding(nn.Module): self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) - def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722 + def forward( + self, + x: float["b n d"], + cond: float["b n d"], + text_embed: float["b n d"], + drop_audio_cond=False, + ): # noqa: F722 if drop_audio_cond: # cfg for cond audio cond = torch.zeros_like(cond) @@ -114,17 +128,23 @@ class DiT(nn.Module): if second_time: self.time_embed2 = TimestepEmbedding(dim) # Zero-init the weights and biases of the first and last Linear layers in time_mlp - nn.init.zeros_(self.time_embed2.time_mlp[0].weight) # First Linear layer weights - nn.init.zeros_(self.time_embed2.time_mlp[0].bias) # First Linear layer bias - nn.init.zeros_(self.time_embed2.time_mlp[-1].weight) # Last Linear layer weights - nn.init.zeros_(self.time_embed2.time_mlp[-1].bias) # Last Linear layer bias + nn.init.zeros_( + self.time_embed2.time_mlp[0].weight + ) # First Linear layer weights + nn.init.zeros_(self.time_embed2.time_mlp[0].bias) # First Linear layer bias + nn.init.zeros_( + self.time_embed2.time_mlp[-1].weight + ) # Last Linear layer weights + nn.init.zeros_(self.time_embed2.time_mlp[-1].bias) # Last Linear layer bias else: self.time_embed2 = None - + if text_dim is None: text_dim = mel_dim self.vocab_size = text_num_embeds - self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers) + self.text_embed = TextEmbedding( + text_num_embeds, text_dim, conv_layers=conv_layers + ) self.input_embed = InputEmbedding(mel_dim, text_dim, dim) self.rotary_embed = RotaryEmbedding(dim_head) @@ -133,9 +153,20 @@ class DiT(nn.Module): self.depth = depth self.transformer_blocks = nn.ModuleList( - [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)] + [ + DiTBlock( + dim=dim, + heads=heads, + dim_head=dim_head, + ff_mult=ff_mult, + dropout=dropout, + ) + for _ in range(depth) + ] + ) + self.long_skip_connection = ( + nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None ) - self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None self.norm_out = AdaLayerNormZero_Final(dim) # final modulation self.proj_out = nn.Linear(dim, mel_dim) @@ -171,7 +202,7 @@ class DiT(nn.Module): if second_time is not None and self.time_embed2 is not None: t2 = self.time_embed2(second_time) t = t + t2 - + text_embed = self.text_embed(text, seq_len, drop_text=drop_text) x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) @@ -185,7 +216,9 @@ class DiT(nn.Module): for block in self.transformer_blocks: if self.checkpoint_activations: - x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope) + x = torch.utils.checkpoint.checkpoint( + self.ckpt_wrapper(block), x, t, mask, rope + ) else: x = block(x, t, mask=mask, rope=rope) diff --git a/f5_tts/model/backbones/mmdit.py b/f5_tts/model/backbones/mmdit.py index d3f9a91ab7ab7ac4dd9588113e1c557c48269d70..27336a358256ce425535291ce5b39bfe92aaa9cd 100644 --- a/f5_tts/model/backbones/mmdit.py +++ b/f5_tts/model/backbones/mmdit.py @@ -10,41 +10,37 @@ d - dimension from __future__ import annotations import torch -from torch import nn import torch.nn.functional as F - +from torch import nn from x_transformers.x_transformers import RotaryEmbedding -from f5_tts.model.modules import ( - TimestepEmbedding, - ConvPositionEmbedding, - MMDiTBlock, - DiTBlock, - AdaLayerNormZero_Final, - precompute_freqs_cis, - get_pos_embed_indices, -) - -from f5_tts.model.utils import ( - default, - exists, - lens_to_mask, - list_str_to_idx, - list_str_to_tensor, - mask_from_frac_lengths, -) +from f5_tts.model.modules import (AdaLayerNormZero_Final, + ConvPositionEmbedding, DiTBlock, MMDiTBlock, + TimestepEmbedding, get_pos_embed_indices, + precompute_freqs_cis) +from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx, + list_str_to_tensor, mask_from_frac_lengths) + # text embedding class TextEmbedding(nn.Module): def __init__(self, out_dim, text_num_embeds): super().__init__() - self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token + self.text_embed = nn.Embedding( + text_num_embeds + 1, out_dim + ) # will use 0 as filler token self.precompute_max_pos = 1024 - self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False) + self.register_buffer( + "freqs_cis", + precompute_freqs_cis(out_dim, self.precompute_max_pos), + persistent=False, + ) - def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722 + def forward( + self, text: int["b nt"], drop_text=False + ) -> int["b nt d"]: # noqa: F722 text = text + 1 if drop_text: text = torch.zeros_like(text) @@ -53,7 +49,9 @@ class TextEmbedding(nn.Module): # sinus pos emb batch_start = torch.zeros((text.shape[0],), dtype=torch.long) batch_text_len = text.shape[1] - pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos) + pos_idx = get_pos_embed_indices( + batch_start, batch_text_len, max_pos=self.precompute_max_pos + ) text_pos_embed = self.freqs_cis[pos_idx] text = text + text_pos_embed @@ -70,7 +68,9 @@ class AudioEmbedding(nn.Module): self.linear = nn.Linear(2 * in_dim, out_dim) self.conv_pos_embed = ConvPositionEmbedding(out_dim) - def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722 + def forward( + self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False + ): # noqa: F722 if drop_audio_cond: cond = torch.zeros_like(cond) x = torch.cat((x, cond), dim=-1) @@ -97,23 +97,24 @@ class MMDiT(nn.Module): mel_dim=100, checkpoint_activations=False, text_encoder=True, - ): super().__init__() self.time_embed = TimestepEmbedding(dim) if text_encoder: - self.text_encoder = TextEncoder(text_num_embeds=text_num_embeds, - text_dim=dim, - depth=text_depth, - heads=heads, - dim_head=dim_head, - ff_mult=ff_mult, - dropout=dropout) + self.text_encoder = TextEncoder( + text_num_embeds=text_num_embeds, + text_dim=dim, + depth=text_depth, + heads=heads, + dim_head=dim_head, + ff_mult=ff_mult, + dropout=dropout, + ) else: self.text_encoder = None self.text_embed = TextEmbedding(dim, text_num_embeds) - + self.audio_embed = AudioEmbedding(mel_dim, dim) self.rotary_embed = RotaryEmbedding(dim_head) @@ -136,9 +137,8 @@ class MMDiT(nn.Module): ) self.norm_out = AdaLayerNormZero_Final(dim) # final modulation self.proj_out = nn.Linear(dim, mel_dim) - - self.checkpoint_activations = checkpoint_activations + self.checkpoint_activations = checkpoint_activations def forward( self, @@ -161,45 +161,53 @@ class MMDiT(nn.Module): c = self.text_encoder(text, t, mask=text_mask, drop_text=drop_text) else: c = self.text_embed(text, drop_text=drop_text) - + x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond) seq_len = x.shape[1] text_len = text.shape[1] rope_audio = self.rotary_embed.forward_from_seq_len(seq_len) rope_text = self.rotary_embed.forward_from_seq_len(text_len) - + # if mask is not None: # rope_audio = self.rotary_embed.forward_from_seq_len(seq_len + 1) - + # dummy_token = torch.zeros((x.shape[0], 1, x.shape[-1]), device=x.device, dtype=x.dtype) # x = torch.cat([x, dummy_token], dim=1) # shape is now [b, nw+1, d] - + # # pad the mask so that new dummy token is always masked out # # mask: [b, nw] -> [b, nw+1] # false_col = torch.zeros((x.shape[0], 1), dtype=torch.bool, device=x.device) # mask = torch.cat([mask, false_col], dim=1) - + # if text_mask is not None: # rope_text = self.rotary_embed.forward_from_seq_len(text_len + 1) # dummy_token = torch.zeros((c.shape[0], 1, c.shape[-1]), device=c.device, dtype=c.dtype) # c = torch.cat([c, dummy_token], dim=1) # shape is now [b, nt+1, d] - + # # pad the text mask so that new dummy token is always masked out # # text_mask: [b, nt] -> [b, nt+1] # false_col = torch.zeros((c.shape[0], 1), dtype=torch.bool, device=c.device) # text_mask = torch.cat([text_mask, false_col], dim=1) - + for block in self.transformer_blocks: - c, x = block(x, c, t, mask=mask, src_mask=text_mask, rope=rope_audio, c_rope=rope_text) + c, x = block( + x, + c, + t, + mask=mask, + src_mask=text_mask, + rope=rope_audio, + c_rope=rope_text, + ) x = self.norm_out(x, t) output = self.proj_out(x) - return output + class TextEncoder(nn.Module): def __init__( self, @@ -219,7 +227,7 @@ class TextEncoder(nn.Module): # Embeddings self.text_embed = TextEmbedding(text_dim, text_num_embeds) self.rotary_embed = RotaryEmbedding(dim_head) - + # Example stack of DiTBlocks or any custom blocks self.transformer_blocks = nn.ModuleList( [ @@ -239,7 +247,7 @@ class TextEncoder(nn.Module): text: int["b nt"], # noqa: F821 time: float["b"] | float[""], # time step # noqa: F821 F722 mask: bool["b nt"] | None = None, # noqa: F821 F722 - drop_text: bool = False + drop_text: bool = False, ): """ Encode text into hidden states of shape [b, nt, d]. @@ -251,7 +259,7 @@ class TextEncoder(nn.Module): # Basic embedding hidden_states = self.text_embed(text, seq_len) # [b, nt, d] - + # lens and mask rope = self.rotary_embed.forward_from_seq_len(seq_len) @@ -260,17 +268,18 @@ class TextEncoder(nn.Module): # Here, you likely want standard self-attn, so no cross-attn hidden_states = block( x=hidden_states, - t=time, # no time embedding for the text encoder by default - mask=mask, # or pass a text mask if needed - rope=rope # pass a rope if you want rotary embeddings for text + t=time, # no time embedding for the text encoder by default + mask=mask, # or pass a text mask if needed + rope=rope, # pass a rope if you want rotary embeddings for text ) return hidden_states + if __name__ == "__main__": from f5_tts.model.utils import get_tokenizer bsz = 16 - + tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) dataset_name = "Emilia_ZH_EN" @@ -279,23 +288,22 @@ if __name__ == "__main__": else: tokenizer_path = dataset_name vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) - + text = ["hello world"] * bsz - text_lens = torch.ones((bsz, ), dtype=torch.long) * len("hello world") + text_lens = torch.ones((bsz,), dtype=torch.long) * len("hello world") text_lens[-1] = 5 device = "cuda" batch = bsz time_embed = TimestepEmbedding(512).to(device) - - + # handle text as string if isinstance(text, list): if exists(vocab_char_map): text = list_str_to_idx(text, vocab_char_map).to(device) else: text = list_str_to_tensor(text).to(device) - assert text.shape[0] == batch - + assert text.shape[0] == batch + time = torch.rand((batch,), device=device) text_mask = lens_to_mask(text_lens).to(device) @@ -311,7 +319,7 @@ if __name__ == "__main__": # ).to('cuda') # hidden_states = text_encoder(text, time_embed(time), mask) # print(hidden_states.shape) # [bsz, seq_len, text_dim] - + # test MMDiT mel_dim = 80 model = MMDiT( @@ -323,14 +331,23 @@ if __name__ == "__main__": dropout=0.1, ff_mult=4, text_num_embeds=vocab_size, - mel_dim=mel_dim + mel_dim=mel_dim, ).to(device) - + x = torch.rand((batch, 100, mel_dim), device=device) cond = torch.rand((batch, 100, mel_dim), device=device) lens = torch.ones((batch,), dtype=torch.long) * 100 mask = lens_to_mask(lens).to(device) - - output = model(x, cond, text, time, drop_audio_cond=False, drop_text=False, mask=mask, text_mask=text_mask) - - print(output.shape) # [bsz, seq_len, mel_dim] \ No newline at end of file + + output = model( + x, + cond, + text, + time, + drop_audio_cond=False, + drop_text=False, + mask=mask, + text_mask=text_mask, + ) + + print(output.shape) # [bsz, seq_len, mel_dim] diff --git a/f5_tts/model/backbones/unett.py b/f5_tts/model/backbones/unett.py index acf649a52448e87a34a2af4bc14051caaba74c86..4cd3facaef93931ffaa00be3123c18cbf21b29d2 100644 --- a/f5_tts/model/backbones/unett.py +++ b/f5_tts/model/backbones/unett.py @@ -8,26 +8,19 @@ d - dimension """ from __future__ import annotations + from typing import Literal import torch -from torch import nn import torch.nn.functional as F - +from torch import nn from x_transformers import RMSNorm from x_transformers.x_transformers import RotaryEmbedding -from f5_tts.model.modules import ( - TimestepEmbedding, - ConvNeXtV2Block, - ConvPositionEmbedding, - Attention, - AttnProcessor, - FeedForward, - precompute_freqs_cis, - get_pos_embed_indices, -) - +from f5_tts.model.modules import (Attention, AttnProcessor, ConvNeXtV2Block, + ConvPositionEmbedding, FeedForward, + TimestepEmbedding, get_pos_embed_indices, + precompute_freqs_cis) # Text embedding @@ -35,21 +28,34 @@ from f5_tts.model.modules import ( class TextEmbedding(nn.Module): def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): super().__init__() - self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token + self.text_embed = nn.Embedding( + text_num_embeds + 1, text_dim + ) # use 0 as filler token if conv_layers > 0: self.extra_modeling = True self.precompute_max_pos = 4096 # ~44s of 24khz audio - self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False) + self.register_buffer( + "freqs_cis", + precompute_freqs_cis(text_dim, self.precompute_max_pos), + persistent=False, + ) self.text_blocks = nn.Sequential( - *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)] + *[ + ConvNeXtV2Block(text_dim, text_dim * conv_mult) + for _ in range(conv_layers) + ] ) else: self.extra_modeling = False def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 - text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() - text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens + text = ( + text + 1 + ) # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() + text = text[ + :, :seq_len + ] # curtail if character tokens are more than the mel spec tokens batch, text_len = text.shape[0], text.shape[1] text = F.pad(text, (0, seq_len - text_len), value=0) @@ -62,7 +68,9 @@ class TextEmbedding(nn.Module): if self.extra_modeling: # sinus pos emb batch_start = torch.zeros((batch,), dtype=torch.long) - pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos) + pos_idx = get_pos_embed_indices( + batch_start, seq_len, max_pos=self.precompute_max_pos + ) text_pos_embed = self.freqs_cis[pos_idx] text = text + text_pos_embed @@ -81,7 +89,13 @@ class InputEmbedding(nn.Module): self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) - def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722 + def forward( + self, + x: float["b n d"], + cond: float["b n d"], + text_embed: float["b n d"], + drop_audio_cond=False, + ): # noqa: F722 if drop_audio_cond: # cfg for cond audio cond = torch.zeros_like(cond) @@ -115,7 +129,9 @@ class UNetT(nn.Module): self.time_embed = TimestepEmbedding(dim) if text_dim is None: text_dim = mel_dim - self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers) + self.text_embed = TextEmbedding( + text_num_embeds, text_dim, conv_layers=conv_layers + ) self.input_embed = InputEmbedding(mel_dim, text_dim, dim) self.rotary_embed = RotaryEmbedding(dim_head) @@ -144,7 +160,11 @@ class UNetT(nn.Module): ff_norm = RMSNorm(dim) ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") - skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None + skip_proj = ( + nn.Linear(dim * 2, dim, bias=False) + if needs_skip_proj and is_later_half + else None + ) self.layers.append( nn.ModuleList( @@ -190,7 +210,9 @@ class UNetT(nn.Module): # flat unet transformer skip_connect_type = self.skip_connect_type skips = [] - for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers): + for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate( + self.layers + ): layer = idx + 1 # skip connection logic diff --git a/f5_tts/model/cfm.py b/f5_tts/model/cfm.py index 9de0f3eaeeef46ce66cfbadaae74340c0104a54a..154667665de4c50f1bf5478413f61d5ff853800c 100644 --- a/f5_tts/model/cfm.py +++ b/f5_tts/model/cfm.py @@ -19,14 +19,8 @@ from torch.nn.utils.rnn import pad_sequence from torchdiffeq import odeint from f5_tts.model.modules import MelSpec -from f5_tts.model.utils import ( - default, - exists, - lens_to_mask, - list_str_to_idx, - list_str_to_tensor, - mask_from_frac_lengths, -) +from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx, + list_str_to_tensor, mask_from_frac_lengths) class CFM(nn.Module): @@ -74,7 +68,7 @@ class CFM(nn.Module): # vocab map for tokenization self.vocab_char_map = vocab_char_map - + self.scale = scale @property @@ -109,11 +103,11 @@ class CFM(nn.Module): assert cond.shape[-1] == self.num_channels cond = cond.to(next(self.parameters()).dtype) - + print(self.scale) cond = cond / self.scale - + batch, cond_seq_len, device = *cond.shape[:2], cond.device if not exists(lens): lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) @@ -129,7 +123,9 @@ class CFM(nn.Module): if exists(text): text_lens = (text != -1).sum(dim=-1) - lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters + lens = torch.maximum( + text_lens, lens + ) # make sure lengths are at least those of the text characters # duration @@ -140,19 +136,25 @@ class CFM(nn.Module): if isinstance(duration, int): duration = torch.full((batch,), duration, device=device, dtype=torch.long) - duration = torch.maximum(lens + 1, duration) # just add one token so something is generated + duration = torch.maximum( + lens + 1, duration + ) # just add one token so something is generated duration = duration.clamp(max=max_duration) max_duration = duration.amax() # duplicate test corner for inner time step oberservation if duplicate_test: - test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0) + test_cond = F.pad( + cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0 + ) cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) if no_ref_audio: cond = torch.zeros_like(cond) - cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) + cond_mask = F.pad( + cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False + ) cond_mask = cond_mask.unsqueeze(-1) step_cond = torch.where( cond_mask, cond, torch.zeros_like(cond) @@ -171,13 +173,25 @@ class CFM(nn.Module): # predict flow pred = self.transformer( - x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False + x=x, + cond=step_cond, + text=text, + time=t, + mask=mask, + drop_audio_cond=False, + drop_text=False, ) if cfg_strength < 1e-5: return pred null_pred = self.transformer( - x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True + x=x, + cond=step_cond, + text=text, + time=t, + mask=mask, + drop_audio_cond=True, + drop_text=True, ) return pred + (pred - null_pred) * cfg_strength @@ -188,7 +202,11 @@ class CFM(nn.Module): for dur in duration: if exists(seed): torch.manual_seed(seed) - y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype)) + y0.append( + torch.randn( + dur, self.num_channels, device=self.device, dtype=step_cond.dtype + ) + ) y0 = pad_sequence(y0, padding_value=0, batch_first=True) t_start = 0 @@ -199,7 +217,9 @@ class CFM(nn.Module): y0 = (1 - t_start) * y0 + t_start * test_cond steps = int(steps * (1 - t_start)) - t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype) + t = torch.linspace( + t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype + ) if sway_sampling_coef is not None: t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) @@ -210,7 +230,7 @@ class CFM(nn.Module): out = torch.where(cond_mask, cond, out) out = out * self.scale - + if exists(vocoder): out = out.permute(0, 2, 1) out = vocoder(out) @@ -231,7 +251,12 @@ class CFM(nn.Module): inp = inp.permute(0, 2, 1) assert inp.shape[-1] == self.num_channels - batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma + batch, seq_len, dtype, device, _σ1 = ( + *inp.shape[:2], + inp.dtype, + self.device, + self.sigma, + ) # handle text as string if isinstance(text, list): @@ -245,10 +270,16 @@ class CFM(nn.Module): if not exists(lens): lens = torch.full((batch,), seq_len, device=device) - mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch + mask = lens_to_mask( + lens, length=seq_len + ) # useless here, as collate_fn will pad to max length in batch # get a random span to mask out for training conditionally - frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask) + frac_lengths = ( + torch.zeros((batch,), device=self.device) + .float() + .uniform_(*self.frac_lengths_mask) + ) rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) if exists(mask): @@ -283,11 +314,16 @@ class CFM(nn.Module): # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences pred = self.transformer( - x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text + x=φ, + cond=cond, + text=text, + time=time, + drop_audio_cond=drop_audio_cond, + drop_text=drop_text, ) # flow matching loss loss = F.mse_loss(pred, flow, reduction="none") loss = loss[rand_span_mask] - return loss.mean(), cond, pred, t \ No newline at end of file + return loss.mean(), cond, pred, t diff --git a/f5_tts/model/dataset.py b/f5_tts/model/dataset.py index 13734a1c479aa9be0c5fb13fab3e8b96afd1c2a8..f7508b403b05e79a0725bd05a048a3b7026312fd 100644 --- a/f5_tts/model/dataset.py +++ b/f5_tts/model/dataset.py @@ -1,6 +1,6 @@ -import re import json import random +import re from importlib.resources import files import torch @@ -15,8 +15,9 @@ from tqdm import tqdm from f5_tts.model.modules import MelSpec from f5_tts.model.utils import default + def get_speaker_id(path): - parts = path.split('/') + parts = path.split("/") speaker_id = parts[-3] return speaker_id @@ -40,7 +41,7 @@ class CustomDataset(Dataset): return_wavform=False, remove_starting_space=True, need_prompt_speech=False, - prompt_repository: dict=None, + prompt_repository: dict = None, ): self.data = custom_dataset self.durations = durations @@ -63,42 +64,32 @@ class CustomDataset(Dataset): mel_spec_type=mel_spec_type, ), ) - + self.validation = validation self.validation_num = validation_num if (not validation) and data_augmentation: - print('Using data augmentation.') - self.augment = Compose([ - AddBackgroundNoise( - sounds_path="/data5/ESC-50-master", - min_snr_db=3.0, - max_snr_db=30.0, - noise_transform=PolarityInversion(), - p=0.5 - ), - AddGaussianNoise( - min_amplitude=0.001, - max_amplitude=0.015, - p=0.5 - ), - PitchShift( - min_semitones=-12.0, - max_semitones=12.0, - p=0.8 - ), - ApplyImpulseResponse(ir_path="/data5/Audio", p=1.0), - Aliasing(min_sample_rate=4000, max_sample_rate=30000, p=0.3), - BandPassFilter(min_center_freq=100.0, max_center_freq=6000, p=0.2), - SevenBandParametricEQ(p=0.2), - TanhDistortion( - min_distortion=0.01, - max_distortion=0.7, - p=0.2 - ), - ]) + print("Using data augmentation.") + self.augment = Compose( + [ + AddBackgroundNoise( + sounds_path="/data5/ESC-50-master", + min_snr_db=3.0, + max_snr_db=30.0, + noise_transform=PolarityInversion(), + p=0.5, + ), + AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5), + PitchShift(min_semitones=-12.0, max_semitones=12.0, p=0.8), + ApplyImpulseResponse(ir_path="/data5/Audio", p=1.0), + Aliasing(min_sample_rate=4000, max_sample_rate=30000, p=0.3), + BandPassFilter(min_center_freq=100.0, max_center_freq=6000, p=0.2), + SevenBandParametricEQ(p=0.2), + TanhDistortion(min_distortion=0.01, max_distortion=0.7, p=0.2), + ] + ) else: - print('No data augmentation.') + print("No data augmentation.") self.augment = None self.return_wavform = return_wavform @@ -112,7 +103,7 @@ class CustomDataset(Dataset): text = row["text"] duration = row["duration"] spk_id = get_speaker_id(audio_path) - assert spk_id != None and spk_id != 'mp3' + assert spk_id != None and spk_id != "mp3" if spk_id not in self.prompt_repository: self.prompt_repository[spk_id] = [row] else: @@ -120,13 +111,14 @@ class CustomDataset(Dataset): else: self.prompt_repository = prompt_repository - print(f'Grouped samples into {len(self.prompt_repository.keys())} speakers.') + print( + f"Grouped samples into {len(self.prompt_repository.keys())} speakers." + ) self.need_prompt_speech = True else: self.need_prompt_speech = False - def get_frame_len(self, index): if self.validation: index += len(self.data) - self.validation_num @@ -164,9 +156,9 @@ class CustomDataset(Dataset): index = (index + 1) % len(self.data) if self.remove_starting_space: - while len(text) > 1 and text[0] == ' ': + while len(text) > 1 and text[0] == " ": text = text[1:] - + if self.preprocessed_mel: mel_spec = torch.tensor(row["mel_spec"]) else: @@ -178,31 +170,37 @@ class CustomDataset(Dataset): # resample if necessary if source_sample_rate != self.target_sample_rate: - resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate) + resampler = torchaudio.transforms.Resample( + source_sample_rate, self.target_sample_rate + ) audio = resampler(audio) if not self.validation: if self.augment != None: - audio = self.augment(audio.squeeze().numpy(), sample_rate=self.target_sample_rate) + audio = self.augment( + audio.squeeze().numpy(), sample_rate=self.target_sample_rate + ) audio = torch.from_numpy(audio).float().unsqueeze(0) # to mel spectrogram mel_spec = self.mel_spectrogram(audio) mel_spec = mel_spec.squeeze(0) # '1 d t -> d t' - out['mel_spec'] = mel_spec - out['text'] = text - out['duration'] = duration - out['target_text'] = self.data[(index + len(self.data) // 2) % len(self.data)]["text"] + out["mel_spec"] = mel_spec + out["text"] = text + out["duration"] = duration + out["target_text"] = self.data[(index + len(self.data) // 2) % len(self.data)][ + "text" + ] if self.return_wavform: - out['wav'] = audio + out["wav"] = audio if return_path: - out['path'] = audio_path + out["path"] = audio_path if return_row: - out['row'] = row + out["row"] = row # Sample a prompt speech of the same speaker # From prompt_repository @@ -212,9 +210,9 @@ class CustomDataset(Dataset): _count = 100 while True: pmt_row = random.choice(spk_repository) - pmt_audio_path = pmt_row['audio_path'] - pmt_text = pmt_row['text'] - pmt_duration = pmt_row['duration'] + pmt_audio_path = pmt_row["audio_path"] + pmt_text = pmt_row["text"] + pmt_duration = pmt_row["duration"] if not isinstance(pmt_text, list): pmt_text = list(pmt_text) @@ -223,14 +221,14 @@ class CustomDataset(Dataset): if 0.3 <= pmt_duration <= 30 and (0 < len(pmt_text) < 2048): if pmt_text != text: break - _count = _count - 1 + _count = _count - 1 if _count <= 0: break if self.remove_starting_space: - while len(pmt_text) > 1 and pmt_text[0] == ' ': + while len(pmt_text) > 1 and pmt_text[0] == " ": pmt_text = pmt_text[1:] - + if self.preprocessed_mel: pmt_mel_spec = torch.tensor(pmt_row["mel_spec"]) else: @@ -242,30 +240,35 @@ class CustomDataset(Dataset): # resample if necessary if source_sample_rate != self.target_sample_rate: - resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate) + resampler = torchaudio.transforms.Resample( + source_sample_rate, self.target_sample_rate + ) pmt_audio = resampler(pmt_audio) if not self.validation: if self.augment != None: - pmt_audio = self.augment(pmt_audio.squeeze().numpy(), sample_rate=self.target_sample_rate) + pmt_audio = self.augment( + pmt_audio.squeeze().numpy(), + sample_rate=self.target_sample_rate, + ) pmt_audio = torch.from_numpy(pmt_audio).float().unsqueeze(0) # to mel spectrogram pmt_mel_spec = self.mel_spectrogram(pmt_audio) pmt_mel_spec = pmt_mel_spec.squeeze(0) # '1 d t -> d t' - out['pmt_mel_spec'] = pmt_mel_spec - out['pmt_text'] = pmt_text - out['pmt_duration'] = pmt_duration + out["pmt_mel_spec"] = pmt_mel_spec + out["pmt_text"] = pmt_text + out["pmt_duration"] = pmt_duration if self.return_wavform: - out['pmt_wav'] = pmt_audio + out["pmt_wav"] = pmt_audio if return_path: - out['pmt_path'] = pmt_audio_path + out["pmt_path"] = pmt_audio_path if return_row: - out['pmt_row'] = pmt_row + out["pmt_row"] = pmt_row return out @@ -280,7 +283,12 @@ class DynamicBatchSampler(Sampler[list[int]]): """ def __init__( - self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False + self, + sampler: Sampler[int], + frames_threshold: int, + max_samples=0, + random_seed=None, + drop_last: bool = False, ): self.sampler = sampler self.frames_threshold = frames_threshold @@ -302,7 +310,9 @@ class DynamicBatchSampler(Sampler[list[int]]): # indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu" # ): for idx, frame_len in indices: - if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples): + if batch_frames + frame_len <= self.frames_threshold and ( + max_samples == 0 or len(batch) < max_samples + ): batch.append(idx) batch_frames += frame_len else: @@ -337,6 +347,7 @@ class DynamicBatchSampler(Sampler[list[int]]): # Load dataset + def load_dataset( dataset_name: str, tokenizer: str = "pinyin", @@ -349,7 +360,7 @@ def load_dataset( return_wavform: bool = False, remove_starting_space: bool = True, need_prompt_speech: bool = False, - prompt_repository: dict = None + prompt_repository: dict = None, ) -> CustomDataset: """ dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset @@ -359,9 +370,13 @@ def load_dataset( print("Loading dataset ...") if dataset_type == "CustomDataset": - rel_data_path = str(f'/home/yl4579/F5-TTS-diff/F5-TTS-DMD-flow-ds/data/{dataset_name}_{tokenizer}') - if 'LibriTTS_100_360_500_char_pinyin' in rel_data_path: - rel_data_path = rel_data_path.replace('LibriTTS_100_360_500_char_pinyin', 'LibriTTS_100_360_500_char') + rel_data_path = str( + f"/home/yl4579/F5-TTS-diff/F5-TTS-DMD-flow-ds/data/{dataset_name}_{tokenizer}" + ) + if "LibriTTS_100_360_500_char_pinyin" in rel_data_path: + rel_data_path = rel_data_path.replace( + "LibriTTS_100_360_500_char_pinyin", "LibriTTS_100_360_500_char" + ) if audio_type == "raw": try: train_dataset = load_from_disk(f"{rel_data_path}/raw") @@ -385,7 +400,7 @@ def load_dataset( return_wavform=return_wavform, remove_starting_space=remove_starting_space, need_prompt_speech=need_prompt_speech, - prompt_repository=prompt_repository + prompt_repository=prompt_repository, ) elif dataset_type == "CustomDatasetPath": @@ -398,7 +413,10 @@ def load_dataset( data_dict = json.load(f) durations = data_dict["duration"] train_dataset = CustomDataset( - train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs + train_dataset, + durations=durations, + preprocessed_mel=preprocessed_mel, + **mel_spec_kwargs, ) return train_dataset @@ -410,7 +428,7 @@ def collate_fn(batch): mel_specs = [item["mel_spec"].squeeze(0) for item in batch] mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs]) max_mel_length = mel_lengths.amax() - + # Pad mel_specs padded_mel_specs = [] for spec in mel_specs: # TODO. maybe records mask for attention here @@ -419,8 +437,8 @@ def collate_fn(batch): padded_mel_specs.append(padded_spec) mel_specs = torch.stack(padded_mel_specs) - text = [item['text'] for item in batch] - target_text = [item['target_text'] for item in batch] + text = [item["text"] for item in batch] + target_text = [item["target_text"] for item in batch] text_lengths = torch.LongTensor([len(item) for item in text]) @@ -432,26 +450,26 @@ def collate_fn(batch): target_text=target_text, ) - if 'pmt_mel_spec' in batch[0]: + if "pmt_mel_spec" in batch[0]: pmt_mel_specs = [item["pmt_mel_spec"].squeeze(0) for item in batch] pmt_mel_lengths = torch.LongTensor([spec.shape[-1] for spec in pmt_mel_specs]) max_pmt_mel_length = pmt_mel_lengths.amax() - + # Pad mel_specs padded_pmt_mel_specs = [] - for spec in pmt_mel_specs: + for spec in pmt_mel_specs: padding = (0, max_pmt_mel_length - spec.size(-1)) padded_spec = F.pad(spec, padding, value=0) padded_pmt_mel_specs.append(padded_spec) pmt_mel_specs = torch.stack(padded_pmt_mel_specs) - out['pmt_mel_specs'] = pmt_mel_specs + out["pmt_mel_specs"] = pmt_mel_specs - if 'pmt_text' in batch[0]: - pmt_text = [item['pmt_text'] for item in batch] + if "pmt_text" in batch[0]: + pmt_text = [item["pmt_text"] for item in batch] pmt_text_lengths = torch.LongTensor([len(item) for item in pmt_text]) - out['pmt_text'] = pmt_text - out['pmt_text_lengths'] = pmt_text_lengths + out["pmt_text"] = pmt_text + out["pmt_text_lengths"] = pmt_text_lengths - return out \ No newline at end of file + return out diff --git a/f5_tts/model/modules.py b/f5_tts/model/modules.py index 62507a4c802f7386cdba2005b62f076940aa8fd8..363435e2d6b70d56c959ac7a20d4a31a5e4a9f8e 100644 --- a/f5_tts/model/modules.py +++ b/f5_tts/model/modules.py @@ -19,7 +19,6 @@ from librosa.filters import mel as librosa_mel_fn from torch import nn from x_transformers.x_transformers import apply_rotary_pos_emb - # raw wav to mel spec @@ -42,15 +41,25 @@ def get_bigvgan_mel_spectrogram( key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}" if key not in mel_basis_cache: - mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax) - mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()? + mel = librosa_mel_fn( + sr=target_sample_rate, + n_fft=n_fft, + n_mels=n_mel_channels, + fmin=fmin, + fmax=fmax, + ) + mel_basis_cache[key] = ( + torch.from_numpy(mel).float().to(device) + ) # TODO: why they need .float()? hann_window_cache[key] = torch.hann_window(win_length).to(device) mel_basis = mel_basis_cache[key] hann_window = hann_window_cache[key] padding = (n_fft - hop_length) // 2 - waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1) + waveform = torch.nn.functional.pad( + waveform.unsqueeze(1), (padding, padding), mode="reflect" + ).squeeze(1) spec = torch.stft( waveform, @@ -112,7 +121,9 @@ class MelSpec(nn.Module): mel_spec_type="vocos", ): super().__init__() - assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan") + assert mel_spec_type in ["vocos", "bigvgan"], print( + "We only support two extract mel backend: vocos or bigvgan" + ) self.n_fft = n_fft self.hop_length = hop_length @@ -193,7 +204,9 @@ class ConvPositionEmbedding(nn.Module): # rotary positional embedding related -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0): +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0 +): # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning # has some connection to NTK literature # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ @@ -209,10 +222,15 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca def get_pos_embed_indices(start, length, max_pos, scale=1.0): # length = length if isinstance(length, int) else length.max() - scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar + scale = scale * torch.ones_like( + start, dtype=torch.float32 + ) # in case scale is a scalar pos = ( start.unsqueeze(1) - + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long() + + ( + torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) + * scale.unsqueeze(1) + ).long() ) # avoid extra long error. pos = torch.where(pos < max_pos, pos, max_pos - 1) @@ -251,7 +269,9 @@ class ConvNeXtV2Block(nn.Module): dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation ) # depthwise conv self.norm = nn.LayerNorm(dim, eps=1e-6) - self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers + self.pwconv1 = nn.Linear( + dim, intermediate_dim + ) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.grn = GRN(intermediate_dim) self.pwconv2 = nn.Linear(intermediate_dim, dim) @@ -284,7 +304,9 @@ class AdaLayerNormZero(nn.Module): def forward(self, x, emb=None): emb = self.linear(self.silu(emb)) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk( + emb, 6, dim=1 + ) x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] return x, gate_msa, shift_mlp, scale_mlp, gate_mlp @@ -315,14 +337,18 @@ class AdaLayerNormZero_Final(nn.Module): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"): + def __init__( + self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none" + ): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim activation = nn.GELU(approximate=approximate) project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation) - self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + self.ff = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) def forward(self, x): return self.ff(x) @@ -346,7 +372,9 @@ class Attention(nn.Module): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + raise ImportError( + "Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) self.processor = processor @@ -385,7 +413,9 @@ class Attention(nn.Module): c_rope=None, # rotary position embedding for c ) -> torch.Tensor: if c is not None: - return self.processor(self, x, c=c, mask=mask, src_mask=src_mask, rope=rope, c_rope=c_rope) + return self.processor( + self, x, c=c, mask=mask, src_mask=src_mask, rope=rope, c_rope=c_rope + ) else: return self.processor(self, x, mask=mask, rope=rope) @@ -414,7 +444,9 @@ class AttnProcessor: # apply rotary position embedding if rope is not None: freqs, xpos_scale = rope - q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + ) query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) @@ -430,11 +462,15 @@ class AttnProcessor: if mask is not None: attn_mask = mask attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' - attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + attn_mask = attn_mask.expand( + batch_size, attn.heads, query.shape[-2], key.shape[-2] + ) else: attn_mask = None - x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) + x = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False + ) x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) x = x.to(query.dtype) @@ -461,12 +497,12 @@ class JointAttnProcessor: def __call__( self, attn: Attention, - x: float["b n d"], # noised input x - c: float["b nt d"] = None, # context c, here text + x: float["b n d"], # noised input x + c: float["b nt d"] = None, # context c, here text mask: bool["b n"] | None = None, src_mask: bool["b nt"] | None = None, - rope=None, # rotary position embedding for x - c_rope=None, # rotary position embedding for c + rope=None, # rotary position embedding for x + c_rope=None, # rotary position embedding for c ) -> torch.FloatTensor: residual = x batch_size = c.shape[0] @@ -484,14 +520,18 @@ class JointAttnProcessor: # apply rope for x if rope is not None: freqs, xpos_scale = rope - q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + ) query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) # apply rope for c if c_rope is not None: freqs, xpos_scale = c_rope - q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + ) c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale) c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale) @@ -515,17 +555,23 @@ class JointAttnProcessor: attn_mask_c = F.pad(src_mask, (x.shape[1], 0), value=True) attn_mask = attn_mask & attn_mask_c attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) - attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + attn_mask = attn_mask.expand( + batch_size, attn.heads, query.shape[-2], key.shape[-2] + ) else: if src_mask is not None: # if there's no mask for x but there's src_mask attn_mask = F.pad(src_mask, (x.shape[1], 0), value=True) attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) - attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + attn_mask = attn_mask.expand( + batch_size, attn.heads, query.shape[-2], key.shape[-2] + ) else: attn_mask = None - x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) + x = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False + ) x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) x = x.to(query.dtype) @@ -546,7 +592,6 @@ class JointAttnProcessor: return x, c - # DiT Block @@ -564,7 +609,9 @@ class DiTBlock(nn.Module): ) self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + self.ff = FeedForward( + dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh" + ) def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding # pre-norm & modulation for attention input @@ -596,12 +643,16 @@ class MMDiTBlock(nn.Module): context_pre_only: last layer only do prenorm + modulation cuz no more ffn """ - def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False): + def __init__( + self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False + ): super().__init__() self.context_pre_only = context_pre_only - self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim) + self.attn_norm_c = ( + AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim) + ) self.attn_norm_x = AdaLayerNormZero(dim) self.attn = Attention( processor=JointAttnProcessor(), @@ -615,23 +666,35 @@ class MMDiTBlock(nn.Module): if not context_pre_only: self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + self.ff_c = FeedForward( + dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh" + ) else: self.ff_norm_c = None self.ff_c = None self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + self.ff_x = FeedForward( + dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh" + ) - def forward(self, x, c, t, mask=None, src_mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding + def forward( + self, x, c, t, mask=None, src_mask=None, rope=None, c_rope=None + ): # x: noised input, c: context, t: time embedding # pre-norm & modulation for attention input if self.context_pre_only: norm_c = self.attn_norm_c(c, t) else: - norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t) - norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t) + norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c( + c, emb=t + ) + norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x( + x, emb=t + ) # attention - x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, src_mask=src_mask, rope=rope, c_rope=c_rope) + x_attn_output, c_attn_output = self.attn( + x=norm_x, c=norm_c, mask=mask, src_mask=src_mask, rope=rope, c_rope=c_rope + ) # process attention output for context c if self.context_pre_only: @@ -639,7 +702,9 @@ class MMDiTBlock(nn.Module): else: # if not last layer c = c + c_gate_msa.unsqueeze(1) * c_attn_output - norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + norm_c = ( + self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + ) c_ff_output = self.ff_c(norm_c) c = c + c_gate_mlp.unsqueeze(1) * c_ff_output @@ -660,7 +725,9 @@ class TimestepEmbedding(nn.Module): def __init__(self, dim, freq_embed_dim=256): super().__init__() self.time_embed = SinusPositionEmbedding(freq_embed_dim) - self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_mlp = nn.Sequential( + nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim) + ) def forward(self, timestep: float["b"]): # noqa: F821 time_hidden = self.time_embed(timestep) diff --git a/f5_tts/model/trainer.py b/f5_tts/model/trainer.py index 63c74a0e37f2f3afd7acf3fb79d234da1b903ea7..a3ae90e6efbb226ae564318d9aa1dba3d2a5460a 100644 --- a/f5_tts/model/trainer.py +++ b/f5_tts/model/trainer.py @@ -67,7 +67,13 @@ class Trainer: self.logger = logger if self.logger == "wandb": if exists(wandb_resume_id): - init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}} + init_kwargs = { + "wandb": { + "resume": "allow", + "name": wandb_run_name, + "id": wandb_resume_id, + } + } else: init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}} @@ -102,7 +108,9 @@ class Trainer: self.epochs = epochs self.num_warmup_updates = num_warmup_updates self.save_per_updates = save_per_updates - self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps) + self.last_per_steps = default( + last_per_steps, save_per_updates * grad_accumulation_steps + ) self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts") self.batch_size = batch_size @@ -126,8 +134,10 @@ class Trainer: self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate) else: self.optimizer = AdamW(model.parameters(), lr=learning_rate) - self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) - + self.model, self.optimizer = self.accelerator.prepare( + self.model, self.optimizer + ) + self.scale = None self.count = 0 @@ -137,10 +147,12 @@ class Trainer: def save_checkpoint(self, step, last=False): self.accelerator.wait_for_everyone() - if self.is_main: + if self.is_main: checkpoint = dict( model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(), - optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(), + optimizer_state_dict=self.accelerator.unwrap_model( + self.optimizer + ).state_dict(), ema_model_state_dict=self.ema_model.state_dict(), scheduler_state_dict=self.scheduler.state_dict(), step=step, @@ -150,16 +162,23 @@ class Trainer: if not os.path.exists(self.checkpoint_path): os.makedirs(self.checkpoint_path) if last: - self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt") + self.accelerator.save( + checkpoint, f"{self.checkpoint_path}/model_last.pt" + ) print(f"Saved last checkpoint at step {step}") else: - self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt") + self.accelerator.save( + checkpoint, f"{self.checkpoint_path}/model_{step}.pt" + ) def load_checkpoint(self): if ( not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) - or not any(filename.endswith(".pt") for filename in os.listdir(self.checkpoint_path)) + or not any( + filename.endswith(".pt") + for filename in os.listdir(self.checkpoint_path) + ) ): return 0 @@ -172,10 +191,17 @@ class Trainer: key=lambda x: int("".join(filter(str.isdigit, x))), )[-1] # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ - checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu") + checkpoint = torch.load( + f"{self.checkpoint_path}/{latest_checkpoint}", + weights_only=True, + map_location="cpu", + ) # patch for backward compatibility, 305e3ea - for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]: + for key in [ + "ema_model.mel_spec.mel_stft.mel_scale.fb", + "ema_model.mel_spec.mel_stft.spectrogram.window", + ]: if key in checkpoint["ema_model_state_dict"]: del checkpoint["ema_model_state_dict"][key] @@ -184,12 +210,19 @@ class Trainer: if "step" in checkpoint: # patch for backward compatibility, 305e3ea - for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]: + for key in [ + "mel_spec.mel_stft.mel_scale.fb", + "mel_spec.mel_stft.spectrogram.window", + ]: if key in checkpoint["model_state_dict"]: del checkpoint["model_state_dict"][key] - self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"]) - self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"]) + self.accelerator.unwrap_model(self.model).load_state_dict( + checkpoint["model_state_dict"] + ) + self.accelerator.unwrap_model(self.optimizer).load_state_dict( + checkpoint["optimizer_state_dict"] + ) if self.scheduler: self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) step = checkpoint["step"] @@ -199,28 +232,37 @@ class Trainer: for k, v in checkpoint["ema_model_state_dict"].items() if k not in ["initted", "step"] } - self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"]) + self.accelerator.unwrap_model(self.model).load_state_dict( + checkpoint["model_state_dict"] + ) step = 0 if "scale" in checkpoint: self.scale = float(checkpoint["scale"]) self.model.scale = self.scale - + if "count" in checkpoint: self.count = int(checkpoint["count"]) - + del checkpoint gc.collect() return step - def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None): + def train( + self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None + ): if self.log_samples: - from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef + from f5_tts.infer.utils_infer import (cfg_strength, load_vocoder, + nfe_step, sway_sampling_coef) vocoder = load_vocoder( - vocoder_name=self.vocoder_name, is_local=self.is_local_vocoder, local_path=self.local_vocoder_path + vocoder_name=self.vocoder_name, + is_local=self.is_local_vocoder, + local_path=self.local_vocoder_path, ) - target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate + target_sample_rate = self.accelerator.unwrap_model( + self.model + ).mel_spec.target_sample_rate log_samples_path = f"{self.checkpoint_path}/samples" os.makedirs(log_samples_path, exist_ok=True) @@ -245,7 +287,11 @@ class Trainer: self.accelerator.even_batches = False sampler = SequentialSampler(train_dataset) batch_sampler = DynamicBatchSampler( - sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False + sampler, + self.batch_size, + max_samples=self.max_samples, + random_seed=resumable_with_seed, + drop_last=False, ) train_dataloader = DataLoader( train_dataset, @@ -256,7 +302,9 @@ class Trainer: batch_sampler=batch_sampler, ) else: - raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}") + raise ValueError( + f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}" + ) # accelerator.prepare() dispatches batches to devices; # which means the length of dataloader calculated before, should consider the number of devices @@ -266,10 +314,16 @@ class Trainer: # otherwise by default with split_batches=False, warmup steps change with num_processes total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps decay_steps = total_steps - warmup_steps - warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps) - decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps) + warmup_scheduler = LinearLR( + self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps + ) + decay_scheduler = LinearLR( + self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps + ) self.scheduler = SequentialLR( - self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps] + self.optimizer, + schedulers=[warmup_scheduler, decay_scheduler], + milestones=[warmup_steps], ) train_dataloader, self.scheduler = self.accelerator.prepare( train_dataloader, self.scheduler @@ -281,7 +335,9 @@ class Trainer: orig_epoch_step = len(train_dataloader) skipped_epoch = int(start_step // orig_epoch_step) skipped_batch = start_step % orig_epoch_step - skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch) + skipped_dataloader = self.accelerator.skip_first_batches( + train_dataloader, num_batches=skipped_batch + ) else: skipped_epoch = 0 @@ -309,28 +365,40 @@ class Trainer: text_inputs = batch["text"] mel_spec = batch["mel"].permute(0, 2, 1) mel_lengths = batch["mel_lengths"] - + self.count += 1 - + if self.scale is None: self.scale = mel_spec.std() else: self.scale += (mel_spec.std() - self.scale) / self.count - - mel_spec = mel_spec / self.scale # normalize mel spectrogram - + + mel_spec = mel_spec / self.scale # normalize mel spectrogram + # TODO. add duration predictor training - if self.duration_predictor is not None and self.accelerator.is_local_main_process: - dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations")) - self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step) + if ( + self.duration_predictor is not None + and self.accelerator.is_local_main_process + ): + dur_loss = self.duration_predictor( + mel_spec, lens=batch.get("durations") + ) + self.accelerator.log( + {"duration loss": dur_loss.item()}, step=global_step + ) loss, cond, pred, t = self.model( - mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler + mel_spec, + text=text_inputs, + lens=mel_lengths, + noise_scheduler=self.noise_scheduler, ) self.accelerator.backward(loss) if self.max_grad_norm > 0 and self.accelerator.sync_gradients: - self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) + self.accelerator.clip_grad_norm_( + self.model.parameters(), self.max_grad_norm + ) self.optimizer.step() self.scheduler.step() @@ -342,18 +410,30 @@ class Trainer: global_step += 1 if self.accelerator.is_local_main_process: - self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step) + self.accelerator.log( + {"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, + step=global_step, + ) if self.logger == "tensorboard": self.writer.add_scalar("loss", loss.item(), global_step) - self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step) + self.writer.add_scalar( + "lr", self.scheduler.get_last_lr()[0], global_step + ) progress_bar.set_postfix(step=str(global_step), loss=loss.item()) - if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0: + if ( + global_step % (self.save_per_updates * self.grad_accumulation_steps) + == 0 + ): self.save_checkpoint(global_step) if self.log_samples and self.accelerator.is_local_main_process: - gen_mel_spec = pred[0].unsqueeze(0).permute(0, 2, 1) * self.scale - ref_mel_spec = cond[0].unsqueeze(0).permute(0, 2, 1) * self.scale + gen_mel_spec = ( + pred[0].unsqueeze(0).permute(0, 2, 1) * self.scale + ) + ref_mel_spec = ( + cond[0].unsqueeze(0).permute(0, 2, 1) * self.scale + ) with torch.inference_mode(): if self.vocoder_name == "vocos": gen_audio = vocoder.decode(gen_mel_spec).cpu() @@ -361,51 +441,56 @@ class Trainer: elif self.vocoder_name == "bigvgan": gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu() ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu() - + gen_audio = wandb.Audio( gen_audio.float().numpy().squeeze(), sample_rate=24000, - caption="time: " + str(t[0].squeeze().float().cpu().numpy()) + caption="time: " + + str(t[0].squeeze().float().cpu().numpy()), ) ref_audio = wandb.Audio( ref_audio.float().numpy().squeeze(), sample_rate=24000, - caption="time: " + str(t[0].squeeze().float().cpu().numpy()) + caption="time: " + + str(t[0].squeeze().float().cpu().numpy()), + ) + + self.accelerator.log( + { + "gen_audio": gen_audio, + "ref_audio": ref_audio, + }, + step=global_step, ) - self.accelerator.log({"gen_audio": gen_audio, - "ref_audio": ref_audio, - }, step=global_step) - - -# if self.log_samples and self.accelerator.is_local_main_process: -# ref_audio_len = mel_lengths[0] -# infer_text = [ -# text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0] -# ] -# with torch.inference_mode(): -# # generated, _ = self.accelerator.unwrap_model(self.model).sample( -# # cond=mel_spec[0][:ref_audio_len].unsqueeze(0), -# # text=infer_text, -# # duration=ref_audio_len * 2, -# # steps=nfe_step, -# # cfg_strength=cfg_strength, -# # sway_sampling_coef=sway_sampling_coef, -# # ) -# # generated = generated.to(torch.float32) -# # gen_mel_spec = generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device) -# # ref_mel_spec = batch["mel"][0].unsqueeze(0) -# gen_mel_spec = pred[0].unsqueeze(0).permute(0, 2, 1) -# ref_mel_spec = cond[0].unsqueeze(0).permute(0, 2, 1) -# if self.vocoder_name == "vocos": -# gen_audio = vocoder.decode(gen_mel_spec).cpu() -# ref_audio = vocoder.decode(ref_mel_spec).cpu() -# elif self.vocoder_name == "bigvgan": -# gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu() -# ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu() - -# torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate) -# torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate) + # if self.log_samples and self.accelerator.is_local_main_process: + # ref_audio_len = mel_lengths[0] + # infer_text = [ + # text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0] + # ] + # with torch.inference_mode(): + # # generated, _ = self.accelerator.unwrap_model(self.model).sample( + # # cond=mel_spec[0][:ref_audio_len].unsqueeze(0), + # # text=infer_text, + # # duration=ref_audio_len * 2, + # # steps=nfe_step, + # # cfg_strength=cfg_strength, + # # sway_sampling_coef=sway_sampling_coef, + # # ) + # # generated = generated.to(torch.float32) + # # gen_mel_spec = generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device) + # # ref_mel_spec = batch["mel"][0].unsqueeze(0) + # gen_mel_spec = pred[0].unsqueeze(0).permute(0, 2, 1) + # ref_mel_spec = cond[0].unsqueeze(0).permute(0, 2, 1) + # if self.vocoder_name == "vocos": + # gen_audio = vocoder.decode(gen_mel_spec).cpu() + # ref_audio = vocoder.decode(ref_mel_spec).cpu() + # elif self.vocoder_name == "bigvgan": + # gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu() + # ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu() + + # torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate) + # torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate) if global_step % self.last_per_steps == 0: self.save_checkpoint(global_step, last=True) diff --git a/f5_tts/model/utils.py b/f5_tts/model/utils.py index 171f9739052a99a711f7049a485cfd352117c080..f933e42041ac66a52ac437d5d3220bf7f4595aa5 100644 --- a/f5_tts/model/utils.py +++ b/f5_tts/model/utils.py @@ -5,13 +5,11 @@ import random from collections import defaultdict from importlib.resources import files +import jieba import torch +from pypinyin import Style, lazy_pinyin from torch.nn.utils.rnn import pad_sequence -import jieba -from pypinyin import lazy_pinyin, Style - - # seed everything @@ -39,7 +37,9 @@ def default(v, d): # tensor helpers -def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821 +def lens_to_mask( + t: int["b"], length: int | None = None +) -> bool["b n"]: # noqa: F722 F821 if not exists(length): length = t.amax() @@ -47,7 +47,9 @@ def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa return seq[None, :] < t[:, None] -def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821 +def mask_from_start_end_indices( + seq_len: int["b"], start: int["b"], end: int["b"] +): # noqa: F722 F821 max_seq_len = seq_len.max().item() seq = torch.arange(max_seq_len, device=start.device).long() start_mask = seq[None, :] >= start[:, None] @@ -55,7 +57,9 @@ def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b" return start_mask & end_mask -def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821 +def mask_from_frac_lengths( + seq_len: int["b"], frac_lengths: float["b"] +): # noqa: F722 F821 lengths = (frac_lengths * seq_len).long() max_start = seq_len - lengths @@ -66,7 +70,9 @@ def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa return mask_from_start_end_indices(seq_len, start, end) -def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722 +def maybe_masked_mean( + t: float["b n d"], mask: bool["b n"] = None +) -> float["b d"]: # noqa: F722 if not exists(mask): return t.mean(dim=1) @@ -90,7 +96,9 @@ def list_str_to_idx( vocab_char_map: dict[str, int], # {char: idx} padding_value=-1, ) -> int["b nt"]: # noqa: F722 - list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style + list_idx_tensors = [ + torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text + ] # pinyin or char style text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) return text @@ -109,13 +117,17 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): - if use "byte", set to 256 (unicode byte range) """ if tokenizer in ["pinyin", "char"]: - tokenizer_path = os.path.join(files("f5_tts").joinpath("../data"), f"{dataset_name}_{tokenizer}/vocab.txt") + tokenizer_path = os.path.join( + files("f5_tts").joinpath("../data"), f"{dataset_name}_{tokenizer}/vocab.txt" + ) with open(tokenizer_path, "r", encoding="utf-8") as f: vocab_char_map = {} for i, char in enumerate(f): vocab_char_map[char[:-1]] = i vocab_size = len(vocab_char_map) - assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char" + assert ( + vocab_char_map[" "] == 0 + ), "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char" elif tokenizer == "byte": vocab_char_map = None @@ -131,7 +143,6 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): return vocab_char_map, vocab_size - # convert char to pinyin jieba.initialize() @@ -145,9 +156,7 @@ def convert_char_to_pinyin(text_list, polyphone=True): ) # add custom trans here, to address oov def is_chinese(c): - return ( - "\u3100" <= c <= "\u9fff" # common chinese characters - ) + return "\u3100" <= c <= "\u9fff" # common chinese characters for text in text_list: char_list = [] @@ -158,7 +167,9 @@ def convert_char_to_pinyin(text_list, polyphone=True): if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": char_list.append(" ") char_list.extend(seg) - elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters + elif polyphone and seg_byte_len == 3 * len( + seg + ): # if pure east asian characters seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) for i, c in enumerate(seg): if is_chinese(c): @@ -170,7 +181,9 @@ def convert_char_to_pinyin(text_list, polyphone=True): char_list.extend(c) elif is_chinese(c): char_list.append(" ") - char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)) + char_list.extend( + lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True) + ) else: char_list.append(c) final_text_list.append(char_list) @@ -224,7 +237,7 @@ def load_checkpoint(model, ckpt_path, device, use_ema=True): def sample_consecutive_steps(float_list): idx = torch.randint(0, len(float_list), size=(1,)) next_idx = idx - 1 - + if next_idx < 0: next_idx = 0 else: diff --git a/f5_tts/model_new/__init__.py b/f5_tts/model_new/__init__.py index d7d5fbed96dda0db5855e6650e5e878eafeff105..7c6b8a55ee012dd4972eb48e3248c425c3dcb365 100644 --- a/f5_tts/model_new/__init__.py +++ b/f5_tts/model_new/__init__.py @@ -4,5 +4,4 @@ from f5_tts.model_new.backbones.unett import UNetT from f5_tts.model_new.cfm import CFM from f5_tts.model_new.trainer import Trainer - __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"] diff --git a/f5_tts/model_new/backbones/dit.py b/f5_tts/model_new/backbones/dit.py index d20434cee873db27ba557b041149c2d8f54903db..76dc1e63a8563cd3c512580d31260f1cd6134c2a 100644 --- a/f5_tts/model_new/backbones/dit.py +++ b/f5_tts/model_new/backbones/dit.py @@ -14,40 +14,49 @@ import torch.nn.functional as F from torch import nn from x_transformers.x_transformers import RotaryEmbedding -from f5_tts.model_new.modules import ( - AdaLayerNorm_Final, - ConvNeXtV2Block, - ConvPositionEmbedding, - DiTBlock, - TimestepEmbedding, - get_pos_embed_indices, - precompute_freqs_cis, -) - +from f5_tts.model_new.modules import (AdaLayerNorm_Final, ConvNeXtV2Block, + ConvPositionEmbedding, DiTBlock, + TimestepEmbedding, get_pos_embed_indices, + precompute_freqs_cis) # Text embedding class TextEmbedding(nn.Module): - def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2): + def __init__( + self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2 + ): super().__init__() - self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token + self.text_embed = nn.Embedding( + text_num_embeds + 1, text_dim + ) # use 0 as filler token self.mask_padding = mask_padding # mask filler and batch padding tokens or not if conv_layers > 0: self.extra_modeling = True self.precompute_max_pos = 4096 # ~44s of 24khz audio - self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False) + self.register_buffer( + "freqs_cis", + precompute_freqs_cis(text_dim, self.precompute_max_pos), + persistent=False, + ) self.text_blocks = nn.Sequential( - *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)] + *[ + ConvNeXtV2Block(text_dim, text_dim * conv_mult) + for _ in range(conv_layers) + ] ) else: self.extra_modeling = False def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 - text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() - text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens + text = ( + text + 1 + ) # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() + text = text[ + :, :seq_len + ] # curtail if character tokens are more than the mel spec tokens batch, text_len = text.shape[0], text.shape[1] text = F.pad(text, (0, seq_len - text_len), value=0) if self.mask_padding: @@ -62,16 +71,22 @@ class TextEmbedding(nn.Module): if self.extra_modeling: # sinus pos emb batch_start = torch.zeros((batch,), dtype=torch.long) - pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos) + pos_idx = get_pos_embed_indices( + batch_start, seq_len, max_pos=self.precompute_max_pos + ) text_pos_embed = self.freqs_cis[pos_idx] text = text + text_pos_embed # convnextv2 blocks if self.mask_padding: - text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) + text = text.masked_fill( + text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0 + ) for block in self.text_blocks: text = block(text) - text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) + text = text.masked_fill( + text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0 + ) else: text = self.text_blocks(text) @@ -87,7 +102,13 @@ class InputEmbedding(nn.Module): self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) - def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722 + def forward( + self, + x: float["b n d"], + cond: float["b n d"], + text_embed: float["b n d"], + drop_audio_cond=False, + ): # noqa: F722 if drop_audio_cond: # cfg for cond audio cond = torch.zeros_like(cond) @@ -127,7 +148,10 @@ class DiT(nn.Module): if text_dim is None: text_dim = mel_dim self.text_embed = TextEmbedding( - text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers + text_num_embeds, + text_dim, + mask_padding=text_mask_padding, + conv_layers=conv_layers, ) self.text_cond, self.text_uncond = None, None # text cache self.input_embed = InputEmbedding(mel_dim, text_dim, dim) @@ -153,7 +177,9 @@ class DiT(nn.Module): for _ in range(depth) ] ) - self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None + self.long_skip_connection = ( + nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None + ) self.norm_out = AdaLayerNorm_Final(dim) # final modulation self.proj_out = nn.Linear(dim, mel_dim) @@ -230,13 +256,24 @@ class DiT(nn.Module): # t: conditioning time, text: text, x: noised audio + cond audio + text t = self.time_embed(time) if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d - x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache) - x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache) + x_cond = self.get_input_embed( + x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache + ) + x_uncond = self.get_input_embed( + x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache + ) x = torch.cat((x_cond, x_uncond), dim=0) t = torch.cat((t, t), dim=0) mask = torch.cat((mask, mask), dim=0) if mask is not None else None else: - x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache) + x = self.get_input_embed( + x, + cond, + text, + drop_audio_cond=drop_audio_cond, + drop_text=drop_text, + cache=cache, + ) rope = self.rotary_embed.forward_from_seq_len(seq_len) @@ -246,7 +283,9 @@ class DiT(nn.Module): for block in self.transformer_blocks: if self.checkpoint_activations: # https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint - x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False) + x = torch.utils.checkpoint.checkpoint( + self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False + ) else: x = block(x, t, mask=mask, rope=rope) diff --git a/f5_tts/model_new/backbones/mmdit.py b/f5_tts/model_new/backbones/mmdit.py index 0a785d2298debd552b23d49dd80b10985baca899..8ebee6d68284f9e564e19e4e405dff0680a7ea99 100644 --- a/f5_tts/model_new/backbones/mmdit.py +++ b/f5_tts/model_new/backbones/mmdit.py @@ -13,15 +13,10 @@ import torch from torch import nn from x_transformers.x_transformers import RotaryEmbedding -from f5_tts.model_new.modules import ( - AdaLayerNorm_Final, - ConvPositionEmbedding, - MMDiTBlock, - TimestepEmbedding, - get_pos_embed_indices, - precompute_freqs_cis, -) - +from f5_tts.model_new.modules import (AdaLayerNorm_Final, + ConvPositionEmbedding, MMDiTBlock, + TimestepEmbedding, get_pos_embed_indices, + precompute_freqs_cis) # text embedding @@ -29,15 +24,25 @@ from f5_tts.model_new.modules import ( class TextEmbedding(nn.Module): def __init__(self, out_dim, text_num_embeds, mask_padding=True): super().__init__() - self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token + self.text_embed = nn.Embedding( + text_num_embeds + 1, out_dim + ) # will use 0 as filler token self.mask_padding = mask_padding # mask filler and batch padding tokens or not self.precompute_max_pos = 1024 - self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False) + self.register_buffer( + "freqs_cis", + precompute_freqs_cis(out_dim, self.precompute_max_pos), + persistent=False, + ) - def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722 - text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() + def forward( + self, text: int["b nt"], drop_text=False + ) -> int["b nt d"]: # noqa: F722 + text = ( + text + 1 + ) # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() if self.mask_padding: text_mask = text == 0 @@ -49,13 +54,17 @@ class TextEmbedding(nn.Module): # sinus pos emb batch_start = torch.zeros((text.shape[0],), dtype=torch.long) batch_text_len = text.shape[1] - pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos) + pos_idx = get_pos_embed_indices( + batch_start, batch_text_len, max_pos=self.precompute_max_pos + ) text_pos_embed = self.freqs_cis[pos_idx] text = text + text_pos_embed if self.mask_padding: - text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) + text = text.masked_fill( + text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0 + ) return text @@ -69,7 +78,9 @@ class AudioEmbedding(nn.Module): self.linear = nn.Linear(2 * in_dim, out_dim) self.conv_pos_embed = ConvPositionEmbedding(out_dim) - def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722 + def forward( + self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False + ): # noqa: F722 if drop_audio_cond: cond = torch.zeros_like(cond) x = torch.cat((x, cond), dim=-1) @@ -99,7 +110,9 @@ class MMDiT(nn.Module): super().__init__() self.time_embed = TimestepEmbedding(dim) - self.text_embed = TextEmbedding(dim, text_num_embeds, mask_padding=text_mask_padding) + self.text_embed = TextEmbedding( + dim, text_num_embeds, mask_padding=text_mask_padding + ) self.text_cond, self.text_uncond = None, None # text cache self.audio_embed = AudioEmbedding(mel_dim, dim) @@ -187,15 +200,24 @@ class MMDiT(nn.Module): # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio t = self.time_embed(time) if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d - x_cond, c_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache) - x_uncond, c_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache) + x_cond, c_cond = self.get_input_embed( + x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache + ) + x_uncond, c_uncond = self.get_input_embed( + x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache + ) x = torch.cat((x_cond, x_uncond), dim=0) c = torch.cat((c_cond, c_uncond), dim=0) t = torch.cat((t, t), dim=0) mask = torch.cat((mask, mask), dim=0) if mask is not None else None else: x, c = self.get_input_embed( - x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache + x, + cond, + text, + drop_audio_cond=drop_audio_cond, + drop_text=drop_text, + cache=cache, ) seq_len = x.shape[1] diff --git a/f5_tts/model_new/backbones/unett.py b/f5_tts/model_new/backbones/unett.py index f9fda55d4ed6995fda7ec13a793ef4cb26a96904..60e212aebda5c2a5906034ab3826be48d00b1bba 100644 --- a/f5_tts/model_new/backbones/unett.py +++ b/f5_tts/model_new/backbones/unett.py @@ -17,41 +17,50 @@ from torch import nn from x_transformers import RMSNorm from x_transformers.x_transformers import RotaryEmbedding -from f5_tts.model_new.modules import ( - Attention, - AttnProcessor, - ConvNeXtV2Block, - ConvPositionEmbedding, - FeedForward, - TimestepEmbedding, - get_pos_embed_indices, - precompute_freqs_cis, -) - +from f5_tts.model_new.modules import (Attention, AttnProcessor, + ConvNeXtV2Block, ConvPositionEmbedding, + FeedForward, TimestepEmbedding, + get_pos_embed_indices, + precompute_freqs_cis) # Text embedding class TextEmbedding(nn.Module): - def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2): + def __init__( + self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2 + ): super().__init__() - self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token + self.text_embed = nn.Embedding( + text_num_embeds + 1, text_dim + ) # use 0 as filler token self.mask_padding = mask_padding # mask filler and batch padding tokens or not if conv_layers > 0: self.extra_modeling = True self.precompute_max_pos = 4096 # ~44s of 24khz audio - self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False) + self.register_buffer( + "freqs_cis", + precompute_freqs_cis(text_dim, self.precompute_max_pos), + persistent=False, + ) self.text_blocks = nn.Sequential( - *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)] + *[ + ConvNeXtV2Block(text_dim, text_dim * conv_mult) + for _ in range(conv_layers) + ] ) else: self.extra_modeling = False def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 - text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() - text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens + text = ( + text + 1 + ) # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() + text = text[ + :, :seq_len + ] # curtail if character tokens are more than the mel spec tokens batch, text_len = text.shape[0], text.shape[1] text = F.pad(text, (0, seq_len - text_len), value=0) if self.mask_padding: @@ -66,16 +75,22 @@ class TextEmbedding(nn.Module): if self.extra_modeling: # sinus pos emb batch_start = torch.zeros((batch,), dtype=torch.long) - pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos) + pos_idx = get_pos_embed_indices( + batch_start, seq_len, max_pos=self.precompute_max_pos + ) text_pos_embed = self.freqs_cis[pos_idx] text = text + text_pos_embed # convnextv2 blocks if self.mask_padding: - text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) + text = text.masked_fill( + text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0 + ) for block in self.text_blocks: text = block(text) - text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) + text = text.masked_fill( + text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0 + ) else: text = self.text_blocks(text) @@ -91,7 +106,13 @@ class InputEmbedding(nn.Module): self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) - def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722 + def forward( + self, + x: float["b n d"], + cond: float["b n d"], + text_embed: float["b n d"], + drop_audio_cond=False, + ): # noqa: F722 if drop_audio_cond: # cfg for cond audio cond = torch.zeros_like(cond) @@ -129,7 +150,10 @@ class UNetT(nn.Module): if text_dim is None: text_dim = mel_dim self.text_embed = TextEmbedding( - text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers + text_num_embeds, + text_dim, + mask_padding=text_mask_padding, + conv_layers=conv_layers, ) self.text_cond, self.text_uncond = None, None # text cache self.input_embed = InputEmbedding(mel_dim, text_dim, dim) @@ -161,7 +185,11 @@ class UNetT(nn.Module): ff_norm = RMSNorm(dim) ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") - skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None + skip_proj = ( + nn.Linear(dim * 2, dim, bias=False) + if needs_skip_proj and is_later_half + else None + ) self.layers.append( nn.ModuleList( @@ -226,13 +254,24 @@ class UNetT(nn.Module): # t: conditioning time, c: context (text + masked cond audio), x: noised input audio t = self.time_embed(time) if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d - x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache) - x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache) + x_cond = self.get_input_embed( + x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache + ) + x_uncond = self.get_input_embed( + x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache + ) x = torch.cat((x_cond, x_uncond), dim=0) t = torch.cat((t, t), dim=0) mask = torch.cat((mask, mask), dim=0) if mask is not None else None else: - x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache) + x = self.get_input_embed( + x, + cond, + text, + drop_audio_cond=drop_audio_cond, + drop_text=drop_text, + cache=cache, + ) # postfix time t to input x, [b n d] -> [b n+1 d] x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x @@ -244,7 +283,9 @@ class UNetT(nn.Module): # flat unet transformer skip_connect_type = self.skip_connect_type skips = [] - for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers): + for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate( + self.layers + ): layer = idx + 1 # skip connection logic diff --git a/f5_tts/model_new/cfm.py b/f5_tts/model_new/cfm.py index a803faf041da479a29b4c144062ea7698bd55b1f..d4a0a4bf445f77d028290710c3a53f306815f676 100644 --- a/f5_tts/model_new/cfm.py +++ b/f5_tts/model_new/cfm.py @@ -19,15 +19,9 @@ from torch.nn.utils.rnn import pad_sequence from torchdiffeq import odeint from f5_tts.model_new.modules import MelSpec -from f5_tts.model_new.utils import ( - default, - exists, - get_epss_timesteps, - lens_to_mask, - list_str_to_idx, - list_str_to_tensor, - mask_from_frac_lengths, -) +from f5_tts.model_new.utils import (default, exists, get_epss_timesteps, + lens_to_mask, list_str_to_idx, + list_str_to_tensor, mask_from_frac_lengths) class CFM(nn.Module): @@ -139,13 +133,17 @@ class CFM(nn.Module): # duplicate test corner for inner time step oberservation if duplicate_test: - test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0) + test_cond = F.pad( + cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0 + ) cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) if no_ref_audio: cond = torch.zeros_like(cond) - cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) + cond_mask = F.pad( + cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False + ) cond_mask = cond_mask.unsqueeze(-1) step_cond = torch.where( cond_mask, cond, torch.zeros_like(cond) @@ -196,7 +194,11 @@ class CFM(nn.Module): for dur in duration: if exists(seed): torch.manual_seed(seed) - y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype)) + y0.append( + torch.randn( + dur, self.num_channels, device=self.device, dtype=step_cond.dtype + ) + ) y0 = pad_sequence(y0, padding_value=0, batch_first=True) t_start = 0 @@ -207,10 +209,14 @@ class CFM(nn.Module): y0 = (1 - t_start) * y0 + t_start * test_cond steps = int(steps * (1 - t_start)) - if t_start == 0 and use_epss: # use Empirically Pruned Step Sampling for low NFE + if ( + t_start == 0 and use_epss + ): # use Empirically Pruned Step Sampling for low NFE t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype) else: - t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype) + t = torch.linspace( + t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype + ) if sway_sampling_coef is not None: t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) @@ -241,7 +247,12 @@ class CFM(nn.Module): inp = inp.permute(0, 2, 1) assert inp.shape[-1] == self.num_channels - batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma + batch, seq_len, dtype, device, _σ1 = ( + *inp.shape[:2], + inp.dtype, + self.device, + self.sigma, + ) # handle text as string if isinstance(text, list): @@ -255,10 +266,16 @@ class CFM(nn.Module): if not exists(lens): lens = torch.full((batch,), seq_len, device=device) - mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch + mask = lens_to_mask( + lens, length=seq_len + ) # useless here, as collate_fn will pad to max length in batch # get a random span to mask out for training conditionally - frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask) + frac_lengths = ( + torch.zeros((batch,), device=self.device) + .float() + .uniform_(*self.frac_lengths_mask) + ) rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) if exists(mask): @@ -292,7 +309,13 @@ class CFM(nn.Module): # apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold pred = self.transformer( - x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, mask=mask + x=φ, + cond=cond, + text=text, + time=time, + drop_audio_cond=drop_audio_cond, + drop_text=drop_text, + mask=mask, ) # flow matching loss diff --git a/f5_tts/model_new/dataset.py b/f5_tts/model_new/dataset.py index 50448c348b082a8bdfd1735ed0dd7d3f9d8f4bda..7aa81a918a7252d7c3bdbecde510706b322a55e6 100644 --- a/f5_tts/model_new/dataset.py +++ b/f5_tts/model_new/dataset.py @@ -62,7 +62,9 @@ class HFDataset(Dataset): audio_tensor = torch.from_numpy(audio).float() if sample_rate != self.target_sample_rate: - resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate) + resampler = torchaudio.transforms.Resample( + sample_rate, self.target_sample_rate + ) audio_tensor = resampler(audio_tensor) audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t') @@ -149,7 +151,9 @@ class CustomDataset(Dataset): # resample if necessary if source_sample_rate != self.target_sample_rate: - resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate) + resampler = torchaudio.transforms.Resample( + source_sample_rate, self.target_sample_rate + ) audio = resampler(audio) # to mel spectrogram @@ -173,7 +177,12 @@ class DynamicBatchSampler(Sampler[list[int]]): """ def __init__( - self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_residual: bool = False + self, + sampler: Sampler[int], + frames_threshold: int, + max_samples=0, + random_seed=None, + drop_residual: bool = False, ): self.sampler = sampler self.frames_threshold = frames_threshold @@ -185,7 +194,8 @@ class DynamicBatchSampler(Sampler[list[int]]): data_source = self.sampler.data_source for idx in tqdm( - self.sampler, desc="Sorting with sampler... if slow, check whether dataset is provided with duration" + self.sampler, + desc="Sorting with sampler... if slow, check whether dataset is provided with duration", ): indices.append((idx, data_source.get_frame_len(idx))) indices.sort(key=lambda elem: elem[1]) @@ -193,9 +203,12 @@ class DynamicBatchSampler(Sampler[list[int]]): batch = [] batch_frames = 0 for idx, frame_len in tqdm( - indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu" + indices, + desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu", ): - if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples): + if batch_frames + frame_len <= self.frames_threshold and ( + max_samples == 0 or len(batch) < max_samples + ): batch.append(idx) batch_frames += frame_len else: @@ -256,7 +269,9 @@ def load_dataset( print("Loading dataset ...") if dataset_type == "CustomDataset": - rel_data_path = str(files("f5_tts").joinpath(f"../../data/{dataset_name}_{tokenizer}")) + rel_data_path = str( + files("f5_tts").joinpath(f"../../data/{dataset_name}_{tokenizer}") + ) if audio_type == "raw": try: train_dataset = load_from_disk(f"{rel_data_path}/raw") @@ -287,7 +302,10 @@ def load_dataset( data_dict = json.load(f) durations = data_dict["duration"] train_dataset = CustomDataset( - train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs + train_dataset, + durations=durations, + preprocessed_mel=preprocessed_mel, + **mel_spec_kwargs, ) elif dataset_type == "HFDataset": @@ -297,7 +315,11 @@ def load_dataset( ) pre, post = dataset_name.split("_") train_dataset = HFDataset( - load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir=str(files("f5_tts").joinpath("../../data"))), + load_dataset( + f"{pre}/{pre}", + split=f"train.{post}", + cache_dir=str(files("f5_tts").joinpath("../../data")), + ), ) return train_dataset diff --git a/f5_tts/model_new/modules.py b/f5_tts/model_new/modules.py index 655a3b69fd8a6882db422e66cdb5656d2c5367d1..33d961734751fe37a5604e6d6e9cc65862246eb8 100644 --- a/f5_tts/model_new/modules.py +++ b/f5_tts/model_new/modules.py @@ -6,6 +6,7 @@ nt - text sequence nw - raw wave length d - dimension """ + # flake8: noqa from __future__ import annotations @@ -22,7 +23,6 @@ from x_transformers.x_transformers import apply_rotary_pos_emb from f5_tts.model_new.utils import is_package_available - # raw wav to mel spec @@ -45,15 +45,25 @@ def get_bigvgan_mel_spectrogram( key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}" if key not in mel_basis_cache: - mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax) - mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()? + mel = librosa_mel_fn( + sr=target_sample_rate, + n_fft=n_fft, + n_mels=n_mel_channels, + fmin=fmin, + fmax=fmax, + ) + mel_basis_cache[key] = ( + torch.from_numpy(mel).float().to(device) + ) # TODO: why they need .float()? hann_window_cache[key] = torch.hann_window(win_length).to(device) mel_basis = mel_basis_cache[key] hann_window = hann_window_cache[key] padding = (n_fft - hop_length) // 2 - waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1) + waveform = torch.nn.functional.pad( + waveform.unsqueeze(1), (padding, padding), mode="reflect" + ).squeeze(1) spec = torch.stft( waveform, @@ -115,7 +125,9 @@ class MelSpec(nn.Module): mel_spec_type="vocos", ): super().__init__() - assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan") + assert mel_spec_type in ["vocos", "bigvgan"], print( + "We only support two extract mel backend: vocos or bigvgan" + ) self.n_fft = n_fft self.hop_length = hop_length @@ -196,7 +208,9 @@ class ConvPositionEmbedding(nn.Module): # rotary positional embedding related -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0): +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0 +): # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning # has some connection to NTK literature # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ @@ -212,10 +226,15 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca def get_pos_embed_indices(start, length, max_pos, scale=1.0): # length = length if isinstance(length, int) else length.max() - scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar + scale = scale * torch.ones_like( + start, dtype=torch.float32 + ) # in case scale is a scalar pos = ( start.unsqueeze(1) - + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long() + + ( + torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) + * scale.unsqueeze(1) + ).long() ) # avoid extra long error. pos = torch.where(pos < max_pos, pos, max_pos - 1) @@ -254,7 +273,9 @@ class ConvNeXtV2Block(nn.Module): dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation ) # depthwise conv self.norm = nn.LayerNorm(dim, eps=1e-6) - self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers + self.pwconv1 = nn.Linear( + dim, intermediate_dim + ) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.grn = GRN(intermediate_dim) self.pwconv2 = nn.Linear(intermediate_dim, dim) @@ -286,7 +307,9 @@ class RMSNorm(nn.Module): if self.native_rms_norm: if self.weight.dtype in [torch.float16, torch.bfloat16]: x = x.to(self.weight.dtype) - x = F.rms_norm(x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps) + x = F.rms_norm( + x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps + ) else: variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.eps) @@ -312,7 +335,9 @@ class AdaLayerNorm(nn.Module): def forward(self, x, emb=None): emb = self.linear(self.silu(emb)) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk( + emb, 6, dim=1 + ) x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] return x, gate_msa, shift_mlp, scale_mlp, gate_mlp @@ -343,14 +368,18 @@ class AdaLayerNorm_Final(nn.Module): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"): + def __init__( + self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none" + ): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim activation = nn.GELU(approximate=approximate) project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation) - self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + self.ff = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) def forward(self, x): return self.ff(x) @@ -375,7 +404,9 @@ class Attention(nn.Module): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + raise ImportError( + "Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) self.processor = processor @@ -435,19 +466,23 @@ class Attention(nn.Module): # Attention processor if is_package_available("flash_attn"): + from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import pad_input, unpad_input - from flash_attn import flash_attn_varlen_func, flash_attn_func class AttnProcessor: def __init__( self, - pe_attn_head: int | None = None, # number of attention head to apply rope, None for all + pe_attn_head: ( + int | None + ) = None, # number of attention head to apply rope, None for all attn_backend: str = "torch", # "torch" or "flash_attn" attn_mask_enabled: bool = True, ): if attn_backend == "flash_attn": - assert is_package_available("flash_attn"), "Please install flash-attn first." + assert is_package_available( + "flash_attn" + ), "Please install flash-attn first." self.pe_attn_head = pe_attn_head self.attn_backend = attn_backend @@ -483,12 +518,18 @@ class AttnProcessor: # apply rotary position embedding if rope is not None: freqs, xpos_scale = rope - q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + ) if self.pe_attn_head is not None: pn = self.pe_attn_head - query[:, :pn, :, :] = apply_rotary_pos_emb(query[:, :pn, :, :], freqs, q_xpos_scale) - key[:, :pn, :, :] = apply_rotary_pos_emb(key[:, :pn, :, :], freqs, k_xpos_scale) + query[:, :pn, :, :] = apply_rotary_pos_emb( + query[:, :pn, :, :], freqs, q_xpos_scale + ) + key[:, :pn, :, :] = apply_rotary_pos_emb( + key[:, :pn, :, :], freqs, k_xpos_scale + ) else: query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) @@ -498,10 +539,14 @@ class AttnProcessor: if self.attn_mask_enabled and mask is not None: attn_mask = mask attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' - attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + attn_mask = attn_mask.expand( + batch_size, attn.heads, query.shape[-2], key.shape[-2] + ) else: attn_mask = None - x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) + x = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False + ) x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) elif self.attn_backend == "flash_attn": @@ -509,7 +554,9 @@ class AttnProcessor: key = key.transpose(1, 2) value = value.transpose(1, 2) if self.attn_mask_enabled and mask is not None: - query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask) + query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input( + query, mask + ) key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask) value, _, _, _, _ = unpad_input(value, mask) x = flash_attn_varlen_func( @@ -595,12 +642,16 @@ class JointAttnProcessor: # apply rope for context and noised input independently if rope is not None: freqs, xpos_scale = rope - q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + ) query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) if c_rope is not None: freqs, xpos_scale = c_rope - q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + ) c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale) c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale) @@ -613,11 +664,15 @@ class JointAttnProcessor: if mask is not None: attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text) attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' - attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + attn_mask = attn_mask.expand( + batch_size, attn.heads, query.shape[-2], key.shape[-2] + ) else: attn_mask = None - x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) + x = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False + ) x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) x = x.to(query.dtype) @@ -675,7 +730,9 @@ class DiTBlock(nn.Module): ) self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + self.ff = FeedForward( + dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh" + ) def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding # pre-norm & modulation for attention input @@ -708,14 +765,26 @@ class MMDiTBlock(nn.Module): """ def __init__( - self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_dim=None, context_pre_only=False, qk_norm=None + self, + dim, + heads, + dim_head, + ff_mult=4, + dropout=0.1, + context_dim=None, + context_pre_only=False, + qk_norm=None, ): super().__init__() if context_dim is None: context_dim = dim self.context_pre_only = context_pre_only - self.attn_norm_c = AdaLayerNorm_Final(context_dim) if context_pre_only else AdaLayerNorm(context_dim) + self.attn_norm_c = ( + AdaLayerNorm_Final(context_dim) + if context_pre_only + else AdaLayerNorm(context_dim) + ) self.attn_norm_x = AdaLayerNorm(dim) self.attn = Attention( processor=JointAttnProcessor(), @@ -729,24 +798,38 @@ class MMDiTBlock(nn.Module): ) if not context_pre_only: - self.ff_norm_c = nn.LayerNorm(context_dim, elementwise_affine=False, eps=1e-6) - self.ff_c = FeedForward(dim=context_dim, mult=ff_mult, dropout=dropout, approximate="tanh") + self.ff_norm_c = nn.LayerNorm( + context_dim, elementwise_affine=False, eps=1e-6 + ) + self.ff_c = FeedForward( + dim=context_dim, mult=ff_mult, dropout=dropout, approximate="tanh" + ) else: self.ff_norm_c = None self.ff_c = None self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + self.ff_x = FeedForward( + dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh" + ) - def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding + def forward( + self, x, c, t, mask=None, rope=None, c_rope=None + ): # x: noised input, c: context, t: time embedding # pre-norm & modulation for attention input if self.context_pre_only: norm_c = self.attn_norm_c(c, t) else: - norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t) - norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t) + norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c( + c, emb=t + ) + norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x( + x, emb=t + ) # attention - x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope) + x_attn_output, c_attn_output = self.attn( + x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope + ) # process attention output for context c if self.context_pre_only: @@ -754,7 +837,9 @@ class MMDiTBlock(nn.Module): else: # if not last layer c = c + c_gate_msa.unsqueeze(1) * c_attn_output - norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + norm_c = ( + self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + ) c_ff_output = self.ff_c(norm_c) c = c + c_gate_mlp.unsqueeze(1) * c_ff_output @@ -775,7 +860,9 @@ class TimestepEmbedding(nn.Module): def __init__(self, dim, freq_embed_dim=256): super().__init__() self.time_embed = SinusPositionEmbedding(freq_embed_dim) - self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_mlp = nn.Sequential( + nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim) + ) def forward(self, timestep: float["b"]): time_hidden = self.time_embed(timestep) diff --git a/f5_tts/model_new/trainer.py b/f5_tts/model_new/trainer.py index 45cbc530e2985fe1f8e4adaa5454363ef7722e73..9417b4bb2c0b62f478d2bf6f4e1bbc980aa51372 100644 --- a/f5_tts/model_new/trainer.py +++ b/f5_tts/model_new/trainer.py @@ -19,7 +19,6 @@ from f5_tts.model import CFM from f5_tts.model.dataset import DynamicBatchSampler, collate_fn from f5_tts.model.utils import default, exists - # trainer @@ -70,7 +69,13 @@ class Trainer: self.logger = logger if self.logger == "wandb": if exists(wandb_resume_id): - init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}} + init_kwargs = { + "wandb": { + "resume": "allow", + "name": wandb_run_name, + "id": wandb_resume_id, + } + } else: init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}} @@ -138,7 +143,9 @@ class Trainer: self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate) else: self.optimizer = AdamW(model.parameters(), lr=learning_rate) - self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) + self.model, self.optimizer = self.accelerator.prepare( + self.model, self.optimizer + ) @property def is_main(self): @@ -157,12 +164,16 @@ class Trainer: if not os.path.exists(self.checkpoint_path): os.makedirs(self.checkpoint_path) if last: - self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt") + self.accelerator.save( + checkpoint, f"{self.checkpoint_path}/model_last.pt" + ) print(f"Saved last checkpoint at update {update}") else: if self.keep_last_n_checkpoints == 0: return - self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt") + self.accelerator.save( + checkpoint, f"{self.checkpoint_path}/model_{update}.pt" + ) if self.keep_last_n_checkpoints > 0: # Updated logic to exclude pretrained model from rotation checkpoints = [ @@ -183,7 +194,10 @@ class Trainer: if ( not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) - or not any(filename.endswith((".pt", ".safetensors")) for filename in os.listdir(self.checkpoint_path)) + or not any( + filename.endswith((".pt", ".safetensors")) + for filename in os.listdir(self.checkpoint_path) + ) ): return 0 @@ -195,11 +209,16 @@ class Trainer: all_checkpoints = [ f for f in os.listdir(self.checkpoint_path) - if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith((".pt", ".safetensors")) + if (f.startswith("model_") or f.startswith("pretrained_")) + and f.endswith((".pt", ".safetensors")) ] # First try to find regular training checkpoints - training_checkpoints = [f for f in all_checkpoints if f.startswith("model_") and f != "model_last.pt"] + training_checkpoints = [ + f + for f in all_checkpoints + if f.startswith("model_") and f != "model_last.pt" + ] if training_checkpoints: latest_checkpoint = sorted( training_checkpoints, @@ -207,21 +226,30 @@ class Trainer: )[-1] else: # If no training checkpoints, use pretrained model - latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_")) + latest_checkpoint = next( + f for f in all_checkpoints if f.startswith("pretrained_") + ) if latest_checkpoint.endswith(".safetensors"): # always a pretrained checkpoint from safetensors.torch import load_file - checkpoint = load_file(f"{self.checkpoint_path}/{latest_checkpoint}", device="cpu") + checkpoint = load_file( + f"{self.checkpoint_path}/{latest_checkpoint}", device="cpu" + ) checkpoint = {"ema_model_state_dict": checkpoint} elif latest_checkpoint.endswith(".pt"): # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ checkpoint = torch.load( - f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu" + f"{self.checkpoint_path}/{latest_checkpoint}", + weights_only=True, + map_location="cpu", ) # patch for backward compatibility, 305e3ea - for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]: + for key in [ + "ema_model.mel_spec.mel_stft.mel_scale.fb", + "ema_model.mel_spec.mel_stft.spectrogram.window", + ]: if key in checkpoint["ema_model_state_dict"]: del checkpoint["ema_model_state_dict"][key] @@ -231,17 +259,24 @@ class Trainer: if "update" in checkpoint or "step" in checkpoint: # patch for backward compatibility, with before f992c4e if "step" in checkpoint: - checkpoint["update"] = checkpoint["step"] // self.grad_accumulation_steps + checkpoint["update"] = ( + checkpoint["step"] // self.grad_accumulation_steps + ) if self.grad_accumulation_steps > 1 and self.is_main: print( "F5-TTS WARNING: Loading checkpoint saved with per_steps logic (before f992c4e), will convert to per_updates according to grad_accumulation_steps setting, may have unexpected behaviour." ) # patch for backward compatibility, 305e3ea - for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]: + for key in [ + "mel_spec.mel_stft.mel_scale.fb", + "mel_spec.mel_stft.spectrogram.window", + ]: if key in checkpoint["model_state_dict"]: del checkpoint["model_state_dict"][key] - self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"]) + self.accelerator.unwrap_model(self.model).load_state_dict( + checkpoint["model_state_dict"] + ) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) if self.scheduler: self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) @@ -252,21 +287,30 @@ class Trainer: for k, v in checkpoint["ema_model_state_dict"].items() if k not in ["initted", "update", "step"] } - self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"]) + self.accelerator.unwrap_model(self.model).load_state_dict( + checkpoint["model_state_dict"] + ) update = 0 del checkpoint gc.collect() return update - def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None): + def train( + self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None + ): if self.log_samples: - from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef + from f5_tts.infer.utils_infer import (cfg_strength, load_vocoder, + nfe_step, sway_sampling_coef) vocoder = load_vocoder( - vocoder_name=self.vocoder_name, is_local=self.is_local_vocoder, local_path=self.local_vocoder_path + vocoder_name=self.vocoder_name, + is_local=self.is_local_vocoder, + local_path=self.local_vocoder_path, ) - target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate + target_sample_rate = self.accelerator.unwrap_model( + self.model + ).mel_spec.target_sample_rate log_samples_path = f"{self.checkpoint_path}/samples" os.makedirs(log_samples_path, exist_ok=True) @@ -306,7 +350,9 @@ class Trainer: batch_sampler=batch_sampler, ) else: - raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}") + raise ValueError( + f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}" + ) # accelerator.prepare() dispatches batches to devices; # which means the length of dataloader calculated before, should consider the number of devices @@ -314,12 +360,24 @@ class Trainer: self.num_warmup_updates * self.accelerator.num_processes ) # consider a fixed warmup steps while using accelerate multi-gpu ddp # otherwise by default with split_batches=False, warmup steps change with num_processes - total_updates = math.ceil(len(train_dataloader) / self.grad_accumulation_steps) * self.epochs + total_updates = ( + math.ceil(len(train_dataloader) / self.grad_accumulation_steps) + * self.epochs + ) decay_updates = total_updates - warmup_updates - warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_updates) - decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_updates) + warmup_scheduler = LinearLR( + self.optimizer, + start_factor=1e-8, + end_factor=1.0, + total_iters=warmup_updates, + ) + decay_scheduler = LinearLR( + self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_updates + ) self.scheduler = SequentialLR( - self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_updates] + self.optimizer, + schedulers=[warmup_scheduler, decay_scheduler], + milestones=[warmup_updates], ) train_dataloader, self.scheduler = self.accelerator.prepare( train_dataloader, self.scheduler @@ -332,21 +390,27 @@ class Trainer: start_step = start_update * self.grad_accumulation_steps skipped_epoch = int(start_step // orig_epoch_step) skipped_batch = start_step % orig_epoch_step - skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch) + skipped_dataloader = self.accelerator.skip_first_batches( + train_dataloader, num_batches=skipped_batch + ) else: skipped_epoch = 0 for epoch in range(skipped_epoch, self.epochs): self.model.train() if exists(resumable_with_seed) and epoch == skipped_epoch: - progress_bar_initial = math.ceil(skipped_batch / self.grad_accumulation_steps) + progress_bar_initial = math.ceil( + skipped_batch / self.grad_accumulation_steps + ) current_dataloader = skipped_dataloader else: progress_bar_initial = 0 current_dataloader = train_dataloader # Set epoch for the batch sampler if it exists - if hasattr(train_dataloader, "batch_sampler") and hasattr(train_dataloader.batch_sampler, "set_epoch"): + if hasattr(train_dataloader, "batch_sampler") and hasattr( + train_dataloader.batch_sampler, "set_epoch" + ): train_dataloader.batch_sampler.set_epoch(epoch) progress_bar = tqdm( @@ -364,17 +428,29 @@ class Trainer: mel_lengths = batch["mel_lengths"] # TODO. add duration predictor training - if self.duration_predictor is not None and self.accelerator.is_local_main_process: - dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations")) - self.accelerator.log({"duration loss": dur_loss.item()}, step=global_update) + if ( + self.duration_predictor is not None + and self.accelerator.is_local_main_process + ): + dur_loss = self.duration_predictor( + mel_spec, lens=batch.get("durations") + ) + self.accelerator.log( + {"duration loss": dur_loss.item()}, step=global_update + ) loss, cond, pred = self.model( - mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler + mel_spec, + text=text_inputs, + lens=mel_lengths, + noise_scheduler=self.noise_scheduler, ) self.accelerator.backward(loss) if self.max_grad_norm > 0 and self.accelerator.sync_gradients: - self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) + self.accelerator.clip_grad_norm_( + self.model.parameters(), self.max_grad_norm + ) self.optimizer.step() self.scheduler.step() @@ -386,29 +462,44 @@ class Trainer: global_update += 1 progress_bar.update(1) - progress_bar.set_postfix(update=str(global_update), loss=loss.item()) + progress_bar.set_postfix( + update=str(global_update), loss=loss.item() + ) if self.accelerator.is_local_main_process: self.accelerator.log( - {"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_update + {"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, + step=global_update, ) if self.logger == "tensorboard": self.writer.add_scalar("loss", loss.item(), global_update) - self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update) + self.writer.add_scalar( + "lr", self.scheduler.get_last_lr()[0], global_update + ) - if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients: + if ( + global_update % self.last_per_updates == 0 + and self.accelerator.sync_gradients + ): self.save_checkpoint(global_update, last=True) - if global_update % self.save_per_updates == 0 and self.accelerator.sync_gradients: + if ( + global_update % self.save_per_updates == 0 + and self.accelerator.sync_gradients + ): self.save_checkpoint(global_update) if self.log_samples and self.accelerator.is_local_main_process: ref_audio_len = mel_lengths[0] infer_text = [ - text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0] + text_inputs[0] + + ([" "] if isinstance(text_inputs[0], list) else " ") + + text_inputs[0] ] with torch.inference_mode(): - generated, _ = self.accelerator.unwrap_model(self.model).sample( + generated, _ = self.accelerator.unwrap_model( + self.model + ).sample( cond=mel_spec[0][:ref_audio_len].unsqueeze(0), text=infer_text, duration=ref_audio_len * 2, @@ -417,7 +508,11 @@ class Trainer: sway_sampling_coef=sway_sampling_coef, ) generated = generated.to(torch.float32) - gen_mel_spec = generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device) + gen_mel_spec = ( + generated[:, ref_audio_len:, :] + .permute(0, 2, 1) + .to(self.accelerator.device) + ) ref_mel_spec = batch["mel"][0].unsqueeze(0) if self.vocoder_name == "vocos": gen_audio = vocoder.decode(gen_mel_spec).cpu() @@ -427,10 +522,14 @@ class Trainer: ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu() torchaudio.save( - f"{log_samples_path}/update_{global_update}_gen.wav", gen_audio, target_sample_rate + f"{log_samples_path}/update_{global_update}_gen.wav", + gen_audio, + target_sample_rate, ) torchaudio.save( - f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate + f"{log_samples_path}/update_{global_update}_ref.wav", + ref_audio, + target_sample_rate, ) self.model.train() diff --git a/f5_tts/model_new/utils.py b/f5_tts/model_new/utils.py index c5c38292ec56180190ddd8ea1f3cb0d4ca9a3247..10ff360f3915c6efc6f855d334568ad3a691c99b 100644 --- a/f5_tts/model_new/utils.py +++ b/f5_tts/model_new/utils.py @@ -10,7 +10,6 @@ import torch from pypinyin import Style, lazy_pinyin from torch.nn.utils.rnn import pad_sequence - # seed everything @@ -48,7 +47,9 @@ def is_package_available(package_name: str) -> bool: # tensor helpers -def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821 +def lens_to_mask( + t: int["b"], length: int | None = None +) -> bool["b n"]: # noqa: F722 F821 if not exists(length): length = t.amax() @@ -56,7 +57,9 @@ def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa return seq[None, :] < t[:, None] -def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821 +def mask_from_start_end_indices( + seq_len: int["b"], start: int["b"], end: int["b"] +): # noqa: F722 F821 max_seq_len = seq_len.max().item() seq = torch.arange(max_seq_len, device=start.device).long() start_mask = seq[None, :] >= start[:, None] @@ -64,7 +67,9 @@ def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b" return start_mask & end_mask -def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821 +def mask_from_frac_lengths( + seq_len: int["b"], frac_lengths: float["b"] +): # noqa: F722 F821 lengths = (frac_lengths * seq_len).long() max_start = seq_len - lengths @@ -75,7 +80,9 @@ def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa return mask_from_start_end_indices(seq_len, start, end) -def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722 +def maybe_masked_mean( + t: float["b n d"], mask: bool["b n"] = None +) -> float["b d"]: # noqa: F722 if not exists(mask): return t.mean(dim=1) @@ -99,7 +106,9 @@ def list_str_to_idx( vocab_char_map: dict[str, int], # {char: idx} padding_value=-1, ) -> int["b nt"]: # noqa: F722 - list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style + list_idx_tensors = [ + torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text + ] # pinyin or char style text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) return text @@ -118,13 +127,18 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): - if use "byte", set to 256 (unicode byte range) """ if tokenizer in ["pinyin", "char"]: - tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt") + tokenizer_path = os.path.join( + files("f5_tts").joinpath("../../data"), + f"{dataset_name}_{tokenizer}/vocab.txt", + ) with open(tokenizer_path, "r", encoding="utf-8") as f: vocab_char_map = {} for i, char in enumerate(f): vocab_char_map[char[:-1]] = i vocab_size = len(vocab_char_map) - assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char" + assert ( + vocab_char_map[" "] == 0 + ), "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char" elif tokenizer == "byte": vocab_char_map = None @@ -154,9 +168,7 @@ def convert_char_to_pinyin(text_list, polyphone=True): ) # add custom trans here, to address oov def is_chinese(c): - return ( - "\u3100" <= c <= "\u9fff" # common chinese characters - ) + return "\u3100" <= c <= "\u9fff" # common chinese characters for text in text_list: char_list = [] @@ -167,7 +179,9 @@ def convert_char_to_pinyin(text_list, polyphone=True): if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": char_list.append(" ") char_list.extend(seg) - elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters + elif polyphone and seg_byte_len == 3 * len( + seg + ): # if pure east asian characters seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) for i, c in enumerate(seg): if is_chinese(c): @@ -179,7 +193,9 @@ def convert_char_to_pinyin(text_list, polyphone=True): char_list.extend(c) elif is_chinese(c): char_list.append(" ") - char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)) + char_list.extend( + lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True) + ) else: char_list.append(c) final_text_list.append(char_list) diff --git a/f5_tts/runtime/triton_trtllm/benchmark.py b/f5_tts/runtime/triton_trtllm/benchmark.py index cb054ec6b64c47409ff29d49f065e4787de4f8b5..9044718e4e367571c505b9c66a19bba00a825e4f 100644 --- a/f5_tts/runtime/triton_trtllm/benchmark.py +++ b/f5_tts/runtime/triton_trtllm/benchmark.py @@ -51,7 +51,6 @@ from torch.utils.data import DataLoader, DistributedSampler from tqdm import tqdm from vocos import Vocos - torch.manual_seed(0) @@ -64,7 +63,9 @@ def get_args(): choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"], help="huggingface dataset split name", ) - parser.add_argument("--output-dir", required=True, type=str, help="dir to save result") + parser.add_argument( + "--output-dir", required=True, type=str, help="dir to save result" + ) parser.add_argument( "--vocab-file", required=True, @@ -89,8 +90,12 @@ def get_args(): type=int, help="batch size (per-device) for inference", ) - parser.add_argument("--num-workers", type=int, default=0, help="workers for dataloader") - parser.add_argument("--prefetch", type=int, default=None, help="prefetch for dataloader") + parser.add_argument( + "--num-workers", type=int, default=0, help="workers for dataloader" + ) + parser.add_argument( + "--prefetch", type=int, default=None, help="prefetch for dataloader" + ) parser.add_argument( "--vocoder", default="vocos", @@ -105,8 +110,16 @@ def get_args(): ) parser.add_argument("--enable-warmup", action="store_true") parser.add_argument("--remove-input-padding", action="store_true") - parser.add_argument("--use-perf", action="store_true", help="use nvtx to record performance") - parser.add_argument("--backend-type", type=str, default="triton", choices=["trt", "pytorch"], help="backend type") + parser.add_argument( + "--use-perf", action="store_true", help="use nvtx to record performance" + ) + parser.add_argument( + "--backend-type", + type=str, + default="triton", + choices=["trt", "pytorch"], + help="backend type", + ) args = parser.parse_args() return args @@ -126,7 +139,13 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False): torch.cuda.nvtx.range_push("data_collator") target_sample_rate = 24000 target_rms = 0.1 - ids, ref_mel_list, ref_mel_len_list, estimated_reference_target_mel_len, reference_target_texts_list = ( + ( + ids, + ref_mel_list, + ref_mel_len_list, + estimated_reference_target_mel_len, + reference_target_texts_list, + ) = ( [], [], [], @@ -170,7 +189,14 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False): ref_mel_len_list.append(ref_mel_len) estimated_reference_target_mel_len.append( - int(ref_mel.shape[0] * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8")))) + int( + ref_mel.shape[0] + * ( + 1 + + len(target_text.encode("utf-8")) + / len(prompt_text.encode("utf-8")) + ) + ) ) max_seq_len = max(estimated_reference_target_mel_len) @@ -182,12 +208,22 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False): for i, item in enumerate(text_pad_sequence): text_pad_sequence[i] = F.pad( - item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1 + item, + (0, estimated_reference_target_mel_len[i] - len(item)), + mode="constant", + value=-1, ) - text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS - text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(device) + text_pad_sequence[ + i + ] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS + text_pad_sequence = pad_sequence( + text_pad_sequence, padding_value=-1, batch_first=True + ).to(device) text_pad_sequence = F.pad( - text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1 + text_pad_sequence, + (0, max_seq_len - text_pad_sequence.shape[1]), + mode="constant", + value=-1, ) if use_perf: torch.cuda.nvtx.range_pop() @@ -252,7 +288,9 @@ def convert_char_to_pinyin(reference_target_texts_list, polyphone=True): if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": char_list.append(" ") char_list.extend(seg) - elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters + elif polyphone and seg_byte_len == 3 * len( + seg + ): # if pure east asian characters seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) for i, c in enumerate(seg): if is_chinese(c): @@ -264,7 +302,9 @@ def convert_char_to_pinyin(reference_target_texts_list, polyphone=True): char_list.extend(c) elif is_chinese(c): char_list.append(" ") - char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)) + char_list.extend( + lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True) + ) else: char_list.append(c) final_reference_target_texts_list.append(char_list) @@ -277,13 +317,20 @@ def list_str_to_idx( vocab_char_map: Dict[str, int], # {char: idx} padding_value=-1, ): - list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style + list_idx_tensors = [ + torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text + ] # pinyin or char style # text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) return list_idx_tensors def load_vocoder( - vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None + vocoder_name="vocos", + is_local=False, + local_path="", + device="cuda", + hf_cache_dir=None, + vocoder_trt_engine_path=None, ): if vocoder_name == "vocos": if vocoder_trt_engine_path is not None: @@ -297,8 +344,14 @@ def load_vocoder( else: print("Download Vocos from huggingface charactr/vocos-mel-24khz") repo_id = "charactr/vocos-mel-24khz" - config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml") - model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin") + config_path = hf_hub_download( + repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml" + ) + model_path = hf_hub_download( + repo_id=repo_id, + cache_dir=hf_cache_dir, + filename="pytorch_model.bin", + ) vocoder = Vocos.from_hparams(config_path) state_dict = torch.load(model_path, map_location="cpu", weights_only=True) from vocos.feature_extractors import EncodecFeatures @@ -343,14 +396,21 @@ class VocosTensorRT: with open(engine_path, "rb") as f: engine_buffer = f.read() self.session = Session.from_serialized_engine(engine_buffer) - self.stream = stream if stream is not None else torch.cuda.current_stream().cuda_stream + self.stream = ( + stream if stream is not None else torch.cuda.current_stream().cuda_stream + ) def decode(self, mels): mels = mels.contiguous() inputs = {"mel": mels} - output_info = self.session.infer_shapes([TensorInfo("mel", trt.DataType.FLOAT, mels.shape)]) + output_info = self.session.infer_shapes( + [TensorInfo("mel", trt.DataType.FLOAT, mels.shape)] + ) outputs = { - t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda") for t in output_info + t.name: torch.empty( + tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda" + ) + for t in output_info } ok = self.session.run(inputs, outputs, self.stream) @@ -376,12 +436,18 @@ def main(): config = json.load(f) if args.backend_type == "trt": model = F5TTS( - config, debug_mode=False, tllm_model_dir=tllm_model_dir, model_path=args.model_path, vocab_size=vocab_size + config, + debug_mode=False, + tllm_model_dir=tllm_model_dir, + model_path=args.model_path, + vocab_size=vocab_size, ) elif args.backend_type == "pytorch": import sys - sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/") + sys.path.append( + f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/" + ) from f5_tts.infer.utils_infer import load_model from f5_tts.model import DiT @@ -398,7 +464,9 @@ def main(): model = load_model(DiT, F5TTS_model_cfg, args.model_path) vocoder = load_vocoder( - vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path + vocoder_name=args.vocoder, + device=device, + vocoder_trt_engine_path=args.vocoder_trt_engine_path, ) dataset = load_dataset( @@ -411,7 +479,9 @@ def main(): prompt_audio_len = example["prompt_audio"]["array"].shape[0] scale_factor = 1 + len(example["target_text"]) / len(example["prompt_text"]) estimated_duration = prompt_audio_len * scale_factor - example["estimated_duration"] = estimated_duration / example["prompt_audio"]["sampling_rate"] + example["estimated_duration"] = ( + estimated_duration / example["prompt_audio"]["sampling_rate"] + ) return example dataset = dataset.map(add_estimated_duration) @@ -442,12 +512,18 @@ def main(): if args.enable_warmup: for batch in dataloader: - ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device) + ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch[ + "ref_mel_len_batch" + ].to(device) text_pad_seq = batch["text_pad_sequence"].to(device) total_mel_lens = batch["estimated_reference_target_mel_len"] if args.backend_type == "trt": _ = model.sample( - text_pad_seq, ref_mels, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding + text_pad_seq, + ref_mels, + ref_mel_lens, + total_mel_lens, + remove_input_padding=args.remove_input_padding, ) elif args.backend_type == "pytorch": with torch.inference_mode(): @@ -475,7 +551,9 @@ def main(): for batch in dataloader: if args.use_perf: torch.cuda.nvtx.range_push("data sample") - ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device) + ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch[ + "ref_mel_len_batch" + ].to(device) text_pad_seq = batch["text_pad_sequence"].to(device) total_mel_lens = batch["estimated_reference_target_mel_len"] diff --git a/f5_tts/runtime/triton_trtllm/client_grpc.py b/f5_tts/runtime/triton_trtllm/client_grpc.py index 1eb9b5b3fb4b5e35e41a4c5181a3dbde2df9763f..28c4d61346c0ca19713c182b695e001a49125304 100644 --- a/f5_tts/runtime/triton_trtllm/client_grpc.py +++ b/f5_tts/runtime/triton_trtllm/client_grpc.py @@ -64,8 +64,12 @@ def write_triton_stats(stats, summary_file): "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n" ) summary_f.write("To learn more about the log, please refer to: \n") - summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n") - summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n") + summary_f.write( + "1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n" + ) + summary_f.write( + "2. https://github.com/triton-inference-server/server/issues/5374 \n\n" + ) summary_f.write( "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n" ) @@ -86,7 +90,9 @@ def write_triton_stats(stats, summary_file): total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9 total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9 total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9 - total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9 + total_output_time_s = ( + int(model_inference_stats["compute_output"]["ns"]) / 1e9 + ) summary_f.write( f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa ) @@ -97,7 +103,11 @@ def write_triton_stats(stats, summary_file): compute_output = batch["compute_output"] compute_infer = batch["compute_infer"] batch_count = int(compute_infer["count"]) - assert compute_infer["count"] == compute_output["count"] == compute_input["count"] + assert ( + compute_infer["count"] + == compute_output["count"] + == compute_input["count"] + ) compute_infer_time_ms = int(compute_infer["ns"]) / 1e6 compute_input_time_ms = int(compute_input["ns"]) / 1e6 compute_output_time_ms = int(compute_output["ns"]) / 1e6 @@ -113,7 +123,9 @@ def write_triton_stats(stats, summary_file): def get_args(): - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) parser.add_argument( "--server-addr", @@ -254,7 +266,9 @@ async def send( for i, item in enumerate(manifest_item_list): if i % log_interval == 0: print(f"{name}: {i}/{len(manifest_item_list)}") - waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=24000) + waveform, sample_rate = load_audio( + item["audio_filepath"], target_sample_rate=24000 + ) duration = len(waveform) / sample_rate lengths = np.array([[len(waveform)]], dtype=np.int32) @@ -269,7 +283,10 @@ async def send( 1, padding_duration * sample_rate - * ((int(estimated_target_duration + duration) // padding_duration) + 1), + * ( + (int(estimated_target_duration + duration) // padding_duration) + + 1 + ), ), dtype=np.float32, ) @@ -281,8 +298,12 @@ async def send( samples = samples.reshape(1, -1).astype(np.float32) inputs = [ - protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)), - protocol_client.InferInput("reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)), + protocol_client.InferInput( + "reference_wav", samples.shape, np_to_triton_dtype(samples.dtype) + ), + protocol_client.InferInput( + "reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype) + ), protocol_client.InferInput("reference_text", [1, 1], "BYTES"), protocol_client.InferInput("target_text", [1, 1], "BYTES"), ] @@ -301,13 +322,17 @@ async def send( sequence_id = 100000000 + i + task_id * 10 start = time.time() - response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs) + response = await triton_client.infer( + model_name, inputs, request_id=str(sequence_id), outputs=outputs + ) audio = response.as_numpy("waveform").reshape(-1) end = time.time() - start - audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav") + audio_save_path = os.path.join( + audio_save_dir, f"{item['target_audio_path']}.wav" + ) sf.write(audio_save_path, audio, save_sample_rate, "PCM_16") actual_duration = len(audio) / save_sample_rate @@ -341,7 +366,9 @@ def load_manifests(manifest_path): def split_data(data, k): n = len(data) if n < k: - print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.") + print( + f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}." + ) k = n quotient = n // k @@ -461,7 +488,9 @@ async def main(): stats = await triton_client.get_inference_statistics(model_name="", as_json=True) write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt") - metadata = await triton_client.get_model_config(model_name=args.model_name, as_json=True) + metadata = await triton_client.get_model_config( + model_name=args.model_name, as_json=True + ) with open(f"{args.log_dir}/model_config-{name}.json", "w") as f: json.dump(metadata, f, indent=4) diff --git a/f5_tts/runtime/triton_trtllm/client_http.py b/f5_tts/runtime/triton_trtllm/client_http.py index 804ba5c605e08352027971ae42a22de12375956b..96f8784c04c2841d211de59f0248c3849d251cbd 100644 --- a/f5_tts/runtime/triton_trtllm/client_http.py +++ b/f5_tts/runtime/triton_trtllm/client_http.py @@ -31,7 +31,9 @@ import soundfile as sf def get_args(): - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) parser.add_argument( "--server-url", @@ -91,15 +93,30 @@ def prepare_request( data = { "inputs": [ - {"name": "reference_wav", "shape": samples.shape, "datatype": "FP32", "data": samples.tolist()}, + { + "name": "reference_wav", + "shape": samples.shape, + "datatype": "FP32", + "data": samples.tolist(), + }, { "name": "reference_wav_len", "shape": lengths.shape, "datatype": "INT32", "data": lengths.tolist(), }, - {"name": "reference_text", "shape": [1, 1], "datatype": "BYTES", "data": [reference_text]}, - {"name": "target_text", "shape": [1, 1], "datatype": "BYTES", "data": [target_text]}, + { + "name": "reference_text", + "shape": [1, 1], + "datatype": "BYTES", + "data": [reference_text], + }, + { + "name": "target_text", + "shape": [1, 1], + "datatype": "BYTES", + "data": [target_text], + }, ] } @@ -135,7 +152,11 @@ if __name__ == "__main__": data = prepare_request(samples, args.reference_text, args.target_text) rsp = requests.post( - url, headers={"Content-Type": "application/json"}, json=data, verify=False, params={"request_id": "0"} + url, + headers={"Content-Type": "application/json"}, + json=data, + verify=False, + params={"request_id": "0"}, ) result = rsp.json() audio = result["outputs"][0]["data"] diff --git a/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py b/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py index cace21987091bbf8a4bbd34b4e80d22525be92f0..4435416dcd785cd7897fb8f750630ee4599ed06f 100644 --- a/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py +++ b/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py @@ -17,7 +17,9 @@ from tensorrt_llm.runtime.session import Session def remove_tensor_padding(input_tensor, input_tensor_lengths=None): # Audio tensor case: batch, seq_len, feature_len # position_ids case: batch, seq_len - assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor" + assert ( + input_tensor_lengths is not None + ), "input_tensor_lengths must be provided for 3D input_tensor" # Initialize a list to collect valid sequences valid_sequences = [] @@ -32,11 +34,29 @@ def remove_tensor_padding(input_tensor, input_tensor_lengths=None): class TextEmbedding(nn.Module): - def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2, precompute_max_pos=4096): + def __init__( + self, + text_num_embeds, + text_dim, + conv_layers=0, + conv_mult=2, + precompute_max_pos=4096, + ): super().__init__() - self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token - self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False) - self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]) + self.text_embed = nn.Embedding( + text_num_embeds + 1, text_dim + ) # use 0 as filler token + self.register_buffer( + "freqs_cis", + precompute_freqs_cis(text_dim, precompute_max_pos), + persistent=False, + ) + self.text_blocks = nn.Sequential( + *[ + ConvNeXtV2Block(text_dim, text_dim * conv_mult) + for _ in range(conv_layers) + ] + ) def forward(self, text): # only keep tensors with value not -1 @@ -80,7 +100,9 @@ class ConvNeXtV2Block(nn.Module): dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation ) # depthwise conv self.norm = nn.LayerNorm(dim, eps=1e-6) - self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers + self.pwconv1 = nn.Linear( + dim, intermediate_dim + ) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.grn = GRN(intermediate_dim) self.pwconv2 = nn.Linear(intermediate_dim, dim) @@ -98,7 +120,9 @@ class ConvNeXtV2Block(nn.Module): return residual + x -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0): +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0 +): # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning # has some connection to NTK literature # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ @@ -125,7 +149,9 @@ def load_checkpoint(ckpt_path, use_ema=True): for key in dict_state.keys(): # transformer.text_embed.text_embed.weight -> text_embed.weight if "text_embed" in key: - text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[key] + text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[ + key + ] return text_embed_dict @@ -148,7 +174,12 @@ class F5TTS(object): pp_size = config["pretrained_config"]["mapping"]["pp_size"] assert pp_size == 1 self.mapping = tensorrt_llm.Mapping( - world_size=world_size, rank=rank, cp_size=cp_size, tp_size=tp_size, pp_size=1, gpus_per_node=1 + world_size=world_size, + rank=rank, + cp_size=cp_size, + tp_size=tp_size, + pp_size=1, + gpus_per_node=1, ) local_rank = rank % self.mapping.gpus_per_node @@ -176,10 +207,23 @@ class F5TTS(object): self.outputs = {} self.buffer_allocated = False - expected_tensor_names = ["noise", "cond", "time", "rope_cos", "rope_sin", "input_lengths", "denoised"] - - found_tensor_names = [self.session.engine.get_tensor_name(i) for i in range(self.session.engine.num_io_tensors)] - if not self.debug_mode and set(expected_tensor_names) != set(found_tensor_names): + expected_tensor_names = [ + "noise", + "cond", + "time", + "rope_cos", + "rope_sin", + "input_lengths", + "denoised", + ] + + found_tensor_names = [ + self.session.engine.get_tensor_name(i) + for i in range(self.session.engine.num_io_tensors) + ] + if not self.debug_mode and set(expected_tensor_names) != set( + found_tensor_names + ): logger.error( f"The following expected tensors are not found: {set(expected_tensor_names).difference(set(found_tensor_names))}" ) @@ -190,11 +234,16 @@ class F5TTS(object): logger.error(f"Found tensor names: {found_tensor_names}") raise RuntimeError("Tensor names in engine are not the same as expected.") if self.debug_mode: - self.debug_tensors = list(set(found_tensor_names) - set(expected_tensor_names)) + self.debug_tensors = list( + set(found_tensor_names) - set(expected_tensor_names) + ) self.max_mel_len = 4096 self.text_embedding = TextEmbedding( - text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len + text_num_embeds=vocab_size, + text_dim=512, + conv_layers=4, + precompute_max_pos=self.max_mel_len, ).to(self.device) self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True) @@ -208,9 +257,16 @@ class F5TTS(object): self.head_dim = 64 self.base_rescale_factor = 1.0 self.interpolation_factor = 1.0 - base = 10000.0 * self.base_rescale_factor ** (self.head_dim / (self.head_dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)) - freqs = torch.outer(torch.arange(self.max_mel_len, dtype=torch.float32), inv_freq) / self.interpolation_factor + base = 10000.0 * self.base_rescale_factor ** ( + self.head_dim / (self.head_dim - 2) + ) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim) + ) + freqs = ( + torch.outer(torch.arange(self.max_mel_len, dtype=torch.float32), inv_freq) + / self.interpolation_factor + ) self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0) self.rope_cos = self.freqs.cos().half() self.rope_sin = self.freqs.sin().half() @@ -223,7 +279,9 @@ class F5TTS(object): time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32) half_dim = tmp_dim // 2 emb_factor = math.log(10000) / (half_dim - 1) - emb_factor = 1000.0 * torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb_factor) + emb_factor = 1000.0 * torch.exp( + torch.arange(half_dim, dtype=torch.float32) * -emb_factor + ) for i in range(self.nfe_steps): emb = time_step[i] * emb_factor time_expand[:, i, :] = torch.cat((emb.sin(), emb.cos()), dim=-1) @@ -242,7 +300,9 @@ class F5TTS(object): shape = list(self.session.engine.get_tensor_shape(name)) shape[0] = batch_size shape[1] = seq_len - self.outputs[name] = torch.empty(shape, dtype=self._tensor_dtype(name), device=self.device) + self.outputs[name] = torch.empty( + shape, dtype=self._tensor_dtype(name), device=self.device + ) self.buffer_allocated = True @@ -356,17 +416,29 @@ class F5TTS(object): max_seq_len = ref_mel_batch.shape[1] text_pad_sequence_drop = torch.cat( - (text_pad_sequence, torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(self.device)), dim=0 + ( + text_pad_sequence, + torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to( + self.device + ), + ), + dim=0, ) text_embedding_drop_list = [] for i in range(batch + 1): - text_embedding_drop_list.append(self.text_embedding(text_pad_sequence_drop[i].unsqueeze(0).to(self.device))) + text_embedding_drop_list.append( + self.text_embedding( + text_pad_sequence_drop[i].unsqueeze(0).to(self.device) + ) + ) text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0) text_embedding = text_embedding_drop_condition[:-1] # text_embedding_drop B,T,C batch should be the same - text_embedding_drop = text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1) + text_embedding_drop = ( + text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1) + ) noise = torch.randn_like(ref_mel_batch).to(self.device) rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1) @@ -375,7 +447,9 @@ class F5TTS(object): cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1) cat_mel_text_drop = torch.cat( ( - torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device), + torch.zeros( + (batch, max_seq_len, self.n_mel_channels), dtype=torch.float32 + ).to(self.device), text_embedding_drop, ), dim=-1, @@ -384,7 +458,9 @@ class F5TTS(object): time_expand = self.time_expand.repeat(2 * batch, 1, 1).contiguous() # Convert estimated_reference_target_mel_len to tensor - input_lengths = torch.tensor(estimated_reference_target_mel_len, dtype=torch.int32) + input_lengths = torch.tensor( + estimated_reference_target_mel_len, dtype=torch.int32 + ) # combine above along the batch dimension inputs = { @@ -393,20 +469,34 @@ class F5TTS(object): "time_expand": time_expand, "rope_cos": torch.cat((rope_cos, rope_cos), dim=0).contiguous(), "rope_sin": torch.cat((rope_sin, rope_sin), dim=0).contiguous(), - "input_lengths": torch.cat((input_lengths, input_lengths), dim=0).contiguous(), + "input_lengths": torch.cat( + (input_lengths, input_lengths), dim=0 + ).contiguous(), "delta_t": self.delta_t, } if use_perf and remove_input_padding: torch.cuda.nvtx.range_push("remove input padding") if remove_input_padding: max_seq_len = inputs["cond"].shape[1] - inputs["noise"] = remove_tensor_padding(inputs["noise"], inputs["input_lengths"]) - inputs["cond"] = remove_tensor_padding(inputs["cond"], inputs["input_lengths"]) + inputs["noise"] = remove_tensor_padding( + inputs["noise"], inputs["input_lengths"] + ) + inputs["cond"] = remove_tensor_padding( + inputs["cond"], inputs["input_lengths"] + ) # for time_expand, convert from B,D to B,T,D by repeat - inputs["time_expand"] = inputs["time_expand"].unsqueeze(1).repeat(1, max_seq_len, 1, 1) - inputs["time_expand"] = remove_tensor_padding(inputs["time_expand"], inputs["input_lengths"]) - inputs["rope_cos"] = remove_tensor_padding(inputs["rope_cos"], inputs["input_lengths"]) - inputs["rope_sin"] = remove_tensor_padding(inputs["rope_sin"], inputs["input_lengths"]) + inputs["time_expand"] = ( + inputs["time_expand"].unsqueeze(1).repeat(1, max_seq_len, 1, 1) + ) + inputs["time_expand"] = remove_tensor_padding( + inputs["time_expand"], inputs["input_lengths"] + ) + inputs["rope_cos"] = remove_tensor_padding( + inputs["rope_cos"], inputs["input_lengths"] + ) + inputs["rope_sin"] = remove_tensor_padding( + inputs["rope_sin"], inputs["input_lengths"] + ) if use_perf and remove_input_padding: torch.cuda.nvtx.range_pop() for key in inputs: @@ -422,7 +512,9 @@ class F5TTS(object): denoised_list = [] start_idx = 0 for i in range(batch): - denoised_list.append(denoised[start_idx : start_idx + inputs["input_lengths"][i]]) + denoised_list.append( + denoised[start_idx : start_idx + inputs["input_lengths"][i]] + ) start_idx += inputs["input_lengths"][i] if use_perf and remove_input_padding: torch.cuda.nvtx.range_pop() diff --git a/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py b/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py index 6926d87e8bf068d621fc1e555ecd298ddefa7c5b..efc97faa6b2394e9de6e3435e703e243c3b4d34d 100644 --- a/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py +++ b/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py @@ -73,7 +73,9 @@ def convert_char_to_pinyin(reference_target_texts_list, polyphone=True): if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": char_list.append(" ") char_list.extend(seg) - elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters + elif polyphone and seg_byte_len == 3 * len( + seg + ): # if pure east asian characters seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) for i, c in enumerate(seg): if is_chinese(c): @@ -85,7 +87,9 @@ def convert_char_to_pinyin(reference_target_texts_list, polyphone=True): char_list.extend(c) elif is_chinese(c): char_list.append(" ") - char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)) + char_list.extend( + lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True) + ) else: char_list.append(c) final_reference_target_texts_list.append(char_list) @@ -98,7 +102,9 @@ def list_str_to_idx( vocab_char_map: dict[str, int], # {char: idx} padding_value=-1, ): # noqa: F722 - list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style + list_idx_tensors = [ + torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text + ] # pinyin or char style return list_idx_tensors @@ -121,7 +127,9 @@ class TritonPythonModel: self.vocab_char_map, self.vocab_size = get_tokenizer(parameters["vocab_file"]) self.reference_sample_rate = int(parameters["reference_audio_sample_rate"]) - self.resampler = torchaudio.transforms.Resample(self.reference_sample_rate, self.target_audio_sample_rate) + self.resampler = torchaudio.transforms.Resample( + self.reference_sample_rate, self.target_audio_sample_rate + ) self.tllm_model_dir = parameters["tllm_model_dir"] config_file = os.path.join(self.tllm_model_dir, "config.json") @@ -163,13 +171,17 @@ class TritonPythonModel: input_tensor_0 = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel)) inference_request = pb_utils.InferenceRequest( - model_name="vocoder", requested_output_names=["waveform"], inputs=[input_tensor_0] + model_name="vocoder", + requested_output_names=["waveform"], + inputs=[input_tensor_0], ) inference_response = inference_request.exec() if inference_response.has_error(): raise pb_utils.TritonModelException(inference_response.error().message()) else: - waveform = pb_utils.get_output_tensor_by_name(inference_response, "waveform") + waveform = pb_utils.get_output_tensor_by_name( + inference_response, "waveform" + ) waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() return waveform @@ -181,7 +193,13 @@ class TritonPythonModel: reference_target_texts_list, estimated_reference_target_mel_len, reference_mel_len, - ) = [], [], [], [], [] + ) = ( + [], + [], + [], + [], + [], + ) mel_features_list = [] if self.use_perf: torch.cuda.nvtx.range_push("preprocess") @@ -189,10 +207,14 @@ class TritonPythonModel: wav_tensor = pb_utils.get_input_tensor_by_name(request, "reference_wav") wav_lens = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") - reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() + reference_text = pb_utils.get_input_tensor_by_name( + request, "reference_text" + ).as_numpy() reference_text = reference_text[0][0].decode("utf-8") reference_text_list.append(reference_text) - target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() + target_text = pb_utils.get_input_tensor_by_name( + request, "target_text" + ).as_numpy() target_text = target_text[0][0].decode("utf-8") target_text_list.append(target_text) @@ -221,30 +243,49 @@ class TritonPythonModel: reference_mel_len.append(mel_features.shape[1]) estimated_reference_target_mel_len.append( int( - mel_features.shape[1] * (1 + len(target_text.encode("utf-8")) / len(reference_text.encode("utf-8"))) + mel_features.shape[1] + * ( + 1 + + len(target_text.encode("utf-8")) + / len(reference_text.encode("utf-8")) + ) ) ) max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len) batch = len(requests) - mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float16).to(self.device) + mel_features = torch.zeros( + (batch, max_seq_len, self.n_mel_channels), dtype=torch.float16 + ).to(self.device) for i, mel in enumerate(mel_features_list): mel_features[i, : mel.shape[1], :] = mel reference_mel_len_tensor = torch.LongTensor(reference_mel_len).to(self.device) - pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True) + pinyin_list = convert_char_to_pinyin( + reference_target_texts_list, polyphone=True + ) text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map) for i, item in enumerate(text_pad_sequence): text_pad_sequence[i] = F.pad( - item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1 + item, + (0, estimated_reference_target_mel_len[i] - len(item)), + mode="constant", + value=-1, ) - text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS - text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(self.device) + text_pad_sequence[ + i + ] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS + text_pad_sequence = pad_sequence( + text_pad_sequence, padding_value=-1, batch_first=True + ).to(self.device) text_pad_sequence = F.pad( - text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1 + text_pad_sequence, + (0, max_seq_len - text_pad_sequence.shape[1]), + mode="constant", + value=-1, ) if self.use_perf: torch.cuda.nvtx.range_pop() @@ -264,7 +305,11 @@ class TritonPythonModel: for i in range(batch): ref_me_len = reference_mel_len[i] estimated_mel_len = estimated_reference_target_mel_len[i] - denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2) + denoised_one_item = ( + denoised[i, ref_me_len:estimated_mel_len, :] + .unsqueeze(0) + .transpose(1, 2) + ) audio = self.forward_vocoder(denoised_one_item) rms = torch.sqrt(torch.mean(torch.square(audio))) if rms < self.target_rms: diff --git a/f5_tts/runtime/triton_trtllm/patch/__init__.py b/f5_tts/runtime/triton_trtllm/patch/__init__.py index ab19e9f181f7e5f49e6f2dbe36d32250fcd8761a..dc4aedfb3e23e4c570927949d4600033b19a2e47 100644 --- a/f5_tts/runtime/triton_trtllm/patch/__init__.py +++ b/f5_tts/runtime/triton_trtllm/patch/__init__.py @@ -13,14 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. from .baichuan.model import BaichuanForCausalLM -from .bert.model import ( - BertForQuestionAnswering, - BertForSequenceClassification, - BertModel, - RobertaForQuestionAnswering, - RobertaForSequenceClassification, - RobertaModel, -) +from .bert.model import (BertForQuestionAnswering, + BertForSequenceClassification, BertModel, + RobertaForQuestionAnswering, + RobertaForSequenceClassification, RobertaModel) from .bloom.model import BloomForCausalLM, BloomModel from .chatglm.config import ChatGLMConfig from .chatglm.model import ChatGLMForCausalLM, ChatGLMModel @@ -51,17 +47,17 @@ from .mamba.model import MambaForCausalLM from .medusa.config import MedusaConfig from .medusa.model import MedusaForCausalLm from .mllama.model import MLLaMAModel -from .modeling_utils import PretrainedConfig, PretrainedModel, SpeculativeDecodingMode +from .modeling_utils import (PretrainedConfig, PretrainedModel, + SpeculativeDecodingMode) from .mpt.model import MPTForCausalLM, MPTModel from .nemotron_nas.model import DeciLMForCausalLM from .opt.model import OPTForCausalLM, OPTModel -from .phi.model import PhiForCausalLM, PhiModel from .phi3.model import Phi3ForCausalLM, Phi3Model +from .phi.model import PhiForCausalLM, PhiModel from .qwen.model import QWenForCausalLM from .recurrentgemma.model import RecurrentGemmaForCausalLM from .redrafter.model import ReDrafterForCausalLM - __all__ = [ "BertModel", "BertForQuestionAnswering", diff --git a/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py b/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py index 2f8007ff14a669a3db3302d3853711b3f9fd6106..b9e064509e4d130711258636382805f0982b6564 100644 --- a/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py +++ b/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py @@ -13,8 +13,8 @@ from ...layers import Linear from ...module import Module, ModuleList from ...plugin import current_all_reduce_helper from ..modeling_utils import PretrainedConfig, PretrainedModel -from .modules import AdaLayerNormZero_Final, ConvPositionEmbedding, DiTBlock, TimestepEmbedding - +from .modules import (AdaLayerNormZero_Final, ConvPositionEmbedding, DiTBlock, + TimestepEmbedding) current_file_path = os.path.abspath(__file__) parent_dir = os.path.dirname(current_file_path) @@ -38,7 +38,9 @@ class F5TTS(PretrainedModel): self.dtype = str_dtype_to_trt(config.dtype) self.time_embed = TimestepEmbedding(config.hidden_size) - self.input_embed = InputEmbedding(config.mel_dim, config.text_dim, config.hidden_size) + self.input_embed = InputEmbedding( + config.mel_dim, config.text_dim, config.hidden_size + ) self.dim = config.hidden_size self.depth = config.num_hidden_layers @@ -71,7 +73,14 @@ class F5TTS(PretrainedModel): t = self.time_embed(time) x = self.input_embed(noise, cond) for block in self.transformer_blocks: - x = block(x, t, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale) + x = block( + x, + t, + rope_cos=rope_cos, + rope_sin=rope_sin, + input_lengths=input_lengths, + scale=scale, + ) denoise = self.proj_out(self.norm_out(x, t)) denoise.mark_output("denoised", self.dtype) return denoise diff --git a/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py b/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py index 2121d2879b7667359eb6585ef6c3b8ebbc9fc4a9..a9d1eb47b184e0dce738646e1cf0405f51778d9a 100644 --- a/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py +++ b/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py @@ -9,28 +9,10 @@ import torch.nn.functional as F from tensorrt_llm._common import default_net from ..._utils import str_dtype_to_trt, trt_dtype_to_np -from ...functional import ( - Tensor, - bert_attention, - cast, - chunk, - concat, - constant, - expand, - expand_dims, - expand_dims_like, - expand_mask, - gelu, - matmul, - permute, - shape, - silu, - slice, - softmax, - squeeze, - unsqueeze, - view, -) +from ...functional import (Tensor, bert_attention, cast, chunk, concat, + constant, expand, expand_dims, expand_dims_like, + expand_mask, gelu, matmul, permute, shape, silu, + slice, softmax, squeeze, unsqueeze, view) from ...layers import ColumnLinear, Conv1d, LayerNorm, Linear, Mish, RowLinear from ...module import Module @@ -57,7 +39,9 @@ class AdaLayerNormZero(Module): def forward(self, x, emb=None): emb = self.linear(silu(emb)) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = chunk(emb, 6, dim=1) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = chunk( + emb, 6, dim=1 + ) x = self.norm(x) ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype) if default_net().plugin_config.remove_input_padding: @@ -91,8 +75,12 @@ class ConvPositionEmbedding(Module): def __init__(self, dim, kernel_size=31, groups=16): super().__init__() assert kernel_size % 2 != 0 - self.conv1d1 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2) - self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2) + self.conv1d1 = Conv1d( + dim, dim, kernel_size, groups=groups, padding=kernel_size // 2 + ) + self.conv1d2 = Conv1d( + dim, dim, kernel_size, groups=groups, padding=kernel_size // 2 + ) self.mish = Mish() def forward(self, x, mask=None): # noqa: F722 @@ -120,7 +108,9 @@ class Attention(Module): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + raise ImportError( + "Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) self.processor = processor @@ -191,16 +181,32 @@ class Attention(Module): c_rope=None, # rotary position embedding for c ) -> torch.Tensor: if c is not None: - return self.processor(self, x, c=c, input_lengths=input_lengths, scale=scale, rope=rope, c_rope=c_rope) + return self.processor( + self, + x, + c=c, + input_lengths=input_lengths, + scale=scale, + rope=rope, + c_rope=c_rope, + ) else: return self.processor( - self, x, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale + self, + x, + rope_cos=rope_cos, + rope_sin=rope_sin, + input_lengths=input_lengths, + scale=scale, ) def rotate_every_two_3dim(tensor: Tensor) -> Tensor: shape_tensor = concat( - [shape(tensor, i) / 2 if i == (tensor.ndim() - 1) else shape(tensor, i) for i in range(tensor.ndim())] + [ + shape(tensor, i) / 2 if i == (tensor.ndim() - 1) else shape(tensor, i) + for i in range(tensor.ndim()) + ] ) if default_net().plugin_config.remove_input_padding: assert tensor.ndim() == 2 @@ -208,7 +214,9 @@ def rotate_every_two_3dim(tensor: Tensor) -> Tensor: x2 = slice(tensor, [0, 1], shape_tensor, [1, 2]) x1 = expand_dims(x1, 2) x2 = expand_dims(x2, 2) - zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype)))) + zero = constant( + np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))) + ) x2 = zero - x2 x = concat([x2, x1], 2) out = view(x, concat([shape(x, 0), shape(x, 1) * 2])) @@ -219,7 +227,9 @@ def rotate_every_two_3dim(tensor: Tensor) -> Tensor: x2 = slice(tensor, [0, 0, 1], shape_tensor, [1, 1, 2]) x1 = expand_dims(x1, 3) x2 = expand_dims(x2, 3) - zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype)))) + zero = constant( + np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))) + ) x2 = zero - x2 x = concat([x2, x1], 3) out = view(x, concat([shape(x, 0), shape(x, 1), shape(x, 2) * 2])) @@ -235,15 +245,23 @@ def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin): end_dim = shape(x, -1) - shape(rope_cos, -1) new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960) x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1]) - out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1) + out = concat( + [x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1 + ) else: rot_dim = shape(rope_cos, 2) # 64 new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64) x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1]) end_dim = shape(x, 2) - shape(rope_cos, 2) - new_t_unrotated_shape = concat([shape(x, 0), shape(x, 1), end_dim]) # (2, -1, 960) - x_unrotated = slice(x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1]) - out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1) + new_t_unrotated_shape = concat( + [shape(x, 0), shape(x, 1), end_dim] + ) # (2, -1, 960) + x_unrotated = slice( + x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1] + ) + out = concat( + [x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1 + ) return out @@ -279,8 +297,12 @@ class AttnProcessor: seq_len_2d = concat([1, N]) max_position_embeddings = 4096 # create position ids - position_ids_buffer = constant(np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0)) - tmp_position_ids = slice(position_ids_buffer, starts=[0, 0], sizes=seq_len_2d) + position_ids_buffer = constant( + np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0) + ) + tmp_position_ids = slice( + position_ids_buffer, starts=[0, 0], sizes=seq_len_2d + ) tmp_position_ids = expand(tmp_position_ids, concat([B, N])) # BxL tmp_input_lengths = unsqueeze(input_lengths, 1) # Bx1 tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) # BxL @@ -315,14 +337,28 @@ class AttnProcessor: assert not default_net().plugin_config.remove_input_padding def transpose_for_scores(x): - new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size]) + new_x_shape = concat( + [ + shape(x, 0), + shape(x, 1), + attn.num_attention_heads, + attn.attention_head_size, + ] + ) y = x.view(new_x_shape) y = y.transpose(1, 2) return y def transpose_for_scores_k(x): - new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size]) + new_x_shape = concat( + [ + shape(x, 0), + shape(x, 1), + attn.num_attention_heads, + attn.attention_head_size, + ] + ) y = x.view(new_x_shape) y = y.permute([0, 2, 3, 1]) @@ -342,7 +378,11 @@ class AttnProcessor: attention_probs = softmax(attention_scores, dim=-1) context = matmul(attention_probs, value, use_fp32_acc=False).transpose(1, 2) - context = context.view(concat([shape(context, 0), shape(context, 1), attn.attention_hidden_size])) + context = context.view( + concat( + [shape(context, 0), shape(context, 1), attn.attention_hidden_size] + ) + ) context = attn.to_out(context) if mask is not None: mask = mask.view(concat([shape(mask, 0), shape(mask, 1), 1])) @@ -370,13 +410,26 @@ class DiTBlock(Module): self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout) def forward( - self, x, t, rope_cos, rope_sin, input_lengths, scale=1.0, rope=ModuleNotFoundError + self, + x, + t, + rope_cos, + rope_sin, + input_lengths, + scale=1.0, + rope=ModuleNotFoundError, ): # x: noised input, t: time embedding # pre-norm & modulation for attention input norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) # attention # norm ----> (2,1226,1024) - attn_output = self.attn(x=norm, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale) + attn_output = self.attn( + x=norm, + rope_cos=rope_cos, + rope_sin=rope_sin, + input_lengths=input_lengths, + scale=scale, + ) # process attention output for input x if default_net().plugin_config.remove_input_padding: @@ -387,7 +440,9 @@ class DiTBlock(Module): if default_net().plugin_config.remove_input_padding: norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp else: - norm = self.ff_norm(x) * (ones + unsqueeze(scale_mlp, 1)) + unsqueeze(shift_mlp, 1) + norm = self.ff_norm(x) * (ones + unsqueeze(scale_mlp, 1)) + unsqueeze( + shift_mlp, 1 + ) # norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp ff_output = self.ff(norm) if default_net().plugin_config.remove_input_padding: diff --git a/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py b/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py index 993e47243fc6998ae2c15e8d22aa3e56a1aae1cf..ef54876f4b01e4542546b2694110cc13329a1204 100644 --- a/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py +++ b/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py @@ -40,7 +40,6 @@ import torch as th import torch.nn.functional as F from scipy.signal import check_COLA, get_window - support_clp_op = None if th.__version__ >= "1.7.0": from torch.fft import rfft as fft @@ -124,7 +123,9 @@ class STFT(th.nn.Module): ifft_kernel = th.pinverse(fft_kernel)[:, None, :] window = get_window(self.win_type, self.win_len) - self.perfect_reconstruct = check_COLA(window, self.win_len, self.win_len - self.win_hop) + self.perfect_reconstruct = check_COLA( + window, self.win_len, self.win_len - self.win_hop + ) window = th.FloatTensor(window) if self.mode == "continue": left_pad = (self.fft_len - self.win_len) // 2 diff --git a/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py b/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py index fcf55922f932c2c03442dd4d6e789e6472bc165f..53107a05d6905e8d255ee17c5d107e97a0217746 100644 --- a/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py +++ b/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py @@ -179,19 +179,47 @@ def parse_arguments(): ) # TODO: support F5TTS_v1_Base parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt") parser.add_argument( - "--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint" + "--output_dir", + type=str, + default="./tllm_checkpoint", + help="The path to save the TensorRT-LLM checkpoint", + ) + parser.add_argument( + "--hidden_size", type=int, default=1024, help="The hidden size of DiT" + ) + parser.add_argument( + "--depth", type=int, default=22, help="The number of DiTBlock layers" + ) + parser.add_argument( + "--num_heads", + type=int, + default=16, + help="The number of heads of attention module", ) - parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT") - parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers") - parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module") parser.add_argument("--cfg_scale", type=float, default=4.0) - parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size") - parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size") - parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size") - parser.add_argument("--dtype", type=str, default="float16", choices=["float32", "bfloat16", "float16"]) - parser.add_argument("--fp8_linear", action="store_true", help="Whether use FP8 for linear layers") parser.add_argument( - "--workers", type=int, default=1, help="The number of workers for converting checkpoint in parallel" + "--tp_size", type=int, default=1, help="N-way tensor parallelism size" + ) + parser.add_argument( + "--cp_size", type=int, default=1, help="Context parallelism size" + ) + parser.add_argument( + "--pp_size", type=int, default=1, help="N-way pipeline parallelism size" + ) + parser.add_argument( + "--dtype", + type=str, + default="float16", + choices=["float32", "bfloat16", "float16"], + ) + parser.add_argument( + "--fp8_linear", action="store_true", help="Whether use FP8 for linear layers" + ) + parser.add_argument( + "--workers", + type=int, + default=1, + help="The number of workers for converting checkpoint in parallel", ) args = parser.parse_args() return args @@ -205,10 +233,15 @@ def convert_timm_dit(args, mapping, dtype="float32"): model_params = dict(torch.load(args.timm_ckpt)) model_params = { - k: v for k, v in model_params["ema_model_state_dict"].items() if k.startswith("ema_model.transformer") + k: v + for k, v in model_params["ema_model_state_dict"].items() + if k.startswith("ema_model.transformer") } prefix = "ema_model.transformer." - model_params = {key[len(prefix) :] if key.startswith(prefix) else key: value for key, value in model_params.items()} + model_params = { + key[len(prefix) :] if key.startswith(prefix) else key: value + for key, value in model_params.items() + } timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING @@ -223,8 +256,13 @@ def convert_timm_dit(args, mapping, dtype="float32"): weights = dict() for name, param in model_params.items(): - if name == "input_embed.conv_pos_embed.conv1d.0.weight" or name == "input_embed.conv_pos_embed.conv1d.2.weight": - weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype).unsqueeze(-1) + if ( + name == "input_embed.conv_pos_embed.conv1d.0.weight" + or name == "input_embed.conv_pos_embed.conv1d.2.weight" + ): + weights[get_trtllm_name(name)] = ( + param.contiguous().to(torch_dtype).unsqueeze(-1) + ) else: weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype) @@ -239,25 +277,37 @@ def convert_timm_dit(args, mapping, dtype="float32"): for k, v in weights.items(): if re.match("^transformer_blocks.*.attn.to_k.weight$", k): weights[k] *= scale_factor - weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) + weights[k] = split_q_tp( + v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank + ) elif re.match("^transformer_blocks.*.attn.to_k.bias$", k): weights[k] *= scale_factor - weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) + weights[k] = split_q_bias_tp( + v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank + ) elif re.match("^transformer_blocks.*.attn.to_q.weight$", k): - weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) + weights[k] = split_q_tp( + v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank + ) weights[k] *= scale_factor elif re.match("^transformer_blocks.*.attn.to_q.bias$", k): - weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) + weights[k] = split_q_bias_tp( + v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank + ) weights[k] *= scale_factor elif re.match("^transformer_blocks.*.attn.to_v.weight$", k): - weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) + weights[k] = split_q_tp( + v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank + ) elif re.match("^transformer_blocks.*.attn.to_v.bias$", k): - weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) + weights[k] = split_q_bias_tp( + v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank + ) elif re.match("^transformer_blocks.*.attn.to_out.weight$", k): weights[k] = split_matrix_tp(v, tensor_parallel, mapping.tp_rank, dim=1) @@ -317,7 +367,9 @@ def covert_and_save(args, rank): weights = convert_timm_dit(args, mapping, dtype=args.dtype) - safetensors.torch.save_file(weights, os.path.join(args.output_dir, f"rank{rank}.safetensors")) + safetensors.torch.save_file( + weights, os.path.join(args.output_dir, f"rank{rank}.safetensors") + ) def execute(workers, func, args): @@ -334,7 +386,9 @@ def execute(workers, func, args): except Exception as e: traceback.print_exc() exceptions.append(e) - assert len(exceptions) == 0, "Checkpoint conversion failed, please check error log." + assert ( + len(exceptions) == 0 + ), "Checkpoint conversion failed, please check error log." def main(): diff --git a/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py b/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py index 6743aeca06073cfa66796a2efbe7d1d371ee18fe..b2a0444645983d52dbe8fc689bbaa48ae0ec7254 100644 --- a/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py +++ b/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py @@ -20,12 +20,13 @@ from conv_stft import STFT from huggingface_hub import hf_hub_download from vocos import Vocos - opset_version = 17 def get_args(): - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) parser.add_argument( "--vocoder", type=str, @@ -108,7 +109,9 @@ def export_VocosVocoder(vocos_vocoder, output_path, verbose): print("Exported to {}".format(output_path)) -def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device="cpu", hf_cache_dir=None): +def load_vocoder( + vocoder_name="vocos", is_local=False, local_path="", device="cpu", hf_cache_dir=None +): if vocoder_name == "vocos": # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device) if is_local: @@ -118,8 +121,12 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device="cp else: print("Download Vocos from huggingface charactr/vocos-mel-24khz") repo_id = "charactr/vocos-mel-24khz" - config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml") - model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin") + config_path = hf_hub_download( + repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml" + ) + model_path = hf_hub_download( + repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin" + ) vocoder = Vocos.from_hparams(config_path) state_dict = torch.load(model_path, map_location="cpu", weights_only=True) vocoder.load_state_dict(state_dict) diff --git a/f5_tts/runtime/triton_trtllm/scripts/fill_template.py b/f5_tts/runtime/triton_trtllm/scripts/fill_template.py index 105cfac85984d104d1204f0ead52e2bd1048c4be..6b2df01e9784d364ad683474da37d5f003e71a87 100644 --- a/f5_tts/runtime/triton_trtllm/scripts/fill_template.py +++ b/f5_tts/runtime/triton_trtllm/scripts/fill_template.py @@ -29,8 +29,12 @@ if __name__ == "__main__": "substitutions", help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2...", ) - parser.add_argument("--in_place", "-i", action="store_true", help="do the operation in-place") - parser.add_argument("--participant_ids", help="Participant IDs for the model", default="") + parser.add_argument( + "--in_place", "-i", action="store_true", help="do the operation in-place" + ) + parser.add_argument( + "--participant_ids", help="Participant IDs for the model", default="" + ) args = parser.parse_args() main(**vars(args)) diff --git a/f5_tts/scripts/count_max_epoch.py b/f5_tts/scripts/count_max_epoch.py index 5e62b76abf5374d90aae7c8d7f8eba8b6913a062..b1c723e12c48b3af09954835dab98f48bc900131 100644 --- a/f5_tts/scripts/count_max_epoch.py +++ b/f5_tts/scripts/count_max_epoch.py @@ -24,10 +24,14 @@ updates_per_epoch = total_hours / mini_batch_hours # result epochs = wanted_max_updates / updates_per_epoch -print(f"epochs should be set to: {epochs:.0f} ({epochs / grad_accum:.1f} x gd_acum {grad_accum})") +print( + f"epochs should be set to: {epochs:.0f} ({epochs / grad_accum:.1f} x gd_acum {grad_accum})" +) print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates") # print(f" or approx. 0/{steps_per_epoch:.0f} steps") # others print(f"total {total_hours:.0f} hours") -print(f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch") +print( + f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch" +) diff --git a/f5_tts/scripts/count_params_gflops.py b/f5_tts/scripts/count_params_gflops.py index d706388add443641bdac3ed36f0f876d4f547063..130ab97951345a1f13cc81371549ad5325da0864 100644 --- a/f5_tts/scripts/count_params_gflops.py +++ b/f5_tts/scripts/count_params_gflops.py @@ -1,7 +1,6 @@ import os import sys - sys.path.append(os.getcwd()) import thop @@ -9,7 +8,6 @@ import torch from f5_tts.model import CFM, DiT - """ ~155M """ # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4) # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4) @@ -34,7 +32,11 @@ frame_length = int(duration * target_sample_rate / hop_length) text_length = 150 flops, params = thop.profile( - model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long)) + model, + inputs=( + torch.randn(1, frame_length, n_mel_channels), + torch.zeros(1, text_length, dtype=torch.long), + ), ) print(f"FLOPs: {flops / 1e9} G") print(f"Params: {params / 1e6} M") diff --git a/f5_tts/socket_client.py b/f5_tts/socket_client.py index c47ad4482600c4338afddb8eb9d40e08e6c3cf7c..382e1fe2119879317ab15cd73da20dde002449ce 100644 --- a/f5_tts/socket_client.py +++ b/f5_tts/socket_client.py @@ -6,14 +6,15 @@ import time import numpy as np import pyaudio - logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998): client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port))) + await asyncio.get_event_loop().run_in_executor( + None, client_socket.connect, (server_ip, int(server_port)) + ) start_time = time.time() first_chunk_time = None @@ -21,11 +22,19 @@ async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998): async def play_audio_stream(): nonlocal first_chunk_time p = pyaudio.PyAudio() - stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048) + stream = p.open( + format=pyaudio.paFloat32, + channels=1, + rate=24000, + output=True, + frames_per_buffer=2048, + ) try: while True: - data = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 8192) + data = await asyncio.get_event_loop().run_in_executor( + None, client_socket.recv, 8192 + ) if not data: break if data == b"END": @@ -47,7 +56,9 @@ async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998): try: data_to_send = f"{text}".encode("utf-8") - await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send) + await asyncio.get_event_loop().run_in_executor( + None, client_socket.sendall, data_to_send + ) await play_audio_stream() except Exception as e: diff --git a/f5_tts/socket_server.py b/f5_tts/socket_server.py index 3fd780ae1f275575da42360167c47395389b271e..802ad27a868aa977e5819e5d38e5c009a7116277 100644 --- a/f5_tts/socket_server.py +++ b/f5_tts/socket_server.py @@ -16,14 +16,9 @@ from huggingface_hub import hf_hub_download from hydra.utils import get_class from omegaconf import OmegaConf -from f5_tts.infer.utils_infer import ( - chunk_text, - infer_batch_process, - load_model, - load_vocoder, - preprocess_ref_audio_text, -) - +from f5_tts.infer.utils_infer import (chunk_text, infer_batch_process, + load_model, load_vocoder, + preprocess_ref_audio_text) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -70,17 +65,28 @@ class AudioFileWriterThread(threading.Thread): class TTSStreamingProcessor: - def __init__(self, model, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32): + def __init__( + self, + model, + ckpt_file, + vocab_file, + ref_audio, + ref_text, + device=None, + dtype=torch.float32, + ): self.device = device or ( "cuda" if torch.cuda.is_available() - else "xpu" - if torch.xpu.is_available() - else "mps" - if torch.backends.mps.is_available() - else "cpu" + else ( + "xpu" + if torch.xpu.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" + ) + ) + model_cfg = OmegaConf.load( + str(files("f5_tts").joinpath(f"configs/{model}.yaml")) ) - model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) self.model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") self.model_arc = model_cfg.model.arch self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type @@ -107,7 +113,12 @@ class TTSStreamingProcessor: ).to(self.device, dtype=dtype) def load_vocoder_model(self): - return load_vocoder(vocoder_name=self.mel_spec_type, is_local=False, local_path=None, device=self.device) + return load_vocoder( + vocoder_name=self.mel_spec_type, + is_local=False, + local_path=None, + device=self.device, + ) def update_reference(self, ref_audio, ref_text): self.ref_audio, self.ref_text = preprocess_ref_audio_text(ref_audio, ref_text) @@ -115,9 +126,15 @@ class TTSStreamingProcessor: ref_audio_duration = self.audio.shape[-1] / self.sr ref_text_byte_len = len(self.ref_text.encode("utf-8")) - self.max_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration)) - self.few_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 2) - self.min_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 4) + self.max_chars = int( + ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) + ) + self.few_chars = int( + ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 2 + ) + self.min_chars = int( + ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 4 + ) def _warm_up(self): logger.info("Warming up the model...") @@ -138,8 +155,12 @@ class TTSStreamingProcessor: def generate_stream(self, text, conn): text_batches = chunk_text(text, max_chars=self.max_chars) if self.first_package: - text_batches = chunk_text(text_batches[0], max_chars=self.few_chars) + text_batches[1:] - text_batches = chunk_text(text_batches[0], max_chars=self.min_chars) + text_batches[1:] + text_batches = ( + chunk_text(text_batches[0], max_chars=self.few_chars) + text_batches[1:] + ) + text_batches = ( + chunk_text(text_batches[0], max_chars=self.min_chars) + text_batches[1:] + ) self.first_package = False audio_stream = infer_batch_process( @@ -157,7 +178,9 @@ class TTSStreamingProcessor: # Reset the file writer thread if self.file_writer_thread is not None: self.file_writer_thread.stop() - self.file_writer_thread = AudioFileWriterThread("output.wav", self.sampling_rate) + self.file_writer_thread = AudioFileWriterThread( + "output.wav", self.sampling_rate + ) self.file_writer_thread.start() for audio_chunk, _ in audio_stream: @@ -224,7 +247,12 @@ if __name__ == "__main__": ) parser.add_argument( "--ckpt_file", - default=str(hf_hub_download(repo_id="SWivid/F5-TTS", filename="F5TTS_v1_Base/model_1250000.safetensors")), + default=str( + hf_hub_download( + repo_id="SWivid/F5-TTS", + filename="F5TTS_v1_Base/model_1250000.safetensors", + ) + ), help="Path to the model checkpoint file", ) parser.add_argument( @@ -245,7 +273,9 @@ if __name__ == "__main__": ) parser.add_argument("--device", default=None, help="Device to run the model on") - parser.add_argument("--dtype", default=torch.float32, help="Data type to use for model inference") + parser.add_argument( + "--dtype", default=torch.float32, help="Data type to use for model inference" + ) args = parser.parse_args() diff --git a/f5_tts/train/datasets/prepare_csv_wavs.py b/f5_tts/train/datasets/prepare_csv_wavs.py index 26ad6f8a44a7f22b3b9ba10612f9246a0d3dee5f..9c7acbf58a4525192462a17660e173fad51e4ea5 100644 --- a/f5_tts/train/datasets/prepare_csv_wavs.py +++ b/f5_tts/train/datasets/prepare_csv_wavs.py @@ -7,7 +7,6 @@ import subprocess # For invoking ffprobe import sys from contextlib import contextmanager - sys.path.append(os.getcwd()) import argparse @@ -22,8 +21,9 @@ from tqdm import tqdm from f5_tts.model.utils import convert_char_to_pinyin - -PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt") +PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath( + "../../data/Emilia_ZH_EN_pinyin/vocab.txt" +) def is_csv_wavs_format(input_dataset_dir): @@ -75,7 +75,9 @@ def process_audio_file(audio_path, text, polyphone): raise ValueError(f"Duration {audio_duration} is non-positive.") return (audio_path, text, audio_duration) except Exception as e: - print(f"Warning: Failed to process {audio_path} due to error: {e}. Skipping corrupt file.") + print( + f"Warning: Failed to process {audio_path} due to error: {e}. Skipping corrupt file." + ) return None @@ -100,7 +102,9 @@ def prepare_csv_wavs_dir(input_dir, num_workers=None): total_files = len(audio_path_text_pairs) # Use provided worker count or calculate optimal number - worker_count = num_workers if num_workers is not None else min(MAX_WORKERS, total_files) + worker_count = ( + num_workers if num_workers is not None else min(MAX_WORKERS, total_files) + ) print(f"\nProcessing {total_files} audio files using {worker_count} workers...") with graceful_exit(): @@ -115,7 +119,10 @@ def prepare_csv_wavs_dir(input_dir, num_workers=None): for i in range(0, len(audio_path_text_pairs), CHUNK_SIZE): chunk = audio_path_text_pairs[i : i + CHUNK_SIZE] # Submit futures in order - chunk_futures = [executor.submit(process_audio_file, pair[0], pair[1], polyphone) for pair in chunk] + chunk_futures = [ + executor.submit(process_audio_file, pair[0], pair[1], polyphone) + for pair in chunk + ] # Iterate over futures in the original submission order to preserve ordering for future in tqdm( @@ -147,7 +154,9 @@ def prepare_csv_wavs_dir(input_dir, num_workers=None): vocab_set = set() for (audio_path, _, duration), conv_text in zip(processed, converted_texts): - sub_result.append({"audio_path": audio_path, "text": conv_text, "duration": duration}) + sub_result.append( + {"audio_path": audio_path, "text": conv_text, "duration": duration} + ) durations.append(duration) vocab_set.update(list(conv_text)) @@ -171,19 +180,28 @@ def get_audio_duration(audio_path, timeout=5): audio_path, ] result = subprocess.run( - cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True, timeout=timeout + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True, + timeout=timeout, ) duration_str = result.stdout.strip() if duration_str: return float(duration_str) raise ValueError("Empty duration string from ffprobe.") except (subprocess.TimeoutExpired, subprocess.SubprocessError, ValueError) as e: - print(f"Warning: ffprobe failed for {audio_path} with error: {e}. Falling back to torchaudio.") + print( + f"Warning: ffprobe failed for {audio_path} with error: {e}. Falling back to torchaudio." + ) try: audio, sample_rate = torchaudio.load(audio_path) return audio.shape[1] / sample_rate except Exception as e: - raise RuntimeError(f"Both ffprobe and torchaudio failed for {audio_path}: {e}") + raise RuntimeError( + f"Both ffprobe and torchaudio failed for {audio_path}: {e}" + ) def read_audio_text_pairs(csv_file_path): @@ -235,10 +253,16 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours") -def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None): +def prepare_and_save_set( + inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None +): if is_finetune: - assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}" - sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir, num_workers=num_workers) + assert ( + PRETRAINED_VOCAB_PATH.exists() + ), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}" + sub_result, durations, vocab_set = prepare_csv_wavs_dir( + inp_dir, num_workers=num_workers + ) save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune) @@ -265,13 +289,30 @@ Examples: python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --workers 4 """, ) - parser.add_argument("inp_dir", type=str, help="Input directory containing the data.") - parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.") - parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune") - parser.add_argument("--workers", type=int, help=f"Number of worker threads (default: {MAX_WORKERS})") + parser.add_argument( + "inp_dir", type=str, help="Input directory containing the data." + ) + parser.add_argument( + "out_dir", type=str, help="Output directory to save the prepared data." + ) + parser.add_argument( + "--pretrain", + action="store_true", + help="Enable for new pretrain, otherwise is a fine-tune", + ) + parser.add_argument( + "--workers", + type=int, + help=f"Number of worker threads (default: {MAX_WORKERS})", + ) args = parser.parse_args() - prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain, num_workers=args.workers) + prepare_and_save_set( + args.inp_dir, + args.out_dir, + is_finetune=not args.pretrain, + num_workers=args.workers, + ) except KeyboardInterrupt: print("\nOperation cancelled by user. Cleaning up...") if executor is not None: diff --git a/f5_tts/train/datasets/prepare_emilia.py b/f5_tts/train/datasets/prepare_emilia.py index 4c4a771d722b141959cadaf306b0e31d5aae303f..27b188b1cfba75917cc519d06e873de29b239d9a 100644 --- a/f5_tts/train/datasets/prepare_emilia.py +++ b/f5_tts/train/datasets/prepare_emilia.py @@ -7,7 +7,6 @@ import os import sys - sys.path.append(os.getcwd()) import json @@ -20,7 +19,6 @@ from tqdm import tqdm from f5_tts.model.utils import convert_char_to_pinyin, repetition_found - out_zh = { "ZH_B00041_S06226", "ZH_B00042_S09204", @@ -120,7 +118,11 @@ def deal_with_audio_dir(audio_dir): obj = json.loads(line) text = obj["text"] if obj["language"] == "zh": - if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text): + if ( + obj["wav"].split("/")[1] in out_zh + or any(f in text for f in zh_filters) + or repetition_found(text) + ): bad_case_zh += 1 continue else: @@ -138,7 +140,13 @@ def deal_with_audio_dir(audio_dir): if tokenizer == "pinyin": text = convert_char_to_pinyin([text], polyphone=polyphone)[0] duration = obj["duration"] - sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration}) + sub_result.append( + { + "audio_path": str(audio_dir.parent / obj["wav"]), + "text": text, + "duration": duration, + } + ) durations.append(duration) vocab_set.update(list(text)) return sub_result, durations, vocab_set, bad_case_zh, bad_case_en diff --git a/f5_tts/train/datasets/prepare_emilia_v2.py b/f5_tts/train/datasets/prepare_emilia_v2.py index 50322c0e7602f4335473abb109e2dbeb79ccfbe1..673a9004c9fe16c980fa50215b81618286c4f0c5 100644 --- a/f5_tts/train/datasets/prepare_emilia_v2.py +++ b/f5_tts/train/datasets/prepare_emilia_v2.py @@ -12,7 +12,6 @@ from tqdm import tqdm from f5_tts.model.utils import repetition_found - # Define filters for exclusion out_en = set() en_filters = ["ا", "い", "て"] @@ -27,14 +26,22 @@ def process_audio_directory(audio_dir): with open(file, "r") as f: obj = json.load(f) text = obj["text"] - if any(f in text for f in en_filters) or repetition_found(text, length=4): + if any(f in text for f in en_filters) or repetition_found( + text, length=4 + ): bad_case_en += 1 continue duration = obj["duration"] audio_file = file.with_suffix(".mp3") if audio_file.exists(): - sub_result.append({"audio_path": str(audio_file), "text": text, "duration": duration}) + sub_result.append( + { + "audio_path": str(audio_file), + "text": text, + "duration": duration, + } + ) durations.append(duration) vocab_set.update(list(text)) diff --git a/f5_tts/train/datasets/prepare_libritts.py b/f5_tts/train/datasets/prepare_libritts.py index a892dd67e9481447f688c2eabc2ee07b9ef67595..82fbc6d428095470117a99f5c7fb9e1dbfa0ea52 100644 --- a/f5_tts/train/datasets/prepare_libritts.py +++ b/f5_tts/train/datasets/prepare_libritts.py @@ -1,7 +1,6 @@ import os import sys - sys.path.append(os.getcwd()) import json @@ -84,7 +83,9 @@ if __name__ == "__main__": SUB_SET = ["train-clean-100", "train-clean-360", "train-other-500"] dataset_dir = "/LibriTTS" - dataset_name = f"LibriTTS_{'_'.join(SUB_SET)}_{tokenizer}".replace("train-clean-", "").replace("train-other-", "") + dataset_name = f"LibriTTS_{'_'.join(SUB_SET)}_{tokenizer}".replace( + "train-clean-", "" + ).replace("train-other-", "") save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}" print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n") main() diff --git a/f5_tts/train/datasets/prepare_ljspeech.py b/f5_tts/train/datasets/prepare_ljspeech.py index 9f64b0a059807fb07977bca8ce7ff3f6d2628f6a..88e0d16a5d0f379630aee2cce4923b7f5d8a59bd 100644 --- a/f5_tts/train/datasets/prepare_ljspeech.py +++ b/f5_tts/train/datasets/prepare_ljspeech.py @@ -1,7 +1,6 @@ import os import sys - sys.path.append(os.getcwd()) import json @@ -27,7 +26,9 @@ def main(): duration = sf.info(wav_path).duration if duration < 0.4 or duration > 30: continue - result.append({"audio_path": str(wav_path), "text": norm_text, "duration": duration}) + result.append( + {"audio_path": str(wav_path), "text": norm_text, "duration": duration} + ) duration_list.append(duration) text_vocab_set.update(list(norm_text)) diff --git a/f5_tts/train/datasets/prepare_wenetspeech4tts.py b/f5_tts/train/datasets/prepare_wenetspeech4tts.py index 6498421b5402649b01ae3211e726dcdcf5e64116..dea706f06651b05e5b5da7ba4e3678db7e50645a 100644 --- a/f5_tts/train/datasets/prepare_wenetspeech4tts.py +++ b/f5_tts/train/datasets/prepare_wenetspeech4tts.py @@ -4,7 +4,6 @@ import os import sys - sys.path.append(os.getcwd()) import json @@ -55,9 +54,15 @@ def main(): futures = [] for dataset_path in dataset_paths: sub_items = os.listdir(dataset_path) - sub_paths = [item for item in sub_items if os.path.isdir(os.path.join(dataset_path, item))] + sub_paths = [ + item + for item in sub_items + if os.path.isdir(os.path.join(dataset_path, item)) + ] for sub_path in sub_paths: - futures.append(executor.submit(deal_with_sub_path_files, dataset_path, sub_path)) + futures.append( + executor.submit(deal_with_sub_path_files, dataset_path, sub_path) + ) for future in tqdm(futures, total=len(futures)): audio_paths, texts, durations = future.result() audio_path_list.extend(audio_paths) @@ -69,7 +74,9 @@ def main(): os.makedirs("data") print(f"\nSaving to {save_dir} ...") - dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) + dataset = Dataset.from_dict( + {"audio_path": audio_path_list, "text": text_list, "duration": duration_list} + ) dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB") # arrow format with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f: @@ -84,7 +91,9 @@ def main(): # add alphabets and symbols (optional, if plan to ft on de/fr etc.) if tokenizer == "pinyin": - text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)]) + text_vocab_set.update( + [chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)] + ) with open(f"{save_dir}/vocab.txt", "w") as f: for vocab in sorted(text_vocab_set): @@ -101,7 +110,11 @@ if __name__ == "__main__": dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic dataset_name = ( - ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice - 1] + [ + "WenetSpeech4TTS_Premium", + "WenetSpeech4TTS_Standard", + "WenetSpeech4TTS_Basic", + ][dataset_choice - 1] + "_" + tokenizer ) diff --git a/f5_tts/train/finetune_cli.py b/f5_tts/train/finetune_cli.py index cdf42a9ace27a92cf369ce1523a3e640fd61ea3e..c58bb06795a3c6f479d0a5915a0210455d1c8f69 100644 --- a/f5_tts/train/finetune_cli.py +++ b/f5_tts/train/finetune_cli.py @@ -9,7 +9,6 @@ from f5_tts.model import CFM, DiT, Trainer, UNetT from f5_tts.model.dataset import load_dataset from f5_tts.model.utils import get_tokenizer - # -------------------------- Dataset Settings --------------------------- # target_sample_rate = 24000 n_mel_channels = 100 @@ -30,29 +29,74 @@ def parse_args(): choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], help="Experiment name", ) - parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use") - parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training") - parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU") parser.add_argument( - "--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type" + "--dataset_name", + type=str, + default="Emilia_ZH_EN", + help="Name of the dataset to use", + ) + parser.add_argument( + "--learning_rate", type=float, default=1e-5, help="Learning rate for training" + ) + parser.add_argument( + "--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU" + ) + parser.add_argument( + "--batch_size_type", + type=str, + default="frame", + choices=["frame", "sample"], + help="Batch size type", + ) + parser.add_argument( + "--max_samples", type=int, default=64, help="Max sequences per batch" + ) + parser.add_argument( + "--grad_accumulation_steps", + type=int, + default=1, + help="Gradient accumulation steps", + ) + parser.add_argument( + "--max_grad_norm", + type=float, + default=1.0, + help="Max gradient norm for clipping", + ) + parser.add_argument( + "--epochs", type=int, default=100, help="Number of training epochs" + ) + parser.add_argument( + "--num_warmup_updates", type=int, default=20000, help="Warmup updates" + ) + parser.add_argument( + "--save_per_updates", + type=int, + default=50000, + help="Save checkpoint every N updates", ) - parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch") - parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps") - parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping") - parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs") - parser.add_argument("--num_warmup_updates", type=int, default=20000, help="Warmup updates") - parser.add_argument("--save_per_updates", type=int, default=50000, help="Save checkpoint every N updates") parser.add_argument( "--keep_last_n_checkpoints", type=int, default=-1, help="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints", ) - parser.add_argument("--last_per_updates", type=int, default=5000, help="Save last checkpoint every N updates") + parser.add_argument( + "--last_per_updates", + type=int, + default=5000, + help="Save last checkpoint every N updates", + ) parser.add_argument("--finetune", action="store_true", help="Use Finetune") - parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint") parser.add_argument( - "--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type" + "--pretrain", type=str, default=None, help="the path to the checkpoint" + ) + parser.add_argument( + "--tokenizer", + type=str, + default="pinyin", + choices=["pinyin", "char", "custom"], + help="Tokenizer type", ) parser.add_argument( "--tokenizer_path", @@ -65,7 +109,13 @@ def parse_args(): action="store_true", help="Log inferenced samples per ckpt save updates", ) - parser.add_argument("--logger", type=str, default=None, choices=[None, "wandb", "tensorboard"], help="logger") + parser.add_argument( + "--logger", + type=str, + default=None, + choices=[None, "wandb", "tensorboard"], + help="logger", + ) parser.add_argument( "--bnb_optimizer", action="store_true", @@ -98,7 +148,11 @@ def main(): ) if args.finetune: if args.pretrain is None: - ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors")) + ckpt_path = str( + cached_path( + "hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors" + ) + ) else: ckpt_path = args.pretrain @@ -117,7 +171,9 @@ def main(): ) if args.finetune: if args.pretrain is None: - ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) + ckpt_path = str( + cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt") + ) else: ckpt_path = args.pretrain @@ -134,7 +190,9 @@ def main(): ) if args.finetune: if args.pretrain is None: - ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt")) + ckpt_path = str( + cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt") + ) else: ckpt_path = args.pretrain @@ -143,7 +201,9 @@ def main(): os.makedirs(checkpoint_path, exist_ok=True) file_checkpoint = os.path.basename(ckpt_path) - if not file_checkpoint.startswith("pretrained_"): # Change: Add 'pretrained_' prefix to copied model + if not file_checkpoint.startswith( + "pretrained_" + ): # Change: Add 'pretrained_' prefix to copied model file_checkpoint = "pretrained_" + file_checkpoint file_checkpoint = os.path.join(checkpoint_path, file_checkpoint) if not os.path.isfile(file_checkpoint): @@ -155,7 +215,9 @@ def main(): tokenizer = args.tokenizer if tokenizer == "custom": if not args.tokenizer_path: - raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.") + raise ValueError( + "Custom tokenizer selected, but no tokenizer_path provided." + ) tokenizer_path = args.tokenizer_path else: tokenizer_path = args.dataset_name @@ -175,7 +237,9 @@ def main(): ) model = CFM( - transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), + transformer=model_cls( + **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels + ), mel_spec_kwargs=mel_spec_kwargs, vocab_char_map=vocab_char_map, ) @@ -202,7 +266,9 @@ def main(): bnb_optimizer=args.bnb_optimizer, ) - train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs) + train_dataset = load_dataset( + args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs + ) trainer.train( train_dataset, diff --git a/f5_tts/train/finetune_gradio.py b/f5_tts/train/finetune_gradio.py index eee2a3fe67e4ebd713e1aac34897bf2e79dafa10..a33ba74b39ab78268d02943dac0f816cbca7ada3 100644 --- a/f5_tts/train/finetune_gradio.py +++ b/f5_tts/train/finetune_gradio.py @@ -32,7 +32,6 @@ from f5_tts.api import F5TTS from f5_tts.infer.utils_infer import transcribe from f5_tts.model.utils import convert_char_to_pinyin - training_process = None system = platform.system() python_executable = sys.executable or "python" @@ -49,11 +48,11 @@ file_train = str(files("f5_tts").joinpath("train/finetune_cli.py")) device = ( "cuda" if torch.cuda.is_available() - else "xpu" - if torch.xpu.is_available() - else "mps" - if torch.backends.mps.is_available() - else "cpu" + else ( + "xpu" + if torch.xpu.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" + ) ) @@ -189,9 +188,13 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2. max_sil_kept: int = 2000, ): if not min_length >= min_interval >= hop_size: - raise ValueError("The following condition must be satisfied: min_length >= min_interval >= hop_size") + raise ValueError( + "The following condition must be satisfied: min_length >= min_interval >= hop_size" + ) if not max_sil_kept >= hop_size: - raise ValueError("The following condition must be satisfied: max_sil_kept >= hop_size") + raise ValueError( + "The following condition must be satisfied: max_sil_kept >= hop_size" + ) min_interval = sr * min_interval / 1000 self.threshold = 10 ** (threshold / 20.0) self.hop_size = round(sr * hop_size / 1000) @@ -202,9 +205,13 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2. def _apply_slice(self, waveform, begin, end): if len(waveform.shape) > 1: - return waveform[:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)] + return waveform[ + :, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size) + ] else: - return waveform[begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)] + return waveform[ + begin * self.hop_size : min(waveform.shape[0], end * self.hop_size) + ] # @timeit def slice(self, waveform): @@ -214,7 +221,9 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2. samples = waveform if samples.shape[0] <= self.min_length: return [waveform] - rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0) + rms_list = librosa.feature.rms( + y=samples, frame_length=self.win_size, hop_length=self.hop_size + ).squeeze(0) sil_tags = [] silence_start = None clip_start = 0 @@ -230,7 +239,10 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2. continue # Clear recorded silence start if interval is not enough or clip is too short is_leading_silence = silence_start == 0 and i > self.max_sil_kept - need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length + need_slice_middle = ( + i - silence_start >= self.min_interval + and i - clip_start >= self.min_length + ) if not is_leading_silence and not need_slice_middle: silence_start = None continue @@ -243,10 +255,21 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2. sil_tags.append((pos, pos)) clip_start = pos elif i - silence_start <= self.max_sil_kept * 2: - pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin() + pos = rms_list[ + i - self.max_sil_kept : silence_start + self.max_sil_kept + 1 + ].argmin() pos += i - self.max_sil_kept - pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start - pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept + pos_l = ( + rms_list[ + silence_start : silence_start + self.max_sil_kept + 1 + ].argmin() + + silence_start + ) + pos_r = ( + rms_list[i - self.max_sil_kept : i + 1].argmin() + + i + - self.max_sil_kept + ) if silence_start == 0: sil_tags.append((0, pos_r)) clip_start = pos_r @@ -254,8 +277,17 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2. sil_tags.append((min(pos_l, pos), max(pos_r, pos))) clip_start = max(pos_r, pos) else: - pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start - pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept + pos_l = ( + rms_list[ + silence_start : silence_start + self.max_sil_kept + 1 + ].argmin() + + silence_start + ) + pos_r = ( + rms_list[i - self.max_sil_kept : i + 1].argmin() + + i + - self.max_sil_kept + ) if silence_start == 0: sil_tags.append((0, pos_r)) else: @@ -264,7 +296,10 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2. silence_start = None # Deal with trailing silence. total_frames = rms_list.shape[0] - if silence_start is not None and total_frames - silence_start >= self.min_interval: + if ( + silence_start is not None + and total_frames - silence_start >= self.min_interval + ): silence_end = min(total_frames, silence_start + self.max_sil_kept) pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start sil_tags.append((pos, total_frames + 1)) @@ -274,7 +309,13 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2. else: chunks = [] if sil_tags[0][0] > 0: - chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]), 0, int(sil_tags[0][0] * self.hop_size)]) + chunks.append( + [ + self._apply_slice(waveform, 0, sil_tags[0][0]), + 0, + int(sil_tags[0][0] * self.hop_size), + ] + ) for i in range(len(sil_tags) - 1): chunks.append( [ @@ -368,12 +409,18 @@ def start_training( file_raw = os.path.join(path_project, "raw.arrow") if not os.path.isfile(file_raw): - yield f"There is no file {file_raw}", gr.update(interactive=True), gr.update(interactive=False) + yield f"There is no file {file_raw}", gr.update(interactive=True), gr.update( + interactive=False + ) return # Check if a training process is already running if training_process is not None: - return "Train run already!", gr.update(interactive=False), gr.update(interactive=True) + return ( + "Train run already!", + gr.update(interactive=False), + gr.update(interactive=True), + ) yield "start train", gr.update(interactive=False), gr.update(interactive=False) @@ -460,7 +507,9 @@ def start_training( training_process = subprocess.Popen(cmd, shell=True) time.sleep(5) - yield "train start", gr.update(interactive=False), gr.update(interactive=True) + yield "train start", gr.update(interactive=False), gr.update( + interactive=True + ) # Wait for the training process to finish training_process.wait() @@ -479,15 +528,27 @@ def start_training( env["PYTHONUNBUFFERED"] = "1" training_process = subprocess.Popen( - cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, env=env + cmd, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + env=env, + ) + yield "Training started ...", gr.update(interactive=False), gr.update( + interactive=True ) - yield "Training started ...", gr.update(interactive=False), gr.update(interactive=True) stdout_queue = queue.Queue() stderr_queue = queue.Queue() - stdout_thread = threading.Thread(target=stream_output, args=(training_process.stdout, stdout_queue)) - stderr_thread = threading.Thread(target=stream_output, args=(training_process.stderr, stderr_queue)) + stdout_thread = threading.Thread( + target=stream_output, args=(training_process.stdout, stdout_queue) + ) + stderr_thread = threading.Thread( + target=stream_output, args=(training_process.stderr, stderr_queue) + ) stdout_thread.daemon = True stderr_thread.daemon = True stdout_thread.start() @@ -499,7 +560,9 @@ def start_training( time.sleep(0.5) if training_process.poll() is None: training_process.kill() - yield "Training stopped by user.", gr.update(interactive=True), gr.update(interactive=False) + yield "Training stopped by user.", gr.update( + interactive=True + ), gr.update(interactive=False) break process_status = training_process.poll() @@ -510,7 +573,8 @@ def start_training( output = stdout_queue.get_nowait() print(output, end="") match = re.search( - r"Epoch (\d+)/(\d+):\s+(\d+)%\|.*\[(\d+:\d+)<.*?loss=(\d+\.\d+), update=(\d+)", output + r"Epoch (\d+)/(\d+):\s+(\d+)%\|.*\[(\d+:\d+)<.*?loss=(\d+\.\d+), update=(\d+)", + output, ) if match: current_epoch = match.group(1) @@ -526,9 +590,13 @@ def start_training( f"Loss: {loss}, " f"Update: {current_update}" ) - yield message, gr.update(interactive=False), gr.update(interactive=True) + yield message, gr.update(interactive=False), gr.update( + interactive=True + ) elif output.strip(): - yield output, gr.update(interactive=False), gr.update(interactive=True) + yield output, gr.update(interactive=False), gr.update( + interactive=True + ) except queue.Empty: pass @@ -538,11 +606,17 @@ def start_training( error_output = stderr_queue.get_nowait() print(error_output, end="") if error_output.strip(): - yield f"{error_output.strip()}", gr.update(interactive=False), gr.update(interactive=True) + yield f"{error_output.strip()}", gr.update( + interactive=False + ), gr.update(interactive=True) except queue.Empty: pass - if process_status is not None and stdout_queue.empty() and stderr_queue.empty(): + if ( + process_status is not None + and stdout_queue.empty() + and stderr_queue.empty() + ): if process_status != 0: yield ( f"Process crashed with exit code {process_status}!", @@ -585,7 +659,11 @@ def stop_training(): global training_process, stop_signal if training_process is None: - return "Train not running !", gr.update(interactive=True), gr.update(interactive=False) + return ( + "Train not running !", + gr.update(interactive=True), + gr.update(interactive=False), + ) terminate_process_tree(training_process.pid) # training_process = None stop_signal = True @@ -616,7 +694,9 @@ def create_data_project(name, tokenizer_type): return gr.update(choices=project_list, value=name) -def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()): +def transcribe_all( + name_project, audio_files, language, user=False, progress=gr.Progress() +): path_project = os.path.join(path_data, name_project) path_dataset = os.path.join(path_project, "dataset") path_project_wavs = os.path.join(path_project, "wavs") @@ -652,11 +732,15 @@ def transcribe_all(name_project, audio_files, language, user=False, progress=gr. num = 0 error_num = 0 data = "" - for file_audio in progress.tqdm(file_audios, desc="transcribe files", total=len((file_audios))): + for file_audio in progress.tqdm( + file_audios, desc="transcribe files", total=len((file_audios)) + ): audio, _ = librosa.load(file_audio, sr=24000, mono=True) list_slicer = slicer.slice(audio) - for chunk, start, end in progress.tqdm(list_slicer, total=len(list_slicer), desc="slicer files"): + for chunk, start, end in progress.tqdm( + list_slicer, total=len(list_slicer), desc="slicer files" + ): name_segment = os.path.join(f"segment_{num}") file_segment = os.path.join(path_project_wavs, f"{name_segment}.wav") @@ -684,7 +768,9 @@ def transcribe_all(name_project, audio_files, language, user=False, progress=gr. else: error_text = "" - return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}" + return ( + f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}" + ) def format_seconds_to_hms(seconds): @@ -697,7 +783,18 @@ def format_seconds_to_hms(seconds): def get_correct_audio_path( audio_input, base_path="wavs", - supported_formats=("wav", "mp3", "aac", "flac", "m4a", "alac", "ogg", "aiff", "wma", "amr"), + supported_formats=( + "wav", + "mp3", + "aac", + "flac", + "m4a", + "alac", + "ogg", + "aiff", + "wma", + "amr", + ), ): file_audio = None @@ -721,7 +818,9 @@ def get_correct_audio_path( file_audio = potential_file break else: - file_audio = os.path.join(base_path, f"{audio_input}.{supported_formats[0]}") + file_audio = os.path.join( + base_path, f"{audio_input}.{supported_formats[0]}" + ) return file_audio @@ -791,7 +890,10 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()): lenght += duration if duration_list == []: - return f"Error: No audio files found in the specified path : {path_project_wavs}", "" + return ( + f"Error: No audio files found in the specified path : {path_project_wavs}", + "", + ) min_second = round(min(duration_list), 2) max_second = round(max(duration_list), 2) @@ -806,7 +908,9 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()): new_vocal = "" if not ch_tokenizer: if not os.path.isfile(file_vocab): - file_vocab_finetune = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt") + file_vocab_finetune = os.path.join( + path_data, "Emilia_ZH_EN_pinyin/vocab.txt" + ) if not os.path.isfile(file_vocab_finetune): return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!", "" shutil.copy2(file_vocab_finetune, file_vocab) @@ -893,7 +997,9 @@ def calculate_train( # rough estimate of batch size if batch_size_type == "frame": - batch_size_per_gpu = max(int(38400 * (avg_gpu_memory - 5) / 75), int(max_sample_length)) + batch_size_per_gpu = max( + int(38400 * (avg_gpu_memory - 5) / 75), int(max_sample_length) + ) elif batch_size_type == "sample": batch_size_per_gpu = int(200 / (total_duration / total_samples)) @@ -906,7 +1012,9 @@ def calculate_train( max_updates = 1200000 if batch_size_type == "frame": - mini_batch_duration = batch_size_per_gpu * gpu_count * hop_length / sampling_rate + mini_batch_duration = ( + batch_size_per_gpu * gpu_count * hop_length / sampling_rate + ) updates_per_epoch = total_duration / mini_batch_duration elif batch_size_type == "sample": updates_per_epoch = total_samples / batch_size_per_gpu / gpu_count @@ -928,7 +1036,9 @@ def calculate_train( ) -def prune_checkpoint(checkpoint_path: str, new_checkpoint_path: str, save_ema: bool, safetensors: bool) -> str: +def prune_checkpoint( + checkpoint_path: str, new_checkpoint_path: str, save_ema: bool, safetensors: bool +) -> str: try: checkpoint = torch.load(checkpoint_path, weights_only=True) print("Original Checkpoint Keys:", checkpoint.keys()) @@ -1039,7 +1149,9 @@ def vocab_extend(project_name, symbols, model_type): f.write("\n".join(vocab)) if model_type == "F5TTS_v1_Base": - ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors")) + ckpt_path = str( + cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors") + ) elif model_type == "F5TTS_Base": ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) elif model_type == "E2TTS_Base": @@ -1052,9 +1164,13 @@ def vocab_extend(project_name, symbols, model_type): os.makedirs(new_ckpt_path, exist_ok=True) # Add pretrained_ prefix to model when copying for consistency with finetune_cli.py - new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_" + os.path.basename(ckpt_path)) + new_ckpt_file = os.path.join( + new_ckpt_path, "pretrained_" + os.path.basename(ckpt_path) + ) - size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new) + size = expand_model_embeddings( + ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new + ) vocab_new = "\n".join(miss_symbols) return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {vocab_size_new}\nnew symbols :\n{vocab_new}" @@ -1159,7 +1275,17 @@ def get_random_sample_infer(project_name): def infer( - project, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, use_ema, speed, seed, remove_silence + project, + file_checkpoint, + exp_name, + ref_text, + ref_audio, + gen_text, + nfe_step, + use_ema, + speed, + seed, + remove_silence, ): global last_checkpoint, last_device, tts_api, last_ema @@ -1171,7 +1297,12 @@ def infer( else: device_test = None - if last_checkpoint != file_checkpoint or last_device != device_test or last_ema != use_ema or tts_api is None: + if ( + last_checkpoint != file_checkpoint + or last_device != device_test + or last_ema != use_ema + or tts_api is None + ): if last_checkpoint != file_checkpoint: last_checkpoint = file_checkpoint @@ -1184,7 +1315,11 @@ def infer( vocab_file = os.path.join(path_data, project, "vocab.txt") tts_api = F5TTS( - model=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema + model=exp_name, + ckpt_file=file_checkpoint, + vocab_file=vocab_file, + device=device_test, + use_ema=use_ema, ) print("update >> ", device_test, file_checkpoint, use_ema) @@ -1207,7 +1342,11 @@ def infer( def check_finetune(finetune): - return gr.update(interactive=finetune), gr.update(interactive=finetune), gr.update(interactive=finetune) + return ( + gr.update(interactive=finetune), + gr.update(interactive=finetune), + gr.update(interactive=finetune), + ) def get_checkpoints_project(project_name, is_gradio=True): @@ -1218,21 +1357,29 @@ def get_checkpoints_project(project_name, is_gradio=True): if os.path.isdir(path_project_ckpts): files_checkpoints = glob(os.path.join(path_project_ckpts, project_name, "*.pt")) # Separate pretrained and regular checkpoints - pretrained_checkpoints = [f for f in files_checkpoints if "pretrained_" in os.path.basename(f)] + pretrained_checkpoints = [ + f for f in files_checkpoints if "pretrained_" in os.path.basename(f) + ] regular_checkpoints = [ f for f in files_checkpoints - if "pretrained_" not in os.path.basename(f) and "model_last.pt" not in os.path.basename(f) + if "pretrained_" not in os.path.basename(f) + and "model_last.pt" not in os.path.basename(f) + ] + last_checkpoint = [ + f for f in files_checkpoints if "model_last.pt" in os.path.basename(f) ] - last_checkpoint = [f for f in files_checkpoints if "model_last.pt" in os.path.basename(f)] # Sort regular checkpoints by number regular_checkpoints = sorted( - regular_checkpoints, key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0]) + regular_checkpoints, + key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0]), ) # Combine in order: pretrained, regular, last - files_checkpoints = pretrained_checkpoints + regular_checkpoints + last_checkpoint + files_checkpoints = ( + pretrained_checkpoints + regular_checkpoints + last_checkpoint + ) else: files_checkpoints = [] @@ -1250,10 +1397,19 @@ def get_audio_project(project_name, is_gradio=True): project_name = project_name.replace("_pinyin", "").replace("_char", "") if os.path.isdir(path_project_ckpts): - files_audios = glob(os.path.join(path_project_ckpts, project_name, "samples", "*.wav")) - files_audios = sorted(files_audios, key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0])) + files_audios = glob( + os.path.join(path_project_ckpts, project_name, "samples", "*.wav") + ) + files_audios = sorted( + files_audios, + key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0]), + ) - files_audios = [item.replace("_gen.wav", "") for item in files_audios if item.endswith("_gen.wav")] + files_audios = [ + item.replace("_gen.wav", "") + for item in files_audios + if item.endswith("_gen.wav") + ] else: files_audios = [] @@ -1375,23 +1531,35 @@ For tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion with gr.Row(): projects, projects_selelect = get_list_projects() - tokenizer_type = gr.Radio(label="Tokenizer Type", choices=["pinyin", "char", "custom"], value="pinyin") + tokenizer_type = gr.Radio( + label="Tokenizer Type", choices=["pinyin", "char", "custom"], value="pinyin" + ) project_name = gr.Textbox(label="Project Name", value="my_speak") bt_create = gr.Button("Create a New Project") with gr.Row(): cm_project = gr.Dropdown( - choices=projects, value=projects_selelect, label="Project", allow_custom_value=True, scale=6 + choices=projects, + value=projects_selelect, + label="Project", + allow_custom_value=True, + scale=6, ) ch_refresh_project = gr.Button("Refresh", scale=1) - bt_create.click(fn=create_data_project, inputs=[project_name, tokenizer_type], outputs=[cm_project]) + bt_create.click( + fn=create_data_project, + inputs=[project_name, tokenizer_type], + outputs=[cm_project], + ) with gr.Tabs(): with gr.TabItem("Transcribe Data"): - gr.Markdown("""```plaintext + gr.Markdown( + """```plaintext Skip this step if you have your dataset, metadata.csv, and a folder wavs with all the audio files. -```""") +```""" + ) ch_manual = gr.Checkbox(label="Audio from Path", value=False) @@ -1409,7 +1577,9 @@ Skip this step if you have your dataset, metadata.csv, and a folder wavs with al visible=False, ) - audio_speaker = gr.File(label="Voice", type="filepath", file_count="multiple") + audio_speaker = gr.File( + label="Voice", type="filepath", file_count="multiple" + ) txt_lang = gr.Textbox(label="Language", value="English") bt_transcribe = bt_create = gr.Button("Transcribe") txt_info_transcribe = gr.Textbox(label="Info", value="") @@ -1418,7 +1588,11 @@ Skip this step if you have your dataset, metadata.csv, and a folder wavs with al inputs=[cm_project, audio_speaker, txt_lang, ch_manual], outputs=[txt_info_transcribe], ) - ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe]) + ch_manual.change( + fn=check_user, + inputs=[ch_manual], + outputs=[audio_speaker, mark_info_transcribe], + ) random_sample_transcribe = gr.Button("Random Sample") @@ -1433,19 +1607,25 @@ Skip this step if you have your dataset, metadata.csv, and a folder wavs with al ) with gr.TabItem("Vocab Check"): - gr.Markdown("""```plaintext + gr.Markdown( + """```plaintext Check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are included. For fine-tuning a new language. -```""") +```""" + ) check_button = gr.Button("Check Vocab") txt_info_check = gr.Textbox(label="Info", value="") - gr.Markdown("""```plaintext + gr.Markdown( + """```plaintext Using the extended model, you can finetune to a new language that is missing symbols in the vocab. This creates a new model with a new vocabulary size and saves it in your ckpts/project folder. -```""") +```""" + ) exp_name_extend = gr.Radio( - label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base" + label="Model", + choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], + value="F5TTS_v1_Base", ) with gr.Row(): @@ -1460,18 +1640,26 @@ Using the extended model, you can finetune to a new language that is missing sym extend_button = gr.Button("Extend") txt_info_extend = gr.Textbox(label="Info", value="") - txt_extend.change(vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol]) + txt_extend.change( + vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol] + ) check_button.click( - fn=vocab_check, inputs=[cm_project, tokenizer_type], outputs=[txt_info_check, txt_extend] + fn=vocab_check, + inputs=[cm_project, tokenizer_type], + outputs=[txt_info_check, txt_extend], ) extend_button.click( - fn=vocab_extend, inputs=[cm_project, txt_extend, exp_name_extend], outputs=[txt_info_extend] + fn=vocab_extend, + inputs=[cm_project, txt_extend, exp_name_extend], + outputs=[txt_info_extend], ) with gr.TabItem("Prepare Data"): - gr.Markdown("""```plaintext + gr.Markdown( + """```plaintext Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt -```""") +```""" + ) gr.Markdown( """```plaintext @@ -1497,14 +1685,18 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt ```""" ) - ch_tokenizern = gr.Checkbox(label="Create Vocabulary", value=False, visible=False) + ch_tokenizern = gr.Checkbox( + label="Create Vocabulary", value=False, visible=False + ) bt_prepare = bt_create = gr.Button("Prepare") txt_info_prepare = gr.Textbox(label="Info", value="") txt_vocab_prepare = gr.Textbox(label="Vocab", value="") bt_prepare.click( - fn=create_metadata, inputs=[cm_project, ch_tokenizern], outputs=[txt_info_prepare, txt_vocab_prepare] + fn=create_metadata, + inputs=[cm_project, ch_tokenizern], + outputs=[txt_info_prepare, txt_vocab_prepare], ) random_sample_prepare = gr.Button("Random Sample") @@ -1514,18 +1706,26 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt random_audio_prepare = gr.Audio(label="Audio", type="filepath") random_sample_prepare.click( - fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare] + fn=get_random_sample_prepare, + inputs=[cm_project], + outputs=[random_text_prepare, random_audio_prepare], ) with gr.TabItem("Train Model"): - gr.Markdown("""```plaintext + gr.Markdown( + """```plaintext The auto-setting is still experimental. Set a large value of epoch if not sure; and keep last N checkpoints if limited disk space. If you encounter a memory error, try reducing the batch size per GPU to a smaller number. -```""") +```""" + ) with gr.Row(): - exp_name = gr.Radio(label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"]) + exp_name = gr.Radio( + label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"] + ) tokenizer_file = gr.Textbox(label="Tokenizer File") - file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint") + file_checkpoint_train = gr.Textbox( + label="Path to the Pretrained Checkpoint" + ) with gr.Row(): ch_finetune = bt_create = gr.Checkbox(label="Finetune") @@ -1544,11 +1744,17 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle choices=["frame", "sample"], info="frame is calculated as seconds * sampling_rate / hop_length", ) - batch_size_per_gpu = gr.Number(label="Batch Size per GPU", info="N frames or N samples") + batch_size_per_gpu = gr.Number( + label="Batch Size per GPU", info="N frames or N samples" + ) grad_accumulation_steps = gr.Number( - label="Gradient Accumulation Steps", info="Effective batch size is multiplied by this value" + label="Gradient Accumulation Steps", + info="Effective batch size is multiplied by this value", + ) + max_samples = gr.Number( + label="Max Samples", + info="Maximum number of samples per single GPU batch", ) - max_samples = gr.Number(label="Max Samples", info="Maximum number of samples per single GPU batch") with gr.Row(): save_per_updates = gr.Number( @@ -1572,8 +1778,12 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle with gr.Row(): ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer") - mixed_precision = gr.Radio(label="Mixed Precision", choices=["none", "fp16", "bf16"]) - cd_logger = gr.Radio(label="Logger", choices=["none", "wandb", "tensorboard"]) + mixed_precision = gr.Radio( + label="Mixed Precision", choices=["none", "fp16", "bf16"] + ) + cd_logger = gr.Radio( + label="Logger", choices=["none", "wandb", "tensorboard"] + ) with gr.Column(): start_button = gr.Button("Start Training") stop_button = gr.Button("Stop Training", interactive=False) @@ -1644,12 +1854,20 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle interactive=True, ) bt_stream_audio = gr.Button("Refresh", scale=1) - bt_stream_audio.click(fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio]) - cm_project.change(fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio]) + bt_stream_audio.click( + fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio] + ) + cm_project.change( + fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio] + ) with gr.Row(): - audio_ref_stream = gr.Audio(label="Original", type="filepath", value=select_audio_ref) - audio_gen_stream = gr.Audio(label="Generate", type="filepath", value=select_audio_gen) + audio_ref_stream = gr.Audio( + label="Original", type="filepath", value=select_audio_ref + ) + audio_gen_stream = gr.Audio( + label="Generate", type="filepath", value=select_audio_gen + ) ch_list_audio.change( fn=get_audio_select, @@ -1684,7 +1902,9 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle ], outputs=[txt_info_train, start_button, stop_button], ) - stop_button.click(fn=stop_training, outputs=[txt_info_train, start_button, stop_button]) + stop_button.click( + fn=stop_training, outputs=[txt_info_train, start_button, stop_button] + ) bt_calculate.click( fn=calculate_train, @@ -1709,7 +1929,9 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle ) ch_finetune.change( - check_finetune, inputs=[ch_finetune], outputs=[file_checkpoint_train, tokenizer_file, tokenizer_type] + check_finetune, + inputs=[ch_finetune], + outputs=[file_checkpoint_train, tokenizer_file, tokenizer_type], ) def setup_load_settings(): @@ -1751,26 +1973,39 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle ) with gr.TabItem("Test Model"): - gr.Markdown("""```plaintext + gr.Markdown( + """```plaintext Check the use_ema setting (True or False) for your model to see what works best for you. Set seed to -1 for random. -```""") +```""" + ) exp_name = gr.Radio( - label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base" + label="Model", + choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], + value="F5TTS_v1_Base", + ) + list_checkpoints, checkpoint_select = get_checkpoints_project( + projects_selelect, False ) - list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False) with gr.Row(): nfe_step = gr.Number(label="NFE Step", value=32) - speed = gr.Slider(label="Speed", value=1.0, minimum=0.3, maximum=2.0, step=0.1) + speed = gr.Slider( + label="Speed", value=1.0, minimum=0.3, maximum=2.0, step=0.1 + ) seed = gr.Number(label="Random Seed", value=-1, minimum=-1) remove_silence = gr.Checkbox(label="Remove Silence") with gr.Row(): ch_use_ema = gr.Checkbox( - label="Use EMA", value=True, info="Turn off at early stage might offer better results" + label="Use EMA", + value=True, + info="Turn off at early stage might offer better results", ) cm_checkpoint = gr.Dropdown( - choices=list_checkpoints, value=checkpoint_select, label="Checkpoints", allow_custom_value=True + choices=list_checkpoints, + value=checkpoint_select, + label="Checkpoints", + allow_custom_value=True, ) bt_checkpoint_refresh = gr.Button("Refresh") @@ -1781,7 +2016,9 @@ Check the use_ema setting (True or False) for your model to see what works best gen_text = gr.Textbox(label="Text to Generate") random_sample_infer.click( - fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio] + fn=get_random_sample_infer, + inputs=[cm_project], + outputs=[ref_text, gen_text, ref_audio], ) with gr.Row(): @@ -1809,23 +2046,36 @@ Check the use_ema setting (True or False) for your model to see what works best outputs=[gen_audio, txt_info_gpu, seed_info], ) - bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint]) - cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint]) + bt_checkpoint_refresh.click( + fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint] + ) + cm_project.change( + fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint] + ) with gr.TabItem("Prune Checkpoint"): - gr.Markdown("""```plaintext + gr.Markdown( + """```plaintext Reduce the Base model size from 5GB to 1.3GB. The new checkpoint file prunes out optimizer and etc., can be used for inference or finetuning afterward, but not able to resume pretraining. -```""") +```""" + ) txt_path_checkpoint = gr.Textbox(label="Path to Checkpoint:") txt_path_checkpoint_small = gr.Textbox(label="Path to Output:") with gr.Row(): ch_save_ema = gr.Checkbox(label="Save EMA checkpoint", value=True) - ch_safetensors = gr.Checkbox(label="Save with safetensors format", value=True) + ch_safetensors = gr.Checkbox( + label="Save with safetensors format", value=True + ) txt_info_reduse = gr.Textbox(label="Info", value="") reduse_button = gr.Button("Prune") reduse_button.click( fn=prune_checkpoint, - inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_save_ema, ch_safetensors], + inputs=[ + txt_path_checkpoint, + txt_path_checkpoint_small, + ch_save_ema, + ch_safetensors, + ], outputs=[txt_info_reduse], ) @@ -1858,7 +2108,9 @@ Reduce the Base model size from 5GB to 1.3GB. The new checkpoint file prunes out def main(port, host, share, api): global app print("Starting app...") - app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api) + app.queue(api_open=api).launch( + server_name=host, server_port=port, share=share, show_api=api + ) if __name__ == "__main__": diff --git a/f5_tts/train/train.py b/f5_tts/train/train.py index b948ab194e6e46e14d46392e4a23ff9b76c1d1ae..200e29216323f0c0e9cd2284ae2a79c1aea40800 100644 --- a/f5_tts/train/train.py +++ b/f5_tts/train/train.py @@ -10,11 +10,16 @@ from f5_tts.model import CFM, Trainer from f5_tts.model.dataset import load_dataset from f5_tts.model.utils import get_tokenizer +os.chdir( + str(files("f5_tts").joinpath("../..")) +) # change working directory to root of project (local editable) -os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable) - -@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None) +@hydra.main( + version_base="1.3", + config_path=str(files("f5_tts").joinpath("configs")), + config_name=None, +) def main(model_cfg): model_cls = hydra.utils.get_class(f"f5_tts.model.{model_cfg.model.backbone}") model_arc = model_cfg.model.arch @@ -33,7 +38,11 @@ def main(model_cfg): # set model model = CFM( - transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=model_cfg.model.mel_spec.n_mel_channels), + transformer=model_cls( + **model_arc, + text_num_embeds=vocab_size, + mel_dim=model_cfg.model.mel_spec.n_mel_channels, + ), mel_spec_kwargs=model_cfg.model.mel_spec, vocab_char_map=vocab_char_map, ) @@ -46,7 +55,9 @@ def main(model_cfg): num_warmup_updates=model_cfg.optim.num_warmup_updates, save_per_updates=model_cfg.ckpts.save_per_updates, keep_last_n_checkpoints=model_cfg.ckpts.keep_last_n_checkpoints, - checkpoint_path=str(files("f5_tts").joinpath(f"../../{model_cfg.ckpts.save_dir}")), + checkpoint_path=str( + files("f5_tts").joinpath(f"../../{model_cfg.ckpts.save_dir}") + ), batch_size_per_gpu=model_cfg.datasets.batch_size_per_gpu, batch_size_type=model_cfg.datasets.batch_size_type, max_samples=model_cfg.datasets.max_samples, @@ -65,7 +76,9 @@ def main(model_cfg): model_cfg_dict=OmegaConf.to_container(model_cfg, resolve=True), ) - train_dataset = load_dataset(model_cfg.datasets.name, tokenizer, mel_spec_kwargs=model_cfg.model.mel_spec) + train_dataset = load_dataset( + model_cfg.datasets.name, tokenizer, mel_spec_kwargs=model_cfg.model.mel_spec + ) trainer.train( train_dataset, num_workers=model_cfg.datasets.num_workers, diff --git a/grpo_duration_trainer.py b/grpo_duration_trainer.py index 34b98c2be168b3fe491b855071dbd4f6d74340a8..dde60f2d0c4100e508c250d19d92fd89c2d61f86 100644 --- a/grpo_duration_trainer.py +++ b/grpo_duration_trainer.py @@ -1,25 +1,24 @@ -import os +import copy import gc +import io import json +import os import random import time -import io -import copy -from typing import List, Dict, Any, Optional, Callable, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F +import wandb +from accelerate import Accelerator +from accelerate.utils import DistributedDataParallelKwargs from torch.optim import AdamW from torch.optim.lr_scheduler import LinearLR, SequentialLR from torch.utils.data import DataLoader, Dataset, SequentialSampler, Subset from tqdm import tqdm -from accelerate import Accelerator -from accelerate.utils import DistributedDataParallelKwargs -import wandb - -from f5_tts.model.dataset import collate_fn, DynamicBatchSampler +from f5_tts.model.dataset import DynamicBatchSampler, collate_fn from f5_tts.model.utils import list_str_to_idx # torch.autograd.set_detect_anomaly(True) @@ -33,16 +32,16 @@ def safe_sample(logits, temperature=1.0): """ # Apply temperature scaling scaled_logits = logits / temperature - + # Compute categorical distribution probs = F.softmax(scaled_logits, dim=-1) - + # Sample from the distribution once per batch element samples = torch.multinomial(probs, num_samples=1) # (B, 1) - + # Convert to one-hot encoding one_hot_samples = torch.zeros_like(probs).scatter_(1, samples, 1) - + return one_hot_samples @@ -51,51 +50,45 @@ class GRPODurationTrainer: Trainer class that implements GRPO (Generative Reinforcement Learning from Preference Optimization) for a duration predictor in text-to-speech synthesis. """ + def __init__( self, - model, # Duration predictor model - inference_fn, # Function to generate speech - reward_fn, # Function to compute rewards from generated speech - - vocab_size: int, # Size of the vocabulary - vocab_char_map: dict, # Mapping from characters to token IDs - + model, # Duration predictor model + inference_fn, # Function to generate speech + reward_fn, # Function to compute rewards from generated speech + vocab_size: int, # Size of the vocabulary + vocab_char_map: dict, # Mapping from characters to token IDs # Duration model parameters - n_class: int = 301, # Number of duration classes - n_frame_per_class: int = 10, # Number of frames per class + n_class: int = 301, # Number of duration classes + n_frame_per_class: int = 10, # Number of frames per class gumbel_tau: int = 0.7, - # GRPO parameters - beta: float = 0.04, # KL regularization weight - clip_param: float = 0.2, # PPO clip parameter - num_pre_samples: int = 8, # Number of samples per prompt - compute_gen_logps: bool = True, # Whether to compute generation log probabilities - + beta: float = 0.04, # KL regularization weight + clip_param: float = 0.2, # PPO clip parameter + num_pre_samples: int = 8, # Number of samples per prompt + compute_gen_logps: bool = True, # Whether to compute generation log probabilities # Training parameters learning_rate: float = 5e-6, num_warmup_updates: int = 10000, save_per_updates: int = 10000, checkpoint_path: Optional[str] = None, - all_steps: int = 100000, # Total training steps - + all_steps: int = 100000, # Total training steps # Batch parameters batch_size: int = 8, batch_size_type: str = "sample", max_samples: int = 16, grad_accumulation_steps: int = 2, max_grad_norm: float = 1.0, - # Logging parameters logger: Optional[str] = "wandb", wandb_project: str = "tts-duration-grpo", wandb_run_name: str = "grpo_run", wandb_resume_id: Optional[str] = None, - accelerate_kwargs: dict = dict(), ): # Initialize accelerator for distributed training ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) - + if logger == "wandb" and not wandb.api.api_key: logger = None print(f"Using logger: {logger}") @@ -110,7 +103,13 @@ class GRPODurationTrainer: self.logger = logger if self.logger == "wandb": if wandb_resume_id: - init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}} + init_kwargs = { + "wandb": { + "resume": "allow", + "name": wandb_run_name, + "id": wandb_resume_id, + } + } else: init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}} @@ -134,24 +133,27 @@ class GRPODurationTrainer: ) elif self.logger == "tensorboard": from torch.utils.tensorboard import SummaryWriter + self.writer = SummaryWriter(log_dir=f"runs/{wandb_run_name}") # Store model, inference function, and reward function self.model = model - + # Create reference model (frozen clone of the initial model) self.ref_model = copy.deepcopy(model) for param in self.ref_model.parameters(): param.requires_grad = False self.ref_model.eval() - + # prepare inference_fn self.inference_fn = inference_fn self.inference_fn.scale = self.inference_fn.scale.to(self.accelerator.device) - self.inference_fn.tts_model = self.inference_fn.tts_model.to(self.accelerator.device) + self.inference_fn.tts_model = self.inference_fn.tts_model.to( + self.accelerator.device + ) # prepare reward_fn self.reward_fn = reward_fn - + # Store vocabulary and mapping self.vocab_size = vocab_size self.vocab_char_map = vocab_char_map @@ -160,70 +162,87 @@ class GRPODurationTrainer: self.n_class = n_class self.n_frame_per_class = n_frame_per_class self.gumbel_tau = gumbel_tau - + # Store GRPO parameters self.beta = beta self.clip_param = clip_param self.num_pre_samples = num_pre_samples self.compute_gen_logps = compute_gen_logps - + # Store training parameters self.learning_rate = learning_rate self.num_warmup_updates: int = num_warmup_updates self.save_per_updates = save_per_updates self.checkpoint_path = checkpoint_path or f"ckpts/{wandb_run_name}" self.all_steps = all_steps - + # Store batch parameters self.batch_size = batch_size self.batch_size_type = batch_size_type self.max_samples = max_samples self.grad_accumulation_steps = grad_accumulation_steps self.max_grad_norm = max_grad_norm - + # Initialize optimizer self.optimizer = AdamW(model.parameters(), lr=learning_rate) - + # Prepare model and optimizer with accelerator - self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) + self.model, self.optimizer = self.accelerator.prepare( + self.model, self.optimizer + ) self.ref_model = self.accelerator.prepare(self.ref_model) - self.reward_fn, self.inference_fn = self.accelerator.prepare(self.reward_fn, self.inference_fn) - + self.reward_fn, self.inference_fn = self.accelerator.prepare( + self.reward_fn, self.inference_fn + ) + # GRPO batch queue self.batch_queue = [] - + # Store distributed rank - self.rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + self.rank = ( + torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + ) + + self.device = f"cuda:{self.rank}" - self.device = f'cuda:{self.rank}' - @property def is_main(self): return self.accelerator.is_main_process - + def save_checkpoint(self, step, last=False): """Save model and optimizer state""" self.accelerator.wait_for_everyone() if self.is_main: checkpoint = dict( model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(), - optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(), - scheduler_state_dict=self.scheduler.state_dict() if hasattr(self, 'scheduler') else None, + optimizer_state_dict=self.accelerator.unwrap_model( + self.optimizer + ).state_dict(), + scheduler_state_dict=( + self.scheduler.state_dict() if hasattr(self, "scheduler") else None + ), step=step, ) if not os.path.exists(self.checkpoint_path): os.makedirs(self.checkpoint_path) if last: - self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt") + self.accelerator.save( + checkpoint, f"{self.checkpoint_path}/model_last.pt" + ) else: - self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt") - + self.accelerator.save( + checkpoint, f"{self.checkpoint_path}/model_{step}.pt" + ) + def load_checkpoint(self): """Load latest checkpoint if available""" if ( not self.checkpoint_path or not os.path.exists(self.checkpoint_path) - or not any(filename.endswith(".pt") for filename in os.listdir(self.checkpoint_path)) + or not any( + filename.endswith(".pt") + for filename in os.listdir(self.checkpoint_path) + ) ): return 0 @@ -236,29 +255,35 @@ class GRPODurationTrainer: key=lambda x: int("".join(filter(str.isdigit, x))), )[-1] - print(f'Loading checkpoint: {latest_checkpoint}') + print(f"Loading checkpoint: {latest_checkpoint}") checkpoint = torch.load( - f"{self.checkpoint_path}/{latest_checkpoint}", - weights_only=True, - map_location="cpu" + f"{self.checkpoint_path}/{latest_checkpoint}", + weights_only=True, + map_location="cpu", ) if "step" in checkpoint: - self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"]) - self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"]) - if hasattr(self, 'scheduler') and checkpoint["scheduler_state_dict"]: + self.accelerator.unwrap_model(self.model).load_state_dict( + checkpoint["model_state_dict"] + ) + self.accelerator.unwrap_model(self.optimizer).load_state_dict( + checkpoint["optimizer_state_dict"] + ) + if hasattr(self, "scheduler") and checkpoint["scheduler_state_dict"]: self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) step = checkpoint["step"] else: - self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"]) + self.accelerator.unwrap_model(self.model).load_state_dict( + checkpoint["model_state_dict"] + ) step = 0 - + del checkpoint gc.collect() - - print(f'Successfully loaded checkpoint at step {step}') + + print(f"Successfully loaded checkpoint at step {step}") return step - + @torch.no_grad() def get_ref_logps(self, text_ids, mel, sampled_classes): """ @@ -268,31 +293,29 @@ class GRPODurationTrainer: K = self.num_pre_samples with torch.no_grad(): ref_logits = self.ref_model(text_ids=text_ids, mel=mel)[:, -1, :] - ref_logits = ref_logits.unsqueeze(1).repeat(1, K, 1).view(B*K, -1) + ref_logits = ref_logits.unsqueeze(1).repeat(1, K, 1).view(B * K, -1) ref_log_probs = F.log_softmax(ref_logits, dim=-1) ref_logps = torch.gather( - ref_log_probs, - dim=-1, - index=sampled_classes.unsqueeze(-1) + ref_log_probs, dim=-1, index=sampled_classes.unsqueeze(-1) ).squeeze(-1) return ref_logps - + @torch.no_grad() def generate_duration_samples(self, batch_inputs): """ Generate multiple duration predictions from the model for each input and evaluate them using the inference function and reward model - + Args: batch_inputs: Dictionary with text, prompt audio, etc. - + Returns: Dictionary with duration samples, rewards, and reference logits """ if self.rank == 0: print("Generating duration samples...") - + # all_logits = [] all_text_ids = [] all_mels = [] @@ -306,30 +329,44 @@ class GRPODurationTrainer: # Fetch batch inputs # prompt_mel = batch_inputs['mel'].permute(0, 2, 1).to(self.device) - prompt_mel = batch_inputs['mel'].permute(0, 2, 1) # (B, T, 100) - prompt_text = batch_inputs['text'] + prompt_mel = batch_inputs["mel"].permute(0, 2, 1) # (B, T, 100) + prompt_text = batch_inputs["text"] batch_size = prompt_mel.shape[0] # Shift text to unpair 'mel' and 'text'; The shifted text will be synthesized - target_text = batch_inputs['target_text'] - target_text_lengths = torch.LongTensor([len(t) for t in target_text]).to(prompt_mel.device) + target_text = batch_inputs["target_text"] + target_text_lengths = torch.LongTensor([len(t) for t in target_text]).to( + prompt_mel.device + ) try: - full_text = [prompt+[' ']+target for prompt, target in zip(prompt_text, target_text)] + full_text = [ + prompt + [" "] + target + for prompt, target in zip(prompt_text, target_text) + ] except: - target_text = [batch_inputs['text'][-1]] + batch_inputs['text'][:-1] - target_text_lengths = batch_inputs['text_lengths'].clone().roll(1, 0) - full_text = [prompt+[' ']+target for prompt, target in zip(prompt_text, target_text)] + target_text = [batch_inputs["text"][-1]] + batch_inputs["text"][:-1] + target_text_lengths = batch_inputs["text_lengths"].clone().roll(1, 0) + full_text = [ + prompt + [" "] + target + for prompt, target in zip(prompt_text, target_text) + ] # Goes to reward model - target_text_ids = list_str_to_idx(target_text, self.vocab_char_map).to(self.accelerator.device) # to device, the dataloader only gives list + target_text_ids = list_str_to_idx(target_text, self.vocab_char_map).to( + self.accelerator.device + ) # to device, the dataloader only gives list # Goes to duration model and TTS - full_text_ids = list_str_to_idx(full_text, self.vocab_char_map).to(self.accelerator.device) + full_text_ids = list_str_to_idx(full_text, self.vocab_char_map).to( + self.accelerator.device + ) # Deepcopy to separate text_ids for SLP and TTS slp_text_ids = full_text_ids.detach().clone() - slp_text_ids = slp_text_ids.masked_fill(slp_text_ids==-1, self.vocab_size) # (B, L) + slp_text_ids = slp_text_ids.masked_fill( + slp_text_ids == -1, self.vocab_size + ) # (B, L) # Pre-compute duration logits K = self.num_pre_samples @@ -340,40 +377,50 @@ class GRPODurationTrainer: # Run model once for B inputs old_logits = self.model( - text_ids=slp_text_ids, # (B, L) - mel=prompt_mel # (B, T, 100) - )[:, -1, :] # (B, n_class) + text_ids=slp_text_ids, mel=prompt_mel # (B, L) # (B, T, 100) + )[ + :, -1, : + ] # (B, n_class) # Repeat each result K times along batch dimension - old_logits = old_logits.unsqueeze(1).repeat(1, K, 1) # (B, K, n_class) + old_logits = old_logits.unsqueeze(1).repeat(1, K, 1) # (B, K, n_class) # logits_nograd = logits_grad.detach().clone().view(B, K, -1) # (B, K, n_class) - for _full_text_ids, _target_text_ids, _target_text_lengths, \ - _prompt_mel, _old_logits in zip( - full_text_ids, target_text_ids, target_text_lengths, - prompt_mel, old_logits - ): + for ( + _full_text_ids, + _target_text_ids, + _target_text_lengths, + _prompt_mel, + _old_logits, + ) in zip( + full_text_ids, target_text_ids, target_text_lengths, prompt_mel, old_logits + ): - duration_sample = F.gumbel_softmax(_old_logits, tau=self.gumbel_tau, hard=True, dim=-1) - duration2frames = torch.arange(self.n_class).float().to(self.accelerator.device) * self.n_frame_per_class - est_frames = (duration_sample * duration2frames).sum(-1) # (K, ) + duration_sample = F.gumbel_softmax( + _old_logits, tau=self.gumbel_tau, hard=True, dim=-1 + ) + duration2frames = ( + torch.arange(self.n_class).float().to(self.accelerator.device) + * self.n_frame_per_class + ) + est_frames = (duration_sample * duration2frames).sum(-1) # (K, ) # Compute log probabilities of the samples sampled_classes = duration_sample.argmax(dim=-1) log_probs = F.log_softmax(_old_logits, dim=-1) gen_logps = torch.gather( - log_probs, - dim=-1, - index=sampled_classes.unsqueeze(-1) - ).squeeze(-1) # Shape: [K, n_class] - + log_probs, dim=-1, index=sampled_classes.unsqueeze(-1) + ).squeeze( + -1 + ) # Shape: [K, n_class] + # Generate speech using the sampled durations sampled_rewards = [] for i in range(K): cur_duration = est_frames[i] if cur_duration == 0: - cur_duration = cur_duration + 50 # prevent 0 duration + cur_duration = cur_duration + 50 # prevent 0 duration infer_full_text_ids = _full_text_ids.unsqueeze(0) infer_prompt_mel = _prompt_mel.unsqueeze(0) cur_duration = cur_duration.unsqueeze(0) @@ -382,22 +429,22 @@ class GRPODurationTrainer: with torch.inference_mode(): try: _est_mel = self.inference_fn( - full_text_ids=infer_full_text_ids, - prompt_mel=infer_prompt_mel, - target_duration=cur_duration, - teacher_steps=0 + full_text_ids=infer_full_text_ids, + prompt_mel=infer_prompt_mel, + target_duration=cur_duration, + teacher_steps=0, ) - _est_mel = _est_mel.permute(0, 2, 1) # (1, T, 100) - + _est_mel = _est_mel.permute(0, 2, 1) # (1, T, 100) + loss_dict = self.reward_fn( prompt_mel=infer_prompt_mel, est_mel=_est_mel, target_text_id=infer_target_text_ids, - target_text_length=infer_target_text_lengths + target_text_length=infer_target_text_lengths, ) # #TODO reweight the loss for reward - reward_sim = loss_dict['loss_sim'] # 0 to 1 - reward_ctc = loss_dict['loss_ctc'] + reward_sim = loss_dict["loss_sim"] # 0 to 1 + reward_ctc = loss_dict["loss_ctc"] reward = -(reward_ctc + reward_sim * 3) all_ctc_loss.append(reward_ctc) all_sv_loss.append(reward_sim) @@ -405,13 +452,15 @@ class GRPODurationTrainer: if self.rank == 0: print(f"Error in speech synthesis: {e}") reward = torch.tensor(-1.0).to(cur_duration.device) - + sampled_rewards.append(reward) - # list with length of K + # list with length of K sampled_rewards = torch.stack(sampled_rewards) # (K, ) # Normalize rewards if (sampled_rewards.max() - sampled_rewards.min()).item() > 1e-6: - sampled_rewards = (sampled_rewards - sampled_rewards.mean()) / (sampled_rewards.std() + 1e-8) + sampled_rewards = (sampled_rewards - sampled_rewards.mean()) / ( + sampled_rewards.std() + 1e-8 + ) # Store all data # all_logits.append(duration_logits) @@ -421,27 +470,32 @@ class GRPODurationTrainer: all_durations.append(est_frames) all_gen_logps.append(gen_logps) all_rewards.extend(sampled_rewards) # list with length of B*K - + # Concatenate all data # logits = torch.cat(all_logits, dim=0) # text_ids = torch.cat(all_text_ids, dim=0) # mels = torch.cat(all_mels, dim=0) sampled_classes = torch.cat(all_sampled_classes, dim=0) durations = torch.cat(all_durations, dim=0) - rewards = torch.stack(all_rewards) # use stack to keep the same device of elements + rewards = torch.stack( + all_rewards + ) # use stack to keep the same device of elements gen_logps = torch.cat(all_gen_logps, dim=0) ctc_losses = torch.stack(all_ctc_loss) sv_losses = torch.stack(all_sv_loss) - + if self.is_main: - self.accelerator.log({ - "ctc_loss": ctc_losses.mean().item(), - "sv_loss": sv_losses.mean().item(), - "reward": rewards.mean().item(), - "reward_min": rewards.min().item(), - "reward_max": rewards.max().item(), - }, step=self.global_step) + self.accelerator.log( + { + "ctc_loss": ctc_losses.mean().item(), + "sv_loss": sv_losses.mean().item(), + "reward": rewards.mean().item(), + "reward_min": rewards.min().item(), + "reward_max": rewards.max().item(), + }, + step=self.global_step, + ) # # Normalize rewards # if (rewards.max() - rewards.min()).item() > 1e-6: @@ -459,88 +513,89 @@ class GRPODurationTrainer: "sampled_classes": sampled_classes, "durations": durations, } - + if self.compute_gen_logps: batch_outputs["gen_logps"] = gen_logps - + if self.rank == 0: - print(f"Generated {len(rewards)} samples with reward min/mean/max: {rewards.min().item():.4f}/{rewards.mean().item():.4f}/{rewards.max().item():.4f}") - + print( + f"Generated {len(rewards)} samples with reward min/mean/max: {rewards.min().item():.4f}/{rewards.mean().item():.4f}/{rewards.max().item():.4f}" + ) + return batch_outputs - + def GRPO_step(self, batch): """ Perform a GRPO update step - + Args: batch: Dictionary with inputs, rewards, reference logits, etc. - + Returns: Loss value """ # Extract batch data # NOTE: why .unsqueeze(1) ??? - rewards = batch['rewards'] #.unsqueeze(1) - ref_logps = batch['refs'] # (B) - sampled_classes = batch['sampled_classes'] # (B) - prompt_mel = batch['prompt_mel'] - text_ids = batch['text_ids'] + rewards = batch["rewards"] # .unsqueeze(1) + ref_logps = batch["refs"] # (B) + sampled_classes = batch["sampled_classes"] # (B) + prompt_mel = batch["prompt_mel"] + text_ids = batch["text_ids"] # Forward pass to get current model logits K = self.num_pre_samples B, T, _ = prompt_mel.shape _, L = text_ids.shape cur_logits = self.model( - text_ids=text_ids, # (B, L) - mel=prompt_mel # (B, T, 100) + text_ids=text_ids, mel=prompt_mel # (B, L) # (B, T, 100) )[:, -1, :] - cur_logits = cur_logits.unsqueeze(1).repeat(1, K, 1).view(B*K, -1) + cur_logits = cur_logits.unsqueeze(1).repeat(1, K, 1).view(B * K, -1) # Compute current log probabilities for sampled actions log_probs = F.log_softmax(cur_logits, dim=-1) cur_logps = torch.gather( - log_probs, - dim=-1, - index=sampled_classes.unsqueeze(-1) - ).squeeze(-1) # (B) + log_probs, dim=-1, index=sampled_classes.unsqueeze(-1) + ).squeeze( + -1 + ) # (B) # KL divergence computation (same as in Qwen2.5 code) # KL = exp(ref - cur) - (ref - cur) - 1 - kl_div = torch.exp(ref_logps - cur_logps) - (ref_logps - cur_logps) - 1 # (B) - + kl_div = torch.exp(ref_logps - cur_logps) - (ref_logps - cur_logps) - 1 # (B) + # Compute probability ratio for PPO if "gen_logps" in batch: - gen_logps = batch['gen_logps'] + gen_logps = batch["gen_logps"] ratio = torch.exp(cur_logps - gen_logps) clipped_ratio = torch.clamp(ratio, 1 - self.clip_param, 1 + self.clip_param) loss = torch.min(ratio * rewards, clipped_ratio * rewards) else: # Simplification if gen_logps not available loss = torch.exp(cur_logps - cur_logps.detach()) * rewards - + # Final GRPO loss with KL regularization - loss = -(loss - self.beta * kl_div) # (B) + loss = -(loss - self.beta * kl_div) # (B) loss = loss.mean() - + return loss - + def get_batch(self): """Get a batch from the queue or return None if empty""" if not self.batch_queue: return None return self.batch_queue.pop(0) - + def generate_mode(self, num_batches=5): """ Generate samples and add them to the batch queue - + Args: dataset: Dataset to sample from num_batches: Number of batches to generate """ if self.rank == 0: print("Entering generate mode...") - + tic = time.time() for _ in range(num_batches): try: @@ -559,14 +614,16 @@ class GRPODurationTrainer: continue # Add batch to queue self.batch_queue.append(batch_outputs) - + if self.rank == 0: print(f"Exiting generate mode: {time.time() - tic:.3f}s") - - def train(self, train_dataset, valid_dataset=None, num_workers=64, resumable_with_seed=666): + + def train( + self, train_dataset, valid_dataset=None, num_workers=64, resumable_with_seed=666 + ): """ Train the model using GRPO - + Args: train_dataset: Training dataset valid_dataset: Validation dataset (optional) @@ -597,13 +654,17 @@ class GRPODurationTrainer: self.train_iterator = iter(self.train_dataloader) self.valid_iterator = iter(self.valid_dataloader) - + elif self.batch_size_type == "frame": self.accelerator.even_batches = False sampler = SequentialSampler(train_dataset) batch_sampler = DynamicBatchSampler( - sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False + sampler, + self.batch_size, + max_samples=self.max_samples, + random_seed=resumable_with_seed, + drop_last=False, ) self.train_dataloader = DataLoader( train_dataset, @@ -616,114 +677,132 @@ class GRPODurationTrainer: sampler = SequentialSampler(valid_dataset) batch_sampler = DynamicBatchSampler( - sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False + sampler, + self.batch_size, + max_samples=self.max_samples, + random_seed=resumable_with_seed, + drop_last=False, ) # Create validation dataloader (always sequential, no shuffling) self.valid_dataloader = DataLoader( valid_dataset, collate_fn=collate_fn, num_workers=num_workers, - pin_memory=True, + pin_memory=True, persistent_workers=True, batch_sampler=batch_sampler, ) - - self.train_dataloader, self.valid_dataloader = self.accelerator.prepare(self.train_dataloader, self.valid_dataloader) + + self.train_dataloader, self.valid_dataloader = self.accelerator.prepare( + self.train_dataloader, self.valid_dataloader + ) self.train_iterator = iter(self.train_dataloader) self.valid_iterator = iter(self.valid_dataloader) else: - raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}") - + raise ValueError( + f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}" + ) # Setup schedulers warmup_steps = self.num_warmup_updates * self.accelerator.num_processes total_steps = self.all_steps decay_steps = total_steps - warmup_steps - - warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps) - decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps) - + + warmup_scheduler = LinearLR( + self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps + ) + decay_scheduler = LinearLR( + self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps + ) + self.scheduler = SequentialLR( - self.optimizer, - schedulers=[warmup_scheduler, decay_scheduler], - milestones=[warmup_steps] + self.optimizer, + schedulers=[warmup_scheduler, decay_scheduler], + milestones=[warmup_steps], ) - + self.scheduler = self.accelerator.prepare(self.scheduler) - + # Load checkpoint if available start_step = self.load_checkpoint() self.global_step = start_step - + # Generate initial batches self.generate_mode() - + # Training loop progress = range(1, self.all_steps + 1) - + # Skip steps that are already done progress = [step for step in progress if step > start_step] if self.is_main: progress = tqdm(progress, desc="Training", unit="step") - + for step in progress: # Get batch from queue or generate more batch = self.get_batch() while batch is None: self.generate_mode() batch = self.get_batch() - + # GRPO update with self.accelerator.accumulate(self.model): loss = self.GRPO_step(batch) # for param in self.model.parameters(): - # custom_loss = loss + 0 * param.sum() + # custom_loss = loss + 0 * param.sum() self.accelerator.backward(loss) - + if self.max_grad_norm > 0 and self.accelerator.sync_gradients: - total_norm = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) + total_norm = self.accelerator.clip_grad_norm_( + self.model.parameters(), self.max_grad_norm + ) else: total_norm = torch.norm( - torch.stack([ - torch.norm(p.grad.detach(), 2) - for p in self.model.parameters() - if p.grad is not None - ]), - 2 + torch.stack( + [ + torch.norm(p.grad.detach(), 2) + for p in self.model.parameters() + if p.grad is not None + ] + ), + 2, ) - - self.accelerator.log({ - "grad_norm": total_norm.item() - }, step=self.global_step) + + self.accelerator.log( + {"grad_norm": total_norm.item()}, step=self.global_step + ) self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() - + self.global_step += 1 - + # Log metrics if self.is_main: - self.accelerator.log({ - "loss": loss.item(), - "lr": self.scheduler.get_last_lr()[0], - # "avg_reward": batch["rewards"].mean().item(), - # "max_reward": batch["rewards"].max().item(), - # "min_reward": batch["rewards"].min().item(), - }, step=self.global_step) + self.accelerator.log( + { + "loss": loss.item(), + "lr": self.scheduler.get_last_lr()[0], + # "avg_reward": batch["rewards"].mean().item(), + # "max_reward": batch["rewards"].max().item(), + # "min_reward": batch["rewards"].min().item(), + }, + step=self.global_step, + ) progress.set_postfix( loss=f"{loss.item():.4f}", - lr=f"{self.scheduler.get_last_lr()[0]:.8f}" + lr=f"{self.scheduler.get_last_lr()[0]:.8f}", ) - + # Save checkpoint if self.global_step % self.save_per_updates == 0: self.save_checkpoint(self.global_step) - + # Optional validation logic could be added here - + # Save final checkpoint self.save_checkpoint(self.global_step, last=True) self.accelerator.end_training() diff --git a/guidance_model.py b/guidance_model.py index cf064fe6edecd90a0cd77fcaaf233e157dbc2395..339755bb8cc68aed73c07f02ed96259b97c6c379 100644 --- a/guidance_model.py +++ b/guidance_model.py @@ -8,28 +8,22 @@ d - dimension """ from __future__ import annotations -from typing import Callable + from random import random -import numpy as np +from typing import Callable +import numpy as np import torch -from torch import nn import torch.nn.functional as F +from torch import nn -from f5_tts.model import DiT - -from f5_tts.model.utils import ( - default, - exists, - list_str_to_idx, - list_str_to_tensor, - lens_to_mask, - mask_from_frac_lengths, -) - -from discriminator_conformer import ConformerDiscirminator from ctcmodel import ConformerCTC +from discriminator_conformer import ConformerDiscirminator from ecapa_tdnn import ECAPA_TDNN +from f5_tts.model import DiT +from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx, + list_str_to_tensor, mask_from_frac_lengths) + class NoOpContext: def __enter__(self): @@ -38,118 +32,147 @@ class NoOpContext: def __exit__(self, *args): pass -def predict_flow(transformer, # flow model - x, # noisy input - cond, # mask (prompt mask + length mask) - text, # text input - time, # time step - second_time=None, - cfg_strength=1.0 + +def predict_flow( + transformer, # flow model + x, # noisy input + cond, # mask (prompt mask + length mask) + text, # text input + time, # time step + second_time=None, + cfg_strength=1.0, ): pred = transformer( - x=x, - cond=cond, - text=text, time=time, + x=x, + cond=cond, + text=text, + time=time, second_time=second_time, - drop_audio_cond=False, - drop_text=False + drop_audio_cond=False, + drop_text=False, ) - + if cfg_strength < 1e-5: return pred - + null_pred = transformer( - x=x, - cond=cond, - text=text, time=time, - second_time=second_time, - drop_audio_cond=True, - drop_text=True + x=x, + cond=cond, + text=text, + time=time, + second_time=second_time, + drop_audio_cond=True, + drop_text=True, ) return pred + (pred - null_pred) * cfg_strength + def _kl_dist_func(x, y): log_probs = F.log_softmax(x, dim=2) - target_probs = F.log_softmax(y, dim=2) - return torch.nn.functional.kl_div(log_probs, target_probs, reduction="batchmean", log_target=True) + target_probs = F.log_softmax(y, dim=2) + return torch.nn.functional.kl_div( + log_probs, target_probs, reduction="batchmean", log_target=True + ) class Guidance(nn.Module): - def __init__(self, - real_unet: DiT, # teacher flow model - fake_unet: DiT, # student flow model - - use_fp16: bool = True, - real_guidance_scale: float = 0.0, - fake_guidance_scale: float = 0.0, - gen_cls_loss: bool = False, - - sv_path_en: str = "", - sv_path_zh: str = "", - ctc_path: str = "", - sway_coeff: float = 0.0, - scale: float = 1.0, - ): + def __init__( + self, + real_unet: DiT, # teacher flow model + fake_unet: DiT, # student flow model + use_fp16: bool = True, + real_guidance_scale: float = 0.0, + fake_guidance_scale: float = 0.0, + gen_cls_loss: bool = False, + sv_path_en: str = "", + sv_path_zh: str = "", + ctc_path: str = "", + sway_coeff: float = 0.0, + scale: float = 1.0, + ): super().__init__() self.vocab_size = real_unet.vocab_size - + if ctc_path != "": - model = ConformerCTC(vocab_size=real_unet.vocab_size, mel_dim=real_unet.mel_dim, num_heads=8, d_hid=512, nlayers=6) + model = ConformerCTC( + vocab_size=real_unet.vocab_size, + mel_dim=real_unet.mel_dim, + num_heads=8, + d_hid=512, + nlayers=6, + ) self.ctc_model = model.eval() self.ctc_model.requires_grad_(False) - self.ctc_model.load_state_dict(torch.load(ctc_path, weights_only=True, map_location='cpu')['model_state_dict']) + self.ctc_model.load_state_dict( + torch.load(ctc_path, weights_only=True, map_location="cpu")[ + "model_state_dict" + ] + ) if sv_path_en != "": model = ECAPA_TDNN() self.sv_model_en = model.eval() self.sv_model_en.requires_grad_(False) - self.sv_model_en.load_state_dict(torch.load(sv_path, weights_only=True, map_location='cpu')['model_state_dict']) + self.sv_model_en.load_state_dict( + torch.load(sv_path, weights_only=True, map_location="cpu")[ + "model_state_dict" + ] + ) if sv_path_zh != "": model = ECAPA_TDNN() self.sv_model_zh = model.eval() self.sv_model_zh.requires_grad_(False) - self.sv_model_zh.load_state_dict(torch.load(sv_path_zh, weights_only=True, map_location='cpu')['model_state_dict']) + self.sv_model_zh.load_state_dict( + torch.load(sv_path_zh, weights_only=True, map_location="cpu")[ + "model_state_dict" + ] + ) self.scale = scale - + self.real_unet = real_unet - self.real_unet.requires_grad_(False) # no update on the teacher model + self.real_unet.requires_grad_(False) # no update on the teacher model self.fake_unet = fake_unet - self.fake_unet.requires_grad_(True) # update the student model - - self.real_guidance_scale = real_guidance_scale + self.fake_unet.requires_grad_(True) # update the student model + + self.real_guidance_scale = real_guidance_scale self.fake_guidance_scale = fake_guidance_scale - + assert self.fake_guidance_scale == 0, "no guidance for fake" self.use_fp16 = use_fp16 - self.gen_cls_loss = gen_cls_loss - + self.gen_cls_loss = gen_cls_loss + self.sway_coeff = sway_coeff - + if self.gen_cls_loss: self.cls_pred_branch = ConformerDiscirminator( - input_dim=(self.fake_unet.depth + 1) * self.fake_unet.dim + 3 * 512, # 3 is the number of layers from the CTC model + input_dim=(self.fake_unet.depth + 1) * self.fake_unet.dim + + 3 * 512, # 3 is the number of layers from the CTC model num_layers=3, channels=self.fake_unet.dim // 2, ) self.cls_pred_branch.requires_grad_(True) - - self.network_context_manager = torch.autocast(device_type="cuda", dtype=torch.float16) if self.use_fp16 else NoOpContext() + self.network_context_manager = ( + torch.autocast(device_type="cuda", dtype=torch.float16) + if self.use_fp16 + else NoOpContext() + ) - from f5_tts.model.utils import get_tokenizer from torch.utils.data import DataLoader, Dataset, SequentialSampler - from f5_tts.model.dataset import load_dataset - from f5_tts.model.dataset import DynamicBatchSampler, collate_fn + + from f5_tts.model.dataset import (DynamicBatchSampler, collate_fn, + load_dataset) + from f5_tts.model.utils import get_tokenizer bsz = 16 - + tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) dataset_name = "Emilia_ZH_EN" @@ -161,16 +184,15 @@ class Guidance(nn.Module): self.vocab_char_map = vocab_char_map - - def compute_distribution_matching_loss( - self, - + self, inp: float["b n d"] | float["b nw"], # mel or raw wave, ground truth latent text: int["b nt"] | list[str], # text input *, second_time: torch.Tensor | None = None, # second time step for flow prediction - rand_span_mask: bool["b n d"] | bool["b nw"] | None = None, # combined mask (prompt mask + padding mask) + rand_span_mask: ( + bool["b n d"] | bool["b nw"] | None + ) = None, # combined mask (prompt mask + padding mask) ): """ Compute DMD loss (L_DMD) between the student distribution and teacher distribution. @@ -183,12 +205,12 @@ class Guidance(nn.Module): The code is adapted from F5-TTS but conceptualized per DMD: L_DMD encourages p_theta to match p_data via the difference between teacher and student predictions. """ - + original_inp = inp - + with torch.no_grad(): batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device - + # mel is x1 x1 = inp @@ -197,68 +219,77 @@ class Guidance(nn.Module): # time step time = torch.rand((batch,), dtype=dtype, device=device) - + # get flow t = time.unsqueeze(-1).unsqueeze(-1) # t = t + self.sway_coeff * (torch.cos(torch.pi / 2 * t) - 1 + t) sigma_t, alpha_t = (1 - t), t - phi = (1 - t) * x0 + t * x1 # noisy x - flow = x1 - x0 # flow target - + phi = (1 - t) * x0 + t * x1 # noisy x + flow = x1 - x0 # flow target + # only predict what is within the random mask span for infilling cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) - - # run at full precision as autocast and no_grad doesn't work well together + + # run at full precision as autocast and no_grad doesn't work well together with self.network_context_manager: pred_fake = predict_flow( - self.fake_unet, - phi, - cond, # mask (prompt mask + length mask) - text, # text input - time, # time step + self.fake_unet, + phi, + cond, # mask (prompt mask + length mask) + text, # text input + time, # time step second_time=second_time, - cfg_strength=self.fake_guidance_scale + cfg_strength=self.fake_guidance_scale, ) - # pred = (x1 - x0), thus phi + (1-t) * pred = (1 - t) * x0 + t * x1 + (1 - t) * (x1 - x0) = (1 - t) * x1 + t * x1 = x1 + # pred = (x1 - x0), thus phi + (1-t) * pred = (1 - t) * x0 + t * x1 + (1 - t) * (x1 - x0) = (1 - t) * x1 + t * x1 = x1 pred_fake_image = phi + (1 - t) * pred_fake pred_fake_image[~rand_span_mask] = inp[~rand_span_mask] - + with self.network_context_manager: pred_real = predict_flow( - self.real_unet, phi, cond, text, time, cfg_strength=self.real_guidance_scale + self.real_unet, + phi, + cond, + text, + time, + cfg_strength=self.real_guidance_scale, ) - + pred_real_image = phi + (1 - t) * pred_real pred_real_image[~rand_span_mask] = inp[~rand_span_mask] - p_real = (inp - pred_real_image) - p_fake = (inp - pred_fake_image) - - grad = (p_real - p_fake) / torch.abs(p_real).mean(dim=[1, 2], keepdim=True) + p_real = inp - pred_real_image + p_fake = inp - pred_fake_image + + grad = (p_real - p_fake) / torch.abs(p_real).mean(dim=[1, 2], keepdim=True) grad = torch.nan_to_num(grad) - + # grad = grad / sigma_t # pred_fake - pred_real # grad = grad * (1 + sigma_t / alpha_t) - + # grad = grad / (1 + sigma_t / alpha_t) # noise # grad = grad / sigma_t # score difference # grad = grad * alpha_t # grad = grad * (sigma_t ** 2 / alpha_t) - + # grad = grad * (alpha_t + sigma_t ** 2 / alpha_t) - + # The DMD loss: MSE to move student distribution closer to teacher distribution # Only optimize over the masked region - loss = 0.5 * F.mse_loss(original_inp.float(), (original_inp-grad).detach().float(), reduction="none") * rand_span_mask.unsqueeze( - -1 + loss = ( + 0.5 + * F.mse_loss( + original_inp.float(), + (original_inp - grad).detach().float(), + reduction="none", ) + * rand_span_mask.unsqueeze(-1) + ) loss = loss.sum() / (rand_span_mask.sum() * grad.size(-1)) - - loss_dict = { - "loss_dm": loss - } + + loss_dict = {"loss_dm": loss} dm_log_dict = { "dmtrain_time": time.detach().float(), @@ -266,16 +297,15 @@ class Guidance(nn.Module): "dmtrain_pred_real_image": pred_real_image.detach().float(), "dmtrain_pred_fake_image": pred_fake_image.detach().float(), "dmtrain_grad": grad.detach().float(), - "dmtrain_gradient_norm": torch.norm(grad).item() + "dmtrain_gradient_norm": torch.norm(grad).item(), } return loss_dict, dm_log_dict - - + def compute_ctc_sv_loss( self, - real_inp: torch.Tensor, # real data latent - fake_inp: torch.Tensor, # student-generated data latent + real_inp: torch.Tensor, # real data latent + fake_inp: torch.Tensor, # student-generated data latent text: torch.Tensor, text_lens: torch.Tensor, rand_span_mask: torch.Tensor, @@ -290,16 +320,20 @@ class Guidance(nn.Module): """ # compute CTC loss - out, layer, ctc_loss = self.ctc_model(fake_inp * self.scale, text, text_lens) # lengths from rand_span_mask or known + out, layer, ctc_loss = self.ctc_model( + fake_inp * self.scale, text, text_lens + ) # lengths from rand_span_mask or known with torch.no_grad(): - real_out, real_layers, ctc_loss_test = self.ctc_model(real_inp * self.scale, text, text_lens) + real_out, real_layers, ctc_loss_test = self.ctc_model( + real_inp * self.scale, text, text_lens + ) real_logits = real_out.log_softmax(dim=2) - # emb_real = self.sv_model(real_inp * self.scale) # snippet from prompt region - + # emb_real = self.sv_model(real_inp * self.scale) # snippet from prompt region + fake_logits = out.log_softmax(dim=2) kl_loss = F.kl_div(fake_logits, real_logits, reduction="mean", log_target=True) - + # For SV: # Extract speaker embeddings from real (prompt) and fake: # emb_fake = self.sv_model(fake_inp * self.scale) @@ -325,8 +359,8 @@ class Guidance(nn.Module): random_start = np.random.randint(0, mel_length - mel_len) else: random_start = np.random.randint(prompt_start, prompt_end - mel_len) - - chunks_fake.append(fake_inp[bib, random_start:random_start + mel_len, :]) + + chunks_fake.append(fake_inp[bib, random_start : random_start + mel_len, :]) chunks_real.append(real_inp[bib, :mel_len, :]) chunks_real = torch.stack(chunks_real, dim=0) @@ -346,27 +380,25 @@ class Guidance(nn.Module): sv_loss = (sv_loss_en + sv_loss_zh) / 2 - return { - "loss_ctc": ctc_loss, - 'loss_kl': kl_loss, - "loss_sim": sv_loss - }, layer, real_layers + return ( + {"loss_ctc": ctc_loss, "loss_kl": kl_loss, "loss_sim": sv_loss}, + layer, + real_layers, + ) - - def compute_loss_fake( self, - inp: torch.Tensor, # student generator output + inp: torch.Tensor, # student generator output text: torch.Tensor | list[str], rand_span_mask: torch.Tensor, second_time: torch.Tensor | None = None, ): """ Compute flow loss for the fake flow model, which is trained to estimate the flow (score) of the student distribution. - - This is the same as L_diff in the paper. + + This is the same as L_diff in the paper. """ - + # Similar to distribution matching, but only train fake to predict flow directly batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device @@ -383,29 +415,27 @@ class Guidance(nn.Module): x1 = inp x0 = torch.randn_like(x1) t = time.unsqueeze(-1).unsqueeze(-1) - + phi = (1 - t) * x0 + t * x1 flow = x1 - x0 cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) with self.network_context_manager: pred = self.fake_unet( - x=phi, + x=phi, cond=cond, - text=text, - time=time, + text=text, + time=time, second_time=second_time, - drop_audio_cond=False, - drop_text=False # make sure the cfg=1 + drop_audio_cond=False, + drop_text=False, # make sure the cfg=1 ) # Compute MSE between predicted flow and actual flow, masked by rand_span_mask loss = F.mse_loss(pred, flow, reduction="none") loss = loss[rand_span_mask].mean() - - loss_dict = { - "loss_fake_mean": loss - } + + loss_dict = {"loss_fake_mean": loss} log_dict = { "faketrain_noisy_inp": phi.detach().float(), "faketrain_x1": x1.detach().float(), @@ -416,19 +446,19 @@ class Guidance(nn.Module): def compute_cls_logits( self, - inp: torch.Tensor, # student generator output + inp: torch.Tensor, # student generator output layer: torch.Tensor, text: torch.Tensor, rand_span_mask: torch.Tensor, second_time: torch.Tensor | None = None, guidance: bool = False, ): - ''' + """ Compute adversarial loss logits for the generator. - + This is used to compute L_adv in the paper. - - ''' + + """ context_no_grad = torch.no_grad if guidance else NoOpContext with context_no_grad(): @@ -438,7 +468,7 @@ class Guidance(nn.Module): # For classification, we need some representation: # We'll mimic the logic from compute_loss_fake - + batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device if isinstance(text, list): if exists(self.vocab_char_map): @@ -453,26 +483,26 @@ class Guidance(nn.Module): x1 = inp x0 = torch.randn_like(x1) t = time.unsqueeze(-1).unsqueeze(-1) - + phi = (1 - t) * x0 + t * x1 cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) with self.network_context_manager: layers = self.fake_unet( - x=phi, + x=phi, cond=cond, - text=text, - time=time, + text=text, + time=time, second_time=second_time, - drop_audio_cond=False, - drop_text=False, # make sure the cfg=1 - classify_mode=True + drop_audio_cond=False, + drop_text=False, # make sure the cfg=1 + classify_mode=True, ) # layers = torch.stack(layers, dim=0) if guidance: layers = [layer.detach() for layer in layers] - layer = layer[-3:] # only use the last 3 layers + layer = layer[-3:] # only use the last 3 layers layer = [l.transpose(-1, -2) for l in layer] # layer = [F.interpolate(l, mode='nearest', scale_factor=4).transpose(-1, -2) for l in layer] if layer[0].size(1) < layers[0].size(1): @@ -484,10 +514,9 @@ class Guidance(nn.Module): return logits, layers - def compute_generator_cls_loss( self, - inp: torch.Tensor, # student generator output + inp: torch.Tensor, # student generator output layer: torch.Tensor, real_layers: torch.Tensor, text: torch.Tensor, @@ -496,20 +525,22 @@ class Guidance(nn.Module): mse_loss: bool = False, mse_inp: torch.Tensor | None = None, ): - ''' - Compute the adversarial loss for the generator. - ''' - + """ + Compute the adversarial loss for the generator. + """ + # Compute classification loss for generator: if not self.gen_cls_loss: return {"gen_cls_loss": 0} - logits, fake_layers = self.compute_cls_logits(inp, layer, text, rand_span_mask, second_time, guidance=False) + logits, fake_layers = self.compute_cls_logits( + inp, layer, text, rand_span_mask, second_time, guidance=False + ) loss = ((1 - logits) ** 2).mean() return {"gen_cls_loss": loss, "loss_mse": 0} - + def compute_guidance_cls_loss( self, fake_inp: torch.Tensor, @@ -518,17 +549,19 @@ class Guidance(nn.Module): real_data: dict, second_time: torch.Tensor | None = None, ): - ''' + """ This function computes the adversarial loss for the discirminator. The discriminator is trained to classify the generator output as real or fake. - ''' + """ with torch.no_grad(): # get layers from CTC model _, layer = self.ctc_model(fake_inp * self.scale) - logits_fake, _ = self.compute_cls_logits(fake_inp.detach(), layer, text, rand_span_mask, second_time, guidance=True) + logits_fake, _ = self.compute_cls_logits( + fake_inp.detach(), layer, text, rand_span_mask, second_time, guidance=True + ) loss_fake = (logits_fake**2).mean() real_inp = real_data["inp"] @@ -536,18 +569,18 @@ class Guidance(nn.Module): with torch.no_grad(): # get layers from CTC model _, layer = self.ctc_model(real_inp * self.scale) - - logits_real, _ = self.compute_cls_logits(real_inp.detach(), layer, text, rand_span_mask, second_time, guidance=True) - loss_real = ((1 - logits_real)**2).mean() + + logits_real, _ = self.compute_cls_logits( + real_inp.detach(), layer, text, rand_span_mask, second_time, guidance=True + ) + loss_real = ((1 - logits_real) ** 2).mean() classification_loss = loss_real + loss_fake - loss_dict = { - "guidance_cls_loss": classification_loss - } + loss_dict = {"guidance_cls_loss": classification_loss} log_dict = { "pred_realism_on_real": loss_real.detach().item(), - "pred_realism_on_fake": loss_fake.detach().item() + "pred_realism_on_fake": loss_fake.detach().item(), } return loss_dict, log_dict @@ -560,21 +593,25 @@ class Guidance(nn.Module): text_normalized: torch.Tensor, text_normalized_lens: torch.Tensor, rand_span_mask: torch.Tensor, - real_data: dict | None = None, # ground truth data (primarily prompt) to compute SV loss + real_data: ( + dict | None + ) = None, # ground truth data (primarily prompt) to compute SV loss second_time: torch.Tensor | None = None, mse_loss: bool = False, ): - ''' + """ Forward pass for the generator. - + This function computes the loss for the generator, which includes: - Distribution matching loss (L_DMD) - Adversarial generator loss (L_adv(G; D)) - CTC/SV loss (L_ctc + L_sv) - ''' - + """ + # 1. Compute DM loss - dm_loss_dict, dm_log_dict = self.compute_distribution_matching_loss(inp, text, rand_span_mask=rand_span_mask, second_time=second_time) + dm_loss_dict, dm_log_dict = self.compute_distribution_matching_loss( + inp, text, rand_span_mask=rand_span_mask, second_time=second_time + ) ctc_sv_loss_dict = {} cls_loss_dict = {} @@ -582,17 +619,27 @@ class Guidance(nn.Module): # 2. Compute optional CTC/SV loss if real_data provided if real_data is not None: real_inp = real_data["inp"] - ctc_sv_loss_dict, layer, real_layers = self.compute_ctc_sv_loss(real_inp, inp, text_normalized, text_normalized_lens, rand_span_mask, second_time=second_time) + ctc_sv_loss_dict, layer, real_layers = self.compute_ctc_sv_loss( + real_inp, + inp, + text_normalized, + text_normalized_lens, + rand_span_mask, + second_time=second_time, + ) # 3. Compute optional classification loss if self.gen_cls_loss: - cls_loss_dict = self.compute_generator_cls_loss(inp, layer, real_layers, text, - rand_span_mask=rand_span_mask, - second_time=second_time, - mse_inp = real_data["inp"] if real_data is not None else None, - mse_loss = mse_loss, - ) - + cls_loss_dict = self.compute_generator_cls_loss( + inp, + layer, + real_layers, + text, + rand_span_mask=rand_span_mask, + second_time=second_time, + mse_inp=real_data["inp"] if real_data is not None else None, + mse_loss=mse_loss, + ) loss_dict = {**dm_loss_dict, **cls_loss_dict, **ctc_sv_loss_dict} log_dict = {**dm_log_dict} @@ -608,35 +655,39 @@ class Guidance(nn.Module): real_data: dict | None = None, second_time: torch.Tensor | None = None, ): - ''' + """ Forward pass for the guidnce module (discriminator + fake flow function). - + This function computes the loss for the guidance module, which includes: - Flow matching loss (L_diff) - Adversarial discrminator loss (L_adv(D; G)) - - ''' - + + """ + # Compute fake loss (like epsilon prediction loss in Guidance) - fake_loss_dict, fake_log_dict = self.compute_loss_fake(fake_inp, text, rand_span_mask=rand_span_mask, second_time=second_time) + fake_loss_dict, fake_log_dict = self.compute_loss_fake( + fake_inp, text, rand_span_mask=rand_span_mask, second_time=second_time + ) # If gen_cls_loss, compute guidance cls loss cls_loss_dict = {} cls_log_dict = {} if self.gen_cls_loss and real_data is not None: - cls_loss_dict, cls_log_dict = self.compute_guidance_cls_loss(fake_inp, text, rand_span_mask, real_data, second_time=second_time) + cls_loss_dict, cls_log_dict = self.compute_guidance_cls_loss( + fake_inp, text, rand_span_mask, real_data, second_time=second_time + ) loss_dict = {**fake_loss_dict, **cls_loss_dict} log_dict = {**fake_log_dict, **cls_log_dict} return loss_dict, log_dict - + def forward( self, generator_turn=False, guidance_turn=False, generator_data_dict=None, - guidance_data_dict=None + guidance_data_dict=None, ): if generator_turn: loss_dict, log_dict = self.generator_forward( @@ -660,19 +711,18 @@ class Guidance(nn.Module): second_time=guidance_data_dict.get("second_time", None), ) else: - raise NotImplementedError("Must specify either generator_turn or guidance_turn") + raise NotImplementedError( + "Must specify either generator_turn or guidance_turn" + ) return loss_dict, log_dict - - if __name__ == "__main__": from f5_tts.model.utils import get_tokenizer - bsz = 16 - + tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) dataset_name = "Emilia_ZH_EN" @@ -681,25 +731,43 @@ if __name__ == "__main__": else: tokenizer_path = dataset_name vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) - - - real_unet = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4, text_num_embeds=vocab_size, mel_dim=100) - fake_unet = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4, text_num_embeds=vocab_size, mel_dim=100) - - guidance = Guidance(real_unet, - fake_unet, - real_guidance_scale=1.0, - fake_guidance_scale=0.0, - use_fp16=True, - gen_cls_loss=True, - ).cuda() - + + real_unet = DiT( + dim=1024, + depth=22, + heads=16, + ff_mult=2, + text_dim=512, + conv_layers=4, + text_num_embeds=vocab_size, + mel_dim=100, + ) + fake_unet = DiT( + dim=1024, + depth=22, + heads=16, + ff_mult=2, + text_dim=512, + conv_layers=4, + text_num_embeds=vocab_size, + mel_dim=100, + ) + + guidance = Guidance( + real_unet, + fake_unet, + real_guidance_scale=1.0, + fake_guidance_scale=0.0, + use_fp16=True, + gen_cls_loss=True, + ).cuda() + text = ["hello world"] * bsz lens = torch.randint(1, 1000, (bsz,)).cuda() inp = torch.randn(bsz, lens.max(), 80).cuda() - + batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device - + # handle text as string if isinstance(text, list): if exists(vocab_char_map): @@ -712,16 +780,20 @@ if __name__ == "__main__": if not exists(lens): lens = torch.full((batch,), seq_len, device=device) - mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch + mask = lens_to_mask( + lens, length=seq_len + ) # useless here, as collate_fn will pad to max length in batch frac_lengths_mask = (0.7, 1.0) - + # get a random span to mask out for training conditionally - frac_lengths = torch.zeros((batch,), device=device).float().uniform_(*frac_lengths_mask) + frac_lengths = ( + torch.zeros((batch,), device=device).float().uniform_(*frac_lengths_mask) + ) rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) - + if exists(mask): rand_span_mask &= mask - + # Construct data dicts for generator and guidance phases # For flow, `real_data` can just be the ground truth if available; here we simulate it real_data_dict = { @@ -732,21 +804,24 @@ if __name__ == "__main__": "inp": inp, "text": text, "rand_span_mask": rand_span_mask, - "real_data": real_data_dict + "real_data": real_data_dict, } guidance_data_dict = { "inp": inp, "text": text, "rand_span_mask": rand_span_mask, - "real_data": real_data_dict + "real_data": real_data_dict, } - # Generator forward pass - loss_dict, log_dict = guidance(generator_turn=True, generator_data_dict=generator_data_dict) + loss_dict, log_dict = guidance( + generator_turn=True, generator_data_dict=generator_data_dict + ) print("Generator turn losses:", loss_dict) # Guidance forward pass - loss_dict, log_dict = guidance(guidance_turn=True, guidance_data_dict=guidance_data_dict) + loss_dict, log_dict = guidance( + guidance_turn=True, guidance_data_dict=guidance_data_dict + ) print("Guidance turn losses:", loss_dict) diff --git a/infer.py b/infer.py index 91678efeee60979b121b1a923fc8ab5e279c56e2..3306c78122b66044707405770de1213e5b543ba6 100644 --- a/infer.py +++ b/infer.py @@ -1,32 +1,31 @@ import os + import torch -import torchaudio import torch.nn.functional as F +import torchaudio +from safetensors.torch import load_file from torch.nn.utils.rnn import pad_sequence from torchdiffeq import odeint -from safetensors.torch import load_file +from duration_predictor import SpeechLengthPredictor +from f5_tts.infer.utils_infer import (chunk_text, convert_char_to_pinyin, + hop_length, load_vocoder, + preprocess_ref_audio_text, speed, + target_rms, target_sample_rate, + transcribe) # Import F5-TTS modules -from f5_tts.model import CFM, UNetT, DiT +from f5_tts.model import CFM, DiT, UNetT from f5_tts.model.modules import MelSpec -from f5_tts.model.utils import ( - default, exists, list_str_to_idx, list_str_to_tensor, - lens_to_mask, mask_from_frac_lengths, get_tokenizer -) -from f5_tts.infer.utils_infer import ( - load_vocoder, preprocess_ref_audio_text, chunk_text, - convert_char_to_pinyin, transcribe, target_rms, - target_sample_rate, hop_length, speed -) - +from f5_tts.model.utils import (default, exists, get_tokenizer, lens_to_mask, + list_str_to_idx, list_str_to_tensor, + mask_from_frac_lengths) # Import custom modules from unimodel import UniModel -from duration_predictor import SpeechLengthPredictor class DMOInference: """F5-TTS Inference wrapper class for easy text-to-speech generation.""" - + def __init__( self, student_checkpoint_path="", @@ -38,7 +37,7 @@ class DMOInference: ): """ Initialize F5-TTS inference model. - + Args: student_checkpoint_path: Path to student model checkpoint duration_predictor_path: Path to duration predictor checkpoint @@ -48,12 +47,12 @@ class DMOInference: dataset_name: Dataset name for tokenizer cuda_device_id: CUDA device ID to use """ - + self.device = device self.model_type = model_type self.tokenizer = tokenizer self.dataset_name = dataset_name - + # Model parameters self.target_sample_rate = 24000 self.n_mel_channels = 100 @@ -62,39 +61,47 @@ class DMOInference: self.fake_guidance_scale = 0 self.gen_cls_loss = False self.num_student_step = 4 - + # Initialize components self._setup_tokenizer() self._setup_models(student_checkpoint_path) self._setup_mel_spec() self._setup_vocoder() self._setup_duration_predictor(duration_predictor_path) - + def _setup_tokenizer(self): """Setup tokenizer and vocabulary.""" if self.tokenizer == "custom": tokenizer_path = self.tokenizer_path else: tokenizer_path = self.dataset_name - - self.vocab_char_map, self.vocab_size = get_tokenizer(tokenizer_path, self.tokenizer) - + + self.vocab_char_map, self.vocab_size = get_tokenizer( + tokenizer_path, self.tokenizer + ) + def _setup_models(self, student_checkpoint_path): """Initialize teacher and student models.""" # Model configuration if self.model_type == "F5TTS_Base": model_cls = DiT - model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) + model_cfg = dict( + dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4 + ) elif self.model_type == "E2TTS_Base": model_cls = UNetT model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) else: raise ValueError(f"Unknown model type: {self.model_type}") - + # Initialize UniModel (student) self.model = UniModel( - model_cls(**model_cfg, text_num_embeds=self.vocab_size, mel_dim=self.n_mel_channels, - second_time=self.num_student_step > 1), + model_cls( + **model_cfg, + text_num_embeds=self.vocab_size, + mel_dim=self.n_mel_channels, + second_time=self.num_student_step > 1, + ), checkpoint_path="", vocab_char_map=self.vocab_char_map, frac_lengths_mask=(0.5, 0.9), @@ -103,17 +110,17 @@ class DMOInference: gen_cls_loss=self.gen_cls_loss, sway_coeff=0, ) - + # Load student checkpoint - checkpoint = torch.load(student_checkpoint_path, map_location='cpu') - self.model.load_state_dict(checkpoint['model_state_dict'], strict=False) - + checkpoint = torch.load(student_checkpoint_path, map_location="cpu") + self.model.load_state_dict(checkpoint["model_state_dict"], strict=False) + # Setup generator and teacher self.generator = self.model.feedforward_model.to(self.device) self.teacher = self.model.guidance_model.real_unet.to(self.device) - - self.scale = checkpoint['scale'] - + + self.scale = checkpoint["scale"] + def _setup_mel_spec(self): """Initialize mel spectrogram module.""" mel_spec_kwargs = dict( @@ -122,12 +129,12 @@ class DMOInference: hop_length=self.hop_length, ) self.mel_spec = MelSpec(**mel_spec_kwargs) - + def _setup_vocoder(self): """Initialize vocoder.""" self.vocos = load_vocoder(is_local=False, local_path="") self.vocos = self.vocos.to(self.device) - + def _setup_duration_predictor(self, checkpoint_path): """Initialize duration predictor.""" self.wav2mel = MelSpec( @@ -136,9 +143,9 @@ class DMOInference: hop_length=256, win_length=1024, n_fft=1024, - mel_spec_type='vocos' + mel_spec_type="vocos", ).to(self.device) - + self.SLP = SpeechLengthPredictor( vocab_size=2545, n_mel=100, @@ -146,16 +153,20 @@ class DMOInference: n_text_layer=4, n_cross_layer=4, n_head=8, - output_dim=301 + output_dim=301, ).to(self.device) - + self.SLP.eval() - self.SLP.load_state_dict(torch.load(checkpoint_path, map_location='cpu')['model_state_dict']) - - def predict_duration(self, pmt_wav_path, tar_text, pmt_text, dp_softmax_range=0.7, temperature=0): + self.SLP.load_state_dict( + torch.load(checkpoint_path, map_location="cpu")["model_state_dict"] + ) + + def predict_duration( + self, pmt_wav_path, tar_text, pmt_text, dp_softmax_range=0.7, temperature=0 + ): """ Predict duration for target text based on prompt audio. - + Args: pmt_wav_path: Path to prompt audio tar_text: Target text to generate @@ -172,50 +183,52 @@ class DMOInference: if pmt_wav.size(0) > 1: pmt_wav = pmt_wav[0].unsqueeze(0) pmt_wav = pmt_wav.to(self.device) - + pmt_mel = self.wav2mel(pmt_wav).permute(0, 2, 1) tar_tokens = self._convert_to_pinyin(list(tar_text)) pmt_tokens = self._convert_to_pinyin(list(pmt_text)) - + # Calculate duration ref_text_len = len(pmt_tokens) gen_text_len = len(tar_tokens) ref_audio_len = pmt_mel.size(1) duration = int(ref_audio_len / ref_text_len * gen_text_len / speed) duration = duration // 10 - + min_duration = max(int(duration * dp_softmax_range), 0) max_duration = min(int(duration * (1 + dp_softmax_range)), 301) - - all_tokens = pmt_tokens + [' '] + tar_tokens - + + all_tokens = pmt_tokens + [" "] + tar_tokens + text_ids = list_str_to_idx([all_tokens], self.vocab_char_map).to(self.device) text_ids = text_ids.masked_fill(text_ids == -1, self.vocab_size) - + with torch.no_grad(): predictions = self.SLP(text_ids=text_ids, mel=pmt_mel) predictions = predictions[:, -1, :] - predictions[:, :min_duration] = float('-inf') - predictions[:, max_duration:] = float('-inf') - + predictions[:, :min_duration] = float("-inf") + predictions[:, max_duration:] = float("-inf") + if temperature == 0: est_label = predictions.argmax(-1)[..., -1].item() * 10 else: probs = torch.softmax(predictions / temperature, dim=-1) - sampled_idx = torch.multinomial(probs.squeeze(0), num_samples=1) # Remove the -1 index + sampled_idx = torch.multinomial( + probs.squeeze(0), num_samples=1 + ) # Remove the -1 index est_label = sampled_idx.item() * 10 - + return est_label - + def _convert_to_pinyin(self, char_list): """Convert character list to pinyin.""" result = [] for x in convert_char_to_pinyin(char_list): result = result + x - while result[0] == ' ' and len(result) > 1: + while result[0] == " " and len(result) > 1: result = result[1:] return result - + def generate( self, gen_text, @@ -230,11 +243,11 @@ class DMOInference: eta=1.0, cfg_strength=2.0, sway_coefficient=-1.0, - verbose=False + verbose=False, ): """ Generate speech from text using teacher-student distillation. - + Args: gen_text: Text to generate audio_path: Path to prompt audio @@ -249,82 +262,101 @@ class DMOInference: cfg_strength: Classifier-free guidance strength sway_coefficient: Sway sampling coefficient verbose: Output sampling steps - + Returns: Generated audio waveform """ if prompt_text is None: prompt_text = transcribe(audio_path) - + # Predict duration if not provided if duration is None: - duration = self.predict_duration(audio_path, gen_text, prompt_text, dp_softmax_range, temperature) - + duration = self.predict_duration( + audio_path, gen_text, prompt_text, dp_softmax_range, temperature + ) + # Preprocess audio and text ref_audio, ref_text = preprocess_ref_audio_text(audio_path, prompt_text) audio, sr = torchaudio.load(ref_audio) - + if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) - + # Normalize audio rms = torch.sqrt(torch.mean(torch.square(audio))) if rms < target_rms: audio = audio * target_rms / rms - + if sr != self.target_sample_rate: resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate) audio = resampler(audio) - + audio = audio.to(self.device) - + # Prepare text text_list = [ref_text + gen_text] final_text_list = convert_char_to_pinyin(text_list) - + # Calculate durations ref_audio_len = audio.shape[-1] // self.hop_length if duration is None: ref_text_len = len(ref_text.encode("utf-8")) gen_text_len = len(gen_text.encode("utf-8")) - duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed) + duration = ref_audio_len + int( + ref_audio_len / ref_text_len * gen_text_len / speed + ) else: duration = ref_audio_len + duration - + if verbose: - print('audio:', audio.shape) - print('text:', final_text_list) - print('duration:', duration) - print('eta (stochasticity):', eta) # Print eta value for debugging + print("audio:", audio.shape) + print("text:", final_text_list) + print("duration:", duration) + print("eta (stochasticity):", eta) # Print eta value for debugging # Run inference with torch.inference_mode(): - cond, text, step_cond, cond_mask, max_duration, duration_tensor = self._prepare_inputs( - audio, final_text_list, duration + cond, text, step_cond, cond_mask, max_duration, duration_tensor = ( + self._prepare_inputs(audio, final_text_list, duration) ) - + # Teacher-student sampling if teacher_steps > 0 and student_start_step > 0: if verbose: - print('Start teacher sampling with hybrid DDIM/DDPM (eta={})....'.format(eta)) + print( + "Start teacher sampling with hybrid DDIM/DDPM (eta={})....".format( + eta + ) + ) x1 = self._teacher_sampling( - step_cond, text, cond_mask, max_duration, duration_tensor, # Use duration_tensor - teacher_steps, teacher_stopping_time, eta, cfg_strength, verbose, sway_coefficient + step_cond, + text, + cond_mask, + max_duration, + duration_tensor, # Use duration_tensor + teacher_steps, + teacher_stopping_time, + eta, + cfg_strength, + verbose, + sway_coefficient, ) else: x1 = step_cond - + if verbose: - print('Start student sampling...') + print("Start student sampling...") # Student sampling - x1 = self._student_sampling(x1, cond, text, student_start_step, verbose, sway_coefficient) - + x1 = self._student_sampling( + x1, cond, text, student_start_step, verbose, sway_coefficient + ) + # Decode to audio mel = x1.permute(0, 2, 1) * self.scale - generated_wave = self.vocos.decode(mel[..., cond_mask.sum():]) - + generated_wave = self.vocos.decode(mel[..., cond_mask.sum() :]) + return generated_wave.cpu().numpy().squeeze() - + def generate_teacher_only( self, gen_text, @@ -334,11 +366,11 @@ class DMOInference: duration=None, eta=1.0, cfg_strength=2.0, - sway_coefficient=-1.0 + sway_coefficient=-1.0, ): """ Generate speech using teacher model only (no student distillation). - + Args: gen_text: Text to generate audio_path: Path to prompt audio @@ -348,85 +380,95 @@ class DMOInference: eta: Stochasticity control (0=DDIM, 1=DDPM) cfg_strength: Classifier-free guidance strength sway_coefficient: Sway sampling coefficient - + Returns: Generated audio waveform """ if prompt_text is None: prompt_text = transcribe(audio_path) - + # Predict duration if not provided if duration is None: duration = self.predict_duration(audio_path, gen_text, prompt_text) - + # Preprocess audio and text ref_audio, ref_text = preprocess_ref_audio_text(audio_path, prompt_text) audio, sr = torchaudio.load(ref_audio) - + if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) - + # Normalize audio rms = torch.sqrt(torch.mean(torch.square(audio))) if rms < target_rms: audio = audio * target_rms / rms - + if sr != self.target_sample_rate: resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate) audio = resampler(audio) - + audio = audio.to(self.device) - + # Prepare text text_list = [ref_text + gen_text] final_text_list = convert_char_to_pinyin(text_list) - + # Calculate durations ref_audio_len = audio.shape[-1] // self.hop_length if duration is None: ref_text_len = len(ref_text.encode("utf-8")) gen_text_len = len(gen_text.encode("utf-8")) - duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed) + duration = ref_audio_len + int( + ref_audio_len / ref_text_len * gen_text_len / speed + ) else: duration = ref_audio_len + duration - + # Run inference with torch.inference_mode(): cond, text, step_cond, cond_mask, max_duration = self._prepare_inputs( audio, final_text_list, duration ) - + # Teacher-only sampling x1 = self._teacher_sampling( - step_cond, text, cond_mask, max_duration, duration, - teacher_steps, 1.0, eta, cfg_strength, sway_coefficient # stopping_time=1.0 for full sampling + step_cond, + text, + cond_mask, + max_duration, + duration, + teacher_steps, + 1.0, + eta, + cfg_strength, + sway_coefficient, # stopping_time=1.0 for full sampling ) - + # Decode to audio mel = x1.permute(0, 2, 1) * self.scale - generated_wave = self.vocos.decode(mel[..., cond_mask.sum():]) - + generated_wave = self.vocos.decode(mel[..., cond_mask.sum() :]) + return generated_wave - + def _prepare_inputs(self, audio, text_list, duration): """Prepare inputs for generation.""" lens = None max_duration_limit = 4096 - + cond = audio text = text_list - + if cond.ndim == 2: cond = self.mel_spec(cond) cond = cond.permute(0, 2, 1) assert cond.shape[-1] == 100 - + cond = cond / self.scale batch, cond_seq_len, device = *cond.shape[:2], cond.device - + if not exists(lens): lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) - + # Process text if isinstance(text, list): if exists(self.vocab_char_map): @@ -434,89 +476,123 @@ class DMOInference: else: text = list_str_to_tensor(text).to(device) assert text.shape[0] == batch - + if exists(text): text_lens = (text != -1).sum(dim=-1) lens = torch.maximum(text_lens, lens) - + # Process duration cond_mask = lens_to_mask(lens) - + if isinstance(duration, int): duration = torch.full((batch,), duration, device=device, dtype=torch.long) - + duration = torch.maximum(lens + 1, duration) duration = duration.clamp(max=max_duration_limit) max_duration = duration.amax() - + # Pad conditioning cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) - cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) + cond_mask = F.pad( + cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False + ) cond_mask = cond_mask.unsqueeze(-1) step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) - + return cond, text, step_cond, cond_mask, max_duration, duration - - def _teacher_sampling(self, step_cond, text, cond_mask, max_duration, duration, - teacher_steps, teacher_stopping_time, eta, cfg_strength, verbose, sway_sampling_coef = -1): + + def _teacher_sampling( + self, + step_cond, + text, + cond_mask, + max_duration, + duration, + teacher_steps, + teacher_stopping_time, + eta, + cfg_strength, + verbose, + sway_sampling_coef=-1, + ): """Perform teacher model sampling.""" device = step_cond.device - + # Pre-generate noise sequence for stochastic sampling noise_seq = None if eta > 0: - noise_seq = [torch.randn(1, max_duration, 100, device=device) - for _ in range(teacher_steps)] - + noise_seq = [ + torch.randn(1, max_duration, 100, device=device) + for _ in range(teacher_steps) + ] + def fn(t, x): with torch.inference_mode(): with torch.autocast(device_type="cuda", dtype=torch.float16): if verbose: - print(f'current t: {t}') + print(f"current t: {t}") step_frac = 1.0 - t.item() - step_idx = min(int(step_frac * len(noise_seq)), len(noise_seq) - 1) if noise_seq else 0 - + step_idx = ( + min(int(step_frac * len(noise_seq)), len(noise_seq) - 1) + if noise_seq + else 0 + ) + # Predict flow pred = self.teacher( - x=x, cond=step_cond, text=text, time=t, mask=None, - drop_audio_cond=False, drop_text=False + x=x, + cond=step_cond, + text=text, + time=t, + mask=None, + drop_audio_cond=False, + drop_text=False, ) - + if cfg_strength > 1e-5: null_pred = self.teacher( - x=x, cond=step_cond, text=text, time=t, mask=None, - drop_audio_cond=True, drop_text=True + x=x, + cond=step_cond, + text=text, + time=t, + mask=None, + drop_audio_cond=True, + drop_text=True, ) pred = pred + (pred - null_pred) * cfg_strength - + # Add stochasticity if eta > 0 if eta > 0 and noise_seq is not None: alpha_t = 1.0 - t.item() sigma_t = t.item() - noise_scale = torch.sqrt(torch.tensor( - (sigma_t**2) / (alpha_t**2 + sigma_t**2) * eta, - device=device - )) + noise_scale = torch.sqrt( + torch.tensor( + (sigma_t**2) / (alpha_t**2 + sigma_t**2) * eta, + device=device, + ) + ) return pred + noise_scale * noise_seq[step_idx] else: return pred - + # Initialize noise y0 = [] for dur in duration: y0.append(torch.randn(dur, 100, device=device, dtype=step_cond.dtype)) y0 = pad_sequence(y0, padding_value=0, batch_first=True) - + # Setup time steps - t = torch.linspace(0, 1, teacher_steps + 1, device=device, dtype=step_cond.dtype) + t = torch.linspace( + 0, 1, teacher_steps + 1, device=device, dtype=step_cond.dtype + ) if sway_sampling_coef is not None: t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) - t = t[:(t > teacher_stopping_time).float().argmax() + 2] + t = t[: (t > teacher_stopping_time).float().argmax() + 2] t = t[:-1] - + # Solve ODE trajectory = odeint(fn, y0, t, method="euler") - + if teacher_stopping_time < 1.0: # If early stopping, compute final step pred = fn(t[-1], trajectory[-1]) @@ -524,22 +600,24 @@ class DMOInference: return test_out else: return trajectory[-1] - - def _student_sampling(self, x1, cond, text, student_start_step, verbose, sway_coeff = -1): + + def _student_sampling( + self, x1, cond, text, student_start_step, verbose, sway_coeff=-1 + ): """Perform student model sampling.""" steps = torch.Tensor([0, 0.25, 0.5, 0.75]) steps = steps + sway_coeff * (torch.cos(torch.pi / 2 * steps) - 1 + steps) steps = steps[student_start_step:] - + for step in steps: time = torch.Tensor([step]).to(x1.device) - + x0 = torch.randn_like(x1) t = time.unsqueeze(-1).unsqueeze(-1) phi = (1 - t) * x0 + t * x1 - + if verbose: - print(f'current step: {step}') + print(f"current step: {step}") with torch.no_grad(): pred = self.generator( x=phi, @@ -547,12 +625,12 @@ class DMOInference: text=text, time=time, drop_audio_cond=False, - drop_text=False + drop_text=False, ) - + # Predicted mel spectrogram output = phi + (1 - t) * pred - + x1 = output - + return x1 diff --git a/requirements.txt b/requirements.txt index a39bd415a46dcfd75c81816e7c667872540851ce..d213748c35dc44bafbf444fae77ccdaea0c77c84 100644 --- a/requirements.txt +++ b/requirements.txt @@ -37,3 +37,5 @@ x_transformers>=1.31.14 # modelscope # zhconv # zhon + +cached_path \ No newline at end of file diff --git a/unimodel.py b/unimodel.py index afdb142752e2edee651c41a414928b10d3c8cfe3..559cd4bf5e129c90832b5b932523a03f4abf49eb 100644 --- a/unimodel.py +++ b/unimodel.py @@ -1,43 +1,39 @@ from __future__ import annotations -from typing import Callable -from random import random import contextlib - -from torch import nn -import torch import copy import os +from pathlib import Path +from random import random +from typing import Callable + +import torch +from torch import nn from f5_tts.model import DiT, UNetT -from pathlib import Path +from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx, + list_str_to_tensor, mask_from_frac_lengths, + sample_consecutive_steps, sample_from_list) from guidance_model import Guidance -from f5_tts.model.utils import ( - default, - exists, - list_str_to_idx, - list_str_to_tensor, - lens_to_mask, - mask_from_frac_lengths, - sample_consecutive_steps, - sample_from_list, -) + class UniModel(nn.Module): - def __init__(self, - model: DiT, # teacher model (dit model) - checkpoint_path: str = "", - second_time: bool = True, - use_fp16: bool = True, - real_guidance_scale: float = 2.0, - fake_guidance_scale: float = 0.0, - gen_cls_loss: bool = False, - sway_coeff: float = -1.0, - vocab_char_map: dict[str, int] | None = None, - frac_lengths_mask: tuple[float, float] = (0.7, 1.0)): - + def __init__( + self, + model: DiT, # teacher model (dit model) + checkpoint_path: str = "", + second_time: bool = True, + use_fp16: bool = True, + real_guidance_scale: float = 2.0, + fake_guidance_scale: float = 0.0, + gen_cls_loss: bool = False, + sway_coeff: float = -1.0, + vocab_char_map: dict[str, int] | None = None, + frac_lengths_mask: tuple[float, float] = (0.7, 1.0), + ): + super().__init__() - + if checkpoint_path != "": if "model_last.pt" in os.listdir(checkpoint_path): latest_checkpoint = "model_last.pt" @@ -46,7 +42,11 @@ class UniModel(nn.Module): [f for f in os.listdir(checkpoint_path) if f.endswith(".pt")], key=lambda x: int("".join(filter(str.isdigit, x))), )[-1] - checkpoint = torch.load(f"{checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu") + checkpoint = torch.load( + f"{checkpoint_path}/{latest_checkpoint}", + weights_only=True, + map_location="cpu", + ) if "scale" in checkpoint: self.scale = checkpoint["scale"] @@ -74,12 +74,12 @@ class UniModel(nn.Module): model.load_state_dict(filtered_state_dict, strict=False) else: self.scale = 1.0 - + real_unet = copy.deepcopy(model) real_unet.time_embed2 = None - + fake_unet = copy.deepcopy(model) - + # Instantiate Guidance, which internally uses real_unet and fake_unet initialized from the teacher self.guidance_model = Guidance( real_unet=real_unet, @@ -90,23 +90,24 @@ class UniModel(nn.Module): gen_cls_loss=gen_cls_loss, sway_coeff=sway_coeff, ) - - self.feedforward_model = copy.deepcopy(model) # initialize the student model + + self.feedforward_model = copy.deepcopy(model) # initialize the student model self.feedforward_model.requires_grad_(True) self.feedforward_model.time_embed2 = None self.vocab_char_map = vocab_char_map self.frac_lengths_mask = frac_lengths_mask - - self.second_time = second_time # fake_unet.time_embed2 is not None - - def forward(self, - inp: float["b n d"], # mel - text: int["b nt"] | list[str], - *, - lens: int["b"] | None = None, - student_steps: list[int] = [0, 0.25, 0.5, 0.75], - update_generator: bool = False, + + self.second_time = second_time # fake_unet.time_embed2 is not None + + def forward( + self, + inp: float["b n d"], # mel + text: int["b nt"] | list[str], + *, + lens: int["b"] | None = None, + student_steps: list[int] = [0, 0.25, 0.5, 0.75], + update_generator: bool = False, ): """ Forward pass that routes to either generator_forward or guidance_forward @@ -126,7 +127,7 @@ class UniModel(nn.Module): "rand_span_mask": Tensor (B, N) - boolean mask "real_data": dict with keys like: "inp", "text", "rand_span_mask" - + Returns: -------- loss_dict: dict[str, Tensor] @@ -134,7 +135,7 @@ class UniModel(nn.Module): log_dict: dict[str, Tensor or float] Dictionary of logging tensors or values. """ - + batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device # handle text as string @@ -149,7 +150,9 @@ class UniModel(nn.Module): if not exists(lens): lens = torch.full((batch,), seq_len, device=device) - mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch + mask = lens_to_mask( + lens, length=seq_len + ) # useless here, as collate_fn will pad to max length in batch # sample from the list of student steps time = sample_from_list(student_steps, batch).to(device) @@ -157,12 +160,15 @@ class UniModel(nn.Module): time = torch.ones_like(time) * c_time p_time = torch.ones_like(time) * p_time - frac_lengths = torch.zeros((batch,), device=device).float().uniform_(*self.frac_lengths_mask) + frac_lengths = ( + torch.zeros((batch,), device=device) + .float() + .uniform_(*self.frac_lengths_mask) + ) rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) - + if exists(mask): rand_span_mask &= mask - # # use generated output from previous step as input with torch.no_grad(): @@ -171,41 +177,41 @@ class UniModel(nn.Module): t = p_time.unsqueeze(-1).unsqueeze(-1) phi = (1 - t) * x0 + t * x1 cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) - + pred = self.feedforward_model( - x=phi, + x=phi, cond=cond, - text=text, - time=p_time, - drop_audio_cond=False, - drop_text=False # make sure the cfg=1 - ) # flow prediction - + text=text, + time=p_time, + drop_audio_cond=False, + drop_text=False, # make sure the cfg=1 + ) # flow prediction + # predicted mel spectrogram - output = phi + (1 - t) * pred + output = phi + (1 - t) * pred output[~rand_span_mask] = inp[~rand_span_mask] - + # forward diffusion x1 = output x0 = torch.randn_like(x1) t = time.unsqueeze(-1).unsqueeze(-1) phi = (1 - t) * x0 + t * x1 cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) - + with torch.no_grad() if not update_generator else contextlib.nullcontext(): pred = self.feedforward_model( - x=phi, + x=phi, cond=cond, - text=text, - time=time, - drop_audio_cond=False, - drop_text=False # make sure no cfg is used + text=text, + time=time, + drop_audio_cond=False, + drop_text=False, # make sure no cfg is used ) - + # predicted mel spectrogram output = phi + (1 - t) * pred output[~rand_span_mask] = inp[~rand_span_mask] - + if update_generator: generator_data_dict = { "inp": output, @@ -216,10 +222,10 @@ class UniModel(nn.Module): "real_data": { "inp": inp, "text": text, - "rand_span_mask": rand_span_mask - } + "rand_span_mask": rand_span_mask, + }, } - + # avoid any side effects of gradient accumulation # self.guidance_model.requires_grad_(False) # self.feedforward_model.requires_grad_(True) @@ -227,15 +233,15 @@ class UniModel(nn.Module): generator_turn=True, guidance_turn=False, generator_data_dict=generator_data_dict, - guidance_data_dict=None + guidance_data_dict=None, ) - - generator_log_dict['ground_truth'] = x1 - generator_log_dict['generator_input'] = phi - generator_log_dict['generator_output'] = output - generator_log_dict['generator_cond'] = cond - generator_log_dict['time'] = time - + + generator_log_dict["ground_truth"] = x1 + generator_log_dict["generator_input"] = phi + generator_log_dict["generator_output"] = output + generator_log_dict["generator_cond"] = cond + generator_log_dict["time"] = time + return generator_loss_dict, generator_log_dict else: guidance_data_dict = { @@ -246,10 +252,10 @@ class UniModel(nn.Module): "real_data": { "inp": inp, "text": text, - "rand_span_mask": rand_span_mask - } + "rand_span_mask": rand_span_mask, + }, } - + # avoid any side effects of gradient accumulation # self.feedforward_model.requires_grad_(False) # self.guidance_model.requires_grad_(True) @@ -257,24 +263,25 @@ class UniModel(nn.Module): generator_turn=False, guidance_turn=True, generator_data_dict=None, - guidance_data_dict=guidance_data_dict + guidance_data_dict=guidance_data_dict, ) # self.feedforward_model.requires_grad_(True) - + return guidance_loss_dict, guidance_log_dict - + # return guidance_loss_dict, guidance_log_dict, generator_loss_dict, generator_log_dict - + if __name__ == "__main__": - - from f5_tts.model.utils import get_tokenizer + from torch.utils.data import DataLoader, Dataset, SequentialSampler - from f5_tts.model.dataset import load_dataset - from f5_tts.model.dataset import DynamicBatchSampler, collate_fn + + from f5_tts.model.dataset import (DynamicBatchSampler, collate_fn, + load_dataset) + from f5_tts.model.utils import get_tokenizer bsz = 16 - + tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) dataset_name = "Emilia_ZH_EN" @@ -284,34 +291,59 @@ if __name__ == "__main__": tokenizer_path = dataset_name vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) - dit = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4, text_num_embeds=vocab_size, mel_dim=100) - - model = UniModel(dit, - checkpoint_path="/data4/F5TTS/ckpts/F5TTS_Base_norm_flow_8GPU_vocos_pinyin_Emilia_ZH_EN", - gen_cls_loss=True, - vocab_char_map=vocab_char_map, - frac_lengths_mask=(0.7, 1.0) - ).cuda() - + dit = DiT( + dim=1024, + depth=22, + heads=16, + ff_mult=2, + text_dim=512, + conv_layers=4, + text_num_embeds=vocab_size, + mel_dim=100, + ) + + model = UniModel( + dit, + checkpoint_path="/data4/F5TTS/ckpts/F5TTS_Base_norm_flow_8GPU_vocos_pinyin_Emilia_ZH_EN", + gen_cls_loss=True, + vocab_char_map=vocab_char_map, + frac_lengths_mask=(0.7, 1.0), + ).cuda() + # batch = next(iter(train_dataloader)) # torch.save(batch, "batch.pt") batch = torch.load("batch.pt") - inp, text, lens = batch["mel"].permute(0, 2, 1).cuda(), batch["text"], batch["mel_lengths"].cuda() + inp, text, lens = ( + batch["mel"].permute(0, 2, 1).cuda(), + batch["text"], + batch["mel_lengths"].cuda(), + ) - # text = ["hello world"] * bsz # lens = torch.randint(1, 1000, (bsz,)).cuda() # inp = torch.randn(bsz, lens.max(), 100).cuda() with torch.autocast(device_type="cuda", dtype=torch.float16): num_student_step = 4 - guidance_loss_dict, guidance_log_dict = model(inp, text, lens=lens, update_generator=False, student_steps=(torch.linspace(0.0, 1.0, num_student_step + 1)[:-1])) + guidance_loss_dict, guidance_log_dict = model( + inp, + text, + lens=lens, + update_generator=False, + student_steps=(torch.linspace(0.0, 1.0, num_student_step + 1)[:-1]), + ) + + generator_loss_dict, generator_log_dict = model( + inp, + text, + lens=lens, + update_generator=True, + student_steps=(torch.linspace(0.0, 1.0, num_student_step + 1)[:-1]), + ) - generator_loss_dict, generator_log_dict = model(inp, text, lens=lens, update_generator=True, student_steps=(torch.linspace(0.0, 1.0, num_student_step + 1)[:-1])) - print(guidance_loss_dict) print(generator_loss_dict) - + guidance_loss = 0 guidance_loss += guidance_loss_dict["loss_fake_mean"] guidance_loss += guidance_loss_dict["guidance_cls_loss"] @@ -324,4 +356,4 @@ if __name__ == "__main__": generator_loss += generator_loss_dict["loss_mse"] guidance_loss.backward() - generator_loss.backward() \ No newline at end of file + generator_loss.backward()