Spaces:
Build error
Build error
File size: 4,929 Bytes
51da11a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import torch
from deepafx_st.processors.proxy.proxy_system import ProxySystem
from deepafx_st.utils import DSPMode
class ProxyChannel(torch.nn.Module):
def __init__(
self,
proxy_system_ckpts: list,
freeze_proxies: bool = True,
dsp_mode: DSPMode = DSPMode.NONE,
num_tcns: int = 2,
tcn_nblocks: int = 4,
tcn_dilation_growth: int = 8,
tcn_channel_width: int = 64,
tcn_kernel_size: int = 13,
sample_rate: int = 24000,
):
super().__init__()
self.freeze_proxies = freeze_proxies
self.dsp_mode = dsp_mode
self.num_tcns = num_tcns
# load the proxies
self.proxies = torch.nn.ModuleList()
self.num_control_params = 0
self.ports = []
for proxy_system_ckpt in proxy_system_ckpts:
proxy = ProxySystem.load_from_checkpoint(proxy_system_ckpt)
# freeze model parameters
if freeze_proxies:
for param in proxy.parameters():
param.requires_grad = False
self.proxies.append(proxy)
if proxy.hparams.processor == "channel":
self.ports = proxy.processor.ports
else:
self.ports.append(proxy.processor.ports)
self.num_control_params += proxy.processor.num_control_params
if len(proxy_system_ckpts) == 0:
if self.num_tcns == 2:
peq_proxy = ProxySystem(
processor="peq",
output_gain=False,
nblocks=tcn_nblocks,
dilation_growth=tcn_dilation_growth,
kernel_size=tcn_kernel_size,
channel_width=tcn_channel_width,
sample_rate=sample_rate,
)
self.proxies.append(peq_proxy)
self.ports.append(peq_proxy.processor.ports)
self.num_control_params += peq_proxy.processor.num_control_params
comp_proxy = ProxySystem(
processor="comp",
output_gain=True,
nblocks=tcn_nblocks,
dilation_growth=tcn_dilation_growth,
kernel_size=tcn_kernel_size,
channel_width=tcn_channel_width,
sample_rate=sample_rate,
)
self.proxies.append(comp_proxy)
self.ports.append(comp_proxy.processor.ports)
self.num_control_params += comp_proxy.processor.num_control_params
elif self.num_tcns == 1:
channel_proxy = ProxySystem(
processor="channel",
output_gain=True,
nblocks=tcn_nblocks,
dilation_growth=tcn_dilation_growth,
kernel_size=tcn_kernel_size,
channel_width=tcn_channel_width,
sample_rate=sample_rate,
)
self.proxies.append(channel_proxy)
for port_list in channel_proxy.processor.ports:
self.ports.append(port_list)
self.num_control_params += channel_proxy.processor.num_control_params
else:
raise ValueError(f"num_tcns must be <= 2. Asked for {self.num_tcns}.")
def forward(
self,
x: torch.Tensor,
p: torch.Tensor,
dsp_mode: DSPMode = DSPMode.NONE,
sample_rate: int = 24000,
**kwargs,
):
# loop over the proxies and pass parameters
stop_idx = 0
for proxy in self.proxies:
start_idx = stop_idx
stop_idx += proxy.processor.num_control_params
p_subset = p[:, start_idx:stop_idx]
if dsp_mode.name == DSPMode.NONE.name:
x = proxy(
x,
p_subset,
use_dsp=False,
)
elif dsp_mode.name == DSPMode.INFER.name:
x = proxy(
x,
p_subset,
use_dsp=True,
sample_rate=sample_rate,
)
elif dsp_mode.name == DSPMode.TRAIN_INFER.name:
# Mimic gumbel softmax implementation to replace grads similar to
# https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f
x_hard = proxy(
x,
p_subset,
use_dsp=True,
sample_rate=sample_rate,
)
x = proxy(
x,
p_subset,
use_dsp=False,
sample_rate=sample_rate,
)
x = (x_hard - x).detach() + x
else:
assert 0, "invalid dsp model for proxy"
return x
|