msy127 commited on
Commit
540d541
ยท
1 Parent(s): 21c8007

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -12
app.py CHANGED
@@ -26,29 +26,34 @@ vectordb = Chroma(
26
 
27
  retriever = vectordb.as_retriever(search_kwargs={"k": 5})
28
 
 
 
 
29
  class CustomInferenceClient:
30
  def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
31
- self.llm = LLM(model_name=model_name, hf_token=hf_token) # LLM ์ธ์Šคํ„ด์Šค
32
- self.kw_args_model = KwArgsModel(kwargs=kwargs) # KwArgsModel ์ธ์Šคํ„ด์Šค
33
- self.inference_client = InferenceClient(model=model_name, token=hf_token)
34
- self.model_name = model_name
35
 
36
- # ๊ธฐ์กด ๋ฉ”์„œ๋“œ๋“ค์„ ์•ฝ๊ฐ„ ์ˆ˜์ •ํ•˜์—ฌ ๋‚ด๋ถ€ LLM๊ณผ KwArgsModel ์ธ์Šคํ„ด์Šค๋ฅผ ์‚ฌ์šฉํ•˜๊ฒŒ ํ•จ
37
  def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
38
  if stop is not None:
39
- raise ValueError("stop kwargs are not permitted.")
40
- kwargs = self.kw_args_model.kwargs # KwArgsModel ์ธ์Šคํ„ด์Šค์—์„œ kwargs๋ฅผ ๊ฐ€์ ธ์˜ด
41
- response_gen = self.inference_client.text_generation(prompt, **kwargs, stream=True)
42
- response = ''.join(response_gen)
43
- return response
44
 
 
45
  @property
46
  def _llm_type(self) -> str:
47
- return "custom"
48
 
 
49
  @property
50
  def _identifying_params(self) -> dict:
51
- return {"model_name": self.model_name}
52
 
53
  kwargs = {"max_new_tokens":256, "temperature":0.9, "top_p":0.6, "repetition_penalty":1.3, "do_sample":True}
54
 
 
26
 
27
  retriever = vectordb.as_retriever(search_kwargs={"k": 5})
28
 
29
+ from typing import Optional, List, Dict, Any
30
+ # LLM, KwArgsModel, InferenceClient ๋“ฑ ํ•„์š”ํ•œ ๋ชจ๋“ˆ ๋˜๋Š” ํด๋ž˜์Šค๋ฅผ ๋ถˆ๋Ÿฌ์™€์•ผ ํ•จ
31
+
32
  class CustomInferenceClient:
33
  def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
34
+ self.llm = LLM(model_name=model_name, hf_token=hf_token) # LLM ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
35
+ self.kw_args_model = KwArgsModel(kwargs=kwargs) # KwArgsModel ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
36
+ self.inference_client = InferenceClient(model=model_name, token=hf_token) # InferenceClient ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
37
+ self.model_name = model_name # ๋ชจ๋ธ ์ด๋ฆ„ ์ €์žฅ
38
 
39
+ # _call ๋ฉ”์„œ๋“œ ๊ตฌํ˜„
40
  def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
41
  if stop is not None:
42
+ raise ValueError("stop kwargs are not permitted.") # stop ์ธ์ž๊ฐ€ ์žˆ์œผ๋ฉด ์—๋Ÿฌ ๋ฐœ์ƒ
43
+ kwargs = self.kw_args_model.kwargs # KwArgsModel๋กœ๋ถ€ํ„ฐ kwargs ๋ฐ›๊ธฐ
44
+ response_gen = self.inference_client.text_generation(prompt, **kwargs, stream=True) # ํ…์ŠคํŠธ ์ƒ์„ฑ ์š”์ฒญ
45
+ response = ''.join(response_gen) # ์ŠคํŠธ๋ฆผ์œผ๋กœ๋ถ€ํ„ฐ ๋ฌธ์ž์—ด ์ƒ์„ฑ
46
+ return response # ์ƒ์„ฑ๋œ ๋ฌธ์ž์—ด ๋ฐ˜ํ™˜
47
 
48
+ # _llm_type ์†์„ฑ ๊ตฌํ˜„
49
  @property
50
  def _llm_type(self) -> str:
51
+ return "custom" # ์‚ฌ์šฉ์ž ์ •์˜ ํƒ€์ž…์œผ๋กœ ์ง€์ •
52
 
53
+ # _identifying_params ์†์„ฑ ๊ตฌํ˜„
54
  @property
55
  def _identifying_params(self) -> dict:
56
+ return {"model_name": self.model_name} # ๋ชจ๋ธ ์ด๋ฆ„์„ ๋ฐ˜ํ™˜
57
 
58
  kwargs = {"max_new_tokens":256, "temperature":0.9, "top_p":0.6, "repetition_penalty":1.3, "do_sample":True}
59