Update TTS/tts/layers/xtts/dvae.py
Browse files- TTS/tts/layers/xtts/dvae.py +10 -18
TTS/tts/layers/xtts/dvae.py
CHANGED
|
@@ -24,9 +24,7 @@ def eval_decorator(fn):
|
|
| 24 |
return inner
|
| 25 |
|
| 26 |
|
| 27 |
-
def dvae_wav_to_mel(
|
| 28 |
-
wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu")
|
| 29 |
-
):
|
| 30 |
mel_stft = torchaudio.transforms.MelSpectrogram(
|
| 31 |
n_fft=1024,
|
| 32 |
hop_length=256,
|
|
@@ -44,7 +42,7 @@ def dvae_wav_to_mel(
|
|
| 44 |
mel = torch.log(torch.clamp(mel, min=1e-5))
|
| 45 |
if mel_norms is None:
|
| 46 |
mel_norms = torch.load(mel_norms_file, map_location=device)
|
| 47 |
-
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
|
| 48 |
return mel
|
| 49 |
|
| 50 |
|
|
@@ -112,7 +110,7 @@ class Quantize(nn.Module):
|
|
| 112 |
self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
|
| 113 |
n = self.cluster_size.sum()
|
| 114 |
cluster_size = (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
|
| 115 |
-
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
|
| 116 |
self.embed.data.copy_(embed_normalized)
|
| 117 |
|
| 118 |
diff = (quantize.detach() - input).pow(2).mean()
|
|
@@ -198,6 +196,7 @@ class UpsampledConv(nn.Module):
|
|
| 198 |
|
| 199 |
# DiscreteVAE partially derived from lucidrains DALLE implementation
|
| 200 |
# Credit: https://github.com/lucidrains/DALLE-pytorch
|
|
|
|
| 201 |
class DiscreteVAE(nn.Module):
|
| 202 |
def __init__(
|
| 203 |
self,
|
|
@@ -215,7 +214,7 @@ class DiscreteVAE(nn.Module):
|
|
| 215 |
activation="relu",
|
| 216 |
smooth_l1_loss=False,
|
| 217 |
straight_through=False,
|
| 218 |
-
normalization=None,
|
| 219 |
record_codes=False,
|
| 220 |
discretization_loss_averaging_steps=100,
|
| 221 |
lr_quantizer_args={},
|
|
@@ -231,7 +230,7 @@ class DiscreteVAE(nn.Module):
|
|
| 231 |
num_tokens, 2, 1 / (num_tokens * 2), discretization_loss_averaging_steps
|
| 232 |
)
|
| 233 |
|
| 234 |
-
assert positional_dims > 0 and positional_dims < 3
|
| 235 |
if positional_dims == 2:
|
| 236 |
conv = nn.Conv2d
|
| 237 |
conv_transpose = nn.ConvTranspose2d
|
|
@@ -246,7 +245,7 @@ class DiscreteVAE(nn.Module):
|
|
| 246 |
elif activation == "silu":
|
| 247 |
act = nn.SiLU
|
| 248 |
else:
|
| 249 |
-
|
| 250 |
|
| 251 |
enc_layers = []
|
| 252 |
dec_layers = []
|
|
@@ -293,7 +292,6 @@ class DiscreteVAE(nn.Module):
|
|
| 293 |
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
|
| 294 |
self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True)
|
| 295 |
|
| 296 |
-
# take care of normalization within class
|
| 297 |
self.normalization = normalization
|
| 298 |
self.record_codes = record_codes
|
| 299 |
if record_codes:
|
|
@@ -303,19 +301,18 @@ class DiscreteVAE(nn.Module):
|
|
| 303 |
self.internal_step = 0
|
| 304 |
|
| 305 |
def norm(self, images):
|
| 306 |
-
if
|
| 307 |
return images
|
| 308 |
|
| 309 |
means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
|
| 310 |
arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()"
|
| 311 |
means, stds = map(lambda t: rearrange(t, arrange), (means, stds))
|
| 312 |
images = images.clone()
|
| 313 |
-
images.sub_(means).div_(stds)
|
| 314 |
return images
|
| 315 |
|
| 316 |
def get_debug_values(self, step, __):
|
| 317 |
if self.record_codes and self.total_codes > 0:
|
| 318 |
-
# Report annealing schedule
|
| 319 |
return {"histogram_codes": self.codes[: self.total_codes]}
|
| 320 |
else:
|
| 321 |
return {}
|
|
@@ -356,9 +353,6 @@ class DiscreteVAE(nn.Module):
|
|
| 356 |
sampled, codes, commitment_loss = self.codebook(logits)
|
| 357 |
return self.decode(codes)
|
| 358 |
|
| 359 |
-
# Note: This module is not meant to be run in forward() except while training. It has special logic which performs
|
| 360 |
-
# evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially
|
| 361 |
-
# more lossy (but useful for determining network performance).
|
| 362 |
def forward(self, img):
|
| 363 |
img = self.norm(img)
|
| 364 |
logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1))
|
|
@@ -371,16 +365,13 @@ class DiscreteVAE(nn.Module):
|
|
| 371 |
out = d(out)
|
| 372 |
self.log_codes(codes)
|
| 373 |
else:
|
| 374 |
-
# This is non-differentiable, but gives a better idea of how the network is actually performing.
|
| 375 |
out, _ = self.decode(codes)
|
| 376 |
|
| 377 |
-
# reconstruction loss
|
| 378 |
recon_loss = self.loss_fn(img, out, reduction="none")
|
| 379 |
|
| 380 |
return recon_loss, commitment_loss, out
|
| 381 |
|
| 382 |
def log_codes(self, codes):
|
| 383 |
-
# This is so we can debug the distribution of codes being learned.
|
| 384 |
if self.record_codes and self.internal_step % 10 == 0:
|
| 385 |
codes = codes.flatten()
|
| 386 |
l = codes.shape[0]
|
|
@@ -391,3 +382,4 @@ class DiscreteVAE(nn.Module):
|
|
| 391 |
self.code_ind = 0
|
| 392 |
self.total_codes += 1
|
| 393 |
self.internal_step += 1
|
|
|
|
|
|
| 24 |
return inner
|
| 25 |
|
| 26 |
|
| 27 |
+
def dvae_wav_to_mel(wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu")):
|
|
|
|
|
|
|
| 28 |
mel_stft = torchaudio.transforms.MelSpectrogram(
|
| 29 |
n_fft=1024,
|
| 30 |
hop_length=256,
|
|
|
|
| 42 |
mel = torch.log(torch.clamp(mel, min=1e-5))
|
| 43 |
if mel_norms is None:
|
| 44 |
mel_norms = torch.load(mel_norms_file, map_location=device)
|
| 45 |
+
mel = mel / (mel_norms.unsqueeze(0).unsqueeze(-1) + 1e-8) # Adicionando um valor pequeno para evitar divisão por zero
|
| 46 |
return mel
|
| 47 |
|
| 48 |
|
|
|
|
| 110 |
self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
|
| 111 |
n = self.cluster_size.sum()
|
| 112 |
cluster_size = (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
|
| 113 |
+
embed_normalized = self.embed_avg / (cluster_size.unsqueeze(0) + self.eps) # Adicionando eps para evitar divisão por zero
|
| 114 |
self.embed.data.copy_(embed_normalized)
|
| 115 |
|
| 116 |
diff = (quantize.detach() - input).pow(2).mean()
|
|
|
|
| 196 |
|
| 197 |
# DiscreteVAE partially derived from lucidrains DALLE implementation
|
| 198 |
# Credit: https://github.com/lucidrains/DALLE-pytorch
|
| 199 |
+
|
| 200 |
class DiscreteVAE(nn.Module):
|
| 201 |
def __init__(
|
| 202 |
self,
|
|
|
|
| 214 |
activation="relu",
|
| 215 |
smooth_l1_loss=False,
|
| 216 |
straight_through=False,
|
| 217 |
+
normalization=None,
|
| 218 |
record_codes=False,
|
| 219 |
discretization_loss_averaging_steps=100,
|
| 220 |
lr_quantizer_args={},
|
|
|
|
| 230 |
num_tokens, 2, 1 / (num_tokens * 2), discretization_loss_averaging_steps
|
| 231 |
)
|
| 232 |
|
| 233 |
+
assert positional_dims > 0 and positional_dims < 3
|
| 234 |
if positional_dims == 2:
|
| 235 |
conv = nn.Conv2d
|
| 236 |
conv_transpose = nn.ConvTranspose2d
|
|
|
|
| 245 |
elif activation == "silu":
|
| 246 |
act = nn.SiLU
|
| 247 |
else:
|
| 248 |
+
raise NotImplementedError()
|
| 249 |
|
| 250 |
enc_layers = []
|
| 251 |
dec_layers = []
|
|
|
|
| 292 |
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
|
| 293 |
self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True)
|
| 294 |
|
|
|
|
| 295 |
self.normalization = normalization
|
| 296 |
self.record_codes = record_codes
|
| 297 |
if record_codes:
|
|
|
|
| 301 |
self.internal_step = 0
|
| 302 |
|
| 303 |
def norm(self, images):
|
| 304 |
+
if self.normalization is None:
|
| 305 |
return images
|
| 306 |
|
| 307 |
means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
|
| 308 |
arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()"
|
| 309 |
means, stds = map(lambda t: rearrange(t, arrange), (means, stds))
|
| 310 |
images = images.clone()
|
| 311 |
+
images.sub_(means).div_(stds + 1e-8) # Adicionando eps para evitar divisão por zero
|
| 312 |
return images
|
| 313 |
|
| 314 |
def get_debug_values(self, step, __):
|
| 315 |
if self.record_codes and self.total_codes > 0:
|
|
|
|
| 316 |
return {"histogram_codes": self.codes[: self.total_codes]}
|
| 317 |
else:
|
| 318 |
return {}
|
|
|
|
| 353 |
sampled, codes, commitment_loss = self.codebook(logits)
|
| 354 |
return self.decode(codes)
|
| 355 |
|
|
|
|
|
|
|
|
|
|
| 356 |
def forward(self, img):
|
| 357 |
img = self.norm(img)
|
| 358 |
logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1))
|
|
|
|
| 365 |
out = d(out)
|
| 366 |
self.log_codes(codes)
|
| 367 |
else:
|
|
|
|
| 368 |
out, _ = self.decode(codes)
|
| 369 |
|
|
|
|
| 370 |
recon_loss = self.loss_fn(img, out, reduction="none")
|
| 371 |
|
| 372 |
return recon_loss, commitment_loss, out
|
| 373 |
|
| 374 |
def log_codes(self, codes):
|
|
|
|
| 375 |
if self.record_codes and self.internal_step % 10 == 0:
|
| 376 |
codes = codes.flatten()
|
| 377 |
l = codes.shape[0]
|
|
|
|
| 382 |
self.code_ind = 0
|
| 383 |
self.total_codes += 1
|
| 384 |
self.internal_step += 1
|
| 385 |
+
|