Upload folder using huggingface_hub
Browse files- api.py +49 -5
- artifact.py +28 -18
- collections_operators.py +60 -2
- dataclass.py +59 -0
- dataset.py +1 -1
- dialog_operators.py +10 -1
- dict_utils.py +1 -1
- error_utils.py +254 -12
- evaluate_cli.py +56 -3
- formats.py +53 -9
- fusion.py +14 -16
- inference.py +229 -126
- llm_as_judge_constants.py +1 -2
- loaders.py +107 -81
- metric.py +1 -1
- metric_utils.py +19 -12
- metrics.py +548 -654
- operator.py +23 -13
- operators.py +79 -58
- processors.py +11 -1
- schema.py +1 -1
- serializers.py +18 -2
- settings_utils.py +4 -0
- struct_data_operators.py +49 -0
- task.py +13 -8
- templates.py +13 -1
- sql_utils.py → text2sql_utils.py +488 -2
- type_utils.py +18 -2
- types.py +56 -26
- version.py +1 -1
api.py
CHANGED
@@ -11,6 +11,7 @@ from datasets.exceptions import DatasetGenerationError
|
|
11 |
from .artifact import fetch_artifact
|
12 |
from .benchmark import Benchmark
|
13 |
from .card import TaskCard
|
|
|
14 |
from .dataset_utils import get_dataset_artifact
|
15 |
from .error_utils import UnitxtError
|
16 |
from .inference import (
|
@@ -149,6 +150,36 @@ def create_dataset(
|
|
149 |
return load_dataset(card=card, split=split, **kwargs)
|
150 |
|
151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
def _source_to_dataset(
|
153 |
source: SourceOperator,
|
154 |
split=None,
|
@@ -157,22 +188,35 @@ def _source_to_dataset(
|
|
157 |
):
|
158 |
from .dataset import Dataset as UnitxtDataset
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
stream = source()
|
161 |
|
162 |
try:
|
163 |
ds_builder = UnitxtDataset(
|
164 |
dataset_name="unitxt",
|
165 |
-
config_name=
|
166 |
version=constants.version,
|
167 |
)
|
168 |
if split is not None:
|
169 |
stream = {split: stream[split]}
|
170 |
ds_builder._generators = stream
|
171 |
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
if streaming:
|
178 |
return ds_builder.as_streaming_dataset(split=split)
|
|
|
11 |
from .artifact import fetch_artifact
|
12 |
from .benchmark import Benchmark
|
13 |
from .card import TaskCard
|
14 |
+
from .dataclass import to_dict
|
15 |
from .dataset_utils import get_dataset_artifact
|
16 |
from .error_utils import UnitxtError
|
17 |
from .inference import (
|
|
|
150 |
return load_dataset(card=card, split=split, **kwargs)
|
151 |
|
152 |
|
153 |
+
def object_to_str_without_addresses(obj):
|
154 |
+
"""Generates a string representation of a Python object while removing memory address references.
|
155 |
+
|
156 |
+
This function is useful for creating consistent and comparable string representations of objects
|
157 |
+
that would otherwise include memory addresses (e.g., `<object_name at 0x123abc>`), which can vary
|
158 |
+
between executions. By stripping the memory address, the function ensures that the representation
|
159 |
+
is stable and independent of the object's location in memory.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
obj: Any Python object to be converted to a string representation.
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
str: A string representation of the object with memory addresses removed if present.
|
166 |
+
|
167 |
+
Example:
|
168 |
+
```python
|
169 |
+
class MyClass:
|
170 |
+
pass
|
171 |
+
|
172 |
+
obj = MyClass()
|
173 |
+
print(str(obj)) # "<__main__.MyClass object at 0x7f8b9d4d6e20>"
|
174 |
+
print(to_str_without_addresses(obj)) # "<__main__.MyClass object>"
|
175 |
+
```
|
176 |
+
"""
|
177 |
+
obj_str = str(obj)
|
178 |
+
if " at 0x" in obj_str:
|
179 |
+
obj_str = obj_str.split(" at 0x")[0] + ">"
|
180 |
+
return obj_str
|
181 |
+
|
182 |
+
|
183 |
def _source_to_dataset(
|
184 |
source: SourceOperator,
|
185 |
split=None,
|
|
|
188 |
):
|
189 |
from .dataset import Dataset as UnitxtDataset
|
190 |
|
191 |
+
# Generate a unique signature for the source
|
192 |
+
source_signature = json.dumps(
|
193 |
+
to_dict(source, object_to_str_without_addresses), sort_keys=True
|
194 |
+
)
|
195 |
+
config_name = "recipe-" + short_hex_hash(source_signature)
|
196 |
+
# Obtain data stream from the source
|
197 |
stream = source()
|
198 |
|
199 |
try:
|
200 |
ds_builder = UnitxtDataset(
|
201 |
dataset_name="unitxt",
|
202 |
+
config_name=config_name, # Dictate the cache name
|
203 |
version=constants.version,
|
204 |
)
|
205 |
if split is not None:
|
206 |
stream = {split: stream[split]}
|
207 |
ds_builder._generators = stream
|
208 |
|
209 |
+
try:
|
210 |
+
ds_builder.download_and_prepare(
|
211 |
+
verification_mode="no_checks",
|
212 |
+
download_mode=None if use_cache else "force_redownload",
|
213 |
+
)
|
214 |
+
except DatasetGenerationError as e:
|
215 |
+
if e.__cause__:
|
216 |
+
raise e.__cause__ from None
|
217 |
+
if e.__context__:
|
218 |
+
raise e.__context__ from None
|
219 |
+
raise
|
220 |
|
221 |
if streaming:
|
222 |
return ds_builder.as_streaming_dataset(split=split)
|
artifact.py
CHANGED
@@ -16,13 +16,13 @@ from .dataclass import (
|
|
16 |
NonPositionalField,
|
17 |
fields,
|
18 |
)
|
19 |
-
from .error_utils import Documentation, UnitxtError, UnitxtWarning
|
20 |
from .logging_utils import get_logger
|
21 |
from .parsing_utils import (
|
22 |
separate_inside_and_outside_square_brackets,
|
23 |
)
|
24 |
from .settings_utils import get_constants, get_settings
|
25 |
-
from .text_utils import camel_to_snake_case, is_camel_case
|
26 |
from .type_utils import isoftype, issubtype
|
27 |
from .utils import (
|
28 |
artifacts_json_cache,
|
@@ -342,8 +342,10 @@ class Artifact(Dataclass):
|
|
342 |
self.verify_data_classification_policy()
|
343 |
self.prepare_args()
|
344 |
if not settings.skip_artifacts_prepare_and_verify:
|
345 |
-
self
|
346 |
-
|
|
|
|
|
347 |
|
348 |
def _to_raw_dict(self):
|
349 |
return {
|
@@ -367,11 +369,14 @@ class Artifact(Dataclass):
|
|
367 |
|
368 |
def to_json(self):
|
369 |
data = self.to_dict()
|
|
|
370 |
return json_dump(data)
|
371 |
|
372 |
def to_yaml(self):
|
|
|
|
|
373 |
data = self.to_dict()
|
374 |
-
return
|
375 |
|
376 |
def serialize(self):
|
377 |
if self.__id__ is not None:
|
@@ -449,20 +454,25 @@ class Artifact(Dataclass):
|
|
449 |
)
|
450 |
return instance
|
451 |
|
452 |
-
|
453 |
-
|
454 |
-
|
|
|
455 |
):
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
|
|
|
|
|
|
|
|
466 |
|
467 |
return instance
|
468 |
|
|
|
16 |
NonPositionalField,
|
17 |
fields,
|
18 |
)
|
19 |
+
from .error_utils import Documentation, UnitxtError, UnitxtWarning, error_context
|
20 |
from .logging_utils import get_logger
|
21 |
from .parsing_utils import (
|
22 |
separate_inside_and_outside_square_brackets,
|
23 |
)
|
24 |
from .settings_utils import get_constants, get_settings
|
25 |
+
from .text_utils import camel_to_snake_case, is_camel_case
|
26 |
from .type_utils import isoftype, issubtype
|
27 |
from .utils import (
|
28 |
artifacts_json_cache,
|
|
|
342 |
self.verify_data_classification_policy()
|
343 |
self.prepare_args()
|
344 |
if not settings.skip_artifacts_prepare_and_verify:
|
345 |
+
with error_context(self, action="Prepare Object"):
|
346 |
+
self.prepare()
|
347 |
+
with error_context(self, action="Verify Object"):
|
348 |
+
self.verify()
|
349 |
|
350 |
def _to_raw_dict(self):
|
351 |
return {
|
|
|
369 |
|
370 |
def to_json(self):
|
371 |
data = self.to_dict()
|
372 |
+
|
373 |
return json_dump(data)
|
374 |
|
375 |
def to_yaml(self):
|
376 |
+
import yaml
|
377 |
+
|
378 |
data = self.to_dict()
|
379 |
+
return yaml.dump(data)
|
380 |
|
381 |
def serialize(self):
|
382 |
if self.__id__ is not None:
|
|
|
454 |
)
|
455 |
return instance
|
456 |
|
457 |
+
with error_context(
|
458 |
+
self,
|
459 |
+
action="Sensitive Data Verification",
|
460 |
+
help="https://www.unitxt.ai/en/latest/docs/data_classification_policy.html",
|
461 |
):
|
462 |
+
if not any(
|
463 |
+
data_classification in data_classification_policy
|
464 |
+
for data_classification in instance_data_classification
|
465 |
+
):
|
466 |
+
raise UnitxtError(
|
467 |
+
f"The instance '{instance} 'has the following data classification policy "
|
468 |
+
f"'{instance_data_classification}', however, the artifact '{name}' "
|
469 |
+
f"is only configured to support the data with classification "
|
470 |
+
f"'{data_classification_policy}'. To enable this either change "
|
471 |
+
f"the 'data_classification_policy' attribute of the artifact, "
|
472 |
+
f"or modify the environment variable "
|
473 |
+
f"'UNITXT_DATA_CLASSIFICATION_POLICY' accordingly.",
|
474 |
+
Documentation.DATA_CLASSIFICATION_POLICY,
|
475 |
+
)
|
476 |
|
477 |
return instance
|
478 |
|
collections_operators.py
CHANGED
@@ -1,6 +1,8 @@
|
|
|
|
1 |
from typing import Any, Dict, Generator, List, Optional
|
2 |
|
3 |
from .dict_utils import dict_get, dict_set
|
|
|
4 |
from .operators import FieldOperator, StreamOperator
|
5 |
from .stream import Stream
|
6 |
from .utils import recursive_shallow_copy
|
@@ -13,11 +15,52 @@ class Dictify(FieldOperator):
|
|
13 |
return dict(zip(self.with_keys, tup))
|
14 |
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
class DictToTuplesList(FieldOperator):
|
17 |
def process_value(self, dic: Dict) -> Any:
|
18 |
return list(dic.items())
|
19 |
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
class Wrap(FieldOperator):
|
22 |
inside: str
|
23 |
|
@@ -64,6 +107,13 @@ class Get(FieldOperator):
|
|
64 |
return collection[self.item]
|
65 |
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
class DuplicateByList(StreamOperator):
|
68 |
field: str
|
69 |
to_field: Optional[str] = None
|
@@ -91,12 +141,16 @@ class DuplicateBySubLists(StreamOperator):
|
|
91 |
field: str
|
92 |
to_field: Optional[str] = None
|
93 |
use_deep_copy: bool = False
|
|
|
|
|
|
|
94 |
|
95 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
96 |
to_field = self.field if self.to_field is None else self.to_field
|
97 |
for instance in stream:
|
98 |
-
elements = instance
|
99 |
-
|
|
|
100 |
if self.use_deep_copy:
|
101 |
instance_copy = recursive_shallow_copy(instance)
|
102 |
instance_copy[to_field] = elements[:i]
|
@@ -109,6 +163,10 @@ class DuplicateBySubLists(StreamOperator):
|
|
109 |
yield instance_copy
|
110 |
|
111 |
|
|
|
|
|
|
|
|
|
112 |
class GetLength(FieldOperator):
|
113 |
def process_value(self, collection: Any) -> Any:
|
114 |
return len(collection)
|
|
|
1 |
+
from itertools import zip_longest
|
2 |
from typing import Any, Dict, Generator, List, Optional
|
3 |
|
4 |
from .dict_utils import dict_get, dict_set
|
5 |
+
from .operator import InstanceOperator
|
6 |
from .operators import FieldOperator, StreamOperator
|
7 |
from .stream import Stream
|
8 |
from .utils import recursive_shallow_copy
|
|
|
15 |
return dict(zip(self.with_keys, tup))
|
16 |
|
17 |
|
18 |
+
class Zip(InstanceOperator):
|
19 |
+
fields: List[str]
|
20 |
+
to_field: str
|
21 |
+
|
22 |
+
def zip(self, values):
|
23 |
+
return list(zip(*values))
|
24 |
+
|
25 |
+
def process(
|
26 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
27 |
+
) -> Dict[str, Any]:
|
28 |
+
values = []
|
29 |
+
for field in self.fields:
|
30 |
+
values.append(dict_get(instance, field))
|
31 |
+
dict_set(instance, self.to_field, self.zip(values))
|
32 |
+
return instance
|
33 |
+
|
34 |
+
|
35 |
+
class ZipLongest(Zip):
|
36 |
+
fields: List[str]
|
37 |
+
fill_value: Any = None
|
38 |
+
|
39 |
+
def zip(self, values):
|
40 |
+
return list(zip_longest(*values, fillvalue=self.fill_value))
|
41 |
+
|
42 |
+
|
43 |
class DictToTuplesList(FieldOperator):
|
44 |
def process_value(self, dic: Dict) -> Any:
|
45 |
return list(dic.items())
|
46 |
|
47 |
|
48 |
+
def flatten(container):
|
49 |
+
def _flat_gen(x):
|
50 |
+
for item in x:
|
51 |
+
if isinstance(item, (list, tuple)):
|
52 |
+
yield from _flat_gen(item)
|
53 |
+
else:
|
54 |
+
yield item
|
55 |
+
|
56 |
+
return type(container)(_flat_gen(container))
|
57 |
+
|
58 |
+
|
59 |
+
class Flatten(FieldOperator):
|
60 |
+
def process_value(self, value: Any) -> Any:
|
61 |
+
return flatten(value)
|
62 |
+
|
63 |
+
|
64 |
class Wrap(FieldOperator):
|
65 |
inside: str
|
66 |
|
|
|
107 |
return collection[self.item]
|
108 |
|
109 |
|
110 |
+
class Pop(FieldOperator):
|
111 |
+
item: Any = None
|
112 |
+
|
113 |
+
def process_value(self, collection: Any) -> Any:
|
114 |
+
return collection.pop(self.item)
|
115 |
+
|
116 |
+
|
117 |
class DuplicateByList(StreamOperator):
|
118 |
field: str
|
119 |
to_field: Optional[str] = None
|
|
|
141 |
field: str
|
142 |
to_field: Optional[str] = None
|
143 |
use_deep_copy: bool = False
|
144 |
+
start: int = 1
|
145 |
+
end: int = 0
|
146 |
+
step: int = 1
|
147 |
|
148 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
149 |
to_field = self.field if self.to_field is None else self.to_field
|
150 |
for instance in stream:
|
151 |
+
elements = dict_get(instance, self.field)
|
152 |
+
end = len(elements) + 1 + self.end
|
153 |
+
for i in range(self.start, end, self.step):
|
154 |
if self.use_deep_copy:
|
155 |
instance_copy = recursive_shallow_copy(instance)
|
156 |
instance_copy[to_field] = elements[:i]
|
|
|
163 |
yield instance_copy
|
164 |
|
165 |
|
166 |
+
class ExplodeSubLists(DuplicateBySubLists):
|
167 |
+
pass
|
168 |
+
|
169 |
+
|
170 |
class GetLength(FieldOperator):
|
171 |
def process_value(self, collection: Any) -> Any:
|
172 |
return len(collection)
|
dataclass.py
CHANGED
@@ -297,6 +297,65 @@ def _asdict_inner(obj):
|
|
297 |
return copy.deepcopy(obj)
|
298 |
|
299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
class DataclassMeta(ABCMeta):
|
301 |
"""Metaclass for Dataclass.
|
302 |
|
|
|
297 |
return copy.deepcopy(obj)
|
298 |
|
299 |
|
300 |
+
def to_dict(obj, func=copy.deepcopy, _visited=None):
|
301 |
+
"""Recursively converts an object into a dictionary representation while avoiding infinite recursion due to circular references.
|
302 |
+
|
303 |
+
Args:
|
304 |
+
obj: Any Python object to be converted into a dictionary-like structure.
|
305 |
+
func (Callable, optional): A function applied to non-iterable objects. Defaults to `copy.deepcopy`.
|
306 |
+
_visited (set, optional): A set of object IDs used to track visited objects and prevent infinite recursion.
|
307 |
+
|
308 |
+
Returns:
|
309 |
+
dict: A dictionary representation of the input object, with supported collections and dataclasses
|
310 |
+
recursively processed.
|
311 |
+
|
312 |
+
Notes:
|
313 |
+
- Supports dataclasses, named tuples, lists, tuples, and dictionaries.
|
314 |
+
- Circular references are detected using object IDs and replaced by `func(obj)`.
|
315 |
+
- Named tuples retain their original type instead of being converted to dictionaries.
|
316 |
+
"""
|
317 |
+
# Initialize visited set on first call
|
318 |
+
if _visited is None:
|
319 |
+
_visited = set()
|
320 |
+
|
321 |
+
# Get object ID to track visited objects
|
322 |
+
obj_id = id(obj)
|
323 |
+
|
324 |
+
# If we've seen this object before, return a placeholder to avoid infinite recursion
|
325 |
+
if obj_id in _visited:
|
326 |
+
return func(obj)
|
327 |
+
|
328 |
+
# For mutable objects, add to visited set before recursing
|
329 |
+
if (
|
330 |
+
isinstance(obj, (dict, list))
|
331 |
+
or is_dataclass(obj)
|
332 |
+
or (isinstance(obj, tuple) and hasattr(obj, "_fields"))
|
333 |
+
):
|
334 |
+
_visited.add(obj_id)
|
335 |
+
|
336 |
+
if is_dataclass(obj):
|
337 |
+
return {
|
338 |
+
field.name: to_dict(getattr(obj, field.name), func, _visited)
|
339 |
+
for field in fields(obj)
|
340 |
+
}
|
341 |
+
|
342 |
+
if isinstance(obj, tuple) and hasattr(obj, "_fields"): # named tuple
|
343 |
+
return type(obj)(*[to_dict(v, func, _visited) for v in obj])
|
344 |
+
|
345 |
+
if isinstance(obj, (list, tuple)):
|
346 |
+
return type(obj)([to_dict(v, func, _visited) for v in obj])
|
347 |
+
|
348 |
+
if isinstance(obj, dict):
|
349 |
+
return type(obj)(
|
350 |
+
{
|
351 |
+
to_dict(k, func, _visited): to_dict(v, func, _visited)
|
352 |
+
for k, v in obj.items()
|
353 |
+
}
|
354 |
+
)
|
355 |
+
|
356 |
+
return func(obj)
|
357 |
+
|
358 |
+
|
359 |
class DataclassMeta(ABCMeta):
|
360 |
"""Metaclass for Dataclass.
|
361 |
|
dataset.py
CHANGED
@@ -59,7 +59,6 @@ from .settings_utils import get_constants
|
|
59 |
from .span_lableing_operators import __file__ as _
|
60 |
from .split_utils import __file__ as _
|
61 |
from .splitters import __file__ as _
|
62 |
-
from .sql_utils import __file__ as _
|
63 |
from .standard import __file__ as _
|
64 |
from .stream import __file__ as _
|
65 |
from .stream_operators import __file__ as _
|
@@ -68,6 +67,7 @@ from .struct_data_operators import __file__ as _
|
|
68 |
from .system_prompts import __file__ as _
|
69 |
from .task import __file__ as _
|
70 |
from .templates import __file__ as _
|
|
|
71 |
from .text_utils import __file__ as _
|
72 |
from .type_utils import __file__ as _
|
73 |
from .types import __file__ as _
|
|
|
59 |
from .span_lableing_operators import __file__ as _
|
60 |
from .split_utils import __file__ as _
|
61 |
from .splitters import __file__ as _
|
|
|
62 |
from .standard import __file__ as _
|
63 |
from .stream import __file__ as _
|
64 |
from .stream_operators import __file__ as _
|
|
|
67 |
from .system_prompts import __file__ as _
|
68 |
from .task import __file__ as _
|
69 |
from .templates import __file__ as _
|
70 |
+
from .text2sql_utils import __file__ as _
|
71 |
from .text_utils import __file__ as _
|
72 |
from .type_utils import __file__ as _
|
73 |
from .types import __file__ as _
|
dialog_operators.py
CHANGED
@@ -17,7 +17,16 @@ The format of the dialog is:
|
|
17 |
from typing import Any, Dict, List, Optional
|
18 |
|
19 |
from .formats import SystemFormat
|
20 |
-
from .operators import InstanceFieldOperator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
|
23 |
class SerializeDialog(InstanceFieldOperator):
|
|
|
17 |
from typing import Any, Dict, List, Optional
|
18 |
|
19 |
from .formats import SystemFormat
|
20 |
+
from .operators import FieldOperator, InstanceFieldOperator
|
21 |
+
|
22 |
+
|
23 |
+
class ToDialog(FieldOperator):
|
24 |
+
def process_value(self, value: Any) -> Any:
|
25 |
+
dialog = []
|
26 |
+
for question, answer in value:
|
27 |
+
dialog.append({"role": "user", "content": question})
|
28 |
+
dialog.append({"role": "agent", "content": answer})
|
29 |
+
return dialog
|
30 |
|
31 |
|
32 |
class SerializeDialog(InstanceFieldOperator):
|
dict_utils.py
CHANGED
@@ -3,7 +3,7 @@ from typing import Any, List, Tuple
|
|
3 |
|
4 |
from .text_utils import to_pretty_string
|
5 |
|
6 |
-
indx = re.compile(r"
|
7 |
|
8 |
|
9 |
def is_index(string):
|
|
|
3 |
|
4 |
from .text_utils import to_pretty_string
|
5 |
|
6 |
+
indx = re.compile(r"^-?\d+$")
|
7 |
|
8 |
|
9 |
def is_index(string):
|
error_utils.py
CHANGED
@@ -1,7 +1,11 @@
|
|
1 |
-
|
|
|
|
|
2 |
|
3 |
from .logging_utils import get_logger
|
|
|
4 |
|
|
|
5 |
logger = get_logger()
|
6 |
|
7 |
|
@@ -29,12 +33,9 @@ class UnitxtError(Exception):
|
|
29 |
"""Exception raised for Unitxt errors.
|
30 |
|
31 |
Args:
|
32 |
-
message (str):
|
33 |
-
|
34 |
-
|
35 |
-
relative path to additional documentation on web
|
36 |
-
If set, should be one of the DOCUMENATION_* constants in the error_utils.py file.
|
37 |
-
|
38 |
"""
|
39 |
|
40 |
def __init__(self, message: str, additional_info_id: Optional[str] = None):
|
@@ -47,14 +48,255 @@ class UnitxtWarning:
|
|
47 |
"""Object to format warning message to log.
|
48 |
|
49 |
Args:
|
50 |
-
message (str):
|
51 |
-
|
52 |
-
|
53 |
-
relative path to additional documentation on web
|
54 |
-
If set, should be one of the DOCUMENATION_* constants in the error_utils.py file.
|
55 |
"""
|
56 |
|
57 |
def __init__(self, message: str, additional_info_id: Optional[str] = None):
|
58 |
if additional_info_id is not None:
|
59 |
message += additional_info(additional_info_id)
|
60 |
logger.warning(message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from contextlib import contextmanager
|
3 |
+
from typing import Any, Optional
|
4 |
|
5 |
from .logging_utils import get_logger
|
6 |
+
from .settings_utils import get_constants
|
7 |
|
8 |
+
constants = get_constants()
|
9 |
logger = get_logger()
|
10 |
|
11 |
|
|
|
33 |
"""Exception raised for Unitxt errors.
|
34 |
|
35 |
Args:
|
36 |
+
message (str): explanation of the error
|
37 |
+
additional_info_id (Optional[str]): relative path to additional documentation on web
|
38 |
+
If set, should be one of the DOCUMENTATION_* constants in the error_utils.py file.
|
|
|
|
|
|
|
39 |
"""
|
40 |
|
41 |
def __init__(self, message: str, additional_info_id: Optional[str] = None):
|
|
|
48 |
"""Object to format warning message to log.
|
49 |
|
50 |
Args:
|
51 |
+
message (str): explanation of the warning
|
52 |
+
additional_info_id (Optional[str]): relative path to additional documentation on web
|
53 |
+
If set, should be one of the DOCUMENTATION_* constants in the error_utils.py file.
|
|
|
|
|
54 |
"""
|
55 |
|
56 |
def __init__(self, message: str, additional_info_id: Optional[str] = None):
|
57 |
if additional_info_id is not None:
|
58 |
message += additional_info(additional_info_id)
|
59 |
logger.warning(message)
|
60 |
+
|
61 |
+
|
62 |
+
context_block_title = "🦄 Unitxt Error Context"
|
63 |
+
|
64 |
+
|
65 |
+
def _visible_length(text: str) -> int:
|
66 |
+
import unicodedata
|
67 |
+
|
68 |
+
ansi_escape = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]|\x1b\]8;;[^\x1b]*\x1b\\")
|
69 |
+
clean_text = ansi_escape.sub("", text)
|
70 |
+
width = 0
|
71 |
+
for char in clean_text:
|
72 |
+
if (
|
73 |
+
unicodedata.east_asian_width(char) in ("F", "W")
|
74 |
+
or 0x1F300 <= ord(char) <= 0x1F9FF
|
75 |
+
):
|
76 |
+
width += 2
|
77 |
+
else:
|
78 |
+
width += 1
|
79 |
+
return width
|
80 |
+
|
81 |
+
|
82 |
+
def _make_object_clickable(
|
83 |
+
full_obj_name: str, display_name: Optional[str] = None
|
84 |
+
) -> str:
|
85 |
+
import os
|
86 |
+
|
87 |
+
if display_name is None:
|
88 |
+
display_name = full_obj_name.split(".")[-1]
|
89 |
+
if full_obj_name.startswith("unitxt."):
|
90 |
+
parts = full_obj_name.split(".")
|
91 |
+
if len(parts) >= 2:
|
92 |
+
module_path = ".".join(parts[:2])
|
93 |
+
doc_url = f"{Documentation.URL}{module_path}.html#{full_obj_name}"
|
94 |
+
if (
|
95 |
+
os.environ.get("TERM_PROGRAM") in ["iTerm.app", "vscode"]
|
96 |
+
or os.environ.get("TERMINAL_EMULATOR") == "JetBrains-JediTerm"
|
97 |
+
):
|
98 |
+
return f"\033]8;;{doc_url}\033\\{display_name}\033]8;;\033\\"
|
99 |
+
return f"{display_name} ({doc_url})"
|
100 |
+
return display_name
|
101 |
+
|
102 |
+
|
103 |
+
def _get_existing_context(error: Exception):
|
104 |
+
"""Extract existing context from an error if it exists."""
|
105 |
+
if hasattr(error, "__error_context__"):
|
106 |
+
existing = error.__error_context__
|
107 |
+
return (
|
108 |
+
existing["original_message"],
|
109 |
+
existing["context_object"],
|
110 |
+
existing["context"],
|
111 |
+
)
|
112 |
+
return str(error), None, {}
|
113 |
+
|
114 |
+
|
115 |
+
def _format_object_context(obj: Any) -> Optional[str]:
|
116 |
+
"""Format an object for display in error context."""
|
117 |
+
if obj is None:
|
118 |
+
return None
|
119 |
+
if hasattr(obj, "__class__"):
|
120 |
+
class_name = obj.__class__.__name__
|
121 |
+
module_name = getattr(obj.__class__, "__module__", "")
|
122 |
+
else:
|
123 |
+
obj_type = type(obj)
|
124 |
+
class_name = obj_type.__name__
|
125 |
+
module_name = getattr(obj_type, "__module__", "")
|
126 |
+
if module_name:
|
127 |
+
full_name = f"{module_name}.{class_name}"
|
128 |
+
clickable_object = _make_object_clickable(full_name, class_name)
|
129 |
+
return f"Object: {clickable_object}"
|
130 |
+
return f"Object: {class_name}"
|
131 |
+
|
132 |
+
|
133 |
+
def _make_clickable_link(url: str) -> str:
|
134 |
+
"""Create a clickable terminal link."""
|
135 |
+
import os
|
136 |
+
|
137 |
+
if (
|
138 |
+
os.environ.get("TERM_PROGRAM") in ["iTerm.app", "vscode"]
|
139 |
+
or os.environ.get("TERMINAL_EMULATOR") == "JetBrains-JediTerm"
|
140 |
+
):
|
141 |
+
return f"\033]8;;{url}\033\\link\033]8;;\033\\"
|
142 |
+
return url
|
143 |
+
|
144 |
+
|
145 |
+
def _format_help_context(help_docs) -> list:
|
146 |
+
"""Format help documentation into context parts."""
|
147 |
+
parts = []
|
148 |
+
if isinstance(help_docs, str):
|
149 |
+
parts.append(f"Help: {_make_clickable_link(help_docs)}")
|
150 |
+
elif isinstance(help_docs, dict):
|
151 |
+
for label, url in help_docs.items():
|
152 |
+
parts.append(f"Help ({label}): {_make_clickable_link(url)}")
|
153 |
+
elif isinstance(help_docs, list):
|
154 |
+
for item in help_docs:
|
155 |
+
if isinstance(item, dict) and len(item) == 1:
|
156 |
+
label, url = next(iter(item.items()))
|
157 |
+
parts.append(f"Help ({label}): {_make_clickable_link(url)}")
|
158 |
+
elif isinstance(item, str):
|
159 |
+
parts.append(f"Help: {_make_clickable_link(item)}")
|
160 |
+
return parts
|
161 |
+
|
162 |
+
|
163 |
+
def _build_context_parts(context_object: Any, context: dict) -> list:
|
164 |
+
"""Build the list of context information parts."""
|
165 |
+
parts = []
|
166 |
+
ordered_keys = [
|
167 |
+
"Python",
|
168 |
+
"Unitxt",
|
169 |
+
"Stage",
|
170 |
+
"Stream",
|
171 |
+
"Index",
|
172 |
+
"Instance",
|
173 |
+
"Object",
|
174 |
+
"Action",
|
175 |
+
]
|
176 |
+
processed_keys = set()
|
177 |
+
|
178 |
+
for desired_key in ordered_keys:
|
179 |
+
for actual_key in context.keys():
|
180 |
+
if actual_key.lower() == desired_key.lower():
|
181 |
+
value = (
|
182 |
+
"unknown" if context[actual_key] is None else context[actual_key]
|
183 |
+
)
|
184 |
+
parts.append(f"{actual_key.replace('_', ' ').title()}: {value}")
|
185 |
+
processed_keys.add(actual_key)
|
186 |
+
break
|
187 |
+
|
188 |
+
if not any(key.lower() == "object" for key in processed_keys):
|
189 |
+
obj_context = _format_object_context(context_object)
|
190 |
+
if obj_context:
|
191 |
+
parts.append(obj_context)
|
192 |
+
|
193 |
+
processed_keys.add("help")
|
194 |
+
for key, value in context.items():
|
195 |
+
if key not in processed_keys:
|
196 |
+
value = "unknown" if value is None else value
|
197 |
+
parts.append(f"{key.replace('_', ' ').title()}: {value}")
|
198 |
+
|
199 |
+
if "help" in context:
|
200 |
+
parts.extend(_format_help_context(context["help"]))
|
201 |
+
else:
|
202 |
+
parts.append(f"Help: {_make_clickable_link(Documentation.URL)}")
|
203 |
+
|
204 |
+
return parts
|
205 |
+
|
206 |
+
|
207 |
+
def _create_context_box(parts: list) -> str:
|
208 |
+
"""Create a formatted box containing context information."""
|
209 |
+
if not parts:
|
210 |
+
return ""
|
211 |
+
max_width = (
|
212 |
+
max(
|
213 |
+
_visible_length(context_block_title),
|
214 |
+
max(_visible_length(part) for part in parts),
|
215 |
+
)
|
216 |
+
+ 4
|
217 |
+
)
|
218 |
+
top_line = "┌" + "─" * max_width + "┐"
|
219 |
+
bottom_line = "└" + "─" * max_width + "┘"
|
220 |
+
lines = [top_line]
|
221 |
+
lines.append(
|
222 |
+
f"│ {context_block_title}{' ' * (max_width - _visible_length(context_block_title) - 1)}│"
|
223 |
+
)
|
224 |
+
lines.append(f"│ {'-' * (max_width - 2)} │")
|
225 |
+
for part in parts:
|
226 |
+
padding = " " * (max_width - _visible_length(part) - 4)
|
227 |
+
lines.append(f"│ - {part}{padding}│")
|
228 |
+
lines.append(bottom_line)
|
229 |
+
return "\n".join(lines)
|
230 |
+
|
231 |
+
|
232 |
+
def _store_context_attributes(
|
233 |
+
error: Exception, context_object: Any, context: dict, original_message: str
|
234 |
+
):
|
235 |
+
"""Store context information in error attributes."""
|
236 |
+
error.__error_context__ = {
|
237 |
+
"context_object": context_object,
|
238 |
+
"context": context,
|
239 |
+
"original_message": original_message,
|
240 |
+
}
|
241 |
+
try:
|
242 |
+
error.original_error = type(error)(original_message)
|
243 |
+
except (TypeError, ValueError):
|
244 |
+
error.original_error = Exception(original_message)
|
245 |
+
error.context_object = context_object
|
246 |
+
error.context = context
|
247 |
+
|
248 |
+
|
249 |
+
def _add_context_to_exception(
|
250 |
+
original_error: Exception, context_object: Any = None, **context
|
251 |
+
):
|
252 |
+
"""Add context information to an exception by modifying its message."""
|
253 |
+
original_message, existing_object, existing_context = _get_existing_context(
|
254 |
+
original_error
|
255 |
+
)
|
256 |
+
final_context_object = existing_object or context_object
|
257 |
+
final_context = {
|
258 |
+
"Unitxt": constants.version,
|
259 |
+
"Python": constants.python,
|
260 |
+
**existing_context,
|
261 |
+
**context,
|
262 |
+
}
|
263 |
+
context_parts = _build_context_parts(final_context_object, final_context)
|
264 |
+
context_message = _create_context_box(context_parts)
|
265 |
+
_store_context_attributes(
|
266 |
+
original_error, final_context_object, final_context, original_message
|
267 |
+
)
|
268 |
+
if context_parts:
|
269 |
+
formatted_message = f"\n{context_message}\n\n{original_message}"
|
270 |
+
original_error.args = (formatted_message,)
|
271 |
+
else:
|
272 |
+
original_error.args = (original_message,)
|
273 |
+
|
274 |
+
|
275 |
+
@contextmanager
|
276 |
+
def error_context(context_object: Any = None, **context):
|
277 |
+
"""Context manager that catches exceptions and re-raises them with additional context.
|
278 |
+
|
279 |
+
Args:
|
280 |
+
context_object: The object being processed (optional)
|
281 |
+
**context: Any additional context to include in the error message.
|
282 |
+
You can provide any key-value pairs that help identify where the error occurred.
|
283 |
+
|
284 |
+
Special context keys:
|
285 |
+
- help: Documentation links to help with the error.
|
286 |
+
Can be a string (single URL), dict (label: URL), or list of URLs/dicts.
|
287 |
+
|
288 |
+
Examples:
|
289 |
+
with error_context(self, operation="validation", item_id=42):
|
290 |
+
result = process_item(item)
|
291 |
+
|
292 |
+
with error_context(operation="schema_validation", help="https://docs.example.com/schema"):
|
293 |
+
validate_schema(data)
|
294 |
+
|
295 |
+
with error_context(processor, step="preprocessing", batch_size=32):
|
296 |
+
results = process_batch(batch)
|
297 |
+
"""
|
298 |
+
try:
|
299 |
+
yield
|
300 |
+
except Exception as e:
|
301 |
+
_add_context_to_exception(e, context_object, **context)
|
302 |
+
raise
|
evaluate_cli.py
CHANGED
@@ -298,9 +298,13 @@ def cli_load_dataset(args: argparse.Namespace) -> HFDataset:
|
|
298 |
dataset_query=task_str, **overwrite_args
|
299 |
)
|
300 |
|
301 |
-
|
|
|
|
|
|
|
|
|
302 |
|
303 |
-
test_dataset = _source_to_dataset(
|
304 |
logger.info(
|
305 |
f"Dataset loaded successfully. Number of instances: {len(test_dataset)}"
|
306 |
)
|
@@ -414,6 +418,8 @@ def initialize_inference_engine(
|
|
414 |
chat_kwargs_dict=chat_kwargs_dict,
|
415 |
)
|
416 |
|
|
|
|
|
417 |
# --- Remote Model (CrossProviderInferenceEngine) ---
|
418 |
elif args.model.lower() == "cross_provider":
|
419 |
if "model_name" not in model_args_dict:
|
@@ -444,6 +450,9 @@ def initialize_inference_engine(
|
|
444 |
model=remote_model_name,
|
445 |
**model_args_dict,
|
446 |
)
|
|
|
|
|
|
|
447 |
else:
|
448 |
# This case should not be reached due to argparse choices
|
449 |
logger.error(
|
@@ -682,7 +691,7 @@ def _save_results_to_disk(
|
|
682 |
|
683 |
# prepend to the results_path name the time in a wat like this: 2025-04-04T11:37:32
|
684 |
|
685 |
-
timestamp = datetime.now().strftime("%Y-%m-%dT%H
|
686 |
|
687 |
results_path = prepend_timestamp_to_path(results_path, timestamp)
|
688 |
samples_path = prepend_timestamp_to_path(samples_path, timestamp)
|
@@ -825,5 +834,49 @@ def main():
|
|
825 |
logger.info("Unitxt Evaluation CLI finished successfully.")
|
826 |
|
827 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
828 |
if __name__ == "__main__":
|
829 |
main()
|
|
|
298 |
dataset_query=task_str, **overwrite_args
|
299 |
)
|
300 |
|
301 |
+
# this hack circumvents an issue with multi-level benchmarks (such Bluebench's translation subset) that fail when wrapped with an additional Benchmark() object.
|
302 |
+
if len(benchmark_subsets) == 1:
|
303 |
+
source = next(iter(benchmark_subsets.values()))
|
304 |
+
else:
|
305 |
+
source = Benchmark(subsets=benchmark_subsets)
|
306 |
|
307 |
+
test_dataset = _source_to_dataset(source, split=args.split)
|
308 |
logger.info(
|
309 |
f"Dataset loaded successfully. Number of instances: {len(test_dataset)}"
|
310 |
)
|
|
|
418 |
chat_kwargs_dict=chat_kwargs_dict,
|
419 |
)
|
420 |
|
421 |
+
# Keep the actual model name for the results
|
422 |
+
args.model = inference_model.model_name
|
423 |
# --- Remote Model (CrossProviderInferenceEngine) ---
|
424 |
elif args.model.lower() == "cross_provider":
|
425 |
if "model_name" not in model_args_dict:
|
|
|
450 |
model=remote_model_name,
|
451 |
**model_args_dict,
|
452 |
)
|
453 |
+
|
454 |
+
# Keep the actual model name for the results
|
455 |
+
args.model = inference_model.engine.model
|
456 |
else:
|
457 |
# This case should not be reached due to argparse choices
|
458 |
logger.error(
|
|
|
691 |
|
692 |
# prepend to the results_path name the time in a wat like this: 2025-04-04T11:37:32
|
693 |
|
694 |
+
timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
695 |
|
696 |
results_path = prepend_timestamp_to_path(results_path, timestamp)
|
697 |
samples_path = prepend_timestamp_to_path(samples_path, timestamp)
|
|
|
834 |
logger.info("Unitxt Evaluation CLI finished successfully.")
|
835 |
|
836 |
|
837 |
+
def extract_scores(directory): # pragma: no cover
|
838 |
+
import pandas as pd
|
839 |
+
|
840 |
+
data = []
|
841 |
+
|
842 |
+
for filename in sorted(os.listdir(directory)):
|
843 |
+
if filename.endswith("evaluation_results.json"):
|
844 |
+
file_path = os.path.join(directory, filename)
|
845 |
+
try:
|
846 |
+
with open(file_path, encoding="utf-8") as f:
|
847 |
+
content = json.load(f)
|
848 |
+
|
849 |
+
env_info = content.get("environment_info", {})
|
850 |
+
timestamp = env_info.get("timestamp_utc", "N/A")
|
851 |
+
model = env_info.get("parsed_arguments", {}).get("model", "N/A")
|
852 |
+
results = content.get("results", {})
|
853 |
+
|
854 |
+
row = {}
|
855 |
+
row["Model"] = model
|
856 |
+
row["Timestamp"] = timestamp
|
857 |
+
row["Average"] = results.get("score", "N/A")
|
858 |
+
|
859 |
+
for key in results.keys():
|
860 |
+
if isinstance(results[key], dict):
|
861 |
+
score = results[key].get("score", "N/A")
|
862 |
+
row[key] = score
|
863 |
+
|
864 |
+
data.append(row)
|
865 |
+
except Exception as e:
|
866 |
+
logger.error(f"Error parsing results file {filename}: {e}.")
|
867 |
+
|
868 |
+
return pd.DataFrame(data).sort_values(by="Timestamp", ascending=True)
|
869 |
+
|
870 |
+
|
871 |
+
def summarize_cli():
|
872 |
+
if len(sys.argv) != 2:
|
873 |
+
logger.error("Usage: python summarize_cli_results.py <results-directory>")
|
874 |
+
sys.exit(1)
|
875 |
+
directory = sys.argv[1]
|
876 |
+
df = extract_scores(directory)
|
877 |
+
|
878 |
+
logger.info(df.to_markdown(index=False))
|
879 |
+
|
880 |
+
|
881 |
if __name__ == "__main__":
|
882 |
main()
|
formats.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import re
|
2 |
from abc import abstractmethod
|
3 |
from typing import (
|
@@ -18,6 +19,7 @@ from .image_operators import image_to_data_url
|
|
18 |
from .operator import InstanceOperator
|
19 |
from .settings_utils import get_constants
|
20 |
from .type_utils import isoftype
|
|
|
21 |
from .utils import retry_connection_with_exponential_backoff
|
22 |
|
23 |
constants = get_constants()
|
@@ -135,6 +137,9 @@ class BaseFormat(Format):
|
|
135 |
def _prepare_instance_fields(self, instance) -> Tuple[str]:
|
136 |
instance_fields = {}
|
137 |
|
|
|
|
|
|
|
138 |
for field in (
|
139 |
"source",
|
140 |
constants.instruction_field,
|
@@ -170,6 +175,7 @@ class BaseFormat(Format):
|
|
170 |
target_prefix: str,
|
171 |
demos: List[Dict[str, Any]],
|
172 |
media: Optional[Dict[str, Any]] = None,
|
|
|
173 |
) -> str:
|
174 |
"""Abstract method for formatting instances in different subclasses.
|
175 |
|
@@ -256,7 +262,10 @@ class SystemFormat(BaseFormat):
|
|
256 |
target_prefix: str,
|
257 |
demos: List[Dict[str, Any]],
|
258 |
media: Optional[Dict[str, Any]] = None,
|
|
|
259 |
) -> str:
|
|
|
|
|
260 |
demos_string = ""
|
261 |
for demo in demos:
|
262 |
demo_str = self.demo_format.format(
|
@@ -356,8 +365,18 @@ class ChatAPIFormat(BaseFormat):
|
|
356 |
)
|
357 |
|
358 |
The resulting `messages` is now a dictionary ready for sending to the OpenAI API.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
"""
|
360 |
|
|
|
|
|
361 |
def to_content(self, text: str, media: Dict[str, Any]) -> Union[str, List[Content]]:
|
362 |
# Regular expression to find <img> tags with src attribute
|
363 |
img_tag_pattern = re.compile(
|
@@ -419,12 +438,15 @@ class ChatAPIFormat(BaseFormat):
|
|
419 |
target_prefix: str,
|
420 |
demos: List[Dict[str, Any]],
|
421 |
media: Optional[Dict[str, Any]] = None,
|
|
|
422 |
) -> List[Message]:
|
423 |
messages = []
|
424 |
|
425 |
-
if system_prompt or instruction:
|
426 |
system_content = self.to_content(
|
427 |
-
system_prompt
|
|
|
|
|
428 |
media,
|
429 |
)
|
430 |
messages.append(
|
@@ -435,13 +457,22 @@ class ChatAPIFormat(BaseFormat):
|
|
435 |
)
|
436 |
|
437 |
for demo_instance in demos:
|
438 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
assistant_content = self.to_content(
|
440 |
-
target_prefix + demo_instance["target"],
|
|
|
441 |
)
|
442 |
messages.extend(
|
443 |
[
|
444 |
-
{"role": "user", "content": user_content},
|
445 |
{
|
446 |
"role": "assistant",
|
447 |
"content": assistant_content,
|
@@ -449,9 +480,15 @@ class ChatAPIFormat(BaseFormat):
|
|
449 |
]
|
450 |
)
|
451 |
|
452 |
-
|
|
|
|
|
453 |
|
454 |
-
|
|
|
|
|
|
|
|
|
455 |
|
456 |
return messages
|
457 |
|
@@ -463,6 +500,7 @@ class ChatAPIFormat(BaseFormat):
|
|
463 |
target_prefix: str,
|
464 |
demos: List[Dict[str, Any]],
|
465 |
media: Optional[Dict[str, Any]] = None,
|
|
|
466 |
) -> Union[str, List[Message]]:
|
467 |
chat = self.to_chat(
|
468 |
system_prompt,
|
@@ -471,6 +509,7 @@ class ChatAPIFormat(BaseFormat):
|
|
471 |
target_prefix,
|
472 |
demos,
|
473 |
media,
|
|
|
474 |
)
|
475 |
media["images"] = []
|
476 |
return chat
|
@@ -492,6 +531,7 @@ class HFSystemFormat(ChatAPIFormat):
|
|
492 |
"""
|
493 |
|
494 |
model_name: str
|
|
|
495 |
_requirements_list = ["transformers", "Jinja2"]
|
496 |
|
497 |
@retry_connection_with_exponential_backoff(backoff_factor=2)
|
@@ -509,13 +549,17 @@ class HFSystemFormat(ChatAPIFormat):
|
|
509 |
target_prefix: str,
|
510 |
demos: List[Dict[str, Any]],
|
511 |
media: Optional[Dict[str, Any]] = None,
|
|
|
512 |
) -> str:
|
513 |
chat = self.to_chat(
|
514 |
-
system_prompt, instruction, source, target_prefix, demos, media
|
515 |
)
|
516 |
return (
|
517 |
self.tokenizer.apply_chat_template(
|
518 |
-
chat,
|
|
|
|
|
|
|
519 |
)
|
520 |
+ target_prefix
|
521 |
)
|
|
|
1 |
+
import json
|
2 |
import re
|
3 |
from abc import abstractmethod
|
4 |
from typing import (
|
|
|
19 |
from .operator import InstanceOperator
|
20 |
from .settings_utils import get_constants
|
21 |
from .type_utils import isoftype
|
22 |
+
from .types import Dialog
|
23 |
from .utils import retry_connection_with_exponential_backoff
|
24 |
|
25 |
constants = get_constants()
|
|
|
137 |
def _prepare_instance_fields(self, instance) -> Tuple[str]:
|
138 |
instance_fields = {}
|
139 |
|
140 |
+
if "__turns__" in instance:
|
141 |
+
instance_fields["turns"] = instance["__turns__"]
|
142 |
+
|
143 |
for field in (
|
144 |
"source",
|
145 |
constants.instruction_field,
|
|
|
175 |
target_prefix: str,
|
176 |
demos: List[Dict[str, Any]],
|
177 |
media: Optional[Dict[str, Any]] = None,
|
178 |
+
turns: Optional[Dialog] = None,
|
179 |
) -> str:
|
180 |
"""Abstract method for formatting instances in different subclasses.
|
181 |
|
|
|
262 |
target_prefix: str,
|
263 |
demos: List[Dict[str, Any]],
|
264 |
media: Optional[Dict[str, Any]] = None,
|
265 |
+
turns: Optional[Dialog] = None,
|
266 |
) -> str:
|
267 |
+
if turns is not None and not source:
|
268 |
+
source = json.dumps(turns)
|
269 |
demos_string = ""
|
270 |
for demo in demos:
|
271 |
demo_str = self.demo_format.format(
|
|
|
365 |
)
|
366 |
|
367 |
The resulting `messages` is now a dictionary ready for sending to the OpenAI API.
|
368 |
+
|
369 |
+
By default, the instruction in the template is placed in a turn with a 'system' role.
|
370 |
+
However, some chat tokenizers, will not place the default system prompt for the model,
|
371 |
+
if there is turn with an explicit 'system' role. To keep the default system prompt,
|
372 |
+
set 'place_instruction_in_user_turns=True'. This will cause the instruction of the template
|
373 |
+
to be placed in a turn with a 'user' role. Note the instruction will also be placed
|
374 |
+
in every demo turn (if demos are generated.)
|
375 |
+
|
376 |
"""
|
377 |
|
378 |
+
place_instruction_in_user_turns: bool = False
|
379 |
+
|
380 |
def to_content(self, text: str, media: Dict[str, Any]) -> Union[str, List[Content]]:
|
381 |
# Regular expression to find <img> tags with src attribute
|
382 |
img_tag_pattern = re.compile(
|
|
|
438 |
target_prefix: str,
|
439 |
demos: List[Dict[str, Any]],
|
440 |
media: Optional[Dict[str, Any]] = None,
|
441 |
+
turns: Optional[Dialog] = None,
|
442 |
) -> List[Message]:
|
443 |
messages = []
|
444 |
|
445 |
+
if system_prompt or (instruction and not self.place_instruction_in_user_turns):
|
446 |
system_content = self.to_content(
|
447 |
+
system_prompt
|
448 |
+
+ ("\n" if system_prompt != "" else "")
|
449 |
+
+ (instruction if not self.place_instruction_in_user_turns else ""),
|
450 |
media,
|
451 |
)
|
452 |
messages.append(
|
|
|
457 |
)
|
458 |
|
459 |
for demo_instance in demos:
|
460 |
+
if "__turns__" in demo_instance:
|
461 |
+
messages.extend(demo_instance["__turns__"])
|
462 |
+
else:
|
463 |
+
text = demo_instance["source"]
|
464 |
+
|
465 |
+
if instruction and self.place_instruction_in_user_turns:
|
466 |
+
text = f"{instruction}\n{text}"
|
467 |
+
source_content = self.to_content(text, media)
|
468 |
+
messages.extend([{"role": "user", "content": source_content}])
|
469 |
+
|
470 |
assistant_content = self.to_content(
|
471 |
+
target_prefix + demo_instance["target"],
|
472 |
+
media,
|
473 |
)
|
474 |
messages.extend(
|
475 |
[
|
|
|
476 |
{
|
477 |
"role": "assistant",
|
478 |
"content": assistant_content,
|
|
|
480 |
]
|
481 |
)
|
482 |
|
483 |
+
text = source
|
484 |
+
if instruction and self.place_instruction_in_user_turns:
|
485 |
+
text = f"{instruction}\n{text}"
|
486 |
|
487 |
+
if turns is None:
|
488 |
+
last_user_content = self.to_content(text, media)
|
489 |
+
messages.extend([{"role": "user", "content": last_user_content}])
|
490 |
+
else:
|
491 |
+
messages.extend(turns)
|
492 |
|
493 |
return messages
|
494 |
|
|
|
500 |
target_prefix: str,
|
501 |
demos: List[Dict[str, Any]],
|
502 |
media: Optional[Dict[str, Any]] = None,
|
503 |
+
turns: Optional[Dialog] = None,
|
504 |
) -> Union[str, List[Message]]:
|
505 |
chat = self.to_chat(
|
506 |
system_prompt,
|
|
|
509 |
target_prefix,
|
510 |
demos,
|
511 |
media,
|
512 |
+
turns,
|
513 |
)
|
514 |
media["images"] = []
|
515 |
return chat
|
|
|
531 |
"""
|
532 |
|
533 |
model_name: str
|
534 |
+
chat_kwargs_dict: Dict[str, str] = {}
|
535 |
_requirements_list = ["transformers", "Jinja2"]
|
536 |
|
537 |
@retry_connection_with_exponential_backoff(backoff_factor=2)
|
|
|
549 |
target_prefix: str,
|
550 |
demos: List[Dict[str, Any]],
|
551 |
media: Optional[Dict[str, Any]] = None,
|
552 |
+
turns: Optional[Dialog] = None,
|
553 |
) -> str:
|
554 |
chat = self.to_chat(
|
555 |
+
system_prompt, instruction, source, target_prefix, demos, media, turns
|
556 |
)
|
557 |
return (
|
558 |
self.tokenizer.apply_chat_template(
|
559 |
+
chat,
|
560 |
+
tokenize=False,
|
561 |
+
add_generation_prompt=True,
|
562 |
+
**self.chat_kwargs_dict,
|
563 |
)
|
564 |
+ target_prefix
|
565 |
)
|
fusion.py
CHANGED
@@ -2,6 +2,7 @@ from abc import abstractmethod
|
|
2 |
from typing import Dict, Generator, List, Optional, Union
|
3 |
|
4 |
from .dataclass import NonPositionalField
|
|
|
5 |
from .logging_utils import get_logger
|
6 |
from .operator import SourceOperator
|
7 |
from .random_utils import new_random_generator
|
@@ -92,7 +93,7 @@ class FixedFusion(BaseFusion):
|
|
92 |
max_from_this_split = max_per_this_split
|
93 |
|
94 |
logger.info(f"Processing {split} from {origin_name}...")
|
95 |
-
|
96 |
for instance in multi_stream[split]:
|
97 |
if (
|
98 |
max_from_this_split is not None
|
@@ -105,8 +106,6 @@ class FixedFusion(BaseFusion):
|
|
105 |
instance["subset"].insert(0, origin_name)
|
106 |
emitted_from_this_split += 1
|
107 |
yield instance
|
108 |
-
except Exception as e:
|
109 |
-
raise RuntimeError(f"Exception in subset: {origin_name}") from e
|
110 |
|
111 |
|
112 |
class WeightedFusion(BaseFusion):
|
@@ -164,16 +163,15 @@ class WeightedFusion(BaseFusion):
|
|
164 |
weights=[self.named_weights[name] for name in population],
|
165 |
)[0]
|
166 |
iterator = iterators[origin_name]
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
if
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
raise RuntimeError(f"Exception in subset: {origin_name}") from e
|
|
|
2 |
from typing import Dict, Generator, List, Optional, Union
|
3 |
|
4 |
from .dataclass import NonPositionalField
|
5 |
+
from .error_utils import error_context
|
6 |
from .logging_utils import get_logger
|
7 |
from .operator import SourceOperator
|
8 |
from .random_utils import new_random_generator
|
|
|
93 |
max_from_this_split = max_per_this_split
|
94 |
|
95 |
logger.info(f"Processing {split} from {origin_name}...")
|
96 |
+
with error_context(self, subset=origin_name):
|
97 |
for instance in multi_stream[split]:
|
98 |
if (
|
99 |
max_from_this_split is not None
|
|
|
106 |
instance["subset"].insert(0, origin_name)
|
107 |
emitted_from_this_split += 1
|
108 |
yield instance
|
|
|
|
|
109 |
|
110 |
|
111 |
class WeightedFusion(BaseFusion):
|
|
|
163 |
weights=[self.named_weights[name] for name in population],
|
164 |
)[0]
|
165 |
iterator = iterators[origin_name]
|
166 |
+
with error_context(self, subset=origin_name):
|
167 |
+
try:
|
168 |
+
instance = next(iterator)
|
169 |
+
if isinstance(origin_name, str):
|
170 |
+
if "subset" not in instance:
|
171 |
+
instance["subset"] = []
|
172 |
+
instance["subset"].insert(0, origin_name)
|
173 |
+
total_examples += 1
|
174 |
+
yield instance
|
175 |
+
|
176 |
+
except StopIteration:
|
177 |
+
iterators.pop(origin_name)
|
|
inference.py
CHANGED
@@ -39,7 +39,7 @@ from .artifact import Artifact
|
|
39 |
from .base_metric import Metric
|
40 |
from .dataclass import InternalField, NonPositionalField
|
41 |
from .deprecation_utils import deprecation
|
42 |
-
from .error_utils import UnitxtError, UnitxtWarning
|
43 |
from .image_operators import (
|
44 |
EncodeImageToString,
|
45 |
ImageDataString,
|
@@ -121,6 +121,8 @@ class TextGenerationInferenceOutput:
|
|
121 |
| For example: ``[ {.. "top_tokens": [ {"text": "a", 'logprob': }, {"text": "b", 'logprob': } ....]},
|
122 |
{.. "top_tokens": [ {"text": "c", 'logprob': }, {"text": "d", 'logprob': } ....]} ]``
|
123 |
|
|
|
|
|
124 |
input_tokens (int) : number of input tokens to the model.
|
125 |
|
126 |
output_tokens (int) : number of output tokens to the model.
|
@@ -137,6 +139,7 @@ class TextGenerationInferenceOutput:
|
|
137 |
"""
|
138 |
|
139 |
prediction: Union[str, List[Dict[str, Any]]]
|
|
|
140 |
input_tokens: Optional[int] = None
|
141 |
output_tokens: Optional[int] = None
|
142 |
stop_reason: Optional[str] = None
|
@@ -186,12 +189,19 @@ class InferenceEngine(Artifact):
|
|
186 |
def prepare(self):
|
187 |
if not settings.mock_inference_mode:
|
188 |
super().prepare() # no need to prepare a mock
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
190 |
if self.use_cache:
|
191 |
from diskcache import Cache
|
192 |
|
193 |
self._cache = Cache(
|
194 |
-
|
|
|
|
|
195 |
)
|
196 |
|
197 |
def __call__(
|
@@ -199,7 +209,12 @@ class InferenceEngine(Artifact):
|
|
199 |
dataset: Union[List[Dict[str, Any]], Dataset],
|
200 |
return_meta_data: bool = False,
|
201 |
) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]:
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
def get_instance_cache_key(self, instance):
|
205 |
instance_key_fields = ["media", "source", "task_data"]
|
@@ -243,54 +258,69 @@ class InferenceEngine(Artifact):
|
|
243 |
result = self._mock_infer(dataset)
|
244 |
else:
|
245 |
if self.use_cache:
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
):
|
251 |
-
|
252 |
-
|
253 |
-
for
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
) # each element is index in batch, and value
|
260 |
-
else:
|
261 |
-
missing_examples.append(
|
262 |
-
(i, item)
|
263 |
-
) # each element is index in batch and example
|
264 |
-
# infare on missing examples only, without indices
|
265 |
-
|
266 |
-
logger.info(
|
267 |
-
f"Inferring batch {batch_index + 1} / {number_of_batches} with {len(missing_examples)} instances (found {len(cached_results)} instances in {self._cache.directory})"
|
268 |
-
)
|
269 |
-
if len(missing_examples) > 0:
|
270 |
-
inferred_results = self._infer(
|
271 |
-
[e[1] for e in missing_examples], return_meta_data
|
272 |
-
)
|
273 |
-
# recombined to index and value
|
274 |
-
inferred_results = list(
|
275 |
-
zip([e[0] for e in missing_examples], inferred_results)
|
276 |
-
)
|
277 |
-
# Add missing examples to cache
|
278 |
-
for (_, item), (_, prediction) in zip(
|
279 |
-
missing_examples, inferred_results
|
280 |
-
):
|
281 |
-
if prediction is None:
|
282 |
-
continue
|
283 |
cache_key = self._get_cache_key(item)
|
284 |
-
self._cache
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
else:
|
293 |
-
|
|
|
|
|
|
|
|
|
|
|
294 |
return ListWithMetadata(
|
295 |
result,
|
296 |
metadata={
|
@@ -339,7 +369,16 @@ class InferenceEngine(Artifact):
|
|
339 |
|
340 |
def to_messages(self, instance):
|
341 |
if isinstance(instance["source"], list):
|
342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
return [
|
344 |
{
|
345 |
"role": "user",
|
@@ -521,13 +560,6 @@ class HFInferenceEngineBase(
|
|
521 |
def get_engine_id(self):
|
522 |
return get_model_and_label_id(self.model_name, self.label)
|
523 |
|
524 |
-
def decode_tokens(self, tokens: Sequence, inp_length: int) -> List[str]:
|
525 |
-
return self.processor.decode(tokens[inp_length:], skip_special_tokens=True)
|
526 |
-
|
527 |
-
@staticmethod
|
528 |
-
def create_string_from_tokens(string_tokens: List[str]) -> str:
|
529 |
-
return "".join(token for token in string_tokens)
|
530 |
-
|
531 |
def make_predictions(self, prepared_inputs: Mapping) -> Mapping:
|
532 |
return self.model.generate(
|
533 |
**prepared_inputs,
|
@@ -598,6 +630,7 @@ class HFInferenceEngineBase(
|
|
598 |
def get_return_object(
|
599 |
self,
|
600 |
output: Union[str, List[Dict[str, Any]]],
|
|
|
601 |
output_tokens: Optional[int],
|
602 |
inp: Optional[str],
|
603 |
inp_tokens: Optional[int],
|
@@ -606,6 +639,7 @@ class HFInferenceEngineBase(
|
|
606 |
if return_meta_data:
|
607 |
return TextGenerationInferenceOutput(
|
608 |
prediction=output,
|
|
|
609 |
output_tokens=output_tokens if output_tokens is not None else None,
|
610 |
input_text=inp,
|
611 |
input_tokens=inp_tokens if inp_tokens is not None else None,
|
@@ -689,7 +723,8 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
689 |
# cause an error because the data is always on the gpu
|
690 |
# if torch.cuda.device_count() > 1:
|
691 |
# assert self.device == torch.device(0)
|
692 |
-
|
|
|
693 |
# else:
|
694 |
# if not self.load_in_8bit:
|
695 |
# args["device"] = self.device
|
@@ -717,15 +752,21 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
717 |
**model_args,
|
718 |
)
|
719 |
|
720 |
-
def prepare_inputs(self, data: Iterable) -> Mapping:
|
721 |
tokenizer_kargs = {}
|
722 |
if isinstance(data[0], list):
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
729 |
tokenizer_kargs["add_special_tokens"] = False
|
730 |
|
731 |
if self.processor.pad_token is None:
|
@@ -766,59 +807,71 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
766 |
total=len(dataset) // self.batch_size,
|
767 |
):
|
768 |
# Get the current batch
|
769 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
770 |
|
771 |
-
|
772 |
-
# 1. Tokenize inputs for the batch
|
773 |
-
tokenized_inputs = self.prepare_inputs(batch_sources)
|
774 |
|
775 |
-
#
|
776 |
input_length = (
|
777 |
1
|
778 |
if self.model.config.is_encoder_decoder
|
779 |
else tokenized_inputs.input_ids.shape[1]
|
780 |
)
|
781 |
|
782 |
-
#
|
783 |
predictions = self.make_predictions(tokenized_inputs)
|
784 |
sequences = predictions.sequences # Sequences for the current batch
|
785 |
|
786 |
-
|
787 |
-
string_tokens_batch = [
|
788 |
-
self.decode_tokens(sequence, input_length) for sequence in sequences
|
789 |
-
]
|
790 |
|
791 |
-
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
-
|
797 |
-
|
798 |
-
|
799 |
-
)
|
800 |
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
-
|
805 |
-
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
813 |
)
|
814 |
-
for j in range(
|
815 |
-
len(sequences)
|
816 |
-
) # Iterate through items in the current batch
|
817 |
-
]
|
818 |
|
819 |
-
# Add results from this batch to the overall list
|
820 |
all_final_outputs.extend(batch_results)
|
821 |
-
# --- End of batch processing ---
|
822 |
|
823 |
return all_final_outputs
|
824 |
|
@@ -847,7 +900,10 @@ class HFLlavaInferenceEngine(HFInferenceEngineBase):
|
|
847 |
self, sequences: Sequence, scores: Sequence, beam_indices: Optional[int]
|
848 |
) -> Sequence:
|
849 |
if not hasattr(self.model.config, "vocab_size"):
|
850 |
-
|
|
|
|
|
|
|
851 |
|
852 |
return super().compute_transition_scores(sequences, scores, beam_indices)
|
853 |
|
@@ -917,18 +973,35 @@ class HFLlavaInferenceEngine(HFInferenceEngineBase):
|
|
917 |
|
918 |
predictions = self.make_predictions(processed_inputs)
|
919 |
|
920 |
-
|
921 |
|
922 |
-
|
923 |
-
|
924 |
-
|
925 |
-
|
926 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
927 |
|
928 |
results.append(
|
929 |
self.get_return_object(
|
930 |
-
output=final_outputs,
|
931 |
-
|
|
|
932 |
inp=instance["source"],
|
933 |
inp_tokens=None,
|
934 |
return_meta_data=return_meta_data,
|
@@ -1189,6 +1262,7 @@ class HFPipelineBasedInferenceEngine(
|
|
1189 |
if return_meta_data:
|
1190 |
return TextGenerationInferenceOutput(
|
1191 |
prediction=output["generated_text"],
|
|
|
1192 |
model_name=self.model_name,
|
1193 |
inference_type=self.label,
|
1194 |
input_text=inp,
|
@@ -1252,10 +1326,13 @@ class MockInferenceEngine(InferenceEngine, LogProbInferenceEngine):
|
|
1252 |
for instance in dataset
|
1253 |
]
|
1254 |
|
1255 |
-
def get_return_object(
|
|
|
|
|
1256 |
if return_meta_data:
|
1257 |
return TextGenerationInferenceOutput(
|
1258 |
prediction=predict_result,
|
|
|
1259 |
input_tokens=len(instance["source"]),
|
1260 |
output_tokens=len(predict_result),
|
1261 |
model_name=self.model_name,
|
@@ -1369,21 +1446,25 @@ class OllamaInferenceEngine(
|
|
1369 |
return get_model_and_label_id(self.model, self.label)
|
1370 |
|
1371 |
def prepare_engine(self):
|
1372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1373 |
|
1374 |
def _infer(
|
1375 |
self,
|
1376 |
dataset: Union[List[Dict[str, Any]], Dataset],
|
1377 |
return_meta_data: bool = False,
|
1378 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
1379 |
-
import ollama
|
1380 |
-
|
1381 |
args = self.to_dict([StandardAPIParamsMixin])
|
1382 |
results = []
|
1383 |
model = args.pop("model")
|
1384 |
for instance in dataset:
|
1385 |
messages = self.to_messages(instance)
|
1386 |
-
response =
|
1387 |
messages=messages,
|
1388 |
model=model,
|
1389 |
options=args,
|
@@ -1877,7 +1958,7 @@ class OpenAiInferenceEngine(
|
|
1877 |
f"Error predicting instance {messages}:{e}. Returning empty prediction"
|
1878 |
)
|
1879 |
return TextGenerationInferenceOutput(
|
1880 |
-
prediction="-", input_tokens=0, output_tokens=0
|
1881 |
)
|
1882 |
|
1883 |
@run_with_imap
|
@@ -1894,10 +1975,12 @@ class OpenAiInferenceEngine(
|
|
1894 |
top_logprobs_response = response.choices[0].logprobs.content
|
1895 |
pred_output = [
|
1896 |
{
|
|
|
|
|
1897 |
"top_tokens": [
|
1898 |
{"text": obj.token, "logprob": obj.logprob}
|
1899 |
for obj in generated_token.top_logprobs
|
1900 |
-
]
|
1901 |
}
|
1902 |
for generated_token in top_logprobs_response
|
1903 |
]
|
@@ -1907,15 +1990,21 @@ class OpenAiInferenceEngine(
|
|
1907 |
logging.error(
|
1908 |
f"Error predicting instance {messages}:{e}. Returning empty prediction"
|
1909 |
)
|
1910 |
-
prediction = [
|
|
|
|
|
1911 |
return TextGenerationInferenceOutput(
|
1912 |
-
prediction=prediction,
|
|
|
|
|
|
|
1913 |
)
|
1914 |
|
1915 |
def get_return_object(self, predict_result, response, return_meta_data):
|
1916 |
if return_meta_data:
|
1917 |
return TextGenerationInferenceOutput(
|
1918 |
prediction=predict_result,
|
|
|
1919 |
input_tokens=response.usage.prompt_tokens,
|
1920 |
output_tokens=response.usage.completion_tokens,
|
1921 |
model_name=self.model_name,
|
@@ -1973,7 +2062,12 @@ class RITSInferenceEngine(
|
|
1973 |
label: str = "rits"
|
1974 |
data_classification_policy = ["public", "proprietary"]
|
1975 |
|
1976 |
-
model_names_dict = {
|
|
|
|
|
|
|
|
|
|
|
1977 |
|
1978 |
def get_default_headers(self):
|
1979 |
return {"RITS_API_KEY": self.credentials["api_key"]}
|
@@ -2606,6 +2700,7 @@ class WMLInferenceEngineGeneration(WMLInferenceEngineBase, WMLGenerationParamsMi
|
|
2606 |
if return_meta_data:
|
2607 |
return TextGenerationInferenceOutput(
|
2608 |
prediction=predict_result,
|
|
|
2609 |
input_tokens=result["input_token_count"],
|
2610 |
output_tokens=result["generated_token_count"],
|
2611 |
model_name=self.model_name or self.deployment_id,
|
@@ -2865,6 +2960,8 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
|
|
2865 |
tool_call = data[idx]["tools"]["tools"] is not None
|
2866 |
|
2867 |
output = response["choices"][0][output_type]
|
|
|
|
|
2868 |
if tool_call:
|
2869 |
if "tool_calls" in output:
|
2870 |
func = output["tool_calls"][0]["function"]
|
@@ -2877,6 +2974,7 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
|
|
2877 |
results.append(
|
2878 |
self.get_return_object(
|
2879 |
prediction,
|
|
|
2880 |
response,
|
2881 |
str(inp),
|
2882 |
return_meta_data,
|
@@ -2885,10 +2983,13 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
|
|
2885 |
|
2886 |
return results
|
2887 |
|
2888 |
-
def get_return_object(
|
|
|
|
|
2889 |
if return_meta_data:
|
2890 |
return TextGenerationInferenceOutput(
|
2891 |
prediction=predict_result,
|
|
|
2892 |
input_tokens=result["usage"]["prompt_tokens"],
|
2893 |
output_tokens=len(predict_result)
|
2894 |
if isinstance(predict_result, list)
|
@@ -3286,6 +3387,7 @@ class LiteLLMInferenceEngine(
|
|
3286 |
prediction = response["choices"][0]["message"]["content"] or ""
|
3287 |
return TextGenerationInferenceOutput(
|
3288 |
prediction=prediction,
|
|
|
3289 |
input_tokens=usage.get("prompt_tokens"),
|
3290 |
output_tokens=usage.get("completion_tokens"),
|
3291 |
model_name=response.get("model", self.model),
|
@@ -3436,21 +3538,22 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3436 |
},
|
3437 |
"rits": {
|
3438 |
"granite-3-8b-instruct": "ibm-granite/granite-3.0-8b-instruct",
|
|
|
3439 |
"granite-3-2-8b-instruct": "ibm-granite/granite-3.2-8b-instruct",
|
3440 |
"granite-3-3-8b-instruct": "ibm-granite/granite-3.3-8b-instruct",
|
3441 |
-
"llama-3-1-8b-instruct": "meta-llama/
|
3442 |
"llama-3-1-70b-instruct": "meta-llama/llama-3-1-70b-instruct",
|
3443 |
"llama-3-1-405b-instruct": "meta-llama/llama-3-1-405b-instruct-fp8",
|
3444 |
"llama-3-1-405b-instruct-fp8": "meta-llama/llama-3-1-405b-instruct-fp8",
|
3445 |
"llama-3-2-11b-vision-instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
3446 |
"llama-3-2-90b-vision-instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
3447 |
"llama-3-3-70b-instruct": "meta-llama/llama-3-3-70b-instruct",
|
3448 |
-
"llama-4-scout": "llama-4-scout-17b-16e",
|
3449 |
-
"llama-4-maverick": "llama-4-
|
3450 |
"mistral-large-instruct": "mistralai/mistral-large-instruct-2407",
|
3451 |
"mixtral-8x7b-instruct": "mistralai/mixtral-8x7B-instruct-v0.1",
|
3452 |
"mixtral-8x7b-instruct-v01": "mistralai/mixtral-8x7B-instruct-v0.1",
|
3453 |
-
"deepseek-v3": "deepseek-ai/
|
3454 |
"granite-guardian-3-2-3b-a800m": "ibm-granite/granite-guardian-3.2-3b-a800m",
|
3455 |
"granite-guardian-3-2-5b": "ibm-granite/granite-guardian-3.2-5b",
|
3456 |
},
|
|
|
39 |
from .base_metric import Metric
|
40 |
from .dataclass import InternalField, NonPositionalField
|
41 |
from .deprecation_utils import deprecation
|
42 |
+
from .error_utils import UnitxtError, UnitxtWarning, error_context
|
43 |
from .image_operators import (
|
44 |
EncodeImageToString,
|
45 |
ImageDataString,
|
|
|
121 |
| For example: ``[ {.. "top_tokens": [ {"text": "a", 'logprob': }, {"text": "b", 'logprob': } ....]},
|
122 |
{.. "top_tokens": [ {"text": "c", 'logprob': }, {"text": "d", 'logprob': } ....]} ]``
|
123 |
|
124 |
+
generated_text (str): The generated text generated by the model (in both _infer and _infer_log_probs calls).
|
125 |
+
|
126 |
input_tokens (int) : number of input tokens to the model.
|
127 |
|
128 |
output_tokens (int) : number of output tokens to the model.
|
|
|
139 |
"""
|
140 |
|
141 |
prediction: Union[str, List[Dict[str, Any]]]
|
142 |
+
generated_text: str
|
143 |
input_tokens: Optional[int] = None
|
144 |
output_tokens: Optional[int] = None
|
145 |
stop_reason: Optional[str] = None
|
|
|
189 |
def prepare(self):
|
190 |
if not settings.mock_inference_mode:
|
191 |
super().prepare() # no need to prepare a mock
|
192 |
+
with error_context(
|
193 |
+
self,
|
194 |
+
stage="Prepare Inference Engine",
|
195 |
+
help="https://www.unitxt.ai/en/latest/docs/inference.html",
|
196 |
+
):
|
197 |
+
self.prepare_engine()
|
198 |
if self.use_cache:
|
199 |
from diskcache import Cache
|
200 |
|
201 |
self._cache = Cache(
|
202 |
+
os.path.join(
|
203 |
+
settings.inference_engine_cache_path, self.__class__.__name__
|
204 |
+
)
|
205 |
)
|
206 |
|
207 |
def __call__(
|
|
|
209 |
dataset: Union[List[Dict[str, Any]], Dataset],
|
210 |
return_meta_data: bool = False,
|
211 |
) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]:
|
212 |
+
with error_context(
|
213 |
+
self,
|
214 |
+
stage="Running Inference",
|
215 |
+
help="https://www.unitxt.ai/en/latest/docs/inference.html",
|
216 |
+
):
|
217 |
+
return self.infer(dataset=dataset, return_meta_data=return_meta_data)
|
218 |
|
219 |
def get_instance_cache_key(self, instance):
|
220 |
instance_key_fields = ["media", "source", "task_data"]
|
|
|
258 |
result = self._mock_infer(dataset)
|
259 |
else:
|
260 |
if self.use_cache:
|
261 |
+
with error_context(
|
262 |
+
self,
|
263 |
+
stage="Inference Cache Handling",
|
264 |
+
help="https://www.unitxt.ai/en/latest/docs/inference.html",
|
265 |
):
|
266 |
+
number_of_batches = math.ceil(len(dataset) / self.cache_batch_size)
|
267 |
+
result = []
|
268 |
+
for batch_index, batch in enumerate(
|
269 |
+
batched(dataset, self.cache_batch_size)
|
270 |
+
):
|
271 |
+
cached_results = []
|
272 |
+
missing_examples = []
|
273 |
+
for i, item in enumerate(batch):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
cache_key = self._get_cache_key(item)
|
275 |
+
cached_value = self._cache.get(cache_key)
|
276 |
+
if cached_value is not None:
|
277 |
+
cached_results.append(
|
278 |
+
(i, cached_value)
|
279 |
+
) # each element is index in batch, and value
|
280 |
+
else:
|
281 |
+
missing_examples.append(
|
282 |
+
(i, item)
|
283 |
+
) # each element is index in batch and example
|
284 |
+
# infare on missing examples only, without indices
|
285 |
+
|
286 |
+
logger.info(
|
287 |
+
f"Inferring batch {batch_index + 1} / {number_of_batches} with {len(missing_examples)} instances (found {len(cached_results)} instances in {self._cache.directory})"
|
288 |
+
)
|
289 |
+
if len(missing_examples) > 0:
|
290 |
+
with error_context(
|
291 |
+
self,
|
292 |
+
stage="Running Inference",
|
293 |
+
help="https://www.unitxt.ai/en/latest/docs/inference.html",
|
294 |
+
):
|
295 |
+
inferred_results = self._infer(
|
296 |
+
[e[1] for e in missing_examples], return_meta_data
|
297 |
+
)
|
298 |
+
# recombined to index and value
|
299 |
+
inferred_results = list(
|
300 |
+
zip([e[0] for e in missing_examples], inferred_results)
|
301 |
+
)
|
302 |
+
# Add missing examples to cache
|
303 |
+
for (_, item), (_, prediction) in zip(
|
304 |
+
missing_examples, inferred_results
|
305 |
+
):
|
306 |
+
if prediction is None:
|
307 |
+
continue
|
308 |
+
cache_key = self._get_cache_key(item)
|
309 |
+
self._cache[cache_key] = prediction
|
310 |
+
else:
|
311 |
+
inferred_results = []
|
312 |
+
# Combine cached and inferred results in original order
|
313 |
+
batch_predictions = [
|
314 |
+
p[1] for p in sorted(cached_results + inferred_results)
|
315 |
+
]
|
316 |
+
result.extend(batch_predictions)
|
317 |
else:
|
318 |
+
with error_context(
|
319 |
+
self,
|
320 |
+
stage="Running Inference",
|
321 |
+
help="https://www.unitxt.ai/en/latest/docs/inference.html",
|
322 |
+
):
|
323 |
+
result = self._infer(dataset, return_meta_data)
|
324 |
return ListWithMetadata(
|
325 |
result,
|
326 |
metadata={
|
|
|
369 |
|
370 |
def to_messages(self, instance):
|
371 |
if isinstance(instance["source"], list):
|
372 |
+
messages = []
|
373 |
+
for message in instance["source"]:
|
374 |
+
if "tool_calls" in message:
|
375 |
+
for tool_call in message["tool_calls"]:
|
376 |
+
if not isinstance(tool_call["function"]["arguments"], str):
|
377 |
+
tool_call["function"]["arguments"] = json.dumps(
|
378 |
+
tool_call["function"]["arguments"]
|
379 |
+
)
|
380 |
+
messages.append(message)
|
381 |
+
return messages
|
382 |
return [
|
383 |
{
|
384 |
"role": "user",
|
|
|
560 |
def get_engine_id(self):
|
561 |
return get_model_and_label_id(self.model_name, self.label)
|
562 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
563 |
def make_predictions(self, prepared_inputs: Mapping) -> Mapping:
|
564 |
return self.model.generate(
|
565 |
**prepared_inputs,
|
|
|
630 |
def get_return_object(
|
631 |
self,
|
632 |
output: Union[str, List[Dict[str, Any]]],
|
633 |
+
generated_text: str,
|
634 |
output_tokens: Optional[int],
|
635 |
inp: Optional[str],
|
636 |
inp_tokens: Optional[int],
|
|
|
639 |
if return_meta_data:
|
640 |
return TextGenerationInferenceOutput(
|
641 |
prediction=output,
|
642 |
+
generated_text=generated_text,
|
643 |
output_tokens=output_tokens if output_tokens is not None else None,
|
644 |
input_text=inp,
|
645 |
input_tokens=inp_tokens if inp_tokens is not None else None,
|
|
|
723 |
# cause an error because the data is always on the gpu
|
724 |
# if torch.cuda.device_count() > 1:
|
725 |
# assert self.device == torch.device(0)
|
726 |
+
if self.device_map is None:
|
727 |
+
args["device_map"] = "auto"
|
728 |
# else:
|
729 |
# if not self.load_in_8bit:
|
730 |
# args["device"] = self.device
|
|
|
752 |
**model_args,
|
753 |
)
|
754 |
|
755 |
+
def prepare_inputs(self, data: Iterable, tools: Iterable) -> Mapping:
|
756 |
tokenizer_kargs = {}
|
757 |
if isinstance(data[0], list):
|
758 |
+
processed = []
|
759 |
+
for item, item_tools in zip(data, tools):
|
760 |
+
processed.append(
|
761 |
+
self.processor.apply_chat_template(
|
762 |
+
item,
|
763 |
+
tokenize=False,
|
764 |
+
tools=item_tools,
|
765 |
+
add_generation_prompt=True,
|
766 |
+
**self.chat_kwargs_dict,
|
767 |
+
)
|
768 |
+
)
|
769 |
+
data = processed
|
770 |
tokenizer_kargs["add_special_tokens"] = False
|
771 |
|
772 |
if self.processor.pad_token is None:
|
|
|
807 |
total=len(dataset) // self.batch_size,
|
808 |
):
|
809 |
# Get the current batch
|
810 |
+
sources = []
|
811 |
+
tools = []
|
812 |
+
for instance in batch:
|
813 |
+
sources.append(instance["source"])
|
814 |
+
if "task_data" in instance and "__tools__" in instance["task_data"]:
|
815 |
+
task_data = instance["task_data"]
|
816 |
+
if isinstance(task_data, str):
|
817 |
+
task_data = json.loads(task_data)
|
818 |
+
tools.append(task_data["__tools__"])
|
819 |
+
else:
|
820 |
+
tools.append(None)
|
821 |
+
# Tokenize inputs for the batch
|
822 |
|
823 |
+
tokenized_inputs = self.prepare_inputs(sources, tools)
|
|
|
|
|
824 |
|
825 |
+
# Determine input length (handle encoder-decoder models)
|
826 |
input_length = (
|
827 |
1
|
828 |
if self.model.config.is_encoder_decoder
|
829 |
else tokenized_inputs.input_ids.shape[1]
|
830 |
)
|
831 |
|
832 |
+
# Make predictions for the batch
|
833 |
predictions = self.make_predictions(tokenized_inputs)
|
834 |
sequences = predictions.sequences # Sequences for the current batch
|
835 |
|
836 |
+
output_tokens = sequences[:, input_length:]
|
|
|
|
|
|
|
837 |
|
838 |
+
output_tokens_strings = []
|
839 |
+
for tokens in output_tokens:
|
840 |
+
output_tokens_strings.append(
|
841 |
+
[
|
842 |
+
self.processor.decode(token, skip_special_tokens=True)
|
843 |
+
for token in tokens
|
844 |
+
]
|
845 |
+
)
|
|
|
846 |
|
847 |
+
output_strings = []
|
848 |
+
for tokens in output_tokens:
|
849 |
+
output_strings.append(
|
850 |
+
self.processor.decode(tokens, skip_special_tokens=True)
|
851 |
+
)
|
852 |
+
|
853 |
+
if return_logprobs:
|
854 |
+
outputs = self.get_logprobs(predictions, output_tokens_strings)
|
855 |
+
else:
|
856 |
+
outputs = output_strings
|
857 |
+
|
858 |
+
# Create return objects for the batch
|
859 |
+
batch_results = []
|
860 |
+
for i in range(len(sequences)):
|
861 |
+
batch_results.append(
|
862 |
+
self.get_return_object(
|
863 |
+
output=outputs[i],
|
864 |
+
generated_text=output_strings[i],
|
865 |
+
output_tokens=len(output_tokens_strings[i]),
|
866 |
+
inp=sources[i],
|
867 |
+
inp_tokens=len(tokenized_inputs.encodings[i].tokens)
|
868 |
+
if tokenized_inputs.encodings is not None
|
869 |
+
else None,
|
870 |
+
return_meta_data=return_meta_data,
|
871 |
+
)
|
872 |
)
|
|
|
|
|
|
|
|
|
873 |
|
|
|
874 |
all_final_outputs.extend(batch_results)
|
|
|
875 |
|
876 |
return all_final_outputs
|
877 |
|
|
|
900 |
self, sequences: Sequence, scores: Sequence, beam_indices: Optional[int]
|
901 |
) -> Sequence:
|
902 |
if not hasattr(self.model.config, "vocab_size"):
|
903 |
+
try:
|
904 |
+
self.model.config.vocab_size = self.model.vocab_size
|
905 |
+
except:
|
906 |
+
self.model.config.vocab_size = self.model.config.text_config.vocab_size
|
907 |
|
908 |
return super().compute_transition_scores(sequences, scores, beam_indices)
|
909 |
|
|
|
973 |
|
974 |
predictions = self.make_predictions(processed_inputs)
|
975 |
|
976 |
+
sequences = predictions.sequences # Sequences for the current batch
|
977 |
|
978 |
+
output_tokens = sequences[:, input_len:]
|
979 |
+
|
980 |
+
output_tokens_strings = []
|
981 |
+
for tokens in output_tokens:
|
982 |
+
output_tokens_strings.append(
|
983 |
+
[
|
984 |
+
self.processor.decode(token, skip_special_tokens=True)
|
985 |
+
for token in tokens
|
986 |
+
]
|
987 |
+
)
|
988 |
+
|
989 |
+
output_strings = []
|
990 |
+
for tokens in output_tokens:
|
991 |
+
output_strings.append(
|
992 |
+
self.processor.decode(tokens, skip_special_tokens=True)
|
993 |
+
)
|
994 |
+
|
995 |
+
if return_logprobs:
|
996 |
+
final_outputs = self.get_logprobs(predictions, output_tokens_strings)
|
997 |
+
else:
|
998 |
+
final_outputs = output_strings
|
999 |
|
1000 |
results.append(
|
1001 |
self.get_return_object(
|
1002 |
+
output=final_outputs[0],
|
1003 |
+
generated_text=output_strings,
|
1004 |
+
output_tokens=len(output_tokens_strings[0]),
|
1005 |
inp=instance["source"],
|
1006 |
inp_tokens=None,
|
1007 |
return_meta_data=return_meta_data,
|
|
|
1262 |
if return_meta_data:
|
1263 |
return TextGenerationInferenceOutput(
|
1264 |
prediction=output["generated_text"],
|
1265 |
+
generated_text=output["generated_text"],
|
1266 |
model_name=self.model_name,
|
1267 |
inference_type=self.label,
|
1268 |
input_text=inp,
|
|
|
1326 |
for instance in dataset
|
1327 |
]
|
1328 |
|
1329 |
+
def get_return_object(
|
1330 |
+
self, predict_result, generated_text, instance, return_meta_data
|
1331 |
+
):
|
1332 |
if return_meta_data:
|
1333 |
return TextGenerationInferenceOutput(
|
1334 |
prediction=predict_result,
|
1335 |
+
generated_text=self.default_inference_value,
|
1336 |
input_tokens=len(instance["source"]),
|
1337 |
output_tokens=len(predict_result),
|
1338 |
model_name=self.model_name,
|
|
|
1446 |
return get_model_and_label_id(self.model, self.label)
|
1447 |
|
1448 |
def prepare_engine(self):
|
1449 |
+
from ollama import Client
|
1450 |
+
|
1451 |
+
self.client = Client(
|
1452 |
+
host=self.credentials["api_base"]
|
1453 |
+
if self.credentials is not None and "api_base" in self.credentials
|
1454 |
+
else None
|
1455 |
+
)
|
1456 |
|
1457 |
def _infer(
|
1458 |
self,
|
1459 |
dataset: Union[List[Dict[str, Any]], Dataset],
|
1460 |
return_meta_data: bool = False,
|
1461 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
|
|
|
|
1462 |
args = self.to_dict([StandardAPIParamsMixin])
|
1463 |
results = []
|
1464 |
model = args.pop("model")
|
1465 |
for instance in dataset:
|
1466 |
messages = self.to_messages(instance)
|
1467 |
+
response = self.client.chat(
|
1468 |
messages=messages,
|
1469 |
model=model,
|
1470 |
options=args,
|
|
|
1958 |
f"Error predicting instance {messages}:{e}. Returning empty prediction"
|
1959 |
)
|
1960 |
return TextGenerationInferenceOutput(
|
1961 |
+
prediction="-", generated_text="-", input_tokens=0, output_tokens=0
|
1962 |
)
|
1963 |
|
1964 |
@run_with_imap
|
|
|
1975 |
top_logprobs_response = response.choices[0].logprobs.content
|
1976 |
pred_output = [
|
1977 |
{
|
1978 |
+
"text": generated_token.token,
|
1979 |
+
"logprob": generated_token.logprob,
|
1980 |
"top_tokens": [
|
1981 |
{"text": obj.token, "logprob": obj.logprob}
|
1982 |
for obj in generated_token.top_logprobs
|
1983 |
+
],
|
1984 |
}
|
1985 |
for generated_token in top_logprobs_response
|
1986 |
]
|
|
|
1990 |
logging.error(
|
1991 |
f"Error predicting instance {messages}:{e}. Returning empty prediction"
|
1992 |
)
|
1993 |
+
prediction = [
|
1994 |
+
{"text": "-", "logprob": 0, "top_tokens": [{"text": "-", "logprob": 0}]}
|
1995 |
+
]
|
1996 |
return TextGenerationInferenceOutput(
|
1997 |
+
prediction=prediction,
|
1998 |
+
generated_text=prediction,
|
1999 |
+
input_tokens=0,
|
2000 |
+
output_tokens=0,
|
2001 |
)
|
2002 |
|
2003 |
def get_return_object(self, predict_result, response, return_meta_data):
|
2004 |
if return_meta_data:
|
2005 |
return TextGenerationInferenceOutput(
|
2006 |
prediction=predict_result,
|
2007 |
+
generated_text=response.choices[0].message.content,
|
2008 |
input_tokens=response.usage.prompt_tokens,
|
2009 |
output_tokens=response.usage.completion_tokens,
|
2010 |
model_name=self.model_name,
|
|
|
2062 |
label: str = "rits"
|
2063 |
data_classification_policy = ["public", "proprietary"]
|
2064 |
|
2065 |
+
model_names_dict = {
|
2066 |
+
"microsoft/phi-4": "microsoft-phi-4",
|
2067 |
+
"meta-llama/llama-4-maverick-17b-128e-instruct-fp8": "llama-4-mvk-17b-128e-fp8",
|
2068 |
+
"deepseek-ai/DeepSeek-V3": "deepseek-v3-h200",
|
2069 |
+
"meta-llama/Llama-3.1-8B-Instruct": "llama-3-1-8b-instruct",
|
2070 |
+
}
|
2071 |
|
2072 |
def get_default_headers(self):
|
2073 |
return {"RITS_API_KEY": self.credentials["api_key"]}
|
|
|
2700 |
if return_meta_data:
|
2701 |
return TextGenerationInferenceOutput(
|
2702 |
prediction=predict_result,
|
2703 |
+
generated_text=result["generated_text"],
|
2704 |
input_tokens=result["input_token_count"],
|
2705 |
output_tokens=result["generated_token_count"],
|
2706 |
model_name=self.model_name or self.deployment_id,
|
|
|
2960 |
tool_call = data[idx]["tools"]["tools"] is not None
|
2961 |
|
2962 |
output = response["choices"][0][output_type]
|
2963 |
+
if "content" not in output:
|
2964 |
+
output["content"] = ""
|
2965 |
if tool_call:
|
2966 |
if "tool_calls" in output:
|
2967 |
func = output["tool_calls"][0]["function"]
|
|
|
2974 |
results.append(
|
2975 |
self.get_return_object(
|
2976 |
prediction,
|
2977 |
+
response["choices"][0]["message"]["content"],
|
2978 |
response,
|
2979 |
str(inp),
|
2980 |
return_meta_data,
|
|
|
2983 |
|
2984 |
return results
|
2985 |
|
2986 |
+
def get_return_object(
|
2987 |
+
self, predict_result, generated_text, result, input_text, return_meta_data
|
2988 |
+
):
|
2989 |
if return_meta_data:
|
2990 |
return TextGenerationInferenceOutput(
|
2991 |
prediction=predict_result,
|
2992 |
+
generated_text=generated_text,
|
2993 |
input_tokens=result["usage"]["prompt_tokens"],
|
2994 |
output_tokens=len(predict_result)
|
2995 |
if isinstance(predict_result, list)
|
|
|
3387 |
prediction = response["choices"][0]["message"]["content"] or ""
|
3388 |
return TextGenerationInferenceOutput(
|
3389 |
prediction=prediction,
|
3390 |
+
generated_text=response["choices"][0]["message"]["content"],
|
3391 |
input_tokens=usage.get("prompt_tokens"),
|
3392 |
output_tokens=usage.get("completion_tokens"),
|
3393 |
model_name=response.get("model", self.model),
|
|
|
3538 |
},
|
3539 |
"rits": {
|
3540 |
"granite-3-8b-instruct": "ibm-granite/granite-3.0-8b-instruct",
|
3541 |
+
"granite-3-1-8b-instruct": "ibm-granite/granite-3.1-8b-instruct",
|
3542 |
"granite-3-2-8b-instruct": "ibm-granite/granite-3.2-8b-instruct",
|
3543 |
"granite-3-3-8b-instruct": "ibm-granite/granite-3.3-8b-instruct",
|
3544 |
+
"llama-3-1-8b-instruct": "meta-llama/Llama-3.1-8B-Instruct",
|
3545 |
"llama-3-1-70b-instruct": "meta-llama/llama-3-1-70b-instruct",
|
3546 |
"llama-3-1-405b-instruct": "meta-llama/llama-3-1-405b-instruct-fp8",
|
3547 |
"llama-3-1-405b-instruct-fp8": "meta-llama/llama-3-1-405b-instruct-fp8",
|
3548 |
"llama-3-2-11b-vision-instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
3549 |
"llama-3-2-90b-vision-instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
3550 |
"llama-3-3-70b-instruct": "meta-llama/llama-3-3-70b-instruct",
|
3551 |
+
"llama-4-scout": "meta-llama/llama-4-scout-17b-16e",
|
3552 |
+
"llama-4-maverick": "meta-llama/llama-4-maverick-17b-128e-instruct-fp8",
|
3553 |
"mistral-large-instruct": "mistralai/mistral-large-instruct-2407",
|
3554 |
"mixtral-8x7b-instruct": "mistralai/mixtral-8x7B-instruct-v0.1",
|
3555 |
"mixtral-8x7b-instruct-v01": "mistralai/mixtral-8x7B-instruct-v0.1",
|
3556 |
+
"deepseek-v3": "deepseek-ai/DeepSeek-V3",
|
3557 |
"granite-guardian-3-2-3b-a800m": "ibm-granite/granite-guardian-3.2-3b-a800m",
|
3558 |
"granite-guardian-3-2-5b": "ibm-granite/granite-guardian-3.2-5b",
|
3559 |
},
|
llm_as_judge_constants.py
CHANGED
@@ -125,7 +125,7 @@ EVALUATOR_TO_MODEL_ID = {
|
|
125 |
EvaluatorNameEnum.GRANITE3_1_8B: "granite-3-1-8b-instruct",
|
126 |
EvaluatorNameEnum.GRANITE3_2_8B: "granite-3-2-8b-instruct",
|
127 |
EvaluatorNameEnum.GRANITE3_3_8B: "granite-3-3-8b-instruct",
|
128 |
-
EvaluatorNameEnum.DEEPSEEK_V3: "deepseek-
|
129 |
EvaluatorNameEnum.GEMMA_2_5_PRO: "gemma-2-5-pro",
|
130 |
EvaluatorNameEnum.GEMINI_2_5_FLASH: "gemini-2-5-flash",
|
131 |
}
|
@@ -198,7 +198,6 @@ EVALUATORS_METADATA = [
|
|
198 |
[
|
199 |
ModelProviderEnum.WATSONX,
|
200 |
ModelProviderEnum.TOGETHER_AI,
|
201 |
-
ModelProviderEnum.RITS,
|
202 |
ModelProviderEnum.OLLAMA,
|
203 |
],
|
204 |
),
|
|
|
125 |
EvaluatorNameEnum.GRANITE3_1_8B: "granite-3-1-8b-instruct",
|
126 |
EvaluatorNameEnum.GRANITE3_2_8B: "granite-3-2-8b-instruct",
|
127 |
EvaluatorNameEnum.GRANITE3_3_8B: "granite-3-3-8b-instruct",
|
128 |
+
EvaluatorNameEnum.DEEPSEEK_V3: "deepseek-v3",
|
129 |
EvaluatorNameEnum.GEMMA_2_5_PRO: "gemma-2-5-pro",
|
130 |
EvaluatorNameEnum.GEMINI_2_5_FLASH: "gemini-2-5-flash",
|
131 |
}
|
|
|
198 |
[
|
199 |
ModelProviderEnum.WATSONX,
|
200 |
ModelProviderEnum.TOGETHER_AI,
|
|
|
201 |
ModelProviderEnum.OLLAMA,
|
202 |
],
|
203 |
),
|
loaders.py
CHANGED
@@ -66,7 +66,7 @@ from tqdm import tqdm
|
|
66 |
|
67 |
from .dataclass import NonPositionalField
|
68 |
from .dict_utils import dict_get
|
69 |
-
from .error_utils import Documentation, UnitxtError, UnitxtWarning
|
70 |
from .fusion import FixedFusion
|
71 |
from .logging_utils import get_logger
|
72 |
from .operator import SourceOperator
|
@@ -90,23 +90,27 @@ class UnitxtUnverifiedCodeError(UnitxtError):
|
|
90 |
|
91 |
@retry_connection_with_exponential_backoff(backoff_factor=2)
|
92 |
def hf_load_dataset(path: str, *args, **kwargs):
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
110 |
|
111 |
|
112 |
@retry_connection_with_exponential_backoff(backoff_factor=2)
|
@@ -218,13 +222,15 @@ class Loader(SourceOperator):
|
|
218 |
pass
|
219 |
|
220 |
def load_data(self) -> MultiStream:
|
221 |
-
|
|
|
|
|
|
|
|
|
222 |
iterables = self.load_iterables()
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
return iterables
|
227 |
-
return MultiStream.from_iterables(iterables, copying=True)
|
228 |
|
229 |
def process(self) -> MultiStream:
|
230 |
self._maybe_set_classification_policy()
|
@@ -514,9 +520,13 @@ class LoadCSV(LoadWithPandas):
|
|
514 |
sep: str = ","
|
515 |
|
516 |
def read_dataframe(self, file) -> pd.DataFrame:
|
517 |
-
|
518 |
-
|
519 |
-
|
|
|
|
|
|
|
|
|
520 |
|
521 |
|
522 |
def read_file(source) -> bytes:
|
@@ -560,32 +570,36 @@ class LoadJsonFile(LoadWithPandas):
|
|
560 |
data_field: Optional[str] = None
|
561 |
|
562 |
def read_dataframe(self, file) -> pd.DataFrame:
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
)
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
instances = data
|
577 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
578 |
raise UnitxtError(
|
579 |
-
|
580 |
)
|
581 |
-
|
582 |
-
|
583 |
-
if self.data_field is not None:
|
584 |
-
raise UnitxtError(
|
585 |
-
"Can not load from a specific 'data_field' when loading multiple lines (lines=True)"
|
586 |
-
)
|
587 |
-
dataframe = pd.read_json(file, lines=self.lines, **args)
|
588 |
-
return dataframe
|
589 |
|
590 |
|
591 |
class LoadFromSklearn(LazyLoader):
|
@@ -631,8 +645,12 @@ class LoadFromSklearn(LazyLoader):
|
|
631 |
dataset_id = str(self) + "_" + split
|
632 |
dataset = self.__class__._loader_cache.get(dataset_id, None)
|
633 |
if dataset is None:
|
634 |
-
|
635 |
-
|
|
|
|
|
|
|
|
|
636 |
df = pd.DataFrame([split_data["data"], targets]).T
|
637 |
df.columns = ["data", "target"]
|
638 |
dataset = df.to_dict("records")
|
@@ -851,18 +869,22 @@ class LoadFromIBMCloud(Loader):
|
|
851 |
if self.data_dir is not None
|
852 |
else data_file
|
853 |
)
|
854 |
-
with
|
855 |
-
|
856 |
-
|
857 |
-
|
858 |
-
|
859 |
-
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
|
865 |
-
|
|
|
|
|
|
|
|
|
866 |
|
867 |
if isinstance(self.data_files, list):
|
868 |
dataset = hf_load_dataset(local_dir, streaming=False, field=self.data_field)
|
@@ -946,22 +968,26 @@ class LoadFromDictionary(Loader):
|
|
946 |
|
947 |
def verify(self):
|
948 |
super().verify()
|
949 |
-
|
950 |
-
|
951 |
-
|
952 |
-
|
953 |
-
|
954 |
-
|
955 |
-
|
956 |
-
|
957 |
-
|
958 |
-
|
959 |
-
for
|
960 |
-
if
|
961 |
-
raise ValueError(
|
962 |
-
|
963 |
-
|
964 |
-
)
|
|
|
|
|
|
|
|
|
965 |
|
966 |
def _maybe_set_classification_policy(self):
|
967 |
self.set_default_data_classification(
|
@@ -1127,7 +1153,7 @@ class LoadFromAPI(Loader):
|
|
1127 |
chunksize: int = 100000
|
1128 |
loader_limit: Optional[int] = None
|
1129 |
streaming: bool = False
|
1130 |
-
api_key_env_var: Optional[str] =
|
1131 |
headers: Optional[Dict[str, Any]] = None
|
1132 |
data_field: str = "data"
|
1133 |
method: str = "GET"
|
|
|
66 |
|
67 |
from .dataclass import NonPositionalField
|
68 |
from .dict_utils import dict_get
|
69 |
+
from .error_utils import Documentation, UnitxtError, UnitxtWarning, error_context
|
70 |
from .fusion import FixedFusion
|
71 |
from .logging_utils import get_logger
|
72 |
from .operator import SourceOperator
|
|
|
90 |
|
91 |
@retry_connection_with_exponential_backoff(backoff_factor=2)
|
92 |
def hf_load_dataset(path: str, *args, **kwargs):
|
93 |
+
with error_context(
|
94 |
+
stage="Raw Dataset Download",
|
95 |
+
help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
|
96 |
+
):
|
97 |
+
if settings.hf_offline_datasets_path is not None:
|
98 |
+
path = os.path.join(settings.hf_offline_datasets_path, path)
|
99 |
+
try:
|
100 |
+
return _hf_load_dataset(
|
101 |
+
path,
|
102 |
+
*args,
|
103 |
+
**kwargs,
|
104 |
+
verification_mode="no_checks",
|
105 |
+
trust_remote_code=settings.allow_unverified_code,
|
106 |
+
download_mode="force_redownload"
|
107 |
+
if settings.disable_hf_datasets_cache
|
108 |
+
else "reuse_dataset_if_exists",
|
109 |
+
)
|
110 |
+
except ValueError as e:
|
111 |
+
if "trust_remote_code" in str(e):
|
112 |
+
raise UnitxtUnverifiedCodeError(path) from e
|
113 |
+
raise e # Re raise
|
114 |
|
115 |
|
116 |
@retry_connection_with_exponential_backoff(backoff_factor=2)
|
|
|
222 |
pass
|
223 |
|
224 |
def load_data(self) -> MultiStream:
|
225 |
+
with error_context(
|
226 |
+
self,
|
227 |
+
stage="Data Loading",
|
228 |
+
help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
|
229 |
+
):
|
230 |
iterables = self.load_iterables()
|
231 |
+
if isoftype(iterables, MultiStream):
|
232 |
+
return iterables
|
233 |
+
return MultiStream.from_iterables(iterables, copying=True)
|
|
|
|
|
234 |
|
235 |
def process(self) -> MultiStream:
|
236 |
self._maybe_set_classification_policy()
|
|
|
520 |
sep: str = ","
|
521 |
|
522 |
def read_dataframe(self, file) -> pd.DataFrame:
|
523 |
+
with error_context(
|
524 |
+
stage="Raw Dataset Loading",
|
525 |
+
help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
|
526 |
+
):
|
527 |
+
return pd.read_csv(
|
528 |
+
file, sep=self.sep, low_memory=self.streaming, **self.get_args()
|
529 |
+
)
|
530 |
|
531 |
|
532 |
def read_file(source) -> bytes:
|
|
|
570 |
data_field: Optional[str] = None
|
571 |
|
572 |
def read_dataframe(self, file) -> pd.DataFrame:
|
573 |
+
with error_context(
|
574 |
+
stage="Raw Dataset Loading",
|
575 |
+
help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
|
576 |
+
):
|
577 |
+
args = self.get_args()
|
578 |
+
if not self.lines:
|
579 |
+
data = json.loads(read_file(file))
|
580 |
+
if self.data_field:
|
581 |
+
instances = dict_get(data, self.data_field)
|
582 |
+
if not isoftype(instances, List[Dict[str, Any]]):
|
583 |
+
raise UnitxtError(
|
584 |
+
f"{self.data_field} of file {file} is not a list of dictionariess in LoadJsonFile loader"
|
585 |
+
)
|
|
|
586 |
else:
|
587 |
+
if isoftype(data, Dict[str, Any]):
|
588 |
+
instances = [data]
|
589 |
+
elif isoftype(data, List[Dict[str, Any]]):
|
590 |
+
instances = data
|
591 |
+
else:
|
592 |
+
raise UnitxtError(
|
593 |
+
f"data of file {file} is not dictionary or a list of dictionaries in LoadJsonFile loader"
|
594 |
+
)
|
595 |
+
dataframe = pd.DataFrame(instances)
|
596 |
+
else:
|
597 |
+
if self.data_field is not None:
|
598 |
raise UnitxtError(
|
599 |
+
"Can not load from a specific 'data_field' when loading multiple lines (lines=True)"
|
600 |
)
|
601 |
+
dataframe = pd.read_json(file, lines=self.lines, **args)
|
602 |
+
return dataframe
|
|
|
|
|
|
|
|
|
|
|
|
|
603 |
|
604 |
|
605 |
class LoadFromSklearn(LazyLoader):
|
|
|
645 |
dataset_id = str(self) + "_" + split
|
646 |
dataset = self.__class__._loader_cache.get(dataset_id, None)
|
647 |
if dataset is None:
|
648 |
+
with error_context(
|
649 |
+
stage="Raw Dataset Loading",
|
650 |
+
help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
|
651 |
+
):
|
652 |
+
split_data = self.downloader(subset=split)
|
653 |
+
targets = [split_data["target_names"][t] for t in split_data["target"]]
|
654 |
df = pd.DataFrame([split_data["data"], targets]).T
|
655 |
df.columns = ["data", "target"]
|
656 |
dataset = df.to_dict("records")
|
|
|
869 |
if self.data_dir is not None
|
870 |
else data_file
|
871 |
)
|
872 |
+
with error_context(
|
873 |
+
stage="Raw Dataset Download",
|
874 |
+
help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
|
875 |
+
):
|
876 |
+
with tempfile.NamedTemporaryFile() as temp_file:
|
877 |
+
# Download to a temporary file in same file partition, and then do an atomic move
|
878 |
+
self._download_from_cos(
|
879 |
+
cos,
|
880 |
+
self.bucket_name,
|
881 |
+
object_key,
|
882 |
+
local_dir + "/" + os.path.basename(temp_file.name),
|
883 |
+
)
|
884 |
+
os.renames(
|
885 |
+
local_dir + "/" + os.path.basename(temp_file.name),
|
886 |
+
local_dir + "/" + data_file,
|
887 |
+
)
|
888 |
|
889 |
if isinstance(self.data_files, list):
|
890 |
dataset = hf_load_dataset(local_dir, streaming=False, field=self.data_field)
|
|
|
968 |
|
969 |
def verify(self):
|
970 |
super().verify()
|
971 |
+
with error_context(
|
972 |
+
stage="Dataset Loading",
|
973 |
+
help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
|
974 |
+
):
|
975 |
+
if not isoftype(self.data, Dict[str, List[Dict[str, Any]]]):
|
976 |
+
raise ValueError(
|
977 |
+
f"Passed data to LoadFromDictionary is not of type Dict[str, List[Dict[str, Any]]].\n"
|
978 |
+
f"Expected data should map between split name and list of instances.\n"
|
979 |
+
f"Received value: {self.data}\n"
|
980 |
+
)
|
981 |
+
for split in self.data.keys():
|
982 |
+
if len(self.data[split]) == 0:
|
983 |
+
raise ValueError(f"Split {split} has no instances.")
|
984 |
+
first_instance = self.data[split][0]
|
985 |
+
for instance in self.data[split]:
|
986 |
+
if instance.keys() != first_instance.keys():
|
987 |
+
raise ValueError(
|
988 |
+
f"Not all instances in split '{split}' have the same fields.\n"
|
989 |
+
f"instance {instance} has different fields different from {first_instance}"
|
990 |
+
)
|
991 |
|
992 |
def _maybe_set_classification_policy(self):
|
993 |
self.set_default_data_classification(
|
|
|
1153 |
chunksize: int = 100000
|
1154 |
loader_limit: Optional[int] = None
|
1155 |
streaming: bool = False
|
1156 |
+
api_key_env_var: Optional[str] = None
|
1157 |
headers: Optional[Dict[str, Any]] = None
|
1158 |
data_field: str = "data"
|
1159 |
method: str = "GET"
|
metric.py
CHANGED
@@ -56,7 +56,6 @@ from .settings_utils import get_constants
|
|
56 |
from .span_lableing_operators import __file__ as _
|
57 |
from .split_utils import __file__ as _
|
58 |
from .splitters import __file__ as _
|
59 |
-
from .sql_utils import __file__ as _
|
60 |
from .standard import __file__ as _
|
61 |
from .stream import __file__ as _
|
62 |
from .stream_operators import __file__ as _
|
@@ -65,6 +64,7 @@ from .struct_data_operators import __file__ as _
|
|
65 |
from .system_prompts import __file__ as _
|
66 |
from .task import __file__ as _
|
67 |
from .templates import __file__ as _
|
|
|
68 |
from .text_utils import __file__ as _
|
69 |
from .type_utils import __file__ as _
|
70 |
from .types import __file__ as _
|
|
|
56 |
from .span_lableing_operators import __file__ as _
|
57 |
from .split_utils import __file__ as _
|
58 |
from .splitters import __file__ as _
|
|
|
59 |
from .standard import __file__ as _
|
60 |
from .stream import __file__ as _
|
61 |
from .stream_operators import __file__ as _
|
|
|
64 |
from .system_prompts import __file__ as _
|
65 |
from .task import __file__ as _
|
66 |
from .templates import __file__ as _
|
67 |
+
from .text2sql_utils import __file__ as _
|
68 |
from .text_utils import __file__ as _
|
69 |
from .type_utils import __file__ as _
|
70 |
from .types import __file__ as _
|
metric_utils.py
CHANGED
@@ -9,7 +9,7 @@ import pandas as pd
|
|
9 |
from datasets import Features, Value
|
10 |
|
11 |
from .dataclass import Dataclass
|
12 |
-
from .error_utils import Documentation, UnitxtError
|
13 |
from .operator import (
|
14 |
InstanceOperator,
|
15 |
MultiStreamOperator,
|
@@ -36,6 +36,9 @@ from .utils import recursive_copy
|
|
36 |
|
37 |
constants = get_constants()
|
38 |
|
|
|
|
|
|
|
39 |
|
40 |
def nan_mean(scores):
|
41 |
result = mean(score for score in scores if score == score)
|
@@ -56,7 +59,10 @@ class FromPredictionsAndOriginalData(StreamInitializerOperator):
|
|
56 |
yield {**original, "prediction": prediction}
|
57 |
|
58 |
def process(
|
59 |
-
self,
|
|
|
|
|
|
|
60 |
) -> MultiStream:
|
61 |
return MultiStream(
|
62 |
{
|
@@ -152,7 +158,7 @@ class SplitSubsetsAndGroups(MultiStreamOperator):
|
|
152 |
|
153 |
subset_stream_name = (
|
154 |
stream_name
|
155 |
-
+
|
156 |
+ "/".join(instance[self.subsets_field][: self.subset_depth])
|
157 |
)
|
158 |
|
@@ -190,7 +196,7 @@ def group_str_to_key_value(group_str):
|
|
190 |
|
191 |
@lru_cache(maxsize=None)
|
192 |
def stream_name_to_origin_subset_group(stream_name):
|
193 |
-
origin, subset_group = stream_name.split(
|
194 |
if "?" in subset_group:
|
195 |
subset, group = subset_group.split("?")
|
196 |
else:
|
@@ -734,22 +740,23 @@ def _compute(
|
|
734 |
predictions: List[Any],
|
735 |
references: Iterable,
|
736 |
flatten: bool = False,
|
737 |
-
split_name: str =
|
738 |
calc_confidence_intervals: bool = True,
|
739 |
):
|
740 |
_reset_env_local_catalogs()
|
741 |
register_all_artifacts()
|
742 |
recipe = MetricRecipe(calc_confidence_intervals=calc_confidence_intervals)
|
743 |
|
744 |
-
|
745 |
-
|
746 |
-
|
|
|
747 |
|
748 |
-
|
749 |
-
|
750 |
-
|
751 |
|
752 |
-
|
753 |
return EvaluationResults(stream)
|
754 |
|
755 |
|
|
|
9 |
from datasets import Features, Value
|
10 |
|
11 |
from .dataclass import Dataclass
|
12 |
+
from .error_utils import Documentation, UnitxtError, error_context
|
13 |
from .operator import (
|
14 |
InstanceOperator,
|
15 |
MultiStreamOperator,
|
|
|
36 |
|
37 |
constants = get_constants()
|
38 |
|
39 |
+
DEFAULT_STREAM_NAME = "all_data"
|
40 |
+
DEFAULT_STREAM_SUBSET_SEPARATOR = ">>"
|
41 |
+
|
42 |
|
43 |
def nan_mean(scores):
|
44 |
result = mean(score for score in scores if score == score)
|
|
|
59 |
yield {**original, "prediction": prediction}
|
60 |
|
61 |
def process(
|
62 |
+
self,
|
63 |
+
predictions: List[str],
|
64 |
+
references: Iterable,
|
65 |
+
split_name: str = DEFAULT_STREAM_NAME,
|
66 |
) -> MultiStream:
|
67 |
return MultiStream(
|
68 |
{
|
|
|
158 |
|
159 |
subset_stream_name = (
|
160 |
stream_name
|
161 |
+
+ DEFAULT_STREAM_SUBSET_SEPARATOR
|
162 |
+ "/".join(instance[self.subsets_field][: self.subset_depth])
|
163 |
)
|
164 |
|
|
|
196 |
|
197 |
@lru_cache(maxsize=None)
|
198 |
def stream_name_to_origin_subset_group(stream_name):
|
199 |
+
origin, subset_group = stream_name.split(DEFAULT_STREAM_SUBSET_SEPARATOR)
|
200 |
if "?" in subset_group:
|
201 |
subset, group = subset_group.split("?")
|
202 |
else:
|
|
|
740 |
predictions: List[Any],
|
741 |
references: Iterable,
|
742 |
flatten: bool = False,
|
743 |
+
split_name: str = DEFAULT_STREAM_NAME,
|
744 |
calc_confidence_intervals: bool = True,
|
745 |
):
|
746 |
_reset_env_local_catalogs()
|
747 |
register_all_artifacts()
|
748 |
recipe = MetricRecipe(calc_confidence_intervals=calc_confidence_intervals)
|
749 |
|
750 |
+
with error_context(stage="Metric Processing"):
|
751 |
+
multi_stream = recipe(
|
752 |
+
predictions=predictions, references=references, split_name=split_name
|
753 |
+
)
|
754 |
|
755 |
+
if flatten:
|
756 |
+
operator = FlattenInstances()
|
757 |
+
multi_stream = operator(multi_stream)
|
758 |
|
759 |
+
stream = multi_stream[split_name]
|
760 |
return EvaluationResults(stream)
|
761 |
|
762 |
|
metrics.py
CHANGED
@@ -8,7 +8,8 @@ import uuid
|
|
8 |
import warnings
|
9 |
from abc import ABC, abstractmethod
|
10 |
from collections import Counter, defaultdict
|
11 |
-
from dataclasses import field
|
|
|
12 |
from enum import Enum
|
13 |
from functools import lru_cache
|
14 |
from typing import (
|
@@ -42,7 +43,8 @@ from .dataclass import (
|
|
42 |
OptionalField,
|
43 |
)
|
44 |
from .deprecation_utils import deprecation
|
45 |
-
from .
|
|
|
46 |
from .inference import (
|
47 |
HFPipelineBasedInferenceEngine,
|
48 |
InferenceEngine,
|
@@ -64,6 +66,7 @@ from .operators import ArtifactFetcherMixin, Copy, FieldOperator, Set
|
|
64 |
from .random_utils import get_seed
|
65 |
from .settings_utils import get_settings
|
66 |
from .stream import MultiStream, Stream
|
|
|
67 |
from .type_utils import isoftype, parse_type_string, to_type_string
|
68 |
from .types import ToolCall
|
69 |
from .utils import deep_copy, recursive_copy, retry_connection_with_exponential_backoff
|
@@ -382,28 +385,35 @@ class MapReduceMetric(
|
|
382 |
return intermediates
|
383 |
|
384 |
def process(self, stream: Stream, stream_name: Optional[str] = None):
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
399 |
|
400 |
-
|
401 |
-
|
402 |
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
|
408 |
def compute(self, stream: Stream, stream_name: Optional[str] = None):
|
409 |
evaluation_inputs_stream = self._instances_stream_to_evaluation_inputs(stream)
|
@@ -453,6 +463,43 @@ class DictReduction(AggregationReduction[Dict[str, float]]):
|
|
453 |
return result
|
454 |
|
455 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
class MeanReduction(DictReduction):
|
457 |
def reduce_list(self, lst: List[float]):
|
458 |
return nan_mean(lst)
|
@@ -468,6 +515,91 @@ class MaxReduction(DictReduction):
|
|
468 |
return float(nan_max(lst))
|
469 |
|
470 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
471 |
class ReductionInstanceMetric(
|
472 |
MapReduceMetric[PredictionType, IntermediateType],
|
473 |
Generic[PredictionType, IntermediateType],
|
@@ -704,6 +836,52 @@ class ToolCallingMetric(ReductionInstanceMetric[str, Dict[str, float]]):
|
|
704 |
}
|
705 |
|
706 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
707 |
class MetricWithConfidenceInterval(Metric):
|
708 |
# The number of resamples used to estimate the confidence intervals of this metric.
|
709 |
# Use None to disable confidence interval computation.
|
@@ -954,83 +1132,88 @@ class GlobalMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
954 |
process_single_instances = True
|
955 |
|
956 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
957 |
-
|
958 |
-
|
959 |
-
|
|
|
|
|
|
|
|
|
|
|
960 |
|
961 |
-
|
962 |
|
963 |
-
|
964 |
-
|
965 |
|
966 |
-
|
967 |
-
|
968 |
|
969 |
-
|
970 |
-
|
971 |
-
|
972 |
-
|
973 |
|
974 |
-
|
975 |
-
|
976 |
-
|
977 |
|
978 |
-
|
979 |
-
|
980 |
-
|
981 |
-
|
982 |
-
|
983 |
|
984 |
-
|
985 |
-
|
986 |
-
|
987 |
-
|
988 |
-
|
989 |
-
|
990 |
-
|
991 |
-
|
992 |
-
|
993 |
-
|
994 |
-
|
995 |
-
|
996 |
-
|
997 |
-
|
998 |
-
|
999 |
-
|
1000 |
|
1001 |
-
|
1002 |
-
|
1003 |
|
1004 |
-
|
1005 |
-
|
1006 |
-
|
|
|
1007 |
)
|
1008 |
-
)
|
1009 |
-
|
1010 |
-
global_score = {"num_of_instances": len(instances)}
|
1011 |
|
1012 |
-
|
1013 |
-
|
1014 |
-
|
1015 |
-
|
|
|
1016 |
)
|
1017 |
-
|
1018 |
-
|
1019 |
-
|
1020 |
-
|
1021 |
-
|
1022 |
-
|
1023 |
-
score_names = [global_score["score_name"]]
|
1024 |
|
1025 |
-
|
1026 |
-
|
1027 |
-
|
1028 |
-
|
1029 |
-
|
1030 |
|
1031 |
-
|
1032 |
-
|
1033 |
-
|
1034 |
|
1035 |
def _compute(
|
1036 |
self,
|
@@ -1080,96 +1263,105 @@ class BulkInstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1080 |
return instance
|
1081 |
|
1082 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
1083 |
-
|
1084 |
-
|
1085 |
-
|
1086 |
-
|
1087 |
-
|
1088 |
-
|
1089 |
-
|
1090 |
-
|
1091 |
-
|
1092 |
-
|
1093 |
-
|
1094 |
-
|
1095 |
-
|
1096 |
-
|
1097 |
-
|
1098 |
-
|
1099 |
-
|
1100 |
-
predictions
|
1101 |
-
|
1102 |
-
|
|
|
|
|
|
|
|
|
|
|
1103 |
|
1104 |
-
|
1105 |
-
|
1106 |
-
|
1107 |
-
|
1108 |
|
1109 |
-
|
1110 |
-
|
1111 |
-
|
1112 |
|
1113 |
-
|
1114 |
-
|
1115 |
-
|
|
|
1116 |
)
|
1117 |
-
)
|
1118 |
|
1119 |
-
|
1120 |
-
|
1121 |
-
|
1122 |
-
|
1123 |
-
|
1124 |
-
|
1125 |
-
|
1126 |
-
|
1127 |
-
|
1128 |
-
|
1129 |
-
|
1130 |
-
|
1131 |
-
|
1132 |
-
|
1133 |
-
|
1134 |
-
|
1135 |
-
|
|
|
|
|
1136 |
|
1137 |
-
|
1138 |
-
|
1139 |
-
|
1140 |
-
|
1141 |
-
|
1142 |
-
|
1143 |
-
|
1144 |
-
|
1145 |
-
|
1146 |
-
|
1147 |
-
|
1148 |
-
|
1149 |
-
|
1150 |
-
|
1151 |
-
|
1152 |
-
|
1153 |
-
|
1154 |
-
|
1155 |
-
|
1156 |
-
|
1157 |
-
|
1158 |
-
|
1159 |
-
|
1160 |
-
|
1161 |
-
|
1162 |
-
|
1163 |
-
|
1164 |
-
|
1165 |
-
|
1166 |
-
|
1167 |
-
|
1168 |
-
|
|
|
|
|
1169 |
|
1170 |
-
|
1171 |
-
|
1172 |
-
|
1173 |
|
1174 |
@abstractmethod
|
1175 |
def compute(
|
@@ -1475,91 +1667,97 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1475 |
assert isinstance(fields["score_fields"], list)
|
1476 |
|
1477 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
1478 |
-
|
1479 |
-
|
1480 |
-
|
1481 |
-
|
1482 |
-
|
1483 |
-
|
1484 |
-
|
1485 |
-
|
1486 |
-
|
1487 |
-
|
1488 |
-
|
1489 |
-
|
1490 |
-
|
1491 |
-
#
|
1492 |
-
|
1493 |
-
|
1494 |
-
|
1495 |
-
|
1496 |
-
|
1497 |
-
|
1498 |
-
|
1499 |
-
|
1500 |
-
|
1501 |
-
|
1502 |
-
|
1503 |
-
|
1504 |
-
|
1505 |
-
|
1506 |
-
|
1507 |
-
|
1508 |
-
|
1509 |
-
|
1510 |
-
|
1511 |
-
|
1512 |
-
|
1513 |
-
|
1514 |
-
|
1515 |
-
|
1516 |
-
|
1517 |
-
|
1518 |
-
|
1519 |
-
|
1520 |
-
|
1521 |
-
|
1522 |
-
|
1523 |
-
|
|
|
|
|
|
|
|
|
|
|
1524 |
|
1525 |
-
|
1526 |
-
|
1527 |
-
|
1528 |
-
|
1529 |
-
|
1530 |
-
|
1531 |
-
|
1532 |
-
|
1533 |
-
|
1534 |
-
|
1535 |
-
|
1536 |
-
|
1537 |
-
|
1538 |
-
|
1539 |
-
|
1540 |
-
|
1541 |
-
|
1542 |
-
|
1543 |
-
|
1544 |
-
|
1545 |
-
|
1546 |
-
|
1547 |
-
|
1548 |
-
|
1549 |
-
|
1550 |
-
|
1551 |
-
|
1552 |
-
|
1553 |
-
|
1554 |
-
|
1555 |
-
|
|
|
1556 |
|
1557 |
-
|
1558 |
-
|
1559 |
|
1560 |
-
|
1561 |
-
|
1562 |
-
|
1563 |
|
1564 |
def compute_instance_scores(
|
1565 |
self, stream: Stream, stream_name: Optional[str] = None
|
@@ -6436,391 +6634,102 @@ RISK_TYPE_TO_CLASS: Dict[RiskType, GraniteGuardianBase] = {
|
|
6436 |
}
|
6437 |
|
6438 |
|
6439 |
-
class
|
6440 |
-
|
6441 |
-
|
6442 |
-
|
6443 |
-
|
6444 |
-
"subset_non_empty_execution_result",
|
6445 |
-
"non_empty_gold_df",
|
6446 |
-
"gold_sql_runtime",
|
6447 |
-
"predicted_sql_runtime",
|
6448 |
-
"pred_to_gold_runtime_ratio",
|
6449 |
-
"gold_error",
|
6450 |
-
"predicted_error",
|
6451 |
-
]
|
6452 |
-
}
|
6453 |
main_score = "non_empty_execution_accuracy"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6454 |
ci_scores = [
|
6455 |
"execution_accuracy",
|
6456 |
"non_empty_execution_accuracy",
|
6457 |
-
"
|
|
|
6458 |
"gold_sql_runtime",
|
6459 |
"predicted_sql_runtime",
|
6460 |
]
|
6461 |
|
6462 |
-
|
6463 |
-
|
6464 |
-
|
6465 |
-
|
6466 |
-
|
6467 |
-
|
6468 |
-
|
6469 |
-
"""Compares two DataFrames based on row content, ignoring column names.
|
6470 |
-
|
6471 |
-
Args:
|
6472 |
-
df1 (pd.DataFrame): Pandas DataFrame 1 to compare.
|
6473 |
-
df2 (pd.DataFrame): Pandas DataFrame 2 to compare.
|
6474 |
-
|
6475 |
-
Returns:
|
6476 |
-
True if the DataFrames have the same ordered rows (ignoring column names),
|
6477 |
-
False otherwise.
|
6478 |
-
"""
|
6479 |
-
df1.fillna(0, inplace=True)
|
6480 |
-
df2.fillna(0, inplace=True)
|
6481 |
-
|
6482 |
-
# Compare row counts first for a quick check
|
6483 |
-
if df1.shape != df2.shape:
|
6484 |
-
return False
|
6485 |
-
|
6486 |
-
# Convert DataFrames to numpy arrays of strings to handle mixed types
|
6487 |
-
df1_array = df1.values.astype(str)
|
6488 |
-
df2_array = df2.values.astype(str)
|
6489 |
-
|
6490 |
-
# Sort each row's elements (column order independence)
|
6491 |
-
df1_sorted_rows = np.array([np.sort(row) for row in df1_array])
|
6492 |
-
df2_sorted_rows = np.array([np.sort(row) for row in df2_array])
|
6493 |
-
|
6494 |
-
# Compare the sorted rows in order
|
6495 |
-
return np.array_equal(df1_sorted_rows, df2_sorted_rows)
|
6496 |
-
|
6497 |
-
@staticmethod
|
6498 |
-
def compare_dfs_ignore_colnames_unordered_rows(df1, df2):
|
6499 |
-
"""Compares two DataFrames based on row content, ignoring row order and column names.
|
6500 |
-
|
6501 |
-
Args:
|
6502 |
-
df1 (pd.DataFrame): Pandas DataFrame 1 to compare.
|
6503 |
-
df2 (pd.DataFrame): Pandas DataFrame 2 to compare.
|
6504 |
-
|
6505 |
-
Returns:
|
6506 |
-
True if the DataFrames have the same content (ignoring column names and row order),
|
6507 |
-
False otherwise.
|
6508 |
-
"""
|
6509 |
-
# Compare shapes early on
|
6510 |
-
if df1.shape != df2.shape:
|
6511 |
-
return False
|
6512 |
-
|
6513 |
-
# Convert DataFrames to numpy arrays of strings (to handle mixed data types)
|
6514 |
-
df1_array = df1.values.astype(str)
|
6515 |
-
df2_array = df2.values.astype(str)
|
6516 |
-
|
6517 |
-
# Sort columns first, then sort rows
|
6518 |
-
df1_sorted = np.sort(np.sort(df1_array, axis=1), axis=0)
|
6519 |
-
df2_sorted = np.sort(np.sort(df2_array, axis=1), axis=0)
|
6520 |
-
|
6521 |
-
# Compare the sorted arrays
|
6522 |
-
return np.array_equal(df1_sorted, df2_sorted)
|
6523 |
-
|
6524 |
-
@staticmethod
|
6525 |
-
def compare_dfs_ignore_colnames_subset(df1, df2, ignore_row_order=True):
|
6526 |
-
"""Checks if the values of either DataFrame are a subset of the values in the other DataFrame.
|
6527 |
-
|
6528 |
-
Comparison is column order independent, and could optionally be row order independent.
|
6529 |
-
We interpret "subset" as follows:
|
6530 |
-
|
6531 |
-
- For each row in df1, there must be a matching (or superset) row in df2, i.e. the set of values
|
6532 |
-
in the df1 row is a subset of the set of values in that df2 row. Then do the same check in reverse.
|
6533 |
-
- If either condition (df1 is subset of df2 OR df2 is subset of df1) is satisfied, return True.
|
6534 |
-
|
6535 |
-
We treat an empty dataframe as a subset of nothing, while in theory is a subset of any dataframe.
|
6536 |
-
|
6537 |
-
Args:
|
6538 |
-
df1 (pd.DataFrame): Pandas DataFrame 1 to compare.
|
6539 |
-
df2 (pd.DataFrame): Pandas DataFrame 2 to compare.
|
6540 |
-
ignore_row_order (bool): If True, row order doesn't matter; if False, row order is respected.
|
6541 |
-
|
6542 |
-
Returns:
|
6543 |
-
bool: True if df1 is a subset of df2 or vice versa, based on the specified row-order condition.
|
6544 |
-
|
6545 |
-
"""
|
6546 |
-
df1_array = df1.values.astype(str)
|
6547 |
-
df2_array = df2.values.astype(str)
|
6548 |
-
|
6549 |
-
df1_sorted_rows = [np.sort(row) for row in df1_array]
|
6550 |
-
df2_sorted_rows = [np.sort(row) for row in df2_array]
|
6551 |
-
|
6552 |
-
def row_is_subset(r_small, r_big):
|
6553 |
-
"""Check if all elements of r_small are in r_big."""
|
6554 |
-
return set(r_small).issubset(set(r_big))
|
6555 |
-
|
6556 |
-
def df_is_subset_of_another(rows_small, rows_big, respect_order):
|
6557 |
-
"""Check if the rows_small is subset of rows_big under the given order condition."""
|
6558 |
-
if not rows_small:
|
6559 |
-
return False # DataFrame needs to be non-empty
|
6560 |
-
|
6561 |
-
# If row order matters:
|
6562 |
-
if respect_order:
|
6563 |
-
i, j = 0, 0
|
6564 |
-
while i < len(rows_small) and j < len(rows_big):
|
6565 |
-
if row_is_subset(rows_small[i], rows_big[j]):
|
6566 |
-
i += 1
|
6567 |
-
j += 1
|
6568 |
-
return i == len(rows_small)
|
6569 |
-
# Row order doesn't matter:
|
6570 |
-
matched_indices = set()
|
6571 |
-
for r_small in rows_small:
|
6572 |
-
found_match = False
|
6573 |
-
for idx, r_big in enumerate(rows_big):
|
6574 |
-
if idx not in matched_indices and row_is_subset(r_small, r_big):
|
6575 |
-
found_match = True
|
6576 |
-
matched_indices.add(idx)
|
6577 |
-
break
|
6578 |
-
if not found_match:
|
6579 |
-
return False
|
6580 |
-
return True
|
6581 |
-
|
6582 |
-
df1_sub_df2 = df_is_subset_of_another(
|
6583 |
-
df1_sorted_rows, df2_sorted_rows, not ignore_row_order
|
6584 |
-
)
|
6585 |
-
df2_sub_df1 = df_is_subset_of_another(
|
6586 |
-
df2_sorted_rows, df1_sorted_rows, not ignore_row_order
|
6587 |
)
|
6588 |
|
6589 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6590 |
|
6591 |
-
|
6592 |
-
|
6593 |
-
|
6594 |
-
|
6595 |
|
6596 |
-
|
6597 |
-
|
6598 |
-
|
6599 |
-
|
|
|
6600 |
|
6601 |
-
Returns:
|
6602 |
-
a 12-tuple of
|
6603 |
-
1. execution_result: if df responses match
|
6604 |
-
2. non_empty_execution_result: if dfs are non-empty and match
|
6605 |
-
3. subset_non_empty_execution_result: if non-empty dfs and one is a subset of the other
|
6606 |
-
4. non_empty_gold_df: if gt df is non-empty
|
6607 |
-
5. gold_sql_runtime: ground truth query runtime
|
6608 |
-
6. predicted_sql_runtime: predicted query runtime
|
6609 |
-
7. pred_to_gold_runtime_ratio: ratio of predicted query runtime to gt query runtime
|
6610 |
-
8. gold_error: if gt has an error
|
6611 |
-
9. predicted_error: if predicted query has an error
|
6612 |
-
10. ground truth dataframe
|
6613 |
-
11. predicted query's dataframe
|
6614 |
-
12. error message (if any)
|
6615 |
-
"""
|
6616 |
-
import time
|
6617 |
|
6618 |
-
|
6619 |
-
|
|
|
|
|
6620 |
|
6621 |
-
|
6622 |
|
6623 |
-
|
6624 |
-
|
6625 |
-
|
6626 |
-
|
6627 |
-
|
6628 |
-
gold_res, gold_error = func_timeout(
|
6629 |
-
self.sql_timeout,
|
6630 |
-
connector.execute_query,
|
6631 |
-
args=(gold_sql,),
|
6632 |
-
)
|
6633 |
-
end_time = time.perf_counter()
|
6634 |
-
gold_sql_runtime = end_time - start_time
|
6635 |
-
except FunctionTimedOut as e:
|
6636 |
-
pred_error = f"Timeout error executing gold SQL: {e}"
|
6637 |
-
logger.warning(pred_error)
|
6638 |
-
except Exception as e:
|
6639 |
-
gold_error = f"Error executing gold SQL: {e}"
|
6640 |
-
if gold_error is not None:
|
6641 |
-
return (
|
6642 |
-
0,
|
6643 |
-
0,
|
6644 |
-
0,
|
6645 |
-
0,
|
6646 |
-
gold_sql_runtime,
|
6647 |
-
0,
|
6648 |
-
0,
|
6649 |
-
0,
|
6650 |
-
0,
|
6651 |
-
"",
|
6652 |
-
"",
|
6653 |
-
gold_error,
|
6654 |
-
)
|
6655 |
|
6656 |
-
|
6657 |
-
gold_res = gold_res["results"]
|
6658 |
-
gold_df = pd.DataFrame(gold_res)
|
6659 |
-
non_empty_gold_df = 0 if gold_df.empty else 1
|
6660 |
|
6661 |
-
|
6662 |
-
|
6663 |
-
|
6664 |
-
|
6665 |
-
|
6666 |
-
|
6667 |
-
|
6668 |
-
|
6669 |
-
0,
|
6670 |
-
0,
|
6671 |
-
gold_df.to_json(),
|
6672 |
-
"",
|
6673 |
-
"",
|
6674 |
-
)
|
6675 |
-
if predicted_sql.lower().strip() == gold_sql.lower().strip():
|
6676 |
-
return no_execution_match_result
|
6677 |
-
try:
|
6678 |
-
if sqlglot_optimized_equivalence(gold_sql, predicted_sql):
|
6679 |
-
return no_execution_match_result
|
6680 |
-
except Exception as e: # Catch specific exceptions if possible
|
6681 |
-
logger.info(
|
6682 |
-
f"Couldn't test equivalent_sqls: {e}. Treating as non-equivalent and going to test with the db."
|
6683 |
-
)
|
6684 |
|
6685 |
-
|
6686 |
-
|
6687 |
-
|
6688 |
-
|
6689 |
-
|
6690 |
-
pred_res, pred_error = func_timeout(
|
6691 |
-
self.sql_timeout,
|
6692 |
-
connector.execute_query,
|
6693 |
-
args=(predicted_sql,),
|
6694 |
-
)
|
6695 |
-
end_time = time.perf_counter()
|
6696 |
-
pred_sql_runtime = end_time - start_time
|
6697 |
-
except FunctionTimedOut as e:
|
6698 |
-
pred_error = f"Timeout error executing predicted SQL: {e}"
|
6699 |
-
logger.info(pred_error)
|
6700 |
-
except Exception as e:
|
6701 |
-
pred_error = f"Error executing predicted SQL: {e}"
|
6702 |
-
logger.info(pred_error)
|
6703 |
-
|
6704 |
-
pred_to_gold_runtime_ratio = (
|
6705 |
-
float(pred_sql_runtime) / gold_sql_runtime if gold_sql_runtime > 0 else 0
|
6706 |
)
|
6707 |
|
6708 |
-
|
6709 |
-
|
6710 |
-
0,
|
6711 |
-
0,
|
6712 |
-
0,
|
6713 |
-
0,
|
6714 |
-
gold_sql_runtime,
|
6715 |
-
pred_sql_runtime,
|
6716 |
-
pred_to_gold_runtime_ratio,
|
6717 |
-
0,
|
6718 |
-
1,
|
6719 |
-
"",
|
6720 |
-
"",
|
6721 |
-
pred_error,
|
6722 |
-
)
|
6723 |
-
|
6724 |
-
if isinstance(pred_res, dict) and "results" in pred_res:
|
6725 |
-
pred_res = pred_res["results"]
|
6726 |
-
predicted_df = pd.DataFrame(pred_res)
|
6727 |
-
|
6728 |
-
subset_non_empty_execution_result = 0
|
6729 |
-
non_empty_execution_result = 0
|
6730 |
-
if "ORDER BY" in gold_sql.upper():
|
6731 |
-
execution_result = (
|
6732 |
-
1
|
6733 |
-
if self.compare_dfs_ignore_colnames_ordered_rows(predicted_df, gold_df)
|
6734 |
-
else 0
|
6735 |
-
)
|
6736 |
-
if non_empty_gold_df:
|
6737 |
-
if execution_result == 1:
|
6738 |
-
non_empty_execution_result = 1
|
6739 |
-
if self.compare_dfs_ignore_colnames_subset(
|
6740 |
-
gold_df, predicted_df, ignore_row_order=False
|
6741 |
-
):
|
6742 |
-
subset_non_empty_execution_result = 1
|
6743 |
-
else:
|
6744 |
-
execution_result = (
|
6745 |
-
1
|
6746 |
-
if self.compare_dfs_ignore_colnames_unordered_rows(
|
6747 |
-
predicted_df, gold_df
|
6748 |
-
)
|
6749 |
-
else 0
|
6750 |
-
)
|
6751 |
-
if non_empty_gold_df:
|
6752 |
-
if execution_result == 1:
|
6753 |
-
non_empty_execution_result = 1
|
6754 |
-
if self.compare_dfs_ignore_colnames_subset(
|
6755 |
-
gold_df, predicted_df, ignore_row_order=True
|
6756 |
-
):
|
6757 |
-
subset_non_empty_execution_result = 1
|
6758 |
|
6759 |
-
|
6760 |
-
|
6761 |
-
|
6762 |
-
subset_non_empty_execution_result,
|
6763 |
-
non_empty_gold_df,
|
6764 |
-
gold_sql_runtime,
|
6765 |
-
pred_sql_runtime,
|
6766 |
-
pred_to_gold_runtime_ratio,
|
6767 |
-
0,
|
6768 |
-
0,
|
6769 |
-
gold_df.to_json(),
|
6770 |
-
predicted_df.to_json(),
|
6771 |
-
pred_error,
|
6772 |
)
|
6773 |
|
6774 |
-
|
6775 |
-
from .sql_utils import get_db_connector
|
6776 |
-
|
6777 |
-
predicted_sql = prediction
|
6778 |
-
execution_result: float = 0.0
|
6779 |
-
|
6780 |
-
if predicted_sql and predicted_sql.strip() != "":
|
6781 |
-
if not predicted_sql.startswith("SELECT") and "SELECT" in predicted_sql:
|
6782 |
-
predicted_sql = predicted_sql[predicted_sql.find("SELECT") :]
|
6783 |
-
if ";" in predicted_sql:
|
6784 |
-
predicted_sql = predicted_sql[: predicted_sql.find(";") + 1]
|
6785 |
-
|
6786 |
-
db_connector = get_db_connector(task_data["db"]["db_type"])(task_data["db"])
|
6787 |
-
|
6788 |
-
logger.debug(
|
6789 |
-
f"Starting to get SQL execution results over DB: {task_data['db']}"
|
6790 |
-
)
|
6791 |
-
(
|
6792 |
-
execution_result,
|
6793 |
-
non_empty_execution_result,
|
6794 |
-
subset_non_empty_execution_result,
|
6795 |
-
non_empty_gold_df,
|
6796 |
-
gold_sql_runtime,
|
6797 |
-
predicted_sql_runtime,
|
6798 |
-
pred_to_gold_runtime_ratio,
|
6799 |
-
gold_error,
|
6800 |
-
predicted_error,
|
6801 |
-
gold_df_json,
|
6802 |
-
predicted_df_json,
|
6803 |
-
error_message,
|
6804 |
-
) = self.get_sql_execution_results(
|
6805 |
-
predicted_sql, references[0], db_connector
|
6806 |
-
)
|
6807 |
-
|
6808 |
-
result = {
|
6809 |
-
"execution_accuracy": float(execution_result),
|
6810 |
-
"non_empty_execution_accuracy": float(non_empty_execution_result),
|
6811 |
-
"subset_non_empty_execution_result": float(
|
6812 |
-
subset_non_empty_execution_result
|
6813 |
-
),
|
6814 |
-
"non_empty_gold_df": float(non_empty_gold_df),
|
6815 |
-
"gold_sql_runtime": float(gold_sql_runtime),
|
6816 |
-
"predicted_sql_runtime": float(predicted_sql_runtime),
|
6817 |
-
"pred_to_gold_runtime_ratio": float(pred_to_gold_runtime_ratio),
|
6818 |
-
"gold_error": float(gold_error),
|
6819 |
-
"predicted_error": float(predicted_error),
|
6820 |
-
"error_message": str(error_message),
|
6821 |
-
"gold_df_json": str(gold_df_json),
|
6822 |
-
"predicted_df_json": str(predicted_df_json),
|
6823 |
-
}
|
6824 |
result["score"] = result[self.main_score]
|
6825 |
result["score_name"] = self.main_score
|
6826 |
logger.debug(f"SQL Execution Accuracy Result: {result}")
|
@@ -6828,34 +6737,22 @@ class SQLExecutionAccuracy(InstanceMetric):
|
|
6828 |
|
6829 |
|
6830 |
class SQLNonExecutionAccuracy(InstanceMetric):
|
6831 |
-
|
6832 |
-
|
6833 |
-
|
6834 |
-
|
6835 |
-
"sqlglot_equivalence",
|
6836 |
-
"sqlglot_optimized_equivalence",
|
6837 |
-
"sqlparse_equivalence",
|
6838 |
-
"sql_exact_match",
|
6839 |
-
"sql_syntactic_equivalence",
|
6840 |
-
]
|
6841 |
-
}
|
6842 |
-
main_score = "sqlglot_equivalence"
|
6843 |
-
ci_scores = [
|
6844 |
-
"sqlglot_validity",
|
6845 |
-
"sqlparse_validity",
|
6846 |
-
"sqlglot_equivalence",
|
6847 |
-
"sqlglot_optimized_equivalence",
|
6848 |
-
"sqlparse_equivalence",
|
6849 |
-
"sql_exact_match",
|
6850 |
-
"sql_syntactic_equivalence",
|
6851 |
]
|
|
|
|
|
|
|
6852 |
|
6853 |
prediction_type = "Any" # string representation is compared
|
6854 |
|
6855 |
_requirements_list = ["sqlglot", "sqlparse"]
|
6856 |
|
6857 |
def compute(self, references: List[Any], prediction: str, task_data: Dict) -> dict:
|
6858 |
-
from .
|
|
|
6859 |
is_sqlglot_parsable,
|
6860 |
is_sqlparse_parsable,
|
6861 |
sql_exact_match,
|
@@ -6864,48 +6761,45 @@ class SQLNonExecutionAccuracy(InstanceMetric):
|
|
6864 |
sqlparse_queries_equivalent,
|
6865 |
)
|
6866 |
|
6867 |
-
predicted_sql = prediction
|
6868 |
gold_sql = references[0]
|
6869 |
-
|
6870 |
-
if predicted_sql and predicted_sql.strip() != "":
|
6871 |
-
if not predicted_sql.startswith("SELECT") and "SELECT" in predicted_sql:
|
6872 |
-
predicted_sql = predicted_sql[predicted_sql.find("SELECT") :]
|
6873 |
-
if ";" in predicted_sql:
|
6874 |
-
predicted_sql = predicted_sql[: predicted_sql.find(";") + 1]
|
6875 |
|
6876 |
is_sqlglot_parsable = is_sqlglot_parsable(predicted_sql)
|
6877 |
is_sqlparse_parsable = is_sqlparse_parsable(predicted_sql)
|
6878 |
-
|
6879 |
-
|
6880 |
-
|
6881 |
-
|
6882 |
sqlglot_parsed_queries_equivalent(predicted_sql, gold_sql)
|
6883 |
if is_sqlglot_parsable
|
6884 |
else 0
|
6885 |
),
|
6886 |
-
|
6887 |
sqlglot_optimized_equivalence(predicted_sql, gold_sql)
|
6888 |
if is_sqlglot_parsable
|
6889 |
else 0
|
6890 |
),
|
6891 |
-
|
6892 |
sqlparse_queries_equivalent(predicted_sql, gold_sql)
|
6893 |
if is_sqlparse_parsable
|
6894 |
else 0
|
6895 |
),
|
6896 |
-
|
6897 |
-
|
6898 |
-
|
|
|
|
|
6899 |
any(
|
6900 |
-
|
6901 |
-
|
6902 |
-
|
6903 |
-
|
6904 |
-
|
6905 |
-
"sql_exact_match",
|
6906 |
]
|
6907 |
)
|
6908 |
)
|
|
|
|
|
6909 |
logger.debug(f"SQL Non Execution Accuracy Result: {result}")
|
6910 |
result["score"] = result[self.main_score]
|
6911 |
result["score_name"] = self.main_score
|
|
|
8 |
import warnings
|
9 |
from abc import ABC, abstractmethod
|
10 |
from collections import Counter, defaultdict
|
11 |
+
from dataclasses import asdict, field
|
12 |
+
from dataclasses import fields as dataclasses_fields
|
13 |
from enum import Enum
|
14 |
from functools import lru_cache
|
15 |
from typing import (
|
|
|
43 |
OptionalField,
|
44 |
)
|
45 |
from .deprecation_utils import deprecation
|
46 |
+
from .dict_utils import dict_get
|
47 |
+
from .error_utils import Documentation, UnitxtError, UnitxtWarning, error_context
|
48 |
from .inference import (
|
49 |
HFPipelineBasedInferenceEngine,
|
50 |
InferenceEngine,
|
|
|
66 |
from .random_utils import get_seed
|
67 |
from .settings_utils import get_settings
|
68 |
from .stream import MultiStream, Stream
|
69 |
+
from .text2sql_utils import SQLExecutionResult, SQLNonExecutionMetricResult
|
70 |
from .type_utils import isoftype, parse_type_string, to_type_string
|
71 |
from .types import ToolCall
|
72 |
from .utils import deep_copy, recursive_copy, retry_connection_with_exponential_backoff
|
|
|
385 |
return intermediates
|
386 |
|
387 |
def process(self, stream: Stream, stream_name: Optional[str] = None):
|
388 |
+
with error_context(
|
389 |
+
self,
|
390 |
+
stage="Evaluating Metric",
|
391 |
+
help="https://www.unitxt.ai/en/latest/docs/adding_metric.html",
|
392 |
+
):
|
393 |
+
instances_scores, global_scores = self.compute(stream, stream_name)
|
394 |
+
for i, (instance, instance_scores) in enumerate(
|
395 |
+
zip(stream, instances_scores)
|
396 |
+
):
|
397 |
+
previous_score = instance.get("score", {"global": {}, "instance": {}})
|
398 |
+
|
399 |
+
if i == 0:
|
400 |
+
for key in global_scores:
|
401 |
+
if is_original_key(key) and key in previous_score["global"]:
|
402 |
+
UnitxtWarning(
|
403 |
+
message=f"Metric '{key}' that has just been evaluated with value {global_scores[key]}, is already recorded "
|
404 |
+
f"to have value {previous_score['global'][key]} by a previous metric evaluation on this instance or stream. "
|
405 |
+
f"To avoid overwriting the existing value, add a score_prefix to the metric name (e.g. score_prefix='my_second_' , "
|
406 |
+
f"which will yield, in this case, a score named: 'my_second_{key}')",
|
407 |
+
additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
|
408 |
+
)
|
409 |
|
410 |
+
global_scores = {**previous_score["global"], **global_scores}
|
411 |
+
instance_scores = {**previous_score["instance"], **instance_scores}
|
412 |
|
413 |
+
yield {
|
414 |
+
**instance,
|
415 |
+
"score": {"global": global_scores, "instance": instance_scores},
|
416 |
+
}
|
417 |
|
418 |
def compute(self, stream: Stream, stream_name: Optional[str] = None):
|
419 |
evaluation_inputs_stream = self._instances_stream_to_evaluation_inputs(stream)
|
|
|
463 |
return result
|
464 |
|
465 |
|
466 |
+
class GroupReduction(AggregationReduction[Tuple[str, Dict[str, float]]]):
|
467 |
+
def reduce_list(self, lst: List[Tuple[str, float]]):
|
468 |
+
pass
|
469 |
+
|
470 |
+
def reduce(self, intermidates: Tuple[str, Dict[str, float]]):
|
471 |
+
lists = {}
|
472 |
+
for id, intermidate in intermidates:
|
473 |
+
for key, val in intermidate.items():
|
474 |
+
if key not in lists:
|
475 |
+
lists[key] = []
|
476 |
+
lists[key].append((id, val))
|
477 |
+
|
478 |
+
result = {}
|
479 |
+
for key, val_list in lists.items():
|
480 |
+
result[key] = self.reduce_list(val_list)
|
481 |
+
return result
|
482 |
+
|
483 |
+
|
484 |
+
class GroupMean(GroupReduction):
|
485 |
+
def reduce_list(self, lst: List[Tuple[str, float]]):
|
486 |
+
return nan_mean([item[1] for item in lst])
|
487 |
+
|
488 |
+
|
489 |
+
class SequentialSuccess(GroupReduction):
|
490 |
+
threshold: float = 0.5
|
491 |
+
|
492 |
+
def reduce_list(self, lst: List[Tuple[str, float]]):
|
493 |
+
sorted_items = [item for _, item in sorted(lst, key=lambda x: x[0])]
|
494 |
+
successful = 0
|
495 |
+
for item in sorted_items:
|
496 |
+
if item > self.threshold:
|
497 |
+
successful += 1
|
498 |
+
else:
|
499 |
+
break
|
500 |
+
return successful / len(lst)
|
501 |
+
|
502 |
+
|
503 |
class MeanReduction(DictReduction):
|
504 |
def reduce_list(self, lst: List[float]):
|
505 |
return nan_mean(lst)
|
|
|
515 |
return float(nan_max(lst))
|
516 |
|
517 |
|
518 |
+
class GroupMetric(
|
519 |
+
MapReduceMetric[PredictionType, IntermediateType],
|
520 |
+
Generic[PredictionType, IntermediateType],
|
521 |
+
):
|
522 |
+
main_score: str = None
|
523 |
+
metric: MapReduceMetric[PredictionType, IntermediateType]
|
524 |
+
group_id_field: str
|
525 |
+
item_id_field: str
|
526 |
+
in_group_reduction: GroupReduction = GroupMean()
|
527 |
+
cross_group_reduction: GroupReduction = GroupMean()
|
528 |
+
n_resamples = None
|
529 |
+
|
530 |
+
def _get_group_id(self, task_data) -> str:
|
531 |
+
return str(dict_get(task_data, self.group_id_field))
|
532 |
+
|
533 |
+
def _get_item_id(self, task_data) -> str:
|
534 |
+
return str(dict_get(task_data, self.item_id_field))
|
535 |
+
|
536 |
+
def prepare(self):
|
537 |
+
super().prepare()
|
538 |
+
self.main_score = self.metric.main_score
|
539 |
+
|
540 |
+
def map_stream(
|
541 |
+
self,
|
542 |
+
evaluation_inputs_stream: Generator[
|
543 |
+
EvaluationInput[PredictionType], None, None
|
544 |
+
],
|
545 |
+
) -> List[Tuple[IntermediateType, str, str]]:
|
546 |
+
group_ids: List[str] = []
|
547 |
+
item_ids: List[str] = []
|
548 |
+
|
549 |
+
def multi_turn_stream(
|
550 |
+
evaluation_inputs_stream: Generator[
|
551 |
+
EvaluationInput[PredictionType], None, None
|
552 |
+
],
|
553 |
+
) -> Generator[
|
554 |
+
Tuple[PredictionType, List[PredictionType], Dict[str, Any]], None, None
|
555 |
+
]:
|
556 |
+
for prediction, references, task_data in evaluation_inputs_stream:
|
557 |
+
group_ids.append(self._get_group_id(task_data))
|
558 |
+
item_ids.append(self._get_item_id(task_data))
|
559 |
+
yield prediction, references, task_data
|
560 |
+
|
561 |
+
intermediates: List[IntermediateType] = list(
|
562 |
+
self.metric.map_stream(multi_turn_stream(evaluation_inputs_stream))
|
563 |
+
)
|
564 |
+
|
565 |
+
return list(zip(intermediates, group_ids, item_ids))
|
566 |
+
|
567 |
+
def reduce_group(self, dialog_data: Dict[str, Dict[str, Any]]):
|
568 |
+
return self.in_group_reduction.reduce(list(dialog_data.items()))
|
569 |
+
|
570 |
+
def reduce_one(self, intermidate: Tuple[IntermediateType, str, str]):
|
571 |
+
return self.metric.reduce_one(intermidate[0])
|
572 |
+
|
573 |
+
def reduce(
|
574 |
+
self, intermediates: List[Tuple[IntermediateType, str, str]]
|
575 |
+
) -> Dict[str, Any]:
|
576 |
+
data: Dict[str, Dict[str, Any]] = {}
|
577 |
+
for intermediate, group_id, item_id in intermediates:
|
578 |
+
if group_id not in data:
|
579 |
+
data[group_id] = {}
|
580 |
+
data[group_id][item_id] = self.metric.reduce_one(intermediate)
|
581 |
+
|
582 |
+
group_scores: Dict[str, Dict[str, Any]] = {
|
583 |
+
dialog_id: self.reduce_group(dialog_data)
|
584 |
+
for dialog_id, dialog_data in data.items()
|
585 |
+
}
|
586 |
+
|
587 |
+
return self.cross_group_reduction.reduce(list(group_scores.items()))
|
588 |
+
|
589 |
+
|
590 |
+
class MultiTurnMetric(
|
591 |
+
GroupMetric[PredictionType, IntermediateType],
|
592 |
+
Generic[PredictionType, IntermediateType],
|
593 |
+
):
|
594 |
+
group_id_field = "conversation/id"
|
595 |
+
item_id_field = "conversation/dialog"
|
596 |
+
|
597 |
+
def _get_item_id(self, task_data):
|
598 |
+
return "assistant_turn_" + str(
|
599 |
+
len(dict_get(task_data, self.item_id_field)) // 2
|
600 |
+
)
|
601 |
+
|
602 |
+
|
603 |
class ReductionInstanceMetric(
|
604 |
MapReduceMetric[PredictionType, IntermediateType],
|
605 |
Generic[PredictionType, IntermediateType],
|
|
|
836 |
}
|
837 |
|
838 |
|
839 |
+
class MultiTurnToolCallingMetric(ReductionInstanceMetric[str, Dict[str, float]]):
|
840 |
+
"""Compares each predicted tool call with list of references tool call."""
|
841 |
+
|
842 |
+
main_score = "argument_schema_validation"
|
843 |
+
reduction = MeanReduction()
|
844 |
+
prediction_type = List[ToolCall]
|
845 |
+
_requirements_list = ["jsonschema-rs"]
|
846 |
+
|
847 |
+
def prepare(self):
|
848 |
+
super().prepare()
|
849 |
+
import jsonschema_rs
|
850 |
+
|
851 |
+
self._schema = jsonschema_rs
|
852 |
+
|
853 |
+
def map(
|
854 |
+
self,
|
855 |
+
prediction: List[ToolCall],
|
856 |
+
references: List[List[ToolCall]],
|
857 |
+
task_data: Dict[str, Any],
|
858 |
+
) -> Dict[str, float]:
|
859 |
+
validation_scores = []
|
860 |
+
for tool_call in prediction:
|
861 |
+
parameters = None
|
862 |
+
for tool in task_data["__tools__"]:
|
863 |
+
if tool["function"]["name"] == tool_call["name"]:
|
864 |
+
parameters = tool["function"]["parameters"]
|
865 |
+
|
866 |
+
if parameters is None:
|
867 |
+
validation_scores.append(0.0)
|
868 |
+
else:
|
869 |
+
try:
|
870 |
+
self._schema.validate(
|
871 |
+
parameters,
|
872 |
+
tool_call["arguments"],
|
873 |
+
)
|
874 |
+
validation_scores.append(1.0)
|
875 |
+
except self._schema.ValidationError:
|
876 |
+
validation_scores.append(0.0)
|
877 |
+
|
878 |
+
argument_schema_validation = sum(validation_scores) / len(validation_scores)
|
879 |
+
|
880 |
+
return {
|
881 |
+
"argument_schema_validation": argument_schema_validation,
|
882 |
+
}
|
883 |
+
|
884 |
+
|
885 |
class MetricWithConfidenceInterval(Metric):
|
886 |
# The number of resamples used to estimate the confidence intervals of this metric.
|
887 |
# Use None to disable confidence interval computation.
|
|
|
1132 |
process_single_instances = True
|
1133 |
|
1134 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
1135 |
+
with error_context(
|
1136 |
+
self,
|
1137 |
+
stage="Evaluating Metric",
|
1138 |
+
help="https://www.unitxt.ai/en/latest/docs/adding_metric.html",
|
1139 |
+
):
|
1140 |
+
references = []
|
1141 |
+
predictions = []
|
1142 |
+
task_data = []
|
1143 |
|
1144 |
+
instances = []
|
1145 |
|
1146 |
+
for instance in stream:
|
1147 |
+
instance = self.verify_instance(instance)
|
1148 |
|
1149 |
+
if "score" not in instance:
|
1150 |
+
instance["score"] = {"global": {}, "instance": {}}
|
1151 |
|
1152 |
+
instance_references, instance_prediction = (
|
1153 |
+
instance["references"],
|
1154 |
+
instance["prediction"],
|
1155 |
+
)
|
1156 |
|
1157 |
+
references.append(instance_references)
|
1158 |
+
predictions.append(instance_prediction)
|
1159 |
+
instances.append(instance)
|
1160 |
|
1161 |
+
instance_task_data = (
|
1162 |
+
instance["task_data"] if "task_data" in instance else {}
|
1163 |
+
)
|
1164 |
+
task_data.append(instance_task_data)
|
1165 |
+
instance_score = None
|
1166 |
|
1167 |
+
# for backward compatibility
|
1168 |
+
no_score_value = np.nan
|
1169 |
+
if self.process_single_instances:
|
1170 |
+
try:
|
1171 |
+
instance_score = self._compute(
|
1172 |
+
[instance_references],
|
1173 |
+
[instance_prediction],
|
1174 |
+
[instance_task_data],
|
1175 |
+
)
|
1176 |
+
except:
|
1177 |
+
no_score_value = None
|
1178 |
+
if not instance_score:
|
1179 |
+
instance_score = {
|
1180 |
+
"score": no_score_value,
|
1181 |
+
"score_name": self.main_score,
|
1182 |
+
}
|
1183 |
|
1184 |
+
if isinstance(self.main_score, str):
|
1185 |
+
instance_score[self.main_score] = no_score_value
|
1186 |
|
1187 |
+
instance["score"]["instance"].update(
|
1188 |
+
self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
|
1189 |
+
instance_score, instance["score"]["instance"]
|
1190 |
+
)
|
1191 |
)
|
1192 |
+
self._validate_references_and_prediction(references, predictions)
|
1193 |
+
global_score = {"num_of_instances": len(instances)}
|
|
|
1194 |
|
1195 |
+
result = self._compute(references, predictions, task_data)
|
1196 |
+
global_score.update(
|
1197 |
+
self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
|
1198 |
+
result, global_score
|
1199 |
+
)
|
1200 |
)
|
1201 |
+
if self.ci_scores:
|
1202 |
+
score_names = [
|
1203 |
+
self._add_score_prefix(score_name) for score_name in self.ci_scores
|
1204 |
+
]
|
1205 |
+
else:
|
1206 |
+
score_names = [global_score["score_name"]]
|
|
|
1207 |
|
1208 |
+
for score_name in score_names:
|
1209 |
+
confidence_interval = self.compute_global_confidence_intervals(
|
1210 |
+
references, predictions, task_data, score_name
|
1211 |
+
)
|
1212 |
+
global_score.update(confidence_interval)
|
1213 |
|
1214 |
+
for instance in instances:
|
1215 |
+
self.update_and_adjust_global_score(instance, global_score)
|
1216 |
+
yield instance
|
1217 |
|
1218 |
def _compute(
|
1219 |
self,
|
|
|
1263 |
return instance
|
1264 |
|
1265 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
1266 |
+
with error_context(
|
1267 |
+
self,
|
1268 |
+
stage="Evaluating Metrics",
|
1269 |
+
help="https://www.unitxt.ai/en/latest/docs/adding_metric.html",
|
1270 |
+
):
|
1271 |
+
instances = []
|
1272 |
+
for instance in stream:
|
1273 |
+
self.verify_instance(instance)
|
1274 |
+
instance = self.preprocess_instance(instance)
|
1275 |
+
instances.append(instance)
|
1276 |
+
|
1277 |
+
predictions = [instance["prediction"] for instance in instances]
|
1278 |
+
references = [instance["references"] for instance in instances]
|
1279 |
+
task_data = [
|
1280 |
+
instance["task_data"] if "task_data" in instance else {}
|
1281 |
+
for instance in instances
|
1282 |
+
]
|
1283 |
+
self._validate_references_and_prediction(references, predictions)
|
1284 |
+
global_score = {"num_of_instances": len(instances)}
|
1285 |
+
# compute the metric over all refs and preds
|
1286 |
+
instance_scores = self.compute(
|
1287 |
+
references=references,
|
1288 |
+
predictions=predictions,
|
1289 |
+
task_data=task_data,
|
1290 |
+
)
|
1291 |
|
1292 |
+
# add the score and score_name fields
|
1293 |
+
for instance_score in instance_scores:
|
1294 |
+
instance_score["score"] = instance_score[self.main_score]
|
1295 |
+
instance_score["score_name"] = self.main_score
|
1296 |
|
1297 |
+
for instance, score in zip(instances, instance_scores):
|
1298 |
+
if "score" not in instance:
|
1299 |
+
instance["score"] = {"global": {}, "instance": {}}
|
1300 |
|
1301 |
+
instance["score"]["instance"].update(
|
1302 |
+
self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
|
1303 |
+
score, instance["score"]["instance"]
|
1304 |
+
)
|
1305 |
)
|
|
|
1306 |
|
1307 |
+
for reduction, fields in self.reduction_map.items():
|
1308 |
+
assert (
|
1309 |
+
reduction in self.implemented_reductions
|
1310 |
+
), f"Reduction {reduction} is not implemented, use one of {self.implemented_reductions}"
|
1311 |
+
|
1312 |
+
if reduction == "mean":
|
1313 |
+
for field_name in fields:
|
1314 |
+
field_name_with_prefix = self._add_score_prefix(field_name)
|
1315 |
+
global_score[field_name_with_prefix] = nan_mean(
|
1316 |
+
[
|
1317 |
+
instance["score"]["instance"][field_name_with_prefix]
|
1318 |
+
for instance in instances
|
1319 |
+
]
|
1320 |
+
)
|
1321 |
+
if field_name == self.main_score:
|
1322 |
+
global_score["score"] = global_score[field_name_with_prefix]
|
1323 |
+
global_score["score_name"] = (
|
1324 |
+
self.score_prefix + self.main_score
|
1325 |
+
)
|
1326 |
|
1327 |
+
ci_fields = (
|
1328 |
+
list(set(self.ci_scores))
|
1329 |
+
if self.ci_scores is not None
|
1330 |
+
else [self.main_score]
|
1331 |
+
)
|
1332 |
+
ci_fields_with_prefix = [
|
1333 |
+
self._add_score_prefix(ci_field) for ci_field in ci_fields
|
1334 |
+
]
|
1335 |
+
confidence_interval = self.score_based_confidence_interval(
|
1336 |
+
instances=instances, score_names=ci_fields_with_prefix
|
1337 |
+
)
|
1338 |
+
global_score.update(confidence_interval)
|
1339 |
+
if reduction == "weighted_win_rate":
|
1340 |
+
for field_name in fields:
|
1341 |
+
field_name_with_prefix = self._add_score_prefix(field_name)
|
1342 |
+
total_battles = 0
|
1343 |
+
wins = 0
|
1344 |
+
for instance in instances:
|
1345 |
+
s = instance["score"]["instance"][field_name_with_prefix]
|
1346 |
+
if s > 0:
|
1347 |
+
total_battles += s
|
1348 |
+
wins += s
|
1349 |
+
elif s < 0:
|
1350 |
+
total_battles += abs(s)
|
1351 |
+
else:
|
1352 |
+
total_battles += 2
|
1353 |
+
wins += 1
|
1354 |
+
|
1355 |
+
global_score[field_name_with_prefix] = wins / total_battles
|
1356 |
+
if field_name == self.main_score:
|
1357 |
+
global_score["score"] = global_score[field_name_with_prefix]
|
1358 |
+
global_score["score_name"] = (
|
1359 |
+
self.score_prefix + self.main_score
|
1360 |
+
)
|
1361 |
|
1362 |
+
for instance in instances:
|
1363 |
+
self.update_and_adjust_global_score(instance, global_score)
|
1364 |
+
yield instance
|
1365 |
|
1366 |
@abstractmethod
|
1367 |
def compute(
|
|
|
1667 |
assert isinstance(fields["score_fields"], list)
|
1668 |
|
1669 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
1670 |
+
with error_context(
|
1671 |
+
self,
|
1672 |
+
stage="Evaluating Metrics",
|
1673 |
+
help="https://www.unitxt.ai/en/latest/docs/adding_metric.html",
|
1674 |
+
):
|
1675 |
+
instance_scores = self.compute_instance_scores(stream)
|
1676 |
+
global_score = {"num_of_instances": len(instance_scores)}
|
1677 |
+
for reduction_type, reduction_params in self.reduction_map.items():
|
1678 |
+
assert (
|
1679 |
+
reduction_type in self.implemented_reductions
|
1680 |
+
), f"Reduction {reduction_type} is not implemented, use one of {self.implemented_reductions}"
|
1681 |
+
|
1682 |
+
field_name_full_prefix = ""
|
1683 |
+
# used for passing to the bootstrapping, depends on whether the groups are fixed or not
|
1684 |
+
aggregation_function = None
|
1685 |
+
if reduction_type == "mean":
|
1686 |
+
aggregation_function = self.average_item_scores
|
1687 |
+
reduction_fields = list(set(reduction_params))
|
1688 |
+
# no group reduction, so resample instances individually
|
1689 |
+
scores_to_resample = instance_scores
|
1690 |
+
elif reduction_type == "max":
|
1691 |
+
aggregation_function = self.max_item_scores
|
1692 |
+
reduction_fields = list(set(reduction_params))
|
1693 |
+
# no group reduction, so resample instances individually
|
1694 |
+
scores_to_resample = instance_scores
|
1695 |
+
elif reduction_type == "group_mean":
|
1696 |
+
aggregation_function = self.average_item_scores
|
1697 |
+
self._validate_group_mean_reduction()
|
1698 |
+
reduction_fields = (
|
1699 |
+
[self.main_score]
|
1700 |
+
if "score_fields" not in reduction_params
|
1701 |
+
else list(set(reduction_params["score_fields"]))
|
1702 |
+
)
|
1703 |
+
aggregation_function_name = str(reduction_params["agg_func"][0])
|
1704 |
+
field_name_full_prefix = "group_" + aggregation_function_name + "_"
|
1705 |
+
do_resample_as_group = reduction_params["agg_func"][2]
|
1706 |
+
if do_resample_as_group:
|
1707 |
+
# append fixed_ to name because resamples the groups as fixed units
|
1708 |
+
field_name_full_prefix = "fixed_" + field_name_full_prefix
|
1709 |
+
(
|
1710 |
+
scores_to_resample,
|
1711 |
+
aggregation_function,
|
1712 |
+
) = self._set_up_group_mean_aggregation(
|
1713 |
+
instance_scores,
|
1714 |
+
reduction_params,
|
1715 |
+
reduction_fields,
|
1716 |
+
)
|
1717 |
+
else:
|
1718 |
+
raise ValueError(
|
1719 |
+
f"Reduction {reduction_type} is not supported, please specify a valid reduction method in reduction_map {self.reduction_map}."
|
1720 |
+
)
|
1721 |
|
1722 |
+
# calculate global scores for each reduction field
|
1723 |
+
for field_name in reduction_fields:
|
1724 |
+
field_name_full = (
|
1725 |
+
field_name_full_prefix + self.score_prefix + field_name
|
1726 |
+
)
|
1727 |
+
# if group resampling (3rd element of agg_func parameter) is True, then
|
1728 |
+
# 1. scores_to_resample are the group scores, and
|
1729 |
+
# 2. aggregation_function is to take the raw mean
|
1730 |
+
# if no group resampling (3rd element of agg_func parameter) is False, then
|
1731 |
+
# 1. scores_to_resample are the original instance scores, and
|
1732 |
+
# 2. aggregation_function is to apply the group aggregation from the instance scores
|
1733 |
+
# either way, the application of aggregation_function to scores_to_resample yields the global score
|
1734 |
+
global_score[field_name_full] = aggregation_function(
|
1735 |
+
scores_to_resample, self.score_prefix + field_name
|
1736 |
+
)
|
1737 |
+
if field_name == self.main_score:
|
1738 |
+
global_score["score"] = global_score[field_name_full]
|
1739 |
+
global_score["score_name"] = field_name_full
|
1740 |
+
|
1741 |
+
# need to specify which fields should have CIs calculated for them through ci_scores
|
1742 |
+
# (will not automatically calculate CIs for fields in reduction map)
|
1743 |
+
if self.ci_scores is not None:
|
1744 |
+
confidence_interval = self.score_based_confidence_interval(
|
1745 |
+
instances=scores_to_resample,
|
1746 |
+
score_names=[
|
1747 |
+
self.score_prefix + ci_score
|
1748 |
+
for ci_score in set(self.ci_scores)
|
1749 |
+
],
|
1750 |
+
ci_score_prefix=field_name_full_prefix,
|
1751 |
+
aggregation_func=aggregation_function,
|
1752 |
+
)
|
1753 |
+
global_score.update(confidence_interval)
|
1754 |
|
1755 |
+
for instance in instance_scores:
|
1756 |
+
self.update_and_adjust_global_score(instance, global_score)
|
1757 |
|
1758 |
+
for i, instance in enumerate(stream):
|
1759 |
+
instance["score"] = recursive_copy(instance_scores[i]["score"])
|
1760 |
+
yield instance
|
1761 |
|
1762 |
def compute_instance_scores(
|
1763 |
self, stream: Stream, stream_name: Optional[str] = None
|
|
|
6634 |
}
|
6635 |
|
6636 |
|
6637 |
+
class SQLExecutionLogicAccuracy(InstanceMetric):
|
6638 |
+
sql_timeout: float = 60.0
|
6639 |
+
prediction_type = "Any"
|
6640 |
+
_requirements_list = ["sqlglot", "func_timeout"]
|
6641 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6642 |
main_score = "non_empty_execution_accuracy"
|
6643 |
+
|
6644 |
+
all_metrics = [
|
6645 |
+
f.name
|
6646 |
+
for f in dataclasses_fields(SQLExecutionResult)
|
6647 |
+
if isinstance(f.type, type) and f.type in (int, float)
|
6648 |
+
]
|
6649 |
+
|
6650 |
+
reduction_map = {"mean": all_metrics}
|
6651 |
+
|
6652 |
ci_scores = [
|
6653 |
"execution_accuracy",
|
6654 |
"non_empty_execution_accuracy",
|
6655 |
+
"subset_non_empty_execution_accuracy",
|
6656 |
+
"execution_accuracy_bird",
|
6657 |
"gold_sql_runtime",
|
6658 |
"predicted_sql_runtime",
|
6659 |
]
|
6660 |
|
6661 |
+
def compute(self, references: List[Any], prediction: str, task_data: Dict) -> dict:
|
6662 |
+
from .text2sql_utils import (
|
6663 |
+
ALL_DIALECTS,
|
6664 |
+
extract_sql_from_text,
|
6665 |
+
get_db_connector,
|
6666 |
+
get_sql_execution_results,
|
6667 |
+
replace_select_clause,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6668 |
)
|
6669 |
|
6670 |
+
predicted_sql = extract_sql_from_text(prediction)
|
6671 |
+
gold_sql = references[0]
|
6672 |
+
dialect = task_data["db"]["db_type"]
|
6673 |
+
if dialect not in ALL_DIALECTS:
|
6674 |
+
dialect = None
|
6675 |
+
revised_sql = (
|
6676 |
+
replace_select_clause(gold_sql, predicted_sql, dialect)
|
6677 |
+
if gold_sql and predicted_sql
|
6678 |
+
else ""
|
6679 |
+
)
|
6680 |
|
6681 |
+
db_connector = get_db_connector(task_data["db"]["db_type"])(task_data["db"])
|
6682 |
+
result_obj = get_sql_execution_results(
|
6683 |
+
revised_sql, gold_sql, db_connector, self.sql_timeout
|
6684 |
+
)
|
6685 |
|
6686 |
+
result = asdict(result_obj)
|
6687 |
+
result["score"] = result[self.main_score]
|
6688 |
+
result["score_name"] = self.main_score
|
6689 |
+
logger.debug(f"SQL Execution Accuracy Result: {result}")
|
6690 |
+
return result
|
6691 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6692 |
|
6693 |
+
class SQLExecutionAccuracy(InstanceMetric):
|
6694 |
+
sql_timeout: float = 60.0
|
6695 |
+
prediction_type = "Any"
|
6696 |
+
_requirements_list = ["sqlglot", "func_timeout"]
|
6697 |
|
6698 |
+
main_score = "non_empty_execution_accuracy"
|
6699 |
|
6700 |
+
all_metrics = [
|
6701 |
+
f.name
|
6702 |
+
for f in dataclasses_fields(SQLExecutionResult)
|
6703 |
+
if isinstance(f.type, type) and f.type in (int, float)
|
6704 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6705 |
|
6706 |
+
reduction_map = {"mean": all_metrics}
|
|
|
|
|
|
|
6707 |
|
6708 |
+
ci_scores = [
|
6709 |
+
"execution_accuracy",
|
6710 |
+
"non_empty_execution_accuracy",
|
6711 |
+
"subset_non_empty_execution_accuracy",
|
6712 |
+
"execution_accuracy_bird",
|
6713 |
+
"gold_sql_runtime",
|
6714 |
+
"predicted_sql_runtime",
|
6715 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6716 |
|
6717 |
+
def compute(self, references: List[Any], prediction: str, task_data: Dict) -> dict:
|
6718 |
+
from .text2sql_utils import (
|
6719 |
+
extract_sql_from_text,
|
6720 |
+
get_db_connector,
|
6721 |
+
get_sql_execution_results,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6722 |
)
|
6723 |
|
6724 |
+
predicted_sql = extract_sql_from_text(prediction)
|
6725 |
+
gold_sql = references[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6726 |
|
6727 |
+
db_connector = get_db_connector(task_data["db"]["db_type"])(task_data["db"])
|
6728 |
+
result_obj = get_sql_execution_results(
|
6729 |
+
predicted_sql, gold_sql, db_connector, self.sql_timeout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6730 |
)
|
6731 |
|
6732 |
+
result = asdict(result_obj)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6733 |
result["score"] = result[self.main_score]
|
6734 |
result["score_name"] = self.main_score
|
6735 |
logger.debug(f"SQL Execution Accuracy Result: {result}")
|
|
|
6737 |
|
6738 |
|
6739 |
class SQLNonExecutionAccuracy(InstanceMetric):
|
6740 |
+
all_metrics = [
|
6741 |
+
f.name
|
6742 |
+
for f in dataclasses_fields(SQLNonExecutionMetricResult)
|
6743 |
+
if isinstance(f.type, type) and f.type in (int, float)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6744 |
]
|
6745 |
+
reduction_map = {"mean": all_metrics}
|
6746 |
+
main_score = "sqlglot_equivalence"
|
6747 |
+
ci_scores = all_metrics
|
6748 |
|
6749 |
prediction_type = "Any" # string representation is compared
|
6750 |
|
6751 |
_requirements_list = ["sqlglot", "sqlparse"]
|
6752 |
|
6753 |
def compute(self, references: List[Any], prediction: str, task_data: Dict) -> dict:
|
6754 |
+
from .text2sql_utils import (
|
6755 |
+
extract_sql_from_text,
|
6756 |
is_sqlglot_parsable,
|
6757 |
is_sqlparse_parsable,
|
6758 |
sql_exact_match,
|
|
|
6761 |
sqlparse_queries_equivalent,
|
6762 |
)
|
6763 |
|
|
|
6764 |
gold_sql = references[0]
|
6765 |
+
predicted_sql = extract_sql_from_text(prediction)
|
|
|
|
|
|
|
|
|
|
|
6766 |
|
6767 |
is_sqlglot_parsable = is_sqlglot_parsable(predicted_sql)
|
6768 |
is_sqlparse_parsable = is_sqlparse_parsable(predicted_sql)
|
6769 |
+
result_obj = SQLNonExecutionMetricResult(
|
6770 |
+
sqlglot_validity=int(is_sqlglot_parsable),
|
6771 |
+
sqlparse_validity=int(is_sqlparse_parsable),
|
6772 |
+
sqlglot_equivalence=int(
|
6773 |
sqlglot_parsed_queries_equivalent(predicted_sql, gold_sql)
|
6774 |
if is_sqlglot_parsable
|
6775 |
else 0
|
6776 |
),
|
6777 |
+
sqlglot_optimized_equivalence=int(
|
6778 |
sqlglot_optimized_equivalence(predicted_sql, gold_sql)
|
6779 |
if is_sqlglot_parsable
|
6780 |
else 0
|
6781 |
),
|
6782 |
+
sqlparse_equivalence=int(
|
6783 |
sqlparse_queries_equivalent(predicted_sql, gold_sql)
|
6784 |
if is_sqlparse_parsable
|
6785 |
else 0
|
6786 |
),
|
6787 |
+
sql_exact_match=int(sql_exact_match(predicted_sql, gold_sql)),
|
6788 |
+
sql_syntactic_equivalence=0, # will update below
|
6789 |
+
)
|
6790 |
+
|
6791 |
+
result_obj.sql_syntactic_equivalence = int(
|
6792 |
any(
|
6793 |
+
[
|
6794 |
+
result_obj.sqlglot_equivalence,
|
6795 |
+
result_obj.sqlglot_optimized_equivalence,
|
6796 |
+
result_obj.sqlparse_equivalence,
|
6797 |
+
result_obj.sql_exact_match,
|
|
|
6798 |
]
|
6799 |
)
|
6800 |
)
|
6801 |
+
|
6802 |
+
result = asdict(result_obj)
|
6803 |
logger.debug(f"SQL Non Execution Accuracy Result: {result}")
|
6804 |
result["score"] = result[self.main_score]
|
6805 |
result["score_name"] = self.main_score
|
operator.py
CHANGED
@@ -6,6 +6,7 @@ from pkg_resources import DistributionNotFound, VersionConflict, require
|
|
6 |
|
7 |
from .artifact import Artifact
|
8 |
from .dataclass import FinalField, InternalField, NonPositionalField
|
|
|
9 |
from .settings_utils import get_constants
|
10 |
from .stream import DynamicStream, EmptyStreamError, MultiStream, Stream
|
11 |
|
@@ -346,7 +347,8 @@ class StreamOperator(MultiStreamOperator):
|
|
346 |
def _process_stream(
|
347 |
self, stream: Stream, stream_name: Optional[str] = None
|
348 |
) -> Generator:
|
349 |
-
|
|
|
350 |
|
351 |
@abstractmethod
|
352 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
@@ -384,12 +386,28 @@ class PagedStreamOperator(StreamOperator):
|
|
384 |
self, stream: Stream, stream_name: Optional[str] = None
|
385 |
) -> Generator:
|
386 |
page = []
|
|
|
387 |
for instance in stream:
|
388 |
page.append(instance)
|
389 |
if len(page) >= self.page_size:
|
390 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
391 |
page = []
|
392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
393 |
|
394 |
def _process_page(
|
395 |
self, page: List[Dict], stream_name: Optional[str] = None
|
@@ -442,17 +460,9 @@ class InstanceOperator(StreamOperator):
|
|
442 |
def _process_stream(
|
443 |
self, stream: Stream, stream_name: Optional[str] = None
|
444 |
) -> Generator:
|
445 |
-
|
446 |
-
|
447 |
-
for _index, instance in enumerate(stream):
|
448 |
yield self._process_instance(instance, stream_name)
|
449 |
-
except Exception as e:
|
450 |
-
if _index is None:
|
451 |
-
raise e
|
452 |
-
else:
|
453 |
-
raise ValueError(
|
454 |
-
f"Error processing instance '{_index}' from stream '{stream_name}' in {self.__class__.__name__} due to the exception above."
|
455 |
-
) from e
|
456 |
|
457 |
def _process_instance(
|
458 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
|
|
6 |
|
7 |
from .artifact import Artifact
|
8 |
from .dataclass import FinalField, InternalField, NonPositionalField
|
9 |
+
from .error_utils import error_context
|
10 |
from .settings_utils import get_constants
|
11 |
from .stream import DynamicStream, EmptyStreamError, MultiStream, Stream
|
12 |
|
|
|
347 |
def _process_stream(
|
348 |
self, stream: Stream, stream_name: Optional[str] = None
|
349 |
) -> Generator:
|
350 |
+
with error_context(self, stream=stream_name):
|
351 |
+
yield from self.process(stream, stream_name)
|
352 |
|
353 |
@abstractmethod
|
354 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
|
|
386 |
self, stream: Stream, stream_name: Optional[str] = None
|
387 |
) -> Generator:
|
388 |
page = []
|
389 |
+
page_number = 0
|
390 |
for instance in stream:
|
391 |
page.append(instance)
|
392 |
if len(page) >= self.page_size:
|
393 |
+
with error_context(
|
394 |
+
self,
|
395 |
+
stream=stream_name,
|
396 |
+
page=page_number,
|
397 |
+
page_size=len(page),
|
398 |
+
):
|
399 |
+
yield from self.process(page, stream_name)
|
400 |
page = []
|
401 |
+
page_number += 1
|
402 |
+
if page: # Handle any remaining instances in the last partial page
|
403 |
+
with error_context(
|
404 |
+
self,
|
405 |
+
stream=stream_name,
|
406 |
+
page=page_number,
|
407 |
+
page_size=len(page),
|
408 |
+
final_page=True,
|
409 |
+
):
|
410 |
+
yield from self._process_page(page, stream_name)
|
411 |
|
412 |
def _process_page(
|
413 |
self, page: List[Dict], stream_name: Optional[str] = None
|
|
|
460 |
def _process_stream(
|
461 |
self, stream: Stream, stream_name: Optional[str] = None
|
462 |
) -> Generator:
|
463 |
+
for _index, instance in enumerate(stream):
|
464 |
+
with error_context(self, stream=stream_name, instance=_index):
|
|
|
465 |
yield self._process_instance(instance, stream_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
|
467 |
def _process_instance(
|
468 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
operators.py
CHANGED
@@ -67,7 +67,7 @@ from .artifact import Artifact, fetch_artifact
|
|
67 |
from .dataclass import NonPositionalField, OptionalField
|
68 |
from .deprecation_utils import deprecation
|
69 |
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
70 |
-
from .error_utils import UnitxtError
|
71 |
from .generator_utils import ReusableGenerator
|
72 |
from .operator import (
|
73 |
InstanceOperator,
|
@@ -309,7 +309,9 @@ def recursive_key_value_replace(data, target_key, value_map, value_remove=None):
|
|
309 |
if not isinstance(item, dict) and item not in value_remove
|
310 |
]
|
311 |
elif isinstance(value, dict):
|
312 |
-
|
|
|
|
|
313 |
elif value in value_remove:
|
314 |
keys_to_delete.append(key)
|
315 |
elif value in value_map:
|
@@ -436,6 +438,7 @@ class InstanceFieldOperator(InstanceOperator):
|
|
436 |
field_to_field: Optional[Union[List[List[str]], Dict[str, str]]] = None
|
437 |
use_query: Optional[bool] = None
|
438 |
process_every_value: bool = False
|
|
|
439 |
get_default: Any = None
|
440 |
not_exist_ok: bool = False
|
441 |
not_exist_do_nothing: bool = False
|
@@ -521,7 +524,7 @@ class InstanceFieldOperator(InstanceOperator):
|
|
521 |
) -> Dict[str, Any]:
|
522 |
self.verify_field_definition()
|
523 |
for from_field, to_field in self._field_to_field:
|
524 |
-
|
525 |
old_value = dict_get(
|
526 |
instance,
|
527 |
from_field,
|
@@ -532,11 +535,8 @@ class InstanceFieldOperator(InstanceOperator):
|
|
532 |
if self.not_exist_do_nothing:
|
533 |
continue
|
534 |
old_value = self.get_default
|
535 |
-
|
536 |
-
|
537 |
-
f"Failed to get '{from_field}' from instance due to the exception above."
|
538 |
-
) from e
|
539 |
-
try:
|
540 |
if self.process_every_value:
|
541 |
new_value = [
|
542 |
self.process_instance_value(value, instance)
|
@@ -544,15 +544,13 @@ class InstanceFieldOperator(InstanceOperator):
|
|
544 |
]
|
545 |
else:
|
546 |
new_value = self.process_instance_value(old_value, instance)
|
547 |
-
|
548 |
-
raise ValueError(
|
549 |
-
f"Failed to process field '{from_field}' from instance due to the exception above."
|
550 |
-
) from e
|
551 |
dict_set(
|
552 |
instance,
|
553 |
to_field,
|
554 |
new_value,
|
555 |
not_exist_ok=True,
|
|
|
556 |
)
|
557 |
return instance
|
558 |
|
@@ -610,11 +608,29 @@ class Rename(FieldOperator):
|
|
610 |
return res
|
611 |
|
612 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
613 |
@deprecation(version="2.0.0", alternative=Rename)
|
614 |
class RenameFields(Rename):
|
615 |
pass
|
616 |
|
617 |
|
|
|
|
|
|
|
|
|
|
|
618 |
class AddConstant(FieldOperator):
|
619 |
"""Adds a constant, being argument 'add', to the processed value.
|
620 |
|
@@ -1200,9 +1216,10 @@ class ApplyOperatorsField(InstanceOperator):
|
|
1200 |
) -> Dict[str, Any]:
|
1201 |
operator_names = instance.get(self.operators_field)
|
1202 |
if operator_names is None:
|
1203 |
-
|
1204 |
-
|
1205 |
-
|
|
|
1206 |
operator_names = self.default_operators
|
1207 |
|
1208 |
if isinstance(operator_names, str):
|
@@ -1436,7 +1453,7 @@ class ExecuteExpression(InstanceOperator, ComputeExpressionMixin):
|
|
1436 |
def process(
|
1437 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
1438 |
) -> Dict[str, Any]:
|
1439 |
-
instance
|
1440 |
return instance
|
1441 |
|
1442 |
|
@@ -1821,54 +1838,58 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
|
|
1821 |
|
1822 |
# to be populated only when two or more metrics
|
1823 |
accumulated_scores = []
|
|
|
|
|
1824 |
|
1825 |
-
|
1826 |
-
|
1827 |
-
|
1828 |
-
|
1829 |
-
raise RuntimeError(
|
1830 |
-
f"Missing metric names in field '{self.metric_field}' and instance '{first_instance}'."
|
1831 |
-
)
|
1832 |
-
|
1833 |
-
if isinstance(metric_names, str):
|
1834 |
-
metric_names = [metric_names]
|
1835 |
-
|
1836 |
-
metrics_list = []
|
1837 |
-
for metric_name in metric_names:
|
1838 |
-
metric = self.get_artifact(metric_name)
|
1839 |
-
if isinstance(metric, MetricsList):
|
1840 |
-
metrics_list.extend(list(metric.items))
|
1841 |
-
elif isinstance(metric, Metric):
|
1842 |
-
metrics_list.append(metric)
|
1843 |
-
else:
|
1844 |
-
raise ValueError(
|
1845 |
-
f"Operator {metric_name} must be a Metric or MetricsList"
|
1846 |
)
|
1847 |
|
1848 |
-
|
1849 |
-
|
1850 |
-
|
1851 |
-
|
1852 |
-
|
1853 |
-
|
1854 |
-
|
1855 |
-
|
1856 |
-
|
1857 |
-
|
1858 |
-
|
1859 |
-
|
1860 |
-
|
1861 |
-
|
1862 |
-
|
|
|
|
|
|
|
1863 |
)
|
1864 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1865 |
|
1866 |
-
|
1867 |
|
1868 |
-
|
1869 |
-
|
1870 |
-
|
1871 |
-
|
1872 |
|
1873 |
yield from multi_stream["tmp"]
|
1874 |
|
|
|
67 |
from .dataclass import NonPositionalField, OptionalField
|
68 |
from .deprecation_utils import deprecation
|
69 |
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
70 |
+
from .error_utils import UnitxtError, error_context
|
71 |
from .generator_utils import ReusableGenerator
|
72 |
from .operator import (
|
73 |
InstanceOperator,
|
|
|
309 |
if not isinstance(item, dict) and item not in value_remove
|
310 |
]
|
311 |
elif isinstance(value, dict):
|
312 |
+
recursive_key_value_replace(
|
313 |
+
value, target_key, value_map, value_remove
|
314 |
+
)
|
315 |
elif value in value_remove:
|
316 |
keys_to_delete.append(key)
|
317 |
elif value in value_map:
|
|
|
438 |
field_to_field: Optional[Union[List[List[str]], Dict[str, str]]] = None
|
439 |
use_query: Optional[bool] = None
|
440 |
process_every_value: bool = False
|
441 |
+
set_every_value: bool = NonPositionalField(default=False)
|
442 |
get_default: Any = None
|
443 |
not_exist_ok: bool = False
|
444 |
not_exist_do_nothing: bool = False
|
|
|
524 |
) -> Dict[str, Any]:
|
525 |
self.verify_field_definition()
|
526 |
for from_field, to_field in self._field_to_field:
|
527 |
+
with error_context(self, field=from_field, action="Read Field"):
|
528 |
old_value = dict_get(
|
529 |
instance,
|
530 |
from_field,
|
|
|
535 |
if self.not_exist_do_nothing:
|
536 |
continue
|
537 |
old_value = self.get_default
|
538 |
+
|
539 |
+
with error_context(self, field=from_field, action="Process Field"):
|
|
|
|
|
|
|
540 |
if self.process_every_value:
|
541 |
new_value = [
|
542 |
self.process_instance_value(value, instance)
|
|
|
544 |
]
|
545 |
else:
|
546 |
new_value = self.process_instance_value(old_value, instance)
|
547 |
+
|
|
|
|
|
|
|
548 |
dict_set(
|
549 |
instance,
|
550 |
to_field,
|
551 |
new_value,
|
552 |
not_exist_ok=True,
|
553 |
+
set_multiple=self.set_every_value,
|
554 |
)
|
555 |
return instance
|
556 |
|
|
|
608 |
return res
|
609 |
|
610 |
|
611 |
+
class Move(InstanceOperator):
|
612 |
+
field: str
|
613 |
+
to_field: str
|
614 |
+
|
615 |
+
def process(
|
616 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
617 |
+
) -> Dict[str, Any]:
|
618 |
+
value = dict_get(instance, self.field)
|
619 |
+
dict_delete(instance, self.field)
|
620 |
+
dict_set(instance, self.to_field, value=value)
|
621 |
+
return instance
|
622 |
+
|
623 |
+
|
624 |
@deprecation(version="2.0.0", alternative=Rename)
|
625 |
class RenameFields(Rename):
|
626 |
pass
|
627 |
|
628 |
|
629 |
+
class BytesToString(FieldOperator):
|
630 |
+
def process_value(self, value: Any) -> Any:
|
631 |
+
return str(value)
|
632 |
+
|
633 |
+
|
634 |
class AddConstant(FieldOperator):
|
635 |
"""Adds a constant, being argument 'add', to the processed value.
|
636 |
|
|
|
1216 |
) -> Dict[str, Any]:
|
1217 |
operator_names = instance.get(self.operators_field)
|
1218 |
if operator_names is None:
|
1219 |
+
if self.default_operators is None:
|
1220 |
+
raise ValueError(
|
1221 |
+
f"No operators found in field '{self.operators_field}', and no default operators provided."
|
1222 |
+
)
|
1223 |
operator_names = self.default_operators
|
1224 |
|
1225 |
if isinstance(operator_names, str):
|
|
|
1453 |
def process(
|
1454 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
1455 |
) -> Dict[str, Any]:
|
1456 |
+
dict_set(instance, self.to_field, self.compute_expression(instance))
|
1457 |
return instance
|
1458 |
|
1459 |
|
|
|
1838 |
|
1839 |
# to be populated only when two or more metrics
|
1840 |
accumulated_scores = []
|
1841 |
+
with error_context(self, stage="Load Metrics"):
|
1842 |
+
first_instance = stream.peek()
|
1843 |
|
1844 |
+
metric_names = first_instance.get(self.metric_field, [])
|
1845 |
+
if not metric_names:
|
1846 |
+
raise RuntimeError(
|
1847 |
+
f"Missing metric names in field '{self.metric_field}' and instance '{first_instance}'."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1848 |
)
|
1849 |
|
1850 |
+
if isinstance(metric_names, str):
|
1851 |
+
metric_names = [metric_names]
|
1852 |
+
|
1853 |
+
metrics_list = []
|
1854 |
+
for metric_name in metric_names:
|
1855 |
+
metric = self.get_artifact(metric_name)
|
1856 |
+
if isinstance(metric, MetricsList):
|
1857 |
+
metrics_list.extend(list(metric.items))
|
1858 |
+
elif isinstance(metric, Metric):
|
1859 |
+
metrics_list.append(metric)
|
1860 |
+
else:
|
1861 |
+
raise ValueError(
|
1862 |
+
f"Operator {metric_name} must be a Metric or MetricsList"
|
1863 |
+
)
|
1864 |
+
with error_context(self, stage="Setup Metrics"):
|
1865 |
+
for metric in metrics_list:
|
1866 |
+
metric.set_confidence_interval_calculation(
|
1867 |
+
self.calc_confidence_intervals
|
1868 |
)
|
1869 |
+
# Each metric operator computes its score and then sets the main score, overwriting
|
1870 |
+
# the previous main score value (if any). So, we need to reverse the order of the listed metrics.
|
1871 |
+
# This will cause the first listed metric to run last, and the main score will be set
|
1872 |
+
# by the first listed metric (as desired).
|
1873 |
+
metrics_list = list(reversed(metrics_list))
|
1874 |
+
|
1875 |
+
for i, metric in enumerate(metrics_list):
|
1876 |
+
if i == 0: # first metric
|
1877 |
+
multi_stream = MultiStream({"tmp": stream})
|
1878 |
+
else: # metrics with previous scores
|
1879 |
+
reusable_generator = ReusableGenerator(
|
1880 |
+
generator=update_scores_of_stream_instances,
|
1881 |
+
gen_kwargs={"stream": stream, "scores": accumulated_scores},
|
1882 |
+
)
|
1883 |
+
multi_stream = MultiStream.from_generators(
|
1884 |
+
{"tmp": reusable_generator}
|
1885 |
+
)
|
1886 |
|
1887 |
+
multi_stream = metric(multi_stream)
|
1888 |
|
1889 |
+
if i < len(metrics_list) - 1: # last metric
|
1890 |
+
accumulated_scores = []
|
1891 |
+
for inst in multi_stream["tmp"]:
|
1892 |
+
accumulated_scores.append(recursive_copy(inst["score"]))
|
1893 |
|
1894 |
yield from multi_stream["tmp"]
|
1895 |
|
processors.py
CHANGED
@@ -98,6 +98,16 @@ class ExtractWithRegex(RegexParser):
|
|
98 |
return ""
|
99 |
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
class ListToEmptyEntitiesTuples(FieldOperator):
|
102 |
def process_value(self, lst: Any) -> Any:
|
103 |
try:
|
@@ -286,7 +296,7 @@ class StringOrNotString(StringEquals):
|
|
286 |
|
287 |
class ExtractMtBenchRatingJudgment(FieldOperator):
|
288 |
def process_value(self, text: Any) -> Any:
|
289 |
-
match = re.search(r"\[\[([\d]+\.?[\d]*)
|
290 |
try:
|
291 |
return float(match.group(1)) / 10
|
292 |
except:
|
|
|
98 |
return ""
|
99 |
|
100 |
|
101 |
+
class GroupDictWithRegex(FieldOperator):
|
102 |
+
pattern: str
|
103 |
+
|
104 |
+
def process_value(self, value: Any) -> Any:
|
105 |
+
match = re.match(self.pattern, value)
|
106 |
+
if match:
|
107 |
+
return match.groupdict()
|
108 |
+
return {}
|
109 |
+
|
110 |
+
|
111 |
class ListToEmptyEntitiesTuples(FieldOperator):
|
112 |
def process_value(self, lst: Any) -> Any:
|
113 |
try:
|
|
|
296 |
|
297 |
class ExtractMtBenchRatingJudgment(FieldOperator):
|
298 |
def process_value(self, text: Any) -> Any:
|
299 |
+
match = re.search(r"\[\[([\s*\d]+\.?[\d]*\s*)(/\s*10)?\s*\]\]", text)
|
300 |
try:
|
301 |
return float(match.group(1)) / 10
|
302 |
except:
|
schema.py
CHANGED
@@ -59,7 +59,7 @@ def get_schema(stream_name):
|
|
59 |
def load_chat_source(chat_str):
|
60 |
chat = json.loads(chat_str)
|
61 |
for turn in chat:
|
62 |
-
if isinstance(turn["content"], list):
|
63 |
for content in turn["content"]:
|
64 |
if content["type"] == "image_url":
|
65 |
content["image_url"]["url"] = ImageDataString(
|
|
|
59 |
def load_chat_source(chat_str):
|
60 |
chat = json.loads(chat_str)
|
61 |
for turn in chat:
|
62 |
+
if "content" in turn and isinstance(turn["content"], list):
|
63 |
for content in turn["content"]:
|
64 |
if content["type"] == "image_url":
|
65 |
content["image_url"]["url"] = ImageDataString(
|
serializers.py
CHANGED
@@ -9,6 +9,7 @@ from .operators import InstanceFieldOperator
|
|
9 |
from .settings_utils import get_constants
|
10 |
from .type_utils import isoftype, to_type_string
|
11 |
from .types import (
|
|
|
12 |
Dialog,
|
13 |
Document,
|
14 |
Image,
|
@@ -75,7 +76,22 @@ class DialogSerializer(SingleTypeSerializer):
|
|
75 |
|
76 |
def serialize(self, value: Dialog, instance: Dict[str, Any]) -> str:
|
77 |
# Convert the Dialog into a string representation, typically combining roles and content
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
|
81 |
class NumberSerializer(SingleTypeSerializer):
|
@@ -225,7 +241,7 @@ class SQLDatabaseAsSchemaSerializer(SingleTypeSerializer):
|
|
225 |
serialized_type = SQLDatabase
|
226 |
|
227 |
def serialize(self, value: SQLDatabase, instance: Dict[str, Any]) -> str:
|
228 |
-
from .
|
229 |
|
230 |
connector = get_db_connector(value["db_type"])(value)
|
231 |
return connector.get_table_schema()
|
|
|
9 |
from .settings_utils import get_constants
|
10 |
from .type_utils import isoftype, to_type_string
|
11 |
from .types import (
|
12 |
+
Conversation,
|
13 |
Dialog,
|
14 |
Document,
|
15 |
Image,
|
|
|
76 |
|
77 |
def serialize(self, value: Dialog, instance: Dict[str, Any]) -> str:
|
78 |
# Convert the Dialog into a string representation, typically combining roles and content
|
79 |
+
turns = []
|
80 |
+
for turn in value:
|
81 |
+
turn_str = f"{turn['role']}: "
|
82 |
+
if "content" in turn:
|
83 |
+
turn_str += str(turn["content"])
|
84 |
+
if "tool_calls" in turn:
|
85 |
+
turn_str += "\n" + json.dumps(turn["tool_calls"])
|
86 |
+
turns.append(turn_str)
|
87 |
+
return "\n".join(turns)
|
88 |
+
|
89 |
+
|
90 |
+
class ConversationSerializer(DialogSerializer):
|
91 |
+
serialized_type = Conversation
|
92 |
+
|
93 |
+
def serialize(self, value: Conversation, instance: Dict[str, Any]) -> str:
|
94 |
+
return super().serialize(value["dialog"], instance)
|
95 |
|
96 |
|
97 |
class NumberSerializer(SingleTypeSerializer):
|
|
|
241 |
serialized_type = SQLDatabase
|
242 |
|
243 |
def serialize(self, value: SQLDatabase, instance: Dict[str, Any]) -> str:
|
244 |
+
from .text2sql_utils import get_db_connector
|
245 |
|
246 |
connector = get_db_connector(value["db_type"])(value)
|
247 |
return connector.get_table_schema()
|
settings_utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import importlib.metadata
|
2 |
import importlib.util
|
3 |
import os
|
|
|
4 |
from contextlib import contextmanager
|
5 |
|
6 |
from .version import version
|
@@ -177,6 +178,9 @@ if Constants.is_uninitilized():
|
|
177 |
constants.dataset_url = "unitxt/data"
|
178 |
constants.metric_url = "unitxt/metric"
|
179 |
constants.version = version
|
|
|
|
|
|
|
180 |
constants.catalog_hierarchy_sep = "."
|
181 |
constants.env_local_catalogs_paths_sep = ":"
|
182 |
constants.non_registered_files = [
|
|
|
1 |
import importlib.metadata
|
2 |
import importlib.util
|
3 |
import os
|
4 |
+
import sys
|
5 |
from contextlib import contextmanager
|
6 |
|
7 |
from .version import version
|
|
|
178 |
constants.dataset_url = "unitxt/data"
|
179 |
constants.metric_url = "unitxt/metric"
|
180 |
constants.version = version
|
181 |
+
constants.python = (
|
182 |
+
f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
|
183 |
+
)
|
184 |
constants.catalog_hierarchy_sep = "."
|
185 |
constants.env_local_catalogs_paths_sep = ":"
|
186 |
constants.non_registered_files = [
|
struct_data_operators.py
CHANGED
@@ -23,6 +23,7 @@ For key-value pairs, expected input format is:
|
|
23 |
{"key1": "value1", "key2": value2, "key3": "value3"}
|
24 |
"""
|
25 |
|
|
|
26 |
import json
|
27 |
import random
|
28 |
from abc import ABC, abstractmethod
|
@@ -754,11 +755,40 @@ class LoadJson(FieldOperator):
|
|
754 |
return json.loads(value, strict=False)
|
755 |
|
756 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
757 |
class ToolCallPostProcessor(FieldOperator):
|
758 |
failure_value: Any = None
|
759 |
allow_failure: bool = False
|
760 |
|
761 |
def process_value(self, value: str) -> ToolCall:
|
|
|
|
|
|
|
762 |
if self.allow_failure:
|
763 |
try:
|
764 |
result = json.loads(value)
|
@@ -776,6 +806,25 @@ class ToolCallPostProcessor(FieldOperator):
|
|
776 |
return result
|
777 |
|
778 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
779 |
class DumpJson(FieldOperator):
|
780 |
def process_value(self, value: str) -> str:
|
781 |
return json.dumps(value)
|
|
|
23 |
{"key1": "value1", "key2": value2, "key3": "value3"}
|
24 |
"""
|
25 |
|
26 |
+
import ast
|
27 |
import json
|
28 |
import random
|
29 |
from abc import ABC, abstractmethod
|
|
|
755 |
return json.loads(value, strict=False)
|
756 |
|
757 |
|
758 |
+
class PythonCallProcessor(FieldOperator):
|
759 |
+
def process_value(self, value: str) -> ToolCall:
|
760 |
+
expr = ast.parse(value, mode="eval").body
|
761 |
+
function = expr.func.id
|
762 |
+
args = {}
|
763 |
+
for kw in expr.keywords:
|
764 |
+
args[kw.arg] = ast.literal_eval(kw.value)
|
765 |
+
# Handle positional args, if any
|
766 |
+
if expr.args:
|
767 |
+
args["_args"] = [ast.literal_eval(arg) for arg in expr.args]
|
768 |
+
return {"name": function, "arguments": args}
|
769 |
+
|
770 |
+
|
771 |
+
def extract_possible_json_str(text):
|
772 |
+
"""Extract potential JSON string from text by finding outermost braces/brackets."""
|
773 |
+
# Find first opening delimiter
|
774 |
+
start_positions = [pos for pos in [text.find("{"), text.find("[")] if pos != -1]
|
775 |
+
start = min(start_positions) if start_positions else 0
|
776 |
+
|
777 |
+
# Find last closing delimiter
|
778 |
+
end_positions = [pos for pos in [text.rfind("}"), text.rfind("]")] if pos != -1]
|
779 |
+
end = max(end_positions) if end_positions else len(text) - 1
|
780 |
+
|
781 |
+
return text[start : end + 1]
|
782 |
+
|
783 |
+
|
784 |
class ToolCallPostProcessor(FieldOperator):
|
785 |
failure_value: Any = None
|
786 |
allow_failure: bool = False
|
787 |
|
788 |
def process_value(self, value: str) -> ToolCall:
|
789 |
+
value = extract_possible_json_str(
|
790 |
+
value
|
791 |
+
) # clear tokens such as <tool_call> focusing on the call json itself
|
792 |
if self.allow_failure:
|
793 |
try:
|
794 |
result = json.loads(value)
|
|
|
806 |
return result
|
807 |
|
808 |
|
809 |
+
class MultipleToolCallPostProcessor(FieldOperator):
|
810 |
+
failure_value: Any = None
|
811 |
+
allow_failure: bool = False
|
812 |
+
|
813 |
+
def process_value(self, value: str) -> List[ToolCall]:
|
814 |
+
if self.allow_failure:
|
815 |
+
try:
|
816 |
+
result = json.loads(value)
|
817 |
+
except json.JSONDecodeError:
|
818 |
+
return self.failure_value
|
819 |
+
else:
|
820 |
+
result = json.loads(value, strict=False)
|
821 |
+
if isoftype(result, List[ToolCall]):
|
822 |
+
return result
|
823 |
+
if not isoftype(result, ToolCall):
|
824 |
+
return self.failure_value
|
825 |
+
return [result]
|
826 |
+
|
827 |
+
|
828 |
class DumpJson(FieldOperator):
|
829 |
def process_value(self, value: str) -> str:
|
830 |
return json.dumps(value)
|
task.py
CHANGED
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Union
|
|
3 |
|
4 |
from .artifact import fetch_artifact
|
5 |
from .deprecation_utils import deprecation
|
6 |
-
from .error_utils import Documentation, UnitxtError, UnitxtWarning
|
7 |
from .logging_utils import get_logger
|
8 |
from .metrics import MetricsList
|
9 |
from .operator import InstanceOperator
|
@@ -285,13 +285,18 @@ class Task(InstanceOperator, ArtifactFetcherMixin):
|
|
285 |
) -> Dict[str, Any]:
|
286 |
instance = self.set_default_values(instance)
|
287 |
|
288 |
-
|
289 |
-
self
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
|
|
|
|
|
|
|
|
|
|
295 |
input_fields = {key: instance[key] for key in self.input_fields.keys()}
|
296 |
data_classification_policy = instance.get("data_classification_policy", [])
|
297 |
|
|
|
3 |
|
4 |
from .artifact import fetch_artifact
|
5 |
from .deprecation_utils import deprecation
|
6 |
+
from .error_utils import Documentation, UnitxtError, UnitxtWarning, error_context
|
7 |
from .logging_utils import get_logger
|
8 |
from .metrics import MetricsList
|
9 |
from .operator import InstanceOperator
|
|
|
285 |
) -> Dict[str, Any]:
|
286 |
instance = self.set_default_values(instance)
|
287 |
|
288 |
+
with error_context(
|
289 |
+
self,
|
290 |
+
stage="Schema Verification",
|
291 |
+
help="https://www.unitxt.ai/en/latest/docs/adding_task.html",
|
292 |
+
):
|
293 |
+
verify_required_schema(
|
294 |
+
self.input_fields,
|
295 |
+
instance,
|
296 |
+
class_name="Task",
|
297 |
+
id=self.__id__,
|
298 |
+
description=self.__description__,
|
299 |
+
)
|
300 |
input_fields = {key: instance[key] for key in self.input_fields.keys()}
|
301 |
data_classification_policy = instance.get("data_classification_policy", [])
|
302 |
|
templates.py
CHANGED
@@ -6,11 +6,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
6 |
from .artifact import Artifact
|
7 |
from .collections import DictCollection, ListCollection
|
8 |
from .dataclass import NonPositionalField
|
9 |
-
from .dict_utils import dict_set
|
10 |
from .error_utils import Documentation, UnitxtError
|
11 |
from .operator import InstanceOperator, Operator
|
12 |
from .random_utils import new_random_generator
|
13 |
from .serializers import (
|
|
|
14 |
DialogSerializer,
|
15 |
ImageSerializer,
|
16 |
ListSerializer,
|
@@ -68,6 +69,7 @@ class Template(InstanceOperator):
|
|
68 |
ToolCallSerializer(),
|
69 |
ToolsSerializer(),
|
70 |
DialogSerializer(),
|
|
|
71 |
ListSerializer(),
|
72 |
SQLDatabaseAsSchemaSerializer(),
|
73 |
]
|
@@ -942,6 +944,16 @@ class MultiReferenceTemplate(InputOutputTemplate):
|
|
942 |
return target, references
|
943 |
|
944 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
945 |
def escape_chars(s, chars_to_escape):
|
946 |
for char in chars_to_escape:
|
947 |
s = s.replace(char, f"\\{char}")
|
|
|
6 |
from .artifact import Artifact
|
7 |
from .collections import DictCollection, ListCollection
|
8 |
from .dataclass import NonPositionalField
|
9 |
+
from .dict_utils import dict_get, dict_set
|
10 |
from .error_utils import Documentation, UnitxtError
|
11 |
from .operator import InstanceOperator, Operator
|
12 |
from .random_utils import new_random_generator
|
13 |
from .serializers import (
|
14 |
+
ConversationSerializer,
|
15 |
DialogSerializer,
|
16 |
ImageSerializer,
|
17 |
ListSerializer,
|
|
|
69 |
ToolCallSerializer(),
|
70 |
ToolsSerializer(),
|
71 |
DialogSerializer(),
|
72 |
+
ConversationSerializer(),
|
73 |
ListSerializer(),
|
74 |
SQLDatabaseAsSchemaSerializer(),
|
75 |
]
|
|
|
944 |
return target, references
|
945 |
|
946 |
|
947 |
+
class MultiTurnTemplate(MultiReferenceTemplate):
|
948 |
+
input_format = ""
|
949 |
+
turns_field: str
|
950 |
+
|
951 |
+
def post_process_instance(self, instance):
|
952 |
+
turns = dict_get(instance["input_fields"], self.turns_field)
|
953 |
+
instance["__turns__"] = turns
|
954 |
+
return super().post_process_instance(instance)
|
955 |
+
|
956 |
+
|
957 |
def escape_chars(s, chars_to_escape):
|
958 |
for char in chars_to_escape:
|
959 |
s = s.replace(char, f"\\{char}")
|
sql_utils.py → text2sql_utils.py
RENAMED
@@ -7,9 +7,13 @@ import re
|
|
7 |
import sqlite3
|
8 |
import time
|
9 |
from abc import ABC, abstractmethod
|
|
|
|
|
10 |
from functools import lru_cache
|
11 |
-
from typing import Any, List, Optional
|
12 |
|
|
|
|
|
13 |
import requests
|
14 |
from huggingface_hub import snapshot_download
|
15 |
from requests.exceptions import ConnectionError, ReadTimeout
|
@@ -539,6 +543,17 @@ def get_db_connector(db_type: str):
|
|
539 |
return connector
|
540 |
|
541 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
542 |
def is_sqlglot_parsable(sql: str, db_type="sqlite") -> bool:
|
543 |
"""Returns True if sqlglot does not encounter any error, False otherwise."""
|
544 |
from sqlglot import parse
|
@@ -695,7 +710,7 @@ def extract_select_info(sql: str):
|
|
695 |
|
696 |
|
697 |
def sqlparse_queries_equivalent(sql1: str, sql2: str) -> bool:
|
698 |
-
"""
|
699 |
try:
|
700 |
info1 = extract_select_info(sql1)
|
701 |
info2 = extract_select_info(sql2)
|
@@ -713,6 +728,7 @@ def sqlparse_queries_equivalent(sql1: str, sql2: str) -> bool:
|
|
713 |
|
714 |
|
715 |
def sqlglot_parsed_queries_equivalent(sql1: str, sql2: str, dialect: str = "") -> bool:
|
|
|
716 |
from sqlglot import exp, parse_one
|
717 |
|
718 |
try:
|
@@ -754,3 +770,473 @@ def sql_exact_match(sql1: str, sql2: str) -> bool:
|
|
754 |
return s.upper()
|
755 |
|
756 |
return normalize_sql(sql1) == normalize_sql(sql2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
import sqlite3
|
8 |
import time
|
9 |
from abc import ABC, abstractmethod
|
10 |
+
from collections import Counter
|
11 |
+
from dataclasses import dataclass
|
12 |
from functools import lru_cache
|
13 |
+
from typing import Any, List, Optional, Tuple
|
14 |
|
15 |
+
import numpy as np
|
16 |
+
import pandas as pd
|
17 |
import requests
|
18 |
from huggingface_hub import snapshot_download
|
19 |
from requests.exceptions import ConnectionError, ReadTimeout
|
|
|
543 |
return connector
|
544 |
|
545 |
|
546 |
+
@dataclass
|
547 |
+
class SQLNonExecutionMetricResult:
|
548 |
+
sqlglot_validity: int # Whether SQL parses with sqlglot
|
549 |
+
sqlparse_validity: int # Whether SQL parses with sqlparse
|
550 |
+
sqlglot_equivalence: int # Semantic equivalence using sqlglot AST
|
551 |
+
sqlglot_optimized_equivalence: int # Equivalence after optimization via sqlglot
|
552 |
+
sqlparse_equivalence: int # Equivalence using sqlparse AST
|
553 |
+
sql_exact_match: int # Exact string match of predicted and gold SQL
|
554 |
+
sql_syntactic_equivalence: int # Any of the above equivalence conditions hold
|
555 |
+
|
556 |
+
|
557 |
def is_sqlglot_parsable(sql: str, db_type="sqlite") -> bool:
|
558 |
"""Returns True if sqlglot does not encounter any error, False otherwise."""
|
559 |
from sqlglot import parse
|
|
|
710 |
|
711 |
|
712 |
def sqlparse_queries_equivalent(sql1: str, sql2: str) -> bool:
|
713 |
+
"""Returns True if both SQL queries are naively considered equivalent."""
|
714 |
try:
|
715 |
info1 = extract_select_info(sql1)
|
716 |
info2 = extract_select_info(sql2)
|
|
|
728 |
|
729 |
|
730 |
def sqlglot_parsed_queries_equivalent(sql1: str, sql2: str, dialect: str = "") -> bool:
|
731 |
+
"""Return True if two SQL queries match after parsing with SQLGlot."""
|
732 |
from sqlglot import exp, parse_one
|
733 |
|
734 |
try:
|
|
|
770 |
return s.upper()
|
771 |
|
772 |
return normalize_sql(sql1) == normalize_sql(sql2)
|
773 |
+
|
774 |
+
|
775 |
+
@dataclass
|
776 |
+
class SQLExecutionResult:
|
777 |
+
execution_accuracy: int # Whether the predicted and gold SQL results match exactly
|
778 |
+
non_empty_execution_accuracy: (
|
779 |
+
int # Same as execution_accuracy but only if gold is non-empty
|
780 |
+
)
|
781 |
+
subset_non_empty_execution_accuracy: (
|
782 |
+
int # Whether predicted is a subset of gold or vice versa, non-empty only
|
783 |
+
)
|
784 |
+
execution_accuracy_bird: (
|
785 |
+
int # Whether the predicted SQL matches gold using BIRD evaluation logic
|
786 |
+
)
|
787 |
+
non_empty_gold_df: int # Whether the gold SQL produced a non-empty dataframe
|
788 |
+
gold_sql_runtime: float # Time taken to execute the gold SQL
|
789 |
+
predicted_sql_runtime: float # Time taken to execute the predicted SQL
|
790 |
+
pred_to_gold_runtime_ratio: float # Ratio of predicted runtime to gold runtime
|
791 |
+
gold_error: int # Whether the gold SQL had an execution error
|
792 |
+
predicted_error: int # Whether the predicted SQL had an execution error
|
793 |
+
gold_df_json: str # JSON representation of the gold SQL result dataframe
|
794 |
+
predicted_df_json: str # JSON representation of the predicted SQL result dataframe
|
795 |
+
error_message: str # Error message from predicted execution if any
|
796 |
+
|
797 |
+
|
798 |
+
def compare_dfs_ignore_colnames_ordered_rows(
|
799 |
+
df1: pd.DataFrame, df2: pd.DataFrame
|
800 |
+
) -> bool:
|
801 |
+
if df1.shape != df2.shape:
|
802 |
+
return False
|
803 |
+
df1_sorted_rows = np.array([np.sort(row) for row in df1.values.astype(str)])
|
804 |
+
df2_sorted_rows = np.array([np.sort(row) for row in df2.values.astype(str)])
|
805 |
+
return np.array_equal(df1_sorted_rows, df2_sorted_rows)
|
806 |
+
|
807 |
+
|
808 |
+
def compare_dfs_ignore_colnames_unordered_rows(
|
809 |
+
df1: pd.DataFrame, df2: pd.DataFrame
|
810 |
+
) -> bool:
|
811 |
+
if df1.shape != df2.shape:
|
812 |
+
return False
|
813 |
+
df1_sorted = np.sort(np.sort(df1.values.astype(str), axis=1), axis=0)
|
814 |
+
df2_sorted = np.sort(np.sort(df2.values.astype(str), axis=1), axis=0)
|
815 |
+
return np.array_equal(df1_sorted, df2_sorted)
|
816 |
+
|
817 |
+
|
818 |
+
def compare_dfs_ignore_colnames_subset(
|
819 |
+
df1: pd.DataFrame, df2: pd.DataFrame, ignore_row_order: bool = True
|
820 |
+
) -> bool:
|
821 |
+
"""Checks if the smaller of the two DataFrames is likely a subset of the other.
|
822 |
+
|
823 |
+
Subset comparison is column-based, to support Text2SQL evaluation for when the
|
824 |
+
predicted SQL dataframe has missing or additional columns. Each row is treated as
|
825 |
+
a multiset of (stringified) values, and the function checks if every row in the
|
826 |
+
smaller DataFrame (by column count) is a multiset subset of the corresponding row
|
827 |
+
in the larger DataFrame. When ground truth SQL does not have ORDER BY,
|
828 |
+
ignore_row_order can be set to True to ignore the order of rows. In this case,
|
829 |
+
column values are sorted before comparison. This means that there could be cases
|
830 |
+
where the dataframes have the exact same number of rows and column values after
|
831 |
+
sort are the same, but the dataframes are not actually a subset of each other.
|
832 |
+
This is unlikely to happen in practice, but the score is not guaranteed to be
|
833 |
+
100% accurate and may overestimate the accuracy.
|
834 |
+
|
835 |
+
Args:
|
836 |
+
df1 (pd.DataFrame): The first DataFrame to compare.
|
837 |
+
df2 (pd.DataFrame): The second DataFrame to compare.
|
838 |
+
ignore_row_order (bool, optional): If True, ignores the order of rows by
|
839 |
+
sorting them before comparison. Defaults to True.
|
840 |
+
|
841 |
+
Returns:
|
842 |
+
bool: True if the smaller DataFrame (column-wise) is likely a subset of the
|
843 |
+
larger one, False otherwise.
|
844 |
+
"""
|
845 |
+
|
846 |
+
def row_to_multiset(row):
|
847 |
+
return Counter(str(x) for x in row)
|
848 |
+
|
849 |
+
def rows_to_multisets(df):
|
850 |
+
return [row_to_multiset(row) for row in df.values]
|
851 |
+
|
852 |
+
def sort_df(df):
|
853 |
+
sorted_df = df.copy()
|
854 |
+
for col in sorted_df.columns:
|
855 |
+
sorted_df[col] = sorted_df[col].astype(str).sort_values(ignore_index=True)
|
856 |
+
return sorted_df
|
857 |
+
|
858 |
+
if df1.empty or df2.empty or len(df1) != len(df2):
|
859 |
+
return False
|
860 |
+
|
861 |
+
subset_df, superset_df = (df1, df2) if df1.shape[1] <= df2.shape[1] else (df2, df1)
|
862 |
+
|
863 |
+
if ignore_row_order:
|
864 |
+
subset_df = sort_df(subset_df)
|
865 |
+
superset_df = sort_df(superset_df)
|
866 |
+
|
867 |
+
subset_rows = rows_to_multisets(subset_df)
|
868 |
+
superset_rows = rows_to_multisets(superset_df)
|
869 |
+
|
870 |
+
for r1, r2 in zip(subset_rows, superset_rows):
|
871 |
+
if not all(r1[k] <= r2.get(k, 0) for k in r1):
|
872 |
+
return False
|
873 |
+
return True
|
874 |
+
|
875 |
+
|
876 |
+
def compare_dfs_bird_eval_logic(df1: pd.DataFrame, df2: pd.DataFrame):
|
877 |
+
"""Check if two SQL query result sets are exactly equal, as in BIRD evaluation.
|
878 |
+
|
879 |
+
This function checks if the set of rows returned by the predicted SQL query
|
880 |
+
(`predicted_res`) is exactly equal to the set of rows returned by the ground truth
|
881 |
+
SQL query (`ground_truth_res`). This is the logic used in the original BIRD
|
882 |
+
evaluation code:
|
883 |
+
https://github.com/AlibabaResearch/DAMO-ConvAI/blob/main/bird/llm/src/evaluation.py.
|
884 |
+
"""
|
885 |
+
df1_set = {tuple(row) for row in df1.values.astype(str)}
|
886 |
+
df2_set = {tuple(row) for row in df2.values.astype(str)}
|
887 |
+
return int(df1_set == df2_set)
|
888 |
+
|
889 |
+
|
890 |
+
def compare_result_dfs(
|
891 |
+
gold_df: pd.DataFrame, pred_df: pd.DataFrame, gold_sql: str
|
892 |
+
) -> Tuple[int, int, int]:
|
893 |
+
"""Compares two DataFrames representing SQL query results.
|
894 |
+
|
895 |
+
Args:
|
896 |
+
gold_df (pd.DataFrame): The ground truth DataFrame.
|
897 |
+
pred_df (pd.DataFrame): The predicted DataFrame.
|
898 |
+
gold_sql (str): The ground truth SQL query string.
|
899 |
+
|
900 |
+
Returns:
|
901 |
+
Tuple[int, int, int]: A tuple containing:
|
902 |
+
- match (int): 1 if the predicted DataFrame matches the gold DataFrame
|
903 |
+
- non_empty_match (int): 1 if both DataFrames are non-empty and match,
|
904 |
+
0 otherwise.
|
905 |
+
- subset_match (int): 1 if the predicted DataFrame is a subset or
|
906 |
+
superset of the gold DataFrame.
|
907 |
+
|
908 |
+
Notes:
|
909 |
+
- The comparison ignores column names.
|
910 |
+
- Row order is considered only if 'ORDER BY' is present in the SQL query.
|
911 |
+
"""
|
912 |
+
subset_match = 0
|
913 |
+
non_empty_match = 0
|
914 |
+
if "ORDER BY" in gold_sql.upper():
|
915 |
+
match = int(compare_dfs_ignore_colnames_ordered_rows(pred_df, gold_df))
|
916 |
+
if not gold_df.empty and not pred_df.empty:
|
917 |
+
non_empty_match = match
|
918 |
+
if compare_dfs_ignore_colnames_subset(
|
919 |
+
gold_df, pred_df, ignore_row_order=False
|
920 |
+
):
|
921 |
+
subset_match = 1
|
922 |
+
else:
|
923 |
+
match = int(compare_dfs_ignore_colnames_unordered_rows(pred_df, gold_df))
|
924 |
+
if not gold_df.empty and not pred_df.empty:
|
925 |
+
non_empty_match = match
|
926 |
+
if compare_dfs_ignore_colnames_subset(
|
927 |
+
gold_df, pred_df, ignore_row_order=True
|
928 |
+
):
|
929 |
+
subset_match = 1
|
930 |
+
return match, non_empty_match, subset_match
|
931 |
+
|
932 |
+
|
933 |
+
def run_query(
|
934 |
+
sql: str, connector, sql_timeout: float
|
935 |
+
) -> Tuple[Optional[pd.DataFrame], float, str]:
|
936 |
+
"""Executes a SQL query using the provided connector with a timeout.
|
937 |
+
|
938 |
+
Args:
|
939 |
+
sql (str): The SQL query string to execute.
|
940 |
+
connector: An object with an `execute_query` method that executes the SQL
|
941 |
+
query.
|
942 |
+
sql_timeout (float): The maximum time in seconds to allow for query
|
943 |
+
execution.
|
944 |
+
|
945 |
+
Returns:
|
946 |
+
Tuple[Optional[pd.DataFrame], float, str]:
|
947 |
+
- A pandas DataFrame containing the query results, or None if an error
|
948 |
+
occurred.
|
949 |
+
- The duration in seconds taken to execute the query. 0.0 if an error.
|
950 |
+
- An error message string if an error occurred, otherwise an empty
|
951 |
+
string.
|
952 |
+
|
953 |
+
Notes:
|
954 |
+
- If the SQL string is empty or only whitespace, returns immediately with a
|
955 |
+
message.
|
956 |
+
- If the query execution exceeds the timeout, returns a timeout error
|
957 |
+
message.
|
958 |
+
- Any other exceptions are caught and returned as error messages.
|
959 |
+
"""
|
960 |
+
import time
|
961 |
+
|
962 |
+
from func_timeout import func_timeout
|
963 |
+
from func_timeout.exceptions import FunctionTimedOut
|
964 |
+
|
965 |
+
if not sql.strip():
|
966 |
+
return None, 0.0, "No SQL query found in the prediction."
|
967 |
+
|
968 |
+
try:
|
969 |
+
start = time.perf_counter()
|
970 |
+
result, error = func_timeout(sql_timeout, connector.execute_query, args=(sql,))
|
971 |
+
duration = time.perf_counter() - start
|
972 |
+
if isinstance(result, dict) and "results" in result:
|
973 |
+
result = result["results"]
|
974 |
+
if error:
|
975 |
+
return None, duration, error
|
976 |
+
return pd.DataFrame(result), duration, ""
|
977 |
+
except FunctionTimedOut as e:
|
978 |
+
return None, 0.0, f"Timeout: {e}"
|
979 |
+
except Exception as e:
|
980 |
+
return None, 0.0, f"Error: {e}"
|
981 |
+
|
982 |
+
|
983 |
+
def get_sql_execution_results(
|
984 |
+
predicted_sql: str, gold_sql: str, connector, sql_timeout: float
|
985 |
+
) -> SQLExecutionResult:
|
986 |
+
"""Execute and compare predicted and gold SQL queries, returning execution metrics.
|
987 |
+
|
988 |
+
Args:
|
989 |
+
predicted_sql (str): The SQL query predicted by the model.
|
990 |
+
gold_sql (str): The reference (gold) SQL query.
|
991 |
+
connector: Database connector object used to execute the queries.
|
992 |
+
sql_timeout (float): Maximum time (in seconds) allowed for query execution.
|
993 |
+
|
994 |
+
Returns:
|
995 |
+
SQLExecutionResult: An object containing various execution metrics, including:
|
996 |
+
- execution_accuracy (int): 1 if predicted and gold queries produce
|
997 |
+
equivalent results, else 0.
|
998 |
+
- non_empty_execution_accuracy (int): 1 if both queries produce non-empty
|
999 |
+
and equivalent results, else 0.
|
1000 |
+
- subset_non_empty_execution_accuracy (int): 1 if predicted results are a
|
1001 |
+
subset or superset of gold results and non-empty, else 0. Subset
|
1002 |
+
comparison is column-based. This means that the predicted SQL dataframe
|
1003 |
+
can have missing or additional columns compared to the gold SQL dataframe.
|
1004 |
+
- execution_accuracy_bird (int): 1 if results match according to BIRD
|
1005 |
+
evaluation logic, else 0.
|
1006 |
+
- non_empty_gold_df (int): 1 if the gold query result is non-empty, else 0.
|
1007 |
+
- gold_sql_runtime (float): Execution time for the gold SQL query.
|
1008 |
+
- predicted_sql_runtime (float): Execution time for the predicted SQL query.
|
1009 |
+
- pred_to_gold_runtime_ratio (float): Ratio of predicted to gold query
|
1010 |
+
runtimes.
|
1011 |
+
- gold_error (int): 1 if the gold query failed, else 0.
|
1012 |
+
- predicted_error (int): 1 if the predicted query failed, else 0.
|
1013 |
+
- gold_df_json (str): JSON representation of the gold query result
|
1014 |
+
DataFrame.
|
1015 |
+
- predicted_df_json (str): JSON representation of the predicted query
|
1016 |
+
result DataFrame.
|
1017 |
+
- error_message (str): Error message if any query failed, else empty
|
1018 |
+
string.
|
1019 |
+
|
1020 |
+
Notes:
|
1021 |
+
- If the gold query fails, the function returns early with error details.
|
1022 |
+
- If the predicted query is identical or SQL-equivalent to the gold query,
|
1023 |
+
results are considered correct without execution.
|
1024 |
+
- Otherwise, both queries are executed and their results compared using
|
1025 |
+
multiple metrics.
|
1026 |
+
"""
|
1027 |
+
gold_df, gold_runtime, gold_error_msg = run_query(gold_sql, connector, sql_timeout)
|
1028 |
+
gold_error = int(bool(gold_error_msg))
|
1029 |
+
|
1030 |
+
if gold_error:
|
1031 |
+
return SQLExecutionResult(
|
1032 |
+
execution_accuracy=0,
|
1033 |
+
non_empty_execution_accuracy=0,
|
1034 |
+
subset_non_empty_execution_accuracy=0,
|
1035 |
+
execution_accuracy_bird=0,
|
1036 |
+
non_empty_gold_df=0,
|
1037 |
+
gold_sql_runtime=gold_runtime,
|
1038 |
+
predicted_sql_runtime=0,
|
1039 |
+
pred_to_gold_runtime_ratio=0,
|
1040 |
+
gold_error=gold_error,
|
1041 |
+
predicted_error=0,
|
1042 |
+
gold_df_json="",
|
1043 |
+
predicted_df_json="",
|
1044 |
+
error_message=gold_error_msg,
|
1045 |
+
)
|
1046 |
+
|
1047 |
+
non_empty_gold_df = int(not gold_df.empty)
|
1048 |
+
if predicted_sql.strip().lower() == gold_sql.strip().lower():
|
1049 |
+
return SQLExecutionResult(
|
1050 |
+
execution_accuracy=1,
|
1051 |
+
non_empty_execution_accuracy=non_empty_gold_df,
|
1052 |
+
subset_non_empty_execution_accuracy=non_empty_gold_df,
|
1053 |
+
execution_accuracy_bird=1,
|
1054 |
+
non_empty_gold_df=non_empty_gold_df,
|
1055 |
+
gold_sql_runtime=gold_runtime,
|
1056 |
+
predicted_sql_runtime=0,
|
1057 |
+
pred_to_gold_runtime_ratio=0,
|
1058 |
+
gold_error=0,
|
1059 |
+
predicted_error=0,
|
1060 |
+
gold_df_json=gold_df.to_json(),
|
1061 |
+
predicted_df_json=gold_df.to_json(),
|
1062 |
+
error_message="",
|
1063 |
+
)
|
1064 |
+
|
1065 |
+
try:
|
1066 |
+
if sqlglot_optimized_equivalence(gold_sql, predicted_sql):
|
1067 |
+
return SQLExecutionResult(
|
1068 |
+
execution_accuracy=1,
|
1069 |
+
non_empty_execution_accuracy=non_empty_gold_df,
|
1070 |
+
subset_non_empty_execution_accuracy=non_empty_gold_df,
|
1071 |
+
execution_accuracy_bird=1,
|
1072 |
+
non_empty_gold_df=non_empty_gold_df,
|
1073 |
+
gold_sql_runtime=gold_runtime,
|
1074 |
+
predicted_sql_runtime=0,
|
1075 |
+
pred_to_gold_runtime_ratio=0,
|
1076 |
+
gold_error=0,
|
1077 |
+
predicted_error=0,
|
1078 |
+
gold_df_json=gold_df.to_json(),
|
1079 |
+
predicted_df_json=gold_df.to_json(),
|
1080 |
+
error_message="",
|
1081 |
+
)
|
1082 |
+
except Exception as e:
|
1083 |
+
logger.info(f"Could not check SQL equivalence: {e}")
|
1084 |
+
|
1085 |
+
pred_df, pred_runtime, pred_error_msg = run_query(
|
1086 |
+
predicted_sql, connector, sql_timeout
|
1087 |
+
)
|
1088 |
+
pred_error = 1 if pred_error_msg else 0
|
1089 |
+
|
1090 |
+
if pred_df is None:
|
1091 |
+
return SQLExecutionResult(
|
1092 |
+
execution_accuracy=0,
|
1093 |
+
non_empty_execution_accuracy=0,
|
1094 |
+
subset_non_empty_execution_accuracy=0,
|
1095 |
+
execution_accuracy_bird=0,
|
1096 |
+
non_empty_gold_df=non_empty_gold_df,
|
1097 |
+
gold_sql_runtime=gold_runtime,
|
1098 |
+
predicted_sql_runtime=pred_runtime,
|
1099 |
+
pred_to_gold_runtime_ratio=(pred_runtime / gold_runtime)
|
1100 |
+
if gold_runtime > 0
|
1101 |
+
else 0,
|
1102 |
+
gold_error=0,
|
1103 |
+
predicted_error=pred_error,
|
1104 |
+
gold_df_json=gold_df.to_json(),
|
1105 |
+
predicted_df_json="",
|
1106 |
+
error_message=pred_error_msg,
|
1107 |
+
)
|
1108 |
+
|
1109 |
+
match, non_empty_match, subset_match = compare_result_dfs(
|
1110 |
+
gold_df, pred_df, gold_sql
|
1111 |
+
)
|
1112 |
+
bird_match = compare_dfs_bird_eval_logic(gold_df, pred_df)
|
1113 |
+
|
1114 |
+
return SQLExecutionResult(
|
1115 |
+
execution_accuracy=match,
|
1116 |
+
non_empty_execution_accuracy=non_empty_match,
|
1117 |
+
subset_non_empty_execution_accuracy=subset_match,
|
1118 |
+
execution_accuracy_bird=bird_match,
|
1119 |
+
non_empty_gold_df=non_empty_gold_df,
|
1120 |
+
gold_sql_runtime=gold_runtime,
|
1121 |
+
predicted_sql_runtime=pred_runtime,
|
1122 |
+
pred_to_gold_runtime_ratio=(pred_runtime / gold_runtime)
|
1123 |
+
if gold_runtime > 0
|
1124 |
+
else 0,
|
1125 |
+
gold_error=0,
|
1126 |
+
predicted_error=0,
|
1127 |
+
gold_df_json=gold_df.to_json(),
|
1128 |
+
predicted_df_json=pred_df.to_json(),
|
1129 |
+
error_message=pred_error_msg,
|
1130 |
+
)
|
1131 |
+
|
1132 |
+
|
1133 |
+
def replace_select_clause(
|
1134 |
+
source_query: str, target_query: str, dialect: str = "postgres"
|
1135 |
+
) -> str:
|
1136 |
+
"""Replaces the SELECT clause of the target SQL query with the SELECT clause from the source SQL query.
|
1137 |
+
|
1138 |
+
Args:
|
1139 |
+
source_query (str): SQL query whose SELECT clause will be used.
|
1140 |
+
target_query (str): SQL query whose SELECT clause will be replaced.
|
1141 |
+
dialect (str): SQL dialect for parsing and rendering (default: "postgres").
|
1142 |
+
|
1143 |
+
Returns:
|
1144 |
+
str: A new SQL query with the SELECT clause of `target_query` replaced by that of `source_query`.
|
1145 |
+
|
1146 |
+
Raises:
|
1147 |
+
ValueError: If either query is not a valid SELECT statement.
|
1148 |
+
|
1149 |
+
Example:
|
1150 |
+
>>> replace_select_clause(
|
1151 |
+
... "SELECT id FROM employees",
|
1152 |
+
... "SELECT name FROM employees WHERE age > 30"
|
1153 |
+
... )
|
1154 |
+
"SELECT id FROM employees WHERE age > 30"
|
1155 |
+
"""
|
1156 |
+
from sqlglot import exp, parse_one
|
1157 |
+
|
1158 |
+
if not dialect:
|
1159 |
+
dialect = "postgres"
|
1160 |
+
|
1161 |
+
# Parse queries using the specified dialect
|
1162 |
+
source_ast = parse_one(source_query, read=dialect)
|
1163 |
+
target_ast = parse_one(target_query, read=dialect)
|
1164 |
+
|
1165 |
+
if not isinstance(source_ast, exp.Select) or not isinstance(target_ast, exp.Select):
|
1166 |
+
raise ValueError("Both queries must be valid SELECT statements.")
|
1167 |
+
|
1168 |
+
# Replace SELECT expressions in the target with those from the source
|
1169 |
+
target_ast.set("expressions", source_ast.expressions)
|
1170 |
+
|
1171 |
+
# Return the updated SQL string using the dialect
|
1172 |
+
return target_ast.sql(dialect=dialect)
|
1173 |
+
|
1174 |
+
|
1175 |
+
def extract_sql_from_text(text: str) -> str:
|
1176 |
+
"""Extracts the first SQL query from the given text.
|
1177 |
+
|
1178 |
+
Priority:
|
1179 |
+
1. SQL inside fenced blocks like ```sql ... ```
|
1180 |
+
2. SQL starting on a new line or after a colon/label
|
1181 |
+
3. SQL without semicolon
|
1182 |
+
|
1183 |
+
Returns:
|
1184 |
+
The SQL query string, or an empty string if not found.
|
1185 |
+
"""
|
1186 |
+
# 1. Look for fenced SQL code block
|
1187 |
+
fenced_block_pattern = re.compile(r"```sql\s+(.*?)```", re.IGNORECASE | re.DOTALL)
|
1188 |
+
match = fenced_block_pattern.search(text)
|
1189 |
+
if match:
|
1190 |
+
return match.group(1).strip()
|
1191 |
+
|
1192 |
+
# 2. Inline SQL with semicolon
|
1193 |
+
sql_keywords = r"(?:SELECT|INSERT|UPDATE|DELETE|WITH)\s+"
|
1194 |
+
sql_start = (
|
1195 |
+
r"(?:^|\n|:\s*)" # Start of string, newline, or colon label like "Just run:"
|
1196 |
+
)
|
1197 |
+
sql_pattern = re.compile(
|
1198 |
+
rf"{sql_start}({sql_keywords}.*?;)", re.IGNORECASE | re.DOTALL
|
1199 |
+
)
|
1200 |
+
match = sql_pattern.search(text)
|
1201 |
+
if match:
|
1202 |
+
return match.group(1).strip()
|
1203 |
+
|
1204 |
+
# 3. Inline SQL without semicolon
|
1205 |
+
fallback_pattern = re.compile(
|
1206 |
+
rf"{sql_start}({sql_keywords}.*)", re.IGNORECASE | re.DOTALL
|
1207 |
+
)
|
1208 |
+
fallback_match = fallback_pattern.search(text)
|
1209 |
+
if fallback_match:
|
1210 |
+
return fallback_match.group(1).strip()
|
1211 |
+
|
1212 |
+
return ""
|
1213 |
+
|
1214 |
+
|
1215 |
+
ALL_DIALECTS = [
|
1216 |
+
"Athena",
|
1217 |
+
"BigQuery",
|
1218 |
+
"ClickHouse",
|
1219 |
+
"Databricks",
|
1220 |
+
"Doris",
|
1221 |
+
"Drill",
|
1222 |
+
"Druid",
|
1223 |
+
"DuckDB",
|
1224 |
+
"Hive",
|
1225 |
+
"Materialize",
|
1226 |
+
"MySQL",
|
1227 |
+
"Oracle",
|
1228 |
+
"Postgres",
|
1229 |
+
"Presto",
|
1230 |
+
"PRQL",
|
1231 |
+
"Redshift",
|
1232 |
+
"RisingWave",
|
1233 |
+
"Snowflake",
|
1234 |
+
"Spark",
|
1235 |
+
"Spark2",
|
1236 |
+
"SQLite",
|
1237 |
+
"StarRocks",
|
1238 |
+
"Tableau",
|
1239 |
+
"Teradata",
|
1240 |
+
"Trino",
|
1241 |
+
"TSQL",
|
1242 |
+
]
|
type_utils.py
CHANGED
@@ -503,9 +503,25 @@ def isoftype(object, typing_type):
|
|
503 |
if is_typed_dict(typing_type):
|
504 |
if not isinstance(object, dict):
|
505 |
return False
|
|
|
|
|
506 |
for key, expected_type in typing_type.__annotations__.items():
|
507 |
-
|
508 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
509 |
return True
|
510 |
|
511 |
if typing_type == typing.Any:
|
|
|
503 |
if is_typed_dict(typing_type):
|
504 |
if not isinstance(object, dict):
|
505 |
return False
|
506 |
+
|
507 |
+
# Only support total=True, check each field
|
508 |
for key, expected_type in typing_type.__annotations__.items():
|
509 |
+
# Check if field is Optional (Union with None)
|
510 |
+
is_optional = (
|
511 |
+
hasattr(expected_type, "__origin__")
|
512 |
+
and expected_type.__origin__ is Union
|
513 |
+
and type(None) in expected_type.__args__
|
514 |
+
)
|
515 |
+
|
516 |
+
if key not in object:
|
517 |
+
# Field is missing - only allowed if it's Optional
|
518 |
+
if not is_optional:
|
519 |
+
return False
|
520 |
+
else:
|
521 |
+
# Field is present - check type
|
522 |
+
if not isoftype(object[key], expected_type):
|
523 |
+
return False
|
524 |
+
|
525 |
return True
|
526 |
|
527 |
if typing_type == typing.Any:
|
types.py
CHANGED
@@ -6,8 +6,52 @@ Text = NewType("Text", str)
|
|
6 |
Number = NewType("Number", Union[float, int])
|
7 |
|
8 |
|
9 |
-
class
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
content: Text
|
12 |
|
13 |
|
@@ -18,7 +62,12 @@ class RagResponse(TypedDict):
|
|
18 |
is_answerable: bool
|
19 |
|
20 |
|
21 |
-
Dialog = NewType("Dialog", List[
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
class Image(TypedDict):
|
@@ -52,36 +101,17 @@ class SQLDatabase(TypedDict):
|
|
52 |
data: Optional[Dict[str, Dict]]
|
53 |
|
54 |
|
55 |
-
class JsonSchema:
|
56 |
-
@classmethod
|
57 |
-
def __verify_type__(cls, object):
|
58 |
-
if not isinstance(object, dict):
|
59 |
-
return False
|
60 |
-
import jsonschema_rs
|
61 |
-
|
62 |
-
jsonschema_rs.meta.validate(object)
|
63 |
-
return True
|
64 |
-
|
65 |
-
|
66 |
-
class Tool(TypedDict):
|
67 |
-
name: str
|
68 |
-
description: str
|
69 |
-
parameters: JsonSchema
|
70 |
-
|
71 |
-
|
72 |
-
class ToolCall(TypedDict):
|
73 |
-
name: str
|
74 |
-
arguments: Dict[str, Any]
|
75 |
-
|
76 |
-
|
77 |
register_type(Text)
|
78 |
register_type(Number)
|
79 |
-
register_type(
|
|
|
|
|
80 |
register_type(Dialog)
|
81 |
register_type(Table)
|
82 |
register_type(Audio)
|
83 |
register_type(Image)
|
84 |
register_type(Video)
|
|
|
85 |
register_type(Document)
|
86 |
register_type(MultiDocument)
|
87 |
register_type(RagResponse)
|
|
|
6 |
Number = NewType("Number", Union[float, int])
|
7 |
|
8 |
|
9 |
+
class JsonSchema:
|
10 |
+
@classmethod
|
11 |
+
def __verify_type__(cls, object):
|
12 |
+
if not isinstance(object, dict):
|
13 |
+
return False
|
14 |
+
import jsonschema_rs
|
15 |
+
|
16 |
+
jsonschema_rs.meta.validate(object)
|
17 |
+
return True
|
18 |
+
|
19 |
+
|
20 |
+
class Tool(TypedDict):
|
21 |
+
# Original fields
|
22 |
+
name: str
|
23 |
+
description: str
|
24 |
+
parameters: JsonSchema
|
25 |
+
# LiteLLM extension
|
26 |
+
type: Optional[Literal["function"]]
|
27 |
+
|
28 |
+
|
29 |
+
class ToolCall(TypedDict):
|
30 |
+
name: str
|
31 |
+
arguments: Dict[str, Any]
|
32 |
+
|
33 |
+
|
34 |
+
class ToolCallContext(TypedDict):
|
35 |
+
id: str
|
36 |
+
type: Literal["function"]
|
37 |
+
function: ToolCall
|
38 |
+
|
39 |
+
|
40 |
+
class ToolCallTurn(TypedDict):
|
41 |
+
role: Literal["assistant"]
|
42 |
+
content: Optional[str]
|
43 |
+
tool_calls: List[ToolCallContext]
|
44 |
+
|
45 |
+
|
46 |
+
class ToolOutputTurn(TypedDict):
|
47 |
+
role: Literal["tool"]
|
48 |
+
tool_call_id: str
|
49 |
+
name: str
|
50 |
+
content: str
|
51 |
+
|
52 |
+
|
53 |
+
class TextTurn(TypedDict):
|
54 |
+
role: Literal["system", "user", "agent", "assistant"]
|
55 |
content: Text
|
56 |
|
57 |
|
|
|
62 |
is_answerable: bool
|
63 |
|
64 |
|
65 |
+
Dialog = NewType("Dialog", List[Union[TextTurn, ToolCallTurn, ToolOutputTurn]])
|
66 |
+
|
67 |
+
|
68 |
+
class Conversation(TypedDict):
|
69 |
+
id: str
|
70 |
+
dialog: Dialog
|
71 |
|
72 |
|
73 |
class Image(TypedDict):
|
|
|
101 |
data: Optional[Dict[str, Dict]]
|
102 |
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
register_type(Text)
|
105 |
register_type(Number)
|
106 |
+
register_type(TextTurn)
|
107 |
+
register_type(ToolCallTurn)
|
108 |
+
register_type(ToolOutputTurn)
|
109 |
register_type(Dialog)
|
110 |
register_type(Table)
|
111 |
register_type(Audio)
|
112 |
register_type(Image)
|
113 |
register_type(Video)
|
114 |
+
register_type(Conversation)
|
115 |
register_type(Document)
|
116 |
register_type(MultiDocument)
|
117 |
register_type(RagResponse)
|
version.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
version = "1.
|
|
|
1 |
+
version = "1.25.0"
|