Update app.py
Browse files
app.py
CHANGED
@@ -12,7 +12,6 @@ import argparse
|
|
12 |
import random
|
13 |
import tempfile
|
14 |
import functools
|
15 |
-
import spaces
|
16 |
|
17 |
import torch
|
18 |
import numpy as np
|
@@ -22,14 +21,8 @@ import gradio as gr
|
|
22 |
|
23 |
from pathlib import Path
|
24 |
|
25 |
-
from vall_e.inference import TTS, cfg
|
26 |
-
from vall_e.train import train
|
27 |
-
from vall_e.utils import get_devices, setup_logging, timer
|
28 |
-
from vall_e.utils.io import json_read, json_stringify
|
29 |
-
from vall_e.emb.qnt import decode_to_wave
|
30 |
-
from vall_e.data import get_lang_symmap, get_random_prompt
|
31 |
-
from vall_e.models.arch import AVAILABLE_ATTENTIONS
|
32 |
|
|
|
33 |
try:
|
34 |
import spaces
|
35 |
|
@@ -39,6 +32,24 @@ except Exception as e:
|
|
39 |
USING_SPACES = False
|
40 |
def spaces_zerogpu_decorator(func):
|
41 |
return func
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
is_windows = sys.platform.startswith("win")
|
44 |
|
@@ -181,11 +192,11 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|
181 |
raise Exception("No model loaded.")
|
182 |
|
183 |
if kwargs.pop("dynamic-sampling", False):
|
184 |
-
kwargs['min-ar-
|
185 |
-
kwargs['min-nar-
|
186 |
else:
|
187 |
-
kwargs['min-ar-
|
188 |
-
kwargs['min-nar-
|
189 |
|
190 |
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
|
191 |
# I'm very sure I can procedurally generate this list
|
@@ -194,16 +205,18 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|
194 |
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
195 |
parser.add_argument("--language", type=str, default=kwargs["language"])
|
196 |
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
|
197 |
-
parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"]
|
198 |
-
parser.add_argument("--max-
|
199 |
-
parser.add_argument("--max-
|
200 |
-
parser.add_argument("--
|
201 |
-
parser.add_argument("--
|
202 |
-
parser.add_argument("--
|
203 |
-
parser.add_argument("--min-
|
204 |
-
parser.add_argument("--
|
|
|
205 |
parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
|
206 |
parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
|
|
|
207 |
parser.add_argument("--min-p", type=float, default=kwargs["min-p"])
|
208 |
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
|
209 |
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
|
@@ -216,10 +229,12 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|
216 |
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
|
217 |
parser.add_argument("--entropix-sampling", action="store_true")
|
218 |
parser.add_argument("--layer-skip", action="store_true")
|
219 |
-
parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"]
|
220 |
-
parser.add_argument("--layer-skip-entropy-threshold", type=int, default=kwargs["layer-skip-entropy-threshold"]
|
221 |
-
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=kwargs["layer-skip-varentropy-threshold"]
|
222 |
parser.add_argument("--refine-on-stop", action="store_true")
|
|
|
|
|
223 |
args, unknown = parser.parse_known_args()
|
224 |
|
225 |
if is_windows:
|
@@ -244,6 +259,40 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|
244 |
tts = init_tts()
|
245 |
|
246 |
gr.Info("Inferencing...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
|
248 |
with timer("Inferenced in", callback=lambda msg: gr.Info( msg )) as t:
|
249 |
wav, sr = tts.inference(
|
@@ -251,34 +300,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|
251 |
language=args.language,
|
252 |
task=args.task,
|
253 |
references=args.references.split(";") if args.references is not None else [],
|
254 |
-
|
255 |
-
max_ar_steps=args.max_ar_steps,
|
256 |
-
max_nar_levels=args.max_nar_levels,
|
257 |
-
input_prompt_length=args.input_prompt_length,
|
258 |
-
input_prompt_prefix=args.input_prompt_prefix,
|
259 |
-
prefix_silence=args.prefix_silence,
|
260 |
-
ar_temp=args.ar_temp,
|
261 |
-
nar_temp=args.nar_temp,
|
262 |
-
min_ar_temp=args.min_ar_temp,
|
263 |
-
min_nar_temp=args.min_nar_temp,
|
264 |
-
top_p=args.top_p,
|
265 |
-
top_k=args.top_k,
|
266 |
-
min_p=args.min_p,
|
267 |
-
beam_width=args.beam_width,
|
268 |
-
repetition_penalty=args.repetition_penalty,
|
269 |
-
repetition_penalty_decay=args.repetition_penalty_decay,
|
270 |
-
length_penalty=args.length_penalty,
|
271 |
-
mirostat_tau=args.mirostat_tau,
|
272 |
-
mirostat_eta=args.mirostat_eta,
|
273 |
-
dry_multiplier=args.dry_multiplier,
|
274 |
-
dry_base=args.dry_base,
|
275 |
-
dry_allowed_length=args.dry_allowed_length,
|
276 |
-
entropix_sampling=args.entropix_sampling,
|
277 |
-
|
278 |
-
layer_skip=args.layer_skip,
|
279 |
-
layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
|
280 |
-
layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
|
281 |
-
refine_on_stop=args.refine_on_stop,
|
282 |
)
|
283 |
|
284 |
wav = wav.squeeze(0).cpu().numpy()
|
@@ -290,20 +312,21 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|
290 |
raise Exception("No model loaded.")
|
291 |
|
292 |
if kwargs.pop("dynamic-sampling", False):
|
293 |
-
kwargs['min-ar-
|
294 |
else:
|
295 |
-
kwargs['min-ar-
|
296 |
|
297 |
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
|
298 |
# I'm very sure I can procedurally generate this list
|
|
|
299 |
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
|
|
300 |
parser.add_argument("--language", type=str, default=kwargs["language"])
|
301 |
-
parser.add_argument("--
|
302 |
-
parser.add_argument("--ar-
|
303 |
-
parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"])
|
304 |
parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
|
305 |
parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
|
306 |
-
parser.add_argument("--min-p", type=
|
307 |
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
|
308 |
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
|
309 |
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
|
@@ -313,27 +336,37 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|
313 |
parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
|
314 |
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
|
315 |
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
|
316 |
-
parser.add_argument("--entropix-sampling", action="store_true")
|
317 |
args, unknown = parser.parse_known_args()
|
318 |
|
319 |
-
|
320 |
"""
|
321 |
if not args.references:
|
322 |
raise Exception("No reference audio provided.")
|
323 |
"""
|
324 |
|
325 |
args.references = args.references.split(";") if args.references is not None else []
|
326 |
-
if args.
|
327 |
for i, path in enumerate( args.references ):
|
328 |
metadata = torchaudio.info(path)
|
329 |
duration = metadata.num_frames / metadata.sample_rate
|
330 |
-
args.
|
331 |
-
args.
|
332 |
|
333 |
if kwargs.pop("entropix-sampling", False):
|
334 |
args.entropix_sampling = True
|
335 |
|
336 |
tts = init_tts()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
|
338 |
gr.Info("Inferencing...")
|
339 |
with timer("Inferenced in") as t:
|
@@ -342,21 +375,7 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|
342 |
language=args.language,
|
343 |
task="stt",
|
344 |
references=args.references,
|
345 |
-
|
346 |
-
ar_temp=args.ar_temp,
|
347 |
-
min_ar_temp=args.min_ar_temp,
|
348 |
-
top_p=args.top_p,
|
349 |
-
top_k=args.top_k,
|
350 |
-
min_p=args.min_p,
|
351 |
-
repetition_penalty=args.repetition_penalty,
|
352 |
-
repetition_penalty_decay=args.repetition_penalty_decay,
|
353 |
-
length_penalty=args.length_penalty,
|
354 |
-
mirostat_tau=args.mirostat_tau,
|
355 |
-
mirostat_eta=args.mirostat_eta,
|
356 |
-
dry_multiplier=args.dry_multiplier,
|
357 |
-
dry_base=args.dry_base,
|
358 |
-
dry_allowed_length=args.dry_allowed_length,
|
359 |
-
entropix_sampling=args.entropix_sampling,
|
360 |
)
|
361 |
|
362 |
return text
|
@@ -413,21 +432,22 @@ with ui:
|
|
413 |
with gr.Column(scale=7):
|
414 |
with gr.Tab("Basic Settings"):
|
415 |
with gr.Row():
|
416 |
-
layout["inference_tts"]["inputs"]["max-
|
417 |
layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.")
|
418 |
with gr.Row():
|
419 |
-
layout["inference_tts"]["inputs"]["ar-
|
420 |
-
layout["inference_tts"]["inputs"]["nar-
|
421 |
with gr.Row():
|
|
|
422 |
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
423 |
with gr.Tab("Sampler Settings"):
|
424 |
with gr.Row():
|
425 |
layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
|
426 |
layout["inference_tts"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
|
|
|
427 |
layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
|
428 |
-
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
429 |
with gr.Row():
|
430 |
-
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.
|
431 |
layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
432 |
layout["inference_tts"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
433 |
with gr.Row():
|
@@ -437,24 +457,26 @@ with ui:
|
|
437 |
layout["inference_tts"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
|
438 |
layout["inference_tts"]["inputs"]["dry-base"] = gr.Slider(value=1.75, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
|
439 |
layout["inference_tts"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
|
440 |
-
|
441 |
-
with gr.
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
|
|
|
|
458 |
|
459 |
layout["inference_tts"]["buttons"]["inference"].click(
|
460 |
fn=do_inference_tts,
|
@@ -474,7 +496,7 @@ with ui:
|
|
474 |
with gr.Column(scale=7):
|
475 |
with gr.Tab("Basic Settings"):
|
476 |
with gr.Row():
|
477 |
-
layout["inference_stt"]["inputs"]["ar-
|
478 |
with gr.Row():
|
479 |
layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
480 |
layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
@@ -485,7 +507,7 @@ with ui:
|
|
485 |
layout["inference_stt"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
|
486 |
layout["inference_stt"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
487 |
with gr.Row():
|
488 |
-
layout["inference_stt"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.
|
489 |
layout["inference_stt"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
490 |
layout["inference_stt"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
491 |
with gr.Row():
|
|
|
12 |
import random
|
13 |
import tempfile
|
14 |
import functools
|
|
|
15 |
|
16 |
import torch
|
17 |
import numpy as np
|
|
|
21 |
|
22 |
from pathlib import Path
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
# agony with HF's ZeroGPU spaces
|
26 |
try:
|
27 |
import spaces
|
28 |
|
|
|
32 |
USING_SPACES = False
|
33 |
def spaces_zerogpu_decorator(func):
|
34 |
return func
|
35 |
+
# more agony, because gradio will not stay launched if directly called from the package, for who knows why
|
36 |
+
# this allows me to directly copy this file rather than constantly edit it on the HF space repo
|
37 |
+
if USING_SPACES:
|
38 |
+
from vall_e.inference import TTS, cfg
|
39 |
+
from vall_e.train import train
|
40 |
+
from vall_e.utils import get_devices, setup_logging, timer
|
41 |
+
from vall_e.utils.io import json_read, json_stringify
|
42 |
+
from vall_e.emb.qnt import decode_to_wave
|
43 |
+
from vall_e.data import get_lang_symmap, get_random_prompt
|
44 |
+
from vall_e.models.arch import AVAILABLE_ATTENTIONS
|
45 |
+
else:
|
46 |
+
from .inference import TTS, cfg
|
47 |
+
from .train import train
|
48 |
+
from .utils import get_devices, setup_logging, timer
|
49 |
+
from .utils.io import json_read, json_stringify
|
50 |
+
from .emb.qnt import decode_to_wave
|
51 |
+
from .data import get_lang_symmap, get_random_prompt
|
52 |
+
from .models.arch import AVAILABLE_ATTENTIONS
|
53 |
|
54 |
is_windows = sys.platform.startswith("win")
|
55 |
|
|
|
192 |
raise Exception("No model loaded.")
|
193 |
|
194 |
if kwargs.pop("dynamic-sampling", False):
|
195 |
+
kwargs['min-ar-temperature'] = 0.01 if kwargs['ar-temperature'] > 0.01 else 0.0
|
196 |
+
kwargs['min-nar-temperature'] = 0.0 # 0.85 if kwargs['nar-temperature'] > 0.85 else 0.0 # should probably disable it for the NAR
|
197 |
else:
|
198 |
+
kwargs['min-ar-temperature'] = -1
|
199 |
+
kwargs['min-nar-temperature'] = -1
|
200 |
|
201 |
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
|
202 |
# I'm very sure I can procedurally generate this list
|
|
|
205 |
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
206 |
parser.add_argument("--language", type=str, default=kwargs["language"])
|
207 |
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
|
208 |
+
parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"])
|
209 |
+
parser.add_argument("--max-duration", type=int, default=int(kwargs["max-duration"]*cfg.dataset.frames_per_second))
|
210 |
+
parser.add_argument("--max-levels", type=int, default=kwargs["max-levels"])
|
211 |
+
parser.add_argument("--max-steps", type=int, default=kwargs["max-steps"])
|
212 |
+
parser.add_argument("--ar-temperature", type=float, default=kwargs["ar-temperature"])
|
213 |
+
parser.add_argument("--nar-temperature", type=float, default=kwargs["nar-temperature"])
|
214 |
+
parser.add_argument("--min-ar-temperature", type=float, default=kwargs["min-ar-temperature"])
|
215 |
+
parser.add_argument("--min-nar-temperature", type=float, default=kwargs["min-nar-temperature"])
|
216 |
+
parser.add_argument("--prefix-silence", type=float, default=kwargs["prefix-silence"])
|
217 |
parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
|
218 |
parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
|
219 |
+
parser.add_argument("--top-no", type=float, default=kwargs["top-no"])
|
220 |
parser.add_argument("--min-p", type=float, default=kwargs["min-p"])
|
221 |
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
|
222 |
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
|
|
|
229 |
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
|
230 |
parser.add_argument("--entropix-sampling", action="store_true")
|
231 |
parser.add_argument("--layer-skip", action="store_true")
|
232 |
+
parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"])
|
233 |
+
parser.add_argument("--layer-skip-entropy-threshold", type=int, default=kwargs["layer-skip-entropy-threshold"])
|
234 |
+
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=kwargs["layer-skip-varentropy-threshold"])
|
235 |
parser.add_argument("--refine-on-stop", action="store_true")
|
236 |
+
parser.add_argument("--denoise-start", type=float, default=0.0)
|
237 |
+
parser.add_argument("--cfg-strength", type=float, default=kwargs['cfg-strength'])
|
238 |
args, unknown = parser.parse_known_args()
|
239 |
|
240 |
if is_windows:
|
|
|
259 |
tts = init_tts()
|
260 |
|
261 |
gr.Info("Inferencing...")
|
262 |
+
|
263 |
+
# icky
|
264 |
+
modality = kwargs.get("modality")
|
265 |
+
if modality:
|
266 |
+
for name, engine in tts.engines.items():
|
267 |
+
if modality == "AR+NAR":
|
268 |
+
engine.hyper_config.capabilities = ["ar", "nar"]
|
269 |
+
elif modality == "NAR-len":
|
270 |
+
engine.hyper_config.capabilities = ["nar", "len"]
|
271 |
+
|
272 |
+
sampling_kwargs = dict(
|
273 |
+
max_steps=args.max_steps,
|
274 |
+
max_levels=args.max_levels,
|
275 |
+
max_duration=args.max_duration,
|
276 |
+
ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
|
277 |
+
min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
|
278 |
+
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p, top_no=args.top_no,
|
279 |
+
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
280 |
+
length_penalty=args.length_penalty,
|
281 |
+
beam_width=args.beam_width,
|
282 |
+
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
283 |
+
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
284 |
+
entropix_sampling=args.entropix_sampling,
|
285 |
+
layer_skip=args.layer_skip,
|
286 |
+
layer_skip_exit_layer=args.layer_skip_exit_layer,
|
287 |
+
layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
|
288 |
+
layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
|
289 |
+
refine_on_stop=args.refine_on_stop,
|
290 |
+
denoise_start=args.denoise_start,
|
291 |
+
prefix_silence=args.prefix_silence,
|
292 |
+
input_prompt_prefix=args.input_prompt_prefix,
|
293 |
+
input_prompt_length=args.input_prompt_length,
|
294 |
+
cfg_strength=args.cfg_strength,
|
295 |
+
)
|
296 |
|
297 |
with timer("Inferenced in", callback=lambda msg: gr.Info( msg )) as t:
|
298 |
wav, sr = tts.inference(
|
|
|
300 |
language=args.language,
|
301 |
task=args.task,
|
302 |
references=args.references.split(";") if args.references is not None else [],
|
303 |
+
**sampling_kwargs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
)
|
305 |
|
306 |
wav = wav.squeeze(0).cpu().numpy()
|
|
|
312 |
raise Exception("No model loaded.")
|
313 |
|
314 |
if kwargs.pop("dynamic-sampling", False):
|
315 |
+
kwargs['min-ar-temperature'] = 0.85 if kwargs['ar-temperature'] > 0.85 else 0.0
|
316 |
else:
|
317 |
+
kwargs['min-ar-temperature'] = -1
|
318 |
|
319 |
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
|
320 |
# I'm very sure I can procedurally generate this list
|
321 |
+
parser.add_argument("--task", type=str, default="tts")
|
322 |
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
323 |
+
parser.add_argument("--max-duration", type=int, default=0)
|
324 |
parser.add_argument("--language", type=str, default=kwargs["language"])
|
325 |
+
parser.add_argument("--ar-temperature", type=float, default=kwargs["ar-temperature"])
|
326 |
+
parser.add_argument("--min-ar-temperature", type=float, default=kwargs["min-ar-temperature"])
|
|
|
327 |
parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
|
328 |
parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
|
329 |
+
parser.add_argument("--min-p", type=float, default=kwargs["min-p"])
|
330 |
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
|
331 |
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
|
332 |
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
|
|
|
336 |
parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
|
337 |
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
|
338 |
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
|
|
|
339 |
args, unknown = parser.parse_known_args()
|
340 |
|
|
|
341 |
"""
|
342 |
if not args.references:
|
343 |
raise Exception("No reference audio provided.")
|
344 |
"""
|
345 |
|
346 |
args.references = args.references.split(";") if args.references is not None else []
|
347 |
+
if args.max_duration == 0:
|
348 |
for i, path in enumerate( args.references ):
|
349 |
metadata = torchaudio.info(path)
|
350 |
duration = metadata.num_frames / metadata.sample_rate
|
351 |
+
args.max_duration += duration
|
352 |
+
args.max_duration = math.floor( args.max_duration * 20 ) # assume 20 tokens per second
|
353 |
|
354 |
if kwargs.pop("entropix-sampling", False):
|
355 |
args.entropix_sampling = True
|
356 |
|
357 |
tts = init_tts()
|
358 |
+
|
359 |
+
sampling_kwargs = dict(
|
360 |
+
max_duration=args.max_duration,
|
361 |
+
ar_temperature=args.ar_temperature,
|
362 |
+
min_ar_temperature=args.min_ar_temperature,
|
363 |
+
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p,
|
364 |
+
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
365 |
+
length_penalty=args.length_penalty,
|
366 |
+
beam_width=args.beam_width,
|
367 |
+
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
368 |
+
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
369 |
+
)
|
370 |
|
371 |
gr.Info("Inferencing...")
|
372 |
with timer("Inferenced in") as t:
|
|
|
375 |
language=args.language,
|
376 |
task="stt",
|
377 |
references=args.references,
|
378 |
+
**sampling_kwargs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
)
|
380 |
|
381 |
return text
|
|
|
432 |
with gr.Column(scale=7):
|
433 |
with gr.Tab("Basic Settings"):
|
434 |
with gr.Row():
|
435 |
+
layout["inference_tts"]["inputs"]["max-duration"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
|
436 |
layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.")
|
437 |
with gr.Row():
|
438 |
+
layout["inference_tts"]["inputs"]["ar-temperature"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)")
|
439 |
+
layout["inference_tts"]["inputs"]["nar-temperature"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
|
440 |
with gr.Row():
|
441 |
+
layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=3.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale")
|
442 |
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
443 |
with gr.Tab("Sampler Settings"):
|
444 |
with gr.Row():
|
445 |
layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
|
446 |
layout["inference_tts"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
|
447 |
+
layout["inference_tts"]["inputs"]["top-no"] = gr.Slider(value=0, minimum=0, maximum=2, step=0.05, label="Top-nσ", info="Performs top-nσ logits processing.")
|
448 |
layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
|
|
|
449 |
with gr.Row():
|
450 |
+
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=0.0, maximum=5.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
451 |
layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
452 |
layout["inference_tts"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
453 |
with gr.Row():
|
|
|
457 |
layout["inference_tts"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
|
458 |
layout["inference_tts"]["inputs"]["dry-base"] = gr.Slider(value=1.75, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
|
459 |
layout["inference_tts"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
|
460 |
+
with gr.Tab("Experimental Settings", visible=cfg.experimental):
|
461 |
+
with gr.Row():
|
462 |
+
layout["inference_tts"]["inputs"]["max-steps"] = gr.Slider(value=25, minimum=1, maximum=500, step=1, label="Max NAR Steps", info="Limits how many steps to perform in the NAR (demask) pass.")
|
463 |
+
layout["inference_tts"]["inputs"]["max-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
464 |
+
with gr.Row():
|
465 |
+
layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
|
466 |
+
layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
|
467 |
+
layout["inference_tts"]["inputs"]["modality"] = gr.Dropdown(value="Auto", choices=["Auto", "AR+NAR", "NAR-len"], label="Modality", info="Whether to inference with the AR+NAR or through the NAR-len.")
|
468 |
+
with gr.Row():
|
469 |
+
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
470 |
+
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
471 |
+
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
|
472 |
+
with gr.Row():
|
473 |
+
layout["inference_tts"]["inputs"]["layer-skip"] = gr.Checkbox(label="Layer Skip", info="Performs self-speculative early exit 'sampling'")
|
474 |
+
layout["inference_tts"]["inputs"]["refine-on-stop"] = gr.Checkbox(label="Refine on <stop>", info="Uses the last step's logits for the AR sequence instead.")
|
475 |
+
with gr.Row():
|
476 |
+
layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Maximum model layer to exit early from.")
|
477 |
+
layout["inference_tts"]["inputs"]["layer-skip-entropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Entropy Threshold", info="Entropy threshold for early-exit")
|
478 |
+
layout["inference_tts"]["inputs"]["layer-skip-varentropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Varentropy Threshold", info="Varentropy threshold for early-exit")
|
479 |
+
|
480 |
|
481 |
layout["inference_tts"]["buttons"]["inference"].click(
|
482 |
fn=do_inference_tts,
|
|
|
496 |
with gr.Column(scale=7):
|
497 |
with gr.Tab("Basic Settings"):
|
498 |
with gr.Row():
|
499 |
+
layout["inference_stt"]["inputs"]["ar-temperature"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
|
500 |
with gr.Row():
|
501 |
layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
502 |
layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
|
|
507 |
layout["inference_stt"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
|
508 |
layout["inference_stt"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
509 |
with gr.Row():
|
510 |
+
layout["inference_stt"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
511 |
layout["inference_stt"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
512 |
layout["inference_stt"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
513 |
with gr.Row():
|