builder / services.py
mgbam's picture
Update services.py
489ab9c verified
raw
history blame
4.31 kB
# /services.py
"""
Manages interactions with external services like LLM providers and web search APIs.
This module has been refactored to support multiple LLM providers:
- Hugging Face (for standard and multimodal models)
- Groq (for high-speed inference)
- Fireworks AI
"""
import os
import logging
from typing import Dict, Any, Generator, List
from dotenv import load_dotenv
# Import all necessary clients
from huggingface_hub import InferenceClient
from tavily import TavilyClient
from groq import Groq
import fireworks.client as Fireworks
# --- Setup Logging & Environment ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
load_dotenv()
# --- API Keys ---
HF_TOKEN = os.getenv("HF_TOKEN")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")
# --- Type Definitions ---
Messages = List[Dict[str, Any]]
class LLMService:
"""A multi-provider wrapper for LLM Inference APIs."""
def __init__(self):
# Initialize clients if their API keys are available
self.hf_client = InferenceClient(token=HF_TOKEN) if HF_TOKEN else None
self.groq_client = Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None
self.fireworks_client = Fireworks if FIREWORKS_API_KEY else None
if self.fireworks_client:
self.fireworks_client.api_key = FIREWORKS_API_KEY
def generate_code_stream(
self, model_id: str, messages: Messages, max_tokens: int = 8000
) -> Generator[str, None, None]:
"""
Streams code generation, dispatching to the correct provider based on model_id.
The model_id format is 'provider/model-name' or a full HF model_id.
"""
provider = "huggingface" # Default provider
model_name = model_id
if '/' in model_id:
parts = model_id.split('/', 1)
if parts[0] in ['groq', 'fireworks', 'huggingface']:
provider = parts[0]
model_name = parts[1]
logging.info(f"Dispatching to provider: {provider} for model: {model_name}")
try:
# --- Groq Provider ---
if provider == 'groq':
if not self.groq_client:
raise ValueError("Groq API key is not configured.")
stream = self.groq_client.chat.completions.create(
model=model_name, messages=messages, stream=True, max_tokens=max_tokens
)
for chunk in stream:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
# --- Fireworks AI Provider ---
elif provider == 'fireworks':
if not self.fireworks_client:
raise ValueError("Fireworks AI API key is not configured.")
stream = self.fireworks_client.ChatCompletion.create(
model=model_name, messages=messages, stream=True, max_tokens=max_tokens
)
for chunk in stream:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
# --- Hugging Face Provider (Default) ---
else:
if not self.hf_client:
raise ValueError("Hugging Face API token is not configured.")
# For HF, the model_name is the full original model_id
stream = self.hf_client.chat_completion(
model=model_name, messages=messages, stream=True, max_tokens=max_tokens
)
for chunk in stream:
yield chunk.choices[0].delta.content
except Exception as e:
logging.error(f"LLM API Error with provider {provider}: {e}")
yield f"Error from {provider.capitalize()}: {str(e)}"
class SearchService:
# (This class remains unchanged)
def __init__(self, api_key: str = TAVILY_API_KEY):
# ... existing code ...
def is_available(self) -> bool:
# ... existing code ...
def search(self, query: str, max_results: int = 5) -> str:
# ... existing code ...
# --- Singleton Instances ---
llm_service = LLMService()
search_service = SearchService()