|
import json |
|
from typing import Any, Coroutine, Dict, Optional, Union |
|
|
|
import httpx |
|
|
|
import litellm |
|
from litellm.llms.custom_httpx.http_handler import ( |
|
_get_httpx_client, |
|
get_async_httpx_client, |
|
) |
|
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM |
|
from litellm.types.llms.openai import Batch, CreateBatchRequest |
|
from litellm.types.llms.vertex_ai import VertexAIBatchPredictionJob |
|
|
|
from .transformation import VertexAIBatchTransformation |
|
|
|
|
|
class VertexAIBatchPrediction(VertexLLM): |
|
def __init__(self, gcs_bucket_name: str, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.gcs_bucket_name = gcs_bucket_name |
|
|
|
def create_batch( |
|
self, |
|
_is_async: bool, |
|
create_batch_data: CreateBatchRequest, |
|
api_base: Optional[str], |
|
vertex_credentials: Optional[str], |
|
vertex_project: Optional[str], |
|
vertex_location: Optional[str], |
|
timeout: Union[float, httpx.Timeout], |
|
max_retries: Optional[int], |
|
) -> Union[Batch, Coroutine[Any, Any, Batch]]: |
|
|
|
sync_handler = _get_httpx_client() |
|
|
|
access_token, project_id = self._ensure_access_token( |
|
credentials=vertex_credentials, |
|
project_id=vertex_project, |
|
custom_llm_provider="vertex_ai", |
|
) |
|
|
|
default_api_base = self.create_vertex_url( |
|
vertex_location=vertex_location or "us-central1", |
|
vertex_project=vertex_project or project_id, |
|
) |
|
|
|
if len(default_api_base.split(":")) > 1: |
|
endpoint = default_api_base.split(":")[-1] |
|
else: |
|
endpoint = "" |
|
|
|
_, api_base = self._check_custom_proxy( |
|
api_base=api_base, |
|
custom_llm_provider="vertex_ai", |
|
gemini_api_key=None, |
|
endpoint=endpoint, |
|
stream=None, |
|
auth_header=None, |
|
url=default_api_base, |
|
) |
|
|
|
headers = { |
|
"Content-Type": "application/json; charset=utf-8", |
|
"Authorization": f"Bearer {access_token}", |
|
} |
|
|
|
vertex_batch_request: VertexAIBatchPredictionJob = ( |
|
VertexAIBatchTransformation.transform_openai_batch_request_to_vertex_ai_batch_request( |
|
request=create_batch_data |
|
) |
|
) |
|
|
|
if _is_async is True: |
|
return self._async_create_batch( |
|
vertex_batch_request=vertex_batch_request, |
|
api_base=api_base, |
|
headers=headers, |
|
) |
|
|
|
response = sync_handler.post( |
|
url=api_base, |
|
headers=headers, |
|
data=json.dumps(vertex_batch_request), |
|
) |
|
|
|
if response.status_code != 200: |
|
raise Exception(f"Error: {response.status_code} {response.text}") |
|
|
|
_json_response = response.json() |
|
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response( |
|
response=_json_response |
|
) |
|
return vertex_batch_response |
|
|
|
async def _async_create_batch( |
|
self, |
|
vertex_batch_request: VertexAIBatchPredictionJob, |
|
api_base: str, |
|
headers: Dict[str, str], |
|
) -> Batch: |
|
client = get_async_httpx_client( |
|
llm_provider=litellm.LlmProviders.VERTEX_AI, |
|
) |
|
response = await client.post( |
|
url=api_base, |
|
headers=headers, |
|
data=json.dumps(vertex_batch_request), |
|
) |
|
if response.status_code != 200: |
|
raise Exception(f"Error: {response.status_code} {response.text}") |
|
|
|
_json_response = response.json() |
|
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response( |
|
response=_json_response |
|
) |
|
return vertex_batch_response |
|
|
|
def create_vertex_url( |
|
self, |
|
vertex_location: str, |
|
vertex_project: str, |
|
) -> str: |
|
"""Return the base url for the vertex garden models""" |
|
|
|
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/batchPredictionJobs" |
|
|
|
def retrieve_batch( |
|
self, |
|
_is_async: bool, |
|
batch_id: str, |
|
api_base: Optional[str], |
|
vertex_credentials: Optional[str], |
|
vertex_project: Optional[str], |
|
vertex_location: Optional[str], |
|
timeout: Union[float, httpx.Timeout], |
|
max_retries: Optional[int], |
|
) -> Union[Batch, Coroutine[Any, Any, Batch]]: |
|
sync_handler = _get_httpx_client() |
|
|
|
access_token, project_id = self._ensure_access_token( |
|
credentials=vertex_credentials, |
|
project_id=vertex_project, |
|
custom_llm_provider="vertex_ai", |
|
) |
|
|
|
default_api_base = self.create_vertex_url( |
|
vertex_location=vertex_location or "us-central1", |
|
vertex_project=vertex_project or project_id, |
|
) |
|
|
|
|
|
default_api_base = f"{default_api_base}/{batch_id}" |
|
|
|
if len(default_api_base.split(":")) > 1: |
|
endpoint = default_api_base.split(":")[-1] |
|
else: |
|
endpoint = "" |
|
|
|
_, api_base = self._check_custom_proxy( |
|
api_base=api_base, |
|
custom_llm_provider="vertex_ai", |
|
gemini_api_key=None, |
|
endpoint=endpoint, |
|
stream=None, |
|
auth_header=None, |
|
url=default_api_base, |
|
) |
|
|
|
headers = { |
|
"Content-Type": "application/json; charset=utf-8", |
|
"Authorization": f"Bearer {access_token}", |
|
} |
|
|
|
if _is_async is True: |
|
return self._async_retrieve_batch( |
|
api_base=api_base, |
|
headers=headers, |
|
) |
|
|
|
response = sync_handler.get( |
|
url=api_base, |
|
headers=headers, |
|
) |
|
|
|
if response.status_code != 200: |
|
raise Exception(f"Error: {response.status_code} {response.text}") |
|
|
|
_json_response = response.json() |
|
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response( |
|
response=_json_response |
|
) |
|
return vertex_batch_response |
|
|
|
async def _async_retrieve_batch( |
|
self, |
|
api_base: str, |
|
headers: Dict[str, str], |
|
) -> Batch: |
|
client = get_async_httpx_client( |
|
llm_provider=litellm.LlmProviders.VERTEX_AI, |
|
) |
|
response = await client.get( |
|
url=api_base, |
|
headers=headers, |
|
) |
|
if response.status_code != 200: |
|
raise Exception(f"Error: {response.status_code} {response.text}") |
|
|
|
_json_response = response.json() |
|
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response( |
|
response=_json_response |
|
) |
|
return vertex_batch_response |
|
|