Skip to content

Commit 184fdaf

Browse files
committed
mpi: Fix Bundles support and improve codegen
1 parent 6e7466e commit 184fdaf

4 files changed

Lines changed: 137 additions & 65 deletions

File tree

devito/mpi/routines.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
IndexedPointer, Macro, cast, subs_op_args)
1919
from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, infer_datasize,
2020
flatten, generator, is_integer)
21-
from devito.types import (Array, Bag, Dimension, Eq, Symbol, LocalObject,
22-
CompositeObject, CustomDimension)
21+
from devito.types import (Array, Bag, BundleView, Dimension, Eq, Symbol,
22+
LocalObject, CompositeObject, CustomDimension)
2323

2424
__all__ = ['HaloExchangeBuilder', 'ReductionBuilder', 'mpi_registry']
2525

@@ -295,8 +295,15 @@ def _make_bundles(self, hs):
295295
halo_scheme = halo_scheme.drop(components)
296296

297297
# Existing Bundles are preserved
298-
if hse.bundle and set(components) == set(hse.bundle.components):
299-
halo_scheme = halo_scheme.add(hse.bundle, hse)
298+
if hse.bundle:
299+
if set(components) == set(hse.bundle.components):
300+
halo_scheme = halo_scheme.add(hse.bundle, hse)
301+
else:
302+
name = f'bundleview_{hse.bundle.name}'
303+
bundle_view = BundleView(
304+
name=name, components=components, parent=hse.bundle
305+
)
306+
halo_scheme = halo_scheme.add(bundle_view, hse)
300307
continue
301308

302309
# We recast everything else as Bags for simplicity -- worst case
@@ -367,15 +374,13 @@ def _make_copy(self, f, hse, key, swap=False):
367374
name = 'scatter%s' % key
368375

369376
if isinstance(f, Bag):
370-
if hse.bundle is not None:
371-
# `f` is the only component of `hse.bundle` that is
372-
# being communicated
373-
assert f.ncomp == 1
374-
i = hse.bundle.components.index(f.c0)
375-
eqns.append(Eq(*swap(buf[[0] + bdims], hse.bundle[[i] + findices])))
376-
else:
377-
for i, c in enumerate(f.components):
378-
eqns.append(Eq(*swap(buf[[i] + bdims], c[findices])))
377+
for i, c in enumerate(f.components):
378+
eqns.append(Eq(*swap(buf[[i] + bdims], c[findices])))
379+
elif isinstance(f, BundleView):
380+
assert f.parent is hse.bundle
381+
for i, c in enumerate(f.components):
382+
indices = [f.parent.components.index(c), *findices]
383+
eqns.append(Eq(*swap(buf[[i] + bdims], f.parent[indices])))
379384
else:
380385
assert f.is_Bundle
381386
for i in range(f.ncomp):

devito/passes/iet/definitions.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
SizeOf, VOID, pow_to_mul, unevaluate)
2020
from devito.tools import as_mapper, as_list, as_tuple, filter_sorted, flatten
2121
from devito.types import (Array, ComponentAccess, CustomDimension, DeviceMap,
22-
DeviceRM, Eq, Symbol)
22+
DeviceRM, Eq, Symbol, size_t)
2323

2424
__all__ = ['DataManager', 'DeviceAwareDataManager', 'Storage']
2525

@@ -168,49 +168,56 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):
168168
"""
169169
decl = Definition(obj)
170170

171+
arity_param = Symbol(name='arity', dtype=size_t)
172+
arity_arg = SizeOf(obj.indexed._C_typedata)
173+
ffp0 = FieldFromPointer(obj._C_field_data, obj._C_symbol)
174+
ffp1 = FieldFromPointer(obj._C_field_shape, obj._C_symbol)
175+
ffp2 = FieldFromPointer(obj._C_field_size, obj._C_symbol)
176+
ffp3 = FieldFromPointer(obj._C_field_nbytes, obj._C_symbol)
177+
171178
# Allocate the Array struct
172179
memptr = VOID(Byref(obj._C_symbol), '**')
173180
alignment = obj._data_alignment
174181
nbytes = SizeOf(obj._C_typedata)
175-
allocs = [self.langbb['host-alloc'](memptr, alignment, nbytes)]
182+
alloc0 = self.langbb['host-alloc'](memptr, alignment, nbytes)
176183

177-
nbytes_param = Symbol(name='nbytes', dtype=np.uint64, is_const=True)
178-
nbytes_arg = SizeOf(obj.indexed._C_typedata)*obj.size
184+
# Allocate the shape array
185+
memptr = VOID(Byref(ffp1), '**')
186+
nbytes = SizeOf(obj._C_size_type)*obj.ndim
187+
alloc1 = self.langbb['host-alloc'](memptr, alignment, nbytes)
188+
189+
# Initialize the Array struct
190+
init = [*[DummyExpr(IndexedPointer(ffp1, i), s)
191+
for i, s in enumerate(obj.c0.symbolic_shape)],
192+
DummyExpr(ffp2, obj.size),
193+
DummyExpr(ffp3, ffp2*arity_param)]
179194

180195
# Allocate the underlying host data
181-
ffp0 = FieldFromPointer(obj._C_field_data, obj._C_symbol)
182196
memptr = VOID(Byref(ffp0), '**')
183-
allocs.append(self.langbb['host-alloc-pin'](memptr, alignment, nbytes_param))
184-
185-
# Initialize the Array struct
186-
ffp1 = FieldFromPointer(obj._C_field_nbytes, obj._C_symbol)
187-
init0 = DummyExpr(ffp1, nbytes_param)
188-
ffp2 = FieldFromPointer(obj._C_field_size, obj._C_symbol)
189-
init1 = DummyExpr(ffp2, 0)
197+
alloc2 = self.langbb['host-alloc-pin'](memptr, alignment, ffp3)
190198

199+
# Free all of the allocated data
191200
frees = [self.langbb['host-free-pin'](ffp0),
201+
self.langbb['host-free'](ffp1),
192202
self.langbb['host-free'](obj._C_symbol)]
193203

194204
# Allocate the underlying device data, if required by the backend
195-
alloc, free = self._make_dmap_allocfree(obj, nbytes_param)
196-
197-
# Chain together all allocs and frees
198-
allocs = as_tuple(allocs) + as_tuple(alloc)
199-
frees = as_tuple(free) + as_tuple(frees)
205+
alloc_dmap, free_dmap = self._make_dmap_allocfree(obj, ffp3)
200206

201207
ret = Return(obj._C_symbol)
202208

203209
# Wrap everything in a Callable so that we can reuse the same code
204210
# for equivalent Array structs
205211
name = self.sregistry.make_name(prefix='alloc')
206-
body = (decl, *allocs, init0, init1, ret)
212+
body = (decl, alloc0, alloc1, *init, alloc2, *as_tuple(alloc_dmap), ret)
207213
efunc0 = make_callable(name, body, retval=obj)
208214
args = list(efunc0.parameters)
209-
args[args.index(nbytes_param)] = nbytes_arg
215+
args[args.index(arity_param)] = arity_arg
210216
alloc = Call(name, args, retobj=obj)
211217

212218
# Same story for the frees
213219
name = self.sregistry.make_name(prefix='free')
220+
frees = as_tuple(free_dmap) + as_tuple(frees)
214221
efunc1 = make_callable(name, frees)
215222
free = Call(name, efunc1.parameters)
216223

@@ -222,35 +229,50 @@ def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage):
222229
"""
223230
decl = Definition(obj)
224231

232+
arity_param = Symbol(name='arity', dtype=size_t)
233+
arity_arg = SizeOf(obj.indexed._C_typedata)
234+
ffp1 = FieldFromPointer(obj._C_field_shape, obj._C_symbol)
235+
ffp2 = FieldFromPointer(obj._C_field_size, obj._C_symbol)
236+
ffp3 = FieldFromPointer(obj._C_field_nbytes, obj._C_symbol)
237+
225238
# Allocate the Bundle struct
226239
memptr = VOID(Byref(obj._C_symbol), '**')
227240
alignment = obj._data_alignment
228241
nbytes = SizeOf(obj._C_typedata)
229-
alloc = self.langbb['host-alloc'](memptr, alignment, nbytes)
242+
alloc0 = self.langbb['host-alloc'](memptr, alignment, nbytes)
230243

231-
nbytes_param = Symbol(name='nbytes', dtype=np.uint64, is_const=True)
232-
nbytes_arg = SizeOf(obj.indexed._C_typedata)*obj.size
244+
# Allocate the shape array
245+
memptr = VOID(Byref(ffp1), '**')
246+
nbytes = SizeOf(obj._C_size_type)*obj.ndim
247+
alloc1 = self.langbb['host-alloc'](memptr, alignment, nbytes)
233248

