mgbam commited on
Commit
489ab9c
·
verified ·
1 Parent(s): a73275a

Update services.py

Browse files
Files changed (1) hide show
  1. services.py +73 -68
services.py CHANGED
@@ -2,108 +2,113 @@
2
 
3
  """
4
  Manages interactions with external services like LLM providers and web search APIs.
5
-
6
- This module uses a class-based approach to encapsulate API clients and their
7
- logic, making it easy to manage connections and mock services for testing.
 
8
  """
9
  import os
10
  import logging
11
  from typing import Dict, Any, Generator, List
12
 
13
  from dotenv import load_dotenv
 
 
14
  from huggingface_hub import InferenceClient
15
  from tavily import TavilyClient
 
 
16
 
17
- # --- Setup Logging ---
18
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
19
-
20
- # --- Load Environment Variables ---
21
  load_dotenv()
 
 
22
  HF_TOKEN = os.getenv("HF_TOKEN")
23
  TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
24
-
25
- if not HF_TOKEN:
26
- raise ValueError("HF_TOKEN environment variable is not set. Please get a token from https://huggingface.co/settings/tokens")
27
 
28
  # --- Type Definitions ---
29
  Messages = List[Dict[str, Any]]
30
 
31
  class LLMService:
32
- """A wrapper for the Hugging Face Inference API."""
33
- def __init__(self, api_key: str = HF_TOKEN):
34
- if not api_key:
35
- raise ValueError("Hugging Face API key is required.")
36
- self.api_key = api_key
37
 
38
- def get_client(self, model_id: str, provider: str = "auto") -> InferenceClient:
39
- """Initializes and returns an InferenceClient."""
40
- return InferenceClient(provider=provider, api_key=self.api_key, bill_to="huggingface")
 
 
 
 
41
 
42
  def generate_code_stream(
43
- self, model_id: str, messages: Messages, provider: str = "auto", max_tokens: int = 10000
44
  ) -> Generator[str, None, None]:
45
  """
46
- Streams code generation from the specified model.
47
- Yields content chunks as they are received.
48
  """
49
- client = self.get_client(model_id, provider)
 
 
 
 
 
 
 
 
 
 
50
  try:
51
- stream = client.chat.completions.create(
52
- model=model_id,
53
- messages=messages,
54
- stream=True,
55
- max_tokens=max_tokens,
56
- )
57
- for chunk in stream:
58
- if chunk.choices and chunk.choices[0].delta.content:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  yield chunk.choices[0].delta.content
 
60
  except Exception as e:
61
- logging.error(f"LLM API Error for model {model_id}: {e}")
62
- yield f"Error: Could not get a response from the model. Details: {str(e)}"
63
- # Re-raise or handle as appropriate for your application flow
64
- # For this app, we yield an error message to the user.
65
 
66
 
67
  class SearchService:
68
- """A wrapper for the Tavily Search API."""
69
  def __init__(self, api_key: str = TAVILY_API_KEY):
70
- if not api_key:
71
- logging.warning("TAVILY_API_KEY not set. Web search will be disabled.")
72
- self.client = None
73
- else:
74
- try:
75
- self.client = TavilyClient(api_key=api_key)
76
- except Exception as e:
77
- logging.error(f"Failed to initialize Tavily client: {e}")
78
- self.client = None
79
-
80
  def is_available(self) -> bool:
81
- """Checks if the search service is configured and available."""
82
- return self.client is not None
83
-
84
  def search(self, query: str, max_results: int = 5) -> str:
85
- """
86
- Performs a web search and returns a formatted string of results.
87
- """
88
- if not self.is_available():
89
- return "Web search is not available."
90
-
91
- try:
92
- response = self.client.search(
93
- query,
94
- search_depth="advanced",
95
- max_results=min(max(1, max_results), 10)
96
- )
97
- results = [
98
- f"Title: {res.get('title', 'N/A')}\nURL: {res.get('url', 'N/A')}\nContent: {res.get('content', 'N/A')}"
99
- for res in response.get('results', [])
100
- ]
101
- return "Web Search Results:\n\n" + "\n---\n".join(results) if results else "No search results found."
102
- except Exception as e:
103
- logging.error(f"Tavily search error: {e}")
104
- return f"Search error: {str(e)}"
105
 
106
  # --- Singleton Instances ---
107
- # These instances can be imported and used throughout the application.
108
  llm_service = LLMService()
109
  search_service = SearchService()
 
2
 
