Skip to content

Commit 85518f6

Browse files
committed
SQUASH WITH LINEARIZE TWEAK
1 parent 64139dd commit 85518f6

1 file changed

Lines changed: 8 additions & 7 deletions

File tree

tests/test_linearize.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -342,20 +342,21 @@ def test_strides_forwarding1():
342342
linearize(graph, callback=True, options={'index-mode': 'int32'},
343343
sregistry=SymbolRegistry())
344344

345-
# Despite `a` is passed via `a.indexed`, and since it's an Array (which
346-
# have symbolic shape), we expect the stride exprs to be placed in `bar`,
347-
# and in `bar` only, as `foo` doesn't really use `a`, it just propagates it
348-
# down to `bar`
345+
# `a` is passed via `a.indexed`, so the stride exprs are expected to be
346+
# placed in `foo` and then passed down to `bar` as arguments
349347
foo = graph.root
350348
bar = graph.efuncs['bar']
351349

352350
assert len(foo.body.body) == 1
353351
assert foo.body.body[0].is_Call
352+
assert len(foo.body.strides) == 3
353+
assert foo.body.strides[0].write.name == 'y_fsz0'
354+
assert foo.body.strides[2].write.name == 'y_stride0'
354355

356+
assert len(bar.parameters) == 2
357+
assert bar.parameters[0].name == 'a'
358+
assert bar.parameters[1].name == 'y_stride0'
355359
assert len(bar.body.body) == 1
356-
assert len(bar.body.strides) == 3
357-
assert bar.body.strides[0].write.name == 'y_fsz0'
358-
assert bar.body.strides[2].write.name == 'y_stride0'
359360

360361

361362
def test_strides_forwarding2():

0 commit comments

Comments
 (0)