Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from model.crop import centre_crop | |
| from model.resample import Resample1d | |
| from model.conv import ConvLayer | |
| class UpsamplingBlock(nn.Module): | |
| def __init__(self, n_inputs, n_shortcut, n_outputs, kernel_size, stride, depth, conv_type, res): | |
| super(UpsamplingBlock, self).__init__() | |
| assert(stride > 1) | |
| # CONV 1 for UPSAMPLING | |
| if res == "fixed": | |
| self.upconv = Resample1d(n_inputs, 15, stride, transpose=True) | |
| else: | |
| self.upconv = ConvLayer(n_inputs, n_inputs, kernel_size, stride, conv_type, transpose=True) | |
| self.pre_shortcut_convs = nn.ModuleList([ConvLayer(n_inputs, n_outputs, kernel_size, 1, conv_type)] + | |
| [ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in range(depth - 1)]) | |
| # CONVS to combine high- with low-level information (from shortcut) | |
| self.post_shortcut_convs = nn.ModuleList([ConvLayer(n_outputs + n_shortcut, n_outputs, kernel_size, 1, conv_type)] + | |
| [ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in range(depth - 1)]) | |
| def forward(self, x, shortcut): | |
| # UPSAMPLE HIGH-LEVEL FEATURES | |
| upsampled = self.upconv(x) | |
| for conv in self.pre_shortcut_convs: | |
| upsampled = conv(upsampled) | |
| # Prepare shortcut connection | |
| combined = centre_crop(shortcut, upsampled) | |
| # Combine high- and low-level features | |
| for conv in self.post_shortcut_convs: | |
| combined = conv(torch.cat([combined, centre_crop(upsampled, combined)], dim=1)) | |
| return combined | |
| def get_output_size(self, input_size): | |
| curr_size = self.upconv.get_output_size(input_size) | |
| # Upsampling convs | |
| for conv in self.pre_shortcut_convs: | |
| curr_size = conv.get_output_size(curr_size) | |
| # Combine convolutions | |
| for conv in self.post_shortcut_convs: | |
| curr_size = conv.get_output_size(curr_size) | |
| return curr_size | |
| class DownsamplingBlock(nn.Module): | |
| def __init__(self, n_inputs, n_shortcut, n_outputs, kernel_size, stride, depth, conv_type, res): | |
| super(DownsamplingBlock, self).__init__() | |
| assert(stride > 1) | |
| self.kernel_size = kernel_size | |
| self.stride = stride | |
| # CONV 1 | |
| self.pre_shortcut_convs = nn.ModuleList([ConvLayer(n_inputs, n_shortcut, kernel_size, 1, conv_type)] + | |
| [ConvLayer(n_shortcut, n_shortcut, kernel_size, 1, conv_type) for _ in range(depth - 1)]) | |
| self.post_shortcut_convs = nn.ModuleList([ConvLayer(n_shortcut, n_outputs, kernel_size, 1, conv_type)] + | |
| [ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in | |
| range(depth - 1)]) | |
| # CONV 2 with decimation | |
| if res == "fixed": | |
| self.downconv = Resample1d(n_outputs, 15, stride) # Resampling with fixed-size sinc lowpass filter | |
| else: | |
| self.downconv = ConvLayer(n_outputs, n_outputs, kernel_size, stride, conv_type) | |
| def forward(self, x): | |
| # PREPARING SHORTCUT FEATURES | |
| shortcut = x | |
| for conv in self.pre_shortcut_convs: | |
| shortcut = conv(shortcut) | |
| # PREPARING FOR DOWNSAMPLING | |
| out = shortcut | |
| for conv in self.post_shortcut_convs: | |
| out = conv(out) | |
| # DOWNSAMPLING | |
| out = self.downconv(out) | |
| return out, shortcut | |
| def get_input_size(self, output_size): | |
| curr_size = self.downconv.get_input_size(output_size) | |
| for conv in reversed(self.post_shortcut_convs): | |
| curr_size = conv.get_input_size(curr_size) | |
| for conv in reversed(self.pre_shortcut_convs): | |
| curr_size = conv.get_input_size(curr_size) | |
| return curr_size | |
| class Waveunet(nn.Module): | |
| def __init__(self, num_inputs, num_channels, num_outputs, instruments, kernel_size, target_output_size, conv_type, res, separate=False, depth=1, strides=2): | |
| super(Waveunet, self).__init__() | |
| self.num_levels = len(num_channels) | |
| self.strides = strides | |
| self.kernel_size = kernel_size | |
| self.num_inputs = num_inputs | |
| self.num_outputs = num_outputs | |
| self.depth = depth | |
| self.instruments = instruments | |
| self.separate = separate | |
| # Only odd filter kernels allowed | |
| assert(kernel_size % 2 == 1) | |
| self.waveunets = nn.ModuleDict() | |
| model_list = instruments if separate else ["ALL"] | |
| # Create a model for each source if we separate sources separately, otherwise only one (model_list=["ALL"]) | |
| for instrument in model_list: | |
| module = nn.Module() | |
| module.downsampling_blocks = nn.ModuleList() | |
| module.upsampling_blocks = nn.ModuleList() | |
| for i in range(self.num_levels - 1): | |
| in_ch = num_inputs if i == 0 else num_channels[i] | |
| module.downsampling_blocks.append( | |
| DownsamplingBlock(in_ch, num_channels[i], num_channels[i+1], kernel_size, strides, depth, conv_type, res)) | |
| for i in range(0, self.num_levels - 1): | |
| module.upsampling_blocks.append( | |
| UpsamplingBlock(num_channels[-1-i], num_channels[-2-i], num_channels[-2-i], kernel_size, strides, depth, conv_type, res)) | |
| module.bottlenecks = nn.ModuleList( | |
| [ConvLayer(num_channels[-1], num_channels[-1], kernel_size, 1, conv_type) for _ in range(depth)]) | |
| # Output conv | |
| outputs = num_outputs if separate else num_outputs * len(instruments) | |
| module.output_conv = nn.Conv1d(num_channels[0], outputs, 1) | |
| self.waveunets[instrument] = module | |
| self.set_output_size(target_output_size) | |
| def set_output_size(self, target_output_size): | |
| self.target_output_size = target_output_size | |
| self.input_size, self.output_size = self.check_padding(target_output_size) | |
| print("Using valid convolutions with " + str(self.input_size) + " inputs and " + str(self.output_size) + " outputs") | |
| assert((self.input_size - self.output_size) % 2 == 0) | |
| self.shapes = {"output_start_frame" : (self.input_size - self.output_size) // 2, | |
| "output_end_frame" : (self.input_size - self.output_size) // 2 + self.output_size, | |
| "output_frames" : self.output_size, | |
| "input_frames" : self.input_size} | |
| def check_padding(self, target_output_size): | |
| # Ensure number of outputs covers a whole number of cycles so each output in the cycle is weighted equally during training | |
| bottleneck = 1 | |
| while True: | |
| out = self.check_padding_for_bottleneck(bottleneck, target_output_size) | |
| if out is not False: | |
| return out | |
| bottleneck += 1 | |
| def check_padding_for_bottleneck(self, bottleneck, target_output_size): | |
| module = self.waveunets[[k for k in self.waveunets.keys()][0]] | |
| try: | |
| curr_size = bottleneck | |
| for idx, block in enumerate(module.upsampling_blocks): | |
| curr_size = block.get_output_size(curr_size) | |
| output_size = curr_size | |
| # Bottleneck-Conv | |
| curr_size = bottleneck | |
| for block in reversed(module.bottlenecks): | |
| curr_size = block.get_input_size(curr_size) | |
| for idx, block in enumerate(reversed(module.downsampling_blocks)): | |
| curr_size = block.get_input_size(curr_size) | |
| assert(output_size >= target_output_size) | |
| return curr_size, output_size | |
| except AssertionError as e: | |
| return False | |
| def forward_module(self, x, module): | |
| ''' | |
| A forward pass through a single Wave-U-Net (multiple Wave-U-Nets might be used, one for each source) | |
| :param x: Input mix | |
| :param module: Network module to be used for prediction | |
| :return: Source estimates | |
| ''' | |
| shortcuts = [] | |
| out = x | |
| # DOWNSAMPLING BLOCKS | |
| for block in module.downsampling_blocks: | |
| out, short = block(out) | |
| shortcuts.append(short) | |
| # BOTTLENECK CONVOLUTION | |
| for conv in module.bottlenecks: | |
| out = conv(out) | |
| # UPSAMPLING BLOCKS | |
| for idx, block in enumerate(module.upsampling_blocks): | |
| out = block(out, shortcuts[-1 - idx]) | |
| # OUTPUT CONV | |
| out = module.output_conv(out) | |
| if not self.training: # At test time clip predictions to valid amplitude range | |
| out = out.clamp(min=-1.0, max=1.0) | |
| return out | |
| def forward(self, x, inst=None): | |
| curr_input_size = x.shape[-1] | |
| assert(curr_input_size == self.input_size) # User promises to feed the proper input himself, to get the pre-calculated (NOT the originally desired) output size | |
| if self.separate: | |
| return {inst : self.forward_module(x, self.waveunets[inst])} | |
| else: | |
| assert(len(self.waveunets) == 1) | |
| out = self.forward_module(x, self.waveunets["ALL"]) | |
| out_dict = {} | |
| for idx, inst in enumerate(self.instruments): | |
| out_dict[inst] = out[:, idx * self.num_outputs:(idx + 1) * self.num_outputs] | |
| return out_dict |