Skip to content

Commit 350e23e

Browse files
committed
tests: Consolidate utility functions
1 parent ddfacda commit 350e23e

1 file changed

Lines changed: 12 additions & 25 deletions

File tree

tests/test_operator.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2069,19 +2069,15 @@ class TestEstimateMemory:
20692069
_array_temp = "r0L0(x, y)" if "CXX" in configuration['language'] else "r0[x][y]"
20702070
_devicelangs = ('openacc',)
20712071

2072-
def parse_output(self, summary, expected):
2073-
"""Parse estimate_memory machine-readable output"""
2072+
def parse_output(self, summary, check):
2073+
expected = ((check, check) if configuration['language']
2074+
in self._devicelangs else (check, 0))
20742075
assert (summary['host'], summary['device']) == expected
20752076

20762077
def sum_sizes(self, funcs):
20772078
return sum(reduce(mul, func.shape_allocated)*np.dtype(func.dtype).itemsize
20782079
for func in funcs)
20792080

2080-
def make_check(self, check):
2081-
if configuration['language'] in self._devicelangs:
2082-
return (check, check)
2083-
return (check, 0)
2084-
20852081
@pytest.mark.parametrize('shape', [(11,), (101, 101), (101, 101, 101)])
20862082
@pytest.mark.parametrize('dtype', [np.int8, np.int16, np.float32,
20872083
np.float32, np.complex64])
@@ -2098,8 +2094,7 @@ def test_basic_usage(self, caplog, shape, dtype, so):
20982094

20992095
# Check output of estimate_memory
21002096
host = reduce(mul, f.shape_allocated)*np.dtype(f.dtype).itemsize
2101-
expected = self.make_check(host)
2102-
self.parse_output(summary, expected)
2097+
self.parse_output(summary, host)
21032098

21042099
def test_multiple_objects(self, caplog):
21052100
grid = Grid(shape=(101, 101))
@@ -2112,8 +2107,7 @@ def test_multiple_objects(self, caplog):
21122107
assert "Allocating" not in caplog.text
21132108

21142109
check = self.sum_sizes((f, g))
2115-
expected = self.make_check(check)
2116-
self.parse_output(summary, expected)
2110+
self.parse_output(summary, check)
21172111

21182112
@pytest.mark.parametrize('time', [True, False])
21192113
def test_sparse(self, caplog, time):
@@ -2131,8 +2125,7 @@ def test_sparse(self, caplog, time):
21312125
assert "Allocating" not in caplog.text
21322126

21332127
check = self.sum_sizes((f, src, src.coordinates))
2134-
expected = self.make_check(check)
2135-
self.parse_output(summary, expected)
2128+
self.parse_output(summary, check)
21362129

21372130
@pytest.mark.parametrize('save', [None, Buffer(3), 10])
21382131
def test_timefunction(self, caplog, save):
@@ -2144,8 +2137,7 @@ def test_timefunction(self, caplog, save):
21442137
summary = op.estimate_memory()
21452138
assert "Allocating" not in caplog.text
21462139
check = reduce(mul, f.shape_allocated)*np.dtype(f.dtype).itemsize
2147-
expected = self.make_check(check)
2148-
self.parse_output(summary, expected)
2140+
self.parse_output(summary, check)
21492141

21502142
def test_mashup(self, caplog):
21512143
grid = Grid(shape=(101, 101))
@@ -2167,8 +2159,7 @@ def test_mashup(self, caplog):
21672159
assert "Allocating" not in caplog.text
21682160

21692161
check = self.sum_sizes((f, g, src0, src0.coordinates, src1, src1.coordinates))
2170-
expected = self.make_check(check)
2171-
self.parse_output(summary, expected)
2162+
self.parse_output(summary, check)
21722163

21732164
@pytest.mark.parametrize('override', [True, False])
21742165
def test_temp_array(self, caplog, override):
@@ -2211,8 +2202,7 @@ def test_temp_array(self, caplog, override):
22112202

22122203
# Factor in the temp array
22132204
check += reduce(mul, b.shape_allocated)*np.dtype(b.dtype).itemsize
2214-
expected = self.make_check(check)
2215-
self.parse_output(summary, expected)
2205+
self.parse_output(summary, check)
22162206

22172207
def test_overrides(self, caplog):
22182208
def setup(size, npoint, nt, counter):
@@ -2244,16 +2234,14 @@ def setup(size, npoint, nt, counter):
22442234
summary0 = op.estimate_memory(f0=f1, tf0=tf1, s0=s1, st0=st1)
22452235

22462236
check0 = self.sum_sizes((f1, tf1, s1, s1.coordinates, st1, st1.coordinates))
2247-
expected0 = self.make_check(check0)
2248-
self.parse_output(summary0, expected0)
2237+
self.parse_output(summary0, check0)
22492238

22502239
# Check with a second set of overrides
22512240
summary1 = op.estimate_memory(f0=f2, tf0=tf2, s0=s2, st0=st2)
22522241
assert "Allocating" not in caplog.text
22532242

22542243
check1 = self.sum_sizes((f2, tf2, s2, s2.coordinates, st2, st2.coordinates))
2255-
expected1 = self.make_check(check1)
2256-
self.parse_output(summary1, expected1)
2244+
self.parse_output(summary1, check1)
22572245

22582246
def test_device(self, caplog):
22592247
# Note: this uses switchconfig and runs on all backends to reflect expected
@@ -2276,5 +2264,4 @@ def test_device(self, caplog):
22762264
check = reduce(mul, f.shape_allocated)*np.dtype(f.dtype).itemsize
22772265

22782266
# Matching memory allocated both on host and device for memmap
2279-
expected = (check, check)
2280-
self.parse_output(summary, expected)
2267+
self.parse_output(summary, check)

0 commit comments

Comments
 (0)