Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,240 Bytes
9172422 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
"""The 1D discrete wavelet transform for PyTorch."""
from einops import rearrange
import pywt
import torch
from torch import nn
from torch.nn import functional as F
from typing import Literal
def get_filter_bank(wavelet):
filt = torch.tensor(pywt.Wavelet(wavelet).filter_bank)
if wavelet.startswith("bior") and torch.all(filt[:, 0] == 0):
filt = filt[:, 1:]
return filt
class WaveletEncode1d(nn.Module):
def __init__(self,
channels,
levels,
wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"):
super().__init__()
self.wavelet = wavelet
self.channels = channels
self.levels = levels
filt = get_filter_bank(wavelet)
assert filt.shape[-1] % 2 == 1
kernel = filt[:2, None]
kernel = torch.flip(kernel, dims=(-1,))
index_i = torch.repeat_interleave(torch.arange(2), channels)
index_j = torch.tile(torch.arange(channels), (2,))
kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1])
kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0]
self.register_buffer("kernel", kernel_final)
def forward(self, x):
for i in range(self.levels):
low, rest = x[:, : self.channels], x[:, self.channels :]
pad = self.kernel.shape[-1] // 2
low = F.pad(low, (pad, pad), "reflect")
low = F.conv1d(low, self.kernel, stride=2)
rest = rearrange(
rest, "n (c c2) (l l2) -> n (c l2 c2) l", l2=2, c2=self.channels
)
x = torch.cat([low, rest], dim=1)
return x
class WaveletDecode1d(nn.Module):
def __init__(self,
channels,
levels,
wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"):
super().__init__()
self.wavelet = wavelet
self.channels = channels
self.levels = levels
filt = get_filter_bank(wavelet)
assert filt.shape[-1] % 2 == 1
kernel = filt[2:, None]
index_i = torch.repeat_interleave(torch.arange(2), channels)
index_j = torch.tile(torch.arange(channels), (2,))
kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1])
kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0]
self.register_buffer("kernel", kernel_final)
def forward(self, x):
for i in range(self.levels):
low, rest = x[:, : self.channels * 2], x[:, self.channels * 2 :]
pad = self.kernel.shape[-1] // 2 + 2
low = rearrange(low, "n (l2 c) l -> n c (l l2)", l2=2)
low = F.pad(low, (pad, pad), "reflect")
low = rearrange(low, "n c (l l2) -> n (l2 c) l", l2=2)
low = F.conv_transpose1d(
low, self.kernel, stride=2, padding=self.kernel.shape[-1] // 2
)
low = low[..., pad - 1 : -pad]
rest = rearrange(
rest, "n (c l2 c2) l -> n (c c2) (l l2)", l2=2, c2=self.channels
)
x = torch.cat([low, rest], dim=1)
return x |