Spaces:
Runtime error
Runtime error
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
src/f5_tts/train/finetune_gradio.py
CHANGED
|
@@ -34,6 +34,7 @@ python_executable = sys.executable or "python"
|
|
| 34 |
tts_api = None
|
| 35 |
last_checkpoint = ""
|
| 36 |
last_device = ""
|
|
|
|
| 37 |
|
| 38 |
path_basic = os.path.abspath(os.path.join(__file__, "../../../.."))
|
| 39 |
path_data = os.path.join(path_basic, "data")
|
|
@@ -800,7 +801,7 @@ def vocab_extend(project_name, symbols, model_type):
|
|
| 800 |
return "Symbols are okay no need to extend."
|
| 801 |
|
| 802 |
size_vocab = len(vocab)
|
| 803 |
-
|
| 804 |
for item in miss_symbols:
|
| 805 |
vocab.append(item)
|
| 806 |
|
|
@@ -915,8 +916,8 @@ def get_random_sample_infer(project_name):
|
|
| 915 |
)
|
| 916 |
|
| 917 |
|
| 918 |
-
def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
|
| 919 |
-
global last_checkpoint, last_device, tts_api
|
| 920 |
|
| 921 |
if not os.path.isfile(file_checkpoint):
|
| 922 |
return None, "checkpoint not found!"
|
|
@@ -926,15 +927,19 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
|
|
| 926 |
else:
|
| 927 |
device_test = None
|
| 928 |
|
| 929 |
-
if last_checkpoint != file_checkpoint or last_device != device_test:
|
| 930 |
if last_checkpoint != file_checkpoint:
|
| 931 |
last_checkpoint = file_checkpoint
|
|
|
|
| 932 |
if last_device != device_test:
|
| 933 |
last_device = device_test
|
| 934 |
|
| 935 |
-
|
|
|
|
|
|
|
|
|
|
| 936 |
|
| 937 |
-
print("update", device_test, file_checkpoint)
|
| 938 |
|
| 939 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
| 940 |
tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
|
|
@@ -1273,7 +1278,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
| 1273 |
list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
|
| 1274 |
|
| 1275 |
nfe_step = gr.Number(label="n_step", value=32)
|
| 1276 |
-
|
| 1277 |
with gr.Row():
|
| 1278 |
cm_checkpoint = gr.Dropdown(
|
| 1279 |
choices=list_checkpoints, value=checkpoint_select, label="checkpoints", allow_custom_value=True
|
|
@@ -1285,6 +1290,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
| 1285 |
ref_text = gr.Textbox(label="ref text")
|
| 1286 |
ref_audio = gr.Audio(label="audio ref", type="filepath")
|
| 1287 |
gen_text = gr.Textbox(label="gen text")
|
|
|
|
| 1288 |
random_sample_infer.click(
|
| 1289 |
fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio]
|
| 1290 |
)
|
|
@@ -1297,7 +1303,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
| 1297 |
|
| 1298 |
check_button_infer.click(
|
| 1299 |
fn=infer,
|
| 1300 |
-
inputs=[cm_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step],
|
| 1301 |
outputs=[gen_audio, txt_info_gpu],
|
| 1302 |
)
|
| 1303 |
|
|
|
|
| 34 |
tts_api = None
|
| 35 |
last_checkpoint = ""
|
| 36 |
last_device = ""
|
| 37 |
+
last_ema = None
|
| 38 |
|
| 39 |
path_basic = os.path.abspath(os.path.join(__file__, "../../../.."))
|
| 40 |
path_data = os.path.join(path_basic, "data")
|
|
|
|
| 801 |
return "Symbols are okay no need to extend."
|
| 802 |
|
| 803 |
size_vocab = len(vocab)
|
| 804 |
+
vocab.pop() # fix empty space leave
|
| 805 |
for item in miss_symbols:
|
| 806 |
vocab.append(item)
|
| 807 |
|
|
|
|
| 916 |
)
|
| 917 |
|
| 918 |
|
| 919 |
+
def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, use_ema):
|
| 920 |
+
global last_checkpoint, last_device, tts_api, last_ema
|
| 921 |
|
| 922 |
if not os.path.isfile(file_checkpoint):
|
| 923 |
return None, "checkpoint not found!"
|
|
|
|
| 927 |
else:
|
| 928 |
device_test = None
|
| 929 |
|
| 930 |
+
if last_checkpoint != file_checkpoint or last_device != device_test or last_ema != use_ema:
|
| 931 |
if last_checkpoint != file_checkpoint:
|
| 932 |
last_checkpoint = file_checkpoint
|
| 933 |
+
|
| 934 |
if last_device != device_test:
|
| 935 |
last_device = device_test
|
| 936 |
|
| 937 |
+
if last_ema != use_ema:
|
| 938 |
+
last_ema = use_ema
|
| 939 |
+
|
| 940 |
+
tts_api = F5TTS(model_type=exp_name, ckpt_file=file_checkpoint, device=device_test, use_ema=use_ema)
|
| 941 |
|
| 942 |
+
print("update >> ", device_test, file_checkpoint, use_ema)
|
| 943 |
|
| 944 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
| 945 |
tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
|
|
|
|
| 1278 |
list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
|
| 1279 |
|
| 1280 |
nfe_step = gr.Number(label="n_step", value=32)
|
| 1281 |
+
ch_use_ema = gr.Checkbox(label="use ema", value=True)
|
| 1282 |
with gr.Row():
|
| 1283 |
cm_checkpoint = gr.Dropdown(
|
| 1284 |
choices=list_checkpoints, value=checkpoint_select, label="checkpoints", allow_custom_value=True
|
|
|
|
| 1290 |
ref_text = gr.Textbox(label="ref text")
|
| 1291 |
ref_audio = gr.Audio(label="audio ref", type="filepath")
|
| 1292 |
gen_text = gr.Textbox(label="gen text")
|
| 1293 |
+
|
| 1294 |
random_sample_infer.click(
|
| 1295 |
fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio]
|
| 1296 |
)
|
|
|
|
| 1303 |
|
| 1304 |
check_button_infer.click(
|
| 1305 |
fn=infer,
|
| 1306 |
+
inputs=[cm_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, ch_use_ema],
|
| 1307 |
outputs=[gen_audio, txt_info_gpu],
|
| 1308 |
)
|
| 1309 |
|