File size: 7,081 Bytes
1687ea3
 
 
 
489ab9c
6756974
 
489ab9c
6756974
 
 
1687ea3
 
 
 
 
 
489ab9c
 
1687ea3
 
489ab9c
 
6756974
 
 
1687ea3
489ab9c
1687ea3
 
489ab9c
6756974
1687ea3
 
489ab9c
 
6756974
 
 
 
1687ea3
 
 
 
 
489ab9c
1687ea3
489ab9c
6756974
489ab9c
 
6756974
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1687ea3
 
6756974
1687ea3
 
489ab9c
1687ea3
6756974
489ab9c
 
1687ea3
6756974
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489ab9c
6756974
489ab9c
 
6756974
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489ab9c
 
 
 
 
6756974
489ab9c
1687ea3
489ab9c
 
1687ea3
 
6756974
1687ea3
6756974
 
 
 
 
 
 
 
 
 
1687ea3
6756974
 
 
1687ea3
6756974
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1687ea3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# /services.py

"""
Manages interactions with external services like LLM providers and web search APIs.
This module has been refactored to support multiple LLM providers:
- Hugging Face
- Groq
- Fireworks AI
- OpenAI
- Google Gemini
- DeepSeek (Direct API)
"""
import os
import logging
from typing import Dict, Any, Generator, List

from dotenv import load_dotenv

# Import all necessary clients
from huggingface_hub import InferenceClient
from tavily import TavilyClient
from groq import Groq
import fireworks.client as Fireworks
import openai
import google.generativeai as genai
from deepseek import OpenaiClient as DeepSeekClient

# --- Setup Logging & Environment ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
load_dotenv()

# --- API Keys from .env ---
HF_TOKEN = os.getenv("HF_TOKEN")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")


# --- Type Definitions ---
Messages = List[Dict[str, Any]]

class LLMService:
    """A multi-provider wrapper for LLM Inference APIs."""

    def __init__(self):
        # Initialize clients only if their API keys are available
        self.hf_client = InferenceClient(token=HF_TOKEN) if HF_TOKEN else None
        self.groq_client = Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None
        self.openai_client = openai.OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
        self.deepseek_client = DeepSeekClient(api_key=DEEPSEEK_API_KEY) if DEEPSEEK_API_KEY else None
        
        if FIREWORKS_API_KEY:
            Fireworks.api_key = FIREWORKS_API_KEY
            self.fireworks_client = Fireworks
        else:
            self.fireworks_client = None

        if GEMINI_API_KEY:
            genai.configure(api_key=GEMINI_API_KEY)
            self.gemini_model = genai.GenerativeModel('gemini-1.5-pro-latest')
        else:
            self.gemini_model = None

    def _prepare_messages_for_gemini(self, messages: Messages) -> List[Dict[str, Any]]:
        """Gemini requires a slightly different message format."""
        gemini_messages = []
        for msg in messages:
            # Gemini uses 'model' for assistant role
            role = 'model' if msg['role'] == 'assistant' else 'user'
            gemini_messages.append({'role': role, 'parts': [msg['content']]})
        return gemini_messages

    def generate_code_stream(
        self, model_id: str, messages: Messages, max_tokens: int = 8192
    ) -> Generator[str, None, None]:
        """
        Streams code generation, dispatching to the correct provider based on model_id.
        """
        provider, model_name = model_id.split('/', 1)
        logging.info(f"Dispatching to provider: {provider} for model: {model_name}")

        try:
            # --- OpenAI, Groq, DeepSeek, Fireworks (OpenAI-compatible) ---
            if provider in ['openai', 'groq', 'deepseek', 'fireworks']:
                client_map = {
                    'openai': self.openai_client,
                    'groq': self.groq_client,
                    'deepseek': self.deepseek_client,
                    'fireworks': self.fireworks_client.ChatCompletion if self.fireworks_client else None,
                }
                client = client_map.get(provider)
                if not client:
                    raise ValueError(f"{provider.capitalize()} API key not configured.")
                
                # Fireworks has a slightly different call signature
                if provider == 'fireworks':
                     stream = client.create(model=model_name, messages=messages, stream=True, max_tokens=max_tokens)
                else:
                     stream = client.chat.completions.create(model=model_name, messages=messages, stream=True, max_tokens=max_tokens)
                
                for chunk in stream:
                    if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
                        yield chunk.choices[0].delta.content

            # --- Google Gemini ---
            elif provider == 'gemini':
                if not self.gemini_model:
                    raise ValueError("Gemini API key not configured.")
                gemini_messages = self._prepare_messages_for_gemini(messages)
                stream = self.gemini_model.generate_content(gemini_messages, stream=True)
                for chunk in stream:
                    yield chunk.text

            # --- Hugging Face ---
            elif provider == 'huggingface':
                if not self.hf_client:
                    raise ValueError("Hugging Face API token not configured.")
                # For HF, model_name is the rest of the ID, e.g., baidu/ERNIE...
                hf_model_id = model_id.split('/', 1)[1]
                stream = self.hf_client.chat_completion(model=hf_model_id, messages=messages, stream=True, max_tokens=max_tokens)
                for chunk in stream:
                    if chunk.choices[0].delta.content:
                        yield chunk.choices[0].delta.content

            else:
                raise ValueError(f"Unknown provider: {provider}")

        except Exception as e:
            logging.error(f"LLM API Error with provider {provider}: {e}")
            yield f"Error from {provider.capitalize()}: {str(e)}"

class SearchService:
    """A wrapper for the Tavily Search API."""
    def __init__(self, api_key: str = TAVILY_API_KEY):
        if not api_key:
            logging.warning("TAVILY_API_KEY not set. Web search will be disabled.")
            self.client = None
        else:
            try:
                self.client = TavilyClient(api_key=api_key)
            except Exception as e:
                logging.error(f"Failed to initialize Tavily client: {e}")
                self.client = None

    def is_available(self) -> bool:
        """Checks if the search service is configured and available."""
        return self.client is not None

    def search(self, query: str, max_results: int = 5) -> str:
        """Performs a web search and returns a formatted string of results."""
        if not self.is_available():
            return "Web search is not available."
        try:
            response = self.client.search(
                query, search_depth="advanced", max_results=min(max(1, max_results), 10)
            )
            results = [
                f"Title: {res.get('title', 'N/A')}\nURL: {res.get('url', 'N/A')}\nContent: {res.get('content', 'N/A')}"
                for res in response.get('results', [])
            ]
            return "Web Search Results:\n\n" + "\n---\n".join(results) if results else "No search results found."
        except Exception as e:
            logging.error(f"Tavily search error: {e}")
            return f"Search error: {str(e)}"

# --- Singleton Instances ---
llm_service = LLMService()
search_service = SearchService()