builder / models.py
mgbam's picture
Create models.py
dcf9dad verified
raw
history blame
4.66 kB
# models.py
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class ModelInfo:
"""
Represents metadata for an inference model.
Attributes:
name: Human-readable name of the model.
id: Unique model identifier (HF/externally routed).
description: Short description of the model's capabilities.
default_provider: Preferred inference provider ("auto", "groq", "openai", "gemini", "fireworks").
"""
name: str
id: str
description: str
default_provider: str = "auto"
# Registry of supported models
AVAILABLE_MODELS: List[ModelInfo] = [
ModelInfo(
name="Moonshot Kimi-K2",
id="moonshotai/Kimi-K2-Instruct",
description="Moonshot AI Kimi-K2-Instruct model for code generation and general tasks",
default_provider="groq"
),
ModelInfo(
name="DeepSeek V3",
id="deepseek-ai/DeepSeek-V3-0324",
description="DeepSeek V3 model for code generation",
),
ModelInfo(
name="DeepSeek R1",
id="deepseek-ai/DeepSeek-R1-0528",
description="DeepSeek R1 model for code generation",
),
ModelInfo(
name="ERNIE-4.5-VL",
id="baidu/ERNIE-4.5-VL-424B-A47B-Base-PT",
description="ERNIE-4.5-VL model for multimodal code generation with image support",
),
ModelInfo(
name="MiniMax M1",
id="MiniMaxAI/MiniMax-M1-80k",
description="MiniMax M1 model for code generation and general tasks",
),
ModelInfo(
name="Qwen3-235B-A22B",
id="Qwen/Qwen3-235B-A22B",
description="Qwen3-235B-A22B model for code generation and general tasks",
),
ModelInfo(
name="SmolLM3-3B",
id="HuggingFaceTB/SmolLM3-3B",
description="SmolLM3-3B model for code generation and general tasks",
),
ModelInfo(
name="GLM-4.1V-9B-Thinking",
id="THUDM/GLM-4.1V-9B-Thinking",
description="GLM-4.1V-9B-Thinking model for multimodal code generation with image support",
),
ModelInfo(
name="OpenAI GPT-4",
id="openai/gpt-4",
description="OpenAI GPT-4 model via HF Inference Providers",
default_provider="openai"
),
ModelInfo(
name="Gemini Pro",
id="gemini/pro",
description="Google Gemini Pro model via HF Inference Providers",
default_provider="gemini"
),
ModelInfo(
name="Fireworks AI",
id="fireworks-ai/fireworks-v1",
description="Fireworks AI model via HF Inference Providers",
default_provider="fireworks"
),
]
def find_model(identifier: str) -> Optional[ModelInfo]:
"""
Lookup a model by its human name or identifier.
Args:
identifier: ModelInfo.name (case-insensitive) or ModelInfo.id
Returns:
The matching ModelInfo or None if not found.
"""
identifier_lower = identifier.lower()
for model in AVAILABLE_MODELS:
if model.id == identifier or model.name.lower() == identifier_lower:
return model
return None
# inference.py
from typing import List, Dict
from hf_client import get_inference_client
def chat_completion(
model_id: str,
messages: List[Dict[str, str]],
provider: str = None,
max_tokens: int = 4096
) -> str:
"""
Send a chat completion request to the appropriate inference provider.
Args:
model_id: The model identifier to use.
messages: A list of OpenAI-style {'role','content'} messages.
provider: Optional override for provider; uses model default if None.
max_tokens: Maximum tokens to generate.
Returns:
The assistant's response content.
"""
# Initialize client (provider resolution inside)
client = get_inference_client(model_id, provider or "auto")
response = client.chat.completions.create(
model=model_id,
messages=messages,
max_tokens=max_tokens
)
# Extract and return first choice content
return response.choices[0].message.content
def stream_chat_completion(
model_id: str,
messages: List[Dict[str, str]],
provider: str = None,
max_tokens: int = 4096
):
"""
Generator for streaming chat completions.
Yields partial message chunks as strings.
"""
client = get_inference_client(model_id, provider or "auto")
stream = client.chat.completions.create(
model=model_id,
messages=messages,
max_tokens=max_tokens,
stream=True
)
for chunk in stream:
delta = getattr(chunk.choices[0].delta, "content", None)
if delta:
yield delta