Skip to content

Commit 1f754dd

Browse files
committed
tests: Add a test for overrides
1 parent f4e427f commit 1f754dd

2 files changed

Lines changed: 24 additions & 2 deletions

File tree

devito/types/grid.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,7 @@ def __setstate__(self, state):
686686

687687

688688
class MultiSubDomain(AbstractSubDomain):
689+
# FIXME: Should subclass ArgProvider
689690

690691
"""
691692
Abstract base class for types representing groups of SubDomains.

tests/test_subdomains.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,8 @@ class Dummy(SubDomainSet):
726726

727727

728728
class TestBorder:
729+
# Note: This class is partially covered by doctests
730+
# TODO: Will need to add MPI tests, including overrides
729731
def test_exceptions(self):
730732
"""Test exceptions are raised for malformed specifications"""
731733
grid = Grid(shape=(5,))
@@ -741,7 +743,6 @@ def test_uneven_border(self, corners):
741743
"""Test border specifications which vary by dimension"""
742744
shape = (6, 8)
743745
grid = Grid(shape=shape)
744-
x, y = grid.dimensions
745746

746747
border = Border(grid, (1, (2, 1)), corners=corners)
747748

@@ -774,7 +775,6 @@ def test_one_sided_border(self, corners):
774775
grid = Grid(shape=shape)
775776
x, y = grid.dimensions
776777

777-
# border = Border(grid, 2, dims={x: x, y: 'right'})
778778
border = Border(grid, 1, dims={x: 'left', y: 'right'}, corners=corners)
779779

780780
f = Function(name='f', grid=grid, dtype=np.int32)
@@ -811,6 +811,27 @@ def test_border_3d(self):
811811

812812
assert np.all(f.data == check)
813813

814+
def test_override(self):
815+
"""Test overriding one border with another"""
816+
grid = Grid(shape=(6, 8))
817+
x, y = grid.dimensions
818+
819+
border0 = Border(grid, 1)
820+
border1 = Border(grid, 1, dims={x: 'left', y: 'right'}, corners='nocorners')
821+
822+
f0 = Function(name='f0', grid=grid, dtype=np.int32)
823+
f1 = Function(name='f1', grid=grid, dtype=np.int32)
824+
825+
op0 = Operator(Eq(f0, f0+1, subdomain=border0))
826+
827+
# Replace the border with a different one at runtime
828+
op0.apply(border0=border1)
829+
830+
# Compare to just running with the override border directly
831+
Operator(Eq(f1, f1+1, subdomain=border1))()
832+
833+
assert np.all(f0 == f1)
834+
814835

815836
class TestSubDomain_w_condition:
816837

0 commit comments

Comments
 (0)