File size: 2,010 Bytes
7234f57
f960c36
 
7234f57
 
 
3e19edc
7234f57
 
 
 
 
 
 
 
 
 
 
ac4a3a2
3e19edc
7234f57
 
 
 
 
 
 
 
 
 
3e19edc
ac4a3a2
 
7234f57
 
 
 
 
 
ac4a3a2
7234f57
3e19edc
 
7234f57
 
 
 
 
 
 
 
 
 
ac4a3a2
 
3e19edc
 
 
ac4a3a2
3e19edc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# hf_client.py

import os
from typing import Optional
from huggingface_hub import InferenceClient
from tavily import TavilyClient

# Supported billing targets (your secrets)
_VALID_BILL_TO = {
    "huggingface",  # HF_TOKEN
    "fairworksai",  # if you have a special billing target
    "groq",         # GROQ_API_KEY
    "openai",       # OPENAI_API_KEY
    "gemini",       # GEMINI_API_KEY
    "fireworks"     # FIREWORKS_API_KEY
}

# Load your HF token
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    raise RuntimeError("HF_TOKEN is not set. Please configure your Hugging Face token.")

# (Optional) Tavily client for web search
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
tavily_client: Optional[TavilyClient] = None
if TAVILY_API_KEY:
    try:
        tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
    except Exception as e:
        print(f"❗ Failed to init Tavily: {e}")

def get_inference_client(model_id: str, provider: str = "auto") -> InferenceClient:
    """
    Selects the appropriate inference provider & billing target:
      β€’ moonshotai/Kimi-K2-Instruct  β†’ groq
      β€’ openai/... or GPT family     β†’ openai
      β€’ gemini/... or google/...     β†’ gemini
      β€’ fireworks/...                β†’ fireworks
      β€’ otherwise                    β†’ huggingface (billing to groq)
    """
    # override by model prefix
    if model_id == "moonshotai/Kimi-K2-Instruct":
        provider = "groq"
    elif model_id.startswith("openai/") or model_id in {"gpt-4", "gpt-3.5-turbo"}:
        provider = "openai"
    elif model_id.startswith("gemini/") or model_id.startswith("google/"):
        provider = "gemini"
    elif model_id.startswith("fireworks/"):
        provider = "fireworks"
    else:
        provider = "huggingface"

    # ensure billing target is valid, else fallback to groq
    bill_to = provider if provider in _VALID_BILL_TO else "groq"

    return InferenceClient(
        provider=provider,
        api_key=HF_TOKEN,
        bill_to=bill_to
    )