Skip to content

Commit 58e6310

Browse files
committed
compiler: make visitor language parametric
1 parent 3e9e931 commit 58e6310

8 files changed

Lines changed: 64 additions & 53 deletions

File tree

devito/ir/iet/nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def writes(self):
153153
return ()
154154

155155
def _signature_items(self):
156-
return (str(self.ccode),)
156+
return (str(self),)
157157

158158

159159
class ExprStmt:

devito/ir/iet/visitors.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,9 @@ class CGen(Visitor):
176176
Return a representation of the Iteration/Expression tree as a :module:`cgen` tree.
177177
"""
178178

179-
def __init__(self, *args, **kwargs):
179+
def __init__(self, *args, language=None, **kwargs):
180180
super().__init__(*args, **kwargs)
181+
self.language = language
181182

182183
# The following mappers may be customized by subclasses (that is,
183184
# backend-specific CGen-erators)
@@ -189,6 +190,9 @@ def __init__(self, *args, **kwargs):
189190
}
190191
_restrict_keyword = 'restrict'
191192

193+
def ccode(self, expr, **kwargs):
194+
return ccode(expr, language=self.language, **kwargs)
195+
192196
def _gen_struct_decl(self, obj, masked=()):
193197
"""
194198
Convert ctypes.Struct -> cgen.Structure.
@@ -222,7 +226,7 @@ def _gen_struct_decl(self, obj, masked=()):
222226
try:
223227
entries.append(self._gen_value(i, 0, masked=('const',)))
224228
except AttributeError:
225-
cstr = ccode(ct)
229+
cstr = self.ccode(ct)
226230
if ct is c_restrict_void_p:
227231
cstr = '%srestrict' % cstr
228232
entries.append(c.Value(cstr, n))
@@ -244,10 +248,10 @@ def _gen_value(self, obj, mode=1, masked=()):
244248
if getattr(obj.function, k, False) and v not in masked]
245249

246250
if (obj._mem_stack or obj._mem_constant) and mode == 1:
247-
strtype = ccode(obj._C_typedata)
248-
strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape)
251+
strtype = self.ccode(obj._C_typedata)
252+
strshape = ''.join('[%s]' % self.ccode(i) for i in obj.symbolic_shape)
249253
else:
250-
strtype = ccode(obj._C_ctype)
254+
strtype = self.ccode(obj._C_ctype)
251255
strshape = ''
252256
if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1:
253257
if not obj._mem_stack:
@@ -261,7 +265,7 @@ def _gen_value(self, obj, mode=1, masked=()):
261265
strobj = '%s%s' % (strname, strshape)
262266

263267
if obj.is_LocalObject and obj.cargs and mode == 1:
264-
arguments = [ccode(i) for i in obj.cargs]
268+
arguments = [self.ccode(i) for i in obj.cargs]
265269
strobj = MultilineCall(strobj, arguments, True)
266270

267271
value = c.Value(strtype, strobj)
@@ -275,9 +279,9 @@ def _gen_value(self, obj, mode=1, masked=()):
275279
if obj.is_Array and obj.initvalue is not None and mode == 1:
276280
init = ListInitializer(obj.initvalue)
277281
if not obj._mem_constant or init.is_numeric:
278-
value = c.Initializer(value, ccode(init))
282+
value = c.Initializer(value, self.ccode(init))
279283
elif obj.is_LocalObject and obj.initvalue is not None and mode == 1:
280-
value = c.Initializer(value, ccode(obj.initvalue))
284+
value = c.Initializer(value, self.ccode(obj.initvalue))
281285

282286
return value
283287

@@ -311,7 +315,7 @@ def _args_call(self, args):
311315
else:
312316
ret.append(i._C_name)
313317
except AttributeError:
314-
ret.append(ccode(i))
318+
ret.append(self.ccode(i))
315319
return ret
316320

317321
def _gen_signature(self, o, is_declaration=False):
@@ -388,7 +392,7 @@ def visit_tuple(self, o):
388392
def visit_PointerCast(self, o):
389393
f = o.function
390394
i = f.indexed
391-
cstr = ccode(i._C_typedata)
395+
cstr = self.ccode(i._C_typedata)
392396

