builder / inference.py
mgbam's picture
Update inference.py
639177c verified
raw
history blame
3.06 kB
# inference.py
# -------------------------------------------------------------
# Unified wrapper around hf_client.get_inference_client
# with automatic provider‑routing based on model registry
# (see models.py) and graceful fall‑back to Groq.
# -------------------------------------------------------------
from __future__ import annotations
from typing import Dict, Generator, List, Optional
from hf_client import get_inference_client
from models import find_model
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _resolve_provider(model_id: str, override: str | None) -> str:
"""
Decide which provider to use.
Priority:
1. Explicit *override* arg supplied by caller.
2. Model registry default_provider (see models.py).
3. "auto" – lets HF route to the first available provider.
"""
if override:
return override
meta = find_model(model_id)
return getattr(meta, "default_provider", "auto") if meta else "auto"
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def chat_completion(
model_id: str,
messages: List[Dict[str, str]],
provider: Optional[str] = None,
max_tokens: int = 4096,
**kwargs,
) -> str:
"""
Blocking convenience wrapper – returns the full assistant reply.
Parameters
----------
model_id : HF or provider‑qualified model path (e.g. "openai/gpt-4").
messages : OpenAI‑style [{'role': ..., 'content': ...}, …].
provider : Optional provider override; otherwise auto‑resolved.
max_tokens : Token budget for generation.
kwargs : Forward‑compatible extra arguments (temperature, etc.).
Returns
-------
str – assistant message content.
"""
client = get_inference_client(model_id, _resolve_provider(model_id, provider))
resp = client.chat.completions.create(
model=model_id,
messages=messages,
max_tokens=max_tokens,
**kwargs,
)
return resp.choices[0].message.content
def stream_chat_completion(
model_id: str,
messages: List[Dict[str, str]],
provider: Optional[str] = None,
max_tokens: int = 4096,
**kwargs,
) -> Generator[str, None, None]:
"""
Yield the assistant response *incrementally*.
Example
-------
>>> for chunk in stream_chat_completion(model, msgs):
... print(chunk, end='', flush=True)
"""
client = get_inference_client(model_id, _resolve_provider(model_id, provider))
stream = client.chat.completions.create(
model=model_id,
messages=messages,
max_tokens=max_tokens,
stream=True,
**kwargs,
)
# HF Inference returns chunks with .choices[0].delta.content
for chunk in stream:
delta: str | None = getattr(chunk.choices[0].delta, "content", None)
if delta:
yield delta