Update hf_client.py
Browse files- hf_client.py +23 -5
hf_client.py
CHANGED
|
@@ -5,18 +5,36 @@ from tavily import TavilyClient
|
|
| 5 |
import os
|
| 6 |
|
| 7 |
# HF Inference Client
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
if not HF_TOKEN:
|
| 10 |
-
raise RuntimeError(
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
def get_inference_client(model_id, provider="auto"):
|
| 13 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
if model_id == "moonshotai/Kimi-K2-Instruct":
|
| 15 |
provider = "groq"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
return InferenceClient(
|
| 17 |
provider=provider,
|
| 18 |
api_key=HF_TOKEN,
|
| 19 |
-
bill_to=
|
| 20 |
)
|
| 21 |
|
| 22 |
# Tavily Search Client
|
|
|
|
| 5 |
import os
|
| 6 |
|
| 7 |
# HF Inference Client
|
| 8 |
+
|
| 9 |
+
# Supported billing targets
|
| 10 |
+
_VALID_BILL_TO = {"huggingface", "fairworksai", "groq"}
|
| 11 |
+
|
| 12 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 13 |
if not HF_TOKEN:
|
| 14 |
+
raise RuntimeError(
|
| 15 |
+
"HF_TOKEN environment variable is not set. "
|
| 16 |
+
"Please set it to your Hugging Face API token."
|
| 17 |
+
)
|
| 18 |
|
| 19 |
+
def get_inference_client(model_id: str, provider: str = "auto") -> InferenceClient:
|
| 20 |
+
"""
|
| 21 |
+
Return an InferenceClient configured with the correct provider and billing target.
|
| 22 |
+
|
| 23 |
+
- If model_id == "moonshotai/Kimi-K2-Instruct", force provider to "groq".
|
| 24 |
+
- If the requested provider is not one of the supported billing targets,
|
| 25 |
+
default billing to "groq".
|
| 26 |
+
"""
|
| 27 |
+
# force certain models onto groq hardware
|
| 28 |
if model_id == "moonshotai/Kimi-K2-Instruct":
|
| 29 |
provider = "groq"
|
| 30 |
+
|
| 31 |
+
# determine billing target
|
| 32 |
+
bill_to = provider if provider in _VALID_BILL_TO else "groq"
|
| 33 |
+
|
| 34 |
return InferenceClient(
|
| 35 |
provider=provider,
|
| 36 |
api_key=HF_TOKEN,
|
| 37 |
+
bill_to=bill_to
|
| 38 |
)
|
| 39 |
|
| 40 |
# Tavily Search Client
|