|
|
|
|
|
import os |
|
from typing import Optional |
|
from huggingface_hub import InferenceClient |
|
from tavily import TavilyClient |
|
|
|
|
|
_VALID_BILL_TO = {"huggingface", "fairworksai", "groq", "openai", "gemini", "fireworks", "googler"} |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
if not HF_TOKEN: |
|
raise RuntimeError( |
|
"HF_TOKEN environment variable is not set. " |
|
"Please set it to your Hugging Face API token." |
|
) |
|
|
|
|
|
TAVILY_API_KEY = os.getenv('TAVILY_API_KEY') |
|
tavily_client: Optional[TavilyClient] = None |
|
if TAVILY_API_KEY: |
|
try: |
|
tavily_client = TavilyClient(api_key=TAVILY_API_KEY) |
|
except Exception as e: |
|
print(f"Failed to initialize Tavily client: {e}") |
|
tavily_client = None |
|
|
|
|
|
def get_inference_client(model_id: str, provider: str = "auto") -> InferenceClient: |
|
""" |
|
Configure InferenceClient based on model_id prefixes: |
|
- moonshotai/Kimi-K2-Instruct β provider "groq" |
|
- openai/... or GPT names β provider "openai" |
|
- gemini/... or google/... β provider "gemini" |
|
- fireworks/... β provider "fireworks" |
|
- otherwise β use HF Inference provider (auto) |
|
Bill to the chosen provider if valid; otherwise, default to "groq". |
|
""" |
|
|
|
if model_id == "moonshotai/Kimi-K2-Instruct": |
|
provider = "groq" |
|
elif model_id.startswith("openai/") or model_id in {"gpt-4", "gpt-3.5-turbo"}: |
|
provider = "openai" |
|
elif model_id.startswith("gemini/") or model_id.startswith("google/"): |
|
provider = "gemini" |
|
elif model_id.startswith("fireworks/"): |
|
provider = "fireworks" |
|
else: |
|
|
|
provider = "auto" |
|
|
|
|
|
bill_to = provider if provider in _VALID_BILL_TO else "groq" |
|
|
|
return InferenceClient( |
|
provider=provider, |
|
api_key=HF_TOKEN, |
|
bill_to=bill_to |
|
) |
|
|