Skip to content

Commit 8a10fbf

Browse files
committed
api: lif tensor component kwargs init to its own method
1 parent 658064e commit 8a10fbf

1 file changed

Lines changed: 31 additions & 23 deletions

File tree

devito/types/tensor.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,31 @@ def __init_finalize__(self, *args, **kwargs):
7979
inds, _ = Function.__indices_setup__(grid=grid, dimensions=dimensions)
8080
self._space_dimensions = inds
8181

82+
@classmethod
83+
def _component_kwargs(cls, *inds, **kwargs):
84+
"""
85+
Get the kwargs for a single component
86+
from the kwargs of the TensorFunction.
87+
"""
88+
kw = {}
89+
for k, v in kwargs.items():
90+
if k in ('staggered', 'name', 'dimensions', 'shape'):
91+
# Standard Function kwargs
92+
kw[k] = v
93+
elif isinstance(v, MatrixBase):
94+
if len(inds) > 1:
95+
kw[k] = v[inds[0], inds[1]]
96+
else:
97+
kw[k] = v[inds[0]]
98+
elif isinstance(v, (list, tuple)):
99+
if len(inds) > 1:
100+
kw[k] = v[inds[0]][inds[1]]
101+
else:
102+
kw[k] = v[inds[0]]
103+
else:
104+
kw[k] = v
105+
return kw
106+
82107
@classmethod
83108
def __subfunc_setup__(cls, *args, **kwargs):
84109
"""
@@ -110,18 +135,9 @@ def __subfunc_setup__(cls, *args, **kwargs):
110135
for j in range(start, stop):
111136
staggj = (stagg[i][j] if stagg is not None
112137
else (NODE if i == j else (d, dims[j])))
113-
# Setup kwargs for subfunction
114-
# Through rebuilding or user input, the kwargs could be
115-
# Tensors as well from a per-component property
116-
sub_kwargs = {'name': f"{name}_{d.name}{dims[j].name}",
117-
'staggered': staggj}
118-
for k, v in kwargs.items():
119-
if isinstance(v, MatrixBase):
120-
sub_kwargs[k] = v[i, j]
121-
elif isinstance(v, (list, tuple)):
122-
sub_kwargs[k] = v[i][j]
123-
else:
124-
sub_kwargs[k] = v
138+
sub_kwargs = cls._component_kwargs(i, j, **kwargs)
139+
sub_kwargs.update({'name': f"{name}_{d.name}{dims[j].name}",
140+
'staggered': staggj})
125141
funcs2[j] = cls._sub_type(**sub_kwargs)
126142
funcs.append(funcs2)
127143

@@ -335,17 +351,9 @@ def __subfunc_setup__(cls, *args, **kwargs):
335351
stagg = kwargs.get("staggered", None)
336352
name = kwargs.get("name")
337353
for i, d in enumerate(dims):
338-
kwargs["name"] = "%s_%s" % (name, d.name)
339-
kwargs["staggered"] = stagg[i] if stagg is not None else d
340-
# Setup kwargs for subfunction
341-
# Through rebuilding or user input, the kwargs could be
342-
# Tensors as well from a per-component property
343-
sub_kwargs = {}
344-
for k, v in kwargs.items():
345-
if isinstance(v, (list, tuple, MatrixBase)):
346-
sub_kwargs[k] = v[i]
347-
else:
348-
sub_kwargs[k] = v
354+
sub_kwargs = cls._component_kwargs(i, **kwargs)
355+
sub_kwargs.update({'name': f"{name}_{d.name}",
356+
'staggered': stagg[i] if stagg is not None else d})
349357
funcs.append(cls._sub_type(**sub_kwargs))
350358

351359
return funcs

0 commit comments

Comments
 (0)