File size: 2,281 Bytes
f960c36
dfe06b1
7234f57
3e19edc
0adaacb
dfe06b1
034cea3
dfe06b1
0adaacb
dfe06b1
3e19edc
e2e2a15
034cea3
e2e2a15
7234f57
0adaacb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfe06b1
 
 
0adaacb
 
 
dfe06b1
3e19edc
dfe06b1
0adaacb
 
dfe06b1
3e19edc
dfe06b1
3e19edc
 
ffca4ae
3e19edc
034cea3
dfe06b1
0adaacb
 
034cea3
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
from huggingface_hub import InferenceClient
from tavily import TavilyClient

# === Required Tokens ===

HF_TOKEN = os.getenv('HF_TOKEN')
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')

if not HF_TOKEN:
    raise RuntimeError(
        "HF_TOKEN environment variable is not set. Please set it to your Hugging Face API token."
    )


# === Groq Adapter ===

class GroqChatClient:
    def __init__(self, api_key: str, model: str = "mixtral-8x7b-32768"):
        import openai
        openai.api_key = api_key
        openai.api_base = "https://api.groq.com/openai/v1"
        self.openai = openai
        self.model = model

    class Chat:
        def __init__(self, openai_instance, model):
            self.completions = self.Completions(openai_instance, model)

        class Completions:
            def __init__(self, openai_instance, model):
                self.client = openai_instance
                self.model = model

            def create(self, messages, temperature=0.7, max_tokens=1024):
                return self.client.ChatCompletion.create(
                    model=self.model,
                    messages=messages,
                    temperature=temperature,
                    max_tokens=max_tokens
                )

    @property
    def chat(self):
        return self.Chat(self.openai, self.model)


# === Unified Client Factory ===

def get_inference_client(model_id: str, provider: str = "auto"):
    """
    Returns a unified client interface.
    - If model_id is 'moonshotai/Kimi-K2-Instruct', use Groq via OpenAI adapter
    - Otherwise, use Hugging Face's InferenceClient
    """
    if model_id == "moonshotai/Kimi-K2-Instruct":
        if not GROQ_API_KEY:
            raise RuntimeError("GROQ_API_KEY is not set. Required for Groq models.")
        return GroqChatClient(api_key=GROQ_API_KEY)

    return InferenceClient(
        model=model_id,
        provider=provider,
        api_key=HF_TOKEN,
        bill_to="huggingface"
    )


# === Tavily Search Client ===

tavily_client = None
if TAVILY_API_KEY:
    try:
        tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
    except Exception as e:
        print(f"Failed to initialize Tavily client: {e}")
        tavily_client = None