Skip to content

Commit c515253

Browse files
committed
compiler: make printer part of the target and differentiate C and CXX
1 parent 461fd43 commit c515253

10 files changed

Lines changed: 121 additions & 72 deletions

File tree

devito/core/cpu.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from devito.passes.clusters import (Lift, blocking, buffering, cire, cse,
99
factorize, fission, fuse, optimize_pows,
1010
optimize_hyperplanes)
11-
from devito.passes.iet import (CTarget, OmpTarget, avoid_denormals, linearize,
11+
from devito.passes.iet import (CTarget, CXXTarget, COmpTarget, CXXOmpTarget,
12+
avoid_denormals, linearize,
1213
mpiize, hoist_prodders, relax_incr_dimensions,
1314
check_stability)
1415
from devito.tools import timed_pass
@@ -244,7 +245,7 @@ def _normalize_kwargs(cls, **kwargs):
244245

245246
class Cpu64CustomOperator(Cpu64OperatorMixin, CustomOperator):
246247

247-
_Target = OmpTarget
248+
_Target = COmpTarget
248249

249250
@classmethod
250251
def _make_dsl_passes_mapper(cls, **kwargs):
@@ -325,20 +326,20 @@ class Cpu64NoopCOperator(Cpu64NoopOperator):
325326

326327

327328
class Cpu64NoopOmpOperator(Cpu64NoopOperator):
328-
_Target = OmpTarget
329+
_Target = COmpTarget
329330

330331

331332
class Cpu64AdvCOperator(Cpu64AdvOperator):
332333
_Target = CTarget
333334

334335

335336
class Cpu64AdvOmpOperator(Cpu64AdvOperator):
336-
_Target = OmpTarget
337+
_Target = COmpTarget
337338

338339

339340
class Cpu64FsgCOperator(Cpu64FsgOperator):
340341
_Target = CTarget
341342

342343

343344
class Cpu64FsgOmpOperator(Cpu64FsgOperator):
344-
_Target = OmpTarget
345+
_Target = COmpTarget

devito/ir/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from devito.ir.equations import * # noqa
33
from devito.ir.clusters import * # noqa
44
from devito.ir.iet import * # noqa
5+
from devito.ir.printer import * # noqa

devito/ir/cgen/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from devito.ir.cgen.printer import * # noqa
Lines changed: 6 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,22 @@
1010
from sympy.core import S
1111
from sympy.core.numbers import equal_valued, Float
1212
from sympy.printing.codeprinter import CodePrinter
13-
from sympy.printing.c import C99CodePrinter
14-
from sympy.printing.cxx import CXX11CodePrinter
1513
from sympy.logic.boolalg import BooleanFunction
1614
from sympy.printing.precedence import PRECEDENCE_VALUES, precedence
1715

1816
from devito import configuration
1917
from devito.arch.compiler import AOMPCompiler
2018
from devito.symbolics.inspection import has_integer_args, sympy_dtype
21-
from devito.symbolics.extended_dtypes import c_complex, c_double_complex
2219
from devito.types.basic import AbstractFunction
2320
from devito.tools import ctypes_to_cstr, dtype_to_ctype
2421

25-
__all__ = ['ccode']
22+
__all__ = ['BasePrinter', 'printer_registry', 'ccode']
2623

2724

2825
_prec_litterals = {np.float16: 'F16', np.float32: 'F', np.complex64: 'F'}
2926

3027

31-
class _DevitoPrinterBase(CodePrinter):
28+
class BasePrinter(CodePrinter):
3229

3330
"""
3431
Decorator for sympy.printing.ccode.CCodePrinter.
@@ -366,7 +363,7 @@ def _print_Fallback(self, expr):
366363

367364
# Lifted from SymPy so that we go through our own `_print_math_func`
368365
for k in ('exp log sin cos tan ceiling floor').split():
369-
setattr(_DevitoPrinterBase, '_print_%s' % k, _DevitoPrinterBase._print_math_func)
366+
setattr(BasePrinter, '_print_%s' % k, BasePrinter._print_math_func)
370367

371368

372369
# Always parenthesize IntDiv and InlineIf within expressions
@@ -377,53 +374,10 @@ def _print_Fallback(self, expr):
377374
# Sympy 1.11 has introduced a bug in `_print_Add`, so we enforce here
378375
# to always use the correct one from our printer
379376
if Version(sympy.__version__) >= Version("1.11"):
380-
setattr(sympy.printing.str.StrPrinter, '_print_Add', _DevitoPrinterBase._print_Add)
377+
setattr(sympy.printing.str.StrPrinter, '_print_Add', BasePrinter._print_Add)
381378

382379

383-
class CDevitoPrinter(_DevitoPrinterBase, C99CodePrinter):
384-
385-
_default_settings = {**_DevitoPrinterBase._default_settings,
386-
**C99CodePrinter._default_settings}
387-
_func_litterals = {np.float32: 'f', np.complex64: 'f'}
388-
_func_prefix = {np.float32: 'f', np.float64: 'f',
389-
np.complex64: 'c', np.complex128: 'c'}
390-
391-
# These cannot go through _print_xxx because they are classes not
392-
# instances
393-
type_mappings = {**C99CodePrinter.type_mappings,
394-
c_complex: 'float _Complex',
395-
c_double_complex: 'double _Complex'}
396-
397-
def _print_ImaginaryUnit(self, expr):
398-
return '_Complex_I'
399-
400-
401-
class CXXDevitoPrinter(_DevitoPrinterBase, CXX11CodePrinter):
402-
403-
_default_settings = {**_DevitoPrinterBase._default_settings,
404-
**CXX11CodePrinter._default_settings}
405-
_ns = "std::"
406-
_func_litterals = {}
407-
_func_prefix = {np.float32: 'f', np.float64: 'f'}
408-
409-
# These cannot go through _print_xxx because they are classes not
410-
# instances
411-
type_mappings = {**CXX11CodePrinter.type_mappings,
412-
c_complex: 'std::complex<float>',
413-
c_double_complex: 'std::complex<double>'}
414-
415-
def _print_ImaginaryUnit(self, expr):
416-
return f'1i{self.prec_literal(expr).lower()}'
417-
418-
419-
class AccDevitoPrinter(CXXDevitoPrinter):
420-
421-
pass
422-
423-
424-
printer_registry: dict[str, type[_DevitoPrinterBase]] = {
425-
'C': CDevitoPrinter, 'CXX': CXXDevitoPrinter,
426-
'openmp': CDevitoPrinter, 'openacc': AccDevitoPrinter}
380+
printer_registry: dict[str, type[BasePrinter]] = {'default': BasePrinter}
427381

428382

429383
def ccode(expr, language=None, **settings):
@@ -443,8 +397,5 @@ def ccode(expr, language=None, **settings):
443397
the input ``expr`` itself.
444398
"""
445399
lang = language or configuration['language']
446-
cpp = settings.get('compiler', configuration['compiler'])._cpp
447-
if lang in ['C', 'openmp'] and cpp:
448-
lang = 'CXX'
449-
printer = printer_registry.get(lang, CDevitoPrinter)
400+
printer = printer_registry.get(lang, BasePrinter)
450401
return printer(settings=settings).doprint(expr, None)

devito/operator/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class OperatorRegistry(OrderedDict, metaclass=Singleton):
2626
"""
2727

2828
_modes = ('noop', 'advanced', 'advanced-fsg')
29-
_languages = ('C', 'openmp', 'openacc', 'cuda', 'hip', 'sycl')
29+
_languages = ('C', 'CXX', 'openmp', 'openacc', 'cuda', 'hip', 'sycl')
3030
_accepted = _modes + tuple(product(_modes, _languages))
3131

3232
def add(self, operator, platform, mode, language='C'):

devito/passes/iet/languages/C.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
from devito.ir import Call
1+
import numpy as np
2+
from sympy.printing.c import C99CodePrinter
3+
4+
from devito.ir import Call, BasePrinter, printer_registry
25
from devito.passes.iet.definitions import DataManager
36
from devito.passes.iet.orchestration import Orchestrator
47
from devito.passes.iet.langbase import LangBB
8+
from devito.symbolics.extended_dtypes import c_complex, c_double_complex
59

610
__all__ = ['CBB', 'CDataManager', 'COrchestrator']
711

@@ -31,3 +35,26 @@ class CDataManager(DataManager):
3135

3236
class COrchestrator(Orchestrator):
3337
lang = CBB
38+
39+
40+
class CPrinter(BasePrinter, C99CodePrinter):
41+
42+
_default_settings = {**BasePrinter._default_settings,
43+
**C99CodePrinter._default_settings}
44+
_func_litterals = {np.float32: 'f', np.complex64: 'f'}
45+
_func_prefix = {np.float32: 'f', np.float64: 'f',
46+
np.complex64: 'c', np.complex128: 'c'}
47+
48+
# These cannot go through _print_xxx because they are classes not
49+
# instances
50+
type_mappings = {**C99CodePrinter.type_mappings,
51+
c_complex: 'float _Complex',
52+
c_double_complex: 'double _Complex'}
53+
54+
def _print_ImaginaryUnit(self, expr):
55+
return '_Complex_I'
56+
57+
58+
printer_registry['C'] = CPrinter
59+
printer_registry['openmp'] = CPrinter
60+
printer_registry['Copenmp'] = CPrinter

devito/passes/iet/languages/CXX.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1-
from devito.ir import Call, UsingNamespace
1+
import numpy as np
2+
from sympy.printing.cxx import CXX11CodePrinter
3+
4+
from devito.ir import Call, UsingNamespace, BasePrinter, printer_registry
25
from devito.passes.iet.langbase import LangBB
6+
from devito.symbolics.extended_dtypes import c_complex, c_double_complex
37

48
__all__ = ['CXXBB']
59

@@ -61,3 +65,25 @@ class CXXBB(LangBB):
6165
'complex-namespace': [UsingNamespace('std::complex_literals')],
6266
'def-complex': std_arith,
6367
}
68+
69+
70+
class CXXPrinter(BasePrinter, CXX11CodePrinter):
71+
72+
_default_settings = {**BasePrinter._default_settings,
73+
**CXX11CodePrinter._default_settings}
74+
_ns = "std::"
75+
_func_litterals = {}
76+
_func_prefix = {np.float32: 'f', np.float64: 'f'}
77+
78+
# These cannot go through _print_xxx because they are classes not
79+
# instances
80+
type_mappings = {**CXX11CodePrinter.type_mappings,
81+
c_complex: 'std::complex<float>',
82+
c_double_complex: 'std::complex<double>'}
83+
84+
def _print_ImaginaryUnit(self, expr):
85+
return f'1i{self.prec_literal(expr).lower()}'
86+
87+
88+
printer_registry['CXX'] = CXXPrinter
89+
printer_registry['CXXopenmp'] = CXXPrinter

devito/passes/iet/languages/openacc.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
from devito.arch import AMDGPUX, NVIDIAX
44
from devito.ir import (Call, DeviceCall, DummyExpr, EntryFunction, List, Block,
5-
ParallelTree, Pragma, Return, FindSymbols, make_callable)
5+
ParallelTree, Pragma, Return, FindSymbols, make_callable,
6+
printer_registry)
67
from devito.passes import needs_transfer, is_on_device
78
from devito.passes.iet.definitions import DeviceAwareDataManager
89
from devito.passes.iet.engine import iet_pass
910
from devito.passes.iet.orchestration import Orchestrator
1011
from devito.passes.iet.parpragma import (PragmaDeviceAwareTransformer, PragmaLangBB,
1112
PragmaIteration, PragmaTransfer)
12-
from devito.passes.iet.languages.CXX import CXXBB
13+
from devito.passes.iet.languages.CXX import CXXBB, CXXPrinter
1314
from devito.passes.iet.languages.openmp import OmpRegion, OmpIteration
1415
from devito.symbolics import FieldFromPointer, Macro, cast_mapper
1516
from devito.tools import filter_ordered, UnboundTuple
@@ -263,3 +264,11 @@ def place_devptr(self, iet, **kwargs):
263264

264265
class AccOrchestrator(Orchestrator):
265266
lang = AccBB
267+
268+
269+
class AccPrinter(CXXPrinter):
270+
271+
pass
272+
273+
274+
printer_registry['openacc'] = AccPrinter

devito/passes/iet/languages/targets.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from devito.passes.iet.languages.C import CDataManager, COrchestrator
1+
from devito.passes.iet.languages.C import CDataManager, COrchestrator, CPrinter
2+
from devito.passes.iet.languages.CXX import CXXPrinter
23
from devito.passes.iet.languages.openmp import (SimdOmpizer, Ompizer, DeviceOmpizer,
34
OmpDataManager, DeviceOmpDataManager,
45
OmpOrchestrator, DeviceOmpOrchestrator)
56
from devito.passes.iet.languages.openacc import (DeviceAccizer, DeviceAccDataManager,
6-
AccOrchestrator)
7+
AccOrchestrator, AccPrinter)
78
from devito.passes.iet.instrument import instrument
89

910
__all__ = ['CTarget', 'OmpTarget', 'DeviceOmpTarget', 'DeviceAccTarget']
@@ -13,6 +14,7 @@ class Target:
1314
Parizer = None
1415
DataManager = None
1516
Orchestrator = None
17+
Printer = None
1618

1719
@classmethod
1820
def lang(cls):
@@ -27,21 +29,52 @@ class CTarget(Target):
2729
Parizer = SimdOmpizer
2830
DataManager = CDataManager
2931
Orchestrator = COrchestrator
32+
Printer = CPrinter
3033

3134

32-
class OmpTarget(Target):
35+
class CXXTarget(Target):
36+
Parizer = SimdOmpizer
37+
DataManager = CDataManager
38+
Orchestrator = COrchestrator
39+
Printer = CXXPrinter
40+
41+
42+
class COmpTarget(Target):
3343
Parizer = Ompizer
3444
DataManager = OmpDataManager
3545
Orchestrator = OmpOrchestrator
46+
Printer = CPrinter
47+
48+
49+
OmpTarget = COmpTarget
50+
51+
52+
class CXXOmpTarget(Target):
53+
Parizer = Ompizer
54+
DataManager = OmpDataManager
55+
Orchestrator = OmpOrchestrator
56+
Printer = CXXPrinter
57+
58+
59+
class DeviceCOmpTarget(Target):
60+
Parizer = DeviceOmpizer
61+
DataManager = DeviceOmpDataManager
62+
Orchestrator = DeviceOmpOrchestrator
63+
Printer = CPrinter
64+
65+
66+
DeviceOmpTarget = DeviceCOmpTarget
3667

3768

38-
class DeviceOmpTarget(Target):
69+
class DeviceCXXOmpTarget(Target):
3970
Parizer = DeviceOmpizer
4071
DataManager = DeviceOmpDataManager
4172
Orchestrator = DeviceOmpOrchestrator
73+
Printer = CXXPrinter
4274

4375

4476
class DeviceAccTarget(Target):
4577
Parizer = DeviceAccizer
4678
DataManager = DeviceAccDataManager
4779
Orchestrator = AccOrchestrator
80+
Printer = AccPrinter

tests/test_dtypes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from devito.passes.iet.languages.openacc import AccBB
99
from devito.passes.iet.languages.openmp import OmpBB
1010
from devito.symbolics.extended_dtypes import ctypes_vector_mapper
11-
from devito.symbolics.printer import printer_registry, _DevitoPrinterBase
11+
from devito.symbolics.printer import printer_registry, BasePrinter
1212
from devito.types.basic import Basic, Scalar, Symbol
1313
from devito.types.dense import TimeFunction
1414

@@ -27,7 +27,7 @@ def _get_language(language: str, **_) -> type[LangBB]:
2727
return _languages[language]
2828

2929

30-
def _get_printer(language: str, **_) -> type[_DevitoPrinterBase]:
30+
def _get_printer(language: str, **_) -> type[BasePrinter]:
3131
"""
3232
Gets the printer building block type from parametrized kwargs.
3333
"""
@@ -91,7 +91,7 @@ def test_cse_ctypes(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None
9191
the generated code.
9292
"""
9393
# Retrieve the language-specific type mapping
94-
printer: type[_DevitoPrinterBase] = _get_printer(**kwargs)
94+
printer: type[BasePrinter] = _get_printer(**kwargs)
9595

9696
# Set up an operator
9797
grid = Grid(shape=(3, 3))

0 commit comments

Comments
 (0)