Skip to content

Commit 5111506

Browse files
authored
Merge pull request #2728 from devitocodes/simplify-queue-dim-reuse
compiler: Enhance Array codegen
2 parents a5193e2 + ac68597 commit 5111506

8 files changed

Lines changed: 86 additions & 30 deletions

File tree

devito/arch/compiler.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,22 @@ def load(self, soname):
291291
"""
292292
return npct.load_library(str(self.get_jit_dir().joinpath(soname)), '.')
293293

294+
def save_header(self, filename, code):
295+
"""
296+
Store some source code into a header file within the same temporary directory
297+
used for JIT compilation.
298+
299+
Parameters
300+
----------
301+
filename : str
302+
The name of the header file (w/o the suffix).
303+
code : str
304+
The source code to be stored.
305+
"""
306+
hfile = self.get_jit_dir().joinpath(filename).with_suffix('.h')
307+
with open(str(hfile), 'w') as f:
308+
f.write(code)
309+
294310
def save(self, soname, binary):
295311
"""
296312
Store a binary into a file within a temporary directory.

devito/ir/cgen/printer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class BasePrinter(CodePrinter):
4141
_prec_literals = {np.float32: 'F', np.complex64: 'F'}
4242

4343
_qualifiers_mapper = {
44+
'is_extern': 'extern',
4445
'is_const': 'const',
4546
'is_volatile': 'volatile',
4647
'_mem_constant': 'static',

devito/ir/iet/nodes.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ class List(Node):
195195

196196
_traversable = ['body']
197197

198-
def __init__(self, header=None, body=None, footer=None):
198+
def __init__(self, header=None, body=None, footer=None, inline=False):
199199
body = as_tuple(body)
200200
if len(body) == 1 and all(type(i) is List for i in [self, body[0]]):
201201
# De-nest Lists
@@ -213,6 +213,8 @@ def __init__(self, header=None, body=None, footer=None):
213213
self.body = as_tuple(body)
214214
self.footer = as_tuple(footer)
215215

216+
self.inline = inline
217+
216218
def __repr__(self):
217219
return "<%s (%d, %d, %d)>" % (self.__class__.__name__, len(self.header),
218220
len(self.body), len(self.footer))
@@ -1047,9 +1049,8 @@ class Dereference(ExprStmt, Node):
10471049
A node encapsulating a dereference from a `pointer` to a `pointee`.
10481050
The following cases are supported:
10491051
1050-
* `pointer` is a PointerArray or TempFunction, and `pointee` is an Array.
1051-
* `pointer` is an ArrayObject representing a pointer to a C struct, and
1052-
`pointee` is a field in `pointer`.
1052+
* `pointer` is an AbstractFunction, and `pointee` is an Array.
1053+
* `pointer` is an AbstractObject, and `pointee` is an Array.
10531054
* `pointer` is a Symbol with its _C_ctype deriving from ct._Pointer, and
10541055
`pointee` is a Symbol representing the dereferenced value.
10551056
"""
@@ -1075,20 +1076,23 @@ def expr_symbols(self):
10751076
assert issubclass(self.pointer._C_ctype, ctypes._Pointer), \
10761077
"Scalar dereference must have a pointer ctype"
10771078
ret.extend([self.pointer._C_symbol, self.pointee._C_symbol])
1078-
elif self.pointer.is_PointerArray or self.pointer.is_TempFunction:
1079+
elif self.pointer.is_AbstractFunction:
10791080
ret.extend([self.pointer.indexed, self.pointee.indexed])
10801081
ret.extend(flatten(i.free_symbols
10811082
for i in self.pointee.symbolic_shape[1:]))
10821083
ret.extend(self.pointer.free_symbols)
1084+
elif self.pointer.is_AbstractObject:
1085+
ret.extend([self.pointer, self.pointee.indexed])
1086+
ret.extend(flatten(i.free_symbols
1087+
for i in self.pointee.symbolic_shape[1:]))
10831088
else:
1084-
ret.extend([self.pointer.indexed, self.pointee._C_symbol])
1089+
assert False, f"Unexpected pointer type {type(self.pointer)}"
1090+
10851091
return tuple(filter_ordered(ret))
10861092

10871093
@property
10881094
def defines(self):
1089-
if self.pointer.is_PointerArray or \
1090-
self.pointer.is_TempFunction or \
1091-
self.pointee._mem_stack:
1095+
if self.pointer.is_AbstractFunction or self.pointee._mem_stack:
10921096
return (self.pointee.indexed, self.pointee)
10931097
else:
10941098
return (self.pointee,)

devito/ir/iet/visitors.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -506,24 +506,30 @@ def visit_PointerCast(self, o):
506506

507507
def visit_Dereference(self, o):
508508
a0, a1 = o.functions
509-
if a1.is_PointerArray or a1.is_TempFunction:
510-
i = a1.indexed
511-
cstr = self.ccode(i._C_typedata)
509+
if a0.is_AbstractFunction:
510+
cstr = self.ccode(a0.indexed._C_typedata)
511+
512+
try:
513+
# Special AbstractFunctions such as PointerArray or TempFunction
514+
cdim = f'[{a1.dim.name}]'
515+
except AttributeError:
516+
cdim = ''
517+
512518
if o.flat is None:
513519
shape = ''.join(f"[{self.ccode(i)}]" for i in a0.symbolic_shape[1:])
514-
rvalue = f'({cstr} (*){shape}) {a1.name}[{a1.dim.name}]'
520+
rvalue = f'({cstr} (*){shape}) {a1.name}{cdim}'
515521
lvalue = c.Value(cstr, f'(*{self._restrict_keyword} {a0.name}){shape}')
516522
else:
517-
rvalue = f'({cstr} *) {a1.name}[{a1.dim.name}]'
523+
rvalue = f'({cstr} *) {a1.name}{cdim}'
518524
lvalue = c.Value(cstr, f'*{self._restrict_keyword} {a0.name}')
519-
if a0._data_alignment:
520-
lvalue = c.AlignedAttribute(a0._data_alignment, lvalue)
525+
521526
else:
522527
if a1.is_Symbol:
523528
rvalue = f'*{a1.name}'
524529
else:
525530
rvalue = f'{a1.name}->{a0._C_name}'
526531
lvalue = self._gen_value(a0, 0)
532+
527533
return c.Initializer(lvalue, rvalue)
528534

529535
def visit_Block(self, o):
@@ -532,7 +538,11 @@ def visit_Block(self, o):
532538

533539
def visit_List(self, o):
534540
body = flatten(self._visit(i) for i in self._blankline_logic(o.children))
535-
return c.Module(o.header + (c.Collection(body),) + o.footer)
541+
if o.inline:
542+
body = c.Line(' '.join(str(i) for i in body))
543+
else:
544+
body = c.Collection(body)
545+
return c.Module(o.header + (body,) + o.footer)
536546

537547
def visit_Section(self, o):
538548
body = flatten(self._visit(i) for i in o.children)

devito/types/array.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,10 @@ def symbolic_shape(self):
503503
# shape, i.e. something along the lines of `(x_size, y_size, z_size)`
504504
return self.c0.symbolic_shape
505505

506+
@property
507+
def nbytes(self):
508+
return self.size*self.dtype.itemsize
509+
506510
@property
507511
def _mem_heap(self):
508512
return not any([self._mem_stack, self._mem_shared, self._mem_shared_remote])

devito/types/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def _defines(self):
176176
return frozenset().union(*[i._defines for i in self])
177177

178178

179-
class Pointer(LocalObject):
179+
class Pointer(LocalObject, sympy.Expr):
180180

181181
__rkwargs__ = LocalObject.__rkwargs__ + ('dtype',)
182182

examples/performance/00_overview.ipynb

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,7 @@
701701
" #pragma omp parallel num_threads(nthreads)\n",
702702
" {\n",
703703
" const int tid = omp_get_thread_num();\n",
704-
" float (*restrict r0)[z_size] __attribute__ ((aligned (64))) = (float (*)[z_size]) pr0[tid];\n",
704+
" float (*restrict r0)[z_size] = (float (*)[z_size]) pr0[tid];\n",
705705
"\n",
706706
" #pragma omp for schedule(dynamic,1)\n",
707707
" for (int x = x_m; x <= x_M; x += 1)\n",
@@ -836,7 +836,7 @@
836836
" #pragma omp parallel num_threads(nthreads)\n",
837837
" {\n",
838838
" const int tid = omp_get_thread_num();\n",
839-
" float (*restrict r1)[z_size] __attribute__ ((aligned (64))) = (float (*)[z_size]) pr1[tid];\n",
839+
" float (*restrict r1)[z_size] = (float (*)[z_size]) pr1[tid];\n",
840840
"\n",
841841
" #pragma omp for schedule(dynamic,1)\n",
842842
" for (int x = x_m; x <= x_M; x += 1)\n",
@@ -968,7 +968,7 @@
968968
" #pragma omp parallel num_threads(nthreads)\n",
969969
" {\n",
970970
" const int tid = omp_get_thread_num();\n",
971-
" float (*restrict r0)[z_size] __attribute__ ((aligned (64))) = (float (*)[z_size]) pr0[tid];\n",
971+
" float (*restrict r0)[z_size] = (float (*)[z_size]) pr0[tid];\n",
972972
"\n",
973973
" #pragma omp for schedule(dynamic,1)\n",
974974
" for (int x = x_m; x <= x_M; x += 1)\n",
@@ -1236,7 +1236,7 @@
12361236
" #pragma omp parallel num_threads(nthreads)\n",
12371237
" {\n",
12381238
" const int tid = omp_get_thread_num();\n",
1239-
" float (*restrict r2)[z_size] __attribute__ ((aligned (64))) = (float (*)[z_size]) pr2[tid];\n",
1239+
" float (*restrict r2)[z_size] = (float (*)[z_size]) pr2[tid];\n",
12401240
"\n",
12411241
" #pragma omp for collapse(2) schedule(dynamic,1)\n",
12421242
" for (int x0_blk0 = x_m; x0_blk0 <= x_M; x0_blk0 += x0_blk0_size)\n",
@@ -1348,7 +1348,7 @@
13481348
" #pragma omp parallel num_threads(nthreads)\n",
13491349
" {\n",
13501350
" const int tid = omp_get_thread_num();\n",
1351-
" float (*restrict r2)[z_size] __attribute__ ((aligned (64))) = (float (*)[z_size]) pr2[tid];\n",
1351+
" float (*restrict r2)[z_size] = (float (*)[z_size]) pr2[tid];\n",
13521352
"\n",
13531353
" #pragma omp for collapse(2) schedule(dynamic,1)\n",
13541354
" for (int x0_blk0 = x_m; x0_blk0 <= x_M; x0_blk0 += x0_blk0_size)\n",
@@ -1527,7 +1527,7 @@
15271527
" #pragma omp parallel num_threads(nthreads)\n",
15281528
" {\n",
15291529
" const int tid = omp_get_thread_num();\n",
1530-
" float (*restrict r2)[z_size] __attribute__ ((aligned (64))) = (float (*)[z_size]) pr2[tid];\n",
1530+
" float (*restrict r2)[z_size] = (float (*)[z_size]) pr2[tid];\n",
15311531
"\n",
15321532
" #pragma omp for schedule(dynamic,1)\n",
15331533
" for (int x = x_m; x <= x_M; x += 1)\n",

tests/test_iet.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,18 @@
77

88
from devito import (Eq, Grid, Function, TimeFunction, Operator, Dimension, # noqa
99
switchconfig)
10-
from devito.ir.iet import (Call, Callable, Conditional, DeviceCall, DummyExpr,
11-
Iteration, List, KernelLaunch, Lambda, ElementalFunction,
12-
CGen, FindSymbols, filter_iterations, make_efunc,
13-
retrieve_iteration_tree, Transformer)
10+
from devito.ir.iet import (
11+
Call, Callable, Conditional, Definition, DeviceCall, DummyExpr, Iteration, List,
12+
KernelLaunch, Lambda, ElementalFunction, CGen, FindSymbols, filter_iterations,
13+
make_efunc, retrieve_iteration_tree, Transformer
14+
)
1415
from devito.ir import SymbolRegistry
1516
from devito.passes.iet.engine import Graph
1617
from devito.passes.iet.languages.C import CDataManager
1718
from devito.symbolics import (Byref, FieldFromComposite, InlineIf, Macro, Class,
18-
FLOAT)
19+
String, FLOAT)
1920
from devito.tools import CustomDtype, as_tuple, dtype_to_ctype
20-
from devito.types import Array, LocalObject, Symbol
21+
from devito.types import CustomDimension, Array, LocalObject, Symbol
2122

2223

2324
@pytest.fixture
@@ -475,3 +476,23 @@ def test_codegen_quality0():
475476
assert len(foo.parameters) == 3
476477
assert len(foo1.parameters) == 1
477478
assert foo1.parameters[0] is a
479+
480+
481+
def test_special_array_definition():
482+
483+
class MyArray(Array):
484+
is_extern = True
485+
_data_alignment = False
486+
487+
dim = CustomDimension(name='d', symbolic_size=String(''))
488+
a = MyArray(name='a', dimensions=dim, scope='shared', dtype=np.uint8)
489+
490+
assert str(Definition(a)) == "extern unsigned char a[];"
491+
492+
493+
def test_list_inline():
494+
expr0 = DummyExpr(Symbol(name='a'), 1)
495+
expr1 = DummyExpr(Symbol(name='b'), 2)
496+
497+
lst = List(body=[expr0, expr1], inline=True)
498+
assert str(lst) == """a = 1; b = 2;"""

0 commit comments

Comments
 (0)