File size: 5,795 Bytes
1687ea3
4739b8c
1687ea3
 
 
 
 
 
 
 
489ab9c
 
6756974
 
a606af6
1687ea3
 
489ab9c
6756974
1687ea3
 
489ab9c
 
6756974
 
 
 
1687ea3
 
 
489ab9c
 
 
 
6756974
 
a606af6
4739b8c
a606af6
 
 
6756974
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4739b8c
6756974
 
 
1687ea3
4739b8c
6756974
489ab9c
 
1687ea3
6756974
4739b8c
6756974
4739b8c
6756974
4739b8c
489ab9c
4739b8c
 
6756974
4739b8c
 
6756974
4739b8c
 
 
6756974
4739b8c
6756974
 
4739b8c
6756974
 
489ab9c
4739b8c
489ab9c
6756974
1687ea3
489ab9c
 
1687ea3
 
 
4739b8c
 
 
1687ea3
a606af6
6756974
a606af6
4739b8c
 
1687ea3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# /services.py
""" Manages interactions with all external LLM and search APIs. """

import os
import logging
from typing import Dict, Any, Generator, List

from dotenv import load_dotenv
from huggingface_hub import InferenceClient
from tavily import TavilyClient
from groq import Groq
import fireworks.client as Fireworks
import openai
import google.generativeai as genai

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
load_dotenv()

# --- API Keys from .env ---
HF_TOKEN = os.getenv("HF_TOKEN")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")

Messages = List[Dict[str, Any]]

class LLMService:
    """A multi-provider wrapper for LLM Inference APIs."""
    def __init__(self):
        self.hf_client = InferenceClient(token=HF_TOKEN) if HF_TOKEN else None
        self.groq_client = Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None
        self.openai_client = openai.OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
        
        if DEEPSEEK_API_KEY:
            self.deepseek_client = openai.OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com/v1")
        else:
            self.deepseek_client = None

        if FIREWORKS_API_KEY:
            Fireworks.api_key = FIREWORKS_API_KEY
            self.fireworks_client = Fireworks
        else:
            self.fireworks_client = None

        if GEMINI_API_KEY:
            genai.configure(api_key=GEMINI_API_KEY)
            self.gemini_model = genai.GenerativeModel('gemini-1.5-pro-latest')
        else:
            self.gemini_model = None

    def _prepare_messages_for_gemini(self, messages: Messages) -> List[Dict[str, Any]]:
        gemini_messages = []
        for msg in messages:
            if msg['role'] == 'system': continue # Gemini doesn't use a system role in this way
            role = 'model' if msg['role'] == 'assistant' else 'user'
            gemini_messages.append({'role': role, 'parts': [msg['content']]})
        return gemini_messages

    def generate_code_stream(self, model_id: str, messages: Messages, max_tokens: int = 8192) -> Generator[str, None, None]:
        provider, model_name = model_id.split('/', 1)
        logging.info(f"Dispatching to provider: {provider} for model: {model_name}")

        try:
            if provider in ['openai', 'groq', 'deepseek', 'fireworks']:
                client_map = {'openai': self.openai_client, 'groq': self.groq_client, 'deepseek': self.deepseek_client, 'fireworks': self.fireworks_client.ChatCompletion if self.fireworks_client else None}
                client = client_map.get(provider)
                if not client: raise ValueError(f"{provider.capitalize()} API key not configured.")
                
                stream = client.create(model=model_name, messages=messages, stream=True, max_tokens=max_tokens) if provider == 'fireworks' else client.chat.completions.create(model=model_name, messages=messages, stream=True, max_tokens=max_tokens)
                for chunk in stream:
                    if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: yield chunk.choices[0].delta.content
            
            elif provider == 'gemini':
                if not self.gemini_model: raise ValueError("Gemini API key not configured.")
                system_prompt = next((msg['content'] for msg in messages if msg['role'] == 'system'), "")
                gemini_messages = self._prepare_messages_for_gemini(messages)
                # Prepend system prompt to first user message for Gemini
                if system_prompt and gemini_messages and gemini_messages[0]['role'] == 'user':
                    gemini_messages[0]['parts'][0] = f"{system_prompt}\n\n{gemini_messages[0]['parts'][0]}"
                stream = self.gemini_model.generate_content(gemini_messages, stream=True)
                for chunk in stream: yield chunk.text

            elif provider == 'huggingface':
                if not self.hf_client: raise ValueError("Hugging Face API token not configured.")
                hf_model_id = model_id.split('/', 1)[1]
                stream = self.hf_client.chat_completion(model=hf_model_id, messages=messages, stream=True, max_tokens=max_tokens)
                for chunk in stream:
                    if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: yield chunk.choices[0].delta.content
            else:
                raise ValueError(f"Unknown provider: {provider}")
        except Exception as e:
            logging.error(f"LLM API Error with provider {provider}: {e}")
            yield f"Error from {provider.capitalize()}: {str(e)}"

class SearchService:
    def __init__(self, api_key: str = TAVILY_API_KEY):
        self.client = TavilyClient(api_key=api_key) if api_key else None
        if not self.client: logging.warning("TAVILY_API_KEY not set. Web search will be disabled.")
    def is_available(self) -> bool: return self.client is not None
    def search(self, query: str, max_results: int = 5) -> str:
        if not self.is_available(): return "Web search is not available."
        try:
            response = self.client.search(query, search_depth="advanced", max_results=min(max(1, max_results), 10))
            return "Web Search Results:\n\n" + "\n---\n".join([f"Title: {res.get('title', 'N/A')}\nURL: {res.get('url', 'N/A')}\nContent: {res.get('content', 'N/A')}" for res in response.get('results', [])])
        except Exception as e: return f"Search error: {str(e)}"

llm_service = LLMService()
search_service = SearchService()