ecker commited on
Commit
ce32c07
·
verified ·
1 Parent(s): acf492d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -102
app.py CHANGED
@@ -12,7 +12,6 @@ import argparse
12
  import random
13
  import tempfile
14
  import functools
15
- import spaces
16
 
17
  import torch
18
  import numpy as np
@@ -22,14 +21,8 @@ import gradio as gr
22
 
23
  from pathlib import Path
24
 
25
- from vall_e.inference import TTS, cfg
26
- from vall_e.train import train
27
- from vall_e.utils import get_devices, setup_logging, timer
28
- from vall_e.utils.io import json_read, json_stringify
29
- from vall_e.emb.qnt import decode_to_wave
30
- from vall_e.data import get_lang_symmap, get_random_prompt
31
- from vall_e.models.arch import AVAILABLE_ATTENTIONS
32
 
 
33
  try:
34
  import spaces
35
 
@@ -39,6 +32,24 @@ except Exception as e:
39
  USING_SPACES = False
40
  def spaces_zerogpu_decorator(func):
41
  return func
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  is_windows = sys.platform.startswith("win")
44
 
@@ -181,11 +192,11 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
181
  raise Exception("No model loaded.")
182
 
183
  if kwargs.pop("dynamic-sampling", False):
184
- kwargs['min-ar-temp'] = 0.01 if kwargs['ar-temp'] > 0.01 else 0.0
185
- kwargs['min-nar-temp'] = 0.0 # 0.85 if kwargs['nar-temp'] > 0.85 else 0.0 # should probably disable it for the NAR
186
  else:
187
- kwargs['min-ar-temp'] = -1
188
- kwargs['min-nar-temp'] = -1
189
 
190
  parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
191
  # I'm very sure I can procedurally generate this list
@@ -194,16 +205,18 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
194
  parser.add_argument("--references", type=str, default=kwargs["reference"])
195
  parser.add_argument("--language", type=str, default=kwargs["language"])
196
  parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
197
- parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"] if cfg.experimental else False)
198
- parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second))
199
- parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"] if cfg.experimental else 0)
200
- parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
201
- parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
202
- parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"])
203
- parser.add_argument("--min-nar-temp", type=float, default=kwargs["min-nar-temp"])
204
- parser.add_argument("--prefix-silence", type=float, default=kwargs["prefix-silence"] if cfg.experimental else 0)
 
205
  parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
206
  parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
 
207
  parser.add_argument("--min-p", type=float, default=kwargs["min-p"])
208
  parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
209
  parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
@@ -216,10 +229,12 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
216
  parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
217
  parser.add_argument("--entropix-sampling", action="store_true")
218
  parser.add_argument("--layer-skip", action="store_true")
219
- parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"] if cfg.experimental else -1)
220
- parser.add_argument("--layer-skip-entropy-threshold", type=int, default=kwargs["layer-skip-entropy-threshold"] if cfg.experimental else 0.1)
221
- parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=kwargs["layer-skip-varentropy-threshold"] if cfg.experimental else 0.1)
222
  parser.add_argument("--refine-on-stop", action="store_true")
 
 
223
  args, unknown = parser.parse_known_args()
224
 
225
  if is_windows:
@@ -244,6 +259,40 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
244
  tts = init_tts()
245
 
