Spaces:
Sleeping
Sleeping
File size: 6,123 Bytes
6a86ad5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
"""
C++ code printer
"""
from itertools import chain
from sympy.codegen.ast import Type, none
from .codeprinter import requires
from .c import C89CodePrinter, C99CodePrinter
# These are defined in the other file so we can avoid importing sympy.codegen
# from the top-level 'import sympy'. Export them here as well.
from sympy.printing.codeprinter import cxxcode # noqa:F401
# from https://en.cppreference.com/w/cpp/keyword
reserved = {
'C++98': [
'and', 'and_eq', 'asm', 'auto', 'bitand', 'bitor', 'bool', 'break',
'case', 'catch,', 'char', 'class', 'compl', 'const', 'const_cast',
'continue', 'default', 'delete', 'do', 'double', 'dynamic_cast',
'else', 'enum', 'explicit', 'export', 'extern', 'false', 'float',
'for', 'friend', 'goto', 'if', 'inline', 'int', 'long', 'mutable',
'namespace', 'new', 'not', 'not_eq', 'operator', 'or', 'or_eq',
'private', 'protected', 'public', 'register', 'reinterpret_cast',
'return', 'short', 'signed', 'sizeof', 'static', 'static_cast',
'struct', 'switch', 'template', 'this', 'throw', 'true', 'try',
'typedef', 'typeid', 'typename', 'union', 'unsigned', 'using',
'virtual', 'void', 'volatile', 'wchar_t', 'while', 'xor', 'xor_eq'
]
}
reserved['C++11'] = reserved['C++98'][:] + [
'alignas', 'alignof', 'char16_t', 'char32_t', 'constexpr', 'decltype',
'noexcept', 'nullptr', 'static_assert', 'thread_local'
]
reserved['C++17'] = reserved['C++11'][:]
reserved['C++17'].remove('register')
# TM TS: atomic_cancel, atomic_commit, atomic_noexcept, synchronized
# concepts TS: concept, requires
# module TS: import, module
_math_functions = {
'C++98': {
'Mod': 'fmod',
'ceiling': 'ceil',
},
'C++11': {
'gamma': 'tgamma',
},
'C++17': {
'beta': 'beta',
'Ei': 'expint',
'zeta': 'riemann_zeta',
}
}
# from https://en.cppreference.com/w/cpp/header/cmath
for k in ('Abs', 'exp', 'log', 'log10', 'sqrt', 'sin', 'cos', 'tan', # 'Pow'
'asin', 'acos', 'atan', 'atan2', 'sinh', 'cosh', 'tanh', 'floor'):
_math_functions['C++98'][k] = k.lower()
for k in ('asinh', 'acosh', 'atanh', 'erf', 'erfc'):
_math_functions['C++11'][k] = k.lower()
def _attach_print_method(cls, sympy_name, func_name):
meth_name = '_print_%s' % sympy_name
if hasattr(cls, meth_name):
raise ValueError("Edit method (or subclass) instead of overwriting.")
def _print_method(self, expr):
return '{}{}({})'.format(self._ns, func_name, ', '.join(map(self._print, expr.args)))
_print_method.__doc__ = "Prints code for %s" % k
setattr(cls, meth_name, _print_method)
def _attach_print_methods(cls, cont):
for sympy_name, cxx_name in cont[cls.standard].items():
_attach_print_method(cls, sympy_name, cxx_name)
class _CXXCodePrinterBase:
printmethod = "_cxxcode"
language = 'C++'
_ns = 'std::' # namespace
def __init__(self, settings=None):
super().__init__(settings or {})
@requires(headers={'algorithm'})
def _print_Max(self, expr):
from sympy.functions.elementary.miscellaneous import Max
if len(expr.args) == 1:
return self._print(expr.args[0])
return "%smax(%s, %s)" % (self._ns, self._print(expr.args[0]),
self._print(Max(*expr.args[1:])))
@requires(headers={'algorithm'})
def _print_Min(self, expr):
from sympy.functions.elementary.miscellaneous import Min
if len(expr.args) == 1:
return self._print(expr.args[0])
return "%smin(%s, %s)" % (self._ns, self._print(expr.args[0]),
self._print(Min(*expr.args[1:])))
def _print_using(self, expr):
if expr.alias == none:
return 'using %s' % expr.type
else:
raise ValueError("C++98 does not support type aliases")
def _print_Raise(self, rs):
arg, = rs.args
return 'throw %s' % self._print(arg)
@requires(headers={'stdexcept'})
def _print_RuntimeError_(self, re):
message, = re.args
return "%sruntime_error(%s)" % (self._ns, self._print(message))
class CXX98CodePrinter(_CXXCodePrinterBase, C89CodePrinter):
standard = 'C++98'
reserved_words = set(reserved['C++98'])
# _attach_print_methods(CXX98CodePrinter, _math_functions)
class CXX11CodePrinter(_CXXCodePrinterBase, C99CodePrinter):
standard = 'C++11'
reserved_words = set(reserved['C++11'])
type_mappings = dict(chain(
CXX98CodePrinter.type_mappings.items(),
{
Type('int8'): ('int8_t', {'cstdint'}),
Type('int16'): ('int16_t', {'cstdint'}),
Type('int32'): ('int32_t', {'cstdint'}),
Type('int64'): ('int64_t', {'cstdint'}),
Type('uint8'): ('uint8_t', {'cstdint'}),
Type('uint16'): ('uint16_t', {'cstdint'}),
Type('uint32'): ('uint32_t', {'cstdint'}),
Type('uint64'): ('uint64_t', {'cstdint'}),
Type('complex64'): ('std::complex<float>', {'complex'}),
Type('complex128'): ('std::complex<double>', {'complex'}),
Type('bool'): ('bool', None),
}.items()
))
def _print_using(self, expr):
if expr.alias == none:
return super()._print_using(expr)
else:
return 'using %(alias)s = %(type)s' % expr.kwargs(apply=self._print)
# _attach_print_methods(CXX11CodePrinter, _math_functions)
class CXX17CodePrinter(_CXXCodePrinterBase, C99CodePrinter):
standard = 'C++17'
reserved_words = set(reserved['C++17'])
_kf = dict(C99CodePrinter._kf, **_math_functions['C++17'])
def _print_beta(self, expr):
return self._print_math_func(expr)
def _print_Ei(self, expr):
return self._print_math_func(expr)
def _print_zeta(self, expr):
return self._print_math_func(expr)
# _attach_print_methods(CXX17CodePrinter, _math_functions)
cxx_code_printers = {
'c++98': CXX98CodePrinter,
'c++11': CXX11CodePrinter,
'c++17': CXX17CodePrinter
}
|