#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
from functools import lru_cache
from pathlib import Path
import platform
import shutil
import tempfile
import zipfile
from typing import Tuple

import gradio as gr
from huggingface_hub import snapshot_download
import numpy as np
import torch

from project_settings import environment, project_path
from toolbox.torch.utils.data.vocabulary import Vocabulary


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--examples_dir",
        # default=(project_path / "data").as_posix(),
        default=(project_path / "data/examples").as_posix(),
        type=str
    )
    parser.add_argument(
        "--models_repo_id",
        default="qgyd2021/vm_sound_classification",
        type=str
    )
    parser.add_argument(
        "--trained_model_dir",
        default=(project_path / "trained_models").as_posix(),
        type=str
    )
    parser.add_argument(
        "--hf_token",
        default=environment.get("hf_token"),
        type=str,
    )
    parser.add_argument(
        "--server_port",
        default=environment.get("server_port", 7860),
        type=int
    )

    args = parser.parse_args()
    return args


@lru_cache(maxsize=100)
def load_model(model_file: Path):
    with zipfile.ZipFile(model_file, "r") as f_zip:
        out_root = Path(tempfile.gettempdir()) / "vm_sound_classification"
        if out_root.exists():
            shutil.rmtree(out_root.as_posix())
        out_root.mkdir(parents=True, exist_ok=True)
        f_zip.extractall(path=out_root)

    tgt_path = out_root / model_file.stem
    jit_model_file = tgt_path / "trace_model.zip"
    vocab_path = tgt_path / "vocabulary"

    vocabulary = Vocabulary.from_files(vocab_path.as_posix())

    with open(jit_model_file.as_posix(), "rb") as f:
        model = torch.jit.load(f)
    model.eval()

    shutil.rmtree(tgt_path)

    d = {
        "model": model,
        "vocabulary": vocabulary
    }
    return d


def click_button(audio: np.ndarray,
                 model_name: str,
                 ground_true: str) -> Tuple[str, float]:

    sample_rate, signal = audio

    model_file = "trained_models/{}.zip".format(model_name)
    model_file = Path(model_file)
    d = load_model(model_file)

    model = d["model"]
    vocabulary = d["vocabulary"]

    inputs = signal / (1 << 15)
    inputs = torch.tensor(inputs, dtype=torch.float32)
    inputs = torch.unsqueeze(inputs, dim=0)

    with torch.no_grad():
        logits = model.forward(inputs)
        probs = torch.nn.functional.softmax(logits, dim=-1)
        label_idx = torch.argmax(probs, dim=-1)

    label_idx = label_idx.cpu()
    probs = probs.cpu()

    label_idx = label_idx.numpy()[0]
    prob = probs.numpy()[0][label_idx]

    label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")

    return label_str, round(prob, 4)


def main():
    args = get_args()

    examples_dir = Path(args.examples_dir)
    trained_model_dir = Path(args.trained_model_dir)

    # download models
    if not trained_model_dir.exists():
        trained_model_dir.mkdir(parents=True, exist_ok=True)
        _ = snapshot_download(
            repo_id=args.models_repo_id,
            local_dir=trained_model_dir.as_posix(),
            token=args.hf_token,
        )

    # examples
    example_zip_file = trained_model_dir / "examples.zip"
    with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip:
        out_root = examples_dir
        if out_root.exists():
            shutil.rmtree(out_root.as_posix())
        out_root.mkdir(parents=True, exist_ok=True)
        f_zip.extractall(path=out_root)

    # models
    model_choices = list()
    for filename in trained_model_dir.glob("*.zip"):
        model_name = filename.stem
        if model_name == "examples":
            continue
        model_choices.append(model_name)
    model_choices = list(sorted(model_choices))

    # examples
    examples = list()
    for filename in examples_dir.glob("**/*/*.wav"):
        label = filename.parts[-2]

        examples.append([
            filename.as_posix(),
            model_choices[0],
            label
        ])

    # ui
    brief_description = """
国际语音智能外呼系统, 电话声音分类, 8000, int16. 
"""

    # ui
    with gr.Blocks() as blocks:
        gr.Markdown(value=brief_description)

        with gr.Row():
            with gr.Column(scale=3):
                c_audio = gr.Audio(label="audio")
                with gr.Row():
                    with gr.Column(scale=3):
                        c_model_name = gr.Dropdown(choices=model_choices, value=model_choices[0], label="model_name")
                    with gr.Column(scale=3):
                        c_ground_true = gr.Textbox(label="ground_true")

                c_button = gr.Button("run", variant="primary")
            with gr.Column(scale=3):
                c_label = gr.Textbox(label="label")
                c_probability = gr.Number(label="probability")

        gr.Examples(
            examples,
            inputs=[c_audio, c_model_name, c_ground_true],
            outputs=[c_label, c_probability],
            fn=click_button,
            examples_per_page=5,
        )

        c_button.click(
            click_button,
            inputs=[c_audio, c_model_name, c_ground_true],
            outputs=[c_label, c_probability],
        )

    # http://127.0.0.1:7864/
    blocks.queue().launch(
        share=False if platform.system() == "Windows" else False,
        server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
        server_port=args.server_port
    )
    return


if __name__ == "__main__":
    main()