try:
    from langchain.llms.base import LLM
except ImportError:
    raise ImportError(
        "To use the ctransformers.langchain module, please install the "
        "`langchain` python package: `pip install langchain`"
    )

from typing import Any, Dict, Optional, Sequence

from pydantic import root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun

from ctransformers import AutoModelForCausalLM


class CTransformers(LLM):
    """Wrapper around the C Transformers LLM interface.

    To use, you should have the `langchain` python package installed.
    """

    client: Any  #: :meta private:

    model: str
    """The path to a model file or directory or the name of a Hugging Face Hub
    model repo."""

    model_type: Optional[str] = None
    """The model type."""

    model_file: Optional[str] = None
    """The name of the model file in repo or directory."""

    config: Optional[Dict[str, Any]] = None
    """The config parameters."""

    lib: Optional[Any] = None
    """The path to a shared library or one of `avx2`, `avx`, `basic`."""

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """Get the identifying parameters."""
        return {
            "model": self.model,
            "model_type": self.model_type,
            "model_file": self.model_file,
            "config": self.config,
        }

    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
        return "ctransformers"

    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate and load model from a local file or remote repo."""
        config = values["config"] or {}
        values["client"] = AutoModelForCausalLM.from_pretrained(
            values["model"],
            model_type=values["model_type"],
            model_file=values["model_file"],
            lib=values["lib"],
            **config,
        )
        return values

    def _call(
        self,
        prompt: str,
        stop: Optional[Sequence[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
    ) -> str:
        """Generate text from a prompt.

        Args:
            prompt: The prompt to generate text from.
            stop: A list of sequences to stop generation when encountered.

        Returns:
            The generated text.
        """
        text = []
        for chunk in self.client(prompt, stop=stop, stream=True):
            text.append(chunk)
            if run_manager:
                run_manager.on_llm_new_token(chunk, verbose=self.verbose)
        return "".join(text)