Skip to content

Commit cea1b91

Browse files
committed
compiler: Add abstraction of efuncs w ComponentAccesses
1 parent 19cbe35 commit cea1b91

4 files changed

Lines changed: 126 additions & 20 deletions

File tree

devito/mpi/routines.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@ def _make_halowait(self, f, hse, key, wait, msg=None):
740740

741741
parameters = list(f.handles) + list(fixed.values()) + [nb, msg]
742742

743-
return Callable('halowait%d' % key, iet, 'void', parameters, ('static',))
743+
return HaloWait(f'halowait{key}', iet, parameters)
744744

745745
def _call_halowait(self, name, f, hse, msg):
746746
nb = f.grid.distributor._obj_neighborhood
@@ -779,7 +779,7 @@ def _make_region(self, hs, key):
779779
def _make_msg(self, f, hse, key):
780780
# Only retain the halos required by the Diag scheme
781781
halos = sorted(i for i in hse.halos if isinstance(i.dim, tuple))
782-
return MPIMsgEnriched('msg%d' % key, f, halos)
782+
return MPIMsgEnriched(f'msg{key}', f, halos)
783783

784784
def _make_sendrecv(self, *args, **kwargs):
785785
return
@@ -868,7 +868,7 @@ def _make_halowait(self, f, hse, key, *args, msg=None):
868868
ncomms = Symbol(name='ncomms')
869869
iet = Iteration([waitsend, waitrecv, scatter], dim, ncomms - 1)
870870
parameters = f.handles + tuple(fixed.values()) + (msg, ncomms)
871-
return Callable('halowait%d' % key, iet, 'void', parameters, ('static',))
871+
return HaloWait(f'halowait{key}', iet, parameters)
872872

873873
def _call_halowait(self, name, f, hse, msg):
874874
args = f.handles + tuple(hse.loc_indices.values()) + (msg, msg.npeers)
@@ -1050,9 +1050,11 @@ def __init__(self, name, body, parameters, bufg, bufs):
10501050

10511051

10521052
class HaloUpdate(MPICallable):
1053+
pass
10531054

1054-
def __init__(self, name, body, parameters):
1055-
super().__init__(name, body, parameters)
1055+
1056+
class HaloWait(MPICallable):
1057+
pass
10561058

10571059

10581060
class Remainder(ElementalFunction):
@@ -1254,12 +1256,14 @@ class MPIMsgEnriched(MPIMsg):
12541256
_C_field_ofsg = 'ofsg'
12551257
_C_field_from = 'fromrank'
12561258
_C_field_to = 'torank'
1259+
_C_field_components = 'components'
12571260

12581261
fields = MPIMsg.fields + [
12591262
(_C_field_ofss, POINTER(c_int)),
12601263
(_C_field_ofsg, POINTER(c_int)),
12611264
(_C_field_from, c_int),
1262-
(_C_field_to, c_int)
1265+
(_C_field_to, c_int),
1266+
(_C_field_components, POINTER(c_int)),
12631267
]
12641268

12651269
def _arg_defaults(self, allocator, alias=None, args=None):
@@ -1298,6 +1302,17 @@ def _arg_defaults(self, allocator, alias=None, args=None):
12981302
ofss.append(f._offset_owned[dim].left)
12991303
entry.ofss = (c_int*len(ofss))(*ofss)
13001304

1305+
# Track the component accesses for packing/unpacking as numbers
1306+
# representing the field being accessed (that is: .x -> 0, .y -> 1,
1307+
# .z -> 2, .w -> 3), if any
1308+
if isinstance(self.target, BundleView):
1309+
ncomp = self.target.ncomp
1310+
component_indices = self.target.component_indices
1311+
entry.components = (c_int*ncomp)(*component_indices)
1312+
elif self.target.is_Bundle:
1313+
ncomp = self.target.ncomp
1314+
entry.components = (c_int*ncomp)(*range(ncomp))
1315+
13011316
return {self.name: self.value}
13021317

13031318

