|
|
|
|
|
|
|
import functools |
|
import math |
|
import operator |
|
import re |
|
from abc import ABC, abstractmethod |
|
from collections.abc import Iterable |
|
from inspect import signature |
|
from numbers import Integral, Real |
|
|
|
import numpy as np |
|
from scipy.sparse import csr_matrix, issparse |
|
|
|
from .._config import config_context, get_config |
|
from .validation import _is_arraylike_not_scalar |
|
|
|
|
|
class InvalidParameterError(ValueError, TypeError): |
|
"""Custom exception to be raised when the parameter of a class/method/function |
|
does not have a valid type or value. |
|
""" |
|
|
|
|
|
|
|
|
|
def validate_parameter_constraints(parameter_constraints, params, caller_name): |
|
"""Validate types and values of given parameters. |
|
|
|
Parameters |
|
---------- |
|
parameter_constraints : dict or {"no_validation"} |
|
If "no_validation", validation is skipped for this parameter. |
|
|
|
If a dict, it must be a dictionary `param_name: list of constraints`. |
|
A parameter is valid if it satisfies one of the constraints from the list. |
|
Constraints can be: |
|
- an Interval object, representing a continuous or discrete range of numbers |
|
- the string "array-like" |
|
- the string "sparse matrix" |
|
- the string "random_state" |
|
- callable |
|
- None, meaning that None is a valid value for the parameter |
|
- any type, meaning that any instance of this type is valid |
|
- an Options object, representing a set of elements of a given type |
|
- a StrOptions object, representing a set of strings |
|
- the string "boolean" |
|
- the string "verbose" |
|
- the string "cv_object" |
|
- the string "nan" |
|
- a MissingValues object representing markers for missing values |
|
- a HasMethods object, representing method(s) an object must have |
|
- a Hidden object, representing a constraint not meant to be exposed to the user |
|
|
|
params : dict |
|
A dictionary `param_name: param_value`. The parameters to validate against the |
|
constraints. |
|
|
|
caller_name : str |
|
The name of the estimator or function or method that called this function. |
|
""" |
|
for param_name, param_val in params.items(): |
|
|
|
|
|
|
|
if param_name not in parameter_constraints: |
|
continue |
|
|
|
constraints = parameter_constraints[param_name] |
|
|
|
if constraints == "no_validation": |
|
continue |
|
|
|
constraints = [make_constraint(constraint) for constraint in constraints] |
|
|
|
for constraint in constraints: |
|
if constraint.is_satisfied_by(param_val): |
|
|
|
break |
|
else: |
|
|
|
|
|
|
|
|
|
constraints = [ |
|
constraint for constraint in constraints if not constraint.hidden |
|
] |
|
|
|
if len(constraints) == 1: |
|
constraints_str = f"{constraints[0]}" |
|
else: |
|
constraints_str = ( |
|
f"{', '.join([str(c) for c in constraints[:-1]])} or" |
|
f" {constraints[-1]}" |
|
) |
|
|
|
raise InvalidParameterError( |
|
f"The {param_name!r} parameter of {caller_name} must be" |
|
f" {constraints_str}. Got {param_val!r} instead." |
|
) |
|
|
|
|
|
def make_constraint(constraint): |
|
"""Convert the constraint into the appropriate Constraint object. |
|
|
|
Parameters |
|
---------- |
|
constraint : object |
|
The constraint to convert. |
|
|
|
Returns |
|
------- |
|
constraint : instance of _Constraint |
|
The converted constraint. |
|
""" |
|
if isinstance(constraint, str) and constraint == "array-like": |
|
return _ArrayLikes() |
|
if isinstance(constraint, str) and constraint == "sparse matrix": |
|
return _SparseMatrices() |
|
if isinstance(constraint, str) and constraint == "random_state": |
|
return _RandomStates() |
|
if constraint is callable: |
|
return _Callables() |
|
if constraint is None: |
|
return _NoneConstraint() |
|
if isinstance(constraint, type): |
|
return _InstancesOf(constraint) |
|
if isinstance( |
|
constraint, (Interval, StrOptions, Options, HasMethods, MissingValues) |
|
): |
|
return constraint |
|
if isinstance(constraint, str) and constraint == "boolean": |
|
return _Booleans() |
|
if isinstance(constraint, str) and constraint == "verbose": |
|
return _VerboseHelper() |
|
if isinstance(constraint, str) and constraint == "cv_object": |
|
return _CVObjects() |
|
if isinstance(constraint, Hidden): |
|
constraint = make_constraint(constraint.constraint) |
|
constraint.hidden = True |
|
return constraint |
|
if isinstance(constraint, str) and constraint == "nan": |
|
return _NanConstraint() |
|
raise ValueError(f"Unknown constraint type: {constraint}") |
|
|
|
|
|
def validate_params(parameter_constraints, *, prefer_skip_nested_validation): |
|
"""Decorator to validate types and values of functions and methods. |
|
|
|
Parameters |
|
---------- |
|
parameter_constraints : dict |
|
A dictionary `param_name: list of constraints`. See the docstring of |
|
`validate_parameter_constraints` for a description of the accepted constraints. |
|
|
|
Note that the *args and **kwargs parameters are not validated and must not be |
|
present in the parameter_constraints dictionary. |
|
|
|
prefer_skip_nested_validation : bool |
|
If True, the validation of parameters of inner estimators or functions |
|
called by the decorated function will be skipped. |
|
|
|
This is useful to avoid validating many times the parameters passed by the |
|
user from the public facing API. It's also useful to avoid validating |
|
parameters that we pass internally to inner functions that are guaranteed to |
|
be valid by the test suite. |
|
|
|
It should be set to True for most functions, except for those that receive |
|
non-validated objects as parameters or that are just wrappers around classes |
|
because they only perform a partial validation. |
|
|
|
Returns |
|
------- |
|
decorated_function : function or method |
|
The decorated function. |
|
""" |
|
|
|
def decorator(func): |
|
|
|
|
|
|
|
setattr(func, "_skl_parameter_constraints", parameter_constraints) |
|
|
|
@functools.wraps(func) |
|
def wrapper(*args, **kwargs): |
|
global_skip_validation = get_config()["skip_parameter_validation"] |
|
if global_skip_validation: |
|
return func(*args, **kwargs) |
|
|
|
func_sig = signature(func) |
|
|
|
|
|
params = func_sig.bind(*args, **kwargs) |
|
params.apply_defaults() |
|
|
|
|
|
to_ignore = [ |
|
p.name |
|
for p in func_sig.parameters.values() |
|
if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD) |
|
] |
|
to_ignore += ["self", "cls"] |
|
params = {k: v for k, v in params.arguments.items() if k not in to_ignore} |
|
|
|
validate_parameter_constraints( |
|
parameter_constraints, params, caller_name=func.__qualname__ |
|
) |
|
|
|
try: |
|
with config_context( |
|
skip_parameter_validation=( |
|
prefer_skip_nested_validation or global_skip_validation |
|
) |
|
): |
|
return func(*args, **kwargs) |
|
except InvalidParameterError as e: |
|
|
|
|
|
|
|
|
|
msg = re.sub( |
|
r"parameter of \w+ must be", |
|
f"parameter of {func.__qualname__} must be", |
|
str(e), |
|
) |
|
raise InvalidParameterError(msg) from e |
|
|
|
return wrapper |
|
|
|
return decorator |
|
|
|
|
|
class RealNotInt(Real): |
|
"""A type that represents reals that are not instances of int. |
|
|
|
Behaves like float, but also works with values extracted from numpy arrays. |
|
isintance(1, RealNotInt) -> False |
|
isinstance(1.0, RealNotInt) -> True |
|
""" |
|
|
|
|
|
RealNotInt.register(float) |
|
|
|
|
|
def _type_name(t): |
|
"""Convert type into human readable string.""" |
|
module = t.__module__ |
|
qualname = t.__qualname__ |
|
if module == "builtins": |
|
return qualname |
|
elif t == Real: |
|
return "float" |
|
elif t == Integral: |
|
return "int" |
|
return f"{module}.{qualname}" |
|
|
|
|
|
class _Constraint(ABC): |
|
"""Base class for the constraint objects.""" |
|
|
|
def __init__(self): |
|
self.hidden = False |
|
|
|
@abstractmethod |
|
def is_satisfied_by(self, val): |
|
"""Whether or not a value satisfies the constraint. |
|
|
|
Parameters |
|
---------- |
|
val : object |
|
The value to check. |
|
|
|
Returns |
|
------- |
|
is_satisfied : bool |
|
Whether or not the constraint is satisfied by this value. |
|
""" |
|
|
|
@abstractmethod |
|
def __str__(self): |
|
"""A human readable representational string of the constraint.""" |
|
|
|
|
|
class _InstancesOf(_Constraint): |
|
"""Constraint representing instances of a given type. |
|
|
|
Parameters |
|
---------- |
|
type : type |
|
The valid type. |
|
""" |
|
|
|
def __init__(self, type): |
|
super().__init__() |
|
self.type = type |
|
|
|
def is_satisfied_by(self, val): |
|
return isinstance(val, self.type) |
|
|
|
def __str__(self): |
|
return f"an instance of {_type_name(self.type)!r}" |
|
|
|
|
|
class _NoneConstraint(_Constraint): |
|
"""Constraint representing the None singleton.""" |
|
|
|
def is_satisfied_by(self, val): |
|
return val is None |
|
|
|
def __str__(self): |
|
return "None" |
|
|
|
|
|
class _NanConstraint(_Constraint): |
|
"""Constraint representing the indicator `np.nan`.""" |
|
|
|
def is_satisfied_by(self, val): |
|
return ( |
|
not isinstance(val, Integral) and isinstance(val, Real) and math.isnan(val) |
|
) |
|
|
|
def __str__(self): |
|
return "numpy.nan" |
|
|
|
|
|
class _PandasNAConstraint(_Constraint): |
|
"""Constraint representing the indicator `pd.NA`.""" |
|
|
|
def is_satisfied_by(self, val): |
|
try: |
|
import pandas as pd |
|
|
|
return isinstance(val, type(pd.NA)) and pd.isna(val) |
|
except ImportError: |
|
return False |
|
|
|
def __str__(self): |
|
return "pandas.NA" |
|
|
|
|
|
class Options(_Constraint): |
|
"""Constraint representing a finite set of instances of a given type. |
|
|
|
Parameters |
|
---------- |
|
type : type |
|
|
|
options : set |
|
The set of valid scalars. |
|
|
|
deprecated : set or None, default=None |
|
A subset of the `options` to mark as deprecated in the string |
|
representation of the constraint. |
|
""" |
|
|
|
def __init__(self, type, options, *, deprecated=None): |
|
super().__init__() |
|
self.type = type |
|
self.options = options |
|
self.deprecated = deprecated or set() |
|
|
|
if self.deprecated - self.options: |
|
raise ValueError("The deprecated options must be a subset of the options.") |
|
|
|
def is_satisfied_by(self, val): |
|
return isinstance(val, self.type) and val in self.options |
|
|
|
def _mark_if_deprecated(self, option): |
|
"""Add a deprecated mark to an option if needed.""" |
|
option_str = f"{option!r}" |
|
if option in self.deprecated: |
|
option_str = f"{option_str} (deprecated)" |
|
return option_str |
|
|
|
def __str__(self): |
|
options_str = ( |
|
f"{', '.join([self._mark_if_deprecated(o) for o in self.options])}" |
|
) |
|
return f"a {_type_name(self.type)} among {{{options_str}}}" |
|
|
|
|
|
class StrOptions(Options): |
|
"""Constraint representing a finite set of strings. |
|
|
|
Parameters |
|
---------- |
|
options : set of str |
|
The set of valid strings. |
|
|
|
deprecated : set of str or None, default=None |
|
A subset of the `options` to mark as deprecated in the string |
|
representation of the constraint. |
|
""" |
|
|
|
def __init__(self, options, *, deprecated=None): |
|
super().__init__(type=str, options=options, deprecated=deprecated) |
|
|
|
|
|
class Interval(_Constraint): |
|
"""Constraint representing a typed interval. |
|
|
|
Parameters |
|
---------- |
|
type : {numbers.Integral, numbers.Real, RealNotInt} |
|
The set of numbers in which to set the interval. |
|
|
|
If RealNotInt, only reals that don't have the integer type |
|
are allowed. For example 1.0 is allowed but 1 is not. |
|
|
|
left : float or int or None |
|
The left bound of the interval. None means left bound is -∞. |
|
|
|
right : float, int or None |
|
The right bound of the interval. None means right bound is +∞. |
|
|
|
closed : {"left", "right", "both", "neither"} |
|
Whether the interval is open or closed. Possible choices are: |
|
|
|
- `"left"`: the interval is closed on the left and open on the right. |
|
It is equivalent to the interval `[ left, right )`. |
|
- `"right"`: the interval is closed on the right and open on the left. |
|
It is equivalent to the interval `( left, right ]`. |
|
- `"both"`: the interval is closed. |
|
It is equivalent to the interval `[ left, right ]`. |
|
- `"neither"`: the interval is open. |
|
It is equivalent to the interval `( left, right )`. |
|
|
|
Notes |
|
----- |
|
Setting a bound to `None` and setting the interval closed is valid. For instance, |
|
strictly speaking, `Interval(Real, 0, None, closed="both")` corresponds to |
|
`[0, +∞) U {+∞}`. |
|
""" |
|
|
|
def __init__(self, type, left, right, *, closed): |
|
super().__init__() |
|
self.type = type |
|
self.left = left |
|
self.right = right |
|
self.closed = closed |
|
|
|
self._check_params() |
|
|
|
def _check_params(self): |
|
if self.type not in (Integral, Real, RealNotInt): |
|
raise ValueError( |
|
"type must be either numbers.Integral, numbers.Real or RealNotInt." |
|
f" Got {self.type} instead." |
|
) |
|
|
|
if self.closed not in ("left", "right", "both", "neither"): |
|
raise ValueError( |
|
"closed must be either 'left', 'right', 'both' or 'neither'. " |
|
f"Got {self.closed} instead." |
|
) |
|
|
|
if self.type is Integral: |
|
suffix = "for an interval over the integers." |
|
if self.left is not None and not isinstance(self.left, Integral): |
|
raise TypeError(f"Expecting left to be an int {suffix}") |
|
if self.right is not None and not isinstance(self.right, Integral): |
|
raise TypeError(f"Expecting right to be an int {suffix}") |
|
if self.left is None and self.closed in ("left", "both"): |
|
raise ValueError( |
|
f"left can't be None when closed == {self.closed} {suffix}" |
|
) |
|
if self.right is None and self.closed in ("right", "both"): |
|
raise ValueError( |
|
f"right can't be None when closed == {self.closed} {suffix}" |
|
) |
|
else: |
|
if self.left is not None and not isinstance(self.left, Real): |
|
raise TypeError("Expecting left to be a real number.") |
|
if self.right is not None and not isinstance(self.right, Real): |
|
raise TypeError("Expecting right to be a real number.") |
|
|
|
if self.right is not None and self.left is not None and self.right <= self.left: |
|
raise ValueError( |
|
f"right can't be less than left. Got left={self.left} and " |
|
f"right={self.right}" |
|
) |
|
|
|
def __contains__(self, val): |
|
if not isinstance(val, Integral) and np.isnan(val): |
|
return False |
|
|
|
left_cmp = operator.lt if self.closed in ("left", "both") else operator.le |
|
right_cmp = operator.gt if self.closed in ("right", "both") else operator.ge |
|
|
|
left = -np.inf if self.left is None else self.left |
|
right = np.inf if self.right is None else self.right |
|
|
|
if left_cmp(val, left): |
|
return False |
|
if right_cmp(val, right): |
|
return False |
|
return True |
|
|
|
def is_satisfied_by(self, val): |
|
if not isinstance(val, self.type): |
|
return False |
|
|
|
return val in self |
|
|
|
def __str__(self): |
|
type_str = "an int" if self.type is Integral else "a float" |
|
left_bracket = "[" if self.closed in ("left", "both") else "(" |
|
left_bound = "-inf" if self.left is None else self.left |
|
right_bound = "inf" if self.right is None else self.right |
|
right_bracket = "]" if self.closed in ("right", "both") else ")" |
|
|
|
|
|
if not self.type == Integral and isinstance(self.left, Real): |
|
left_bound = float(left_bound) |
|
if not self.type == Integral and isinstance(self.right, Real): |
|
right_bound = float(right_bound) |
|
|
|
return ( |
|
f"{type_str} in the range " |
|
f"{left_bracket}{left_bound}, {right_bound}{right_bracket}" |
|
) |
|
|
|
|
|
class _ArrayLikes(_Constraint): |
|
"""Constraint representing array-likes""" |
|
|
|
def is_satisfied_by(self, val): |
|
return _is_arraylike_not_scalar(val) |
|
|
|
def __str__(self): |
|
return "an array-like" |
|
|
|
|
|
class _SparseMatrices(_Constraint): |
|
"""Constraint representing sparse matrices.""" |
|
|
|
def is_satisfied_by(self, val): |
|
return issparse(val) |
|
|
|
def __str__(self): |
|
return "a sparse matrix" |
|
|
|
|
|
class _Callables(_Constraint): |
|
"""Constraint representing callables.""" |
|
|
|
def is_satisfied_by(self, val): |
|
return callable(val) |
|
|
|
def __str__(self): |
|
return "a callable" |
|
|
|
|
|
class _RandomStates(_Constraint): |
|
"""Constraint representing random states. |
|
|
|
Convenience class for |
|
[Interval(Integral, 0, 2**32 - 1, closed="both"), np.random.RandomState, None] |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self._constraints = [ |
|
Interval(Integral, 0, 2**32 - 1, closed="both"), |
|
_InstancesOf(np.random.RandomState), |
|
_NoneConstraint(), |
|
] |
|
|
|
def is_satisfied_by(self, val): |
|
return any(c.is_satisfied_by(val) for c in self._constraints) |
|
|
|
def __str__(self): |
|
return ( |
|
f"{', '.join([str(c) for c in self._constraints[:-1]])} or" |
|
f" {self._constraints[-1]}" |
|
) |
|
|
|
|
|
class _Booleans(_Constraint): |
|
"""Constraint representing boolean likes. |
|
|
|
Convenience class for |
|
[bool, np.bool_] |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self._constraints = [ |
|
_InstancesOf(bool), |
|
_InstancesOf(np.bool_), |
|
] |
|
|
|
def is_satisfied_by(self, val): |
|
return any(c.is_satisfied_by(val) for c in self._constraints) |
|
|
|
def __str__(self): |
|
return ( |
|
f"{', '.join([str(c) for c in self._constraints[:-1]])} or" |
|
f" {self._constraints[-1]}" |
|
) |
|
|
|
|
|
class _VerboseHelper(_Constraint): |
|
"""Helper constraint for the verbose parameter. |
|
|
|
Convenience class for |
|
[Interval(Integral, 0, None, closed="left"), bool, numpy.bool_] |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self._constraints = [ |
|
Interval(Integral, 0, None, closed="left"), |
|
_InstancesOf(bool), |
|
_InstancesOf(np.bool_), |
|
] |
|
|
|
def is_satisfied_by(self, val): |
|
return any(c.is_satisfied_by(val) for c in self._constraints) |
|
|
|
def __str__(self): |
|
return ( |
|
f"{', '.join([str(c) for c in self._constraints[:-1]])} or" |
|
f" {self._constraints[-1]}" |
|
) |
|
|
|
|
|
class MissingValues(_Constraint): |
|
"""Helper constraint for the `missing_values` parameters. |
|
|
|
Convenience for |
|
[ |
|
Integral, |
|
Interval(Real, None, None, closed="both"), |
|
str, # when numeric_only is False |
|
None, # when numeric_only is False |
|
_NanConstraint(), |
|
_PandasNAConstraint(), |
|
] |
|
|
|
Parameters |
|
---------- |
|
numeric_only : bool, default=False |
|
Whether to consider only numeric missing value markers. |
|
|
|
""" |
|
|
|
def __init__(self, numeric_only=False): |
|
super().__init__() |
|
|
|
self.numeric_only = numeric_only |
|
|
|
self._constraints = [ |
|
_InstancesOf(Integral), |
|
|
|
Interval(Real, None, None, closed="both"), |
|
_NanConstraint(), |
|
_PandasNAConstraint(), |
|
] |
|
if not self.numeric_only: |
|
self._constraints.extend([_InstancesOf(str), _NoneConstraint()]) |
|
|
|
def is_satisfied_by(self, val): |
|
return any(c.is_satisfied_by(val) for c in self._constraints) |
|
|
|
def __str__(self): |
|
return ( |
|
f"{', '.join([str(c) for c in self._constraints[:-1]])} or" |
|
f" {self._constraints[-1]}" |
|
) |
|
|
|
|
|
class HasMethods(_Constraint): |
|
"""Constraint representing objects that expose specific methods. |
|
|
|
It is useful for parameters following a protocol and where we don't want to impose |
|
an affiliation to a specific module or class. |
|
|
|
Parameters |
|
---------- |
|
methods : str or list of str |
|
The method(s) that the object is expected to expose. |
|
""" |
|
|
|
@validate_params( |
|
{"methods": [str, list]}, |
|
prefer_skip_nested_validation=True, |
|
) |
|
def __init__(self, methods): |
|
super().__init__() |
|
if isinstance(methods, str): |
|
methods = [methods] |
|
self.methods = methods |
|
|
|
def is_satisfied_by(self, val): |
|
return all(callable(getattr(val, method, None)) for method in self.methods) |
|
|
|
def __str__(self): |
|
if len(self.methods) == 1: |
|
methods = f"{self.methods[0]!r}" |
|
else: |
|
methods = ( |
|
f"{', '.join([repr(m) for m in self.methods[:-1]])} and" |
|
f" {self.methods[-1]!r}" |
|
) |
|
return f"an object implementing {methods}" |
|
|
|
|
|
class _IterablesNotString(_Constraint): |
|
"""Constraint representing iterables that are not strings.""" |
|
|
|
def is_satisfied_by(self, val): |
|
return isinstance(val, Iterable) and not isinstance(val, str) |
|
|
|
def __str__(self): |
|
return "an iterable" |
|
|
|
|
|
class _CVObjects(_Constraint): |
|
"""Constraint representing cv objects. |
|
|
|
Convenient class for |
|
[ |
|
Interval(Integral, 2, None, closed="left"), |
|
HasMethods(["split", "get_n_splits"]), |
|
_IterablesNotString(), |
|
None, |
|
] |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self._constraints = [ |
|
Interval(Integral, 2, None, closed="left"), |
|
HasMethods(["split", "get_n_splits"]), |
|
_IterablesNotString(), |
|
_NoneConstraint(), |
|
] |
|
|
|
def is_satisfied_by(self, val): |
|
return any(c.is_satisfied_by(val) for c in self._constraints) |
|
|
|
def __str__(self): |
|
return ( |
|
f"{', '.join([str(c) for c in self._constraints[:-1]])} or" |
|
f" {self._constraints[-1]}" |
|
) |
|
|
|
|
|
class Hidden: |
|
"""Class encapsulating a constraint not meant to be exposed to the user. |
|
|
|
Parameters |
|
---------- |
|
constraint : str or _Constraint instance |
|
The constraint to be used internally. |
|
""" |
|
|
|
def __init__(self, constraint): |
|
self.constraint = constraint |
|
|
|
|
|
def generate_invalid_param_val(constraint): |
|
"""Return a value that does not satisfy the constraint. |
|
|
|
Raises a NotImplementedError if there exists no invalid value for this constraint. |
|
|
|
This is only useful for testing purpose. |
|
|
|
Parameters |
|
---------- |
|
constraint : _Constraint instance |
|
The constraint to generate a value for. |
|
|
|
Returns |
|
------- |
|
val : object |
|
A value that does not satisfy the constraint. |
|
""" |
|
if isinstance(constraint, StrOptions): |
|
return f"not {' or '.join(constraint.options)}" |
|
|
|
if isinstance(constraint, MissingValues): |
|
return np.array([1, 2, 3]) |
|
|
|
if isinstance(constraint, _VerboseHelper): |
|
return -1 |
|
|
|
if isinstance(constraint, HasMethods): |
|
return type("HasNotMethods", (), {})() |
|
|
|
if isinstance(constraint, _IterablesNotString): |
|
return "a string" |
|
|
|
if isinstance(constraint, _CVObjects): |
|
return "not a cv object" |
|
|
|
if isinstance(constraint, Interval) and constraint.type is Integral: |
|
if constraint.left is not None: |
|
return constraint.left - 1 |
|
if constraint.right is not None: |
|
return constraint.right + 1 |
|
|
|
|
|
raise NotImplementedError |
|
|
|
if isinstance(constraint, Interval) and constraint.type in (Real, RealNotInt): |
|
if constraint.left is not None: |
|
return constraint.left - 1e-6 |
|
if constraint.right is not None: |
|
return constraint.right + 1e-6 |
|
|
|
|
|
if constraint.closed in ("right", "neither"): |
|
return -np.inf |
|
if constraint.closed in ("left", "neither"): |
|
return np.inf |
|
|
|
|
|
return np.nan |
|
|
|
raise NotImplementedError |
|
|
|
|
|
def generate_valid_param(constraint): |
|
"""Return a value that does satisfy a constraint. |
|
|
|
This is only useful for testing purpose. |
|
|
|
Parameters |
|
---------- |
|
constraint : Constraint instance |
|
The constraint to generate a value for. |
|
|
|
Returns |
|
------- |
|
val : object |
|
A value that does satisfy the constraint. |
|
""" |
|
if isinstance(constraint, _ArrayLikes): |
|
return np.array([1, 2, 3]) |
|
|
|
if isinstance(constraint, _SparseMatrices): |
|
return csr_matrix([[0, 1], [1, 0]]) |
|
|
|
if isinstance(constraint, _RandomStates): |
|
return np.random.RandomState(42) |
|
|
|
if isinstance(constraint, _Callables): |
|
return lambda x: x |
|
|
|
if isinstance(constraint, _NoneConstraint): |
|
return None |
|
|
|
if isinstance(constraint, _InstancesOf): |
|
if constraint.type is np.ndarray: |
|
|
|
return np.array([1, 2, 3]) |
|
|
|
if constraint.type in (Integral, Real): |
|
|
|
return 1 |
|
|
|
return constraint.type() |
|
|
|
if isinstance(constraint, _Booleans): |
|
return True |
|
|
|
if isinstance(constraint, _VerboseHelper): |
|
return 1 |
|
|
|
if isinstance(constraint, MissingValues) and constraint.numeric_only: |
|
return np.nan |
|
|
|
if isinstance(constraint, MissingValues) and not constraint.numeric_only: |
|
return "missing" |
|
|
|
if isinstance(constraint, HasMethods): |
|
return type( |
|
"ValidHasMethods", (), {m: lambda self: None for m in constraint.methods} |
|
)() |
|
|
|
if isinstance(constraint, _IterablesNotString): |
|
return [1, 2, 3] |
|
|
|
if isinstance(constraint, _CVObjects): |
|
return 5 |
|
|
|
if isinstance(constraint, Options): |
|
for option in constraint.options: |
|
return option |
|
|
|
if isinstance(constraint, Interval): |
|
interval = constraint |
|
if interval.left is None and interval.right is None: |
|
return 0 |
|
elif interval.left is None: |
|
return interval.right - 1 |
|
elif interval.right is None: |
|
return interval.left + 1 |
|
else: |
|
if interval.type is Real: |
|
return (interval.left + interval.right) / 2 |
|
else: |
|
return interval.left + 1 |
|
|
|
raise ValueError(f"Unknown constraint type: {constraint}") |
|
|