mgbam commited on
Commit
4739b8c
·
verified ·
1 Parent(s): de11970

Update services.py

Browse files
Files changed (1) hide show
  1. services.py +22 -67
services.py CHANGED
@@ -1,22 +1,11 @@
1
  # /services.py
 
2
 
3
- """
4
- Manages interactions with external services like LLM providers and web search APIs.
5
- This module has been refactored to support multiple LLM providers:
6
- - Hugging Face
7
- - Groq
8
- - Fireworks AI
9
- - OpenAI
10
- - Google Gemini
11
- - DeepSeek (Direct API via OpenAI client)
12
- """
13
  import os
14
  import logging
15
  from typing import Dict, Any, Generator, List
16
 
17
  from dotenv import load_dotenv
18
-
19
- # Import all necessary clients
20
  from huggingface_hub import InferenceClient
21
  from tavily import TavilyClient
22
  from groq import Groq
@@ -24,9 +13,6 @@ import fireworks.client as Fireworks
24
  import openai
25
  import google.generativeai as genai
26
 
27
- # <--- FIX: REMOVED the incorrect 'from deepseek import ...' line ---
28
-
29
- # --- Setup Logging & Environment ---
30
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
31
  load_dotenv()
32
 
@@ -39,24 +25,17 @@ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
39
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
40
  DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
41
 
42
- # --- Type Definitions ---
43
  Messages = List[Dict[str, Any]]
44
 
45
  class LLMService:
46
  """A multi-provider wrapper for LLM Inference APIs."""
47
-
48
  def __init__(self):
49
- # Initialize clients only if their API keys are available
50
  self.hf_client = InferenceClient(token=HF_TOKEN) if HF_TOKEN else None
51
  self.groq_client = Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None
52
  self.openai_client = openai.OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
53
 
54
- # <--- FIX: Correctly instantiate the DeepSeek client using the OpenAI library ---
55
  if DEEPSEEK_API_KEY:
56
- self.deepseek_client = openai.OpenAI(
57
- api_key=DEEPSEEK_API_KEY,
58
- base_url="https://api.deepseek.com/v1"
59
- )
60
  else:
61
  self.deepseek_client = None
62
 
@@ -73,84 +52,60 @@ class LLMService:
73
  self.gemini_model = None
74
 
75
  def _prepare_messages_for_gemini(self, messages: Messages) -> List[Dict[str, Any]]:
76
- # This function remains the same
77
  gemini_messages = []
78
  for msg in messages:
 
79
  role = 'model' if msg['role'] == 'assistant' else 'user'
80
  gemini_messages.append({'role': role, 'parts': [msg['content']]})
81
  return gemini_messages
82
 
83
- def generate_code_stream(
84
- self, model_id: str, messages: Messages, max_tokens: int = 8192
85
- ) -> Generator[str, None, None]:
86
- # This function remains the same, as the dispatcher logic is already correct
87
  provider, model_name = model_id.split('/', 1)
88
  logging.info(f"Dispatching to provider: {provider} for model: {model_name}")
89
 
90
  try:
91
  if provider in ['openai', 'groq', 'deepseek', 'fireworks']:
92
- client_map = {
93
- 'openai': self.openai_client,
94
- 'groq': self.groq_client,
95
- 'deepseek': self.deepseek_client,
96
- 'fireworks': self.fireworks_client.ChatCompletion if self.fireworks_client else None,
97
- }
98
  client = client_map.get(provider)
99
- if not client:
100
- raise ValueError(f"{provider.capitalize()} API key not configured.")
101
-
102
- if provider == 'fireworks':
103
- stream = client.create(model=model_name, messages=messages, stream=True, max_tokens=max_tokens)
104
- else:
105
- stream = client.chat.completions.create(model=model_name, messages=messages, stream=True, max_tokens=max_tokens)
106
 
 
107
  for chunk in stream:
108
- if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
109
- yield chunk.choices[0].delta.content
110
-
111
  elif provider == 'gemini':
112
- if not self.gemini_model:
113
- raise ValueError("Gemini API key not configured.")
114
  gemini_messages = self._prepare_messages_for_gemini(messages)
 
 
 
115
  stream = self.gemini_model.generate_content(gemini_messages, stream=True)
116
- for chunk in stream:
117
- yield chunk.text
118
 
119
  elif provider == 'huggingface':
120
- if not self.hf_client:
121
- raise ValueError("Hugging Face API token not configured.")
122
  hf_model_id = model_id.split('/', 1)[1]
123
  stream = self.hf_client.chat_completion(model=hf_model_id, messages=messages, stream=True, max_tokens=max_tokens)
124
  for chunk in stream:
125
- if chunk.choices[0].delta.content:
126
- yield chunk.choices[0].delta.content
127
  else:
128
  raise ValueError(f"Unknown provider: {provider}")
129
-
130
  except Exception as e:
131
  logging.error(f"LLM API Error with provider {provider}: {e}")
132
  yield f"Error from {provider.capitalize()}: {str(e)}"
133
 
134
- # The SearchService class remains unchanged
135
  class SearchService:
136
  def __init__(self, api_key: str = TAVILY_API_KEY):
137
- if not api_key:
138
- logging.warning("TAVILY_API_KEY not set. Web search will be disabled.")
139
- self.client = None
140
- else:
141
- self.client = TavilyClient(api_key=api_key)
142
- def is_available(self) -> bool:
143
- return self.client is not None
144
  def search(self, query: str, max_results: int = 5) -> str:
145
  if not self.is_available(): return "Web search is not available."
146
  try:
147
  response = self.client.search(query, search_depth="advanced", max_results=min(max(1, max_results), 10))