devito/passes/iet/definitions.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,14 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):
177177
"""
178178
decl = Definition(obj)
179179

180+
sizeof_dtypeN = SizeOf(obj.indexed._C_typedata)
181+
sizeof_dtype1 = SizeOf(obj.c0.indexed._C_typedata)
182+
180183
# NOTE: the `arity` is calculated such as `sizeof(float3)/sizeof(float)`
181184
# for portability reasons (since we don't know the size of compound
182185
# types a priori)
183186
arity_param = Symbol(name='arity', dtype=size_t)
184-
arity_arg = (SizeOf(obj.indexed._C_typedata) /
185-
SizeOf(obj.c0.indexed._C_typedata))
187+
arity_arg = sizeof_dtypeN / sizeof_dtype1
186188
ndims_param = Symbol(name='ndims', dtype=size_t)
187189
ndims_arg = obj.ndim
188190
shape_param = Array(name=f'{obj.name}_shape', dtype=np.uint64,
@@ -217,7 +219,7 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):
217219
dim,
218220
ndims_param - 1,
219221
))
220-
init.append(DummyExpr(ffp3, ffp2*arity_param))
222+
init.append(DummyExpr(ffp3, ffp2*arity_param*sizeof_dtype1))
221223
init.append(DummyExpr(ffp4, arity_param))
222224

223225
# Allocate the underlying host data
@@ -260,12 +262,14 @@ def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage):
260262
"""
261263
decl = Definition(obj)
262264

265+
sizeof_dtypeN = SizeOf(obj.indexed._C_typedata)
266+
sizeof_dtype1 = SizeOf(obj.c0.indexed._C_typedata)
267+
263268
# NOTE: the `arity` is calculated such as `sizeof(float3)/sizeof(float)`
264269
# for portability reasons (since we don't know the size of compound
265270
# types a priori)
266271
arity_param = Symbol(name='arity', dtype=size_t)
267-
arity_arg = (SizeOf(obj.indexed._C_typedata) /
268-
SizeOf(obj.c0.indexed._C_typedata))
272+
arity_arg = sizeof_dtypeN / sizeof_dtype1
269273
ndims_param = Symbol(name='ndims', dtype=size_t)
270274
ndims_arg = obj.ndim
271275
shape_param = Array(name=f'{obj.name}_shape', dtype=np.uint64,
@@ -299,7 +303,7 @@ def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage):
299303
dim,
300304
ndims_param - 1,
301305
))
302-
init.append(DummyExpr(ffp3, ffp2*arity_param))
306+
init.append(DummyExpr(ffp3, ffp2*arity_param*sizeof_dtype1))
303307
init.append(DummyExpr(ffp4, arity_param))
304308

305309
# Free all of the allocated data

devito/passes/iet/engine.py

Lines changed: 91 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,26 @@
11
from collections import OrderedDict, defaultdict
22
from functools import partial, singledispatch, wraps
33

4-
from devito.ir.iet import (Call, ExprStmt, Iteration, SyncSpot, AsyncCallable,
5-
FindNodes, FindSymbols, MapNodes, MetaCall, Transformer,
6-
EntryFunction, ThreadCallable, Uxreplace,
7-
derive_parameters)
4+
import numpy as np
5+
from sympy import Mul
6+
7+
from devito.ir.iet import (
8+
Call, ExprStmt, Expression, Iteration, SyncSpot, AsyncCallable, FindNodes,
9+
FindSymbols, MapNodes, MetaCall, Transformer, EntryFunction,
10+
ThreadCallable, Uxreplace, derive_parameters
11+
)
812
from devito.ir.support import SymbolRegistry
913
from devito.mpi.distributed import MPINeighborhood
14+
from devito.mpi.routines import Gather, Scatter, HaloUpdate, HaloWait, MPIMsg
1015
from devito.passes import needs_transfer
11-
from devito.symbolics import FieldFromComposite, FieldFromPointer
16+
from devito.symbolics import (FieldFromComposite, FieldFromPointer, IndexedPointer,
17+
search)
1218
from devito.tools import DAG, as_tuple, filter_ordered, sorted_priority, timed_pass
13-
from devito.types import (Array, Bundle, CompositeObject, Lock, IncrDimension,
14-
ModuloDimension, Indirection, Pointer, SharedData,
15-
ThreadArray, Temp, NPThreads, NThreadsBase, Wildcard)
19+
from devito.types import (
20+
Array, Bundle, ComponentAccess, CompositeObject, Lock, IncrDimension,
21+
ModuloDimension, Indirection, Pointer, SharedData, ThreadArray, Symbol, Temp,
22+
NPThreads, NThreadsBase, Wildcard
23+
)
1624
from devito.types.args import ArgProvider
1725
from devito.types.dense import DiscreteFunction
1826
from devito.types.dimension import AbstractIncrDimension, BlockDimension
@@ -147,6 +155,7 @@ def apply(self, func, **kwargs):
147155
# Minimize code size
148156
if len(efuncs) > len(self.efuncs):
149157
efuncs = reuse_compounds(efuncs, self.sregistry)
158+
efuncs = abstract_component_accesses(efuncs)
150159
efuncs = reuse_efuncs(self.root, efuncs, self.sregistry)
151160

