"""
Code is from https://github.com/sony/bigvsan/blob/main/san_modules.py

Paper:  Shibuya, T., Takida, Y., Mitsufuji, Y., "BigVSAN: Enhancing GAN-based Neural Vocoders with Slicing Adversarial Network," Preprint.
        https://arxiv.org/pdf/2309.02836.pdf
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


def _normalize(tensor, dim):
    denom = tensor.norm(p=2.0, dim=dim, keepdim=True).clamp_min(1e-12)
    return tensor / denom


class SANConv1d(nn.Conv1d):

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 bias=True,
                 padding_mode='zeros',
                 device=None,
                 dtype=None
                 ):
        super(SANConv1d, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding=padding, dilation=dilation,
            groups=1, bias=bias, padding_mode=padding_mode, device=device, dtype=dtype)
        scale = self.weight.norm(p=2.0, dim=[1, 2], keepdim=True).clamp_min(1e-12)
        self.weight = nn.parameter.Parameter(self.weight / scale.expand_as(self.weight))
        self.scale = nn.parameter.Parameter(scale.view(out_channels))
        if bias:
            self.bias = nn.parameter.Parameter(torch.zeros(in_channels, device=device, dtype=dtype))
        else:
            self.register_parameter('bias', None)
        self.normalize_weight()

    def forward(self, input, flg_train=False):
        if self.bias is not None:
            input = input + self.bias.view(self.in_channels, 1)
        normalized_weight = self._get_normalized_weight()
        scale = self.scale.view(self.out_channels, 1)
        if flg_train:
            out_fun = F.conv1d(input, normalized_weight.detach(), None, self.stride,
                               self.padding, self.dilation, self.groups)
            out_dir = F.conv1d(input.detach(), normalized_weight, None, self.stride,
                               self.padding, self.dilation, self.groups)
            out = [out_fun * scale, out_dir * scale.detach()]
        else:
            out = F.conv1d(input, normalized_weight, None, self.stride,
                           self.padding, self.dilation, self.groups)
            out = out * scale
        return out

    @torch.no_grad()
    def normalize_weight(self):
        self.weight.data = self._get_normalized_weight()

    def _get_normalized_weight(self):
        return _normalize(self.weight, dim=[1, 2])


class SANConv2d(nn.Conv2d):

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 bias=True,
                 padding_mode='zeros',
                 device=None,
                 dtype=None
                 ):
        super(SANConv2d, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding=padding, dilation=dilation,
            groups=1, bias=bias, padding_mode=padding_mode, device=device, dtype=dtype)
        scale = self.weight.norm(p=2.0, dim=[1, 2, 3], keepdim=True).clamp_min(1e-12)
        self.weight = nn.parameter.Parameter(self.weight / scale.expand_as(self.weight))
        self.scale = nn.parameter.Parameter(scale.view(out_channels))
        if bias:
            self.bias = nn.parameter.Parameter(torch.zeros(in_channels, device=device, dtype=dtype))
        else:
            self.register_parameter('bias', None)
        self.normalize_weight()

    def forward(self, input, flg_train=False):
        if self.bias is not None:
            input = input + self.bias.view(self.in_channels, 1, 1)
        normalized_weight = self._get_normalized_weight()
        scale = self.scale.view(self.out_channels, 1, 1)
        if flg_train:
            out_fun = F.conv2d(input, normalized_weight.detach(), None, self.stride,
                               self.padding, self.dilation, self.groups)
            out_dir = F.conv2d(input.detach(), normalized_weight, None, self.stride,
                               self.padding, self.dilation, self.groups)
            out = [out_fun * scale, out_dir * scale.detach()]
        else:
            out = F.conv2d(input, normalized_weight, None, self.stride,
                           self.padding, self.dilation, self.groups)
            out = out * scale
        return out

    @torch.no_grad()
    def normalize_weight(self):
        self.weight.data = self._get_normalized_weight()

    def _get_normalized_weight(self):
        return _normalize(self.weight, dim=[1, 2, 3])


class SANEmbedding(nn.Embedding):

    def __init__(self, num_embeddings, embedding_dim,
                 scale_grad_by_freq=False,
                 sparse=False, _weight=None,
                 device=None, dtype=None):
        super(SANEmbedding, self).__init__(
            num_embeddings, embedding_dim, padding_idx=None,
            max_norm=None, norm_type=2., scale_grad_by_freq=scale_grad_by_freq,
            sparse=sparse, _weight=_weight,
            device=device, dtype=dtype)
        scale = self.weight.norm(p=2.0, dim=1, keepdim=True).clamp_min(1e-12)
        self.weight = nn.parameter.Parameter(self.weight / scale.expand_as(self.weight))
        self.scale = nn.parameter.Parameter(scale)

    def forward(self, input, flg_train=False):
        out = F.embedding(
            input, self.weight, self.padding_idx, self.max_norm,
            self.norm_type, self.scale_grad_by_freq, self.sparse)
        out = _normalize(out, dim=-1)
        scale = F.embedding(
            input, self.scale, self.padding_idx, self.max_norm,
            self.norm_type, self.scale_grad_by_freq, self.sparse)
        if flg_train:
            out_fun = out.detach()
            out_dir = out
            out = [out_fun * scale, out_dir * scale.detach()]
        else:
            out = out * scale
        return out

    @torch.no_grad()
    def normalize_weight(self):
        self.weight.data = _normalize(self.weight, dim=1)