from utils.cwt import get_lf0_cwt
import torch.optim
import torch.utils.data
import importlib
from utils.indexed_datasets import IndexedDataset
from utils.pitch_utils import norm_interp_f0, denorm_f0, f0_to_coarse
import numpy as np
from tasks.base_task import BaseDataset
import torch
import torch.optim
import torch.utils.data
import utils
import torch.distributions
from utils.hparams import hparams
from utils.pitch_utils import norm_interp_f0
from resemblyzer import VoiceEncoder
import json
from data_gen.tts.data_gen_utils import build_phone_encoder

class BaseTTSDataset(BaseDataset):
    def __init__(self, prefix, shuffle=False, test_items=None, test_sizes=None, data_dir=None):
        super().__init__(shuffle)
        self.data_dir = hparams['binary_data_dir'] if data_dir is None else data_dir
        self.prefix = prefix
        self.hparams = hparams
        self.indexed_ds = None
        self.ext_mel2ph = None

        def load_size():
            self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy')

        if prefix == 'test' or hparams['inference']:
            if test_items is not None:
                self.indexed_ds, self.sizes = test_items, test_sizes
            else:
                load_size()
            if hparams['num_test_samples'] > 0:
                self.avail_idxs = [x for x in range(hparams['num_test_samples']) \
                                   if x < len(self.sizes)]
                if len(hparams['test_ids']) > 0:
                    self.avail_idxs = hparams['test_ids'] + self.avail_idxs
            else:
                self.avail_idxs = list(range(len(self.sizes)))
        else:
            load_size()
            self.avail_idxs = list(range(len(self.sizes)))

        if hparams['min_frames'] > 0:
            self.avail_idxs = [
                x for x in self.avail_idxs if self.sizes[x] >= hparams['min_frames']]
        self.sizes = [self.sizes[i] for i in self.avail_idxs]

    def _get_item(self, index):
        if hasattr(self, 'avail_idxs') and self.avail_idxs is not None:
            index = self.avail_idxs[index]
        if self.indexed_ds is None:
            self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
        return self.indexed_ds[index]

    def __getitem__(self, index):
        hparams = self.hparams
        item = self._get_item(index)
        assert len(item['mel']) == self.sizes[index], (len(item['mel']), self.sizes[index])
        max_frames = hparams['max_frames']
        spec = torch.Tensor(item['mel'])[:max_frames]
        max_frames = spec.shape[0] // hparams['frames_multiple'] * hparams['frames_multiple']
        spec = spec[:max_frames]
        phone = torch.LongTensor(item['phone'][:hparams['max_input_tokens']])
        sample = {
            "id": index,
            "item_name": item['item_name'],
            "text": item['txt'],
            "txt_token": phone,
            "mel": spec,
            "mel_nonpadding": spec.abs().sum(-1) > 0,
        }
        if hparams['use_spk_embed']:
            sample["spk_embed"] = torch.Tensor(item['spk_embed'])
        if hparams['use_spk_id']:
            sample["spk_id"] = item['spk_id']
        return sample

    def collater(self, samples):
        if len(samples) == 0:
            return {}
        hparams = self.hparams
        id = torch.LongTensor([s['id'] for s in samples])
        item_names = [s['item_name'] for s in samples]
        text = [s['text'] for s in samples]
        txt_tokens = utils.collate_1d([s['txt_token'] for s in samples], 0)
        mels = utils.collate_2d([s['mel'] for s in samples], 0.0)
        txt_lengths = torch.LongTensor([s['txt_token'].numel() for s in samples])
        mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples])

        batch = {
            'id': id,
            'item_name': item_names,
            'nsamples': len(samples),
            'text': text,
            'txt_tokens': txt_tokens,
            'txt_lengths': txt_lengths,
            'mels': mels,
            'mel_lengths': mel_lengths,
        }

        if hparams['use_spk_embed']:
            spk_embed = torch.stack([s['spk_embed'] for s in samples])
            batch['spk_embed'] = spk_embed
        if hparams['use_spk_id']:
            spk_ids = torch.LongTensor([s['spk_id'] for s in samples])
            batch['spk_ids'] = spk_ids
        return batch


