mrfakename commited on
Commit
fbe6497
·
verified ·
1 Parent(s): 0b61e94

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

src/f5_tts/runtime/triton_trtllm/README.md CHANGED
@@ -30,18 +30,40 @@ bash run.sh 0 4 F5TTS_Base
30
  python3 client_http.py
31
  ```
32
 
33
- ### Benchmark using Dataset
34
  ```sh
35
  num_task=2
36
  python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
37
  ```
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  ### Benchmark Results
40
  Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs.
41
 
42
- | Model | Concurrency | Avg Latency | RTF |
43
- |-------|-------------|----------------|-------|
44
- | F5-TTS Base (Vocos) | 1 | 253 ms | 0.0394|
 
 
45
 
46
  ### Credits
47
  1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
 
30
  python3 client_http.py
31
  ```
32
 
33
+ ### Benchmark using Client-Server Mode
34
  ```sh
35
  num_task=2
36
  python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
37
  ```
38
 
39
+ ### Benchmark using Offline TRT-LLM Mode
40
+ ```sh
41
+ batch_size=1
42
+ split_name=wenetspeech4tts
43
+ backend_type=trt
44
+ log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
45
+ rm -r $log_dir
46
+ ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
47
+ torchrun --nproc_per_node=1 \
48
+ benchmark.py --output-dir $log_dir \
49
+ --batch-size $batch_size \
50
+ --enable-warmup \
51
+ --split-name $split_name \
52
+ --model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
53
+ --vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
54
+ --vocoder-trt-engine-path $vocoder_trt_engine_path \
55
+ --backend-type $backend_type \
56
+ --tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
57
+ ```
58
+
59
  ### Benchmark Results
60
  Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs.
61
 
62
+ | Model | Concurrency | Avg Latency | RTF | Mode |
63
+ |-------|-------------|----------------|-------|------|
64
+ | F5-TTS Base (Vocos) | 2 | 253 ms | 0.0394|Client-Server|
65
+ | F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.0402|Offline TRT-LLM|
66
+ | F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467|Offline Pytorch|
67
 
68
  ### Credits
69
  1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
