mgbam commited on
Commit
dcf9dad
·
verified ·
1 Parent(s): e7d5ce8

Create models.py

Browse files
Files changed (1) hide show
  1. models.py +155 -0
models.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models.py
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional
4
+
5
+ @dataclass
6
+ class ModelInfo:
7
+ """
8
+ Represents metadata for an inference model.
9
+
10
+ Attributes:
11
+ name: Human-readable name of the model.
12
+ id: Unique model identifier (HF/externally routed).
13
+ description: Short description of the model's capabilities.
14
+ default_provider: Preferred inference provider ("auto", "groq", "openai", "gemini", "fireworks").
15
+ """
16
+ name: str
17
+ id: str
18
+ description: str
19
+ default_provider: str = "auto"
20
+
21
+ # Registry of supported models
22
+ AVAILABLE_MODELS: List[ModelInfo] = [
23
+ ModelInfo(
24
+ name="Moonshot Kimi-K2",
25
+ id="moonshotai/Kimi-K2-Instruct",
26
+ description="Moonshot AI Kimi-K2-Instruct model for code generation and general tasks",
27
+ default_provider="groq"
28
+ ),
29
+ ModelInfo(
30
+ name="DeepSeek V3",
31
+ id="deepseek-ai/DeepSeek-V3-0324",
32
+ description="DeepSeek V3 model for code generation",
33
+ ),
34
+ ModelInfo(
35
+ name="DeepSeek R1",
36
+ id="deepseek-ai/DeepSeek-R1-0528",
37
+ description="DeepSeek R1 model for code generation",
38
+ ),
39
+ ModelInfo(
40
+ name="ERNIE-4.5-VL",
41
+ id="baidu/ERNIE-4.5-VL-424B-A47B-Base-PT",
42
+ description="ERNIE-4.5-VL model for multimodal code generation with image support",
43
+ ),
44
+ ModelInfo(
45
+ name="MiniMax M1",
46
+ id="MiniMaxAI/MiniMax-M1-80k",
47
+ description="MiniMax M1 model for code generation and general tasks",
48
+ ),
49
+ ModelInfo(
50
+ name="Qwen3-235B-A22B",
51
+ id="Qwen/Qwen3-235B-A22B",
52
+ description="Qwen3-235B-A22B model for code generation and general tasks",
53
+ ),
54
+ ModelInfo(
55
+ name="SmolLM3-3B",
56
+ id="HuggingFaceTB/SmolLM3-3B",
57
+ description="SmolLM3-3B model for code generation and general tasks",
58
+ ),
59
+ ModelInfo(
60
+ name="GLM-4.1V-9B-Thinking",
61
+ id="THUDM/GLM-4.1V-9B-Thinking",
62
+ description="GLM-4.1V-9B-Thinking model for multimodal code generation with image support",
63
+ ),
64
+ ModelInfo(
65
+ name="OpenAI GPT-4",
66
+ id="openai/gpt-4",
67
+ description="OpenAI GPT-4 model via HF Inference Providers",
68
+ default_provider="openai"
69
+ ),
70
+ ModelInfo(
71
+ name="Gemini Pro",
72
+ id="gemini/pro",
73
+ description="Google Gemini Pro model via HF Inference Providers",
74
+ default_provider="gemini"
75
+ ),
76
+ ModelInfo(
77
+ name="Fireworks AI",
78
+ id="fireworks-ai/fireworks-v1",
79
+ description="Fireworks AI model via HF Inference Providers",
80
+ default_provider="fireworks"
81
+ ),
82
+ ]
83
+
84
+
85
+ def find_model(identifier: str) -> Optional[ModelInfo]:
86
+ """
87
+ Lookup a model by its human name or identifier.
88
+
89
+ Args:
90
+ identifier: ModelInfo.name (case-insensitive) or ModelInfo.id
91
+ Returns:
92
+ The matching ModelInfo or None if not found.
93
+ """
94
+ identifier_lower = identifier.lower()
95
+ for model in AVAILABLE_MODELS:
96
+ if model.id == identifier or model.name.lower() == identifier_lower:
97
+ return model
98
+ return None
99
+
100
+
101
+ # inference.py
102
+ from typing import List, Dict
103
+ from hf_client import get_inference_client
104
+
105
+
106
+ def chat_completion(
107
+ model_id: str,
108
+ messages: List[Dict[str, str]],
109
+ provider: str = None,
110
+ max_tokens: int = 4096
111
+ ) -> str:
112
+ """
113
+ Send a chat completion request to the appropriate inference provider.
114
+
115
+ Args:
116
+ model_id: The model identifier to use.
117
+ messages: A list of OpenAI-style {'role','content'} messages.
118
+ provider: Optional override for provider; uses model default if None.
119
+ max_tokens: Maximum tokens to generate.
120
+
121
+ Returns:
122
+ The assistant's response content.
123
+ """
124
+ # Initialize client (provider resolution inside)
125
+ client = get_inference_client(model_id, provider or "auto")
126
+ response = client.chat.completions.create(
127
+ model=model_id,
128
+ messages=messages,
129
+ max_tokens=max_tokens
130
+ )
131
+ # Extract and return first choice content
132
+ return response.choices[0].message.content
133
+
134
+
135
+ def stream_chat_completion(
136
+ model_id: str,
137
+ messages: List[Dict[str, str]],
138
+ provider: str = None,
139
+ max_tokens: int = 4096
140
+ ):
141
+ """
142
+ Generator for streaming chat completions.
143
+ Yields partial message chunks as strings.
144
+ """
145
+ client = get_inference_client(model_id, provider or "auto")
146
+ stream = client.chat.completions.create(
147
+ model=model_id,
148
+ messages=messages,
149
+ max_tokens=max_tokens,
150
+ stream=True
151
+ )
152
+ for chunk in stream:
153
+ delta = getattr(chunk.choices[0].delta, "content", None)
154
+ if delta:
155
+ yield delta