import torch import torch.nn as nn import torch.nn.functional as F import random import numpy as np from scipy.fftpack import fft import wave class Model(nn.Module): def __init__(self, input_dim=1, hidden_dim = 256, tone_class=5, syllable_class=1000): super().__init__() self.input_dim = input_dim self.tone_class = tone_class self.syllable_class = syllable_class # hidden_size = 128*hidden_dim//16 conv_layers = [] in_channels = input_dim # Input channels for the first layer channel_list = [16,16,'p2',32,32,'p2',64,64,'p1',64] # channel_list = [32,'p','p',128] # channel_list = [32,32,64,64,128] for out_channels in channel_list: if out_channels=='p2': conv_layers.append(nn.MaxPool2d(kernel_size=2)) continue elif out_channels=='p1': conv_layers.append(nn.MaxPool2d(kernel_size=1)) continue conv_layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)) conv_layers.append(nn.BatchNorm2d(out_channels)) conv_layers.append(nn.ReLU(inplace=True)) conv_layers.append(nn.Dropout(0.1)) in_channels = out_channels self.conv = nn.Sequential(*conv_layers) self.output = nn.Sequential( nn.Linear(4096, 128), nn.ReLU(), nn.Dropout(0.1), nn.Linear(128,syllable_class) ) def forward(self, x): x = self.conv(x) #[batch_size,channel,length(input_length//4),hidden_dim] x = x.permute((0,2,1,3))#[batch_size,length,channel,hidden_dim] x = x.reshape(x.shape[0],x.shape[1],-1) return self.output(x)