Skip to content

Commit aab0e3e

Browse files
committed
compiler: patch efuncs graph
1 parent ad6303f commit aab0e3e

2 files changed

Lines changed: 8 additions & 2 deletions

File tree

devito/finite_differences/derivative.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections.abc import Iterable
33
from functools import cached_property
44
from itertools import chain
5+
import numbers
56

67
import sympy
78

@@ -158,11 +159,14 @@ def _validate_expr(expr):
158159
"""
159160
if type(expr) is sympy.Derivative:
160161
raise ValueError("Cannot nest sympy.Derivative with devito.Derivative")
162+
elif isinstance(expr, numbers.Number):
163+
return sympy.S.Zero
161164
if not isinstance(expr, Differentiable):
162165
try:
163166
expr = diffify(expr)
164167
except Exception as e:
165-
raise ValueError("`expr` must be a `Differentiable` type object") from e
168+
d = type(expr)
169+
raise ValueError(f"`expr` must be a `Differentiable` type object not {d}") from e
166170
return expr
167171

168172
@staticmethod

devito/passes/iet/engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,15 @@ def create_call_graph(root, efuncs):
223223
"""
224224
dag = DAG(nodes=[root])
225225
queue = [root]
226+
defuncs = {root}
226227

227228
while queue:
228229
caller = queue.pop(0)
229230
callees = FindNodes(Call).visit(efuncs[caller])
230231

231232
for callee in filter_ordered([i.name for i in callees]):
232233
if callee in efuncs: # Exclude foreign Calls, e.g., MPI calls
234+
defuncs.add(callee)
233235
try:
234236
dag.add_node(callee)
235237
queue.append(callee)
@@ -239,7 +241,7 @@ def create_call_graph(root, efuncs):
239241
dag.add_edge(callee, caller)
240242

241243
# Sanity check
242-
assert dag.size == len(efuncs)
244+
assert dag.size == len(defuncs)
243245

244246
return dag
245247

0 commit comments

Comments
 (0)