xtensor: inline Einsum OFG in lower_dot to avoid ShapeFeature compile blow-up#2164
xtensor: inline Einsum OFG in lower_dot to avoid ShapeFeature compile blow-up#2164cetagostini wants to merge 3 commits into
Conversation
`pt.einsum` wraps its output in an `Einsum` `OpFromGraph`. The OFG is
only inlined by `inline_optimized_einsum` during `specialize`, but while
it is alive `ShapeFeature.on_import` calls `OpFromGraph.infer_shape` on
every node import during canonicalize, and `infer_shape` re-walks the
OFG's inner graph each time. When several xtensor dots are composed
(e.g. multi-layer attention), this becomes super-linear and dominates
compile time.
Inlining the OFG immediately after `einsum` removes it before any
shape-using pass ever sees it. The 2-operand case `lower_dot` produces
has no path optimisation to defer, so inlining is safe and behaviour-
preserving.
Effect on the toy multi-head attention reproducer
(block_size=32, n_embd=64, n_head=4, with grad):
n_layer plain xtensor (before) xtensor (after)
1 0.94s 3.04s 1.12s
2 2.03s 72.50s 4.07s
Adds a structural test that locks in the post-`lower_xtensor` invariant
"no OpFromGraph nodes left in the lowered graph". All existing
xtensor / tensordot / einsum tests still pass.
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
After inlining the Einsum OFG in lower_dot, the runtime shape mismatch is now raised by the inlined Dot directly (`Shape mismatch: x has ...`) instead of by np.einsum/np.dot inside the OFG's wrapper. Add that message to the regex so the test passes on all backends. Co-authored-by: Cursor <cursoragent@cursor.com>
The fix is clearly ShapeFeature, which I linked in the other PR as already been worked out. Did you try to see if simply including those changes fixed the problem? |
Using your branch LLM takes compilation to 7s instead of 2s, and training 84s instead of 32s. So, the changes in #2056 are not sufficient to kill the full overhead generated. But this two lines of code make the trick @ricardoV94 |
I'll have to look at the before/after graph, otherwise we're pushing the dirt under the carpet instead of fixing it. If you want to, you would look at the dprint before and after, and use config optimizer_verbose or the interactive rewrite to see what's actually going on. Also what's compile time? Calling function? first eval? All mix or some paths going through numba? |
Found while writing the tiny-transformer gallery notebook in #2163 — multi-layer xtensor attention with
pytensor.gradblew up super-linearly in compile time, to the point of being unusable past 2 layers.Root cause
pt.einsumwraps its output in anEinsumOpFromGraph.inline_optimized_einsumonly inlines the OFG duringspecialize, but while it is aliveShapeFeature.on_importcallsOpFromGraph.infer_shapeon every node import during canonicalize, andinfer_shapere-walks the inner graph each time. With many composedxtensor.dots (e.g. multi-layer attention) this becomes super-linear and dominates compile time.cProfile of a 3-layer xtensor attention compile on
main(32s total):This is the same family of
ShapeFeatureissues #2056 is targeting.Fix
Inline the
EinsumOFG immediately after building it inlower_dot, soShapeFeaturenever sees it. The 2-operandEinsumproduced bylower_dothas no path optimisation to defer, so inlining is safe and behaviour-preserving.Reproducer
Compile time, single thread, M-class macOS, NUMBA mode:
N_LAYERSmainThe PR adds a structural test that locks in the post-
lower_xtensorinvariant "noOpFromGraphnodes left in the lowered graph".Relation to #2163
#2163 (tiny transformer gallery notebook) hits this in its current form and the notebook is unusable without it. Once this PR lands, #2163 will be updated to drop its own (incorrect) workaround attempts and rebase.
Test plan
tests/xtensor/(non-random): 190 passed, 6 skipped, 1 xfailed (190 → 191 with new structural test)tests/tensor/test_math.py::TestTensordot/TestMatMul: 24 passed, 1 xfailedtests/tensor/test_einsum.py: 46 passed0.000e+00, max grad diff5.03e-17main; confirm the speedupMade with Cursor