Spaces:
Configuration error
Configuration error
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | |
import logging | |
from itertools import cycle | |
from random import choice | |
from typing import ( | |
Any, | |
Callable, | |
Dict, | |
List, | |
Union, | |
) | |
from openai import Stream | |
from camel.messages import OpenAIMessage | |
from camel.models.base_model import BaseModelBackend | |
from camel.types import ( | |
ChatCompletion, | |
ChatCompletionChunk, | |
UnifiedModelType, | |
) | |
from camel.utils import BaseTokenCounter | |
logger = logging.getLogger(__name__) | |
class ModelProcessingError(Exception): | |
r"""Raised when an error occurs during model processing.""" | |
pass | |
class ModelManager: | |
r"""ModelManager choosing a model from provided list. | |
Models are picked according to defined strategy. | |
Args: | |
models(Union[BaseModelBackend, List[BaseModelBackend]]): | |
model backend or list of model backends | |
(e.g., model instances, APIs) | |
scheduling_strategy (str): name of function that defines how | |
to select the next model. (default: :str:`round_robin`) | |
""" | |
def __init__( | |
self, | |
models: Union[BaseModelBackend, List[BaseModelBackend]], | |
scheduling_strategy: str = "round_robin", | |
): | |
if isinstance(models, list): | |
self.models = models | |
else: | |
self.models = [models] | |
self.models_cycle = cycle(self.models) | |
self.current_model = self.models[0] | |
# Set the scheduling strategy; default is round-robin | |
try: | |
self.scheduling_strategy = getattr(self, scheduling_strategy) | |
except AttributeError: | |
logger.warning( | |
f"Provided strategy: {scheduling_strategy} is not implemented." | |
f"Using default 'round robin'" | |
) | |
self.scheduling_strategy = self.round_robin | |
def model_type(self) -> UnifiedModelType: | |
r"""Return type of the current model. | |
Returns: | |
Union[ModelType, str]: Current model type. | |
""" | |
return self.current_model.model_type | |
def model_config_dict(self) -> Dict[str, Any]: | |
r"""Return model_config_dict of the current model. | |
Returns: | |
Dict[str, Any]: Config dictionary of the current model. | |
""" | |
return self.current_model.model_config_dict | |
def model_config_dict(self, model_config_dict: Dict[str, Any]): | |
r"""Set model_config_dict to the current model. | |
Args: | |
model_config_dict (Dict[str, Any]): Config dictionary to be set at | |
current model. | |
""" | |
self.current_model.model_config_dict = model_config_dict | |
def current_model_index(self) -> int: | |
r"""Return the index of current model in self.models list. | |
Returns: | |
int: index of current model in given list of models. | |
""" | |
return self.models.index(self.current_model) | |
def token_limit(self): | |
r"""Returns the maximum token limit for current model. | |
This method retrieves the maximum token limit either from the | |
`model_config_dict` or from the model's default token limit. | |
Returns: | |
int: The maximum token limit for the given model. | |
""" | |
return self.current_model.token_limit | |
def token_counter(self) -> BaseTokenCounter: | |
r"""Return token_counter of the current model. | |
Returns: | |
BaseTokenCounter: The token counter following the model's | |
tokenization style. | |
""" | |
return self.current_model.token_counter | |
def add_strategy(self, name: str, strategy_fn: Callable): | |
r"""Add a scheduling strategy method provided by user in case when none | |
of existent strategies fits. | |
When custom strategy is provided, it will be set as | |
"self.scheduling_strategy" attribute. | |
Args: | |
name (str): The name of the strategy. | |
strategy_fn (Callable): The scheduling strategy function. | |
""" | |
if not callable(strategy_fn): | |
raise ValueError("strategy_fn must be a callable function.") | |
setattr(self, name, strategy_fn.__get__(self)) | |
self.scheduling_strategy = getattr(self, name) | |
logger.info(f"Custom strategy '{name}' added.") | |
# Strategies | |
def round_robin(self) -> BaseModelBackend: | |
r"""Return models one by one in simple round-robin fashion. | |
Returns: | |
BaseModelBackend for processing incoming messages. | |
""" | |
return next(self.models_cycle) | |
def always_first(self) -> BaseModelBackend: | |
r"""Always return the first model from self.models. | |
Returns: | |
BaseModelBackend for processing incoming messages. | |
""" | |
return self.models[0] | |
def random_model(self) -> BaseModelBackend: | |
r"""Return random model from self.models list. | |
Returns: | |
BaseModelBackend for processing incoming messages. | |
""" | |
return choice(self.models) | |
def run( | |
self, messages: List[OpenAIMessage] | |
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]: | |
r"""Process a list of messages by selecting a model based on | |
the scheduling strategy. | |
Sends the entire list of messages to the selected model, | |
and returns a single response. | |
Args: | |
messages (List[OpenAIMessage]): Message list with the chat | |
history in OpenAI API format. | |
Returns: | |
Union[ChatCompletion, Stream[ChatCompletionChunk]]: | |
`ChatCompletion` in the non-stream mode, or | |
`Stream[ChatCompletionChunk]` in the stream mode. | |
""" | |
self.current_model = self.scheduling_strategy() | |
# Pass all messages to the selected model and get the response | |
try: | |
response = self.current_model.run(messages) | |
except Exception as exc: | |
logger.error(f"Error processing with model: {self.current_model}") | |
if self.scheduling_strategy == self.always_first: | |
self.scheduling_strategy = self.round_robin | |
logger.warning( | |
"The scheduling strategy has been changed to 'round_robin'" | |
) | |
# Skip already used one | |
self.current_model = self.scheduling_strategy() | |
raise exc | |
return response | |