Skip to content

Commit 658064e

Browse files
committed
api: fix kwargs processing for tensor functions
1 parent 3047085 commit 658064e

5 files changed

Lines changed: 55 additions & 13 deletions

File tree

devito/builtins/initializers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def assign(f, rhs=0, options=None, name='assign', assign_halo=False, **kwargs):
7979
eqs = [eq.xreplace(subs) for eq in eqs]
8080

8181
op = dv.Operator(eqs, name=name, **kwargs)
82+
8283
try:
8384
op()
8485
except ValueError:

devito/types/basic.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,6 +1489,15 @@ def __subfunc_setup__(cls, *args, **kwargs):
14891489
"""Setup each component of the tensor as a Devito type."""
14901490
return []
14911491

1492+
@classmethod
1493+
def _sympify(self, arg):
1494+
# This is used internally by sympy to process arguments at rebuilt. And since
1495+
# some of our properties are non-sympyfiable we need to have a fallback
1496+
try:
1497+
return super()._sympify(arg)
1498+
except sympy.SympifyError:
1499+
return arg
1500+
14921501
@property
14931502
def grid(self):
14941503
"""

devito/types/dense.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,7 +1019,7 @@ def __init_finalize__(self, *args, **kwargs):
10191019

10201020
# Space order
10211021
space_order = kwargs.get('space_order', 1)
1022-
if isinstance(space_order, int):
1022+
if is_integer(space_order):
10231023
self._space_order = space_order
10241024
elif isinstance(space_order, tuple) and len(space_order) >= 2:
10251025
self._space_order = space_order[0]
@@ -1175,7 +1175,7 @@ def __halo_setup__(self, **kwargs):
11751175
halo = tuple(halo[d] for d in self.dimensions)
11761176
else:
11771177
space_order = kwargs.get('space_order', 1)
1178-
if isinstance(space_order, int):
1178+
if is_integer(space_order):
11791179
v = (space_order, space_order)
11801180
halo = [v if i.is_Space else (0, 0) for i in self.dimensions]
11811181

@@ -1208,12 +1208,12 @@ def __padding_setup__(self, **kwargs):
12081208
elif isinstance(padding, DimensionTuple):
12091209
padding = tuple(padding[d] for d in self.dimensions)
12101210

1211-
elif isinstance(padding, int):
1211+
elif is_integer(padding):
12121212
padding = tuple((0, padding) if d.is_Space else (0, 0)
12131213
for d in self.dimensions)
12141214

12151215
elif isinstance(padding, tuple) and len(padding) == self.ndim:
1216-
padding = tuple((0, i) if isinstance(i, int) else i for i in padding)
1216+
padding = tuple((0, i) if is_integer(i) else i for i in padding)
12171217

12181218
else:
12191219
raise TypeError("`padding` must be int or %d-tuple of ints" % self.ndim)
@@ -1398,7 +1398,7 @@ def __init_finalize__(self, *args, **kwargs):
13981398
self._time_order = kwargs.get('time_order', 1)
13991399
super().__init_finalize__(*args, **kwargs)
14001400

1401-
if not isinstance(self.time_order, int):
1401+
if not is_integer(self.time_order):
14021402
raise TypeError("`time_order` must be int")
14031403

14041404
self.save = kwargs.get('save')
@@ -1420,7 +1420,7 @@ def __indices_setup__(cls, *args, **kwargs):
14201420
time_dim = kwargs.get('time_dim')
14211421

14221422
if time_dim is None:
1423-
time_dim = grid.time_dim if isinstance(save, int) else grid.stepping_dim
1423+
time_dim = grid.time_dim if is_integer(save) else grid.stepping_dim
14241424
elif not (isinstance(time_dim, Dimension) and time_dim.is_Time):
14251425
raise TypeError("`time_dim` must be a time dimension")
14261426
dimensions = list(Function.__indices_setup__(**kwargs)[0])
@@ -1450,7 +1450,7 @@ def __shape_setup__(cls, **kwargs):
14501450
shape.insert(cls._time_position, time_order + 1)
14511451
elif isinstance(save, Buffer):
14521452
shape.insert(cls._time_position, save.val)
1453-
elif isinstance(save, int):
1453+
elif is_integer(save):
14541454
shape.insert(cls._time_position, save)
14551455
else:
14561456
raise TypeError("`save` can be None, int or Buffer, not %s" % type(save))

devito/types/tensor.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from functools import cached_property
33

44
import numpy as np
5+
from sympy.matrices.matrixbase import MatrixBase
56
from sympy.core.sympify import converter as sympify_converter
67

78
from devito.finite_differences import Differentiable
@@ -107,10 +108,21 @@ def __subfunc_setup__(cls, *args, **kwargs):
107108
start = i if (symm or diag) else 0
108109
stop = i + 1 if diag else len(dims)
109110
for j in range(start, stop):
110-
kwargs["name"] = "%s_%s%s" % (name, d.name, dims[j].name)
111-
kwargs["staggered"] = (stagg[i][j] if stagg is not None
112-
else (NODE if i == j else (d, dims[j])))
113-
funcs2[j] = cls._sub_type(**kwargs)
111+
staggj = (stagg[i][j] if stagg is not None
112+
else (NODE if i == j else (d, dims[j])))
113+
# Setup kwargs for subfunction
114+
# Through rebuilding or user input, the kwargs could be
115+
# Tensors as well from a per-component property
116+
sub_kwargs = {'name': f"{name}_{d.name}{dims[j].name}",
117+
'staggered': staggj}
118+
for k, v in kwargs.items():
119+
if isinstance(v, MatrixBase):
120+
sub_kwargs[k] = v[i, j]
121+
elif isinstance(v, (list, tuple)):
122+
sub_kwargs[k] = v[i][j]
123+
else:
124+
sub_kwargs[k] = v
125+
funcs2[j] = cls._sub_type(**sub_kwargs)
114126
funcs.append(funcs2)
115127

116128
# Symmetrize and fill diagonal if symmetric
@@ -169,7 +181,11 @@ def root_dimensions(self):
169181
@cached_property
170182
def space_order(self):
171183
"""The space order for all components."""
172-
return ({a.space_order for a in self} - {None}).pop()
184+
orders = self.applyfunc(lambda x: x.space_order)
185+
if len(set(orders)) > 1:
186+
return orders
187+
else:
188+
return orders[0]
173189

174190
@property
175191
def is_diagonal(self):
@@ -321,7 +337,16 @@ def __subfunc_setup__(cls, *args, **kwargs):
321337
for i, d in enumerate(dims):
322338
kwargs["name"] = "%s_%s" % (name, d.name)
323339
kwargs["staggered"] = stagg[i] if stagg is not None else d
324-
funcs.append(cls._sub_type(**kwargs))
340+
# Setup kwargs for subfunction
341+
# Through rebuilding or user input, the kwargs could be
342+
# Tensors as well from a per-component property
343+
sub_kwargs = {}
344+
for k, v in kwargs.items():
345+
if isinstance(v, (list, tuple, MatrixBase)):
346+
sub_kwargs[k] = v[i]
347+
else:
348+
sub_kwargs[k] = v
349+
funcs.append(cls._sub_type(**sub_kwargs))
325350

326351
return funcs
327352

tests/test_tensors.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,3 +494,10 @@ def test_diag(func1):
494494
else:
495495
assert all(f2[i, j] == 0 for i in range(3) for j in range(3) if i != j)
496496
assert all(f2[i, i] == f1 for i in range(3))
497+
498+
499+
@pytest.mark.parametrize('func1', [TensorFunction, VectorFunction])
500+
def test_kwargs(func1):
501+
orders = [[1, 2], [3, 4]] if func1 is TensorFunction else [1, 2]
502+
f = func1(name="f", grid=Grid((5, 5)), space_order=orders, symmetric=False)
503+
assert f.space_order == sympy.Matrix(orders)

0 commit comments

Comments
 (0)