IC4T commited on
Commit
a5be45f
·
1 Parent(s): 4093a3e
langchain/langchain/llms/ctransformers.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wrapper around the C Transformers library."""
2
+ from typing import Any, Dict, Optional, Sequence
3
+
4
+ from pydantic import root_validator
5
+
6
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
7
+ from langchain.llms.base import LLM
8
+
9
+
10
+ class CTransformers(LLM):
11
+ """Wrapper around the C Transformers LLM interface.
12
+
13
+ To use, you should have the ``ctransformers`` python package installed.
14
+ See https://github.com/marella/ctransformers
15
+
16
+ Example:
17
+ .. code-block:: python
18
+
19
+ from langchain.llms import CTransformers
20
+
21
+ llm = CTransformers(model="/path/to/ggml-gpt-2.bin", model_type="gpt2")
22
+ """
23
+
24
+ client: Any #: :meta private:
25
+
26
+ model: str
27
+ """The path to a model file or directory or the name of a Hugging Face Hub
28
+ model repo."""
29
+
30
+ model_type: Optional[str] = None
31
+ """The model type."""
32
+
33
+ model_file: Optional[str] = None
34
+ """The name of the model file in repo or directory."""
35
+
36
+ config: Optional[Dict[str, Any]] = None
37
+ """The config parameters.
38
+ See https://github.com/marella/ctransformers#config"""
39
+
40
+ lib: Optional[str] = None
41
+ """The path to a shared library or one of `avx2`, `avx`, `basic`."""
42
+
43
+ @property
44
+ def _identifying_params(self) -> Dict[str, Any]:
45
+ """Get the identifying parameters."""
46
+ return {
47
+ "model": self.model,
48
+ "model_type": self.model_type,
49
+ "model_file": self.model_file,
50
+ "config": self.config,
51
+ }
52
+
53
+ @property
54
+ def _llm_type(self) -> str:
55
+ """Return type of llm."""
56
+ return "ctransformers"
57
+
58
+ @root_validator()
59
+ def validate_environment(cls, values: Dict) -> Dict:
60
+ """Validate that ``ctransformers`` package is installed."""
61
+ try:
62
+ from ctransformers import AutoModelForCausalLM
63
+ except ImportError:
64
+ raise ImportError(
65
+ "Could not import `ctransformers` package. "
66
+ "Please install it with `pip install ctransformers`"
67
+ )
68
+
69
+ config = values["config"] or {}
70
+ values["client"] = AutoModelForCausalLM.from_pretrained(
71
+ values["model"],
72
+ model_type=values["model_type"],
73
+ model_file=values["model_file"],
74
+ lib=values["lib"],
75
+ **config,
76
+ )
77
+ return values
78
+
79
+ def _call(
80
+ self,
81
+ prompt: str,
82
+ stop: Optional[Sequence[str]] = None,
83
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
84
+ **kwargs: Any,
85
+ ) -> str:
86
+ """Generate text from a prompt.
87
+
88
+ Args:
89
+ prompt: The prompt to generate text from.
90
+ stop: A list of sequences to stop generation when encountered.
91
+
92
+ Returns:
93
+ The generated text.
94
+
95
+ Example:
96
+ .. code-block:: python
97
+
98
+ response = llm("Tell me a joke.")
99
+ """
100
+ text = []
101
+ _run_manager = run_manager or CallbackManagerForLLMRun.get_noop_manager()
102
+ for chunk in self.client(prompt, stop=stop, stream=True):
103
+ text.append(chunk)
104
+ _run_manager.on_llm_new_token(chunk, verbose=self.verbose)
105
+ return "".join(text)