Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from einops import rearrange | |
from scipy.optimize import fmin | |
from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord | |
class PQMF(nn.Module): | |
""" | |
Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction. | |
Uses polyphase representation which is computationally more efficient for real-time. | |
Parameters: | |
- attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB. | |
- num_bands (int): Number of desired frequency bands. It must be a power of 2. | |
""" | |
def __init__(self, attenuation, num_bands): | |
super(PQMF, self).__init__() | |
# Ensure num_bands is a power of 2 | |
is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands))) | |
assert is_power_of_2, "'num_bands' must be a power of 2." | |
# Create the prototype filter | |
prototype_filter = design_prototype_filter(attenuation, num_bands) | |
filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands) | |
padded_filter_bank = pad_to_nearest_power_of_two(filter_bank) | |
# Register filters and settings | |
self.register_buffer("filter_bank", padded_filter_bank) | |
self.register_buffer("prototype", prototype_filter) | |
self.num_bands = num_bands | |
def forward(self, signal): | |
"""Decompose the signal into multiple frequency bands.""" | |
# If signal is not a pytorch tensor of Batch x Channels x Length, convert it | |
signal = prepare_signal_dimensions(signal) | |
# The signal length must be a multiple of num_bands. Pad it with zeros. | |
signal = pad_signal(signal, self.num_bands) | |
# run it | |
signal = polyphase_analysis(signal, self.filter_bank) | |
return apply_alias_cancellation(signal) | |
def inverse(self, bands): | |
"""Reconstruct the original signal from the frequency bands.""" | |
bands = apply_alias_cancellation(bands) | |
return polyphase_synthesis(bands, self.filter_bank) | |
def prepare_signal_dimensions(signal): | |
""" | |
Rearrange signal into Batch x Channels x Length. | |
Parameters | |
---------- | |
signal : torch.Tensor or numpy.ndarray | |
The input signal. | |
Returns | |
------- | |
torch.Tensor | |
Preprocessed signal tensor. | |
""" | |
# Convert numpy to torch tensor | |
if isinstance(signal, np.ndarray): | |
signal = torch.from_numpy(signal) | |
# Ensure tensor | |
if not isinstance(signal, torch.Tensor): | |
raise ValueError("Input should be either a numpy array or a PyTorch tensor.") | |
# Modify dimension of signal to Batch x Channels x Length | |
if signal.dim() == 1: | |
# This is just a mono signal. Unsqueeze to 1 x 1 x Length | |
signal = signal.unsqueeze(0).unsqueeze(0) | |
elif signal.dim() == 2: | |
# This is a multi-channel signal (e.g. stereo) | |
# Rearrange so that larger dimension (Length) is last | |
if signal.shape[0] > signal.shape[1]: | |
signal = signal.T | |
# Unsqueeze to 1 x Channels x Length | |
signal = signal.unsqueeze(0) | |
return signal | |
def pad_signal(signal, num_bands): | |
""" | |
Pads the signal to make its length divisible by the given number of bands. | |
Parameters | |
---------- | |
signal : torch.Tensor | |
The input signal tensor, where the last dimension represents the signal length. | |
num_bands : int | |
The number of bands by which the signal length should be divisible. | |
Returns | |
------- | |
torch.Tensor | |
The padded signal tensor. If the original signal length was already divisible | |
by num_bands, returns the original signal unchanged. | |
""" | |
remainder = signal.shape[-1] % num_bands | |
if remainder > 0: | |
padding_size = num_bands - remainder | |
signal = nn.functional.pad(signal, (0, padding_size)) | |
return signal | |
def generate_modulated_filter_bank(prototype_filter, num_bands): | |
""" | |
Generate a QMF bank of cosine modulated filters based on a given prototype filter. | |
Parameters | |
---------- | |
prototype_filter : torch.Tensor | |
The prototype filter used as the basis for modulation. | |
num_bands : int | |
The number of desired subbands or filters. | |
Returns | |
------- | |
torch.Tensor | |
A bank of cosine modulated filters. | |
""" | |
# Initialize indices for modulation. | |
subband_indices = torch.arange(num_bands).reshape(-1, 1) | |
# Calculate the length of the prototype filter. | |
filter_length = prototype_filter.shape[-1] | |
# Generate symmetric time indices centered around zero. | |
time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1) | |
# Calculate phase offsets to ensure orthogonality between subbands. | |
phase_offsets = (-1)**subband_indices * np.pi / 4 | |
# Compute the cosine modulation function. | |
modulation = torch.cos( | |
(2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets | |
) | |
# Apply modulation to the prototype filter. | |
modulated_filters = 2 * prototype_filter * modulation | |
return modulated_filters | |
def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None): | |
""" | |
Design a lowpass filter using the Kaiser window. | |
Parameters | |
---------- | |
angular_cutoff : float | |
The angular frequency cutoff of the filter. | |
attenuation : float | |
The desired stopband attenuation in decibels (dB). | |
filter_length : int, optional | |
Desired length of the filter. If not provided, it's computed based on the given specs. | |
Returns | |
------- | |
ndarray | |
The designed lowpass filter coefficients. | |
""" | |
estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi) | |
# Ensure the estimated length is odd. | |
estimated_length = 2 * (estimated_length // 2) + 1 | |
if filter_length is None: | |
filter_length = estimated_length | |
return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi) | |
def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length): | |
""" | |
Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427 | |
Parameters | |
---------- | |
angular_cutoff : float | |
Angular frequency cutoff of the filter. | |
attenuation : float | |
Desired stopband attenuation in dB. | |
num_bands : int | |
Number of bands for the multiband filter system. | |
filter_length : int, optional | |
Desired length of the filter. | |
Returns | |
------- | |
float | |
The computed objective (loss) value for the given filter specs. | |
""" | |
filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length) | |
convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full") | |
return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:])) | |
def design_prototype_filter(attenuation, num_bands, filter_length=None): | |
""" | |
Design the optimal prototype filter for a multiband system given the desired specs. | |
Parameters | |
---------- | |
attenuation : float | |
The desired stopband attenuation in dB. | |
num_bands : int | |
Number of bands for the multiband filter system. | |
filter_length : int, optional | |
Desired length of the filter. If not provided, it's computed based on the given specs. | |
Returns | |
------- | |
ndarray | |
The optimal prototype filter coefficients. | |
""" | |
optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length), | |
1 / num_bands, disp=0)[0] | |
prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length) | |
return torch.tensor(prototype_filter, dtype=torch.float32) | |
def pad_to_nearest_power_of_two(x): | |
""" | |
Pads the input tensor 'x' on both sides such that its last dimension | |
becomes the nearest larger power of two. | |
Parameters: | |
----------- | |
x : torch.Tensor | |
The input tensor to be padded. | |
Returns: | |
-------- | |
torch.Tensor | |
The padded tensor. | |
""" | |
current_length = x.shape[-1] | |
target_length = 2**math.ceil(math.log2(current_length)) | |
total_padding = target_length - current_length | |
left_padding = total_padding // 2 | |
right_padding = total_padding - left_padding | |
return nn.functional.pad(x, (left_padding, right_padding)) | |
def apply_alias_cancellation(x): | |
""" | |
Applies alias cancellation by inverting the sign of every | |
second element of every second row, starting from the second | |
row's first element in a tensor. | |
This operation helps ensure that the aliasing introduced in | |
each band during the decomposition will be counteracted during | |
the reconstruction. | |
Parameters: | |
----------- | |
x : torch.Tensor | |
The input tensor. | |
Returns: | |
-------- | |
torch.Tensor | |
Tensor with specific elements' sign inverted for alias cancellation. | |
""" | |
# Create a mask of the same shape as 'x', initialized with all ones | |
mask = torch.ones_like(x) | |
# Update specific elements in the mask to -1 to perform inversion | |
mask[..., 1::2, ::2] = -1 | |
# Apply the mask to the input tensor 'x' | |
return x * mask | |
def ensure_odd_length(tensor): | |
""" | |
Pads the last dimension of a tensor to ensure its size is odd. | |
Parameters: | |
----------- | |
tensor : torch.Tensor | |
Input tensor whose last dimension might need padding. | |
Returns: | |
-------- | |
torch.Tensor | |
The original tensor if its last dimension was already odd, | |
or the padded tensor with an odd-sized last dimension. | |
""" | |
last_dim_size = tensor.shape[-1] | |
if last_dim_size % 2 == 0: | |
tensor = nn.functional.pad(tensor, (0, 1)) | |
return tensor | |
def polyphase_analysis(signal, filter_bank): | |
""" | |
Applies the polyphase method to efficiently analyze the signal using a filter bank. | |
Parameters: | |
----------- | |
signal : torch.Tensor | |
Input signal tensor with shape (Batch x Channels x Length). | |
filter_bank : torch.Tensor | |
Filter bank tensor with shape (Bands x Length). | |
Returns: | |
-------- | |
torch.Tensor | |
Signal split into sub-bands. (Batch x Channels x Bands x Length) | |
""" | |
num_bands = filter_bank.shape[0] | |
num_channels = signal.shape[1] | |
# Rearrange signal for polyphase processing. | |
# Also combine Batch x Channel into one dimension for now. | |
#signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands) | |
signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands) | |
# Rearrange the filter bank for matching signal shape | |
filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands) | |
# Apply convolution with appropriate padding to maintain spatial dimensions | |
padding = filter_bank.shape[-1] // 2 | |
filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding) | |
# Truncate the last dimension post-convolution to adjust the output shape | |
filtered_signal = filtered_signal[..., :-1] | |
# Rearrange the first dimension back into Batch x Channels | |
filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels) | |
return filtered_signal | |
def polyphase_synthesis(signal, filter_bank): | |
""" | |
Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal. | |
Parameters | |
---------- | |
signal : torch.Tensor | |
Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length). | |
filter_bank : torch.Tensor | |
Analysis filter bank (shape: Bands x Length). | |
should_rearrange : bool, optional | |
Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True. | |
Returns | |
------- | |
torch.Tensor | |
Reconstructed signal (shape: Batch x Channels X Length) | |
""" | |
num_bands = filter_bank.shape[0] | |
num_channels = signal.shape[1] | |
# Rearrange the filter bank | |
filter_bank = filter_bank.flip(-1) | |
filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands) | |
# Combine Batch x Channels into one dimension for now. | |
signal = rearrange(signal, "b c n t -> (b c) n t") | |
# Apply convolution with appropriate padding | |
padding_amount = filter_bank.shape[-1] // 2 + 1 | |
reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount)) | |
# Scale the result | |
reconstructed_signal = reconstructed_signal[..., :-1] * num_bands | |
# Reorganize the output and truncate | |
reconstructed_signal = reconstructed_signal.flip(1) | |
reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands) | |
reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:] | |
return reconstructed_signal |