Skip to content

Commit 90aeff2

Browse files
committed
api: only allow per-component kwargs as Matrix to avoid conflicts
1 parent 8a10fbf commit 90aeff2

3 files changed

Lines changed: 9 additions & 20 deletions

File tree

devito/types/tensor.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -80,26 +80,15 @@ def __init_finalize__(self, *args, **kwargs):
8080
self._space_dimensions = inds
8181

8282
@classmethod
83-
def _component_kwargs(cls, *inds, **kwargs):
83+
def _component_kwargs(cls, inds, **kwargs):
8484
"""
8585
Get the kwargs for a single component
8686
from the kwargs of the TensorFunction.
8787
"""
8888
kw = {}
8989
for k, v in kwargs.items():
90-
if k in ('staggered', 'name', 'dimensions', 'shape'):
91-
# Standard Function kwargs
92-
kw[k] = v
93-
elif isinstance(v, MatrixBase):
94-
if len(inds) > 1:
95-
kw[k] = v[inds[0], inds[1]]
96-
else:
97-
kw[k] = v[inds[0]]
98-
elif isinstance(v, (list, tuple)):
99-
if len(inds) > 1:
100-
kw[k] = v[inds[0]][inds[1]]
101-
else:
102-
kw[k] = v[inds[0]]
90+
if isinstance(v, MatrixBase):
91+
kw[k] = v[inds]
10392
else:
10493
kw[k] = v
10594
return kw
@@ -135,7 +124,7 @@ def __subfunc_setup__(cls, *args, **kwargs):
135124
for j in range(start, stop):
136125
staggj = (stagg[i][j] if stagg is not None
137126
else (NODE if i == j else (d, dims[j])))
138-
sub_kwargs = cls._component_kwargs(i, j, **kwargs)
127+
sub_kwargs = cls._component_kwargs((i, j), **kwargs)
139128
sub_kwargs.update({'name': f"{name}_{d.name}{dims[j].name}",
140129
'staggered': staggj})
141130
funcs2[j] = cls._sub_type(**sub_kwargs)

docker/Dockerfile.devito

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ ARG GROUP_ID=1000
1313

1414
################## Install devito ############################################
1515

16+
ENV PIP_USE_PEP517=1
1617
# Install pip dependencies
1718
RUN python3 -m venv /venv && \
18-
/venv/bin/pip install --no-cache-dir --upgrade pip && \
19+
/venv/bin/pip install --no-cache-dir --upgrade pip wheel setuptools && \
1920
/venv/bin/pip install --no-cache-dir jupyter && \
20-
/venv/bin/pip install --no-cache-dir --upgrade wheel setuptools && \
2121
ln -fs /app/nvtop/build/src/nvtop /venv/bin/nvtop
2222

2323
# Copy Devito

tests/test_tensors.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
import sympy
3-
from sympy import Rational
3+
from sympy import Rational, Matrix
44

55
import pytest
66

@@ -498,6 +498,6 @@ def test_diag(func1):
498498

499499
@pytest.mark.parametrize('func1', [TensorFunction, VectorFunction])
500500
def test_kwargs(func1):
501-
orders = [[1, 2], [3, 4]] if func1 is TensorFunction else [1, 2]
501+
orders = Matrix([[1, 2], [3, 4]]) if func1 is TensorFunction else Matrix([1, 2])
502502
f = func1(name="f", grid=Grid((5, 5)), space_order=orders, symmetric=False)
503-
assert f.space_order == sympy.Matrix(orders)
503+
assert f.space_order == orders

0 commit comments

Comments
 (0)