import logging

import torch
import torch.utils.data

logger = logging.getLogger(__name__)


class TextAudioCollateMultiNSFsid:
    """Zero-pads model inputs and targets"""

    def __init__(self):
        pass

    def __call__(self, batch):
        """Collate's training batch from normalized text and aduio
        PARAMS
        ------
        batch: [text_normalized, spec_normalized, wav_normalized]
        """
        device = batch[0]["spec"].device

        with device:
            # Right zero-pad all one-hot text sequences to max input length
            _, ids_sorted_decreasing = torch.sort(
                torch.tensor([x["spec"].size(1) for x in batch], dtype=torch.int32),
                dim=0,
                descending=True,
            )

            max_spec_len = max([x["spec"].size(1) for x in batch])
            max_wave_len = max([x["wav_gt"]["array"].size(0) for x in batch])
            spec_lengths = torch.zeros(len(batch), dtype=torch.int32)
            wave_lengths = torch.zeros(len(batch), dtype=torch.int32)
            spec_padded = torch.zeros(
                len(batch), batch[0]["spec"].size(0), max_spec_len, dtype=torch.float32
            )
            wave_padded = torch.zeros(len(batch), 1, max_wave_len, dtype=torch.float32)

            max_phone_len = max([x["hubert_feats"].size(0) for x in batch])
            phone_lengths = torch.zeros(len(batch), dtype=torch.int32)
            phone_padded = torch.zeros(
                len(batch),
                max_phone_len,
                batch[0]["hubert_feats"].shape[1],
                dtype=torch.float32,
            )  # (spec, wav, phone, pitch)
            pitch_padded = torch.zeros(len(batch), max_phone_len, dtype=torch.int32)
            pitchf_padded = torch.zeros(len(batch), max_phone_len, dtype=torch.float32)
            # dv = torch.FloatTensor(len(batch), 256)#gin=256
            sid = torch.zeros(len(batch), dtype=torch.int32)

            for i in range(len(ids_sorted_decreasing)):
                row = batch[ids_sorted_decreasing[i]]

                spec = row["spec"]
                spec_padded[i, :, : spec.size(1)] = spec
                spec_lengths[i] = spec.size(1)

                wave = row["wav_gt"]["array"]
                wave_padded[i, :, : wave.size(0)] = wave
                wave_lengths[i] = wave.size(0)

                phone = row["hubert_feats"]
                phone_padded[i, : phone.size(0), :] = phone
                phone_lengths[i] = phone.size(0)

                pitch = row["f0"]
                pitch_padded[i, : pitch.size(0)] = pitch
                pitchf = row["f0nsf"]
                pitchf_padded[i, : pitchf.size(0)] = pitchf

                sid[i] = torch.tensor([0], dtype=torch.int32)

            return (
                phone_padded,
                phone_lengths,
                pitch_padded,
                pitchf_padded,
                spec_padded,
                spec_lengths,
                wave_padded,
                wave_lengths,
                sid,
            )