wakeupmh commited on
Commit
ed14957
·
1 Parent(s): 743463b

fix: adapt to agent

Browse files
Files changed (1) hide show
  1. services/model_handler.py +23 -4
services/model_handler.py CHANGED
@@ -19,20 +19,39 @@ class LocalHuggingFaceModel(Model):
19
 
20
  async def ainvoke(self, prompt: str, **kwargs) -> str:
21
  """Async invoke method"""
22
- return self.invoke(prompt=prompt, **kwargs)
23
 
24
  async def ainvoke_stream(self, prompt: str, **kwargs):
25
  """Async streaming invoke method"""
26
- result = self.invoke(prompt=prompt, **kwargs)
27
  yield result
28
 
29
  def invoke(self, prompt: str, **kwargs) -> str:
30
  """Synchronous invoke method"""
31
- return self.generate(prompt=prompt, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def invoke_stream(self, prompt: str, **kwargs):
34
  """Synchronous streaming invoke method"""
35
- result = self.generate(prompt=prompt, **kwargs)
36
  yield result
37
 
38
  def parse_provider_response(self, response: str) -> str:
 
19
 
20
  async def ainvoke(self, prompt: str, **kwargs) -> str:
21
  """Async invoke method"""
22
+ return await self.invoke(prompt=prompt, **kwargs)
23
 
24
  async def ainvoke_stream(self, prompt: str, **kwargs):
25
  """Async streaming invoke method"""
26
+ result = await self.invoke(prompt=prompt, **kwargs)
27
  yield result
28
 
29
  def invoke(self, prompt: str, **kwargs) -> str:
30
  """Synchronous invoke method"""
31
+ try:
32
+ inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
33
+
34
+ # Configure generation parameters
35
+ generation_config = {
36
+ "max_length": self.max_length,
37
+ "num_return_sequences": 1,
38
+ "do_sample": kwargs.get("do_sample", False),
39
+ "temperature": kwargs.get("temperature", 1.0),
40
+ "top_p": kwargs.get("top_p", 1.0),
41
+ }
42
+
43
+ # Generate the answer
44
+ outputs = self.model.generate(**inputs, **generation_config)
45
+ decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
46
+
47
+ return decoded_output
48
+ except Exception as e:
49
+ logging.error(f"Error in local model generation: {str(e)}")
50
+ return f"Error during generation: {str(e)}"
51
 
52
  def invoke_stream(self, prompt: str, **kwargs):
53
  """Synchronous streaming invoke method"""
54
+ result = self.invoke(prompt=prompt, **kwargs)
55
  yield result
56
 
57
  def parse_provider_response(self, response: str) -> str: