Skip to content

Commit 48a18cf

Browse files
committed
WIP WIP
1 parent f3d7aca commit 48a18cf

3 files changed

Lines changed: 102 additions & 9 deletions

File tree

devito/passes/iet/definitions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,14 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):
184184
shape_param = Array(name=f'{obj.name}_shape', dtype=np.uint64,
185185
dimensions=(Dimension(name='d'),), scope='rvalue')
186186
shape_arg = ListInitializer(obj.c0.symbolic_shape, dtype=shape_param.dtype)
187+
sizeofelem_param = Symbol(name='sizeofelem', dtype=size_t)
188+
sizeofelem_arg = 1 #TODO
187189

188190
ffp0 = FieldFromPointer(obj._C_field_data, obj._C_symbol)
189191
ffp1 = FieldFromPointer(obj._C_field_shape, obj._C_symbol)
190192
ffp2 = FieldFromPointer(obj._C_field_size, obj._C_symbol)
191193
ffp3 = FieldFromPointer(obj._C_field_nbytes, obj._C_symbol)
194+
ffp4 = FieldFromPointer(obj._C_field_sizeofelem, obj._C_symbol)
192195

193196
# Allocate the Array struct
194197
memptr = VOID(Byref(obj._C_symbol), '**')
@@ -213,6 +216,7 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):
213216
ndims_param - 1,
214217
))
215218
init.append(DummyExpr(ffp3, ffp2*arity_param))
219+
init.append(DummyExpr(ffp4, sizeofelem_param))
216220

217221
# Allocate the underlying host data
218222
memptr = VOID(Byref(ffp0), '**')
@@ -237,6 +241,7 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):
237241
args[args.index(arity_param)] = arity_arg
238242
args[args.index(ndims_param)] = ndims_arg
239243
args[args.index(shape_param)] = shape_arg
244+
args[args.index(sizeofelem_param)] = sizeofelem_arg #TODO
240245
alloc = Call(name, args, retobj=obj)
241246

242247
# Same story for the frees
@@ -261,10 +266,14 @@ def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage):
261266
shape_param = Array(name=f'{obj.name}_shape', dtype=np.uint64,
262267
dimensions=(Dimension(name='d'),), scope='rvalue')
263268
shape_arg = ListInitializer(obj.c0.symbolic_shape, dtype=shape_param.dtype)
269+
sizeofelem_param = Symbol(name='sizeofelem', dtype=size_t)
270+
sizeofelem_arg = 1 #TODO
271+
from IPython import embed; embed()
264272

265273
ffp1 = FieldFromPointer(obj._C_field_shape, obj._C_symbol)
266274
ffp2 = FieldFromPointer(obj._C_field_size, obj._C_symbol)
267275
ffp3 = FieldFromPointer(obj._C_field_nbytes, obj._C_symbol)
276+
ffp4 = FieldFromPointer(obj._C_field_sizeofelem, obj._C_symbol)
268277

269278
# Allocate the Bundle struct
270279
memptr = VOID(Byref(obj._C_symbol), '**')
@@ -305,6 +314,7 @@ def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage):
305314
args[args.index(arity_param)] = arity_arg
306315
args[args.index(ndims_param)] = ndims_arg
307316
args[args.index(shape_param)] = shape_arg
317+
args[args.index(sizeofelem_param)] = sizeofelem_arg #TODO
308318
alloc = Call(name, args, retobj=obj)
309319

310320
# Same story for the frees

devito/passes/iet/engine.py

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
11
from collections import OrderedDict, defaultdict
22
from functools import partial, singledispatch, wraps
33

4-
from devito.ir.iet import (Call, ExprStmt, Iteration, SyncSpot, AsyncCallable,
5-
FindNodes, FindSymbols, MapNodes, MetaCall, Transformer,
6-
EntryFunction, ThreadCallable, Uxreplace,
7-
derive_parameters)
4+
import numpy as np
5+
6+
from devito.ir.iet import (
7+
Call, DummyExpr, ExprStmt, Expression, List, Iteration, SyncSpot,
8+
AsyncCallable, FindNodes, FindSymbols, MapNodes, MetaCall, Transformer,
9+
EntryFunction, ThreadCallable, Uxreplace, derive_parameters
10+
)
811
from devito.ir.support import SymbolRegistry
912
from devito.mpi.distributed import MPINeighborhood
13+
from devito.mpi.routines import CopyBuffer
1014
from devito.passes import needs_transfer
11-
from devito.symbolics import FieldFromComposite, FieldFromPointer
15+
from devito.symbolics import (FieldFromComposite, FieldFromPointer, Byref, Cast,
16+
Deref, search)
1217
from devito.tools import DAG, as_tuple, filter_ordered, sorted_priority, timed_pass
13-
from devito.types import (Array, Bundle, CompositeObject, Lock, IncrDimension,
14-
ModuloDimension, Indirection, Pointer, SharedData,
15-
ThreadArray, Temp, NPThreads, NThreadsBase, Wildcard)
18+
from devito.types import (
19+
Array, Bundle, ComponentAccess, CompositeObject, Lock, IncrDimension,
20+
ModuloDimension, Indirection, Pointer, SharedData, ThreadArray, Symbol, Temp,
21+
NPThreads, NThreadsBase, Wildcard
22+
)
1623
from devito.types.args import ArgProvider
1724
from devito.types.dense import DiscreteFunction
1825
from devito.types.dimension import AbstractIncrDimension, BlockDimension
@@ -147,6 +154,7 @@ def apply(self, func, **kwargs):
147154
# Minimize code size
148155
if len(efuncs) > len(self.efuncs):
149156
efuncs = reuse_compounds(efuncs, self.sregistry)
157+
efuncs = abstract_component_accesses(self.root, efuncs, self.sregistry)
150158
efuncs = reuse_efuncs(self.root, efuncs, self.sregistry)
151159