class FastSpeechDataset(BaseTTSDataset):
    def __init__(self, prefix, shuffle=False, test_items=None, test_sizes=None, data_dir=None):
        super().__init__(prefix, shuffle, test_items, test_sizes, data_dir)
        self.f0_mean, self.f0_std = hparams.get('f0_mean', None), hparams.get('f0_std', None)
        if prefix == 'test' and hparams['test_input_dir'] != '':
            self.data_dir = hparams['test_input_dir']
            self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
            self.indexed_ds = sorted(self.indexed_ds, key=lambda item: item['item_name'])
            items = {}
            for i in range(len(self.indexed_ds)):
                speaker = self.indexed_ds[i]['item_name'].split('_')[0]
                if speaker not in items.keys():
                    items[speaker] = [i]
                else:
                    items[speaker].append(i)
            sort_item = sorted(items.values(), key=lambda item_pre_speaker: len(item_pre_speaker), reverse=True)
            self.avail_idxs = [n for a in sort_item for n in a][:hparams['num_test_samples']]
            self.indexed_ds, self.sizes = self.load_test_inputs()
            self.avail_idxs = [i for i in range(hparams['num_test_samples'])]

        if hparams['pitch_type'] == 'cwt':
            _, hparams['cwt_scales'] = get_lf0_cwt(np.ones(10))

    def __getitem__(self, index):
        sample = super(FastSpeechDataset, self).__getitem__(index)
        item = self._get_item(index)
        hparams = self.hparams
        max_frames = hparams['max_frames']
        spec = sample['mel']
        T = spec.shape[0]
        phone = sample['txt_token']
        sample['energy'] = (spec.exp() ** 2).sum(-1).sqrt()
        sample['mel2ph'] = mel2ph = torch.LongTensor(item['mel2ph'])[:T] if 'mel2ph' in item else None
        if hparams['use_pitch_embed']:
            assert 'f0' in item
            if hparams.get('normalize_pitch', False):
                f0 = item["f0"]
                if len(f0 > 0) > 0 and f0[f0 > 0].std() > 0:
                    f0[f0 > 0] = (f0[f0 > 0] - f0[f0 > 0].mean()) / f0[f0 > 0].std() * hparams['f0_std'] + \
                                 hparams['f0_mean']
                    f0[f0 > 0] = f0[f0 > 0].clip(min=60, max=500)
                pitch = f0_to_coarse(f0)
                pitch = torch.LongTensor(pitch[:max_frames])
            else:
                pitch = torch.LongTensor(item.get("pitch"))[:max_frames] if "pitch" in item else None
            f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams)
            uv = torch.FloatTensor(uv)
            f0 = torch.FloatTensor(f0)
            if hparams['pitch_type'] == 'cwt':
                cwt_spec = torch.Tensor(item['cwt_spec'])[:max_frames]
                f0_mean = item.get('f0_mean', item.get('cwt_mean'))
                f0_std = item.get('f0_std', item.get('cwt_std'))
                sample.update({"cwt_spec": cwt_spec, "f0_mean": f0_mean, "f0_std": f0_std})
            elif hparams['pitch_type'] == 'ph':
                if "f0_ph" in item:
                    f0 = torch.FloatTensor(item['f0_ph'])
                else:
                    f0 = denorm_f0(f0, None, hparams)
                f0_phlevel_sum = torch.zeros_like(phone).float().scatter_add(0, mel2ph - 1, f0)
                f0_phlevel_num = torch.zeros_like(phone).float().scatter_add(
                    0, mel2ph - 1, torch.ones_like(f0)).clamp_min(1)
                f0_ph = f0_phlevel_sum / f0_phlevel_num
                f0, uv = norm_interp_f0(f0_ph, hparams)
        else:
            f0 = uv = torch.zeros_like(mel2ph)
            pitch = None
        sample["f0"], sample["uv"], sample["pitch"] = f0, uv, pitch
        if hparams['use_spk_embed']:
            sample["spk_embed"] = torch.Tensor(item['spk_embed'])
        if hparams['use_spk_id']:
            sample["spk_id"] = item['spk_id']
        return sample

    def collater(self, samples):
        if len(samples) == 0:
            return {}
        hparams = self.hparams
        batch = super(FastSpeechDataset, self).collater(samples)
        f0 = utils.collate_1d([s['f0'] for s in samples], 0.0)
        pitch = utils.collate_1d([s['pitch'] for s in samples]) if samples[0]['pitch'] is not None else None
        uv = utils.collate_1d([s['uv'] for s in samples])
        energy = utils.collate_1d([s['energy'] for s in samples], 0.0)
        mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \
            if samples[0]['mel2ph'] is not None else None
        batch.update({
            'mel2ph': mel2ph,
            'energy': energy,
            'pitch': pitch,
            'f0': f0,
            'uv': uv,
        })
        if hparams['pitch_type'] == 'cwt':
            cwt_spec = utils.collate_2d([s['cwt_spec'] for s in samples])
            f0_mean = torch.Tensor([s['f0_mean'] for s in samples])
            f0_std = torch.Tensor([s['f0_std'] for s in samples])
            batch.update({'cwt_spec': cwt_spec, 'f0_mean': f0_mean, 'f0_std': f0_std})
        return batch

    def load_test_inputs(self):
        binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizerr.BaseBinarizer')
        pkg = ".".join(binarizer_cls.split(".")[:-1])
        cls_name = binarizer_cls.split(".")[-1]
        binarizer_cls = getattr(importlib.import_module(pkg), cls_name)
        ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json"
        ph_set = json.load(open(ph_set_fn, 'r'))
        print("| phone set: ", ph_set)
        phone_encoder = build_phone_encoder(hparams['binary_data_dir'])
        word_encoder = None
        voice_encoder = VoiceEncoder().cuda()
        encoder = [phone_encoder, word_encoder]
        sizes = []
        items = []
        for i in range(len(self.avail_idxs)):
            item = self._get_item(i)

            item2tgfn = f"{hparams['test_input_dir'].replace('binary', 'processed')}/mfa_outputs/{item['item_name']}.TextGrid"
            item = binarizer_cls.process_item(item['item_name'], item['ph'], item['txt'], item2tgfn,
                                              item['wav_fn'], item['spk_id'], encoder, hparams['binarization_args'])
            item['spk_embed'] = voice_encoder.embed_utterance(item['wav']) \
                if hparams['binarization_args']['with_spk_embed'] else None  # 判断是否保存embedding文件
            items.append(item)
            sizes.append(item['len'])
        return items, sizes

class FastSpeechWordDataset(FastSpeechDataset):
    def __getitem__(self, index):
        sample = super(FastSpeechWordDataset, self).__getitem__(index)
        item = self._get_item(index)
        max_frames = hparams['max_frames']
        sample["ph_words"] = item["ph_words"]
        sample["word_tokens"] = torch.LongTensor(item["word_tokens"])
        sample["mel2word"] = torch.LongTensor(item.get("mel2word"))[:max_frames]
        sample["ph2word"] = torch.LongTensor(item['ph2word'][:hparams['max_input_tokens']])
        return sample

    def collater(self, samples):
        batch = super(FastSpeechWordDataset, self).collater(samples)
        ph_words = [s['ph_words'] for s in samples]
        batch['ph_words'] = ph_words
        word_tokens = utils.collate_1d([s['word_tokens'] for s in samples], 0)
        batch['word_tokens'] = word_tokens
        mel2word = utils.collate_1d([s['mel2word'] for s in samples], 0)
        batch['mel2word'] = mel2word
        ph2word = utils.collate_1d([s['ph2word'] for s in samples], 0)
        batch['ph2word'] = ph2word
        return batch