33from sympy import S
44import numpy as np
55
6- from devito .finite_differences import IndexDerivative
6+ from devito .finite_differences import IndexDerivative , Weights
77from devito .ir import Backward , Forward , Interval , IterationSpace , Queue
88from devito .passes .clusters .misc import fuse
99from devito .symbolics import BasicWrapperMixin , reuse_if_untouched , uxreplace
@@ -94,17 +94,39 @@ def _core(expr, c, ispace, weights, reusables, mapper, **kwargs):
9494
9595
9696@_core .register (Symbol )
97- @_core .register (Indexed )
9897@_core .register (BasicWrapperMixin )
9998def _ (expr , c , ispace , weights , reusables , mapper , ** kwargs ):
10099 return expr , []
101100
102101
102+ @_core .register (Indexed )
103+ def _ (expr , c , ispace , weights , reusables , mapper , ** kwargs ):
104+ if not isinstance (expr .function , Weights ):
105+ return expr , []
106+
107+ # Lower or reuse a previously lowered Weights array
108+ sregistry = kwargs ['sregistry' ]
109+ subs_user = kwargs ['subs' ]
110+
111+ w0 = expr .function
112+ k = tuple (w0 .weights )
113+ try :
114+ w = weights [k ]
115+ except KeyError :
116+ name = sregistry .make_name (prefix = 'w' )
117+ dtype = infer_dtype ([w0 .dtype , c .dtype ]) # At least np.float32
118+ initvalue = tuple (i .subs (subs_user ) for i in k )
119+ w = weights [k ] = w0 ._rebuild (name = name , dtype = dtype , initvalue = initvalue )
120+
121+ rebuilt = expr ._subs (w0 .indexed , w .indexed )
122+
123+ return rebuilt , []
124+
125+
103126@_core .register (IndexDerivative )
104127def _ (expr , c , ispace , weights , reusables , mapper , ** kwargs ):
105128 sregistry = kwargs ['sregistry' ]
106129 options = kwargs ['options' ]
107- subs_user = kwargs ['subs' ]
108130
109131 try :
110132 cbk0 = deriv_schedule_registry [options ['deriv-schedule' ]]
@@ -117,18 +139,10 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
117139
118140 # Create the concrete Weights array, or reuse an already existing one
119141 # if possible
120- name = sregistry .make_name (prefix = 'w' )
121- w0 = ideriv .weights .function
122- dtype = infer_dtype ([w0 .dtype , c .dtype ]) # At least np.float32
123- k = tuple (w0 .weights )
124- try :
125- w = weights [k ]
126- except KeyError :
127- initvalue = tuple (i .subs (subs_user ) for i in k )
128- w = weights [k ] = w0 ._rebuild (name = name , dtype = dtype , initvalue = initvalue )
142+ w , _ = _core (ideriv .weights , c , ispace , weights , reusables , mapper , ** kwargs )
129143
130144 # Replace the abstract Weights array with the concrete one
131- subs = {w0 . indexed : w .indexed }
145+ subs = {ideriv . weights . base : w .base }
132146 init = uxreplace (init , subs )
133147 ideriv = uxreplace (ideriv , subs )
134148
@@ -158,10 +172,10 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
158172 # NOTE: created before recurring so that we ultimately get a sound ordering
159173 try :
160174 s = reusables .pop ()
161- assert np .can_cast (s .dtype , dtype )
175+ assert np .can_cast (s .dtype , w . dtype )
162176 except KeyError :
163177 name = sregistry .make_name (prefix = 'r' )
164- s = Symbol (name = name , dtype = dtype )
178+ s = Symbol (name = name , dtype = w . dtype )
165179
166180 # Go inside `expr` and recursively lower any nested IndexDerivatives
167181 expr , processed = _core (expr , c , ispace1 , weights , reusables , mapper , ** kwargs )
0 commit comments