File size: 4,658 Bytes
dcf9dad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# models.py
from dataclasses import dataclass
from typing import List, Optional

@dataclass
class ModelInfo:
    """
    Represents metadata for an inference model.

    Attributes:
        name: Human-readable name of the model.
        id: Unique model identifier (HF/externally routed).
        description: Short description of the model's capabilities.
        default_provider: Preferred inference provider ("auto", "groq", "openai", "gemini", "fireworks").
    """
    name: str
    id: str
    description: str
    default_provider: str = "auto"

# Registry of supported models
AVAILABLE_MODELS: List[ModelInfo] = [
    ModelInfo(
        name="Moonshot Kimi-K2",
        id="moonshotai/Kimi-K2-Instruct",
        description="Moonshot AI Kimi-K2-Instruct model for code generation and general tasks",
        default_provider="groq"
    ),
    ModelInfo(
        name="DeepSeek V3",
        id="deepseek-ai/DeepSeek-V3-0324",
        description="DeepSeek V3 model for code generation",
    ),
    ModelInfo(
        name="DeepSeek R1",
        id="deepseek-ai/DeepSeek-R1-0528",
        description="DeepSeek R1 model for code generation",
    ),
    ModelInfo(
        name="ERNIE-4.5-VL",
        id="baidu/ERNIE-4.5-VL-424B-A47B-Base-PT",
        description="ERNIE-4.5-VL model for multimodal code generation with image support",
    ),
    ModelInfo(
        name="MiniMax M1",
        id="MiniMaxAI/MiniMax-M1-80k",
        description="MiniMax M1 model for code generation and general tasks",
    ),
    ModelInfo(
        name="Qwen3-235B-A22B",
        id="Qwen/Qwen3-235B-A22B",
        description="Qwen3-235B-A22B model for code generation and general tasks",
    ),
    ModelInfo(
        name="SmolLM3-3B",
        id="HuggingFaceTB/SmolLM3-3B",
        description="SmolLM3-3B model for code generation and general tasks",
    ),
    ModelInfo(
        name="GLM-4.1V-9B-Thinking",
        id="THUDM/GLM-4.1V-9B-Thinking",
        description="GLM-4.1V-9B-Thinking model for multimodal code generation with image support",
    ),
    ModelInfo(
        name="OpenAI GPT-4",
        id="openai/gpt-4",
        description="OpenAI GPT-4 model via HF Inference Providers",
        default_provider="openai"
    ),
    ModelInfo(
        name="Gemini Pro",
        id="gemini/pro",
        description="Google Gemini Pro model via HF Inference Providers",
        default_provider="gemini"
    ),
    ModelInfo(
        name="Fireworks AI",
        id="fireworks-ai/fireworks-v1",
        description="Fireworks AI model via HF Inference Providers",
        default_provider="fireworks"
    ),
]


def find_model(identifier: str) -> Optional[ModelInfo]:
    """
    Lookup a model by its human name or identifier.

    Args:
        identifier: ModelInfo.name (case-insensitive) or ModelInfo.id
    Returns:
        The matching ModelInfo or None if not found.
    """
    identifier_lower = identifier.lower()
    for model in AVAILABLE_MODELS:
        if model.id == identifier or model.name.lower() == identifier_lower:
            return model
    return None


# inference.py
from typing import List, Dict
from hf_client import get_inference_client


def chat_completion(
    model_id: str,
    messages: List[Dict[str, str]],
    provider: str = None,
    max_tokens: int = 4096
) -> str:
    """
    Send a chat completion request to the appropriate inference provider.

    Args:
        model_id: The model identifier to use.
        messages: A list of OpenAI-style {'role','content'} messages.
        provider: Optional override for provider; uses model default if None.
        max_tokens: Maximum tokens to generate.

    Returns:
        The assistant's response content.
    """
    # Initialize client (provider resolution inside)
    client = get_inference_client(model_id, provider or "auto")
    response = client.chat.completions.create(
        model=model_id,
        messages=messages,
        max_tokens=max_tokens
    )
    # Extract and return first choice content
    return response.choices[0].message.content


def stream_chat_completion(
    model_id: str,
    messages: List[Dict[str, str]],
    provider: str = None,
    max_tokens: int = 4096
):
    """
    Generator for streaming chat completions.
    Yields partial message chunks as strings.
    """
    client = get_inference_client(model_id, provider or "auto")
    stream = client.chat.completions.create(
        model=model_id,
        messages=messages,
        max_tokens=max_tokens,
        stream=True
    )
    for chunk in stream:
        delta = getattr(chunk.choices[0].delta, "content", None)
        if delta:
            yield delta