Skip to content

Commit a0af5fa

Browse files
committed
misc: Update memory report format
1 parent b587d51 commit a0af5fa

2 files changed

Lines changed: 50 additions & 25 deletions

File tree

devito/operator/operator.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -883,30 +883,38 @@ def estimate_memory(self, human_readable=True, **kwargs):
883883
args = self._prepare_arguments(estimate_memory=True, **kwargs)
884884
mem = args.nbytes_consumed
885885

886+
# Extra information for enhanced operators
887+
extras = self._enrich_memreport(args, human_readable=human_readable)
888+
886889
if human_readable:
887890
headline = f"Memory consumption for operator `{self.name}`:"
888891
w = len(headline)
889892
# Columns are width 10
890-
fdisk = str(humanbytes(mem[disk_layer])).center(10)
891893
fhost = str(humanbytes(mem[host_layer])).center(10)
892894
fdevice = str(humanbytes(mem[device_layer])).center(10)
893895

894-
# TODO: There is nominally a table generator in pytools which is a dependency
895-
# of a dependency and thus in the env anyway
896-
info(
896+
memreport = (
897897
"\n"
898898
f"{headline}\n"
899-
f"{'┌──────────┬──────────┬──────────┐'.center(w)}\n"
900-
f"{'│ Disk │ Host │ Device │'.center(w)}\n"
901-
f"{'├──────────┼──────────┼──────────┤'.center(w)}\n"
902-
f"{f'│{fdisk}{fhost}{fdevice}│'.center(w)}\n"
903-
f"{'└──────────┴──────────┴──────────┘'.center(w)}\n"
899+
f"{'┌──────────┬──────────┐'.center(w)}\n"
900+
f"{'│ Host │ Device │'.center(w)}\n"
901+
f"{'├──────────┼──────────┤'.center(w)}\n"
902+
f"{f'│{fhost}{fdevice}│'.center(w)}\n"
903+
f"{'└──────────┴──────────┘'.center(w)}\n"
904904
)
905905

906906
# TODO: add hinting if the specified operator won't fit
907-
908907
else:
909-
info(f"{self.name} {mem[disk_layer]} {mem[host_layer]} {mem[device_layer]}")
908+
memreport = f"{self.name} {mem[host_layer]} {mem[device_layer]}"
909+
910+
if extras is not None:
911+
memreport += extras
912+
913+
info(memreport)
914+
915+
def _enrich_memreport(self, args, human_readable=True):
916+
# Hook for enriching memory report
917+
pass
910918

911919
def apply(self, **kwargs):
912920
"""
@@ -1326,8 +1334,6 @@ def nbytes_avail_mapper(self):
13261334

13271335
return mapper
13281336

1329-
# TODO: This will want some suitable tests in due course
1330-
# TODO: Might want to also check the spillover onto disk
13311337
@cached_property
13321338
def nbytes_consumed(self):
13331339
"""Memory consumed by all objects in the operator"""
@@ -1376,8 +1382,6 @@ def get_nbytes(obj):
13761382
and not i.is_ArrayBasic and not i.alias]
13771383

13781384
for i in op_symbols:
1379-
# FIXME: Probably wrong for streamed functions
1380-
# TODO: Need a hook for PRO here
13811385
# Will overreport memory usage currently
13821386
try:
13831387
# TODO: is _obj even needed?
@@ -1465,6 +1469,27 @@ def nbytes_consumed_memmapped(self):
14651469

14661470
return {disk_layer: 0, host_layer: 0, device_layer: device}
14671471

1472+
@cached_property
1473+
def nbytes_snapshots(self):
1474+
disk = 0
1475+
1476+
# Symbols in the operator which may or may not carry data
1477+
op_symbols = FindSymbols().visit(self.op)
1478+
1479+
# Filter to streamed functions
1480+
op_symbols = [i for i in op_symbols if i.is_AbstractFunction
1481+
and not i.is_ArrayBasic and not i.alias]
1482+
1483+
for i in op_symbols:
1484+
try:
1485+
disk += i.size_snapshot
1486+
except AttributeError:
1487+
pass
1488+
1489+
print(disk)
1490+
1491+
return {disk_layer: 0, host_layer: 0, device_layer: 0}
1492+
14681493

14691494
def parse_kwargs(**kwargs):
14701495
"""

tests/test_operator.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2073,8 +2073,8 @@ def parse_output(self, output, expected):
20732073
# Check that no allocation occurs as estimate_memory should avoid data touch
20742074
assert "Allocating" not in output.text
20752075

2076-
name, disk, host, device = output.records[-1].message.split()
2077-
extracted = (name, int(disk), int(host), int(device))
2076+
name, host, device = output.records[-1].message.split()
2077+
extracted = (name, int(host), int(device))
20782078

20792079
assert extracted == expected
20802080

@@ -2093,7 +2093,7 @@ def test_basic_usage(self, caplog, shape, dtype, so):
20932093

20942094
# Check output of estimate_memory
20952095
host = reduce(mul, f.shape_allocated)*np.dtype(f.dtype).itemsize
2096-
expected = ("Kernel", 0, host, 0)
2096+
expected = ("Kernel", host, 0)
20972097
self.parse_output(caplog, expected)
20982098

20992099
def test_multiple_objects(self, caplog):
@@ -2107,7 +2107,7 @@ def test_multiple_objects(self, caplog):
21072107

21082108
check = sum(reduce(mul, func.shape_allocated)*np.dtype(func.dtype).itemsize
21092109
for func in (f, g))
2110-
expected = ("Kernel", 0, check, 0)
2110+
expected = ("Kernel", check, 0)
21112111
self.parse_output(caplog, expected)
21122112

21132113
@pytest.mark.parametrize('time', [True, False])
@@ -2126,7 +2126,7 @@ def test_sparse(self, caplog, time):
21262126

21272127
check = sum(reduce(mul, func.shape_allocated)*np.dtype(func.dtype).itemsize
21282128
for func in (f, src, src.coordinates))
2129-
expected = ("Kernel", 0, check, 0)
2129+
expected = ("Kernel", check, 0)
21302130
self.parse_output(caplog, expected)
21312131

21322132
@pytest.mark.parametrize('save', [None, Buffer(3), 10])
@@ -2138,7 +2138,7 @@ def test_timefunction(self, caplog, save):
21382138
op = Operator(Eq(f, 1))
21392139
op.estimate_memory(human_readable=False)
21402140
check = reduce(mul, f.shape_allocated)*np.dtype(f.dtype).itemsize
2141-
expected = ("Kernel", 0, check, 0)
2141+
expected = ("Kernel", check, 0)
21422142
self.parse_output(caplog, expected)
21432143

21442144
def test_mashup(self, caplog):
@@ -2162,7 +2162,7 @@ def test_mashup(self, caplog):
21622162
check = sum(reduce(mul, func.shape_allocated)*np.dtype(func.dtype).itemsize
21632163
for func in (f, g, src0, src0.coordinates,
21642164
src1, src1.coordinates))
2165-
expected = ("Kernel", 0, check, 0)
2165+
expected = ("Kernel", check, 0)
21662166
self.parse_output(caplog, expected)
21672167

21682168
def test_temp_array(self, caplog):
@@ -2194,7 +2194,7 @@ def test_temp_array(self, caplog):
21942194
# Factor in the temp array
21952195
check += reduce(mul, b.shape_allocated)*np.dtype(a.dtype).itemsize
21962196

2197-
expected = ("Kernel", 0, check, 0)
2197+
expected = ("Kernel", check, 0)
21982198
self.parse_output(caplog, expected)
21992199

22002200
def test_overrides(self, caplog):
@@ -2226,7 +2226,7 @@ def test_overrides(self, caplog):
22262226
check = sum(reduce(mul, func.shape_allocated)*np.dtype(func.dtype).itemsize
22272227
for func in (f1, tf1, s1, s1.coordinates, st1, st1.coordinates))
22282228

2229-
expected = ("Kernel", 0, check, 0)
2229+
expected = ("Kernel", check, 0)
22302230
self.parse_output(caplog, expected)
22312231

22322232
def test_device(self, caplog):
@@ -2249,5 +2249,5 @@ def test_device(self, caplog):
22492249
check = reduce(mul, f.shape_allocated)*np.dtype(f.dtype).itemsize
22502250

22512251
# Matching memory allocated both on host and device for memmap
2252-
expected = ("Kernel", 0, check, check)
2252+
expected = ("Kernel", check, check)
22532253
self.parse_output(caplog, expected)

0 commit comments

Comments
 (0)