1010from sympy .core import S
1111from sympy .core .numbers import equal_valued , Float
1212from sympy .printing .codeprinter import CodePrinter
13- from sympy .printing .c import C99CodePrinter
14- from sympy .printing .cxx import CXX11CodePrinter
1513from sympy .logic .boolalg import BooleanFunction
1614from sympy .printing .precedence import PRECEDENCE_VALUES , precedence
1715
1816from devito import configuration
1917from devito .arch .compiler import AOMPCompiler
2018from devito .symbolics .inspection import has_integer_args , sympy_dtype
21- from devito .symbolics .extended_dtypes import c_complex , c_double_complex
2219from devito .types .basic import AbstractFunction
2320from devito .tools import ctypes_to_cstr , dtype_to_ctype
2421
25- __all__ = ['ccode' ]
22+ __all__ = ['BasePrinter' , 'printer_registry' , ' ccode' ]
2623
2724
2825_prec_litterals = {np .float16 : 'F16' , np .float32 : 'F' , np .complex64 : 'F' }
2926
3027
31- class _DevitoPrinterBase (CodePrinter ):
28+ class BasePrinter (CodePrinter ):
3229
3330 """
3431 Decorator for sympy.printing.ccode.CCodePrinter.
@@ -366,7 +363,7 @@ def _print_Fallback(self, expr):
366363
367364# Lifted from SymPy so that we go through our own `_print_math_func`
368365for k in ('exp log sin cos tan ceiling floor' ).split ():
369- setattr (_DevitoPrinterBase , '_print_%s' % k , _DevitoPrinterBase ._print_math_func )
366+ setattr (BasePrinter , '_print_%s' % k , BasePrinter ._print_math_func )
370367
371368
372369# Always parenthesize IntDiv and InlineIf within expressions
@@ -377,53 +374,10 @@ def _print_Fallback(self, expr):
377374# Sympy 1.11 has introduced a bug in `_print_Add`, so we enforce here
378375# to always use the correct one from our printer
379376if Version (sympy .__version__ ) >= Version ("1.11" ):
380- setattr (sympy .printing .str .StrPrinter , '_print_Add' , _DevitoPrinterBase ._print_Add )
377+ setattr (sympy .printing .str .StrPrinter , '_print_Add' , BasePrinter ._print_Add )
381378
382379
383- class CDevitoPrinter (_DevitoPrinterBase , C99CodePrinter ):
384-
385- _default_settings = {** _DevitoPrinterBase ._default_settings ,
386- ** C99CodePrinter ._default_settings }
387- _func_litterals = {np .float32 : 'f' , np .complex64 : 'f' }
388- _func_prefix = {np .float32 : 'f' , np .float64 : 'f' ,
389- np .complex64 : 'c' , np .complex128 : 'c' }
390-
391- # These cannot go through _print_xxx because they are classes not
392- # instances
393- type_mappings = {** C99CodePrinter .type_mappings ,
394- c_complex : 'float _Complex' ,
395- c_double_complex : 'double _Complex' }
396-
397- def _print_ImaginaryUnit (self , expr ):
398- return '_Complex_I'
399-
400-
401- class CXXDevitoPrinter (_DevitoPrinterBase , CXX11CodePrinter ):
402-
403- _default_settings = {** _DevitoPrinterBase ._default_settings ,
404- ** CXX11CodePrinter ._default_settings }
405- _ns = "std::"
406- _func_litterals = {}
407- _func_prefix = {np .float32 : 'f' , np .float64 : 'f' }
408-
409- # These cannot go through _print_xxx because they are classes not
410- # instances
411- type_mappings = {** CXX11CodePrinter .type_mappings ,
412- c_complex : 'std::complex<float>' ,
413- c_double_complex : 'std::complex<double>' }
414-
415- def _print_ImaginaryUnit (self , expr ):
416- return f'1i{ self .prec_literal (expr ).lower ()} '
417-
418-
419- class AccDevitoPrinter (CXXDevitoPrinter ):
420-
421- pass
422-
423-
424- printer_registry : dict [str , type [_DevitoPrinterBase ]] = {
425- 'C' : CDevitoPrinter , 'CXX' : CXXDevitoPrinter ,
426- 'openmp' : CDevitoPrinter , 'openacc' : AccDevitoPrinter }
380+ printer_registry : dict [str , type [BasePrinter ]] = {'default' : BasePrinter }
427381
428382
429383def ccode (expr , language = None , ** settings ):
@@ -443,8 +397,5 @@ def ccode(expr, language=None, **settings):
443397 the input ``expr`` itself.
444398 """
445399 lang = language or configuration ['language' ]
446- cpp = settings .get ('compiler' , configuration ['compiler' ])._cpp
447- if lang in ['C' , 'openmp' ] and cpp :
448- lang = 'CXX'
449- printer = printer_registry .get (lang , CDevitoPrinter )
400+ printer = printer_registry .get (lang , BasePrinter )
450401 return printer (settings = settings ).doprint (expr , None )
0 commit comments