# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2022 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
from pathlib import Path
from typing import Optional, Union

import numpy as np

from .download_utils import ppdiffusers_bos_download
from .utils import (
    FASTDEPLOY_MODEL_NAME,
    FASTDEPLOY_WEIGHTS_NAME,
    is_fastdeploy_available,
    is_paddle_available,
    logging,
)

if is_paddle_available():
    import paddle


if is_fastdeploy_available():
    import fastdeploy as fd

    def fdtensor2pdtensor(fdtensor: fd.C.FDTensor):
        dltensor = fdtensor.to_dlpack()
        pdtensor = paddle.utils.dlpack.from_dlpack(dltensor)
        return pdtensor

    def pdtensor2fdtensor(pdtensor: paddle.Tensor, name: str = "", share_with_raw_ptr=False):
        if not share_with_raw_ptr:
            dltensor = paddle.utils.dlpack.to_dlpack(pdtensor)
            return fd.C.FDTensor.from_dlpack(name, dltensor)
        else:
            return fd.C.FDTensor.from_external_data(
                name,
                pdtensor.data_ptr(),
                pdtensor.shape,
                pdtensor.dtype.name,
                str(pdtensor.place),
                int(pdtensor.place.gpu_device_id()),
            )


logger = logging.get_logger(__name__)


