Skip to content

Commit 23eabda

Browse files
committed
WIP WIP
1 parent 48a18cf commit 23eabda

1 file changed

Lines changed: 35 additions & 45 deletions

File tree

devito/passes/iet/engine.py

Lines changed: 35 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
from functools import partial, singledispatch, wraps
33

44
import numpy as np
5+
from sympy import Mul
56

67
from devito.ir.iet import (
78
Call, DummyExpr, ExprStmt, Expression, List, Iteration, SyncSpot,
89
AsyncCallable, FindNodes, FindSymbols, MapNodes, MetaCall, Transformer,
9-
EntryFunction, ThreadCallable, Uxreplace, derive_parameters
10+
EntryFunction, ThreadCallable, KernelLaunch, Uxreplace, derive_parameters
1011
)
1112
from devito.ir.support import SymbolRegistry
1213
from devito.mpi.distributed import MPINeighborhood
@@ -337,62 +338,51 @@ def abstract_component_accesses(root, efuncs, sregistry=None):
337338
if not isinstance(efunc, candidates):
338339
continue
339340

341+
# We have to run the pass when the KernelLaunches (if any) are visible,
342+
# otherwise it gets way too complicated
343+
calls = FindNodes(Call).visit(efunc)
344+
try:
345+
call, = [c for c in calls if isinstance(c, KernelLaunch)]
346+
except ValueError:
347+
continue
348+
kernel = efuncs[call.name]
349+
350+
# There's either one single ComponentAccess or it's not worth it (as
351+
# in the case of a whole-Bundle access)
340352
found = defaultdict(set)
341-
for e in FindNodes(Expression).visit(efunc):
353+
for e in FindNodes(Expression).visit(kernel):
342354
for v in search(e.expr, ComponentAccess):
343355
found[e].add(v)
344-
345-
# Is it a supported `efunc` structure?
346356
if len(found) != 1:
347357
continue
348358
expr, compaccs = found.popitem()
349-
if len(compaccs) != 1:
350-
# Pointless -- either none or a whole-Bundle access, so not worth
351-
continue
359+
assert len(compaccs) == 1
352360

353361
ca, = compaccs
354362
f = ca.function
355363

356-
dtype = f.c0.indexed._C_ctype
357-
358364
# Access the same entry as `ca` but via pointer arithmetic instead
359-
from IPython import embed; embed()
360-
base = Symbol(name='base', dtype=dtype)
361-
ptr = Symbol(name='ptr', dtype=dtype)
362365
index = Symbol(name='index', dtype=np.uint32, is_const=True)
363-
364-
body = List(body=[
365-
DummyExpr(base, Cast(Byref(f.indexed[ca.indices]), dtype=dtype,
366-
reinterpret=True)),
367-
DummyExpr(ptr, base + index),
368-
Uxreplace({ca: Deref(ptr)}).visit(expr)
369-
])
370-
371-
efunc1 = Transformer({expr: body}).visit(efunc)
372-
efunc1 = efunc1._rebuild(parameters=(*efunc1.parameters, index))
373-
374-
processed[k] = (efunc1, ca)
375-
376-
if not processed:
377-
return efuncs
378-
379-
# Update the Call sites
380-
mapper = {}
381-
dag = create_call_graph(root.name, efuncs)
382-
for k, (efunc, ca) in processed.items():
383-
mapper[k] = efunc
384-
385-
for i in dag.downstream(k):
386-
caller = efuncs[i]
387-
388-
calls = [c for c in FindNodes(Call).visit(caller) if c.name == k]
389-
subs = {c: c._rebuild(arguments=(*c.arguments, ca.index))
390-
for c in calls}
391-
392-
mapper[i] = Transformer(subs).visit(caller)
393-
394-
# Update the efuncs
395-
efuncs = {**dict(efuncs), **mapper}
366+
indices = [Mul(f.ncomp, i, evaluate=False) for i in ca.indices] #TODO
367+
indices[-1] += index
368+
kernel1 = Uxreplace({ca: f.c0.indexed[indices]}).visit(kernel)
369+
370+
# Update the parameters list
371+
parameters = [*kernel1.parameters, index]
372+
parameters[parameters.index(f.indexed)] = f.c0.indexed
373+
kernel1 = kernel1._rebuild(parameters=parameters)
374+
375+
# Update the Call site
376+
args = [*call.arguments, ca.index]
377+
args[args.index(f.dmap)] = Cast(f.dmap, dtype=f.c0.indexed._C_ctype,
378+
reinterpret=True)
379+
call1 = call._rebuild(arguments=args)
380+
efunc1 = Transformer({call: call1}).visit(efunc)
381+
382+
# Store the new Callables
383+
processed.update({kernel1.name: kernel1, k: efunc1})
384+
385+
efuncs = {**efuncs, **processed}
396386

397387
return efuncs
398388

0 commit comments

Comments
 (0)