393397
if f.is_PointerArray:
394398
# lvalue
@@ -410,7 +414,7 @@ def visit_PointerCast(self, o):
410414
else:
411415
v = f.name
412416
if o.flat is None:
413-
shape = ''.join("[%s]" % ccode(i) for i in o.castshape)
417+
shape = ''.join("[%s]" % self.ccode(i) for i in o.castshape)
414418
rshape = '(*)%s' % shape
415419
lvalue = c.Value(cstr, '(*restrict %s)%s' % (v, shape))
416420
else:
@@ -443,9 +447,9 @@ def visit_Dereference(self, o):
443447
a0, a1 = o.functions
444448
if a1.is_PointerArray or a1.is_TempFunction:
445449
i = a1.indexed
446-
cstr = ccode(i._C_typedata)
450+
cstr = self.ccode(i._C_typedata)
447451
if o.flat is None:
448-
shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:])
452+
shape = ''.join("[%s]" % self.ccode(i) for i in a0.symbolic_shape[1:])
449453
rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name,
450454
a1.dim.name)
451455
lvalue = c.Value(cstr, '(*restrict %s)%s' % (a0.name, shape))
@@ -484,8 +488,8 @@ def visit_Definition(self, o):
484488
return self._gen_value(o.function)
485489

486490
def visit_Expression(self, o):
487-
lhs = ccode(o.expr.lhs, dtype=o.dtype)
488-
rhs = ccode(o.expr.rhs, dtype=o.dtype)
491+
lhs = self.ccode(o.expr.lhs, dtype=o.dtype)
492+
rhs = self.ccode(o.expr.rhs, dtype=o.dtype)
489493

490494
if o.init:
491495
code = c.Initializer(self._gen_value(o.expr.lhs, 0), rhs)
@@ -498,8 +502,8 @@ def visit_Expression(self, o):
498502
return code
499503

500504
def visit_AugmentedExpression(self, o):
501-
c_lhs = ccode(o.expr.lhs, dtype=o.dtype)
502-
c_rhs = ccode(o.expr.rhs, dtype=o.dtype)
505+
c_lhs = self.ccode(o.expr.lhs, dtype=o.dtype)
506+
c_rhs = self.ccode(o.expr.rhs, dtype=o.dtype)
503507
code = c.Statement("%s %s= %s" % (c_lhs, o.op, c_rhs))
504508
if o.pragmas:
505509
code = c.Module(self._visit(o.pragmas) + (code,))
@@ -518,7 +522,7 @@ def visit_Call(self, o, nested_call=False):
518522
o.templates)
519523
if retobj.is_Indexed or \
520524
isinstance(retobj, (FieldFromComposite, FieldFromPointer)):
521-
return c.Assign(ccode(retobj), call)
525+
return c.Assign(self.ccode(retobj), call)
522526
else:
523527
return c.Initializer(c.Value(rettype, retobj._C_name), call)
524528

@@ -532,9 +536,9 @@ def visit_Conditional(self, o):
532536
then_body = c.Block(self._visit(then_body))
533537
if else_body:
534538
else_body = c.Block(self._visit(else_body))
535-
return c.If(ccode(o.condition), then_body, else_body)
539+
return c.If(self.ccode(o.condition), then_body, else_body)
536540
else:
537-
return c.If(ccode(o.condition), then_body)
541+
return c.If(self.ccode(o.condition), then_body)
538542

539543
def visit_Iteration(self, o):
540544
body = flatten(self._visit(i) for i in self._blankline_logic(o.children))
@@ -544,23 +548,23 @@ def visit_Iteration(self, o):
544548

545549
# For backward direction flip loop bounds
546550
if o.direction == Backward:
547-
loop_init = 'int %s = %s' % (o.index, ccode(_max))
548-
loop_cond = '%s >= %s' % (o.index, ccode(_min))
551+
loop_init = 'int %s = %s' % (o.index, self.ccode(_max))
552+
loop_cond = '%s >= %s' % (o.index, self.ccode(_min))
549553
loop_inc = '%s -= %s' % (o.index, o.limits[2])
550554
else:
551-
loop_init = 'int %s = %s' % (o.index, ccode(_min))
552-
loop_cond = '%s <= %s' % (o.index, ccode(_max))
555+
loop_init = 'int %s = %s' % (o.index, self.ccode(_min))
556+
loop_cond = '%s <= %s' % (o.index, self.ccode(_max))
553557
loop_inc = '%s += %s' % (o.index, o.limits[2])
554558

