Skip to content

Commit 2714d16

Browse files
committed
compiler: fix Cast for string cast type
1 parent 34cdc53 commit 2714d16

3 files changed

Lines changed: 48 additions & 3 deletions

File tree

devito/symbolics/extended_sympy.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Extended SymPy hierarchy.
33
"""
4+
import re
45

56
import numpy as np
67
import sympy
@@ -394,13 +395,29 @@ def __new__(cls, base, dtype=None, stars=None, reinterpret=False, **kwargs):
394395
# E.g. void
395396
pass
396397

398+
dtype, stars = cls._process_dtype(dtype, stars)
399+
397400
obj = super().__new__(cls, base)
398401
obj._stars = stars or ''
399402
obj._dtype = dtype
400403
obj._reinterpret = reinterpret
401404

402405
return obj
403406

407+
@classmethod
408+
def _process_dtype(cls, dtype, stars):
409+
if not isinstance(dtype, str) or stars is not None:
410+
return dtype, stars
411+
412+
# String dtype, e.g. "float", "int*", "foo**"
413+
match = re.fullmatch(r'(\w+)\s*(\*+)?', dtype)
414+
if match:
415+
dtype = match.group(1)
416+
stars = match.group(2) or ''
417+
return dtype, stars
418+
else:
419+
return dtype, stars
420+
404421
def _hashable_content(self):
405422
return super()._hashable_content() + (self._stars,)
406423

@@ -429,7 +446,10 @@ def _C_ctype(self):
429446

430447
@property
431448
def _op(self):
432-
return f'({ctypes_to_cstr(self._C_ctype)})'
449+
cstr = ctypes_to_cstr(self._C_ctype)
450+
if self.stars:
451+
cstr = f"{cstr}{self.stars}"
452+
return f'({cstr})'
433453

434454
def __str__(self):
435455
return f"{self._op}{self.base}"

devito/tools/dtypes_lowering.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,10 @@ class c_restrict_void_p(ctypes.c_void_p):
250250

251251
def ctypes_to_cstr(ctype, toarray=None):
252252
"""Translate ctypes types into C strings."""
253-
if ctype in ctypes_vector_mapper.values():
253+
if isinstance(ctype, str):
254+
# Already a C string
255+
return ctype
256+
elif ctype in ctypes_vector_mapper.values():
254257
retval = ctype.__name__
255258
elif isinstance(ctype, CustomDtype):
256259
retval = str(ctype)

tests/test_symbolics.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def test_rvalue():
419419
assert str(Rvalue(ctype, ns, init)) == 'my::namespace::dummytype{}'
420420

421421

422-
def test_cast():
422+
def test_basecast():
423423
s = Symbol(name='s', dtype=np.float32)
424424

425425
class BarCast(BaseCast):
@@ -435,6 +435,28 @@ class BarCast(BaseCast):
435435
assert v != v1
436436

437437

438+
def test_str_cast():
439+
s = Symbol(name='s', dtype=np.float32)
440+
441+
v = Cast(s, 'foo')
442+
assert not v.stars
443+
assert v.dtype == 'foo'
444+
assert v._op == '(foo)'
445+
assert ccode(v) == '(foo)s'
446+
447+
v = Cast(s, 'foo*')
448+
assert v.stars == '*'
449+
assert v.dtype == 'foo'
450+
assert v._op == '(foo*)'
451+
assert ccode(v) == '(foo*)s'
452+
453+
v = Cast(s, 'foo **')
454+
assert v.stars == '**'
455+
assert v.dtype == 'foo'
456+
assert v._op == '(foo**)'
457+
assert ccode(v) == '(foo**)s'
458+
459+
438460
def test_findexed():
439461
grid = Grid(shape=(3, 3, 3))
440462
x, y, z = grid.dimensions

0 commit comments

Comments
 (0)