Skip to content

Commit cc4a53f

Browse files
committed
compiler: Generalize ideriv lowering
1 parent 01255cf commit cc4a53f

1 file changed

Lines changed: 29 additions & 15 deletions

File tree

devito/passes/clusters/derivatives.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from sympy import S
55

6-
from devito.finite_differences import IndexDerivative
6+
from devito.finite_differences import IndexDerivative, Weights
77
from devito.ir import Backward, Forward, Interval, IterationSpace, Queue
88
from devito.passes.clusters.misc import fuse
99
from devito.symbolics import BasicWrapperMixin, reuse_if_untouched, uxreplace
@@ -91,17 +91,39 @@ def _core(expr, c, ispace, weights, reusables, mapper, **kwargs):
9191

9292

9393
@_core.register(Symbol)
94-
@_core.register(Indexed)
9594
@_core.register(BasicWrapperMixin)
9695
def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
9796
return expr, []
9897

9998

99+
@_core.register(Indexed)
100+
def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
101+
if not isinstance(expr.function, Weights):
102+
return expr, []
103+
104+
# Lower or reuse a previously lowered Weights array
105+
sregistry = kwargs['sregistry']
106+
subs_user = kwargs['subs']
107+
108+
w0 = expr.function
109+
k = tuple(w0.weights)
110+
try:
111+
w = weights[k]
112+
except KeyError:
113+
name = sregistry.make_name(prefix='w')
114+
dtype = infer_dtype([w0.dtype, c.dtype]) # At least np.float32
115+
initvalue = tuple(i.subs(subs_user) for i in k)
116+
w = weights[k] = w0._rebuild(name=name, dtype=dtype, initvalue=initvalue)
117+
118+
rebuilt = expr._subs(w0.indexed, w.indexed)
119+
120+
return rebuilt, []
121+
122+
100123
@_core.register(IndexDerivative)
101124
def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
102125
sregistry = kwargs['sregistry']
103126
options = kwargs['options']
104-
subs_user = kwargs['subs']
105127

106128
try:
107129
cbk0 = deriv_schedule_registry[options['deriv-schedule']]
@@ -114,18 +136,10 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
114136

115137
# Create the concrete Weights array, or reuse an already existing one
116138
# if possible
117-
name = sregistry.make_name(prefix='w')
118-
w0 = ideriv.weights.function
119-
dtype = infer_dtype([w0.dtype, c.dtype]) # At least np.float32
120-
k = tuple(w0.weights)
121-
try:
122-
w = weights[k]
123-
except KeyError:
124-
initvalue = tuple(i.subs(subs_user) for i in k)
125-
w = weights[k] = w0._rebuild(name=name, dtype=dtype, initvalue=initvalue)
139+
w, _ = _core(ideriv.weights, c, ispace, weights, reusables, mapper, **kwargs)
126140

127141
# Replace the abstract Weights array with the concrete one
128-
subs = {w0.indexed: w.indexed}
142+
subs = {ideriv.weights.base: w.base}
129143
init = uxreplace(init, subs)
130144
ideriv = uxreplace(ideriv, subs)
131145

@@ -155,10 +169,10 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
155169
# NOTE: created before recurring so that we ultimately get a sound ordering
156170
try:
157171
s = reusables.pop()
158-
assert np.can_cast(s.dtype, dtype)
172+
assert np.can_cast(s.dtype, w.dtype)
159173
except KeyError:
160174
name = sregistry.make_name(prefix='r')
161-
s = Symbol(name=name, dtype=dtype)
175+
s = Symbol(name=name, dtype=w.dtype)
162176

163177
# Go inside `expr` and recursively lower any nested IndexDerivatives
164178
expr, processed = _core(expr, c, ispace1, weights, reusables, mapper, **kwargs)

0 commit comments

Comments
 (0)