# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from librosa.filters import mel as librosa_mel_fn


def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
    return torch.log(torch.clamp(x, min=clip_val) * C)


def spectral_normalize_torch(magnitudes):
    output = dynamic_range_compression_torch(magnitudes)
    return output


def extract_linear_features(y, cfg, center=False):
    if torch.min(y) < -1.0:
        print("min value is ", torch.min(y))
    if torch.max(y) > 1.0:
        print("max value is ", torch.max(y))

    global hann_window
    hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device)

    y = torch.nn.functional.pad(
        y.unsqueeze(1),
        (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)),
        mode="reflect",
    )
    y = y.squeeze(1)

    # complex tensor as default, then use view_as_real for future pytorch compatibility
    spec = torch.stft(
        y,
        cfg.n_fft,
        hop_length=cfg.hop_size,
        win_length=cfg.win_size,
        window=hann_window[str(y.device)],
        center=center,
        pad_mode="reflect",
        normalized=False,
        onesided=True,
        return_complex=True,
    )
    spec = torch.view_as_real(spec)
    spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
    spec = torch.squeeze(spec, 0)
    return spec


def mel_spectrogram_torch(y, cfg, center=False):
    if torch.min(y) < -1.0:
        print("min value is ", torch.min(y))
    if torch.max(y) > 1.0:
        print("max value is ", torch.max(y))

    global mel_basis, hann_window
    if cfg.fmax not in mel_basis:
        mel = librosa_mel_fn(
            sr=cfg.sample_rate,
            n_fft=cfg.n_fft,
            n_mels=cfg.n_mel,
            fmin=cfg.fmin,
            fmax=cfg.fmax,
        )
        mel_basis[str(cfg.fmax) + "_" + str(y.device)] = (
            torch.from_numpy(mel).float().to(y.device)
        )
        hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device)

    y = torch.nn.functional.pad(
        y.unsqueeze(1),
        (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)),
        mode="reflect",
    )
    y = y.squeeze(1)

    spec = torch.stft(
        y,
        cfg.n_fft,
        hop_length=cfg.hop_size,
        win_length=cfg.win_size,
        window=hann_window[str(y.device)],
        center=center,
        pad_mode="reflect",
        normalized=False,
        onesided=True,
        return_complex=True,
    )

    spec = torch.view_as_real(spec)
    spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)

    spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec)
    spec = spectral_normalize_torch(spec)

    return spec


mel_basis = {}
hann_window = {}


def extract_mel_features(
    y,
    cfg,
    center=False
    # n_fft, n_mel, sampling_rate, hop_size, win_size, fmin, fmax, center=False
):
    """Extract mel features

    Args:
        y (tensor): audio data in tensor
        cfg (dict): configuration in cfg.preprocess
        center (bool, optional): In STFT, whether t-th frame is centered at time t*hop_length. Defaults to False.

    Returns:
        tensor: a tensor containing the mel feature calculated based on STFT result
    """
    if torch.min(y) < -1.0:
        print("min value is ", torch.min(y))
    if torch.max(y) > 1.0:
        print("max value is ", torch.max(y))

    global mel_basis, hann_window
    if cfg.fmax not in mel_basis:
        mel = librosa_mel_fn(
            sr=cfg.sample_rate,
            n_fft=cfg.n_fft,
            n_mels=cfg.n_mel,
            fmin=cfg.fmin,
            fmax=cfg.fmax,
        )
        mel_basis[str(cfg.fmax) + "_" + str(y.device)] = (
            torch.from_numpy(mel).float().to(y.device)
        )
        hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device)

    y = torch.nn.functional.pad(
        y.unsqueeze(1),
        (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)),
        mode="reflect",
    )
    y = y.squeeze(1)

    # complex tensor as default, then use view_as_real for future pytorch compatibility
    spec = torch.stft(
        y,
        cfg.n_fft,
        hop_length=cfg.hop_size,
        win_length=cfg.win_size,
        window=hann_window[str(y.device)],
        center=center,
        pad_mode="reflect",
        normalized=False,
        onesided=True,
        return_complex=True,
    )
    spec = torch.view_as_real(spec)
    spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))

    spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec)
    spec = spectral_normalize_torch(spec)

    return spec.squeeze(0)


def extract_mel_features_tts(
    y,
    cfg,
    center=False,
    taco=False,
    _stft=None,
):
    """Extract mel features

    Args:
        y (tensor): audio data in tensor
        cfg (dict): configuration in cfg.preprocess
        center (bool, optional): In STFT, whether t-th frame is centered at time t*hop_length. Defaults to False.
        taco: use tacotron mel

    Returns:
        tensor: a tensor containing the mel feature calculated based on STFT result
    """
    if not taco:
        if torch.min(y) < -1.0:
            print("min value is ", torch.min(y))
        if torch.max(y) > 1.0:
            print("max value is ", torch.max(y))

        global mel_basis, hann_window
        if cfg.fmax not in mel_basis:
            mel = librosa_mel_fn(
                sr=cfg.sample_rate,
                n_fft=cfg.n_fft,
                n_mels=cfg.n_mel,
                fmin=cfg.fmin,
                fmax=cfg.fmax,
            )
            mel_basis[str(cfg.fmax) + "_" + str(y.device)] = (
                torch.from_numpy(mel).float().to(y.device)
            )
            hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device)

        y = torch.nn.functional.pad(
            y.unsqueeze(1),
            (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)),
            mode="reflect",
        )
        y = y.squeeze(1)

        # complex tensor as default, then use view_as_real for future pytorch compatibility
        spec = torch.stft(
            y,
            cfg.n_fft,
            hop_length=cfg.hop_size,
            win_length=cfg.win_size,
            window=hann_window[str(y.device)],
            center=center,
            pad_mode="reflect",
            normalized=False,
            onesided=True,
            return_complex=True,
        )
        spec = torch.view_as_real(spec)
        spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))

        spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec)
        spec = spectral_normalize_torch(spec)
        spec = spec.squeeze(0)
    else:
        audio = torch.clip(y, -1, 1)
        audio = torch.autograd.Variable(audio, requires_grad=False)
        spec, energy = _stft.mel_spectrogram(audio)
        spec = torch.squeeze(spec, 0)

    spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec)
    spec = spectral_normalize_torch(spec)

    return spec.squeeze(0)


def amplitude_phase_spectrum(y, cfg):
    hann_window = torch.hann_window(cfg.win_size).to(y.device)

    y = torch.nn.functional.pad(
        y.unsqueeze(1),
        (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)),
        mode="reflect",
    )
    y = y.squeeze(1)

    stft_spec = torch.stft(
        y,
        cfg.n_fft,
        hop_length=cfg.hop_size,
        win_length=cfg.win_size,
        window=hann_window,
        center=False,
        return_complex=True,
    )

    stft_spec = torch.view_as_real(stft_spec)
    if stft_spec.size()[0] == 1:
        stft_spec = stft_spec.squeeze(0)

    if len(list(stft_spec.size())) == 4:
        rea = stft_spec[:, :, :, 0]  # [batch_size, n_fft//2+1, frames]
        imag = stft_spec[:, :, :, 1]  # [batch_size, n_fft//2+1, frames]
    else:
        rea = stft_spec[:, :, 0]  # [n_fft//2+1, frames]
        imag = stft_spec[:, :, 1]  # [n_fft//2+1, frames]

    log_amplitude = torch.log(
        torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
    )  # [n_fft//2+1, frames]
    phase = torch.atan2(imag, rea)  # [n_fft//2+1, frames]

    return log_amplitude, phase, rea, imag