|
2 | 2 | from functools import partial, singledispatch, wraps |
3 | 3 |
|
4 | 4 | import numpy as np |
| 5 | +from sympy import Mul |
5 | 6 |
|
6 | 7 | from devito.ir.iet import ( |
7 | 8 | Call, DummyExpr, ExprStmt, Expression, List, Iteration, SyncSpot, |
8 | 9 | AsyncCallable, FindNodes, FindSymbols, MapNodes, MetaCall, Transformer, |
9 | | - EntryFunction, ThreadCallable, Uxreplace, derive_parameters |
| 10 | + EntryFunction, ThreadCallable, KernelLaunch, Uxreplace, derive_parameters |
10 | 11 | ) |
11 | 12 | from devito.ir.support import SymbolRegistry |
12 | 13 | from devito.mpi.distributed import MPINeighborhood |
@@ -337,62 +338,51 @@ def abstract_component_accesses(root, efuncs, sregistry=None): |
337 | 338 | if not isinstance(efunc, candidates): |
338 | 339 | continue |
339 | 340 |
|
| 341 | + # We have to run the pass when the KernelLaunches (if any) are visible, |
| 342 | + # otherwise it gets way too complicated |
| 343 | + calls = FindNodes(Call).visit(efunc) |
| 344 | + try: |
| 345 | + call, = [c for c in calls if isinstance(c, KernelLaunch)] |
| 346 | + except ValueError: |
| 347 | + continue |
| 348 | + kernel = efuncs[call.name] |
| 349 | + |
| 350 | + # There's either one single ComponentAccess or it's not worth it (as |
| 351 | + # in the case of a whole-Bundle access) |
340 | 352 | found = defaultdict(set) |
341 | | - for e in FindNodes(Expression).visit(efunc): |
| 353 | + for e in FindNodes(Expression).visit(kernel): |
342 | 354 | for v in search(e.expr, ComponentAccess): |
343 | 355 | found[e].add(v) |
344 | | - |
345 | | - # Is it a supported `efunc` structure? |
346 | 356 | if len(found) != 1: |
347 | 357 | continue |
348 | 358 | expr, compaccs = found.popitem() |
349 | | - if len(compaccs) != 1: |
350 | | - # Pointless -- either none or a whole-Bundle access, so not worth |
351 | | - continue |
| 359 | + assert len(compaccs) == 1 |
352 | 360 |
|
353 | 361 | ca, = compaccs |
354 | 362 | f = ca.function |
355 | 363 |
|
356 | | - dtype = f.c0.indexed._C_ctype |
357 | | - |
358 | 364 | # Access the same entry as `ca` but via pointer arithmetic instead |
359 | | - from IPython import embed; embed() |
360 | | - base = Symbol(name='base', dtype=dtype) |
361 | | - ptr = Symbol(name='ptr', dtype=dtype) |
362 | 365 | index = Symbol(name='index', dtype=np.uint32, is_const=True) |
363 | | - |
364 | | - body = List(body=[ |
365 | | - DummyExpr(base, Cast(Byref(f.indexed[ca.indices]), dtype=dtype, |
366 | | - reinterpret=True)), |
367 | | - DummyExpr(ptr, base + index), |
368 | | - Uxreplace({ca: Deref(ptr)}).visit(expr) |
369 | | - ]) |
370 | | - |
371 | | - efunc1 = Transformer({expr: body}).visit(efunc) |
372 | | - efunc1 = efunc1._rebuild(parameters=(*efunc1.parameters, index)) |
373 | | - |
374 | | - processed[k] = (efunc1, ca) |
375 | | - |
376 | | - if not processed: |
377 | | - return efuncs |
378 | | - |
379 | | - # Update the Call sites |
380 | | - mapper = {} |
381 | | - dag = create_call_graph(root.name, efuncs) |
382 | | - for k, (efunc, ca) in processed.items(): |
383 | | - mapper[k] = efunc |
384 | | - |
385 | | - for i in dag.downstream(k): |
386 | | - caller = efuncs[i] |
387 | | - |
388 | | - calls = [c for c in FindNodes(Call).visit(caller) if c.name == k] |
389 | | - subs = {c: c._rebuild(arguments=(*c.arguments, ca.index)) |
390 | | - for c in calls} |
391 | | - |
392 | | - mapper[i] = Transformer(subs).visit(caller) |
393 | | - |
394 | | - # Update the efuncs |
395 | | - efuncs = {**dict(efuncs), **mapper} |
| 366 | + indices = [Mul(f.ncomp, i, evaluate=False) for i in ca.indices] #TODO |
| 367 | + indices[-1] += index |
| 368 | + kernel1 = Uxreplace({ca: f.c0.indexed[indices]}).visit(kernel) |
| 369 | + |
| 370 | + # Update the parameters list |
| 371 | + parameters = [*kernel1.parameters, index] |
| 372 | + parameters[parameters.index(f.indexed)] = f.c0.indexed |
| 373 | + kernel1 = kernel1._rebuild(parameters=parameters) |
| 374 | + |
| 375 | + # Update the Call site |
| 376 | + args = [*call.arguments, ca.index] |
| 377 | + args[args.index(f.dmap)] = Cast(f.dmap, dtype=f.c0.indexed._C_ctype, |
| 378 | + reinterpret=True) |
| 379 | + call1 = call._rebuild(arguments=args) |
| 380 | + efunc1 = Transformer({call: call1}).visit(efunc) |
| 381 | + |
| 382 | + # Store the new Callables |
| 383 | + processed.update({kernel1.name: kernel1, k: efunc1}) |
| 384 | + |
| 385 | + efuncs = {**efuncs, **processed} |
396 | 386 |
|
397 | 387 | return efuncs |
398 | 388 |
|
|
0 commit comments