Skip to content

Commit f1b8f7f

Browse files
committed
compiler: move launch check injection to later in compilation pipeline
1 parent ad2febe commit f1b8f7f

1 file changed

Lines changed: 65 additions & 2 deletions

File tree

devito/passes/iet/errors.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1+
import contextlib
2+
13
import cgen as c
24
import numpy as np
35
from sympy import Expr, Not, S
46

57
from devito.ir.iet import (
68
Break, Call, Conditional, DummyExpr, EntryFunction, FindNodes, FindSymbols, Iteration,
7-
List, Return, Transformer, make_callable
9+
List, Return, Transformer, KernelLaunch, make_callable, retrieve_iteration_tree
810
)
911
from devito.passes.iet.engine import iet_pass
1012
from devito.symbolics import CondEq, MathFunction
1113
from devito.tools import dtype_to_ctype
1214
from devito.types import Eq, Inc, LocalObject, Symbol
1315

14-
__all__ = ['check_stability', 'error_mapper']
16+
__all__ = ['check_stability', 'check_launch', 'error_mapper']
1517

1618

1719
def check_stability(graph, options=None, rcompile=None, sregistry=None, **kwargs):
@@ -100,6 +102,67 @@ def _check_stability(iet, wmovs=(), rcompile=None, sregistry=None):
100102
return iet, {'efuncs': efuncs, 'includes': includes}
101103

102104

105+
def check_launch(graph, options={}, **kwargs):
106+
"""
107+
Insert the CHECK_LAUNCH macro if errctl is set to ensure graceful handling of
108+
failed kernel launches. This macro should only be inserted if the kernel is
109+
directly within a loop, as compilation will fail otherwise.
110+
"""
111+
if not options.get('errctl', False):
112+
return
113+
114+
langbb = kwargs['langbb']
115+
116+
definition = make_launch_macros(langbb)
117+
if not definition:
118+
return
119+
120+
macro = [langbb['check-launch']]
121+
122+
_check_launch(graph, definition=definition, macro=macro, **kwargs)
123+
124+
125+
@iet_pass
126+
def _check_launch(iet, definition=None, macro=None, **kwargs):
127+
iterations = FindNodes(Iteration).visit(iet)
128+
129+
mapper = {}
130+
for i in iterations:
131+
# Two stages of substitution to account for the edge case
132+
# where a kernel is launched in multiple places within the
133+
# generated code, once inside a loop, once outside
134+
launch_mapper = {}
135+
launches = FindNodes(KernelLaunch).visit(i)
136+
137+
for launch in launches:
138+
launch_mapper[launch] = List(body=[launch] + macro)
139+
140+
if launch_mapper:
141+
mapper[i] = Transformer(launch_mapper).visit(i)
142+
143+
extras = {}
144+
if mapper:
145+
iet = Transformer(mapper).visit(iet)
146+
extras.update({'headers': definition})
147+
148+
return iet, extras
149+
150+
151+
def make_launch_macros(langbb):
152+
"""
153+
Define macros to check for errors to ensure graceful handling of failed kernel
154+
launches.
155+
"""
156+
157+
# Will skip if there is no peek-error call or success code for the langbb
158+
with contextlib.suppress(NotImplementedError):
159+
peek = langbb['peek-error']
160+
success = langbb['error-none']
161+
return [('CHECK_LAUNCH', f'if ({peek().name}() != {success}) {{break;}}')]
162+
163+
return []
164+
165+
103166
class Retval(LocalObject, Expr):
104167

105168
dtype = dtype_to_ctype(np.int32)

0 commit comments

Comments
 (0)