import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from modules.FastDiff.module.util import calc_noise_scale_embedding
def swish(x):
    return x * torch.sigmoid(x)


# dilated conv layer with kaiming_normal initialization
# from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py
class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
        super(Conv, self).__init__()
        self.padding = dilation * (kernel_size - 1) // 2
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=self.padding)
        self.conv = nn.utils.weight_norm(self.conv)
        nn.init.kaiming_normal_(self.conv.weight)

    def forward(self, x):
        out = self.conv(x)
        return out


# conv1x1 layer with zero initialization
# from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py but the scale parameter is removed
class ZeroConv1d(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(ZeroConv1d, self).__init__()
        self.conv = nn.Conv1d(in_channel, out_channel, kernel_size=1, padding=0)
        self.conv.weight.data.zero_()
        self.conv.bias.data.zero_()

    def forward(self, x):
        out = self.conv(x)
        return out


# every residual block (named residual layer in paper)
# contains one noncausal dilated conv
class Residual_block(nn.Module):
    def __init__(self, res_channels, skip_channels, dilation, 
                 noise_scale_embed_dim_out, multiband=True):
        super(Residual_block, self).__init__()
        self.res_channels = res_channels

        # the layer-specific fc for noise scale embedding
        self.fc_t = nn.Linear(noise_scale_embed_dim_out, self.res_channels)

        # dilated conv layer
        self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels, kernel_size=3, dilation=dilation)

        # add mel spectrogram upsampler and conditioner conv1x1 layer
        self.upsample_conv2d = torch.nn.ModuleList()
        if multiband is True:
             params = 8
        else:
             params = 16
        for s in [params, params]:  #######  Very  Important!!!!!  #######
            conv_trans2d = torch.nn.ConvTranspose2d(1, 1, (3, 2 * s), padding=(1, s // 2), stride=(1, s))
            conv_trans2d = torch.nn.utils.weight_norm(conv_trans2d)
            torch.nn.init.kaiming_normal_(conv_trans2d.weight)
            self.upsample_conv2d.append(conv_trans2d)
        self.mel_conv = Conv(80, 2 * self.res_channels, kernel_size=1)  # 80 is mel bands

        # residual conv1x1 layer, connect to next residual layer
        self.res_conv = nn.Conv1d(res_channels, res_channels, kernel_size=1)
        self.res_conv = nn.utils.weight_norm(self.res_conv)
        nn.init.kaiming_normal_(self.res_conv.weight)

        # skip conv1x1 layer, add to all skip outputs through skip connections
        self.skip_conv = nn.Conv1d(res_channels, skip_channels, kernel_size=1)
        self.skip_conv = nn.utils.weight_norm(self.skip_conv)
        nn.init.kaiming_normal_(self.skip_conv.weight)

    def forward(self, input_data):
        x, mel_spec, noise_scale_embed = input_data
        h = x
        B, C, L = x.shape   # B, res_channels, L
        assert C == self.res_channels

        # add in noise scale embedding
        part_t = self.fc_t(noise_scale_embed)
        part_t = part_t.view([B, self.res_channels, 1])
        h += part_t

        # dilated conv layer
        h = self.dilated_conv_layer(h)

        # add mel spectrogram as (local) conditioner
        assert mel_spec is not None

        # Upsample spectrogram to size of audio
        mel_spec = torch.unsqueeze(mel_spec, dim=1)  # (B, 1, 80, T')
        mel_spec = F.leaky_relu(self.upsample_conv2d[0](mel_spec), 0.4)
        mel_spec = F.leaky_relu(self.upsample_conv2d[1](mel_spec), 0.4)
        mel_spec = torch.squeeze(mel_spec, dim=1)

        assert(mel_spec.size(2) >= L)
        if mel_spec.size(2) > L:
            mel_spec = mel_spec[:, :, :L]

        mel_spec = self.mel_conv(mel_spec)
        h += mel_spec

        # gated-tanh nonlinearity
        out = torch.tanh(h[:,:self.res_channels,:]) * torch.sigmoid(h[:,self.res_channels:,:])

        # residual and skip outputs
        res = self.res_conv(out)
        assert x.shape == res.shape
        skip = self.skip_conv(out)

        return (x + res) * math.sqrt(0.5), skip  # normalize for training stability


class Residual_group(nn.Module):
    def __init__(self, res_channels, skip_channels, num_res_layers, dilation_cycle, 
                 noise_scale_embed_dim_in, 
                 noise_scale_embed_dim_mid,
                 noise_scale_embed_dim_out, multiband):
        super(Residual_group, self).__init__()
        self.num_res_layers = num_res_layers
        self.noise_scale_embed_dim_in = noise_scale_embed_dim_in

        # the shared two fc layers for noise scale embedding
        self.fc_t1 = nn.Linear(noise_scale_embed_dim_in, noise_scale_embed_dim_mid)
        self.fc_t2 = nn.Linear(noise_scale_embed_dim_mid, noise_scale_embed_dim_out)

        # stack all residual blocks with dilations 1, 2, ... , 512, ... , 1, 2, ..., 512
        self.residual_blocks = nn.ModuleList()
        for n in range(self.num_res_layers):
            self.residual_blocks.append(Residual_block(res_channels, skip_channels, 
                                                       dilation=2 ** (n % dilation_cycle),
                                                       noise_scale_embed_dim_out=noise_scale_embed_dim_out, multiband=multiband))

    def forward(self, input_data):
        x, mel_spectrogram, noise_scales = input_data

        # embed noise scale
        noise_scale_embed = calc_noise_scale_embedding(noise_scales, self.noise_scale_embed_dim_in)
        noise_scale_embed = swish(self.fc_t1(noise_scale_embed))
        noise_scale_embed = swish(self.fc_t2(noise_scale_embed))

        # pass all residual layers
        h = x
        skip = 0
        for n in range(self.num_res_layers):
            h, skip_n = self.residual_blocks[n]((h, mel_spectrogram, noise_scale_embed))  # use the output from last residual layer
            skip += skip_n  # accumulate all skip outputs

        return skip * math.sqrt(1.0 / self.num_res_layers)  # normalize for training stability


class WaveNet_vocoder(nn.Module):
    def __init__(self, in_channels, res_channels, skip_channels, out_channels, 
                 num_res_layers, dilation_cycle, 
                 noise_scale_embed_dim_in, 
                 noise_scale_embed_dim_mid,
                 noise_scale_embed_dim_out, multiband):
        super(WaveNet_vocoder, self).__init__()

        # initial conv1x1 with relu
        self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU())
        
        # all residual layers
        self.residual_layer = Residual_group(res_channels=res_channels, 
                                             skip_channels=skip_channels, 
                                             num_res_layers=num_res_layers, 
                                             dilation_cycle=dilation_cycle,
                                             noise_scale_embed_dim_in=noise_scale_embed_dim_in,
                                             noise_scale_embed_dim_mid=noise_scale_embed_dim_mid,
                                             noise_scale_embed_dim_out=noise_scale_embed_dim_out, multiband=multiband)
        
        # final conv1x1 -> relu -> zeroconv1x1
        self.final_conv = nn.Sequential(Conv(skip_channels, skip_channels, kernel_size=1),
                                        nn.ReLU(),
                                        ZeroConv1d(skip_channels, out_channels))

    def forward(self, input_data):
        audio, mel_spectrogram, noise_scales = input_data  # b x band x T, b x 80 x T', b x 1
        x = audio
        x = self.init_conv(x)
        x = self.residual_layer((x, mel_spectrogram, noise_scales))
        x = self.final_conv(x)

        return x