@@ -2068,16 +2068,9 @@ class TestEstimateMemory:
20682068
20692069 _array_temp = "r0L0(x, y)" if "CXX" in configuration ['language' ] else "r0[x][y]"
20702070
2071- def parse_output (self , output , expected ):
2071+ def parse_output (self , summary , expected ):
20722072 """Parse estimate_memory machine-readable output"""
2073- # Check that no allocation occurs as estimate_memory should avoid data touch
2074- assert "Allocating" not in output .text
2075-
2076- parsed = output .records [- 1 ].message .split ()
2077- name , host , device = parsed [:3 ]
2078- extracted = (name , int (host ), int (device ))
2079-
2080- assert extracted == expected
2073+ assert (summary ['host' ], summary ['device' ]) == expected
20812074
20822075 @pytest .mark .parametrize ('shape' , [(11 ,), (101 , 101 ), (101 , 101 , 101 )])
20832076 @pytest .mark .parametrize ('dtype' , [np .int8 , np .int16 , np .float32 ,
@@ -2089,13 +2082,14 @@ def test_basic_usage(self, caplog, shape, dtype, so):
20892082 with switchconfig (log_level = 'DEBUG' ), caplog .at_level (logging .DEBUG ):
20902083 op = Operator (Eq (f , 1 ))
20912084
2092- # Machine-readable output for parsing
2093- op .estimate_memory (human_readable = False )
2085+ summary = op .estimate_memory ()
2086+ # Check that no allocation occurs as estimate_memory should avoid data touch
2087+ assert "Allocating" not in caplog .text
20942088
20952089 # Check output of estimate_memory
20962090 host = reduce (mul , f .shape_allocated )* np .dtype (f .dtype ).itemsize
2097- expected = ("Kernel" , host , 0 )
2098- self .parse_output (caplog , expected )
2091+ expected = (host , 0 )
2092+ self .parse_output (summary , expected )
20992093
21002094 def test_multiple_objects (self , caplog ):
21012095 grid = Grid (shape = (101 , 101 ))
@@ -2104,12 +2098,13 @@ def test_multiple_objects(self, caplog):
21042098 g = Function (name = 'g' , grid = grid , space_order = 4 , dtype = np .float64 )
21052099 with switchconfig (log_level = 'DEBUG' ), caplog .at_level (logging .DEBUG ):
21062100 op = Operator ([Eq (f , 1 ), Eq (g , 1 )])
2107- op .estimate_memory (human_readable = False )
2101+ summary = op .estimate_memory ()
2102+ assert "Allocating" not in caplog .text
21082103
21092104 check = sum (reduce (mul , func .shape_allocated )* np .dtype (func .dtype ).itemsize
21102105 for func in (f , g ))
2111- expected = ("Kernel" , check , 0 )
2112- self .parse_output (caplog , expected )
2106+ expected = (check , 0 )
2107+ self .parse_output (summary , expected )
21132108
21142109 @pytest .mark .parametrize ('time' , [True , False ])
21152110 def test_sparse (self , caplog , time ):
@@ -2123,12 +2118,13 @@ def test_sparse(self, caplog, time):
21232118
21242119 with switchconfig (log_level = 'DEBUG' ), caplog .at_level (logging .DEBUG ):
21252120 op = Operator (src_term )
2126- op .estimate_memory (human_readable = False )
2121+ summary = op .estimate_memory ()
2122+ assert "Allocating" not in caplog .text
21272123
21282124 check = sum (reduce (mul , func .shape_allocated )* np .dtype (func .dtype ).itemsize
21292125 for func in (f , src , src .coordinates ))
2130- expected = ("Kernel" , check , 0 )
2131- self .parse_output (caplog , expected )
2126+ expected = (check , 0 )
2127+ self .parse_output (summary , expected )
21322128
21332129 @pytest .mark .parametrize ('save' , [None , Buffer (3 ), 10 ])
21342130 def test_timefunction (self , caplog , save ):
@@ -2137,10 +2133,11 @@ def test_timefunction(self, caplog, save):
21372133
21382134 with switchconfig (log_level = 'DEBUG' ), caplog .at_level (logging .DEBUG ):
21392135 op = Operator (Eq (f , 1 ))
2140- op .estimate_memory (human_readable = False )
2136+ summary = op .estimate_memory ()
2137+ assert "Allocating" not in caplog .text
21412138 check = reduce (mul , f .shape_allocated )* np .dtype (f .dtype ).itemsize
2142- expected = ("Kernel" , check , 0 )
2143- self .parse_output (caplog , expected )
2139+ expected = (check , 0 )
2140+ self .parse_output (summary , expected )
21442141
21452142 def test_mashup (self , caplog ):
21462143 grid = Grid (shape = (101 , 101 ))
@@ -2158,13 +2155,14 @@ def test_mashup(self, caplog):
21582155
21592156 with switchconfig (log_level = 'DEBUG' ), caplog .at_level (logging .DEBUG ):
21602157 op = Operator ([eq0 , eq1 ] + src_term0 + src_term1 )
2161- op .estimate_memory (human_readable = False )
2158+ summary = op .estimate_memory ()
2159+ assert "Allocating" not in caplog .text
21622160
21632161 check = sum (reduce (mul , func .shape_allocated )* np .dtype (func .dtype ).itemsize
21642162 for func in (f , g , src0 , src0 .coordinates ,
21652163 src1 , src1 .coordinates ))
2166- expected = ("Kernel" , check , 0 )
2167- self .parse_output (caplog , expected )
2164+ expected = (check , 0 )
2165+ self .parse_output (summary , expected )
21682166
21692167 def test_temp_array (self , caplog ):
21702168 """Check that temporary arrays will be factored into the memory calculation"""
@@ -2187,18 +2185,20 @@ def test_temp_array(self, caplog):
21872185 # Ensure an array temporary is created
21882186 assert self ._array_temp in str (op .ccode )
21892187
2190- op .estimate_memory (human_readable = False )
2188+ summary = op .estimate_memory ()
2189+ assert "Allocating" not in caplog .text
21912190
21922191 check = sum (reduce (mul , func .shape_allocated )* np .dtype (func .dtype ).itemsize
21932192 for func in (f , g , a ))
21942193
21952194 # Factor in the temp array
21962195 check += reduce (mul , b .shape_allocated )* np .dtype (a .dtype ).itemsize
21972196
2198- expected = ("Kernel" , check , 0 )
2199- self .parse_output (caplog , expected )
2197+ expected = (check , 0 )
2198+ self .parse_output (summary , expected )
22002199
22012200 def test_overrides (self , caplog ):
2201+ # TODO: Consolidate this boilerplate
22022202 grid0 = Grid (shape = (101 , 101 ))
22032203 # Original fields
22042204 f0 = Function (name = 'f0' , grid = grid0 , space_order = 4 )
@@ -2213,6 +2213,13 @@ def test_overrides(self, caplog):
22132213 s1 = SparseFunction (name = 's1' , grid = grid1 , npoint = 200 )
22142214 st1 = SparseTimeFunction (name = 'st1' , grid = grid1 , npoint = 200 , nt = 20 )
22152215
2216+ grid2 = Grid (shape = (51 , 51 )) # Smaller grid so overrides are distinct
2217+ # Alternative replacement fields
2218+ f2 = Function (name = 'f2' , grid = grid2 , space_order = 4 )
2219+ tf2 = TimeFunction (name = 'tf2' , grid = grid2 , space_order = 4 )
2220+ s2 = SparseFunction (name = 's2' , grid = grid2 , npoint = 50 )
2221+ st2 = SparseTimeFunction (name = 'st2' , grid = grid2 , npoint = 50 , nt = 5 )
2222+
22162223 eq0 = Eq (f0 , 1 )
22172224 eq1 = Eq (tf0 , 1 )
22182225 s0_term = s0 .inject (field = f0 , expr = s0 )
@@ -2222,13 +2229,61 @@ def test_overrides(self, caplog):
22222229 op = Operator ([eq0 , eq1 ] + s0_term + st0_term )
22232230
22242231 # Apply overrides for the check
2225- op .estimate_memory (f0 = f1 , tf0 = tf1 , s0 = s1 , st0 = st1 , human_readable = False )
2232+ summary0 = op .estimate_memory (f0 = f1 , tf0 = tf1 , s0 = s1 , st0 = st1 )
2233+
2234+ check0 = sum (reduce (mul , func .shape_allocated )* np .dtype (func .dtype ).itemsize
2235+ for func in (f1 , tf1 , s1 , s1 .coordinates , st1 , st1 .coordinates ))
2236+
2237+ expected0 = (check0 , 0 )
2238+ self .parse_output (summary0 , expected0 )
2239+
2240+ # Check with a second set of overrides
2241+ summary1 = op .estimate_memory (f0 = f2 , tf0 = tf2 , s0 = s2 , st0 = st2 )
2242+ assert "Allocating" not in caplog .text
2243+
2244+ check1 = sum (reduce (mul , func .shape_allocated )* np .dtype (func .dtype ).itemsize
2245+ for func in (f2 , tf2 , s2 , s2 .coordinates , st2 , st2 .coordinates ))
2246+
2247+ expected1 = (check1 , 0 )
2248+ self .parse_output (summary1 , expected1 )
2249+
2250+ def test_overrides_w_temp_array (self , caplog ):
2251+ """Check that temporary arrays are correctly adjusted for overrides"""
2252+ grid = Grid (shape = (101 , 101 ))
2253+ f = TimeFunction (name = 'f' , grid = grid , space_order = 2 )
2254+ g = TimeFunction (name = 'g' , grid = grid , space_order = 2 )
2255+ a = Function (name = 'a' , grid = grid , space_order = 2 )
2256+
2257+ grid0 = Grid (shape = (51 , 51 ))
2258+ f0 = TimeFunction (name = 'f0' , grid = grid0 , space_order = 2 )
2259+ g0 = TimeFunction (name = 'g0' , grid = grid0 , space_order = 2 )
2260+ a0 = Function (name = 'a0' , grid = grid0 , space_order = 2 )
2261+
2262+ # Fake array allocated in Python land so that shape_allocated can be used
2263+ b = Function (name = 'b' , grid = grid0 , space_order = 0 )
2264+
2265+ # Reuse an expensive function to encourage generation of an array temp
2266+ eq0 = Eq (f .forward , g + sympy .sin (a ))
2267+ eq1 = Eq (g .forward , f + sympy .sin (a ))
2268+
2269+ with switchconfig (log_level = 'DEBUG' ), caplog .at_level (logging .DEBUG ):
2270+ op = Operator ([eq0 , eq1 ])
2271+
2272+ # Regression to ensure this test functions as intended
2273+ # Ensure an array temporary is created
2274+ assert self ._array_temp in str (op .ccode )
2275+
2276+ summary = op .estimate_memory (f = f0 , g = g0 , a = a0 )
2277+ assert "Allocating" not in caplog .text
22262278
22272279 check = sum (reduce (mul , func .shape_allocated )* np .dtype (func .dtype ).itemsize
2228- for func in (f1 , tf1 , s1 , s1 .coordinates , st1 , st1 .coordinates ))
2280+ for func in (f0 , g0 , a0 ))
2281+
2282+ # Factor in the temp array
2283+ check += reduce (mul , b .shape_allocated )* np .dtype (a0 .dtype ).itemsize
22292284
2230- expected = ("Kernel" , check , 0 )
2231- self .parse_output (caplog , expected )
2285+ expected = (check , 0 )
2286+ self .parse_output (summary , expected )
22322287
22332288 def test_device (self , caplog ):
22342289 # Note: this uses switchconfig and runs on all backends to reflect expected
@@ -2245,10 +2300,11 @@ def test_device(self, caplog):
22452300 with switchconfig (** config ), caplog .at_level (logging .DEBUG ):
22462301 op = Operator (Eq (f , 1 ))
22472302
2248- op .estimate_memory (human_readable = False )
2303+ summary = op .estimate_memory ()
2304+ assert "Allocating" not in caplog .text
22492305
22502306 check = reduce (mul , f .shape_allocated )* np .dtype (f .dtype ).itemsize
22512307
22522308 # Matching memory allocated both on host and device for memmap
2253- expected = ("Kernel" , check , check )
2254- self .parse_output (caplog , expected )
2309+ expected = (check , check )
2310+ self .parse_output (summary , expected )
0 commit comments