wakeupmh commited on
Commit
2e9d088
·
1 Parent(s): 6952b41

fix: response

Browse files
Files changed (1) hide show
  1. services/model_handler.py +16 -7
services/model_handler.py CHANGED
@@ -9,6 +9,14 @@ from tenacity import retry, stop_after_attempt, wait_exponential
9
 
10
  MODEL_PATH = "google/flan-t5-small"
11
 
 
 
 
 
 
 
 
 
12
  # Personnalized class for local models
13
  class LocalHuggingFaceModel(Model):
14
  def __init__(self, model, tokenizer, max_length=512):
@@ -62,31 +70,33 @@ class LocalHuggingFaceModel(Model):
62
  """Parse the provider response delta for streaming"""
63
  return delta
64
 
65
- async def aresponse(self, prompt=None, **kwargs) -> str:
66
  """Async response method - required abstract method"""
67
  if prompt is None:
68
  prompt = kwargs.get('input', '')
69
- return await self.ainvoke(prompt=prompt, **kwargs)
 
70
 
71
  async def aresponse_stream(self, prompt=None, **kwargs):
72
  """Async streaming response method - required abstract method"""
73
  if prompt is None:
74
  prompt = kwargs.get('input', '')
75
  async for chunk in self.ainvoke_stream(prompt=prompt, **kwargs):
76
- yield chunk
77
 
78
- def response(self, prompt=None, **kwargs) -> str:
79
  """Synchronous response method - required abstract method"""
80
  if prompt is None:
81
  prompt = kwargs.get('input', '')
82
- return self.invoke(prompt=prompt, **kwargs)
 
83
 
84
  def response_stream(self, prompt=None, **kwargs):
85
  """Synchronous streaming response method - required abstract method"""
86
  if prompt is None:
87
  prompt = kwargs.get('input', '')
88
  for chunk in self.invoke_stream(prompt=prompt, **kwargs):
89
- yield chunk
90
 
91
  def generate(self, prompt: str, **kwargs):
92
  try:
@@ -243,7 +253,6 @@ Output:"""
243
 
244
  # Get English translation
245
  translation = self.translator.run(prompt=translation_prompt, stream=False)
246
- logging.info(f"Translation: {translation}")
247
  if not translation:
248
  logging.error("Translation failed")
249
  return "Error: Unable to translate the query"
 
9
 
10
  MODEL_PATH = "google/flan-t5-small"
11
 
12
+ # Simple Response class to wrap the model output
13
+ class Response:
14
+ def __init__(self, content):
15
+ self.content = content
16
+
17
+ def __str__(self):
18
+ return self.content
19
+
20
  # Personnalized class for local models
21
  class LocalHuggingFaceModel(Model):
22
  def __init__(self, model, tokenizer, max_length=512):
 
70
  """Parse the provider response delta for streaming"""
71
  return delta
72
 
73
+ async def aresponse(self, prompt=None, **kwargs):
74
  """Async response method - required abstract method"""
75
  if prompt is None:
76
  prompt = kwargs.get('input', '')
77
+ content = await self.ainvoke(prompt=prompt, **kwargs)
78
+ return Response(content)
79
 
80
  async def aresponse_stream(self, prompt=None, **kwargs):
81
  """Async streaming response method - required abstract method"""
82
  if prompt is None:
83
  prompt = kwargs.get('input', '')
84
  async for chunk in self.ainvoke_stream(prompt=prompt, **kwargs):
85
+ yield Response(chunk)
86
 
87
+ def response(self, prompt=None, **kwargs):
88
  """Synchronous response method - required abstract method"""
89
  if prompt is None:
90
  prompt = kwargs.get('input', '')
91
+ content = self.invoke(prompt=prompt, **kwargs)
92
+ return Response(content)
93
 
94
  def response_stream(self, prompt=None, **kwargs):
95
  """Synchronous streaming response method - required abstract method"""
96
  if prompt is None:
97
  prompt = kwargs.get('input', '')
98
  for chunk in self.invoke_stream(prompt=prompt, **kwargs):
99
+ yield Response(chunk)
100
 
101
  def generate(self, prompt: str, **kwargs):
102
  try:
 
253
 
254
  # Get English translation
255
  translation = self.translator.run(prompt=translation_prompt, stream=False)
 
256
  if not translation:
257
  logging.error("Translation failed")
258
  return "Error: Unable to translate the query"