Skip to content

Commit db47bfc

Browse files
committed
compiler: Generalize ideriv lowering
1 parent db02491 commit db47bfc

1 file changed

Lines changed: 29 additions & 15 deletions

File tree

devito/passes/clusters/derivatives.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from sympy import S
44
import numpy as np
55

6-
from devito.finite_differences import IndexDerivative
6+
from devito.finite_differences import IndexDerivative, Weights
77
from devito.ir import Backward, Forward, Interval, IterationSpace, Queue
88
from devito.passes.clusters.misc import fuse
99
from devito.symbolics import BasicWrapperMixin, reuse_if_untouched, uxreplace
@@ -94,17 +94,39 @@ def _core(expr, c, ispace, weights, reusables, mapper, **kwargs):
9494

9595

9696
@_core.register(Symbol)
97-
@_core.register(Indexed)
9897
@_core.register(BasicWrapperMixin)
9998
def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
10099
return expr, []
101100

102101

102+
@_core.register(Indexed)
103+
def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
104+
if not isinstance(expr.function, Weights):
105+
return expr, []
106+
107+
# Lower or reuse a previously lowered Weights array
108+
sregistry = kwargs['sregistry']
109+
subs_user = kwargs['subs']
110+
111+
w0 = expr.function
112+
k = tuple(w0.weights)
113+
try:
114+
w = weights[k]
115+
except KeyError:
116+
name = sregistry.make_name(prefix='w')
117+
dtype = infer_dtype([w0.dtype, c.dtype]) # At least np.float32
118+
initvalue = tuple(i.subs(subs_user) for i in k)
119+
w = weights[k] = w0._rebuild(name=name, dtype=dtype, initvalue=initvalue)
120+
121+
rebuilt = expr._subs(w0.indexed, w.indexed)
122+
123+
return rebuilt, []
124+
125+
103126
@_core.register(IndexDerivative)
104127
def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
105128
sregistry = kwargs['sregistry']
106129
options = kwargs['options']
107-
subs_user = kwargs['subs']
108130

109131
try:
110132
cbk0 = deriv_schedule_registry[options['deriv-schedule']]
@@ -117,18 +139,10 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
117139

118140
# Create the concrete Weights array, or reuse an already existing one
119141
# if possible
120-
name = sregistry.make_name(prefix='w')
121-
w0 = ideriv.weights.function
122-
dtype = infer_dtype([w0.dtype, c.dtype]) # At least np.float32
123-
k = tuple(w0.weights)
124-
try:
125-
w = weights[k]
126-
except KeyError:
127-
initvalue = tuple(i.subs(subs_user) for i in k)
128-
w = weights[k] = w0._rebuild(name=name, dtype=dtype, initvalue=initvalue)
142+
w, _ = _core(ideriv.weights, c, ispace, weights, reusables, mapper, **kwargs)
129143

130144
# Replace the abstract Weights array with the concrete one
131-
subs = {w0.indexed: w.indexed}
145+
subs = {ideriv.weights.base: w.base}
132146
init = uxreplace(init, subs)
133147
ideriv = uxreplace(ideriv, subs)
134148

@@ -158,10 +172,10 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
158172
# NOTE: created before recurring so that we ultimately get a sound ordering
159173
try:
160174
s = reusables.pop()
161-
assert np.can_cast(s.dtype, dtype)
175+
assert np.can_cast(s.dtype, w.dtype)
162176
except KeyError:
163177
name = sregistry.make_name(prefix='r')
164-
s = Symbol(name=name, dtype=dtype)
178+
s = Symbol(name=name, dtype=w.dtype)
165179

166180
# Go inside `expr` and recursively lower any nested IndexDerivatives
167181
expr, processed = _core(expr, c, ispace1, weights, reusables, mapper, **kwargs)

0 commit comments

Comments
 (0)