Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions src/TiledArray/tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -3096,6 +3096,198 @@ class Tensor {
const integer ldb =
(gemm_helper.right_op() == TiledArray::math::blas::NoTranspose ? N : K);

// GEMM-based ToT scale path: for the scale contraction
// "m,k;a" * "k,n" -> "m,n;a" (left ToT, right plain scalar), recast each
// row m as one strided GEMM result_m(A_m x N) += left_m(A_m x K) *
// right(K x N), directly on the arena slab -- amortizing the per-cell AXPY
// setup over a single BLAS call. Applies for NoTranspose, matching scalar
// type, and "clean" rows (all cells present, uniform inner size A_m, laid
// out as one contiguous single-page stride-A_m block); other rows fall back
// to the per-cell AXPY loop.
if constexpr (detail::is_numeric_v<V> && is_tensor_view_v<U> &&
is_tensor_view_v<value_type>) {
using Real = std::remove_cv_t<typename value_type::value_type>;
if constexpr (std::is_same_v<std::remove_cv_t<V>, Real>) {
if (gemm_helper.left_op() == TiledArray::math::blas::NoTranspose &&
gemm_helper.right_op() == TiledArray::math::blas::NoTranspose) {
for (integer b = 0; b != nbatch(); ++b) {
auto this_data = this->batch_data(b);
auto left_data = left.batch_data(b);
auto right_data = right.batch_data(b); // K x N row-major scalars
for (integer m = 0; m != M; ++m) {
auto* lc0 = left_data + (m * K); // left cells (m,0..K-1)
auto* rc0 = this_data + (m * N); // result cells (m,0..N-1)
// A "clean" row has all cells present, uniform inner size A, and
// laid out as one contiguous stride-A block (so the GEMM can run
// zero-copy directly on the slab). Else fall back to per-cell
// AXPY.
long A = -1;
bool clean = true;
for (integer k = 0; k != K && clean; ++k) {
const auto& c = lc0[k];
if (c.empty()) {
clean = false;
break;
}
long s = static_cast<long>(c.size());
if (A < 0)
A = s;
else if (A != s)
clean = false;
}
for (integer n = 0; n != N && clean; ++n) {
const auto& c = rc0[n];
if (c.empty()) {
clean = false;
break;
}
long s = static_cast<long>(c.size());
if (A < 0)
A = s;
else if (A != s)
clean = false;
}
// Arena cells are SIMD-padded, so the per-row inter-cell stride
// is the padded inner size (>= A). The strided GEMM requires the
// row's cells to be ONE contiguous run at constant stride -- only
// true for a single-page arena. An incrementally-built (un-
// compacted) ToT tile may span multiple pages, where the stride
// jumps at a page boundary; verify constant stride across ALL
// cells (so multi-page tiles fall back to the AXPY loop).
integer ldb = static_cast<integer>(A);
integer ldc = static_cast<integer>(A);
if (clean && A > 0) {
if (K > 1)
ldb = static_cast<integer>(lc0[1].data() - lc0[0].data());
if (N > 1)
ldc = static_cast<integer>(rc0[1].data() - rc0[0].data());
if (ldb < A || ldc < A) clean = false; // sanity
const std::ptrdiff_t sb = ldb, sc = ldc;
for (integer k = 0; clean && k != K; ++k)
if (lc0[k].data() != lc0[0].data() + k * sb) clean = false;
for (integer n = 0; clean && n != N; ++n)
if (rc0[n].data() != rc0[0].data() + n * sc) clean = false;
}
if (A <= 0) continue; // empty row -> nothing to do
if (clean) {
// result[m,n][a] += sum_k left[m,k][a] * right[k,n].
// Row-major gemm: C2(N x A) += right^T(N x K) * L2(K x A),
// where L2 = left row-m slab (K x A, ld=ldb), C2 = result row-m
// slab (N x A, ld=ldc), right is K x N (ld=N). ldb/ldc carry
// padding.
const integer Ai = static_cast<integer>(A);
TiledArray::math::blas::gemm(
TiledArray::math::blas::Transpose,
TiledArray::math::blas::NoTranspose,
/*M=*/N, /*N=*/Ai, /*K=*/K, Real(1),
/*A=*/right_data, /*lda=*/N,
/*B=*/lc0[0].data(), /*ldb=*/ldb, Real(1),
/*C=*/rc0[0].data(), /*ldc=*/ldc);
} else { // per-cell AXPY fallback for this row
for (integer n = 0; n != N; ++n) {
auto c_offset = m * N + n;
for (integer k = 0; k != K; ++k)
elem_muladd_op(*(this_data + c_offset),
*(left_data + (m * K + k)),
*(right_data + (k * N + n)));
}
}
}
}
return *this;
}
}
}

// GEMM-based scale path, mirror for T * ToT ("m,k" * "k,n;a" -> "m,n;a",
// left plain scalar, right ToT). Per column n: one GEMM
// result_n(M x A_n) += left(M x K) * right_n(K x A_n). The right/result
// column-n cells are strided over the slab (constant k-/m-stride within a
// single arena page); verify that, else fall back to per-cell AXPY.
if constexpr (detail::is_numeric_v<U> && is_tensor_view_v<V> &&
is_tensor_view_v<value_type>) {
using Real = std::remove_cv_t<typename value_type::value_type>;
if constexpr (std::is_same_v<std::remove_cv_t<U>, Real>) {
if (gemm_helper.left_op() == TiledArray::math::blas::NoTranspose &&
gemm_helper.right_op() == TiledArray::math::blas::NoTranspose) {
for (integer b = 0; b != nbatch(); ++b) {
auto this_data = this->batch_data(b);
auto left_data = left.batch_data(b); // M x K row-major scalars
auto right_data = right.batch_data(b); // K x N ToT
for (integer n = 0; n != N; ++n) {
long A = -1;
bool clean = true;
for (integer k = 0; k != K && clean; ++k) {
const auto& c = right_data[k * N + n];
if (c.empty()) {
clean = false;
break;
}
long s = static_cast<long>(c.size());
if (A < 0)
A = s;
else if (A != s)
clean = false;
}
for (integer m = 0; m != M && clean; ++m) {
const auto& c = this_data[m * N + n];
if (c.empty()) {
clean = false;
break;
}
long s = static_cast<long>(c.size());
if (A < 0)
A = s;
else if (A != s)
clean = false;
}
integer ldb = static_cast<integer>(A); // k-stride, right col n
integer ldc = static_cast<integer>(A); // m-stride, result col n
if (clean && A > 0) {
if (K > 1)
ldb = static_cast<integer>(right_data[N + n].data() -
right_data[n].data());
if (M > 1)
ldc = static_cast<integer>(this_data[N + n].data() -
this_data[n].data());
if (ldb < A || ldc < A) clean = false;
const std::ptrdiff_t sb = ldb, sc = ldc;
for (integer k = 0; clean && k != K; ++k)
if (right_data[k * N + n].data() !=
right_data[n].data() + k * sb)
clean = false;
for (integer m = 0; clean && m != M; ++m)
if (this_data[m * N + n].data() !=
this_data[n].data() + m * sc)
clean = false;
}
if (A <= 0) continue;
if (clean) {
// C_n(M x A) += left(M x K) * B_n(K x A). Row-major gemm.
const integer Ai = static_cast<integer>(A);
TiledArray::math::blas::gemm(
TiledArray::math::blas::NoTranspose,
TiledArray::math::blas::NoTranspose,
/*M=*/M, /*N=*/Ai, /*K=*/K, Real(1),
/*A=*/left_data, /*lda=*/K,
/*B=*/right_data[n].data(), /*ldb=*/ldb, Real(1),
/*C=*/this_data[n].data(), /*ldc=*/ldc);
} else { // per-cell AXPY fallback for this column
for (integer m = 0; m != M; ++m) {
auto c_offset = m * N + n;
for (integer k = 0; k != K; ++k)
elem_muladd_op(*(this_data + c_offset),
*(left_data + (m * K + k)),
*(right_data + (k * N + n)));
}
}
}
}
return *this;
}
}
}

for (integer b = 0; b != nbatch(); ++b) {
auto this_data = this->batch_data(b);
auto left_data = left.batch_data(b);
Expand Down
Loading