|
17 | 17 | from devito.symbolics import (Byref, CondNe, FieldFromPointer, FieldFromComposite, |
18 | 18 | IndexedPointer, Macro, cast, subs_op_args) |
19 | 19 | from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, infer_datasize, |
20 | | - flatten, generator, is_integer, split) |
| 20 | + flatten, generator, is_integer) |
21 | 21 | from devito.types import (Array, Bag, Dimension, Eq, Symbol, LocalObject, |
22 | 22 | CompositeObject, CustomDimension) |
23 | 23 |
|
@@ -292,19 +292,21 @@ def _make_bundles(self, hs): |
292 | 292 |
|
293 | 293 | mapper = as_mapper(halo_scheme.fmapper, lambda i: halo_scheme.fmapper[i]) |
294 | 294 | for hse, components in mapper.items(): |
295 | | - # We recast everything as Bags for simplicity -- worst case scenario |
296 | | - # all Bags only have one component. Existing Bundles are preserved |
297 | 295 | halo_scheme = halo_scheme.drop(components) |
298 | | - bundles, candidates = split(tuple(components), lambda i: i.is_Bundle) |
299 | | - for b in bundles: |
300 | | - halo_scheme = halo_scheme.add(b, hse) |
301 | 296 |
|
| 297 | + # Existing Bundles are preserved |
| 298 | + if hse.bundle and set(components) == set(hse.bundle.components): |
| 299 | + halo_scheme = halo_scheme.add(hse.bundle, hse) |
| 300 | + continue |
| 301 | + |
| 302 | + # We recast everything else as Bags for simplicity -- worst case |
| 303 | + # scenario all Bags only have one component. |
302 | 304 | try: |
303 | | - name = "bag_%s" % "".join(f.name for f in candidates) |
304 | | - bag = Bag(name=name, components=candidates) |
| 305 | + name = "bag_%s" % "".join(f.name for f in components) |
| 306 | + bag = Bag(name=name, components=components) |
305 | 307 | halo_scheme = halo_scheme.add(bag, hse) |
306 | 308 | except ValueError: |
307 | | - for i in candidates: |
| 309 | + for i in components: |
308 | 310 | name = "bag_%s" % i.name |
309 | 311 | bag = Bag(name=name, components=i) |
310 | 312 | halo_scheme = halo_scheme.add(bag, hse) |
@@ -363,10 +365,19 @@ def _make_copy(self, f, hse, key, swap=False): |
363 | 365 | else: |
364 | 366 | swap = lambda i, j: (j, i) |
365 | 367 | name = 'scatter%s' % key |
| 368 | + |
366 | 369 | if isinstance(f, Bag): |
367 | | - for i, c in enumerate(f.components): |
368 | | - eqns.append(Eq(*swap(buf[[i] + bdims], c[findices]))) |
| 370 | + if hse.bundle is not None: |
| 371 | + # `f` is the only component of `hse.bundle` that is |
| 372 | + # being communicated |
| 373 | + assert f.ncomp == 1 |
| 374 | + i = hse.bundle.components.index(f.c0) |
| 375 | + eqns.append(Eq(*swap(buf[[0] + bdims], hse.bundle[[i] + findices]))) |
| 376 | + else: |
| 377 | + for i, c in enumerate(f.components): |
| 378 | + eqns.append(Eq(*swap(buf[[i] + bdims], c[findices]))) |
369 | 379 | else: |
| 380 | + assert f.is_Bundle |
370 | 381 | for i in range(f.ncomp): |
371 | 382 | eqns.append(Eq(*swap(buf[[i] + bdims], f[[i] + findices]))) |
372 | 383 |
|
|
0 commit comments