import io
import torch
import PIL.Image
import numpy as np
import scipy.signal
import librosa.display
import matplotlib.pyplot as plt

from torch.functional import Tensor
from torchvision.transforms import ToTensor


def compute_comparison_spectrogram(
    x: np.ndarray,
    y: np.ndarray,
    sample_rate: float = 44100,
    n_fft: int = 2048,
    hop_length: int = 1024,
) -> Tensor:
    X = librosa.stft(x, n_fft=n_fft, hop_length=hop_length)
    X_db = librosa.amplitude_to_db(np.abs(X), ref=np.max)

    Y = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
    Y_db = librosa.amplitude_to_db(np.abs(Y), ref=np.max)

    fig, axs = plt.subplots(figsize=(9, 6), nrows=2)
    img = librosa.display.specshow(
        X_db,
        ax=axs[0],
        hop_length=hop_length,
        x_axis="time",
        y_axis="log",
        sr=sample_rate,
    )
    # fig.colorbar(img, ax=axs[0])
    img = librosa.display.specshow(
        Y_db,
        ax=axs[1],
        hop_length=hop_length,
        x_axis="time",
        y_axis="log",
        sr=sample_rate,
    )
    # fig.colorbar(img, ax=axs[1])

    plt.tight_layout()

    buf = io.BytesIO()
    plt.savefig(buf, format="jpeg")
    buf.seek(0)
    image = PIL.Image.open(buf)
    image = ToTensor()(image)
    plt.close("all")

    return image


def plot_multi_spectrum(
    ys=None,
    Hs=None,
    legend=[],
    title="Spectrum",
    filename=None,
    sample_rate=44100,
    n_fft=1024,
    zero_mean=False,
):

    if Hs is None:
        Hs = []
        for y in ys:
            X = get_average_spectrum(y, n_fft)
            X_sm = smooth_spectrum(X)
            Hs.append(X_sm)

    bin_width = (sample_rate / 2) / (n_fft // 2)
    freqs = np.arange(0, (sample_rate / 2) + bin_width, step=bin_width)

    fig, ax1 = plt.subplots()

    for idx, H in enumerate(Hs):
        H = np.nan_to_num(H)
        H = np.clip(H, 0, np.max(H))
        H_dB = 20 * np.log10(H + 1e-8)
        if zero_mean:
            H_dB -= np.mean(H_dB)
        if "Target" in legend[idx]:
            ax1.plot(freqs, H_dB, linestyle="--", color="k")
        else:
            ax1.plot(freqs, H_dB)

    plt.legend(legend)

    ax1.set_xscale("log")
    ax1.set_ylim([-80, 0])
    ax1.set_xlim([100, 11000])
    plt.title(title)
    plt.ylabel("Magnitude (dB)")
    plt.xlabel("Frequency (Hz)")
    plt.grid(c="lightgray", which="both")

    if filename is not None:
        plt.savefig(f"{filename}.png", dpi=300)

    plt.tight_layout()

    buf = io.BytesIO()
    plt.savefig(buf, format="jpeg")
    buf.seek(0)
    image = PIL.Image.open(buf)
    image = ToTensor()(image)
    plt.close("all")

    return image


def smooth_spectrum(H):
    # apply Savgol filter for smoothed target curve
    return scipy.signal.savgol_filter(H, 1025, 2)


def get_average_spectrum(x, n_fft):
    X = torch.stft(x, n_fft, return_complex=True, normalized=True)
    X = X.abs()  # convert to magnitude
    X = X.mean(dim=-1).view(-1)  # average across frames
    return X