|
|
|
from dataclasses import dataclass |
|
from typing import List, Optional |
|
|
|
@dataclass |
|
class ModelInfo: |
|
""" |
|
Represents metadata for an inference model. |
|
|
|
Attributes: |
|
name: Human-readable name of the model. |
|
id: Unique model identifier (HF/externally routed). |
|
description: Short description of the model's capabilities. |
|
default_provider: Preferred inference provider ("auto", "groq", "openai", "gemini", "fireworks"). |
|
""" |
|
name: str |
|
id: str |
|
description: str |
|
default_provider: str = "auto" |
|
|
|
|
|
AVAILABLE_MODELS: List[ModelInfo] = [ |
|
ModelInfo( |
|
name="Moonshot Kimi-K2", |
|
id="moonshotai/Kimi-K2-Instruct", |
|
description="Moonshot AI Kimi-K2-Instruct model for code generation and general tasks", |
|
default_provider="groq" |
|
), |
|
ModelInfo( |
|
name="DeepSeek V3", |
|
id="deepseek-ai/DeepSeek-V3-0324", |
|
description="DeepSeek V3 model for code generation", |
|
), |
|
ModelInfo( |
|
name="DeepSeek R1", |
|
id="deepseek-ai/DeepSeek-R1-0528", |
|
description="DeepSeek R1 model for code generation", |
|
), |
|
ModelInfo( |
|
name="ERNIE-4.5-VL", |
|
id="baidu/ERNIE-4.5-VL-424B-A47B-Base-PT", |
|
description="ERNIE-4.5-VL model for multimodal code generation with image support", |
|
), |
|
ModelInfo( |
|
name="MiniMax M1", |
|
id="MiniMaxAI/MiniMax-M1-80k", |
|
description="MiniMax M1 model for code generation and general tasks", |
|
), |
|
ModelInfo( |
|
name="Qwen3-235B-A22B", |
|
id="Qwen/Qwen3-235B-A22B", |
|
description="Qwen3-235B-A22B model for code generation and general tasks", |
|
), |
|
ModelInfo( |
|
name="SmolLM3-3B", |
|
id="HuggingFaceTB/SmolLM3-3B", |
|
description="SmolLM3-3B model for code generation and general tasks", |
|
), |
|
ModelInfo( |
|
name="GLM-4.1V-9B-Thinking", |
|
id="THUDM/GLM-4.1V-9B-Thinking", |
|
description="GLM-4.1V-9B-Thinking model for multimodal code generation with image support", |
|
), |
|
ModelInfo( |
|
name="OpenAI GPT-4", |
|
id="openai/gpt-4", |
|
description="OpenAI GPT-4 model via HF Inference Providers", |
|
default_provider="openai" |
|
), |
|
ModelInfo( |
|
name="Gemini Pro", |
|
id="gemini/pro", |
|
description="Google Gemini Pro model via HF Inference Providers", |
|
default_provider="gemini" |
|
), |
|
ModelInfo( |
|
name="Fireworks AI", |
|
id="fireworks-ai/fireworks-v1", |
|
description="Fireworks AI model via HF Inference Providers", |
|
default_provider="fireworks" |
|
), |
|
] |
|
|
|
|
|
def find_model(identifier: str) -> Optional[ModelInfo]: |
|
""" |
|
Lookup a model by its human name or identifier. |
|
|
|
Args: |
|
identifier: ModelInfo.name (case-insensitive) or ModelInfo.id |
|
Returns: |
|
The matching ModelInfo or None if not found. |
|
""" |
|
identifier_lower = identifier.lower() |
|
for model in AVAILABLE_MODELS: |
|
if model.id == identifier or model.name.lower() == identifier_lower: |
|
return model |
|
return None |
|
|
|
|
|
|
|
from typing import List, Dict |
|
from hf_client import get_inference_client |
|
|
|
|
|
def chat_completion( |
|
model_id: str, |
|
messages: List[Dict[str, str]], |
|
provider: str = None, |
|
max_tokens: int = 4096 |
|
) -> str: |
|
""" |
|
Send a chat completion request to the appropriate inference provider. |
|
|
|
Args: |
|
model_id: The model identifier to use. |
|
messages: A list of OpenAI-style {'role','content'} messages. |
|
provider: Optional override for provider; uses model default if None. |
|
max_tokens: Maximum tokens to generate. |
|
|
|
Returns: |
|
The assistant's response content. |
|
""" |
|
|
|
client = get_inference_client(model_id, provider or "auto") |
|
response = client.chat.completions.create( |
|
model=model_id, |
|
messages=messages, |
|
max_tokens=max_tokens |
|
) |
|
|
|
return response.choices[0].message.content |
|
|
|
|
|
def stream_chat_completion( |
|
model_id: str, |
|
messages: List[Dict[str, str]], |
|
provider: str = None, |
|
max_tokens: int = 4096 |
|
): |
|
""" |
|
Generator for streaming chat completions. |
|
Yields partial message chunks as strings. |
|
""" |
|
client = get_inference_client(model_id, provider or "auto") |
|
stream = client.chat.completions.create( |
|
model=model_id, |
|
messages=messages, |
|
max_tokens=max_tokens, |
|
stream=True |
|
) |
|
for chunk in stream: |
|
delta = getattr(chunk.choices[0].delta, "content", None) |
|
if delta: |
|
yield delta |
|
|