Skip to content

Commit 65f1b37

Browse files
committed
compiler: fix real dtype
1 parent 8e0a2d3 commit 65f1b37

4 files changed

Lines changed: 12 additions & 14 deletions

File tree

devito/ir/cgen/printer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,11 @@ def _prec(self, expr):
7777
if dtype is None or np.issubdtype(dtype, np.integer):
7878
real = any(isinstance(i, Float) for i in expr.atoms())
7979
if real:
80-
return self.dtype
80+
try:
81+
return np.promote_types(self.dtype, np.float32).type
82+
except np.exceptions.DTypePromotionError:
83+
# Corner cases, e.g. Void, cannot (shouldn't) be promoted
84+
return self.dtype
8185
else:
8286
return dtype or self.dtype
8387
else:
@@ -89,9 +93,9 @@ def prec_literal(self, expr):
8993
def func_literal(self, expr):
9094
return self._func_litterals.get(self._prec(expr), '')
9195

92-
def func_prefix(self, expr, abs=False):
96+
def func_prefix(self, expr, mfunc=False):
9397
prefix = self._func_prefix.get(self._prec(expr), '')
94-
if abs:
98+
if mfunc:
9599
return prefix
96100
else:
97101
return '' if prefix == 'f' else prefix
@@ -235,7 +239,7 @@ def _print_Mul(self, expr):
235239

236240
def _print_fmath_func(self, name, expr):
237241
args = ",".join([self._print(i) for i in expr.args])
238-
func = f'{self.func_prefix(expr, abs=True)}{name}{self.func_literal(expr)}'
242+
func = f'{self.func_prefix(expr, mfunc=True)}{name}{self.func_literal(expr)}'
239243
return f"{self._ns}{func}({args})"
240244

241245
def _print_Min(self, expr):

devito/operator/operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple,
3333
flatten, filter_sorted, frozendict, is_integer,
3434
split, timed_pass, timed_region, contains_val)
35-
from devito.types import (Buffer, Grid, Evaluable, host_layer, device_layer,
35+
from devito.types import (Buffer, Evaluable, host_layer, device_layer,
3636
disk_layer)
3737
from devito.types.dimension import Thickness
3838

devito/types/grid.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,7 @@ def __init__(self, shape, extent=None, origin=None, dimensions=None,
193193
extent = as_tuple(extent or tuple(1. for _ in self.shape))
194194
self._extent = tuple(dtype(e) for e in extent)
195195

196-
# Initialize SubDomains
197-
subdomains = tuple(i for i in (Domain(), Interior(), *as_tuple(subdomains)))
198-
for i in subdomains:
199-
i.__subdomain_finalize__(self)
200-
self._subdomains = subdomains
201-
196+
# The origin of the grid
202197
origin = as_tuple(origin or tuple(0. for _ in self.shape))
203198
self._origin = tuple(dtype(o) for o in origin)
204199
self._origin_symbols = tuple(Scalar(name='o_%s' % d.name, dtype=dtype,

tests/test_builtins.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def test_gs_1d_float(self, sigma):
155155
def test_gs_2d_int(self, sigma):
156156
"""Test the Gaussian smoother in 2d."""
157157

158-
a = ascent()
158+
a = ascent().astype(np.int32)
159159
sp_smoothed = gaussian_filter(a, sigma=sigma)
160160
dv_smoothed = gaussian_smooth(a, sigma=sigma)
161161

@@ -169,8 +169,7 @@ def test_gs_2d_int(self, sigma):
169169
def test_gs_2d_float(self, sigma):
170170
"""Test the Gaussian smoother in 2d."""
171171

172-
a = ascent()
173-
a = a+0.1
172+
a = ascent()+0.1
174173
sp_smoothed = gaussian_filter(a, sigma=sigma)
175174
dv_smoothed = gaussian_smooth(a, sigma=sigma)
176175

0 commit comments

Comments
 (0)