msy127 commited on
Commit
21c8007
·
1 Parent(s): 06abb5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -25
app.py CHANGED
@@ -26,34 +26,20 @@ vectordb = Chroma(
26
 
27
  retriever = vectordb.as_retriever(search_kwargs={"k": 5})
28
 
29
- class KwArgsModel(BaseModel):
30
- kwargs: Dict[str, Any] = Field(default_factory=dict)
31
-
32
- class CombinedMeta(type(LLM), type(KwArgsModel)):
33
- pass
34
-
35
- class CustomInferenceClient(LLM, KwArgsModel, metaclass=CombinedMeta):
36
- model_name: str
37
- inference_client: InferenceClient
38
-
39
  def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
40
- inference_client = InferenceClient(model=model_name, token=hf_token)
41
- super().__init__(
42
- model_name=model_name,
43
- hf_token=hf_token,
44
- kwargs=kwargs,
45
- inference_client=inference_client
46
- )
47
-
48
- def _call(
49
- self,
50
- prompt: str,
51
- stop: Optional[List[str]] = None
52
- ) -> str:
53
  if stop is not None:
54
  raise ValueError("stop kwargs are not permitted.")
55
- response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True)
56
- response = ''.join(response_gen)
 
57
  return response
58
 
59
  @property
 
26
 
27
  retriever = vectordb.as_retriever(search_kwargs={"k": 5})
28
 
29
+ class CustomInferenceClient:
 
 
 
 
 
 
 
 
 
30
  def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
31
+ self.llm = LLM(model_name=model_name, hf_token=hf_token) # LLM 인스턴스
32
+ self.kw_args_model = KwArgsModel(kwargs=kwargs) # KwArgsModel 인스턴스
33
+ self.inference_client = InferenceClient(model=model_name, token=hf_token)
34
+ self.model_name = model_name
35
+
36
+ # 기존 메서드들을 약간 수정하여 내부 LLM과 KwArgsModel 인스턴스를 사용하게 함
37
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
 
 
 
 
 
 
38
  if stop is not None:
39
  raise ValueError("stop kwargs are not permitted.")
40
+ kwargs = self.kw_args_model.kwargs # KwArgsModel 인스턴스에서 kwargs 가져옴
41
+ response_gen = self.inference_client.text_generation(prompt, **kwargs, stream=True)
42
+ response = ''.join(response_gen)
43
  return response
44
 
45
  @property