mgbam commited on
Commit
7234f57
Β·
verified Β·
1 Parent(s): 71a8f19

Update hf_client.py

Browse files
Files changed (1) hide show
  1. hf_client.py +42 -30
hf_client.py CHANGED
@@ -1,34 +1,56 @@
1
- ### hf_client.py
2
 
3
- from huggingface_hub import InferenceClient, HfApi
4
- from tavily import TavilyClient
5
  import os
 
 
 
6
 
7
- # HF Inference Client
8
-
9
- # Supported billing targets
10
- _VALID_BILL_TO = {"huggingface", "fairworksai", "groq"}
11
-
 
 
 
 
 
 
12
  HF_TOKEN = os.getenv("HF_TOKEN")
13
  if not HF_TOKEN:
14
- raise RuntimeError(
15
- "HF_TOKEN environment variable is not set. "
16
- "Please set it to your Hugging Face API token."
17
- )
 
 
 
 
 
 
18
 
19
  def get_inference_client(model_id: str, provider: str = "auto") -> InferenceClient:
20
  """
21
- Return an InferenceClient configured with the correct provider and billing target.
22
-
23
- - If model_id == "moonshotai/Kimi-K2-Instruct", force provider to "groq".
24
- - If the requested provider is not one of the supported billing targets,
25
- default billing to "groq".
 
26
  """
27
- # force certain models onto groq hardware
28
  if model_id == "moonshotai/Kimi-K2-Instruct":
29
  provider = "groq"
30
-
31
- # determine billing target
 
 
 
 
 
 
 
 
32
  bill_to = provider if provider in _VALID_BILL_TO else "groq"
33
 
34
  return InferenceClient(
@@ -36,13 +58,3 @@ def get_inference_client(model_id: str, provider: str = "auto") -> InferenceClie
36
  api_key=HF_TOKEN,
37
  bill_to=bill_to
38
  )
39
-
40
- # Tavily Search Client
41
- TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
42
- tavily_client = None
43
- if TAVILY_API_KEY:
44
- try:
45
- tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
46
- except Exception as e:
47
- print(f"Failed to initialize Tavily client: {e}")
48
- tavily_client = None
 
1
+ # hf_client.py
2
 
 
 
3
  import os
4
+ from typing import Optional
5
+ from huggingface_hub import InferenceClient
6
+ from tavily import TavilyClient
7
 
8
+ # Supported billing targets (your secrets)
9
+ _VALID_BILL_TO = {
10
+ "huggingface", # HF_TOKEN
11
+ "fairworksai", # if you have a special billing target
12
+ "groq", # GROQ_API_KEY
13
+ "openai", # OPENAI_API_KEY
14
+ "gemini", # GEMINI_API_KEY
15
+ "fireworks" # FIREWORKS_API_KEY
16
+ }
17
+
18
+ # Load your HF token
19
  HF_TOKEN = os.getenv("HF_TOKEN")
20
  if not HF_TOKEN:
21
+ raise RuntimeError("HF_TOKEN is not set. Please configure your Hugging Face token.")
22
+
23
+ # (Optional) Tavily client for web search
24
+ TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
25
+ tavily_client: Optional[TavilyClient] = None
26
+ if TAVILY_API_KEY:
27
+ try:
28
+ tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
29
+ except Exception as e:
30
+ print(f"❗ Failed to init Tavily: {e}")
31
 
32
  def get_inference_client(model_id: str, provider: str = "auto") -> InferenceClient:
33
  """
34
+ Selects the appropriate inference provider & billing target:
35
+ β€’ moonshotai/Kimi-K2-Instruct β†’ groq
36
+ β€’ openai/... or GPT family β†’ openai
37
+ β€’ gemini/... or google/... β†’ gemini
38
+ β€’ fireworks/... β†’ fireworks
39
+ β€’ otherwise β†’ huggingface (billing to groq)
40
  """
41
+ # override by model prefix
42
  if model_id == "moonshotai/Kimi-K2-Instruct":
43
  provider = "groq"
44
+ elif model_id.startswith("openai/") or model_id in {"gpt-4", "gpt-3.5-turbo"}:
45
+ provider = "openai"
46
+ elif model_id.startswith("gemini/") or model_id.startswith("google/"):
47
+ provider = "gemini"
48
+ elif model_id.startswith("fireworks/"):
49
+ provider = "fireworks"
50
+ else:
51
+ provider = "huggingface"
52
+
53
+ # ensure billing target is valid, else fallback to groq
54
  bill_to = provider if provider in _VALID_BILL_TO else "groq"
55
 
56
  return InferenceClient(
 
58
  api_key=HF_TOKEN,
59
  bill_to=bill_to
60
  )