|
import os |
|
from huggingface_hub import InferenceClient |
|
from tavily import TavilyClient |
|
|
|
|
|
HF_TOKEN = os.getenv('HF_TOKEN') |
|
GROQ_API_KEY = os.getenv('GROQ_API_KEY') |
|
TAVILY_API_KEY = os.getenv('TAVILY_API_KEY') |
|
|
|
if not HF_TOKEN: |
|
raise RuntimeError("HF_TOKEN environment variable is not set. Please set it to your Hugging Face API token.") |
|
|
|
|
|
class GroqChatClient: |
|
def __init__(self, api_key: str): |
|
import openai |
|
openai.api_key = api_key |
|
openai.api_base = "https://api.groq.com/openai/v1" |
|
self.client = openai |
|
self.chat = self.Chat(openai) |
|
|
|
class Chat: |
|
def __init__(self, openai_client): |
|
self.completions = self.Completions(openai_client) |
|
|
|
class Completions: |
|
def __init__(self, openai_client): |
|
self.client = openai_client |
|
|
|
def create(self, model, messages, temperature=0.7, max_tokens=1024, **kwargs): |
|
return self.client.ChatCompletion.create( |
|
model=model, |
|
messages=messages, |
|
temperature=temperature, |
|
max_tokens=max_tokens, |
|
**kwargs |
|
) |
|
|
|
|
|
def get_inference_client(model_id: str, provider: str = "auto"): |
|
""" |
|
Returns a unified interface: |
|
- For 'moonshotai/Kimi-K2-Instruct', uses Groq with OpenAI-compatible API |
|
- For others, uses Hugging Face InferenceClient |
|
""" |
|
if model_id == "moonshotai/Kimi-K2-Instruct": |
|
if not GROQ_API_KEY: |
|
raise RuntimeError("GROQ_API_KEY is required for Groq-hosted models.") |
|
return GroqChatClient(api_key=GROQ_API_KEY) |
|
|
|
return InferenceClient( |
|
model=model_id, |
|
provider=provider, |
|
api_key=HF_TOKEN, |
|
bill_to="huggingface" |
|
) |
|
|
|
|
|
tavily_client = None |
|
if TAVILY_API_KEY: |
|
try: |
|
tavily_client = TavilyClient(api_key=TAVILY_API_KEY) |
|
except Exception as e: |
|
print(f"Failed to initialize Tavily client: {e}") |
|
tavily_client = None |
|
|