Skip to content

Commit b419050

Browse files
committed
compiler: add switch for static_cast vs reinterpret_cast
1 parent d385957 commit b419050

5 files changed

Lines changed: 17 additions & 8 deletions

File tree

devito/arch/compiler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ def __init__(self):
181181

182182
fields = {'cc', 'ld'}
183183
default_cpp = False
184+
_cxxstd = 'c++14'
185+
_cstd = 'c99'
184186

185187
def __init__(self, **kwargs):
186188
_name = kwargs.pop('name', self.__class__.__name__)
@@ -255,7 +257,7 @@ def version(self):
255257

256258
@property
257259
def std(self):
258-
return 'c++14' if self._cpp else 'c99'
260+
return self._cxxstd if self._cpp else self._cstd
259261

260262
def get_version(self):
261263
result, stdout, stderr = call_capture_output((self.cc, "--version"))
@@ -491,7 +493,7 @@ def __init_finalize__(self, **kwargs):
491493
language = kwargs.pop('language', configuration['language'])
492494
platform = kwargs.pop('platform', configuration['platform'])
493495

494-
if platform is NvidiaDevice:
496+
if isinstance(platform, NvidiaDevice):
495497
self.cflags.remove(f'-std={self.std}')
496498
# Add flags for OpenMP offloading
497499
if language in ['C', 'openmp']:

devito/operator/operator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1424,7 +1424,8 @@ def parse_kwargs(**kwargs):
14241424
kwargs['compiler'] = configuration['compiler'].__new_with__()
14251425

14261426
# Make sure compiler and language are compatible
1427-
if kwargs['compiler']._cpp and kwargs['language'] in ['C', 'openmp']:
1427+
if compiler is not None and kwargs['compiler']._cpp and \
1428+
kwargs['language'] in ['C', 'openmp']:
14281429
kwargs['language'] = 'CXX' if kwargs['language'] == 'C' else 'CXXopenmp'
14291430
if 'CXX' in kwargs['language'] and not kwargs['compiler']._cpp:
14301431
kwargs['compiler'] = kwargs['compiler'].__new_with__(cpp=True)

devito/passes/iet/languages/CXX.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,5 +91,6 @@ def _print_Cast(self, expr):
9191
tstr = self._print(expr._C_ctype)
9292
if 'void' in tstr:
9393
return super()._print_Cast(expr)
94-
cast = f'static_cast<{tstr}{self._print(expr.stars)}>'
94+
caster = 'reinterpret_cast' if expr.reinterpret else 'static_cast'
95+
cast = f'{caster}<{tstr}{self._print(expr.stars)}>'
9596
return self._print_UnaryOp(expr, op=cast, parenthesize=True)

devito/passes/iet/languages/openacc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,11 +236,11 @@ def place_devptr(self, iet, **kwargs):
236236

237237
dpf = List(body=[
238238
self.lang.mapper['map-serial-present'](hp, tdp),
239-
Block(body=DummyExpr(tdp, cast_mapper(tdp.dtype)(hp)))
239+
Block(body=DummyExpr(tdp, cast_mapper(tdp.dtype)(hp, reinterpret=True)))
240240
])
241241

242242
ffp = FieldFromPointer(f._C_field_dmap, f._C_symbol)
243-
ctdp = cast_mapper((hp.dtype, '*'))(tdp)
243+
ctdp = cast_mapper((hp.dtype, '*'))(tdp, reinterpret=True)
244244
cast = DummyExpr(ffp, ctdp)
245245

246246
ret = Return(ctdp)

devito/symbolics/extended_sympy.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,9 +384,9 @@ class Cast(UnaryOp):
384384
"""
385385

386386
__rargs__ = ('base', )
387-
__rkwargs__ = ('dtype', 'stars')
387+
__rkwargs__ = ('dtype', 'stars', 'reinterpret')
388388

389-
def __new__(cls, base, dtype=None, stars=None, **kwargs):
389+
def __new__(cls, base, dtype=None, stars=None, reinterpret=False, **kwargs):
390390
try:
391391
if issubclass(dtype, np.generic) and sympify(base).is_Number:
392392
base = sympify(dtype(base))
@@ -397,6 +397,7 @@ def __new__(cls, base, dtype=None, stars=None, **kwargs):
397397
obj = super().__new__(cls, base)
398398
obj._stars = stars or ''
399399
obj._dtype = dtype
400+
obj._reinterpret = reinterpret
400401
return obj
401402

402403
def _hashable_content(self):
@@ -412,6 +413,10 @@ def stars(self):
412413
def dtype(self):
413414
return self._dtype
414415

416+
@property
417+
def reinterpret(self):
418+
return self._reinterpret
419+
415420
@property
416421
def _C_ctype(self):
417422
ctype = ctypes_vector_mapper.get(self.dtype, self.dtype)

0 commit comments

Comments
 (0)