# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import os
import pickle
from tqdm import tqdm
import numpy as np

from modules import whisper_extractor as whisper


def whisper_encoder_batch(model, audio_paths):
    batch = len(audio_paths)
    batch_mel = torch.zeros((batch, 80, 3000), dtype=torch.float32, device=model.device)

    for i, audio_path in enumerate(audio_paths):
        # (48000,)
        audio = whisper.load_audio(str(audio_path))
        audio = whisper.pad_or_trim(audio)

        # (80, 3000)
        mel = whisper.log_mel_spectrogram(audio).to(model.device)
        batch_mel[i] = mel

    with torch.no_grad():
        # (batch, 1500, 1024)
        features = model.embed_audio(batch_mel)

    return features.cpu().detach().numpy()


def whisper_encoder(model, audio_path):
    audio = whisper.load_audio(str(audio_path))
    audio = whisper.pad_or_trim(audio)

    # (80, 3000)
    mel = whisper.log_mel_spectrogram(audio).to(model.device).unsqueeze(0)

    with torch.no_grad():
        # (1, 1500, 1024) -> # (1500, 1024)
        features = model.embed_audio(mel).squeeze(0)

    return features.cpu().detach().numpy()


def get_mapped_whisper_features(
    raw_whisper_features, mapping_features, fast_mapping=True
):
    """
    Whisper: frameshift = 20ms (30s audio -> 1500 frames), hop_size = 480 in 24k
    # Ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/model.py#L136

    Now it's only used for mapping to bigvgan's mels (sr = 24k, hop_size = 256, frameshift ~= 10.7 ms)
    """
    source_hop = 480
    target_hop = 256

    factor = np.gcd(source_hop, target_hop)
    source_hop //= factor
    target_hop //= factor
    print(
        "Mapping source's {} frames => target's {} frames".format(
            target_hop, source_hop
        )
    )

    max_source_len = 1500
    whisper_features = []
    for index, mapping_feat in enumerate(tqdm(mapping_features)):
        # mapping_feat: (mels_frame_len, n_mels)
        target_len = mapping_feat.shape[0]
        # The max target_len is 2812
        target_len = min(target_len, max_source_len * source_hop // target_hop)

        # (1500, dim)
        raw_feats = raw_whisper_features[index]
        width = raw_feats.shape[-1]

        if fast_mapping:
            source_len = target_len * target_hop // source_hop + 1
            raw_feats = raw_feats[:source_len]
        else:
            source_len = max_source_len

        # const ~= target_len * target_hop
        const = source_len * source_hop // target_hop * target_hop

        # (source_len * source_hop, dim)
        up_sampling_feats = np.repeat(raw_feats, source_hop, axis=0)
        # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim)
        down_sampling_feats = np.average(
            up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1
        )
        assert len(down_sampling_feats) >= target_len

        # (target_len, dim)
        feats = down_sampling_feats[:target_len]
        whisper_features.append(feats)

    return whisper_features


def load_whisper_model(hps):
    print("Loading Whisper Model: ", hps.whisper_model)
    model = whisper.load_model(hps.whisper_model)
    if torch.cuda.is_available():
        model = model.cuda()

    model = model.eval()
    return model


def load_target_acoustic_features(
    output_path, dataset, acoustic_features_name, acoustic_features_fs, dataset_type
):
    mapping_dir = os.path.join(
        output_path,
        dataset,
        "{}/{}".format(acoustic_features_name, acoustic_features_fs),
    )
    with open(os.path.join(mapping_dir, "{}.pkl".format(dataset_type)), "rb") as f:
        mapping_features = pickle.load(f)

    # Mels: (n_mels, frame_len) -> (frame_len, n_mels)
    if acoustic_features_name == "mels":
        print("Transposing mel features...")
        mapping_features = [feat.T for feat in mapping_features]

    print(
        "Mapping to the acoustic features {}, #sz = {}, feats[0] is {}".format(
            acoustic_features_name, len(mapping_features), mapping_features[0].shape
        )
    )
    return mapping_features


def extract_whisper_features_of_dataset(
    datasets,
    model,
    batch_size,
    out_dir,
):
    audio_paths = [utt["Path"] for utt in datasets]
    if len(audio_paths) < batch_size:
        batch_size = len(audio_paths)

    start, end = 0, 0
    while end < len(audio_paths):
        # Raw features: (batch_size, 1500, dim)
        start = end
        end = start + batch_size
        tmp_raw_whisper_features = whisper_encoder_batch(model, audio_paths[start:end])

        # Mapping to acoustic features' lengths
        for index, utt in enumerate(tqdm(datasets[start:end])):
            uid = utt["Uid"]
            raw_whisper_feature = tmp_raw_whisper_features[index]

            save_path = os.path.join(out_dir, uid + ".npy")
            np.save(save_path, raw_whisper_feature)

        print("{}/{} Done...".format(end, len(audio_paths)))