import gradio as gr
import json
import subprocess
import urllib.parse
from pathlib import Path

from huggingface_hub import hf_hub_download, HfApi
from coremltools import ComputeUnit
from coremltools.models.utils import _is_macos, _macos_version

from transformers.onnx.utils import get_preprocessor

from exporters.coreml import export
from exporters.coreml.features import FeaturesManager
from exporters.coreml.validate import validate_model_outputs

compute_units_mapping = {
    "All": ComputeUnit.ALL,
    "CPU": ComputeUnit.CPU_ONLY,
    "CPU + GPU": ComputeUnit.CPU_AND_GPU,
    "CPU + NE": ComputeUnit.CPU_AND_NE,
}
compute_units_labels = list(compute_units_mapping.keys())

framework_mapping = {
    "PyTorch": "pt",
    "TensorFlow": "tf",
}
framework_labels = list(framework_mapping.keys())

precision_mapping = {
    "Float32": "float32",
    "Float16 quantization": "float16",
}
precision_labels = list(precision_mapping.keys())

tolerance_mapping = {
    "Model default": None,
    "1e-2": 1e-2,
    "1e-3": 1e-3,
    "1e-4": 1e-4,
}
tolerance_labels = list(tolerance_mapping.keys())

push_mapping = {
    "Submit a PR to the original repo": "pr",
    "Create a new repo": "new",
}
push_labels = list(push_mapping.keys())

def error_str(error, title="Error", model=None, task=None, framework=None, compute_units=None, precision=None, tolerance=None, destination=None):
    if not error: return ""

    issue_title = urllib.parse.quote(f"Error converting {model}")
    issue_description = urllib.parse.quote(f"""Conversion Settings:

        Model: {model}
        Task: {task}
        Framework: {framework}
        Compute Units: {compute_units}
        Precision: {precision}
        Tolerance: {tolerance}
        Push to: {destination}

        Error: {error}
        """)
    issue_url = f"https://huggingface.co/spaces/pcuenq/transformers-to-coreml/discussions/new?title={issue_title}&description={issue_description}"
    return f"""
        #### {title}
        {error}

        It could be that the model is not yet compatible with the Core ML exporter. Please, open a discussion on the [Hugging Face Hub]({issue_url}) to report this issue.
        """

def url_to_model_id(model_id_str):
    if not model_id_str.startswith("https://huggingface.co/"): return model_id_str
    return model_id_str.split("/")[-2] + "/" + model_id_str.split("/")[-1]

def get_pr_url(api, repo_id, title):
    try:
        discussions = api.get_repo_discussions(repo_id=repo_id)
    except Exception:
        return None
    for discussion in discussions:
        if (
            discussion.status == "open"
            and discussion.is_pull_request
            and discussion.title == title
        ):
            return f"https://huggingface.co/{repo_id}/discussions/{discussion.num}"
        
def supported_frameworks(model_id):
    """
    Return a list of supported frameworks (`PyTorch` or `TensorFlow`) for a given model_id.
    Only PyTorch and Tensorflow are supported.
    """
    api = HfApi()
    model_info = api.model_info(model_id)
    tags = model_info.tags
    frameworks = [tag for tag in tags if tag in ["pytorch", "tf"]]
    return sorted(["PyTorch" if f == "pytorch" else "TensorFlow" for f in frameworks])

def on_model_change(model):
    model = url_to_model_id(model)    
    tasks = None
    error = None

    try:
        config_file = hf_hub_download(model, filename="config.json")
        if config_file is None:
            raise Exception(f"Model {model} not found")

        with open(config_file, "r") as f:
            config_json = f.read()

        config = json.loads(config_json)
        model_type = config["model_type"]

        features = FeaturesManager.get_supported_features_for_model_type(model_type)
        tasks = list(features.keys())

        frameworks = supported_frameworks(model)
        selected_framework = frameworks[0] if len(frameworks) > 0 else None
        return (
            gr.update(visible=bool(model_type)),                                                    # Settings column
            gr.update(choices=tasks, value=tasks[0] if tasks else None),                            # Tasks
            gr.update(visible=len(frameworks)>1, choices=frameworks, value=selected_framework),     # Frameworks
            gr.update(value=error_str(error, model=model)),                                         # Error
        )
    except Exception as e:
        error = e
        model_type = None