246
  gr.Info("Inferencing...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
  with timer("Inferenced in", callback=lambda msg: gr.Info( msg )) as t:
249
  wav, sr = tts.inference(
@@ -251,34 +300,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
251
  language=args.language,
252
  task=args.task,
253
  references=args.references.split(";") if args.references is not None else [],
254
- out_path=tmp.name,
255
- max_ar_steps=args.max_ar_steps,
256
- max_nar_levels=args.max_nar_levels,
257
- input_prompt_length=args.input_prompt_length,
258
- input_prompt_prefix=args.input_prompt_prefix,
259
- prefix_silence=args.prefix_silence,
260
- ar_temp=args.ar_temp,
261
- nar_temp=args.nar_temp,
262
- min_ar_temp=args.min_ar_temp,
263
- min_nar_temp=args.min_nar_temp,
264
- top_p=args.top_p,
265
- top_k=args.top_k,
266
- min_p=args.min_p,
267
- beam_width=args.beam_width,
268
- repetition_penalty=args.repetition_penalty,
269
- repetition_penalty_decay=args.repetition_penalty_decay,
270
- length_penalty=args.length_penalty,
271
- mirostat_tau=args.mirostat_tau,
272
- mirostat_eta=args.mirostat_eta,
273
- dry_multiplier=args.dry_multiplier,
274
- dry_base=args.dry_base,
275
- dry_allowed_length=args.dry_allowed_length,
276
- entropix_sampling=args.entropix_sampling,
277
-
278
- layer_skip=args.layer_skip,
279
- layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
280
- layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
281
- refine_on_stop=args.refine_on_stop,
282
  )
283
 
284
  wav = wav.squeeze(0).cpu().numpy()
@@ -290,20 +312,21 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
290
  raise Exception("No model loaded.")
291
 
292
  if kwargs.pop("dynamic-sampling", False):
293
- kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0
294
  else:
295
- kwargs['min-ar-temp'] = -1
296
 
297
  parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
298
  # I'm very sure I can procedurally generate this list
 
299
  parser.add_argument("--references", type=str, default=kwargs["reference"])
 
300
  parser.add_argument("--language", type=str, default=kwargs["language"])
301
- parser.add_argument("--max-ar-steps", type=int, default=0)
302
- parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
303
- parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"])
304
  parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
305
  parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
306
- parser.add_argument("--min-p", type=int, default=kwargs["min-p"])
307
  parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
308
  parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
309
  parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
@@ -313,27 +336,37 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
313
  parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
314
  parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
315
  parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
316
- parser.add_argument("--entropix-sampling", action="store_true")
317
  args, unknown = parser.parse_known_args()
318
 
319
-
320
  """
321
  if not args.references:
322
  raise Exception("No reference audio provided.")
323
  """
324
 
325
  args.references = args.references.split(";") if args.references is not None else []
326
- if args.max_ar_steps == 0:
327
  for i, path in enumerate( args.references ):
328
  metadata = torchaudio.info(path)
329
  duration = metadata.num_frames / metadata.sample_rate
330
- args.max_ar_steps += duration
331
- args.max_ar_steps = math.floor( args.max_ar_steps * 20 ) # assume 20 tokens per second
332
 
333
  if kwargs.pop("entropix-sampling", False):
334
  args.entropix_sampling = True
335
 
336
  tts = init_tts()
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
  gr.Info("Inferencing...")
339
  with timer("Inferenced in") as t:
@@ -342,21 +375,7 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
342
  language=args.language,
343
  task="stt",
344
  references=args.references,
345
- max_ar_steps=args.max_ar_steps,
346
- ar_temp=args.ar_temp,
347
- min_ar_temp=args.min_ar_temp,
348
- top_p=args.top_p,
349
- top_k=args.top_k,
350
- min_p=args.min_p,
351
- repetition_penalty=args.repetition_penalty,
352
- repetition_penalty_decay=args.repetition_penalty_decay,
353
- length_penalty=args.length_penalty,
354
- mirostat_tau=args.mirostat_tau,
355
- mirostat_eta=args.mirostat_eta,
356
- dry_multiplier=args.dry_multiplier,
357
- dry_base=args.dry_base,
358
- dry_allowed_length=args.dry_allowed_length,
359
- entropix_sampling=args.entropix_sampling,
360
  )
361
 
362
  return text
@@ -413,21 +432,22 @@ with ui:
413
  with gr.Column(scale=7):
414
  with gr.Tab("Basic Settings"):
415
  with gr.Row():
416
- layout["inference_tts"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
417
  layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.")
418
  with gr.Row():
419
- layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.5, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)")
420
- layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
421
  with gr.Row():
 
422
  layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
423
  with gr.Tab("Sampler Settings"):
424
  with gr.Row():
425
  layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
