File size: 625 Bytes
1e4a2ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch

import torch.nn.functional as F

class LayerNorm(torch.nn.Module):
    def __init__(self, channels, eps=1e-5, onnx=False):
        super().__init__()
        self.channels = channels
        self.eps = eps
        self.onnx = onnx
        self.gamma = torch.nn.Parameter(torch.ones(channels))
        self.beta = torch.nn.Parameter(torch.zeros(channels))

    def forward(self, x):
        x = x.transpose(1, -1)
        return (F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) if self.onnx else F.layer_norm(x, (x.size(-1),), self.gamma, self.beta, self.eps)).transpose(1, -1)