Spaces:
Running
Running
# ai_providers.py | |
from typing import Dict, Any, Optional, List | |
import os | |
from dotenv import load_dotenv | |
import time | |
from groq import Groq | |
from groq.types.chat import ChatCompletion | |
from langchain.llms.base import LLM | |
from langchain.callbacks.manager import CallbackManagerForLLMRun | |
from pydantic import Field, BaseModel | |
class GroqLLM(LLM, BaseModel): | |
api_key: str = Field(..., description="Groq API key") | |
model: str = Field(default="deepseek-r1-distill-llama-70b", description="Model name") | |
temperature: float = Field(default=0.7, description="Sampling temperature") | |
max_tokens: int = Field(default=4000, description="Maximum number of tokens to generate") | |
top_p: float = Field(default=1.0, description="Top p sampling parameter") | |
client: Any = Field(default=None, description="Groq client instance") | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.client = Groq(api_key=self.api_key) | |
def _llm_type(self) -> str: | |
return "groq" | |
def _call( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs, | |
) -> str: | |
max_retries = 3 | |
retry_delay = 1 | |
for attempt in range(max_retries): | |
try: | |
completion: ChatCompletion = self.client.chat.completions.create( | |
model=self.model, | |
messages=[{"role": "user", "content": prompt}], | |
temperature=self.temperature, | |
max_tokens=self.max_tokens, | |
top_p=self.top_p, | |
stop=stop, | |
) | |
return completion.choices[0].message.content | |
except Exception as e: | |
if attempt == max_retries - 1: | |
raise e | |
time.sleep(retry_delay * (attempt + 1)) | |
class AIProvider: | |
def get_completion(self, prompt: str, **kwargs) -> str: | |
pass | |
def get_config(self) -> Dict[str, Any]: | |
pass | |
class GroqProvider(AIProvider): | |
def __init__(self, api_key: str): | |
if not api_key: | |
raise ValueError("API key cannot be None or empty") | |
self.api_key = api_key | |
self.model = "deepseek-r1-distill-llama-70b" | |
self.llm = GroqLLM( | |
api_key=api_key, | |
model=self.model, | |
temperature=0.7, | |
max_tokens=4000, | |
top_p=1.0 | |
) | |
def get_completion(self, prompt: str, **kwargs) -> str: | |
return self.llm(prompt) | |
def get_llm(self) -> GroqLLM: | |
return self.llm | |
def get_config(self) -> Dict[str, Any]: | |
return { | |
"llm": self.llm, | |
"config_list": [{ | |
"model": self.model, | |
"api_key": self.api_key, | |
"temperature": 0.7, | |
"max_tokens": 4000, | |
"api_base": "https://api.groq.com/openai/v1" | |
}] | |
} | |
class AIProviderFactory: | |
def create_provider(api_key: Optional[str] = None) -> AIProvider: | |
# Try to get API key from parameter first | |
if not api_key: | |
# If not provided, try to load from environment | |
load_dotenv() | |
api_key = os.getenv("GROQ_API_KEY") | |
# Validate API key | |
if not api_key or not isinstance(api_key, str) or not api_key.strip(): | |
raise ValueError("GROQ_API_KEY must be provided either directly or through environment variables") | |
# Clean up API key | |
api_key = api_key.strip() | |
try: | |
return GroqProvider(api_key=api_key) | |
except Exception as e: | |
raise ValueError(f"Failed to create Groq provider: {str(e)}") |