import numpy as np
import torch as t
import models.utils.dist_adapter as dist
import soundfile
import librosa
from models.utils.dist_utils import print_once

class DefaultSTFTValues:
    def __init__(self, hps):
        self.sr = hps.sr
        self.n_fft = 2048
        self.hop_length = 256
        self.window_size = 6 * self.hop_length

class STFTValues:
    def __init__(self, hps, n_fft, hop_length, window_size):
        self.sr = hps.sr
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.window_size = window_size

def calculate_bandwidth(dataset, hps, duration=600):
    hps = DefaultSTFTValues(hps)
    n_samples = int(dataset.sr * duration)
    l1, total, total_sq, n_seen, idx = 0.0, 0.0, 0.0, 0.0, dist.get_rank()
    spec_norm_total, spec_nelem = 0.0, 0.0
    while n_seen < n_samples:
        x = dataset[idx]
        if isinstance(x, (tuple, list)):
            x, y = x
        samples = x.astype(np.float64)
        stft = librosa.core.stft(np.mean(samples, axis=1), hps.n_fft, hop_length=hps.hop_length, win_length=hps.window_size)
        spec = np.absolute(stft)
        spec_norm_total += np.linalg.norm(spec)
        spec_nelem += 1
        n_seen += int(np.prod(samples.shape))
        l1 += np.sum(np.abs(samples))
        total += np.sum(samples)
        total_sq += np.sum(samples ** 2)
        idx += max(16, dist.get_world_size())

    if dist.is_available():
        from jukebox.utils.dist_utils import allreduce
        n_seen = allreduce(n_seen)
        total = allreduce(total)
        total_sq = allreduce(total_sq)
        l1 = allreduce(l1)
        spec_nelem = allreduce(spec_nelem)
        spec_norm_total = allreduce(spec_norm_total)

    mean = total / n_seen
    bandwidth = dict(l2 = total_sq / n_seen - mean ** 2,
                     l1 = l1 / n_seen,
                     spec = spec_norm_total / spec_nelem)
    print_once(bandwidth)
    return bandwidth

def audio_preprocess(x, hps):
    # Extra layer in case we want to experiment with different preprocessing
    # For two channel, blend randomly into mono (standard is .5 left, .5 right)

    # x: NTC
    # x = x.float()
    # if x.shape[-1]==2:
    #     if hps.aug_blend:
    #         mix=t.rand((x.shape[0],1), device=x.device) #np.random.rand()
    #     else:
    #         mix = 0.5
    #     x=(mix*x[:,:,0]+(1-mix)*x[:,:,1])
    # elif x.shape[-1]==1:
    #     x=x[:,:,0]
    # else:
    #     assert False, f'Expected channels {hps.channels}. Got unknown {x.shape[-1]} channels'

    # # x: NT -> NTC
    # x = x.unsqueeze(2)
    return x

def audio_postprocess(x, hps):
    return x

def stft(sig, hps):
    return t.stft(sig, hps.n_fft, hps.hop_length, win_length=hps.window_size, window=t.hann_window(hps.window_size, device=sig.device))

def spec(x, hps):
    return t.norm(stft(x, hps), p=2, dim=-1)

def norm(x):
    return (x.view(x.shape[0], -1) ** 2).sum(dim=-1).sqrt()

def squeeze(x):
    if len(x.shape) == 3:
        assert x.shape[-1] in [1,2]
        x = t.mean(x, -1)
    if len(x.shape) != 2:
        raise ValueError(f'Unknown input shape {x.shape}')
    return x

def spectral_loss(x_in, x_out, hps):
    hps = DefaultSTFTValues(hps)
    spec_in = spec(squeeze(x_in.float()), hps)
    spec_out = spec(squeeze(x_out.float()), hps)
    return norm(spec_in - spec_out)

def multispectral_loss(x_in, x_out, hps):
    losses = []
    assert len(hps.multispec_loss_n_fft) == len(hps.multispec_loss_hop_length) == len(hps.multispec_loss_window_size)
    args = [hps.multispec_loss_n_fft,
            hps.multispec_loss_hop_length,
            hps.multispec_loss_window_size]
    for n_fft, hop_length, window_size in zip(*args):
        hps = STFTValues(hps, n_fft, hop_length, window_size)
        spec_in = spec(squeeze(x_in.float()), hps)
        spec_out = spec(squeeze(x_out.float()), hps)
        losses.append(norm(spec_in - spec_out))
    return sum(losses) / len(losses)

def spectral_convergence(x_in, x_out, hps, epsilon=2e-3):
    hps = DefaultSTFTValues(hps)
    spec_in = spec(squeeze(x_in.float()), hps)
    spec_out = spec(squeeze(x_out.float()), hps)

    gt_norm = norm(spec_in)
    residual_norm = norm(spec_in - spec_out)
    mask = (gt_norm > epsilon).float()
    return (residual_norm * mask) / t.clamp(gt_norm, min=epsilon)

def log_magnitude_loss(x_in, x_out, hps, epsilon=1e-4):
    hps = DefaultSTFTValues(hps)
    spec_in = t.log(spec(squeeze(x_in.float()), hps) + epsilon)
    spec_out = t.log(spec(squeeze(x_out.float()), hps) + epsilon)
    return t.mean(t.abs(spec_in - spec_out))

def load_audio(file, sr, offset, duration, mono=False):
    # Librosa loads more filetypes than soundfile
    x, _ = librosa.load(file, sr=sr, mono=mono, offset=offset/sr, duration=duration/sr)
    if len(x.shape) == 1:
        x = x.reshape((1, -1))
    return x    


def save_wav(fname, aud, sr):
    # clip before saving?
    aud = t.clamp(aud, -1, 1).cpu().numpy()
    for i in list(range(aud.shape[0])):
        soundfile.write(f'{fname}/item_{i}.wav', aud[i], samplerate=sr, format='wav')