mgbam commited on
Commit
e2e2a15
Β·
verified Β·
1 Parent(s): 7234f57

Update hf_client.py

Browse files
Files changed (1) hide show
  1. hf_client.py +24 -24
hf_client.py CHANGED
@@ -5,40 +5,39 @@ from typing import Optional
5
  from huggingface_hub import InferenceClient
6
  from tavily import TavilyClient
7
 
8
- # Supported billing targets (your secrets)
9
- _VALID_BILL_TO = {
10
- "huggingface", # HF_TOKEN
11
- "fairworksai", # if you have a special billing target
12
- "groq", # GROQ_API_KEY
13
- "openai", # OPENAI_API_KEY
14
- "gemini", # GEMINI_API_KEY
15
- "fireworks" # FIREWORKS_API_KEY
16
- }
17
-
18
- # Load your HF token
19
  HF_TOKEN = os.getenv("HF_TOKEN")
20
  if not HF_TOKEN:
21
- raise RuntimeError("HF_TOKEN is not set. Please configure your Hugging Face token.")
 
 
 
22
 
23
- # (Optional) Tavily client for web search
24
- TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
25
  tavily_client: Optional[TavilyClient] = None
26
  if TAVILY_API_KEY:
27
  try:
28
  tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
29
  except Exception as e:
30
- print(f"❗ Failed to init Tavily: {e}")
 
 
31
 
32
  def get_inference_client(model_id: str, provider: str = "auto") -> InferenceClient:
33
  """
34
- Selects the appropriate inference provider & billing target:
35
- β€’ moonshotai/Kimi-K2-Instruct β†’ groq
36
- β€’ openai/... or GPT family β†’ openai
37
- β€’ gemini/... or google/... β†’ gemini
38
- β€’ fireworks/... β†’ fireworks
39
- β€’ otherwise β†’ huggingface (billing to groq)
 
40
  """
41
- # override by model prefix
42
  if model_id == "moonshotai/Kimi-K2-Instruct":
43
  provider = "groq"
44
  elif model_id.startswith("openai/") or model_id in {"gpt-4", "gpt-3.5-turbo"}:
@@ -48,9 +47,10 @@ def get_inference_client(model_id: str, provider: str = "auto") -> InferenceClie
48
  elif model_id.startswith("fireworks/"):
49
  provider = "fireworks"
50
  else:
51
- provider = "huggingface"
 
52
 
53
- # ensure billing target is valid, else fallback to groq
54
  bill_to = provider if provider in _VALID_BILL_TO else "groq"
55
 
56
  return InferenceClient(
 
5
  from huggingface_hub import InferenceClient
6
  from tavily import TavilyClient
7
 
8
+ # Supported billing targets
9
+ _VALID_BILL_TO = {"huggingface", "fairworksai", "groq", "openai", "gemini", "fireworks", "googler"}
10
+
11
+ # Load Hugging Face token
 
 
 
 
 
 
 
12
  HF_TOKEN = os.getenv("HF_TOKEN")
13
  if not HF_TOKEN:
14
+ raise RuntimeError(
15
+ "HF_TOKEN environment variable is not set. "
16
+ "Please set it to your Hugging Face API token."
17
+ )
18
 
19
+ # Initialize Tavily search client (optional)
20
+ TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
21
  tavily_client: Optional[TavilyClient] = None
22
  if TAVILY_API_KEY:
23
  try:
24
  tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
25
  except Exception as e:
26
+ print(f"Failed to initialize Tavily client: {e}")
27
+ tavily_client = None
28
+
29
 
30
  def get_inference_client(model_id: str, provider: str = "auto") -> InferenceClient:
31
  """
32
+ Configure InferenceClient based on model_id prefixes:
33
+ - moonshotai/Kimi-K2-Instruct β†’ provider "groq"
34
+ - openai/... or GPT names β†’ provider "openai"
35
+ - gemini/... or google/... β†’ provider "gemini"
36
+ - fireworks/... β†’ provider "fireworks"
37
+ - otherwise β†’ use HF Inference provider (auto)
38
+ Bill to the chosen provider if valid; otherwise, default to "groq".
39
  """
40
+ # Override provider by model
41
  if model_id == "moonshotai/Kimi-K2-Instruct":
42
  provider = "groq"
43
  elif model_id.startswith("openai/") or model_id in {"gpt-4", "gpt-3.5-turbo"}:
 
47
  elif model_id.startswith("fireworks/"):
48
  provider = "fireworks"
49
  else:
50
+ # For Hugging Face models, let HF Inference pick best provider
51
+ provider = "auto"
52
 
53
+ # Determine billing target
54
  bill_to = provider if provider in _VALID_BILL_TO else "groq"
55
 
56
  return InferenceClient(