234249
# Initialize the Bundle struct
235-
ffp1 = FieldFromPointer(obj._C_field_nbytes, obj._C_symbol)
236-
init0 = DummyExpr(ffp1, nbytes_param)
237-
ffp2 = FieldFromPointer(obj._C_field_size, obj._C_symbol)
238-
init1 = DummyExpr(ffp2, 0)
250+
init = [*[DummyExpr(IndexedPointer(ffp1, i), s)
251+
for i, s in enumerate(obj.c0.symbolic_shape)],
252+
DummyExpr(ffp2, obj.size),
253+
DummyExpr(ffp3, ffp2*arity_param)]
239254

240-
free = self.langbb['host-free'](obj._C_symbol)
255+
# Free all of the allocated data
256+
frees = [self.langbb['host-free'](ffp1),
257+
self.langbb['host-free'](obj._C_symbol)]
241258

242259
ret = Return(obj._C_symbol)
243260

244261
# Wrap everything in a Callable so that we can reuse the same code
245262
# for equivalent Bundle structs
246263
name = self.sregistry.make_name(prefix='alloc')
247-
body = (decl, alloc, init0, init1, ret)
264+
body = (decl, alloc0, alloc1, *init, ret)
248265
efunc0 = make_callable(name, body, retval=obj)
249266
args = list(efunc0.parameters)
250-
args[args.index(nbytes_param)] = nbytes_arg
267+
args[args.index(arity_param)] = arity_arg
251268
alloc = Call(name, args, retobj=obj)
252269

253-
storage.update(obj, site, allocs=alloc, frees=free, efuncs=efunc0)
270+
# Same story for the frees
271+
name = self.sregistry.make_name(prefix='free')
272+
efunc1 = make_callable(name, frees)
273+
free = Call(name, efunc1.parameters)
274+
275+
storage.update(obj, site, allocs=alloc, frees=free, efuncs=(efunc0, efunc1))
254276

255277
def _alloc_object_array_on_low_lat_mem(self, site, obj, storage):
256278
"""

devito/types/array.py

Lines changed: 62 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ctypes import POINTER, Structure, c_void_p, c_ulong, c_uint64
1+
from ctypes import POINTER, Structure, c_void_p, c_uint64
22
from functools import cached_property
33

44
import numpy as np
@@ -10,7 +10,7 @@
1010
from devito.types.utils import CtypesFactory, DimensionTuple
1111

1212
__all__ = ['Array', 'ArrayMapped', 'ArrayObject', 'PointerArray', 'Bundle',
13-
'ComponentAccess', 'Bag']
13+
'ComponentAccess', 'Bag', 'BundleView']
1414

1515

1616
class ArrayBasic(AbstractFunction, LocalType):
@@ -60,6 +60,13 @@ def shape_allocated(self):
6060
def is_const(self):
6161
return self._is_const
6262

63+
@property
64+
def c0(self):
65+
# ArrayBasic can be used as a base class for tensorial objects (that is,
66+
# arrays whose components are AbstractFunctions). This property enables
67+
# treating the two cases uniformly in some lowering passes
68+
return self
69+
6370

6471
class Array(ArrayBasic):
6572

@@ -202,19 +209,27 @@ def _make_pointer(self, dim):
202209
return PointerArray(name='p%s' % self.name, dimensions=dim, array=self)
203210

204211

205-
class ArrayMapped(Array):
212+
class MappedArrayMixin:
206213

207214
_C_structname = 'array'
208215
_C_field_data = 'data'
209216
_C_field_dmap = 'dmap'
210-
_C_field_nbytes = 'nbytes'
217+
_C_field_shape = 'shape'
211218
_C_field_size = 'size'
219+
_C_field_nbytes = 'nbytes'
220+
221+
_C_size_type = c_uint64
212222

213223
_C_ctype = POINTER(type(_C_structname, (Structure,),
214224
{'_fields_': [(_C_field_data, c_restrict_void_p),
215-
(_C_field_nbytes, c_ulong),
216225
(_C_field_dmap, c_void_p),
217-
(_C_field_size, c_uint64)]}))
226+
(_C_field_shape, POINTER(_C_size_type)),
227+
(_C_field_size, _C_size_type),
228+
(_C_field_nbytes, _C_size_type)]}))
229+
230+
231+
class ArrayMapped(MappedArrayMixin, Array):
232+
pass
218233

219234

220235
class ArrayObject(ArrayBasic):
@@ -343,7 +358,7 @@ def array(self):
343358
return self._array
344359

345360

346-
class Bundle(ArrayBasic):
361+
class Bundle(MappedArrayMixin, ArrayBasic):
347362

348363
"""
349364
Tensor symbol representing an unrolled vector of AbstractFunctions.
@@ -417,7 +432,6 @@ def __halo_setup__(self, components=(), **kwargs):
417432

