Skip to content

Commit 8e0a2d3

Browse files
committed
compiler: add scalar type option
1 parent d24e6e1 commit 8e0a2d3

6 files changed

Lines changed: 25 additions & 6 deletions

File tree

devito/core/cpu.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ def _normalize_kwargs(cls, **kwargs):
3232
o['mpi'] = oo.pop('mpi')
3333
o['parallel'] = o['openmp'] # Backwards compatibility
3434

35+
# Minimum scalar type
36+
o['scalar-min-type'] = oo.pop('scalar-min-type', cls.SCALAR_MIN_TYPE)
37+
3538
# Buffering
3639
o['buf-async-degree'] = oo.pop('buf-async-degree', None)
3740

devito/core/gpu.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ def _normalize_kwargs(cls, **kwargs):
4040
o['mpi'] = oo.pop('mpi')
4141
o['parallel'] = True
4242

43+
# Minimum scalar type
44+
o['scalar-min-type'] = oo.pop('scalar-min-type', cls.SCALAR_MIN_TYPE)
45+
4346
# Buffering
4447
o['buf-async-degree'] = oo.pop('buf-async-degree', None)
4548

devito/core/operator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from collections.abc import Iterable
22
from functools import cached_property
33

4+
import numpy as np
5+
46
from devito.core.autotuning import autotune
57
from devito.exceptions import InvalidOperator
68
from devito.ir import FindSymbols
@@ -67,6 +69,11 @@ class BasicOperator(Operator):
6769
intensity of the generated kernel.
6870
"""
6971

72+
SCALAR_MIN_TYPE = np.float16
73+
"""
74+
Minimum datatype for a scalar alias for a common sub-expression or cire temp.
75+
"""
76+
7077
PAR_COLLAPSE_NCORES = 4
7178
"""
7279
Use a collapse clause if the number of available physical cores is greater

devito/passes/clusters/aliases.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,11 @@ class CireTransformer:
109109
def __init__(self, sregistry, options, platform):
110110
self.sregistry = sregistry
111111
self.platform = platform
112-
113112
self.opt_minstorage = options['min-storage']
114113
self.opt_rotate = options['cire-rotate']
115114
self.opt_ftemps = options['cire-ftemps']
116115
self.opt_mingain = options['cire-mingain']
116+
self.opt_mindtype = options['scalar-min-type']
117117
self.opt_multisubdomain = True
118118

119119
def _aliases_from_clusters(self, clusters, exclude, meta):
@@ -143,7 +143,7 @@ def _aliases_from_clusters(self, clusters, exclude, meta):
143143

144144
# Schedule -> [Clusters]_k
145145
processed, subs = lower_schedule(schedule, meta, self.sregistry,
146-
self.opt_ftemps)
146+
self.opt_ftemps, self.opt_mindtype)
147147

148148
# [Clusters]_k -> [Clusters]_k (optimization)
149149
if self.opt_multisubdomain:
@@ -831,7 +831,7 @@ def optimize_schedule_rotations(schedule, sregistry):
831831
return schedule.rebuild(*processed, rmapper=rmapper)
832832

833833

834-
def lower_schedule(schedule, meta, sregistry, ftemps):
834+
def lower_schedule(schedule, meta, sregistry, ftemps, mindtype):
835835
"""
836836
Turn a Schedule into a sequence of Clusters.
837837
"""
@@ -849,7 +849,8 @@ def lower_schedule(schedule, meta, sregistry, ftemps):
849849
# This prevents cases such as `floor(a*b)` with `a` and `b` floats
850850
# that would creat a temporary `int r = b` leading to erronous
851851
# numerical results
852-
dtype = sympy_dtype(pivot, base=meta.dtype)
852+
mindtype = None if writeto else mindtype
853+
dtype = sympy_dtype(pivot, base=meta.dtype, smin=mindtype)
853854

854855
if writeto:
855856
# The Dimensions defining the shape of Array

devito/passes/clusters/cse.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections import defaultdict
22
from functools import cached_property, singledispatch
33

4+
import numpy as np
45
import sympy
56
from sympy import Add, Function, Indexed, Mul, Pow
67
try:
@@ -69,11 +70,12 @@ def cse(cluster, sregistry=None, options=None, **kwargs):
6970
"""
7071
min_cost = options['cse-min-cost']
7172
mode = options['cse-algo']
73+
mindtype = np.promote_types(options['scalar-min-type'], cluster.dtype).type
7274

7375
if cluster.is_fence:
7476
return cluster
7577

76-
make = lambda: CTemp(name=sregistry.make_name(), dtype=cluster.dtype)
78+
make = lambda: CTemp(name=sregistry.make_name(), dtype=mindtype)
7779

7880
exprs = _cse(cluster, make, min_cost=min_cost, mode=mode)
7981

devito/symbolics/inspection.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def has_integer_args(*args):
298298
return res
299299

300300

301-
def sympy_dtype(expr, base=None, default=None):
301+
def sympy_dtype(expr, base=None, default=None, smin=None):
302302
"""
303303
Infer the dtype of the expression.
304304
"""
@@ -322,4 +322,7 @@ def sympy_dtype(expr, base=None, default=None):
322322
else:
323323
dtype = np.promote_types(dtype, np.complex64).type
324324

325+
if smin is not None and not np.issubdtype(dtype, np.integer):
326+
dtype = np.promote_types(dtype, smin).type
327+
325328
return dtype

0 commit comments

Comments
 (0)