Skip to content

Commit bead895

Browse files
committed
compiler: Enforce indexed weights as rightmost arg
1 parent edc9288 commit bead895

3 files changed

Lines changed: 20 additions & 4 deletions

File tree

devito/finite_differences/differentiable.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from devito.tools import (as_tuple, filter_ordered, flatten, frozendict,
2020
infer_dtype, extract_dtype, is_integer, split, is_number)
2121
from devito.types import Array, DimensionTuple, Evaluable, StencilDimension
22-
from devito.types.basic import AbstractFunction
22+
from devito.types.basic import AbstractFunction, Indexed
2323

2424
__all__ = ['Differentiable', 'DiffDerivative', 'IndexDerivative', 'EvalDerivative',
2525
'Weights', 'Real', 'Imag', 'Conj']
@@ -769,12 +769,22 @@ def free_symbols(self):
769769
func = DifferentiableOp._rebuild
770770

771771

772+
class WeightsIndexed(Indexed):
773+
pass
774+
775+
772776
class Weights(Array):
773777

774778
"""
775779
The weights (or coefficients) of a finite-difference expansion.
776780
"""
777781

782+
# Use IndexedWeights for the underlying Indexed objects because they
783+
# are guaranteed to appear at the end on an expression's .args.
784+
# This makes it dramatically easier to implement substutions. It also makes
785+
# it easier to visually parse IndexDerivatives when looking at them
786+
_indexed_cls = WeightsIndexed
787+
778788
def __init_finalize__(self, *args, **kwargs):
779789
dimensions = as_tuple(kwargs.get('dimensions'))
780790
weights = kwargs.get('initvalue')

devito/types/basic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1663,7 +1663,13 @@ def sort_key(self, order=None):
16631663

16641664
def __getitem__(self, indices, **kwargs):
16651665
"""Produce a types.Indexed, rather than a sympy.Indexed."""
1666-
return Indexed(self, *as_tuple(indices))
1666+
# Is there a specific Indexed class to use?
1667+
try:
1668+
cls = self.function._indexed_cls
1669+
except AttributeError:
1670+
cls = Indexed
1671+
1672+
return cls(self, *as_tuple(indices))
16671673

16681674
def _hashable_content(self):
16691675
return super()._hashable_content() + (self.function,)

tests/test_symbolics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -558,8 +558,8 @@ def test_canonical_ordering_of_weights():
558558

559559
assert (ccode(1.0*f[x, y, z] + 2.0*f[x, y + 1, z] + 3.0*f[x, y + 2, z]) ==
560560
'1.0F*f[x][y][z] + 2.0F*f[x][y + 1][z] + 3.0F*f[x][y + 2][z]')
561-
assert ccode(fi*wi) == 'w0[i0]*f[x][y + i0][z]'
562-
assert ccode(cf*wi) == 'w0[i0]*f[x][y + i0][z].x'
561+
assert ccode(fi*wi) == 'f[x][y + i0][z]*w0[i0]'
562+
assert ccode(cf*wi) == 'f[x][y + i0][z].x*w0[i0]'
563563

564564

565565
def test_symbolic_printing():

0 commit comments

Comments
 (0)