Update hf_client.py
Browse files- hf_client.py +23 -5
hf_client.py
CHANGED
@@ -5,18 +5,36 @@ from tavily import TavilyClient
|
|
5 |
import os
|
6 |
|
7 |
# HF Inference Client
|
8 |
-
|
|
|
|
|
|
|
|
|
9 |
if not HF_TOKEN:
|
10 |
-
raise RuntimeError(
|
|
|
|
|
|
|
11 |
|
12 |
-
def get_inference_client(model_id, provider="auto"):
|
13 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
if model_id == "moonshotai/Kimi-K2-Instruct":
|
15 |
provider = "groq"
|
|
|
|
|
|
|
|
|
16 |
return InferenceClient(
|
17 |
provider=provider,
|
18 |
api_key=HF_TOKEN,
|
19 |
-
bill_to=
|
20 |
)
|
21 |
|
22 |
# Tavily Search Client
|
|
|
5 |
import os
|
6 |
|
7 |
# HF Inference Client
|
8 |
+
|
9 |
+
# Supported billing targets
|
10 |
+
_VALID_BILL_TO = {"huggingface", "fairworksai", "groq"}
|
11 |
+
|
12 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
13 |
if not HF_TOKEN:
|
14 |
+
raise RuntimeError(
|
15 |
+
"HF_TOKEN environment variable is not set. "
|
16 |
+
"Please set it to your Hugging Face API token."
|
17 |
+
)
|
18 |
|
19 |
+
def get_inference_client(model_id: str, provider: str = "auto") -> InferenceClient:
|
20 |
+
"""
|
21 |
+
Return an InferenceClient configured with the correct provider and billing target.
|
22 |
+
|
23 |
+
- If model_id == "moonshotai/Kimi-K2-Instruct", force provider to "groq".
|
24 |
+
- If the requested provider is not one of the supported billing targets,
|
25 |
+
default billing to "groq".
|
26 |
+
"""
|
27 |
+
# force certain models onto groq hardware
|
28 |
if model_id == "moonshotai/Kimi-K2-Instruct":
|
29 |
provider = "groq"
|
30 |
+
|
31 |
+
# determine billing target
|
32 |
+
bill_to = provider if provider in _VALID_BILL_TO else "groq"
|
33 |
+
|
34 |
return InferenceClient(
|
35 |
provider=provider,
|
36 |
api_key=HF_TOKEN,
|
37 |
+
bill_to=bill_to
|
38 |
)
|
39 |
|
40 |
# Tavily Search Client
|