from modules.commons.common_layers import *
from modules.commons.common_layers import Embedding
from modules.fastspeech.tts_modules import FastspeechDecoder, DurationPredictor, LengthRegulator, PitchPredictor, \
    EnergyPredictor, FastspeechEncoder
from utils.cwt import cwt2f0
from utils.hparams import hparams
from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0

FS_ENCODERS = {
    'fft': lambda hp: FastspeechEncoder(
        hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'],
        num_heads=hp['num_heads']),
}

FS_DECODERS = {
    'fft': lambda hp: FastspeechDecoder(
        hp['hidden_size'], hp['dec_layers'], hp['dec_ffn_kernel_size'], hp['num_heads']),
}


class FastSpeech2(nn.Module):
    def __init__(self, dictionary, out_dims=None):
        super().__init__()
        # self.dictionary = dictionary
        self.padding_idx = 0
        if not hparams['no_fs2'] if 'no_fs2' in hparams.keys() else True:
            self.enc_layers = hparams['enc_layers']
            self.dec_layers = hparams['dec_layers']
            self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams)
            self.decoder = FS_DECODERS[hparams['decoder_type']](hparams)
        self.hidden_size = hparams['hidden_size']
        # self.encoder_embed_tokens = self.build_embedding(self.dictionary, self.hidden_size)
        self.out_dims = out_dims
        if out_dims is None:
            self.out_dims = hparams['audio_num_mel_bins']
        self.mel_out = Linear(self.hidden_size, self.out_dims, bias=True)
        #=========not used===========
        # if hparams['use_spk_id']:
        #     self.spk_embed_proj = Embedding(hparams['num_spk'] + 1, self.hidden_size)
        #     if hparams['use_split_spk_id']:
        #         self.spk_embed_f0 = Embedding(hparams['num_spk'] + 1, self.hidden_size)
        #         self.spk_embed_dur = Embedding(hparams['num_spk'] + 1, self.hidden_size)
        # elif hparams['use_spk_embed']:
        #     self.spk_embed_proj = Linear(256, self.hidden_size, bias=True)
        predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
        # self.dur_predictor = DurationPredictor(
        #     self.hidden_size,
        #     n_chans=predictor_hidden,
        #     n_layers=hparams['dur_predictor_layers'],
        #     dropout_rate=hparams['predictor_dropout'], padding=hparams['ffn_padding'],
        #     kernel_size=hparams['dur_predictor_kernel'])
        # self.length_regulator = LengthRegulator()
        if hparams['use_pitch_embed']:
            self.pitch_embed = Embedding(300, self.hidden_size, self.padding_idx)
            if hparams['pitch_type'] == 'cwt':
                h = hparams['cwt_hidden_size']
                cwt_out_dims = 10
                if hparams['use_uv']:
                    cwt_out_dims = cwt_out_dims + 1
                self.cwt_predictor = nn.Sequential(
                    nn.Linear(self.hidden_size, h),
                    PitchPredictor(
                        h,
                        n_chans=predictor_hidden,
                        n_layers=hparams['predictor_layers'],
                        dropout_rate=hparams['predictor_dropout'], odim=cwt_out_dims,
                        padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel']))
                self.cwt_stats_layers = nn.Sequential(
                    nn.Linear(self.hidden_size, h), nn.ReLU(),
                    nn.Linear(h, h), nn.ReLU(), nn.Linear(h, 2)
                )
            else:
                self.pitch_predictor = PitchPredictor(
                    self.hidden_size,
                    n_chans=predictor_hidden,
                    n_layers=hparams['predictor_layers'],
                    dropout_rate=hparams['predictor_dropout'],
                    odim=2 if hparams['pitch_type'] == 'frame' else 1,
                    padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
        if hparams['use_energy_embed']:
            self.energy_embed = Embedding(256, self.hidden_size, self.padding_idx)
            # self.energy_predictor = EnergyPredictor(
            #     self.hidden_size,
            #     n_chans=predictor_hidden,
            #     n_layers=hparams['predictor_layers'],
            #     dropout_rate=hparams['predictor_dropout'], odim=1,
            #     padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])

    # def build_embedding(self, dictionary, embed_dim):
    #     num_embeddings = len(dictionary)
    #     emb = Embedding(num_embeddings, embed_dim, self.padding_idx)
    #     return emb

    def forward(self, hubert, mel2ph=None, spk_embed=None,
                ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=True,
                spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
        ret = {}
        if not hparams['no_fs2'] if 'no_fs2' in hparams.keys() else True:
            encoder_out =self.encoder(hubert)  # [B, T, C]
        else:
            encoder_out =hubert
        src_nonpadding = (hubert!=0).any(-1)[:,:,None]

        # add ref style embed
        # Not implemented
        # variance encoder
        var_embed = 0

        # encoder_out_dur denotes encoder outputs for duration predictor
        # in speech adaptation, duration predictor use old speaker embedding
        if hparams['use_spk_embed']:
            spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
        elif hparams['use_spk_id']:
            spk_embed_id = spk_embed
            if spk_embed_dur_id is None:
                spk_embed_dur_id = spk_embed_id
            if spk_embed_f0_id is None:
                spk_embed_f0_id = spk_embed_id
            spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :]
            spk_embed_dur = spk_embed_f0 = spk_embed
            if hparams['use_split_spk_id']:
                spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :]
                spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :]
        else:
            spk_embed_dur = spk_embed_f0 = spk_embed = 0

        # add dur
        # dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding

        # mel2ph = self.add_dur(dur_inp, mel2ph, hubert, ret)
        ret['mel2ph'] = mel2ph

        decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])

        mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
        decoder_inp_origin = decoder_inp = torch.gather(decoder_inp, 1, mel2ph_)  # [B, T, H]

        tgt_nonpadding = (mel2ph > 0).float()[:, :, None]

        # add pitch and energy embed
        pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
        if hparams['use_pitch_embed']:
            pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
            decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
        if hparams['use_energy_embed']:
            decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)

        ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
        if not hparams['no_fs2'] if 'no_fs2' in hparams.keys() else True:
            if skip_decoder:
                return ret
            ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)

        return ret

    def add_dur(self, dur_input, mel2ph, hubert, ret):
        src_padding = (hubert==0).all(-1)
        dur_input = dur_input.detach() + hparams['predictor_grad'] * (dur_input - dur_input.detach())
        if mel2ph is None:
            dur, xs = self.dur_predictor.inference(dur_input, src_padding)
            ret['dur'] = xs
            ret['dur_choice'] = dur
            mel2ph = self.length_regulator(dur, src_padding).detach()
        else:
            ret['dur'] = self.dur_predictor(dur_input, src_padding)
        ret['mel2ph'] = mel2ph
        return mel2ph

    def run_decoder(self, decoder_inp, tgt_nonpadding, ret, infer, **kwargs):
        x = decoder_inp  # [B, T, H]
        x = self.decoder(x)
        x = self.mel_out(x)
        return x * tgt_nonpadding

    def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
        f0 = cwt2f0(cwt_spec, mean, std, hparams['cwt_scales'])
        f0 = torch.cat(
            [f0] + [f0[:, -1:]] * (mel2ph.shape[1] - f0.shape[1]), 1)
        f0_norm = norm_f0(f0, None, hparams)
        return f0_norm

    def out2mel(self, out):
        return out
    
    def add_pitch(self,decoder_inp, f0, uv, mel2ph, ret, encoder_out=None):
        # if hparams['pitch_type'] == 'ph':
        #     pitch_pred_inp = encoder_out.detach() + hparams['predictor_grad'] * (encoder_out - encoder_out.detach())
        #     pitch_padding = (encoder_out.sum().abs() == 0)
        #     ret['pitch_pred'] = pitch_pred = self.pitch_predictor(pitch_pred_inp)
        #     if f0 is None:
        #         f0 = pitch_pred[:, :, 0]
        #     ret['f0_denorm'] = f0_denorm = denorm_f0(f0, None, hparams, pitch_padding=pitch_padding)
        #     pitch = f0_to_coarse(f0_denorm)  # start from 0 [B, T_txt]
        #     pitch = F.pad(pitch, [1, 0])
        #     pitch = torch.gather(pitch, 1, mel2ph)  # [B, T_mel]
        #     pitch_embedding = pitch_embed(pitch)
        #     return pitch_embedding
        
        decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())

        pitch_padding = (mel2ph == 0)

        # if hparams['pitch_type'] == 'cwt':
        #     # NOTE: this part of script is *isolated* from other scripts, which means
        #     #       it may not be compatible with the current version.    
        #     pass
        #     # pitch_padding = None
        #     # ret['cwt'] = cwt_out = self.cwt_predictor(decoder_inp)
        #     # stats_out = self.cwt_stats_layers(encoder_out[:, 0, :])  # [B, 2]
        #     # mean = ret['f0_mean'] = stats_out[:, 0]
        #     # std = ret['f0_std'] = stats_out[:, 1]
        #     # cwt_spec = cwt_out[:, :, :10]
        #     # if f0 is None:
        #     #     std = std * hparams['cwt_std_scale']
        #     #     f0 = self.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
        #     #     if hparams['use_uv']:
        #     #         assert cwt_out.shape[-1] == 11
        #     #         uv = cwt_out[:, :, -1] > 0
        # elif hparams['pitch_ar']:
        #     ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp, f0 if is_training else None)
        #     if f0 is None:
        #         f0 = pitch_pred[:, :, 0]
        # else:
        #ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp)
        # if f0 is None:
        #     f0 = pitch_pred[:, :, 0]
        # if hparams['use_uv'] and uv is None:
        #     uv = pitch_pred[:, :, 1] > 0
        ret['f0_denorm'] = f0_denorm = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding)
        if pitch_padding is not None:
            f0[pitch_padding] = 0
        
        pitch = f0_to_coarse(f0_denorm,hparams)  # start from 0
        ret['pitch_pred']=pitch.unsqueeze(-1)
        # print(ret['pitch_pred'].shape)
        # print(pitch.shape)
        pitch_embedding = self.pitch_embed(pitch)
        return pitch_embedding

    def add_energy(self,decoder_inp, energy, ret):
        decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
        ret['energy_pred'] = energy#energy_pred = self.energy_predictor(decoder_inp)[:, :, 0]
        # if energy is None:
        #     energy = energy_pred
        energy = torch.clamp(energy * 256 // 4, max=255).long() # energy_to_coarse
        energy_embedding = self.energy_embed(energy)
        return energy_embedding

    @staticmethod
    def mel_norm(x):
        return (x + 5.5) / (6.3 / 2) - 1

    @staticmethod
    def mel_denorm(x):
        return (x + 1) * (6.3 / 2) - 5.5