File size: 4,312 Bytes
1687ea3 489ab9c 1687ea3 489ab9c 1687ea3 489ab9c 1687ea3 489ab9c 1687ea3 489ab9c 1687ea3 489ab9c 1687ea3 489ab9c 1687ea3 489ab9c 1687ea3 489ab9c 1687ea3 489ab9c 1687ea3 489ab9c 1687ea3 489ab9c 1687ea3 489ab9c 1687ea3 489ab9c 1687ea3 489ab9c 1687ea3 489ab9c 1687ea3 489ab9c 1687ea3 489ab9c 1687ea3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
# /services.py
"""
Manages interactions with external services like LLM providers and web search APIs.
This module has been refactored to support multiple LLM providers:
- Hugging Face (for standard and multimodal models)
- Groq (for high-speed inference)
- Fireworks AI
"""
import os
import logging
from typing import Dict, Any, Generator, List
from dotenv import load_dotenv
# Import all necessary clients
from huggingface_hub import InferenceClient
from tavily import TavilyClient
from groq import Groq
import fireworks.client as Fireworks
# --- Setup Logging & Environment ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
load_dotenv()
# --- API Keys ---
HF_TOKEN = os.getenv("HF_TOKEN")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")
# --- Type Definitions ---
Messages = List[Dict[str, Any]]
class LLMService:
"""A multi-provider wrapper for LLM Inference APIs."""
def __init__(self):
# Initialize clients if their API keys are available
self.hf_client = InferenceClient(token=HF_TOKEN) if HF_TOKEN else None
self.groq_client = Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None
self.fireworks_client = Fireworks if FIREWORKS_API_KEY else None
if self.fireworks_client:
self.fireworks_client.api_key = FIREWORKS_API_KEY
def generate_code_stream(
self, model_id: str, messages: Messages, max_tokens: int = 8000
) -> Generator[str, None, None]:
"""
Streams code generation, dispatching to the correct provider based on model_id.
The model_id format is 'provider/model-name' or a full HF model_id.
"""
provider = "huggingface" # Default provider
model_name = model_id
if '/' in model_id:
parts = model_id.split('/', 1)
if parts[0] in ['groq', 'fireworks', 'huggingface']:
provider = parts[0]
model_name = parts[1]
logging.info(f"Dispatching to provider: {provider} for model: {model_name}")
try:
# --- Groq Provider ---
if provider == 'groq':
if not self.groq_client:
raise ValueError("Groq API key is not configured.")
stream = self.groq_client.chat.completions.create(
model=model_name, messages=messages, stream=True, max_tokens=max_tokens
)
for chunk in stream:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
# --- Fireworks AI Provider ---
elif provider == 'fireworks':
if not self.fireworks_client:
raise ValueError("Fireworks AI API key is not configured.")
stream = self.fireworks_client.ChatCompletion.create(
model=model_name, messages=messages, stream=True, max_tokens=max_tokens
)
for chunk in stream:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
# --- Hugging Face Provider (Default) ---
else:
if not self.hf_client:
raise ValueError("Hugging Face API token is not configured.")
# For HF, the model_name is the full original model_id
stream = self.hf_client.chat_completion(
model=model_name, messages=messages, stream=True, max_tokens=max_tokens
)
for chunk in stream:
yield chunk.choices[0].delta.content
except Exception as e:
logging.error(f"LLM API Error with provider {provider}: {e}")
yield f"Error from {provider.capitalize()}: {str(e)}"
class SearchService:
# (This class remains unchanged)
def __init__(self, api_key: str = TAVILY_API_KEY):
# ... existing code ...
def is_available(self) -> bool:
# ... existing code ...
def search(self, query: str, max_results: int = 5) -> str:
# ... existing code ...
# --- Singleton Instances ---
llm_service = LLMService()
search_service = SearchService() |