From 63305962ee3d2a8b31306534b85f88fc8a46783b Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Sat, 23 May 2026 17:02:49 -0400 Subject: [PATCH 1/3] arena: support permuted Hadamard add/subt/mult on Tensor ToT The permuted, arena ToT x arena ToT overloads of add, subt, and mult (scaled and unscaled) previously threw "permuted ... of a tensor-of-tensors is not yet supported". This blocked CSV/PNO-based coupled-cluster, whose residual evaluates permuted ToT Hadamard products at the tile-op level (a binary Mult/Add op calling left.mult(right, perm) etc.). By the time a permuted product reaches a tile op, the expression engine has already brought both operands to a common (congruent) layout, so the elementwise product/sum is valid and perm is purely the result permutation. Compute the unpermuted result, then apply perm as a post-pass via permute(), which already handles arena ToT: a shallow outer-cell reindex (arena_permute_shallow) plus an inner-slab rewrite (arena_inner_permute) when the bipartite permutation's inner part is non-trivial. This mirrors the existing numeric x arena permuted-mult branches. --- src/TiledArray/tensor/tensor.h | 67 ++++++++++++++++++---------------- 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/src/TiledArray/tensor/tensor.h b/src/TiledArray/tensor/tensor.h index ca67641e83..8399238f5f 100644 --- a/src/TiledArray/tensor/tensor.h +++ b/src/TiledArray/tensor/tensor.h @@ -1953,33 +1953,28 @@ class Tensor { return !static_cast(perm) || perm.is_identity(); } - /// Permuted add for `Tensor` ToT operands. A non-trivial - /// permutation of arena ToT tiles is not yet supported; an identity (or - /// null) permutation falls through to the plain element-wise add. + /// Permuted add for `Tensor` ToT operands. The operands are + /// congruent by the time a permuted product reaches a tile op, so the + /// elementwise `add(right)` is valid and `perm` is the result permutation; + /// `permute` applies it (shallow outer reindex + inner-slab rewrite). template requires(is_arena_tensor_v && is_arena_tensor_v && detail::is_permutation_v) Tensor add(const Right& right, const Perm& perm) const { - if (!arena_perm_is_trivial(perm)) - TA_EXCEPTION( - "TA::Tensor::add: permuted add of a tensor-of-tensors " - "is not yet supported"); - return add(right); + auto result = add(right); + return arena_perm_is_trivial(perm) ? result : result.permute(perm); } /// Permuted scaled add for `Tensor` ToT operands; see the - /// permuted-add overload above for the permutation restriction. + /// permuted-add overload above for the congruent-operand rationale. template requires(is_arena_tensor_v && is_arena_tensor_v && detail::is_numeric_v && detail::is_permutation_v) Tensor add(const Right& right, const Scalar factor, const Perm& perm) const { - if (!arena_perm_is_trivial(perm)) - TA_EXCEPTION( - "TA::Tensor::add: permuted scaled add of a " - "tensor-of-tensors is not yet supported"); - return add(right, factor); + auto result = add(right, factor); + return arena_perm_is_trivial(perm) ? result : result.permute(perm); } /// Add this and \c other to construct a new tensor @@ -2382,8 +2377,15 @@ class Tensor { typename std::enable_if::value && detail::is_permutation_v>::type* = nullptr> Tensor subt(const Right& right, const Perm& perm) const { - if constexpr (is_tensor_view_v) { - // Permutation isn't supported for view inner cells (fixed storage + if constexpr (is_arena_tensor_v && + is_arena_tensor_v) { + // arena ToT x arena ToT: operands are congruent at tile-op time, so the + // elementwise `subt(right)` is valid; apply the result permutation as a + // post-pass (shallow outer reindex + inner-slab rewrite). + auto result = subt(right); + return arena_perm_is_trivial(perm) ? result : result.permute(perm); + } else if constexpr (is_tensor_view_v) { + // Permutation isn't supported for other view inner cells (fixed storage // layout). Subt+permute would require materialization. TA_EXCEPTION( "Tensor::subt(right, perm): permutation is not " @@ -2443,11 +2445,10 @@ class Tensor { Tensor subt(const Right& right, const Scalar factor, const Perm& perm) const { if constexpr (is_arena_tensor_v && is_arena_tensor_v) { - if (!arena_perm_is_trivial(perm)) - TA_EXCEPTION( - "TA::Tensor::subt: permuted scaled subt of a " - "tensor-of-tensors is not yet supported"); - return subt(right, factor); + // arena ToT x arena ToT scaled subtraction; see the unscaled permuted + // subt overload above for the congruent-operand rationale. + auto result = subt(right, factor); + return arena_perm_is_trivial(perm) ? result : result.permute(perm); } else { return binary( right, @@ -2622,11 +2623,15 @@ class Tensor { decltype(auto) mult(const Right& right, const Perm& perm) const { if constexpr (is_arena_tensor_v && is_arena_tensor_v) { - if (!arena_perm_is_trivial(perm)) - TA_EXCEPTION( - "TA::Tensor::mult: permuted mult of a " - "tensor-of-tensors is not yet supported"); - return mult(right); + // arena ToT x arena ToT Hadamard product. By the time a permuted product + // reaches a tile op, the engine has already brought both operands to a + // common (congruent) layout, so the elementwise `mult(right)` is valid; + // `perm` is the result permutation (common layout -> target). Apply it + // as a post-pass: `permute` reindexes the outer cells shallowly + // (arena_permute_shallow) and rewrites the inner slab if the inner part + // of the permutation is non-trivial (arena_inner_permute). + auto result = mult(right); + return arena_perm_is_trivial(perm) ? result : result.permute(perm); } else if constexpr (detail::is_numeric_v && is_arena_tensor_v) { // t x tot: a plain scalar tile times an arena ToT tile. The 2-arg @@ -2697,11 +2702,11 @@ class Tensor { const Perm& perm) const { if constexpr (is_arena_tensor_v && is_arena_tensor_v) { - if (!arena_perm_is_trivial(perm)) - TA_EXCEPTION( - "TA::Tensor::mult: permuted scaled mult of a " - "tensor-of-tensors is not yet supported"); - return mult(right, factor); + // arena ToT x arena ToT scaled Hadamard product; see the unscaled + // permuted mult overload above for the congruent-operand rationale. + // Scale during the elementwise product, then permute the result. + auto result = mult(right, factor); + return arena_perm_is_trivial(perm) ? result : result.permute(perm); } else { return binary( right, From 34711c8692788de341f0e01367c454a529b10f75 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Sun, 24 May 2026 14:49:03 -0400 Subject: [PATCH 2/3] arena: support inner result permutation for ToT x scalar Scale contraction A mixed inner-Scale product (Tensor ToT x plain Tensor -> ToT) under an outer Contraction with a non-identity inner *result* permutation crashed in ArenaTensor::axpy_to(other, factor, perm), which rejects in-place permutation of view cells. The Scale path pushed the inner perm onto the per-cell op (via the fallback axpy_to(..., perm)) and dropped it from total_perm, while make_contraction_arena_plan bailed on any non-identity inner perm -- leaving the view result cell both unshaped and asked to permute itself. Mirror the inner-Contraction view handling instead: for view (arena) result cells, carry the full perm in total_perm so op_'s post-processing permute applies the inner result perm as a slab-level rewrite, and pass an identity inner perm to make_contraction_arena_plan so it builds the plan (pre-shaping result cells unpermuted) and selects the perm-free fused scale op. Owning inner cells keep applying the inner perm in the per-cell scale op (outer-only total_perm), unchanged. --- src/TiledArray/expressions/cont_engine.h | 38 +++++++++++++++++++----- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index 867d18d3a9..3f8ec2eaa9 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -305,9 +305,18 @@ class ContEngine : public BinaryEngine { // this->product_type() is Tensor::Contraction, and, // this->implicit_permute_inner_ is false - return this->inner_product_type() == TensorProduct::Scale - ? BipartitePermutation(outer(this->perm_)) - : this->perm_; + if (this->inner_product_type() == TensorProduct::Scale) { + // Owning inner cells apply the inner result permutation in the + // per-cell scale op, so they carry only the outer perm here. View + // (arena) cells instead use a perm-free per-cell op + an unpermuted + // arena plan and rely on op_'s post-processing permute for the + // inner perm -- so they carry the full perm, like inner + // Contraction. + if constexpr (!TiledArray::is_tensor_view_v< + result_tile_element_type>) + return BipartitePermutation(outer(this->perm_)); + } + return this->perm_; }; auto total_perm = make_total_perm(); @@ -341,9 +350,18 @@ class ContEngine : public BinaryEngine { // this->product_type() is Tensor::Contraction, and, // this->implicit_permute_inner_ is false - return this->inner_product_type() == TensorProduct::Scale - ? BipartitePermutation(outer(this->perm_)) - : this->perm_; + if (this->inner_product_type() == TensorProduct::Scale) { + // Owning inner cells apply the inner result permutation in the + // per-cell scale op, so they carry only the outer perm here. View + // (arena) cells instead use a perm-free per-cell op + an unpermuted + // arena plan and rely on op_'s post-processing permute for the + // inner perm -- so they carry the full perm, like inner + // Contraction. + if constexpr (!TiledArray::is_tensor_view_v< + result_tile_element_type>) + return BipartitePermutation(outer(this->perm_)); + } + return this->perm_; }; auto total_perm = make_total_perm(); @@ -949,10 +967,16 @@ class ContEngine : public BinaryEngine { result_tile_type, left_tile_type, right_tile_type>; if constexpr (arena_eligible_scale) { if (this->product_type() == TensorProduct::Contraction) { + // Pass an identity inner perm: a non-identity inner *result* + // permutation is applied downstream by op_'s post-processing + // permute (carried in make_total_perm for view cells), not by + // the per-cell op -- so the plan must not bail on it. The plan + // pre-shapes result cells in the unpermuted (operand) inner + // layout; the perm-free fused scale op accumulates into them. this->arena_plan_ = TiledArray::detail::make_contraction_arena_plan< result_tile_type, left_tile_type, right_tile_type>( - kind, std::nullopt, inner(this->perm_)); + kind, std::nullopt, Permutation{}); } } // Fallback per-element op for the scale inner-product when no From 4e4a1c145935b9d6f11ee3d46b2f8b25c834e60d Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Sun, 24 May 2026 23:45:50 -0400 Subject: [PATCH 3/3] arena: don't drop the inner result perm for owning ToT Scale contractions Commit 34711c8 made the ToT x scalar Scale contraction always hand an identity inner perm to make_contraction_arena_plan, so the plan is built and the perm-free fused scale op is selected, with the inner result perm applied downstream by op_'s post-processing permute. That is correct for view (arena) inner cells -- make_total_perm carries the full perm for them -- but is_contraction_arena_tot_v is also true for owning legacy TA::Tensor ToT inner cells, and for those make_total_perm carries only the outer perm. So an owning ToT Scale contraction with a non-identity inner result permutation lost the inner perm entirely: identity plan + perm-free op + outer-only total_perm, producing a wrong-inner-layout result (and, under distributed eval, a malformed result whose deferred destruction aborted at a later fence). Mirror the make_total_perm view/owning split here: pass an identity inner perm only for view cells; for owning cells pass inner(perm_) so the plan bails (nullopt) on a non-identity inner perm and the per-cell op applies it, exactly as before 34711c8. Restores einsum_manual/ different_nested_ranks and einsum_tot_t/ilkj_nm_eq_ij_mn_times_kl. --- src/TiledArray/expressions/cont_engine.h | 28 ++++++++++++++++++------ 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index 3f8ec2eaa9..9091445f4e 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -967,16 +967,30 @@ class ContEngine : public BinaryEngine { result_tile_type, left_tile_type, right_tile_type>; if constexpr (arena_eligible_scale) { if (this->product_type() == TensorProduct::Contraction) { - // Pass an identity inner perm: a non-identity inner *result* - // permutation is applied downstream by op_'s post-processing - // permute (carried in make_total_perm for view cells), not by - // the per-cell op -- so the plan must not bail on it. The plan - // pre-shapes result cells in the unpermuted (operand) inner - // layout; the perm-free fused scale op accumulates into them. + // The inner perm handed to the plan must match how the inner + // *result* permutation is applied for this result cell type -- + // and the two cell types apply it in different places: + // + // * View (arena) cells: pass an identity inner perm so the + // plan is always built (pre-shaping result cells in the + // unpermuted operand inner layout) and the perm-free fused + // scale op is selected; the inner result perm is applied + // downstream by op_'s post-processing permute (carried in + // make_total_perm for view cells). + // + // * Owning cells: pass the inner result perm so the plan bails + // (nullopt) on a non-identity inner perm, falling back to the + // per-cell op that applies the inner perm itself -- matching + // the outer-only total_perm make_total_perm carries here. + // (A trivial inner perm still lets the plan + fused op run.) + Permutation plan_inner_perm; + if constexpr (!TiledArray::is_tensor_view_v< + result_tile_element_type>) + plan_inner_perm = inner(this->perm_); this->arena_plan_ = TiledArray::detail::make_contraction_arena_plan< result_tile_type, left_tile_type, right_tile_type>( - kind, std::nullopt, Permutation{}); + kind, std::nullopt, plan_inner_perm); } } // Fallback per-element op for the scale inner-product when no