File size: 2,072 Bytes
7234f57
f960c36
 
7234f57
 
 
3e19edc
e2e2a15
 
 
 
ac4a3a2
3e19edc
e2e2a15
 
 
 
7234f57
e2e2a15
 
7234f57
 
 
 
 
e2e2a15
 
 
3e19edc
ac4a3a2
 
e2e2a15
 
 
 
 
 
 
ac4a3a2
e2e2a15
3e19edc
 
7234f57
 
 
 
 
 
 
e2e2a15
 
7234f57
e2e2a15
ac4a3a2
 
3e19edc
 
 
ac4a3a2
3e19edc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# hf_client.py

import os
from typing import Optional
from huggingface_hub import InferenceClient
from tavily import TavilyClient

# Supported billing targets
_VALID_BILL_TO = {"huggingface", "fairworksai", "groq", "openai", "gemini", "fireworks", "googler"}

# Load Hugging Face token
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    raise RuntimeError(
        "HF_TOKEN environment variable is not set. "
        "Please set it to your Hugging Face API token."
    )

# Initialize Tavily search client (optional)
TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
tavily_client: Optional[TavilyClient] = None
if TAVILY_API_KEY:
    try:
        tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
    except Exception as e:
        print(f"Failed to initialize Tavily client: {e}")
        tavily_client = None


def get_inference_client(model_id: str, provider: str = "auto") -> InferenceClient:
    """
    Configure InferenceClient based on model_id prefixes:
      - moonshotai/Kimi-K2-Instruct β†’ provider "groq"
      - openai/... or GPT names       β†’ provider "openai"
      - gemini/... or google/...     β†’ provider "gemini"
      - fireworks/...                β†’ provider "fireworks"
      - otherwise                     β†’ use HF Inference provider (auto)
    Bill to the chosen provider if valid; otherwise, default to "groq".
    """
    # Override provider by model
    if model_id == "moonshotai/Kimi-K2-Instruct":
        provider = "groq"
    elif model_id.startswith("openai/") or model_id in {"gpt-4", "gpt-3.5-turbo"}:
        provider = "openai"
    elif model_id.startswith("gemini/") or model_id.startswith("google/"):
        provider = "gemini"
    elif model_id.startswith("fireworks/"):
        provider = "fireworks"
    else:
        # For Hugging Face models, let HF Inference pick best provider
        provider = "auto"

    # Determine billing target
    bill_to = provider if provider in _VALID_BILL_TO else "groq"

    return InferenceClient(
        provider=provider,
        api_key=HF_TOKEN,
        bill_to=bill_to
    )