|
1 | 1 | from devito.types.equation import PetscEq |
2 | | -from devito.tools import as_tuple |
| 2 | +from devito.tools import filter_ordered, as_tuple |
| 3 | +from devito.types import Symbol, SteppingDimension, TimeDimension |
| 4 | +from devito.operations.solve import eval_time_derivatives |
| 5 | +from devito.symbolics import retrieve_functions, retrieve_dimensions |
| 6 | + |
3 | 7 | from devito.petsc.types import (LinearSolverMetaData, PETScArray, DMDALocalInfo, |
4 | 8 | FieldData, MultipleFieldData, Jacobian, Residual, |
5 | 9 | MixedResidual, MixedJacobian, InitialGuess) |
6 | 10 | from devito.petsc.types.equation import EssentialBC |
7 | 11 | from devito.petsc.solver_parameters import (linear_solver_parameters, |
8 | 12 | format_options_prefix) |
9 | | -from devito.petsc.utils import get_funcs, generate_time_mapper |
10 | 13 |
|
11 | 14 |
|
12 | 15 | __all__ = ['PETScSolve'] |
@@ -186,5 +189,45 @@ def linear_solve_args(self): |
186 | 189 | return targets[0], funcs, all_data |
187 | 190 |
|
188 | 191 |
|
| 192 | +def get_funcs(exprs): |
| 193 | + funcs = [ |
| 194 | + f for e in exprs |
| 195 | + for f in retrieve_functions(eval_time_derivatives(e.lhs - e.rhs)) |
| 196 | + ] |
| 197 | + return as_tuple(filter_ordered(funcs)) |
| 198 | + |
| 199 | + |
| 200 | +def generate_time_mapper(exprs): |
| 201 | + """ |
| 202 | + Replace time indices with `Symbols` in expressions used within |
| 203 | + PETSc callback functions. These symbols are Uxreplaced at the IET |
| 204 | + level to align with the `TimeDimension` and `ModuloDimension` objects |
| 205 | + present in the initial lowering. |
| 206 | + NOTE: All functions used in PETSc callback functions are attached to |
| 207 | + the `SolverMetaData` object, which is passed through the initial lowering |
| 208 | + (and subsequently dropped and replaced with calls to run the solver). |
| 209 | + Therefore, the appropriate time loop will always be correctly generated inside |
| 210 | + the main kernel. |
| 211 | + Examples |
| 212 | + -------- |
| 213 | + >>> exprs = (Eq(f1(t + dt, x, y), g1(t + dt, x, y) + g2(t, x, y)*f1(t, x, y)),) |
| 214 | + >>> generate_time_mapper(exprs) |
| 215 | + {t + dt: tau0, t: tau1} |
| 216 | + """ |
| 217 | + # First, map any actual TimeDimensions |
| 218 | + time_indices = [d for d in retrieve_dimensions(exprs) if isinstance(d, TimeDimension)] |
| 219 | + |
| 220 | + funcs = get_funcs(exprs) |
| 221 | + |
| 222 | + time_indices.extend(list({ |
| 223 | + i if isinstance(d, SteppingDimension) else d |
| 224 | + for f in funcs |
| 225 | + for i, d in zip(f.indices, f.dimensions) |
| 226 | + if d.is_Time |
| 227 | + })) |
| 228 | + tau_symbs = [Symbol('tau%d' % i) for i in range(len(time_indices))] |
| 229 | + return dict(zip(time_indices, tau_symbs)) |
| 230 | + |
| 231 | + |
189 | 232 | localinfo = DMDALocalInfo(name='info', liveness='eager') |
190 | 233 | prefixes = ['y', 'x', 'f', 'b'] |
0 commit comments