def convert_model(preprocessor, model, model_coreml_config,
                  compute_units, precision, tolerance, output,
                  use_past=False, seq2seq=None,
                  progress=None, progress_start=0.1, progress_end=0.8):
    coreml_config = model_coreml_config(model.config, use_past=use_past, seq2seq=seq2seq)

    model_label = "model" if seq2seq is None else seq2seq
    progress(progress_start, desc=f"Converting {model_label}")
    mlmodel = export(
        preprocessor,
        model,
        coreml_config,
        quantize=precision,
        compute_units=compute_units,
    )

    filename = output
    if seq2seq == "encoder":
        filename = filename.parent / ("encoder_" + filename.name)
    elif seq2seq == "decoder":
        filename = filename.parent / ("decoder_" + filename.name)
    filename = filename.as_posix()

    mlmodel.save(filename)

    if _is_macos() and _macos_version() >= (12, 0):
        progress(progress_end * 0.8, desc=f"Validating {model_label}")
        if tolerance is None:
            tolerance = coreml_config.atol_for_validation
        validate_model_outputs(coreml_config, preprocessor, model, mlmodel, tolerance)
    progress(progress_end, desc=f"Done converting {model_label}")


def push_to_hub(destination, directory, task, precision, token=None):
    api = HfApi(token=token)
    api.create_repo(destination, token=token, exist_ok=True)
    commit_message="Add Core ML conversion"
    api.upload_folder(
        folder_path=directory,
        repo_id=destination,
        token=token,
        create_pr=True,
        commit_message=commit_message,
        commit_description=f"Core ML conversion, task={task}, precision={precision}",
    )

    subprocess.run(["rm", "-rf", directory])
    return get_pr_url(HfApi(token=token), destination, commit_message)


def convert(model_id, task,
            compute_units, precision, tolerance, framework,
            push_destination, destination_model, token,
            progress=gr.Progress()):
    model_id = url_to_model_id(model_id)
    compute_units = compute_units_mapping[compute_units]
    precision = precision_mapping[precision]
    tolerance = tolerance_mapping[tolerance]
    framework = framework_mapping[framework]
    push_destination = push_mapping[push_destination]
    if push_destination == "pr":
        destination_model = model_id
        token = None

    # TODO: support legacy format
    base = Path("exported")/model_id
    output = base/"coreml"/task
    output.mkdir(parents=True, exist_ok=True)
    output = output/f"{precision}_model.mlpackage"

    try:
        progress(0, desc="Downloading model")

        preprocessor = get_preprocessor(model_id)
        model = FeaturesManager.get_model_from_feature(task, model_id, framework=framework)
        _, model_coreml_config = FeaturesManager.check_supported_model_or_raise(model, feature=task)

        if task in ["seq2seq-lm", "speech-seq2seq"]:
            convert_model(
                preprocessor,
                model,
                model_coreml_config,
                compute_units,
                precision,
                tolerance,
                output,
                seq2seq="encoder",
                progress=progress,
                progress_start=0.1,
                progress_end=0.4,
            )
            progress(0.4, desc="Converting decoder")
            convert_model(
                preprocessor,
                model,
                model_coreml_config,
                compute_units,
                precision,
                tolerance,
                output,
                seq2seq="decoder",
                progress=progress,
                progress_start=0.4,
                progress_end=0.7,
            )
        else:
            convert_model(
                preprocessor,
                model,
                model_coreml_config,
                compute_units,
                precision,
                tolerance,
                output,
                progress=progress,
                progress_end=0.7,
            )

        progress(0.7, "Uploading model to Hub")
        pr_url = push_to_hub(destination_model, base, task, precision, token=token)
        progress(1, "Done")

        did_validate = _is_macos() and _macos_version() >= (12, 0)
        result = f"""### Successfully converted!
        We opened a PR to add the Core ML weights to the model repo. Please, view and merge the PR [here]({pr_url}).

        {f"**Note**: model could not be automatically validated as this Space is not running on macOS." if not did_validate else ""}
        """

        return result
    except Exception as e:
        return error_str(e, model=model_id, task=task, framework=framework, compute_units=compute_units, precision=precision, tolerance=tolerance)

