Elron commited on
Commit
39b18be
·
verified ·
1 Parent(s): 99f75f9

Upload folder using huggingface_hub

Browse files
api.py CHANGED
@@ -11,6 +11,7 @@ from datasets.exceptions import DatasetGenerationError
11
  from .artifact import fetch_artifact
12
  from .benchmark import Benchmark
13
  from .card import TaskCard
 
14
  from .dataset_utils import get_dataset_artifact
15
  from .error_utils import UnitxtError
16
  from .inference import (
@@ -149,6 +150,36 @@ def create_dataset(
149
  return load_dataset(card=card, split=split, **kwargs)
150
 
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  def _source_to_dataset(
153
  source: SourceOperator,
154
  split=None,
@@ -157,22 +188,35 @@ def _source_to_dataset(
157
  ):
158
  from .dataset import Dataset as UnitxtDataset
159
 
 
 
 
 
 
 
160
  stream = source()
161
 
162
  try:
163
  ds_builder = UnitxtDataset(
164
  dataset_name="unitxt",
165
- config_name="recipe-" + short_hex_hash(repr(source)),
166
  version=constants.version,
167
  )
168
  if split is not None:
169
  stream = {split: stream[split]}
170
  ds_builder._generators = stream
171
 
172
- ds_builder.download_and_prepare(
173
- verification_mode="no_checks",
174
- download_mode=None if use_cache else "force_redownload",
175
- )
 
 
 
 
 
 
 
176
 
177
  if streaming:
178
  return ds_builder.as_streaming_dataset(split=split)
 
11
  from .artifact import fetch_artifact
12
  from .benchmark import Benchmark
13
  from .card import TaskCard
14
+ from .dataclass import to_dict
15
  from .dataset_utils import get_dataset_artifact
16
  from .error_utils import UnitxtError
17
  from .inference import (
 
150
  return load_dataset(card=card, split=split, **kwargs)
151
 
152
 
153
+ def object_to_str_without_addresses(obj):
154
+ """Generates a string representation of a Python object while removing memory address references.
155
+
156
+ This function is useful for creating consistent and comparable string representations of objects
157
+ that would otherwise include memory addresses (e.g., `<object_name at 0x123abc>`), which can vary
158
+ between executions. By stripping the memory address, the function ensures that the representation
159
+ is stable and independent of the object's location in memory.
160
+
161
+ Args:
162
+ obj: Any Python object to be converted to a string representation.
163
+
164
+ Returns:
165
+ str: A string representation of the object with memory addresses removed if present.
166
+
167
+ Example:
168
+ ```python
169
+ class MyClass:
170
+ pass
171
+
172
+ obj = MyClass()
173
+ print(str(obj)) # "<__main__.MyClass object at 0x7f8b9d4d6e20>"
174
+ print(to_str_without_addresses(obj)) # "<__main__.MyClass object>"
175
+ ```
176
+ """
177
+ obj_str = str(obj)
178
+ if " at 0x" in obj_str:
179
+ obj_str = obj_str.split(" at 0x")[0] + ">"
180
+ return obj_str
181
+
182
+
183
  def _source_to_dataset(
184
  source: SourceOperator,
185
  split=None,
 
188
  ):
189
  from .dataset import Dataset as UnitxtDataset
190
 
191
+ # Generate a unique signature for the source
192
+ source_signature = json.dumps(
193
+ to_dict(source, object_to_str_without_addresses), sort_keys=True
194
+ )
195
+ config_name = "recipe-" + short_hex_hash(source_signature)
196
+ # Obtain data stream from the source
197
  stream = source()
198
 
199
  try:
200
  ds_builder = UnitxtDataset(
201
  dataset_name="unitxt",
202
+ config_name=config_name, # Dictate the cache name
203
  version=constants.version,
204
  )
205
  if split is not None:
206
  stream = {split: stream[split]}
207
  ds_builder._generators = stream
208
 
209
+ try:
210
+ ds_builder.download_and_prepare(
211
+ verification_mode="no_checks",
212
+ download_mode=None if use_cache else "force_redownload",
213
+ )
214
+ except DatasetGenerationError as e:
215
+ if e.__cause__:
216
+ raise e.__cause__ from None
217
+ if e.__context__:
218
+ raise e.__context__ from None
219
+ raise
220
 
221
  if streaming:
222
  return ds_builder.as_streaming_dataset(split=split)
artifact.py CHANGED
@@ -16,13 +16,13 @@ from .dataclass import (
16
  NonPositionalField,
17
  fields,
18
  )
19
- from .error_utils import Documentation, UnitxtError, UnitxtWarning
20
  from .logging_utils import get_logger
21
  from .parsing_utils import (
22
  separate_inside_and_outside_square_brackets,
23
  )
24
  from .settings_utils import get_constants, get_settings
25
- from .text_utils import camel_to_snake_case, is_camel_case, print_dict_as_yaml
26
  from .type_utils import isoftype, issubtype
27
  from .utils import (
28
  artifacts_json_cache,
@@ -342,8 +342,10 @@ class Artifact(Dataclass):
342
  self.verify_data_classification_policy()
343
  self.prepare_args()
344
  if not settings.skip_artifacts_prepare_and_verify:
345
- self.prepare()
346
- self.verify()
 
 
347
 
348
  def _to_raw_dict(self):
349
  return {
@@ -367,11 +369,14 @@ class Artifact(Dataclass):
367
 
368
  def to_json(self):
369
  data = self.to_dict()
 
370
  return json_dump(data)
371
 
372
  def to_yaml(self):
 
 
373
  data = self.to_dict()
374
- return print_dict_as_yaml(data)
375
 
376
  def serialize(self):
377
  if self.__id__ is not None:
@@ -449,20 +454,25 @@ class Artifact(Dataclass):
449
  )
450
  return instance
451
 
452
- if not any(
453
- data_classification in data_classification_policy
454
- for data_classification in instance_data_classification
 
455
  ):
456
- raise UnitxtError(
457
- f"The instance '{instance} 'has the following data classification policy "
458
- f"'{instance_data_classification}', however, the artifact '{name}' "
459
- f"is only configured to support the data with classification "
460
- f"'{data_classification_policy}'. To enable this either change "
461
- f"the 'data_classification_policy' attribute of the artifact, "
462
- f"or modify the environment variable "
463
- f"'UNITXT_DATA_CLASSIFICATION_POLICY' accordingly.",
464
- Documentation.DATA_CLASSIFICATION_POLICY,
465
- )
 
 
 
 
466
 
467
  return instance
468
 
 
16
  NonPositionalField,
17
  fields,
18
  )
19
+ from .error_utils import Documentation, UnitxtError, UnitxtWarning, error_context
20
  from .logging_utils import get_logger
21
  from .parsing_utils import (
22
  separate_inside_and_outside_square_brackets,
23
  )
24
  from .settings_utils import get_constants, get_settings
25
+ from .text_utils import camel_to_snake_case, is_camel_case
26
  from .type_utils import isoftype, issubtype
27
  from .utils import (
28
  artifacts_json_cache,
 
342
  self.verify_data_classification_policy()
343
  self.prepare_args()
344
  if not settings.skip_artifacts_prepare_and_verify:
345
+ with error_context(self, action="Prepare Object"):
346
+ self.prepare()
347
+ with error_context(self, action="Verify Object"):
348
+ self.verify()
349
 
350
  def _to_raw_dict(self):
351
  return {
 
369
 
370
  def to_json(self):
371
  data = self.to_dict()
372
+
373
  return json_dump(data)
374
 
375
  def to_yaml(self):
376
+ import yaml
377
+
378
  data = self.to_dict()
379
+ return yaml.dump(data)
380
 
381
  def serialize(self):
382
  if self.__id__ is not None:
 
454
  )
455
  return instance
456
 
457
+ with error_context(
458
+ self,
459
+ action="Sensitive Data Verification",
460
+ help="https://www.unitxt.ai/en/latest/docs/data_classification_policy.html",
461
  ):
462
+ if not any(
463
+ data_classification in data_classification_policy
464
+ for data_classification in instance_data_classification
465
+ ):
466
+ raise UnitxtError(
467
+ f"The instance '{instance} 'has the following data classification policy "
468
+ f"'{instance_data_classification}', however, the artifact '{name}' "
469
+ f"is only configured to support the data with classification "
470
+ f"'{data_classification_policy}'. To enable this either change "
471
+ f"the 'data_classification_policy' attribute of the artifact, "
472
+ f"or modify the environment variable "
473
+ f"'UNITXT_DATA_CLASSIFICATION_POLICY' accordingly.",
474
+ Documentation.DATA_CLASSIFICATION_POLICY,
475
+ )
476
 
477
  return instance
478
 
collections_operators.py CHANGED
@@ -1,6 +1,8 @@
 
1
  from typing import Any, Dict, Generator, List, Optional
2
 
3
  from .dict_utils import dict_get, dict_set
 
4
  from .operators import FieldOperator, StreamOperator
5
  from .stream import Stream
6
  from .utils import recursive_shallow_copy
@@ -13,11 +15,52 @@ class Dictify(FieldOperator):
13
  return dict(zip(self.with_keys, tup))
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  class DictToTuplesList(FieldOperator):
17
  def process_value(self, dic: Dict) -> Any:
18
  return list(dic.items())
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class Wrap(FieldOperator):
22
  inside: str
23
 
@@ -64,6 +107,13 @@ class Get(FieldOperator):
64
  return collection[self.item]
65
 
66
 
 
 
 
 
 
 
 
67
  class DuplicateByList(StreamOperator):
68
  field: str
69
  to_field: Optional[str] = None
@@ -91,12 +141,16 @@ class DuplicateBySubLists(StreamOperator):
91
  field: str
92
  to_field: Optional[str] = None
93
  use_deep_copy: bool = False
 
 
 
94
 
95
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
96
  to_field = self.field if self.to_field is None else self.to_field
97
  for instance in stream:
98
- elements = instance[self.field]
99
- for i in range(1, len(elements) + 1):
 
100
  if self.use_deep_copy:
101
  instance_copy = recursive_shallow_copy(instance)
102
  instance_copy[to_field] = elements[:i]
@@ -109,6 +163,10 @@ class DuplicateBySubLists(StreamOperator):
109
  yield instance_copy
110
 
111
 
 
 
 
 
112
  class GetLength(FieldOperator):
113
  def process_value(self, collection: Any) -> Any:
114
  return len(collection)
 
1
+ from itertools import zip_longest
2
  from typing import Any, Dict, Generator, List, Optional
3
 
4
  from .dict_utils import dict_get, dict_set
5
+ from .operator import InstanceOperator
6
  from .operators import FieldOperator, StreamOperator
7
  from .stream import Stream
8
  from .utils import recursive_shallow_copy
 
15
  return dict(zip(self.with_keys, tup))
16
 
17
 
18
+ class Zip(InstanceOperator):
19
+ fields: List[str]
20
+ to_field: str
21
+
22
+ def zip(self, values):
23
+ return list(zip(*values))
24
+
25
+ def process(
26
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
27
+ ) -> Dict[str, Any]:
28
+ values = []
29
+ for field in self.fields:
30
+ values.append(dict_get(instance, field))
31
+ dict_set(instance, self.to_field, self.zip(values))
32
+ return instance
33
+
34
+
35
+ class ZipLongest(Zip):
36
+ fields: List[str]
37
+ fill_value: Any = None
38
+
39
+ def zip(self, values):
40
+ return list(zip_longest(*values, fillvalue=self.fill_value))
41
+
42
+
43
  class DictToTuplesList(FieldOperator):
44
  def process_value(self, dic: Dict) -> Any:
45
  return list(dic.items())
46
 
47
 
48
+ def flatten(container):
49
+ def _flat_gen(x):
50
+ for item in x:
51
+ if isinstance(item, (list, tuple)):
52
+ yield from _flat_gen(item)
53
+ else:
54
+ yield item
55
+
56
+ return type(container)(_flat_gen(container))
57
+
58
+
59
+ class Flatten(FieldOperator):
60
+ def process_value(self, value: Any) -> Any:
61
+ return flatten(value)
62
+
63
+
64
  class Wrap(FieldOperator):
65
  inside: str
66
 
 
107
  return collection[self.item]
108
 
109
 
110
+ class Pop(FieldOperator):
111
+ item: Any = None
112
+
113
+ def process_value(self, collection: Any) -> Any:
114
+ return collection.pop(self.item)
115
+
116
+
117
  class DuplicateByList(StreamOperator):
118
  field: str
119
  to_field: Optional[str] = None
 
141
  field: str
142
  to_field: Optional[str] = None
143
  use_deep_copy: bool = False
144
+ start: int = 1
145
+ end: int = 0
146
+ step: int = 1
147
 
148
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
149
  to_field = self.field if self.to_field is None else self.to_field
150
  for instance in stream:
151
+ elements = dict_get(instance, self.field)
152
+ end = len(elements) + 1 + self.end
153
+ for i in range(self.start, end, self.step):
154
  if self.use_deep_copy:
155
  instance_copy = recursive_shallow_copy(instance)
156
  instance_copy[to_field] = elements[:i]
 
163
  yield instance_copy
164
 
165
 
166
+ class ExplodeSubLists(DuplicateBySubLists):
167
+ pass
168
+
169
+
170
  class GetLength(FieldOperator):
171
  def process_value(self, collection: Any) -> Any:
172
  return len(collection)
dataclass.py CHANGED
@@ -297,6 +297,65 @@ def _asdict_inner(obj):
297
  return copy.deepcopy(obj)
298
 
299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  class DataclassMeta(ABCMeta):
301
  """Metaclass for Dataclass.
302
 
 
297
  return copy.deepcopy(obj)
298
 
299
 
300
+ def to_dict(obj, func=copy.deepcopy, _visited=None):
301
+ """Recursively converts an object into a dictionary representation while avoiding infinite recursion due to circular references.
302
+
303
+ Args:
304
+ obj: Any Python object to be converted into a dictionary-like structure.
305
+ func (Callable, optional): A function applied to non-iterable objects. Defaults to `copy.deepcopy`.
306
+ _visited (set, optional): A set of object IDs used to track visited objects and prevent infinite recursion.
307
+
308
+ Returns:
309
+ dict: A dictionary representation of the input object, with supported collections and dataclasses
310
+ recursively processed.
311
+
312
+ Notes:
313
+ - Supports dataclasses, named tuples, lists, tuples, and dictionaries.
314
+ - Circular references are detected using object IDs and replaced by `func(obj)`.
315
+ - Named tuples retain their original type instead of being converted to dictionaries.
316
+ """
317
+ # Initialize visited set on first call
318
+ if _visited is None:
319
+ _visited = set()
320
+
321
+ # Get object ID to track visited objects
322
+ obj_id = id(obj)
323
+
324
+ # If we've seen this object before, return a placeholder to avoid infinite recursion
325
+ if obj_id in _visited:
326
+ return func(obj)
327
+
328
+ # For mutable objects, add to visited set before recursing
329
+ if (
330
+ isinstance(obj, (dict, list))
331
+ or is_dataclass(obj)
332
+ or (isinstance(obj, tuple) and hasattr(obj, "_fields"))
333
+ ):
334
+ _visited.add(obj_id)
335
+
336
+ if is_dataclass(obj):
337
+ return {
338
+ field.name: to_dict(getattr(obj, field.name), func, _visited)
339
+ for field in fields(obj)
340
+ }
341
+
342
+ if isinstance(obj, tuple) and hasattr(obj, "_fields"): # named tuple
343
+ return type(obj)(*[to_dict(v, func, _visited) for v in obj])
344
+
345
+ if isinstance(obj, (list, tuple)):
346
+ return type(obj)([to_dict(v, func, _visited) for v in obj])
347
+
348
+ if isinstance(obj, dict):
349
+ return type(obj)(
350
+ {
351
+ to_dict(k, func, _visited): to_dict(v, func, _visited)
352
+ for k, v in obj.items()
353
+ }
354
+ )
355
+
356
+ return func(obj)
357
+
358
+
359
  class DataclassMeta(ABCMeta):
360
  """Metaclass for Dataclass.
361
 
dataset.py CHANGED
@@ -59,7 +59,6 @@ from .settings_utils import get_constants
59
  from .span_lableing_operators import __file__ as _
60
  from .split_utils import __file__ as _
61
  from .splitters import __file__ as _
62
- from .sql_utils import __file__ as _
63
  from .standard import __file__ as _
64
  from .stream import __file__ as _
65
  from .stream_operators import __file__ as _
@@ -68,6 +67,7 @@ from .struct_data_operators import __file__ as _
68
  from .system_prompts import __file__ as _
69
  from .task import __file__ as _
70
  from .templates import __file__ as _
 
71
  from .text_utils import __file__ as _
72
  from .type_utils import __file__ as _
73
  from .types import __file__ as _
 
59
  from .span_lableing_operators import __file__ as _
60
  from .split_utils import __file__ as _
61
  from .splitters import __file__ as _
 
62
  from .standard import __file__ as _
63
  from .stream import __file__ as _
64
  from .stream_operators import __file__ as _
 
67
  from .system_prompts import __file__ as _
68
  from .task import __file__ as _
69
  from .templates import __file__ as _
70
+ from .text2sql_utils import __file__ as _
71
  from .text_utils import __file__ as _
72
  from .type_utils import __file__ as _
73
  from .types import __file__ as _
dialog_operators.py CHANGED
@@ -17,7 +17,16 @@ The format of the dialog is:
17
  from typing import Any, Dict, List, Optional
18
 
19
  from .formats import SystemFormat
20
- from .operators import InstanceFieldOperator
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  class SerializeDialog(InstanceFieldOperator):
 
17
  from typing import Any, Dict, List, Optional
18
 
19
  from .formats import SystemFormat
20
+ from .operators import FieldOperator, InstanceFieldOperator
21
+
22
+
23
+ class ToDialog(FieldOperator):
24
+ def process_value(self, value: Any) -> Any:
25
+ dialog = []
26
+ for question, answer in value:
27
+ dialog.append({"role": "user", "content": question})
28
+ dialog.append({"role": "agent", "content": answer})
29
+ return dialog
30
 
31
 
32
  class SerializeDialog(InstanceFieldOperator):
dict_utils.py CHANGED
@@ -3,7 +3,7 @@ from typing import Any, List, Tuple
3
 
4
  from .text_utils import to_pretty_string
5
 
6
- indx = re.compile(r"^(\d+)$")
7
 
8
 
9
  def is_index(string):
 
3
 
4
  from .text_utils import to_pretty_string
5
 
6
+ indx = re.compile(r"^-?\d+$")
7
 
8
 
9
  def is_index(string):
error_utils.py CHANGED
@@ -1,7 +1,11 @@
1
- from typing import Optional
 
 
2
 
3
  from .logging_utils import get_logger
 
4
 
 
5
  logger = get_logger()
6
 
7
 
@@ -29,12 +33,9 @@ class UnitxtError(Exception):
29
  """Exception raised for Unitxt errors.
30
 
31
  Args:
32
- message (str):
33
- explanation of the error
34
- additional_info_id (Optional[str]):
35
- relative path to additional documentation on web
36
- If set, should be one of the DOCUMENATION_* constants in the error_utils.py file.
37
-
38
  """
39
 
40
  def __init__(self, message: str, additional_info_id: Optional[str] = None):
@@ -47,14 +48,255 @@ class UnitxtWarning:
47
  """Object to format warning message to log.
48
 
49
  Args:
50
- message (str):
51
- explanation of the warning
52
- additional_info_id (Optional[str]):
53
- relative path to additional documentation on web
54
- If set, should be one of the DOCUMENATION_* constants in the error_utils.py file.
55
  """
56
 
57
  def __init__(self, message: str, additional_info_id: Optional[str] = None):
58
  if additional_info_id is not None:
59
  message += additional_info(additional_info_id)
60
  logger.warning(message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from contextlib import contextmanager
3
+ from typing import Any, Optional
4
 
5
  from .logging_utils import get_logger
6
+ from .settings_utils import get_constants
7
 
8
+ constants = get_constants()
9
  logger = get_logger()
10
 
11
 
 
33
  """Exception raised for Unitxt errors.
34
 
35
  Args:
36
+ message (str): explanation of the error
37
+ additional_info_id (Optional[str]): relative path to additional documentation on web
38
+ If set, should be one of the DOCUMENTATION_* constants in the error_utils.py file.
 
 
 
39
  """
40
 
41
  def __init__(self, message: str, additional_info_id: Optional[str] = None):
 
48
  """Object to format warning message to log.
49
 
50
  Args:
51
+ message (str): explanation of the warning
52
+ additional_info_id (Optional[str]): relative path to additional documentation on web
53
+ If set, should be one of the DOCUMENTATION_* constants in the error_utils.py file.
 
 
54
  """
55
 
56
  def __init__(self, message: str, additional_info_id: Optional[str] = None):
57
  if additional_info_id is not None:
58
  message += additional_info(additional_info_id)
59
  logger.warning(message)
60
+
61
+
62
+ context_block_title = "🦄 Unitxt Error Context"
63
+
64
+
65
+ def _visible_length(text: str) -> int:
66
+ import unicodedata
67
+
68
+ ansi_escape = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]|\x1b\]8;;[^\x1b]*\x1b\\")
69
+ clean_text = ansi_escape.sub("", text)
70
+ width = 0
71
+ for char in clean_text:
72
+ if (
73
+ unicodedata.east_asian_width(char) in ("F", "W")
74
+ or 0x1F300 <= ord(char) <= 0x1F9FF
75
+ ):
76
+ width += 2
77
+ else:
78
+ width += 1
79
+ return width
80
+
81
+
82
+ def _make_object_clickable(
83
+ full_obj_name: str, display_name: Optional[str] = None
84
+ ) -> str:
85
+ import os
86
+
87
+ if display_name is None:
88
+ display_name = full_obj_name.split(".")[-1]
89
+ if full_obj_name.startswith("unitxt."):
90
+ parts = full_obj_name.split(".")
91
+ if len(parts) >= 2:
92
+ module_path = ".".join(parts[:2])
93
+ doc_url = f"{Documentation.URL}{module_path}.html#{full_obj_name}"
94
+ if (
95
+ os.environ.get("TERM_PROGRAM") in ["iTerm.app", "vscode"]
96
+ or os.environ.get("TERMINAL_EMULATOR") == "JetBrains-JediTerm"
97
+ ):
98
+ return f"\033]8;;{doc_url}\033\\{display_name}\033]8;;\033\\"
99
+ return f"{display_name} ({doc_url})"
100
+ return display_name
101
+
102
+
103
+ def _get_existing_context(error: Exception):
104
+ """Extract existing context from an error if it exists."""
105
+ if hasattr(error, "__error_context__"):
106
+ existing = error.__error_context__
107
+ return (
108
+ existing["original_message"],
109
+ existing["context_object"],
110
+ existing["context"],
111
+ )
112
+ return str(error), None, {}
113
+
114
+
115
+ def _format_object_context(obj: Any) -> Optional[str]:
116
+ """Format an object for display in error context."""
117
+ if obj is None:
118
+ return None
119
+ if hasattr(obj, "__class__"):
120
+ class_name = obj.__class__.__name__
121
+ module_name = getattr(obj.__class__, "__module__", "")
122
+ else:
123
+ obj_type = type(obj)
124
+ class_name = obj_type.__name__
125
+ module_name = getattr(obj_type, "__module__", "")
126
+ if module_name:
127
+ full_name = f"{module_name}.{class_name}"
128
+ clickable_object = _make_object_clickable(full_name, class_name)
129
+ return f"Object: {clickable_object}"
130
+ return f"Object: {class_name}"
131
+
132
+
133
+ def _make_clickable_link(url: str) -> str:
134
+ """Create a clickable terminal link."""
135
+ import os
136
+
137
+ if (
138
+ os.environ.get("TERM_PROGRAM") in ["iTerm.app", "vscode"]
139
+ or os.environ.get("TERMINAL_EMULATOR") == "JetBrains-JediTerm"
140
+ ):
141
+ return f"\033]8;;{url}\033\\link\033]8;;\033\\"
142
+ return url
143
+
144
+
145
+ def _format_help_context(help_docs) -> list:
146
+ """Format help documentation into context parts."""
147
+ parts = []
148
+ if isinstance(help_docs, str):
149
+ parts.append(f"Help: {_make_clickable_link(help_docs)}")
150
+ elif isinstance(help_docs, dict):
151
+ for label, url in help_docs.items():
152
+ parts.append(f"Help ({label}): {_make_clickable_link(url)}")
153
+ elif isinstance(help_docs, list):
154
+ for item in help_docs:
155
+ if isinstance(item, dict) and len(item) == 1:
156
+ label, url = next(iter(item.items()))
157
+ parts.append(f"Help ({label}): {_make_clickable_link(url)}")
158
+ elif isinstance(item, str):
159
+ parts.append(f"Help: {_make_clickable_link(item)}")
160
+ return parts
161
+
162
+
163
+ def _build_context_parts(context_object: Any, context: dict) -> list:
164
+ """Build the list of context information parts."""
165
+ parts = []
166
+ ordered_keys = [
167
+ "Python",
168
+ "Unitxt",
169
+ "Stage",
170
+ "Stream",
171
+ "Index",
172
+ "Instance",
173
+ "Object",
174
+ "Action",
175
+ ]
176
+ processed_keys = set()
177
+
178
+ for desired_key in ordered_keys:
179
+ for actual_key in context.keys():
180
+ if actual_key.lower() == desired_key.lower():
181
+ value = (
182
+ "unknown" if context[actual_key] is None else context[actual_key]
183
+ )
184
+ parts.append(f"{actual_key.replace('_', ' ').title()}: {value}")
185
+ processed_keys.add(actual_key)
186
+ break
187
+
188
+ if not any(key.lower() == "object" for key in processed_keys):
189
+ obj_context = _format_object_context(context_object)
190
+ if obj_context:
191
+ parts.append(obj_context)
192
+
193
+ processed_keys.add("help")
194
+ for key, value in context.items():
195
+ if key not in processed_keys:
196
+ value = "unknown" if value is None else value
197
+ parts.append(f"{key.replace('_', ' ').title()}: {value}")
198
+
199
+ if "help" in context:
200
+ parts.extend(_format_help_context(context["help"]))
201
+ else:
202
+ parts.append(f"Help: {_make_clickable_link(Documentation.URL)}")
203
+
204
+ return parts
205
+
206
+
207
+ def _create_context_box(parts: list) -> str:
208
+ """Create a formatted box containing context information."""
209
+ if not parts:
210
+ return ""
211
+ max_width = (
212
+ max(
213
+ _visible_length(context_block_title),
214
+ max(_visible_length(part) for part in parts),
215
+ )
216
+ + 4
217
+ )
218
+ top_line = "┌" + "─" * max_width + "┐"
219
+ bottom_line = "└" + "─" * max_width + "┘"
220
+ lines = [top_line]
221
+ lines.append(
222
+ f"│ {context_block_title}{' ' * (max_width - _visible_length(context_block_title) - 1)}│"
223
+ )
224
+ lines.append(f"│ {'-' * (max_width - 2)} │")
225
+ for part in parts:
226
+ padding = " " * (max_width - _visible_length(part) - 4)
227
+ lines.append(f"│ - {part}{padding}│")
228
+ lines.append(bottom_line)
229
+ return "\n".join(lines)
230
+
231
+
232
+ def _store_context_attributes(
233
+ error: Exception, context_object: Any, context: dict, original_message: str
234
+ ):
235
+ """Store context information in error attributes."""
236
+ error.__error_context__ = {
237
+ "context_object": context_object,
238
+ "context": context,
239
+ "original_message": original_message,
240
+ }
241
+ try:
242
+ error.original_error = type(error)(original_message)
243
+ except (TypeError, ValueError):
244
+ error.original_error = Exception(original_message)
245
+ error.context_object = context_object
246
+ error.context = context
247
+
248
+
249
+ def _add_context_to_exception(
250
+ original_error: Exception, context_object: Any = None, **context
251
+ ):
252
+ """Add context information to an exception by modifying its message."""
253
+ original_message, existing_object, existing_context = _get_existing_context(
254
+ original_error
255
+ )
256
+ final_context_object = existing_object or context_object
257
+ final_context = {
258
+ "Unitxt": constants.version,
259
+ "Python": constants.python,
260
+ **existing_context,
261
+ **context,
262
+ }
263
+ context_parts = _build_context_parts(final_context_object, final_context)
264
+ context_message = _create_context_box(context_parts)
265
+ _store_context_attributes(
266
+ original_error, final_context_object, final_context, original_message
267
+ )
268
+ if context_parts:
269
+ formatted_message = f"\n{context_message}\n\n{original_message}"
270
+ original_error.args = (formatted_message,)
271
+ else:
272
+ original_error.args = (original_message,)
273
+
274
+
275
+ @contextmanager
276
+ def error_context(context_object: Any = None, **context):
277
+ """Context manager that catches exceptions and re-raises them with additional context.
278
+
279
+ Args:
280
+ context_object: The object being processed (optional)
281
+ **context: Any additional context to include in the error message.
282
+ You can provide any key-value pairs that help identify where the error occurred.
283
+
284
+ Special context keys:
285
+ - help: Documentation links to help with the error.
286
+ Can be a string (single URL), dict (label: URL), or list of URLs/dicts.
287
+
288
+ Examples:
289
+ with error_context(self, operation="validation", item_id=42):
290
+ result = process_item(item)
291
+
292
+ with error_context(operation="schema_validation", help="https://docs.example.com/schema"):
293
+ validate_schema(data)
294
+
295
+ with error_context(processor, step="preprocessing", batch_size=32):
296
+ results = process_batch(batch)
297
+ """
298
+ try:
299
+ yield
300
+ except Exception as e:
301
+ _add_context_to_exception(e, context_object, **context)
302
+ raise
evaluate_cli.py CHANGED
@@ -298,9 +298,13 @@ def cli_load_dataset(args: argparse.Namespace) -> HFDataset:
298
  dataset_query=task_str, **overwrite_args
299
  )
300
 
301
- benchmark = Benchmark(subsets=benchmark_subsets)
 
 
 
 
302
 
303
- test_dataset = _source_to_dataset(benchmark, split=args.split)
304
  logger.info(
305
  f"Dataset loaded successfully. Number of instances: {len(test_dataset)}"
306
  )
@@ -414,6 +418,8 @@ def initialize_inference_engine(
414
  chat_kwargs_dict=chat_kwargs_dict,
415
  )
416
 
 
 
417
  # --- Remote Model (CrossProviderInferenceEngine) ---
418
  elif args.model.lower() == "cross_provider":
419
  if "model_name" not in model_args_dict:
@@ -444,6 +450,9 @@ def initialize_inference_engine(
444
  model=remote_model_name,
445
  **model_args_dict,
446
  )
 
 
 
447
  else:
448
  # This case should not be reached due to argparse choices
449
  logger.error(
@@ -682,7 +691,7 @@ def _save_results_to_disk(
682
 
683
  # prepend to the results_path name the time in a wat like this: 2025-04-04T11:37:32
684
 
685
- timestamp = datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
686
 
687
  results_path = prepend_timestamp_to_path(results_path, timestamp)
688
  samples_path = prepend_timestamp_to_path(samples_path, timestamp)
@@ -825,5 +834,49 @@ def main():
825
  logger.info("Unitxt Evaluation CLI finished successfully.")
826
 
827
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
828
  if __name__ == "__main__":
829
  main()
 
298
  dataset_query=task_str, **overwrite_args
299
  )
300
 
301
+ # this hack circumvents an issue with multi-level benchmarks (such Bluebench's translation subset) that fail when wrapped with an additional Benchmark() object.
302
+ if len(benchmark_subsets) == 1:
303
+ source = next(iter(benchmark_subsets.values()))
304
+ else:
305
+ source = Benchmark(subsets=benchmark_subsets)
306
 
307
+ test_dataset = _source_to_dataset(source, split=args.split)
308
  logger.info(
309
  f"Dataset loaded successfully. Number of instances: {len(test_dataset)}"
310
  )
 
418
  chat_kwargs_dict=chat_kwargs_dict,
419
  )
420
 
421
+ # Keep the actual model name for the results
422
+ args.model = inference_model.model_name
423
  # --- Remote Model (CrossProviderInferenceEngine) ---
424
  elif args.model.lower() == "cross_provider":
425
  if "model_name" not in model_args_dict:
 
450
  model=remote_model_name,
451
  **model_args_dict,
452
  )
453
+
454
+ # Keep the actual model name for the results
455
+ args.model = inference_model.engine.model
456
  else:
457
  # This case should not be reached due to argparse choices
458
  logger.error(
 
691
 
692
  # prepend to the results_path name the time in a wat like this: 2025-04-04T11:37:32
693
 
694
+ timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
695
 
696
  results_path = prepend_timestamp_to_path(results_path, timestamp)
697
  samples_path = prepend_timestamp_to_path(samples_path, timestamp)
 
834
  logger.info("Unitxt Evaluation CLI finished successfully.")
835
 
836
 
837
+ def extract_scores(directory): # pragma: no cover
838
+ import pandas as pd
839
+
840
+ data = []
841
+
842
+ for filename in sorted(os.listdir(directory)):
843
+ if filename.endswith("evaluation_results.json"):
844
+ file_path = os.path.join(directory, filename)
845
+ try:
846
+ with open(file_path, encoding="utf-8") as f:
847
+ content = json.load(f)
848
+
849
+ env_info = content.get("environment_info", {})
850
+ timestamp = env_info.get("timestamp_utc", "N/A")
851
+ model = env_info.get("parsed_arguments", {}).get("model", "N/A")
852
+ results = content.get("results", {})
853
+
854
+ row = {}
855
+ row["Model"] = model
856
+ row["Timestamp"] = timestamp
857
+ row["Average"] = results.get("score", "N/A")
858
+
859
+ for key in results.keys():
860
+ if isinstance(results[key], dict):
861
+ score = results[key].get("score", "N/A")
862
+ row[key] = score
863
+
864
+ data.append(row)
865
+ except Exception as e:
866
+ logger.error(f"Error parsing results file {filename}: {e}.")
867
+
868
+ return pd.DataFrame(data).sort_values(by="Timestamp", ascending=True)
869
+
870
+
871
+ def summarize_cli():
872
+ if len(sys.argv) != 2:
873
+ logger.error("Usage: python summarize_cli_results.py <results-directory>")
874
+ sys.exit(1)
875
+ directory = sys.argv[1]
876
+ df = extract_scores(directory)
877
+
878
+ logger.info(df.to_markdown(index=False))
879
+
880
+
881
  if __name__ == "__main__":
882
  main()
formats.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import re
2
  from abc import abstractmethod
3
  from typing import (
@@ -18,6 +19,7 @@ from .image_operators import image_to_data_url
18
  from .operator import InstanceOperator
19
  from .settings_utils import get_constants
20
  from .type_utils import isoftype
 
21
  from .utils import retry_connection_with_exponential_backoff
22
 
23
  constants = get_constants()
@@ -135,6 +137,9 @@ class BaseFormat(Format):
135
  def _prepare_instance_fields(self, instance) -> Tuple[str]:
136
  instance_fields = {}
137
 
 
 
 
138
  for field in (
139
  "source",
140
  constants.instruction_field,
@@ -170,6 +175,7 @@ class BaseFormat(Format):
170
  target_prefix: str,
171
  demos: List[Dict[str, Any]],
172
  media: Optional[Dict[str, Any]] = None,
 
173
  ) -> str:
174
  """Abstract method for formatting instances in different subclasses.
175
 
@@ -256,7 +262,10 @@ class SystemFormat(BaseFormat):
256
  target_prefix: str,
257
  demos: List[Dict[str, Any]],
258
  media: Optional[Dict[str, Any]] = None,
 
259
  ) -> str:
 
 
260
  demos_string = ""
261
  for demo in demos:
262
  demo_str = self.demo_format.format(
@@ -356,8 +365,18 @@ class ChatAPIFormat(BaseFormat):
356
  )
357
 
358
  The resulting `messages` is now a dictionary ready for sending to the OpenAI API.
 
 
 
 
 
 
 
 
359
  """
360
 
 
 
361
  def to_content(self, text: str, media: Dict[str, Any]) -> Union[str, List[Content]]:
362
  # Regular expression to find <img> tags with src attribute
363
  img_tag_pattern = re.compile(
@@ -419,12 +438,15 @@ class ChatAPIFormat(BaseFormat):
419
  target_prefix: str,
420
  demos: List[Dict[str, Any]],
421
  media: Optional[Dict[str, Any]] = None,
 
422
  ) -> List[Message]:
423
  messages = []
424
 
425
- if system_prompt or instruction:
426
  system_content = self.to_content(
427
- system_prompt + ("\n" if system_prompt != "" else "") + instruction,
 
 
428
  media,
429
  )
430
  messages.append(
@@ -435,13 +457,22 @@ class ChatAPIFormat(BaseFormat):
435
  )
436
 
437
  for demo_instance in demos:
438
- user_content = self.to_content(demo_instance["source"], media)
 
 
 
 
 
 
 
 
 
439
  assistant_content = self.to_content(
440
- target_prefix + demo_instance["target"], media
 
441
  )
442
  messages.extend(
443
  [
444
- {"role": "user", "content": user_content},
445
  {
446
  "role": "assistant",
447
  "content": assistant_content,
@@ -449,9 +480,15 @@ class ChatAPIFormat(BaseFormat):
449
  ]
450
  )
451
 
452
- last_user_content = self.to_content(source, media)
 
 
453
 
454
- messages.extend([{"role": "user", "content": last_user_content}])
 
 
 
 
455
 
456
  return messages
457
 
@@ -463,6 +500,7 @@ class ChatAPIFormat(BaseFormat):
463
  target_prefix: str,
464
  demos: List[Dict[str, Any]],
465
  media: Optional[Dict[str, Any]] = None,
 
466
  ) -> Union[str, List[Message]]:
467
  chat = self.to_chat(
468
  system_prompt,
@@ -471,6 +509,7 @@ class ChatAPIFormat(BaseFormat):
471
  target_prefix,
472
  demos,
473
  media,
 
474
  )
475
  media["images"] = []
476
  return chat
@@ -492,6 +531,7 @@ class HFSystemFormat(ChatAPIFormat):
492
  """
493
 
494
  model_name: str
 
495
  _requirements_list = ["transformers", "Jinja2"]
496
 
497
  @retry_connection_with_exponential_backoff(backoff_factor=2)
@@ -509,13 +549,17 @@ class HFSystemFormat(ChatAPIFormat):
509
  target_prefix: str,
510
  demos: List[Dict[str, Any]],
511
  media: Optional[Dict[str, Any]] = None,
 
512
  ) -> str:
513
  chat = self.to_chat(
514
- system_prompt, instruction, source, target_prefix, demos, media
515
  )
516
  return (
517
  self.tokenizer.apply_chat_template(
518
- chat, tokenize=False, add_generation_prompt=True
 
 
 
519
  )
520
  + target_prefix
521
  )
 
1
+ import json
2
  import re
3
  from abc import abstractmethod
4
  from typing import (
 
19
  from .operator import InstanceOperator
20
  from .settings_utils import get_constants
21
  from .type_utils import isoftype
22
+ from .types import Dialog
23
  from .utils import retry_connection_with_exponential_backoff
24
 
25
  constants = get_constants()
 
137
  def _prepare_instance_fields(self, instance) -> Tuple[str]:
138
  instance_fields = {}
139
 
140
+ if "__turns__" in instance:
141
+ instance_fields["turns"] = instance["__turns__"]
142
+
143
  for field in (
144
  "source",
145
  constants.instruction_field,
 
175
  target_prefix: str,
176
  demos: List[Dict[str, Any]],
177
  media: Optional[Dict[str, Any]] = None,
178
+ turns: Optional[Dialog] = None,
179
  ) -> str:
180
  """Abstract method for formatting instances in different subclasses.
181
 
 
262
  target_prefix: str,
263
  demos: List[Dict[str, Any]],
264
  media: Optional[Dict[str, Any]] = None,
265
+ turns: Optional[Dialog] = None,
266
  ) -> str:
267
+ if turns is not None and not source:
268
+ source = json.dumps(turns)
269
  demos_string = ""
270
  for demo in demos:
271
  demo_str = self.demo_format.format(
 
365
  )
366
 
367
  The resulting `messages` is now a dictionary ready for sending to the OpenAI API.
368
+
369
+ By default, the instruction in the template is placed in a turn with a 'system' role.
370
+ However, some chat tokenizers, will not place the default system prompt for the model,
371
+ if there is turn with an explicit 'system' role. To keep the default system prompt,
372
+ set 'place_instruction_in_user_turns=True'. This will cause the instruction of the template
373
+ to be placed in a turn with a 'user' role. Note the instruction will also be placed
374
+ in every demo turn (if demos are generated.)
375
+
376
  """
377
 
378
+ place_instruction_in_user_turns: bool = False
379
+
380
  def to_content(self, text: str, media: Dict[str, Any]) -> Union[str, List[Content]]:
381
  # Regular expression to find <img> tags with src attribute
382
  img_tag_pattern = re.compile(
 
438
  target_prefix: str,
439
  demos: List[Dict[str, Any]],
440
  media: Optional[Dict[str, Any]] = None,
441
+ turns: Optional[Dialog] = None,
442
  ) -> List[Message]:
443
  messages = []
444
 
445
+ if system_prompt or (instruction and not self.place_instruction_in_user_turns):
446
  system_content = self.to_content(
447
+ system_prompt
448
+ + ("\n" if system_prompt != "" else "")
449
+ + (instruction if not self.place_instruction_in_user_turns else ""),
450
  media,
451
  )
452
  messages.append(
 
457
  )
458
 
459
  for demo_instance in demos:
460
+ if "__turns__" in demo_instance:
461
+ messages.extend(demo_instance["__turns__"])
462
+ else:
463
+ text = demo_instance["source"]
464
+
465
+ if instruction and self.place_instruction_in_user_turns:
466
+ text = f"{instruction}\n{text}"
467
+ source_content = self.to_content(text, media)
468
+ messages.extend([{"role": "user", "content": source_content}])
469
+
470
  assistant_content = self.to_content(
471
+ target_prefix + demo_instance["target"],
472
+ media,
473
  )
474
  messages.extend(
475
  [
 
476
  {
477
  "role": "assistant",
478
  "content": assistant_content,
 
480
  ]
481
  )
482
 
483
+ text = source
484
+ if instruction and self.place_instruction_in_user_turns:
485
+ text = f"{instruction}\n{text}"
486
 
487
+ if turns is None:
488
+ last_user_content = self.to_content(text, media)
489
+ messages.extend([{"role": "user", "content": last_user_content}])
490
+ else:
491
+ messages.extend(turns)
492
 
493
  return messages
494
 
 
500
  target_prefix: str,
501
  demos: List[Dict[str, Any]],
502
  media: Optional[Dict[str, Any]] = None,
503
+ turns: Optional[Dialog] = None,
504
  ) -> Union[str, List[Message]]:
505
  chat = self.to_chat(
506
  system_prompt,
 
509
  target_prefix,
510
  demos,
511
  media,
512
+ turns,
513
  )
514
  media["images"] = []
515
  return chat
 
531
  """
532
 
533
  model_name: str
534
+ chat_kwargs_dict: Dict[str, str] = {}
535
  _requirements_list = ["transformers", "Jinja2"]
536
 
537
  @retry_connection_with_exponential_backoff(backoff_factor=2)
 
549
  target_prefix: str,
550
  demos: List[Dict[str, Any]],
551
  media: Optional[Dict[str, Any]] = None,
552
+ turns: Optional[Dialog] = None,
553
  ) -> str:
554
  chat = self.to_chat(
555
+ system_prompt, instruction, source, target_prefix, demos, media, turns
556
  )
557
  return (
558
  self.tokenizer.apply_chat_template(
559
+ chat,
560
+ tokenize=False,
561
+ add_generation_prompt=True,
562
+ **self.chat_kwargs_dict,
563
  )
564
  + target_prefix
565
  )
fusion.py CHANGED
@@ -2,6 +2,7 @@ from abc import abstractmethod
2
  from typing import Dict, Generator, List, Optional, Union
3
 
4
  from .dataclass import NonPositionalField
 
5
  from .logging_utils import get_logger
6
  from .operator import SourceOperator
7
  from .random_utils import new_random_generator
@@ -92,7 +93,7 @@ class FixedFusion(BaseFusion):
92
  max_from_this_split = max_per_this_split
93
 
94
  logger.info(f"Processing {split} from {origin_name}...")
95
- try:
96
  for instance in multi_stream[split]:
97
  if (
98
  max_from_this_split is not None
@@ -105,8 +106,6 @@ class FixedFusion(BaseFusion):
105
  instance["subset"].insert(0, origin_name)
106
  emitted_from_this_split += 1
107
  yield instance
108
- except Exception as e:
109
- raise RuntimeError(f"Exception in subset: {origin_name}") from e
110
 
111
 
112
  class WeightedFusion(BaseFusion):
@@ -164,16 +163,15 @@ class WeightedFusion(BaseFusion):
164
  weights=[self.named_weights[name] for name in population],
165
  )[0]
166
  iterator = iterators[origin_name]
167
- try:
168
- instance = next(iterator)
169
- if isinstance(origin_name, str):
170
- if "subset" not in instance:
171
- instance["subset"] = []
172
- instance["subset"].insert(0, origin_name)
173
- total_examples += 1
174
- yield instance
175
-
176
- except StopIteration:
177
- iterators.pop(origin_name)
178
- except Exception as e:
179
- raise RuntimeError(f"Exception in subset: {origin_name}") from e
 
2
  from typing import Dict, Generator, List, Optional, Union
3
 
4
  from .dataclass import NonPositionalField
5
+ from .error_utils import error_context
6
  from .logging_utils import get_logger
7
  from .operator import SourceOperator
8
  from .random_utils import new_random_generator
 
93
  max_from_this_split = max_per_this_split
94
 
95
  logger.info(f"Processing {split} from {origin_name}...")
96
+ with error_context(self, subset=origin_name):
97
  for instance in multi_stream[split]:
98
  if (
99
  max_from_this_split is not None
 
106
  instance["subset"].insert(0, origin_name)
107
  emitted_from_this_split += 1
108
  yield instance
 
 
109
 
110
 
111
  class WeightedFusion(BaseFusion):
 
163
  weights=[self.named_weights[name] for name in population],
164
  )[0]
165
  iterator = iterators[origin_name]
166
+ with error_context(self, subset=origin_name):
167
+ try:
168
+ instance = next(iterator)
169
+ if isinstance(origin_name, str):
170
+ if "subset" not in instance:
171
+ instance["subset"] = []
172
+ instance["subset"].insert(0, origin_name)
173
+ total_examples += 1
174
+ yield instance
175
+
176
+ except StopIteration:
177
+ iterators.pop(origin_name)
 
inference.py CHANGED
@@ -39,7 +39,7 @@ from .artifact import Artifact
39
  from .base_metric import Metric
40
  from .dataclass import InternalField, NonPositionalField
41
  from .deprecation_utils import deprecation
42
- from .error_utils import UnitxtError, UnitxtWarning
43
  from .image_operators import (
44
  EncodeImageToString,
45
  ImageDataString,
@@ -121,6 +121,8 @@ class TextGenerationInferenceOutput:
121
  | For example: ``[ {.. "top_tokens": [ {"text": "a", 'logprob': }, {"text": "b", 'logprob': } ....]},
122
  {.. "top_tokens": [ {"text": "c", 'logprob': }, {"text": "d", 'logprob': } ....]} ]``
123
 
 
 
124
  input_tokens (int) : number of input tokens to the model.
125
 
126
  output_tokens (int) : number of output tokens to the model.
@@ -137,6 +139,7 @@ class TextGenerationInferenceOutput:
137
  """
138
 
139
  prediction: Union[str, List[Dict[str, Any]]]
 
140
  input_tokens: Optional[int] = None
141
  output_tokens: Optional[int] = None
142
  stop_reason: Optional[str] = None
@@ -186,12 +189,19 @@ class InferenceEngine(Artifact):
186
  def prepare(self):
187
  if not settings.mock_inference_mode:
188
  super().prepare() # no need to prepare a mock
189
- self.prepare_engine()
 
 
 
 
 
190
  if self.use_cache:
191
  from diskcache import Cache
192
 
193
  self._cache = Cache(
194
- settings.inference_engine_cache_path + self.__class__.__name__
 
 
195
  )
196
 
197
  def __call__(
@@ -199,7 +209,12 @@ class InferenceEngine(Artifact):
199
  dataset: Union[List[Dict[str, Any]], Dataset],
200
  return_meta_data: bool = False,
201
  ) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]:
202
- return self.infer(dataset=dataset, return_meta_data=return_meta_data)
 
 
 
 
 
203
 
204
  def get_instance_cache_key(self, instance):
205
  instance_key_fields = ["media", "source", "task_data"]
@@ -243,54 +258,69 @@ class InferenceEngine(Artifact):
243
  result = self._mock_infer(dataset)
244
  else:
245
  if self.use_cache:
246
- number_of_batches = math.ceil(len(dataset) / self.cache_batch_size)
247
- result = []
248
- for batch_index, batch in enumerate(
249
- batched(dataset, self.cache_batch_size)
250
  ):
251
- cached_results = []
252
- missing_examples = []
253
- for i, item in enumerate(batch):
254
- cache_key = self._get_cache_key(item)
255
- cached_value = self._cache.get(cache_key)
256
- if cached_value is not None:
257
- cached_results.append(
258
- (i, cached_value)
259
- ) # each element is index in batch, and value
260
- else:
261
- missing_examples.append(
262
- (i, item)
263
- ) # each element is index in batch and example
264
- # infare on missing examples only, without indices
265
-
266
- logger.info(
267
- f"Inferring batch {batch_index + 1} / {number_of_batches} with {len(missing_examples)} instances (found {len(cached_results)} instances in {self._cache.directory})"
268
- )
269
- if len(missing_examples) > 0:
270
- inferred_results = self._infer(
271
- [e[1] for e in missing_examples], return_meta_data
272
- )
273
- # recombined to index and value
274
- inferred_results = list(
275
- zip([e[0] for e in missing_examples], inferred_results)
276
- )
277
- # Add missing examples to cache
278
- for (_, item), (_, prediction) in zip(
279
- missing_examples, inferred_results
280
- ):
281
- if prediction is None:
282
- continue
283
  cache_key = self._get_cache_key(item)
284
- self._cache[cache_key] = prediction
285
- else:
286
- inferred_results = []
287
- # Combine cached and inferred results in original order
288
- batch_predictions = [
289
- p[1] for p in sorted(cached_results + inferred_results)
290
- ]
291
- result.extend(batch_predictions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  else:
293
- result = self._infer(dataset, return_meta_data)
 
 
 
 
 
294
  return ListWithMetadata(
295
  result,
296
  metadata={
@@ -339,7 +369,16 @@ class InferenceEngine(Artifact):
339
 
340
  def to_messages(self, instance):
341
  if isinstance(instance["source"], list):
342
- return instance["source"]
 
 
 
 
 
 
 
 
 
343
  return [
344
  {
345
  "role": "user",
@@ -521,13 +560,6 @@ class HFInferenceEngineBase(
521
  def get_engine_id(self):
522
  return get_model_and_label_id(self.model_name, self.label)
523
 
524
- def decode_tokens(self, tokens: Sequence, inp_length: int) -> List[str]:
525
- return self.processor.decode(tokens[inp_length:], skip_special_tokens=True)
526
-
527
- @staticmethod
528
- def create_string_from_tokens(string_tokens: List[str]) -> str:
529
- return "".join(token for token in string_tokens)
530
-
531
  def make_predictions(self, prepared_inputs: Mapping) -> Mapping:
532
  return self.model.generate(
533
  **prepared_inputs,
@@ -598,6 +630,7 @@ class HFInferenceEngineBase(
598
  def get_return_object(
599
  self,
600
  output: Union[str, List[Dict[str, Any]]],
 
601
  output_tokens: Optional[int],
602
  inp: Optional[str],
603
  inp_tokens: Optional[int],
@@ -606,6 +639,7 @@ class HFInferenceEngineBase(
606
  if return_meta_data:
607
  return TextGenerationInferenceOutput(
608
  prediction=output,
 
609
  output_tokens=output_tokens if output_tokens is not None else None,
610
  input_text=inp,
611
  input_tokens=inp_tokens if inp_tokens is not None else None,
@@ -689,7 +723,8 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
689
  # cause an error because the data is always on the gpu
690
  # if torch.cuda.device_count() > 1:
691
  # assert self.device == torch.device(0)
692
- args["device_map"] = "auto"
 
693
  # else:
694
  # if not self.load_in_8bit:
695
  # args["device"] = self.device
@@ -717,15 +752,21 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
717
  **model_args,
718
  )
719
 
720
- def prepare_inputs(self, data: Iterable) -> Mapping:
721
  tokenizer_kargs = {}
722
  if isinstance(data[0], list):
723
- data = self.processor.apply_chat_template(
724
- data,
725
- tokenize=False,
726
- add_generation_prompt=True,
727
- **self.chat_kwargs_dict,
728
- )
 
 
 
 
 
 
729
  tokenizer_kargs["add_special_tokens"] = False
730
 
731
  if self.processor.pad_token is None:
@@ -766,59 +807,71 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
766
  total=len(dataset) // self.batch_size,
767
  ):
768
  # Get the current batch
769
- batch_sources = [instance["source"] for instance in batch]
 
 
 
 
 
 
 
 
 
 
 
770
 
771
- # --- Process the current batch ---
772
- # 1. Tokenize inputs for the batch
773
- tokenized_inputs = self.prepare_inputs(batch_sources)
774
 
775
- # 2. Determine input length (handle encoder-decoder models)
776
  input_length = (
777
  1
778
  if self.model.config.is_encoder_decoder
779
  else tokenized_inputs.input_ids.shape[1]
780
  )
781
 
782
- # 3. Make predictions for the batch
783
  predictions = self.make_predictions(tokenized_inputs)
784
  sequences = predictions.sequences # Sequences for the current batch
785
 
786
- # 4. Decode tokens for the batch
787
- string_tokens_batch = [
788
- self.decode_tokens(sequence, input_length) for sequence in sequences
789
- ]
790
 
791
- # 5. Calculate logprobs or create strings for the batch
792
- final_outputs_batch = (
793
- self.get_logprobs(predictions, string_tokens_batch)
794
- if return_logprobs
795
- else [
796
- self.create_string_from_tokens(strings)
797
- for strings in string_tokens_batch
798
- ]
799
- )
800
 
801
- # 6. Create return objects for the batch
802
- batch_results = [
803
- self.get_return_object(
804
- output=final_outputs_batch[
805
- j
806
- ], # Output for the j-th item in the batch
807
- output_tokens=len(string_tokens_batch[j]),
808
- inp=batch[j]["source"], # Original input for the j-th item
809
- inp_tokens=len(tokenized_inputs.encodings[j].tokens)
810
- if tokenized_inputs.encodings is not None
811
- else None,
812
- return_meta_data=return_meta_data,
 
 
 
 
 
 
 
 
 
 
 
 
 
813
  )
814
- for j in range(
815
- len(sequences)
816
- ) # Iterate through items in the current batch
817
- ]
818
 
819
- # Add results from this batch to the overall list
820
  all_final_outputs.extend(batch_results)
821
- # --- End of batch processing ---
822
 
823
  return all_final_outputs
824
 
@@ -847,7 +900,10 @@ class HFLlavaInferenceEngine(HFInferenceEngineBase):
847
  self, sequences: Sequence, scores: Sequence, beam_indices: Optional[int]
848
  ) -> Sequence:
849
  if not hasattr(self.model.config, "vocab_size"):
850
- self.model.config.vocab_size = self.model.vocab_size
 
 
 
851
 
852
  return super().compute_transition_scores(sequences, scores, beam_indices)
853
 
@@ -917,18 +973,35 @@ class HFLlavaInferenceEngine(HFInferenceEngineBase):
917
 
918
  predictions = self.make_predictions(processed_inputs)
919
 
920
- string_tokens = self.decode_tokens(predictions.sequences[0], input_len)
921
 
922
- final_outputs = (
923
- self.get_logprobs(predictions, [string_tokens])[0]
924
- if return_logprobs
925
- else self.create_string_from_tokens(string_tokens)
926
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
927
 
928
  results.append(
929
  self.get_return_object(
930
- output=final_outputs,
931
- output_tokens=len(string_tokens),
 
932
  inp=instance["source"],
933
  inp_tokens=None,
934
  return_meta_data=return_meta_data,
@@ -1189,6 +1262,7 @@ class HFPipelineBasedInferenceEngine(
1189
  if return_meta_data:
1190
  return TextGenerationInferenceOutput(
1191
  prediction=output["generated_text"],
 
1192
  model_name=self.model_name,
1193
  inference_type=self.label,
1194
  input_text=inp,
@@ -1252,10 +1326,13 @@ class MockInferenceEngine(InferenceEngine, LogProbInferenceEngine):
1252
  for instance in dataset
1253
  ]
1254
 
1255
- def get_return_object(self, predict_result, instance, return_meta_data):
 
 
1256
  if return_meta_data:
1257
  return TextGenerationInferenceOutput(
1258
  prediction=predict_result,
 
1259
  input_tokens=len(instance["source"]),
1260
  output_tokens=len(predict_result),
1261
  model_name=self.model_name,
@@ -1369,21 +1446,25 @@ class OllamaInferenceEngine(
1369
  return get_model_and_label_id(self.model, self.label)
1370
 
1371
  def prepare_engine(self):
1372
- pass
 
 
 
 
 
 
1373
 
1374
  def _infer(
1375
  self,
1376
  dataset: Union[List[Dict[str, Any]], Dataset],
1377
  return_meta_data: bool = False,
1378
  ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
1379
- import ollama
1380
-
1381
  args = self.to_dict([StandardAPIParamsMixin])
1382
  results = []
1383
  model = args.pop("model")
1384
  for instance in dataset:
1385
  messages = self.to_messages(instance)
1386
- response = ollama.chat(
1387
  messages=messages,
1388
  model=model,
1389
  options=args,
@@ -1877,7 +1958,7 @@ class OpenAiInferenceEngine(
1877
  f"Error predicting instance {messages}:{e}. Returning empty prediction"
1878
  )
1879
  return TextGenerationInferenceOutput(
1880
- prediction="-", input_tokens=0, output_tokens=0
1881
  )
1882
 
1883
  @run_with_imap
@@ -1894,10 +1975,12 @@ class OpenAiInferenceEngine(
1894
  top_logprobs_response = response.choices[0].logprobs.content
1895
  pred_output = [
1896
  {
 
 
1897
  "top_tokens": [
1898
  {"text": obj.token, "logprob": obj.logprob}
1899
  for obj in generated_token.top_logprobs
1900
- ]
1901
  }
1902
  for generated_token in top_logprobs_response
1903
  ]
@@ -1907,15 +1990,21 @@ class OpenAiInferenceEngine(
1907
  logging.error(
1908
  f"Error predicting instance {messages}:{e}. Returning empty prediction"
1909
  )
1910
- prediction = [{"top_tokens": [{"text": "-", "logprob": 0}]}]
 
 
1911
  return TextGenerationInferenceOutput(
1912
- prediction=prediction, input_tokens=0, output_tokens=0
 
 
 
1913
  )
1914
 
1915
  def get_return_object(self, predict_result, response, return_meta_data):
1916
  if return_meta_data:
1917
  return TextGenerationInferenceOutput(
1918
  prediction=predict_result,
 
1919
  input_tokens=response.usage.prompt_tokens,
1920
  output_tokens=response.usage.completion_tokens,
1921
  model_name=self.model_name,
@@ -1973,7 +2062,12 @@ class RITSInferenceEngine(
1973
  label: str = "rits"
1974
  data_classification_policy = ["public", "proprietary"]
1975
 
1976
- model_names_dict = {"microsoft/phi-4": "microsoft-phi-4"}
 
 
 
 
 
1977
 
1978
  def get_default_headers(self):
1979
  return {"RITS_API_KEY": self.credentials["api_key"]}
@@ -2606,6 +2700,7 @@ class WMLInferenceEngineGeneration(WMLInferenceEngineBase, WMLGenerationParamsMi
2606
  if return_meta_data:
2607
  return TextGenerationInferenceOutput(
2608
  prediction=predict_result,
 
2609
  input_tokens=result["input_token_count"],
2610
  output_tokens=result["generated_token_count"],
2611
  model_name=self.model_name or self.deployment_id,
@@ -2865,6 +2960,8 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
2865
  tool_call = data[idx]["tools"]["tools"] is not None
2866
 
2867
  output = response["choices"][0][output_type]
 
 
2868
  if tool_call:
2869
  if "tool_calls" in output:
2870
  func = output["tool_calls"][0]["function"]
@@ -2877,6 +2974,7 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
2877
  results.append(
2878
  self.get_return_object(
2879
  prediction,
 
2880
  response,
2881
  str(inp),
2882
  return_meta_data,
@@ -2885,10 +2983,13 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
2885
 
2886
  return results
2887
 
2888
- def get_return_object(self, predict_result, result, input_text, return_meta_data):
 
 
2889
  if return_meta_data:
2890
  return TextGenerationInferenceOutput(
2891
  prediction=predict_result,
 
2892
  input_tokens=result["usage"]["prompt_tokens"],
2893
  output_tokens=len(predict_result)
2894
  if isinstance(predict_result, list)
@@ -3286,6 +3387,7 @@ class LiteLLMInferenceEngine(
3286
  prediction = response["choices"][0]["message"]["content"] or ""
3287
  return TextGenerationInferenceOutput(
3288
  prediction=prediction,
 
3289
  input_tokens=usage.get("prompt_tokens"),
3290
  output_tokens=usage.get("completion_tokens"),
3291
  model_name=response.get("model", self.model),
@@ -3436,21 +3538,22 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
3436
  },
3437
  "rits": {
3438
  "granite-3-8b-instruct": "ibm-granite/granite-3.0-8b-instruct",
 
3439
  "granite-3-2-8b-instruct": "ibm-granite/granite-3.2-8b-instruct",
3440
  "granite-3-3-8b-instruct": "ibm-granite/granite-3.3-8b-instruct",
3441
- "llama-3-1-8b-instruct": "meta-llama/llama-3-1-8b-instruct",
3442
  "llama-3-1-70b-instruct": "meta-llama/llama-3-1-70b-instruct",
3443
  "llama-3-1-405b-instruct": "meta-llama/llama-3-1-405b-instruct-fp8",
3444
  "llama-3-1-405b-instruct-fp8": "meta-llama/llama-3-1-405b-instruct-fp8",
3445
  "llama-3-2-11b-vision-instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
3446
  "llama-3-2-90b-vision-instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
3447
  "llama-3-3-70b-instruct": "meta-llama/llama-3-3-70b-instruct",
3448
- "llama-4-scout": "llama-4-scout-17b-16e",
3449
- "llama-4-maverick": "llama-4-mvk-17b-128e-fp8",
3450
  "mistral-large-instruct": "mistralai/mistral-large-instruct-2407",
3451
  "mixtral-8x7b-instruct": "mistralai/mixtral-8x7B-instruct-v0.1",
3452
  "mixtral-8x7b-instruct-v01": "mistralai/mixtral-8x7B-instruct-v0.1",
3453
- "deepseek-v3": "deepseek-ai/deepseek-v3-h200",
3454
  "granite-guardian-3-2-3b-a800m": "ibm-granite/granite-guardian-3.2-3b-a800m",
3455
  "granite-guardian-3-2-5b": "ibm-granite/granite-guardian-3.2-5b",
3456
  },
 
39
  from .base_metric import Metric
40
  from .dataclass import InternalField, NonPositionalField
41
  from .deprecation_utils import deprecation
42
+ from .error_utils import UnitxtError, UnitxtWarning, error_context
43
  from .image_operators import (
44
  EncodeImageToString,
45
  ImageDataString,
 
121
  | For example: ``[ {.. "top_tokens": [ {"text": "a", 'logprob': }, {"text": "b", 'logprob': } ....]},
122
  {.. "top_tokens": [ {"text": "c", 'logprob': }, {"text": "d", 'logprob': } ....]} ]``
123
 
124
+ generated_text (str): The generated text generated by the model (in both _infer and _infer_log_probs calls).
125
+
126
  input_tokens (int) : number of input tokens to the model.
127
 
128
  output_tokens (int) : number of output tokens to the model.
 
139
  """
140
 
141
  prediction: Union[str, List[Dict[str, Any]]]
142
+ generated_text: str
143
  input_tokens: Optional[int] = None
144
  output_tokens: Optional[int] = None
145
  stop_reason: Optional[str] = None
 
189
  def prepare(self):
190
  if not settings.mock_inference_mode:
191
  super().prepare() # no need to prepare a mock
192
+ with error_context(
193
+ self,
194
+ stage="Prepare Inference Engine",
195
+ help="https://www.unitxt.ai/en/latest/docs/inference.html",
196
+ ):
197
+ self.prepare_engine()
198
  if self.use_cache:
199
  from diskcache import Cache
200
 
201
  self._cache = Cache(
202
+ os.path.join(
203
+ settings.inference_engine_cache_path, self.__class__.__name__
204
+ )
205
  )
206
 
207
  def __call__(
 
209
  dataset: Union[List[Dict[str, Any]], Dataset],
210
  return_meta_data: bool = False,
211
  ) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]:
212
+ with error_context(
213
+ self,
214
+ stage="Running Inference",
215
+ help="https://www.unitxt.ai/en/latest/docs/inference.html",
216
+ ):
217
+ return self.infer(dataset=dataset, return_meta_data=return_meta_data)
218
 
219
  def get_instance_cache_key(self, instance):
220
  instance_key_fields = ["media", "source", "task_data"]
 
258
  result = self._mock_infer(dataset)
259
  else:
260
  if self.use_cache:
261
+ with error_context(
262
+ self,
263
+ stage="Inference Cache Handling",
264
+ help="https://www.unitxt.ai/en/latest/docs/inference.html",
265
  ):
266
+ number_of_batches = math.ceil(len(dataset) / self.cache_batch_size)
267
+ result = []
268
+ for batch_index, batch in enumerate(
269
+ batched(dataset, self.cache_batch_size)
270
+ ):
271
+ cached_results = []
272
+ missing_examples = []
273
+ for i, item in enumerate(batch):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  cache_key = self._get_cache_key(item)
275
+ cached_value = self._cache.get(cache_key)
276
+ if cached_value is not None:
277
+ cached_results.append(
278
+ (i, cached_value)
279
+ ) # each element is index in batch, and value
280
+ else:
281
+ missing_examples.append(
282
+ (i, item)
283
+ ) # each element is index in batch and example
284
+ # infare on missing examples only, without indices
285
+
286
+ logger.info(
287
+ f"Inferring batch {batch_index + 1} / {number_of_batches} with {len(missing_examples)} instances (found {len(cached_results)} instances in {self._cache.directory})"
288
+ )
289
+ if len(missing_examples) > 0:
290
+ with error_context(
291
+ self,
292
+ stage="Running Inference",
293
+ help="https://www.unitxt.ai/en/latest/docs/inference.html",
294
+ ):
295
+ inferred_results = self._infer(
296
+ [e[1] for e in missing_examples], return_meta_data
297
+ )
298
+ # recombined to index and value
299
+ inferred_results = list(
300
+ zip([e[0] for e in missing_examples], inferred_results)
301
+ )
302
+ # Add missing examples to cache
303
+ for (_, item), (_, prediction) in zip(
304
+ missing_examples, inferred_results
305
+ ):
306
+ if prediction is None:
307
+ continue
308
+ cache_key = self._get_cache_key(item)
309
+ self._cache[cache_key] = prediction
310
+ else:
311
+ inferred_results = []
312
+ # Combine cached and inferred results in original order
313
+ batch_predictions = [
314
+ p[1] for p in sorted(cached_results + inferred_results)
315
+ ]
316
+ result.extend(batch_predictions)
317
  else:
318
+ with error_context(
319
+ self,
320
+ stage="Running Inference",
321
+ help="https://www.unitxt.ai/en/latest/docs/inference.html",
322
+ ):
323
+ result = self._infer(dataset, return_meta_data)
324
  return ListWithMetadata(
325
  result,
326
  metadata={
 
369
 
370
  def to_messages(self, instance):
371
  if isinstance(instance["source"], list):
372
+ messages = []
373
+ for message in instance["source"]:
374
+ if "tool_calls" in message:
375
+ for tool_call in message["tool_calls"]:
376
+ if not isinstance(tool_call["function"]["arguments"], str):
377
+ tool_call["function"]["arguments"] = json.dumps(
378
+ tool_call["function"]["arguments"]
379
+ )
380
+ messages.append(message)
381
+ return messages
382
  return [
383
  {
384
  "role": "user",
 
560
  def get_engine_id(self):
561
  return get_model_and_label_id(self.model_name, self.label)
562
 
 
 
 
 
 
 
 
563
  def make_predictions(self, prepared_inputs: Mapping) -> Mapping:
564
  return self.model.generate(
565
  **prepared_inputs,
 
630
  def get_return_object(
631
  self,
632
  output: Union[str, List[Dict[str, Any]]],
633
+ generated_text: str,
634
  output_tokens: Optional[int],
635
  inp: Optional[str],
636
  inp_tokens: Optional[int],
 
639
  if return_meta_data:
640
  return TextGenerationInferenceOutput(
641
  prediction=output,
642
+ generated_text=generated_text,
643
  output_tokens=output_tokens if output_tokens is not None else None,
644
  input_text=inp,
645
  input_tokens=inp_tokens if inp_tokens is not None else None,
 
723
  # cause an error because the data is always on the gpu
724
  # if torch.cuda.device_count() > 1:
725
  # assert self.device == torch.device(0)
726
+ if self.device_map is None:
727
+ args["device_map"] = "auto"
728
  # else:
729
  # if not self.load_in_8bit:
730
  # args["device"] = self.device
 
752
  **model_args,
753
  )
754
 
755
+ def prepare_inputs(self, data: Iterable, tools: Iterable) -> Mapping:
756
  tokenizer_kargs = {}
757
  if isinstance(data[0], list):
758
+ processed = []
759
+ for item, item_tools in zip(data, tools):
760
+ processed.append(
761
+ self.processor.apply_chat_template(
762
+ item,
763
+ tokenize=False,
764
+ tools=item_tools,
765
+ add_generation_prompt=True,
766
+ **self.chat_kwargs_dict,
767
+ )
768
+ )
769
+ data = processed
770
  tokenizer_kargs["add_special_tokens"] = False
771
 
772
  if self.processor.pad_token is None:
 
807
  total=len(dataset) // self.batch_size,
808
  ):
809
  # Get the current batch
810
+ sources = []
811
+ tools = []
812
+ for instance in batch:
813
+ sources.append(instance["source"])
814
+ if "task_data" in instance and "__tools__" in instance["task_data"]:
815
+ task_data = instance["task_data"]
816
+ if isinstance(task_data, str):
817
+ task_data = json.loads(task_data)
818
+ tools.append(task_data["__tools__"])
819
+ else:
820
+ tools.append(None)
821
+ # Tokenize inputs for the batch
822
 
823
+ tokenized_inputs = self.prepare_inputs(sources, tools)
 
 
824
 
825
+ # Determine input length (handle encoder-decoder models)
826
  input_length = (
827
  1
828
  if self.model.config.is_encoder_decoder
829
  else tokenized_inputs.input_ids.shape[1]
830
  )
831
 
832
+ # Make predictions for the batch
833
  predictions = self.make_predictions(tokenized_inputs)
834
  sequences = predictions.sequences # Sequences for the current batch
835
 
836
+ output_tokens = sequences[:, input_length:]
 
 
 
837
 
838
+ output_tokens_strings = []
839
+ for tokens in output_tokens:
840
+ output_tokens_strings.append(
841
+ [
842
+ self.processor.decode(token, skip_special_tokens=True)
843
+ for token in tokens
844
+ ]
845
+ )
 
846
 
847
+ output_strings = []
848
+ for tokens in output_tokens:
849
+ output_strings.append(
850
+ self.processor.decode(tokens, skip_special_tokens=True)
851
+ )
852
+
853
+ if return_logprobs:
854
+ outputs = self.get_logprobs(predictions, output_tokens_strings)
855
+ else:
856
+ outputs = output_strings
857
+
858
+ # Create return objects for the batch
859
+ batch_results = []
860
+ for i in range(len(sequences)):
861
+ batch_results.append(
862
+ self.get_return_object(
863
+ output=outputs[i],
864
+ generated_text=output_strings[i],
865
+ output_tokens=len(output_tokens_strings[i]),
866
+ inp=sources[i],
867
+ inp_tokens=len(tokenized_inputs.encodings[i].tokens)
868
+ if tokenized_inputs.encodings is not None
869
+ else None,
870
+ return_meta_data=return_meta_data,
871
+ )
872
  )
 
 
 
 
873
 
 
874
  all_final_outputs.extend(batch_results)
 
875
 
876
  return all_final_outputs
877
 
 
900
  self, sequences: Sequence, scores: Sequence, beam_indices: Optional[int]
901
  ) -> Sequence:
902
  if not hasattr(self.model.config, "vocab_size"):
903
+ try:
904
+ self.model.config.vocab_size = self.model.vocab_size
905
+ except:
906
+ self.model.config.vocab_size = self.model.config.text_config.vocab_size
907
 
908
  return super().compute_transition_scores(sequences, scores, beam_indices)
909
 
 
973
 
974
  predictions = self.make_predictions(processed_inputs)
975
 
976
+ sequences = predictions.sequences # Sequences for the current batch
977
 
978
+ output_tokens = sequences[:, input_len:]
979
+
980
+ output_tokens_strings = []
981
+ for tokens in output_tokens:
982
+ output_tokens_strings.append(
983
+ [
984
+ self.processor.decode(token, skip_special_tokens=True)
985
+ for token in tokens
986
+ ]
987
+ )
988
+
989
+ output_strings = []
990
+ for tokens in output_tokens:
991
+ output_strings.append(
992
+ self.processor.decode(tokens, skip_special_tokens=True)
993
+ )
994
+
995
+ if return_logprobs:
996
+ final_outputs = self.get_logprobs(predictions, output_tokens_strings)
997
+ else:
998
+ final_outputs = output_strings
999
 
1000
  results.append(
1001
  self.get_return_object(
1002
+ output=final_outputs[0],
1003
+ generated_text=output_strings,
1004
+ output_tokens=len(output_tokens_strings[0]),
1005
  inp=instance["source"],
1006
  inp_tokens=None,
1007
  return_meta_data=return_meta_data,
 
1262
  if return_meta_data:
1263
  return TextGenerationInferenceOutput(
1264
  prediction=output["generated_text"],
1265
+ generated_text=output["generated_text"],
1266
  model_name=self.model_name,
1267
  inference_type=self.label,
1268
  input_text=inp,
 
1326
  for instance in dataset
1327
  ]
1328
 
1329
+ def get_return_object(
1330
+ self, predict_result, generated_text, instance, return_meta_data
1331
+ ):
1332
  if return_meta_data:
1333
  return TextGenerationInferenceOutput(
1334
  prediction=predict_result,
1335
+ generated_text=self.default_inference_value,
1336
  input_tokens=len(instance["source"]),
1337
  output_tokens=len(predict_result),
1338
  model_name=self.model_name,
 
1446
  return get_model_and_label_id(self.model, self.label)
1447
 
1448
  def prepare_engine(self):
1449
+ from ollama import Client
1450
+
1451
+ self.client = Client(
1452
+ host=self.credentials["api_base"]
1453
+ if self.credentials is not None and "api_base" in self.credentials
1454
+ else None
1455
+ )
1456
 
1457
  def _infer(
1458
  self,
1459
  dataset: Union[List[Dict[str, Any]], Dataset],
1460
  return_meta_data: bool = False,
1461
  ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
 
 
1462
  args = self.to_dict([StandardAPIParamsMixin])
1463
  results = []
1464
  model = args.pop("model")
1465
  for instance in dataset:
1466
  messages = self.to_messages(instance)
1467
+ response = self.client.chat(
1468
  messages=messages,
1469
  model=model,
1470
  options=args,
 
1958
  f"Error predicting instance {messages}:{e}. Returning empty prediction"
1959
  )
1960
  return TextGenerationInferenceOutput(
1961
+ prediction="-", generated_text="-", input_tokens=0, output_tokens=0
1962
  )
1963
 
1964
  @run_with_imap
 
1975
  top_logprobs_response = response.choices[0].logprobs.content
1976
  pred_output = [
1977
  {
1978
+ "text": generated_token.token,
1979
+ "logprob": generated_token.logprob,
1980
  "top_tokens": [
1981
  {"text": obj.token, "logprob": obj.logprob}
1982
  for obj in generated_token.top_logprobs
1983
+ ],
1984
  }
1985
  for generated_token in top_logprobs_response
1986
  ]
 
1990
  logging.error(
1991
  f"Error predicting instance {messages}:{e}. Returning empty prediction"
1992
  )
1993
+ prediction = [
1994
+ {"text": "-", "logprob": 0, "top_tokens": [{"text": "-", "logprob": 0}]}
1995
+ ]
1996
  return TextGenerationInferenceOutput(
1997
+ prediction=prediction,
1998
+ generated_text=prediction,
1999
+ input_tokens=0,
2000
+ output_tokens=0,
2001
  )
2002
 
2003
  def get_return_object(self, predict_result, response, return_meta_data):
2004
  if return_meta_data:
2005
  return TextGenerationInferenceOutput(
2006
  prediction=predict_result,
2007
+ generated_text=response.choices[0].message.content,
2008
  input_tokens=response.usage.prompt_tokens,
2009
  output_tokens=response.usage.completion_tokens,
2010
  model_name=self.model_name,
 
2062
  label: str = "rits"
2063
  data_classification_policy = ["public", "proprietary"]
2064
 
2065
+ model_names_dict = {
2066
+ "microsoft/phi-4": "microsoft-phi-4",
2067
+ "meta-llama/llama-4-maverick-17b-128e-instruct-fp8": "llama-4-mvk-17b-128e-fp8",
2068
+ "deepseek-ai/DeepSeek-V3": "deepseek-v3-h200",
2069
+ "meta-llama/Llama-3.1-8B-Instruct": "llama-3-1-8b-instruct",
2070
+ }
2071
 
2072
  def get_default_headers(self):
2073
  return {"RITS_API_KEY": self.credentials["api_key"]}
 
2700
  if return_meta_data:
2701
  return TextGenerationInferenceOutput(
2702
  prediction=predict_result,
2703
+ generated_text=result["generated_text"],
2704
  input_tokens=result["input_token_count"],
2705
  output_tokens=result["generated_token_count"],
2706
  model_name=self.model_name or self.deployment_id,
 
2960
  tool_call = data[idx]["tools"]["tools"] is not None
2961
 
2962
  output = response["choices"][0][output_type]
2963
+ if "content" not in output:
2964
+ output["content"] = ""
2965
  if tool_call:
2966
  if "tool_calls" in output:
2967
  func = output["tool_calls"][0]["function"]
 
2974
  results.append(
2975
  self.get_return_object(
2976
  prediction,
2977
+ response["choices"][0]["message"]["content"],
2978
  response,
2979
  str(inp),
2980
  return_meta_data,
 
2983
 
2984
  return results
2985
 
2986
+ def get_return_object(
2987
+ self, predict_result, generated_text, result, input_text, return_meta_data
2988
+ ):
2989
  if return_meta_data:
2990
  return TextGenerationInferenceOutput(
2991
  prediction=predict_result,
2992
+ generated_text=generated_text,
2993
  input_tokens=result["usage"]["prompt_tokens"],
2994
  output_tokens=len(predict_result)
2995
  if isinstance(predict_result, list)
 
3387
  prediction = response["choices"][0]["message"]["content"] or ""
3388
  return TextGenerationInferenceOutput(
3389
  prediction=prediction,
3390
+ generated_text=response["choices"][0]["message"]["content"],
3391
  input_tokens=usage.get("prompt_tokens"),
3392
  output_tokens=usage.get("completion_tokens"),
3393
  model_name=response.get("model", self.model),
 
3538
  },
3539
  "rits": {
3540
  "granite-3-8b-instruct": "ibm-granite/granite-3.0-8b-instruct",
3541
+ "granite-3-1-8b-instruct": "ibm-granite/granite-3.1-8b-instruct",
3542
  "granite-3-2-8b-instruct": "ibm-granite/granite-3.2-8b-instruct",
3543
  "granite-3-3-8b-instruct": "ibm-granite/granite-3.3-8b-instruct",
3544
+ "llama-3-1-8b-instruct": "meta-llama/Llama-3.1-8B-Instruct",
3545
  "llama-3-1-70b-instruct": "meta-llama/llama-3-1-70b-instruct",
3546
  "llama-3-1-405b-instruct": "meta-llama/llama-3-1-405b-instruct-fp8",
3547
  "llama-3-1-405b-instruct-fp8": "meta-llama/llama-3-1-405b-instruct-fp8",
3548
  "llama-3-2-11b-vision-instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
3549
  "llama-3-2-90b-vision-instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
3550
  "llama-3-3-70b-instruct": "meta-llama/llama-3-3-70b-instruct",
3551
+ "llama-4-scout": "meta-llama/llama-4-scout-17b-16e",
3552
+ "llama-4-maverick": "meta-llama/llama-4-maverick-17b-128e-instruct-fp8",
3553
  "mistral-large-instruct": "mistralai/mistral-large-instruct-2407",
3554
  "mixtral-8x7b-instruct": "mistralai/mixtral-8x7B-instruct-v0.1",
3555
  "mixtral-8x7b-instruct-v01": "mistralai/mixtral-8x7B-instruct-v0.1",
3556
+ "deepseek-v3": "deepseek-ai/DeepSeek-V3",
3557
  "granite-guardian-3-2-3b-a800m": "ibm-granite/granite-guardian-3.2-3b-a800m",
3558
  "granite-guardian-3-2-5b": "ibm-granite/granite-guardian-3.2-5b",
3559
  },
llm_as_judge_constants.py CHANGED
@@ -125,7 +125,7 @@ EVALUATOR_TO_MODEL_ID = {
125
  EvaluatorNameEnum.GRANITE3_1_8B: "granite-3-1-8b-instruct",
126
  EvaluatorNameEnum.GRANITE3_2_8B: "granite-3-2-8b-instruct",
127
  EvaluatorNameEnum.GRANITE3_3_8B: "granite-3-3-8b-instruct",
128
- EvaluatorNameEnum.DEEPSEEK_V3: "deepseek-ai/DeepSeek-V3",
129
  EvaluatorNameEnum.GEMMA_2_5_PRO: "gemma-2-5-pro",
130
  EvaluatorNameEnum.GEMINI_2_5_FLASH: "gemini-2-5-flash",
131
  }
@@ -198,7 +198,6 @@ EVALUATORS_METADATA = [
198
  [
199
  ModelProviderEnum.WATSONX,
200
  ModelProviderEnum.TOGETHER_AI,
201
- ModelProviderEnum.RITS,
202
  ModelProviderEnum.OLLAMA,
203
  ],
204
  ),
 
125
  EvaluatorNameEnum.GRANITE3_1_8B: "granite-3-1-8b-instruct",
126
  EvaluatorNameEnum.GRANITE3_2_8B: "granite-3-2-8b-instruct",
127
  EvaluatorNameEnum.GRANITE3_3_8B: "granite-3-3-8b-instruct",
128
+ EvaluatorNameEnum.DEEPSEEK_V3: "deepseek-v3",
129
  EvaluatorNameEnum.GEMMA_2_5_PRO: "gemma-2-5-pro",
130
  EvaluatorNameEnum.GEMINI_2_5_FLASH: "gemini-2-5-flash",
131
  }
 
198
  [
199
  ModelProviderEnum.WATSONX,
200
  ModelProviderEnum.TOGETHER_AI,
 
201
  ModelProviderEnum.OLLAMA,
202
  ],
203
  ),
loaders.py CHANGED
@@ -66,7 +66,7 @@ from tqdm import tqdm
66
 
67
  from .dataclass import NonPositionalField
68
  from .dict_utils import dict_get
69
- from .error_utils import Documentation, UnitxtError, UnitxtWarning
70
  from .fusion import FixedFusion
71
  from .logging_utils import get_logger
72
  from .operator import SourceOperator
@@ -90,23 +90,27 @@ class UnitxtUnverifiedCodeError(UnitxtError):
90
 
91
  @retry_connection_with_exponential_backoff(backoff_factor=2)
92
  def hf_load_dataset(path: str, *args, **kwargs):
93
- if settings.hf_offline_datasets_path is not None:
94
- path = os.path.join(settings.hf_offline_datasets_path, path)
95
- try:
96
- return _hf_load_dataset(
97
- path,
98
- *args,
99
- **kwargs,
100
- verification_mode="no_checks",
101
- trust_remote_code=settings.allow_unverified_code,
102
- download_mode="force_redownload"
103
- if settings.disable_hf_datasets_cache
104
- else "reuse_dataset_if_exists",
105
- )
106
- except ValueError as e:
107
- if "trust_remote_code" in str(e):
108
- raise UnitxtUnverifiedCodeError(path) from e
109
- raise e # Re raise
 
 
 
 
110
 
111
 
112
  @retry_connection_with_exponential_backoff(backoff_factor=2)
@@ -218,13 +222,15 @@ class Loader(SourceOperator):
218
  pass
219
 
220
  def load_data(self) -> MultiStream:
221
- try:
 
 
 
 
222
  iterables = self.load_iterables()
223
- except Exception as e:
224
- raise UnitxtError(f"Error in loader:\n{self}") from e
225
- if isoftype(iterables, MultiStream):
226
- return iterables
227
- return MultiStream.from_iterables(iterables, copying=True)
228
 
229
  def process(self) -> MultiStream:
230
  self._maybe_set_classification_policy()
@@ -514,9 +520,13 @@ class LoadCSV(LoadWithPandas):
514
  sep: str = ","
515
 
516
  def read_dataframe(self, file) -> pd.DataFrame:
517
- return pd.read_csv(
518
- file, sep=self.sep, low_memory=self.streaming, **self.get_args()
519
- )
 
 
 
 
520
 
521
 
522
  def read_file(source) -> bytes:
@@ -560,32 +570,36 @@ class LoadJsonFile(LoadWithPandas):
560
  data_field: Optional[str] = None
561
 
562
  def read_dataframe(self, file) -> pd.DataFrame:
563
- args = self.get_args()
564
- if not self.lines:
565
- data = json.loads(read_file(file))
566
- if self.data_field:
567
- instances = dict_get(data, self.data_field)
568
- if not isoftype(instances, List[Dict[str, Any]]):
569
- raise UnitxtError(
570
- f"{self.data_field} of file {file} is not a list of dictionariess in LoadJsonFile loader"
571
- )
572
- else:
573
- if isoftype(data, Dict[str, Any]):
574
- instances = [data]
575
- elif isoftype(data, List[Dict[str, Any]]):
576
- instances = data
577
  else:
 
 
 
 
 
 
 
 
 
 
 
578
  raise UnitxtError(
579
- f"data of file {file} is not dictionary or a list of dictionaries in LoadJsonFile loader"
580
  )
581
- dataframe = pd.DataFrame(instances)
582
- else:
583
- if self.data_field is not None:
584
- raise UnitxtError(
585
- "Can not load from a specific 'data_field' when loading multiple lines (lines=True)"
586
- )
587
- dataframe = pd.read_json(file, lines=self.lines, **args)
588
- return dataframe
589
 
590
 
591
  class LoadFromSklearn(LazyLoader):
@@ -631,8 +645,12 @@ class LoadFromSklearn(LazyLoader):
631
  dataset_id = str(self) + "_" + split
632
  dataset = self.__class__._loader_cache.get(dataset_id, None)
633
  if dataset is None:
634
- split_data = self.downloader(subset=split)
635
- targets = [split_data["target_names"][t] for t in split_data["target"]]
 
 
 
 
636
  df = pd.DataFrame([split_data["data"], targets]).T
637
  df.columns = ["data", "target"]
638
  dataset = df.to_dict("records")
@@ -851,18 +869,22 @@ class LoadFromIBMCloud(Loader):
851
  if self.data_dir is not None
852
  else data_file
853
  )
854
- with tempfile.NamedTemporaryFile() as temp_file:
855
- # Download to a temporary file in same file partition, and then do an atomic move
856
- self._download_from_cos(
857
- cos,
858
- self.bucket_name,
859
- object_key,
860
- local_dir + "/" + os.path.basename(temp_file.name),
861
- )
862
- os.renames(
863
- local_dir + "/" + os.path.basename(temp_file.name),
864
- local_dir + "/" + data_file,
865
- )
 
 
 
 
866
 
867
  if isinstance(self.data_files, list):
868
  dataset = hf_load_dataset(local_dir, streaming=False, field=self.data_field)
@@ -946,22 +968,26 @@ class LoadFromDictionary(Loader):
946
 
947
  def verify(self):
948
  super().verify()
949
- if not isoftype(self.data, Dict[str, List[Dict[str, Any]]]):
950
- raise ValueError(
951
- f"Passed data to LoadFromDictionary is not of type Dict[str, List[Dict[str, Any]]].\n"
952
- f"Expected data should map between split name and list of instances.\n"
953
- f"Received value: {self.data}\n"
954
- )
955
- for split in self.data.keys():
956
- if len(self.data[split]) == 0:
957
- raise ValueError(f"Split {split} has no instances.")
958
- first_instance = self.data[split][0]
959
- for instance in self.data[split]:
960
- if instance.keys() != first_instance.keys():
961
- raise ValueError(
962
- f"Not all instances in split '{split}' have the same fields.\n"
963
- f"instance {instance} has different fields different from {first_instance}"
964
- )
 
 
 
 
965
 
966
  def _maybe_set_classification_policy(self):
967
  self.set_default_data_classification(
@@ -1127,7 +1153,7 @@ class LoadFromAPI(Loader):
1127
  chunksize: int = 100000
1128
  loader_limit: Optional[int] = None
1129
  streaming: bool = False
1130
- api_key_env_var: Optional[str] = ""
1131
  headers: Optional[Dict[str, Any]] = None
1132
  data_field: str = "data"
1133
  method: str = "GET"
 
66
 
67
  from .dataclass import NonPositionalField
68
  from .dict_utils import dict_get
69
+ from .error_utils import Documentation, UnitxtError, UnitxtWarning, error_context
70
  from .fusion import FixedFusion
71
  from .logging_utils import get_logger
72
  from .operator import SourceOperator
 
90
 
91
  @retry_connection_with_exponential_backoff(backoff_factor=2)
92
  def hf_load_dataset(path: str, *args, **kwargs):
93
+ with error_context(
94
+ stage="Raw Dataset Download",
95
+ help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
96
+ ):
97
+ if settings.hf_offline_datasets_path is not None:
98
+ path = os.path.join(settings.hf_offline_datasets_path, path)
99
+ try:
100
+ return _hf_load_dataset(
101
+ path,
102
+ *args,
103
+ **kwargs,
104
+ verification_mode="no_checks",
105
+ trust_remote_code=settings.allow_unverified_code,
106
+ download_mode="force_redownload"
107
+ if settings.disable_hf_datasets_cache
108
+ else "reuse_dataset_if_exists",
109
+ )
110
+ except ValueError as e:
111
+ if "trust_remote_code" in str(e):
112
+ raise UnitxtUnverifiedCodeError(path) from e
113
+ raise e # Re raise
114
 
115
 
116
  @retry_connection_with_exponential_backoff(backoff_factor=2)
 
222
  pass
223
 
224
  def load_data(self) -> MultiStream:
225
+ with error_context(
226
+ self,
227
+ stage="Data Loading",
228
+ help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
229
+ ):
230
  iterables = self.load_iterables()
231
+ if isoftype(iterables, MultiStream):
232
+ return iterables
233
+ return MultiStream.from_iterables(iterables, copying=True)
 
 
234
 
235
  def process(self) -> MultiStream:
236
  self._maybe_set_classification_policy()
 
520
  sep: str = ","
521
 
522
  def read_dataframe(self, file) -> pd.DataFrame:
523
+ with error_context(
524
+ stage="Raw Dataset Loading",
525
+ help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
526
+ ):
527
+ return pd.read_csv(
528
+ file, sep=self.sep, low_memory=self.streaming, **self.get_args()
529
+ )
530
 
531
 
532
  def read_file(source) -> bytes:
 
570
  data_field: Optional[str] = None
571
 
572
  def read_dataframe(self, file) -> pd.DataFrame:
573
+ with error_context(
574
+ stage="Raw Dataset Loading",
575
+ help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
576
+ ):
577
+ args = self.get_args()
578
+ if not self.lines:
579
+ data = json.loads(read_file(file))
580
+ if self.data_field:
581
+ instances = dict_get(data, self.data_field)
582
+ if not isoftype(instances, List[Dict[str, Any]]):
583
+ raise UnitxtError(
584
+ f"{self.data_field} of file {file} is not a list of dictionariess in LoadJsonFile loader"
585
+ )
 
586
  else:
587
+ if isoftype(data, Dict[str, Any]):
588
+ instances = [data]
589
+ elif isoftype(data, List[Dict[str, Any]]):
590
+ instances = data
591
+ else:
592
+ raise UnitxtError(
593
+ f"data of file {file} is not dictionary or a list of dictionaries in LoadJsonFile loader"
594
+ )
595
+ dataframe = pd.DataFrame(instances)
596
+ else:
597
+ if self.data_field is not None:
598
  raise UnitxtError(
599
+ "Can not load from a specific 'data_field' when loading multiple lines (lines=True)"
600
  )
601
+ dataframe = pd.read_json(file, lines=self.lines, **args)
602
+ return dataframe
 
 
 
 
 
 
603
 
604
 
605
  class LoadFromSklearn(LazyLoader):
 
645
  dataset_id = str(self) + "_" + split
646
  dataset = self.__class__._loader_cache.get(dataset_id, None)
647
  if dataset is None:
648
+ with error_context(
649
+ stage="Raw Dataset Loading",
650
+ help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
651
+ ):
652
+ split_data = self.downloader(subset=split)
653
+ targets = [split_data["target_names"][t] for t in split_data["target"]]
654
  df = pd.DataFrame([split_data["data"], targets]).T
655
  df.columns = ["data", "target"]
656
  dataset = df.to_dict("records")
 
869
  if self.data_dir is not None
870
  else data_file
871
  )
872
+ with error_context(
873
+ stage="Raw Dataset Download",
874
+ help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
875
+ ):
876
+ with tempfile.NamedTemporaryFile() as temp_file:
877
+ # Download to a temporary file in same file partition, and then do an atomic move
878
+ self._download_from_cos(
879
+ cos,
880
+ self.bucket_name,
881
+ object_key,
882
+ local_dir + "/" + os.path.basename(temp_file.name),
883
+ )
884
+ os.renames(
885
+ local_dir + "/" + os.path.basename(temp_file.name),
886
+ local_dir + "/" + data_file,
887
+ )
888
 
889
  if isinstance(self.data_files, list):
890
  dataset = hf_load_dataset(local_dir, streaming=False, field=self.data_field)
 
968
 
969
  def verify(self):
970
  super().verify()
971
+ with error_context(
972
+ stage="Dataset Loading",
973
+ help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
974
+ ):
975
+ if not isoftype(self.data, Dict[str, List[Dict[str, Any]]]):
976
+ raise ValueError(
977
+ f"Passed data to LoadFromDictionary is not of type Dict[str, List[Dict[str, Any]]].\n"
978
+ f"Expected data should map between split name and list of instances.\n"
979
+ f"Received value: {self.data}\n"
980
+ )
981
+ for split in self.data.keys():
982
+ if len(self.data[split]) == 0:
983
+ raise ValueError(f"Split {split} has no instances.")
984
+ first_instance = self.data[split][0]
985
+ for instance in self.data[split]:
986
+ if instance.keys() != first_instance.keys():
987
+ raise ValueError(
988
+ f"Not all instances in split '{split}' have the same fields.\n"
989
+ f"instance {instance} has different fields different from {first_instance}"
990
+ )
991
 
992
  def _maybe_set_classification_policy(self):
993
  self.set_default_data_classification(
 
1153
  chunksize: int = 100000
1154
  loader_limit: Optional[int] = None
1155
  streaming: bool = False
1156
+ api_key_env_var: Optional[str] = None
1157
  headers: Optional[Dict[str, Any]] = None
1158
  data_field: str = "data"
1159
  method: str = "GET"
metric.py CHANGED
@@ -56,7 +56,6 @@ from .settings_utils import get_constants
56
  from .span_lableing_operators import __file__ as _
57
  from .split_utils import __file__ as _
58
  from .splitters import __file__ as _
59
- from .sql_utils import __file__ as _
60
  from .standard import __file__ as _
61
  from .stream import __file__ as _
62
  from .stream_operators import __file__ as _
@@ -65,6 +64,7 @@ from .struct_data_operators import __file__ as _
65
  from .system_prompts import __file__ as _
66
  from .task import __file__ as _
67
  from .templates import __file__ as _
 
68
  from .text_utils import __file__ as _
69
  from .type_utils import __file__ as _
70
  from .types import __file__ as _
 
56
  from .span_lableing_operators import __file__ as _
57
  from .split_utils import __file__ as _
58
  from .splitters import __file__ as _
 
59
  from .standard import __file__ as _
60
  from .stream import __file__ as _
61
  from .stream_operators import __file__ as _
 
64
  from .system_prompts import __file__ as _
65
  from .task import __file__ as _
66
  from .templates import __file__ as _
67
+ from .text2sql_utils import __file__ as _
68
  from .text_utils import __file__ as _
69
  from .type_utils import __file__ as _
70
  from .types import __file__ as _
metric_utils.py CHANGED
@@ -9,7 +9,7 @@ import pandas as pd
9
  from datasets import Features, Value
10
 
11
  from .dataclass import Dataclass
12
- from .error_utils import Documentation, UnitxtError
13
  from .operator import (
14
  InstanceOperator,
15
  MultiStreamOperator,
@@ -36,6 +36,9 @@ from .utils import recursive_copy
36
 
37
  constants = get_constants()
38
 
 
 
 
39
 
40
  def nan_mean(scores):
41
  result = mean(score for score in scores if score == score)
@@ -56,7 +59,10 @@ class FromPredictionsAndOriginalData(StreamInitializerOperator):
56
  yield {**original, "prediction": prediction}
57
 
58
  def process(
59
- self, predictions: List[str], references: Iterable, split_name: str = "all"
 
 
 
60
  ) -> MultiStream:
61
  return MultiStream(
62
  {
@@ -152,7 +158,7 @@ class SplitSubsetsAndGroups(MultiStreamOperator):
152
 
153
  subset_stream_name = (
154
  stream_name
155
- + "://"
156
  + "/".join(instance[self.subsets_field][: self.subset_depth])
157
  )
158
 
@@ -190,7 +196,7 @@ def group_str_to_key_value(group_str):
190
 
191
  @lru_cache(maxsize=None)
192
  def stream_name_to_origin_subset_group(stream_name):
193
- origin, subset_group = stream_name.split("://")
194
  if "?" in subset_group:
195
  subset, group = subset_group.split("?")
196
  else:
@@ -734,22 +740,23 @@ def _compute(
734
  predictions: List[Any],
735
  references: Iterable,
736
  flatten: bool = False,
737
- split_name: str = "all",
738
  calc_confidence_intervals: bool = True,
739
  ):
740
  _reset_env_local_catalogs()
741
  register_all_artifacts()
742
  recipe = MetricRecipe(calc_confidence_intervals=calc_confidence_intervals)
743
 
744
- multi_stream = recipe(
745
- predictions=predictions, references=references, split_name=split_name
746
- )
 
747
 
748
- if flatten:
749
- operator = FlattenInstances()
750
- multi_stream = operator(multi_stream)
751
 
752
- stream = multi_stream[split_name]
753
  return EvaluationResults(stream)
754
 
755
 
 
9
  from datasets import Features, Value
10
 
11
  from .dataclass import Dataclass
12
+ from .error_utils import Documentation, UnitxtError, error_context
13
  from .operator import (
14
  InstanceOperator,
15
  MultiStreamOperator,
 
36
 
37
  constants = get_constants()
38
 
39
+ DEFAULT_STREAM_NAME = "all_data"
40
+ DEFAULT_STREAM_SUBSET_SEPARATOR = ">>"
41
+
42
 
43
  def nan_mean(scores):
44
  result = mean(score for score in scores if score == score)
 
59
  yield {**original, "prediction": prediction}
60
 
61
  def process(
62
+ self,
63
+ predictions: List[str],
64
+ references: Iterable,
65
+ split_name: str = DEFAULT_STREAM_NAME,
66
  ) -> MultiStream:
67
  return MultiStream(
68
  {
 
158
 
159
  subset_stream_name = (
160
  stream_name
161
+ + DEFAULT_STREAM_SUBSET_SEPARATOR
162
  + "/".join(instance[self.subsets_field][: self.subset_depth])
163
  )
164
 
 
196
 
197
  @lru_cache(maxsize=None)
198
  def stream_name_to_origin_subset_group(stream_name):
199
+ origin, subset_group = stream_name.split(DEFAULT_STREAM_SUBSET_SEPARATOR)
200
  if "?" in subset_group:
201
  subset, group = subset_group.split("?")
202
  else:
 
740
  predictions: List[Any],
741
  references: Iterable,
742
  flatten: bool = False,
743
+ split_name: str = DEFAULT_STREAM_NAME,
744
  calc_confidence_intervals: bool = True,
745
  ):
746
  _reset_env_local_catalogs()
747
  register_all_artifacts()
748
  recipe = MetricRecipe(calc_confidence_intervals=calc_confidence_intervals)
749
 
750
+ with error_context(stage="Metric Processing"):
751
+ multi_stream = recipe(
752
+ predictions=predictions, references=references, split_name=split_name
753
+ )
754
 
755
+ if flatten:
756
+ operator = FlattenInstances()
757
+ multi_stream = operator(multi_stream)
758
 
759
+ stream = multi_stream[split_name]
760
  return EvaluationResults(stream)
761
 
762
 
metrics.py CHANGED
@@ -8,7 +8,8 @@ import uuid
8
  import warnings
9
  from abc import ABC, abstractmethod
10
  from collections import Counter, defaultdict
11
- from dataclasses import field
 
12
  from enum import Enum
13
  from functools import lru_cache
14
  from typing import (
@@ -42,7 +43,8 @@ from .dataclass import (
42
  OptionalField,
43
  )
44
  from .deprecation_utils import deprecation
45
- from .error_utils import Documentation, UnitxtError, UnitxtWarning
 
46
  from .inference import (
47
  HFPipelineBasedInferenceEngine,
48
  InferenceEngine,
@@ -64,6 +66,7 @@ from .operators import ArtifactFetcherMixin, Copy, FieldOperator, Set
64
  from .random_utils import get_seed
65
  from .settings_utils import get_settings
66
  from .stream import MultiStream, Stream
 
67
  from .type_utils import isoftype, parse_type_string, to_type_string
68
  from .types import ToolCall
69
  from .utils import deep_copy, recursive_copy, retry_connection_with_exponential_backoff
@@ -382,28 +385,35 @@ class MapReduceMetric(
382
  return intermediates
383
 
384
  def process(self, stream: Stream, stream_name: Optional[str] = None):
385
- instances_scores, global_scores = self.compute(stream, stream_name)
386
- for i, (instance, instance_scores) in enumerate(zip(stream, instances_scores)):
387
- previous_score = instance.get("score", {"global": {}, "instance": {}})
388
-
389
- if i == 0:
390
- for key in global_scores:
391
- if is_original_key(key) and key in previous_score["global"]:
392
- UnitxtWarning(
393
- message=f"Metric '{key}' that has just been evaluated with value {global_scores[key]}, is already recorded "
394
- f"to have value {previous_score['global'][key]} by a previous metric evaluation on this instance or stream. "
395
- f"To avoid overwriting the existing value, add a score_prefix to the metric name (e.g. score_prefix='my_second_' , "
396
- f"which will yield, in this case, a score named: 'my_second_{key}')",
397
- additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
398
- )
 
 
 
 
 
 
 
399
 
400
- global_scores = {**previous_score["global"], **global_scores}
401
- instance_scores = {**previous_score["instance"], **instance_scores}
402
 
403
- yield {
404
- **instance,
405
- "score": {"global": global_scores, "instance": instance_scores},
406
- }
407
 
408
  def compute(self, stream: Stream, stream_name: Optional[str] = None):
409
  evaluation_inputs_stream = self._instances_stream_to_evaluation_inputs(stream)
@@ -453,6 +463,43 @@ class DictReduction(AggregationReduction[Dict[str, float]]):
453
  return result
454
 
455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
  class MeanReduction(DictReduction):
457
  def reduce_list(self, lst: List[float]):
458
  return nan_mean(lst)
@@ -468,6 +515,91 @@ class MaxReduction(DictReduction):
468
  return float(nan_max(lst))
469
 
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  class ReductionInstanceMetric(
472
  MapReduceMetric[PredictionType, IntermediateType],
473
  Generic[PredictionType, IntermediateType],
@@ -704,6 +836,52 @@ class ToolCallingMetric(ReductionInstanceMetric[str, Dict[str, float]]):
704
  }
705
 
706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
707
  class MetricWithConfidenceInterval(Metric):
708
  # The number of resamples used to estimate the confidence intervals of this metric.
709
  # Use None to disable confidence interval computation.
@@ -954,83 +1132,88 @@ class GlobalMetric(StreamOperator, MetricWithConfidenceInterval):
954
  process_single_instances = True
955
 
956
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
957
- references = []
958
- predictions = []
959
- task_data = []
 
 
 
 
 
960
 
961
- instances = []
962
 
963
- for instance in stream:
964
- instance = self.verify_instance(instance)
965
 
966
- if "score" not in instance:
967
- instance["score"] = {"global": {}, "instance": {}}
968
 
969
- instance_references, instance_prediction = (
970
- instance["references"],
971
- instance["prediction"],
972
- )
973
 
974
- references.append(instance_references)
975
- predictions.append(instance_prediction)
976
- instances.append(instance)
977
 
978
- instance_task_data = (
979
- instance["task_data"] if "task_data" in instance else {}
980
- )
981
- task_data.append(instance_task_data)
982
- instance_score = None
983
 
984
- # for backward compatibility
985
- no_score_value = np.nan
986
- if self.process_single_instances:
987
- try:
988
- instance_score = self._compute(
989
- [instance_references],
990
- [instance_prediction],
991
- [instance_task_data],
992
- )
993
- except:
994
- no_score_value = None
995
- if not instance_score:
996
- instance_score = {
997
- "score": no_score_value,
998
- "score_name": self.main_score,
999
- }
1000
 
1001
- if isinstance(self.main_score, str):
1002
- instance_score[self.main_score] = no_score_value
1003
 
1004
- instance["score"]["instance"].update(
1005
- self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
1006
- instance_score, instance["score"]["instance"]
 
1007
  )
1008
- )
1009
- self._validate_references_and_prediction(references, predictions)
1010
- global_score = {"num_of_instances": len(instances)}
1011
 
1012
- result = self._compute(references, predictions, task_data)
1013
- global_score.update(
1014
- self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
1015
- result, global_score
 
1016
  )
1017
- )
1018
- if self.ci_scores:
1019
- score_names = [
1020
- self._add_score_prefix(score_name) for score_name in self.ci_scores
1021
- ]
1022
- else:
1023
- score_names = [global_score["score_name"]]
1024
 
1025
- for score_name in score_names:
1026
- confidence_interval = self.compute_global_confidence_intervals(
1027
- references, predictions, task_data, score_name
1028
- )
1029
- global_score.update(confidence_interval)
1030
 
1031
- for instance in instances:
1032
- self.update_and_adjust_global_score(instance, global_score)
1033
- yield instance
1034
 
1035
  def _compute(
1036
  self,
@@ -1080,96 +1263,105 @@ class BulkInstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1080
  return instance
1081
 
1082
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1083
- instances = []
1084
- for instance in stream:
1085
- self.verify_instance(instance)
1086
- instance = self.preprocess_instance(instance)
1087
- instances.append(instance)
1088
-
1089
- predictions = [instance["prediction"] for instance in instances]
1090
- references = [instance["references"] for instance in instances]
1091
- task_data = [
1092
- instance["task_data"] if "task_data" in instance else {}
1093
- for instance in instances
1094
- ]
1095
- self._validate_references_and_prediction(references, predictions)
1096
- global_score = {"num_of_instances": len(instances)}
1097
- # compute the metric over all refs and preds
1098
- instance_scores = self.compute(
1099
- references=references,
1100
- predictions=predictions,
1101
- task_data=task_data,
1102
- )
 
 
 
 
 
1103
 
1104
- # add the score and score_name fields
1105
- for instance_score in instance_scores:
1106
- instance_score["score"] = instance_score[self.main_score]
1107
- instance_score["score_name"] = self.main_score
1108
 
1109
- for instance, score in zip(instances, instance_scores):
1110
- if "score" not in instance:
1111
- instance["score"] = {"global": {}, "instance": {}}
1112
 
1113
- instance["score"]["instance"].update(
1114
- self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
1115
- score, instance["score"]["instance"]
 
1116
  )
1117
- )
1118
 
1119
- for reduction, fields in self.reduction_map.items():
1120
- assert (
1121
- reduction in self.implemented_reductions
1122
- ), f"Reduction {reduction} is not implemented, use one of {self.implemented_reductions}"
1123
-
1124
- if reduction == "mean":
1125
- for field_name in fields:
1126
- field_name_with_prefix = self._add_score_prefix(field_name)
1127
- global_score[field_name_with_prefix] = nan_mean(
1128
- [
1129
- instance["score"]["instance"][field_name_with_prefix]
1130
- for instance in instances
1131
- ]
1132
- )
1133
- if field_name == self.main_score:
1134
- global_score["score"] = global_score[field_name_with_prefix]
1135
- global_score["score_name"] = self.score_prefix + self.main_score
 
 
1136
 
1137
- ci_fields = (
1138
- list(set(self.ci_scores))
1139
- if self.ci_scores is not None
1140
- else [self.main_score]
1141
- )
1142
- ci_fields_with_prefix = [
1143
- self._add_score_prefix(ci_field) for ci_field in ci_fields
1144
- ]
1145
- confidence_interval = self.score_based_confidence_interval(
1146
- instances=instances, score_names=ci_fields_with_prefix
1147
- )
1148
- global_score.update(confidence_interval)
1149
- if reduction == "weighted_win_rate":
1150
- for field_name in fields:
1151
- field_name_with_prefix = self._add_score_prefix(field_name)
1152
- total_battles = 0
1153
- wins = 0
1154
- for instance in instances:
1155
- s = instance["score"]["instance"][field_name_with_prefix]
1156
- if s > 0:
1157
- total_battles += s
1158
- wins += s
1159
- elif s < 0:
1160
- total_battles += abs(s)
1161
- else:
1162
- total_battles += 2
1163
- wins += 1
1164
-
1165
- global_score[field_name_with_prefix] = wins / total_battles
1166
- if field_name == self.main_score:
1167
- global_score["score"] = global_score[field_name_with_prefix]
1168
- global_score["score_name"] = self.score_prefix + self.main_score
 
 
1169
 
1170
- for instance in instances:
1171
- self.update_and_adjust_global_score(instance, global_score)
1172
- yield instance
1173
 
1174
  @abstractmethod
1175
  def compute(
@@ -1475,91 +1667,97 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1475
  assert isinstance(fields["score_fields"], list)
1476
 
1477
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1478
- instance_scores = self.compute_instance_scores(stream)
1479
- global_score = {"num_of_instances": len(instance_scores)}
1480
- for reduction_type, reduction_params in self.reduction_map.items():
1481
- assert (
1482
- reduction_type in self.implemented_reductions
1483
- ), f"Reduction {reduction_type} is not implemented, use one of {self.implemented_reductions}"
1484
-
1485
- field_name_full_prefix = ""
1486
- # used for passing to the bootstrapping, depends on whether the groups are fixed or not
1487
- aggregation_function = None
1488
- if reduction_type == "mean":
1489
- aggregation_function = self.average_item_scores
1490
- reduction_fields = list(set(reduction_params))
1491
- # no group reduction, so resample instances individually
1492
- scores_to_resample = instance_scores
1493
- elif reduction_type == "max":
1494
- aggregation_function = self.max_item_scores
1495
- reduction_fields = list(set(reduction_params))
1496
- # no group reduction, so resample instances individually
1497
- scores_to_resample = instance_scores
1498
- elif reduction_type == "group_mean":
1499
- aggregation_function = self.average_item_scores
1500
- self._validate_group_mean_reduction()
1501
- reduction_fields = (
1502
- [self.main_score]
1503
- if "score_fields" not in reduction_params
1504
- else list(set(reduction_params["score_fields"]))
1505
- )
1506
- aggregation_function_name = str(reduction_params["agg_func"][0])
1507
- field_name_full_prefix = "group_" + aggregation_function_name + "_"
1508
- do_resample_as_group = reduction_params["agg_func"][2]
1509
- if do_resample_as_group:
1510
- # append fixed_ to name because resamples the groups as fixed units
1511
- field_name_full_prefix = "fixed_" + field_name_full_prefix
1512
- (
1513
- scores_to_resample,
1514
- aggregation_function,
1515
- ) = self._set_up_group_mean_aggregation(
1516
- instance_scores,
1517
- reduction_params,
1518
- reduction_fields,
1519
- )
1520
- else:
1521
- raise ValueError(
1522
- f"Reduction {reduction_type} is not supported, please specify a valid reduction method in reduction_map {self.reduction_map}."
1523
- )
 
 
 
 
 
1524
 
1525
- # calculate global scores for each reduction field
1526
- for field_name in reduction_fields:
1527
- field_name_full = (
1528
- field_name_full_prefix + self.score_prefix + field_name
1529
- )
1530
- # if group resampling (3rd element of agg_func parameter) is True, then
1531
- # 1. scores_to_resample are the group scores, and
1532
- # 2. aggregation_function is to take the raw mean
1533
- # if no group resampling (3rd element of agg_func parameter) is False, then
1534
- # 1. scores_to_resample are the original instance scores, and
1535
- # 2. aggregation_function is to apply the group aggregation from the instance scores
1536
- # either way, the application of aggregation_function to scores_to_resample yields the global score
1537
- global_score[field_name_full] = aggregation_function(
1538
- scores_to_resample, self.score_prefix + field_name
1539
- )
1540
- if field_name == self.main_score:
1541
- global_score["score"] = global_score[field_name_full]
1542
- global_score["score_name"] = field_name_full
1543
-
1544
- # need to specify which fields should have CIs calculated for them through ci_scores
1545
- # (will not automatically calculate CIs for fields in reduction map)
1546
- if self.ci_scores is not None:
1547
- confidence_interval = self.score_based_confidence_interval(
1548
- instances=scores_to_resample,
1549
- score_names=[
1550
- self.score_prefix + ci_score for ci_score in set(self.ci_scores)
1551
- ],
1552
- ci_score_prefix=field_name_full_prefix,
1553
- aggregation_func=aggregation_function,
1554
- )
1555
- global_score.update(confidence_interval)
 
1556
 
1557
- for instance in instance_scores:
1558
- self.update_and_adjust_global_score(instance, global_score)
1559
 
1560
- for i, instance in enumerate(stream):
1561
- instance["score"] = recursive_copy(instance_scores[i]["score"])
1562
- yield instance
1563
 
1564
  def compute_instance_scores(
1565
  self, stream: Stream, stream_name: Optional[str] = None
@@ -6436,391 +6634,102 @@ RISK_TYPE_TO_CLASS: Dict[RiskType, GraniteGuardianBase] = {
6436
  }
6437
 
6438
 
6439
- class SQLExecutionAccuracy(InstanceMetric):
6440
- reduction_map = {
6441
- "mean": [
6442
- "execution_accuracy",
6443
- "non_empty_execution_accuracy",
6444
- "subset_non_empty_execution_result",
6445
- "non_empty_gold_df",
6446
- "gold_sql_runtime",
6447
- "predicted_sql_runtime",
6448
- "pred_to_gold_runtime_ratio",
6449
- "gold_error",
6450
- "predicted_error",
6451
- ]
6452
- }
6453
  main_score = "non_empty_execution_accuracy"
 
 
 
 
 
 
 
 
 
6454
  ci_scores = [
6455
  "execution_accuracy",
6456
  "non_empty_execution_accuracy",
6457
- "subset_non_empty_execution_result",
 
6458
  "gold_sql_runtime",
6459
  "predicted_sql_runtime",
6460
  ]
6461
 
6462
- prediction_type = "Any" # string representation is compared
6463
- sql_timeout = 30.0
6464
-
6465
- _requirements_list = ["sqlglot", "func_timeout"]
6466
-
6467
- @staticmethod
6468
- def compare_dfs_ignore_colnames_ordered_rows(df1, df2):
6469
- """Compares two DataFrames based on row content, ignoring column names.
6470
-
6471
- Args:
6472
- df1 (pd.DataFrame): Pandas DataFrame 1 to compare.
6473
- df2 (pd.DataFrame): Pandas DataFrame 2 to compare.
6474
-
6475
- Returns:
6476
- True if the DataFrames have the same ordered rows (ignoring column names),
6477
- False otherwise.
6478
- """
6479
- df1.fillna(0, inplace=True)
6480
- df2.fillna(0, inplace=True)
6481
-
6482
- # Compare row counts first for a quick check
6483
- if df1.shape != df2.shape:
6484
- return False
6485
-
6486
- # Convert DataFrames to numpy arrays of strings to handle mixed types
6487
- df1_array = df1.values.astype(str)
6488
- df2_array = df2.values.astype(str)
6489
-
6490
- # Sort each row's elements (column order independence)
6491
- df1_sorted_rows = np.array([np.sort(row) for row in df1_array])
6492
- df2_sorted_rows = np.array([np.sort(row) for row in df2_array])
6493
-
6494
- # Compare the sorted rows in order
6495
- return np.array_equal(df1_sorted_rows, df2_sorted_rows)
6496
-
6497
- @staticmethod
6498
- def compare_dfs_ignore_colnames_unordered_rows(df1, df2):
6499
- """Compares two DataFrames based on row content, ignoring row order and column names.
6500
-
6501
- Args:
6502
- df1 (pd.DataFrame): Pandas DataFrame 1 to compare.
6503
- df2 (pd.DataFrame): Pandas DataFrame 2 to compare.
6504
-
6505
- Returns:
6506
- True if the DataFrames have the same content (ignoring column names and row order),
6507
- False otherwise.
6508
- """
6509
- # Compare shapes early on
6510
- if df1.shape != df2.shape:
6511
- return False
6512
-
6513
- # Convert DataFrames to numpy arrays of strings (to handle mixed data types)
6514
- df1_array = df1.values.astype(str)
6515
- df2_array = df2.values.astype(str)
6516
-
6517
- # Sort columns first, then sort rows
6518
- df1_sorted = np.sort(np.sort(df1_array, axis=1), axis=0)
6519
- df2_sorted = np.sort(np.sort(df2_array, axis=1), axis=0)
6520
-
6521
- # Compare the sorted arrays
6522
- return np.array_equal(df1_sorted, df2_sorted)
6523
-
6524
- @staticmethod
6525
- def compare_dfs_ignore_colnames_subset(df1, df2, ignore_row_order=True):
6526
- """Checks if the values of either DataFrame are a subset of the values in the other DataFrame.
6527
-
6528
- Comparison is column order independent, and could optionally be row order independent.
6529
- We interpret "subset" as follows:
6530
-
6531
- - For each row in df1, there must be a matching (or superset) row in df2, i.e. the set of values
6532
- in the df1 row is a subset of the set of values in that df2 row. Then do the same check in reverse.
6533
- - If either condition (df1 is subset of df2 OR df2 is subset of df1) is satisfied, return True.
6534
-
6535
- We treat an empty dataframe as a subset of nothing, while in theory is a subset of any dataframe.
6536
-
6537
- Args:
6538
- df1 (pd.DataFrame): Pandas DataFrame 1 to compare.
6539
- df2 (pd.DataFrame): Pandas DataFrame 2 to compare.
6540
- ignore_row_order (bool): If True, row order doesn't matter; if False, row order is respected.
6541
-
6542
- Returns:
6543
- bool: True if df1 is a subset of df2 or vice versa, based on the specified row-order condition.
6544
-
6545
- """
6546
- df1_array = df1.values.astype(str)
6547
- df2_array = df2.values.astype(str)
6548
-
6549
- df1_sorted_rows = [np.sort(row) for row in df1_array]
6550
- df2_sorted_rows = [np.sort(row) for row in df2_array]
6551
-
6552
- def row_is_subset(r_small, r_big):
6553
- """Check if all elements of r_small are in r_big."""
6554
- return set(r_small).issubset(set(r_big))
6555
-
6556
- def df_is_subset_of_another(rows_small, rows_big, respect_order):
6557
- """Check if the rows_small is subset of rows_big under the given order condition."""
6558
- if not rows_small:
6559
- return False # DataFrame needs to be non-empty
6560
-
6561
- # If row order matters:
6562
- if respect_order:
6563
- i, j = 0, 0
6564
- while i < len(rows_small) and j < len(rows_big):
6565
- if row_is_subset(rows_small[i], rows_big[j]):
6566
- i += 1
6567
- j += 1
6568
- return i == len(rows_small)
6569
- # Row order doesn't matter:
6570
- matched_indices = set()
6571
- for r_small in rows_small:
6572
- found_match = False
6573
- for idx, r_big in enumerate(rows_big):
6574
- if idx not in matched_indices and row_is_subset(r_small, r_big):
6575
- found_match = True
6576
- matched_indices.add(idx)
6577
- break
6578
- if not found_match:
6579
- return False
6580
- return True
6581
-
6582
- df1_sub_df2 = df_is_subset_of_another(
6583
- df1_sorted_rows, df2_sorted_rows, not ignore_row_order
6584
- )
6585
- df2_sub_df1 = df_is_subset_of_another(
6586
- df2_sorted_rows, df1_sorted_rows, not ignore_row_order
6587
  )
6588
 
6589
- return df1_sub_df2 or df2_sub_df1
 
 
 
 
 
 
 
 
 
6590
 
6591
- def get_sql_execution_results(
6592
- self, predicted_sql: str, gold_sql: str, connector
6593
- ) -> (int, int, int, int, int, int, int, int, int, str, str, str):
6594
- """Runs SQL queries using the provided connector and gets scores and results.
6595
 
6596
- Args:
6597
- predicted_sql (str): predicted SQL query
6598
- gold_sql (str): gold reference SQL query
6599
- connector: database connector
 
6600
 
6601
- Returns:
6602
- a 12-tuple of
6603
- 1. execution_result: if df responses match
6604
- 2. non_empty_execution_result: if dfs are non-empty and match
6605
- 3. subset_non_empty_execution_result: if non-empty dfs and one is a subset of the other
6606
- 4. non_empty_gold_df: if gt df is non-empty
6607
- 5. gold_sql_runtime: ground truth query runtime
6608
- 6. predicted_sql_runtime: predicted query runtime
6609
- 7. pred_to_gold_runtime_ratio: ratio of predicted query runtime to gt query runtime
6610
- 8. gold_error: if gt has an error
6611
- 9. predicted_error: if predicted query has an error
6612
- 10. ground truth dataframe
6613
- 11. predicted query's dataframe
6614
- 12. error message (if any)
6615
- """
6616
- import time
6617
 
6618
- from func_timeout import func_timeout
6619
- from func_timeout.exceptions import FunctionTimedOut
 
 
6620
 
6621
- from .sql_utils import sqlglot_optimized_equivalence
6622
 
6623
- gold_res = None
6624
- gold_error = ""
6625
- gold_sql_runtime = 0
6626
- try:
6627
- start_time = time.perf_counter()
6628
- gold_res, gold_error = func_timeout(
6629
- self.sql_timeout,
6630
- connector.execute_query,
6631
- args=(gold_sql,),
6632
- )
6633
- end_time = time.perf_counter()
6634
- gold_sql_runtime = end_time - start_time
6635
- except FunctionTimedOut as e:
6636
- pred_error = f"Timeout error executing gold SQL: {e}"
6637
- logger.warning(pred_error)
6638
- except Exception as e:
6639
- gold_error = f"Error executing gold SQL: {e}"
6640
- if gold_error is not None:
6641
- return (
6642
- 0,
6643
- 0,
6644
- 0,
6645
- 0,
6646
- gold_sql_runtime,
6647
- 0,
6648
- 0,
6649
- 0,
6650
- 0,
6651
- "",
6652
- "",
6653
- gold_error,
6654
- )
6655
 
6656
- if isinstance(gold_res, dict) and "results" in gold_res:
6657
- gold_res = gold_res["results"]
6658
- gold_df = pd.DataFrame(gold_res)
6659
- non_empty_gold_df = 0 if gold_df.empty else 1
6660
 
6661
- no_execution_match_result = (
6662
- 1,
6663
- non_empty_gold_df,
6664
- non_empty_gold_df,
6665
- non_empty_gold_df,
6666
- gold_sql_runtime,
6667
- 0,
6668
- 0,
6669
- 0,
6670
- 0,
6671
- gold_df.to_json(),
6672
- "",
6673
- "",
6674
- )
6675
- if predicted_sql.lower().strip() == gold_sql.lower().strip():
6676
- return no_execution_match_result
6677
- try:
6678
- if sqlglot_optimized_equivalence(gold_sql, predicted_sql):
6679
- return no_execution_match_result
6680
- except Exception as e: # Catch specific exceptions if possible
6681
- logger.info(
6682
- f"Couldn't test equivalent_sqls: {e}. Treating as non-equivalent and going to test with the db."
6683
- )
6684
 
6685
- pred_res = None
6686
- pred_error = ""
6687
- pred_sql_runtime = 0
6688
- try:
6689
- start_time = time.perf_counter()
6690
- pred_res, pred_error = func_timeout(
6691
- self.sql_timeout,
6692
- connector.execute_query,
6693
- args=(predicted_sql,),
6694
- )
6695
- end_time = time.perf_counter()
6696
- pred_sql_runtime = end_time - start_time
6697
- except FunctionTimedOut as e:
6698
- pred_error = f"Timeout error executing predicted SQL: {e}"
6699
- logger.info(pred_error)
6700
- except Exception as e:
6701
- pred_error = f"Error executing predicted SQL: {e}"
6702
- logger.info(pred_error)
6703
-
6704
- pred_to_gold_runtime_ratio = (
6705
- float(pred_sql_runtime) / gold_sql_runtime if gold_sql_runtime > 0 else 0
6706
  )
6707
 
6708
- if pred_res is None:
6709
- return (
6710
- 0,
6711
- 0,
6712
- 0,
6713
- 0,
6714
- gold_sql_runtime,
6715
- pred_sql_runtime,
6716
- pred_to_gold_runtime_ratio,
6717
- 0,
6718
- 1,
6719
- "",
6720
- "",
6721
- pred_error,
6722
- )
6723
-
6724
- if isinstance(pred_res, dict) and "results" in pred_res:
6725
- pred_res = pred_res["results"]
6726
- predicted_df = pd.DataFrame(pred_res)
6727
-
6728
- subset_non_empty_execution_result = 0
6729
- non_empty_execution_result = 0
6730
- if "ORDER BY" in gold_sql.upper():
6731
- execution_result = (
6732
- 1
6733
- if self.compare_dfs_ignore_colnames_ordered_rows(predicted_df, gold_df)
6734
- else 0
6735
- )
6736
- if non_empty_gold_df:
6737
- if execution_result == 1:
6738
- non_empty_execution_result = 1
6739
- if self.compare_dfs_ignore_colnames_subset(
6740
- gold_df, predicted_df, ignore_row_order=False
6741
- ):
6742
- subset_non_empty_execution_result = 1
6743
- else:
6744
- execution_result = (
6745
- 1
6746
- if self.compare_dfs_ignore_colnames_unordered_rows(
6747
- predicted_df, gold_df
6748
- )
6749
- else 0
6750
- )
6751
- if non_empty_gold_df:
6752
- if execution_result == 1:
6753
- non_empty_execution_result = 1
6754
- if self.compare_dfs_ignore_colnames_subset(
6755
- gold_df, predicted_df, ignore_row_order=True
6756
- ):
6757
- subset_non_empty_execution_result = 1
6758
 
6759
- return (
6760
- execution_result,
6761
- non_empty_execution_result,
6762
- subset_non_empty_execution_result,
6763
- non_empty_gold_df,
6764
- gold_sql_runtime,
6765
- pred_sql_runtime,
6766
- pred_to_gold_runtime_ratio,
6767
- 0,
6768
- 0,
6769
- gold_df.to_json(),
6770
- predicted_df.to_json(),
6771
- pred_error,
6772
  )
6773
 
6774
- def compute(self, references: List[Any], prediction: str, task_data: Dict) -> dict:
6775
- from .sql_utils import get_db_connector
6776
-
6777
- predicted_sql = prediction
6778
- execution_result: float = 0.0
6779
-
6780
- if predicted_sql and predicted_sql.strip() != "":
6781
- if not predicted_sql.startswith("SELECT") and "SELECT" in predicted_sql:
6782
- predicted_sql = predicted_sql[predicted_sql.find("SELECT") :]
6783
- if ";" in predicted_sql:
6784
- predicted_sql = predicted_sql[: predicted_sql.find(";") + 1]
6785
-
6786
- db_connector = get_db_connector(task_data["db"]["db_type"])(task_data["db"])
6787
-
6788
- logger.debug(
6789
- f"Starting to get SQL execution results over DB: {task_data['db']}"
6790
- )
6791
- (
6792
- execution_result,
6793
- non_empty_execution_result,
6794
- subset_non_empty_execution_result,
6795
- non_empty_gold_df,
6796
- gold_sql_runtime,
6797
- predicted_sql_runtime,
6798
- pred_to_gold_runtime_ratio,
6799
- gold_error,
6800
- predicted_error,
6801
- gold_df_json,
6802
- predicted_df_json,
6803
- error_message,
6804
- ) = self.get_sql_execution_results(
6805
- predicted_sql, references[0], db_connector
6806
- )
6807
-
6808
- result = {
6809
- "execution_accuracy": float(execution_result),
6810
- "non_empty_execution_accuracy": float(non_empty_execution_result),
6811
- "subset_non_empty_execution_result": float(
6812
- subset_non_empty_execution_result
6813
- ),
6814
- "non_empty_gold_df": float(non_empty_gold_df),
6815
- "gold_sql_runtime": float(gold_sql_runtime),
6816
- "predicted_sql_runtime": float(predicted_sql_runtime),
6817
- "pred_to_gold_runtime_ratio": float(pred_to_gold_runtime_ratio),
6818
- "gold_error": float(gold_error),
6819
- "predicted_error": float(predicted_error),
6820
- "error_message": str(error_message),
6821
- "gold_df_json": str(gold_df_json),
6822
- "predicted_df_json": str(predicted_df_json),
6823
- }
6824
  result["score"] = result[self.main_score]
6825
  result["score_name"] = self.main_score
6826
  logger.debug(f"SQL Execution Accuracy Result: {result}")
@@ -6828,34 +6737,22 @@ class SQLExecutionAccuracy(InstanceMetric):
6828
 
6829
 
6830
  class SQLNonExecutionAccuracy(InstanceMetric):
6831
- reduction_map = {
6832
- "mean": [
6833
- "sqlglot_validity",
6834
- "sqlparse_validity",
6835
- "sqlglot_equivalence",
6836
- "sqlglot_optimized_equivalence",
6837
- "sqlparse_equivalence",
6838
- "sql_exact_match",
6839
- "sql_syntactic_equivalence",
6840
- ]
6841
- }
6842
- main_score = "sqlglot_equivalence"
6843
- ci_scores = [
6844
- "sqlglot_validity",
6845
- "sqlparse_validity",
6846
- "sqlglot_equivalence",
6847
- "sqlglot_optimized_equivalence",
6848
- "sqlparse_equivalence",
6849
- "sql_exact_match",
6850
- "sql_syntactic_equivalence",
6851
  ]
 
 
 
6852
 
6853
  prediction_type = "Any" # string representation is compared
6854
 
6855
  _requirements_list = ["sqlglot", "sqlparse"]
6856
 
6857
  def compute(self, references: List[Any], prediction: str, task_data: Dict) -> dict:
6858
- from .sql_utils import (
 
6859
  is_sqlglot_parsable,
6860
  is_sqlparse_parsable,
6861
  sql_exact_match,
@@ -6864,48 +6761,45 @@ class SQLNonExecutionAccuracy(InstanceMetric):
6864
  sqlparse_queries_equivalent,
6865
  )
6866
 
6867
- predicted_sql = prediction
6868
  gold_sql = references[0]
6869
-
6870
- if predicted_sql and predicted_sql.strip() != "":
6871
- if not predicted_sql.startswith("SELECT") and "SELECT" in predicted_sql:
6872
- predicted_sql = predicted_sql[predicted_sql.find("SELECT") :]
6873
- if ";" in predicted_sql:
6874
- predicted_sql = predicted_sql[: predicted_sql.find(";") + 1]
6875
 
6876
  is_sqlglot_parsable = is_sqlglot_parsable(predicted_sql)
6877
  is_sqlparse_parsable = is_sqlparse_parsable(predicted_sql)
6878
- result = {
6879
- "sqlglot_validity": float(is_sqlglot_parsable),
6880
- "sqlparse_validity": float(is_sqlparse_parsable),
6881
- "sqlglot_equivalence": float(
6882
  sqlglot_parsed_queries_equivalent(predicted_sql, gold_sql)
6883
  if is_sqlglot_parsable
6884
  else 0
6885
  ),
6886
- "sqlglot_optimized_equivalence": float(
6887
  sqlglot_optimized_equivalence(predicted_sql, gold_sql)
6888
  if is_sqlglot_parsable
6889
  else 0
6890
  ),
6891
- "sqlparse_equivalence": float(
6892
  sqlparse_queries_equivalent(predicted_sql, gold_sql)
6893
  if is_sqlparse_parsable
6894
  else 0
6895
  ),
6896
- "sql_exact_match": float(sql_exact_match(predicted_sql, gold_sql)),
6897
- }
6898
- result["sql_syntactic_equivalence"] = float(
 
 
6899
  any(
6900
- result[key]
6901
- for key in [
6902
- "sqlglot_equivalence",
6903
- "sqlglot_optimized_equivalence",
6904
- "sqlparse_equivalence",
6905
- "sql_exact_match",
6906
  ]
6907
  )
6908
  )
 
 
6909
  logger.debug(f"SQL Non Execution Accuracy Result: {result}")
6910
  result["score"] = result[self.main_score]
6911
  result["score_name"] = self.main_score
 
8
  import warnings
9
  from abc import ABC, abstractmethod
10
  from collections import Counter, defaultdict
11
+ from dataclasses import asdict, field
12
+ from dataclasses import fields as dataclasses_fields
13
  from enum import Enum
14
  from functools import lru_cache
15
  from typing import (
 
43
  OptionalField,
44
  )
45
  from .deprecation_utils import deprecation
46
+ from .dict_utils import dict_get
47
+ from .error_utils import Documentation, UnitxtError, UnitxtWarning, error_context
48
  from .inference import (
49
  HFPipelineBasedInferenceEngine,
50
  InferenceEngine,
 
66
  from .random_utils import get_seed
67
  from .settings_utils import get_settings
68
  from .stream import MultiStream, Stream
69
+ from .text2sql_utils import SQLExecutionResult, SQLNonExecutionMetricResult
70
  from .type_utils import isoftype, parse_type_string, to_type_string
71
  from .types import ToolCall
72
  from .utils import deep_copy, recursive_copy, retry_connection_with_exponential_backoff
 
385
  return intermediates
386
 
387
  def process(self, stream: Stream, stream_name: Optional[str] = None):
388
+ with error_context(
389
+ self,
390
+ stage="Evaluating Metric",
391
+ help="https://www.unitxt.ai/en/latest/docs/adding_metric.html",
392
+ ):
393
+ instances_scores, global_scores = self.compute(stream, stream_name)
394
+ for i, (instance, instance_scores) in enumerate(
395
+ zip(stream, instances_scores)
396
+ ):
397
+ previous_score = instance.get("score", {"global": {}, "instance": {}})
398
+
399
+ if i == 0:
400
+ for key in global_scores:
401
+ if is_original_key(key) and key in previous_score["global"]:
402
+ UnitxtWarning(
403
+ message=f"Metric '{key}' that has just been evaluated with value {global_scores[key]}, is already recorded "
404
+ f"to have value {previous_score['global'][key]} by a previous metric evaluation on this instance or stream. "
405
+ f"To avoid overwriting the existing value, add a score_prefix to the metric name (e.g. score_prefix='my_second_' , "
406
+ f"which will yield, in this case, a score named: 'my_second_{key}')",
407
+ additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
408
+ )
409
 
410
+ global_scores = {**previous_score["global"], **global_scores}
411
+ instance_scores = {**previous_score["instance"], **instance_scores}
412
 
413
+ yield {
414
+ **instance,
415
+ "score": {"global": global_scores, "instance": instance_scores},
416
+ }
417
 
418
  def compute(self, stream: Stream, stream_name: Optional[str] = None):
419
  evaluation_inputs_stream = self._instances_stream_to_evaluation_inputs(stream)
 
463
  return result
464
 
465
 
466
+ class GroupReduction(AggregationReduction[Tuple[str, Dict[str, float]]]):
467
+ def reduce_list(self, lst: List[Tuple[str, float]]):
468
+ pass
469
+
470
+ def reduce(self, intermidates: Tuple[str, Dict[str, float]]):
471
+ lists = {}
472
+ for id, intermidate in intermidates:
473
+ for key, val in intermidate.items():
474
+ if key not in lists:
475
+ lists[key] = []
476
+ lists[key].append((id, val))
477
+
478
+ result = {}
479
+ for key, val_list in lists.items():
480
+ result[key] = self.reduce_list(val_list)
481
+ return result
482
+
483
+
484
+ class GroupMean(GroupReduction):
485
+ def reduce_list(self, lst: List[Tuple[str, float]]):
486
+ return nan_mean([item[1] for item in lst])
487
+
488
+
489
+ class SequentialSuccess(GroupReduction):
490
+ threshold: float = 0.5
491
+
492
+ def reduce_list(self, lst: List[Tuple[str, float]]):
493
+ sorted_items = [item for _, item in sorted(lst, key=lambda x: x[0])]
494
+ successful = 0
495
+ for item in sorted_items:
496
+ if item > self.threshold:
497
+ successful += 1
498
+ else:
499
+ break
500
+ return successful / len(lst)
501
+
502
+
503
  class MeanReduction(DictReduction):
504
  def reduce_list(self, lst: List[float]):
505
  return nan_mean(lst)
 
515
  return float(nan_max(lst))
516
 
517
 
518
+ class GroupMetric(
519
+ MapReduceMetric[PredictionType, IntermediateType],
520
+ Generic[PredictionType, IntermediateType],
521
+ ):
522
+ main_score: str = None
523
+ metric: MapReduceMetric[PredictionType, IntermediateType]
524
+ group_id_field: str
525
+ item_id_field: str
526
+ in_group_reduction: GroupReduction = GroupMean()
527
+ cross_group_reduction: GroupReduction = GroupMean()
528
+ n_resamples = None
529
+
530
+ def _get_group_id(self, task_data) -> str:
531
+ return str(dict_get(task_data, self.group_id_field))
532
+
533
+ def _get_item_id(self, task_data) -> str:
534
+ return str(dict_get(task_data, self.item_id_field))
535
+
536
+ def prepare(self):
537
+ super().prepare()
538
+ self.main_score = self.metric.main_score
539
+
540
+ def map_stream(
541
+ self,
542
+ evaluation_inputs_stream: Generator[
543
+ EvaluationInput[PredictionType], None, None
544
+ ],
545
+ ) -> List[Tuple[IntermediateType, str, str]]:
546
+ group_ids: List[str] = []
547
+ item_ids: List[str] = []
548
+
549
+ def multi_turn_stream(
550
+ evaluation_inputs_stream: Generator[
551
+ EvaluationInput[PredictionType], None, None
552
+ ],
553
+ ) -> Generator[
554
+ Tuple[PredictionType, List[PredictionType], Dict[str, Any]], None, None
555
+ ]:
556
+ for prediction, references, task_data in evaluation_inputs_stream:
557
+ group_ids.append(self._get_group_id(task_data))
558
+ item_ids.append(self._get_item_id(task_data))
559
+ yield prediction, references, task_data
560
+
561
+ intermediates: List[IntermediateType] = list(
562
+ self.metric.map_stream(multi_turn_stream(evaluation_inputs_stream))
563
+ )
564
+
565
+ return list(zip(intermediates, group_ids, item_ids))
566
+
567
+ def reduce_group(self, dialog_data: Dict[str, Dict[str, Any]]):
568
+ return self.in_group_reduction.reduce(list(dialog_data.items()))
569
+
570
+ def reduce_one(self, intermidate: Tuple[IntermediateType, str, str]):
571
+ return self.metric.reduce_one(intermidate[0])
572
+
573
+ def reduce(
574
+ self, intermediates: List[Tuple[IntermediateType, str, str]]
575
+ ) -> Dict[str, Any]:
576
+ data: Dict[str, Dict[str, Any]] = {}
577
+ for intermediate, group_id, item_id in intermediates:
578
+ if group_id not in data:
579
+ data[group_id] = {}
580
+ data[group_id][item_id] = self.metric.reduce_one(intermediate)
581
+
582
+ group_scores: Dict[str, Dict[str, Any]] = {
583
+ dialog_id: self.reduce_group(dialog_data)
584
+ for dialog_id, dialog_data in data.items()
585
+ }
586
+
587
+ return self.cross_group_reduction.reduce(list(group_scores.items()))
588
+
589
+
590
+ class MultiTurnMetric(
591
+ GroupMetric[PredictionType, IntermediateType],
592
+ Generic[PredictionType, IntermediateType],
593
+ ):
594
+ group_id_field = "conversation/id"
595
+ item_id_field = "conversation/dialog"
596
+
597
+ def _get_item_id(self, task_data):
598
+ return "assistant_turn_" + str(
599
+ len(dict_get(task_data, self.item_id_field)) // 2
600
+ )
601
+
602
+
603
  class ReductionInstanceMetric(
604
  MapReduceMetric[PredictionType, IntermediateType],
605
  Generic[PredictionType, IntermediateType],
 
836
  }
837
 
838
 
839
+ class MultiTurnToolCallingMetric(ReductionInstanceMetric[str, Dict[str, float]]):
840
+ """Compares each predicted tool call with list of references tool call."""
841
+
842
+ main_score = "argument_schema_validation"
843
+ reduction = MeanReduction()
844
+ prediction_type = List[ToolCall]
845
+ _requirements_list = ["jsonschema-rs"]
846
+
847
+ def prepare(self):
848
+ super().prepare()
849
+ import jsonschema_rs
850
+
851
+ self._schema = jsonschema_rs
852
+
853
+ def map(
854
+ self,
855
+ prediction: List[ToolCall],
856
+ references: List[List[ToolCall]],
857
+ task_data: Dict[str, Any],
858
+ ) -> Dict[str, float]:
859
+ validation_scores = []
860
+ for tool_call in prediction:
861
+ parameters = None
862
+ for tool in task_data["__tools__"]:
863
+ if tool["function"]["name"] == tool_call["name"]:
864
+ parameters = tool["function"]["parameters"]
865
+
866
+ if parameters is None:
867
+ validation_scores.append(0.0)
868
+ else:
869
+ try:
870
+ self._schema.validate(
871
+ parameters,
872
+ tool_call["arguments"],
873
+ )
874
+ validation_scores.append(1.0)
875
+ except self._schema.ValidationError:
876
+ validation_scores.append(0.0)
877
+
878
+ argument_schema_validation = sum(validation_scores) / len(validation_scores)
879
+
880
+ return {
881
+ "argument_schema_validation": argument_schema_validation,
882
+ }
883
+
884
+
885
  class MetricWithConfidenceInterval(Metric):
886
  # The number of resamples used to estimate the confidence intervals of this metric.
887
  # Use None to disable confidence interval computation.
 
1132
  process_single_instances = True
1133
 
1134
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1135
+ with error_context(
1136
+ self,
1137
+ stage="Evaluating Metric",
1138
+ help="https://www.unitxt.ai/en/latest/docs/adding_metric.html",
1139
+ ):
1140
+ references = []
1141
+ predictions = []
1142
+ task_data = []
1143
 
1144
+ instances = []
1145
 
1146
+ for instance in stream:
1147
+ instance = self.verify_instance(instance)
1148
 
1149
+ if "score" not in instance:
1150
+ instance["score"] = {"global": {}, "instance": {}}
1151
 
1152
+ instance_references, instance_prediction = (
1153
+ instance["references"],
1154
+ instance["prediction"],
1155
+ )
1156
 
1157
+ references.append(instance_references)
1158
+ predictions.append(instance_prediction)
1159
+ instances.append(instance)
1160
 
1161
+ instance_task_data = (
1162
+ instance["task_data"] if "task_data" in instance else {}
1163
+ )
1164
+ task_data.append(instance_task_data)
1165
+ instance_score = None
1166
 
1167
+ # for backward compatibility
1168
+ no_score_value = np.nan
1169
+ if self.process_single_instances:
1170
+ try:
1171
+ instance_score = self._compute(
1172
+ [instance_references],
1173
+ [instance_prediction],
1174
+ [instance_task_data],
1175
+ )
1176
+ except:
1177
+ no_score_value = None
1178
+ if not instance_score:
1179
+ instance_score = {
1180
+ "score": no_score_value,
1181
+ "score_name": self.main_score,
1182
+ }
1183
 
1184
+ if isinstance(self.main_score, str):
1185
+ instance_score[self.main_score] = no_score_value
1186
 
1187
+ instance["score"]["instance"].update(
1188
+ self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
1189
+ instance_score, instance["score"]["instance"]
1190
+ )
1191
  )
1192
+ self._validate_references_and_prediction(references, predictions)
1193
+ global_score = {"num_of_instances": len(instances)}
 
1194
 
1195
+ result = self._compute(references, predictions, task_data)
1196
+ global_score.update(
1197
+ self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
1198
+ result, global_score
1199
+ )
1200
  )
1201
+ if self.ci_scores:
1202
+ score_names = [
1203
+ self._add_score_prefix(score_name) for score_name in self.ci_scores
1204
+ ]
1205
+ else:
1206
+ score_names = [global_score["score_name"]]
 
1207
 
1208
+ for score_name in score_names:
1209
+ confidence_interval = self.compute_global_confidence_intervals(
1210
+ references, predictions, task_data, score_name
1211
+ )
1212
+ global_score.update(confidence_interval)
1213
 
1214
+ for instance in instances:
1215
+ self.update_and_adjust_global_score(instance, global_score)
1216
+ yield instance
1217
 
1218
  def _compute(
1219
  self,
 
1263
  return instance
1264
 
1265
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1266
+ with error_context(
1267
+ self,
1268
+ stage="Evaluating Metrics",
1269
+ help="https://www.unitxt.ai/en/latest/docs/adding_metric.html",
1270
+ ):
1271
+ instances = []
1272
+ for instance in stream:
1273
+ self.verify_instance(instance)
1274
+ instance = self.preprocess_instance(instance)
1275
+ instances.append(instance)
1276
+
1277
+ predictions = [instance["prediction"] for instance in instances]
1278
+ references = [instance["references"] for instance in instances]
1279
+ task_data = [
1280
+ instance["task_data"] if "task_data" in instance else {}
1281
+ for instance in instances
1282
+ ]
1283
+ self._validate_references_and_prediction(references, predictions)
1284
+ global_score = {"num_of_instances": len(instances)}
1285
+ # compute the metric over all refs and preds
1286
+ instance_scores = self.compute(
1287
+ references=references,
1288
+ predictions=predictions,
1289
+ task_data=task_data,
1290
+ )
1291
 
1292
+ # add the score and score_name fields
1293
+ for instance_score in instance_scores:
1294
+ instance_score["score"] = instance_score[self.main_score]
1295
+ instance_score["score_name"] = self.main_score
1296
 
1297
+ for instance, score in zip(instances, instance_scores):
1298
+ if "score" not in instance:
1299
+ instance["score"] = {"global": {}, "instance": {}}
1300
 
1301
+ instance["score"]["instance"].update(
1302
+ self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
1303
+ score, instance["score"]["instance"]
1304
+ )
1305
  )
 
1306
 
1307
+ for reduction, fields in self.reduction_map.items():
1308
+ assert (
1309
+ reduction in self.implemented_reductions
1310
+ ), f"Reduction {reduction} is not implemented, use one of {self.implemented_reductions}"
1311
+
1312
+ if reduction == "mean":
1313
+ for field_name in fields:
1314
+ field_name_with_prefix = self._add_score_prefix(field_name)
1315
+ global_score[field_name_with_prefix] = nan_mean(
1316
+ [
1317
+ instance["score"]["instance"][field_name_with_prefix]
1318
+ for instance in instances
1319
+ ]
1320
+ )
1321
+ if field_name == self.main_score:
1322
+ global_score["score"] = global_score[field_name_with_prefix]
1323
+ global_score["score_name"] = (
1324
+ self.score_prefix + self.main_score
1325
+ )
1326
 
1327
+ ci_fields = (
1328
+ list(set(self.ci_scores))
1329
+ if self.ci_scores is not None
1330
+ else [self.main_score]
1331
+ )
1332
+ ci_fields_with_prefix = [
1333
+ self._add_score_prefix(ci_field) for ci_field in ci_fields
1334
+ ]
1335
+ confidence_interval = self.score_based_confidence_interval(
1336
+ instances=instances, score_names=ci_fields_with_prefix
1337
+ )
1338
+ global_score.update(confidence_interval)
1339
+ if reduction == "weighted_win_rate":
1340
+ for field_name in fields:
1341
+ field_name_with_prefix = self._add_score_prefix(field_name)
1342
+ total_battles = 0
1343
+ wins = 0
1344
+ for instance in instances:
1345
+ s = instance["score"]["instance"][field_name_with_prefix]
1346
+ if s > 0:
1347
+ total_battles += s
1348
+ wins += s
1349
+ elif s < 0:
1350
+ total_battles += abs(s)
1351
+ else:
1352
+ total_battles += 2
1353
+ wins += 1
1354
+
1355
+ global_score[field_name_with_prefix] = wins / total_battles
1356
+ if field_name == self.main_score:
1357
+ global_score["score"] = global_score[field_name_with_prefix]
1358
+ global_score["score_name"] = (
1359
+ self.score_prefix + self.main_score
1360
+ )
1361
 
1362
+ for instance in instances:
1363
+ self.update_and_adjust_global_score(instance, global_score)
1364
+ yield instance
1365
 
1366
  @abstractmethod
1367
  def compute(
 
1667
  assert isinstance(fields["score_fields"], list)
1668
 
1669
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1670
+ with error_context(
1671
+ self,
1672
+ stage="Evaluating Metrics",
1673
+ help="https://www.unitxt.ai/en/latest/docs/adding_metric.html",
1674
+ ):
1675
+ instance_scores = self.compute_instance_scores(stream)
1676
+ global_score = {"num_of_instances": len(instance_scores)}
1677
+ for reduction_type, reduction_params in self.reduction_map.items():
1678
+ assert (
1679
+ reduction_type in self.implemented_reductions
1680
+ ), f"Reduction {reduction_type} is not implemented, use one of {self.implemented_reductions}"
1681
+
1682
+ field_name_full_prefix = ""
1683
+ # used for passing to the bootstrapping, depends on whether the groups are fixed or not
1684
+ aggregation_function = None
1685
+ if reduction_type == "mean":
1686
+ aggregation_function = self.average_item_scores
1687
+ reduction_fields = list(set(reduction_params))
1688
+ # no group reduction, so resample instances individually
1689
+ scores_to_resample = instance_scores
1690
+ elif reduction_type == "max":
1691
+ aggregation_function = self.max_item_scores
1692
+ reduction_fields = list(set(reduction_params))
1693
+ # no group reduction, so resample instances individually
1694
+ scores_to_resample = instance_scores
1695
+ elif reduction_type == "group_mean":
1696
+ aggregation_function = self.average_item_scores
1697
+ self._validate_group_mean_reduction()
1698
+ reduction_fields = (
1699
+ [self.main_score]
1700
+ if "score_fields" not in reduction_params
1701
+ else list(set(reduction_params["score_fields"]))
1702
+ )
1703
+ aggregation_function_name = str(reduction_params["agg_func"][0])
1704
+ field_name_full_prefix = "group_" + aggregation_function_name + "_"
1705
+ do_resample_as_group = reduction_params["agg_func"][2]
1706
+ if do_resample_as_group:
1707
+ # append fixed_ to name because resamples the groups as fixed units
1708
+ field_name_full_prefix = "fixed_" + field_name_full_prefix
1709
+ (
1710
+ scores_to_resample,
1711
+ aggregation_function,
1712
+ ) = self._set_up_group_mean_aggregation(
1713
+ instance_scores,
1714
+ reduction_params,
1715
+ reduction_fields,
1716
+ )
1717
+ else:
1718
+ raise ValueError(
1719
+ f"Reduction {reduction_type} is not supported, please specify a valid reduction method in reduction_map {self.reduction_map}."
1720
+ )
1721
 
1722
+ # calculate global scores for each reduction field
1723
+ for field_name in reduction_fields:
1724
+ field_name_full = (
1725
+ field_name_full_prefix + self.score_prefix + field_name
1726
+ )
1727
+ # if group resampling (3rd element of agg_func parameter) is True, then
1728
+ # 1. scores_to_resample are the group scores, and
1729
+ # 2. aggregation_function is to take the raw mean
1730
+ # if no group resampling (3rd element of agg_func parameter) is False, then
1731
+ # 1. scores_to_resample are the original instance scores, and
1732
+ # 2. aggregation_function is to apply the group aggregation from the instance scores
1733
+ # either way, the application of aggregation_function to scores_to_resample yields the global score
1734
+ global_score[field_name_full] = aggregation_function(
1735
+ scores_to_resample, self.score_prefix + field_name
1736
+ )
1737
+ if field_name == self.main_score:
1738
+ global_score["score"] = global_score[field_name_full]
1739
+ global_score["score_name"] = field_name_full
1740
+
1741
+ # need to specify which fields should have CIs calculated for them through ci_scores
1742
+ # (will not automatically calculate CIs for fields in reduction map)
1743
+ if self.ci_scores is not None:
1744
+ confidence_interval = self.score_based_confidence_interval(
1745
+ instances=scores_to_resample,
1746
+ score_names=[
1747
+ self.score_prefix + ci_score
1748
+ for ci_score in set(self.ci_scores)
1749
+ ],
1750
+ ci_score_prefix=field_name_full_prefix,
1751
+ aggregation_func=aggregation_function,
1752
+ )
1753
+ global_score.update(confidence_interval)
1754
 
1755
+ for instance in instance_scores:
1756
+ self.update_and_adjust_global_score(instance, global_score)
1757
 
1758
+ for i, instance in enumerate(stream):
1759
+ instance["score"] = recursive_copy(instance_scores[i]["score"])
1760
+ yield instance
1761
 
1762
  def compute_instance_scores(
1763
  self, stream: Stream, stream_name: Optional[str] = None
 
6634
  }
6635
 
6636
 
6637
+ class SQLExecutionLogicAccuracy(InstanceMetric):
6638
+ sql_timeout: float = 60.0
6639
+ prediction_type = "Any"
6640
+ _requirements_list = ["sqlglot", "func_timeout"]
6641
+
 
 
 
 
 
 
 
 
 
6642
  main_score = "non_empty_execution_accuracy"
6643
+
6644
+ all_metrics = [
6645
+ f.name
6646
+ for f in dataclasses_fields(SQLExecutionResult)
6647
+ if isinstance(f.type, type) and f.type in (int, float)
6648
+ ]
6649
+
6650
+ reduction_map = {"mean": all_metrics}
6651
+
6652
  ci_scores = [
6653
  "execution_accuracy",
6654
  "non_empty_execution_accuracy",
6655
+ "subset_non_empty_execution_accuracy",
6656
+ "execution_accuracy_bird",
6657
  "gold_sql_runtime",
6658
  "predicted_sql_runtime",
6659
  ]
6660
 
6661
+ def compute(self, references: List[Any], prediction: str, task_data: Dict) -> dict:
6662
+ from .text2sql_utils import (
6663
+ ALL_DIALECTS,
6664
+ extract_sql_from_text,
6665
+ get_db_connector,
6666
+ get_sql_execution_results,
6667
+ replace_select_clause,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6668
  )
6669
 
6670
+ predicted_sql = extract_sql_from_text(prediction)
6671
+ gold_sql = references[0]
6672
+ dialect = task_data["db"]["db_type"]
6673
+ if dialect not in ALL_DIALECTS:
6674
+ dialect = None
6675
+ revised_sql = (
6676
+ replace_select_clause(gold_sql, predicted_sql, dialect)
6677
+ if gold_sql and predicted_sql
6678
+ else ""
6679
+ )
6680
 
6681
+ db_connector = get_db_connector(task_data["db"]["db_type"])(task_data["db"])
6682
+ result_obj = get_sql_execution_results(
6683
+ revised_sql, gold_sql, db_connector, self.sql_timeout
6684
+ )
6685
 
6686
+ result = asdict(result_obj)
6687
+ result["score"] = result[self.main_score]
6688
+ result["score_name"] = self.main_score
6689
+ logger.debug(f"SQL Execution Accuracy Result: {result}")
6690
+ return result
6691
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6692
 
6693
+ class SQLExecutionAccuracy(InstanceMetric):
6694
+ sql_timeout: float = 60.0
6695
+ prediction_type = "Any"
6696
+ _requirements_list = ["sqlglot", "func_timeout"]
6697
 
6698
+ main_score = "non_empty_execution_accuracy"
6699
 
6700
+ all_metrics = [
6701
+ f.name
6702
+ for f in dataclasses_fields(SQLExecutionResult)
6703
+ if isinstance(f.type, type) and f.type in (int, float)
6704
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6705
 
6706
+ reduction_map = {"mean": all_metrics}
 
 
 
6707
 
6708
+ ci_scores = [
6709
+ "execution_accuracy",
6710
+ "non_empty_execution_accuracy",
6711
+ "subset_non_empty_execution_accuracy",
6712
+ "execution_accuracy_bird",
6713
+ "gold_sql_runtime",
6714
+ "predicted_sql_runtime",
6715
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6716
 
6717
+ def compute(self, references: List[Any], prediction: str, task_data: Dict) -> dict:
6718
+ from .text2sql_utils import (
6719
+ extract_sql_from_text,
6720
+ get_db_connector,
6721
+ get_sql_execution_results,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6722
  )
6723
 
6724
+ predicted_sql = extract_sql_from_text(prediction)
6725
+ gold_sql = references[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6726
 
6727
+ db_connector = get_db_connector(task_data["db"]["db_type"])(task_data["db"])
6728
+ result_obj = get_sql_execution_results(
6729
+ predicted_sql, gold_sql, db_connector, self.sql_timeout
 
 
 
 
 
 
 
 
 
 
6730
  )
6731
 
6732
+ result = asdict(result_obj)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6733
  result["score"] = result[self.main_score]
6734
  result["score_name"] = self.main_score
6735
  logger.debug(f"SQL Execution Accuracy Result: {result}")
 
6737
 
6738
 
6739
  class SQLNonExecutionAccuracy(InstanceMetric):
6740
+ all_metrics = [
6741
+ f.name
6742
+ for f in dataclasses_fields(SQLNonExecutionMetricResult)
6743
+ if isinstance(f.type, type) and f.type in (int, float)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6744
  ]
6745
+ reduction_map = {"mean": all_metrics}
6746
+ main_score = "sqlglot_equivalence"
6747
+ ci_scores = all_metrics
6748
 
6749
  prediction_type = "Any" # string representation is compared
6750
 
6751
  _requirements_list = ["sqlglot", "sqlparse"]
6752
 
6753
  def compute(self, references: List[Any], prediction: str, task_data: Dict) -> dict:
6754
+ from .text2sql_utils import (
6755
+ extract_sql_from_text,
6756
  is_sqlglot_parsable,
6757
  is_sqlparse_parsable,
6758
  sql_exact_match,
 
6761
  sqlparse_queries_equivalent,
6762
  )
6763
 
 
6764
  gold_sql = references[0]
6765
+ predicted_sql = extract_sql_from_text(prediction)
 
 
 
 
 
6766
 
6767
  is_sqlglot_parsable = is_sqlglot_parsable(predicted_sql)
6768
  is_sqlparse_parsable = is_sqlparse_parsable(predicted_sql)
6769
+ result_obj = SQLNonExecutionMetricResult(
6770
+ sqlglot_validity=int(is_sqlglot_parsable),
6771
+ sqlparse_validity=int(is_sqlparse_parsable),
6772
+ sqlglot_equivalence=int(
6773
  sqlglot_parsed_queries_equivalent(predicted_sql, gold_sql)
6774
  if is_sqlglot_parsable
6775
  else 0
6776
  ),
6777
+ sqlglot_optimized_equivalence=int(
6778
  sqlglot_optimized_equivalence(predicted_sql, gold_sql)
6779
  if is_sqlglot_parsable
6780
  else 0
6781
  ),
6782
+ sqlparse_equivalence=int(
6783
  sqlparse_queries_equivalent(predicted_sql, gold_sql)
6784
  if is_sqlparse_parsable
6785
  else 0
6786
  ),
6787
+ sql_exact_match=int(sql_exact_match(predicted_sql, gold_sql)),
6788
+ sql_syntactic_equivalence=0, # will update below
6789
+ )
6790
+
6791
+ result_obj.sql_syntactic_equivalence = int(
6792
  any(
6793
+ [
6794
+ result_obj.sqlglot_equivalence,
6795
+ result_obj.sqlglot_optimized_equivalence,
6796
+ result_obj.sqlparse_equivalence,
6797
+ result_obj.sql_exact_match,
 
6798
  ]
6799
  )
6800
  )
6801
+
6802
+ result = asdict(result_obj)
6803
  logger.debug(f"SQL Non Execution Accuracy Result: {result}")
6804
  result["score"] = result[self.main_score]
6805
  result["score_name"] = self.main_score
operator.py CHANGED
@@ -6,6 +6,7 @@ from pkg_resources import DistributionNotFound, VersionConflict, require
6
 
7
  from .artifact import Artifact
8
  from .dataclass import FinalField, InternalField, NonPositionalField
 
9
  from .settings_utils import get_constants
10
  from .stream import DynamicStream, EmptyStreamError, MultiStream, Stream
11
 
@@ -346,7 +347,8 @@ class StreamOperator(MultiStreamOperator):
346
  def _process_stream(
347
  self, stream: Stream, stream_name: Optional[str] = None
348
  ) -> Generator:
349
- yield from self.process(stream, stream_name)
 
350
 
351
  @abstractmethod
352
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
@@ -384,12 +386,28 @@ class PagedStreamOperator(StreamOperator):
384
  self, stream: Stream, stream_name: Optional[str] = None
385
  ) -> Generator:
386
  page = []
 
387
  for instance in stream:
388
  page.append(instance)
389
  if len(page) >= self.page_size:
390
- yield from self.process(page, stream_name)
 
 
 
 
 
 
391
  page = []
392
- yield from self._process_page(page, stream_name)
 
 
 
 
 
 
 
 
 
393
 
394
  def _process_page(
395
  self, page: List[Dict], stream_name: Optional[str] = None
@@ -442,17 +460,9 @@ class InstanceOperator(StreamOperator):
442
  def _process_stream(
443
  self, stream: Stream, stream_name: Optional[str] = None
444
  ) -> Generator:
445
- try:
446
- _index = None
447
- for _index, instance in enumerate(stream):
448
  yield self._process_instance(instance, stream_name)
449
- except Exception as e:
450
- if _index is None:
451
- raise e
452
- else:
453
- raise ValueError(
454
- f"Error processing instance '{_index}' from stream '{stream_name}' in {self.__class__.__name__} due to the exception above."
455
- ) from e
456
 
457
  def _process_instance(
458
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
 
6
 
7
  from .artifact import Artifact
8
  from .dataclass import FinalField, InternalField, NonPositionalField
9
+ from .error_utils import error_context
10
  from .settings_utils import get_constants
11
  from .stream import DynamicStream, EmptyStreamError, MultiStream, Stream
12
 
 
347
  def _process_stream(
348
  self, stream: Stream, stream_name: Optional[str] = None
349
  ) -> Generator:
350
+ with error_context(self, stream=stream_name):
351
+ yield from self.process(stream, stream_name)
352
 
353
  @abstractmethod
354
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
 
386
  self, stream: Stream, stream_name: Optional[str] = None
387
  ) -> Generator:
388
  page = []
389
+ page_number = 0
390
  for instance in stream:
391
  page.append(instance)
392
  if len(page) >= self.page_size:
393
+ with error_context(
394
+ self,
395
+ stream=stream_name,
396
+ page=page_number,
397
+ page_size=len(page),
398
+ ):
399
+ yield from self.process(page, stream_name)
400
  page = []
401
+ page_number += 1
402
+ if page: # Handle any remaining instances in the last partial page
403
+ with error_context(
404
+ self,
405
+ stream=stream_name,
406
+ page=page_number,
407
+ page_size=len(page),
408
+ final_page=True,
409
+ ):
410
+ yield from self._process_page(page, stream_name)
411
 
412
  def _process_page(
413
  self, page: List[Dict], stream_name: Optional[str] = None
 
460
  def _process_stream(
461
  self, stream: Stream, stream_name: Optional[str] = None
462
  ) -> Generator:
463
+ for _index, instance in enumerate(stream):
464
+ with error_context(self, stream=stream_name, instance=_index):
 
465
  yield self._process_instance(instance, stream_name)
 
 
 
 
 
 
 
466
 
467
  def _process_instance(
468
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
operators.py CHANGED
@@ -67,7 +67,7 @@ from .artifact import Artifact, fetch_artifact
67
  from .dataclass import NonPositionalField, OptionalField
68
  from .deprecation_utils import deprecation
69
  from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
70
- from .error_utils import UnitxtError
71
  from .generator_utils import ReusableGenerator
72
  from .operator import (
73
  InstanceOperator,
@@ -309,7 +309,9 @@ def recursive_key_value_replace(data, target_key, value_map, value_remove=None):
309
  if not isinstance(item, dict) and item not in value_remove
310
  ]
311
  elif isinstance(value, dict):
312
- pass # Skip or handle dict values if needed
 
 
313
  elif value in value_remove:
314
  keys_to_delete.append(key)
315
  elif value in value_map:
@@ -436,6 +438,7 @@ class InstanceFieldOperator(InstanceOperator):
436
  field_to_field: Optional[Union[List[List[str]], Dict[str, str]]] = None
437
  use_query: Optional[bool] = None
438
  process_every_value: bool = False
 
439
  get_default: Any = None
440
  not_exist_ok: bool = False
441
  not_exist_do_nothing: bool = False
@@ -521,7 +524,7 @@ class InstanceFieldOperator(InstanceOperator):
521
  ) -> Dict[str, Any]:
522
  self.verify_field_definition()
523
  for from_field, to_field in self._field_to_field:
524
- try:
525
  old_value = dict_get(
526
  instance,
527
  from_field,
@@ -532,11 +535,8 @@ class InstanceFieldOperator(InstanceOperator):
532
  if self.not_exist_do_nothing:
533
  continue
534
  old_value = self.get_default
535
- except Exception as e:
536
- raise ValueError(
537
- f"Failed to get '{from_field}' from instance due to the exception above."
538
- ) from e
539
- try:
540
  if self.process_every_value:
541
  new_value = [
542
  self.process_instance_value(value, instance)
@@ -544,15 +544,13 @@ class InstanceFieldOperator(InstanceOperator):
544
  ]
545
  else:
546
  new_value = self.process_instance_value(old_value, instance)
547
- except Exception as e:
548
- raise ValueError(
549
- f"Failed to process field '{from_field}' from instance due to the exception above."
550
- ) from e
551
  dict_set(
552
  instance,
553
  to_field,
554
  new_value,
555
  not_exist_ok=True,
 
556
  )
557
  return instance
558
 
@@ -610,11 +608,29 @@ class Rename(FieldOperator):
610
  return res
611
 
612
 
 
 
 
 
 
 
 
 
 
 
 
 
 
613
  @deprecation(version="2.0.0", alternative=Rename)
614
  class RenameFields(Rename):
615
  pass
616
 
617
 
 
 
 
 
 
618
  class AddConstant(FieldOperator):
619
  """Adds a constant, being argument 'add', to the processed value.
620
 
@@ -1200,9 +1216,10 @@ class ApplyOperatorsField(InstanceOperator):
1200
  ) -> Dict[str, Any]:
1201
  operator_names = instance.get(self.operators_field)
1202
  if operator_names is None:
1203
- assert (
1204
- self.default_operators is not None
1205
- ), f"No operators found in field '{self.operators_field}', and no default operators provided."
 
1206
  operator_names = self.default_operators
1207
 
1208
  if isinstance(operator_names, str):
@@ -1436,7 +1453,7 @@ class ExecuteExpression(InstanceOperator, ComputeExpressionMixin):
1436
  def process(
1437
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
1438
  ) -> Dict[str, Any]:
1439
- instance[self.to_field] = self.compute_expression(instance)
1440
  return instance
1441
 
1442
 
@@ -1821,54 +1838,58 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
1821
 
1822
  # to be populated only when two or more metrics
1823
  accumulated_scores = []
 
 
1824
 
1825
- first_instance = stream.peek()
1826
-
1827
- metric_names = first_instance.get(self.metric_field, [])
1828
- if not metric_names:
1829
- raise RuntimeError(
1830
- f"Missing metric names in field '{self.metric_field}' and instance '{first_instance}'."
1831
- )
1832
-
1833
- if isinstance(metric_names, str):
1834
- metric_names = [metric_names]
1835
-
1836
- metrics_list = []
1837
- for metric_name in metric_names:
1838
- metric = self.get_artifact(metric_name)
1839
- if isinstance(metric, MetricsList):
1840
- metrics_list.extend(list(metric.items))
1841
- elif isinstance(metric, Metric):
1842
- metrics_list.append(metric)
1843
- else:
1844
- raise ValueError(
1845
- f"Operator {metric_name} must be a Metric or MetricsList"
1846
  )
1847
 
1848
- for metric in metrics_list:
1849
- metric.set_confidence_interval_calculation(self.calc_confidence_intervals)
1850
- # Each metric operator computes its score and then sets the main score, overwriting
1851
- # the previous main score value (if any). So, we need to reverse the order of the listed metrics.
1852
- # This will cause the first listed metric to run last, and the main score will be set
1853
- # by the first listed metric (as desired).
1854
- metrics_list = list(reversed(metrics_list))
1855
-
1856
- for i, metric in enumerate(metrics_list):
1857
- if i == 0: # first metric
1858
- multi_stream = MultiStream({"tmp": stream})
1859
- else: # metrics with previous scores
1860
- reusable_generator = ReusableGenerator(
1861
- generator=update_scores_of_stream_instances,
1862
- gen_kwargs={"stream": stream, "scores": accumulated_scores},
 
 
 
1863
  )
1864
- multi_stream = MultiStream.from_generators({"tmp": reusable_generator})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1865
 
1866
- multi_stream = metric(multi_stream)
1867
 
1868
- if i < len(metrics_list) - 1: # last metric
1869
- accumulated_scores = []
1870
- for inst in multi_stream["tmp"]:
1871
- accumulated_scores.append(recursive_copy(inst["score"]))
1872
 
1873
  yield from multi_stream["tmp"]
1874
 
 
67
  from .dataclass import NonPositionalField, OptionalField
68
  from .deprecation_utils import deprecation
69
  from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
70
+ from .error_utils import UnitxtError, error_context
71
  from .generator_utils import ReusableGenerator
72
  from .operator import (
73
  InstanceOperator,
 
309
  if not isinstance(item, dict) and item not in value_remove
310
  ]
311
  elif isinstance(value, dict):
312
+ recursive_key_value_replace(
313
+ value, target_key, value_map, value_remove
314
+ )
315
  elif value in value_remove:
316
  keys_to_delete.append(key)
317
  elif value in value_map:
 
438
  field_to_field: Optional[Union[List[List[str]], Dict[str, str]]] = None
439
  use_query: Optional[bool] = None
440
  process_every_value: bool = False
441
+ set_every_value: bool = NonPositionalField(default=False)
442
  get_default: Any = None
443
  not_exist_ok: bool = False
444
  not_exist_do_nothing: bool = False
 
524
  ) -> Dict[str, Any]:
525
  self.verify_field_definition()
526
  for from_field, to_field in self._field_to_field:
527
+ with error_context(self, field=from_field, action="Read Field"):
528
  old_value = dict_get(
529
  instance,
530
  from_field,
 
535
  if self.not_exist_do_nothing:
536
  continue
537
  old_value = self.get_default
538
+
539
+ with error_context(self, field=from_field, action="Process Field"):
 
 
 
540
  if self.process_every_value:
541
  new_value = [
542
  self.process_instance_value(value, instance)
 
544
  ]
545
  else:
546
  new_value = self.process_instance_value(old_value, instance)
547
+
 
 
 
548
  dict_set(
549
  instance,
550
  to_field,
551
  new_value,
552
  not_exist_ok=True,
553
+ set_multiple=self.set_every_value,
554
  )
555
  return instance
556
 
 
608
  return res
609
 
610
 
611
+ class Move(InstanceOperator):
612
+ field: str
613
+ to_field: str
614
+
615
+ def process(
616
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
617
+ ) -> Dict[str, Any]:
618
+ value = dict_get(instance, self.field)
619
+ dict_delete(instance, self.field)
620
+ dict_set(instance, self.to_field, value=value)
621
+ return instance
622
+
623
+
624
  @deprecation(version="2.0.0", alternative=Rename)
625
  class RenameFields(Rename):
626
  pass
627
 
628
 
629
+ class BytesToString(FieldOperator):
630
+ def process_value(self, value: Any) -> Any:
631
+ return str(value)
632
+
633
+
634
  class AddConstant(FieldOperator):
635
  """Adds a constant, being argument 'add', to the processed value.
636
 
 
1216
  ) -> Dict[str, Any]:
1217
  operator_names = instance.get(self.operators_field)
1218
  if operator_names is None:
1219
+ if self.default_operators is None:
1220
+ raise ValueError(
1221
+ f"No operators found in field '{self.operators_field}', and no default operators provided."
1222
+ )
1223
  operator_names = self.default_operators
1224
 
1225
  if isinstance(operator_names, str):
 
1453
  def process(
1454
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
1455
  ) -> Dict[str, Any]:
1456
+ dict_set(instance, self.to_field, self.compute_expression(instance))
1457
  return instance
1458
 
1459
 
 
1838
 
1839
  # to be populated only when two or more metrics
1840
  accumulated_scores = []
1841
+ with error_context(self, stage="Load Metrics"):
1842
+ first_instance = stream.peek()
1843
 
1844
+ metric_names = first_instance.get(self.metric_field, [])
1845
+ if not metric_names:
1846
+ raise RuntimeError(
1847
+ f"Missing metric names in field '{self.metric_field}' and instance '{first_instance}'."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1848
  )
1849
 
1850
+ if isinstance(metric_names, str):
1851
+ metric_names = [metric_names]
1852
+
1853
+ metrics_list = []
1854
+ for metric_name in metric_names:
1855
+ metric = self.get_artifact(metric_name)
1856
+ if isinstance(metric, MetricsList):
1857
+ metrics_list.extend(list(metric.items))
1858
+ elif isinstance(metric, Metric):
1859
+ metrics_list.append(metric)
1860
+ else:
1861
+ raise ValueError(
1862
+ f"Operator {metric_name} must be a Metric or MetricsList"
1863
+ )
1864
+ with error_context(self, stage="Setup Metrics"):
1865
+ for metric in metrics_list:
1866
+ metric.set_confidence_interval_calculation(
1867
+ self.calc_confidence_intervals
1868
  )
1869
+ # Each metric operator computes its score and then sets the main score, overwriting
1870
+ # the previous main score value (if any). So, we need to reverse the order of the listed metrics.
1871
+ # This will cause the first listed metric to run last, and the main score will be set
1872
+ # by the first listed metric (as desired).
1873
+ metrics_list = list(reversed(metrics_list))
1874
+
1875
+ for i, metric in enumerate(metrics_list):
1876
+ if i == 0: # first metric
1877
+ multi_stream = MultiStream({"tmp": stream})
1878
+ else: # metrics with previous scores
1879
+ reusable_generator = ReusableGenerator(
1880
+ generator=update_scores_of_stream_instances,
1881
+ gen_kwargs={"stream": stream, "scores": accumulated_scores},
1882
+ )
1883
+ multi_stream = MultiStream.from_generators(
1884
+ {"tmp": reusable_generator}
1885
+ )
1886
 
1887
+ multi_stream = metric(multi_stream)
1888
 
1889
+ if i < len(metrics_list) - 1: # last metric
1890
+ accumulated_scores = []
1891
+ for inst in multi_stream["tmp"]:
1892
+ accumulated_scores.append(recursive_copy(inst["score"]))
1893
 
1894
  yield from multi_stream["tmp"]
1895
 
processors.py CHANGED
@@ -98,6 +98,16 @@ class ExtractWithRegex(RegexParser):
98
  return ""
99
 
100
 
 
 
 
 
 
 
 
 
 
 
101
  class ListToEmptyEntitiesTuples(FieldOperator):
102
  def process_value(self, lst: Any) -> Any:
103
  try:
@@ -286,7 +296,7 @@ class StringOrNotString(StringEquals):
286
 
287
  class ExtractMtBenchRatingJudgment(FieldOperator):
288
  def process_value(self, text: Any) -> Any:
289
- match = re.search(r"\[\[([\d]+\.?[\d]*)\]\]", text)
290
  try:
291
  return float(match.group(1)) / 10
292
  except:
 
98
  return ""
99
 
100
 
101
+ class GroupDictWithRegex(FieldOperator):
102
+ pattern: str
103
+
104
+ def process_value(self, value: Any) -> Any:
105
+ match = re.match(self.pattern, value)
106
+ if match:
107
+ return match.groupdict()
108
+ return {}
109
+
110
+
111
  class ListToEmptyEntitiesTuples(FieldOperator):
112
  def process_value(self, lst: Any) -> Any:
113
  try:
 
296
 
297
  class ExtractMtBenchRatingJudgment(FieldOperator):
298
  def process_value(self, text: Any) -> Any:
299
+ match = re.search(r"\[\[([\s*\d]+\.?[\d]*\s*)(/\s*10)?\s*\]\]", text)
300
  try:
301
  return float(match.group(1)) / 10
302
  except:
schema.py CHANGED
@@ -59,7 +59,7 @@ def get_schema(stream_name):
59
  def load_chat_source(chat_str):
60
  chat = json.loads(chat_str)
61
  for turn in chat:
62
- if isinstance(turn["content"], list):
63
  for content in turn["content"]:
64
  if content["type"] == "image_url":
65
  content["image_url"]["url"] = ImageDataString(
 
59
  def load_chat_source(chat_str):
60
  chat = json.loads(chat_str)
61
  for turn in chat:
62
+ if "content" in turn and isinstance(turn["content"], list):
63
  for content in turn["content"]:
64
  if content["type"] == "image_url":
65
  content["image_url"]["url"] = ImageDataString(
serializers.py CHANGED
@@ -9,6 +9,7 @@ from .operators import InstanceFieldOperator
9
  from .settings_utils import get_constants
10
  from .type_utils import isoftype, to_type_string
11
  from .types import (
 
12
  Dialog,
13
  Document,
14
  Image,
@@ -75,7 +76,22 @@ class DialogSerializer(SingleTypeSerializer):
75
 
76
  def serialize(self, value: Dialog, instance: Dict[str, Any]) -> str:
77
  # Convert the Dialog into a string representation, typically combining roles and content
78
- return "\n".join(f"{turn['role']}: {turn['content']}" for turn in value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
 
81
  class NumberSerializer(SingleTypeSerializer):
@@ -225,7 +241,7 @@ class SQLDatabaseAsSchemaSerializer(SingleTypeSerializer):
225
  serialized_type = SQLDatabase
226
 
227
  def serialize(self, value: SQLDatabase, instance: Dict[str, Any]) -> str:
228
- from .sql_utils import get_db_connector
229
 
230
  connector = get_db_connector(value["db_type"])(value)
231
  return connector.get_table_schema()
 
9
  from .settings_utils import get_constants
10
  from .type_utils import isoftype, to_type_string
11
  from .types import (
12
+ Conversation,
13
  Dialog,
14
  Document,
15
  Image,
 
76
 
77
  def serialize(self, value: Dialog, instance: Dict[str, Any]) -> str:
78
  # Convert the Dialog into a string representation, typically combining roles and content
79
+ turns = []
80
+ for turn in value:
81
+ turn_str = f"{turn['role']}: "
82
+ if "content" in turn:
83
+ turn_str += str(turn["content"])
84
+ if "tool_calls" in turn:
85
+ turn_str += "\n" + json.dumps(turn["tool_calls"])
86
+ turns.append(turn_str)
87
+ return "\n".join(turns)
88
+
89
+
90
+ class ConversationSerializer(DialogSerializer):
91
+ serialized_type = Conversation
92
+
93
+ def serialize(self, value: Conversation, instance: Dict[str, Any]) -> str:
94
+ return super().serialize(value["dialog"], instance)
95
 
96
 
97
  class NumberSerializer(SingleTypeSerializer):
 
241
  serialized_type = SQLDatabase
242
 
243
  def serialize(self, value: SQLDatabase, instance: Dict[str, Any]) -> str:
244
+ from .text2sql_utils import get_db_connector
245
 
246
  connector = get_db_connector(value["db_type"])(value)
247
  return connector.get_table_schema()
settings_utils.py CHANGED
@@ -1,6 +1,7 @@
1
  import importlib.metadata
2
  import importlib.util
3
  import os
 
4
  from contextlib import contextmanager
5
 
6
  from .version import version
@@ -177,6 +178,9 @@ if Constants.is_uninitilized():
177
  constants.dataset_url = "unitxt/data"
178
  constants.metric_url = "unitxt/metric"
179
  constants.version = version
 
 
 
180
  constants.catalog_hierarchy_sep = "."
181
  constants.env_local_catalogs_paths_sep = ":"
182
  constants.non_registered_files = [
 
1
  import importlib.metadata
2
  import importlib.util
3
  import os
4
+ import sys
5
  from contextlib import contextmanager
6
 
7
  from .version import version
 
178
  constants.dataset_url = "unitxt/data"
179
  constants.metric_url = "unitxt/metric"
180
  constants.version = version
181
+ constants.python = (
182
+ f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
183
+ )
184
  constants.catalog_hierarchy_sep = "."
185
  constants.env_local_catalogs_paths_sep = ":"
186
  constants.non_registered_files = [
struct_data_operators.py CHANGED
@@ -23,6 +23,7 @@ For key-value pairs, expected input format is:
23
  {"key1": "value1", "key2": value2, "key3": "value3"}
24
  """
25
 
 
26
  import json
27
  import random
28
  from abc import ABC, abstractmethod
@@ -754,11 +755,40 @@ class LoadJson(FieldOperator):
754
  return json.loads(value, strict=False)
755
 
756
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
757
  class ToolCallPostProcessor(FieldOperator):
758
  failure_value: Any = None
759
  allow_failure: bool = False
760
 
761
  def process_value(self, value: str) -> ToolCall:
 
 
 
762
  if self.allow_failure:
763
  try:
764
  result = json.loads(value)
@@ -776,6 +806,25 @@ class ToolCallPostProcessor(FieldOperator):
776
  return result
777
 
778
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
779
  class DumpJson(FieldOperator):
780
  def process_value(self, value: str) -> str:
781
  return json.dumps(value)
 
23
  {"key1": "value1", "key2": value2, "key3": "value3"}
24
  """
25
 
26
+ import ast
27
  import json
28
  import random
29
  from abc import ABC, abstractmethod
 
755
  return json.loads(value, strict=False)
756
 
757
 
758
+ class PythonCallProcessor(FieldOperator):
759
+ def process_value(self, value: str) -> ToolCall:
760
+ expr = ast.parse(value, mode="eval").body
761
+ function = expr.func.id
762
+ args = {}
763
+ for kw in expr.keywords:
764
+ args[kw.arg] = ast.literal_eval(kw.value)
765
+ # Handle positional args, if any
766
+ if expr.args:
767
+ args["_args"] = [ast.literal_eval(arg) for arg in expr.args]
768
+ return {"name": function, "arguments": args}
769
+
770
+
771
+ def extract_possible_json_str(text):
772
+ """Extract potential JSON string from text by finding outermost braces/brackets."""
773
+ # Find first opening delimiter
774
+ start_positions = [pos for pos in [text.find("{"), text.find("[")] if pos != -1]
775
+ start = min(start_positions) if start_positions else 0
776
+
777
+ # Find last closing delimiter
778
+ end_positions = [pos for pos in [text.rfind("}"), text.rfind("]")] if pos != -1]
779
+ end = max(end_positions) if end_positions else len(text) - 1
780
+
781
+ return text[start : end + 1]
782
+
783
+
784
  class ToolCallPostProcessor(FieldOperator):
785
  failure_value: Any = None
786
  allow_failure: bool = False
787
 
788
  def process_value(self, value: str) -> ToolCall:
789
+ value = extract_possible_json_str(
790
+ value
791
+ ) # clear tokens such as <tool_call> focusing on the call json itself
792
  if self.allow_failure:
793
  try:
794
  result = json.loads(value)
 
806
  return result
807
 
808
 
809
+ class MultipleToolCallPostProcessor(FieldOperator):
810
+ failure_value: Any = None
811
+ allow_failure: bool = False
812
+
813
+ def process_value(self, value: str) -> List[ToolCall]:
814
+ if self.allow_failure:
815
+ try:
816
+ result = json.loads(value)
817
+ except json.JSONDecodeError:
818
+ return self.failure_value
819
+ else:
820
+ result = json.loads(value, strict=False)
821
+ if isoftype(result, List[ToolCall]):
822
+ return result
823
+ if not isoftype(result, ToolCall):
824
+ return self.failure_value
825
+ return [result]
826
+
827
+
828
  class DumpJson(FieldOperator):
829
  def process_value(self, value: str) -> str:
830
  return json.dumps(value)
task.py CHANGED
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Union
3
 
4
  from .artifact import fetch_artifact
5
  from .deprecation_utils import deprecation
6
- from .error_utils import Documentation, UnitxtError, UnitxtWarning
7
  from .logging_utils import get_logger
8
  from .metrics import MetricsList
9
  from .operator import InstanceOperator
@@ -285,13 +285,18 @@ class Task(InstanceOperator, ArtifactFetcherMixin):
285
  ) -> Dict[str, Any]:
286
  instance = self.set_default_values(instance)
287
 
288
- verify_required_schema(
289
- self.input_fields,
290
- instance,
291
- class_name="Task",
292
- id=self.__id__,
293
- description=self.__description__,
294
- )
 
 
 
 
 
295
  input_fields = {key: instance[key] for key in self.input_fields.keys()}
296
  data_classification_policy = instance.get("data_classification_policy", [])
297
 
 
3
 
4
  from .artifact import fetch_artifact
5
  from .deprecation_utils import deprecation
6
+ from .error_utils import Documentation, UnitxtError, UnitxtWarning, error_context
7
  from .logging_utils import get_logger
8
  from .metrics import MetricsList
9
  from .operator import InstanceOperator
 
285
  ) -> Dict[str, Any]:
286
  instance = self.set_default_values(instance)
287
 
288
+ with error_context(
289
+ self,
290
+ stage="Schema Verification",
291
+ help="https://www.unitxt.ai/en/latest/docs/adding_task.html",
292
+ ):
293
+ verify_required_schema(
294
+ self.input_fields,
295
+ instance,
296
+ class_name="Task",
297
+ id=self.__id__,
298
+ description=self.__description__,
299
+ )
300
  input_fields = {key: instance[key] for key in self.input_fields.keys()}
301
  data_classification_policy = instance.get("data_classification_policy", [])
302
 
templates.py CHANGED
@@ -6,11 +6,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union
6
  from .artifact import Artifact
7
  from .collections import DictCollection, ListCollection
8
  from .dataclass import NonPositionalField
9
- from .dict_utils import dict_set
10
  from .error_utils import Documentation, UnitxtError
11
  from .operator import InstanceOperator, Operator
12
  from .random_utils import new_random_generator
13
  from .serializers import (
 
14
  DialogSerializer,
15
  ImageSerializer,
16
  ListSerializer,
@@ -68,6 +69,7 @@ class Template(InstanceOperator):
68
  ToolCallSerializer(),
69
  ToolsSerializer(),
70
  DialogSerializer(),
 
71
  ListSerializer(),
72
  SQLDatabaseAsSchemaSerializer(),
73
  ]
@@ -942,6 +944,16 @@ class MultiReferenceTemplate(InputOutputTemplate):
942
  return target, references
943
 
944
 
 
 
 
 
 
 
 
 
 
 
945
  def escape_chars(s, chars_to_escape):
946
  for char in chars_to_escape:
947
  s = s.replace(char, f"\\{char}")
 
6
  from .artifact import Artifact
7
  from .collections import DictCollection, ListCollection
8
  from .dataclass import NonPositionalField
9
+ from .dict_utils import dict_get, dict_set
10
  from .error_utils import Documentation, UnitxtError
11
  from .operator import InstanceOperator, Operator
12
  from .random_utils import new_random_generator
13
  from .serializers import (
14
+ ConversationSerializer,
15
  DialogSerializer,
16
  ImageSerializer,
17
  ListSerializer,
 
69
  ToolCallSerializer(),
70
  ToolsSerializer(),
71
  DialogSerializer(),
72
+ ConversationSerializer(),
73
  ListSerializer(),
74
  SQLDatabaseAsSchemaSerializer(),
75
  ]
 
944
  return target, references
945
 
946
 
947
+ class MultiTurnTemplate(MultiReferenceTemplate):
948
+ input_format = ""
949
+ turns_field: str
950
+
951
+ def post_process_instance(self, instance):
952
+ turns = dict_get(instance["input_fields"], self.turns_field)
953
+ instance["__turns__"] = turns
954
+ return super().post_process_instance(instance)
955
+
956
+
957
  def escape_chars(s, chars_to_escape):
958
  for char in chars_to_escape:
959
  s = s.replace(char, f"\\{char}")
sql_utils.py → text2sql_utils.py RENAMED
@@ -7,9 +7,13 @@ import re
7
  import sqlite3
8
  import time
9
  from abc import ABC, abstractmethod
 
 
10
  from functools import lru_cache
11
- from typing import Any, List, Optional
12
 
 
 
13
  import requests
14
  from huggingface_hub import snapshot_download
15
  from requests.exceptions import ConnectionError, ReadTimeout
@@ -539,6 +543,17 @@ def get_db_connector(db_type: str):
539
  return connector
540
 
541
 
 
 
 
 
 
 
 
 
 
 
 
542
  def is_sqlglot_parsable(sql: str, db_type="sqlite") -> bool:
543
  """Returns True if sqlglot does not encounter any error, False otherwise."""
544
  from sqlglot import parse
@@ -695,7 +710,7 @@ def extract_select_info(sql: str):
695
 
696
 
697
  def sqlparse_queries_equivalent(sql1: str, sql2: str) -> bool:
698
- """Return True if both SQL queries are naively considered equivalent."""
699
  try:
700
  info1 = extract_select_info(sql1)
701
  info2 = extract_select_info(sql2)
@@ -713,6 +728,7 @@ def sqlparse_queries_equivalent(sql1: str, sql2: str) -> bool:
713
 
714
 
715
  def sqlglot_parsed_queries_equivalent(sql1: str, sql2: str, dialect: str = "") -> bool:
 
716
  from sqlglot import exp, parse_one
717
 
718
  try:
@@ -754,3 +770,473 @@ def sql_exact_match(sql1: str, sql2: str) -> bool:
754
  return s.upper()
755
 
756
  return normalize_sql(sql1) == normalize_sql(sql2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import sqlite3
8
  import time
9
  from abc import ABC, abstractmethod
10
+ from collections import Counter
11
+ from dataclasses import dataclass
12
  from functools import lru_cache
13
+ from typing import Any, List, Optional, Tuple
14
 
15
+ import numpy as np
16
+ import pandas as pd
17
  import requests
18
  from huggingface_hub import snapshot_download
19
  from requests.exceptions import ConnectionError, ReadTimeout
 
543
  return connector
544
 
545
 
546
+ @dataclass
547
+ class SQLNonExecutionMetricResult:
548
+ sqlglot_validity: int # Whether SQL parses with sqlglot
549
+ sqlparse_validity: int # Whether SQL parses with sqlparse
550
+ sqlglot_equivalence: int # Semantic equivalence using sqlglot AST
551
+ sqlglot_optimized_equivalence: int # Equivalence after optimization via sqlglot
552
+ sqlparse_equivalence: int # Equivalence using sqlparse AST
553
+ sql_exact_match: int # Exact string match of predicted and gold SQL
554
+ sql_syntactic_equivalence: int # Any of the above equivalence conditions hold
555
+
556
+
557
  def is_sqlglot_parsable(sql: str, db_type="sqlite") -> bool:
558
  """Returns True if sqlglot does not encounter any error, False otherwise."""
559
  from sqlglot import parse
 
710
 
711
 
712
  def sqlparse_queries_equivalent(sql1: str, sql2: str) -> bool:
713
+ """Returns True if both SQL queries are naively considered equivalent."""
714
  try:
715
  info1 = extract_select_info(sql1)
716
  info2 = extract_select_info(sql2)
 
728
 
729
 
730
  def sqlglot_parsed_queries_equivalent(sql1: str, sql2: str, dialect: str = "") -> bool:
731
+ """Return True if two SQL queries match after parsing with SQLGlot."""
732
  from sqlglot import exp, parse_one
733
 
734
  try:
 
770
  return s.upper()
771
 
772
  return normalize_sql(sql1) == normalize_sql(sql2)
773
+
774
+
775
+ @dataclass
776
+ class SQLExecutionResult:
777
+ execution_accuracy: int # Whether the predicted and gold SQL results match exactly
778
+ non_empty_execution_accuracy: (
779
+ int # Same as execution_accuracy but only if gold is non-empty
780
+ )
781
+ subset_non_empty_execution_accuracy: (
782
+ int # Whether predicted is a subset of gold or vice versa, non-empty only
783
+ )
784
+ execution_accuracy_bird: (
785
+ int # Whether the predicted SQL matches gold using BIRD evaluation logic
786
+ )
787
+ non_empty_gold_df: int # Whether the gold SQL produced a non-empty dataframe
788
+ gold_sql_runtime: float # Time taken to execute the gold SQL
789
+ predicted_sql_runtime: float # Time taken to execute the predicted SQL
790
+ pred_to_gold_runtime_ratio: float # Ratio of predicted runtime to gold runtime
791
+ gold_error: int # Whether the gold SQL had an execution error
792
+ predicted_error: int # Whether the predicted SQL had an execution error
793
+ gold_df_json: str # JSON representation of the gold SQL result dataframe
794
+ predicted_df_json: str # JSON representation of the predicted SQL result dataframe
795
+ error_message: str # Error message from predicted execution if any
796
+
797
+
798
+ def compare_dfs_ignore_colnames_ordered_rows(
799
+ df1: pd.DataFrame, df2: pd.DataFrame
800
+ ) -> bool:
801
+ if df1.shape != df2.shape:
802
+ return False
803
+ df1_sorted_rows = np.array([np.sort(row) for row in df1.values.astype(str)])
804
+ df2_sorted_rows = np.array([np.sort(row) for row in df2.values.astype(str)])
805
+ return np.array_equal(df1_sorted_rows, df2_sorted_rows)
806
+
807
+
808
+ def compare_dfs_ignore_colnames_unordered_rows(
809
+ df1: pd.DataFrame, df2: pd.DataFrame
810
+ ) -> bool:
811
+ if df1.shape != df2.shape:
812
+ return False
813
+ df1_sorted = np.sort(np.sort(df1.values.astype(str), axis=1), axis=0)
814
+ df2_sorted = np.sort(np.sort(df2.values.astype(str), axis=1), axis=0)
815
+ return np.array_equal(df1_sorted, df2_sorted)
816
+
817
+
818
+ def compare_dfs_ignore_colnames_subset(
819
+ df1: pd.DataFrame, df2: pd.DataFrame, ignore_row_order: bool = True
820
+ ) -> bool:
821
+ """Checks if the smaller of the two DataFrames is likely a subset of the other.
822
+
823
+ Subset comparison is column-based, to support Text2SQL evaluation for when the
824
+ predicted SQL dataframe has missing or additional columns. Each row is treated as
825
+ a multiset of (stringified) values, and the function checks if every row in the
826
+ smaller DataFrame (by column count) is a multiset subset of the corresponding row
827
+ in the larger DataFrame. When ground truth SQL does not have ORDER BY,
828
+ ignore_row_order can be set to True to ignore the order of rows. In this case,
829
+ column values are sorted before comparison. This means that there could be cases
830
+ where the dataframes have the exact same number of rows and column values after
831
+ sort are the same, but the dataframes are not actually a subset of each other.
832
+ This is unlikely to happen in practice, but the score is not guaranteed to be
833
+ 100% accurate and may overestimate the accuracy.
834
+
835
+ Args:
836
+ df1 (pd.DataFrame): The first DataFrame to compare.
837
+ df2 (pd.DataFrame): The second DataFrame to compare.
838
+ ignore_row_order (bool, optional): If True, ignores the order of rows by
839
+ sorting them before comparison. Defaults to True.
840
+
841
+ Returns:
842
+ bool: True if the smaller DataFrame (column-wise) is likely a subset of the
843
+ larger one, False otherwise.
844
+ """
845
+
846
+ def row_to_multiset(row):
847
+ return Counter(str(x) for x in row)
848
+
849
+ def rows_to_multisets(df):
850
+ return [row_to_multiset(row) for row in df.values]
851
+
852
+ def sort_df(df):
853
+ sorted_df = df.copy()
854
+ for col in sorted_df.columns:
855
+ sorted_df[col] = sorted_df[col].astype(str).sort_values(ignore_index=True)
856
+ return sorted_df
857
+
858
+ if df1.empty or df2.empty or len(df1) != len(df2):
859
+ return False
860
+
861
+ subset_df, superset_df = (df1, df2) if df1.shape[1] <= df2.shape[1] else (df2, df1)
862
+
863
+ if ignore_row_order:
864
+ subset_df = sort_df(subset_df)
865
+ superset_df = sort_df(superset_df)
866
+
867
+ subset_rows = rows_to_multisets(subset_df)
868
+ superset_rows = rows_to_multisets(superset_df)
869
+
870
+ for r1, r2 in zip(subset_rows, superset_rows):
871
+ if not all(r1[k] <= r2.get(k, 0) for k in r1):
872
+ return False
873
+ return True
874
+
875
+
876
+ def compare_dfs_bird_eval_logic(df1: pd.DataFrame, df2: pd.DataFrame):
877
+ """Check if two SQL query result sets are exactly equal, as in BIRD evaluation.
878
+
879
+ This function checks if the set of rows returned by the predicted SQL query
880
+ (`predicted_res`) is exactly equal to the set of rows returned by the ground truth
881
+ SQL query (`ground_truth_res`). This is the logic used in the original BIRD
882
+ evaluation code:
883
+ https://github.com/AlibabaResearch/DAMO-ConvAI/blob/main/bird/llm/src/evaluation.py.
884
+ """
885
+ df1_set = {tuple(row) for row in df1.values.astype(str)}
886
+ df2_set = {tuple(row) for row in df2.values.astype(str)}
887
+ return int(df1_set == df2_set)
888
+
889
+
890
+ def compare_result_dfs(
891
+ gold_df: pd.DataFrame, pred_df: pd.DataFrame, gold_sql: str
892
+ ) -> Tuple[int, int, int]:
893
+ """Compares two DataFrames representing SQL query results.
894
+
895
+ Args:
896
+ gold_df (pd.DataFrame): The ground truth DataFrame.
897
+ pred_df (pd.DataFrame): The predicted DataFrame.
898
+ gold_sql (str): The ground truth SQL query string.
899
+
900
+ Returns:
901
+ Tuple[int, int, int]: A tuple containing:
902
+ - match (int): 1 if the predicted DataFrame matches the gold DataFrame
903
+ - non_empty_match (int): 1 if both DataFrames are non-empty and match,
904
+ 0 otherwise.
905
+ - subset_match (int): 1 if the predicted DataFrame is a subset or
906
+ superset of the gold DataFrame.
907
+
908
+ Notes:
909
+ - The comparison ignores column names.
910
+ - Row order is considered only if 'ORDER BY' is present in the SQL query.
911
+ """
912
+ subset_match = 0
913
+ non_empty_match = 0
914
+ if "ORDER BY" in gold_sql.upper():
915
+ match = int(compare_dfs_ignore_colnames_ordered_rows(pred_df, gold_df))
916
+ if not gold_df.empty and not pred_df.empty:
917
+ non_empty_match = match
918
+ if compare_dfs_ignore_colnames_subset(
919
+ gold_df, pred_df, ignore_row_order=False
920
+ ):
921
+ subset_match = 1
922
+ else:
923
+ match = int(compare_dfs_ignore_colnames_unordered_rows(pred_df, gold_df))
924
+ if not gold_df.empty and not pred_df.empty:
925
+ non_empty_match = match
926
+ if compare_dfs_ignore_colnames_subset(
927
+ gold_df, pred_df, ignore_row_order=True
928
+ ):
929
+ subset_match = 1
930
+ return match, non_empty_match, subset_match
931
+
932
+
933
+ def run_query(
934
+ sql: str, connector, sql_timeout: float
935
+ ) -> Tuple[Optional[pd.DataFrame], float, str]:
936
+ """Executes a SQL query using the provided connector with a timeout.
937
+
938
+ Args:
939
+ sql (str): The SQL query string to execute.
940
+ connector: An object with an `execute_query` method that executes the SQL
941
+ query.
942
+ sql_timeout (float): The maximum time in seconds to allow for query
943
+ execution.
944
+
945
+ Returns:
946
+ Tuple[Optional[pd.DataFrame], float, str]:
947
+ - A pandas DataFrame containing the query results, or None if an error
948
+ occurred.
949
+ - The duration in seconds taken to execute the query. 0.0 if an error.
950
+ - An error message string if an error occurred, otherwise an empty
951
+ string.
952
+
953
+ Notes:
954
+ - If the SQL string is empty or only whitespace, returns immediately with a
955
+ message.
956
+ - If the query execution exceeds the timeout, returns a timeout error
957
+ message.
958
+ - Any other exceptions are caught and returned as error messages.
959
+ """
960
+ import time
961
+
962
+ from func_timeout import func_timeout
963
+ from func_timeout.exceptions import FunctionTimedOut
964
+
965
+ if not sql.strip():
966
+ return None, 0.0, "No SQL query found in the prediction."
967
+
968
+ try:
969
+ start = time.perf_counter()
970
+ result, error = func_timeout(sql_timeout, connector.execute_query, args=(sql,))
971
+ duration = time.perf_counter() - start
972
+ if isinstance(result, dict) and "results" in result:
973
+ result = result["results"]
974
+ if error:
975
+ return None, duration, error
976
+ return pd.DataFrame(result), duration, ""
977
+ except FunctionTimedOut as e:
978
+ return None, 0.0, f"Timeout: {e}"
979
+ except Exception as e:
980
+ return None, 0.0, f"Error: {e}"
981
+
982
+
983
+ def get_sql_execution_results(
984
+ predicted_sql: str, gold_sql: str, connector, sql_timeout: float
985
+ ) -> SQLExecutionResult:
986
+ """Execute and compare predicted and gold SQL queries, returning execution metrics.
987
+
988
+ Args:
989
+ predicted_sql (str): The SQL query predicted by the model.
990
+ gold_sql (str): The reference (gold) SQL query.
991
+ connector: Database connector object used to execute the queries.
992
+ sql_timeout (float): Maximum time (in seconds) allowed for query execution.
993
+
994
+ Returns:
995
+ SQLExecutionResult: An object containing various execution metrics, including:
996
+ - execution_accuracy (int): 1 if predicted and gold queries produce
997
+ equivalent results, else 0.
998
+ - non_empty_execution_accuracy (int): 1 if both queries produce non-empty
999
+ and equivalent results, else 0.
1000
+ - subset_non_empty_execution_accuracy (int): 1 if predicted results are a
1001
+ subset or superset of gold results and non-empty, else 0. Subset
1002
+ comparison is column-based. This means that the predicted SQL dataframe
1003
+ can have missing or additional columns compared to the gold SQL dataframe.
1004
+ - execution_accuracy_bird (int): 1 if results match according to BIRD
1005
+ evaluation logic, else 0.
1006
+ - non_empty_gold_df (int): 1 if the gold query result is non-empty, else 0.
1007
+ - gold_sql_runtime (float): Execution time for the gold SQL query.
1008
+ - predicted_sql_runtime (float): Execution time for the predicted SQL query.
1009
+ - pred_to_gold_runtime_ratio (float): Ratio of predicted to gold query
1010
+ runtimes.
1011
+ - gold_error (int): 1 if the gold query failed, else 0.
1012
+ - predicted_error (int): 1 if the predicted query failed, else 0.
1013
+ - gold_df_json (str): JSON representation of the gold query result
1014
+ DataFrame.
1015
+ - predicted_df_json (str): JSON representation of the predicted query
1016
+ result DataFrame.
1017
+ - error_message (str): Error message if any query failed, else empty
1018
+ string.
1019
+
1020
+ Notes:
1021
+ - If the gold query fails, the function returns early with error details.
1022
+ - If the predicted query is identical or SQL-equivalent to the gold query,
1023
+ results are considered correct without execution.
1024
+ - Otherwise, both queries are executed and their results compared using
1025
+ multiple metrics.
1026
+ """
1027
+ gold_df, gold_runtime, gold_error_msg = run_query(gold_sql, connector, sql_timeout)
1028
+ gold_error = int(bool(gold_error_msg))
1029
+
1030
+ if gold_error:
1031
+ return SQLExecutionResult(
1032
+ execution_accuracy=0,
1033
+ non_empty_execution_accuracy=0,
1034
+ subset_non_empty_execution_accuracy=0,
1035
+ execution_accuracy_bird=0,
1036
+ non_empty_gold_df=0,
1037
+ gold_sql_runtime=gold_runtime,
1038
+ predicted_sql_runtime=0,
1039
+ pred_to_gold_runtime_ratio=0,
1040
+ gold_error=gold_error,
1041
+ predicted_error=0,
1042
+ gold_df_json="",
1043
+ predicted_df_json="",
1044
+ error_message=gold_error_msg,
1045
+ )
1046
+
1047
+ non_empty_gold_df = int(not gold_df.empty)
1048
+ if predicted_sql.strip().lower() == gold_sql.strip().lower():
1049
+ return SQLExecutionResult(
1050
+ execution_accuracy=1,
1051
+ non_empty_execution_accuracy=non_empty_gold_df,
1052
+ subset_non_empty_execution_accuracy=non_empty_gold_df,
1053
+ execution_accuracy_bird=1,
1054
+ non_empty_gold_df=non_empty_gold_df,
1055
+ gold_sql_runtime=gold_runtime,
1056
+ predicted_sql_runtime=0,
1057
+ pred_to_gold_runtime_ratio=0,
1058
+ gold_error=0,
1059
+ predicted_error=0,
1060
+ gold_df_json=gold_df.to_json(),
1061
+ predicted_df_json=gold_df.to_json(),
1062
+ error_message="",
1063
+ )
1064
+
1065
+ try:
1066
+ if sqlglot_optimized_equivalence(gold_sql, predicted_sql):
1067
+ return SQLExecutionResult(
1068
+ execution_accuracy=1,
1069
+ non_empty_execution_accuracy=non_empty_gold_df,
1070
+ subset_non_empty_execution_accuracy=non_empty_gold_df,
1071
+ execution_accuracy_bird=1,
1072
+ non_empty_gold_df=non_empty_gold_df,
1073
+ gold_sql_runtime=gold_runtime,
1074
+ predicted_sql_runtime=0,
1075
+ pred_to_gold_runtime_ratio=0,
1076
+ gold_error=0,
1077
+ predicted_error=0,
1078
+ gold_df_json=gold_df.to_json(),
1079
+ predicted_df_json=gold_df.to_json(),
1080
+ error_message="",
1081
+ )
1082
+ except Exception as e:
1083
+ logger.info(f"Could not check SQL equivalence: {e}")
1084
+
1085
+ pred_df, pred_runtime, pred_error_msg = run_query(
1086
+ predicted_sql, connector, sql_timeout
1087
+ )
1088
+ pred_error = 1 if pred_error_msg else 0
1089
+
1090
+ if pred_df is None:
1091
+ return SQLExecutionResult(
1092
+ execution_accuracy=0,
1093
+ non_empty_execution_accuracy=0,
1094
+ subset_non_empty_execution_accuracy=0,
1095
+ execution_accuracy_bird=0,
1096
+ non_empty_gold_df=non_empty_gold_df,
1097
+ gold_sql_runtime=gold_runtime,
1098
+ predicted_sql_runtime=pred_runtime,
1099
+ pred_to_gold_runtime_ratio=(pred_runtime / gold_runtime)
1100
+ if gold_runtime > 0
1101
+ else 0,
1102
+ gold_error=0,
1103
+ predicted_error=pred_error,
1104
+ gold_df_json=gold_df.to_json(),
1105
+ predicted_df_json="",
1106
+ error_message=pred_error_msg,
1107
+ )
1108
+
1109
+ match, non_empty_match, subset_match = compare_result_dfs(
1110
+ gold_df, pred_df, gold_sql
1111
+ )
1112
+ bird_match = compare_dfs_bird_eval_logic(gold_df, pred_df)
1113
+
1114
+ return SQLExecutionResult(
1115
+ execution_accuracy=match,
1116
+ non_empty_execution_accuracy=non_empty_match,
1117
+ subset_non_empty_execution_accuracy=subset_match,
1118
+ execution_accuracy_bird=bird_match,
1119
+ non_empty_gold_df=non_empty_gold_df,
1120
+ gold_sql_runtime=gold_runtime,
1121
+ predicted_sql_runtime=pred_runtime,
1122
+ pred_to_gold_runtime_ratio=(pred_runtime / gold_runtime)
1123
+ if gold_runtime > 0
1124
+ else 0,
1125
+ gold_error=0,
1126
+ predicted_error=0,
1127
+ gold_df_json=gold_df.to_json(),
1128
+ predicted_df_json=pred_df.to_json(),
1129
+ error_message=pred_error_msg,
1130
+ )
1131
+
1132
+
1133
+ def replace_select_clause(
1134
+ source_query: str, target_query: str, dialect: str = "postgres"
1135
+ ) -> str:
1136
+ """Replaces the SELECT clause of the target SQL query with the SELECT clause from the source SQL query.
1137
+
1138
+ Args:
1139
+ source_query (str): SQL query whose SELECT clause will be used.
1140
+ target_query (str): SQL query whose SELECT clause will be replaced.
1141
+ dialect (str): SQL dialect for parsing and rendering (default: "postgres").
1142
+
1143
+ Returns:
1144
+ str: A new SQL query with the SELECT clause of `target_query` replaced by that of `source_query`.
1145
+
1146
+ Raises:
1147
+ ValueError: If either query is not a valid SELECT statement.
1148
+
1149
+ Example:
1150
+ >>> replace_select_clause(
1151
+ ... "SELECT id FROM employees",
1152
+ ... "SELECT name FROM employees WHERE age > 30"
1153
+ ... )
1154
+ "SELECT id FROM employees WHERE age > 30"
1155
+ """
1156
+ from sqlglot import exp, parse_one
1157
+
1158
+ if not dialect:
1159
+ dialect = "postgres"
1160
+
1161
+ # Parse queries using the specified dialect
1162
+ source_ast = parse_one(source_query, read=dialect)
1163
+ target_ast = parse_one(target_query, read=dialect)
1164
+
1165
+ if not isinstance(source_ast, exp.Select) or not isinstance(target_ast, exp.Select):
1166
+ raise ValueError("Both queries must be valid SELECT statements.")
1167
+
1168
+ # Replace SELECT expressions in the target with those from the source
1169
+ target_ast.set("expressions", source_ast.expressions)
1170
+
1171
+ # Return the updated SQL string using the dialect
1172
+ return target_ast.sql(dialect=dialect)
1173
+
1174
+
1175
+ def extract_sql_from_text(text: str) -> str:
1176
+ """Extracts the first SQL query from the given text.
1177
+
1178
+ Priority:
1179
+ 1. SQL inside fenced blocks like ```sql ... ```
1180
+ 2. SQL starting on a new line or after a colon/label
1181
+ 3. SQL without semicolon
1182
+
1183
+ Returns:
1184
+ The SQL query string, or an empty string if not found.
1185
+ """
1186
+ # 1. Look for fenced SQL code block
1187
+ fenced_block_pattern = re.compile(r"```sql\s+(.*?)```", re.IGNORECASE | re.DOTALL)
1188
+ match = fenced_block_pattern.search(text)
1189
+ if match:
1190
+ return match.group(1).strip()
1191
+
1192
+ # 2. Inline SQL with semicolon
1193
+ sql_keywords = r"(?:SELECT|INSERT|UPDATE|DELETE|WITH)\s+"
1194
+ sql_start = (
1195
+ r"(?:^|\n|:\s*)" # Start of string, newline, or colon label like "Just run:"
1196
+ )
1197
+ sql_pattern = re.compile(
1198
+ rf"{sql_start}({sql_keywords}.*?;)", re.IGNORECASE | re.DOTALL
1199
+ )
1200
+ match = sql_pattern.search(text)
1201
+ if match:
1202
+ return match.group(1).strip()
1203
+
1204
+ # 3. Inline SQL without semicolon
1205
+ fallback_pattern = re.compile(
1206
+ rf"{sql_start}({sql_keywords}.*)", re.IGNORECASE | re.DOTALL
1207
+ )
1208
+ fallback_match = fallback_pattern.search(text)
1209
+ if fallback_match:
1210
+ return fallback_match.group(1).strip()
1211
+
1212
+ return ""
1213
+
1214
+
1215
+ ALL_DIALECTS = [
1216
+ "Athena",
1217
+ "BigQuery",
1218
+ "ClickHouse",
1219
+ "Databricks",
1220
+ "Doris",
1221
+ "Drill",
1222
+ "Druid",
1223
+ "DuckDB",
1224
+ "Hive",
1225
+ "Materialize",
1226
+ "MySQL",
1227
+ "Oracle",
1228
+ "Postgres",
1229
+ "Presto",
1230
+ "PRQL",
1231
+ "Redshift",
1232
+ "RisingWave",
1233
+ "Snowflake",
1234
+ "Spark",
1235
+ "Spark2",
1236
+ "SQLite",
1237
+ "StarRocks",
1238
+ "Tableau",
1239
+ "Teradata",
1240
+ "Trino",
1241
+ "TSQL",
1242
+ ]
type_utils.py CHANGED
@@ -503,9 +503,25 @@ def isoftype(object, typing_type):
503
  if is_typed_dict(typing_type):
504
  if not isinstance(object, dict):
505
  return False
 
 
506
  for key, expected_type in typing_type.__annotations__.items():
507
- if key not in object or not isoftype(object[key], expected_type):
508
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
  return True
510
 
511
  if typing_type == typing.Any:
 
503
  if is_typed_dict(typing_type):
504
  if not isinstance(object, dict):
505
  return False
506
+
507
+ # Only support total=True, check each field
508
  for key, expected_type in typing_type.__annotations__.items():
509
+ # Check if field is Optional (Union with None)
510
+ is_optional = (
511
+ hasattr(expected_type, "__origin__")
512
+ and expected_type.__origin__ is Union
513
+ and type(None) in expected_type.__args__
514
+ )
515
+
516
+ if key not in object:
517
+ # Field is missing - only allowed if it's Optional
518
+ if not is_optional:
519
+ return False
520
+ else:
521
+ # Field is present - check type
522
+ if not isoftype(object[key], expected_type):
523
+ return False
524
+
525
  return True
526
 
527
  if typing_type == typing.Any:
types.py CHANGED
@@ -6,8 +6,52 @@ Text = NewType("Text", str)
6
  Number = NewType("Number", Union[float, int])
7
 
8
 
9
- class Turn(TypedDict):
10
- role: Literal["system", "user", "agent"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  content: Text
12
 
13
 
@@ -18,7 +62,12 @@ class RagResponse(TypedDict):
18
  is_answerable: bool
19
 
20
 
21
- Dialog = NewType("Dialog", List[Turn])
 
 
 
 
 
22
 
23
 
24
  class Image(TypedDict):
@@ -52,36 +101,17 @@ class SQLDatabase(TypedDict):
52
  data: Optional[Dict[str, Dict]]
53
 
54
 
55
- class JsonSchema:
56
- @classmethod
57
- def __verify_type__(cls, object):
58
- if not isinstance(object, dict):
59
- return False
60
- import jsonschema_rs
61
-
62
- jsonschema_rs.meta.validate(object)
63
- return True
64
-
65
-
66
- class Tool(TypedDict):
67
- name: str
68
- description: str
69
- parameters: JsonSchema
70
-
71
-
72
- class ToolCall(TypedDict):
73
- name: str
74
- arguments: Dict[str, Any]
75
-
76
-
77
  register_type(Text)
78
  register_type(Number)
79
- register_type(Turn)
 
 
80
  register_type(Dialog)
81
  register_type(Table)
82
  register_type(Audio)
83
  register_type(Image)
84
  register_type(Video)
 
85
  register_type(Document)
86
  register_type(MultiDocument)
87
  register_type(RagResponse)
 
6
  Number = NewType("Number", Union[float, int])
7
 
8
 
9
+ class JsonSchema:
10
+ @classmethod
11
+ def __verify_type__(cls, object):
12
+ if not isinstance(object, dict):
13
+ return False
14
+ import jsonschema_rs
15
+
16
+ jsonschema_rs.meta.validate(object)
17
+ return True
18
+
19
+
20
+ class Tool(TypedDict):
21
+ # Original fields
22
+ name: str
23
+ description: str
24
+ parameters: JsonSchema
25
+ # LiteLLM extension
26
+ type: Optional[Literal["function"]]
27
+
28
+
29
+ class ToolCall(TypedDict):
30
+ name: str
31
+ arguments: Dict[str, Any]
32
+
33
+
34
+ class ToolCallContext(TypedDict):
35
+ id: str
36
+ type: Literal["function"]
37
+ function: ToolCall
38
+
39
+
40
+ class ToolCallTurn(TypedDict):
41
+ role: Literal["assistant"]
42
+ content: Optional[str]
43
+ tool_calls: List[ToolCallContext]
44
+
45
+
46
+ class ToolOutputTurn(TypedDict):
47
+ role: Literal["tool"]
48
+ tool_call_id: str
49
+ name: str
50
+ content: str
51
+
52
+
53
+ class TextTurn(TypedDict):
54
+ role: Literal["system", "user", "agent", "assistant"]
55
  content: Text
56
 
57
 
 
62
  is_answerable: bool
63
 
64
 
65
+ Dialog = NewType("Dialog", List[Union[TextTurn, ToolCallTurn, ToolOutputTurn]])
66
+
67
+
68
+ class Conversation(TypedDict):
69
+ id: str
70
+ dialog: Dialog
71
 
72
 
73
  class Image(TypedDict):
 
101
  data: Optional[Dict[str, Dict]]
102
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  register_type(Text)
105
  register_type(Number)
106
+ register_type(TextTurn)
107
+ register_type(ToolCallTurn)
108
+ register_type(ToolOutputTurn)
109
  register_type(Dialog)
110
  register_type(Table)
111
  register_type(Audio)
112
  register_type(Image)
113
  register_type(Video)
114
+ register_type(Conversation)
115
  register_type(Document)
116
  register_type(MultiDocument)
117
  register_type(RagResponse)
version.py CHANGED
@@ -1 +1 @@
1
- version = "1.24.0"
 
1
+ version = "1.25.0"