|
1 | 1 | from collections import OrderedDict, defaultdict |
2 | 2 | from functools import partial, singledispatch, wraps |
3 | 3 |
|
4 | | -from devito.ir.iet import (Call, ExprStmt, Iteration, SyncSpot, AsyncCallable, |
5 | | - FindNodes, FindSymbols, MapNodes, MetaCall, Transformer, |
6 | | - EntryFunction, ThreadCallable, Uxreplace, |
7 | | - derive_parameters) |
| 4 | +import numpy as np |
| 5 | +from sympy import Mul |
| 6 | + |
| 7 | +from devito.ir.iet import ( |
| 8 | + Call, ExprStmt, Expression, Iteration, SyncSpot, AsyncCallable, FindNodes, |
| 9 | + FindSymbols, MapNodes, MetaCall, Transformer, EntryFunction, |
| 10 | + ThreadCallable, Uxreplace, derive_parameters |
| 11 | +) |
8 | 12 | from devito.ir.support import SymbolRegistry |
9 | 13 | from devito.mpi.distributed import MPINeighborhood |
| 14 | +from devito.mpi.routines import Gather, Scatter, HaloUpdate, HaloWait, MPIMsg |
10 | 15 | from devito.passes import needs_transfer |
11 | | -from devito.symbolics import FieldFromComposite, FieldFromPointer |
| 16 | +from devito.symbolics import (FieldFromComposite, FieldFromPointer, IndexedPointer, |
| 17 | + search) |
12 | 18 | from devito.tools import DAG, as_tuple, filter_ordered, sorted_priority, timed_pass |
13 | | -from devito.types import (Array, Bundle, CompositeObject, Lock, IncrDimension, |
14 | | - ModuloDimension, Indirection, Pointer, SharedData, |
15 | | - ThreadArray, Temp, NPThreads, NThreadsBase, Wildcard) |
| 19 | +from devito.types import ( |
| 20 | + Array, Bundle, ComponentAccess, CompositeObject, Lock, IncrDimension, |
| 21 | + ModuloDimension, Indirection, Pointer, SharedData, ThreadArray, Symbol, Temp, |
| 22 | + NPThreads, NThreadsBase, Wildcard |
| 23 | +) |
16 | 24 | from devito.types.args import ArgProvider |
17 | 25 | from devito.types.dense import DiscreteFunction |
18 | 26 | from devito.types.dimension import AbstractIncrDimension, BlockDimension |
@@ -147,6 +155,7 @@ def apply(self, func, **kwargs): |
147 | 155 | # Minimize code size |
148 | 156 | if len(efuncs) > len(self.efuncs): |
149 | 157 | efuncs = reuse_compounds(efuncs, self.sregistry) |
| 158 | + efuncs = abstract_component_accesses(efuncs) |
150 | 159 | efuncs = reuse_efuncs(self.root, efuncs, self.sregistry) |
151 | 160 |
|
152 | 161 | self.efuncs = efuncs |
@@ -316,6 +325,80 @@ def _(i, sregistry=None): |
316 | 325 | return i._rebuild(pname=pname, cfields=cfields, ncfields=ncfields, function=None) |
317 | 326 |
|
318 | 327 |
|
| 328 | +def abstract_component_accesses(efuncs): |
| 329 | + """ |
| 330 | + Generalise `efuncs` by replacing ComponentAccesses with pointer arithmetic |
| 331 | + where possible and useful. |
| 332 | +
|
| 333 | + This pass is only applied to selected Callables: |
| 334 | +
|
| 335 | + 1) HaloUpdate/Gather |
| 336 | + 2) HaloWait/Scatter |
| 337 | + """ |
| 338 | + processed = dict(efuncs) |
| 339 | + |
| 340 | + for k, efunc in efuncs.items(): |
| 341 | + if not isinstance(efunc, (HaloUpdate, HaloWait)): |
| 342 | + continue |
| 343 | + |
| 344 | + calls = FindNodes((Gather, Scatter)).visit(efunc) |
| 345 | + if len(calls) != 1: |
| 346 | + continue |
| 347 | + call_copy_buffer, = calls |
| 348 | + copy_buffer = efuncs[call_copy_buffer.name] |
| 349 | + |
| 350 | + # Retrieve expected objects. If this fails, it means it's a structure |
| 351 | + # we don't recognize (e.g., an unsupported MPI scheme?), so we just |
| 352 | + # give up and move on |
| 353 | + try: |
| 354 | + f, = [i for i in copy_buffer.parameters if i.is_Bundle] |
| 355 | + msg, = [i for i in efunc.parameters if isinstance(i, MPIMsg)] |
| 356 | + dim, = FindSymbols('dimensions').visit(efunc) |
| 357 | + except ValueError: |
| 358 | + continue |
| 359 | + |
| 360 | + exprs = FindNodes(Expression).visit(copy_buffer) |
| 361 | + compaccs = set().union(*[search(i.expr, ComponentAccess) for i in exprs]) |
| 362 | + if not compaccs or len(compaccs) == f.ncomp: |
| 363 | + continue |
| 364 | + |
| 365 | + # Sorted for deterministic codegen |
| 366 | + compaccs = sorted(compaccs, key=lambda i: i.index) |
| 367 | + |
| 368 | + # Access the same entry as `ca` but via pointer arithmetic instead |
| 369 | + arity_param = Symbol(name='arity', dtype=np.uint32, is_const=True) |
| 370 | + compoff_params = [Symbol(name=f'c{i}', dtype=np.uint32, is_const=True) |
| 371 | + for i in range(len(compaccs))] |
| 372 | + |
| 373 | + f_flatten = f.func(name='flat_data', components=f.c0) |
| 374 | + |
| 375 | + subs = {} |
| 376 | + for ca, o in zip(compaccs, compoff_params): |
| 377 | + indices = [Mul(arity_param, i, evaluate=False) for i in ca.indices] |
| 378 | + indices[-1] += o |
| 379 | + subs[ca] = f_flatten.indexed[indices] |
| 380 | + |
| 381 | + # Transform the `copy_buffer` to use a flatten representation of the Bundle |
| 382 | + copy_buffer1 = Uxreplace(subs).visit(copy_buffer) |
| 383 | + parameters = [*copy_buffer1.parameters, arity_param, *compoff_params] |
| 384 | + parameters[parameters.index(f)] = f_flatten |
| 385 | + copy_buffer1 = copy_buffer1._rebuild(parameters=parameters) |
| 386 | + |
| 387 | + # Update the `efunc` to use the new `copy_buffer1` version |
| 388 | + arity_arg = FieldFromPointer(f._C_field_arity, f._C_symbol) |
| 389 | + compoff_args = [IndexedPointer( |
| 390 | + FieldFromComposite(msg._C_field_components, IndexedPointer(msg, dim)), i |
| 391 | + ) for i in range(len(compoff_params))] |
| 392 | + |
| 393 | + arguments = [*call_copy_buffer.arguments, arity_arg, *compoff_args] |
| 394 | + call_copy_buffer1 = call_copy_buffer._rebuild(arguments=arguments) |
| 395 | + efunc1 = Transformer({call_copy_buffer: call_copy_buffer1}).visit(efunc) |
| 396 | + |
| 397 | + processed.update({k: efunc1, copy_buffer.name: copy_buffer1}) |
| 398 | + |
| 399 | + return processed |
| 400 | + |
| 401 | + |
319 | 402 | def reuse_efuncs(root, efuncs, sregistry=None): |
320 | 403 | """ |
321 | 404 | Generalise `efuncs` so that syntactically identical Callables may be dropped, |
|
0 commit comments