# ai_providers.py from typing import Dict, Any, Optional, List import os from dotenv import load_dotenv import time from groq import Groq from groq.types.chat import ChatCompletion from langchain.llms.base import LLM from langchain.callbacks.manager import CallbackManagerForLLMRun from pydantic import Field, BaseModel class GroqLLM(LLM, BaseModel): api_key: str = Field(..., description="Groq API key") model: str = Field(default="deepseek-r1-distill-llama-70b", description="Model name") temperature: float = Field(default=0.7, description="Sampling temperature") max_tokens: int = Field(default=4000, description="Maximum number of tokens to generate") top_p: float = Field(default=1.0, description="Top p sampling parameter") client: Any = Field(default=None, description="Groq client instance") def __init__(self, **kwargs): super().__init__(**kwargs) self.client = Groq(api_key=self.api_key) @property def _llm_type(self) -> str: return "groq" def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs, ) -> str: max_retries = 3 retry_delay = 1 for attempt in range(max_retries): try: completion: ChatCompletion = self.client.chat.completions.create( model=self.model, messages=[{"role": "user", "content": prompt}], temperature=self.temperature, max_tokens=self.max_tokens, top_p=self.top_p, stop=stop, ) return completion.choices[0].message.content except Exception as e: if attempt == max_retries - 1: raise e time.sleep(retry_delay * (attempt + 1)) class AIProvider: def get_completion(self, prompt: str, **kwargs) -> str: pass def get_config(self) -> Dict[str, Any]: pass class GroqProvider(AIProvider): def __init__(self, api_key: str): if not api_key: raise ValueError("API key cannot be None or empty") self.api_key = api_key self.model = "deepseek-r1-distill-llama-70b" self.llm = GroqLLM( api_key=api_key, model=self.model, temperature=0.7, max_tokens=4000, top_p=1.0 ) def get_completion(self, prompt: str, **kwargs) -> str: return self.llm(prompt) def get_llm(self) -> GroqLLM: return self.llm def get_config(self) -> Dict[str, Any]: return { "llm": self.llm, "config_list": [{ "model": self.model, "api_key": self.api_key, "temperature": 0.7, "max_tokens": 4000, "api_base": "https://api.groq.com/openai/v1" }] } class AIProviderFactory: @staticmethod def create_provider(api_key: Optional[str] = None) -> AIProvider: # Try to get API key from parameter first if not api_key: # If not provided, try to load from environment load_dotenv() api_key = os.getenv("GROQ_API_KEY") # Validate API key if not api_key or not isinstance(api_key, str) or not api_key.strip(): raise ValueError("GROQ_API_KEY must be provided either directly or through environment variables") # Clean up API key api_key = api_key.strip() try: return GroqProvider(api_key=api_key) except Exception as e: raise ValueError(f"Failed to create Groq provider: {str(e)}")