Update API to have more expressive interface for controlling various generation knobs
Browse files- Also adds typical decoder support; unfortunately this does not work well with the current model.
- api.py +25 -16
- eval_multiple.py +3 -2
- models/autoregressive.py +5 -4
- utils/typical_sampling.py +33 -0
api.py
CHANGED
|
@@ -49,13 +49,13 @@ def download_models():
|
|
| 49 |
print('Done.')
|
| 50 |
|
| 51 |
|
| 52 |
-
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True):
|
| 53 |
"""
|
| 54 |
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
|
| 55 |
"""
|
| 56 |
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
|
| 57 |
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
|
| 58 |
-
conditioning_free=cond_free, conditioning_free_k=
|
| 59 |
|
| 60 |
|
| 61 |
def load_conditioning(clip, cond_length=132300):
|
|
@@ -96,7 +96,7 @@ def fix_autoregressive_output(codes, stop_token):
|
|
| 96 |
return codes
|
| 97 |
|
| 98 |
|
| 99 |
-
def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_input,
|
| 100 |
"""
|
| 101 |
Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip.
|
| 102 |
"""
|
|
@@ -111,11 +111,10 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_
|
|
| 111 |
|
| 112 |
output_shape = (mel.shape[0], 100, mel.shape[-1]*4)
|
| 113 |
precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mel)
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
mel = diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
|
| 119 |
return denormalize_tacotron_mel(mel)[:,:,:msl*4]
|
| 120 |
|
| 121 |
|
|
@@ -150,7 +149,12 @@ class TextToSpeech:
|
|
| 150 |
self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
|
| 151 |
self.vocoder.eval(inference=True)
|
| 152 |
|
| 153 |
-
def tts(self, text, voice_samples,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
|
| 155 |
text = F.pad(text, (0, 1)) # This may not be necessary.
|
| 156 |
|
|
@@ -167,7 +171,7 @@ class TextToSpeech:
|
|
| 167 |
else:
|
| 168 |
cond_diffusion = cond_diffusion[:, :88200]
|
| 169 |
|
| 170 |
-
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free)
|
| 171 |
|
| 172 |
with torch.no_grad():
|
| 173 |
samples = []
|
|
@@ -175,11 +179,16 @@ class TextToSpeech:
|
|
| 175 |
stop_mel_token = self.autoregressive.stop_mel_token
|
| 176 |
self.autoregressive = self.autoregressive.cuda()
|
| 177 |
for b in tqdm(range(num_batches)):
|
| 178 |
-
codes = self.autoregressive.inference_speech(conds, text,
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
padding_needed = 250 - codes.shape[1]
|
| 184 |
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
|
| 185 |
samples.append(codes)
|
|
@@ -203,7 +212,7 @@ class TextToSpeech:
|
|
| 203 |
self.vocoder = self.vocoder.cuda()
|
| 204 |
for b in range(best_results.shape[0]):
|
| 205 |
code = best_results[b].unsqueeze(0)
|
| 206 |
-
mel = do_spectrogram_diffusion(self.diffusion, diffuser, code, cond_diffusion,
|
| 207 |
wav = self.vocoder.inference(mel)
|
| 208 |
wav_candidates.append(wav.cpu())
|
| 209 |
self.diffusion = self.diffusion.cpu()
|
|
|
|
| 49 |
print('Done.')
|
| 50 |
|
| 51 |
|
| 52 |
+
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1):
|
| 53 |
"""
|
| 54 |
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
|
| 55 |
"""
|
| 56 |
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
|
| 57 |
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
|
| 58 |
+
conditioning_free=cond_free, conditioning_free_k=cond_free_k)
|
| 59 |
|
| 60 |
|
| 61 |
def load_conditioning(clip, cond_length=132300):
|
|
|
|
| 96 |
return codes
|
| 97 |
|
| 98 |
|
| 99 |
+
def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_input, temperature=1):
|
| 100 |
"""
|
| 101 |
Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip.
|
| 102 |
"""
|
|
|
|
| 111 |
|
| 112 |
output_shape = (mel.shape[0], 100, mel.shape[-1]*4)
|
| 113 |
precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mel)
|
| 114 |
+
|
| 115 |
+
noise = torch.randn(output_shape, device=mel_codes.device) * temperature
|
| 116 |
+
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
|
| 117 |
+
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
|
|
|
|
| 118 |
return denormalize_tacotron_mel(mel)[:,:,:msl*4]
|
| 119 |
|
| 120 |
|
|
|
|
| 149 |
self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
|
| 150 |
self.vocoder.eval(inference=True)
|
| 151 |
|
| 152 |
+
def tts(self, text, voice_samples, k=1,
|
| 153 |
+
# autoregressive generation parameters follow
|
| 154 |
+
num_autoregressive_samples=512, temperature=.9, length_penalty=1, repetition_penalty=1.0, top_k=50, top_p=.95,
|
| 155 |
+
typical_sampling=False, typical_mass=.9,
|
| 156 |
+
# diffusion generation parameters follow
|
| 157 |
+
diffusion_iterations=100, cond_free=True, cond_free_k=1, diffusion_temperature=1,):
|
| 158 |
text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
|
| 159 |
text = F.pad(text, (0, 1)) # This may not be necessary.
|
| 160 |
|
|
|
|
| 171 |
else:
|
| 172 |
cond_diffusion = cond_diffusion[:, :88200]
|
| 173 |
|
| 174 |
+
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
|
| 175 |
|
| 176 |
with torch.no_grad():
|
| 177 |
samples = []
|
|
|
|
| 179 |
stop_mel_token = self.autoregressive.stop_mel_token
|
| 180 |
self.autoregressive = self.autoregressive.cuda()
|
| 181 |
for b in tqdm(range(num_batches)):
|
| 182 |
+
codes = self.autoregressive.inference_speech(conds, text,
|
| 183 |
+
do_sample=True,
|
| 184 |
+
top_k=top_k,
|
| 185 |
+
top_p=top_p,
|
| 186 |
+
temperature=temperature,
|
| 187 |
+
num_return_sequences=self.autoregressive_batch_size,
|
| 188 |
+
length_penalty=length_penalty,
|
| 189 |
+
repetition_penalty=repetition_penalty,
|
| 190 |
+
typical_sampling=typical_sampling,
|
| 191 |
+
typical_mass=typical_mass)
|
| 192 |
padding_needed = 250 - codes.shape[1]
|
| 193 |
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
|
| 194 |
samples.append(codes)
|
|
|
|
| 212 |
self.vocoder = self.vocoder.cuda()
|
| 213 |
for b in range(best_results.shape[0]):
|
| 214 |
code = best_results[b].unsqueeze(0)
|
| 215 |
+
mel = do_spectrogram_diffusion(self.diffusion, diffuser, code, cond_diffusion, temperature=diffusion_temperature)
|
| 216 |
wav = self.vocoder.inference(mel)
|
| 217 |
wav_candidates.append(wav.cpu())
|
| 218 |
self.diffusion = self.diffusion.cpu()
|
eval_multiple.py
CHANGED
|
@@ -7,7 +7,7 @@ from utils.audio import load_audio
|
|
| 7 |
|
| 8 |
if __name__ == '__main__':
|
| 9 |
fname = 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv'
|
| 10 |
-
outpath = 'D:\\tmp\\tortoise-tts-eval\\
|
| 11 |
outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
|
| 12 |
|
| 13 |
os.makedirs(outpath, exist_ok=True)
|
|
@@ -24,7 +24,8 @@ if __name__ == '__main__':
|
|
| 24 |
path = os.path.join(os.path.dirname(fname), line[1])
|
| 25 |
cond_audio = load_audio(path, 22050)
|
| 26 |
torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050)
|
| 27 |
-
sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=
|
|
|
|
| 28 |
down = torchaudio.functional.resample(sample, 24000, 22050)
|
| 29 |
fout_path = os.path.join(outpath, os.path.basename(line[1]))
|
| 30 |
torchaudio.save(fout_path, down.squeeze(0), 22050)
|
|
|
|
| 7 |
|
| 8 |
if __name__ == '__main__':
|
| 9 |
fname = 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv'
|
| 10 |
+
outpath = 'D:\\tmp\\tortoise-tts-eval\\redo_outlier'
|
| 11 |
outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
|
| 12 |
|
| 13 |
os.makedirs(outpath, exist_ok=True)
|
|
|
|
| 24 |
path = os.path.join(os.path.dirname(fname), line[1])
|
| 25 |
cond_audio = load_audio(path, 22050)
|
| 26 |
torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050)
|
| 27 |
+
sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=256, k=1, diffusion_iterations=200, cond_free=False,
|
| 28 |
+
top_k=None, top_p=.95, typical_sampling=False, temperature=.7, length_penalty=.5, repetition_penalty=1)
|
| 29 |
down = torchaudio.functional.resample(sample, 24000, 22050)
|
| 30 |
fout_path = os.path.join(outpath, os.path.basename(line[1]))
|
| 31 |
torchaudio.save(fout_path, down.squeeze(0), 22050)
|
models/autoregressive.py
CHANGED
|
@@ -3,11 +3,11 @@ import functools
|
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
import torch.nn.functional as F
|
| 6 |
-
from transformers import GPT2Config, GPT2PreTrainedModel
|
| 7 |
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
| 8 |
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
| 9 |
from models.arch_util import AttentionBlock
|
| 10 |
-
|
| 11 |
|
| 12 |
|
| 13 |
def null_position_embeddings(range, dim):
|
|
@@ -497,7 +497,7 @@ class UnifiedVoice(nn.Module):
|
|
| 497 |
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
| 498 |
return loss_mel.mean()
|
| 499 |
|
| 500 |
-
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
|
| 501 |
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
| 502 |
if not hasattr(self, 'inference_model'):
|
| 503 |
# TODO: Decouple gpt_config from this inference model.
|
|
@@ -530,8 +530,9 @@ class UnifiedVoice(nn.Module):
|
|
| 530 |
fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
| 531 |
fake_inputs[:,-1] = self.start_mel_token
|
| 532 |
|
|
|
|
| 533 |
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
| 534 |
-
max_length=
|
| 535 |
return gen[:, fake_inputs.shape[1]:]
|
| 536 |
|
| 537 |
|
|
|
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
import torch.nn.functional as F
|
| 6 |
+
from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
|
| 7 |
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
| 8 |
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
| 9 |
from models.arch_util import AttentionBlock
|
| 10 |
+
from utils.typical_sampling import TypicalLogitsWarper
|
| 11 |
|
| 12 |
|
| 13 |
def null_position_embeddings(range, dim):
|
|
|
|
| 497 |
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
| 498 |
return loss_mel.mean()
|
| 499 |
|
| 500 |
+
def inference_speech(self, speech_conditioning_input, text_inputs, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
|
| 501 |
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
| 502 |
if not hasattr(self, 'inference_model'):
|
| 503 |
# TODO: Decouple gpt_config from this inference model.
|
|
|
|
| 530 |
fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
| 531 |
fake_inputs[:,-1] = self.start_mel_token
|
| 532 |
|
| 533 |
+
logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
|
| 534 |
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
| 535 |
+
max_length=fake_inputs.shape[-1] + self.max_mel_tokens - 1, logits_processor=logits_processor, **hf_generate_kwargs)
|
| 536 |
return gen[:, fake_inputs.shape[1]:]
|
| 537 |
|
| 538 |
|
utils/typical_sampling.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import LogitsWarper
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class TypicalLogitsWarper(LogitsWarper):
|
| 6 |
+
def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
| 7 |
+
self.filter_value = filter_value
|
| 8 |
+
self.mass = mass
|
| 9 |
+
self.min_tokens_to_keep = min_tokens_to_keep
|
| 10 |
+
|
| 11 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 12 |
+
# calculate entropy
|
| 13 |
+
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
|
| 14 |
+
p = torch.exp(normalized)
|
| 15 |
+
ent = -(normalized * p).nansum(-1, keepdim=True)
|
| 16 |
+
|
| 17 |
+
# shift and sort
|
| 18 |
+
shifted_scores = torch.abs((-normalized) - ent)
|
| 19 |
+
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
|
| 20 |
+
sorted_logits = scores.gather(-1, sorted_indices)
|
| 21 |
+
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
| 22 |
+
|
| 23 |
+
# Remove tokens with cumulative mass above the threshold
|
| 24 |
+
last_ind = (cumulative_probs < self.mass).sum(dim=1)
|
| 25 |
+
last_ind[last_ind < 0] = 0
|
| 26 |
+
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
|
| 27 |
+
if self.min_tokens_to_keep > 1:
|
| 28 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
| 29 |
+
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
| 30 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 31 |
+
|
| 32 |
+
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
| 33 |
+
return scores
|