|
|
import os |
|
|
from huggingface_hub import InferenceClient |
|
|
from tavily import TavilyClient |
|
|
|
|
|
|
|
|
|
|
|
HF_TOKEN = os.getenv('HF_TOKEN') |
|
|
GROQ_API_KEY = os.getenv("GROQ_API_KEY") |
|
|
TAVILY_API_KEY = os.getenv('TAVILY_API_KEY') |
|
|
|
|
|
if not HF_TOKEN: |
|
|
raise RuntimeError( |
|
|
"HF_TOKEN environment variable is not set. Please set it to your Hugging Face API token." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GroqChatClient: |
|
|
def __init__(self, api_key: str, model: str = "mixtral-8x7b-32768"): |
|
|
import openai |
|
|
openai.api_key = api_key |
|
|
openai.api_base = "https://api.groq.com/openai/v1" |
|
|
self.openai = openai |
|
|
self.model = model |
|
|
|
|
|
class Chat: |
|
|
def __init__(self, openai_instance, model): |
|
|
self.completions = self.Completions(openai_instance, model) |
|
|
|
|
|
class Completions: |
|
|
def __init__(self, openai_instance, model): |
|
|
self.client = openai_instance |
|
|
self.model = model |
|
|
|
|
|
def create(self, messages, temperature=0.7, max_tokens=1024): |
|
|
return self.client.ChatCompletion.create( |
|
|
model=self.model, |
|
|
messages=messages, |
|
|
temperature=temperature, |
|
|
max_tokens=max_tokens |
|
|
) |
|
|
|
|
|
@property |
|
|
def chat(self): |
|
|
return self.Chat(self.openai, self.model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_inference_client(model_id: str, provider: str = "auto"): |
|
|
""" |
|
|
Returns a unified client interface. |
|
|
- If model_id is 'moonshotai/Kimi-K2-Instruct', use Groq via OpenAI adapter |
|
|
- Otherwise, use Hugging Face's InferenceClient |
|
|
""" |
|
|
if model_id == "moonshotai/Kimi-K2-Instruct": |
|
|
if not GROQ_API_KEY: |
|
|
raise RuntimeError("GROQ_API_KEY is not set. Required for Groq models.") |
|
|
return GroqChatClient(api_key=GROQ_API_KEY) |
|
|
|
|
|
return InferenceClient( |
|
|
model=model_id, |
|
|
provider=provider, |
|
|
api_key=HF_TOKEN, |
|
|
bill_to="huggingface" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tavily_client = None |
|
|
if TAVILY_API_KEY: |
|
|
try: |
|
|
tavily_client = TavilyClient(api_key=TAVILY_API_KEY) |
|
|
except Exception as e: |
|
|
print(f"Failed to initialize Tavily client: {e}") |
|
|
tavily_client = None |
|
|
|