555559
# Append unbounded indices, if any
556560
if o.uindices:
557-
uinit = ['%s = %s' % (i.name, ccode(i.symbolic_min)) for i in o.uindices]
561+
uinit = ['%s = %s' % (i.name, self.ccode(i.symbolic_min)) for i in o.uindices]
558562
loop_init = c.Line(', '.join([loop_init] + uinit))
559563

560564
ustep = []
561565
for i in o.uindices:
562566
op = '=' if i.is_Modulo else '+='
563-
ustep.append('%s %s %s' % (i.name, op, ccode(i.symbolic_incr)))
567+
ustep.append('%s %s %s' % (i.name, op, self.ccode(i.symbolic_incr)))
564568
loop_inc = c.Line(', '.join([loop_inc] + ustep))
565569

566570
# Create For header+body
@@ -577,7 +581,7 @@ def visit_Pragma(self, o):
577581
return c.Pragma(o._generate)
578582

579583
def visit_While(self, o):
580-
condition = ccode(o.condition)
584+
condition = self.ccode(o.condition)
581585
if o.body:
582586
body = flatten(self._visit(i) for i in o.children)
583587
return c.While(condition, c.Block(body))

devito/operator/operator.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from devito.operator.profiling import create_profile
2424
from devito.operator.registry import operator_selector
2525
from devito.mpi import MPI
26-
from devito.parameters import configuration, switchconfig
26+
from devito.parameters import configuration
2727
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
2828
generate_macros, minimize_symbols, unevaluate,
2929
error_mapper, is_on_device)
@@ -485,8 +485,6 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):
485485

486486
# Lower IET to a target-specific IET
487487
graph = Graph(iet, **kwargs)
488-
489-
# Specialize
490488
graph = cls._specialize_iet(graph, **kwargs)
491489

492490
# Instrument the IET for C-level profiling
@@ -775,12 +773,11 @@ def _soname(self):
775773

776774
@cached_property
777775
def ccode(self):
778-
with switchconfig(compiler=self._compiler, language=self._language):
779-
try:
780-
return self._ccode_handler().visit(self)
781-
except (AttributeError, TypeError):
782-
from devito.ir.iet.visitors import CGen
783-
return CGen().visit(self)
776+
try:
777+
return self._ccode_handler(language=self._language).visit(self)
778+
except (AttributeError, TypeError):
779+
from devito.ir.iet.visitors import CGen
780+
return CGen(language=self._language).visit(self)
784781

785782
def _jit_compile(self):
786783
"""
@@ -918,7 +915,8 @@ def apply(self, **kwargs):
918915
"""
919916
# Compile the operator before building the arguments list
920917
# to avoid out of memory with greedy compilers
921-
cfunction = self.cfunction
918+
with self._profiler.timer_on('jit-compile'):
919+
cfunction = self.cfunction
922920

923921
# Build the arguments list to invoke the kernel function
924922
with self._profiler.timer_on('arguments'):

devito/passes/iet/engine.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,11 @@ class Graph:
4040
The `visit` method collects info about the nodes in the Graph.
4141
"""
4242

43-
def __init__(self, iet, options=None, sregistry=None, **kwargs):
43+
def __init__(self, iet, options=None, sregistry=None, language=None, **kwargs):
4444
self.efuncs = OrderedDict([(iet.name, iet)])
4545

4646
self.sregistry = sregistry
47+
self.language = language
4748

4849
self.includes = []
4950
self.headers = []
@@ -147,7 +148,7 @@ def apply(self, func, **kwargs):
147148
# Minimize code size
148149
if len(efuncs) > len(self.efuncs):
149150
efuncs = reuse_compounds(efuncs, self.sregistry)
150-
efuncs = reuse_efuncs(self.root, efuncs, self.sregistry)
151+
efuncs = reuse_efuncs(self.root, efuncs, self.sregistry, self.language)
151152

152153
self.efuncs = efuncs
153154

@@ -316,7 +317,7 @@ def _(i, sregistry=None):
316317
return i._rebuild(pname=pname, cfields=cfields, ncfields=ncfields, function=None)
317318

318319

319-
def reuse_efuncs(root, efuncs, sregistry=None):
320+
def reuse_efuncs(root, efuncs, sregistry=None, language=None):
320321
"""
321322
Generalise `efuncs` so that syntactically identical Callables may be dropped,
322323
thus maximizing code reuse.

