mgbam commited on
Commit
639177c
·
verified ·
1 Parent(s): 4610542

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +62 -28
inference.py CHANGED
@@ -1,37 +1,67 @@
1
  # inference.py
2
- from typing import List, Dict, Optional
 
 
 
 
 
 
 
 
3
  from hf_client import get_inference_client
4
  from models import find_model
5
 
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def chat_completion(
8
  model_id: str,
9
  messages: List[Dict[str, str]],
10
  provider: Optional[str] = None,
11
- max_tokens: int = 4096
 
12
  ) -> str:
13
  """
14
- Send a chat completion request to the appropriate inference provider.
15
 
16
- Args:
17
- model_id: The model identifier to use.
18
- messages: A list of OpenAI-style {'role','content'} messages.
19
- provider: Optional override for provider; uses model default if None.
20
- max_tokens: Maximum tokens to generate.
 
 
21
 
22
- Returns:
23
- The assistant's response content.
 
24
  """
25
- # resolve default provider from registry if needed
26
- if provider is None:
27
- meta = find_model(model_id)
28
- provider = meta.default_provider if meta else "auto"
29
-
30
- client = get_inference_client(model_id, provider)
31
  resp = client.chat.completions.create(
32
  model=model_id,
33
  messages=messages,
34
- max_tokens=max_tokens
 
35
  )
36
  return resp.choices[0].message.content
37
 
@@ -40,24 +70,28 @@ def stream_chat_completion(
40
  model_id: str,
41
  messages: List[Dict[str, str]],
42
  provider: Optional[str] = None,
43
- max_tokens: int = 4096
44
- ):
 
45
  """
46
- Generator for streaming chat completions.
47
- Yields partial message chunks as strings.
48
- """
49
- if provider is None:
50
- meta = find_model(model_id)
51
- provider = meta.default_provider if meta else "auto"
52
 
53
- client = get_inference_client(model_id, provider)
 
 
 
 
 
54
  stream = client.chat.completions.create(
55
  model=model_id,
56
  messages=messages,
57
  max_tokens=max_tokens,
58
- stream=True
 
59
  )
 
 
60
  for chunk in stream:
61
- delta = getattr(chunk.choices[0].delta, "content", None)
62
  if delta:
63
  yield delta
 
1
  # inference.py
2
+ # -------------------------------------------------------------
3
+ # Unified wrapper around hf_client.get_inference_client
4
+ # with automatic provider‑routing based on model registry
5
+ # (see models.py) and graceful fall‑back to Groq.
6
+ # -------------------------------------------------------------
7
+ from __future__ import annotations
8
+
9
+ from typing import Dict, Generator, List, Optional
10
+
11
  from hf_client import get_inference_client
12
  from models import find_model
13
 
14
 
15
+ # ------------------------------------------------------------------
16
+ # Helpers
17
+ # ------------------------------------------------------------------
18
+ def _resolve_provider(model_id: str, override: str | None) -> str:
19
+ """
20
+ Decide which provider to use.
21
+
22
+ Priority:
23
+ 1. Explicit *override* arg supplied by caller.
24
+ 2. Model registry default_provider (see models.py).
25
+ 3. "auto" – lets HF route to the first available provider.
26
+ """
27
+ if override:
28
+ return override
29
+
30
+ meta = find_model(model_id)
31
+ return getattr(meta, "default_provider", "auto") if meta else "auto"
32
+
33
+
34
+ # ------------------------------------------------------------------
35
+ # Public API
36
+ # ------------------------------------------------------------------
37
  def chat_completion(
38
  model_id: str,
39
  messages: List[Dict[str, str]],
40
  provider: Optional[str] = None,
41
+ max_tokens: int = 4096,
42
+ **kwargs,
43
  ) -> str:
44
  """
45
+ Blocking convenience wrapper returns the full assistant reply.
46
 
47
+ Parameters
48
+ ----------
49
+ model_id : HF or provider‑qualified model path (e.g. "openai/gpt-4").
50
+ messages : OpenAI‑style [{'role': ..., 'content': ...}, …].
51
+ provider : Optional provider override; otherwise auto‑resolved.
52
+ max_tokens : Token budget for generation.
53
+ kwargs : Forward‑compatible extra arguments (temperature, etc.).
54
 
55
+ Returns
56
+ -------
57
+ str – assistant message content.
58
  """
59
+ client = get_inference_client(model_id, _resolve_provider(model_id, provider))
 
 
 
 
 
60
  resp = client.chat.completions.create(
61
  model=model_id,
62
  messages=messages,
63
+ max_tokens=max_tokens,
64
+ **kwargs,
65
  )
66
  return resp.choices[0].message.content
67
 
 
70
  model_id: str,
71
  messages: List[Dict[str, str]],
72
  provider: Optional[str] = None,
73
+ max_tokens: int = 4096,
74
+ **kwargs,
75
+ ) -> Generator[str, None, None]:
76
  """
77
+ Yield the assistant response *incrementally*.
 
 
 
 
 
78
 
79
+ Example
80
+ -------
81
+ >>> for chunk in stream_chat_completion(model, msgs):
82
+ ... print(chunk, end='', flush=True)
83
+ """
84
+ client = get_inference_client(model_id, _resolve_provider(model_id, provider))
85
  stream = client.chat.completions.create(
86
  model=model_id,
87
  messages=messages,
88
  max_tokens=max_tokens,
89
+ stream=True,
90
+ **kwargs,
91
  )
92
+
93
+ # HF Inference returns chunks with .choices[0].delta.content
94
  for chunk in stream:
95
+ delta: str | None = getattr(chunk.choices[0].delta, "content", None)
96
  if delta:
97
  yield delta