mgbam commited on
Commit
6a1db5c
·
verified ·
1 Parent(s): 0adaacb

Update hf_client.py

Browse files
Files changed (1) hide show
  1. hf_client.py +20 -33
hf_client.py CHANGED
@@ -2,61 +2,50 @@ import os
2
  from huggingface_hub import InferenceClient
3
  from tavily import TavilyClient
4
 
5
- # === Required Tokens ===
6
-
7
  HF_TOKEN = os.getenv('HF_TOKEN')
8
- GROQ_API_KEY = os.getenv("GROQ_API_KEY")
9
  TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
10
 
11
  if not HF_TOKEN:
12
- raise RuntimeError(
13
- "HF_TOKEN environment variable is not set. Please set it to your Hugging Face API token."
14
- )
15
-
16
-
17
- # === Groq Adapter ===
18
 
 
19
  class GroqChatClient:
20
- def __init__(self, api_key: str, model: str = "mixtral-8x7b-32768"):
21
  import openai
22
  openai.api_key = api_key
23
  openai.api_base = "https://api.groq.com/openai/v1"
24
- self.openai = openai
25
- self.model = model
26
 
27
  class Chat:
28
- def __init__(self, openai_instance, model):
29
- self.completions = self.Completions(openai_instance, model)
30
 
31
  class Completions:
32
- def __init__(self, openai_instance, model):
33
- self.client = openai_instance
34
- self.model = model
35
 
36
- def create(self, messages, temperature=0.7, max_tokens=1024):
37
  return self.client.ChatCompletion.create(
38
- model=self.model,
39
  messages=messages,
40
  temperature=temperature,
41
- max_tokens=max_tokens
 
42
  )
43
 
44
- @property
45
- def chat(self):
46
- return self.Chat(self.openai, self.model)
47
-
48
-
49
- # === Unified Client Factory ===
50
-
51
  def get_inference_client(model_id: str, provider: str = "auto"):
52
  """
53
- Returns a unified client interface.
54
- - If model_id is 'moonshotai/Kimi-K2-Instruct', use Groq via OpenAI adapter
55
- - Otherwise, use Hugging Face's InferenceClient
56
  """
57
  if model_id == "moonshotai/Kimi-K2-Instruct":
58
  if not GROQ_API_KEY:
59
- raise RuntimeError("GROQ_API_KEY is not set. Required for Groq models.")
60
  return GroqChatClient(api_key=GROQ_API_KEY)
61
 
62
  return InferenceClient(
@@ -66,9 +55,7 @@ def get_inference_client(model_id: str, provider: str = "auto"):
66
  bill_to="huggingface"
67
  )
68
 
69
-
70
  # === Tavily Search Client ===
71
-
72
  tavily_client = None
73
  if TAVILY_API_KEY:
74
  try:
 
2
  from huggingface_hub import InferenceClient
3
  from tavily import TavilyClient
4
 
5
+ # === API Keys ===
 
6
  HF_TOKEN = os.getenv('HF_TOKEN')
7
+ GROQ_API_KEY = os.getenv('GROQ_API_KEY')
8
  TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
9
 
10
  if not HF_TOKEN:
11
+ raise RuntimeError("HF_TOKEN environment variable is not set. Please set it to your Hugging Face API token.")
 
 
 
 
 
12
 
13
+ # === GROQ-Compatible Wrapper ===
14
  class GroqChatClient:
15
+ def __init__(self, api_key: str):
16
  import openai
17
  openai.api_key = api_key
18
  openai.api_base = "https://api.groq.com/openai/v1"
19
+ self.client = openai
20
+ self.chat = self.Chat(openai)
21
 
22
  class Chat:
23
+ def __init__(self, openai_client):
24
+ self.completions = self.Completions(openai_client)
25
 
26
  class Completions:
27
+ def __init__(self, openai_client):
28
+ self.client = openai_client
 
29
 
30
+ def create(self, model, messages, temperature=0.7, max_tokens=1024, **kwargs):
31
  return self.client.ChatCompletion.create(
32
+ model=model,
33
  messages=messages,
34
  temperature=temperature,
35
+ max_tokens=max_tokens,
36
+ **kwargs
37
  )
38
 
39
+ # === Inference Client Selector ===
 
 
 
 
 
 
40
  def get_inference_client(model_id: str, provider: str = "auto"):
41
  """
42
+ Returns a unified interface:
43
+ - For 'moonshotai/Kimi-K2-Instruct', uses Groq with OpenAI-compatible API
44
+ - For others, uses Hugging Face InferenceClient
45
  """
46
  if model_id == "moonshotai/Kimi-K2-Instruct":
47
  if not GROQ_API_KEY:
48
+ raise RuntimeError("GROQ_API_KEY is required for Groq-hosted models.")
49
  return GroqChatClient(api_key=GROQ_API_KEY)
50
 
51
  return InferenceClient(
 
55
  bill_to="huggingface"
56
  )
57
 
 
58
  # === Tavily Search Client ===
 
59
  tavily_client = None
60
  if TAVILY_API_KEY:
61
  try: