Skip to content

Commit 8d03952

Browse files
committed
compiler: Revamp Array allocation
1 parent 5569e99 commit 8d03952

9 files changed

Lines changed: 138 additions & 65 deletions

File tree

devito/operator/operator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
from devito.data import default_allocator
1414
from devito.exceptions import (CompilationError, ExecutionError, InvalidArgument,
1515
InvalidOperator)
16-
from devito.logger import debug, info, perf, warning, is_log_enabled_for, switch_log_level
16+
from devito.logger import (debug, info, perf, warning, is_log_enabled_for,
17+
switch_log_level)
1718
from devito.ir.equations import LoweredEq, lower_exprs, concretize_subdims
1819
from devito.ir.clusters import ClusterGroup, clusterize
19-
from devito.ir.iet import (Callable, CInterface, EntryFunction, FindSymbols, MetaCall,
20-
derive_parameters, iet_build)
20+
from devito.ir.iet import (Callable, CInterface, EntryFunction, FindSymbols,
21+
MetaCall, derive_parameters, iet_build)
2122
from devito.ir.support import AccessMode, SymbolRegistry
2223
from devito.ir.stree import stree_build
2324
from devito.operator.profiling import create_profile
@@ -26,8 +27,7 @@
2627
from devito.parameters import configuration
2728
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
2829
generate_macros, minimize_symbols, unevaluate,
29-
error_mapper, is_on_device)
30-
from devito.passes.iet.dtypes import lower_dtypes
30+
error_mapper, is_on_device, lower_dtypes)
3131
from devito.symbolics import estimate_cost, subs_op_args
3232
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple,
3333
flatten, filter_sorted, frozendict, is_integer,
@@ -488,7 +488,7 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):
488488
# Extract the necessary macros from the symbolic objects
489489
generate_macros(graph, **kwargs)
490490

491-
# Add type specific metadata
491+
# Target-specific lowering
492492
lower_dtypes(graph, **kwargs)
493493

494494
# Target-independent optimizations

devito/passes/iet/definitions.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,22 @@
99

1010
import numpy as np
1111

12-
from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction,
13-
FindNodes, FindSymbols, MapExprStmts, Transformer,
14-
make_callable)
12+
from devito.ir import (
13+
Block, Call, Definition, DummyExpr, Iteration, List, Return, EntryFunction,
14+
FindNodes, FindSymbols, MapExprStmts, Transformer, make_callable
15+
)
1516
from devito.passes import is_gpu_create
1617
from devito.passes.iet.engine import iet_pass
1718
from devito.passes.iet.langbase import LangBB
18-
from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer,
19-
SizeOf, VOID, pow_to_mul, unevaluate)
19+
from devito.symbolics import (
20+
Byref, DefFunction, FieldFromPointer, IndexedPointer, ListInitializer,
21+
SizeOf, VOID, pow_to_mul, unevaluate
22+
)
2023
from devito.tools import as_mapper, as_list, as_tuple, filter_sorted, flatten
21-
from devito.types import (Array, ComponentAccess, CustomDimension, DeviceMap,
22-
DeviceRM, Eq, Symbol, size_t)
24+
from devito.types import (
25+
Array, ComponentAccess, CustomDimension, Dimension, DeviceMap, DeviceRM,
26+
Eq, Symbol, size_t
27+
)
2328

2429
__all__ = ['DataManager', 'DeviceAwareDataManager', 'Storage']
2530

@@ -40,6 +45,7 @@ def __init__(self, *args, **kwargs):
4045
super().__init__(*args, **kwargs)
4146

4247
self.defined = set()
48+
self.includes = set()
4349

4450
def update(self, key, site, **kwargs):
4551
if key in self.defined:
@@ -67,6 +73,10 @@ def map(self, key, site, k, v):
6773

6874
self.defined.add((site[-1], key))
6975

76+
def include(self, v):
77+
if v:
78+
self.includes.add(v)
79+
7080

7181
class DataManager:
7282

@@ -132,8 +142,7 @@ def _alloc_array_on_global_mem(self, site, obj, storage):
132142
alloc = Call(name, efunc.parameters)
133143

134144
storage.update(obj, site, allocs=alloc, efuncs=efunc)
135-
136-
return self.langbb['header-memcpy']
145+
storage.include(self.langbb['header-memcpy'])
137146

138147
def _alloc_scalar_on_low_lat_mem(self, site, expr, storage):
139148
"""
@@ -170,6 +179,12 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):
170179

171180
arity_param = Symbol(name='arity', dtype=size_t)
172181
arity_arg = SizeOf(obj.indexed._C_typedata)
182+
ndims_param = Symbol(name='ndims', dtype=size_t)
183+
ndims_arg = obj.ndim
184+
shape_param = Array(name=f'{obj.name}_shape', dtype=np.uint64,
185+
dimensions=(Dimension(name='d'),), scope='rvalue')
186+
shape_arg = ListInitializer(obj.c0.symbolic_shape, dtype=shape_param.dtype)
187+
173188
ffp0 = FieldFromPointer(obj._C_field_data, obj._C_symbol)
174189
ffp1 = FieldFromPointer(obj._C_field_shape, obj._C_symbol)
175190
ffp2 = FieldFromPointer(obj._C_field_size, obj._C_symbol)
@@ -183,14 +198,21 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):
183198

184199
# Allocate the shape array
185200
memptr = VOID(Byref(ffp1), '**')
186-
nbytes = SizeOf(obj._C_size_type)*obj.ndim
201+
nbytes = SizeOf(obj._C_size_type)*ndims_param
187202
alloc1 = self.langbb['host-alloc'](memptr, alignment, nbytes)
188203

189-
# Initialize the Array struct
190-
init = [*[DummyExpr(IndexedPointer(ffp1, i), s)
191-
for i, s in enumerate(obj.c0.symbolic_shape)],
192-
DummyExpr(ffp2, obj.size),
193-
DummyExpr(ffp3, ffp2*arity_param)]
204+
# Initialize the Array metadata
205+
dim, = shape_param.dimensions
206+
init = [DummyExpr(ffp2, 1)]
207+
init.append(Iteration(
208+
List(body=(
209+
DummyExpr(IndexedPointer(ffp1, dim), shape_param[dim]),
210+
DummyExpr(ffp2, ffp2*shape_param[dim])
211+
)),
212+
dim,
213+
ndims_param - 1,
214+
))
215+
init.append(DummyExpr(ffp3, ffp2*arity_param))
194216

195217
# Allocate the underlying host data
196218
memptr = VOID(Byref(ffp0), '**')
@@ -213,6 +235,8 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):
213235
efunc0 = make_callable(name, body, retval=obj)
214236
args = list(efunc0.parameters)
215237
args[args.index(arity_param)] = arity_arg
238+
args[args.index(ndims_param)] = ndims_arg
239+
args[args.index(shape_param)] = shape_arg
216240
alloc = Call(name, args, retobj=obj)
217241

218242
# Same story for the frees
@@ -222,6 +246,7 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):
222246
free = Call(name, efunc1.parameters)
223247

224248
storage.update(obj, site, allocs=alloc, frees=free, efuncs=(efunc0, efunc1))
249+
storage.include(self.langbb['header-array'])
225250

226251
def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage):
227252
"""
@@ -436,18 +461,15 @@ def place_definitions(self, iet, globs=None, **kwargs):
436461
self._alloc_pointed_array_on_high_bw_mem(iet, i, storage)
437462

438463
# Handle postponed global objects
439-
includes = set()
440464
if isinstance(iet, EntryFunction) and globs:
441465
for i in sorted(globs, key=lambda f: f.name):
442-
v = self._alloc_array_on_global_mem(iet, i, storage)
443-
if v:
444-
includes.add(v)
466+
self._alloc_array_on_global_mem(iet, i, storage)
445467

446468
iet, efuncs = self._inject_definitions(iet, storage)
447469

448470
return iet, {'efuncs': efuncs,
449471
'globals': as_tuple(globs),
450-
'includes': as_tuple(includes)}
472+
'includes': as_tuple(sorted(storage.includes))}
451473

452474
@iet_pass
453475
def place_casts(self, iet, **kwargs):

