Skip to content

Commit 4f41ec3

Browse files
committed
api: add support for diag(tensor) and diag(vector)
1 parent 29fc0c9 commit 4f41ec3

2 files changed

Lines changed: 38 additions & 3 deletions

File tree

devito/finite_differences/operators.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,19 @@ def diag(func, size=None):
165165
size of the diagonal matrix (size x size).
166166
Defaults to the number of spatial dimensions when unspecified
167167
"""
168+
from devito.types.tensor import TensorFunction, TensorTimeFunction
169+
if isinstance(func, TensorFunction):
170+
if func.is_TensorValued:
171+
return func._new(*func.shape, lambda i, j: func[i, i] if i == j else 0)
172+
else:
173+
n = func.shape[0]
174+
return func._new(n, n, lambda i, j: func[i] if i == j else 0)
175+
168176
dim = size or len(func.dimensions)
169177
dim = dim-1 if func.is_TimeDependent else dim
170178
to = getattr(func, 'time_order', 0)
171179

172-
from devito.types.tensor import TensorFunction, TensorTimeFunction
173180
tens_func = TensorTimeFunction if func.is_TimeDependent else TensorFunction
174-
175181
comps = [[func if i == j else 0 for i in range(dim)] for j in range(dim)]
176182
return tens_func(name='diag', grid=func.grid, space_order=func.space_order,
177183
components=comps, time_order=to, diagonal=True)

tests/test_tensors.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
import pytest
66

77
from devito import VectorFunction, TensorFunction, VectorTimeFunction, TensorTimeFunction
8-
from devito import Grid, Function, TimeFunction, Dimension, Eq, div, grad, curl, laplace
8+
from devito import (
9+
Grid, Function, TimeFunction, Dimension, Eq, div, grad, curl, laplace, diag
10+
)
911
from devito.symbolics import retrieve_derivatives
1012
from devito.types import NODE
1113

@@ -465,3 +467,30 @@ def test_rebuild(func1):
465467
assert j.name == i.name
466468
assert j.grid == i.grid
467469
assert j.dimensions == tuple(new_dims)
470+
471+
472+
@pytest.mark.parametrize('func1', [Function, TimeFunction,
473+
TensorFunction, TensorTimeFunction,
474+
VectorFunction, VectorTimeFunction])
475+
def test_diag(func1):
476+
grid = Grid(tuple([5]*3))
477+
f1 = func1(name="f1", grid=grid)
478+
479+
f2 = diag(f1)
480+
assert isinstance(f2, TensorFunction)
481+
if f1.is_TimeDependent:
482+
assert f2.is_TimeDependent
483+
print(f2)
484+
assert f2.shape == (3, 3)
485+
# Vector input
486+
if isinstance(f1, VectorFunction):
487+
assert all(f2[i, i] == f1[i] for i in range(3))
488+
assert all(f2[i, j] == 0 for i in range(3) for j in range(3) if i != j)
489+
# Tensor input
490+
elif isinstance(f1, TensorFunction):
491+
assert all(f2[i, i] == f1[i, i] for i in range(3))
492+
assert all(f2[i, j] == 0 for i in range(3) for j in range(3) if i != j)
493+
# Function input
494+
else:
495+
assert all(f2[i, j] == 0 for i in range(3) for j in range(3) if i != j)
496+
assert all(f2[i, i] == f1 for i in range(3))

0 commit comments

Comments
 (0)