Skip to content

Commit 0d57ce8

Browse files
committed
compiler: Change functioning of memory estimate to be more parseable
1 parent ca49794 commit 0d57ce8

4 files changed

Lines changed: 159 additions & 72 deletions

File tree

devito/operator/operator.py

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
from devito.symbolics import estimate_cost, subs_op_args
3333
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple,
3434
flatten, filter_sorted, frozendict, is_integer,
35-
split, timed_pass, timed_region, contains_val, humanbytes)
35+
split, timed_pass, timed_region, contains_val,
36+
MemoryEstimate)
3637
from devito.types import (Buffer, Evaluable, host_layer, device_layer,
3738
disk_layer)
3839
from devito.types.dimension import Thickness
@@ -870,49 +871,54 @@ def cinterface(self, force=False):
870871
def __call__(self, **kwargs):
871872
return self.apply(**kwargs)
872873

873-
def estimate_memory(self, human_readable=True, **kwargs):
874+
def estimate_memory(self, **kwargs):
874875
"""
875-
Estimate the memory consumed by the Operator.
876+
Estimate the memory consumed by the Operator without touching or allocating any
877+
data. This interface is designed to mimic `Operator.apply(**kwargs)` and can be
878+
called with the kwargs for a prospective operator execution. With no arguments,
879+
it will simply estimate memory for the default operator parameters. However, if
880+
desired, overrides can be supplied (as per `apply`) and these will be used for
881+
the memory estimate.
882+
883+
If estimating memory for an Operator which is expected to allocate large arrays,
884+
it is strongly recommended that one avoids touching the data in Python (thus
885+
avoiding allocation). `AbstractFunction` types have their data allocated lazily -
886+
the underlying array is only created at the point at which the `data`,
887+
`data_with_halo`, etc, attributes are first accessed. Thus by avoiding accessing
888+
such attributes in the memory estimation script, one can check the nominal memory
889+
usage of proposed operators far larger than will fit in system DRAM.
890+
891+
Note that this estimate will build the Operator in order to factor in memory
892+
allocation for array temporaries and buffers generated during compilation.
876893
877-
TODO: Finish this docstring
894+
Parameters
895+
----------
896+
human_readable: bool
897+
Return human-readable values, rather than raw byte counts. Default is False.
898+
**kwargs: dict
899+
As per `Operator.apply()`.
900+
901+
Returns
902+
-------
903+
summary: MemoryEstimate
904+
An estimate of memory consumed in each of the specified locations.
878905
"""
879906
# Build the arguments list for which to get the memory consumption
880907
# This is so that the estimate will factor in overrides
881908
args = self._prepare_arguments(estimate_memory=True, **kwargs)
882909
mem = args.nbytes_consumed
883910

884-
# Extra information for enhanced operators
885-
extras = self._enrich_memreport(args, human_readable=human_readable)
886-
887-
if human_readable:
888-
headline = f"Memory consumption for operator `{self.name}`:"
889-
w = len(headline)
890-
# Columns are width 10
891-
fhost = str(humanbytes(mem[host_layer])).center(10)
892-
fdevice = str(humanbytes(mem[device_layer])).center(10)
893-
894-
memreport = (
895-
"\n"
896-
f"{headline}\n"
897-
f"{'┌──────────┬──────────┐'.center(w)}\n"
898-
f"{'│ Host │ Device │'.center(w)}\n"
899-
f"{'├──────────┼──────────┤'.center(w)}\n"
900-
f"{f'│{fhost}{fdevice}│'.center(w)}\n"
901-
f"{'└──────────┴──────────┘'.center(w)}\n"
902-
)
911+
memreport = {'host': mem[host_layer], 'device': mem[device_layer]}
903912

904-
# TODO: add hinting if the specified operator won't fit
905-
else:
906-
memreport = f"{self.name} {mem[host_layer]} {mem[device_layer]}"
913+
# Extra information for enriched operators
914+
extras = self._enrich_memreport(args)
915+
memreport.update(extras)
907916

908-
if extras is not None:
909-
memreport += extras
917+
return MemoryEstimate(memreport, name=self.name)
910918

911-
info(memreport)
912-
913-
def _enrich_memreport(self, args, human_readable=True):
914-
# Hook for enriching memory report
915-
pass
919+
def _enrich_memreport(self, args):
920+
# Hook for enriching memory report with additional metadata
921+
return {}
916922

917923
def apply(self, **kwargs):
918924
"""

devito/tools/data_structures.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
from collections import OrderedDict, deque
22
from collections.abc import Callable, Iterable, MutableSet, Mapping, Set
3-
from functools import reduce
3+
from functools import reduce, cached_property
44

55
import numpy as np
66
from multidict import MultiDict
77

88
from devito.tools import Pickable
9-
from devito.tools.utils import as_tuple, filter_ordered
9+
from devito.tools.utils import as_tuple, filter_ordered, humanbytes
1010
from devito.tools.algorithms import toposort
1111

1212
__all__ = ['Bunch', 'EnrichedTuple', 'ReducerMap', 'DefaultOrderedDict',
1313
'OrderedSet', 'Ordering', 'DAG', 'frozendict',
14-
'UnboundTuple', 'UnboundedMultiTuple']
14+
'UnboundTuple', 'UnboundedMultiTuple', 'MemoryEstimate']
1515

1616

1717
class Bunch:
@@ -660,6 +660,31 @@ def __hash__(self):
660660
return self._hash
661661

662662

663+
class MemoryEstimate(frozendict):
664+
"""
665+
An immutable wrapper for a memory estimate, showing the
666+
various values.
667+
668+
TODO: Finish this docstring
669+
"""
670+
671+
def __init__(self, *args, **kwargs):
672+
self._name = kwargs.pop('name', 'memory_estimate')
673+
super().__init__(*args, **kwargs)
674+
675+
@property
676+
def name(self):
677+
return self._name
678+
679+
@cached_property
680+
def human_readable(self):
681+
"""The memory estimate in human-readable format"""
682+
return frozendict({k: humanbytes(v) for k, v in self.items()})
683+
684+
def __repr__(self):
685+
return f'{self.__class__.__name__}({self.name}): {self.human_readable._dict}'
686+
687+
663688
class UnboundTuple(tuple):
664689
"""
665690
An UnboundedTuple is a tuple that can be

devito/types/dense.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,7 @@ def _arg_defaults(self, alias=None, metadata=None, estimate_memory=False):
810810
To bind the argument values to different names.
811811
"""
812812
key = alias or self
813-
# TODO: Tidy this up. The idea is to avoid touching the data
813+
# Avoid touching the data if just estimating memory usage
814814
if estimate_memory:
815815
args = ReducerMap({key.name: self})
816816
else:

tests/test_operator.py

Lines changed: 91 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2068,16 +2068,9 @@ class TestEstimateMemory:
20682068

20692069
_array_temp = "r0L0(x, y)" if "CXX" in configuration['language'] else "r0[x][y]"
20702070

2071-
def parse_output(self, output, expected):
2071+
def parse_output(self, summary, expected):
20722072
"""Parse estimate_memory machine-readable output"""
2073-
# Check that no allocation occurs as estimate_memory should avoid data touch
2074-
assert "Allocating" not in output.text
2075-
2076-
parsed = output.records[-1].message.split()
2077-
name, host, device = parsed[:3]
2078-
extracted = (name, int(host), int(device))
2079-
2080-
assert extracted == expected
2073+
assert (summary['host'], summary['device']) == expected
20812074

20822075
@pytest.mark.parametrize('shape', [(11,), (101, 101), (101, 101, 101)])
20832076
@pytest.mark.parametrize('dtype', [np.int8, np.int16, np.float32,
@@ -2089,13 +2082,14 @@ def test_basic_usage(self, caplog, shape, dtype, so):
20892082
with switchconfig(log_level='DEBUG'), caplog.at_level(logging.DEBUG):
20902083
op = Operator(Eq(f, 1))
20912084

2092-
# Machine-readable output for parsing
2093-
op.estimate_memory(human_readable=False)
2085+
summary = op.estimate_memory()
2086+
# Check that no allocation occurs as estimate_memory should avoid data touch
2087+
assert "Allocating" not in caplog.text
20942088

20952089
# Check output of estimate_memory
20962090
host = reduce(mul, f.shape_allocated)*np.dtype(f.dtype).itemsize
2097-
expected = ("Kernel", host, 0)
2098-
self.parse_output(caplog, expected)
2091+
expected = (host, 0)
2092+
self.parse_output(summary, expected)
20992093

21002094
def test_multiple_objects(self, caplog):
21012095
grid = Grid(shape=(101, 101))
@@ -2104,12 +2098,13 @@ def test_multiple_objects(self, caplog):
21042098
g = Function(name='g', grid=grid, space_order=4, dtype=np.float64)
21052099
with switchconfig(log_level='DEBUG'), caplog.at_level(logging.DEBUG):
21062100
op = Operator([Eq(f, 1), Eq(g, 1)])
2107-
op.estimate_memory(human_readable=False)
2101+
summary = op.estimate_memory()
2102+
assert "Allocating" not in caplog.text
21082103

21092104
check = sum(reduce(mul, func.shape_allocated)*np.dtype(func.dtype).itemsize
21102105
for func in (f, g))
2111-
expected = ("Kernel", check, 0)
2112-
self.parse_output(caplog, expected)
2106+
expected = (check, 0)
2107+
self.parse_output(summary, expected)
21132108

21142109
@pytest.mark.parametrize('time', [True, False])
21152110
def test_sparse(self, caplog, time):
@@ -2123,12 +2118,13 @@ def test_sparse(self, caplog, time):
21232118

21242119
with switchconfig(log_level='DEBUG'), caplog.at_level(logging.DEBUG):
21252120
op = Operator(src_term)
2126-
op.estimate_memory(human_readable=False)
2121+
summary = op.estimate_memory()
2122+
assert "Allocating" not in caplog.text
21272123

21282124
check = sum(reduce(mul, func.shape_allocated)*np.dtype(func.dtype).itemsize
21292125
for func in (f, src, src.coordinates))
2130-
expected = ("Kernel", check, 0)
2131-
self.parse_output(caplog, expected)
2126+
expected = (check, 0)
2127+
self.parse_output(summary, expected)
21322128

21332129
@pytest.mark.parametrize('save', [None, Buffer(3), 10])
21342130
def test_timefunction(self, caplog, save):
@@ -2137,10 +2133,11 @@ def test_timefunction(self, caplog, save):
21372133

21382134
with switchconfig(log_level='DEBUG'), caplog.at_level(logging.DEBUG):
21392135
op = Operator(Eq(f, 1))
2140-
op.estimate_memory(human_readable=False)
2136+
summary = op.estimate_memory()
2137+
assert "Allocating" not in caplog.text
21412138
check = reduce(mul, f.shape_allocated)*np.dtype(f.dtype).itemsize
2142-
expected = ("Kernel", check, 0)
2143-
self.parse_output(caplog, expected)
2139+
expected = (check, 0)
2140+
self.parse_output(summary, expected)
21442141

21452142
def test_mashup(self, caplog):
21462143
grid = Grid(shape=(101, 101))
@@ -2158,13 +2155,14 @@ def test_mashup(self, caplog):
21582155

21592156
with switchconfig(log_level='DEBUG'), caplog.at_level(logging.DEBUG):
21602157
op = Operator([eq0, eq1] + src_term0 + src_term1)
2161-
op.estimate_memory(human_readable=False)
2158+
summary = op.estimate_memory()
2159+
assert "Allocating" not in caplog.text
21622160

21632161
check = sum(reduce(mul, func.shape_allocated)*np.dtype(func.dtype).itemsize
21642162
for func in (f, g, src0, src0.coordinates,
21652163
src1, src1.coordinates))
2166-
expected = ("Kernel", check, 0)
2167-
self.parse_output(caplog, expected)
2164+
expected = (check, 0)
2165+
self.parse_output(summary, expected)
21682166

21692167
def test_temp_array(self, caplog):
21702168
"""Check that temporary arrays will be factored into the memory calculation"""
@@ -2187,18 +2185,20 @@ def test_temp_array(self, caplog):
21872185
# Ensure an array temporary is created
21882186
assert self._array_temp in str(op.ccode)
21892187

2190-
op.estimate_memory(human_readable=False)
2188+
summary = op.estimate_memory()
2189+
assert "Allocating" not in caplog.text
21912190

21922191
check = sum(reduce(mul, func.shape_allocated)*np.dtype(func.dtype).itemsize
21932192
for func in (f, g, a))
21942193

21952194
# Factor in the temp array
21962195
check += reduce(mul, b.shape_allocated)*np.dtype(a.dtype).itemsize
21972196

2198-
expected = ("Kernel", check, 0)
2199-
self.parse_output(caplog, expected)
2197+
expected = (check, 0)
2198+
self.parse_output(summary, expected)
22002199

22012200
def test_overrides(self, caplog):
2201+
# TODO: Consolidate this boilerplate
22022202
grid0 = Grid(shape=(101, 101))
22032203
# Original fields
22042204
f0 = Function(name='f0', grid=grid0, space_order=4)
@@ -2213,6 +2213,13 @@ def test_overrides(self, caplog):
22132213
s1 = SparseFunction(name='s1', grid=grid1, npoint=200)
22142214
st1 = SparseTimeFunction(name='st1', grid=grid1, npoint=200, nt=20)
22152215

2216+
grid2 = Grid(shape=(51, 51)) # Smaller grid so overrides are distinct
2217+
# Alternative replacement fields
2218+
f2 = Function(name='f2', grid=grid2, space_order=4)
2219+
tf2 = TimeFunction(name='tf2', grid=grid2, space_order=4)
2220+
s2 = SparseFunction(name='s2', grid=grid2, npoint=50)
2221+
st2 = SparseTimeFunction(name='st2', grid=grid2, npoint=50, nt=5)
2222+
22162223
eq0 = Eq(f0, 1)
22172224
eq1 = Eq(tf0, 1)
22182225
s0_term = s0.inject(field=f0, expr=s0)
@@ -2222,13 +2229,61 @@ def test_overrides(self, caplog):
22222229
op = Operator([eq0, eq1] + s0_term + st0_term)
22232230

22242231
# Apply overrides for the check
2225-
op.estimate_memory(f0=f1, tf0=tf1, s0=s1, st0=st1, human_readable=False)
2232+
summary0 = op.estimate_memory(f0=f1, tf0=tf1, s0=s1, st0=st1)
2233+
2234+
check0 = sum(reduce(mul, func.shape_allocated)*np.dtype(func.dtype).itemsize
2235+
for func in (f1, tf1, s1, s1.coordinates, st1, st1.coordinates))
2236+
2237+
expected0 = (check0, 0)
2238+
self.parse_output(summary0, expected0)
2239+
2240+
# Check with a second set of overrides
2241+
summary1 = op.estimate_memory(f0=f2, tf0=tf2, s0=s2, st0=st2)
2242+
assert "Allocating" not in caplog.text
2243+
2244+
check1 = sum(reduce(mul, func.shape_allocated)*np.dtype(func.dtype).itemsize
2245+
for func in (f2, tf2, s2, s2.coordinates, st2, st2.coordinates))
2246+
2247+
expected1 = (check1, 0)
2248+
self.parse_output(summary1, expected1)
2249+
2250+
def test_overrides_w_temp_array(self, caplog):
2251+
"""Check that temporary arrays are correctly adjusted for overrides"""
2252+
grid = Grid(shape=(101, 101))
2253+
f = TimeFunction(name='f', grid=grid, space_order=2)
2254+
g = TimeFunction(name='g', grid=grid, space_order=2)
2255+
a = Function(name='a', grid=grid, space_order=2)
2256+
2257+
grid0 = Grid(shape=(51, 51))
2258+
f0 = TimeFunction(name='f0', grid=grid0, space_order=2)
2259+
g0 = TimeFunction(name='g0', grid=grid0, space_order=2)
2260+
a0 = Function(name='a0', grid=grid0, space_order=2)
2261+
2262+
# Fake array allocated in Python land so that shape_allocated can be used
2263+
b = Function(name='b', grid=grid0, space_order=0)
2264+
2265+
# Reuse an expensive function to encourage generation of an array temp
2266+
eq0 = Eq(f.forward, g + sympy.sin(a))
2267+
eq1 = Eq(g.forward, f + sympy.sin(a))
2268+
2269+
with switchconfig(log_level='DEBUG'), caplog.at_level(logging.DEBUG):
2270+
op = Operator([eq0, eq1])
2271+
2272+
# Regression to ensure this test functions as intended
2273+
# Ensure an array temporary is created
2274+
assert self._array_temp in str(op.ccode)
2275+
2276+
summary = op.estimate_memory(f=f0, g=g0, a=a0)
2277+
assert "Allocating" not in caplog.text
22262278

22272279
check = sum(reduce(mul, func.shape_allocated)*np.dtype(func.dtype).itemsize
2228-
for func in (f1, tf1, s1, s1.coordinates, st1, st1.coordinates))
2280+
for func in (f0, g0, a0))
2281+
2282+
# Factor in the temp array
2283+
check += reduce(mul, b.shape_allocated)*np.dtype(a0.dtype).itemsize
22292284

2230-
expected = ("Kernel", check, 0)
2231-
self.parse_output(caplog, expected)
2285+
expected = (check, 0)
2286+
self.parse_output(summary, expected)
22322287

22332288
def test_device(self, caplog):
22342289
# Note: this uses switchconfig and runs on all backends to reflect expected
@@ -2245,10 +2300,11 @@ def test_device(self, caplog):
22452300
with switchconfig(**config), caplog.at_level(logging.DEBUG):
22462301
op = Operator(Eq(f, 1))
22472302

2248-
op.estimate_memory(human_readable=False)
2303+
summary = op.estimate_memory()
2304+
assert "Allocating" not in caplog.text
22492305

22502306
check = reduce(mul, f.shape_allocated)*np.dtype(f.dtype).itemsize
22512307

22522308
# Matching memory allocated both on host and device for memmap
2253-
expected = ("Kernel", check, check)
2254-
self.parse_output(caplog, expected)
2309+
expected = (check, check)
2310+
self.parse_output(summary, expected)

0 commit comments

Comments
 (0)