# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Dataset of audio with a simple description.
"""

from dataclasses import dataclass, fields, replace
import json
from pathlib import Path
import random
import typing as tp

import numpy as np
import torch

from .info_audio_dataset import (
    InfoAudioDataset,
    get_keyword_or_keyword_list
)
from ..modules.conditioners import (
    ConditioningAttributes,
    SegmentWithAttributes,
    WavCondition,
)


EPS = torch.finfo(torch.float32).eps
TARGET_LEVEL_LOWER = -35
TARGET_LEVEL_UPPER = -15


@dataclass
class SoundInfo(SegmentWithAttributes):
    """Segment info augmented with Sound metadata.
    """
    description: tp.Optional[str] = None
    self_wav: tp.Optional[torch.Tensor] = None

    @property
    def has_sound_meta(self) -> bool:
        return self.description is not None

    def to_condition_attributes(self) -> ConditioningAttributes:
        out = ConditioningAttributes()

        for _field in fields(self):
            key, value = _field.name, getattr(self, _field.name)
            if key == 'self_wav':
                out.wav[key] = value
            else:
                out.text[key] = value
        return out

    @staticmethod
    def attribute_getter(attribute):
        if attribute == 'description':
            preprocess_func = get_keyword_or_keyword_list
        else:
            preprocess_func = None
        return preprocess_func

    @classmethod
    def from_dict(cls, dictionary: dict, fields_required: bool = False):
        _dictionary: tp.Dict[str, tp.Any] = {}

        # allow a subset of attributes to not be loaded from the dictionary
        # these attributes may be populated later
        post_init_attributes = ['self_wav']

        for _field in fields(cls):
            if _field.name in post_init_attributes:
                continue
            elif _field.name not in dictionary:
                if fields_required:
                    raise KeyError(f"Unexpected missing key: {_field.name}")
            else:
                preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
                value = dictionary[_field.name]
                if preprocess_func:
                    value = preprocess_func(value)
                _dictionary[_field.name] = value
        return cls(**_dictionary)


class SoundDataset(InfoAudioDataset):
    """Sound audio dataset: Audio dataset with environmental sound-specific metadata.

    Args:
        info_fields_required (bool): Whether all the mandatory metadata fields should be in the loaded metadata.
        external_metadata_source (tp.Optional[str]): Folder containing JSON metadata for the corresponding dataset.
            The metadata files contained in this folder are expected to match the stem of the audio file with
            a json extension.
        aug_p (float): Probability of performing audio mixing augmentation on the batch.
        mix_p (float): Proportion of batch items that are mixed together when applying audio mixing augmentation.
        mix_snr_low (int): Lowerbound for SNR value sampled for mixing augmentation.
        mix_snr_high (int): Upperbound for SNR value sampled for mixing augmentation.
        mix_min_overlap (float): Minimum overlap between audio files when performing mixing augmentation.
        kwargs: Additional arguments for AudioDataset.

    See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
    """
    def __init__(
        self,
        *args,
        info_fields_required: bool = True,
        external_metadata_source: tp.Optional[str] = None,
        aug_p: float = 0.,
        mix_p: float = 0.,
        mix_snr_low: int = -5,
        mix_snr_high: int = 5,
        mix_min_overlap: float = 0.5,
        **kwargs
    ):
        kwargs['return_info'] = True  # We require the info for each song of the dataset.
        super().__init__(*args, **kwargs)
        self.info_fields_required = info_fields_required
        self.external_metadata_source = external_metadata_source
        self.aug_p = aug_p
        self.mix_p = mix_p
        if self.aug_p > 0:
            assert self.mix_p > 0, "Expecting some mixing proportion mix_p if aug_p > 0"
            assert self.channels == 1, "SoundDataset with audio mixing considers only monophonic audio"
        self.mix_snr_low = mix_snr_low
        self.mix_snr_high = mix_snr_high
        self.mix_min_overlap = mix_min_overlap

    def _get_info_path(self, path: tp.Union[str, Path]) -> Path:
        """Get path of JSON with metadata (description, etc.).
        If there exists a JSON with the same name as 'path.name', then it will be used.
        Else, such JSON will be searched for in an external json source folder if it exists.
        """
        info_path = Path(path).with_suffix('.json')
        if Path(info_path).exists():
            return info_path
        elif self.external_metadata_source and (Path(self.external_metadata_source) / info_path.name).exists():
            return Path(self.external_metadata_source) / info_path.name
        else:
            raise Exception(f"Unable to find a metadata JSON for path: {path}")

    def __getitem__(self, index):
        wav, info = super().__getitem__(index)
        info_data = info.to_dict()
        info_path = self._get_info_path(info.meta.path)
        if Path(info_path).exists():
            with open(info_path, 'r') as json_file:
                sound_data = json.load(json_file)
                sound_data.update(info_data)
                sound_info = SoundInfo.from_dict(sound_data, fields_required=self.info_fields_required)
                # if there are multiple descriptions, sample one randomly
                if isinstance(sound_info.description, list):
                    sound_info.description = random.choice(sound_info.description)
        else:
            sound_info = SoundInfo.from_dict(info_data, fields_required=False)

        sound_info.self_wav = WavCondition(
            wav=wav[None], length=torch.tensor([info.n_frames]),
            sample_rate=[sound_info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])

        return wav, sound_info

    def collater(self, samples):
        # when training, audio mixing is performed in the collate function
        wav, sound_info = super().collater(samples)  # SoundDataset always returns infos
        if self.aug_p > 0:
            wav, sound_info = mix_samples(wav, sound_info, self.aug_p, self.mix_p,
                                          snr_low=self.mix_snr_low, snr_high=self.mix_snr_high,
                                          min_overlap=self.mix_min_overlap)
        return wav, sound_info


def rms_f(x: torch.Tensor) -> torch.Tensor:
    return (x ** 2).mean(1).pow(0.5)


def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Tensor:
    """Normalize the signal to the target level."""
    rms = rms_f(audio)
    scalar = 10 ** (target_level / 20) / (rms + EPS)
    audio = audio * scalar.unsqueeze(1)
    return audio


def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) -> torch.Tensor:
    return (abs(audio) > clipping_threshold).any(1)


def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -> torch.Tensor:
    start = random.randint(0, int(src.shape[1] * (1 - min_overlap)))
    remainder = src.shape[1] - start
    if dst.shape[1] > remainder:
        src[:, start:] = src[:, start:] + dst[:, :remainder]
    else:
        src[:, start:start+dst.shape[1]] = src[:, start:start+dst.shape[1]] + dst
    return src


def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_overlap: float,
              target_level: int = -25, clipping_threshold: float = 0.99) -> torch.Tensor:
    """Function to mix clean speech and noise at various SNR levels.

    Args:
        clean (torch.Tensor): Clean audio source to mix, of shape [B, T].
        noise (torch.Tensor): Noise audio source to mix, of shape [B, T].
        snr (int): SNR level when mixing.
        min_overlap (float): Minimum overlap between the two mixed sources.
        target_level (int): Gain level in dB.
        clipping_threshold (float): Threshold for clipping the audio.
    Returns:
        torch.Tensor: The mixed audio, of shape [B, T].
    """
    if clean.shape[1] > noise.shape[1]:
        noise = torch.nn.functional.pad(noise, (0, clean.shape[1] - noise.shape[1]))
    else:
        noise = noise[:, :clean.shape[1]]

    # normalizing to -25 dB FS
    clean = clean / (clean.max(1)[0].abs().unsqueeze(1) + EPS)
    clean = normalize(clean, target_level)
    rmsclean = rms_f(clean)

    noise = noise / (noise.max(1)[0].abs().unsqueeze(1) + EPS)
    noise = normalize(noise, target_level)
    rmsnoise = rms_f(noise)

    # set the noise level for a given SNR
    noisescalar = (rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)).unsqueeze(1)
    noisenewlevel = noise * noisescalar

    # mix noise and clean speech
    noisyspeech = mix_pair(clean, noisenewlevel, min_overlap)

    # randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value
    # there is a chance of clipping that might happen with very less probability, which is not a major issue.
    noisy_rms_level = np.random.randint(TARGET_LEVEL_LOWER, TARGET_LEVEL_UPPER)
    rmsnoisy = rms_f(noisyspeech)
    scalarnoisy = (10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)).unsqueeze(1)
    noisyspeech = noisyspeech * scalarnoisy
    clean = clean * scalarnoisy
    noisenewlevel = noisenewlevel * scalarnoisy

    # final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly
    clipped = is_clipped(noisyspeech)
    if clipped.any():
        noisyspeech_maxamplevel = noisyspeech[clipped].max(1)[0].abs().unsqueeze(1) / (clipping_threshold - EPS)
        noisyspeech[clipped] = noisyspeech[clipped] / noisyspeech_maxamplevel

    return noisyspeech


def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high: int, min_overlap: float):
    if snr_low == snr_high:
        snr = snr_low
    else:
        snr = np.random.randint(snr_low, snr_high)
    mix = snr_mixer(src, dst, snr, min_overlap)
    return mix


def mix_text(src_text: str, dst_text: str):
    """Mix text from different sources by concatenating them."""
    if src_text == dst_text:
        return src_text
    return src_text + " " + dst_text


def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: float, mix_p: float,
                snr_low: int, snr_high: int, min_overlap: float):
    """Mix samples within a batch, summing the waveforms and concatenating the text infos.

    Args:
        wavs (torch.Tensor): Audio tensors of shape [B, C, T].
        infos (list[SoundInfo]): List of SoundInfo items corresponding to the audio.
        aug_p (float): Augmentation probability.
        mix_p (float): Proportion of items in the batch to mix (and merge) together.
        snr_low (int): Lowerbound for sampling SNR.
        snr_high (int): Upperbound for sampling SNR.
        min_overlap (float): Minimum overlap between mixed samples.
    Returns:
        tuple[torch.Tensor, list[SoundInfo]]: A tuple containing the mixed wavs
            and mixed SoundInfo for the given batch.
    """
    # no mixing to perform within the batch
    if mix_p == 0:
        return wavs, infos

    if random.uniform(0, 1) < aug_p:
        # perform all augmentations on waveforms as [B, T]
        # randomly picking pairs of audio to mix
        assert wavs.size(1) == 1, f"Mix samples requires monophonic audio but C={wavs.size(1)}"
        wavs = wavs.mean(dim=1, keepdim=False)
        B, T = wavs.shape
        k = int(mix_p * B)
        mixed_sources_idx = torch.randperm(B)[:k]
        mixed_targets_idx = torch.randperm(B)[:k]
        aug_wavs = snr_mix(
            wavs[mixed_sources_idx],
            wavs[mixed_targets_idx],
            snr_low,
            snr_high,
            min_overlap,
        )
        # mixing textual descriptions in metadata
        descriptions = [info.description for info in infos]
        aug_infos = []
        for i, j in zip(mixed_sources_idx, mixed_targets_idx):
            text = mix_text(descriptions[i], descriptions[j])
            m = replace(infos[i])
            m.description = text
            aug_infos.append(m)

        # back to [B, C, T]
        aug_wavs = aug_wavs.unsqueeze(1)
        assert aug_wavs.shape[0] > 0, "Samples mixing returned empty batch."
        assert aug_wavs.dim() == 3, f"Returned wav should be [B, C, T] but dim = {aug_wavs.dim()}"
        assert aug_wavs.shape[0] == len(aug_infos), "Mismatch between number of wavs and infos in the batch"

        return aug_wavs, aug_infos  # [B, C, T]
    else:
        # randomly pick samples in the batch to match
        # the batch size when performing audio mixing
        B, C, T = wavs.shape
        k = int(mix_p * B)
        wav_idx = torch.randperm(B)[:k]
        wavs = wavs[wav_idx]
        infos = [infos[i] for i in wav_idx]
        assert wavs.shape[0] == len(infos), "Mismatch between number of wavs and infos in the batch"

        return wavs, infos  # [B, C, T]