mgbam commited on
Commit
0adaacb
·
verified ·
1 Parent(s): dfe06b1

Update hf_client.py

Browse files
Files changed (1) hide show
  1. hf_client.py +43 -24
hf_client.py CHANGED
@@ -1,45 +1,64 @@
1
  import os
2
- import openai
3
  from huggingface_hub import InferenceClient
4
  from tavily import TavilyClient
5
 
6
- # === Environment Setup ===
7
 
8
  HF_TOKEN = os.getenv('HF_TOKEN')
9
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
 
10
 
11
  if not HF_TOKEN:
12
  raise RuntimeError(
13
  "HF_TOKEN environment variable is not set. Please set it to your Hugging Face API token."
14
  )
15
 
16
- # === Dynamic Inference Client ===
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def get_inference_client(model_id: str, provider: str = "auto"):
19
  """
20
- Return an inference client depending on model ID.
21
- Uses Groq's native API for specific models, otherwise HuggingFace InferenceClient.
 
22
  """
23
  if model_id == "moonshotai/Kimi-K2-Instruct":
24
  if not GROQ_API_KEY:
25
- raise RuntimeError("GROQ_API_KEY is required for Groq models.")
26
-
27
- # Configure OpenAI client for Groq
28
- openai.api_key = GROQ_API_KEY
29
- openai.api_base = "https://api.groq.com/openai/v1"
30
 
31
- def chat(messages, temperature=0.7, max_tokens=1024):
32
- response = openai.ChatCompletion.create(
33
- model="mixtral-8x7b-32768", # You can map the model here
34
- messages=messages,
35
- temperature=temperature,
36
- max_tokens=max_tokens
37
- )
38
- return response["choices"][0]["message"]["content"]
39
-
40
- return chat # Return callable interface
41
-
42
- # Fallback to Hugging Face
43
  return InferenceClient(
44
  model=model_id,
45
  provider=provider,
@@ -47,9 +66,9 @@ def get_inference_client(model_id: str, provider: str = "auto"):
47
  bill_to="huggingface"
48
  )
49
 
50
- # === Tavily Client ===
51
 
52
- TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
 
53
  tavily_client = None
54
  if TAVILY_API_KEY:
55
  try:
 
1
  import os
 
2
  from huggingface_hub import InferenceClient
3
  from tavily import TavilyClient
4
 
5
+ # === Required Tokens ===
6
 
7
  HF_TOKEN = os.getenv('HF_TOKEN')
8
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
9
+ TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
10
 
11
  if not HF_TOKEN:
12
  raise RuntimeError(
13
  "HF_TOKEN environment variable is not set. Please set it to your Hugging Face API token."
14
  )
15
 
16
+
17
+ # === Groq Adapter ===
18
+
19
+ class GroqChatClient:
20
+ def __init__(self, api_key: str, model: str = "mixtral-8x7b-32768"):
21
+ import openai
22
+ openai.api_key = api_key
23
+ openai.api_base = "https://api.groq.com/openai/v1"
24
+ self.openai = openai
25
+ self.model = model
26
+
27
+ class Chat:
28
+ def __init__(self, openai_instance, model):
29
+ self.completions = self.Completions(openai_instance, model)
30
+
31
+ class Completions:
32
+ def __init__(self, openai_instance, model):
33
+ self.client = openai_instance
34
+ self.model = model
35
+
36
+ def create(self, messages, temperature=0.7, max_tokens=1024):
37
+ return self.client.ChatCompletion.create(
38
+ model=self.model,
39
+ messages=messages,
40
+ temperature=temperature,
41
+ max_tokens=max_tokens
42
+ )
43
+
44
+ @property
45
+ def chat(self):
46
+ return self.Chat(self.openai, self.model)
47
+
48
+
49
+ # === Unified Client Factory ===
50
 
51
  def get_inference_client(model_id: str, provider: str = "auto"):
52
  """
53
+ Returns a unified client interface.
54
+ - If model_id is 'moonshotai/Kimi-K2-Instruct', use Groq via OpenAI adapter
55
+ - Otherwise, use Hugging Face's InferenceClient
56
  """
57
  if model_id == "moonshotai/Kimi-K2-Instruct":
58
  if not GROQ_API_KEY:
59
+ raise RuntimeError("GROQ_API_KEY is not set. Required for Groq models.")
60
+ return GroqChatClient(api_key=GROQ_API_KEY)
 
 
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  return InferenceClient(
63
  model=model_id,
64
  provider=provider,
 
66
  bill_to="huggingface"
67
  )
68
 
 
69
 
70
+ # === Tavily Search Client ===
71
+
72
  tavily_client = None
73
  if TAVILY_API_KEY:
74
  try: