Skip to content

Commit c845871

Browse files
committed
compiler: Fix abstract_object(Array)
1 parent 41b8af8 commit c845871

4 files changed

Lines changed: 29 additions & 7 deletions

File tree

devito/passes/iet/engine.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
from devito.symbolics import FieldFromComposite, FieldFromPointer, IndexedPointer, search
1818
from devito.tools import DAG, as_tuple, filter_ordered, sorted_priority, timed_pass
1919
from devito.types import (
20-
Array, Bundle, ComponentAccess, CompositeObject, IncrDimension, Indirection, Lock,
21-
ModuloDimension, NPThreads, NThreadsBase, Pointer, SharedData, Symbol, Temp,
22-
ThreadArray, Wildcard
20+
Array, Bundle, ComponentAccess, CompositeObject, FunctionMap, IncrDimension,
21+
Indirection, ModuloDimension, NPThreads, NThreadsBase, Pointer, SharedData,
22+
Symbol, Temp, ThreadArray, Wildcard
2323
)
2424
from devito.types.args import ArgProvider
2525
from devito.types.dense import DiscreteFunction
@@ -550,12 +550,19 @@ def _(i, mapper, sregistry):
550550

551551
@abstract_object.register(Array)
552552
def _(i, mapper, sregistry):
553-
if isinstance(i, Lock):
554-
name = sregistry.make_name(prefix='lock')
553+
name = sregistry.make_name(prefix=i._symbol_prefix)
554+
555+
if i.initvalue is not None:
556+
initvalue = []
557+
for v in i.initvalue:
558+
try:
559+
initvalue.append(v.xreplace(mapper))
560+
except AttributeError:
561+
initvalue.append(v)
555562
else:
556-
name = sregistry.make_name(prefix='a')
563+
initvalue = None
557564

558-
v = i._rebuild(name=name, alias=True)
565+
v = i._rebuild(name=name, initvalue=initvalue, alias=True)
559566

560567
mapper.update({
561568
i: v,
@@ -662,6 +669,16 @@ def _(i, mapper, sregistry):
662669
mapper[i] = i._rebuild(name=sregistry.make_name(prefix='ptr'))
663670

664671

672+
@abstract_object.register(FunctionMap)
673+
def _(i, mapper, sregistry):
674+
name = sregistry.make_name(prefix=i._symbol_prefix)
675+
tensor = mapper.get(i.tensor, i.tensor)
676+
677+
v = i._rebuild(name, tensor)
678+
679+
mapper[i] = v
680+
681+
665682
@abstract_object.register(NPThreads)
666683
def _(i, mapper, sregistry):
667684
mapper[i] = i._rebuild(name=sregistry.make_name(prefix='npthreads'))

devito/types/array.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ class Array(ArrayBasic):
134134

135135
is_Array = True
136136

137+
_symbol_prefix = 'a'
138+
137139
__rkwargs__ = (ArrayBasic.__rkwargs__ +
138140
('dimensions', 'scope', 'initvalue'))
139141

devito/types/misc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
'CriticalRegion',
2020
'FIndexed',
2121
'Fence',
22+
'FunctionMap',
2223
'Global',
2324
'Hyperplane',
2425
'Indirection',

devito/types/parallel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,8 @@ class Lock(Array):
245245

246246
is_volatile = True
247247

248+
_symbol_prefix = 'lock'
249+
248250
# Not a performance-sensitive object
249251
_data_alignment = False
250252

0 commit comments

Comments
 (0)