From 433e7c0c1cbd6ed02c7fe86e35dd18d0edd67d9e Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Sat, 18 Apr 2026 12:32:16 +0100 Subject: [PATCH] dsl: Add ReduceMinMax construct for joint minmax reductions --- devito/ir/clusters/algorithms.py | 60 +++++++++++++++++++++++++------- devito/ir/equations/equation.py | 9 +++-- devito/ir/iet/nodes.py | 10 +++--- devito/passes/iet/engine.py | 3 +- devito/types/basic.py | 25 +++++++++++++ devito/types/equation.py | 26 ++++++++++++-- tests/test_dle.py | 27 +++++++++++++- 7 files changed, 135 insertions(+), 25 deletions(-) diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index d09efb4fef..9b9d71da5d 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -10,7 +10,7 @@ from devito.ir.clusters.analysis import analyze from devito.ir.clusters.cluster import Cluster, ClusterGroup from devito.ir.clusters.visitors import Queue, cluster_pass -from devito.ir.equations import OpMax, OpMin, identity_mapper +from devito.ir.equations import OpMax, OpMin, OpMinMax, identity_mapper from devito.ir.support import ( Any, Backward, Forward, IterationSpace, Scope, erange, pull_dims ) @@ -531,8 +531,19 @@ def _update(reductions): # The IterationSpace within which the global distributed reduction # must be carried out ispace = c.ispace.prefix(lambda d: d in var.free_symbols) # noqa: B023 - expr = [Eq(var, DistReduce(var, op=op, grid=grid, ispace=ispace))] - fifo.append(c.rebuild(exprs=expr, ispace=ispace)) + + if op is OpMinMax: + # MinMax not natively supported by MPI, so for now we perform two + # separate reductions (not optimal, but it will do for now) + var0, var1 = var, var._translate() + exprs = [ + Eq(var0, DistReduce(var0, op=OpMin, grid=grid, ispace=ispace)), + Eq(var1, DistReduce(var1, op=OpMax, grid=grid, ispace=ispace)) + ] + else: + exprs = [Eq(var, DistReduce(var, op=op, grid=grid, ispace=ispace))] + + fifo.append(c.rebuild(exprs=exprs, ispace=ispace)) processed.append(c) @@ -547,7 +558,7 @@ def normalize(clusters, sregistry=None, options=None, platform=None, **kwargs): if options['mapify-reduce']: clusters = normalize_reductions_dense(clusters, sregistry, platform) else: - clusters = normalize_reductions_minmax(clusters) + clusters = normalize_reductions_minmax(clusters, sregistry) clusters = normalize_reductions_sparse(clusters, sregistry) return clusters @@ -591,7 +602,7 @@ def pull_indexeds(expr, subs, mapper, parent=None): @cluster_pass(mode='dense') -def normalize_reductions_minmax(cluster): +def normalize_reductions_minmax(cluster, sregistry): """ Initialize the reduction variables to their neutral element and use them to compute the reduction. @@ -603,6 +614,7 @@ def normalize_reductions_minmax(cluster): init = [] processed = [] + post = [] for e in cluster.exprs: lhs, rhs = e.args f = lhs.function @@ -623,10 +635,32 @@ def normalize_reductions_minmax(cluster): processed.append(e.func(lhs, Max(lhs, rhs))) + elif e.operation is OpMinMax: + # NOTE: we need to create two different reduction variables here + # (instead of using say `n[0]` and `n[1]` directly) because that's + # essentially what OpenMP/OpenACC expect -- two different symbols + rmin = Symbol(name=sregistry.make_name(prefix='rmin'), dtype=lhs.dtype) + rmax = Symbol(name=sregistry.make_name(prefix='rmax'), dtype=lhs.dtype) + + expr0 = Eq(rmin, limits_mapper[lhs.dtype].max) + expr1 = Eq(rmax, limits_mapper[lhs.dtype].min) + ispace = cluster.ispace.project(lambda i: i not in dims) + init.append(cluster.rebuild(exprs=[expr0, expr1], ispace=ispace)) + + processed.extend([ + e.func(rmin, Min(rmin, rhs), operation=OpMin), + e.func(rmax, Max(rmax, rhs), operation=OpMax) + ]) + + # Copy-back the final result to `lhs` at the end of the reduction + expr0 = Eq(lhs, rmin) + expr1 = Eq(lhs._translate(), rmax) + post.append(cluster.rebuild(exprs=[expr0, expr1], ispace=ispace)) + else: processed.append(e) - return init + [cluster.rebuild(processed)] + return init + [cluster.rebuild(processed)] + post def normalize_reductions_dense(cluster, sregistry, platform): @@ -674,19 +708,20 @@ def _normalize_reductions_dense(cluster, mapper, sregistry, platform): if e.is_Reduction: lhs, rhs = e.args + wf = lhs.function try: - f = rhs.function + rf = rhs.function except AttributeError: - f = None + rf = None - if lhs.function.is_Array: + if wf.is_Array and set(candidates).intersection(wf.dimensions): # Probably a compiler-generated reduction, e.g. via # recursive compilation; it's an Array already, so nothing to do processed.append(e) elif rhs in mapper: # Seen this RHS already, so reuse the Array that was created for it processed.append(e.func(lhs, mapper[rhs].indexify())) - elif f and f.is_Array and sum(flatten(f._size_nodomain)) == 0: + elif rf and rf.is_Array and sum(flatten(rf._size_nodomain)) == 0: # Special case: the RHS is an Array with no halo/padding, meaning # that the written data values are contiguous in memory, hence # we can simply reuse the Array itself as we're already in the @@ -698,8 +733,9 @@ def _normalize_reductions_dense(cluster, mapper, sregistry, platform): grid = cluster.grid except ValueError: grid = None - a = mapper[rhs] = Array(name=name, dtype=e.dtype, dimensions=dims, - grid=grid) + a = mapper[rhs] = Array( + name=name, dtype=e.dtype, dimensions=dims, grid=grid + ) # Populate the Array (the "map" part) processed.append(e.func(a.indexify(), rhs, operation=None)) diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index 29589ed8f9..ae03ff2fb0 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -11,7 +11,7 @@ ) from devito.symbolics import IntDiv, limits_mapper, uxreplace from devito.tools import Pickable, Tag, frozendict -from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min +from devito.types import Eq, Inc, ReduceMax, ReduceMin, ReduceMinMax, relational_min __all__ = [ 'ClusterizedEq', @@ -20,6 +20,7 @@ 'OpInc', 'OpMax', 'OpMin', + 'OpMinMax', 'identity_mapper', ] @@ -69,7 +70,7 @@ def operation(self): @property def is_Reduction(self): - return self.operation in (OpInc, OpMin, OpMax) + return self.operation in (OpInc, OpMin, OpMax, OpMinMax) @property def is_Increment(self): @@ -113,7 +114,8 @@ def detect(cls, expr): reduction_mapper = { Inc: OpInc, ReduceMax: OpMax, - ReduceMin: OpMin + ReduceMin: OpMin, + ReduceMinMax: OpMinMax } try: return reduction_mapper[type(expr)] @@ -130,6 +132,7 @@ def detect(cls, expr): OpInc = Operation('+') OpMax = Operation('max') OpMin = Operation('min') +OpMinMax = Operation('minmax') identity_mapper = { diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 6214515275..d106b5e811 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -13,7 +13,7 @@ from devito.data import FULL from devito.ir.cgen import ccode -from devito.ir.equations import DummyEq, OpInc, OpMax, OpMin +from devito.ir.equations import DummyEq, OpInc, OpMax, OpMin, OpMinMax from devito.ir.support import ( AFFINE, INBOUND, PARALLEL, PARALLEL_IF_ATOMIC, PARALLEL_IF_PVT, SEQUENTIAL, VECTORIZED, Forward, PrefetchUpdate, Property, WithLock, detect_io @@ -457,7 +457,7 @@ def reads(self): @cached_property def write(self): """The Function written by the Expression.""" - return self.expr.lhs.base.function + return self.output.base.function @cached_property def dimensions(self): @@ -467,17 +467,17 @@ def dimensions(self): @property def is_scalar(self): """True if the LHS is a scalar, False otherwise.""" - return isinstance(self.expr.lhs, (AbstractSymbol, IndexedBase, LocalObject)) + return isinstance(self.output, (AbstractSymbol, IndexedBase, LocalObject)) @property def is_tensor(self): """True if the LHS is an array entry, False otherwise.""" - return self.expr.lhs.is_Indexed + return self.output.is_Indexed @property def is_reduction(self): """True if the RHS performs a reduction operation, False otherwise.""" - return self.operation in (OpInc, OpMin, OpMax) + return self.operation in (OpInc, OpMin, OpMax, OpMinMax) @property def is_initializable(self): diff --git a/devito/passes/iet/engine.py b/devito/passes/iet/engine.py index 43ae7801ba..9b936fba76 100644 --- a/devito/passes/iet/engine.py +++ b/devito/passes/iet/engine.py @@ -17,7 +17,7 @@ from devito.symbolics import FieldFromComposite, FieldFromPointer, IndexedPointer, search from devito.tools import DAG, as_tuple, filter_ordered, sorted_priority, timed_pass from devito.types import ( - Array, Bundle, ComponentAccess, CompositeObject, FunctionMap, IncrDimension, + Array, Auto, Bundle, ComponentAccess, CompositeObject, FunctionMap, IncrDimension, Indirection, ModuloDimension, NPThreads, NThreadsBase, Pointer, SharedData, Symbol, Temp, ThreadArray, Wildcard ) @@ -658,6 +658,7 @@ def _(i, mapper, sregistry): mapper[i] = i._rebuild(name=sregistry.make_name(prefix='ind')) +@abstract_object.register(Auto) @abstract_object.register(Temp) @abstract_object.register(Wildcard) def _(i, mapper, sregistry): diff --git a/devito/types/basic.py b/devito/types/basic.py index e2af17bd97..32ad48ab51 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -1926,6 +1926,31 @@ def _subs(self, old, new, **hints): pass return super()._subs(old, new, **hints) + def _translate(self, mapper=None): + """ + Translate the indices of the current Indexed according to the provided + `{Dimension -> offset}` mapper. For example, if the current Indexed is + `f[x+1]` and the mapper is `{x: -1}`, then the result of the translation + will be `f[x]`. + + If `mapper` is None, then the translation will be unitary increment + along the fastest varying Dimension. For example, if the current + Indexed is `f[x+1, y+2]`, then the result of the translation will be + `f[x+1, y+3]` since `x` is the fastest varying Dimension. + """ + mapper = mapper or {self.dimensions[-1]: 1} + + if any(d not in mapper for d in self.dimensions): + raise ValueError( + f"Cannot translate {self} with mapper {mapper} since not " + "all dimensions are covered" + ) + + translations = [mapper.get(d, 0) for d in self.dimensions] + indices = [sum(i) for i in zip(self.indices, translations, strict=True)] + + return self.base[indices] + class IrregularFunctionInterface: diff --git a/devito/types/equation.py b/devito/types/equation.py index 3b6625c471..5cbba42e89 100644 --- a/devito/types/equation.py +++ b/devito/types/equation.py @@ -7,7 +7,7 @@ from devito.tools import Pickable, as_tuple, frozendict from devito.types.lazy import Evaluable -__all__ = ['Eq', 'Inc', 'ReduceMax', 'ReduceMin'] +__all__ = ['Eq', 'Inc', 'ReduceMax', 'ReduceMin', 'ReduceMinMax'] class Eq(sympy.Eq, Evaluable, Pickable): @@ -62,8 +62,8 @@ class Eq(sympy.Eq, Evaluable, Pickable): __rargs__ = ('lhs', 'rhs') __rkwargs__ = ('subdomain', 'coefficients', 'implicit_dims') - def __new__(cls, lhs, rhs=0, subdomain=None, coefficients=None, implicit_dims=None, - **kwargs): + def __new__(cls, lhs, rhs=0, subdomain=None, coefficients=None, + implicit_dims=None, **kwargs): if coefficients is not None: _ = deprecations.coeff_warn kwargs['evaluate'] = False @@ -237,3 +237,23 @@ class ReduceMax(Reduction): class ReduceMin(Reduction): pass + + +class ReduceMinMax(Reduction): + + """ + A coupled min/max Reduction. + + The left-hand side must have room for two components, one for the minimum and + one for the maximum; the behaviour is otherwise undefined. + The right-hand side is the expression to be reduced. + """ + + def __new__(cls, lhs, rhs=0, **kwargs): + if not lhs.function.is_AbstractFunction: + raise ValueError( + f"The left-hand side of a {cls.__name__} must be a " + "Function of size at least 2" + ) + + return super().__new__(cls, lhs, rhs=rhs, **kwargs) diff --git a/tests/test_dle.py b/tests/test_dle.py index ced0ad9937..5fa1c305c2 100644 --- a/tests/test_dle.py +++ b/tests/test_dle.py @@ -8,7 +8,7 @@ from conftest import _R, assert_blocking, assert_structure, skipif from devito import ( CustomDimension, DefaultDimension, Dimension, Eq, Function, Grid, Inc, Operator, - PrecomputedSparseTimeFunction, ReduceMax, ReduceMin, SpaceDimension, + PrecomputedSparseTimeFunction, ReduceMax, ReduceMin, ReduceMinMax, SpaceDimension, SparseTimeFunction, SubDimension, TimeFunction, configuration, cos, dimensions, info ) from devito.exceptions import InvalidArgument @@ -999,6 +999,31 @@ def test_array_minmax_reduction(self): assert n.data[0] == 26 assert n.data[1] == 0 + def test_array_minmax_reduction_simultaneous(self): + """ + Test the combined min/max reduction DSL construct. + """ + grid = Grid(shape=(3, 3, 3)) + i = Dimension(name='i') + + f = Function(name='f', grid=grid) + n = Function(name='n', grid=grid, shape=(2,), dimensions=(i,)) + + f.data[:] = np.arange(-5, 22).reshape((3, 3, 3)) + + eqn = [ReduceMinMax(n[0], f)] + + op = Operator(eqn) + + if 'openmp' in configuration['language']: + iterations = FindNodes(Iteration).visit(op) + expected = "reduction(min:rmin0) reduction(max:rmax0)" + assert expected in iterations[0].pragmas[0].ccode.value + + op() + assert n.data[0] == -5 + assert n.data[1] == 21 + def test_incs_no_atomic(self): """ Test that `Inc`'s don't get a `#pragma omp atomic` if performing