|
21 | 21 | from devito.ir.clusters import ClusterGroup, clusterize |
22 | 22 | from devito.ir.iet import (Callable, CInterface, EntryFunction, DeviceFunction, |
23 | 23 | FindSymbols, MetaCall, derive_parameters, iet_build) |
| 24 | +from devito.ir.iet.visitors import Specializer |
24 | 25 | from devito.ir.support import AccessMode, SymbolRegistry |
25 | 26 | from devito.ir.stree import stree_build |
26 | 27 | from devito.operator.profiling import create_profile |
@@ -985,16 +986,34 @@ def apply(self, **kwargs): |
985 | 986 | >>> op = Operator(Eq(u3.forward, u3 + 1)) |
986 | 987 | >>> summary = op.apply(time_M=10) |
987 | 988 | """ |
988 | | - # Compile the operator before building the arguments list |
989 | | - # to avoid out of memory with greedy compilers |
990 | | - cfunction = self.cfunction |
| 989 | + # Get items expected to be specialized |
| 990 | + specialize = as_tuple(kwargs.pop('specialize', [])) |
| 991 | + |
| 992 | + if not specialize: |
| 993 | + # Compile the operator before building the arguments list |
| 994 | + # to avoid out of memory with greedy compilers |
| 995 | + cfunction = self.cfunction |
991 | 996 |
|
992 | 997 | # Build the arguments list to invoke the kernel function |
993 | 998 | with self._profiler.timer_on('arguments-preprocess'): |
994 | 999 | args = self.arguments(**kwargs) |
995 | 1000 | with switch_log_level(comm=args.comm): |
996 | 1001 | self._emit_args_profiling('arguments-preprocess') |
997 | 1002 |
|
| 1003 | + # In the case of specialization, arguments must be processed before |
| 1004 | + # the operator can be compiled |
| 1005 | + if specialize: |
| 1006 | + specialized_args = {p: sympify(args.pop(p.name)) |
| 1007 | + for p in self.parameters if p.name in specialize} |
| 1008 | + |
| 1009 | + op = Specializer(specialized_args).visit(self) |
| 1010 | + else: |
| 1011 | + op = self |
| 1012 | + |
| 1013 | + from IPython import embed; embed() |
| 1014 | + |
| 1015 | + # TODO: Whose profiler should get used here? |
| 1016 | + |
998 | 1017 | # Invoke kernel function with args |
999 | 1018 | arg_values = [args[p.name] for p in self.parameters] |
1000 | 1019 | try: |
|
0 commit comments