KingNish commited on
Commit
4e8aca1
·
verified ·
1 Parent(s): e02e40c

Upload ./vocos/loss.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vocos/loss.py +114 -0
vocos/loss.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ import torchaudio
5
+ from torch import nn
6
+
7
+ from vocos.modules import safe_log
8
+
9
+
10
+ class MelSpecReconstructionLoss(nn.Module):
11
+ """
12
+ L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample
13
+ """
14
+
15
+ def __init__(
16
+ self, sample_rate: int = 24000, n_fft: int = 1024, hop_length: int = 256, n_mels: int = 100,
17
+ ):
18
+ super().__init__()
19
+ self.mel_spec = torchaudio.transforms.MelSpectrogram(
20
+ sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, center=True, power=1,
21
+ )
22
+
23
+ def forward(self, y_hat, y) -> torch.Tensor:
24
+ """
25
+ Args:
26
+ y_hat (Tensor): Predicted audio waveform.
27
+ y (Tensor): Ground truth audio waveform.
28
+
29
+ Returns:
30
+ Tensor: L1 loss between the mel-scaled magnitude spectrograms.
31
+ """
32
+ mel_hat = safe_log(self.mel_spec(y_hat))
33
+ mel = safe_log(self.mel_spec(y))
34
+
35
+ loss = torch.nn.functional.l1_loss(mel, mel_hat)
36
+
37
+ return loss
38
+
39
+
40
+ class GeneratorLoss(nn.Module):
41
+ """
42
+ Generator Loss module. Calculates the loss for the generator based on discriminator outputs.
43
+ """
44
+
45
+ def forward(self, disc_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
46
+ """
47
+ Args:
48
+ disc_outputs (List[Tensor]): List of discriminator outputs.
49
+
50
+ Returns:
51
+ Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
52
+ the sub-discriminators
53
+ """
54
+ loss = torch.zeros(1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype)
55
+ gen_losses = []
56
+ for dg in disc_outputs:
57
+ l = torch.mean(torch.clamp(1 - dg, min=0))
58
+ gen_losses.append(l)
59
+ loss += l
60
+
61
+ return loss, gen_losses
62
+
63
+
64
+ class DiscriminatorLoss(nn.Module):
65
+ """
66
+ Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs.
67
+ """
68
+
69
+ def forward(
70
+ self, disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
71
+ ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
72
+ """
73
+ Args:
74
+ disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples.
75
+ disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples.
76
+
77
+ Returns:
78
+ Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from
79
+ the sub-discriminators for real outputs, and a list of
80
+ loss values for generated outputs.
81
+ """
82
+ loss = torch.zeros(1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype)
83
+ r_losses = []
84
+ g_losses = []
85
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
86
+ r_loss = torch.mean(torch.clamp(1 - dr, min=0))
87
+ g_loss = torch.mean(torch.clamp(1 + dg, min=0))
88
+ loss += r_loss + g_loss
89
+ r_losses.append(r_loss)
90
+ g_losses.append(g_loss)
91
+
92
+ return loss, r_losses, g_losses
93
+
94
+
95
+ class FeatureMatchingLoss(nn.Module):
96
+ """
97
+ Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators.
98
+ """
99
+
100
+ def forward(self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
101
+ """
102
+ Args:
103
+ fmap_r (List[List[Tensor]]): List of feature maps from real samples.
104
+ fmap_g (List[List[Tensor]]): List of feature maps from generated samples.
105
+
106
+ Returns:
107
+ Tensor: The calculated feature matching loss.
108
+ """
109
+ loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype)
110
+ for dr, dg in zip(fmap_r, fmap_g):
111
+ for rl, gl in zip(dr, dg):
112
+ loss += torch.mean(torch.abs(rl - gl))
113
+
114
+ return loss