builder / hf_client.py
mgbam's picture
Update hf_client.py
7234f57 verified
raw
history blame
2.01 kB
# hf_client.py
import os
from typing import Optional
from huggingface_hub import InferenceClient
from tavily import TavilyClient
# Supported billing targets (your secrets)
_VALID_BILL_TO = {
"huggingface", # HF_TOKEN
"fairworksai", # if you have a special billing target
"groq", # GROQ_API_KEY
"openai", # OPENAI_API_KEY
"gemini", # GEMINI_API_KEY
"fireworks" # FIREWORKS_API_KEY
}
# Load your HF token
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise RuntimeError("HF_TOKEN is not set. Please configure your Hugging Face token.")
# (Optional) Tavily client for web search
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
tavily_client: Optional[TavilyClient] = None
if TAVILY_API_KEY:
try:
tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
except Exception as e:
print(f"❗ Failed to init Tavily: {e}")
def get_inference_client(model_id: str, provider: str = "auto") -> InferenceClient:
"""
Selects the appropriate inference provider & billing target:
β€’ moonshotai/Kimi-K2-Instruct β†’ groq
β€’ openai/... or GPT family β†’ openai
β€’ gemini/... or google/... β†’ gemini
β€’ fireworks/... β†’ fireworks
β€’ otherwise β†’ huggingface (billing to groq)
"""
# override by model prefix
if model_id == "moonshotai/Kimi-K2-Instruct":
provider = "groq"
elif model_id.startswith("openai/") or model_id in {"gpt-4", "gpt-3.5-turbo"}:
provider = "openai"
elif model_id.startswith("gemini/") or model_id.startswith("google/"):
provider = "gemini"
elif model_id.startswith("fireworks/"):
provider = "fireworks"
else:
provider = "huggingface"
# ensure billing target is valid, else fallback to groq
bill_to = provider if provider in _VALID_BILL_TO else "groq"
return InferenceClient(
provider=provider,
api_key=HF_TOKEN,
bill_to=bill_to
)