Spaces:
Build error
Build error
File size: 4,026 Bytes
51da11a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import torch
import auraloss
import resampy
import torchaudio
from pesq import pesq
import pyloudnorm as pyln
def crest_factor(x):
"""Compute the crest factor of waveform."""
peak, _ = x.abs().max(dim=-1)
rms = torch.sqrt((x ** 2).mean(dim=-1))
return 20 * torch.log(peak / rms.clamp(1e-8))
def rms_energy(x):
rms = torch.sqrt((x ** 2).mean(dim=-1))
return 20 * torch.log(rms.clamp(1e-8))
def spectral_centroid(x):
"""Compute the crest factor of waveform.
See: https://gist.github.com/endolith/359724
"""
spectrum = torch.fft.rfft(x).abs()
normalized_spectrum = spectrum / spectrum.sum()
normalized_frequencies = torch.linspace(0, 1, spectrum.shape[-1])
spectral_centroid = torch.sum(normalized_frequencies * normalized_spectrum)
return spectral_centroid
def loudness(x, sample_rate):
"""Compute the loudness in dB LUFS of waveform."""
meter = pyln.Meter(sample_rate)
# add stereo dim if needed
if x.shape[0] < 2:
x = x.repeat(2, 1)
return torch.tensor(meter.integrated_loudness(x.permute(1, 0).numpy()))
class MelSpectralDistance(torch.nn.Module):
def __init__(self, sample_rate, length=65536):
super().__init__()
self.error = auraloss.freq.MelSTFTLoss(
sample_rate,
fft_size=length,
hop_size=length,
win_length=length,
w_sc=0,
w_log_mag=1,
w_lin_mag=1,
n_mels=128,
scale_invariance=False,
)
# I think scale invariance may not work well,
# since aspects of the phase may be considered?
def forward(self, input, target):
return self.error(input, target)
class PESQ(torch.nn.Module):
def __init__(self, sample_rate):
super().__init__()
self.sample_rate = sample_rate
def forward(self, input, target):
if self.sample_rate != 16000:
target = resampy.resample(
target.view(-1).numpy(),
self.sample_rate,
16000,
)
input = resampy.resample(
input.view(-1).numpy(),
self.sample_rate,
16000,
)
return pesq(
16000,
target,
input,
"wb",
)
class CrestFactorError(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, target):
return torch.nn.functional.l1_loss(
crest_factor(input),
crest_factor(target),
).item()
class RMSEnergyError(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, target):
return torch.nn.functional.l1_loss(
rms_energy(input),
rms_energy(target),
).item()
class SpectralCentroidError(torch.nn.Module):
def __init__(self, sample_rate, n_fft=2048, hop_length=512):
super().__init__()
self.spectral_centroid = torchaudio.transforms.SpectralCentroid(
sample_rate,
n_fft=n_fft,
hop_length=hop_length,
)
def forward(self, input, target):
return torch.nn.functional.l1_loss(
self.spectral_centroid(input + 1e-16).mean(),
self.spectral_centroid(target + 1e-16).mean(),
).item()
class LoudnessError(torch.nn.Module):
def __init__(self, sample_rate: int, peak_normalize: bool = False):
super().__init__()
self.sample_rate = sample_rate
self.peak_normalize = peak_normalize
def forward(self, input, target):
if self.peak_normalize:
# peak normalize
x = input / input.abs().max()
y = target / target.abs().max()
else:
x = input
y = target
return torch.nn.functional.l1_loss(
loudness(x.view(1, -1), self.sample_rate),
loudness(y.view(1, -1), self.sample_rate),
).item()
|