Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ def __new__(cls, *args, **kwargs):
# Initialization. The following attributes must be available
# when executing __init_finalize__
newobj._name = name
newobj._dimensions = dimensions
newobj._dimensions = DimensionTuple(*dimensions, getters=dimensions)
newobj._shape = cls.__shape_setup__(**kwargs)
newobj._dtype = cls.__dtype_setup__(**kwargs)

Expand Down
3 changes: 2 additions & 1 deletion devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,7 +1128,8 @@ def _eval_at(self, func):
for d in self.dimensions:
try:
if self.indices_ref[d] is not func.indices_ref[d]:
mapper[self.indices_ref[d]] = func.indices_ref[d]
f_idx = func.indices_ref[d]._subs(func.dimensions[d], d)
mapper[self.indices_ref[d]] = f_idx
except KeyError:
pass

Expand Down
13 changes: 13 additions & 0 deletions tests/test_staggered_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,16 @@ def test_staggered_rebuild(stagg):
assert f2.indices[nd] == nd + nd.spacing / 2
else:
assert f2.indices[nd] == nd


def test_eval_at_different_dim():
grid = Grid(shape=(31, 17, 25))
nt = 5
x, _, _ = grid.dimensions

v = TimeFunction(name="v", grid=grid, staggered=x)
tau = TimeFunction(name="tau", grid=grid, save=nt)

eq = Eq(tau.forward, v).evaluate

assert grid.time_dim not in eq.rhs.free_symbols
Loading