mgbam commited on
Commit
a606af6
·
verified ·
1 Parent(s): b41faae

Update services.py

Browse files
Files changed (1) hide show
  1. services.py +19 -34
services.py CHANGED
@@ -8,7 +8,7 @@ This module has been refactored to support multiple LLM providers:
8
  - Fireworks AI
9
  - OpenAI
10
  - Google Gemini
11
- - DeepSeek (Direct API)
12
  """
13
  import os
14
  import logging
@@ -23,7 +23,8 @@ from groq import Groq
23
  import fireworks.client as Fireworks
24
  import openai
25
  import google.generativeai as genai
26
- from deepseek import OpenaiClient as DeepSeekClient
 
27
 
28
  # --- Setup Logging & Environment ---
29
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -38,7 +39,6 @@ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
38
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
39
  DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
40
 
41
-
42
  # --- Type Definitions ---
43
  Messages = List[Dict[str, Any]]
44
 
@@ -50,8 +50,16 @@ class LLMService:
50
  self.hf_client = InferenceClient(token=HF_TOKEN) if HF_TOKEN else None
51
  self.groq_client = Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None
52
  self.openai_client = openai.OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
53
- self.deepseek_client = DeepSeekClient(api_key=DEEPSEEK_API_KEY) if DEEPSEEK_API_KEY else None
54
 
 
 
 
 
 
 
 
 
 
55
  if FIREWORKS_API_KEY:
56
  Fireworks.api_key = FIREWORKS_API_KEY
57
  self.fireworks_client = Fireworks
@@ -65,10 +73,9 @@ class LLMService:
65
  self.gemini_model = None
66
 
67
  def _prepare_messages_for_gemini(self, messages: Messages) -> List[Dict[str, Any]]:
68
- """Gemini requires a slightly different message format."""
69
  gemini_messages = []
70
  for msg in messages:
71
- # Gemini uses 'model' for assistant role
72
  role = 'model' if msg['role'] == 'assistant' else 'user'
73
  gemini_messages.append({'role': role, 'parts': [msg['content']]})
74
  return gemini_messages
@@ -76,14 +83,11 @@ class LLMService:
76
  def generate_code_stream(
77
  self, model_id: str, messages: Messages, max_tokens: int = 8192
78
  ) -> Generator[str, None, None]:
79
- """
80
- Streams code generation, dispatching to the correct provider based on model_id.
81
- """
82
  provider, model_name = model_id.split('/', 1)
83
  logging.info(f"Dispatching to provider: {provider} for model: {model_name}")
84
 
85
  try:
86
- # --- OpenAI, Groq, DeepSeek, Fireworks (OpenAI-compatible) ---
87
  if provider in ['openai', 'groq', 'deepseek', 'fireworks']:
88
  client_map = {
89
  'openai': self.openai_client,
@@ -95,7 +99,6 @@ class LLMService:
95
  if not client:
96
  raise ValueError(f"{provider.capitalize()} API key not configured.")
97
 
98
- # Fireworks has a slightly different call signature
99
  if provider == 'fireworks':
100
  stream = client.create(model=model_name, messages=messages, stream=True, max_tokens=max_tokens)
101
  else:
@@ -105,7 +108,6 @@ class LLMService:
105
  if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
106
  yield chunk.choices[0].delta.content
107
 
108
- # --- Google Gemini ---
109
  elif provider == 'gemini':
110
  if not self.gemini_model:
111
  raise ValueError("Gemini API key not configured.")
@@ -114,17 +116,14 @@ class LLMService:
114
  for chunk in stream:
115
  yield chunk.text
116
 
117
- # --- Hugging Face ---
118
  elif provider == 'huggingface':
119
  if not self.hf_client:
120
  raise ValueError("Hugging Face API token not configured.")
121
- # For HF, model_name is the rest of the ID, e.g., baidu/ERNIE...
122
  hf_model_id = model_id.split('/', 1)[1]
123
  stream = self.hf_client.chat_completion(model=hf_model_id, messages=messages, stream=True, max_tokens=max_tokens)
124
  for chunk in stream:
125
  if chunk.choices[0].delta.content:
126
  yield chunk.choices[0].delta.content
127
-
128
  else:
129
  raise ValueError(f"Unknown provider: {provider}")
130
 
@@ -132,35 +131,21 @@ class LLMService:
132
  logging.error(f"LLM API Error with provider {provider}: {e}")
133
  yield f"Error from {provider.capitalize()}: {str(e)}"
134
 
 
135
  class SearchService:
136
- """A wrapper for the Tavily Search API."""
137
  def __init__(self, api_key: str = TAVILY_API_KEY):
138
  if not api_key:
139
  logging.warning("TAVILY_API_KEY not set. Web search will be disabled.")
140
  self.client = None
141
  else:
142
- try:
143
- self.client = TavilyClient(api_key=api_key)
144
- except Exception as e:
145
- logging.error(f"Failed to initialize Tavily client: {e}")
146
- self.client = None
147
-
148
  def is_available(self) -> bool:
149
- """Checks if the search service is configured and available."""
150
  return self.client is not None
151
-
152
  def search(self, query: str, max_results: int = 5) -> str:
153
- """Performs a web search and returns a formatted string of results."""
154
- if not self.is_available():
155
- return "Web search is not available."
156
  try:
157
- response = self.client.search(
158
- query, search_depth="advanced", max_results=min(max(1, max_results), 10)
159
- )
160
- results = [
161
- f"Title: {res.get('title', 'N/A')}\nURL: {res.get('url', 'N/A')}\nContent: {res.get('content', 'N/A')}"
162
- for res in response.get('results', [])
163
- ]
164
  return "Web Search Results:\n\n" + "\n---\n".join(results) if results else "No search results found."
165
  except Exception as e:
166
  logging.error(f"Tavily search error: {e}")
 
8
  - Fireworks AI
9
  - OpenAI
10
  - Google Gemini
11
+ - DeepSeek (Direct API via OpenAI client)
12
  """