src/f5_tts/runtime/triton_trtllm/benchmark.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
2
+ # 2025 (authors: Yuekai Zhang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py
16
+ """ Example Usage
17
+ torchrun --nproc_per_node=1 \
18
+ benchmark.py --output-dir $log_dir \
19
+ --batch-size $batch_size \
20
+ --enable-warmup \
21
+ --split-name $split_name \
22
+ --model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
23
+ --vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
24
+ --vocoder-trt-engine-path $vocoder_trt_engine_path \
25
+ --backend-type $backend_type \
26
+ --tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
27
+ """
28
+
29
+ import argparse
30
+ import json
31
+ import os
32
+ import time
33
+ from typing import List, Dict, Union
34
+
35
+ import torch
36
+ import torch.distributed as dist
37
+ import torch.nn.functional as F
38
+ from torch.nn.utils.rnn import pad_sequence
39
+ import torchaudio
40
+ import jieba
41
+ from pypinyin import Style, lazy_pinyin
42
+ from datasets import load_dataset
43
+ import datasets
44
+ from huggingface_hub import hf_hub_download
45
+ from torch.utils.data import DataLoader, DistributedSampler
46
+ from tqdm import tqdm
47
+ from vocos import Vocos
48
+ from f5_tts_trtllm import F5TTS
49
+ import tensorrt as trt
50
+ from tensorrt_llm.runtime.session import Session, TensorInfo
51
+ from tensorrt_llm.logger import logger
52
+ from tensorrt_llm._utils import trt_dtype_to_torch
53
+
54
+ torch.manual_seed(0)
55
+
56
+
57
+ def get_args():
58
+ parser = argparse.ArgumentParser(description="extract speech code")
59
+ parser.add_argument(
60
+ "--split-name",
61
+ type=str,
62
+ default="wenetspeech4tts",
63
+ choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
64
+ help="huggingface dataset split name",
65
+ )
66
+ parser.add_argument("--output-dir", required=True, type=str, help="dir to save result")
67
+ parser.add_argument(
68
+ "--vocab-file",
69
+ required=True,
70
+ type=str,
71
+ help="vocab file",
72
+ )
73
+ parser.add_argument(
74
+ "--model-path",
75
+ required=True,
76
+ type=str,
77
+ help="model path, to load text embedding",
78
+ )
79
+ parser.add_argument(
80
+ "--tllm-model-dir",
81
+ required=True,
82
+ type=str,
83
+ help="tllm model dir",
84
+ )
85
+ parser.add_argument(
86
+ "--batch-size",
87
+ required=True,
88
+ type=int,
89
+ help="batch size (per-device) for inference",
90
+ )
91
+ parser.add_argument("--num-workers", type=int, default=0, help="workers for dataloader")
92
+ parser.add_argument("--prefetch", type=int, default=None, help="prefetch for dataloader")
93
+ parser.add_argument(
94
+ "--vocoder",
95
+ default="vocos",
96
+ type=str,
97
+ help="vocoder name",
98
+ )
99
+ parser.add_argument(
100
+ "--vocoder-trt-engine-path",
101
+ default=None,
102
+ type=str,
103
+ help="vocoder trt engine path",
104
+ )
105
+ parser.add_argument("--enable-warmup", action="store_true")
106
+ parser.add_argument("--remove-input-padding", action="store_true")
107
+ parser.add_argument("--use-perf", action="store_true", help="use nvtx to record performance")
108
+ parser.add_argument("--backend-type", type=str, default="triton", choices=["trt", "pytorch"], help="backend type")
109
+ args = parser.parse_args()
110
+ return args
111
+
112
+
113
+ def padded_mel_batch(ref_mels, max_seq_len):
114
+ padded_ref_mels = []
115
+ for mel in ref_mels:
116
+ # pad along the last dimension
117
+ padded_ref_mel = F.pad(mel, (0, 0, 0, max_seq_len - mel.shape[0]), value=0)
118
+ padded_ref_mels.append(padded_ref_mel)
119
+ padded_ref_mels = torch.stack(padded_ref_mels)
120
+ return padded_ref_mels
121
+
122
+
123
+ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
124
+ if use_perf:
125
+ torch.cuda.nvtx.range_push("data_collator")
126
+ target_sample_rate = 24000
127
+ target_rms = 0.1
128
+ ids, ref_mel_list, ref_mel_len_list, estimated_reference_target_mel_len, reference_target_texts_list = (
129
+ [],
130
+ [],
131
+ [],
132
+ [],
133
+ [],
134
+ )
135
+ for i, item in enumerate(batch):
136
+ item_id, prompt_text, target_text = (
137
+ item["id"],
138
+ item["prompt_text"],
139
+ item["target_text"],
140
+ )
141
+ ids.append(item_id)
142
+ reference_target_texts_list.append(prompt_text + target_text)
143
+
144
+ ref_audio_org, ref_sr = (
145
+ item["prompt_audio"]["array"],
146
+ item["prompt_audio"]["sampling_rate"],
147
+ )
148
+ ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
149
+ ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
150
+ if ref_rms < target_rms:
151
+ ref_audio_org = ref_audio_org * target_rms / ref_rms
152
+
153
+ if ref_sr != target_sample_rate:
154
+ resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
155
+ ref_audio = resampler(ref_audio_org)
156
+ else:
157
+ ref_audio = ref_audio_org
158
+
159
+ if use_perf:
160
+ torch.cuda.nvtx.range_push(f"mel_spectrogram {i}")
161
+ ref_mel = mel_spectrogram(ref_audio, vocoder="vocos", device="cuda")
162
+ if use_perf:
163
+ torch.cuda.nvtx.range_pop()
164
+ ref_mel = ref_mel.squeeze()
165
+ ref_mel_len = ref_mel.shape[0]
166
+ assert ref_mel.shape[1] == 100
167
+
168
+ ref_mel_list.append(ref_mel)
169
+ ref_mel_len_list.append(ref_mel_len)
170
+
171
+ estimated_reference_target_mel_len.append(int(ref_mel.shape[0] * (1 + len(target_text) / len(prompt_text))))
172
+
173
+ max_seq_len = max(estimated_reference_target_mel_len)
174
+ ref_mel_batch = padded_mel_batch(ref_mel_list, max_seq_len)
175
+ ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
176
+
177
+ pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
178
+ text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map)
179
+
180
+ for i, item in enumerate(text_pad_sequence):
181
+ text_pad_sequence[i] = F.pad(
182
+ item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
183
+ )
184
+ text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
185
+ text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(device)
186
+ text_pad_sequence = F.pad(
187
+ text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
188
+ )
189
+ if use_perf:
190
+ torch.cuda.nvtx.range_pop()
191
+ return {
192
+ "ids": ids,
193
+ "ref_mel_batch": ref_mel_batch,
194
+ "ref_mel_len_batch": ref_mel_len_batch,
195
+ "text_pad_sequence": text_pad_sequence,
196
+ "estimated_reference_target_mel_len": estimated_reference_target_mel_len,
197
+ }
198
+
199
+
200
+ def init_distributed():
201
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
202
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
203
+ rank = int(os.environ.get("RANK", 0))
204
+ print(
205
+ "Inference on multiple gpus, this gpu {}".format(local_rank)
206
+ + ", rank {}, world_size {}".format(rank, world_size)
207
+ )
208
+ torch.cuda.set_device(local_rank)
209
+ # Initialize process group with explicit device IDs
210
+ dist.init_process_group(
211
+ "nccl",
212
+ )
213
+ return world_size, local_rank, rank
214
+
215
+
216
+ def get_tokenizer(vocab_file_path: str):
217
+ """
218
+ tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
219
+ - "char" for char-wise tokenizer, need .txt vocab_file
220
+ - "byte" for utf-8 tokenizer
221
+ - "custom" if you're directly passing in a path to the vocab.txt you want to use
222
+ vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
223
+ - if use "char", derived from unfiltered character & symbol counts of custom dataset
224
+ - if use "byte", set to 256 (unicode byte range)
225
+ """
226
+ with open(vocab_file_path, "r", encoding="utf-8") as f:
227
+ vocab_char_map = {}
228
+ for i, char in enumerate(f):
229
+ vocab_char_map[char[:-1]] = i
230
+ vocab_size = len(vocab_char_map)
231
+ return vocab_char_map, vocab_size
232
+
233
+
234
+ def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
235
+ final_reference_target_texts_list = []
236
+ custom_trans = str.maketrans(
237
+ {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
238
+ ) # add custom trans here, to address oov
239
+
240
+ def is_chinese(c):
241
+ return "\u3100" <= c <= "\u9fff" # common chinese characters
242
+
243
+ for text in reference_target_texts_list:
244
+ char_list = []
245
+ text = text.translate(custom_trans)
246
+ for seg in jieba.cut(text):
247
+ seg_byte_len = len(bytes(seg, "UTF-8"))
248
+ if seg_byte_len == len(seg): # if pure alphabets and symbols
249
+ if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
250
+ char_list.append(" ")
251
+ char_list.extend(seg)
252
+ elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
253
+ seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
254
+ for i, c in enumerate(seg):
255
+ if is_chinese(c):
256
+ char_list.append(" ")
257
+ char_list.append(seg_[i])
258
+ else: # if mixed characters, alphabets and symbols
259
+ for c in seg:
260
+ if ord(c) < 256:
261
+ char_list.extend(c)
262
+ elif is_chinese(c):
263
+ char_list.append(" ")
264
+ char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
265
+ else:
266
+ char_list.append(c)
267
+ final_reference_target_texts_list.append(char_list)
268
+
269
+ return final_reference_target_texts_list
270
+
271
+
272
+ def list_str_to_idx(
273
+ text: Union[List[str], List[List[str]]],
274
+ vocab_char_map: Dict[str, int], # {char: idx}
275
+ padding_value=-1,
276
+ ):
277
+ list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
278
+ # text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
279
+ return list_idx_tensors
280
+
281
+
282
+ def load_vocoder(
283
+ vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None
284
+ ):
285
+ if vocoder_name == "vocos":
286
+ if vocoder_trt_engine_path is not None:
287
+ vocoder = VocosTensorRT(engine_path=vocoder_trt_engine_path)
288
+ else:
289
+ # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
290
+ if is_local:
291
+ print(f"Load vocos from local path {local_path}")
292
+ config_path = f"{local_path}/config.yaml"
293
+ model_path = f"{local_path}/pytorch_model.bin"
294
+ else:
295
+ print("Download Vocos from huggingface charactr/vocos-mel-24khz")
296
+ repo_id = "charactr/vocos-mel-24khz"
297
+ config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
298
+ model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
299
+ vocoder = Vocos.from_hparams(config_path)
300
+ state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
301
+ from vocos.feature_extractors import EncodecFeatures
302
+
303
+ if isinstance(vocoder.feature_extractor, EncodecFeatures):
304
+ encodec_parameters = {
305
+ "feature_extractor.encodec." + key: value
306
+ for key, value in vocoder.feature_extractor.encodec.state_dict().items()
307
+ }
308
+ state_dict.update(encodec_parameters)
309
+ vocoder.load_state_dict(state_dict)
310
+ vocoder = vocoder.eval().to(device)
311
+ elif vocoder_name == "bigvgan":
312
+ raise NotImplementedError("BigVGAN is not implemented yet")
313
+ return vocoder
314
+
315
+
316
+ def mel_spectrogram(waveform, vocoder="vocos", device="cuda"):
317
+ if vocoder == "vocos":
318
+ mel_stft = torchaudio.transforms.MelSpectrogram(
319
+ sample_rate=24000,
320
+ n_fft=1024,
321
+ win_length=1024,
322
+ hop_length=256,
323
+ n_mels=100,
324
+ power=1,
325
+ center=True,
326
+ normalized=False,
327
+ norm=None,
328
+ ).to(device)
329
+ mel = mel_stft(waveform.to(device))
330
+ mel = mel.clamp(min=1e-5).log()
331
+ return mel.transpose(1, 2)
332
+
333
+
334
+ class VocosTensorRT:
335
+ def __init__(self, engine_path="./vocos_vocoder.plan", stream=None):
336
+ TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
337
+ trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
338
+ logger.info(f"Loading vae engine from {engine_path}")
339
+ self.engine_path = engine_path
340
+ with open(engine_path, "rb") as f:
341
+ engine_buffer = f.read()
342
+ self.session = Session.from_serialized_engine(engine_buffer)
343
+ self.stream = stream if stream is not None else torch.cuda.current_stream().cuda_stream
344
+
345
+ def decode(self, mels):
346
+ mels = mels.contiguous()
347
+ inputs = {"mel": mels}
348
+ output_info = self.session.infer_shapes([TensorInfo("mel", trt.DataType.FLOAT, mels.shape)])
349
+ outputs = {
350
+ t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda") for t in output_info
351
+ }
352
+ ok = self.session.run(inputs, outputs, self.stream)
353
+
354
+ assert ok, "Runtime execution failed for vae session"
355
+
356
+ samples = outputs["waveform"]
357
+ return samples
358
+
359
+
360
+ def main():
361
+ args = get_args()
362
+ os.makedirs(args.output_dir, exist_ok=True)
363
+
364
+ assert torch.cuda.is_available()
365
+ world_size, local_rank, rank = init_distributed()
366
+ device = torch.device(f"cuda:{local_rank}")
367
+
368
+ vocab_char_map, vocab_size = get_tokenizer(args.vocab_file)
369
+
370
+ tllm_model_dir = args.tllm_model_dir
371
+ config_file = os.path.join(tllm_model_dir, "config.json")
372
+ with open(config_file) as f:
373
+ config = json.load(f)
374
+ if args.backend_type == "trt":
375
+ model = F5TTS(
376
+ config, debug_mode=False, tllm_model_dir=tllm_model_dir, model_path=args.model_path, vocab_size=vocab_size
377
+ )
378
+ elif args.backend_type == "pytorch":
379
+ import sys
380
+
381
+ sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
382
+ from f5_tts.model import DiT
383
+ from f5_tts.infer.utils_infer import load_model
384
+
385
+ F5TTS_model_cfg = dict(
386
+ dim=1024,
387
+ depth=22,
388
+ heads=16,
389
+ ff_mult=2,
390
+ text_dim=512,
391
+ conv_layers=4,
392
+ pe_attn_head=1,
393
+ text_mask_padding=False,
394
+ )
395
+ model = load_model(DiT, F5TTS_model_cfg, args.model_path)
396
+
397
+ vocoder = load_vocoder(
398
+ vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path
399
+ )
400
+
401
+ dataset = load_dataset(
402
+ "yuekai/seed_tts",
403
+ split=args.split_name,
404
+ trust_remote_code=True,
405
+ )
406
+
407
+ def add_estimated_duration(example):
408
+ prompt_audio_len = example["prompt_audio"]["array"].shape[0]
409
+ scale_factor = 1 + len(example["target_text"]) / len(example["prompt_text"])
410
+ estimated_duration = prompt_audio_len * scale_factor
411
+ example["estimated_duration"] = estimated_duration / example["prompt_audio"]["sampling_rate"]
412
+ return example
413
+
414
+ dataset = dataset.map(add_estimated_duration)
415
+ dataset = dataset.sort("estimated_duration", reverse=True)
416
+ if args.use_perf:
417
+ # dataset_list = [dataset.select(range(1)) for i in range(16)] # seq_len 1000
418
+ dataset_list_short = [dataset.select([24]) for i in range(8)] # seq_len 719
419
+ # dataset_list_long = [dataset.select([23]) for i in range(8)] # seq_len 2002
420
+ # dataset = datasets.concatenate_datasets(dataset_list_short + dataset_list_long)
421
+ dataset = datasets.concatenate_datasets(dataset_list_short)
422
+ if world_size > 1:
423
+ sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
424
+ else:
425
+ # This would disable shuffling
426
+ sampler = None
427
+
428
+ dataloader = DataLoader(
429
+ dataset,
430
+ batch_size=args.batch_size,
431
+ sampler=sampler,
432
+ shuffle=False,
433
+ num_workers=args.num_workers,
434
+ prefetch_factor=args.prefetch,
435
+ collate_fn=lambda x: data_collator(x, vocab_char_map, use_perf=args.use_perf),
436
+ )
437
+
438
+ total_steps = len(dataset)
439
+
440
+ if args.enable_warmup:
441
+ for batch in dataloader:
442
+ ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
443
+ text_pad_seq = batch["text_pad_sequence"].to(device)
444
+ total_mel_lens = batch["estimated_reference_target_mel_len"]
445
+ if args.backend_type == "trt":
446
+ _ = model.sample(
447
+ text_pad_seq, ref_mels, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding
448
+ )
449
+ elif args.backend_type == "pytorch":
450
+ with torch.inference_mode():
451
+ text_pad_seq -= 1
452
+ text_pad_seq[text_pad_seq == -2] = -1
453
+ total_mel_lens = torch.tensor(total_mel_lens, device=device)
454
+ generated, _ = model.sample(
455
+ cond=ref_mels,
456
+ text=text_pad_seq,
457
+ duration=total_mel_lens,
458
+ steps=16,
459
+ cfg_strength=2.0,
460
+ sway_sampling_coef=-1,
461
+ )
462
+
463
+ if rank == 0:
464
+ progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
465
+
466
+ decoding_time = 0
467
+ vocoder_time = 0
468
+ total_duration = 0
469
+ if args.use_perf:
470
+ torch.cuda.cudart().cudaProfilerStart()
471
+ total_decoding_time = time.time()
472
+ for batch in dataloader:
473
+ if args.use_perf:
474
+ torch.cuda.nvtx.range_push("data sample")
475
+ ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
476
+ text_pad_seq = batch["text_pad_sequence"].to(device)
477
+ total_mel_lens = batch["estimated_reference_target_mel_len"]
478
+
479
+ if args.use_perf:
480
+ torch.cuda.nvtx.range_pop()
481
+ if args.backend_type == "trt":
482
+ generated, cost_time = model.sample(
483
+ text_pad_seq,
484
+ ref_mels,
485
+ ref_mel_lens,
486
+ total_mel_lens,
487
+ remove_input_padding=args.remove_input_padding,
488
+ use_perf=args.use_perf,
489
+ )
490
+ elif args.backend_type == "pytorch":
491
+ total_mel_lens = torch.tensor(total_mel_lens, device=device)
492
+ with torch.inference_mode():
493
+ start_time = time.time()
494
+ text_pad_seq -= 1
495
+ text_pad_seq[text_pad_seq == -2] = -1
496
+ generated, _ = model.sample(
497
+ cond=ref_mels,
498
+ text=text_pad_seq,
499
+ duration=total_mel_lens,
500
+ lens=ref_mel_lens,
501
+ steps=16,
502
+ cfg_strength=2.0,
503
+ sway_sampling_coef=-1,
504
+ )
505
+ cost_time = time.time() - start_time
506
+ decoding_time += cost_time
507
+ vocoder_start_time = time.time()
508
+ for i, gen in enumerate(generated):
509
+ gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
510
+ gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
511
+ if args.vocoder == "vocos":
512
+ if args.use_perf:
513
+ torch.cuda.nvtx.range_push("vocoder decode")
514
+ generated_wave = vocoder.decode(gen_mel_spec).cpu()
515
+ if args.use_perf:
516
+ torch.cuda.nvtx.range_pop()
517
+ else:
518
+ generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
519
+ target_rms = 0.1
520
+ target_sample_rate = 24_000
521
+ # if ref_rms_list[i] < target_rms:
522
+ # generated_wave = generated_wave * ref_rms_list[i] / target_rms
523
+ rms = torch.sqrt(torch.mean(torch.square(generated_wave)))
524
+ if rms < target_rms:
525
+ generated_wave = generated_wave * target_rms / rms
526
+ utt = batch["ids"][i]
527
+ torchaudio.save(
528
+ f"{args.output_dir}/{utt}.wav",
529
+ generated_wave,
530
+ target_sample_rate,
531
+ )
532
+ total_duration += generated_wave.shape[1] / target_sample_rate
533
+ vocoder_time += time.time() - vocoder_start_time
534
+ if rank == 0:
535
+ progress_bar.update(world_size * len(batch["ids"]))
536
+ total_decoding_time = time.time() - total_decoding_time
537
+ if rank == 0:
538
+ progress_bar.close()
539
+ rtf = total_decoding_time / total_duration
540
+ s = f"RTF: {rtf:.4f}\n"
541
+ s += f"total_duration: {total_duration:.3f} seconds\n"
542
+ s += f"({total_duration / 3600:.2f} hours)\n"
543
+ s += f"DiT time: {decoding_time:.3f} seconds ({decoding_time / 3600:.2f} hours)\n"
544
+ s += f"Vocoder time: {vocoder_time:.3f} seconds ({vocoder_time / 3600:.2f} hours)\n"
545
+ s += f"total decoding time: {total_decoding_time:.3f} seconds ({total_decoding_time / 3600:.2f} hours)\n"
546
+ s += f"batch size: {args.batch_size}\n"
547
+ print(s)
548
+
549
+ with open(f"{args.output_dir}/rtf.txt", "w") as f:
550
+ f.write(s)
551
+
552
+ dist.barrier()
553
+ dist.destroy_process_group()
554
+
555
+
556
+ if __name__ == "__main__":
557
+ main()
src/f5_tts/runtime/triton_trtllm/requirements-pytorch.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=0.33.0
2
+ bitsandbytes>0.37.0
3
+ cached_path
4
+ click
5
+ datasets
6
+ ema_pytorch>=0.5.2
7
+ gradio>=3.45.2
8
+ hydra-core>=1.3.0
9
+ jieba
10
+ librosa
11
+ matplotlib
12
+ numpy<=1.26.4
13
+ pydub
14
+ pypinyin
15
+ safetensors
16
+ soundfile
17
+ tomli
18
+ torch>=2.0.0
19
+ # torchaudio>=2.0.0
20
+ torchdiffeq
21
+ tqdm>=4.65.0
22
+ transformers
23
+ x_transformers>=1.31.14
24
+ packaging>=24.2
src/f5_tts/runtime/triton_trtllm/run.sh CHANGED
@@ -2,8 +2,8 @@ stage=$1
2
  stop_stage=$2
3
  model=$3 # F5TTS_Base
4
  if [ -z "$model" ]; then
5
- echo "Model is none"
6
- exit 1
7
  fi
8
  echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
9
  export CUDA_VISIBLE_DEVICES=0
@@ -68,3 +68,43 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
68
  target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
69
  python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text"
70
  fi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  stop_stage=$2
3
  model=$3 # F5TTS_Base
4
  if [ -z "$model" ]; then
5
+ echo "Model is none, using default model F5TTS_Base"
6
+ model=F5TTS_Base
7
  fi
8
  echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
9
  export CUDA_VISIBLE_DEVICES=0
 
68
  target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
69
  python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text"
70
  fi
71
+
72
+ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
73
+ echo "TRT-LLM: offline decoding benchmark test"
74
+ batch_size=1
75
+ split_name=wenetspeech4tts
76
+ backend_type=trt
77
+ log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
78
+ rm -r $log_dir
79
+ ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
80
+ torchrun --nproc_per_node=1 \
81
+ benchmark.py --output-dir $log_dir \
82
+ --batch-size $batch_size \
83
+ --enable-warmup \
84
+ --split-name $split_name \
85
+ --model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
86
+ --vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
87
+ --vocoder-trt-engine-path $vocoder_trt_engine_path \
88
+ --backend-type $backend_type \
89
+ --tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
90
+ fi
91
+
92
+ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
93
+ echo "Native Pytorch: offline decoding benchmark test"
94
+ pip install -r requirements-pytorch.txt
95
+ batch_size=1
96
+ split_name=wenetspeech4tts
97
+ backend_type=pytorch
98
+ log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
99
+ rm -r $log_dir
100
+ ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
101
+ torchrun --nproc_per_node=1 \
102
+ benchmark.py --output-dir $log_dir \
103
+ --batch-size $batch_size \
104
+ --split-name $split_name \
105
+ --enable-warmup \
106
+ --model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
107
+ --vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
108
+ --backend-type $backend_type \
109
+ --tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
110
+ fi