msy127 commited on
Commit
9531d4c
ยท
1 Parent(s): 540d541

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -19
app.py CHANGED
@@ -26,34 +26,40 @@ vectordb = Chroma(
26
 
27
  retriever = vectordb.as_retriever(search_kwargs={"k": 5})
28
 
29
- from typing import Optional, List, Dict, Any
30
- # LLM, KwArgsModel, InferenceClient ๋“ฑ ํ•„์š”ํ•œ ๋ชจ๋“ˆ ๋˜๋Š” ํด๋ž˜์Šค๋ฅผ ๋ถˆ๋Ÿฌ์™€์•ผ ํ•จ
31
 
32
- class CustomInferenceClient:
33
- def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
34
- self.llm = LLM(model_name=model_name, hf_token=hf_token) # LLM ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
35
- self.kw_args_model = KwArgsModel(kwargs=kwargs) # KwArgsModel ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
36
- self.inference_client = InferenceClient(model=model_name, token=hf_token) # InferenceClient ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
37
- self.model_name = model_name # ๋ชจ๋ธ ์ด๋ฆ„ ์ €์žฅ
38
 
39
- # _call ๋ฉ”์„œ๋“œ ๊ตฌํ˜„
40
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
41
  if stop is not None:
42
- raise ValueError("stop kwargs are not permitted.") # stop ์ธ์ž๊ฐ€ ์žˆ์œผ๋ฉด ์—๋Ÿฌ ๋ฐœ์ƒ
43
- kwargs = self.kw_args_model.kwargs # KwArgsModel๋กœ๋ถ€ํ„ฐ kwargs ๋ฐ›๊ธฐ
44
- response_gen = self.inference_client.text_generation(prompt, **kwargs, stream=True) # ํ…์ŠคํŠธ ์ƒ์„ฑ ์š”์ฒญ
45
- response = ''.join(response_gen) # ์ŠคํŠธ๋ฆผ์œผ๋กœ๋ถ€ํ„ฐ ๋ฌธ์ž์—ด ์ƒ์„ฑ
46
- return response # ์ƒ์„ฑ๋œ ๋ฌธ์ž์—ด ๋ฐ˜ํ™˜
47
 
48
- # _llm_type ์†์„ฑ ๊ตฌํ˜„
49
  @property
50
  def _llm_type(self) -> str:
51
- return "custom" # ์‚ฌ์šฉ์ž ์ •์˜ ํƒ€์ž…์œผ๋กœ ์ง€์ •
52
 
53
- # _identifying_params ์†์„ฑ ๊ตฌํ˜„
54
  @property
55
  def _identifying_params(self) -> dict:
56
- return {"model_name": self.model_name} # ๋ชจ๋ธ ์ด๋ฆ„์„ ๋ฐ˜ํ™˜
57
 
58
  kwargs = {"max_new_tokens":256, "temperature":0.9, "top_p":0.6, "repetition_penalty":1.3, "do_sample":True}
59
 
 
26
 
27
  retriever = vectordb.as_retriever(search_kwargs={"k": 5})
28
 
29
+ class KwArgsModel(BaseModel):
30
+ kwargs: Dict[str, Any] = Field(default_factory=dict)
31
 
32
+ class CustomInferenceClient(LLM, KwArgsModel):
33
+ model_name: str
34
+ inference_client: InferenceClient
 
 
 
35
 
36
+ def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
37
+ inference_client = InferenceClient(model=model_name, token=hf_token)
38
+ super().__init__(
39
+ model_name=model_name,
40
+ hf_token=hf_token,
41
+ kwargs=kwargs,
42
+ inference_client=inference_client
43
+ )
44
+
45
+ def _call(
46
+ self,
47
+ prompt: str,
48
+ stop: Optional[List[str]] = None
49
+ ) -> str:
50
  if stop is not None:
51
+ raise ValueError("stop kwargs are not permitted.")
52
+ response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True)
53
+ response = ''.join(response_gen)
54
+ return response
 
55
 
 
56
  @property
57
  def _llm_type(self) -> str:
58
+ return "custom"
59
 
 
60
  @property
61
  def _identifying_params(self) -> dict:
62
+ return {"model_name": self.model_name}
63
 
64
  kwargs = {"max_new_tokens":256, "temperature":0.9, "top_p":0.6, "repetition_penalty":1.3, "do_sample":True}
65