# inference.py # ------------------------------------------------------------- # Unified wrapper around hf_client.get_inference_client # with automatic provider‑routing based on model registry # (see models.py) and graceful fall‑back to Groq. # ------------------------------------------------------------- from __future__ import annotations from typing import Dict, Generator, List, Optional from hf_client import get_inference_client from models import find_model # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _resolve_provider(model_id: str, override: str | None) -> str: """ Decide which provider to use. Priority: 1. Explicit *override* arg supplied by caller. 2. Model registry default_provider (see models.py). 3. "auto" – lets HF route to the first available provider. """ if override: return override meta = find_model(model_id) return getattr(meta, "default_provider", "auto") if meta else "auto" # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def chat_completion( model_id: str, messages: List[Dict[str, str]], provider: Optional[str] = None, max_tokens: int = 4096, **kwargs, ) -> str: """ Blocking convenience wrapper – returns the full assistant reply. Parameters ---------- model_id : HF or provider‑qualified model path (e.g. "openai/gpt-4"). messages : OpenAI‑style [{'role': ..., 'content': ...}, …]. provider : Optional provider override; otherwise auto‑resolved. max_tokens : Token budget for generation. kwargs : Forward‑compatible extra arguments (temperature, etc.). Returns ------- str – assistant message content. """ client = get_inference_client(model_id, _resolve_provider(model_id, provider)) resp = client.chat.completions.create( model=model_id, messages=messages, max_tokens=max_tokens, **kwargs, ) return resp.choices[0].message.content def stream_chat_completion( model_id: str, messages: List[Dict[str, str]], provider: Optional[str] = None, max_tokens: int = 4096, **kwargs, ) -> Generator[str, None, None]: """ Yield the assistant response *incrementally*. Example ------- >>> for chunk in stream_chat_completion(model, msgs): ... print(chunk, end='', flush=True) """ client = get_inference_client(model_id, _resolve_provider(model_id, provider)) stream = client.chat.completions.create( model=model_id, messages=messages, max_tokens=max_tokens, stream=True, **kwargs, ) # HF Inference returns chunks with .choices[0].delta.content for chunk in stream: delta: str | None = getattr(chunk.choices[0].delta, "content", None) if delta: yield delta