Spaces:
Configuration error
Configuration error
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | |
import logging | |
from abc import ABC, abstractmethod | |
from pathlib import Path | |
from typing import Any, Dict, List, Literal, Optional | |
from camel.agents import ChatAgent | |
logger = logging.getLogger(__name__) | |
class BaseBenchmark(ABC): | |
r"""Base class for benchmarks. | |
Attributes: | |
name (str): Name of the benchmark. | |
data_dir (str): Path to the data directory. | |
save_to (str): Path to save the results. | |
processes (int): Number of processes to use for parallel | |
processing. :(default: :obj:`1`) | |
""" | |
def __init__( | |
self, name: str, data_dir: str, save_to: str, processes: int = 1 | |
): | |
r"""Initialize the benchmark. | |
Args: | |
name (str): Name of the benchmark. | |
data_dir (str): Path to the data directory. | |
save_to (str): Path to save the results. | |
processes (int): Number of processes to use for parallel | |
processing. :(default: :obj:`1`) | |
""" | |
self.name = name | |
self.data_dir = Path(data_dir) | |
self.processes = processes | |
self.save_to = save_to | |
if not self.data_dir.exists(): | |
logger.info( | |
f"Data directory {data_dir} does not exist. Creating it." | |
) | |
self.data_dir.mkdir(parents=True, exist_ok=True) | |
if not self.data_dir.is_dir(): | |
raise NotADirectoryError( | |
f"Data directory {data_dir} is not a directory" | |
) | |
self._data: Dict[str, List[Dict[str, Any]]] = dict() | |
self._results: List[Dict[str, Any]] = [] | |
def download(self) -> "BaseBenchmark": | |
r"""Download the benchmark data. | |
Returns: | |
BaseBenchmark: The benchmark instance. | |
""" | |
pass | |
def load(self, force_download: bool = False) -> "BaseBenchmark": | |
r"""Load the benchmark data. | |
Args: | |
force_download (bool): Whether to force download the data. | |
Returns: | |
BaseBenchmark: The benchmark instance. | |
""" | |
pass | |
def train(self) -> List[Dict[str, Any]]: | |
r"""Get the training data. | |
Returns: | |
List[Dict[str, Any]]: The training data. | |
""" | |
if not self._data: | |
logger.info("Data not loaded. Loading data.") | |
self.load() | |
return self._data["train"] | |
def valid(self) -> List[Dict[str, Any]]: | |
r"""Get the validation data. | |
Returns: | |
List[Dict[str, Any]]: The validation data. | |
""" | |
if not self._data: | |
logger.info("Data not loaded. Loading data.") | |
self.load() | |
return self._data["valid"] | |
def test(self) -> List[Dict[str, Any]]: | |
r"""Get the test data. | |
Returns: | |
List[Dict[str, Any]]: The test data. | |
""" | |
if not self._data: | |
logger.info("Data not loaded. Loading data.") | |
self.load() | |
return self._data["test"] | |
def run( | |
self, | |
agent: ChatAgent, | |
on: Literal["train", "valid", "test"], | |
randomize: bool = False, | |
subset: Optional[int] = None, | |
*args, | |
**kwargs, | |
) -> "BaseBenchmark": | |
r"""Run the benchmark. | |
Args: | |
agent (ChatAgent): The chat agent. | |
on (str): The data split to run the benchmark on. | |
randomize (bool): Whether to randomize the data. | |
subset (int): The subset of the data to run the benchmark on. | |
Returns: | |
BaseBenchmark: The benchmark instance. | |
""" | |
pass | |
def results(self) -> List[Dict[str, Any]]: | |
r"""Get the results. | |
Returns: | |
List[Dict[str, Any]]: The results. | |
""" | |
return self._results | |