File size: 3,060 Bytes
0181a1f 639177c 0181a1f 1e7a57c 0181a1f 639177c 0181a1f 639177c 0181a1f 639177c 0181a1f 639177c 0181a1f 639177c 0181a1f 639177c 1e7a57c 0181a1f 639177c 0181a1f 1e7a57c 0181a1f 639177c 0181a1f 639177c 1e7a57c 639177c 0181a1f 639177c 0181a1f 639177c 0181a1f 639177c 0181a1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
# inference.py
# -------------------------------------------------------------
# Unified wrapper around hf_client.get_inference_client
# with automatic provider‑routing based on model registry
# (see models.py) and graceful fall‑back to Groq.
# -------------------------------------------------------------
from __future__ import annotations
from typing import Dict, Generator, List, Optional
from hf_client import get_inference_client
from models import find_model
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _resolve_provider(model_id: str, override: str | None) -> str:
"""
Decide which provider to use.
Priority:
1. Explicit *override* arg supplied by caller.
2. Model registry default_provider (see models.py).
3. "auto" – lets HF route to the first available provider.
"""
if override:
return override
meta = find_model(model_id)
return getattr(meta, "default_provider", "auto") if meta else "auto"
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def chat_completion(
model_id: str,
messages: List[Dict[str, str]],
provider: Optional[str] = None,
max_tokens: int = 4096,
**kwargs,
) -> str:
"""
Blocking convenience wrapper – returns the full assistant reply.
Parameters
----------
model_id : HF or provider‑qualified model path (e.g. "openai/gpt-4").
messages : OpenAI‑style [{'role': ..., 'content': ...}, …].
provider : Optional provider override; otherwise auto‑resolved.
max_tokens : Token budget for generation.
kwargs : Forward‑compatible extra arguments (temperature, etc.).
Returns
-------
str – assistant message content.
"""
client = get_inference_client(model_id, _resolve_provider(model_id, provider))
resp = client.chat.completions.create(
model=model_id,
messages=messages,
max_tokens=max_tokens,
**kwargs,
)
return resp.choices[0].message.content
def stream_chat_completion(
model_id: str,
messages: List[Dict[str, str]],
provider: Optional[str] = None,
max_tokens: int = 4096,
**kwargs,
) -> Generator[str, None, None]:
"""
Yield the assistant response *incrementally*.
Example
-------
>>> for chunk in stream_chat_completion(model, msgs):
... print(chunk, end='', flush=True)
"""
client = get_inference_client(model_id, _resolve_provider(model_id, provider))
stream = client.chat.completions.create(
model=model_id,
messages=messages,
max_tokens=max_tokens,
stream=True,
**kwargs,
)
# HF Inference returns chunks with .choices[0].delta.content
for chunk in stream:
delta: str | None = getattr(chunk.choices[0].delta, "content", None)
if delta:
yield delta
|