devito/symbolics/extended_sympy.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -765,14 +765,11 @@ class SizeOf(DefFunction):
765765

766766
def __new__(cls, intype, stars=None, **kwargs):
767767
stars = stars or ''
768-
769768
if not isinstance(intype, (str, ReservedWord)):
770-
intype = dtype_to_ctype(intype)
771-
if intype in ctypes_vector_mapper.values():
772-
idx = list(ctypes_vector_mapper.values()).index(intype)
769+
ctype = dtype_to_ctype(intype)
770+
if ctype in ctypes_vector_mapper.values():
771+
idx = list(ctypes_vector_mapper.values()).index(ctype)
773772
intype = list(ctypes_vector_mapper.keys())[idx]
774-
else:
775-
intype = ctypes_to_cstr(intype)
776773

777774
newobj = super().__new__(cls, 'sizeof', arguments=f'{intype}{stars}', **kwargs)
778775
newobj.stars = stars

devito/symbolics/printer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from devito.symbolics.inspection import has_integer_args, sympy_dtype
2121
from devito.symbolics.extended_dtypes import c_complex, c_double_complex
2222
from devito.types.basic import AbstractFunction
23-
from devito.tools import ctypes_to_cstr
23+
from devito.tools import ctypes_to_cstr, dtype_to_ctype
2424

2525
__all__ = ['ccode']
2626

@@ -95,6 +95,10 @@ def parenthesize(self, item, level, strict=False):
9595
return super().parenthesize(item, level, strict=strict)
9696

9797
def _print_type(self, expr):
98+
try:
99+
expr = dtype_to_ctype(expr)
100+
except TypeError:
101+
pass
98102
try:
99103
return self.type_mappings[expr]
100104
except KeyError:
@@ -422,7 +426,7 @@ class AccDevitoPrinter(CXXDevitoPrinter):
422426
'openacc': AccDevitoPrinter}
423427

424428

425-
def ccode(expr, **settings):
429+
def ccode(expr, language=None, **settings):
426430
"""Generate C++ code from an expression.
427431
428432
Parameters
@@ -438,5 +442,6 @@ def ccode(expr, **settings):
438442
The resulting code as a C++ string. If something went south, returns
439443
the input ``expr`` itself.
440444
"""
441-
printer = printer_registry.get(configuration['language'], CDevitoPrinter)
445+
lang = language or configuration['language']
446+
printer = printer_registry.get(lang, CDevitoPrinter)
442447
return printer(settings=settings).doprint(expr, None)

devito/types/basic.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import abc
22
import inspect
33
from collections import namedtuple
4-
from ctypes import POINTER, _Pointer, c_char_p, c_char
4+
from ctypes import POINTER, _Pointer, c_char_p, c_char, Structure
55
from functools import reduce, cached_property
66
from operator import mul
77

@@ -87,13 +87,18 @@ def _C_typedata(self):
8787
if isinstance(_type, CustomDtype):
8888
return _type
8989

90+
_pointer = False
9091
while issubclass(_type, _Pointer):
92+
_pointer = True
9193
_type = _type._type_
9294

9395
# `ctypes` treats C strings specially
9496
if _type is c_char_p:
9597
_type = c_char
9698

99+
if issubclass(_type, Structure) and _pointer:
100+
_type = f'struct {_type.__name__}'
101+
97102
return _type
98103

99104
@abc.abstractproperty

tests/test_dtypes.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ def _config_kwargs(platform: str, language: str) -> dict[str, str]:
5959

6060
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
6161
@pytest.mark.parametrize('kwargs', _configs)
62-
def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None:
62+
def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str],
63+
expected=None) -> None:
6364
"""
6465
Tests that half and complex floats' dtypes result in the correct type
6566
strings in generated code.
@@ -78,7 +79,7 @@ def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> N
7879
params: dict[str, Basic] = {p.name: p for p in op.parameters}
7980
_u, _c = params['u'], params['c']
8081
assert type(_u.indexed._C_ctype._type_()) == ctypes_vector_mapper[dtype]
81-
assert _c._C_ctype == ctypes_vector_mapper[dtype]
82+
assert _c._C_ctype == expected or ctypes_vector_mapper[dtype]
8283

8384

8485
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])

0 commit comments

Comments
 (0)