3
  """
4
  Manages interactions with external services like LLM providers and web search APIs.
5
+ This module has been refactored to support multiple LLM providers:
6
+ - Hugging Face (for standard and multimodal models)
7
+ - Groq (for high-speed inference)
8
+ - Fireworks AI
9
  """
10
  import os
11
  import logging
12
  from typing import Dict, Any, Generator, List
13
 
14
  from dotenv import load_dotenv
15
+
16
+ # Import all necessary clients
17
  from huggingface_hub import InferenceClient
18
  from tavily import TavilyClient
19
+ from groq import Groq
20
+ import fireworks.client as Fireworks
21
 
22
+ # --- Setup Logging & Environment ---
23
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
 
24
  load_dotenv()
25
+
26
+ # --- API Keys ---
27
  HF_TOKEN = os.getenv("HF_TOKEN")
28
  TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
29
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
30
+ FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")
 
31
 
32
  # --- Type Definitions ---
33
  Messages = List[Dict[str, Any]]
34
 
35
  class LLMService:
36
+ """A multi-provider wrapper for LLM Inference APIs."""
 
 
 
 
37
 
38
+ def __init__(self):
39
+ # Initialize clients if their API keys are available
40
+ self.hf_client = InferenceClient(token=HF_TOKEN) if HF_TOKEN else None
41
+ self.groq_client = Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None
42
+ self.fireworks_client = Fireworks if FIREWORKS_API_KEY else None
43
+ if self.fireworks_client:
44
+ self.fireworks_client.api_key = FIREWORKS_API_KEY
45
 
46
  def generate_code_stream(
47
+ self, model_id: str, messages: Messages, max_tokens: int = 8000
48
  ) -> Generator[str, None, None]:
49
  """
50
+ Streams code generation, dispatching to the correct provider based on model_id.
51
+ The model_id format is 'provider/model-name' or a full HF model_id.
52
  """
53
+ provider = "huggingface" # Default provider
54
+ model_name = model_id
55
+
56
+ if '/' in model_id:
57
+ parts = model_id.split('/', 1)
58
+ if parts[0] in ['groq', 'fireworks', 'huggingface']:
59
+ provider = parts[0]
60
+ model_name = parts[1]
61
+
62
+ logging.info(f"Dispatching to provider: {provider} for model: {model_name}")
63
+
64
  try:
65
+ # --- Groq Provider ---
66
+ if provider == 'groq':
67
+ if not self.groq_client:
68
+ raise ValueError("Groq API key is not configured.")
69
+ stream = self.groq_client.chat.completions.create(
70
+ model=model_name, messages=messages, stream=True, max_tokens=max_tokens
71
+ )
72
+ for chunk in stream:
73
+ if chunk.choices[0].delta.content:
74
+ yield chunk.choices[0].delta.content
75
+
76
+ # --- Fireworks AI Provider ---
77
+ elif provider == 'fireworks':
78
+ if not self.fireworks_client:
79
+ raise ValueError("Fireworks AI API key is not configured.")
80
+ stream = self.fireworks_client.ChatCompletion.create(
81
+ model=model_name, messages=messages, stream=True, max_tokens=max_tokens
82
+ )
83
+ for chunk in stream:
84
+ if chunk.choices[0].delta.content:
85
+ yield chunk.choices[0].delta.content
86
+
87
+ # --- Hugging Face Provider (Default) ---
88
+ else:
89
+ if not self.hf_client:
90
+ raise ValueError("Hugging Face API token is not configured.")
91
+ # For HF, the model_name is the full original model_id
92
+ stream = self.hf_client.chat_completion(
93
+ model=model_name, messages=messages, stream=True, max_tokens=max_tokens
94
+ )
95
+ for chunk in stream:
96
  yield chunk.choices[0].delta.content
97
+
98
  except Exception as e:
99
+ logging.error(f"LLM API Error with provider {provider}: {e}")
100
+ yield f"Error from {provider.capitalize()}: {str(e)}"
 
 
101
 
102
 
103
  class SearchService:
104
+ # (This class remains unchanged)
105
  def __init__(self, api_key: str = TAVILY_API_KEY):
106
+ # ... existing code ...
 
 
 
 
 
 
 
 
 
107
  def is_available(self) -> bool:
108
+ # ... existing code ...
 
 
109
  def search(self, query: str, max_results: int = 5) -> str:
110
+ # ... existing code ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  # --- Singleton Instances ---
 
113
  llm_service = LLMService()
114
  search_service = SearchService()