Skip to content

Commit 6d6f658

Browse files
committed
dsl: Add Re and Im operators for taking real and imaginary parts of an expression
1 parent 0c9f18b commit 6d6f658

5 files changed

Lines changed: 179 additions & 4 deletions

File tree

devito/finite_differences/differentiable.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from devito.types.basic import AbstractFunction
2323

2424
__all__ = ['Differentiable', 'DiffDerivative', 'IndexDerivative', 'EvalDerivative',
25-
'Weights']
25+
'Weights', 'Re', 'Im']
2626

2727

2828
class Differentiable(sympy.Expr, Evaluable):
@@ -645,6 +645,39 @@ def __str__(self):
645645
__repr__ = __str__
646646

647647

648+
class ComplexPart(Differentiable, sympy.core.function.Application):
649+
"""Abstract class for `Re` or `Im` of an expression"""
650+
651+
def __new__(cls, *args, **kwargs):
652+
if len(args) != 1:
653+
raise ValueError(f"{cls.__name__} is constructed with exactly one arg;"
654+
f" {len(args)} were supplied.")
655+
656+
# Diffify any Add, Mul, etc which might be in the expression
657+
new_args = (diffify(args[0]),)
658+
659+
if not np.issubdtype(new_args[0].dtype, np.complexfloating):
660+
raise ValueError(f"{cls.__name__} requires a complex dtype,"
661+
f" not {new_args[0].dtype.__name__}.")
662+
663+
return super().__new__(cls, *new_args, **kwargs)
664+
665+
def __str__(self):
666+
return f"{self.__class__.__name__}({self.args[0]})"
667+
668+
__repr__ = __str__
669+
670+
671+
class Re(ComplexPart):
672+
"""Get the real part of an expression"""
673+
pass
674+
675+
676+
class Im(ComplexPart):
677+
"""Get the imaginary part of an expression"""
678+
pass
679+
680+
648681
class IndexSum(sympy.Expr, Evaluable):
649682

