builder / hf_client.py
mgbam's picture
Update hf_client.py
e2e2a15 verified
raw
history blame
2.07 kB
# hf_client.py
import os
from typing import Optional
from huggingface_hub import InferenceClient
from tavily import TavilyClient
# Supported billing targets
_VALID_BILL_TO = {"huggingface", "fairworksai", "groq", "openai", "gemini", "fireworks", "googler"}
# Load Hugging Face token
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise RuntimeError(
"HF_TOKEN environment variable is not set. "
"Please set it to your Hugging Face API token."
)
# Initialize Tavily search client (optional)
TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
tavily_client: Optional[TavilyClient] = None
if TAVILY_API_KEY:
try:
tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
except Exception as e:
print(f"Failed to initialize Tavily client: {e}")
tavily_client = None
def get_inference_client(model_id: str, provider: str = "auto") -> InferenceClient:
"""
Configure InferenceClient based on model_id prefixes:
- moonshotai/Kimi-K2-Instruct β†’ provider "groq"
- openai/... or GPT names β†’ provider "openai"
- gemini/... or google/... β†’ provider "gemini"
- fireworks/... β†’ provider "fireworks"
- otherwise β†’ use HF Inference provider (auto)
Bill to the chosen provider if valid; otherwise, default to "groq".
"""
# Override provider by model
if model_id == "moonshotai/Kimi-K2-Instruct":
provider = "groq"
elif model_id.startswith("openai/") or model_id in {"gpt-4", "gpt-3.5-turbo"}:
provider = "openai"
elif model_id.startswith("gemini/") or model_id.startswith("google/"):
provider = "gemini"
elif model_id.startswith("fireworks/"):
provider = "fireworks"
else:
# For Hugging Face models, let HF Inference pick best provider
provider = "auto"
# Determine billing target
bill_to = provider if provider in _VALID_BILL_TO else "groq"
return InferenceClient(
provider=provider,
api_key=HF_TOKEN,
bill_to=bill_to
)