Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
from warnings import warn | |
import inspect | |
from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning | |
from .utils import expand_tuples | |
import itertools as itl | |
class MDNotImplementedError(NotImplementedError): | |
""" A NotImplementedError for multiple dispatch """ | |
### Functions for on_ambiguity | |
def ambiguity_warn(dispatcher, ambiguities): | |
""" Raise warning when ambiguity is detected | |
Parameters | |
---------- | |
dispatcher : Dispatcher | |
The dispatcher on which the ambiguity was detected | |
ambiguities : set | |
Set of type signature pairs that are ambiguous within this dispatcher | |
See Also: | |
Dispatcher.add | |
warning_text | |
""" | |
warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning) | |
class RaiseNotImplementedError: | |
"""Raise ``NotImplementedError`` when called.""" | |
def __init__(self, dispatcher): | |
self.dispatcher = dispatcher | |
def __call__(self, *args, **kwargs): | |
types = tuple(type(a) for a in args) | |
raise NotImplementedError( | |
"Ambiguous signature for %s: <%s>" % ( | |
self.dispatcher.name, str_signature(types) | |
)) | |
def ambiguity_register_error_ignore_dup(dispatcher, ambiguities): | |
""" | |
If super signature for ambiguous types is duplicate types, ignore it. | |
Else, register instance of ``RaiseNotImplementedError`` for ambiguous types. | |
Parameters | |
---------- | |
dispatcher : Dispatcher | |
The dispatcher on which the ambiguity was detected | |
ambiguities : set | |
Set of type signature pairs that are ambiguous within this dispatcher | |
See Also: | |
Dispatcher.add | |
ambiguity_warn | |
""" | |
for amb in ambiguities: | |
signature = tuple(super_signature(amb)) | |
if len(set(signature)) == 1: | |
continue | |
dispatcher.add( | |
signature, RaiseNotImplementedError(dispatcher), | |
on_ambiguity=ambiguity_register_error_ignore_dup | |
) | |
### | |
_unresolved_dispatchers: set[Dispatcher] = set() | |
_resolve = [True] | |
def halt_ordering(): | |
_resolve[0] = False | |
def restart_ordering(on_ambiguity=ambiguity_warn): | |
_resolve[0] = True | |
while _unresolved_dispatchers: | |
dispatcher = _unresolved_dispatchers.pop() | |
dispatcher.reorder(on_ambiguity=on_ambiguity) | |
class Dispatcher: | |
""" Dispatch methods based on type signature | |
Use ``dispatch`` to add implementations | |
Examples | |
-------- | |
>>> from sympy.multipledispatch import dispatch | |
>>> @dispatch(int) | |
... def f(x): | |
... return x + 1 | |
>>> @dispatch(float) | |
... def f(x): # noqa: F811 | |
... return x - 1 | |
>>> f(3) | |
4 | |
>>> f(3.0) | |
2.0 | |
""" | |
__slots__ = '__name__', 'name', 'funcs', 'ordering', '_cache', 'doc' | |
def __init__(self, name, doc=None): | |
self.name = self.__name__ = name | |
self.funcs = {} | |
self._cache = {} | |
self.ordering = [] | |
self.doc = doc | |
def register(self, *types, **kwargs): | |
""" Register dispatcher with new implementation | |
>>> from sympy.multipledispatch.dispatcher import Dispatcher | |
>>> f = Dispatcher('f') | |
>>> @f.register(int) | |
... def inc(x): | |
... return x + 1 | |
>>> @f.register(float) | |
... def dec(x): | |
... return x - 1 | |
>>> @f.register(list) | |
... @f.register(tuple) | |
... def reverse(x): | |
... return x[::-1] | |
>>> f(1) | |
2 | |
>>> f(1.0) | |
0.0 | |
>>> f([1, 2, 3]) | |
[3, 2, 1] | |
""" | |
def _(func): | |
self.add(types, func, **kwargs) | |
return func | |
return _ | |
def get_func_params(cls, func): | |
if hasattr(inspect, "signature"): | |
sig = inspect.signature(func) | |
return sig.parameters.values() | |
def get_func_annotations(cls, func): | |
""" Get annotations of function positional parameters | |
""" | |
params = cls.get_func_params(func) | |
if params: | |
Parameter = inspect.Parameter | |
params = (param for param in params | |
if param.kind in | |
(Parameter.POSITIONAL_ONLY, | |
Parameter.POSITIONAL_OR_KEYWORD)) | |
annotations = tuple( | |
param.annotation | |
for param in params) | |
if not any(ann is Parameter.empty for ann in annotations): | |
return annotations | |
def add(self, signature, func, on_ambiguity=ambiguity_warn): | |
""" Add new types/method pair to dispatcher | |
>>> from sympy.multipledispatch import Dispatcher | |
>>> D = Dispatcher('add') | |
>>> D.add((int, int), lambda x, y: x + y) | |
>>> D.add((float, float), lambda x, y: x + y) | |
>>> D(1, 2) | |
3 | |
>>> D(1, 2.0) | |
Traceback (most recent call last): | |
... | |
NotImplementedError: Could not find signature for add: <int, float> | |
When ``add`` detects a warning it calls the ``on_ambiguity`` callback | |
with a dispatcher/itself, and a set of ambiguous type signature pairs | |
as inputs. See ``ambiguity_warn`` for an example. | |
""" | |
# Handle annotations | |
if not signature: | |
annotations = self.get_func_annotations(func) | |
if annotations: | |
signature = annotations | |
# Handle union types | |
if any(isinstance(typ, tuple) for typ in signature): | |
for typs in expand_tuples(signature): | |
self.add(typs, func, on_ambiguity) | |
return | |
for typ in signature: | |
if not isinstance(typ, type): | |
str_sig = ', '.join(c.__name__ if isinstance(c, type) | |
else str(c) for c in signature) | |
raise TypeError("Tried to dispatch on non-type: %s\n" | |
"In signature: <%s>\n" | |
"In function: %s" % | |
(typ, str_sig, self.name)) | |
self.funcs[signature] = func | |
self.reorder(on_ambiguity=on_ambiguity) | |
self._cache.clear() | |
def reorder(self, on_ambiguity=ambiguity_warn): | |
if _resolve[0]: | |
self.ordering = ordering(self.funcs) | |
amb = ambiguities(self.funcs) | |
if amb: | |
on_ambiguity(self, amb) | |
else: | |
_unresolved_dispatchers.add(self) | |
def __call__(self, *args, **kwargs): | |
types = tuple([type(arg) for arg in args]) | |
try: | |
func = self._cache[types] | |
except KeyError: | |
func = self.dispatch(*types) | |
if not func: | |
raise NotImplementedError( | |
'Could not find signature for %s: <%s>' % | |
(self.name, str_signature(types))) | |
self._cache[types] = func | |
try: | |
return func(*args, **kwargs) | |
except MDNotImplementedError: | |
funcs = self.dispatch_iter(*types) | |
next(funcs) # burn first | |
for func in funcs: | |
try: | |
return func(*args, **kwargs) | |
except MDNotImplementedError: | |
pass | |
raise NotImplementedError("Matching functions for " | |
"%s: <%s> found, but none completed successfully" | |
% (self.name, str_signature(types))) | |
def __str__(self): | |
return "<dispatched %s>" % self.name | |
__repr__ = __str__ | |
def dispatch(self, *types): | |
""" Deterimine appropriate implementation for this type signature | |
This method is internal. Users should call this object as a function. | |
Implementation resolution occurs within the ``__call__`` method. | |
>>> from sympy.multipledispatch import dispatch | |
>>> @dispatch(int) | |
... def inc(x): | |
... return x + 1 | |
>>> implementation = inc.dispatch(int) | |
>>> implementation(3) | |
4 | |
>>> print(inc.dispatch(float)) | |
None | |
See Also: | |
``sympy.multipledispatch.conflict`` - module to determine resolution order | |
""" | |
if types in self.funcs: | |
return self.funcs[types] | |
try: | |
return next(self.dispatch_iter(*types)) | |
except StopIteration: | |
return None | |
def dispatch_iter(self, *types): | |
n = len(types) | |
for signature in self.ordering: | |
if len(signature) == n and all(map(issubclass, types, signature)): | |
result = self.funcs[signature] | |
yield result | |
def resolve(self, types): | |
""" Deterimine appropriate implementation for this type signature | |
.. deprecated:: 0.4.4 | |
Use ``dispatch(*types)`` instead | |
""" | |
warn("resolve() is deprecated, use dispatch(*types)", | |
DeprecationWarning) | |
return self.dispatch(*types) | |
def __getstate__(self): | |
return {'name': self.name, | |
'funcs': self.funcs} | |
def __setstate__(self, d): | |
self.name = d['name'] | |
self.funcs = d['funcs'] | |
self.ordering = ordering(self.funcs) | |
self._cache = {} | |
def __doc__(self): | |
docs = ["Multiply dispatched method: %s" % self.name] | |
if self.doc: | |
docs.append(self.doc) | |
other = [] | |
for sig in self.ordering[::-1]: | |
func = self.funcs[sig] | |
if func.__doc__: | |
s = 'Inputs: <%s>\n' % str_signature(sig) | |
s += '-' * len(s) + '\n' | |
s += func.__doc__.strip() | |
docs.append(s) | |
else: | |
other.append(str_signature(sig)) | |
if other: | |
docs.append('Other signatures:\n ' + '\n '.join(other)) | |
return '\n\n'.join(docs) | |
def _help(self, *args): | |
return self.dispatch(*map(type, args)).__doc__ | |
def help(self, *args, **kwargs): | |
""" Print docstring for the function corresponding to inputs """ | |
print(self._help(*args)) | |
def _source(self, *args): | |
func = self.dispatch(*map(type, args)) | |
if not func: | |
raise TypeError("No function found") | |
return source(func) | |
def source(self, *args, **kwargs): | |
""" Print source code for the function corresponding to inputs """ | |
print(self._source(*args)) | |
def source(func): | |
s = 'File: %s\n\n' % inspect.getsourcefile(func) | |
s = s + inspect.getsource(func) | |
return s | |
class MethodDispatcher(Dispatcher): | |
""" Dispatch methods based on type signature | |
See Also: | |
Dispatcher | |
""" | |
def get_func_params(cls, func): | |
if hasattr(inspect, "signature"): | |
sig = inspect.signature(func) | |
return itl.islice(sig.parameters.values(), 1, None) | |
def __get__(self, instance, owner): | |
self.obj = instance | |
self.cls = owner | |
return self | |
def __call__(self, *args, **kwargs): | |
types = tuple([type(arg) for arg in args]) | |
func = self.dispatch(*types) | |
if not func: | |
raise NotImplementedError('Could not find signature for %s: <%s>' % | |
(self.name, str_signature(types))) | |
return func(self.obj, *args, **kwargs) | |
def str_signature(sig): | |
""" String representation of type signature | |
>>> from sympy.multipledispatch.dispatcher import str_signature | |
>>> str_signature((int, float)) | |
'int, float' | |
""" | |
return ', '.join(cls.__name__ for cls in sig) | |
def warning_text(name, amb): | |
""" The text for ambiguity warnings """ | |
text = "\nAmbiguities exist in dispatched function %s\n\n" % (name) | |
text += "The following signatures may result in ambiguous behavior:\n" | |
for pair in amb: | |
text += "\t" + \ | |
', '.join('[' + str_signature(s) + ']' for s in pair) + "\n" | |
text += "\n\nConsider making the following additions:\n\n" | |
text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s)) | |
+ ')\ndef %s(...)' % name for s in amb]) | |
return text | |