devito/passes/iet/languages/C.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@
55
from devito.passes.iet.definitions import DataManager
66
from devito.passes.iet.orchestration import Orchestrator
77
from devito.passes.iet.langbase import LangBB
8-
from devito.symbolics.extended_dtypes import c_complex, c_double_complex
8+
from devito.symbolics import c_complex, c_double_complex
9+
from devito.tools import dtype_to_cstr
910

1011
__all__ = ['CBB', 'CDataManager', 'COrchestrator']
1112

1213

1314
class CBB(LangBB):
1415

1516
mapper = {
17+
# Misc
18+
'header-array': None,
1619
# Complex
1720
'includes-complex': 'complex.h',
1821
# Allocs
@@ -55,3 +58,12 @@ class CPrinter(BasePrinter, C99CodePrinter):
5558

5659
def _print_ImaginaryUnit(self, expr):
5760
return '_Complex_I'
61+
62+
def _print_ListInitializer(self, expr):
63+
li = super()._print_ListInitializer(expr)
64+
if expr.dtype:
65+
# C99, unlike CXX, supports compound literals
66+
tstr = dtype_to_cstr(expr.dtype)
67+
return f'({tstr}[]){li}'
68+
else:
69+
return li

devito/passes/iet/languages/CXX.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
from devito.ir import Call, UsingNamespace, BasePrinter
55
from devito.passes.iet.langbase import LangBB
6-
from devito.symbolics.extended_dtypes import c_complex, c_double_complex
6+
from devito.symbolics import c_complex, c_double_complex
7+
from devito.tools import dtype_to_cstr
78

89
__all__ = ['CXXBB']
910

@@ -65,6 +66,8 @@ def std_arith(prefix=None):
6566
class CXXBB(LangBB):
6667

6768
mapper = {
69+
# Misc
70+
'header-array': 'array',
6871
# Complex
6972
'includes-complex': 'complex',
7073
'complex-namespace': [UsingNamespace('std::complex_literals')],
@@ -112,3 +115,12 @@ def _print_Cast(self, expr):
112115
caster = 'reinterpret_cast' if expr.reinterpret else 'static_cast'
113116
cast = f'{caster}<{tstr}{self._print(expr.stars)}>'
114117
return self._print_UnaryOp(expr, op=cast, parenthesize=True)
118+
119+
def _print_ListInitializer(self, expr):
120+
li = super()._print_ListInitializer(expr)
121+
if expr.dtype:
122+
# CXX, unlike C99, does not support compound literals
123+
tstr = dtype_to_cstr(expr.dtype)
124+
return f'std::array<{tstr}, {len(expr.params)}>{li}.data()'
125+
else:
126+
return li

devito/symbolics/extended_sympy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,16 +292,20 @@ class ListInitializer(sympy.Expr, Pickable):
292292
"""
293293

294294
__rargs__ = ('params',)
295+
__rkwargs__ = ('dtype',)
295296

296-
def __new__(cls, params):
297+
def __new__(cls, params, dtype=None):
297298
args = []
298299
for p in as_tuple(params):
299300
try:
300301
args.append(sympify(p))
301302
except sympy.SympifyError:
302303
raise ValueError(f"Illegal param `{p}`")
303304
obj = sympy.Expr.__new__(cls, *args)
305+
304306
obj.params = tuple(args)
307+
obj.dtype = dtype
308+
305309
return obj
306310

307311
def __str__(self):

devito/types/array.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __indices_setup__(cls, *args, **kwargs):
4040

4141
@property
4242
def _C_name(self):
43-
if self._mem_stack or self._mem_constant:
43+
if self._mem_stack or self._mem_constant or self._mem_rvalue:
4444
# No reason to distinguish between two different names, that is
4545
# the _C_name and the name -- just `self.name` is enough
4646
return self.name
@@ -97,7 +97,7 @@ class Array(ArrayBasic):
9797
to 'local'. Used to override `_mem_local` and `_mem_mapped`.
9898
scope : str, optional
9999
The scope in the given memory space. Allowed values: 'heap', 'stack',
100-
'static', 'constant', 'shared', 'shared-remote', 'registers'.
100+
'static', 'constant', 'shared', 'shared-remote', 'registers', 'rvalue'.
101101
'static' refers to a static array in a C/C++ sense. 'constant' and
102102
'shared' mean that the Array represents an object allocated in so
103103
called constant and shared memory, respectively, which are typical of
@@ -107,7 +107,9 @@ class Array(ArrayBasic):
107107
architecture doesn't have something akin to constant memory, the Array
108108
falls back to a global, const, static array in a C/C++ sense.
109109
'registers' is used to indicate that the Array has a small static size
110-
and, as such, it could be allocated in registers. Defaults to 'heap'.
110+
and, as such, it could be allocated in registers. If `rvalue`, the
111+
Array is treated as a temporary or "transient" object, just like
112+
C++'s rvalue references or C's compound literals. Defaults to 'heap'.
111113
Note that not all scopes make sense for a given space.
112114
grid : Grid, optional
113115
Only necessary for distributed-memory parallelism; a Grid contains
@@ -142,7 +144,7 @@ def __init_finalize__(self, *args, **kwargs):
142144

143145
self._scope = kwargs.get('scope', 'heap')
144146
assert self._scope in ['heap', 'stack', 'static', 'constant', 'shared',
145-
'shared-remote', 'registers']
147+
'shared-remote', 'registers', 'rvalue']
146148

147149
self._initvalue = kwargs.get('initvalue')
148150
assert self._initvalue is None or self._scope != 'heap'
@@ -197,6 +199,10 @@ def _mem_registers(self):
197199
def _mem_constant(self):
198200
return self._scope == 'constant'
199201

202+
@property
203+
def _mem_rvalue(self):
204+
return self._scope == 'rvalue'
205+
200206
@property
201207
def initvalue(self):
202208
return self._initvalue
@@ -473,10 +479,11 @@ def initvalue(self):
473479
# Defaulting to self.c0's behaviour
474480
for i in ('_mem_internal_eager', '_mem_internal_lazy', '_mem_local',
475481
'_mem_mapped', '_mem_host', '_mem_stack', '_mem_constant',
476-
'_mem_shared', '_mem_shared_remote', '__padding_dtype__',
477-
'_size_domain', '_size_halo', '_size_owned', '_size_padding',
478-
'_size_nopad', '_size_nodomain', '_offset_domain', '_offset_halo',
479-
'_offset_owned', '_dist_dimensions', '_C_get_field', 'grid',
482+
'_mem_shared', '_mem_shared_remote', '_mem_registers',
483+
'_mem_rvalue', '__padding_dtype__', '_size_domain', '_size_halo',
484+
'_size_owned', '_size_padding', '_size_nopad', '_size_nodomain',
485+
'_offset_domain', '_offset_halo', '_offset_owned',
486+
'_dist_dimensions', '_C_get_field', 'grid',
480487
*AbstractFunction.__properties__):
481488
locals()[i] = property(lambda self, v=i: getattr(self.c0, v))
482489

@@ -517,7 +524,6 @@ def _C_ctype(self):
517524
if self._mem_mapped:
518525
return super()._C_ctype
519526
else:
520-
#TODO DROP???
521527
return POINTER(dtype_to_ctype(self.dtype))
522528

523529

devito/types/basic.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ class CodeSymbol:
4242
* "liveness": `_mem_external`, `_mem_internal_eager`, `_mem_internal_lazy`
4343
* "space": `_mem_local`, `_mem_mapped`, `_mem_host`
4444
* "scope": `_mem_stack`, `_mem_heap`, `_mem_global`, `_mem_shared`,
45-
`_mem_shared_remote`, `_mem_constant`
45+
`_mem_shared_remote`, `_mem_constant`, `_mem_registers`,
46+
`_mem_rvalue`
4647
4748
For example, an object that is `<_mem_internal_lazy, _mem_local, _mem_heap>`
4849
is allocated within the Operator entry point, on either the host or device
@@ -230,6 +231,21 @@ def _mem_shared_remote(self):
230231
"""
231232
return False
232233

234+
@property
235+
def _mem_registers(self):
236+
"""
237+
True if the associated data is allocated in registers, False otherwise.
238+
"""
239+
return False
240+
241+
@property
242+
def _mem_rvalue(self):
243+
"""
244+
True if the associated data is allocated in a temporary (or "transient")
245+
variable, such as rvalues in CXX, False otherwise.
246+
"""
247+
return False
248+
233249

234250
class Basic(CodeSymbol):
235251

0 commit comments

Comments
 (0)