Spaces:
Paused
Paused
| from torch import nn | |
| class MDNBlock(nn.Module): | |
| """Mixture of Density Network implementation | |
| https://arxiv.org/pdf/2003.01950.pdf | |
| """ | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.out_channels = out_channels | |
| self.conv1 = nn.Conv1d(in_channels, in_channels, 1) | |
| self.norm = nn.LayerNorm(in_channels) | |
| self.relu = nn.ReLU() | |
| self.dropout = nn.Dropout(0.1) | |
| self.conv2 = nn.Conv1d(in_channels, out_channels, 1) | |
| def forward(self, x): | |
| o = self.conv1(x) | |
| o = o.transpose(1, 2) | |
| o = self.norm(o) | |
| o = o.transpose(1, 2) | |
| o = self.relu(o) | |
| o = self.dropout(o) | |
| mu_sigma = self.conv2(o) | |
| # TODO: check this sigmoid | |
| # mu = torch.sigmoid(mu_sigma[:, :self.out_channels//2, :]) | |
| mu = mu_sigma[:, : self.out_channels // 2, :] | |
| log_sigma = mu_sigma[:, self.out_channels // 2 :, :] | |
| return mu, log_sigma | |