|
|
|
|
|
""" |
|
Manages interactions with external services like LLM providers and web search APIs. |
|
This module has been refactored to support multiple LLM providers: |
|
- Hugging Face (for standard and multimodal models) |
|
- Groq (for high-speed inference) |
|
- Fireworks AI |
|
""" |
|
import os |
|
import logging |
|
from typing import Dict, Any, Generator, List |
|
|
|
from dotenv import load_dotenv |
|
|
|
|
|
from huggingface_hub import InferenceClient |
|
from tavily import TavilyClient |
|
from groq import Groq |
|
import fireworks.client as Fireworks |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
load_dotenv() |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") |
|
GROQ_API_KEY = os.getenv("GROQ_API_KEY") |
|
FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY") |
|
|
|
|
|
Messages = List[Dict[str, Any]] |
|
|
|
class LLMService: |
|
"""A multi-provider wrapper for LLM Inference APIs.""" |
|
|
|
def __init__(self): |
|
|
|
self.hf_client = InferenceClient(token=HF_TOKEN) if HF_TOKEN else None |
|
self.groq_client = Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None |
|
self.fireworks_client = Fireworks if FIREWORKS_API_KEY else None |
|
if self.fireworks_client: |
|
self.fireworks_client.api_key = FIREWORKS_API_KEY |
|
|
|
def generate_code_stream( |
|
self, model_id: str, messages: Messages, max_tokens: int = 8000 |
|
) -> Generator[str, None, None]: |
|
""" |
|
Streams code generation, dispatching to the correct provider based on model_id. |
|
The model_id format is 'provider/model-name' or a full HF model_id. |
|
""" |
|
provider = "huggingface" |
|
model_name = model_id |
|
|
|
if '/' in model_id: |
|
parts = model_id.split('/', 1) |
|
if parts[0] in ['groq', 'fireworks', 'huggingface']: |
|
provider = parts[0] |
|
model_name = parts[1] |
|
|
|
logging.info(f"Dispatching to provider: {provider} for model: {model_name}") |
|
|
|
try: |
|
|
|
if provider == 'groq': |
|
if not self.groq_client: |
|
raise ValueError("Groq API key is not configured.") |
|
stream = self.groq_client.chat.completions.create( |
|
model=model_name, messages=messages, stream=True, max_tokens=max_tokens |
|
) |
|
for chunk in stream: |
|
if chunk.choices[0].delta.content: |
|
yield chunk.choices[0].delta.content |
|
|
|
|
|
elif provider == 'fireworks': |
|
if not self.fireworks_client: |
|
raise ValueError("Fireworks AI API key is not configured.") |
|
stream = self.fireworks_client.ChatCompletion.create( |
|
model=model_name, messages=messages, stream=True, max_tokens=max_tokens |
|
) |
|
for chunk in stream: |
|
if chunk.choices[0].delta.content: |
|
yield chunk.choices[0].delta.content |
|
|
|
|
|
else: |
|
if not self.hf_client: |
|
raise ValueError("Hugging Face API token is not configured.") |
|
|
|
stream = self.hf_client.chat_completion( |
|
model=model_name, messages=messages, stream=True, max_tokens=max_tokens |
|
) |
|
for chunk in stream: |
|
yield chunk.choices[0].delta.content |
|
|
|
except Exception as e: |
|
logging.error(f"LLM API Error with provider {provider}: {e}") |
|
yield f"Error from {provider.capitalize()}: {str(e)}" |
|
|
|
|
|
class SearchService: |
|
|
|
def __init__(self, api_key: str = TAVILY_API_KEY): |
|
|
|
def is_available(self) -> bool: |
|
|
|
def search(self, query: str, max_results: int = 5) -> str: |
|
|
|
|
|
|
|
llm_service = LLMService() |
|
search_service = SearchService() |