|
import os |
|
import openai |
|
from huggingface_hub import InferenceClient |
|
from tavily import TavilyClient |
|
|
|
|
|
|
|
HF_TOKEN = os.getenv('HF_TOKEN') |
|
GROQ_API_KEY = os.getenv("GROQ_API_KEY") |
|
|
|
if not HF_TOKEN: |
|
raise RuntimeError( |
|
"HF_TOKEN environment variable is not set. Please set it to your Hugging Face API token." |
|
) |
|
|
|
|
|
|
|
def get_inference_client(model_id: str, provider: str = "auto"): |
|
""" |
|
Return an inference client depending on model ID. |
|
Uses Groq's native API for specific models, otherwise HuggingFace InferenceClient. |
|
""" |
|
if model_id == "moonshotai/Kimi-K2-Instruct": |
|
if not GROQ_API_KEY: |
|
raise RuntimeError("GROQ_API_KEY is required for Groq models.") |
|
|
|
|
|
openai.api_key = GROQ_API_KEY |
|
openai.api_base = "https://api.groq.com/openai/v1" |
|
|
|
def chat(messages, temperature=0.7, max_tokens=1024): |
|
response = openai.ChatCompletion.create( |
|
model="mixtral-8x7b-32768", |
|
messages=messages, |
|
temperature=temperature, |
|
max_tokens=max_tokens |
|
) |
|
return response["choices"][0]["message"]["content"] |
|
|
|
return chat |
|
|
|
|
|
return InferenceClient( |
|
model=model_id, |
|
provider=provider, |
|
api_key=HF_TOKEN, |
|
bill_to="huggingface" |
|
) |
|
|
|
|
|
|
|
TAVILY_API_KEY = os.getenv('TAVILY_API_KEY') |
|
tavily_client = None |
|
if TAVILY_API_KEY: |
|
try: |
|
tavily_client = TavilyClient(api_key=TAVILY_API_KEY) |
|
except Exception as e: |
|
print(f"Failed to initialize Tavily client: {e}") |
|
tavily_client = None |
|
|