Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| import pytorch_lightning as pl | |
| class MLP(pl.LightningModule): | |
| def __init__(self, | |
| filter_channels, | |
| name=None, | |
| res_layers=[], | |
| norm='group', | |
| last_op=None): | |
| super(MLP, self).__init__() | |
| self.filters = nn.ModuleList() | |
| self.norms = nn.ModuleList() | |
| self.res_layers = res_layers | |
| self.norm = norm | |
| self.last_op = last_op | |
| self.name = name | |
| self.activate = nn.LeakyReLU(inplace=True) | |
| for l in range(0, len(filter_channels) - 1): | |
| if l in self.res_layers: | |
| self.filters.append( | |
| nn.Conv1d(filter_channels[l] + filter_channels[0], | |
| filter_channels[l + 1], 1)) | |
| else: | |
| self.filters.append( | |
| nn.Conv1d(filter_channels[l], filter_channels[l + 1], 1)) | |
| if l != len(filter_channels) - 2: | |
| if norm == 'group': | |
| self.norms.append(nn.GroupNorm(32, filter_channels[l + 1])) | |
| elif norm == 'batch': | |
| self.norms.append(nn.BatchNorm1d(filter_channels[l + 1])) | |
| elif norm == 'instance': | |
| self.norms.append(nn.InstanceNorm1d(filter_channels[l + | |
| 1])) | |
| elif norm == 'weight': | |
| self.filters[l] = nn.utils.weight_norm(self.filters[l], | |
| name='weight') | |
| # print(self.filters[l].weight_g.size(), | |
| # self.filters[l].weight_v.size()) | |
| def forward(self, feature): | |
| ''' | |
| feature may include multiple view inputs | |
| args: | |
| feature: [B, C_in, N] | |
| return: | |
| [B, C_out, N] prediction | |
| ''' | |
| y = feature | |
| tmpy = feature | |
| for i, f in enumerate(self.filters): | |
| y = f(y if i not in self.res_layers else torch.cat([y, tmpy], 1)) | |
| if i != len(self.filters) - 1: | |
| if self.norm not in ['batch', 'group', 'instance']: | |
| y = self.activate(y) | |
| else: | |
| y = self.activate(self.norms[i](y)) | |
| if self.last_op is not None: | |
| y = self.last_op(y) | |
| return y | |