Skip to content

Commit 2495bb8

Browse files
committed
tests: Update Re and Im tests to be more comprehensive
1 parent 4bc9b61 commit 2495bb8

2 files changed

Lines changed: 25 additions & 30 deletions

File tree

devito/passes/iet/languages/CXX.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ class CXXPrinter(BasePrinter, CXX11CodePrinter):
103103

104104
def _print_ImaginaryUnit(self, expr):
105105
return f'1i{self.prec_literal(expr).lower()}'
106-
# return '1i'
107106

108107
def _print_Re(self, expr):
109108
return f'{self._ns}real({self._print(expr.args[0])})'

tests/test_symbolics.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,6 @@ def test_assumptions(self, op, expr, assumptions, expected):
877877

878878

879879
class TestComplexParts:
880-
# TODO: Add a cxx switchconfig
881880
def setup_basic(self, dtype):
882881
grid = Grid(shape=(5,), extent=(4.,))
883882
f = Function(name='f', grid=grid, dtype=dtype)
@@ -887,87 +886,84 @@ def setup_basic(self, dtype):
887886
f_imag = Function(name='f_imag', grid=grid)
888887
return f, f_real, f_imag
889888

890-
def run_operator(self, eqs, cxx):
891-
if cxx:
892-
with switchconfig(language='CXX'):
893-
Operator(eqs)()
894-
else:
889+
def run_operator(self, eqs, language):
890+
with switchconfig(language=language):
895891
Operator(eqs)()
896892

897-
@pytest.mark.parametrize('cxx', [False, True])
898-
def test_printing(self, cxx):
893+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
894+
def test_printing(self, language):
899895
f, f_real, f_imag = self.setup_basic(np.complex64)
900896

901897
eq_re = Eq(f_real, Re(f))
902898
eq_im = Eq(f_imag, Im(f))
903899

904-
if cxx:
905-
with switchconfig(language='CXX'):
906-
op = Operator([eq_re, eq_im])
907-
assert "f_real[x + 1] = std::real(f[x + 1])" in str(op.ccode)
908-
assert "f_imag[x + 1] = std::imag(f[x + 1])" in str(op.ccode)
900+
with switchconfig(language=language):
901+
op = Operator([eq_re, eq_im])
902+
903+
if language in ('CXX', 'CXXopenmp'):
904+
assert "f_real[x + 1] = std::real(f[x + 1])" in str(op.ccode)
905+
assert "f_imag[x + 1] = std::imag(f[x + 1])" in str(op.ccode)
909906

910907
else:
911-
op = Operator([eq_re, eq_im])
912908
assert "f_real[x + 1] = crealf(f[x + 1])" in str(op.ccode)
913909
assert "f_imag[x + 1] = cimagf(f[x + 1])" in str(op.ccode)
914910

915-
@pytest.mark.parametrize('cxx', [False, True])
911+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
916912
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
917-
def test_trivial(self, cxx, dtype):
913+
def test_trivial(self, language, dtype):
918914
f, f_real, f_imag = self.setup_basic(dtype)
919915

920916
eq_re = Eq(f_real, Re(f+1.))
921917
eq_im = Eq(f_imag, Im(f+1.))
922918

923-
self.run_operator([eq_re, eq_im], cxx)
919+
self.run_operator([eq_re, eq_im], language)
924920

925921
rcheck = np.array([2., 3., 4., 5., 6.])
926922
icheck = np.array([12., 11., 10., 9., 8.])
927923
assert np.all(np.isclose(f_real.data, rcheck))
928924
assert np.all(np.isclose(f_imag.data, icheck))
929925

930-
@pytest.mark.parametrize('cxx', [False, True])
926+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
931927
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
932-
def test_trivial_imag(self, cxx, dtype):
928+
def test_trivial_imag(self, language, dtype):
933929
f, f_real, f_imag = self.setup_basic(dtype)
934930

935931
eq_re = Eq(f_real, Re(f+1j))
936932
eq_im = Eq(f_imag, Im(f+1j))
937933

938-
self.run_operator([eq_re, eq_im], cxx)
934+
self.run_operator([eq_re, eq_im], language)
939935

940936
rcheck = np.array([1., 2., 3., 4., 5.])
941937
icheck = np.array([13., 12., 11., 10., 9.])
942938
assert np.all(np.isclose(f_real.data, rcheck))
943939
assert np.all(np.isclose(f_imag.data, icheck))
944940

945-
@pytest.mark.parametrize('cxx', [False, True])
946-
def test_deriv(self, cxx):
941+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
942+
def test_deriv(self, language):
947943
f, f_real, f_imag = self.setup_basic(np.complex64)
948944

949945
eq_re = Eq(f_real, Re(f.dx))
950946
eq_im = Eq(f_imag, Im(f.dx))
951947

952-
self.run_operator([eq_re, eq_im], cxx)
948+
self.run_operator([eq_re, eq_im], language)
953949

954950
assert np.all(np.isclose(f_real.data, 1.))
955951
assert np.all(np.isclose(f_imag.data, -1.))
956952

957-
@pytest.mark.parametrize('cxx', [False, True])
958-
def test_outer_deriv(self, cxx):
953+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
954+
def test_outer_deriv(self, language):
959955
f, f_real, f_imag = self.setup_basic(np.complex64)
960956

961957
eq_re = Eq(f_real, Re(f).dx)
962958
eq_im = Eq(f_imag, Im(f).dx)
963959

964-
self.run_operator([eq_re, eq_im], cxx)
960+
self.run_operator([eq_re, eq_im], language)
965961

966962
assert np.all(np.isclose(f_real.data, 1.))
967963
assert np.all(np.isclose(f_imag.data, -1.))
968964

969-
@pytest.mark.parametrize('cxx', [False, True])
970-
def test_mul(self, cxx):
965+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
966+
def test_mul(self, language):
971967
grid = Grid(shape=(5,))
972968

973969
f = Function(name='f', grid=grid, dtype=np.complex64)
@@ -987,7 +983,7 @@ def test_mul(self, cxx):
987983
eq_fh_re = Eq(fh_re, Re(f*h))
988984
eq_fh_im = Eq(fh_im, Im(f*h))
989985

990-
self.run_operator([eq_fg_re, eq_fg_im, eq_fh_re, eq_fh_im], cxx)
986+
self.run_operator([eq_fg_re, eq_fg_im, eq_fh_re, eq_fh_im], language)
991987

992988
assert np.all(np.isclose(fg_re.data, 2.))
993989
assert np.all(np.isclose(fg_im.data, 2.))

0 commit comments

Comments
 (0)