Upload folder using huggingface_hub
Browse files- README.md +4 -4
- api.py +5 -3
- dataset.py +10 -3
- error_utils.py +1 -0
- inference.py +184 -66
- llm_as_judge.py +715 -61
- llm_as_judge_chat_templates.py +5 -3
- llm_as_judge_constants.py +24 -44
- llm_as_judge_utils.py +0 -9
- loaders.py +36 -31
- metrics.py +193 -56
- schema.py +19 -2
- serializers.py +2 -1
- settings_utils.py +1 -0
- sql_utils.py +197 -7
- struct_data_operators.py +9 -9
- version.py +1 -1
README.md
CHANGED
@@ -40,11 +40,11 @@ https://github.com/IBM/unitxt/assets/23455264/baef9131-39d4-4164-90b2-05da52919f
|
|
40 |
|
41 |
### π¦ Currently on Unitxt Catalog
|
42 |
|
43 |
-

|
46 |
-

|
44 |
+

|
45 |

|
46 |
+

|
47 |
+

|
48 |
|
49 |
### π¦ Run Unitxt Exploration Dashboard
|
50 |
|
api.py
CHANGED
@@ -21,7 +21,7 @@ from .loaders import LoadFromDictionary
|
|
21 |
from .logging_utils import get_logger
|
22 |
from .metric_utils import EvaluationResults, _compute, _inference_post_process
|
23 |
from .operator import SourceOperator
|
24 |
-
from .schema import
|
25 |
from .settings_utils import get_constants, get_settings
|
26 |
from .standard import DatasetRecipe
|
27 |
from .task import Task
|
@@ -98,6 +98,7 @@ def create_dataset(
|
|
98 |
train_set: Optional[List[Dict[Any, Any]]] = None,
|
99 |
validation_set: Optional[List[Dict[Any, Any]]] = None,
|
100 |
split: Optional[str] = None,
|
|
|
101 |
**kwargs,
|
102 |
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
|
103 |
"""Creates dataset from input data based on a specific task.
|
@@ -108,6 +109,7 @@ def create_dataset(
|
|
108 |
train_set : optional train_set
|
109 |
validation_set: optional validation set
|
110 |
split: optional one split to choose
|
|
|
111 |
**kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
|
112 |
|
113 |
Returns:
|
@@ -129,7 +131,7 @@ def create_dataset(
|
|
129 |
f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
|
130 |
)
|
131 |
|
132 |
-
card = TaskCard(loader=LoadFromDictionary(data=data), task=task)
|
133 |
return load_dataset(card=card, split=split, **kwargs)
|
134 |
|
135 |
|
@@ -283,7 +285,7 @@ def produce(
|
|
283 |
result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
|
284 |
if not is_list:
|
285 |
return result[0]
|
286 |
-
return Dataset.from_list(result).with_transform(
|
287 |
|
288 |
|
289 |
def infer(
|
|
|
21 |
from .logging_utils import get_logger
|
22 |
from .metric_utils import EvaluationResults, _compute, _inference_post_process
|
23 |
from .operator import SourceOperator
|
24 |
+
from .schema import loads_batch
|
25 |
from .settings_utils import get_constants, get_settings
|
26 |
from .standard import DatasetRecipe
|
27 |
from .task import Task
|
|
|
98 |
train_set: Optional[List[Dict[Any, Any]]] = None,
|
99 |
validation_set: Optional[List[Dict[Any, Any]]] = None,
|
100 |
split: Optional[str] = None,
|
101 |
+
data_classification_policy: Optional[List[str]] = None,
|
102 |
**kwargs,
|
103 |
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
|
104 |
"""Creates dataset from input data based on a specific task.
|
|
|
109 |
train_set : optional train_set
|
110 |
validation_set: optional validation set
|
111 |
split: optional one split to choose
|
112 |
+
data_classification_policy: data_classification_policy
|
113 |
**kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
|
114 |
|
115 |
Returns:
|
|
|
131 |
f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
|
132 |
)
|
133 |
|
134 |
+
card = TaskCard(loader=LoadFromDictionary(data=data, data_classification_policy=data_classification_policy), task=task)
|
135 |
return load_dataset(card=card, split=split, **kwargs)
|
136 |
|
137 |
|
|
|
285 |
result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
|
286 |
if not is_list:
|
287 |
return result[0]
|
288 |
+
return Dataset.from_list(result).with_transform(loads_batch)
|
289 |
|
290 |
|
291 |
def infer(
|
dataset.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import os
|
2 |
-
from typing import Optional, Union
|
3 |
|
4 |
import datasets
|
5 |
|
@@ -50,7 +50,7 @@ from .random_utils import __file__ as _
|
|
50 |
from .recipe import __file__ as _
|
51 |
from .register import __file__ as _
|
52 |
from .schema import __file__ as _
|
53 |
-
from .schema import loads_instance
|
54 |
from .serializers import __file__ as _
|
55 |
from .settings_utils import __file__ as _
|
56 |
from .settings_utils import get_constants
|
@@ -120,6 +120,13 @@ class Dataset(datasets.GeneratorBasedBuilder):
|
|
120 |
dl_manager, "no_checks", **prepare_splits_kwargs
|
121 |
)
|
122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
def as_dataset(
|
124 |
self,
|
125 |
split: Optional[datasets.Split] = None,
|
@@ -162,5 +169,5 @@ class Dataset(datasets.GeneratorBasedBuilder):
|
|
162 |
return (
|
163 |
super()
|
164 |
.as_dataset(split, run_post_process, verification_mode, in_memory)
|
165 |
-
.with_transform(
|
166 |
)
|
|
|
1 |
import os
|
2 |
+
from typing import Dict, Optional, Union
|
3 |
|
4 |
import datasets
|
5 |
|
|
|
50 |
from .recipe import __file__ as _
|
51 |
from .register import __file__ as _
|
52 |
from .schema import __file__ as _
|
53 |
+
from .schema import loads_batch, loads_instance
|
54 |
from .serializers import __file__ as _
|
55 |
from .settings_utils import __file__ as _
|
56 |
from .settings_utils import get_constants
|
|
|
120 |
dl_manager, "no_checks", **prepare_splits_kwargs
|
121 |
)
|
122 |
|
123 |
+
def as_streaming_dataset(self, split: Optional[str] = None, base_path: Optional[str] = None) -> Union[Dict[str, datasets.IterableDataset], datasets.IterableDataset]:
|
124 |
+
return (
|
125 |
+
super()
|
126 |
+
.as_streaming_dataset(split, base_path=base_path)
|
127 |
+
.map(loads_instance)
|
128 |
+
)
|
129 |
+
|
130 |
def as_dataset(
|
131 |
self,
|
132 |
split: Optional[datasets.Split] = None,
|
|
|
169 |
return (
|
170 |
super()
|
171 |
.as_dataset(split, run_post_process, verification_mode, in_memory)
|
172 |
+
.with_transform(loads_batch)
|
173 |
)
|
error_utils.py
CHANGED
@@ -18,6 +18,7 @@ class Documentation:
|
|
18 |
BENCHMARKS = "docs/benchmark.html"
|
19 |
DATA_CLASSIFICATION_POLICY = "docs/data_classification_policy.html"
|
20 |
CATALOG = "docs/saving_and_loading_from_catalog.html"
|
|
|
21 |
|
22 |
|
23 |
def additional_info(path: str) -> str:
|
|
|
18 |
BENCHMARKS = "docs/benchmark.html"
|
19 |
DATA_CLASSIFICATION_POLICY = "docs/data_classification_policy.html"
|
20 |
CATALOG = "docs/saving_and_loading_from_catalog.html"
|
21 |
+
SETTINGS = "docs/settings.html"
|
22 |
|
23 |
|
24 |
def additional_info(path: str) -> str:
|
inference.py
CHANGED
@@ -2,6 +2,7 @@ import abc
|
|
2 |
import asyncio
|
3 |
import base64
|
4 |
import dataclasses
|
|
|
5 |
import io
|
6 |
import json
|
7 |
import logging
|
@@ -12,6 +13,7 @@ import time
|
|
12 |
import uuid
|
13 |
from collections import Counter
|
14 |
from datetime import datetime
|
|
|
15 |
from multiprocessing.pool import ThreadPool
|
16 |
from typing import (
|
17 |
Any,
|
@@ -29,6 +31,7 @@ from typing import (
|
|
29 |
)
|
30 |
|
31 |
from datasets import Dataset, DatasetDict, Image
|
|
|
32 |
from tqdm import tqdm, trange
|
33 |
from tqdm.asyncio import tqdm_asyncio
|
34 |
|
@@ -53,6 +56,11 @@ settings = get_settings()
|
|
53 |
logger = get_logger()
|
54 |
|
55 |
|
|
|
|
|
|
|
|
|
|
|
56 |
class StandardAPIParamsMixin(Artifact):
|
57 |
model: str
|
58 |
frequency_penalty: Optional[float] = None
|
@@ -149,6 +157,8 @@ class ListWithMetadata(List[T]):
|
|
149 |
|
150 |
class InferenceEngine(Artifact):
|
151 |
"""Abstract base class for inference."""
|
|
|
|
|
152 |
|
153 |
@abc.abstractmethod
|
154 |
def _infer(
|
@@ -173,6 +183,7 @@ class InferenceEngine(Artifact):
|
|
173 |
if not settings.mock_inference_mode:
|
174 |
super().prepare() # no need to prepare a mock
|
175 |
self.prepare_engine()
|
|
|
176 |
|
177 |
def __call__(
|
178 |
self,
|
@@ -181,16 +192,20 @@ class InferenceEngine(Artifact):
|
|
181 |
) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]:
|
182 |
return self.infer(dataset=dataset, return_meta_data=return_meta_data)
|
183 |
|
184 |
-
def
|
185 |
-
|
186 |
-
|
187 |
-
return_meta_data: bool = False,
|
188 |
-
) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]:
|
189 |
-
"""Verifies instances of a dataset and perform inference on the input dataset.
|
190 |
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
if not isoftype(dataset, Union[List[Dict[str, Any]], Dataset]):
|
195 |
raise Exception(
|
196 |
"Dataset passed to infer() is not list of dictionaries or Huggingface Dataset"
|
@@ -202,10 +217,54 @@ class InferenceEngine(Artifact):
|
|
202 |
)
|
203 |
|
204 |
[self.verify_instance(instance) for instance in dataset]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
if settings.mock_inference_mode:
|
206 |
result = self._mock_infer(dataset)
|
207 |
else:
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
return ListWithMetadata(
|
210 |
result,
|
211 |
metadata={
|
@@ -221,6 +280,7 @@ class InferenceEngine(Artifact):
|
|
221 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
222 |
return [str(instance["source"]) for instance in dataset]
|
223 |
|
|
|
224 |
def get_engine_id(self):
|
225 |
raise NotImplementedError()
|
226 |
|
@@ -918,16 +978,18 @@ class HFPipelineBasedInferenceEngine(
|
|
918 |
return args
|
919 |
|
920 |
def _create_pipeline(self, model_args: Dict[str, Any]):
|
921 |
-
from transformers import pipeline
|
922 |
|
923 |
path = self.model_name
|
924 |
if settings.hf_offline_models_path is not None:
|
925 |
path = os.path.join(settings.hf_offline_models_path, path)
|
926 |
|
|
|
927 |
self.model = pipeline(
|
928 |
model=path,
|
929 |
task=self.task,
|
930 |
use_fast=self.use_fast_tokenizer,
|
|
|
931 |
trust_remote_code=settings.allow_unverified_code,
|
932 |
**model_args,
|
933 |
**self.to_dict(
|
@@ -1302,7 +1364,7 @@ class IbmGenAiInferenceEngine(
|
|
1302 |
def _get_credentials():
|
1303 |
from genai import Credentials
|
1304 |
|
1305 |
-
api_key_env_var_name = "GENAI_KEY"
|
1306 |
api_key = os.environ.get(api_key_env_var_name)
|
1307 |
|
1308 |
assert api_key is not None, (
|
@@ -1718,7 +1780,7 @@ class AzureOpenAIInferenceEngine(OpenAiInferenceEngine):
|
|
1718 |
), "Error while trying to run AzureOpenAIInferenceEngine: Missing environment variable param AZURE_OPENAI_HOST or OPENAI_API_VERSION"
|
1719 |
api_url = f"{azure_openapi_host}/openai/deployments/{self.model_name}/chat/completions?api-version={api_version}"
|
1720 |
|
1721 |
-
return {"api_key": api_key, "api_url": api_url}
|
1722 |
|
1723 |
def create_client(self):
|
1724 |
from openai import AzureOpenAI
|
@@ -1727,12 +1789,13 @@ class AzureOpenAIInferenceEngine(OpenAiInferenceEngine):
|
|
1727 |
return AzureOpenAI(
|
1728 |
api_key=self.credentials["api_key"],
|
1729 |
base_url=self.credentials["api_url"],
|
|
|
1730 |
default_headers=self.get_default_headers(),
|
1731 |
)
|
1732 |
|
1733 |
|
1734 |
class VLLMRemoteInferenceEngine(OpenAiInferenceEngine):
|
1735 |
-
label: str = "vllm"
|
1736 |
|
1737 |
|
1738 |
class RITSInferenceEngine(
|
@@ -1741,6 +1804,10 @@ class RITSInferenceEngine(
|
|
1741 |
label: str = "rits"
|
1742 |
data_classification_policy = ["public", "proprietary"]
|
1743 |
|
|
|
|
|
|
|
|
|
1744 |
def get_default_headers(self):
|
1745 |
return {"RITS_API_KEY": self.credentials["api_key"]}
|
1746 |
|
@@ -1761,8 +1828,10 @@ class RITSInferenceEngine(
|
|
1761 |
RITSInferenceEngine._get_model_name_for_endpoint(model_name)
|
1762 |
)
|
1763 |
|
1764 |
-
@
|
1765 |
-
def _get_model_name_for_endpoint(model_name: str):
|
|
|
|
|
1766 |
return (
|
1767 |
model_name.split("/")[-1]
|
1768 |
.lower()
|
@@ -1805,7 +1874,7 @@ class TogetherAiInferenceEngine(
|
|
1805 |
from together import Together
|
1806 |
from together.types.models import ModelType
|
1807 |
|
1808 |
-
api_key_env_var_name = "TOGETHER_API_KEY"
|
1809 |
api_key = os.environ.get(api_key_env_var_name)
|
1810 |
assert api_key is not None, (
|
1811 |
f"Error while trying to run TogetherAiInferenceEngine."
|
@@ -1969,6 +2038,9 @@ class WMLInferenceEngineBase(
|
|
1969 |
deployment_id (str, optional):
|
1970 |
Deployment ID of a tuned model to be used for
|
1971 |
inference. Mutually exclusive with 'model_name'.
|
|
|
|
|
|
|
1972 |
parameters (Union[WMLInferenceEngineParams, WMLGenerationParamsMixin, WMLChatParamsMixin], optional):
|
1973 |
Defines inference parameters and their values. Deprecated attribute, please pass respective
|
1974 |
parameters directly to the respective class instead.
|
@@ -1977,6 +2049,7 @@ class WMLInferenceEngineBase(
|
|
1977 |
credentials: Optional[CredentialsWML] = None
|
1978 |
model_name: Optional[str] = None
|
1979 |
deployment_id: Optional[str] = None
|
|
|
1980 |
label: str = "wml"
|
1981 |
_requirements_list = {
|
1982 |
"ibm_watsonx_ai": "Install ibm-watsonx-ai package using 'pip install --upgrade ibm-watsonx-ai'. "
|
@@ -2230,11 +2303,6 @@ class WMLInferenceEngineGeneration(WMLInferenceEngineBase, WMLGenerationParamsMi
|
|
2230 |
|
2231 |
If you want to include images in your input, please use 'WMLInferenceEngineChat' instead.
|
2232 |
|
2233 |
-
Args:
|
2234 |
-
concurrency_limit (int):
|
2235 |
-
Number of concurrent requests sent to a model. Default is 10,
|
2236 |
-
which is also the maximum value.
|
2237 |
-
|
2238 |
Examples:
|
2239 |
.. code-block:: python
|
2240 |
|
@@ -2258,8 +2326,6 @@ class WMLInferenceEngineGeneration(WMLInferenceEngineBase, WMLGenerationParamsMi
|
|
2258 |
results = wml_inference.infer(dataset["test"])
|
2259 |
"""
|
2260 |
|
2261 |
-
concurrency_limit: int = 10
|
2262 |
-
|
2263 |
def verify(self):
|
2264 |
super().verify()
|
2265 |
|
@@ -2511,6 +2577,32 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
|
|
2511 |
# images as SDK allows sending only one image per message.
|
2512 |
return [messages]
|
2513 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2514 |
def _send_requests(
|
2515 |
self,
|
2516 |
dataset: Union[List[Dict[str, Any]], Dataset],
|
@@ -2526,27 +2618,25 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
|
|
2526 |
output_type = "message"
|
2527 |
params["logprobs"] = False
|
2528 |
|
2529 |
-
|
2530 |
-
|
2531 |
-
|
2532 |
-
|
2533 |
-
|
2534 |
-
for message in messages:
|
2535 |
-
result = self._model.chat(
|
2536 |
-
messages=message,
|
2537 |
-
params=params,
|
2538 |
-
)
|
2539 |
|
2540 |
-
|
2541 |
-
|
2542 |
-
|
2543 |
-
result,
|
2544 |
-
instance["source"],
|
2545 |
-
return_meta_data,
|
2546 |
-
)
|
2547 |
-
)
|
2548 |
|
2549 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2550 |
|
2551 |
def get_return_object(self, predict_result, result, input_text, return_meta_data):
|
2552 |
if return_meta_data:
|
@@ -2614,6 +2704,7 @@ def get_text_without_images(instance, image_token="<image>"):
|
|
2614 |
class LMMSEvalBaseInferenceEngine(
|
2615 |
InferenceEngine, PackageRequirementsMixin, LazyLoadMixin, TorchDeviceMixin
|
2616 |
):
|
|
|
2617 |
model_type: str
|
2618 |
model_args: Dict[str, str]
|
2619 |
batch_size: int = 1
|
@@ -2623,6 +2714,9 @@ class LMMSEvalBaseInferenceEngine(
|
|
2623 |
"lmms_eval": "Install llms-eval package using 'pip install lmms-eval==0.2.4'",
|
2624 |
}
|
2625 |
|
|
|
|
|
|
|
2626 |
def prepare_engine(self):
|
2627 |
if not self.lazy_load:
|
2628 |
self._prepare_engine()
|
@@ -2798,6 +2892,11 @@ class VLLMParamsMixin(Artifact):
|
|
2798 |
|
2799 |
|
2800 |
class VLLMInferenceEngine(InferenceEngine, PackageRequirementsMixin, VLLMParamsMixin):
|
|
|
|
|
|
|
|
|
|
|
2801 |
def prepare_engine(self):
|
2802 |
args = self.to_dict([VLLMParamsMixin])
|
2803 |
args.pop("model")
|
@@ -2883,6 +2982,9 @@ class LiteLLMInferenceEngine(
|
|
2883 |
|
2884 |
_requirements_list: list = ["litellm", "tenacity", "tqdm", "diskcache"]
|
2885 |
|
|
|
|
|
|
|
2886 |
def prepare_engine(self):
|
2887 |
# Initialize the token bucket rate limiter
|
2888 |
self._rate_limiter = AsyncTokenBucket(
|
@@ -2890,15 +2992,12 @@ class LiteLLMInferenceEngine(
|
|
2890 |
capacity=self.max_requests_per_second,
|
2891 |
)
|
2892 |
self.inference_type = "litellm"
|
2893 |
-
import litellm
|
2894 |
from litellm import acompletion
|
2895 |
-
from litellm.caching.caching import Cache
|
2896 |
|
2897 |
-
litellm.cache = Cache(type="disk")
|
2898 |
|
2899 |
self._completion = acompletion
|
2900 |
# Initialize a semaphore to limit concurrency
|
2901 |
-
self._semaphore = asyncio.Semaphore(self.max_requests_per_second)
|
2902 |
|
2903 |
async def _infer_instance(
|
2904 |
self, index: int, instance: Dict[str, Any]
|
@@ -3010,28 +3109,34 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3010 |
provider_specific_args: Optional[Dict[str, Dict[str,str]]] = None
|
3011 |
|
3012 |
provider_model_map: Dict[_supported_apis, Dict[str, str]] = {
|
3013 |
-
"watsonx": {
|
3014 |
-
"
|
3015 |
-
"
|
3016 |
-
"
|
3017 |
-
"llama-3-3-70b-instruct": "watsonx/meta-llama/llama-3-3-70b-instruct",
|
3018 |
-
"granite-3-8b-instruct": "watsonx/ibm/granite-3-8b-instruct",
|
3019 |
-
"flan-t5-xxl": "watsonx/google/flan-t5-xxl",
|
3020 |
-
"llama-3-2-1b-instruct": "watsonx/meta-llama/llama-3-2-1b-instruct",
|
3021 |
-
"llama-3-2-11b-vision-instruct": "watsonx/meta-llama/llama-3-2-11b-vision-instruct",
|
3022 |
-
"llama-3-2-90b-vision-instruct": "watsonx/meta-llama/llama-3-2-90b-vision-instruct",
|
3023 |
-
"mistral-large-instruct": "watsonx/mistralai/mistral-large",
|
3024 |
-
},
|
3025 |
-
"watsonx-sdk": {
|
3026 |
-
"llama-3-2-11b-vision-instruct": "meta-llama/llama-3-2-11b-vision-instruct",
|
3027 |
-
"llama-3-8b-instruct": "meta-llama/llama-3-8b-instruct",
|
3028 |
-
"llama-3-70b-instruct": "meta-llama/llama-3-70b-instruct",
|
3029 |
"granite-3-8b-instruct": "ibm/granite-3-8b-instruct",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3030 |
},
|
3031 |
"together-ai": {
|
3032 |
"llama-3-8b-instruct": "together_ai/meta-llama/Llama-3-8b-chat-hf",
|
3033 |
"llama-3-70b-instruct": "together_ai/meta-llama/Llama-3-70b-chat-hf",
|
|
|
|
|
|
|
3034 |
"llama-3-2-1b-instruct": "together_ai/togethercomputer/llama-3-2-1b-instruct",
|
|
|
3035 |
},
|
3036 |
"aws": {
|
3037 |
"llama-3-8b-instruct": "bedrock/meta.llama3-8b-instruct-v1:0",
|
@@ -3040,6 +3145,12 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3040 |
"ollama": {
|
3041 |
"llama-3-8b-instruct": "llama3:8b",
|
3042 |
"llama-3-70b-instruct": "llama3:70b",
|
|
|
|
|
|
|
|
|
|
|
|
|
3043 |
},
|
3044 |
"bam": {
|
3045 |
"granite-3-8b-instruct": "ibm/granite-8b-instruct-preview-4k",
|
@@ -3049,12 +3160,14 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3049 |
},
|
3050 |
"rits": {
|
3051 |
"granite-3-8b-instruct": "ibm-granite/granite-3.0-8b-instruct",
|
|
|
3052 |
"llama-3-1-8b-instruct": "meta-llama/llama-3-1-8b-instruct",
|
3053 |
"llama-3-1-70b-instruct": "meta-llama/llama-3-1-70b-instruct",
|
|
|
|
|
3054 |
"llama-3-2-11b-vision-instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
3055 |
"llama-3-2-90b-vision-instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
3056 |
"llama-3-3-70b-instruct": "meta-llama/llama-3-3-70b-instruct",
|
3057 |
-
"llama-3-1-405b-instruct-fp8": "meta-llama/llama-3-1-405b-instruct-fp8",
|
3058 |
"mistral-large-instruct": "mistralai/mistral-large-instruct-2407",
|
3059 |
"mixtral-8x7b-instruct": "mistralai/mixtral-8x7B-instruct-v0.1",
|
3060 |
},
|
@@ -3089,6 +3202,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3089 |
"o1-preview": "azure/o1-preview",
|
3090 |
"gpt-4o-mini": "azure/gpt-4o-mini",
|
3091 |
"gpt-4o": "azure/gpt-4o",
|
|
|
3092 |
"gpt-4": "azure/gpt-4",
|
3093 |
"gpt-4-0314": "azure/gpt-4-0314",
|
3094 |
"gpt-4-0613": "azure/gpt-4-0613",
|
@@ -3133,6 +3247,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3133 |
"mixtral-8x7b-instruct-v0.1": "replicate/mistralai/mixtral-8x7b-instruct-v0.1",
|
3134 |
},
|
3135 |
}
|
|
|
3136 |
|
3137 |
_provider_to_base_class = {
|
3138 |
"watsonx": LiteLLMInferenceEngine,
|
@@ -3190,7 +3305,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3190 |
del args[param]
|
3191 |
else:
|
3192 |
del args[param]
|
3193 |
-
self.engine = cls(**args)
|
3194 |
self.data_classification_policy = self.engine.data_classification_policy
|
3195 |
|
3196 |
def _infer(
|
@@ -3210,7 +3325,7 @@ class HFOptionSelectingInferenceEngine(InferenceEngine, TorchDeviceMixin):
|
|
3210 |
|
3211 |
This class uses models from the HuggingFace Transformers library to calculate log probabilities for text inputs.
|
3212 |
"""
|
3213 |
-
|
3214 |
model_name: str
|
3215 |
batch_size: int
|
3216 |
|
@@ -3218,6 +3333,9 @@ class HFOptionSelectingInferenceEngine(InferenceEngine, TorchDeviceMixin):
|
|
3218 |
"transformers": "Install huggingface package using 'pip install --upgrade transformers"
|
3219 |
}
|
3220 |
|
|
|
|
|
|
|
3221 |
def prepare_engine(self):
|
3222 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3223 |
|
|
|
2 |
import asyncio
|
3 |
import base64
|
4 |
import dataclasses
|
5 |
+
import hashlib
|
6 |
import io
|
7 |
import json
|
8 |
import logging
|
|
|
13 |
import uuid
|
14 |
from collections import Counter
|
15 |
from datetime import datetime
|
16 |
+
from itertools import islice
|
17 |
from multiprocessing.pool import ThreadPool
|
18 |
from typing import (
|
19 |
Any,
|
|
|
31 |
)
|
32 |
|
33 |
from datasets import Dataset, DatasetDict, Image
|
34 |
+
from diskcache import Cache
|
35 |
from tqdm import tqdm, trange
|
36 |
from tqdm.asyncio import tqdm_asyncio
|
37 |
|
|
|
56 |
logger = get_logger()
|
57 |
|
58 |
|
59 |
+
def batched(lst, n):
|
60 |
+
it = iter(lst)
|
61 |
+
while batch := list(islice(it, n)):
|
62 |
+
yield batch
|
63 |
+
|
64 |
class StandardAPIParamsMixin(Artifact):
|
65 |
model: str
|
66 |
frequency_penalty: Optional[float] = None
|
|
|
157 |
|
158 |
class InferenceEngine(Artifact):
|
159 |
"""Abstract base class for inference."""
|
160 |
+
cache_batch_size: int = 100
|
161 |
+
use_cache: bool = True
|
162 |
|
163 |
@abc.abstractmethod
|
164 |
def _infer(
|
|
|
183 |
if not settings.mock_inference_mode:
|
184 |
super().prepare() # no need to prepare a mock
|
185 |
self.prepare_engine()
|
186 |
+
self._cache = Cache(get_settings().inference_engine_cache_path + self.__class__.__name__)
|
187 |
|
188 |
def __call__(
|
189 |
self,
|
|
|
192 |
) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]:
|
193 |
return self.infer(dataset=dataset, return_meta_data=return_meta_data)
|
194 |
|
195 |
+
def get_instance_cache_key(self, instance):
|
196 |
+
instance_key_fields = ["media", "source", "task_data"]
|
197 |
+
return {key: instance[key] for key in instance if key in instance_key_fields}
|
|
|
|
|
|
|
198 |
|
199 |
+
def _get_cache_key(self, instance: Dict[str, Any]) -> str:
|
200 |
+
"""Generate a unique cache key for each input."""
|
201 |
+
record = self.get_instance_cache_key(instance)
|
202 |
+
record.update(self.to_dict())
|
203 |
+
instance_str = json.dumps(record, sort_keys=True)
|
204 |
+
return hashlib.md5(instance_str.encode()).hexdigest()
|
205 |
+
|
206 |
+
def verify_infer_inputs(self,
|
207 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
208 |
+
return_meta_data: bool):
|
209 |
if not isoftype(dataset, Union[List[Dict[str, Any]], Dataset]):
|
210 |
raise Exception(
|
211 |
"Dataset passed to infer() is not list of dictionaries or Huggingface Dataset"
|
|
|
217 |
)
|
218 |
|
219 |
[self.verify_instance(instance) for instance in dataset]
|
220 |
+
|
221 |
+
def infer(
|
222 |
+
self,
|
223 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
224 |
+
return_meta_data: bool = False,
|
225 |
+
) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]:
|
226 |
+
"""Verifies instances of a dataset and perform inference on the input dataset.
|
227 |
+
|
228 |
+
If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the string
|
229 |
+
predictions.
|
230 |
+
"""
|
231 |
+
self.verify_infer_inputs(dataset, return_meta_data)
|
232 |
if settings.mock_inference_mode:
|
233 |
result = self._mock_infer(dataset)
|
234 |
else:
|
235 |
+
if self.use_cache:
|
236 |
+
number_of_batches = len(dataset) // self.cache_batch_size + 1
|
237 |
+
result = []
|
238 |
+
for batch_index, batch in enumerate(batched(dataset, self.cache_batch_size)):
|
239 |
+
cached_results = []
|
240 |
+
missing_examples = []
|
241 |
+
for i, item in enumerate(batch):
|
242 |
+
cache_key = self._get_cache_key(item)
|
243 |
+
cached_value = self._cache.get(cache_key)
|
244 |
+
if cached_value is not None:
|
245 |
+
cached_results.append((i, cached_value)) # each element is index in batch, and value
|
246 |
+
else:
|
247 |
+
missing_examples.append((i, item)) # each element is index in batch and example
|
248 |
+
# infare on missing examples only, without indices
|
249 |
+
|
250 |
+
logger.info(f"Inferring batch {batch_index + 1} / {number_of_batches} with {len(missing_examples)} instances (found {len(cached_results)} instances in {self._cache.directory})")
|
251 |
+
if (len(missing_examples) > 0):
|
252 |
+
inferred_results = self._infer([e[1] for e in missing_examples], return_meta_data)
|
253 |
+
# recombined to index and value
|
254 |
+
inferred_results = list(zip([e[0] for e in missing_examples], inferred_results))
|
255 |
+
# Add missing examples to cache
|
256 |
+
for (_, item), (_, prediction) in zip(missing_examples, inferred_results):
|
257 |
+
if prediction is None:
|
258 |
+
continue
|
259 |
+
cache_key = self._get_cache_key(item)
|
260 |
+
self._cache[cache_key] = prediction
|
261 |
+
else:
|
262 |
+
inferred_results=[]
|
263 |
+
# Combine cached and inferred results in original order
|
264 |
+
batch_predictions = [p[1] for p in sorted(cached_results + inferred_results)]
|
265 |
+
result.extend(batch_predictions)
|
266 |
+
else:
|
267 |
+
result = self._infer(dataset, return_meta_data)
|
268 |
return ListWithMetadata(
|
269 |
result,
|
270 |
metadata={
|
|
|
280 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
281 |
return [str(instance["source"]) for instance in dataset]
|
282 |
|
283 |
+
@abc.abstractmethod
|
284 |
def get_engine_id(self):
|
285 |
raise NotImplementedError()
|
286 |
|
|
|
978 |
return args
|
979 |
|
980 |
def _create_pipeline(self, model_args: Dict[str, Any]):
|
981 |
+
from transformers import AutoTokenizer, pipeline
|
982 |
|
983 |
path = self.model_name
|
984 |
if settings.hf_offline_models_path is not None:
|
985 |
path = os.path.join(settings.hf_offline_models_path, path)
|
986 |
|
987 |
+
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
988 |
self.model = pipeline(
|
989 |
model=path,
|
990 |
task=self.task,
|
991 |
use_fast=self.use_fast_tokenizer,
|
992 |
+
tokenizer=tokenizer,
|
993 |
trust_remote_code=settings.allow_unverified_code,
|
994 |
**model_args,
|
995 |
**self.to_dict(
|
|
|
1364 |
def _get_credentials():
|
1365 |
from genai import Credentials
|
1366 |
|
1367 |
+
api_key_env_var_name = "GENAI_KEY" # pragma: allowlist secret
|
1368 |
api_key = os.environ.get(api_key_env_var_name)
|
1369 |
|
1370 |
assert api_key is not None, (
|
|
|
1780 |
), "Error while trying to run AzureOpenAIInferenceEngine: Missing environment variable param AZURE_OPENAI_HOST or OPENAI_API_VERSION"
|
1781 |
api_url = f"{azure_openapi_host}/openai/deployments/{self.model_name}/chat/completions?api-version={api_version}"
|
1782 |
|
1783 |
+
return {"api_key": api_key, "api_url": api_url, "api_version": api_version}
|
1784 |
|
1785 |
def create_client(self):
|
1786 |
from openai import AzureOpenAI
|
|
|
1789 |
return AzureOpenAI(
|
1790 |
api_key=self.credentials["api_key"],
|
1791 |
base_url=self.credentials["api_url"],
|
1792 |
+
api_version=self.credentials["api_version"],
|
1793 |
default_headers=self.get_default_headers(),
|
1794 |
)
|
1795 |
|
1796 |
|
1797 |
class VLLMRemoteInferenceEngine(OpenAiInferenceEngine):
|
1798 |
+
label: str = "vllm-remote"
|
1799 |
|
1800 |
|
1801 |
class RITSInferenceEngine(
|
|
|
1804 |
label: str = "rits"
|
1805 |
data_classification_policy = ["public", "proprietary"]
|
1806 |
|
1807 |
+
model_names_dict = {
|
1808 |
+
"microsoft/phi-4": "microsoft-phi-4"
|
1809 |
+
}
|
1810 |
+
|
1811 |
def get_default_headers(self):
|
1812 |
return {"RITS_API_KEY": self.credentials["api_key"]}
|
1813 |
|
|
|
1828 |
RITSInferenceEngine._get_model_name_for_endpoint(model_name)
|
1829 |
)
|
1830 |
|
1831 |
+
@classmethod
|
1832 |
+
def _get_model_name_for_endpoint(cls, model_name: str):
|
1833 |
+
if model_name in cls.model_names_dict:
|
1834 |
+
return cls.model_names_dict[model_name]
|
1835 |
return (
|
1836 |
model_name.split("/")[-1]
|
1837 |
.lower()
|
|
|
1874 |
from together import Together
|
1875 |
from together.types.models import ModelType
|
1876 |
|
1877 |
+
api_key_env_var_name = "TOGETHER_API_KEY" # pragma: allowlist secret
|
1878 |
api_key = os.environ.get(api_key_env_var_name)
|
1879 |
assert api_key is not None, (
|
1880 |
f"Error while trying to run TogetherAiInferenceEngine."
|
|
|
2038 |
deployment_id (str, optional):
|
2039 |
Deployment ID of a tuned model to be used for
|
2040 |
inference. Mutually exclusive with 'model_name'.
|
2041 |
+
concurrency_limit (int):
|
2042 |
+
Number of concurrent requests sent to a model. Default is 10,
|
2043 |
+
which is also the maximum value for the generation.
|
2044 |
parameters (Union[WMLInferenceEngineParams, WMLGenerationParamsMixin, WMLChatParamsMixin], optional):
|
2045 |
Defines inference parameters and their values. Deprecated attribute, please pass respective
|
2046 |
parameters directly to the respective class instead.
|
|
|
2049 |
credentials: Optional[CredentialsWML] = None
|
2050 |
model_name: Optional[str] = None
|
2051 |
deployment_id: Optional[str] = None
|
2052 |
+
concurrency_limit: int = 10
|
2053 |
label: str = "wml"
|
2054 |
_requirements_list = {
|
2055 |
"ibm_watsonx_ai": "Install ibm-watsonx-ai package using 'pip install --upgrade ibm-watsonx-ai'. "
|
|
|
2303 |
|
2304 |
If you want to include images in your input, please use 'WMLInferenceEngineChat' instead.
|
2305 |
|
|
|
|
|
|
|
|
|
|
|
2306 |
Examples:
|
2307 |
.. code-block:: python
|
2308 |
|
|
|
2326 |
results = wml_inference.infer(dataset["test"])
|
2327 |
"""
|
2328 |
|
|
|
|
|
2329 |
def verify(self):
|
2330 |
super().verify()
|
2331 |
|
|
|
2577 |
# images as SDK allows sending only one image per message.
|
2578 |
return [messages]
|
2579 |
|
2580 |
+
def _handle_async_requests(
|
2581 |
+
self,
|
2582 |
+
messages: List[List[Dict[str, Any]]],
|
2583 |
+
params: Dict[str, Any],
|
2584 |
+
) -> List[Dict[str, Any]]:
|
2585 |
+
async def handle_async_requests(start_idx, end_idx):
|
2586 |
+
coroutines = [
|
2587 |
+
self._model.achat(messages=messages[idx], params=params)
|
2588 |
+
for idx in range(start_idx, end_idx)
|
2589 |
+
]
|
2590 |
+
batch_results = await asyncio.gather(*coroutines)
|
2591 |
+
return list(batch_results)
|
2592 |
+
|
2593 |
+
loop = asyncio.get_event_loop()
|
2594 |
+
results = []
|
2595 |
+
|
2596 |
+
for batch_idx in range(0, len(messages), self.concurrency_limit):
|
2597 |
+
batch_results = loop.run_until_complete(
|
2598 |
+
handle_async_requests(
|
2599 |
+
batch_idx, min(batch_idx + self.concurrency_limit, len(messages))
|
2600 |
+
)
|
2601 |
+
)
|
2602 |
+
results.extend(batch_results)
|
2603 |
+
|
2604 |
+
return results
|
2605 |
+
|
2606 |
def _send_requests(
|
2607 |
self,
|
2608 |
dataset: Union[List[Dict[str, Any]], Dataset],
|
|
|
2618 |
output_type = "message"
|
2619 |
params["logprobs"] = False
|
2620 |
|
2621 |
+
indexed_messages = [
|
2622 |
+
(i, message)
|
2623 |
+
for i in range(len(dataset))
|
2624 |
+
for message in self.to_messages(dataset[i])
|
2625 |
+
]
|
|
|
|
|
|
|
|
|
|
|
2626 |
|
2627 |
+
results = self._handle_async_requests(
|
2628 |
+
[msg[1] for msg in indexed_messages], params
|
2629 |
+
)
|
|
|
|
|
|
|
|
|
|
|
2630 |
|
2631 |
+
return [
|
2632 |
+
self.get_return_object(
|
2633 |
+
result["choices"][0][output_type]["content"],
|
2634 |
+
result,
|
2635 |
+
dataset[idx[0]]["source"],
|
2636 |
+
return_meta_data,
|
2637 |
+
)
|
2638 |
+
for result, idx in zip(results, indexed_messages)
|
2639 |
+
]
|
2640 |
|
2641 |
def get_return_object(self, predict_result, result, input_text, return_meta_data):
|
2642 |
if return_meta_data:
|
|
|
2704 |
class LMMSEvalBaseInferenceEngine(
|
2705 |
InferenceEngine, PackageRequirementsMixin, LazyLoadMixin, TorchDeviceMixin
|
2706 |
):
|
2707 |
+
label = "lmms-eval"
|
2708 |
model_type: str
|
2709 |
model_args: Dict[str, str]
|
2710 |
batch_size: int = 1
|
|
|
2714 |
"lmms_eval": "Install llms-eval package using 'pip install lmms-eval==0.2.4'",
|
2715 |
}
|
2716 |
|
2717 |
+
def get_engine_id(self):
|
2718 |
+
return get_model_and_label_id(self.model_type, self.label)
|
2719 |
+
|
2720 |
def prepare_engine(self):
|
2721 |
if not self.lazy_load:
|
2722 |
self._prepare_engine()
|
|
|
2892 |
|
2893 |
|
2894 |
class VLLMInferenceEngine(InferenceEngine, PackageRequirementsMixin, VLLMParamsMixin):
|
2895 |
+
label="vllm"
|
2896 |
+
|
2897 |
+
def get_engine_id(self):
|
2898 |
+
return get_model_and_label_id(self.model, self.label)
|
2899 |
+
|
2900 |
def prepare_engine(self):
|
2901 |
args = self.to_dict([VLLMParamsMixin])
|
2902 |
args.pop("model")
|
|
|
2982 |
|
2983 |
_requirements_list: list = ["litellm", "tenacity", "tqdm", "diskcache"]
|
2984 |
|
2985 |
+
def get_engine_id(self):
|
2986 |
+
return get_model_and_label_id(self.model, self.label)
|
2987 |
+
|
2988 |
def prepare_engine(self):
|
2989 |
# Initialize the token bucket rate limiter
|
2990 |
self._rate_limiter = AsyncTokenBucket(
|
|
|
2992 |
capacity=self.max_requests_per_second,
|
2993 |
)
|
2994 |
self.inference_type = "litellm"
|
|
|
2995 |
from litellm import acompletion
|
|
|
2996 |
|
|
|
2997 |
|
2998 |
self._completion = acompletion
|
2999 |
# Initialize a semaphore to limit concurrency
|
3000 |
+
self._semaphore = asyncio.Semaphore(round(self.max_requests_per_second))
|
3001 |
|
3002 |
async def _infer_instance(
|
3003 |
self, index: int, instance: Dict[str, Any]
|
|
|
3109 |
provider_specific_args: Optional[Dict[str, Dict[str,str]]] = None
|
3110 |
|
3111 |
provider_model_map: Dict[_supported_apis, Dict[str, str]] = {
|
3112 |
+
"watsonx-sdk": { # checked from ibm_watsonx_ai.APIClient().foundation_models.ChatModels
|
3113 |
+
"granite-20b-code-instruct": "ibm/granite-20b-code-instruct",
|
3114 |
+
"granite-3-2-8b-instruct": "ibm/granite-3-2-8b-instruct",
|
3115 |
+
"granite-3-2b-instruct": "ibm/granite-3-2b-instruct",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3116 |
"granite-3-8b-instruct": "ibm/granite-3-8b-instruct",
|
3117 |
+
"granite-34b-code-instruct": "ibm/granite-34b-code-instruct",
|
3118 |
+
"granite-guardian-3-8b": "ibm/granite-guardian-3-8b",
|
3119 |
+
"granite-vision-3-2-2b": "ibm/granite-vision-3-2-2b",
|
3120 |
+
"llama-3-1-8b-instruct": "meta-llama/llama-3-1-8b-instruct",
|
3121 |
+
"llama-3-1-70b-instruct": "meta-llama/llama-3-1-70b-instruct",
|
3122 |
+
"llama-3-1-405b-instruct": "meta-llama/llama-3-405b-instruct",
|
3123 |
+
"llama-3-2-11b-vision-instruct": "meta-llama/llama-3-2-11b-vision-instruct",
|
3124 |
+
"llama-3-2-1b-instruct": "meta-llama/llama-3-2-1b-instruct",
|
3125 |
+
"llama-3-2-3b-instruct": "meta-llama/llama-3-2-3b-instruct",
|
3126 |
+
"llama-3-2-90b-vision-instruct": "meta-llama/llama-3-2-90b-vision-instruct",
|
3127 |
+
"llama-3-3-70b-instruct": "meta-llama/llama-3-3-70b-instruct",
|
3128 |
+
"llama-guard-3-11b-vision": "meta-llama/llama-guard-3-11b-vision",
|
3129 |
+
"mistral-large-instruct": "mistralai/mistral-large",
|
3130 |
+
"mixtral-8x7b-instruct-v01": "mistralai/mixtral-8x7b-instruct-v01",
|
3131 |
},
|
3132 |
"together-ai": {
|
3133 |
"llama-3-8b-instruct": "together_ai/meta-llama/Llama-3-8b-chat-hf",
|
3134 |
"llama-3-70b-instruct": "together_ai/meta-llama/Llama-3-70b-chat-hf",
|
3135 |
+
"llama-3-1-8b-instruct": "together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
3136 |
+
"llama-3-1-70b-instruct": "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
3137 |
+
"llama-3-1-405b-instruct": "together_ai/meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
3138 |
"llama-3-2-1b-instruct": "together_ai/togethercomputer/llama-3-2-1b-instruct",
|
3139 |
+
"llama-3-3-70b-instruct": "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo"
|
3140 |
},
|
3141 |
"aws": {
|
3142 |
"llama-3-8b-instruct": "bedrock/meta.llama3-8b-instruct-v1:0",
|
|
|
3145 |
"ollama": {
|
3146 |
"llama-3-8b-instruct": "llama3:8b",
|
3147 |
"llama-3-70b-instruct": "llama3:70b",
|
3148 |
+
"llama-3-1-8b-instruct": "llama3.1:8b",
|
3149 |
+
"llama-3-1-70b-instruct": "llama3.1:70b",
|
3150 |
+
"llama-3-1-405b-instruct": "llama3.1:405b",
|
3151 |
+
"llama-3-2-1b-instruct": "llama3.2:1b",
|
3152 |
+
"llama-3-2-3b-instruct": "llama3.2:3b",
|
3153 |
+
"llama-3-3-70b-instruct": "llama3.3"
|
3154 |
},
|
3155 |
"bam": {
|
3156 |
"granite-3-8b-instruct": "ibm/granite-8b-instruct-preview-4k",
|
|
|
3160 |
},
|
3161 |
"rits": {
|
3162 |
"granite-3-8b-instruct": "ibm-granite/granite-3.0-8b-instruct",
|
3163 |
+
"granite-3-2-8b-instruct": "ibm-granite/granite-3.2-8b-instruct",
|
3164 |
"llama-3-1-8b-instruct": "meta-llama/llama-3-1-8b-instruct",
|
3165 |
"llama-3-1-70b-instruct": "meta-llama/llama-3-1-70b-instruct",
|
3166 |
+
"llama-3-1-405b-instruct": "meta-llama/llama-3-1-405b-instruct-fp8",
|
3167 |
+
"llama-3-1-405b-instruct-fp8": "meta-llama/llama-3-1-405b-instruct-fp8",
|
3168 |
"llama-3-2-11b-vision-instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
3169 |
"llama-3-2-90b-vision-instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
3170 |
"llama-3-3-70b-instruct": "meta-llama/llama-3-3-70b-instruct",
|
|
|
3171 |
"mistral-large-instruct": "mistralai/mistral-large-instruct-2407",
|
3172 |
"mixtral-8x7b-instruct": "mistralai/mixtral-8x7B-instruct-v0.1",
|
3173 |
},
|
|
|
3202 |
"o1-preview": "azure/o1-preview",
|
3203 |
"gpt-4o-mini": "azure/gpt-4o-mini",
|
3204 |
"gpt-4o": "azure/gpt-4o",
|
3205 |
+
"gpt-4o-2024-08-06": "azure/gpt-4o-2024-08-06",
|
3206 |
"gpt-4": "azure/gpt-4",
|
3207 |
"gpt-4-0314": "azure/gpt-4-0314",
|
3208 |
"gpt-4-0613": "azure/gpt-4-0613",
|
|
|
3247 |
"mixtral-8x7b-instruct-v0.1": "replicate/mistralai/mixtral-8x7b-instruct-v0.1",
|
3248 |
},
|
3249 |
}
|
3250 |
+
provider_model_map["watsonx"] = {k: f"watsonx/{v}" for k,v in provider_model_map["watsonx-sdk"].items()}
|
3251 |
|
3252 |
_provider_to_base_class = {
|
3253 |
"watsonx": LiteLLMInferenceEngine,
|
|
|
3305 |
del args[param]
|
3306 |
else:
|
3307 |
del args[param]
|
3308 |
+
self.engine: InferenceEngine = cls(**args)
|
3309 |
self.data_classification_policy = self.engine.data_classification_policy
|
3310 |
|
3311 |
def _infer(
|
|
|
3325 |
|
3326 |
This class uses models from the HuggingFace Transformers library to calculate log probabilities for text inputs.
|
3327 |
"""
|
3328 |
+
label = "hf_option_selection"
|
3329 |
model_name: str
|
3330 |
batch_size: int
|
3331 |
|
|
|
3333 |
"transformers": "Install huggingface package using 'pip install --upgrade transformers"
|
3334 |
}
|
3335 |
|
3336 |
+
def get_engine_id(self):
|
3337 |
+
return get_model_and_label_id(self.model_name, self.label)
|
3338 |
+
|
3339 |
def prepare_engine(self):
|
3340 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3341 |
|
llm_as_judge.py
CHANGED
@@ -8,15 +8,12 @@ from .dict_utils import dict_get
|
|
8 |
from .error_utils import UnitxtError
|
9 |
from .inference import (
|
10 |
InferenceEngine,
|
11 |
-
OptionSelectingByLogProbsInferenceEngine,
|
12 |
)
|
13 |
from .llm_as_judge_chat_templates import direct_template_dict, pairwise_template_dict
|
14 |
from .llm_as_judge_constants import (
|
15 |
DIRECT_CRITERIA,
|
16 |
EVALUATOR_TO_MODEL_ID,
|
17 |
EVALUATORS_METADATA,
|
18 |
-
INFERENCE_ENGINE_NAME_TO_CLASS,
|
19 |
-
MODEL_RENAMINGS,
|
20 |
PAIRWISE_CRITERIA,
|
21 |
Criteria,
|
22 |
CriteriaOption,
|
@@ -44,30 +41,50 @@ from .llm_as_judge_utils import (
|
|
44 |
get_evaluator_metadata,
|
45 |
get_parsed_context,
|
46 |
rank_indexes,
|
47 |
-
rename_model_if_required,
|
48 |
)
|
49 |
from .logging_utils import get_logger
|
50 |
from .metrics import BulkInstanceMetric
|
51 |
from .task import Task
|
52 |
from .templates import Template
|
53 |
|
|
|
54 |
|
55 |
class LLMJudge(BulkInstanceMetric):
|
|
|
|
|
|
|
|
|
|
|
56 |
inference_engine: InferenceEngine
|
57 |
-
|
58 |
-
|
59 |
-
# )
|
60 |
evaluator_name: EvaluatorNameEnum = None
|
|
|
|
|
61 |
check_positional_bias: bool = True
|
|
|
|
|
62 |
context_fields: Union[str, List[str], Dict[str, str]] = ["context"]
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
criteria_field: str = None
|
|
|
|
|
67 |
criteria: Criteria = None
|
68 |
-
|
|
|
69 |
|
70 |
def prepare(self):
|
|
|
71 |
super().prepare()
|
72 |
if isinstance(self.context_fields, str):
|
73 |
self.context_fields = [self.context_fields]
|
@@ -78,10 +95,13 @@ class LLMJudge(BulkInstanceMetric):
|
|
78 |
|
79 |
if self.evaluator_name is None:
|
80 |
self.evaluator_name = self.inference_engine.get_engine_id()
|
81 |
-
elif not isinstance(self.evaluator_name, EvaluatorNameEnum):
|
82 |
-
self.evaluator_name = EvaluatorNameEnum[self.evaluator_name]
|
83 |
|
84 |
def before_process_multi_stream(self):
|
|
|
|
|
|
|
|
|
|
|
85 |
super().before_process_multi_stream()
|
86 |
# We check the criteria here and not in verify(), because we want catalog
|
87 |
# may contain a partially initialized object, and verify() method
|
@@ -93,6 +113,14 @@ class LLMJudge(BulkInstanceMetric):
|
|
93 |
return
|
94 |
|
95 |
def get_contexts(self, task_data: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
return [
|
97 |
get_parsed_context(
|
98 |
{
|
@@ -110,6 +138,17 @@ class LLMJudge(BulkInstanceMetric):
|
|
110 |
template: Template,
|
111 |
previous_messages: Optional[List[Dict[str, str]]] = None,
|
112 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
outputs_dataset = infer(
|
114 |
instances,
|
115 |
task=task,
|
@@ -129,6 +168,14 @@ class LLMJudge(BulkInstanceMetric):
|
|
129 |
return (prompts, raw_predictions, predictions)
|
130 |
|
131 |
def clean_results(self, results: Union[dict, list]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
if isinstance(results, list):
|
133 |
return [self.clean_results(x) for x in results]
|
134 |
cleaned = {
|
@@ -143,13 +190,25 @@ class LLMJudge(BulkInstanceMetric):
|
|
143 |
if not (isinstance(v, dict) and len(v) == 0)
|
144 |
}
|
145 |
|
146 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
if self.criteria is None:
|
148 |
if self.criteria_field not in task_data[0]:
|
149 |
raise UnitxtError(
|
150 |
f"The criteria field `{self.criteria_field}` required for {__class__.__name__} is not found in instance. Perhaps you meant '{get_close_matches(self.criteria_field, task_data[0].keys(), n=1, cutoff=0.0)[0]}'?"
|
151 |
)
|
152 |
-
|
153 |
f"Reading criteria from the task_data field '{self.criteria_field}'"
|
154 |
)
|
155 |
criterias = [
|
@@ -157,20 +216,31 @@ class LLMJudge(BulkInstanceMetric):
|
|
157 |
for task_data_instance in task_data
|
158 |
]
|
159 |
else:
|
160 |
-
|
161 |
"Reading criteria from self. Criteria is a single CriteriaWithOptions, replicating it for all predictions"
|
162 |
)
|
163 |
criterias: List[Criteria] = [self.criteria] * eval_count
|
164 |
unique_criteria_names = list({criteria.name for criteria in criterias})
|
165 |
|
166 |
-
|
167 |
return criterias
|
168 |
|
169 |
|
170 |
class LLMJudgeDirect(LLMJudge):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
criteria: CriteriaWithOptions = None
|
|
|
172 |
main_score = "llm_as_judge"
|
|
|
173 |
reduction_map = {"mean": ["llm_as_judge"]}
|
|
|
174 |
|
175 |
def prepare(self):
|
176 |
super().prepare()
|
@@ -200,7 +270,7 @@ class LLMJudgeDirect(LLMJudge):
|
|
200 |
self.option_selection_task = Task(
|
201 |
input_fields={
|
202 |
"criteria_description": str,
|
203 |
-
"
|
204 |
"options": list,
|
205 |
},
|
206 |
reference_fields={},
|
@@ -209,6 +279,7 @@ class LLMJudgeDirect(LLMJudge):
|
|
209 |
)
|
210 |
|
211 |
def before_process_multi_stream(self):
|
|
|
212 |
super().before_process_multi_stream()
|
213 |
if self.criteria is not None and not isinstance(
|
214 |
self.criteria, CriteriaWithOptions
|
@@ -218,34 +289,42 @@ class LLMJudgeDirect(LLMJudge):
|
|
218 |
)
|
219 |
return
|
220 |
|
221 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
criteria_description = criteria.description
|
223 |
criteria_option_names = [o.name for o in criteria.options]
|
224 |
|
225 |
-
display_options_instruction = "Choose an
|
226 |
[
|
227 |
f'- "{o.name}"{f" if {o.description}" if o.description != "" else ""}'
|
228 |
for o in criteria.options
|
229 |
]
|
230 |
)
|
231 |
-
score_option_instruction = "".join(
|
232 |
-
[f"Score {o.name}: {o.description}\n" for o in criteria.options]
|
233 |
-
)
|
234 |
|
235 |
return (
|
236 |
criteria_description,
|
237 |
criteria_option_names,
|
238 |
display_options_instruction,
|
239 |
-
score_option_instruction,
|
240 |
)
|
241 |
|
242 |
-
def
|
243 |
unique_criteria_names = list({criteria.name for criteria in criterias})
|
244 |
if len(unique_criteria_names) == 1 and criterias[0].name != "":
|
245 |
self.main_score = "_".join(criterias[0].name.lower().split(" "))
|
246 |
self.reduction_map = {"mean": [self.main_score]}
|
247 |
|
248 |
-
def
|
249 |
self,
|
250 |
assessment_prompts,
|
251 |
assessment_outputs,
|
@@ -289,6 +368,9 @@ class LLMJudgeDirect(LLMJudge):
|
|
289 |
"summary": summarization_outputs[i]
|
290 |
if self.generate_summaries
|
291 |
else None,
|
|
|
|
|
|
|
292 |
"prompts": {
|
293 |
"assessment": assessment_prompts[i],
|
294 |
"positional_bias_assessment": assessment_prompts[
|
@@ -332,14 +414,113 @@ class LLMJudgeDirect(LLMJudge):
|
|
332 |
references: List[List[str]],
|
333 |
predictions: List[str],
|
334 |
task_data: List[Dict[str, Any]],
|
335 |
-
) ->
|
336 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
f'Starting evaluation with evaluator "{self.evaluator_name}" and provider "{self.inference_engine.get_pretty_print_name()}'
|
338 |
)
|
339 |
evaluations_count = len(predictions)
|
340 |
# TODO: find out how to serialize and deserialize enums
|
341 |
-
criterias = self.
|
342 |
-
self.
|
343 |
contexts = self.get_contexts(task_data)
|
344 |
if self.check_positional_bias:
|
345 |
criterias += [
|
@@ -355,14 +536,13 @@ class LLMJudgeDirect(LLMJudge):
|
|
355 |
predictions += predictions
|
356 |
|
357 |
parsed_criterias = [
|
358 |
-
self.
|
359 |
]
|
360 |
|
361 |
(
|
362 |
criteria_description_list,
|
363 |
criteria_option_names_list,
|
364 |
display_options_instruction_list,
|
365 |
-
score_option_instruction_list,
|
366 |
) = zip(*parsed_criterias)
|
367 |
|
368 |
assessment_for_summaries_slice = slice(0, evaluations_count)
|
@@ -385,7 +565,7 @@ class LLMJudgeDirect(LLMJudge):
|
|
385 |
assessment_prompts, assessment_outputs, _ = self.perform_evaluation_step(
|
386 |
assessment_instances, self.assessment_task, self.assessment_template
|
387 |
)
|
388 |
-
|
389 |
|
390 |
summarization_prompts = None
|
391 |
summarization_outputs = None
|
@@ -409,18 +589,22 @@ class LLMJudgeDirect(LLMJudge):
|
|
409 |
self.summarization_task,
|
410 |
self.summarization_template,
|
411 |
)
|
412 |
-
|
413 |
|
414 |
option_selection_instances = [
|
415 |
{
|
416 |
"criteria_description": criteria_description,
|
417 |
-
"
|
418 |
"options": criteria_option_names,
|
419 |
"data_classification_policy": ["public"],
|
420 |
}
|
421 |
-
for
|
|
|
|
|
|
|
|
|
422 |
criteria_description_list,
|
423 |
-
|
424 |
criteria_option_names_list,
|
425 |
)
|
426 |
]
|
@@ -441,9 +625,9 @@ class LLMJudgeDirect(LLMJudge):
|
|
441 |
self.option_selection_template,
|
442 |
previous_messages,
|
443 |
)
|
444 |
-
|
445 |
|
446 |
-
results = self.
|
447 |
assessment_prompts,
|
448 |
assessment_outputs,
|
449 |
summarization_prompts,
|
@@ -454,15 +638,19 @@ class LLMJudgeDirect(LLMJudge):
|
|
454 |
evaluations_count,
|
455 |
criterias,
|
456 |
)
|
|
|
457 |
return self.clean_results(results)
|
458 |
|
459 |
|
460 |
class LLMJudgePairwise(LLMJudge):
|
461 |
-
|
462 |
main_score = "1_winrate"
|
463 |
-
|
|
|
|
|
464 |
|
465 |
def prepare(self):
|
|
|
466 |
super().prepare()
|
467 |
self.assessment_template = pairwise_template_dict["assessment"]
|
468 |
self.summarization_template = pairwise_template_dict["summarization"]
|
@@ -501,6 +689,7 @@ class LLMJudgePairwise(LLMJudge):
|
|
501 |
)
|
502 |
|
503 |
def before_process_multi_stream(self):
|
|
|
504 |
super().before_process_multi_stream()
|
505 |
if self.criteria is not None and not isinstance(self.criteria, Criteria):
|
506 |
raise Exception(
|
@@ -508,7 +697,7 @@ class LLMJudgePairwise(LLMJudge):
|
|
508 |
)
|
509 |
return
|
510 |
|
511 |
-
def
|
512 |
self,
|
513 |
instance_predictions: Dict[str, str],
|
514 |
assessment_prompts,
|
@@ -520,8 +709,26 @@ class LLMJudgePairwise(LLMJudge):
|
|
520 |
selections,
|
521 |
contests_count,
|
522 |
combination_indexes,
|
523 |
-
|
524 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
525 |
response_names = list(instance_predictions.keys())
|
526 |
per_response_results = {
|
527 |
response_key: {
|
@@ -680,32 +887,479 @@ class LLMJudgePairwise(LLMJudge):
|
|
680 |
for metric in single_result.keys():
|
681 |
all_results[f"{response_name}_{metric}"] = single_result[metric]
|
682 |
|
683 |
-
all_results["criteria"] =
|
684 |
return self.clean_results(all_results)
|
685 |
|
686 |
-
def
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
692 |
)
|
693 |
|
694 |
-
def
|
695 |
self, predictions: Union[List[Dict[str, str]], List[str]]
|
696 |
):
|
697 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
698 |
|
699 |
def compute(
|
700 |
self,
|
701 |
references: List[List[str]],
|
702 |
predictions: List[str],
|
703 |
task_data: List[Dict[str, str]],
|
704 |
-
) ->
|
705 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
706 |
f'Starting evaluation with evaluator "{self.evaluator_name}" and provider {self.inference_engine.get_pretty_print_name()}'
|
707 |
)
|
708 |
-
predictions = self.
|
|
|
709 |
instances_count = len(predictions)
|
710 |
self.reduction_map = {"mean": ["score"]}
|
711 |
self.reduction_map["mean"].extend(
|
@@ -721,7 +1375,7 @@ class LLMJudgePairwise(LLMJudge):
|
|
721 |
len(combination_indexes) for combination_indexes in combination_indexes_list
|
722 |
]
|
723 |
|
724 |
-
|
725 |
f"The evaluation will perform {sum(contests_count_list) * [1, 2][self.check_positional_bias]} ({' + '.join([f'{c * [1, 2][self.check_positional_bias]}' for c in contests_count_list])}) pairwise comparisons"
|
726 |
)
|
727 |
|
@@ -752,7 +1406,7 @@ class LLMJudgePairwise(LLMJudge):
|
|
752 |
response_pairs_list.append(response_pairs)
|
753 |
option_pairs_list.append(option_pairs)
|
754 |
|
755 |
-
criterias = self.
|
756 |
contexts = self.get_contexts(task_data)
|
757 |
if self.check_positional_bias:
|
758 |
criterias.extend(criterias)
|
@@ -786,7 +1440,7 @@ class LLMJudgePairwise(LLMJudge):
|
|
786 |
assessment_prompts, assessment_outputs, _ = self.perform_evaluation_step(
|
787 |
assessment_instances, self.assessment_task, self.assessment_template
|
788 |
)
|
789 |
-
|
790 |
|
791 |
# the slices used to get the assessment for each summary generation instance
|
792 |
# it will grab the whole assessment for a particular instance or half of it depending on the value of check_positional_bias
|
@@ -836,7 +1490,7 @@ class LLMJudgePairwise(LLMJudge):
|
|
836 |
self.summarization_task,
|
837 |
self.summarization_template,
|
838 |
)
|
839 |
-
|
840 |
|
841 |
score_option_instruction_list = [
|
842 |
"".join(
|
@@ -884,7 +1538,7 @@ class LLMJudgePairwise(LLMJudge):
|
|
884 |
)
|
885 |
# Selections are of the form 'Response n', so we just keep n
|
886 |
selections = [selection.split(" ")[-1] for selection in selections]
|
887 |
-
|
888 |
results = []
|
889 |
slice_start = 0
|
890 |
for i, incremental_contests_count in enumerate(incremental_contests_count_list):
|
@@ -897,7 +1551,7 @@ class LLMJudgePairwise(LLMJudge):
|
|
897 |
(incremental_contests_count_list[i - 1] if i > 0 else 0)
|
898 |
+ incremental_contests_count,
|
899 |
)
|
900 |
-
instance_results = self.
|
901 |
predictions[i],
|
902 |
assessment_prompts[sli],
|
903 |
assessment_outputs[sli],
|
|
|
8 |
from .error_utils import UnitxtError
|
9 |
from .inference import (
|
10 |
InferenceEngine,
|
|
|
11 |
)
|
12 |
from .llm_as_judge_chat_templates import direct_template_dict, pairwise_template_dict
|
13 |
from .llm_as_judge_constants import (
|
14 |
DIRECT_CRITERIA,
|
15 |
EVALUATOR_TO_MODEL_ID,
|
16 |
EVALUATORS_METADATA,
|
|
|
|
|
17 |
PAIRWISE_CRITERIA,
|
18 |
Criteria,
|
19 |
CriteriaOption,
|
|
|
41 |
get_evaluator_metadata,
|
42 |
get_parsed_context,
|
43 |
rank_indexes,
|
|
|
44 |
)
|
45 |
from .logging_utils import get_logger
|
46 |
from .metrics import BulkInstanceMetric
|
47 |
from .task import Task
|
48 |
from .templates import Template
|
49 |
|
50 |
+
logger = get_logger(__name__)
|
51 |
|
52 |
class LLMJudge(BulkInstanceMetric):
|
53 |
+
"""A metric class to evaluate instances using LLM as a Judge.
|
54 |
+
|
55 |
+
Evaluations are performed in two steps. First, the LLM is asked to generate an assessment following a CoT approach based on the criteria. Then, the same LLM is asked to select one of the available options. A summary of the general assessment can be generated for easy consumption by end users.
|
56 |
+
"""
|
57 |
+
|
58 |
inference_engine: InferenceEngine
|
59 |
+
"""The engine used for generating predictions in the different evaluation steps."""
|
60 |
+
|
|
|
61 |
evaluator_name: EvaluatorNameEnum = None
|
62 |
+
"""The name of the evaluator. It is used for score naming. If not provided `self.inference_engine.get_engine_id()` is used."""
|
63 |
+
|
64 |
check_positional_bias: bool = True
|
65 |
+
"""Flag to check for positional bias. Detecting for positional bias duplicates the amount of inference calls."""
|
66 |
+
|
67 |
context_fields: Union[str, List[str], Dict[str, str]] = ["context"]
|
68 |
+
"""Fields to be used as context. If a dict is provided, the keys are used as the final names in the prompts, while the values are used to access the context variable values in the `task_data` object."""
|
69 |
+
|
70 |
+
generate_summaries: bool = False
|
71 |
+
"""Flag to generate summaries of the assessments. Defaults to `False`."""
|
72 |
+
|
73 |
+
format: str = "formats.chat_api"
|
74 |
+
"""The format used for the inference. Defaults to `formats.chat_api` (only allowed value)."""
|
75 |
+
|
76 |
+
include_prompts_in_result: bool = True
|
77 |
+
"""Flag to include prompts in the result. Defaults to `True`."""
|
78 |
+
|
79 |
criteria_field: str = None
|
80 |
+
"""The field specifying the evaluation criteria in the `task_data` object."""
|
81 |
+
|
82 |
criteria: Criteria = None
|
83 |
+
"""The criteria used for evaluation. If the `criteria_field` is provided, it will take precedence."""
|
84 |
+
|
85 |
|
86 |
def prepare(self):
|
87 |
+
"""Prepares the `LLMJudge` instance by setting up context fields and evaluator name."""
|
88 |
super().prepare()
|
89 |
if isinstance(self.context_fields, str):
|
90 |
self.context_fields = [self.context_fields]
|
|
|
95 |
|
96 |
if self.evaluator_name is None:
|
97 |
self.evaluator_name = self.inference_engine.get_engine_id()
|
|
|
|
|
98 |
|
99 |
def before_process_multi_stream(self):
|
100 |
+
"""Checks the criteria-related fields correctness before processing multiple streams.
|
101 |
+
|
102 |
+
Raises:
|
103 |
+
UnitxtError: If both 'criteria' and 'criteria_field' are not set.
|
104 |
+
"""
|
105 |
super().before_process_multi_stream()
|
106 |
# We check the criteria here and not in verify(), because we want catalog
|
107 |
# may contain a partially initialized object, and verify() method
|
|
|
113 |
return
|
114 |
|
115 |
def get_contexts(self, task_data: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
116 |
+
"""Extracts and parses context fields from task data.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
task_data (List[Dict[str, Any]]): The task data containing context information.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
List[Dict[str, str]]: A list of parsed context dictionaries.
|
123 |
+
"""
|
124 |
return [
|
125 |
get_parsed_context(
|
126 |
{
|
|
|
138 |
template: Template,
|
139 |
previous_messages: Optional[List[Dict[str, str]]] = None,
|
140 |
):
|
141 |
+
"""Performs an evaluation step by generating predictions for the given instances.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
instances (list): The list of instances to evaluate.
|
145 |
+
task (Task): The task associated with the instances.
|
146 |
+
template (Template): The template used for generating predictions.
|
147 |
+
previous_messages (Optional[List[Dict[str, str]]]): Previous messages for context.
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
Tuple[List[str], List[str], List[str]]: A tuple containing prompts, raw predictions, and processed predictions. Raw predictions differ from processed predictions only in the completion step, where the processors.match_closest_option is used.
|
151 |
+
"""
|
152 |
outputs_dataset = infer(
|
153 |
instances,
|
154 |
task=task,
|
|
|
168 |
return (prompts, raw_predictions, predictions)
|
169 |
|
170 |
def clean_results(self, results: Union[dict, list]):
|
171 |
+
"""Cleans the results by removing `None` values and empty lists and dictionaries.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
results (Union[dict, list]): The results to clean.
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
Union[dict, list]: The cleaned results.
|
178 |
+
"""
|
179 |
if isinstance(results, list):
|
180 |
return [self.clean_results(x) for x in results]
|
181 |
cleaned = {
|
|
|
190 |
if not (isinstance(v, dict) and len(v) == 0)
|
191 |
}
|
192 |
|
193 |
+
def get_criteria(self, task_data, eval_count):
|
194 |
+
"""Retrieves the evaluation criteria from the `criteria_field` or from `self`.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
task_data (List[Dict[str, Any]]): The task data containing criteria information.
|
198 |
+
eval_count (int): The number of evaluations to perform.
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
List[Criteria]: A list of criteria for evaluation.
|
202 |
+
|
203 |
+
Raises:
|
204 |
+
UnitxtError: If the criteria field is not found in the task data.
|
205 |
+
"""
|
206 |
if self.criteria is None:
|
207 |
if self.criteria_field not in task_data[0]:
|
208 |
raise UnitxtError(
|
209 |
f"The criteria field `{self.criteria_field}` required for {__class__.__name__} is not found in instance. Perhaps you meant '{get_close_matches(self.criteria_field, task_data[0].keys(), n=1, cutoff=0.0)[0]}'?"
|
210 |
)
|
211 |
+
logger.info(
|
212 |
f"Reading criteria from the task_data field '{self.criteria_field}'"
|
213 |
)
|
214 |
criterias = [
|
|
|
216 |
for task_data_instance in task_data
|
217 |
]
|
218 |
else:
|
219 |
+
logger.info(
|
220 |
"Reading criteria from self. Criteria is a single CriteriaWithOptions, replicating it for all predictions"
|
221 |
)
|
222 |
criterias: List[Criteria] = [self.criteria] * eval_count
|
223 |
unique_criteria_names = list({criteria.name for criteria in criterias})
|
224 |
|
225 |
+
logger.info(f"Criteria names are '{', '.join(unique_criteria_names)}'")
|
226 |
return criterias
|
227 |
|
228 |
|
229 |
class LLMJudgeDirect(LLMJudge):
|
230 |
+
"""LLMJudgeDirect is a specialized evaluation metric that performs Direct Assessment using an LLM to score responses based on a predefined evaluation criteria.
|
231 |
+
|
232 |
+
Direct Assessment is an evaluation paradigm in which the LLM selects one of a
|
233 |
+
predefined set of options based on an assessment criterion. This approach can
|
234 |
+
be used for Likert-scale scoring (e.g., 1-5) or selecting from semantically
|
235 |
+
conditioned literals (e.g., Yes/No, Pass/Fail).
|
236 |
+
"""
|
237 |
+
|
238 |
criteria: CriteriaWithOptions = None
|
239 |
+
"""The evaluation criteria, including a name, description, a predefined set of options and and option_map."""
|
240 |
main_score = "llm_as_judge"
|
241 |
+
"""The primary score name used in the results. By default, it will take the value of the criteria name (if only one criteria is being used for evaluation) or "llm_as_judge" otherwise."""
|
242 |
reduction_map = {"mean": ["llm_as_judge"]}
|
243 |
+
"""A mapping used for score aggregation. By default, it will take the value of `{'mean': [<default_main_score_name>]}`."""
|
244 |
|
245 |
def prepare(self):
|
246 |
super().prepare()
|
|
|
270 |
self.option_selection_task = Task(
|
271 |
input_fields={
|
272 |
"criteria_description": str,
|
273 |
+
"display_options_instruction": str,
|
274 |
"options": list,
|
275 |
},
|
276 |
reference_fields={},
|
|
|
279 |
)
|
280 |
|
281 |
def before_process_multi_stream(self):
|
282 |
+
"""Ensures that the criteria is of type `CriteriaWithOptions`, raising an exception otherwise."""
|
283 |
super().before_process_multi_stream()
|
284 |
if self.criteria is not None and not isinstance(
|
285 |
self.criteria, CriteriaWithOptions
|
|
|
289 |
)
|
290 |
return
|
291 |
|
292 |
+
def __get_parsed_criteria(self, criteria: CriteriaWithOptions):
|
293 |
+
"""Extracts key information from the given criteria.
|
294 |
+
|
295 |
+
Args:
|
296 |
+
criteria (CriteriaWithOptions): The evaluation criteria.
|
297 |
+
|
298 |
+
Returns:
|
299 |
+
Tuple[str, List[str], str, str]:
|
300 |
+
- Criteria description.
|
301 |
+
- List of option names.
|
302 |
+
- Formatted instruction for displaying options.
|
303 |
+
- Instruction for scoring options.
|
304 |
+
"""
|
305 |
criteria_description = criteria.description
|
306 |
criteria_option_names = [o.name for o in criteria.options]
|
307 |
|
308 |
+
display_options_instruction = "Choose an option:\n" + "\n".join(
|
309 |
[
|
310 |
f'- "{o.name}"{f" if {o.description}" if o.description != "" else ""}'
|
311 |
for o in criteria.options
|
312 |
]
|
313 |
)
|
|
|
|
|
|
|
314 |
|
315 |
return (
|
316 |
criteria_description,
|
317 |
criteria_option_names,
|
318 |
display_options_instruction,
|
|
|
319 |
)
|
320 |
|
321 |
+
def __set_main_score(self, criterias: List[CriteriaWithOptions]):
|
322 |
unique_criteria_names = list({criteria.name for criteria in criterias})
|
323 |
if len(unique_criteria_names) == 1 and criterias[0].name != "":
|
324 |
self.main_score = "_".join(criterias[0].name.lower().split(" "))
|
325 |
self.reduction_map = {"mean": [self.main_score]}
|
326 |
|
327 |
+
def __get_results(
|
328 |
self,
|
329 |
assessment_prompts,
|
330 |
assessment_outputs,
|
|
|
368 |
"summary": summarization_outputs[i]
|
369 |
if self.generate_summaries
|
370 |
else None,
|
371 |
+
"positional_bias_summary": summarization_outputs[i]
|
372 |
+
if self.generate_summaries and self.check_positional_bias
|
373 |
+
else None,
|
374 |
"prompts": {
|
375 |
"assessment": assessment_prompts[i],
|
376 |
"positional_bias_assessment": assessment_prompts[
|
|
|
414 |
references: List[List[str]],
|
415 |
predictions: List[str],
|
416 |
task_data: List[Dict[str, Any]],
|
417 |
+
) -> List[Dict]:
|
418 |
+
r"""Performs direct assessment evaluation on the given predictions and references.
|
419 |
+
|
420 |
+
This method evaluates the quality of of the predictions by calculating scores for each instance based on a criterion.
|
421 |
+
|
422 |
+
Returns:
|
423 |
+
-------
|
424 |
+
List[Dict]
|
425 |
+
A list of dictionaries containing the evaluation results for each instance. The results include the computed scores for each prediction. Each result will have the `score_name` as a prefix, which may be the criterion name if only one used, or "llm_as_judge" if several criteria were used.
|
426 |
+
|
427 |
+
Explanation of fields:
|
428 |
+
|
429 |
+
- `score`: a float representing the evaluation score for the response. The value is calculated from criteria.option_map[selected_option].
|
430 |
+
- `using_<evaluator_name>`: Equal to score.
|
431 |
+
- `positional_bias`: Boolean indicating whether the assessment detected positional bias. Its final value is selected_option != positional_bias_selected_option
|
432 |
+
- `selected_option`: The criteria option that the evaluator chose (e.g., "Could be Improved"). It is calculated by processing `option_selection_completion` using `processors.match_closest_option`
|
433 |
+
- `positional_bias_selected_option`: The criteria option that the evaluator chose when checking positional bias.
|
434 |
+
- `assessment`: The inference engine's generated text using the `prompts.assessment` prompt.
|
435 |
+
- `positional_bias_assessment`: The inference engine's generated text using the `prompts.positional_bias_assessment` prompt.
|
436 |
+
- `summary`: An LLM-generated summary of the assessment.
|
437 |
+
- `positional_bias_summary`: A LLM-generated summary of the positional bias assessment.
|
438 |
+
- `prompts`: A dictionary of prompts used in different stages of evaluation.
|
439 |
+
- `assessment`: The prompt used to instruct the model on how to assess the response.
|
440 |
+
- `positional_bias_assessment`: The prompt used to instruct the model on how to assess the response in the positional bias check.
|
441 |
+
- `summarization`: The prompt used to generate summary of the assessment.
|
442 |
+
- `option_selection`: The prompt used to generate a final judgement.
|
443 |
+
- `positional_bias_option_selection`: The prompt used to generate a final judgement in the positional bias check.
|
444 |
+
- `option_selection_completion`: The inference engine's generated text using `prompts.option_selection`.
|
445 |
+
- `positional_bias_option_selection_completion`: The inference engine's generated text using `prompts.positional_bias_option_selection`.
|
446 |
+
- `criteria`: A JSON-like string representing the evaluation criteria's artifact.
|
447 |
+
|
448 |
+
Result example:
|
449 |
+
|
450 |
+
.. code-block:: python
|
451 |
+
|
452 |
+
[
|
453 |
+
{
|
454 |
+
"answer_relevance": 1,
|
455 |
+
"answer_relevance_using_granite3.0-2b_litellm": 1,
|
456 |
+
"answer_relevance_positional_bias": false,
|
457 |
+
"answer_relevance_selected_option": "Could be Improved",
|
458 |
+
"answer_relevance_positional_bias_selected_option": "Could be Improved",
|
459 |
+
"answer_relevance_assessment": "To assess the quality of the response, l...",
|
460 |
+
"answer_relevance_positional_bias_assessment": "To assess the quality of the response, l...",
|
461 |
+
"answer_relevance_summary": "A response about apprenticeships during ...",
|
462 |
+
"answer_relevance_positional_bias_summary": "A response about apprenticeships during ...",
|
463 |
+
"answer_relevance_prompts": {
|
464 |
+
"assessment": [
|
465 |
+
{
|
466 |
+
"role": "user",
|
467 |
+
"content": "You are presented with a response gener..."
|
468 |
+
}
|
469 |
+
],
|
470 |
+
"positional_bias_assessment": [
|
471 |
+
{
|
472 |
+
"role": "user",
|
473 |
+
"content": "You are presented with a response gener..."
|
474 |
+
}
|
475 |
+
],
|
476 |
+
"summarization": [
|
477 |
+
{
|
478 |
+
"role": "user",
|
479 |
+
"content": "Transform the following assessment into ..."
|
480 |
+
}
|
481 |
+
],
|
482 |
+
"option_selection": [
|
483 |
+
{
|
484 |
+
"content": "You are presented with a response gener...",
|
485 |
+
"role": "user"
|
486 |
+
},
|
487 |
+
{
|
488 |
+
"content": "To assess the quality of the response, l...",
|
489 |
+
"role": "assistant"
|
490 |
+
},
|
491 |
+
{
|
492 |
+
"content": "Now consider the evaluation criteria and...",
|
493 |
+
"role": "user"
|
494 |
+
}
|
495 |
+
],
|
496 |
+
"posional_bias_option_selection": [
|
497 |
+
{
|
498 |
+
"content": "You are presented with a response gener...",
|
499 |
+
"role": "user"
|
500 |
+
},
|
501 |
+
{
|
502 |
+
"content": "To assess the quality of the response, l...",
|
503 |
+
"role": "assistant"
|
504 |
+
},
|
505 |
+
{
|
506 |
+
"content": "Now consider the evaluation criteria and...",
|
507 |
+
"role": "user"
|
508 |
+
}
|
509 |
+
]
|
510 |
+
},
|
511 |
+
"answer_relevance_option_selection_completion": "Could be Improved",
|
512 |
+
"answer_relevance_positional_bias_option_selection_completion": "Could be Improved",
|
513 |
+
"answer_relevance_criteria": "{ \"__type__\": \"criteria_with_options..."
|
514 |
+
}
|
515 |
+
]
|
516 |
+
"""
|
517 |
+
logger.info(
|
518 |
f'Starting evaluation with evaluator "{self.evaluator_name}" and provider "{self.inference_engine.get_pretty_print_name()}'
|
519 |
)
|
520 |
evaluations_count = len(predictions)
|
521 |
# TODO: find out how to serialize and deserialize enums
|
522 |
+
criterias = self.get_criteria(task_data, evaluations_count)
|
523 |
+
self.__set_main_score(criterias)
|
524 |
contexts = self.get_contexts(task_data)
|
525 |
if self.check_positional_bias:
|
526 |
criterias += [
|
|
|
536 |
predictions += predictions
|
537 |
|
538 |
parsed_criterias = [
|
539 |
+
self.__get_parsed_criteria(criteria) for criteria in criterias
|
540 |
]
|
541 |
|
542 |
(
|
543 |
criteria_description_list,
|
544 |
criteria_option_names_list,
|
545 |
display_options_instruction_list,
|
|
|
546 |
) = zip(*parsed_criterias)
|
547 |
|
548 |
assessment_for_summaries_slice = slice(0, evaluations_count)
|
|
|
565 |
assessment_prompts, assessment_outputs, _ = self.perform_evaluation_step(
|
566 |
assessment_instances, self.assessment_task, self.assessment_template
|
567 |
)
|
568 |
+
logger.info("The assessment was generated successfully.")
|
569 |
|
570 |
summarization_prompts = None
|
571 |
summarization_outputs = None
|
|
|
589 |
self.summarization_task,
|
590 |
self.summarization_template,
|
591 |
)
|
592 |
+
logger.info("The summary was generated successfully.")
|
593 |
|
594 |
option_selection_instances = [
|
595 |
{
|
596 |
"criteria_description": criteria_description,
|
597 |
+
"display_options_instruction": display_options_instruction,
|
598 |
"options": criteria_option_names,
|
599 |
"data_classification_policy": ["public"],
|
600 |
}
|
601 |
+
for (
|
602 |
+
criteria_description,
|
603 |
+
display_options_instruction,
|
604 |
+
criteria_option_names
|
605 |
+
) in zip(
|
606 |
criteria_description_list,
|
607 |
+
display_options_instruction_list,
|
608 |
criteria_option_names_list,
|
609 |
)
|
610 |
]
|
|
|
625 |
self.option_selection_template,
|
626 |
previous_messages,
|
627 |
)
|
628 |
+
logger.info("The selections were calculated successfully.")
|
629 |
|
630 |
+
results = self.__get_results(
|
631 |
assessment_prompts,
|
632 |
assessment_outputs,
|
633 |
summarization_prompts,
|
|
|
638 |
evaluations_count,
|
639 |
criterias,
|
640 |
)
|
641 |
+
|
642 |
return self.clean_results(results)
|
643 |
|
644 |
|
645 |
class LLMJudgePairwise(LLMJudge):
|
646 |
+
"""A judge for pairwise comparison evaluations, where two or more responses are compared to determine which one is preferred based on a criterion."""
|
647 |
main_score = "1_winrate"
|
648 |
+
"""The main score metric for pairwise evaluation. By default, its value is `1_winrate`, and will take the value of the winrate of the first system."""
|
649 |
+
reduction_map = {"mean": ["score"]}
|
650 |
+
"""A mapping specifying how scores should be reduced. By default, it will be `{'main': ['score']}`"""
|
651 |
|
652 |
def prepare(self):
|
653 |
+
"""Prepares the pairwise comparison by initializing the necessary templates and tasks. These tasks will be used to assess, summarize, and select options from candidate responses."""
|
654 |
super().prepare()
|
655 |
self.assessment_template = pairwise_template_dict["assessment"]
|
656 |
self.summarization_template = pairwise_template_dict["summarization"]
|
|
|
689 |
)
|
690 |
|
691 |
def before_process_multi_stream(self):
|
692 |
+
"""Verifies that the criteria is of the correct type before processing the multi-stream data."""
|
693 |
super().before_process_multi_stream()
|
694 |
if self.criteria is not None and not isinstance(self.criteria, Criteria):
|
695 |
raise Exception(
|
|
|
697 |
)
|
698 |
return
|
699 |
|
700 |
+
def __get_instance_results(
|
701 |
self,
|
702 |
instance_predictions: Dict[str, str],
|
703 |
assessment_prompts,
|
|
|
709 |
selections,
|
710 |
contests_count,
|
711 |
combination_indexes,
|
712 |
+
criterion: Criteria,
|
713 |
):
|
714 |
+
"""Computes the results for each instance by comparing the responses and calculating metrics such as winrate, ranking, and the responses overall performance. This method processes assessment, summarization, and option selection outputs to track contest results, positional bias, and winrate.
|
715 |
+
|
716 |
+
Args:
|
717 |
+
instance_predictions (Dict[str, str]): The predictions for each response.
|
718 |
+
assessment_prompts (List[str]): The prompts for the assessment task.
|
719 |
+
assessment_outputs (List[str]): The results from the assessment task.
|
720 |
+
summarization_prompts (List[str]): The prompts for the summarization task.
|
721 |
+
summarization_outputs (List[str]): The results from the summarization task.
|
722 |
+
option_selection_prompts (List[str]): The prompts for the option selection task.
|
723 |
+
option_selection_outputs (List[str]): The results from the option selection task.
|
724 |
+
selections (List[str]): The selections made during the pairwise comparison.
|
725 |
+
contests_count (int): The total number of contests that were run.
|
726 |
+
combination_indexes (List[Tuple[int, int]]): The indexes of the response pairs that were compared.
|
727 |
+
criterion (Criteria): The criterion used to assess the responses.
|
728 |
+
|
729 |
+
Returns:
|
730 |
+
dict: A dictionary containing the results for each response, including winrate, ranking, and other metrics.
|
731 |
+
"""
|
732 |
response_names = list(instance_predictions.keys())
|
733 |
per_response_results = {
|
734 |
response_key: {
|
|
|
887 |
for metric in single_result.keys():
|
888 |
all_results[f"{response_name}_{metric}"] = single_result[metric]
|
889 |
|
890 |
+
all_results["criteria"] = criterion.to_json()
|
891 |
return self.clean_results(all_results)
|
892 |
|
893 |
+
def __parse_prediction_to_dict(self, predictions: Union[Dict[str, str], List[str]]):
|
894 |
+
"""Converts a list or dictionary of predictions into a dictionary format.
|
895 |
+
|
896 |
+
Args:
|
897 |
+
predictions (Union[Dict[str, str], List[str]]): The prediction data to convert.
|
898 |
+
|
899 |
+
Returns:
|
900 |
+
dict: The prediction data in dictionary format.
|
901 |
+
"""
|
902 |
+
if isinstance(predictions, list):
|
903 |
+
return {f"{key + 1}": value for key, value in enumerate(predictions)}
|
904 |
+
if isinstance(predictions, dict):
|
905 |
+
return predictions
|
906 |
+
raise UnitxtError(
|
907 |
+
f"Prediction may be a list or a dict. Instead got type {type(predictions)}"
|
908 |
)
|
909 |
|
910 |
+
def __convert_predictions_to_dicts(
|
911 |
self, predictions: Union[List[Dict[str, str]], List[str]]
|
912 |
):
|
913 |
+
"""Converts a list of predictions into a list of dictionaries.
|
914 |
+
|
915 |
+
Args:
|
916 |
+
predictions (Union[List[Dict[str, str]], List[str]]): The predictions to convert.
|
917 |
+
|
918 |
+
Returns:
|
919 |
+
List[dict]: A list of predictions in dictionary format.
|
920 |
+
"""
|
921 |
+
return [self.__parse_prediction_to_dict(prediction) for prediction in predictions]
|
922 |
+
|
923 |
+
def __set_main_score(self, predictions: List[Dict[str, str]]):
|
924 |
+
self.main_score = f"{next(iter(predictions[0].keys()))}_winrate"
|
925 |
|
926 |
def compute(
|
927 |
self,
|
928 |
references: List[List[str]],
|
929 |
predictions: List[str],
|
930 |
task_data: List[Dict[str, str]],
|
931 |
+
) -> List[Dict]:
|
932 |
+
r"""Executes the pairwise comparison evaluation, including assessment, summarization, and option selection. It computes the winrate and ranking for the responses.
|
933 |
+
|
934 |
+
Args:
|
935 |
+
references (List[List[str]]): A list of reference responses for comparison.
|
936 |
+
predictions (List[str]): A list of predicted responses.
|
937 |
+
task_data (List[Dict[str, str]]): Task data to be used for evaluation.
|
938 |
+
|
939 |
+
Returns:
|
940 |
+
-------
|
941 |
+
List[Dict[str,Dict]]
|
942 |
+
The results of the evaluation, including winrate, ranking, and other metrics.
|
943 |
+
|
944 |
+
For each instance result, the following metrics are included per response/system. Each of the metrics will have appended the systems name, if predictions were provided as a list of dicts, or their index, starting from 1, if predictions were provided as a list of lists.
|
945 |
+
|
946 |
+
All the fields are arrays with length equal to `len(systems) - 1`. For any result at index `i`: `response_name[i]`'s contest against `compared_to[i]`'s result is `contest_results[i]`.
|
947 |
+
|
948 |
+
Explanation of fields:
|
949 |
+
|
950 |
+
- `summaries`: A list of LLM-generated summaries explaining the comparison results for each response.
|
951 |
+
- `contest_results`: A list of boolean values indicating whether the response won in each comparison.
|
952 |
+
- `selections`: A list of the selected system names, representing the preferred response in each comparison.
|
953 |
+
- `compared_to`: A list of system names that were compared against the given response.
|
954 |
+
- `assessments`: A list of LLM-generated assessments explaining the reasoning behind the evaluation results.
|
955 |
+
- `positional_bias_assessments`: A list of LLM-generated assessments focused on detecting positional bias in the evaluation.
|
956 |
+
- `option_selection_outputs`: A list of response names selected as the best choice based on the evaluation.
|
957 |
+
- `positional_bias`: A list of boolean values indicating whether positional bias was detected in the contest.
|
958 |
+
- `positional_bias_selection`: A list of response names representing the selected option when considering positional bias.
|
959 |
+
- `prompts`: A dictionary of prompts used in different stages of evaluation.
|
960 |
+
- `assessment`: The prompt used to instruct the model on how to assess the responses.
|
961 |
+
- `positional_bias_assessment`: The prompt used to instruct the model on how to assess positional bias.
|
962 |
+
- `option_selection`: The prompt used to guide the model in selecting the best response.
|
963 |
+
- `positional_bias_option_selection`: The prompt used for selecting the best response while checking for positional bias.
|
964 |
+
- `summary`: The prompt used to generate a summary of the assessment.
|
965 |
+
- `winrate`: A float representing the proportion of comparisons the response won.
|
966 |
+
- `llm_as_judge`: Equal to `winrate`.
|
967 |
+
- `ranking`: An integer representing the ranking position of the response based on the evaluation results. Best is 1.
|
968 |
+
- `response_name`: A string identifying the response in the evaluation.
|
969 |
+
|
970 |
+
Result example:
|
971 |
+
|
972 |
+
.. code-block:: python
|
973 |
+
|
974 |
+
[
|
975 |
+
{
|
976 |
+
"system1_contest_results": [
|
977 |
+
true,
|
978 |
+
true
|
979 |
+
],
|
980 |
+
"system1_selections": [
|
981 |
+
"system1",
|
982 |
+
"system1"
|
983 |
+
],
|
984 |
+
"system1_compared_to": [
|
985 |
+
"system2",
|
986 |
+
"system3"
|
987 |
+
],
|
988 |
+
"system1_assessments": [
|
989 |
+
"To determine the better response accordi...",
|
990 |
+
"To determine the better response accordi..."
|
991 |
+
],
|
992 |
+
"system1_positional_bias_assessments": [
|
993 |
+
"To determine the better response accordi...",
|
994 |
+
"To determine the better response accordi..."
|
995 |
+
],
|
996 |
+
"system1_option_selection_outputs": [
|
997 |
+
"system1",
|
998 |
+
"system1"
|
999 |
+
],
|
1000 |
+
"system1_positional_bias": [
|
1001 |
+
false,
|
1002 |
+
false
|
1003 |
+
],
|
1004 |
+
"system1_positional_bias_selection": [
|
1005 |
+
"system1",
|
1006 |
+
"system1"
|
1007 |
+
],
|
1008 |
+
"system1_prompts": {
|
1009 |
+
"assessment": [
|
1010 |
+
[
|
1011 |
+
{
|
1012 |
+
"role": "user",
|
1013 |
+
"content": "You are provided a pair of responses (Re..."
|
1014 |
+
}
|
1015 |
+
],
|
1016 |
+
[
|
1017 |
+
{
|
1018 |
+
"role": "user",
|
1019 |
+
"content": "You are provided a pair of responses (Re..."
|
1020 |
+
}
|
1021 |
+
]
|
1022 |
+
],
|
1023 |
+
"positional_bias_assessment": [
|
1024 |
+
[
|
1025 |
+
{
|
1026 |
+
"role": "user",
|
1027 |
+
"content": "You are provided a pair of responses (Re..."
|
1028 |
+
}
|
1029 |
+
],
|
1030 |
+
[
|
1031 |
+
{
|
1032 |
+
"role": "user",
|
1033 |
+
"content": "You are provided a pair of responses (Re..."
|
1034 |
+
}
|
1035 |
+
]
|
1036 |
+
],
|
1037 |
+
"option_selection": [
|
1038 |
+
[
|
1039 |
+
{
|
1040 |
+
"content": "You are provided a pair of responses (Re...",
|
1041 |
+
"role": "user"
|
1042 |
+
},
|
1043 |
+
{
|
1044 |
+
"content": "To determine the better response accordi...",
|
1045 |
+
"role": "assistant"
|
1046 |
+
},
|
1047 |
+
{
|
1048 |
+
"content": "Now considering the evaluation criteria,...",
|
1049 |
+
"role": "user"
|
1050 |
+
}
|
1051 |
+
],
|
1052 |
+
[
|
1053 |
+
{
|
1054 |
+
"content": "You are provided a pair of responses (Re...",
|
1055 |
+
"role": "user"
|
1056 |
+
},
|
1057 |
+
{
|
1058 |
+
"content": "To determine the better response accordi...",
|
1059 |
+
"role": "assistant"
|
1060 |
+
},
|
1061 |
+
{
|
1062 |
+
"content": "Now considering the evaluation criteria,...",
|
1063 |
+
"role": "user"
|
1064 |
+
}
|
1065 |
+
]
|
1066 |
+
],
|
1067 |
+
"positional_bias_option_selection": [
|
1068 |
+
[
|
1069 |
+
{
|
1070 |
+
"content": "You are provided a pair of responses (Re...",
|
1071 |
+
"role": "user"
|
1072 |
+
},
|
1073 |
+
{
|
1074 |
+
"content": "To determine the better response accordi...",
|
1075 |
+
"role": "assistant"
|
1076 |
+
},
|
1077 |
+
{
|
1078 |
+
"content": "Now considering the evaluation criteria,...",
|
1079 |
+
"role": "user"
|
1080 |
+
}
|
1081 |
+
],
|
1082 |
+
[
|
1083 |
+
{
|
1084 |
+
"content": "You are provided a pair of responses (Re...",
|
1085 |
+
"role": "user"
|
1086 |
+
},
|
1087 |
+
{
|
1088 |
+
"content": "To determine the better response accordi...",
|
1089 |
+
"role": "assistant"
|
1090 |
+
},
|
1091 |
+
{
|
1092 |
+
"content": "Now considering the evaluation criteria,...",
|
1093 |
+
"role": "user"
|
1094 |
+
}
|
1095 |
+
]
|
1096 |
+
]
|
1097 |
+
},
|
1098 |
+
"system1_winrate": 1.0,
|
1099 |
+
"system1_llm_as_judge": 1.0,
|
1100 |
+
"system1_ranking": 1,
|
1101 |
+
"system1_response_name": "system1",
|
1102 |
+
"system2_contest_results": [
|
1103 |
+
false,
|
1104 |
+
true
|
1105 |
+
],
|
1106 |
+
"system2_selections": [
|
1107 |
+
"system1",
|
1108 |
+
"system2"
|
1109 |
+
],
|
1110 |
+
"system2_compared_to": [
|
1111 |
+
"system1",
|
1112 |
+
"system3"
|
1113 |
+
],
|
1114 |
+
"system2_assessments": [
|
1115 |
+
"To determine the better response accordi...",
|
1116 |
+
"To determine the better response accordi..."
|
1117 |
+
],
|
1118 |
+
"system2_positional_bias_assessments": [
|
1119 |
+
"To determine the better response accordi...",
|
1120 |
+
"To determine the better response accordi..."
|
1121 |
+
],
|
1122 |
+
"system2_option_selection_outputs": [
|
1123 |
+
"system1",
|
1124 |
+
"system2"
|
1125 |
+
],
|
1126 |
+
"system2_positional_bias": [
|
1127 |
+
false,
|
1128 |
+
false
|
1129 |
+
],
|
1130 |
+
"system2_positional_bias_selection": [
|
1131 |
+
"system1",
|
1132 |
+
"system2"
|
1133 |
+
],
|
1134 |
+
"system2_prompts": {
|
1135 |
+
"assessment": [
|
1136 |
+
[
|
1137 |
+
{
|
1138 |
+
"role": "user",
|
1139 |
+
"content": "You are provided a pair of responses (Re..."
|
1140 |
+
}
|
1141 |
+
],
|
1142 |
+
[
|
1143 |
+
{
|
1144 |
+
"role": "user",
|
1145 |
+
"content": "You are provided a pair of responses (Re..."
|
1146 |
+
}
|
1147 |
+
]
|
1148 |
+
],
|
1149 |
+
"positional_bias_assessment": [
|
1150 |
+
[
|
1151 |
+
{
|
1152 |
+
"role": "user",
|
1153 |
+
"content": "You are provided a pair of responses (Re..."
|
1154 |
+
}
|
1155 |
+
],
|
1156 |
+
[
|
1157 |
+
{
|
1158 |
+
"role": "user",
|
1159 |
+
"content": "You are provided a pair of responses (Re..."
|
1160 |
+
}
|
1161 |
+
]
|
1162 |
+
],
|
1163 |
+
"option_selection": [
|
1164 |
+
[
|
1165 |
+
{
|
1166 |
+
"content": "You are provided a pair of responses (Re...",
|
1167 |
+
"role": "user"
|
1168 |
+
},
|
1169 |
+
{
|
1170 |
+
"content": "To determine the better response accordi...",
|
1171 |
+
"role": "assistant"
|
1172 |
+
},
|
1173 |
+
{
|
1174 |
+
"content": "Now considering the evaluation criteria,...",
|
1175 |
+
"role": "user"
|
1176 |
+
}
|
1177 |
+
],
|
1178 |
+
[
|
1179 |
+
{
|
1180 |
+
"content": "You are provided a pair of responses (Re...",
|
1181 |
+
"role": "user"
|
1182 |
+
},
|
1183 |
+
{
|
1184 |
+
"content": "To determine the better response accordi...",
|
1185 |
+
"role": "assistant"
|
1186 |
+
},
|
1187 |
+
{
|
1188 |
+
"content": "Now considering the evaluation criteria,...",
|
1189 |
+
"role": "user"
|
1190 |
+
}
|
1191 |
+
]
|
1192 |
+
],
|
1193 |
+
"positional_bias_option_selection": [
|
1194 |
+
[
|
1195 |
+
{
|
1196 |
+
"content": "You are provided a pair of responses (Re...",
|
1197 |
+
"role": "user"
|
1198 |
+
},
|
1199 |
+
{
|
1200 |
+
"content": "To determine the better response accordi...",
|
1201 |
+
"role": "assistant"
|
1202 |
+
},
|
1203 |
+
{
|
1204 |
+
"content": "Now considering the evaluation criteria,...",
|
1205 |
+
"role": "user"
|
1206 |
+
}
|
1207 |
+
],
|
1208 |
+
[
|
1209 |
+
{
|
1210 |
+
"content": "You are provided a pair of responses (Re...",
|
1211 |
+
"role": "user"
|
1212 |
+
},
|
1213 |
+
{
|
1214 |
+
"content": "To determine the better response accordi...",
|
1215 |
+
"role": "assistant"
|
1216 |
+
},
|
1217 |
+
{
|
1218 |
+
"content": "Now considering the evaluation criteria,...",
|
1219 |
+
"role": "user"
|
1220 |
+
}
|
1221 |
+
]
|
1222 |
+
]
|
1223 |
+
},
|
1224 |
+
"system2_winrate": 0.5,
|
1225 |
+
"system2_llm_as_judge": 0.5,
|
1226 |
+
"system2_ranking": 2,
|
1227 |
+
"system2_response_name": "system2",
|
1228 |
+
"system3_contest_results": [
|
1229 |
+
false,
|
1230 |
+
false
|
1231 |
+
],
|
1232 |
+
"system3_selections": [
|
1233 |
+
"system1",
|
1234 |
+
"system2"
|
1235 |
+
],
|
1236 |
+
"system3_compared_to": [
|
1237 |
+
"system1",
|
1238 |
+
"system2"
|
1239 |
+
],
|
1240 |
+
"system3_assessments": [
|
1241 |
+
"To determine the better response accordi...",
|
1242 |
+
"To determine the better response accordi..."
|
1243 |
+
],
|
1244 |
+
"system3_positional_bias_assessments": [
|
1245 |
+
"To determine the better response accordi...",
|
1246 |
+
"To determine the better response accordi..."
|
1247 |
+
],
|
1248 |
+
"system3_option_selection_outputs": [
|
1249 |
+
"system1",
|
1250 |
+
"system2"
|
1251 |
+
],
|
1252 |
+
"system3_positional_bias": [
|
1253 |
+
false,
|
1254 |
+
false
|
1255 |
+
],
|
1256 |
+
"system3_positional_bias_selection": [
|
1257 |
+
"system1",
|
1258 |
+
"system2"
|
1259 |
+
],
|
1260 |
+
"system3_prompts": {
|
1261 |
+
"assessment": [
|
1262 |
+
[
|
1263 |
+
{
|
1264 |
+
"role": "user",
|
1265 |
+
"content": "You are provided a pair of responses (Re..."
|
1266 |
+
}
|
1267 |
+
],
|
1268 |
+
[
|
1269 |
+
{
|
1270 |
+
"role": "user",
|
1271 |
+
"content": "You are provided a pair of responses (Re..."
|
1272 |
+
}
|
1273 |
+
]
|
1274 |
+
],
|
1275 |
+
"positional_bias_assessment": [
|
1276 |
+
[
|
1277 |
+
{
|
1278 |
+
"role": "user",
|
1279 |
+
"content": "You are provided a pair of responses (Re..."
|
1280 |
+
}
|
1281 |
+
],
|
1282 |
+
[
|
1283 |
+
{
|
1284 |
+
"role": "user",
|
1285 |
+
"content": "You are provided a pair of responses (Re..."
|
1286 |
+
}
|
1287 |
+
]
|
1288 |
+
],
|
1289 |
+
"option_selection": [
|
1290 |
+
[
|
1291 |
+
{
|
1292 |
+
"content": "You are provided a pair of responses (Re...",
|
1293 |
+
"role": "user"
|
1294 |
+
},
|
1295 |
+
{
|
1296 |
+
"content": "To determine the better response accordi...",
|
1297 |
+
"role": "assistant"
|
1298 |
+
},
|
1299 |
+
{
|
1300 |
+
"content": "Now considering the evaluation criteria,...",
|
1301 |
+
"role": "user"
|
1302 |
+
}
|
1303 |
+
],
|
1304 |
+
[
|
1305 |
+
{
|
1306 |
+
"content": "You are provided a pair of responses (Re...",
|
1307 |
+
"role": "user"
|
1308 |
+
},
|
1309 |
+
{
|
1310 |
+
"content": "To determine the better response accordi...",
|
1311 |
+
"role": "assistant"
|
1312 |
+
},
|
1313 |
+
{
|
1314 |
+
"content": "Now considering the evaluation criteria,...",
|
1315 |
+
"role": "user"
|
1316 |
+
}
|
1317 |
+
]
|
1318 |
+
],
|
1319 |
+
"positional_bias_option_selection": [
|
1320 |
+
[
|
1321 |
+
{
|
1322 |
+
"content": "You are provided a pair of responses (Re...",
|
1323 |
+
"role": "user"
|
1324 |
+
},
|
1325 |
+
{
|
1326 |
+
"content": "To determine the better response accordi...",
|
1327 |
+
"role": "assistant"
|
1328 |
+
},
|
1329 |
+
{
|
1330 |
+
"content": "Now considering the evaluation criteria,...",
|
1331 |
+
"role": "user"
|
1332 |
+
}
|
1333 |
+
],
|
1334 |
+
[
|
1335 |
+
{
|
1336 |
+
"content": "You are provided a pair of responses (Re...",
|
1337 |
+
"role": "user"
|
1338 |
+
},
|
1339 |
+
{
|
1340 |
+
"content": "To determine the better response accordi...",
|
1341 |
+
"role": "assistant"
|
1342 |
+
},
|
1343 |
+
{
|
1344 |
+
"content": "Now considering the evaluation criteria,...",
|
1345 |
+
"role": "user"
|
1346 |
+
}
|
1347 |
+
]
|
1348 |
+
]
|
1349 |
+
},
|
1350 |
+
"system3_winrate": 0.0,
|
1351 |
+
"system3_llm_as_judge": 0.0,
|
1352 |
+
"system3_ranking": 3,
|
1353 |
+
"system3_response_name": "system3",
|
1354 |
+
"criteria": "{ \"__type__\": \"criteria\", \"name\"..."
|
1355 |
+
}
|
1356 |
+
]
|
1357 |
+
"""
|
1358 |
+
logger.info(
|
1359 |
f'Starting evaluation with evaluator "{self.evaluator_name}" and provider {self.inference_engine.get_pretty_print_name()}'
|
1360 |
)
|
1361 |
+
predictions = self.__convert_predictions_to_dicts(predictions)
|
1362 |
+
self.__set_main_score(predictions)
|
1363 |
instances_count = len(predictions)
|
1364 |
self.reduction_map = {"mean": ["score"]}
|
1365 |
self.reduction_map["mean"].extend(
|
|
|
1375 |
len(combination_indexes) for combination_indexes in combination_indexes_list
|
1376 |
]
|
1377 |
|
1378 |
+
logger.info(
|
1379 |
f"The evaluation will perform {sum(contests_count_list) * [1, 2][self.check_positional_bias]} ({' + '.join([f'{c * [1, 2][self.check_positional_bias]}' for c in contests_count_list])}) pairwise comparisons"
|
1380 |
)
|
1381 |
|
|
|
1406 |
response_pairs_list.append(response_pairs)
|
1407 |
option_pairs_list.append(option_pairs)
|
1408 |
|
1409 |
+
criterias = self.get_criteria(task_data, instances_count)
|
1410 |
contexts = self.get_contexts(task_data)
|
1411 |
if self.check_positional_bias:
|
1412 |
criterias.extend(criterias)
|
|
|
1440 |
assessment_prompts, assessment_outputs, _ = self.perform_evaluation_step(
|
1441 |
assessment_instances, self.assessment_task, self.assessment_template
|
1442 |
)
|
1443 |
+
logger.info("The assessment was generated successfully.")
|
1444 |
|
1445 |
# the slices used to get the assessment for each summary generation instance
|
1446 |
# it will grab the whole assessment for a particular instance or half of it depending on the value of check_positional_bias
|
|
|
1490 |
self.summarization_task,
|
1491 |
self.summarization_template,
|
1492 |
)
|
1493 |
+
logger.info("The summary was generated successfully.")
|
1494 |
|
1495 |
score_option_instruction_list = [
|
1496 |
"".join(
|
|
|
1538 |
)
|
1539 |
# Selections are of the form 'Response n', so we just keep n
|
1540 |
selections = [selection.split(" ")[-1] for selection in selections]
|
1541 |
+
logger.info("The selections were calculated successfully.")
|
1542 |
results = []
|
1543 |
slice_start = 0
|
1544 |
for i, incremental_contests_count in enumerate(incremental_contests_count_list):
|
|
|
1551 |
(incremental_contests_count_list[i - 1] if i > 0 else 0)
|
1552 |
+ incremental_contests_count,
|
1553 |
)
|
1554 |
+
instance_results = self.__get_instance_results(
|
1555 |
predictions[i],
|
1556 |
assessment_prompts[sli],
|
1557 |
assessment_outputs[sli],
|
llm_as_judge_chat_templates.py
CHANGED
@@ -29,11 +29,13 @@ Assessment: {assessment}
|
|
29 |
Summary:"""
|
30 |
),
|
31 |
"answer": InputOutputTemplate(
|
32 |
-
input_format="""Now
|
|
|
33 |
###Evaluation criteria:
|
34 |
{criteria_description}
|
35 |
-
{
|
36 |
-
|
|
|
37 |
postprocessors=["processors.match_closest_option"],
|
38 |
),
|
39 |
}
|
|
|
29 |
Summary:"""
|
30 |
),
|
31 |
"answer": InputOutputTemplate(
|
32 |
+
input_format="""Now based on the assessment, choose a criteria option. Only include the chosen option in the response. If the assessment already contains a selected option, choose that option. Don't contradict the assessment's selected option.
|
33 |
+
|
34 |
###Evaluation criteria:
|
35 |
{criteria_description}
|
36 |
+
{display_options_instruction}
|
37 |
+
|
38 |
+
The selected criteria option is: """,
|
39 |
postprocessors=["processors.match_closest_option"],
|
40 |
),
|
41 |
}
|
llm_as_judge_constants.py
CHANGED
@@ -3,10 +3,6 @@ from enum import Enum
|
|
3 |
from typing import Dict, List, Optional
|
4 |
|
5 |
from .artifact import Artifact
|
6 |
-
from .inference import (
|
7 |
-
LiteLLMInferenceEngine,
|
8 |
-
RITSInferenceEngine,
|
9 |
-
)
|
10 |
|
11 |
|
12 |
class OptionSelectionStrategyEnum(str, Enum):
|
@@ -68,13 +64,13 @@ class EvaluatorTypeEnum(str, Enum):
|
|
68 |
|
69 |
class EvaluatorNameEnum(str, Enum):
|
70 |
MIXTRAL8_7b = "Mixtral8-7b"
|
71 |
-
MIXTRAL8_22b = "Mixtral8-22b"
|
72 |
MIXTRAL_LARGE = "Mixtral Large"
|
73 |
LLAMA3_8B = "Llama3-8b"
|
74 |
LLAMA3_1_405B = "Llama3.1-405b"
|
75 |
LLAMA3_1_8B = "Llama3.1-8b"
|
76 |
LLAMA3_1_70B = "Llama3.1-70b"
|
77 |
LLAMA3_2_3B = "Llama3.2-3b"
|
|
|
78 |
PROMETHEUS = "Prometheus"
|
79 |
GPT4 = "GPT-4o"
|
80 |
O1_PREVIEW = "o1-Preview"
|
@@ -84,53 +80,33 @@ class EvaluatorNameEnum(str, Enum):
|
|
84 |
GRANITE3_8B = "Granite3.0-8b"
|
85 |
GRANITE3_1_2B = "Granite3.1-2b"
|
86 |
GRANITE3_1_8B = "Granite3.1-8b"
|
|
|
87 |
|
88 |
|
89 |
class ModelProviderEnum(str, Enum):
|
90 |
WATSONX = "watsonx"
|
91 |
OPENAI = "openai"
|
92 |
RITS = "rits"
|
93 |
-
AZURE_OPENAI = "
|
94 |
|
95 |
|
96 |
EVALUATOR_TO_MODEL_ID = {
|
97 |
-
EvaluatorNameEnum.MIXTRAL8_7b: "
|
98 |
-
EvaluatorNameEnum.
|
99 |
-
EvaluatorNameEnum.
|
100 |
-
EvaluatorNameEnum.
|
101 |
-
EvaluatorNameEnum.
|
102 |
-
EvaluatorNameEnum.
|
103 |
-
EvaluatorNameEnum.LLAMA3_2_3B: "meta-llama/llama-3-2-3b-instruct",
|
104 |
-
EvaluatorNameEnum.PROMETHEUS: "kaist-ai/prometheus-8x7b-v2",
|
105 |
EvaluatorNameEnum.GPT4: "gpt-4o-2024-08-06",
|
106 |
-
EvaluatorNameEnum.O1_PREVIEW: "o1-preview
|
107 |
-
EvaluatorNameEnum.O1_MINI: "o1-mini
|
108 |
-
EvaluatorNameEnum.
|
109 |
-
EvaluatorNameEnum.
|
110 |
-
EvaluatorNameEnum.
|
111 |
-
EvaluatorNameEnum.
|
112 |
-
EvaluatorNameEnum.
|
113 |
}
|
114 |
|
115 |
-
MODEL_RENAMINGS = {
|
116 |
-
ModelProviderEnum.RITS: {
|
117 |
-
"meta-llama/llama-3-1-8b-instruct": "meta-llama/Llama-3.1-8B-Instruct",
|
118 |
-
"mistralai/mixtral-8x7b-instruct-v01": "mistralai/mixtral-8x7B-instruct-v0.1",
|
119 |
-
"ibm/granite-3-8b-instruct": "ibm-granite/granite-3.0-8b-instruct",
|
120 |
-
"ibm/granite-3.1-8b-instruct": "ibm-granite/granite-3.1-8b-instruct",
|
121 |
-
"meta-llama/llama-3-405b-instruct": "meta-llama/llama-3-1-405b-instruct-fp8",
|
122 |
-
"mistralai/mistral-large": "mistralai/mistral-large-instruct-2407",
|
123 |
-
},
|
124 |
-
}
|
125 |
-
|
126 |
-
INFERENCE_ENGINE_NAME_TO_CLASS = {
|
127 |
-
ModelProviderEnum.WATSONX: LiteLLMInferenceEngine,
|
128 |
-
ModelProviderEnum.OPENAI: LiteLLMInferenceEngine,
|
129 |
-
ModelProviderEnum.RITS: RITSInferenceEngine,
|
130 |
-
ModelProviderEnum.AZURE_OPENAI: LiteLLMInferenceEngine,
|
131 |
-
}
|
132 |
-
|
133 |
-
|
134 |
class EvaluatorMetadata:
|
135 |
name: EvaluatorNameEnum
|
136 |
providers: List[ModelProviderEnum]
|
@@ -145,10 +121,6 @@ EVALUATORS_METADATA = [
|
|
145 |
EvaluatorNameEnum.MIXTRAL8_7b,
|
146 |
[ModelProviderEnum.RITS, ModelProviderEnum.WATSONX],
|
147 |
),
|
148 |
-
EvaluatorMetadata(
|
149 |
-
EvaluatorNameEnum.MIXTRAL8_22b,
|
150 |
-
[ModelProviderEnum.RITS],
|
151 |
-
),
|
152 |
EvaluatorMetadata(
|
153 |
EvaluatorNameEnum.MIXTRAL_LARGE,
|
154 |
[ModelProviderEnum.RITS, ModelProviderEnum.WATSONX],
|
@@ -161,6 +133,10 @@ EVALUATORS_METADATA = [
|
|
161 |
EvaluatorNameEnum.GRANITE3_1_8B,
|
162 |
[ModelProviderEnum.RITS],
|
163 |
),
|
|
|
|
|
|
|
|
|
164 |
EvaluatorMetadata(
|
165 |
EvaluatorNameEnum.GPT4,
|
166 |
[ModelProviderEnum.OPENAI, ModelProviderEnum.AZURE_OPENAI],
|
@@ -185,6 +161,10 @@ EVALUATORS_METADATA = [
|
|
185 |
EvaluatorNameEnum.LLAMA3_1_405B,
|
186 |
[ModelProviderEnum.WATSONX, ModelProviderEnum.RITS],
|
187 |
),
|
|
|
|
|
|
|
|
|
188 |
]
|
189 |
|
190 |
################################ Direct Assessment Criterias ################################
|
|
|
3 |
from typing import Dict, List, Optional
|
4 |
|
5 |
from .artifact import Artifact
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
class OptionSelectionStrategyEnum(str, Enum):
|
|
|
64 |
|
65 |
class EvaluatorNameEnum(str, Enum):
|
66 |
MIXTRAL8_7b = "Mixtral8-7b"
|
|
|
67 |
MIXTRAL_LARGE = "Mixtral Large"
|
68 |
LLAMA3_8B = "Llama3-8b"
|
69 |
LLAMA3_1_405B = "Llama3.1-405b"
|
70 |
LLAMA3_1_8B = "Llama3.1-8b"
|
71 |
LLAMA3_1_70B = "Llama3.1-70b"
|
72 |
LLAMA3_2_3B = "Llama3.2-3b"
|
73 |
+
LLAMA3_3_70B = "Llama3.3-70b"
|
74 |
PROMETHEUS = "Prometheus"
|
75 |
GPT4 = "GPT-4o"
|
76 |
O1_PREVIEW = "o1-Preview"
|
|
|
80 |
GRANITE3_8B = "Granite3.0-8b"
|
81 |
GRANITE3_1_2B = "Granite3.1-2b"
|
82 |
GRANITE3_1_8B = "Granite3.1-8b"
|
83 |
+
GRANITE3_2_8B = "Granite3.2-8b"
|
84 |
|
85 |
|
86 |
class ModelProviderEnum(str, Enum):
|
87 |
WATSONX = "watsonx"
|
88 |
OPENAI = "openai"
|
89 |
RITS = "rits"
|
90 |
+
AZURE_OPENAI = "azure"
|
91 |
|
92 |
|
93 |
EVALUATOR_TO_MODEL_ID = {
|
94 |
+
EvaluatorNameEnum.MIXTRAL8_7b: "mixtral-8x7b-instruct-v01",
|
95 |
+
EvaluatorNameEnum.MIXTRAL_LARGE: "mistral-large-instruct",
|
96 |
+
EvaluatorNameEnum.LLAMA3_1_405B: "llama-3-1-405b-instruct",
|
97 |
+
EvaluatorNameEnum.LLAMA3_1_8B: "llama-3-1-70b-instruct",
|
98 |
+
EvaluatorNameEnum.LLAMA3_1_70B: "llama-3-1-70b-instruct",
|
99 |
+
EvaluatorNameEnum.LLAMA3_3_70B: "llama-3-3-70b-instruct",
|
|
|
|
|
100 |
EvaluatorNameEnum.GPT4: "gpt-4o-2024-08-06",
|
101 |
+
EvaluatorNameEnum.O1_PREVIEW: "o1-preview",
|
102 |
+
EvaluatorNameEnum.O1_MINI: "o1-mini",
|
103 |
+
EvaluatorNameEnum.GRANITE3_2B: "granite-3-2b-instruct",
|
104 |
+
EvaluatorNameEnum.GRANITE3_8B: "granite-3-8b-instruct",
|
105 |
+
EvaluatorNameEnum.GRANITE3_1_2B: "granite-3-1-2b-instruct",
|
106 |
+
EvaluatorNameEnum.GRANITE3_1_8B: "granite-3-1-8b-instruct",
|
107 |
+
EvaluatorNameEnum.GRANITE3_2_8B: "granite-3-2-8b-instruct",
|
108 |
}
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
class EvaluatorMetadata:
|
111 |
name: EvaluatorNameEnum
|
112 |
providers: List[ModelProviderEnum]
|
|
|
121 |
EvaluatorNameEnum.MIXTRAL8_7b,
|
122 |
[ModelProviderEnum.RITS, ModelProviderEnum.WATSONX],
|
123 |
),
|
|
|
|
|
|
|
|
|
124 |
EvaluatorMetadata(
|
125 |
EvaluatorNameEnum.MIXTRAL_LARGE,
|
126 |
[ModelProviderEnum.RITS, ModelProviderEnum.WATSONX],
|
|
|
133 |
EvaluatorNameEnum.GRANITE3_1_8B,
|
134 |
[ModelProviderEnum.RITS],
|
135 |
),
|
136 |
+
EvaluatorMetadata(
|
137 |
+
EvaluatorNameEnum.GRANITE3_2_8B,
|
138 |
+
[ModelProviderEnum.WATSONX, ModelProviderEnum.RITS],
|
139 |
+
),
|
140 |
EvaluatorMetadata(
|
141 |
EvaluatorNameEnum.GPT4,
|
142 |
[ModelProviderEnum.OPENAI, ModelProviderEnum.AZURE_OPENAI],
|
|
|
161 |
EvaluatorNameEnum.LLAMA3_1_405B,
|
162 |
[ModelProviderEnum.WATSONX, ModelProviderEnum.RITS],
|
163 |
),
|
164 |
+
EvaluatorMetadata(
|
165 |
+
EvaluatorNameEnum.LLAMA3_3_70B,
|
166 |
+
[ModelProviderEnum.WATSONX, ModelProviderEnum.RITS],
|
167 |
+
),
|
168 |
]
|
169 |
|
170 |
################################ Direct Assessment Criterias ################################
|
llm_as_judge_utils.py
CHANGED
@@ -2,10 +2,8 @@ from typing import Dict
|
|
2 |
|
3 |
from .llm_as_judge_constants import (
|
4 |
EVALUATORS_METADATA,
|
5 |
-
MODEL_RENAMINGS,
|
6 |
EvaluatorMetadata,
|
7 |
EvaluatorNameEnum,
|
8 |
-
ModelProviderEnum,
|
9 |
)
|
10 |
|
11 |
|
@@ -32,13 +30,6 @@ def get_evaluator_metadata(
|
|
32 |
raise ValueError(f"An evaluator with id {name} matched several models.")
|
33 |
return evaluator_search[0]
|
34 |
|
35 |
-
|
36 |
-
def rename_model_if_required(model_name: str, provider: ModelProviderEnum) -> str:
|
37 |
-
if provider in MODEL_RENAMINGS and model_name in MODEL_RENAMINGS[provider]:
|
38 |
-
return MODEL_RENAMINGS[provider][model_name]
|
39 |
-
return model_name
|
40 |
-
|
41 |
-
|
42 |
def rank_indexes(numbers):
|
43 |
# Generate the initial list of indices
|
44 |
indices = list(range(len(numbers)))
|
|
|
2 |
|
3 |
from .llm_as_judge_constants import (
|
4 |
EVALUATORS_METADATA,
|
|
|
5 |
EvaluatorMetadata,
|
6 |
EvaluatorNameEnum,
|
|
|
7 |
)
|
8 |
|
9 |
|
|
|
30 |
raise ValueError(f"An evaluator with id {name} matched several models.")
|
31 |
return evaluator_search[0]
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
def rank_indexes(numbers):
|
34 |
# Generate the initial list of indices
|
35 |
indices = list(range(len(numbers)))
|
loaders.py
CHANGED
@@ -67,7 +67,7 @@ from huggingface_hub import HfApi
|
|
67 |
from tqdm import tqdm
|
68 |
|
69 |
from .dataclass import NonPositionalField
|
70 |
-
from .error_utils import UnitxtError, UnitxtWarning
|
71 |
from .fusion import FixedFusion
|
72 |
from .logging_utils import get_logger
|
73 |
from .operator import SourceOperator
|
@@ -80,19 +80,27 @@ from .utils import LRUCache, recursive_copy
|
|
80 |
logger = get_logger()
|
81 |
settings = get_settings()
|
82 |
|
|
|
|
|
|
|
|
|
83 |
def hf_load_dataset(path: str, *args, **kwargs):
|
84 |
if settings.hf_offline_datasets_path is not None:
|
85 |
path = os.path.join(settings.hf_offline_datasets_path, path)
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
96 |
|
97 |
class Loader(SourceOperator):
|
98 |
"""A base class for all loaders.
|
@@ -288,26 +296,21 @@ class LoadHF(LazyLoader):
|
|
288 |
if dataset is None:
|
289 |
if streaming is None:
|
290 |
streaming = self.is_streaming()
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
except ValueError as e:
|
303 |
-
if "trust_remote_code" in str(e):
|
304 |
-
raise ValueError(
|
305 |
-
f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
|
306 |
-
) from e
|
307 |
self.__class__._loader_cache.max_size = settings.loader_cache_size
|
308 |
if not disable_memory_caching:
|
309 |
self.__class__._loader_cache[dataset_id] = dataset
|
310 |
-
return
|
311 |
|
312 |
def _maybe_set_classification_policy(self):
|
313 |
if os.path.exists(self.path):
|
@@ -333,7 +336,9 @@ class LoadHF(LazyLoader):
|
|
333 |
extract_on_the_fly=True,
|
334 |
),
|
335 |
)
|
336 |
-
except:
|
|
|
|
|
337 |
UnitxtWarning(
|
338 |
f'LoadHF(path="{self.path}", name="{self.name}") could not retrieve split names without loading the dataset. Consider defining "splits" in the LoadHF definition to improve loading time.'
|
339 |
)
|
@@ -599,11 +604,11 @@ class LoadFromIBMCloud(Loader):
|
|
599 |
load_ibm_cloud = LoadFromIBMCloud(
|
600 |
endpoint_url_env='IBM_CLOUD_ENDPOINT',
|
601 |
aws_access_key_id_env='IBM_AWS_ACCESS_KEY_ID',
|
602 |
-
aws_secret_access_key_env='IBM_AWS_SECRET_ACCESS_KEY',
|
603 |
bucket_name='my-bucket'
|
604 |
)
|
605 |
multi_stream = load_ibm_cloud.process()
|
606 |
-
"""
|
607 |
|
608 |
endpoint_url_env: str
|
609 |
aws_access_key_id_env: str
|
|
|
67 |
from tqdm import tqdm
|
68 |
|
69 |
from .dataclass import NonPositionalField
|
70 |
+
from .error_utils import Documentation, UnitxtError, UnitxtWarning
|
71 |
from .fusion import FixedFusion
|
72 |
from .logging_utils import get_logger
|
73 |
from .operator import SourceOperator
|
|
|
80 |
logger = get_logger()
|
81 |
settings = get_settings()
|
82 |
|
83 |
+
class UnitxtUnverifiedCodeError(UnitxtError):
|
84 |
+
def __init__(self, path):
|
85 |
+
super().__init__(f"Loader cannot load and run remote code from {path} in huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE.", Documentation.SETTINGS)
|
86 |
+
|
87 |
def hf_load_dataset(path: str, *args, **kwargs):
|
88 |
if settings.hf_offline_datasets_path is not None:
|
89 |
path = os.path.join(settings.hf_offline_datasets_path, path)
|
90 |
+
try:
|
91 |
+
return _hf_load_dataset(
|
92 |
+
path,
|
93 |
+
*args, **kwargs,
|
94 |
+
download_config=DownloadConfig(
|
95 |
+
max_retries=settings.loaders_max_retries,
|
96 |
+
),
|
97 |
+
verification_mode="no_checks",
|
98 |
+
trust_remote_code=settings.allow_unverified_code,
|
99 |
+
download_mode= "force_redownload" if settings.disable_hf_datasets_cache else "reuse_dataset_if_exists"
|
100 |
+
)
|
101 |
+
except ValueError as e:
|
102 |
+
if "trust_remote_code" in str(e):
|
103 |
+
raise UnitxtUnverifiedCodeError(path) from e
|
104 |
|
105 |
class Loader(SourceOperator):
|
106 |
"""A base class for all loaders.
|
|
|
296 |
if dataset is None:
|
297 |
if streaming is None:
|
298 |
streaming = self.is_streaming()
|
299 |
+
|
300 |
+
dataset = hf_load_dataset(
|
301 |
+
self.path,
|
302 |
+
name=self.name,
|
303 |
+
data_dir=self.data_dir,
|
304 |
+
data_files=self.data_files,
|
305 |
+
revision=self.revision,
|
306 |
+
streaming=streaming,
|
307 |
+
split=split,
|
308 |
+
num_proc=self.num_proc,
|
309 |
+
)
|
|
|
|
|
|
|
|
|
|
|
310 |
self.__class__._loader_cache.max_size = settings.loader_cache_size
|
311 |
if not disable_memory_caching:
|
312 |
self.__class__._loader_cache[dataset_id] = dataset
|
313 |
+
return dataset
|
314 |
|
315 |
def _maybe_set_classification_policy(self):
|
316 |
if os.path.exists(self.path):
|
|
|
336 |
extract_on_the_fly=True,
|
337 |
),
|
338 |
)
|
339 |
+
except Exception as e:
|
340 |
+
if "trust_remote_code" in str(e):
|
341 |
+
raise UnitxtUnverifiedCodeError(self.path) from e
|
342 |
UnitxtWarning(
|
343 |
f'LoadHF(path="{self.path}", name="{self.name}") could not retrieve split names without loading the dataset. Consider defining "splits" in the LoadHF definition to improve loading time.'
|
344 |
)
|
|
|
604 |
load_ibm_cloud = LoadFromIBMCloud(
|
605 |
endpoint_url_env='IBM_CLOUD_ENDPOINT',
|
606 |
aws_access_key_id_env='IBM_AWS_ACCESS_KEY_ID',
|
607 |
+
aws_secret_access_key_env='IBM_AWS_SECRET_ACCESS_KEY', # pragma: allowlist secret
|
608 |
bucket_name='my-bucket'
|
609 |
)
|
610 |
multi_stream = load_ibm_cloud.process()
|
611 |
+
"""
|
612 |
|
613 |
endpoint_url_env: str
|
614 |
aws_access_key_id_env: str
|
metrics.py
CHANGED
@@ -63,13 +63,10 @@ from .operator import (
|
|
63 |
from .operators import ArtifactFetcherMixin, Copy, Set
|
64 |
from .random_utils import get_seed
|
65 |
from .settings_utils import get_settings
|
66 |
-
from .sql_utils import get_db_connector
|
67 |
from .stream import MultiStream, Stream
|
68 |
from .type_utils import Type, isoftype, parse_type_string, to_type_string
|
69 |
from .utils import deep_copy, recursive_copy
|
70 |
|
71 |
-
FINQA_HASH = "42430b8613082bb4b85d49210284135d"
|
72 |
-
|
73 |
logger = get_logger()
|
74 |
settings = get_settings()
|
75 |
|
@@ -127,13 +124,18 @@ def nan_mean(x):
|
|
127 |
|
128 |
def nan_max(x):
|
129 |
with warnings.catch_warnings():
|
130 |
-
# final mean should be mean of scores, ignoring NaN, hence nanmax
|
131 |
-
# but if the group function values is NaN for ALL values, nanmean throws a
|
132 |
-
# RuntimeWarning that it is calculating the mean of an empty slice (with no non-Nans)
|
133 |
-
# this is the desired behavior, but we want to avoid the warning here
|
134 |
warnings.simplefilter("ignore", category=RuntimeWarning)
|
135 |
return np.nanmax(x)
|
136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
class UpdateStream(InstanceOperator):
|
139 |
update: dict
|
@@ -365,6 +367,43 @@ def new_random_generator():
|
|
365 |
return np.random.default_rng(hash(get_seed()) & _max_32bit)
|
366 |
|
367 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
class ConfidenceIntervalMixin(Artifact):
|
369 |
n_resamples: int = 1000
|
370 |
confidence_level: float = 0.95
|
@@ -374,42 +413,41 @@ class ConfidenceIntervalMixin(Artifact):
|
|
374 |
def _sample_to_scores(self, sample: List[Any]) -> Dict[str, Any]:
|
375 |
pass
|
376 |
|
377 |
-
def get_statistic(self, data: List[Any], score_names: List[str]):
|
378 |
-
def statistic_function(indices, axis=0):
|
379 |
-
# indices might be a 1D or 2D array, depending on bootstrap internals
|
380 |
-
# For simplicity, ensure we handle them as 1D.
|
381 |
-
indices = np.atleast_1d(indices).astype(int)
|
382 |
-
|
383 |
-
# Gather the subset
|
384 |
-
sample = [data[i] for i in indices]
|
385 |
-
|
386 |
-
# Compute metrics on this sample
|
387 |
-
scores = self._sample_to_scores(sample)
|
388 |
-
|
389 |
-
# Return them in consistent order
|
390 |
-
return np.array([scores[m] for m in score_names])
|
391 |
-
|
392 |
-
return statistic_function
|
393 |
|
394 |
def bootstrap(self, data: List[Any], score_names: List[str]):
|
395 |
if self.ci_score_names is not None:
|
396 |
score_names = self.ci_score_names
|
397 |
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
408 |
|
409 |
result = {}
|
410 |
for i, metric in enumerate(score_names):
|
411 |
-
|
412 |
-
|
|
|
|
|
|
|
|
|
|
|
413 |
|
414 |
return result
|
415 |
|
@@ -2769,7 +2807,7 @@ class FinQAEval(InstanceMetric):
|
|
2769 |
remote_url = "https://raw.githubusercontent.com/czyssrs/FinQA/dfc5b72c01ee17c442d28d5201b82a1f4e95d5af/code/evaluate/evaluate.py"
|
2770 |
local_filepath = "/tmp/finqa_eval_script.py"
|
2771 |
module_name = "finqa_eval"
|
2772 |
-
hash_of_script =
|
2773 |
|
2774 |
download_finqa_eval_script_file(remote_url, local_filepath, hash_of_script)
|
2775 |
self.finqa_module = load_finqa_eval_module_from_file(
|
@@ -3375,25 +3413,83 @@ class CustomF1(GlobalMetric):
|
|
3375 |
result["precision_macro"] = self.zero_division
|
3376 |
|
3377 |
|
3378 |
-
class
|
3379 |
-
"""F1 Metrics that receives as input a list of (Entity,EntityType) pairs."""
|
3380 |
|
3381 |
-
prediction_type =
|
|
|
|
|
|
|
|
|
|
|
|
|
3382 |
|
3383 |
-
def
|
3384 |
-
|
|
|
|
|
|
|
|
|
|
|
3385 |
|
3386 |
-
|
3387 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3388 |
|
|
|
3389 |
|
3390 |
-
class
|
3391 |
-
"""F1 Metrics that receives as input a list of (
|
3392 |
|
3393 |
prediction_type = List[Tuple[str, str]]
|
3394 |
|
3395 |
def get_element_group(self, element, additional_input):
|
3396 |
-
return element[
|
3397 |
|
3398 |
def get_element_representation(self, element, additional_input):
|
3399 |
return str(element)
|
@@ -6004,6 +6100,9 @@ class GraniteGuardianBase(InstanceMetric):
|
|
6004 |
)
|
6005 |
|
6006 |
def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
|
|
|
|
|
|
|
6007 |
self.verify_granite_guardian_config(task_data)
|
6008 |
self.set_main_score()
|
6009 |
|
@@ -6017,7 +6116,10 @@ class GraniteGuardianBase(InstanceMetric):
|
|
6017 |
)
|
6018 |
messages = self.process_input_fields(task_data)
|
6019 |
prompt = self.get_prompt(messages)
|
6020 |
-
|
|
|
|
|
|
|
6021 |
generated_tokens_list = result[0]
|
6022 |
label, prob_of_risk = self.parse_output(generated_tokens_list)
|
6023 |
confidence_score = (
|
@@ -6030,6 +6132,7 @@ class GraniteGuardianBase(InstanceMetric):
|
|
6030 |
f"{self.main_score}_prob_of_risk": prob_of_risk,
|
6031 |
f"{self.main_score}_certainty": confidence_score,
|
6032 |
f"{self.main_score}_label": label,
|
|
|
6033 |
}
|
6034 |
logger.debug(f"Results are ready:\n{result}")
|
6035 |
return result
|
@@ -6042,7 +6145,7 @@ class GraniteGuardianBase(InstanceMetric):
|
|
6042 |
generated_tokens["top_tokens"] for generated_tokens in generated_tokens_list
|
6043 |
]
|
6044 |
prob = self.get_probabilities(top_tokens_list)
|
6045 |
-
prob_of_risk = prob[1]
|
6046 |
|
6047 |
res = next(iter(generated_tokens_list))["text"].strip()
|
6048 |
|
@@ -6055,7 +6158,7 @@ class GraniteGuardianBase(InstanceMetric):
|
|
6055 |
|
6056 |
return label, prob_of_risk
|
6057 |
|
6058 |
-
def get_probabilities(self, top_tokens_list):
|
6059 |
import torch
|
6060 |
|
6061 |
safe_token_prob = 1e-50
|
@@ -6254,7 +6357,7 @@ class SQLExecutionAccuracy(InstanceMetric):
|
|
6254 |
_requirements_list = ["sqlglot", "func_timeout"]
|
6255 |
|
6256 |
@staticmethod
|
6257 |
-
def
|
6258 |
"""Compares two DataFrames based on row content, ignoring column names.
|
6259 |
|
6260 |
Args:
|
@@ -6262,7 +6365,7 @@ class SQLExecutionAccuracy(InstanceMetric):
|
|
6262 |
df2 (pd.DataFrame): Pandas DataFrame 2 to compare.
|
6263 |
|
6264 |
Returns:
|
6265 |
-
True if the DataFrames have the same
|
6266 |
False otherwise.
|
6267 |
"""
|
6268 |
df1.fillna(0, inplace=True)
|
@@ -6276,6 +6379,20 @@ class SQLExecutionAccuracy(InstanceMetric):
|
|
6276 |
|
6277 |
return df1_rows_sorted == df2_rows_sorted
|
6278 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6279 |
@staticmethod
|
6280 |
def is_subset_ignore_colnames(df1, df2):
|
6281 |
"""Checks if df1 is a subset of df2 based on row content, ignoring column names.
|
@@ -6343,6 +6460,7 @@ class SQLExecutionAccuracy(InstanceMetric):
|
|
6343 |
import time
|
6344 |
|
6345 |
from func_timeout import func_timeout
|
|
|
6346 |
|
6347 |
from .sql_utils import sqlglot_optimized_equivalence
|
6348 |
|
@@ -6358,6 +6476,9 @@ class SQLExecutionAccuracy(InstanceMetric):
|
|
6358 |
)
|
6359 |
end_time = time.perf_counter()
|
6360 |
gold_sql_runtime = end_time - start_time
|
|
|
|
|
|
|
6361 |
except Exception as e:
|
6362 |
gold_error = f"Error executing gold SQL: {e}"
|
6363 |
if gold_error is not None:
|
@@ -6389,10 +6510,10 @@ class SQLExecutionAccuracy(InstanceMetric):
|
|
6389 |
gold_sql_runtime,
|
6390 |
0,
|
6391 |
0,
|
6392 |
-
|
6393 |
0,
|
6394 |
gold_df.to_json(),
|
6395 |
-
|
6396 |
"",
|
6397 |
)
|
6398 |
if predicted_sql.lower().strip() == gold_sql.lower().strip():
|
@@ -6417,6 +6538,9 @@ class SQLExecutionAccuracy(InstanceMetric):
|
|
6417 |
)
|
6418 |
end_time = time.perf_counter()
|
6419 |
pred_sql_runtime = end_time - start_time
|
|
|
|
|
|
|
6420 |
except Exception as e:
|
6421 |
pred_error = f"Error executing predicted SQL: {e}"
|
6422 |
logger.info(pred_error)
|
@@ -6445,9 +6569,20 @@ class SQLExecutionAccuracy(InstanceMetric):
|
|
6445 |
pred_res = pred_res["results"]
|
6446 |
predicted_df = pd.DataFrame(pred_res)
|
6447 |
|
6448 |
-
|
6449 |
-
|
6450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6451 |
|
6452 |
subset_non_empty_execution_result = 0
|
6453 |
non_empty_execution_result = 0
|
@@ -6473,6 +6608,8 @@ class SQLExecutionAccuracy(InstanceMetric):
|
|
6473 |
)
|
6474 |
|
6475 |
def compute(self, references: List[Any], prediction: str, task_data: Dict) -> dict:
|
|
|
|
|
6476 |
predicted_sql = prediction
|
6477 |
execution_result: float = 0.0
|
6478 |
|
|
|
63 |
from .operators import ArtifactFetcherMixin, Copy, Set
|
64 |
from .random_utils import get_seed
|
65 |
from .settings_utils import get_settings
|
|
|
66 |
from .stream import MultiStream, Stream
|
67 |
from .type_utils import Type, isoftype, parse_type_string, to_type_string
|
68 |
from .utils import deep_copy, recursive_copy
|
69 |
|
|
|
|
|
70 |
logger = get_logger()
|
71 |
settings = get_settings()
|
72 |
|
|
|
124 |
|
125 |
def nan_max(x):
|
126 |
with warnings.catch_warnings():
|
|
|
|
|
|
|
|
|
127 |
warnings.simplefilter("ignore", category=RuntimeWarning)
|
128 |
return np.nanmax(x)
|
129 |
|
130 |
+
def nan_std(x):
|
131 |
+
with warnings.catch_warnings():
|
132 |
+
warnings.simplefilter("ignore", category=RuntimeWarning)
|
133 |
+
result = np.nanstd(x)
|
134 |
+
try:
|
135 |
+
return float(result)
|
136 |
+
except:
|
137 |
+
return result
|
138 |
+
|
139 |
|
140 |
class UpdateStream(InstanceOperator):
|
141 |
update: dict
|
|
|
367 |
return np.random.default_rng(hash(get_seed()) & _max_32bit)
|
368 |
|
369 |
|
370 |
+
class Statistic:
|
371 |
+
"""Statistic for which the confidence interval is to be calculated.
|
372 |
+
|
373 |
+
`statistic` must be a callable that accepts ``len(data)`` samples
|
374 |
+
as separate arguments and returns the resulting statistic.
|
375 |
+
If `vectorized` is set ``True``,
|
376 |
+
`statistic` must also accept a keyword argument `axis` and be
|
377 |
+
vectorized to compute the statistic along the provided `axis`.
|
378 |
+
"""
|
379 |
+
|
380 |
+
def __init__(self, data, score_names, scorer):
|
381 |
+
self.data = data
|
382 |
+
self.score_names = score_names
|
383 |
+
self.scorer = scorer
|
384 |
+
self._history = []
|
385 |
+
|
386 |
+
def __call__(self, indices, axis=0):
|
387 |
+
# indices might be a 1D or 2D array, depending on bootstrap internals
|
388 |
+
# For simplicity, ensure we handle them as 1D.
|
389 |
+
indices = np.atleast_1d(indices).astype(int)
|
390 |
+
|
391 |
+
# Gather the subset
|
392 |
+
sample = [self.data[i] for i in indices]
|
393 |
+
|
394 |
+
# Compute metrics on this sample
|
395 |
+
scores = self.scorer(sample)
|
396 |
+
|
397 |
+
# Return them in consistent order
|
398 |
+
result = np.array([scores[m] for m in self.score_names])
|
399 |
+
self._history.append(result)
|
400 |
+
return result
|
401 |
+
def mean(self, idx):
|
402 |
+
return nan_mean([result[idx] for result in self._history])
|
403 |
+
|
404 |
+
def std(self, idx):
|
405 |
+
return nan_std([result[idx] for result in self._history])
|
406 |
+
|
407 |
class ConfidenceIntervalMixin(Artifact):
|
408 |
n_resamples: int = 1000
|
409 |
confidence_level: float = 0.95
|
|
|
413 |
def _sample_to_scores(self, sample: List[Any]) -> Dict[str, Any]:
|
414 |
pass
|
415 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
|
417 |
def bootstrap(self, data: List[Any], score_names: List[str]):
|
418 |
if self.ci_score_names is not None:
|
419 |
score_names = self.ci_score_names
|
420 |
|
421 |
+
|
422 |
+
statistic = Statistic(data, score_names, self._sample_to_scores)
|
423 |
+
with warnings.catch_warnings():
|
424 |
+
warnings.filterwarnings( # Ignore error the arises when all sample scores are identical
|
425 |
+
"ignore",
|
426 |
+
message="invalid value encountered in divide",
|
427 |
+
category=RuntimeWarning
|
428 |
+
)
|
429 |
+
|
430 |
+
intervals = bootstrap(
|
431 |
+
(np.arange(len(data)),),
|
432 |
+
statistic=statistic,
|
433 |
+
n_resamples=self.n_resamples,
|
434 |
+
confidence_level=self.confidence_level,
|
435 |
+
random_state=new_random_generator(),
|
436 |
+
paired=False,
|
437 |
+
vectorized=False,
|
438 |
+
method="BCa",
|
439 |
+
).confidence_interval
|
440 |
+
|
441 |
|
442 |
result = {}
|
443 |
for i, metric in enumerate(score_names):
|
444 |
+
high = intervals.high[i]
|
445 |
+
low = intervals.low[i]
|
446 |
+
if np.isnan(high) and np.isnan(low):
|
447 |
+
if statistic.std(i) == 0: # When sample scores are identical "BCa" will fail (due to division by std 0)
|
448 |
+
high = low = statistic.mean(i) # In this case we will use the mean (as there is no variance)
|
449 |
+
result[f"{metric}_ci_low"] = float(low)
|
450 |
+
result[f"{metric}_ci_high"] = float(high)
|
451 |
|
452 |
return result
|
453 |
|
|
|
2807 |
remote_url = "https://raw.githubusercontent.com/czyssrs/FinQA/dfc5b72c01ee17c442d28d5201b82a1f4e95d5af/code/evaluate/evaluate.py"
|
2808 |
local_filepath = "/tmp/finqa_eval_script.py"
|
2809 |
module_name = "finqa_eval"
|
2810 |
+
hash_of_script = "42430b8613082bb4b85d49210284135d" # pragma: allowlist secret
|
2811 |
|
2812 |
download_finqa_eval_script_file(remote_url, local_filepath, hash_of_script)
|
2813 |
self.finqa_module = load_finqa_eval_module_from_file(
|
|
|
3413 |
result["precision_macro"] = self.zero_division
|
3414 |
|
3415 |
|
3416 |
+
class KeyValueExtraction(GlobalMetric):
|
|
|
3417 |
|
3418 |
+
prediction_type = Dict[str,str]
|
3419 |
+
metric : Metric
|
3420 |
+
single_reference_per_prediction = True
|
3421 |
+
main_score = ""
|
3422 |
+
def prepare(self):
|
3423 |
+
super().prepare()
|
3424 |
+
self.main_score = f"{self.metric.main_score}_micro"
|
3425 |
|
3426 |
+
def compute(
|
3427 |
+
self,
|
3428 |
+
references: List[List[Any]],
|
3429 |
+
predictions: List[Any],
|
3430 |
+
task_data: List[Dict],
|
3431 |
+
) -> dict:
|
3432 |
+
references = [element[0] for element in references]
|
3433 |
|
3434 |
+
key_statistics = {}
|
3435 |
+
all_reference_keys = set()
|
3436 |
+
for reference in references:
|
3437 |
+
all_reference_keys.update(list(reference.keys()))
|
3438 |
+
for key in all_reference_keys:
|
3439 |
+
key_statistics[key]= []
|
3440 |
+
|
3441 |
+
num_prediction_keys=0
|
3442 |
+
illegal_prediction_keys=0
|
3443 |
+
for reference, prediction in zip(references, predictions):
|
3444 |
+
for key in all_reference_keys:
|
3445 |
+
if (key not in reference and key not in prediction):
|
3446 |
+
continue
|
3447 |
+
if (key in reference and key in prediction):
|
3448 |
+
multi_stream = MultiStream.from_iterables({"test": [{"prediction" : prediction[key],
|
3449 |
+
"references" : [reference[key]]}
|
3450 |
+
]})
|
3451 |
+
output_multi_stream = self.metric(multi_stream)
|
3452 |
+
output_stream = output_multi_stream["test"]
|
3453 |
+
score = next(iter(output_stream))["score"]["global"]["score"]
|
3454 |
+
key_statistics[key].append(score)
|
3455 |
+
else:
|
3456 |
+
key_statistics[key].append(0.0)
|
3457 |
+
|
3458 |
+
for key in prediction.keys():
|
3459 |
+
num_prediction_keys += 1
|
3460 |
+
if key not in all_reference_keys:
|
3461 |
+
illegal_prediction_keys += 1
|
3462 |
+
|
3463 |
+
result={}
|
3464 |
+
|
3465 |
+
average = 0
|
3466 |
+
total = 0
|
3467 |
+
|
3468 |
+
weighted_average = 0
|
3469 |
+
for key in key_statistics:
|
3470 |
+
mean_for_key = numpy.mean(key_statistics[key])
|
3471 |
+
num = len(key_statistics[key])
|
3472 |
+
total += num
|
3473 |
+
average += mean_for_key
|
3474 |
+
weighted_average += mean_for_key * num
|
3475 |
+
result[f"{self.metric.main_score}_{key}"] = mean_for_key
|
3476 |
+
|
3477 |
+
result[f"{self.metric.main_score}_micro"] = weighted_average / total
|
3478 |
+
result[f"{self.metric.main_score}_macro"] = average / len(key_statistics)
|
3479 |
+
if (num_prediction_keys !=0):
|
3480 |
+
result[f"{self.metric.main_score}_legal_keys_in_predictions"] = 1 - 1.0 * illegal_prediction_keys / num_prediction_keys
|
3481 |
+
else:
|
3482 |
+
result[f"{self.metric.main_score}_legal_keys_in_predictions"] = 0
|
3483 |
|
3484 |
+
return result
|
3485 |
|
3486 |
+
class NER(CustomF1):
|
3487 |
+
"""F1 Metrics that receives as input a list of (Entity,EntityType) pairs."""
|
3488 |
|
3489 |
prediction_type = List[Tuple[str, str]]
|
3490 |
|
3491 |
def get_element_group(self, element, additional_input):
|
3492 |
+
return element[1]
|
3493 |
|
3494 |
def get_element_representation(self, element, additional_input):
|
3495 |
return str(element)
|
|
|
6100 |
)
|
6101 |
|
6102 |
def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
|
6103 |
+
# TODO replace with logic inside verify_granite_guardian_config and process_input_fields
|
6104 |
+
task_data["prediction"] = prediction
|
6105 |
+
|
6106 |
self.verify_granite_guardian_config(task_data)
|
6107 |
self.set_main_score()
|
6108 |
|
|
|
6116 |
)
|
6117 |
messages = self.process_input_fields(task_data)
|
6118 |
prompt = self.get_prompt(messages)
|
6119 |
+
data_classification_policy = task_data.get("metadata", {}).get("data_classification_policy")
|
6120 |
+
|
6121 |
+
result = self.inference_engine.infer_log_probs([{"source": prompt, "data_classification_policy": data_classification_policy}])
|
6122 |
+
|
6123 |
generated_tokens_list = result[0]
|
6124 |
label, prob_of_risk = self.parse_output(generated_tokens_list)
|
6125 |
confidence_score = (
|
|
|
6132 |
f"{self.main_score}_prob_of_risk": prob_of_risk,
|
6133 |
f"{self.main_score}_certainty": confidence_score,
|
6134 |
f"{self.main_score}_label": label,
|
6135 |
+
f"{self.main_score}_prompt": prompt,
|
6136 |
}
|
6137 |
logger.debug(f"Results are ready:\n{result}")
|
6138 |
return result
|
|
|
6145 |
generated_tokens["top_tokens"] for generated_tokens in generated_tokens_list
|
6146 |
]
|
6147 |
prob = self.get_probabilities(top_tokens_list)
|
6148 |
+
prob_of_risk = prob[1].item()
|
6149 |
|
6150 |
res = next(iter(generated_tokens_list))["text"].strip()
|
6151 |
|
|
|
6158 |
|
6159 |
return label, prob_of_risk
|
6160 |
|
6161 |
+
def get_probabilities(self, top_tokens_list) -> Tuple[np.float32, np.float32]:
|
6162 |
import torch
|
6163 |
|
6164 |
safe_token_prob = 1e-50
|
|
|
6357 |
_requirements_list = ["sqlglot", "func_timeout"]
|
6358 |
|
6359 |
@staticmethod
|
6360 |
+
def compare_dfs_ignore_colnames_ordered_rows(df1, df2):
|
6361 |
"""Compares two DataFrames based on row content, ignoring column names.
|
6362 |
|
6363 |
Args:
|
|
|
6365 |
df2 (pd.DataFrame): Pandas DataFrame 2 to compare.
|
6366 |
|
6367 |
Returns:
|
6368 |
+
True if the DataFrames have the same ordered rows (ignoring column names),
|
6369 |
False otherwise.
|
6370 |
"""
|
6371 |
df1.fillna(0, inplace=True)
|
|
|
6379 |
|
6380 |
return df1_rows_sorted == df2_rows_sorted
|
6381 |
|
6382 |
+
@staticmethod
|
6383 |
+
def compare_dfs_ignore_colnames_unordered_rows(df1, df2):
|
6384 |
+
"""Compares two DataFrames based on row content, ignoring row order and column names.
|
6385 |
+
|
6386 |
+
Args:
|
6387 |
+
df1 (pd.DataFrame): Pandas DataFrame 1 to compare.
|
6388 |
+
df2 (pd.DataFrame): Pandas DataFrame 2 to compare.
|
6389 |
+
|
6390 |
+
Returns:
|
6391 |
+
True if the DataFrames have the same content (ignoring column names and row order),
|
6392 |
+
False otherwise.
|
6393 |
+
"""
|
6394 |
+
return set(map(tuple, df1.to_numpy())) == set(map(tuple, df2.to_numpy()))
|
6395 |
+
|
6396 |
@staticmethod
|
6397 |
def is_subset_ignore_colnames(df1, df2):
|
6398 |
"""Checks if df1 is a subset of df2 based on row content, ignoring column names.
|
|
|
6460 |
import time
|
6461 |
|
6462 |
from func_timeout import func_timeout
|
6463 |
+
from func_timeout.exceptions import FunctionTimedOut
|
6464 |
|
6465 |
from .sql_utils import sqlglot_optimized_equivalence
|
6466 |
|
|
|
6476 |
)
|
6477 |
end_time = time.perf_counter()
|
6478 |
gold_sql_runtime = end_time - start_time
|
6479 |
+
except FunctionTimedOut as e:
|
6480 |
+
pred_error = f"Timeout error executing gold SQL: {e}"
|
6481 |
+
logger.warning(pred_error)
|
6482 |
except Exception as e:
|
6483 |
gold_error = f"Error executing gold SQL: {e}"
|
6484 |
if gold_error is not None:
|
|
|
6510 |
gold_sql_runtime,
|
6511 |
0,
|
6512 |
0,
|
6513 |
+
0,
|
6514 |
0,
|
6515 |
gold_df.to_json(),
|
6516 |
+
"",
|
6517 |
"",
|
6518 |
)
|
6519 |
if predicted_sql.lower().strip() == gold_sql.lower().strip():
|
|
|
6538 |
)
|
6539 |
end_time = time.perf_counter()
|
6540 |
pred_sql_runtime = end_time - start_time
|
6541 |
+
except FunctionTimedOut as e:
|
6542 |
+
pred_error = f"Timeout error executing predicted SQL: {e}"
|
6543 |
+
logger.info(pred_error)
|
6544 |
except Exception as e:
|
6545 |
pred_error = f"Error executing predicted SQL: {e}"
|
6546 |
logger.info(pred_error)
|
|
|
6569 |
pred_res = pred_res["results"]
|
6570 |
predicted_df = pd.DataFrame(pred_res)
|
6571 |
|
6572 |
+
if "ORDER BY" in gold_sql.upper():
|
6573 |
+
execution_result = (
|
6574 |
+
1
|
6575 |
+
if self.compare_dfs_ignore_colnames_ordered_rows(predicted_df, gold_df)
|
6576 |
+
else 0
|
6577 |
+
)
|
6578 |
+
else:
|
6579 |
+
execution_result = (
|
6580 |
+
1
|
6581 |
+
if self.compare_dfs_ignore_colnames_unordered_rows(
|
6582 |
+
predicted_df, gold_df
|
6583 |
+
)
|
6584 |
+
else 0
|
6585 |
+
)
|
6586 |
|
6587 |
subset_non_empty_execution_result = 0
|
6588 |
non_empty_execution_result = 0
|
|
|
6608 |
)
|
6609 |
|
6610 |
def compute(self, references: List[Any], prediction: str, task_data: Dict) -> dict:
|
6611 |
+
from .sql_utils import get_db_connector
|
6612 |
+
|
6613 |
predicted_sql = prediction
|
6614 |
execution_result: float = 0.0
|
6615 |
|
schema.py
CHANGED
@@ -67,8 +67,7 @@ def load_chat_source(chat_str):
|
|
67 |
)
|
68 |
return chat
|
69 |
|
70 |
-
|
71 |
-
def loads_instance(batch):
|
72 |
if (
|
73 |
"source" in batch
|
74 |
and isinstance(batch["source"][0], str)
|
@@ -86,6 +85,24 @@ def loads_instance(batch):
|
|
86 |
batch["task_data"] = [json.loads(d) for d in batch["task_data"]]
|
87 |
return batch
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
class FinalizeDataset(InstanceOperatorValidator):
|
91 |
group_by: List[List[str]]
|
|
|
67 |
)
|
68 |
return chat
|
69 |
|
70 |
+
def loads_batch(batch):
|
|
|
71 |
if (
|
72 |
"source" in batch
|
73 |
and isinstance(batch["source"][0], str)
|
|
|
85 |
batch["task_data"] = [json.loads(d) for d in batch["task_data"]]
|
86 |
return batch
|
87 |
|
88 |
+
def loads_instance(instance):
|
89 |
+
if (
|
90 |
+
"source" in instance
|
91 |
+
and isinstance(instance["source"], str)
|
92 |
+
and (
|
93 |
+
instance["source"].startswith('[{"role":')
|
94 |
+
or instance["source"].startswith('[{"content":')
|
95 |
+
)
|
96 |
+
):
|
97 |
+
instance["source"] = load_chat_source(instance["source"])
|
98 |
+
if (
|
99 |
+
not settings.task_data_as_text
|
100 |
+
and "task_data" in instance
|
101 |
+
and isinstance(instance["task_data"], str)
|
102 |
+
):
|
103 |
+
instance["task_data"] = json.loads(instance["task_data"])
|
104 |
+
return instance
|
105 |
+
|
106 |
|
107 |
class FinalizeDataset(InstanceOperatorValidator):
|
108 |
group_by: List[List[str]]
|
serializers.py
CHANGED
@@ -7,7 +7,6 @@ from typing import Any, Dict, List, Union
|
|
7 |
from .dataclass import AbstractField, Field
|
8 |
from .operators import InstanceFieldOperator
|
9 |
from .settings_utils import get_constants
|
10 |
-
from .sql_utils import get_db_connector
|
11 |
from .type_utils import isoftype, to_type_string
|
12 |
from .types import (
|
13 |
Dialog,
|
@@ -203,5 +202,7 @@ class SQLDatabaseAsSchemaSerializer(SingleTypeSerializer):
|
|
203 |
serialized_type = SQLDatabase
|
204 |
|
205 |
def serialize(self, value: SQLDatabase, instance: Dict[str, Any]) -> str:
|
|
|
|
|
206 |
connector = get_db_connector(value["db_type"])(value)
|
207 |
return connector.get_table_schema()
|
|
|
7 |
from .dataclass import AbstractField, Field
|
8 |
from .operators import InstanceFieldOperator
|
9 |
from .settings_utils import get_constants
|
|
|
10 |
from .type_utils import isoftype, to_type_string
|
11 |
from .types import (
|
12 |
Dialog,
|
|
|
202 |
serialized_type = SQLDatabase
|
203 |
|
204 |
def serialize(self, value: SQLDatabase, instance: Dict[str, Any]) -> str:
|
205 |
+
from .sql_utils import get_db_connector
|
206 |
+
|
207 |
connector = get_db_connector(value["db_type"])(value)
|
208 |
return connector.get_table_schema()
|
settings_utils.py
CHANGED
@@ -159,6 +159,7 @@ if Settings.is_uninitilized():
|
|
159 |
settings.hf_offline_datasets_path = None
|
160 |
settings.hf_offline_metrics_path = None
|
161 |
settings.hf_offline_models_path = None
|
|
|
162 |
|
163 |
if Constants.is_uninitilized():
|
164 |
constants = Constants()
|
|
|
159 |
settings.hf_offline_datasets_path = None
|
160 |
settings.hf_offline_metrics_path = None
|
161 |
settings.hf_offline_models_path = None
|
162 |
+
settings.inference_engine_cache_path = "./inference_engine_cache/"
|
163 |
|
164 |
if Constants.is_uninitilized():
|
165 |
constants = Constants()
|
sql_utils.py
CHANGED
@@ -1,4 +1,7 @@
|
|
|
|
1 |
import glob
|
|
|
|
|
2 |
import os
|
3 |
import re
|
4 |
import sqlite3
|
@@ -16,6 +19,14 @@ from .types import SQLDatabase
|
|
16 |
|
17 |
logger = get_logger()
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
class DatabaseConnector(ABC):
|
21 |
"""Abstract base class for database connectors."""
|
@@ -23,7 +34,7 @@ class DatabaseConnector(ABC):
|
|
23 |
def __init__(self, db_config: SQLDatabase):
|
24 |
self.db_config = db_config
|
25 |
self.databases_folder = os.path.join(
|
26 |
-
os.environ.get("
|
27 |
)
|
28 |
os.makedirs(self.databases_folder, exist_ok=True)
|
29 |
|
@@ -187,6 +198,177 @@ class InMemoryDatabaseConnector(DatabaseConnector):
|
|
187 |
conn.close()
|
188 |
|
189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
@lru_cache(maxsize=128)
|
191 |
def execute_query_remote(
|
192 |
api_url: str,
|
@@ -318,12 +500,20 @@ class RemoteDatabaseConnector(DatabaseConnector):
|
|
318 |
|
319 |
def execute_query(self, query: str) -> Any:
|
320 |
"""Executes a query against the remote database, with retries for certain exceptions."""
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
)
|
328 |
|
329 |
|
|
|
1 |
+
import functools
|
2 |
import glob
|
3 |
+
import hashlib
|
4 |
+
import json
|
5 |
import os
|
6 |
import re
|
7 |
import sqlite3
|
|
|
19 |
|
20 |
logger = get_logger()
|
21 |
|
22 |
+
# Check if caching is enabled via environment variable
|
23 |
+
CACHE_LOCATION = os.getenv("UNITXT_CACHE_LOCATION")
|
24 |
+
|
25 |
+
# Set max cache size to 10GB or the value of env var MAX_CACHE_SIZE
|
26 |
+
MAX_CACHE_SIZE = os.getenv("MAX_CACHE_SIZE", 10 * 1024**3)
|
27 |
+
|
28 |
+
_cache_instance = None
|
29 |
+
|
30 |
|
31 |
class DatabaseConnector(ABC):
|
32 |
"""Abstract base class for database connectors."""
|
|
|
34 |
def __init__(self, db_config: SQLDatabase):
|
35 |
self.db_config = db_config
|
36 |
self.databases_folder = os.path.join(
|
37 |
+
os.environ.get("UNITXT_CACHE_LOCATION", "cache/text2sql"), "databases"
|
38 |
)
|
39 |
os.makedirs(self.databases_folder, exist_ok=True)
|
40 |
|
|
|
198 |
conn.close()
|
199 |
|
200 |
|
201 |
+
def get_cache():
|
202 |
+
"""Returns a singleton cache instance, initializing it if necessary."""
|
203 |
+
global _cache_instance
|
204 |
+
if _cache_instance is None:
|
205 |
+
_cache_instance = Cache()
|
206 |
+
return _cache_instance
|
207 |
+
|
208 |
+
|
209 |
+
def generate_cache_key(*args, **kwargs):
|
210 |
+
"""Generate a stable hashable cache key for various input types.
|
211 |
+
|
212 |
+
:param args: Positional arguments of the function.
|
213 |
+
:param kwargs: Keyword arguments of the function.
|
214 |
+
:return: A hashed key as a string.
|
215 |
+
"""
|
216 |
+
try:
|
217 |
+
# Convert args and kwargs to a JSON string (sorted to ensure consistency)
|
218 |
+
serialized = json.dumps(
|
219 |
+
{"args": args, "kwargs": kwargs}, sort_keys=True, default=str
|
220 |
+
)
|
221 |
+
except TypeError:
|
222 |
+
# Fallback for non-serializable objects
|
223 |
+
serialized = repr((args, kwargs))
|
224 |
+
|
225 |
+
# Hash the serialized data
|
226 |
+
return hashlib.md5(serialized.encode()).hexdigest()
|
227 |
+
|
228 |
+
|
229 |
+
class Cache:
|
230 |
+
"""A class that provides disk-based caching functionality for a given function."""
|
231 |
+
|
232 |
+
def __init__(self):
|
233 |
+
"""Initializes the cache.
|
234 |
+
|
235 |
+
If `CACHE_LOCATION` (os.getenv("UNITXT_CACHE_LOCATION") is set, a disk-based
|
236 |
+
cache is created using `diskcache`.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
None
|
240 |
+
|
241 |
+
Returns:
|
242 |
+
None
|
243 |
+
"""
|
244 |
+
if CACHE_LOCATION:
|
245 |
+
try:
|
246 |
+
import diskcache
|
247 |
+
|
248 |
+
# Ensure the cache directory exists
|
249 |
+
os.makedirs(CACHE_LOCATION, exist_ok=True)
|
250 |
+
|
251 |
+
# Create a global diskcache Cache instance
|
252 |
+
self.cache = diskcache.Cache(CACHE_LOCATION, size_limit=MAX_CACHE_SIZE)
|
253 |
+
logger.info(f"Caching enabled at {CACHE_LOCATION}")
|
254 |
+
except ImportError as e:
|
255 |
+
raise ImportError(
|
256 |
+
"UNITXT_CACHE_LOCATION is set, but diskcache is not installed.\n"
|
257 |
+
"Please install diskcache `pip install diskcache` "
|
258 |
+
"or unset UNITXT_CACHE_LOCATION."
|
259 |
+
) from e
|
260 |
+
else:
|
261 |
+
self.cache = None # Disable caching
|
262 |
+
|
263 |
+
def get_or_set(self, key, compute_fn, no_cache=False, refresh=False):
|
264 |
+
if not self.cache or no_cache:
|
265 |
+
logger.info(f"Bypassing cache for key: {key}")
|
266 |
+
return compute_fn()
|
267 |
+
|
268 |
+
if refresh and key in self.cache:
|
269 |
+
logger.info(f"Refreshing cache for key: {key}")
|
270 |
+
del self.cache[key]
|
271 |
+
|
272 |
+
if key in self.cache:
|
273 |
+
logger.info(f"Cache hit for key: {key}")
|
274 |
+
return self.cache[key]
|
275 |
+
|
276 |
+
logger.info(f"Cache miss for key: {key}. Computing value...")
|
277 |
+
result = compute_fn()
|
278 |
+
self.cache[key] = result
|
279 |
+
logger.info(f"Stored result in cache for key: {key}")
|
280 |
+
return result
|
281 |
+
|
282 |
+
async def async_get_or_set(self, key, compute_fn, no_cache=False, refresh=False):
|
283 |
+
if not self.cache or no_cache:
|
284 |
+
logger.info(f"Bypassing cache for key: {key}")
|
285 |
+
return await compute_fn()
|
286 |
+
|
287 |
+
if refresh and key in self.cache:
|
288 |
+
logger.info(f"Refreshing cache for key: {key}")
|
289 |
+
del self.cache[key]
|
290 |
+
|
291 |
+
if key in self.cache:
|
292 |
+
logger.info(f"Cache hit for key: {key}")
|
293 |
+
return self.cache[key]
|
294 |
+
|
295 |
+
logger.info(f"Cache miss for key: {key}. Computing value asynchronously...")
|
296 |
+
result = await compute_fn()
|
297 |
+
self.cache[key] = result
|
298 |
+
logger.info(f"Stored result in cache for key: {key}")
|
299 |
+
return result
|
300 |
+
|
301 |
+
def memoize(self, key_func=generate_cache_key, no_cache=False, refresh=False):
|
302 |
+
def decorator(func):
|
303 |
+
@functools.wraps(func)
|
304 |
+
def wrapper(*args, **kwargs):
|
305 |
+
if not self.cache or no_cache:
|
306 |
+
logger.info(f"Bypassing cache for function: {func.__name__}")
|
307 |
+
return func(*args, **kwargs)
|
308 |
+
|
309 |
+
key = key_func(func.__name__, *args, **kwargs)
|
310 |
+
|
311 |
+
if refresh and key in self.cache:
|
312 |
+
logger.info(
|
313 |
+
f"Refreshing cache for function: {func.__name__}, key: {key}"
|
314 |
+
)
|
315 |
+
del self.cache[key]
|
316 |
+
|
317 |
+
if key in self.cache:
|
318 |
+
logger.info(f"Cache hit for function: {func.__name__}, key: {key}")
|
319 |
+
return self.cache[key]
|
320 |
+
|
321 |
+
logger.info(
|
322 |
+
f"Cache miss for function: {func.__name__}, key: {key}. Computing value..."
|
323 |
+
)
|
324 |
+
result = func(*args, **kwargs)
|
325 |
+
self.cache[key] = result
|
326 |
+
logger.info(
|
327 |
+
f"Stored result in cache for function: {func.__name__}, key: {key}"
|
328 |
+
)
|
329 |
+
return result
|
330 |
+
|
331 |
+
return wrapper
|
332 |
+
|
333 |
+
return decorator
|
334 |
+
|
335 |
+
def async_memoize(self, key_func=generate_cache_key, no_cache=False, refresh=False):
|
336 |
+
def decorator(func):
|
337 |
+
@functools.wraps(func)
|
338 |
+
async def wrapper(*args, **kwargs):
|
339 |
+
if no_cache:
|
340 |
+
logger.info(f"Bypassing cache for async function: {func.__name__}")
|
341 |
+
return await func(*args, **kwargs)
|
342 |
+
|
343 |
+
key = key_func(func.__name__, *args, **kwargs)
|
344 |
+
|
345 |
+
if refresh and key in self.cache:
|
346 |
+
logger.info(
|
347 |
+
f"Refreshing cache for async function: {func.__name__}, key: {key}"
|
348 |
+
)
|
349 |
+
del self.cache[key]
|
350 |
+
|
351 |
+
if key in self.cache:
|
352 |
+
logger.info(
|
353 |
+
f"Cache hit for async function: {func.__name__}, key: {key}"
|
354 |
+
)
|
355 |
+
return self.cache[key]
|
356 |
+
|
357 |
+
logger.info(
|
358 |
+
f"Cache miss for async function: {func.__name__}, key: {key}. Computing value..."
|
359 |
+
)
|
360 |
+
result = await func(*args, **kwargs)
|
361 |
+
self.cache[key] = result
|
362 |
+
logger.info(
|
363 |
+
f"Stored result in cache for async function: {func.__name__}, key: {key}"
|
364 |
+
)
|
365 |
+
return result
|
366 |
+
|
367 |
+
return wrapper
|
368 |
+
|
369 |
+
return decorator
|
370 |
+
|
371 |
+
|
372 |
@lru_cache(maxsize=128)
|
373 |
def execute_query_remote(
|
374 |
api_url: str,
|
|
|
500 |
|
501 |
def execute_query(self, query: str) -> Any:
|
502 |
"""Executes a query against the remote database, with retries for certain exceptions."""
|
503 |
+
cache = get_cache()
|
504 |
+
|
505 |
+
cache_key = generate_cache_key(
|
506 |
+
"sql_request", self.api_url, self.database_id, query
|
507 |
+
)
|
508 |
+
return cache.get_or_set(
|
509 |
+
cache_key,
|
510 |
+
lambda: execute_query_remote(
|
511 |
+
api_url=self.api_url,
|
512 |
+
database_id=self.database_id,
|
513 |
+
api_key=self.api_key,
|
514 |
+
query=query,
|
515 |
+
timeout=self.timeout,
|
516 |
+
),
|
517 |
)
|
518 |
|
519 |
|
struct_data_operators.py
CHANGED
@@ -1024,24 +1024,24 @@ class ShuffleColumnsNames(TypeDependentAugmentor):
|
|
1024 |
return {"header": shuffled_header, "rows": table["rows"]}
|
1025 |
|
1026 |
|
1027 |
-
class
|
1028 |
-
"""Convert a Json string of representing key value as dictionary
|
|
|
|
|
|
|
|
|
1029 |
|
1030 |
def process_value(self, text: str) -> List[Tuple[str, str]]:
|
1031 |
try:
|
1032 |
dict_value = json.loads(text)
|
1033 |
except Exception as e:
|
1034 |
UnitxtWarning(
|
1035 |
-
f"Unable to convert input text to json format in
|
1036 |
)
|
1037 |
dict_value = {}
|
1038 |
if not isoftype(dict_value, Dict[str, Any]):
|
1039 |
UnitxtWarning(
|
1040 |
-
f"Unable to convert input text to dictionary in
|
1041 |
)
|
1042 |
dict_value = {}
|
1043 |
-
return
|
1044 |
-
(str(key), str(value))
|
1045 |
-
for key, value in dict_value.items()
|
1046 |
-
if value is not None
|
1047 |
-
]
|
|
|
1024 |
return {"header": shuffled_header, "rows": table["rows"]}
|
1025 |
|
1026 |
|
1027 |
+
class JsonStrToDict(FieldOperator):
|
1028 |
+
"""Convert a Json string of representing key value as dictionary.
|
1029 |
+
|
1030 |
+
Ensure keys and values are strings, and there are no None values.
|
1031 |
+
|
1032 |
+
"""
|
1033 |
|
1034 |
def process_value(self, text: str) -> List[Tuple[str, str]]:
|
1035 |
try:
|
1036 |
dict_value = json.loads(text)
|
1037 |
except Exception as e:
|
1038 |
UnitxtWarning(
|
1039 |
+
f"Unable to convert input text to json format in JsonStrToDict due to {e}. Text: {text}"
|
1040 |
)
|
1041 |
dict_value = {}
|
1042 |
if not isoftype(dict_value, Dict[str, Any]):
|
1043 |
UnitxtWarning(
|
1044 |
+
f"Unable to convert input text to dictionary in JsonStrToDict. Text: {text}"
|
1045 |
)
|
1046 |
dict_value = {}
|
1047 |
+
return {str(k): str(v) for k, v in dict_value.items() if v is not None}
|
|
|
|
|
|
|
|
version.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
version = "1.
|
|
|
1 |
+
version = "1.21.0"
|