Skip to content

Commit ba010e2

Browse files
committed
api: allow spacing as grid input
1 parent 8d1995d commit ba010e2

2 files changed

Lines changed: 45 additions & 8 deletions

File tree

devito/types/grid.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ class Grid(CartesianDiscretization, ArgProvider):
150150
_default_dimensions = ('x', 'y', 'z')
151151

152152
def __init__(self, shape, extent=None, origin=None, dimensions=None,
153-
time_dimension=None, dtype=np.float32, subdomains=None,
154-
comm=None, topology=None):
153+
spacing=None, time_dimension=None, dtype=np.float32,
154+
subdomains=None, comm=None, topology=None):
155155
shape = as_tuple(shape)
156156

157157
# Create or pull the SpaceDimensions
@@ -193,9 +193,19 @@ def __init__(self, shape, extent=None, origin=None, dimensions=None,
193193
self._topology = None
194194
self._distributor = Distributor(shape, dimensions, comm, self._topology)
195195

196-
# The physical extent
197-
extent = as_tuple(extent or tuple(1. for _ in self.shape))
198-
self._extent = tuple(dtype(e) for e in extent)
196+
# The physical extent and grid spacing
197+
if spacing is not None:
198+
self._spacing = tuple(dtype(s) for s in spacing)
199+
else:
200+
self._spacing = None
201+
202+
if extent is not None:
203+
self._extent = tuple(dtype(e) for e in extent)
204+
else:
205+
self._extent = None
206+
207+
if extent is None and spacing is None:
208+
raise ValueError("At least one of `extent` and `spacing` must be provided")
199209

200210
# The origin of the grid
201211
origin = as_tuple(origin or tuple(0. for _ in self.shape))
@@ -230,10 +240,13 @@ def __repr__(self):
230240
return 'Grid' + \
231241
f'[extent={self.extent}, shape={self.shape}, dimensions={self.dimensions}]'
232242

233-
@property
243+
@cached_property
234244
def extent(self):
235245
"""Physical extent of the domain in m."""
236-
return self._extent
246+
if self._extent is not None:
247+
return self._extent
248+
extent = ((np.array(self.shape) - 1)*np.array(self.spacing)).astype(self.dtype)
249+
return as_tuple(extent)
237250

238251
@property
239252
def origin(self):
@@ -293,9 +306,11 @@ def volume_cell(self):
293306
"""Volume of a single cell e.g h_x*h_y*h_z in 3D."""
294307
return prod(d.spacing for d in self.dimensions).subs(self.spacing_map)
295308

296-
@property
309+
@cached_property
297310
def spacing(self):
298311
"""Spacing between grid points in m."""
312+
if self._spacing is not None:
313+
return self._spacing
299314
spacing = (np.array(self.extent) / (np.array(self.shape) - 1)).astype(self.dtype)
300315
return as_tuple(spacing)
301316

tests/test_symbolics.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,28 @@ def test_real():
132132
assert s.is_imaginary is np.iscomplexobj(dtype(0))
133133

134134

135+
@pytest.mark.parametrize('spacing, extent, shape, expected, broken', [
136+
((0.5, 0.5), None, (11, 11), ((0.5, 0.5), (5.0, 5.0)), False),
137+
(None, (5.0, 5.0), (11, 11), ((0.5, 0.5), (5.0, 5.0)), False),
138+
((0.5, 0.5), (5.0, 5.0), (11, 11), ((0.5, 0.5), (5.0, 5.0)), False),
139+
(None, (.3, .3), (151, 146), ((0.002, 0.002), (.3, .3)), 'spacing'),
140+
((.002, .002), (.3, .3), (151, 146), ((0.002, 0.002), (.3, .3)), False),
141+
((.002, .002), None, (151, 146), ((0.002, 0.002), (.3, .3)), 'extent')
142+
])
143+
def test_grid_inputs(spacing, extent, shape, expected, broken):
144+
grid = Grid(shape=shape, spacing=spacing, extent=extent)
145+
sp, ex = expected
146+
if broken == 'spacing':
147+
assert grid.spacing != sp
148+
else:
149+
assert grid.spacing == sp
150+
151+
if broken == 'extent':
152+
assert grid.extent != ex
153+
else:
154+
assert grid.extent == ex
155+
156+
135157
def test_constant():
136158
c = Constant(name='c')
137159

0 commit comments

Comments
 (0)