650683
"""

devito/ir/equations/equation.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from devito.ir.support import (GuardFactor, Interval, IntervalGroup, IterationSpace,
99
Stencil, detect_io, detect_accesses)
1010
from devito.symbolics import IntDiv, limits_mapper, uxreplace
11-
from devito.tools import Pickable, Tag, frozendict
11+
from devito.tools import Pickable, Tag, frozendict, infer_dtype
1212
from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min
1313

1414
__all__ = ['LoweredEq', 'ClusterizedEq', 'DummyEq', 'OpInc', 'OpMin', 'OpMax',
@@ -48,7 +48,12 @@ def directions(self):
4848

4949
@property
5050
def dtype(self):
51-
return self.lhs.dtype
51+
try:
52+
rhs_dtype = self.rhs.dtype
53+
except AttributeError:
54+
rhs_dtype = None
55+
56+
return infer_dtype({self.lhs.dtype, rhs_dtype} - {None})
5257

5358
@property
5459
def state(self):

devito/passes/iet/languages/C.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,13 @@ class CPrinter(BasePrinter, C99CodePrinter):
5555

5656
def _print_ImaginaryUnit(self, expr):
5757
return '_Complex_I'
58+
59+
def _print_Re(self, expr):
60+
"""Print an Re as an access into the second entry of a float array."""
61+
return (f'{self.func_prefix(expr)}real{self.func_literal(expr).lower()}'
62+
f'({self._print(expr.args[0])})')
63+
64+
def _print_Im(self, expr):
65+
"""Print an Im as an access into the second entry of a float array."""
66+
return (f'{self.func_prefix(expr)}imag{self.func_literal(expr).lower()}'
67+
f'({self._print(expr.args[0])})')

devito/passes/iet/languages/CXX.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ class CXXPrinter(BasePrinter, CXX11CodePrinter):
103103

104104
def _print_ImaginaryUnit(self, expr):
105105
return f'1i{self.prec_literal(expr).lower()}'
106+
# return '1i'
107+
108+
def _print_Re(self, expr):
109+
return f'{self._ns}real({self._print(expr.args[0])})'
110+
111+
def _print_Im(self, expr):
112+
return f'{self._ns}imag({self._print(expr.args[0])})'
106113

107114
def _print_Cast(self, expr):
108115
# The CXX recommended way to cast is to use static_cast

tests/test_symbolics.py

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sympy import Expr, Number, Symbol
88
from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa
99
Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos,
10-
Min, Max)
10+
Min, Max, Re, Im, switchconfig)
1111
from devito.finite_differences.differentiable import SafeInv, Weights
1212
from devito.ir import Expression, FindNodes, ccode
1313
from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa
@@ -874,3 +874,123 @@ def test_assumptions(self, op, expr, assumptions, expected):
874874
assumptions = eval(assumptions)
875875
expected = eval(expected)
876876
assert evalrel(op, eqn, assumptions) == expected
877+
878+
879+
class TestComplexParts:
880+
# TODO: Add a cxx switchconfig
881+
def setup_basic(self, dtype):
882+
grid = Grid(shape=(5,), extent=(4.,))
883+
f = Function(name='f', grid=grid, dtype=dtype)
884+
f.data_with_halo[:] = np.arange(7) + 1j*np.arange(7, 14)[::-1]
885+
886+
f_real = Function(name='f_real', grid=grid)
887+
f_imag = Function(name='f_imag', grid=grid)
888+
return f, f_real, f_imag
889+
890+
def run_operator(self, eqs, cxx):
891+
if cxx:
892+
with switchconfig(language='CXX'):
893+
Operator(eqs)()
894+
else:
895+
Operator(eqs)()
896+
897+
@pytest.mark.parametrize('cxx', [False, True])
898+
def test_printing(self, cxx):
899+
f, f_real, f_imag = self.setup_basic(np.complex64)
900+
901+
eq_re = Eq(f_real, Re(f))
902+
eq_im = Eq(f_imag, Im(f))
903+
904+
if cxx:
905+
with switchconfig(language='CXX'):
906+
op = Operator([eq_re, eq_im])
907+
assert "f_real[x + 1] = std::real(f[x + 1])" in str(op.ccode)
908+
assert "f_imag[x + 1] = std::imag(f[x + 1])" in str(op.ccode)
909+
910+
else:
911+
op = Operator([eq_re, eq_im])
912+
assert "f_real[x + 1] = crealf(f[x + 1])" in str(op.ccode)
913+
assert "f_imag[x + 1] = cimagf(f[x + 1])" in str(op.ccode)
914+
915+
@pytest.mark.parametrize('cxx', [False, True])
916+
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
917+
def test_trivial(self, cxx, dtype):
918+
f, f_real, f_imag = self.setup_basic(dtype)
919+
920+
eq_re = Eq(f_real, Re(f+1.))
921+
eq_im = Eq(f_imag, Im(f+1.))
922+
923+
self.run_operator([eq_re, eq_im], cxx)
924+
925+
rcheck = np.array([2., 3., 4., 5., 6.])
926+
icheck = np.array([12., 11., 10., 9., 8.])
927+
assert np.all(np.isclose(f_real.data, rcheck))
928+
assert np.all(np.isclose(f_imag.data, icheck))
929+
930+
@pytest.mark.parametrize('cxx', [False, True])
931+
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
932+
def test_trivial_imag(self, cxx, dtype):
933+
f, f_real, f_imag = self.setup_basic(dtype)
934+
935+
eq_re = Eq(f_real, Re(f+1j))
936+
eq_im = Eq(f_imag, Im(f+1j))
937+
938+
self.run_operator([eq_re, eq_im], cxx)
939+
940+
rcheck = np.array([1., 2., 3., 4., 5.])
941+
icheck = np.array([13., 12., 11., 10., 9.])
942+
assert np.all(np.isclose(f_real.data, rcheck))
943+
assert np.all(np.isclose(f_imag.data, icheck))
944+
945+
@pytest.mark.parametrize('cxx', [False, True])
946+
def test_deriv(self, cxx):
947+
f, f_real, f_imag = self.setup_basic(np.complex64)
948+
949+
eq_re = Eq(f_real, Re(f.dx))
950+
eq_im = Eq(f_imag, Im(f.dx))
951+
952+
self.run_operator([eq_re, eq_im], cxx)
953+
954+
assert np.all(np.isclose(f_real.data, 1.))
955+
assert np.all(np.isclose(f_imag.data, -1.))
956+
957+
@pytest.mark.parametrize('cxx', [False, True])
958+
def test_outer_deriv(self, cxx):
959+
f, f_real, f_imag = self.setup_basic(np.complex64)
960+
961+
eq_re = Eq(f_real, Re(f).dx)
962+
eq_im = Eq(f_imag, Im(f).dx)
963+
964+
self.run_operator([eq_re, eq_im], cxx)
965+
966+
assert np.all(np.isclose(f_real.data, 1.))
967+
assert np.all(np.isclose(f_imag.data, -1.))
968+
969+
@pytest.mark.parametrize('cxx', [False, True])
970+
def test_mul(self, cxx):
971+
grid = Grid(shape=(5,))
972+
973+
f = Function(name='f', grid=grid, dtype=np.complex64)
974+
g = Function(name='g', grid=grid)
975+
h = Function(name='h', grid=grid, dtype=np.complex64)
976+
f.data[:] = 1 + 1j
977+
g.data[:] = 2
978+
h.data[:] = 2j
979+
980+
fg_re = Function(name='fg_re', grid=grid)
981+
fg_im = Function(name='fg_im', grid=grid)
982+
fh_re = Function(name='fh_re', grid=grid)
983+
fh_im = Function(name='fh_im', grid=grid)
984+
985+
eq_fg_re = Eq(fg_re, Re(f*g))
986+
eq_fg_im = Eq(fg_im, Im(f*g))
987+
eq_fh_re = Eq(fh_re, Re(f*h))
988+
eq_fh_im = Eq(fh_im, Im(f*h))
989+
990+
self.run_operator([eq_fg_re, eq_fg_im, eq_fh_re, eq_fh_im], cxx)
991+
992+
assert np.all(np.isclose(fg_re.data, 2.))
993+
assert np.all(np.isclose(fg_im.data, 2.))
994+
995+
assert np.all(np.isclose(fh_re.data, -2.))
996+
assert np.all(np.isclose(fh_im.data, 2.))

0 commit comments

Comments
 (0)