Spaces:
Sleeping
Sleeping
| """ | |
| Fortran code printer | |
| The FCodePrinter converts single SymPy expressions into single Fortran | |
| expressions, using the functions defined in the Fortran 77 standard where | |
| possible. Some useful pointers to Fortran can be found on wikipedia: | |
| https://en.wikipedia.org/wiki/Fortran | |
| Most of the code below is based on the "Professional Programmer\'s Guide to | |
| Fortran77" by Clive G. Page: | |
| https://www.star.le.ac.uk/~cgp/prof77.html | |
| Fortran is a case-insensitive language. This might cause trouble because | |
| SymPy is case sensitive. So, fcode adds underscores to variable names when | |
| it is necessary to make them different for Fortran. | |
| """ | |
| from __future__ import annotations | |
| from typing import Any | |
| from collections import defaultdict | |
| from itertools import chain | |
| import string | |
| from sympy.codegen.ast import ( | |
| Assignment, Declaration, Pointer, value_const, | |
| float32, float64, float80, complex64, complex128, int8, int16, int32, | |
| int64, intc, real, integer, bool_, complex_, none, stderr, stdout | |
| ) | |
| from sympy.codegen.fnodes import ( | |
| allocatable, isign, dsign, cmplx, merge, literal_dp, elemental, pure, | |
| intent_in, intent_out, intent_inout | |
| ) | |
| from sympy.core import S, Add, N, Float, Symbol | |
| from sympy.core.function import Function | |
| from sympy.core.numbers import equal_valued | |
| from sympy.core.relational import Eq | |
| from sympy.sets import Range | |
| from sympy.printing.codeprinter import CodePrinter | |
| from sympy.printing.precedence import precedence, PRECEDENCE | |
| from sympy.printing.printer import printer_context | |
| # These are defined in the other file so we can avoid importing sympy.codegen | |
| # from the top-level 'import sympy'. Export them here as well. | |
| from sympy.printing.codeprinter import fcode, print_fcode # noqa:F401 | |
| known_functions = { | |
| "sin": "sin", | |
| "cos": "cos", | |
| "tan": "tan", | |
| "asin": "asin", | |
| "acos": "acos", | |
| "atan": "atan", | |
| "atan2": "atan2", | |
| "sinh": "sinh", | |
| "cosh": "cosh", | |
| "tanh": "tanh", | |
| "log": "log", | |
| "exp": "exp", | |
| "erf": "erf", | |
| "Abs": "abs", | |
| "conjugate": "conjg", | |
| "Max": "max", | |
| "Min": "min", | |
| } | |
| class FCodePrinter(CodePrinter): | |
| """A printer to convert SymPy expressions to strings of Fortran code""" | |
| printmethod = "_fcode" | |
| language = "Fortran" | |
| type_aliases = { | |
| integer: int32, | |
| real: float64, | |
| complex_: complex128, | |
| } | |
| type_mappings = { | |
| intc: 'integer(c_int)', | |
| float32: 'real*4', # real(kind(0.e0)) | |
| float64: 'real*8', # real(kind(0.d0)) | |
| float80: 'real*10', # real(kind(????)) | |
| complex64: 'complex*8', | |
| complex128: 'complex*16', | |
| int8: 'integer*1', | |
| int16: 'integer*2', | |
| int32: 'integer*4', | |
| int64: 'integer*8', | |
| bool_: 'logical' | |
| } | |
| type_modules = { | |
| intc: {'iso_c_binding': 'c_int'} | |
| } | |
| _default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{ | |
| 'precision': 17, | |
| 'user_functions': {}, | |
| 'source_format': 'fixed', | |
| 'contract': True, | |
| 'standard': 77, | |
| 'name_mangling': True, | |
| }) | |
| _operators = { | |
| 'and': '.and.', | |
| 'or': '.or.', | |
| 'xor': '.neqv.', | |
| 'equivalent': '.eqv.', | |
| 'not': '.not. ', | |
| } | |
| _relationals = { | |
| '!=': '/=', | |
| } | |
| def __init__(self, settings=None): | |
| if not settings: | |
| settings = {} | |
| self.mangled_symbols = {} # Dict showing mapping of all words | |
| self.used_name = [] | |
| self.type_aliases = dict(chain(self.type_aliases.items(), | |
| settings.pop('type_aliases', {}).items())) | |
| self.type_mappings = dict(chain(self.type_mappings.items(), | |
| settings.pop('type_mappings', {}).items())) | |
| super().__init__(settings) | |
| self.known_functions = dict(known_functions) | |
| userfuncs = settings.get('user_functions', {}) | |
| self.known_functions.update(userfuncs) | |
| # leading columns depend on fixed or free format | |
| standards = {66, 77, 90, 95, 2003, 2008} | |
| if self._settings['standard'] not in standards: | |
| raise ValueError("Unknown Fortran standard: %s" % self._settings[ | |
| 'standard']) | |
| self.module_uses = defaultdict(set) # e.g.: use iso_c_binding, only: c_int | |
| def _lead(self): | |
| if self._settings['source_format'] == 'fixed': | |
| return {'code': " ", 'cont': " @ ", 'comment': "C "} | |
| elif self._settings['source_format'] == 'free': | |
| return {'code': "", 'cont': " ", 'comment': "! "} | |
| else: | |
| raise ValueError("Unknown source format: %s" % self._settings['source_format']) | |
| def _print_Symbol(self, expr): | |
| if self._settings['name_mangling'] == True: | |
| if expr not in self.mangled_symbols: | |
| name = expr.name | |
| while name.lower() in self.used_name: | |
| name += '_' | |
| self.used_name.append(name.lower()) | |
| if name == expr.name: | |
| self.mangled_symbols[expr] = expr | |
| else: | |
| self.mangled_symbols[expr] = Symbol(name) | |
| expr = expr.xreplace(self.mangled_symbols) | |
| name = super()._print_Symbol(expr) | |
| return name | |
| def _rate_index_position(self, p): | |
| return -p*5 | |
| def _get_statement(self, codestring): | |
| return codestring | |
| def _get_comment(self, text): | |
| return "! {}".format(text) | |
| def _declare_number_const(self, name, value): | |
| return "parameter ({} = {})".format(name, self._print(value)) | |
| def _print_NumberSymbol(self, expr): | |
| # A Number symbol that is not implemented here or with _printmethod | |
| # is registered and evaluated | |
| self._number_symbols.add((expr, Float(expr.evalf(self._settings['precision'])))) | |
| return str(expr) | |
| def _format_code(self, lines): | |
| return self._wrap_fortran(self.indent_code(lines)) | |
| def _traverse_matrix_indices(self, mat): | |
| rows, cols = mat.shape | |
| return ((i, j) for j in range(cols) for i in range(rows)) | |
| def _get_loop_opening_ending(self, indices): | |
| open_lines = [] | |
| close_lines = [] | |
| for i in indices: | |
| # fortran arrays start at 1 and end at dimension | |
| var, start, stop = map(self._print, | |
| [i.label, i.lower + 1, i.upper + 1]) | |
| open_lines.append("do %s = %s, %s" % (var, start, stop)) | |
| close_lines.append("end do") | |
| return open_lines, close_lines | |
| def _print_sign(self, expr): | |
| from sympy.functions.elementary.complexes import Abs | |
| arg, = expr.args | |
| if arg.is_integer: | |
| new_expr = merge(0, isign(1, arg), Eq(arg, 0)) | |
| elif (arg.is_complex or arg.is_infinite): | |
| new_expr = merge(cmplx(literal_dp(0), literal_dp(0)), arg/Abs(arg), Eq(Abs(arg), literal_dp(0))) | |
| else: | |
| new_expr = merge(literal_dp(0), dsign(literal_dp(1), arg), Eq(arg, literal_dp(0))) | |
| return self._print(new_expr) | |
| def _print_Piecewise(self, expr): | |
| if expr.args[-1].cond != True: | |
| # We need the last conditional to be a True, otherwise the resulting | |
| # function may not return a result. | |
| raise ValueError("All Piecewise expressions must contain an " | |
| "(expr, True) statement to be used as a default " | |
| "condition. Without one, the generated " | |
| "expression may not evaluate to anything under " | |
| "some condition.") | |
| lines = [] | |
| if expr.has(Assignment): | |
| for i, (e, c) in enumerate(expr.args): | |
| if i == 0: | |
| lines.append("if (%s) then" % self._print(c)) | |
| elif i == len(expr.args) - 1 and c == True: | |
| lines.append("else") | |
| else: | |
| lines.append("else if (%s) then" % self._print(c)) | |
| lines.append(self._print(e)) | |
| lines.append("end if") | |
| return "\n".join(lines) | |
| elif self._settings["standard"] >= 95: | |
| # Only supported in F95 and newer: | |
| # The piecewise was used in an expression, need to do inline | |
| # operators. This has the downside that inline operators will | |
| # not work for statements that span multiple lines (Matrix or | |
| # Indexed expressions). | |
| pattern = "merge({T}, {F}, {COND})" | |
| code = self._print(expr.args[-1].expr) | |
| terms = list(expr.args[:-1]) | |
| while terms: | |
| e, c = terms.pop() | |
| expr = self._print(e) | |
| cond = self._print(c) | |
| code = pattern.format(T=expr, F=code, COND=cond) | |
| return code | |
| else: | |
| # `merge` is not supported prior to F95 | |
| raise NotImplementedError("Using Piecewise as an expression using " | |
| "inline operators is not supported in " | |
| "standards earlier than Fortran95.") | |
| def _print_MatrixElement(self, expr): | |
| return "{}({}, {})".format(self.parenthesize(expr.parent, | |
| PRECEDENCE["Atom"], strict=True), expr.i + 1, expr.j + 1) | |
| def _print_Add(self, expr): | |
| # purpose: print complex numbers nicely in Fortran. | |
| # collect the purely real and purely imaginary parts: | |
| pure_real = [] | |
| pure_imaginary = [] | |
| mixed = [] | |
| for arg in expr.args: | |
| if arg.is_number and arg.is_real: | |
| pure_real.append(arg) | |
| elif arg.is_number and arg.is_imaginary: | |
| pure_imaginary.append(arg) | |
| else: | |
| mixed.append(arg) | |
| if pure_imaginary: | |
| if mixed: | |
| PREC = precedence(expr) | |
| term = Add(*mixed) | |
| t = self._print(term) | |
| if t.startswith('-'): | |
| sign = "-" | |
| t = t[1:] | |
| else: | |
| sign = "+" | |
| if precedence(term) < PREC: | |
| t = "(%s)" % t | |
| return "cmplx(%s,%s) %s %s" % ( | |
| self._print(Add(*pure_real)), | |
| self._print(-S.ImaginaryUnit*Add(*pure_imaginary)), | |
| sign, t, | |
| ) | |
| else: | |
| return "cmplx(%s,%s)" % ( | |
| self._print(Add(*pure_real)), | |
| self._print(-S.ImaginaryUnit*Add(*pure_imaginary)), | |
| ) | |
| else: | |
| return CodePrinter._print_Add(self, expr) | |
| def _print_Function(self, expr): | |
| # All constant function args are evaluated as floats | |
| prec = self._settings['precision'] | |
| args = [N(a, prec) for a in expr.args] | |
| eval_expr = expr.func(*args) | |
| if not isinstance(eval_expr, Function): | |
| return self._print(eval_expr) | |
| else: | |
| return CodePrinter._print_Function(self, expr.func(*args)) | |
| def _print_Mod(self, expr): | |
| # NOTE : Fortran has the functions mod() and modulo(). modulo() behaves | |
| # the same wrt to the sign of the arguments as Python and SymPy's | |
| # modulus computations (% and Mod()) but is not available in Fortran 66 | |
| # or Fortran 77, thus we raise an error. | |
| if self._settings['standard'] in [66, 77]: | |
| msg = ("Python % operator and SymPy's Mod() function are not " | |
| "supported by Fortran 66 or 77 standards.") | |
| raise NotImplementedError(msg) | |
| else: | |
| x, y = expr.args | |
| return " modulo({}, {})".format(self._print(x), self._print(y)) | |
| def _print_ImaginaryUnit(self, expr): | |
| # purpose: print complex numbers nicely in Fortran. | |
| return "cmplx(0,1)" | |
| def _print_int(self, expr): | |
| return str(expr) | |
| def _print_Mul(self, expr): | |
| # purpose: print complex numbers nicely in Fortran. | |
| if expr.is_number and expr.is_imaginary: | |
| return "cmplx(0,%s)" % ( | |
| self._print(-S.ImaginaryUnit*expr) | |
| ) | |
| else: | |
| return CodePrinter._print_Mul(self, expr) | |
| def _print_Pow(self, expr): | |
| PREC = precedence(expr) | |
| if equal_valued(expr.exp, -1): | |
| return '%s/%s' % ( | |
| self._print(literal_dp(1)), | |
| self.parenthesize(expr.base, PREC) | |
| ) | |
| elif equal_valued(expr.exp, 0.5): | |
| if expr.base.is_integer: | |
| # Fortran intrinsic sqrt() does not accept integer argument | |
| if expr.base.is_Number: | |
| return 'sqrt(%s.0d0)' % self._print(expr.base) | |
| else: | |
| return 'sqrt(dble(%s))' % self._print(expr.base) | |
| else: | |
| return 'sqrt(%s)' % self._print(expr.base) | |
| else: | |
| return CodePrinter._print_Pow(self, expr) | |
| def _print_Rational(self, expr): | |
| p, q = int(expr.p), int(expr.q) | |
| return "%d.0d0/%d.0d0" % (p, q) | |
| def _print_Float(self, expr): | |
| printed = CodePrinter._print_Float(self, expr) | |
| e = printed.find('e') | |
| if e > -1: | |
| return "%sd%s" % (printed[:e], printed[e + 1:]) | |
| return "%sd0" % printed | |
| def _print_Relational(self, expr): | |
| lhs_code = self._print(expr.lhs) | |
| rhs_code = self._print(expr.rhs) | |
| op = expr.rel_op | |
| op = op if op not in self._relationals else self._relationals[op] | |
| return "{} {} {}".format(lhs_code, op, rhs_code) | |
| def _print_Indexed(self, expr): | |
| inds = [ self._print(i) for i in expr.indices ] | |
| return "%s(%s)" % (self._print(expr.base.label), ", ".join(inds)) | |
| def _print_Idx(self, expr): | |
| return self._print(expr.label) | |
| def _print_AugmentedAssignment(self, expr): | |
| lhs_code = self._print(expr.lhs) | |
| rhs_code = self._print(expr.rhs) | |
| return self._get_statement("{0} = {0} {1} {2}".format( | |
| self._print(lhs_code), self._print(expr.binop), self._print(rhs_code))) | |
| def _print_sum_(self, sm): | |
| params = self._print(sm.array) | |
| if sm.dim != None: # Must use '!= None', cannot use 'is not None' | |
| params += ', ' + self._print(sm.dim) | |
| if sm.mask != None: # Must use '!= None', cannot use 'is not None' | |
| params += ', mask=' + self._print(sm.mask) | |
| return '%s(%s)' % (sm.__class__.__name__.rstrip('_'), params) | |
| def _print_product_(self, prod): | |
| return self._print_sum_(prod) | |
| def _print_Do(self, do): | |
| excl = ['concurrent'] | |
| if do.step == 1: | |
| excl.append('step') | |
| step = '' | |
| else: | |
| step = ', {step}' | |
| return ( | |
| 'do {concurrent}{counter} = {first}, {last}'+step+'\n' | |
| '{body}\n' | |
| 'end do\n' | |
| ).format( | |
| concurrent='concurrent ' if do.concurrent else '', | |
| **do.kwargs(apply=lambda arg: self._print(arg), exclude=excl) | |
| ) | |
| def _print_ImpliedDoLoop(self, idl): | |
| step = '' if idl.step == 1 else ', {step}' | |
| return ('({expr}, {counter} = {first}, {last}'+step+')').format( | |
| **idl.kwargs(apply=lambda arg: self._print(arg)) | |
| ) | |
| def _print_For(self, expr): | |
| target = self._print(expr.target) | |
| if isinstance(expr.iterable, Range): | |
| start, stop, step = expr.iterable.args | |
| else: | |
| raise NotImplementedError("Only iterable currently supported is Range") | |
| body = self._print(expr.body) | |
| return ('do {target} = {start}, {stop}, {step}\n' | |
| '{body}\n' | |
| 'end do').format(target=target, start=start, stop=stop - 1, | |
| step=step, body=body) | |
| def _print_Type(self, type_): | |
| type_ = self.type_aliases.get(type_, type_) | |
| type_str = self.type_mappings.get(type_, type_.name) | |
| module_uses = self.type_modules.get(type_) | |
| if module_uses: | |
| for k, v in module_uses: | |
| self.module_uses[k].add(v) | |
| return type_str | |
| def _print_Element(self, elem): | |
| return '{symbol}({idxs})'.format( | |
| symbol=self._print(elem.symbol), | |
| idxs=', '.join((self._print(arg) for arg in elem.indices)) | |
| ) | |
| def _print_Extent(self, ext): | |
| return str(ext) | |
| def _print_Declaration(self, expr): | |
| var = expr.variable | |
| val = var.value | |
| dim = var.attr_params('dimension') | |
| intents = [intent in var.attrs for intent in (intent_in, intent_out, intent_inout)] | |
| if intents.count(True) == 0: | |
| intent = '' | |
| elif intents.count(True) == 1: | |
| intent = ', intent(%s)' % ['in', 'out', 'inout'][intents.index(True)] | |
| else: | |
| raise ValueError("Multiple intents specified for %s" % self) | |
| if isinstance(var, Pointer): | |
| raise NotImplementedError("Pointers are not available by default in Fortran.") | |
| if self._settings["standard"] >= 90: | |
| result = '{t}{vc}{dim}{intent}{alloc} :: {s}'.format( | |
| t=self._print(var.type), | |
| vc=', parameter' if value_const in var.attrs else '', | |
| dim=', dimension(%s)' % ', '.join((self._print(arg) for arg in dim)) if dim else '', | |
| intent=intent, | |
| alloc=', allocatable' if allocatable in var.attrs else '', | |
| s=self._print(var.symbol) | |
| ) | |
| if val != None: # Must be "!= None", cannot be "is not None" | |
| result += ' = %s' % self._print(val) | |
| else: | |
| if value_const in var.attrs or val: | |
| raise NotImplementedError("F77 init./parameter statem. req. multiple lines.") | |
| result = ' '.join((self._print(arg) for arg in [var.type, var.symbol])) | |
| return result | |
| def _print_Infinity(self, expr): | |
| return '(huge(%s) + 1)' % self._print(literal_dp(0)) | |
| def _print_While(self, expr): | |
| return 'do while ({condition})\n{body}\nend do'.format(**expr.kwargs( | |
| apply=lambda arg: self._print(arg))) | |
| def _print_BooleanTrue(self, expr): | |
| return '.true.' | |
| def _print_BooleanFalse(self, expr): | |
| return '.false.' | |
| def _pad_leading_columns(self, lines): | |
| result = [] | |
| for line in lines: | |
| if line.startswith('!'): | |
| result.append(self._lead['comment'] + line[1:].lstrip()) | |
| else: | |
| result.append(self._lead['code'] + line) | |
| return result | |
| def _wrap_fortran(self, lines): | |
| """Wrap long Fortran lines | |
| Argument: | |
| lines -- a list of lines (without \\n character) | |
| A comment line is split at white space. Code lines are split with a more | |
| complex rule to give nice results. | |
| """ | |
| # routine to find split point in a code line | |
| my_alnum = set("_+-." + string.digits + string.ascii_letters) | |
| my_white = set(" \t()") | |
| def split_pos_code(line, endpos): | |
| if len(line) <= endpos: | |
| return len(line) | |
| pos = endpos | |
| split = lambda pos: \ | |
| (line[pos] in my_alnum and line[pos - 1] not in my_alnum) or \ | |
| (line[pos] not in my_alnum and line[pos - 1] in my_alnum) or \ | |
| (line[pos] in my_white and line[pos - 1] not in my_white) or \ | |
| (line[pos] not in my_white and line[pos - 1] in my_white) | |
| while not split(pos): | |
| pos -= 1 | |
| if pos == 0: | |
| return endpos | |
| return pos | |
| # split line by line and add the split lines to result | |
| result = [] | |
| if self._settings['source_format'] == 'free': | |
| trailing = ' &' | |
| else: | |
| trailing = '' | |
| for line in lines: | |
| if line.startswith(self._lead['comment']): | |
| # comment line | |
| if len(line) > 72: | |
| pos = line.rfind(" ", 6, 72) | |
| if pos == -1: | |
| pos = 72 | |
| hunk = line[:pos] | |
| line = line[pos:].lstrip() | |
| result.append(hunk) | |
| while line: | |
| pos = line.rfind(" ", 0, 66) | |
| if pos == -1 or len(line) < 66: | |
| pos = 66 | |
| hunk = line[:pos] | |
| line = line[pos:].lstrip() | |
| result.append("%s%s" % (self._lead['comment'], hunk)) | |
| else: | |
| result.append(line) | |
| elif line.startswith(self._lead['code']): | |
| # code line | |
| pos = split_pos_code(line, 72) | |
| hunk = line[:pos].rstrip() | |
| line = line[pos:].lstrip() | |
| if line: | |
| hunk += trailing | |
| result.append(hunk) | |
| while line: | |
| pos = split_pos_code(line, 65) | |
| hunk = line[:pos].rstrip() | |
| line = line[pos:].lstrip() | |
| if line: | |
| hunk += trailing | |
| result.append("%s%s" % (self._lead['cont'], hunk)) | |
| else: | |
| result.append(line) | |
| return result | |
| def indent_code(self, code): | |
| """Accepts a string of code or a list of code lines""" | |
| if isinstance(code, str): | |
| code_lines = self.indent_code(code.splitlines(True)) | |
| return ''.join(code_lines) | |
| free = self._settings['source_format'] == 'free' | |
| code = [ line.lstrip(' \t') for line in code ] | |
| inc_keyword = ('do ', 'if(', 'if ', 'do\n', 'else', 'program', 'interface') | |
| dec_keyword = ('end do', 'enddo', 'end if', 'endif', 'else', 'end program', 'end interface') | |
| increase = [ int(any(map(line.startswith, inc_keyword))) | |
| for line in code ] | |
| decrease = [ int(any(map(line.startswith, dec_keyword))) | |
| for line in code ] | |
| continuation = [ int(any(map(line.endswith, ['&', '&\n']))) | |
| for line in code ] | |
| level = 0 | |
| cont_padding = 0 | |
| tabwidth = 3 | |
| new_code = [] | |
| for i, line in enumerate(code): | |
| if line in ('', '\n'): | |
| new_code.append(line) | |
| continue | |
| level -= decrease[i] | |
| if free: | |
| padding = " "*(level*tabwidth + cont_padding) | |
| else: | |
| padding = " "*level*tabwidth | |
| line = "%s%s" % (padding, line) | |
| if not free: | |
| line = self._pad_leading_columns([line])[0] | |
| new_code.append(line) | |
| if continuation[i]: | |
| cont_padding = 2*tabwidth | |
| else: | |
| cont_padding = 0 | |
| level += increase[i] | |
| if not free: | |
| return self._wrap_fortran(new_code) | |
| return new_code | |
| def _print_GoTo(self, goto): | |
| if goto.expr: # computed goto | |
| return "go to ({labels}), {expr}".format( | |
| labels=', '.join((self._print(arg) for arg in goto.labels)), | |
| expr=self._print(goto.expr) | |
| ) | |
| else: | |
| lbl, = goto.labels | |
| return "go to %s" % self._print(lbl) | |
| def _print_Program(self, prog): | |
| return ( | |
| "program {name}\n" | |
| "{body}\n" | |
| "end program\n" | |
| ).format(**prog.kwargs(apply=lambda arg: self._print(arg))) | |
| def _print_Module(self, mod): | |
| return ( | |
| "module {name}\n" | |
| "{declarations}\n" | |
| "\ncontains\n\n" | |
| "{definitions}\n" | |
| "end module\n" | |
| ).format(**mod.kwargs(apply=lambda arg: self._print(arg))) | |
| def _print_Stream(self, strm): | |
| if strm.name == 'stdout' and self._settings["standard"] >= 2003: | |
| self.module_uses['iso_c_binding'].add('stdint=>input_unit') | |
| return 'input_unit' | |
| elif strm.name == 'stderr' and self._settings["standard"] >= 2003: | |
| self.module_uses['iso_c_binding'].add('stdint=>error_unit') | |
| return 'error_unit' | |
| else: | |
| if strm.name == 'stdout': | |
| return '*' | |
| else: | |
| return strm.name | |
| def _print_Print(self, ps): | |
| if ps.format_string == none: # Must be '!= None', cannot be 'is not None' | |
| template = "print {fmt}, {iolist}" | |
| fmt = '*' | |
| else: | |
| template = 'write(%(out)s, fmt="{fmt}", advance="no"), {iolist}' % { | |
| 'out': {stderr: '0', stdout: '6'}.get(ps.file, '*') | |
| } | |
| fmt = self._print(ps.format_string) | |
| return template.format(fmt=fmt, iolist=', '.join( | |
| (self._print(arg) for arg in ps.print_args))) | |
| def _print_Return(self, rs): | |
| arg, = rs.args | |
| return "{result_name} = {arg}".format( | |
| result_name=self._context.get('result_name', 'sympy_result'), | |
| arg=self._print(arg) | |
| ) | |
| def _print_FortranReturn(self, frs): | |
| arg, = frs.args | |
| if arg: | |
| return 'return %s' % self._print(arg) | |
| else: | |
| return 'return' | |
| def _head(self, entity, fp, **kwargs): | |
| bind_C_params = fp.attr_params('bind_C') | |
| if bind_C_params is None: | |
| bind = '' | |
| else: | |
| bind = ' bind(C, name="%s")' % bind_C_params[0] if bind_C_params else ' bind(C)' | |
| result_name = self._settings.get('result_name', None) | |
| return ( | |
| "{entity}{name}({arg_names}){result}{bind}\n" | |
| "{arg_declarations}" | |
| ).format( | |
| entity=entity, | |
| name=self._print(fp.name), | |
| arg_names=', '.join([self._print(arg.symbol) for arg in fp.parameters]), | |
| result=(' result(%s)' % result_name) if result_name else '', | |
| bind=bind, | |
| arg_declarations='\n'.join((self._print(Declaration(arg)) for arg in fp.parameters)) | |
| ) | |
| def _print_FunctionPrototype(self, fp): | |
| entity = "{} function ".format(self._print(fp.return_type)) | |
| return ( | |
| "interface\n" | |
| "{function_head}\n" | |
| "end function\n" | |
| "end interface" | |
| ).format(function_head=self._head(entity, fp)) | |
| def _print_FunctionDefinition(self, fd): | |
| if elemental in fd.attrs: | |
| prefix = 'elemental ' | |
| elif pure in fd.attrs: | |
| prefix = 'pure ' | |
| else: | |
| prefix = '' | |
| entity = "{} function ".format(self._print(fd.return_type)) | |
| with printer_context(self, result_name=fd.name): | |
| return ( | |
| "{prefix}{function_head}\n" | |
| "{body}\n" | |
| "end function\n" | |
| ).format( | |
| prefix=prefix, | |
| function_head=self._head(entity, fd), | |
| body=self._print(fd.body) | |
| ) | |
| def _print_Subroutine(self, sub): | |
| return ( | |
| '{subroutine_head}\n' | |
| '{body}\n' | |
| 'end subroutine\n' | |
| ).format( | |
| subroutine_head=self._head('subroutine ', sub), | |
| body=self._print(sub.body) | |
| ) | |
| def _print_SubroutineCall(self, scall): | |
| return 'call {name}({args})'.format( | |
| name=self._print(scall.name), | |
| args=', '.join((self._print(arg) for arg in scall.subroutine_args)) | |
| ) | |
| def _print_use_rename(self, rnm): | |
| return "%s => %s" % tuple((self._print(arg) for arg in rnm.args)) | |
| def _print_use(self, use): | |
| result = 'use %s' % self._print(use.namespace) | |
| if use.rename != None: # Must be '!= None', cannot be 'is not None' | |
| result += ', ' + ', '.join([self._print(rnm) for rnm in use.rename]) | |
| if use.only != None: # Must be '!= None', cannot be 'is not None' | |
| result += ', only: ' + ', '.join([self._print(nly) for nly in use.only]) | |
| return result | |
| def _print_BreakToken(self, _): | |
| return 'exit' | |
| def _print_ContinueToken(self, _): | |
| return 'cycle' | |
| def _print_ArrayConstructor(self, ac): | |
| fmtstr = "[%s]" if self._settings["standard"] >= 2003 else '(/%s/)' | |
| return fmtstr % ', '.join((self._print(arg) for arg in ac.elements)) | |
| def _print_ArrayElement(self, elem): | |
| return '{symbol}({idxs})'.format( | |
| symbol=self._print(elem.name), | |
| idxs=', '.join((self._print(arg) for arg in elem.indices)) | |
| ) | |