import argparse
import os
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import List, Optional, Tuple
import torch

from huggingface_hub import (
    CommitOperationAdd,
    HfApi,
    get_repo_discussions,
    hf_hub_download,
)
from huggingface_hub.file_download import repo_folder_name
from optimum.exporters.onnx import validate_model_outputs
from optimum.exporters.tasks import TasksManager
from transformers import AutoConfig, AutoTokenizer, is_torch_available
from optimum.intel.openvino import (
    OVModelForAudioClassification,
    OVModelForCausalLM,
    OVModelForFeatureExtraction,
    OVModelForImageClassification,
    OVModelForMaskedLM,
    OVModelForQuestionAnswering,
    OVModelForSeq2SeqLM,
    OVModelForSequenceClassification,
    OVModelForTokenClassification,
    OVStableDiffusionPipeline,
)
from optimum.intel.utils.constant import _TASK_ALIASES
from optimum.intel.openvino.utils import _HEAD_TO_AUTOMODELS

SPACES_URL = "https://huggingface.co/spaces/echarlaix/openvino-export"


def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
    try:
        discussions = api.get_repo_discussions(repo_id=model_id)
    except Exception:
        return None
    for discussion in discussions:
        if (
            discussion.status == "open"
            and discussion.is_pull_request
            and discussion.title == pr_title
        ):
            return discussion


def convert_openvino(model_id: str, task: str, folder: str) -> List:
    task = _TASK_ALIASES.get(task, task)
    if task not in _HEAD_TO_AUTOMODELS:
        raise ValueError(f"The task '{task}' is not supported, only {_HEAD_TO_AUTOMODELS.keys()} tasks are supported")

    if task == "text2text-generation":
        raise ValueError("Export of Seq2Seq models is currently disabled.")

    auto_model_class = eval(_HEAD_TO_AUTOMODELS[task])
    ov_model = auto_model_class.from_pretrained(model_id, export=True)
    ov_model.save_pretrained(folder)
    if not isinstance(ov_model, OVStableDiffusionPipeline):
        try:
            model = TasksManager.get_model_from_task(task, model_id)
            exporter_config_class = TasksManager.get_exporter_config_constructor(
                exporter="openvino",
                model=model,
                task=task,
                model_name=model_id,
                model_type=model.config.model_type.replace("_", "-"),
            )
            openvino_config = exporter_config_class(model.config)
            inputs = openvino_config.generate_dummy_inputs(framework="pt")
            ov_outputs = ov_model(**inputs)
            outputs = model(**inputs)

            for output_name in ov_outputs:
                if isinstance(outputs, torch.Tensor) and not torch.allclose(outputs[output_name], ov_outputs[output_name], atol=1e-3):
                    raise ValueError(
                        "The exported model does not have the same outputs as the original model. Export interrupted."
                    )
        except Exception as e:
            raise

    file_names = {elem for elem in os.listdir(folder) if os.path.isfile(os.path.join(folder, elem))}

    operations = [
        CommitOperationAdd(
            path_in_repo=file_name, path_or_fileobj=os.path.join(folder, file_name)
        )
        for file_name in file_names if "openvino" in file_name
    ]

    dir_names = set(os.listdir(folder)) - file_names

    for dir_name in dir_names.intersection({"vae_encoder", "vae_decoder", "text_encoder", "unet"}):
        operations += [
            CommitOperationAdd(
                path_in_repo=os.path.join(dir_name, file_name),
                path_or_fileobj=os.path.join(folder, dir_name, file_name),
            )
            for file_name in os.listdir(os.path.join(folder, dir_name)) if "openvino" in file_name
        ]

    return operations


def convert(
    api: "HfApi",
    model_id: str,
    task: str,
    force: bool = False,
) -> Tuple[int, "CommitInfo"]:
    pr_title = "Adding OpenVINO file of this model"
    info = api.model_info(model_id)
    filenames = set(s.rfilename for s in info.siblings)

    requesting_user = api.whoami()["name"]

    if task == "auto":
        try:
            task = TasksManager.infer_task_from_model(model_id)
        except Exception as e:
            return (
                f"### Error: {e}. Please pass explicitely the task as it could not be infered.",
                None,
            )

    with TemporaryDirectory() as d:
        folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
        os.makedirs(folder)
        new_pr = None
        try:
            pr = previous_pr(api, model_id, pr_title)
            if "openvino_model.xml" in filenames and not force:
                raise Exception(f"Model {model_id} is already converted, skipping..")
            elif pr is not None and not force:
                url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
                new_pr = pr
                raise Exception(
                    f"Model {model_id} already has an open PR check out [{url}]({url})"
                )
            else:
                operations = convert_openvino(model_id, task, folder)

                commit_description = f"""
                Beep boop I am the [OpenVINO exporter bot 🤖]({SPACES_URL}). On behalf of [{requesting_user}](https://huggingface.co/{requesting_user}), I would like to add to this repository the exported OpenVINO model.
                """
                new_pr = api.create_commit(
                    repo_id=model_id,
                    operations=operations,
                    commit_message=pr_title,
                    commit_description=commit_description,
                    create_pr=True,
                )
        finally:
            shutil.rmtree(folder)
        return "0", new_pr