import numpy as np
import pandas as pd

import librosa

from pathlib import Path
from typing import Callable, Literal, Optional

def load_dataset(
    paths: list,
    remove_label: list = [""],
    sr: int = 22050,
    method  = "fix_length",
    max_time: float = 4.0):
    """Folder dataset in memory loader (return fully loaded pandas dataframe).
    - For sklearn, load the whole dataset if possible otherwise use `proportion` to only load a part of the dataset.
    - For pytorch, load the whole dataset if possible otherwise use `proportion` to only load a part of the dataset.
        And convert output to Tensor on the fly.

    Use `to_numpy(df.y)` to extract a numpy matrix with a (n_row, ...) shape.
    Expect a dataset folder structure as: paths = [paths1, paths2, ...]
        - paths1
            - sub1
                - blabla_GroundTruth1.wav
                - blabla_GroundTruth2.wav
            - sub2
                - ...
        - ...

        paths (list[Path]): list of dataset directory to parse.
        remove_label (list, optional): list of label to remove. Defaults to None.. Defaults to [""].
        shuffle (bool, optional): True to suffle the dataframe. Defaults to True.
        proportion (float, optional): Proportion of file to load. Defaults to 1.0.
        sr (int, optional): Sample Rate to resample audio file. Defaults to 22050.
        method (Literal['fix_length';, 'time_stretch'], optional): uniformization method to apply. Defaults to "fix_length".
        max_time (float, optional): Common audio duration . Defaults to 4.0.

        df (pd.DataFrame): A pd.DataFrame with such define column: 
        - absolute_path (str): file-system absolute path of the .wav file.
        - labels (list): list of labels defining the sound file (ie, subdirectories and post _ filename).
        - ground_truth (str): ground_truth label meaning the last one after _ in the sound filename.
        - y_original_signal (np.ndarray): sound signal normalize as `float64` and resample with the given sr by `librosa.load`
        - y_original_duration (float): y_original_signal signal duration.
        - y_uniform (np.ndarray): uniformized sound signal compute from y_original_signal using the chosen uniform method.
        uniform_transform (Callable[[np.ndarray, int], np.ndarray]]): A lambda function to uniformized an audio signal as the same in df.
    data = []
    uniform_transform = lambda y, sr: uniformize(y, sr, method, max_time)
    for path in paths:
        path = Path(path)
        for wav_file in path.rglob("*.wav"):
            wav_file_dict = dict()
            absolute_path = wav_file.absolute()
            *labels, label = absolute_path.relative_to(path.absolute()).parts
            label = label.replace(".wav", "").split("_")
            ground_truth = labels[-1]
            if ground_truth not in remove_label:
                y_original, sr = librosa.load(path=absolute_path, sr=sr) 
                # WARINING : Convert the sampling rate to 22.05 KHz, 
                # normalize the bit depth between -1 and 1 and convert stereo to mono
                wav_file_dict["absolute_path"] = absolute_path
                wav_file_dict["labels"] = labels
                wav_file_dict["ground_truth"] = ground_truth
                ## Save original sound signal
                wav_file_dict["y_original_signal"] = y_original
                duration = librosa.get_duration(y=y_original, sr=sr)
                wav_file_dict["y_original_duration"] = duration
                ## Save uniformized sound signal
                wav_file_dict["y_uniform"] = uniform_transform(y_original, sr)
    df = pd.DataFrame(data)
    return df, uniform_transform

def uniformize(
        audio: np.ndarray,
        sr: int,
        method = "fix_length", 
        max_time: float = 4.0
    if method == "fix_length":
        return librosa.util.fix_length(audio, size=int(np.ceil(max_time*sr)))
    elif method == "time_stretch":
        duration = librosa.get_duration(y=audio, sr=sr)
        return librosa.effects.time_stretch(audio, rate=duration/max_time)

def to_numpy(ds: pd.Series) -> np.ndarray:
    """Transform a pd.Series (ie columns slice) in a numpy array with the shape (n_row, cell_array.flatten()).

        df (pd.Series): Columns to transform in numpy.

        np.ndarray: resulting np.array from the ds pd.Series.
    numpy_df = np.stack([*ds.to_numpy()])
    C, *o = numpy_df.shape
    if o:
        return numpy_df.reshape(numpy_df.shape[0],        
        return numpy_df.reshape(numpy_df.shape[0])