Spaces:
Running
on
Zero
Running
on
Zero
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
src/f5_tts/runtime/triton_trtllm/README.md
CHANGED
@@ -30,18 +30,40 @@ bash run.sh 0 4 F5TTS_Base
|
|
30 |
python3 client_http.py
|
31 |
```
|
32 |
|
33 |
-
### Benchmark using
|
34 |
```sh
|
35 |
num_task=2
|
36 |
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
|
37 |
```
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
### Benchmark Results
|
40 |
Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs.
|
41 |
|
42 |
-
| Model | Concurrency | Avg Latency | RTF |
|
43 |
-
|
44 |
-
| F5-TTS Base (Vocos) |
|
|
|
|
|
45 |
|
46 |
### Credits
|
47 |
1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
|
|
|
30 |
python3 client_http.py
|
31 |
```
|
32 |
|
33 |
+
### Benchmark using Client-Server Mode
|
34 |
```sh
|
35 |
num_task=2
|
36 |
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
|
37 |
```
|
38 |
|
39 |
+
### Benchmark using Offline TRT-LLM Mode
|
40 |
+
```sh
|
41 |
+
batch_size=1
|
42 |
+
split_name=wenetspeech4tts
|
43 |
+
backend_type=trt
|
44 |
+
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
|
45 |
+
rm -r $log_dir
|
46 |
+
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
|
47 |
+
torchrun --nproc_per_node=1 \
|
48 |
+
benchmark.py --output-dir $log_dir \
|
49 |
+
--batch-size $batch_size \
|
50 |
+
--enable-warmup \
|
51 |
+
--split-name $split_name \
|
52 |
+
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
|
53 |
+
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
|
54 |
+
--vocoder-trt-engine-path $vocoder_trt_engine_path \
|
55 |
+
--backend-type $backend_type \
|
56 |
+
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
|
57 |
+
```
|
58 |
+
|
59 |
### Benchmark Results
|
60 |
Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs.
|
61 |
|
62 |
+
| Model | Concurrency | Avg Latency | RTF | Mode |
|
63 |
+
|-------|-------------|----------------|-------|------|
|
64 |
+
| F5-TTS Base (Vocos) | 2 | 253 ms | 0.0394|Client-Server|
|
65 |
+
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.0402|Offline TRT-LLM|
|
66 |
+
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467|Offline Pytorch|
|
67 |
|
68 |
### Credits
|
69 |
1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
|
src/f5_tts/runtime/triton_trtllm/benchmark.py
ADDED
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
|
2 |
+
# 2025 (authors: Yuekai Zhang)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py
|
16 |
+
""" Example Usage
|
17 |
+
torchrun --nproc_per_node=1 \
|
18 |
+
benchmark.py --output-dir $log_dir \
|
19 |
+
--batch-size $batch_size \
|
20 |
+
--enable-warmup \
|
21 |
+
--split-name $split_name \
|
22 |
+
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
|
23 |
+
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
|
24 |
+
--vocoder-trt-engine-path $vocoder_trt_engine_path \
|
25 |
+
--backend-type $backend_type \
|
26 |
+
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
|
27 |
+
"""
|
28 |
+
|
29 |
+
import argparse
|
30 |
+
import json
|
31 |
+
import os
|
32 |
+
import time
|
33 |
+
from typing import List, Dict, Union
|
34 |
+
|
35 |
+
import torch
|
36 |
+
import torch.distributed as dist
|
37 |
+
import torch.nn.functional as F
|
38 |
+
from torch.nn.utils.rnn import pad_sequence
|
39 |
+
import torchaudio
|
40 |
+
import jieba
|
41 |
+
from pypinyin import Style, lazy_pinyin
|
42 |
+
from datasets import load_dataset
|
43 |
+
import datasets
|
44 |
+
from huggingface_hub import hf_hub_download
|
45 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
46 |
+
from tqdm import tqdm
|
47 |
+
from vocos import Vocos
|
48 |
+
from f5_tts_trtllm import F5TTS
|
49 |
+
import tensorrt as trt
|
50 |
+
from tensorrt_llm.runtime.session import Session, TensorInfo
|
51 |
+
from tensorrt_llm.logger import logger
|
52 |
+
from tensorrt_llm._utils import trt_dtype_to_torch
|
53 |
+
|
54 |
+
torch.manual_seed(0)
|
55 |
+
|
56 |
+
|
57 |
+
def get_args():
|
58 |
+
parser = argparse.ArgumentParser(description="extract speech code")
|
59 |
+
parser.add_argument(
|
60 |
+
"--split-name",
|
61 |
+
type=str,
|
62 |
+
default="wenetspeech4tts",
|
63 |
+
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
|
64 |
+
help="huggingface dataset split name",
|
65 |
+
)
|
66 |
+
parser.add_argument("--output-dir", required=True, type=str, help="dir to save result")
|
67 |
+
parser.add_argument(
|
68 |
+
"--vocab-file",
|
69 |
+
required=True,
|
70 |
+
type=str,
|
71 |
+
help="vocab file",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--model-path",
|
75 |
+
required=True,
|
76 |
+
type=str,
|
77 |
+
help="model path, to load text embedding",
|
78 |
+
)
|
79 |
+
parser.add_argument(
|
80 |
+
"--tllm-model-dir",
|
81 |
+
required=True,
|
82 |
+
type=str,
|
83 |
+
help="tllm model dir",
|
84 |
+
)
|
85 |
+
parser.add_argument(
|
86 |
+
"--batch-size",
|
87 |
+
required=True,
|
88 |
+
type=int,
|
89 |
+
help="batch size (per-device) for inference",
|
90 |
+
)
|
91 |
+
parser.add_argument("--num-workers", type=int, default=0, help="workers for dataloader")
|
92 |
+
parser.add_argument("--prefetch", type=int, default=None, help="prefetch for dataloader")
|
93 |
+
parser.add_argument(
|
94 |
+
"--vocoder",
|
95 |
+
default="vocos",
|
96 |
+
type=str,
|
97 |
+
help="vocoder name",
|
98 |
+
)
|
99 |
+
parser.add_argument(
|
100 |
+
"--vocoder-trt-engine-path",
|
101 |
+
default=None,
|
102 |
+
type=str,
|
103 |
+
help="vocoder trt engine path",
|
104 |
+
)
|
105 |
+
parser.add_argument("--enable-warmup", action="store_true")
|
106 |
+
parser.add_argument("--remove-input-padding", action="store_true")
|
107 |
+
parser.add_argument("--use-perf", action="store_true", help="use nvtx to record performance")
|
108 |
+
parser.add_argument("--backend-type", type=str, default="triton", choices=["trt", "pytorch"], help="backend type")
|
109 |
+
args = parser.parse_args()
|
110 |
+
return args
|
111 |
+
|
112 |
+
|
113 |
+
def padded_mel_batch(ref_mels, max_seq_len):
|
114 |
+
padded_ref_mels = []
|
115 |
+
for mel in ref_mels:
|
116 |
+
# pad along the last dimension
|
117 |
+
padded_ref_mel = F.pad(mel, (0, 0, 0, max_seq_len - mel.shape[0]), value=0)
|
118 |
+
padded_ref_mels.append(padded_ref_mel)
|
119 |
+
padded_ref_mels = torch.stack(padded_ref_mels)
|
120 |
+
return padded_ref_mels
|
121 |
+
|
122 |
+
|
123 |
+
def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
|
124 |
+
if use_perf:
|
125 |
+
torch.cuda.nvtx.range_push("data_collator")
|
126 |
+
target_sample_rate = 24000
|
127 |
+
target_rms = 0.1
|
128 |
+
ids, ref_mel_list, ref_mel_len_list, estimated_reference_target_mel_len, reference_target_texts_list = (
|
129 |
+
[],
|
130 |
+
[],
|
131 |
+
[],
|
132 |
+
[],
|
133 |
+
[],
|
134 |
+
)
|
135 |
+
for i, item in enumerate(batch):
|
136 |
+
item_id, prompt_text, target_text = (
|
137 |
+
item["id"],
|
138 |
+
item["prompt_text"],
|
139 |
+
item["target_text"],
|
140 |
+
)
|
141 |
+
ids.append(item_id)
|
142 |
+
reference_target_texts_list.append(prompt_text + target_text)
|
143 |
+
|
144 |
+
ref_audio_org, ref_sr = (
|
145 |
+
item["prompt_audio"]["array"],
|
146 |
+
item["prompt_audio"]["sampling_rate"],
|
147 |
+
)
|
148 |
+
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
|
149 |
+
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
|
150 |
+
if ref_rms < target_rms:
|
151 |
+
ref_audio_org = ref_audio_org * target_rms / ref_rms
|
152 |
+
|
153 |
+
if ref_sr != target_sample_rate:
|
154 |
+
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
155 |
+
ref_audio = resampler(ref_audio_org)
|
156 |
+
else:
|
157 |
+
ref_audio = ref_audio_org
|
158 |
+
|
159 |
+
if use_perf:
|
160 |
+
torch.cuda.nvtx.range_push(f"mel_spectrogram {i}")
|
161 |
+
ref_mel = mel_spectrogram(ref_audio, vocoder="vocos", device="cuda")
|
162 |
+
if use_perf:
|
163 |
+
torch.cuda.nvtx.range_pop()
|
164 |
+
ref_mel = ref_mel.squeeze()
|
165 |
+
ref_mel_len = ref_mel.shape[0]
|
166 |
+
assert ref_mel.shape[1] == 100
|
167 |
+
|
168 |
+
ref_mel_list.append(ref_mel)
|
169 |
+
ref_mel_len_list.append(ref_mel_len)
|
170 |
+
|
171 |
+
estimated_reference_target_mel_len.append(int(ref_mel.shape[0] * (1 + len(target_text) / len(prompt_text))))
|
172 |
+
|
173 |
+
max_seq_len = max(estimated_reference_target_mel_len)
|
174 |
+
ref_mel_batch = padded_mel_batch(ref_mel_list, max_seq_len)
|
175 |
+
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
|
176 |
+
|
177 |
+
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
|
178 |
+
text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map)
|
179 |
+
|
180 |
+
for i, item in enumerate(text_pad_sequence):
|
181 |
+
text_pad_sequence[i] = F.pad(
|
182 |
+
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
|
183 |
+
)
|
184 |
+
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
|
185 |
+
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(device)
|
186 |
+
text_pad_sequence = F.pad(
|
187 |
+
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
|
188 |
+
)
|
189 |
+
if use_perf:
|
190 |
+
torch.cuda.nvtx.range_pop()
|
191 |
+
return {
|
192 |
+
"ids": ids,
|
193 |
+
"ref_mel_batch": ref_mel_batch,
|
194 |
+
"ref_mel_len_batch": ref_mel_len_batch,
|
195 |
+
"text_pad_sequence": text_pad_sequence,
|
196 |
+
"estimated_reference_target_mel_len": estimated_reference_target_mel_len,
|
197 |
+
}
|
198 |
+
|
199 |
+
|
200 |
+
def init_distributed():
|
201 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
202 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
203 |
+
rank = int(os.environ.get("RANK", 0))
|
204 |
+
print(
|
205 |
+
"Inference on multiple gpus, this gpu {}".format(local_rank)
|
206 |
+
+ ", rank {}, world_size {}".format(rank, world_size)
|
207 |
+
)
|
208 |
+
torch.cuda.set_device(local_rank)
|
209 |
+
# Initialize process group with explicit device IDs
|
210 |
+
dist.init_process_group(
|
211 |
+
"nccl",
|
212 |
+
)
|
213 |
+
return world_size, local_rank, rank
|
214 |
+
|
215 |
+
|
216 |
+
def get_tokenizer(vocab_file_path: str):
|
217 |
+
"""
|
218 |
+
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
|
219 |
+
- "char" for char-wise tokenizer, need .txt vocab_file
|
220 |
+
- "byte" for utf-8 tokenizer
|
221 |
+
- "custom" if you're directly passing in a path to the vocab.txt you want to use
|
222 |
+
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
|
223 |
+
- if use "char", derived from unfiltered character & symbol counts of custom dataset
|
224 |
+
- if use "byte", set to 256 (unicode byte range)
|
225 |
+
"""
|
226 |
+
with open(vocab_file_path, "r", encoding="utf-8") as f:
|
227 |
+
vocab_char_map = {}
|
228 |
+
for i, char in enumerate(f):
|
229 |
+
vocab_char_map[char[:-1]] = i
|
230 |
+
vocab_size = len(vocab_char_map)
|
231 |
+
return vocab_char_map, vocab_size
|
232 |
+
|
233 |
+
|
234 |
+
def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
|
235 |
+
final_reference_target_texts_list = []
|
236 |
+
custom_trans = str.maketrans(
|
237 |
+
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
|
238 |
+
) # add custom trans here, to address oov
|
239 |
+
|
240 |
+
def is_chinese(c):
|
241 |
+
return "\u3100" <= c <= "\u9fff" # common chinese characters
|
242 |
+
|
243 |
+
for text in reference_target_texts_list:
|
244 |
+
char_list = []
|
245 |
+
text = text.translate(custom_trans)
|
246 |
+
for seg in jieba.cut(text):
|
247 |
+
seg_byte_len = len(bytes(seg, "UTF-8"))
|
248 |
+
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
249 |
+
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
250 |
+
char_list.append(" ")
|
251 |
+
char_list.extend(seg)
|
252 |
+
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
|
253 |
+
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
|
254 |
+
for i, c in enumerate(seg):
|
255 |
+
if is_chinese(c):
|
256 |
+
char_list.append(" ")
|
257 |
+
char_list.append(seg_[i])
|
258 |
+
else: # if mixed characters, alphabets and symbols
|
259 |
+
for c in seg:
|
260 |
+
if ord(c) < 256:
|
261 |
+
char_list.extend(c)
|
262 |
+
elif is_chinese(c):
|
263 |
+
char_list.append(" ")
|
264 |
+
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
|
265 |
+
else:
|
266 |
+
char_list.append(c)
|
267 |
+
final_reference_target_texts_list.append(char_list)
|
268 |
+
|
269 |
+
return final_reference_target_texts_list
|
270 |
+
|
271 |
+
|
272 |
+
def list_str_to_idx(
|
273 |
+
text: Union[List[str], List[List[str]]],
|
274 |
+
vocab_char_map: Dict[str, int], # {char: idx}
|
275 |
+
padding_value=-1,
|
276 |
+
):
|
277 |
+
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
|
278 |
+
# text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
279 |
+
return list_idx_tensors
|
280 |
+
|
281 |
+
|
282 |
+
def load_vocoder(
|
283 |
+
vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None
|
284 |
+
):
|
285 |
+
if vocoder_name == "vocos":
|
286 |
+
if vocoder_trt_engine_path is not None:
|
287 |
+
vocoder = VocosTensorRT(engine_path=vocoder_trt_engine_path)
|
288 |
+
else:
|
289 |
+
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
|
290 |
+
if is_local:
|
291 |
+
print(f"Load vocos from local path {local_path}")
|
292 |
+
config_path = f"{local_path}/config.yaml"
|
293 |
+
model_path = f"{local_path}/pytorch_model.bin"
|
294 |
+
else:
|
295 |
+
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
|
296 |
+
repo_id = "charactr/vocos-mel-24khz"
|
297 |
+
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
|
298 |
+
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
|
299 |
+
vocoder = Vocos.from_hparams(config_path)
|
300 |
+
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
301 |
+
from vocos.feature_extractors import EncodecFeatures
|
302 |
+
|
303 |
+
if isinstance(vocoder.feature_extractor, EncodecFeatures):
|
304 |
+
encodec_parameters = {
|
305 |
+
"feature_extractor.encodec." + key: value
|
306 |
+
for key, value in vocoder.feature_extractor.encodec.state_dict().items()
|
307 |
+
}
|
308 |
+
state_dict.update(encodec_parameters)
|
309 |
+
vocoder.load_state_dict(state_dict)
|
310 |
+
vocoder = vocoder.eval().to(device)
|
311 |
+
elif vocoder_name == "bigvgan":
|
312 |
+
raise NotImplementedError("BigVGAN is not implemented yet")
|
313 |
+
return vocoder
|
314 |
+
|
315 |
+
|
316 |
+
def mel_spectrogram(waveform, vocoder="vocos", device="cuda"):
|
317 |
+
if vocoder == "vocos":
|
318 |
+
mel_stft = torchaudio.transforms.MelSpectrogram(
|
319 |
+
sample_rate=24000,
|
320 |
+
n_fft=1024,
|
321 |
+
win_length=1024,
|
322 |
+
hop_length=256,
|
323 |
+
n_mels=100,
|
324 |
+
power=1,
|
325 |
+
center=True,
|
326 |
+
normalized=False,
|
327 |
+
norm=None,
|
328 |
+
).to(device)
|
329 |
+
mel = mel_stft(waveform.to(device))
|
330 |
+
mel = mel.clamp(min=1e-5).log()
|
331 |
+
return mel.transpose(1, 2)
|
332 |
+
|
333 |
+
|
334 |
+
class VocosTensorRT:
|
335 |
+
def __init__(self, engine_path="./vocos_vocoder.plan", stream=None):
|
336 |
+
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
337 |
+
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
|
338 |
+
logger.info(f"Loading vae engine from {engine_path}")
|
339 |
+
self.engine_path = engine_path
|
340 |
+
with open(engine_path, "rb") as f:
|
341 |
+
engine_buffer = f.read()
|
342 |
+
self.session = Session.from_serialized_engine(engine_buffer)
|
343 |
+
self.stream = stream if stream is not None else torch.cuda.current_stream().cuda_stream
|
344 |
+
|
345 |
+
def decode(self, mels):
|
346 |
+
mels = mels.contiguous()
|
347 |
+
inputs = {"mel": mels}
|
348 |
+
output_info = self.session.infer_shapes([TensorInfo("mel", trt.DataType.FLOAT, mels.shape)])
|
349 |
+
outputs = {
|
350 |
+
t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda") for t in output_info
|
351 |
+
}
|
352 |
+
ok = self.session.run(inputs, outputs, self.stream)
|
353 |
+
|
354 |
+
assert ok, "Runtime execution failed for vae session"
|
355 |
+
|
356 |
+
samples = outputs["waveform"]
|
357 |
+
return samples
|
358 |
+
|
359 |
+
|
360 |
+
def main():
|
361 |
+
args = get_args()
|
362 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
363 |
+
|
364 |
+
assert torch.cuda.is_available()
|
365 |
+
world_size, local_rank, rank = init_distributed()
|
366 |
+
device = torch.device(f"cuda:{local_rank}")
|
367 |
+
|
368 |
+
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file)
|
369 |
+
|
370 |
+
tllm_model_dir = args.tllm_model_dir
|
371 |
+
config_file = os.path.join(tllm_model_dir, "config.json")
|
372 |
+
with open(config_file) as f:
|
373 |
+
config = json.load(f)
|
374 |
+
if args.backend_type == "trt":
|
375 |
+
model = F5TTS(
|
376 |
+
config, debug_mode=False, tllm_model_dir=tllm_model_dir, model_path=args.model_path, vocab_size=vocab_size
|
377 |
+
)
|
378 |
+
elif args.backend_type == "pytorch":
|
379 |
+
import sys
|
380 |
+
|
381 |
+
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
|
382 |
+
from f5_tts.model import DiT
|
383 |
+
from f5_tts.infer.utils_infer import load_model
|
384 |
+
|
385 |
+
F5TTS_model_cfg = dict(
|
386 |
+
dim=1024,
|
387 |
+
depth=22,
|
388 |
+
heads=16,
|
389 |
+
ff_mult=2,
|
390 |
+
text_dim=512,
|
391 |
+
conv_layers=4,
|
392 |
+
pe_attn_head=1,
|
393 |
+
text_mask_padding=False,
|
394 |
+
)
|
395 |
+
model = load_model(DiT, F5TTS_model_cfg, args.model_path)
|
396 |
+
|
397 |
+
vocoder = load_vocoder(
|
398 |
+
vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path
|
399 |
+
)
|
400 |
+
|
401 |
+
dataset = load_dataset(
|
402 |
+
"yuekai/seed_tts",
|
403 |
+
split=args.split_name,
|
404 |
+
trust_remote_code=True,
|
405 |
+
)
|
406 |
+
|
407 |
+
def add_estimated_duration(example):
|
408 |
+
prompt_audio_len = example["prompt_audio"]["array"].shape[0]
|
409 |
+
scale_factor = 1 + len(example["target_text"]) / len(example["prompt_text"])
|
410 |
+
estimated_duration = prompt_audio_len * scale_factor
|
411 |
+
example["estimated_duration"] = estimated_duration / example["prompt_audio"]["sampling_rate"]
|
412 |
+
return example
|
413 |
+
|
414 |
+
dataset = dataset.map(add_estimated_duration)
|
415 |
+
dataset = dataset.sort("estimated_duration", reverse=True)
|
416 |
+
if args.use_perf:
|
417 |
+
# dataset_list = [dataset.select(range(1)) for i in range(16)] # seq_len 1000
|
418 |
+
dataset_list_short = [dataset.select([24]) for i in range(8)] # seq_len 719
|
419 |
+
# dataset_list_long = [dataset.select([23]) for i in range(8)] # seq_len 2002
|
420 |
+
# dataset = datasets.concatenate_datasets(dataset_list_short + dataset_list_long)
|
421 |
+
dataset = datasets.concatenate_datasets(dataset_list_short)
|
422 |
+
if world_size > 1:
|
423 |
+
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
|
424 |
+
else:
|
425 |
+
# This would disable shuffling
|
426 |
+
sampler = None
|
427 |
+
|
428 |
+
dataloader = DataLoader(
|
429 |
+
dataset,
|
430 |
+
batch_size=args.batch_size,
|
431 |
+
sampler=sampler,
|
432 |
+
shuffle=False,
|
433 |
+
num_workers=args.num_workers,
|
434 |
+
prefetch_factor=args.prefetch,
|
435 |
+
collate_fn=lambda x: data_collator(x, vocab_char_map, use_perf=args.use_perf),
|
436 |
+
)
|
437 |
+
|
438 |
+
total_steps = len(dataset)
|
439 |
+
|
440 |
+
if args.enable_warmup:
|
441 |
+
for batch in dataloader:
|
442 |
+
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
|
443 |
+
text_pad_seq = batch["text_pad_sequence"].to(device)
|
444 |
+
total_mel_lens = batch["estimated_reference_target_mel_len"]
|
445 |
+
if args.backend_type == "trt":
|
446 |
+
_ = model.sample(
|
447 |
+
text_pad_seq, ref_mels, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding
|
448 |
+
)
|
449 |
+
elif args.backend_type == "pytorch":
|
450 |
+
with torch.inference_mode():
|
451 |
+
text_pad_seq -= 1
|
452 |
+
text_pad_seq[text_pad_seq == -2] = -1
|
453 |
+
total_mel_lens = torch.tensor(total_mel_lens, device=device)
|
454 |
+
generated, _ = model.sample(
|
455 |
+
cond=ref_mels,
|
456 |
+
text=text_pad_seq,
|
457 |
+
duration=total_mel_lens,
|
458 |
+
steps=16,
|
459 |
+
cfg_strength=2.0,
|
460 |
+
sway_sampling_coef=-1,
|
461 |
+
)
|
462 |
+
|
463 |
+
if rank == 0:
|
464 |
+
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
|
465 |
+
|
466 |
+
decoding_time = 0
|
467 |
+
vocoder_time = 0
|
468 |
+
total_duration = 0
|
469 |
+
if args.use_perf:
|
470 |
+
torch.cuda.cudart().cudaProfilerStart()
|
471 |
+
total_decoding_time = time.time()
|
472 |
+
for batch in dataloader:
|
473 |
+
if args.use_perf:
|
474 |
+
torch.cuda.nvtx.range_push("data sample")
|
475 |
+
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
|
476 |
+
text_pad_seq = batch["text_pad_sequence"].to(device)
|
477 |
+
total_mel_lens = batch["estimated_reference_target_mel_len"]
|
478 |
+
|
479 |
+
if args.use_perf:
|
480 |
+
torch.cuda.nvtx.range_pop()
|
481 |
+
if args.backend_type == "trt":
|
482 |
+
generated, cost_time = model.sample(
|
483 |
+
text_pad_seq,
|
484 |
+
ref_mels,
|
485 |
+
ref_mel_lens,
|
486 |
+
total_mel_lens,
|
487 |
+
remove_input_padding=args.remove_input_padding,
|
488 |
+
use_perf=args.use_perf,
|
489 |
+
)
|
490 |
+
elif args.backend_type == "pytorch":
|
491 |
+
total_mel_lens = torch.tensor(total_mel_lens, device=device)
|
492 |
+
with torch.inference_mode():
|
493 |
+
start_time = time.time()
|
494 |
+
text_pad_seq -= 1
|
495 |
+
text_pad_seq[text_pad_seq == -2] = -1
|
496 |
+
generated, _ = model.sample(
|
497 |
+
cond=ref_mels,
|
498 |
+
text=text_pad_seq,
|
499 |
+
duration=total_mel_lens,
|
500 |
+
lens=ref_mel_lens,
|
501 |
+
steps=16,
|
502 |
+
cfg_strength=2.0,
|
503 |
+
sway_sampling_coef=-1,
|
504 |
+
)
|
505 |
+
cost_time = time.time() - start_time
|
506 |
+
decoding_time += cost_time
|
507 |
+
vocoder_start_time = time.time()
|
508 |
+
for i, gen in enumerate(generated):
|
509 |
+
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
510 |
+
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
511 |
+
if args.vocoder == "vocos":
|
512 |
+
if args.use_perf:
|
513 |
+
torch.cuda.nvtx.range_push("vocoder decode")
|
514 |
+
generated_wave = vocoder.decode(gen_mel_spec).cpu()
|
515 |
+
if args.use_perf:
|
516 |
+
torch.cuda.nvtx.range_pop()
|
517 |
+
else:
|
518 |
+
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
519 |
+
target_rms = 0.1
|
520 |
+
target_sample_rate = 24_000
|
521 |
+
# if ref_rms_list[i] < target_rms:
|
522 |
+
# generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
523 |
+
rms = torch.sqrt(torch.mean(torch.square(generated_wave)))
|
524 |
+
if rms < target_rms:
|
525 |
+
generated_wave = generated_wave * target_rms / rms
|
526 |
+
utt = batch["ids"][i]
|
527 |
+
torchaudio.save(
|
528 |
+
f"{args.output_dir}/{utt}.wav",
|
529 |
+
generated_wave,
|
530 |
+
target_sample_rate,
|
531 |
+
)
|
532 |
+
total_duration += generated_wave.shape[1] / target_sample_rate
|
533 |
+
vocoder_time += time.time() - vocoder_start_time
|
534 |
+
if rank == 0:
|
535 |
+
progress_bar.update(world_size * len(batch["ids"]))
|
536 |
+
total_decoding_time = time.time() - total_decoding_time
|
537 |
+
if rank == 0:
|
538 |
+
progress_bar.close()
|
539 |
+
rtf = total_decoding_time / total_duration
|
540 |
+
s = f"RTF: {rtf:.4f}\n"
|
541 |
+
s += f"total_duration: {total_duration:.3f} seconds\n"
|
542 |
+
s += f"({total_duration / 3600:.2f} hours)\n"
|
543 |
+
s += f"DiT time: {decoding_time:.3f} seconds ({decoding_time / 3600:.2f} hours)\n"
|
544 |
+
s += f"Vocoder time: {vocoder_time:.3f} seconds ({vocoder_time / 3600:.2f} hours)\n"
|
545 |
+
s += f"total decoding time: {total_decoding_time:.3f} seconds ({total_decoding_time / 3600:.2f} hours)\n"
|
546 |
+
s += f"batch size: {args.batch_size}\n"
|
547 |
+
print(s)
|
548 |
+
|
549 |
+
with open(f"{args.output_dir}/rtf.txt", "w") as f:
|
550 |
+
f.write(s)
|
551 |
+
|
552 |
+
dist.barrier()
|
553 |
+
dist.destroy_process_group()
|
554 |
+
|
555 |
+
|
556 |
+
if __name__ == "__main__":
|
557 |
+
main()
|
src/f5_tts/runtime/triton_trtllm/requirements-pytorch.txt
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate>=0.33.0
|
2 |
+
bitsandbytes>0.37.0
|
3 |
+
cached_path
|
4 |
+
click
|
5 |
+
datasets
|
6 |
+
ema_pytorch>=0.5.2
|
7 |
+
gradio>=3.45.2
|
8 |
+
hydra-core>=1.3.0
|
9 |
+
jieba
|
10 |
+
librosa
|
11 |
+
matplotlib
|
12 |
+
numpy<=1.26.4
|
13 |
+
pydub
|
14 |
+
pypinyin
|
15 |
+
safetensors
|
16 |
+
soundfile
|
17 |
+
tomli
|
18 |
+
torch>=2.0.0
|
19 |
+
# torchaudio>=2.0.0
|
20 |
+
torchdiffeq
|
21 |
+
tqdm>=4.65.0
|
22 |
+
transformers
|
23 |
+
x_transformers>=1.31.14
|
24 |
+
packaging>=24.2
|
src/f5_tts/runtime/triton_trtllm/run.sh
CHANGED
@@ -2,8 +2,8 @@ stage=$1
|
|
2 |
stop_stage=$2
|
3 |
model=$3 # F5TTS_Base
|
4 |
if [ -z "$model" ]; then
|
5 |
-
echo "Model is none"
|
6 |
-
|
7 |
fi
|
8 |
echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
|
9 |
export CUDA_VISIBLE_DEVICES=0
|
@@ -68,3 +68,43 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|
68 |
target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
|
69 |
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text"
|
70 |
fi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
stop_stage=$2
|
3 |
model=$3 # F5TTS_Base
|
4 |
if [ -z "$model" ]; then
|
5 |
+
echo "Model is none, using default model F5TTS_Base"
|
6 |
+
model=F5TTS_Base
|
7 |
fi
|
8 |
echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
|
9 |
export CUDA_VISIBLE_DEVICES=0
|
|
|
68 |
target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
|
69 |
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text"
|
70 |
fi
|
71 |
+
|
72 |
+
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
73 |
+
echo "TRT-LLM: offline decoding benchmark test"
|
74 |
+
batch_size=1
|
75 |
+
split_name=wenetspeech4tts
|
76 |
+
backend_type=trt
|
77 |
+
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
|
78 |
+
rm -r $log_dir
|
79 |
+
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
|
80 |
+
torchrun --nproc_per_node=1 \
|
81 |
+
benchmark.py --output-dir $log_dir \
|
82 |
+
--batch-size $batch_size \
|
83 |
+
--enable-warmup \
|
84 |
+
--split-name $split_name \
|
85 |
+
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
|
86 |
+
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
|
87 |
+
--vocoder-trt-engine-path $vocoder_trt_engine_path \
|
88 |
+
--backend-type $backend_type \
|
89 |
+
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
|
90 |
+
fi
|
91 |
+
|
92 |
+
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
93 |
+
echo "Native Pytorch: offline decoding benchmark test"
|
94 |
+
pip install -r requirements-pytorch.txt
|
95 |
+
batch_size=1
|
96 |
+
split_name=wenetspeech4tts
|
97 |
+
backend_type=pytorch
|
98 |
+
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
|
99 |
+
rm -r $log_dir
|
100 |
+
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
|
101 |
+
torchrun --nproc_per_node=1 \
|
102 |
+
benchmark.py --output-dir $log_dir \
|
103 |
+
--batch-size $batch_size \
|
104 |
+
--split-name $split_name \
|
105 |
+
--enable-warmup \
|
106 |
+
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
|
107 |
+
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
|
108 |
+
--backend-type $backend_type \
|
109 |
+
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
|
110 |
+
fi
|