Spaces:
Sleeping
Sleeping
from sympy.printing.dot import (purestr, styleof, attrprint, dotnode, | |
dotedges, dotprint) | |
from sympy.core.basic import Basic | |
from sympy.core.expr import Expr | |
from sympy.core.numbers import (Float, Integer) | |
from sympy.core.singleton import S | |
from sympy.core.symbol import (Symbol, symbols) | |
from sympy.printing.repr import srepr | |
from sympy.abc import x | |
def test_purestr(): | |
assert purestr(Symbol('x')) == "Symbol('x')" | |
assert purestr(Basic(S(1), S(2))) == "Basic(Integer(1), Integer(2))" | |
assert purestr(Float(2)) == "Float('2.0', precision=53)" | |
assert purestr(Symbol('x'), with_args=True) == ("Symbol('x')", ()) | |
assert purestr(Basic(S(1), S(2)), with_args=True) == \ | |
('Basic(Integer(1), Integer(2))', ('Integer(1)', 'Integer(2)')) | |
assert purestr(Float(2), with_args=True) == \ | |
("Float('2.0', precision=53)", ()) | |
def test_styleof(): | |
styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}), | |
(Expr, {'color': 'black'})] | |
assert styleof(Basic(S(1)), styles) == {'color': 'blue', 'shape': 'ellipse'} | |
assert styleof(x + 1, styles) == {'color': 'black', 'shape': 'ellipse'} | |
def test_attrprint(): | |
assert attrprint({'color': 'blue', 'shape': 'ellipse'}) == \ | |
'"color"="blue", "shape"="ellipse"' | |
def test_dotnode(): | |
assert dotnode(x, repeat=False) == \ | |
'"Symbol(\'x\')" ["color"="black", "label"="x", "shape"="ellipse"];' | |
assert dotnode(x+2, repeat=False) == \ | |
'"Add(Integer(2), Symbol(\'x\'))" ' \ | |
'["color"="black", "label"="Add", "shape"="ellipse"];', \ | |
dotnode(x+2,repeat=0) | |
assert dotnode(x + x**2, repeat=False) == \ | |
'"Add(Symbol(\'x\'), Pow(Symbol(\'x\'), Integer(2)))" ' \ | |
'["color"="black", "label"="Add", "shape"="ellipse"];' | |
assert dotnode(x + x**2, repeat=True) == \ | |
'"Add(Symbol(\'x\'), Pow(Symbol(\'x\'), Integer(2)))_()" ' \ | |
'["color"="black", "label"="Add", "shape"="ellipse"];' | |
def test_dotedges(): | |
assert sorted(dotedges(x+2, repeat=False)) == [ | |
'"Add(Integer(2), Symbol(\'x\'))" -> "Integer(2)";', | |
'"Add(Integer(2), Symbol(\'x\'))" -> "Symbol(\'x\')";' | |
] | |
assert sorted(dotedges(x + 2, repeat=True)) == [ | |
'"Add(Integer(2), Symbol(\'x\'))_()" -> "Integer(2)_(0,)";', | |
'"Add(Integer(2), Symbol(\'x\'))_()" -> "Symbol(\'x\')_(1,)";' | |
] | |
def test_dotprint(): | |
text = dotprint(x+2, repeat=False) | |
assert all(e in text for e in dotedges(x+2, repeat=False)) | |
assert all( | |
n in text for n in [dotnode(expr, repeat=False) | |
for expr in (x, Integer(2), x+2)]) | |
assert 'digraph' in text | |
text = dotprint(x+x**2, repeat=False) | |
assert all(e in text for e in dotedges(x+x**2, repeat=False)) | |
assert all( | |
n in text for n in [dotnode(expr, repeat=False) | |
for expr in (x, Integer(2), x**2)]) | |
assert 'digraph' in text | |
text = dotprint(x+x**2, repeat=True) | |
assert all(e in text for e in dotedges(x+x**2, repeat=True)) | |
assert all( | |
n in text for n in [dotnode(expr, pos=()) | |
for expr in [x + x**2]]) | |
text = dotprint(x**x, repeat=True) | |
assert all(e in text for e in dotedges(x**x, repeat=True)) | |
assert all( | |
n in text for n in [dotnode(x, pos=(0,)), dotnode(x, pos=(1,))]) | |
assert 'digraph' in text | |
def test_dotprint_depth(): | |
text = dotprint(3*x+2, depth=1) | |
assert dotnode(3*x+2) in text | |
assert dotnode(x) not in text | |
text = dotprint(3*x+2) | |
assert "depth" not in text | |
def test_Matrix_and_non_basics(): | |
from sympy.matrices.expressions.matexpr import MatrixSymbol | |
n = Symbol('n') | |
assert dotprint(MatrixSymbol('X', n, n)) == \ | |
"""digraph{ | |
# Graph style | |
"ordering"="out" | |
"rankdir"="TD" | |
######### | |
# Nodes # | |
######### | |
"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" ["color"="black", "label"="MatrixSymbol", "shape"="ellipse"]; | |
"Str('X')_(0,)" ["color"="blue", "label"="X", "shape"="ellipse"]; | |
"Symbol('n')_(1,)" ["color"="black", "label"="n", "shape"="ellipse"]; | |
"Symbol('n')_(2,)" ["color"="black", "label"="n", "shape"="ellipse"]; | |
######### | |
# Edges # | |
######### | |
"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Str('X')_(0,)"; | |
"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Symbol('n')_(1,)"; | |
"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Symbol('n')_(2,)"; | |
}""" | |
def test_labelfunc(): | |
text = dotprint(x + 2, labelfunc=srepr) | |
assert "Symbol('x')" in text | |
assert "Integer(2)" in text | |
def test_commutative(): | |
x, y = symbols('x y', commutative=False) | |
assert dotprint(x + y) == dotprint(y + x) | |
assert dotprint(x*y) != dotprint(y*x) | |