mgbam commited on
Commit
1e7a57c
·
verified ·
1 Parent(s): d4adf5c

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +18 -8
inference.py CHANGED
@@ -1,7 +1,8 @@
1
  # inference.py
2
-
3
- from typing import List, Dict, Generator, Optional
4
  from hf_client import get_inference_client
 
 
5
 
6
  def chat_completion(
7
  model_id: str,
@@ -14,20 +15,25 @@ def chat_completion(
14
 
15
  Args:
16
  model_id: The model identifier to use.
17
- messages: A list of OpenAIstyle {'role':'...', 'content':'...'} messages.
18
  provider: Optional override for provider; uses model default if None.
19
  max_tokens: Maximum tokens to generate.
20
 
21
  Returns:
22
  The assistant's response content.
23
  """
24
- client = get_inference_client(model_id, provider or "auto")
25
- response = client.chat.completions.create(
 
 
 
 
 
26
  model=model_id,
27
  messages=messages,
28
  max_tokens=max_tokens
29
  )
30
- return response.choices[0].message.content
31
 
32
 
33
  def stream_chat_completion(
@@ -35,12 +41,16 @@ def stream_chat_completion(
35
  messages: List[Dict[str, str]],
36
  provider: Optional[str] = None,
37
  max_tokens: int = 4096
38
- ) -> Generator[str, None, None]:
39
  """
40
  Generator for streaming chat completions.
41
  Yields partial message chunks as strings.
42
  """
43
- client = get_inference_client(model_id, provider or "auto")
 
 
 
 
44
  stream = client.chat.completions.create(
45
  model=model_id,
46
  messages=messages,
 
1
  # inference.py
2
+ from typing import List, Dict, Optional
 
3
  from hf_client import get_inference_client
4
+ from models import find_model
5
+
6
 
7
  def chat_completion(
8
  model_id: str,
 
15
 
16
  Args:
17
  model_id: The model identifier to use.
18
+ messages: A list of OpenAI-style {'role','content'} messages.
19
  provider: Optional override for provider; uses model default if None.
20
  max_tokens: Maximum tokens to generate.
21
 
22
  Returns:
23
  The assistant's response content.
24
  """
25
+ # resolve default provider from registry if needed
26
+ if provider is None:
27
+ meta = find_model(model_id)
28
+ provider = meta.default_provider if meta else "auto"
29
+
30
+ client = get_inference_client(model_id, provider)
31
+ resp = client.chat.completions.create(
32
  model=model_id,
33
  messages=messages,
34
  max_tokens=max_tokens
35
  )
36
+ return resp.choices[0].message.content
37
 
38
 
39
  def stream_chat_completion(
 
41
  messages: List[Dict[str, str]],
42
  provider: Optional[str] = None,
43
  max_tokens: int = 4096
44
+ ):
45
  """
46
  Generator for streaming chat completions.
47
  Yields partial message chunks as strings.
48
  """
49
+ if provider is None:
50
+ meta = find_model(model_id)
51
+ provider = meta.default_provider if meta else "auto"
52
+
53
+ client = get_inference_client(model_id, provider)
54
  stream = client.chat.completions.create(
55
  model=model_id,
56
  messages=messages,