Elron commited on
Commit
d346c89
Β·
verified Β·
1 Parent(s): 35fffae

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -40,11 +40,11 @@ https://github.com/IBM/unitxt/assets/23455264/baef9131-39d4-4164-90b2-05da52919f
40
 
41
  ### πŸ¦„ Currently on Unitxt Catalog
42
 
43
- ![Abstract Tasks](https://img.shields.io/badge/Abstract_Tasks-62-blue)
44
- ![Dataset Cards](https://img.shields.io/badge/Dataset_Cards-3025-blue)
45
  ![Templates](https://img.shields.io/badge/Templates-342-blue)
46
- ![Benchmarks](https://img.shields.io/badge/Benchmarks-4-blue)
47
- ![Metrics](https://img.shields.io/badge/Metrics-422-blue)
48
 
49
  ### πŸ¦„ Run Unitxt Exploration Dashboard
50
 
 
40
 
41
  ### πŸ¦„ Currently on Unitxt Catalog
42
 
43
+ ![Abstract Tasks](https://img.shields.io/badge/Abstract_Tasks-64-blue)
44
+ ![Dataset Cards](https://img.shields.io/badge/Dataset_Cards-3174-blue)
45
  ![Templates](https://img.shields.io/badge/Templates-342-blue)
46
+ ![Benchmarks](https://img.shields.io/badge/Benchmarks-6-blue)
47
+ ![Metrics](https://img.shields.io/badge/Metrics-462-blue)
48
 
49
  ### πŸ¦„ Run Unitxt Exploration Dashboard
50
 
api.py CHANGED
@@ -21,7 +21,7 @@ from .loaders import LoadFromDictionary
21
  from .logging_utils import get_logger
22
  from .metric_utils import EvaluationResults, _compute, _inference_post_process
23
  from .operator import SourceOperator
24
- from .schema import loads_instance
25
  from .settings_utils import get_constants, get_settings
26
  from .standard import DatasetRecipe
27
  from .task import Task
@@ -98,6 +98,7 @@ def create_dataset(
98
  train_set: Optional[List[Dict[Any, Any]]] = None,
99
  validation_set: Optional[List[Dict[Any, Any]]] = None,
100
  split: Optional[str] = None,
 
101
  **kwargs,
102
  ) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
103
  """Creates dataset from input data based on a specific task.
@@ -108,6 +109,7 @@ def create_dataset(
108
  train_set : optional train_set
109
  validation_set: optional validation set
110
  split: optional one split to choose
 
111
  **kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
112
 
113
  Returns:
@@ -129,7 +131,7 @@ def create_dataset(
129
  f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
130
  )
131
 
132
- card = TaskCard(loader=LoadFromDictionary(data=data), task=task)
133
  return load_dataset(card=card, split=split, **kwargs)
134
 
135
 
@@ -283,7 +285,7 @@ def produce(
283
  result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
284
  if not is_list:
285
  return result[0]
286
- return Dataset.from_list(result).with_transform(loads_instance)
287
 
288
 
289
  def infer(
 
21
  from .logging_utils import get_logger
22
  from .metric_utils import EvaluationResults, _compute, _inference_post_process
23
  from .operator import SourceOperator
24
+ from .schema import loads_batch
25
  from .settings_utils import get_constants, get_settings
26
  from .standard import DatasetRecipe
27
  from .task import Task
 
98
  train_set: Optional[List[Dict[Any, Any]]] = None,
99
  validation_set: Optional[List[Dict[Any, Any]]] = None,
100
  split: Optional[str] = None,
101
+ data_classification_policy: Optional[List[str]] = None,
102
  **kwargs,
103
  ) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
104
  """Creates dataset from input data based on a specific task.
 
109
  train_set : optional train_set
110
  validation_set: optional validation set
111
  split: optional one split to choose
112
+ data_classification_policy: data_classification_policy
113
  **kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
114
 
115
  Returns:
 
131
  f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
132
  )
133
 
134
+ card = TaskCard(loader=LoadFromDictionary(data=data, data_classification_policy=data_classification_policy), task=task)
135
  return load_dataset(card=card, split=split, **kwargs)
136
 
137
 
 
285
  result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
286
  if not is_list:
287
  return result[0]
288
+ return Dataset.from_list(result).with_transform(loads_batch)
289
 
290
 
291
  def infer(
dataset.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from typing import Optional, Union
3
 
4
  import datasets
5
 
@@ -50,7 +50,7 @@ from .random_utils import __file__ as _
50
  from .recipe import __file__ as _
51
  from .register import __file__ as _
52
  from .schema import __file__ as _
53
- from .schema import loads_instance
54
  from .serializers import __file__ as _
55
  from .settings_utils import __file__ as _
56
  from .settings_utils import get_constants
@@ -120,6 +120,13 @@ class Dataset(datasets.GeneratorBasedBuilder):
120
  dl_manager, "no_checks", **prepare_splits_kwargs
121
  )
122
 
 
 
 
 
 
 
 
123
  def as_dataset(
124
  self,
125
  split: Optional[datasets.Split] = None,
@@ -162,5 +169,5 @@ class Dataset(datasets.GeneratorBasedBuilder):
162
  return (
163
  super()
164
  .as_dataset(split, run_post_process, verification_mode, in_memory)
165
- .with_transform(loads_instance)
166
  )
 
1
  import os
2
+ from typing import Dict, Optional, Union
3
 
4
  import datasets
5
 
 
50
  from .recipe import __file__ as _
51
  from .register import __file__ as _
52
  from .schema import __file__ as _
53
+ from .schema import loads_batch, loads_instance
54
  from .serializers import __file__ as _
55
  from .settings_utils import __file__ as _
56
  from .settings_utils import get_constants
 
120
  dl_manager, "no_checks", **prepare_splits_kwargs
121
  )
122
 
123
+ def as_streaming_dataset(self, split: Optional[str] = None, base_path: Optional[str] = None) -> Union[Dict[str, datasets.IterableDataset], datasets.IterableDataset]:
124
+ return (
125
+ super()
126
+ .as_streaming_dataset(split, base_path=base_path)
127
+ .map(loads_instance)
128
+ )
129
+
130
  def as_dataset(
131
  self,
132
  split: Optional[datasets.Split] = None,
 
169
  return (
170
  super()
171
  .as_dataset(split, run_post_process, verification_mode, in_memory)
172
+ .with_transform(loads_batch)
173
  )
error_utils.py CHANGED
@@ -18,6 +18,7 @@ class Documentation:
18
  BENCHMARKS = "docs/benchmark.html"
19
  DATA_CLASSIFICATION_POLICY = "docs/data_classification_policy.html"
20
  CATALOG = "docs/saving_and_loading_from_catalog.html"
 
21
 
22
 
23
  def additional_info(path: str) -> str:
 
18
  BENCHMARKS = "docs/benchmark.html"
19
  DATA_CLASSIFICATION_POLICY = "docs/data_classification_policy.html"
20
  CATALOG = "docs/saving_and_loading_from_catalog.html"
21
+ SETTINGS = "docs/settings.html"
22
 
23
 
24
  def additional_info(path: str) -> str:
inference.py CHANGED
@@ -2,6 +2,7 @@ import abc
2
  import asyncio
3
  import base64
4
  import dataclasses
 
5
  import io
6
  import json
7
  import logging
@@ -12,6 +13,7 @@ import time
12
  import uuid
13
  from collections import Counter
14
  from datetime import datetime
 
15
  from multiprocessing.pool import ThreadPool
16
  from typing import (
17
  Any,
@@ -29,6 +31,7 @@ from typing import (
29
  )
30
 
31
  from datasets import Dataset, DatasetDict, Image
 
32
  from tqdm import tqdm, trange
33
  from tqdm.asyncio import tqdm_asyncio
34
 
@@ -53,6 +56,11 @@ settings = get_settings()
53
  logger = get_logger()
54
 
55
 
 
 
 
 
 
56
  class StandardAPIParamsMixin(Artifact):
57
  model: str
58
  frequency_penalty: Optional[float] = None
@@ -149,6 +157,8 @@ class ListWithMetadata(List[T]):
149
 
150
  class InferenceEngine(Artifact):
151
  """Abstract base class for inference."""
 
 
152
 
153
  @abc.abstractmethod
154
  def _infer(
@@ -173,6 +183,7 @@ class InferenceEngine(Artifact):
173
  if not settings.mock_inference_mode:
174
  super().prepare() # no need to prepare a mock
175
  self.prepare_engine()
 
176
 
177
  def __call__(
178
  self,
@@ -181,16 +192,20 @@ class InferenceEngine(Artifact):
181
  ) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]:
182
  return self.infer(dataset=dataset, return_meta_data=return_meta_data)
183
 
184
- def infer(
185
- self,
186
- dataset: Union[List[Dict[str, Any]], Dataset],
187
- return_meta_data: bool = False,
188
- ) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]:
189
- """Verifies instances of a dataset and perform inference on the input dataset.
190
 
191
- If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the string
192
- predictions.
193
- """
 
 
 
 
 
 
 
194
  if not isoftype(dataset, Union[List[Dict[str, Any]], Dataset]):
195
  raise Exception(
196
  "Dataset passed to infer() is not list of dictionaries or Huggingface Dataset"
@@ -202,10 +217,54 @@ class InferenceEngine(Artifact):
202
  )
203
 
204
  [self.verify_instance(instance) for instance in dataset]
 
 
 
 
 
 
 
 
 
 
 
 
205
  if settings.mock_inference_mode:
206
  result = self._mock_infer(dataset)
207
  else:
208
- result = self._infer(dataset, return_meta_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  return ListWithMetadata(
210
  result,
211
  metadata={
@@ -221,6 +280,7 @@ class InferenceEngine(Artifact):
221
  ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
222
  return [str(instance["source"]) for instance in dataset]
223
 
 
224
  def get_engine_id(self):
225
  raise NotImplementedError()
226
 
@@ -918,16 +978,18 @@ class HFPipelineBasedInferenceEngine(
918
  return args
919
 
920
  def _create_pipeline(self, model_args: Dict[str, Any]):
921
- from transformers import pipeline
922
 
923
  path = self.model_name
924
  if settings.hf_offline_models_path is not None:
925
  path = os.path.join(settings.hf_offline_models_path, path)
926
 
 
927
  self.model = pipeline(
928
  model=path,
929
  task=self.task,
930
  use_fast=self.use_fast_tokenizer,
 
931
  trust_remote_code=settings.allow_unverified_code,
932
  **model_args,
933
  **self.to_dict(
@@ -1302,7 +1364,7 @@ class IbmGenAiInferenceEngine(
1302
  def _get_credentials():
1303
  from genai import Credentials
1304
 
1305
- api_key_env_var_name = "GENAI_KEY"
1306
  api_key = os.environ.get(api_key_env_var_name)
1307
 
1308
  assert api_key is not None, (
@@ -1718,7 +1780,7 @@ class AzureOpenAIInferenceEngine(OpenAiInferenceEngine):
1718
  ), "Error while trying to run AzureOpenAIInferenceEngine: Missing environment variable param AZURE_OPENAI_HOST or OPENAI_API_VERSION"
1719
  api_url = f"{azure_openapi_host}/openai/deployments/{self.model_name}/chat/completions?api-version={api_version}"
1720
 
1721
- return {"api_key": api_key, "api_url": api_url}
1722
 
1723
  def create_client(self):
1724
  from openai import AzureOpenAI
@@ -1727,12 +1789,13 @@ class AzureOpenAIInferenceEngine(OpenAiInferenceEngine):
1727
  return AzureOpenAI(
1728
  api_key=self.credentials["api_key"],
1729
  base_url=self.credentials["api_url"],
 
1730
  default_headers=self.get_default_headers(),
1731
  )
1732
 
1733
 
1734
  class VLLMRemoteInferenceEngine(OpenAiInferenceEngine):
1735
- label: str = "vllm"
1736
 
1737
 
1738
  class RITSInferenceEngine(
@@ -1741,6 +1804,10 @@ class RITSInferenceEngine(
1741
  label: str = "rits"
1742
  data_classification_policy = ["public", "proprietary"]
1743
 
 
 
 
 
1744
  def get_default_headers(self):
1745
  return {"RITS_API_KEY": self.credentials["api_key"]}
1746
 
@@ -1761,8 +1828,10 @@ class RITSInferenceEngine(
1761
  RITSInferenceEngine._get_model_name_for_endpoint(model_name)
1762
  )
1763
 
1764
- @staticmethod
1765
- def _get_model_name_for_endpoint(model_name: str):
 
 
1766
  return (
1767
  model_name.split("/")[-1]
1768
  .lower()
@@ -1805,7 +1874,7 @@ class TogetherAiInferenceEngine(
1805
  from together import Together
1806
  from together.types.models import ModelType
1807
 
1808
- api_key_env_var_name = "TOGETHER_API_KEY"
1809
  api_key = os.environ.get(api_key_env_var_name)
1810
  assert api_key is not None, (
1811
  f"Error while trying to run TogetherAiInferenceEngine."
@@ -1969,6 +2038,9 @@ class WMLInferenceEngineBase(
1969
  deployment_id (str, optional):
1970
  Deployment ID of a tuned model to be used for
1971
  inference. Mutually exclusive with 'model_name'.
 
 
 
1972
  parameters (Union[WMLInferenceEngineParams, WMLGenerationParamsMixin, WMLChatParamsMixin], optional):
1973
  Defines inference parameters and their values. Deprecated attribute, please pass respective
1974
  parameters directly to the respective class instead.
@@ -1977,6 +2049,7 @@ class WMLInferenceEngineBase(
1977
  credentials: Optional[CredentialsWML] = None
1978
  model_name: Optional[str] = None
1979
  deployment_id: Optional[str] = None
 
1980
  label: str = "wml"
1981
  _requirements_list = {
1982
  "ibm_watsonx_ai": "Install ibm-watsonx-ai package using 'pip install --upgrade ibm-watsonx-ai'. "
@@ -2230,11 +2303,6 @@ class WMLInferenceEngineGeneration(WMLInferenceEngineBase, WMLGenerationParamsMi
2230
 
2231
  If you want to include images in your input, please use 'WMLInferenceEngineChat' instead.
2232
 
2233
- Args:
2234
- concurrency_limit (int):
2235
- Number of concurrent requests sent to a model. Default is 10,
2236
- which is also the maximum value.
2237
-
2238
  Examples:
2239
  .. code-block:: python
2240
 
@@ -2258,8 +2326,6 @@ class WMLInferenceEngineGeneration(WMLInferenceEngineBase, WMLGenerationParamsMi
2258
  results = wml_inference.infer(dataset["test"])
2259
  """
2260
 
2261
- concurrency_limit: int = 10
2262
-
2263
  def verify(self):
2264
  super().verify()
2265
 
@@ -2511,6 +2577,32 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
2511
  # images as SDK allows sending only one image per message.
2512
  return [messages]
2513
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2514
  def _send_requests(
2515
  self,
2516
  dataset: Union[List[Dict[str, Any]], Dataset],
@@ -2526,27 +2618,25 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
2526
  output_type = "message"
2527
  params["logprobs"] = False
2528
 
2529
- final_results = []
2530
-
2531
- for instance in dataset:
2532
- messages = self.to_messages(instance)
2533
-
2534
- for message in messages:
2535
- result = self._model.chat(
2536
- messages=message,
2537
- params=params,
2538
- )
2539
 
2540
- final_results.append(
2541
- self.get_return_object(
2542
- result["choices"][0][output_type]["content"],
2543
- result,
2544
- instance["source"],
2545
- return_meta_data,
2546
- )
2547
- )
2548
 
2549
- return final_results
 
 
 
 
 
 
 
 
2550
 
2551
  def get_return_object(self, predict_result, result, input_text, return_meta_data):
2552
  if return_meta_data:
@@ -2614,6 +2704,7 @@ def get_text_without_images(instance, image_token="<image>"):
2614
  class LMMSEvalBaseInferenceEngine(
2615
  InferenceEngine, PackageRequirementsMixin, LazyLoadMixin, TorchDeviceMixin
2616
  ):
 
2617
  model_type: str
2618
  model_args: Dict[str, str]
2619
  batch_size: int = 1
@@ -2623,6 +2714,9 @@ class LMMSEvalBaseInferenceEngine(
2623
  "lmms_eval": "Install llms-eval package using 'pip install lmms-eval==0.2.4'",
2624
  }
2625
 
 
 
 
2626
  def prepare_engine(self):
2627
  if not self.lazy_load:
2628
  self._prepare_engine()
@@ -2798,6 +2892,11 @@ class VLLMParamsMixin(Artifact):
2798
 
2799
 
2800
  class VLLMInferenceEngine(InferenceEngine, PackageRequirementsMixin, VLLMParamsMixin):
 
 
 
 
 
2801
  def prepare_engine(self):
2802
  args = self.to_dict([VLLMParamsMixin])
2803
  args.pop("model")
@@ -2883,6 +2982,9 @@ class LiteLLMInferenceEngine(
2883
 
2884
  _requirements_list: list = ["litellm", "tenacity", "tqdm", "diskcache"]
2885
 
 
 
 
2886
  def prepare_engine(self):
2887
  # Initialize the token bucket rate limiter
2888
  self._rate_limiter = AsyncTokenBucket(
@@ -2890,15 +2992,12 @@ class LiteLLMInferenceEngine(
2890
  capacity=self.max_requests_per_second,
2891
  )
2892
  self.inference_type = "litellm"
2893
- import litellm
2894
  from litellm import acompletion
2895
- from litellm.caching.caching import Cache
2896
 
2897
- litellm.cache = Cache(type="disk")
2898
 
2899
  self._completion = acompletion
2900
  # Initialize a semaphore to limit concurrency
2901
- self._semaphore = asyncio.Semaphore(self.max_requests_per_second)
2902
 
2903
  async def _infer_instance(
2904
  self, index: int, instance: Dict[str, Any]
@@ -3010,28 +3109,34 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
3010
  provider_specific_args: Optional[Dict[str, Dict[str,str]]] = None
3011
 
3012
  provider_model_map: Dict[_supported_apis, Dict[str, str]] = {
3013
- "watsonx": {
3014
- "llama-3-8b-instruct": "watsonx/meta-llama/llama-3-8b-instruct",
3015
- "llama-3-70b-instruct": "watsonx/meta-llama/llama-3-70b-instruct",
3016
- "llama-3-1-70b-instruct": "watsonx/meta-llama/llama-3-1-70b-instruct",
3017
- "llama-3-3-70b-instruct": "watsonx/meta-llama/llama-3-3-70b-instruct",
3018
- "granite-3-8b-instruct": "watsonx/ibm/granite-3-8b-instruct",
3019
- "flan-t5-xxl": "watsonx/google/flan-t5-xxl",
3020
- "llama-3-2-1b-instruct": "watsonx/meta-llama/llama-3-2-1b-instruct",
3021
- "llama-3-2-11b-vision-instruct": "watsonx/meta-llama/llama-3-2-11b-vision-instruct",
3022
- "llama-3-2-90b-vision-instruct": "watsonx/meta-llama/llama-3-2-90b-vision-instruct",
3023
- "mistral-large-instruct": "watsonx/mistralai/mistral-large",
3024
- },
3025
- "watsonx-sdk": {
3026
- "llama-3-2-11b-vision-instruct": "meta-llama/llama-3-2-11b-vision-instruct",
3027
- "llama-3-8b-instruct": "meta-llama/llama-3-8b-instruct",
3028
- "llama-3-70b-instruct": "meta-llama/llama-3-70b-instruct",
3029
  "granite-3-8b-instruct": "ibm/granite-3-8b-instruct",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3030
  },
3031
  "together-ai": {
3032
  "llama-3-8b-instruct": "together_ai/meta-llama/Llama-3-8b-chat-hf",
3033
  "llama-3-70b-instruct": "together_ai/meta-llama/Llama-3-70b-chat-hf",
 
 
 
3034
  "llama-3-2-1b-instruct": "together_ai/togethercomputer/llama-3-2-1b-instruct",
 
3035
  },
3036
  "aws": {
3037
  "llama-3-8b-instruct": "bedrock/meta.llama3-8b-instruct-v1:0",
@@ -3040,6 +3145,12 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
3040
  "ollama": {
3041
  "llama-3-8b-instruct": "llama3:8b",
3042
  "llama-3-70b-instruct": "llama3:70b",
 
 
 
 
 
 
3043
  },
3044
  "bam": {
3045
  "granite-3-8b-instruct": "ibm/granite-8b-instruct-preview-4k",
@@ -3049,12 +3160,14 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
3049
  },
3050
  "rits": {
3051
  "granite-3-8b-instruct": "ibm-granite/granite-3.0-8b-instruct",
 
3052
  "llama-3-1-8b-instruct": "meta-llama/llama-3-1-8b-instruct",
3053
  "llama-3-1-70b-instruct": "meta-llama/llama-3-1-70b-instruct",
 
 
3054
  "llama-3-2-11b-vision-instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
3055
  "llama-3-2-90b-vision-instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
3056
  "llama-3-3-70b-instruct": "meta-llama/llama-3-3-70b-instruct",
3057
- "llama-3-1-405b-instruct-fp8": "meta-llama/llama-3-1-405b-instruct-fp8",
3058
  "mistral-large-instruct": "mistralai/mistral-large-instruct-2407",
3059
  "mixtral-8x7b-instruct": "mistralai/mixtral-8x7B-instruct-v0.1",
3060
  },
@@ -3089,6 +3202,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
3089
  "o1-preview": "azure/o1-preview",
3090
  "gpt-4o-mini": "azure/gpt-4o-mini",
3091
  "gpt-4o": "azure/gpt-4o",
 
3092
  "gpt-4": "azure/gpt-4",
3093
  "gpt-4-0314": "azure/gpt-4-0314",
3094
  "gpt-4-0613": "azure/gpt-4-0613",
@@ -3133,6 +3247,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
3133
  "mixtral-8x7b-instruct-v0.1": "replicate/mistralai/mixtral-8x7b-instruct-v0.1",
3134
  },
3135
  }
 
3136
 
3137
  _provider_to_base_class = {
3138
  "watsonx": LiteLLMInferenceEngine,
@@ -3190,7 +3305,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
3190
  del args[param]
3191
  else:
3192
  del args[param]
3193
- self.engine = cls(**args)
3194
  self.data_classification_policy = self.engine.data_classification_policy
3195
 
3196
  def _infer(
@@ -3210,7 +3325,7 @@ class HFOptionSelectingInferenceEngine(InferenceEngine, TorchDeviceMixin):
3210
 
3211
  This class uses models from the HuggingFace Transformers library to calculate log probabilities for text inputs.
3212
  """
3213
-
3214
  model_name: str
3215
  batch_size: int
3216
 
@@ -3218,6 +3333,9 @@ class HFOptionSelectingInferenceEngine(InferenceEngine, TorchDeviceMixin):
3218
  "transformers": "Install huggingface package using 'pip install --upgrade transformers"
3219
  }
3220
 
 
 
 
3221
  def prepare_engine(self):
3222
  from transformers import AutoModelForCausalLM, AutoTokenizer
3223
 
 
2
  import asyncio
3
  import base64
4
  import dataclasses
5
+ import hashlib
6
  import io
7
  import json
8
  import logging
 
13
  import uuid
14
  from collections import Counter
15
  from datetime import datetime
16
+ from itertools import islice
17
  from multiprocessing.pool import ThreadPool
18
  from typing import (
19
  Any,
 
31
  )
32
 
33
  from datasets import Dataset, DatasetDict, Image
34
+ from diskcache import Cache
35
  from tqdm import tqdm, trange
36
  from tqdm.asyncio import tqdm_asyncio
37
 
 
56
  logger = get_logger()
57
 
58
 
59
+ def batched(lst, n):
60
+ it = iter(lst)
61
+ while batch := list(islice(it, n)):
62
+ yield batch
63
+
64
  class StandardAPIParamsMixin(Artifact):
65
  model: str
66
  frequency_penalty: Optional[float] = None
 
157
 
158
  class InferenceEngine(Artifact):
159
  """Abstract base class for inference."""
160
+ cache_batch_size: int = 100
161
+ use_cache: bool = True
162
 
163
  @abc.abstractmethod
164
  def _infer(
 
183
  if not settings.mock_inference_mode:
184
  super().prepare() # no need to prepare a mock
185
  self.prepare_engine()
186
+ self._cache = Cache(get_settings().inference_engine_cache_path + self.__class__.__name__)
187
 
188
  def __call__(
189
  self,
 
192
  ) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]:
193
  return self.infer(dataset=dataset, return_meta_data=return_meta_data)
194
 
195
+ def get_instance_cache_key(self, instance):
196
+ instance_key_fields = ["media", "source", "task_data"]
197
+ return {key: instance[key] for key in instance if key in instance_key_fields}
 
 
 
198
 
199
+ def _get_cache_key(self, instance: Dict[str, Any]) -> str:
200
+ """Generate a unique cache key for each input."""
201
+ record = self.get_instance_cache_key(instance)
202
+ record.update(self.to_dict())
203
+ instance_str = json.dumps(record, sort_keys=True)
204
+ return hashlib.md5(instance_str.encode()).hexdigest()
205
+
206
+ def verify_infer_inputs(self,
207
+ dataset: Union[List[Dict[str, Any]], Dataset],
208
+ return_meta_data: bool):
209
  if not isoftype(dataset, Union[List[Dict[str, Any]], Dataset]):
210
  raise Exception(
211
  "Dataset passed to infer() is not list of dictionaries or Huggingface Dataset"
 
217
  )
218
 
219
  [self.verify_instance(instance) for instance in dataset]
220
+
221
+ def infer(
222
+ self,
223
+ dataset: Union[List[Dict[str, Any]], Dataset],
224
+ return_meta_data: bool = False,
225
+ ) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]:
226
+ """Verifies instances of a dataset and perform inference on the input dataset.
227
+
228
+ If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the string
229
+ predictions.
230
+ """
231
+ self.verify_infer_inputs(dataset, return_meta_data)
232
  if settings.mock_inference_mode:
233
  result = self._mock_infer(dataset)
234
  else:
235
+ if self.use_cache:
236
+ number_of_batches = len(dataset) // self.cache_batch_size + 1
237
+ result = []
238
+ for batch_index, batch in enumerate(batched(dataset, self.cache_batch_size)):
239
+ cached_results = []
240
+ missing_examples = []
241
+ for i, item in enumerate(batch):
242
+ cache_key = self._get_cache_key(item)
243
+ cached_value = self._cache.get(cache_key)
244
+ if cached_value is not None:
245
+ cached_results.append((i, cached_value)) # each element is index in batch, and value
246
+ else:
247
+ missing_examples.append((i, item)) # each element is index in batch and example
248
+ # infare on missing examples only, without indices
249
+
250
+ logger.info(f"Inferring batch {batch_index + 1} / {number_of_batches} with {len(missing_examples)} instances (found {len(cached_results)} instances in {self._cache.directory})")
251
+ if (len(missing_examples) > 0):
252
+ inferred_results = self._infer([e[1] for e in missing_examples], return_meta_data)
253
+ # recombined to index and value
254
+ inferred_results = list(zip([e[0] for e in missing_examples], inferred_results))
255
+ # Add missing examples to cache
256
+ for (_, item), (_, prediction) in zip(missing_examples, inferred_results):
257
+ if prediction is None:
258
+ continue
259
+ cache_key = self._get_cache_key(item)
260
+ self._cache[cache_key] = prediction
261
+ else:
262
+ inferred_results=[]
263
+ # Combine cached and inferred results in original order
264
+ batch_predictions = [p[1] for p in sorted(cached_results + inferred_results)]
265
+ result.extend(batch_predictions)
266
+ else:
267
+ result = self._infer(dataset, return_meta_data)
268
  return ListWithMetadata(
269
  result,
270
  metadata={
 
280
  ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
281
  return [str(instance["source"]) for instance in dataset]
282
 
283
+ @abc.abstractmethod
284
  def get_engine_id(self):
285
  raise NotImplementedError()
286
 
 
978
  return args
979
 
980
  def _create_pipeline(self, model_args: Dict[str, Any]):
981
+ from transformers import AutoTokenizer, pipeline
982
 
983
  path = self.model_name
984
  if settings.hf_offline_models_path is not None:
985
  path = os.path.join(settings.hf_offline_models_path, path)
986
 
987
+ tokenizer = AutoTokenizer.from_pretrained(self.model_name)
988
  self.model = pipeline(
989
  model=path,
990
  task=self.task,
991
  use_fast=self.use_fast_tokenizer,
992
+ tokenizer=tokenizer,
993
  trust_remote_code=settings.allow_unverified_code,
994
  **model_args,
995
  **self.to_dict(
 
1364
  def _get_credentials():
1365
  from genai import Credentials
1366
 
1367
+ api_key_env_var_name = "GENAI_KEY" # pragma: allowlist secret
1368
  api_key = os.environ.get(api_key_env_var_name)
1369
 
1370
  assert api_key is not None, (
 
1780
  ), "Error while trying to run AzureOpenAIInferenceEngine: Missing environment variable param AZURE_OPENAI_HOST or OPENAI_API_VERSION"
1781
  api_url = f"{azure_openapi_host}/openai/deployments/{self.model_name}/chat/completions?api-version={api_version}"
1782
 
1783
+ return {"api_key": api_key, "api_url": api_url, "api_version": api_version}
1784
 
1785
  def create_client(self):
1786
  from openai import AzureOpenAI
 
1789
  return AzureOpenAI(
1790
  api_key=self.credentials["api_key"],
1791
  base_url=self.credentials["api_url"],
1792
+ api_version=self.credentials["api_version"],
1793
  default_headers=self.get_default_headers(),
1794
  )
1795
 
1796
 
1797
  class VLLMRemoteInferenceEngine(OpenAiInferenceEngine):
1798
+ label: str = "vllm-remote"
1799
 
1800
 
1801
  class RITSInferenceEngine(
 
1804
  label: str = "rits"
1805
  data_classification_policy = ["public", "proprietary"]
1806
 
1807
+ model_names_dict = {
1808
+ "microsoft/phi-4": "microsoft-phi-4"
1809
+ }
1810
+
1811
  def get_default_headers(self):
1812
  return {"RITS_API_KEY": self.credentials["api_key"]}
1813
 
 
1828
  RITSInferenceEngine._get_model_name_for_endpoint(model_name)
1829
  )
1830
 
1831
+ @classmethod
1832
+ def _get_model_name_for_endpoint(cls, model_name: str):
1833
+ if model_name in cls.model_names_dict:
1834
+ return cls.model_names_dict[model_name]
1835
  return (
1836
  model_name.split("/")[-1]
1837
  .lower()
 
1874
  from together import Together
1875
  from together.types.models import ModelType
1876
 
1877
+ api_key_env_var_name = "TOGETHER_API_KEY" # pragma: allowlist secret
1878
  api_key = os.environ.get(api_key_env_var_name)
1879
  assert api_key is not None, (
1880
  f"Error while trying to run TogetherAiInferenceEngine."
 
2038
  deployment_id (str, optional):
2039
  Deployment ID of a tuned model to be used for
2040
  inference. Mutually exclusive with 'model_name'.
2041
+ concurrency_limit (int):
2042
+ Number of concurrent requests sent to a model. Default is 10,
2043
+ which is also the maximum value for the generation.
2044
  parameters (Union[WMLInferenceEngineParams, WMLGenerationParamsMixin, WMLChatParamsMixin], optional):
2045
  Defines inference parameters and their values. Deprecated attribute, please pass respective
2046
  parameters directly to the respective class instead.
 
2049
  credentials: Optional[CredentialsWML] = None
2050
  model_name: Optional[str] = None
2051
  deployment_id: Optional[str] = None
2052
+ concurrency_limit: int = 10
2053
  label: str = "wml"
2054
  _requirements_list = {
2055
  "ibm_watsonx_ai": "Install ibm-watsonx-ai package using 'pip install --upgrade ibm-watsonx-ai'. "
 
2303
 
2304
  If you want to include images in your input, please use 'WMLInferenceEngineChat' instead.
2305
 
 
 
 
 
 
2306
  Examples:
2307
  .. code-block:: python
2308
 
 
2326
  results = wml_inference.infer(dataset["test"])
2327
  """
2328
 
 
 
2329
  def verify(self):
2330
  super().verify()
2331
 
 
2577
  # images as SDK allows sending only one image per message.
2578
  return [messages]
2579
 
2580
+ def _handle_async_requests(
2581
+ self,
2582
+ messages: List[List[Dict[str, Any]]],
2583
+ params: Dict[str, Any],
2584
+ ) -> List[Dict[str, Any]]:
2585
+ async def handle_async_requests(start_idx, end_idx):
2586
+ coroutines = [
2587
+ self._model.achat(messages=messages[idx], params=params)
2588
+ for idx in range(start_idx, end_idx)
2589
+ ]
2590
+ batch_results = await asyncio.gather(*coroutines)
2591
+ return list(batch_results)
2592
+
2593
+ loop = asyncio.get_event_loop()
2594
+ results = []
2595
+
2596
+ for batch_idx in range(0, len(messages), self.concurrency_limit):
2597
+ batch_results = loop.run_until_complete(
2598
+ handle_async_requests(
2599
+ batch_idx, min(batch_idx + self.concurrency_limit, len(messages))
2600
+ )
2601
+ )
2602
+ results.extend(batch_results)
2603
+
2604
+ return results
2605
+
2606
  def _send_requests(
2607
  self,
2608
  dataset: Union[List[Dict[str, Any]], Dataset],
 
2618
  output_type = "message"
2619
  params["logprobs"] = False
2620
 
2621
+ indexed_messages = [
2622
+ (i, message)
2623
+ for i in range(len(dataset))
2624
+ for message in self.to_messages(dataset[i])
2625
+ ]
 
 
 
 
 
2626
 
2627
+ results = self._handle_async_requests(
2628
+ [msg[1] for msg in indexed_messages], params
2629
+ )
 
 
 
 
 
2630
 
2631
+ return [
2632
+ self.get_return_object(
2633
+ result["choices"][0][output_type]["content"],
2634
+ result,
2635
+ dataset[idx[0]]["source"],
2636
+ return_meta_data,
2637
+ )
2638
+ for result, idx in zip(results, indexed_messages)
2639
+ ]
2640
 
2641
  def get_return_object(self, predict_result, result, input_text, return_meta_data):
2642
  if return_meta_data:
 
2704
  class LMMSEvalBaseInferenceEngine(
2705
  InferenceEngine, PackageRequirementsMixin, LazyLoadMixin, TorchDeviceMixin
2706
  ):
2707
+ label = "lmms-eval"
2708
  model_type: str
2709
  model_args: Dict[str, str]
2710
  batch_size: int = 1
 
2714
  "lmms_eval": "Install llms-eval package using 'pip install lmms-eval==0.2.4'",
2715
  }
2716
 
2717
+ def get_engine_id(self):
2718
+ return get_model_and_label_id(self.model_type, self.label)
2719
+
2720
  def prepare_engine(self):
2721
  if not self.lazy_load:
2722
  self._prepare_engine()
 
2892
 
2893
 
2894
  class VLLMInferenceEngine(InferenceEngine, PackageRequirementsMixin, VLLMParamsMixin):
2895
+ label="vllm"
2896
+
2897
+ def get_engine_id(self):
2898
+ return get_model_and_label_id(self.model, self.label)
2899
+
2900
  def prepare_engine(self):
2901
  args = self.to_dict([VLLMParamsMixin])
2902
  args.pop("model")
 
2982
 
2983
  _requirements_list: list = ["litellm", "tenacity", "tqdm", "diskcache"]
2984
 
2985
+ def get_engine_id(self):
2986
+ return get_model_and_label_id(self.model, self.label)
2987
+
2988
  def prepare_engine(self):
2989
  # Initialize the token bucket rate limiter
2990
  self._rate_limiter = AsyncTokenBucket(
 
2992
  capacity=self.max_requests_per_second,
2993
  )
2994
  self.inference_type = "litellm"
 
2995
  from litellm import acompletion
 
2996
 
 
2997
 
2998
  self._completion = acompletion
2999
  # Initialize a semaphore to limit concurrency
3000
+ self._semaphore = asyncio.Semaphore(round(self.max_requests_per_second))
3001
 
3002
  async def _infer_instance(
3003
  self, index: int, instance: Dict[str, Any]
 
3109
  provider_specific_args: Optional[Dict[str, Dict[str,str]]] = None
3110
 
3111
  provider_model_map: Dict[_supported_apis, Dict[str, str]] = {
3112
+ "watsonx-sdk": { # checked from ibm_watsonx_ai.APIClient().foundation_models.ChatModels
3113
+ "granite-20b-code-instruct": "ibm/granite-20b-code-instruct",
3114
+ "granite-3-2-8b-instruct": "ibm/granite-3-2-8b-instruct",
3115
+ "granite-3-2b-instruct": "ibm/granite-3-2b-instruct",
 
 
 
 
 
 
 
 
 
 
 
 
3116
  "granite-3-8b-instruct": "ibm/granite-3-8b-instruct",
3117
+ "granite-34b-code-instruct": "ibm/granite-34b-code-instruct",
3118
+ "granite-guardian-3-8b": "ibm/granite-guardian-3-8b",
3119
+ "granite-vision-3-2-2b": "ibm/granite-vision-3-2-2b",
3120
+ "llama-3-1-8b-instruct": "meta-llama/llama-3-1-8b-instruct",
3121
+ "llama-3-1-70b-instruct": "meta-llama/llama-3-1-70b-instruct",
3122
+ "llama-3-1-405b-instruct": "meta-llama/llama-3-405b-instruct",
3123
+ "llama-3-2-11b-vision-instruct": "meta-llama/llama-3-2-11b-vision-instruct",
3124
+ "llama-3-2-1b-instruct": "meta-llama/llama-3-2-1b-instruct",
3125
+ "llama-3-2-3b-instruct": "meta-llama/llama-3-2-3b-instruct",
3126
+ "llama-3-2-90b-vision-instruct": "meta-llama/llama-3-2-90b-vision-instruct",
3127
+ "llama-3-3-70b-instruct": "meta-llama/llama-3-3-70b-instruct",
3128
+ "llama-guard-3-11b-vision": "meta-llama/llama-guard-3-11b-vision",
3129
+ "mistral-large-instruct": "mistralai/mistral-large",
3130
+ "mixtral-8x7b-instruct-v01": "mistralai/mixtral-8x7b-instruct-v01",
3131
  },
3132
  "together-ai": {
3133
  "llama-3-8b-instruct": "together_ai/meta-llama/Llama-3-8b-chat-hf",
3134
  "llama-3-70b-instruct": "together_ai/meta-llama/Llama-3-70b-chat-hf",
3135
+ "llama-3-1-8b-instruct": "together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
3136
+ "llama-3-1-70b-instruct": "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
3137
+ "llama-3-1-405b-instruct": "together_ai/meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
3138
  "llama-3-2-1b-instruct": "together_ai/togethercomputer/llama-3-2-1b-instruct",
3139
+ "llama-3-3-70b-instruct": "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo"
3140
  },
3141
  "aws": {
3142
  "llama-3-8b-instruct": "bedrock/meta.llama3-8b-instruct-v1:0",
 
3145
  "ollama": {
3146
  "llama-3-8b-instruct": "llama3:8b",
3147
  "llama-3-70b-instruct": "llama3:70b",
3148
+ "llama-3-1-8b-instruct": "llama3.1:8b",
3149
+ "llama-3-1-70b-instruct": "llama3.1:70b",
3150
+ "llama-3-1-405b-instruct": "llama3.1:405b",
3151
+ "llama-3-2-1b-instruct": "llama3.2:1b",
3152
+ "llama-3-2-3b-instruct": "llama3.2:3b",
3153
+ "llama-3-3-70b-instruct": "llama3.3"
3154
  },
3155
  "bam": {
3156
  "granite-3-8b-instruct": "ibm/granite-8b-instruct-preview-4k",
 
3160
  },
3161
  "rits": {
3162
  "granite-3-8b-instruct": "ibm-granite/granite-3.0-8b-instruct",
3163
+ "granite-3-2-8b-instruct": "ibm-granite/granite-3.2-8b-instruct",
3164
  "llama-3-1-8b-instruct": "meta-llama/llama-3-1-8b-instruct",
3165
  "llama-3-1-70b-instruct": "meta-llama/llama-3-1-70b-instruct",
3166
+ "llama-3-1-405b-instruct": "meta-llama/llama-3-1-405b-instruct-fp8",
3167
+ "llama-3-1-405b-instruct-fp8": "meta-llama/llama-3-1-405b-instruct-fp8",
3168
  "llama-3-2-11b-vision-instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
3169
  "llama-3-2-90b-vision-instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
3170
  "llama-3-3-70b-instruct": "meta-llama/llama-3-3-70b-instruct",
 
3171
  "mistral-large-instruct": "mistralai/mistral-large-instruct-2407",
3172
  "mixtral-8x7b-instruct": "mistralai/mixtral-8x7B-instruct-v0.1",
3173
  },
 
3202
  "o1-preview": "azure/o1-preview",
3203
  "gpt-4o-mini": "azure/gpt-4o-mini",
3204
  "gpt-4o": "azure/gpt-4o",
3205
+ "gpt-4o-2024-08-06": "azure/gpt-4o-2024-08-06",
3206
  "gpt-4": "azure/gpt-4",
3207
  "gpt-4-0314": "azure/gpt-4-0314",
3208
  "gpt-4-0613": "azure/gpt-4-0613",
 
3247
  "mixtral-8x7b-instruct-v0.1": "replicate/mistralai/mixtral-8x7b-instruct-v0.1",
3248
  },
3249
  }
3250
+ provider_model_map["watsonx"] = {k: f"watsonx/{v}" for k,v in provider_model_map["watsonx-sdk"].items()}
3251
 
3252
  _provider_to_base_class = {
3253
  "watsonx": LiteLLMInferenceEngine,
 
3305
  del args[param]
3306
  else:
3307
  del args[param]
3308
+ self.engine: InferenceEngine = cls(**args)
3309
  self.data_classification_policy = self.engine.data_classification_policy
3310
 
3311
  def _infer(
 
3325
 
3326
  This class uses models from the HuggingFace Transformers library to calculate log probabilities for text inputs.
3327
  """
3328
+ label = "hf_option_selection"
3329
  model_name: str
3330
  batch_size: int
3331
 
 
3333
  "transformers": "Install huggingface package using 'pip install --upgrade transformers"
3334
  }
3335
 
3336
+ def get_engine_id(self):
3337
+ return get_model_and_label_id(self.model_name, self.label)
3338
+
3339
  def prepare_engine(self):
3340
  from transformers import AutoModelForCausalLM, AutoTokenizer
3341
 
llm_as_judge.py CHANGED
@@ -8,15 +8,12 @@ from .dict_utils import dict_get
8
  from .error_utils import UnitxtError
9
  from .inference import (
10
  InferenceEngine,
11
- OptionSelectingByLogProbsInferenceEngine,
12
  )
13
  from .llm_as_judge_chat_templates import direct_template_dict, pairwise_template_dict
14
  from .llm_as_judge_constants import (
15
  DIRECT_CRITERIA,
16
  EVALUATOR_TO_MODEL_ID,
17
  EVALUATORS_METADATA,
18
- INFERENCE_ENGINE_NAME_TO_CLASS,
19
- MODEL_RENAMINGS,
20
  PAIRWISE_CRITERIA,
21
  Criteria,
22
  CriteriaOption,
@@ -44,30 +41,50 @@ from .llm_as_judge_utils import (
44
  get_evaluator_metadata,
45
  get_parsed_context,
46
  rank_indexes,
47
- rename_model_if_required,
48
  )
49
  from .logging_utils import get_logger
50
  from .metrics import BulkInstanceMetric
51
  from .task import Task
52
  from .templates import Template
53
 
 
54
 
55
  class LLMJudge(BulkInstanceMetric):
 
 
 
 
 
56
  inference_engine: InferenceEngine
57
- # option_selection_strategy: OptionSelectionStrategyEnum = (
58
- # OptionSelectionStrategyEnum.PARSE_OUTPUT_TEXT
59
- # )
60
  evaluator_name: EvaluatorNameEnum = None
 
 
61
  check_positional_bias: bool = True
 
 
62
  context_fields: Union[str, List[str], Dict[str, str]] = ["context"]
63
- generate_summaries: bool = True
64
- format = "formats.chat_api"
65
- include_prompts_in_result: bool = False
 
 
 
 
 
 
 
 
66
  criteria_field: str = None
 
 
67
  criteria: Criteria = None
68
- logger = get_logger()
 
69
 
70
  def prepare(self):
 
71
  super().prepare()
72
  if isinstance(self.context_fields, str):
73
  self.context_fields = [self.context_fields]
@@ -78,10 +95,13 @@ class LLMJudge(BulkInstanceMetric):
78
 
79
  if self.evaluator_name is None:
80
  self.evaluator_name = self.inference_engine.get_engine_id()
81
- elif not isinstance(self.evaluator_name, EvaluatorNameEnum):
82
- self.evaluator_name = EvaluatorNameEnum[self.evaluator_name]
83
 
84
  def before_process_multi_stream(self):
 
 
 
 
 
85
  super().before_process_multi_stream()
86
  # We check the criteria here and not in verify(), because we want catalog
87
  # may contain a partially initialized object, and verify() method
@@ -93,6 +113,14 @@ class LLMJudge(BulkInstanceMetric):
93
  return
94
 
95
  def get_contexts(self, task_data: List[Dict[str, Any]]) -> List[Dict[str, str]]:
 
 
 
 
 
 
 
 
96
  return [
97
  get_parsed_context(
98
  {
@@ -110,6 +138,17 @@ class LLMJudge(BulkInstanceMetric):
110
  template: Template,
111
  previous_messages: Optional[List[Dict[str, str]]] = None,
112
  ):
 
 
 
 
 
 
 
 
 
 
 
113
  outputs_dataset = infer(
114
  instances,
115
  task=task,
@@ -129,6 +168,14 @@ class LLMJudge(BulkInstanceMetric):
129
  return (prompts, raw_predictions, predictions)
130
 
131
  def clean_results(self, results: Union[dict, list]):
 
 
 
 
 
 
 
 
132
  if isinstance(results, list):
133
  return [self.clean_results(x) for x in results]
134
  cleaned = {
@@ -143,13 +190,25 @@ class LLMJudge(BulkInstanceMetric):
143
  if not (isinstance(v, dict) and len(v) == 0)
144
  }
145
 
146
- def get_criterias(self, task_data, eval_count):
 
 
 
 
 
 
 
 
 
 
 
 
147
  if self.criteria is None:
148
  if self.criteria_field not in task_data[0]:
149
  raise UnitxtError(
150
  f"The criteria field `{self.criteria_field}` required for {__class__.__name__} is not found in instance. Perhaps you meant '{get_close_matches(self.criteria_field, task_data[0].keys(), n=1, cutoff=0.0)[0]}'?"
151
  )
152
- self.logger.info(
153
  f"Reading criteria from the task_data field '{self.criteria_field}'"
154
  )
155
  criterias = [
@@ -157,20 +216,31 @@ class LLMJudge(BulkInstanceMetric):
157
  for task_data_instance in task_data
158
  ]
159
  else:
160
- self.logger.info(
161
  "Reading criteria from self. Criteria is a single CriteriaWithOptions, replicating it for all predictions"
162
  )
163
  criterias: List[Criteria] = [self.criteria] * eval_count
164
  unique_criteria_names = list({criteria.name for criteria in criterias})
165
 
166
- self.logger.info(f"Criteria names are '{', '.join(unique_criteria_names)}'")
167
  return criterias
168
 
169
 
170
  class LLMJudgeDirect(LLMJudge):
 
 
 
 
 
 
 
 
171
  criteria: CriteriaWithOptions = None
 
172
  main_score = "llm_as_judge"
 
173
  reduction_map = {"mean": ["llm_as_judge"]}
 
174
 
175
  def prepare(self):
176
  super().prepare()
@@ -200,7 +270,7 @@ class LLMJudgeDirect(LLMJudge):
200
  self.option_selection_task = Task(
201
  input_fields={
202
  "criteria_description": str,
203
- "score_option_instruction": str,
204
  "options": list,
205
  },
206
  reference_fields={},
@@ -209,6 +279,7 @@ class LLMJudgeDirect(LLMJudge):
209
  )
210
 
211
  def before_process_multi_stream(self):
 
212
  super().before_process_multi_stream()
213
  if self.criteria is not None and not isinstance(
214
  self.criteria, CriteriaWithOptions
@@ -218,34 +289,42 @@ class LLMJudgeDirect(LLMJudge):
218
  )
219
  return
220
 
221
- def get_parsed_criteria(self, criteria: CriteriaWithOptions):
 
 
 
 
 
 
 
 
 
 
 
 
222
  criteria_description = criteria.description
223
  criteria_option_names = [o.name for o in criteria.options]
224
 
225
- display_options_instruction = "Choose an answer:\n" + "\n".join(
226
  [
227
  f'- "{o.name}"{f" if {o.description}" if o.description != "" else ""}'
228
  for o in criteria.options
229
  ]
230
  )
231
- score_option_instruction = "".join(
232
- [f"Score {o.name}: {o.description}\n" for o in criteria.options]
233
- )
234
 
235
  return (
236
  criteria_description,
237
  criteria_option_names,
238
  display_options_instruction,
239
- score_option_instruction,
240
  )
241
 
242
- def set_main_score(self, criterias: List[CriteriaWithOptions]):
243
  unique_criteria_names = list({criteria.name for criteria in criterias})
244
  if len(unique_criteria_names) == 1 and criterias[0].name != "":
245
  self.main_score = "_".join(criterias[0].name.lower().split(" "))
246
  self.reduction_map = {"mean": [self.main_score]}
247
 
248
- def get_results(
249
  self,
250
  assessment_prompts,
251
  assessment_outputs,
@@ -289,6 +368,9 @@ class LLMJudgeDirect(LLMJudge):
289
  "summary": summarization_outputs[i]
290
  if self.generate_summaries
291
  else None,
 
 
 
292
  "prompts": {
293
  "assessment": assessment_prompts[i],
294
  "positional_bias_assessment": assessment_prompts[
@@ -332,14 +414,113 @@ class LLMJudgeDirect(LLMJudge):
332
  references: List[List[str]],
333
  predictions: List[str],
334
  task_data: List[Dict[str, Any]],
335
- ) -> dict:
336
- self.logger.info(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  f'Starting evaluation with evaluator "{self.evaluator_name}" and provider "{self.inference_engine.get_pretty_print_name()}'
338
  )
339
  evaluations_count = len(predictions)
340
  # TODO: find out how to serialize and deserialize enums
341
- criterias = self.get_criterias(task_data, evaluations_count)
342
- self.set_main_score(criterias)
343
  contexts = self.get_contexts(task_data)
344
  if self.check_positional_bias:
345
  criterias += [
@@ -355,14 +536,13 @@ class LLMJudgeDirect(LLMJudge):
355
  predictions += predictions
356
 
357
  parsed_criterias = [
358
- self.get_parsed_criteria(criteria) for criteria in criterias
359
  ]
360
 
361
  (
362
  criteria_description_list,
363
  criteria_option_names_list,
364
  display_options_instruction_list,
365
- score_option_instruction_list,
366
  ) = zip(*parsed_criterias)
367
 
368
  assessment_for_summaries_slice = slice(0, evaluations_count)
@@ -385,7 +565,7 @@ class LLMJudgeDirect(LLMJudge):
385
  assessment_prompts, assessment_outputs, _ = self.perform_evaluation_step(
386
  assessment_instances, self.assessment_task, self.assessment_template
387
  )
388
- self.logger.info("The assessment was generated successfully.")
389
 
390
  summarization_prompts = None
391
  summarization_outputs = None
@@ -409,18 +589,22 @@ class LLMJudgeDirect(LLMJudge):
409
  self.summarization_task,
410
  self.summarization_template,
411
  )
412
- self.logger.info("The summary was generated successfully.")
413
 
414
  option_selection_instances = [
415
  {
416
  "criteria_description": criteria_description,
417
- "score_option_instruction": score_option_instruction,
418
  "options": criteria_option_names,
419
  "data_classification_policy": ["public"],
420
  }
421
- for criteria_description, score_option_instruction, criteria_option_names in zip(
 
 
 
 
422
  criteria_description_list,
423
- score_option_instruction_list,
424
  criteria_option_names_list,
425
  )
426
  ]
@@ -441,9 +625,9 @@ class LLMJudgeDirect(LLMJudge):
441
  self.option_selection_template,
442
  previous_messages,
443
  )
444
- self.logger.info("The selections were calculated successfully.")
445
 
446
- results = self.get_results(
447
  assessment_prompts,
448
  assessment_outputs,
449
  summarization_prompts,
@@ -454,15 +638,19 @@ class LLMJudgeDirect(LLMJudge):
454
  evaluations_count,
455
  criterias,
456
  )
 
457
  return self.clean_results(results)
458
 
459
 
460
  class LLMJudgePairwise(LLMJudge):
461
- reduction_map = {"mean": ["score"]}
462
  main_score = "1_winrate"
463
- prediction_type = List[str]
 
 
464
 
465
  def prepare(self):
 
466
  super().prepare()
467
  self.assessment_template = pairwise_template_dict["assessment"]
468
  self.summarization_template = pairwise_template_dict["summarization"]
@@ -501,6 +689,7 @@ class LLMJudgePairwise(LLMJudge):
501
  )
502
 
503
  def before_process_multi_stream(self):
 
504
  super().before_process_multi_stream()
505
  if self.criteria is not None and not isinstance(self.criteria, Criteria):
506
  raise Exception(
@@ -508,7 +697,7 @@ class LLMJudgePairwise(LLMJudge):
508
  )
509
  return
510
 
511
- def get_instance_results(
512
  self,
513
  instance_predictions: Dict[str, str],
514
  assessment_prompts,
@@ -520,8 +709,26 @@ class LLMJudgePairwise(LLMJudge):
520
  selections,
521
  contests_count,
522
  combination_indexes,
523
- criteria: Criteria,
524
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  response_names = list(instance_predictions.keys())
526
  per_response_results = {
527
  response_key: {
@@ -680,32 +887,479 @@ class LLMJudgePairwise(LLMJudge):
680
  for metric in single_result.keys():
681
  all_results[f"{response_name}_{metric}"] = single_result[metric]
682
 
683
- all_results["criteria"] = criteria.to_json()
684
  return self.clean_results(all_results)
685
 
686
- def parse_prediction_to_dict(self, prediction: Union[Dict[str, str], List[str]]):
687
- if isinstance(prediction, list):
688
- return {f"{key + 1}": value for key, value in enumerate(prediction)}
689
-
690
- raise Exception(
691
- f"Prediction may be a list or a dict. Instead got type {type(prediction)}"
 
 
 
 
 
 
 
 
 
692
  )
693
 
694
- def convert_predictions_to_dicts(
695
  self, predictions: Union[List[Dict[str, str]], List[str]]
696
  ):
697
- return [self.parse_prediction_to_dict(prediction) for prediction in predictions]
 
 
 
 
 
 
 
 
 
 
 
698
 
699
  def compute(
700
  self,
701
  references: List[List[str]],
702
  predictions: List[str],
703
  task_data: List[Dict[str, str]],
704
- ) -> dict:
705
- self.logger.info(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706
  f'Starting evaluation with evaluator "{self.evaluator_name}" and provider {self.inference_engine.get_pretty_print_name()}'
707
  )
708
- predictions = self.convert_predictions_to_dicts(predictions)
 
709
  instances_count = len(predictions)
710
  self.reduction_map = {"mean": ["score"]}
711
  self.reduction_map["mean"].extend(
@@ -721,7 +1375,7 @@ class LLMJudgePairwise(LLMJudge):
721
  len(combination_indexes) for combination_indexes in combination_indexes_list
722
  ]
723
 
724
- self.logger.info(
725
  f"The evaluation will perform {sum(contests_count_list) * [1, 2][self.check_positional_bias]} ({' + '.join([f'{c * [1, 2][self.check_positional_bias]}' for c in contests_count_list])}) pairwise comparisons"
726
  )
727
 
@@ -752,7 +1406,7 @@ class LLMJudgePairwise(LLMJudge):
752
  response_pairs_list.append(response_pairs)
753
  option_pairs_list.append(option_pairs)
754
 
755
- criterias = self.get_criterias(task_data, instances_count)
756
  contexts = self.get_contexts(task_data)
757
  if self.check_positional_bias:
758
  criterias.extend(criterias)
@@ -786,7 +1440,7 @@ class LLMJudgePairwise(LLMJudge):
786
  assessment_prompts, assessment_outputs, _ = self.perform_evaluation_step(
787
  assessment_instances, self.assessment_task, self.assessment_template
788
  )
789
- self.logger.info("The assessment was generated successfully.")
790
 
791
  # the slices used to get the assessment for each summary generation instance
792
  # it will grab the whole assessment for a particular instance or half of it depending on the value of check_positional_bias
@@ -836,7 +1490,7 @@ class LLMJudgePairwise(LLMJudge):
836
  self.summarization_task,
837
  self.summarization_template,
838
  )
839
- self.logger.info("The summary was generated successfully.")
840
 
841
  score_option_instruction_list = [
842
  "".join(
@@ -884,7 +1538,7 @@ class LLMJudgePairwise(LLMJudge):
884
  )
885
  # Selections are of the form 'Response n', so we just keep n
886
  selections = [selection.split(" ")[-1] for selection in selections]
887
- self.logger.info("The selections were calculated successfully.")
888
  results = []
889
  slice_start = 0
890
  for i, incremental_contests_count in enumerate(incremental_contests_count_list):
@@ -897,7 +1551,7 @@ class LLMJudgePairwise(LLMJudge):
897
  (incremental_contests_count_list[i - 1] if i > 0 else 0)
898
  + incremental_contests_count,
899
  )
900
- instance_results = self.get_instance_results(
901
  predictions[i],
902
  assessment_prompts[sli],
903
  assessment_outputs[sli],
 
8
  from .error_utils import UnitxtError
9
  from .inference import (
10
  InferenceEngine,
 
11
  )
12
  from .llm_as_judge_chat_templates import direct_template_dict, pairwise_template_dict
13
  from .llm_as_judge_constants import (
14
  DIRECT_CRITERIA,
15
  EVALUATOR_TO_MODEL_ID,
16
  EVALUATORS_METADATA,
 
 
17
  PAIRWISE_CRITERIA,
18
  Criteria,
19
  CriteriaOption,
 
41
  get_evaluator_metadata,
42
  get_parsed_context,
43
  rank_indexes,
 
44
  )
45
  from .logging_utils import get_logger
46
  from .metrics import BulkInstanceMetric
47
  from .task import Task
48
  from .templates import Template
49
 
50
+ logger = get_logger(__name__)
51
 
52
  class LLMJudge(BulkInstanceMetric):
53
+ """A metric class to evaluate instances using LLM as a Judge.
54
+
55
+ Evaluations are performed in two steps. First, the LLM is asked to generate an assessment following a CoT approach based on the criteria. Then, the same LLM is asked to select one of the available options. A summary of the general assessment can be generated for easy consumption by end users.
56
+ """
57
+
58
  inference_engine: InferenceEngine
59
+ """The engine used for generating predictions in the different evaluation steps."""
60
+
 
61
  evaluator_name: EvaluatorNameEnum = None
62
+ """The name of the evaluator. It is used for score naming. If not provided `self.inference_engine.get_engine_id()` is used."""
63
+
64
  check_positional_bias: bool = True
65
+ """Flag to check for positional bias. Detecting for positional bias duplicates the amount of inference calls."""
66
+
67
  context_fields: Union[str, List[str], Dict[str, str]] = ["context"]
68
+ """Fields to be used as context. If a dict is provided, the keys are used as the final names in the prompts, while the values are used to access the context variable values in the `task_data` object."""
69
+
70
+ generate_summaries: bool = False
71
+ """Flag to generate summaries of the assessments. Defaults to `False`."""
72
+
73
+ format: str = "formats.chat_api"
74
+ """The format used for the inference. Defaults to `formats.chat_api` (only allowed value)."""
75
+
76
+ include_prompts_in_result: bool = True
77
+ """Flag to include prompts in the result. Defaults to `True`."""
78
+
79
  criteria_field: str = None
80
+ """The field specifying the evaluation criteria in the `task_data` object."""
81
+
82
  criteria: Criteria = None
83
+ """The criteria used for evaluation. If the `criteria_field` is provided, it will take precedence."""
84
+
85
 
86
  def prepare(self):
87
+ """Prepares the `LLMJudge` instance by setting up context fields and evaluator name."""
88
  super().prepare()
89
  if isinstance(self.context_fields, str):
90
  self.context_fields = [self.context_fields]
 
95
 
96
  if self.evaluator_name is None:
97
  self.evaluator_name = self.inference_engine.get_engine_id()
 
 
98
 
99
  def before_process_multi_stream(self):
100
+ """Checks the criteria-related fields correctness before processing multiple streams.
101
+
102
+ Raises:
103
+ UnitxtError: If both 'criteria' and 'criteria_field' are not set.
104
+ """
105
  super().before_process_multi_stream()
106
  # We check the criteria here and not in verify(), because we want catalog
107
  # may contain a partially initialized object, and verify() method
 
113
  return
114
 
115
  def get_contexts(self, task_data: List[Dict[str, Any]]) -> List[Dict[str, str]]:
116
+ """Extracts and parses context fields from task data.
117
+
118
+ Args:
119
+ task_data (List[Dict[str, Any]]): The task data containing context information.
120
+
121
+ Returns:
122
+ List[Dict[str, str]]: A list of parsed context dictionaries.
123
+ """
124
  return [
125
  get_parsed_context(
126
  {
 
138
  template: Template,
139
  previous_messages: Optional[List[Dict[str, str]]] = None,
140
  ):
141
+ """Performs an evaluation step by generating predictions for the given instances.
142
+
143
+ Args:
144
+ instances (list): The list of instances to evaluate.
145
+ task (Task): The task associated with the instances.
146
+ template (Template): The template used for generating predictions.
147
+ previous_messages (Optional[List[Dict[str, str]]]): Previous messages for context.
148
+
149
+ Returns:
150
+ Tuple[List[str], List[str], List[str]]: A tuple containing prompts, raw predictions, and processed predictions. Raw predictions differ from processed predictions only in the completion step, where the processors.match_closest_option is used.
151
+ """
152
  outputs_dataset = infer(
153
  instances,
154
  task=task,
 
168
  return (prompts, raw_predictions, predictions)
169
 
170
  def clean_results(self, results: Union[dict, list]):
171
+ """Cleans the results by removing `None` values and empty lists and dictionaries.
172
+
173
+ Args:
174
+ results (Union[dict, list]): The results to clean.
175
+
176
+ Returns:
177
+ Union[dict, list]: The cleaned results.
178
+ """
179
  if isinstance(results, list):
180
  return [self.clean_results(x) for x in results]
181
  cleaned = {
 
190
  if not (isinstance(v, dict) and len(v) == 0)
191
  }
192
 
193
+ def get_criteria(self, task_data, eval_count):
194
+ """Retrieves the evaluation criteria from the `criteria_field` or from `self`.
195
+
196
+ Args:
197
+ task_data (List[Dict[str, Any]]): The task data containing criteria information.
198
+ eval_count (int): The number of evaluations to perform.
199
+
200
+ Returns:
201
+ List[Criteria]: A list of criteria for evaluation.
202
+
203
+ Raises:
204
+ UnitxtError: If the criteria field is not found in the task data.
205
+ """
206
  if self.criteria is None:
207
  if self.criteria_field not in task_data[0]:
208
  raise UnitxtError(
209
  f"The criteria field `{self.criteria_field}` required for {__class__.__name__} is not found in instance. Perhaps you meant '{get_close_matches(self.criteria_field, task_data[0].keys(), n=1, cutoff=0.0)[0]}'?"
210
  )
211
+ logger.info(
212
  f"Reading criteria from the task_data field '{self.criteria_field}'"
213
  )
214
  criterias = [
 
216
  for task_data_instance in task_data
217
  ]
218
  else:
219
+ logger.info(
220
  "Reading criteria from self. Criteria is a single CriteriaWithOptions, replicating it for all predictions"
221
  )
222
  criterias: List[Criteria] = [self.criteria] * eval_count
223
  unique_criteria_names = list({criteria.name for criteria in criterias})
224
 
225
+ logger.info(f"Criteria names are '{', '.join(unique_criteria_names)}'")
226
  return criterias
227
 
228
 
229
  class LLMJudgeDirect(LLMJudge):
230
+ """LLMJudgeDirect is a specialized evaluation metric that performs Direct Assessment using an LLM to score responses based on a predefined evaluation criteria.
231
+
232
+ Direct Assessment is an evaluation paradigm in which the LLM selects one of a
233
+ predefined set of options based on an assessment criterion. This approach can
234
+ be used for Likert-scale scoring (e.g., 1-5) or selecting from semantically
235
+ conditioned literals (e.g., Yes/No, Pass/Fail).
236
+ """
237
+
238
  criteria: CriteriaWithOptions = None
239
+ """The evaluation criteria, including a name, description, a predefined set of options and and option_map."""
240
  main_score = "llm_as_judge"
241
+ """The primary score name used in the results. By default, it will take the value of the criteria name (if only one criteria is being used for evaluation) or "llm_as_judge" otherwise."""
242
  reduction_map = {"mean": ["llm_as_judge"]}
243
+ """A mapping used for score aggregation. By default, it will take the value of `{'mean': [<default_main_score_name>]}`."""
244
 
245
  def prepare(self):
246
  super().prepare()
 
270
  self.option_selection_task = Task(
271
  input_fields={
272
  "criteria_description": str,
273
+ "display_options_instruction": str,
274
  "options": list,
275
  },
276
  reference_fields={},
 
279
  )
280
 
281
  def before_process_multi_stream(self):
282
+ """Ensures that the criteria is of type `CriteriaWithOptions`, raising an exception otherwise."""
283
  super().before_process_multi_stream()
284
  if self.criteria is not None and not isinstance(
285
  self.criteria, CriteriaWithOptions
 
289
  )
290
  return
291
 
292
+ def __get_parsed_criteria(self, criteria: CriteriaWithOptions):
293
+ """Extracts key information from the given criteria.
294
+
295
+ Args:
296
+ criteria (CriteriaWithOptions): The evaluation criteria.
297
+
298
+ Returns:
299
+ Tuple[str, List[str], str, str]:
300
+ - Criteria description.
301
+ - List of option names.
302
+ - Formatted instruction for displaying options.
303
+ - Instruction for scoring options.
304
+ """
305
  criteria_description = criteria.description
306
  criteria_option_names = [o.name for o in criteria.options]
307
 
308
+ display_options_instruction = "Choose an option:\n" + "\n".join(
309
  [
310
  f'- "{o.name}"{f" if {o.description}" if o.description != "" else ""}'
311
  for o in criteria.options
312
  ]
313
  )
 
 
 
314
 
315
  return (
316
  criteria_description,
317
  criteria_option_names,
318
  display_options_instruction,
 
319
  )
320
 
321
+ def __set_main_score(self, criterias: List[CriteriaWithOptions]):
322
  unique_criteria_names = list({criteria.name for criteria in criterias})
323
  if len(unique_criteria_names) == 1 and criterias[0].name != "":
324
  self.main_score = "_".join(criterias[0].name.lower().split(" "))
325
  self.reduction_map = {"mean": [self.main_score]}
326
 
327
+ def __get_results(
328
  self,
329
  assessment_prompts,
330
  assessment_outputs,
 
368
  "summary": summarization_outputs[i]
369
  if self.generate_summaries
370
  else None,
371
+ "positional_bias_summary": summarization_outputs[i]
372
+ if self.generate_summaries and self.check_positional_bias
373
+ else None,
374
  "prompts": {
375
  "assessment": assessment_prompts[i],
376
  "positional_bias_assessment": assessment_prompts[
 
414
  references: List[List[str]],
415
  predictions: List[str],
416
  task_data: List[Dict[str, Any]],
417
+ ) -> List[Dict]:
418
+ r"""Performs direct assessment evaluation on the given predictions and references.
419
+
420
+ This method evaluates the quality of of the predictions by calculating scores for each instance based on a criterion.
421
+
422
+ Returns:
423
+ -------
424
+ List[Dict]
425
+ A list of dictionaries containing the evaluation results for each instance. The results include the computed scores for each prediction. Each result will have the `score_name` as a prefix, which may be the criterion name if only one used, or "llm_as_judge" if several criteria were used.
426
+
427
+ Explanation of fields:
428
+
429
+ - `score`: a float representing the evaluation score for the response. The value is calculated from criteria.option_map[selected_option].
430
+ - `using_<evaluator_name>`: Equal to score.
431
+ - `positional_bias`: Boolean indicating whether the assessment detected positional bias. Its final value is selected_option != positional_bias_selected_option
432
+ - `selected_option`: The criteria option that the evaluator chose (e.g., "Could be Improved"). It is calculated by processing `option_selection_completion` using `processors.match_closest_option`
433
+ - `positional_bias_selected_option`: The criteria option that the evaluator chose when checking positional bias.
434
+ - `assessment`: The inference engine's generated text using the `prompts.assessment` prompt.
435
+ - `positional_bias_assessment`: The inference engine's generated text using the `prompts.positional_bias_assessment` prompt.
436
+ - `summary`: An LLM-generated summary of the assessment.
437
+ - `positional_bias_summary`: A LLM-generated summary of the positional bias assessment.
438
+ - `prompts`: A dictionary of prompts used in different stages of evaluation.
439
+ - `assessment`: The prompt used to instruct the model on how to assess the response.
440
+ - `positional_bias_assessment`: The prompt used to instruct the model on how to assess the response in the positional bias check.
441
+ - `summarization`: The prompt used to generate summary of the assessment.
442
+ - `option_selection`: The prompt used to generate a final judgement.
443
+ - `positional_bias_option_selection`: The prompt used to generate a final judgement in the positional bias check.
444
+ - `option_selection_completion`: The inference engine's generated text using `prompts.option_selection`.
445
+ - `positional_bias_option_selection_completion`: The inference engine's generated text using `prompts.positional_bias_option_selection`.
446
+ - `criteria`: A JSON-like string representing the evaluation criteria's artifact.
447
+
448
+ Result example:
449
+
450
+ .. code-block:: python
451
+
452
+ [
453
+ {
454
+ "answer_relevance": 1,
455
+ "answer_relevance_using_granite3.0-2b_litellm": 1,
456
+ "answer_relevance_positional_bias": false,
457
+ "answer_relevance_selected_option": "Could be Improved",
458
+ "answer_relevance_positional_bias_selected_option": "Could be Improved",
459
+ "answer_relevance_assessment": "To assess the quality of the response, l...",
460
+ "answer_relevance_positional_bias_assessment": "To assess the quality of the response, l...",
461
+ "answer_relevance_summary": "A response about apprenticeships during ...",
462
+ "answer_relevance_positional_bias_summary": "A response about apprenticeships during ...",
463
+ "answer_relevance_prompts": {
464
+ "assessment": [
465
+ {
466
+ "role": "user",
467
+ "content": "You are presented with a response gener..."
468
+ }
469
+ ],
470
+ "positional_bias_assessment": [
471
+ {
472
+ "role": "user",
473
+ "content": "You are presented with a response gener..."
474
+ }
475
+ ],
476
+ "summarization": [
477
+ {
478
+ "role": "user",
479
+ "content": "Transform the following assessment into ..."
480
+ }
481
+ ],
482
+ "option_selection": [
483
+ {
484
+ "content": "You are presented with a response gener...",
485
+ "role": "user"
486
+ },
487
+ {
488
+ "content": "To assess the quality of the response, l...",
489
+ "role": "assistant"
490
+ },
491
+ {
492
+ "content": "Now consider the evaluation criteria and...",
493
+ "role": "user"
494
+ }
495
+ ],
496
+ "posional_bias_option_selection": [
497
+ {
498
+ "content": "You are presented with a response gener...",
499
+ "role": "user"
500
+ },
501
+ {
502
+ "content": "To assess the quality of the response, l...",
503
+ "role": "assistant"
504
+ },
505
+ {
506
+ "content": "Now consider the evaluation criteria and...",
507
+ "role": "user"
508
+ }
509
+ ]
510
+ },
511
+ "answer_relevance_option_selection_completion": "Could be Improved",
512
+ "answer_relevance_positional_bias_option_selection_completion": "Could be Improved",
513
+ "answer_relevance_criteria": "{ \"__type__\": \"criteria_with_options..."
514
+ }
515
+ ]
516
+ """
517
+ logger.info(
518
  f'Starting evaluation with evaluator "{self.evaluator_name}" and provider "{self.inference_engine.get_pretty_print_name()}'
519
  )
520
  evaluations_count = len(predictions)
521
  # TODO: find out how to serialize and deserialize enums
522
+ criterias = self.get_criteria(task_data, evaluations_count)
523
+ self.__set_main_score(criterias)
524
  contexts = self.get_contexts(task_data)
525
  if self.check_positional_bias:
526
  criterias += [
 
536
  predictions += predictions
537
 
538
  parsed_criterias = [
539
+ self.__get_parsed_criteria(criteria) for criteria in criterias
540
  ]
541
 
542
  (
543
  criteria_description_list,
544
  criteria_option_names_list,
545
  display_options_instruction_list,
 
546
  ) = zip(*parsed_criterias)
547
 
548
  assessment_for_summaries_slice = slice(0, evaluations_count)
 
565
  assessment_prompts, assessment_outputs, _ = self.perform_evaluation_step(
566
  assessment_instances, self.assessment_task, self.assessment_template
567
  )
568
+ logger.info("The assessment was generated successfully.")
569
 
570
  summarization_prompts = None
571
  summarization_outputs = None
 
589
  self.summarization_task,
590
  self.summarization_template,
591
  )
592
+ logger.info("The summary was generated successfully.")
593
 
594
  option_selection_instances = [
595
  {
596
  "criteria_description": criteria_description,
597
+ "display_options_instruction": display_options_instruction,
598
  "options": criteria_option_names,
599
  "data_classification_policy": ["public"],
600
  }
601
+ for (
602
+ criteria_description,
603
+ display_options_instruction,
604
+ criteria_option_names
605
+ ) in zip(
606
  criteria_description_list,
607
+ display_options_instruction_list,
608
  criteria_option_names_list,
609
  )
610
  ]
 
625
  self.option_selection_template,
626
  previous_messages,
627
  )
628
+ logger.info("The selections were calculated successfully.")
629
 
630
+ results = self.__get_results(
631
  assessment_prompts,
632
  assessment_outputs,
633
  summarization_prompts,
 
638
  evaluations_count,
639
  criterias,
640
  )
641
+
642
  return self.clean_results(results)
643
 
644
 
645
  class LLMJudgePairwise(LLMJudge):
646
+ """A judge for pairwise comparison evaluations, where two or more responses are compared to determine which one is preferred based on a criterion."""
647
  main_score = "1_winrate"
648
+ """The main score metric for pairwise evaluation. By default, its value is `1_winrate`, and will take the value of the winrate of the first system."""
649
+ reduction_map = {"mean": ["score"]}
650
+ """A mapping specifying how scores should be reduced. By default, it will be `{'main': ['score']}`"""
651
 
652
  def prepare(self):
653
+ """Prepares the pairwise comparison by initializing the necessary templates and tasks. These tasks will be used to assess, summarize, and select options from candidate responses."""
654
  super().prepare()
655
  self.assessment_template = pairwise_template_dict["assessment"]
656
  self.summarization_template = pairwise_template_dict["summarization"]
 
689
  )
690
 
691
  def before_process_multi_stream(self):
692
+ """Verifies that the criteria is of the correct type before processing the multi-stream data."""
693
  super().before_process_multi_stream()
694
  if self.criteria is not None and not isinstance(self.criteria, Criteria):
695
  raise Exception(
 
697
  )
698
  return
699
 
700
+ def __get_instance_results(
701
  self,
702
  instance_predictions: Dict[str, str],
703
  assessment_prompts,
 
709
  selections,
710
  contests_count,
711
  combination_indexes,
712
+ criterion: Criteria,
713
  ):
714
+ """Computes the results for each instance by comparing the responses and calculating metrics such as winrate, ranking, and the responses overall performance. This method processes assessment, summarization, and option selection outputs to track contest results, positional bias, and winrate.
715
+
716
+ Args:
717
+ instance_predictions (Dict[str, str]): The predictions for each response.
718
+ assessment_prompts (List[str]): The prompts for the assessment task.
719
+ assessment_outputs (List[str]): The results from the assessment task.
720
+ summarization_prompts (List[str]): The prompts for the summarization task.
721
+ summarization_outputs (List[str]): The results from the summarization task.
722
+ option_selection_prompts (List[str]): The prompts for the option selection task.
723
+ option_selection_outputs (List[str]): The results from the option selection task.
724
+ selections (List[str]): The selections made during the pairwise comparison.
725
+ contests_count (int): The total number of contests that were run.
726
+ combination_indexes (List[Tuple[int, int]]): The indexes of the response pairs that were compared.
727
+ criterion (Criteria): The criterion used to assess the responses.
728
+
729
+ Returns:
730
+ dict: A dictionary containing the results for each response, including winrate, ranking, and other metrics.
731
+ """
732
  response_names = list(instance_predictions.keys())
733
  per_response_results = {
734
  response_key: {
 
887
  for metric in single_result.keys():
888
  all_results[f"{response_name}_{metric}"] = single_result[metric]
889
 
890
+ all_results["criteria"] = criterion.to_json()
891
  return self.clean_results(all_results)
892
 
893
+ def __parse_prediction_to_dict(self, predictions: Union[Dict[str, str], List[str]]):
894
+ """Converts a list or dictionary of predictions into a dictionary format.
895
+
896
+ Args:
897
+ predictions (Union[Dict[str, str], List[str]]): The prediction data to convert.
898
+
899
+ Returns:
900
+ dict: The prediction data in dictionary format.
901
+ """
902
+ if isinstance(predictions, list):
903
+ return {f"{key + 1}": value for key, value in enumerate(predictions)}
904
+ if isinstance(predictions, dict):
905
+ return predictions
906
+ raise UnitxtError(
907
+ f"Prediction may be a list or a dict. Instead got type {type(predictions)}"
908
  )
909
 
910
+ def __convert_predictions_to_dicts(
911
  self, predictions: Union[List[Dict[str, str]], List[str]]
912
  ):
913
+ """Converts a list of predictions into a list of dictionaries.
914
+
915
+ Args:
916
+ predictions (Union[List[Dict[str, str]], List[str]]): The predictions to convert.
917
+
918
+ Returns:
919
+ List[dict]: A list of predictions in dictionary format.
920
+ """
921
+ return [self.__parse_prediction_to_dict(prediction) for prediction in predictions]
922
+
923
+ def __set_main_score(self, predictions: List[Dict[str, str]]):
924
+ self.main_score = f"{next(iter(predictions[0].keys()))}_winrate"
925
 
926
  def compute(
927
  self,
928
  references: List[List[str]],
929
  predictions: List[str],
930
  task_data: List[Dict[str, str]],
931
+ ) -> List[Dict]:
932
+ r"""Executes the pairwise comparison evaluation, including assessment, summarization, and option selection. It computes the winrate and ranking for the responses.
933
+
934
+ Args:
935
+ references (List[List[str]]): A list of reference responses for comparison.
936
+ predictions (List[str]): A list of predicted responses.
937
+ task_data (List[Dict[str, str]]): Task data to be used for evaluation.
938
+
939
+ Returns:
940
+ -------
941
+ List[Dict[str,Dict]]
942
+ The results of the evaluation, including winrate, ranking, and other metrics.
943
+
944
+ For each instance result, the following metrics are included per response/system. Each of the metrics will have appended the systems name, if predictions were provided as a list of dicts, or their index, starting from 1, if predictions were provided as a list of lists.
945
+
946
+ All the fields are arrays with length equal to `len(systems) - 1`. For any result at index `i`: `response_name[i]`'s contest against `compared_to[i]`'s result is `contest_results[i]`.
947
+
948
+ Explanation of fields:
949
+
950
+ - `summaries`: A list of LLM-generated summaries explaining the comparison results for each response.
951
+ - `contest_results`: A list of boolean values indicating whether the response won in each comparison.
952
+ - `selections`: A list of the selected system names, representing the preferred response in each comparison.
953
+ - `compared_to`: A list of system names that were compared against the given response.
954
+ - `assessments`: A list of LLM-generated assessments explaining the reasoning behind the evaluation results.
955
+ - `positional_bias_assessments`: A list of LLM-generated assessments focused on detecting positional bias in the evaluation.
956
+ - `option_selection_outputs`: A list of response names selected as the best choice based on the evaluation.
957
+ - `positional_bias`: A list of boolean values indicating whether positional bias was detected in the contest.
958
+ - `positional_bias_selection`: A list of response names representing the selected option when considering positional bias.
959
+ - `prompts`: A dictionary of prompts used in different stages of evaluation.
960
+ - `assessment`: The prompt used to instruct the model on how to assess the responses.
961
+ - `positional_bias_assessment`: The prompt used to instruct the model on how to assess positional bias.
962
+ - `option_selection`: The prompt used to guide the model in selecting the best response.
963
+ - `positional_bias_option_selection`: The prompt used for selecting the best response while checking for positional bias.
964
+ - `summary`: The prompt used to generate a summary of the assessment.
965
+ - `winrate`: A float representing the proportion of comparisons the response won.
966
+ - `llm_as_judge`: Equal to `winrate`.
967
+ - `ranking`: An integer representing the ranking position of the response based on the evaluation results. Best is 1.
968
+ - `response_name`: A string identifying the response in the evaluation.
969
+
970
+ Result example:
971
+
972
+ .. code-block:: python
973
+
974
+ [
975
+ {
976
+ "system1_contest_results": [
977
+ true,
978
+ true
979
+ ],
980
+ "system1_selections": [
981
+ "system1",
982
+ "system1"
983
+ ],
984
+ "system1_compared_to": [
985
+ "system2",
986
+ "system3"
987
+ ],
988
+ "system1_assessments": [
989
+ "To determine the better response accordi...",
990
+ "To determine the better response accordi..."
991
+ ],
992
+ "system1_positional_bias_assessments": [
993
+ "To determine the better response accordi...",
994
+ "To determine the better response accordi..."
995
+ ],
996
+ "system1_option_selection_outputs": [
997
+ "system1",
998
+ "system1"
999
+ ],
1000
+ "system1_positional_bias": [
1001
+ false,
1002
+ false
1003
+ ],
1004
+ "system1_positional_bias_selection": [
1005
+ "system1",
1006
+ "system1"
1007
+ ],
1008
+ "system1_prompts": {
1009
+ "assessment": [
1010
+ [
1011
+ {
1012
+ "role": "user",
1013
+ "content": "You are provided a pair of responses (Re..."
1014
+ }
1015
+ ],
1016
+ [
1017
+ {
1018
+ "role": "user",
1019
+ "content": "You are provided a pair of responses (Re..."
1020
+ }
1021
+ ]
1022
+ ],
1023
+ "positional_bias_assessment": [
1024
+ [
1025
+ {
1026
+ "role": "user",
1027
+ "content": "You are provided a pair of responses (Re..."
1028
+ }
1029
+ ],
1030
+ [
1031
+ {
1032
+ "role": "user",
1033
+ "content": "You are provided a pair of responses (Re..."
1034
+ }
1035
+ ]
1036
+ ],
1037
+ "option_selection": [
1038
+ [
1039
+ {
1040
+ "content": "You are provided a pair of responses (Re...",
1041
+ "role": "user"
1042
+ },
1043
+ {
1044
+ "content": "To determine the better response accordi...",
1045
+ "role": "assistant"
1046
+ },
1047
+ {
1048
+ "content": "Now considering the evaluation criteria,...",
1049
+ "role": "user"
1050
+ }
1051
+ ],
1052
+ [
1053
+ {
1054
+ "content": "You are provided a pair of responses (Re...",
1055
+ "role": "user"
1056
+ },
1057
+ {
1058
+ "content": "To determine the better response accordi...",
1059
+ "role": "assistant"
1060
+ },
1061
+ {
1062
+ "content": "Now considering the evaluation criteria,...",
1063
+ "role": "user"
1064
+ }
1065
+ ]
1066
+ ],
1067
+ "positional_bias_option_selection": [
1068
+ [
1069
+ {
1070
+ "content": "You are provided a pair of responses (Re...",
1071
+ "role": "user"
1072
+ },
1073
+ {
1074
+ "content": "To determine the better response accordi...",
1075
+ "role": "assistant"
1076
+ },
1077
+ {
1078
+ "content": "Now considering the evaluation criteria,...",
1079
+ "role": "user"
1080
+ }
1081
+ ],
1082
+ [
1083
+ {
1084
+ "content": "You are provided a pair of responses (Re...",
1085
+ "role": "user"
1086
+ },
1087
+ {
1088
+ "content": "To determine the better response accordi...",
1089
+ "role": "assistant"
1090
+ },
1091
+ {
1092
+ "content": "Now considering the evaluation criteria,...",
1093
+ "role": "user"
1094
+ }
1095
+ ]
1096
+ ]
1097
+ },
1098
+ "system1_winrate": 1.0,
1099
+ "system1_llm_as_judge": 1.0,
1100
+ "system1_ranking": 1,
1101
+ "system1_response_name": "system1",
1102
+ "system2_contest_results": [
1103
+ false,
1104
+ true
1105
+ ],
1106
+ "system2_selections": [
1107
+ "system1",
1108
+ "system2"
1109
+ ],
1110
+ "system2_compared_to": [
1111
+ "system1",
1112
+ "system3"
1113
+ ],
1114
+ "system2_assessments": [
1115
+ "To determine the better response accordi...",
1116
+ "To determine the better response accordi..."
1117
+ ],
1118
+ "system2_positional_bias_assessments": [
1119
+ "To determine the better response accordi...",
1120
+ "To determine the better response accordi..."
1121
+ ],
1122
+ "system2_option_selection_outputs": [
1123
+ "system1",
1124
+ "system2"
1125
+ ],
1126
+ "system2_positional_bias": [
1127
+ false,
1128
+ false
1129
+ ],
1130
+ "system2_positional_bias_selection": [
1131
+ "system1",
1132
+ "system2"
1133
+ ],
1134
+ "system2_prompts": {
1135
+ "assessment": [
1136
+ [
1137
+ {
1138
+ "role": "user",
1139
+ "content": "You are provided a pair of responses (Re..."
1140
+ }
1141
+ ],
1142
+ [
1143
+ {
1144
+ "role": "user",
1145
+ "content": "You are provided a pair of responses (Re..."
1146
+ }
1147
+ ]
1148
+ ],
1149
+ "positional_bias_assessment": [
1150
+ [
1151
+ {
1152
+ "role": "user",
1153
+ "content": "You are provided a pair of responses (Re..."
1154
+ }
1155
+ ],
1156
+ [
1157
+ {
1158
+ "role": "user",
1159
+ "content": "You are provided a pair of responses (Re..."
1160
+ }
1161
+ ]
1162
+ ],
1163
+ "option_selection": [
1164
+ [
1165
+ {
1166
+ "content": "You are provided a pair of responses (Re...",
1167
+ "role": "user"
1168
+ },
1169
+ {
1170
+ "content": "To determine the better response accordi...",
1171
+ "role": "assistant"
1172
+ },
1173
+ {
1174
+ "content": "Now considering the evaluation criteria,...",
1175
+ "role": "user"
1176
+ }
1177
+ ],
1178
+ [
1179
+ {
1180
+ "content": "You are provided a pair of responses (Re...",
1181
+ "role": "user"
1182
+ },
1183
+ {
1184
+ "content": "To determine the better response accordi...",
1185
+ "role": "assistant"
1186
+ },
1187
+ {
1188
+ "content": "Now considering the evaluation criteria,...",
1189
+ "role": "user"
1190
+ }
1191
+ ]
1192
+ ],
1193
+ "positional_bias_option_selection": [
1194
+ [
1195
+ {
1196
+ "content": "You are provided a pair of responses (Re...",
1197
+ "role": "user"
1198
+ },
1199
+ {
1200
+ "content": "To determine the better response accordi...",
1201
+ "role": "assistant"
1202
+ },
1203
+ {
1204
+ "content": "Now considering the evaluation criteria,...",
1205
+ "role": "user"
1206
+ }
1207
+ ],
1208
+ [
1209
+ {
1210
+ "content": "You are provided a pair of responses (Re...",
1211
+ "role": "user"
1212
+ },
1213
+ {
1214
+ "content": "To determine the better response accordi...",
1215
+ "role": "assistant"
1216
+ },
1217
+ {
1218
+ "content": "Now considering the evaluation criteria,...",
1219
+ "role": "user"
1220
+ }
1221
+ ]
1222
+ ]
1223
+ },
1224
+ "system2_winrate": 0.5,
1225
+ "system2_llm_as_judge": 0.5,
1226
+ "system2_ranking": 2,
1227
+ "system2_response_name": "system2",
1228
+ "system3_contest_results": [
1229
+ false,
1230
+ false
1231
+ ],
1232
+ "system3_selections": [
1233
+ "system1",
1234
+ "system2"
1235
+ ],
1236
+ "system3_compared_to": [
1237
+ "system1",
1238
+ "system2"
1239
+ ],
1240
+ "system3_assessments": [
1241
+ "To determine the better response accordi...",
1242
+ "To determine the better response accordi..."
1243
+ ],
1244
+ "system3_positional_bias_assessments": [
1245
+ "To determine the better response accordi...",
1246
+ "To determine the better response accordi..."
1247
+ ],
1248
+ "system3_option_selection_outputs": [
1249
+ "system1",
1250
+ "system2"
1251
+ ],
1252
+ "system3_positional_bias": [
1253
+ false,
1254
+ false
1255
+ ],
1256
+ "system3_positional_bias_selection": [
1257
+ "system1",
1258
+ "system2"
1259
+ ],
1260
+ "system3_prompts": {
1261
+ "assessment": [
1262
+ [
1263
+ {
1264
+ "role": "user",
1265
+ "content": "You are provided a pair of responses (Re..."
1266
+ }
1267
+ ],
1268
+ [
1269
+ {
1270
+ "role": "user",
1271
+ "content": "You are provided a pair of responses (Re..."
1272
+ }
1273
+ ]
1274
+ ],
1275
+ "positional_bias_assessment": [
1276
+ [
1277
+ {
1278
+ "role": "user",
1279
+ "content": "You are provided a pair of responses (Re..."
1280
+ }
1281
+ ],
1282
+ [
1283
+ {
1284
+ "role": "user",
1285
+ "content": "You are provided a pair of responses (Re..."
1286
+ }
1287
+ ]
1288
+ ],
1289
+ "option_selection": [
1290
+ [
1291
+ {
1292
+ "content": "You are provided a pair of responses (Re...",
1293
+ "role": "user"
1294
+ },
1295
+ {
1296
+ "content": "To determine the better response accordi...",
1297
+ "role": "assistant"
1298
+ },
1299
+ {
1300
+ "content": "Now considering the evaluation criteria,...",
1301
+ "role": "user"
1302
+ }
1303
+ ],
1304
+ [
1305
+ {
1306
+ "content": "You are provided a pair of responses (Re...",
1307
+ "role": "user"
1308
+ },
1309
+ {
1310
+ "content": "To determine the better response accordi...",
1311
+ "role": "assistant"
1312
+ },
1313
+ {
1314
+ "content": "Now considering the evaluation criteria,...",
1315
+ "role": "user"
1316
+ }
1317
+ ]
1318
+ ],
1319
+ "positional_bias_option_selection": [
1320
+ [
1321
+ {
1322
+ "content": "You are provided a pair of responses (Re...",
1323
+ "role": "user"
1324
+ },
1325
+ {
1326
+ "content": "To determine the better response accordi...",
1327
+ "role": "assistant"
1328
+ },
1329
+ {
1330
+ "content": "Now considering the evaluation criteria,...",
1331
+ "role": "user"
1332
+ }
1333
+ ],
1334
+ [
1335
+ {
1336
+ "content": "You are provided a pair of responses (Re...",
1337
+ "role": "user"
1338
+ },
1339
+ {
1340
+ "content": "To determine the better response accordi...",
1341
+ "role": "assistant"
1342
+ },
1343
+ {
1344
+ "content": "Now considering the evaluation criteria,...",
1345
+ "role": "user"
1346
+ }
1347
+ ]
1348
+ ]
1349
+ },
1350
+ "system3_winrate": 0.0,
1351
+ "system3_llm_as_judge": 0.0,
1352
+ "system3_ranking": 3,
1353
+ "system3_response_name": "system3",
1354
+ "criteria": "{ \"__type__\": \"criteria\", \"name\"..."
1355
+ }
1356
+ ]
1357
+ """
1358
+ logger.info(
1359
  f'Starting evaluation with evaluator "{self.evaluator_name}" and provider {self.inference_engine.get_pretty_print_name()}'
1360
  )
1361
+ predictions = self.__convert_predictions_to_dicts(predictions)
1362
+ self.__set_main_score(predictions)
1363
  instances_count = len(predictions)
1364
  self.reduction_map = {"mean": ["score"]}
1365
  self.reduction_map["mean"].extend(
 
1375
  len(combination_indexes) for combination_indexes in combination_indexes_list
1376
  ]
1377
 
1378
+ logger.info(
1379
  f"The evaluation will perform {sum(contests_count_list) * [1, 2][self.check_positional_bias]} ({' + '.join([f'{c * [1, 2][self.check_positional_bias]}' for c in contests_count_list])}) pairwise comparisons"
1380
  )
1381
 
 
1406
  response_pairs_list.append(response_pairs)
1407
  option_pairs_list.append(option_pairs)
1408
 
1409
+ criterias = self.get_criteria(task_data, instances_count)
1410
  contexts = self.get_contexts(task_data)
1411
  if self.check_positional_bias:
1412
  criterias.extend(criterias)
 
1440
  assessment_prompts, assessment_outputs, _ = self.perform_evaluation_step(
1441
  assessment_instances, self.assessment_task, self.assessment_template
1442
  )
1443
+ logger.info("The assessment was generated successfully.")
1444
 
1445
  # the slices used to get the assessment for each summary generation instance
1446
  # it will grab the whole assessment for a particular instance or half of it depending on the value of check_positional_bias
 
1490
  self.summarization_task,
1491
  self.summarization_template,
1492
  )
1493
+ logger.info("The summary was generated successfully.")
1494
 
1495
  score_option_instruction_list = [
1496
  "".join(
 
1538
  )
1539
  # Selections are of the form 'Response n', so we just keep n
1540
  selections = [selection.split(" ")[-1] for selection in selections]
1541
+ logger.info("The selections were calculated successfully.")
1542
  results = []
1543
  slice_start = 0
1544
  for i, incremental_contests_count in enumerate(incremental_contests_count_list):
 
1551
  (incremental_contests_count_list[i - 1] if i > 0 else 0)
1552
  + incremental_contests_count,
1553
  )
1554
+ instance_results = self.__get_instance_results(
1555
  predictions[i],
1556
  assessment_prompts[sli],
1557
  assessment_outputs[sli],
llm_as_judge_chat_templates.py CHANGED
@@ -29,11 +29,13 @@ Assessment: {assessment}
29
  Summary:"""
30
  ),
31
  "answer": InputOutputTemplate(
32
- input_format="""Now consider the evaluation criteria and choose a final answer. Only include the chosen answer in the response.
 
33
  ###Evaluation criteria:
34
  {criteria_description}
35
- {score_option_instruction}
36
- The selected answer is: """,
 
37
  postprocessors=["processors.match_closest_option"],
38
  ),
39
  }
 
29
  Summary:"""
30
  ),
31
  "answer": InputOutputTemplate(
32
+ input_format="""Now based on the assessment, choose a criteria option. Only include the chosen option in the response. If the assessment already contains a selected option, choose that option. Don't contradict the assessment's selected option.
33
+
34
  ###Evaluation criteria:
35
  {criteria_description}
36
+ {display_options_instruction}
37
+
38
+ The selected criteria option is: """,
39
  postprocessors=["processors.match_closest_option"],
40
  ),
41
  }
llm_as_judge_constants.py CHANGED
@@ -3,10 +3,6 @@ from enum import Enum
3
  from typing import Dict, List, Optional
4
 
5
  from .artifact import Artifact
6
- from .inference import (
7
- LiteLLMInferenceEngine,
8
- RITSInferenceEngine,
9
- )
10
 
11
 
12
  class OptionSelectionStrategyEnum(str, Enum):
@@ -68,13 +64,13 @@ class EvaluatorTypeEnum(str, Enum):
68
 
69
  class EvaluatorNameEnum(str, Enum):
70
  MIXTRAL8_7b = "Mixtral8-7b"
71
- MIXTRAL8_22b = "Mixtral8-22b"
72
  MIXTRAL_LARGE = "Mixtral Large"
73
  LLAMA3_8B = "Llama3-8b"
74
  LLAMA3_1_405B = "Llama3.1-405b"
75
  LLAMA3_1_8B = "Llama3.1-8b"
76
  LLAMA3_1_70B = "Llama3.1-70b"
77
  LLAMA3_2_3B = "Llama3.2-3b"
 
78
  PROMETHEUS = "Prometheus"
79
  GPT4 = "GPT-4o"
80
  O1_PREVIEW = "o1-Preview"
@@ -84,53 +80,33 @@ class EvaluatorNameEnum(str, Enum):
84
  GRANITE3_8B = "Granite3.0-8b"
85
  GRANITE3_1_2B = "Granite3.1-2b"
86
  GRANITE3_1_8B = "Granite3.1-8b"
 
87
 
88
 
89
  class ModelProviderEnum(str, Enum):
90
  WATSONX = "watsonx"
91
  OPENAI = "openai"
92
  RITS = "rits"
93
- AZURE_OPENAI = "azure_openai"
94
 
95
 
96
  EVALUATOR_TO_MODEL_ID = {
97
- EvaluatorNameEnum.MIXTRAL8_7b: "mistralai/mixtral-8x7b-instruct-v01",
98
- EvaluatorNameEnum.MIXTRAL8_22b: "mistralai/mixtral-8x22B-instruct-v0.1",
99
- EvaluatorNameEnum.MIXTRAL_LARGE: "mistralai/mistral-large",
100
- EvaluatorNameEnum.LLAMA3_1_405B: "meta-llama/llama-3-405b-instruct",
101
- EvaluatorNameEnum.LLAMA3_1_8B: "meta-llama/llama-3-1-8b-instruct",
102
- EvaluatorNameEnum.LLAMA3_1_70B: "meta-llama/llama-3-1-70b-instruct",
103
- EvaluatorNameEnum.LLAMA3_2_3B: "meta-llama/llama-3-2-3b-instruct",
104
- EvaluatorNameEnum.PROMETHEUS: "kaist-ai/prometheus-8x7b-v2",
105
  EvaluatorNameEnum.GPT4: "gpt-4o-2024-08-06",
106
- EvaluatorNameEnum.O1_PREVIEW: "o1-preview-2024-09-12",
107
- EvaluatorNameEnum.O1_MINI: "o1-mini-2024-09-12",
108
- EvaluatorNameEnum.GRANITE_13B: "ibm/granite-13b-instruct-v2",
109
- EvaluatorNameEnum.GRANITE3_2B: "ibm/granite-3-2b-instruct",
110
- EvaluatorNameEnum.GRANITE3_8B: "ibm/granite-3-8b-instruct",
111
- EvaluatorNameEnum.GRANITE3_1_2B: "ibm/granite-3.1-2b-instruct",
112
- EvaluatorNameEnum.GRANITE3_1_8B: "ibm/granite-3.1-8b-instruct",
113
  }
114
 
115
- MODEL_RENAMINGS = {
116
- ModelProviderEnum.RITS: {
117
- "meta-llama/llama-3-1-8b-instruct": "meta-llama/Llama-3.1-8B-Instruct",
118
- "mistralai/mixtral-8x7b-instruct-v01": "mistralai/mixtral-8x7B-instruct-v0.1",
119
- "ibm/granite-3-8b-instruct": "ibm-granite/granite-3.0-8b-instruct",
120
- "ibm/granite-3.1-8b-instruct": "ibm-granite/granite-3.1-8b-instruct",
121
- "meta-llama/llama-3-405b-instruct": "meta-llama/llama-3-1-405b-instruct-fp8",
122
- "mistralai/mistral-large": "mistralai/mistral-large-instruct-2407",
123
- },
124
- }
125
-
126
- INFERENCE_ENGINE_NAME_TO_CLASS = {
127
- ModelProviderEnum.WATSONX: LiteLLMInferenceEngine,
128
- ModelProviderEnum.OPENAI: LiteLLMInferenceEngine,
129
- ModelProviderEnum.RITS: RITSInferenceEngine,
130
- ModelProviderEnum.AZURE_OPENAI: LiteLLMInferenceEngine,
131
- }
132
-
133
-
134
  class EvaluatorMetadata:
135
  name: EvaluatorNameEnum
136
  providers: List[ModelProviderEnum]
@@ -145,10 +121,6 @@ EVALUATORS_METADATA = [
145
  EvaluatorNameEnum.MIXTRAL8_7b,
146
  [ModelProviderEnum.RITS, ModelProviderEnum.WATSONX],
147
  ),
148
- EvaluatorMetadata(
149
- EvaluatorNameEnum.MIXTRAL8_22b,
150
- [ModelProviderEnum.RITS],
151
- ),
152
  EvaluatorMetadata(
153
  EvaluatorNameEnum.MIXTRAL_LARGE,
154
  [ModelProviderEnum.RITS, ModelProviderEnum.WATSONX],
@@ -161,6 +133,10 @@ EVALUATORS_METADATA = [
161
  EvaluatorNameEnum.GRANITE3_1_8B,
162
  [ModelProviderEnum.RITS],
163
  ),
 
 
 
 
164
  EvaluatorMetadata(
165
  EvaluatorNameEnum.GPT4,
166
  [ModelProviderEnum.OPENAI, ModelProviderEnum.AZURE_OPENAI],
@@ -185,6 +161,10 @@ EVALUATORS_METADATA = [
185
  EvaluatorNameEnum.LLAMA3_1_405B,
186
  [ModelProviderEnum.WATSONX, ModelProviderEnum.RITS],
187
  ),
 
 
 
 
188
  ]
189
 
190
  ################################ Direct Assessment Criterias ################################
 
3
  from typing import Dict, List, Optional
4
 
5
  from .artifact import Artifact
 
 
 
 
6
 
7
 
8
  class OptionSelectionStrategyEnum(str, Enum):
 
64
 
65
  class EvaluatorNameEnum(str, Enum):
66
  MIXTRAL8_7b = "Mixtral8-7b"
 
67
  MIXTRAL_LARGE = "Mixtral Large"
68
  LLAMA3_8B = "Llama3-8b"
69
  LLAMA3_1_405B = "Llama3.1-405b"
70
  LLAMA3_1_8B = "Llama3.1-8b"
71
  LLAMA3_1_70B = "Llama3.1-70b"
72
  LLAMA3_2_3B = "Llama3.2-3b"
73
+ LLAMA3_3_70B = "Llama3.3-70b"
74
  PROMETHEUS = "Prometheus"
75
  GPT4 = "GPT-4o"
76
  O1_PREVIEW = "o1-Preview"
 
80
  GRANITE3_8B = "Granite3.0-8b"
81
  GRANITE3_1_2B = "Granite3.1-2b"
82
  GRANITE3_1_8B = "Granite3.1-8b"
83
+ GRANITE3_2_8B = "Granite3.2-8b"
84
 
85
 
86
  class ModelProviderEnum(str, Enum):
87
  WATSONX = "watsonx"
88
  OPENAI = "openai"
89
  RITS = "rits"
90
+ AZURE_OPENAI = "azure"
91
 
92
 
93
  EVALUATOR_TO_MODEL_ID = {
94
+ EvaluatorNameEnum.MIXTRAL8_7b: "mixtral-8x7b-instruct-v01",
95
+ EvaluatorNameEnum.MIXTRAL_LARGE: "mistral-large-instruct",
96
+ EvaluatorNameEnum.LLAMA3_1_405B: "llama-3-1-405b-instruct",
97
+ EvaluatorNameEnum.LLAMA3_1_8B: "llama-3-1-70b-instruct",
98
+ EvaluatorNameEnum.LLAMA3_1_70B: "llama-3-1-70b-instruct",
99
+ EvaluatorNameEnum.LLAMA3_3_70B: "llama-3-3-70b-instruct",
 
 
100
  EvaluatorNameEnum.GPT4: "gpt-4o-2024-08-06",
101
+ EvaluatorNameEnum.O1_PREVIEW: "o1-preview",
102
+ EvaluatorNameEnum.O1_MINI: "o1-mini",
103
+ EvaluatorNameEnum.GRANITE3_2B: "granite-3-2b-instruct",
104
+ EvaluatorNameEnum.GRANITE3_8B: "granite-3-8b-instruct",
105
+ EvaluatorNameEnum.GRANITE3_1_2B: "granite-3-1-2b-instruct",
106
+ EvaluatorNameEnum.GRANITE3_1_8B: "granite-3-1-8b-instruct",
107
+ EvaluatorNameEnum.GRANITE3_2_8B: "granite-3-2-8b-instruct",
108
  }
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  class EvaluatorMetadata:
111
  name: EvaluatorNameEnum
112
  providers: List[ModelProviderEnum]
 
121
  EvaluatorNameEnum.MIXTRAL8_7b,
122
  [ModelProviderEnum.RITS, ModelProviderEnum.WATSONX],
123
  ),
 
 
 
 
124
  EvaluatorMetadata(
125
  EvaluatorNameEnum.MIXTRAL_LARGE,
126
  [ModelProviderEnum.RITS, ModelProviderEnum.WATSONX],
 
133
  EvaluatorNameEnum.GRANITE3_1_8B,
134
  [ModelProviderEnum.RITS],
135
  ),
136
+ EvaluatorMetadata(
137
+ EvaluatorNameEnum.GRANITE3_2_8B,
138
+ [ModelProviderEnum.WATSONX, ModelProviderEnum.RITS],
139
+ ),
140
  EvaluatorMetadata(
141
  EvaluatorNameEnum.GPT4,
142
  [ModelProviderEnum.OPENAI, ModelProviderEnum.AZURE_OPENAI],
 
161
  EvaluatorNameEnum.LLAMA3_1_405B,
162
  [ModelProviderEnum.WATSONX, ModelProviderEnum.RITS],
163
  ),
164
+ EvaluatorMetadata(
165
+ EvaluatorNameEnum.LLAMA3_3_70B,
166
+ [ModelProviderEnum.WATSONX, ModelProviderEnum.RITS],
167
+ ),
168
  ]
169
 
170
  ################################ Direct Assessment Criterias ################################
llm_as_judge_utils.py CHANGED
@@ -2,10 +2,8 @@ from typing import Dict
2
 
3
  from .llm_as_judge_constants import (
4
  EVALUATORS_METADATA,
5
- MODEL_RENAMINGS,
6
  EvaluatorMetadata,
7
  EvaluatorNameEnum,
8
- ModelProviderEnum,
9
  )
10
 
11
 
@@ -32,13 +30,6 @@ def get_evaluator_metadata(
32
  raise ValueError(f"An evaluator with id {name} matched several models.")
33
  return evaluator_search[0]
34
 
35
-
36
- def rename_model_if_required(model_name: str, provider: ModelProviderEnum) -> str:
37
- if provider in MODEL_RENAMINGS and model_name in MODEL_RENAMINGS[provider]:
38
- return MODEL_RENAMINGS[provider][model_name]
39
- return model_name
40
-
41
-
42
  def rank_indexes(numbers):
43
  # Generate the initial list of indices
44
  indices = list(range(len(numbers)))
 
2
 
3
  from .llm_as_judge_constants import (
4
  EVALUATORS_METADATA,
 
5
  EvaluatorMetadata,
6
  EvaluatorNameEnum,
 
7
  )
8
 
9
 
 
30
  raise ValueError(f"An evaluator with id {name} matched several models.")
31
  return evaluator_search[0]
32
 
 
 
 
 
 
 
 
33
  def rank_indexes(numbers):
34
  # Generate the initial list of indices
35
  indices = list(range(len(numbers)))
loaders.py CHANGED
@@ -67,7 +67,7 @@ from huggingface_hub import HfApi
67
  from tqdm import tqdm
68
 
69
  from .dataclass import NonPositionalField
70
- from .error_utils import UnitxtError, UnitxtWarning
71
  from .fusion import FixedFusion
72
  from .logging_utils import get_logger
73
  from .operator import SourceOperator
@@ -80,19 +80,27 @@ from .utils import LRUCache, recursive_copy
80
  logger = get_logger()
81
  settings = get_settings()
82
 
 
 
 
 
83
  def hf_load_dataset(path: str, *args, **kwargs):
84
  if settings.hf_offline_datasets_path is not None:
85
  path = os.path.join(settings.hf_offline_datasets_path, path)
86
- return _hf_load_dataset(
87
- path,
88
- *args, **kwargs,
89
- download_config=DownloadConfig(
90
- max_retries=settings.loaders_max_retries,
91
- ),
92
- verification_mode="no_checks",
93
- trust_remote_code=settings.allow_unverified_code,
94
- download_mode= "force_redownload" if settings.disable_hf_datasets_cache else "reuse_dataset_if_exists"
95
- )
 
 
 
 
96
 
97
  class Loader(SourceOperator):
98
  """A base class for all loaders.
@@ -288,26 +296,21 @@ class LoadHF(LazyLoader):
288
  if dataset is None:
289
  if streaming is None:
290
  streaming = self.is_streaming()
291
- try:
292
- dataset = hf_load_dataset(
293
- self.path,
294
- name=self.name,
295
- data_dir=self.data_dir,
296
- data_files=self.data_files,
297
- revision=self.revision,
298
- streaming=streaming,
299
- split=split,
300
- num_proc=self.num_proc,
301
- )
302
- except ValueError as e:
303
- if "trust_remote_code" in str(e):
304
- raise ValueError(
305
- f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
306
- ) from e
307
  self.__class__._loader_cache.max_size = settings.loader_cache_size
308
  if not disable_memory_caching:
309
  self.__class__._loader_cache[dataset_id] = dataset
310
- return self.__class__._loader_cache[dataset_id]
311
 
312
  def _maybe_set_classification_policy(self):
313
  if os.path.exists(self.path):
@@ -333,7 +336,9 @@ class LoadHF(LazyLoader):
333
  extract_on_the_fly=True,
334
  ),
335
  )
336
- except:
 
 
337
  UnitxtWarning(
338
  f'LoadHF(path="{self.path}", name="{self.name}") could not retrieve split names without loading the dataset. Consider defining "splits" in the LoadHF definition to improve loading time.'
339
  )
@@ -599,11 +604,11 @@ class LoadFromIBMCloud(Loader):
599
  load_ibm_cloud = LoadFromIBMCloud(
600
  endpoint_url_env='IBM_CLOUD_ENDPOINT',
601
  aws_access_key_id_env='IBM_AWS_ACCESS_KEY_ID',
602
- aws_secret_access_key_env='IBM_AWS_SECRET_ACCESS_KEY',
603
  bucket_name='my-bucket'
604
  )
605
  multi_stream = load_ibm_cloud.process()
606
- """ # pragma: allowlist secret
607
 
608
  endpoint_url_env: str
609
  aws_access_key_id_env: str
 
67
  from tqdm import tqdm
68
 
69
  from .dataclass import NonPositionalField
70
+ from .error_utils import Documentation, UnitxtError, UnitxtWarning
71
  from .fusion import FixedFusion
72
  from .logging_utils import get_logger
73
  from .operator import SourceOperator
 
80
  logger = get_logger()
81
  settings = get_settings()
82
 
83
+ class UnitxtUnverifiedCodeError(UnitxtError):
84
+ def __init__(self, path):
85
+ super().__init__(f"Loader cannot load and run remote code from {path} in huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE.", Documentation.SETTINGS)
86
+
87
  def hf_load_dataset(path: str, *args, **kwargs):
88
  if settings.hf_offline_datasets_path is not None:
89
  path = os.path.join(settings.hf_offline_datasets_path, path)
90
+ try:
91
+ return _hf_load_dataset(
92
+ path,
93
+ *args, **kwargs,
94
+ download_config=DownloadConfig(
95
+ max_retries=settings.loaders_max_retries,
96
+ ),
97
+ verification_mode="no_checks",
98
+ trust_remote_code=settings.allow_unverified_code,
99
+ download_mode= "force_redownload" if settings.disable_hf_datasets_cache else "reuse_dataset_if_exists"
100
+ )
101
+ except ValueError as e:
102
+ if "trust_remote_code" in str(e):
103
+ raise UnitxtUnverifiedCodeError(path) from e
104
 
105
  class Loader(SourceOperator):
106
  """A base class for all loaders.
 
296
  if dataset is None:
297
  if streaming is None:
298
  streaming = self.is_streaming()
299
+
300
+ dataset = hf_load_dataset(
301
+ self.path,
302
+ name=self.name,
303
+ data_dir=self.data_dir,
304
+ data_files=self.data_files,
305
+ revision=self.revision,
306
+ streaming=streaming,
307
+ split=split,
308
+ num_proc=self.num_proc,
309
+ )
 
 
 
 
 
310
  self.__class__._loader_cache.max_size = settings.loader_cache_size
311
  if not disable_memory_caching:
312
  self.__class__._loader_cache[dataset_id] = dataset
313
+ return dataset
314
 
315
  def _maybe_set_classification_policy(self):
316
  if os.path.exists(self.path):
 
336
  extract_on_the_fly=True,
337
  ),
338
  )
339
+ except Exception as e:
340
+ if "trust_remote_code" in str(e):
341
+ raise UnitxtUnverifiedCodeError(self.path) from e
342
  UnitxtWarning(
343
  f'LoadHF(path="{self.path}", name="{self.name}") could not retrieve split names without loading the dataset. Consider defining "splits" in the LoadHF definition to improve loading time.'
344
  )
 
604
  load_ibm_cloud = LoadFromIBMCloud(
605
  endpoint_url_env='IBM_CLOUD_ENDPOINT',
606
  aws_access_key_id_env='IBM_AWS_ACCESS_KEY_ID',
607
+ aws_secret_access_key_env='IBM_AWS_SECRET_ACCESS_KEY', # pragma: allowlist secret
608
  bucket_name='my-bucket'
609
  )
610
  multi_stream = load_ibm_cloud.process()
611
+ """
612
 
613
  endpoint_url_env: str
614
  aws_access_key_id_env: str
metrics.py CHANGED
@@ -63,13 +63,10 @@ from .operator import (
63
  from .operators import ArtifactFetcherMixin, Copy, Set
64
  from .random_utils import get_seed
65
  from .settings_utils import get_settings
66
- from .sql_utils import get_db_connector
67
  from .stream import MultiStream, Stream
68
  from .type_utils import Type, isoftype, parse_type_string, to_type_string
69
  from .utils import deep_copy, recursive_copy
70
 
71
- FINQA_HASH = "42430b8613082bb4b85d49210284135d"
72
-
73
  logger = get_logger()
74
  settings = get_settings()
75
 
@@ -127,13 +124,18 @@ def nan_mean(x):
127
 
128
  def nan_max(x):
129
  with warnings.catch_warnings():
130
- # final mean should be mean of scores, ignoring NaN, hence nanmax
131
- # but if the group function values is NaN for ALL values, nanmean throws a
132
- # RuntimeWarning that it is calculating the mean of an empty slice (with no non-Nans)
133
- # this is the desired behavior, but we want to avoid the warning here
134
  warnings.simplefilter("ignore", category=RuntimeWarning)
135
  return np.nanmax(x)
136
 
 
 
 
 
 
 
 
 
 
137
 
138
  class UpdateStream(InstanceOperator):
139
  update: dict
@@ -365,6 +367,43 @@ def new_random_generator():
365
  return np.random.default_rng(hash(get_seed()) & _max_32bit)
366
 
367
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  class ConfidenceIntervalMixin(Artifact):
369
  n_resamples: int = 1000
370
  confidence_level: float = 0.95
@@ -374,42 +413,41 @@ class ConfidenceIntervalMixin(Artifact):
374
  def _sample_to_scores(self, sample: List[Any]) -> Dict[str, Any]:
375
  pass
376
 
377
- def get_statistic(self, data: List[Any], score_names: List[str]):
378
- def statistic_function(indices, axis=0):
379
- # indices might be a 1D or 2D array, depending on bootstrap internals
380
- # For simplicity, ensure we handle them as 1D.
381
- indices = np.atleast_1d(indices).astype(int)
382
-
383
- # Gather the subset
384
- sample = [data[i] for i in indices]
385
-
386
- # Compute metrics on this sample
387
- scores = self._sample_to_scores(sample)
388
-
389
- # Return them in consistent order
390
- return np.array([scores[m] for m in score_names])
391
-
392
- return statistic_function
393
 
394
  def bootstrap(self, data: List[Any], score_names: List[str]):
395
  if self.ci_score_names is not None:
396
  score_names = self.ci_score_names
397
 
398
- intervals = bootstrap(
399
- (np.arange(len(data)),),
400
- statistic=self.get_statistic(data, score_names),
401
- n_resamples=self.n_resamples,
402
- confidence_level=self.confidence_level,
403
- random_state=new_random_generator(),
404
- paired=False,
405
- vectorized=False, # set to True if your statistic function is vectorized
406
- method="BCa",
407
- ).confidence_interval
 
 
 
 
 
 
 
 
 
 
408
 
409
  result = {}
410
  for i, metric in enumerate(score_names):
411
- result[f"{metric}_ci_low"] = float(intervals.low[i])
412
- result[f"{metric}_ci_high"] = float(intervals.high[i])
 
 
 
 
 
413
 
414
  return result
415
 
@@ -2769,7 +2807,7 @@ class FinQAEval(InstanceMetric):
2769
  remote_url = "https://raw.githubusercontent.com/czyssrs/FinQA/dfc5b72c01ee17c442d28d5201b82a1f4e95d5af/code/evaluate/evaluate.py"
2770
  local_filepath = "/tmp/finqa_eval_script.py"
2771
  module_name = "finqa_eval"
2772
- hash_of_script = FINQA_HASH
2773
 
2774
  download_finqa_eval_script_file(remote_url, local_filepath, hash_of_script)
2775
  self.finqa_module = load_finqa_eval_module_from_file(
@@ -3375,25 +3413,83 @@ class CustomF1(GlobalMetric):
3375
  result["precision_macro"] = self.zero_division
3376
 
3377
 
3378
- class NER(CustomF1):
3379
- """F1 Metrics that receives as input a list of (Entity,EntityType) pairs."""
3380
 
3381
- prediction_type = List[Tuple[str, str]]
 
 
 
 
 
 
3382
 
3383
- def get_element_group(self, element, additional_input):
3384
- return element[1]
 
 
 
 
 
3385
 
3386
- def get_element_representation(self, element, additional_input):
3387
- return str(element)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3388
 
 
3389
 
3390
- class KeyValueExtraction(CustomF1):
3391
- """F1 Metrics that receives as input a list of (Key,Value) pairs."""
3392
 
3393
  prediction_type = List[Tuple[str, str]]
3394
 
3395
  def get_element_group(self, element, additional_input):
3396
- return element[0]
3397
 
3398
  def get_element_representation(self, element, additional_input):
3399
  return str(element)
@@ -6004,6 +6100,9 @@ class GraniteGuardianBase(InstanceMetric):
6004
  )
6005
 
6006
  def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
 
 
 
6007
  self.verify_granite_guardian_config(task_data)
6008
  self.set_main_score()
6009
 
@@ -6017,7 +6116,10 @@ class GraniteGuardianBase(InstanceMetric):
6017
  )
6018
  messages = self.process_input_fields(task_data)
6019
  prompt = self.get_prompt(messages)
6020
- result = self.inference_engine.infer_log_probs([{"source": prompt}])
 
 
 
6021
  generated_tokens_list = result[0]
6022
  label, prob_of_risk = self.parse_output(generated_tokens_list)
6023
  confidence_score = (
@@ -6030,6 +6132,7 @@ class GraniteGuardianBase(InstanceMetric):
6030
  f"{self.main_score}_prob_of_risk": prob_of_risk,
6031
  f"{self.main_score}_certainty": confidence_score,
6032
  f"{self.main_score}_label": label,
 
6033
  }
6034
  logger.debug(f"Results are ready:\n{result}")
6035
  return result
@@ -6042,7 +6145,7 @@ class GraniteGuardianBase(InstanceMetric):
6042
  generated_tokens["top_tokens"] for generated_tokens in generated_tokens_list
6043
  ]
6044
  prob = self.get_probabilities(top_tokens_list)
6045
- prob_of_risk = prob[1]
6046
 
6047
  res = next(iter(generated_tokens_list))["text"].strip()
6048
 
@@ -6055,7 +6158,7 @@ class GraniteGuardianBase(InstanceMetric):
6055
 
6056
  return label, prob_of_risk
6057
 
6058
- def get_probabilities(self, top_tokens_list):
6059
  import torch
6060
 
6061
  safe_token_prob = 1e-50
@@ -6254,7 +6357,7 @@ class SQLExecutionAccuracy(InstanceMetric):
6254
  _requirements_list = ["sqlglot", "func_timeout"]
6255
 
6256
  @staticmethod
6257
- def compare_dfs_ignore_colnames(df1, df2):
6258
  """Compares two DataFrames based on row content, ignoring column names.
6259
 
6260
  Args:
@@ -6262,7 +6365,7 @@ class SQLExecutionAccuracy(InstanceMetric):
6262
  df2 (pd.DataFrame): Pandas DataFrame 2 to compare.
6263
 
6264
  Returns:
6265
- True if the DataFrames have the same content (ignoring column names),
6266
  False otherwise.
6267
  """
6268
  df1.fillna(0, inplace=True)
@@ -6276,6 +6379,20 @@ class SQLExecutionAccuracy(InstanceMetric):
6276
 
6277
  return df1_rows_sorted == df2_rows_sorted
6278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6279
  @staticmethod
6280
  def is_subset_ignore_colnames(df1, df2):
6281
  """Checks if df1 is a subset of df2 based on row content, ignoring column names.
@@ -6343,6 +6460,7 @@ class SQLExecutionAccuracy(InstanceMetric):
6343
  import time
6344
 
6345
  from func_timeout import func_timeout
 
6346
 
6347
  from .sql_utils import sqlglot_optimized_equivalence
6348
 
@@ -6358,6 +6476,9 @@ class SQLExecutionAccuracy(InstanceMetric):
6358
  )
6359
  end_time = time.perf_counter()
6360
  gold_sql_runtime = end_time - start_time
 
 
 
6361
  except Exception as e:
6362
  gold_error = f"Error executing gold SQL: {e}"
6363
  if gold_error is not None:
@@ -6389,10 +6510,10 @@ class SQLExecutionAccuracy(InstanceMetric):
6389
  gold_sql_runtime,
6390
  0,
6391
  0,
6392
- 1,
6393
  0,
6394
  gold_df.to_json(),
6395
- gold_df.to_json(),
6396
  "",
6397
  )
6398
  if predicted_sql.lower().strip() == gold_sql.lower().strip():
@@ -6417,6 +6538,9 @@ class SQLExecutionAccuracy(InstanceMetric):
6417
  )
6418
  end_time = time.perf_counter()
6419
  pred_sql_runtime = end_time - start_time
 
 
 
6420
  except Exception as e:
6421
  pred_error = f"Error executing predicted SQL: {e}"
6422
  logger.info(pred_error)
@@ -6445,9 +6569,20 @@ class SQLExecutionAccuracy(InstanceMetric):
6445
  pred_res = pred_res["results"]
6446
  predicted_df = pd.DataFrame(pred_res)
6447
 
6448
- execution_result = (
6449
- 1 if self.compare_dfs_ignore_colnames(predicted_df, gold_df) else 0
6450
- )
 
 
 
 
 
 
 
 
 
 
 
6451
 
6452
  subset_non_empty_execution_result = 0
6453
  non_empty_execution_result = 0
@@ -6473,6 +6608,8 @@ class SQLExecutionAccuracy(InstanceMetric):
6473
  )
6474
 
6475
  def compute(self, references: List[Any], prediction: str, task_data: Dict) -> dict:
 
 
6476
  predicted_sql = prediction
6477
  execution_result: float = 0.0
6478
 
 
63
  from .operators import ArtifactFetcherMixin, Copy, Set
64
  from .random_utils import get_seed
65
  from .settings_utils import get_settings
 
66
  from .stream import MultiStream, Stream
67
  from .type_utils import Type, isoftype, parse_type_string, to_type_string
68
  from .utils import deep_copy, recursive_copy
69
 
 
 
70
  logger = get_logger()
71
  settings = get_settings()
72
 
 
124
 
125
  def nan_max(x):
126
  with warnings.catch_warnings():
 
 
 
 
127
  warnings.simplefilter("ignore", category=RuntimeWarning)
128
  return np.nanmax(x)
129
 
130
+ def nan_std(x):
131
+ with warnings.catch_warnings():
132
+ warnings.simplefilter("ignore", category=RuntimeWarning)
133
+ result = np.nanstd(x)
134
+ try:
135
+ return float(result)
136
+ except:
137
+ return result
138
+
139
 
140
  class UpdateStream(InstanceOperator):
141
  update: dict
 
367
  return np.random.default_rng(hash(get_seed()) & _max_32bit)
368
 
369
 
370
+ class Statistic:
371
+ """Statistic for which the confidence interval is to be calculated.
372
+
373
+ `statistic` must be a callable that accepts ``len(data)`` samples
374
+ as separate arguments and returns the resulting statistic.
375
+ If `vectorized` is set ``True``,
376
+ `statistic` must also accept a keyword argument `axis` and be
377
+ vectorized to compute the statistic along the provided `axis`.
378
+ """
379
+
380
+ def __init__(self, data, score_names, scorer):
381
+ self.data = data
382
+ self.score_names = score_names
383
+ self.scorer = scorer
384
+ self._history = []
385
+
386
+ def __call__(self, indices, axis=0):
387
+ # indices might be a 1D or 2D array, depending on bootstrap internals
388
+ # For simplicity, ensure we handle them as 1D.
389
+ indices = np.atleast_1d(indices).astype(int)
390
+
391
+ # Gather the subset
392
+ sample = [self.data[i] for i in indices]
393
+
394
+ # Compute metrics on this sample
395
+ scores = self.scorer(sample)
396
+
397
+ # Return them in consistent order
398
+ result = np.array([scores[m] for m in self.score_names])
399
+ self._history.append(result)
400
+ return result
401
+ def mean(self, idx):
402
+ return nan_mean([result[idx] for result in self._history])
403
+
404
+ def std(self, idx):
405
+ return nan_std([result[idx] for result in self._history])
406
+
407
  class ConfidenceIntervalMixin(Artifact):
408
  n_resamples: int = 1000
409
  confidence_level: float = 0.95
 
413
  def _sample_to_scores(self, sample: List[Any]) -> Dict[str, Any]:
414
  pass
415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
 
417
  def bootstrap(self, data: List[Any], score_names: List[str]):
418
  if self.ci_score_names is not None:
419
  score_names = self.ci_score_names
420
 
421
+
422
+ statistic = Statistic(data, score_names, self._sample_to_scores)
423
+ with warnings.catch_warnings():
424
+ warnings.filterwarnings( # Ignore error the arises when all sample scores are identical
425
+ "ignore",
426
+ message="invalid value encountered in divide",
427
+ category=RuntimeWarning
428
+ )
429
+
430
+ intervals = bootstrap(
431
+ (np.arange(len(data)),),
432
+ statistic=statistic,
433
+ n_resamples=self.n_resamples,
434
+ confidence_level=self.confidence_level,
435
+ random_state=new_random_generator(),
436
+ paired=False,
437
+ vectorized=False,
438
+ method="BCa",
439
+ ).confidence_interval
440
+
441
 
442
  result = {}
443
  for i, metric in enumerate(score_names):
444
+ high = intervals.high[i]
445
+ low = intervals.low[i]
446
+ if np.isnan(high) and np.isnan(low):
447
+ if statistic.std(i) == 0: # When sample scores are identical "BCa" will fail (due to division by std 0)
448
+ high = low = statistic.mean(i) # In this case we will use the mean (as there is no variance)
449
+ result[f"{metric}_ci_low"] = float(low)
450
+ result[f"{metric}_ci_high"] = float(high)
451
 
452
  return result
453
 
 
2807
  remote_url = "https://raw.githubusercontent.com/czyssrs/FinQA/dfc5b72c01ee17c442d28d5201b82a1f4e95d5af/code/evaluate/evaluate.py"
2808
  local_filepath = "/tmp/finqa_eval_script.py"
2809
  module_name = "finqa_eval"
2810
+ hash_of_script = "42430b8613082bb4b85d49210284135d" # pragma: allowlist secret
2811
 
2812
  download_finqa_eval_script_file(remote_url, local_filepath, hash_of_script)
2813
  self.finqa_module = load_finqa_eval_module_from_file(
 
3413
  result["precision_macro"] = self.zero_division
3414
 
3415
 
3416
+ class KeyValueExtraction(GlobalMetric):
 
3417
 
3418
+ prediction_type = Dict[str,str]
3419
+ metric : Metric
3420
+ single_reference_per_prediction = True
3421
+ main_score = ""
3422
+ def prepare(self):
3423
+ super().prepare()
3424
+ self.main_score = f"{self.metric.main_score}_micro"
3425
 
3426
+ def compute(
3427
+ self,
3428
+ references: List[List[Any]],
3429
+ predictions: List[Any],
3430
+ task_data: List[Dict],
3431
+ ) -> dict:
3432
+ references = [element[0] for element in references]
3433
 
3434
+ key_statistics = {}
3435
+ all_reference_keys = set()
3436
+ for reference in references:
3437
+ all_reference_keys.update(list(reference.keys()))
3438
+ for key in all_reference_keys:
3439
+ key_statistics[key]= []
3440
+
3441
+ num_prediction_keys=0
3442
+ illegal_prediction_keys=0
3443
+ for reference, prediction in zip(references, predictions):
3444
+ for key in all_reference_keys:
3445
+ if (key not in reference and key not in prediction):
3446
+ continue
3447
+ if (key in reference and key in prediction):
3448
+ multi_stream = MultiStream.from_iterables({"test": [{"prediction" : prediction[key],
3449
+ "references" : [reference[key]]}
3450
+ ]})
3451
+ output_multi_stream = self.metric(multi_stream)
3452
+ output_stream = output_multi_stream["test"]
3453
+ score = next(iter(output_stream))["score"]["global"]["score"]
3454
+ key_statistics[key].append(score)
3455
+ else:
3456
+ key_statistics[key].append(0.0)
3457
+
3458
+ for key in prediction.keys():
3459
+ num_prediction_keys += 1
3460
+ if key not in all_reference_keys:
3461
+ illegal_prediction_keys += 1
3462
+
3463
+ result={}
3464
+
3465
+ average = 0
3466
+ total = 0
3467
+
3468
+ weighted_average = 0
3469
+ for key in key_statistics:
3470
+ mean_for_key = numpy.mean(key_statistics[key])
3471
+ num = len(key_statistics[key])
3472
+ total += num
3473
+ average += mean_for_key
3474
+ weighted_average += mean_for_key * num
3475
+ result[f"{self.metric.main_score}_{key}"] = mean_for_key
3476
+
3477
+ result[f"{self.metric.main_score}_micro"] = weighted_average / total
3478
+ result[f"{self.metric.main_score}_macro"] = average / len(key_statistics)
3479
+ if (num_prediction_keys !=0):
3480
+ result[f"{self.metric.main_score}_legal_keys_in_predictions"] = 1 - 1.0 * illegal_prediction_keys / num_prediction_keys
3481
+ else:
3482
+ result[f"{self.metric.main_score}_legal_keys_in_predictions"] = 0
3483
 
3484
+ return result
3485
 
3486
+ class NER(CustomF1):
3487
+ """F1 Metrics that receives as input a list of (Entity,EntityType) pairs."""
3488
 
3489
  prediction_type = List[Tuple[str, str]]
3490
 
3491
  def get_element_group(self, element, additional_input):
3492
+ return element[1]
3493
 
3494
  def get_element_representation(self, element, additional_input):
3495
  return str(element)
 
6100
  )
6101
 
6102
  def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
6103
+ # TODO replace with logic inside verify_granite_guardian_config and process_input_fields
6104
+ task_data["prediction"] = prediction
6105
+
6106
  self.verify_granite_guardian_config(task_data)
6107
  self.set_main_score()
6108
 
 
6116
  )
6117
  messages = self.process_input_fields(task_data)
6118
  prompt = self.get_prompt(messages)
6119
+ data_classification_policy = task_data.get("metadata", {}).get("data_classification_policy")
6120
+
6121
+ result = self.inference_engine.infer_log_probs([{"source": prompt, "data_classification_policy": data_classification_policy}])
6122
+
6123
  generated_tokens_list = result[0]
6124
  label, prob_of_risk = self.parse_output(generated_tokens_list)
6125
  confidence_score = (
 
6132
  f"{self.main_score}_prob_of_risk": prob_of_risk,
6133
  f"{self.main_score}_certainty": confidence_score,
6134
  f"{self.main_score}_label": label,
6135
+ f"{self.main_score}_prompt": prompt,
6136
  }
6137
  logger.debug(f"Results are ready:\n{result}")
6138
  return result
 
6145
  generated_tokens["top_tokens"] for generated_tokens in generated_tokens_list
6146
  ]
6147
  prob = self.get_probabilities(top_tokens_list)
6148
+ prob_of_risk = prob[1].item()
6149
 
6150
  res = next(iter(generated_tokens_list))["text"].strip()
6151
 
 
6158
 
6159
  return label, prob_of_risk
6160
 
6161
+ def get_probabilities(self, top_tokens_list) -> Tuple[np.float32, np.float32]:
6162
  import torch
6163
 
6164
  safe_token_prob = 1e-50
 
6357
  _requirements_list = ["sqlglot", "func_timeout"]
6358
 
6359
  @staticmethod
6360
+ def compare_dfs_ignore_colnames_ordered_rows(df1, df2):
6361
  """Compares two DataFrames based on row content, ignoring column names.
6362
 
6363
  Args:
 
6365
  df2 (pd.DataFrame): Pandas DataFrame 2 to compare.
6366
 
6367
  Returns:
6368
+ True if the DataFrames have the same ordered rows (ignoring column names),
6369
  False otherwise.
6370
  """
6371
  df1.fillna(0, inplace=True)
 
6379
 
6380
  return df1_rows_sorted == df2_rows_sorted
6381
 
6382
+ @staticmethod
6383
+ def compare_dfs_ignore_colnames_unordered_rows(df1, df2):
6384
+ """Compares two DataFrames based on row content, ignoring row order and column names.
6385
+
6386
+ Args:
6387
+ df1 (pd.DataFrame): Pandas DataFrame 1 to compare.
6388
+ df2 (pd.DataFrame): Pandas DataFrame 2 to compare.
6389
+
6390
+ Returns:
6391
+ True if the DataFrames have the same content (ignoring column names and row order),
6392
+ False otherwise.
6393
+ """
6394
+ return set(map(tuple, df1.to_numpy())) == set(map(tuple, df2.to_numpy()))
6395
+
6396
  @staticmethod
6397
  def is_subset_ignore_colnames(df1, df2):
6398
  """Checks if df1 is a subset of df2 based on row content, ignoring column names.
 
6460
  import time
6461
 
6462
  from func_timeout import func_timeout
6463
+ from func_timeout.exceptions import FunctionTimedOut
6464
 
6465
  from .sql_utils import sqlglot_optimized_equivalence
6466
 
 
6476
  )
6477
  end_time = time.perf_counter()
6478
  gold_sql_runtime = end_time - start_time
6479
+ except FunctionTimedOut as e:
6480
+ pred_error = f"Timeout error executing gold SQL: {e}"
6481
+ logger.warning(pred_error)
6482
  except Exception as e:
6483
  gold_error = f"Error executing gold SQL: {e}"
6484
  if gold_error is not None:
 
6510
  gold_sql_runtime,
6511
  0,
6512
  0,
6513
+ 0,
6514
  0,
6515
  gold_df.to_json(),
6516
+ "",
6517
  "",
6518
  )
6519
  if predicted_sql.lower().strip() == gold_sql.lower().strip():
 
6538
  )
6539
  end_time = time.perf_counter()
6540
  pred_sql_runtime = end_time - start_time
6541
+ except FunctionTimedOut as e:
6542
+ pred_error = f"Timeout error executing predicted SQL: {e}"
6543
+ logger.info(pred_error)
6544
  except Exception as e:
6545
  pred_error = f"Error executing predicted SQL: {e}"
6546
  logger.info(pred_error)
 
6569
  pred_res = pred_res["results"]
6570
  predicted_df = pd.DataFrame(pred_res)
6571
 
6572
+ if "ORDER BY" in gold_sql.upper():
6573
+ execution_result = (
6574
+ 1
6575
+ if self.compare_dfs_ignore_colnames_ordered_rows(predicted_df, gold_df)
6576
+ else 0
6577
+ )
6578
+ else:
6579
+ execution_result = (
6580
+ 1
6581
+ if self.compare_dfs_ignore_colnames_unordered_rows(
6582
+ predicted_df, gold_df
6583
+ )
6584
+ else 0
6585
+ )
6586
 
6587
  subset_non_empty_execution_result = 0
6588
  non_empty_execution_result = 0
 
6608
  )
6609
 
6610
  def compute(self, references: List[Any], prediction: str, task_data: Dict) -> dict:
6611
+ from .sql_utils import get_db_connector
6612
+
6613
  predicted_sql = prediction
6614
  execution_result: float = 0.0
6615
 
schema.py CHANGED
@@ -67,8 +67,7 @@ def load_chat_source(chat_str):
67
  )
68
  return chat
69
 
70
-
71
- def loads_instance(batch):
72
  if (
73
  "source" in batch
74
  and isinstance(batch["source"][0], str)
@@ -86,6 +85,24 @@ def loads_instance(batch):
86
  batch["task_data"] = [json.loads(d) for d in batch["task_data"]]
87
  return batch
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  class FinalizeDataset(InstanceOperatorValidator):
91
  group_by: List[List[str]]
 
67
  )
68
  return chat
69
 
70
+ def loads_batch(batch):
 
71
  if (
72
  "source" in batch
73
  and isinstance(batch["source"][0], str)
 
85
  batch["task_data"] = [json.loads(d) for d in batch["task_data"]]
86
  return batch
87
 
88
+ def loads_instance(instance):
89
+ if (
90
+ "source" in instance
91
+ and isinstance(instance["source"], str)
92
+ and (
93
+ instance["source"].startswith('[{"role":')
94
+ or instance["source"].startswith('[{"content":')
95
+ )
96
+ ):
97
+ instance["source"] = load_chat_source(instance["source"])
98
+ if (
99
+ not settings.task_data_as_text
100
+ and "task_data" in instance
101
+ and isinstance(instance["task_data"], str)
102
+ ):
103
+ instance["task_data"] = json.loads(instance["task_data"])
104
+ return instance
105
+
106
 
107
  class FinalizeDataset(InstanceOperatorValidator):
108
  group_by: List[List[str]]
serializers.py CHANGED
@@ -7,7 +7,6 @@ from typing import Any, Dict, List, Union
7
  from .dataclass import AbstractField, Field
8
  from .operators import InstanceFieldOperator
9
  from .settings_utils import get_constants
10
- from .sql_utils import get_db_connector
11
  from .type_utils import isoftype, to_type_string
12
  from .types import (
13
  Dialog,
@@ -203,5 +202,7 @@ class SQLDatabaseAsSchemaSerializer(SingleTypeSerializer):
203
  serialized_type = SQLDatabase
204
 
205
  def serialize(self, value: SQLDatabase, instance: Dict[str, Any]) -> str:
 
 
206
  connector = get_db_connector(value["db_type"])(value)
207
  return connector.get_table_schema()
 
7
  from .dataclass import AbstractField, Field
8
  from .operators import InstanceFieldOperator
9
  from .settings_utils import get_constants
 
10
  from .type_utils import isoftype, to_type_string
11
  from .types import (
12
  Dialog,
 
202
  serialized_type = SQLDatabase
203
 
204
  def serialize(self, value: SQLDatabase, instance: Dict[str, Any]) -> str:
205
+ from .sql_utils import get_db_connector
206
+
207
  connector = get_db_connector(value["db_type"])(value)
208
  return connector.get_table_schema()
settings_utils.py CHANGED
@@ -159,6 +159,7 @@ if Settings.is_uninitilized():
159
  settings.hf_offline_datasets_path = None
160
  settings.hf_offline_metrics_path = None
161
  settings.hf_offline_models_path = None
 
162
 
163
  if Constants.is_uninitilized():
164
  constants = Constants()
 
159
  settings.hf_offline_datasets_path = None
160
  settings.hf_offline_metrics_path = None
161
  settings.hf_offline_models_path = None
162
+ settings.inference_engine_cache_path = "./inference_engine_cache/"
163
 
164
  if Constants.is_uninitilized():
165
  constants = Constants()
sql_utils.py CHANGED
@@ -1,4 +1,7 @@
 
1
  import glob
 
 
2
  import os
3
  import re
4
  import sqlite3
@@ -16,6 +19,14 @@ from .types import SQLDatabase
16
 
17
  logger = get_logger()
18
 
 
 
 
 
 
 
 
 
19
 
20
  class DatabaseConnector(ABC):
21
  """Abstract base class for database connectors."""
@@ -23,7 +34,7 @@ class DatabaseConnector(ABC):
23
  def __init__(self, db_config: SQLDatabase):
24
  self.db_config = db_config
25
  self.databases_folder = os.path.join(
26
- os.environ.get("UNITXT_TEXT2SQL_CACHE", "cache/text2sql"), "databases"
27
  )
28
  os.makedirs(self.databases_folder, exist_ok=True)
29
 
@@ -187,6 +198,177 @@ class InMemoryDatabaseConnector(DatabaseConnector):
187
  conn.close()
188
 
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  @lru_cache(maxsize=128)
191
  def execute_query_remote(
192
  api_url: str,
@@ -318,12 +500,20 @@ class RemoteDatabaseConnector(DatabaseConnector):
318
 
319
  def execute_query(self, query: str) -> Any:
320
  """Executes a query against the remote database, with retries for certain exceptions."""
321
- return execute_query_remote(
322
- api_url=self.api_url,
323
- database_id=self.database_id,
324
- api_key=self.api_key,
325
- query=query,
326
- timeout=self.timeout,
 
 
 
 
 
 
 
 
327
  )
328
 
329
 
 
1
+ import functools
2
  import glob
3
+ import hashlib
4
+ import json
5
  import os
6
  import re
7
  import sqlite3
 
19
 
20
  logger = get_logger()
21
 
22
+ # Check if caching is enabled via environment variable
23
+ CACHE_LOCATION = os.getenv("UNITXT_CACHE_LOCATION")
24
+
25
+ # Set max cache size to 10GB or the value of env var MAX_CACHE_SIZE
26
+ MAX_CACHE_SIZE = os.getenv("MAX_CACHE_SIZE", 10 * 1024**3)
27
+
28
+ _cache_instance = None
29
+
30
 
31
  class DatabaseConnector(ABC):
32
  """Abstract base class for database connectors."""
 
34
  def __init__(self, db_config: SQLDatabase):
35
  self.db_config = db_config
36
  self.databases_folder = os.path.join(
37
+ os.environ.get("UNITXT_CACHE_LOCATION", "cache/text2sql"), "databases"
38
  )
39
  os.makedirs(self.databases_folder, exist_ok=True)
40
 
 
198
  conn.close()
199
 
200
 
201
+ def get_cache():
202
+ """Returns a singleton cache instance, initializing it if necessary."""
203
+ global _cache_instance
204
+ if _cache_instance is None:
205
+ _cache_instance = Cache()
206
+ return _cache_instance
207
+
208
+
209
+ def generate_cache_key(*args, **kwargs):
210
+ """Generate a stable hashable cache key for various input types.
211
+
212
+ :param args: Positional arguments of the function.
213
+ :param kwargs: Keyword arguments of the function.
214
+ :return: A hashed key as a string.
215
+ """
216
+ try:
217
+ # Convert args and kwargs to a JSON string (sorted to ensure consistency)
218
+ serialized = json.dumps(
219
+ {"args": args, "kwargs": kwargs}, sort_keys=True, default=str
220
+ )
221
+ except TypeError:
222
+ # Fallback for non-serializable objects
223
+ serialized = repr((args, kwargs))
224
+
225
+ # Hash the serialized data
226
+ return hashlib.md5(serialized.encode()).hexdigest()
227
+
228
+
229
+ class Cache:
230
+ """A class that provides disk-based caching functionality for a given function."""
231
+
232
+ def __init__(self):
233
+ """Initializes the cache.
234
+
235
+ If `CACHE_LOCATION` (os.getenv("UNITXT_CACHE_LOCATION") is set, a disk-based
236
+ cache is created using `diskcache`.
237
+
238
+ Args:
239
+ None
240
+
241
+ Returns:
242
+ None
243
+ """
244
+ if CACHE_LOCATION:
245
+ try:
246
+ import diskcache
247
+
248
+ # Ensure the cache directory exists
249
+ os.makedirs(CACHE_LOCATION, exist_ok=True)
250
+
251
+ # Create a global diskcache Cache instance
252
+ self.cache = diskcache.Cache(CACHE_LOCATION, size_limit=MAX_CACHE_SIZE)
253
+ logger.info(f"Caching enabled at {CACHE_LOCATION}")
254
+ except ImportError as e:
255
+ raise ImportError(
256
+ "UNITXT_CACHE_LOCATION is set, but diskcache is not installed.\n"
257
+ "Please install diskcache `pip install diskcache` "
258
+ "or unset UNITXT_CACHE_LOCATION."
259
+ ) from e
260
+ else:
261
+ self.cache = None # Disable caching
262
+
263
+ def get_or_set(self, key, compute_fn, no_cache=False, refresh=False):
264
+ if not self.cache or no_cache:
265
+ logger.info(f"Bypassing cache for key: {key}")
266
+ return compute_fn()
267
+
268
+ if refresh and key in self.cache:
269
+ logger.info(f"Refreshing cache for key: {key}")
270
+ del self.cache[key]
271
+
272
+ if key in self.cache:
273
+ logger.info(f"Cache hit for key: {key}")
274
+ return self.cache[key]
275
+
276
+ logger.info(f"Cache miss for key: {key}. Computing value...")
277
+ result = compute_fn()
278
+ self.cache[key] = result
279
+ logger.info(f"Stored result in cache for key: {key}")
280
+ return result
281
+
282
+ async def async_get_or_set(self, key, compute_fn, no_cache=False, refresh=False):
283
+ if not self.cache or no_cache:
284
+ logger.info(f"Bypassing cache for key: {key}")
285
+ return await compute_fn()
286
+
287
+ if refresh and key in self.cache:
288
+ logger.info(f"Refreshing cache for key: {key}")
289
+ del self.cache[key]
290
+
291
+ if key in self.cache:
292
+ logger.info(f"Cache hit for key: {key}")
293
+ return self.cache[key]
294
+
295
+ logger.info(f"Cache miss for key: {key}. Computing value asynchronously...")
296
+ result = await compute_fn()
297
+ self.cache[key] = result
298
+ logger.info(f"Stored result in cache for key: {key}")
299
+ return result
300
+
301
+ def memoize(self, key_func=generate_cache_key, no_cache=False, refresh=False):
302
+ def decorator(func):
303
+ @functools.wraps(func)
304
+ def wrapper(*args, **kwargs):
305
+ if not self.cache or no_cache:
306
+ logger.info(f"Bypassing cache for function: {func.__name__}")
307
+ return func(*args, **kwargs)
308
+
309
+ key = key_func(func.__name__, *args, **kwargs)
310
+
311
+ if refresh and key in self.cache:
312
+ logger.info(
313
+ f"Refreshing cache for function: {func.__name__}, key: {key}"
314
+ )
315
+ del self.cache[key]
316
+
317
+ if key in self.cache:
318
+ logger.info(f"Cache hit for function: {func.__name__}, key: {key}")
319
+ return self.cache[key]
320
+
321
+ logger.info(
322
+ f"Cache miss for function: {func.__name__}, key: {key}. Computing value..."
323
+ )
324
+ result = func(*args, **kwargs)
325
+ self.cache[key] = result
326
+ logger.info(
327
+ f"Stored result in cache for function: {func.__name__}, key: {key}"
328
+ )
329
+ return result
330
+
331
+ return wrapper
332
+
333
+ return decorator
334
+
335
+ def async_memoize(self, key_func=generate_cache_key, no_cache=False, refresh=False):
336
+ def decorator(func):
337
+ @functools.wraps(func)
338
+ async def wrapper(*args, **kwargs):
339
+ if no_cache:
340
+ logger.info(f"Bypassing cache for async function: {func.__name__}")
341
+ return await func(*args, **kwargs)
342
+
343
+ key = key_func(func.__name__, *args, **kwargs)
344
+
345
+ if refresh and key in self.cache:
346
+ logger.info(
347
+ f"Refreshing cache for async function: {func.__name__}, key: {key}"
348
+ )
349
+ del self.cache[key]
350
+
351
+ if key in self.cache:
352
+ logger.info(
353
+ f"Cache hit for async function: {func.__name__}, key: {key}"
354
+ )
355
+ return self.cache[key]
356
+
357
+ logger.info(
358
+ f"Cache miss for async function: {func.__name__}, key: {key}. Computing value..."
359
+ )
360
+ result = await func(*args, **kwargs)
361
+ self.cache[key] = result
362
+ logger.info(
363
+ f"Stored result in cache for async function: {func.__name__}, key: {key}"
364
+ )
365
+ return result
366
+
367
+ return wrapper
368
+
369
+ return decorator
370
+
371
+
372
  @lru_cache(maxsize=128)
373
  def execute_query_remote(
374
  api_url: str,
 
500
 
501
  def execute_query(self, query: str) -> Any:
502
  """Executes a query against the remote database, with retries for certain exceptions."""
503
+ cache = get_cache()
504
+
505
+ cache_key = generate_cache_key(
506
+ "sql_request", self.api_url, self.database_id, query
507
+ )
508
+ return cache.get_or_set(
509
+ cache_key,
510
+ lambda: execute_query_remote(
511
+ api_url=self.api_url,
512
+ database_id=self.database_id,
513
+ api_key=self.api_key,
514
+ query=query,
515
+ timeout=self.timeout,
516
+ ),
517
  )
518
 
519
 
struct_data_operators.py CHANGED
@@ -1024,24 +1024,24 @@ class ShuffleColumnsNames(TypeDependentAugmentor):
1024
  return {"header": shuffled_header, "rows": table["rows"]}
1025
 
1026
 
1027
- class JsonStrToListOfKeyValuePairs(FieldOperator):
1028
- """Convert a Json string of representing key value as dictionary to list of key value pairs."""
 
 
 
 
1029
 
1030
  def process_value(self, text: str) -> List[Tuple[str, str]]:
1031
  try:
1032
  dict_value = json.loads(text)
1033
  except Exception as e:
1034
  UnitxtWarning(
1035
- f"Unable to convert input text to json format in JsonStrToListOfKeyValuePairs due to {e}. Text: {text}"
1036
  )
1037
  dict_value = {}
1038
  if not isoftype(dict_value, Dict[str, Any]):
1039
  UnitxtWarning(
1040
- f"Unable to convert input text to dictionary in JsonStrToListOfKeyValuePairs. Text: {text}"
1041
  )
1042
  dict_value = {}
1043
- return [
1044
- (str(key), str(value))
1045
- for key, value in dict_value.items()
1046
- if value is not None
1047
- ]
 
1024
  return {"header": shuffled_header, "rows": table["rows"]}
1025
 
1026
 
1027
+ class JsonStrToDict(FieldOperator):
1028
+ """Convert a Json string of representing key value as dictionary.
1029
+
1030
+ Ensure keys and values are strings, and there are no None values.
1031
+
1032
+ """
1033
 
1034
  def process_value(self, text: str) -> List[Tuple[str, str]]:
1035
  try:
1036
  dict_value = json.loads(text)
1037
  except Exception as e:
1038
  UnitxtWarning(
1039
+ f"Unable to convert input text to json format in JsonStrToDict due to {e}. Text: {text}"
1040
  )
1041
  dict_value = {}
1042
  if not isoftype(dict_value, Dict[str, Any]):
1043
  UnitxtWarning(
1044
+ f"Unable to convert input text to dictionary in JsonStrToDict. Text: {text}"
1045
  )
1046
  dict_value = {}
1047
+ return {str(k): str(v) for k, v in dict_value.items() if v is not None}
 
 
 
 
version.py CHANGED
@@ -1 +1 @@
1
- version = "1.20.0"
 
1
+ version = "1.21.0"