@@ -685,8 +685,7 @@ def __setstate__(self, state):
685685 self ._distributor = SubDistributor (self )
686686
687687
688- class MultiSubDomain (AbstractSubDomain ):
689- # FIXME: Should subclass ArgProvider
688+ class MultiSubDomain (AbstractSubDomain , ArgProvider ):
690689
691690 """
692691 Abstract base class for types representing groups of SubDomains.
@@ -745,6 +744,26 @@ def _bounds_glb_to_loc(cls, dec, m, M):
745744 return bounds_m , bounds_M
746745
747746
747+ class MultiSubDomainFunction (Function ):
748+ def _arg_values (self , ** kwargs ):
749+ # TODO: Will want estimate-memory utility in due course
750+ new = kwargs .get (self .name , self )
751+
752+ # We support `op.apply(..., msd=msd1, ...)`
753+ if isinstance (new , MultiSubDomain ):
754+ # TODO: Homogenise this nomenclature throughout all MultiSubDomains
755+ new = new ._subfunction
756+
757+ return new ._arg_defaults (alias = self , ** kwargs ).reduce_all ()
758+
759+ def _arg_defaults (self , alias = None , ** kwargs ):
760+ # Either apply the overide here, or apply the default values
761+ # Will need to replace the function, but also adjust the implicit
762+ # dimension bounds -> this bit is trickier, since they only get numbered
763+ # during compilation
764+ pass
765+
766+
748767class SubDomainSet (MultiSubDomain ):
749768 """
750769 Class to define a set of N (a positive integer) subdomains.
@@ -875,9 +894,9 @@ def __subdomain_finalize_core__(self, grid):
875894 # of replacements without risking overwriting.
876895 i_dim = Dimension (f'n_{ str (id (self ))} ' )
877896 d_dim = DefaultDimension (name = 'd' , default_value = 2 * grid .dim )
878- sd_func = Function (name = self .name , grid = self ._grid ,
879- shape = (self ._n_domains , 2 * grid .dim ),
880- dimensions = (i_dim , d_dim ), dtype = np .int32 )
897+ sd_func = MultiSubDomainFunction (name = self .name , grid = self ._grid ,
898+ shape = (self ._n_domains , 2 * grid .dim ),
899+ dimensions = (i_dim , d_dim ), dtype = np .int32 )
881900
882901 dimensions = []
883902 for i , d in enumerate (grid .dimensions ):
@@ -892,6 +911,7 @@ def __subdomain_finalize_core__(self, grid):
892911 ))
893912
894913 self ._dimensions = tuple (dimensions )
914+ self ._subfunction = sd_func
895915
896916 def __subdomain_finalize__ (self ):
897917 self .__subdomain_finalize_core__ (self .grid )
0 commit comments