LegalAI-DS / ai_providers.py
Hassankhwileh's picture
Update ai_providers.py
892edcf verified
# ai_providers.py
from typing import Dict, Any, Optional, List
import os
from dotenv import load_dotenv
import time
from groq import Groq
from groq.types.chat import ChatCompletion
from langchain.llms.base import LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun
from pydantic import Field, BaseModel
class GroqLLM(LLM, BaseModel):
api_key: str = Field(..., description="Groq API key")
model: str = Field(default="deepseek-r1-distill-llama-70b", description="Model name")
temperature: float = Field(default=0.7, description="Sampling temperature")
max_tokens: int = Field(default=4000, description="Maximum number of tokens to generate")
top_p: float = Field(default=1.0, description="Top p sampling parameter")
client: Any = Field(default=None, description="Groq client instance")
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.client = Groq(api_key=self.api_key)
@property
def _llm_type(self) -> str:
return "groq"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs,
) -> str:
max_retries = 3
retry_delay = 1
for attempt in range(max_retries):
try:
completion: ChatCompletion = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
stop=stop,
)
return completion.choices[0].message.content
except Exception as e:
if attempt == max_retries - 1:
raise e
time.sleep(retry_delay * (attempt + 1))
class AIProvider:
def get_completion(self, prompt: str, **kwargs) -> str:
pass
def get_config(self) -> Dict[str, Any]:
pass
class GroqProvider(AIProvider):
def __init__(self, api_key: str):
if not api_key:
raise ValueError("API key cannot be None or empty")
self.api_key = api_key
self.model = "deepseek-r1-distill-llama-70b"
self.llm = GroqLLM(
api_key=api_key,
model=self.model,
temperature=0.7,
max_tokens=4000,
top_p=1.0
)
def get_completion(self, prompt: str, **kwargs) -> str:
return self.llm(prompt)
def get_llm(self) -> GroqLLM:
return self.llm
def get_config(self) -> Dict[str, Any]:
return {
"llm": self.llm,
"config_list": [{
"model": self.model,
"api_key": self.api_key,
"temperature": 0.7,
"max_tokens": 4000,
"api_base": "https://api.groq.com/openai/v1"
}]
}
class AIProviderFactory:
@staticmethod
def create_provider(api_key: Optional[str] = None) -> AIProvider:
# Try to get API key from parameter first
if not api_key:
# If not provided, try to load from environment
load_dotenv()
api_key = os.getenv("GROQ_API_KEY")
# Validate API key
if not api_key or not isinstance(api_key, str) or not api_key.strip():
raise ValueError("GROQ_API_KEY must be provided either directly or through environment variables")
# Clean up API key
api_key = api_key.strip()
try:
return GroqProvider(api_key=api_key)
except Exception as e:
raise ValueError(f"Failed to create Groq provider: {str(e)}")