|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Generic utilities |
|
""" |
|
|
|
import inspect |
|
import tempfile |
|
from collections import OrderedDict, UserDict |
|
from collections.abc import MutableMapping |
|
from contextlib import ExitStack, contextmanager |
|
from dataclasses import fields |
|
from enum import Enum |
|
from typing import Any, ContextManager, List, Tuple |
|
|
|
import numpy as np |
|
|
|
from .import_utils import is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy |
|
|
|
|
|
if is_flax_available(): |
|
import jax.numpy as jnp |
|
|
|
|
|
class cached_property(property): |
|
""" |
|
Descriptor that mimics @property but caches output in member variable. |
|
|
|
From tensorflow_datasets |
|
|
|
Built-in in functools from Python 3.8. |
|
""" |
|
|
|
def __get__(self, obj, objtype=None): |
|
|
|
if obj is None: |
|
return self |
|
if self.fget is None: |
|
raise AttributeError("unreadable attribute") |
|
attr = "__cached_" + self.fget.__name__ |
|
cached = getattr(obj, attr, None) |
|
if cached is None: |
|
cached = self.fget(obj) |
|
setattr(obj, attr, cached) |
|
return cached |
|
|
|
|
|
def is_tensor(x): |
|
""" |
|
Tests if `x` is a `torch.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray` or `np.ndarray`. |
|
""" |
|
if is_torch_fx_proxy(x): |
|
return True |
|
if is_torch_available(): |
|
import torch |
|
|
|
if isinstance(x, torch.Tensor): |
|
return True |
|
if is_tf_available(): |
|
import tensorflow as tf |
|
|
|
if isinstance(x, tf.Tensor): |
|
return True |
|
|
|
if is_flax_available(): |
|
import jax.numpy as jnp |
|
from jax.core import Tracer |
|
|
|
if isinstance(x, (jnp.ndarray, Tracer)): |
|
return True |
|
|
|
return isinstance(x, np.ndarray) |
|
|
|
|
|
def _is_numpy(x): |
|
return isinstance(x, np.ndarray) |
|
|
|
|
|
def is_numpy_array(x): |
|
""" |
|
Tests if `x` is a numpy array or not. |
|
""" |
|
return _is_numpy(x) |
|
|
|
|
|
def _is_torch(x): |
|
import torch |
|
|
|
return isinstance(x, torch.Tensor) |
|
|
|
|
|
def is_torch_tensor(x): |
|
""" |
|
Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed. |
|
""" |
|
return False if not is_torch_available() else _is_torch(x) |
|
|
|
|
|
def _is_torch_device(x): |
|
import torch |
|
|
|
return isinstance(x, torch.device) |
|
|
|
|
|
def is_torch_device(x): |
|
""" |
|
Tests if `x` is a torch device or not. Safe to call even if torch is not installed. |
|
""" |
|
return False if not is_torch_available() else _is_torch_device(x) |
|
|
|
|
|
def _is_torch_dtype(x): |
|
import torch |
|
|
|
if isinstance(x, str): |
|
if hasattr(torch, x): |
|
x = getattr(torch, x) |
|
else: |
|
return False |
|
return isinstance(x, torch.dtype) |
|
|
|
|
|
def is_torch_dtype(x): |
|
""" |
|
Tests if `x` is a torch dtype or not. Safe to call even if torch is not installed. |
|
""" |
|
return False if not is_torch_available() else _is_torch_dtype(x) |
|
|
|
|
|
def _is_tensorflow(x): |
|
import tensorflow as tf |
|
|
|
return isinstance(x, tf.Tensor) |
|
|
|
|
|
def is_tf_tensor(x): |
|
""" |
|
Tests if `x` is a tensorflow tensor or not. Safe to call even if tensorflow is not installed. |
|
""" |
|
return False if not is_tf_available() else _is_tensorflow(x) |
|
|
|
|
|
def _is_jax(x): |
|
import jax.numpy as jnp |
|
|
|
return isinstance(x, jnp.ndarray) |
|
|
|
|
|
def is_jax_tensor(x): |
|
""" |
|
Tests if `x` is a Jax tensor or not. Safe to call even if jax is not installed. |
|
""" |
|
return False if not is_flax_available() else _is_jax(x) |
|
|
|
|
|
def to_py_obj(obj): |
|
""" |
|
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list. |
|
""" |
|
if isinstance(obj, (dict, UserDict)): |
|
return {k: to_py_obj(v) for k, v in obj.items()} |
|
elif isinstance(obj, (list, tuple)): |
|
return [to_py_obj(o) for o in obj] |
|
elif is_tf_tensor(obj): |
|
return obj.numpy().tolist() |
|
elif is_torch_tensor(obj): |
|
return obj.detach().cpu().tolist() |
|
elif is_jax_tensor(obj): |
|
return np.asarray(obj).tolist() |
|
elif isinstance(obj, (np.ndarray, np.number)): |
|
return obj.tolist() |
|
else: |
|
return obj |
|
|
|
|
|
def to_numpy(obj): |
|
""" |
|
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a Numpy array. |
|
""" |
|
if isinstance(obj, (dict, UserDict)): |
|
return {k: to_numpy(v) for k, v in obj.items()} |
|
elif isinstance(obj, (list, tuple)): |
|
return np.array(obj) |
|
elif is_tf_tensor(obj): |
|
return obj.numpy() |
|
elif is_torch_tensor(obj): |
|
return obj.detach().cpu().numpy() |
|
elif is_jax_tensor(obj): |
|
return np.asarray(obj) |
|
else: |
|
return obj |
|
|
|
|
|
class ModelOutput(OrderedDict): |
|
""" |
|
Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a |
|
tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular |
|
python dictionary. |
|
|
|
<Tip warning={true}> |
|
|
|
You can't unpack a `ModelOutput` directly. Use the [`~utils.ModelOutput.to_tuple`] method to convert it to a tuple |
|
before. |
|
|
|
</Tip> |
|
""" |
|
|
|
def __post_init__(self): |
|
class_fields = fields(self) |
|
|
|
|
|
if not len(class_fields): |
|
raise ValueError(f"{self.__class__.__name__} has no fields.") |
|
if not all(field.default is None for field in class_fields[1:]): |
|
raise ValueError(f"{self.__class__.__name__} should not have more than one required field.") |
|
|
|
first_field = getattr(self, class_fields[0].name) |
|
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) |
|
|
|
if other_fields_are_none and not is_tensor(first_field): |
|
if isinstance(first_field, dict): |
|
iterator = first_field.items() |
|
first_field_iterator = True |
|
else: |
|
try: |
|
iterator = iter(first_field) |
|
first_field_iterator = True |
|
except TypeError: |
|
first_field_iterator = False |
|
|
|
|
|
|
|
if first_field_iterator: |
|
for idx, element in enumerate(iterator): |
|
if ( |
|
not isinstance(element, (list, tuple)) |
|
or not len(element) == 2 |
|
or not isinstance(element[0], str) |
|
): |
|
if idx == 0: |
|
|
|
self[class_fields[0].name] = first_field |
|
else: |
|
|
|
raise ValueError( |
|
f"Cannot set key/value for {element}. It needs to be a tuple (key, value)." |
|
) |
|
break |
|
setattr(self, element[0], element[1]) |
|
if element[1] is not None: |
|
self[element[0]] = element[1] |
|
elif first_field is not None: |
|
self[class_fields[0].name] = first_field |
|
else: |
|
for field in class_fields: |
|
v = getattr(self, field.name) |
|
if v is not None: |
|
self[field.name] = v |
|
|
|
def __delitem__(self, *args, **kwargs): |
|
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") |
|
|
|
def setdefault(self, *args, **kwargs): |
|
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") |
|
|
|
def pop(self, *args, **kwargs): |
|
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") |
|
|
|
def update(self, *args, **kwargs): |
|
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") |
|
|
|
def __getitem__(self, k): |
|
if isinstance(k, str): |
|
inner_dict = dict(self.items()) |
|
return inner_dict[k] |
|
else: |
|
return self.to_tuple()[k] |
|
|
|
def __setattr__(self, name, value): |
|
if name in self.keys() and value is not None: |
|
|
|
super().__setitem__(name, value) |
|
super().__setattr__(name, value) |
|
|
|
def __setitem__(self, key, value): |
|
|
|
super().__setitem__(key, value) |
|
|
|
super().__setattr__(key, value) |
|
|
|
def to_tuple(self) -> Tuple[Any]: |
|
""" |
|
Convert self to a tuple containing all the attributes/keys that are not `None`. |
|
""" |
|
return tuple(self[k] for k in self.keys()) |
|
|
|
|
|
class ExplicitEnum(str, Enum): |
|
""" |
|
Enum with more explicit error message for missing values. |
|
""" |
|
|
|
@classmethod |
|
def _missing_(cls, value): |
|
raise ValueError( |
|
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" |
|
) |
|
|
|
|
|
class PaddingStrategy(ExplicitEnum): |
|
""" |
|
Possible values for the `padding` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in an |
|
IDE. |
|
""" |
|
|
|
LONGEST = "longest" |
|
MAX_LENGTH = "max_length" |
|
DO_NOT_PAD = "do_not_pad" |
|
|
|
|
|
class TensorType(ExplicitEnum): |
|
""" |
|
Possible values for the `return_tensors` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for |
|
tab-completion in an IDE. |
|
""" |
|
|
|
PYTORCH = "pt" |
|
TENSORFLOW = "tf" |
|
NUMPY = "np" |
|
JAX = "jax" |
|
|
|
|
|
class ContextManagers: |
|
""" |
|
Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers` |
|
in the `fastcore` library. |
|
""" |
|
|
|
def __init__(self, context_managers: List[ContextManager]): |
|
self.context_managers = context_managers |
|
self.stack = ExitStack() |
|
|
|
def __enter__(self): |
|
for context_manager in self.context_managers: |
|
self.stack.enter_context(context_manager) |
|
|
|
def __exit__(self, *args, **kwargs): |
|
self.stack.__exit__(*args, **kwargs) |
|
|
|
|
|
def can_return_loss(model_class): |
|
""" |
|
Check if a given model can return loss. |
|
|
|
Args: |
|
model_class (`type`): The class of the model. |
|
""" |
|
base_classes = str(inspect.getmro(model_class)) |
|
|
|
if "keras.engine.training.Model" in base_classes: |
|
signature = inspect.signature(model_class.call) |
|
elif "torch.nn.modules.module.Module" in base_classes: |
|
signature = inspect.signature(model_class.forward) |
|
else: |
|
signature = inspect.signature(model_class.__call__) |
|
|
|
for p in signature.parameters: |
|
if p == "return_loss" and signature.parameters[p].default is True: |
|
return True |
|
|
|
return False |
|
|
|
|
|
def find_labels(model_class): |
|
""" |
|
Find the labels used by a given model. |
|
|
|
Args: |
|
model_class (`type`): The class of the model. |
|
""" |
|
model_name = model_class.__name__ |
|
base_classes = str(inspect.getmro(model_class)) |
|
|
|
if "keras.engine.training.Model" in base_classes: |
|
signature = inspect.signature(model_class.call) |
|
elif "torch.nn.modules.module.Module" in base_classes: |
|
signature = inspect.signature(model_class.forward) |
|
else: |
|
signature = inspect.signature(model_class.__call__) |
|
|
|
if "QuestionAnswering" in model_name: |
|
return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")] |
|
else: |
|
return [p for p in signature.parameters if "label" in p] |
|
|
|
|
|
def flatten_dict(d: MutableMapping, parent_key: str = "", delimiter: str = "."): |
|
"""Flatten a nested dict into a single level dict.""" |
|
|
|
def _flatten_dict(d, parent_key="", delimiter="."): |
|
for k, v in d.items(): |
|
key = str(parent_key) + delimiter + str(k) if parent_key else k |
|
if v and isinstance(v, MutableMapping): |
|
yield from flatten_dict(v, key, delimiter=delimiter).items() |
|
else: |
|
yield key, v |
|
|
|
return dict(_flatten_dict(d, parent_key, delimiter)) |
|
|
|
|
|
@contextmanager |
|
def working_or_temp_dir(working_dir, use_temp_dir: bool = False): |
|
if use_temp_dir: |
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
yield tmp_dir |
|
else: |
|
yield working_dir |
|
|
|
|
|
def transpose(array, axes=None): |
|
""" |
|
Framework-agnostic version of `numpy.transpose` that will work on torch/TensorFlow/Jax tensors as well as NumPy |
|
arrays. |
|
""" |
|
if is_numpy_array(array): |
|
return np.transpose(array, axes=axes) |
|
elif is_torch_tensor(array): |
|
return array.T if axes is None else array.permute(*axes) |
|
elif is_tf_tensor(array): |
|
import tensorflow as tf |
|
|
|
return tf.transpose(array, perm=axes) |
|
elif is_jax_tensor(array): |
|
return jnp.transpose(array, axes=axes) |
|
else: |
|
raise ValueError(f"Type not supported for transpose: {type(array)}.") |
|
|
|
|
|
def reshape(array, newshape): |
|
""" |
|
Framework-agnostic version of `numpy.reshape` that will work on torch/TensorFlow/Jax tensors as well as NumPy |
|
arrays. |
|
""" |
|
if is_numpy_array(array): |
|
return np.reshape(array, newshape) |
|
elif is_torch_tensor(array): |
|
return array.reshape(*newshape) |
|
elif is_tf_tensor(array): |
|
import tensorflow as tf |
|
|
|
return tf.reshape(array, newshape) |
|
elif is_jax_tensor(array): |
|
return jnp.reshape(array, newshape) |
|
else: |
|
raise ValueError(f"Type not supported for reshape: {type(array)}.") |
|
|
|
|
|
def squeeze(array, axis=None): |
|
""" |
|
Framework-agnostic version of `numpy.squeeze` that will work on torch/TensorFlow/Jax tensors as well as NumPy |
|
arrays. |
|
""" |
|
if is_numpy_array(array): |
|
return np.squeeze(array, axis=axis) |
|
elif is_torch_tensor(array): |
|
return array.squeeze() if axis is None else array.squeeze(dim=axis) |
|
elif is_tf_tensor(array): |
|
import tensorflow as tf |
|
|
|
return tf.squeeze(array, axis=axis) |
|
elif is_jax_tensor(array): |
|
return jnp.squeeze(array, axis=axis) |
|
else: |
|
raise ValueError(f"Type not supported for squeeze: {type(array)}.") |
|
|
|
|
|
def expand_dims(array, axis): |
|
""" |
|
Framework-agnostic version of `numpy.expand_dims` that will work on torch/TensorFlow/Jax tensors as well as NumPy |
|
arrays. |
|
""" |
|
if is_numpy_array(array): |
|
return np.expand_dims(array, axis) |
|
elif is_torch_tensor(array): |
|
return array.unsqueeze(dim=axis) |
|
elif is_tf_tensor(array): |
|
import tensorflow as tf |
|
|
|
return tf.expand_dims(array, axis=axis) |
|
elif is_jax_tensor(array): |
|
return jnp.expand_dims(array, axis=axis) |
|
else: |
|
raise ValueError(f"Type not supported for expand_dims: {type(array)}.") |
|
|
|
|
|
def tensor_size(array): |
|
""" |
|
Framework-agnostic version of `numpy.size` that will work on torch/TensorFlow/Jax tensors as well as NumPy arrays. |
|
""" |
|
if is_numpy_array(array): |
|
return np.size(array) |
|
elif is_torch_tensor(array): |
|
return array.numel() |
|
elif is_tf_tensor(array): |
|
import tensorflow as tf |
|
|
|
return tf.size(array) |
|
elif is_jax_tensor(array): |
|
return array.size |
|
else: |
|
raise ValueError(f"Type not supported for expand_dims: {type(array)}.") |
|
|