Skip to content

Commit 4354077

Browse files
committed
dsl: Add options for including, excluding, or overlapping corners
1 parent fd79b2e commit 4354077

1 file changed

Lines changed: 54 additions & 3 deletions

File tree

devito/types/grid.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -911,8 +911,7 @@ def bounds(self):
911911
class Border(SubDomainSet):
912912
"""
913913
A convenience class for constructing a SubDomainSet which covers specified edges
914-
of the domain to a thickness of `border`. Note that none of the subdomains in this
915-
MultiSubDomain will overlap with one another.
914+
of the domain to a thickness of `border`.
916915
917916
By default, this border covers all sides of the grid. Alternatively, it is possible
918917
to add the border selectively to specific sides by supplying, for example,
@@ -932,6 +931,10 @@ class Border(SubDomainSet):
932931
to borders on both sides of all dimensions.
933932
name : str, optional
934933
A unique name for the SubDomainSet created. Default is 'border'.
934+
corners : str, optional
935+
Behaviour at the corners. Can be set to 'overlap' for overlapping subdomains at
936+
the corners, 'nooverlap' for non-overlapping corner subdomains, or 'nocorners'
937+
to omit the corners entirely. Default is `nooverlap`.
935938
936939
Examples
937940
--------
@@ -1006,12 +1009,18 @@ class Border(SubDomainSet):
10061009
ParsedDimSpec = frozendict[Dimension, Dimension | str]
10071010

10081011
def __init__(self, grid: Grid, border: int | np.integer,
1009-
dims: DimSpec = None, name: str = 'border') -> None:
1012+
dims: DimSpec = None, name: str = 'border',
1013+
corners: str = 'nooverlap') -> None:
10101014

10111015
self._name = name
1016+
# FIXME: Needs to accept int, tuple of int, or tuple of tuple of int
10121017
self._border = border
10131018
self._border_dims = Border._parse_dims(dims, grid)
10141019

1020+
if corners not in ('overlap', 'nooverlap', 'nocorners'):
1021+
raise ValueError(f"Unrecognised corners option: {corners}")
1022+
self._corners = corners
1023+
10151024
ndomains, bounds = self._build_domains(grid)
10161025
super().__init__(N=ndomains, bounds=bounds, grid=grid)
10171026

@@ -1027,6 +1036,10 @@ def border_dims(self):
10271036
def name(self):
10281037
return self._name
10291038

1039+
@property
1040+
def corners(self):
1041+
return self._corners
1042+
10301043
@staticmethod
10311044
def _parse_dims(dims: DimSpec, grid: Grid) -> ParsedDimSpec:
10321045
if dims is None:
@@ -1045,7 +1058,41 @@ def _build_domains(self, grid: Grid) -> tuple[int, tuple[np.ndarray]]:
10451058
"""
10461059
Constructs the bounds and ndomains kwargs for the SubDomainSet.
10471060
"""
1061+
if self.corners == 'overlap':
1062+
return self._build_domains_overlap(grid)
1063+
else:
1064+
return self._build_domains_nooverlap(grid)
1065+
1066+
def _build_domains_overlap(self, grid: Grid) -> tuple[int, tuple[np.ndarray]]:
1067+
1068+
bounds = []
1069+
for i, (d, s) in enumerate(zip(grid.dimensions, grid.shape)):
1070+
1071+
if d in self.border_dims:
1072+
side = self.border_dims[d]
1073+
1074+
if isinstance(side, Dimension):
1075+
bounds_l = [0 if j != 2*i else s-self.border
1076+
for j in range(2*len(grid.dimensions))]
1077+
bounds_r = [0 if j != 2*i+1 else s-self.border
1078+
for j in range(2*len(grid.dimensions))]
10481079

1080+
bounds.extend([bounds_l, bounds_r])
1081+
1082+
elif side == 'left':
1083+
bounds.append([0 if j != 2*i else s-self.border
1084+
for j in range(2*len(grid.dimensions))])
1085+
1086+
elif side == 'right':
1087+
bounds.append([0 if j != 2*i+1 else s-self.border
1088+
for j in range(2*len(grid.dimensions))])
1089+
1090+
else:
1091+
raise ValueError(f"Unrecognised side value {side}")
1092+
1093+
return len(bounds), tuple(np.array(bounds))
1094+
1095+
def _build_domains_nooverlap(self, grid: Grid) -> tuple[int, tuple[np.ndarray]]:
10491096
domain_map = {} # Stores the side
10501097
interval_map = {} # Stores the mapping from the side to bounds
10511098

@@ -1083,6 +1130,10 @@ def _build_domains(self, grid: Grid) -> tuple[int, tuple[np.ndarray]]:
10831130
if all(i == CENTER for i in d):
10841131
abstract_domains.remove(d)
10851132

1133+
# If 'no corners' option selected, then remove any corners
1134+
if self.corners == 'nocorners' and not any(i == CENTER for i in d):
1135+
abstract_domains.remove(d)
1136+
10861137
domains = []
10871138
for dom in abstract_domains:
10881139
domains.append([interval_map[d][i]

0 commit comments

Comments
 (0)