|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
from typing import Dict, Generator, List, Optional |
|
|
|
from hf_client import get_inference_client |
|
from models import find_model |
|
|
|
|
|
|
|
|
|
|
|
def _resolve_provider(model_id: str, override: str | None) -> str: |
|
""" |
|
Decide which provider to use. |
|
|
|
Priority: |
|
1. Explicit *override* arg supplied by caller. |
|
2. Model registry default_provider (see models.py). |
|
3. "auto" – lets HF route to the first available provider. |
|
""" |
|
if override: |
|
return override |
|
|
|
meta = find_model(model_id) |
|
return getattr(meta, "default_provider", "auto") if meta else "auto" |
|
|
|
|
|
|
|
|
|
|
|
def chat_completion( |
|
model_id: str, |
|
messages: List[Dict[str, str]], |
|
provider: Optional[str] = None, |
|
max_tokens: int = 4096, |
|
**kwargs, |
|
) -> str: |
|
""" |
|
Blocking convenience wrapper – returns the full assistant reply. |
|
|
|
Parameters |
|
---------- |
|
model_id : HF or provider‑qualified model path (e.g. "openai/gpt-4"). |
|
messages : OpenAI‑style [{'role': ..., 'content': ...}, …]. |
|
provider : Optional provider override; otherwise auto‑resolved. |
|
max_tokens : Token budget for generation. |
|
kwargs : Forward‑compatible extra arguments (temperature, etc.). |
|
|
|
Returns |
|
------- |
|
str – assistant message content. |
|
""" |
|
client = get_inference_client(model_id, _resolve_provider(model_id, provider)) |
|
resp = client.chat.completions.create( |
|
model=model_id, |
|
messages=messages, |
|
max_tokens=max_tokens, |
|
**kwargs, |
|
) |
|
return resp.choices[0].message.content |
|
|
|
|
|
def stream_chat_completion( |
|
model_id: str, |
|
messages: List[Dict[str, str]], |
|
provider: Optional[str] = None, |
|
max_tokens: int = 4096, |
|
**kwargs, |
|
) -> Generator[str, None, None]: |
|
""" |
|
Yield the assistant response *incrementally*. |
|
|
|
Example |
|
------- |
|
>>> for chunk in stream_chat_completion(model, msgs): |
|
... print(chunk, end='', flush=True) |
|
""" |
|
client = get_inference_client(model_id, _resolve_provider(model_id, provider)) |
|
stream = client.chat.completions.create( |
|
model=model_id, |
|
messages=messages, |
|
max_tokens=max_tokens, |
|
stream=True, |
|
**kwargs, |
|
) |
|
|
|
|
|
for chunk in stream: |
|
delta: str | None = getattr(chunk.choices[0].delta, "content", None) |
|
if delta: |
|
yield delta |
|
|