148
- results = [f"Title: {res.get('title', 'N/A')}\nURL: {res.get('url', 'N/A')}\nContent: {res.get('content', 'N/A')}" for res in response.get('results', [])]
149
- return "Web Search Results:\n\n" + "\n---\n".join(results) if results else "No search results found."
150
- except Exception as e:
151
- logging.error(f"Tavily search error: {e}")
152
- return f"Search error: {str(e)}"
153
 
154
- # --- Singleton Instances ---
155
  llm_service = LLMService()
156
  search_service = SearchService()
 
1
  # /services.py
2
+ """ Manages interactions with all external LLM and search APIs. """
3
 
 
 
 
 
 
 
 
 
 
 
4
  import os
5
  import logging
6
  from typing import Dict, Any, Generator, List
7
 
8
  from dotenv import load_dotenv
 
 
9
  from huggingface_hub import InferenceClient
10
  from tavily import TavilyClient
11
  from groq import Groq
 
13
  import openai
14
  import google.generativeai as genai
15
 
 
 
 
16
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
17
  load_dotenv()
18
 
 
25
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
26
  DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
27
 
 
28
  Messages = List[Dict[str, Any]]
29
 
30
  class LLMService:
31
  """A multi-provider wrapper for LLM Inference APIs."""
 
32
  def __init__(self):
 
33
  self.hf_client = InferenceClient(token=HF_TOKEN) if HF_TOKEN else None
34
  self.groq_client = Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None
35
  self.openai_client = openai.OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
36
 
 
37
  if DEEPSEEK_API_KEY:
38
+ self.deepseek_client = openai.OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com/v1")
 
 
 
39
  else:
40
  self.deepseek_client = None
41
 
 
52
  self.gemini_model = None
53
 
54
  def _prepare_messages_for_gemini(self, messages: Messages) -> List[Dict[str, Any]]:
 
55
  gemini_messages = []
56
  for msg in messages:
57
+ if msg['role'] == 'system': continue # Gemini doesn't use a system role in this way
58
  role = 'model' if msg['role'] == 'assistant' else 'user'
59
  gemini_messages.append({'role': role, 'parts': [msg['content']]})
60
  return gemini_messages
61
 
62
+ def generate_code_stream(self, model_id: str, messages: Messages, max_tokens: int = 8192) -> Generator[str, None, None]:
 
 
 
63
  provider, model_name = model_id.split('/', 1)
64
  logging.info(f"Dispatching to provider: {provider} for model: {model_name}")
65
 
66
  try:
67
  if provider in ['openai', 'groq', 'deepseek', 'fireworks']:
68
+ client_map = {'openai': self.openai_client, 'groq': self.groq_client, 'deepseek': self.deepseek_client, 'fireworks': self.fireworks_client.ChatCompletion if self.fireworks_client else None}
 
 
 
 
 
69
  client = client_map.get(provider)
70
+ if not client: raise ValueError(f"{provider.capitalize()} API key not configured.")
 
 
 
 
 
 
71
 
72
+ stream = client.create(model=model_name, messages=messages, stream=True, max_tokens=max_tokens) if provider == 'fireworks' else client.chat.completions.create(model=model_name, messages=messages, stream=True, max_tokens=max_tokens)
73
  for chunk in stream:
74
+ if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: yield chunk.choices[0].delta.content
75
+
 
76
  elif provider == 'gemini':
77
+ if not self.gemini_model: raise ValueError("Gemini API key not configured.")
78
+ system_prompt = next((msg['content'] for msg in messages if msg['role'] == 'system'), "")
79
  gemini_messages = self._prepare_messages_for_gemini(messages)
80
+ # Prepend system prompt to first user message for Gemini
81
+ if system_prompt and gemini_messages and gemini_messages[0]['role'] == 'user':
82
+ gemini_messages[0]['parts'][0] = f"{system_prompt}\n\n{gemini_messages[0]['parts'][0]}"
83
  stream = self.gemini_model.generate_content(gemini_messages, stream=True)
84
+ for chunk in stream: yield chunk.text
 
85
 
86
  elif provider == 'huggingface':
87
+ if not self.hf_client: raise ValueError("Hugging Face API token not configured.")
 
88
  hf_model_id = model_id.split('/', 1)[1]
89
  stream = self.hf_client.chat_completion(model=hf_model_id, messages=messages, stream=True, max_tokens=max_tokens)
90
  for chunk in stream:
91
+ if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: yield chunk.choices[0].delta.content
 
92
  else:
93
  raise ValueError(f"Unknown provider: {provider}")
 
94
  except Exception as e:
95
  logging.error(f"LLM API Error with provider {provider}: {e}")
96
  yield f"Error from {provider.capitalize()}: {str(e)}"
97
 
 
98
  class SearchService:
99
  def __init__(self, api_key: str = TAVILY_API_KEY):
100
+ self.client = TavilyClient(api_key=api_key) if api_key else None
101
+ if not self.client: logging.warning("TAVILY_API_KEY not set. Web search will be disabled.")
102
+ def is_available(self) -> bool: return self.client is not None
 
 
 
 
103
  def search(self, query: str, max_results: int = 5) -> str:
104
  if not self.is_available(): return "Web search is not available."
105
  try:
106
  response = self.client.search(query, search_depth="advanced", max_results=min(max(1, max_results), 10))
107
+ return "Web Search Results:\n\n" + "\n---\n".join([f"Title: {res.get('title', 'N/A')}\nURL: {res.get('url', 'N/A')}\nContent: {res.get('content', 'N/A')}" for res in response.get('results', [])])
108
+ except Exception as e: return f"Search error: {str(e)}"
 
 
 
109
 
 
110
  llm_service = LLMService()
111
  search_service = SearchService()