Skip to content

Commit 54c5e49

Browse files
committed
api: fix interpolate with complex dtype
1 parent 342be17 commit 54c5e49

4 files changed

Lines changed: 41 additions & 3 deletions

File tree

devito/finite_differences/finite_difference.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,6 @@ def first_derivative(expr, dim, fd_order, **kwargs):
157157

158158
def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coefficients,
159159
expand, weights=None):
160-
if deriv_order == 0 and not expr.is_Add:
161-
print(expr, dim, fd_order)
162160
# Always expand time derivatives to avoid issue with buffering and streaming.
163161
# Time derivative are almost always short stencils and won't benefit from
164162
# unexpansion in the rare case the derivative is not evaluated for time stepping.

devito/ir/equations/equation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,19 @@ def __new__(cls, *args, **kwargs):
234234
shift = 0
235235
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
236236

237+
# Merge conditionals when possible. E.g if we have an implicit_dim
238+
# and there is a dimension with the same parent, we ca merged
239+
# its condition
240+
for d in input_expr.implicit_dims:
241+
if d not in conditionals:
242+
continue
243+
for cd in dict(conditionals):
244+
if cd.parent == d.parent and cd != d:
245+
cond = conditionals.pop(d)
246+
mode = cd.relation and d.relation
247+
conditionals[cd] = mode(cond, conditionals[cd])
248+
break
249+
237250
conditionals = frozendict(conditionals)
238251

239252
# Lower all Differentiable operations into SymPy operations

devito/passes/clusters/cse.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ def retrieve_ctemps(exprs, mode='all'):
3636
return search(exprs, lambda expr: isinstance(expr, CTemp), mode, 'dfs')
3737

3838

39+
def cse_dtype(exprdtype, cdtype):
40+
"""
41+
Return the dtype of a CSE temporary given the dtype of the expression to be
42+
captured and the cluster's dtype.
43+
"""
44+
if np.issubdtype(cdtype, np.complexfloating):
45+
return np.promote_types(exprdtype, cdtype(0).real.__class__).type
46+
else:
47+
# Real cluster, can safely promote to the largest precision
48+
return np.promote_types(exprdtype, cdtype).type
49+
50+
3951
@cluster_pass
4052
def cse(cluster, sregistry=None, options=None, **kwargs):
4153
"""
@@ -86,7 +98,7 @@ def cse(cluster, sregistry=None, options=None, **kwargs):
8698
if cluster.is_fence:
8799
return cluster
88100

89-
make_dtype = lambda e: np.promote_types(e.dtype, dtype).type
101+
make_dtype = lambda e: cse_dtype(e.dtype, dtype)
90102
make = lambda e: CTemp(name=sregistry.make_name(), dtype=make_dtype(e))
91103

92104
exprs = _cse(cluster, make, min_cost=min_cost, mode=mode)

tests/test_interpolation.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,21 @@ def test_point_symbol_types(dtype, expected):
855855
assert point_symbol.dtype is expected
856856

857857

858+
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
859+
def test_interp_complex(dtype):
860+
grid = Grid((11, 11, 11))
861+
862+
sc = SparseFunction(name="sc", grid=grid, npoint=1, dtype=dtype)
863+
sc.coordinates.data[:] = [.5, .5, .5]
864+
865+
fc = Function(name="fc", grid=grid, npoint=2, dtype=dtype)
866+
fc.data[:] = np.random.randn(*grid.shape) + 1j * np.random.randn(*grid.shape)
867+
opC = Operator([sc.interpolate(expr=fc)], name="OpC")
868+
opC()
869+
870+
assert np.isclose(sc.data[0], fc.data[5, 5, 5])
871+
872+
858873
class SD0(SubDomain):
859874
name = 'sd0'
860875

0 commit comments

Comments
 (0)