Skip to content

Commit aafde92

Browse files
committed
api: allow spacing as grid input
1 parent 8d1995d commit aafde92

2 files changed

Lines changed: 43 additions & 8 deletions

File tree

devito/types/grid.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ class Grid(CartesianDiscretization, ArgProvider):
150150
_default_dimensions = ('x', 'y', 'z')
151151

152152
def __init__(self, shape, extent=None, origin=None, dimensions=None,
153-
time_dimension=None, dtype=np.float32, subdomains=None,
154-
comm=None, topology=None):
153+
spacing=None, time_dimension=None, dtype=np.float32,
154+
subdomains=None, comm=None, topology=None):
155155
shape = as_tuple(shape)
156156

157157
# Create or pull the SpaceDimensions
@@ -193,9 +193,16 @@ def __init__(self, shape, extent=None, origin=None, dimensions=None,
193193
self._topology = None
194194
self._distributor = Distributor(shape, dimensions, comm, self._topology)
195195

196-
# The physical extent
197-
extent = as_tuple(extent or tuple(1. for _ in self.shape))
198-
self._extent = tuple(dtype(e) for e in extent)
196+
# The physical extent and grid spacing
197+
if spacing is not None:
198+
self._spacing = tuple(dtype(s) for s in as_tuple(spacing))
199+
else:
200+
self._spacing = None
201+
202+
if extent is not None:
203+
self._extent = tuple(dtype(e) for e in as_tuple(extent))
204+
else:
205+
self._extent = tuple(1. for _ in shape) if spacing is None else None
199206

200207
# The origin of the grid
201208
origin = as_tuple(origin or tuple(0. for _ in self.shape))
@@ -230,10 +237,13 @@ def __repr__(self):
230237
return 'Grid' + \
231238
f'[extent={self.extent}, shape={self.shape}, dimensions={self.dimensions}]'
232239

233-
@property
240+
@cached_property
234241
def extent(self):
235242
"""Physical extent of the domain in m."""
236-
return self._extent
243+
if self._extent is not None:
244+
return self._extent
245+
extent = ((np.array(self.shape) - 1)*np.array(self.spacing)).astype(self.dtype)
246+
return as_tuple(extent)
237247

238248
@property
239249
def origin(self):
@@ -293,9 +303,11 @@ def volume_cell(self):
293303
"""Volume of a single cell e.g h_x*h_y*h_z in 3D."""
294304
return prod(d.spacing for d in self.dimensions).subs(self.spacing_map)
295305

296-
@property
306+
@cached_property
297307
def spacing(self):
298308
"""Spacing between grid points in m."""
309+
if self._spacing is not None:
310+
return self._spacing
299311
spacing = (np.array(self.extent) / (np.array(self.shape) - 1)).astype(self.dtype)
300312
return as_tuple(spacing)
301313

tests/test_symbolics.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,29 @@ def test_real():
132132
assert s.is_imaginary is np.iscomplexobj(dtype(0))
133133

134134

135+
@pytest.mark.parametrize('spacing, extent, shape, expected, broken', [
136+
((0.5, 0.5), None, (11, 11), ((0.5, 0.5), (5.0, 5.0)), False),
137+
(None, (5.0, 5.0), (11, 11), ((0.5, 0.5), (5.0, 5.0)), False),
138+
((0.5, 0.5), (5.0, 5.0), (11, 11), ((0.5, 0.5), (5.0, 5.0)), False),
139+
(None, (.3, .3), (151, 146), ((0.002, 0.002), (.3, .3)), 'spacing'),
140+
((.002, .002), (.3, .3), (151, 146), ((0.002, 0.002), (.3, .3)), False),
141+
((.002, .002), None, (151, 146), ((0.002, 0.002), (.3, .3)), 'extent'),
142+
(None, None, (11, 11), ((.1, .1), (1.0, 1.0)), False),
143+
])
144+
def test_grid_inputs(spacing, extent, shape, expected, broken):
145+
grid = Grid(shape=shape, spacing=spacing, extent=extent)
146+
sp, ex = expected
147+
if broken == 'spacing':
148+
assert grid.spacing != sp
149+
else:
150+
assert np.allclose(grid.spacing, sp, atol=0, rtol=0)
151+
152+
if broken == 'extent':
153+
assert grid.extent != ex
154+
else:
155+
assert np.allclose(grid.extent, ex, atol=0, rtol=0)
156+
157+
135158
def test_constant():
136159
c = Constant(name='c')
137160

0 commit comments

Comments
 (0)