Skip to content

Commit 751b272

Browse files
committed
api: enforce valid coordinates for inject/interp
1 parent 1745806 commit 751b272

4 files changed

Lines changed: 42 additions & 3 deletions

File tree

devito/operations/interpolators.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,29 @@ def wrapper(interp, *args, **kwargs):
3434
return wrapper
3535

3636

37+
def check_coords(func):
38+
@wraps(func)
39+
def wrapper(interp, *args, **kwargs):
40+
inputs = args + as_tuple(kwargs.get('expr', ()))
41+
# Subfunction of the SparseFunction use to create the interpolator
42+
sfunc = interp.sfunction
43+
# Subfunctions found in the arguments of the interpolation/injection operation
44+
a_sfuncs = {f for f in retrieve_functions(inputs)
45+
if f.is_SparseFunction} - {sfunc}
46+
if a_sfuncs:
47+
# Check that is uses the same coordinates as the interpolator's SparseFunction
48+
subfuncs = {getattr(sfunc, s, None) for s in sfunc._sub_functions}
49+
for f in a_sfuncs:
50+
for s in f._sub_functions:
51+
if getattr(f, s, None) not in subfuncs:
52+
raise ValueError(f"Interpolation/injection with {sfunc}"
53+
f"requires {f} "
54+
f"to use the same {s} as {sfunc}")
55+
56+
return func(interp, *args, **kwargs)
57+
return wrapper
58+
59+
3760
def _extract_subdomain(variables):
3861
"""
3962
Check if any of the variables provided are defined on a SubDomain
@@ -322,6 +345,7 @@ def _interp_idx(self, variables, implicit_dims=None, pos_only=(), subdomain=None
322345
return idx_subs, temps
323346

324347
@check_radius
348+
@check_coords
325349
def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None):
326350
"""
327351
Generate equations interpolating an arbitrary expression into ``self``.
@@ -342,6 +366,7 @@ def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None)
342366
return Interpolation(expr, increment, implicit_dims, self_subs, self)
343367

344368
@check_radius
369+
@check_coords
345370
def inject(self, field, expr, implicit_dims=None):
346371
"""
347372
Generate equations injecting an arbitrary expression into a field.

devito/operator/operator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,8 @@ def _prepare_arguments(self, autotune=None, estimate_memory=False, **kwargs):
647647
else:
648648
args[k] = args.unique(k, candidate=v)
649649

650-
kwargs['args'] = args.reduce_inplace()
650+
args.reduce_inplace()
651+
kwargs['args'] = args
651652

652653
for i in discretizations:
653654
args.update(i._arg_values(**kwargs))

devito/tools/data_structures.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,6 @@ def reduce_inplace(self):
223223
for k, v in self.reduce_all().items():
224224
self[k] = v
225225

226-
return self
227-
228226

229227
class DefaultOrderedDict(OrderedDict):
230228
# Source: http://stackoverflow.com/a/6190500/562769

tests/test_interpolation.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,3 +1255,18 @@ def test_inject_subdomain_mpi(self, mode):
12551255
assert data1 == None # noqa
12561256
assert data2 == None # noqa
12571257
assert data3 == None # noqa
1258+
1259+
1260+
def test_wrong_coords():
1261+
grid = Grid(shape=(11, 11))
1262+
s = SparseFunction(name='src', npoint=1, grid=grid)
1263+
s2 = SparseFunction(name='src2', npoint=1, grid=grid)
1264+
u = Function(name='u', grid=grid)
1265+
1266+
with pytest.raises(ValueError) as vinfo:
1267+
s.inject(u, expr=s2)
1268+
assert "Interpolation/injection with" in str(vinfo.value)
1269+
1270+
with pytest.raises(ValueError) as vinfo:
1271+
s.interpolate(u + s2)
1272+
assert "Interpolation/injection with" in str(vinfo.value)

0 commit comments

Comments
 (0)