jonathanjordan21 commited on
Commit
96b7a63
·
1 Parent(s): e2801bb

Create custom_llm.py

Browse files
Files changed (1) hide show
  1. components/custom_llm.py +54 -0
components/custom_llm.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Mapping, Optional
2
+
3
+ from langchain_core.callbacks.manager import CallbackManagerForLLMRun
4
+ from langchain_core.language_models.llms import LLM
5
+ from typing import Literal
6
+ import requests
7
+ from langchain.prompts import PromptTemplate, ChatPromptTemplate
8
+ from operator import itemgetter
9
+
10
+ class CustomLLM(LLM):
11
+ repo_id : str
12
+ api_token : str
13
+ model_type: Literal["text2text-generation", "text-generation"]
14
+ max_new_tokens: int = None
15
+ temperature: float = 0.001
16
+ timeout: float = None
17
+ top_p: float = None
18
+ top_k : int = None
19
+ repetition_penalty : float = None
20
+ stop : List[str] = []
21
+
22
+
23
+ @property
24
+ def _llm_type(self) -> str:
25
+ return "custom"
26
+
27
+ def _call(
28
+ self,
29
+ prompt: str,
30
+ stop: Optional[List[str]] = None,
31
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
32
+ **kwargs: Any,
33
+ ) -> str:
34
+
35
+ headers = {"Authorization": f"Bearer {self.api_token}"}
36
+ API_URL = f"https://api-inference.huggingface.co/models/{self.repo_id}"
37
+
38
+ parameters_dict = {
39
+ 'max_new_tokens': self.max_new_tokens,
40
+ 'temperature': self.temperature,
41
+ 'timeout': self.timeout,
42
+ 'top_p': self.top_p,
43
+ 'top_k': self.top_k,
44
+ 'repetition_penalty': self.repetition_penalty,
45
+ 'stop':self.stop
46
+ }
47
+
48
+ if self.model_type == 'text-generation':
49
+ parameters_dict["return_full_text"]=False
50
+
51
+ data = {"inputs": prompt, "parameters":parameters_dict, "options":{"wait_for_model":True}}
52
+ data = requests.post(API_URL, headers=headers, json=data).json()
53
+ print(data)
54
+ return data[0]['generated_text']