File size: 4,312 Bytes
1687ea3
 
 
 
489ab9c
 
 
 
1687ea3
 
 
 
 
 
489ab9c
 
1687ea3
 
489ab9c
 
1687ea3
489ab9c
1687ea3
 
489ab9c
 
1687ea3
 
489ab9c
 
1687ea3
 
 
 
 
489ab9c
1687ea3
489ab9c
 
 
 
 
 
 
1687ea3
 
489ab9c
1687ea3
 
489ab9c
 
1687ea3
489ab9c
 
 
 
 
 
 
 
 
 
 
1687ea3
489ab9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1687ea3
489ab9c
1687ea3
489ab9c
 
1687ea3
 
 
489ab9c
1687ea3
489ab9c
1687ea3
489ab9c
1687ea3
489ab9c
1687ea3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# /services.py

"""
Manages interactions with external services like LLM providers and web search APIs.
This module has been refactored to support multiple LLM providers:
- Hugging Face (for standard and multimodal models)
- Groq (for high-speed inference)
- Fireworks AI
"""
import os
import logging
from typing import Dict, Any, Generator, List

from dotenv import load_dotenv

# Import all necessary clients
from huggingface_hub import InferenceClient
from tavily import TavilyClient
from groq import Groq
import fireworks.client as Fireworks

# --- Setup Logging & Environment ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
load_dotenv()

# --- API Keys ---
HF_TOKEN = os.getenv("HF_TOKEN")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")

# --- Type Definitions ---
Messages = List[Dict[str, Any]]

class LLMService:
    """A multi-provider wrapper for LLM Inference APIs."""

    def __init__(self):
        # Initialize clients if their API keys are available
        self.hf_client = InferenceClient(token=HF_TOKEN) if HF_TOKEN else None
        self.groq_client = Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None
        self.fireworks_client = Fireworks if FIREWORKS_API_KEY else None
        if self.fireworks_client:
            self.fireworks_client.api_key = FIREWORKS_API_KEY

    def generate_code_stream(
        self, model_id: str, messages: Messages, max_tokens: int = 8000
    ) -> Generator[str, None, None]:
        """
        Streams code generation, dispatching to the correct provider based on model_id.
        The model_id format is 'provider/model-name' or a full HF model_id.
        """
        provider = "huggingface" # Default provider
        model_name = model_id

        if '/' in model_id:
            parts = model_id.split('/', 1)
            if parts[0] in ['groq', 'fireworks', 'huggingface']:
                provider = parts[0]
                model_name = parts[1]

        logging.info(f"Dispatching to provider: {provider} for model: {model_name}")

        try:
            # --- Groq Provider ---
            if provider == 'groq':
                if not self.groq_client:
                    raise ValueError("Groq API key is not configured.")
                stream = self.groq_client.chat.completions.create(
                    model=model_name, messages=messages, stream=True, max_tokens=max_tokens
                )
                for chunk in stream:
                    if chunk.choices[0].delta.content:
                        yield chunk.choices[0].delta.content

            # --- Fireworks AI Provider ---
            elif provider == 'fireworks':
                if not self.fireworks_client:
                    raise ValueError("Fireworks AI API key is not configured.")
                stream = self.fireworks_client.ChatCompletion.create(
                    model=model_name, messages=messages, stream=True, max_tokens=max_tokens
                )
                for chunk in stream:
                    if chunk.choices[0].delta.content:
                        yield chunk.choices[0].delta.content

            # --- Hugging Face Provider (Default) ---
            else:
                if not self.hf_client:
                    raise ValueError("Hugging Face API token is not configured.")
                # For HF, the model_name is the full original model_id
                stream = self.hf_client.chat_completion(
                    model=model_name, messages=messages, stream=True, max_tokens=max_tokens
                )
                for chunk in stream:
                    yield chunk.choices[0].delta.content

        except Exception as e:
            logging.error(f"LLM API Error with provider {provider}: {e}")
            yield f"Error from {provider.capitalize()}: {str(e)}"


class SearchService:
    # (This class remains unchanged)
    def __init__(self, api_key: str = TAVILY_API_KEY):
        # ... existing code ...
    def is_available(self) -> bool:
        # ... existing code ...
    def search(self, query: str, max_results: int = 5) -> str:
        # ... existing code ...

# --- Singleton Instances ---
llm_service = LLMService()
search_service = SearchService()