|
| 1 | +import contextlib |
| 2 | + |
1 | 3 | import cgen as c |
2 | 4 | import numpy as np |
3 | 5 | from sympy import Expr, Not, S |
4 | 6 |
|
5 | 7 | from devito.ir.iet import ( |
6 | 8 | Break, Call, Conditional, DummyExpr, EntryFunction, FindNodes, FindSymbols, Iteration, |
7 | | - List, Return, Transformer, make_callable |
| 9 | + KernelLaunch, List, Return, Transformer, make_callable |
8 | 10 | ) |
9 | 11 | from devito.passes.iet.engine import iet_pass |
10 | 12 | from devito.symbolics import CondEq, MathFunction |
11 | 13 | from devito.tools import dtype_to_ctype |
12 | 14 | from devito.types import Eq, Inc, LocalObject, Symbol |
13 | 15 |
|
14 | | -__all__ = ['check_stability', 'error_mapper'] |
| 16 | +__all__ = ['check_launch', 'check_stability', 'error_mapper'] |
15 | 17 |
|
16 | 18 |
|
17 | 19 | def check_stability(graph, options=None, rcompile=None, sregistry=None, **kwargs): |
@@ -100,6 +102,67 @@ def _check_stability(iet, wmovs=(), rcompile=None, sregistry=None): |
100 | 102 | return iet, {'efuncs': efuncs, 'includes': includes} |
101 | 103 |
|
102 | 104 |
|
| 105 | +def check_launch(graph, options=None, **kwargs): |
| 106 | + """ |
| 107 | + Insert the CHECK_LAUNCH macro if errctl is set to ensure graceful handling of |
| 108 | + failed kernel launches. This macro should only be inserted if the kernel is |
| 109 | + directly within a loop, as compilation will fail otherwise. |
| 110 | + """ |
| 111 | + if options is None or not options.get('errctl', False): |
| 112 | + return |
| 113 | + |
| 114 | + langbb = kwargs['langbb'] |
| 115 | + |
| 116 | + definition = make_launch_macros(langbb) |
| 117 | + if not definition: |
| 118 | + return |
| 119 | + |
| 120 | + macro = [langbb['check-launch']] |
| 121 | + |
| 122 | + _check_launch(graph, definition=definition, macro=macro, **kwargs) |
| 123 | + |
| 124 | + |
| 125 | +@iet_pass |
| 126 | +def _check_launch(iet, definition=None, macro=None, **kwargs): |
| 127 | + iterations = FindNodes(Iteration).visit(iet) |
| 128 | + |
| 129 | + mapper = {} |
| 130 | + for i in iterations: |
| 131 | + # Two stages of substitution to account for the edge case |
| 132 | + # where a kernel is launched in multiple places within the |
| 133 | + # generated code, once inside a loop, once outside |
| 134 | + launch_mapper = {} |
| 135 | + launches = FindNodes(KernelLaunch).visit(i) |
| 136 | + |
| 137 | + for launch in launches: |
| 138 | + launch_mapper[launch] = List(body=[launch] + macro) |
| 139 | + |
| 140 | + if launch_mapper: |
| 141 | + mapper[i] = Transformer(launch_mapper).visit(i) |
| 142 | + |
| 143 | + extras = {} |
| 144 | + if mapper: |
| 145 | + iet = Transformer(mapper).visit(iet) |
| 146 | + extras.update({'headers': definition}) |
| 147 | + |
| 148 | + return iet, extras |
| 149 | + |
| 150 | + |
| 151 | +def make_launch_macros(langbb): |
| 152 | + """ |
| 153 | + Define macros to check for errors to ensure graceful handling of failed kernel |
| 154 | + launches. |
| 155 | + """ |
| 156 | + |
| 157 | + # Will skip if there is no peek-error call or success code for the langbb |
| 158 | + with contextlib.suppress(NotImplementedError): |
| 159 | + peek = langbb['peek-error'] |
| 160 | + success = langbb['error-none'] |
| 161 | + return [('CHECK_LAUNCH', f'if ({peek().name}() != {success}) {{break;}}')] |
| 162 | + |
| 163 | + return [] |
| 164 | + |
| 165 | + |
103 | 166 | class Retval(LocalObject, Expr): |
104 | 167 |
|
105 | 168 | dtype = dtype_to_ctype(np.int32) |
|
0 commit comments