13
  import os
14
  import logging
 
23
  import fireworks.client as Fireworks
24
  import openai
25
  import google.generativeai as genai
26
+
27
+ # <--- FIX: REMOVED the incorrect 'from deepseek import ...' line ---
28
 
29
  # --- Setup Logging & Environment ---
30
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
39
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
40
  DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
41
 
 
42
  # --- Type Definitions ---
43
  Messages = List[Dict[str, Any]]
44
 
 
50
  self.hf_client = InferenceClient(token=HF_TOKEN) if HF_TOKEN else None
51
  self.groq_client = Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None
52
  self.openai_client = openai.OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
 
53
 
54
+ # <--- FIX: Correctly instantiate the DeepSeek client using the OpenAI library ---
55
+ if DEEPSEEK_API_KEY:
56
+ self.deepseek_client = openai.OpenAI(
57
+ api_key=DEEPSEEK_API_KEY,
58
+ base_url="https://api.deepseek.com/v1"
59
+ )
60
+ else:
61
+ self.deepseek_client = None
62
+
63
  if FIREWORKS_API_KEY:
64
  Fireworks.api_key = FIREWORKS_API_KEY
65
  self.fireworks_client = Fireworks
 
73
  self.gemini_model = None
74
 
75
  def _prepare_messages_for_gemini(self, messages: Messages) -> List[Dict[str, Any]]:
76
+ # This function remains the same
77
  gemini_messages = []
78
  for msg in messages:
 
79
  role = 'model' if msg['role'] == 'assistant' else 'user'
80
  gemini_messages.append({'role': role, 'parts': [msg['content']]})
81
  return gemini_messages
 
83
  def generate_code_stream(
84
  self, model_id: str, messages: Messages, max_tokens: int = 8192
85
  ) -> Generator[str, None, None]:
86
+ # This function remains the same, as the dispatcher logic is already correct
 
 
87
  provider, model_name = model_id.split('/', 1)
88
  logging.info(f"Dispatching to provider: {provider} for model: {model_name}")
89
 
90
  try:
 
91
  if provider in ['openai', 'groq', 'deepseek', 'fireworks']:
92
  client_map = {
93
  'openai': self.openai_client,
 
99
  if not client:
100
  raise ValueError(f"{provider.capitalize()} API key not configured.")
101
 
 
102
  if provider == 'fireworks':
103
  stream = client.create(model=model_name, messages=messages, stream=True, max_tokens=max_tokens)
104
  else:
 
108
  if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
109
  yield chunk.choices[0].delta.content
110
 
 
111
  elif provider == 'gemini':
112
  if not self.gemini_model:
113
  raise ValueError("Gemini API key not configured.")
 
116
  for chunk in stream:
117
  yield chunk.text
118
 
 
119
  elif provider == 'huggingface':
120
  if not self.hf_client:
121
  raise ValueError("Hugging Face API token not configured.")
 
122
  hf_model_id = model_id.split('/', 1)[1]
123
  stream = self.hf_client.chat_completion(model=hf_model_id, messages=messages, stream=True, max_tokens=max_tokens)
124
  for chunk in stream:
125
  if chunk.choices[0].delta.content:
126
  yield chunk.choices[0].delta.content
 
127
  else:
128
  raise ValueError(f"Unknown provider: {provider}")
129
 
 
131
  logging.error(f"LLM API Error with provider {provider}: {e}")
132
  yield f"Error from {provider.capitalize()}: {str(e)}"
133
 
134
+ # The SearchService class remains unchanged
135
  class SearchService:
 
136
  def __init__(self, api_key: str = TAVILY_API_KEY):
137
  if not api_key:
138
  logging.warning("TAVILY_API_KEY not set. Web search will be disabled.")
139
  self.client = None
140
  else:
141
+ self.client = TavilyClient(api_key=api_key)
 
 
 
 
 
142
  def is_available(self) -> bool:
 
143
  return self.client is not None
 
144
  def search(self, query: str, max_results: int = 5) -> str:
145
+ if not self.is_available(): return "Web search is not available."
 
 
146
  try:
147
+ response = self.client.search(query, search_depth="advanced", max_results=min(max(1, max_results), 10))
148
+ results = [f"Title: {res.get('title', 'N/A')}\nURL: {res.get('url', 'N/A')}\nContent: {res.get('content', 'N/A')}" for res in response.get('results', [])]
 
 
 
 
 
149
  return "Web Search Results:\n\n" + "\n---\n".join(results) if results else "No search results found."
150
  except Exception as e:
151
  logging.error(f"Tavily search error: {e}")