File size: 3,060 Bytes
0181a1f
639177c
 
 
 
 
 
 
 
 
0181a1f
1e7a57c
 
0181a1f
639177c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0181a1f
 
 
 
639177c
 
0181a1f
 
639177c
0181a1f
639177c
 
 
 
 
 
 
0181a1f
639177c
 
 
0181a1f
639177c
1e7a57c
0181a1f
 
639177c
 
0181a1f
1e7a57c
0181a1f
 
 
 
 
 
639177c
 
 
0181a1f
639177c
1e7a57c
639177c
 
 
 
 
 
0181a1f
 
 
 
639177c
 
0181a1f
639177c
 
0181a1f
639177c
0181a1f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# inference.py
# -------------------------------------------------------------
# Unified wrapper around hf_client.get_inference_client
# with automatic provider‑routing based on model registry
# (see models.py) and graceful fall‑back to Groq.
# -------------------------------------------------------------
from __future__ import annotations

from typing import Dict, Generator, List, Optional

from hf_client import get_inference_client
from models import find_model


# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _resolve_provider(model_id: str, override: str | None) -> str:
    """
    Decide which provider to use.

    Priority:
    1. Explicit *override* arg supplied by caller.
    2. Model registry default_provider (see models.py).
    3. "auto" – lets HF route to the first available provider.
    """
    if override:
        return override

    meta = find_model(model_id)
    return getattr(meta, "default_provider", "auto") if meta else "auto"


# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def chat_completion(
    model_id: str,
    messages: List[Dict[str, str]],
    provider: Optional[str] = None,
    max_tokens: int = 4096,
    **kwargs,
) -> str:
    """
    Blocking convenience wrapper – returns the full assistant reply.

    Parameters
    ----------
    model_id      : HF or provider‑qualified model path (e.g. "openai/gpt-4").
    messages      : OpenAI‑style [{'role': ..., 'content': ...}, …].
    provider      : Optional provider override; otherwise auto‑resolved.
    max_tokens    : Token budget for generation.
    kwargs        : Forward‑compatible extra arguments (temperature, etc.).

    Returns
    -------
    str – assistant message content.
    """
    client = get_inference_client(model_id, _resolve_provider(model_id, provider))
    resp = client.chat.completions.create(
        model=model_id,
        messages=messages,
        max_tokens=max_tokens,
        **kwargs,
    )
    return resp.choices[0].message.content


def stream_chat_completion(
    model_id: str,
    messages: List[Dict[str, str]],
    provider: Optional[str] = None,
    max_tokens: int = 4096,
    **kwargs,
) -> Generator[str, None, None]:
    """
    Yield the assistant response *incrementally*.

    Example
    -------
    >>> for chunk in stream_chat_completion(model, msgs):
    ...     print(chunk, end='', flush=True)
    """
    client = get_inference_client(model_id, _resolve_provider(model_id, provider))
    stream = client.chat.completions.create(
        model=model_id,
        messages=messages,
        max_tokens=max_tokens,
        stream=True,
        **kwargs,
    )

    # HF Inference returns chunks with .choices[0].delta.content
    for chunk in stream:
        delta: str | None = getattr(chunk.choices[0].delta, "content", None)
        if delta:
            yield delta