# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import os
import glob
from tqdm import tqdm
import json
import torch
import time

from models.svc.diffusion.diffusion_inference import DiffusionInference
from models.svc.comosvc.comosvc_inference import ComoSVCInference
from models.svc.transformer.transformer_inference import TransformerInference
from utils.util import load_config
from utils.audio_slicer import split_audio, merge_segments_encodec
from processors import acoustic_extractor, content_extractor


def build_inference(args, cfg, infer_type="from_dataset"):
    supported_inference = {
        "DiffWaveNetSVC": DiffusionInference,
        "DiffComoSVC": ComoSVCInference,
        "TransformerSVC": TransformerInference,
    }

    inference_class = supported_inference[cfg.model_type]
    return inference_class(args, cfg, infer_type)


def prepare_for_audio_file(args, cfg, num_workers=1):
    preprocess_path = cfg.preprocess.processed_dir
    audio_name = cfg.inference.source_audio_name
    temp_audio_dir = os.path.join(preprocess_path, audio_name)

    ### eval file
    t = time.time()
    eval_file = prepare_source_eval_file(cfg, temp_audio_dir, audio_name)
    args.source = eval_file
    with open(eval_file, "r") as f:
        metadata = json.load(f)
    print("Prepare for meta eval data: {:.1f}s".format(time.time() - t))

    ### acoustic features
    t = time.time()
    acoustic_extractor.extract_utt_acoustic_features_serial(
        metadata, temp_audio_dir, cfg
    )
    acoustic_extractor.cal_mel_min_max(
        dataset=audio_name, output_path=preprocess_path, cfg=cfg, metadata=metadata
    )
    acoustic_extractor.cal_pitch_statistics_svc(
        dataset=audio_name, output_path=preprocess_path, cfg=cfg, metadata=metadata
    )
    print("Prepare for acoustic features: {:.1f}s".format(time.time() - t))

    ### content features
    t = time.time()
    content_extractor.extract_utt_content_features_dataloader(
        cfg, metadata, num_workers
    )
    print("Prepare for content features: {:.1f}s".format(time.time() - t))
    return args, cfg, temp_audio_dir


def merge_for_audio_segments(audio_files, args, cfg):
    audio_name = cfg.inference.source_audio_name
    target_singer_name = args.target_singer

    merge_segments_encodec(
        wav_files=audio_files,
        fs=cfg.preprocess.sample_rate,
        output_path=os.path.join(
            args.output_dir, "{}_{}.wav".format(audio_name, target_singer_name)
        ),
        overlap_duration=cfg.inference.segments_overlap_duration,
    )

    for tmp_file in audio_files:
        os.remove(tmp_file)


def prepare_source_eval_file(cfg, temp_audio_dir, audio_name):
    """
    Prepare the eval file (json) for an audio
    """

    audio_chunks_results = split_audio(
        wav_file=cfg.inference.source_audio_path,
        target_sr=cfg.preprocess.sample_rate,
        output_dir=os.path.join(temp_audio_dir, "wavs"),
        max_duration_of_segment=cfg.inference.segments_max_duration,
        overlap_duration=cfg.inference.segments_overlap_duration,
    )

    metadata = []
    for i, res in enumerate(audio_chunks_results):
        res["index"] = i
        res["Dataset"] = audio_name
        res["Singer"] = audio_name
        res["Uid"] = "{}_{}".format(audio_name, res["Uid"])
        metadata.append(res)

    eval_file = os.path.join(temp_audio_dir, "eval.json")
    with open(eval_file, "w") as f:
        json.dump(metadata, f, indent=4, ensure_ascii=False, sort_keys=True)

    return eval_file


def cuda_relevant(deterministic=False):
    torch.cuda.empty_cache()
    # TF32 on Ampere and above
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.allow_tf32 = True
    # Deterministic
    torch.backends.cudnn.deterministic = deterministic
    torch.backends.cudnn.benchmark = not deterministic
    torch.use_deterministic_algorithms(deterministic)


