Skip to content

Commit 0d6ad67

Browse files
committed
sympy: Possible sympy bug workaround
1 parent 4c330f7 commit 0d6ad67

1 file changed

Lines changed: 67 additions & 0 deletions

File tree

devito/finite_differences/differentiable.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,73 @@ def is_Staggered(self):
113113
def is_TimeDependent(self):
114114
return any(i.is_Time for i in self.dimensions)
115115

116+
def as_independent(self, *deps, **hint):
117+
"""
118+
A near copy of sympy.core.expr.Expr.as_independent
119+
with a bug fixed
120+
"""
121+
from sympy import Symbol
122+
from sympy.core.add import _unevaluated_Add
123+
from sympy.core.mul import _unevaluated_Mul
124+
125+
from sympy.core.singleton import S
126+
from sympy.utilities.iterables import sift
127+
128+
if self is S.Zero:
129+
return (self, self)
130+
131+
func = self.func
132+
if hint.get('as_Add', isinstance(self, Add) ):
133+
want = Add
134+
else:
135+
want = Mul
136+
137+
# sift out deps into symbolic and other and ignore
138+
# all symbols but those that are in the free symbols
139+
sym = set()
140+
other = []
141+
for d in deps:
142+
if isinstance(d, Symbol): # Symbol.is_Symbol is True
143+
sym.add(d)
144+
else:
145+
other.append(d)
146+
147+
def has(e):
148+
"""return the standard has() if there are no literal symbols, else
149+
check to see that symbol-deps are in the free symbols."""
150+
has_other = e.has(*other)
151+
if not sym:
152+
return has_other
153+
return has_other or e.has(*(e.free_symbols & sym))
154+
155+
if (want is not func or
156+
not issubclass(func, Add) and not issubclass(func, Mul)):
157+
if has(self):
158+
return (want.identity, self)
159+
else:
160+
return (self, want.identity)
161+
else:
162+
if func is Add:
163+
args = list(self.args)
164+
else:
165+
args, nc = self.args_cnc()
166+
167+
d = sift(args, has)
168+
depend = d[True]
169+
indep = d[False]
170+
171+
if func is Add: # all terms were treated as commutative
172+
return (Add(*indep), _unevaluated_Add(*depend))
173+
else: # handle noncommutative by stopping at first dependent term
174+
for i, n in enumerate(nc):
175+
if has(n):
176+
depend.extend(nc[i:])
177+
break
178+
indep.append(n)
179+
return Mul(*indep), (
180+
Mul(*depend, evaluate=False) if nc else
181+
_unevaluated_Mul(*depend))
182+
116183
@cached_property
117184
def _fd(self):
118185
# Filter out all args with fd order too high

0 commit comments

Comments
 (0)