Skip to content

Commit ec47714

Browse files
committed
compiler: Tweak linearize pass
1 parent f25d863 commit ec47714

2 files changed

Lines changed: 10 additions & 16 deletions

File tree

devito/passes/iet/linearization.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def linearize_accesses(iet, key0, tracker=None):
223223

224224
# 2) What `iet` *offers*
225225
# E.g. `{x_fsz0 -> u_vec->size[1]}`
226-
defines = FindSymbols('defines-aliases').visit(iet)
226+
defines = FindSymbols('defines').visit(iet)
227227
offers = filter_ordered(i for i in defines if key0(i.function))
228228
instances = {}
229229
for i in offers:
@@ -294,16 +294,9 @@ def _(f, d):
294294

295295

296296
@_generate_fsz.register(Array)
297-
def _(f, d):
298-
return f.symbolic_shape[d]
299-
300-
301297
@_generate_fsz.register(Bundle)
302298
def _(f, d):
303-
if f.is_DiscreteFunction:
304-
return _generate_fsz.registry[DiscreteFunction](f, d)
305-
else:
306-
return _generate_fsz.registry[Array](f, d)
299+
return f.symbolic_shape[d]
307300

308301

309302
@singledispatch

tests/test_linearize.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -342,20 +342,21 @@ def test_strides_forwarding1():
342342
linearize(graph, callback=True, options={'index-mode': 'int32'},
343343
sregistry=SymbolRegistry())
344344

345-
# Despite `a` is passed via `a.indexed`, and since it's an Array (which
346-
# have symbolic shape), we expect the stride exprs to be placed in `bar`,
347-
# and in `bar` only, as `foo` doesn't really use `a`, it just propagates it
348-
# down to `bar`
345+
# `a` is passed via `a.indexed`, so the stride exprs are expected to be
346+
# placed in `foo` and then passed down to `bar` as arguments
349347
foo = graph.root
350348
bar = graph.efuncs['bar']
351349

352350
assert len(foo.body.body) == 1
353351
assert foo.body.body[0].is_Call
352+
assert len(foo.body.strides) == 3
353+
assert foo.body.strides[0].write.name == 'y_fsz0'
354+
assert foo.body.strides[2].write.name == 'y_stride0'
354355

356+
assert len(bar.parameters) == 2
357+
assert bar.parameters[0].name == 'a'
358+
assert bar.parameters[1].name == 'y_stride0'
355359
assert len(bar.body.body) == 1
356-
assert len(bar.body.strides) == 3
357-
assert bar.body.strides[0].write.name == 'y_fsz0'
358-
assert bar.body.strides[2].write.name == 'y_stride0'
359360

360361

361362
def test_strides_forwarding2():

0 commit comments

Comments
 (0)