mgbam commited on
Commit
034cea3
·
verified ·
1 Parent(s): f70a193

Update hf_client.py

Browse files
Files changed (1) hide show
  1. hf_client.py +18 -44
hf_client.py CHANGED
@@ -1,60 +1,34 @@
 
 
1
  # hf_client.py
2
 
3
  import os
4
- from typing import Optional
5
- from huggingface_hub import InferenceClient
6
  from tavily import TavilyClient
7
 
8
- # Supported billing targets
9
- _VALID_BILL_TO = {"huggingface", "fairworksai", "groq", "openai", "gemini", "fireworks", "googler"}
10
-
11
- # Load Hugging Face token
12
- HF_TOKEN = os.getenv("HF_TOKEN")
13
  if not HF_TOKEN:
14
  raise RuntimeError(
15
- "HF_TOKEN environment variable is not set. "
16
- "Please set it to your Hugging Face API token."
17
  )
18
 
19
- # Initialize Tavily search client (optional)
20
- TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
21
- tavily_client: Optional[TavilyClient] = None
22
- if TAVILY_API_KEY:
23
- try:
24
- tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
25
- except Exception as e:
26
- print(f"Failed to initialize Tavily client: {e}")
27
- tavily_client = None
28
-
29
-
30
  def get_inference_client(model_id: str, provider: str = "auto") -> InferenceClient:
31
- """
32
- Configure InferenceClient based on model_id prefixes:
33
- - moonshotai/Kimi-K2-Instruct → provider "groq"
34
- - openai/... or GPT names → provider "openai"
35
- - gemini/... or google/... → provider "gemini"
36
- - fireworks/... → provider "fireworks"
37
- - otherwise → use HF Inference provider (auto)
38
- Bill to the chosen provider if valid; otherwise, default to "groq".
39
- """
40
- # Override provider by model
41
  if model_id == "moonshotai/Kimi-K2-Instruct":
42
  provider = "groq"
43
- elif model_id.startswith("openai/") or model_id in {"gpt-4", "gpt-3.5-turbo"}:
44
- provider = "openai"
45
- elif model_id.startswith("gemini/") or model_id.startswith("google/"):
46
- provider = "gemini"
47
- elif model_id.startswith("fireworks/"):
48
- provider = "fireworks"
49
- else:
50
- # For Hugging Face models, let HF Inference pick best provider
51
- provider = "auto"
52
-
53
- # Determine billing target
54
- bill_to = provider if provider in _VALID_BILL_TO else "groq"
55
-
56
  return InferenceClient(
57
  provider=provider,
58
  api_key=HF_TOKEN,
59
- bill_to=bill_to
60
  )
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ```python
3
  # hf_client.py
4
 
5
  import os
6
+ from huggingface_hub import InferenceClient, HfApi
 
7
  from tavily import TavilyClient
8
 
9
+ # HF Inference Client
10
+ HF_TOKEN = os.getenv('HF_TOKEN')
 
 
 
11
  if not HF_TOKEN:
12
  raise RuntimeError(
13
+ "HF_TOKEN environment variable is not set. Please set it to your Hugging Face API token."
 
14
  )
15
 
 
 
 
 
 
 
 
 
 
 
 
16
  def get_inference_client(model_id: str, provider: str = "auto") -> InferenceClient:
17
+ """Return an InferenceClient with the appropriate provider."""
 
 
 
 
 
 
 
 
 
18
  if model_id == "moonshotai/Kimi-K2-Instruct":
19
  provider = "groq"
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  return InferenceClient(
21
  provider=provider,
22
  api_key=HF_TOKEN,
23
+ bill_to="my-org-name"
24
  )
25
+
26
+ # Tavily Search Client
27
+ TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
28
+ tavily_client = None
29
+ if TAVILY_API_KEY:
30
+ try:
31
+ tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
32
+ except Exception as e:
33
+ print(f"Failed to initialize Tavily client: {e}")
34
+ tavily_client = None