""" Class to handle llm wildcard routing and regex pattern matching """ import copy import re from re import Match from typing import Dict, List, Optional, Tuple from litellm import get_llm_provider from litellm._logging import verbose_router_logger class PatternUtils: @staticmethod def calculate_pattern_specificity(pattern: str) -> Tuple[int, int]: """ Calculate pattern specificity based on length and complexity. Args: pattern: Regex pattern to analyze Returns: Tuple of (length, complexity) for sorting """ complexity_chars = ["*", "+", "?", "\\", "^", "$", "|", "(", ")"] ret_val = ( len(pattern), # Longer patterns more specific sum( pattern.count(char) for char in complexity_chars ), # More regex complexity ) return ret_val @staticmethod def sorted_patterns( patterns: Dict[str, List[Dict]] ) -> List[Tuple[str, List[Dict]]]: """ Cached property for patterns sorted by specificity. Returns: Sorted list of pattern-deployment tuples """ return sorted( patterns.items(), key=lambda x: PatternUtils.calculate_pattern_specificity(x[0]), reverse=True, ) class PatternMatchRouter: """ Class to handle llm wildcard routing and regex pattern matching doc: https://docs.litellm.ai/docs/proxy/configs#provider-specific-wildcard-routing This class will store a mapping for regex pattern: List[Deployments] """ def __init__(self): self.patterns: Dict[str, List] = {} def add_pattern(self, pattern: str, llm_deployment: Dict): """ Add a regex pattern and the corresponding llm deployments to the patterns Args: pattern: str llm_deployment: str or List[str] """ # Convert the pattern to a regex regex = self._pattern_to_regex(pattern) if regex not in self.patterns: self.patterns[regex] = [] self.patterns[regex].append(llm_deployment) def _pattern_to_regex(self, pattern: str) -> str: """ Convert a wildcard pattern to a regex pattern example: pattern: openai/* regex: openai/.* pattern: openai/fo::*::static::* regex: openai/fo::.*::static::.* Args: pattern: str Returns: str: regex pattern """ # # Replace '*' with '.*' for regex matching # regex = pattern.replace("*", ".*") # # Escape other special characters # regex = re.escape(regex).replace(r"\.\*", ".*") # return f"^{regex}$" return re.escape(pattern).replace(r"\*", "(.*)") def _return_pattern_matched_deployments( self, matched_pattern: Match, deployments: List[Dict] ) -> List[Dict]: new_deployments = [] for deployment in deployments: new_deployment = copy.deepcopy(deployment) new_deployment["litellm_params"]["model"] = ( PatternMatchRouter.set_deployment_model_name( matched_pattern=matched_pattern, litellm_deployment_litellm_model=deployment["litellm_params"][ "model" ], ) ) new_deployments.append(new_deployment) return new_deployments def route( self, request: Optional[str], filtered_model_names: Optional[List[str]] = None ) -> Optional[List[Dict]]: """ Route a requested model to the corresponding llm deployments based on the regex pattern loop through all the patterns and find the matching pattern if a pattern is found, return the corresponding llm deployments if no pattern is found, return None Args: request: str - the received model name from the user (can be a wildcard route). If none, No deployments will be returned. filtered_model_names: Optional[List[str]] - if provided, only return deployments that match the filtered_model_names Returns: Optional[List[Deployment]]: llm deployments """ try: if request is None: return None sorted_patterns = PatternUtils.sorted_patterns(self.patterns) regex_filtered_model_names = ( [self._pattern_to_regex(m) for m in filtered_model_names] if filtered_model_names is not None else [] ) for pattern, llm_deployments in sorted_patterns: if ( filtered_model_names is not None and pattern not in regex_filtered_model_names ): continue pattern_match = re.match(pattern, request) if pattern_match: return self._return_pattern_matched_deployments( matched_pattern=pattern_match, deployments=llm_deployments ) except Exception as e: verbose_router_logger.debug(f"Error in PatternMatchRouter.route: {str(e)}") return None # No matching pattern found @staticmethod def set_deployment_model_name( matched_pattern: Match, litellm_deployment_litellm_model: str, ) -> str: """ Set the model name for the matched pattern llm deployment E.g.: Case 1: model_name: llmengine/* (can be any regex pattern or wildcard pattern) litellm_params: model: openai/* if model_name = "llmengine/foo" -> model = "openai/foo" Case 2: model_name: llmengine/fo::*::static::* litellm_params: model: openai/fo::*::static::* if model_name = "llmengine/foo::bar::static::baz" -> model = "openai/foo::bar::static::baz" Case 3: model_name: *meta.llama3* litellm_params: model: bedrock/meta.llama3* if model_name = "hello-world-meta.llama3-70b" -> model = "bedrock/meta.llama3-70b" """ ## BASE CASE: if the deployment model name does not contain a wildcard, return the deployment model name if "*" not in litellm_deployment_litellm_model: return litellm_deployment_litellm_model wildcard_count = litellm_deployment_litellm_model.count("*") # Extract all dynamic segments from the request dynamic_segments = matched_pattern.groups() if len(dynamic_segments) > wildcard_count: return ( matched_pattern.string ) # default to the user input, if unable to map based on wildcards. # Replace the corresponding wildcards in the litellm model pattern with extracted segments for segment in dynamic_segments: litellm_deployment_litellm_model = litellm_deployment_litellm_model.replace( "*", segment, 1 ) return litellm_deployment_litellm_model def get_pattern( self, model: str, custom_llm_provider: Optional[str] = None ) -> Optional[List[Dict]]: """ Check if a pattern exists for the given model and custom llm provider Args: model: str custom_llm_provider: Optional[str] Returns: bool: True if pattern exists, False otherwise """ if custom_llm_provider is None: try: ( _, custom_llm_provider, _, _, ) = get_llm_provider(model=model) except Exception: # get_llm_provider raises exception when provider is unknown pass return self.route(model) or self.route(f"{custom_llm_provider}/{model}") def get_deployments_by_pattern( self, model: str, custom_llm_provider: Optional[str] = None ) -> List[Dict]: """ Get the deployments by pattern Args: model: str custom_llm_provider: Optional[str] Returns: List[Dict]: llm deployments matching the pattern """ pattern_match = self.get_pattern(model, custom_llm_provider) if pattern_match: return pattern_match return [] # Example usage: # router = PatternRouter() # router.add_pattern('openai/*', [Deployment(), Deployment()]) # router.add_pattern('openai/fo::*::static::*', Deployment()) # print(router.route('openai/gpt-4')) # Output: [Deployment(), Deployment()] # print(router.route('openai/fo::hi::static::hi')) # Output: [Deployment()] # print(router.route('something/else')) # Output: None