Upload operators.py with huggingface_hub
Browse files- operators.py +120 -20
operators.py
CHANGED
|
@@ -1,11 +1,24 @@
|
|
|
|
|
|
|
|
| 1 |
import uuid
|
| 2 |
from abc import abstractmethod
|
| 3 |
from copy import deepcopy
|
| 4 |
from dataclasses import field
|
| 5 |
from itertools import zip_longest
|
| 6 |
-
from typing import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
from .artifact import Artifact, fetch_artifact
|
|
|
|
| 9 |
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
| 10 |
from .operator import (
|
| 11 |
MultiStream,
|
|
@@ -35,24 +48,6 @@ class FromIterables(StreamInitializerOperator):
|
|
| 35 |
return MultiStream.from_iterables(iterables)
|
| 36 |
|
| 37 |
|
| 38 |
-
class RenameFields(StreamInstanceOperator):
|
| 39 |
-
"""
|
| 40 |
-
Renames fields
|
| 41 |
-
Attributes:
|
| 42 |
-
mapper (Dict[str, str]): old field names to new field names dict
|
| 43 |
-
"""
|
| 44 |
-
|
| 45 |
-
mapper: Dict[str, str]
|
| 46 |
-
|
| 47 |
-
def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
|
| 48 |
-
result = {}
|
| 49 |
-
# passes on all values to preserve ordering
|
| 50 |
-
for key, value in instance.items():
|
| 51 |
-
result[self.mapper.get(key, key)] = value
|
| 52 |
-
# doesn't warn if unnecessary mapping was supplied for efficiency
|
| 53 |
-
return result
|
| 54 |
-
|
| 55 |
-
|
| 56 |
class MapInstanceValues(StreamInstanceOperator):
|
| 57 |
"""A class used to map instance values into a stream.
|
| 58 |
|
|
@@ -179,7 +174,10 @@ class FieldOperator(StreamInstanceOperator):
|
|
| 179 |
|
| 180 |
def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
|
| 181 |
for from_field, to_field in self._field_to_field:
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
| 183 |
if self.process_every_value:
|
| 184 |
new_value = [self.process_value(value) for value in old_value]
|
| 185 |
else:
|
|
@@ -190,6 +188,49 @@ class FieldOperator(StreamInstanceOperator):
|
|
| 190 |
return instance
|
| 191 |
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
class JoinStr(FieldOperator):
|
| 194 |
"""
|
| 195 |
Joins a list of strings (contents of a field), similar to str.join()
|
|
@@ -203,6 +244,65 @@ class JoinStr(FieldOperator):
|
|
| 203 |
return self.separator.join(str(x) for x in value)
|
| 204 |
|
| 205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
class ZipFieldValues(StreamInstanceOperator):
|
| 207 |
"""
|
| 208 |
Zips values of multiple fields similar to list(zip(*fields))
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import inspect
|
| 3 |
import uuid
|
| 4 |
from abc import abstractmethod
|
| 5 |
from copy import deepcopy
|
| 6 |
from dataclasses import field
|
| 7 |
from itertools import zip_longest
|
| 8 |
+
from typing import (
|
| 9 |
+
Any,
|
| 10 |
+
Callable,
|
| 11 |
+
Dict,
|
| 12 |
+
Generator,
|
| 13 |
+
Iterable,
|
| 14 |
+
List,
|
| 15 |
+
Optional,
|
| 16 |
+
Tuple,
|
| 17 |
+
Union,
|
| 18 |
+
)
|
| 19 |
|
| 20 |
from .artifact import Artifact, fetch_artifact
|
| 21 |
+
from .dataclass import NonPositionalField
|
| 22 |
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
| 23 |
from .operator import (
|
| 24 |
MultiStream,
|
|
|
|
| 48 |
return MultiStream.from_iterables(iterables)
|
| 49 |
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
class MapInstanceValues(StreamInstanceOperator):
|
| 52 |
"""A class used to map instance values into a stream.
|
| 53 |
|
|
|
|
| 174 |
|
| 175 |
def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
|
| 176 |
for from_field, to_field in self._field_to_field:
|
| 177 |
+
try:
|
| 178 |
+
old_value = dict_get(instance, from_field, use_dpath=self.use_query)
|
| 179 |
+
except TypeError as e:
|
| 180 |
+
raise TypeError(f"Failed to get {from_field} from {instance}")
|
| 181 |
if self.process_every_value:
|
| 182 |
new_value = [self.process_value(value) for value in old_value]
|
| 183 |
else:
|
|
|
|
| 188 |
return instance
|
| 189 |
|
| 190 |
|
| 191 |
+
class RenameFields(FieldOperator):
|
| 192 |
+
"""
|
| 193 |
+
Renames fields
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
def process_value(self, value: Any) -> Any:
|
| 197 |
+
return value
|
| 198 |
+
|
| 199 |
+
def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
|
| 200 |
+
res = super().process(instance=instance, stream_name=stream_name)
|
| 201 |
+
vals = [x[1] for x in self._field_to_field]
|
| 202 |
+
for key, _ in self._field_to_field:
|
| 203 |
+
if self.use_query and "/" in key:
|
| 204 |
+
continue
|
| 205 |
+
if key not in vals:
|
| 206 |
+
res.pop(key)
|
| 207 |
+
return res
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class AddConstant(FieldOperator):
|
| 211 |
+
"""
|
| 212 |
+
Adds a number, similar to field + add
|
| 213 |
+
Args:
|
| 214 |
+
add (float): sum to add
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
add: float
|
| 218 |
+
|
| 219 |
+
def process_value(self, value: Any) -> Any:
|
| 220 |
+
return value + self.add
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class ShuffleFieldValues(FieldOperator):
|
| 224 |
+
"""
|
| 225 |
+
Shuffles an iterable value
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
def process_value(self, value: Any) -> Any:
|
| 229 |
+
res = list(value)
|
| 230 |
+
random.shuffle(res)
|
| 231 |
+
return res
|
| 232 |
+
|
| 233 |
+
|
| 234 |
class JoinStr(FieldOperator):
|
| 235 |
"""
|
| 236 |
Joins a list of strings (contents of a field), similar to str.join()
|
|
|
|
| 244 |
return self.separator.join(str(x) for x in value)
|
| 245 |
|
| 246 |
|
| 247 |
+
class Apply(StreamInstanceOperator):
|
| 248 |
+
__allow_unexpected_arguments__ = True
|
| 249 |
+
function: Callable = NonPositionalField(required=True)
|
| 250 |
+
to_field: str = NonPositionalField(required=True)
|
| 251 |
+
|
| 252 |
+
def function_to_str(self, function: Callable) -> str:
|
| 253 |
+
return function.__qualname__
|
| 254 |
+
|
| 255 |
+
def str_to_function(self, function_str: str) -> Callable:
|
| 256 |
+
splitted = function_str.split(".", 1)
|
| 257 |
+
if len(splitted) == 1:
|
| 258 |
+
return getattr(__builtins__, function_str)
|
| 259 |
+
else:
|
| 260 |
+
module_name, function_name = splitted
|
| 261 |
+
if module_name in globals():
|
| 262 |
+
obj = globals()[module_name]
|
| 263 |
+
else:
|
| 264 |
+
obj = importlib.import_module(module_name)
|
| 265 |
+
for part in function_name.split("."):
|
| 266 |
+
obj = getattr(obj, part)
|
| 267 |
+
return obj
|
| 268 |
+
|
| 269 |
+
def prepare(self):
|
| 270 |
+
super().prepare()
|
| 271 |
+
if isinstance(self.function, str):
|
| 272 |
+
self.function = self.str_to_function(self.function)
|
| 273 |
+
|
| 274 |
+
def get_init_dict(self):
|
| 275 |
+
result = super().get_init_dict()
|
| 276 |
+
result["function"] = self.function_to_str(self.function)
|
| 277 |
+
return result
|
| 278 |
+
|
| 279 |
+
def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
|
| 280 |
+
argv = [instance[arg] for arg in self._argv]
|
| 281 |
+
kwargs = {key: instance[val] for key, val in self._kwargs}
|
| 282 |
+
|
| 283 |
+
result = self.function(*argv, **kwargs)
|
| 284 |
+
|
| 285 |
+
instance[self.to_field] = result
|
| 286 |
+
return instance
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class ListFieldValues(StreamInstanceOperator):
|
| 290 |
+
"""
|
| 291 |
+
Concatanates values of multiple fields into a list to list(fields)
|
| 292 |
+
"""
|
| 293 |
+
|
| 294 |
+
fields: str
|
| 295 |
+
to_field: str
|
| 296 |
+
use_query: bool = False
|
| 297 |
+
|
| 298 |
+
def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
|
| 299 |
+
values = []
|
| 300 |
+
for field in self.fields:
|
| 301 |
+
values.append(dict_get(instance, field, use_dpath=self.use_query))
|
| 302 |
+
instance[self.to_field] = values
|
| 303 |
+
return instance
|
| 304 |
+
|
| 305 |
+
|
| 306 |
class ZipFieldValues(StreamInstanceOperator):
|
| 307 |
"""
|
| 308 |
Zips values of multiple fields similar to list(zip(*fields))
|