Kano001's picture
Upload 3077 files
6a86ad5 verified
raw
history blame
3.75 kB
from sympy.printing.smtlib import smtlib_code
from sympy.assumptions.assume import AppliedPredicate
from sympy.assumptions.cnf import EncodedCNF
from sympy.assumptions.ask import Q
from sympy.core import Add, Mul
from sympy.core.relational import Equality, LessThan, GreaterThan, StrictLessThan, StrictGreaterThan
from sympy.functions.elementary.complexes import Abs
from sympy.functions.elementary.exponential import Pow
from sympy.functions.elementary.miscellaneous import Min, Max
from sympy.logic.boolalg import And, Or, Xor, Implies
from sympy.logic.boolalg import Not, ITE
from sympy.assumptions.relation.equality import StrictGreaterThanPredicate, StrictLessThanPredicate, GreaterThanPredicate, LessThanPredicate, EqualityPredicate
from sympy.external import import_module
def z3_satisfiable(expr, all_models=False):
if not isinstance(expr, EncodedCNF):
exprs = EncodedCNF()
exprs.add_prop(expr)
expr = exprs
z3 = import_module("z3")
if z3 is None:
raise ImportError("z3 is not installed")
s = encoded_cnf_to_z3_solver(expr, z3)
res = str(s.check())
if res == "unsat":
return False
elif res == "sat":
return z3_model_to_sympy_model(s.model(), expr)
else:
return None
def z3_model_to_sympy_model(z3_model, enc_cnf):
rev_enc = {value : key for key, value in enc_cnf.encoding.items()}
return {rev_enc[int(var.name()[1:])] : bool(z3_model[var]) for var in z3_model}
def clause_to_assertion(clause):
clause_strings = [f"d{abs(lit)}" if lit > 0 else f"(not d{abs(lit)})" for lit in clause]
return "(assert (or " + " ".join(clause_strings) + "))"
def encoded_cnf_to_z3_solver(enc_cnf, z3):
def dummify_bool(pred):
return False
assert isinstance(pred, AppliedPredicate)
if pred.function in [Q.positive, Q.negative, Q.zero]:
return pred
else:
return False
s = z3.Solver()
declarations = [f"(declare-const d{var} Bool)" for var in enc_cnf.variables]
assertions = [clause_to_assertion(clause) for clause in enc_cnf.data]
symbols = set()
for pred, enc in enc_cnf.encoding.items():
if not isinstance(pred, AppliedPredicate):
continue
if pred.function not in (Q.gt, Q.lt, Q.ge, Q.le, Q.ne, Q.eq, Q.positive, Q.negative, Q.extended_negative, Q.extended_positive, Q.zero, Q.nonzero, Q.nonnegative, Q.nonpositive, Q.extended_nonzero, Q.extended_nonnegative, Q.extended_nonpositive):
continue
pred_str = smtlib_code(pred, auto_declare=False, auto_assert=False, known_functions=known_functions)
symbols |= pred.free_symbols
pred = pred_str
clause = f"(implies d{enc} {pred})"
assertion = "(assert " + clause + ")"
assertions.append(assertion)
for sym in symbols:
declarations.append(f"(declare-const {sym} Real)")
declarations = "\n".join(declarations)
assertions = "\n".join(assertions)
s.from_string(declarations)
s.from_string(assertions)
return s
known_functions = {
Add: '+',
Mul: '*',
Equality: '=',
LessThan: '<=',
GreaterThan: '>=',
StrictLessThan: '<',
StrictGreaterThan: '>',
EqualityPredicate(): '=',
LessThanPredicate(): '<=',
GreaterThanPredicate(): '>=',
StrictLessThanPredicate(): '<',
StrictGreaterThanPredicate(): '>',
Abs: 'abs',
Min: 'min',
Max: 'max',
Pow: '^',
And: 'and',
Or: 'or',
Xor: 'xor',
Not: 'not',
ITE: 'ite',
Implies: '=>',
}