Spaces:
Running
on
Zero
Running
on
Zero
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
src/f5_tts/runtime/triton_trtllm/client_grpc.py
CHANGED
|
@@ -220,8 +220,8 @@ def get_args():
|
|
| 220 |
return parser.parse_args()
|
| 221 |
|
| 222 |
|
| 223 |
-
def load_audio(wav_path, target_sample_rate=
|
| 224 |
-
assert target_sample_rate ==
|
| 225 |
if isinstance(wav_path, dict):
|
| 226 |
waveform = wav_path["array"]
|
| 227 |
sample_rate = wav_path["sampling_rate"]
|
|
@@ -244,7 +244,7 @@ async def send(
|
|
| 244 |
model_name: str,
|
| 245 |
padding_duration: int = None,
|
| 246 |
audio_save_dir: str = "./",
|
| 247 |
-
save_sample_rate: int =
|
| 248 |
):
|
| 249 |
total_duration = 0.0
|
| 250 |
latency_data = []
|
|
@@ -254,7 +254,7 @@ async def send(
|
|
| 254 |
for i, item in enumerate(manifest_item_list):
|
| 255 |
if i % log_interval == 0:
|
| 256 |
print(f"{name}: {i}/{len(manifest_item_list)}")
|
| 257 |
-
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=
|
| 258 |
duration = len(waveform) / sample_rate
|
| 259 |
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
| 260 |
|
|
@@ -417,7 +417,7 @@ async def main():
|
|
| 417 |
model_name=args.model_name,
|
| 418 |
audio_save_dir=args.log_dir,
|
| 419 |
padding_duration=1,
|
| 420 |
-
save_sample_rate=24000
|
| 421 |
)
|
| 422 |
)
|
| 423 |
tasks.append(task)
|
|
|
|
| 220 |
return parser.parse_args()
|
| 221 |
|
| 222 |
|
| 223 |
+
def load_audio(wav_path, target_sample_rate=24000):
|
| 224 |
+
assert target_sample_rate == 24000, "hard coding in server"
|
| 225 |
if isinstance(wav_path, dict):
|
| 226 |
waveform = wav_path["array"]
|
| 227 |
sample_rate = wav_path["sampling_rate"]
|
|
|
|
| 244 |
model_name: str,
|
| 245 |
padding_duration: int = None,
|
| 246 |
audio_save_dir: str = "./",
|
| 247 |
+
save_sample_rate: int = 24000,
|
| 248 |
):
|
| 249 |
total_duration = 0.0
|
| 250 |
latency_data = []
|
|
|
|
| 254 |
for i, item in enumerate(manifest_item_list):
|
| 255 |
if i % log_interval == 0:
|
| 256 |
print(f"{name}: {i}/{len(manifest_item_list)}")
|
| 257 |
+
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=24000)
|
| 258 |
duration = len(waveform) / sample_rate
|
| 259 |
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
| 260 |
|
|
|
|
| 417 |
model_name=args.model_name,
|
| 418 |
audio_save_dir=args.log_dir,
|
| 419 |
padding_duration=1,
|
| 420 |
+
save_sample_rate=24000,
|
| 421 |
)
|
| 422 |
)
|
| 423 |
tasks.append(task)
|
src/f5_tts/runtime/triton_trtllm/client_http.py
CHANGED
|
@@ -82,7 +82,7 @@ def prepare_request(
|
|
| 82 |
samples,
|
| 83 |
reference_text,
|
| 84 |
target_text,
|
| 85 |
-
sample_rate=
|
| 86 |
audio_save_dir: str = "./",
|
| 87 |
):
|
| 88 |
assert len(samples.shape) == 1, "samples should be 1D"
|
|
@@ -106,8 +106,8 @@ def prepare_request(
|
|
| 106 |
return data
|
| 107 |
|
| 108 |
|
| 109 |
-
def load_audio(wav_path, target_sample_rate=
|
| 110 |
-
assert target_sample_rate ==
|
| 111 |
if isinstance(wav_path, dict):
|
| 112 |
samples = wav_path["array"]
|
| 113 |
sample_rate = wav_path["sampling_rate"]
|
|
@@ -129,7 +129,7 @@ if __name__ == "__main__":
|
|
| 129 |
|
| 130 |
url = f"{server_url}/v2/models/{args.model_name}/infer"
|
| 131 |
samples, sr = load_audio(args.reference_audio)
|
| 132 |
-
assert sr ==
|
| 133 |
|
| 134 |
samples = np.array(samples, dtype=np.float32)
|
| 135 |
data = prepare_request(samples, args.reference_text, args.target_text)
|
|
|
|
| 82 |
samples,
|
| 83 |
reference_text,
|
| 84 |
target_text,
|
| 85 |
+
sample_rate=24000,
|
| 86 |
audio_save_dir: str = "./",
|
| 87 |
):
|
| 88 |
assert len(samples.shape) == 1, "samples should be 1D"
|
|
|
|
| 106 |
return data
|
| 107 |
|
| 108 |
|
| 109 |
+
def load_audio(wav_path, target_sample_rate=24000):
|
| 110 |
+
assert target_sample_rate == 24000, "hard coding in server"
|
| 111 |
if isinstance(wav_path, dict):
|
| 112 |
samples = wav_path["array"]
|
| 113 |
sample_rate = wav_path["sampling_rate"]
|
|
|
|
| 129 |
|
| 130 |
url = f"{server_url}/v2/models/{args.model_name}/infer"
|
| 131 |
samples, sr = load_audio(args.reference_audio)
|
| 132 |
+
assert sr == 24000, "sample rate hardcoded in server"
|
| 133 |
|
| 134 |
samples = np.array(samples, dtype=np.float32)
|
| 135 |
data = prepare_request(samples, args.reference_text, args.target_text)
|
src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt
CHANGED
|
@@ -33,7 +33,7 @@ parameters [
|
|
| 33 |
},
|
| 34 |
{
|
| 35 |
key: "reference_audio_sample_rate",
|
| 36 |
-
value: {string_value:"
|
| 37 |
},
|
| 38 |
{
|
| 39 |
key: "vocoder",
|
|
|
|
| 33 |
},
|
| 34 |
{
|
| 35 |
key: "reference_audio_sample_rate",
|
| 36 |
+
value: {string_value:"24000"}
|
| 37 |
},
|
| 38 |
{
|
| 39 |
key: "vocoder",
|