DESCRIPTION = """
## Convert a transformers model to Core ML

With this Space you can try to convert a transformers model to Core ML. It uses the 🤗 Hugging Face [Exporters repo](https://huggingface.co/exporters) under the hood.

Note that not all models are supported. If you get an error on a model you'd like to convert, please open an issue on the [repo](https://github.com/huggingface/exporters).

After conversion, you can choose to submit a PR to the original repo, or create your own repo with just the converted Core ML weights.
"""

with gr.Blocks() as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown("## 1. Load model info")
            input_model = gr.Textbox(
                max_lines=1,
                label="Model name or URL, such as apple/mobilevit-small",
                placeholder="pcuenq/distilbert-base-uncased",
            )
            btn_get_tasks = gr.Button("Load")
        with gr.Column(scale=3):
            with gr.Column(visible=False) as group_settings:
                gr.Markdown("## 2. Select Task")
                radio_tasks = gr.Radio(label="Choose the task for the converted model.")
                gr.Markdown("The `default` task is suitable for feature extraction.")
                radio_framework = gr.Radio(
                    visible=False,
                    label="Framework",
                    choices=framework_labels,
                    value=framework_labels[0],
                )
                radio_compute = gr.Radio(
                    label="Compute Units",
                    choices=compute_units_labels,
                    value=compute_units_labels[0],
                )
                radio_precision = gr.Radio(
                    label="Precision",
                    choices=precision_labels,
                    value=precision_labels[0],
                )
                radio_tolerance = gr.Radio(
                    label="Absolute Tolerance for Validation",
                    choices=tolerance_labels,
                    value=tolerance_labels[0],
                )

                radio_push = gr.Radio(
                    label="Destination Model",
                    choices=push_labels,
                    value=push_labels[0],
                )
                with gr.Row(visible=False) as row_destination:
                    # TODO: public/private
                    text_destination = gr.Textbox(label="Destination model name", value="")
                    text_token = gr.Textbox(label="Token (write permissions)", value="")

                btn_convert = gr.Button("Convert")
                gr.Markdown("Conversion will take a few minutes.")


    error_output = gr.Markdown(label="Output")

    # Clear output
    btn_get_tasks.click(lambda x: gr.update(value=''), [], [error_output])
    btn_convert.click(lambda x: gr.update(value=''), [], [error_output])
    input_model.submit(lambda x: gr.update(value=''), [],[error_output])

    input_model.submit(
        fn=on_model_change,
        inputs=input_model,
        outputs=[group_settings, radio_tasks, radio_framework, error_output],
        queue=False,
        scroll_to_output=True
    )
    btn_get_tasks.click(
        fn=on_model_change,
        inputs=input_model,
        outputs=[group_settings, radio_tasks, radio_framework, error_output],
        queue=False,
        scroll_to_output=True
    )
        
    btn_convert.click(
        fn=convert,
        inputs=[input_model, radio_tasks, radio_compute, radio_precision, radio_tolerance, radio_framework, radio_push, text_destination, text_token],
        outputs=error_output,
        scroll_to_output=True
    )

    radio_push.change(
        lambda x: gr.update(visible=x == "Create a new repo"),
        inputs=radio_push,
        outputs=row_destination,
        queue=False,
        scroll_to_output=True
    )

    gr.HTML("""
    <div style="border-top: 0.5px solid #303030;">
      <br>
      <p style="color:gray;font-size:smaller;font-style:italic">Adapted from https://huggingface.co/spaces/diffusers/sd-to-diffusers/tree/main</p><br>
    </div>
    """)
    
demo.queue(concurrency_count=1, max_size=10)
demo.launch(debug=True, share=False)