mgbam commited on
Commit
dfe06b1
·
verified ·
1 Parent(s): ffca4ae

Update hf_client.py

Browse files
Files changed (1) hide show
  1. hf_client.py +35 -10
hf_client.py CHANGED
@@ -1,29 +1,54 @@
1
-
2
-
3
- # hf_client.py
4
-
5
  import os
6
- from huggingface_hub import InferenceClient, HfApi
 
7
  from tavily import TavilyClient
8
 
9
- # HF Inference Client
 
10
  HF_TOKEN = os.getenv('HF_TOKEN')
 
 
11
  if not HF_TOKEN:
12
  raise RuntimeError(
13
  "HF_TOKEN environment variable is not set. Please set it to your Hugging Face API token."
14
  )
15
 
16
- def get_inference_client(model_id: str, provider: str = "auto") -> InferenceClient:
17
- """Return an InferenceClient with the appropriate provider."""
 
 
 
 
 
18
  if model_id == "moonshotai/Kimi-K2-Instruct":
19
- provider = "groq"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  return InferenceClient(
 
21
  provider=provider,
22
  api_key=HF_TOKEN,
23
  bill_to="huggingface"
24
  )
25
 
26
- # Tavily Search Client
 
27
  TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
28
  tavily_client = None
29
  if TAVILY_API_KEY:
 
 
 
 
 
1
  import os
2
+ import openai
3
+ from huggingface_hub import InferenceClient
4
  from tavily import TavilyClient
5
 
6
+ # === Environment Setup ===
7
+
8
  HF_TOKEN = os.getenv('HF_TOKEN')
9
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
10
+
11
  if not HF_TOKEN:
12
  raise RuntimeError(
13
  "HF_TOKEN environment variable is not set. Please set it to your Hugging Face API token."
14
  )
15
 
16
+ # === Dynamic Inference Client ===
17
+
18
+ def get_inference_client(model_id: str, provider: str = "auto"):
19
+ """
20
+ Return an inference client depending on model ID.
21
+ Uses Groq's native API for specific models, otherwise HuggingFace InferenceClient.
22
+ """
23
  if model_id == "moonshotai/Kimi-K2-Instruct":
24
+ if not GROQ_API_KEY:
25
+ raise RuntimeError("GROQ_API_KEY is required for Groq models.")
26
+
27
+ # Configure OpenAI client for Groq
28
+ openai.api_key = GROQ_API_KEY
29
+ openai.api_base = "https://api.groq.com/openai/v1"
30
+
31
+ def chat(messages, temperature=0.7, max_tokens=1024):
32
+ response = openai.ChatCompletion.create(
33
+ model="mixtral-8x7b-32768", # You can map the model here
34
+ messages=messages,
35
+ temperature=temperature,
36
+ max_tokens=max_tokens
37
+ )
38
+ return response["choices"][0]["message"]["content"]
39
+
40
+ return chat # Return callable interface
41
+
42
+ # Fallback to Hugging Face
43
  return InferenceClient(
44
+ model=model_id,
45
  provider=provider,
46
  api_key=HF_TOKEN,
47
  bill_to="huggingface"
48
  )
49
 
50
+ # === Tavily Client ===
51
+
52
  TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
53
  tavily_client = None
54
  if TAVILY_API_KEY: