@@ -2164,15 +2164,28 @@ def test_mashup(self, caplog):
21642164 expected = (check , 0 )
21652165 self .parse_output (summary , expected )
21662166
2167- def test_temp_array (self , caplog ):
2167+ @pytest .mark .parametrize ('override' , [True , False ])
2168+ def test_temp_array (self , caplog , override ):
21682169 """Check that temporary arrays will be factored into the memory calculation"""
21692170 grid = Grid (shape = (101 , 101 ))
21702171 f = TimeFunction (name = 'f' , grid = grid , space_order = 2 )
21712172 g = TimeFunction (name = 'g' , grid = grid , space_order = 2 )
21722173 a = Function (name = 'a' , grid = grid , space_order = 2 )
21732174
2174- # Fake array allocated in Python land so that shape_allocated can be used
2175- b = Function (name = 'b' , grid = grid , space_order = 0 )
2175+ if override :
2176+ grid0 = Grid (shape = (51 , 51 ))
2177+ f0 = TimeFunction (name = 'f0' , grid = grid0 , space_order = 2 )
2178+ g0 = TimeFunction (name = 'g0' , grid = grid0 , space_order = 2 )
2179+ a0 = Function (name = 'a0' , grid = grid0 , space_order = 2 )
2180+ funcs = (f0 , g0 , a0 )
2181+ kwargs = {'f' : f0 , 'g' : g0 , 'a' : a0 }
2182+
2183+ # Fake array allocated in Python land so that shape_allocated can be used
2184+ b = Function (name = 'b' , grid = grid0 , space_order = 0 )
2185+ else :
2186+ funcs = (f , g , a )
2187+ kwargs = {}
2188+ b = Function (name = 'b' , grid = grid , space_order = 0 )
21762189
21772190 # Reuse an expensive function to encourage generation of an array temp
21782191 eq0 = Eq (f .forward , g + sympy .sin (a ))
@@ -2185,14 +2198,14 @@ def test_temp_array(self, caplog):
21852198 # Ensure an array temporary is created
21862199 assert self ._array_temp in str (op .ccode )
21872200
2188- summary = op .estimate_memory ()
2201+ summary = op .estimate_memory (** kwargs )
21892202 assert "Allocating" not in caplog .text
21902203
21912204 check = sum (reduce (mul , func .shape_allocated )* np .dtype (func .dtype ).itemsize
2192- for func in ( f , g , a ) )
2205+ for func in funcs )
21932206
21942207 # Factor in the temp array
2195- check += reduce (mul , b .shape_allocated )* np .dtype (a .dtype ).itemsize
2208+ check += reduce (mul , b .shape_allocated )* np .dtype (b .dtype ).itemsize
21962209
21972210 expected = (check , 0 )
21982211 self .parse_output (summary , expected )
@@ -2247,44 +2260,6 @@ def test_overrides(self, caplog):
22472260 expected1 = (check1 , 0 )
22482261 self .parse_output (summary1 , expected1 )
22492262
2250- def test_overrides_w_temp_array (self , caplog ):
2251- """Check that temporary arrays are correctly adjusted for overrides"""
2252- grid = Grid (shape = (101 , 101 ))
2253- f = TimeFunction (name = 'f' , grid = grid , space_order = 2 )
2254- g = TimeFunction (name = 'g' , grid = grid , space_order = 2 )
2255- a = Function (name = 'a' , grid = grid , space_order = 2 )
2256-
2257- grid0 = Grid (shape = (51 , 51 ))
2258- f0 = TimeFunction (name = 'f0' , grid = grid0 , space_order = 2 )
2259- g0 = TimeFunction (name = 'g0' , grid = grid0 , space_order = 2 )
2260- a0 = Function (name = 'a0' , grid = grid0 , space_order = 2 )
2261-
2262- # Fake array allocated in Python land so that shape_allocated can be used
2263- b = Function (name = 'b' , grid = grid0 , space_order = 0 )
2264-
2265- # Reuse an expensive function to encourage generation of an array temp
2266- eq0 = Eq (f .forward , g + sympy .sin (a ))
2267- eq1 = Eq (g .forward , f + sympy .sin (a ))
2268-
2269- with switchconfig (log_level = 'DEBUG' ), caplog .at_level (logging .DEBUG ):
2270- op = Operator ([eq0 , eq1 ])
2271-
2272- # Regression to ensure this test functions as intended
2273- # Ensure an array temporary is created
2274- assert self ._array_temp in str (op .ccode )
2275-
2276- summary = op .estimate_memory (f = f0 , g = g0 , a = a0 )
2277- assert "Allocating" not in caplog .text
2278-
2279- check = sum (reduce (mul , func .shape_allocated )* np .dtype (func .dtype ).itemsize
2280- for func in (f0 , g0 , a0 ))
2281-
2282- # Factor in the temp array
2283- check += reduce (mul , b .shape_allocated )* np .dtype (a0 .dtype ).itemsize
2284-
2285- expected = (check , 0 )
2286- self .parse_output (summary , expected )
2287-
22882263 def test_device (self , caplog ):
22892264 # Note: this uses switchconfig and runs on all backends to reflect expected
22902265 # usage: users are likely to run the estimate on the orchestration node which
0 commit comments