Skip to content

Commit 238068a

Browse files
committed
dsl: Further work on making MultiSubDomains overrideable
1 parent 1f754dd commit 238068a

3 files changed

Lines changed: 29 additions & 7 deletions

File tree

devito/operator/operator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ def _prepare_arguments(self, autotune=None, **kwargs):
559559
Process runtime arguments passed to ``.apply()` and derive
560560
default values for any remaining arguments.
561561
"""
562+
from IPython import embed; embed()
562563
# Sanity check -- all user-provided keywords must be known to the Operator
563564
if not configuration['ignore-unknowns']:
564565
for k, v in kwargs.items():

devito/types/grid.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -685,8 +685,7 @@ def __setstate__(self, state):
685685
self._distributor = SubDistributor(self)
686686

687687

688-
class MultiSubDomain(AbstractSubDomain):
689-
# FIXME: Should subclass ArgProvider
688+
class MultiSubDomain(AbstractSubDomain, ArgProvider):
690689

691690
"""
692691
Abstract base class for types representing groups of SubDomains.
@@ -745,6 +744,26 @@ def _bounds_glb_to_loc(cls, dec, m, M):
745744
return bounds_m, bounds_M
746745

747746

747+
class MultiSubDomainFunction(Function):
748+
def _arg_values(self, **kwargs):
749+
# TODO: Will want estimate-memory utility in due course
750+
new = kwargs.get(self.name, self)
751+
752+
# We support `op.apply(..., msd=msd1, ...)`
753+
if isinstance(new, MultiSubDomain):
754+
# TODO: Homogenise this nomenclature throughout all MultiSubDomains
755+
new = new._subfunction
756+
757+
return new._arg_defaults(alias=self, **kwargs).reduce_all()
758+
759+
def _arg_defaults(self, alias=None, **kwargs):
760+
# Either apply the overide here, or apply the default values
761+
# Will need to replace the function, but also adjust the implicit
762+
# dimension bounds -> this bit is trickier, since they only get numbered
763+
# during compilation
764+
pass
765+
766+
748767
class SubDomainSet(MultiSubDomain):
749768
"""
750769
Class to define a set of N (a positive integer) subdomains.
@@ -875,9 +894,9 @@ def __subdomain_finalize_core__(self, grid):
875894
# of replacements without risking overwriting.
876895
i_dim = Dimension(f'n_{str(id(self))}')
877896
d_dim = DefaultDimension(name='d', default_value=2*grid.dim)
878-
sd_func = Function(name=self.name, grid=self._grid,
879-
shape=(self._n_domains, 2*grid.dim),
880-
dimensions=(i_dim, d_dim), dtype=np.int32)
897+
sd_func = MultiSubDomainFunction(name=self.name, grid=self._grid,
898+
shape=(self._n_domains, 2*grid.dim),
899+
dimensions=(i_dim, d_dim), dtype=np.int32)
881900

882901
dimensions = []
883902
for i, d in enumerate(grid.dimensions):
@@ -892,6 +911,7 @@ def __subdomain_finalize_core__(self, grid):
892911
))
893912

894913
self._dimensions = tuple(dimensions)
914+
self._subfunction = sd_func
895915

896916
def __subdomain_finalize__(self):
897917
self.__subdomain_finalize_core__(self.grid)

tests/test_subdomains.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -816,8 +816,9 @@ def test_override(self):
816816
grid = Grid(shape=(6, 8))
817817
x, y = grid.dimensions
818818

819-
border0 = Border(grid, 1)
820-
border1 = Border(grid, 1, dims={x: 'left', y: 'right'}, corners='nocorners')
819+
border0 = Border(grid, 1, name='border0')
820+
border1 = Border(grid, 1, dims={x: 'left', y: 'right'},
821+
corners='nocorners', name='border1')
821822

822823
f0 = Function(name='f0', grid=grid, dtype=np.int32)
823824
f1 = Function(name='f1', grid=grid, dtype=np.int32)

0 commit comments

Comments
 (0)