|
1 | 1 | from collections import OrderedDict, defaultdict |
2 | 2 | from functools import partial, singledispatch, wraps |
3 | 3 |
|
4 | | -from devito.ir.iet import (Call, ExprStmt, Iteration, SyncSpot, AsyncCallable, |
5 | | - FindNodes, FindSymbols, MapNodes, MetaCall, Transformer, |
6 | | - EntryFunction, ThreadCallable, Uxreplace, |
7 | | - derive_parameters) |
| 4 | +import numpy as np |
| 5 | + |
| 6 | +from devito.ir.iet import ( |
| 7 | + Call, DummyExpr, ExprStmt, Expression, List, Iteration, SyncSpot, |
| 8 | + AsyncCallable, FindNodes, FindSymbols, MapNodes, MetaCall, Transformer, |
| 9 | + EntryFunction, ThreadCallable, Uxreplace, derive_parameters |
| 10 | +) |
8 | 11 | from devito.ir.support import SymbolRegistry |
9 | 12 | from devito.mpi.distributed import MPINeighborhood |
| 13 | +from devito.mpi.routines import CopyBuffer |
10 | 14 | from devito.passes import needs_transfer |
11 | | -from devito.symbolics import FieldFromComposite, FieldFromPointer |
| 15 | +from devito.symbolics import (FieldFromComposite, FieldFromPointer, Byref, Cast, |
| 16 | + Deref, search) |
12 | 17 | from devito.tools import DAG, as_tuple, filter_ordered, sorted_priority, timed_pass |
13 | | -from devito.types import (Array, Bundle, CompositeObject, Lock, IncrDimension, |
14 | | - ModuloDimension, Indirection, Pointer, SharedData, |
15 | | - ThreadArray, Temp, NPThreads, NThreadsBase, Wildcard) |
| 18 | +from devito.types import ( |
| 19 | + Array, Bundle, ComponentAccess, CompositeObject, Lock, IncrDimension, |
| 20 | + ModuloDimension, Indirection, Pointer, SharedData, ThreadArray, Symbol, Temp, |
| 21 | + NPThreads, NThreadsBase, Wildcard |
| 22 | +) |
16 | 23 | from devito.types.args import ArgProvider |
17 | 24 | from devito.types.dense import DiscreteFunction |
18 | 25 | from devito.types.dimension import AbstractIncrDimension, BlockDimension |
@@ -147,6 +154,7 @@ def apply(self, func, **kwargs): |
147 | 154 | # Minimize code size |
148 | 155 | if len(efuncs) > len(self.efuncs): |
149 | 156 | efuncs = reuse_compounds(efuncs, self.sregistry) |
| 157 | + efuncs = abstract_component_accesses(self.root, efuncs, self.sregistry) |
150 | 158 | efuncs = reuse_efuncs(self.root, efuncs, self.sregistry) |
151 | 159 |
|
152 | 160 | self.efuncs = efuncs |
@@ -316,6 +324,79 @@ def _(i, sregistry=None): |
316 | 324 | return i._rebuild(pname=pname, cfields=cfields, ncfields=ncfields, function=None) |
317 | 325 |
|
318 | 326 |
|
| 327 | +def abstract_component_accesses(root, efuncs, sregistry=None): |
| 328 | + """ |
| 329 | + Generalise `efuncs` by replacing ComponentAccesses with pointer arithmetic |
| 330 | + where possible and useful. |
| 331 | + """ |
| 332 | + candidates = (CopyBuffer,) |
| 333 | + |
| 334 | + # Transform the candidate efuncs |
| 335 | + processed = {} |
| 336 | + for k, efunc in efuncs.items(): |
| 337 | + if not isinstance(efunc, candidates): |
| 338 | + continue |
| 339 | + |
| 340 | + found = defaultdict(set) |
| 341 | + for e in FindNodes(Expression).visit(efunc): |
| 342 | + for v in search(e.expr, ComponentAccess): |
| 343 | + found[e].add(v) |
| 344 | + |
| 345 | + # Is it a supported `efunc` structure? |
| 346 | + if len(found) != 1: |
| 347 | + continue |
| 348 | + expr, compaccs = found.popitem() |
| 349 | + if len(compaccs) != 1: |
| 350 | + # Pointless -- either none or a whole-Bundle access, so not worth |
| 351 | + continue |
| 352 | + |
| 353 | + ca, = compaccs |
| 354 | + f = ca.function |
| 355 | + |
| 356 | + dtype = f.c0.indexed._C_ctype |
| 357 | + |
| 358 | + # Access the same entry as `ca` but via pointer arithmetic instead |
| 359 | + from IPython import embed; embed() |
| 360 | + base = Symbol(name='base', dtype=dtype) |
| 361 | + ptr = Symbol(name='ptr', dtype=dtype) |
| 362 | + index = Symbol(name='index', dtype=np.uint32, is_const=True) |
| 363 | + |
| 364 | + body = List(body=[ |
| 365 | + DummyExpr(base, Cast(Byref(f.indexed[ca.indices]), dtype=dtype, |
| 366 | + reinterpret=True)), |
| 367 | + DummyExpr(ptr, base + index), |
| 368 | + Uxreplace({ca: Deref(ptr)}).visit(expr) |
| 369 | + ]) |
| 370 | + |
| 371 | + efunc1 = Transformer({expr: body}).visit(efunc) |
| 372 | + efunc1 = efunc1._rebuild(parameters=(*efunc1.parameters, index)) |
| 373 | + |
| 374 | + processed[k] = (efunc1, ca) |
| 375 | + |
| 376 | + if not processed: |
| 377 | + return efuncs |
| 378 | + |
| 379 | + # Update the Call sites |
| 380 | + mapper = {} |
| 381 | + dag = create_call_graph(root.name, efuncs) |
| 382 | + for k, (efunc, ca) in processed.items(): |
| 383 | + mapper[k] = efunc |
| 384 | + |
| 385 | + for i in dag.downstream(k): |
| 386 | + caller = efuncs[i] |
| 387 | + |
| 388 | + calls = [c for c in FindNodes(Call).visit(caller) if c.name == k] |
| 389 | + subs = {c: c._rebuild(arguments=(*c.arguments, ca.index)) |
| 390 | + for c in calls} |
| 391 | + |
| 392 | + mapper[i] = Transformer(subs).visit(caller) |
| 393 | + |
| 394 | + # Update the efuncs |
| 395 | + efuncs = {**dict(efuncs), **mapper} |
| 396 | + |
| 397 | + return efuncs |
| 398 | + |
| 399 | + |
319 | 400 | def reuse_efuncs(root, efuncs, sregistry=None): |
320 | 401 | """ |
321 | 402 | Generalise `efuncs` so that syntactically identical Callables may be dropped, |
|
0 commit comments