Spaces:
Running
on
Zero
Running
on
Zero
File size: 23,627 Bytes
9172422 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 |
# Copied and modified from https://github.com/csteinmetz1/auraloss/blob/main/auraloss/freq.py under Apache License 2.0
# You can find the license at LICENSES/LICENSE_AURALOSS.txt
import torch
import numpy as np
from typing import List, Any
import scipy.signal
def apply_reduction(losses, reduction="none"):
"""Apply reduction to collection of losses."""
if reduction == "mean":
losses = losses.mean()
elif reduction == "sum":
losses = losses.sum()
return losses
def get_window(win_type: str, win_length: int):
"""Return a window function.
Args:
win_type (str): Window type. Can either be one of the window function provided in PyTorch
['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html).
win_length (int): Window length
Returns:
win: The window as a 1D torch tensor
"""
try:
win = getattr(torch, win_type)(win_length)
except:
win = torch.from_numpy(scipy.signal.windows.get_window(win_type, win_length))
return win
class SumAndDifference(torch.nn.Module):
"""Sum and difference signal extraction module."""
def __init__(self):
"""Initialize sum and difference extraction module."""
super(SumAndDifference, self).__init__()
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, #channels, #samples).
Returns:
Tensor: Sum signal.
Tensor: Difference signal.
"""
if not (x.size(1) == 2): # inputs must be stereo
raise ValueError(f"Input must be stereo: {x.size(1)} channel(s).")
sum_sig = self.sum(x).unsqueeze(1)
diff_sig = self.diff(x).unsqueeze(1)
return sum_sig, diff_sig
@staticmethod
def sum(x):
return x[:, 0, :] + x[:, 1, :]
@staticmethod
def diff(x):
return x[:, 0, :] - x[:, 1, :]
class FIRFilter(torch.nn.Module):
"""FIR pre-emphasis filtering module.
Args:
filter_type (str): Shape of the desired FIR filter ("hp", "fd", "aw"). Default: "hp"
coef (float): Coefficient value for the filter tap (only applicable for "hp" and "fd"). Default: 0.85
ntaps (int): Number of FIR filter taps for constructing A-weighting filters. Default: 101
plot (bool): Plot the magnitude respond of the filter. Default: False
Based upon the perceptual loss pre-empahsis filters proposed by
[Wright & Välimäki, 2019](https://arxiv.org/abs/1911.08922).
A-weighting filter - "aw"
First-order highpass - "hp"
Folded differentiator - "fd"
Note that the default coefficeint value of 0.85 is optimized for
a sampling rate of 44.1 kHz, considering adjusting this value at differnt sampling rates.
"""
def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False):
"""Initilize FIR pre-emphasis filtering module."""
super(FIRFilter, self).__init__()
self.filter_type = filter_type
self.coef = coef
self.fs = fs
self.ntaps = ntaps
self.plot = plot
import scipy.signal
if ntaps % 2 == 0:
raise ValueError(f"ntaps must be odd (ntaps={ntaps}).")
if filter_type == "hp":
self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1)
self.fir.weight.requires_grad = False
self.fir.weight.data = torch.tensor([1, -coef, 0]).view(1, 1, -1)
elif filter_type == "fd":
self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1)
self.fir.weight.requires_grad = False
self.fir.weight.data = torch.tensor([1, 0, -coef]).view(1, 1, -1)
elif filter_type == "aw":
# Definition of analog A-weighting filter according to IEC/CD 1672.
f1 = 20.598997
f2 = 107.65265
f3 = 737.86223
f4 = 12194.217
A1000 = 1.9997
NUMs = [(2 * np.pi * f4) ** 2 * (10 ** (A1000 / 20)), 0, 0, 0, 0]
DENs = np.polymul(
[1, 4 * np.pi * f4, (2 * np.pi * f4) ** 2],
[1, 4 * np.pi * f1, (2 * np.pi * f1) ** 2],
)
DENs = np.polymul(
np.polymul(DENs, [1, 2 * np.pi * f3]), [1, 2 * np.pi * f2]
)
# convert analog filter to digital filter
b, a = scipy.signal.bilinear(NUMs, DENs, fs=fs)
# compute the digital filter frequency response
w_iir, h_iir = scipy.signal.freqz(b, a, worN=512, fs=fs)
# then we fit to 101 tap FIR filter with least squares
taps = scipy.signal.firls(ntaps, w_iir, abs(h_iir), fs=fs)
# now implement this digital FIR filter as a Conv1d layer
self.fir = torch.nn.Conv1d(
1, 1, kernel_size=ntaps, bias=False, padding=ntaps // 2
)
self.fir.weight.requires_grad = False
self.fir.weight.data = torch.tensor(taps.astype("float32")).view(1, 1, -1)
if plot:
from .plotting import compare_filters
compare_filters(b, a, taps, fs=fs)
def forward(self, input, target):
"""Calculate forward propagation.
Args:
input (Tensor): Predicted signal (B, #channels, #samples).
target (Tensor): Groundtruth signal (B, #channels, #samples).
Returns:
Tensor: Filtered signal.
"""
input = torch.nn.functional.conv1d(
input, self.fir.weight.data, padding=self.ntaps // 2
)
target = torch.nn.functional.conv1d(
target, self.fir.weight.data, padding=self.ntaps // 2
)
return input, target
class SpectralConvergenceLoss(torch.nn.Module):
"""Spectral convergence loss module.
See [Arik et al., 2018](https://arxiv.org/abs/1808.06719).
"""
def __init__(self):
super(SpectralConvergenceLoss, self).__init__()
def forward(self, x_mag, y_mag):
return (torch.norm(y_mag - x_mag, p="fro", dim=[-1, -2]) / torch.norm(y_mag, p="fro", dim=[-1, -2])).mean()
class STFTMagnitudeLoss(torch.nn.Module):
"""STFT magnitude loss module.
See [Arik et al., 2018](https://arxiv.org/abs/1808.06719)
and [Engel et al., 2020](https://arxiv.org/abs/2001.04643v1)
Log-magnitudes are calculated with `log(log_fac*x + log_eps)`, where `log_fac` controls the
compression strength (larger value results in more compression), and `log_eps` can be used
to control the range of the compressed output values (e.g., `log_eps>=1` ensures positive
output values). The default values `log_fac=1` and `log_eps=0` correspond to plain log-compression.
Args:
log (bool, optional): Log-scale the STFT magnitudes,
or use linear scale. Default: True
log_eps (float, optional): Constant value added to the magnitudes before evaluating the logarithm.
Default: 0.0
log_fac (float, optional): Constant multiplication factor for the magnitudes before evaluating the logarithm.
Default: 1.0
distance (str, optional): Distance function ["L1", "L2"]. Default: "L1"
reduction (str, optional): Reduction of the loss elements. Default: "mean"
"""
def __init__(self, log=True, log_eps=0.0, log_fac=1.0, distance="L1", reduction="mean"):
super(STFTMagnitudeLoss, self).__init__()
self.log = log
self.log_eps = log_eps
self.log_fac = log_fac
if distance == "L1":
self.distance = torch.nn.L1Loss(reduction=reduction)
elif distance == "L2":
self.distance = torch.nn.MSELoss(reduction=reduction)
else:
raise ValueError(f"Invalid distance: '{distance}'.")
def forward(self, x_mag, y_mag):
if self.log:
x_mag = torch.log(self.log_fac * x_mag + self.log_eps)
y_mag = torch.log(self.log_fac * y_mag + self.log_eps)
return self.distance(x_mag, y_mag)
class STFTLoss(torch.nn.Module):
"""STFT loss module.
See [Yamamoto et al. 2019](https://arxiv.org/abs/1904.04472).
Args:
fft_size (int, optional): FFT size in samples. Default: 1024
hop_size (int, optional): Hop size of the FFT in samples. Default: 256
win_length (int, optional): Length of the FFT analysis window. Default: 1024
window (str, optional): Window to apply before FFT, can either be one of the window function provided in PyTorch
['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html).
Default: 'hann_window'
w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0
w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0
w_lin_mag_mag (float, optional): Weight of the linear magnitude loss term. Default: 0.0
w_phs (float, optional): Weight of the spectral phase loss term. Default: 0.0
sample_rate (int, optional): Sample rate. Required when scale = 'mel'. Default: None
scale (str, optional): Optional frequency scaling method, options include:
['mel', 'chroma']
Default: None
n_bins (int, optional): Number of scaling frequency bins. Default: None.
perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False
scale_invariance (bool, optional): Perform an optimal scaling of the target. Default: False
eps (float, optional): Small epsilon value for stablity. Default: 1e-8
output (str, optional): Format of the loss returned.
'loss' : Return only the raw, aggregate loss term.
'full' : Return the raw loss, plus intermediate loss terms.
Default: 'loss'
reduction (str, optional): Specifies the reduction to apply to the output:
'none': no reduction will be applied,
'mean': the sum of the output will be divided by the number of elements in the output,
'sum': the output will be summed.
Default: 'mean'
mag_distance (str, optional): Distance function ["L1", "L2"] for the magnitude loss terms.
device (str, optional): Place the filterbanks on specified device. Default: None
Returns:
loss:
Aggreate loss term. Only returned if output='loss'. By default.
loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss:
Aggregate and intermediate loss terms. Only returned if output='full'.
"""
def __init__(
self,
fft_size: int = 1024,
hop_size: int = 256,
win_length: int = 1024,
window: str = "hann_window",
w_sc: float = 1.0,
w_log_mag: float = 1.0,
w_lin_mag: float = 0.0,
w_phs: float = 0.0,
sample_rate: float = None,
scale: str = None,
n_bins: int = None,
perceptual_weighting: bool = False,
scale_invariance: bool = False,
eps: float = 1e-8,
output: str = "loss",
reduction: str = "mean",
mag_distance: str = "L1",
device: Any = None,
**kwargs
):
super().__init__()
self.fft_size = fft_size
self.hop_size = hop_size
self.win_length = win_length
self.window = get_window(window, win_length)
self.w_sc = w_sc
self.w_log_mag = w_log_mag
self.w_lin_mag = w_lin_mag
self.w_phs = w_phs
self.sample_rate = sample_rate
self.scale = scale
self.n_bins = n_bins
self.perceptual_weighting = perceptual_weighting
self.scale_invariance = scale_invariance
self.eps = eps
self.output = output
self.reduction = reduction
self.mag_distance = mag_distance
self.device = device
self.phs_used = bool(self.w_phs)
self.spectralconv = SpectralConvergenceLoss()
self.logstft = STFTMagnitudeLoss(
log=True,
reduction=reduction,
distance=mag_distance,
**kwargs
)
self.linstft = STFTMagnitudeLoss(
log=False,
reduction=reduction,
distance=mag_distance,
**kwargs
)
# setup mel filterbank
if scale is not None:
try:
import librosa.filters
except Exception as e:
print(e)
print("Try `pip install auraloss[all]`.")
if self.scale == "mel":
assert sample_rate != None # Must set sample rate to use mel scale
assert n_bins <= fft_size # Must be more FFT bins than Mel bins
fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins)
fb = torch.tensor(fb).unsqueeze(0)
elif self.scale == "chroma":
assert sample_rate != None # Must set sample rate to use chroma scale
assert n_bins <= fft_size # Must be more FFT bins than chroma bins
fb = librosa.filters.chroma(
sr=sample_rate, n_fft=fft_size, n_chroma=n_bins
)
else:
raise ValueError(
f"Invalid scale: {self.scale}. Must be 'mel' or 'chroma'."
)
self.register_buffer("fb", fb)
if scale is not None and device is not None:
self.fb = self.fb.to(self.device) # move filterbank to device
if self.perceptual_weighting:
if sample_rate is None:
raise ValueError(
f"`sample_rate` must be supplied when `perceptual_weighting = True`."
)
self.prefilter = FIRFilter(filter_type="aw", fs=sample_rate)
def stft(self, x):
"""Perform STFT.
Args:
x (Tensor): Input signal tensor (B, T).
Returns:
Tensor: x_mag, x_phs
Magnitude and phase spectra (B, fft_size // 2 + 1, frames).
"""
x_stft = torch.stft(
x,
self.fft_size,
self.hop_size,
self.win_length,
self.window,
return_complex=True,
)
x_mag = torch.sqrt(
torch.clamp((x_stft.real**2) + (x_stft.imag**2), min=self.eps)
)
# torch.angle is expensive, so it is only evaluated if the values are used in the loss
if self.phs_used:
x_phs = torch.angle(x_stft)
else:
x_phs = None
return x_mag, x_phs
def forward(self, input: torch.Tensor, target: torch.Tensor):
bs, chs, seq_len = input.size()
if self.perceptual_weighting: # apply optional A-weighting via FIR filter
# since FIRFilter only support mono audio we will move channels to batch dim
input = input.view(bs * chs, 1, -1)
target = target.view(bs * chs, 1, -1)
# now apply the filter to both
self.prefilter.to(input.device)
input, target = self.prefilter(input, target)
# now move the channels back
input = input.view(bs, chs, -1)
target = target.view(bs, chs, -1)
# compute the magnitude and phase spectra of input and target
self.window = self.window.to(input.device)
x_mag, x_phs = self.stft(input.view(-1, input.size(-1)))
y_mag, y_phs = self.stft(target.view(-1, target.size(-1)))
# apply relevant transforms
if self.scale is not None:
self.fb = self.fb.to(input.device)
x_mag = torch.matmul(self.fb, x_mag)
y_mag = torch.matmul(self.fb, y_mag)
# normalize scales
if self.scale_invariance:
alpha = (x_mag * y_mag).sum([-2, -1]) / ((y_mag**2).sum([-2, -1]))
y_mag = y_mag * alpha.unsqueeze(-1)
# compute loss terms
sc_mag_loss = self.spectralconv(x_mag, y_mag) if self.w_sc else 0.0
log_mag_loss = self.logstft(x_mag, y_mag) if self.w_log_mag else 0.0
lin_mag_loss = self.linstft(x_mag, y_mag) if self.w_lin_mag else 0.0
phs_loss = torch.nn.functional.mse_loss(x_phs, y_phs) if self.phs_used else 0.0
# combine loss terms
loss = (
(self.w_sc * sc_mag_loss)
+ (self.w_log_mag * log_mag_loss)
+ (self.w_lin_mag * lin_mag_loss)
+ (self.w_phs * phs_loss)
)
loss = apply_reduction(loss, reduction=self.reduction)
if self.output == "loss":
return loss
elif self.output == "full":
return loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss
class MultiResolutionSTFTLoss(torch.nn.Module):
"""Multi resolution STFT loss module.
See [Yamamoto et al., 2019](https://arxiv.org/abs/1910.11480)
Args:
fft_sizes (list): List of FFT sizes.
hop_sizes (list): List of hop sizes.
win_lengths (list): List of window lengths.
window (str, optional): Window to apply before FFT, options include:
'hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
Default: 'hann_window'
w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0
w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0
w_lin_mag (float, optional): Weight of the linear magnitude loss term. Default: 0.0
w_phs (float, optional): Weight of the spectral phase loss term. Default: 0.0
sample_rate (int, optional): Sample rate. Required when scale = 'mel'. Default: None
scale (str, optional): Optional frequency scaling method, options include:
['mel', 'chroma']
Default: None
n_bins (int, optional): Number of mel frequency bins. Required when scale = 'mel'. Default: None.
scale_invariance (bool, optional): Perform an optimal scaling of the target. Default: False
"""
def __init__(
self,
fft_sizes: List[int] = [1024, 2048, 512],
hop_sizes: List[int] = [120, 240, 50],
win_lengths: List[int] = [600, 1200, 240],
window: str = "hann_window",
w_sc: float = 1.0,
w_log_mag: float = 1.0,
w_lin_mag: float = 0.0,
w_phs: float = 0.0,
sample_rate: float = None,
scale: str = None,
n_bins: int = None,
perceptual_weighting: bool = False,
scale_invariance: bool = False,
**kwargs,
):
super().__init__()
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) # must define all
self.fft_sizes = fft_sizes
self.hop_sizes = hop_sizes
self.win_lengths = win_lengths
self.stft_losses = torch.nn.ModuleList()
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
self.stft_losses += [
STFTLoss(
fs,
ss,
wl,
window,
w_sc,
w_log_mag,
w_lin_mag,
w_phs,
sample_rate,
scale,
n_bins,
perceptual_weighting,
scale_invariance,
**kwargs,
)
]
def forward(self, x, y):
mrstft_loss = 0.0
sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss = [], [], [], []
for f in self.stft_losses:
if f.output == "full": # extract just first term
tmp_loss = f(x, y)
mrstft_loss += tmp_loss[0]
sc_mag_loss.append(tmp_loss[1])
log_mag_loss.append(tmp_loss[2])
lin_mag_loss.append(tmp_loss[3])
phs_loss.append(tmp_loss[4])
else:
mrstft_loss += f(x, y)
mrstft_loss /= len(self.stft_losses)
if f.output == "loss":
return mrstft_loss
else:
return mrstft_loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss
class SumAndDifferenceSTFTLoss(torch.nn.Module):
"""Sum and difference sttereo STFT loss module.
See [Steinmetz et al., 2020](https://arxiv.org/abs/2010.10291)
Args:
fft_sizes (List[int]): List of FFT sizes.
hop_sizes (List[int]): List of hop sizes.
win_lengths (List[int]): List of window lengths.
window (str, optional): Window function type.
w_sum (float, optional): Weight of the sum loss component. Default: 1.0
w_diff (float, optional): Weight of the difference loss component. Default: 1.0
perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False
mel_stft (bool, optional): Use Multi-resoltuion mel spectrograms. Default: False
n_mel_bins (int, optional): Number of mel bins to use when mel_stft = True. Default: 128
sample_rate (float, optional): Audio sample rate. Default: None
output (str, optional): Format of the loss returned.
'loss' : Return only the raw, aggregate loss term.
'full' : Return the raw loss, plus intermediate loss terms.
Default: 'loss'
"""
def __init__(
self,
fft_sizes: List[int],
hop_sizes: List[int],
win_lengths: List[int],
window: str = "hann_window",
w_sum: float = 1.0,
w_diff: float = 1.0,
output: str = "loss",
**kwargs,
):
super().__init__()
self.sd = SumAndDifference()
self.w_sum = w_sum
self.w_diff = w_diff
self.output = output
self.mrstft = MultiResolutionSTFTLoss(
fft_sizes,
hop_sizes,
win_lengths,
window,
**kwargs,
)
def forward(self, input: torch.Tensor, target: torch.Tensor):
"""This loss function assumes batched input of stereo audio in the time domain.
Args:
input (torch.Tensor): Input tensor with shape (batch size, 2, seq_len).
target (torch.Tensor): Target tensor with shape (batch size, 2, seq_len).
Returns:
loss (torch.Tensor): Aggreate loss term. Only returned if output='loss'.
loss (torch.Tensor), sum_loss (torch.Tensor), diff_loss (torch.Tensor):
Aggregate and intermediate loss terms. Only returned if output='full'.
"""
assert input.shape == target.shape # must have same shape
bs, chs, seq_len = input.size()
# compute sum and difference signals for both
input_sum, input_diff = self.sd(input)
target_sum, target_diff = self.sd(target)
# compute error in STFT domain
sum_loss = self.mrstft(input_sum, target_sum)
diff_loss = self.mrstft(input_diff, target_diff)
loss = ((self.w_sum * sum_loss) + (self.w_diff * diff_loss)) / 2
if self.output == "loss":
return loss
elif self.output == "full":
return loss, sum_loss, diff_loss |