Spaces:
Running
Running
"""Implementation of :class:`ModularInteger` class. """ | |
from __future__ import annotations | |
from typing import Any | |
import operator | |
from sympy.polys.polyutils import PicklableWithSlots | |
from sympy.polys.polyerrors import CoercionFailed | |
from sympy.polys.domains.domainelement import DomainElement | |
from sympy.utilities import public | |
from sympy.utilities.exceptions import sympy_deprecation_warning | |
class ModularInteger(PicklableWithSlots, DomainElement): | |
"""A class representing a modular integer. """ | |
mod, dom, sym, _parent = None, None, None, None | |
__slots__ = ('val',) | |
def parent(self): | |
return self._parent | |
def __init__(self, val): | |
if isinstance(val, self.__class__): | |
self.val = val.val % self.mod | |
else: | |
self.val = self.dom.convert(val) % self.mod | |
def modulus(self): | |
return self.mod | |
def __hash__(self): | |
return hash((self.val, self.mod)) | |
def __repr__(self): | |
return "%s(%s)" % (self.__class__.__name__, self.val) | |
def __str__(self): | |
return "%s mod %s" % (self.val, self.mod) | |
def __int__(self): | |
return int(self.val) | |
def to_int(self): | |
sympy_deprecation_warning( | |
"""ModularInteger.to_int() is deprecated. | |
Use int(a) or K = GF(p) and K.to_int(a) instead of a.to_int(). | |
""", | |
deprecated_since_version="1.13", | |
active_deprecations_target="modularinteger-to-int", | |
) | |
if self.sym: | |
if self.val <= self.mod // 2: | |
return self.val | |
else: | |
return self.val - self.mod | |
else: | |
return self.val | |
def __pos__(self): | |
return self | |
def __neg__(self): | |
return self.__class__(-self.val) | |
def _get_val(cls, other): | |
if isinstance(other, cls): | |
return other.val | |
else: | |
try: | |
return cls.dom.convert(other) | |
except CoercionFailed: | |
return None | |
def __add__(self, other): | |
val = self._get_val(other) | |
if val is not None: | |
return self.__class__(self.val + val) | |
else: | |
return NotImplemented | |
def __radd__(self, other): | |
return self.__add__(other) | |
def __sub__(self, other): | |
val = self._get_val(other) | |
if val is not None: | |
return self.__class__(self.val - val) | |
else: | |
return NotImplemented | |
def __rsub__(self, other): | |
return (-self).__add__(other) | |
def __mul__(self, other): | |
val = self._get_val(other) | |
if val is not None: | |
return self.__class__(self.val * val) | |
else: | |
return NotImplemented | |
def __rmul__(self, other): | |
return self.__mul__(other) | |
def __truediv__(self, other): | |
val = self._get_val(other) | |
if val is not None: | |
return self.__class__(self.val * self._invert(val)) | |
else: | |
return NotImplemented | |
def __rtruediv__(self, other): | |
return self.invert().__mul__(other) | |
def __mod__(self, other): | |
val = self._get_val(other) | |
if val is not None: | |
return self.__class__(self.val % val) | |
else: | |
return NotImplemented | |
def __rmod__(self, other): | |
val = self._get_val(other) | |
if val is not None: | |
return self.__class__(val % self.val) | |
else: | |
return NotImplemented | |
def __pow__(self, exp): | |
if not exp: | |
return self.__class__(self.dom.one) | |
if exp < 0: | |
val, exp = self.invert().val, -exp | |
else: | |
val = self.val | |
return self.__class__(pow(val, int(exp), self.mod)) | |
def _compare(self, other, op): | |
val = self._get_val(other) | |
if val is None: | |
return NotImplemented | |
return op(self.val, val % self.mod) | |
def _compare_deprecated(self, other, op): | |
val = self._get_val(other) | |
if val is None: | |
return NotImplemented | |
sympy_deprecation_warning( | |
"""Ordered comparisons with modular integers are deprecated. | |
Use e.g. int(a) < int(b) instead of a < b. | |
""", | |
deprecated_since_version="1.13", | |
active_deprecations_target="modularinteger-compare", | |
stacklevel=4, | |
) | |
return op(self.val, val % self.mod) | |
def __eq__(self, other): | |
return self._compare(other, operator.eq) | |
def __ne__(self, other): | |
return self._compare(other, operator.ne) | |
def __lt__(self, other): | |
return self._compare_deprecated(other, operator.lt) | |
def __le__(self, other): | |
return self._compare_deprecated(other, operator.le) | |
def __gt__(self, other): | |
return self._compare_deprecated(other, operator.gt) | |
def __ge__(self, other): | |
return self._compare_deprecated(other, operator.ge) | |
def __bool__(self): | |
return bool(self.val) | |
def _invert(cls, value): | |
return cls.dom.invert(value, cls.mod) | |
def invert(self): | |
return self.__class__(self._invert(self.val)) | |
_modular_integer_cache: dict[tuple[Any, Any, Any], type[ModularInteger]] = {} | |
def ModularIntegerFactory(_mod, _dom, _sym, parent): | |
"""Create custom class for specific integer modulus.""" | |
try: | |
_mod = _dom.convert(_mod) | |
except CoercionFailed: | |
ok = False | |
else: | |
ok = True | |
if not ok or _mod < 1: | |
raise ValueError("modulus must be a positive integer, got %s" % _mod) | |
key = _mod, _dom, _sym | |
try: | |
cls = _modular_integer_cache[key] | |
except KeyError: | |
class cls(ModularInteger): | |
mod, dom, sym = _mod, _dom, _sym | |
_parent = parent | |
if _sym: | |
cls.__name__ = "SymmetricModularIntegerMod%s" % _mod | |
else: | |
cls.__name__ = "ModularIntegerMod%s" % _mod | |
_modular_integer_cache[key] = cls | |
return cls | |