class FastDeployRuntimeModel:
    def __init__(self, model=None, **kwargs):
        logger.info("`ppdiffusers.FastDeployRuntimeModel` is experimental and might change in the future.")
        self.model = model
        self.model_save_dir = kwargs.get("model_save_dir", None)
        self.latest_model_name = kwargs.get("latest_model_name", "inference.pdmodel")
        self.latest_params_name = kwargs.get("latest_params_name", "inference.pdiparams")

    def zero_copy_infer(self, prebinded_inputs: dict, prebinded_outputs: dict, share_with_raw_ptr=True, **kwargs):
        """
        Execute inference without copying data from cpu to gpu.

        Arguments:
            kwargs (`dict(name, paddle.Tensor)`):
                An input map from name to tensor.
        Return:
            List of output tensor.
        """
        for inputs_name, inputs_tensor in prebinded_inputs.items():
            input_fdtensor = pdtensor2fdtensor(inputs_tensor, inputs_name, share_with_raw_ptr=share_with_raw_ptr)
            self.model.bind_input_tensor(inputs_name, input_fdtensor)

        for outputs_name, outputs_tensor in prebinded_outputs.items():
            output_fdtensor = pdtensor2fdtensor(outputs_tensor, outputs_name, share_with_raw_ptr=share_with_raw_ptr)
            self.model.bind_output_tensor(outputs_name, output_fdtensor)

            self.model.zero_copy_infer()

    def __call__(self, **kwargs):
        inputs = {k: np.array(v) for k, v in kwargs.items()}
        return self.model.infer(inputs)

    @staticmethod
    def load_model(
        model_path: Union[str, Path],
        params_path: Union[str, Path],
        runtime_options: Optional["fd.RuntimeOption"] = None,
    ):
        """
        Loads an FastDeploy Inference Model with fastdeploy.RuntimeOption

        Arguments:
            model_path (`str` or `Path`):
                Model path from which to load
            params_path (`str` or `Path`):
                Params path from which to load
            runtime_options (fd.RuntimeOption, *optional*):
                The RuntimeOption of fastdeploy to initialize the fastdeploy runtime. Default setting
                the device to cpu and the backend to paddle inference
        """
        option = runtime_options
        if option is None or not isinstance(runtime_options, fd.RuntimeOption):
            logger.info("No fastdeploy.RuntimeOption specified, using CPU device and paddle inference backend.")
            option = fd.RuntimeOption()
            option.use_paddle_backend()
            option.use_cpu()
        option.set_model_path(model_path, params_path)
        return fd.Runtime(option)

    def _save_pretrained(
        self,
        save_directory: Union[str, Path],
        model_file_name: Optional[str] = None,
        params_file_name: Optional[str] = None,
        **kwargs
    ):
        """
        Save a model and its configuration file to a directory, so that it can be re-loaded using the
        [`~FastDeployRuntimeModel.from_pretrained`] class method. It will always save the
        latest_model_name.

        Arguments:
            save_directory (`str` or `Path`):
                Directory where to save the model file.
            model_file_name(`str`, *optional*):
                Overwrites the default model file name from `"inference.pdmodel"` to `model_file_name`. This allows you to save the
                model with a different name.
            params_file_name(`str`, *optional*):
                Overwrites the default model file name from `"inference.pdiparams"` to `params_file_name`. This allows you to save the
                model with a different name.
        """

        model_file_name = model_file_name if model_file_name is not None else FASTDEPLOY_MODEL_NAME
        params_file_name = params_file_name if params_file_name is not None else FASTDEPLOY_WEIGHTS_NAME

        src_model_path = self.model_save_dir.joinpath(self.latest_model_name)
        dst_model_path = Path(save_directory).joinpath(model_file_name)

        src_params_path = self.model_save_dir.joinpath(self.latest_params_name)
        dst_params_path = Path(save_directory).joinpath(params_file_name)
        try:
            shutil.copyfile(src_model_path, dst_model_path)
            shutil.copyfile(src_params_path, dst_params_path)
        except shutil.SameFileError:
            pass

    def save_pretrained(
        self,
        save_directory: Union[str, os.PathLike],
        **kwargs,
    ):
        """
        Save a model to a directory, so that it can be re-loaded using the [`~FastDeployRuntimeModel.from_pretrained`] class
        method.:

        Arguments:
            save_directory (`str` or `os.PathLike`):
                Directory to which to save. Will be created if it doesn't exist.
        """
        if os.path.isfile(save_directory):
            logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
            return

        os.makedirs(save_directory, exist_ok=True)

        # saving model weights/files
        self._save_pretrained(save_directory, **kwargs)

    @classmethod
    def _from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, Path],
        cache_dir: Optional[str] = None,
        model_file_name: Optional[str] = None,
        params_file_name: Optional[str] = None,
        runtime_options: Optional["fd.RuntimeOption"] = None,
        **kwargs,
    ):
        """
        Load a model from a directory or the BOS.

        Arguments:
            pretrained_model_name_or_path (`str` or `Path`):
                Directory from which to load
            cache_dir (`Union[str, Path]`, *optional*):
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
            model_file_name (`str`):
                Overwrites the default model file name from `"inference.pdmodel"` to `file_name`. This allows you to load
                different model files from the same repository or directory.
            params_file_name (`str`):
                Overwrites the default params file name from `"inference.pdiparams"` to `file_name`. This allows you to load
                different model files from the same repository or directory.
            runtime_options (`fastdeploy.RuntimeOption`, *optional*):
                The RuntimeOption of fastdeploy.
            kwargs (`Dict`, *optional*):
                kwargs will be passed to the model during initialization
        """
        model_file_name = model_file_name if model_file_name is not None else FASTDEPLOY_MODEL_NAME
        params_file_name = params_file_name if params_file_name is not None else FASTDEPLOY_WEIGHTS_NAME
        # load model from local directory
        if os.path.isdir(pretrained_model_name_or_path):
            model = FastDeployRuntimeModel.load_model(
                os.path.join(pretrained_model_name_or_path, model_file_name),
                os.path.join(pretrained_model_name_or_path, params_file_name),
                runtime_options=runtime_options,
            )
            kwargs["model_save_dir"] = Path(pretrained_model_name_or_path)
        # load model from hub
        else:
            # download model
            model_cache_path = ppdiffusers_bos_download(
                pretrained_model_name_or_path=pretrained_model_name_or_path,
                filename=model_file_name,
                cache_dir=cache_dir,
            )
            # download params
            params_cache_path = ppdiffusers_bos_download(
                pretrained_model_name_or_path=pretrained_model_name_or_path,
                filename=params_file_name,
                cache_dir=cache_dir,
            )
            kwargs["model_save_dir"] = Path(model_cache_path).parent
            kwargs["latest_model_name"] = Path(model_cache_path).name
            kwargs["latest_params_name"] = Path(params_cache_path).name
            model = FastDeployRuntimeModel.load_model(
                model_cache_path, params_cache_path, runtime_options=runtime_options
            )
        return cls(model=model, **kwargs)

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, Path],
        cache_dir: Optional[str] = None,
        model_file_name: Optional[str] = None,
        params_file_name: Optional[str] = None,
        runtime_options: Optional["fd.RuntimeOption"] = None,
        **model_kwargs,
    ):
        return cls._from_pretrained(
            pretrained_model_name_or_path=pretrained_model_name_or_path,
            cache_dir=cache_dir,
            model_file_name=model_file_name,
            params_file_name=params_file_name,
            runtime_options=runtime_options,
            **model_kwargs,
        )