Skip to content

Commit 570605c

Browse files
FabioLuporinimloubout
authored andcommitted
compiler: Enhance Array codegen
1 parent a5193e2 commit 570605c

2 files changed

Lines changed: 20 additions & 6 deletions

File tree

devito/ir/cgen/printer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class BasePrinter(CodePrinter):
4141
_prec_literals = {np.float32: 'F', np.complex64: 'F'}
4242

4343
_qualifiers_mapper = {
44+
'is_extern': 'extern',
4445
'is_const': 'const',
4546
'is_volatile': 'volatile',
4647
'_mem_constant': 'static',

tests/test_iet.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,18 @@
77

88
from devito import (Eq, Grid, Function, TimeFunction, Operator, Dimension, # noqa
99
switchconfig)
10-
from devito.ir.iet import (Call, Callable, Conditional, DeviceCall, DummyExpr,
11-
Iteration, List, KernelLaunch, Lambda, ElementalFunction,
12-
CGen, FindSymbols, filter_iterations, make_efunc,
13-
retrieve_iteration_tree, Transformer)
10+
from devito.ir.iet import (
11+
Call, Callable, Conditional, Definition, DeviceCall, DummyExpr, Iteration, List,
12+
KernelLaunch, Lambda, ElementalFunction, CGen, FindSymbols, filter_iterations,
13+
make_efunc, retrieve_iteration_tree, Transformer
14+
)
1415
from devito.ir import SymbolRegistry
1516
from devito.passes.iet.engine import Graph
1617
from devito.passes.iet.languages.C import CDataManager
1718
from devito.symbolics import (Byref, FieldFromComposite, InlineIf, Macro, Class,
18-
FLOAT)
19+
String, FLOAT)
1920
from devito.tools import CustomDtype, as_tuple, dtype_to_ctype
20-
from devito.types import Array, LocalObject, Symbol
21+
from devito.types import CustomDimension, Array, LocalObject, Symbol
2122

2223

2324
@pytest.fixture
@@ -475,3 +476,15 @@ def test_codegen_quality0():
475476
assert len(foo.parameters) == 3
476477
assert len(foo1.parameters) == 1
477478
assert foo1.parameters[0] is a
479+
480+
481+
def test_special_array_definition():
482+
483+
class MyArray(Array):
484+
is_extern = True
485+
_data_alignment = False
486+
487+
dim = CustomDimension(name='d', symbolic_size=String(''))
488+
a = MyArray(name='a', dimensions=dim, scope='shared', dtype=np.uint8)
489+
490+
assert str(Definition(a)) == "extern unsigned char a[];"

0 commit comments

Comments
 (0)