File size: 6,901 Bytes
e3278e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import uuid
from typing import Dict
from litellm.llms.vertex_ai.common_utils import (
_convert_vertex_datetime_to_openai_datetime,
)
from litellm.types.llms.openai import Batch, BatchJobStatus, CreateBatchRequest
from litellm.types.llms.vertex_ai import *
class VertexAIBatchTransformation:
"""
Transforms OpenAI Batch requests to Vertex AI Batch requests
API Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini
"""
@classmethod
def transform_openai_batch_request_to_vertex_ai_batch_request(
cls,
request: CreateBatchRequest,
) -> VertexAIBatchPredictionJob:
"""
Transforms OpenAI Batch requests to Vertex AI Batch requests
"""
request_display_name = f"litellm-vertex-batch-{uuid.uuid4()}"
input_file_id = request.get("input_file_id")
if input_file_id is None:
raise ValueError("input_file_id is required, but not provided")
input_config: InputConfig = InputConfig(
gcsSource=GcsSource(uris=input_file_id), instancesFormat="jsonl"
)
model: str = cls._get_model_from_gcs_file(input_file_id)
output_config: OutputConfig = OutputConfig(
predictionsFormat="jsonl",
gcsDestination=GcsDestination(
outputUriPrefix=cls._get_gcs_uri_prefix_from_file(input_file_id)
),
)
return VertexAIBatchPredictionJob(
inputConfig=input_config,
outputConfig=output_config,
model=model,
displayName=request_display_name,
)
@classmethod
def transform_vertex_ai_batch_response_to_openai_batch_response(
cls, response: VertexBatchPredictionResponse
) -> Batch:
return Batch(
id=cls._get_batch_id_from_vertex_ai_batch_response(response),
completion_window="24hrs",
created_at=_convert_vertex_datetime_to_openai_datetime(
vertex_datetime=response.get("createTime", "")
),
endpoint="",
input_file_id=cls._get_input_file_id_from_vertex_ai_batch_response(
response
),
object="batch",
status=cls._get_batch_job_status_from_vertex_ai_batch_response(response),
error_file_id=None, # Vertex AI doesn't seem to have a direct equivalent
output_file_id=cls._get_output_file_id_from_vertex_ai_batch_response(
response
),
)
@classmethod
def _get_batch_id_from_vertex_ai_batch_response(
cls, response: VertexBatchPredictionResponse
) -> str:
"""
Gets the batch id from the Vertex AI Batch response safely
vertex response: `projects/510528649030/locations/us-central1/batchPredictionJobs/3814889423749775360`
returns: `3814889423749775360`
"""
_name = response.get("name", "")
if not _name:
return ""
# Split by '/' and get the last part if it exists
parts = _name.split("/")
return parts[-1] if parts else _name
@classmethod
def _get_input_file_id_from_vertex_ai_batch_response(
cls, response: VertexBatchPredictionResponse
) -> str:
"""
Gets the input file id from the Vertex AI Batch response
"""
input_file_id: str = ""
input_config = response.get("inputConfig")
if input_config is None:
return input_file_id
gcs_source = input_config.get("gcsSource")
if gcs_source is None:
return input_file_id
uris = gcs_source.get("uris", "")
if len(uris) == 0:
return input_file_id
return uris[0]
@classmethod
def _get_output_file_id_from_vertex_ai_batch_response(
cls, response: VertexBatchPredictionResponse
) -> str:
"""
Gets the output file id from the Vertex AI Batch response
"""
output_file_id: str = ""
output_config = response.get("outputConfig")
if output_config is None:
return output_file_id
gcs_destination = output_config.get("gcsDestination")
if gcs_destination is None:
return output_file_id
output_uri_prefix = gcs_destination.get("outputUriPrefix", "")
return output_uri_prefix
@classmethod
def _get_batch_job_status_from_vertex_ai_batch_response(
cls, response: VertexBatchPredictionResponse
) -> BatchJobStatus:
"""
Gets the batch job status from the Vertex AI Batch response
ref: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/JobState
"""
state_mapping: Dict[str, BatchJobStatus] = {
"JOB_STATE_UNSPECIFIED": "failed",
"JOB_STATE_QUEUED": "validating",
"JOB_STATE_PENDING": "validating",
"JOB_STATE_RUNNING": "in_progress",
"JOB_STATE_SUCCEEDED": "completed",
"JOB_STATE_FAILED": "failed",
"JOB_STATE_CANCELLING": "cancelling",
"JOB_STATE_CANCELLED": "cancelled",
"JOB_STATE_PAUSED": "in_progress",
"JOB_STATE_EXPIRED": "expired",
"JOB_STATE_UPDATING": "in_progress",
"JOB_STATE_PARTIALLY_SUCCEEDED": "completed",
}
vertex_state = response.get("state", "JOB_STATE_UNSPECIFIED")
return state_mapping[vertex_state]
@classmethod
def _get_gcs_uri_prefix_from_file(cls, input_file_id: str) -> str:
"""
Gets the gcs uri prefix from the input file id
Example:
input_file_id: "gs://litellm-testing-bucket/vtx_batch.jsonl"
returns: "gs://litellm-testing-bucket"
input_file_id: "gs://litellm-testing-bucket/batches/vtx_batch.jsonl"
returns: "gs://litellm-testing-bucket/batches"
"""
# Split the path and remove the filename
path_parts = input_file_id.rsplit("/", 1)
return path_parts[0]
@classmethod
def _get_model_from_gcs_file(cls, gcs_file_uri: str) -> str:
"""
Extracts the model from the gcs file uri
When files are uploaded using LiteLLM (/v1/files), the model is stored in the gcs file uri
Why?
- Because Vertex Requires the `model` param in create batch jobs request, but OpenAI does not require this
gcs_file_uri format: gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/e9412502-2c91-42a6-8e61-f5c294cc0fc8
returns: "publishers/google/models/gemini-1.5-flash-001"
"""
from urllib.parse import unquote
decoded_uri = unquote(gcs_file_uri)
model_path = decoded_uri.split("publishers/")[1]
parts = model_path.split("/")
model = f"publishers/{'/'.join(parts[:3])}"
return model
|