426
  layout["inference_tts"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
 
427
  layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
428
- layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
429
  with gr.Row():
430
- layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.5, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
431
  layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
432
  layout["inference_tts"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
433
  with gr.Row():
@@ -437,24 +457,26 @@ with ui:
437
  layout["inference_tts"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
438
  layout["inference_tts"]["inputs"]["dry-base"] = gr.Slider(value=1.75, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
439
  layout["inference_tts"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
440
- if cfg.experimental:
441
- with gr.Tab("Experimental Settings"):
442
- with gr.Row():
443
- layout["inference_tts"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
444
- layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
445
- with gr.Row():
446
- layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
447
- with gr.Row():
448
- layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
449
- layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
450
- with gr.Row():
451
- layout["inference_tts"]["inputs"]["layer-skip"] = gr.Checkbox(label="Layer Skip", info="Performs self-speculative early exit 'sampling'")
452
- layout["inference_tts"]["inputs"]["refine-on-stop"] = gr.Checkbox(label="Refine on <stop>", info="Uses the last step's logits for the AR sequence instead.")
453
- with gr.Row():
454
- layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Maximum model layer to exit early from.")
455
- layout["inference_tts"]["inputs"]["layer-skip-entropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Entropy Threshold", info="Entropy threshold for early-exit")
456
- layout["inference_tts"]["inputs"]["layer-skip-varentropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Varentropy Threshold", info="Varentropy threshold for early-exit")
457
-
 
 
458
 
459
  layout["inference_tts"]["buttons"]["inference"].click(
460
  fn=do_inference_tts,
@@ -474,7 +496,7 @@ with ui:
474
  with gr.Column(scale=7):
475
  with gr.Tab("Basic Settings"):
476
  with gr.Row():
477
- layout["inference_stt"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
478
  with gr.Row():
479
  layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
480
  layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
@@ -485,7 +507,7 @@ with ui:
485
  layout["inference_stt"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
486
  layout["inference_stt"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
487
  with gr.Row():
488
- layout["inference_stt"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.25, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
489
  layout["inference_stt"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
490
  layout["inference_stt"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
491
  with gr.Row():
 
12
  import random
13
  import tempfile
14
  import functools
 
15
 
16
  import torch
17
  import numpy as np
 
21
 
22
  from pathlib import Path
23
 
 
 
 
 
 
 
 
24
 
25
+ # agony with HF's ZeroGPU spaces
26
  try:
27
  import spaces
28
 
 
32
  USING_SPACES = False
33
  def spaces_zerogpu_decorator(func):
34
  return func
35
+ # more agony, because gradio will not stay launched if directly called from the package, for who knows why
36
+ # this allows me to directly copy this file rather than constantly edit it on the HF space repo
37
+ if USING_SPACES:
38
+ from vall_e.inference import TTS, cfg
39
+ from vall_e.train import train
40
+ from vall_e.utils import get_devices, setup_logging, timer
41
+ from vall_e.utils.io import json_read, json_stringify
42
+ from vall_e.emb.qnt import decode_to_wave
43
+ from vall_e.data import get_lang_symmap, get_random_prompt
44
+ from vall_e.models.arch import AVAILABLE_ATTENTIONS
45
+ else:
46
+ from .inference import TTS, cfg
47
+ from .train import train
48
+ from .utils import get_devices, setup_logging, timer
49
+ from .utils.io import json_read, json_stringify
50
+ from .emb.qnt import decode_to_wave
51
+ from .data import get_lang_symmap, get_random_prompt
52
+ from .models.arch import AVAILABLE_ATTENTIONS
53
 
54
  is_windows = sys.platform.startswith("win")
55
 
 
192
  raise Exception("No model loaded.")
193
 
194
  if kwargs.pop("dynamic-sampling", False):
195
+ kwargs['min-ar-temperature'] = 0.01 if kwargs['ar-temperature'] > 0.01 else 0.0
196
+ kwargs['min-nar-temperature'] = 0.0 # 0.85 if kwargs['nar-temperature'] > 0.85 else 0.0 # should probably disable it for the NAR
197
  else:
198
+ kwargs['min-ar-temperature'] = -1
199
+ kwargs['min-nar-temperature'] = -1
200
 
201
  parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
202
  # I'm very sure I can procedurally generate this list
 
205
  parser.add_argument("--references", type=str, default=kwargs["reference"])
206
  parser.add_argument("--language", type=str, default=kwargs["language"])
207
  parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
208
+ parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"])
209
+ parser.add_argument("--max-duration", type=int, default=int(kwargs["max-duration"]*cfg.dataset.frames_per_second))
210
+ parser.add_argument("--max-levels", type=int, default=kwargs["max-levels"])
211
+ parser.add_argument("--max-steps", type=int, default=kwargs["max-steps"])
212
+ parser.add_argument("--ar-temperature", type=float, default=kwargs["ar-temperature"])
213
+ parser.add_argument("--nar-temperature", type=float, default=kwargs["nar-temperature"])
214
+ parser.add_argument("--min-ar-temperature", type=float, default=kwargs["min-ar-temperature"])
215
+ parser.add_argument("--min-nar-temperature", type=float, default=kwargs["min-nar-temperature"])
216
+ parser.add_argument("--prefix-silence", type=float, default=kwargs["prefix-silence"])
217
  parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
218
  parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
219
+ parser.add_argument("--top-no", type=float, default=kwargs["top-no"])
220
  parser.add_argument("--min-p", type=float, default=kwargs["min-p"])
221
  parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
222
  parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
 
229
  parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
230
  parser.add_argument("--entropix-sampling", action="store_true")
231
  parser.add_argument("--layer-skip", action="store_true")
232
+ parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"])
233
+ parser.add_argument("--layer-skip-entropy-threshold", type=int, default=kwargs["layer-skip-entropy-threshold"])
234
+ parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=kwargs["layer-skip-varentropy-threshold"])
235
  parser.add_argument("--refine-on-stop", action="store_true")
236
+ parser.add_argument("--denoise-start", type=float, default=0.0)
237
+ parser.add_argument("--cfg-strength", type=float, default=kwargs['cfg-strength'])
238
  args, unknown = parser.parse_known_args()
239
 
240
  if is_windows:
 
259
  tts = init_tts()
260
 
261
  gr.Info("Inferencing...")
262
+
263
+ # icky
264
+ modality = kwargs.get("modality")
265
+ if modality:
266
+ for name, engine in tts.engines.items():
267
+ if modality == "AR+NAR":
268
+ engine.hyper_config.capabilities = ["ar", "nar"]
269
+ elif modality == "NAR-len":
270
+ engine.hyper_config.capabilities = ["nar", "len"]
271
+
272
+ sampling_kwargs = dict(
273
+ max_steps=args.max_steps,
274
+ max_levels=args.max_levels,
275
+ max_duration=args.max_duration,
276
+ ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
277
+ min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
278
+ top_p=args.top_p, top_k=args.top_k, min_p=args.min_p, top_no=args.top_no,
279
+ repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
280
+ length_penalty=args.length_penalty,
281
+ beam_width=args.beam_width,
282
+ mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
283
+ dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
284
+ entropix_sampling=args.entropix_sampling,
285
+ layer_skip=args.layer_skip,
286
+ layer_skip_exit_layer=args.layer_skip_exit_layer,
287
+ layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
288
+ layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
289
+ refine_on_stop=args.refine_on_stop,
290
+ denoise_start=args.denoise_start,
291
+ prefix_silence=args.prefix_silence,
292
+ input_prompt_prefix=args.input_prompt_prefix,
293
+ input_prompt_length=args.input_prompt_length,
294
+ cfg_strength=args.cfg_strength,
295
+ )
296
 
297
  with timer("Inferenced in", callback=lambda msg: gr.Info( msg )) as t:
298
  wav, sr = tts.inference(
 
300
  language=args.language,
301
  task=args.task,
302
  references=args.references.split(";") if args.references is not None else [],
303
+ **sampling_kwargs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  )
305
 
306
  wav = wav.squeeze(0).cpu().numpy()
 
312
  raise Exception("No model loaded.")
313
 
314
  if kwargs.pop("dynamic-sampling", False):
315
+ kwargs['min-ar-temperature'] = 0.85 if kwargs['ar-temperature'] > 0.85 else 0.0
316
  else:
317
+ kwargs['min-ar-temperature'] = -1
318
 
319
  parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
320
  # I'm very sure I can procedurally generate this list
321
+ parser.add_argument("--task", type=str, default="tts")
322
  parser.add_argument("--references", type=str, default=kwargs["reference"])
323
+ parser.add_argument("--max-duration", type=int, default=0)
324
  parser.add_argument("--language", type=str, default=kwargs["language"])
325
+ parser.add_argument("--ar-temperature", type=float, default=kwargs["ar-temperature"])
326
+ parser.add_argument("--min-ar-temperature", type=float, default=kwargs["min-ar-temperature"])
 
327
  parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
328
  parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
329
+ parser.add_argument("--min-p", type=float, default=kwargs["min-p"])
330
  parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
331
  parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
332
  parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
 
336
  parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
337
  parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
338
  parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
 
339
  args, unknown = parser.parse_known_args()
340
 
 
341
  """
342
  if not args.references:
343
  raise Exception("No reference audio provided.")
344
  """
345
 
346
  args.references = args.references.split(";") if args.references is not None else []
347
+ if args.max_duration == 0:
348
  for i, path in enumerate( args.references ):
349
  metadata = torchaudio.info(path)
350
  duration = metadata.num_frames / metadata.sample_rate
351
+ args.max_duration += duration
352
+ args.max_duration = math.floor( args.max_duration * 20 ) # assume 20 tokens per second
353
 
354
  if kwargs.pop("entropix-sampling", False):
355
  args.entropix_sampling = True
356
 
357
  tts = init_tts()
358
+
359
+ sampling_kwargs = dict(
360
+ max_duration=args.max_duration,
361
+ ar_temperature=args.ar_temperature,
362
+ min_ar_temperature=args.min_ar_temperature,
363
+ top_p=args.top_p, top_k=args.top_k, min_p=args.min_p,
364
+ repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
365
+ length_penalty=args.length_penalty,
366
+ beam_width=args.beam_width,
367
+ mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
368
+ dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
369
+ )
370
 
371
  gr.Info("Inferencing...")
372
  with timer("Inferenced in") as t:
 
375
  language=args.language,
376
  task="stt",
377
  references=args.references,
378
+ **sampling_kwargs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  )
380
 
381
  return text
 
432
  with gr.Column(scale=7):
433
  with gr.Tab("Basic Settings"):
434
  with gr.Row():
435
+ layout["inference_tts"]["inputs"]["max-duration"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
436
  layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.")
437
  with gr.Row():
438
+ layout["inference_tts"]["inputs"]["ar-temperature"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)")
439
+ layout["inference_tts"]["inputs"]["nar-temperature"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
440
  with gr.Row():
441
+ layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=3.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale")
442
  layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
443
  with gr.Tab("Sampler Settings"):
444
  with gr.Row():
445
  layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
446
  layout["inference_tts"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
447
+ layout["inference_tts"]["inputs"]["top-no"] = gr.Slider(value=0, minimum=0, maximum=2, step=0.05, label="Top-nσ", info="Performs top-nσ logits processing.")
448
  layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
 
449
  with gr.Row():
450
+ layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=0.0, maximum=5.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
451
  layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
452
  layout["inference_tts"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
453
  with gr.Row():
 
457
  layout["inference_tts"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
458
  layout["inference_tts"]["inputs"]["dry-base"] = gr.Slider(value=1.75, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
459
  layout["inference_tts"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
460
+ with gr.Tab("Experimental Settings", visible=cfg.experimental):
461
+ with gr.Row():
462
+ layout["inference_tts"]["inputs"]["max-steps"] = gr.Slider(value=25, minimum=1, maximum=500, step=1, label="Max NAR Steps", info="Limits how many steps to perform in the NAR (demask) pass.")
463
+ layout["inference_tts"]["inputs"]["max-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
464
+ with gr.Row():
465
+ layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
466
+ layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
467
+ layout["inference_tts"]["inputs"]["modality"] = gr.Dropdown(value="Auto", choices=["Auto", "AR+NAR", "NAR-len"], label="Modality", info="Whether to inference with the AR+NAR or through the NAR-len.")
468
+ with gr.Row():
469
+ layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
470
+ layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
471
+ layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
472
+ with gr.Row():
473
+ layout["inference_tts"]["inputs"]["layer-skip"] = gr.Checkbox(label="Layer Skip", info="Performs self-speculative early exit 'sampling'")
474
+ layout["inference_tts"]["inputs"]["refine-on-stop"] = gr.Checkbox(label="Refine on <stop>", info="Uses the last step's logits for the AR sequence instead.")
475
+ with gr.Row():
476
+ layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Maximum model layer to exit early from.")
477
+ layout["inference_tts"]["inputs"]["layer-skip-entropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Entropy Threshold", info="Entropy threshold for early-exit")
478
+ layout["inference_tts"]["inputs"]["layer-skip-varentropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Varentropy Threshold", info="Varentropy threshold for early-exit")
479
+
480
 
481
  layout["inference_tts"]["buttons"]["inference"].click(
482
  fn=do_inference_tts,
 
496
  with gr.Column(scale=7):
497
  with gr.Tab("Basic Settings"):
498
  with gr.Row():
499
+ layout["inference_stt"]["inputs"]["ar-temperature"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
500
  with gr.Row():
501
  layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
502
  layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
 
507
  layout["inference_stt"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
508
  layout["inference_stt"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
509
  with gr.Row():
510
+ layout["inference_stt"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
511
  layout["inference_stt"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
512
  layout["inference_stt"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
513
  with gr.Row():