wakeupmh commited on
Commit
beba082
·
1 Parent(s): 0943c16

test: local model

Browse files
Files changed (1) hide show
  1. services/model_handler.py +24 -0
services/model_handler.py CHANGED
@@ -17,6 +17,30 @@ class LocalHuggingFaceModel(Model):
17
  self.tokenizer = tokenizer
18
  self.max_length = max_length
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def generate(self, prompt, **kwargs):
21
  try:
22
  inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
 
17
  self.tokenizer = tokenizer
18
  self.max_length = max_length
19
 
20
+ async def ainvoke(self, prompt: str, **kwargs):
21
+ """Async invoke method"""
22
+ return self.invoke(prompt, **kwargs)
23
+
24
+ async def ainvoke_stream(self, prompt: str, **kwargs):
25
+ """Async streaming invoke method"""
26
+ yield self.invoke(prompt, **kwargs)
27
+
28
+ def invoke(self, prompt: str, **kwargs):
29
+ """Synchronous invoke method"""
30
+ return self.generate(prompt, **kwargs)
31
+
32
+ def invoke_stream(self, prompt: str, **kwargs):
33
+ """Synchronous streaming invoke method"""
34
+ yield self.generate(prompt, **kwargs)
35
+
36
+ def parse_provider_response(self, response: str) -> str:
37
+ """Parse the provider response"""
38
+ return response
39
+
40
+ def parse_provider_response_delta(self, delta: str) -> str:
41
+ """Parse the provider response delta for streaming"""
42
+ return delta
43
+
44
  def generate(self, prompt, **kwargs):
45
  try:
46
  inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)