def infer(args, cfg, infer_type):
    # Build inference
    t = time.time()
    trainer = build_inference(args, cfg, infer_type)
    print("Model Init: {:.1f}s".format(time.time() - t))

    # Run inference
    t = time.time()
    output_audio_files = trainer.inference()
    print("Model inference: {:.1f}s".format(time.time() - t))
    return output_audio_files


def build_parser():
    r"""Build argument parser for inference.py.
    Anything else should be put in an extra config YAML file.
    """

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config",
        type=str,
        required=True,
        help="JSON/YAML file for configurations.",
    )
    parser.add_argument(
        "--acoustics_dir",
        type=str,
        help="Acoustics model checkpoint directory. If a directory is given, "
        "search for the latest checkpoint dir in the directory. If a specific "
        "checkpoint dir is given, directly load the checkpoint.",
    )
    parser.add_argument(
        "--vocoder_dir",
        type=str,
        required=True,
        help="Vocoder checkpoint directory. Searching behavior is the same as "
        "the acoustics one.",
    )
    parser.add_argument(
        "--target_singer",
        type=str,
        required=True,
        help="convert to a specific singer (e.g. --target_singers singer_id).",
    )
    parser.add_argument(
        "--trans_key",
        default=0,
        help="0: no pitch shift; autoshift: pitch shift;  int: key shift.",
    )
    parser.add_argument(
        "--source",
        type=str,
        default="source_audio",
        help="Source audio file or directory. If a JSON file is given, "
        "inference from dataset is applied. If a directory is given, "
        "inference from all wav/flac/mp3 audio files in the directory is applied. "
        "Default: inference from all wav/flac/mp3 audio files in ./source_audio",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="conversion_results",
        help="Output directory. Default: ./conversion_results",
    )
    parser.add_argument(
        "--log_level",
        type=str,
        default="warning",
        help="Logging level. Default: warning",
    )
    parser.add_argument(
        "--keep_cache",
        action="store_true",
        default=True,
        help="Keep cache files. Only applicable to inference from files.",
    )
    parser.add_argument(
        "--diffusion_inference_steps",
        type=int,
        default=1000,
        help="Number of inference steps. Only applicable to diffusion inference.",
    )
    return parser


def main(args_list):
    ### Parse arguments and config
    args = build_parser().parse_args(args_list)
    cfg = load_config(args.config)

    # CUDA settings
    cuda_relevant()

    if os.path.isdir(args.source):
        ### Infer from file

        # Get all the source audio files (.wav, .flac, .mp3)
        source_audio_dir = args.source
        audio_list = []
        for suffix in ["wav", "flac", "mp3"]:
            audio_list += glob.glob(
                os.path.join(source_audio_dir, "**/*.{}".format(suffix)), recursive=True
            )
        print("There are {} source audios: ".format(len(audio_list)))

        # Infer for every file as dataset
        output_root_path = args.output_dir
        for audio_path in tqdm(audio_list):
            audio_name = audio_path.split("/")[-1].split(".")[0]
            args.output_dir = os.path.join(output_root_path, audio_name)
            print("\n{}\nConversion for {}...\n".format("*" * 10, audio_name))

            cfg.inference.source_audio_path = audio_path
            cfg.inference.source_audio_name = audio_name
            cfg.inference.segments_max_duration = 10.0
            cfg.inference.segments_overlap_duration = 1.0

            # Prepare metadata and features
            args, cfg, cache_dir = prepare_for_audio_file(args, cfg)

            # Infer from file
            output_audio_files = infer(args, cfg, infer_type="from_file")

            # Merge the split segments
            merge_for_audio_segments(output_audio_files, args, cfg)

            # Keep or remove caches
            if not args.keep_cache:
                os.removedirs(cache_dir)

    else:
        ### Infer from dataset
        infer(args, cfg, infer_type="from_dataset")


if __name__ == "__main__":
    main()