|
|
|
|
|
from typing import List, Dict, Generator, Optional |
|
from hf_client import get_inference_client |
|
|
|
def chat_completion( |
|
model_id: str, |
|
messages: List[Dict[str, str]], |
|
provider: Optional[str] = None, |
|
max_tokens: int = 4096 |
|
) -> str: |
|
""" |
|
Send a chat completion request to the appropriate inference provider. |
|
|
|
Args: |
|
model_id: The model identifier to use. |
|
messages: A list of OpenAI‑style {'role':'...', 'content':'...'} messages. |
|
provider: Optional override for provider; uses model default if None. |
|
max_tokens: Maximum tokens to generate. |
|
|
|
Returns: |
|
The assistant's response content. |
|
""" |
|
client = get_inference_client(model_id, provider or "auto") |
|
response = client.chat.completions.create( |
|
model=model_id, |
|
messages=messages, |
|
max_tokens=max_tokens |
|
) |
|
return response.choices[0].message.content |
|
|
|
|
|
def stream_chat_completion( |
|
model_id: str, |
|
messages: List[Dict[str, str]], |
|
provider: Optional[str] = None, |
|
max_tokens: int = 4096 |
|
) -> Generator[str, None, None]: |
|
""" |
|
Generator for streaming chat completions. |
|
Yields partial message chunks as strings. |
|
""" |
|
client = get_inference_client(model_id, provider or "auto") |
|
stream = client.chat.completions.create( |
|
model=model_id, |
|
messages=messages, |
|
max_tokens=max_tokens, |
|
stream=True |
|
) |
|
for chunk in stream: |
|
delta = getattr(chunk.choices[0].delta, "content", None) |
|
if delta: |
|
yield delta |
|
|