|
|
|
|
|
import os |
|
from typing import Optional |
|
from huggingface_hub import InferenceClient |
|
from tavily import TavilyClient |
|
|
|
|
|
_VALID_BILL_TO = { |
|
"huggingface", |
|
"fairworksai", |
|
"groq", |
|
"openai", |
|
"gemini", |
|
"fireworks" |
|
} |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
if not HF_TOKEN: |
|
raise RuntimeError("HF_TOKEN is not set. Please configure your Hugging Face token.") |
|
|
|
|
|
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") |
|
tavily_client: Optional[TavilyClient] = None |
|
if TAVILY_API_KEY: |
|
try: |
|
tavily_client = TavilyClient(api_key=TAVILY_API_KEY) |
|
except Exception as e: |
|
print(f"β Failed to init Tavily: {e}") |
|
|
|
def get_inference_client(model_id: str, provider: str = "auto") -> InferenceClient: |
|
""" |
|
Selects the appropriate inference provider & billing target: |
|
β’ moonshotai/Kimi-K2-Instruct β groq |
|
β’ openai/... or GPT family β openai |
|
β’ gemini/... or google/... β gemini |
|
β’ fireworks/... β fireworks |
|
β’ otherwise β huggingface (billing to groq) |
|
""" |
|
|
|
if model_id == "moonshotai/Kimi-K2-Instruct": |
|
provider = "groq" |
|
elif model_id.startswith("openai/") or model_id in {"gpt-4", "gpt-3.5-turbo"}: |
|
provider = "openai" |
|
elif model_id.startswith("gemini/") or model_id.startswith("google/"): |
|
provider = "gemini" |
|
elif model_id.startswith("fireworks/"): |
|
provider = "fireworks" |
|
else: |
|
provider = "huggingface" |
|
|
|
|
|
bill_to = provider if provider in _VALID_BILL_TO else "groq" |
|
|
|
return InferenceClient( |
|
provider=provider, |
|
api_key=HF_TOKEN, |
|
bill_to=bill_to |
|
) |
|
|