152160
self.efuncs = efuncs
@@ -316,6 +324,79 @@ def _(i, sregistry=None):
316324
return i._rebuild(pname=pname, cfields=cfields, ncfields=ncfields, function=None)
317325

318326

327+
def abstract_component_accesses(root, efuncs, sregistry=None):
328+
"""
329+
Generalise `efuncs` by replacing ComponentAccesses with pointer arithmetic
330+
where possible and useful.
331+
"""
332+
candidates = (CopyBuffer,)
333+
334+
# Transform the candidate efuncs
335+
processed = {}
336+
for k, efunc in efuncs.items():
337+
if not isinstance(efunc, candidates):
338+
continue
339+
340+
found = defaultdict(set)
341+
for e in FindNodes(Expression).visit(efunc):
342+
for v in search(e.expr, ComponentAccess):
343+
found[e].add(v)
344+
345+
# Is it a supported `efunc` structure?
346+
if len(found) != 1:
347+
continue
348+
expr, compaccs = found.popitem()
349+
if len(compaccs) != 1:
350+
# Pointless -- either none or a whole-Bundle access, so not worth
351+
continue
352+
353+
ca, = compaccs
354+
f = ca.function
355+
356+
dtype = f.c0.indexed._C_ctype
357+
358+
# Access the same entry as `ca` but via pointer arithmetic instead
359+
from IPython import embed; embed()
360+
base = Symbol(name='base', dtype=dtype)
361+
ptr = Symbol(name='ptr', dtype=dtype)
362+
index = Symbol(name='index', dtype=np.uint32, is_const=True)
363+
364+
body = List(body=[
365+
DummyExpr(base, Cast(Byref(f.indexed[ca.indices]), dtype=dtype,
366+
reinterpret=True)),
367+
DummyExpr(ptr, base + index),
368+
Uxreplace({ca: Deref(ptr)}).visit(expr)
369+
])
370+
371+
efunc1 = Transformer({expr: body}).visit(efunc)
372+
efunc1 = efunc1._rebuild(parameters=(*efunc1.parameters, index))
373+
374+
processed[k] = (efunc1, ca)
375+
376+
if not processed:
377+
return efuncs
378+
379+
# Update the Call sites
380+
mapper = {}
381+
dag = create_call_graph(root.name, efuncs)
382+
for k, (efunc, ca) in processed.items():
383+
mapper[k] = efunc
384+
385+
for i in dag.downstream(k):
386+
caller = efuncs[i]
387+
388+
calls = [c for c in FindNodes(Call).visit(caller) if c.name == k]
389+
subs = {c: c._rebuild(arguments=(*c.arguments, ca.index))
390+
for c in calls}
391+
392+
mapper[i] = Transformer(subs).visit(caller)
393+
394+
# Update the efuncs
395+
efuncs = {**dict(efuncs), **mapper}
396+
397+
return efuncs
398+
399+
319400
def reuse_efuncs(root, efuncs, sregistry=None):
320401
"""
321402
Generalise `efuncs` so that syntactically identical Callables may be dropped,

devito/types/array.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ class MappedArrayMixin:
223223
_C_field_shape = 'shape'
224224
_C_field_size = 'size'
225225
_C_field_nbytes = 'nbytes'
226+
_C_field_sizeofelem = 'sizeofelem'
226227

227228
_C_size_type = c_uint64
228229

@@ -231,7 +232,8 @@ class MappedArrayMixin:
231232
(_C_field_dmap, c_void_p),
232233
(_C_field_shape, POINTER(_C_size_type)),
233234
(_C_field_size, _C_size_type),
234-
(_C_field_nbytes, _C_size_type)]}))
235+
(_C_field_nbytes, _C_size_type),
236+
(_C_field_sizeofelem, _C_size_type)]}))
235237

236238

237239
class ArrayMapped(MappedArrayMixin, Array):

0 commit comments

Comments
 (0)