|
import os |
|
from abc import ABC, abstractmethod |
|
from functools import cached_property |
|
from typing import ClassVar, Literal, Optional, Union |
|
|
|
import httpx |
|
from httpx import Limits, Timeout |
|
from openai import AsyncOpenAI |
|
from openai.types.chat.chat_completion import ( |
|
ChatCompletion, |
|
) |
|
from pydantic import BaseModel |
|
|
|
from proxy_lite.history import MessageHistory |
|
from proxy_lite.logger import logger |
|
from proxy_lite.serializer import ( |
|
BaseSerializer, |
|
OpenAICompatibleSerializer, |
|
) |
|
from proxy_lite.tools import Tool |
|
|
|
|
|
class BaseClientConfig(BaseModel): |
|
http_timeout: float = 50 |
|
http_concurrent_connections: int = 50 |
|
|
|
|
|
class BaseClient(BaseModel, ABC): |
|
config: BaseClientConfig |
|
serializer: ClassVar[BaseSerializer] |
|
|
|
@abstractmethod |
|
async def create_completion( |
|
self, |
|
messages: MessageHistory, |
|
temperature: float = 0.7, |
|
seed: Optional[int] = None, |
|
tools: Optional[list[Tool]] = None, |
|
response_format: Optional[type[BaseModel]] = None, |
|
) -> ChatCompletion: ... |
|
|
|
""" |
|
Create completion from model. |
|
Expect subclasses to adapt from various endpoints that will handle |
|
requests differently, make sure to raise appropriate warnings. |
|
|
|
Returns: |
|
ChatCompletion: OpenAI ChatCompletion format for consistency |
|
""" |
|
|
|
@classmethod |
|
def create(cls, config: BaseClientConfig) -> "BaseClient": |
|
supported_clients = { |
|
"openai": OpenAIClient, |
|
"openai-azure": OpenAIClient, |
|
"convergence": ConvergenceClient, |
|
"gemini": GeminiClient, |
|
} |
|
if config.name not in supported_clients: |
|
error_message = f"Unsupported model: {config.name}." |
|
raise ValueError(error_message) |
|
return supported_clients[config.name](config=config) |
|
|
|
@property |
|
def http_client(self) -> httpx.AsyncClient: |
|
return httpx.AsyncClient( |
|
timeout=Timeout(self.config.http_timeout), |
|
limits=Limits( |
|
max_connections=self.config.http_concurrent_connections, |
|
max_keepalive_connections=self.config.http_concurrent_connections, |
|
), |
|
) |
|
|
|
|
|
class OpenAIClientConfig(BaseClientConfig): |
|
name: Literal["openai"] = "openai" |
|
model_id: str = "gpt-4o" |
|
api_key: str = os.environ.get("OPENAI_API_KEY") |
|
api_base: Optional[str] = None |
|
|
|
|
|
class OpenAIClient(BaseClient): |
|
config: OpenAIClientConfig |
|
serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer() |
|
|
|
@cached_property |
|
def external_client(self) -> AsyncOpenAI: |
|
client_params = { |
|
"api_key": self.config.api_key, |
|
"http_client": self.http_client, |
|
} |
|
if self.config.api_base: |
|
client_params["base_url"] = self.config.api_base |
|
return AsyncOpenAI(**client_params) |
|
|
|
async def create_completion( |
|
self, |
|
messages: MessageHistory, |
|
temperature: float = 0.7, |
|
seed: Optional[int] = None, |
|
tools: Optional[list[Tool]] = None, |
|
response_format: Optional[type[BaseModel]] = None, |
|
) -> ChatCompletion: |
|
base_params = { |
|
"model": self.config.model_id, |
|
"messages": self.serializer.serialize_messages(messages), |
|
"temperature": temperature, |
|
} |
|
optional_params = { |
|
"seed": seed, |
|
"tools": self.serializer.serialize_tools(tools) if tools else None, |
|
"tool_choice": "required" if tools else None, |
|
"response_format": {"type": "json_object"} if response_format else {"type": "text"}, |
|
} |
|
base_params.update({k: v for k, v in optional_params.items() if v is not None}) |
|
return await self.external_client.chat.completions.create(**base_params) |
|
|
|
|
|
class ConvergenceClientConfig(BaseClientConfig): |
|
name: Literal["convergence"] = "convergence" |
|
model_id: str = "convergence-ai/proxy-lite-7b" |
|
api_base: str = "http://localhost:8000/v1" |
|
api_key: str = "none" |
|
|
|
|
|
class ConvergenceClient(OpenAIClient): |
|
config: ConvergenceClientConfig |
|
serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer() |
|
_model_validated: bool = False |
|
|
|
async def _validate_model(self) -> None: |
|
try: |
|
response = await self.external_client.models.list() |
|
assert self.config.model_id in [model.id for model in response.data], ( |
|
f"Model {self.config.model_id} not found in {response.data}" |
|
) |
|
self._model_validated = True |
|
logger.debug(f"Model {self.config.model_id} validated and connected to cluster") |
|
except Exception as e: |
|
logger.error(f"Error retrieving model: {e}") |
|
raise e |
|
|
|
@cached_property |
|
def external_client(self) -> AsyncOpenAI: |
|
return AsyncOpenAI( |
|
api_key=self.config.api_key, |
|
base_url=self.config.api_base, |
|
http_client=self.http_client, |
|
) |
|
|
|
async def create_completion( |
|
self, |
|
messages: MessageHistory, |
|
temperature: float = 0.7, |
|
seed: Optional[int] = None, |
|
tools: Optional[list[Tool]] = None, |
|
response_format: Optional[type[BaseModel]] = None, |
|
) -> ChatCompletion: |
|
if not self._model_validated: |
|
await self._validate_model() |
|
base_params = { |
|
"model": self.config.model_id, |
|
"messages": self.serializer.serialize_messages(messages), |
|
"temperature": temperature, |
|
} |
|
optional_params = { |
|
"seed": seed, |
|
"tools": self.serializer.serialize_tools(tools) if tools else None, |
|
"tool_choice": "auto" if tools else None, |
|
"response_format": response_format if response_format else {"type": "text"}, |
|
} |
|
base_params.update({k: v for k, v in optional_params.items() if v is not None}) |
|
return await self.external_client.chat.completions.create(**base_params) |
|
|
|
|
|
class GeminiClientConfig(BaseClientConfig): |
|
name: Literal["gemini"] = "gemini" |
|
model_id: str = "gemini-2.0-flash-001" |
|
api_key: str = "" |
|
|
|
|
|
class GeminiClient(BaseClient): |
|
config: GeminiClientConfig |
|
serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer() |
|
|
|
def _convert_messages_to_gemini_format(self, messages): |
|
"""Convert OpenAI format messages to Gemini format""" |
|
gemini_parts = [] |
|
for msg in messages: |
|
if msg["role"] == "user": |
|
gemini_parts.append({"text": msg["content"]}) |
|
elif msg["role"] == "assistant": |
|
gemini_parts.append({"text": msg["content"]}) |
|
|
|
return gemini_parts |
|
|
|
def _clean_schema_for_gemini(self, schema): |
|
"""Clean up JSON schema for Gemini function calling - remove $defs and $ref""" |
|
if not isinstance(schema, dict): |
|
return schema |
|
|
|
cleaned = {} |
|
for key, value in schema.items(): |
|
if key == "$defs": |
|
|
|
continue |
|
elif key == "$ref": |
|
|
|
continue |
|
elif isinstance(value, dict): |
|
cleaned[key] = self._clean_schema_for_gemini(value) |
|
elif isinstance(value, list): |
|
cleaned[key] = [self._clean_schema_for_gemini(item) for item in value] |
|
else: |
|
cleaned[key] = value |
|
|
|
|
|
if "$defs" in schema: |
|
cleaned = self._inline_definitions(cleaned, schema["$defs"]) |
|
|
|
return cleaned |
|
|
|
def _inline_definitions(self, schema, definitions): |
|
"""Inline $ref definitions into the schema""" |
|
if not isinstance(schema, dict): |
|
return schema |
|
|
|
if "$ref" in schema: |
|
|
|
ref_name = schema["$ref"].split("/")[-1] |
|
if ref_name in definitions: |
|
|
|
return self._inline_definitions(definitions[ref_name], definitions) |
|
else: |
|
|
|
return {k: v for k, v in schema.items() if k != "$ref"} |
|
|
|
|
|
inlined = {} |
|
for key, value in schema.items(): |
|
if isinstance(value, dict): |
|
inlined[key] = self._inline_definitions(value, definitions) |
|
elif isinstance(value, list): |
|
inlined[key] = [self._inline_definitions(item, definitions) for item in value] |
|
else: |
|
inlined[key] = value |
|
|
|
return inlined |
|
|
|
async def create_completion( |
|
self, |
|
messages: MessageHistory, |
|
temperature: float = 0.7, |
|
seed: Optional[int] = None, |
|
tools: Optional[list[Tool]] = None, |
|
response_format: Optional[type[BaseModel]] = None, |
|
) -> ChatCompletion: |
|
import json |
|
from openai.types.chat.chat_completion import ChatCompletion, Choice |
|
from openai.types.chat.chat_completion_message import ChatCompletionMessage |
|
from openai.types.completion_usage import CompletionUsage |
|
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall |
|
|
|
|
|
serialized_messages = self.serializer.serialize_messages(messages) |
|
|
|
|
|
contents = [] |
|
current_user_text = "" |
|
|
|
for msg in serialized_messages: |
|
|
|
content_text = "" |
|
if isinstance(msg["content"], list): |
|
|
|
for item in msg["content"]: |
|
if isinstance(item, dict) and "text" in item: |
|
content_text += item["text"] |
|
elif isinstance(item, str): |
|
content_text += item |
|
elif isinstance(msg["content"], str): |
|
content_text = msg["content"] |
|
|
|
if msg["role"] == "user": |
|
|
|
current_user_text += content_text + "\n" |
|
elif msg["role"] == "assistant": |
|
|
|
if current_user_text.strip(): |
|
contents.append({ |
|
"role": "user", |
|
"parts": [{"text": current_user_text.strip()}] |
|
}) |
|
current_user_text = "" |
|
|
|
|
|
contents.append({ |
|
"role": "model", |
|
"parts": [{"text": content_text}] |
|
}) |
|
elif msg["role"] == "tool": |
|
|
|
|
|
current_user_text += f"[ACTION COMPLETED] {content_text}\n" |
|
|
|
|
|
if current_user_text.strip(): |
|
contents.append({ |
|
"role": "user", |
|
"parts": [{"text": current_user_text.strip()}] |
|
}) |
|
|
|
payload = { |
|
"contents": contents, |
|
"generationConfig": { |
|
"temperature": temperature, |
|
} |
|
} |
|
|
|
|
|
if tools: |
|
|
|
function_declarations = [] |
|
for tool in tools: |
|
for tool_schema in tool.schema: |
|
|
|
cleaned_parameters = self._clean_schema_for_gemini(tool_schema["parameters"]) |
|
function_declarations.append({ |
|
"name": tool_schema["name"], |
|
"description": tool_schema["description"], |
|
"parameters": cleaned_parameters |
|
}) |
|
|
|
payload["tools"] = [{ |
|
"function_declarations": function_declarations |
|
}] |
|
|
|
|
|
url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.config.model_id}:generateContent?key={self.config.api_key}" |
|
|
|
response = await self.http_client.post( |
|
url, |
|
json=payload, |
|
headers={"Content-Type": "application/json"} |
|
) |
|
|
|
response.raise_for_status() |
|
response_data = response.json() |
|
|
|
|
|
if "candidates" in response_data and len(response_data["candidates"]) > 0: |
|
candidate = response_data["candidates"][0] |
|
|
|
|
|
content = "" |
|
tool_calls = [] |
|
|
|
if "content" in candidate and "parts" in candidate["content"]: |
|
for part in candidate["content"]["parts"]: |
|
if "text" in part: |
|
content += part["text"] |
|
elif "functionCall" in part: |
|
|
|
func_call = part["functionCall"] |
|
tool_call = ChatCompletionMessageToolCall( |
|
id=f"call_{hash(str(func_call))}"[:16], |
|
type="function", |
|
function={ |
|
"name": func_call["name"], |
|
"arguments": json.dumps(func_call.get("args", {})) |
|
} |
|
) |
|
tool_calls.append(tool_call) |
|
|
|
choice = Choice( |
|
index=0, |
|
message=ChatCompletionMessage( |
|
role="assistant", |
|
content=content if content else None, |
|
tool_calls=tool_calls if tool_calls else None |
|
), |
|
finish_reason="stop" |
|
) |
|
|
|
|
|
completion = ChatCompletion( |
|
id="gemini-" + str(hash(content))[:8], |
|
choices=[choice], |
|
created=int(__import__('time').time()), |
|
model=self.config.model_id, |
|
object="chat.completion", |
|
usage=CompletionUsage( |
|
completion_tokens=len(content.split()), |
|
prompt_tokens=sum(len(str(msg.get("content", "")).split()) for msg in serialized_messages), |
|
total_tokens=len(content.split()) + sum(len(str(msg.get("content", "")).split()) for msg in serialized_messages) |
|
) |
|
) |
|
|
|
return completion |
|
else: |
|
raise Exception(f"No valid response from Gemini API: {response_data}") |
|
|
|
|
|
ClientConfigTypes = Union[OpenAIClientConfig, ConvergenceClientConfig, GeminiClientConfig] |
|
ClientTypes = Union[OpenAIClient, ConvergenceClient, GeminiClient] |
|
|