Upload operators.py with huggingface_hub
Browse files- operators.py +38 -13
operators.py
CHANGED
|
@@ -20,6 +20,7 @@ Some operators are specielized in specific task such as:
|
|
| 20 |
|
| 21 |
- :class:`loaders<unitxt.loaders>` for loading data.
|
| 22 |
- :class:`splitters<unitxt.splitters>` for fixing data splits.
|
|
|
|
| 23 |
|
| 24 |
Other specelized operators are used by unitxt internally:
|
| 25 |
|
|
@@ -32,7 +33,7 @@ General Operaotrs List:
|
|
| 32 |
------------------------
|
| 33 |
"""
|
| 34 |
import collections
|
| 35 |
-
import
|
| 36 |
import operator
|
| 37 |
import os
|
| 38 |
import uuid
|
|
@@ -41,7 +42,6 @@ from abc import abstractmethod
|
|
| 41 |
from collections import Counter
|
| 42 |
from copy import deepcopy
|
| 43 |
from dataclasses import field
|
| 44 |
-
from importlib import import_module
|
| 45 |
from itertools import zip_longest
|
| 46 |
from random import Random
|
| 47 |
from typing import (
|
|
@@ -64,6 +64,7 @@ from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
|
| 64 |
from .operator import (
|
| 65 |
MultiStream,
|
| 66 |
MultiStreamOperator,
|
|
|
|
| 67 |
PagedStreamOperator,
|
| 68 |
SequentialOperator,
|
| 69 |
SideEffectOperator,
|
|
@@ -782,7 +783,7 @@ class Apply(StreamInstanceOperator):
|
|
| 782 |
elif module_name in globals():
|
| 783 |
obj = globals()[module_name]
|
| 784 |
else:
|
| 785 |
-
obj =
|
| 786 |
for part in function_name.split("."):
|
| 787 |
obj = getattr(obj, part)
|
| 788 |
return obj
|
|
@@ -963,7 +964,16 @@ class CopyFields(FieldOperator):
|
|
| 963 |
"""
|
| 964 |
|
| 965 |
def process_value(self, value: Any) -> Any:
|
| 966 |
-
return value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 967 |
|
| 968 |
|
| 969 |
class AddID(StreamInstanceOperator):
|
|
@@ -1230,10 +1240,13 @@ class ComputeExpressionMixin(Artifact):
|
|
| 1230 |
expression: str
|
| 1231 |
imports_list: List[str] = OptionalField(default_factory=list)
|
| 1232 |
|
|
|
|
|
|
|
|
|
|
| 1233 |
def prepare(self):
|
| 1234 |
# can not do the imports here, because object does not pickle with imports
|
| 1235 |
self.globals = {
|
| 1236 |
-
module_name:
|
| 1237 |
}
|
| 1238 |
|
| 1239 |
def compute_expression(self, instance: dict) -> Any:
|
|
@@ -1574,7 +1587,7 @@ class ApplyMetric(SingleStreamOperator, ArtifactFetcherMixin):
|
|
| 1574 |
calc_confidence_intervals: bool
|
| 1575 |
|
| 1576 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1577 |
-
from .metrics import Metric
|
| 1578 |
|
| 1579 |
first_instance = stream.peek()
|
| 1580 |
|
|
@@ -1593,6 +1606,16 @@ class ApplyMetric(SingleStreamOperator, ArtifactFetcherMixin):
|
|
| 1593 |
# by the first listed metric (as desired).
|
| 1594 |
metric_names = list(reversed(metric_names))
|
| 1595 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1596 |
for metric_name in metric_names:
|
| 1597 |
metric = self.get_artifact(metric_name)
|
| 1598 |
assert isinstance(
|
|
@@ -1600,15 +1623,17 @@ class ApplyMetric(SingleStreamOperator, ArtifactFetcherMixin):
|
|
| 1600 |
), f"Operator {metric_name} must be a Metric"
|
| 1601 |
|
| 1602 |
if not self.calc_confidence_intervals:
|
| 1603 |
-
|
| 1604 |
-
metric.disable_confidence_interval_calculation()
|
| 1605 |
-
elif isinstance(metric, MetricPipeline) and isinstance(
|
| 1606 |
-
metric.metric, MetricWithConfidenceInterval
|
| 1607 |
-
):
|
| 1608 |
-
metric.metric.disable_confidence_interval_calculation()
|
| 1609 |
|
| 1610 |
-
|
|
|
|
|
|
|
|
|
|
| 1611 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1612 |
yield from stream
|
| 1613 |
|
| 1614 |
|
|
|
|
| 20 |
|
| 21 |
- :class:`loaders<unitxt.loaders>` for loading data.
|
| 22 |
- :class:`splitters<unitxt.splitters>` for fixing data splits.
|
| 23 |
+
- :class:`struct_data_operators<unitxt.struct_data_operators>` for structured data operators.
|
| 24 |
|
| 25 |
Other specelized operators are used by unitxt internally:
|
| 26 |
|
|
|
|
| 33 |
------------------------
|
| 34 |
"""
|
| 35 |
import collections
|
| 36 |
+
import copy
|
| 37 |
import operator
|
| 38 |
import os
|
| 39 |
import uuid
|
|
|
|
| 42 |
from collections import Counter
|
| 43 |
from copy import deepcopy
|
| 44 |
from dataclasses import field
|
|
|
|
| 45 |
from itertools import zip_longest
|
| 46 |
from random import Random
|
| 47 |
from typing import (
|
|
|
|
| 64 |
from .operator import (
|
| 65 |
MultiStream,
|
| 66 |
MultiStreamOperator,
|
| 67 |
+
PackageRequirementsMixin,
|
| 68 |
PagedStreamOperator,
|
| 69 |
SequentialOperator,
|
| 70 |
SideEffectOperator,
|
|
|
|
| 783 |
elif module_name in globals():
|
| 784 |
obj = globals()[module_name]
|
| 785 |
else:
|
| 786 |
+
obj = __import__(module_name)
|
| 787 |
for part in function_name.split("."):
|
| 788 |
obj = getattr(obj, part)
|
| 789 |
return obj
|
|
|
|
| 964 |
"""
|
| 965 |
|
| 966 |
def process_value(self, value: Any) -> Any:
|
| 967 |
+
return copy.deepcopy(value)
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
class GetItemByIndex(FieldOperator):
|
| 971 |
+
"""Get from the item list by the index in the field."""
|
| 972 |
+
|
| 973 |
+
items_list: List[Any]
|
| 974 |
+
|
| 975 |
+
def process_value(self, value: Any) -> Any:
|
| 976 |
+
return self.items_list[value]
|
| 977 |
|
| 978 |
|
| 979 |
class AddID(StreamInstanceOperator):
|
|
|
|
| 1240 |
expression: str
|
| 1241 |
imports_list: List[str] = OptionalField(default_factory=list)
|
| 1242 |
|
| 1243 |
+
def verify(self):
|
| 1244 |
+
PackageRequirementsMixin.check_missing_requirements(self, self.imports_list)
|
| 1245 |
+
|
| 1246 |
def prepare(self):
|
| 1247 |
# can not do the imports here, because object does not pickle with imports
|
| 1248 |
self.globals = {
|
| 1249 |
+
module_name: __import__(module_name) for module_name in self.imports_list
|
| 1250 |
}
|
| 1251 |
|
| 1252 |
def compute_expression(self, instance: dict) -> Any:
|
|
|
|
| 1587 |
calc_confidence_intervals: bool
|
| 1588 |
|
| 1589 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1590 |
+
from .metrics import Metric
|
| 1591 |
|
| 1592 |
first_instance = stream.peek()
|
| 1593 |
|
|
|
|
| 1606 |
# by the first listed metric (as desired).
|
| 1607 |
metric_names = list(reversed(metric_names))
|
| 1608 |
|
| 1609 |
+
# Workaround: The metric/MetricPipeline modifies the stream itself, sometines making it incompatible
|
| 1610 |
+
# for further metrics' processing, instead of just modifying the score field.
|
| 1611 |
+
# Here we keep all the fields besides the score, and restore them after the metric finishes.
|
| 1612 |
+
first_instance = stream.peek()
|
| 1613 |
+
keys_to_restore = set(first_instance.keys()).difference({"score"})
|
| 1614 |
+
multi_stream = MultiStream({"tmp": stream})
|
| 1615 |
+
multi_stream = CopyFields(
|
| 1616 |
+
field_to_field={k: f"{k}_orig" for k in keys_to_restore}
|
| 1617 |
+
)(multi_stream)
|
| 1618 |
+
|
| 1619 |
for metric_name in metric_names:
|
| 1620 |
metric = self.get_artifact(metric_name)
|
| 1621 |
assert isinstance(
|
|
|
|
| 1623 |
), f"Operator {metric_name} must be a Metric"
|
| 1624 |
|
| 1625 |
if not self.calc_confidence_intervals:
|
| 1626 |
+
metric.disable_confidence_interval_calculation()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1627 |
|
| 1628 |
+
multi_stream = metric(multi_stream)
|
| 1629 |
+
multi_stream = CopyFields(
|
| 1630 |
+
field_to_field={f"{k}_orig": k for k in keys_to_restore}
|
| 1631 |
+
)(multi_stream)
|
| 1632 |
|
| 1633 |
+
multi_stream = RemoveFields(fields=[f"{k}_orig" for k in keys_to_restore])(
|
| 1634 |
+
multi_stream
|
| 1635 |
+
)
|
| 1636 |
+
stream = multi_stream["tmp"]
|
| 1637 |
yield from stream
|
| 1638 |
|
| 1639 |
|