File size: 8,858 Bytes
e3278e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 |
"""
Class to handle llm wildcard routing and regex pattern matching
"""
import copy
import re
from re import Match
from typing import Dict, List, Optional, Tuple
from litellm import get_llm_provider
from litellm._logging import verbose_router_logger
class PatternUtils:
@staticmethod
def calculate_pattern_specificity(pattern: str) -> Tuple[int, int]:
"""
Calculate pattern specificity based on length and complexity.
Args:
pattern: Regex pattern to analyze
Returns:
Tuple of (length, complexity) for sorting
"""
complexity_chars = ["*", "+", "?", "\\", "^", "$", "|", "(", ")"]
ret_val = (
len(pattern), # Longer patterns more specific
sum(
pattern.count(char) for char in complexity_chars
), # More regex complexity
)
return ret_val
@staticmethod
def sorted_patterns(
patterns: Dict[str, List[Dict]]
) -> List[Tuple[str, List[Dict]]]:
"""
Cached property for patterns sorted by specificity.
Returns:
Sorted list of pattern-deployment tuples
"""
return sorted(
patterns.items(),
key=lambda x: PatternUtils.calculate_pattern_specificity(x[0]),
reverse=True,
)
class PatternMatchRouter:
"""
Class to handle llm wildcard routing and regex pattern matching
doc: https://docs.litellm.ai/docs/proxy/configs#provider-specific-wildcard-routing
This class will store a mapping for regex pattern: List[Deployments]
"""
def __init__(self):
self.patterns: Dict[str, List] = {}
def add_pattern(self, pattern: str, llm_deployment: Dict):
"""
Add a regex pattern and the corresponding llm deployments to the patterns
Args:
pattern: str
llm_deployment: str or List[str]
"""
# Convert the pattern to a regex
regex = self._pattern_to_regex(pattern)
if regex not in self.patterns:
self.patterns[regex] = []
self.patterns[regex].append(llm_deployment)
def _pattern_to_regex(self, pattern: str) -> str:
"""
Convert a wildcard pattern to a regex pattern
example:
pattern: openai/*
regex: openai/.*
pattern: openai/fo::*::static::*
regex: openai/fo::.*::static::.*
Args:
pattern: str
Returns:
str: regex pattern
"""
# # Replace '*' with '.*' for regex matching
# regex = pattern.replace("*", ".*")
# # Escape other special characters
# regex = re.escape(regex).replace(r"\.\*", ".*")
# return f"^{regex}$"
return re.escape(pattern).replace(r"\*", "(.*)")
def _return_pattern_matched_deployments(
self, matched_pattern: Match, deployments: List[Dict]
) -> List[Dict]:
new_deployments = []
for deployment in deployments:
new_deployment = copy.deepcopy(deployment)
new_deployment["litellm_params"]["model"] = (
PatternMatchRouter.set_deployment_model_name(
matched_pattern=matched_pattern,
litellm_deployment_litellm_model=deployment["litellm_params"][
"model"
],
)
)
new_deployments.append(new_deployment)
return new_deployments
def route(
self, request: Optional[str], filtered_model_names: Optional[List[str]] = None
) -> Optional[List[Dict]]:
"""
Route a requested model to the corresponding llm deployments based on the regex pattern
loop through all the patterns and find the matching pattern
if a pattern is found, return the corresponding llm deployments
if no pattern is found, return None
Args:
request: str - the received model name from the user (can be a wildcard route). If none, No deployments will be returned.
filtered_model_names: Optional[List[str]] - if provided, only return deployments that match the filtered_model_names
Returns:
Optional[List[Deployment]]: llm deployments
"""
try:
if request is None:
return None
sorted_patterns = PatternUtils.sorted_patterns(self.patterns)
regex_filtered_model_names = (
[self._pattern_to_regex(m) for m in filtered_model_names]
if filtered_model_names is not None
else []
)
for pattern, llm_deployments in sorted_patterns:
if (
filtered_model_names is not None
and pattern not in regex_filtered_model_names
):
continue
pattern_match = re.match(pattern, request)
if pattern_match:
return self._return_pattern_matched_deployments(
matched_pattern=pattern_match, deployments=llm_deployments
)
except Exception as e:
verbose_router_logger.debug(f"Error in PatternMatchRouter.route: {str(e)}")
return None # No matching pattern found
@staticmethod
def set_deployment_model_name(
matched_pattern: Match,
litellm_deployment_litellm_model: str,
) -> str:
"""
Set the model name for the matched pattern llm deployment
E.g.:
Case 1:
model_name: llmengine/* (can be any regex pattern or wildcard pattern)
litellm_params:
model: openai/*
if model_name = "llmengine/foo" -> model = "openai/foo"
Case 2:
model_name: llmengine/fo::*::static::*
litellm_params:
model: openai/fo::*::static::*
if model_name = "llmengine/foo::bar::static::baz" -> model = "openai/foo::bar::static::baz"
Case 3:
model_name: *meta.llama3*
litellm_params:
model: bedrock/meta.llama3*
if model_name = "hello-world-meta.llama3-70b" -> model = "bedrock/meta.llama3-70b"
"""
## BASE CASE: if the deployment model name does not contain a wildcard, return the deployment model name
if "*" not in litellm_deployment_litellm_model:
return litellm_deployment_litellm_model
wildcard_count = litellm_deployment_litellm_model.count("*")
# Extract all dynamic segments from the request
dynamic_segments = matched_pattern.groups()
if len(dynamic_segments) > wildcard_count:
return (
matched_pattern.string
) # default to the user input, if unable to map based on wildcards.
# Replace the corresponding wildcards in the litellm model pattern with extracted segments
for segment in dynamic_segments:
litellm_deployment_litellm_model = litellm_deployment_litellm_model.replace(
"*", segment, 1
)
return litellm_deployment_litellm_model
def get_pattern(
self, model: str, custom_llm_provider: Optional[str] = None
) -> Optional[List[Dict]]:
"""
Check if a pattern exists for the given model and custom llm provider
Args:
model: str
custom_llm_provider: Optional[str]
Returns:
bool: True if pattern exists, False otherwise
"""
if custom_llm_provider is None:
try:
(
_,
custom_llm_provider,
_,
_,
) = get_llm_provider(model=model)
except Exception:
# get_llm_provider raises exception when provider is unknown
pass
return self.route(model) or self.route(f"{custom_llm_provider}/{model}")
def get_deployments_by_pattern(
self, model: str, custom_llm_provider: Optional[str] = None
) -> List[Dict]:
"""
Get the deployments by pattern
Args:
model: str
custom_llm_provider: Optional[str]
Returns:
List[Dict]: llm deployments matching the pattern
"""
pattern_match = self.get_pattern(model, custom_llm_provider)
if pattern_match:
return pattern_match
return []
# Example usage:
# router = PatternRouter()
# router.add_pattern('openai/*', [Deployment(), Deployment()])
# router.add_pattern('openai/fo::*::static::*', Deployment())
# print(router.route('openai/gpt-4')) # Output: [Deployment(), Deployment()]
# print(router.route('openai/fo::hi::static::hi')) # Output: [Deployment()]
# print(router.route('something/else')) # Output: None
|