Spaces:
Sleeping
Sleeping
import typing | |
import sympy | |
from sympy.core import Add, Mul | |
from sympy.core import Symbol, Expr, Float, Rational, Integer, Basic | |
from sympy.core.function import UndefinedFunction, Function | |
from sympy.core.relational import Relational, Unequality, Equality, LessThan, GreaterThan, StrictLessThan, StrictGreaterThan | |
from sympy.functions.elementary.complexes import Abs | |
from sympy.functions.elementary.exponential import exp, log, Pow | |
from sympy.functions.elementary.hyperbolic import sinh, cosh, tanh | |
from sympy.functions.elementary.miscellaneous import Min, Max | |
from sympy.functions.elementary.piecewise import Piecewise | |
from sympy.functions.elementary.trigonometric import sin, cos, tan, asin, acos, atan, atan2 | |
from sympy.logic.boolalg import And, Or, Xor, Implies, Boolean | |
from sympy.logic.boolalg import BooleanTrue, BooleanFalse, BooleanFunction, Not, ITE | |
from sympy.printing.printer import Printer | |
from sympy.sets import Interval | |
from mpmath.libmp.libmpf import prec_to_dps, to_str as mlib_to_str | |
from sympy.assumptions.assume import AppliedPredicate | |
from sympy.assumptions.relation.binrel import AppliedBinaryRelation | |
from sympy.assumptions.ask import Q | |
from sympy.assumptions.relation.equality import StrictGreaterThanPredicate, StrictLessThanPredicate, GreaterThanPredicate, LessThanPredicate, EqualityPredicate | |
class SMTLibPrinter(Printer): | |
printmethod = "_smtlib" | |
# based on dReal, an automated reasoning tool for solving problems that can be encoded as first-order logic formulas over the real numbers. | |
# dReal's special strength is in handling problems that involve a wide range of nonlinear real functions. | |
_default_settings: dict = { | |
'precision': None, | |
'known_types': { | |
bool: 'Bool', | |
int: 'Int', | |
float: 'Real' | |
}, | |
'known_constants': { | |
# pi: 'MY_VARIABLE_PI_DECLARED_ELSEWHERE', | |
}, | |
'known_functions': { | |
Add: '+', | |
Mul: '*', | |
Equality: '=', | |
LessThan: '<=', | |
GreaterThan: '>=', | |
StrictLessThan: '<', | |
StrictGreaterThan: '>', | |
EqualityPredicate(): '=', | |
LessThanPredicate(): '<=', | |
GreaterThanPredicate(): '>=', | |
StrictLessThanPredicate(): '<', | |
StrictGreaterThanPredicate(): '>', | |
exp: 'exp', | |
log: 'log', | |
Abs: 'abs', | |
sin: 'sin', | |
cos: 'cos', | |
tan: 'tan', | |
asin: 'arcsin', | |
acos: 'arccos', | |
atan: 'arctan', | |
atan2: 'arctan2', | |
sinh: 'sinh', | |
cosh: 'cosh', | |
tanh: 'tanh', | |
Min: 'min', | |
Max: 'max', | |
Pow: 'pow', | |
And: 'and', | |
Or: 'or', | |
Xor: 'xor', | |
Not: 'not', | |
ITE: 'ite', | |
Implies: '=>', | |
} | |
} | |
symbol_table: dict | |
def __init__(self, settings: typing.Optional[dict] = None, | |
symbol_table=None): | |
settings = settings or {} | |
self.symbol_table = symbol_table or {} | |
Printer.__init__(self, settings) | |
self._precision = self._settings['precision'] | |
self._known_types = dict(self._settings['known_types']) | |
self._known_constants = dict(self._settings['known_constants']) | |
self._known_functions = dict(self._settings['known_functions']) | |
for _ in self._known_types.values(): assert self._is_legal_name(_) | |
for _ in self._known_constants.values(): assert self._is_legal_name(_) | |
# for _ in self._known_functions.values(): assert self._is_legal_name(_) # +, *, <, >, etc. | |
def _is_legal_name(self, s: str): | |
if not s: return False | |
if s[0].isnumeric(): return False | |
return all(_.isalnum() or _ == '_' for _ in s) | |
def _s_expr(self, op: str, args: typing.Union[list, tuple]) -> str: | |
args_str = ' '.join( | |
a if isinstance(a, str) | |
else self._print(a) | |
for a in args | |
) | |
return f'({op} {args_str})' | |
def _print_Function(self, e): | |
if e in self._known_functions: | |
op = self._known_functions[e] | |
elif type(e) in self._known_functions: | |
op = self._known_functions[type(e)] | |
elif type(type(e)) == UndefinedFunction: | |
op = e.name | |
elif isinstance(e, AppliedBinaryRelation) and e.function in self._known_functions: | |
op = self._known_functions[e.function] | |
return self._s_expr(op, e.arguments) | |
else: | |
op = self._known_functions[e] # throw KeyError | |
return self._s_expr(op, e.args) | |
def _print_Relational(self, e: Relational): | |
return self._print_Function(e) | |
def _print_BooleanFunction(self, e: BooleanFunction): | |
return self._print_Function(e) | |
def _print_Expr(self, e: Expr): | |
return self._print_Function(e) | |
def _print_Unequality(self, e: Unequality): | |
if type(e) in self._known_functions: | |
return self._print_Relational(e) # default | |
else: | |
eq_op = self._known_functions[Equality] | |
not_op = self._known_functions[Not] | |
return self._s_expr(not_op, [self._s_expr(eq_op, e.args)]) | |
def _print_Piecewise(self, e: Piecewise): | |
def _print_Piecewise_recursive(args: typing.Union[list, tuple]): | |
e, c = args[0] | |
if len(args) == 1: | |
assert (c is True) or isinstance(c, BooleanTrue) | |
return self._print(e) | |
else: | |
ite = self._known_functions[ITE] | |
return self._s_expr(ite, [ | |
c, e, _print_Piecewise_recursive(args[1:]) | |
]) | |
return _print_Piecewise_recursive(e.args) | |
def _print_Interval(self, e: Interval): | |
if e.start.is_infinite and e.end.is_infinite: | |
return '' | |
elif e.start.is_infinite != e.end.is_infinite: | |
raise ValueError(f'One-sided intervals (`{e}`) are not supported in SMT.') | |
else: | |
return f'[{e.start}, {e.end}]' | |
def _print_AppliedPredicate(self, e: AppliedPredicate): | |
if e.function == Q.positive: | |
rel = Q.gt(e.arguments[0],0) | |
elif e.function == Q.negative: | |
rel = Q.lt(e.arguments[0], 0) | |
elif e.function == Q.zero: | |
rel = Q.eq(e.arguments[0], 0) | |
elif e.function == Q.nonpositive: | |
rel = Q.le(e.arguments[0], 0) | |
elif e.function == Q.nonnegative: | |
rel = Q.ge(e.arguments[0], 0) | |
elif e.function == Q.nonzero: | |
rel = Q.ne(e.arguments[0], 0) | |
else: | |
raise ValueError(f"Predicate (`{e}`) is not handled.") | |
return self._print_AppliedBinaryRelation(rel) | |
def _print_AppliedBinaryRelation(self, e: AppliedPredicate): | |
if e.function == Q.ne: | |
return self._print_Unequality(Unequality(*e.arguments)) | |
else: | |
return self._print_Function(e) | |
# todo: Sympy does not support quantifiers yet as of 2022, but quantifiers can be handy in SMT. | |
# For now, users can extend this class and build in their own quantifier support. | |
# See `test_quantifier_extensions()` in test_smtlib.py for an example of how this might look. | |
# def _print_ForAll(self, e: ForAll): | |
# return self._s('forall', [ | |
# self._s('', [ | |
# self._s(sym.name, [self._type_name(sym), Interval(start, end)]) | |
# for sym, start, end in e.limits | |
# ]), | |
# e.function | |
# ]) | |
def _print_BooleanTrue(self, x: BooleanTrue): | |
return 'true' | |
def _print_BooleanFalse(self, x: BooleanFalse): | |
return 'false' | |
def _print_Float(self, x: Float): | |
dps = prec_to_dps(x._prec) | |
str_real = mlib_to_str(x._mpf_, dps, strip_zeros=True, min_fixed=None, max_fixed=None) | |
if 'e' in str_real: | |
(mant, exp) = str_real.split('e') | |
if exp[0] == '+': | |
exp = exp[1:] | |
mul = self._known_functions[Mul] | |
pow = self._known_functions[Pow] | |
return r"(%s %s (%s 10 %s))" % (mul, mant, pow, exp) | |
elif str_real in ["+inf", "-inf"]: | |
raise ValueError("Infinite values are not supported in SMT.") | |
else: | |
return str_real | |
def _print_float(self, x: float): | |
return self._print(Float(x)) | |
def _print_Rational(self, x: Rational): | |
return self._s_expr('/', [x.p, x.q]) | |
def _print_Integer(self, x: Integer): | |
assert x.q == 1 | |
return str(x.p) | |
def _print_int(self, x: int): | |
return str(x) | |
def _print_Symbol(self, x: Symbol): | |
assert self._is_legal_name(x.name) | |
return x.name | |
def _print_NumberSymbol(self, x): | |
name = self._known_constants.get(x) | |
if name: | |
return name | |
else: | |
f = x.evalf(self._precision) if self._precision else x.evalf() | |
return self._print_Float(f) | |
def _print_UndefinedFunction(self, x): | |
assert self._is_legal_name(x.name) | |
return x.name | |
def _print_Exp1(self, x): | |
return ( | |
self._print_Function(exp(1, evaluate=False)) | |
if exp in self._known_functions else | |
self._print_NumberSymbol(x) | |
) | |
def emptyPrinter(self, expr): | |
raise NotImplementedError(f'Cannot convert `{repr(expr)}` of type `{type(expr)}` to SMT.') | |
def smtlib_code( | |
expr, | |
auto_assert=True, auto_declare=True, | |
precision=None, | |
symbol_table=None, | |
known_types=None, known_constants=None, known_functions=None, | |
prefix_expressions=None, suffix_expressions=None, | |
log_warn=None | |
): | |
r"""Converts ``expr`` to a string of smtlib code. | |
Parameters | |
========== | |
expr : Expr | List[Expr] | |
A SymPy expression or system to be converted. | |
auto_assert : bool, optional | |
If false, do not modify expr and produce only the S-Expression equivalent of expr. | |
If true, assume expr is a system and assert each boolean element. | |
auto_declare : bool, optional | |
If false, do not produce declarations for the symbols used in expr. | |
If true, prepend all necessary declarations for variables used in expr based on symbol_table. | |
precision : integer, optional | |
The ``evalf(..)`` precision for numbers such as pi. | |
symbol_table : dict, optional | |
A dictionary where keys are ``Symbol`` or ``Function`` instances and values are their Python type i.e. ``bool``, ``int``, ``float``, or ``Callable[...]``. | |
If incomplete, an attempt will be made to infer types from ``expr``. | |
known_types: dict, optional | |
A dictionary where keys are ``bool``, ``int``, ``float`` etc. and values are their corresponding SMT type names. | |
If not given, a partial listing compatible with several solvers will be used. | |
known_functions : dict, optional | |
A dictionary where keys are ``Function``, ``Relational``, ``BooleanFunction``, or ``Expr`` instances and values are their SMT string representations. | |
If not given, a partial listing optimized for dReal solver (but compatible with others) will be used. | |
known_constants: dict, optional | |
A dictionary where keys are ``NumberSymbol`` instances and values are their SMT variable names. | |
When using this feature, extra caution must be taken to avoid naming collisions between user symbols and listed constants. | |
If not given, constants will be expanded inline i.e. ``3.14159`` instead of ``MY_SMT_VARIABLE_FOR_PI``. | |
prefix_expressions: list, optional | |
A list of lists of ``str`` and/or expressions to convert into SMTLib and prefix to the output. | |
suffix_expressions: list, optional | |
A list of lists of ``str`` and/or expressions to convert into SMTLib and postfix to the output. | |
log_warn: lambda function, optional | |
A function to record all warnings during potentially risky operations. | |
Soundness is a core value in SMT solving, so it is good to log all assumptions made. | |
Examples | |
======== | |
>>> from sympy import smtlib_code, symbols, sin, Eq | |
>>> x = symbols('x') | |
>>> smtlib_code(sin(x).series(x).removeO(), log_warn=print) | |
Could not infer type of `x`. Defaulting to float. | |
Non-Boolean expression `x**5/120 - x**3/6 + x` will not be asserted. Converting to SMTLib verbatim. | |
'(declare-const x Real)\n(+ x (* (/ -1 6) (pow x 3)) (* (/ 1 120) (pow x 5)))' | |
>>> from sympy import Rational | |
>>> x, y, tau = symbols("x, y, tau") | |
>>> smtlib_code((2*tau)**Rational(7, 2), log_warn=print) | |
Could not infer type of `tau`. Defaulting to float. | |
Non-Boolean expression `8*sqrt(2)*tau**(7/2)` will not be asserted. Converting to SMTLib verbatim. | |
'(declare-const tau Real)\n(* 8 (pow 2 (/ 1 2)) (pow tau (/ 7 2)))' | |
``Piecewise`` expressions are implemented with ``ite`` expressions by default. | |
Note that if the ``Piecewise`` lacks a default term, represented by | |
``(expr, True)`` then an error will be thrown. This is to prevent | |
generating an expression that may not evaluate to anything. | |
>>> from sympy import Piecewise | |
>>> pw = Piecewise((x + 1, x > 0), (x, True)) | |
>>> smtlib_code(Eq(pw, 3), symbol_table={x: float}, log_warn=print) | |
'(declare-const x Real)\n(assert (= (ite (> x 0) (+ 1 x) x) 3))' | |
Custom printing can be defined for certain types by passing a dictionary of | |
PythonType : "SMT Name" to the ``known_types``, ``known_constants``, and ``known_functions`` kwargs. | |
>>> from typing import Callable | |
>>> from sympy import Function, Add | |
>>> f = Function('f') | |
>>> g = Function('g') | |
>>> smt_builtin_funcs = { # functions our SMT solver will understand | |
... f: "existing_smtlib_fcn", | |
... Add: "sum", | |
... } | |
>>> user_def_funcs = { # functions defined by the user must have their types specified explicitly | |
... g: Callable[[int], float], | |
... } | |
>>> smtlib_code(f(x) + g(x), symbol_table=user_def_funcs, known_functions=smt_builtin_funcs, log_warn=print) | |
Non-Boolean expression `f(x) + g(x)` will not be asserted. Converting to SMTLib verbatim. | |
'(declare-const x Int)\n(declare-fun g (Int) Real)\n(sum (existing_smtlib_fcn x) (g x))' | |
""" | |
log_warn = log_warn or (lambda _: None) | |
if not isinstance(expr, list): expr = [expr] | |
expr = [ | |
sympy.sympify(_, strict=True, evaluate=False, convert_xor=False) | |
for _ in expr | |
] | |
if not symbol_table: symbol_table = {} | |
symbol_table = _auto_infer_smtlib_types( | |
*expr, symbol_table=symbol_table | |
) | |
# See [FALLBACK RULES] | |
# Need SMTLibPrinter to populate known_functions and known_constants first. | |
settings = {} | |
if precision: settings['precision'] = precision | |
del precision | |
if known_types: settings['known_types'] = known_types | |
del known_types | |
if known_functions: settings['known_functions'] = known_functions | |
del known_functions | |
if known_constants: settings['known_constants'] = known_constants | |
del known_constants | |
if not prefix_expressions: prefix_expressions = [] | |
if not suffix_expressions: suffix_expressions = [] | |
p = SMTLibPrinter(settings, symbol_table) | |
del symbol_table | |
# [FALLBACK RULES] | |
for e in expr: | |
for sym in e.atoms(Symbol, Function): | |
if ( | |
sym.is_Symbol and | |
sym not in p._known_constants and | |
sym not in p.symbol_table | |
): | |
log_warn(f"Could not infer type of `{sym}`. Defaulting to float.") | |
p.symbol_table[sym] = float | |
if ( | |
sym.is_Function and | |
type(sym) not in p._known_functions and | |
type(sym) not in p.symbol_table and | |
not sym.is_Piecewise | |
): raise TypeError( | |
f"Unknown type of undefined function `{sym}`. " | |
f"Must be mapped to ``str`` in known_functions or mapped to ``Callable[..]`` in symbol_table." | |
) | |
declarations = [] | |
if auto_declare: | |
constants = {sym.name: sym for e in expr for sym in e.free_symbols | |
if sym not in p._known_constants} | |
functions = {fnc.name: fnc for e in expr for fnc in e.atoms(Function) | |
if type(fnc) not in p._known_functions and not fnc.is_Piecewise} | |
declarations = \ | |
[ | |
_auto_declare_smtlib(sym, p, log_warn) | |
for sym in constants.values() | |
] + [ | |
_auto_declare_smtlib(fnc, p, log_warn) | |
for fnc in functions.values() | |
] | |
declarations = [decl for decl in declarations if decl] | |
if auto_assert: | |
expr = [_auto_assert_smtlib(e, p, log_warn) for e in expr] | |
# return SMTLibPrinter().doprint(expr) | |
return '\n'.join([ | |
# ';; PREFIX EXPRESSIONS', | |
*[ | |
e if isinstance(e, str) else p.doprint(e) | |
for e in prefix_expressions | |
], | |
# ';; DECLARATIONS', | |
*sorted(e for e in declarations), | |
# ';; EXPRESSIONS', | |
*[ | |
e if isinstance(e, str) else p.doprint(e) | |
for e in expr | |
], | |
# ';; SUFFIX EXPRESSIONS', | |
*[ | |
e if isinstance(e, str) else p.doprint(e) | |
for e in suffix_expressions | |
], | |
]) | |
def _auto_declare_smtlib(sym: typing.Union[Symbol, Function], p: SMTLibPrinter, log_warn: typing.Callable[[str], None]): | |
if sym.is_Symbol: | |
type_signature = p.symbol_table[sym] | |
assert isinstance(type_signature, type) | |
type_signature = p._known_types[type_signature] | |
return p._s_expr('declare-const', [sym, type_signature]) | |
elif sym.is_Function: | |
type_signature = p.symbol_table[type(sym)] | |
assert callable(type_signature) | |
type_signature = [p._known_types[_] for _ in type_signature.__args__] | |
assert len(type_signature) > 0 | |
params_signature = f"({' '.join(type_signature[:-1])})" | |
return_signature = type_signature[-1] | |
return p._s_expr('declare-fun', [type(sym), params_signature, return_signature]) | |
else: | |
log_warn(f"Non-Symbol/Function `{sym}` will not be declared.") | |
return None | |
def _auto_assert_smtlib(e: Expr, p: SMTLibPrinter, log_warn: typing.Callable[[str], None]): | |
if isinstance(e, Boolean) or ( | |
e in p.symbol_table and p.symbol_table[e] == bool | |
) or ( | |
e.is_Function and | |
type(e) in p.symbol_table and | |
p.symbol_table[type(e)].__args__[-1] == bool | |
): | |
return p._s_expr('assert', [e]) | |
else: | |
log_warn(f"Non-Boolean expression `{e}` will not be asserted. Converting to SMTLib verbatim.") | |
return e | |
def _auto_infer_smtlib_types( | |
*exprs: Basic, | |
symbol_table: typing.Optional[dict] = None | |
) -> dict: | |
# [TYPE INFERENCE RULES] | |
# X is alone in an expr => X is bool | |
# X in BooleanFunction.args => X is bool | |
# X matches to a bool param of a symbol_table function => X is bool | |
# X matches to an int param of a symbol_table function => X is int | |
# X.is_integer => X is int | |
# X == Y, where X is T => Y is T | |
# [FALLBACK RULES] | |
# see _auto_declare_smtlib(..) | |
# X is not bool and X is not int and X is Symbol => X is float | |
# else (e.g. X is Function) => error. must be specified explicitly. | |
_symbols = dict(symbol_table) if symbol_table else {} | |
def safe_update(syms: set, inf): | |
for s in syms: | |
assert s.is_Symbol | |
if (old_type := _symbols.setdefault(s, inf)) != inf: | |
raise TypeError(f"Could not infer type of `{s}`. Apparently both `{old_type}` and `{inf}`?") | |
# EXPLICIT TYPES | |
safe_update({ | |
e | |
for e in exprs | |
if e.is_Symbol | |
}, bool) | |
safe_update({ | |
symbol | |
for e in exprs | |
for boolfunc in e.atoms(BooleanFunction) | |
for symbol in boolfunc.args | |
if symbol.is_Symbol | |
}, bool) | |
safe_update({ | |
symbol | |
for e in exprs | |
for boolfunc in e.atoms(Function) | |
if type(boolfunc) in _symbols | |
for symbol, param in zip(boolfunc.args, _symbols[type(boolfunc)].__args__) | |
if symbol.is_Symbol and param == bool | |
}, bool) | |
safe_update({ | |
symbol | |
for e in exprs | |
for intfunc in e.atoms(Function) | |
if type(intfunc) in _symbols | |
for symbol, param in zip(intfunc.args, _symbols[type(intfunc)].__args__) | |
if symbol.is_Symbol and param == int | |
}, int) | |
safe_update({ | |
symbol | |
for e in exprs | |
for symbol in e.atoms(Symbol) | |
if symbol.is_integer | |
}, int) | |
safe_update({ | |
symbol | |
for e in exprs | |
for symbol in e.atoms(Symbol) | |
if symbol.is_real and not symbol.is_integer | |
}, float) | |
# EQUALITY RELATION RULE | |
rels = [rel for expr in exprs for rel in expr.atoms(Equality)] | |
rels = [ | |
(rel.lhs, rel.rhs) for rel in rels if rel.lhs.is_Symbol | |
] + [ | |
(rel.rhs, rel.lhs) for rel in rels if rel.rhs.is_Symbol | |
] | |
for infer, reltd in rels: | |
inference = ( | |
_symbols[infer] if infer in _symbols else | |
_symbols[reltd] if reltd in _symbols else | |
_symbols[type(reltd)].__args__[-1] | |
if reltd.is_Function and type(reltd) in _symbols else | |
bool if reltd.is_Boolean else | |
int if reltd.is_integer or reltd.is_Integer else | |
float if reltd.is_real else | |
None | |
) | |
if inference: safe_update({infer}, inference) | |
return _symbols | |