import os from abc import ABC, abstractmethod from functools import cached_property from typing import ClassVar, Literal, Optional, Union import httpx from httpx import Limits, Timeout from openai import AsyncOpenAI from openai.types.chat.chat_completion import ( ChatCompletion, ) from pydantic import BaseModel from proxy_lite.history import MessageHistory from proxy_lite.logger import logger from proxy_lite.serializer import ( BaseSerializer, OpenAICompatibleSerializer, ) from proxy_lite.tools import Tool class BaseClientConfig(BaseModel): http_timeout: float = 50 http_concurrent_connections: int = 50 class BaseClient(BaseModel, ABC): config: BaseClientConfig serializer: ClassVar[BaseSerializer] @abstractmethod async def create_completion( self, messages: MessageHistory, temperature: float = 0.7, seed: Optional[int] = None, tools: Optional[list[Tool]] = None, response_format: Optional[type[BaseModel]] = None, ) -> ChatCompletion: ... """ Create completion from model. Expect subclasses to adapt from various endpoints that will handle requests differently, make sure to raise appropriate warnings. Returns: ChatCompletion: OpenAI ChatCompletion format for consistency """ @classmethod def create(cls, config: BaseClientConfig) -> "BaseClient": supported_clients = { "openai-azure": OpenAIClient, "convergence": ConvergenceClient, } if config.name not in supported_clients: error_message = f"Unsupported model: {config.name}." raise ValueError(error_message) return supported_clients[config.name](config=config) @property def http_client(self) -> httpx.AsyncClient: return httpx.AsyncClient( timeout=Timeout(self.config.http_timeout), limits=Limits( max_connections=self.config.http_concurrent_connections, max_keepalive_connections=self.config.http_concurrent_connections, ), ) class OpenAIClientConfig(BaseClientConfig): name: Literal["openai"] = "openai" model_id: str = "gpt-4o" api_key: str = os.environ.get("OPENAI_API_KEY") class OpenAIClient(BaseClient): config: OpenAIClientConfig serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer() @cached_property def external_client(self) -> AsyncOpenAI: return AsyncOpenAI( api_key=self.config.api_key, http_client=self.http_client, ) async def create_completion( self, messages: MessageHistory, temperature: float = 0.7, seed: Optional[int] = None, tools: Optional[list[Tool]] = None, response_format: Optional[type[BaseModel]] = None, ) -> ChatCompletion: base_params = { "model": self.config.model_id, "messages": self.serializer.serialize_messages(messages), "temperature": temperature, } optional_params = { "seed": seed, "tools": self.serializer.serialize_tools(tools) if tools else None, "tool_choice": "required" if tools else None, "response_format": {"type": "json_object"} if response_format else {"type": "text"}, } base_params.update({k: v for k, v in optional_params.items() if v is not None}) return await self.external_client.chat.completions.create(**base_params) class ConvergenceClientConfig(BaseClientConfig): name: Literal["convergence"] = "convergence" model_id: str = "convergence-ai/proxy-lite-7b" api_base: str = "http://localhost:8000/v1" api_key: str = "none" class ConvergenceClient(OpenAIClient): config: ConvergenceClientConfig serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer() _model_validated: bool = False async def _validate_model(self) -> None: try: response = await self.external_client.models.list() assert self.config.model_id in [model.id for model in response.data], ( f"Model {self.config.model_id} not found in {response.data}" ) self._model_validated = True logger.debug(f"Model {self.config.model_id} validated and connected to cluster") except Exception as e: logger.error(f"Error retrieving model: {e}") raise e @cached_property def external_client(self) -> AsyncOpenAI: return AsyncOpenAI( api_key=self.config.api_key, base_url=self.config.api_base, http_client=self.http_client, ) async def create_completion( self, messages: MessageHistory, temperature: float = 0.7, seed: Optional[int] = None, tools: Optional[list[Tool]] = None, response_format: Optional[type[BaseModel]] = None, ) -> ChatCompletion: if not self._model_validated: await self._validate_model() base_params = { "model": self.config.model_id, "messages": self.serializer.serialize_messages(messages), "temperature": temperature, } optional_params = { "seed": seed, "tools": self.serializer.serialize_tools(tools) if tools else None, "tool_choice": "auto" if tools else None, # vLLM does not support "required" "response_format": response_format if response_format else {"type": "text"}, } base_params.update({k: v for k, v in optional_params.items() if v is not None}) return await self.external_client.chat.completions.create(**base_params) ClientConfigTypes = Union[OpenAIClientConfig, ConvergenceClientConfig] ClientTypes = Union[OpenAIClient, ConvergenceClient]