Spaces:
Running
Running
from sympy.printing.smtlib import smtlib_code | |
from sympy.assumptions.assume import AppliedPredicate | |
from sympy.assumptions.cnf import EncodedCNF | |
from sympy.assumptions.ask import Q | |
from sympy.core import Add, Mul | |
from sympy.core.relational import Equality, LessThan, GreaterThan, StrictLessThan, StrictGreaterThan | |
from sympy.functions.elementary.complexes import Abs | |
from sympy.functions.elementary.exponential import Pow | |
from sympy.functions.elementary.miscellaneous import Min, Max | |
from sympy.logic.boolalg import And, Or, Xor, Implies | |
from sympy.logic.boolalg import Not, ITE | |
from sympy.assumptions.relation.equality import StrictGreaterThanPredicate, StrictLessThanPredicate, GreaterThanPredicate, LessThanPredicate, EqualityPredicate | |
from sympy.external import import_module | |
def z3_satisfiable(expr, all_models=False): | |
if not isinstance(expr, EncodedCNF): | |
exprs = EncodedCNF() | |
exprs.add_prop(expr) | |
expr = exprs | |
z3 = import_module("z3") | |
if z3 is None: | |
raise ImportError("z3 is not installed") | |
s = encoded_cnf_to_z3_solver(expr, z3) | |
res = str(s.check()) | |
if res == "unsat": | |
return False | |
elif res == "sat": | |
return z3_model_to_sympy_model(s.model(), expr) | |
else: | |
return None | |
def z3_model_to_sympy_model(z3_model, enc_cnf): | |
rev_enc = {value : key for key, value in enc_cnf.encoding.items()} | |
return {rev_enc[int(var.name()[1:])] : bool(z3_model[var]) for var in z3_model} | |
def clause_to_assertion(clause): | |
clause_strings = [f"d{abs(lit)}" if lit > 0 else f"(not d{abs(lit)})" for lit in clause] | |
return "(assert (or " + " ".join(clause_strings) + "))" | |
def encoded_cnf_to_z3_solver(enc_cnf, z3): | |
def dummify_bool(pred): | |
return False | |
assert isinstance(pred, AppliedPredicate) | |
if pred.function in [Q.positive, Q.negative, Q.zero]: | |
return pred | |
else: | |
return False | |
s = z3.Solver() | |
declarations = [f"(declare-const d{var} Bool)" for var in enc_cnf.variables] | |
assertions = [clause_to_assertion(clause) for clause in enc_cnf.data] | |
symbols = set() | |
for pred, enc in enc_cnf.encoding.items(): | |
if not isinstance(pred, AppliedPredicate): | |
continue | |
if pred.function not in (Q.gt, Q.lt, Q.ge, Q.le, Q.ne, Q.eq, Q.positive, Q.negative, Q.extended_negative, Q.extended_positive, Q.zero, Q.nonzero, Q.nonnegative, Q.nonpositive, Q.extended_nonzero, Q.extended_nonnegative, Q.extended_nonpositive): | |
continue | |
pred_str = smtlib_code(pred, auto_declare=False, auto_assert=False, known_functions=known_functions) | |
symbols |= pred.free_symbols | |
pred = pred_str | |
clause = f"(implies d{enc} {pred})" | |
assertion = "(assert " + clause + ")" | |
assertions.append(assertion) | |
for sym in symbols: | |
declarations.append(f"(declare-const {sym} Real)") | |
declarations = "\n".join(declarations) | |
assertions = "\n".join(assertions) | |
s.from_string(declarations) | |
s.from_string(assertions) | |
return s | |
known_functions = { | |
Add: '+', | |
Mul: '*', | |
Equality: '=', | |
LessThan: '<=', | |
GreaterThan: '>=', | |
StrictLessThan: '<', | |
StrictGreaterThan: '>', | |
EqualityPredicate(): '=', | |
LessThanPredicate(): '<=', | |
GreaterThanPredicate(): '>=', | |
StrictLessThanPredicate(): '<', | |
StrictGreaterThanPredicate(): '>', | |
Abs: 'abs', | |
Min: 'min', | |
Max: 'max', | |
Pow: '^', | |
And: 'and', | |
Or: 'or', | |
Xor: 'xor', | |
Not: 'not', | |
ITE: 'ite', | |
Implies: '=>', | |
} | |