Skip to content

Commit 1c9ab2e

Browse files
committed
test: improve dtype tests log
1 parent 0c914f2 commit 1c9ab2e

1 file changed

Lines changed: 9 additions & 14 deletions

File tree

tests/test_dtypes.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,20 @@ def _get_language(language: str, **_) -> type[LangBB]:
2424
"""
2525
Gets the language building block type from parametrized kwargs.
2626
"""
27-
2827
return _languages[language]
2928

3029

3130
def _get_printer(language: str, **_) -> type[_DevitoPrinterBase]:
3231
"""
3332
Gets the printer building block type from parametrized kwargs.
3433
"""
35-
3634
return printer_registry[language]
3735

3836

3937
def _config_kwargs(platform: str, language: str) -> dict[str, str]:
4038
"""
4139
Generates kwargs for Operator to test language-specific behavior.
4240
"""
43-
4441
return {
4542
'platform': platform,
4643
'language': language,
@@ -57,15 +54,19 @@ def _config_kwargs(platform: str, language: str) -> dict[str, str]:
5754
]
5855

5956

57+
def kw_id(kwargs):
58+
# For more readable log
59+
return "-".join(f'{k}' for k in kwargs.values())
60+
61+
6062
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
61-
@pytest.mark.parametrize('kwargs', _configs)
63+
@pytest.mark.parametrize('kwargs', _configs, ids=kw_id)
6264
def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str],
6365
expected=None) -> None:
6466
"""
6567
Tests that half and complex floats' dtypes result in the correct type
6668
strings in generated code.
6769
"""
68-
6970
# Set up an operator
7071
grid = Grid(shape=(3, 3))
7172
x, y = grid.dimensions
@@ -83,13 +84,12 @@ def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str],
8384

8485

8586
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
86-
@pytest.mark.parametrize('kwargs', _configs)
87+
@pytest.mark.parametrize('kwargs', _configs, ids=kw_id)
8788
def test_cse_ctypes(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None:
8889
"""
8990
Tests that variables introduced by CSE have the correct type strings in
9091
the generated code.
9192
"""
92-
9393
# Retrieve the language-specific type mapping
9494
printer: type[_DevitoPrinterBase] = _get_printer(**kwargs)
9595

@@ -108,14 +108,13 @@ def test_cse_ctypes(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None
108108

109109

110110
@pytest.mark.parametrize('dtype', [np.float32, np.complex64, np.complex128])
111-
@pytest.mark.parametrize('kwargs', _configs)
111+
@pytest.mark.parametrize('kwargs', _configs, ids=kw_id)
112112
def test_complex_headers(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None:
113113
np.dtype
114114
"""
115115
Tests that the correct complex headers are included when complex dtypes
116116
are present in the operator, and omitted otherwise.
117117
"""
118-
119118
# Set up an operator
120119
grid = Grid(shape=(3, 3))
121120
x, y = grid.dimensions
@@ -134,7 +133,7 @@ def test_complex_headers(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) ->
134133

135134

136135
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
137-
@pytest.mark.parametrize('kwargs', _configs)
136+
@pytest.mark.parametrize('kwargs', _configs, ids=kw_id)
138137
def test_imag_unit(dtype: np.complexfloating, kwargs: dict[str, str]) -> None:
139138
"""
140139
Tests that the correct literal is used for the imaginary unit.
@@ -172,7 +171,6 @@ def test_math_functions(dtype: np.dtype[np.inexact],
172171
and assigned appropriately for different float precisions and for
173172
complex floats/doubles.
174173
"""
175-
176174
# Get the expected function call string
177175
call_str = str(sym)
178176
if np.issubdtype(dtype, np.complexfloating):
@@ -198,7 +196,6 @@ def test_complex_override(dtype: np.dtype[np.complexfloating]) -> None:
198196
"""
199197
Tests overriding complex values in op.apply().
200198
"""
201-
202199
grid = Grid(shape=(5, 5))
203200
x, y = grid.dimensions
204201

@@ -221,7 +218,6 @@ def test_complex_time_deriv(dtype: np.dtype[np.complexfloating]) -> None:
221218
"""
222219
Tests taking the time derivative of a complex-valued function.
223220
"""
224-
225221
grid = Grid(shape=(5, 5))
226222
x, y = grid.dimensions
227223
t = grid.time_dim
@@ -248,7 +244,6 @@ def test_complex_space_deriv(dtype: np.dtype[np.complexfloating]) -> None:
248244
Tests taking the space derivative of a complex-valued function, with
249245
respect to the real and imaginary axes.
250246
"""
251-
252247
grid = Grid(shape=(7, 7), dtype=dtype)
253248
x, y = grid.dimensions
254249

0 commit comments

Comments
 (0)