@@ -342,20 +342,21 @@ def test_strides_forwarding1():
342342 linearize (graph , callback = True , options = {'index-mode' : 'int32' },
343343 sregistry = SymbolRegistry ())
344344
345- # Despite `a` is passed via `a.indexed`, and since it's an Array (which
346- # have symbolic shape), we expect the stride exprs to be placed in `bar`,
347- # and in `bar` only, as `foo` doesn't really use `a`, it just propagates it
348- # down to `bar`
345+ # `a` is passed via `a.indexed`, so the stride exprs are expected to be
346+ # placed in `foo` and then passed down to `bar` as arguments
349347 foo = graph .root
350348 bar = graph .efuncs ['bar' ]
351349
352350 assert len (foo .body .body ) == 1
353351 assert foo .body .body [0 ].is_Call
352+ assert len (foo .body .strides ) == 3
353+ assert foo .body .strides [0 ].write .name == 'y_fsz0'
354+ assert foo .body .strides [2 ].write .name == 'y_stride0'
354355
356+ assert len (bar .parameters ) == 2
357+ assert bar .parameters [0 ].name == 'a'
358+ assert bar .parameters [1 ].name == 'y_stride0'
355359 assert len (bar .body .body ) == 1
356- assert len (bar .body .strides ) == 3
357- assert bar .body .strides [0 ].write .name == 'y_fsz0'
358- assert bar .body .strides [2 ].write .name == 'y_stride0'
359360
360361
361362def test_strides_forwarding2 ():
0 commit comments