sonalkum's picture
stable
9172422
raw
history blame
3.24 kB
"""The 1D discrete wavelet transform for PyTorch."""
from einops import rearrange
import pywt
import torch
from torch import nn
from torch.nn import functional as F
from typing import Literal
def get_filter_bank(wavelet):
filt = torch.tensor(pywt.Wavelet(wavelet).filter_bank)
if wavelet.startswith("bior") and torch.all(filt[:, 0] == 0):
filt = filt[:, 1:]
return filt
class WaveletEncode1d(nn.Module):
def __init__(self,
channels,
levels,
wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"):
super().__init__()
self.wavelet = wavelet
self.channels = channels
self.levels = levels
filt = get_filter_bank(wavelet)
assert filt.shape[-1] % 2 == 1
kernel = filt[:2, None]
kernel = torch.flip(kernel, dims=(-1,))
index_i = torch.repeat_interleave(torch.arange(2), channels)
index_j = torch.tile(torch.arange(channels), (2,))
kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1])
kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0]
self.register_buffer("kernel", kernel_final)
def forward(self, x):
for i in range(self.levels):
low, rest = x[:, : self.channels], x[:, self.channels :]
pad = self.kernel.shape[-1] // 2
low = F.pad(low, (pad, pad), "reflect")
low = F.conv1d(low, self.kernel, stride=2)
rest = rearrange(
rest, "n (c c2) (l l2) -> n (c l2 c2) l", l2=2, c2=self.channels
)
x = torch.cat([low, rest], dim=1)
return x
class WaveletDecode1d(nn.Module):
def __init__(self,
channels,
levels,
wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"):
super().__init__()
self.wavelet = wavelet
self.channels = channels
self.levels = levels
filt = get_filter_bank(wavelet)
assert filt.shape[-1] % 2 == 1
kernel = filt[2:, None]
index_i = torch.repeat_interleave(torch.arange(2), channels)
index_j = torch.tile(torch.arange(channels), (2,))
kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1])
kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0]
self.register_buffer("kernel", kernel_final)
def forward(self, x):
for i in range(self.levels):
low, rest = x[:, : self.channels * 2], x[:, self.channels * 2 :]
pad = self.kernel.shape[-1] // 2 + 2
low = rearrange(low, "n (l2 c) l -> n c (l l2)", l2=2)
low = F.pad(low, (pad, pad), "reflect")
low = rearrange(low, "n c (l l2) -> n (l2 c) l", l2=2)
low = F.conv_transpose1d(
low, self.kernel, stride=2, padding=self.kernel.shape[-1] // 2
)
low = low[..., pad - 1 : -pad]
rest = rearrange(
rest, "n (c l2 c2) l -> n (c c2) (l l2)", l2=2, c2=self.channels
)
x = torch.cat([low, rest], dim=1)
return x