418433
@property
419434
def c0(self):
420-
# Shortcut for self.components[0]
421435
return self.components[0]
422436

423437
# Class attributes overrides
@@ -456,18 +470,26 @@ def ncomp(self):
456470
def initvalue(self):
457471
return None
458472

459-
# Overrides defaulting to self.c0's behaviour
460-
473+
# Defaulting to self.c0's behaviour
461474
for i in ('_mem_internal_eager', '_mem_internal_lazy', '_mem_local',
462475
'_mem_mapped', '_mem_host', '_mem_stack', '_mem_constant',
463476
'_mem_shared', '_mem_shared_remote', '__padding_dtype__',
464477
'_size_domain', '_size_halo', '_size_owned', '_size_padding',
465-
'_size_nopad', '_size_nodomain', '_offset_domain',
466-
'_offset_halo', '_offset_owned', '_dist_dimensions',
467-
'_C_get_field', 'grid', 'symbolic_shape',
478+
'_size_nopad', '_size_nodomain', '_offset_domain', '_offset_halo',
479+
'_offset_owned', '_dist_dimensions', '_C_get_field', 'grid',
468480
*AbstractFunction.__properties__):
469481
locals()[i] = property(lambda self, v=i: getattr(self.c0, v))
470482

483+
# Other overrides
484+
485+
@cached_property
486+
def symbolic_shape(self):
487+
from devito.symbolics import FieldFromPointer, IndexedPointer # noqa
488+
ffp = FieldFromPointer(self._C_field_shape, self._C_symbol)
489+
ret = [s if is_integer(s) else IndexedPointer(ffp, i)
490+
for i, s in enumerate(super().symbolic_shape)]
491+
return DimensionTuple(*ret, getters=self.dimensions)
492+
471493
@property
472494
def _mem_heap(self):
473495
return not any([self._mem_stack, self._mem_shared, self._mem_shared_remote])
@@ -490,17 +512,12 @@ def __getitem__(self, index):
490512
raise ValueError("Expected %d or %d indices, got %d instead"
491513
% (self.ndim, self.ndim + 1, len(index)))
492514

493-
_C_structname = ArrayMapped._C_structname
494-
_C_field_data = ArrayMapped._C_field_data
495-
_C_field_nbytes = ArrayMapped._C_field_nbytes
496-
_C_field_dmap = ArrayMapped._C_field_dmap
497-
_C_field_size = ArrayMapped._C_field_size
498-
499515
@property
500516
def _C_ctype(self):
501517
if self._mem_mapped:
502-
return ArrayMapped._C_ctype
518+
return super()._C_ctype
503519
else:
520+
#TODO DROP???
504521
return POINTER(dtype_to_ctype(self.dtype))
505522

506523

@@ -518,6 +535,31 @@ def handles(self):
518535
return self.components
519536

520537

538+
class BundleView(Bundle):
539+
540+
"""
541+
A BundleView is like a Bundle but it doesn't represent a concrete object
542+
in the generated code. It's used by the compiler to represent a subset
543+
of the components of a Bundle.
544+
"""
545+
546+
__rkwargs__ = Bundle.__rkwargs__ + ('parent',)
547+
548+
def __new__(cls, *args, parent=None, **kwargs):
549+
obj = super().__new__(cls, *args, **kwargs)
550+
obj._parent = parent
551+
552+
return obj
553+
554+
@property
555+
def parent(self):
556+
return self._parent
557+
558+
@property
559+
def handles(self):
560+
return (self.parent,)
561+
562+
521563
class ComponentAccess(Expr, Pickable):
522564

523565
_component_names = ('x', 'y', 'z', 'w')

devito/types/misc.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
__all__ = ['Timer', 'Pointer', 'VolatileInt', 'FIndexed', 'Wildcard', 'Fence',
1616
'Global', 'Hyperplane', 'Indirection', 'Temp', 'TempArray', 'Jump',
17-
'nop', 'WeakFence', 'CriticalRegion', 'Auto', 'AutoRef', 'auto']
17+
'nop', 'WeakFence', 'CriticalRegion', 'Auto', 'AutoRef', 'auto',
18+
'size_t']
1819

1920

2021
class Timer(CompositeObject):
@@ -344,7 +345,9 @@ def closing(self):
344345
"""
345346

346347

347-
# *** CXX support types
348+
# *** C/CXX support types
349+
350+
size_t = CustomDtype('size_t')
348351

349352
# NOTE: In C++, `auto` is a type specifier more than a type itself, but
350353
# it's a distinction we can afford to ignore, at least for now

0 commit comments

Comments
 (0)