diff --git a/src/TiledArray/tensor/tensor.h b/src/TiledArray/tensor/tensor.h index 8399238f5f..921490add8 100644 --- a/src/TiledArray/tensor/tensor.h +++ b/src/TiledArray/tensor/tensor.h @@ -3096,6 +3096,198 @@ class Tensor { const integer ldb = (gemm_helper.right_op() == TiledArray::math::blas::NoTranspose ? N : K); + // GEMM-based ToT scale path: for the scale contraction + // "m,k;a" * "k,n" -> "m,n;a" (left ToT, right plain scalar), recast each + // row m as one strided GEMM result_m(A_m x N) += left_m(A_m x K) * + // right(K x N), directly on the arena slab -- amortizing the per-cell AXPY + // setup over a single BLAS call. Applies for NoTranspose, matching scalar + // type, and "clean" rows (all cells present, uniform inner size A_m, laid + // out as one contiguous single-page stride-A_m block); other rows fall back + // to the per-cell AXPY loop. + if constexpr (detail::is_numeric_v && is_tensor_view_v && + is_tensor_view_v) { + using Real = std::remove_cv_t; + if constexpr (std::is_same_v, Real>) { + if (gemm_helper.left_op() == TiledArray::math::blas::NoTranspose && + gemm_helper.right_op() == TiledArray::math::blas::NoTranspose) { + for (integer b = 0; b != nbatch(); ++b) { + auto this_data = this->batch_data(b); + auto left_data = left.batch_data(b); + auto right_data = right.batch_data(b); // K x N row-major scalars + for (integer m = 0; m != M; ++m) { + auto* lc0 = left_data + (m * K); // left cells (m,0..K-1) + auto* rc0 = this_data + (m * N); // result cells (m,0..N-1) + // A "clean" row has all cells present, uniform inner size A, and + // laid out as one contiguous stride-A block (so the GEMM can run + // zero-copy directly on the slab). Else fall back to per-cell + // AXPY. + long A = -1; + bool clean = true; + for (integer k = 0; k != K && clean; ++k) { + const auto& c = lc0[k]; + if (c.empty()) { + clean = false; + break; + } + long s = static_cast(c.size()); + if (A < 0) + A = s; + else if (A != s) + clean = false; + } + for (integer n = 0; n != N && clean; ++n) { + const auto& c = rc0[n]; + if (c.empty()) { + clean = false; + break; + } + long s = static_cast(c.size()); + if (A < 0) + A = s; + else if (A != s) + clean = false; + } + // Arena cells are SIMD-padded, so the per-row inter-cell stride + // is the padded inner size (>= A). The strided GEMM requires the + // row's cells to be ONE contiguous run at constant stride -- only + // true for a single-page arena. An incrementally-built (un- + // compacted) ToT tile may span multiple pages, where the stride + // jumps at a page boundary; verify constant stride across ALL + // cells (so multi-page tiles fall back to the AXPY loop). + integer ldb = static_cast(A); + integer ldc = static_cast(A); + if (clean && A > 0) { + if (K > 1) + ldb = static_cast(lc0[1].data() - lc0[0].data()); + if (N > 1) + ldc = static_cast(rc0[1].data() - rc0[0].data()); + if (ldb < A || ldc < A) clean = false; // sanity + const std::ptrdiff_t sb = ldb, sc = ldc; + for (integer k = 0; clean && k != K; ++k) + if (lc0[k].data() != lc0[0].data() + k * sb) clean = false; + for (integer n = 0; clean && n != N; ++n) + if (rc0[n].data() != rc0[0].data() + n * sc) clean = false; + } + if (A <= 0) continue; // empty row -> nothing to do + if (clean) { + // result[m,n][a] += sum_k left[m,k][a] * right[k,n]. + // Row-major gemm: C2(N x A) += right^T(N x K) * L2(K x A), + // where L2 = left row-m slab (K x A, ld=ldb), C2 = result row-m + // slab (N x A, ld=ldc), right is K x N (ld=N). ldb/ldc carry + // padding. + const integer Ai = static_cast(A); + TiledArray::math::blas::gemm( + TiledArray::math::blas::Transpose, + TiledArray::math::blas::NoTranspose, + /*M=*/N, /*N=*/Ai, /*K=*/K, Real(1), + /*A=*/right_data, /*lda=*/N, + /*B=*/lc0[0].data(), /*ldb=*/ldb, Real(1), + /*C=*/rc0[0].data(), /*ldc=*/ldc); + } else { // per-cell AXPY fallback for this row + for (integer n = 0; n != N; ++n) { + auto c_offset = m * N + n; + for (integer k = 0; k != K; ++k) + elem_muladd_op(*(this_data + c_offset), + *(left_data + (m * K + k)), + *(right_data + (k * N + n))); + } + } + } + } + return *this; + } + } + } + + // GEMM-based scale path, mirror for T * ToT ("m,k" * "k,n;a" -> "m,n;a", + // left plain scalar, right ToT). Per column n: one GEMM + // result_n(M x A_n) += left(M x K) * right_n(K x A_n). The right/result + // column-n cells are strided over the slab (constant k-/m-stride within a + // single arena page); verify that, else fall back to per-cell AXPY. + if constexpr (detail::is_numeric_v && is_tensor_view_v && + is_tensor_view_v) { + using Real = std::remove_cv_t; + if constexpr (std::is_same_v, Real>) { + if (gemm_helper.left_op() == TiledArray::math::blas::NoTranspose && + gemm_helper.right_op() == TiledArray::math::blas::NoTranspose) { + for (integer b = 0; b != nbatch(); ++b) { + auto this_data = this->batch_data(b); + auto left_data = left.batch_data(b); // M x K row-major scalars + auto right_data = right.batch_data(b); // K x N ToT + for (integer n = 0; n != N; ++n) { + long A = -1; + bool clean = true; + for (integer k = 0; k != K && clean; ++k) { + const auto& c = right_data[k * N + n]; + if (c.empty()) { + clean = false; + break; + } + long s = static_cast(c.size()); + if (A < 0) + A = s; + else if (A != s) + clean = false; + } + for (integer m = 0; m != M && clean; ++m) { + const auto& c = this_data[m * N + n]; + if (c.empty()) { + clean = false; + break; + } + long s = static_cast(c.size()); + if (A < 0) + A = s; + else if (A != s) + clean = false; + } + integer ldb = static_cast(A); // k-stride, right col n + integer ldc = static_cast(A); // m-stride, result col n + if (clean && A > 0) { + if (K > 1) + ldb = static_cast(right_data[N + n].data() - + right_data[n].data()); + if (M > 1) + ldc = static_cast(this_data[N + n].data() - + this_data[n].data()); + if (ldb < A || ldc < A) clean = false; + const std::ptrdiff_t sb = ldb, sc = ldc; + for (integer k = 0; clean && k != K; ++k) + if (right_data[k * N + n].data() != + right_data[n].data() + k * sb) + clean = false; + for (integer m = 0; clean && m != M; ++m) + if (this_data[m * N + n].data() != + this_data[n].data() + m * sc) + clean = false; + } + if (A <= 0) continue; + if (clean) { + // C_n(M x A) += left(M x K) * B_n(K x A). Row-major gemm. + const integer Ai = static_cast(A); + TiledArray::math::blas::gemm( + TiledArray::math::blas::NoTranspose, + TiledArray::math::blas::NoTranspose, + /*M=*/M, /*N=*/Ai, /*K=*/K, Real(1), + /*A=*/left_data, /*lda=*/K, + /*B=*/right_data[n].data(), /*ldb=*/ldb, Real(1), + /*C=*/this_data[n].data(), /*ldc=*/ldc); + } else { // per-cell AXPY fallback for this column + for (integer m = 0; m != M; ++m) { + auto c_offset = m * N + n; + for (integer k = 0; k != K; ++k) + elem_muladd_op(*(this_data + c_offset), + *(left_data + (m * K + k)), + *(right_data + (k * N + n))); + } + } + } + } + return *this; + } + } + } + for (integer b = 0; b != nbatch(); ++b) { auto this_data = this->batch_data(b); auto left_data = left.batch_data(b);