from __future__ import annotations import sys import os import tensorrt as trt from collections import OrderedDict from ..._utils import str_dtype_to_trt from ...plugin import current_all_reduce_helper from ..modeling_utils import PretrainedConfig, PretrainedModel from ...functional import Tensor, concat from ...module import Module, ModuleList from tensorrt_llm._common import default_net from ...layers import Linear from .modules import ( TimestepEmbedding, ConvPositionEmbedding, DiTBlock, AdaLayerNormZero_Final, ) current_file_path = os.path.abspath(__file__) parent_dir = os.path.dirname(current_file_path) sys.path.append(parent_dir) class InputEmbedding(Module): def __init__(self, mel_dim, text_dim, out_dim): super().__init__() self.proj = Linear(mel_dim * 2 + text_dim, out_dim) self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) def forward(self, x, cond): x = self.proj(concat([x, cond], dim=-1)) return self.conv_pos_embed(x) + x class F5TTS(PretrainedModel): def __init__(self, config: PretrainedConfig): super().__init__(config) self.dtype = str_dtype_to_trt(config.dtype) self.time_embed = TimestepEmbedding(config.hidden_size) self.input_embed = InputEmbedding(config.mel_dim, config.text_dim, config.hidden_size) self.dim = config.hidden_size self.depth = config.num_hidden_layers self.transformer_blocks = ModuleList( [ DiTBlock( dim=self.dim, heads=config.num_attention_heads, dim_head=config.dim_head, ff_mult=config.ff_mult, dropout=config.dropout, ) for _ in range(self.depth) ] ) self.norm_out = AdaLayerNormZero_Final(config.hidden_size) # final modulation self.proj_out = Linear(config.hidden_size, config.mel_dim) def forward( self, noise, # nosied input audio cond, # masked cond audio time, # time step rope_cos, rope_sin, input_lengths, scale=1.0, ): t = self.time_embed(time) x = self.input_embed(noise, cond) for block in self.transformer_blocks: x = block(x, t, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale) denoise = self.proj_out(self.norm_out(x, t)) denoise.mark_output("denoised", self.dtype) return denoise def prepare_inputs(self, **kwargs): max_batch_size = kwargs["max_batch_size"] batch_size_range = [2, 2, max_batch_size] mel_size = 100 max_seq_len = 3000 num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size] hidden_size = 512 concat_feature_dim = mel_size + hidden_size freq_embed_dim = 256 head_dim = 64 mapping = self.config.mapping if mapping.tp_size > 1: current_all_reduce_helper().set_workspace_tensor(mapping, 1) if default_net().plugin_config.remove_input_padding: noise = Tensor( name="noise", dtype=self.dtype, shape=[-1, mel_size], dim_range=OrderedDict( [ ("num_frames", [num_frames_range]), ("n_mels", [mel_size]), ] ), ) cond = Tensor( name="cond", dtype=self.dtype, shape=[-1, concat_feature_dim], dim_range=OrderedDict( [ ("num_frames", [num_frames_range]), ("embeded_length", [concat_feature_dim]), ] ), ) time = Tensor( name="time", dtype=self.dtype, shape=[-1, freq_embed_dim], dim_range=OrderedDict( [ ("num_frames", [num_frames_range]), ("freq_dim", [freq_embed_dim]), ] ), ) rope_cos = Tensor( name="rope_cos", dtype=self.dtype, shape=[-1, head_dim], dim_range=OrderedDict( [ ("num_frames", [num_frames_range]), ("head_dim", [head_dim]), ] ), ) rope_sin = Tensor( name="rope_sin", dtype=self.dtype, shape=[-1, head_dim], dim_range=OrderedDict( [ ("num_frames", [num_frames_range]), ("head_dim", [head_dim]), ] ), ) else: noise = Tensor( name="noise", dtype=self.dtype, shape=[-1, -1, mel_size], dim_range=OrderedDict( [ ("batch_size", [batch_size_range]), ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]), ("n_mels", [mel_size]), ] ), ) cond = Tensor( name="cond", dtype=self.dtype, shape=[-1, -1, concat_feature_dim], dim_range=OrderedDict( [ ("batch_size", [batch_size_range]), ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]), ("embeded_length", [concat_feature_dim]), ] ), ) time = Tensor( name="time", dtype=self.dtype, shape=[-1, freq_embed_dim], dim_range=OrderedDict( [ ("batch_size", [batch_size_range]), ("freq_dim", [freq_embed_dim]), ] ), ) rope_cos = Tensor( name="rope_cos", dtype=self.dtype, shape=[-1, -1, head_dim], dim_range=OrderedDict( [ ("batch_size", [batch_size_range]), ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]), ("head_dim", [head_dim]), ] ), ) rope_sin = Tensor( name="rope_sin", dtype=self.dtype, shape=[-1, -1, head_dim], dim_range=OrderedDict( [ ("batch_size", [batch_size_range]), ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]), ("head_dim", [head_dim]), ] ), ) input_lengths = Tensor( name="input_lengths", dtype=trt.int32, shape=[-1], dim_range=OrderedDict([("batch_size", [batch_size_range])]), ) return { "noise": noise, "cond": cond, "time": time, "rope_cos": rope_cos, "rope_sin": rope_sin, "input_lengths": input_lengths, }