152161
self.efuncs = efuncs
@@ -316,6 +325,80 @@ def _(i, sregistry=None):
316325
return i._rebuild(pname=pname, cfields=cfields, ncfields=ncfields, function=None)
317326

318327

328+
def abstract_component_accesses(efuncs):
329+
"""
330+
Generalise `efuncs` by replacing ComponentAccesses with pointer arithmetic
331+
where possible and useful.
332+
333+
This pass is only applied to selected Callables:
334+
335+
1) HaloUpdate/Gather
336+
2) HaloWait/Scatter
337+
"""
338+
processed = dict(efuncs)
339+
340+
for k, efunc in efuncs.items():
341+
if not isinstance(efunc, (HaloUpdate, HaloWait)):
342+
continue
343+
344+
calls = FindNodes((Gather, Scatter)).visit(efunc)
345+
if len(calls) != 1:
346+
continue
347+
call_copy_buffer, = calls
348+
copy_buffer = efuncs[call_copy_buffer.name]
349+
350+
# Retrieve expected objects. If this fails, it means it's a structure
351+
# we don't recognize (e.g., an unsupported MPI scheme?), so we just
352+
# give up and move on
353+
try:
354+
f, = [i for i in copy_buffer.parameters if i.is_Bundle]
355+
msg, = [i for i in efunc.parameters if isinstance(i, MPIMsg)]
356+
dim, = FindSymbols('dimensions').visit(efunc)
357+
except ValueError:
358+
continue
359+
360+
exprs = FindNodes(Expression).visit(copy_buffer)
361+
compaccs = set().union(*[search(i.expr, ComponentAccess) for i in exprs])
362+
if not compaccs or len(compaccs) == f.ncomp:
363+
continue
364+
365+
# Sorted for deterministic codegen
366+
compaccs = sorted(compaccs, key=lambda i: i.index)
367+
368+
# Access the same entry as `ca` but via pointer arithmetic instead
369+
arity_param = Symbol(name='arity', dtype=np.uint32, is_const=True)
370+
compoff_params = [Symbol(name=f'c{i}', dtype=np.uint32, is_const=True)
371+
for i in range(len(compaccs))]
372+
373+
f_flatten = f.func(name='flat_data', components=f.c0)
374+
375+
subs = {}
376+
for ca, o in zip(compaccs, compoff_params):
377+
indices = [Mul(arity_param, i, evaluate=False) for i in ca.indices]
378+
indices[-1] += o
379+
subs[ca] = f_flatten.indexed[indices]
380+
381+
# Transform the `copy_buffer` to use a flatten representation of the Bundle
382+
copy_buffer1 = Uxreplace(subs).visit(copy_buffer)
383+
parameters = [*copy_buffer1.parameters, arity_param, *compoff_params]
384+
parameters[parameters.index(f)] = f_flatten
385+
copy_buffer1 = copy_buffer1._rebuild(parameters=parameters)
386+
387+
# Update the `efunc` to use the new `copy_buffer1` version
388+
arity_arg = FieldFromPointer(f._C_field_arity, f._C_symbol)
389+
compoff_args = [IndexedPointer(
390+
FieldFromComposite(msg._C_field_components, IndexedPointer(msg, dim)), i
391+
) for i in range(len(compoff_params))]
392+
393+
arguments = [*call_copy_buffer.arguments, arity_arg, *compoff_args]
394+
call_copy_buffer1 = call_copy_buffer._rebuild(arguments=arguments)
395+
efunc1 = Transformer({call_copy_buffer: call_copy_buffer1}).visit(efunc)
396+
397+
processed.update({k: efunc1, copy_buffer.name: copy_buffer1})
398+
399+
return processed
400+
401+
319402
def reuse_efuncs(root, efuncs, sregistry=None):
320403
"""
321404
Generalise `efuncs` so that syntactically identical Callables may be dropped,

devito/types/array.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,10 @@ def parent(self):
567567
def handles(self):
568568
return (self.parent,)
569569

570+
@property
571+
def component_indices(self):
572+
return tuple(self.parent.components.index(i) for i in self.components)
573+
570574

571575
class ComponentAccess(Expr, Pickable):
572576

0 commit comments

Comments
 (0)